Update TensorRT-LLM (#1492)

* Update TensorRT-LLM

---------

Co-authored-by: Loki <lokravi@amazon.com>
This commit is contained in:
Kaiyu Xie 2024-04-24 14:44:22 +08:00 committed by GitHub
parent 71d8d4d3dc
commit 66ef1df492
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
319 changed files with 21392 additions and 37330 deletions

2
.gitignore vendored
View File

@ -32,6 +32,8 @@ cpp/.ccache/
tensorrt_llm/libs
tensorrt_llm/bindings.pyi
tensorrt_llm/bindings/*.pyi
*docs/cpp_docs*
*docs/source/_cpp_gen*
# Testing
.coverage.*

2
3rdparty/cutlass vendored

@ -1 +1 @@
Subproject commit a8f2c80db0564c74f4efccac71993b971dfc448b
Subproject commit 7d49e6c7e2f8896c47f586706e67e1fb215529dc

460
README.md
View File

@ -11,7 +11,7 @@ TensorRT-LLM
[![version](https://img.shields.io/badge/release-0.9.0-green)](./setup.py)
[![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE)
[Architecture](./docs/source/architecture.md)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Results](./docs/source/performance.md)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Examples](./examples/)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Documentation](./docs/source/)
[Architecture](./docs/source/architecture/overview.md)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Results](./docs/source/performance/perf-overview.md)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Examples](./examples/)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Documentation](./docs/source/)
---
<div align="left">
@ -29,42 +29,13 @@ TensorRT-LLM
* [2023/10/17] [Large Language Models up to 4x Faster on RTX With TensorRT-LLM for Windows
](https://blogs.nvidia.com/blog/2023/10/17/tensorrt-llm-windows-stable-diffusion-rtx/)
## Table of Contents
- [TensorRT-LLM](#tensorrt-llm)
- [Latest News](#latest-news)
- [Table of Contents](#table-of-contents)
- [TensorRT-LLM Overview](#tensorrt-llm-overview)
- [Installation](#installation)
- [Quick Start](#quick-start)
- [Support Matrix](#support-matrix)
- [Devices](#devices)
- [Precision](#precision)
- [Key Features](#key-features)
- [Models](#models)
- [Performance](#performance)
- [Advanced Topics](#advanced-topics)
- [Quantization](#quantization)
- [In-flight Batching](#in-flight-batching)
- [Attention](#attention)
- [Graph Rewriting](#graph-rewriting)
- [Benchmark](#benchmark)
- [Troubleshooting](#troubleshooting)
- [Release notes](#release-notes)
- [Change Log](#change-log)
- [Versions 0.9.0](#versions-090)
- [For history change log, please see CHANGELOG.md.](#for-history-change-log-please-see-changelogmd)
- [Known Issues](#known-issues)
- [Report Issues](#report-issues)
## TensorRT-LLM Overview
TensorRT-LLM provides users with an easy-to-use Python API to define Large
TensorRT-LLM is an easy-to-use Python API to define Large
Language Models (LLMs) and build
[TensorRT](https://developer.nvidia.com/tensorrt) engines that contain
state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.
TensorRT-LLM also contains components to create Python and C++ runtimes that
TensorRT-LLM contains components to create Python and C++ runtimes that
execute those TensorRT engines. It also includes a
[backend](https://github.com/triton-inference-server/tensorrtllm_backend)
for integration with the
@ -76,8 +47,8 @@ multiple nodes with multiple GPUs (using
and/or
[Pipeline Parallelism](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/nlp/nemo_megatron/parallelisms.html#pipeline-parallelism)).
The Python API of TensorRT-LLM is architectured to look similar to the
[PyTorch](https://pytorch.org) API. It provides users with a
The TensorRT-LLM Python API architecture looks similar to the
[PyTorch](https://pytorch.org) API. It provides a
[functional](./tensorrt_llm/functional.py) module containing functions like
`einsum`, `softmax`, `matmul` or `view`. The [layers](./tensorrt_llm/layers)
module bundles useful building blocks to assemble LLMs; like an `Attention`
@ -86,422 +57,21 @@ like `GPTAttention` or `BertAttention`, can be found in the
[models](./tensorrt_llm/models) module.
TensorRT-LLM comes with several popular models pre-defined. They can easily be
modified and extended to fit custom needs. See below for a list of supported
[models](#Models).
modified and extended to fit custom needs. Refer to the [Support Matrix](https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html) for a list of supported models.
To maximize performance and reduce memory footprint, TensorRT-LLM allows the
models to be executed using different quantization modes (see
[`examples/gpt`](./examples/gpt) for concrete examples). TensorRT-LLM supports
models to be executed using different quantization modes (refer to
[`support matrix`](https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html#software)). TensorRT-LLM supports
INT4 or INT8 weights (and FP16 activations; a.k.a. INT4/INT8 weight-only) as
well as a complete implementation of the
[SmoothQuant](https://arxiv.org/abs/2211.10438) technique.
For a more detailed presentation of the software architecture and the key
concepts used in TensorRT-LLM, we recommend you to read the following
[document](./docs/source/architecture.md).
## Getting Started
## Installation
To get started with TensorRT-LLM, visit our documentation:
After installing the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit),
please run the following commands to install TensorRT-LLM for x86_64 users.
```bash
# Obtain and start the basic docker image environment.
docker run --rm --runtime=nvidia --gpus all --entrypoint /bin/bash -it nvidia/cuda:12.1.0-devel-ubuntu22.04
# Install dependencies, TensorRT-LLM requires Python 3.10
apt-get update && apt-get -y install python3.10 python3-pip openmpi-bin libopenmpi-dev
# Install the latest preview version (corresponding to the main branch) of TensorRT-LLM.
# If you want to install the stable version (corresponding to the release branch), please
# remove the `--pre` option.
pip3 install tensorrt_llm -U --pre --extra-index-url https://pypi.nvidia.com
# Check installation
python3 -c "import tensorrt_llm"
```
For developers who have the best performance requirements, debugging needs, or use the aarch64 architecture,
please refer to the instructions for [building from source code](docs/source/build_from_source.md).
For Windows installation, see [`Windows`](windows/README.md).
## Quick Start
Please be sure to complete the [installation steps](#installation) before proceeding with the following steps.
To create a TensorRT engine for an existing model, there are 3 steps:
1. Download pre-trained weights,
2. Build a fully-optimized engine of the model,
3. Deploy the engine, in other words, run the fully-optimized model.
The following sections show how to use TensorRT-LLM to run the
[BLOOM-560m](https://huggingface.co/bigscience/bloom-560m) model.
***0. In the BLOOM folder***
Inside the Docker container, you have to install the requirements:
```bash
pip install -r examples/bloom/requirements.txt
git lfs install
```
***1. Download the model weights from HuggingFace***
From the BLOOM example folder, you must download the weights of the model.
```bash
cd examples/bloom
rm -rf ./bloom/560M
mkdir -p ./bloom/560M && git clone https://huggingface.co/bigscience/bloom-560m ./bloom/560M
```
***2. Build the engine***
```bash
# Single GPU on BLOOM 560M
python convert_checkpoint.py --model_dir ./bloom/560M/ \
--dtype float16 \
--output_dir ./bloom/560M/trt_ckpt/fp16/1-gpu/
# May need to add trtllm-build to PATH, export PATH=/usr/local/bin:$PATH
trtllm-build --checkpoint_dir ./bloom/560M/trt_ckpt/fp16/1-gpu/ \
--gemm_plugin float16 \
--output_dir ./bloom/560M/trt_engines/fp16/1-gpu/
```
See the BLOOM [example](examples/bloom) for more details and options regarding the `trtllm-build` command.
***3. Run***
The `../summarize.py` script can be used to perform the summarization of articles
from the CNN Daily dataset:
```bash
python ../summarize.py --test_trt_llm \
--hf_model_dir ./bloom/560M/ \
--data_type fp16 \
--engine_dir ./bloom/560M/trt_engines/fp16/1-gpu/
```
More details about the script and how to run the BLOOM model can be found in
the example [folder](examples/bloom). Many more [models](#models) than BLOOM
are implemented in TensorRT-LLM. They can be found in the
[examples](./examples/) directory.
Beyond local execution, you can also use the NVIDIA Triton Inference Server to create a production-ready deployment of your LLM as described in this [blog](https://developer.nvidia.com/blog/optimizing-inference-on-llms-with-tensorrt-llm-now-publicly-available/).
## Support Matrix
TensorRT-LLM optimizes the performance of a range of well-known models on
NVIDIA GPUs. The following sections provide a list of supported GPU
architectures as well as important features implemented in TensorRT-LLM.
### Devices
TensorRT-LLM supports the following architectures:
* [NVIDIA Hopper](https://www.nvidia.com/en-us/data-center/technologies/hopper-architecture/) (SM90), for example, H200, H100, H20
* [NVIDIA Ada Lovelace](https://www.nvidia.com/en-us/geforce/ada-lovelace-architecture/) (SM89), for example, L40S, L20, L4
* [NVIDIA Ampere](https://www.nvidia.com/en-us/data-center/ampere-architecture/) (SM80, SM86), for example, A100, A30, A10G
* [NVIDIA Turing](https://www.nvidia.com/en-us/geforce/turing/) (SM75), for example, T4
* [NVIDIA Volta](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) (SM70 - experimental), for example, V100
It is important to note that TensorRT-LLM is expected to work on all GPUs based on the Volta, Turing, Ampere, Hopper, and Ada Lovelace architectures. Certain limitations may apply.
### Precision
Various numerical precisions are supported in TensorRT-LLM. The support for
some of those numerical features require specific architectures:
| | FP32 | FP16 | BF16 | FP8 | INT8 | INT4 |
| :------------------ | :--- | :--- | :--- | :--- | :---- | :---- |
| Volta (SM70) | Y | Y | N | N | Y (1) | Y (2) |
| Turing (SM75) | Y | Y | N | N | Y (1) | Y (2) |
| Ampere (SM80, SM86) | Y | Y | Y | N | Y | Y (3) |
| Ada-Lovelace (SM89) | Y | Y | Y | Y | Y | Y |
| Hopper (SM90) | Y | Y | Y | Y | Y | Y |
(1) INT8 SmoothQuant is not supported on SM70 and SM75.<br>
(2) INT4 AWQ and GPTQ are not supported on SM < 80.<br>
(3) INT4 AWQ and GPTQ with FP8 activations require SM >= 89.
In this release of TensorRT-LLM, the support for FP8 and quantized data types
(INT8 or INT4) is not implemented for all the models. See the
[precision](./docs/source/precision.md) document and the
[examples](./examples/.) folder for additional details.
### Key Features
TensorRT-LLM contains examples that implement the following features.
* Multi-head Attention([MHA](https://arxiv.org/abs/1706.03762))
* Multi-query Attention ([MQA](https://arxiv.org/abs/1911.02150))
* Group-query Attention([GQA](https://arxiv.org/abs/2307.09288))
* In-flight Batching
* Paged KV Cache for the Attention
* Tensor Parallelism
* Pipeline Parallelism
* INT4/INT8 Weight-Only Quantization (W4A16 & W8A16)
* [SmoothQuant](https://arxiv.org/abs/2211.10438)
* [GPTQ](https://arxiv.org/abs/2210.17323)
* [AWQ](https://arxiv.org/abs/2306.00978)
* [FP8](https://arxiv.org/abs/2209.05433)
* Greedy-search
* Beam-search
* RoPE
In this release of TensorRT-LLM, some of the features are not enabled for all
the models listed in the [examples](examples/.) folder.
### Models
The list of supported models is:
* [Baichuan](examples/baichuan)
* [BART](examples/enc_dec)
* [BERT](examples/bert)
* [Blip2](examples/blip2)
* [BLOOM](examples/bloom)
* [ChatGLM](examples/chatglm)
* [FairSeq NMT](examples/enc_dec/nmt)
* [Falcon](examples/falcon)
* [Flan-T5](examples/enc_dec)
* [GPT](examples/gpt)
* [GPT-J](examples/gptj)
* [GPT-Nemo](examples/gpt)
* [GPT-NeoX](examples/gptneox)
* [InternLM](examples/internlm)
* [LLaMA](examples/llama)
* [LLaMA-v2](examples/llama)
* [Mamba](examples/mamba)
* [mBART](examples/enc_dec)
* [Medusa](examples/medusa)
* [Mistral](examples/llama#mistral-v01)
* [MPT](examples/mpt)
* [mT5](examples/enc_dec)
* [OPT](examples/opt)
* [Phi-1.5/Phi-2](examples/phi)
* [Qwen](examples/qwen)
* [Replit Code](examples/mpt)
* [RoBERTa](examples/bert)
* [SantaCoder](examples/gpt)
* [StarCoder1/StarCoder2](examples/gpt)
* [T5](examples/enc_dec)
* [Whisper](examples/whisper)
Note: [Encoder-Decoder](examples/enc_dec/) provides general encoder-decoder
functionality that supports many encoder-decoder models such as T5 family, BART family, Whisper family, NMT family, etc. We
unroll the exact model names in the list above to let users find specific
models easier.
The list of supported multi-modal models is:
* [BLIP2 w/ OPT-2.7B](examples/multimodal)
* [BLIP2 w/ T5-XL](examples/multimodal)
* [LLaVA-v1.5-7B](examples/multimodal)
* [Nougat family](examples/multimodal) Nougat-small, Nougat-base
Note: Multi-modal provides general multi-modal functionality that supports many multi-modal architectures such as BLIP family, LLaVA family, etc. We unroll the exact model names in the list above to let users find specific models easier.
## Performance
Please refer to the [performance](./docs/source/performance.md) page for
performance numbers. That page contains measured numbers for four variants of
popular models (GPT-J, LLAMA-7B, LLAMA-70B, Falcon-180B), measured on the H100,
L40S and A100 GPU(s).
## Advanced Topics
### Quantization
This [document](./docs/source/precision.md) describes the different
quantization methods implemented in TensorRT-LLM and contains a support matrix
for the different models.
### In-flight Batching
TensorRT-LLM supports in-flight batching of requests (also known as continuous
batching or iteration-level batching). It's a
[technique](./docs/source/batch_manager.md) that aims at reducing wait
times in queues, eliminating the need for padding requests and allowing for
higher GPU utilization.
### Attention
TensorRT-LLM implements several variants of the Attention mechanism that
appears in most the Large Language Models. This
[document](./docs/source/gpt_attention.md) summarizes those implementations and
how they are optimized in TensorRT-LLM.
### Graph Rewriting
TensorRT-LLM uses a declarative approach to define neural networks and contains
techniques to optimize the underlying graph. For more details, please refer to
[doc](./docs/source/graph-rewriting.md)
### Benchmark
TensorRT-LLM provides [C++](./benchmarks/cpp/README.md) and
[Python](./benchmarks/python/README.md) tools to perform benchmarking. Note,
however, that it is recommended to use the C++ version.
## Troubleshooting
* If you encounter accuracy issues in the generated text, you may want to increase
the internal precision in the attention layer. For that, pass the `--context_fmha_fp32_acc enable` to
`trtllm-build`.
* It's recommended to add options `shm-size=1g ulimit memlock=-1` to the
docker or nvidia-docker run command. Otherwise you may see NCCL errors when
running multiple GPU inferences. See
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html#errors
for details.
* When building models, memory-related issues such as
```
[09/23/2023-03:13:00] [TRT] [E] 9: GPTLMHeadModel/layers/0/attention/qkv/PLUGIN_V2_Gemm_0: could not find any supported formats consistent with input/output data types
[09/23/2023-03:13:00] [TRT] [E] 9: [pluginV2Builder.cpp::reportPluginError::24] Error Code 9: Internal Error (GPTLMHeadModel/layers/0/attention/qkv/PLUGIN_V2_Gemm_0: could not find any supported formats consistent with input/output data types)
```
may happen. One possible solution is to reduce the amount of memory needed by
reducing the maximum batch size, input and output lengths. Another option is to
enable plugins, for example: `--gpt_attention_plugin`.
* MPI + Slurm
TensorRT-LLM is a
[MPI](https://en.wikipedia.org/wiki/Message_Passing_Interface)-aware package
that uses [`mpi4py`](https://mpi4py.readthedocs.io/en/stable/). If you are
running scripts in a [Slurm](https://slurm.schedmd.com/) environment, you might
encounter interferences:
```
--------------------------------------------------------------------------
PMI2_Init failed to initialize. Return code: 14
--------------------------------------------------------------------------
--------------------------------------------------------------------------
The application appears to have been direct launched using "srun",
but OMPI was not built with SLURM's PMI support and therefore cannot
execute. There are several options for building PMI support under
SLURM, depending upon the SLURM version you are using:
version 16.05 or later: you can use SLURM's PMIx support. This
requires that you configure and build SLURM --with-pmix.
Versions earlier than 16.05: you must use either SLURM's PMI-1 or
PMI-2 support. SLURM builds PMI-1 by default, or you can manually
install PMI-2. You must then build Open MPI using --with-pmi pointing
to the SLURM PMI library location.
Please configure as appropriate and try again.
--------------------------------------------------------------------------
```
As a rule of thumb, if you are running TensorRT-LLM interactively on a Slurm
node, prefix your commands with `mpirun -n 1` to run TensorRT-LLM in a
dedicated MPI environment, not the one provided by your Slurm allocation.
For example: `mpirun -n 1 python3 examples/run.py ...`
## Release notes
* TensorRT-LLM requires TensorRT 9.3 and 24.02 containers.
### Change Log
#### Versions 0.9.0
* Model Support
- Support distil-whisper, thanks to the contribution from @Bhuvanesh09 in PR #1061
- Support HuggingFace StarCoder2
- Support VILA
- Support Smaug-72B-v0.1
- Migrate BLIP-2 examples to `examples/multimodal`
* Features
- Add support to context chunking to work with KV cache reuse
- Enable different rewind tokens per sequence for Medusa
- BART LoRA support (limited to the Python runtime)
- Enable multi-LoRA for BART LoRA
- Support `early_stopping=False` in beam search for C++ Runtime
- Add logits post processor to the batch manager (see docs/source/batch_manager.md#logits-post-processor-optional)
- Support import and convert HuggingFace Gemma checkpoints, thanks for the contribution from @mfuntowicz in #1147
- Support loading Gemma from HuggingFace
- Support auto parallelism planner for high-level API and unified builder workflow
- Support run `GptSession` without OpenMPI #1220
- [BREAKING CHANGE] TopP sampling optimization with deterministic AIR TopP algorithm is enabled by default
- Medusa IFB support
- [Experimental] Support FP8 FMHA, note that the performance is not optimal, and we will keep optimizing it
- [BREAKING CHANGE] Support embedding sharing for Gemma
- More head sizes support for LLaMA-like models
- Ampere (sm80, sm86), Ada (sm89), Hopper(sm90) all support head sizes [32, 40, 64, 80, 96, 104, 128, 160, 256] now.
- OOTB functionality support
- T5
- Mixtral 8x7B
* API
- C++ `executor` API
- Add Python bindings, see documentation and examples in `examples/bindings`
- Add advanced and multi-GPU examples for Python binding of `executor` C++ API, see `examples/bindings/README.md`
- Add documents for C++ `executor` API, see `docs/source/executor.md`
- High-level API (refer to `examples/high-level-api/README.md` for guidance)
- [BREAKING CHANGE] Reuse the `QuantConfig` used in `trtllm-build` tool, support broader quantization features
- Support in `LLM()` API to accept engines built by `trtllm-build` command
- Add support for TensorRT-LLM checkpoint as model input
- Refine `SamplingConfig` used in `LLM.generate` or `LLM.generate_async` APIs, with the support of beam search, a variety of penalties, and more features
- Add support for the StreamingLLM feature, enable it by setting `LLM(streaming_llm=...)`
- Migrate Mixtral to high level API and unified builder workflow
- [BREAKING CHANGE] Refactored Qwen model to the unified build workflow, see `examples/qwen/README.md` for the latest commands
- [BREAKING CHANGE] Move LLaMA convert checkpoint script from examples directory into the core library
- [BREAKING CHANGE] Refactor GPT with unified building workflow, see `examples/gpt/README.md` for the latest commands
- [BREAKING CHANGE] Removed all the lora related flags from convert_checkpoint.py script and the checkpoint content to `trtllm-build` command, to generalize the feature better to more models
- [BREAKING CHANGE] Removed the use_prompt_tuning flag and options from convert_checkpoint.py script and the checkpoint content, to generalize the feature better to more models. Use the `trtllm-build --max_prompt_embedding_table_size` instead.
- [BREAKING CHANGE] Changed the `trtllm-build --world_size` flag to `--auto_parallel` flag, the option is used for auto parallel planner only.
- [BREAKING CHANGE] `AsyncLLMEngine` is removed, `tensorrt_llm.GenerationExecutor` class is refactored to work with both explicitly launching with `mpirun` in the application level, and accept an MPI communicator created by `mpi4py`
- [BREAKING CHANGE] `examples/server` are removed, see `examples/app` instead.
- [BREAKING CHANGE] Remove LoRA related parameters from convert checkpoint scripts
- [BREAKING CHANGE] Simplify Qwen convert checkpoint script
- [BREAKING CHANGE] Remove `model` parameter from `gptManagerBenchmark` and `gptSessionBenchmark`
* Bug fixes
- Fix a weight-only quant bug for Whisper to make sure that the `encoder_input_len_range` is not 0, thanks to the contribution from @Eddie-Wang1120 in #992
- Fix the issue that log probabilities in Python runtime are not returned #983
- Multi-GPU fixes for multimodal examples #1003
- Fix wrong `end_id` issue for Qwen #987
- Fix a non-stopping generation issue #1118 #1123
- Fix wrong link in examples/mixtral/README.md #1181
- Fix LLaMA2-7B bad results when int8 kv cache and per-channel int8 weight only are enabled #967
- Fix wrong `head_size` when importing Gemma model from HuggingFace Hub, thanks for the contribution from @mfuntowicz in #1148
- Fix ChatGLM2-6B building failure on INT8 #1239
- Fix wrong relative path in Baichuan documentation #1242
- Fix wrong `SamplingConfig` tensors in `ModelRunnerCpp` #1183
- Fix error when converting SmoothQuant LLaMA #1267
- Fix the issue that `examples/run.py` only load one line from `--input_file`
- Fix the issue that `ModelRunnerCpp` does not transfer `SamplingConfig` tensor fields correctly #1183
* Benchmark
- Add emulated static batching in `gptManagerBenchmark`
- Support arbitrary dataset from HuggingFace for C++ benchmarks, see “Prepare dataset” section in `benchmarks/cpp/README.md`
- Add percentile latency report to `gptManagerBenchmark`
* Performance
- Optimize `gptDecoderBatch` to support batched sampling
- Enable FMHA for models in BART, Whisper and NMT family
- Remove router tensor parallelism to improve performance for MoE models, thanks to the contribution from @megha95 in #1091
- Improve custom all-reduce kernel
* Infra
- Base Docker image for TensorRT-LLM is updated to `nvcr.io/nvidia/pytorch:24.02-py3`
- Base Docker image for TensorRT-LLM backend is updated to `nvcr.io/nvidia/tritonserver:24.02-py3`
- The dependent TensorRT version is updated to 9.3
- The dependent PyTorch version is updated to 2.2
- The dependent CUDA version is updated to 12.3.2 (a.k.a. 12.3 Update 2)
#### For history change log, please see [CHANGELOG.md](./CHANGELOG.md).
### Known Issues
* On windows, running context FMHA plugin with FP16 accumulation on LLaMA, Mistral and Phi models suffers from poor accuracy and the resulting inference output may be garbled. The suggestion to workaround these is to enable FP32 accumulation when building the models, i.e. passing the options `--context_fmha disable --context_fmha_fp32_acc enable` to `trtllm-build` command as a work-around, and this should be fixed in the next version
* The hang reported in issue
[#149](https://github.com/triton-inference-server/tensorrtllm_backend/issues/149)
has not been reproduced by the TensorRT-LLM team. If it is caused by a bug
in TensorRT-LLM, that bug may be present in that release
### Report Issues
You can use GitHub issues to report issues with TensorRT-LLM.
- [Quick Start Guide](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html)
- [Release Notes](https://nvidia.github.io/TensorRT-LLM/release-notes.html)
- [Installation Guide for Linux](https://nvidia.github.io/TensorRT-LLM/installation/linux.html)
- [Installation Guide for Windows](https://nvidia.github.io/TensorRT-LLM/installation/windows.html)
- [Supported Hardware, Models, and other Software](https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html)

View File

@ -225,9 +225,7 @@ python examples/llama/convert_checkpoint.py --model_dir ${MODEL_CHECKPOINT} \
--output_dir ${CONVERTED_CHECKPOINT} \
--dtype ${DTYPE} \
--tp_size ${TP} \
--pp_size 1 \
--lora_target_modules attn_qkv \
--max_lora_rank ${MAX_LORA_RANK}
--pp_size 1
${HOME}/.local/bin/trtllm-build \
--checkpoint_dir ${CONVERTED_CHECKPOINT} \
@ -235,13 +233,11 @@ ${HOME}/.local/bin/trtllm-build \
--max_batch_size ${MAX_BATCH} \
--max_input_len $MAX_LEN \
--max_output_len $MAX_LEN \
--gpt_attention_plugin float16 \
--paged_kv_cache enable \
--remove_input_padding enable \
--gemm_plugin float16 \
--lora_plugin float16 \
--use_paged_context_fmha enable \
--use_custom_all_reduce disable
--lora_target_modules attn_qkv \
--max_lora_rank ${MAX_LORA_RANK}
NUM_LORAS=(8 16 24 32 64 128 256)
NUM_REQUESTS=1024

View File

@ -14,6 +14,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*****************************************************************************
*
* GptSession is going to be deprecated soon.
* Please do not add new functionality in this file!
*
*****************************************************************************/
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/mpiUtils.h"
#include "tensorrt_llm/plugins/api/tllmPlugin.h"

View File

@ -1127,6 +1127,39 @@ _allowed_configs = {
max_output_len=200,
builder_opt=None,
)),
"qwen1.5_7b_chat":
ModelConfig(name="qwen1.5_7b_chat",
family="qwen2",
benchmark_type="gpt",
build_config=BuildConfig(num_layers=32,
num_heads=32,
hidden_size=4096,
vocab_size=151936,
hidden_act='silu',
n_positions=8192,
inter_size=11008,
max_batch_size=128,
max_input_len=512,
max_output_len=200,
builder_opt=None,
bias=False)),
"qwen1.5_14b_chat":
ModelConfig(name="qwen1.5_14b_chat",
family="qwen2",
benchmark_type="gpt",
build_config=BuildConfig(
num_layers=40,
num_heads=40,
hidden_size=5120,
vocab_size=152064,
hidden_act='silu',
n_positions=8192,
inter_size=13696,
max_batch_size=64,
max_input_len=512,
max_output_len=200,
builder_opt=None,
)),
"mamba_2.8b":
ModelConfig(name="mamba_2.8b",
family="mamba",

View File

@ -232,6 +232,7 @@ def build_gpt(args):
builder_config_extra_kwargs['mamba_expand'] = build_config[
'mamba_expand']
builder_config_extra_kwargs['max_beam_width'] = max_beam_width
builder_config_extra_kwargs['layer_types'] = ['recurrent']
builder_config = builder.create_builder_config(
name=args.model,
precision=args.dtype,
@ -715,6 +716,51 @@ def build_gpt(args):
build_config["moe_num_experts"],
'moe_top_k':
build_config["moe_top_k"],
'qwen_type':
'qwen',
}
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.QWenForCausalLM(config)
elif family == "qwen2":
config = {
'architecture':
'QWenForCausalLM',
'dtype':
args.dtype,
'num_hidden_layers':
build_config['num_layers'],
'num_attention_heads':
build_config['num_heads'],
'num_key_value_heads':
build_config['num_heads'] if build_config['num_kv_heads'] is None
else build_config['num_kv_heads'],
'hidden_size':
build_config['hidden_size'],
'intermediate_size':
build_config['inter_size'],
'vocab_size':
build_config['vocab_size'],
'position_embedding_type':
'rope_gpt_neox',
'max_position_embeddings':
build_config['n_positions'],
'hidden_act':
build_config['hidden_act'],
'quantization': {
'group_size': 128,
'quant_algo': quant_algo,
'kv_cache_quant_algo': kv_cache_quant_algo
},
'mapping': {
'world_size': world_size,
'tp_size': world_size
},
'moe_num_experts':
build_config["moe_num_experts"],
'moe_top_k':
build_config["moe_top_k"],
'qwen_type':
'qwen2',
}
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.QWenForCausalLM(config)

View File

@ -21,7 +21,7 @@
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/batch_manager/schedulerPolicy.h"
#include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h"
#include "tensorrt_llm/runtime/gptModelConfig.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <atomic>
@ -79,9 +79,13 @@ public:
virtual ~GptManager();
protected:
/* Synchronizes the decoder */
virtual BatchManagerErrorCode_t forwardSync();
/* Invokes one step of backend
Updates state of all requests */
virtual BatchManagerErrorCode_t step(RequestList& activeRequests, std::set<uint64_t>& activeRequestsIds);
virtual BatchManagerErrorCode_t forwardAsync(
RequestList& activeRequests, std::unordered_set<uint64_t>& activeRequestsIds);
private:
[[nodiscard]] SizeType getMaxInputLen() const;
@ -89,7 +93,7 @@ private:
[[nodiscard]] SizeType getMaxNumSequences() const;
void validateLlmRequest(
LlmRequest& newReq, runtime::GptModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig) const;
LlmRequest& newReq, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig) const;
static std::shared_ptr<LlmRequest> fillLlmRequest(std::shared_ptr<InferenceRequest> newReq);
static std::shared_ptr<std::vector<TokenIdType>> getReqInputTokens(std::shared_ptr<InferenceRequest> newReq);
static SizeType getMaxNewTokens(std::shared_ptr<InferenceRequest> newReq);
@ -108,7 +112,7 @@ private:
// List of live requests
RequestList mActiveRequests;
// IDs of live requests
std::set<uint64_t> mActiveRequestsIds;
std::unordered_set<uint64_t> mActiveRequestsIds;
// Boolean that controls if prompt should be included in output tokens for non-streaming
bool mExcludeInputInOutput;

View File

@ -63,6 +63,8 @@ public:
&& hostCacheSize == other.hostCacheSize && onboardBlocks == other.onboardBlocks;
}
friend std::ostream& operator<<(std::ostream& os, KvCacheConfig const& self);
std::optional<SizeType> maxTokens;
std::optional<SizeType> maxAttentionWindow;
std::optional<SizeType> sinkTokenLength;

View File

@ -18,15 +18,16 @@
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
#include "tensorrt_llm/batch_manager/llmRequest.h" // TODO forward declare
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/kernels/kvCacheIndex.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/gptModelConfig.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <NvInferRuntime.h>
#include <cstdint>
#include <functional>
#include <list>
@ -89,15 +90,15 @@ struct KvCacheStats
class KVCacheBlock
{
public:
using OffsetType = std::int32_t;
using IdType = std::int32_t;
explicit KVCacheBlock(OffsetType blockIdx, OffsetType blocksInPrimaryPool);
explicit KVCacheBlock(IdType blockId, kernels::KVCacheIndex blockIdx);
void startScheduling();
[[nodiscard]] OffsetType getBlockIdx() const;
[[nodiscard]] IdType getBlockId() const;
[[nodiscard]] OffsetType getMemoryPoolBlockOffset() const;
[[nodiscard]] kernels::KVCacheIndex::UnderlyingType getMemoryPoolBlockIndex() const;
[[nodiscard]] bool isPrimary() const;
@ -143,11 +144,12 @@ public:
[[nodiscard]] bool isShared() const;
private:
// Linear index of block in pool
OffsetType mBlockIdx;
// Linear ID of block independent of pool
IdType mBlockId;
// Block in memory pool backing this block
OffsetType mMemoryPoolBlockOffset;
// Index of block in memory pool backing this block
// Choice of pool is encoded into the type
kernels::KVCacheIndex mMemoryPoolBlockIndex;
// Number of references to the block
SizeType mRefCount;
@ -169,9 +171,6 @@ private:
// Flag indicating if block is full
bool mIsFull;
// Flag indicating mMemoryPoolBlockOffset refers to secondary pool
static constexpr OffsetType secondaryPoolFlag = static_cast<OffsetType>(1) << (8 * sizeof(OffsetType) - 1);
};
class GenerationRequest
@ -220,14 +219,14 @@ public:
return mCacheBlockIds;
}
void addCacheBlock(SizeType beamIdx, SizeType blockIdx)
void addCacheBlock(SizeType beamIdx, KVCacheBlock::IdType blockId)
{
mCacheBlockIds.at(beamIdx).push_back(blockIdx);
mCacheBlockIds.at(beamIdx).push_back(blockId);
}
void changeCacheBlock(SizeType beamIdx, SizeType pagedBlockIdx, SizeType blockIdx)
void changeCacheBlock(SizeType beamIdx, SizeType pagedBlockIdx, KVCacheBlock::IdType blockId)
{
mCacheBlockIds.at(beamIdx).at(pagedBlockIdx) = blockIdx;
mCacheBlockIds.at(beamIdx).at(pagedBlockIdx) = blockId;
}
void clearCacheBlocks()
@ -264,7 +263,7 @@ private:
// Number of beams
SizeType mBeamWidth;
// List of blocks allocated for each beam of the sequence
std::vector<std::vector<SizeType>> mCacheBlockIds;
std::vector<std::vector<KVCacheBlock::IdType>> mCacheBlockIds;
// Number of tokens already in kv cache before context phase.
// A value > 0 indicates cached kv cache blocks were reused.
// One value per beam.
@ -348,7 +347,7 @@ public:
[[nodiscard]] SizeType getMaxNumBlocks() const noexcept
{
return static_cast<SizeType>(mAllBlocksByIdx.size());
return static_cast<SizeType>(mAllBlocksById.size());
}
[[nodiscard]] SizeType getTokensPerBlock() const noexcept
@ -356,7 +355,8 @@ public:
return mTokensPerBlock;
}
//! \brief Get size of one field in one layer in one block.
//! \brief Get size of one K/V cache block in one layer.
//! @details Volume of [numKvHeads, tokensPerBlock, sizePerHead]
[[nodiscard]] SizeType getBlockSize() const
{
return mBlockSize;
@ -372,10 +372,10 @@ public:
return mSecondaryPool;
}
//! \brief Get offset in pool to K or V block.
//! \param blockIdx the blockIdx as returned by getBlockIdx()
//! \brief Get index in pool to K or V block.
//! \param blockId the blockId as returned by getBlockId()
//! \param fieldIdx either 0 (K) or 1 (V),
[[nodiscard]] SizeType getKOrVBlockOffset(SizeType blockIdx, SizeType fieldIdx) const;
[[nodiscard]] kernels::KVCacheIndex getKOrVBlockIndex(KVCacheBlock::IdType blockId, SizeType fieldIdx) const;
//! \brief Bring offloaded block from secondary to primary memory.
//! \details Does nothing of block is already in primary memory.
@ -442,7 +442,7 @@ private:
// Number of tokens per one block
SizeType mTokensPerBlock;
// List of all blocks by idx
std::vector<BlockPtr> mAllBlocksByIdx;
std::vector<BlockPtr> mAllBlocksById;
// Dummy block acting as root for BlockToken searches
BlockPtr mCachedBlocksRoot;
// Statistics for block allocations/reuse
@ -452,7 +452,6 @@ private:
class KVCacheManager
{
public:
using OffsetType = KVCacheBlock::OffsetType;
using SizeType = tensorrt_llm::runtime::SizeType;
using SequencesPtr = GenerationRequest::SharedPtr;
using CudaStreamPtr = std::shared_ptr<runtime::CudaStream>;
@ -495,12 +494,6 @@ public:
return kvCacheStats;
}
// Volume of [numKvHeads, tokensPerBlock, sizePerHead]
[[nodiscard]] SizeType getBlockSize() const
{
return mBlockManager.getBlockSize();
}
[[nodiscard]] SizeType getMaxBlocksPerSeq() const
{
return mMaxBlocksPerSeq;
@ -544,21 +537,21 @@ public:
runtime::ITensor& output, SizeType outputSlotOffset, SizeType seqSlotIdx, SizeType beamWidth) const;
// Volume of [2, numKvHeads, tokensPerBlock, sizePerHead]
[[nodiscard]] static SizeType constexpr calculatePageSize(tensorrt_llm::runtime::GptModelConfig const& modelConfig)
[[nodiscard]] static SizeType constexpr calculatePageSize(tensorrt_llm::runtime::ModelConfig const& modelConfig)
{
return 2 * modelConfig.getNbKvHeads() * modelConfig.getTokensPerBlock() * modelConfig.getSizePerHead();
}
// numLayers * 2 * numKvHeads * sizePerHead
[[nodiscard]] static SizeType constexpr calculateCacheSizePerToken(
tensorrt_llm::runtime::GptModelConfig const& modelConfig, tensorrt_llm::runtime::WorldConfig const& worldConfig)
tensorrt_llm::runtime::ModelConfig const& modelConfig, tensorrt_llm::runtime::WorldConfig const& worldConfig)
{
return modelConfig.getNbLayers(worldConfig.getPipelineParallelism()) * 2 * modelConfig.getNbKvHeads()
return modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism()) * 2 * modelConfig.getNbKvHeads()
* modelConfig.getSizePerHead();
}
[[nodiscard]] static std::tuple<SizeType, SizeType> const calculateMaxNumBlocks(KvCacheConfig const& config,
nvinfer1::DataType dtype, tensorrt_llm::runtime::GptModelConfig const& modelConfig,
nvinfer1::DataType dtype, tensorrt_llm::runtime::ModelConfig const& modelConfig,
tensorrt_llm::runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager);
[[nodiscard]] SizeType getNumPrepopulatedTokens(SizeType batchSlotIdx, SizeType beamIdx) const
@ -576,8 +569,8 @@ public:
void rewindKVCache(SizeType seqSlotIdx, SizeType rewindLengths);
private:
void setOffsets(OffsetType* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType seqSlotIdx, SizeType beamIdx,
SizeType blockIdx, SizeType blockId) const;
void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType seqSlotIdx,
SizeType beamIdx, SizeType blockIdx, KVCacheBlock::IdType blockId) const;
void resetBlockOffsets(SizeType seqSlotIdx, SizeType beamWidth);
void cacheBlockOffsets(GenerationRequest const& seq, SizeType seqSlotIdx);
@ -586,8 +579,6 @@ private:
void updateToken(SizeType seqSlotIdx, bool addToken);
private:
// Number of layers
SizeType mNumLayers;
// Maximum number of sequences
SizeType mMaxNumSequences;
// Maximum beam width
@ -607,8 +598,8 @@ private:
BlockManager mBlockManager;
// List of all sequences
std::vector<SequencesPtr> mSequences;
// buffer for block offsets for all managed sequences
runtime::ITensor::SharedPtr mSequenceBlockOffsets;
// buffer for block indices for all managed sequences
runtime::ITensor::SharedPtr mSequenceBlockIndices;
// Whether to cache KV pages for reuse
bool mEnableBlockReuse;
};

View File

@ -92,6 +92,7 @@ public:
, mCumLogProbs(samplingConfig.beamWidth)
, mDraftTokens(draftTokens.value_or(std::make_shared<VecTokens>()))
, mDraftLogits(draftLogits)
, mNumTokensPerIteration(1)
, mReturnContextLogits(returnContextLogits)
, mReturnGenerationLogits(returnGenerationLogits)
, mExcludeInputFromOutput(excludeInputFromOutput)
@ -189,9 +190,9 @@ public:
{
auto const maxNewTokens = maxSequenceLen - mPromptLen;
TLLM_LOG_WARNING(
"Number of requested output tokens (%d) exceeds maximum sequence length (%d). "
"Prompt length + number of requested output tokens (%d + %d) exceeds maximum sequence length (%d). "
"Number of requested output tokens is changed to (%d).",
mMaxNewTokens, maxSequenceLen, maxNewTokens);
mPromptLen, mMaxNewTokens, maxSequenceLen, maxNewTokens);
mMaxNewTokens = maxNewTokens;
}
@ -494,6 +495,16 @@ public:
return mDraftTokens->size();
}
void setNumTokensPerIteration(SizeType numTokensPerIteration)
{
mNumTokensPerIteration = numTokensPerIteration;
}
SizeType getNumTokensPerIteration() const
{
return mNumTokensPerIteration;
}
void setReturnContextLogits(bool const returnContextLogits)
{
mReturnContextLogits = returnContextLogits;
@ -766,6 +777,7 @@ protected:
VecLogProbs mCumLogProbs; // [beamSize]
std::shared_ptr<VecTokens> mDraftTokens;
std::optional<TensorPtr> mDraftLogits;
SizeType mNumTokensPerIteration;
// Save logits
bool mReturnContextLogits;

View File

@ -12,10 +12,11 @@
#pragma once
#include "tensorrt_llm/batch_manager/common.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/batch_manager/peftCacheManagerConfig.h"
#include "tensorrt_llm/runtime/gptModelConfig.h"
#include "tensorrt_llm/runtime/loraCache.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/workerPool.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <NvInferRuntimeBase.h>
@ -23,6 +24,7 @@
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace tensorrt_llm::batch_manager
{
@ -39,7 +41,7 @@ class BasePeftCacheManager
{
public:
using LlmRequestPtr = std::shared_ptr<LlmRequest>;
using RequestTable = std::map<uint64_t, LlmRequestPtr>;
using RequestVector = std::vector<LlmRequestPtr>;
using PeftTable = std::map<uint64_t, std::shared_ptr<std::vector<runtime::LoraCache::TaskLayerModuleConfig>>>;
/**
@ -50,13 +52,14 @@ public:
virtual void addRequestPeft(LlmRequestPtr llmRequest, bool tryGpuCache = true) = 0;
/**
* \brief ensures device cache has all the weights needed to execute batch as specified by requestTable.
* \brief ensures device cache has all the weights needed to execute batch as specified by requests.
* This acts as sync for the copy tasks started by addRequestPeft
* \param[in] requestTable: current request table
* \param[in] contextRequests: current context requests
* \param[in] genRequests: current generation requests
* \param[in] resetGpuCache: reset (make all tasks evictable)
* \returns -- a PeftTable
*/
virtual PeftTable ensureBatch(RequestTable const& requestTable, bool resetGpuCache = false) = 0;
virtual PeftTable ensureBatch(ScheduledRequests const& scheduledRequests, bool resetGpuCache = false) = 0;
/**
* \brief mark all the tasks in device cache as done
@ -77,12 +80,12 @@ public:
class PeftCacheManager : public BasePeftCacheManager
{
public:
PeftCacheManager(PeftCacheManagerConfig const& config, runtime::GptModelConfig const& modelConfig,
PeftCacheManager(PeftCacheManagerConfig const& config, runtime::ModelConfig const& modelConfig,
runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager);
void addRequestPeft(std::shared_ptr<LlmRequest> llmRequest, bool tryGpuCache = true) override;
PeftTable ensureBatch(RequestTable const& requestTable, bool resetGpuCache = false) override;
PeftTable ensureBatch(ScheduledRequests const& scheduledRequests, bool resetGpuCache = false) override;
[[nodiscard]] bool isTaskCached(uint64_t taskId) const;
@ -116,7 +119,7 @@ public:
runtime::BufferManager const& bufferManager);
static std::pair<runtime::LoraCachePageManagerConfig, runtime::LoraCachePageManagerConfig> getPageManagerConfig(
PeftCacheManagerConfig const& config, runtime::GptModelConfig const& modelConfig,
PeftCacheManagerConfig const& config, runtime::ModelConfig const& modelConfig,
runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager);
private:
@ -133,9 +136,9 @@ private:
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> mTaskIdToPausedReqIds;
std::tuple<std::map<uint64_t, std::future<void>>, std::map<uint64_t, std::vector<uint64_t>>> getTaskMaps(
RequestTable const& requestTable);
ScheduledRequests const& scheduledRequests);
runtime::GptModelConfig mModelConfig;
runtime::ModelConfig mModelConfig;
runtime::WorldConfig mWorldConfig;
int mDevice{-1};
@ -145,7 +148,7 @@ class NoOpPeftCacheManager : public BasePeftCacheManager
{
void addRequestPeft(std::shared_ptr<LlmRequest> llmRequest, bool tryGpuCache = true) override;
PeftTable ensureBatch(RequestTable const& requestTable, bool resetGpuCache = false) override;
PeftTable ensureBatch(ScheduledRequests const& scheduledRequests, bool resetGpuCache = false) override;
void resetDeviceCache() override;

View File

@ -60,6 +60,7 @@ struct PeftCacheManagerConfig
, optimalAdapterSize(cfg.getOptimalAdapterSize())
, maxAdapterSize(cfg.getMaxAdapterSize())
, numPutWorkers(cfg.getNumPutWorkers())
, numEnsureWorkers(cfg.getNumEnsureWorkers())
, numCopyStreams(cfg.getNumCopyStreams())
, maxPagesPerBlockHost(cfg.getMaxPagesPerBlockHost())
, maxPagesPerBlockDevice(cfg.getMaxPagesPerBlockDevice())

View File

@ -31,4 +31,6 @@ SchedulerPolicy execToBatchManagerSchedPolicy(executor::SchedulerPolicy policy);
executor::SchedulerPolicy batchManagerToExecSchedPolicy(SchedulerPolicy policy);
std::ostream& operator<<(std::ostream& os, SchedulerPolicy policy);
} // namespace tensorrt_llm::batch_manager::batch_scheduler

View File

@ -57,7 +57,9 @@ public:
explicit TrtGptModelOptionalParams(executor::ExecutorConfig const& executorConfig)
: TrtGptModelOptionalParams(KvCacheConfig(executorConfig.getKvCacheConfig()), false,
executorConfig.getParallelConfig().value_or(executor::ParallelConfig()).getDeviceIds(),
executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext(), std::nullopt,
executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext(),
runtime::DecodingMode::fromExecutor(
executorConfig.getDecodingMode().value_or(executor::DecodingMode::kNONE)),
PeftCacheManagerConfig(executorConfig.getPeftCacheConfig().value_or(executor::PeftCacheConfig())),
executorConfig.getMedusaChoices())
{
@ -70,6 +72,8 @@ public:
&& enableChunkedContext == other.enableChunkedContext && decodingMode == other.decodingMode;
}
friend std::ostream& operator<<(std::ostream& os, TrtGptModelOptionalParams const& self);
KvCacheConfig kvCacheConfig;
bool enableTrtOverlap;

View File

@ -17,7 +17,6 @@
#pragma once
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/utils/multiDeviceUtils.h"
#ifdef ENABLE_FP8
@ -36,6 +35,11 @@
#define MPICHECK(cmd) TLLM_MPI_CHECK(cmd)
namespace tensorrt_llm::runtime
{
class IBuffer;
}
// A wrapper module of the MPI library.
namespace tensorrt_llm::mpi
{
@ -234,18 +238,11 @@ public:
std::shared_ptr<MpiRequest> bcastAsync(void* buffer, size_t size, MpiType dtype, int root) const;
std::shared_ptr<MpiRequest> bcastAsync(runtime::IBuffer& buf, int root) const
{
TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU);
return bcastAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
}
std::shared_ptr<MpiRequest> bcastAsync(runtime::IBuffer& buf, int root) const;
void bcast(void* buffer, size_t size, MpiType dtype, int root) const;
void bcast(runtime::IBuffer& buf, int root) const
{
bcast(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
}
void bcast(runtime::IBuffer& buf, int root) const;
template <typename T>
void bcastValue(T& value, int root) const
@ -281,11 +278,7 @@ public:
void send(void const* buffer, std::size_t size, MpiType dtype, int dest, int tag) const;
void send(runtime::IBuffer const& buf, int dest, int tag) const
{
TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU);
send(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag);
}
void send(runtime::IBuffer const& buf, int dest, int tag) const;
template <typename T>
void send(T const& value, int dest, int tag) const
@ -302,11 +295,7 @@ public:
MPI_Status recv(void* buffer, size_t size, MpiType dtype, int source, int tag) const;
MPI_Status recv(runtime::IBuffer& buf, int source, int tag) const
{
TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU);
return recv(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, source, tag);
}
MPI_Status recv(runtime::IBuffer& buf, int source, int tag) const;
template <typename T>
MPI_Status recv(T& value, int source, int tag) const

View File

@ -29,6 +29,11 @@
#include <string>
#include <vector>
namespace tensorrt_llm::mpi
{
class MpiComm;
}
namespace tensorrt_llm::executor
{
@ -310,6 +315,7 @@ public:
[[nodiscard]] Result getResult() const;
private:
friend class Serialization;
class Impl;
std::unique_ptr<Impl> mImpl;
};
@ -323,6 +329,8 @@ public:
[[nodiscard]] SchedulerPolicy getPolicy() const;
private:
friend class Serialization;
/// @brief The scheduler policy. See SchedulerPolicy.
SchedulerPolicy mPolicy;
};
@ -346,6 +354,8 @@ public:
[[nodiscard]] bool getOnboardBlocks() const;
private:
friend class Serialization;
/// @brief Controls if KV cache blocks can be reused for different requests
bool mEnableBlockReuse;
@ -378,6 +388,26 @@ SizeType const kDefaultIterStatsMaxIterations = 1000;
// Per request stats may have additional overhead due to going through all requests. Turned off by default.
SizeType const kDefaultRequestStatsMaxIterations = 0;
class OrchestratorConfig
{
public:
explicit OrchestratorConfig(bool isOrchestrator = true, std::string workerExecutablePath = "",
std::shared_ptr<mpi::MpiComm> orchLeaderComm = nullptr);
[[nodiscard]] bool getIsOrchestrator() const;
[[nodiscard]] std::string getWorkerExecutablePath() const;
[[nodiscard]] std::shared_ptr<mpi::MpiComm> getOrchLeaderComm() const;
void setIsOrchestrator(bool isOrchestrator);
void setWorkerExecutablePath(std::string const& workerExecutablePath);
void setOrchLeaderComm(std::shared_ptr<mpi::MpiComm> const& orchLeaderComm);
private:
bool mIsOrchestrator;
std::string mWorkerExecutablePath;
std::shared_ptr<mpi::MpiComm> mOrchLeaderComm;
};
/// @brief A configuration class for the parallel execution parameters
/// Currently only supports commType = CommunicationType::kMPI
class ParallelConfig
@ -392,19 +422,24 @@ public:
explicit ParallelConfig(CommunicationType commType = CommunicationType::kMPI,
CommunicationMode commMode = CommunicationMode::kLEADER,
std::optional<std::vector<SizeType>> deviceIds = std::nullopt,
std::optional<std::vector<SizeType>> participantIds = std::nullopt);
std::optional<std::vector<SizeType>> participantIds = std::nullopt,
std::optional<OrchestratorConfig> const& orchestratorConfig = std::nullopt);
[[nodiscard]] CommunicationType getCommunicationType() const;
[[nodiscard]] CommunicationMode getCommunicationMode() const;
[[nodiscard]] std::optional<std::vector<SizeType>> getDeviceIds() const;
[[nodiscard]] std::optional<std::vector<SizeType>> getParticipantIds() const;
[[nodiscard]] std::optional<OrchestratorConfig> getOrchestratorConfig() const;
void setCommunicationType(CommunicationType type);
void setCommunicationMode(CommunicationMode mode);
void setDeviceIds(std::vector<SizeType> const& deviceIds);
void setParticipantIds(std::vector<SizeType> const& participantIds);
void setOrchestratorConfig(OrchestratorConfig const& orchestratorConfig);
private:
friend class Serialization;
/// @brief The type of communication protocol used. Default is MPI.
CommunicationType mCommType;
@ -416,6 +451,9 @@ private:
/// @brief The participant ids (MPI ranks for example) used for executing this model
std::optional<std::vector<SizeType>> mParticipantIds;
/// @brief Optional orchestrator configuration
std::optional<OrchestratorConfig> mOrchestratorConfig;
};
/// @brief config for PeftCacheManager
@ -428,6 +466,8 @@ public:
SizeType maxPagesPerBlockDevice = 8, std::optional<float> const& deviceCachePercent = std::nullopt,
std::optional<size_t> const& hostCacheSize = std::nullopt);
bool operator==(PeftCacheConfig const& other) const;
[[nodiscard]] SizeType getNumHostModuleLayer() const;
[[nodiscard]] SizeType getNumDeviceModuleLayer() const;
[[nodiscard]] SizeType getOptimalAdapterSize() const;
@ -441,6 +481,8 @@ public:
[[nodiscard]] std::optional<size_t> getHostCacheSize() const;
private:
friend class Serialization;
// number of max sized 1-layer 1-module adapterSize=1 sets of weights that can be stored in host cache
SizeType mNumHostModuleLayer;
// number of max sized 1-layer 1-module sets of weights that can be stored in host cache
@ -460,7 +502,7 @@ private:
// Number of cache pages per allocation block (device)
SizeType mMaxPagesPerBlockDevice;
// percent of memory after engine load to use for cache
std::optional<float> mDeviceCachePercent;
std::optional<FloatType> mDeviceCachePercent;
// size in bytes to use for host cache
std::optional<size_t> mHostCacheSize;
};
@ -477,7 +519,8 @@ public:
std::optional<ParallelConfig> parallelConfig = std::nullopt,
std::optional<PeftCacheConfig> const& peftCacheConfig = std::nullopt,
std::optional<LogitsPostProcessorMap> logitsPostProcessorMap = std::nullopt,
std::optional<MedusaChoices> medusaChoices = std::nullopt);
std::optional<MedusaChoices> medusaChoices = std::nullopt,
std::optional<DecodingMode> decodingMode = std::nullopt);
[[nodiscard]] SizeType getMaxBeamWidth() const;
[[nodiscard]] SchedulerConfig getSchedulerConfig() const;
@ -491,6 +534,7 @@ public:
[[nodiscard]] std::optional<PeftCacheConfig> getPeftCacheConfig() const;
[[nodiscard]] std::optional<LogitsPostProcessorMap> getLogitsPostProcessorMap() const;
[[nodiscard]] std::optional<MedusaChoices> getMedusaChoices() const;
[[nodiscard]] std::optional<DecodingMode> getDecodingMode() const;
void setMaxBeamWidth(SizeType maxBeamWidth);
void setSchedulerConfig(SchedulerConfig const& schedulerConfig);
@ -504,8 +548,11 @@ public:
void setPeftCacheConfig(PeftCacheConfig const& peftCacheConfig);
void setLogitsPostProcessorMap(LogitsPostProcessorMap const& logitsPostProcessorMap);
void setMedusaChoices(MedusaChoices const& medusaChoices);
void setDecodingMode(DecodingMode decodingMode);
private:
friend class Serialization;
/// @brief The beam width value of requests that will be sent to the executor
SizeType mMaxBeamWidth;
@ -535,6 +582,7 @@ private:
std::optional<PeftCacheConfig> mPeftCacheConfig;
std::optional<LogitsPostProcessorMap> mLogitsPostProcessorMap;
std::optional<MedusaChoices> mMedusaChoices;
std::optional<DecodingMode> mDecodingMode;
};
/// @brief The executor is responsible for receiving new requests and sending responses, and running the inference

View File

@ -0,0 +1,117 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/executor/tensor.h"
#include "tensorrt_llm/executor/types.h"
#include <istream>
#include <ostream>
namespace tensorrt_llm::executor
{
class Serialization
{
public:
// SamplingConfig
[[nodiscard]] static SamplingConfig deserializeSamplingConfig(std::istream& is);
static void serialize(SamplingConfig const& config, std::ostream& os);
[[nodiscard]] static size_t serializedSize(SamplingConfig const& config);
// OutputConfig
[[nodiscard]] static OutputConfig deserializeOutputConfig(std::istream& is);
static void serialize(OutputConfig const& config, std::ostream& os);
[[nodiscard]] static size_t serializedSize(OutputConfig const& config);
// SpeculativeDecodingConfig
[[nodiscard]] static SpeculativeDecodingConfig deserializeSpeculativeDecodingConfig(std::istream& is);
static void serialize(SpeculativeDecodingConfig const& config, std::ostream& os);
[[nodiscard]] static size_t serializedSize(SpeculativeDecodingConfig const& config);
// PromptTuningConfig
[[nodiscard]] static PromptTuningConfig deserializePromptTuningConfig(std::istream& is);
static void serialize(PromptTuningConfig const& config, std::ostream& os);
[[nodiscard]] static size_t serializedSize(PromptTuningConfig const& config);
// LoraConfig
[[nodiscard]] static LoraConfig deserializeLoraConfig(std::istream& is);
static void serialize(LoraConfig const& config, std::ostream& os);
[[nodiscard]] static size_t serializedSize(LoraConfig const& config);
// Request
[[nodiscard]] static Request deserializeRequest(std::istream& is);
static void serialize(Request const& request, std::ostream& os);
[[nodiscard]] static size_t serializedSize(Request const& request);
// Tensor
[[nodiscard]] static Tensor deserializeTensor(std::istream& is);
static void serialize(Tensor const& tensor, std::ostream& os);
[[nodiscard]] static size_t serializedSize(Tensor const& tensor);
// Result
[[nodiscard]] static Result deserializeResult(std::istream& is);
static void serialize(Result const& result, std::ostream& os);
[[nodiscard]] static size_t serializedSize(Result const& result);
// Response
[[nodiscard]] static Response deserializeResponse(std::istream& is);
static void serialize(Response const& response, std::ostream& os);
[[nodiscard]] static size_t serializedSize(Response const& response);
// Vector of responses
static std::vector<Response> deserializeResponses(std::vector<char>& buffer);
static std::vector<char> serialize(std::vector<Response> const& responses);
// KvCacheConfig
static KvCacheConfig deserializeKvCacheConfig(std::istream& is);
static void serialize(KvCacheConfig const& kvCacheConfig, std::ostream& os);
static size_t serializedSize(KvCacheConfig const& kvCacheConfig);
// SchedulerConfig
static SchedulerConfig deserializeSchedulerConfig(std::istream& is);
static void serialize(SchedulerConfig const& schedulerConfig, std::ostream& os);
static size_t serializedSize(SchedulerConfig const& schedulerConfig);
// ParallelConfig
static ParallelConfig deserializeParallelConfig(std::istream& is);
static void serialize(ParallelConfig const& parallelConfig, std::ostream& os);
static size_t serializedSize(ParallelConfig const& parallelConfig);
// PeftCacheConfig
static PeftCacheConfig deserializePeftCacheConfig(std::istream& is);
static void serialize(PeftCacheConfig const& peftCacheConfig, std::ostream& os);
static size_t serializedSize(PeftCacheConfig const& peftCacheConfig);
// OrchestratorConfig
static OrchestratorConfig deserializeOrchestratorConfig(std::istream& is);
static void serialize(OrchestratorConfig const& orchestratorConfig, std::ostream& os);
static size_t serializedSize(OrchestratorConfig const& orchestratorConfig);
// ExecutorConfig
static ExecutorConfig deserializeExecutorConfig(std::istream& is);
static void serialize(ExecutorConfig const& executorConfig, std::ostream& os);
static size_t serializedSize(ExecutorConfig const& executorConfig);
// String
static std::string deserializeString(std::istream& is);
// ModelType
static ModelType deserializeModelType(std::istream& is);
};
} // namespace tensorrt_llm::executor

View File

@ -191,6 +191,9 @@ enum class CommunicationMode
kLEADER, // With the leader mode, only the leader can enqueue requests. The requests will be
// broadcasted to the workers. All participants can get response via awaitResponses. The leader is the
// first participant in the provided participant IDS, or 0 if participant ID is not provided
kORCHESTRATOR, // With the orchestrator mode, only the orchestrator can enqueue requests and await responses. The
// requests will be broadcasted to the workers. The orchestrator will spawn new processes for the
// execution of the model
};
/// @brief Struct that holds the stats of a KV cache manager
@ -305,4 +308,17 @@ struct RequestStatsPerIteration
std::vector<RequestStats> requestStats;
};
/// @brief Decoding mode
enum class DecodingMode
{
/// @brief No mode specified. Config will be determined from the beam width of the first request at runtime
/// TopKTopP if beamWidth == 1, BeamSearch otherwise
kNONE,
kTOP_K,
kTOP_P,
kBEAM_SEARCH,
kMEDUSA,
kTOP_K_TOP_P,
};
} // namespace tensorrt_llm::executor

View File

@ -16,6 +16,8 @@
#pragma once
#include "tensorrt_llm/executor/executor.h"
namespace tensorrt_llm
{
namespace runtime
@ -54,37 +56,37 @@ public:
return DecodingMode{kMedusa};
}
bool constexpr isNone()
bool constexpr isNone() const
{
return mState == 0;
}
bool constexpr isTopK()
bool constexpr isTopK() const
{
return anyBitSet(kTopK);
}
bool constexpr isTopP()
bool constexpr isTopP() const
{
return anyBitSet(kTopP);
}
bool constexpr isTopKorTopP()
bool constexpr isTopKorTopP() const
{
return anyBitSet(kTopKTopP);
}
bool constexpr isTopKandTopP()
bool constexpr isTopKandTopP() const
{
return allBitSet(kTopKTopP);
}
bool constexpr isBeamSearch()
bool constexpr isBeamSearch() const
{
return anyBitSet(kBeamSearch);
}
bool constexpr isMedusa()
bool constexpr isMedusa() const
{
return anyBitSet(kMedusa);
}
@ -96,6 +98,28 @@ public:
return mState == other.mState;
}
static DecodingMode fromExecutor(executor::DecodingMode decodingMode)
{
switch (decodingMode)
{
case executor::DecodingMode::kNONE: return DecodingMode::None();
case executor::DecodingMode::kTOP_K: return DecodingMode::TopK();
case executor::DecodingMode::kTOP_P: return DecodingMode::TopP();
case executor::DecodingMode::kBEAM_SEARCH: return DecodingMode::BeamSearch();
case executor::DecodingMode::kMEDUSA: return DecodingMode::Medusa();
case executor::DecodingMode::kTOP_K_TOP_P: return DecodingMode::TopKTopP();
default: TLLM_THROW("Invalid decoding mode"); break;
}
}
friend std::ostream& operator<<(std::ostream& os, DecodingMode other);
private:
constexpr DecodingMode(UnderlyingType state)
: mState(state)

View File

@ -29,17 +29,21 @@ class DecodingOutput
public:
using TensorPtr = ITensor::SharedPtr;
// BS: batch_size, BM: beam_width, MSL: max_seq_length
// All TensorPtr without special comments are on gpu
class BeamHypotheses
{
public:
TensorPtr outputIdsTgt; // [batchSize, 2 * beamWidth, maxSeqLen]
TensorPtr sequenceLengthsTgt; // [batchSize, 2 * beamWidth]
TensorPtr cumLogProbs; // [batchSize, 2 * beamWidth]
TensorPtr normedScores; // [batchSize, 2 * beamWidth]
TensorPtr logProbs; // [batchSize, 2 * beamWidth, maxSeqLen]
TensorPtr minNormedScores; // [batchSize]
TensorPtr numBeams; // [batchSize]
TensorPtr isDone; // [batchSize]
// The same as cpp/tensorrt_llm/kernels/beamSearchKernels.h
TensorPtr outputIdsCBA; // [BS, BM*2, MSL]
TensorPtr sequenceLengthsCBA; // [BS, BM]
TensorPtr cumLogProbsCBA; // [BS, BM*2]
TensorPtr normedScoresCBA; // [BS, BM*2]
TensorPtr logProbsCBA; // [BS, BM*2, MSL]
TensorPtr minNormedScoresCBA; // [BS]
TensorPtr numBeamsCBA; // [BS]
TensorPtr batchDones; // [BS]
void empty(BufferManager& manager);
@ -61,27 +65,26 @@ public:
}
// mandatory parameters
TensorPtr ids; // [batchSize, beamWidth, maxSeqLen], on gpu, must contain previously generated token ids for all
// steps before DecodingInput.step
TensorPtr newTokensSteps; // [maxTokensPerStep, batchSize, beamWidth] new tokens at each generated token of
// maxTokensPerStep, on gpu.
TensorPtr newTokens; // [batchSize, beamWidth] usually a view of newTokensSteps for the current token, on gpu.
std::vector<TensorPtr> newTokensVec; // vector of size maxTokensPerStep with tensor [batchSize, beamWidth].
// Vector of views on newTokensSteps for each token. Elements are on gpu.
TensorPtr ids; // [BS, BM, MSL], contains previously generated token ids for all
// steps before DecodingInput.step
TensorPtr newTokensSteps; // [maxTokensPerStep, BS, BM] new tokens at each generated token of
// maxTokensPerStep
TensorPtr newTokens; // [BS, BM] usually a view of newTokensSteps for the current token
std::vector<TensorPtr> newTokensVec; // vector of size maxTokensPerStep with tensor [BS, BM].
// Vector of views on newTokensSteps for each token
// optional parameters
TensorPtr finished; // [batchSize, beamWidth],
// Set to true by decoding if any of the stop conditions are met or if DecodingInput.finished is
// true. In beam search and to determine whether to stop according to
// DecodingInput.sequenceLimitLength, on gpu
TensorPtr finishedSum; // [batchSize], the sum of finished sequences per request, in pinned memory
TensorPtr finished; // [BS, BM], set to true by decoding if any of the stop conditions are met or if
// DecodingInput.finished is true. In beam search and to determine whether to stop according to
// DecodingInput.sequenceLimitLength
TensorPtr finishedSum; // [BS], the sum of finished sequences per request, in pinned memory
// mandatory parameters for beam search
TensorPtr logProbs; // [batchSize, beamWidth, maxSeqLen], must be float*, on gpu
TensorPtr cumLogProbs; // [batchSize, beamWidth], optional for sampling, on gpu
TensorPtr parentIds; // [batchSize, beamWidth, maxSeqLen], on gpu
TensorPtr lengths; // [batchSize, beamWidth], total sequence lengths including padding, on gpu
TensorPtr cacheIndirection; // [batchSize, beamWidth, maxSeqLen], k/v indirection for next generation step, on gpu
TensorPtr logProbs; // [BS, BM, MSL], must be float*
TensorPtr cumLogProbs; // [BS, BM], optional for sampling
TensorPtr parentIds; // [BS, BM, MSL]
TensorPtr lengths; // [BS, BM], total sequence lengths including padding
TensorPtr cacheIndirection; // [BS, BM, MSL], k/v indirection for next generation step
BeamHypotheses beamHypotheses;
@ -89,10 +92,10 @@ public:
class MedusaOutputs
{
public:
TensorPtr medusaNextDraftTokens; // [maxBatchSize, maxTokensPerStep], on gpu
TensorPtr medusaAcceptedTokensLen; // [maxBatchSize], on gpu
TensorPtr medusaAcceptedLengthsCumSum; // [maxBatchSize + 1], on gpu
TensorPtr medusaPathsOffsets; // [maxBatchSize * maxNumHeads], on gpu
TensorPtr medusaNextDraftTokens; // [maxBatchSize, maxTokensPerStep]
TensorPtr medusaAcceptedTokensLen; // [maxBatchSize]
TensorPtr medusaAcceptedLengthsCumSum; // [maxBatchSize + 1]
TensorPtr medusaPathsOffsets; // [maxBatchSize * maxNumHeads]
};
std::optional<MedusaOutputs> medusaOutputs;

View File

@ -21,7 +21,7 @@
#include "tensorrt_llm/runtime/decodingInput.h"
#include "tensorrt_llm/runtime/decodingMode.h"
#include "tensorrt_llm/runtime/decodingOutput.h"
#include "tensorrt_llm/runtime/gptModelConfig.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/samplingConfig.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <curand_kernel.h>

View File

@ -43,7 +43,7 @@ public:
//! Setup the decoder before calling `forward()`
void setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow,
SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep, bool fusedDecoder,
nvinfer1::DataType dtype, GptModelConfig const& modelConfig) override;
nvinfer1::DataType dtype, ModelConfig const& modelConfig) override;
void newBatch(
GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig) override;
@ -182,7 +182,7 @@ private:
void allocateMedusaBuffers();
//! @brief Setup buffers for medusa decoding.
void setupMedusa(GptModelConfig const& modelConfig);
void setupMedusa(ModelConfig const& modelConfig);
//! @brief Setups decoder internal tensors for new speculative decoding request
void newRequestSpeculativeDecoding(

View File

@ -17,7 +17,7 @@
#pragma once
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/gptModelConfig.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <filesystem>
@ -32,13 +32,13 @@ class GptJsonConfig
{
public:
GptJsonConfig(std::string name, std::string version, std::string precision, SizeType tensorParallelism,
SizeType pipelineParallelism, GptModelConfig const& modelConfig)
SizeType pipelineParallelism, ModelConfig const& modelConfig)
: mName(std::move(name))
, mVersion(std::move(version))
, mPrecision(std::move(precision))
, mTensorParallelism{tensorParallelism}
, mPipelineParallelism{pipelineParallelism}
, mGptModelConfig(modelConfig)
, mModelConfig(modelConfig)
{
}
@ -48,9 +48,9 @@ public:
static GptJsonConfig parse(std::filesystem::path const& path);
[[nodiscard]] GptModelConfig getModelConfig() const
[[nodiscard]] ModelConfig getModelConfig() const
{
return mGptModelConfig;
return mModelConfig;
}
[[nodiscard]] std::string const& getName() const
@ -96,7 +96,7 @@ private:
std::string const mPrecision;
SizeType const mTensorParallelism;
SizeType const mPipelineParallelism;
GptModelConfig const mGptModelConfig;
ModelConfig const mModelConfig;
};
} // namespace tensorrt_llm::runtime

View File

@ -14,6 +14,13 @@
* limitations under the License.
*/
/*****************************************************************************
*
* GptSession is going to be deprecated soon.
* Please do not add new functionality in this file!
*
*****************************************************************************/
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
@ -23,8 +30,8 @@
#include "tensorrt_llm/runtime/decodingMode.h"
#include "tensorrt_llm/runtime/generationInput.h"
#include "tensorrt_llm/runtime/generationOutput.h"
#include "tensorrt_llm/runtime/gptModelConfig.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/samplingConfig.h"
#include "tensorrt_llm/runtime/worldConfig.h"
@ -150,17 +157,17 @@ public:
//! @param engineBuffer The compiled TensorRT engine (const void*),
//! @param engineSize The size in bytes of the TensorRT engine (size_t),
//! @param logger The optional logger.
GptSession(Config const& sessionConfig, GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
GptSession(Config const& sessionConfig, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
void const* engineBuffer, std::size_t engineSize, LoggerPtr logger = nullptr);
GptSession(Config const& sessionConfig, GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
GptSession(Config const& sessionConfig, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
std::vector<uint8_t> const& engineBuffer, LoggerPtr logger = nullptr)
: GptSession(
sessionConfig, modelConfig, worldConfig, engineBuffer.data(), engineBuffer.size(), std::move(logger))
{
}
GptSession(Config const& sessionConfig, GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
GptSession(Config const& sessionConfig, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
std::string const& engineFile, LoggerPtr logger = nullptr)
: GptSession(sessionConfig, modelConfig, worldConfig, utils::loadEngine(engineFile), std::move(logger))
{
@ -170,7 +177,7 @@ public:
[[nodiscard]] BufferManager const& getBufferManager() const;
[[nodiscard]] GptModelConfig const& getModelConfig() const
[[nodiscard]] ModelConfig const& getModelConfig() const
{
return mModelConfig;
}
@ -335,7 +342,7 @@ private:
friend class batch_manager::TrtGptModelV1;
private:
GptModelConfig const mModelConfig;
ModelConfig const mModelConfig;
WorldConfig const mWorldConfig;
int mDevice{-1};
std::shared_ptr<NcclCommunicator> mPipelineComm;

View File

@ -18,6 +18,7 @@
#include "tensorrt_llm/common/arrayView.h"
#include "tensorrt_llm/common/dataType.h"
#include "tensorrt_llm/kernels/kvCacheIndex.h"
#include <NvInferRuntime.h>
@ -307,6 +308,12 @@ struct TRTDataType<__nv_fp8_e4m3>
};
#endif
template <>
struct TRTDataType<kernels::KVCacheIndex>
{
static constexpr auto value = TRTDataType<kernels::KVCacheIndex::UnderlyingType>::value;
};
template <>
struct TRTDataType<void*>
{

View File

@ -75,7 +75,7 @@ public:
//! Setup the decoder before calling `forward()`, also calls reshapeBuffers
virtual void setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth,
SizeType maxAttentionWindow, SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep,
bool fusedDecoder, nvinfer1::DataType dtype, GptModelConfig const& modelConfig)
bool fusedDecoder, nvinfer1::DataType dtype, ModelConfig const& modelConfig)
= 0;
//! @brief Initialize the decoder with new batch of inputs.

View File

@ -18,10 +18,10 @@
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/gptModelConfig.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/loraCachePageManagerConfig.h"
#include "tensorrt_llm/runtime/loraModule.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <NvInferRuntimeBase.h>
#include <deque>
@ -159,11 +159,11 @@ public:
/**
* param[in] pageManagerConfig: a LoraCachePageManagerConfig
* param[in] modelConfig: a GptModelConfig
* param[in] modelConfig: a ModelConfig
* param[in] worldConfig: a WorldConfig
* param[in] bufferManager: a BufferManager only used to allocate page blocks
*/
LoraCache(LoraCachePageManagerConfig const& pageManagerConfig, GptModelConfig const& modelConfig,
LoraCache(LoraCachePageManagerConfig const& pageManagerConfig, ModelConfig const& modelConfig,
WorldConfig const& worldConfig, BufferManager const& bufferManager);
/**
@ -277,7 +277,7 @@ public:
* \brief Copy task weights to cache pages.
* \param[in] weights: task weights
* \param[in] config: task config tensor
* \param[in] modelConfig: a GptModelConfig
* \param[in] modelConfig: a ModelConfig
* \param[in] worldConfig: a WorldConfig
* \param[in] modelIdToModel: map from lora module id to LoraModule
* \param[in] manager: a BufferManager the manager to use to perform the copies
@ -286,7 +286,7 @@ public:
* \returns -- list of cache Values objects
*/
static std::vector<LoraCache::TaskLayerModuleConfig> copyToPages(TensorPtr weights, TensorPtr config,
GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
ModelConfig const& modelConfig, WorldConfig const& worldConfig,
std::unordered_map<SizeType, LoraModule> moduleIdToModel, BufferManager const& manager,
std::vector<TensorPtr> const& pages, std::vector<std::size_t> const& pageIds);
@ -385,7 +385,7 @@ private:
};
LoraCachePageManagerConfig mPageManagerConfig;
GptModelConfig mModelConfig;
ModelConfig mModelConfig;
WorldConfig mWorldConfig;
// Protects mCachePageManager

View File

@ -32,7 +32,7 @@ struct MambaConfig
SizeType expand = 0;
};
class GptModelConfig
class ModelConfig
{
public:
enum class ModelVariant : std::int32_t
@ -42,10 +42,11 @@ public:
kMamba = 2, // https://github.com/state-spaces/mamba
};
explicit GptModelConfig(
SizeType vocabSize, SizeType nbLayers, SizeType nbHeads, SizeType hiddenSize, nvinfer1::DataType dtype)
explicit ModelConfig(SizeType vocabSize, SizeType nbAttentionLayers, SizeType nbSsmLayers, SizeType nbHeads,
SizeType hiddenSize, nvinfer1::DataType dtype)
: mVocabSize(vocabSize)
, mNbLayers(nbLayers)
, mNbAttentionLayers(nbAttentionLayers)
, mNbSsmLayers(nbSsmLayers)
, mNbHeads(nbHeads)
, mNbKvHeads(nbHeads)
, mHiddenSize(hiddenSize)
@ -71,6 +72,7 @@ public:
, mMaxDraftLen(0)
, mUseContextFMHAForGeneration(false)
, mPagedContextFMHA(false)
, mUseXQA{false}
, mUseLoraPlugin(false)
, mMlpHiddenSize(0)
, mMedusaModule(std::nullopt)
@ -87,10 +89,16 @@ public:
return (mVocabSize + worldSize - 1) / worldSize * worldSize;
}
[[nodiscard]] SizeType constexpr getNbLayers(SizeType pipelineParallelism = 1) const
[[nodiscard]] SizeType constexpr getNbAttentionLayers(SizeType pipelineParallelism = 1) const
{
TLLM_CHECK(mNbLayers % pipelineParallelism == 0);
return mNbLayers / pipelineParallelism;
TLLM_CHECK(mNbAttentionLayers % pipelineParallelism == 0);
return mNbAttentionLayers / pipelineParallelism;
}
[[nodiscard]] SizeType constexpr getNbSsmLayers(SizeType pipelineParallelism = 1) const
{
TLLM_CHECK(mNbSsmLayers % pipelineParallelism == 0);
return mNbSsmLayers / pipelineParallelism;
}
[[nodiscard]] SizeType constexpr getNbHeads() const noexcept
@ -344,6 +352,16 @@ public:
return mPagedContextFMHA;
}
void constexpr useXQA(bool useXQA) noexcept
{
mUseXQA = useXQA;
}
[[nodiscard]] bool constexpr useXQA() const noexcept
{
return mUseXQA;
}
[[nodiscard]] bool constexpr useLoraPlugin() const noexcept
{
return mUseLoraPlugin;
@ -354,7 +372,7 @@ public:
mUseLoraPlugin = useLoraPlugin;
}
std::vector<LoraModule> const& getLoraModules() const noexcept
[[nodiscard]] std::vector<LoraModule> const& getLoraModules() const noexcept
{
return mLoraModules;
}
@ -442,7 +460,8 @@ public:
private:
SizeType mVocabSize;
SizeType mNbLayers;
SizeType mNbAttentionLayers;
SizeType mNbSsmLayers;
SizeType mNbHeads;
SizeType mNbKvHeads;
SizeType mHiddenSize;
@ -471,6 +490,7 @@ private:
bool mUseContextFMHAForGeneration;
bool mPagedContextFMHA;
bool mUseXQA;
bool mUseLoraPlugin;
std::vector<LoraModule> mLoraModules;

View File

@ -17,6 +17,7 @@
#pragma once
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/layers/defaultDecodingParams.h"
#include "tensorrt_llm/runtime/common.h"
#include <functional>
@ -36,25 +37,21 @@ private:
template <typename T>
static OptVec<T> fuseValues(
std::vector<SamplingConfig> const& configs, std::function<OptVec<T>(SizeType ci)> accessor)
std::vector<SamplingConfig> const& configs, std::function<OptVec<T>(size_t ci)> accessor, T defaultValue)
{
std::vector<T> values;
auto const hasValues = accessor(0).has_value();
for (size_t ci = 0; ci < configs.size(); ++ci)
{
auto value = defaultValue;
auto const& configValue = accessor(ci);
TLLM_CHECK(hasValues == configValue.has_value());
if (hasValues)
if (configValue.has_value())
{
TLLM_CHECK(configValue.value().size() == 1);
values.push_back(configValue.value().front());
value = configValue.value().front();
}
values.push_back(value);
}
if (!hasValues)
{
return std::nullopt;
}
return std::make_optional<std::vector<T>>(values);
}
@ -72,26 +69,52 @@ public:
TLLM_CHECK(configs.size() > 0);
beamWidth = configs.front().beamWidth;
normalizeLogProbs = configs.front().normalizeLogProbs;
temperature = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].temperature; });
minLength = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].minLength; });
repetitionPenalty
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].repetitionPenalty; });
presencePenalty
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].presencePenalty; });
topK = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].topK; });
topP = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].topP; });
randomSeed = fuseValues<uint64_t>(configs, [&configs](SizeType ci) { return configs[ci].randomSeed; });
topPDecay = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].topPDecay; });
topPMin = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].topPMin; });
topPResetIds = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].topPResetIds; });
beamSearchDiversityRate
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].beamSearchDiversityRate; });
lengthPenalty = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].lengthPenalty; });
earlyStopping = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].earlyStopping; });
draftAcceptanceThreshold
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].draftAcceptanceThreshold; });
topKMedusaHeads = fuseValues<std::vector<runtime::SizeType>>(
configs, [&configs](SizeType ci) { return configs[ci].topKMedusaHeads; });
temperature = fuseValues<FloatType>(
configs, [&configs](size_t ci) { return configs[ci].temperature; },
layers::DefaultDecodingParams::getTemperature());
minLength = fuseValues<SizeType32>(
configs, [&configs](size_t ci) { return configs[ci].minLength; },
layers::DefaultDecodingParams::getMinLength());
repetitionPenalty = fuseValues<FloatType>(
configs, [&configs](size_t ci) { return configs[ci].repetitionPenalty; },
layers::DefaultDecodingParams::getRepetitionPenalty());
presencePenalty = fuseValues<FloatType>(
configs, [&configs](size_t ci) { return configs[ci].presencePenalty; },
layers::DefaultDecodingParams::getPresencePenalty());
frequencyPenalty = fuseValues<FloatType>(
configs, [&configs](size_t ci) { return configs[ci].frequencyPenalty; },
layers::DefaultDecodingParams::getFrequencyPenalty());
topK = fuseValues<SizeType32>(
configs, [&configs](size_t ci) { return configs[ci].topK; }, layers::DefaultDecodingParams::getTopK());
topP = fuseValues<FloatType>(
configs, [&configs](size_t ci) { return configs[ci].topP; }, layers::DefaultDecodingParams::getTopP());
randomSeed = fuseValues<uint64_t>(
configs, [&configs](size_t ci) { return configs[ci].randomSeed; },
layers::DefaultDecodingParams::getSeed());
topPDecay = fuseValues<FloatType>(
configs, [&configs](size_t ci) { return configs[ci].topPDecay; },
layers::DefaultDecodingParams::getTopPDecay());
topPMin = fuseValues<FloatType>(
configs, [&configs](size_t ci) { return configs[ci].topPMin; },
layers::DefaultDecodingParams::getTopPMin());
topPResetIds = fuseValues<TokenIdType>(
configs, [&configs](size_t ci) { return configs[ci].topPResetIds; },
layers::DefaultDecodingParams::getTopPResetId());
beamSearchDiversityRate = fuseValues<FloatType>(
configs, [&configs](size_t ci) { return configs[ci].beamSearchDiversityRate; },
layers::DefaultDecodingParams::getBeamSearchDiversity());
lengthPenalty = fuseValues<FloatType>(
configs, [&configs](size_t ci) { return configs[ci].lengthPenalty; },
layers::DefaultDecodingParams::getLengthPenalty());
earlyStopping = fuseValues<SizeType32>(
configs, [&configs](size_t ci) { return configs[ci].earlyStopping; },
layers::DefaultDecodingParams::getEarlyStopping());
topKMedusaHeads = fuseValues<std::vector<SizeType32>>(
configs, [&configs](size_t ci) { return configs[ci].topKMedusaHeads; },
layers::DefaultDecodingParams::getTopKMedusaHeads());
// Only used for tests.
draftAcceptanceThreshold = fuseValues<FloatType>(
configs, [&configs](size_t ci) { return configs[ci].draftAcceptanceThreshold; }, 0);
}
explicit SamplingConfig(executor::SamplingConfig const& samplingConfig,
@ -148,13 +171,13 @@ public:
// beam search layer
OptVec<FloatType> beamSearchDiversityRate; // [1] or [batch_size]
OptVec<FloatType> lengthPenalty; // [1] or [batch_size]
OptVec<SizeType> earlyStopping; // [1] or [batch_size]
OptVec<SizeType32> earlyStopping; // [1] or [batch_size]
// speculative decoding, only the first value is used (in gptDecoderBatch.cpp)
OptVec<FloatType> draftAcceptanceThreshold; // [1] or [batch_size]
// medusa params
OptVec<std::vector<runtime::SizeType>> topKMedusaHeads; // [batchSize, maxMedusaHeads]
OptVec<std::vector<runtime::SizeType32>> topKMedusaHeads; // [batchSize, maxMedusaHeads]
std::optional<bool> normalizeLogProbs;

View File

@ -30,6 +30,7 @@ add_subdirectory(common)
add_subdirectory(kernels)
add_subdirectory(layers)
add_subdirectory(runtime)
add_subdirectory(executor_worker)
set(BATCH_MANAGER_TARGET tensorrt_llm_batch_manager_static)
set(BATCH_MANAGER_TARGET_ARCH "unknown")
@ -196,8 +197,9 @@ set(TRTLLM_LINK_LIBS
kernels_src
context_attention_src
decoder_attention_src
cutlass2_src
cutlass3_src
fpA_intB_gemm_src
moe_gemm_src
cutlass_src
layers_src
runtime_src)
@ -218,44 +220,31 @@ set_target_properties(
PROPERTIES CXX_STANDARD "17" CXX_STANDARD_REQUIRED "YES" CXX_EXTENSIONS "NO"
LINK_FLAGS "${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}")
function(link_whole_archive TARGET LIBRARY_TO_LINK)
if(WIN32)
target_link_libraries(${TARGET} PUBLIC $<TARGET_FILE:${LIBRARY_TO_LINK}>)
set_target_properties(
${TARGET} PROPERTIES LINK_FLAGS "/WHOLEARCHIVE:${LIBRARY_TO_LINK}")
else()
# Assume everything else is like gcc
target_link_libraries(
${TARGET} PRIVATE "-Wl,--whole-archive" $<TARGET_FILE:${LIBRARY_TO_LINK}>
"-Wl,--no-whole-archive")
endif()
endfunction()
target_link_libraries(${SHARED_TARGET} PUBLIC ${TRTLLM_LINK_LIBS})
if(WIN32)
target_link_libraries(${SHARED_TARGET}
PUBLIC $<TARGET_FILE:${BATCH_MANAGER_TARGET}>)
set_target_properties(
${SHARED_TARGET} PROPERTIES LINK_FLAGS
"/WHOLEARCHIVE:${BATCH_MANAGER_TARGET}")
else()
# Assume everything else is like gcc
target_link_libraries(
${SHARED_TARGET}
PRIVATE "-Wl,--whole-archive" $<TARGET_FILE:${BATCH_MANAGER_TARGET}>
"-Wl,--no-whole-archive")
endif()
if(WIN32)
target_link_libraries(${SHARED_TARGET}
PUBLIC $<TARGET_FILE:${EXECUTOR_TARGET}>)
set_target_properties(
${SHARED_TARGET} PROPERTIES LINK_FLAGS "/WHOLEARCHIVE:${EXECUTOR_TARGET}")
else()
# Assume everything else is like gcc
target_link_libraries(
${SHARED_TARGET}
PRIVATE "-Wl,--whole-archive" $<TARGET_FILE:${EXECUTOR_TARGET}>
"-Wl,--no-whole-archive")
endif()
link_whole_archive(${SHARED_TARGET} ${BATCH_MANAGER_TARGET})
link_whole_archive(${SHARED_TARGET} ${EXECUTOR_TARGET})
# Cyclic dependency of batch manager on TRT-LLM
target_link_libraries(${BATCH_MANAGER_TARGET} INTERFACE ${SHARED_TARGET})
# Cyclic dependency of executor on TRT-LLM
target_link_libraries(${EXECUTOR_TARGET} INTERFACE ${SHARED_TARGET})
add_dependencies(${SHARED_TARGET} check_symbol)
add_dependencies(${SHARED_TARGET} check_symbol_executor)
# Cyclic dependency of batch manager on TRT-LLM
target_link_libraries(${BATCH_MANAGER_TARGET} INTERFACE ${SHARED_TARGET})
# Cyclic dependency of executor on TRT-LLM
target_link_libraries(${EXECUTOR_TARGET} INTERFACE ${SHARED_TARGET})
if(BUILD_PYT)
add_subdirectory(thop)
endif()

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6bd5ec7130a703889eb51fe6591c93a079ded644ca089099efe5e3d72474838e
size 2896708
oid sha256:d8a083974ff58e74dec95d1ad438bf84be9adeedeb20b5e7254fe56d6a4bf40c
size 2997970

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d25d35be9ec13d1f0a0b9f3ed40362879d9ac50bdfcdcb827990554a26ff5c10
size 2923694
oid sha256:40cace20ce33a945ed12a2a2e382053aa90113d8bed2623c985dbb60b943251e
size 3034874

View File

@ -1,3 +1,3 @@
cafe56cc4a916b91ea338a8412c79fef libtensorrt_llm_batch_manager_static.a
3274866669694da8f09e30388939b7dd libtensorrt_llm_batch_manager_static.pre_cxx11.a
165fe125d6bf55090d8a7dec012d08f8d0e7a54b commit
7c5e14e8ed4e3e0641a8aefa659a03c0 libtensorrt_llm_batch_manager_static.a
79a986633cb1f0dc6621423bbbf21727 libtensorrt_llm_batch_manager_static.pre_cxx11.a
83029c1606a00e0e4aaf5ea2de17867a6e5ddd9b commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:27dbbdae087a946d1762f11efe953a1b1b282e27747708145c405e9380fce287
size 2822910
oid sha256:913f548b9f66aaea93baaa40bd7ca37f4fb0b52f5ed0778b1fe52c136141433c
size 2916334

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:622724d6b9219dd3d4710a822ca92d497c466cdc34149258f9559c08f4470f8e
size 2796594
oid sha256:8dd40bb9cafae379971b365c8206fd20addb7816c64953456568110e5f694b0e
size 2900610

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:296c78f2c29774fab2145465a9a515a7e4aaedde96ba3c3f6fa5af91fa92dee6
size 18976374
oid sha256:889f62ee370c0a00c1ccfc26e82fcd1410413e44e6d955aca12a90c906e89239
size 18428048

View File

@ -62,6 +62,7 @@ CUDADriverWrapper::CUDADriverWrapper()
*(void**) (&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2");
*(void**) (&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
*(void**) (&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel");
*(void**) (&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled");
}
CUDADriverWrapper::~CUDADriverWrapper()
@ -143,5 +144,14 @@ CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX,
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra);
}
CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const
{
return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides,
boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill);
}
} // namespace common
} // namespace tensorrt_llm

View File

@ -70,6 +70,11 @@ public:
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
CUstream hStream, void** kernelParams, void** extra) const;
CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank,
void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim,
cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle,
CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const;
private:
void* handle;
CUresult (*_cuGetErrorName)(CUresult, char const**);
@ -89,6 +94,10 @@ private:
CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
CUstream hStream, void** kernelParams, void** extra);
CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill);
};
inline void cuErrCheck_(CUresult stat, CUDADriverWrapper const& wrap, char const* file, int line)

View File

@ -22,21 +22,39 @@
namespace tensorrt_llm::common
{
static std::optional<int32_t> getIntEnv(char const* name)
{
char const* const env = std::getenv(name);
if (env == nullptr)
{
return std::nullopt;
}
int32_t const val = std::stoi(env);
if (val <= 0)
{
return std::nullopt;
}
return {val};
};
// XQA kernels (optimized kernels for generation phase).
bool forceXQAKernels()
{
char const* force_xqa_env_var = getenv("TRTLLM_FORCE_XQA");
static bool forceXQA = false;
if (force_xqa_env_var != nullptr)
{
if (force_xqa_env_var[0] == '1' && force_xqa_env_var[1] == '\0')
{
forceXQA = true;
}
}
static bool const forceXQA = (getIntEnv("TRTLLM_FORCE_XQA").value_or(0) != 0);
return forceXQA;
}
int32_t xqaMaxNbCtaPerKVHeadFactor()
{
return envXqaNbCtaPerKVHead().value_or(8);
}
std::optional<int32_t> envXqaNbCtaPerKVHead()
{
static std::optional<int32_t> const ret = getIntEnv("TRTLLM_XQA_BLOCKS_PER_SEQUENCE");
return ret;
}
// Tune the number of blocks per sequence for accuracy/performance purpose.
bool getEnvMmhaMultiblockDebug()
{

View File

@ -16,6 +16,8 @@
*/
#pragma once
#include <cstdint>
#include <optional>
namespace tensorrt_llm::common
{
@ -23,6 +25,14 @@ namespace tensorrt_llm::common
// XQA kernels (optimized kernels for generation phase).
bool forceXQAKernels();
// max number of CTAs for each KV head, multiple CTAs for one KV head is multi-block mode.
// this number defines the maximum number when reaches both max_batch_size and max_beam_width.
// If batch_size or beam_width doesn't reach maximum value, it is possible to have more CTAs per KV head than this
// value.
int32_t xqaMaxNbCtaPerKVHeadFactor();
std::optional<int32_t> envXqaNbCtaPerKVHead();
// Tune the number of blocks per sequence for accuracy/performance purpose.
bool getEnvMmhaMultiblockDebug();

View File

@ -19,6 +19,7 @@
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include <csignal>
#include <mpi.h>
@ -35,7 +36,7 @@ namespace tensorrt_llm::mpi
MPI_Datatype getMpiDtype(MpiType dtype)
{
static const std::unordered_map<MpiType, MPI_Datatype> dtype_map{
static std::unordered_map<MpiType, MPI_Datatype> const dtype_map{
{MpiType::kBYTE, MPI_BYTE},
{MpiType::kHALF, MPI_UINT16_T},
@ -57,7 +58,7 @@ MPI_Datatype getMpiDtype(MpiType dtype)
MPI_Op getMpiOp(MpiOp op)
{
static const std::unordered_map<MpiOp, MPI_Op> op_map{
static std::unordered_map<MpiOp, MPI_Op> const op_map{
{MpiOp::NULLOP, MPI_OP_NULL},
{MpiOp::MAX, MPI_MAX},
{MpiOp::MIN, MPI_MIN},
@ -122,16 +123,33 @@ std::shared_ptr<MpiRequest> MpiComm::bcastAsync(void* buffer, size_t size, MpiTy
return r;
}
std::shared_ptr<MpiRequest> MpiComm::bcastAsync(runtime::IBuffer& buf, int root) const
{
TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU);
return bcastAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
}
void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const
{
MPICHECK(MPI_Bcast(buffer, size, getMpiDtype(dtype), root, mComm));
}
void MpiComm::bcast(runtime::IBuffer& buf, int root) const
{
bcast(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
}
void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const
{
MPICHECK(MPI_Send(buffer, size, getMpiDtype(dtype), dest, tag, mComm));
}
void MpiComm::send(runtime::IBuffer const& buf, int dest, int tag) const
{
TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU);
send(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag);
}
MPI_Status MpiComm::recv(void* buffer, size_t size, MpiType dtype, int source, int tag) const
{
MPI_Status status{};
@ -139,6 +157,12 @@ MPI_Status MpiComm::recv(void* buffer, size_t size, MpiType dtype, int source, i
return status;
}
MPI_Status MpiComm::recv(runtime::IBuffer& buf, int source, int tag) const
{
TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU);
return recv(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, source, tag);
}
MpiComm MpiComm::split(int color, int key) const
{
MPI_Comm splitComm;

View File

@ -63,7 +63,7 @@ int8_t* nextWorkspacePtrWithAlignment(
return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, alignment);
}
size_t calculateTotalWorkspaceSize(size_t* workspaces, int count, const uintptr_t alignment = kCudaMemAlign)
size_t calculateTotalWorkspaceSize(size_t const* workspaces, int count, const uintptr_t alignment = kCudaMemAlign)
{
size_t total = 0;
for (int i = 0; i < count; i++)

View File

@ -30,6 +30,7 @@
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/epilogue/thread/linear_combination_silu.h"
#include "cutlass_extensions/epilogue/thread/fused_activations.h"
#include <cutlass/epilogue/fusion/operations.hpp>
namespace tensorrt_llm
{
@ -48,6 +49,10 @@ struct EpilogueOpBiasFtGelu
{
};
struct EpilogueOpBias
{
};
struct EpilogueOpDefaultSilu
{
};
@ -60,10 +65,6 @@ struct EpilogueOpDefaultFtGelu
{
};
struct EpilogueOpBias
{
};
struct EpilogueOpDefault
{
};
@ -71,6 +72,7 @@ struct EpilogueOpDefault
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator, typename Op>
struct Epilogue
{
static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag");
};
constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling;

View File

@ -36,10 +36,11 @@ namespace kernel
template <typename TypeA, typename TypeB, typename arch, typename Enable = void>
struct MixedGemmArchTraits
{
static_assert(dependent_false<arch>, "Unrecognised parameterization");
};
template <typename arch>
struct MixedGemmArchTraits<float, float, arch>
template <typename Arch>
struct MixedGemmArchTraits<float, float, Arch>
{
static constexpr int Stages = 2;
using OperatorClass = cutlass::arch::OpClassSimt;
@ -66,7 +67,7 @@ struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm70,
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
{
private:
using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm70>;
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm70>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
@ -92,7 +93,7 @@ struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm75,
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
{
private:
using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm75>;
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm75>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
@ -116,7 +117,7 @@ struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm80,
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
{
private:
using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm80>;
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm80>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
@ -133,6 +134,34 @@ public:
using Operator = typename LayoutDetails::Operator;
};
// ======================= Ada Traits ==============================
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm89,
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value
#ifdef ENABLE_FP8
|| cutlass::platform::is_same<TypeA, cutlass::float_e4m3_t>::value>::type
#endif
>
{
private:
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm89>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = typename LayoutDetails::Layout;
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits<TypeA>::value>;
using Operator = typename LayoutDetails::Operator;
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -546,8 +546,10 @@ struct GemmFpAIntB
run_kernel<arch::Sm70>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
run_kernel<arch::Sm75>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890)
run_kernel<arch::Sm80>(params, shared_storage);
#elif (__CUDA_ARCH__ == 890)
run_kernel<arch::Sm89>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 900)
CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.
#else

View File

@ -42,16 +42,16 @@ namespace gemm
namespace kernel
{
template <typename TypeB, typename Arch, typename Enable = void>
template <typename TypeA, typename TypeB, typename Arch, typename Enable = void>
struct LayoutDetailsB
{
};
// Volta specialiations. Volta will dequantize before STS, so we need a different operator
template <typename TypeB>
struct LayoutDetailsB<TypeB, arch::Sm70>
template <typename TypeA, typename TypeB>
struct LayoutDetailsB<TypeA, TypeB, arch::Sm70>
{
static constexpr int ThreadblockK = 64;
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 8;
using Operator = cutlass::arch::OpMultiplyAdd;
@ -59,19 +59,19 @@ struct LayoutDetailsB<TypeB, arch::Sm70>
// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks.
// TODO - Switch this to column major for weights since gemms should be more performant.
template <typename Arch>
struct LayoutDetailsB<half_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, half_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
{
static constexpr int ThreadblockK = 64;
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename Arch>
struct LayoutDetailsB<bfloat16_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, bfloat16_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
{
static constexpr int ThreadblockK = 64;
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<bfloat16_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
@ -79,11 +79,12 @@ struct LayoutDetailsB<bfloat16_t, Arch, typename platform::enable_if<Arch::kMinC
// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
// which signals that we want to dequantize after loading from smem.
template <typename Arch>
struct LayoutDetailsB < uint8_t,
Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
template <typename TypeA, typename Arch>
struct LayoutDetailsB < TypeA,
uint8_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
{
static constexpr int ThreadblockK = 64;
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
private:
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
@ -95,11 +96,12 @@ public:
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};
template <typename Arch>
struct LayoutDetailsB < uint4b_t,
Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
template <typename TypeA, typename Arch>
struct LayoutDetailsB < TypeA,
uint4b_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
{
static constexpr int ThreadblockK = 64;
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
private:
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint4b_t>::value;
@ -111,19 +113,19 @@ public:
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};
template <typename Arch>
struct LayoutDetailsB<uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
{
static constexpr int ThreadblockK = 64;
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename Arch>
struct LayoutDetailsB<uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
{
static constexpr int ThreadblockK = 64;
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;

View File

@ -35,6 +35,8 @@
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
@ -502,8 +504,7 @@ public:
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
run_kernel<arch::Sm80>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 900)
run_kernel<arch::Sm80>(
params, shared_storage); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.
run_kernel<arch::Sm80>(params, shared_storage);
#else
static_assert(
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");

View File

@ -38,11 +38,12 @@ namespace threadblock
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment,
typename Enable = void>
struct DefaultScaleIterators;
struct DefaultScaleIteratorsMultistage;
// Fine grained iterators
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
struct DefaultScaleIterators<MmaShape, Element, Layout, QuantOp, Alignment, std::enable_if_t<isFinegrained(QuantOp)>>
struct DefaultScaleIteratorsMultistage<MmaShape, Element, Layout, QuantOp, Alignment,
std::enable_if_t<isFinegrained(QuantOp)>>
{
using IteratorScale
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>, Element,
@ -53,7 +54,8 @@ struct DefaultScaleIterators<MmaShape, Element, Layout, QuantOp, Alignment, std:
// Per column iterators
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
struct DefaultScaleIterators<MmaShape, Element, Layout, QuantOp, Alignment, std::enable_if_t<!isFinegrained(QuantOp)>>
struct DefaultScaleIteratorsMultistage<MmaShape, Element, Layout, QuantOp, Alignment,
std::enable_if_t<!isFinegrained(QuantOp)>>
{
// ThreadMap for scale iterator
static_assert((MmaShape::kN % Alignment) == 0, "");
@ -73,7 +75,7 @@ public:
////////////////////////////////////////////////////////////////////////////////
template <
/// Type for elementA
/// Type for element A
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
@ -105,9 +107,9 @@ template <
typename InstructionShape,
/// Stages in GEMM
int kStages,
///
/// Operator performed by GEMM
typename Operator_,
///
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear>
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
@ -116,8 +118,9 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
ArchTag::kMinComputeCapability >= 80 && !layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
{
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value
|| platform::is_same<ElementA, float_e4m3_t>::value,
"Element A must be fp16, fp8 or bf16");
using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
@ -155,7 +158,7 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, ElementB, LayoutB, 0, ThreadMapB,
AccessTypeB>;
using ScaleIterators = DefaultScaleIterators<typename MmaCore::Shape, ElementScale, LayoutScale,
using ScaleIterators = DefaultScaleIteratorsMultistage<typename MmaCore::Shape, ElementScale, LayoutScale,
OperatorInfo::QuantOp, kAlignmentScale>;
// Define iterators over tiles from the scale operand
@ -163,7 +166,7 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementA, ElementB,
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementScale, ElementB,
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
// Define the threadblock-scoped pipelined matrix multiply
@ -173,6 +176,7 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
typename MmaCore::MmaPolicy, kStages, Converter, OperatorInfo::QuantOp, SharedMemoryClear>;
};
// Specialization to handle column major interleave B
template <
/// Type for element A
typename ElementA,
@ -206,9 +210,9 @@ template <
typename InstructionShape,
/// Stages in GEMM
int kStages,
///
/// Operator performed by GEMM
typename Operator_,
///
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear>
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
@ -217,8 +221,9 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
ArchTag::kMinComputeCapability >= 80 && layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
{
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value
|| platform::is_same<ElementA, float_e4m3_t>::value,
"Element A must be fp16, fp8 or bf16");
using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
@ -274,7 +279,7 @@ public:
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<GmemIteratorShape, ElementB,
layout::ColumnMajor, 0, GmemThreadMapB, AccessTypeB>;
using ScaleIterators = DefaultScaleIterators<typename MmaCore::Shape, ElementScale, LayoutScale,
using ScaleIterators = DefaultScaleIteratorsMultistage<typename MmaCore::Shape, ElementScale, LayoutScale,
OperatorInfo::QuantOp, kAlignmentScale>;
// Define iterators over tiles from the scale operand
@ -282,7 +287,7 @@ public:
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementA, ElementB,
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementScale, ElementB,
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
// Define the threadblock-scoped pipelined matrix multiply

View File

@ -35,6 +35,42 @@ namespace threadblock
////////////////////////////////////////////////////////////////////////////////
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment,
typename Enable = void>
struct DefaultScaleIteratorsPipelined;
// TODO: Fine grained iterators
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
struct DefaultScaleIteratorsPipelined<MmaShape, Element, Layout, QuantOp, Alignment,
std::enable_if_t<isFinegrained(QuantOp)>>
{
};
// Per column iterators
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
struct DefaultScaleIteratorsPipelined<MmaShape, Element, Layout, QuantOp, Alignment,
std::enable_if_t<!isFinegrained(QuantOp)>>
{
static_assert((MmaShape::kN % Alignment) == 0, "");
private:
// ThreadMap for scale iterator
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaShape::kN, 1>,
MmaShape::kN / Alignment, Alignment>;
using SmemScaleType = half_t;
public:
// Define iterators over tiles from the scale operand
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>,
Element, Layout, 0, IteratorScaleThreadMap, Alignment>;
using SmemIteratorScale
= cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>, SmemScaleType,
Layout, 0, IteratorScaleThreadMap, Alignment>;
};
////////////////////////////////////////////////////////////////////////////////
template <
/// Type for element A
typename ElementA,
@ -86,8 +122,7 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
using MmaCoreElementA = typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type;
using MmaCoreElementA = half_t;
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
// Define the MmaCore components
@ -105,21 +140,13 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>, ElementB, LayoutB, 0,
typename MmaCore::IteratorThreadMapB, kAlignmentB>;
// ThreadMap for scale iterator
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
using IteratorScaleThreadMap
= transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>;
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape, ElementScale, LayoutScale,
OperatorInfo::QuantOp, kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale
= cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
ElementScale, LayoutScale, 0, IteratorScaleThreadMap, kAlignmentScale>;
using IteratorScale = typename ScaleIterators::IteratorScale;
using SmemScaleType = typename platform::conditional<arch_has_bf16_mma, ElementScale, half_t>::type;
using SmemIteratorScale
= cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
SmemScaleType, LayoutScale, 0, IteratorScaleThreadMap, kAlignmentScale>;
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
@ -182,8 +209,7 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
using MmaCoreElementA = typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type;
using MmaCoreElementA = half_t;
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
// Define the MmaCore components
@ -225,15 +251,13 @@ public:
= transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale
= cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
ElementScale, LayoutScale, 0, IteratorScaleThreadMap, kAlignmentScale>;
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape, ElementScale, LayoutScale,
OperatorInfo::QuantOp, kAlignmentScale>;
using SmemScaleType = typename platform::conditional<arch_has_bf16_mma, ElementScale, half_t>::type;
using SmemIteratorScale
= cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
SmemScaleType, LayoutScale, 0, IteratorScaleThreadMap, kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = typename ScaleIterators::IteratorScale;
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;

View File

@ -29,7 +29,7 @@ namespace threadblock
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma pipelined (stage=2)
template <
/// Layout type for A matrix operand
typename LayoutA,
@ -77,7 +77,7 @@ public:
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2)
template <
/// Layout type for A matrix operand
typename LayoutA,
@ -124,6 +124,9 @@ public:
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage
/// (stage>=3)
template <
/// Layout type for A matrix operand
typename LayoutA,
@ -176,7 +179,8 @@ public:
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage
/// (stage>=3)
template <
/// Layout type for A matrix operand
typename LayoutA,
@ -228,6 +232,63 @@ public:
using ThreadblockMma = typename Mma::ThreadblockMma;
};
#ifdef ENABLE_FP8
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage
/// (stage>=3)
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
using Mma = DqMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t,
layout::RowMajor, kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
ThreadblockShape, WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
#endif
// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
// large tile when not enough shared mem is present to do 3+ stage
template <

View File

@ -86,7 +86,7 @@ template <
typename SmemIteratorScale_,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
/// Layout of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
@ -189,8 +189,9 @@ private:
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool RequiresTileInterleave
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
@ -218,7 +219,7 @@ public:
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& shared_storage,
/// The group size for quantization
int group_size,
int const group_size,
///< ID within the threadblock
int thread_idx,
///< ID of warp
@ -279,10 +280,21 @@ public:
}
else if (iterator_scale.group_size_ == 128)
{
if (iterator_scale.row_groupsize64_ & 0x1)
if constexpr (Shape::kK == 128)
{
iterator_scale.add_tile_offset({1, 0});
}
else if constexpr (Shape::kK == 64)
{
if (iterator_scale.row_groupsize64_ & 0x1)
{
iterator_scale.add_tile_offset({1, 0});
}
}
else
{
static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128");
}
}
iterator_scale.row_groupsize64_++;
@ -291,8 +303,8 @@ public:
}
CUTLASS_DEVICE
void copy_tiles_and_advance(IteratorA& iterator_A, IteratorB& iterator_B, IteratorScale& iterator_scale,
int group_start_A = 0, int group_start_B = 0)
void copy_tiles_and_advance(
IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0)
{
iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
this->smem_iterator_A_.set_iteration_index(group_start_A);
@ -414,8 +426,6 @@ public:
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
@ -586,8 +596,17 @@ public:
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros);
run_warp_mma(
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
using FragmentOperandB = cutlass::Array<ElementA, Operator::FragmentB::kElements>;
constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements;
static_assert(ConversionVectorWidth == FragmentOperandB::kElements);
using Converter
= cutlass::NumericArrayConverter<ElementA, ElementScale, ConversionVectorWidth, RoundStyle>;
FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B);
run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum,
warp_tileB_k_compute_offset);
// Issue global->shared copies for the this stage
if (warp_mma_k < Base::kWarpGemmIterations - 1)
@ -597,8 +616,7 @@ public:
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(
iterator_A, iterator_B, iterator_scale, group_start_iteration_A, group_start_iteration_B);
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
// This is the first group of a given stage, so we issue the loads for the B scales immediately.
if (group_start_iteration_B == 0)
@ -613,8 +631,7 @@ public:
group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(
iterator_A, iterator_B, iterator_scale, group_start_iteration_A, group_start_iteration_B);
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();

View File

@ -190,8 +190,9 @@ private:
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool RequiresTileInterleave
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
@ -482,7 +483,7 @@ public:
}
}
// Waits until kStages-2 stages have committed.
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
cutlass::arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
@ -548,8 +549,17 @@ public:
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
run_warp_mma(
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
using FragmentOperandB = cutlass::Array<ElementA, Operator::FragmentB::kElements>;
constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements;
static_assert(ConversionVectorWidth == FragmentOperandB::kElements);
using Converter
= cutlass::NumericArrayConverter<ElementA, ElementScale, ConversionVectorWidth, RoundStyle>;
FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B);
run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum,
warp_tileB_k_compute_offset);
// Issue global->shared copies for the this stage
if (warp_mma_k < Base::kWarpGemmIterations - 1)
@ -573,7 +583,8 @@ public:
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();
// Waits until kStages-2 stages have committed.
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 -
// #committed)
arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();

View File

@ -161,8 +161,9 @@ private:
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool RequiresTileInterleave
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;

View File

@ -82,7 +82,7 @@ private:
using ComputeInstructionShape = InstructionShape_;
// Chosen so we get K=16 for int8 and K=32 for int4.
static constexpr int LoadInstructionK = 8 * sizeof_bits<ElementA>::value / sizeof_bits<ElementB>::value;
static constexpr int LoadInstructionK = 128 / sizeof_bits<ElementB>::value;
// Shape for loading the narrow data type from shared memory
using LoadInstructionShape = GemmShape<InstructionShape_::kM, InstructionShape_::kN, LoadInstructionK>;

View File

@ -46,6 +46,7 @@
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/mma_sm80.h"
#include "cutlass/arch/mma_sm89.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/warp/mma.h"
@ -131,12 +132,16 @@ public:
&& platform::is_same<typename ArchMmaOperator::ElementB, half_t>::value)
|| (platform::is_same<typename ArchMmaOperator::ElementA, bfloat16_t>::value
&& platform::is_same<typename ArchMmaOperator::ElementB, bfloat16_t>::value
&& ArchTag::kMinComputeCapability >= 80),
"MmaTensorOpCvtBToA only supports underlying HMMA");
&& ArchTag::kMinComputeCapability >= 80)
|| (platform::is_same<typename ArchMmaOperator::ElementA, float_e4m3_t>::value
&& platform::is_same<typename ArchMmaOperator::ElementB, float_e4m3_t>::value
&& ArchTag::kMinComputeCapability >= 89),
"MmaTensorOpCvtBToA only supports underlying HMMA/QMMA");
static_assert(platform::is_same<ElementA, half_t>::value
|| (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80),
"MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+");
|| (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80)
|| (platform::is_same<ElementA, float_e4m3_t>::value && ArchTag::kMinComputeCapability >= 89),
"MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+, or FP8 on Ada");
/// Indicates class of matrix operator
using OperatorClass = arch::OpClassTensorOp;

View File

@ -367,7 +367,8 @@ public:
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
{
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
using ExpandedMmaOperandB
= Array<typename FragmentDequantizedOperand::Element, kExpansionFactor * _MmaOperandB::kElements>;
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
== FragmentDequantizedOperand::kElements,
"");
@ -409,7 +410,8 @@ public:
FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag)
{
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
using ExpandedMmaOperandB
= Array<typename FragmentDequantizedOperand::Element, kExpansionFactor * _MmaOperandB::kElements>;
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
== FragmentDequantizedOperand::kElements,
"");

View File

@ -111,6 +111,16 @@ enum class ClusterShape
struct CutlassGemmConfig
{
enum CandidateConfigTypeParam : int
{
NONE = 0,
WEIGHT_ONLY = 1u << 0,
SIMT_ONLY = 1u << 1,
INT8_ONLY = 1u << 2,
HOPPER = 1u << 3,
GROUPED_GEMM = 1u << 4,
};
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K;
int split_k_factor = -1;
@ -121,6 +131,7 @@ struct CutlassGemmConfig
MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO;
EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO;
ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1;
bool is_sm90 = false;
CutlassGemmConfig() {}
@ -129,6 +140,7 @@ struct CutlassGemmConfig
, split_k_style(split_k_style)
, split_k_factor(split_k_factor)
, stages(stages)
, is_sm90(false)
{
}
@ -138,6 +150,7 @@ struct CutlassGemmConfig
, mainloop_schedule(mainloop_schedule)
, epilogue_schedule(epilogue_schedule)
, cluster_shape(cluster_shape)
, is_sm90(true)
{
}
};

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:667eb7aaa018b36c8aee79be8c3e9432a29ba33a32c1e1e423e15809a57a40b0
size 851008
oid sha256:f1990679ad8fbfbcb2b063eb7cef689a5111776cd4bef0af7f792a8ce0d46277
size 1202412

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:deb973f60a1a3623f5a014366638bddb71d76ebd7e78f9526c98d873ee4b1b5d
size 863464
oid sha256:4e6aaf013256a78494afa1f46de2109b6acb572ee1dc27b1a80cb915868c3162
size 1218938

View File

@ -1,3 +1,3 @@
326f6910f53b9272872d7630a8ce2eea libtensorrt_llm_executor_static.a
82f32c59f88aff8eee61aeb1b328ecf1 libtensorrt_llm_executor_static.pre_cxx11.a
165fe125d6bf55090d8a7dec012d08f8d0e7a54b commit
f91339ae7a9840c71f672d960bfd5446 libtensorrt_llm_executor_static.a
66f3b01a5c61c22127d263fa97fec0ec libtensorrt_llm_executor_static.pre_cxx11.a
83029c1606a00e0e4aaf5ea2de17867a6e5ddd9b commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fdd06e455ac1527ab7c1da4748b88e1e1d98d04c0cbc94ffbf1bf5968b095ed6
size 890380
oid sha256:bd96ad0a662f7a989a8f2d04581a28504a56f3f642f787d13dd2edbcbb50890c
size 1225668

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5f9841068c9ece43d9d7cb95af8dffb1b1d6384399b21cb0ce4f5330846daf7d
size 843698
oid sha256:e27cf9ea138ba3ab62833e02a6a594ef209316eac0c7031e07c5f1c6cd6fb341
size 1180420

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b43a9234c39d36b580b1bdcae5676698a6f11be911ae1f8168db8e2dc89c4d04
size 9865220
oid sha256:feecaaadbdc3ab649e7b9fca4076b8b745ca4891244edfa62fe09beee4560d34
size 12198304

View File

@ -0,0 +1,26 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
set(SRCS executorWorker.cpp)
include_directories(${PROJECT_SOURCE_DIR}/include)
set(EXECUTOR_WORKER_TARGET executorWorker)
add_executable(${EXECUTOR_WORKER_TARGET} ${SRCS})
target_link_libraries(${EXECUTOR_WORKER_TARGET}
PUBLIC ${SHARED_TARGET} nvinfer_plugin_tensorrt_llm)
target_compile_features(${EXECUTOR_WORKER_TARGET} PRIVATE cxx_std_17)

View File

@ -0,0 +1,80 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mpi.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/mpiUtils.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/executor/serialization.h"
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
namespace tle = tensorrt_llm::executor;
int main(int argc, char* argv[])
{
// Register the TRT-LLM plugins
initTrtLlmPlugins();
tensorrt_llm::mpi::initialize(tensorrt_llm::mpi::MpiThreadSupport::THREAD_MULTIPLE);
MPI_Comm parentComm;
MPI_Comm_get_parent(&parentComm);
if (parentComm == MPI_COMM_NULL)
{
TLLM_LOG_ERROR("TRT-LLM worker has no parent!");
return -1;
}
int size;
MPI_Comm_remote_size(parentComm, &size);
if (size != 1)
{
TLLM_LOG_ERROR("Parent size is %d, must be 1", size);
return -1;
}
// Since parentComm is an intercommunicator, input root
// is the rank of the parent process in his group
// (always 0 as the parent size is checked before)
// Receive from the parent the executor configuration
int64_t bufferSize;
MPICHECK(MPI_Bcast(&bufferSize, 1, MPI_INT64_T, 0, parentComm));
std::vector<char> buffer(bufferSize);
MPICHECK(MPI_Bcast(buffer.data(), bufferSize, MPI_CHAR, 0, parentComm));
std::istringstream is(std::string(buffer.begin(), buffer.end()));
auto modelPath = tle::Serialization::deserializeString(is);
auto modelType = tle::Serialization::deserializeModelType(is);
auto executorConfig = tle::Serialization::deserializeExecutorConfig(is);
// Create the orchestrator config for workers
auto orchLeaderComm = std::make_shared<tensorrt_llm::mpi::MpiComm>(parentComm, true);
auto parallelConfig = executorConfig.getParallelConfig();
TLLM_CHECK_WITH_INFO(parallelConfig.has_value(), "Parallel config should have a value.");
TLLM_CHECK_WITH_INFO(
parallelConfig.value().getOrchestratorConfig().has_value(), "Orchestrator config should have a value.");
auto orchConfig = parallelConfig.value().getOrchestratorConfig().value();
TLLM_CHECK_WITH_INFO(parallelConfig.has_value(), "Parallel config should have a value.");
auto newOrchConfig = tle::OrchestratorConfig(false, orchConfig.getWorkerExecutablePath(), orchLeaderComm);
parallelConfig.value().setOrchestratorConfig(newOrchConfig);
executorConfig.setParallelConfig(parallelConfig.value());
// In orchestrator mode, the spawned threads will wait for termination signal from orchestrator
auto executor = tle::Executor(modelPath, modelType, executorConfig);
TLLM_LOG_INFO("Executor worker exiting");
return 0;
}

View File

@ -24,18 +24,18 @@ namespace tensorrt_llm
namespace kernels
{
template <typename T, int MAX_K>
void topK_softMax_kernelLauncher(
template <typename T, int PAD_K>
void topKSoftMaxKernelLauncher(
T const* logits, T const* bias, void* workspace, BeamHypotheses& bh, cudaStream_t stream);
#define CASE_K(MAX_K) \
topK_softMax_kernelLauncher<T, MAX_K>(logits, bias, workspace, bh, stream); \
#define CASE_K(PAD_K) \
topKSoftMaxKernelLauncher<T, PAD_K>(logits, bias, workspace, bh, stream); \
break;
template <typename T>
void invokeTopkSoftMax(T const* logits, T const* bias, void* workspace, BeamHypotheses& bh, cudaStream_t stream)
{
switch (padToNextPowerOfTwo(bh.beam_width))
switch (padToNextPowerOfTwo(bh.nBeamWidth)) // PAD_K must be a compilation-time constant
{
case 1:
case 2:
@ -52,8 +52,9 @@ void invokeTopkSoftMax(T const* logits, T const* bias, void* workspace, BeamHypo
CASE_K(64)
#endif // FAST_BUILD
default:
throw std::runtime_error(fmtstr(
"%s:%d Topk kernel of beam search does not support beam_width=%d", __FILE__, __LINE__, bh.beam_width));
throw std::runtime_error(
fmtstr("%s:%d Maximum beam width supported for beam search (%d) is larger than beam_width now use (%d)",
__FILE__, __LINE__, nMaxBeamWidth, bh.nBeamWidth));
}
}

View File

@ -22,64 +22,64 @@ namespace tensorrt_llm
namespace kernels
{
static constexpr int nMaxBeamWidth = 64; // max beam width supported now
static constexpr int nSmallTopKBlockSize = 256;
static constexpr int nSmallTopKMaxVocParts = 128;
static constexpr int nBlockSizeForSmallBeamWidth = 256;
static constexpr int nMaxVocabPartForStage1FastKernel = 128;
struct BeamHypotheses
{
// clang-format off
// BS: batch_size, BM: beam_width, mSL: max_seq_length
// BS: batch_size, BM: beam_width, MSL: max_seq_length
// %%: parameter name when dynamic_decoder.forward() / gather_tree() are called in [generation.py] (python workflow)
// Candidate beams: When a beam generates end_id or its sequence length reaches mSL, it becomes a candidate beam to be selected finally.
// Candidate-Beam-Array (CBA): Arrays (size: BM*2) to place the candidate beams and related information
// Candidate beams: a beam which generates end_id or its sequence length reaches MSL
// Candidate-Beam-Array (CBA): The arrays (size: BM*2) to place the candidate beams and related information
// Scalar values
bool is_return_normed_score{true}; // return normed_score / cum_log_probs, useless yet
int batch_size{0}; //
int beam_width{0}; //
int ite{0}; // index of local_batch, always be 0 when pp_size==1
int local_batch_size{0}; //
int max_seq_len{0}; //
int vocab_size{0}; // vocab_size_padded
bool bReturnNormedScore{false}; // return normed_score / cum_log_probs, useless yet
int nBatchSize{0}; //
int nBeamWidth{0}; //
int nIte{0}; // index of local_batch, always be 0 when pp_size==1
int nBatchSizeLocal{0}; //
int nMaxSeqLen{0}; //
int nVocabSize{0}; // vocab_size_padded
// Pointers from SamplingConfig
float const* diversity_rates{nullptr}; // [BS]
float const* length_penalties{nullptr}; // [BS]
int const* early_stoppings{nullptr}; // [BS]
float const* diversityRates{nullptr}; // [BS]
float const* lengthPenalties{nullptr}; // [BS]
int const* earlyStoppings{nullptr}; // [BS]
// Pointers from input
int const* input_lengths{nullptr}; // [BS, BM] %% context_length
int const* end_ids{nullptr}; // [BS, BM] %% self.end_ids
int const* inputLengths{nullptr}; // [BS, BM] %% context_length
int const* endIds{nullptr}; // [BS, BM] %% self.end_ids
// Pointers for output
int* final_output_ids{nullptr}; // [BS, BM, mSL] %% self.output_ids
float* log_probs{nullptr}; // [mSL, BS, BM] %% self.log_probs_tiled
int* seq_len{nullptr}; // [BS, BM] %% self.sequence_length_buffer
float* cum_log_probs{nullptr}; // [BS, BM] %% self.cum_log_probs
int* outputIds{nullptr}; // [BS, BM, MSL] %% self.output_ids
float* logProbs{nullptr}; // [MSL, BS, BM] %% self.log_probs_tiled
int* sequenceLengths{nullptr}; // [BS, BM] %% self.sequence_length_buffer
float* cumLogProbs{nullptr}; // [BS, BM] %% self.cum_log_probs
// Pointers of CBA
int* output_ids_cba{nullptr}; // [BS, BM*2, mSL] %% self.beam_hyps_output_ids_cba
float* log_probs_cba{nullptr}; // [BS, BM*2, mSL] %% self.beam_hyps_log_probs_cba
int* seq_len_cba{nullptr}; // [BS, BM*2] %% self.beam_hyps_seq_len_cba
float* cum_log_probs_cba{nullptr}; // [BS, BM*2] %% self.beam_hyps_cum_log_probs_cba
float* normed_scores_cba{nullptr}; // [BS, BM*2] %% self.beam_hyps_normed_scores_cba
int* num_beams{nullptr}; // [BS] %% self.beam_hyps_num_beams number of beams in CBA
float* min_normed_scores{nullptr}; // [BS] %% self.beam_hyps_min_normed_scores worst score in CBA
int* outputIdsCBA{nullptr}; // [BS, BM*2, MSL] %% self.beam_hyps_output_ids_cba
float* logProbsCBA{nullptr}; // [BS, BM*2, MSL] %% self.beam_hyps_log_probs_cba
int* sequenceLengthsCBA{nullptr}; // [BS, BM*2] %% self.beam_hyps_seq_len_cba
float* cumLogProbsCBA{nullptr}; // [BS, BM*2] %% self.beam_hyps_cum_log_probs_cba
float* normedScoresCBA{nullptr}; // [BS, BM*2] %% self.beam_hyps_normed_scores_cba
int* numBeamsCBA{nullptr}; // [BS] %% self.beam_hyps_num_beams number of beams in CBA
float* minNormedScoresCBA{nullptr}; // [BS] %% self.beam_hyps_min_normed_scores worst score in CBA
// Pointers related to beam search process, they are initialized in those two functions:
// [gptDecoder.cpp] GptDecoder<T>::forward or [dynamicDecodeOp.cpp] FtDynamicDecode<T>::forward
bool* is_done{nullptr}; // [BS] %% self.beam_hyps_is_done whether a whole batch is finished
FinishedState* finished; // [BS*BM] %% self.finished whether and how a beam is finished
bool* batchDones{nullptr}; // [BS] %% self.beam_hyps_is_done whether a whole batch is finished
FinishedState* finished{nullptr}; // [BS*BM] %% self.finished whether and how a beam is finished
// Pointers for backtrack of the beams, they are relocated in [dynamicDecodeLayer.cpp] DynamicDecodeLayer<T>::prepareIdsPtrs
int** output_ids_ptr{nullptr}; // [BS][BM, mSL] %% self.output_ids
int** parent_ids_ptr{nullptr}; // [BS][BM, mSL] %% self.parent_ids
int** outputIdsPtr{nullptr}; // [BS][BM, MSL] %% self.output_ids
int** parentIdsPtr{nullptr}; // [BS][BM, MSL] %% self.parent_ids
// Pointers for gather_tree(), read the unfinished beams from them and write to CBA for the final selection
int const* output_ids_src{nullptr}; // [BS, BM, mSL] %% self.output_ids
int const* parent_ids_src{nullptr}; // [BS, BM, mSL] %% self.parent_ids
int const* outputIdsUnfinish{nullptr}; // [BS, BM, MSL] %% self.output_ids
int const* parentIdsUnfinish{nullptr}; // [BS, BM, MSL] %% self.parent_ids
// clang-format on
};

View File

@ -730,11 +730,8 @@ bool FusedMHARunnerV2::isValid(int s) const
// static function to check if fmha is supported when building plugins
bool MHARunner::fmha_supported(int const headSize, int const sm)
{
return (headSize == 32 || headSize == 40 || headSize == 64 || headSize == 80 || headSize == 96 || headSize == 104
|| headSize == 128 || headSize == 160 || headSize == 192 || headSize == 256);
return false;
}
} // namespace kernels

View File

@ -1,5 +1,5 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
@ -15,9 +15,6 @@
# the License.
#
file(GLOB_RECURSE SRC_CPP *.cpp)
file(GLOB_RECURSE SRC_CU *.cu)
# The Python executable will only be defined if building with Torch support. If
# not, we need to find it here.
if(NOT Python3_EXECUTABLE)
@ -43,10 +40,12 @@ set_directory_properties(
PROPERTIES CMAKE_CONFIGURE_DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/python/generate_kernels.py)
set(INSTANTIATION_GENERATION_DIR
${CMAKE_CURRENT_BINARY_DIR}/cutlass_instantiations)
execute_process(
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/python/
COMMAND ${Python3_EXECUTABLE} generate_kernels.py -o
${CMAKE_CURRENT_BINARY_DIR}
${INSTANTIATION_GENERATION_DIR}
RESULT_VARIABLE _KERNEL_GEN_SUCCESS)
if(NOT _KERNEL_GEN_SUCCESS MATCHES 0)
@ -56,41 +55,73 @@ if(NOT _KERNEL_GEN_SUCCESS MATCHES 0)
)
endif()
file(GLOB_RECURSE CU_INSTANTIATIONS ${CMAKE_CURRENT_BINARY_DIR}/*.cu)
# Get the sources for Mixed Input GEMM launchers
file(GLOB_RECURSE MIXED_CU_INSTANTIATIONS
${INSTANTIATION_GENERATION_DIR}/gemm/*.cu)
file(GLOB_RECURSE MIXED_SRC_CPP fpA_intB_gemm/*.cpp)
file(GLOB_RECURSE MIXED_SRC_CU fpA_intB_gemm/*.cu)
add_library(cutlass2_src STATIC ${SRC_CPP} ${SRC_CU})
set_property(TARGET cutlass2_src PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cutlass2_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
# Get the sources for MOE Grouped GEMM launchers
file(GLOB_RECURSE GROUPED_CU_INSTANTIATIONS
${INSTANTIATION_GENERATION_DIR}/gemm_grouped/*.cu)
file(GLOB_RECURSE GROUPED_SRC_CPP moe_gemm/*.cpp)
file(GLOB_RECURSE GROUPED_SRC_CU moe_gemm/*.cu)
add_library(cutlass3_src STATIC ${CU_INSTANTIATIONS})
set_property(TARGET cutlass3_src PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cutlass3_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
# Get the sources for all the remaining sources
file(GLOB_RECURSE SRC_CPP *.cpp)
file(GLOB_RECURSE SRC_CU *.cu)
set(ALL_SRCS ${SRC_CPP};${SRC_CU})
list(FILTER ALL_SRCS EXCLUDE REGEX "fpA_intB_gemm/.*")
list(FILTER ALL_SRCS EXCLUDE REGEX "moe_gemm/.*")
# Note - we deliberately do not include 90a PTX (even when 9.0+PTX is
# specified). This is because sm_90a has arch conditional instructions that are
# not forward compatible. As a result, it does not make sense to embed PTX into
# the binary anyway.
if("9.0" IN_LIST TORCH_CUDA_ARCH_LIST
OR "9.0+PTX" IN_LIST TORCH_CUDA_ARCH_LIST
OR "90-real" IN_LIST CMAKE_CUDA_ARCHITECTURES_NATIVE)
message(STATUS "MANUALLY APPENDING FLAG TO COMPILE FOR SM_90a.")
target_compile_options(
cutlass3_src
PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-gencode=arch=compute_90a,code=sm_90a>)
message(
STATUS
"Mixed srcs ${MIXED_SRC_CPP} ${MIXED_SRC_CU} ${MIXED_CU_INSTANTIATIONS}")
message(
STATUS
"Group srcs ${GROUPED_SRC_CU} ${GROUPED_SRC_CPP} ${GROUPED_CU_INSTANTIATIONS}"
)
message(STATUS "All srcs ${ALL_SRCS}")
# Hopper kernels require cuda lib for TMA APIs
target_link_libraries(cutlass3_src PRIVATE CUDA::cuda_driver)
add_library(cutlass_src STATIC ${ALL_SRCS})
set_property(TARGET cutlass_src PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cutlass_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
# No kernels should be parsed, unless hopper is specified. This is a build
# time improvement
target_compile_definitions(cutlass3_src
PRIVATE COMPILE_HOPPER_MIXED_INPUT_GEMMS)
endif()
add_library(fpA_intB_gemm_src STATIC ${MIXED_SRC_CPP} ${MIXED_SRC_CU}
${MIXED_CU_INSTANTIATIONS})
add_library(moe_gemm_src STATIC ${GROUPED_SRC_CU} ${GROUPED_SRC_CPP}
${GROUPED_CU_INSTANTIATIONS})
foreach(target_name fpA_intB_gemm_src;moe_gemm_src)
set_property(TARGET ${target_name} PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET ${target_name} PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
# Suppress GCC note: the ABI for passing parameters with 64-byte alignment has
# changed in GCC 4.6 This note appears for kernels using TMA and clutters the
# compilation output.
if(NOT WIN32)
target_compile_options(
cutlass3_src PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-psabi>)
endif()
# Note - we deliberately do not include 90a PTX (even when 9.0+PTX is
# specified). This is because sm_90a has arch conditional instructions that
# are not forward compatible. As a result, it does not make sense to embed PTX
# into the binary anyway.
if("9.0" IN_LIST TORCH_CUDA_ARCH_LIST
OR "9.0+PTX" IN_LIST TORCH_CUDA_ARCH_LIST
OR "90-real" IN_LIST CMAKE_CUDA_ARCHITECTURES_NATIVE)
message(STATUS "MANUALLY APPENDING FLAG TO COMPILE FOR SM_90a.")
target_compile_options(
${target_name}
PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-gencode=arch=compute_90a,code=sm_90a>)
# Hopper kernels require cuda lib for TMA APIs
target_link_libraries(${target_name} PRIVATE CUDA::cuda_driver)
# No kernels should be parsed, unless hopper is specified. This is a build
# time improvement
target_compile_definitions(${target_name} PRIVATE COMPILE_HOPPER_TMA_GEMMS)
endif()
# Suppress GCC note: the ABI for passing parameters with 64-byte alignment has
# changed in GCC 4.6 This note appears for kernels using TMA and clutters the
# compilation output.
if(NOT WIN32)
target_compile_options(
${target_name} PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-psabi>)
endif()
endforeach()

View File

@ -24,6 +24,7 @@
#include "cutlass/gemm/gemm.h"
#include "cutlass/numeric_types.h"
#include "tensorrt_llm/common/assert.h"
#ifndef _WIN32
#pragma GCC diagnostic pop
@ -65,7 +66,7 @@ TileShape get_cta_shape_for_config(CutlassTileConfig tile_config)
case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: return TileShape{128, 128};
case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: return TileShape{128, 256};
case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: return TileShape{256, 128};
default: throw std::runtime_error("[TensorRT-LLm Error][get_grid_shape_for_config] Invalid config");
default: TLLM_THROW("[get_grid_shape_for_config] Invalid config");
}
}
@ -110,7 +111,7 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k,
}
std::vector<CutlassTileConfig> get_candidate_tiles(
int const sm, bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only)
int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param)
{
enum class CutlassGemmType : char
{
@ -121,15 +122,15 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
};
CutlassGemmType gemm_type = CutlassGemmType::Default;
if (simt_configs_only)
if (config_type_param & CutlassGemmConfig::SIMT_ONLY)
{
gemm_type = CutlassGemmType::Simt;
}
else if (is_weight_only)
else if (config_type_param & CutlassGemmConfig::WEIGHT_ONLY)
{
gemm_type = CutlassGemmType::WeightOnly;
}
else if (int8_configs_only)
else if (config_type_param & CutlassGemmConfig::INT8_ONLY)
{
gemm_type = CutlassGemmType::Int8;
}
@ -170,39 +171,21 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
}
std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(
int const sm, bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only)
int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config)
{
enum class CutlassGemmType : char
if (config & CutlassGemmConfig::GROUPED_GEMM)
{
Default,
WeightOnly,
Simt,
Int8
};
CutlassGemmType gemm_type = CutlassGemmType::Default;
if (simt_configs_only)
{
gemm_type = CutlassGemmType::Simt;
return {CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B,
CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B,
CutlassTileConfigSM90::CtaShape128x256x128B};
}
else if (is_weight_only)
else
{
gemm_type = CutlassGemmType::WeightOnly;
}
else if (int8_configs_only)
{
gemm_type = CutlassGemmType::Int8;
}
switch (gemm_type)
{
case CutlassGemmType::WeightOnly:
return {CutlassTileConfigSM90::CtaShape64x16x128B, CutlassTileConfigSM90::CtaShape64x32x128B,
CutlassTileConfigSM90::CtaShape64x64x128B, CutlassTileConfigSM90::CtaShape64x128x128B,
CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x16x128B,
CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B,
CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B};
default: throw std::runtime_error("get_candidate_tiles_sm90 only supports WeightOnly now.");
}
}
@ -226,13 +209,12 @@ bool supports_mcast_along_n(const CutlassTileConfigSM90 tile)
return valid_tiles.count(tile) == 1;
}
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, bool const is_weight_only, bool const simt_configs_only,
bool const int8_configs_only, int const max_split_k, bool const enable_hopper_gmma)
std::vector<CutlassGemmConfig> get_candidate_configs(
int sm, int const max_split_k, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param)
{
if (sm == 90 && enable_hopper_gmma)
if (sm == 90 && (config_type_param & CutlassGemmConfig::HOPPER))
{
std::vector<CutlassTileConfigSM90> tiles
= get_candidate_tiles_sm90(sm, is_weight_only, simt_configs_only, int8_configs_only);
std::vector<CutlassTileConfigSM90> tiles = get_candidate_tiles_sm90(sm, config_type_param);
std::vector<CutlassGemmConfig> candidate_configs;
for (auto const& tile_config : tiles)
@ -266,10 +248,10 @@ std::vector<CutlassGemmConfig> get_candidate_configs(int sm, bool const is_weigh
}
return candidate_configs;
}
std::vector<CutlassTileConfig> tiles
= get_candidate_tiles(sm, is_weight_only, simt_configs_only, int8_configs_only);
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, config_type_param);
std::vector<CutlassGemmConfig> candidate_configs;
bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY;
int const min_stages = int8_configs_only ? 3 : 2;
int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);
for (auto const& tile_config : tiles)
@ -299,8 +281,8 @@ CutlassGemmConfig estimate_best_config_from_occupancies(std::vector<CutlassGemmC
if (occupancies.size() != candidate_configs.size())
{
throw std::runtime_error(
"[TensorRT-LLm Error][estimate_best_config_from_occupancies] occpancies and "
TLLM_THROW(
"[estimate_best_config_from_occupancies] occpancies and "
"candidate configs vectors must have equal length.");
}
@ -374,7 +356,7 @@ CutlassGemmConfig estimate_best_config_from_occupancies(std::vector<CutlassGemmC
if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic)
{
throw std::runtime_error("[TensorRT-LLm Error] Heurisitc failed to find a valid config.");
TLLM_THROW("Heurisitc failed to find a valid config.");
}
return best_config;

View File

@ -26,9 +26,8 @@ namespace kernels
namespace cutlass_kernels
{
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> get_candidate_configs(int sm,
bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only = false,
int const max_split_k = 1, bool const enable_hopper_gmma = false);
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> get_candidate_configs(
int sm, int const max_split_k, tensorrt_llm::cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam const);
tensorrt_llm::cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies(
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> const& candidate_configs,

View File

@ -30,16 +30,6 @@ namespace kernels
namespace cutlass_kernels
{
int get_bits_in_quant_type(QuantType quant_type)
{
switch (quant_type)
{
case QuantType::INT8_WEIGHT_ONLY: return 8;
case QuantType::PACKED_INT4_WEIGHT_ONLY: return 4;
default: TLLM_CHECK_WITH_INFO(false, "Invalid quant_type"); return -1;
}
}
struct LayoutDetails
{
enum class Layout
@ -96,11 +86,11 @@ struct getLayoutDetails<cutlass::layout::ColumnMajorTileInterleave<RowsPerTile,
}
};
template <typename cutlassArch, typename TypeB>
template <typename cutlassArch, typename TypeA, typename TypeB>
LayoutDetails getLayoutDetailsForArchAndQuantType()
{
using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB<TypeB, cutlassArch>;
using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB<TypeA, TypeB, cutlassArch>;
using LayoutB = typename CompileTraits::Layout;
using MmaOperator = typename CompileTraits::Operator;
LayoutDetails details = getLayoutDetails<LayoutB>()();
@ -111,18 +101,20 @@ LayoutDetails getLayoutDetailsForArchAndQuantType()
template <typename cutlassArch>
LayoutDetails getLayoutDetailsForArch(QuantType quant_type)
{
int const bits_per_weight_element = get_weight_quant_bits(quant_type);
LayoutDetails details;
if (quant_type == QuantType::INT8_WEIGHT_ONLY)
switch (quant_type)
{
details = getLayoutDetailsForArchAndQuantType<cutlassArch, uint8_t>();
}
else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY)
{
details = getLayoutDetailsForArchAndQuantType<cutlassArch, cutlass::uint4b_t>();
}
else
{
TLLM_CHECK_WITH_INFO(false, "Unsupported quantization type");
case QuantType::W8_A16:
details = getLayoutDetailsForArchAndQuantType<cutlassArch, cutlass::half_t, uint8_t>();
break;
case QuantType::W4_A16:
details = getLayoutDetailsForArchAndQuantType<cutlassArch, cutlass::half_t, cutlass::uint4b_t>();
break;
case QuantType::W4_AFP8:
details = getLayoutDetailsForArchAndQuantType<cutlassArch, cutlass::float_e4m3_t, cutlass::uint4b_t>();
break;
default: TLLM_THROW("Unsupported quantization type");
}
return details;
}
@ -137,7 +129,7 @@ LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch)
{
return getLayoutDetailsForArch<cutlass::arch::Sm75>(quant_type);
}
else if (arch >= 80 && arch <= 89)
else if (arch >= 80 && arch < 90)
{
return getLayoutDetailsForArch<cutlass::arch::Sm80>(quant_type);
}
@ -152,25 +144,54 @@ LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch)
}
}
// Permutes the rows of B for Turing and Ampere. Throws an error for other architectures.
// Permutes the rows of B in a way that is compatible with Turing+ architectures.
//
// Throws an error for other architectures.
// The data is permuted such that:
// For int8, each group of 16 rows is permuted using the map below:
// For W8_A16, each group of 16 rows is permuted using the map below:
// 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15
// For int4, each group of 32 rows is permuted using the map below:
// For W4_A16, each group of 32 rows is permuted using the map below:
// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31
// For W4_A8, see the map in the code. The idea is similar to above.
// The goal of this permutation is to ensure data ends up in the correct threads after
// we execute LDSM. It counteracts the effect of the data being of different widths.
// For more information about the expected layouts, see the MMA section in the PTX docs.
std::vector<int> get_permutation_map(QuantType quant_type)
{
if (quant_type == QuantType::W8_A16)
{
return {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
}
else if (quant_type == QuantType::W4_A16)
{
return {0, 1, 8, 9, 16, 17, 24, 25, 2, 3, 10, 11, 18, 19, 26, 27, 4, 5, 12, 13, 20, 21, 28, 29, 6, 7, 14, 15,
22, 23, 30, 31};
}
else if (quant_type == QuantType::W4_AFP8)
{
return {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23, 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15,
28, 29, 30, 31};
}
else
{
TLLM_THROW("Invalid quantization type for LDSM permutation");
}
}
void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor,
std::vector<size_t> const& shape, QuantType quant_type, const int64_t arch_version)
std::vector<size_t> const& shape, QuantType quant_type, int64_t const arch_version)
{
// We only want to run this step for weight only quant.
TLLM_CHECK(quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY || quant_type == QuantType::INT8_WEIGHT_ONLY);
std::vector<int> row_permutation = get_permutation_map(quant_type);
TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
const size_t num_experts = shape.size() == 2 ? 1 : shape[0];
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
int const BITS_PER_ELT = get_bits_in_quant_type(quant_type);
int const BITS_PER_ELT = get_weight_quant_bits(quant_type);
int const K = 16 / BITS_PER_ELT;
int const ELTS_PER_BYTE = 8 / BITS_PER_ELT;
int const ELTS_PER_REG = 32 / BITS_PER_ELT;
@ -194,7 +215,8 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t con
fmtstr("Invalid shape for quantized tensor. On turing/Ampere, the number of cols must be a multiple of %d.",
MMA_SHAPE_N));
// The code is written as below so it works for both int8 and packed int4.
TLLM_CHECK_WITH_INFO(size_t(B_ROWS_PER_MMA) == row_permutation.size(), "Unexpected number of LDSM rows permuted.");
for (int expert = 0; expert < num_experts; ++expert)
{
const int64_t matrix_offset = expert * int64_t(num_rows) * int64_t(num_vec_cols);
@ -206,8 +228,7 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t con
for (int write_col = 0; write_col < num_vec_cols; ++write_col)
{
int const write_row = base_row + tile_row;
int const tile_read_row
= 8 * (((tile_row % ELTS_PER_REG) / 2)) + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG);
int const tile_read_row = row_permutation[tile_row];
int const read_row = base_row + tile_read_row;
int const read_col = write_col;
@ -229,7 +250,7 @@ template <QuantType quant_type>
void subbyte_transpose_impl(
int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, std::vector<size_t> const& shape)
{
int const bits_per_elt = get_bits_in_quant_type(quant_type);
constexpr int bits_per_elt = get_weight_quant_bits(quant_type);
TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
const size_t num_experts = shape.size() == 2 ? 1 : shape[0];
@ -243,8 +264,7 @@ void subbyte_transpose_impl(
uint8_t const* input_byte_ptr = reinterpret_cast<uint8_t const*>(quantized_tensor);
uint8_t* output_byte_ptr = reinterpret_cast<uint8_t*>(transposed_quantized_tensor);
static_assert(quant_type == QuantType::INT8_WEIGHT_ONLY || quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY, "");
static constexpr int ELTS_PER_BYTE = quant_type == QuantType::INT8_WEIGHT_ONLY ? 1 : 2;
static constexpr int ELTS_PER_BYTE = 8 / bits_per_elt;
static constexpr int M_TILE_L1 = 64;
static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE;
@ -294,7 +314,7 @@ void subbyte_transpose_impl(
}
}
if (quant_type == QuantType::INT8_WEIGHT_ONLY)
if constexpr (bits_per_elt == 8)
{
for (int ii = 0; ii < M_TILE_L1; ++ii)
{
@ -304,7 +324,7 @@ void subbyte_transpose_impl(
}
}
}
else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY)
else if constexpr (bits_per_elt == 4)
{
for (int ii = 0; ii < M_TILE_L1; ++ii)
@ -368,14 +388,17 @@ void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quanti
std::vector<size_t> const& shape, QuantType quant_type)
{
if (quant_type == QuantType::INT8_WEIGHT_ONLY)
if (quant_type == QuantType::W8_A16)
{
subbyte_transpose_impl<QuantType::INT8_WEIGHT_ONLY>(transposed_quantized_tensor, quantized_tensor, shape);
subbyte_transpose_impl<QuantType::W8_A16>(transposed_quantized_tensor, quantized_tensor, shape);
}
else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY)
else if (quant_type == QuantType::W4_A16)
{
subbyte_transpose_impl<QuantType::PACKED_INT4_WEIGHT_ONLY>(
transposed_quantized_tensor, quantized_tensor, shape);
subbyte_transpose_impl<QuantType::W4_A16>(transposed_quantized_tensor, quantized_tensor, shape);
}
else if (quant_type == QuantType::W4_AFP8)
{
subbyte_transpose_impl<QuantType::W4_AFP8>(transposed_quantized_tensor, quantized_tensor, shape);
}
else
{
@ -464,12 +487,16 @@ void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const siz
void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type)
{
if (quant_type == QuantType::INT8_WEIGHT_ONLY)
if (quant_type == QuantType::W8_A16)
{
add_bias_and_interleave_int8s_inplace(tensor, num_elts);
}
else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY)
else if (quant_type == QuantType::W4_A16 || quant_type == QuantType::W4_AFP8)
{
// W4_AFP8 uses the same preprocessor as W4_A16 because the FP8 data must
// be converted to FP16 before the scales can be applied using CUDA cores.
// As a result, we still want permute the data so that it is well aligned
// for conversion to FP16.
add_bias_and_interleave_int4s_inplace(tensor, num_elts);
}
else
@ -482,15 +509,12 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, int8_t
std::vector<size_t> const& shape, QuantType quant_type, LayoutDetails details)
{
// We only want to run this step for weight only quant.
TLLM_CHECK(quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY || quant_type == QuantType::INT8_WEIGHT_ONLY);
TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
const size_t num_experts = shape.size() == 2 ? 1 : shape[0];
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
int const BITS_PER_ELT = get_bits_in_quant_type(quant_type);
int const BITS_PER_ELT = get_weight_quant_bits(quant_type);
int const elts_in_int32 = 32 / BITS_PER_ELT;
int const rows_per_tile = details.rows_per_column_tile;
@ -551,7 +575,7 @@ void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, in
num_elts *= dim;
}
const size_t num_bytes = num_elts * get_bits_in_quant_type(quant_type) / 8;
const size_t num_bytes = num_elts * get_weight_quant_bits(quant_type) / 8;
std::vector<int8_t> src_buf(num_bytes);
std::vector<int8_t> dst_buf(num_bytes);
@ -633,9 +657,11 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
int const bits_in_type = get_bits_in_quant_type(quant_type);
int const bits_in_type = get_weight_quant_bits(quant_type);
int const bytes_per_out_col = num_cols * bits_in_type / 8;
int const bits_per_weigtht_element = get_weight_quant_bits(quant_type);
std::vector<int8_t> weight_buf;
if (unprocessed_quantized_weight == nullptr)
{
@ -685,15 +711,15 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
for (int jj = 0; jj < bytes_per_out_col; ++jj)
{
if (quant_type == QuantType::INT8_WEIGHT_ONLY)
if (bits_per_weigtht_element == 8)
{
float const col_scale = per_col_max[jj];
float const weight_elt = float(current_weight_row[jj]);
float const scaled_weight = round(weight_elt / col_scale);
float const scaled_weight = (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f;
const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight)));
current_quantized_weight_row[jj] = clipped_weight;
}
else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY)
else if (bits_per_weigtht_element == 4)
{
// We will pack two int4 elements per iteration of the inner loop.
@ -705,7 +731,7 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
{
float const col_scale = per_col_max[input_idx];
float const weight_elt = float(current_weight_row[input_idx]);
float const scaled_weight = round(weight_elt / col_scale);
float const scaled_weight = (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f;
int int_weight = int(scaled_weight);
const int8_t clipped_weight = std::max(-8, std::min(7, int_weight));

View File

@ -31,10 +31,21 @@ namespace cutlass_kernels
enum class QuantType
{
INT8_WEIGHT_ONLY,
PACKED_INT4_WEIGHT_ONLY
W8_A16,
W4_A16,
W4_AFP8
};
int get_bits_in_quant_type(QuantType quant_type);
constexpr int get_weight_quant_bits(QuantType quant_type)
{
switch (quant_type)
{
case QuantType::W8_A16: return 8;
case QuantType::W4_A16: return 4;
case QuantType::W4_AFP8: return 4;
default: TLLM_CHECK_WITH_INFO(false, "Invalid quant_type"); return -1;
}
}
// Shapes here can be 2 or 3D. 2-D shapes are [num_rows, num_cols]
// 3-D shapes are [num_experts, num_rows, num_cols]

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -24,7 +24,7 @@ namespace cutlass_kernels
{
#ifdef ENABLE_FP8
template class CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, /*Activation Type*/
cutlass::int4b_t, /*Weight Type*/
cutlass::uint4b_t, /*Weight Type*/
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, half, /*Scale and Zero Type*/
half, /*Bias type Type*/
half /*Output type Type*/

View File

@ -24,7 +24,7 @@ namespace cutlass_kernels
{
#ifdef ENABLE_FP8
template class CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, /*Activation Type*/
cutlass::int4b_t, /*Weight Type*/
cutlass::uint4b_t, /*Weight Type*/
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, half, /*Scale and Zero Type*/
half, /*Bias type Type*/
half /*Output type Type*/

View File

@ -24,7 +24,7 @@ namespace cutlass_kernels
{
#ifdef ENABLE_FP8
template class CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, /*Activation Type*/
cutlass::int4b_t, /*Weight Type*/
cutlass::uint4b_t, /*Weight Type*/
cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, half, /*Scale and Zero Type*/
half, /*Bias type Type*/
half /*Output type Type*/

View File

@ -37,6 +37,7 @@
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h"
@ -50,65 +51,59 @@ namespace kernels
namespace cutlass_kernels
{
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
typename ThreadblockShape, typename WarpShape, int Stages>
void generic_mixed_gemm_kernelLauncher(T const* A, WeightType const* B, T const* weight_scales,
T const* weight_zero_points, T const* biases, float const alpha, T* C, int m, int n, int k, int const group_size,
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
int* occupancy = nullptr)
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename ThreadblockShape,
typename WarpShape, int Stages>
void generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
cudaStream_t stream, int* occupancy = nullptr)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
#ifdef ENABLE_BF16
static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
|| cutlass::platform::is_same<T, float>::value,
static_assert(
#ifdef ENABLE_FP8
cutlass::platform::is_same<ActivationType, __nv_fp8_e4m3>::value ||
#endif
cutlass::platform::is_same<ActivationType, __nv_bfloat16>::value
|| cutlass::platform::is_same<ActivationType, half>::value
|| cutlass::platform::is_same<ActivationType, float>::value,
"Specialized for bfloat16, half, float");
#else
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value,
static_assert(cutlass::platform::is_same<ActivationType, half>::value
|| cutlass::platform::is_same<ActivationType, float>::value,
"Specialized for half, float");
#endif
static_assert(cutlass::platform::is_same<T, WeightType>::value
static_assert(cutlass::platform::is_same<ActivationType, WeightType>::value
|| cutlass::platform::is_same<WeightType, uint8_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value,
"");
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
using ElementType_ =
typename cutlass::platform::conditional<cutlass::platform::is_same<T, half>::value, cutlass::half_t, T>::type;
#ifdef ENABLE_BF16
using ElementType =
typename cutlass::platform::conditional<cutlass::platform::is_same<ElementType_, __nv_bfloat16>::value,
cutlass::bfloat16_t, ElementType_>::type;
#else
using ElementType = ElementType_;
#endif
using CutlassWeightType_ =
typename cutlass::platform::conditional<cutlass::platform::is_same<WeightType, half>::value, cutlass::half_t,
WeightType>::type;
#ifdef ENABLE_BF16
using CutlassWeightType =
typename cutlass::platform::conditional<cutlass::platform::is_same<CutlassWeightType_, __nv_bfloat16>::value,
cutlass::bfloat16_t, CutlassWeightType_>::type;
#else
using CutlassWeightType = CutlassWeightType_;
#endif
using CutlassActivationType = typename TllmToCutlassTypeAdapter<ActivationType>::type;
using CutlassWeightType = typename TllmToCutlassTypeAdapter<WeightType>::type;
using CutlassScaleZeroType = typename TllmToCutlassTypeAdapter<ScaleZeroType>::type;
using CutlassBiasType = typename TllmToCutlassTypeAdapter<BiasType>::type;
using CutlassOutputType = typename TllmToCutlassTypeAdapter<OutputType>::type;
// We need separate config for each architecture since we will target different tensorcore instructions. For float,
// we do not target TCs.
using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits<ElementType, CutlassWeightType, arch>;
using MixedGemmArchTraits
= cutlass::gemm::kernel::MixedGemmArchTraits<CutlassActivationType, CutlassWeightType, arch>;
using ElementAccumulator = typename MixedGemmArchTraits::AccType;
using EpilogueOp = typename tkc::Epilogue<ElementType, MixedGemmArchTraits::ElementsPerAccessC, ElementAccumulator,
EpilogueTag>::Op;
constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<CutlassOutputType>::value;
using EpilogueOp =
typename tkc::Epilogue<CutlassOutputType, ElementsPerAccessC, ElementAccumulator, EpilogueTag>::Op;
using Operator = typename MixedGemmArchTraits::Operator;
using TaggedOperator = typename cutlass::arch::TagOperator<Operator, QuantOp>::TaggedOperator;
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<ElementType, cutlass::layout::RowMajor,
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<CutlassActivationType, cutlass::layout::RowMajor,
MixedGemmArchTraits::ElementsPerAccessA, CutlassWeightType, typename MixedGemmArchTraits::LayoutB,
MixedGemmArchTraits::ElementsPerAccessB, ElementType, cutlass::layout::RowMajor, ElementAccumulator,
MixedGemmArchTraits::ElementsPerAccessB, CutlassOutputType, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, arch, ThreadblockShape, WarpShape,
typename MixedGemmArchTraits::InstructionShape, EpilogueOp,
typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, Stages, true,
@ -138,6 +133,13 @@ void generic_mixed_gemm_kernelLauncher(T const* A, WeightType const* B, T const*
if constexpr (cutlass::isFinegrained(QuantOp))
{
if constexpr (cutlass::platform::is_same<CutlassActivationType, float_e4m3_t>::value)
{
if (group_size != 128)
{
throw std::runtime_error("Only group size 128 supported for fine grained W4A(fp)8 kernels.");
}
}
if (group_size != 64 && group_size != 128)
{
throw std::runtime_error("Only group size 64 and 128 supported for fine grained kernels.");
@ -173,12 +175,14 @@ void generic_mixed_gemm_kernelLauncher(T const* A, WeightType const* B, T const*
int const ld_scale_zero = cutlass::isFinegrained(QuantOp) ? n : 0;
ElementAccumulator output_op_beta = (biases == nullptr) ? ElementAccumulator(0.f) : ElementAccumulator(1.f);
typename Gemm::Arguments args({m, n, k}, group_size, {reinterpret_cast<ElementType*>(const_cast<T*>(A)), k},
typename Gemm::Arguments args({m, n, k}, group_size,
{reinterpret_cast<CutlassActivationType*>(const_cast<ActivationType*>(A)), k},
{reinterpret_cast<CutlassWeightType*>(const_cast<WeightType*>(B)), ldb},
{reinterpret_cast<ElementType*>(const_cast<T*>(weight_scales)), ld_scale_zero},
{reinterpret_cast<ElementType*>(const_cast<T*>(weight_zero_points)), ld_scale_zero},
{reinterpret_cast<ElementType*>(const_cast<T*>(biases)), 0}, {reinterpret_cast<ElementType*>(C), n},
gemm_config.split_k_factor, {ElementAccumulator(alpha), output_op_beta});
{reinterpret_cast<CutlassScaleZeroType*>(const_cast<ScaleZeroType*>(weight_scales)), ld_scale_zero},
{reinterpret_cast<CutlassScaleZeroType*>(const_cast<ScaleZeroType*>(weight_zero_points)), ld_scale_zero},
{reinterpret_cast<CutlassBiasType*>(const_cast<BiasType*>(biases)), 0},
{reinterpret_cast<CutlassOutputType*>(C), n}, gemm_config.split_k_factor,
{ElementAccumulator(alpha), output_op_beta});
// This assertion is enabled because because for the column interleaved layout, K MUST be a multiple of
// threadblockK. The reason for this is that the default pitchlinear iterators are used to handle walking over the
@ -227,13 +231,14 @@ void generic_mixed_gemm_kernelLauncher(T const* A, WeightType const* B, T const*
// This filters out invalid template combinations that we DON'T want instantiated in CUTLASS. For example,
// instantiating SM=75, Stages=3 is invalid so we would need to filter that out. Fine grained
// quanitzation is only supported on Ampere+ GPUs.
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
typename ThreadblockShape, typename WarpShape, int Stages>
void filter_and_run_mixed_gemm(T const* A, WeightType const* B, T const* weight_scales, T const* weight_zero_points,
T const* biases, float const alpha, T* C, int m, int n, int k, int const group_size,
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
int* occupancy = nullptr)
// quanitzation is only supported on Ampere+ GPUs. FP8 GEMM is only supported on Ada+ GPUs.
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename ThreadblockShape,
typename WarpShape, int Stages>
void filter_and_run_mixed_gemm(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
cudaStream_t stream, int* occupancy = nullptr)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
@ -251,39 +256,55 @@ void filter_and_run_mixed_gemm(T const* A, WeightType const* B, T const* weight_
+ std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages);
throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg);
}
else if constexpr (Stages == 2 && arch::kMinComputeCapability >= 89)
{
// Multistage only supported on Ampere
std::string err_msg = "Cutlass fpA_intB gemm not supported for arch "
+ std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages);
throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg);
}
else if constexpr (cutlass::platform::is_same<ActivationType, __nv_fp8_e4m3>::value
&& arch::kMinComputeCapability < 89)
{
// FP8 activation type only supported on Ada+ GPUs
std::string err_msg = "Cutlass fpA_intB gemm not supported for arch "
+ std::to_string(arch::kMinComputeCapability) + " with activation type set to FP8";
throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg);
}
else
{
generic_mixed_gemm_kernelLauncher<T, WeightType, arch, QuantOp, EpilogueTag, ThreadblockShape, WarpShape,
Stages>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config,
workspace, workspace_bytes, stream, occupancy);
generic_mixed_gemm_kernelLauncher<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, arch,
QuantOp, EpilogueTag, ThreadblockShape, WarpShape, Stages>(A, B, weight_scales, weight_zero_points, biases,
alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
}
}
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
typename ThreadblockShape, typename WarpShape>
void dispatch_gemm_config(T const* A, WeightType const* B, T const* weight_scales, T const* weight_zero_points,
T const* biases, float const alpha, T* C, int m, int n, int k, int const group_size,
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
int* occupancy = nullptr)
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename ThreadblockShape,
typename WarpShape>
void dispatch_gemm_config(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
cudaStream_t stream, int* occupancy = nullptr)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
switch (gemm_config.stages)
{
case 2:
filter_and_run_mixed_gemm<T, WeightType, arch, QuantOp, EpilogueTag, ThreadblockShape, WarpShape, 2>(A, B,
weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace,
workspace_bytes, stream, occupancy);
filter_and_run_mixed_gemm<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, arch, QuantOp,
EpilogueTag, ThreadblockShape, WarpShape, 2>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m,
n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
case 3:
filter_and_run_mixed_gemm<T, WeightType, arch, QuantOp, EpilogueTag, ThreadblockShape, WarpShape, 3>(A, B,
weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace,
workspace_bytes, stream, occupancy);
filter_and_run_mixed_gemm<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, arch, QuantOp,
EpilogueTag, ThreadblockShape, WarpShape, 3>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m,
n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
case 4:
filter_and_run_mixed_gemm<T, WeightType, arch, QuantOp, EpilogueTag, ThreadblockShape, WarpShape, 4>(A, B,
weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace,
workspace_bytes, stream, occupancy);
filter_and_run_mixed_gemm<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, arch, QuantOp,
EpilogueTag, ThreadblockShape, WarpShape, 4>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m,
n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
default:
std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages);
@ -315,55 +336,56 @@ void dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, Scal
constexpr bool all_types_are_the_same = std::is_same_v<ActivationType, ScaleZeroType>
&& std::is_same_v<ActivationType, BiasType> && std::is_same_v<ActivationType, OutputType>;
constexpr bool is_valid_pre_hopper = all_types_are_the_same && !any_is_fp8;
constexpr bool is_valid_pre_hopper = (all_types_are_the_same && !any_is_fp8) || (arch::kMinComputeCapability >= 89);
if constexpr (is_valid_pre_hopper)
{
// Note that SIMT configs are omitted here since they are not supported for fpA_intB.
// We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the
// best for mixed type gemms.
constexpr int tile_shape_k = 128 * 8 / cutlass::sizeof_bits<ActivationType>::value;
switch (gemm_config.tile_config)
{
case tkc::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
if constexpr (arch::kMinComputeCapability >= 75)
{
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
cutlass::gemm::GemmShape<16, 128, 64>, cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales,
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
stream, occupancy);
dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, arch, QuantOp,
EpilogueTag, cutlass::gemm::GemmShape<16, 128, tile_shape_k>,
cutlass::gemm::GemmShape<16, 32, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases,
alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
}
break;
case tkc::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
if constexpr (arch::kMinComputeCapability >= 75)
{
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
cutlass::gemm::GemmShape<16, 256, 64>, cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales,
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
stream, occupancy);
dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, arch, QuantOp,
EpilogueTag, cutlass::gemm::GemmShape<16, 256, tile_shape_k>,
cutlass::gemm::GemmShape<16, 64, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases,
alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
}
break;
case tkc::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
cutlass::gemm::GemmShape<32, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales,
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
stream, occupancy);
dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, arch, QuantOp,
EpilogueTag, cutlass::gemm::GemmShape<32, 128, tile_shape_k>,
cutlass::gemm::GemmShape<32, 32, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases, alpha,
C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
case tkc::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales,
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
stream, occupancy);
dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, arch, QuantOp,
EpilogueTag, cutlass::gemm::GemmShape<64, 128, tile_shape_k>,
cutlass::gemm::GemmShape<64, 32, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases, alpha,
C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
case tkc::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
if constexpr (arch::kMinComputeCapability >= 75)
{
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales,
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
stream, occupancy);
dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, arch, QuantOp,
EpilogueTag, cutlass::gemm::GemmShape<128, 128, tile_shape_k>,
cutlass::gemm::GemmShape<128, 32, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases,
alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
}
break;
case tkc::CutlassTileConfig::Undefined:
@ -430,12 +452,26 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
}
else if (sm_ >= 80 && sm_ < 90)
else if (sm_ >= 80 && sm_ < 89)
{
dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, cutlass::arch::Sm80,
QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
}
else if (sm_ == 89)
{
#if ENABLE_FP8 && ((__CUDACC_VER_MAJOR__ < 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4))
if constexpr (cutlass::platform::is_same<ActivationType, __nv_fp8_e4m3>::value)
{
throw std::runtime_error(
"[TensorRT-LLM Error][CutlassFpAIntBGemmRunner][dispatch_to_arch] INT4xFP8 GEMM for Ada needs "
"CUDA>=12.4");
}
#endif
dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, cutlass::arch::Sm89,
QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
}
else if (sm_ == 90)
{
sm90_dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
@ -520,8 +556,14 @@ std::vector<tkc::CutlassGemmConfig>
CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::getConfigs() const
{
static constexpr bool is_weight_only = !std::is_same<ActivationType, WeightType>::value;
std::vector<tkc::CutlassGemmConfig> candidateConfigs
= get_candidate_configs(sm_, is_weight_only, false, false, SPLIT_K_LIMIT, true);
tkc::CutlassGemmConfig::CandidateConfigTypeParam config_type_param
= tkc::CutlassGemmConfig::CandidateConfigTypeParam::HOPPER;
if (is_weight_only)
{
config_type_param = static_cast<tkc::CutlassGemmConfig::CandidateConfigTypeParam>(
config_type_param | tkc::CutlassGemmConfig::CandidateConfigTypeParam::WEIGHT_ONLY);
}
std::vector<tkc::CutlassGemmConfig> candidateConfigs = get_candidate_configs(sm_, SPLIT_K_LIMIT, config_type_param);
return candidateConfigs;
}

View File

@ -79,7 +79,7 @@ void sm90_dispatch_epilogue_schedules(ActivationType const* A, WeightType const*
2x2x1 cluster shape is only supported when both the M and N tiles are at least 128.
We make the above restrictions are to improve compilation speed in TRT-LLM by pruning kernels
We make the above restrictions to improve compilation speed in TRT-LLM, by pruning kernels
that may not be very useful in practice.
*/
template <typename CTAShape, typename ClusterShape>

View File

@ -66,7 +66,7 @@ void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
#ifdef COMPILE_HOPPER_MIXED_INPUT_GEMMS
#ifdef COMPILE_HOPPER_TMA_GEMMS
using CutlassActivationType = typename TllmToCutlassTypeAdapter<ActivationType>::type;
// For FAST_BUILD, only instantiate kernels with 128x128x128B with 1x1x1 cluster shape.
@ -286,11 +286,11 @@ void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType
}
#endif // FAST_BUILD
#else // COMPILE_HOPPER_MIXED_INPUT_GEMMS
#else // COMPILE_HOPPER_TMA_GEMMS
throw std::runtime_error(
"[TensorRT-LLm Error][fpA_intB Runner] Please recompile with support for hopper by passing 90-real as an arch "
"to build_wheel.py.");
#endif // COMPILE_HOPPER_MIXED_INPUT_GEMMS
#endif // COMPILE_HOPPER_TMA_GEMMS
}
} // namespace cutlass_kernels

View File

@ -365,10 +365,15 @@ void CutlassInt8GemmRunner<T>::gemm(int8_t const* A, int8_t const* B, tk::QuantM
template <typename T>
std::vector<tkc::CutlassGemmConfig> CutlassInt8GemmRunner<T>::getConfigs() const
{
static constexpr bool isWeightOnly = false;
std::vector<tkc::CutlassGemmConfig> candidateConfigs
= get_candidate_configs(mSm, isWeightOnly, mSm <= 70, /* SIMT configs */
true, SPLIT_K_LIMIT); /* INT8 configs */
auto config_type_param = tkc::CutlassGemmConfig::CandidateConfigTypeParam::INT8_ONLY;
if (mSm <= 70)
{
config_type_param = static_cast<tkc::CutlassGemmConfig::CandidateConfigTypeParam>(
config_type_param | tkc::CutlassGemmConfig::CandidateConfigTypeParam::SIMT_ONLY);
}
std::vector<tkc::CutlassGemmConfig> candidateConfigs = get_candidate_configs(mSm, SPLIT_K_LIMIT, config_type_param);
return candidateConfigs;
}

View File

@ -0,0 +1,36 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include <cuda_runtime_api.h>
namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
{
// Keep in sync with the signature generated by generate_kernels.py
template <typename T, typename WeightType, typename EpilogueTag, typename TileShape, typename ClusterShape, bool BIAS>
void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts,
int multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size);
} // namespace cutlass_kernels
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,304 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass_extensions/compute_occupancy.h"
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#pragma GCC diagnostic pop
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
#include <sstream>
namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
{
// Hopper helper class for defining all the cutlass helper types
template <typename T, typename WeightType, typename EpilogueTag, typename TileShape, typename ClusterShape, bool BIAS>
struct HopperGroupedGemmInfo
{
using Arch = cutlass::arch::Sm90;
// TODO Update once mixed input support is added
static_assert(cutlass::platform::is_same<T, WeightType>::value,
"CUTLASS does not currently have specialised SM90 support for quantized operations");
#ifdef ENABLE_FP8
constexpr static bool IsFP8
= cutlass::platform::is_same<T, __nv_fp8_e4m3>::value || cutlass::platform::is_same<T, __nv_fp8_e5m2>::value;
#else
constexpr static bool IsFP8 = false;
#endif
#ifdef ENABLE_BF16
static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
|| cutlass::platform::is_same<T, float>::value || IsFP8,
"Specialized for bfloat16, half, float, fp8");
#else
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value || IsFP8,
"Specialized for half, float, fp8");
#endif
static_assert(cutlass::platform::is_same<T, WeightType>::value
|| cutlass::platform::is_same<WeightType, uint8_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::float_e4m3_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::float_e5m2_t>::value,
"Unexpected quantization type");
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
using ElementType = typename TllmToCutlassTypeAdapter<T>::type;
using CutlassWeightTypeMaybeUint4 = typename TllmToCutlassTypeAdapter<WeightType>::type;
// For legacy reasons we convert unsigned 8-bit to signed
using CutlassWeightTypeMaybeUint8
= std::conditional_t<std::is_same_v<CutlassWeightTypeMaybeUint4, cutlass::uint4b_t>, cutlass::int4b_t,
CutlassWeightTypeMaybeUint4>;
using CutlassWeightType
= std::conditional_t<std::is_same_v<CutlassWeightTypeMaybeUint8, uint8_t>, int8_t, CutlassWeightTypeMaybeUint8>;
using ElementA = ElementType;
using ElementB = CutlassWeightType;
template <class Element>
using CutlassOutputTypeAdaptor_t = typename TllmToCutlassTypeAdapter<
HopperGroupedGemmInput::OutputTypeAdaptor_t<typename CutlassToTllmTypeAdapter<Element>::type>>::type;
using ElementD = CutlassOutputTypeAdaptor_t<ElementType>;
using ElementC = std::conditional_t<BIAS, ElementType, void>;
using ElementCNoVoid = std::conditional_t<BIAS, ElementType, ElementD>;
using ElementAccumulator = float;
// A matrix configuration - this is transposed and swapped with B
using LayoutA = HopperGroupedGemmInput::LayoutA;
constexpr static int AlignmentA
= 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units
// of elements (up to 16 bytes)
// B matrix configuration - this is transposed and swapped with A
using LayoutB = HopperGroupedGemmInput::LayoutB; // Layout type for B matrix operand
constexpr static int AlignmentB
= 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units
// of elements (up to 16 bytes)
// C matrix configuration
using LayoutC = HopperGroupedGemmInput::LayoutC; // Layout type for C matrix operand
// Note we use ElementType here deliberately, so we don't break when BIAS is disabled
constexpr static int AlignmentC
= 128 / cutlass::sizeof_bits<ElementType>::value; // Memory access granularity/alignment of C matrix in units
// of elements (up to 16 bytes)
// D matrix configuration
using LayoutD = HopperGroupedGemmInput::LayoutD;
constexpr static int AlignmentD
= 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of D matrix
// in units of elements (up to 16 bytes)
static_assert(cutlass::platform::is_same<EpilogueTag, tensorrt_llm::cutlass_extensions::EpilogueOpDefault>::value,
"Hopper Grouped GEMM specialisation doesn't support fused activation");
using EpilogueOp
= cutlass::epilogue::fusion::LinearCombination<ElementD, ElementAccumulator, ElementC, ElementAccumulator>;
// TODO Add mode for fused activation once CUTLASS adds support
// using EpilogueSchedule = cutlass::platform::conditional_t<
// cutlass::platform::is_same<EpilogueOp, EpilogueOpDefault>::value,
// cutlass::epilogue::PtrArrayNoSmemWarpSpecialized,
// cutlass::epilogue::?????????????????? /// <<<<<< what supports activations
// >;
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< //
Arch, cutlass::arch::OpClassTensorOp, //
TileShape, ClusterShape, //
cutlass::epilogue::collective::EpilogueTileAuto, //
ElementAccumulator, ElementAccumulator, //
ElementC, LayoutC*, AlignmentC, //
ElementD, LayoutD*, AlignmentD, //
EpilogueSchedule>::CollectiveOp;
using StageCountAutoCarveout = cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>;
using KernelSchedule
= std::conditional_t<IsFP8, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< //
Arch, cutlass::arch::OpClassTensorOp, //
CutlassWeightType, LayoutB*, AlignmentB, // A & B swapped here
ElementType, LayoutA*, AlignmentA, //
ElementAccumulator, //
TileShape, ClusterShape, //
StageCountAutoCarveout, KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<HopperGroupedGemmInput::ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;
using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};
// Hopper specialised version
template <typename T, typename WeightType, typename EpilogueTag, typename TileShape, typename ClusterShape, bool BIAS>
void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts,
int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size)
{
#ifdef COMPILE_HOPPER_TMA_GEMMS
using namespace cute;
// For FAST_BUILD, only instantiate kernels with 128x128x128B with 1x1x1 cluster shape.
#ifdef FAST_BUILD
constexpr int TILE_K = 128 * 8 / cutlass::sizeof_bits<WeightType>::value;
using SupportedCtaShape = Shape<_128, _128, cute::Int<TILE_K>>;
using SupportedCgaShape = Shape<_1, _1, _1>;
if constexpr (cute::is_same_v<SupportedCtaShape, TileShape> && cute::is_same_v<SupportedCgaShape, ClusterShape>)
#endif // FAST_BUILD
{
using GemmInfo = HopperGroupedGemmInfo<T, WeightType, EpilogueTag, TileShape, ClusterShape, BIAS>;
using ElementAccumulator = typename GemmInfo::ElementAccumulator;
using ElementA = typename GemmInfo::ElementA;
using ElementB = typename GemmInfo::ElementB;
using ElementC = typename GemmInfo::ElementC;
using ElementCNoVoid = typename GemmInfo::ElementCNoVoid;
using ElementD = typename GemmInfo::ElementD;
using CollectiveMainloop = typename GemmInfo::CollectiveMainloop;
using CollectiveEpilogue = typename GemmInfo::CollectiveEpilogue;
using GemmKernel = typename GemmInfo::GemmKernel;
using GemmGrouped = typename GemmInfo::GemmGrouped;
if (kernel_occupancy != nullptr)
{
*kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel<GemmKernel, true>();
return;
}
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = multi_processor_count;
GemmGrouped gemm;
if (workspace_size != nullptr)
{
// Make a mock problem shape with just the minimal information actually required to get the workspace size
// This makes some assumptions about CUTLASS's implementation which is suboptimal. We have a check later to
// catch future cutlass updates causing silent breakages, but that is not fool proof.
// The alternative is to wait until we have data and then dynamically allocate the workspace
typename HopperGroupedGemmInput::ProblemShape shape_info{num_experts, nullptr, nullptr};
typename GemmGrouped::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped, shape_info, {}, {}, hw_info};
*workspace_size = gemm.get_workspace_size(args);
return;
}
using MainloopArguments = typename CollectiveMainloop::Arguments;
TLLM_CHECK(hopper_input.stride_a);
TLLM_CHECK(hopper_input.stride_b);
TLLM_CHECK(hopper_input.stride_d);
TLLM_CHECK(hopper_input.ptr_a);
TLLM_CHECK(hopper_input.ptr_b);
TLLM_CHECK(hopper_input.ptr_d);
const MainloopArguments mainloop_params = {reinterpret_cast<ElementB const**>(hopper_input.ptr_b),
hopper_input.stride_b, reinterpret_cast<ElementA const**>(hopper_input.ptr_a), hopper_input.stride_a};
typename GemmGrouped::EpilogueOutputOp::Params epilogue_scalars{
ElementAccumulator(1.f), hopper_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)};
epilogue_scalars.alpha_ptr_array = hopper_input.alpha_scale_ptr_array;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
// TODO(dastokes) ptr_c casts to ElementCNoVoid** because there is a workaround in CUTLASS
const EpilogueArguments epilogue_params
= {epilogue_scalars, reinterpret_cast<ElementCNoVoid const**>(hopper_input.ptr_c), hopper_input.stride_c,
reinterpret_cast<ElementD**>(hopper_input.ptr_d), hopper_input.stride_d};
typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, hopper_input.shape_info,
mainloop_params, epilogue_params, hw_info};
size_t calculated_ws_size = gemm.get_workspace_size(args);
TLLM_CHECK_WITH_INFO(calculated_ws_size <= hopper_input.gemm_workspace_size,
"Workspace is size %zu but only %zu were allocated", calculated_ws_size, hopper_input.gemm_workspace_size);
auto can_implement = gemm.can_implement(args);
TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess,
"Grouped GEMM kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)));
auto init_status = gemm.initialize(args, hopper_input.gemm_workspace);
TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess,
"Failed to initialize cutlass variable batched gemm. Error: "
+ std::string(cutlassGetStatusString(init_status)));
auto run_status = gemm.run(stream);
TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess,
"Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status)));
sync_check_cuda_error();
}
#ifdef FAST_BUILD
else
{
TLLM_THROW("Configuration was disabled by FAST_BUILD");
}
#endif
#else // COMPILE_HOPPER_TMA_GEMMS
TLLM_THROW("Please recompile with support for hopper by passing 90-real as an arch to build_wheel.py.");
#endif // COMPILE_HOPPER_TMA_GEMMS
}
} // namespace cutlass_kernels
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -16,13 +16,119 @@
*/
#pragma once
#include "tensorrt_llm/common/workspace.h"
#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h"
#include <cuda_runtime_api.h>
#include <optional>
#include <cutlass/gemm/group_array_problem_shape.hpp>
namespace tensorrt_llm
{
struct HopperGroupedGemmInput
{
template <class Tag>
using TransposeLayoutTag = std::conditional_t<std::is_same_v<Tag, cutlass::layout::RowMajor>,
cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>;
static_assert(std::is_same_v<cutlass::layout::RowMajor, TransposeLayoutTag<cutlass::layout::ColumnMajor>>);
static_assert(std::is_same_v<cutlass::layout::ColumnMajor, TransposeLayoutTag<cutlass::layout::RowMajor>>);
// Layout for A and B is transposed and then swapped in the implementation
// This uses B^T * A^T = (A * B)^T to get a better layout for the GEMM
using LayoutA = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for A matrix operand
using LayoutB = TransposeLayoutTag<cutlass::layout::ColumnMajor>; // Layout type for B matrix operand
using LayoutC = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for C matrix operand
using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand
using StrideA
= std::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutA*>>; // Use B because they will be swapped
using StrideB
= std::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutB*>>; // Use A because they will be swapped
using StrideC = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutC*>>;
using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;
template <class T>
constexpr static bool IsFP8_v = std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
template <class T>
using OutputTypeAdaptor_t = std::conditional_t<IsFP8_v<T>, float, T>;
using ProblemShape = cutlass::gemm::GroupProblemShape<cute::Shape<int64_t, int64_t, int64_t>>;
ProblemShape shape_info{};
StrideA* stride_a = nullptr;
StrideB* stride_b = nullptr;
StrideC* stride_c = nullptr;
StrideD* stride_d = nullptr;
void const** ptr_a = nullptr;
void const** ptr_b = nullptr;
void const** ptr_c = nullptr;
void** ptr_d = nullptr;
float const** alpha_scale_ptr_array = nullptr;
uint8_t* gemm_workspace = nullptr;
size_t gemm_workspace_size = 0;
static auto workspaceBuffers(int num_experts)
{
size_t problem_shape_size = sizeof(ProblemShape::UnderlyingProblemShape) * num_experts;
size_t stride_a_size = sizeof(StrideA) * num_experts;
size_t stride_b_size = sizeof(StrideB) * num_experts;
size_t stride_c_size = sizeof(StrideC) * num_experts;
size_t stride_d_size = sizeof(StrideD) * num_experts;
size_t ptr_buf_size = sizeof(void*) * num_experts;
size_t scale_buf_size = sizeof(float**) * num_experts;
return std::array{problem_shape_size, stride_a_size, stride_b_size, stride_c_size, stride_d_size, ptr_buf_size,
ptr_buf_size, ptr_buf_size, ptr_buf_size, scale_buf_size};
}
static size_t workspaceSize(int num_experts)
{
auto buffers = workspaceBuffers(num_experts);
return tensorrt_llm::common::calculateTotalWorkspaceSize(buffers.data(), buffers.size());
}
void configureWorkspace(int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size)
{
auto buffers = workspaceBuffers(num_experts);
std::array<int8_t*, 10> pointers{};
TLLM_CHECK_WITH_INFO(pointers.size() == buffers.size(), "Mismatching workspace size and number of buffers");
for (int i = 0; i < buffers.size(); i++)
{
pointers[i] = start_ptr;
start_ptr = tensorrt_llm::common::nextWorkspacePtr(start_ptr, buffers[i]);
}
shape_info.num_groups = num_experts;
shape_info.problem_shapes = reinterpret_cast<ProblemShape::UnderlyingProblemShape*>(pointers[0]);
shape_info.host_problem_shapes = nullptr;
stride_a = reinterpret_cast<StrideA*>(pointers[1]);
stride_b = reinterpret_cast<StrideB*>(pointers[2]);
stride_c = reinterpret_cast<StrideC*>(pointers[3]);
stride_d = reinterpret_cast<StrideD*>(pointers[4]);
ptr_a = reinterpret_cast<void const**>(pointers[5]);
ptr_b = reinterpret_cast<void const**>(pointers[6]);
ptr_c = reinterpret_cast<void const**>(pointers[7]);
ptr_d = reinterpret_cast<void**>(pointers[8]);
alpha_scale_ptr_array = reinterpret_cast<float const**>(pointers[9]);
this->gemm_workspace = reinterpret_cast<uint8_t*>(gemm_workspace);
this->gemm_workspace_size = gemm_workspace_size;
}
bool isValid() const
{
return stride_a != nullptr && ptr_a != nullptr;
}
};
// Note update moe.py to match
enum class ActivationType
{
@ -53,24 +159,33 @@ public:
}
void moeGemmBiasAct(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C,
int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
ActivationType activation_type, cudaStream_t stream);
int64_t* total_rows_before_expert, HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n,
int64_t gemm_k, int num_experts, ActivationType activation_type, cudaStream_t stream);
void moeGemm(T const* A, WeightType const* B, T const* weight_scales, T* C, int64_t* total_rows_before_expert,
int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream);
HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cudaStream_t stream);
std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs();
std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs() const;
std::vector<cutlass_extensions::CutlassGemmConfig> getHopperConfigs() const;
std::vector<cutlass_extensions::CutlassGemmConfig> getAmpereConfigs() const;
bool isHopperSpecialised() const;
bool supportsHopperSpecialisation() const;
size_t calcMaxWorkspaceSize(int num_experts) const;
private:
template <typename EpilogueTag>
void dispatchToArch(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C,
int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy = nullptr);
int64_t* total_rows_before_expert, HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n,
int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, cudaStream_t stream,
int* occupancy = nullptr);
template <typename EpilogueTag>
void runGemm(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C,
int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cudaStream_t stream);
int64_t* total_rows_before_expert, HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n,
int64_t gemm_k, int num_experts, cudaStream_t stream);
private:
int sm_;

View File

@ -0,0 +1,25 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
namespace tensorrt_llm
{
#ifdef ENABLE_FP8
template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3>;
// template class MoeGemmRunner<__nv_fp8_e5m2, __nv_fp8_e5m2>;
#endif
} // namespace tensorrt_llm

View File

@ -24,6 +24,20 @@
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass_extensions/compute_occupancy.h"
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
@ -35,7 +49,14 @@
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "moe_gemm_kernels_template_sm90.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
@ -43,6 +64,8 @@
namespace tensorrt_llm
{
namespace kernels::cutlass_kernels
{
// ============================= Variable batched Gemm things ===========================
template <typename T, typename WeightType, typename arch, typename EpilogueTag, typename ThreadblockShape,
@ -66,27 +89,12 @@ void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, T const* weig
|| cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value,
"");
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
using ElementType_ =
typename cutlass::platform::conditional<cutlass::platform::is_same<T, half>::value, cutlass::half_t, T>::type;
#ifdef ENABLE_BF16
using ElementType =
typename cutlass::platform::conditional<cutlass::platform::is_same<ElementType_, __nv_bfloat16>::value,
cutlass::bfloat16_t, ElementType_>::type;
#else
using ElementType = ElementType_;
#endif
static_assert(!cutlass::platform::is_same<arch, cutlass::arch::Sm90>::value,
"Sm90 architecture should use specialised kernels");
using CutlassWeightType_ =
typename cutlass::platform::conditional<cutlass::platform::is_same<WeightType, half>::value, cutlass::half_t,
WeightType>::type;
#ifdef ENABLE_BF16
using CutlassWeightType =
typename cutlass::platform::conditional<cutlass::platform::is_same<CutlassWeightType_, __nv_bfloat16>::value,
cutlass::bfloat16_t, CutlassWeightType_>::type;
#else
using CutlassWeightType = CutlassWeightType_;
#endif
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
using ElementType = typename TllmToCutlassTypeAdapter<T>::type;
using CutlassWeightType = typename TllmToCutlassTypeAdapter<WeightType>::type;
// We need separate config for each architecture since we will target different tensorcore instructions. For float,
// we do not target TCs.
@ -147,50 +155,29 @@ void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, T const* weig
"Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status)));
}
template <typename T, typename WeightType, typename arch, typename EpilogueTag, typename ThreadblockShape,
typename WarpShape, int Stages, typename Enable = void>
struct dispatch_stages
{
static void dispatch(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C,
int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
int* occupancy = nullptr)
{
TLLM_THROW("Cutlass fpA_intB gemm. Not instantiated for arch %d with stages set to %d",
arch::kMinComputeCapability, Stages);
}
};
} // namespace kernels::cutlass_kernels
template <typename T, typename WeightType, typename arch, typename EpilogueTag, typename ThreadblockShape,
typename WarpShape>
struct dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>
template <typename T, typename WeightType, typename Arch, typename EpilogueTag, typename ThreadblockShape,
typename WarpShape, int Stages>
static void dispatch(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C,
int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
int* occupancy = nullptr)
{
static void dispatch(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C,
int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
int* occupancy = nullptr)
static_assert(!std::is_same_v<Arch, cutlass::arch::Sm90>, "Use TMA specialised functions for arch SM90");
constexpr bool isFp8 = std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
if constexpr ((Stages == 2 || Arch::kMinComputeCapability >= 80) && !isFp8)
{
genericMoeGemmKernelLauncher<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>(A, B,
weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config,
multi_processor_count, stream, occupancy);
kernels::cutlass_kernels::genericMoeGemmKernelLauncher<T, WeightType, Arch, EpilogueTag, ThreadblockShape,
WarpShape, Stages>(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k,
num_experts, gemm_config, multi_processor_count, stream, occupancy);
}
};
template <typename T, typename WeightType, typename EpilogueTag, typename ThreadblockShape, typename WarpShape,
int Stages>
struct dispatch_stages<T, WeightType, cutlass::arch::Sm80, EpilogueTag, ThreadblockShape, WarpShape, Stages,
typename std::enable_if<(Stages > 2)>::type>
{
static void dispatch(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C,
int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
int* occupancy = nullptr)
else
{
genericMoeGemmKernelLauncher<T, WeightType, cutlass::arch::Sm80, EpilogueTag, ThreadblockShape, WarpShape,
Stages>(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts,
gemm_config, multi_processor_count, stream, occupancy);
TLLM_THROW(
"Cutlass gemm. Not instantiated for arch %d with stages set to %d", Arch::kMinComputeCapability, Stages);
}
};
}
template <typename T, typename WeightType, typename arch, typename EpilogueTag, typename ThreadblockShape,
typename WarpShape>
@ -202,19 +189,19 @@ void dispatchGemmConfig(T const* A, WeightType const* B, T const* weight_scales,
switch (gemm_config.stages)
{
case 2:
using DispatcherStages2 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>;
DispatcherStages2::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k,
num_experts, gemm_config, multi_processor_count, stream, occupancy);
dispatch<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>(A, B, weight_scales, biases, C,
total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream,
occupancy);
break;
case 3:
using DispatcherStages3 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 3>;
DispatcherStages3::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k,
num_experts, gemm_config, multi_processor_count, stream, occupancy);
dispatch<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 3>(A, B, weight_scales, biases, C,
total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream,
occupancy);
break;
case 4:
using DispatcherStages4 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 4>;
DispatcherStages4::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k,
num_experts, gemm_config, multi_processor_count, stream, occupancy);
dispatch<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 4>(A, B, weight_scales, biases, C,
total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream,
occupancy);
break;
default: TLLM_THROW("dispatchGemmConfig does not support stages %d", gemm_config.stages); break;
}
@ -226,7 +213,7 @@ template <typename T, typename WeightType, typename arch, typename EpilogueTag,
typename std::enable_if<!std::is_same<T, float>::value && std::is_same<T, WeightType>::value>::type* = nullptr>
void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C,
int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, cudaStream_t stream,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
int* occupancy = nullptr)
{
switch (gemm_config.tile_config)
@ -279,7 +266,7 @@ template <typename T, typename WeightType, typename arch, typename EpilogueTag,
typename std::enable_if<!std::is_same<T, float>::value && !std::is_same<T, WeightType>::value>::type* = nullptr>
void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C,
int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, cudaStream_t stream,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
int* occupancy = nullptr)
{
switch (gemm_config.tile_config)
@ -330,7 +317,7 @@ template <typename T, typename WeightType, typename arch, typename EpilogueTag,
typename std::enable_if<std::is_same<T, float>::value>::type* = nullptr>
void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C,
int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, cudaStream_t stream,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
int* occupancy = nullptr)
{
switch (gemm_config.tile_config)
@ -349,15 +336,78 @@ void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_s
}
template <typename T, typename WeightType>
std::vector<cutlass_extensions::CutlassGemmConfig> MoeGemmRunner<T, WeightType>::getConfigs()
std::vector<cutlass_extensions::CutlassGemmConfig> MoeGemmRunner<T, WeightType>::getConfigs() const
{
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
static constexpr bool only_simt_configs = std::is_same<T, float>::value;
std::vector<cutlass_extensions::CutlassGemmConfig> candidate_configs
= kernels::cutlass_kernels::get_candidate_configs(sm_, is_weight_only, only_simt_configs);
std::vector<cutlass_extensions::CutlassGemmConfig> candidate_configs = getAmpereConfigs();
std::vector<cutlass_extensions::CutlassGemmConfig> hopper_configs = getHopperConfigs();
std::copy(hopper_configs.begin(), hopper_configs.end(), std::back_inserter(candidate_configs));
return candidate_configs;
}
template <typename T, typename WeightType>
std::vector<cutlass_extensions::CutlassGemmConfig> MoeGemmRunner<T, WeightType>::getAmpereConfigs() const
{
using tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
static constexpr auto weight_only_flag
= std::is_same<T, WeightType>::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY;
static constexpr auto simt_only_flag
= std::is_same<T, float>::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE;
int const max_split_k = 1;
int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM;
int const enable_hopper = CutlassGemmConfig::NONE;
auto config_type_param = static_cast<CutlassGemmConfig::CandidateConfigTypeParam>(
weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper);
if (!kernels::cutlass_kernels::isValidAmpereMOESpecialisation<T, WeightType>())
{
return {};
}
std::vector<cutlass_extensions::CutlassGemmConfig> ampere_configs
= kernels::cutlass_kernels::get_candidate_configs(sm_, max_split_k, config_type_param);
return ampere_configs;
}
template <typename T, typename WeightType>
std::vector<cutlass_extensions::CutlassGemmConfig> MoeGemmRunner<T, WeightType>::getHopperConfigs() const
{
using tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
static constexpr auto weight_only_flag
= std::is_same<T, WeightType>::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY;
static constexpr auto simt_only_flag
= std::is_same<T, float>::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE;
int const max_split_k = 1;
int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM;
int const enable_hopper = CutlassGemmConfig::HOPPER;
auto config_type_param = static_cast<CutlassGemmConfig::CandidateConfigTypeParam>(
weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper);
if (!kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>())
{
return {};
}
std::vector<cutlass_extensions::CutlassGemmConfig> hopper_configs
= kernels::cutlass_kernels::get_candidate_configs(sm_, max_split_k, config_type_param);
return hopper_configs;
}
template <typename T, typename WeightType>
bool MoeGemmRunner<T, WeightType>::isHopperSpecialised() const
{
bool config_is_sm90 = best_config_ && best_config_->is_sm90;
return supportsHopperSpecialisation() && config_is_sm90;
}
template <typename T, typename WeightType>
bool MoeGemmRunner<T, WeightType>::supportsHopperSpecialisation() const
{
return sm_ == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>();
}
template <typename T, typename WeightType>
MoeGemmRunner<T, WeightType>::MoeGemmRunner()
{
@ -371,33 +421,72 @@ MoeGemmRunner<T, WeightType>::MoeGemmRunner()
template <typename T, typename WeightType>
template <typename EpilogueTag>
void MoeGemmRunner<T, WeightType>::dispatchToArch<EpilogueTag>(T const* A, WeightType const* B, T const* weight_scales,
T const* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k,
int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy)
T const* biases, T* C, int64_t* total_rows_before_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows,
int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config,
cudaStream_t stream, int* occupancy)
{
TLLM_CHECK_WITH_INFO(
sm_ == 90 || !hopper_input.isValid(), "Hopper input information is set for non specialised implementation");
TLLM_CHECK_WITH_INFO(
sm_ == 90 || !gemm_config.is_sm90, "Hopper configuration provided for non-Hopper architecture");
if (sm_ >= 70 && sm_ < 75)
{
dispatchMoeGemmToCutlass<T, WeightType, cutlass::arch::Sm70, EpilogueTag>(A, B, weight_scales, biases, C,
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_,
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_,
stream, occupancy);
}
else if (sm_ >= 75 && sm_ < 80)
{
dispatchMoeGemmToCutlass<T, WeightType, cutlass::arch::Sm75, EpilogueTag>(A, B, weight_scales, biases, C,
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_,
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_,
stream, occupancy);
}
else if (sm_ >= 80 && sm_ < 90)
{
dispatchMoeGemmToCutlass<T, WeightType, cutlass::arch::Sm80, EpilogueTag>(A, B, weight_scales, biases, C,
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_,
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_,
stream, occupancy);
}
else if (sm_ >= 90)
{
// TODO Update the arch to Sm90 once CUTLASS hopper specialisations are available
dispatchMoeGemmToCutlass<T, WeightType, cutlass::arch::Sm80, EpilogueTag>(A, B, weight_scales, biases, C,
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_,
stream, occupancy);
if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType, EpilogueTag>())
{
// We allow both SM90 and SM80 configurations to coexist because for some cases with small numbers of tokens
// SM80 is faster. We check here to see which is selected
if (gemm_config.is_sm90)
{
TLLM_CHECK_WITH_INFO(biases != nullptr || hopper_input.ptr_c == nullptr,
"Input biases and hopper input disagree if bias is enabled");
TLLM_CHECK_WITH_INFO(hopper_input.isValid(), "Calling SM90 configuration with invalid hopper config");
dispatchMoeGemmSelectTileShapeSM90<T, WeightType, EpilogueTag>(
hopper_input, num_experts, gemm_config, multi_processor_count_, stream, occupancy, nullptr);
return;
}
// Fallthrough to SM80 impl below
}
// Do Ampere case instead
if constexpr (kernels::cutlass_kernels::isValidAmpereMOESpecialisation<T, WeightType, EpilogueTag>())
{
TLLM_CHECK_WITH_INFO(!hopper_input.isValid(),
"Non-specialised Hopper implementation is being rerouted to fallback implementation so input "
"information is not required");
TLLM_CHECK_WITH_INFO(!gemm_config.is_sm90,
"GEMM config is for SM90 configuration, but this configuration is not valid for Hppper");
dispatchMoeGemmToCutlass<T, WeightType, cutlass::arch::Sm80, EpilogueTag>(A, B, weight_scales, biases, C,
total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_,
stream, occupancy);
}
else
{
// Should only hit by FP8 configs during GEMM profiling pass. Never at runtime
TLLM_THROW("Configuration expects SM80 but configuration is not supported by SM80 kernels");
}
}
else
{
@ -405,59 +494,74 @@ void MoeGemmRunner<T, WeightType>::dispatchToArch<EpilogueTag>(T const* A, Weigh
}
}
template <typename T, typename WeightType>
size_t MoeGemmRunner<T, WeightType>::calcMaxWorkspaceSize(int num_experts) const
{
if (!supportsHopperSpecialisation())
{
return 0;
}
TLLM_CHECK_WITH_INFO((kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>()),
"Configuration is specialised for Hopper but not supported");
if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>())
{
auto configs = getHopperConfigs();
size_t max_size = 0;
bool has_config = false;
for (auto conf : configs)
{
try
{
size_t size = calcMaxWorkspaceSizeSM90<T, WeightType>(num_experts, conf, multi_processor_count_);
max_size = std::max(max_size, size);
has_config = true;
}
catch (tensorrt_llm::common::TllmException const& e)
{
TLLM_LOG_TRACE("Unsupported config skipped when calculating MOE workspace size");
}
}
TLLM_CHECK_WITH_INFO(has_config, "Could not find valid config when calculating workspace size");
return max_size;
}
assert(false); // Unreachable
}
template <typename T, typename WeightType>
template <typename EpilogueTag>
void MoeGemmRunner<T, WeightType>::runGemm<EpilogueTag>(T const* A, WeightType const* B, T const* weight_scales,
T const* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k,
int num_experts, cudaStream_t stream)
T const* biases, T* C, int64_t* total_rows_before_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows,
int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream)
{
auto chosen_conf = this->best_config_;
if (!chosen_conf)
{
auto candidate_configs = getConfigs();
std::vector<int> occupancies(candidate_configs.size());
for (size_t ii = 0; ii < candidate_configs.size(); ++ii)
{
dispatchToArch<EpilogueTag>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n,
gemm_k, num_experts, candidate_configs[ii], stream, &occupancies[ii]);
}
static constexpr int workspace_bytes = 0; // No workspace for MoE GEMMs.
static constexpr int split_k_limit = 1; // MoE GEMM does not support split-k.
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
chosen_conf = kernels::cutlass_kernels::estimate_best_config_from_occupancies(candidate_configs, occupancies,
total_rows, gemm_n, gemm_k, num_experts, split_k_limit, workspace_bytes, multi_processor_count_,
is_weight_only);
}
assert(chosen_conf);
dispatchToArch<EpilogueTag>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k,
num_experts, *chosen_conf, stream);
TLLM_CHECK_WITH_INFO(this->best_config_, "No MOE GEMM config set at runtime");
auto chosen_conf = *this->best_config_;
dispatchToArch<EpilogueTag>(A, B, weight_scales, biases, C, total_rows_before_expert, hopper_input, total_rows,
gemm_n, gemm_k, num_experts, chosen_conf, stream);
}
template <typename T, typename WeightType>
void MoeGemmRunner<T, WeightType>::moeGemmBiasAct(T const* A, WeightType const* B, T const* weight_scales,
T const* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k,
int num_experts, ActivationType activation_type, cudaStream_t stream)
T const* biases, T* C, int64_t* total_rows_before_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows,
int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, cudaStream_t stream)
{
switch (activation_type)
{
case ActivationType::Relu:
runGemm<cutlass_extensions::EpilogueOpDefaultReLU>(
A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream);
runGemm<cutlass_extensions::EpilogueOpDefaultReLU>(A, B, weight_scales, biases, C, total_rows_before_expert,
hopper_input, total_rows, gemm_n, gemm_k, num_experts, stream);
break;
case ActivationType::Gelu:
runGemm<cutlass_extensions::EpilogueOpDefaultFtGelu>(
A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream);
runGemm<cutlass_extensions::EpilogueOpDefaultFtGelu>(A, B, weight_scales, biases, C, total_rows_before_expert,
hopper_input, total_rows, gemm_n, gemm_k, num_experts, stream);
break;
case ActivationType::Silu:
runGemm<cutlass_extensions::EpilogueOpDefaultSilu>(
A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream);
runGemm<cutlass_extensions::EpilogueOpDefaultSilu>(A, B, weight_scales, biases, C, total_rows_before_expert,
hopper_input, total_rows, gemm_n, gemm_k, num_experts, stream);
break;
case ActivationType::Identity:
runGemm<cutlass_extensions::EpilogueOpDefault>(
A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream);
runGemm<cutlass_extensions::EpilogueOpDefault>(A, B, weight_scales, biases, C, total_rows_before_expert,
hopper_input, total_rows, gemm_n, gemm_k, num_experts, stream);
break;
case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break;
default: TLLM_THROW("Invalid activation type."); break;
@ -466,11 +570,11 @@ void MoeGemmRunner<T, WeightType>::moeGemmBiasAct(T const* A, WeightType const*
template <typename T, typename WeightType>
void MoeGemmRunner<T, WeightType>::moeGemm(T const* A, WeightType const* B, T const* weight_scales, T* C,
int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cudaStream_t stream)
int64_t* total_rows_before_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, int64_t gemm_n,
int64_t gemm_k, int num_experts, cudaStream_t stream)
{
runGemm<cutlass_extensions::EpilogueOpDefault>(
A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream);
runGemm<cutlass_extensions::EpilogueOpDefault>(A, B, weight_scales, nullptr, C, total_rows_before_expert,
hopper_input, total_rows, gemm_n, gemm_k, num_experts, stream);
}
} // namespace tensorrt_llm

View File

@ -0,0 +1,214 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Ignore CUTLASS warnings about type punning
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass_extensions/compute_occupancy.h"
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#pragma GCC diagnostic pop
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
#include <sstream>
namespace tensorrt_llm
{
template <typename T, typename WeightType, typename EpilogueTag, typename TileShape, typename ClusterShape>
void dispatchMoeGemmSelectBiasSM90(HopperGroupedGemmInput hopper_input, int num_experts, int multi_processor_count,
cudaStream_t stream, int* occupancy, size_t* workspace_size)
{
static_assert(kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType, EpilogueTag>(),
"Invalid hopper configuration invoked, fallback to Sm80");
TLLM_CHECK_WITH_INFO(
workspace_size || hopper_input.isValid(), "Hopper specialisation is missing additional input information");
// auto func = hopper_input.ptr_c ?
// kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper<T, WeightType,
// cutlass::arch::Sm90, EpilogueTag, true>
// :
// kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper<T,
// WeightType,
// cutlass::arch::Sm90, EpilogueTag, false>;
// TODO(dastokes) Re-enable bias when CUTLASS supports it
auto func = kernels::cutlass_kernels::sm90_generic_moe_gemm_kernelLauncher<T, WeightType, EpilogueTag, TileShape,
ClusterShape, false>;
func(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size);
}
/*
1x1x1 cluster shape is are supported for any tile shape.
2x1x1 cluster shape is only supported for when the M tile is at least 128.
1x2x1 cluster shape is only supported when the N tile is at least 128.
2x2x1 cluster shape is only supported when both the M and N tiles are at least 128.
We make the above restrictions are to improve compilation speed in TRT-LLM by pruning kernels
that may not be very useful in practice.
*/
template <typename CTAShape, typename ClusterShape>
constexpr bool are_tile_shapes_supported()
{
using namespace cute;
constexpr int cta_m = get<0>(CTAShape{});
constexpr int cta_n = get<1>(CTAShape{});
constexpr int cga_m = get<0>(ClusterShape{});
constexpr int cga_n = get<1>(ClusterShape{});
if constexpr (cga_m == _1{} && cga_n == _1{})
{
return true;
}
else if constexpr (cga_m == _2{} && cga_n == _1{} && cta_m >= _128{})
{
return true;
}
else if constexpr (cga_m == _1{} && cga_n == _2{} && cta_n >= _128{})
{
return true;
}
else if constexpr (cga_m == _2{} && cga_n == _2{} && cta_m >= _128{} && cta_n >= _128{})
{
return true;
}
else
{
return false;
}
}
template <typename T, typename WeightType, typename EpilogueTag, typename TileShape>
void dispatchMoeGemmSelectClusterShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy,
size_t* workspace_size)
{
using namespace cute;
switch (gemm_config.cluster_shape)
{
#define SHAPE_CASE(M, N, K) \
case cutlass_extensions::ClusterShape::ClusterShape_##M##x##N##x##K: \
{ \
using ClusterShape = Shape<_##M, _##N, _##K>; \
if constexpr (are_tile_shapes_supported<TileShape, ClusterShape>()) \
{ \
dispatchMoeGemmSelectBiasSM90<T, WeightType, EpilogueTag, TileShape, ClusterShape>( \
hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); \
break; \
} \
else \
{ \
TLLM_THROW("Unsupported tile and cluster shape combination"); \
} \
}
SHAPE_CASE(1, 1, 1)
SHAPE_CASE(1, 2, 1)
SHAPE_CASE(2, 1, 1)
SHAPE_CASE(2, 2, 1)
#undef SHAPE_CASE
default: TLLM_THROW("Unsupported config for MoE gemm.");
}
} // namespace tensorrt_llm
template <typename T, typename WeightType, typename EpilogueTag>
void dispatchMoeGemmSelectTileShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy,
size_t* workspace_size)
{
using namespace cute;
switch (gemm_config.tile_config_sm90)
{
#define SHAPE_CASE(M, N, K) \
case cutlass_extensions::CutlassTileConfigSM90::CtaShape##M##x##N##x##K##B: \
{ \
constexpr int KtileBytes = K / sizeof(T); \
using KTileDim = Int<KtileBytes>; \
using TileShape = Shape<_##M, _##N, KTileDim>; \
dispatchMoeGemmSelectClusterShapeSM90<T, WeightType, EpilogueTag, TileShape>( \
hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, workspace_size); \
break; \
}
SHAPE_CASE(128, 16, 128)
SHAPE_CASE(128, 32, 128)
SHAPE_CASE(128, 64, 128)
SHAPE_CASE(128, 128, 128)
SHAPE_CASE(128, 256, 128)
#undef SHAPE_CASE
case cutlass_extensions::CutlassTileConfigSM90::Undefined: TLLM_THROW("GEMM config undefined."); break;
case cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic:
TLLM_THROW("GEMM config should have already been set by heuristic.");
break;
default: TLLM_THROW("Unsupported config for MoE gemm."); break;
}
}
template <typename T, typename WeightType>
size_t calcMaxWorkspaceSizeSM90(
int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count)
{
size_t count;
// Most of the values are ignored for WS size calculation. We reuse the function to reduce the template bloat
dispatchMoeGemmSelectTileShapeSM90<T, WeightType, cutlass_extensions::EpilogueOpDefault>(
HopperGroupedGemmInput{}, num_experts, gemm_config, multi_processor_count, cudaStream_t{0}, nullptr, &count);
return count;
}
} // namespace tensorrt_llm

View File

@ -0,0 +1,50 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/arch/mma_sm90.h"
#include "cutlass_extensions/epilogue_helpers.h"
namespace tensorrt_llm::kernels::cutlass_kernels
{
// Hopper arch
template <typename T, typename WeightType, typename EpilogueTag = cutlass_extensions::EpilogueOpDefault>
constexpr bool isValidHopperMOESpecialisation()
{
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
return cutlass::platform::is_same<T, WeightType>::value
&& cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value;
#else
return false; // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED is set when Hopper kernels are enabled
#endif
}
// Hopper arch
template <typename T, typename WeightType, typename EpilogueTag = cutlass_extensions::EpilogueOpDefault>
constexpr bool isValidAmpereMOESpecialisation()
{
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) and defined(ENABLE_FP8)
constexpr bool is_fp8
= cutlass::platform::is_same<T, __nv_fp8_e4m3>::value || cutlass::platform::is_same<T, __nv_fp8_e5m2>::value;
return !is_fp8;
#else
return true; // Default to true
#endif
}
} // namespace tensorrt_llm::kernels::cutlass_kernels

View File

@ -33,12 +33,14 @@ class TrtLlm_QuantOp(enum.Enum):
per_column_scale_only = enum_auto()
finegrained_scale_only = enum_auto()
finegrained_scale_and_zeros = enum_auto()
none = enum_auto()
QuantOpNames = {
TrtLlm_QuantOp.per_column_scale_only: "cs",
TrtLlm_QuantOp.finegrained_scale_only: "fgs",
TrtLlm_QuantOp.finegrained_scale_and_zeros: "fgsz"
TrtLlm_QuantOp.finegrained_scale_and_zeros: "fgsz",
TrtLlm_QuantOp.none: "noquant"
}
QuantOpTag = {
@ -47,7 +49,8 @@ QuantOpTag = {
TrtLlm_QuantOp.finegrained_scale_only:
"cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY",
TrtLlm_QuantOp.finegrained_scale_and_zeros:
"cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS"
"cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS",
TrtLlm_QuantOp.none: "void"
}
################################################################################
@ -56,7 +59,8 @@ QuantOpTag = {
CudaTypeName = {
DataType.e4m3: "__nv_fp8_e4m3",
DataType.bf16: "__nv_bfloat16",
DataType.f16: "half"
DataType.f16: "half",
DataType.f32: "float"
}
@ -64,22 +68,10 @@ CudaTypeName = {
# A data structure holding all info to instantiate gemm launchers in TRT LLM.
class TrtLlm_GemmLauncher:
def __init__(self,
gemm_kind,
arch,
act_type,
weight_type,
scalezero_type,
bias_type,
output_type,
quant_op,
epi_tag,
cta_shape,
warp_shape,
stages,
cga_shape=None,
mainloop_schedule=None,
epi_schedule=None):
def __init__(self, gemm_kind, arch, act_type, weight_type, scalezero_type,
bias_type, output_type, quant_op, epi_tag, cta_shape,
warp_shape, stages, cga_shape, mainloop_schedule,
epi_schedule):
self.gemm_kind = gemm_kind
self.arch = arch
self.act_type = act_type
@ -124,9 +116,7 @@ def tuple_to_cute_shape(shape):
def instantiate_operation(operation):
act_tag = CudaTypeName[operation.act_type]
weight_tag = DataTypeTag[operation.weight_type]
scale_zero_tag = CudaTypeName[operation.scalezero_type]
bias_tag = CudaTypeName[operation.bias_type]
out_tag = CudaTypeName[operation.output_type]
@ -138,18 +128,21 @@ def instantiate_operation(operation):
cute_cga_shape = tuple_to_cute_shape(operation.cga_shape)
kernel_sched = KernelScheduleTag[operation.mainloop_schedule]
# Here, we must append MixedInput depending on the schedule, since we know the types are different.
# It is a work around since the CUTLASS library did not have the MixedInput schedules at the time of writing.
if operation.mainloop_schedule in [
KernelScheduleType.TmaWarpSpecializedCooperative,
KernelScheduleType.TmaWarpSpecializedPingpong,
KernelScheduleType.TmaWarpSpecialized
]:
kernel_sched += "MixedInput"
epi_sched = EpilogueScheduleTag[operation.epi_schedule]
instantiation = f"""
if operation.gemm_kind == GemmKind.Gemm:
if operation.mainloop_schedule in [
KernelScheduleType.TmaWarpSpecializedCooperative,
KernelScheduleType.TmaWarpSpecializedPingpong,
KernelScheduleType.TmaWarpSpecialized
] and DataTypeSize[operation.act_type] != DataTypeSize[
operation.weight_type]:
# Here, we must append MixedInput depending on the schedule, since we know the types are different.
# It is a work around since the CUTLASS library did not have the MixedInput schedules at the time of writing.
kernel_sched += "MixedInput"
weight_tag = DataTypeTag[operation.weight_type]
instantiation = f"""
template void sm90_generic_mixed_gemm_kernelLauncher<{act_tag}, {weight_tag}, {scale_zero_tag}, {bias_tag}, {out_tag},
{quant_op}, {epi_tag},
{cute_cta_shape}, {cute_cga_shape},
@ -157,12 +150,28 @@ template void sm90_generic_mixed_gemm_kernelLauncher<{act_tag}, {weight_tag}, {s
const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zero_tag}*, const {bias_tag}*, const float,
{out_tag}*, int, int, int, const int, tensorrt_llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
);
"""
elif operation.gemm_kind == GemmKind.Grouped:
# Similar to MixedInput above, we must modify the tags for grouped gemm as CUTLASS library does not have the updated schedules
assert operation.mainloop_schedule in [
KernelScheduleType.TmaWarpSpecializedCooperative,
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum
]
assert operation.epi_schedule == EpilogueScheduleType.NoSmemWarpSpecialized
kernel_sched.replace("::Kernel", "::KernelGrouped")
epi_sched += "Grouped"
weight_tag = CudaTypeName[operation.weight_type]
instantiation = f"""
template void sm90_generic_moe_gemm_kernelLauncher<{act_tag}, {weight_tag},
{epi_tag}, {cute_cta_shape}, {cute_cga_shape}, false>
(HopperGroupedGemmInput, int, int, cudaStream_t, int*, size_t*);
"""
return instantiation
def get_file_content(launcher_inl_files, operations):
include_list = list()
for file in launcher_inl_files:
include_list.append(f"#include \"{file}\"")
@ -191,11 +200,12 @@ namespace cutlass_kernels
def write_file(launcher_inl_files, operations, output_file):
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, mode="w") as f:
f.write(get_file_content(launcher_inl_files, operations))
def is_op_valid(op):
def is_gemm_op_valid(op):
tile_m, tile_n, _ = op.cta_shape
cga_m, cga_n, _ = op.cga_shape
@ -214,14 +224,41 @@ def is_op_valid(op):
return False
def is_grouped_gemm_op_valid(op):
if not is_gemm_op_valid(op):
return False
if op.epi_tag != TrtLlm_EpilogueTag.epilogue_op_default:
return False
if op.epi_schedule != EpilogueScheduleType.NoSmemWarpSpecialized:
return False
if op.mainloop_schedule not in [
KernelScheduleType.TmaWarpSpecializedCooperative,
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum
]:
return False
return True
def is_op_valid(op):
if op.gemm_kind == GemmKind.Gemm:
return is_gemm_op_valid(op)
if op.gemm_kind == GemmKind.Grouped:
return is_grouped_gemm_op_valid(op)
################################################################################
def generate_sm90_operations():
def generate_sm90_mixed_gemm_operations():
arch = 90
# For legacy reasons, we use unsigned types for fp16 / bf16 activations.
# For legacy reasons, we use unsigned types for the weights. The instanitated template
# will remap those back to the signed type.
# Takes the form (activation_type, weight_type, scalezero_type, bias_type, output_type)
supported_dtypes = [
(DataType.e4m3, DataType.s4, DataType.f16, DataType.f16, DataType.f16),
(DataType.e4m3, DataType.u4, DataType.f16, DataType.f16, DataType.f16),
(DataType.f16, DataType.u4, DataType.f16, DataType.f16, DataType.f16),
(DataType.bf16, DataType.u4, DataType.bf16, DataType.bf16,
DataType.bf16),
@ -260,11 +297,56 @@ def generate_sm90_operations():
mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative if use_coop else KernelScheduleType.TmaWarpSpecializedPingpong
epi_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative if use_coop else EpilogueScheduleType.TmaWarpSpecialized
operation = TrtLlm_GemmLauncher(GemmKind.Gemm, arch, *dtype_combo, quant_op, epi_tag, cta_shape_mnk, \
warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule)
fpA_intB_operation = TrtLlm_GemmLauncher(GemmKind.Gemm, arch, *dtype_combo, quant_op, epi_tag, cta_shape_mnk, \
warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule)
if is_op_valid(operation):
operations.append(operation)
if is_op_valid(fpA_intB_operation):
operations.append(fpA_intB_operation)
return operations
def generate_sm90_grouped_gemm_operations():
arch = 90
supported_dtypes = [
DataType.f16, DataType.bf16, DataType.f32, DataType.e4m3
]
quant_ops = [TrtLlm_QuantOp.none]
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default]
M_TILES = [128] # Currently M tile must be 128 for Grouped GEMM
N_TILES = [16, 32, 64, 128, 256]
cta_shapes_mn = product(M_TILES, N_TILES)
warp_shape = [0, 0, 0] # ignored except for naming
stages = 0 # auto
cga_shapes = product([1, 2], [1, 2], [1])
partial_args = product(supported_dtypes, quant_ops, epi_tags, cta_shapes_mn,
cga_shapes)
operations = list()
for dtype, quant_op, epi_tag, cta_shape_mn, cga_shape in partial_args:
max_k_bits = 128 * 8
cta_shape_k = max_k_bits // DataTypeSize[dtype]
cta_shape_mnk = cta_shape_mn + (cta_shape_k, )
mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative if dtype else KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum
epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized
moe_gemm_operation = TrtLlm_GemmLauncher(
GemmKind.Grouped, arch, dtype, dtype, dtype, dtype, dtype, quant_op,
epi_tag, cta_shape_mnk, warp_shape, stages, cga_shape,
mainloop_schedule, epi_schedule)
if is_op_valid(moe_gemm_operation):
operations.append(moe_gemm_operation)
return operations
def generate_sm90_operations():
operations = generate_sm90_mixed_gemm_operations()
operations.extend(generate_sm90_grouped_gemm_operations())
return operations
@ -284,7 +366,10 @@ if __name__ == "__main__":
# Get the absolute path of the provided directory
output_dir = os.path.abspath(args.output_dir)
hopper_inl = "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl"
fpA_intB_inl = "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl"
moe_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl"
inl_map = {GemmKind.Gemm: [fpA_intB_inl], GemmKind.Grouped: [moe_gemm_inl]}
# The goal here is to group kernels with common instantiations together in order to reduce template instantiation overheads.
# Template instantiation dominates the time in a compilation unit, so it is the most important factor to improve.
@ -298,7 +383,9 @@ if __name__ == "__main__":
file_counter = 1
for key, value in op_groups.items():
gemm_kind, _, _ = key
out_file = os.path.join(
output_dir, f"cutlass_kernel_file_{file_counter}.generated.cu")
write_file([hopper_inl], value, out_file)
output_dir, GemmKindNames[gemm_kind],
f"cutlass_kernel_file_{file_counter}.generated.cu")
write_file(inl_map[gemm_kind], value, out_file)
file_counter += 1

Some files were not shown because too many files have changed in this diff Show More