mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#1492)
* Update TensorRT-LLM --------- Co-authored-by: Loki <lokravi@amazon.com>
This commit is contained in:
parent
71d8d4d3dc
commit
66ef1df492
2
.gitignore
vendored
2
.gitignore
vendored
@ -32,6 +32,8 @@ cpp/.ccache/
|
||||
tensorrt_llm/libs
|
||||
tensorrt_llm/bindings.pyi
|
||||
tensorrt_llm/bindings/*.pyi
|
||||
*docs/cpp_docs*
|
||||
*docs/source/_cpp_gen*
|
||||
|
||||
# Testing
|
||||
.coverage.*
|
||||
|
||||
2
3rdparty/cutlass
vendored
2
3rdparty/cutlass
vendored
@ -1 +1 @@
|
||||
Subproject commit a8f2c80db0564c74f4efccac71993b971dfc448b
|
||||
Subproject commit 7d49e6c7e2f8896c47f586706e67e1fb215529dc
|
||||
460
README.md
460
README.md
@ -11,7 +11,7 @@ TensorRT-LLM
|
||||
[](./setup.py)
|
||||
[](./LICENSE)
|
||||
|
||||
[Architecture](./docs/source/architecture.md) | [Results](./docs/source/performance.md) | [Examples](./examples/) | [Documentation](./docs/source/)
|
||||
[Architecture](./docs/source/architecture/overview.md) | [Results](./docs/source/performance/perf-overview.md) | [Examples](./examples/) | [Documentation](./docs/source/)
|
||||
|
||||
---
|
||||
<div align="left">
|
||||
@ -29,42 +29,13 @@ TensorRT-LLM
|
||||
* [2023/10/17] [Large Language Models up to 4x Faster on RTX With TensorRT-LLM for Windows
|
||||
](https://blogs.nvidia.com/blog/2023/10/17/tensorrt-llm-windows-stable-diffusion-rtx/)
|
||||
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [TensorRT-LLM](#tensorrt-llm)
|
||||
- [Latest News](#latest-news)
|
||||
- [Table of Contents](#table-of-contents)
|
||||
- [TensorRT-LLM Overview](#tensorrt-llm-overview)
|
||||
- [Installation](#installation)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Support Matrix](#support-matrix)
|
||||
- [Devices](#devices)
|
||||
- [Precision](#precision)
|
||||
- [Key Features](#key-features)
|
||||
- [Models](#models)
|
||||
- [Performance](#performance)
|
||||
- [Advanced Topics](#advanced-topics)
|
||||
- [Quantization](#quantization)
|
||||
- [In-flight Batching](#in-flight-batching)
|
||||
- [Attention](#attention)
|
||||
- [Graph Rewriting](#graph-rewriting)
|
||||
- [Benchmark](#benchmark)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
- [Release notes](#release-notes)
|
||||
- [Change Log](#change-log)
|
||||
- [Versions 0.9.0](#versions-090)
|
||||
- [For history change log, please see CHANGELOG.md.](#for-history-change-log-please-see-changelogmd)
|
||||
- [Known Issues](#known-issues)
|
||||
- [Report Issues](#report-issues)
|
||||
|
||||
## TensorRT-LLM Overview
|
||||
|
||||
TensorRT-LLM provides users with an easy-to-use Python API to define Large
|
||||
TensorRT-LLM is an easy-to-use Python API to define Large
|
||||
Language Models (LLMs) and build
|
||||
[TensorRT](https://developer.nvidia.com/tensorrt) engines that contain
|
||||
state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.
|
||||
TensorRT-LLM also contains components to create Python and C++ runtimes that
|
||||
TensorRT-LLM contains components to create Python and C++ runtimes that
|
||||
execute those TensorRT engines. It also includes a
|
||||
[backend](https://github.com/triton-inference-server/tensorrtllm_backend)
|
||||
for integration with the
|
||||
@ -76,8 +47,8 @@ multiple nodes with multiple GPUs (using
|
||||
and/or
|
||||
[Pipeline Parallelism](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/nlp/nemo_megatron/parallelisms.html#pipeline-parallelism)).
|
||||
|
||||
The Python API of TensorRT-LLM is architectured to look similar to the
|
||||
[PyTorch](https://pytorch.org) API. It provides users with a
|
||||
The TensorRT-LLM Python API architecture looks similar to the
|
||||
[PyTorch](https://pytorch.org) API. It provides a
|
||||
[functional](./tensorrt_llm/functional.py) module containing functions like
|
||||
`einsum`, `softmax`, `matmul` or `view`. The [layers](./tensorrt_llm/layers)
|
||||
module bundles useful building blocks to assemble LLMs; like an `Attention`
|
||||
@ -86,422 +57,21 @@ like `GPTAttention` or `BertAttention`, can be found in the
|
||||
[models](./tensorrt_llm/models) module.
|
||||
|
||||
TensorRT-LLM comes with several popular models pre-defined. They can easily be
|
||||
modified and extended to fit custom needs. See below for a list of supported
|
||||
[models](#Models).
|
||||
modified and extended to fit custom needs. Refer to the [Support Matrix](https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html) for a list of supported models.
|
||||
|
||||
To maximize performance and reduce memory footprint, TensorRT-LLM allows the
|
||||
models to be executed using different quantization modes (see
|
||||
[`examples/gpt`](./examples/gpt) for concrete examples). TensorRT-LLM supports
|
||||
models to be executed using different quantization modes (refer to
|
||||
[`support matrix`](https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html#software)). TensorRT-LLM supports
|
||||
INT4 or INT8 weights (and FP16 activations; a.k.a. INT4/INT8 weight-only) as
|
||||
well as a complete implementation of the
|
||||
[SmoothQuant](https://arxiv.org/abs/2211.10438) technique.
|
||||
|
||||
For a more detailed presentation of the software architecture and the key
|
||||
concepts used in TensorRT-LLM, we recommend you to read the following
|
||||
[document](./docs/source/architecture.md).
|
||||
## Getting Started
|
||||
|
||||
## Installation
|
||||
To get started with TensorRT-LLM, visit our documentation:
|
||||
|
||||
After installing the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit),
|
||||
please run the following commands to install TensorRT-LLM for x86_64 users.
|
||||
|
||||
```bash
|
||||
# Obtain and start the basic docker image environment.
|
||||
docker run --rm --runtime=nvidia --gpus all --entrypoint /bin/bash -it nvidia/cuda:12.1.0-devel-ubuntu22.04
|
||||
|
||||
# Install dependencies, TensorRT-LLM requires Python 3.10
|
||||
apt-get update && apt-get -y install python3.10 python3-pip openmpi-bin libopenmpi-dev
|
||||
|
||||
# Install the latest preview version (corresponding to the main branch) of TensorRT-LLM.
|
||||
# If you want to install the stable version (corresponding to the release branch), please
|
||||
# remove the `--pre` option.
|
||||
pip3 install tensorrt_llm -U --pre --extra-index-url https://pypi.nvidia.com
|
||||
|
||||
# Check installation
|
||||
python3 -c "import tensorrt_llm"
|
||||
```
|
||||
|
||||
For developers who have the best performance requirements, debugging needs, or use the aarch64 architecture,
|
||||
please refer to the instructions for [building from source code](docs/source/build_from_source.md).
|
||||
|
||||
For Windows installation, see [`Windows`](windows/README.md).
|
||||
|
||||
## Quick Start
|
||||
|
||||
Please be sure to complete the [installation steps](#installation) before proceeding with the following steps.
|
||||
|
||||
To create a TensorRT engine for an existing model, there are 3 steps:
|
||||
|
||||
1. Download pre-trained weights,
|
||||
2. Build a fully-optimized engine of the model,
|
||||
3. Deploy the engine, in other words, run the fully-optimized model.
|
||||
|
||||
The following sections show how to use TensorRT-LLM to run the
|
||||
[BLOOM-560m](https://huggingface.co/bigscience/bloom-560m) model.
|
||||
|
||||
***0. In the BLOOM folder***
|
||||
|
||||
Inside the Docker container, you have to install the requirements:
|
||||
|
||||
```bash
|
||||
pip install -r examples/bloom/requirements.txt
|
||||
git lfs install
|
||||
```
|
||||
|
||||
***1. Download the model weights from HuggingFace***
|
||||
|
||||
From the BLOOM example folder, you must download the weights of the model.
|
||||
|
||||
```bash
|
||||
cd examples/bloom
|
||||
rm -rf ./bloom/560M
|
||||
mkdir -p ./bloom/560M && git clone https://huggingface.co/bigscience/bloom-560m ./bloom/560M
|
||||
|
||||
```
|
||||
***2. Build the engine***
|
||||
|
||||
```bash
|
||||
# Single GPU on BLOOM 560M
|
||||
python convert_checkpoint.py --model_dir ./bloom/560M/ \
|
||||
--dtype float16 \
|
||||
--output_dir ./bloom/560M/trt_ckpt/fp16/1-gpu/
|
||||
# May need to add trtllm-build to PATH, export PATH=/usr/local/bin:$PATH
|
||||
trtllm-build --checkpoint_dir ./bloom/560M/trt_ckpt/fp16/1-gpu/ \
|
||||
--gemm_plugin float16 \
|
||||
--output_dir ./bloom/560M/trt_engines/fp16/1-gpu/
|
||||
```
|
||||
|
||||
See the BLOOM [example](examples/bloom) for more details and options regarding the `trtllm-build` command.
|
||||
|
||||
***3. Run***
|
||||
|
||||
The `../summarize.py` script can be used to perform the summarization of articles
|
||||
from the CNN Daily dataset:
|
||||
|
||||
```bash
|
||||
python ../summarize.py --test_trt_llm \
|
||||
--hf_model_dir ./bloom/560M/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./bloom/560M/trt_engines/fp16/1-gpu/
|
||||
```
|
||||
|
||||
More details about the script and how to run the BLOOM model can be found in
|
||||
the example [folder](examples/bloom). Many more [models](#models) than BLOOM
|
||||
are implemented in TensorRT-LLM. They can be found in the
|
||||
[examples](./examples/) directory.
|
||||
|
||||
Beyond local execution, you can also use the NVIDIA Triton Inference Server to create a production-ready deployment of your LLM as described in this [blog](https://developer.nvidia.com/blog/optimizing-inference-on-llms-with-tensorrt-llm-now-publicly-available/).
|
||||
|
||||
## Support Matrix
|
||||
|
||||
TensorRT-LLM optimizes the performance of a range of well-known models on
|
||||
NVIDIA GPUs. The following sections provide a list of supported GPU
|
||||
architectures as well as important features implemented in TensorRT-LLM.
|
||||
|
||||
### Devices
|
||||
|
||||
TensorRT-LLM supports the following architectures:
|
||||
|
||||
* [NVIDIA Hopper](https://www.nvidia.com/en-us/data-center/technologies/hopper-architecture/) (SM90), for example, H200, H100, H20
|
||||
* [NVIDIA Ada Lovelace](https://www.nvidia.com/en-us/geforce/ada-lovelace-architecture/) (SM89), for example, L40S, L20, L4
|
||||
* [NVIDIA Ampere](https://www.nvidia.com/en-us/data-center/ampere-architecture/) (SM80, SM86), for example, A100, A30, A10G
|
||||
* [NVIDIA Turing](https://www.nvidia.com/en-us/geforce/turing/) (SM75), for example, T4
|
||||
* [NVIDIA Volta](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) (SM70 - experimental), for example, V100
|
||||
|
||||
|
||||
It is important to note that TensorRT-LLM is expected to work on all GPUs based on the Volta, Turing, Ampere, Hopper, and Ada Lovelace architectures. Certain limitations may apply.
|
||||
|
||||
### Precision
|
||||
|
||||
Various numerical precisions are supported in TensorRT-LLM. The support for
|
||||
some of those numerical features require specific architectures:
|
||||
|
||||
| | FP32 | FP16 | BF16 | FP8 | INT8 | INT4 |
|
||||
| :------------------ | :--- | :--- | :--- | :--- | :---- | :---- |
|
||||
| Volta (SM70) | Y | Y | N | N | Y (1) | Y (2) |
|
||||
| Turing (SM75) | Y | Y | N | N | Y (1) | Y (2) |
|
||||
| Ampere (SM80, SM86) | Y | Y | Y | N | Y | Y (3) |
|
||||
| Ada-Lovelace (SM89) | Y | Y | Y | Y | Y | Y |
|
||||
| Hopper (SM90) | Y | Y | Y | Y | Y | Y |
|
||||
|
||||
(1) INT8 SmoothQuant is not supported on SM70 and SM75.<br>
|
||||
(2) INT4 AWQ and GPTQ are not supported on SM < 80.<br>
|
||||
(3) INT4 AWQ and GPTQ with FP8 activations require SM >= 89.
|
||||
|
||||
In this release of TensorRT-LLM, the support for FP8 and quantized data types
|
||||
(INT8 or INT4) is not implemented for all the models. See the
|
||||
[precision](./docs/source/precision.md) document and the
|
||||
[examples](./examples/.) folder for additional details.
|
||||
|
||||
### Key Features
|
||||
|
||||
TensorRT-LLM contains examples that implement the following features.
|
||||
|
||||
* Multi-head Attention([MHA](https://arxiv.org/abs/1706.03762))
|
||||
* Multi-query Attention ([MQA](https://arxiv.org/abs/1911.02150))
|
||||
* Group-query Attention([GQA](https://arxiv.org/abs/2307.09288))
|
||||
* In-flight Batching
|
||||
* Paged KV Cache for the Attention
|
||||
* Tensor Parallelism
|
||||
* Pipeline Parallelism
|
||||
* INT4/INT8 Weight-Only Quantization (W4A16 & W8A16)
|
||||
* [SmoothQuant](https://arxiv.org/abs/2211.10438)
|
||||
* [GPTQ](https://arxiv.org/abs/2210.17323)
|
||||
* [AWQ](https://arxiv.org/abs/2306.00978)
|
||||
* [FP8](https://arxiv.org/abs/2209.05433)
|
||||
* Greedy-search
|
||||
* Beam-search
|
||||
* RoPE
|
||||
|
||||
In this release of TensorRT-LLM, some of the features are not enabled for all
|
||||
the models listed in the [examples](examples/.) folder.
|
||||
|
||||
### Models
|
||||
|
||||
The list of supported models is:
|
||||
|
||||
* [Baichuan](examples/baichuan)
|
||||
* [BART](examples/enc_dec)
|
||||
* [BERT](examples/bert)
|
||||
* [Blip2](examples/blip2)
|
||||
* [BLOOM](examples/bloom)
|
||||
* [ChatGLM](examples/chatglm)
|
||||
* [FairSeq NMT](examples/enc_dec/nmt)
|
||||
* [Falcon](examples/falcon)
|
||||
* [Flan-T5](examples/enc_dec)
|
||||
* [GPT](examples/gpt)
|
||||
* [GPT-J](examples/gptj)
|
||||
* [GPT-Nemo](examples/gpt)
|
||||
* [GPT-NeoX](examples/gptneox)
|
||||
* [InternLM](examples/internlm)
|
||||
* [LLaMA](examples/llama)
|
||||
* [LLaMA-v2](examples/llama)
|
||||
* [Mamba](examples/mamba)
|
||||
* [mBART](examples/enc_dec)
|
||||
* [Medusa](examples/medusa)
|
||||
* [Mistral](examples/llama#mistral-v01)
|
||||
* [MPT](examples/mpt)
|
||||
* [mT5](examples/enc_dec)
|
||||
* [OPT](examples/opt)
|
||||
* [Phi-1.5/Phi-2](examples/phi)
|
||||
* [Qwen](examples/qwen)
|
||||
* [Replit Code](examples/mpt)
|
||||
* [RoBERTa](examples/bert)
|
||||
* [SantaCoder](examples/gpt)
|
||||
* [StarCoder1/StarCoder2](examples/gpt)
|
||||
* [T5](examples/enc_dec)
|
||||
* [Whisper](examples/whisper)
|
||||
|
||||
Note: [Encoder-Decoder](examples/enc_dec/) provides general encoder-decoder
|
||||
functionality that supports many encoder-decoder models such as T5 family, BART family, Whisper family, NMT family, etc. We
|
||||
unroll the exact model names in the list above to let users find specific
|
||||
models easier.
|
||||
|
||||
The list of supported multi-modal models is:
|
||||
|
||||
* [BLIP2 w/ OPT-2.7B](examples/multimodal)
|
||||
* [BLIP2 w/ T5-XL](examples/multimodal)
|
||||
* [LLaVA-v1.5-7B](examples/multimodal)
|
||||
* [Nougat family](examples/multimodal) Nougat-small, Nougat-base
|
||||
|
||||
Note: Multi-modal provides general multi-modal functionality that supports many multi-modal architectures such as BLIP family, LLaVA family, etc. We unroll the exact model names in the list above to let users find specific models easier.
|
||||
|
||||
## Performance
|
||||
|
||||
Please refer to the [performance](./docs/source/performance.md) page for
|
||||
performance numbers. That page contains measured numbers for four variants of
|
||||
popular models (GPT-J, LLAMA-7B, LLAMA-70B, Falcon-180B), measured on the H100,
|
||||
L40S and A100 GPU(s).
|
||||
|
||||
## Advanced Topics
|
||||
|
||||
### Quantization
|
||||
|
||||
This [document](./docs/source/precision.md) describes the different
|
||||
quantization methods implemented in TensorRT-LLM and contains a support matrix
|
||||
for the different models.
|
||||
|
||||
### In-flight Batching
|
||||
|
||||
TensorRT-LLM supports in-flight batching of requests (also known as continuous
|
||||
batching or iteration-level batching). It's a
|
||||
[technique](./docs/source/batch_manager.md) that aims at reducing wait
|
||||
times in queues, eliminating the need for padding requests and allowing for
|
||||
higher GPU utilization.
|
||||
|
||||
### Attention
|
||||
|
||||
TensorRT-LLM implements several variants of the Attention mechanism that
|
||||
appears in most the Large Language Models. This
|
||||
[document](./docs/source/gpt_attention.md) summarizes those implementations and
|
||||
how they are optimized in TensorRT-LLM.
|
||||
|
||||
### Graph Rewriting
|
||||
|
||||
TensorRT-LLM uses a declarative approach to define neural networks and contains
|
||||
techniques to optimize the underlying graph. For more details, please refer to
|
||||
[doc](./docs/source/graph-rewriting.md)
|
||||
|
||||
### Benchmark
|
||||
|
||||
TensorRT-LLM provides [C++](./benchmarks/cpp/README.md) and
|
||||
[Python](./benchmarks/python/README.md) tools to perform benchmarking. Note,
|
||||
however, that it is recommended to use the C++ version.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
* If you encounter accuracy issues in the generated text, you may want to increase
|
||||
the internal precision in the attention layer. For that, pass the `--context_fmha_fp32_acc enable` to
|
||||
`trtllm-build`.
|
||||
|
||||
* It's recommended to add options `–shm-size=1g –ulimit memlock=-1` to the
|
||||
docker or nvidia-docker run command. Otherwise you may see NCCL errors when
|
||||
running multiple GPU inferences. See
|
||||
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html#errors
|
||||
for details.
|
||||
|
||||
* When building models, memory-related issues such as
|
||||
```
|
||||
[09/23/2023-03:13:00] [TRT] [E] 9: GPTLMHeadModel/layers/0/attention/qkv/PLUGIN_V2_Gemm_0: could not find any supported formats consistent with input/output data types
|
||||
[09/23/2023-03:13:00] [TRT] [E] 9: [pluginV2Builder.cpp::reportPluginError::24] Error Code 9: Internal Error (GPTLMHeadModel/layers/0/attention/qkv/PLUGIN_V2_Gemm_0: could not find any supported formats consistent with input/output data types)
|
||||
```
|
||||
may happen. One possible solution is to reduce the amount of memory needed by
|
||||
reducing the maximum batch size, input and output lengths. Another option is to
|
||||
enable plugins, for example: `--gpt_attention_plugin`.
|
||||
|
||||
* MPI + Slurm
|
||||
|
||||
TensorRT-LLM is a
|
||||
[MPI](https://en.wikipedia.org/wiki/Message_Passing_Interface)-aware package
|
||||
that uses [`mpi4py`](https://mpi4py.readthedocs.io/en/stable/). If you are
|
||||
running scripts in a [Slurm](https://slurm.schedmd.com/) environment, you might
|
||||
encounter interferences:
|
||||
```
|
||||
--------------------------------------------------------------------------
|
||||
PMI2_Init failed to initialize. Return code: 14
|
||||
--------------------------------------------------------------------------
|
||||
--------------------------------------------------------------------------
|
||||
The application appears to have been direct launched using "srun",
|
||||
but OMPI was not built with SLURM's PMI support and therefore cannot
|
||||
execute. There are several options for building PMI support under
|
||||
SLURM, depending upon the SLURM version you are using:
|
||||
|
||||
version 16.05 or later: you can use SLURM's PMIx support. This
|
||||
requires that you configure and build SLURM --with-pmix.
|
||||
|
||||
Versions earlier than 16.05: you must use either SLURM's PMI-1 or
|
||||
PMI-2 support. SLURM builds PMI-1 by default, or you can manually
|
||||
install PMI-2. You must then build Open MPI using --with-pmi pointing
|
||||
to the SLURM PMI library location.
|
||||
|
||||
Please configure as appropriate and try again.
|
||||
--------------------------------------------------------------------------
|
||||
```
|
||||
As a rule of thumb, if you are running TensorRT-LLM interactively on a Slurm
|
||||
node, prefix your commands with `mpirun -n 1` to run TensorRT-LLM in a
|
||||
dedicated MPI environment, not the one provided by your Slurm allocation.
|
||||
|
||||
For example: `mpirun -n 1 python3 examples/run.py ...`
|
||||
|
||||
## Release notes
|
||||
|
||||
* TensorRT-LLM requires TensorRT 9.3 and 24.02 containers.
|
||||
|
||||
### Change Log
|
||||
|
||||
#### Versions 0.9.0
|
||||
|
||||
* Model Support
|
||||
- Support distil-whisper, thanks to the contribution from @Bhuvanesh09 in PR #1061
|
||||
- Support HuggingFace StarCoder2
|
||||
- Support VILA
|
||||
- Support Smaug-72B-v0.1
|
||||
- Migrate BLIP-2 examples to `examples/multimodal`
|
||||
* Features
|
||||
- Add support to context chunking to work with KV cache reuse
|
||||
- Enable different rewind tokens per sequence for Medusa
|
||||
- BART LoRA support (limited to the Python runtime)
|
||||
- Enable multi-LoRA for BART LoRA
|
||||
- Support `early_stopping=False` in beam search for C++ Runtime
|
||||
- Add logits post processor to the batch manager (see docs/source/batch_manager.md#logits-post-processor-optional)
|
||||
- Support import and convert HuggingFace Gemma checkpoints, thanks for the contribution from @mfuntowicz in #1147
|
||||
- Support loading Gemma from HuggingFace
|
||||
- Support auto parallelism planner for high-level API and unified builder workflow
|
||||
- Support run `GptSession` without OpenMPI #1220
|
||||
- [BREAKING CHANGE] TopP sampling optimization with deterministic AIR TopP algorithm is enabled by default
|
||||
- Medusa IFB support
|
||||
- [Experimental] Support FP8 FMHA, note that the performance is not optimal, and we will keep optimizing it
|
||||
- [BREAKING CHANGE] Support embedding sharing for Gemma
|
||||
- More head sizes support for LLaMA-like models
|
||||
- Ampere (sm80, sm86), Ada (sm89), Hopper(sm90) all support head sizes [32, 40, 64, 80, 96, 104, 128, 160, 256] now.
|
||||
- OOTB functionality support
|
||||
- T5
|
||||
- Mixtral 8x7B
|
||||
* API
|
||||
- C++ `executor` API
|
||||
- Add Python bindings, see documentation and examples in `examples/bindings`
|
||||
- Add advanced and multi-GPU examples for Python binding of `executor` C++ API, see `examples/bindings/README.md`
|
||||
- Add documents for C++ `executor` API, see `docs/source/executor.md`
|
||||
- High-level API (refer to `examples/high-level-api/README.md` for guidance)
|
||||
- [BREAKING CHANGE] Reuse the `QuantConfig` used in `trtllm-build` tool, support broader quantization features
|
||||
- Support in `LLM()` API to accept engines built by `trtllm-build` command
|
||||
- Add support for TensorRT-LLM checkpoint as model input
|
||||
- Refine `SamplingConfig` used in `LLM.generate` or `LLM.generate_async` APIs, with the support of beam search, a variety of penalties, and more features
|
||||
- Add support for the StreamingLLM feature, enable it by setting `LLM(streaming_llm=...)`
|
||||
- Migrate Mixtral to high level API and unified builder workflow
|
||||
- [BREAKING CHANGE] Refactored Qwen model to the unified build workflow, see `examples/qwen/README.md` for the latest commands
|
||||
- [BREAKING CHANGE] Move LLaMA convert checkpoint script from examples directory into the core library
|
||||
- [BREAKING CHANGE] Refactor GPT with unified building workflow, see `examples/gpt/README.md` for the latest commands
|
||||
- [BREAKING CHANGE] Removed all the lora related flags from convert_checkpoint.py script and the checkpoint content to `trtllm-build` command, to generalize the feature better to more models
|
||||
- [BREAKING CHANGE] Removed the use_prompt_tuning flag and options from convert_checkpoint.py script and the checkpoint content, to generalize the feature better to more models. Use the `trtllm-build --max_prompt_embedding_table_size` instead.
|
||||
- [BREAKING CHANGE] Changed the `trtllm-build --world_size` flag to `--auto_parallel` flag, the option is used for auto parallel planner only.
|
||||
- [BREAKING CHANGE] `AsyncLLMEngine` is removed, `tensorrt_llm.GenerationExecutor` class is refactored to work with both explicitly launching with `mpirun` in the application level, and accept an MPI communicator created by `mpi4py`
|
||||
- [BREAKING CHANGE] `examples/server` are removed, see `examples/app` instead.
|
||||
- [BREAKING CHANGE] Remove LoRA related parameters from convert checkpoint scripts
|
||||
- [BREAKING CHANGE] Simplify Qwen convert checkpoint script
|
||||
- [BREAKING CHANGE] Remove `model` parameter from `gptManagerBenchmark` and `gptSessionBenchmark`
|
||||
* Bug fixes
|
||||
- Fix a weight-only quant bug for Whisper to make sure that the `encoder_input_len_range` is not 0, thanks to the contribution from @Eddie-Wang1120 in #992
|
||||
- Fix the issue that log probabilities in Python runtime are not returned #983
|
||||
- Multi-GPU fixes for multimodal examples #1003
|
||||
- Fix wrong `end_id` issue for Qwen #987
|
||||
- Fix a non-stopping generation issue #1118 #1123
|
||||
- Fix wrong link in examples/mixtral/README.md #1181
|
||||
- Fix LLaMA2-7B bad results when int8 kv cache and per-channel int8 weight only are enabled #967
|
||||
- Fix wrong `head_size` when importing Gemma model from HuggingFace Hub, thanks for the contribution from @mfuntowicz in #1148
|
||||
- Fix ChatGLM2-6B building failure on INT8 #1239
|
||||
- Fix wrong relative path in Baichuan documentation #1242
|
||||
- Fix wrong `SamplingConfig` tensors in `ModelRunnerCpp` #1183
|
||||
- Fix error when converting SmoothQuant LLaMA #1267
|
||||
- Fix the issue that `examples/run.py` only load one line from `--input_file`
|
||||
- Fix the issue that `ModelRunnerCpp` does not transfer `SamplingConfig` tensor fields correctly #1183
|
||||
* Benchmark
|
||||
- Add emulated static batching in `gptManagerBenchmark`
|
||||
- Support arbitrary dataset from HuggingFace for C++ benchmarks, see “Prepare dataset” section in `benchmarks/cpp/README.md`
|
||||
- Add percentile latency report to `gptManagerBenchmark`
|
||||
* Performance
|
||||
- Optimize `gptDecoderBatch` to support batched sampling
|
||||
- Enable FMHA for models in BART, Whisper and NMT family
|
||||
- Remove router tensor parallelism to improve performance for MoE models, thanks to the contribution from @megha95 in #1091
|
||||
- Improve custom all-reduce kernel
|
||||
* Infra
|
||||
- Base Docker image for TensorRT-LLM is updated to `nvcr.io/nvidia/pytorch:24.02-py3`
|
||||
- Base Docker image for TensorRT-LLM backend is updated to `nvcr.io/nvidia/tritonserver:24.02-py3`
|
||||
- The dependent TensorRT version is updated to 9.3
|
||||
- The dependent PyTorch version is updated to 2.2
|
||||
- The dependent CUDA version is updated to 12.3.2 (a.k.a. 12.3 Update 2)
|
||||
|
||||
#### For history change log, please see [CHANGELOG.md](./CHANGELOG.md).
|
||||
|
||||
### Known Issues
|
||||
|
||||
* On windows, running context FMHA plugin with FP16 accumulation on LLaMA, Mistral and Phi models suffers from poor accuracy and the resulting inference output may be garbled. The suggestion to workaround these is to enable FP32 accumulation when building the models, i.e. passing the options `--context_fmha disable --context_fmha_fp32_acc enable` to `trtllm-build` command as a work-around, and this should be fixed in the next version
|
||||
|
||||
* The hang reported in issue
|
||||
[#149](https://github.com/triton-inference-server/tensorrtllm_backend/issues/149)
|
||||
has not been reproduced by the TensorRT-LLM team. If it is caused by a bug
|
||||
in TensorRT-LLM, that bug may be present in that release
|
||||
|
||||
### Report Issues
|
||||
|
||||
You can use GitHub issues to report issues with TensorRT-LLM.
|
||||
- [Quick Start Guide](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html)
|
||||
- [Release Notes](https://nvidia.github.io/TensorRT-LLM/release-notes.html)
|
||||
- [Installation Guide for Linux](https://nvidia.github.io/TensorRT-LLM/installation/linux.html)
|
||||
- [Installation Guide for Windows](https://nvidia.github.io/TensorRT-LLM/installation/windows.html)
|
||||
- [Supported Hardware, Models, and other Software](https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html)
|
||||
|
||||
@ -225,9 +225,7 @@ python examples/llama/convert_checkpoint.py --model_dir ${MODEL_CHECKPOINT} \
|
||||
--output_dir ${CONVERTED_CHECKPOINT} \
|
||||
--dtype ${DTYPE} \
|
||||
--tp_size ${TP} \
|
||||
--pp_size 1 \
|
||||
--lora_target_modules attn_qkv \
|
||||
--max_lora_rank ${MAX_LORA_RANK}
|
||||
--pp_size 1
|
||||
|
||||
${HOME}/.local/bin/trtllm-build \
|
||||
--checkpoint_dir ${CONVERTED_CHECKPOINT} \
|
||||
@ -235,13 +233,11 @@ ${HOME}/.local/bin/trtllm-build \
|
||||
--max_batch_size ${MAX_BATCH} \
|
||||
--max_input_len $MAX_LEN \
|
||||
--max_output_len $MAX_LEN \
|
||||
--gpt_attention_plugin float16 \
|
||||
--paged_kv_cache enable \
|
||||
--remove_input_padding enable \
|
||||
--gemm_plugin float16 \
|
||||
--lora_plugin float16 \
|
||||
--use_paged_context_fmha enable \
|
||||
--use_custom_all_reduce disable
|
||||
--lora_target_modules attn_qkv \
|
||||
--max_lora_rank ${MAX_LORA_RANK}
|
||||
|
||||
NUM_LORAS=(8 16 24 32 64 128 256)
|
||||
NUM_REQUESTS=1024
|
||||
|
||||
@ -14,6 +14,14 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*****************************************************************************
|
||||
*
|
||||
* GptSession is going to be deprecated soon.
|
||||
* Please do not add new functionality in this file!
|
||||
*
|
||||
*****************************************************************************/
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/mpiUtils.h"
|
||||
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
|
||||
|
||||
@ -1127,6 +1127,39 @@ _allowed_configs = {
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"qwen1.5_7b_chat":
|
||||
ModelConfig(name="qwen1.5_7b_chat",
|
||||
family="qwen2",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(num_layers=32,
|
||||
num_heads=32,
|
||||
hidden_size=4096,
|
||||
vocab_size=151936,
|
||||
hidden_act='silu',
|
||||
n_positions=8192,
|
||||
inter_size=11008,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
bias=False)),
|
||||
"qwen1.5_14b_chat":
|
||||
ModelConfig(name="qwen1.5_14b_chat",
|
||||
family="qwen2",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
num_layers=40,
|
||||
num_heads=40,
|
||||
hidden_size=5120,
|
||||
vocab_size=152064,
|
||||
hidden_act='silu',
|
||||
n_positions=8192,
|
||||
inter_size=13696,
|
||||
max_batch_size=64,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"mamba_2.8b":
|
||||
ModelConfig(name="mamba_2.8b",
|
||||
family="mamba",
|
||||
|
||||
@ -232,6 +232,7 @@ def build_gpt(args):
|
||||
builder_config_extra_kwargs['mamba_expand'] = build_config[
|
||||
'mamba_expand']
|
||||
builder_config_extra_kwargs['max_beam_width'] = max_beam_width
|
||||
builder_config_extra_kwargs['layer_types'] = ['recurrent']
|
||||
builder_config = builder.create_builder_config(
|
||||
name=args.model,
|
||||
precision=args.dtype,
|
||||
@ -715,6 +716,51 @@ def build_gpt(args):
|
||||
build_config["moe_num_experts"],
|
||||
'moe_top_k':
|
||||
build_config["moe_top_k"],
|
||||
'qwen_type':
|
||||
'qwen',
|
||||
}
|
||||
config = PretrainedConfig.from_dict(config)
|
||||
tensorrt_llm_model = tensorrt_llm.models.QWenForCausalLM(config)
|
||||
elif family == "qwen2":
|
||||
config = {
|
||||
'architecture':
|
||||
'QWenForCausalLM',
|
||||
'dtype':
|
||||
args.dtype,
|
||||
'num_hidden_layers':
|
||||
build_config['num_layers'],
|
||||
'num_attention_heads':
|
||||
build_config['num_heads'],
|
||||
'num_key_value_heads':
|
||||
build_config['num_heads'] if build_config['num_kv_heads'] is None
|
||||
else build_config['num_kv_heads'],
|
||||
'hidden_size':
|
||||
build_config['hidden_size'],
|
||||
'intermediate_size':
|
||||
build_config['inter_size'],
|
||||
'vocab_size':
|
||||
build_config['vocab_size'],
|
||||
'position_embedding_type':
|
||||
'rope_gpt_neox',
|
||||
'max_position_embeddings':
|
||||
build_config['n_positions'],
|
||||
'hidden_act':
|
||||
build_config['hidden_act'],
|
||||
'quantization': {
|
||||
'group_size': 128,
|
||||
'quant_algo': quant_algo,
|
||||
'kv_cache_quant_algo': kv_cache_quant_algo
|
||||
},
|
||||
'mapping': {
|
||||
'world_size': world_size,
|
||||
'tp_size': world_size
|
||||
},
|
||||
'moe_num_experts':
|
||||
build_config["moe_num_experts"],
|
||||
'moe_top_k':
|
||||
build_config["moe_top_k"],
|
||||
'qwen_type':
|
||||
'qwen2',
|
||||
}
|
||||
config = PretrainedConfig.from_dict(config)
|
||||
tensorrt_llm_model = tensorrt_llm.models.QWenForCausalLM(config)
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/batch_manager/schedulerPolicy.h"
|
||||
#include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h"
|
||||
#include "tensorrt_llm/runtime/gptModelConfig.h"
|
||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
|
||||
#include <atomic>
|
||||
@ -79,9 +79,13 @@ public:
|
||||
virtual ~GptManager();
|
||||
|
||||
protected:
|
||||
/* Synchronizes the decoder */
|
||||
virtual BatchManagerErrorCode_t forwardSync();
|
||||
|
||||
/* Invokes one step of backend
|
||||
Updates state of all requests */
|
||||
virtual BatchManagerErrorCode_t step(RequestList& activeRequests, std::set<uint64_t>& activeRequestsIds);
|
||||
virtual BatchManagerErrorCode_t forwardAsync(
|
||||
RequestList& activeRequests, std::unordered_set<uint64_t>& activeRequestsIds);
|
||||
|
||||
private:
|
||||
[[nodiscard]] SizeType getMaxInputLen() const;
|
||||
@ -89,7 +93,7 @@ private:
|
||||
[[nodiscard]] SizeType getMaxNumSequences() const;
|
||||
|
||||
void validateLlmRequest(
|
||||
LlmRequest& newReq, runtime::GptModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig) const;
|
||||
LlmRequest& newReq, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig) const;
|
||||
static std::shared_ptr<LlmRequest> fillLlmRequest(std::shared_ptr<InferenceRequest> newReq);
|
||||
static std::shared_ptr<std::vector<TokenIdType>> getReqInputTokens(std::shared_ptr<InferenceRequest> newReq);
|
||||
static SizeType getMaxNewTokens(std::shared_ptr<InferenceRequest> newReq);
|
||||
@ -108,7 +112,7 @@ private:
|
||||
// List of live requests
|
||||
RequestList mActiveRequests;
|
||||
// IDs of live requests
|
||||
std::set<uint64_t> mActiveRequestsIds;
|
||||
std::unordered_set<uint64_t> mActiveRequestsIds;
|
||||
// Boolean that controls if prompt should be included in output tokens for non-streaming
|
||||
bool mExcludeInputInOutput;
|
||||
|
||||
|
||||
@ -63,6 +63,8 @@ public:
|
||||
&& hostCacheSize == other.hostCacheSize && onboardBlocks == other.onboardBlocks;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, KvCacheConfig const& self);
|
||||
|
||||
std::optional<SizeType> maxTokens;
|
||||
std::optional<SizeType> maxAttentionWindow;
|
||||
std::optional<SizeType> sinkTokenLength;
|
||||
|
||||
@ -18,15 +18,16 @@
|
||||
|
||||
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h" // TODO forward declare
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/kernels/kvCacheIndex.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/gptModelConfig.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <list>
|
||||
@ -89,15 +90,15 @@ struct KvCacheStats
|
||||
class KVCacheBlock
|
||||
{
|
||||
public:
|
||||
using OffsetType = std::int32_t;
|
||||
using IdType = std::int32_t;
|
||||
|
||||
explicit KVCacheBlock(OffsetType blockIdx, OffsetType blocksInPrimaryPool);
|
||||
explicit KVCacheBlock(IdType blockId, kernels::KVCacheIndex blockIdx);
|
||||
|
||||
void startScheduling();
|
||||
|
||||
[[nodiscard]] OffsetType getBlockIdx() const;
|
||||
[[nodiscard]] IdType getBlockId() const;
|
||||
|
||||
[[nodiscard]] OffsetType getMemoryPoolBlockOffset() const;
|
||||
[[nodiscard]] kernels::KVCacheIndex::UnderlyingType getMemoryPoolBlockIndex() const;
|
||||
|
||||
[[nodiscard]] bool isPrimary() const;
|
||||
|
||||
@ -143,11 +144,12 @@ public:
|
||||
[[nodiscard]] bool isShared() const;
|
||||
|
||||
private:
|
||||
// Linear index of block in pool
|
||||
OffsetType mBlockIdx;
|
||||
// Linear ID of block independent of pool
|
||||
IdType mBlockId;
|
||||
|
||||
// Block in memory pool backing this block
|
||||
OffsetType mMemoryPoolBlockOffset;
|
||||
// Index of block in memory pool backing this block
|
||||
// Choice of pool is encoded into the type
|
||||
kernels::KVCacheIndex mMemoryPoolBlockIndex;
|
||||
|
||||
// Number of references to the block
|
||||
SizeType mRefCount;
|
||||
@ -169,9 +171,6 @@ private:
|
||||
|
||||
// Flag indicating if block is full
|
||||
bool mIsFull;
|
||||
|
||||
// Flag indicating mMemoryPoolBlockOffset refers to secondary pool
|
||||
static constexpr OffsetType secondaryPoolFlag = static_cast<OffsetType>(1) << (8 * sizeof(OffsetType) - 1);
|
||||
};
|
||||
|
||||
class GenerationRequest
|
||||
@ -220,14 +219,14 @@ public:
|
||||
return mCacheBlockIds;
|
||||
}
|
||||
|
||||
void addCacheBlock(SizeType beamIdx, SizeType blockIdx)
|
||||
void addCacheBlock(SizeType beamIdx, KVCacheBlock::IdType blockId)
|
||||
{
|
||||
mCacheBlockIds.at(beamIdx).push_back(blockIdx);
|
||||
mCacheBlockIds.at(beamIdx).push_back(blockId);
|
||||
}
|
||||
|
||||
void changeCacheBlock(SizeType beamIdx, SizeType pagedBlockIdx, SizeType blockIdx)
|
||||
void changeCacheBlock(SizeType beamIdx, SizeType pagedBlockIdx, KVCacheBlock::IdType blockId)
|
||||
{
|
||||
mCacheBlockIds.at(beamIdx).at(pagedBlockIdx) = blockIdx;
|
||||
mCacheBlockIds.at(beamIdx).at(pagedBlockIdx) = blockId;
|
||||
}
|
||||
|
||||
void clearCacheBlocks()
|
||||
@ -264,7 +263,7 @@ private:
|
||||
// Number of beams
|
||||
SizeType mBeamWidth;
|
||||
// List of blocks allocated for each beam of the sequence
|
||||
std::vector<std::vector<SizeType>> mCacheBlockIds;
|
||||
std::vector<std::vector<KVCacheBlock::IdType>> mCacheBlockIds;
|
||||
// Number of tokens already in kv cache before context phase.
|
||||
// A value > 0 indicates cached kv cache blocks were reused.
|
||||
// One value per beam.
|
||||
@ -348,7 +347,7 @@ public:
|
||||
|
||||
[[nodiscard]] SizeType getMaxNumBlocks() const noexcept
|
||||
{
|
||||
return static_cast<SizeType>(mAllBlocksByIdx.size());
|
||||
return static_cast<SizeType>(mAllBlocksById.size());
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType getTokensPerBlock() const noexcept
|
||||
@ -356,7 +355,8 @@ public:
|
||||
return mTokensPerBlock;
|
||||
}
|
||||
|
||||
//! \brief Get size of one field in one layer in one block.
|
||||
//! \brief Get size of one K/V cache block in one layer.
|
||||
//! @details Volume of [numKvHeads, tokensPerBlock, sizePerHead]
|
||||
[[nodiscard]] SizeType getBlockSize() const
|
||||
{
|
||||
return mBlockSize;
|
||||
@ -372,10 +372,10 @@ public:
|
||||
return mSecondaryPool;
|
||||
}
|
||||
|
||||
//! \brief Get offset in pool to K or V block.
|
||||
//! \param blockIdx the blockIdx as returned by getBlockIdx()
|
||||
//! \brief Get index in pool to K or V block.
|
||||
//! \param blockId the blockId as returned by getBlockId()
|
||||
//! \param fieldIdx either 0 (K) or 1 (V),
|
||||
[[nodiscard]] SizeType getKOrVBlockOffset(SizeType blockIdx, SizeType fieldIdx) const;
|
||||
[[nodiscard]] kernels::KVCacheIndex getKOrVBlockIndex(KVCacheBlock::IdType blockId, SizeType fieldIdx) const;
|
||||
|
||||
//! \brief Bring offloaded block from secondary to primary memory.
|
||||
//! \details Does nothing of block is already in primary memory.
|
||||
@ -442,7 +442,7 @@ private:
|
||||
// Number of tokens per one block
|
||||
SizeType mTokensPerBlock;
|
||||
// List of all blocks by idx
|
||||
std::vector<BlockPtr> mAllBlocksByIdx;
|
||||
std::vector<BlockPtr> mAllBlocksById;
|
||||
// Dummy block acting as root for BlockToken searches
|
||||
BlockPtr mCachedBlocksRoot;
|
||||
// Statistics for block allocations/reuse
|
||||
@ -452,7 +452,6 @@ private:
|
||||
class KVCacheManager
|
||||
{
|
||||
public:
|
||||
using OffsetType = KVCacheBlock::OffsetType;
|
||||
using SizeType = tensorrt_llm::runtime::SizeType;
|
||||
using SequencesPtr = GenerationRequest::SharedPtr;
|
||||
using CudaStreamPtr = std::shared_ptr<runtime::CudaStream>;
|
||||
@ -495,12 +494,6 @@ public:
|
||||
return kvCacheStats;
|
||||
}
|
||||
|
||||
// Volume of [numKvHeads, tokensPerBlock, sizePerHead]
|
||||
[[nodiscard]] SizeType getBlockSize() const
|
||||
{
|
||||
return mBlockManager.getBlockSize();
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType getMaxBlocksPerSeq() const
|
||||
{
|
||||
return mMaxBlocksPerSeq;
|
||||
@ -544,21 +537,21 @@ public:
|
||||
runtime::ITensor& output, SizeType outputSlotOffset, SizeType seqSlotIdx, SizeType beamWidth) const;
|
||||
|
||||
// Volume of [2, numKvHeads, tokensPerBlock, sizePerHead]
|
||||
[[nodiscard]] static SizeType constexpr calculatePageSize(tensorrt_llm::runtime::GptModelConfig const& modelConfig)
|
||||
[[nodiscard]] static SizeType constexpr calculatePageSize(tensorrt_llm::runtime::ModelConfig const& modelConfig)
|
||||
{
|
||||
return 2 * modelConfig.getNbKvHeads() * modelConfig.getTokensPerBlock() * modelConfig.getSizePerHead();
|
||||
}
|
||||
|
||||
// numLayers * 2 * numKvHeads * sizePerHead
|
||||
[[nodiscard]] static SizeType constexpr calculateCacheSizePerToken(
|
||||
tensorrt_llm::runtime::GptModelConfig const& modelConfig, tensorrt_llm::runtime::WorldConfig const& worldConfig)
|
||||
tensorrt_llm::runtime::ModelConfig const& modelConfig, tensorrt_llm::runtime::WorldConfig const& worldConfig)
|
||||
{
|
||||
return modelConfig.getNbLayers(worldConfig.getPipelineParallelism()) * 2 * modelConfig.getNbKvHeads()
|
||||
return modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism()) * 2 * modelConfig.getNbKvHeads()
|
||||
* modelConfig.getSizePerHead();
|
||||
}
|
||||
|
||||
[[nodiscard]] static std::tuple<SizeType, SizeType> const calculateMaxNumBlocks(KvCacheConfig const& config,
|
||||
nvinfer1::DataType dtype, tensorrt_llm::runtime::GptModelConfig const& modelConfig,
|
||||
nvinfer1::DataType dtype, tensorrt_llm::runtime::ModelConfig const& modelConfig,
|
||||
tensorrt_llm::runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager);
|
||||
|
||||
[[nodiscard]] SizeType getNumPrepopulatedTokens(SizeType batchSlotIdx, SizeType beamIdx) const
|
||||
@ -576,8 +569,8 @@ public:
|
||||
void rewindKVCache(SizeType seqSlotIdx, SizeType rewindLengths);
|
||||
|
||||
private:
|
||||
void setOffsets(OffsetType* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType seqSlotIdx, SizeType beamIdx,
|
||||
SizeType blockIdx, SizeType blockId) const;
|
||||
void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType seqSlotIdx,
|
||||
SizeType beamIdx, SizeType blockIdx, KVCacheBlock::IdType blockId) const;
|
||||
|
||||
void resetBlockOffsets(SizeType seqSlotIdx, SizeType beamWidth);
|
||||
void cacheBlockOffsets(GenerationRequest const& seq, SizeType seqSlotIdx);
|
||||
@ -586,8 +579,6 @@ private:
|
||||
void updateToken(SizeType seqSlotIdx, bool addToken);
|
||||
|
||||
private:
|
||||
// Number of layers
|
||||
SizeType mNumLayers;
|
||||
// Maximum number of sequences
|
||||
SizeType mMaxNumSequences;
|
||||
// Maximum beam width
|
||||
@ -607,8 +598,8 @@ private:
|
||||
BlockManager mBlockManager;
|
||||
// List of all sequences
|
||||
std::vector<SequencesPtr> mSequences;
|
||||
// buffer for block offsets for all managed sequences
|
||||
runtime::ITensor::SharedPtr mSequenceBlockOffsets;
|
||||
// buffer for block indices for all managed sequences
|
||||
runtime::ITensor::SharedPtr mSequenceBlockIndices;
|
||||
// Whether to cache KV pages for reuse
|
||||
bool mEnableBlockReuse;
|
||||
};
|
||||
|
||||
@ -92,6 +92,7 @@ public:
|
||||
, mCumLogProbs(samplingConfig.beamWidth)
|
||||
, mDraftTokens(draftTokens.value_or(std::make_shared<VecTokens>()))
|
||||
, mDraftLogits(draftLogits)
|
||||
, mNumTokensPerIteration(1)
|
||||
, mReturnContextLogits(returnContextLogits)
|
||||
, mReturnGenerationLogits(returnGenerationLogits)
|
||||
, mExcludeInputFromOutput(excludeInputFromOutput)
|
||||
@ -189,9 +190,9 @@ public:
|
||||
{
|
||||
auto const maxNewTokens = maxSequenceLen - mPromptLen;
|
||||
TLLM_LOG_WARNING(
|
||||
"Number of requested output tokens (%d) exceeds maximum sequence length (%d). "
|
||||
"Prompt length + number of requested output tokens (%d + %d) exceeds maximum sequence length (%d). "
|
||||
"Number of requested output tokens is changed to (%d).",
|
||||
mMaxNewTokens, maxSequenceLen, maxNewTokens);
|
||||
mPromptLen, mMaxNewTokens, maxSequenceLen, maxNewTokens);
|
||||
mMaxNewTokens = maxNewTokens;
|
||||
}
|
||||
|
||||
@ -494,6 +495,16 @@ public:
|
||||
return mDraftTokens->size();
|
||||
}
|
||||
|
||||
void setNumTokensPerIteration(SizeType numTokensPerIteration)
|
||||
{
|
||||
mNumTokensPerIteration = numTokensPerIteration;
|
||||
}
|
||||
|
||||
SizeType getNumTokensPerIteration() const
|
||||
{
|
||||
return mNumTokensPerIteration;
|
||||
}
|
||||
|
||||
void setReturnContextLogits(bool const returnContextLogits)
|
||||
{
|
||||
mReturnContextLogits = returnContextLogits;
|
||||
@ -766,6 +777,7 @@ protected:
|
||||
VecLogProbs mCumLogProbs; // [beamSize]
|
||||
std::shared_ptr<VecTokens> mDraftTokens;
|
||||
std::optional<TensorPtr> mDraftLogits;
|
||||
SizeType mNumTokensPerIteration;
|
||||
|
||||
// Save logits
|
||||
bool mReturnContextLogits;
|
||||
|
||||
@ -12,10 +12,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/common.h"
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/batch_manager/peftCacheManagerConfig.h"
|
||||
#include "tensorrt_llm/runtime/gptModelConfig.h"
|
||||
#include "tensorrt_llm/runtime/loraCache.h"
|
||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||
#include "tensorrt_llm/runtime/workerPool.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
#include <NvInferRuntimeBase.h>
|
||||
@ -23,6 +24,7 @@
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
@ -39,7 +41,7 @@ class BasePeftCacheManager
|
||||
{
|
||||
public:
|
||||
using LlmRequestPtr = std::shared_ptr<LlmRequest>;
|
||||
using RequestTable = std::map<uint64_t, LlmRequestPtr>;
|
||||
using RequestVector = std::vector<LlmRequestPtr>;
|
||||
using PeftTable = std::map<uint64_t, std::shared_ptr<std::vector<runtime::LoraCache::TaskLayerModuleConfig>>>;
|
||||
|
||||
/**
|
||||
@ -50,13 +52,14 @@ public:
|
||||
virtual void addRequestPeft(LlmRequestPtr llmRequest, bool tryGpuCache = true) = 0;
|
||||
|
||||
/**
|
||||
* \brief ensures device cache has all the weights needed to execute batch as specified by requestTable.
|
||||
* \brief ensures device cache has all the weights needed to execute batch as specified by requests.
|
||||
* This acts as sync for the copy tasks started by addRequestPeft
|
||||
* \param[in] requestTable: current request table
|
||||
* \param[in] contextRequests: current context requests
|
||||
* \param[in] genRequests: current generation requests
|
||||
* \param[in] resetGpuCache: reset (make all tasks evictable)
|
||||
* \returns -- a PeftTable
|
||||
*/
|
||||
virtual PeftTable ensureBatch(RequestTable const& requestTable, bool resetGpuCache = false) = 0;
|
||||
virtual PeftTable ensureBatch(ScheduledRequests const& scheduledRequests, bool resetGpuCache = false) = 0;
|
||||
|
||||
/**
|
||||
* \brief mark all the tasks in device cache as done
|
||||
@ -77,12 +80,12 @@ public:
|
||||
class PeftCacheManager : public BasePeftCacheManager
|
||||
{
|
||||
public:
|
||||
PeftCacheManager(PeftCacheManagerConfig const& config, runtime::GptModelConfig const& modelConfig,
|
||||
PeftCacheManager(PeftCacheManagerConfig const& config, runtime::ModelConfig const& modelConfig,
|
||||
runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager);
|
||||
|
||||
void addRequestPeft(std::shared_ptr<LlmRequest> llmRequest, bool tryGpuCache = true) override;
|
||||
|
||||
PeftTable ensureBatch(RequestTable const& requestTable, bool resetGpuCache = false) override;
|
||||
PeftTable ensureBatch(ScheduledRequests const& scheduledRequests, bool resetGpuCache = false) override;
|
||||
|
||||
[[nodiscard]] bool isTaskCached(uint64_t taskId) const;
|
||||
|
||||
@ -116,7 +119,7 @@ public:
|
||||
runtime::BufferManager const& bufferManager);
|
||||
|
||||
static std::pair<runtime::LoraCachePageManagerConfig, runtime::LoraCachePageManagerConfig> getPageManagerConfig(
|
||||
PeftCacheManagerConfig const& config, runtime::GptModelConfig const& modelConfig,
|
||||
PeftCacheManagerConfig const& config, runtime::ModelConfig const& modelConfig,
|
||||
runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager);
|
||||
|
||||
private:
|
||||
@ -133,9 +136,9 @@ private:
|
||||
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> mTaskIdToPausedReqIds;
|
||||
|
||||
std::tuple<std::map<uint64_t, std::future<void>>, std::map<uint64_t, std::vector<uint64_t>>> getTaskMaps(
|
||||
RequestTable const& requestTable);
|
||||
ScheduledRequests const& scheduledRequests);
|
||||
|
||||
runtime::GptModelConfig mModelConfig;
|
||||
runtime::ModelConfig mModelConfig;
|
||||
runtime::WorldConfig mWorldConfig;
|
||||
|
||||
int mDevice{-1};
|
||||
@ -145,7 +148,7 @@ class NoOpPeftCacheManager : public BasePeftCacheManager
|
||||
{
|
||||
void addRequestPeft(std::shared_ptr<LlmRequest> llmRequest, bool tryGpuCache = true) override;
|
||||
|
||||
PeftTable ensureBatch(RequestTable const& requestTable, bool resetGpuCache = false) override;
|
||||
PeftTable ensureBatch(ScheduledRequests const& scheduledRequests, bool resetGpuCache = false) override;
|
||||
|
||||
void resetDeviceCache() override;
|
||||
|
||||
|
||||
@ -60,6 +60,7 @@ struct PeftCacheManagerConfig
|
||||
, optimalAdapterSize(cfg.getOptimalAdapterSize())
|
||||
, maxAdapterSize(cfg.getMaxAdapterSize())
|
||||
, numPutWorkers(cfg.getNumPutWorkers())
|
||||
, numEnsureWorkers(cfg.getNumEnsureWorkers())
|
||||
, numCopyStreams(cfg.getNumCopyStreams())
|
||||
, maxPagesPerBlockHost(cfg.getMaxPagesPerBlockHost())
|
||||
, maxPagesPerBlockDevice(cfg.getMaxPagesPerBlockDevice())
|
||||
|
||||
@ -31,4 +31,6 @@ SchedulerPolicy execToBatchManagerSchedPolicy(executor::SchedulerPolicy policy);
|
||||
|
||||
executor::SchedulerPolicy batchManagerToExecSchedPolicy(SchedulerPolicy policy);
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, SchedulerPolicy policy);
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::batch_scheduler
|
||||
|
||||
@ -57,7 +57,9 @@ public:
|
||||
explicit TrtGptModelOptionalParams(executor::ExecutorConfig const& executorConfig)
|
||||
: TrtGptModelOptionalParams(KvCacheConfig(executorConfig.getKvCacheConfig()), false,
|
||||
executorConfig.getParallelConfig().value_or(executor::ParallelConfig()).getDeviceIds(),
|
||||
executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext(), std::nullopt,
|
||||
executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext(),
|
||||
runtime::DecodingMode::fromExecutor(
|
||||
executorConfig.getDecodingMode().value_or(executor::DecodingMode::kNONE)),
|
||||
PeftCacheManagerConfig(executorConfig.getPeftCacheConfig().value_or(executor::PeftCacheConfig())),
|
||||
executorConfig.getMedusaChoices())
|
||||
{
|
||||
@ -70,6 +72,8 @@ public:
|
||||
&& enableChunkedContext == other.enableChunkedContext && decodingMode == other.decodingMode;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, TrtGptModelOptionalParams const& self);
|
||||
|
||||
KvCacheConfig kvCacheConfig;
|
||||
|
||||
bool enableTrtOverlap;
|
||||
|
||||
@ -17,7 +17,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
#include "tensorrt_llm/runtime/utils/multiDeviceUtils.h"
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
@ -36,6 +35,11 @@
|
||||
|
||||
#define MPICHECK(cmd) TLLM_MPI_CHECK(cmd)
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
class IBuffer;
|
||||
}
|
||||
|
||||
// A wrapper module of the MPI library.
|
||||
namespace tensorrt_llm::mpi
|
||||
{
|
||||
@ -234,18 +238,11 @@ public:
|
||||
|
||||
std::shared_ptr<MpiRequest> bcastAsync(void* buffer, size_t size, MpiType dtype, int root) const;
|
||||
|
||||
std::shared_ptr<MpiRequest> bcastAsync(runtime::IBuffer& buf, int root) const
|
||||
{
|
||||
TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU);
|
||||
return bcastAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
|
||||
}
|
||||
std::shared_ptr<MpiRequest> bcastAsync(runtime::IBuffer& buf, int root) const;
|
||||
|
||||
void bcast(void* buffer, size_t size, MpiType dtype, int root) const;
|
||||
|
||||
void bcast(runtime::IBuffer& buf, int root) const
|
||||
{
|
||||
bcast(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
|
||||
}
|
||||
void bcast(runtime::IBuffer& buf, int root) const;
|
||||
|
||||
template <typename T>
|
||||
void bcastValue(T& value, int root) const
|
||||
@ -281,11 +278,7 @@ public:
|
||||
|
||||
void send(void const* buffer, std::size_t size, MpiType dtype, int dest, int tag) const;
|
||||
|
||||
void send(runtime::IBuffer const& buf, int dest, int tag) const
|
||||
{
|
||||
TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU);
|
||||
send(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag);
|
||||
}
|
||||
void send(runtime::IBuffer const& buf, int dest, int tag) const;
|
||||
|
||||
template <typename T>
|
||||
void send(T const& value, int dest, int tag) const
|
||||
@ -302,11 +295,7 @@ public:
|
||||
|
||||
MPI_Status recv(void* buffer, size_t size, MpiType dtype, int source, int tag) const;
|
||||
|
||||
MPI_Status recv(runtime::IBuffer& buf, int source, int tag) const
|
||||
{
|
||||
TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU);
|
||||
return recv(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, source, tag);
|
||||
}
|
||||
MPI_Status recv(runtime::IBuffer& buf, int source, int tag) const;
|
||||
|
||||
template <typename T>
|
||||
MPI_Status recv(T& value, int source, int tag) const
|
||||
|
||||
@ -29,6 +29,11 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::mpi
|
||||
{
|
||||
class MpiComm;
|
||||
}
|
||||
|
||||
namespace tensorrt_llm::executor
|
||||
{
|
||||
|
||||
@ -310,6 +315,7 @@ public:
|
||||
[[nodiscard]] Result getResult() const;
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> mImpl;
|
||||
};
|
||||
@ -323,6 +329,8 @@ public:
|
||||
[[nodiscard]] SchedulerPolicy getPolicy() const;
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
|
||||
/// @brief The scheduler policy. See SchedulerPolicy.
|
||||
SchedulerPolicy mPolicy;
|
||||
};
|
||||
@ -346,6 +354,8 @@ public:
|
||||
[[nodiscard]] bool getOnboardBlocks() const;
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
|
||||
/// @brief Controls if KV cache blocks can be reused for different requests
|
||||
bool mEnableBlockReuse;
|
||||
|
||||
@ -378,6 +388,26 @@ SizeType const kDefaultIterStatsMaxIterations = 1000;
|
||||
// Per request stats may have additional overhead due to going through all requests. Turned off by default.
|
||||
SizeType const kDefaultRequestStatsMaxIterations = 0;
|
||||
|
||||
class OrchestratorConfig
|
||||
{
|
||||
public:
|
||||
explicit OrchestratorConfig(bool isOrchestrator = true, std::string workerExecutablePath = "",
|
||||
std::shared_ptr<mpi::MpiComm> orchLeaderComm = nullptr);
|
||||
|
||||
[[nodiscard]] bool getIsOrchestrator() const;
|
||||
[[nodiscard]] std::string getWorkerExecutablePath() const;
|
||||
[[nodiscard]] std::shared_ptr<mpi::MpiComm> getOrchLeaderComm() const;
|
||||
|
||||
void setIsOrchestrator(bool isOrchestrator);
|
||||
void setWorkerExecutablePath(std::string const& workerExecutablePath);
|
||||
void setOrchLeaderComm(std::shared_ptr<mpi::MpiComm> const& orchLeaderComm);
|
||||
|
||||
private:
|
||||
bool mIsOrchestrator;
|
||||
std::string mWorkerExecutablePath;
|
||||
std::shared_ptr<mpi::MpiComm> mOrchLeaderComm;
|
||||
};
|
||||
|
||||
/// @brief A configuration class for the parallel execution parameters
|
||||
/// Currently only supports commType = CommunicationType::kMPI
|
||||
class ParallelConfig
|
||||
@ -392,19 +422,24 @@ public:
|
||||
explicit ParallelConfig(CommunicationType commType = CommunicationType::kMPI,
|
||||
CommunicationMode commMode = CommunicationMode::kLEADER,
|
||||
std::optional<std::vector<SizeType>> deviceIds = std::nullopt,
|
||||
std::optional<std::vector<SizeType>> participantIds = std::nullopt);
|
||||
std::optional<std::vector<SizeType>> participantIds = std::nullopt,
|
||||
std::optional<OrchestratorConfig> const& orchestratorConfig = std::nullopt);
|
||||
|
||||
[[nodiscard]] CommunicationType getCommunicationType() const;
|
||||
[[nodiscard]] CommunicationMode getCommunicationMode() const;
|
||||
[[nodiscard]] std::optional<std::vector<SizeType>> getDeviceIds() const;
|
||||
[[nodiscard]] std::optional<std::vector<SizeType>> getParticipantIds() const;
|
||||
[[nodiscard]] std::optional<OrchestratorConfig> getOrchestratorConfig() const;
|
||||
|
||||
void setCommunicationType(CommunicationType type);
|
||||
void setCommunicationMode(CommunicationMode mode);
|
||||
void setDeviceIds(std::vector<SizeType> const& deviceIds);
|
||||
void setParticipantIds(std::vector<SizeType> const& participantIds);
|
||||
void setOrchestratorConfig(OrchestratorConfig const& orchestratorConfig);
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
|
||||
/// @brief The type of communication protocol used. Default is MPI.
|
||||
CommunicationType mCommType;
|
||||
|
||||
@ -416,6 +451,9 @@ private:
|
||||
|
||||
/// @brief The participant ids (MPI ranks for example) used for executing this model
|
||||
std::optional<std::vector<SizeType>> mParticipantIds;
|
||||
|
||||
/// @brief Optional orchestrator configuration
|
||||
std::optional<OrchestratorConfig> mOrchestratorConfig;
|
||||
};
|
||||
|
||||
/// @brief config for PeftCacheManager
|
||||
@ -428,6 +466,8 @@ public:
|
||||
SizeType maxPagesPerBlockDevice = 8, std::optional<float> const& deviceCachePercent = std::nullopt,
|
||||
std::optional<size_t> const& hostCacheSize = std::nullopt);
|
||||
|
||||
bool operator==(PeftCacheConfig const& other) const;
|
||||
|
||||
[[nodiscard]] SizeType getNumHostModuleLayer() const;
|
||||
[[nodiscard]] SizeType getNumDeviceModuleLayer() const;
|
||||
[[nodiscard]] SizeType getOptimalAdapterSize() const;
|
||||
@ -441,6 +481,8 @@ public:
|
||||
[[nodiscard]] std::optional<size_t> getHostCacheSize() const;
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
|
||||
// number of max sized 1-layer 1-module adapterSize=1 sets of weights that can be stored in host cache
|
||||
SizeType mNumHostModuleLayer;
|
||||
// number of max sized 1-layer 1-module sets of weights that can be stored in host cache
|
||||
@ -460,7 +502,7 @@ private:
|
||||
// Number of cache pages per allocation block (device)
|
||||
SizeType mMaxPagesPerBlockDevice;
|
||||
// percent of memory after engine load to use for cache
|
||||
std::optional<float> mDeviceCachePercent;
|
||||
std::optional<FloatType> mDeviceCachePercent;
|
||||
// size in bytes to use for host cache
|
||||
std::optional<size_t> mHostCacheSize;
|
||||
};
|
||||
@ -477,7 +519,8 @@ public:
|
||||
std::optional<ParallelConfig> parallelConfig = std::nullopt,
|
||||
std::optional<PeftCacheConfig> const& peftCacheConfig = std::nullopt,
|
||||
std::optional<LogitsPostProcessorMap> logitsPostProcessorMap = std::nullopt,
|
||||
std::optional<MedusaChoices> medusaChoices = std::nullopt);
|
||||
std::optional<MedusaChoices> medusaChoices = std::nullopt,
|
||||
std::optional<DecodingMode> decodingMode = std::nullopt);
|
||||
|
||||
[[nodiscard]] SizeType getMaxBeamWidth() const;
|
||||
[[nodiscard]] SchedulerConfig getSchedulerConfig() const;
|
||||
@ -491,6 +534,7 @@ public:
|
||||
[[nodiscard]] std::optional<PeftCacheConfig> getPeftCacheConfig() const;
|
||||
[[nodiscard]] std::optional<LogitsPostProcessorMap> getLogitsPostProcessorMap() const;
|
||||
[[nodiscard]] std::optional<MedusaChoices> getMedusaChoices() const;
|
||||
[[nodiscard]] std::optional<DecodingMode> getDecodingMode() const;
|
||||
|
||||
void setMaxBeamWidth(SizeType maxBeamWidth);
|
||||
void setSchedulerConfig(SchedulerConfig const& schedulerConfig);
|
||||
@ -504,8 +548,11 @@ public:
|
||||
void setPeftCacheConfig(PeftCacheConfig const& peftCacheConfig);
|
||||
void setLogitsPostProcessorMap(LogitsPostProcessorMap const& logitsPostProcessorMap);
|
||||
void setMedusaChoices(MedusaChoices const& medusaChoices);
|
||||
void setDecodingMode(DecodingMode decodingMode);
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
|
||||
/// @brief The beam width value of requests that will be sent to the executor
|
||||
SizeType mMaxBeamWidth;
|
||||
|
||||
@ -535,6 +582,7 @@ private:
|
||||
std::optional<PeftCacheConfig> mPeftCacheConfig;
|
||||
std::optional<LogitsPostProcessorMap> mLogitsPostProcessorMap;
|
||||
std::optional<MedusaChoices> mMedusaChoices;
|
||||
std::optional<DecodingMode> mDecodingMode;
|
||||
};
|
||||
|
||||
/// @brief The executor is responsible for receiving new requests and sending responses, and running the inference
|
||||
|
||||
117
cpp/include/tensorrt_llm/executor/serialization.h
Normal file
117
cpp/include/tensorrt_llm/executor/serialization.h
Normal file
@ -0,0 +1,117 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/executor/tensor.h"
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include <istream>
|
||||
#include <ostream>
|
||||
|
||||
namespace tensorrt_llm::executor
|
||||
{
|
||||
|
||||
class Serialization
|
||||
{
|
||||
public:
|
||||
// SamplingConfig
|
||||
[[nodiscard]] static SamplingConfig deserializeSamplingConfig(std::istream& is);
|
||||
static void serialize(SamplingConfig const& config, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(SamplingConfig const& config);
|
||||
|
||||
// OutputConfig
|
||||
[[nodiscard]] static OutputConfig deserializeOutputConfig(std::istream& is);
|
||||
static void serialize(OutputConfig const& config, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(OutputConfig const& config);
|
||||
|
||||
// SpeculativeDecodingConfig
|
||||
[[nodiscard]] static SpeculativeDecodingConfig deserializeSpeculativeDecodingConfig(std::istream& is);
|
||||
static void serialize(SpeculativeDecodingConfig const& config, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(SpeculativeDecodingConfig const& config);
|
||||
|
||||
// PromptTuningConfig
|
||||
[[nodiscard]] static PromptTuningConfig deserializePromptTuningConfig(std::istream& is);
|
||||
static void serialize(PromptTuningConfig const& config, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(PromptTuningConfig const& config);
|
||||
|
||||
// LoraConfig
|
||||
[[nodiscard]] static LoraConfig deserializeLoraConfig(std::istream& is);
|
||||
static void serialize(LoraConfig const& config, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(LoraConfig const& config);
|
||||
|
||||
// Request
|
||||
[[nodiscard]] static Request deserializeRequest(std::istream& is);
|
||||
static void serialize(Request const& request, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(Request const& request);
|
||||
|
||||
// Tensor
|
||||
[[nodiscard]] static Tensor deserializeTensor(std::istream& is);
|
||||
static void serialize(Tensor const& tensor, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(Tensor const& tensor);
|
||||
|
||||
// Result
|
||||
[[nodiscard]] static Result deserializeResult(std::istream& is);
|
||||
static void serialize(Result const& result, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(Result const& result);
|
||||
|
||||
// Response
|
||||
[[nodiscard]] static Response deserializeResponse(std::istream& is);
|
||||
static void serialize(Response const& response, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(Response const& response);
|
||||
|
||||
// Vector of responses
|
||||
static std::vector<Response> deserializeResponses(std::vector<char>& buffer);
|
||||
static std::vector<char> serialize(std::vector<Response> const& responses);
|
||||
|
||||
// KvCacheConfig
|
||||
static KvCacheConfig deserializeKvCacheConfig(std::istream& is);
|
||||
static void serialize(KvCacheConfig const& kvCacheConfig, std::ostream& os);
|
||||
static size_t serializedSize(KvCacheConfig const& kvCacheConfig);
|
||||
|
||||
// SchedulerConfig
|
||||
static SchedulerConfig deserializeSchedulerConfig(std::istream& is);
|
||||
static void serialize(SchedulerConfig const& schedulerConfig, std::ostream& os);
|
||||
static size_t serializedSize(SchedulerConfig const& schedulerConfig);
|
||||
|
||||
// ParallelConfig
|
||||
static ParallelConfig deserializeParallelConfig(std::istream& is);
|
||||
static void serialize(ParallelConfig const& parallelConfig, std::ostream& os);
|
||||
static size_t serializedSize(ParallelConfig const& parallelConfig);
|
||||
|
||||
// PeftCacheConfig
|
||||
static PeftCacheConfig deserializePeftCacheConfig(std::istream& is);
|
||||
static void serialize(PeftCacheConfig const& peftCacheConfig, std::ostream& os);
|
||||
static size_t serializedSize(PeftCacheConfig const& peftCacheConfig);
|
||||
|
||||
// OrchestratorConfig
|
||||
static OrchestratorConfig deserializeOrchestratorConfig(std::istream& is);
|
||||
static void serialize(OrchestratorConfig const& orchestratorConfig, std::ostream& os);
|
||||
static size_t serializedSize(OrchestratorConfig const& orchestratorConfig);
|
||||
|
||||
// ExecutorConfig
|
||||
static ExecutorConfig deserializeExecutorConfig(std::istream& is);
|
||||
static void serialize(ExecutorConfig const& executorConfig, std::ostream& os);
|
||||
static size_t serializedSize(ExecutorConfig const& executorConfig);
|
||||
|
||||
// String
|
||||
static std::string deserializeString(std::istream& is);
|
||||
|
||||
// ModelType
|
||||
static ModelType deserializeModelType(std::istream& is);
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::executor
|
||||
@ -191,6 +191,9 @@ enum class CommunicationMode
|
||||
kLEADER, // With the leader mode, only the leader can enqueue requests. The requests will be
|
||||
// broadcasted to the workers. All participants can get response via awaitResponses. The leader is the
|
||||
// first participant in the provided participant IDS, or 0 if participant ID is not provided
|
||||
kORCHESTRATOR, // With the orchestrator mode, only the orchestrator can enqueue requests and await responses. The
|
||||
// requests will be broadcasted to the workers. The orchestrator will spawn new processes for the
|
||||
// execution of the model
|
||||
};
|
||||
|
||||
/// @brief Struct that holds the stats of a KV cache manager
|
||||
@ -305,4 +308,17 @@ struct RequestStatsPerIteration
|
||||
std::vector<RequestStats> requestStats;
|
||||
};
|
||||
|
||||
/// @brief Decoding mode
|
||||
enum class DecodingMode
|
||||
{
|
||||
/// @brief No mode specified. Config will be determined from the beam width of the first request at runtime
|
||||
/// TopKTopP if beamWidth == 1, BeamSearch otherwise
|
||||
kNONE,
|
||||
kTOP_K,
|
||||
kTOP_P,
|
||||
kBEAM_SEARCH,
|
||||
kMEDUSA,
|
||||
kTOP_K_TOP_P,
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::executor
|
||||
|
||||
@ -16,6 +16,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace runtime
|
||||
@ -54,37 +56,37 @@ public:
|
||||
return DecodingMode{kMedusa};
|
||||
}
|
||||
|
||||
bool constexpr isNone()
|
||||
bool constexpr isNone() const
|
||||
{
|
||||
return mState == 0;
|
||||
}
|
||||
|
||||
bool constexpr isTopK()
|
||||
bool constexpr isTopK() const
|
||||
{
|
||||
return anyBitSet(kTopK);
|
||||
}
|
||||
|
||||
bool constexpr isTopP()
|
||||
bool constexpr isTopP() const
|
||||
{
|
||||
return anyBitSet(kTopP);
|
||||
}
|
||||
|
||||
bool constexpr isTopKorTopP()
|
||||
bool constexpr isTopKorTopP() const
|
||||
{
|
||||
return anyBitSet(kTopKTopP);
|
||||
}
|
||||
|
||||
bool constexpr isTopKandTopP()
|
||||
bool constexpr isTopKandTopP() const
|
||||
{
|
||||
return allBitSet(kTopKTopP);
|
||||
}
|
||||
|
||||
bool constexpr isBeamSearch()
|
||||
bool constexpr isBeamSearch() const
|
||||
{
|
||||
return anyBitSet(kBeamSearch);
|
||||
}
|
||||
|
||||
bool constexpr isMedusa()
|
||||
bool constexpr isMedusa() const
|
||||
{
|
||||
return anyBitSet(kMedusa);
|
||||
}
|
||||
@ -96,6 +98,28 @@ public:
|
||||
return mState == other.mState;
|
||||
}
|
||||
|
||||
static DecodingMode fromExecutor(executor::DecodingMode decodingMode)
|
||||
{
|
||||
switch (decodingMode)
|
||||
{
|
||||
case executor::DecodingMode::kNONE: return DecodingMode::None();
|
||||
|
||||
case executor::DecodingMode::kTOP_K: return DecodingMode::TopK();
|
||||
|
||||
case executor::DecodingMode::kTOP_P: return DecodingMode::TopP();
|
||||
|
||||
case executor::DecodingMode::kBEAM_SEARCH: return DecodingMode::BeamSearch();
|
||||
|
||||
case executor::DecodingMode::kMEDUSA: return DecodingMode::Medusa();
|
||||
|
||||
case executor::DecodingMode::kTOP_K_TOP_P: return DecodingMode::TopKTopP();
|
||||
|
||||
default: TLLM_THROW("Invalid decoding mode"); break;
|
||||
}
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, DecodingMode other);
|
||||
|
||||
private:
|
||||
constexpr DecodingMode(UnderlyingType state)
|
||||
: mState(state)
|
||||
|
||||
@ -29,17 +29,21 @@ class DecodingOutput
|
||||
public:
|
||||
using TensorPtr = ITensor::SharedPtr;
|
||||
|
||||
// BS: batch_size, BM: beam_width, MSL: max_seq_length
|
||||
// All TensorPtr without special comments are on gpu
|
||||
|
||||
class BeamHypotheses
|
||||
{
|
||||
public:
|
||||
TensorPtr outputIdsTgt; // [batchSize, 2 * beamWidth, maxSeqLen]
|
||||
TensorPtr sequenceLengthsTgt; // [batchSize, 2 * beamWidth]
|
||||
TensorPtr cumLogProbs; // [batchSize, 2 * beamWidth]
|
||||
TensorPtr normedScores; // [batchSize, 2 * beamWidth]
|
||||
TensorPtr logProbs; // [batchSize, 2 * beamWidth, maxSeqLen]
|
||||
TensorPtr minNormedScores; // [batchSize]
|
||||
TensorPtr numBeams; // [batchSize]
|
||||
TensorPtr isDone; // [batchSize]
|
||||
// The same as cpp/tensorrt_llm/kernels/beamSearchKernels.h
|
||||
TensorPtr outputIdsCBA; // [BS, BM*2, MSL]
|
||||
TensorPtr sequenceLengthsCBA; // [BS, BM]
|
||||
TensorPtr cumLogProbsCBA; // [BS, BM*2]
|
||||
TensorPtr normedScoresCBA; // [BS, BM*2]
|
||||
TensorPtr logProbsCBA; // [BS, BM*2, MSL]
|
||||
TensorPtr minNormedScoresCBA; // [BS]
|
||||
TensorPtr numBeamsCBA; // [BS]
|
||||
TensorPtr batchDones; // [BS]
|
||||
|
||||
void empty(BufferManager& manager);
|
||||
|
||||
@ -61,27 +65,26 @@ public:
|
||||
}
|
||||
|
||||
// mandatory parameters
|
||||
TensorPtr ids; // [batchSize, beamWidth, maxSeqLen], on gpu, must contain previously generated token ids for all
|
||||
// steps before DecodingInput.step
|
||||
TensorPtr newTokensSteps; // [maxTokensPerStep, batchSize, beamWidth] new tokens at each generated token of
|
||||
// maxTokensPerStep, on gpu.
|
||||
TensorPtr newTokens; // [batchSize, beamWidth] usually a view of newTokensSteps for the current token, on gpu.
|
||||
std::vector<TensorPtr> newTokensVec; // vector of size maxTokensPerStep with tensor [batchSize, beamWidth].
|
||||
// Vector of views on newTokensSteps for each token. Elements are on gpu.
|
||||
TensorPtr ids; // [BS, BM, MSL], contains previously generated token ids for all
|
||||
// steps before DecodingInput.step
|
||||
TensorPtr newTokensSteps; // [maxTokensPerStep, BS, BM] new tokens at each generated token of
|
||||
// maxTokensPerStep
|
||||
TensorPtr newTokens; // [BS, BM] usually a view of newTokensSteps for the current token
|
||||
std::vector<TensorPtr> newTokensVec; // vector of size maxTokensPerStep with tensor [BS, BM].
|
||||
// Vector of views on newTokensSteps for each token
|
||||
|
||||
// optional parameters
|
||||
TensorPtr finished; // [batchSize, beamWidth],
|
||||
// Set to true by decoding if any of the stop conditions are met or if DecodingInput.finished is
|
||||
// true. In beam search and to determine whether to stop according to
|
||||
// DecodingInput.sequenceLimitLength, on gpu
|
||||
TensorPtr finishedSum; // [batchSize], the sum of finished sequences per request, in pinned memory
|
||||
TensorPtr finished; // [BS, BM], set to true by decoding if any of the stop conditions are met or if
|
||||
// DecodingInput.finished is true. In beam search and to determine whether to stop according to
|
||||
// DecodingInput.sequenceLimitLength
|
||||
TensorPtr finishedSum; // [BS], the sum of finished sequences per request, in pinned memory
|
||||
|
||||
// mandatory parameters for beam search
|
||||
TensorPtr logProbs; // [batchSize, beamWidth, maxSeqLen], must be float*, on gpu
|
||||
TensorPtr cumLogProbs; // [batchSize, beamWidth], optional for sampling, on gpu
|
||||
TensorPtr parentIds; // [batchSize, beamWidth, maxSeqLen], on gpu
|
||||
TensorPtr lengths; // [batchSize, beamWidth], total sequence lengths including padding, on gpu
|
||||
TensorPtr cacheIndirection; // [batchSize, beamWidth, maxSeqLen], k/v indirection for next generation step, on gpu
|
||||
TensorPtr logProbs; // [BS, BM, MSL], must be float*
|
||||
TensorPtr cumLogProbs; // [BS, BM], optional for sampling
|
||||
TensorPtr parentIds; // [BS, BM, MSL]
|
||||
TensorPtr lengths; // [BS, BM], total sequence lengths including padding
|
||||
TensorPtr cacheIndirection; // [BS, BM, MSL], k/v indirection for next generation step
|
||||
|
||||
BeamHypotheses beamHypotheses;
|
||||
|
||||
@ -89,10 +92,10 @@ public:
|
||||
class MedusaOutputs
|
||||
{
|
||||
public:
|
||||
TensorPtr medusaNextDraftTokens; // [maxBatchSize, maxTokensPerStep], on gpu
|
||||
TensorPtr medusaAcceptedTokensLen; // [maxBatchSize], on gpu
|
||||
TensorPtr medusaAcceptedLengthsCumSum; // [maxBatchSize + 1], on gpu
|
||||
TensorPtr medusaPathsOffsets; // [maxBatchSize * maxNumHeads], on gpu
|
||||
TensorPtr medusaNextDraftTokens; // [maxBatchSize, maxTokensPerStep]
|
||||
TensorPtr medusaAcceptedTokensLen; // [maxBatchSize]
|
||||
TensorPtr medusaAcceptedLengthsCumSum; // [maxBatchSize + 1]
|
||||
TensorPtr medusaPathsOffsets; // [maxBatchSize * maxNumHeads]
|
||||
};
|
||||
|
||||
std::optional<MedusaOutputs> medusaOutputs;
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
#include "tensorrt_llm/runtime/decodingInput.h"
|
||||
#include "tensorrt_llm/runtime/decodingMode.h"
|
||||
#include "tensorrt_llm/runtime/decodingOutput.h"
|
||||
#include "tensorrt_llm/runtime/gptModelConfig.h"
|
||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||
#include "tensorrt_llm/runtime/samplingConfig.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
#include <curand_kernel.h>
|
||||
|
||||
@ -43,7 +43,7 @@ public:
|
||||
//! Setup the decoder before calling `forward()`
|
||||
void setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow,
|
||||
SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep, bool fusedDecoder,
|
||||
nvinfer1::DataType dtype, GptModelConfig const& modelConfig) override;
|
||||
nvinfer1::DataType dtype, ModelConfig const& modelConfig) override;
|
||||
|
||||
void newBatch(
|
||||
GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig) override;
|
||||
@ -182,7 +182,7 @@ private:
|
||||
void allocateMedusaBuffers();
|
||||
|
||||
//! @brief Setup buffers for medusa decoding.
|
||||
void setupMedusa(GptModelConfig const& modelConfig);
|
||||
void setupMedusa(ModelConfig const& modelConfig);
|
||||
|
||||
//! @brief Setups decoder internal tensors for new speculative decoding request
|
||||
void newRequestSpeculativeDecoding(
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/gptModelConfig.h"
|
||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
|
||||
#include <filesystem>
|
||||
@ -32,13 +32,13 @@ class GptJsonConfig
|
||||
{
|
||||
public:
|
||||
GptJsonConfig(std::string name, std::string version, std::string precision, SizeType tensorParallelism,
|
||||
SizeType pipelineParallelism, GptModelConfig const& modelConfig)
|
||||
SizeType pipelineParallelism, ModelConfig const& modelConfig)
|
||||
: mName(std::move(name))
|
||||
, mVersion(std::move(version))
|
||||
, mPrecision(std::move(precision))
|
||||
, mTensorParallelism{tensorParallelism}
|
||||
, mPipelineParallelism{pipelineParallelism}
|
||||
, mGptModelConfig(modelConfig)
|
||||
, mModelConfig(modelConfig)
|
||||
{
|
||||
}
|
||||
|
||||
@ -48,9 +48,9 @@ public:
|
||||
|
||||
static GptJsonConfig parse(std::filesystem::path const& path);
|
||||
|
||||
[[nodiscard]] GptModelConfig getModelConfig() const
|
||||
[[nodiscard]] ModelConfig getModelConfig() const
|
||||
{
|
||||
return mGptModelConfig;
|
||||
return mModelConfig;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string const& getName() const
|
||||
@ -96,7 +96,7 @@ private:
|
||||
std::string const mPrecision;
|
||||
SizeType const mTensorParallelism;
|
||||
SizeType const mPipelineParallelism;
|
||||
GptModelConfig const mGptModelConfig;
|
||||
ModelConfig const mModelConfig;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -14,6 +14,13 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*****************************************************************************
|
||||
*
|
||||
* GptSession is going to be deprecated soon.
|
||||
* Please do not add new functionality in this file!
|
||||
*
|
||||
*****************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
|
||||
@ -23,8 +30,8 @@
|
||||
#include "tensorrt_llm/runtime/decodingMode.h"
|
||||
#include "tensorrt_llm/runtime/generationInput.h"
|
||||
#include "tensorrt_llm/runtime/generationOutput.h"
|
||||
#include "tensorrt_llm/runtime/gptModelConfig.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||
#include "tensorrt_llm/runtime/samplingConfig.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
|
||||
@ -150,17 +157,17 @@ public:
|
||||
//! @param engineBuffer The compiled TensorRT engine (const void*),
|
||||
//! @param engineSize The size in bytes of the TensorRT engine (size_t),
|
||||
//! @param logger The optional logger.
|
||||
GptSession(Config const& sessionConfig, GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
|
||||
GptSession(Config const& sessionConfig, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
|
||||
void const* engineBuffer, std::size_t engineSize, LoggerPtr logger = nullptr);
|
||||
|
||||
GptSession(Config const& sessionConfig, GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
|
||||
GptSession(Config const& sessionConfig, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
|
||||
std::vector<uint8_t> const& engineBuffer, LoggerPtr logger = nullptr)
|
||||
: GptSession(
|
||||
sessionConfig, modelConfig, worldConfig, engineBuffer.data(), engineBuffer.size(), std::move(logger))
|
||||
{
|
||||
}
|
||||
|
||||
GptSession(Config const& sessionConfig, GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
|
||||
GptSession(Config const& sessionConfig, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
|
||||
std::string const& engineFile, LoggerPtr logger = nullptr)
|
||||
: GptSession(sessionConfig, modelConfig, worldConfig, utils::loadEngine(engineFile), std::move(logger))
|
||||
{
|
||||
@ -170,7 +177,7 @@ public:
|
||||
|
||||
[[nodiscard]] BufferManager const& getBufferManager() const;
|
||||
|
||||
[[nodiscard]] GptModelConfig const& getModelConfig() const
|
||||
[[nodiscard]] ModelConfig const& getModelConfig() const
|
||||
{
|
||||
return mModelConfig;
|
||||
}
|
||||
@ -335,7 +342,7 @@ private:
|
||||
friend class batch_manager::TrtGptModelV1;
|
||||
|
||||
private:
|
||||
GptModelConfig const mModelConfig;
|
||||
ModelConfig const mModelConfig;
|
||||
WorldConfig const mWorldConfig;
|
||||
int mDevice{-1};
|
||||
std::shared_ptr<NcclCommunicator> mPipelineComm;
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
|
||||
#include "tensorrt_llm/common/arrayView.h"
|
||||
#include "tensorrt_llm/common/dataType.h"
|
||||
#include "tensorrt_llm/kernels/kvCacheIndex.h"
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
|
||||
@ -307,6 +308,12 @@ struct TRTDataType<__nv_fp8_e4m3>
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct TRTDataType<kernels::KVCacheIndex>
|
||||
{
|
||||
static constexpr auto value = TRTDataType<kernels::KVCacheIndex::UnderlyingType>::value;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TRTDataType<void*>
|
||||
{
|
||||
|
||||
@ -75,7 +75,7 @@ public:
|
||||
//! Setup the decoder before calling `forward()`, also calls reshapeBuffers
|
||||
virtual void setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth,
|
||||
SizeType maxAttentionWindow, SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep,
|
||||
bool fusedDecoder, nvinfer1::DataType dtype, GptModelConfig const& modelConfig)
|
||||
bool fusedDecoder, nvinfer1::DataType dtype, ModelConfig const& modelConfig)
|
||||
= 0;
|
||||
|
||||
//! @brief Initialize the decoder with new batch of inputs.
|
||||
|
||||
@ -18,10 +18,10 @@
|
||||
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/gptModelConfig.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/loraCachePageManagerConfig.h"
|
||||
#include "tensorrt_llm/runtime/loraModule.h"
|
||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
#include <NvInferRuntimeBase.h>
|
||||
#include <deque>
|
||||
@ -159,11 +159,11 @@ public:
|
||||
|
||||
/**
|
||||
* param[in] pageManagerConfig: a LoraCachePageManagerConfig
|
||||
* param[in] modelConfig: a GptModelConfig
|
||||
* param[in] modelConfig: a ModelConfig
|
||||
* param[in] worldConfig: a WorldConfig
|
||||
* param[in] bufferManager: a BufferManager only used to allocate page blocks
|
||||
*/
|
||||
LoraCache(LoraCachePageManagerConfig const& pageManagerConfig, GptModelConfig const& modelConfig,
|
||||
LoraCache(LoraCachePageManagerConfig const& pageManagerConfig, ModelConfig const& modelConfig,
|
||||
WorldConfig const& worldConfig, BufferManager const& bufferManager);
|
||||
|
||||
/**
|
||||
@ -277,7 +277,7 @@ public:
|
||||
* \brief Copy task weights to cache pages.
|
||||
* \param[in] weights: task weights
|
||||
* \param[in] config: task config tensor
|
||||
* \param[in] modelConfig: a GptModelConfig
|
||||
* \param[in] modelConfig: a ModelConfig
|
||||
* \param[in] worldConfig: a WorldConfig
|
||||
* \param[in] modelIdToModel: map from lora module id to LoraModule
|
||||
* \param[in] manager: a BufferManager the manager to use to perform the copies
|
||||
@ -286,7 +286,7 @@ public:
|
||||
* \returns -- list of cache Values objects
|
||||
*/
|
||||
static std::vector<LoraCache::TaskLayerModuleConfig> copyToPages(TensorPtr weights, TensorPtr config,
|
||||
GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
|
||||
ModelConfig const& modelConfig, WorldConfig const& worldConfig,
|
||||
std::unordered_map<SizeType, LoraModule> moduleIdToModel, BufferManager const& manager,
|
||||
std::vector<TensorPtr> const& pages, std::vector<std::size_t> const& pageIds);
|
||||
|
||||
@ -385,7 +385,7 @@ private:
|
||||
};
|
||||
|
||||
LoraCachePageManagerConfig mPageManagerConfig;
|
||||
GptModelConfig mModelConfig;
|
||||
ModelConfig mModelConfig;
|
||||
WorldConfig mWorldConfig;
|
||||
|
||||
// Protects mCachePageManager
|
||||
|
||||
@ -32,7 +32,7 @@ struct MambaConfig
|
||||
SizeType expand = 0;
|
||||
};
|
||||
|
||||
class GptModelConfig
|
||||
class ModelConfig
|
||||
{
|
||||
public:
|
||||
enum class ModelVariant : std::int32_t
|
||||
@ -42,10 +42,11 @@ public:
|
||||
kMamba = 2, // https://github.com/state-spaces/mamba
|
||||
};
|
||||
|
||||
explicit GptModelConfig(
|
||||
SizeType vocabSize, SizeType nbLayers, SizeType nbHeads, SizeType hiddenSize, nvinfer1::DataType dtype)
|
||||
explicit ModelConfig(SizeType vocabSize, SizeType nbAttentionLayers, SizeType nbSsmLayers, SizeType nbHeads,
|
||||
SizeType hiddenSize, nvinfer1::DataType dtype)
|
||||
: mVocabSize(vocabSize)
|
||||
, mNbLayers(nbLayers)
|
||||
, mNbAttentionLayers(nbAttentionLayers)
|
||||
, mNbSsmLayers(nbSsmLayers)
|
||||
, mNbHeads(nbHeads)
|
||||
, mNbKvHeads(nbHeads)
|
||||
, mHiddenSize(hiddenSize)
|
||||
@ -71,6 +72,7 @@ public:
|
||||
, mMaxDraftLen(0)
|
||||
, mUseContextFMHAForGeneration(false)
|
||||
, mPagedContextFMHA(false)
|
||||
, mUseXQA{false}
|
||||
, mUseLoraPlugin(false)
|
||||
, mMlpHiddenSize(0)
|
||||
, mMedusaModule(std::nullopt)
|
||||
@ -87,10 +89,16 @@ public:
|
||||
return (mVocabSize + worldSize - 1) / worldSize * worldSize;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType constexpr getNbLayers(SizeType pipelineParallelism = 1) const
|
||||
[[nodiscard]] SizeType constexpr getNbAttentionLayers(SizeType pipelineParallelism = 1) const
|
||||
{
|
||||
TLLM_CHECK(mNbLayers % pipelineParallelism == 0);
|
||||
return mNbLayers / pipelineParallelism;
|
||||
TLLM_CHECK(mNbAttentionLayers % pipelineParallelism == 0);
|
||||
return mNbAttentionLayers / pipelineParallelism;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType constexpr getNbSsmLayers(SizeType pipelineParallelism = 1) const
|
||||
{
|
||||
TLLM_CHECK(mNbSsmLayers % pipelineParallelism == 0);
|
||||
return mNbSsmLayers / pipelineParallelism;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType constexpr getNbHeads() const noexcept
|
||||
@ -344,6 +352,16 @@ public:
|
||||
return mPagedContextFMHA;
|
||||
}
|
||||
|
||||
void constexpr useXQA(bool useXQA) noexcept
|
||||
{
|
||||
mUseXQA = useXQA;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr useXQA() const noexcept
|
||||
{
|
||||
return mUseXQA;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr useLoraPlugin() const noexcept
|
||||
{
|
||||
return mUseLoraPlugin;
|
||||
@ -354,7 +372,7 @@ public:
|
||||
mUseLoraPlugin = useLoraPlugin;
|
||||
}
|
||||
|
||||
std::vector<LoraModule> const& getLoraModules() const noexcept
|
||||
[[nodiscard]] std::vector<LoraModule> const& getLoraModules() const noexcept
|
||||
{
|
||||
return mLoraModules;
|
||||
}
|
||||
@ -442,7 +460,8 @@ public:
|
||||
|
||||
private:
|
||||
SizeType mVocabSize;
|
||||
SizeType mNbLayers;
|
||||
SizeType mNbAttentionLayers;
|
||||
SizeType mNbSsmLayers;
|
||||
SizeType mNbHeads;
|
||||
SizeType mNbKvHeads;
|
||||
SizeType mHiddenSize;
|
||||
@ -471,6 +490,7 @@ private:
|
||||
|
||||
bool mUseContextFMHAForGeneration;
|
||||
bool mPagedContextFMHA;
|
||||
bool mUseXQA;
|
||||
|
||||
bool mUseLoraPlugin;
|
||||
std::vector<LoraModule> mLoraModules;
|
||||
@ -17,6 +17,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/layers/defaultDecodingParams.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
|
||||
#include <functional>
|
||||
@ -36,25 +37,21 @@ private:
|
||||
|
||||
template <typename T>
|
||||
static OptVec<T> fuseValues(
|
||||
std::vector<SamplingConfig> const& configs, std::function<OptVec<T>(SizeType ci)> accessor)
|
||||
std::vector<SamplingConfig> const& configs, std::function<OptVec<T>(size_t ci)> accessor, T defaultValue)
|
||||
{
|
||||
std::vector<T> values;
|
||||
auto const hasValues = accessor(0).has_value();
|
||||
for (size_t ci = 0; ci < configs.size(); ++ci)
|
||||
{
|
||||
auto value = defaultValue;
|
||||
auto const& configValue = accessor(ci);
|
||||
TLLM_CHECK(hasValues == configValue.has_value());
|
||||
if (hasValues)
|
||||
if (configValue.has_value())
|
||||
{
|
||||
TLLM_CHECK(configValue.value().size() == 1);
|
||||
values.push_back(configValue.value().front());
|
||||
value = configValue.value().front();
|
||||
}
|
||||
values.push_back(value);
|
||||
}
|
||||
|
||||
if (!hasValues)
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
return std::make_optional<std::vector<T>>(values);
|
||||
}
|
||||
|
||||
@ -72,26 +69,52 @@ public:
|
||||
TLLM_CHECK(configs.size() > 0);
|
||||
beamWidth = configs.front().beamWidth;
|
||||
normalizeLogProbs = configs.front().normalizeLogProbs;
|
||||
temperature = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].temperature; });
|
||||
minLength = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].minLength; });
|
||||
repetitionPenalty
|
||||
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].repetitionPenalty; });
|
||||
presencePenalty
|
||||
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].presencePenalty; });
|
||||
topK = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].topK; });
|
||||
topP = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].topP; });
|
||||
randomSeed = fuseValues<uint64_t>(configs, [&configs](SizeType ci) { return configs[ci].randomSeed; });
|
||||
topPDecay = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].topPDecay; });
|
||||
topPMin = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].topPMin; });
|
||||
topPResetIds = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].topPResetIds; });
|
||||
beamSearchDiversityRate
|
||||
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].beamSearchDiversityRate; });
|
||||
lengthPenalty = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].lengthPenalty; });
|
||||
earlyStopping = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].earlyStopping; });
|
||||
draftAcceptanceThreshold
|
||||
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].draftAcceptanceThreshold; });
|
||||
topKMedusaHeads = fuseValues<std::vector<runtime::SizeType>>(
|
||||
configs, [&configs](SizeType ci) { return configs[ci].topKMedusaHeads; });
|
||||
temperature = fuseValues<FloatType>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].temperature; },
|
||||
layers::DefaultDecodingParams::getTemperature());
|
||||
minLength = fuseValues<SizeType32>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].minLength; },
|
||||
layers::DefaultDecodingParams::getMinLength());
|
||||
repetitionPenalty = fuseValues<FloatType>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].repetitionPenalty; },
|
||||
layers::DefaultDecodingParams::getRepetitionPenalty());
|
||||
presencePenalty = fuseValues<FloatType>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].presencePenalty; },
|
||||
layers::DefaultDecodingParams::getPresencePenalty());
|
||||
frequencyPenalty = fuseValues<FloatType>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].frequencyPenalty; },
|
||||
layers::DefaultDecodingParams::getFrequencyPenalty());
|
||||
topK = fuseValues<SizeType32>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].topK; }, layers::DefaultDecodingParams::getTopK());
|
||||
topP = fuseValues<FloatType>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].topP; }, layers::DefaultDecodingParams::getTopP());
|
||||
randomSeed = fuseValues<uint64_t>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].randomSeed; },
|
||||
layers::DefaultDecodingParams::getSeed());
|
||||
topPDecay = fuseValues<FloatType>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].topPDecay; },
|
||||
layers::DefaultDecodingParams::getTopPDecay());
|
||||
topPMin = fuseValues<FloatType>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].topPMin; },
|
||||
layers::DefaultDecodingParams::getTopPMin());
|
||||
topPResetIds = fuseValues<TokenIdType>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].topPResetIds; },
|
||||
layers::DefaultDecodingParams::getTopPResetId());
|
||||
beamSearchDiversityRate = fuseValues<FloatType>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].beamSearchDiversityRate; },
|
||||
layers::DefaultDecodingParams::getBeamSearchDiversity());
|
||||
lengthPenalty = fuseValues<FloatType>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].lengthPenalty; },
|
||||
layers::DefaultDecodingParams::getLengthPenalty());
|
||||
earlyStopping = fuseValues<SizeType32>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].earlyStopping; },
|
||||
layers::DefaultDecodingParams::getEarlyStopping());
|
||||
topKMedusaHeads = fuseValues<std::vector<SizeType32>>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].topKMedusaHeads; },
|
||||
layers::DefaultDecodingParams::getTopKMedusaHeads());
|
||||
// Only used for tests.
|
||||
draftAcceptanceThreshold = fuseValues<FloatType>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].draftAcceptanceThreshold; }, 0);
|
||||
}
|
||||
|
||||
explicit SamplingConfig(executor::SamplingConfig const& samplingConfig,
|
||||
@ -148,13 +171,13 @@ public:
|
||||
// beam search layer
|
||||
OptVec<FloatType> beamSearchDiversityRate; // [1] or [batch_size]
|
||||
OptVec<FloatType> lengthPenalty; // [1] or [batch_size]
|
||||
OptVec<SizeType> earlyStopping; // [1] or [batch_size]
|
||||
OptVec<SizeType32> earlyStopping; // [1] or [batch_size]
|
||||
|
||||
// speculative decoding, only the first value is used (in gptDecoderBatch.cpp)
|
||||
OptVec<FloatType> draftAcceptanceThreshold; // [1] or [batch_size]
|
||||
|
||||
// medusa params
|
||||
OptVec<std::vector<runtime::SizeType>> topKMedusaHeads; // [batchSize, maxMedusaHeads]
|
||||
OptVec<std::vector<runtime::SizeType32>> topKMedusaHeads; // [batchSize, maxMedusaHeads]
|
||||
|
||||
std::optional<bool> normalizeLogProbs;
|
||||
|
||||
|
||||
@ -30,6 +30,7 @@ add_subdirectory(common)
|
||||
add_subdirectory(kernels)
|
||||
add_subdirectory(layers)
|
||||
add_subdirectory(runtime)
|
||||
add_subdirectory(executor_worker)
|
||||
|
||||
set(BATCH_MANAGER_TARGET tensorrt_llm_batch_manager_static)
|
||||
set(BATCH_MANAGER_TARGET_ARCH "unknown")
|
||||
@ -196,8 +197,9 @@ set(TRTLLM_LINK_LIBS
|
||||
kernels_src
|
||||
context_attention_src
|
||||
decoder_attention_src
|
||||
cutlass2_src
|
||||
cutlass3_src
|
||||
fpA_intB_gemm_src
|
||||
moe_gemm_src
|
||||
cutlass_src
|
||||
layers_src
|
||||
runtime_src)
|
||||
|
||||
@ -218,44 +220,31 @@ set_target_properties(
|
||||
PROPERTIES CXX_STANDARD "17" CXX_STANDARD_REQUIRED "YES" CXX_EXTENSIONS "NO"
|
||||
LINK_FLAGS "${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}")
|
||||
|
||||
function(link_whole_archive TARGET LIBRARY_TO_LINK)
|
||||
if(WIN32)
|
||||
target_link_libraries(${TARGET} PUBLIC $<TARGET_FILE:${LIBRARY_TO_LINK}>)
|
||||
set_target_properties(
|
||||
${TARGET} PROPERTIES LINK_FLAGS "/WHOLEARCHIVE:${LIBRARY_TO_LINK}")
|
||||
else()
|
||||
# Assume everything else is like gcc
|
||||
target_link_libraries(
|
||||
${TARGET} PRIVATE "-Wl,--whole-archive" $<TARGET_FILE:${LIBRARY_TO_LINK}>
|
||||
"-Wl,--no-whole-archive")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
target_link_libraries(${SHARED_TARGET} PUBLIC ${TRTLLM_LINK_LIBS})
|
||||
|
||||
if(WIN32)
|
||||
target_link_libraries(${SHARED_TARGET}
|
||||
PUBLIC $<TARGET_FILE:${BATCH_MANAGER_TARGET}>)
|
||||
set_target_properties(
|
||||
${SHARED_TARGET} PROPERTIES LINK_FLAGS
|
||||
"/WHOLEARCHIVE:${BATCH_MANAGER_TARGET}")
|
||||
else()
|
||||
# Assume everything else is like gcc
|
||||
target_link_libraries(
|
||||
${SHARED_TARGET}
|
||||
PRIVATE "-Wl,--whole-archive" $<TARGET_FILE:${BATCH_MANAGER_TARGET}>
|
||||
"-Wl,--no-whole-archive")
|
||||
endif()
|
||||
|
||||
if(WIN32)
|
||||
target_link_libraries(${SHARED_TARGET}
|
||||
PUBLIC $<TARGET_FILE:${EXECUTOR_TARGET}>)
|
||||
set_target_properties(
|
||||
${SHARED_TARGET} PROPERTIES LINK_FLAGS "/WHOLEARCHIVE:${EXECUTOR_TARGET}")
|
||||
else()
|
||||
# Assume everything else is like gcc
|
||||
target_link_libraries(
|
||||
${SHARED_TARGET}
|
||||
PRIVATE "-Wl,--whole-archive" $<TARGET_FILE:${EXECUTOR_TARGET}>
|
||||
"-Wl,--no-whole-archive")
|
||||
endif()
|
||||
link_whole_archive(${SHARED_TARGET} ${BATCH_MANAGER_TARGET})
|
||||
link_whole_archive(${SHARED_TARGET} ${EXECUTOR_TARGET})
|
||||
# Cyclic dependency of batch manager on TRT-LLM
|
||||
target_link_libraries(${BATCH_MANAGER_TARGET} INTERFACE ${SHARED_TARGET})
|
||||
# Cyclic dependency of executor on TRT-LLM
|
||||
target_link_libraries(${EXECUTOR_TARGET} INTERFACE ${SHARED_TARGET})
|
||||
|
||||
add_dependencies(${SHARED_TARGET} check_symbol)
|
||||
add_dependencies(${SHARED_TARGET} check_symbol_executor)
|
||||
|
||||
# Cyclic dependency of batch manager on TRT-LLM
|
||||
target_link_libraries(${BATCH_MANAGER_TARGET} INTERFACE ${SHARED_TARGET})
|
||||
|
||||
# Cyclic dependency of executor on TRT-LLM
|
||||
target_link_libraries(${EXECUTOR_TARGET} INTERFACE ${SHARED_TARGET})
|
||||
|
||||
if(BUILD_PYT)
|
||||
add_subdirectory(thop)
|
||||
endif()
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6bd5ec7130a703889eb51fe6591c93a079ded644ca089099efe5e3d72474838e
|
||||
size 2896708
|
||||
oid sha256:d8a083974ff58e74dec95d1ad438bf84be9adeedeb20b5e7254fe56d6a4bf40c
|
||||
size 2997970
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d25d35be9ec13d1f0a0b9f3ed40362879d9ac50bdfcdcb827990554a26ff5c10
|
||||
size 2923694
|
||||
oid sha256:40cace20ce33a945ed12a2a2e382053aa90113d8bed2623c985dbb60b943251e
|
||||
size 3034874
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
cafe56cc4a916b91ea338a8412c79fef libtensorrt_llm_batch_manager_static.a
|
||||
3274866669694da8f09e30388939b7dd libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
165fe125d6bf55090d8a7dec012d08f8d0e7a54b commit
|
||||
7c5e14e8ed4e3e0641a8aefa659a03c0 libtensorrt_llm_batch_manager_static.a
|
||||
79a986633cb1f0dc6621423bbbf21727 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
83029c1606a00e0e4aaf5ea2de17867a6e5ddd9b commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:27dbbdae087a946d1762f11efe953a1b1b282e27747708145c405e9380fce287
|
||||
size 2822910
|
||||
oid sha256:913f548b9f66aaea93baaa40bd7ca37f4fb0b52f5ed0778b1fe52c136141433c
|
||||
size 2916334
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:622724d6b9219dd3d4710a822ca92d497c466cdc34149258f9559c08f4470f8e
|
||||
size 2796594
|
||||
oid sha256:8dd40bb9cafae379971b365c8206fd20addb7816c64953456568110e5f694b0e
|
||||
size 2900610
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:296c78f2c29774fab2145465a9a515a7e4aaedde96ba3c3f6fa5af91fa92dee6
|
||||
size 18976374
|
||||
oid sha256:889f62ee370c0a00c1ccfc26e82fcd1410413e44e6d955aca12a90c906e89239
|
||||
size 18428048
|
||||
|
||||
@ -62,6 +62,7 @@ CUDADriverWrapper::CUDADriverWrapper()
|
||||
*(void**) (&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2");
|
||||
*(void**) (&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
|
||||
*(void**) (&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel");
|
||||
*(void**) (&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled");
|
||||
}
|
||||
|
||||
CUDADriverWrapper::~CUDADriverWrapper()
|
||||
@ -143,5 +144,14 @@ CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX,
|
||||
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
|
||||
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
|
||||
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
|
||||
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const
|
||||
{
|
||||
return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides,
|
||||
boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill);
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -70,6 +70,11 @@ public:
|
||||
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
|
||||
CUstream hStream, void** kernelParams, void** extra) const;
|
||||
|
||||
CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank,
|
||||
void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim,
|
||||
cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle,
|
||||
CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const;
|
||||
|
||||
private:
|
||||
void* handle;
|
||||
CUresult (*_cuGetErrorName)(CUresult, char const**);
|
||||
@ -89,6 +94,10 @@ private:
|
||||
CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
|
||||
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
|
||||
CUstream hStream, void** kernelParams, void** extra);
|
||||
CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
|
||||
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
|
||||
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
|
||||
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill);
|
||||
};
|
||||
|
||||
inline void cuErrCheck_(CUresult stat, CUDADriverWrapper const& wrap, char const* file, int line)
|
||||
|
||||
@ -22,21 +22,39 @@
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
static std::optional<int32_t> getIntEnv(char const* name)
|
||||
{
|
||||
char const* const env = std::getenv(name);
|
||||
if (env == nullptr)
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
int32_t const val = std::stoi(env);
|
||||
if (val <= 0)
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
return {val};
|
||||
};
|
||||
|
||||
// XQA kernels (optimized kernels for generation phase).
|
||||
bool forceXQAKernels()
|
||||
{
|
||||
char const* force_xqa_env_var = getenv("TRTLLM_FORCE_XQA");
|
||||
static bool forceXQA = false;
|
||||
if (force_xqa_env_var != nullptr)
|
||||
{
|
||||
if (force_xqa_env_var[0] == '1' && force_xqa_env_var[1] == '\0')
|
||||
{
|
||||
forceXQA = true;
|
||||
}
|
||||
}
|
||||
static bool const forceXQA = (getIntEnv("TRTLLM_FORCE_XQA").value_or(0) != 0);
|
||||
return forceXQA;
|
||||
}
|
||||
|
||||
int32_t xqaMaxNbCtaPerKVHeadFactor()
|
||||
{
|
||||
return envXqaNbCtaPerKVHead().value_or(8);
|
||||
}
|
||||
|
||||
std::optional<int32_t> envXqaNbCtaPerKVHead()
|
||||
{
|
||||
static std::optional<int32_t> const ret = getIntEnv("TRTLLM_XQA_BLOCKS_PER_SEQUENCE");
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Tune the number of blocks per sequence for accuracy/performance purpose.
|
||||
bool getEnvMmhaMultiblockDebug()
|
||||
{
|
||||
|
||||
@ -16,6 +16,8 @@
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
@ -23,6 +25,14 @@ namespace tensorrt_llm::common
|
||||
// XQA kernels (optimized kernels for generation phase).
|
||||
bool forceXQAKernels();
|
||||
|
||||
// max number of CTAs for each KV head, multiple CTAs for one KV head is multi-block mode.
|
||||
// this number defines the maximum number when reaches both max_batch_size and max_beam_width.
|
||||
// If batch_size or beam_width doesn't reach maximum value, it is possible to have more CTAs per KV head than this
|
||||
// value.
|
||||
int32_t xqaMaxNbCtaPerKVHeadFactor();
|
||||
|
||||
std::optional<int32_t> envXqaNbCtaPerKVHead();
|
||||
|
||||
// Tune the number of blocks per sequence for accuracy/performance purpose.
|
||||
bool getEnvMmhaMultiblockDebug();
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
|
||||
#include <csignal>
|
||||
#include <mpi.h>
|
||||
@ -35,7 +36,7 @@ namespace tensorrt_llm::mpi
|
||||
|
||||
MPI_Datatype getMpiDtype(MpiType dtype)
|
||||
{
|
||||
static const std::unordered_map<MpiType, MPI_Datatype> dtype_map{
|
||||
static std::unordered_map<MpiType, MPI_Datatype> const dtype_map{
|
||||
|
||||
{MpiType::kBYTE, MPI_BYTE},
|
||||
{MpiType::kHALF, MPI_UINT16_T},
|
||||
@ -57,7 +58,7 @@ MPI_Datatype getMpiDtype(MpiType dtype)
|
||||
|
||||
MPI_Op getMpiOp(MpiOp op)
|
||||
{
|
||||
static const std::unordered_map<MpiOp, MPI_Op> op_map{
|
||||
static std::unordered_map<MpiOp, MPI_Op> const op_map{
|
||||
{MpiOp::NULLOP, MPI_OP_NULL},
|
||||
{MpiOp::MAX, MPI_MAX},
|
||||
{MpiOp::MIN, MPI_MIN},
|
||||
@ -122,16 +123,33 @@ std::shared_ptr<MpiRequest> MpiComm::bcastAsync(void* buffer, size_t size, MpiTy
|
||||
return r;
|
||||
}
|
||||
|
||||
std::shared_ptr<MpiRequest> MpiComm::bcastAsync(runtime::IBuffer& buf, int root) const
|
||||
{
|
||||
TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU);
|
||||
return bcastAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
|
||||
}
|
||||
|
||||
void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const
|
||||
{
|
||||
MPICHECK(MPI_Bcast(buffer, size, getMpiDtype(dtype), root, mComm));
|
||||
}
|
||||
|
||||
void MpiComm::bcast(runtime::IBuffer& buf, int root) const
|
||||
{
|
||||
bcast(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
|
||||
}
|
||||
|
||||
void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const
|
||||
{
|
||||
MPICHECK(MPI_Send(buffer, size, getMpiDtype(dtype), dest, tag, mComm));
|
||||
}
|
||||
|
||||
void MpiComm::send(runtime::IBuffer const& buf, int dest, int tag) const
|
||||
{
|
||||
TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU);
|
||||
send(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag);
|
||||
}
|
||||
|
||||
MPI_Status MpiComm::recv(void* buffer, size_t size, MpiType dtype, int source, int tag) const
|
||||
{
|
||||
MPI_Status status{};
|
||||
@ -139,6 +157,12 @@ MPI_Status MpiComm::recv(void* buffer, size_t size, MpiType dtype, int source, i
|
||||
return status;
|
||||
}
|
||||
|
||||
MPI_Status MpiComm::recv(runtime::IBuffer& buf, int source, int tag) const
|
||||
{
|
||||
TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU);
|
||||
return recv(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, source, tag);
|
||||
}
|
||||
|
||||
MpiComm MpiComm::split(int color, int key) const
|
||||
{
|
||||
MPI_Comm splitComm;
|
||||
|
||||
@ -63,7 +63,7 @@ int8_t* nextWorkspacePtrWithAlignment(
|
||||
return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, alignment);
|
||||
}
|
||||
|
||||
size_t calculateTotalWorkspaceSize(size_t* workspaces, int count, const uintptr_t alignment = kCudaMemAlign)
|
||||
size_t calculateTotalWorkspaceSize(size_t const* workspaces, int count, const uintptr_t alignment = kCudaMemAlign)
|
||||
{
|
||||
size_t total = 0;
|
||||
for (int i = 0; i < count; i++)
|
||||
|
||||
@ -30,6 +30,7 @@
|
||||
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_silu.h"
|
||||
#include "cutlass_extensions/epilogue/thread/fused_activations.h"
|
||||
#include <cutlass/epilogue/fusion/operations.hpp>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
@ -48,6 +49,10 @@ struct EpilogueOpBiasFtGelu
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpBias
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpDefaultSilu
|
||||
{
|
||||
};
|
||||
@ -60,10 +65,6 @@ struct EpilogueOpDefaultFtGelu
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpBias
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpDefault
|
||||
{
|
||||
};
|
||||
@ -71,6 +72,7 @@ struct EpilogueOpDefault
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator, typename Op>
|
||||
struct Epilogue
|
||||
{
|
||||
static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag");
|
||||
};
|
||||
|
||||
constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling;
|
||||
|
||||
@ -36,10 +36,11 @@ namespace kernel
|
||||
template <typename TypeA, typename TypeB, typename arch, typename Enable = void>
|
||||
struct MixedGemmArchTraits
|
||||
{
|
||||
static_assert(dependent_false<arch>, "Unrecognised parameterization");
|
||||
};
|
||||
|
||||
template <typename arch>
|
||||
struct MixedGemmArchTraits<float, float, arch>
|
||||
template <typename Arch>
|
||||
struct MixedGemmArchTraits<float, float, Arch>
|
||||
{
|
||||
static constexpr int Stages = 2;
|
||||
using OperatorClass = cutlass::arch::OpClassSimt;
|
||||
@ -66,7 +67,7 @@ struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm70,
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm70>;
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm70>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
@ -92,7 +93,7 @@ struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm75,
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm75>;
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm75>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
@ -116,7 +117,7 @@ struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm80,
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm80>;
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm80>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
@ -133,6 +134,34 @@ public:
|
||||
using Operator = typename LayoutDetails::Operator;
|
||||
};
|
||||
|
||||
// ======================= Ada Traits ==============================
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm89,
|
||||
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value
|
||||
#ifdef ENABLE_FP8
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::float_e4m3_t>::value>::type
|
||||
#endif
|
||||
>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm89>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using AccType = float;
|
||||
using LayoutB = typename LayoutDetails::Layout;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
||||
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits<TypeA>::value>;
|
||||
|
||||
using Operator = typename LayoutDetails::Operator;
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@ -546,8 +546,10 @@ struct GemmFpAIntB
|
||||
run_kernel<arch::Sm70>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
|
||||
run_kernel<arch::Sm75>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
|
||||
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890)
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ == 890)
|
||||
run_kernel<arch::Sm89>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 900)
|
||||
CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.
|
||||
#else
|
||||
|
||||
@ -42,16 +42,16 @@ namespace gemm
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
template <typename TypeB, typename Arch, typename Enable = void>
|
||||
template <typename TypeA, typename TypeB, typename Arch, typename Enable = void>
|
||||
struct LayoutDetailsB
|
||||
{
|
||||
};
|
||||
|
||||
// Volta specialiations. Volta will dequantize before STS, so we need a different operator
|
||||
template <typename TypeB>
|
||||
struct LayoutDetailsB<TypeB, arch::Sm70>
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct LayoutDetailsB<TypeA, TypeB, arch::Sm70>
|
||||
{
|
||||
static constexpr int ThreadblockK = 64;
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 8;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
@ -59,19 +59,19 @@ struct LayoutDetailsB<TypeB, arch::Sm70>
|
||||
|
||||
// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks.
|
||||
// TODO - Switch this to column major for weights since gemms should be more performant.
|
||||
template <typename Arch>
|
||||
struct LayoutDetailsB<half_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, half_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 64;
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <typename Arch>
|
||||
struct LayoutDetailsB<bfloat16_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, bfloat16_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 64;
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<bfloat16_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
@ -79,11 +79,12 @@ struct LayoutDetailsB<bfloat16_t, Arch, typename platform::enable_if<Arch::kMinC
|
||||
|
||||
// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
|
||||
// which signals that we want to dequantize after loading from smem.
|
||||
template <typename Arch>
|
||||
struct LayoutDetailsB < uint8_t,
|
||||
Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB < TypeA,
|
||||
uint8_t, Arch,
|
||||
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 64;
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
|
||||
private:
|
||||
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
|
||||
@ -95,11 +96,12 @@ public:
|
||||
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
||||
};
|
||||
|
||||
template <typename Arch>
|
||||
struct LayoutDetailsB < uint4b_t,
|
||||
Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB < TypeA,
|
||||
uint4b_t, Arch,
|
||||
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 64;
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
|
||||
private:
|
||||
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint4b_t>::value;
|
||||
@ -111,19 +113,19 @@ public:
|
||||
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
||||
};
|
||||
|
||||
template <typename Arch>
|
||||
struct LayoutDetailsB<uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 64;
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <typename Arch>
|
||||
struct LayoutDetailsB<uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 64;
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
|
||||
@ -35,6 +35,8 @@
|
||||
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
|
||||
#include "cutlass_extensions/tile_interleaved_layout.h"
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
@ -502,8 +504,7 @@ public:
|
||||
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 900)
|
||||
run_kernel<arch::Sm80>(
|
||||
params, shared_storage); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#else
|
||||
static_assert(
|
||||
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
|
||||
|
||||
@ -38,11 +38,12 @@ namespace threadblock
|
||||
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment,
|
||||
typename Enable = void>
|
||||
struct DefaultScaleIterators;
|
||||
struct DefaultScaleIteratorsMultistage;
|
||||
|
||||
// Fine grained iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIterators<MmaShape, Element, Layout, QuantOp, Alignment, std::enable_if_t<isFinegrained(QuantOp)>>
|
||||
struct DefaultScaleIteratorsMultistage<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<isFinegrained(QuantOp)>>
|
||||
{
|
||||
using IteratorScale
|
||||
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>, Element,
|
||||
@ -53,7 +54,8 @@ struct DefaultScaleIterators<MmaShape, Element, Layout, QuantOp, Alignment, std:
|
||||
|
||||
// Per column iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIterators<MmaShape, Element, Layout, QuantOp, Alignment, std::enable_if_t<!isFinegrained(QuantOp)>>
|
||||
struct DefaultScaleIteratorsMultistage<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<!isFinegrained(QuantOp)>>
|
||||
{
|
||||
// ThreadMap for scale iterator
|
||||
static_assert((MmaShape::kN % Alignment) == 0, "");
|
||||
@ -73,7 +75,7 @@ public:
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Type for elementA
|
||||
/// Type for element A
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
@ -105,9 +107,9 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Stages in GEMM
|
||||
int kStages,
|
||||
///
|
||||
/// Operator performed by GEMM
|
||||
typename Operator_,
|
||||
///
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
|
||||
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
|
||||
@ -116,8 +118,9 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
|
||||
ArchTag::kMinComputeCapability >= 80 && !layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
|
||||
{
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
||||
"Element A must be fp16 or bf16");
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value
|
||||
|| platform::is_same<ElementA, float_e4m3_t>::value,
|
||||
"Element A must be fp16, fp8 or bf16");
|
||||
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
@ -155,7 +158,7 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, ElementB, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
using ScaleIterators = DefaultScaleIterators<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
using ScaleIterators = DefaultScaleIteratorsMultistage<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
OperatorInfo::QuantOp, kAlignmentScale>;
|
||||
|
||||
// Define iterators over tiles from the scale operand
|
||||
@ -163,7 +166,7 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
|
||||
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
|
||||
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementA, ElementB,
|
||||
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementScale, ElementB,
|
||||
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
@ -173,6 +176,7 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
|
||||
typename MmaCore::MmaPolicy, kStages, Converter, OperatorInfo::QuantOp, SharedMemoryClear>;
|
||||
};
|
||||
|
||||
// Specialization to handle column major interleave B
|
||||
template <
|
||||
/// Type for element A
|
||||
typename ElementA,
|
||||
@ -206,9 +210,9 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Stages in GEMM
|
||||
int kStages,
|
||||
///
|
||||
/// Operator performed by GEMM
|
||||
typename Operator_,
|
||||
///
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
|
||||
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
|
||||
@ -217,8 +221,9 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
|
||||
ArchTag::kMinComputeCapability >= 80 && layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
|
||||
{
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
||||
"Element A must be fp16 or bf16");
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value
|
||||
|| platform::is_same<ElementA, float_e4m3_t>::value,
|
||||
"Element A must be fp16, fp8 or bf16");
|
||||
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
@ -274,7 +279,7 @@ public:
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<GmemIteratorShape, ElementB,
|
||||
layout::ColumnMajor, 0, GmemThreadMapB, AccessTypeB>;
|
||||
|
||||
using ScaleIterators = DefaultScaleIterators<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
using ScaleIterators = DefaultScaleIteratorsMultistage<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
OperatorInfo::QuantOp, kAlignmentScale>;
|
||||
|
||||
// Define iterators over tiles from the scale operand
|
||||
@ -282,7 +287,7 @@ public:
|
||||
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
|
||||
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementA, ElementB,
|
||||
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementScale, ElementB,
|
||||
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
|
||||
@ -35,6 +35,42 @@ namespace threadblock
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment,
|
||||
typename Enable = void>
|
||||
struct DefaultScaleIteratorsPipelined;
|
||||
|
||||
// TODO: Fine grained iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIteratorsPipelined<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<isFinegrained(QuantOp)>>
|
||||
{
|
||||
};
|
||||
|
||||
// Per column iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIteratorsPipelined<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<!isFinegrained(QuantOp)>>
|
||||
{
|
||||
static_assert((MmaShape::kN % Alignment) == 0, "");
|
||||
|
||||
private:
|
||||
// ThreadMap for scale iterator
|
||||
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaShape::kN, 1>,
|
||||
MmaShape::kN / Alignment, Alignment>;
|
||||
using SmemScaleType = half_t;
|
||||
|
||||
public:
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>,
|
||||
Element, Layout, 0, IteratorScaleThreadMap, Alignment>;
|
||||
|
||||
using SmemIteratorScale
|
||||
= cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>, SmemScaleType,
|
||||
Layout, 0, IteratorScaleThreadMap, Alignment>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Type for element A
|
||||
typename ElementA,
|
||||
@ -86,8 +122,7 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
|
||||
static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
|
||||
|
||||
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
|
||||
static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
|
||||
using MmaCoreElementA = typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type;
|
||||
using MmaCoreElementA = half_t;
|
||||
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
|
||||
|
||||
// Define the MmaCore components
|
||||
@ -105,21 +140,13 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
|
||||
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>, ElementB, LayoutB, 0,
|
||||
typename MmaCore::IteratorThreadMapB, kAlignmentB>;
|
||||
|
||||
// ThreadMap for scale iterator
|
||||
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
|
||||
using IteratorScaleThreadMap
|
||||
= transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
|
||||
MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>;
|
||||
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
OperatorInfo::QuantOp, kAlignmentScale>;
|
||||
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale
|
||||
= cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
|
||||
ElementScale, LayoutScale, 0, IteratorScaleThreadMap, kAlignmentScale>;
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
|
||||
using SmemScaleType = typename platform::conditional<arch_has_bf16_mma, ElementScale, half_t>::type;
|
||||
using SmemIteratorScale
|
||||
= cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
|
||||
SmemScaleType, LayoutScale, 0, IteratorScaleThreadMap, kAlignmentScale>;
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
|
||||
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
|
||||
|
||||
@ -182,8 +209,7 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
|
||||
static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
|
||||
|
||||
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
|
||||
static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
|
||||
using MmaCoreElementA = typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type;
|
||||
using MmaCoreElementA = half_t;
|
||||
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
|
||||
|
||||
// Define the MmaCore components
|
||||
@ -225,15 +251,13 @@ public:
|
||||
= transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
|
||||
MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>;
|
||||
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale
|
||||
= cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
|
||||
ElementScale, LayoutScale, 0, IteratorScaleThreadMap, kAlignmentScale>;
|
||||
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
OperatorInfo::QuantOp, kAlignmentScale>;
|
||||
|
||||
using SmemScaleType = typename platform::conditional<arch_has_bf16_mma, ElementScale, half_t>::type;
|
||||
using SmemIteratorScale
|
||||
= cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
|
||||
SmemScaleType, LayoutScale, 0, IteratorScaleThreadMap, kAlignmentScale>;
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
|
||||
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ namespace threadblock
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma pipelined (stage=2)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
@ -77,7 +77,7 @@ public:
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
@ -124,6 +124,9 @@ public:
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
@ -176,7 +179,8 @@ public:
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
@ -228,6 +232,63 @@ public:
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
using Mma = DqMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t,
|
||||
layout::RowMajor, kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
|
||||
ThreadblockShape, WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
|
||||
// large tile when not enough shared mem is present to do 3+ stage
|
||||
template <
|
||||
|
||||
@ -86,7 +86,7 @@ template <
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
/// Layout of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
@ -189,8 +189,9 @@ private:
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave
|
||||
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
@ -218,7 +219,7 @@ public:
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage& shared_storage,
|
||||
/// The group size for quantization
|
||||
int group_size,
|
||||
int const group_size,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
@ -279,10 +280,21 @@ public:
|
||||
}
|
||||
else if (iterator_scale.group_size_ == 128)
|
||||
{
|
||||
if (iterator_scale.row_groupsize64_ & 0x1)
|
||||
if constexpr (Shape::kK == 128)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
else if constexpr (Shape::kK == 64)
|
||||
{
|
||||
if (iterator_scale.row_groupsize64_ & 0x1)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128");
|
||||
}
|
||||
}
|
||||
|
||||
iterator_scale.row_groupsize64_++;
|
||||
@ -291,8 +303,8 @@ public:
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance(IteratorA& iterator_A, IteratorB& iterator_B, IteratorScale& iterator_scale,
|
||||
int group_start_A = 0, int group_start_B = 0)
|
||||
void copy_tiles_and_advance(
|
||||
IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0)
|
||||
{
|
||||
iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
|
||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||
@ -414,8 +426,6 @@ public:
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
|
||||
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
@ -586,8 +596,17 @@ public:
|
||||
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros);
|
||||
|
||||
run_warp_mma(
|
||||
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
|
||||
using FragmentOperandB = cutlass::Array<ElementA, Operator::FragmentB::kElements>;
|
||||
constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements;
|
||||
static_assert(ConversionVectorWidth == FragmentOperandB::kElements);
|
||||
|
||||
using Converter
|
||||
= cutlass::NumericArrayConverter<ElementA, ElementScale, ConversionVectorWidth, RoundStyle>;
|
||||
|
||||
FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B);
|
||||
run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum,
|
||||
warp_tileB_k_compute_offset);
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (warp_mma_k < Base::kWarpGemmIterations - 1)
|
||||
@ -597,8 +616,7 @@ public:
|
||||
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
|
||||
|
||||
copy_tiles_and_advance(
|
||||
iterator_A, iterator_B, iterator_scale, group_start_iteration_A, group_start_iteration_B);
|
||||
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
|
||||
|
||||
// This is the first group of a given stage, so we issue the loads for the B scales immediately.
|
||||
if (group_start_iteration_B == 0)
|
||||
@ -613,8 +631,7 @@ public:
|
||||
group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
|
||||
|
||||
copy_tiles_and_advance(
|
||||
iterator_A, iterator_B, iterator_scale, group_start_iteration_A, group_start_iteration_B);
|
||||
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
@ -190,8 +190,9 @@ private:
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave
|
||||
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
@ -482,7 +483,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
@ -548,8 +549,17 @@ public:
|
||||
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
|
||||
|
||||
run_warp_mma(
|
||||
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
|
||||
using FragmentOperandB = cutlass::Array<ElementA, Operator::FragmentB::kElements>;
|
||||
constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements;
|
||||
static_assert(ConversionVectorWidth == FragmentOperandB::kElements);
|
||||
|
||||
using Converter
|
||||
= cutlass::NumericArrayConverter<ElementA, ElementScale, ConversionVectorWidth, RoundStyle>;
|
||||
|
||||
FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B);
|
||||
run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum,
|
||||
warp_tileB_k_compute_offset);
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (warp_mma_k < Base::kWarpGemmIterations - 1)
|
||||
@ -573,7 +583,8 @@ public:
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 -
|
||||
// #committed)
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
|
||||
@ -161,8 +161,9 @@ private:
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave
|
||||
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
|
||||
@ -82,7 +82,7 @@ private:
|
||||
using ComputeInstructionShape = InstructionShape_;
|
||||
|
||||
// Chosen so we get K=16 for int8 and K=32 for int4.
|
||||
static constexpr int LoadInstructionK = 8 * sizeof_bits<ElementA>::value / sizeof_bits<ElementB>::value;
|
||||
static constexpr int LoadInstructionK = 128 / sizeof_bits<ElementB>::value;
|
||||
|
||||
// Shape for loading the narrow data type from shared memory
|
||||
using LoadInstructionShape = GemmShape<InstructionShape_::kM, InstructionShape_::kN, LoadInstructionK>;
|
||||
|
||||
@ -46,6 +46,7 @@
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/arch/mma_sm75.h"
|
||||
#include "cutlass/arch/mma_sm80.h"
|
||||
#include "cutlass/arch/mma_sm89.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/warp/mma.h"
|
||||
@ -131,12 +132,16 @@ public:
|
||||
&& platform::is_same<typename ArchMmaOperator::ElementB, half_t>::value)
|
||||
|| (platform::is_same<typename ArchMmaOperator::ElementA, bfloat16_t>::value
|
||||
&& platform::is_same<typename ArchMmaOperator::ElementB, bfloat16_t>::value
|
||||
&& ArchTag::kMinComputeCapability >= 80),
|
||||
"MmaTensorOpCvtBToA only supports underlying HMMA");
|
||||
&& ArchTag::kMinComputeCapability >= 80)
|
||||
|| (platform::is_same<typename ArchMmaOperator::ElementA, float_e4m3_t>::value
|
||||
&& platform::is_same<typename ArchMmaOperator::ElementB, float_e4m3_t>::value
|
||||
&& ArchTag::kMinComputeCapability >= 89),
|
||||
"MmaTensorOpCvtBToA only supports underlying HMMA/QMMA");
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value
|
||||
|| (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80),
|
||||
"MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+");
|
||||
|| (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80)
|
||||
|| (platform::is_same<ElementA, float_e4m3_t>::value && ArchTag::kMinComputeCapability >= 89),
|
||||
"MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+, or FP8 on Ada");
|
||||
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
|
||||
@ -367,7 +367,8 @@ public:
|
||||
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
|
||||
{
|
||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
||||
using ExpandedMmaOperandB
|
||||
= Array<typename FragmentDequantizedOperand::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
||||
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
|
||||
== FragmentDequantizedOperand::kElements,
|
||||
"");
|
||||
@ -409,7 +410,8 @@ public:
|
||||
FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag)
|
||||
{
|
||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
||||
using ExpandedMmaOperandB
|
||||
= Array<typename FragmentDequantizedOperand::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
||||
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
|
||||
== FragmentDequantizedOperand::kElements,
|
||||
"");
|
||||
|
||||
@ -111,6 +111,16 @@ enum class ClusterShape
|
||||
|
||||
struct CutlassGemmConfig
|
||||
{
|
||||
enum CandidateConfigTypeParam : int
|
||||
{
|
||||
NONE = 0,
|
||||
WEIGHT_ONLY = 1u << 0,
|
||||
SIMT_ONLY = 1u << 1,
|
||||
INT8_ONLY = 1u << 2,
|
||||
HOPPER = 1u << 3,
|
||||
GROUPED_GEMM = 1u << 4,
|
||||
};
|
||||
|
||||
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
|
||||
SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K;
|
||||
int split_k_factor = -1;
|
||||
@ -121,6 +131,7 @@ struct CutlassGemmConfig
|
||||
MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO;
|
||||
EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO;
|
||||
ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1;
|
||||
bool is_sm90 = false;
|
||||
|
||||
CutlassGemmConfig() {}
|
||||
|
||||
@ -129,6 +140,7 @@ struct CutlassGemmConfig
|
||||
, split_k_style(split_k_style)
|
||||
, split_k_factor(split_k_factor)
|
||||
, stages(stages)
|
||||
, is_sm90(false)
|
||||
{
|
||||
}
|
||||
|
||||
@ -138,6 +150,7 @@ struct CutlassGemmConfig
|
||||
, mainloop_schedule(mainloop_schedule)
|
||||
, epilogue_schedule(epilogue_schedule)
|
||||
, cluster_shape(cluster_shape)
|
||||
, is_sm90(true)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:667eb7aaa018b36c8aee79be8c3e9432a29ba33a32c1e1e423e15809a57a40b0
|
||||
size 851008
|
||||
oid sha256:f1990679ad8fbfbcb2b063eb7cef689a5111776cd4bef0af7f792a8ce0d46277
|
||||
size 1202412
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:deb973f60a1a3623f5a014366638bddb71d76ebd7e78f9526c98d873ee4b1b5d
|
||||
size 863464
|
||||
oid sha256:4e6aaf013256a78494afa1f46de2109b6acb572ee1dc27b1a80cb915868c3162
|
||||
size 1218938
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
326f6910f53b9272872d7630a8ce2eea libtensorrt_llm_executor_static.a
|
||||
82f32c59f88aff8eee61aeb1b328ecf1 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
165fe125d6bf55090d8a7dec012d08f8d0e7a54b commit
|
||||
f91339ae7a9840c71f672d960bfd5446 libtensorrt_llm_executor_static.a
|
||||
66f3b01a5c61c22127d263fa97fec0ec libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
83029c1606a00e0e4aaf5ea2de17867a6e5ddd9b commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fdd06e455ac1527ab7c1da4748b88e1e1d98d04c0cbc94ffbf1bf5968b095ed6
|
||||
size 890380
|
||||
oid sha256:bd96ad0a662f7a989a8f2d04581a28504a56f3f642f787d13dd2edbcbb50890c
|
||||
size 1225668
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5f9841068c9ece43d9d7cb95af8dffb1b1d6384399b21cb0ce4f5330846daf7d
|
||||
size 843698
|
||||
oid sha256:e27cf9ea138ba3ab62833e02a6a594ef209316eac0c7031e07c5f1c6cd6fb341
|
||||
size 1180420
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b43a9234c39d36b580b1bdcae5676698a6f11be911ae1f8168db8e2dc89c4d04
|
||||
size 9865220
|
||||
oid sha256:feecaaadbdc3ab649e7b9fca4076b8b745ca4891244edfa62fe09beee4560d34
|
||||
size 12198304
|
||||
|
||||
26
cpp/tensorrt_llm/executor_worker/CMakeLists.txt
Normal file
26
cpp/tensorrt_llm/executor_worker/CMakeLists.txt
Normal file
@ -0,0 +1,26 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
|
||||
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain a copy of
|
||||
# the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
set(SRCS executorWorker.cpp)
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
|
||||
set(EXECUTOR_WORKER_TARGET executorWorker)
|
||||
|
||||
add_executable(${EXECUTOR_WORKER_TARGET} ${SRCS})
|
||||
|
||||
target_link_libraries(${EXECUTOR_WORKER_TARGET}
|
||||
PUBLIC ${SHARED_TARGET} nvinfer_plugin_tensorrt_llm)
|
||||
|
||||
target_compile_features(${EXECUTOR_WORKER_TARGET} PRIVATE cxx_std_17)
|
||||
80
cpp/tensorrt_llm/executor_worker/executorWorker.cpp
Normal file
80
cpp/tensorrt_llm/executor_worker/executorWorker.cpp
Normal file
@ -0,0 +1,80 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mpi.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/mpiUtils.h"
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/executor/serialization.h"
|
||||
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
|
||||
|
||||
namespace tle = tensorrt_llm::executor;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
// Register the TRT-LLM plugins
|
||||
initTrtLlmPlugins();
|
||||
|
||||
tensorrt_llm::mpi::initialize(tensorrt_llm::mpi::MpiThreadSupport::THREAD_MULTIPLE);
|
||||
|
||||
MPI_Comm parentComm;
|
||||
MPI_Comm_get_parent(&parentComm);
|
||||
if (parentComm == MPI_COMM_NULL)
|
||||
{
|
||||
TLLM_LOG_ERROR("TRT-LLM worker has no parent!");
|
||||
return -1;
|
||||
}
|
||||
|
||||
int size;
|
||||
MPI_Comm_remote_size(parentComm, &size);
|
||||
if (size != 1)
|
||||
{
|
||||
TLLM_LOG_ERROR("Parent size is %d, must be 1", size);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Since parentComm is an intercommunicator, input root
|
||||
// is the rank of the parent process in his group
|
||||
// (always 0 as the parent size is checked before)
|
||||
|
||||
// Receive from the parent the executor configuration
|
||||
int64_t bufferSize;
|
||||
MPICHECK(MPI_Bcast(&bufferSize, 1, MPI_INT64_T, 0, parentComm));
|
||||
std::vector<char> buffer(bufferSize);
|
||||
MPICHECK(MPI_Bcast(buffer.data(), bufferSize, MPI_CHAR, 0, parentComm));
|
||||
std::istringstream is(std::string(buffer.begin(), buffer.end()));
|
||||
auto modelPath = tle::Serialization::deserializeString(is);
|
||||
auto modelType = tle::Serialization::deserializeModelType(is);
|
||||
auto executorConfig = tle::Serialization::deserializeExecutorConfig(is);
|
||||
|
||||
// Create the orchestrator config for workers
|
||||
auto orchLeaderComm = std::make_shared<tensorrt_llm::mpi::MpiComm>(parentComm, true);
|
||||
auto parallelConfig = executorConfig.getParallelConfig();
|
||||
TLLM_CHECK_WITH_INFO(parallelConfig.has_value(), "Parallel config should have a value.");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
parallelConfig.value().getOrchestratorConfig().has_value(), "Orchestrator config should have a value.");
|
||||
auto orchConfig = parallelConfig.value().getOrchestratorConfig().value();
|
||||
TLLM_CHECK_WITH_INFO(parallelConfig.has_value(), "Parallel config should have a value.");
|
||||
auto newOrchConfig = tle::OrchestratorConfig(false, orchConfig.getWorkerExecutablePath(), orchLeaderComm);
|
||||
parallelConfig.value().setOrchestratorConfig(newOrchConfig);
|
||||
executorConfig.setParallelConfig(parallelConfig.value());
|
||||
// In orchestrator mode, the spawned threads will wait for termination signal from orchestrator
|
||||
auto executor = tle::Executor(modelPath, modelType, executorConfig);
|
||||
|
||||
TLLM_LOG_INFO("Executor worker exiting");
|
||||
|
||||
return 0;
|
||||
}
|
||||
@ -24,18 +24,18 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
template <typename T, int MAX_K>
|
||||
void topK_softMax_kernelLauncher(
|
||||
template <typename T, int PAD_K>
|
||||
void topKSoftMaxKernelLauncher(
|
||||
T const* logits, T const* bias, void* workspace, BeamHypotheses& bh, cudaStream_t stream);
|
||||
|
||||
#define CASE_K(MAX_K) \
|
||||
topK_softMax_kernelLauncher<T, MAX_K>(logits, bias, workspace, bh, stream); \
|
||||
#define CASE_K(PAD_K) \
|
||||
topKSoftMaxKernelLauncher<T, PAD_K>(logits, bias, workspace, bh, stream); \
|
||||
break;
|
||||
|
||||
template <typename T>
|
||||
void invokeTopkSoftMax(T const* logits, T const* bias, void* workspace, BeamHypotheses& bh, cudaStream_t stream)
|
||||
{
|
||||
switch (padToNextPowerOfTwo(bh.beam_width))
|
||||
switch (padToNextPowerOfTwo(bh.nBeamWidth)) // PAD_K must be a compilation-time constant
|
||||
{
|
||||
case 1:
|
||||
case 2:
|
||||
@ -52,8 +52,9 @@ void invokeTopkSoftMax(T const* logits, T const* bias, void* workspace, BeamHypo
|
||||
CASE_K(64)
|
||||
#endif // FAST_BUILD
|
||||
default:
|
||||
throw std::runtime_error(fmtstr(
|
||||
"%s:%d Topk kernel of beam search does not support beam_width=%d", __FILE__, __LINE__, bh.beam_width));
|
||||
throw std::runtime_error(
|
||||
fmtstr("%s:%d Maximum beam width supported for beam search (%d) is larger than beam_width now use (%d)",
|
||||
__FILE__, __LINE__, nMaxBeamWidth, bh.nBeamWidth));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -22,64 +22,64 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
static constexpr int nMaxBeamWidth = 64; // max beam width supported now
|
||||
static constexpr int nSmallTopKBlockSize = 256;
|
||||
static constexpr int nSmallTopKMaxVocParts = 128;
|
||||
static constexpr int nBlockSizeForSmallBeamWidth = 256;
|
||||
static constexpr int nMaxVocabPartForStage1FastKernel = 128;
|
||||
|
||||
struct BeamHypotheses
|
||||
{
|
||||
// clang-format off
|
||||
|
||||
// BS: batch_size, BM: beam_width, mSL: max_seq_length
|
||||
// BS: batch_size, BM: beam_width, MSL: max_seq_length
|
||||
// %%: parameter name when dynamic_decoder.forward() / gather_tree() are called in [generation.py] (python workflow)
|
||||
|
||||
// Candidate beams: When a beam generates end_id or its sequence length reaches mSL, it becomes a candidate beam to be selected finally.
|
||||
// Candidate-Beam-Array (CBA): Arrays (size: BM*2) to place the candidate beams and related information
|
||||
// Candidate beams: a beam which generates end_id or its sequence length reaches MSL
|
||||
// Candidate-Beam-Array (CBA): The arrays (size: BM*2) to place the candidate beams and related information
|
||||
|
||||
// Scalar values
|
||||
bool is_return_normed_score{true}; // return normed_score / cum_log_probs, useless yet
|
||||
int batch_size{0}; //
|
||||
int beam_width{0}; //
|
||||
int ite{0}; // index of local_batch, always be 0 when pp_size==1
|
||||
int local_batch_size{0}; //
|
||||
int max_seq_len{0}; //
|
||||
int vocab_size{0}; // vocab_size_padded
|
||||
bool bReturnNormedScore{false}; // return normed_score / cum_log_probs, useless yet
|
||||
int nBatchSize{0}; //
|
||||
int nBeamWidth{0}; //
|
||||
int nIte{0}; // index of local_batch, always be 0 when pp_size==1
|
||||
int nBatchSizeLocal{0}; //
|
||||
int nMaxSeqLen{0}; //
|
||||
int nVocabSize{0}; // vocab_size_padded
|
||||
|
||||
// Pointers from SamplingConfig
|
||||
float const* diversity_rates{nullptr}; // [BS]
|
||||
float const* length_penalties{nullptr}; // [BS]
|
||||
int const* early_stoppings{nullptr}; // [BS]
|
||||
float const* diversityRates{nullptr}; // [BS]
|
||||
float const* lengthPenalties{nullptr}; // [BS]
|
||||
int const* earlyStoppings{nullptr}; // [BS]
|
||||
|
||||
// Pointers from input
|
||||
int const* input_lengths{nullptr}; // [BS, BM] %% context_length
|
||||
int const* end_ids{nullptr}; // [BS, BM] %% self.end_ids
|
||||
int const* inputLengths{nullptr}; // [BS, BM] %% context_length
|
||||
int const* endIds{nullptr}; // [BS, BM] %% self.end_ids
|
||||
|
||||
// Pointers for output
|
||||
int* final_output_ids{nullptr}; // [BS, BM, mSL] %% self.output_ids
|
||||
float* log_probs{nullptr}; // [mSL, BS, BM] %% self.log_probs_tiled
|
||||
int* seq_len{nullptr}; // [BS, BM] %% self.sequence_length_buffer
|
||||
float* cum_log_probs{nullptr}; // [BS, BM] %% self.cum_log_probs
|
||||
int* outputIds{nullptr}; // [BS, BM, MSL] %% self.output_ids
|
||||
float* logProbs{nullptr}; // [MSL, BS, BM] %% self.log_probs_tiled
|
||||
int* sequenceLengths{nullptr}; // [BS, BM] %% self.sequence_length_buffer
|
||||
float* cumLogProbs{nullptr}; // [BS, BM] %% self.cum_log_probs
|
||||
|
||||
// Pointers of CBA
|
||||
int* output_ids_cba{nullptr}; // [BS, BM*2, mSL] %% self.beam_hyps_output_ids_cba
|
||||
float* log_probs_cba{nullptr}; // [BS, BM*2, mSL] %% self.beam_hyps_log_probs_cba
|
||||
int* seq_len_cba{nullptr}; // [BS, BM*2] %% self.beam_hyps_seq_len_cba
|
||||
float* cum_log_probs_cba{nullptr}; // [BS, BM*2] %% self.beam_hyps_cum_log_probs_cba
|
||||
float* normed_scores_cba{nullptr}; // [BS, BM*2] %% self.beam_hyps_normed_scores_cba
|
||||
int* num_beams{nullptr}; // [BS] %% self.beam_hyps_num_beams number of beams in CBA
|
||||
float* min_normed_scores{nullptr}; // [BS] %% self.beam_hyps_min_normed_scores worst score in CBA
|
||||
int* outputIdsCBA{nullptr}; // [BS, BM*2, MSL] %% self.beam_hyps_output_ids_cba
|
||||
float* logProbsCBA{nullptr}; // [BS, BM*2, MSL] %% self.beam_hyps_log_probs_cba
|
||||
int* sequenceLengthsCBA{nullptr}; // [BS, BM*2] %% self.beam_hyps_seq_len_cba
|
||||
float* cumLogProbsCBA{nullptr}; // [BS, BM*2] %% self.beam_hyps_cum_log_probs_cba
|
||||
float* normedScoresCBA{nullptr}; // [BS, BM*2] %% self.beam_hyps_normed_scores_cba
|
||||
int* numBeamsCBA{nullptr}; // [BS] %% self.beam_hyps_num_beams number of beams in CBA
|
||||
float* minNormedScoresCBA{nullptr}; // [BS] %% self.beam_hyps_min_normed_scores worst score in CBA
|
||||
|
||||
// Pointers related to beam search process, they are initialized in those two functions:
|
||||
// [gptDecoder.cpp] GptDecoder<T>::forward or [dynamicDecodeOp.cpp] FtDynamicDecode<T>::forward
|
||||
bool* is_done{nullptr}; // [BS] %% self.beam_hyps_is_done whether a whole batch is finished
|
||||
FinishedState* finished; // [BS*BM] %% self.finished whether and how a beam is finished
|
||||
bool* batchDones{nullptr}; // [BS] %% self.beam_hyps_is_done whether a whole batch is finished
|
||||
FinishedState* finished{nullptr}; // [BS*BM] %% self.finished whether and how a beam is finished
|
||||
|
||||
// Pointers for backtrack of the beams, they are relocated in [dynamicDecodeLayer.cpp] DynamicDecodeLayer<T>::prepareIdsPtrs
|
||||
int** output_ids_ptr{nullptr}; // [BS][BM, mSL] %% self.output_ids
|
||||
int** parent_ids_ptr{nullptr}; // [BS][BM, mSL] %% self.parent_ids
|
||||
int** outputIdsPtr{nullptr}; // [BS][BM, MSL] %% self.output_ids
|
||||
int** parentIdsPtr{nullptr}; // [BS][BM, MSL] %% self.parent_ids
|
||||
|
||||
// Pointers for gather_tree(), read the unfinished beams from them and write to CBA for the final selection
|
||||
int const* output_ids_src{nullptr}; // [BS, BM, mSL] %% self.output_ids
|
||||
int const* parent_ids_src{nullptr}; // [BS, BM, mSL] %% self.parent_ids
|
||||
int const* outputIdsUnfinish{nullptr}; // [BS, BM, MSL] %% self.output_ids
|
||||
int const* parentIdsUnfinish{nullptr}; // [BS, BM, MSL] %% self.parent_ids
|
||||
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -730,11 +730,8 @@ bool FusedMHARunnerV2::isValid(int s) const
|
||||
// static function to check if fmha is supported when building plugins
|
||||
bool MHARunner::fmha_supported(int const headSize, int const sm)
|
||||
{
|
||||
|
||||
return (headSize == 32 || headSize == 40 || headSize == 64 || headSize == 80 || headSize == 96 || headSize == 104
|
||||
|| headSize == 128 || headSize == 160 || headSize == 192 || headSize == 256);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
|
||||
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
@ -15,9 +15,6 @@
|
||||
# the License.
|
||||
#
|
||||
|
||||
file(GLOB_RECURSE SRC_CPP *.cpp)
|
||||
file(GLOB_RECURSE SRC_CU *.cu)
|
||||
|
||||
# The Python executable will only be defined if building with Torch support. If
|
||||
# not, we need to find it here.
|
||||
if(NOT Python3_EXECUTABLE)
|
||||
@ -43,10 +40,12 @@ set_directory_properties(
|
||||
PROPERTIES CMAKE_CONFIGURE_DEPENDS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/python/generate_kernels.py)
|
||||
|
||||
set(INSTANTIATION_GENERATION_DIR
|
||||
${CMAKE_CURRENT_BINARY_DIR}/cutlass_instantiations)
|
||||
execute_process(
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/python/
|
||||
COMMAND ${Python3_EXECUTABLE} generate_kernels.py -o
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
${INSTANTIATION_GENERATION_DIR}
|
||||
RESULT_VARIABLE _KERNEL_GEN_SUCCESS)
|
||||
|
||||
if(NOT _KERNEL_GEN_SUCCESS MATCHES 0)
|
||||
@ -56,41 +55,73 @@ if(NOT _KERNEL_GEN_SUCCESS MATCHES 0)
|
||||
)
|
||||
endif()
|
||||
|
||||
file(GLOB_RECURSE CU_INSTANTIATIONS ${CMAKE_CURRENT_BINARY_DIR}/*.cu)
|
||||
# Get the sources for Mixed Input GEMM launchers
|
||||
file(GLOB_RECURSE MIXED_CU_INSTANTIATIONS
|
||||
${INSTANTIATION_GENERATION_DIR}/gemm/*.cu)
|
||||
file(GLOB_RECURSE MIXED_SRC_CPP fpA_intB_gemm/*.cpp)
|
||||
file(GLOB_RECURSE MIXED_SRC_CU fpA_intB_gemm/*.cu)
|
||||
|
||||
add_library(cutlass2_src STATIC ${SRC_CPP} ${SRC_CU})
|
||||
set_property(TARGET cutlass2_src PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET cutlass2_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
|
||||
# Get the sources for MOE Grouped GEMM launchers
|
||||
file(GLOB_RECURSE GROUPED_CU_INSTANTIATIONS
|
||||
${INSTANTIATION_GENERATION_DIR}/gemm_grouped/*.cu)
|
||||
file(GLOB_RECURSE GROUPED_SRC_CPP moe_gemm/*.cpp)
|
||||
file(GLOB_RECURSE GROUPED_SRC_CU moe_gemm/*.cu)
|
||||
|
||||
add_library(cutlass3_src STATIC ${CU_INSTANTIATIONS})
|
||||
set_property(TARGET cutlass3_src PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET cutlass3_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
|
||||
# Get the sources for all the remaining sources
|
||||
file(GLOB_RECURSE SRC_CPP *.cpp)
|
||||
file(GLOB_RECURSE SRC_CU *.cu)
|
||||
set(ALL_SRCS ${SRC_CPP};${SRC_CU})
|
||||
list(FILTER ALL_SRCS EXCLUDE REGEX "fpA_intB_gemm/.*")
|
||||
list(FILTER ALL_SRCS EXCLUDE REGEX "moe_gemm/.*")
|
||||
|
||||
# Note - we deliberately do not include 90a PTX (even when 9.0+PTX is
|
||||
# specified). This is because sm_90a has arch conditional instructions that are
|
||||
# not forward compatible. As a result, it does not make sense to embed PTX into
|
||||
# the binary anyway.
|
||||
if("9.0" IN_LIST TORCH_CUDA_ARCH_LIST
|
||||
OR "9.0+PTX" IN_LIST TORCH_CUDA_ARCH_LIST
|
||||
OR "90-real" IN_LIST CMAKE_CUDA_ARCHITECTURES_NATIVE)
|
||||
message(STATUS "MANUALLY APPENDING FLAG TO COMPILE FOR SM_90a.")
|
||||
target_compile_options(
|
||||
cutlass3_src
|
||||
PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-gencode=arch=compute_90a,code=sm_90a>)
|
||||
message(
|
||||
STATUS
|
||||
"Mixed srcs ${MIXED_SRC_CPP} ${MIXED_SRC_CU} ${MIXED_CU_INSTANTIATIONS}")
|
||||
message(
|
||||
STATUS
|
||||
"Group srcs ${GROUPED_SRC_CU} ${GROUPED_SRC_CPP} ${GROUPED_CU_INSTANTIATIONS}"
|
||||
)
|
||||
message(STATUS "All srcs ${ALL_SRCS}")
|
||||
|
||||
# Hopper kernels require cuda lib for TMA APIs
|
||||
target_link_libraries(cutlass3_src PRIVATE CUDA::cuda_driver)
|
||||
add_library(cutlass_src STATIC ${ALL_SRCS})
|
||||
set_property(TARGET cutlass_src PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET cutlass_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
|
||||
|
||||
# No kernels should be parsed, unless hopper is specified. This is a build
|
||||
# time improvement
|
||||
target_compile_definitions(cutlass3_src
|
||||
PRIVATE COMPILE_HOPPER_MIXED_INPUT_GEMMS)
|
||||
endif()
|
||||
add_library(fpA_intB_gemm_src STATIC ${MIXED_SRC_CPP} ${MIXED_SRC_CU}
|
||||
${MIXED_CU_INSTANTIATIONS})
|
||||
add_library(moe_gemm_src STATIC ${GROUPED_SRC_CU} ${GROUPED_SRC_CPP}
|
||||
${GROUPED_CU_INSTANTIATIONS})
|
||||
foreach(target_name fpA_intB_gemm_src;moe_gemm_src)
|
||||
set_property(TARGET ${target_name} PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET ${target_name} PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
|
||||
|
||||
# Suppress GCC note: the ABI for passing parameters with 64-byte alignment has
|
||||
# changed in GCC 4.6 This note appears for kernels using TMA and clutters the
|
||||
# compilation output.
|
||||
if(NOT WIN32)
|
||||
target_compile_options(
|
||||
cutlass3_src PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-psabi>)
|
||||
endif()
|
||||
# Note - we deliberately do not include 90a PTX (even when 9.0+PTX is
|
||||
# specified). This is because sm_90a has arch conditional instructions that
|
||||
# are not forward compatible. As a result, it does not make sense to embed PTX
|
||||
# into the binary anyway.
|
||||
if("9.0" IN_LIST TORCH_CUDA_ARCH_LIST
|
||||
OR "9.0+PTX" IN_LIST TORCH_CUDA_ARCH_LIST
|
||||
OR "90-real" IN_LIST CMAKE_CUDA_ARCHITECTURES_NATIVE)
|
||||
|
||||
message(STATUS "MANUALLY APPENDING FLAG TO COMPILE FOR SM_90a.")
|
||||
target_compile_options(
|
||||
${target_name}
|
||||
PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-gencode=arch=compute_90a,code=sm_90a>)
|
||||
|
||||
# Hopper kernels require cuda lib for TMA APIs
|
||||
target_link_libraries(${target_name} PRIVATE CUDA::cuda_driver)
|
||||
|
||||
# No kernels should be parsed, unless hopper is specified. This is a build
|
||||
# time improvement
|
||||
target_compile_definitions(${target_name} PRIVATE COMPILE_HOPPER_TMA_GEMMS)
|
||||
endif()
|
||||
|
||||
# Suppress GCC note: the ABI for passing parameters with 64-byte alignment has
|
||||
# changed in GCC 4.6 This note appears for kernels using TMA and clutters the
|
||||
# compilation output.
|
||||
if(NOT WIN32)
|
||||
target_compile_options(
|
||||
${target_name} PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-psabi>)
|
||||
endif()
|
||||
|
||||
endforeach()
|
||||
|
||||
@ -24,6 +24,7 @@
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
|
||||
#ifndef _WIN32
|
||||
#pragma GCC diagnostic pop
|
||||
@ -65,7 +66,7 @@ TileShape get_cta_shape_for_config(CutlassTileConfig tile_config)
|
||||
case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: return TileShape{128, 128};
|
||||
case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: return TileShape{128, 256};
|
||||
case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: return TileShape{256, 128};
|
||||
default: throw std::runtime_error("[TensorRT-LLm Error][get_grid_shape_for_config] Invalid config");
|
||||
default: TLLM_THROW("[get_grid_shape_for_config] Invalid config");
|
||||
}
|
||||
}
|
||||
|
||||
@ -110,7 +111,7 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k,
|
||||
}
|
||||
|
||||
std::vector<CutlassTileConfig> get_candidate_tiles(
|
||||
int const sm, bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only)
|
||||
int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param)
|
||||
{
|
||||
enum class CutlassGemmType : char
|
||||
{
|
||||
@ -121,15 +122,15 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
|
||||
};
|
||||
|
||||
CutlassGemmType gemm_type = CutlassGemmType::Default;
|
||||
if (simt_configs_only)
|
||||
if (config_type_param & CutlassGemmConfig::SIMT_ONLY)
|
||||
{
|
||||
gemm_type = CutlassGemmType::Simt;
|
||||
}
|
||||
else if (is_weight_only)
|
||||
else if (config_type_param & CutlassGemmConfig::WEIGHT_ONLY)
|
||||
{
|
||||
gemm_type = CutlassGemmType::WeightOnly;
|
||||
}
|
||||
else if (int8_configs_only)
|
||||
else if (config_type_param & CutlassGemmConfig::INT8_ONLY)
|
||||
{
|
||||
gemm_type = CutlassGemmType::Int8;
|
||||
}
|
||||
@ -170,39 +171,21 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
|
||||
}
|
||||
|
||||
std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(
|
||||
int const sm, bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only)
|
||||
int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config)
|
||||
{
|
||||
enum class CutlassGemmType : char
|
||||
if (config & CutlassGemmConfig::GROUPED_GEMM)
|
||||
{
|
||||
Default,
|
||||
WeightOnly,
|
||||
Simt,
|
||||
Int8
|
||||
};
|
||||
|
||||
CutlassGemmType gemm_type = CutlassGemmType::Default;
|
||||
if (simt_configs_only)
|
||||
{
|
||||
gemm_type = CutlassGemmType::Simt;
|
||||
return {CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x256x128B};
|
||||
}
|
||||
else if (is_weight_only)
|
||||
else
|
||||
{
|
||||
gemm_type = CutlassGemmType::WeightOnly;
|
||||
}
|
||||
else if (int8_configs_only)
|
||||
{
|
||||
gemm_type = CutlassGemmType::Int8;
|
||||
}
|
||||
|
||||
switch (gemm_type)
|
||||
{
|
||||
case CutlassGemmType::WeightOnly:
|
||||
return {CutlassTileConfigSM90::CtaShape64x16x128B, CutlassTileConfigSM90::CtaShape64x32x128B,
|
||||
CutlassTileConfigSM90::CtaShape64x64x128B, CutlassTileConfigSM90::CtaShape64x128x128B,
|
||||
CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x16x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B};
|
||||
default: throw std::runtime_error("get_candidate_tiles_sm90 only supports WeightOnly now.");
|
||||
}
|
||||
}
|
||||
|
||||
@ -226,13 +209,12 @@ bool supports_mcast_along_n(const CutlassTileConfigSM90 tile)
|
||||
return valid_tiles.count(tile) == 1;
|
||||
}
|
||||
|
||||
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, bool const is_weight_only, bool const simt_configs_only,
|
||||
bool const int8_configs_only, int const max_split_k, bool const enable_hopper_gmma)
|
||||
std::vector<CutlassGemmConfig> get_candidate_configs(
|
||||
int sm, int const max_split_k, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param)
|
||||
{
|
||||
if (sm == 90 && enable_hopper_gmma)
|
||||
if (sm == 90 && (config_type_param & CutlassGemmConfig::HOPPER))
|
||||
{
|
||||
std::vector<CutlassTileConfigSM90> tiles
|
||||
= get_candidate_tiles_sm90(sm, is_weight_only, simt_configs_only, int8_configs_only);
|
||||
std::vector<CutlassTileConfigSM90> tiles = get_candidate_tiles_sm90(sm, config_type_param);
|
||||
|
||||
std::vector<CutlassGemmConfig> candidate_configs;
|
||||
for (auto const& tile_config : tiles)
|
||||
@ -266,10 +248,10 @@ std::vector<CutlassGemmConfig> get_candidate_configs(int sm, bool const is_weigh
|
||||
}
|
||||
return candidate_configs;
|
||||
}
|
||||
std::vector<CutlassTileConfig> tiles
|
||||
= get_candidate_tiles(sm, is_weight_only, simt_configs_only, int8_configs_only);
|
||||
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, config_type_param);
|
||||
|
||||
std::vector<CutlassGemmConfig> candidate_configs;
|
||||
bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY;
|
||||
int const min_stages = int8_configs_only ? 3 : 2;
|
||||
int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);
|
||||
for (auto const& tile_config : tiles)
|
||||
@ -299,8 +281,8 @@ CutlassGemmConfig estimate_best_config_from_occupancies(std::vector<CutlassGemmC
|
||||
|
||||
if (occupancies.size() != candidate_configs.size())
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"[TensorRT-LLm Error][estimate_best_config_from_occupancies] occpancies and "
|
||||
TLLM_THROW(
|
||||
"[estimate_best_config_from_occupancies] occpancies and "
|
||||
"candidate configs vectors must have equal length.");
|
||||
}
|
||||
|
||||
@ -374,7 +356,7 @@ CutlassGemmConfig estimate_best_config_from_occupancies(std::vector<CutlassGemmC
|
||||
|
||||
if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic)
|
||||
{
|
||||
throw std::runtime_error("[TensorRT-LLm Error] Heurisitc failed to find a valid config.");
|
||||
TLLM_THROW("Heurisitc failed to find a valid config.");
|
||||
}
|
||||
|
||||
return best_config;
|
||||
|
||||
@ -26,9 +26,8 @@ namespace kernels
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
|
||||
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> get_candidate_configs(int sm,
|
||||
bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only = false,
|
||||
int const max_split_k = 1, bool const enable_hopper_gmma = false);
|
||||
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> get_candidate_configs(
|
||||
int sm, int const max_split_k, tensorrt_llm::cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam const);
|
||||
|
||||
tensorrt_llm::cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies(
|
||||
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> const& candidate_configs,
|
||||
|
||||
@ -30,16 +30,6 @@ namespace kernels
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
|
||||
int get_bits_in_quant_type(QuantType quant_type)
|
||||
{
|
||||
switch (quant_type)
|
||||
{
|
||||
case QuantType::INT8_WEIGHT_ONLY: return 8;
|
||||
case QuantType::PACKED_INT4_WEIGHT_ONLY: return 4;
|
||||
default: TLLM_CHECK_WITH_INFO(false, "Invalid quant_type"); return -1;
|
||||
}
|
||||
}
|
||||
|
||||
struct LayoutDetails
|
||||
{
|
||||
enum class Layout
|
||||
@ -96,11 +86,11 @@ struct getLayoutDetails<cutlass::layout::ColumnMajorTileInterleave<RowsPerTile,
|
||||
}
|
||||
};
|
||||
|
||||
template <typename cutlassArch, typename TypeB>
|
||||
template <typename cutlassArch, typename TypeA, typename TypeB>
|
||||
LayoutDetails getLayoutDetailsForArchAndQuantType()
|
||||
{
|
||||
|
||||
using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB<TypeB, cutlassArch>;
|
||||
using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB<TypeA, TypeB, cutlassArch>;
|
||||
using LayoutB = typename CompileTraits::Layout;
|
||||
using MmaOperator = typename CompileTraits::Operator;
|
||||
LayoutDetails details = getLayoutDetails<LayoutB>()();
|
||||
@ -111,18 +101,20 @@ LayoutDetails getLayoutDetailsForArchAndQuantType()
|
||||
template <typename cutlassArch>
|
||||
LayoutDetails getLayoutDetailsForArch(QuantType quant_type)
|
||||
{
|
||||
int const bits_per_weight_element = get_weight_quant_bits(quant_type);
|
||||
LayoutDetails details;
|
||||
if (quant_type == QuantType::INT8_WEIGHT_ONLY)
|
||||
switch (quant_type)
|
||||
{
|
||||
details = getLayoutDetailsForArchAndQuantType<cutlassArch, uint8_t>();
|
||||
}
|
||||
else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY)
|
||||
{
|
||||
details = getLayoutDetailsForArchAndQuantType<cutlassArch, cutlass::uint4b_t>();
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false, "Unsupported quantization type");
|
||||
case QuantType::W8_A16:
|
||||
details = getLayoutDetailsForArchAndQuantType<cutlassArch, cutlass::half_t, uint8_t>();
|
||||
break;
|
||||
case QuantType::W4_A16:
|
||||
details = getLayoutDetailsForArchAndQuantType<cutlassArch, cutlass::half_t, cutlass::uint4b_t>();
|
||||
break;
|
||||
case QuantType::W4_AFP8:
|
||||
details = getLayoutDetailsForArchAndQuantType<cutlassArch, cutlass::float_e4m3_t, cutlass::uint4b_t>();
|
||||
break;
|
||||
default: TLLM_THROW("Unsupported quantization type");
|
||||
}
|
||||
return details;
|
||||
}
|
||||
@ -137,7 +129,7 @@ LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch)
|
||||
{
|
||||
return getLayoutDetailsForArch<cutlass::arch::Sm75>(quant_type);
|
||||
}
|
||||
else if (arch >= 80 && arch <= 89)
|
||||
else if (arch >= 80 && arch < 90)
|
||||
{
|
||||
return getLayoutDetailsForArch<cutlass::arch::Sm80>(quant_type);
|
||||
}
|
||||
@ -152,25 +144,54 @@ LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch)
|
||||
}
|
||||
}
|
||||
|
||||
// Permutes the rows of B for Turing and Ampere. Throws an error for other architectures.
|
||||
// Permutes the rows of B in a way that is compatible with Turing+ architectures.
|
||||
//
|
||||
// Throws an error for other architectures.
|
||||
// The data is permuted such that:
|
||||
// For int8, each group of 16 rows is permuted using the map below:
|
||||
// For W8_A16, each group of 16 rows is permuted using the map below:
|
||||
// 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15
|
||||
// For int4, each group of 32 rows is permuted using the map below:
|
||||
// For W4_A16, each group of 32 rows is permuted using the map below:
|
||||
// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31
|
||||
// For W4_A8, see the map in the code. The idea is similar to above.
|
||||
// The goal of this permutation is to ensure data ends up in the correct threads after
|
||||
// we execute LDSM. It counteracts the effect of the data being of different widths.
|
||||
// For more information about the expected layouts, see the MMA section in the PTX docs.
|
||||
std::vector<int> get_permutation_map(QuantType quant_type)
|
||||
{
|
||||
|
||||
if (quant_type == QuantType::W8_A16)
|
||||
{
|
||||
return {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
|
||||
}
|
||||
else if (quant_type == QuantType::W4_A16)
|
||||
{
|
||||
return {0, 1, 8, 9, 16, 17, 24, 25, 2, 3, 10, 11, 18, 19, 26, 27, 4, 5, 12, 13, 20, 21, 28, 29, 6, 7, 14, 15,
|
||||
22, 23, 30, 31};
|
||||
}
|
||||
else if (quant_type == QuantType::W4_AFP8)
|
||||
{
|
||||
return {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23, 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15,
|
||||
28, 29, 30, 31};
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Invalid quantization type for LDSM permutation");
|
||||
}
|
||||
}
|
||||
|
||||
void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor,
|
||||
std::vector<size_t> const& shape, QuantType quant_type, const int64_t arch_version)
|
||||
std::vector<size_t> const& shape, QuantType quant_type, int64_t const arch_version)
|
||||
{
|
||||
|
||||
// We only want to run this step for weight only quant.
|
||||
TLLM_CHECK(quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY || quant_type == QuantType::INT8_WEIGHT_ONLY);
|
||||
std::vector<int> row_permutation = get_permutation_map(quant_type);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
|
||||
const size_t num_experts = shape.size() == 2 ? 1 : shape[0];
|
||||
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
|
||||
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
|
||||
|
||||
int const BITS_PER_ELT = get_bits_in_quant_type(quant_type);
|
||||
int const BITS_PER_ELT = get_weight_quant_bits(quant_type);
|
||||
int const K = 16 / BITS_PER_ELT;
|
||||
int const ELTS_PER_BYTE = 8 / BITS_PER_ELT;
|
||||
int const ELTS_PER_REG = 32 / BITS_PER_ELT;
|
||||
@ -194,7 +215,8 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t con
|
||||
fmtstr("Invalid shape for quantized tensor. On turing/Ampere, the number of cols must be a multiple of %d.",
|
||||
MMA_SHAPE_N));
|
||||
|
||||
// The code is written as below so it works for both int8 and packed int4.
|
||||
TLLM_CHECK_WITH_INFO(size_t(B_ROWS_PER_MMA) == row_permutation.size(), "Unexpected number of LDSM rows permuted.");
|
||||
|
||||
for (int expert = 0; expert < num_experts; ++expert)
|
||||
{
|
||||
const int64_t matrix_offset = expert * int64_t(num_rows) * int64_t(num_vec_cols);
|
||||
@ -206,8 +228,7 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t con
|
||||
for (int write_col = 0; write_col < num_vec_cols; ++write_col)
|
||||
{
|
||||
int const write_row = base_row + tile_row;
|
||||
int const tile_read_row
|
||||
= 8 * (((tile_row % ELTS_PER_REG) / 2)) + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG);
|
||||
int const tile_read_row = row_permutation[tile_row];
|
||||
int const read_row = base_row + tile_read_row;
|
||||
int const read_col = write_col;
|
||||
|
||||
@ -229,7 +250,7 @@ template <QuantType quant_type>
|
||||
void subbyte_transpose_impl(
|
||||
int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, std::vector<size_t> const& shape)
|
||||
{
|
||||
int const bits_per_elt = get_bits_in_quant_type(quant_type);
|
||||
constexpr int bits_per_elt = get_weight_quant_bits(quant_type);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
|
||||
const size_t num_experts = shape.size() == 2 ? 1 : shape[0];
|
||||
@ -243,8 +264,7 @@ void subbyte_transpose_impl(
|
||||
uint8_t const* input_byte_ptr = reinterpret_cast<uint8_t const*>(quantized_tensor);
|
||||
uint8_t* output_byte_ptr = reinterpret_cast<uint8_t*>(transposed_quantized_tensor);
|
||||
|
||||
static_assert(quant_type == QuantType::INT8_WEIGHT_ONLY || quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY, "");
|
||||
static constexpr int ELTS_PER_BYTE = quant_type == QuantType::INT8_WEIGHT_ONLY ? 1 : 2;
|
||||
static constexpr int ELTS_PER_BYTE = 8 / bits_per_elt;
|
||||
|
||||
static constexpr int M_TILE_L1 = 64;
|
||||
static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE;
|
||||
@ -294,7 +314,7 @@ void subbyte_transpose_impl(
|
||||
}
|
||||
}
|
||||
|
||||
if (quant_type == QuantType::INT8_WEIGHT_ONLY)
|
||||
if constexpr (bits_per_elt == 8)
|
||||
{
|
||||
for (int ii = 0; ii < M_TILE_L1; ++ii)
|
||||
{
|
||||
@ -304,7 +324,7 @@ void subbyte_transpose_impl(
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY)
|
||||
else if constexpr (bits_per_elt == 4)
|
||||
{
|
||||
|
||||
for (int ii = 0; ii < M_TILE_L1; ++ii)
|
||||
@ -368,14 +388,17 @@ void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quanti
|
||||
std::vector<size_t> const& shape, QuantType quant_type)
|
||||
{
|
||||
|
||||
if (quant_type == QuantType::INT8_WEIGHT_ONLY)
|
||||
if (quant_type == QuantType::W8_A16)
|
||||
{
|
||||
subbyte_transpose_impl<QuantType::INT8_WEIGHT_ONLY>(transposed_quantized_tensor, quantized_tensor, shape);
|
||||
subbyte_transpose_impl<QuantType::W8_A16>(transposed_quantized_tensor, quantized_tensor, shape);
|
||||
}
|
||||
else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY)
|
||||
else if (quant_type == QuantType::W4_A16)
|
||||
{
|
||||
subbyte_transpose_impl<QuantType::PACKED_INT4_WEIGHT_ONLY>(
|
||||
transposed_quantized_tensor, quantized_tensor, shape);
|
||||
subbyte_transpose_impl<QuantType::W4_A16>(transposed_quantized_tensor, quantized_tensor, shape);
|
||||
}
|
||||
else if (quant_type == QuantType::W4_AFP8)
|
||||
{
|
||||
subbyte_transpose_impl<QuantType::W4_AFP8>(transposed_quantized_tensor, quantized_tensor, shape);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -464,12 +487,16 @@ void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const siz
|
||||
|
||||
void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type)
|
||||
{
|
||||
if (quant_type == QuantType::INT8_WEIGHT_ONLY)
|
||||
if (quant_type == QuantType::W8_A16)
|
||||
{
|
||||
add_bias_and_interleave_int8s_inplace(tensor, num_elts);
|
||||
}
|
||||
else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY)
|
||||
else if (quant_type == QuantType::W4_A16 || quant_type == QuantType::W4_AFP8)
|
||||
{
|
||||
// W4_AFP8 uses the same preprocessor as W4_A16 because the FP8 data must
|
||||
// be converted to FP16 before the scales can be applied using CUDA cores.
|
||||
// As a result, we still want permute the data so that it is well aligned
|
||||
// for conversion to FP16.
|
||||
add_bias_and_interleave_int4s_inplace(tensor, num_elts);
|
||||
}
|
||||
else
|
||||
@ -482,15 +509,12 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, int8_t
|
||||
std::vector<size_t> const& shape, QuantType quant_type, LayoutDetails details)
|
||||
{
|
||||
|
||||
// We only want to run this step for weight only quant.
|
||||
TLLM_CHECK(quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY || quant_type == QuantType::INT8_WEIGHT_ONLY);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
|
||||
const size_t num_experts = shape.size() == 2 ? 1 : shape[0];
|
||||
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
|
||||
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
|
||||
|
||||
int const BITS_PER_ELT = get_bits_in_quant_type(quant_type);
|
||||
int const BITS_PER_ELT = get_weight_quant_bits(quant_type);
|
||||
int const elts_in_int32 = 32 / BITS_PER_ELT;
|
||||
|
||||
int const rows_per_tile = details.rows_per_column_tile;
|
||||
@ -551,7 +575,7 @@ void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, in
|
||||
num_elts *= dim;
|
||||
}
|
||||
|
||||
const size_t num_bytes = num_elts * get_bits_in_quant_type(quant_type) / 8;
|
||||
const size_t num_bytes = num_elts * get_weight_quant_bits(quant_type) / 8;
|
||||
|
||||
std::vector<int8_t> src_buf(num_bytes);
|
||||
std::vector<int8_t> dst_buf(num_bytes);
|
||||
@ -633,9 +657,11 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
|
||||
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
|
||||
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
|
||||
|
||||
int const bits_in_type = get_bits_in_quant_type(quant_type);
|
||||
int const bits_in_type = get_weight_quant_bits(quant_type);
|
||||
int const bytes_per_out_col = num_cols * bits_in_type / 8;
|
||||
|
||||
int const bits_per_weigtht_element = get_weight_quant_bits(quant_type);
|
||||
|
||||
std::vector<int8_t> weight_buf;
|
||||
if (unprocessed_quantized_weight == nullptr)
|
||||
{
|
||||
@ -685,15 +711,15 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
|
||||
for (int jj = 0; jj < bytes_per_out_col; ++jj)
|
||||
{
|
||||
|
||||
if (quant_type == QuantType::INT8_WEIGHT_ONLY)
|
||||
if (bits_per_weigtht_element == 8)
|
||||
{
|
||||
float const col_scale = per_col_max[jj];
|
||||
float const weight_elt = float(current_weight_row[jj]);
|
||||
float const scaled_weight = round(weight_elt / col_scale);
|
||||
float const scaled_weight = (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f;
|
||||
const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight)));
|
||||
current_quantized_weight_row[jj] = clipped_weight;
|
||||
}
|
||||
else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY)
|
||||
else if (bits_per_weigtht_element == 4)
|
||||
{
|
||||
|
||||
// We will pack two int4 elements per iteration of the inner loop.
|
||||
@ -705,7 +731,7 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
|
||||
{
|
||||
float const col_scale = per_col_max[input_idx];
|
||||
float const weight_elt = float(current_weight_row[input_idx]);
|
||||
float const scaled_weight = round(weight_elt / col_scale);
|
||||
float const scaled_weight = (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f;
|
||||
int int_weight = int(scaled_weight);
|
||||
const int8_t clipped_weight = std::max(-8, std::min(7, int_weight));
|
||||
|
||||
|
||||
@ -31,10 +31,21 @@ namespace cutlass_kernels
|
||||
|
||||
enum class QuantType
|
||||
{
|
||||
INT8_WEIGHT_ONLY,
|
||||
PACKED_INT4_WEIGHT_ONLY
|
||||
W8_A16,
|
||||
W4_A16,
|
||||
W4_AFP8
|
||||
};
|
||||
int get_bits_in_quant_type(QuantType quant_type);
|
||||
|
||||
constexpr int get_weight_quant_bits(QuantType quant_type)
|
||||
{
|
||||
switch (quant_type)
|
||||
{
|
||||
case QuantType::W8_A16: return 8;
|
||||
case QuantType::W4_A16: return 4;
|
||||
case QuantType::W4_AFP8: return 4;
|
||||
default: TLLM_CHECK_WITH_INFO(false, "Invalid quant_type"); return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// Shapes here can be 2 or 3D. 2-D shapes are [num_rows, num_cols]
|
||||
// 3-D shapes are [num_experts, num_rows, num_cols]
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -24,7 +24,7 @@ namespace cutlass_kernels
|
||||
{
|
||||
#ifdef ENABLE_FP8
|
||||
template class CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, /*Activation Type*/
|
||||
cutlass::int4b_t, /*Weight Type*/
|
||||
cutlass::uint4b_t, /*Weight Type*/
|
||||
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, half, /*Scale and Zero Type*/
|
||||
half, /*Bias type Type*/
|
||||
half /*Output type Type*/
|
||||
|
||||
@ -24,7 +24,7 @@ namespace cutlass_kernels
|
||||
{
|
||||
#ifdef ENABLE_FP8
|
||||
template class CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, /*Activation Type*/
|
||||
cutlass::int4b_t, /*Weight Type*/
|
||||
cutlass::uint4b_t, /*Weight Type*/
|
||||
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, half, /*Scale and Zero Type*/
|
||||
half, /*Bias type Type*/
|
||||
half /*Output type Type*/
|
||||
|
||||
@ -24,7 +24,7 @@ namespace cutlass_kernels
|
||||
{
|
||||
#ifdef ENABLE_FP8
|
||||
template class CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, /*Activation Type*/
|
||||
cutlass::int4b_t, /*Weight Type*/
|
||||
cutlass::uint4b_t, /*Weight Type*/
|
||||
cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, half, /*Scale and Zero Type*/
|
||||
half, /*Bias type Type*/
|
||||
half /*Output type Type*/
|
||||
|
||||
@ -37,6 +37,7 @@
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h"
|
||||
|
||||
@ -50,65 +51,59 @@ namespace kernels
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
|
||||
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
|
||||
typename ThreadblockShape, typename WarpShape, int Stages>
|
||||
void generic_mixed_gemm_kernelLauncher(T const* A, WeightType const* B, T const* weight_scales,
|
||||
T const* weight_zero_points, T const* biases, float const alpha, T* C, int m, int n, int k, int const group_size,
|
||||
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename ThreadblockShape,
|
||||
typename WarpShape, int Stages>
|
||||
void generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
|
||||
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
|
||||
int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||
cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
|
||||
|| cutlass::platform::is_same<T, float>::value,
|
||||
static_assert(
|
||||
#ifdef ENABLE_FP8
|
||||
cutlass::platform::is_same<ActivationType, __nv_fp8_e4m3>::value ||
|
||||
#endif
|
||||
cutlass::platform::is_same<ActivationType, __nv_bfloat16>::value
|
||||
|| cutlass::platform::is_same<ActivationType, half>::value
|
||||
|| cutlass::platform::is_same<ActivationType, float>::value,
|
||||
"Specialized for bfloat16, half, float");
|
||||
#else
|
||||
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value,
|
||||
static_assert(cutlass::platform::is_same<ActivationType, half>::value
|
||||
|| cutlass::platform::is_same<ActivationType, float>::value,
|
||||
"Specialized for half, float");
|
||||
#endif
|
||||
|
||||
static_assert(cutlass::platform::is_same<T, WeightType>::value
|
||||
static_assert(cutlass::platform::is_same<ActivationType, WeightType>::value
|
||||
|| cutlass::platform::is_same<WeightType, uint8_t>::value
|
||||
|| cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value,
|
||||
"");
|
||||
|
||||
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
|
||||
using ElementType_ =
|
||||
typename cutlass::platform::conditional<cutlass::platform::is_same<T, half>::value, cutlass::half_t, T>::type;
|
||||
#ifdef ENABLE_BF16
|
||||
using ElementType =
|
||||
typename cutlass::platform::conditional<cutlass::platform::is_same<ElementType_, __nv_bfloat16>::value,
|
||||
cutlass::bfloat16_t, ElementType_>::type;
|
||||
#else
|
||||
using ElementType = ElementType_;
|
||||
#endif
|
||||
|
||||
using CutlassWeightType_ =
|
||||
typename cutlass::platform::conditional<cutlass::platform::is_same<WeightType, half>::value, cutlass::half_t,
|
||||
WeightType>::type;
|
||||
#ifdef ENABLE_BF16
|
||||
using CutlassWeightType =
|
||||
typename cutlass::platform::conditional<cutlass::platform::is_same<CutlassWeightType_, __nv_bfloat16>::value,
|
||||
cutlass::bfloat16_t, CutlassWeightType_>::type;
|
||||
#else
|
||||
using CutlassWeightType = CutlassWeightType_;
|
||||
#endif
|
||||
using CutlassActivationType = typename TllmToCutlassTypeAdapter<ActivationType>::type;
|
||||
using CutlassWeightType = typename TllmToCutlassTypeAdapter<WeightType>::type;
|
||||
using CutlassScaleZeroType = typename TllmToCutlassTypeAdapter<ScaleZeroType>::type;
|
||||
using CutlassBiasType = typename TllmToCutlassTypeAdapter<BiasType>::type;
|
||||
using CutlassOutputType = typename TllmToCutlassTypeAdapter<OutputType>::type;
|
||||
|
||||
// We need separate config for each architecture since we will target different tensorcore instructions. For float,
|
||||
// we do not target TCs.
|
||||
using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits<ElementType, CutlassWeightType, arch>;
|
||||
using MixedGemmArchTraits
|
||||
= cutlass::gemm::kernel::MixedGemmArchTraits<CutlassActivationType, CutlassWeightType, arch>;
|
||||
using ElementAccumulator = typename MixedGemmArchTraits::AccType;
|
||||
|
||||
using EpilogueOp = typename tkc::Epilogue<ElementType, MixedGemmArchTraits::ElementsPerAccessC, ElementAccumulator,
|
||||
EpilogueTag>::Op;
|
||||
constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<CutlassOutputType>::value;
|
||||
using EpilogueOp =
|
||||
typename tkc::Epilogue<CutlassOutputType, ElementsPerAccessC, ElementAccumulator, EpilogueTag>::Op;
|
||||
|
||||
using Operator = typename MixedGemmArchTraits::Operator;
|
||||
using TaggedOperator = typename cutlass::arch::TagOperator<Operator, QuantOp>::TaggedOperator;
|
||||
|
||||
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<ElementType, cutlass::layout::RowMajor,
|
||||
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<CutlassActivationType, cutlass::layout::RowMajor,
|
||||
MixedGemmArchTraits::ElementsPerAccessA, CutlassWeightType, typename MixedGemmArchTraits::LayoutB,
|
||||
MixedGemmArchTraits::ElementsPerAccessB, ElementType, cutlass::layout::RowMajor, ElementAccumulator,
|
||||
MixedGemmArchTraits::ElementsPerAccessB, CutlassOutputType, cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, arch, ThreadblockShape, WarpShape,
|
||||
typename MixedGemmArchTraits::InstructionShape, EpilogueOp,
|
||||
typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, Stages, true,
|
||||
@ -138,6 +133,13 @@ void generic_mixed_gemm_kernelLauncher(T const* A, WeightType const* B, T const*
|
||||
|
||||
if constexpr (cutlass::isFinegrained(QuantOp))
|
||||
{
|
||||
if constexpr (cutlass::platform::is_same<CutlassActivationType, float_e4m3_t>::value)
|
||||
{
|
||||
if (group_size != 128)
|
||||
{
|
||||
throw std::runtime_error("Only group size 128 supported for fine grained W4A(fp)8 kernels.");
|
||||
}
|
||||
}
|
||||
if (group_size != 64 && group_size != 128)
|
||||
{
|
||||
throw std::runtime_error("Only group size 64 and 128 supported for fine grained kernels.");
|
||||
@ -173,12 +175,14 @@ void generic_mixed_gemm_kernelLauncher(T const* A, WeightType const* B, T const*
|
||||
|
||||
int const ld_scale_zero = cutlass::isFinegrained(QuantOp) ? n : 0;
|
||||
ElementAccumulator output_op_beta = (biases == nullptr) ? ElementAccumulator(0.f) : ElementAccumulator(1.f);
|
||||
typename Gemm::Arguments args({m, n, k}, group_size, {reinterpret_cast<ElementType*>(const_cast<T*>(A)), k},
|
||||
typename Gemm::Arguments args({m, n, k}, group_size,
|
||||
{reinterpret_cast<CutlassActivationType*>(const_cast<ActivationType*>(A)), k},
|
||||
{reinterpret_cast<CutlassWeightType*>(const_cast<WeightType*>(B)), ldb},
|
||||
{reinterpret_cast<ElementType*>(const_cast<T*>(weight_scales)), ld_scale_zero},
|
||||
{reinterpret_cast<ElementType*>(const_cast<T*>(weight_zero_points)), ld_scale_zero},
|
||||
{reinterpret_cast<ElementType*>(const_cast<T*>(biases)), 0}, {reinterpret_cast<ElementType*>(C), n},
|
||||
gemm_config.split_k_factor, {ElementAccumulator(alpha), output_op_beta});
|
||||
{reinterpret_cast<CutlassScaleZeroType*>(const_cast<ScaleZeroType*>(weight_scales)), ld_scale_zero},
|
||||
{reinterpret_cast<CutlassScaleZeroType*>(const_cast<ScaleZeroType*>(weight_zero_points)), ld_scale_zero},
|
||||
{reinterpret_cast<CutlassBiasType*>(const_cast<BiasType*>(biases)), 0},
|
||||
{reinterpret_cast<CutlassOutputType*>(C), n}, gemm_config.split_k_factor,
|
||||
{ElementAccumulator(alpha), output_op_beta});
|
||||
|
||||
// This assertion is enabled because because for the column interleaved layout, K MUST be a multiple of
|
||||
// threadblockK. The reason for this is that the default pitchlinear iterators are used to handle walking over the
|
||||
@ -227,13 +231,14 @@ void generic_mixed_gemm_kernelLauncher(T const* A, WeightType const* B, T const*
|
||||
|
||||
// This filters out invalid template combinations that we DON'T want instantiated in CUTLASS. For example,
|
||||
// instantiating SM=75, Stages=3 is invalid so we would need to filter that out. Fine grained
|
||||
// quanitzation is only supported on Ampere+ GPUs.
|
||||
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
|
||||
typename ThreadblockShape, typename WarpShape, int Stages>
|
||||
void filter_and_run_mixed_gemm(T const* A, WeightType const* B, T const* weight_scales, T const* weight_zero_points,
|
||||
T const* biases, float const alpha, T* C, int m, int n, int k, int const group_size,
|
||||
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
// quanitzation is only supported on Ampere+ GPUs. FP8 GEMM is only supported on Ada+ GPUs.
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename ThreadblockShape,
|
||||
typename WarpShape, int Stages>
|
||||
void filter_and_run_mixed_gemm(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
|
||||
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
|
||||
int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||
cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
@ -251,39 +256,55 @@ void filter_and_run_mixed_gemm(T const* A, WeightType const* B, T const* weight_
|
||||
+ std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages);
|
||||
throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg);
|
||||
}
|
||||
else if constexpr (Stages == 2 && arch::kMinComputeCapability >= 89)
|
||||
{
|
||||
// Multistage only supported on Ampere
|
||||
std::string err_msg = "Cutlass fpA_intB gemm not supported for arch "
|
||||
+ std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages);
|
||||
throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg);
|
||||
}
|
||||
else if constexpr (cutlass::platform::is_same<ActivationType, __nv_fp8_e4m3>::value
|
||||
&& arch::kMinComputeCapability < 89)
|
||||
{
|
||||
// FP8 activation type only supported on Ada+ GPUs
|
||||
std::string err_msg = "Cutlass fpA_intB gemm not supported for arch "
|
||||
+ std::to_string(arch::kMinComputeCapability) + " with activation type set to FP8";
|
||||
throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg);
|
||||
}
|
||||
else
|
||||
{
|
||||
generic_mixed_gemm_kernelLauncher<T, WeightType, arch, QuantOp, EpilogueTag, ThreadblockShape, WarpShape,
|
||||
Stages>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config,
|
||||
workspace, workspace_bytes, stream, occupancy);
|
||||
generic_mixed_gemm_kernelLauncher<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, arch,
|
||||
QuantOp, EpilogueTag, ThreadblockShape, WarpShape, Stages>(A, B, weight_scales, weight_zero_points, biases,
|
||||
alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
|
||||
typename ThreadblockShape, typename WarpShape>
|
||||
void dispatch_gemm_config(T const* A, WeightType const* B, T const* weight_scales, T const* weight_zero_points,
|
||||
T const* biases, float const alpha, T* C, int m, int n, int k, int const group_size,
|
||||
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename ThreadblockShape,
|
||||
typename WarpShape>
|
||||
void dispatch_gemm_config(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
|
||||
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
|
||||
int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||
cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
switch (gemm_config.stages)
|
||||
{
|
||||
case 2:
|
||||
filter_and_run_mixed_gemm<T, WeightType, arch, QuantOp, EpilogueTag, ThreadblockShape, WarpShape, 2>(A, B,
|
||||
weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace,
|
||||
workspace_bytes, stream, occupancy);
|
||||
filter_and_run_mixed_gemm<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, arch, QuantOp,
|
||||
EpilogueTag, ThreadblockShape, WarpShape, 2>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m,
|
||||
n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case 3:
|
||||
filter_and_run_mixed_gemm<T, WeightType, arch, QuantOp, EpilogueTag, ThreadblockShape, WarpShape, 3>(A, B,
|
||||
weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace,
|
||||
workspace_bytes, stream, occupancy);
|
||||
filter_and_run_mixed_gemm<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, arch, QuantOp,
|
||||
EpilogueTag, ThreadblockShape, WarpShape, 3>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m,
|
||||
n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case 4:
|
||||
filter_and_run_mixed_gemm<T, WeightType, arch, QuantOp, EpilogueTag, ThreadblockShape, WarpShape, 4>(A, B,
|
||||
weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace,
|
||||
workspace_bytes, stream, occupancy);
|
||||
filter_and_run_mixed_gemm<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, arch, QuantOp,
|
||||
EpilogueTag, ThreadblockShape, WarpShape, 4>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m,
|
||||
n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
default:
|
||||
std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages);
|
||||
@ -315,55 +336,56 @@ void dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, Scal
|
||||
constexpr bool all_types_are_the_same = std::is_same_v<ActivationType, ScaleZeroType>
|
||||
&& std::is_same_v<ActivationType, BiasType> && std::is_same_v<ActivationType, OutputType>;
|
||||
|
||||
constexpr bool is_valid_pre_hopper = all_types_are_the_same && !any_is_fp8;
|
||||
constexpr bool is_valid_pre_hopper = (all_types_are_the_same && !any_is_fp8) || (arch::kMinComputeCapability >= 89);
|
||||
|
||||
if constexpr (is_valid_pre_hopper)
|
||||
{
|
||||
// Note that SIMT configs are omitted here since they are not supported for fpA_intB.
|
||||
// We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the
|
||||
// best for mixed type gemms.
|
||||
constexpr int tile_shape_k = 128 * 8 / cutlass::sizeof_bits<ActivationType>::value;
|
||||
switch (gemm_config.tile_config)
|
||||
{
|
||||
case tkc::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
|
||||
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
|
||||
if constexpr (arch::kMinComputeCapability >= 75)
|
||||
{
|
||||
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
|
||||
cutlass::gemm::GemmShape<16, 128, 64>, cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales,
|
||||
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
|
||||
stream, occupancy);
|
||||
dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, arch, QuantOp,
|
||||
EpilogueTag, cutlass::gemm::GemmShape<16, 128, tile_shape_k>,
|
||||
cutlass::gemm::GemmShape<16, 32, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases,
|
||||
alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
}
|
||||
break;
|
||||
case tkc::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
|
||||
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
|
||||
if constexpr (arch::kMinComputeCapability >= 75)
|
||||
{
|
||||
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
|
||||
cutlass::gemm::GemmShape<16, 256, 64>, cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales,
|
||||
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
|
||||
stream, occupancy);
|
||||
dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, arch, QuantOp,
|
||||
EpilogueTag, cutlass::gemm::GemmShape<16, 256, tile_shape_k>,
|
||||
cutlass::gemm::GemmShape<16, 64, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases,
|
||||
alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
}
|
||||
break;
|
||||
case tkc::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
|
||||
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
|
||||
cutlass::gemm::GemmShape<32, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales,
|
||||
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
|
||||
stream, occupancy);
|
||||
dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, arch, QuantOp,
|
||||
EpilogueTag, cutlass::gemm::GemmShape<32, 128, tile_shape_k>,
|
||||
cutlass::gemm::GemmShape<32, 32, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases, alpha,
|
||||
C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
|
||||
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales,
|
||||
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
|
||||
stream, occupancy);
|
||||
dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, arch, QuantOp,
|
||||
EpilogueTag, cutlass::gemm::GemmShape<64, 128, tile_shape_k>,
|
||||
cutlass::gemm::GemmShape<64, 32, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases, alpha,
|
||||
C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
|
||||
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
|
||||
if constexpr (arch::kMinComputeCapability >= 75)
|
||||
{
|
||||
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales,
|
||||
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
|
||||
stream, occupancy);
|
||||
dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, arch, QuantOp,
|
||||
EpilogueTag, cutlass::gemm::GemmShape<128, 128, tile_shape_k>,
|
||||
cutlass::gemm::GemmShape<128, 32, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases,
|
||||
alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
}
|
||||
break;
|
||||
case tkc::CutlassTileConfig::Undefined:
|
||||
@ -430,12 +452,26 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
|
||||
QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
|
||||
workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
|
||||
}
|
||||
else if (sm_ >= 80 && sm_ < 90)
|
||||
else if (sm_ >= 80 && sm_ < 89)
|
||||
{
|
||||
dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, cutlass::arch::Sm80,
|
||||
QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
|
||||
workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
|
||||
}
|
||||
else if (sm_ == 89)
|
||||
{
|
||||
#if ENABLE_FP8 && ((__CUDACC_VER_MAJOR__ < 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4))
|
||||
if constexpr (cutlass::platform::is_same<ActivationType, __nv_fp8_e4m3>::value)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"[TensorRT-LLM Error][CutlassFpAIntBGemmRunner][dispatch_to_arch] INT4xFP8 GEMM for Ada needs "
|
||||
"CUDA>=12.4");
|
||||
}
|
||||
#endif
|
||||
dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, cutlass::arch::Sm89,
|
||||
QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
|
||||
workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
|
||||
}
|
||||
else if (sm_ == 90)
|
||||
{
|
||||
sm90_dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
|
||||
@ -520,8 +556,14 @@ std::vector<tkc::CutlassGemmConfig>
|
||||
CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::getConfigs() const
|
||||
{
|
||||
static constexpr bool is_weight_only = !std::is_same<ActivationType, WeightType>::value;
|
||||
std::vector<tkc::CutlassGemmConfig> candidateConfigs
|
||||
= get_candidate_configs(sm_, is_weight_only, false, false, SPLIT_K_LIMIT, true);
|
||||
tkc::CutlassGemmConfig::CandidateConfigTypeParam config_type_param
|
||||
= tkc::CutlassGemmConfig::CandidateConfigTypeParam::HOPPER;
|
||||
if (is_weight_only)
|
||||
{
|
||||
config_type_param = static_cast<tkc::CutlassGemmConfig::CandidateConfigTypeParam>(
|
||||
config_type_param | tkc::CutlassGemmConfig::CandidateConfigTypeParam::WEIGHT_ONLY);
|
||||
}
|
||||
std::vector<tkc::CutlassGemmConfig> candidateConfigs = get_candidate_configs(sm_, SPLIT_K_LIMIT, config_type_param);
|
||||
return candidateConfigs;
|
||||
}
|
||||
|
||||
|
||||
@ -79,7 +79,7 @@ void sm90_dispatch_epilogue_schedules(ActivationType const* A, WeightType const*
|
||||
|
||||
2x2x1 cluster shape is only supported when both the M and N tiles are at least 128.
|
||||
|
||||
We make the above restrictions are to improve compilation speed in TRT-LLM by pruning kernels
|
||||
We make the above restrictions to improve compilation speed in TRT-LLM, by pruning kernels
|
||||
that may not be very useful in practice.
|
||||
*/
|
||||
template <typename CTAShape, typename ClusterShape>
|
||||
|
||||
@ -66,7 +66,7 @@ void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
|
||||
#ifdef COMPILE_HOPPER_MIXED_INPUT_GEMMS
|
||||
#ifdef COMPILE_HOPPER_TMA_GEMMS
|
||||
using CutlassActivationType = typename TllmToCutlassTypeAdapter<ActivationType>::type;
|
||||
|
||||
// For FAST_BUILD, only instantiate kernels with 128x128x128B with 1x1x1 cluster shape.
|
||||
@ -286,11 +286,11 @@ void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType
|
||||
}
|
||||
#endif // FAST_BUILD
|
||||
|
||||
#else // COMPILE_HOPPER_MIXED_INPUT_GEMMS
|
||||
#else // COMPILE_HOPPER_TMA_GEMMS
|
||||
throw std::runtime_error(
|
||||
"[TensorRT-LLm Error][fpA_intB Runner] Please recompile with support for hopper by passing 90-real as an arch "
|
||||
"to build_wheel.py.");
|
||||
#endif // COMPILE_HOPPER_MIXED_INPUT_GEMMS
|
||||
#endif // COMPILE_HOPPER_TMA_GEMMS
|
||||
}
|
||||
|
||||
} // namespace cutlass_kernels
|
||||
|
||||
@ -365,10 +365,15 @@ void CutlassInt8GemmRunner<T>::gemm(int8_t const* A, int8_t const* B, tk::QuantM
|
||||
template <typename T>
|
||||
std::vector<tkc::CutlassGemmConfig> CutlassInt8GemmRunner<T>::getConfigs() const
|
||||
{
|
||||
static constexpr bool isWeightOnly = false;
|
||||
std::vector<tkc::CutlassGemmConfig> candidateConfigs
|
||||
= get_candidate_configs(mSm, isWeightOnly, mSm <= 70, /* SIMT configs */
|
||||
true, SPLIT_K_LIMIT); /* INT8 configs */
|
||||
|
||||
auto config_type_param = tkc::CutlassGemmConfig::CandidateConfigTypeParam::INT8_ONLY;
|
||||
if (mSm <= 70)
|
||||
{
|
||||
config_type_param = static_cast<tkc::CutlassGemmConfig::CandidateConfigTypeParam>(
|
||||
config_type_param | tkc::CutlassGemmConfig::CandidateConfigTypeParam::SIMT_ONLY);
|
||||
}
|
||||
|
||||
std::vector<tkc::CutlassGemmConfig> candidateConfigs = get_candidate_configs(mSm, SPLIT_K_LIMIT, config_type_param);
|
||||
return candidateConfigs;
|
||||
}
|
||||
|
||||
|
||||
@ -0,0 +1,36 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
|
||||
// Keep in sync with the signature generated by generate_kernels.py
|
||||
template <typename T, typename WeightType, typename EpilogueTag, typename TileShape, typename ClusterShape, bool BIAS>
|
||||
void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts,
|
||||
int multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size);
|
||||
|
||||
} // namespace cutlass_kernels
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,304 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "cutlass_extensions/compute_occupancy.h"
|
||||
#include "cutlass_extensions/epilogue_helpers.h"
|
||||
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
|
||||
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
|
||||
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <math.h>
|
||||
#include <sstream>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
|
||||
// Hopper helper class for defining all the cutlass helper types
|
||||
template <typename T, typename WeightType, typename EpilogueTag, typename TileShape, typename ClusterShape, bool BIAS>
|
||||
struct HopperGroupedGemmInfo
|
||||
{
|
||||
using Arch = cutlass::arch::Sm90;
|
||||
|
||||
// TODO Update once mixed input support is added
|
||||
static_assert(cutlass::platform::is_same<T, WeightType>::value,
|
||||
"CUTLASS does not currently have specialised SM90 support for quantized operations");
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
constexpr static bool IsFP8
|
||||
= cutlass::platform::is_same<T, __nv_fp8_e4m3>::value || cutlass::platform::is_same<T, __nv_fp8_e5m2>::value;
|
||||
#else
|
||||
constexpr static bool IsFP8 = false;
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
|
||||
|| cutlass::platform::is_same<T, float>::value || IsFP8,
|
||||
"Specialized for bfloat16, half, float, fp8");
|
||||
#else
|
||||
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value || IsFP8,
|
||||
"Specialized for half, float, fp8");
|
||||
#endif
|
||||
|
||||
static_assert(cutlass::platform::is_same<T, WeightType>::value
|
||||
|| cutlass::platform::is_same<WeightType, uint8_t>::value
|
||||
|| cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value
|
||||
|| cutlass::platform::is_same<WeightType, cutlass::float_e4m3_t>::value
|
||||
|| cutlass::platform::is_same<WeightType, cutlass::float_e5m2_t>::value,
|
||||
"Unexpected quantization type");
|
||||
|
||||
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
|
||||
using ElementType = typename TllmToCutlassTypeAdapter<T>::type;
|
||||
|
||||
using CutlassWeightTypeMaybeUint4 = typename TllmToCutlassTypeAdapter<WeightType>::type;
|
||||
// For legacy reasons we convert unsigned 8-bit to signed
|
||||
using CutlassWeightTypeMaybeUint8
|
||||
= std::conditional_t<std::is_same_v<CutlassWeightTypeMaybeUint4, cutlass::uint4b_t>, cutlass::int4b_t,
|
||||
CutlassWeightTypeMaybeUint4>;
|
||||
using CutlassWeightType
|
||||
= std::conditional_t<std::is_same_v<CutlassWeightTypeMaybeUint8, uint8_t>, int8_t, CutlassWeightTypeMaybeUint8>;
|
||||
|
||||
using ElementA = ElementType;
|
||||
using ElementB = CutlassWeightType;
|
||||
|
||||
template <class Element>
|
||||
using CutlassOutputTypeAdaptor_t = typename TllmToCutlassTypeAdapter<
|
||||
HopperGroupedGemmInput::OutputTypeAdaptor_t<typename CutlassToTllmTypeAdapter<Element>::type>>::type;
|
||||
using ElementD = CutlassOutputTypeAdaptor_t<ElementType>;
|
||||
|
||||
using ElementC = std::conditional_t<BIAS, ElementType, void>;
|
||||
using ElementCNoVoid = std::conditional_t<BIAS, ElementType, ElementD>;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
|
||||
// A matrix configuration - this is transposed and swapped with B
|
||||
using LayoutA = HopperGroupedGemmInput::LayoutA;
|
||||
constexpr static int AlignmentA
|
||||
= 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units
|
||||
// of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration - this is transposed and swapped with A
|
||||
using LayoutB = HopperGroupedGemmInput::LayoutB; // Layout type for B matrix operand
|
||||
constexpr static int AlignmentB
|
||||
= 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units
|
||||
// of elements (up to 16 bytes)
|
||||
|
||||
// C matrix configuration
|
||||
using LayoutC = HopperGroupedGemmInput::LayoutC; // Layout type for C matrix operand
|
||||
// Note we use ElementType here deliberately, so we don't break when BIAS is disabled
|
||||
constexpr static int AlignmentC
|
||||
= 128 / cutlass::sizeof_bits<ElementType>::value; // Memory access granularity/alignment of C matrix in units
|
||||
// of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using LayoutD = HopperGroupedGemmInput::LayoutD;
|
||||
constexpr static int AlignmentD
|
||||
= 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of D matrix
|
||||
// in units of elements (up to 16 bytes)
|
||||
|
||||
static_assert(cutlass::platform::is_same<EpilogueTag, tensorrt_llm::cutlass_extensions::EpilogueOpDefault>::value,
|
||||
"Hopper Grouped GEMM specialisation doesn't support fused activation");
|
||||
|
||||
using EpilogueOp
|
||||
= cutlass::epilogue::fusion::LinearCombination<ElementD, ElementAccumulator, ElementC, ElementAccumulator>;
|
||||
|
||||
// TODO Add mode for fused activation once CUTLASS adds support
|
||||
// using EpilogueSchedule = cutlass::platform::conditional_t<
|
||||
// cutlass::platform::is_same<EpilogueOp, EpilogueOpDefault>::value,
|
||||
// cutlass::epilogue::PtrArrayNoSmemWarpSpecialized,
|
||||
// cutlass::epilogue::?????????????????? /// <<<<<< what supports activations
|
||||
// >;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< //
|
||||
Arch, cutlass::arch::OpClassTensorOp, //
|
||||
TileShape, ClusterShape, //
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, //
|
||||
ElementAccumulator, ElementAccumulator, //
|
||||
ElementC, LayoutC*, AlignmentC, //
|
||||
ElementD, LayoutD*, AlignmentD, //
|
||||
EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using StageCountAutoCarveout = cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>;
|
||||
|
||||
using KernelSchedule
|
||||
= std::conditional_t<IsFP8, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< //
|
||||
Arch, cutlass::arch::OpClassTensorOp, //
|
||||
CutlassWeightType, LayoutB*, AlignmentB, // A & B swapped here
|
||||
ElementType, LayoutA*, AlignmentA, //
|
||||
ElementAccumulator, //
|
||||
TileShape, ClusterShape, //
|
||||
StageCountAutoCarveout, KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<HopperGroupedGemmInput::ProblemShape, CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
|
||||
using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
};
|
||||
|
||||
// Hopper specialised version
|
||||
template <typename T, typename WeightType, typename EpilogueTag, typename TileShape, typename ClusterShape, bool BIAS>
|
||||
void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts,
|
||||
int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size)
|
||||
{
|
||||
#ifdef COMPILE_HOPPER_TMA_GEMMS
|
||||
using namespace cute;
|
||||
// For FAST_BUILD, only instantiate kernels with 128x128x128B with 1x1x1 cluster shape.
|
||||
#ifdef FAST_BUILD
|
||||
constexpr int TILE_K = 128 * 8 / cutlass::sizeof_bits<WeightType>::value;
|
||||
using SupportedCtaShape = Shape<_128, _128, cute::Int<TILE_K>>;
|
||||
using SupportedCgaShape = Shape<_1, _1, _1>;
|
||||
|
||||
if constexpr (cute::is_same_v<SupportedCtaShape, TileShape> && cute::is_same_v<SupportedCgaShape, ClusterShape>)
|
||||
#endif // FAST_BUILD
|
||||
{
|
||||
using GemmInfo = HopperGroupedGemmInfo<T, WeightType, EpilogueTag, TileShape, ClusterShape, BIAS>;
|
||||
|
||||
using ElementAccumulator = typename GemmInfo::ElementAccumulator;
|
||||
using ElementA = typename GemmInfo::ElementA;
|
||||
using ElementB = typename GemmInfo::ElementB;
|
||||
using ElementC = typename GemmInfo::ElementC;
|
||||
using ElementCNoVoid = typename GemmInfo::ElementCNoVoid;
|
||||
using ElementD = typename GemmInfo::ElementD;
|
||||
|
||||
using CollectiveMainloop = typename GemmInfo::CollectiveMainloop;
|
||||
using CollectiveEpilogue = typename GemmInfo::CollectiveEpilogue;
|
||||
using GemmKernel = typename GemmInfo::GemmKernel;
|
||||
using GemmGrouped = typename GemmInfo::GemmGrouped;
|
||||
|
||||
if (kernel_occupancy != nullptr)
|
||||
{
|
||||
*kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel<GemmKernel, true>();
|
||||
return;
|
||||
}
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = multi_processor_count;
|
||||
|
||||
GemmGrouped gemm;
|
||||
|
||||
if (workspace_size != nullptr)
|
||||
{
|
||||
// Make a mock problem shape with just the minimal information actually required to get the workspace size
|
||||
// This makes some assumptions about CUTLASS's implementation which is suboptimal. We have a check later to
|
||||
// catch future cutlass updates causing silent breakages, but that is not fool proof.
|
||||
// The alternative is to wait until we have data and then dynamically allocate the workspace
|
||||
typename HopperGroupedGemmInput::ProblemShape shape_info{num_experts, nullptr, nullptr};
|
||||
|
||||
typename GemmGrouped::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped, shape_info, {}, {}, hw_info};
|
||||
*workspace_size = gemm.get_workspace_size(args);
|
||||
return;
|
||||
}
|
||||
|
||||
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
||||
TLLM_CHECK(hopper_input.stride_a);
|
||||
TLLM_CHECK(hopper_input.stride_b);
|
||||
TLLM_CHECK(hopper_input.stride_d);
|
||||
TLLM_CHECK(hopper_input.ptr_a);
|
||||
TLLM_CHECK(hopper_input.ptr_b);
|
||||
TLLM_CHECK(hopper_input.ptr_d);
|
||||
|
||||
const MainloopArguments mainloop_params = {reinterpret_cast<ElementB const**>(hopper_input.ptr_b),
|
||||
hopper_input.stride_b, reinterpret_cast<ElementA const**>(hopper_input.ptr_a), hopper_input.stride_a};
|
||||
|
||||
typename GemmGrouped::EpilogueOutputOp::Params epilogue_scalars{
|
||||
ElementAccumulator(1.f), hopper_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)};
|
||||
epilogue_scalars.alpha_ptr_array = hopper_input.alpha_scale_ptr_array;
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
||||
// TODO(dastokes) ptr_c casts to ElementCNoVoid** because there is a workaround in CUTLASS
|
||||
const EpilogueArguments epilogue_params
|
||||
= {epilogue_scalars, reinterpret_cast<ElementCNoVoid const**>(hopper_input.ptr_c), hopper_input.stride_c,
|
||||
reinterpret_cast<ElementD**>(hopper_input.ptr_d), hopper_input.stride_d};
|
||||
|
||||
typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, hopper_input.shape_info,
|
||||
mainloop_params, epilogue_params, hw_info};
|
||||
|
||||
size_t calculated_ws_size = gemm.get_workspace_size(args);
|
||||
TLLM_CHECK_WITH_INFO(calculated_ws_size <= hopper_input.gemm_workspace_size,
|
||||
"Workspace is size %zu but only %zu were allocated", calculated_ws_size, hopper_input.gemm_workspace_size);
|
||||
|
||||
auto can_implement = gemm.can_implement(args);
|
||||
TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess,
|
||||
"Grouped GEMM kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)));
|
||||
|
||||
auto init_status = gemm.initialize(args, hopper_input.gemm_workspace);
|
||||
TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess,
|
||||
"Failed to initialize cutlass variable batched gemm. Error: "
|
||||
+ std::string(cutlassGetStatusString(init_status)));
|
||||
|
||||
auto run_status = gemm.run(stream);
|
||||
TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess,
|
||||
"Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status)));
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
#ifdef FAST_BUILD
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Configuration was disabled by FAST_BUILD");
|
||||
}
|
||||
#endif
|
||||
|
||||
#else // COMPILE_HOPPER_TMA_GEMMS
|
||||
TLLM_THROW("Please recompile with support for hopper by passing 90-real as an arch to build_wheel.py.");
|
||||
#endif // COMPILE_HOPPER_TMA_GEMMS
|
||||
}
|
||||
|
||||
} // namespace cutlass_kernels
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -16,13 +16,119 @@
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "tensorrt_llm/common/workspace.h"
|
||||
#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h"
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <optional>
|
||||
|
||||
#include <cutlass/gemm/group_array_problem_shape.hpp>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
|
||||
struct HopperGroupedGemmInput
|
||||
{
|
||||
|
||||
template <class Tag>
|
||||
using TransposeLayoutTag = std::conditional_t<std::is_same_v<Tag, cutlass::layout::RowMajor>,
|
||||
cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>;
|
||||
static_assert(std::is_same_v<cutlass::layout::RowMajor, TransposeLayoutTag<cutlass::layout::ColumnMajor>>);
|
||||
static_assert(std::is_same_v<cutlass::layout::ColumnMajor, TransposeLayoutTag<cutlass::layout::RowMajor>>);
|
||||
|
||||
// Layout for A and B is transposed and then swapped in the implementation
|
||||
// This uses B^T * A^T = (A * B)^T to get a better layout for the GEMM
|
||||
using LayoutA = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for A matrix operand
|
||||
using LayoutB = TransposeLayoutTag<cutlass::layout::ColumnMajor>; // Layout type for B matrix operand
|
||||
using LayoutC = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for C matrix operand
|
||||
using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand
|
||||
|
||||
using StrideA
|
||||
= std::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutA*>>; // Use B because they will be swapped
|
||||
using StrideB
|
||||
= std::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutB*>>; // Use A because they will be swapped
|
||||
using StrideC = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutC*>>;
|
||||
using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;
|
||||
|
||||
template <class T>
|
||||
constexpr static bool IsFP8_v = std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
|
||||
template <class T>
|
||||
using OutputTypeAdaptor_t = std::conditional_t<IsFP8_v<T>, float, T>;
|
||||
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<cute::Shape<int64_t, int64_t, int64_t>>;
|
||||
|
||||
ProblemShape shape_info{};
|
||||
StrideA* stride_a = nullptr;
|
||||
StrideB* stride_b = nullptr;
|
||||
StrideC* stride_c = nullptr;
|
||||
StrideD* stride_d = nullptr;
|
||||
|
||||
void const** ptr_a = nullptr;
|
||||
void const** ptr_b = nullptr;
|
||||
void const** ptr_c = nullptr;
|
||||
void** ptr_d = nullptr;
|
||||
|
||||
float const** alpha_scale_ptr_array = nullptr;
|
||||
|
||||
uint8_t* gemm_workspace = nullptr;
|
||||
size_t gemm_workspace_size = 0;
|
||||
|
||||
static auto workspaceBuffers(int num_experts)
|
||||
{
|
||||
size_t problem_shape_size = sizeof(ProblemShape::UnderlyingProblemShape) * num_experts;
|
||||
size_t stride_a_size = sizeof(StrideA) * num_experts;
|
||||
size_t stride_b_size = sizeof(StrideB) * num_experts;
|
||||
size_t stride_c_size = sizeof(StrideC) * num_experts;
|
||||
size_t stride_d_size = sizeof(StrideD) * num_experts;
|
||||
|
||||
size_t ptr_buf_size = sizeof(void*) * num_experts;
|
||||
size_t scale_buf_size = sizeof(float**) * num_experts;
|
||||
|
||||
return std::array{problem_shape_size, stride_a_size, stride_b_size, stride_c_size, stride_d_size, ptr_buf_size,
|
||||
ptr_buf_size, ptr_buf_size, ptr_buf_size, scale_buf_size};
|
||||
}
|
||||
|
||||
static size_t workspaceSize(int num_experts)
|
||||
{
|
||||
auto buffers = workspaceBuffers(num_experts);
|
||||
return tensorrt_llm::common::calculateTotalWorkspaceSize(buffers.data(), buffers.size());
|
||||
}
|
||||
|
||||
void configureWorkspace(int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size)
|
||||
{
|
||||
auto buffers = workspaceBuffers(num_experts);
|
||||
std::array<int8_t*, 10> pointers{};
|
||||
TLLM_CHECK_WITH_INFO(pointers.size() == buffers.size(), "Mismatching workspace size and number of buffers");
|
||||
for (int i = 0; i < buffers.size(); i++)
|
||||
{
|
||||
pointers[i] = start_ptr;
|
||||
start_ptr = tensorrt_llm::common::nextWorkspacePtr(start_ptr, buffers[i]);
|
||||
}
|
||||
|
||||
shape_info.num_groups = num_experts;
|
||||
shape_info.problem_shapes = reinterpret_cast<ProblemShape::UnderlyingProblemShape*>(pointers[0]);
|
||||
shape_info.host_problem_shapes = nullptr;
|
||||
stride_a = reinterpret_cast<StrideA*>(pointers[1]);
|
||||
stride_b = reinterpret_cast<StrideB*>(pointers[2]);
|
||||
stride_c = reinterpret_cast<StrideC*>(pointers[3]);
|
||||
stride_d = reinterpret_cast<StrideD*>(pointers[4]);
|
||||
|
||||
ptr_a = reinterpret_cast<void const**>(pointers[5]);
|
||||
ptr_b = reinterpret_cast<void const**>(pointers[6]);
|
||||
ptr_c = reinterpret_cast<void const**>(pointers[7]);
|
||||
ptr_d = reinterpret_cast<void**>(pointers[8]);
|
||||
|
||||
alpha_scale_ptr_array = reinterpret_cast<float const**>(pointers[9]);
|
||||
|
||||
this->gemm_workspace = reinterpret_cast<uint8_t*>(gemm_workspace);
|
||||
this->gemm_workspace_size = gemm_workspace_size;
|
||||
}
|
||||
|
||||
bool isValid() const
|
||||
{
|
||||
return stride_a != nullptr && ptr_a != nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
// Note update moe.py to match
|
||||
enum class ActivationType
|
||||
{
|
||||
@ -53,24 +159,33 @@ public:
|
||||
}
|
||||
|
||||
void moeGemmBiasAct(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C,
|
||||
int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
ActivationType activation_type, cudaStream_t stream);
|
||||
int64_t* total_rows_before_expert, HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n,
|
||||
int64_t gemm_k, int num_experts, ActivationType activation_type, cudaStream_t stream);
|
||||
|
||||
void moeGemm(T const* A, WeightType const* B, T const* weight_scales, T* C, int64_t* total_rows_before_expert,
|
||||
int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream);
|
||||
HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cudaStream_t stream);
|
||||
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs();
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs() const;
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> getHopperConfigs() const;
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> getAmpereConfigs() const;
|
||||
|
||||
bool isHopperSpecialised() const;
|
||||
bool supportsHopperSpecialisation() const;
|
||||
|
||||
size_t calcMaxWorkspaceSize(int num_experts) const;
|
||||
|
||||
private:
|
||||
template <typename EpilogueTag>
|
||||
void dispatchToArch(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C,
|
||||
int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy = nullptr);
|
||||
int64_t* total_rows_before_expert, HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n,
|
||||
int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, cudaStream_t stream,
|
||||
int* occupancy = nullptr);
|
||||
|
||||
template <typename EpilogueTag>
|
||||
void runGemm(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C,
|
||||
int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cudaStream_t stream);
|
||||
int64_t* total_rows_before_expert, HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n,
|
||||
int64_t gemm_k, int num_experts, cudaStream_t stream);
|
||||
|
||||
private:
|
||||
int sm_;
|
||||
|
||||
@ -0,0 +1,25 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
#ifdef ENABLE_FP8
|
||||
template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3>;
|
||||
// template class MoeGemmRunner<__nv_fp8_e5m2, __nv_fp8_e5m2>;
|
||||
#endif
|
||||
} // namespace tensorrt_llm
|
||||
@ -24,6 +24,20 @@
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "cutlass_extensions/compute_occupancy.h"
|
||||
#include "cutlass_extensions/epilogue_helpers.h"
|
||||
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
|
||||
@ -35,7 +49,14 @@
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
|
||||
|
||||
#include "moe_gemm_kernels_template_sm90.h"
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <math.h>
|
||||
@ -43,6 +64,8 @@
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels::cutlass_kernels
|
||||
{
|
||||
|
||||
// ============================= Variable batched Gemm things ===========================
|
||||
template <typename T, typename WeightType, typename arch, typename EpilogueTag, typename ThreadblockShape,
|
||||
@ -66,27 +89,12 @@ void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, T const* weig
|
||||
|| cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value,
|
||||
"");
|
||||
|
||||
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
|
||||
using ElementType_ =
|
||||
typename cutlass::platform::conditional<cutlass::platform::is_same<T, half>::value, cutlass::half_t, T>::type;
|
||||
#ifdef ENABLE_BF16
|
||||
using ElementType =
|
||||
typename cutlass::platform::conditional<cutlass::platform::is_same<ElementType_, __nv_bfloat16>::value,
|
||||
cutlass::bfloat16_t, ElementType_>::type;
|
||||
#else
|
||||
using ElementType = ElementType_;
|
||||
#endif
|
||||
static_assert(!cutlass::platform::is_same<arch, cutlass::arch::Sm90>::value,
|
||||
"Sm90 architecture should use specialised kernels");
|
||||
|
||||
using CutlassWeightType_ =
|
||||
typename cutlass::platform::conditional<cutlass::platform::is_same<WeightType, half>::value, cutlass::half_t,
|
||||
WeightType>::type;
|
||||
#ifdef ENABLE_BF16
|
||||
using CutlassWeightType =
|
||||
typename cutlass::platform::conditional<cutlass::platform::is_same<CutlassWeightType_, __nv_bfloat16>::value,
|
||||
cutlass::bfloat16_t, CutlassWeightType_>::type;
|
||||
#else
|
||||
using CutlassWeightType = CutlassWeightType_;
|
||||
#endif
|
||||
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
|
||||
using ElementType = typename TllmToCutlassTypeAdapter<T>::type;
|
||||
using CutlassWeightType = typename TllmToCutlassTypeAdapter<WeightType>::type;
|
||||
|
||||
// We need separate config for each architecture since we will target different tensorcore instructions. For float,
|
||||
// we do not target TCs.
|
||||
@ -147,50 +155,29 @@ void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, T const* weig
|
||||
"Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status)));
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename arch, typename EpilogueTag, typename ThreadblockShape,
|
||||
typename WarpShape, int Stages, typename Enable = void>
|
||||
struct dispatch_stages
|
||||
{
|
||||
static void dispatch(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C,
|
||||
int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
{
|
||||
TLLM_THROW("Cutlass fpA_intB gemm. Not instantiated for arch %d with stages set to %d",
|
||||
arch::kMinComputeCapability, Stages);
|
||||
}
|
||||
};
|
||||
} // namespace kernels::cutlass_kernels
|
||||
|
||||
template <typename T, typename WeightType, typename arch, typename EpilogueTag, typename ThreadblockShape,
|
||||
typename WarpShape>
|
||||
struct dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>
|
||||
template <typename T, typename WeightType, typename Arch, typename EpilogueTag, typename ThreadblockShape,
|
||||
typename WarpShape, int Stages>
|
||||
static void dispatch(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C,
|
||||
int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
{
|
||||
static void dispatch(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C,
|
||||
int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
static_assert(!std::is_same_v<Arch, cutlass::arch::Sm90>, "Use TMA specialised functions for arch SM90");
|
||||
constexpr bool isFp8 = std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
|
||||
if constexpr ((Stages == 2 || Arch::kMinComputeCapability >= 80) && !isFp8)
|
||||
{
|
||||
genericMoeGemmKernelLauncher<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>(A, B,
|
||||
weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config,
|
||||
multi_processor_count, stream, occupancy);
|
||||
kernels::cutlass_kernels::genericMoeGemmKernelLauncher<T, WeightType, Arch, EpilogueTag, ThreadblockShape,
|
||||
WarpShape, Stages>(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k,
|
||||
num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename WeightType, typename EpilogueTag, typename ThreadblockShape, typename WarpShape,
|
||||
int Stages>
|
||||
struct dispatch_stages<T, WeightType, cutlass::arch::Sm80, EpilogueTag, ThreadblockShape, WarpShape, Stages,
|
||||
typename std::enable_if<(Stages > 2)>::type>
|
||||
{
|
||||
static void dispatch(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C,
|
||||
int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
else
|
||||
{
|
||||
genericMoeGemmKernelLauncher<T, WeightType, cutlass::arch::Sm80, EpilogueTag, ThreadblockShape, WarpShape,
|
||||
Stages>(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts,
|
||||
gemm_config, multi_processor_count, stream, occupancy);
|
||||
TLLM_THROW(
|
||||
"Cutlass gemm. Not instantiated for arch %d with stages set to %d", Arch::kMinComputeCapability, Stages);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename arch, typename EpilogueTag, typename ThreadblockShape,
|
||||
typename WarpShape>
|
||||
@ -202,19 +189,19 @@ void dispatchGemmConfig(T const* A, WeightType const* B, T const* weight_scales,
|
||||
switch (gemm_config.stages)
|
||||
{
|
||||
case 2:
|
||||
using DispatcherStages2 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>;
|
||||
DispatcherStages2::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k,
|
||||
num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
dispatch<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>(A, B, weight_scales, biases, C,
|
||||
total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream,
|
||||
occupancy);
|
||||
break;
|
||||
case 3:
|
||||
using DispatcherStages3 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 3>;
|
||||
DispatcherStages3::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k,
|
||||
num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
dispatch<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 3>(A, B, weight_scales, biases, C,
|
||||
total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream,
|
||||
occupancy);
|
||||
break;
|
||||
case 4:
|
||||
using DispatcherStages4 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 4>;
|
||||
DispatcherStages4::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k,
|
||||
num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
dispatch<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 4>(A, B, weight_scales, biases, C,
|
||||
total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream,
|
||||
occupancy);
|
||||
break;
|
||||
default: TLLM_THROW("dispatchGemmConfig does not support stages %d", gemm_config.stages); break;
|
||||
}
|
||||
@ -226,7 +213,7 @@ template <typename T, typename WeightType, typename arch, typename EpilogueTag,
|
||||
typename std::enable_if<!std::is_same<T, float>::value && std::is_same<T, WeightType>::value>::type* = nullptr>
|
||||
void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C,
|
||||
int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, cudaStream_t stream,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
{
|
||||
switch (gemm_config.tile_config)
|
||||
@ -279,7 +266,7 @@ template <typename T, typename WeightType, typename arch, typename EpilogueTag,
|
||||
typename std::enable_if<!std::is_same<T, float>::value && !std::is_same<T, WeightType>::value>::type* = nullptr>
|
||||
void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C,
|
||||
int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, cudaStream_t stream,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
{
|
||||
switch (gemm_config.tile_config)
|
||||
@ -330,7 +317,7 @@ template <typename T, typename WeightType, typename arch, typename EpilogueTag,
|
||||
typename std::enable_if<std::is_same<T, float>::value>::type* = nullptr>
|
||||
void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C,
|
||||
int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, cudaStream_t stream,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
{
|
||||
switch (gemm_config.tile_config)
|
||||
@ -349,15 +336,78 @@ void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_s
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> MoeGemmRunner<T, WeightType>::getConfigs()
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> MoeGemmRunner<T, WeightType>::getConfigs() const
|
||||
{
|
||||
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
|
||||
static constexpr bool only_simt_configs = std::is_same<T, float>::value;
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> candidate_configs
|
||||
= kernels::cutlass_kernels::get_candidate_configs(sm_, is_weight_only, only_simt_configs);
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> candidate_configs = getAmpereConfigs();
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> hopper_configs = getHopperConfigs();
|
||||
std::copy(hopper_configs.begin(), hopper_configs.end(), std::back_inserter(candidate_configs));
|
||||
|
||||
return candidate_configs;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> MoeGemmRunner<T, WeightType>::getAmpereConfigs() const
|
||||
{
|
||||
using tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
|
||||
static constexpr auto weight_only_flag
|
||||
= std::is_same<T, WeightType>::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY;
|
||||
static constexpr auto simt_only_flag
|
||||
= std::is_same<T, float>::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE;
|
||||
int const max_split_k = 1;
|
||||
int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM;
|
||||
int const enable_hopper = CutlassGemmConfig::NONE;
|
||||
|
||||
auto config_type_param = static_cast<CutlassGemmConfig::CandidateConfigTypeParam>(
|
||||
weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper);
|
||||
|
||||
if (!kernels::cutlass_kernels::isValidAmpereMOESpecialisation<T, WeightType>())
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> ampere_configs
|
||||
= kernels::cutlass_kernels::get_candidate_configs(sm_, max_split_k, config_type_param);
|
||||
return ampere_configs;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> MoeGemmRunner<T, WeightType>::getHopperConfigs() const
|
||||
{
|
||||
using tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
|
||||
static constexpr auto weight_only_flag
|
||||
= std::is_same<T, WeightType>::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY;
|
||||
static constexpr auto simt_only_flag
|
||||
= std::is_same<T, float>::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE;
|
||||
int const max_split_k = 1;
|
||||
int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM;
|
||||
int const enable_hopper = CutlassGemmConfig::HOPPER;
|
||||
|
||||
auto config_type_param = static_cast<CutlassGemmConfig::CandidateConfigTypeParam>(
|
||||
weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper);
|
||||
|
||||
if (!kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>())
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> hopper_configs
|
||||
= kernels::cutlass_kernels::get_candidate_configs(sm_, max_split_k, config_type_param);
|
||||
return hopper_configs;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
bool MoeGemmRunner<T, WeightType>::isHopperSpecialised() const
|
||||
{
|
||||
bool config_is_sm90 = best_config_ && best_config_->is_sm90;
|
||||
return supportsHopperSpecialisation() && config_is_sm90;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
bool MoeGemmRunner<T, WeightType>::supportsHopperSpecialisation() const
|
||||
{
|
||||
return sm_ == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>();
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
MoeGemmRunner<T, WeightType>::MoeGemmRunner()
|
||||
{
|
||||
@ -371,33 +421,72 @@ MoeGemmRunner<T, WeightType>::MoeGemmRunner()
|
||||
template <typename T, typename WeightType>
|
||||
template <typename EpilogueTag>
|
||||
void MoeGemmRunner<T, WeightType>::dispatchToArch<EpilogueTag>(T const* A, WeightType const* B, T const* weight_scales,
|
||||
T const* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k,
|
||||
int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy)
|
||||
T const* biases, T* C, int64_t* total_rows_before_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows,
|
||||
int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config,
|
||||
cudaStream_t stream, int* occupancy)
|
||||
{
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
sm_ == 90 || !hopper_input.isValid(), "Hopper input information is set for non specialised implementation");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
sm_ == 90 || !gemm_config.is_sm90, "Hopper configuration provided for non-Hopper architecture");
|
||||
|
||||
if (sm_ >= 70 && sm_ < 75)
|
||||
{
|
||||
dispatchMoeGemmToCutlass<T, WeightType, cutlass::arch::Sm70, EpilogueTag>(A, B, weight_scales, biases, C,
|
||||
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_,
|
||||
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_,
|
||||
stream, occupancy);
|
||||
}
|
||||
else if (sm_ >= 75 && sm_ < 80)
|
||||
{
|
||||
dispatchMoeGemmToCutlass<T, WeightType, cutlass::arch::Sm75, EpilogueTag>(A, B, weight_scales, biases, C,
|
||||
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_,
|
||||
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_,
|
||||
stream, occupancy);
|
||||
}
|
||||
else if (sm_ >= 80 && sm_ < 90)
|
||||
{
|
||||
dispatchMoeGemmToCutlass<T, WeightType, cutlass::arch::Sm80, EpilogueTag>(A, B, weight_scales, biases, C,
|
||||
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_,
|
||||
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_,
|
||||
stream, occupancy);
|
||||
}
|
||||
else if (sm_ >= 90)
|
||||
{
|
||||
// TODO Update the arch to Sm90 once CUTLASS hopper specialisations are available
|
||||
dispatchMoeGemmToCutlass<T, WeightType, cutlass::arch::Sm80, EpilogueTag>(A, B, weight_scales, biases, C,
|
||||
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_,
|
||||
stream, occupancy);
|
||||
if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType, EpilogueTag>())
|
||||
{
|
||||
// We allow both SM90 and SM80 configurations to coexist because for some cases with small numbers of tokens
|
||||
// SM80 is faster. We check here to see which is selected
|
||||
if (gemm_config.is_sm90)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(biases != nullptr || hopper_input.ptr_c == nullptr,
|
||||
"Input biases and hopper input disagree if bias is enabled");
|
||||
TLLM_CHECK_WITH_INFO(hopper_input.isValid(), "Calling SM90 configuration with invalid hopper config");
|
||||
|
||||
dispatchMoeGemmSelectTileShapeSM90<T, WeightType, EpilogueTag>(
|
||||
hopper_input, num_experts, gemm_config, multi_processor_count_, stream, occupancy, nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallthrough to SM80 impl below
|
||||
}
|
||||
|
||||
// Do Ampere case instead
|
||||
if constexpr (kernels::cutlass_kernels::isValidAmpereMOESpecialisation<T, WeightType, EpilogueTag>())
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(!hopper_input.isValid(),
|
||||
"Non-specialised Hopper implementation is being rerouted to fallback implementation so input "
|
||||
"information is not required");
|
||||
TLLM_CHECK_WITH_INFO(!gemm_config.is_sm90,
|
||||
"GEMM config is for SM90 configuration, but this configuration is not valid for Hppper");
|
||||
|
||||
dispatchMoeGemmToCutlass<T, WeightType, cutlass::arch::Sm80, EpilogueTag>(A, B, weight_scales, biases, C,
|
||||
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_,
|
||||
stream, occupancy);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Should only hit by FP8 configs during GEMM profiling pass. Never at runtime
|
||||
TLLM_THROW("Configuration expects SM80 but configuration is not supported by SM80 kernels");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -405,59 +494,74 @@ void MoeGemmRunner<T, WeightType>::dispatchToArch<EpilogueTag>(T const* A, Weigh
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
size_t MoeGemmRunner<T, WeightType>::calcMaxWorkspaceSize(int num_experts) const
|
||||
{
|
||||
if (!supportsHopperSpecialisation())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO((kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>()),
|
||||
"Configuration is specialised for Hopper but not supported");
|
||||
if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>())
|
||||
{
|
||||
auto configs = getHopperConfigs();
|
||||
size_t max_size = 0;
|
||||
bool has_config = false;
|
||||
for (auto conf : configs)
|
||||
{
|
||||
try
|
||||
{
|
||||
size_t size = calcMaxWorkspaceSizeSM90<T, WeightType>(num_experts, conf, multi_processor_count_);
|
||||
max_size = std::max(max_size, size);
|
||||
has_config = true;
|
||||
}
|
||||
catch (tensorrt_llm::common::TllmException const& e)
|
||||
{
|
||||
TLLM_LOG_TRACE("Unsupported config skipped when calculating MOE workspace size");
|
||||
}
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(has_config, "Could not find valid config when calculating workspace size");
|
||||
return max_size;
|
||||
}
|
||||
|
||||
assert(false); // Unreachable
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
template <typename EpilogueTag>
|
||||
void MoeGemmRunner<T, WeightType>::runGemm<EpilogueTag>(T const* A, WeightType const* B, T const* weight_scales,
|
||||
T const* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k,
|
||||
int num_experts, cudaStream_t stream)
|
||||
T const* biases, T* C, int64_t* total_rows_before_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows,
|
||||
int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream)
|
||||
{
|
||||
auto chosen_conf = this->best_config_;
|
||||
if (!chosen_conf)
|
||||
{
|
||||
auto candidate_configs = getConfigs();
|
||||
std::vector<int> occupancies(candidate_configs.size());
|
||||
|
||||
for (size_t ii = 0; ii < candidate_configs.size(); ++ii)
|
||||
{
|
||||
dispatchToArch<EpilogueTag>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n,
|
||||
gemm_k, num_experts, candidate_configs[ii], stream, &occupancies[ii]);
|
||||
}
|
||||
|
||||
static constexpr int workspace_bytes = 0; // No workspace for MoE GEMMs.
|
||||
static constexpr int split_k_limit = 1; // MoE GEMM does not support split-k.
|
||||
|
||||
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
|
||||
chosen_conf = kernels::cutlass_kernels::estimate_best_config_from_occupancies(candidate_configs, occupancies,
|
||||
total_rows, gemm_n, gemm_k, num_experts, split_k_limit, workspace_bytes, multi_processor_count_,
|
||||
is_weight_only);
|
||||
}
|
||||
assert(chosen_conf);
|
||||
dispatchToArch<EpilogueTag>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k,
|
||||
num_experts, *chosen_conf, stream);
|
||||
TLLM_CHECK_WITH_INFO(this->best_config_, "No MOE GEMM config set at runtime");
|
||||
auto chosen_conf = *this->best_config_;
|
||||
dispatchToArch<EpilogueTag>(A, B, weight_scales, biases, C, total_rows_before_expert, hopper_input, total_rows,
|
||||
gemm_n, gemm_k, num_experts, chosen_conf, stream);
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
void MoeGemmRunner<T, WeightType>::moeGemmBiasAct(T const* A, WeightType const* B, T const* weight_scales,
|
||||
T const* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k,
|
||||
int num_experts, ActivationType activation_type, cudaStream_t stream)
|
||||
T const* biases, T* C, int64_t* total_rows_before_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows,
|
||||
int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, cudaStream_t stream)
|
||||
{
|
||||
switch (activation_type)
|
||||
{
|
||||
case ActivationType::Relu:
|
||||
runGemm<cutlass_extensions::EpilogueOpDefaultReLU>(
|
||||
A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream);
|
||||
runGemm<cutlass_extensions::EpilogueOpDefaultReLU>(A, B, weight_scales, biases, C, total_rows_before_expert,
|
||||
hopper_input, total_rows, gemm_n, gemm_k, num_experts, stream);
|
||||
break;
|
||||
case ActivationType::Gelu:
|
||||
runGemm<cutlass_extensions::EpilogueOpDefaultFtGelu>(
|
||||
A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream);
|
||||
runGemm<cutlass_extensions::EpilogueOpDefaultFtGelu>(A, B, weight_scales, biases, C, total_rows_before_expert,
|
||||
hopper_input, total_rows, gemm_n, gemm_k, num_experts, stream);
|
||||
break;
|
||||
case ActivationType::Silu:
|
||||
runGemm<cutlass_extensions::EpilogueOpDefaultSilu>(
|
||||
A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream);
|
||||
runGemm<cutlass_extensions::EpilogueOpDefaultSilu>(A, B, weight_scales, biases, C, total_rows_before_expert,
|
||||
hopper_input, total_rows, gemm_n, gemm_k, num_experts, stream);
|
||||
break;
|
||||
case ActivationType::Identity:
|
||||
runGemm<cutlass_extensions::EpilogueOpDefault>(
|
||||
A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream);
|
||||
runGemm<cutlass_extensions::EpilogueOpDefault>(A, B, weight_scales, biases, C, total_rows_before_expert,
|
||||
hopper_input, total_rows, gemm_n, gemm_k, num_experts, stream);
|
||||
break;
|
||||
case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break;
|
||||
default: TLLM_THROW("Invalid activation type."); break;
|
||||
@ -466,11 +570,11 @@ void MoeGemmRunner<T, WeightType>::moeGemmBiasAct(T const* A, WeightType const*
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
void MoeGemmRunner<T, WeightType>::moeGemm(T const* A, WeightType const* B, T const* weight_scales, T* C,
|
||||
int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cudaStream_t stream)
|
||||
int64_t* total_rows_before_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, int64_t gemm_n,
|
||||
int64_t gemm_k, int num_experts, cudaStream_t stream)
|
||||
{
|
||||
runGemm<cutlass_extensions::EpilogueOpDefault>(
|
||||
A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream);
|
||||
runGemm<cutlass_extensions::EpilogueOpDefault>(A, B, weight_scales, nullptr, C, total_rows_before_expert,
|
||||
hopper_input, total_rows, gemm_n, gemm_k, num_experts, stream);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -0,0 +1,214 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Ignore CUTLASS warnings about type punning
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "cutlass_extensions/compute_occupancy.h"
|
||||
#include "cutlass_extensions/epilogue_helpers.h"
|
||||
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
|
||||
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
|
||||
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <math.h>
|
||||
#include <sstream>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
|
||||
template <typename T, typename WeightType, typename EpilogueTag, typename TileShape, typename ClusterShape>
|
||||
void dispatchMoeGemmSelectBiasSM90(HopperGroupedGemmInput hopper_input, int num_experts, int multi_processor_count,
|
||||
cudaStream_t stream, int* occupancy, size_t* workspace_size)
|
||||
{
|
||||
static_assert(kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType, EpilogueTag>(),
|
||||
"Invalid hopper configuration invoked, fallback to Sm80");
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
workspace_size || hopper_input.isValid(), "Hopper specialisation is missing additional input information");
|
||||
|
||||
// auto func = hopper_input.ptr_c ?
|
||||
// kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper<T, WeightType,
|
||||
// cutlass::arch::Sm90, EpilogueTag, true>
|
||||
// :
|
||||
// kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper<T,
|
||||
// WeightType,
|
||||
// cutlass::arch::Sm90, EpilogueTag, false>;
|
||||
// TODO(dastokes) Re-enable bias when CUTLASS supports it
|
||||
auto func = kernels::cutlass_kernels::sm90_generic_moe_gemm_kernelLauncher<T, WeightType, EpilogueTag, TileShape,
|
||||
ClusterShape, false>;
|
||||
func(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size);
|
||||
}
|
||||
|
||||
/*
|
||||
1x1x1 cluster shape is are supported for any tile shape.
|
||||
|
||||
2x1x1 cluster shape is only supported for when the M tile is at least 128.
|
||||
|
||||
1x2x1 cluster shape is only supported when the N tile is at least 128.
|
||||
|
||||
2x2x1 cluster shape is only supported when both the M and N tiles are at least 128.
|
||||
|
||||
We make the above restrictions are to improve compilation speed in TRT-LLM by pruning kernels
|
||||
that may not be very useful in practice.
|
||||
*/
|
||||
template <typename CTAShape, typename ClusterShape>
|
||||
constexpr bool are_tile_shapes_supported()
|
||||
{
|
||||
using namespace cute;
|
||||
constexpr int cta_m = get<0>(CTAShape{});
|
||||
constexpr int cta_n = get<1>(CTAShape{});
|
||||
constexpr int cga_m = get<0>(ClusterShape{});
|
||||
constexpr int cga_n = get<1>(ClusterShape{});
|
||||
|
||||
if constexpr (cga_m == _1{} && cga_n == _1{})
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if constexpr (cga_m == _2{} && cga_n == _1{} && cta_m >= _128{})
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if constexpr (cga_m == _1{} && cga_n == _2{} && cta_n >= _128{})
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if constexpr (cga_m == _2{} && cga_n == _2{} && cta_m >= _128{} && cta_n >= _128{})
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename EpilogueTag, typename TileShape>
|
||||
void dispatchMoeGemmSelectClusterShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy,
|
||||
size_t* workspace_size)
|
||||
{
|
||||
using namespace cute;
|
||||
switch (gemm_config.cluster_shape)
|
||||
{
|
||||
#define SHAPE_CASE(M, N, K) \
|
||||
case cutlass_extensions::ClusterShape::ClusterShape_##M##x##N##x##K: \
|
||||
{ \
|
||||
using ClusterShape = Shape<_##M, _##N, _##K>; \
|
||||
if constexpr (are_tile_shapes_supported<TileShape, ClusterShape>()) \
|
||||
{ \
|
||||
dispatchMoeGemmSelectBiasSM90<T, WeightType, EpilogueTag, TileShape, ClusterShape>( \
|
||||
hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); \
|
||||
break; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
TLLM_THROW("Unsupported tile and cluster shape combination"); \
|
||||
} \
|
||||
}
|
||||
|
||||
SHAPE_CASE(1, 1, 1)
|
||||
SHAPE_CASE(1, 2, 1)
|
||||
|
||||
SHAPE_CASE(2, 1, 1)
|
||||
SHAPE_CASE(2, 2, 1)
|
||||
|
||||
#undef SHAPE_CASE
|
||||
default: TLLM_THROW("Unsupported config for MoE gemm.");
|
||||
}
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
template <typename T, typename WeightType, typename EpilogueTag>
|
||||
void dispatchMoeGemmSelectTileShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy,
|
||||
size_t* workspace_size)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
switch (gemm_config.tile_config_sm90)
|
||||
{
|
||||
#define SHAPE_CASE(M, N, K) \
|
||||
case cutlass_extensions::CutlassTileConfigSM90::CtaShape##M##x##N##x##K##B: \
|
||||
{ \
|
||||
constexpr int KtileBytes = K / sizeof(T); \
|
||||
using KTileDim = Int<KtileBytes>; \
|
||||
using TileShape = Shape<_##M, _##N, KTileDim>; \
|
||||
dispatchMoeGemmSelectClusterShapeSM90<T, WeightType, EpilogueTag, TileShape>( \
|
||||
hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, workspace_size); \
|
||||
break; \
|
||||
}
|
||||
|
||||
SHAPE_CASE(128, 16, 128)
|
||||
SHAPE_CASE(128, 32, 128)
|
||||
SHAPE_CASE(128, 64, 128)
|
||||
SHAPE_CASE(128, 128, 128)
|
||||
SHAPE_CASE(128, 256, 128)
|
||||
|
||||
#undef SHAPE_CASE
|
||||
case cutlass_extensions::CutlassTileConfigSM90::Undefined: TLLM_THROW("GEMM config undefined."); break;
|
||||
case cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic:
|
||||
TLLM_THROW("GEMM config should have already been set by heuristic.");
|
||||
break;
|
||||
default: TLLM_THROW("Unsupported config for MoE gemm."); break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
size_t calcMaxWorkspaceSizeSM90(
|
||||
int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count)
|
||||
{
|
||||
size_t count;
|
||||
// Most of the values are ignored for WS size calculation. We reuse the function to reduce the template bloat
|
||||
dispatchMoeGemmSelectTileShapeSM90<T, WeightType, cutlass_extensions::EpilogueOpDefault>(
|
||||
HopperGroupedGemmInput{}, num_experts, gemm_config, multi_processor_count, cudaStream_t{0}, nullptr, &count);
|
||||
return count;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/mma_sm90.h"
|
||||
#include "cutlass_extensions/epilogue_helpers.h"
|
||||
|
||||
namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
{
|
||||
|
||||
// Hopper arch
|
||||
template <typename T, typename WeightType, typename EpilogueTag = cutlass_extensions::EpilogueOpDefault>
|
||||
constexpr bool isValidHopperMOESpecialisation()
|
||||
{
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
return cutlass::platform::is_same<T, WeightType>::value
|
||||
&& cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value;
|
||||
#else
|
||||
return false; // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED is set when Hopper kernels are enabled
|
||||
#endif
|
||||
}
|
||||
|
||||
// Hopper arch
|
||||
template <typename T, typename WeightType, typename EpilogueTag = cutlass_extensions::EpilogueOpDefault>
|
||||
constexpr bool isValidAmpereMOESpecialisation()
|
||||
{
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) and defined(ENABLE_FP8)
|
||||
constexpr bool is_fp8
|
||||
= cutlass::platform::is_same<T, __nv_fp8_e4m3>::value || cutlass::platform::is_same<T, __nv_fp8_e5m2>::value;
|
||||
return !is_fp8;
|
||||
#else
|
||||
return true; // Default to true
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
@ -33,12 +33,14 @@ class TrtLlm_QuantOp(enum.Enum):
|
||||
per_column_scale_only = enum_auto()
|
||||
finegrained_scale_only = enum_auto()
|
||||
finegrained_scale_and_zeros = enum_auto()
|
||||
none = enum_auto()
|
||||
|
||||
|
||||
QuantOpNames = {
|
||||
TrtLlm_QuantOp.per_column_scale_only: "cs",
|
||||
TrtLlm_QuantOp.finegrained_scale_only: "fgs",
|
||||
TrtLlm_QuantOp.finegrained_scale_and_zeros: "fgsz"
|
||||
TrtLlm_QuantOp.finegrained_scale_and_zeros: "fgsz",
|
||||
TrtLlm_QuantOp.none: "noquant"
|
||||
}
|
||||
|
||||
QuantOpTag = {
|
||||
@ -47,7 +49,8 @@ QuantOpTag = {
|
||||
TrtLlm_QuantOp.finegrained_scale_only:
|
||||
"cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY",
|
||||
TrtLlm_QuantOp.finegrained_scale_and_zeros:
|
||||
"cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS"
|
||||
"cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS",
|
||||
TrtLlm_QuantOp.none: "void"
|
||||
}
|
||||
|
||||
################################################################################
|
||||
@ -56,7 +59,8 @@ QuantOpTag = {
|
||||
CudaTypeName = {
|
||||
DataType.e4m3: "__nv_fp8_e4m3",
|
||||
DataType.bf16: "__nv_bfloat16",
|
||||
DataType.f16: "half"
|
||||
DataType.f16: "half",
|
||||
DataType.f32: "float"
|
||||
}
|
||||
|
||||
|
||||
@ -64,22 +68,10 @@ CudaTypeName = {
|
||||
# A data structure holding all info to instantiate gemm launchers in TRT LLM.
|
||||
class TrtLlm_GemmLauncher:
|
||||
|
||||
def __init__(self,
|
||||
gemm_kind,
|
||||
arch,
|
||||
act_type,
|
||||
weight_type,
|
||||
scalezero_type,
|
||||
bias_type,
|
||||
output_type,
|
||||
quant_op,
|
||||
epi_tag,
|
||||
cta_shape,
|
||||
warp_shape,
|
||||
stages,
|
||||
cga_shape=None,
|
||||
mainloop_schedule=None,
|
||||
epi_schedule=None):
|
||||
def __init__(self, gemm_kind, arch, act_type, weight_type, scalezero_type,
|
||||
bias_type, output_type, quant_op, epi_tag, cta_shape,
|
||||
warp_shape, stages, cga_shape, mainloop_schedule,
|
||||
epi_schedule):
|
||||
self.gemm_kind = gemm_kind
|
||||
self.arch = arch
|
||||
self.act_type = act_type
|
||||
@ -124,9 +116,7 @@ def tuple_to_cute_shape(shape):
|
||||
|
||||
|
||||
def instantiate_operation(operation):
|
||||
|
||||
act_tag = CudaTypeName[operation.act_type]
|
||||
weight_tag = DataTypeTag[operation.weight_type]
|
||||
scale_zero_tag = CudaTypeName[operation.scalezero_type]
|
||||
bias_tag = CudaTypeName[operation.bias_type]
|
||||
out_tag = CudaTypeName[operation.output_type]
|
||||
@ -138,18 +128,21 @@ def instantiate_operation(operation):
|
||||
cute_cga_shape = tuple_to_cute_shape(operation.cga_shape)
|
||||
|
||||
kernel_sched = KernelScheduleTag[operation.mainloop_schedule]
|
||||
|
||||
# Here, we must append MixedInput depending on the schedule, since we know the types are different.
|
||||
# It is a work around since the CUTLASS library did not have the MixedInput schedules at the time of writing.
|
||||
if operation.mainloop_schedule in [
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
KernelScheduleType.TmaWarpSpecializedPingpong,
|
||||
KernelScheduleType.TmaWarpSpecialized
|
||||
]:
|
||||
kernel_sched += "MixedInput"
|
||||
epi_sched = EpilogueScheduleTag[operation.epi_schedule]
|
||||
|
||||
instantiation = f"""
|
||||
if operation.gemm_kind == GemmKind.Gemm:
|
||||
if operation.mainloop_schedule in [
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
KernelScheduleType.TmaWarpSpecializedPingpong,
|
||||
KernelScheduleType.TmaWarpSpecialized
|
||||
] and DataTypeSize[operation.act_type] != DataTypeSize[
|
||||
operation.weight_type]:
|
||||
# Here, we must append MixedInput depending on the schedule, since we know the types are different.
|
||||
# It is a work around since the CUTLASS library did not have the MixedInput schedules at the time of writing.
|
||||
kernel_sched += "MixedInput"
|
||||
|
||||
weight_tag = DataTypeTag[operation.weight_type]
|
||||
instantiation = f"""
|
||||
template void sm90_generic_mixed_gemm_kernelLauncher<{act_tag}, {weight_tag}, {scale_zero_tag}, {bias_tag}, {out_tag},
|
||||
{quant_op}, {epi_tag},
|
||||
{cute_cta_shape}, {cute_cga_shape},
|
||||
@ -157,12 +150,28 @@ template void sm90_generic_mixed_gemm_kernelLauncher<{act_tag}, {weight_tag}, {s
|
||||
const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zero_tag}*, const {bias_tag}*, const float,
|
||||
{out_tag}*, int, int, int, const int, tensorrt_llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
|
||||
);
|
||||
"""
|
||||
elif operation.gemm_kind == GemmKind.Grouped:
|
||||
# Similar to MixedInput above, we must modify the tags for grouped gemm as CUTLASS library does not have the updated schedules
|
||||
assert operation.mainloop_schedule in [
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum
|
||||
]
|
||||
assert operation.epi_schedule == EpilogueScheduleType.NoSmemWarpSpecialized
|
||||
kernel_sched.replace("::Kernel", "::KernelGrouped")
|
||||
epi_sched += "Grouped"
|
||||
|
||||
weight_tag = CudaTypeName[operation.weight_type]
|
||||
|
||||
instantiation = f"""
|
||||
template void sm90_generic_moe_gemm_kernelLauncher<{act_tag}, {weight_tag},
|
||||
{epi_tag}, {cute_cta_shape}, {cute_cga_shape}, false>
|
||||
(HopperGroupedGemmInput, int, int, cudaStream_t, int*, size_t*);
|
||||
"""
|
||||
return instantiation
|
||||
|
||||
|
||||
def get_file_content(launcher_inl_files, operations):
|
||||
|
||||
include_list = list()
|
||||
for file in launcher_inl_files:
|
||||
include_list.append(f"#include \"{file}\"")
|
||||
@ -191,11 +200,12 @@ namespace cutlass_kernels
|
||||
|
||||
|
||||
def write_file(launcher_inl_files, operations, output_file):
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
with open(output_file, mode="w") as f:
|
||||
f.write(get_file_content(launcher_inl_files, operations))
|
||||
|
||||
|
||||
def is_op_valid(op):
|
||||
def is_gemm_op_valid(op):
|
||||
tile_m, tile_n, _ = op.cta_shape
|
||||
cga_m, cga_n, _ = op.cga_shape
|
||||
|
||||
@ -214,14 +224,41 @@ def is_op_valid(op):
|
||||
return False
|
||||
|
||||
|
||||
def is_grouped_gemm_op_valid(op):
|
||||
if not is_gemm_op_valid(op):
|
||||
return False
|
||||
|
||||
if op.epi_tag != TrtLlm_EpilogueTag.epilogue_op_default:
|
||||
return False
|
||||
|
||||
if op.epi_schedule != EpilogueScheduleType.NoSmemWarpSpecialized:
|
||||
return False
|
||||
|
||||
if op.mainloop_schedule not in [
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum
|
||||
]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def is_op_valid(op):
|
||||
if op.gemm_kind == GemmKind.Gemm:
|
||||
return is_gemm_op_valid(op)
|
||||
if op.gemm_kind == GemmKind.Grouped:
|
||||
return is_grouped_gemm_op_valid(op)
|
||||
|
||||
|
||||
################################################################################
|
||||
def generate_sm90_operations():
|
||||
def generate_sm90_mixed_gemm_operations():
|
||||
arch = 90
|
||||
|
||||
# For legacy reasons, we use unsigned types for fp16 / bf16 activations.
|
||||
# For legacy reasons, we use unsigned types for the weights. The instanitated template
|
||||
# will remap those back to the signed type.
|
||||
# Takes the form (activation_type, weight_type, scalezero_type, bias_type, output_type)
|
||||
supported_dtypes = [
|
||||
(DataType.e4m3, DataType.s4, DataType.f16, DataType.f16, DataType.f16),
|
||||
(DataType.e4m3, DataType.u4, DataType.f16, DataType.f16, DataType.f16),
|
||||
(DataType.f16, DataType.u4, DataType.f16, DataType.f16, DataType.f16),
|
||||
(DataType.bf16, DataType.u4, DataType.bf16, DataType.bf16,
|
||||
DataType.bf16),
|
||||
@ -260,11 +297,56 @@ def generate_sm90_operations():
|
||||
mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative if use_coop else KernelScheduleType.TmaWarpSpecializedPingpong
|
||||
epi_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative if use_coop else EpilogueScheduleType.TmaWarpSpecialized
|
||||
|
||||
operation = TrtLlm_GemmLauncher(GemmKind.Gemm, arch, *dtype_combo, quant_op, epi_tag, cta_shape_mnk, \
|
||||
warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule)
|
||||
fpA_intB_operation = TrtLlm_GemmLauncher(GemmKind.Gemm, arch, *dtype_combo, quant_op, epi_tag, cta_shape_mnk, \
|
||||
warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule)
|
||||
|
||||
if is_op_valid(operation):
|
||||
operations.append(operation)
|
||||
if is_op_valid(fpA_intB_operation):
|
||||
operations.append(fpA_intB_operation)
|
||||
|
||||
return operations
|
||||
|
||||
|
||||
def generate_sm90_grouped_gemm_operations():
|
||||
arch = 90
|
||||
supported_dtypes = [
|
||||
DataType.f16, DataType.bf16, DataType.f32, DataType.e4m3
|
||||
]
|
||||
quant_ops = [TrtLlm_QuantOp.none]
|
||||
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default]
|
||||
M_TILES = [128] # Currently M tile must be 128 for Grouped GEMM
|
||||
N_TILES = [16, 32, 64, 128, 256]
|
||||
cta_shapes_mn = product(M_TILES, N_TILES)
|
||||
|
||||
warp_shape = [0, 0, 0] # ignored except for naming
|
||||
stages = 0 # auto
|
||||
|
||||
cga_shapes = product([1, 2], [1, 2], [1])
|
||||
|
||||
partial_args = product(supported_dtypes, quant_ops, epi_tags, cta_shapes_mn,
|
||||
cga_shapes)
|
||||
|
||||
operations = list()
|
||||
for dtype, quant_op, epi_tag, cta_shape_mn, cga_shape in partial_args:
|
||||
max_k_bits = 128 * 8
|
||||
cta_shape_k = max_k_bits // DataTypeSize[dtype]
|
||||
cta_shape_mnk = cta_shape_mn + (cta_shape_k, )
|
||||
|
||||
mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative if dtype else KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum
|
||||
epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized
|
||||
|
||||
moe_gemm_operation = TrtLlm_GemmLauncher(
|
||||
GemmKind.Grouped, arch, dtype, dtype, dtype, dtype, dtype, quant_op,
|
||||
epi_tag, cta_shape_mnk, warp_shape, stages, cga_shape,
|
||||
mainloop_schedule, epi_schedule)
|
||||
|
||||
if is_op_valid(moe_gemm_operation):
|
||||
operations.append(moe_gemm_operation)
|
||||
return operations
|
||||
|
||||
|
||||
def generate_sm90_operations():
|
||||
operations = generate_sm90_mixed_gemm_operations()
|
||||
operations.extend(generate_sm90_grouped_gemm_operations())
|
||||
return operations
|
||||
|
||||
|
||||
@ -284,7 +366,10 @@ if __name__ == "__main__":
|
||||
# Get the absolute path of the provided directory
|
||||
output_dir = os.path.abspath(args.output_dir)
|
||||
|
||||
hopper_inl = "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl"
|
||||
fpA_intB_inl = "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl"
|
||||
moe_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl"
|
||||
|
||||
inl_map = {GemmKind.Gemm: [fpA_intB_inl], GemmKind.Grouped: [moe_gemm_inl]}
|
||||
|
||||
# The goal here is to group kernels with common instantiations together in order to reduce template instantiation overheads.
|
||||
# Template instantiation dominates the time in a compilation unit, so it is the most important factor to improve.
|
||||
@ -298,7 +383,9 @@ if __name__ == "__main__":
|
||||
|
||||
file_counter = 1
|
||||
for key, value in op_groups.items():
|
||||
gemm_kind, _, _ = key
|
||||
out_file = os.path.join(
|
||||
output_dir, f"cutlass_kernel_file_{file_counter}.generated.cu")
|
||||
write_file([hopper_inl], value, out_file)
|
||||
output_dir, GemmKindNames[gemm_kind],
|
||||
f"cutlass_kernel_file_{file_counter}.generated.cu")
|
||||
write_file(inl_map[gemm_kind], value, out_file)
|
||||
file_counter += 1
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user