Update TensorRT-LLM (#1688)

* Update TensorRT-LLM

---------

Co-authored-by: IbrahimAmin <ibrahimamin532@gmail.com>
Co-authored-by: Fabian Joswig <fjosw@users.noreply.github.com>
Co-authored-by: Pzzzzz <hello-cd.plus@hotmail.com>
Co-authored-by: CoderHam <hemant@cohere.com>
Co-authored-by: Konstantin Lopuhin <kostia.lopuhin@gmail.com>
This commit is contained in:
Kaiyu Xie 2024-05-28 20:07:49 +08:00 committed by GitHub
parent 5d8ca2faf7
commit f430a4b447
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
529 changed files with 1163954 additions and 9131 deletions

1
.gitignore vendored
View File

@ -8,7 +8,6 @@ __pycache__/
build*/
*.egg-info/
.coverage
*.csv
*.onnx
tmp/
venv/

View File

@ -1,210 +0,0 @@
# Change Log
## Versions 0.8.0
* Model Support
- Phi-1.5/2.0
- Mamba support (see examples/mamba/README.md)
- The support is limited to beam width = 1 and single-node single-GPU
- Nougat support (see examples/multimodal/README.md#nougat)
- Qwen-VL support (see examples/qwenvl/README.md)
- RoBERTa support, thanks to the contribution from @erenup
- Skywork model support
- Add example for multimodal models (BLIP with OPT or T5, LlaVA)
* Features
- Chunked context support (see docs/source/gpt_attention.md#chunked-context)
- LoRA support for C++ runtime (see docs/source/lora.md)
- Medusa decoding support (see examples/medusa/README.md)
- The support is limited to Python runtime for Ampere or newer GPUs with fp16 and bf16 accuracy, and the `temperature` parameter of sampling configuration should be 0
- StreamingLLM support for LLaMA (see docs/source/gpt_attention.md#streamingllm)
- Support for batch manager to return logits from context and/or generation phases
- Include support in the Triton backend
- Support AWQ and GPTQ for QWEN
- Support ReduceScatter plugin
- Support for combining `repetition_penalty` and `presence_penalty` #274
- Support for `frequency_penalty` #275
- OOTB functionality support:
- Baichuan
- InternLM
- Qwen
- BART
- LLaMA
- Support enabling INT4-AWQ along with FP8 KV Cache
- Support BF16 for weight-only plugin
- Baichuan
- P-tuning support
- INT4-AWQ and INT4-GPTQ support
- Decoder iteration-level profiling improvements
- Add `masked_select` and `cumsum` function for modeling
- Smooth Quantization support for ChatGLM2-6B / ChatGLM3-6B / ChatGLM2-6B-32K
- Add Weight-Only Support To Whisper #794, thanks to the contribution from @Eddie-Wang1120
- Support FP16 fMHA on NVIDIA V100 GPU
* API
- Add a set of High-level APIs for end-to-end generation tasks (see examples/high-level-api/README.md)
- **[BREAKING CHANGES]** Migrate models to the new build workflow, including LLaMA, Mistral, Mixtral, InternLM, ChatGLM, Falcon, GPT-J, GPT-NeoX, Medusa, MPT, Baichuan and Phi (see docs/source/checkpoint.md)
- **[BREAKING CHANGES]** Deprecate `LayerNorm` and `RMSNorm` plugins and removed corresponding build parameters
- **[BREAKING CHANGES]** Remove optional parameter `maxNumSequences` for GPT manager
* Bug fixes
- Fix the first token being abnormal issue when `--gather_all_token_logits` is enabled #639
- Fix LLaMA with LoRA enabled build failure #673
- Fix InternLM SmoothQuant build failure #705
- Fix Bloom int8_kv_cache functionality #741
- Fix crash in `gptManagerBenchmark` #649
- Fix Blip2 build error #695
- Add pickle support for `InferenceRequest` #701
- Fix Mixtral-8x7b build failure with custom_all_reduce #825
- Fix INT8 GEMM shape #935
- Minor bug fixes
* Performance
- **[BREAKING CHANGES]** Increase default `freeGpuMemoryFraction` parameter from 0.85 to 0.9 for higher throughput
- **[BREAKING CHANGES]** Disable `enable_trt_overlap` argument for GPT manager by default
- Performance optimization of beam search kernel
- Add bfloat16 and paged kv cache support for optimized generation MQA/GQA kernels
- Custom AllReduce plugins performance optimization
- Top-P sampling performance optimization
- LoRA performance optimization
- Custom allreduce performance optimization by introducing a ping-pong buffer to avoid an extra synchronization cost
- Integrate XQA kernels for GPT-J (beamWidth=4)
* Documentation
- Batch manager arguments documentation updates
- Add documentation for best practices for tuning the performance of TensorRT-LLM (See docs/source/perf_best_practices.md)
- Add documentation for Falcon AWQ support (See examples/falcon/README.md)
- Update to the `docs/source/checkpoint.md` documentation
- Update AWQ INT4 weight only quantization documentation for GPT-J
- Add blog: Speed up inference with SOTA quantization techniques in TRT-LLM
- Refine TensorRT-LLM backend README structure #133
- Typo fix #739
## Versions 0.7.0 / 0.7.1
* Models
- BART and mBART support in encoder-decoder models
- FairSeq Neural Machine Translation (NMT) family
- Mixtral-8x7B model
- Support weight loading for HuggingFace Mixtral model
- OpenAI Whisper
- Mixture of Experts support
- MPT - Int4 AWQ / SmoothQuant support
- Baichuan FP8 quantization support
* Features
- [Preview] Speculative decoding
- Add Python binding for `GptManager`
- Add a Python class `ModelRunnerCpp` that wraps C++ `gptSession`
- System prompt caching
- Enable split-k for weight-only cutlass kernels
- FP8 KV cache support for XQA kernel
- New Python builder API and `trtllm-build` command(already applied to [blip2](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/blip2) and [OPT](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/opt#3-build-tensorrt-engines) )
- Support `StoppingCriteria` and `LogitsProcessor` in Python generate API (thanks to the contribution from @zhang-ge-hao)
- fMHA support for chunked attention and paged kv cache
* Bug fixes
- Fix tokenizer usage in quantize.py #288, thanks to the contribution from @0xymoro
- Fix LLaMa with LoRA error #637
- Fix LLaMA GPTQ failure #580
- Fix Python binding for InferenceRequest issue #528
- Fix CodeLlama SQ accuracy issue #453
* Performance
- MMHA optimization for MQA and GQA
- LoRA optimization: cutlass grouped gemm
- Optimize Hopper warp specialized kernels
- Optimize AllReduce for parallel attention on Falcon and GPT-J
- Enable split-k for weight-only cutlass kernel when SM>=75
* Documentation
- Add [documentation for convert/build workflow](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/checkpoint.md)
## Versions 0.6.0 / 0.6.1
* Models
* ChatGLM3
* InternLM (contributed by @wangruohui)
* Mistral 7B (developed in collaboration with Mistral.AI)
* MQA/GQA support to MPT (and GPT) models (contributed by @bheilbrun)
* Qwen (contributed by @Tlntin and @zhaohb)
* Replit Code V-1.5 3B (external contribution)
* T5, mT5, Flan-T5 (Python runtime only)
* Features
* Add runtime statistics related to active requests and KV cache
utilization from the batch manager (see
the [batch manager](docs/source/batch_manager.md) documentation)
* Add `sequence_length` tensor to support proper lengths in beam-search
(when beam-width > 1 - see
[tensorrt_llm/batch_manager/GptManager.h](cpp/include/tensorrt_llm/batch_manager/GptManager.h))
* BF16 support for encoder-decoder models (Python runtime - see
[examples/enc_dec](examples/enc_dec/README.md))
* Improvements to memory utilization (CPU and GPU - including memory
leaks)
* Improved error reporting and memory consumption
* Improved support for stop and bad words
* INT8 SmoothQuant and INT8 KV Cache support for the Baichuan models (see
[examples/baichuan](examples/baichuan/README.md))
* INT4 AWQ Tensor Parallelism support and INT8 KV cache + AWQ/weight-only
support for the GPT-J model (see [examples/gptj](examples/gptj/README.md))
* INT4 AWQ support for the Falcon models
(see [examples/falcon](examples/falcon/README.md))
* LoRA support (functional preview only - limited to the Python runtime,
only QKV support and not optimized in terms of runtime performance) for
the GPT model (see the
[Run LoRA with the Nemo checkpoint](examples/gpt/README.md#Run-LoRA-with-the-Nemo-checkpoint)
in the GPT example)
* Multi-GPU support for encoder-decoder models (Python runtime - see
[examples/enc_dec](examples/enc_dec/README.md))
* New heuristic for launching the Multi-block Masked MHA kernel (similar
to FlashDecoding - see
[decoderMaskedMultiheadAttentionLaunch.h](cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionLaunch.h))
* Prompt-Tuning support for GPT and LLaMA models (see the
[Prompt-tuning](examples/gpt/README.md#Prompt-tuning) Section in the GPT example)
* Performance optimizations in various CUDA kernels
* Possibility to exclude input tokens from the output (see `excludeInputInOutput` in
[`GptManager`](cpp/include/tensorrt_llm/batch_manager/GptManager.h))
* Python binding for the C++ runtime (GptSession - see [`pybind`](cpp/tensorrt_llm/pybind))
* Support for different micro batch sizes for context and generation
phases with pipeline parallelism (see `GptSession::Config::ctxMicroBatchSize` and
`GptSession::Config::genMicroBatchSize` in
[tensorrt_llm/runtime/gptSession.h](cpp/include/tensorrt_llm/runtime/gptSession.h))
* Support for "remove input padding" for encoder-decoder models (see
[examples/enc_dec](examples/enc_dec/README.md))
* Support for context and generation logits (see `mComputeContextLogits` and
`mComputeGenerationLogits` in
[tensorrt_llm/runtime/gptModelConfig.h](cpp/include/tensorrt_llm/runtime/gptModelConfig.h))
* Support for `logProbs` and `cumLogProbs` (see `"output_log_probs"` and
`"cum_log_probs"` in [`GptManager`](cpp/include/tensorrt_llm/batch_manager/GptManager.h))
* Update to CUTLASS 3.x
* Bug fixes
* Fix for ChatGLM2 #93 and #138
* Fix tensor names error "RuntimeError: Tensor names
(`host_max_kv_cache_length`) in engine are not the same as expected in
the main branch" #369
* Fix weights split issue in BLOOM when `world_size = 2` ("array split
does not result in an equal division") #374
* Fix SmoothQuant multi-GPU failure with tensor parallelism is 2 #267
* Fix a crash in GenerationSession if stream keyword argument is not None
#202
* Fix a typo when calling PyNVML API [BUG] code bug #410
* Fix bugs related to the improper management of the `end_id` for various
models [C++ and Python]
* Fix memory leaks [C++ code and Python models]
* Fix the std::alloc error when running the gptManagerBenchmark -- issue
gptManagerBenchmark std::bad_alloc error #66
* Fix a bug in pipeline parallelism when beam-width > 1
* Fix a bug with Llama GPTQ due to improper support of GQA
* Fix issue #88
* Fix an issue with the Huggingface Transformers version #16
* Fix link jump in windows readme.md #30 - by @yuanlehome
* Fix typo in batchScheduler.h #56 - by @eltociear
* Fix typo #58 - by @RichardScottOZ
* Fix Multi-block MMHA: Difference between `max_batch_size` in the engine
builder and `max_num_sequences` in TrtGptModelOptionalParams? #65
* Fix the log message to be more accurate on KV cache #224
* Fix Windows release wheel installation: Failed to install the release
wheel for Windows using pip #261
* Fix missing torch dependencies: [BUG] The batch_manage.a choice error
in --cpp-only when torch's cxx_abi version is different with gcc #151
* Fix linking error during compiling google-test & benchmarks #277
* Fix logits dtype for Baichuan and ChatGLM: segmentation fault caused by
the lack of bfloat16 #335
* Minor bug fixes
## Version 0.5.0
* TensorRT-LLM v0.5.0 is the first public release.

View File

@ -6,9 +6,9 @@ TensorRT-LLM
[![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](https://nvidia.github.io/TensorRT-LLM/)
[![python](https://img.shields.io/badge/python-3.10.12-green)](https://www.python.org/downloads/release/python-31012/)
[![cuda](https://img.shields.io/badge/cuda-12.4.0-green)](https://developer.nvidia.com/cuda-downloads)
[![cuda](https://img.shields.io/badge/cuda-12.4.1-green)](https://developer.nvidia.com/cuda-downloads)
[![trt](https://img.shields.io/badge/TRT-10.0.1-green)](https://developer.nvidia.com/tensorrt)
[![version](https://img.shields.io/badge/release-0.10.0.dev-green)](./setup.py)
[![version](https://img.shields.io/badge/release-0.11.0.dev-green)](./tensorrt_llm/version.py)
[![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE)
[Architecture](./docs/source/architecture/overview.md)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Results](./docs/source/performance/perf-overview.md)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Examples](./examples/)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Documentation](./docs/source/)

View File

@ -906,7 +906,7 @@ public:
[this](uint64_t requestId, std::list<NamedTensor> const& response_tensors, bool final_response,
std::string const& errMsg)
{ return sendResponse(requestId, response_tensors, final_response, errMsg); },
nullptr, iterationDataCallback, optionalParams, terminateReqId, std::nullopt, excludeInputInOutput);
nullptr, iterationDataCallback, optionalParams, terminateReqId, excludeInputInOutput);
}
~GptServer()

View File

@ -18,6 +18,8 @@ from typing import Optional, Tuple
import click
from pydantic import BaseModel, field_validator
from transformers import AutoTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from utils.prepare_real_data import dataset
from utils.prepare_synthetic_data import token_norm_dist
@ -27,10 +29,12 @@ class RootArgs(BaseModel):
output: str
random_seed: int
task_id: int
std_out: bool
rand_task_id: Optional[Tuple[int, int]]
@field_validator('tokenizer')
def get_tokenizer(cls, v: str):
def get_tokenizer(cls,
v: str) -> PreTrainedTokenizer | PreTrainedTokenizerFast:
try:
tokenizer = AutoTokenizer.from_pretrained(v, padding_side='left')
except EnvironmentError as e:
@ -53,6 +57,11 @@ class RootArgs(BaseModel):
type=str,
help="Output json filename.",
default="preprocessed_dataset.json")
@click.option(
"--stdout",
is_flag=True,
help="Print output to stdout with a JSON dataset entry on each line.",
default=False)
@click.option("--random-seed",
required=False,
type=int,
@ -80,6 +89,7 @@ def cli(ctx, **kwargs):
ctx.obj = RootArgs(tokenizer=kwargs['tokenizer'],
output=kwargs['output'],
std_out=kwargs['stdout'],
random_seed=kwargs['random_seed'],
task_id=kwargs['task_id'],
rand_task_id=kwargs['rand_task_id'])

View File

@ -6,7 +6,7 @@ from typing import Optional
import click
from datasets import load_dataset
from pydantic import BaseModel, model_validator
from utils.utils import dataset_dump, get_norm_dist_tokens
from utils.utils import dataset_dump, get_norm_dist_tokens, print_dataset
def validate_output_len_dist(ctx, param, value):
@ -220,11 +220,19 @@ def dataset(root_args, **kwargs):
logging.debug(f"Input lengths: {[len(i) for i in input_ids]}")
logging.debug(f"Output lengths: {output_lens}")
dataset_dump(
input_lens, input_ids, output_lens, task_ids, {
"workload_type": "dataset",
"tokenizer": root_args.tokenizer.__class__.__name__,
"num_requests": len(input_ids),
"max_input_len": max(input_lens),
"max_output_len": max(output_lens)
}, root_args.output)
if not root_args.std_out:
dataset_dump(
input_lens, input_ids, output_lens, task_ids, {
"workload_type": "dataset",
"tokenizer": root_args.tokenizer.__class__.__name__,
"num_requests": len(input_ids),
"max_input_len": max(input_lens),
"max_output_len": max(output_lens)
}, root_args.output)
else:
print_dataset(
task_ids,
input_ids,
output_lens,
tokenizer=None,
)

View File

@ -1,7 +1,8 @@
import random
import click
from utils.utils import dataset_dump, gen_random_tokens, get_norm_dist_tokens
from utils.utils import (dataset_dump, gen_random_tokens, get_norm_dist_tokens,
print_dataset)
@click.command()
@ -55,15 +56,21 @@ def token_norm_dist(root_args, **kwargs):
min_id, max_id = root_args.rand_task_id
task_ids = [random.randint(min_id, max_id) for _ in range(num_reqs)]
dataset_dump(
input_lens, input_ids, output_lens, task_ids, {
"workload_type": "token-norm-dist",
"input_mean": kwargs['input_mean'],
"input_stdev": kwargs['input_stdev'],
"output_mean": kwargs['output_mean'],
"output_stdev": kwargs['output_stdev'],
"num_requests": kwargs['num_requests'],
"tokenize_vocabsize": root_args.tokenizer.vocab_size,
"max_input_len": max_input_len,
"max_output_len": max_output_len
}, root_args.output)
if not root_args.std_out:
dataset_dump(
input_lens, input_ids, output_lens, task_ids, {
"workload_type": "token-norm-dist",
"input_mean": kwargs['input_mean'],
"input_stdev": kwargs['input_stdev'],
"output_mean": kwargs['output_mean'],
"output_stdev": kwargs['output_stdev'],
"num_requests": kwargs['num_requests'],
"tokenize_vocabsize": root_args.tokenizer.vocab_size,
"max_input_len": max_input_len,
"max_output_len": max_output_len
}, root_args.output)
else:
print_dataset(
input_ids,
output_lens,
)

View File

@ -43,7 +43,17 @@ def dataset_dump(input_lens, input_ids, output_lens, task_ids, metadata,
task_id=task_ids[i]))
workload = Workload(metadata=metadata, samples=samples)
with open(output_file, 'w') as f:
json.dump(workload.dict(), f)
json.dump(workload.model_dump(), f)
def print_dataset(input_ids, output_lens):
for i, input_tokens in enumerate(input_ids):
d = {
"task_id": i,
"logits": input_tokens,
"output_tokens": output_lens[i]
}
print(json.dumps(d, separators=(',', ':'), ensure_ascii=False))
def get_list_of_delays(delay_dist, mean_time_bet_reqs, num_reqs, random_seed):

View File

@ -19,12 +19,12 @@ from argparse import ArgumentParser
import torch
# isort: on
from cuda import cuda, cudart
from mpi4py import MPI
from polygraphy.backend.trt import CreateConfig, EngineFromNetwork
import tensorrt_llm as tllm
from tensorrt_llm import Mapping, Tensor
from tensorrt_llm._ipc_utils import peer_access
from tensorrt_llm._utils import OMPI_COMM_TYPE_HOST, mpi_comm
from tensorrt_llm.functional import AllReduceStrategy, allreduce
from tensorrt_llm.plugin.plugin import current_all_reduce_helper
@ -35,11 +35,14 @@ def allreduce_benchmark(dtype: str,
tllm.logger.set_level('error')
world_size = tllm.mpi_world_size()
rank = tllm.mpi_rank()
local_comm = mpi_comm().Split_type(split_type=OMPI_COMM_TYPE_HOST)
local_rank = local_comm.Get_rank()
gpus_per_node = local_comm.Get_size()
torch.cuda.set_device(rank)
cudart.cudaSetDevice(rank)
torch.cuda.set_device(local_rank)
cudart.cudaSetDevice(local_rank)
mapping = Mapping(world_size, rank, world_size, world_size)
mapping = Mapping(world_size, rank, gpus_per_node, world_size)
if world_size == 1:
raise RuntimeError("Benchmark must run with mpi_world_size > 1")
@ -58,8 +61,10 @@ def allreduce_benchmark(dtype: str,
input = torch.ones(size, dtype=torch_dtype, device="cuda")
for strategy in [
AllReduceStrategy.NCCL, AllReduceStrategy.ONESHOT,
AllReduceStrategy.TWOSHOT
AllReduceStrategy.AUTO,
AllReduceStrategy.NCCL,
AllReduceStrategy.ONESHOT,
AllReduceStrategy.TWOSHOT,
]:
builder = tllm.Builder()
net = builder.create_network()
@ -81,9 +86,9 @@ def allreduce_benchmark(dtype: str,
current = allreduce(current, mapping.tp_group, strategy)
output = current.trt_tensor
network.mark_output(output)
output.name = 'output'
output.dtype = tllm.str_dtype_to_trt(dtype)
network.mark_output(output)
build_engine = EngineFromNetwork(
(builder.trt_builder, net.trt_network),
@ -103,7 +108,7 @@ def allreduce_benchmark(dtype: str,
_, stop = cuda.cuEventCreate(0)
runtimes = []
with peer_access(mapping):
MPI.COMM_WORLD.barrier()
tllm.mpi_barrier()
for _ in range(10):
cuda.cuEventRecord(start, stream.cuda_stream)

View File

@ -61,6 +61,7 @@ class BuildConfig:
parallel_attention: bool = None
new_decoder_architecture: bool = None
state_size: int = 0
state_dtype: Optional[str] = None
conv_kernel: int = 0
layer_types: List[str] = field(default_factory=list)
rnn_hidden_size: int = 0
@ -479,6 +480,7 @@ _allowed_configs = {
build_config=BuildConfig(
num_layers=32,
num_heads=32,
num_kv_heads=8,
hidden_size=4096,
vocab_size=32000,
hidden_act='swiglu',
@ -503,7 +505,7 @@ _allowed_configs = {
hidden_act='gelu',
n_positions=1024,
rotary_dim=64,
max_batch_size=256,
max_batch_size=128,
max_input_len=512,
max_output_len=200,
builder_opt=None,
@ -593,7 +595,7 @@ _allowed_configs = {
vocab_size=250880,
hidden_act=None,
n_positions=2048,
max_batch_size=8,
max_batch_size=32,
max_input_len=1024,
max_output_len=1024,
builder_opt=None,
@ -1327,6 +1329,7 @@ _allowed_configs = {
layer_types=["recurrent", "recurrent", "attention"],
rnn_hidden_size=2560,
logits_soft_cap=30.0,
state_dtype="float32",
)),
}

View File

@ -883,6 +883,9 @@ def build_gpt(args):
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.RecurrentGemmaForCausalLM(
config)
tensorrt_llm_model = optimize_model(tensorrt_llm_model,
use_fused_mlp=True,
use_fused_rg_lru=True)
else:
raise Exception(f'Unexpected model: {args.model}')
@ -894,15 +897,15 @@ def build_gpt(args):
# Plugins
if args.mode in ['plugin', 'plugin-ifb']:
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)
network.plugin_config.gpt_attention_plugin = args.dtype
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
network.plugin_config.enable_remove_input_padding()
network.plugin_config.set_lookup_plugin(dtype=args.dtype)
network.plugin_config.set_moe_plugin(dtype=args.dtype)
network.plugin_config.set_mamba_conv1d_plugin(dtype=args.dtype)
network.plugin_config.remove_input_padding = True
network.plugin_config.lookup_plugin = args.dtype
network.plugin_config.moe_plugin = args.dtype
network.plugin_config.mamba_conv1d_plugin = args.dtype
if args.quantization is None or "fp8" not in args.quantization:
network.plugin_config.set_gemm_plugin(dtype=args.dtype)
network.plugin_config.gemm_plugin = args.dtype
# Quantization plugins.
use_smooth_quant = quant_mode.has_act_and_weight_quant()
@ -910,21 +913,19 @@ def build_gpt(args):
if use_smooth_quant:
network.plugin_config.set_smooth_quant_plugins(dtype=args.dtype)
elif use_weight_only:
network.plugin_config.set_weight_only_quant_matmul_plugin(
dtype=args.dtype)
network.plugin_config.weight_only_quant_matmul_plugin = args.dtype
elif family == "llama" and quant_mode.has_act_and_weight_quant():
# RMS norm plugin for SmoothQuant
network.plugin_config.set_rmsnorm_quantization_plugin(
dtype=args.dtype)
network.plugin_config.rmsnorm_quantization_plugin = args.dtype
# Inflight batching
if args.mode == 'plugin-ifb':
network.plugin_config.enable_paged_kv_cache()
network.plugin_config.enable_paged_state()
network.plugin_config.paged_state = True
elif args.mode == 'ootb-except-mha':
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)
network.plugin_config.gpt_attention_plugin = args.dtype
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
network.plugin_config.enable_remove_input_padding()
network.plugin_config.remove_input_padding = True
if world_size > 1:
network.plugin_config.set_nccl_plugin(
@ -1056,12 +1057,12 @@ def build_bert(args):
# Plugins
if args.mode == 'plugin':
network.plugin_config.set_bert_attention_plugin(dtype=args.dtype)
network.plugin_config.set_gemm_plugin(dtype=args.dtype)
network.plugin_config.enable_qk_half_accum()
network.plugin_config.bert_attention_plugin = args.dtype
network.plugin_config.gemm_plugin = args.dtype
network.plugin_config.attention_qk_half_accumulation = True
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
elif args.mode == 'ootb-except-mha':
network.plugin_config.set_bert_attention_plugin(dtype=args.dtype)
network.plugin_config.bert_attention_plugin = args.dtype
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
if world_size > 1:
@ -1226,13 +1227,29 @@ def enc_dec_build_helper(component, config, args):
if component == 'encoder':
if family == 'whisper':
tllm_model = tensorrt_llm.models.WhisperEncoder(
n_mels=config['n_mels'],
n_ctx=1500, # n_audio_ctx
n_state=config['hidden_size'],
n_head=config['num_heads'],
n_layer=config['num_layers'],
dtype=dtype)
pretrained_config = PretrainedConfig.from_dict({
'architecture':
"WhisperEncoder",
'dtype':
dtype,
'num_hidden_layers':
config['num_layers'],
'num_attention_heads':
config['num_heads'],
'hidden_size':
config['hidden_size'],
'n_mels':
config['n_mels'],
'n_audio_ctx':
1500,
'vocab_size':
config['vocab_size'],
'hidden_act':
"gelu",
'num_languages':
100,
})
tllm_model = tensorrt_llm.models.WhisperEncoder(pretrained_config)
if use_weight_only:
tllm_model = quantize(tllm_model, quant_config)
else:
@ -1387,15 +1404,14 @@ def enc_dec_build_helper(component, config, args):
# Plugins
if args.mode == 'plugin':
network.plugin_config.set_bert_attention_plugin(dtype=args.dtype)
network.plugin_config.set_gemm_plugin(dtype=args.dtype)
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)
network.plugin_config.bert_attention_plugin = args.dtype
network.plugin_config.gemm_plugin = args.dtype
network.plugin_config.gpt_attention_plugin = args.dtype
if use_weight_only:
network.plugin_config.set_weight_only_quant_matmul_plugin(
dtype=args.dtype)
network.plugin_config.weight_only_quant_matmul_plugin = args.dtype
elif args.mode == 'ootb-except-mha':
network.plugin_config.set_bert_attention_plugin(dtype=args.dtype)
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)
network.plugin_config.bert_attention_plugin = args.dtype
network.plugin_config.gpt_attention_plugin = args.dtype
if world_size > 1:
network.plugin_config.set_nccl_plugin(

View File

@ -103,6 +103,7 @@ class EncDecBenchmark(BaseBenchmark):
remove_input_padding=config["plugin_config"]
["remove_input_padding"],
cross_attention=config["builder_config"]["cross_attention"],
skip_cross_qkv=config["builder_config"]["skip_cross_qkv"],
has_position_embedding=config["builder_config"]
["has_position_embedding"],
has_token_type_embedding=config["builder_config"]

View File

@ -111,6 +111,7 @@ class GPTBenchmark(BaseBenchmark):
self.use_mamba_conv1d_plugin = True
elif args.mode == 'ootb-except-mha':
self.use_gpt_attention_plugin = True
self.remove_input_padding = True
engine_buffer, build_time = build_gpt(args)
self.weights_size_approx = engine_buffer.nbytes
@ -123,6 +124,15 @@ class GPTBenchmark(BaseBenchmark):
if not hasattr(self, 'num_kv_heads') or self.num_kv_heads is None:
self.num_kv_heads = self.num_heads
rnn_config_items = [
'conv_kernel', 'layer_types', 'rnn_hidden_size', 'state_size',
'state_dtype'
]
rnn_configs_kwargs = {}
for item in rnn_config_items:
if hasattr(self, item):
rnn_configs_kwargs[item] = getattr(self, item)
model_config = tensorrt_llm.runtime.ModelConfig(
max_batch_size=self.max_batch_size,
max_beam_width=self.num_beams,
@ -143,13 +153,8 @@ class GPTBenchmark(BaseBenchmark):
tokens_per_block=self.tokens_per_block if hasattr(
self, 'tokens_per_block') else 64,
mamba_conv1d_plugin=self.use_mamba_conv1d_plugin,
conv_kernel=self.conv_kernel if hasattr(self, 'conv_kernel') else 0,
state_size=self.state_size if hasattr(self, 'state_size') else 0,
layer_types=self.layer_types
if hasattr(self, 'layer_types') else [],
rnn_hidden_size=self.rnn_hidden_size if hasattr(
self, 'rnn_hidden_size') else 0,
gpu_weights_percent=list(sorted(gpu_weights_percents))[0],
**rnn_configs_kwargs,
)
if args.model == 'chatglm_6b':
self.sampling_config = tensorrt_llm.runtime.SamplingConfig(

View File

@ -1,6 +1,7 @@
# TensorRT-LLM Benchmarking
**WORK IN PROGRESS**
> [!WARNING] Work in Progress
> This benchmarking suite is a current work in progress and is prone to large changes.
This package is the official benchmarking suite for TensorRT-LLM. This benchmark will be updated
as development of TensorRT-LLM continues.
@ -9,7 +10,7 @@ as development of TensorRT-LLM continues.
From this folder, run `pip install -r requirements.txt` to install the extra dependencies required for this tool.
### Available Model Options
### Available Build and Benchmark Options
The following model options are available for benchmarking models.
@ -17,7 +18,9 @@ The following model options are available for benchmarking models.
| :- | :-: | :-: | :- |
| `--model` | Y | - | The name of the model to benchmark. |
| `--dtype` | N | `float16` | The datatype of the weights. |
| `--max-batch-size` | Y | - | The batch size to build the engine with for the benchmark. |
| `--kv-dtype` | N | `float16` | The datatype to store the KV Cache in. |
| `--kv-cache-free-gpu-mem-fraction` | N | `0.98` | The percentage of free memory that the KV cache is allowed to occupy. |
| `--quantization` | N | `None` |The quantization algorithm to be used when benchmarking. See the [documentation](https://nvidia.github.io/TensorRT-LLM/precision.html) for more information|
| `--workspace` | N | `/tmp` | The directory to store benchmarking intermediate files. |
| `--tensor-parallel-size` | N | `1` | Number of tensor parallel shards to run the benchmark with. |
@ -35,7 +38,8 @@ The following model options are available for benchmarking models.
#### Support Quantization Modes
TensorRT-LLM supports a number of quanization modes. For more information about quantization, see the [documentation](https://nvidia.github.io/TensorRT-LLM/precision.html).
TensorRT-LLM supports a number of quanization modes. For more information about quantization, see the
[documentation](https://nvidia.github.io/TensorRT-LLM/precision.html).
- None (no quantization applied)
- W8A16
@ -54,7 +58,7 @@ In order to benchmark a static batch for a network, run a command like the follo
```shell
cd tensorrt_llm_bench/
python benchmark.py --model tiiuae/falcon-7b static --isl 128 --osl 128 --batch 1
python benchmark.py --model tiiuae/falcon-7b static --isl 128 --osl 128 --max-batch-size 1
```
This command line will build a unique engine for the configuration and run the benchmark using
@ -64,18 +68,167 @@ the `gptSessionBenchmark` binary. You need to build the TensorRT-LLM wheel with
python3 ./scripts/build_wheel.py --benchmarks <other options>
```
The complete list of arguments are given here:
If you've already compiled the wheel without benchmarks, you can build the benchmarking binaries with the following after the fact:
```shell
pushd cpp/build/
make -j benchmarks
popd
```
The complete list of arguments for static benchmarking are as follows:
| Option | Required | Default | Description |
| :- | :-: | :-: | :- |
| `--batch` | Y | - | The batch size to benchmark. |
| `--isl` | Y | - | The input sequence length to pass in during benchmark. |
| `--osl` | Y | - | The output sequence length to generate in the benchmark. |
| `--gpt-session-path` | N | `../../cpp/build/benchmarks/gptSessionBenchmark` | The path to the built gptSessionBenchmark binary. |
| `--max-tokens-in-kv-cache` | N | `None` | The maximum number of tokens to store in the KV Cache during benchmarking. |
| `--kv-cache-mem-percent` | N | `0.9` | The percentage of free memory that the KV cache is allowed to occupy. |
| `--warm-up-runs` | N | `2` | The number of warm up runs to run before benchmarking actual results. |
| `--num-runs` | N | `10` | The number runs to generate benchmarking results from. |
| `--duration` | N | `60` | The minimum iteration time, in seconds, to measure. |
> [!WARNING]
> `gptSession` will be deprecated for the 1.0 release of TensorRT-LLM. This command line will change in order to match and update benchmarks accordingly.
## Inflight Benchmarking with a Dataset
This section covers how to benchmark TensorRT-LLM using inflight batching.
### Workflow
The workflow for inflight batching is slightly different than the [static scenario](#static-benchmarking-a-network) as it requires a workload of requests instead of a single static batch. The following is the workflow for benchmarking using inflight batching:
1. Prepare a dataset to drive the inflight batching benchmark.
2. Run the `inflight` benchmarking subcommand and provide the dataset from step 1.
#### Preparing a Dataset
The inflight benchmark utilizes a fixed JSON schema so that it is simple and
straightforward to specify requests. The schema is defined as follows:
| Key | Required | Type | Description |
| :- | :-: | :-: | :- |
| `task_id`| Y | String | Unique identifier for the request. |
| `prompt` | N* | String | Input text for a generation request. |
| `logits` | N* | List[Integer] | List of logits that make up the request prompt. |
| `output_tokens` | Y | Integer | Number of generated tokens for this request. |
> [!NOTE] Prompt and logits are mutually exclusive*
> While having both `prompt` and `logits` is not required, at least one is required.
> If `logits` are specified, the `prompt` entry is ignored for request generation.
Examples of valid entries for the inflight benchmark are:
- Entries with a human-readable prompt and no logits.
```json
{"task_id": 1, "prompt": "Generate an infinite response to the following: This is the song that never ends, it goes on and on my friend.", "output_tokens": 1000}
{"task_id": 2, "prompt": "Generate an infinite response to the following: Na, na, na, na", "output_tokens": 1000}
```
- Entries which contain logits.
```json
{"task_id":0,"logits":[863,22056,25603,11943,8932,13195,3132,25032,21747,22213],"output_tokens":128}
{"task_id":1,"logits":[14480,13598,15585,6591,1252,8259,30990,26778,7063,30065,21764,11023,1418],"output_tokens":128}
```
> [!INFO] A whole entry is on a line!
> To make the passing of data simpler, a complete JSON entry is on each line so that the benchmarker
> can simply read a line and assume a complete entry. When creating a dataset, be sure that a complete
> JSON entry is on every line.
#### Using `prepare_dataset` to Create Synthetic Datasets
In order to prepare a synthetic dataset, you can use the provided script in the `benchmarks/cpp`
directory. For example, to generate a synthetic dataset of 1000 requests with a uniform ISL/OSL of
128/128 for [Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b), simply run:
```shell
benchmarks/cpp/prepare_dataset.py --tokenizer meta-llama/Llama-2-7b-hf token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 1000 --stdout
```
You can pipe the above command to a file to reuse the same dataset, or simply pipe its output to the
benchmark script (example below).
### Running a Dataset with the Benchmarker
Once you've generated a dataset (see [above](#preparing-a-dataset)), you can run the benchmarker
in one of two ways:
```shell
benchmarks/suite/tensorrt_llm_bench/benchmark.py --model $HF_MODEL_NAME --max-batch-size $BATCH_SIZE < $DATASET_PATH
```
> [!INFO] Alternative to piping.
> There is also a `--dataset` option for `benchmark.py` that can be used instead of piping a file.
or
```shell
benchmarks/cpp/prepare_dataset.py --tokenizer $HF_MODEL_NAME --input-mean $ISL --output-mean $OSL --num-requests $NUM_REQUESTS --stdout | benchmarks/suite/tensorrt_llm_bench/benchmark.py --model $HF_MODEL_NAME --max-batch-size $BATCH_SIZE --request-rate $REQUEST_RATE
```
#### How the Benchmarker Works
The benchmarker will read in a data file or standard input (stdin) as a stream where a single line contains
a complete JSON request entry. The process that the benchmarker is as follows:
1. Iterate over all input requests. If `logits` is specified, construct the request using the specified
list of logits. Otherwise, tokenize the `prompt` with as specified by `--model $HF_MODEL_NAME`.
2. Build the TensorRT-LLM engine.
3. Submit the dataset to the TensorRT-LLM `Executor` API at the request rate specified by `--request-rate $REQUEST_RATE`
4. Wait for all requests to return, compute statistics, then report out results.
When the benchmark runs successfully, you will see a report out of the run similar to the following:
```
[RANK 0] Submitting requests...
[RANK 0] Completed request submission.
[RANK 0] Calculating results.
[RANK 0] Reporting...
[RANK 0] JSON: {'benchmark_cmd': '', 'binary': '', 'build_cmd': 'trtllm-build --output_dir /tmp/meta-llama/llama-2-7b-hf --model_config /tmp/generated_config.json --workers 1 --max_batch_size 1024 --max_input_len 128 --max_output_len 128 --max_num_tokens 8000 --context_fmha enable --gpt_attention_plugin float16 --paged_kv_cache enable --multiple_profiles enable --gemm_plugin float16', 'first_token_latency': 0.0, 'inflight_batching': True, 'kv_mem_fraction': 0.98, 'latency_units': 'ms', 'max_batch_size': 1024, 'max_tokens': 8000, 'model': 'meta-llama/Llama-2-7b-hf', 'peak_gpu_mem_units': 'GB', 'peak_gpu_mem': 0.0, 'scheduler': 'Max Utilization', 'throughput_units': 'tokens/second', 'throughput': 17634.422523488243, 'time_per_output_token': 0.0, 'total_input_tokens': 128000, 'total_latency': 7.258530855178833, 'total_output_tokens': 128000}
===========================================================
= METADATA
===========================================================
Model: meta-llama/Llama-2-7b-hf
TP Size: 1
PP Size: 1
Scheduling Policy: Max Utilization
In-flight Batcher?: True
Dtype: float16
KV Cache Dtype: FP8
Quantization: FP8
KV Memory Percentage: 98.0%
===========================================================
= ENGINE DETAILS
===========================================================
Engine Directory: /tmp/meta-llama/llama-2-7b-hf
Max Batch Size: 1024
Total Input Length: 128000
Total Output Length: 128000
Max Tokens: 8000
===========================================================
= STATISTICS
===========================================================
Throughput (tokens/second): 17634.422523488243
Total Latency (ms): 7258.5309
First Token Latency (ms): 0.0
Token-to-token Latency (ms): 0.0
Peak GPU Memory Usage (GB): 0.0
===========================================================
= COMMANDS
===========================================================
Build: trtllm-build --output_dir /tmp/meta-llama/llama-2-7b-hf --model_config /tmp/generated_config.json --workers 1 --max_batch_size 1024 --max_input_len 128 --max_output_len 128 --max_num_tokens 8000 --context_fmha enable --gpt_attention_plugin float16 --paged_kv_cache enable --multiple_profiles enable --gemm_plugin float16
Benchmark:
[RANK 0] Terminating.
```
> [!WARNING] Some statistics are not reported.
> There are some statistics that are not reported in the summary (typically as 0.0). These statistics
> are not available currently.
That's it! -- you've successfully benchmarked TensorRT-LLM!

View File

@ -2,9 +2,9 @@ from pathlib import Path
from typing import get_args
import click
from ifb import executor_benchmark
from static import static_benchmark
from utils import (VALID_CACHE_DTYPES, VALID_COMPUTE_DTYPES, VALID_MODELS,
VALID_QUANT_ALGOS)
from utils import VALID_CACHE_DTYPES, VALID_COMPUTE_DTYPES, VALID_QUANT_ALGOS
from utils.dataclasses import BenchmarkConfig
@ -13,9 +13,16 @@ from utils.dataclasses import BenchmarkConfig
"--model",
"-m",
required=True,
type=click.Choice(tuple(get_args(VALID_MODELS))),
type=str,
help="The Huggingface name of the model to benchmark.",
)
@click.option(
"--max-batch-size",
hidden=True,
default=0,
type=int,
help="Maximum batch size to build the benchmark engine with.",
)
@click.option(
"--kv-dtype",
type=click.Choice(tuple(get_args(VALID_CACHE_DTYPES))),
@ -38,9 +45,11 @@ from utils.dataclasses import BenchmarkConfig
"documentations for more information.\n"
" - https://nvidia.github.io/TensorRT-LLM/precision.html"
" - https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/quantization-in-TRT-LLM.md"
))
),
)
@click.option(
"--workspace",
"-w",
required=False,
type=click.Path(writable=True, readable=True),
default="/tmp",
@ -62,26 +71,46 @@ from utils.dataclasses import BenchmarkConfig
required=False,
help="Number of pipeline parallel shards to run the benchmark with.",
)
@click.option(
"--kv-cache-free-gpu-mem-fraction",
"-kv-mem",
type=float,
default=0.98,
help="The percentage of free memory that the KV Cache is allowed to occupy.",
)
@click.option(
"--build-opts",
type=str,
default="",
required=False,
hidden=True,
help="Passthrough options for trtllm-build to fine-tuning build commands.")
@click.pass_context
def benchmark(
ctx,
model: str,
max_batch_size: int,
workspace: Path,
dtype: str,
kv_dtype: str,
quantization: str,
tensor_parallel_size: int,
pipeline_parallel_size: int,
kv_cache_free_gpu_mem_fraction: float,
build_opts: str,
):
"""Utility for using TRT-LLM for benchmarking networks from Huggingface."""
ctx.obj = BenchmarkConfig(
model=model,
max_batch_size=max_batch_size,
workspace=Path(workspace),
dtype=dtype,
cache_dtype=kv_dtype,
quantization=quantization,
tensor_parallel=tensor_parallel_size,
pipeline_parallel=pipeline_parallel_size,
kv_cache_mem_percentage=kv_cache_free_gpu_mem_fraction,
build_overrides=build_opts.split(),
)
# Create the workspace where we plan to store intermediate files.
@ -90,6 +119,7 @@ def benchmark(
# Add nested subcommands to main benchmark CLI.
benchmark.add_command(static_benchmark)
benchmark.add_command(executor_benchmark)
if __name__ == "__main__":
benchmark()

View File

@ -0,0 +1,30 @@
from typing import List, Protocol
from utils.dataclasses import BenchmarkResults, InferenceRequest
class Benchmarker(Protocol):
"""Protocol for defining benchmarking classes for building/benchmarking."""
def build(self) -> None:
"""Build a model to be benchmarked."""
...
def benchmark(self) -> BenchmarkResults:
"""Benchmark the constructed model container by a benchmarker."""
...
class DatasetBenchmarker(Protocol):
def benchmark_dataset(self,
dataset: List[InferenceRequest]) -> BenchmarkResults:
"""_summary_
Args:
dataset (List[InferenceRequest]): List of inference requests to benchmark.
Returns:
BenchmarkResults: The results of the benchmark run.
"""
...

View File

@ -0,0 +1,146 @@
from datetime import timedelta
from time import sleep, time
from typing import List
from mpi4py.MPI import COMM_WORLD
from transformers import PreTrainedTokenizer
from utils.dataclasses import BenchmarkConfig, BenchmarkResults
from utils.enums import IFBSchedulingPolicy, ResultsSchedulingPolicy
from tensorrt_llm.bindings.executor import (Executor, ExecutorConfig,
KvCacheConfig, ModelType,
OutputConfig, Request,
SchedulerConfig)
from . import InferenceRequest
class PybindExecutorBenchmarker:
"""Utility class for running inflight benchmarks via the Executor API."""
def __init__(
self,
config: BenchmarkConfig,
):
"""Initialize a gptSessionBenchmark instance.
Args:
config (BenchmarkConfig): Benchmark configuration for build/run.
"""
self.config: BenchmarkConfig = config
@staticmethod
def get_request(request: InferenceRequest,
tokenizer: PreTrainedTokenizer) -> Request:
return Request(
input_token_ids=request.logits,
max_new_tokens=request.output_tokens,
stop_words=[],
bad_words=[],
streaming=False,
output_config=OutputConfig(exclude_input_from_output=True),
pad_id=tokenizer.pad_token_id,
end_id=tokenizer.eos_token_id,
)
def initialize_executor(self) -> Executor:
"""
Initialize an Executor instance.
Returns:
Executor: An instance of a TensorRT-LLM Executor.
"""
policy = IFBSchedulingPolicy(self.config.scheduling_policy).value
executor_config: ExecutorConfig = ExecutorConfig(
max_beam_width=1,
enable_chunked_context=self.config.chunking,
scheduler_config=SchedulerConfig(
capacity_scheduler_policy=policy, ),
kv_cache_config=KvCacheConfig(
free_gpu_memory_fraction=self.config.kv_cache_mem_percentage, ),
)
executor: Executor = Executor(
model_path=self.config.engine_path,
model_type=ModelType.DECODER_ONLY,
executor_config=executor_config,
)
return executor
def benchmark_dataset(self, rate: int,
dataset: List[InferenceRequest]) -> BenchmarkResults:
"""Benchmark the Executor Pybind interface.
Args:
dataset (List[InferenceRequest]): List of inference requests to
benchmark with.
Returns:
BenchmarkResults: Final results from running the specified dataset.
"""
request_ids = []
num_finished = 0
num_errored = 0
num_input_tokens = 0
num_output_tokens = 0
delay = 1.0 / float(rate)
last_request = len(dataset) - 1
bench_result = None
executor = self.initialize_executor()
if executor.can_enqueue_requests():
print(f"[RANK {COMM_WORLD.rank}] Submitting requests...")
start = time()
for i, request in enumerate(dataset):
sleep_time = delay if i != last_request else 0
request_ids.append(executor.enqueue_request(request))
num_input_tokens += len(request.input_token_ids)
sleep(sleep_time)
print(f"[RANK {COMM_WORLD.rank}] Completed request submission.")
while num_finished <= last_request:
responses = executor.await_responses(timeout=timedelta(
milliseconds=1))
for response in responses:
has_error = response.has_error()
num_finished += 1
num_errored += 1 if has_error else 0
if not has_error:
result = response.result
for out_tokens in result.output_token_ids:
num_output_tokens += len(out_tokens)
end = time()
print(f"[RANK {COMM_WORLD.rank}] Calculating results.")
e2e_time = end - start
e2e_time * 1000.0
policy = ResultsSchedulingPolicy(
IFBSchedulingPolicy(self.config.scheduling_policy).value)
bench_result = BenchmarkResults(
model=self.config.model,
dtype=self.config.dtype.value,
quantization=str(self.config.quantization.value),
max_batch_size=self.config.max_batch_size,
total_input_tokens=num_input_tokens,
total_output_tokens=num_output_tokens,
tp_size=self.config.tensor_parallel,
pp_size=self.config.pipeline_parallel,
kv_mem_fraction=self.config.kv_cache_mem_percentage,
scheduler=policy.value,
max_tokens=self.config.max_tokens,
inflight_batching=True,
total_latency=e2e_time,
first_token_latency=0,
time_per_output_token=0,
latency_units="ms",
throughput=num_output_tokens / e2e_time,
throughput_units="tokens/second",
peak_gpu_mem=0.0,
peak_gpu_mem_units="GB",
build_cmd="",
benchmark_cmd="",
)
return bench_result

View File

@ -0,0 +1,209 @@
from pathlib import Path
from subprocess import CompletedProcess
from typing import Dict, List
from utils import command_logger, process_error_check, run_process
from utils.dataclasses import BenchmarkConfig, BenchmarkResults
from utils.trtllm_config import TRTLLMConfig
class gptSessionBenchmarker:
"""Utility class for running static benchmarks with gptSessionBenchmark."""
def __init__(
self,
config: BenchmarkConfig,
benchmark_binary: Path,
batch_size: int,
isl: int,
osl: int,
warm_up_runs: int,
num_runs: int,
duration: int,
kv_cache_free_fraction: float = .9,
):
"""Initialize a gptSessionBenchmark instance.
Args:
config (BenchmarkConfig): Benchmark configuration for build/run.
benchmark_binary (Path): Path to the benchmarking binary.
batch_size (int): Batch size to configure the build with.
isl (int): Input sequence length to configure the build with.
osl (int): Output sequence length to configure the build with.
kv_cache_free_fraction (float, optional): The amount of remaining
GPU memory after model loading to save for the KV Cache. Defaults
to .9.
"""
self.config: BenchmarkConfig = config
self.gpt_session_path = Path(benchmark_binary).absolute()
self.batch_size = batch_size
self.input_length = isl
self.output_length = osl
self.warm_up = warm_up_runs
self.num_runs = num_runs
self.duration = duration
self.kv_cache_mem = kv_cache_free_fraction
self.result = None
def get_build_command(self) -> List[str]:
"""Build the engine command for TRT-LLM.
Returns:
List[str]: A list of command line arguments to run a build command.
"""
model = self.config.model
tp = self.config.tensor_parallel
pp = self.config.pipeline_parallel
dtype = self.config.dtype.value
kv_dtype = self.config.cache_dtype
quant_algo = self.config.quantization.value
output_dir = self.config.engine_path
max_batch_size = self.batch_size
max_isl = self.input_length
max_osl = self.output_length
workspace = self.config.workspace
# Generate the TRT-LLM Configuration file using the dataclass
# NOTE: This method does not use weights.
trtllm_config = TRTLLMConfig.from_hf(model, tp, pp, dtype, quant_algo,
kv_dtype.value)
# Write the generated configuration file to the benchmark workspace.
trtllm_config.to_json(workspace)
# Return the full command for building TRT-LLM via subprocess call.
cmd = [
"trtllm-build",
"--output_dir",
output_dir,
"--model_config",
Path(workspace, "generated_config.json"),
"--workers",
self.config.world_size,
# Define the maximums the engine can accept.
"--max_batch_size",
max_batch_size,
"--max_input_len",
max_isl,
"--max_output_len",
max_osl,
"--context_fmha",
"enable",
# Set the attention plugin data type.
"--gpt_attention_plugin",
dtype,
# Disable paged cache since we aren't batching on the fly.
"--paged_kv_cache",
"disable",
] + kv_dtype.get_build_options(dtype)
return [str(arg) for arg in cmd]
@command_logger(prefix="BUILD COMMAND: ")
@process_error_check
def _run_build(self, cmd: List[str]) -> CompletedProcess:
"""Wrapper for calling the build for TRT-LLM.
Purpose of this wrapper is so that we can decorate it/log it.
Args:
cmd (List[str]): List of command line arguments for running.
Returns:
CompletedProcess: Completed process information for parsing and
reporting.
"""
return run_process(
cmd,
self.config.workspace,
)
def build(self) -> None:
"""Build the engine for benchmarking."""
self._run_build(self.get_build_command())
@command_logger(prefix="BENCHMARK COMMAND: ")
@process_error_check
def _run_benchmark(self, cmd: List[str]) -> CompletedProcess:
"""Run the benchmark command in the configured workspace.
Args:
cmd (List[str]): List of command line arguments to run via
subprocess.
Returns:
CompletedProcess: Completed process information for reporting.
"""
return run_process(cmd, run_dir=self.config.workspace, use_environ=True)
@staticmethod
def parse_benchmark_result(benchmark_line: str) -> Dict[str, str]:
pass
def benchmark(self):
"""Benchmarks a TRT-LLM for a configured instance."""
# Compile the command for running
cmd = [
"mpirun",
"-allow-run-as-root",
"-n",
self.config.world_size,
self.gpt_session_path,
"--engine_dir",
self.config.engine_path,
"--batch_size",
self.batch_size,
"--log_level",
"info",
"--kv_cache_free_gpu_mem_fraction",
self.kv_cache_mem,
"--beam_width",
"1",
"--warm_up",
self.warm_up,
"--num_runs",
self.num_runs,
"--duration",
self.duration,
"--input_output_len",
f"{self.input_length},{self.output_length};{self.input_length},1",
]
cmd = [str(arg) for arg in cmd]
# Run the benchmark using the provided gptSession benchmark binary.
bench_return = self._run_benchmark(cmd)
results = [
x.split(" ") for x in bench_return.stdout.split("\n")
if "[BENCHMARK]" in x
]
ttft = float(results[1][8])
gen_time = float(results[0][8]) - ttft
total_out = int(results[0][2]) * int(results[0][6])
total_in = int(results[0][2]) * int(results[0][4])
batch_size = int(results[0][2])
bench_result = BenchmarkResults(
model=self.config.model,
dtype=self.config.dtype.value,
quantization=str(self.config.quantization.value),
max_batch_size=batch_size,
total_input_tokens=total_in,
total_output_tokens=total_out,
tp_size=self.config.tensor_parallel,
pp_size=self.config.pipeline_parallel,
kv_mem_fraction=self.kv_cache_mem,
scheduler="Static",
inflight_batching=False,
total_latency=results[0][8],
first_token_latency=ttft,
time_per_output_token=gen_time / (total_out - batch_size),
latency_units="ms",
throughput=results[0][10],
throughput_units="tokens/second",
peak_gpu_mem=results[0][16],
peak_gpu_mem_units="GB",
binary=str(self.gpt_session_path),
build_cmd=" ".join(self.get_build_command()),
benchmark_cmd=" ".join(cmd))
return bench_result

View File

@ -0,0 +1,338 @@
import json
import os
import subprocess
import sys
from functools import partial
from pathlib import Path
from typing import List, TextIO, Tuple
import click
from benchmarkers.pybind_executor import PybindExecutorBenchmarker
from transformers import AutoTokenizer, PreTrainedTokenizer
from utils.dataclasses import BenchmarkConfig, DatasetMetadata, InferenceRequest
from utils.trtllm_config import TRTLLMConfig
from tensorrt_llm.logger import logger
def create_dataset_from_stream(
tokenizer: PreTrainedTokenizer,
max_input_length: int = 0,
max_output_length: int = 0,
stream: TextIO = sys.stdin,
) -> Tuple[DatasetMetadata, List[InferenceRequest]]:
"""Generate metadata and a list of requests to drive benchmarking.
Args:
tokenizer (PreTrainedTokenizer): HuggingFace tokenizer.
max_input_length (int): Maximum input length to cap prompts to.
Returns:
DatasetMetadata: Dataclass of dataset statistics.
List[InferenceRequest]: A list of inference requests for benchmarking.
"""
# Initialize dataset list, and metadata tracking variables.
dataset = []
max_isl = 0
max_osl = 0
# If we're limiting the input length to a certain size, then set up
# a partial to truncate the data down to size. Otherwise, just use the
# unmodified tokenizer callable.
tokenize = (partial(
tokenizer,
padding="max_length",
max_length=max_input_length,
truncation=True,
) if max_input_length > 0 else tokenizer)
# If we need to limit the output length, fill in a partial callable
# for max, otherwise a lambda that just returns x with no bounds.
output_limiter = (partial(max, max_output_length)
if max_output_length > 0 else lambda x: x)
# For each line in the standard input, parse out the JSON string we expect
# to see.
# Note the := walrus -- we're assigning and checking the condition.
while line := stream.readline():
# We expect the data to come in as a JSON string.
# For example:
# {"prompt": "Generate an infinite response to the following: There once was a man who.", "output_tokens": 1000}
# Each line should be a complete JSON dictionary with no indentation
# or newline characters.
data = json.loads(line)
logits = data.get("logits", None)
prompt = data.get("prompt", None)
task_id = data["task_id"]
osl = data["output_tokens"]
# If the request comes in with logits, just use the provided.
# Otherwise we need to tokenize it.
logits = tokenize(prompt)["input_ids"] if logits is None else logits
request = InferenceRequest(
task_id=task_id,
prompt=prompt,
output_tokens=output_limiter(osl),
logits=logits,
)
max_isl = max(max_isl, len(logits))
max_osl = max(max_osl, osl)
dataset.append(request)
# Fill in basic dataset metrics here
# TODO: Maybe fill this out to be more complete?
metadata = DatasetMetadata(
max_isl=max_isl,
max_osl=max_osl,
num_requests=len(dataset),
)
return metadata, dataset
def initialize_tokenizer(model_name: str) -> PreTrainedTokenizer:
"""Initialize a tokenizer.
Args:
model_name (str): The name of the HuggingFace model to pull a
tokenizer from.
Returns:
PreTrainedTokenizer: An initialized HuggingFace tokenizer.
"""
# Initialize the tokenizer specific to the model that we are planning
# to benchmark.
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
if tokenizer.pad_token_id is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
return tokenizer
def get_trtllm_build_command(benchmark_cfg: BenchmarkConfig) -> List[str]:
model = benchmark_cfg.model
tp = benchmark_cfg.tensor_parallel
pp = benchmark_cfg.pipeline_parallel
dtype = benchmark_cfg.dtype.value
kv_dtype = benchmark_cfg.cache_dtype
quant_algo = benchmark_cfg.quantization.value
output_dir = benchmark_cfg.engine_path
max_batch_size = benchmark_cfg.max_batch_size
max_isl = benchmark_cfg.engine_isl
max_osl = benchmark_cfg.engine_osl
max_tokens = benchmark_cfg.max_tokens
workspace = benchmark_cfg.workspace
# Generate the TRT-LLM Configuration file using the dataclass
# NOTE: This method does not use weights.
trtllm_config = TRTLLMConfig.from_hf(model, tp, pp, dtype, quant_algo,
kv_dtype.value)
# Write the generated configuration file to the benchmark workspace.
trtllm_config.to_json(workspace)
# Return the full command for building TRT-LLM via subprocess call.
cmd = [
"trtllm-build",
"--output_dir",
output_dir,
"--model_config",
Path(workspace, "generated_config.json"),
"--workers",
benchmark_cfg.world_size,
"--max_input_len",
max_isl,
"--max_output_len",
max_osl,
"--context_fmha",
"enable",
# Set the attention plugin data type.
"--gpt_attention_plugin",
dtype,
# Enable paged KV Cache for IFB.
"--paged_kv_cache",
"enable",
] + kv_dtype.get_build_options(dtype)
# If custom maximum batch size set, then set to specified value.
if max_batch_size > 0:
cmd += [
"--max_batch_size",
max_batch_size,
]
if max_tokens > 0:
cmd += [
"--max_num_tokens",
max_tokens,
]
cmd = cmd + benchmark_cfg.build_overrides
return cmd
@click.command("inflight")
@click.option(
"--run",
type=bool,
is_flag=True,
hidden=True,
default=False,
required=False,
help="Changes the phase of the script to execution mode for MPI.",
)
@click.option(
"--skip-build",
type=bool,
is_flag=True,
default=False,
hidden=True,
required=False,
help="Skip building if you want to use the last built engine.",
)
@click.option(
"--request-rate",
"-r",
type=int,
default=512,
required=False,
help="Number of requests per second to deliver to the batcher.",
)
@click.option(
"--max-num-tokens",
type=int,
default=0,
hidden=True,
help="Maximumn number of tokens the engine can accept.",
)
@click.option(
"--scheduling-policy",
type=click.Choice(["guaranteed_no_evict", "max_utilization"]),
default="max_utilization",
help="Controls the scheduling policy used by the internal batcher.",
)
@click.option(
"--dataset",
type=click.Path(exists=True,
readable=True,
path_type=Path,
resolve_path=True),
default=None,
required=False,
help="Pass in a dataset file for parsing instead of stdin.",
)
@click.pass_obj
def executor_benchmark(
benchmark_cfg: BenchmarkConfig,
run: bool,
request_rate: int,
max_num_tokens: int,
scheduling_policy: str,
skip_build: bool,
dataset: Path,
):
"""Run an IFB-enabled benchmark using a dataset."""
# Initialize the tokenizer and generate the dataset
logger.set_level("info")
DATASET_PATH = Path(benchmark_cfg.workspace, "tokenized_dataset.txt")
TOKENIZER = initialize_tokenizer(benchmark_cfg.model)
final_dataset = []
benchmark_cfg.max_tokens = max_num_tokens
benchmark_cfg.scheduling_policy = scheduling_policy
if not run:
try:
stream = sys.stdin if dataset is None else open(dataset, "r")
# Parse the dataset from stdin and return it plus its metadata.
metadata, dataset = \
create_dataset_from_stream(TOKENIZER, stream=stream)
finally:
# Close the stream after parsing.
stream.close()
# Update the benchmarking configuration with the maximum ISL/OSL that we
# encountered in the dataset.
benchmark_cfg.engine_isl = metadata.max_isl
benchmark_cfg.engine_osl = metadata.max_osl
# Build engine
logger.info("Building engine...")
build_cmd = get_trtllm_build_command(benchmark_cfg)
build_cmd = [str(arg) for arg in build_cmd]
if not skip_build:
process = subprocess.run(build_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=benchmark_cfg.workspace)
logger.info(f"BUILD CMD: {' '.join(process.args)}")
# If the build failed, raise an exception.
if process.returncode != 0:
logger.error(process.stderr.decode())
raise RuntimeError(
"TensorRT-LLM build process failed. Command used:\n"
f"{' '.join(process.args)}\n", )
with open(DATASET_PATH, "w") as ds_out:
while dataset:
request = dataset.pop()
ds_out.write(f"{request.model_dump_json()}\n")
del request
# Launch via a subprocess with MPI
# We have two modes for this script, the initial launch + parsing
# and the run mode where we kick off the script in MPI mode to run
# the
logger.info("Launching benchmark...")
bench_cmd = \
["mpirun", "-n", f"{benchmark_cfg.world_size}", "python"] + \
sys.argv + ["--run"]
process = subprocess.Popen(
bench_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
env=os.environ,
)
stdout, _ = process.communicate()
logger.info("Benchmark complete.")
logger.info(stdout.decode("ascii"))
else:
from mpi4py.MPI import COMM_WORLD
if COMM_WORLD.Get_rank() == 0:
logger.info(f"[RANK {COMM_WORLD.rank}] Loading dataset...")
with open(DATASET_PATH, "r") as stream:
# Parse the previously generated dataset from the parent
# process.
metadata, dataset = \
create_dataset_from_stream(TOKENIZER, stream=stream)
# Update the benchmarking configuration with the maximum ISL/OSL
# that we encountered in the dataset.
benchmark_cfg.engine_isl = metadata.max_isl
benchmark_cfg.engine_osl = metadata.max_osl
# Parse the dataset into the Executor Request type.
logger.info("Preparing dataset...")
while dataset:
entry = dataset.pop()
request = PybindExecutorBenchmarker.get_request(
entry, TOKENIZER)
final_dataset.append(request)
del entry
logger.info("Dataset prepared.")
logger.info(f"DATASET METADATA: {metadata.model_dump()}")
logger.info(f"[RANK {COMM_WORLD.rank}] Initializing benchmarker...")
# Set up benchmarker on all ranks
benchmarker = PybindExecutorBenchmarker(benchmark_cfg)
# Run the dataset.
result = benchmarker.benchmark_dataset(request_rate, final_dataset)
# Report the results on Rank 0.
if COMM_WORLD.rank == 0:
logger.info(f"[RANK {COMM_WORLD.rank}] Reporting...\n"
f"JSON: {result.model_dump_json()}\n"
f"{result.get_summary(benchmarker.config)}")
logger.info(f"[RANK {COMM_WORLD.rank}] Terminating.")

View File

@ -1,9 +1,8 @@
import json
import os
from pathlib import Path
import click
from utils.benchmarkers import gptSessionBenchmarker
from benchmarkers.static import gptSessionBenchmarker
from utils.dataclasses import BenchmarkConfig, BenchmarkResults
@ -29,16 +28,6 @@ from utils.dataclasses import BenchmarkConfig, BenchmarkResults
default=Path(os.path.dirname(os.path.realpath(__file__)), "../../..",
"cpp/build/benchmarks/gptSessionBenchmark").absolute(),
help="Path to TRT-LLM gptSession benchmark binary.")
@click.option("--max-tokens-in-kv-cache",
type=int,
default=None,
help="Maximum number of tokens to store in KV cache")
@click.option(
"--kv-cache-mem-percent",
type=float,
default=0.9,
help="The percentage of free memory that the KV Cache is allowed to occupy.",
)
@click.option("--warm-up-runs",
type=int,
default=2,
@ -54,23 +43,20 @@ from utils.dataclasses import BenchmarkConfig, BenchmarkResults
@click.pass_obj
def static_benchmark(benchmark_cfg: BenchmarkConfig, batch: int, isl: int,
osl: int, gpt_session_path: Path, warm_up_runs: int,
num_runs: int, duration: int, max_tokens_in_kv_cache: int,
kv_cache_mem_percent: float):
num_runs: int, duration: int):
"""Run a static benchmark with a fixed batch size, ISL, and OSL."""
if max_tokens_in_kv_cache is None:
max_tokens_in_kv_cache = batch * isl
benchmark_cfg.max_batch_size = batch
benchmarker = gptSessionBenchmarker(
benchmark_cfg,
gpt_session_path,
batch,
benchmark_cfg.max_batch_size,
isl,
osl,
warm_up_runs,
num_runs,
duration,
max_tokens_in_kv_cache,
kv_cache_mem_percent,
benchmark_cfg.kv_cache_mem_percentage,
)
print(f"Building TRT-LLM engine for '{benchmark_cfg.model}'...")
@ -79,5 +65,5 @@ def static_benchmark(benchmark_cfg: BenchmarkConfig, batch: int, isl: int,
print("Build complete. Running benchmark...")
result: BenchmarkResults = benchmarker.benchmark()
print(f"JSON: {json.dumps(result.model_dump())}")
print(f"JSON: {result.model_dump_json()}")
print(result.get_summary(benchmarker.config))

View File

@ -16,6 +16,8 @@ VALID_QUANT_ALGOS = Literal["None", f"{QuantAlgo.W8A16}", f"{QuantAlgo.W4A16}",
f"{QuantAlgo.W4A16_AWQ}", f"{QuantAlgo.W4A8_AWQ}",
f"{QuantAlgo.W4A16_GPTQ}", f"{QuantAlgo.FP8}",
f"{QuantAlgo.INT8}"]
VALID_SCHEDULING_POLICIES = \
Literal["max_utilization", "guaranteed_no_evict", "static"]
class _MethodFunctionAdapter:
@ -130,6 +132,7 @@ def run_process(cmd: List[Any],
Returns:
subprocess.CompletedProcess: _description_
"""
run_dir.mkdir(parents=True, exist_ok=True)
result = subprocess.run(
[str(x) for x in cmd],
cwd=run_dir,

View File

@ -1,10 +1,6 @@
from pathlib import Path
from subprocess import CompletedProcess
from typing import Dict, List, Protocol
from typing import Protocol
from utils import command_logger, process_error_check, run_process
from utils.dataclasses import BenchmarkConfig, BenchmarkResults
from utils.trtllm_config import TRTLLMConfig
from utils.dataclasses import BenchmarkResults
class Benchmarker(Protocol):
@ -17,212 +13,3 @@ class Benchmarker(Protocol):
def benchmark(self) -> BenchmarkResults:
"""Benchmark the constructed model container by a benchmarker."""
...
class gptSessionBenchmarker:
"""Utility class for running static benchmarks with gptSessionBenchmark."""
def __init__(
self,
config: BenchmarkConfig,
benchmark_binary: Path,
batch_size: int,
isl: int,
osl: int,
warm_up_runs: int,
num_runs: int,
duration: int,
max_tokens_in_kv_cache: int,
kv_cache_free_fraction: float = .9,
):
"""Initialize a gptSessionBenchmark instance.
Args:
config (BenchmarkConfig): Benchmark configuration for build/run.
benchmark_binary (Path): Path to the benchmarking binary.
batch_size (int): Batch size to configure the build with.
isl (int): Input sequence length to configure the build with.
osl (int): Output sequence length to configure the build with.
max_tokens_in_kv_cache (int): The maximum number of tokens to store
in the KV cache
kv_cache_free_fraction (float, optional): The amount of remaining
GPU memory after model loading to save for the KV Cache. Defaults
to .9.
"""
self.config: BenchmarkConfig = config
self.gpt_session_path = Path(benchmark_binary).absolute()
self.batch_size = batch_size
self.input_length = isl
self.output_length = osl
self.warm_up = warm_up_runs
self.num_runs = num_runs
self.duration = duration
self.kv_cache_mem = kv_cache_free_fraction
self.max_tokens_in_kv_cache = max_tokens_in_kv_cache
self.result = None
def get_build_command(self) -> List[str]:
"""Build the engine command for TRT-LLM.
Returns:
List[str]: A list of command line arguments to run a build command.
"""
model = self.config.model
tp = self.config.tensor_parallel
pp = self.config.pipeline_parallel
dtype = self.config.dtype
kv_dtype = self.config.cache_dtype
quant_algo = self.config.quantization.value
output_dir = self.config.engine_path
max_batch_size = self.batch_size
max_isl = self.input_length
max_osl = self.output_length
workspace = self.config.workspace
# Generate the TRT-LLM Configuration file using the dataclass
# NOTE: This method does not use weights.
trtllm_config = TRTLLMConfig.from_hf(model, tp, pp, dtype, quant_algo,
kv_dtype)
# Write the generated configuration file to the benchmark workspace.
trtllm_config.to_json(workspace)
# Return the full command for building TRT-LLM via subprocess call.
cmd = [
"trtllm-build",
"--output_dir",
output_dir,
"--model_config",
Path(workspace, "generated_config.json"),
"--workers",
self.config.world_size,
# Define the maximums the engine can accept.
"--max_batch_size",
max_batch_size,
"--max_input_len",
max_isl,
"--max_output_len",
max_osl,
"--context_fmha",
"enable",
# Set the attention plugin data type.
"--gpt_attention_plugin",
dtype.value,
# Disable paged cache since we aren't batching on the fly.
"--paged_kv_cache",
"disable",
] + kv_dtype.get_build_options(dtype)
return [str(arg) for arg in cmd]
@command_logger(prefix="BUILD COMMAND: ")
@process_error_check
def _run_build(self, cmd: List[str]) -> CompletedProcess:
"""Wrapper for calling the build for TRT-LLM.
Purpose of this wrapper is so that we can decorate it/log it.
Args:
cmd (List[str]): List of command line arguments for running.
Returns:
CompletedProcess: Completed process information for parsing and
reporting.
"""
return run_process(
cmd,
self.config.workspace,
)
def build(self) -> None:
"""Build the engine for benchmarking."""
self._run_build(self.get_build_command())
@command_logger(prefix="BENCHMARK COMMAND: ")
@process_error_check
def _run_benchmark(self, cmd: List[str]) -> CompletedProcess:
"""Run the benchmark command in the configured workspace.
Args:
cmd (List[str]): List of command line arguments to run via
subprocess.
Returns:
CompletedProcess: Completed process information for reporting.
"""
return run_process(cmd, run_dir=self.config.workspace, use_environ=True)
@staticmethod
def parse_benchmark_result(benchmark_line: str) -> Dict[str, str]:
pass
def benchmark(self):
"""Benchmarks a TRT-LLM for a configured instance."""
# Compile the command for running
cmd = [
"mpirun",
"-allow-run-as-root",
"-n",
self.config.world_size,
self.gpt_session_path,
"--engine_dir",
self.config.engine_path,
"--batch_size",
self.batch_size,
"--log_level",
"info",
"--max_tokens_in_paged_kvcache",
self.max_tokens_in_kv_cache,
"--kv_cache_free_gpu_mem_fraction",
self.kv_cache_mem,
"--beam_width",
"1",
"--warm_up",
self.warm_up,
"--num_runs",
self.num_runs,
"--duration",
self.duration,
"--input_output_len",
f"{self.input_length},{self.output_length};{self.input_length},1",
]
cmd = [str(arg) for arg in cmd]
# Run the benchmark using the provided gptSession benchmark binary.
bench_return = self._run_benchmark(cmd)
results = [
x.split(" ") for x in bench_return.stdout.split("\n")
if "[BENCHMARK]" in x
]
ttft = float(results[1][8])
gen_time = float(results[0][8]) - ttft
total_out = int(results[0][2]) * int(results[0][6])
total_in = int(results[0][2]) * int(results[0][4])
batch_size = int(results[0][2])
bench_result = BenchmarkResults(
model=self.config.model,
dtype=self.config.dtype.value,
quantization=str(self.config.quantization.value),
max_batch_size=batch_size,
total_input_tokens=total_in,
total_output_tokens=total_out,
tp_size=self.config.tensor_parallel,
pp_size=self.config.pipeline_parallel,
kv_mem_fraction=self.kv_cache_mem,
scheduler="Static",
max_tokens_in_cache=self.max_tokens_in_kv_cache,
inflight_batching=False,
total_latency=results[0][8],
first_token_latency=ttft,
time_per_output_token=gen_time / (total_out - batch_size),
latency_units="ms",
throughput=results[0][10],
throughput_units="tokens/second",
peak_gpu_mem=results[0][16],
peak_gpu_mem_units="GB",
binary=str(self.gpt_session_path),
build_cmd=" ".join(self.get_build_command()),
benchmark_cmd=" ".join(cmd))
return bench_result

View File

@ -1,29 +1,52 @@
from __future__ import annotations
from pathlib import Path
from typing import Literal
from typing import List, Literal, Optional, Union, get_args
from pydantic import BaseModel, computed_field
from utils import VALID_MODELS
from utils.enums import ComputeDtypeEnum, KVCacheDtypeEnum, QuantizationAlgo
from pydantic import (BaseModel, Field, ValidationError, computed_field,
field_validator, model_validator)
from transformers import AutoConfig
from utils import VALID_MODELS, VALID_SCHEDULING_POLICIES
from utils.enums import (ComputeDtypeEnum, KVCacheDtypeEnum, ModelArchitecture,
QuantizationAlgo)
class InferenceRequest(BaseModel):
task_id: int
prompt: Optional[str] = None
output_tokens: int
logits: Optional[List[int]] = None
@model_validator(mode="after")
def verify_prompt_and_logits(self) -> InferenceRequest:
if self.prompt is None and self.logits is None:
raise ValueError(
f"Both prompt and logits for {self.task_id} are both None.")
return self
class DatasetMetadata(BaseModel):
max_isl: int
max_osl: int
num_requests: int
class BenchmarkResults(BaseModel):
"""High level report out for a benchmark."""
benchmark_cmd: str = ""
binary: str
binary: str = ""
build_cmd: str = ""
first_token_latency: float
inflight_batching: bool
kv_mem_fraction: float
latency_units: str
max_batch_size: int
max_tokens_in_cache: int
model: VALID_MODELS
max_tokens: int = 0
model: Union[VALID_MODELS, Path]
peak_gpu_mem_units: str
peak_gpu_mem: float
scheduler: Literal["Static", "No evict", "Max Utilization"]
scheduler: Literal["Static", "No Evict", "Max Utilization"]
throughput_units: str
throughput: float
time_per_output_token: float
@ -52,7 +75,6 @@ class BenchmarkResults(BaseModel):
f"In-flight Batcher?:\t{self.inflight_batching}\n"
f"Dtype:\t\t\t{config.dtype.value}\n"
f"KV Cache Dtype:\t\t{config.cache_dtype.value}\n"
f"KV Cache Size (tokens):\t{self.max_tokens_in_cache}\n"
f"Quantization:\t\t{config.quantization.value}\n"
f"KV Memory Percentage:\t{self.kv_mem_fraction * 100}%\n"
f"\n"
@ -62,13 +84,15 @@ class BenchmarkResults(BaseModel):
f"Engine Directory:\t{config.engine_path}\n"
f"Max Batch Size:\t\t{self.max_batch_size}\n"
f"Total Input Length:\t{self.total_input_tokens}\n"
f"Total Output Length:\t{self.total_input_tokens}\n"
f"Total Output Length:\t{self.total_output_tokens}\n"
f"Max Tokens:\t\t{self.max_tokens}\n"
f"\n"
"===========================================================\n"
"= STATISTICS\n"
"===========================================================\n"
f"Throughput ({self.throughput_units}):\t{self.throughput}\n"
f"Total Latency ({self.latency_units}):\t\t{self.total_latency}\n"
f"Total Latency ({self.latency_units}):"
f"\t\t{self.total_latency * 1000.0:.4f}\n"
f"First Token Latency ({self.latency_units}):\t{self.first_token_latency}\n"
f"Token-to-token Latency ({self.latency_units}):\t{self.time_per_output_token}\n"
f"Peak GPU Memory Usage ({self.peak_gpu_mem_units}):\t{self.peak_gpu_mem}\n"
@ -83,18 +107,81 @@ class BenchmarkResults(BaseModel):
class BenchmarkConfig(BaseModel):
"""Basic configuration of a benchmark."""
model: VALID_MODELS
model: Union[VALID_MODELS, Path]
workspace: Path
max_batch_size: int
dtype: ComputeDtypeEnum
cache_dtype: KVCacheDtypeEnum
quantization: QuantizationAlgo
tensor_parallel: int
pipeline_parallel: int
max_tokens: int = 0
kv_cache_mem_percentage: float = .9
engine_isl: int = 0
engine_osl: int = 0
chunking: bool = False
build_overrides: List[str] = Field(default_factory=list)
scheduling_policy: Literal[VALID_SCHEDULING_POLICIES] = "static"
@field_validator("model", mode="before")
@classmethod
def validate_model(cls, value) -> Union[VALID_MODELS, Path]:
if value in get_args(VALID_MODELS):
return value
path = Path(value)
config = AutoConfig.from_pretrained(str(path.absolute()))
for arch in config.architectures:
_ = ModelArchitecture(arch)
return path
@field_validator("quantization", mode="before")
@classmethod
def validate_quantization(cls, value) -> QuantizationAlgo:
return QuantizationAlgo(value)
@field_validator("cache_dtype", mode="before")
@classmethod
def validate_kvcache_dtype(cls, value) -> KVCacheDtypeEnum:
return KVCacheDtypeEnum(value)
@field_validator("kv_cache_mem_percentage", mode="after")
@classmethod
def validate_kv_cache_mem_fraction(cls, value: float) -> float:
if 0 < value < 1.0:
return value
else:
raise ValidationError(
"KV cache memory percentage must be between 0 and 1.0.")
@field_validator("build_overrides", mode="before")
@classmethod
def validate_build_overrides(cls, value) -> List[str]:
# If we encounter a list, scan it to make sure all entries are strings.
if isinstance(value, list):
if not all([isinstance(x, str) for x in value]):
raise ValidationError(
"Found a non-string entry in list of options.")
return value
elif isinstance(value, str):
# Handle the case where we receive a single string of command
# options.
overrides = []
if value:
overrides = [str(x) for x in value.split()]
return overrides
else:
raise ValidationError(
"Invalid value specified for build overrides.")
@computed_field
def engine_path(self) -> Path:
"""Path to the engine workspace."""
return Path(self.workspace.absolute(), self.model.lower())
if self.model in get_args(VALID_MODELS):
return Path(self.workspace.absolute(), self.model.lower())
else:
return Path(self.workspace.absolute(), "engine")
@computed_field
def world_size(self) -> int:

View File

@ -4,8 +4,34 @@ from typing import List
from aenum import MultiValueEnum
from tensorrt_llm.bindings.executor import CapacitySchedulerPolicy
from tensorrt_llm.quantization.mode import QuantAlgo
NO_EVICT = "Guaranteed No Evict"
MAX_UTIL = "Max Utilization"
class ModelArchitecture(MultiValueEnum):
LLAMA = "LlamaForCausalLM"
GPTJ = "GPTJForCausalLM"
GEMMA = "GemmaForCausalLM"
BLOOM = "BloomForCausalLM"
OPT = "OPTForCausalLM"
MIXTRAL = "MixtralForCausalLM"
FALCON = "FalconForCausalLM"
class ResultsSchedulingPolicy(MultiValueEnum):
MAX_UTILIZTION = MAX_UTIL, CapacitySchedulerPolicy.MAX_UTILIZATION
NO_EVICT = NO_EVICT, CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
STATIC = "Static"
class IFBSchedulingPolicy(MultiValueEnum):
MAX_UTILIZTION = CapacitySchedulerPolicy.MAX_UTILIZATION, MAX_UTIL, "max_utilization"
NO_EVICT = CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, NO_EVICT, "guaranteed_no_evict"
STATIC = "Static", "static"
class KVCacheDtypeEnum(MultiValueEnum):
"""Enumeration of KV Cache precisions in TRT-LLM."""
@ -13,11 +39,11 @@ class KVCacheDtypeEnum(MultiValueEnum):
FP16 = None, "FP16", "fp16", "float16"
INT8 = "INT8", "int8"
def get_build_options(self, dtype: ComputeDtypeEnum) -> List[str]:
def get_build_options(self, dtype: str) -> List[str]:
"""Get the build options for TRT-LLM based on KV Cache precision.
Args:
dtype (ComputeDtypeEnum): The activation dtype for the model. This
dtype (str): The activation dtype for the model. This
parameter maps the activation dtype for GEMM plugins for certain
KV cache precisions.
@ -28,7 +54,7 @@ class KVCacheDtypeEnum(MultiValueEnum):
if self.value == self.FP8:
return ["--strongly_typed"]
else:
return ["--gemm_plugin", dtype.value]
return ["--gemm_plugin", dtype]
class ComputeDtypeEnum(MultiValueEnum):

View File

@ -3,10 +3,10 @@ import os
from argparse import ArgumentParser
from typing import Literal, Optional
from pydantic import AliasChoices, AliasPath, BaseModel, Field, model_validator
from pydantic import (AliasChoices, AliasPath, BaseModel, Field, computed_field,
model_validator)
from transformers import AutoConfig
from utils import VALID_QUANT_ALGOS
from utils.enums import ComputeDtypeEnum, KVCacheDtypeEnum
PET_dict = {
"tiiuae/falcon-7b": "rope_gpt_neox",
@ -15,18 +15,37 @@ PET_dict = {
"meta-llama/Llama-2-7b-hf": "rope_gpt_neox",
"meta-llama/Llama-2-13b-hf": "rope_gpt_neox",
"meta-llama/Llama-2-70b-hf": "rope_gpt_neox",
"meta-llama/Meta-Llama-3-8B": "rope_gpt_neox",
"meta-llama/Meta-Llama-3-70B": "rope_gpt_neox",
"EleutherAI/gpt-j-6b": "rope_gptj",
"bigscience/bloom-560m": "alibi",
"mistralai/Mistral-7B-v0.1": "rope_gpt_neox",
"mistralai/Mixtral-8x7B-v0.1": "rope_gpt_neox",
"mistralai/Mixtral-8x22B-v0.1": "rope_gpt_neox",
"01-ai/Yi-6B": "rope_gpt_neox",
"01-ai/Yi-34B": "rope_gpt_neox",
"codellama/CodeLlama-7b-hf": "rope_gpt_neox",
"codellama/CodeLlama-13b-hf": "rope_gpt_neox",
"codellama/CodeLlama-34b-hf": "rope_gpt_neox",
"codellama/CodeLlama-70b-hf": "rope_gpt_neox",
"facebook/opt-125m": "learned_absolute",
"facebook/opt-350m": "learned_absolute",
"facebook/opt-1.3b": "learned_absolute",
"facebook/opt-2.7b": "learned_absolute",
"facebook/opt-13b": "learned_absolute",
"facebook/opt-30b": "learned_absolute",
"facebook/opt-66b": "learned_absolute",
"google/gemma-7b": "rope_gpt_neox",
"google/gemma-2b": "rope_gpt_neox",
}
HA_dict = {
"tiiuae/falcon-7b": "gelu",
"tiiuae/falcon-40b": "gelu",
"tiiuae/falcon-180B": "gelu",
"bigscience/bloom-560m": "gelu",
"mistralai/Mixtral-8x7B-v0.1": "swiglu",
}
ALLOWED_MODELS = list(PET_dict.keys())
class TRTLLM_Mapping(BaseModel):
@ -42,23 +61,20 @@ class TRTLLM_Mapping(BaseModel):
class TRTLLM_Quantization(BaseModel):
quant_algo: Optional[VALID_QUANT_ALGOS] = None
kv_cache_quant_algo: Optional[Literal[None, "FP8", "INT8"]] = None
group_size: int = 128
has_zero_point: bool = False
pre_quant_scale: bool = False
exclude_modules: Optional[list] = None
class TRTLLM_CheckpointConfig(BaseModel):
"""Dataclass for building TRT-LLM model configurations."""
class TRTLLMConfig(BaseModel):
_VALID_EMBED_TYPE = Literal["learned_absolute", "rope_gptj",
"rope_gpt_neox", "alibi", "alibi_with_scale",
"relative", "chatglm", ]
architecture: str = Field(validation_alias=AliasPath("architectures", 0))
architecture: str = Field(validation_alias=AliasChoices(
'architecture', AliasPath("architectures", 0)))
num_hidden_layers: int = Field(validation_alias=AliasChoices(
"num_hidden_layers", "n_layer", "n_layers"))
num_attention_heads: int = Field(validation_alias=AliasChoices(
@ -72,13 +88,15 @@ class TRTLLM_CheckpointConfig(BaseModel):
validation_alias=AliasChoices("hidden_size", "n_embd", "d_model"))
norm_epsilon: float = Field(
default=1e-5,
validation_alias=AliasChoices("norm_epsilon", "layer_norm_epsilon"),
validation_alias=AliasChoices("norm_epsilon", "layer_norm_epsilon",
"rms_norm_eps"),
)
vocab_size: int
max_position_embeddings: Optional[int] = Field(
default=None,
validation_alias=AliasChoices("max_position_embeddings", "n_positions"),
)
head_size: Optional[int] = None
hidden_act: str = Field(
validation_alias=AliasChoices("hidden_act", "activation_function"))
# falcon options
@ -102,33 +120,46 @@ class TRTLLM_CheckpointConfig(BaseModel):
intermediate_size: int = None
use_prompt_tuning: bool = False
sliding_window: Optional[int] = None
moe_num_experts: Optional[int] = Field(
default=0, validation_alias=AliasChoices("num_local_experts"))
moe_top_k: Optional[int] = Field(
default=0, validation_alias=AliasChoices("num_experts_per_tok"))
rotary_base: Optional[float] = Field(
default=None, validation_alias=AliasChoices("rope_theta"))
mapping: TRTLLM_Mapping
quantization: TRTLLM_Quantization
@computed_field
def kv_dtype(self) -> str:
if self.quantization.kv_cache_quant_algo == "FP8":
return "fp8"
elif self.quantization.kv_cache_quant_algo == "INT8":
return "int8"
else:
return self.dtype
@model_validator(mode="after")
def set_kv_head_default_value(self) -> "TRTLLM_CheckpointConfig":
def set_values_if_none(self) -> "TRTLLM_CheckpointConfig":
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.head_size is None:
self.head_size = self.hidden_size // self.num_attention_heads
return self
class TRTLLMConfig:
def __init__(self, trtllm_config, hf_config=None) -> None:
self.trtllm_config = trtllm_config
self.hf_config = hf_config
# self.nemo_config = nemo_config
@classmethod
def from_hf(
cls,
hf_model_name,
tp,
pp,
dtype=None,
quant_dtype=None,
kv_cache_quant_dtype=None,
):
def populate_build_config(cls,
model_name,
tp,
pp,
dtype=None,
quant_dtype=None,
kv_cache_quant_dtype=None):
"""
Common function to populate build parameters, regardless of network
"""
build_config = {
"mapping": {
"tp_size": tp,
@ -137,27 +168,96 @@ class TRTLLMConfig:
"quantization": {},
}
if dtype:
build_config["dtype"] = ComputeDtypeEnum(dtype).value
build_config["dtype"] = dtype
if quant_dtype:
if not kv_cache_quant_dtype:
# will throw errors during validation if the type is invalid
kv_cache_quant_dtype = KVCacheDtypeEnum(quant_dtype).value
kv_cache_quant_dtype = quant_dtype
build_config["quantization"] = {
"quant_algo": quant_dtype,
"kv_cache_quant_algo":
KVCacheDtypeEnum(kv_cache_quant_dtype).value,
"kv_cache_quant_algo": kv_cache_quant_dtype,
}
build_config["position_embedding_type"] = PET_dict[hf_model_name]
if hf_model_name in HA_dict:
build_config["hidden_act"] = HA_dict[hf_model_name]
if model_name in PET_dict:
build_config["position_embedding_type"] = PET_dict.get(model_name)
return build_config
@classmethod
def from_hf(cls,
hf_model_name,
tp,
pp,
dtype=None,
quant_dtype=None,
kv_cache_quant_dtype=None):
"""
Use transformers.AutoConfig to load a model's config from a HF name
"""
build_config = cls.populate_build_config(hf_model_name, tp, pp, dtype,
quant_dtype,
kv_cache_quant_dtype)
hf_config = AutoConfig.from_pretrained(hf_model_name).to_dict()
trtllm_config = TRTLLM_CheckpointConfig(**hf_config,
**build_config).model_dump()
return cls(trtllm_config, hf_config)
if hf_model_name in HA_dict:
hf_config["hidden_act"] = HA_dict[hf_model_name]
return cls(**hf_config, **build_config)
@classmethod
def from_json(cls,
model_name,
tp,
pp,
dtype=None,
quant_dtype=None,
kv_cache_quant_dtype=None):
"""
Load model parameters from a custom json file
A full path can be specified. Otherwise, look for ./trtllm_configs/(model_name).json
"""
build_config = cls.populate_build_config(model_name, tp, pp, dtype,
quant_dtype,
kv_cache_quant_dtype)
if os.path.exists(model_name):
path_to_json = model_name
else:
path_to_json = os.path.join(os.path.dirname(__file__),
f"trtllm_configs/{model_name}.json")
if not os.path.exists(path_to_json):
raise FileNotFoundError(f"{path_to_json} not found")
json_config = json.load(open(path_to_json))
return cls(**json_config, **build_config)
@classmethod
def from_name(cls,
model,
tp,
pp,
dtype=None,
quant_dtype=None,
kv_cache_quant_dtype=None):
"""
Attempts to create a config based on model name. Performs the following steps:
1. Tries to load the HF config using AutoConfig. This will only work if the network name exists on HF.
2. If this fails, try to load a custom config stored on $HF_HOME/custom/*.json
"""
try:
trtllm_config = cls.from_hf(model, tp, pp, dtype, quant_dtype,
kv_cache_quant_dtype)
except EnvironmentError:
try:
trtllm_config = cls.from_json(model, tp, pp, dtype, quant_dtype,
kv_cache_quant_dtype)
except FileNotFoundError as e:
raise NameError(
f"Unable to create PretrainedConfig from {model} due to {e}"
)
return trtllm_config
# future possibilities
# def from_nemo_config (self, nemo_model_name)
def to_json(self, output_dir):
with open(os.path.join(output_dir, "generated_config.json"), "w") as f:
json.dump(self.trtllm_config, f, indent=4)
json.dump(self.model_dump(), f, indent=4)
if __name__ == "__main__":
@ -205,14 +305,19 @@ if __name__ == "__main__":
type=str,
help="TRT-LLM argument",
)
parser.add_argument(
"--populate_hf_cache",
action='store_true',
help="Populate the HF cache with all the supported networks",
)
args = parser.parse_args()
trtllm_config = TRTLLMConfig.from_hf(
args.model,
args.tp_size,
args.pp_size,
args.dtype,
args.quant_dtype,
args.kv_cache_quant_dtype,
)
trtllm_config.to_json(os.getcwd())
if args.populate_hf_cache:
for net in PET_dict.keys():
_ = AutoConfig.from_pretrained(net)
else:
trtllm_config = TRTLLMConfig.from_name(args.model, args.tp_size,
args.pp_size, args.dtype,
args.quant_dtype,
args.kv_cache_quant_dtype)
trtllm_config.to_json(os.getcwd())

View File

@ -32,6 +32,7 @@ option(BUILD_PYBIND "Build Python bindings for C++ runtime and batch manager"
ON)
option(BUILD_TESTS "Build Google tests" ON)
option(BUILD_BENCHMARKS "Build benchmarks" ON)
option(BUILD_MICRO_BENCHMARKS "Build C++ micro benchmarks" OFF)
option(NVTX_DISABLE "Disable all NVTX features" ON)
option(WARNING_IS_ERROR "Treat all warnings as errors" OFF)
option(FAST_BUILD "Skip compiling some kernels to accelerate compiling" OFF)
@ -111,6 +112,12 @@ else()
message(STATUS "Not building benchmarks")
endif()
if(BUILD_MICRO_BENCHMARKS)
message(STATUS "Building C++ micro benchmarks")
else()
message(STATUS "Not building C++ micro benchmarks")
endif()
if(FAST_BUILD)
add_compile_definitions("FAST_BUILD")
message(WARNING "Skip some kernels to accelerate compilation")
@ -506,6 +513,11 @@ if(BUILD_BENCHMARKS)
${CMAKE_BINARY_DIR}/benchmarks)
endif()
if(BUILD_MICRO_BENCHMARKS)
add_subdirectory(${TRT_LLM_ROOT_DIR}/cpp/micro_benchmarks
${CMAKE_BINARY_DIR}/micro_benchmarks)
endif()
# Measure the compile time
option(MEASURE_BUILD_TIME "Measure the build time of each module" OFF)
if(MEASURE_BUILD_TIME)

View File

@ -53,8 +53,7 @@ public:
SendResponseCallback sendResponseCb, PollStopSignalCallback pollStopSignalCb = nullptr,
ReturnBatchManagerStatsCallback returnBatchManagerStatsCb = nullptr,
TrtGptModelOptionalParams const& optionalParams = TrtGptModelOptionalParams(),
std::optional<uint64_t> terminateReqId = std::nullopt, std::optional<SizeType32> maxDraftTokens = std::nullopt,
bool excludeInputInOutput = false);
std::optional<uint64_t> terminateReqId = std::nullopt, bool excludeInputInOutput = false);
/* Wraps the user-provided callback for requests.
Adds requests to request table.
@ -109,7 +108,6 @@ private:
std::shared_ptr<TrtGptModel> mTrtGptModel;
std::optional<uint64_t> mTerminateReqId;
std::optional<SizeType32> mMaxDraftTokens;
// Iteration counter - incremented every iteration of the generation loop
int64_t mIterationCounter;

View File

@ -89,7 +89,6 @@ public:
, mLoraTaskId(loraTaskId)
, mLoraWeights(std::move(loraWeights))
, mLoraConfig(std::move(loraConfig))
, mReturnLogProbs(returnLogProbs)
, mContextChunkSize(std::nullopt)
, mContextCurrentPosition(0)
, mLogProbs(samplingConfig.beamWidth)
@ -101,15 +100,16 @@ public:
, mReturnGenerationLogits(returnGenerationLogits)
, mExcludeInputFromOutput(excludeInputFromOutput)
, mEncoderInputTokens(encoderInputTokens)
, mDecodingIter(0)
{
initialize(*inputTokens);
initialize(*inputTokens, returnLogProbs);
}
GenericLlmRequest(RequestIdType requestId, executor::Request const& req)
: mRequestId(requestId)
, mPromptLen(req.getInputTokenIds().size())
, mMaxNewTokens(req.getMaxNewTokens())
, mSamplingConfig(req.getSamplingConfig(), req.getSpeculativeDecodingConfig())
, mSamplingConfig(req.getSamplingConfig(), req.getExternalDraftTokensConfig())
, mState(REQUEST_STATE_CONTEXT_INIT)
, mIsStreaming(req.getStreaming())
, mEndId(req.getEndId())
@ -124,7 +124,6 @@ public:
, mLoraTaskId(std::nullopt)
, mLoraWeights(std::nullopt)
, mLoraConfig(std::nullopt)
, mReturnLogProbs(req.getOutputConfig().returnLogProbs)
, mContextChunkSize(std::nullopt)
, mContextCurrentPosition(0)
, mLogProbs(mSamplingConfig.beamWidth)
@ -135,6 +134,7 @@ public:
, mReturnContextLogits(req.getOutputConfig().returnContextLogits)
, mReturnGenerationLogits(req.getOutputConfig().returnGenerationLogits)
, mExcludeInputFromOutput(req.getOutputConfig().excludeInputFromOutput)
, mDecodingIter(0)
{
if (req.getEmbeddingBias())
{
@ -178,20 +178,20 @@ public:
}
}
auto speculativeDecodingConfig = req.getSpeculativeDecodingConfig();
if (speculativeDecodingConfig)
auto externalDraftTokensConfig = req.getExternalDraftTokensConfig();
if (externalDraftTokensConfig)
{
mDraftTokens = std::make_shared<VecTokens>(speculativeDecodingConfig.value().getTokens());
mDraftTokens = std::make_shared<VecTokens>(externalDraftTokensConfig.value().getTokens());
if (speculativeDecodingConfig.value().getLogits())
if (externalDraftTokensConfig.value().getLogits())
{
mDraftLogits = executor::detail::toITensor(speculativeDecodingConfig.value().getLogits().value());
mDraftLogits = executor::detail::toITensor(externalDraftTokensConfig.value().getLogits().value());
}
// NOTE: Draft acceptance threshold is stored in mSamplingConfig
}
initialize(req.getInputTokenIds());
initialize(req.getInputTokenIds(), req.getOutputConfig().returnLogProbs);
}
void validate(SizeType32 maxInputLen, SizeType32 maxSequenceLen, SizeType32 maxDraftLen)
@ -233,13 +233,7 @@ public:
mMaxNewTokens = maxNewTokens;
}
if (mSamplingConfig.beamWidth <= 0)
{
TLLM_THROW(
"Requested value: %d for beamWidth is invalid. To de-activate beam searching "
"set beamWidth to 1 instead.",
mSamplingConfig.beamWidth);
}
TLLM_CHECK_WITH_INFO(mSamplingConfig.validate(), "Incorrect sampling config");
}
void setExcludeInputFromOutput(bool exclude)
@ -379,7 +373,7 @@ public:
{
auto& beamTokens = mTokens.at(beam);
beamTokens.resize(mPromptLen);
if (mReturnLogProbs)
if (returnLogProbs())
{
mLogProbs.at(beam).clear();
}
@ -393,7 +387,7 @@ public:
auto& beamTokens = mTokens.at(beam);
beamTokens.resize(newPromptLen);
if (mReturnLogProbs)
if (returnLogProbs())
{
auto& logProb = mLogProbs.at(beam);
logProb.resize(newPromptLen - mPromptLen);
@ -491,12 +485,13 @@ public:
[[nodiscard]] bool returnLogProbs() const
{
return mReturnLogProbs;
return mSamplingConfig.outputLogProbs.has_value() ? mSamplingConfig.outputLogProbs->at(0) : false;
}
void setReturnLogProbs(bool returnLogProbs)
{
mReturnLogProbs = returnLogProbs;
mSamplingConfig.outputLogProbs = {{returnLogProbs}};
mSamplingConfig.cumLogProbs = {{returnLogProbs}};
}
[[nodiscard]] std::vector<VecLogProbs> const& getLogProbs() const
@ -728,6 +723,23 @@ public:
}
}
/// Increment the counter of decoding iterations.
void advanceDecodingIter()
{
mDecodingIter++;
}
/// @brief Return the average number of decoded tokens per iteration. For standard model it is 1.
/// For speculative decoding model >= 1 -- number of draft tokens accepted per step + 1.
[[nodiscard]] float getAvgDecodedTokensPerIter() const noexcept
{
if (mDecodingIter == 0)
{
return 0.f;
}
return static_cast<float>(getMaxNumGeneratedTokens()) / mDecodingIter;
}
/// @brief Create a Response from the current state of the request
/// @return An optional Response
std::optional<executor::Response> createResponse()
@ -841,8 +853,6 @@ protected:
// encoder output, saved for computing cross attention KV Cache
TensorPtr mEncoderOutput;
bool mReturnLogProbs;
// To enable chunked context, the FHMA paged kv-cache also needs to be enabled. Except for the last one,
// the size of the context chunk needs to be an integer multiple of the kv-cache block size. The meaning
// of null value is that the context is not chunked.
@ -868,8 +878,10 @@ protected:
std::shared_ptr<VecTokens>
mEncoderInputTokens; // Input tokens to the encoder for enc only models and enc-dec models
SizeType32 mDecodingIter;
private:
void initialize(VecTokens const& inputTokens)
void initialize(VecTokens const& inputTokens, bool outputLogProbs)
{
// Scatter the input tokens to other beam
mTokens = BeamTokens(mSamplingConfig.beamWidth, inputTokens);
@ -888,6 +900,8 @@ private:
{
TLLM_THROW("Draft tokens must be specified when draft logits are given.");
}
setReturnLogProbs(outputLogProbs);
}
TensorPtr createListTensor(std::list<VecTokens> const& wordsList)

View File

@ -24,6 +24,7 @@
#include <future>
#include <memory>
#include <stdexcept>
#include <unordered_map>
#include <unordered_set>
#include <vector>
@ -33,6 +34,13 @@ namespace tensorrt_llm::batch_manager
using runtime::SizeType32;
class PeftTaskNotCachedException : public std::runtime_error
{
public:
explicit PeftTaskNotCachedException(std::string const& msg);
~PeftTaskNotCachedException() noexcept override;
};
/**
* BasePeftCacheManager
*

View File

@ -21,8 +21,6 @@
#include "tensorrt_llm/batch_manager/peftCacheManagerConfig.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/decodingMode.h"
#include "tensorrt_llm/runtime/medusaModule.h"
#include <optional>
#include <vector>
@ -40,18 +38,15 @@ public:
explicit TrtGptModelOptionalParams(KvCacheConfig const& kvCacheConfig = KvCacheConfig{},
bool enableTrtOverlap = false, std::optional<std::vector<SizeType32>> const& deviceIds = std::nullopt,
bool normalizeLogProbs = true, bool enableChunkedContext = false,
std::optional<runtime::DecodingMode> const& decodingMode = std::nullopt,
PeftCacheManagerConfig const& peftCacheManagerConfig = PeftCacheManagerConfig{},
std::optional<runtime::MedusaModule::MedusaChoices> const& medusaChoices = std::nullopt,
float gpuWeightsPercent = 1)
executor::DecodingConfig const& decodingConfig = executor::DecodingConfig{}, float gpuWeightsPercent = 1)
: kvCacheConfig{kvCacheConfig}
, enableTrtOverlap{enableTrtOverlap}
, deviceIds(deviceIds)
, normalizeLogProbs{normalizeLogProbs}
, enableChunkedContext{enableChunkedContext}
, decodingMode{decodingMode}
, peftCacheManagerConfig(peftCacheManagerConfig)
, medusaChoices(medusaChoices)
, decodingConfig(decodingConfig)
, gpuWeightsPercent(gpuWeightsPercent)
{
}
@ -60,10 +55,9 @@ public:
: TrtGptModelOptionalParams(KvCacheConfig(executorConfig.getKvCacheConfig()), false,
executorConfig.getParallelConfig().value_or(executor::ParallelConfig()).getDeviceIds(),
executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext(),
runtime::DecodingMode::fromExecutor(
executorConfig.getDecodingMode().value_or(executor::DecodingMode::kNONE)),
PeftCacheManagerConfig(executorConfig.getPeftCacheConfig().value_or(executor::PeftCacheConfig())),
executorConfig.getMedusaChoices(), executorConfig.getGpuWeightsPercent())
executorConfig.getDecodingConfig().value_or(executor::DecodingConfig{}),
executorConfig.getGpuWeightsPercent())
{
}
@ -71,7 +65,7 @@ public:
{
return kvCacheConfig == other.kvCacheConfig && enableTrtOverlap == other.enableTrtOverlap
&& deviceIds == other.deviceIds && normalizeLogProbs == other.normalizeLogProbs
&& enableChunkedContext == other.enableChunkedContext && decodingMode == other.decodingMode;
&& enableChunkedContext == other.enableChunkedContext && decodingConfig == other.decodingConfig;
}
friend std::ostream& operator<<(std::ostream& os, TrtGptModelOptionalParams const& self);
@ -82,9 +76,8 @@ public:
std::optional<std::vector<SizeType32>> deviceIds;
bool normalizeLogProbs;
bool enableChunkedContext;
std::optional<runtime::DecodingMode> decodingMode;
PeftCacheManagerConfig peftCacheManagerConfig;
std::optional<runtime::MedusaModule::MedusaChoices> medusaChoices;
executor::DecodingConfig decodingConfig;
// Percentage of weights on the gpu at runtime
float gpuWeightsPercent;
};

View File

@ -260,6 +260,9 @@ public:
//! \brief Corresponds to `world()` by default, but can be overridden per process.
static MpiComm& session();
//! \brief Returns the MPI local communicator.
static MpiComm& localSession();
[[nodiscard]] MpiComm split(int color, int key) const;
std::shared_ptr<MpiRequest> bcastAsync(void* buffer, size_t size, MpiType dtype, int root) const;
@ -370,3 +373,4 @@ void initialize(MpiThreadSupport threadMode = MpiThreadSupport::THREAD_FUNNELED)
} // namespace tensorrt_llm::mpi
#define COMM_SESSION tensorrt_llm::mpi::MpiComm::session()
#define LOCAL_COMM_SESSION tensorrt_llm::mpi::MpiComm::localSession()

View File

@ -78,7 +78,34 @@ public:
[[nodiscard]] std::optional<FloatType> getLengthPenalty() const;
[[nodiscard]] std::optional<SizeType32> getEarlyStopping() const;
void setBeamWidth(SizeType32 beamWidth);
void setTopK(std::optional<SizeType32> const& topK);
void setTopP(std::optional<FloatType> const& topP);
void setTopPMin(std::optional<FloatType> const& topPMin);
void setTopPResetIds(std::optional<TokenIdType> const& topPResetIds);
void setTopPDecay(std::optional<FloatType> const& topPDecay);
void setRandomSeed(std::optional<RandomSeedType> const& randomSeed);
void setTemperature(std::optional<FloatType> const& temperature);
void setMinLength(std::optional<SizeType32> const& minLength);
void setBeamSearchDiversityRate(std::optional<FloatType> const& beamSearchDiversityRate);
void setRepetitionPenalty(std::optional<FloatType> const& repetitionPenalty);
void setPresencePenalty(std::optional<FloatType> const& presencePenalty);
void setFrequencyPenalty(std::optional<FloatType> const& frequencyPenalty);
void setLengthPenalty(std::optional<FloatType> const& lengthPenalty);
void setEarlyStopping(std::optional<SizeType32> const& earlyStopping);
private:
static SizeType32 checkBeamWidth(SizeType32 beamWidth);
static std::optional<FloatType> const& checkTopK(std::optional<FloatType> const& topK);
static std::optional<FloatType> const& checkTopP(std::optional<FloatType> const& topP);
static std::optional<FloatType> const& checkTopPMin(std::optional<FloatType> const& topPMin);
static std::optional<TokenIdType> const& checkTopPResetIds(std::optional<TokenIdType> const& topPResetIds);
static std::optional<FloatType> const& checkTopPDecay(std::optional<FloatType> const& topPDecay);
static std::optional<FloatType> const& checkTemperature(std::optional<FloatType> const& temperature);
static std::optional<SizeType32> const& checkMinLength(std::optional<SizeType32> const& minLength);
static std::optional<FloatType> const& checkBeamSearchDiversityRate(
std::optional<FloatType> const& beamSearchDiversityRate);
friend class Serialization;
/// @brief The beam width. Default is 1 which disables beam search.
@ -134,12 +161,12 @@ public:
bool excludeInputFromOutput;
};
/// @brief Configuration for speculative decoding. Allows to include draft tokens, draft logits and specify acceptance
/// threshold
class SpeculativeDecodingConfig
/// @brief Configuration for speculative decoding with external draft tokens.
/// Allows to include draft tokens, draft logits and specify acceptance threshold.
class ExternalDraftTokensConfig
{
public:
explicit SpeculativeDecodingConfig(VecTokens tokens, std::optional<Tensor> logits = std::nullopt,
explicit ExternalDraftTokensConfig(VecTokens tokens, std::optional<Tensor> logits = std::nullopt,
std::optional<FloatType> const& acceptanceThreshold = std::nullopt);
[[nodiscard]] VecTokens getTokens() const;
@ -209,7 +236,7 @@ public:
/// @param badWords A list of bad words tokens. Each "word" can be composed of multiple tokens
/// @param stopWords A list of stop words tokens. Each "word" can be composed of multiple tokens
/// @param embeddingBias The embedding bias tensor. Expected type is kFP32 and shape is [vocab_size]
/// @param speculativeDecodingConfig The speculative decoding configuration
/// @param externalDraftTokensConfig The speculative decoding configuration
/// @param pTuningConfig The prompt tuning configuration
/// @param loraConfig The LoRA configuration
/// @param logitsPostProcessorName The logits postprocessor name. Must correspond to one of the logits postprocessor
@ -220,7 +247,7 @@ public:
std::optional<std::list<VecTokens>> badWords = std::nullopt,
std::optional<std::list<VecTokens>> stopWords = std::nullopt,
std::optional<Tensor> embeddingBias = std::nullopt,
std::optional<SpeculativeDecodingConfig> speculativeDecodingConfig = std::nullopt,
std::optional<ExternalDraftTokensConfig> externalDraftTokensConfig = std::nullopt,
std::optional<PromptTuningConfig> pTuningConfig = std::nullopt,
std::optional<LoraConfig> loraConfig = std::nullopt,
std::optional<std::string> logitsPostProcessorName = std::nullopt);
@ -241,7 +268,7 @@ public:
[[nodiscard]] std::optional<std::list<VecTokens>> getBadWords() const;
[[nodiscard]] std::optional<std::list<VecTokens>> getStopWords() const;
[[nodiscard]] std::optional<Tensor> getEmbeddingBias() const;
[[nodiscard]] std::optional<SpeculativeDecodingConfig> getSpeculativeDecodingConfig() const;
[[nodiscard]] std::optional<ExternalDraftTokensConfig> getExternalDraftTokensConfig() const;
[[nodiscard]] std::optional<PromptTuningConfig> getPromptTuningConfig() const;
[[nodiscard]] std::optional<LoraConfig> getLoraConfig() const;
[[nodiscard]] std::optional<std::string> getLogitsPostProcessorName() const;
@ -254,7 +281,7 @@ public:
void setBadWords(std::list<VecTokens> const& badWords);
void setStopWords(std::list<VecTokens> const& stopWords);
void setEmbeddingBias(Tensor const& embeddingBias);
void setSpeculativeDecodingConfig(SpeculativeDecodingConfig const& specDecodingConfig);
void setExternalDraftTokensConfig(ExternalDraftTokensConfig const& externalDraftTokensConfig);
void setPromptTuningConfig(PromptTuningConfig const& pTuningConfig);
void setLoraConfig(LoraConfig const& loraConfig);
void setLogitsPostProcessorName(std::string const& logitsPostProcessorName);
@ -514,6 +541,70 @@ private:
std::optional<size_t> mHostCacheSize;
};
/// @brief Configuration class for Lookahead decoding.
class LookaheadDecodingConfig
{
public:
explicit LookaheadDecodingConfig(
SizeType32 maxNgramSize, SizeType32 maxWindowSize, SizeType32 maxVerificationSetSize);
bool operator==(LookaheadDecodingConfig const& other) const;
// Lookahead decoding methods.
void setMaxNgramSize(SizeType32);
void setMaxWindowSize(SizeType32);
void setMaxVerificationSetSize(SizeType32);
[[nodiscard]] SizeType32 getMaxNgramSize() const;
[[nodiscard]] SizeType32 getMaxWindowSize() const;
[[nodiscard]] SizeType32 getMaxVerificationSetSize() const;
private:
friend class Serialization;
// Number of tokens per NGram.
SizeType32 mMaxNgramSize;
// Number of NGrams in lookahead branch per step.
SizeType32 mMaxWindowSize;
// Number of NGrams in verification branch per step.
SizeType32 mMaxVerificationSetSize;
};
/// @brief Configuration class for the speculative decoding.
class DecodingConfig
{
public:
explicit DecodingConfig(std::optional<DecodingMode> decodingMode = std::nullopt,
std::optional<LookaheadDecodingConfig> lookaheadDecodingConfig = std::nullopt,
std::optional<MedusaChoices> medusaChoices = std::nullopt);
bool operator==(DecodingConfig const& other) const;
// Decoding mode.
/// @brief Setsdecoding mode. Can't set lookahead and medusa mode.
void setDecodingMode(DecodingMode const&);
[[nodiscard]] std::optional<DecodingMode> getDecodingMode() const;
// Lookahead methods.
/// @brief Sets lookahead decoding mode and lookahead decoding config.
void setLookaheadDecoding(LookaheadDecodingConfig const&);
[[nodiscard]] std::optional<LookaheadDecodingConfig> getLookaheadDecodingConfig() const;
// Medusa methods.
/// @brief Sets medusa mode and medusa config.
void setMedusaChoices(MedusaChoices const&);
[[nodiscard]] std::optional<MedusaChoices> getMedusaChoices() const;
private:
friend class Serialization;
// Decoding mode.
std::optional<DecodingMode> mDecodingMode;
// Lookahead params.
std::optional<LookaheadDecodingConfig> mLookaheadDecodingConfig;
// Medusa params.
std::optional<MedusaChoices> mMedusaChoices;
};
/// @brief Configuration class for the model executor
class ExecutorConfig
{
@ -526,8 +617,7 @@ public:
std::optional<ParallelConfig> parallelConfig = std::nullopt,
std::optional<PeftCacheConfig> const& peftCacheConfig = std::nullopt,
std::optional<LogitsPostProcessorMap> logitsPostProcessorMap = std::nullopt,
std::optional<MedusaChoices> medusaChoices = std::nullopt,
std::optional<DecodingMode> decodingMode = std::nullopt, float gpuWeightsPercent = 1);
std::optional<DecodingConfig> decodingConfig = std::nullopt, float gpuWeightsPercent = 1);
[[nodiscard]] SizeType32 getMaxBeamWidth() const;
[[nodiscard]] SchedulerConfig getSchedulerConfig() const;
@ -540,8 +630,7 @@ public:
[[nodiscard]] std::optional<ParallelConfig> getParallelConfig() const;
[[nodiscard]] std::optional<PeftCacheConfig> getPeftCacheConfig() const;
[[nodiscard]] std::optional<LogitsPostProcessorMap> getLogitsPostProcessorMap() const;
[[nodiscard]] std::optional<MedusaChoices> getMedusaChoices() const;
[[nodiscard]] std::optional<DecodingMode> getDecodingMode() const;
[[nodiscard]] std::optional<DecodingConfig> getDecodingConfig() const;
[[nodiscard]] float getGpuWeightsPercent() const;
void setMaxBeamWidth(SizeType32 maxBeamWidth);
@ -555,8 +644,7 @@ public:
void setParallelConfig(ParallelConfig const& parallelConfig);
void setPeftCacheConfig(PeftCacheConfig const& peftCacheConfig);
void setLogitsPostProcessorMap(LogitsPostProcessorMap const& logitsPostProcessorMap);
void setMedusaChoices(MedusaChoices const& medusaChoices);
void setDecodingMode(DecodingMode decodingMode);
void setDecodingConfig(DecodingConfig const& decodingConfig);
void setGpuWeightsPercent(float const& gpuWeightsPercent);
private:
@ -590,8 +678,8 @@ private:
std::optional<ParallelConfig> mParallelConfig;
std::optional<PeftCacheConfig> mPeftCacheConfig;
std::optional<LogitsPostProcessorMap> mLogitsPostProcessorMap;
std::optional<MedusaChoices> mMedusaChoices;
std::optional<DecodingMode> mDecodingMode;
/// @brief Decoding configuration.
std::optional<DecodingConfig> mDecodingConfig;
float mGpuWeightsPercent;
};

View File

@ -38,10 +38,10 @@ public:
static void serialize(OutputConfig const& config, std::ostream& os);
[[nodiscard]] static size_t serializedSize(OutputConfig const& config);
// SpeculativeDecodingConfig
[[nodiscard]] static SpeculativeDecodingConfig deserializeSpeculativeDecodingConfig(std::istream& is);
static void serialize(SpeculativeDecodingConfig const& config, std::ostream& os);
[[nodiscard]] static size_t serializedSize(SpeculativeDecodingConfig const& config);
// ExternalDraftTokensConfig
[[nodiscard]] static ExternalDraftTokensConfig deserializeExternalDraftTokensConfig(std::istream& is);
static void serialize(ExternalDraftTokensConfig const& config, std::ostream& os);
[[nodiscard]] static size_t serializedSize(ExternalDraftTokensConfig const& config);
// PromptTuningConfig
[[nodiscard]] static PromptTuningConfig deserializePromptTuningConfig(std::istream& is);
@ -102,6 +102,21 @@ public:
static void serialize(OrchestratorConfig const& orchestratorConfig, std::ostream& os);
static size_t serializedSize(OrchestratorConfig const& orchestratorConfig);
// DecodingMode
static DecodingMode deserializeDecodingMode(std::istream& is);
static void serialize(DecodingMode const& decodingMode, std::ostream& os);
static size_t serializedSize(DecodingMode const& decodingMode);
// LookaheadDecodingConfig
static LookaheadDecodingConfig deserializeLookaheadDecodingConfig(std::istream& is);
static void serialize(LookaheadDecodingConfig const& lookaheadDecodingConfig, std::ostream& os);
static size_t serializedSize(LookaheadDecodingConfig const& lookaheadDecodingConfig);
// DecodingConfig
static DecodingConfig deserializeDecodingConfig(std::istream& is);
static void serialize(DecodingConfig const& decodingConfig, std::ostream& os);
static size_t serializedSize(DecodingConfig const& decodingConfig);
// ExecutorConfig
static ExecutorConfig deserializeExecutorConfig(std::istream& is);
static void serialize(ExecutorConfig const& executorConfig, std::ostream& os);

View File

@ -253,6 +253,8 @@ struct InflightBatchingStats
SizeType32 numCtxTokens;
/// @brief Index of mirco batch
SizeType32 microBatchId;
/// @brief Average number of tokens decoded per request per iteration
float avgNumDecodedTokensPerIter;
};
/// @brief Struct that holds the stats of a single iteration
@ -306,6 +308,8 @@ struct RequestStats
SizeType32 contextPrefillPosition;
/// @brief The number of generated tokens so far
SizeType32 numGeneratedTokens;
/// @brief The average number of decoded tokens per iteration. It is >= 1 for speculative decoding.
float avgNumDecodedTokensPerIter;
/// @brief Whether the request is scheduled for the current iteration
bool scheduled;
/// @brief Whether the request is being paused at the current iteration due to lack of resources (KV cache blocks
@ -322,17 +326,386 @@ struct RequestStatsPerIteration
std::vector<RequestStats> requestStats;
};
/// @brief Decoding mode
enum class DecodingMode
/// @brief mode of the decoder
class DecodingMode
{
public:
/// @brief No mode specified. Config will be determined from the beam width of the first request at runtime
/// TopKTopP if beamWidth == 1, BeamSearch otherwise
kNONE,
kTOP_K,
kTOP_P,
kBEAM_SEARCH,
kMEDUSA,
kTOP_K_TOP_P,
static auto constexpr Auto()
{
return DecodingMode{kAuto};
}
static auto constexpr TopK()
{
return DecodingMode{kTopK | kUsePenalties | kUseBanWords | kStandardStopCriteria};
}
static auto constexpr TopP()
{
return DecodingMode{kTopP | kUsePenalties | kUseBanWords | kStandardStopCriteria};
}
static auto constexpr TopKTopP()
{
return DecodingMode{kTopKTopP | kUsePenalties | kUseBanWords | kStandardStopCriteria};
}
static auto constexpr BeamSearch()
{
return DecodingMode{kBeamSearch | kUsePenalties | kUseBanWords | kStandardStopCriteria};
}
static auto constexpr Medusa()
{
return DecodingMode{kMedusa | kUseMinLength | kUseMaxLengthStop};
}
static auto constexpr Lookahead()
{
return DecodingMode{kLookahead | kUseMinLength | kUseMaxLengthStop};
}
static auto constexpr ExplicitDraftTokens()
{
return DecodingMode{kExplicitDraftTokens | kUseMaxLengthStop | kUseExplicitEosStop};
}
auto constexpr useTemperature(bool useTemp)
{
mState = setBitTo(kUseTemperature, useTemp);
return *this;
}
auto constexpr useOccurrencePenalties(bool usePenalty)
{
mState = setBitTo(kUseOccurrencePenalties, usePenalty);
return *this;
}
auto constexpr usePresencePenalty(bool usePenalty)
{
mState = setBitTo(kUsePresencePenalties, usePenalty);
return *this;
}
auto constexpr useRepetitionPenalty(bool usePenalty)
{
mState = setBitTo(kUseRepetitionPenalties, usePenalty);
return *this;
}
auto constexpr useFrequencyPenalty(bool usePenalty)
{
mState = setBitTo(kUseFrequencyPenalties, usePenalty);
return *this;
}
auto constexpr useMinLength(bool useMinLen)
{
mState = setBitTo(kUseMinLength, useMinLen);
return *this;
}
auto constexpr useBanWords(bool banWords)
{
mState = setBitTo(kUseBanWords, banWords);
return *this;
}
auto constexpr useStopWords(bool stopWords)
{
mState = setBitTo(kUseStopWords, stopWords);
return *this;
}
auto constexpr useMaxLengthStop(bool maxLengthStop)
{
mState = setBitTo(kUseMaxLengthStop, maxLengthStop);
return *this;
}
auto constexpr useExplicitEosStop(bool explicitEosStop)
{
mState = setBitTo(kUseExplicitEosStop, explicitEosStop);
return *this;
}
bool constexpr isAuto() const
{
return anyBitSet(kAuto);
}
bool constexpr isTopK() const
{
return anyBitSet(kTopK);
}
bool constexpr isTopP() const
{
return anyBitSet(kTopP);
}
bool constexpr isTopKorTopP() const
{
return anyBitSet(kTopKTopP);
}
bool constexpr isTopKandTopP() const
{
return allBitSet(kTopKTopP);
}
bool constexpr isBeamSearch() const
{
return anyBitSet(kBeamSearch);
}
bool constexpr isMedusa() const
{
return anyBitSet(kMedusa);
}
bool constexpr isLookahead() const
{
return anyBitSet(kLookahead);
}
bool constexpr isExplicitDraftTokens() const
{
return anyBitSet(kExplicitDraftTokens);
}
bool constexpr isUseTemperature() const
{
return anyBitSet(kUseTemperature);
}
bool constexpr isUsePresencePenalty() const
{
return anyBitSet(kUsePresencePenalties);
}
bool constexpr isUseFrequencyPenalty() const
{
return anyBitSet(kUseFrequencyPenalties);
}
bool constexpr isUseRepetitionPenalty() const
{
return anyBitSet(kUseRepetitionPenalties);
}
bool constexpr isUseMinLength() const
{
return anyBitSet(kUseMinLength);
}
bool constexpr isUseOccurrencePenalty() const
{
return anyBitSet(kUseOccurrencePenalties);
}
bool constexpr isUsePenalty() const
{
return anyBitSet(kUsePenalties);
}
bool constexpr isUseBanWords() const
{
return anyBitSet(kUseBanWords);
}
bool constexpr isUseStopWords() const
{
return anyBitSet(kUseStopWords);
}
bool constexpr isUseMaxLengthStop() const
{
return anyBitSet(kUseMaxLengthStop);
}
bool constexpr isUseExplicitEosStop() const
{
return anyBitSet(kUseExplicitEosStop);
}
bool constexpr isUseStopCriteria() const
{
return anyBitSet(kStandardStopCriteria | kUseExplicitEosStop);
}
using UnderlyingType = uint32_t;
bool operator==(DecodingMode const& other) const
{
return mState == other.mState;
}
constexpr DecodingMode(UnderlyingType state)
: mState(state)
{
}
constexpr UnderlyingType getState() const
{
return mState;
}
private:
// No mode specified. Config will be determined from the beam width of the first request at runtime
// TopKTopP if beamWidth == 1, BeamSearch otherwise
static UnderlyingType constexpr kUseRepetitionPenalties{1u << 0};
static UnderlyingType constexpr kUseFrequencyPenalties{1u << 1};
static UnderlyingType constexpr kUsePresencePenalties{1u << 2};
static UnderlyingType constexpr kUseTemperature{1u << 3};
static UnderlyingType constexpr kUseMinLength{1u << 4};
static UnderlyingType constexpr kUseBanWords{1u << 5};
static UnderlyingType constexpr kUseStopWords{1u << 6};
static UnderlyingType constexpr kUseMaxLengthStop{1u << 7};
static UnderlyingType constexpr kUseExplicitEosStop{1u << 8};
static UnderlyingType constexpr kStandardStopCriteria{kUseStopWords | kUseMaxLengthStop};
static UnderlyingType constexpr kUseOccurrencePenalties{
kUseRepetitionPenalties | kUseFrequencyPenalties | kUsePresencePenalties};
static UnderlyingType constexpr kUsePenalties{kUseOccurrencePenalties | kUseTemperature | kUseMinLength};
static SizeType32 constexpr kNumFlags{9};
static UnderlyingType constexpr kAuto{1u << (kNumFlags + 0)};
static UnderlyingType constexpr kTopK{1u << (kNumFlags + 1)};
static UnderlyingType constexpr kTopP{1u << (kNumFlags + 2)};
static UnderlyingType constexpr kBeamSearch{1u << (kNumFlags + 3)};
static UnderlyingType constexpr kMedusa{1u << (kNumFlags + 4)};
static UnderlyingType constexpr kLookahead{1u << (kNumFlags + 5)};
static UnderlyingType constexpr kExplicitDraftTokens{1u << (kNumFlags + 6)};
static UnderlyingType constexpr kTopKTopP{kTopK | kTopP};
bool constexpr anyBitSet(UnderlyingType bits) const
{
return (mState & bits) != 0;
}
bool constexpr allBitSet(UnderlyingType bits) const
{
return (mState & bits) == bits;
}
UnderlyingType constexpr setBitTo(UnderlyingType state, bool x)
{
return (mState & (~state)) | (state * static_cast<UnderlyingType>(x));
}
UnderlyingType mState{};
};
static_assert(DecodingMode::Auto().isAuto());
static_assert(!DecodingMode::Auto().isUseBanWords());
static_assert(!DecodingMode::Auto().isUseOccurrencePenalty());
static_assert(!DecodingMode::Auto().isUseStopCriteria());
static_assert(!DecodingMode::Auto().isTopK());
static_assert(!DecodingMode::Auto().isTopP());
static_assert(!DecodingMode::Auto().isBeamSearch());
static_assert(!DecodingMode::Auto().isMedusa());
static_assert(!DecodingMode::Auto().isLookahead());
static_assert(!DecodingMode::Auto().isExplicitDraftTokens());
static_assert(DecodingMode::TopK().isTopK());
static_assert(DecodingMode::TopK().isTopKorTopP());
static_assert(DecodingMode::TopK().isUseBanWords());
static_assert(DecodingMode::TopK().isUseOccurrencePenalty());
static_assert(DecodingMode::TopK().isUseStopCriteria());
static_assert(!DecodingMode::TopK().useRepetitionPenalty(false).isUseRepetitionPenalty());
static_assert(DecodingMode::TopK().useRepetitionPenalty(false).isUseOccurrencePenalty());
static_assert(!DecodingMode::TopK()
.useRepetitionPenalty(false)
.usePresencePenalty(false)
.useFrequencyPenalty(false)
.isUseOccurrencePenalty());
static_assert(!DecodingMode::TopK().isTopKandTopP());
static_assert(!DecodingMode::TopK().isTopP());
static_assert(!DecodingMode::TopK().isBeamSearch());
static_assert(!DecodingMode::TopK().isMedusa());
static_assert(!DecodingMode::TopK().isLookahead());
static_assert(!DecodingMode::TopK().isAuto());
static_assert(!DecodingMode::TopK().isExplicitDraftTokens());
static_assert(DecodingMode::TopP().isTopP());
static_assert(DecodingMode::TopP().isTopKorTopP());
static_assert(DecodingMode::TopP().isUseBanWords());
static_assert(DecodingMode::TopP().isUseOccurrencePenalty());
static_assert(DecodingMode::TopP().isUseStopCriteria());
static_assert(!DecodingMode::TopP().isTopKandTopP());
static_assert(!DecodingMode::TopP().isTopK());
static_assert(!DecodingMode::TopP().isBeamSearch());
static_assert(!DecodingMode::TopP().isMedusa());
static_assert(!DecodingMode::TopP().isLookahead());
static_assert(!DecodingMode::TopP().isAuto());
static_assert(!DecodingMode::TopP().isExplicitDraftTokens());
static_assert(DecodingMode::TopKTopP().isTopK());
static_assert(DecodingMode::TopKTopP().isTopP());
static_assert(DecodingMode::TopKTopP().isTopKorTopP());
static_assert(DecodingMode::TopKTopP().isTopKandTopP());
static_assert(DecodingMode::TopKTopP().isUseBanWords());
static_assert(DecodingMode::TopKTopP().isUseOccurrencePenalty());
static_assert(DecodingMode::TopKTopP().isUseStopCriteria());
static_assert(!DecodingMode::TopKTopP().isBeamSearch());
static_assert(!DecodingMode::TopKTopP().isMedusa());
static_assert(!DecodingMode::TopKTopP().isLookahead());
static_assert(!DecodingMode::TopKTopP().isAuto());
static_assert(!DecodingMode::TopKTopP().isExplicitDraftTokens());
static_assert(DecodingMode::BeamSearch().isBeamSearch());
static_assert(DecodingMode::BeamSearch().isUseStopCriteria());
static_assert(!DecodingMode::BeamSearch().isTopKorTopP());
static_assert(!DecodingMode::BeamSearch().isMedusa());
static_assert(!DecodingMode::BeamSearch().isLookahead());
static_assert(!DecodingMode::BeamSearch().isAuto());
static_assert(!DecodingMode::BeamSearch().isExplicitDraftTokens());
static_assert(!DecodingMode::Medusa().isTopK());
static_assert(!DecodingMode::Medusa().isTopKorTopP());
static_assert(!DecodingMode::Medusa().isTopKandTopP());
static_assert(!DecodingMode::Medusa().isTopP());
static_assert(!DecodingMode::Medusa().isBeamSearch());
static_assert(!DecodingMode::Medusa().isLookahead());
static_assert(!DecodingMode::Medusa().isAuto());
static_assert(!DecodingMode::Medusa().isUseBanWords());
static_assert(!DecodingMode::Medusa().isUseOccurrencePenalty());
static_assert(!DecodingMode::Medusa().isExplicitDraftTokens());
static_assert(DecodingMode::Medusa().isUseStopCriteria());
static_assert(!DecodingMode::Medusa().isUseStopWords());
static_assert(!DecodingMode::Medusa().isUseExplicitEosStop());
static_assert(DecodingMode::Medusa().isUsePenalty());
static_assert(DecodingMode::Medusa().isUseMinLength());
static_assert(DecodingMode::Medusa().isMedusa());
static_assert(!DecodingMode::Lookahead().isAuto());
static_assert(!DecodingMode::Lookahead().isTopK());
static_assert(!DecodingMode::Lookahead().isTopKorTopP());
static_assert(!DecodingMode::Lookahead().isTopKandTopP());
static_assert(!DecodingMode::Lookahead().isTopP());
static_assert(!DecodingMode::Lookahead().isBeamSearch());
static_assert(!DecodingMode::Lookahead().isMedusa());
static_assert(!DecodingMode::Lookahead().isExplicitDraftTokens());
static_assert(DecodingMode::Lookahead().isUseStopCriteria());
static_assert(!DecodingMode::Lookahead().isUseStopWords());
static_assert(!DecodingMode::Lookahead().isUseExplicitEosStop());
static_assert(DecodingMode::Lookahead().isLookahead());
static_assert(!DecodingMode::ExplicitDraftTokens().isAuto());
static_assert(!DecodingMode::ExplicitDraftTokens().isTopK());
static_assert(!DecodingMode::ExplicitDraftTokens().isTopKorTopP());
static_assert(!DecodingMode::ExplicitDraftTokens().isTopKandTopP());
static_assert(!DecodingMode::ExplicitDraftTokens().isTopP());
static_assert(!DecodingMode::ExplicitDraftTokens().isBeamSearch());
static_assert(!DecodingMode::ExplicitDraftTokens().isMedusa());
static_assert(!DecodingMode::ExplicitDraftTokens().isLookahead());
static_assert(!DecodingMode::ExplicitDraftTokens().isUsePenalty());
static_assert(DecodingMode::ExplicitDraftTokens().isUseStopCriteria());
static_assert(DecodingMode::ExplicitDraftTokens().isUseMaxLengthStop());
static_assert(DecodingMode::ExplicitDraftTokens().isUseExplicitEosStop());
static_assert(!DecodingMode::ExplicitDraftTokens().isUseStopWords());
static_assert(!DecodingMode::ExplicitDraftTokens().isUseBanWords());
static_assert(DecodingMode::ExplicitDraftTokens().isExplicitDraftTokens());
} // namespace tensorrt_llm::executor

View File

@ -81,10 +81,10 @@ public:
class MedusaInputs
{
public:
TensorPtr medusaPaths; // [maxBatchSize, maxTokensPerStep, maxMedusaHeads + 1], on gpu
TensorPtr medusaTreeIds; // [maxBatchSize, maxTokensPerStep], on gpu
TensorPtr medusaPaths; // [maxBatchSize, maxTokensPerStep, maxMedusaHeads + 1], on gpu
TensorPtr medusaTreeIds; // [maxBatchSize, maxTokensPerStep], on gpu
std::vector<std::vector<TensorPtr>>
medusaLogits; // [maxBatchSize][maxMedusaHeads][tokensPerStep, vocabSizePadded], on gpu
medusaLogits; // [maxBatchSize][maxAcceptedDraftTokensPerStep][maxDraftTokens + 1, vocabSizePadded], on gpu
TensorPtr medusaCurTokensPerStep; // [maxBatchSize], on gpu
TensorPtr medusaTargetTokensPerStep; // [maxBatchSize], on gpu
};

View File

@ -1,190 +0,0 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/executor/executor.h"
namespace tensorrt_llm
{
namespace runtime
{
class DecodingMode
{
public:
static auto constexpr None()
{
return DecodingMode{kNone};
}
static auto constexpr TopK()
{
return DecodingMode{kTopK};
}
static auto constexpr TopP()
{
return DecodingMode{kTopP};
}
static auto constexpr TopKTopP()
{
return DecodingMode{kTopKTopP};
}
static auto constexpr BeamSearch()
{
return DecodingMode{kBeamSearch};
}
static auto constexpr Medusa()
{
return DecodingMode{kMedusa};
}
bool constexpr isNone() const
{
return mState == 0;
}
bool constexpr isTopK() const
{
return anyBitSet(kTopK);
}
bool constexpr isTopP() const
{
return anyBitSet(kTopP);
}
bool constexpr isTopKorTopP() const
{
return anyBitSet(kTopKTopP);
}
bool constexpr isTopKandTopP() const
{
return allBitSet(kTopKTopP);
}
bool constexpr isBeamSearch() const
{
return anyBitSet(kBeamSearch);
}
bool constexpr isMedusa() const
{
return anyBitSet(kMedusa);
}
using UnderlyingType = uint8_t;
bool operator==(DecodingMode const& other) const
{
return mState == other.mState;
}
static DecodingMode fromExecutor(executor::DecodingMode decodingMode)
{
switch (decodingMode)
{
case executor::DecodingMode::kNONE: return DecodingMode::None();
case executor::DecodingMode::kTOP_K: return DecodingMode::TopK();
case executor::DecodingMode::kTOP_P: return DecodingMode::TopP();
case executor::DecodingMode::kBEAM_SEARCH: return DecodingMode::BeamSearch();
case executor::DecodingMode::kMEDUSA: return DecodingMode::Medusa();
case executor::DecodingMode::kTOP_K_TOP_P: return DecodingMode::TopKTopP();
default: TLLM_THROW("Invalid decoding mode"); break;
}
}
friend std::ostream& operator<<(std::ostream& os, DecodingMode other);
private:
constexpr DecodingMode(UnderlyingType state)
: mState(state)
{
}
// No mode specified. Config will be determined from the beam width of the first request at runtime
// TopKTopP if beamWidth == 1, BeamSearch otherwise
static UnderlyingType constexpr kNone{0};
static UnderlyingType constexpr kTopK{1u << 0};
static UnderlyingType constexpr kTopP{1u << 1};
static UnderlyingType constexpr kBeamSearch{1u << 2};
static UnderlyingType constexpr kMedusa{1u << 3};
static UnderlyingType constexpr kTopKTopP{kTopK | kTopP};
bool constexpr anyBitSet(UnderlyingType bits) const
{
return (mState & bits) != 0;
}
bool constexpr allBitSet(UnderlyingType bits) const
{
return (mState & bits) == bits;
}
UnderlyingType mState{};
};
static_assert(DecodingMode::None().isNone());
static_assert(!DecodingMode::None().isTopK());
static_assert(!DecodingMode::None().isTopP());
static_assert(!DecodingMode::None().isBeamSearch());
static_assert(!DecodingMode::None().isMedusa());
static_assert(DecodingMode::TopK().isTopK());
static_assert(DecodingMode::TopK().isTopKorTopP());
static_assert(!DecodingMode::TopK().isTopKandTopP());
static_assert(!DecodingMode::TopK().isTopP());
static_assert(!DecodingMode::TopK().isBeamSearch());
static_assert(!DecodingMode::TopK().isMedusa());
static_assert(DecodingMode::TopP().isTopP());
static_assert(DecodingMode::TopP().isTopKorTopP());
static_assert(!DecodingMode::TopP().isTopKandTopP());
static_assert(!DecodingMode::TopP().isTopK());
static_assert(!DecodingMode::TopP().isBeamSearch());
static_assert(!DecodingMode::TopP().isMedusa());
static_assert(DecodingMode::TopKTopP().isTopK());
static_assert(DecodingMode::TopKTopP().isTopP());
static_assert(DecodingMode::TopKTopP().isTopKorTopP());
static_assert(DecodingMode::TopKTopP().isTopKandTopP());
static_assert(!DecodingMode::TopKTopP().isBeamSearch());
static_assert(!DecodingMode::TopKTopP().isMedusa());
static_assert(DecodingMode::BeamSearch().isBeamSearch());
static_assert(!DecodingMode::BeamSearch().isTopKorTopP());
static_assert(!DecodingMode::BeamSearch().isMedusa());
static_assert(!DecodingMode::Medusa().isTopK());
static_assert(!DecodingMode::Medusa().isTopKorTopP());
static_assert(!DecodingMode::Medusa().isTopKandTopP());
static_assert(!DecodingMode::Medusa().isTopP());
static_assert(!DecodingMode::Medusa().isBeamSearch());
static_assert(DecodingMode::Medusa().isMedusa());
} // namespace runtime
} // namespace tensorrt_llm

View File

@ -88,17 +88,17 @@ public:
BeamHypotheses beamHypotheses;
// Medusa
class MedusaOutputs
// Speculative decoding
class SpeculativeDecodingOutputs
{
public:
TensorPtr medusaNextDraftTokens; // [maxBatchSize, maxTokensPerStep]
TensorPtr medusaAcceptedTokensLen; // [maxBatchSize]
TensorPtr medusaAcceptedLengthsCumSum; // [maxBatchSize + 1]
TensorPtr medusaPathsOffsets; // [maxBatchSize * maxNumHeads]
TensorPtr nextDraftTokens; // [maxBatchSize, maxDraftTokens]
TensorPtr acceptedTokensLen; // [maxBatchSize]
TensorPtr acceptedLengthsCumSum; // [maxBatchSize + 1]
TensorPtr pathsOffsets; // [maxBatchSize, maxAcceptedDraftTokensPerStep]
};
std::optional<MedusaOutputs> medusaOutputs;
std::optional<SpeculativeDecodingOutputs> speculativeDecodingOutputs;
};
} // namespace tensorrt_llm::runtime

View File

@ -16,10 +16,10 @@
#pragma once
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/decodingInput.h"
#include "tensorrt_llm/runtime/decodingMode.h"
#include "tensorrt_llm/runtime/decodingOutput.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/samplingConfig.h"
@ -50,14 +50,14 @@ public:
virtual ~IGptDecoder() = default;
virtual void setup(SamplingConfig const& samplingConfig, size_t batchSize, SizeType32 maxSequenceLength,
virtual void setup(SamplingConfig const& samplingConfig, size_t batchSize,
std::optional<TensorPtr> const& batchSlots = std::nullopt)
= 0;
virtual bool forward(DecodingOutput& output, DecodingInput const& input) = 0;
virtual void forwardAsync(DecodingOutput& output, DecodingInput const& input) = 0;
virtual void forwardSync(DecodingOutput& output, DecodingInput const& input) = 0;
virtual void gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput,
DecodingInput const& decodingInput, BufferManager const& manager)
= 0;
@ -74,10 +74,10 @@ public:
SizeType32 vocabSize, SizeType32 vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold,
curandState_t* curandState, BufferManager::CudaStreamPtr const& stream);
static std::unique_ptr<IGptDecoder> create(DecodingMode const& mode, nvinfer1::DataType dtype, size_t maxBatchSize,
size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
static std::unique_ptr<IGptDecoder> create(executor::DecodingMode const& mode, nvinfer1::DataType dtype,
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
BufferManager::CudaStreamPtr const& stream, std::optional<runtime::SizeType32> maxTokensPerStep = std::nullopt,
std::optional<runtime::SizeType32> maxNumMedusaHeads = std::nullopt);
std::optional<runtime::SizeType32> maxAcceptedDraftTokensPerStep = std::nullopt);
};
template <typename T>
@ -88,18 +88,18 @@ public:
using CudaStreamPtr = BufferManager::CudaStreamPtr;
using TensorPtr = std::shared_ptr<ITensor>;
GptDecoder(DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize,
GptDecoder(executor::DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize,
size_t vocabSizePadded, size_t maxSequenceLength, CudaStreamPtr const& stream,
std::optional<runtime::SizeType32> maxTokensPerStep = std::nullopt,
std::optional<runtime::SizeType32> maxNumMedusaHeads = std::nullopt);
std::optional<runtime::SizeType32> maxAcceptedDraftTokensPerStep = std::nullopt);
void setup(SamplingConfig const& samplingConfig, size_t batchSize, SizeType32 maxSequenceLength,
void setup(SamplingConfig const& samplingConfig, size_t batchSize,
std::optional<TensorPtr> const& batchSlots = std::nullopt) override;
bool forward(DecodingOutput& output, DecodingInput const& input) override;
void forwardAsync(DecodingOutput& output, DecodingInput const& input) override;
void forwardSync(DecodingOutput& output, DecodingInput const& input) override;
void gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput, DecodingInput const& decodingInput,
BufferManager const& manager) override;
@ -119,19 +119,19 @@ private:
size_t mMaxBatchSize;
};
inline std::unique_ptr<IGptDecoder> IGptDecoder::create(DecodingMode const& mode, nvinfer1::DataType dtype,
inline std::unique_ptr<IGptDecoder> IGptDecoder::create(executor::DecodingMode const& mode, nvinfer1::DataType dtype,
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
BufferManager::CudaStreamPtr const& stream, std::optional<runtime::SizeType32> maxTokensPerStep,
std::optional<runtime::SizeType32> maxNumMedusaHeads)
std::optional<runtime::SizeType32> maxAcceptedDraftTokensPerStep)
{
switch (dtype)
{
case nvinfer1::DataType::kFLOAT:
return std::make_unique<GptDecoder<float>>(mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded,
maxSequenceLength, stream, maxTokensPerStep, maxNumMedusaHeads);
maxSequenceLength, stream, maxTokensPerStep, maxAcceptedDraftTokensPerStep);
case nvinfer1::DataType::kHALF:
return std::make_unique<GptDecoder<half>>(mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded,
maxSequenceLength, stream, maxTokensPerStep, maxNumMedusaHeads);
maxSequenceLength, stream, maxTokensPerStep, maxAcceptedDraftTokensPerStep);
default: return nullptr;
}
}

View File

@ -41,7 +41,7 @@ public:
GptDecoderBatch(std::size_t vocabSize, std::size_t vocabSizePadded, CudaStreamPtr stream);
//! Setup the decoder before calling `forward()`
void setup(DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
void setup(executor::DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, SizeType32 maxSequenceLength,
SizeType32 maxTokensPerStep, bool fusedDecoder, nvinfer1::DataType dtype,
ModelConfig const& modelConfig) override;
@ -56,6 +56,9 @@ public:
void forwardSync(decoder_batch::Token const& token) override;
void forwardSync(
decoder_batch::Token const& token, decoder_batch::Output& output, decoder_batch::Input const& input) override;
void forwardAsync(decoder::Output& output, decoder::Input const& input) override;
void forwardSync() override;
@ -154,22 +157,22 @@ public:
return mFinishedSum;
}
//! @returns [batchSize, maxTokensPerStep-1], predicted draft tokens for next step, on gpu
//! @returns [batchSize, maxDraftTokens], predicted draft tokens for next step, on gpu
[[nodiscard]] TensorPtr getNextDraftTokens() const override
{
return mJointDecodingOutput->medusaOutputs->medusaNextDraftTokens;
return mJointDecodingOutput->speculativeDecodingOutputs->nextDraftTokens;
}
//! @returns [batchSize + 1], exclusive sum of accepted draft token lengths, on gpu
[[nodiscard]] TensorPtr getMedusaAcceptedLengthsCumSum() const override
[[nodiscard]] TensorPtr getSpecDecodingAcceptedLengthsCumSum() const override
{
return mJointDecodingOutput->medusaOutputs->medusaAcceptedLengthsCumSum;
return mJointDecodingOutput->speculativeDecodingOutputs->acceptedLengthsCumSum;
}
//! @returns [batchSize * maxMedusaHeads], accepted paths packed into continuous tensor, on gpu
[[nodiscard]] TensorPtr getMedusaAcceptedPackedPaths() const override
//! @returns [batchSize, maxAcceptedDraftTokensPerStep], accepted paths packed into continuous tensor, on gpu
[[nodiscard]] TensorPtr getSpecDecodingAcceptedPackedPaths() const override
{
return mJointDecodingOutput->medusaOutputs->medusaPathsOffsets;
return mJointDecodingOutput->speculativeDecodingOutputs->pathsOffsets;
}
private:
@ -189,16 +192,30 @@ private:
void newRequestSpeculativeDecoding(
SizeType32 batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig);
//! @brief Setups decoder internal tensors for new request in Draft model Sps mode
void newRequestDraftTokensExternal(
SizeType32 batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig);
//! @brief Setups decoder internal tensors for new Medusa request
void newRequestMedusa(SizeType32 batchIdx, decoder_batch::Request const& request);
//! @brief Asynchronously calls unfused decoder for whole batch in loop
void forwardAsyncUnfusedDecoder(
SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input, CudaEvent const& eventStart);
//! @brief Setups decoder internal tensors for new Lookahead request
void newRequestLookahead(SizeType32 batchIdx, decoder_batch::Request const& request);
//! @brief Asynchronously calls fused decoder for whole batch
void forwardAsyncFusedDecoder(
SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input, CudaEvent const& eventStart);
//! @brief Updates finished state on host for all active requests
void updateFinished(decoder_batch::Token const& token);
//! @brief Calls unfused or fused decoders for tokens per engine step
void forwardDispatch(
decoder_batch::Output& output, decoder_batch::Input const& input, std::optional<CudaEvent> const& eventStart);
//! @brief Calls unfused decoder for whole batch in loop
void forwardUnfusedDecoder(SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input,
std::optional<CudaEvent> const& eventStart);
//! @brief Calls fused decoder for whole batch
void forwardFusedDecoder(SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input,
std::optional<CudaEvent> const& eventStart);
private:
std::size_t const mVocabSize;
@ -255,6 +272,6 @@ private:
SizeType32 mMaxTokensPerDecoderStep{};
bool mFusedDecoder{false};
bool mUseMedusa{false};
SpeculativeDecodingMode mSpeculativeDecodingMode;
};
} // namespace tensorrt_llm::runtime

View File

@ -24,10 +24,10 @@
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/decodingMode.h"
#include "tensorrt_llm/runtime/generationInput.h"
#include "tensorrt_llm/runtime/generationOutput.h"
#include "tensorrt_llm/runtime/iTensor.h"
@ -112,7 +112,7 @@ public:
// The micro batch size to be used in generation phase.
// Batches entered in `GptSession::generation` will be split into smaller micro batches of this size.
std::optional<SizeType32> genMicroBatchSize = std::nullopt;
std::optional<DecodingMode> decodingMode = std::nullopt;
std::optional<executor::DecodingMode> decodingMode = std::nullopt;
bool normalizeLogProbs = true;
};
@ -255,7 +255,7 @@ private:
void createBuffers(SizeType32 numMicroBatches);
void createDecoders(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 maxAttentionWindow,
SizeType32 sinkTokenLength, SizeType32 maxSequenceLength, nvinfer1::DataType logitsType, bool decoderPerRequest,
SizeType32 numMicroBatches, DecodingMode const& decodingMode);
SizeType32 numMicroBatches, executor::DecodingMode const& decodingMode);
void createKvCacheManager(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 maxAttentionWindow,
SizeType32 sinkTokenLength, SizeType32 maxSequenceLength, KvCacheConfig const& config);
void createCustomAllReduceWorkspace(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 maxSequenceLength);

View File

@ -44,8 +44,6 @@ public:
, inputLen(inputLen)
, maxNewTokens{maxNewTokens}
, endId{endId}
, computeCumLogProbs(false)
, computeLogProbs(false)
, generatedTokensPerEngineStep(1)
{
}
@ -64,11 +62,9 @@ public:
TensorPtr badWordsList; // [2, badWordsLength], on gpu
TensorPtr stopWordsList; // [2, stopWordsLength], on gpu
bool computeCumLogProbs; // boolean that controls if cumLogProbs should be computed for that request
bool computeLogProbs; // boolean that controls if cumLogProbs should be computed for that request
SizeType32 generatedTokensPerEngineStep;
TensorPtr medusaPaths; // [tokensPerStep, medusaHeads + 1], on gpu
TensorPtr medusaTreeIds; // [tokensPerStep], on gpu
TensorPtr medusaPaths; // [maxDraftTokens + 1, maxAcceptedDraftTokensPerStep + 1], on gpu
TensorPtr medusaTreeIds; // [maxDraftTokens + 1], on gpu
};
class Input
@ -112,7 +108,7 @@ public:
TensorConstPtr cacheIndirection; // [batchSize, maxBeamWidth, maxSeqLen] - indices into KV cache of different rays
// within one beam for beam search, on gpu
std::vector<std::vector<TensorConstPtr>>
medusaLogits; // [maxBatchSize][maxNumHeads][tokensPerStep, vocabSizePadded]
predictedDraftLogits; // [maxBatchSize][maxAcceptedDraftTokensPerStep][maxDraftTokens + 1, vocabSizePadded]
};
using Output = decoder::Output;
@ -142,6 +138,11 @@ public:
//! @brief Run one step for all requests without blocking the host process and return the token for synchronization.
virtual TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) = 0;
//! @brief Call decoder forwardSync and wait for the call to `forwardAsync` associated with a token to complete.
virtual void forwardSync(
decoder_batch::Token const& token, decoder_batch::Output& output, decoder_batch::Input const& input)
= 0;
//! @brief Wait for the call to `forwardAsync` associated with a token to complete.
virtual void forwardSync(decoder_batch::Token const& token) = 0;
@ -188,10 +189,10 @@ public:
virtual TensorPtr getNextDraftTokens() const = 0;
//! @returns [batchSize + 1], exclusive sum of accepted draft token lengths, on gpu
virtual TensorPtr getMedusaAcceptedLengthsCumSum() const = 0;
virtual TensorPtr getSpecDecodingAcceptedLengthsCumSum() const = 0;
//! @returns [batchSize * maxMedusaHeads], accepted paths packed into continuous tensor, on gpu
virtual TensorPtr getMedusaAcceptedPackedPaths() const = 0;
//! @returns [batchSize, maxAcceptedDraftTokensPerStep], accepted paths packed into continuous tensor, on gpu
virtual TensorPtr getSpecDecodingAcceptedPackedPaths() const = 0;
protected:
IGptDecoderBatch() = default;

View File

@ -16,8 +16,8 @@
#pragma once
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/decodingMode.h"
#include "tensorrt_llm/runtime/generationInput.h"
#include "tensorrt_llm/runtime/generationOutput.h"
#include "tensorrt_llm/runtime/iTensor.h"
@ -73,7 +73,7 @@ public:
using TensorPtr = std::shared_ptr<ITensor>;
//! Setup the decoder before calling `forward()`, also calls reshapeBuffers
virtual void setup(DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
virtual void setup(executor::DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, SizeType32 maxSequenceLength,
SizeType32 maxTokensPerStep, bool fusedDecoder, nvinfer1::DataType dtype, ModelConfig const& modelConfig)
= 0;

View File

@ -55,6 +55,7 @@ private:
SizeType32 mTpRank;
std::vector<void*> mCommPtrs;
BufferPtr mBuffer;
bool mOpenIpc;
};
class AllReduceBuffers

View File

@ -18,8 +18,10 @@
#include "tensorrt_llm/common/quantization.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/lookaheadModule.h"
#include "tensorrt_llm/runtime/loraModule.h"
#include "tensorrt_llm/runtime/medusaModule.h"
#include "tensorrt_llm/runtime/speculativeDecodingMode.h"
#include <NvInferRuntime.h>
namespace tensorrt_llm::runtime
@ -81,10 +83,10 @@ public:
, mUseXQA{false}
, mUseLoraPlugin(false)
, mMlpHiddenSize(0)
, mMedusaModule(std::nullopt)
, mUseCrossAttention(false)
, mUsePositionEmbedding(true) // TODO: remove these two properties?
, mUseTokenTypeEmbedding(false)
, mSpeculativeDecodingMode(SpeculativeDecodingMode::None())
{
}
@ -441,19 +443,38 @@ public:
mMaxLoraRank = maxLoraRank;
}
[[nodiscard]] bool constexpr useMedusa() const noexcept
void setSpeculativeDecodingMode(SpeculativeDecodingMode mode) noexcept
{
return mMedusaModule.has_value();
mSpeculativeDecodingMode = mode;
}
[[nodiscard]] std::optional<MedusaModule> getMedusaModule() const noexcept
[[nodiscard]] bool hasSpeculativeDecodingModule() const noexcept
{
return mMedusaModule;
return mSpeculativeDecodingModule != nullptr;
}
void setMedusaModule(MedusaModule const& medusaModule) noexcept
[[nodiscard]] SpeculativeDecodingModule const& getSpeculativeDecodingModule() const noexcept
{
mMedusaModule = medusaModule;
TLLM_CHECK_WITH_INFO(mSpeculativeDecodingModule, "Speculative decoding module is not set");
return *mSpeculativeDecodingModule;
}
[[nodiscard]] std::shared_ptr<SpeculativeDecodingModule const> getSpeculativeDecodingModulePtr() const noexcept
{
TLLM_CHECK_WITH_INFO(mSpeculativeDecodingModule, "Speculative decoding module is not set");
return mSpeculativeDecodingModule;
}
[[nodiscard]] std::shared_ptr<SpeculativeDecodingModule> getSpeculativeDecodingModulePtr() noexcept
{
TLLM_CHECK_WITH_INFO(mSpeculativeDecodingModule, "Speculative decoding module is not set");
return mSpeculativeDecodingModule;
}
void setSpeculativeDecodingModule(
std::shared_ptr<SpeculativeDecodingModule> const& speculativeDecodingModule) noexcept
{
mSpeculativeDecodingModule = speculativeDecodingModule;
}
[[nodiscard]] nvinfer1::DataType getKvDataType() const noexcept
@ -508,6 +529,11 @@ public:
mLayerTypes = layerTypes;
}
[[nodiscard]] SpeculativeDecodingMode getSpeculativeDecodingMode() const noexcept
{
return mSpeculativeDecodingMode;
}
private:
SizeType32 mVocabSize;
SizeType32 mNbAttentionLayers;
@ -546,7 +572,6 @@ private:
SizeType32 mMlpHiddenSize;
SizeType32 mMaxLoraRank;
std::optional<MedusaModule> mMedusaModule;
std::optional<RnnConfig> mRnnConfig;
// Configs related to encoder / enc-dec models
@ -556,6 +581,9 @@ private:
SizeType32 mFfnHiddenSize; // indicates encoder output hidden size
std::vector<LayerType> mLayerTypes;
// Speculative decoding members
std::shared_ptr<SpeculativeDecodingModule> mSpeculativeDecodingModule;
SpeculativeDecodingMode mSpeculativeDecodingMode;
};
} // namespace tensorrt_llm::runtime

View File

@ -16,6 +16,7 @@
#pragma once
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/layers/defaultDecodingParams.h"
#include "tensorrt_llm/runtime/common.h"
@ -75,6 +76,35 @@ private:
template <typename T>
using Vec = std::vector<T>;
template <typename T>
bool validateVec(std::string name, OptVec<T> const& vec, T min, std::optional<T> max = std::nullopt)
{
bool valid{true};
if (vec)
{
valid = std::all_of(vec->begin(), vec->end(),
[min, max](T elem)
{ return min < elem && ((max.has_value() && elem <= max.value()) || (!max.has_value())); });
if (!valid)
{
std::stringstream ss;
ss << "Incorrect sampling param. " << name << " is out of range (";
ss << min << ", ";
if (max.has_value())
{
ss << max.value();
}
else
{
ss << "inf";
}
ss << "]";
TLLM_LOG_WARNING(valid, ss.str());
}
}
return valid;
}
public:
explicit SamplingConfig(SizeType32 beamWidth = 1)
: beamWidth{beamWidth}
@ -129,19 +159,24 @@ public:
topKMedusaHeads = fuseValues<std::vector<SizeType32>>(
configs, [&configs](size_t ci) { return configs[ci].topKMedusaHeads; },
layers::DefaultDecodingParams::getTopKMedusaHeads());
outputLogProbs = fuseValues<bool>(
configs, [&configs](size_t ci) { return configs[ci].outputLogProbs; }, false);
cumLogProbs = fuseValues<bool>(
configs, [&configs](size_t ci) { return configs[ci].cumLogProbs; }, false);
// Only used for tests.
draftAcceptanceThreshold = fuseValues<FloatType>(
configs, [&configs](size_t ci) { return configs[ci].draftAcceptanceThreshold; }, 0);
}
explicit SamplingConfig(executor::SamplingConfig const& samplingConfig,
std::optional<executor::SpeculativeDecodingConfig> const& specDecodingConfig)
std::optional<executor::ExternalDraftTokensConfig> const& externalDraftTokensConfig)
: beamWidth{samplingConfig.getBeamWidth()}
{
if (specDecodingConfig && specDecodingConfig.value().getAcceptanceThreshold())
if (externalDraftTokensConfig && externalDraftTokensConfig.value().getAcceptanceThreshold())
{
draftAcceptanceThreshold = Vec<FloatType>{specDecodingConfig.value().getAcceptanceThreshold().value()};
draftAcceptanceThreshold
= Vec<FloatType>{externalDraftTokensConfig.value().getAcceptanceThreshold().value()};
}
#define SET_FROM_OPTIONAL(varName, VarName, VarType) \
@ -168,15 +203,69 @@ public:
#undef SET_FROM_OPTIONAL
}
bool validate()
{
auto constexpr fltEpsilon = std::numeric_limits<float>::epsilon();
bool valid{true};
valid &= (beamWidth > 0);
if (!valid)
{
TLLM_LOG_WARNING(
"Requested beam width %d is incorrect. Must be > 0. To de-activate beam searching set beamWidth to 1.",
beamWidth);
}
valid &= validateVec("topK", topK, -1);
valid &= validateVec("topP", topP, -fltEpsilon, {1.f});
valid &= validateVec("topPMin", topPMin, 0.f, {1.f});
valid &= validateVec("topPDecay", topPDecay, 0.f, {1.f});
valid &= validateVec("topPResetIds", topPResetIds, -1);
valid &= validateVec("temperature", temperature, -fltEpsilon);
valid &= validateVec("repetitionPenalty", repetitionPenalty, 0.f);
valid &= validateVec("minLength", minLength, -1);
valid &= validateVec("beamSearchDiversityRate", beamSearchDiversityRate, -fltEpsilon);
// Detect greedy sampling and overwrite params.
if (temperature)
{
for (size_t ti = 0; ti < temperature->size(); ++ti)
{
if (temperature->at(ti) == 0.f)
{
temperature->at(ti) = 1.0f;
if (topK)
{
topK->at(ti) = 1;
}
if (topP)
{
topP->at(ti) = 1.f;
}
}
}
}
return valid;
}
public:
SizeType32 beamWidth;
// penalties
OptVec<FloatType> temperature; // [1] or [batch_size] on cpu
OptVec<SizeType32> minLength; // [1] or [batch_size] on cpu
OptVec<FloatType> repetitionPenalty; // [1] or [batch_size] on cpu
OptVec<FloatType> presencePenalty; // [1] or [batch_size] on cpu
OptVec<FloatType> frequencyPenalty; // [1] or [batch_size] on cpu
// probs
OptVec<bool> outputLogProbs;
OptVec<bool> cumLogProbs;
// sampling layers
OptVec<SizeType32> topK; // [1] or [batch_size] on cpu
OptVec<FloatType> topP; // [1] or [batch_size] on cpu
@ -207,7 +296,8 @@ public:
&& topPResetIds == other.topPResetIds && beamSearchDiversityRate == other.beamSearchDiversityRate
&& lengthPenalty == other.lengthPenalty && earlyStopping == other.earlyStopping
&& draftAcceptanceThreshold == other.draftAcceptanceThreshold && topKMedusaHeads == other.topKMedusaHeads
&& normalizeLogProbs == other.normalizeLogProbs;
&& normalizeLogProbs == other.normalizeLogProbs && outputLogProbs == other.outputLogProbs
&& cumLogProbs == other.cumLogProbs;
}
};

View File

@ -30,9 +30,9 @@ public:
return SpeculativeDecodingMode{kNone};
}
static auto constexpr DraftModel()
static auto constexpr DraftTokensExternal()
{
return SpeculativeDecodingMode{kDraftModel};
return SpeculativeDecodingMode{kDraftTokensExternal};
}
static auto constexpr Medusa()
@ -50,9 +50,9 @@ public:
return anyBitSet(kNone);
}
bool constexpr isDraftModel() const
bool constexpr isDraftTokensExternal() const
{
return anyBitSet(kDraftModel);
return anyBitSet(kDraftTokensExternal);
}
bool constexpr isMedusa() const
@ -100,7 +100,7 @@ public:
private:
// No speculative decoding is used.
static UnderlyingType constexpr kNone{1u << 0};
static UnderlyingType constexpr kDraftModel{1u << 1};
static UnderlyingType constexpr kDraftTokensExternal{1u << 1};
static UnderlyingType constexpr kMedusa{1u << 2};
static UnderlyingType constexpr kLookaheadDecoding{1u << 3};
@ -118,23 +118,23 @@ private:
};
static_assert(SpeculativeDecodingMode::None().isNone());
static_assert(!SpeculativeDecodingMode::None().isDraftModel());
static_assert(!SpeculativeDecodingMode::None().isDraftTokensExternal());
static_assert(!SpeculativeDecodingMode::None().isMedusa());
static_assert(!SpeculativeDecodingMode::None().isLookaheadDecoding());
static_assert(SpeculativeDecodingMode::DraftModel().isDraftModel());
static_assert(!SpeculativeDecodingMode::DraftModel().isNone());
static_assert(!SpeculativeDecodingMode::DraftModel().isMedusa());
static_assert(!SpeculativeDecodingMode::DraftModel().isLookaheadDecoding());
static_assert(SpeculativeDecodingMode::DraftTokensExternal().isDraftTokensExternal());
static_assert(!SpeculativeDecodingMode::DraftTokensExternal().isNone());
static_assert(!SpeculativeDecodingMode::DraftTokensExternal().isMedusa());
static_assert(!SpeculativeDecodingMode::DraftTokensExternal().isLookaheadDecoding());
static_assert(SpeculativeDecodingMode::Medusa().isMedusa());
static_assert(!SpeculativeDecodingMode::Medusa().isNone());
static_assert(!SpeculativeDecodingMode::Medusa().isDraftModel());
static_assert(!SpeculativeDecodingMode::Medusa().isDraftTokensExternal());
static_assert(!SpeculativeDecodingMode::Medusa().isLookaheadDecoding());
static_assert(SpeculativeDecodingMode::LookaheadDecoding().isLookaheadDecoding());
static_assert(!SpeculativeDecodingMode::LookaheadDecoding().isNone());
static_assert(!SpeculativeDecodingMode::LookaheadDecoding().isDraftModel());
static_assert(!SpeculativeDecodingMode::LookaheadDecoding().isDraftTokensExternal());
static_assert(!SpeculativeDecodingMode::LookaheadDecoding().isMedusa());
} // namespace runtime

View File

@ -82,6 +82,11 @@ public:
return mDeviceIds[mRank % getGpusPerGroup()];
}
[[nodiscard]] SizeType32 getDeviceOf(SizeType32 rank) const noexcept
{
return mDeviceIds[rank % getGpusPerGroup()];
}
[[nodiscard]] SizeType32 constexpr getPipelineParallelRank() const noexcept
{
return mRank / mTensorParallelism;
@ -92,6 +97,21 @@ public:
return mRank % mTensorParallelism;
}
[[nodiscard]] SizeType32 constexpr getLocalRank() const noexcept
{
return mRank % mGpusPerNode;
}
[[nodiscard]] SizeType32 constexpr getNodeRank() const noexcept
{
return mRank / mGpusPerNode;
}
[[nodiscard]] SizeType32 constexpr getNodeRankOf(SizeType32 rank) const noexcept
{
return rank / mGpusPerNode;
}
[[nodiscard]] bool constexpr isFirstPipelineParallelRank() const noexcept
{
return getPipelineParallelRank() == 0;
@ -109,6 +129,7 @@ public:
}
[[nodiscard]] std::vector<SizeType32> getPipelineParallelGroup() const;
[[nodiscard]] std::vector<SizeType32> getTensorParallelGroup() const;
static WorldConfig mpi(SizeType32 gpusPerNode = kDefaultGpusPerNode,
std::optional<SizeType32> tensorParallelism = std::nullopt,

View File

@ -0,0 +1,48 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved. SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
# Google Benchmark Preparation - Same as google test ../tests/CMakeLists.txt
# Google Benchmark is provided under Apache-2.0 license
include(FetchContent)
FetchContent_Declare(
googlebenchmark
GIT_REPOSITORY https://github.com/google/benchmark.git
GIT_TAG v1.8.3)
FetchContent_MakeAvailable(googlebenchmark)
add_custom_target(micro_benchmarks)
include_directories(
${PROJECT_SOURCE_DIR}/tensorrt_llm/cutlass_extensions/include
${PROJECT_SOURCE_DIR}/include)
set(TOP_LEVEL_DIR "${PROJECT_SOURCE_DIR}/..")
function(add_benchmark test_name test_src)
add_executable(${test_name} ${test_src})
message("Linking with ${SHARED_TARGET}")
target_link_libraries(${test_name} PUBLIC ${SHARED_TARGET}
benchmark::benchmark)
target_compile_features(${test_name} PRIVATE cxx_std_17)
target_compile_definitions(${test_name}
PUBLIC TOP_LEVEL_DIR="${TOP_LEVEL_DIR}")
add_dependencies(micro_benchmarks ${test_name})
endfunction()
add_benchmark(mixtureOfExpertsBackendBenchmark
mixtureOfExpertsBackendBenchmarkLauncher.cu)

View File

@ -0,0 +1,36 @@
# Micro Benchmarks
This folder contains benchmarks for specific components in TRT-LLM,
using [google-benchmark](https://github.com/google/benchmark/tree/main)
## Building
To build add the `--micro_benchmark` flag to `build_wheel.py` or pass `-DBUILD_MICRO_BENCHMARKS=ON` to cmake
## Benchmark Documentations
### Mixture Of Experts Backend Benchmark
Target `mixtureOfExpertsBackendBenchmark`
This benchmark covers the backend used by the `MixtureOfExperts` plugin. It allows you to benchmark different MOE
configurations without building a TRT engine.
Usage:
```bash
./mixtureOfExpertsBackendBenchmark
# or
./mixtureOfExpertsBackendBenchmark --benchmark_file <JSON benchmark definition>
```
For more information see:
```
./mixtureOfExpertsBackendBenchmark --help
```
The `gen-moe-workload-file.py` is a helper script that can generate workload files for MOE benchmarks. This is useful
for sharing or comparing configurations, such as when generating a reproduction case for a performance bug

View File

@ -0,0 +1,92 @@
import argparse
template = '''{{
"num_experts": {num_experts},
"k": {k},
"hidden_size": {hidden_size},
"inter_size": {inter_size},
"tp_size": {tp_size},
"ep_size": {ep_size},
"world_rank": {world_rank},
"num_tokens": {num_tokens},
"bias": 0,
"act_fn": {act_fn},
"norm_mode": {norm_mode},
{dtype_string}
{routing_string}
"tactic_id": {tactic_id}
}}'''
def make_dtype_string(dtypes=None):
if dtypes is None:
return ""
if not isinstance(dtypes, list):
dtypes = [dtypes]
join_term = '","' # Include quotes because they should be strings
return f'"dtypes": ["{join_term.join(dtypes)}"],'
def make_routing_string(name=None, values=None):
if values is None and name is None:
return ""
if values is None:
return f'"routing_values": "{name}",'
values = f'"routing_values": [{",".join(values)}],'
if name is not None:
values += f' "routing_values_name": "{name}",'
return values
def populate_benchmark_config(**kwargs):
return template.format(**kwargs)
# Default Mixtral configurations
num_experts = 8
k = 2
hidden_size = 4096
inter_size = 14336
tp_size = 4
ep_size = 1
world_rank = 0
act_fn = 3
norm_mode = 1
dtype_string = make_dtype_string() # All dtypes
routing_string = make_routing_string(
name="balanced") # Use the default uniform distribution
tactic_id = '"auto"'
configs = []
for num_tokens in [1, 8, 64, 2048, 65536]:
configs.append(
populate_benchmark_config(
num_experts=num_experts,
k=k,
hidden_size=hidden_size,
inter_size=inter_size,
tp_size=tp_size,
ep_size=ep_size,
world_rank=world_rank,
num_tokens=num_tokens,
act_fn=act_fn,
norm_mode=norm_mode,
dtype_string=dtype_string,
routing_string=routing_string,
tactic_id=tactic_id,
))
full_string = "[\n" + ",\n".join(configs) + "\n]"
parser = argparse.ArgumentParser()
parser.add_argument('filename',
type=str,
help='The name of the file to generate',
nargs='?',
default="moe-benchmark-file.json")
args = parser.parse_args()
with open(args.filename, "w+") as f:
f.write(full_string)

View File

@ -0,0 +1,482 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <benchmark/benchmark.h>
#include <nlohmann/json.hpp>
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h"
#include "tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include <algorithm>
#include <cuda.h>
#include <memory>
#include <numeric>
#include <random>
#include <type_traits>
#include <unordered_map>
#include <vector>
using namespace tensorrt_llm::kernels;
using namespace tensorrt_llm::common;
using namespace tensorrt_llm::runtime;
using namespace tensorrt_llm::cutlass_extensions;
static BufferManager::CudaStreamPtr streamPtr;
static std::unique_ptr<BufferManager> bufferManager;
static int deviceCount;
static char* workloadFile = nullptr;
constexpr bool VERBOSE = false;
namespace
{
// Abstract class for routing config
struct RoutingConfig
{
virtual void setRouting(float* routing_output, int64_t num_experts, int64_t k, int64_t num_tokens) = 0;
virtual std::string getName() = 0;
};
struct LoadBalancedRoutingConfig : public RoutingConfig
{
std::string getName() override
{
return "balanced";
}
void setRouting(float* routing_output, int64_t num_experts, int64_t k, int64_t num_tokens) override
{
nvinfer1::DataType type = nvinfer1::DataType::kFLOAT;
makeLoadBalancedRoutingConfiguration(routing_output, num_experts, num_tokens, k, type, streamPtr->get());
check_cuda_error(cudaStreamSynchronize(streamPtr->get()));
}
};
struct VectoredRoutingConfig : public RoutingConfig
{
std::vector<float> config;
std::pair<int64_t, int64_t> shape;
std::string name;
VectoredRoutingConfig(std::vector<float> config, std::pair<int64_t, int64_t> shape, std::string name = "vectored")
: config(config)
, shape(shape)
, name(name)
{
}
std::string getName() override
{
return name;
}
void setRouting(float* routing_output, int64_t num_experts, int64_t k, int64_t num_tokens) override
{
assert(shape.second == num_experts);
for (int64_t i = 0; i < num_tokens; i += shape.first)
{
int num_to_copy = std::min(num_tokens - i, shape.first);
check_cuda_error(cudaMemcpyAsync(routing_output + i * num_experts, config.data(),
num_to_copy * num_experts * sizeof(float), cudaMemcpyHostToDevice, streamPtr->get()));
}
check_cuda_error(cudaStreamSynchronize(streamPtr->get()));
}
};
}; // namespace
constexpr int LOAD_BALANCED_ROUTING_CONFIG = 0;
std::vector<std::shared_ptr<RoutingConfig>> routingConfigCache{
std::static_pointer_cast<RoutingConfig>(std::make_shared<LoadBalancedRoutingConfig>())};
#ifdef ENABLE_FP8
using SafeFP8 = __nv_fp8_e4m3;
#else
using SafeFP8 = void;
#endif
template <class TypeTuple_>
class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
{
public:
using DataType = typename TypeTuple_::DataType;
using WeightType = typename TypeTuple_::WeightType;
using OutputType = typename TypeTuple_::OutputType;
constexpr static bool INT4 = std::is_same_v<WeightType, cutlass::uint4b_t>;
constexpr static bool FP8 = std::is_same_v<DataType, SafeFP8>;
constexpr static bool INT_QUANT = !std::is_same_v<DataType, WeightType>;
using WeightStorage = std::conditional_t<INT_QUANT, uint8_t, WeightType>;
constexpr static int WEIGHT_ELEM_PER_BYTE = INT4 ? 2 : 1;
int const BASE_HIDDEN_SIZE = 64 / sizeof(WeightType) * WEIGHT_ELEM_PER_BYTE;
std::vector<BufferManager::IBufferPtr> managed_buffers;
float* mInputProbabilities{};
DataType* mInputTensor{};
int64_t mHiddenSize{};
int64_t mNumExperts{};
int64_t mK{};
constexpr static nvinfer1::DataType toDTypeID()
{
if (FP8)
return nvinfer1::DataType::kFP8;
if (INT_QUANT && INT4)
return nvinfer1::DataType::kUINT8; // Hack to distinguish int4, use unsigned
if (INT_QUANT)
return nvinfer1::DataType::kINT8;
if (std::is_same_v<DataType, float>)
return nvinfer1::DataType::kFLOAT;
if (std::is_same_v<DataType, half>)
return nvinfer1::DataType::kHALF;
#ifdef ENABLE_BF16
if (std::is_same_v<DataType, nv_bfloat16>)
return nvinfer1::DataType::kBF16;
#endif
return nvinfer1::DataType::kBOOL;
};
static bool shouldSkip()
{
#ifndef ENABLE_FP8
static_assert(!FP8, "FP8 Tests enabled on unsupported CUDA version");
#endif
bool should_skip_unsupported_fp8 = getSMVersion() < 90 && FP8;
return should_skip_unsupported_fp8;
}
// Deprecated, just here to suppress warnings
void SetUp(benchmark::State const& s) override
{
abort();
}
void TearDown(benchmark::State const& s) override
{
abort();
}
cudaEvent_t mStartEvent, mEndEvent;
void SetUp(benchmark::State& s) override
{
assert(bufferManager);
if (shouldSkip())
{
s.SkipWithMessage("GPU does not support FP8");
}
// Makes sure nothing from a previous iteration hangs around
check_cuda_error(cudaDeviceSynchronize());
check_cuda_error(cudaEventCreate(&mStartEvent));
check_cuda_error(cudaEventCreate(&mEndEvent));
}
void TearDown(benchmark::State& s) override
{
managed_buffers.clear();
check_cuda_error(cudaEventDestroy(mStartEvent));
check_cuda_error(cudaEventDestroy(mEndEvent));
check_cuda_error(cudaDeviceSynchronize());
}
CutlassMoeFCRunner<DataType, WeightType, OutputType> mMoERunner{};
char* mWorkspace{};
float* mScaleProbs{};
WeightStorage* mExpertWeight1{};
WeightStorage* mExpertWeight2{};
DataType* mExpertIntScale1{};
DataType* mExpertIntScale2{};
float* mExpertFP8Scale1{};
float* mExpertFP8Scale2{};
float* mExpertFP8Scale3{};
DataType* mExpertBias1{};
DataType* mExpertBias2{};
OutputType* mFinalOutput{};
int* mSourceToExpandedMap;
int* mSelectedExpert;
int64_t mInterSize{};
int64_t mTotalTokens{};
bool mUseBias = true;
bool mIsGated = false;
int mGatedMultiplier = 1;
tensorrt_llm::ActivationType mActType = tensorrt_llm::ActivationType::Relu;
MOEExpertScaleNormalizationMode mNormMode = MOEExpertScaleNormalizationMode::NONE;
QuantParams mQuantParams{};
std::optional<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> mSelectedConfig = std::nullopt;
template <class T>
T* allocBuffer(size_t size)
{
auto i_buffer = bufferManager->gpu(size * sizeof(T));
check_cuda_error(cudaGetLastError());
managed_buffers.emplace_back(std::move(i_buffer));
T* ptr = static_cast<T*>(managed_buffers.back()->data());
check_cuda_error(cudaMemsetAsync(ptr, 0x0, size * sizeof(T), streamPtr->get()));
return ptr;
}
void initBuffersPermute(int64_t num_tokens, int64_t hidden_size, int64_t inter_size, int64_t num_experts, int64_t k,
int64_t routing_config)
{
assert(hidden_size % BASE_HIDDEN_SIZE == 0);
managed_buffers.clear();
mTotalTokens = num_tokens;
mHiddenSize = hidden_size;
mInterSize = inter_size;
mNumExperts = num_experts;
mK = k;
mIsGated = tensorrt_llm::isGatedActivation(mActType);
mGatedMultiplier = mIsGated ? 2 : 1;
auto const gated_inter = mInterSize * mGatedMultiplier;
size_t workspace_size
= mMoERunner.getWorkspaceSize(mTotalTokens, mHiddenSize, mInterSize, mNumExperts, mK, mActType, {});
mWorkspace = allocBuffer<char>(workspace_size);
size_t const expert_matrix_size = mNumExperts * mHiddenSize * mInterSize;
mExpertWeight1
= allocBuffer<WeightStorage>(expert_matrix_size * mGatedMultiplier / WEIGHT_ELEM_PER_BYTE - 8192);
mExpertWeight2 = allocBuffer<WeightStorage>(expert_matrix_size / WEIGHT_ELEM_PER_BYTE - 8192);
mExpertBias1 = nullptr;
mExpertBias2 = nullptr;
if (mUseBias)
{
mExpertBias1 = allocBuffer<DataType>(mNumExperts * gated_inter);
mExpertBias2 = allocBuffer<DataType>(mNumExperts * mHiddenSize);
}
if constexpr (INT_QUANT)
{
mExpertIntScale1 = allocBuffer<DataType>(mNumExperts * gated_inter);
mExpertIntScale2 = allocBuffer<DataType>(mNumExperts * mHiddenSize);
mQuantParams = QuantParams::Int(mExpertIntScale1, mExpertIntScale2);
}
else if constexpr (FP8)
{
mExpertFP8Scale1 = allocBuffer<float>(mNumExperts);
mExpertFP8Scale2 = allocBuffer<float>(1);
mExpertFP8Scale3 = allocBuffer<float>(mNumExperts);
mQuantParams = QuantParams::FP8(mExpertFP8Scale1, mExpertFP8Scale2, mExpertFP8Scale3);
}
mInputProbabilities = allocBuffer<float>(mTotalTokens * mNumExperts);
mScaleProbs = allocBuffer<float>(mTotalTokens * mK);
mInputTensor = allocBuffer<DataType>(mTotalTokens * mHiddenSize);
mFinalOutput = allocBuffer<OutputType>(mTotalTokens * mHiddenSize);
mSourceToExpandedMap = allocBuffer<int>(mTotalTokens * mK);
mSelectedExpert = allocBuffer<int>(mTotalTokens * mK);
auto tactic = routingConfigCache.at(routing_config);
tactic->setRouting(mInputProbabilities, mNumExperts, mK, mTotalTokens);
check_cuda_error(cudaStreamSynchronize(streamPtr->get()));
}
float benchmarkLoop(MOEParallelismConfig parallelism_config)
{
check_cuda_error(cudaEventRecord(mStartEvent, streamPtr->get()));
runMoEPermute(parallelism_config);
check_cuda_error(cudaEventRecord(mEndEvent, streamPtr->get()));
check_cuda_error(cudaStreamSynchronize(streamPtr->get()));
float ms;
check_cuda_error(cudaEventElapsedTime(&ms, mStartEvent, mEndEvent));
return ms;
}
// An imprecise benchmark pass for picking the best tactic.
// Runs for 3 iterations or 1 second and picks the best option
int pickBestTactic(MOEParallelismConfig parallelism_config)
{
auto tactics = mMoERunner.getTactics();
float best_time = INFINITY;
int best_idx = -1;
for (int tidx = 0; tidx < tactics.size(); tidx++)
{
try
{
// Set the tactic
auto const& t = tactics[tidx];
mMoERunner.setTactic(t);
// Warm-Up run
benchmarkLoop(parallelism_config);
float const max_time_ms = 1000.f;
int const max_iters = 3;
float time = 0.f;
int iter = 0;
while (iter < max_iters && time < max_time_ms)
{
time += benchmarkLoop(parallelism_config);
iter++;
}
// Get average time per iteration
time /= static_cast<float>(iter);
// Update the best
if (time < best_time)
{
best_idx = tidx;
best_time = time;
}
}
catch (std::exception const& e)
{
// Sync to tidy up
if (VERBOSE)
std::cout << "Tactic failed to run with: " << e.what() << std::endl;
check_cuda_error(cudaDeviceSynchronize());
// skip invalid tactic
continue;
}
}
// Check if all tactics failed
if (best_idx < 0)
return -1;
auto const& best_tactic = tactics[best_idx];
mMoERunner.setTactic(best_tactic);
return best_idx;
}
int setTactic(int tactic_idx, MOEParallelismConfig parallelism_config)
{
if (tactic_idx == -1)
{
return pickBestTactic(parallelism_config);
}
auto tactics = mMoERunner.getTactics();
if (tactic_idx < 0 || tactic_idx >= tactics.size())
{
return -1;
}
auto selected_tactic = tactics[tactic_idx];
mMoERunner.setTactic(selected_tactic);
return tactic_idx;
}
void runMoEPermute(MOEParallelismConfig parallelism_config)
{
auto stream = streamPtr->get();
mMoERunner.runMoe(mInputTensor, mInputProbabilities, mExpertWeight1, mExpertBias1, mActType, mExpertWeight2,
mExpertBias2, mQuantParams, mTotalTokens, mHiddenSize, mInterSize, mNumExperts, mK, mWorkspace,
mFinalOutput, nullptr, mTotalTokens, mScaleProbs, mSourceToExpandedMap, mSelectedExpert, parallelism_config,
mNormMode, stream);
}
void runBenchmark(benchmark::State& state);
};
template <class TypeTuple_>
void MixtureOfExpertsBenchmark<TypeTuple_>::runBenchmark(benchmark::State& state)
{
int const num_experts = state.range(0);
int const top_k = state.range(1);
int const hidden_size = state.range(2);
int const inter_size = state.range(3);
int const tp_size = state.range(4);
int const ep_size = state.range(5);
int const world_rank = state.range(6);
int const num_tokens = state.range(7);
mUseBias = state.range(8);
mActType = static_cast<tensorrt_llm::ActivationType>(state.range(9));
mNormMode = static_cast<MOEExpertScaleNormalizationMode>(state.range(10));
int tactic_idx = state.range(11);
int const routing_config = state.range(12);
state.counters["num_experts"] = num_experts;
state.counters["top_k"] = top_k;
state.counters["hidden_size"] = hidden_size;
state.counters["inter_size"] = inter_size;
state.counters["tp_size"] = tp_size;
state.counters["ep_size"] = ep_size;
state.counters["world_rank"] = world_rank;
state.counters["num_tokens"] = num_tokens;
state.counters["use_bias"] = (int) mUseBias;
state.counters["act_fn"] = (int) mActType;
state.counters["norm_mode"] = (int) mNormMode;
state.counters["routing_config"] = (int) routing_config;
state.counters["dtype"] = (int) toDTypeID();
std::stringstream ss;
ss << "Experts,K,Hidden,Inter,TP,EP,Rank,Tokens,Bias,Actfn,Norm Mode,Tactic,Routing=";
for (auto v : {num_experts, top_k, hidden_size, inter_size, tp_size, ep_size, world_rank, num_tokens,
(int) mUseBias, (int) mActType, (int) mNormMode, tactic_idx})
{
ss << v << ",";
}
ss << routingConfigCache.at(routing_config)->getName();
// state.SetLabel(ss.str());
// Always use EP size for moe config until we support TP+EP, we just divide the inter size for TP
MOEParallelismConfig parallelism_config = MOEParallelismConfig::ExpertParallelism(ep_size, world_rank / tp_size);
initBuffersPermute(num_tokens, hidden_size, inter_size / tp_size, num_experts, top_k, routing_config);
// Parse the tactic, does checks for "auto" mode and out of range
tactic_idx = setTactic(tactic_idx, parallelism_config);
if (tactic_idx < 0)
{
state.SkipWithMessage("Out of range tactic");
return;
}
if (VERBOSE)
{
auto tactics = mMoERunner.getTactics();
std::cout << "Selected " << tactic_idx << "/" << tactics.size() << "\n"
<< tactics[tactic_idx].toString() << std::endl;
}
state.counters["tactic_idx"] = tactic_idx;
for (auto _ : state)
{
float ms = benchmarkLoop(parallelism_config);
state.SetIterationTime(ms / 1000.f);
}
state.SetItemsProcessed(state.iterations() * num_tokens);
// Cleanup all the benchmark state
managed_buffers.clear();
check_cuda_error(cudaDeviceSynchronize());
}

View File

@ -0,0 +1,786 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Include the fixture with the actual benchmark code
#include "mixtureOfExpertsBackendBenchmarkFixture.h"
/*
* Below is all the setup for parameterising the benchmarks
*/
template <class DataType_, class WeightType_ = DataType_, class OutputType_ = DataType_>
struct WeightParams
{
using DataType = DataType_;
using WeightType = WeightType_;
using OutputType = OutputType_;
};
#define BENCHMARK_BASIC(atype, wtype, otype) \
BENCHMARK_TEMPLATE_DEFINE_F(MixtureOfExpertsBenchmark, Basic_##atype##_##wtype, WeightParams<atype, wtype, otype>) \
(benchmark::State & state) \
{ \
runBenchmark(state); \
}
#define BENCHMARK_BASIC_DO_REGISTER(atype, wtype, otype) \
BENCHMARK_REGISTER_F(MixtureOfExpertsBenchmark, Basic_##atype##_##wtype) \
->Apply(argGen<MixtureOfExpertsBenchmark<WeightParams<atype, wtype, otype>>>)
template <class BenchClass>
auto listAllTactics()
{
int const sm = getSMVersion();
using RunnerType = decltype(BenchClass::mMoERunner);
return RunnerType::getTactics(sm);
}
template <class BenchClass>
int parseTacticToId(nlohmann::json tactic_config)
{
bool is_sm90 = tactic_config.at("is_sm90").get<bool>();
int tile_shape_id = -1;
std::array<int, 3> tile_shape;
if (tactic_config.at("tile_shape").is_array())
tactic_config.at("tile_shape").get_to(tile_shape);
else
tile_shape_id = tactic_config.at("tile_shape").get<int>();
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> confs = listAllTactics<BenchClass>();
try
{
for (int i = 0; i < confs.size(); i++)
{
auto const& c = confs[i];
if (c.is_sm90 != is_sm90)
continue;
if (!is_sm90)
{
int stages = tactic_config.at("stages").get<int>();
if (c.stages != stages)
continue;
}
if (tile_shape_id != -1)
{
int comp = is_sm90 ? (int) c.tile_config_sm90 : (int) c.tile_config;
if (tile_shape_id != comp)
continue;
if (is_sm90 && (int) c.cluster_shape != tactic_config.at("cluster_shape").get<int>())
continue;
// Found matching config
return i;
}
// Handle if the user provided a shape instead of the enum value
if (is_sm90)
{
using Kv = uint64_t;
constexpr static auto K = [](int m, int n) { return (uint64_t(m) << 32) | uint64_t(n); };
static std::unordered_map<Kv, CutlassTileConfigSM90> const tile_map{
{K(64, 16), CutlassTileConfigSM90::CtaShape64x16x128B},
{K(64, 32), CutlassTileConfigSM90::CtaShape64x32x128B},
{K(64, 64), CutlassTileConfigSM90::CtaShape64x64x128B},
{K(64, 128), CutlassTileConfigSM90::CtaShape64x128x128B},
{K(64, 256), CutlassTileConfigSM90::CtaShape64x256x128B},
{K(128, 16), CutlassTileConfigSM90::CtaShape128x16x128B},
{K(128, 32), CutlassTileConfigSM90::CtaShape128x32x128B},
{K(128, 64), CutlassTileConfigSM90::CtaShape128x64x128B},
{K(128, 128), CutlassTileConfigSM90::CtaShape128x128x128B},
{K(128, 256), CutlassTileConfigSM90::CtaShape128x256x128B},
};
if (c.tile_config_sm90 != tile_map.at(K(tile_shape[0], tile_shape[1])))
continue;
static std::unordered_map<Kv, ClusterShape> const cluster_map{
// CTA configs for M=64
{K(1, 1), ClusterShape::ClusterShape_1x1x1},
{K(2, 1), ClusterShape::ClusterShape_2x1x1},
{K(1, 2), ClusterShape::ClusterShape_1x2x1},
{K(2, 2), ClusterShape::ClusterShape_2x2x1},
};
std::array<int, 3> cluster_shape;
tactic_config.at("cluster_shape").get_to(cluster_shape);
if (c.cluster_shape != cluster_map.at(K(cluster_shape[0], cluster_shape[1])))
continue;
// Found matching config
return i;
}
else
{
std::array<int, 3> warp_shape;
tactic_config.at("warp_shape").get_to(warp_shape);
using Kv = uint64_t;
constexpr static auto K = [](std::array<int, 3> a, std::array<int, 3> b)
{
uint64_t sum = 0;
for (auto v : a)
sum = sum * 512 + v;
for (auto v : b)
sum = sum * 256 + v;
return sum;
};
static std::unordered_map<Kv, CutlassTileConfig> tile_map{
{K({128, 128, 8}, {64, 64, 8}), CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8},
{K({16, 128, 64}, {16, 32, 64}), CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64},
{K({32, 128, 64}, {32, 32, 64}), CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64},
{K({64, 128, 64}, {32, 64, 64}), CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64},
{K({64, 64, 128}, {32, 64, 64}), CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64},
{K({64, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64},
{K({128, 64, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64},
{K({128, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64},
{K({128, 128, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64},
{K({128, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64},
{K({128, 256, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64},
{K({256, 128, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64},
{K({16, 256, 64}, {16, 64, 64}), CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64}
};
if (c.tile_config != tile_map.at(K(tile_shape, warp_shape)))
continue;
// Found matching config
return i;
}
}
}
catch (std::out_of_range const& e)
{
std::cerr << "Warning: error parsing tactic " << tactic_config.dump(2) << std::endl;
}
return -1;
}
template <class BenchClass>
void parseTacticToVectorID(nlohmann::json& tactic, std::vector<int>& tactic_ids)
{
if (tactic.is_number_integer())
{
tactic_ids.push_back(tactic.get<int>());
}
else if (tactic.is_array())
{
for (auto c : tactic)
{
parseTacticToVectorID<BenchClass>(c, tactic_ids);
}
}
else if (tactic.is_object())
{
tactic_ids.push_back(parseTacticToId<BenchClass>(tactic));
}
else if (tactic.is_string())
{
assert(tactic.is_string());
auto tactic_name = tactic.get<std::string>();
if (tactic_name == "all")
{
auto all_tactics = listAllTactics<BenchClass>();
tactic_ids.resize(all_tactics.size());
std::iota(tactic_ids.begin(), tactic_ids.end(), 0);
}
else
{
assert(tactic.get<std::string>() == "auto");
tactic_ids.push_back(-1);
}
}
else
{
throw std::invalid_argument("Invalid tactic format");
}
}
// This interdependence of globals could be better, but it works ok for this limited case.
std::unordered_map<std::string, std::pair<int, int>> name_info_map{
{routingConfigCache[LOAD_BALANCED_ROUTING_CONFIG]->getName(), {-1, LOAD_BALANCED_ROUTING_CONFIG}},
};
int getNameCacheIdx(std::string const& name)
{
if (name_info_map.find(name) == name_info_map.end())
{
return -1;
}
return name_info_map.at(name).second;
}
void setNameCacheIdx(std::string const& name, int id)
{
name_info_map.at(name).second = id;
}
// This is suboptimal for large benchmark files as we reread it for every data type
template <class BenchClass>
void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
{
/*
* File schema
*
* [
* {
* "num_experts": int,
* "k": int,
* "hidden_size": int,
* "inter_size": int,
* "tp_size": int, (optional)
* "ep_size": int, (optional)
* "world_rank": int, (optional)
* "num_tokens": int,
* "bias": int,
* "act_fn": int,
* "norm_mode": int,
* "tactic_id": tactic, (see below)
* "dtypes": [string, ...], (optional)
* "routing_values_name": string, (optional)
* "routing_values": [float, ...], or string, (optional, length is a multiple of num_experts)
* },
* ...
* ]
*
* Explanation:
*
* - "num_experts" - The number of experts
* - "k" - The top k
* - "hidden_size" - The hidden size
* - "inter_size" - The inter size
* - "tp_size" - The TP size
* - "ep_size" - The EP size
* - "world_rank" - The world rank = ep_rank * tp_size + tp_rank
* - "num_tokens" - The total number of tokens to benchmark
* - "bias" - If bias should be used, 0 = no bias, 1 = bias
* - "act_fn" - The enum value of the activation function. See
* "cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
* - "norm_mode" - The normalization mode. 0 = NONE, 1 = RENORM. See
* "cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
*
* - "tactic_id"
* Valid tactics are:
* - An object:
* {
* "is_sm90": bool,
* "tile_shape": [int, int, int] or int,
* "cluster_shape": [int, int, int] or int, (required for sm90, type must be an int if tile_shape is an int)
* "warp_shape": [int, int, int], (required for non-sm90 if tile_shape is an array)
* "stages": int, (required for non-sm90)
* },
* - An integer: corresponds to an index in the tactics array. WARNING this is not stable between test
* configurations
* - An array: of integers or objects, forms a list of tactics to sweep
* - The string "all": This will sweep through all possible tactics
* - The string "auto": This runs a short benchmark to pick the fastest tactic before each benchmark case. Useful
* for quick perf tests, prefer a full sweep and manually setting the tactic for more accurate results
*
* - dtypes - A list of dtypes to run this config through.
* Allowed values are: fp8, int4, int8, float, half, bfloat16
* If this argument is omitted all dtypes will be run. Note, not all tactics are supported for all dtypes,
* unsupported tactics will be skipped with a warning.
*
* - "routing_values_name" - a name to help identify the routing pattern. This can be used by later configs to reuse
* the config
* - "routing_values" - a flat array of routing values to define a new config, or a string referencing the name of a
* previous config. Defaults to "balanced", which is short-hand for a uniform expert distribution
* These define the routing values used as input to the moe backend, and is intended to allow comparing different
* routing behaviours.
* When defining an array, it must have `T*num_experts` floating point values. Each set of
* `num_experts` values defines the input for a single token. If `num_tokens` is greater than `T` it will repeat
* from the beginning
*
*/
std::ifstream file{workloadFile};
std::stringstream buffer;
buffer << file.rdbuf();
auto file_contents = buffer.str();
if (VERBOSE)
std::cout << "Loaded benchmark file: " << file_contents << std::endl;
auto source_data = nlohmann::json::parse(file_contents);
int i = 0;
for (auto run_config : source_data)
{
if (VERBOSE)
std::cout << "Parsing run config: " << run_config.dump(2) << std::endl;
std::string config_name = "config_" + std::to_string(i);
// WARNING: Process the routing configuration immediately, so we can guarantee all configs get processed for all
// data types. We should not skip any test cases as a later test config may depend on this config
if (run_config.contains("routing_values_name"))
{
run_config["routing_values_name"].get_to(config_name);
if (!run_config.contains("routing_values"))
{
throw std::invalid_argument("Setting routing value configuration name but missing routing values");
}
}
std::optional<int> routing_config;
auto res = name_info_map.emplace(config_name, std::pair{i, -1});
// We must check i is not equal since this function gets called for each data type
if (!res.second && res.first->second.first != i)
{
throw std::invalid_argument("Redefinition of routing_values_name " + config_name + " at config "
+ std::to_string(i) + ". First declared at " + std::to_string(res.first->second.first));
}
else if (!res.second)
{
// Reuse the existing config from a previous parse
routing_config = getNameCacheIdx(config_name);
}
i++;
int num_experts = run_config.at("num_experts").get<int>();
if (run_config.contains("routing_values") && !routing_config)
{
if (run_config["routing_values"].is_string())
{
routing_config = getNameCacheIdx(run_config["routing_values"].get<std::string>());
if (routing_config < 0)
{
throw std::invalid_argument("Invalid routing value, could not find valid config");
}
}
else
{
if (config_name.empty())
{
throw std::invalid_argument("Explicit routing configurations must specify a name");
}
std::vector<float> routing_values;
run_config["routing_values"].get_to(routing_values);
int shape = routing_values.size() / num_experts;
routingConfigCache.push_back(std::make_shared<VectoredRoutingConfig>(
std::move(routing_values), std::pair<int, int>(shape, num_experts), config_name));
routing_config = routingConfigCache.size() - 1;
}
auto conf = routingConfigCache[*routing_config];
auto conf_vec = std::dynamic_pointer_cast<VectoredRoutingConfig>(conf);
if (conf->getName() != "balanced" && (!conf_vec || conf_vec->shape.second != num_experts))
{
throw std::invalid_argument("Incompatible config selected. Expected " + std::to_string(num_experts)
+ " experts in routing configuration. "
+ ((conf_vec) ? "Found: " + std::to_string(conf_vec->shape.second)
: "Found incompatible routing config type"));
}
}
// Use the selected config or fall back to balanced
routing_config = routing_config.value_or(LOAD_BALANCED_ROUTING_CONFIG);
setNameCacheIdx(config_name, *routing_config);
// Filter out the types we don't care about testing
if (run_config.contains("dtypes"))
{
std::vector<std::string> dtypes;
run_config["dtypes"].get_to(dtypes);
auto hasDtype = [&](char const* d)
{ return std::any_of(dtypes.begin(), dtypes.end(), [&](auto const& n) { return n == d; }); };
if (BenchClass::FP8 && !hasDtype("fp8"))
{
continue;
}
else if (BenchClass::INT4 && !hasDtype("int4"))
{
continue;
}
else if (!BenchClass::INT4 && BenchClass::INT_QUANT && !hasDtype("int8"))
{
continue;
}
else if (std::is_same_v<typename BenchClass::WeightType, float> && !hasDtype("float")
&& !hasDtype("float32"))
{
continue;
}
else if (std::is_same_v<typename BenchClass::WeightType, half> && !hasDtype("float16") && !hasDtype("half"))
{
continue;
}
else if (std::is_same_v<typename BenchClass::WeightType, __nv_bfloat16> && !hasDtype("bfloat16")
&& !hasDtype("bf16"))
{
continue;
}
}
// Do this after filtering datatypes as tactics only make sense if we know the data type
std::vector<int> tactic_ids{};
parseTacticToVectorID<BenchClass>(run_config["tactic_id"], tactic_ids);
if (tactic_ids.empty())
{
std::cerr << "Warning: Skipping benchmark, no such tactic: " << run_config["tactic"].dump(2) << std::endl;
static bool printed = false;
if (!printed)
{
printed = true;
std::cerr << __PRETTY_FUNCTION__ << ": Valid Tactics are:\n";
auto confs = listAllTactics<BenchClass>();
for (auto c : confs)
std::cerr << c.toString();
}
continue;
}
auto get_or = [&](auto name, auto def)
{ return run_config.contains(name) ? run_config[name].template get<decltype(def)>() : def; };
int tp_size = get_or("tp_size", 1);
int ep_size = get_or("ep_size", 1);
int world_rank = get_or("world_rank", 0);
int bias = get_or("bias", 0);
for (auto tactic_id : tactic_ids)
{
benchmark->Args({num_experts, //
run_config.at("k").get<int>(), //
run_config.at("hidden_size").get<int>(), //
run_config.at("inter_size").get<int>(), //
tp_size, ep_size, world_rank, //
run_config.at("num_tokens").get<int>(), //
bias, //
run_config.at("act_fn").get<int>(), //
run_config.at("norm_mode").get<int>(), //
tactic_id, //
*routing_config});
}
}
}
template <class BenchClass>
void argGenHardcoded(benchmark::internal::Benchmark* benchmark)
{
auto num_tactics = listAllTactics<BenchClass>().size();
for (auto [tp, ep] : std::vector<std::pair<int, int>>{{8, 1}, {1, 8}, {2, 4}, {4, 2}})
{
for (int i = 0; i < num_tactics; i++)
{
for (auto tokens : {1, 64, 2048})
{
benchmark->Args({
16, // Experts
2, // K
15360, // hidden
30720, // inter
tp, // TP Size
ep, // EP Size
0, // World Rank
tokens, // Num tokens
0, // bias
(int) tensorrt_llm::ActivationType::Gelu, // Act fn
(int) MOEExpertScaleNormalizationMode::RENORMALIZE, // Norm mode
i, // Tactic ID. Index into getTactics() function result, see argGenLoadFile() for examples
LOAD_BALANCED_ROUTING_CONFIG // Routing configuration id
});
benchmark->Args({
16, // Experts
2, // K
15360, // hidden
20480, // inter
tp, // TP Size
ep, // EP Size
0, // World Rank
tokens, // Num tokens
0, // bias
(int) tensorrt_llm::ActivationType::Swiglu, // Act fn
(int) MOEExpertScaleNormalizationMode::RENORMALIZE, // Norm mode
i, // Tactic ID. Index into getTactics() function result, see argGenLoadFile() for examples
LOAD_BALANCED_ROUTING_CONFIG // Routing configuration id
});
}
}
}
return;
auto num_experts = {1, 8, 9, 64, 65, 257}; // {1, 8, 64, 65, 1024};
auto top_k = {1, 2, 3, 16}; // {1, 2, 3, 42};
auto hidden_size = {32}; // {8, 32, 96, 256, 1024};
auto inter_size_mul = {4.f}; // {7.f/2.f, 4.f};
auto num_tokens = {200}; // {1, 20, 200};
auto use_bias = {0, 1}; // {0, 1};
auto activation_type = {tensorrt_llm::ActivationType::Gelu};
// {tensorrt_llm::ActivationType::Relu, tensorrt_llm::ActivationType::Gelu,
// tensorrt_llm::ActivationType::Silu, tensorrt_llm::ActivationType::Geglu,
// tensorrt_llm::ActivationType::Swiglu};
auto norm_mode = {MOEExpertScaleNormalizationMode::NONE};
auto cutlass_tactic = {0}; // {0, 1, 2};
auto routing_config = {LOAD_BALANCED_ROUTING_CONFIG}; // {0, 1, 2};
for (auto num_expert : num_experts)
for (auto k : top_k)
if (k <= num_expert)
for (auto size : hidden_size)
for (auto inter_mul : inter_size_mul)
{
auto inter_size = static_cast<int>(size * inter_mul);
for (auto tokens : num_tokens)
for (auto bias : use_bias)
for (auto act : activation_type)
for (auto norm : norm_mode)
for (auto tactic : cutlass_tactic)
for (auto routing : routing_config)
benchmark->Args({num_expert, k, size, inter_size, 1, 1, 0, tokens, bias,
(int) act, (int) norm, tactic, routing});
}
}
template <class BenchClass>
void argGen(benchmark::internal::Benchmark* benchmark)
{
if (VERBOSE)
{
std::cout << "List of all tactics for dtype " << (int) BenchClass::toDTypeID() << ":\n";
int i = 0;
for (auto& t : listAllTactics<BenchClass>())
{
std::cout << "Tactic " << i << ":\n";
std::cout << t.toString() << std::endl;
i++;
}
}
// Generic setup
benchmark->UseManualTime();
benchmark->ArgNames({"Num Experts", "K", "Hidden Size", "Inter Size", "TP Size", "EP Size", "World Rank",
"Num Tokens", "Use Bias", "Activation Function", "Norm Mode", "Tactic ID", "Routing ID"});
if (workloadFile)
argGenLoadFile<BenchClass>(benchmark);
else
argGenHardcoded<BenchClass>(benchmark);
}
BENCHMARK_BASIC(float, float, float)
BENCHMARK_BASIC(half, half, half)
using uint8 = uint8_t;
BENCHMARK_BASIC(half, uint8, half)
using cutlass::uint4b_t;
BENCHMARK_BASIC(half, uint4b_t, half)
#ifdef ENABLE_BF16
BENCHMARK_BASIC(nv_bfloat16, nv_bfloat16, nv_bfloat16)
#endif
#ifdef ENABLE_FP8
BENCHMARK_BASIC(SafeFP8, SafeFP8, half)
#endif
void delayedRegisterBenchmark()
{
BENCHMARK_BASIC_DO_REGISTER(half, half, half);
#ifdef ENABLE_FP8
BENCHMARK_BASIC_DO_REGISTER(SafeFP8, SafeFP8, half);
#endif
if (workloadFile)
{
// Extra ones we don't want for hardcoded runs
BENCHMARK_BASIC_DO_REGISTER(float, float, float);
BENCHMARK_BASIC_DO_REGISTER(half, uint8, half);
BENCHMARK_BASIC_DO_REGISTER(half, uint4b_t, half);
#ifdef ENABLE_BF16
BENCHMARK_BASIC_DO_REGISTER(nv_bfloat16, nv_bfloat16, nv_bfloat16);
#endif
}
}
void doCleanup()
{
bufferManager.reset();
streamPtr.reset();
}
void help()
{
std::cout << "Usage: mixtureOfExpertsBackendBenchmark [--input_file <file>] [benchmark options]\n";
std::cout
<< "--input_file\t\tA JSON file describing the benchmark configurations\n\n"
<< "File schema\n"
"[\n"
" {\n"
" \"num_experts\": int,\n"
" \"k\": int,\n"
" \"hidden_size\": int,\n"
" \"inter_size\": int,\n"
" \"tp_size\": int, (optional)\n"
" \"ep_size\": int, (optional)\n"
" \"world_rank\": int, (optional)\n"
" \"num_tokens\": int,\n"
" \"bias\": int,\n"
" \"act_fn\": int,\n"
" \"norm_mode\": int,\n"
" \"tactic_id\": tactic, (see below)\n"
" \"dtypes\": [string, ...], (optional)\n"
" \"routing_values_name\": string, (optional)\n"
" \"routing_values\": [float, ...], or string, (optional, length is a multiple of num_experts)\n"
" },\n"
" ...\n"
"]\n"
"Explanation:\n"
"- \"num_experts\" - The number of experts\n"
"- \"k\" - The top k\n"
"- \"hidden_size\" - The hidden size\n"
"- \"inter_size\" - The inter size\n"
"- \"tp_size\" - The TP size to use\n"
"- \"ep_size\" - The EP size to use\n"
"- \"world_rank\" - The world rank = ep_rank * tp_size + tp_rank\n"
"- \"num_tokens\" - The total number of tokens to benchmark\n"
"- \"bias\" - If bias should be used, 0 = no bias, 1 = bias\n"
"- \"act_fn\" - The enum value of the activation function. See\n"
"\"cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h\"\n"
"- \"norm_mode\" - The normalization mode. 0 = NONE, 1 = RENORM. See\n"
"\"cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h\"\n"
"- \"tactic_id\"\n"
"Valid tactics are:\n"
" - An object:\n"
" {\n"
" \"is_sm90\": bool,\n"
" \"tile_shape\": [int, int, int] or int,\n"
" \"cluster_shape\": [int, int, int] or int, (required for sm90, type must be an int if tile_shape is "
"an int)\n"
" \"warp_shape\": [int, int, int], (required for non-sm90 if tile_shape is an array)\n"
" \"stages\": int, (required for non-sm90)\n"
" },\n"
" - An integer: corresponds to an index in the tactics array. WARNING this is not stable between test "
"configurations\n"
" - An array: of integers or objects, forms a list of tactics to sweep\n"
" - The string \"all\": This will sweep through all possible tactics\n"
" - The string \"auto\": This runs a short benchmark to pick the fastest tactic before each benchmark case. "
"Useful for quick perf tests, prefer a full sweep and manually setting the tactic for more accurate results"
"- dtypes - A list of dtypes to run this config through.\n"
"Allowed values are: fp8, int4, int8, float, half, bfloat16\n"
"If this argument is omitted all dtypes will be run. Note, not all tactics are supported for all dtypes,\n"
"unsupported tactics will be skipped with a warning.\n"
"- \"routing_values_name\" - a name to help identify the routing pattern. This can be used by later "
"benchmarks to reuse the config\n"
"- \"routing_values\" - a flat array of routing values to define a new config, or a string referencing "
"the name of a\n"
"previous config. Defaults to \"balanced\", which is short-hand for a uniform expert distribution\n"
"These define the routing values used as input to the moe backend, and is intended to allow comparing "
"different routing behaviours.\n"
"When defining an array, it must have `T*num_experts` floating point values. Each set of\n"
"`num_experts` values defines the input for a single token. If `num_tokens` is greater than `T` it will "
"repeat from the beginning\n\n";
std::cout << "benchmark options:\n";
benchmark::PrintDefaultHelp();
}
void gbenchCustomHelp()
{
help();
// google-benchmark calls exit() so we need to cleanup manually
doCleanup();
}
int parseArgsAndRunBench(int argc, char** argv)
{
try
{
int shift = 0;
for (int i = 1; i < argc; i++)
{
argv[i - shift] = argv[i];
if (strcmp("--input_file", argv[i]) == 0)
{
i += 1;
if (i == argc)
{
std::cerr << "Missing file name for input_file\n";
return -1;
}
workloadFile = argv[i];
if (workloadFile[0] == '-')
{
std::cerr << "Workload file " << workloadFile << " not a valid file name\n";
return -2;
}
shift += 2;
}
else if (strcmp("--help", argv[i]) == 0 || strcmp("-h", argv[i]) == 0)
{
help();
return 0;
}
}
argc -= shift;
// Delay after we know if the user passed a config file
delayedRegisterBenchmark();
benchmark::Initialize(&argc, argv, &gbenchCustomHelp);
if (argc > 1)
{
help();
std::cout << std::flush; // Force flush
// Print the error second, so it's easy to see
std::cerr << "\nUnrecognised argument: " << argv[1] << std::endl;
return -4;
}
benchmark::RunSpecifiedBenchmarks();
benchmark::Shutdown();
return 0;
}
catch (std::exception const& e)
{
std::cerr << "Exiting benchmarks with exception: " << e.what() << std::endl;
return -3;
}
}
int main(int argc, char** argv)
{
deviceCount = getDeviceCount();
if (deviceCount < 0)
return 0;
streamPtr = std::make_shared<CudaStream>();
bufferManager = std::make_unique<BufferManager>(streamPtr);
int res = -1;
try
{
res = parseArgsAndRunBench(argc, argv);
}
catch (std::exception const& e)
{
std::cout << "Benchmark exited with unhandled exception: " << e.what() << std::endl;
}
doCleanup();
return res;
}

View File

@ -204,12 +204,22 @@ else()
"kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/libtensorrt_llm_nvrtc_wrapper.so"
)
else()
set(NVRTC_WRAPPER_LIB_BINARY_REL_DIR
"kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper"
)
set(NVRTC_WRAPPER_DLL_NAME "tensorrt_llm_nvrtc_wrapper.dll")
set(NVRTC_WRAPPER_LIB_NAME "tensorrt_llm_nvrtc_wrapper.lib")
set(NVRTC_WRAPPER_LIB_SOURCE_REL_LOC
"kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/${NVRTC_WRAPPER_TARGET_ARCH}/libtensorrt_llm_nvrtc_wrapper.dll"
"${NVRTC_WRAPPER_LIB_BINARY_REL_DIR}/${NVRTC_WRAPPER_TARGET_ARCH}/${NVRTC_WRAPPER_DLL_NAME}"
)
set(NVRTC_WRAPPER_LIB_BINARY_REL_LOC
"kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/libtensorrt_llm_nvrtc_wrapper.dll"
"${NVRTC_WRAPPER_LIB_BINARY_REL_DIR}/${NVRTC_WRAPPER_DLL_NAME}")
set(NVRTC_WRAPPER_IMPLIB_SOURCE_REL_LOC
"${NVRTC_WRAPPER_LIB_BINARY_REL_DIR}/${NVRTC_WRAPPER_TARGET_ARCH}/${NVRTC_WRAPPER_LIB_NAME}"
)
set(NVRTC_WRAPPER_IMPLIB_BINARY_REL_LOC
"${NVRTC_WRAPPER_LIB_BINARY_REL_DIR}/${NVRTC_WRAPPER_LIB_NAME}")
endif()
set(NVRTC_WRAPPER_LIB_LOC
"${CMAKE_CURRENT_SOURCE_DIR}/${NVRTC_WRAPPER_LIB_SOURCE_REL_LOC}")
@ -218,6 +228,15 @@ else()
${NVRTC_WRAPPER_LIB_BINARY_REL_LOC} COPYONLY)
set_property(TARGET ${NVRTC_WRAPPER_TARGET} PROPERTY IMPORTED_LOCATION
${NVRTC_WRAPPER_LIB_LOC})
if(WIN32)
set(NVRTC_WRAPPER_IMPLIB_LOC
"${CMAKE_CURRENT_SOURCE_DIR}/${NVRTC_WRAPPER_IMPLIB_SOURCE_REL_LOC}")
configure_file(${NVRTC_WRAPPER_IMPLIB_SOURCE_REL_LOC}
${NVRTC_WRAPPER_IMPLIB_BINARY_REL_LOC} COPYONLY)
set_property(TARGET ${NVRTC_WRAPPER_TARGET}
PROPERTY IMPORTED_IMPLIB ${NVRTC_WRAPPER_IMPLIB_LOC})
endif()
file(SIZE ${NVRTC_WRAPPER_LIB_LOC} NVRTC_WRAPPER_LIB_SIZE)
if(NVRTC_WRAPPER_LIB_SIZE LESS 1024)
message(
@ -234,7 +253,6 @@ set(TRTLLM_LINK_LIBS
${TRT_LIB}
common_src
kernels_src
context_attention_src
decoder_attention_src
fpA_intB_gemm_src
moe_gemm_src

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6a75d7ce014817aaf87168c9db2381273fb2c91cd997663e90b476fe1c7d6503
size 3308922
oid sha256:2eaa6076f6ae6341710ca84097a3083f57235093bfc3bb1cd04f2bce06e1a99e
size 3412616

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6a75d7ce014817aaf87168c9db2381273fb2c91cd997663e90b476fe1c7d6503
size 3308922
oid sha256:2eaa6076f6ae6341710ca84097a3083f57235093bfc3bb1cd04f2bce06e1a99e
size 3412616

View File

@ -1,3 +1,3 @@
ed2ee8d73a5d374e800f653169bf293e libtensorrt_llm_batch_manager_static.a
ed2ee8d73a5d374e800f653169bf293e libtensorrt_llm_batch_manager_static.pre_cxx11.a
05aaf1a0fb2f0115af107b00aa839a6601f6a873 commit
4b12adf3182aabf1df0b5dac217509e0 libtensorrt_llm_batch_manager_static.a
4b12adf3182aabf1df0b5dac217509e0 libtensorrt_llm_batch_manager_static.pre_cxx11.a
fc46fa01e555f9f97387340e46e9571fabf73988 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:00e2d6ee8efd00e27dd8da61be576ba7978d885a055d591c90f600b334356846
size 3211414
oid sha256:e4fb588419244c8c07a9ce949edd7ed4e3dded008ed82aa993ff69e524394be9
size 3310186

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b6b65183b0aa3f40f68aa13105da9dc00fb75b8bf8892813e46a09e3f0743570
size 3186478
oid sha256:4b31c50b2879f57022e788700ebf3a86fa8f30133a01533b03c7bc15d64ad364
size 3283536

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7d4a3bc5160666612e529f21c61dbd9d0f1b387662768f76b9351f877108f84b
size 19840380
oid sha256:8f458ce861720a4ce15c9cedeae4bb6c1f6a8a98f0fced35198c9802feaddb10
size 20305606

View File

@ -207,8 +207,9 @@ __device__ __nv_bfloat16 atomicMaxExtd(__nv_bfloat16* address, __nv_bfloat16 val
return __ushort_as_bfloat16(old);
#else
asm volatile(" brkpt;\n");
return 0;
assert(0);
asm volatile("brkpt;\n" ::);
return __nv_bfloat16(0);
#endif
}

View File

@ -598,8 +598,9 @@ __device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val)
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
return __hmax(val.x, val.y);
#else
asm volatile(" brkpt;\n");
return 0;
assert(0);
asm volatile("brkpt;\n" ::);
return __nv_bfloat16(0);
#endif
}
#endif

View File

@ -55,23 +55,23 @@ std::optional<int32_t> envXqaNbCtaPerKVHead()
return ret;
}
bool getEnvDisableXQAJIT()
bool getEnvEnableXQAJIT()
{
static bool init = false;
static bool disableXQAJIT = false;
static bool enableXQAJIT = false;
if (!init)
{
init = true;
char const* disable_xqa_jit_var = std::getenv("TRTLLM_DISABLE_XQA_JIT");
if (disable_xqa_jit_var)
char const* enable_xqa_jit_var = std::getenv("TRTLLM_ENABLE_XQA_JIT");
if (enable_xqa_jit_var)
{
if (disable_xqa_jit_var[0] == '1' && disable_xqa_jit_var[1] == '\0')
if (enable_xqa_jit_var[0] == '1' && enable_xqa_jit_var[1] == '\0')
{
disableXQAJIT = true;
enableXQAJIT = true;
}
}
}
return disableXQAJIT;
return enableXQAJIT;
}
// Tune the number of blocks per sequence for accuracy/performance purpose.

View File

@ -33,8 +33,8 @@ int32_t xqaMaxNbCtaPerKVHeadFactor();
std::optional<int32_t> envXqaNbCtaPerKVHead();
// Whether XQA JIT is disabled.
bool getEnvDisableXQAJIT();
// Whether XQA JIT is enabled.
bool getEnvEnableXQAJIT();
// Tune the number of blocks per sequence for accuracy/performance purpose.
bool getEnvMmhaMultiblockDebug();

View File

@ -112,6 +112,9 @@ void initialize(MpiThreadSupport threadMode)
auto previousHandler = std::signal(SIGABRT, [](int signal) { MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); });
TLLM_CHECK_WITH_INFO(previousHandler != SIG_ERR, "Signal handler setup failed");
// ensure local MPI communicator is initialized
MpiComm::localSession();
}
#endif // ENABLE_MULTI_DEVICE
mpiInitialized = true;
@ -271,6 +274,24 @@ MpiComm& MpiComm::session()
return commSession;
}
MpiComm getLocalSession()
{
#if ENABLE_MULTI_DEVICE
MPI_Comm localComm;
MPI_Comm_split_type(MPI_COMM_WORLD, OMPI_COMM_TYPE_HOST, 0, MPI_INFO_NULL, &localComm);
MpiComm localSession{localComm, false};
#else
MpiComm localSession{MPI_COMM_WORLD, false};
#endif // ENABLE_MULTI_DEVICE
return localSession;
}
MpiComm& MpiComm::localSession()
{
static MpiComm localSession = getLocalSession();
return localSession;
}
MpiComm::MpiComm(MPI_Comm g, bool freeComm)
: mComm{g}
, mFreeComm{freeComm}

View File

@ -25,6 +25,7 @@
#include "cutlass_extensions/tile_interleaved_layout.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h"
namespace cutlass
{
@ -39,11 +40,22 @@ template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantO
typename Enable = void>
struct DefaultScaleIteratorsPipelined;
// TODO: Fine grained iterators
// Fine grained iterators
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
struct DefaultScaleIteratorsPipelined<MmaShape, Element, Layout, QuantOp, Alignment,
std::enable_if_t<isFinegrained(QuantOp)>>
{
private:
using SmemScaleType = half_t;
public:
using IteratorScale
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>, Element,
Layout, 0, Alignment>;
using SmemIteratorScale
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>,
SmemScaleType, Layout, 0, Alignment>;
};
// Per column iterators
@ -206,7 +218,6 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
using MmaCoreElementA = half_t;

View File

@ -80,7 +80,7 @@ template <
typename SmemIteratorB_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB,
/// Data type for the scales
/// Iterators over scales in global memory
typename IteratorScale_,
/// Iterators over scales in shared memory
typename SmemIteratorScale_,

View File

@ -80,13 +80,13 @@ template <
typename SmemIteratorB_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB,
/// Data type for the scales
/// Iterators over scales in global memory
typename IteratorScale_,
/// Iterators over scales in shared memory
typename SmemIteratorScale_,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
/// Layout of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,

View File

@ -95,304 +95,12 @@ template <
/// The quantization operator being used
WeightOnlyQuantOp QuantOp_,
/// Used for partial specialization
typename Enable = bool>
class DqMmaPipelined : public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>
{
public:
///< Base class
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>;
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using IteratorScale = IteratorScale_;
using ElementScale = typename IteratorScale::Element;
using LayoutScale = typename IteratorScale::Layout;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using SmemIteratorScale = SmemIteratorScale_;
using TransformBAfterLDG = TransformBAfterLDG_;
using TransformBAfterLDS = TransformBAfterLDS_;
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
//
// Dependent types
//
/// Fragment of operand A loaded from global memory
using FragmentA = typename IteratorA::Fragment;
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
/// Fragment of operand Scale loaded from global memory;
using FragmentScale = typename IteratorScale::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB,
typename SmemIteratorScale::Fragment::Element, LayoutScale, 32, QuantOp>;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool RequiresTileInterleave
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
SmemIteratorScale smem_iterator_scale_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaPipelined(typename Base::SharedStorage&
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation
///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this
///< argument is not added, it does not affect compilation for sm>=80.
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
)
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
FragmentC& accum, ///< destination accumulator tile
IteratorA iterator_A, ///< iterator over A operand in global memory
IteratorB iterator_B, ///< iterator over B operand in global memory
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
FragmentC const& src_accum)
{ ///< source accumulator tile
//
// Prologue
//
TransformBAfterLDG ldg_converter;
TransformBAfterLDS lds_converter;
using TransformA
= NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Fragment::Element,
typename FragmentScale::Element, FragmentScale::kElements>;
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
TransformA transformA;
TransformScale transformScale;
// Perform accumulation in the 'd' output operand
accum = src_accum;
FragmentA tb_frag_A;
FragmentB tb_frag_B;
FragmentScale tb_frag_scales;
using WarpFragmentScale = typename Dequantizer::FragmentScale;
WarpFragmentScale warp_frag_scales;
tb_frag_A.clear();
tb_frag_B.clear();
tb_frag_scales.clear();
// The last kblock is loaded in the prolog
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
iterator_scale.load(tb_frag_scales);
++iterator_A;
++iterator_B;
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
this->smem_iterator_scale_.store(transformScale(tb_frag_scales));
++this->smem_iterator_A_;
++this->smem_iterator_B_;
__syncthreads();
warp_dequantizer_.load(warp_frag_scales);
// Pair of fragments used to overlap shared memory loads and math instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
Operator warp_mma;
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
iterator_A.clear_mask(gemm_k_iterations <= 1);
iterator_B.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
// shared memory loads (which have the tighest latency requirement).
//
// Mainloop
//
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > 0; --gemm_k_iterations)
{
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
{
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
// as the case may be.
if (warp_mma_k == Base::kWarpGemmIterations - 1)
{
// Write fragments to shared memory
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
__syncthreads();
++this->smem_iterator_A_;
++this->smem_iterator_B_;
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
if (smem_write_stage_idx == 1)
{
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
}
else
{
this->warp_tile_iterator_A_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
}
smem_write_stage_idx ^= 1;
}
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
{
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
if (warp_mma_k == 0)
{
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
// Avoid reading out of bounds if this was the last loop iteration
iterator_A.clear_mask(gemm_k_iterations <= 2);
iterator_B.clear_mask(gemm_k_iterations <= 2);
}
typename TransformBAfterLDS::result_type converted_frag_B
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
run_warp_mma(
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
typename Enable = void>
class DqMmaPipelined;
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h"

View File

@ -0,0 +1,486 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
#include "cutlass_extensions/gemm_configs.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Iterators over scales in global memory
typename IteratorScale_,
/// Iterators over scales in shared memory
typename SmemIteratorScale_,
/// Data type of accumulator matrix
typename ElementC_,
/// Layout of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Converter for B matrix applied immediately after the LDG (before STS)
typename TransformBAfterLDG_,
/// Converter for B matrix applited immediately after the LDS
typename TransformBAfterLDS_,
/// The quantization operator being used
WeightOnlyQuantOp QuantOp_>
class DqMmaPipelined<Shape_, IteratorA_, SmemIteratorA_, IteratorB_, SmemIteratorB_, IteratorScale_, SmemIteratorScale_,
ElementC_, LayoutC_, Policy_, TransformBAfterLDG_, TransformBAfterLDS_, QuantOp_,
std::enable_if_t<isFinegrained(QuantOp_)>>
: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>
{
public:
///< Base class
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>;
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using IteratorScale = IteratorScale_;
using ElementScale = typename IteratorScale::Element;
using LayoutScale = typename IteratorScale::Layout;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using SmemIteratorScale = SmemIteratorScale_;
using TransformBAfterLDG = TransformBAfterLDG_;
using TransformBAfterLDS = TransformBAfterLDS_;
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
//
// Dependent types
//
/// Fragment of operand A loaded from global memory
using FragmentA = typename IteratorA::Fragment;
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
/// Fragment of operand Scale loaded from global memory;
using FragmentScale = typename IteratorScale::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB,
typename SmemIteratorScale::Element, LayoutScale, 32, QuantOp>;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, "");
static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, "");
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
using WarpFragmentScale = typename Dequantizer::FragmentScale;
using WarpFragmentZero = typename Dequantizer::FragmentZero;
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool RequiresTileInterleave
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory
SmemIteratorScale smem_iterator_scale_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaPipelined(typename Base::SharedStorage&
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
int const group_size, ///< The group size for quantization
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
)
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
{shared_storage.operand_zero.data(), LayoutScale(Shape::kN)},
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(),
shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
}
CUTLASS_DEVICE
void copy_scales_and_advance(IteratorScale& iterator_scale)
{
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Element,
typename FragmentScale::Element, FragmentScale::kElements>;
FragmentScale tb_frag_scales;
FragmentScale tb_frag_zeros;
tb_frag_scales.clear();
tb_frag_zeros.clear();
TransformScale transformScale;
using FragmentElement = typename FragmentScale::Element;
auto gmem_scale_ptr = iterator_scale.get_scale();
auto gmem_zero_ptr = iterator_scale.get_zero();
arch::global_load<FragmentScale, sizeof(FragmentScale)>(tb_frag_scales, gmem_scale_ptr, iterator_scale.valid());
if (gmem_zero_ptr != nullptr)
{
arch::global_load<FragmentScale, sizeof(FragmentScale)>(
tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid());
}
typename TransformScale::result_type tb_frag_scales_fp16 = transformScale(tb_frag_scales);
typename TransformScale::result_type tb_frag_zeros_fp16;
if (gmem_zero_ptr != nullptr)
tb_frag_zeros_fp16 = transformScale(tb_frag_zeros);
auto frag_scale_ptr_fp16 = reinterpret_cast<typename SmemIteratorScale::Element*>(&tb_frag_scales_fp16);
auto frag_zero_ptr_fp16 = reinterpret_cast<typename SmemIteratorScale::Element*>(&tb_frag_zeros_fp16);
auto smem_scale_ptr = this->smem_iterator_scale_.get_scale();
auto smem_zero_ptr = this->smem_iterator_scale_.get_zero();
if (iterator_scale.valid())
{
auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr);
arch::shared_store<sizeof(FragmentScale)>(smem_offset, frag_scale_ptr_fp16);
if (gmem_zero_ptr != nullptr)
{
smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr);
arch::shared_store<sizeof(FragmentScale)>(smem_offset, frag_zero_ptr_fp16);
}
}
if (iterator_scale.group_size_ == 64)
{
iterator_scale.add_tile_offset({1, 0});
}
else if (iterator_scale.group_size_ == 128)
{
if constexpr (Shape::kK == 128)
{
iterator_scale.add_tile_offset({1, 0});
}
else if constexpr (Shape::kK == 64)
{
if (iterator_scale.row_groupsize64_ & 0x1)
{
iterator_scale.add_tile_offset({1, 0});
}
}
else
{
static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128");
}
}
iterator_scale.row_groupsize64_++;
this->smem_iterator_scale_.add_tile_offset({1, 0});
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
FragmentC& accum, ///< destination accumulator tile
IteratorA iterator_A, ///< iterator over A operand in global memory
IteratorB iterator_B, ///< iterator over B operand in global memory
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
FragmentC const& src_accum)
{ ///< source accumulator tile
//
// Prologue
//
TransformBAfterLDG ldg_converter;
TransformBAfterLDS lds_converter;
using TransformA
= NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
TransformA transformA;
// Perform accumulation in the 'd' output operand
accum = src_accum;
FragmentA tb_frag_A;
FragmentB tb_frag_B;
tb_frag_A.clear();
tb_frag_B.clear();
// The last kblock is loaded in the prolog
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
++this->smem_iterator_A_;
++this->smem_iterator_B_;
copy_scales_and_advance(iterator_scale);
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
WarpFragmentScale warp_frag_scales;
WarpFragmentZero warp_frag_zero;
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
warp_dequantizer_.add_pointer_offset(Shape::kN);
Operator warp_mma;
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
iterator_A.clear_mask(gemm_k_iterations <= 1);
iterator_B.clear_mask(gemm_k_iterations <= 1);
iterator_scale.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
// shared memory loads (which have the tighest latency requirement).
//
// Mainloop
//
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > 0; --gemm_k_iterations)
{
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
{
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
// as the case may be.
if (warp_mma_k == Base::kWarpGemmIterations - 1)
{
// Write fragments to shared memory
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
__syncthreads();
++this->smem_iterator_A_;
++this->smem_iterator_B_;
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
if (smem_write_stage_idx == 1)
{
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0});
}
else
{
this->warp_tile_iterator_A_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN);
}
smem_write_stage_idx ^= 1;
}
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
{
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
if (warp_mma_k == 0)
{
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
copy_scales_and_advance(iterator_scale);
// Avoid reading out of bounds if this was the last loop iteration
iterator_A.clear_mask(gemm_k_iterations <= 2);
iterator_B.clear_mask(gemm_k_iterations <= 2);
iterator_scale.clear_mask(gemm_k_iterations <= 2);
}
typename TransformBAfterLDS::result_type converted_frag_B
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zero);
run_warp_mma(
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
}
// Load the scales needed for the next tile iteration
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
// Update internal pointer to the set of scales in shared memory
warp_dequantizer_.add_pointer_offset(Shape::kN);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,399 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
#include "cutlass_extensions/gemm_configs.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Iterators over scales in global memory
typename IteratorScale_,
/// Iterators over scales in shared memory
typename SmemIteratorScale_,
/// Data type of accumulator matrix
typename ElementC_,
/// Layout of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Converter for B matrix applied immediately after the LDG (before STS)
typename TransformBAfterLDG_,
/// Converter for B matrix applited immediately after the LDS
typename TransformBAfterLDS_,
/// The quantization operator being used
WeightOnlyQuantOp QuantOp_>
class DqMmaPipelined<Shape_, IteratorA_, SmemIteratorA_, IteratorB_, SmemIteratorB_, IteratorScale_, SmemIteratorScale_,
ElementC_, LayoutC_, Policy_, TransformBAfterLDG_, TransformBAfterLDS_, QuantOp_,
std::enable_if_t<!isFinegrained(QuantOp_)>>
: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>
{
public:
///< Base class
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>;
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using IteratorScale = IteratorScale_;
using ElementScale = typename IteratorScale::Element;
using LayoutScale = typename IteratorScale::Layout;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using SmemIteratorScale = SmemIteratorScale_;
using TransformBAfterLDG = TransformBAfterLDG_;
using TransformBAfterLDS = TransformBAfterLDS_;
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
//
// Dependent types
//
/// Fragment of operand A loaded from global memory
using FragmentA = typename IteratorA::Fragment;
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
/// Fragment of operand Scale loaded from global memory;
using FragmentScale = typename IteratorScale::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB,
typename SmemIteratorScale::Fragment::Element, LayoutScale, 32, QuantOp>;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool RequiresTileInterleave
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
SmemIteratorScale smem_iterator_scale_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaPipelined(typename Base::SharedStorage&
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation
///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this
///< argument is not added, it does not affect compilation for sm>=80.
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
)
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
FragmentC& accum, ///< destination accumulator tile
IteratorA iterator_A, ///< iterator over A operand in global memory
IteratorB iterator_B, ///< iterator over B operand in global memory
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
FragmentC const& src_accum)
{ ///< source accumulator tile
//
// Prologue
//
TransformBAfterLDG ldg_converter;
TransformBAfterLDS lds_converter;
using TransformA
= NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Fragment::Element,
typename FragmentScale::Element, FragmentScale::kElements>;
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
TransformA transformA;
TransformScale transformScale;
// Perform accumulation in the 'd' output operand
accum = src_accum;
FragmentA tb_frag_A;
FragmentB tb_frag_B;
FragmentScale tb_frag_scales;
using WarpFragmentScale = typename Dequantizer::FragmentScale;
WarpFragmentScale warp_frag_scales;
tb_frag_A.clear();
tb_frag_B.clear();
tb_frag_scales.clear();
// The last kblock is loaded in the prolog
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
iterator_scale.load(tb_frag_scales);
++iterator_A;
++iterator_B;
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
this->smem_iterator_scale_.store(transformScale(tb_frag_scales));
++this->smem_iterator_A_;
++this->smem_iterator_B_;
__syncthreads();
warp_dequantizer_.load(warp_frag_scales);
// Pair of fragments used to overlap shared memory loads and math instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
Operator warp_mma;
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
iterator_A.clear_mask(gemm_k_iterations <= 1);
iterator_B.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
// shared memory loads (which have the tighest latency requirement).
//
// Mainloop
//
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > 0; --gemm_k_iterations)
{
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
{
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
// as the case may be.
if (warp_mma_k == Base::kWarpGemmIterations - 1)
{
// Write fragments to shared memory
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
__syncthreads();
++this->smem_iterator_A_;
++this->smem_iterator_B_;
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
if (smem_write_stage_idx == 1)
{
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
}
else
{
this->warp_tile_iterator_A_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
}
smem_write_stage_idx ^= 1;
}
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
{
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
if (warp_mma_k == 0)
{
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
// Avoid reading out of bounds if this was the last loop iteration
iterator_A.clear_mask(gemm_k_iterations <= 2);
iterator_B.clear_mask(gemm_k_iterations <= 2);
}
typename TransformBAfterLDS::result_type converted_frag_B
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
run_warp_mma(
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -16,6 +16,10 @@
#pragma once
#include <cassert>
#include <sstream>
#include <string>
namespace tensorrt_llm
{
namespace cutlass_extensions
@ -153,6 +157,32 @@ struct CutlassGemmConfig
, is_sm90(true)
{
}
std::string toString()
{
std::stringstream tactic;
tactic << "Cutlass GEMM Tactic";
if (tile_config_sm90 != tensorrt_llm::cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic)
{
assert(is_sm90 && "Invalid cutlass GEMM config");
tactic << "\n\tstyle=TMA"
<< "\n\ttile shape ID: " << (int) tile_config_sm90 << "\n\tcluster shape ID: " << (int) cluster_shape
<< "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule;
}
else if (tile_config != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic)
{
assert(!is_sm90 && "Invalid cutlass GEMM config");
tactic << "\n\tstyle=compatible"
<< "\n\ttile shape ID: " << (int) tile_config << "\n\tstages: " << (int) stages
<< "\n\tsplit k: " << (int) split_k_factor;
}
else
{
tactic << "\n\tundefined";
}
tactic << "\n";
return tactic.str();
}
};
} // namespace cutlass_extensions

View File

@ -89,6 +89,8 @@ public:
using AccessType = AlignedArray<Element, kAlignment>;
using Fragment = cutlass::Array<Element, kAlignment>;
// For compatibility with existing iterator interface
struct Params
{

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8ff03e99e17e64c9f559e4586dec3983d438857c0050a34a417e4d86d56fbe2a
size 1251854
oid sha256:dc6db658e67f3ea6a11bc37be2b90090f23b831719559a3202bf8b2b397f6f3b
size 1334290

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8ff03e99e17e64c9f559e4586dec3983d438857c0050a34a417e4d86d56fbe2a
size 1251854
oid sha256:dc6db658e67f3ea6a11bc37be2b90090f23b831719559a3202bf8b2b397f6f3b
size 1334290

View File

@ -1,3 +1,3 @@
54670adde093baff8b031869bdeeeb1b libtensorrt_llm_executor_static.a
54670adde093baff8b031869bdeeeb1b libtensorrt_llm_executor_static.pre_cxx11.a
05aaf1a0fb2f0115af107b00aa839a6601f6a873 commit
8e5f1d6bb88c80004b4260aa2d022420 libtensorrt_llm_executor_static.a
8e5f1d6bb88c80004b4260aa2d022420 libtensorrt_llm_executor_static.pre_cxx11.a
fc46fa01e555f9f97387340e46e9571fabf73988 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:431dc6352dcb332821aab031ccbd887e6a60591a5ea276a9ffd3df1f28463326
size 1271014
oid sha256:f0421ca1e637adfdebc9718c47537ed81b55cad4f7fbd062b1d83ca0ab7ebbe5
size 1371948

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:86319542d275570a0c66622d4656b88f3d153c6861db0e53f17f29d47e0a30c9
size 1227362
oid sha256:46b6253eef9136f91d2877e9baa827a8ff229b54b6fb1f2717fb6c85a7ffa047
size 1306830

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8ed99448579b40e0046eca5c8989151a66579f8fbccef9cda4ee7fc2ffd2245b
size 12076106
oid sha256:fa1bde9d020eac84321954b0cb393ec7af63393e29947d6c147700063b0267da
size 12726212

View File

@ -80,6 +80,8 @@ int main(int argc, char* argv[])
// In orchestrator mode, the spawned threads will wait for termination signal from orchestrator
auto executor = tle::Executor(modelPath, modelType, executorConfig);
// Wait for all workers to have created their instances
MPI_Barrier(parentComm);
TLLM_LOG_INFO("Executor instance created by worker");
#endif // ENABLE_MULTI_DEVICE

View File

@ -18,12 +18,9 @@
file(GLOB_RECURSE SRC_CPP *.cpp)
file(GLOB_RECURSE SRC_CU *.cu)
# Exclude files in the cutlass_kernels, contextFusedMultiHeadAttention and
# unfusedAttentionKernels folder
# Exclude files in the cutlass_kernels and unfusedAttentionKernels folder
list(FILTER SRC_CPP EXCLUDE REGEX "cutlass_kernels/.*")
list(FILTER SRC_CU EXCLUDE REGEX "cutlass_kernels/.*")
list(FILTER SRC_CPP EXCLUDE REGEX "contextFusedMultiHeadAttention/.*")
list(FILTER SRC_CU EXCLUDE REGEX "contextFusedMultiHeadAttention/.*")
list(FILTER SRC_CPP EXCLUDE REGEX "decoderMaskedMultiheadAttention/.*")
list(FILTER SRC_CU EXCLUDE REGEX "decoderMaskedMultiheadAttention/.*")
@ -37,4 +34,3 @@ set_property(TARGET kernels_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_subdirectory(cutlass_kernels)
add_subdirectory(decoderMaskedMultiheadAttention)
add_subdirectory(contextFusedMultiHeadAttention)

View File

@ -1,24 +0,0 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
#
file(GLOB_RECURSE SRC_CPP *.cpp)
file(GLOB_RECURSE SRC_CU *.cu)
add_library(context_attention_src OBJECT ${SRC_CPP} ${SRC_CU})
set_property(TARGET context_attention_src PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET context_attention_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS
ON)

Some files were not shown because too many files have changed in this diff Show More