mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#1688)
* Update TensorRT-LLM --------- Co-authored-by: IbrahimAmin <ibrahimamin532@gmail.com> Co-authored-by: Fabian Joswig <fjosw@users.noreply.github.com> Co-authored-by: Pzzzzz <hello-cd.plus@hotmail.com> Co-authored-by: CoderHam <hemant@cohere.com> Co-authored-by: Konstantin Lopuhin <kostia.lopuhin@gmail.com>
This commit is contained in:
parent
5d8ca2faf7
commit
f430a4b447
1
.gitignore
vendored
1
.gitignore
vendored
@ -8,7 +8,6 @@ __pycache__/
|
||||
build*/
|
||||
*.egg-info/
|
||||
.coverage
|
||||
*.csv
|
||||
*.onnx
|
||||
tmp/
|
||||
venv/
|
||||
|
||||
210
CHANGELOG.md
210
CHANGELOG.md
@ -1,210 +0,0 @@
|
||||
# Change Log
|
||||
|
||||
## Versions 0.8.0
|
||||
|
||||
* Model Support
|
||||
- Phi-1.5/2.0
|
||||
- Mamba support (see examples/mamba/README.md)
|
||||
- The support is limited to beam width = 1 and single-node single-GPU
|
||||
- Nougat support (see examples/multimodal/README.md#nougat)
|
||||
- Qwen-VL support (see examples/qwenvl/README.md)
|
||||
- RoBERTa support, thanks to the contribution from @erenup
|
||||
- Skywork model support
|
||||
- Add example for multimodal models (BLIP with OPT or T5, LlaVA)
|
||||
* Features
|
||||
- Chunked context support (see docs/source/gpt_attention.md#chunked-context)
|
||||
- LoRA support for C++ runtime (see docs/source/lora.md)
|
||||
- Medusa decoding support (see examples/medusa/README.md)
|
||||
- The support is limited to Python runtime for Ampere or newer GPUs with fp16 and bf16 accuracy, and the `temperature` parameter of sampling configuration should be 0
|
||||
- StreamingLLM support for LLaMA (see docs/source/gpt_attention.md#streamingllm)
|
||||
- Support for batch manager to return logits from context and/or generation phases
|
||||
- Include support in the Triton backend
|
||||
- Support AWQ and GPTQ for QWEN
|
||||
- Support ReduceScatter plugin
|
||||
- Support for combining `repetition_penalty` and `presence_penalty` #274
|
||||
- Support for `frequency_penalty` #275
|
||||
- OOTB functionality support:
|
||||
- Baichuan
|
||||
- InternLM
|
||||
- Qwen
|
||||
- BART
|
||||
- LLaMA
|
||||
- Support enabling INT4-AWQ along with FP8 KV Cache
|
||||
- Support BF16 for weight-only plugin
|
||||
- Baichuan
|
||||
- P-tuning support
|
||||
- INT4-AWQ and INT4-GPTQ support
|
||||
- Decoder iteration-level profiling improvements
|
||||
- Add `masked_select` and `cumsum` function for modeling
|
||||
- Smooth Quantization support for ChatGLM2-6B / ChatGLM3-6B / ChatGLM2-6B-32K
|
||||
- Add Weight-Only Support To Whisper #794, thanks to the contribution from @Eddie-Wang1120
|
||||
- Support FP16 fMHA on NVIDIA V100 GPU
|
||||
* API
|
||||
- Add a set of High-level APIs for end-to-end generation tasks (see examples/high-level-api/README.md)
|
||||
- **[BREAKING CHANGES]** Migrate models to the new build workflow, including LLaMA, Mistral, Mixtral, InternLM, ChatGLM, Falcon, GPT-J, GPT-NeoX, Medusa, MPT, Baichuan and Phi (see docs/source/checkpoint.md)
|
||||
- **[BREAKING CHANGES]** Deprecate `LayerNorm` and `RMSNorm` plugins and removed corresponding build parameters
|
||||
- **[BREAKING CHANGES]** Remove optional parameter `maxNumSequences` for GPT manager
|
||||
* Bug fixes
|
||||
- Fix the first token being abnormal issue when `--gather_all_token_logits` is enabled #639
|
||||
- Fix LLaMA with LoRA enabled build failure #673
|
||||
- Fix InternLM SmoothQuant build failure #705
|
||||
- Fix Bloom int8_kv_cache functionality #741
|
||||
- Fix crash in `gptManagerBenchmark` #649
|
||||
- Fix Blip2 build error #695
|
||||
- Add pickle support for `InferenceRequest` #701
|
||||
- Fix Mixtral-8x7b build failure with custom_all_reduce #825
|
||||
- Fix INT8 GEMM shape #935
|
||||
- Minor bug fixes
|
||||
* Performance
|
||||
- **[BREAKING CHANGES]** Increase default `freeGpuMemoryFraction` parameter from 0.85 to 0.9 for higher throughput
|
||||
- **[BREAKING CHANGES]** Disable `enable_trt_overlap` argument for GPT manager by default
|
||||
- Performance optimization of beam search kernel
|
||||
- Add bfloat16 and paged kv cache support for optimized generation MQA/GQA kernels
|
||||
- Custom AllReduce plugins performance optimization
|
||||
- Top-P sampling performance optimization
|
||||
- LoRA performance optimization
|
||||
- Custom allreduce performance optimization by introducing a ping-pong buffer to avoid an extra synchronization cost
|
||||
- Integrate XQA kernels for GPT-J (beamWidth=4)
|
||||
* Documentation
|
||||
- Batch manager arguments documentation updates
|
||||
- Add documentation for best practices for tuning the performance of TensorRT-LLM (See docs/source/perf_best_practices.md)
|
||||
- Add documentation for Falcon AWQ support (See examples/falcon/README.md)
|
||||
- Update to the `docs/source/checkpoint.md` documentation
|
||||
- Update AWQ INT4 weight only quantization documentation for GPT-J
|
||||
- Add blog: Speed up inference with SOTA quantization techniques in TRT-LLM
|
||||
- Refine TensorRT-LLM backend README structure #133
|
||||
- Typo fix #739
|
||||
|
||||
## Versions 0.7.0 / 0.7.1
|
||||
|
||||
* Models
|
||||
- BART and mBART support in encoder-decoder models
|
||||
- FairSeq Neural Machine Translation (NMT) family
|
||||
- Mixtral-8x7B model
|
||||
- Support weight loading for HuggingFace Mixtral model
|
||||
- OpenAI Whisper
|
||||
- Mixture of Experts support
|
||||
- MPT - Int4 AWQ / SmoothQuant support
|
||||
- Baichuan FP8 quantization support
|
||||
* Features
|
||||
- [Preview] Speculative decoding
|
||||
- Add Python binding for `GptManager`
|
||||
- Add a Python class `ModelRunnerCpp` that wraps C++ `gptSession`
|
||||
- System prompt caching
|
||||
- Enable split-k for weight-only cutlass kernels
|
||||
- FP8 KV cache support for XQA kernel
|
||||
- New Python builder API and `trtllm-build` command(already applied to [blip2](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/blip2) and [OPT](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/opt#3-build-tensorrt-engines) )
|
||||
- Support `StoppingCriteria` and `LogitsProcessor` in Python generate API (thanks to the contribution from @zhang-ge-hao)
|
||||
- fMHA support for chunked attention and paged kv cache
|
||||
* Bug fixes
|
||||
- Fix tokenizer usage in quantize.py #288, thanks to the contribution from @0xymoro
|
||||
- Fix LLaMa with LoRA error #637
|
||||
- Fix LLaMA GPTQ failure #580
|
||||
- Fix Python binding for InferenceRequest issue #528
|
||||
- Fix CodeLlama SQ accuracy issue #453
|
||||
* Performance
|
||||
- MMHA optimization for MQA and GQA
|
||||
- LoRA optimization: cutlass grouped gemm
|
||||
- Optimize Hopper warp specialized kernels
|
||||
- Optimize AllReduce for parallel attention on Falcon and GPT-J
|
||||
- Enable split-k for weight-only cutlass kernel when SM>=75
|
||||
* Documentation
|
||||
- Add [documentation for convert/build workflow](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/checkpoint.md)
|
||||
|
||||
## Versions 0.6.0 / 0.6.1
|
||||
|
||||
* Models
|
||||
* ChatGLM3
|
||||
* InternLM (contributed by @wangruohui)
|
||||
* Mistral 7B (developed in collaboration with Mistral.AI)
|
||||
* MQA/GQA support to MPT (and GPT) models (contributed by @bheilbrun)
|
||||
* Qwen (contributed by @Tlntin and @zhaohb)
|
||||
* Replit Code V-1.5 3B (external contribution)
|
||||
* T5, mT5, Flan-T5 (Python runtime only)
|
||||
|
||||
* Features
|
||||
* Add runtime statistics related to active requests and KV cache
|
||||
utilization from the batch manager (see
|
||||
the [batch manager](docs/source/batch_manager.md) documentation)
|
||||
* Add `sequence_length` tensor to support proper lengths in beam-search
|
||||
(when beam-width > 1 - see
|
||||
[tensorrt_llm/batch_manager/GptManager.h](cpp/include/tensorrt_llm/batch_manager/GptManager.h))
|
||||
* BF16 support for encoder-decoder models (Python runtime - see
|
||||
[examples/enc_dec](examples/enc_dec/README.md))
|
||||
* Improvements to memory utilization (CPU and GPU - including memory
|
||||
leaks)
|
||||
* Improved error reporting and memory consumption
|
||||
* Improved support for stop and bad words
|
||||
* INT8 SmoothQuant and INT8 KV Cache support for the Baichuan models (see
|
||||
[examples/baichuan](examples/baichuan/README.md))
|
||||
* INT4 AWQ Tensor Parallelism support and INT8 KV cache + AWQ/weight-only
|
||||
support for the GPT-J model (see [examples/gptj](examples/gptj/README.md))
|
||||
* INT4 AWQ support for the Falcon models
|
||||
(see [examples/falcon](examples/falcon/README.md))
|
||||
* LoRA support (functional preview only - limited to the Python runtime,
|
||||
only QKV support and not optimized in terms of runtime performance) for
|
||||
the GPT model (see the
|
||||
[Run LoRA with the Nemo checkpoint](examples/gpt/README.md#Run-LoRA-with-the-Nemo-checkpoint)
|
||||
in the GPT example)
|
||||
* Multi-GPU support for encoder-decoder models (Python runtime - see
|
||||
[examples/enc_dec](examples/enc_dec/README.md))
|
||||
* New heuristic for launching the Multi-block Masked MHA kernel (similar
|
||||
to FlashDecoding - see
|
||||
[decoderMaskedMultiheadAttentionLaunch.h](cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionLaunch.h))
|
||||
* Prompt-Tuning support for GPT and LLaMA models (see the
|
||||
[Prompt-tuning](examples/gpt/README.md#Prompt-tuning) Section in the GPT example)
|
||||
* Performance optimizations in various CUDA kernels
|
||||
* Possibility to exclude input tokens from the output (see `excludeInputInOutput` in
|
||||
[`GptManager`](cpp/include/tensorrt_llm/batch_manager/GptManager.h))
|
||||
* Python binding for the C++ runtime (GptSession - see [`pybind`](cpp/tensorrt_llm/pybind))
|
||||
* Support for different micro batch sizes for context and generation
|
||||
phases with pipeline parallelism (see `GptSession::Config::ctxMicroBatchSize` and
|
||||
`GptSession::Config::genMicroBatchSize` in
|
||||
[tensorrt_llm/runtime/gptSession.h](cpp/include/tensorrt_llm/runtime/gptSession.h))
|
||||
* Support for "remove input padding" for encoder-decoder models (see
|
||||
[examples/enc_dec](examples/enc_dec/README.md))
|
||||
* Support for context and generation logits (see `mComputeContextLogits` and
|
||||
`mComputeGenerationLogits` in
|
||||
[tensorrt_llm/runtime/gptModelConfig.h](cpp/include/tensorrt_llm/runtime/gptModelConfig.h))
|
||||
* Support for `logProbs` and `cumLogProbs` (see `"output_log_probs"` and
|
||||
`"cum_log_probs"` in [`GptManager`](cpp/include/tensorrt_llm/batch_manager/GptManager.h))
|
||||
* Update to CUTLASS 3.x
|
||||
|
||||
* Bug fixes
|
||||
* Fix for ChatGLM2 #93 and #138
|
||||
* Fix tensor names error "RuntimeError: Tensor names
|
||||
(`host_max_kv_cache_length`) in engine are not the same as expected in
|
||||
the main branch" #369
|
||||
* Fix weights split issue in BLOOM when `world_size = 2` ("array split
|
||||
does not result in an equal division") #374
|
||||
* Fix SmoothQuant multi-GPU failure with tensor parallelism is 2 #267
|
||||
* Fix a crash in GenerationSession if stream keyword argument is not None
|
||||
#202
|
||||
* Fix a typo when calling PyNVML API [BUG] code bug #410
|
||||
* Fix bugs related to the improper management of the `end_id` for various
|
||||
models [C++ and Python]
|
||||
* Fix memory leaks [C++ code and Python models]
|
||||
* Fix the std::alloc error when running the gptManagerBenchmark -- issue
|
||||
gptManagerBenchmark std::bad_alloc error #66
|
||||
* Fix a bug in pipeline parallelism when beam-width > 1
|
||||
* Fix a bug with Llama GPTQ due to improper support of GQA
|
||||
* Fix issue #88
|
||||
* Fix an issue with the Huggingface Transformers version #16
|
||||
* Fix link jump in windows readme.md #30 - by @yuanlehome
|
||||
* Fix typo in batchScheduler.h #56 - by @eltociear
|
||||
* Fix typo #58 - by @RichardScottOZ
|
||||
* Fix Multi-block MMHA: Difference between `max_batch_size` in the engine
|
||||
builder and `max_num_sequences` in TrtGptModelOptionalParams? #65
|
||||
* Fix the log message to be more accurate on KV cache #224
|
||||
* Fix Windows release wheel installation: Failed to install the release
|
||||
wheel for Windows using pip #261
|
||||
* Fix missing torch dependencies: [BUG] The batch_manage.a choice error
|
||||
in --cpp-only when torch's cxx_abi version is different with gcc #151
|
||||
* Fix linking error during compiling google-test & benchmarks #277
|
||||
* Fix logits dtype for Baichuan and ChatGLM: segmentation fault caused by
|
||||
the lack of bfloat16 #335
|
||||
* Minor bug fixes
|
||||
|
||||
## Version 0.5.0
|
||||
|
||||
* TensorRT-LLM v0.5.0 is the first public release.
|
||||
@ -6,9 +6,9 @@ TensorRT-LLM
|
||||
|
||||
[](https://nvidia.github.io/TensorRT-LLM/)
|
||||
[](https://www.python.org/downloads/release/python-31012/)
|
||||
[](https://developer.nvidia.com/cuda-downloads)
|
||||
[](https://developer.nvidia.com/cuda-downloads)
|
||||
[](https://developer.nvidia.com/tensorrt)
|
||||
[](./setup.py)
|
||||
[](./tensorrt_llm/version.py)
|
||||
[](./LICENSE)
|
||||
|
||||
[Architecture](./docs/source/architecture/overview.md) | [Results](./docs/source/performance/perf-overview.md) | [Examples](./examples/) | [Documentation](./docs/source/)
|
||||
|
||||
@ -906,7 +906,7 @@ public:
|
||||
[this](uint64_t requestId, std::list<NamedTensor> const& response_tensors, bool final_response,
|
||||
std::string const& errMsg)
|
||||
{ return sendResponse(requestId, response_tensors, final_response, errMsg); },
|
||||
nullptr, iterationDataCallback, optionalParams, terminateReqId, std::nullopt, excludeInputInOutput);
|
||||
nullptr, iterationDataCallback, optionalParams, terminateReqId, excludeInputInOutput);
|
||||
}
|
||||
|
||||
~GptServer()
|
||||
|
||||
@ -18,6 +18,8 @@ from typing import Optional, Tuple
|
||||
import click
|
||||
from pydantic import BaseModel, field_validator
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from utils.prepare_real_data import dataset
|
||||
from utils.prepare_synthetic_data import token_norm_dist
|
||||
|
||||
@ -27,10 +29,12 @@ class RootArgs(BaseModel):
|
||||
output: str
|
||||
random_seed: int
|
||||
task_id: int
|
||||
std_out: bool
|
||||
rand_task_id: Optional[Tuple[int, int]]
|
||||
|
||||
@field_validator('tokenizer')
|
||||
def get_tokenizer(cls, v: str):
|
||||
def get_tokenizer(cls,
|
||||
v: str) -> PreTrainedTokenizer | PreTrainedTokenizerFast:
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(v, padding_side='left')
|
||||
except EnvironmentError as e:
|
||||
@ -53,6 +57,11 @@ class RootArgs(BaseModel):
|
||||
type=str,
|
||||
help="Output json filename.",
|
||||
default="preprocessed_dataset.json")
|
||||
@click.option(
|
||||
"--stdout",
|
||||
is_flag=True,
|
||||
help="Print output to stdout with a JSON dataset entry on each line.",
|
||||
default=False)
|
||||
@click.option("--random-seed",
|
||||
required=False,
|
||||
type=int,
|
||||
@ -80,6 +89,7 @@ def cli(ctx, **kwargs):
|
||||
|
||||
ctx.obj = RootArgs(tokenizer=kwargs['tokenizer'],
|
||||
output=kwargs['output'],
|
||||
std_out=kwargs['stdout'],
|
||||
random_seed=kwargs['random_seed'],
|
||||
task_id=kwargs['task_id'],
|
||||
rand_task_id=kwargs['rand_task_id'])
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Optional
|
||||
import click
|
||||
from datasets import load_dataset
|
||||
from pydantic import BaseModel, model_validator
|
||||
from utils.utils import dataset_dump, get_norm_dist_tokens
|
||||
from utils.utils import dataset_dump, get_norm_dist_tokens, print_dataset
|
||||
|
||||
|
||||
def validate_output_len_dist(ctx, param, value):
|
||||
@ -220,11 +220,19 @@ def dataset(root_args, **kwargs):
|
||||
logging.debug(f"Input lengths: {[len(i) for i in input_ids]}")
|
||||
logging.debug(f"Output lengths: {output_lens}")
|
||||
|
||||
dataset_dump(
|
||||
input_lens, input_ids, output_lens, task_ids, {
|
||||
"workload_type": "dataset",
|
||||
"tokenizer": root_args.tokenizer.__class__.__name__,
|
||||
"num_requests": len(input_ids),
|
||||
"max_input_len": max(input_lens),
|
||||
"max_output_len": max(output_lens)
|
||||
}, root_args.output)
|
||||
if not root_args.std_out:
|
||||
dataset_dump(
|
||||
input_lens, input_ids, output_lens, task_ids, {
|
||||
"workload_type": "dataset",
|
||||
"tokenizer": root_args.tokenizer.__class__.__name__,
|
||||
"num_requests": len(input_ids),
|
||||
"max_input_len": max(input_lens),
|
||||
"max_output_len": max(output_lens)
|
||||
}, root_args.output)
|
||||
else:
|
||||
print_dataset(
|
||||
task_ids,
|
||||
input_ids,
|
||||
output_lens,
|
||||
tokenizer=None,
|
||||
)
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import random
|
||||
|
||||
import click
|
||||
from utils.utils import dataset_dump, gen_random_tokens, get_norm_dist_tokens
|
||||
from utils.utils import (dataset_dump, gen_random_tokens, get_norm_dist_tokens,
|
||||
print_dataset)
|
||||
|
||||
|
||||
@click.command()
|
||||
@ -55,15 +56,21 @@ def token_norm_dist(root_args, **kwargs):
|
||||
min_id, max_id = root_args.rand_task_id
|
||||
task_ids = [random.randint(min_id, max_id) for _ in range(num_reqs)]
|
||||
|
||||
dataset_dump(
|
||||
input_lens, input_ids, output_lens, task_ids, {
|
||||
"workload_type": "token-norm-dist",
|
||||
"input_mean": kwargs['input_mean'],
|
||||
"input_stdev": kwargs['input_stdev'],
|
||||
"output_mean": kwargs['output_mean'],
|
||||
"output_stdev": kwargs['output_stdev'],
|
||||
"num_requests": kwargs['num_requests'],
|
||||
"tokenize_vocabsize": root_args.tokenizer.vocab_size,
|
||||
"max_input_len": max_input_len,
|
||||
"max_output_len": max_output_len
|
||||
}, root_args.output)
|
||||
if not root_args.std_out:
|
||||
dataset_dump(
|
||||
input_lens, input_ids, output_lens, task_ids, {
|
||||
"workload_type": "token-norm-dist",
|
||||
"input_mean": kwargs['input_mean'],
|
||||
"input_stdev": kwargs['input_stdev'],
|
||||
"output_mean": kwargs['output_mean'],
|
||||
"output_stdev": kwargs['output_stdev'],
|
||||
"num_requests": kwargs['num_requests'],
|
||||
"tokenize_vocabsize": root_args.tokenizer.vocab_size,
|
||||
"max_input_len": max_input_len,
|
||||
"max_output_len": max_output_len
|
||||
}, root_args.output)
|
||||
else:
|
||||
print_dataset(
|
||||
input_ids,
|
||||
output_lens,
|
||||
)
|
||||
|
||||
@ -43,7 +43,17 @@ def dataset_dump(input_lens, input_ids, output_lens, task_ids, metadata,
|
||||
task_id=task_ids[i]))
|
||||
workload = Workload(metadata=metadata, samples=samples)
|
||||
with open(output_file, 'w') as f:
|
||||
json.dump(workload.dict(), f)
|
||||
json.dump(workload.model_dump(), f)
|
||||
|
||||
|
||||
def print_dataset(input_ids, output_lens):
|
||||
for i, input_tokens in enumerate(input_ids):
|
||||
d = {
|
||||
"task_id": i,
|
||||
"logits": input_tokens,
|
||||
"output_tokens": output_lens[i]
|
||||
}
|
||||
print(json.dumps(d, separators=(',', ':'), ensure_ascii=False))
|
||||
|
||||
|
||||
def get_list_of_delays(delay_dist, mean_time_bet_reqs, num_reqs, random_seed):
|
||||
|
||||
@ -19,12 +19,12 @@ from argparse import ArgumentParser
|
||||
import torch
|
||||
# isort: on
|
||||
from cuda import cuda, cudart
|
||||
from mpi4py import MPI
|
||||
from polygraphy.backend.trt import CreateConfig, EngineFromNetwork
|
||||
|
||||
import tensorrt_llm as tllm
|
||||
from tensorrt_llm import Mapping, Tensor
|
||||
from tensorrt_llm._ipc_utils import peer_access
|
||||
from tensorrt_llm._utils import OMPI_COMM_TYPE_HOST, mpi_comm
|
||||
from tensorrt_llm.functional import AllReduceStrategy, allreduce
|
||||
from tensorrt_llm.plugin.plugin import current_all_reduce_helper
|
||||
|
||||
@ -35,11 +35,14 @@ def allreduce_benchmark(dtype: str,
|
||||
tllm.logger.set_level('error')
|
||||
world_size = tllm.mpi_world_size()
|
||||
rank = tllm.mpi_rank()
|
||||
local_comm = mpi_comm().Split_type(split_type=OMPI_COMM_TYPE_HOST)
|
||||
local_rank = local_comm.Get_rank()
|
||||
gpus_per_node = local_comm.Get_size()
|
||||
|
||||
torch.cuda.set_device(rank)
|
||||
cudart.cudaSetDevice(rank)
|
||||
torch.cuda.set_device(local_rank)
|
||||
cudart.cudaSetDevice(local_rank)
|
||||
|
||||
mapping = Mapping(world_size, rank, world_size, world_size)
|
||||
mapping = Mapping(world_size, rank, gpus_per_node, world_size)
|
||||
|
||||
if world_size == 1:
|
||||
raise RuntimeError("Benchmark must run with mpi_world_size > 1")
|
||||
@ -58,8 +61,10 @@ def allreduce_benchmark(dtype: str,
|
||||
input = torch.ones(size, dtype=torch_dtype, device="cuda")
|
||||
|
||||
for strategy in [
|
||||
AllReduceStrategy.NCCL, AllReduceStrategy.ONESHOT,
|
||||
AllReduceStrategy.TWOSHOT
|
||||
AllReduceStrategy.AUTO,
|
||||
AllReduceStrategy.NCCL,
|
||||
AllReduceStrategy.ONESHOT,
|
||||
AllReduceStrategy.TWOSHOT,
|
||||
]:
|
||||
builder = tllm.Builder()
|
||||
net = builder.create_network()
|
||||
@ -81,9 +86,9 @@ def allreduce_benchmark(dtype: str,
|
||||
current = allreduce(current, mapping.tp_group, strategy)
|
||||
output = current.trt_tensor
|
||||
|
||||
network.mark_output(output)
|
||||
output.name = 'output'
|
||||
output.dtype = tllm.str_dtype_to_trt(dtype)
|
||||
network.mark_output(output)
|
||||
|
||||
build_engine = EngineFromNetwork(
|
||||
(builder.trt_builder, net.trt_network),
|
||||
@ -103,7 +108,7 @@ def allreduce_benchmark(dtype: str,
|
||||
_, stop = cuda.cuEventCreate(0)
|
||||
runtimes = []
|
||||
with peer_access(mapping):
|
||||
MPI.COMM_WORLD.barrier()
|
||||
tllm.mpi_barrier()
|
||||
|
||||
for _ in range(10):
|
||||
cuda.cuEventRecord(start, stream.cuda_stream)
|
||||
|
||||
@ -61,6 +61,7 @@ class BuildConfig:
|
||||
parallel_attention: bool = None
|
||||
new_decoder_architecture: bool = None
|
||||
state_size: int = 0
|
||||
state_dtype: Optional[str] = None
|
||||
conv_kernel: int = 0
|
||||
layer_types: List[str] = field(default_factory=list)
|
||||
rnn_hidden_size: int = 0
|
||||
@ -479,6 +480,7 @@ _allowed_configs = {
|
||||
build_config=BuildConfig(
|
||||
num_layers=32,
|
||||
num_heads=32,
|
||||
num_kv_heads=8,
|
||||
hidden_size=4096,
|
||||
vocab_size=32000,
|
||||
hidden_act='swiglu',
|
||||
@ -503,7 +505,7 @@ _allowed_configs = {
|
||||
hidden_act='gelu',
|
||||
n_positions=1024,
|
||||
rotary_dim=64,
|
||||
max_batch_size=256,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
@ -593,7 +595,7 @@ _allowed_configs = {
|
||||
vocab_size=250880,
|
||||
hidden_act=None,
|
||||
n_positions=2048,
|
||||
max_batch_size=8,
|
||||
max_batch_size=32,
|
||||
max_input_len=1024,
|
||||
max_output_len=1024,
|
||||
builder_opt=None,
|
||||
@ -1327,6 +1329,7 @@ _allowed_configs = {
|
||||
layer_types=["recurrent", "recurrent", "attention"],
|
||||
rnn_hidden_size=2560,
|
||||
logits_soft_cap=30.0,
|
||||
state_dtype="float32",
|
||||
)),
|
||||
}
|
||||
|
||||
|
||||
@ -883,6 +883,9 @@ def build_gpt(args):
|
||||
config = PretrainedConfig.from_dict(config)
|
||||
tensorrt_llm_model = tensorrt_llm.models.RecurrentGemmaForCausalLM(
|
||||
config)
|
||||
tensorrt_llm_model = optimize_model(tensorrt_llm_model,
|
||||
use_fused_mlp=True,
|
||||
use_fused_rg_lru=True)
|
||||
|
||||
else:
|
||||
raise Exception(f'Unexpected model: {args.model}')
|
||||
@ -894,15 +897,15 @@ def build_gpt(args):
|
||||
|
||||
# Plugins
|
||||
if args.mode in ['plugin', 'plugin-ifb']:
|
||||
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)
|
||||
network.plugin_config.gpt_attention_plugin = args.dtype
|
||||
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
|
||||
network.plugin_config.enable_remove_input_padding()
|
||||
network.plugin_config.set_lookup_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_moe_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_mamba_conv1d_plugin(dtype=args.dtype)
|
||||
network.plugin_config.remove_input_padding = True
|
||||
network.plugin_config.lookup_plugin = args.dtype
|
||||
network.plugin_config.moe_plugin = args.dtype
|
||||
network.plugin_config.mamba_conv1d_plugin = args.dtype
|
||||
|
||||
if args.quantization is None or "fp8" not in args.quantization:
|
||||
network.plugin_config.set_gemm_plugin(dtype=args.dtype)
|
||||
network.plugin_config.gemm_plugin = args.dtype
|
||||
|
||||
# Quantization plugins.
|
||||
use_smooth_quant = quant_mode.has_act_and_weight_quant()
|
||||
@ -910,21 +913,19 @@ def build_gpt(args):
|
||||
if use_smooth_quant:
|
||||
network.plugin_config.set_smooth_quant_plugins(dtype=args.dtype)
|
||||
elif use_weight_only:
|
||||
network.plugin_config.set_weight_only_quant_matmul_plugin(
|
||||
dtype=args.dtype)
|
||||
network.plugin_config.weight_only_quant_matmul_plugin = args.dtype
|
||||
elif family == "llama" and quant_mode.has_act_and_weight_quant():
|
||||
# RMS norm plugin for SmoothQuant
|
||||
network.plugin_config.set_rmsnorm_quantization_plugin(
|
||||
dtype=args.dtype)
|
||||
network.plugin_config.rmsnorm_quantization_plugin = args.dtype
|
||||
|
||||
# Inflight batching
|
||||
if args.mode == 'plugin-ifb':
|
||||
network.plugin_config.enable_paged_kv_cache()
|
||||
network.plugin_config.enable_paged_state()
|
||||
network.plugin_config.paged_state = True
|
||||
elif args.mode == 'ootb-except-mha':
|
||||
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)
|
||||
network.plugin_config.gpt_attention_plugin = args.dtype
|
||||
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
|
||||
network.plugin_config.enable_remove_input_padding()
|
||||
network.plugin_config.remove_input_padding = True
|
||||
|
||||
if world_size > 1:
|
||||
network.plugin_config.set_nccl_plugin(
|
||||
@ -1056,12 +1057,12 @@ def build_bert(args):
|
||||
|
||||
# Plugins
|
||||
if args.mode == 'plugin':
|
||||
network.plugin_config.set_bert_attention_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_gemm_plugin(dtype=args.dtype)
|
||||
network.plugin_config.enable_qk_half_accum()
|
||||
network.plugin_config.bert_attention_plugin = args.dtype
|
||||
network.plugin_config.gemm_plugin = args.dtype
|
||||
network.plugin_config.attention_qk_half_accumulation = True
|
||||
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
|
||||
elif args.mode == 'ootb-except-mha':
|
||||
network.plugin_config.set_bert_attention_plugin(dtype=args.dtype)
|
||||
network.plugin_config.bert_attention_plugin = args.dtype
|
||||
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
|
||||
|
||||
if world_size > 1:
|
||||
@ -1226,13 +1227,29 @@ def enc_dec_build_helper(component, config, args):
|
||||
|
||||
if component == 'encoder':
|
||||
if family == 'whisper':
|
||||
tllm_model = tensorrt_llm.models.WhisperEncoder(
|
||||
n_mels=config['n_mels'],
|
||||
n_ctx=1500, # n_audio_ctx
|
||||
n_state=config['hidden_size'],
|
||||
n_head=config['num_heads'],
|
||||
n_layer=config['num_layers'],
|
||||
dtype=dtype)
|
||||
pretrained_config = PretrainedConfig.from_dict({
|
||||
'architecture':
|
||||
"WhisperEncoder",
|
||||
'dtype':
|
||||
dtype,
|
||||
'num_hidden_layers':
|
||||
config['num_layers'],
|
||||
'num_attention_heads':
|
||||
config['num_heads'],
|
||||
'hidden_size':
|
||||
config['hidden_size'],
|
||||
'n_mels':
|
||||
config['n_mels'],
|
||||
'n_audio_ctx':
|
||||
1500,
|
||||
'vocab_size':
|
||||
config['vocab_size'],
|
||||
'hidden_act':
|
||||
"gelu",
|
||||
'num_languages':
|
||||
100,
|
||||
})
|
||||
tllm_model = tensorrt_llm.models.WhisperEncoder(pretrained_config)
|
||||
if use_weight_only:
|
||||
tllm_model = quantize(tllm_model, quant_config)
|
||||
else:
|
||||
@ -1387,15 +1404,14 @@ def enc_dec_build_helper(component, config, args):
|
||||
|
||||
# Plugins
|
||||
if args.mode == 'plugin':
|
||||
network.plugin_config.set_bert_attention_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_gemm_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)
|
||||
network.plugin_config.bert_attention_plugin = args.dtype
|
||||
network.plugin_config.gemm_plugin = args.dtype
|
||||
network.plugin_config.gpt_attention_plugin = args.dtype
|
||||
if use_weight_only:
|
||||
network.plugin_config.set_weight_only_quant_matmul_plugin(
|
||||
dtype=args.dtype)
|
||||
network.plugin_config.weight_only_quant_matmul_plugin = args.dtype
|
||||
elif args.mode == 'ootb-except-mha':
|
||||
network.plugin_config.set_bert_attention_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)
|
||||
network.plugin_config.bert_attention_plugin = args.dtype
|
||||
network.plugin_config.gpt_attention_plugin = args.dtype
|
||||
|
||||
if world_size > 1:
|
||||
network.plugin_config.set_nccl_plugin(
|
||||
|
||||
@ -103,6 +103,7 @@ class EncDecBenchmark(BaseBenchmark):
|
||||
remove_input_padding=config["plugin_config"]
|
||||
["remove_input_padding"],
|
||||
cross_attention=config["builder_config"]["cross_attention"],
|
||||
skip_cross_qkv=config["builder_config"]["skip_cross_qkv"],
|
||||
has_position_embedding=config["builder_config"]
|
||||
["has_position_embedding"],
|
||||
has_token_type_embedding=config["builder_config"]
|
||||
|
||||
@ -111,6 +111,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
self.use_mamba_conv1d_plugin = True
|
||||
elif args.mode == 'ootb-except-mha':
|
||||
self.use_gpt_attention_plugin = True
|
||||
self.remove_input_padding = True
|
||||
|
||||
engine_buffer, build_time = build_gpt(args)
|
||||
self.weights_size_approx = engine_buffer.nbytes
|
||||
@ -123,6 +124,15 @@ class GPTBenchmark(BaseBenchmark):
|
||||
if not hasattr(self, 'num_kv_heads') or self.num_kv_heads is None:
|
||||
self.num_kv_heads = self.num_heads
|
||||
|
||||
rnn_config_items = [
|
||||
'conv_kernel', 'layer_types', 'rnn_hidden_size', 'state_size',
|
||||
'state_dtype'
|
||||
]
|
||||
rnn_configs_kwargs = {}
|
||||
for item in rnn_config_items:
|
||||
if hasattr(self, item):
|
||||
rnn_configs_kwargs[item] = getattr(self, item)
|
||||
|
||||
model_config = tensorrt_llm.runtime.ModelConfig(
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_beam_width=self.num_beams,
|
||||
@ -143,13 +153,8 @@ class GPTBenchmark(BaseBenchmark):
|
||||
tokens_per_block=self.tokens_per_block if hasattr(
|
||||
self, 'tokens_per_block') else 64,
|
||||
mamba_conv1d_plugin=self.use_mamba_conv1d_plugin,
|
||||
conv_kernel=self.conv_kernel if hasattr(self, 'conv_kernel') else 0,
|
||||
state_size=self.state_size if hasattr(self, 'state_size') else 0,
|
||||
layer_types=self.layer_types
|
||||
if hasattr(self, 'layer_types') else [],
|
||||
rnn_hidden_size=self.rnn_hidden_size if hasattr(
|
||||
self, 'rnn_hidden_size') else 0,
|
||||
gpu_weights_percent=list(sorted(gpu_weights_percents))[0],
|
||||
**rnn_configs_kwargs,
|
||||
)
|
||||
if args.model == 'chatglm_6b':
|
||||
self.sampling_config = tensorrt_llm.runtime.SamplingConfig(
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# TensorRT-LLM Benchmarking
|
||||
|
||||
**WORK IN PROGRESS**
|
||||
> [!WARNING] Work in Progress
|
||||
> This benchmarking suite is a current work in progress and is prone to large changes.
|
||||
|
||||
This package is the official benchmarking suite for TensorRT-LLM. This benchmark will be updated
|
||||
as development of TensorRT-LLM continues.
|
||||
@ -9,7 +10,7 @@ as development of TensorRT-LLM continues.
|
||||
|
||||
From this folder, run `pip install -r requirements.txt` to install the extra dependencies required for this tool.
|
||||
|
||||
### Available Model Options
|
||||
### Available Build and Benchmark Options
|
||||
|
||||
The following model options are available for benchmarking models.
|
||||
|
||||
@ -17,7 +18,9 @@ The following model options are available for benchmarking models.
|
||||
| :- | :-: | :-: | :- |
|
||||
| `--model` | Y | - | The name of the model to benchmark. |
|
||||
| `--dtype` | N | `float16` | The datatype of the weights. |
|
||||
| `--max-batch-size` | Y | - | The batch size to build the engine with for the benchmark. |
|
||||
| `--kv-dtype` | N | `float16` | The datatype to store the KV Cache in. |
|
||||
| `--kv-cache-free-gpu-mem-fraction` | N | `0.98` | The percentage of free memory that the KV cache is allowed to occupy. |
|
||||
| `--quantization` | N | `None` |The quantization algorithm to be used when benchmarking. See the [documentation](https://nvidia.github.io/TensorRT-LLM/precision.html) for more information|
|
||||
| `--workspace` | N | `/tmp` | The directory to store benchmarking intermediate files. |
|
||||
| `--tensor-parallel-size` | N | `1` | Number of tensor parallel shards to run the benchmark with. |
|
||||
@ -35,7 +38,8 @@ The following model options are available for benchmarking models.
|
||||
|
||||
#### Support Quantization Modes
|
||||
|
||||
TensorRT-LLM supports a number of quanization modes. For more information about quantization, see the [documentation](https://nvidia.github.io/TensorRT-LLM/precision.html).
|
||||
TensorRT-LLM supports a number of quanization modes. For more information about quantization, see the
|
||||
[documentation](https://nvidia.github.io/TensorRT-LLM/precision.html).
|
||||
|
||||
- None (no quantization applied)
|
||||
- W8A16
|
||||
@ -54,7 +58,7 @@ In order to benchmark a static batch for a network, run a command like the follo
|
||||
|
||||
```shell
|
||||
cd tensorrt_llm_bench/
|
||||
python benchmark.py --model tiiuae/falcon-7b static --isl 128 --osl 128 --batch 1
|
||||
python benchmark.py --model tiiuae/falcon-7b static --isl 128 --osl 128 --max-batch-size 1
|
||||
```
|
||||
|
||||
This command line will build a unique engine for the configuration and run the benchmark using
|
||||
@ -64,18 +68,167 @@ the `gptSessionBenchmark` binary. You need to build the TensorRT-LLM wheel with
|
||||
python3 ./scripts/build_wheel.py --benchmarks <other options>
|
||||
```
|
||||
|
||||
The complete list of arguments are given here:
|
||||
If you've already compiled the wheel without benchmarks, you can build the benchmarking binaries with the following after the fact:
|
||||
|
||||
```shell
|
||||
pushd cpp/build/
|
||||
make -j benchmarks
|
||||
popd
|
||||
```
|
||||
|
||||
The complete list of arguments for static benchmarking are as follows:
|
||||
| Option | Required | Default | Description |
|
||||
| :- | :-: | :-: | :- |
|
||||
| `--batch` | Y | - | The batch size to benchmark. |
|
||||
| `--isl` | Y | - | The input sequence length to pass in during benchmark. |
|
||||
| `--osl` | Y | - | The output sequence length to generate in the benchmark. |
|
||||
| `--gpt-session-path` | N | `../../cpp/build/benchmarks/gptSessionBenchmark` | The path to the built gptSessionBenchmark binary. |
|
||||
| `--max-tokens-in-kv-cache` | N | `None` | The maximum number of tokens to store in the KV Cache during benchmarking. |
|
||||
| `--kv-cache-mem-percent` | N | `0.9` | The percentage of free memory that the KV cache is allowed to occupy. |
|
||||
| `--warm-up-runs` | N | `2` | The number of warm up runs to run before benchmarking actual results. |
|
||||
| `--num-runs` | N | `10` | The number runs to generate benchmarking results from. |
|
||||
| `--duration` | N | `60` | The minimum iteration time, in seconds, to measure. |
|
||||
|
||||
> [!WARNING]
|
||||
> `gptSession` will be deprecated for the 1.0 release of TensorRT-LLM. This command line will change in order to match and update benchmarks accordingly.
|
||||
|
||||
|
||||
## Inflight Benchmarking with a Dataset
|
||||
|
||||
This section covers how to benchmark TensorRT-LLM using inflight batching.
|
||||
|
||||
### Workflow
|
||||
|
||||
The workflow for inflight batching is slightly different than the [static scenario](#static-benchmarking-a-network) as it requires a workload of requests instead of a single static batch. The following is the workflow for benchmarking using inflight batching:
|
||||
|
||||
1. Prepare a dataset to drive the inflight batching benchmark.
|
||||
2. Run the `inflight` benchmarking subcommand and provide the dataset from step 1.
|
||||
|
||||
#### Preparing a Dataset
|
||||
|
||||
The inflight benchmark utilizes a fixed JSON schema so that it is simple and
|
||||
straightforward to specify requests. The schema is defined as follows:
|
||||
|
||||
| Key | Required | Type | Description |
|
||||
| :- | :-: | :-: | :- |
|
||||
| `task_id`| Y | String | Unique identifier for the request. |
|
||||
| `prompt` | N* | String | Input text for a generation request. |
|
||||
| `logits` | N* | List[Integer] | List of logits that make up the request prompt. |
|
||||
| `output_tokens` | Y | Integer | Number of generated tokens for this request. |
|
||||
|
||||
> [!NOTE] Prompt and logits are mutually exclusive*
|
||||
> While having both `prompt` and `logits` is not required, at least one is required.
|
||||
> If `logits` are specified, the `prompt` entry is ignored for request generation.
|
||||
|
||||
Examples of valid entries for the inflight benchmark are:
|
||||
|
||||
- Entries with a human-readable prompt and no logits.
|
||||
```json
|
||||
{"task_id": 1, "prompt": "Generate an infinite response to the following: This is the song that never ends, it goes on and on my friend.", "output_tokens": 1000}
|
||||
{"task_id": 2, "prompt": "Generate an infinite response to the following: Na, na, na, na", "output_tokens": 1000}
|
||||
```
|
||||
|
||||
- Entries which contain logits.
|
||||
```json
|
||||
{"task_id":0,"logits":[863,22056,25603,11943,8932,13195,3132,25032,21747,22213],"output_tokens":128}
|
||||
{"task_id":1,"logits":[14480,13598,15585,6591,1252,8259,30990,26778,7063,30065,21764,11023,1418],"output_tokens":128}
|
||||
```
|
||||
|
||||
> [!INFO] A whole entry is on a line!
|
||||
> To make the passing of data simpler, a complete JSON entry is on each line so that the benchmarker
|
||||
> can simply read a line and assume a complete entry. When creating a dataset, be sure that a complete
|
||||
> JSON entry is on every line.
|
||||
|
||||
#### Using `prepare_dataset` to Create Synthetic Datasets
|
||||
|
||||
In order to prepare a synthetic dataset, you can use the provided script in the `benchmarks/cpp`
|
||||
directory. For example, to generate a synthetic dataset of 1000 requests with a uniform ISL/OSL of
|
||||
128/128 for [Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b), simply run:
|
||||
|
||||
```shell
|
||||
benchmarks/cpp/prepare_dataset.py --tokenizer meta-llama/Llama-2-7b-hf token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 1000 --stdout
|
||||
```
|
||||
|
||||
You can pipe the above command to a file to reuse the same dataset, or simply pipe its output to the
|
||||
benchmark script (example below).
|
||||
|
||||
### Running a Dataset with the Benchmarker
|
||||
|
||||
Once you've generated a dataset (see [above](#preparing-a-dataset)), you can run the benchmarker
|
||||
in one of two ways:
|
||||
|
||||
```shell
|
||||
benchmarks/suite/tensorrt_llm_bench/benchmark.py --model $HF_MODEL_NAME --max-batch-size $BATCH_SIZE < $DATASET_PATH
|
||||
```
|
||||
|
||||
> [!INFO] Alternative to piping.
|
||||
> There is also a `--dataset` option for `benchmark.py` that can be used instead of piping a file.
|
||||
|
||||
or
|
||||
|
||||
```shell
|
||||
benchmarks/cpp/prepare_dataset.py --tokenizer $HF_MODEL_NAME --input-mean $ISL --output-mean $OSL --num-requests $NUM_REQUESTS --stdout | benchmarks/suite/tensorrt_llm_bench/benchmark.py --model $HF_MODEL_NAME --max-batch-size $BATCH_SIZE --request-rate $REQUEST_RATE
|
||||
```
|
||||
|
||||
#### How the Benchmarker Works
|
||||
|
||||
The benchmarker will read in a data file or standard input (stdin) as a stream where a single line contains
|
||||
a complete JSON request entry. The process that the benchmarker is as follows:
|
||||
|
||||
1. Iterate over all input requests. If `logits` is specified, construct the request using the specified
|
||||
list of logits. Otherwise, tokenize the `prompt` with as specified by `--model $HF_MODEL_NAME`.
|
||||
2. Build the TensorRT-LLM engine.
|
||||
3. Submit the dataset to the TensorRT-LLM `Executor` API at the request rate specified by `--request-rate $REQUEST_RATE`
|
||||
4. Wait for all requests to return, compute statistics, then report out results.
|
||||
|
||||
When the benchmark runs successfully, you will see a report out of the run similar to the following:
|
||||
|
||||
```
|
||||
[RANK 0] Submitting requests...
|
||||
[RANK 0] Completed request submission.
|
||||
[RANK 0] Calculating results.
|
||||
[RANK 0] Reporting...
|
||||
[RANK 0] JSON: {'benchmark_cmd': '', 'binary': '', 'build_cmd': 'trtllm-build --output_dir /tmp/meta-llama/llama-2-7b-hf --model_config /tmp/generated_config.json --workers 1 --max_batch_size 1024 --max_input_len 128 --max_output_len 128 --max_num_tokens 8000 --context_fmha enable --gpt_attention_plugin float16 --paged_kv_cache enable --multiple_profiles enable --gemm_plugin float16', 'first_token_latency': 0.0, 'inflight_batching': True, 'kv_mem_fraction': 0.98, 'latency_units': 'ms', 'max_batch_size': 1024, 'max_tokens': 8000, 'model': 'meta-llama/Llama-2-7b-hf', 'peak_gpu_mem_units': 'GB', 'peak_gpu_mem': 0.0, 'scheduler': 'Max Utilization', 'throughput_units': 'tokens/second', 'throughput': 17634.422523488243, 'time_per_output_token': 0.0, 'total_input_tokens': 128000, 'total_latency': 7.258530855178833, 'total_output_tokens': 128000}
|
||||
===========================================================
|
||||
= METADATA
|
||||
===========================================================
|
||||
Model: meta-llama/Llama-2-7b-hf
|
||||
TP Size: 1
|
||||
PP Size: 1
|
||||
Scheduling Policy: Max Utilization
|
||||
In-flight Batcher?: True
|
||||
Dtype: float16
|
||||
KV Cache Dtype: FP8
|
||||
Quantization: FP8
|
||||
KV Memory Percentage: 98.0%
|
||||
|
||||
===========================================================
|
||||
= ENGINE DETAILS
|
||||
===========================================================
|
||||
Engine Directory: /tmp/meta-llama/llama-2-7b-hf
|
||||
Max Batch Size: 1024
|
||||
Total Input Length: 128000
|
||||
Total Output Length: 128000
|
||||
Max Tokens: 8000
|
||||
|
||||
===========================================================
|
||||
= STATISTICS
|
||||
===========================================================
|
||||
Throughput (tokens/second): 17634.422523488243
|
||||
Total Latency (ms): 7258.5309
|
||||
First Token Latency (ms): 0.0
|
||||
Token-to-token Latency (ms): 0.0
|
||||
Peak GPU Memory Usage (GB): 0.0
|
||||
|
||||
===========================================================
|
||||
= COMMANDS
|
||||
===========================================================
|
||||
Build: trtllm-build --output_dir /tmp/meta-llama/llama-2-7b-hf --model_config /tmp/generated_config.json --workers 1 --max_batch_size 1024 --max_input_len 128 --max_output_len 128 --max_num_tokens 8000 --context_fmha enable --gpt_attention_plugin float16 --paged_kv_cache enable --multiple_profiles enable --gemm_plugin float16
|
||||
Benchmark:
|
||||
|
||||
[RANK 0] Terminating.
|
||||
```
|
||||
|
||||
> [!WARNING] Some statistics are not reported.
|
||||
> There are some statistics that are not reported in the summary (typically as 0.0). These statistics
|
||||
> are not available currently.
|
||||
|
||||
|
||||
That's it! -- you've successfully benchmarked TensorRT-LLM!
|
||||
|
||||
@ -2,9 +2,9 @@ from pathlib import Path
|
||||
from typing import get_args
|
||||
|
||||
import click
|
||||
from ifb import executor_benchmark
|
||||
from static import static_benchmark
|
||||
from utils import (VALID_CACHE_DTYPES, VALID_COMPUTE_DTYPES, VALID_MODELS,
|
||||
VALID_QUANT_ALGOS)
|
||||
from utils import VALID_CACHE_DTYPES, VALID_COMPUTE_DTYPES, VALID_QUANT_ALGOS
|
||||
from utils.dataclasses import BenchmarkConfig
|
||||
|
||||
|
||||
@ -13,9 +13,16 @@ from utils.dataclasses import BenchmarkConfig
|
||||
"--model",
|
||||
"-m",
|
||||
required=True,
|
||||
type=click.Choice(tuple(get_args(VALID_MODELS))),
|
||||
type=str,
|
||||
help="The Huggingface name of the model to benchmark.",
|
||||
)
|
||||
@click.option(
|
||||
"--max-batch-size",
|
||||
hidden=True,
|
||||
default=0,
|
||||
type=int,
|
||||
help="Maximum batch size to build the benchmark engine with.",
|
||||
)
|
||||
@click.option(
|
||||
"--kv-dtype",
|
||||
type=click.Choice(tuple(get_args(VALID_CACHE_DTYPES))),
|
||||
@ -38,9 +45,11 @@ from utils.dataclasses import BenchmarkConfig
|
||||
"documentations for more information.\n"
|
||||
" - https://nvidia.github.io/TensorRT-LLM/precision.html"
|
||||
" - https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/quantization-in-TRT-LLM.md"
|
||||
))
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--workspace",
|
||||
"-w",
|
||||
required=False,
|
||||
type=click.Path(writable=True, readable=True),
|
||||
default="/tmp",
|
||||
@ -62,26 +71,46 @@ from utils.dataclasses import BenchmarkConfig
|
||||
required=False,
|
||||
help="Number of pipeline parallel shards to run the benchmark with.",
|
||||
)
|
||||
@click.option(
|
||||
"--kv-cache-free-gpu-mem-fraction",
|
||||
"-kv-mem",
|
||||
type=float,
|
||||
default=0.98,
|
||||
help="The percentage of free memory that the KV Cache is allowed to occupy.",
|
||||
)
|
||||
@click.option(
|
||||
"--build-opts",
|
||||
type=str,
|
||||
default="",
|
||||
required=False,
|
||||
hidden=True,
|
||||
help="Passthrough options for trtllm-build to fine-tuning build commands.")
|
||||
@click.pass_context
|
||||
def benchmark(
|
||||
ctx,
|
||||
model: str,
|
||||
max_batch_size: int,
|
||||
workspace: Path,
|
||||
dtype: str,
|
||||
kv_dtype: str,
|
||||
quantization: str,
|
||||
tensor_parallel_size: int,
|
||||
pipeline_parallel_size: int,
|
||||
kv_cache_free_gpu_mem_fraction: float,
|
||||
build_opts: str,
|
||||
):
|
||||
"""Utility for using TRT-LLM for benchmarking networks from Huggingface."""
|
||||
ctx.obj = BenchmarkConfig(
|
||||
model=model,
|
||||
max_batch_size=max_batch_size,
|
||||
workspace=Path(workspace),
|
||||
dtype=dtype,
|
||||
cache_dtype=kv_dtype,
|
||||
quantization=quantization,
|
||||
tensor_parallel=tensor_parallel_size,
|
||||
pipeline_parallel=pipeline_parallel_size,
|
||||
kv_cache_mem_percentage=kv_cache_free_gpu_mem_fraction,
|
||||
build_overrides=build_opts.split(),
|
||||
)
|
||||
|
||||
# Create the workspace where we plan to store intermediate files.
|
||||
@ -90,6 +119,7 @@ def benchmark(
|
||||
|
||||
# Add nested subcommands to main benchmark CLI.
|
||||
benchmark.add_command(static_benchmark)
|
||||
benchmark.add_command(executor_benchmark)
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark()
|
||||
|
||||
30
benchmarks/suite/tensorrt_llm_bench/benchmarkers/__init__.py
Normal file
30
benchmarks/suite/tensorrt_llm_bench/benchmarkers/__init__.py
Normal file
@ -0,0 +1,30 @@
|
||||
from typing import List, Protocol
|
||||
|
||||
from utils.dataclasses import BenchmarkResults, InferenceRequest
|
||||
|
||||
|
||||
class Benchmarker(Protocol):
|
||||
"""Protocol for defining benchmarking classes for building/benchmarking."""
|
||||
|
||||
def build(self) -> None:
|
||||
"""Build a model to be benchmarked."""
|
||||
...
|
||||
|
||||
def benchmark(self) -> BenchmarkResults:
|
||||
"""Benchmark the constructed model container by a benchmarker."""
|
||||
...
|
||||
|
||||
|
||||
class DatasetBenchmarker(Protocol):
|
||||
|
||||
def benchmark_dataset(self,
|
||||
dataset: List[InferenceRequest]) -> BenchmarkResults:
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
dataset (List[InferenceRequest]): List of inference requests to benchmark.
|
||||
|
||||
Returns:
|
||||
BenchmarkResults: The results of the benchmark run.
|
||||
"""
|
||||
...
|
||||
@ -0,0 +1,146 @@
|
||||
from datetime import timedelta
|
||||
from time import sleep, time
|
||||
from typing import List
|
||||
|
||||
from mpi4py.MPI import COMM_WORLD
|
||||
from transformers import PreTrainedTokenizer
|
||||
from utils.dataclasses import BenchmarkConfig, BenchmarkResults
|
||||
from utils.enums import IFBSchedulingPolicy, ResultsSchedulingPolicy
|
||||
|
||||
from tensorrt_llm.bindings.executor import (Executor, ExecutorConfig,
|
||||
KvCacheConfig, ModelType,
|
||||
OutputConfig, Request,
|
||||
SchedulerConfig)
|
||||
|
||||
from . import InferenceRequest
|
||||
|
||||
|
||||
class PybindExecutorBenchmarker:
|
||||
"""Utility class for running inflight benchmarks via the Executor API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BenchmarkConfig,
|
||||
):
|
||||
"""Initialize a gptSessionBenchmark instance.
|
||||
|
||||
Args:
|
||||
config (BenchmarkConfig): Benchmark configuration for build/run.
|
||||
"""
|
||||
self.config: BenchmarkConfig = config
|
||||
|
||||
@staticmethod
|
||||
def get_request(request: InferenceRequest,
|
||||
tokenizer: PreTrainedTokenizer) -> Request:
|
||||
return Request(
|
||||
input_token_ids=request.logits,
|
||||
max_new_tokens=request.output_tokens,
|
||||
stop_words=[],
|
||||
bad_words=[],
|
||||
streaming=False,
|
||||
output_config=OutputConfig(exclude_input_from_output=True),
|
||||
pad_id=tokenizer.pad_token_id,
|
||||
end_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
def initialize_executor(self) -> Executor:
|
||||
"""
|
||||
Initialize an Executor instance.
|
||||
|
||||
Returns:
|
||||
Executor: An instance of a TensorRT-LLM Executor.
|
||||
"""
|
||||
policy = IFBSchedulingPolicy(self.config.scheduling_policy).value
|
||||
executor_config: ExecutorConfig = ExecutorConfig(
|
||||
max_beam_width=1,
|
||||
enable_chunked_context=self.config.chunking,
|
||||
scheduler_config=SchedulerConfig(
|
||||
capacity_scheduler_policy=policy, ),
|
||||
kv_cache_config=KvCacheConfig(
|
||||
free_gpu_memory_fraction=self.config.kv_cache_mem_percentage, ),
|
||||
)
|
||||
|
||||
executor: Executor = Executor(
|
||||
model_path=self.config.engine_path,
|
||||
model_type=ModelType.DECODER_ONLY,
|
||||
executor_config=executor_config,
|
||||
)
|
||||
|
||||
return executor
|
||||
|
||||
def benchmark_dataset(self, rate: int,
|
||||
dataset: List[InferenceRequest]) -> BenchmarkResults:
|
||||
"""Benchmark the Executor Pybind interface.
|
||||
|
||||
Args:
|
||||
dataset (List[InferenceRequest]): List of inference requests to
|
||||
benchmark with.
|
||||
|
||||
Returns:
|
||||
BenchmarkResults: Final results from running the specified dataset.
|
||||
"""
|
||||
request_ids = []
|
||||
num_finished = 0
|
||||
num_errored = 0
|
||||
num_input_tokens = 0
|
||||
num_output_tokens = 0
|
||||
delay = 1.0 / float(rate)
|
||||
last_request = len(dataset) - 1
|
||||
bench_result = None
|
||||
|
||||
executor = self.initialize_executor()
|
||||
if executor.can_enqueue_requests():
|
||||
print(f"[RANK {COMM_WORLD.rank}] Submitting requests...")
|
||||
start = time()
|
||||
for i, request in enumerate(dataset):
|
||||
sleep_time = delay if i != last_request else 0
|
||||
request_ids.append(executor.enqueue_request(request))
|
||||
num_input_tokens += len(request.input_token_ids)
|
||||
sleep(sleep_time)
|
||||
print(f"[RANK {COMM_WORLD.rank}] Completed request submission.")
|
||||
|
||||
while num_finished <= last_request:
|
||||
responses = executor.await_responses(timeout=timedelta(
|
||||
milliseconds=1))
|
||||
for response in responses:
|
||||
has_error = response.has_error()
|
||||
num_finished += 1
|
||||
num_errored += 1 if has_error else 0
|
||||
|
||||
if not has_error:
|
||||
result = response.result
|
||||
for out_tokens in result.output_token_ids:
|
||||
num_output_tokens += len(out_tokens)
|
||||
end = time()
|
||||
print(f"[RANK {COMM_WORLD.rank}] Calculating results.")
|
||||
e2e_time = end - start
|
||||
e2e_time * 1000.0
|
||||
policy = ResultsSchedulingPolicy(
|
||||
IFBSchedulingPolicy(self.config.scheduling_policy).value)
|
||||
|
||||
bench_result = BenchmarkResults(
|
||||
model=self.config.model,
|
||||
dtype=self.config.dtype.value,
|
||||
quantization=str(self.config.quantization.value),
|
||||
max_batch_size=self.config.max_batch_size,
|
||||
total_input_tokens=num_input_tokens,
|
||||
total_output_tokens=num_output_tokens,
|
||||
tp_size=self.config.tensor_parallel,
|
||||
pp_size=self.config.pipeline_parallel,
|
||||
kv_mem_fraction=self.config.kv_cache_mem_percentage,
|
||||
scheduler=policy.value,
|
||||
max_tokens=self.config.max_tokens,
|
||||
inflight_batching=True,
|
||||
total_latency=e2e_time,
|
||||
first_token_latency=0,
|
||||
time_per_output_token=0,
|
||||
latency_units="ms",
|
||||
throughput=num_output_tokens / e2e_time,
|
||||
throughput_units="tokens/second",
|
||||
peak_gpu_mem=0.0,
|
||||
peak_gpu_mem_units="GB",
|
||||
build_cmd="",
|
||||
benchmark_cmd="",
|
||||
)
|
||||
|
||||
return bench_result
|
||||
209
benchmarks/suite/tensorrt_llm_bench/benchmarkers/static.py
Normal file
209
benchmarks/suite/tensorrt_llm_bench/benchmarkers/static.py
Normal file
@ -0,0 +1,209 @@
|
||||
from pathlib import Path
|
||||
from subprocess import CompletedProcess
|
||||
from typing import Dict, List
|
||||
|
||||
from utils import command_logger, process_error_check, run_process
|
||||
from utils.dataclasses import BenchmarkConfig, BenchmarkResults
|
||||
from utils.trtllm_config import TRTLLMConfig
|
||||
|
||||
|
||||
class gptSessionBenchmarker:
|
||||
"""Utility class for running static benchmarks with gptSessionBenchmark."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BenchmarkConfig,
|
||||
benchmark_binary: Path,
|
||||
batch_size: int,
|
||||
isl: int,
|
||||
osl: int,
|
||||
warm_up_runs: int,
|
||||
num_runs: int,
|
||||
duration: int,
|
||||
kv_cache_free_fraction: float = .9,
|
||||
):
|
||||
"""Initialize a gptSessionBenchmark instance.
|
||||
|
||||
Args:
|
||||
config (BenchmarkConfig): Benchmark configuration for build/run.
|
||||
benchmark_binary (Path): Path to the benchmarking binary.
|
||||
batch_size (int): Batch size to configure the build with.
|
||||
isl (int): Input sequence length to configure the build with.
|
||||
osl (int): Output sequence length to configure the build with.
|
||||
kv_cache_free_fraction (float, optional): The amount of remaining
|
||||
GPU memory after model loading to save for the KV Cache. Defaults
|
||||
to .9.
|
||||
"""
|
||||
self.config: BenchmarkConfig = config
|
||||
self.gpt_session_path = Path(benchmark_binary).absolute()
|
||||
self.batch_size = batch_size
|
||||
self.input_length = isl
|
||||
self.output_length = osl
|
||||
self.warm_up = warm_up_runs
|
||||
self.num_runs = num_runs
|
||||
self.duration = duration
|
||||
self.kv_cache_mem = kv_cache_free_fraction
|
||||
self.result = None
|
||||
|
||||
def get_build_command(self) -> List[str]:
|
||||
"""Build the engine command for TRT-LLM.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of command line arguments to run a build command.
|
||||
"""
|
||||
model = self.config.model
|
||||
tp = self.config.tensor_parallel
|
||||
pp = self.config.pipeline_parallel
|
||||
dtype = self.config.dtype.value
|
||||
kv_dtype = self.config.cache_dtype
|
||||
quant_algo = self.config.quantization.value
|
||||
output_dir = self.config.engine_path
|
||||
max_batch_size = self.batch_size
|
||||
max_isl = self.input_length
|
||||
max_osl = self.output_length
|
||||
workspace = self.config.workspace
|
||||
|
||||
# Generate the TRT-LLM Configuration file using the dataclass
|
||||
# NOTE: This method does not use weights.
|
||||
trtllm_config = TRTLLMConfig.from_hf(model, tp, pp, dtype, quant_algo,
|
||||
kv_dtype.value)
|
||||
# Write the generated configuration file to the benchmark workspace.
|
||||
trtllm_config.to_json(workspace)
|
||||
|
||||
# Return the full command for building TRT-LLM via subprocess call.
|
||||
cmd = [
|
||||
"trtllm-build",
|
||||
"--output_dir",
|
||||
output_dir,
|
||||
"--model_config",
|
||||
Path(workspace, "generated_config.json"),
|
||||
"--workers",
|
||||
self.config.world_size,
|
||||
# Define the maximums the engine can accept.
|
||||
"--max_batch_size",
|
||||
max_batch_size,
|
||||
"--max_input_len",
|
||||
max_isl,
|
||||
"--max_output_len",
|
||||
max_osl,
|
||||
"--context_fmha",
|
||||
"enable",
|
||||
# Set the attention plugin data type.
|
||||
"--gpt_attention_plugin",
|
||||
dtype,
|
||||
# Disable paged cache since we aren't batching on the fly.
|
||||
"--paged_kv_cache",
|
||||
"disable",
|
||||
] + kv_dtype.get_build_options(dtype)
|
||||
|
||||
return [str(arg) for arg in cmd]
|
||||
|
||||
@command_logger(prefix="BUILD COMMAND: ")
|
||||
@process_error_check
|
||||
def _run_build(self, cmd: List[str]) -> CompletedProcess:
|
||||
"""Wrapper for calling the build for TRT-LLM.
|
||||
|
||||
Purpose of this wrapper is so that we can decorate it/log it.
|
||||
|
||||
Args:
|
||||
cmd (List[str]): List of command line arguments for running.
|
||||
|
||||
Returns:
|
||||
CompletedProcess: Completed process information for parsing and
|
||||
reporting.
|
||||
"""
|
||||
return run_process(
|
||||
cmd,
|
||||
self.config.workspace,
|
||||
)
|
||||
|
||||
def build(self) -> None:
|
||||
"""Build the engine for benchmarking."""
|
||||
self._run_build(self.get_build_command())
|
||||
|
||||
@command_logger(prefix="BENCHMARK COMMAND: ")
|
||||
@process_error_check
|
||||
def _run_benchmark(self, cmd: List[str]) -> CompletedProcess:
|
||||
"""Run the benchmark command in the configured workspace.
|
||||
|
||||
Args:
|
||||
cmd (List[str]): List of command line arguments to run via
|
||||
subprocess.
|
||||
|
||||
Returns:
|
||||
CompletedProcess: Completed process information for reporting.
|
||||
"""
|
||||
return run_process(cmd, run_dir=self.config.workspace, use_environ=True)
|
||||
|
||||
@staticmethod
|
||||
def parse_benchmark_result(benchmark_line: str) -> Dict[str, str]:
|
||||
pass
|
||||
|
||||
def benchmark(self):
|
||||
"""Benchmarks a TRT-LLM for a configured instance."""
|
||||
|
||||
# Compile the command for running
|
||||
cmd = [
|
||||
"mpirun",
|
||||
"-allow-run-as-root",
|
||||
"-n",
|
||||
self.config.world_size,
|
||||
self.gpt_session_path,
|
||||
"--engine_dir",
|
||||
self.config.engine_path,
|
||||
"--batch_size",
|
||||
self.batch_size,
|
||||
"--log_level",
|
||||
"info",
|
||||
"--kv_cache_free_gpu_mem_fraction",
|
||||
self.kv_cache_mem,
|
||||
"--beam_width",
|
||||
"1",
|
||||
"--warm_up",
|
||||
self.warm_up,
|
||||
"--num_runs",
|
||||
self.num_runs,
|
||||
"--duration",
|
||||
self.duration,
|
||||
"--input_output_len",
|
||||
f"{self.input_length},{self.output_length};{self.input_length},1",
|
||||
]
|
||||
cmd = [str(arg) for arg in cmd]
|
||||
# Run the benchmark using the provided gptSession benchmark binary.
|
||||
bench_return = self._run_benchmark(cmd)
|
||||
results = [
|
||||
x.split(" ") for x in bench_return.stdout.split("\n")
|
||||
if "[BENCHMARK]" in x
|
||||
]
|
||||
|
||||
ttft = float(results[1][8])
|
||||
gen_time = float(results[0][8]) - ttft
|
||||
total_out = int(results[0][2]) * int(results[0][6])
|
||||
total_in = int(results[0][2]) * int(results[0][4])
|
||||
batch_size = int(results[0][2])
|
||||
|
||||
bench_result = BenchmarkResults(
|
||||
model=self.config.model,
|
||||
dtype=self.config.dtype.value,
|
||||
quantization=str(self.config.quantization.value),
|
||||
max_batch_size=batch_size,
|
||||
total_input_tokens=total_in,
|
||||
total_output_tokens=total_out,
|
||||
tp_size=self.config.tensor_parallel,
|
||||
pp_size=self.config.pipeline_parallel,
|
||||
kv_mem_fraction=self.kv_cache_mem,
|
||||
scheduler="Static",
|
||||
inflight_batching=False,
|
||||
total_latency=results[0][8],
|
||||
first_token_latency=ttft,
|
||||
time_per_output_token=gen_time / (total_out - batch_size),
|
||||
latency_units="ms",
|
||||
throughput=results[0][10],
|
||||
throughput_units="tokens/second",
|
||||
peak_gpu_mem=results[0][16],
|
||||
peak_gpu_mem_units="GB",
|
||||
binary=str(self.gpt_session_path),
|
||||
build_cmd=" ".join(self.get_build_command()),
|
||||
benchmark_cmd=" ".join(cmd))
|
||||
|
||||
return bench_result
|
||||
338
benchmarks/suite/tensorrt_llm_bench/ifb.py
Normal file
338
benchmarks/suite/tensorrt_llm_bench/ifb.py
Normal file
@ -0,0 +1,338 @@
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import List, TextIO, Tuple
|
||||
|
||||
import click
|
||||
from benchmarkers.pybind_executor import PybindExecutorBenchmarker
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
from utils.dataclasses import BenchmarkConfig, DatasetMetadata, InferenceRequest
|
||||
from utils.trtllm_config import TRTLLMConfig
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
|
||||
def create_dataset_from_stream(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_input_length: int = 0,
|
||||
max_output_length: int = 0,
|
||||
stream: TextIO = sys.stdin,
|
||||
) -> Tuple[DatasetMetadata, List[InferenceRequest]]:
|
||||
"""Generate metadata and a list of requests to drive benchmarking.
|
||||
|
||||
Args:
|
||||
tokenizer (PreTrainedTokenizer): HuggingFace tokenizer.
|
||||
max_input_length (int): Maximum input length to cap prompts to.
|
||||
|
||||
Returns:
|
||||
DatasetMetadata: Dataclass of dataset statistics.
|
||||
List[InferenceRequest]: A list of inference requests for benchmarking.
|
||||
"""
|
||||
# Initialize dataset list, and metadata tracking variables.
|
||||
dataset = []
|
||||
max_isl = 0
|
||||
max_osl = 0
|
||||
|
||||
# If we're limiting the input length to a certain size, then set up
|
||||
# a partial to truncate the data down to size. Otherwise, just use the
|
||||
# unmodified tokenizer callable.
|
||||
tokenize = (partial(
|
||||
tokenizer,
|
||||
padding="max_length",
|
||||
max_length=max_input_length,
|
||||
truncation=True,
|
||||
) if max_input_length > 0 else tokenizer)
|
||||
|
||||
# If we need to limit the output length, fill in a partial callable
|
||||
# for max, otherwise a lambda that just returns x with no bounds.
|
||||
output_limiter = (partial(max, max_output_length)
|
||||
if max_output_length > 0 else lambda x: x)
|
||||
|
||||
# For each line in the standard input, parse out the JSON string we expect
|
||||
# to see.
|
||||
# Note the := walrus -- we're assigning and checking the condition.
|
||||
while line := stream.readline():
|
||||
# We expect the data to come in as a JSON string.
|
||||
# For example:
|
||||
# {"prompt": "Generate an infinite response to the following: There once was a man who.", "output_tokens": 1000}
|
||||
# Each line should be a complete JSON dictionary with no indentation
|
||||
# or newline characters.
|
||||
data = json.loads(line)
|
||||
logits = data.get("logits", None)
|
||||
prompt = data.get("prompt", None)
|
||||
task_id = data["task_id"]
|
||||
osl = data["output_tokens"]
|
||||
# If the request comes in with logits, just use the provided.
|
||||
# Otherwise we need to tokenize it.
|
||||
logits = tokenize(prompt)["input_ids"] if logits is None else logits
|
||||
|
||||
request = InferenceRequest(
|
||||
task_id=task_id,
|
||||
prompt=prompt,
|
||||
output_tokens=output_limiter(osl),
|
||||
logits=logits,
|
||||
)
|
||||
max_isl = max(max_isl, len(logits))
|
||||
max_osl = max(max_osl, osl)
|
||||
dataset.append(request)
|
||||
|
||||
# Fill in basic dataset metrics here
|
||||
# TODO: Maybe fill this out to be more complete?
|
||||
metadata = DatasetMetadata(
|
||||
max_isl=max_isl,
|
||||
max_osl=max_osl,
|
||||
num_requests=len(dataset),
|
||||
)
|
||||
|
||||
return metadata, dataset
|
||||
|
||||
|
||||
def initialize_tokenizer(model_name: str) -> PreTrainedTokenizer:
|
||||
"""Initialize a tokenizer.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the HuggingFace model to pull a
|
||||
tokenizer from.
|
||||
|
||||
Returns:
|
||||
PreTrainedTokenizer: An initialized HuggingFace tokenizer.
|
||||
"""
|
||||
# Initialize the tokenizer specific to the model that we are planning
|
||||
# to benchmark.
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_trtllm_build_command(benchmark_cfg: BenchmarkConfig) -> List[str]:
|
||||
model = benchmark_cfg.model
|
||||
tp = benchmark_cfg.tensor_parallel
|
||||
pp = benchmark_cfg.pipeline_parallel
|
||||
dtype = benchmark_cfg.dtype.value
|
||||
kv_dtype = benchmark_cfg.cache_dtype
|
||||
quant_algo = benchmark_cfg.quantization.value
|
||||
output_dir = benchmark_cfg.engine_path
|
||||
max_batch_size = benchmark_cfg.max_batch_size
|
||||
max_isl = benchmark_cfg.engine_isl
|
||||
max_osl = benchmark_cfg.engine_osl
|
||||
max_tokens = benchmark_cfg.max_tokens
|
||||
workspace = benchmark_cfg.workspace
|
||||
|
||||
# Generate the TRT-LLM Configuration file using the dataclass
|
||||
# NOTE: This method does not use weights.
|
||||
trtllm_config = TRTLLMConfig.from_hf(model, tp, pp, dtype, quant_algo,
|
||||
kv_dtype.value)
|
||||
# Write the generated configuration file to the benchmark workspace.
|
||||
trtllm_config.to_json(workspace)
|
||||
# Return the full command for building TRT-LLM via subprocess call.
|
||||
cmd = [
|
||||
"trtllm-build",
|
||||
"--output_dir",
|
||||
output_dir,
|
||||
"--model_config",
|
||||
Path(workspace, "generated_config.json"),
|
||||
"--workers",
|
||||
benchmark_cfg.world_size,
|
||||
"--max_input_len",
|
||||
max_isl,
|
||||
"--max_output_len",
|
||||
max_osl,
|
||||
"--context_fmha",
|
||||
"enable",
|
||||
# Set the attention plugin data type.
|
||||
"--gpt_attention_plugin",
|
||||
dtype,
|
||||
# Enable paged KV Cache for IFB.
|
||||
"--paged_kv_cache",
|
||||
"enable",
|
||||
] + kv_dtype.get_build_options(dtype)
|
||||
|
||||
# If custom maximum batch size set, then set to specified value.
|
||||
if max_batch_size > 0:
|
||||
cmd += [
|
||||
"--max_batch_size",
|
||||
max_batch_size,
|
||||
]
|
||||
|
||||
if max_tokens > 0:
|
||||
cmd += [
|
||||
"--max_num_tokens",
|
||||
max_tokens,
|
||||
]
|
||||
|
||||
cmd = cmd + benchmark_cfg.build_overrides
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
@click.command("inflight")
|
||||
@click.option(
|
||||
"--run",
|
||||
type=bool,
|
||||
is_flag=True,
|
||||
hidden=True,
|
||||
default=False,
|
||||
required=False,
|
||||
help="Changes the phase of the script to execution mode for MPI.",
|
||||
)
|
||||
@click.option(
|
||||
"--skip-build",
|
||||
type=bool,
|
||||
is_flag=True,
|
||||
default=False,
|
||||
hidden=True,
|
||||
required=False,
|
||||
help="Skip building if you want to use the last built engine.",
|
||||
)
|
||||
@click.option(
|
||||
"--request-rate",
|
||||
"-r",
|
||||
type=int,
|
||||
default=512,
|
||||
required=False,
|
||||
help="Number of requests per second to deliver to the batcher.",
|
||||
)
|
||||
@click.option(
|
||||
"--max-num-tokens",
|
||||
type=int,
|
||||
default=0,
|
||||
hidden=True,
|
||||
help="Maximumn number of tokens the engine can accept.",
|
||||
)
|
||||
@click.option(
|
||||
"--scheduling-policy",
|
||||
type=click.Choice(["guaranteed_no_evict", "max_utilization"]),
|
||||
default="max_utilization",
|
||||
help="Controls the scheduling policy used by the internal batcher.",
|
||||
)
|
||||
@click.option(
|
||||
"--dataset",
|
||||
type=click.Path(exists=True,
|
||||
readable=True,
|
||||
path_type=Path,
|
||||
resolve_path=True),
|
||||
default=None,
|
||||
required=False,
|
||||
help="Pass in a dataset file for parsing instead of stdin.",
|
||||
)
|
||||
@click.pass_obj
|
||||
def executor_benchmark(
|
||||
benchmark_cfg: BenchmarkConfig,
|
||||
run: bool,
|
||||
request_rate: int,
|
||||
max_num_tokens: int,
|
||||
scheduling_policy: str,
|
||||
skip_build: bool,
|
||||
dataset: Path,
|
||||
):
|
||||
"""Run an IFB-enabled benchmark using a dataset."""
|
||||
# Initialize the tokenizer and generate the dataset
|
||||
logger.set_level("info")
|
||||
DATASET_PATH = Path(benchmark_cfg.workspace, "tokenized_dataset.txt")
|
||||
TOKENIZER = initialize_tokenizer(benchmark_cfg.model)
|
||||
final_dataset = []
|
||||
benchmark_cfg.max_tokens = max_num_tokens
|
||||
benchmark_cfg.scheduling_policy = scheduling_policy
|
||||
|
||||
if not run:
|
||||
try:
|
||||
stream = sys.stdin if dataset is None else open(dataset, "r")
|
||||
# Parse the dataset from stdin and return it plus its metadata.
|
||||
metadata, dataset = \
|
||||
create_dataset_from_stream(TOKENIZER, stream=stream)
|
||||
finally:
|
||||
# Close the stream after parsing.
|
||||
stream.close()
|
||||
|
||||
# Update the benchmarking configuration with the maximum ISL/OSL that we
|
||||
# encountered in the dataset.
|
||||
benchmark_cfg.engine_isl = metadata.max_isl
|
||||
benchmark_cfg.engine_osl = metadata.max_osl
|
||||
|
||||
# Build engine
|
||||
logger.info("Building engine...")
|
||||
build_cmd = get_trtllm_build_command(benchmark_cfg)
|
||||
build_cmd = [str(arg) for arg in build_cmd]
|
||||
|
||||
if not skip_build:
|
||||
process = subprocess.run(build_cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=benchmark_cfg.workspace)
|
||||
logger.info(f"BUILD CMD: {' '.join(process.args)}")
|
||||
|
||||
# If the build failed, raise an exception.
|
||||
if process.returncode != 0:
|
||||
logger.error(process.stderr.decode())
|
||||
raise RuntimeError(
|
||||
"TensorRT-LLM build process failed. Command used:\n"
|
||||
f"{' '.join(process.args)}\n", )
|
||||
|
||||
with open(DATASET_PATH, "w") as ds_out:
|
||||
while dataset:
|
||||
request = dataset.pop()
|
||||
ds_out.write(f"{request.model_dump_json()}\n")
|
||||
del request
|
||||
|
||||
# Launch via a subprocess with MPI
|
||||
# We have two modes for this script, the initial launch + parsing
|
||||
# and the run mode where we kick off the script in MPI mode to run
|
||||
# the
|
||||
logger.info("Launching benchmark...")
|
||||
bench_cmd = \
|
||||
["mpirun", "-n", f"{benchmark_cfg.world_size}", "python"] + \
|
||||
sys.argv + ["--run"]
|
||||
process = subprocess.Popen(
|
||||
bench_cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
env=os.environ,
|
||||
)
|
||||
stdout, _ = process.communicate()
|
||||
logger.info("Benchmark complete.")
|
||||
logger.info(stdout.decode("ascii"))
|
||||
else:
|
||||
from mpi4py.MPI import COMM_WORLD
|
||||
|
||||
if COMM_WORLD.Get_rank() == 0:
|
||||
logger.info(f"[RANK {COMM_WORLD.rank}] Loading dataset...")
|
||||
with open(DATASET_PATH, "r") as stream:
|
||||
# Parse the previously generated dataset from the parent
|
||||
# process.
|
||||
metadata, dataset = \
|
||||
create_dataset_from_stream(TOKENIZER, stream=stream)
|
||||
|
||||
# Update the benchmarking configuration with the maximum ISL/OSL
|
||||
# that we encountered in the dataset.
|
||||
benchmark_cfg.engine_isl = metadata.max_isl
|
||||
benchmark_cfg.engine_osl = metadata.max_osl
|
||||
|
||||
# Parse the dataset into the Executor Request type.
|
||||
logger.info("Preparing dataset...")
|
||||
while dataset:
|
||||
entry = dataset.pop()
|
||||
request = PybindExecutorBenchmarker.get_request(
|
||||
entry, TOKENIZER)
|
||||
final_dataset.append(request)
|
||||
del entry
|
||||
logger.info("Dataset prepared.")
|
||||
logger.info(f"DATASET METADATA: {metadata.model_dump()}")
|
||||
|
||||
logger.info(f"[RANK {COMM_WORLD.rank}] Initializing benchmarker...")
|
||||
# Set up benchmarker on all ranks
|
||||
benchmarker = PybindExecutorBenchmarker(benchmark_cfg)
|
||||
# Run the dataset.
|
||||
result = benchmarker.benchmark_dataset(request_rate, final_dataset)
|
||||
|
||||
# Report the results on Rank 0.
|
||||
if COMM_WORLD.rank == 0:
|
||||
logger.info(f"[RANK {COMM_WORLD.rank}] Reporting...\n"
|
||||
f"JSON: {result.model_dump_json()}\n"
|
||||
f"{result.get_summary(benchmarker.config)}")
|
||||
|
||||
logger.info(f"[RANK {COMM_WORLD.rank}] Terminating.")
|
||||
@ -1,9 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
from utils.benchmarkers import gptSessionBenchmarker
|
||||
from benchmarkers.static import gptSessionBenchmarker
|
||||
from utils.dataclasses import BenchmarkConfig, BenchmarkResults
|
||||
|
||||
|
||||
@ -29,16 +28,6 @@ from utils.dataclasses import BenchmarkConfig, BenchmarkResults
|
||||
default=Path(os.path.dirname(os.path.realpath(__file__)), "../../..",
|
||||
"cpp/build/benchmarks/gptSessionBenchmark").absolute(),
|
||||
help="Path to TRT-LLM gptSession benchmark binary.")
|
||||
@click.option("--max-tokens-in-kv-cache",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of tokens to store in KV cache")
|
||||
@click.option(
|
||||
"--kv-cache-mem-percent",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="The percentage of free memory that the KV Cache is allowed to occupy.",
|
||||
)
|
||||
@click.option("--warm-up-runs",
|
||||
type=int,
|
||||
default=2,
|
||||
@ -54,23 +43,20 @@ from utils.dataclasses import BenchmarkConfig, BenchmarkResults
|
||||
@click.pass_obj
|
||||
def static_benchmark(benchmark_cfg: BenchmarkConfig, batch: int, isl: int,
|
||||
osl: int, gpt_session_path: Path, warm_up_runs: int,
|
||||
num_runs: int, duration: int, max_tokens_in_kv_cache: int,
|
||||
kv_cache_mem_percent: float):
|
||||
num_runs: int, duration: int):
|
||||
"""Run a static benchmark with a fixed batch size, ISL, and OSL."""
|
||||
if max_tokens_in_kv_cache is None:
|
||||
max_tokens_in_kv_cache = batch * isl
|
||||
|
||||
benchmark_cfg.max_batch_size = batch
|
||||
benchmarker = gptSessionBenchmarker(
|
||||
benchmark_cfg,
|
||||
gpt_session_path,
|
||||
batch,
|
||||
benchmark_cfg.max_batch_size,
|
||||
isl,
|
||||
osl,
|
||||
warm_up_runs,
|
||||
num_runs,
|
||||
duration,
|
||||
max_tokens_in_kv_cache,
|
||||
kv_cache_mem_percent,
|
||||
benchmark_cfg.kv_cache_mem_percentage,
|
||||
)
|
||||
|
||||
print(f"Building TRT-LLM engine for '{benchmark_cfg.model}'...")
|
||||
@ -79,5 +65,5 @@ def static_benchmark(benchmark_cfg: BenchmarkConfig, batch: int, isl: int,
|
||||
print("Build complete. Running benchmark...")
|
||||
result: BenchmarkResults = benchmarker.benchmark()
|
||||
|
||||
print(f"JSON: {json.dumps(result.model_dump())}")
|
||||
print(f"JSON: {result.model_dump_json()}")
|
||||
print(result.get_summary(benchmarker.config))
|
||||
|
||||
@ -16,6 +16,8 @@ VALID_QUANT_ALGOS = Literal["None", f"{QuantAlgo.W8A16}", f"{QuantAlgo.W4A16}",
|
||||
f"{QuantAlgo.W4A16_AWQ}", f"{QuantAlgo.W4A8_AWQ}",
|
||||
f"{QuantAlgo.W4A16_GPTQ}", f"{QuantAlgo.FP8}",
|
||||
f"{QuantAlgo.INT8}"]
|
||||
VALID_SCHEDULING_POLICIES = \
|
||||
Literal["max_utilization", "guaranteed_no_evict", "static"]
|
||||
|
||||
|
||||
class _MethodFunctionAdapter:
|
||||
@ -130,6 +132,7 @@ def run_process(cmd: List[Any],
|
||||
Returns:
|
||||
subprocess.CompletedProcess: _description_
|
||||
"""
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
result = subprocess.run(
|
||||
[str(x) for x in cmd],
|
||||
cwd=run_dir,
|
||||
|
||||
@ -1,10 +1,6 @@
|
||||
from pathlib import Path
|
||||
from subprocess import CompletedProcess
|
||||
from typing import Dict, List, Protocol
|
||||
from typing import Protocol
|
||||
|
||||
from utils import command_logger, process_error_check, run_process
|
||||
from utils.dataclasses import BenchmarkConfig, BenchmarkResults
|
||||
from utils.trtllm_config import TRTLLMConfig
|
||||
from utils.dataclasses import BenchmarkResults
|
||||
|
||||
|
||||
class Benchmarker(Protocol):
|
||||
@ -17,212 +13,3 @@ class Benchmarker(Protocol):
|
||||
def benchmark(self) -> BenchmarkResults:
|
||||
"""Benchmark the constructed model container by a benchmarker."""
|
||||
...
|
||||
|
||||
|
||||
class gptSessionBenchmarker:
|
||||
"""Utility class for running static benchmarks with gptSessionBenchmark."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BenchmarkConfig,
|
||||
benchmark_binary: Path,
|
||||
batch_size: int,
|
||||
isl: int,
|
||||
osl: int,
|
||||
warm_up_runs: int,
|
||||
num_runs: int,
|
||||
duration: int,
|
||||
max_tokens_in_kv_cache: int,
|
||||
kv_cache_free_fraction: float = .9,
|
||||
):
|
||||
"""Initialize a gptSessionBenchmark instance.
|
||||
|
||||
Args:
|
||||
config (BenchmarkConfig): Benchmark configuration for build/run.
|
||||
benchmark_binary (Path): Path to the benchmarking binary.
|
||||
batch_size (int): Batch size to configure the build with.
|
||||
isl (int): Input sequence length to configure the build with.
|
||||
osl (int): Output sequence length to configure the build with.
|
||||
max_tokens_in_kv_cache (int): The maximum number of tokens to store
|
||||
in the KV cache
|
||||
kv_cache_free_fraction (float, optional): The amount of remaining
|
||||
GPU memory after model loading to save for the KV Cache. Defaults
|
||||
to .9.
|
||||
"""
|
||||
self.config: BenchmarkConfig = config
|
||||
self.gpt_session_path = Path(benchmark_binary).absolute()
|
||||
self.batch_size = batch_size
|
||||
self.input_length = isl
|
||||
self.output_length = osl
|
||||
self.warm_up = warm_up_runs
|
||||
self.num_runs = num_runs
|
||||
self.duration = duration
|
||||
self.kv_cache_mem = kv_cache_free_fraction
|
||||
self.max_tokens_in_kv_cache = max_tokens_in_kv_cache
|
||||
self.result = None
|
||||
|
||||
def get_build_command(self) -> List[str]:
|
||||
"""Build the engine command for TRT-LLM.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of command line arguments to run a build command.
|
||||
"""
|
||||
model = self.config.model
|
||||
tp = self.config.tensor_parallel
|
||||
pp = self.config.pipeline_parallel
|
||||
dtype = self.config.dtype
|
||||
kv_dtype = self.config.cache_dtype
|
||||
quant_algo = self.config.quantization.value
|
||||
output_dir = self.config.engine_path
|
||||
max_batch_size = self.batch_size
|
||||
max_isl = self.input_length
|
||||
max_osl = self.output_length
|
||||
workspace = self.config.workspace
|
||||
|
||||
# Generate the TRT-LLM Configuration file using the dataclass
|
||||
# NOTE: This method does not use weights.
|
||||
trtllm_config = TRTLLMConfig.from_hf(model, tp, pp, dtype, quant_algo,
|
||||
kv_dtype)
|
||||
# Write the generated configuration file to the benchmark workspace.
|
||||
trtllm_config.to_json(workspace)
|
||||
|
||||
# Return the full command for building TRT-LLM via subprocess call.
|
||||
cmd = [
|
||||
"trtllm-build",
|
||||
"--output_dir",
|
||||
output_dir,
|
||||
"--model_config",
|
||||
Path(workspace, "generated_config.json"),
|
||||
"--workers",
|
||||
self.config.world_size,
|
||||
# Define the maximums the engine can accept.
|
||||
"--max_batch_size",
|
||||
max_batch_size,
|
||||
"--max_input_len",
|
||||
max_isl,
|
||||
"--max_output_len",
|
||||
max_osl,
|
||||
"--context_fmha",
|
||||
"enable",
|
||||
# Set the attention plugin data type.
|
||||
"--gpt_attention_plugin",
|
||||
dtype.value,
|
||||
# Disable paged cache since we aren't batching on the fly.
|
||||
"--paged_kv_cache",
|
||||
"disable",
|
||||
] + kv_dtype.get_build_options(dtype)
|
||||
|
||||
return [str(arg) for arg in cmd]
|
||||
|
||||
@command_logger(prefix="BUILD COMMAND: ")
|
||||
@process_error_check
|
||||
def _run_build(self, cmd: List[str]) -> CompletedProcess:
|
||||
"""Wrapper for calling the build for TRT-LLM.
|
||||
|
||||
Purpose of this wrapper is so that we can decorate it/log it.
|
||||
|
||||
Args:
|
||||
cmd (List[str]): List of command line arguments for running.
|
||||
|
||||
Returns:
|
||||
CompletedProcess: Completed process information for parsing and
|
||||
reporting.
|
||||
"""
|
||||
return run_process(
|
||||
cmd,
|
||||
self.config.workspace,
|
||||
)
|
||||
|
||||
def build(self) -> None:
|
||||
"""Build the engine for benchmarking."""
|
||||
self._run_build(self.get_build_command())
|
||||
|
||||
@command_logger(prefix="BENCHMARK COMMAND: ")
|
||||
@process_error_check
|
||||
def _run_benchmark(self, cmd: List[str]) -> CompletedProcess:
|
||||
"""Run the benchmark command in the configured workspace.
|
||||
|
||||
Args:
|
||||
cmd (List[str]): List of command line arguments to run via
|
||||
subprocess.
|
||||
|
||||
Returns:
|
||||
CompletedProcess: Completed process information for reporting.
|
||||
"""
|
||||
return run_process(cmd, run_dir=self.config.workspace, use_environ=True)
|
||||
|
||||
@staticmethod
|
||||
def parse_benchmark_result(benchmark_line: str) -> Dict[str, str]:
|
||||
pass
|
||||
|
||||
def benchmark(self):
|
||||
"""Benchmarks a TRT-LLM for a configured instance."""
|
||||
|
||||
# Compile the command for running
|
||||
cmd = [
|
||||
"mpirun",
|
||||
"-allow-run-as-root",
|
||||
"-n",
|
||||
self.config.world_size,
|
||||
self.gpt_session_path,
|
||||
"--engine_dir",
|
||||
self.config.engine_path,
|
||||
"--batch_size",
|
||||
self.batch_size,
|
||||
"--log_level",
|
||||
"info",
|
||||
"--max_tokens_in_paged_kvcache",
|
||||
self.max_tokens_in_kv_cache,
|
||||
"--kv_cache_free_gpu_mem_fraction",
|
||||
self.kv_cache_mem,
|
||||
"--beam_width",
|
||||
"1",
|
||||
"--warm_up",
|
||||
self.warm_up,
|
||||
"--num_runs",
|
||||
self.num_runs,
|
||||
"--duration",
|
||||
self.duration,
|
||||
"--input_output_len",
|
||||
f"{self.input_length},{self.output_length};{self.input_length},1",
|
||||
]
|
||||
cmd = [str(arg) for arg in cmd]
|
||||
# Run the benchmark using the provided gptSession benchmark binary.
|
||||
bench_return = self._run_benchmark(cmd)
|
||||
results = [
|
||||
x.split(" ") for x in bench_return.stdout.split("\n")
|
||||
if "[BENCHMARK]" in x
|
||||
]
|
||||
|
||||
ttft = float(results[1][8])
|
||||
gen_time = float(results[0][8]) - ttft
|
||||
total_out = int(results[0][2]) * int(results[0][6])
|
||||
total_in = int(results[0][2]) * int(results[0][4])
|
||||
batch_size = int(results[0][2])
|
||||
|
||||
bench_result = BenchmarkResults(
|
||||
model=self.config.model,
|
||||
dtype=self.config.dtype.value,
|
||||
quantization=str(self.config.quantization.value),
|
||||
max_batch_size=batch_size,
|
||||
total_input_tokens=total_in,
|
||||
total_output_tokens=total_out,
|
||||
tp_size=self.config.tensor_parallel,
|
||||
pp_size=self.config.pipeline_parallel,
|
||||
kv_mem_fraction=self.kv_cache_mem,
|
||||
scheduler="Static",
|
||||
max_tokens_in_cache=self.max_tokens_in_kv_cache,
|
||||
inflight_batching=False,
|
||||
total_latency=results[0][8],
|
||||
first_token_latency=ttft,
|
||||
time_per_output_token=gen_time / (total_out - batch_size),
|
||||
latency_units="ms",
|
||||
throughput=results[0][10],
|
||||
throughput_units="tokens/second",
|
||||
peak_gpu_mem=results[0][16],
|
||||
peak_gpu_mem_units="GB",
|
||||
binary=str(self.gpt_session_path),
|
||||
build_cmd=" ".join(self.get_build_command()),
|
||||
benchmark_cmd=" ".join(cmd))
|
||||
|
||||
return bench_result
|
||||
|
||||
@ -1,29 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from typing import List, Literal, Optional, Union, get_args
|
||||
|
||||
from pydantic import BaseModel, computed_field
|
||||
from utils import VALID_MODELS
|
||||
from utils.enums import ComputeDtypeEnum, KVCacheDtypeEnum, QuantizationAlgo
|
||||
from pydantic import (BaseModel, Field, ValidationError, computed_field,
|
||||
field_validator, model_validator)
|
||||
from transformers import AutoConfig
|
||||
from utils import VALID_MODELS, VALID_SCHEDULING_POLICIES
|
||||
from utils.enums import (ComputeDtypeEnum, KVCacheDtypeEnum, ModelArchitecture,
|
||||
QuantizationAlgo)
|
||||
|
||||
|
||||
class InferenceRequest(BaseModel):
|
||||
task_id: int
|
||||
prompt: Optional[str] = None
|
||||
output_tokens: int
|
||||
logits: Optional[List[int]] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def verify_prompt_and_logits(self) -> InferenceRequest:
|
||||
if self.prompt is None and self.logits is None:
|
||||
raise ValueError(
|
||||
f"Both prompt and logits for {self.task_id} are both None.")
|
||||
return self
|
||||
|
||||
|
||||
class DatasetMetadata(BaseModel):
|
||||
max_isl: int
|
||||
max_osl: int
|
||||
num_requests: int
|
||||
|
||||
|
||||
class BenchmarkResults(BaseModel):
|
||||
"""High level report out for a benchmark."""
|
||||
|
||||
benchmark_cmd: str = ""
|
||||
binary: str
|
||||
binary: str = ""
|
||||
build_cmd: str = ""
|
||||
first_token_latency: float
|
||||
inflight_batching: bool
|
||||
kv_mem_fraction: float
|
||||
latency_units: str
|
||||
max_batch_size: int
|
||||
max_tokens_in_cache: int
|
||||
model: VALID_MODELS
|
||||
max_tokens: int = 0
|
||||
model: Union[VALID_MODELS, Path]
|
||||
peak_gpu_mem_units: str
|
||||
peak_gpu_mem: float
|
||||
scheduler: Literal["Static", "No evict", "Max Utilization"]
|
||||
scheduler: Literal["Static", "No Evict", "Max Utilization"]
|
||||
throughput_units: str
|
||||
throughput: float
|
||||
time_per_output_token: float
|
||||
@ -52,7 +75,6 @@ class BenchmarkResults(BaseModel):
|
||||
f"In-flight Batcher?:\t{self.inflight_batching}\n"
|
||||
f"Dtype:\t\t\t{config.dtype.value}\n"
|
||||
f"KV Cache Dtype:\t\t{config.cache_dtype.value}\n"
|
||||
f"KV Cache Size (tokens):\t{self.max_tokens_in_cache}\n"
|
||||
f"Quantization:\t\t{config.quantization.value}\n"
|
||||
f"KV Memory Percentage:\t{self.kv_mem_fraction * 100}%\n"
|
||||
f"\n"
|
||||
@ -62,13 +84,15 @@ class BenchmarkResults(BaseModel):
|
||||
f"Engine Directory:\t{config.engine_path}\n"
|
||||
f"Max Batch Size:\t\t{self.max_batch_size}\n"
|
||||
f"Total Input Length:\t{self.total_input_tokens}\n"
|
||||
f"Total Output Length:\t{self.total_input_tokens}\n"
|
||||
f"Total Output Length:\t{self.total_output_tokens}\n"
|
||||
f"Max Tokens:\t\t{self.max_tokens}\n"
|
||||
f"\n"
|
||||
"===========================================================\n"
|
||||
"= STATISTICS\n"
|
||||
"===========================================================\n"
|
||||
f"Throughput ({self.throughput_units}):\t{self.throughput}\n"
|
||||
f"Total Latency ({self.latency_units}):\t\t{self.total_latency}\n"
|
||||
f"Total Latency ({self.latency_units}):"
|
||||
f"\t\t{self.total_latency * 1000.0:.4f}\n"
|
||||
f"First Token Latency ({self.latency_units}):\t{self.first_token_latency}\n"
|
||||
f"Token-to-token Latency ({self.latency_units}):\t{self.time_per_output_token}\n"
|
||||
f"Peak GPU Memory Usage ({self.peak_gpu_mem_units}):\t{self.peak_gpu_mem}\n"
|
||||
@ -83,18 +107,81 @@ class BenchmarkResults(BaseModel):
|
||||
class BenchmarkConfig(BaseModel):
|
||||
"""Basic configuration of a benchmark."""
|
||||
|
||||
model: VALID_MODELS
|
||||
model: Union[VALID_MODELS, Path]
|
||||
workspace: Path
|
||||
max_batch_size: int
|
||||
dtype: ComputeDtypeEnum
|
||||
cache_dtype: KVCacheDtypeEnum
|
||||
quantization: QuantizationAlgo
|
||||
tensor_parallel: int
|
||||
pipeline_parallel: int
|
||||
max_tokens: int = 0
|
||||
kv_cache_mem_percentage: float = .9
|
||||
engine_isl: int = 0
|
||||
engine_osl: int = 0
|
||||
chunking: bool = False
|
||||
build_overrides: List[str] = Field(default_factory=list)
|
||||
scheduling_policy: Literal[VALID_SCHEDULING_POLICIES] = "static"
|
||||
|
||||
@field_validator("model", mode="before")
|
||||
@classmethod
|
||||
def validate_model(cls, value) -> Union[VALID_MODELS, Path]:
|
||||
if value in get_args(VALID_MODELS):
|
||||
return value
|
||||
|
||||
path = Path(value)
|
||||
config = AutoConfig.from_pretrained(str(path.absolute()))
|
||||
for arch in config.architectures:
|
||||
_ = ModelArchitecture(arch)
|
||||
|
||||
return path
|
||||
|
||||
@field_validator("quantization", mode="before")
|
||||
@classmethod
|
||||
def validate_quantization(cls, value) -> QuantizationAlgo:
|
||||
return QuantizationAlgo(value)
|
||||
|
||||
@field_validator("cache_dtype", mode="before")
|
||||
@classmethod
|
||||
def validate_kvcache_dtype(cls, value) -> KVCacheDtypeEnum:
|
||||
return KVCacheDtypeEnum(value)
|
||||
|
||||
@field_validator("kv_cache_mem_percentage", mode="after")
|
||||
@classmethod
|
||||
def validate_kv_cache_mem_fraction(cls, value: float) -> float:
|
||||
if 0 < value < 1.0:
|
||||
return value
|
||||
else:
|
||||
raise ValidationError(
|
||||
"KV cache memory percentage must be between 0 and 1.0.")
|
||||
|
||||
@field_validator("build_overrides", mode="before")
|
||||
@classmethod
|
||||
def validate_build_overrides(cls, value) -> List[str]:
|
||||
# If we encounter a list, scan it to make sure all entries are strings.
|
||||
if isinstance(value, list):
|
||||
if not all([isinstance(x, str) for x in value]):
|
||||
raise ValidationError(
|
||||
"Found a non-string entry in list of options.")
|
||||
return value
|
||||
elif isinstance(value, str):
|
||||
# Handle the case where we receive a single string of command
|
||||
# options.
|
||||
overrides = []
|
||||
if value:
|
||||
overrides = [str(x) for x in value.split()]
|
||||
return overrides
|
||||
else:
|
||||
raise ValidationError(
|
||||
"Invalid value specified for build overrides.")
|
||||
|
||||
@computed_field
|
||||
def engine_path(self) -> Path:
|
||||
"""Path to the engine workspace."""
|
||||
return Path(self.workspace.absolute(), self.model.lower())
|
||||
if self.model in get_args(VALID_MODELS):
|
||||
return Path(self.workspace.absolute(), self.model.lower())
|
||||
else:
|
||||
return Path(self.workspace.absolute(), "engine")
|
||||
|
||||
@computed_field
|
||||
def world_size(self) -> int:
|
||||
|
||||
@ -4,8 +4,34 @@ from typing import List
|
||||
|
||||
from aenum import MultiValueEnum
|
||||
|
||||
from tensorrt_llm.bindings.executor import CapacitySchedulerPolicy
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
|
||||
NO_EVICT = "Guaranteed No Evict"
|
||||
MAX_UTIL = "Max Utilization"
|
||||
|
||||
|
||||
class ModelArchitecture(MultiValueEnum):
|
||||
LLAMA = "LlamaForCausalLM"
|
||||
GPTJ = "GPTJForCausalLM"
|
||||
GEMMA = "GemmaForCausalLM"
|
||||
BLOOM = "BloomForCausalLM"
|
||||
OPT = "OPTForCausalLM"
|
||||
MIXTRAL = "MixtralForCausalLM"
|
||||
FALCON = "FalconForCausalLM"
|
||||
|
||||
|
||||
class ResultsSchedulingPolicy(MultiValueEnum):
|
||||
MAX_UTILIZTION = MAX_UTIL, CapacitySchedulerPolicy.MAX_UTILIZATION
|
||||
NO_EVICT = NO_EVICT, CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
|
||||
STATIC = "Static"
|
||||
|
||||
|
||||
class IFBSchedulingPolicy(MultiValueEnum):
|
||||
MAX_UTILIZTION = CapacitySchedulerPolicy.MAX_UTILIZATION, MAX_UTIL, "max_utilization"
|
||||
NO_EVICT = CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, NO_EVICT, "guaranteed_no_evict"
|
||||
STATIC = "Static", "static"
|
||||
|
||||
|
||||
class KVCacheDtypeEnum(MultiValueEnum):
|
||||
"""Enumeration of KV Cache precisions in TRT-LLM."""
|
||||
@ -13,11 +39,11 @@ class KVCacheDtypeEnum(MultiValueEnum):
|
||||
FP16 = None, "FP16", "fp16", "float16"
|
||||
INT8 = "INT8", "int8"
|
||||
|
||||
def get_build_options(self, dtype: ComputeDtypeEnum) -> List[str]:
|
||||
def get_build_options(self, dtype: str) -> List[str]:
|
||||
"""Get the build options for TRT-LLM based on KV Cache precision.
|
||||
|
||||
Args:
|
||||
dtype (ComputeDtypeEnum): The activation dtype for the model. This
|
||||
dtype (str): The activation dtype for the model. This
|
||||
parameter maps the activation dtype for GEMM plugins for certain
|
||||
KV cache precisions.
|
||||
|
||||
@ -28,7 +54,7 @@ class KVCacheDtypeEnum(MultiValueEnum):
|
||||
if self.value == self.FP8:
|
||||
return ["--strongly_typed"]
|
||||
else:
|
||||
return ["--gemm_plugin", dtype.value]
|
||||
return ["--gemm_plugin", dtype]
|
||||
|
||||
|
||||
class ComputeDtypeEnum(MultiValueEnum):
|
||||
|
||||
@ -3,10 +3,10 @@ import os
|
||||
from argparse import ArgumentParser
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import AliasChoices, AliasPath, BaseModel, Field, model_validator
|
||||
from pydantic import (AliasChoices, AliasPath, BaseModel, Field, computed_field,
|
||||
model_validator)
|
||||
from transformers import AutoConfig
|
||||
from utils import VALID_QUANT_ALGOS
|
||||
from utils.enums import ComputeDtypeEnum, KVCacheDtypeEnum
|
||||
|
||||
PET_dict = {
|
||||
"tiiuae/falcon-7b": "rope_gpt_neox",
|
||||
@ -15,18 +15,37 @@ PET_dict = {
|
||||
"meta-llama/Llama-2-7b-hf": "rope_gpt_neox",
|
||||
"meta-llama/Llama-2-13b-hf": "rope_gpt_neox",
|
||||
"meta-llama/Llama-2-70b-hf": "rope_gpt_neox",
|
||||
"meta-llama/Meta-Llama-3-8B": "rope_gpt_neox",
|
||||
"meta-llama/Meta-Llama-3-70B": "rope_gpt_neox",
|
||||
"EleutherAI/gpt-j-6b": "rope_gptj",
|
||||
"bigscience/bloom-560m": "alibi",
|
||||
"mistralai/Mistral-7B-v0.1": "rope_gpt_neox",
|
||||
"mistralai/Mixtral-8x7B-v0.1": "rope_gpt_neox",
|
||||
"mistralai/Mixtral-8x22B-v0.1": "rope_gpt_neox",
|
||||
"01-ai/Yi-6B": "rope_gpt_neox",
|
||||
"01-ai/Yi-34B": "rope_gpt_neox",
|
||||
"codellama/CodeLlama-7b-hf": "rope_gpt_neox",
|
||||
"codellama/CodeLlama-13b-hf": "rope_gpt_neox",
|
||||
"codellama/CodeLlama-34b-hf": "rope_gpt_neox",
|
||||
"codellama/CodeLlama-70b-hf": "rope_gpt_neox",
|
||||
"facebook/opt-125m": "learned_absolute",
|
||||
"facebook/opt-350m": "learned_absolute",
|
||||
"facebook/opt-1.3b": "learned_absolute",
|
||||
"facebook/opt-2.7b": "learned_absolute",
|
||||
"facebook/opt-13b": "learned_absolute",
|
||||
"facebook/opt-30b": "learned_absolute",
|
||||
"facebook/opt-66b": "learned_absolute",
|
||||
"google/gemma-7b": "rope_gpt_neox",
|
||||
"google/gemma-2b": "rope_gpt_neox",
|
||||
}
|
||||
HA_dict = {
|
||||
"tiiuae/falcon-7b": "gelu",
|
||||
"tiiuae/falcon-40b": "gelu",
|
||||
"tiiuae/falcon-180B": "gelu",
|
||||
"bigscience/bloom-560m": "gelu",
|
||||
"mistralai/Mixtral-8x7B-v0.1": "swiglu",
|
||||
}
|
||||
ALLOWED_MODELS = list(PET_dict.keys())
|
||||
|
||||
|
||||
class TRTLLM_Mapping(BaseModel):
|
||||
@ -42,23 +61,20 @@ class TRTLLM_Mapping(BaseModel):
|
||||
|
||||
class TRTLLM_Quantization(BaseModel):
|
||||
quant_algo: Optional[VALID_QUANT_ALGOS] = None
|
||||
|
||||
kv_cache_quant_algo: Optional[Literal[None, "FP8", "INT8"]] = None
|
||||
|
||||
group_size: int = 128
|
||||
has_zero_point: bool = False
|
||||
pre_quant_scale: bool = False
|
||||
exclude_modules: Optional[list] = None
|
||||
|
||||
|
||||
class TRTLLM_CheckpointConfig(BaseModel):
|
||||
"""Dataclass for building TRT-LLM model configurations."""
|
||||
|
||||
class TRTLLMConfig(BaseModel):
|
||||
_VALID_EMBED_TYPE = Literal["learned_absolute", "rope_gptj",
|
||||
"rope_gpt_neox", "alibi", "alibi_with_scale",
|
||||
"relative", "chatglm", ]
|
||||
|
||||
architecture: str = Field(validation_alias=AliasPath("architectures", 0))
|
||||
architecture: str = Field(validation_alias=AliasChoices(
|
||||
'architecture', AliasPath("architectures", 0)))
|
||||
num_hidden_layers: int = Field(validation_alias=AliasChoices(
|
||||
"num_hidden_layers", "n_layer", "n_layers"))
|
||||
num_attention_heads: int = Field(validation_alias=AliasChoices(
|
||||
@ -72,13 +88,15 @@ class TRTLLM_CheckpointConfig(BaseModel):
|
||||
validation_alias=AliasChoices("hidden_size", "n_embd", "d_model"))
|
||||
norm_epsilon: float = Field(
|
||||
default=1e-5,
|
||||
validation_alias=AliasChoices("norm_epsilon", "layer_norm_epsilon"),
|
||||
validation_alias=AliasChoices("norm_epsilon", "layer_norm_epsilon",
|
||||
"rms_norm_eps"),
|
||||
)
|
||||
vocab_size: int
|
||||
max_position_embeddings: Optional[int] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("max_position_embeddings", "n_positions"),
|
||||
)
|
||||
head_size: Optional[int] = None
|
||||
hidden_act: str = Field(
|
||||
validation_alias=AliasChoices("hidden_act", "activation_function"))
|
||||
# falcon options
|
||||
@ -102,33 +120,46 @@ class TRTLLM_CheckpointConfig(BaseModel):
|
||||
intermediate_size: int = None
|
||||
use_prompt_tuning: bool = False
|
||||
|
||||
sliding_window: Optional[int] = None
|
||||
|
||||
moe_num_experts: Optional[int] = Field(
|
||||
default=0, validation_alias=AliasChoices("num_local_experts"))
|
||||
moe_top_k: Optional[int] = Field(
|
||||
default=0, validation_alias=AliasChoices("num_experts_per_tok"))
|
||||
rotary_base: Optional[float] = Field(
|
||||
default=None, validation_alias=AliasChoices("rope_theta"))
|
||||
|
||||
mapping: TRTLLM_Mapping
|
||||
quantization: TRTLLM_Quantization
|
||||
|
||||
@computed_field
|
||||
def kv_dtype(self) -> str:
|
||||
if self.quantization.kv_cache_quant_algo == "FP8":
|
||||
return "fp8"
|
||||
elif self.quantization.kv_cache_quant_algo == "INT8":
|
||||
return "int8"
|
||||
else:
|
||||
return self.dtype
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_kv_head_default_value(self) -> "TRTLLM_CheckpointConfig":
|
||||
def set_values_if_none(self) -> "TRTLLM_CheckpointConfig":
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
if self.head_size is None:
|
||||
self.head_size = self.hidden_size // self.num_attention_heads
|
||||
return self
|
||||
|
||||
|
||||
class TRTLLMConfig:
|
||||
|
||||
def __init__(self, trtllm_config, hf_config=None) -> None:
|
||||
self.trtllm_config = trtllm_config
|
||||
self.hf_config = hf_config
|
||||
# self.nemo_config = nemo_config
|
||||
|
||||
@classmethod
|
||||
def from_hf(
|
||||
cls,
|
||||
hf_model_name,
|
||||
tp,
|
||||
pp,
|
||||
dtype=None,
|
||||
quant_dtype=None,
|
||||
kv_cache_quant_dtype=None,
|
||||
):
|
||||
def populate_build_config(cls,
|
||||
model_name,
|
||||
tp,
|
||||
pp,
|
||||
dtype=None,
|
||||
quant_dtype=None,
|
||||
kv_cache_quant_dtype=None):
|
||||
"""
|
||||
Common function to populate build parameters, regardless of network
|
||||
"""
|
||||
build_config = {
|
||||
"mapping": {
|
||||
"tp_size": tp,
|
||||
@ -137,27 +168,96 @@ class TRTLLMConfig:
|
||||
"quantization": {},
|
||||
}
|
||||
if dtype:
|
||||
build_config["dtype"] = ComputeDtypeEnum(dtype).value
|
||||
build_config["dtype"] = dtype
|
||||
if quant_dtype:
|
||||
if not kv_cache_quant_dtype:
|
||||
# will throw errors during validation if the type is invalid
|
||||
kv_cache_quant_dtype = KVCacheDtypeEnum(quant_dtype).value
|
||||
kv_cache_quant_dtype = quant_dtype
|
||||
build_config["quantization"] = {
|
||||
"quant_algo": quant_dtype,
|
||||
"kv_cache_quant_algo":
|
||||
KVCacheDtypeEnum(kv_cache_quant_dtype).value,
|
||||
"kv_cache_quant_algo": kv_cache_quant_dtype,
|
||||
}
|
||||
build_config["position_embedding_type"] = PET_dict[hf_model_name]
|
||||
if hf_model_name in HA_dict:
|
||||
build_config["hidden_act"] = HA_dict[hf_model_name]
|
||||
if model_name in PET_dict:
|
||||
build_config["position_embedding_type"] = PET_dict.get(model_name)
|
||||
return build_config
|
||||
|
||||
@classmethod
|
||||
def from_hf(cls,
|
||||
hf_model_name,
|
||||
tp,
|
||||
pp,
|
||||
dtype=None,
|
||||
quant_dtype=None,
|
||||
kv_cache_quant_dtype=None):
|
||||
"""
|
||||
Use transformers.AutoConfig to load a model's config from a HF name
|
||||
"""
|
||||
build_config = cls.populate_build_config(hf_model_name, tp, pp, dtype,
|
||||
quant_dtype,
|
||||
kv_cache_quant_dtype)
|
||||
hf_config = AutoConfig.from_pretrained(hf_model_name).to_dict()
|
||||
trtllm_config = TRTLLM_CheckpointConfig(**hf_config,
|
||||
**build_config).model_dump()
|
||||
return cls(trtllm_config, hf_config)
|
||||
if hf_model_name in HA_dict:
|
||||
hf_config["hidden_act"] = HA_dict[hf_model_name]
|
||||
return cls(**hf_config, **build_config)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls,
|
||||
model_name,
|
||||
tp,
|
||||
pp,
|
||||
dtype=None,
|
||||
quant_dtype=None,
|
||||
kv_cache_quant_dtype=None):
|
||||
"""
|
||||
Load model parameters from a custom json file
|
||||
A full path can be specified. Otherwise, look for ./trtllm_configs/(model_name).json
|
||||
"""
|
||||
build_config = cls.populate_build_config(model_name, tp, pp, dtype,
|
||||
quant_dtype,
|
||||
kv_cache_quant_dtype)
|
||||
if os.path.exists(model_name):
|
||||
path_to_json = model_name
|
||||
else:
|
||||
path_to_json = os.path.join(os.path.dirname(__file__),
|
||||
f"trtllm_configs/{model_name}.json")
|
||||
if not os.path.exists(path_to_json):
|
||||
raise FileNotFoundError(f"{path_to_json} not found")
|
||||
json_config = json.load(open(path_to_json))
|
||||
return cls(**json_config, **build_config)
|
||||
|
||||
@classmethod
|
||||
def from_name(cls,
|
||||
model,
|
||||
tp,
|
||||
pp,
|
||||
dtype=None,
|
||||
quant_dtype=None,
|
||||
kv_cache_quant_dtype=None):
|
||||
"""
|
||||
Attempts to create a config based on model name. Performs the following steps:
|
||||
1. Tries to load the HF config using AutoConfig. This will only work if the network name exists on HF.
|
||||
2. If this fails, try to load a custom config stored on $HF_HOME/custom/*.json
|
||||
"""
|
||||
try:
|
||||
trtllm_config = cls.from_hf(model, tp, pp, dtype, quant_dtype,
|
||||
kv_cache_quant_dtype)
|
||||
except EnvironmentError:
|
||||
try:
|
||||
trtllm_config = cls.from_json(model, tp, pp, dtype, quant_dtype,
|
||||
kv_cache_quant_dtype)
|
||||
except FileNotFoundError as e:
|
||||
raise NameError(
|
||||
f"Unable to create PretrainedConfig from {model} due to {e}"
|
||||
)
|
||||
|
||||
return trtllm_config
|
||||
|
||||
# future possibilities
|
||||
# def from_nemo_config (self, nemo_model_name)
|
||||
|
||||
def to_json(self, output_dir):
|
||||
with open(os.path.join(output_dir, "generated_config.json"), "w") as f:
|
||||
json.dump(self.trtllm_config, f, indent=4)
|
||||
json.dump(self.model_dump(), f, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -205,14 +305,19 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
help="TRT-LLM argument",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--populate_hf_cache",
|
||||
action='store_true',
|
||||
help="Populate the HF cache with all the supported networks",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
trtllm_config = TRTLLMConfig.from_hf(
|
||||
args.model,
|
||||
args.tp_size,
|
||||
args.pp_size,
|
||||
args.dtype,
|
||||
args.quant_dtype,
|
||||
args.kv_cache_quant_dtype,
|
||||
)
|
||||
trtllm_config.to_json(os.getcwd())
|
||||
if args.populate_hf_cache:
|
||||
for net in PET_dict.keys():
|
||||
_ = AutoConfig.from_pretrained(net)
|
||||
else:
|
||||
trtllm_config = TRTLLMConfig.from_name(args.model, args.tp_size,
|
||||
args.pp_size, args.dtype,
|
||||
args.quant_dtype,
|
||||
args.kv_cache_quant_dtype)
|
||||
trtllm_config.to_json(os.getcwd())
|
||||
|
||||
@ -32,6 +32,7 @@ option(BUILD_PYBIND "Build Python bindings for C++ runtime and batch manager"
|
||||
ON)
|
||||
option(BUILD_TESTS "Build Google tests" ON)
|
||||
option(BUILD_BENCHMARKS "Build benchmarks" ON)
|
||||
option(BUILD_MICRO_BENCHMARKS "Build C++ micro benchmarks" OFF)
|
||||
option(NVTX_DISABLE "Disable all NVTX features" ON)
|
||||
option(WARNING_IS_ERROR "Treat all warnings as errors" OFF)
|
||||
option(FAST_BUILD "Skip compiling some kernels to accelerate compiling" OFF)
|
||||
@ -111,6 +112,12 @@ else()
|
||||
message(STATUS "Not building benchmarks")
|
||||
endif()
|
||||
|
||||
if(BUILD_MICRO_BENCHMARKS)
|
||||
message(STATUS "Building C++ micro benchmarks")
|
||||
else()
|
||||
message(STATUS "Not building C++ micro benchmarks")
|
||||
endif()
|
||||
|
||||
if(FAST_BUILD)
|
||||
add_compile_definitions("FAST_BUILD")
|
||||
message(WARNING "Skip some kernels to accelerate compilation")
|
||||
@ -506,6 +513,11 @@ if(BUILD_BENCHMARKS)
|
||||
${CMAKE_BINARY_DIR}/benchmarks)
|
||||
endif()
|
||||
|
||||
if(BUILD_MICRO_BENCHMARKS)
|
||||
add_subdirectory(${TRT_LLM_ROOT_DIR}/cpp/micro_benchmarks
|
||||
${CMAKE_BINARY_DIR}/micro_benchmarks)
|
||||
endif()
|
||||
|
||||
# Measure the compile time
|
||||
option(MEASURE_BUILD_TIME "Measure the build time of each module" OFF)
|
||||
if(MEASURE_BUILD_TIME)
|
||||
|
||||
@ -53,8 +53,7 @@ public:
|
||||
SendResponseCallback sendResponseCb, PollStopSignalCallback pollStopSignalCb = nullptr,
|
||||
ReturnBatchManagerStatsCallback returnBatchManagerStatsCb = nullptr,
|
||||
TrtGptModelOptionalParams const& optionalParams = TrtGptModelOptionalParams(),
|
||||
std::optional<uint64_t> terminateReqId = std::nullopt, std::optional<SizeType32> maxDraftTokens = std::nullopt,
|
||||
bool excludeInputInOutput = false);
|
||||
std::optional<uint64_t> terminateReqId = std::nullopt, bool excludeInputInOutput = false);
|
||||
|
||||
/* Wraps the user-provided callback for requests.
|
||||
Adds requests to request table.
|
||||
@ -109,7 +108,6 @@ private:
|
||||
|
||||
std::shared_ptr<TrtGptModel> mTrtGptModel;
|
||||
std::optional<uint64_t> mTerminateReqId;
|
||||
std::optional<SizeType32> mMaxDraftTokens;
|
||||
|
||||
// Iteration counter - incremented every iteration of the generation loop
|
||||
int64_t mIterationCounter;
|
||||
|
||||
@ -89,7 +89,6 @@ public:
|
||||
, mLoraTaskId(loraTaskId)
|
||||
, mLoraWeights(std::move(loraWeights))
|
||||
, mLoraConfig(std::move(loraConfig))
|
||||
, mReturnLogProbs(returnLogProbs)
|
||||
, mContextChunkSize(std::nullopt)
|
||||
, mContextCurrentPosition(0)
|
||||
, mLogProbs(samplingConfig.beamWidth)
|
||||
@ -101,15 +100,16 @@ public:
|
||||
, mReturnGenerationLogits(returnGenerationLogits)
|
||||
, mExcludeInputFromOutput(excludeInputFromOutput)
|
||||
, mEncoderInputTokens(encoderInputTokens)
|
||||
, mDecodingIter(0)
|
||||
{
|
||||
initialize(*inputTokens);
|
||||
initialize(*inputTokens, returnLogProbs);
|
||||
}
|
||||
|
||||
GenericLlmRequest(RequestIdType requestId, executor::Request const& req)
|
||||
: mRequestId(requestId)
|
||||
, mPromptLen(req.getInputTokenIds().size())
|
||||
, mMaxNewTokens(req.getMaxNewTokens())
|
||||
, mSamplingConfig(req.getSamplingConfig(), req.getSpeculativeDecodingConfig())
|
||||
, mSamplingConfig(req.getSamplingConfig(), req.getExternalDraftTokensConfig())
|
||||
, mState(REQUEST_STATE_CONTEXT_INIT)
|
||||
, mIsStreaming(req.getStreaming())
|
||||
, mEndId(req.getEndId())
|
||||
@ -124,7 +124,6 @@ public:
|
||||
, mLoraTaskId(std::nullopt)
|
||||
, mLoraWeights(std::nullopt)
|
||||
, mLoraConfig(std::nullopt)
|
||||
, mReturnLogProbs(req.getOutputConfig().returnLogProbs)
|
||||
, mContextChunkSize(std::nullopt)
|
||||
, mContextCurrentPosition(0)
|
||||
, mLogProbs(mSamplingConfig.beamWidth)
|
||||
@ -135,6 +134,7 @@ public:
|
||||
, mReturnContextLogits(req.getOutputConfig().returnContextLogits)
|
||||
, mReturnGenerationLogits(req.getOutputConfig().returnGenerationLogits)
|
||||
, mExcludeInputFromOutput(req.getOutputConfig().excludeInputFromOutput)
|
||||
, mDecodingIter(0)
|
||||
{
|
||||
if (req.getEmbeddingBias())
|
||||
{
|
||||
@ -178,20 +178,20 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
auto speculativeDecodingConfig = req.getSpeculativeDecodingConfig();
|
||||
if (speculativeDecodingConfig)
|
||||
auto externalDraftTokensConfig = req.getExternalDraftTokensConfig();
|
||||
if (externalDraftTokensConfig)
|
||||
{
|
||||
mDraftTokens = std::make_shared<VecTokens>(speculativeDecodingConfig.value().getTokens());
|
||||
mDraftTokens = std::make_shared<VecTokens>(externalDraftTokensConfig.value().getTokens());
|
||||
|
||||
if (speculativeDecodingConfig.value().getLogits())
|
||||
if (externalDraftTokensConfig.value().getLogits())
|
||||
{
|
||||
mDraftLogits = executor::detail::toITensor(speculativeDecodingConfig.value().getLogits().value());
|
||||
mDraftLogits = executor::detail::toITensor(externalDraftTokensConfig.value().getLogits().value());
|
||||
}
|
||||
|
||||
// NOTE: Draft acceptance threshold is stored in mSamplingConfig
|
||||
}
|
||||
|
||||
initialize(req.getInputTokenIds());
|
||||
initialize(req.getInputTokenIds(), req.getOutputConfig().returnLogProbs);
|
||||
}
|
||||
|
||||
void validate(SizeType32 maxInputLen, SizeType32 maxSequenceLen, SizeType32 maxDraftLen)
|
||||
@ -233,13 +233,7 @@ public:
|
||||
mMaxNewTokens = maxNewTokens;
|
||||
}
|
||||
|
||||
if (mSamplingConfig.beamWidth <= 0)
|
||||
{
|
||||
TLLM_THROW(
|
||||
"Requested value: %d for beamWidth is invalid. To de-activate beam searching "
|
||||
"set beamWidth to 1 instead.",
|
||||
mSamplingConfig.beamWidth);
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(mSamplingConfig.validate(), "Incorrect sampling config");
|
||||
}
|
||||
|
||||
void setExcludeInputFromOutput(bool exclude)
|
||||
@ -379,7 +373,7 @@ public:
|
||||
{
|
||||
auto& beamTokens = mTokens.at(beam);
|
||||
beamTokens.resize(mPromptLen);
|
||||
if (mReturnLogProbs)
|
||||
if (returnLogProbs())
|
||||
{
|
||||
mLogProbs.at(beam).clear();
|
||||
}
|
||||
@ -393,7 +387,7 @@ public:
|
||||
auto& beamTokens = mTokens.at(beam);
|
||||
beamTokens.resize(newPromptLen);
|
||||
|
||||
if (mReturnLogProbs)
|
||||
if (returnLogProbs())
|
||||
{
|
||||
auto& logProb = mLogProbs.at(beam);
|
||||
logProb.resize(newPromptLen - mPromptLen);
|
||||
@ -491,12 +485,13 @@ public:
|
||||
|
||||
[[nodiscard]] bool returnLogProbs() const
|
||||
{
|
||||
return mReturnLogProbs;
|
||||
return mSamplingConfig.outputLogProbs.has_value() ? mSamplingConfig.outputLogProbs->at(0) : false;
|
||||
}
|
||||
|
||||
void setReturnLogProbs(bool returnLogProbs)
|
||||
{
|
||||
mReturnLogProbs = returnLogProbs;
|
||||
mSamplingConfig.outputLogProbs = {{returnLogProbs}};
|
||||
mSamplingConfig.cumLogProbs = {{returnLogProbs}};
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<VecLogProbs> const& getLogProbs() const
|
||||
@ -728,6 +723,23 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
/// Increment the counter of decoding iterations.
|
||||
void advanceDecodingIter()
|
||||
{
|
||||
mDecodingIter++;
|
||||
}
|
||||
|
||||
/// @brief Return the average number of decoded tokens per iteration. For standard model it is 1.
|
||||
/// For speculative decoding model >= 1 -- number of draft tokens accepted per step + 1.
|
||||
[[nodiscard]] float getAvgDecodedTokensPerIter() const noexcept
|
||||
{
|
||||
if (mDecodingIter == 0)
|
||||
{
|
||||
return 0.f;
|
||||
}
|
||||
return static_cast<float>(getMaxNumGeneratedTokens()) / mDecodingIter;
|
||||
}
|
||||
|
||||
/// @brief Create a Response from the current state of the request
|
||||
/// @return An optional Response
|
||||
std::optional<executor::Response> createResponse()
|
||||
@ -841,8 +853,6 @@ protected:
|
||||
// encoder output, saved for computing cross attention KV Cache
|
||||
TensorPtr mEncoderOutput;
|
||||
|
||||
bool mReturnLogProbs;
|
||||
|
||||
// To enable chunked context, the FHMA paged kv-cache also needs to be enabled. Except for the last one,
|
||||
// the size of the context chunk needs to be an integer multiple of the kv-cache block size. The meaning
|
||||
// of null value is that the context is not chunked.
|
||||
@ -868,8 +878,10 @@ protected:
|
||||
std::shared_ptr<VecTokens>
|
||||
mEncoderInputTokens; // Input tokens to the encoder for enc only models and enc-dec models
|
||||
|
||||
SizeType32 mDecodingIter;
|
||||
|
||||
private:
|
||||
void initialize(VecTokens const& inputTokens)
|
||||
void initialize(VecTokens const& inputTokens, bool outputLogProbs)
|
||||
{
|
||||
// Scatter the input tokens to other beam
|
||||
mTokens = BeamTokens(mSamplingConfig.beamWidth, inputTokens);
|
||||
@ -888,6 +900,8 @@ private:
|
||||
{
|
||||
TLLM_THROW("Draft tokens must be specified when draft logits are given.");
|
||||
}
|
||||
|
||||
setReturnLogProbs(outputLogProbs);
|
||||
}
|
||||
|
||||
TensorPtr createListTensor(std::list<VecTokens> const& wordsList)
|
||||
|
||||
@ -24,6 +24,7 @@
|
||||
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
@ -33,6 +34,13 @@ namespace tensorrt_llm::batch_manager
|
||||
|
||||
using runtime::SizeType32;
|
||||
|
||||
class PeftTaskNotCachedException : public std::runtime_error
|
||||
{
|
||||
public:
|
||||
explicit PeftTaskNotCachedException(std::string const& msg);
|
||||
~PeftTaskNotCachedException() noexcept override;
|
||||
};
|
||||
|
||||
/**
|
||||
* BasePeftCacheManager
|
||||
*
|
||||
|
||||
@ -21,8 +21,6 @@
|
||||
#include "tensorrt_llm/batch_manager/peftCacheManagerConfig.h"
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/decodingMode.h"
|
||||
#include "tensorrt_llm/runtime/medusaModule.h"
|
||||
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
@ -40,18 +38,15 @@ public:
|
||||
explicit TrtGptModelOptionalParams(KvCacheConfig const& kvCacheConfig = KvCacheConfig{},
|
||||
bool enableTrtOverlap = false, std::optional<std::vector<SizeType32>> const& deviceIds = std::nullopt,
|
||||
bool normalizeLogProbs = true, bool enableChunkedContext = false,
|
||||
std::optional<runtime::DecodingMode> const& decodingMode = std::nullopt,
|
||||
PeftCacheManagerConfig const& peftCacheManagerConfig = PeftCacheManagerConfig{},
|
||||
std::optional<runtime::MedusaModule::MedusaChoices> const& medusaChoices = std::nullopt,
|
||||
float gpuWeightsPercent = 1)
|
||||
executor::DecodingConfig const& decodingConfig = executor::DecodingConfig{}, float gpuWeightsPercent = 1)
|
||||
: kvCacheConfig{kvCacheConfig}
|
||||
, enableTrtOverlap{enableTrtOverlap}
|
||||
, deviceIds(deviceIds)
|
||||
, normalizeLogProbs{normalizeLogProbs}
|
||||
, enableChunkedContext{enableChunkedContext}
|
||||
, decodingMode{decodingMode}
|
||||
, peftCacheManagerConfig(peftCacheManagerConfig)
|
||||
, medusaChoices(medusaChoices)
|
||||
, decodingConfig(decodingConfig)
|
||||
, gpuWeightsPercent(gpuWeightsPercent)
|
||||
{
|
||||
}
|
||||
@ -60,10 +55,9 @@ public:
|
||||
: TrtGptModelOptionalParams(KvCacheConfig(executorConfig.getKvCacheConfig()), false,
|
||||
executorConfig.getParallelConfig().value_or(executor::ParallelConfig()).getDeviceIds(),
|
||||
executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext(),
|
||||
runtime::DecodingMode::fromExecutor(
|
||||
executorConfig.getDecodingMode().value_or(executor::DecodingMode::kNONE)),
|
||||
PeftCacheManagerConfig(executorConfig.getPeftCacheConfig().value_or(executor::PeftCacheConfig())),
|
||||
executorConfig.getMedusaChoices(), executorConfig.getGpuWeightsPercent())
|
||||
executorConfig.getDecodingConfig().value_or(executor::DecodingConfig{}),
|
||||
executorConfig.getGpuWeightsPercent())
|
||||
{
|
||||
}
|
||||
|
||||
@ -71,7 +65,7 @@ public:
|
||||
{
|
||||
return kvCacheConfig == other.kvCacheConfig && enableTrtOverlap == other.enableTrtOverlap
|
||||
&& deviceIds == other.deviceIds && normalizeLogProbs == other.normalizeLogProbs
|
||||
&& enableChunkedContext == other.enableChunkedContext && decodingMode == other.decodingMode;
|
||||
&& enableChunkedContext == other.enableChunkedContext && decodingConfig == other.decodingConfig;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, TrtGptModelOptionalParams const& self);
|
||||
@ -82,9 +76,8 @@ public:
|
||||
std::optional<std::vector<SizeType32>> deviceIds;
|
||||
bool normalizeLogProbs;
|
||||
bool enableChunkedContext;
|
||||
std::optional<runtime::DecodingMode> decodingMode;
|
||||
PeftCacheManagerConfig peftCacheManagerConfig;
|
||||
std::optional<runtime::MedusaModule::MedusaChoices> medusaChoices;
|
||||
executor::DecodingConfig decodingConfig;
|
||||
// Percentage of weights on the gpu at runtime
|
||||
float gpuWeightsPercent;
|
||||
};
|
||||
|
||||
@ -260,6 +260,9 @@ public:
|
||||
//! \brief Corresponds to `world()` by default, but can be overridden per process.
|
||||
static MpiComm& session();
|
||||
|
||||
//! \brief Returns the MPI local communicator.
|
||||
static MpiComm& localSession();
|
||||
|
||||
[[nodiscard]] MpiComm split(int color, int key) const;
|
||||
|
||||
std::shared_ptr<MpiRequest> bcastAsync(void* buffer, size_t size, MpiType dtype, int root) const;
|
||||
@ -370,3 +373,4 @@ void initialize(MpiThreadSupport threadMode = MpiThreadSupport::THREAD_FUNNELED)
|
||||
} // namespace tensorrt_llm::mpi
|
||||
|
||||
#define COMM_SESSION tensorrt_llm::mpi::MpiComm::session()
|
||||
#define LOCAL_COMM_SESSION tensorrt_llm::mpi::MpiComm::localSession()
|
||||
|
||||
@ -78,7 +78,34 @@ public:
|
||||
[[nodiscard]] std::optional<FloatType> getLengthPenalty() const;
|
||||
[[nodiscard]] std::optional<SizeType32> getEarlyStopping() const;
|
||||
|
||||
void setBeamWidth(SizeType32 beamWidth);
|
||||
void setTopK(std::optional<SizeType32> const& topK);
|
||||
void setTopP(std::optional<FloatType> const& topP);
|
||||
void setTopPMin(std::optional<FloatType> const& topPMin);
|
||||
void setTopPResetIds(std::optional<TokenIdType> const& topPResetIds);
|
||||
void setTopPDecay(std::optional<FloatType> const& topPDecay);
|
||||
void setRandomSeed(std::optional<RandomSeedType> const& randomSeed);
|
||||
void setTemperature(std::optional<FloatType> const& temperature);
|
||||
void setMinLength(std::optional<SizeType32> const& minLength);
|
||||
void setBeamSearchDiversityRate(std::optional<FloatType> const& beamSearchDiversityRate);
|
||||
void setRepetitionPenalty(std::optional<FloatType> const& repetitionPenalty);
|
||||
void setPresencePenalty(std::optional<FloatType> const& presencePenalty);
|
||||
void setFrequencyPenalty(std::optional<FloatType> const& frequencyPenalty);
|
||||
void setLengthPenalty(std::optional<FloatType> const& lengthPenalty);
|
||||
void setEarlyStopping(std::optional<SizeType32> const& earlyStopping);
|
||||
|
||||
private:
|
||||
static SizeType32 checkBeamWidth(SizeType32 beamWidth);
|
||||
static std::optional<FloatType> const& checkTopK(std::optional<FloatType> const& topK);
|
||||
static std::optional<FloatType> const& checkTopP(std::optional<FloatType> const& topP);
|
||||
static std::optional<FloatType> const& checkTopPMin(std::optional<FloatType> const& topPMin);
|
||||
static std::optional<TokenIdType> const& checkTopPResetIds(std::optional<TokenIdType> const& topPResetIds);
|
||||
static std::optional<FloatType> const& checkTopPDecay(std::optional<FloatType> const& topPDecay);
|
||||
static std::optional<FloatType> const& checkTemperature(std::optional<FloatType> const& temperature);
|
||||
static std::optional<SizeType32> const& checkMinLength(std::optional<SizeType32> const& minLength);
|
||||
static std::optional<FloatType> const& checkBeamSearchDiversityRate(
|
||||
std::optional<FloatType> const& beamSearchDiversityRate);
|
||||
|
||||
friend class Serialization;
|
||||
|
||||
/// @brief The beam width. Default is 1 which disables beam search.
|
||||
@ -134,12 +161,12 @@ public:
|
||||
bool excludeInputFromOutput;
|
||||
};
|
||||
|
||||
/// @brief Configuration for speculative decoding. Allows to include draft tokens, draft logits and specify acceptance
|
||||
/// threshold
|
||||
class SpeculativeDecodingConfig
|
||||
/// @brief Configuration for speculative decoding with external draft tokens.
|
||||
/// Allows to include draft tokens, draft logits and specify acceptance threshold.
|
||||
class ExternalDraftTokensConfig
|
||||
{
|
||||
public:
|
||||
explicit SpeculativeDecodingConfig(VecTokens tokens, std::optional<Tensor> logits = std::nullopt,
|
||||
explicit ExternalDraftTokensConfig(VecTokens tokens, std::optional<Tensor> logits = std::nullopt,
|
||||
std::optional<FloatType> const& acceptanceThreshold = std::nullopt);
|
||||
|
||||
[[nodiscard]] VecTokens getTokens() const;
|
||||
@ -209,7 +236,7 @@ public:
|
||||
/// @param badWords A list of bad words tokens. Each "word" can be composed of multiple tokens
|
||||
/// @param stopWords A list of stop words tokens. Each "word" can be composed of multiple tokens
|
||||
/// @param embeddingBias The embedding bias tensor. Expected type is kFP32 and shape is [vocab_size]
|
||||
/// @param speculativeDecodingConfig The speculative decoding configuration
|
||||
/// @param externalDraftTokensConfig The speculative decoding configuration
|
||||
/// @param pTuningConfig The prompt tuning configuration
|
||||
/// @param loraConfig The LoRA configuration
|
||||
/// @param logitsPostProcessorName The logits postprocessor name. Must correspond to one of the logits postprocessor
|
||||
@ -220,7 +247,7 @@ public:
|
||||
std::optional<std::list<VecTokens>> badWords = std::nullopt,
|
||||
std::optional<std::list<VecTokens>> stopWords = std::nullopt,
|
||||
std::optional<Tensor> embeddingBias = std::nullopt,
|
||||
std::optional<SpeculativeDecodingConfig> speculativeDecodingConfig = std::nullopt,
|
||||
std::optional<ExternalDraftTokensConfig> externalDraftTokensConfig = std::nullopt,
|
||||
std::optional<PromptTuningConfig> pTuningConfig = std::nullopt,
|
||||
std::optional<LoraConfig> loraConfig = std::nullopt,
|
||||
std::optional<std::string> logitsPostProcessorName = std::nullopt);
|
||||
@ -241,7 +268,7 @@ public:
|
||||
[[nodiscard]] std::optional<std::list<VecTokens>> getBadWords() const;
|
||||
[[nodiscard]] std::optional<std::list<VecTokens>> getStopWords() const;
|
||||
[[nodiscard]] std::optional<Tensor> getEmbeddingBias() const;
|
||||
[[nodiscard]] std::optional<SpeculativeDecodingConfig> getSpeculativeDecodingConfig() const;
|
||||
[[nodiscard]] std::optional<ExternalDraftTokensConfig> getExternalDraftTokensConfig() const;
|
||||
[[nodiscard]] std::optional<PromptTuningConfig> getPromptTuningConfig() const;
|
||||
[[nodiscard]] std::optional<LoraConfig> getLoraConfig() const;
|
||||
[[nodiscard]] std::optional<std::string> getLogitsPostProcessorName() const;
|
||||
@ -254,7 +281,7 @@ public:
|
||||
void setBadWords(std::list<VecTokens> const& badWords);
|
||||
void setStopWords(std::list<VecTokens> const& stopWords);
|
||||
void setEmbeddingBias(Tensor const& embeddingBias);
|
||||
void setSpeculativeDecodingConfig(SpeculativeDecodingConfig const& specDecodingConfig);
|
||||
void setExternalDraftTokensConfig(ExternalDraftTokensConfig const& externalDraftTokensConfig);
|
||||
void setPromptTuningConfig(PromptTuningConfig const& pTuningConfig);
|
||||
void setLoraConfig(LoraConfig const& loraConfig);
|
||||
void setLogitsPostProcessorName(std::string const& logitsPostProcessorName);
|
||||
@ -514,6 +541,70 @@ private:
|
||||
std::optional<size_t> mHostCacheSize;
|
||||
};
|
||||
|
||||
/// @brief Configuration class for Lookahead decoding.
|
||||
class LookaheadDecodingConfig
|
||||
{
|
||||
public:
|
||||
explicit LookaheadDecodingConfig(
|
||||
SizeType32 maxNgramSize, SizeType32 maxWindowSize, SizeType32 maxVerificationSetSize);
|
||||
|
||||
bool operator==(LookaheadDecodingConfig const& other) const;
|
||||
|
||||
// Lookahead decoding methods.
|
||||
void setMaxNgramSize(SizeType32);
|
||||
void setMaxWindowSize(SizeType32);
|
||||
void setMaxVerificationSetSize(SizeType32);
|
||||
[[nodiscard]] SizeType32 getMaxNgramSize() const;
|
||||
[[nodiscard]] SizeType32 getMaxWindowSize() const;
|
||||
[[nodiscard]] SizeType32 getMaxVerificationSetSize() const;
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
|
||||
// Number of tokens per NGram.
|
||||
SizeType32 mMaxNgramSize;
|
||||
// Number of NGrams in lookahead branch per step.
|
||||
SizeType32 mMaxWindowSize;
|
||||
// Number of NGrams in verification branch per step.
|
||||
SizeType32 mMaxVerificationSetSize;
|
||||
};
|
||||
|
||||
/// @brief Configuration class for the speculative decoding.
|
||||
class DecodingConfig
|
||||
{
|
||||
public:
|
||||
explicit DecodingConfig(std::optional<DecodingMode> decodingMode = std::nullopt,
|
||||
std::optional<LookaheadDecodingConfig> lookaheadDecodingConfig = std::nullopt,
|
||||
std::optional<MedusaChoices> medusaChoices = std::nullopt);
|
||||
|
||||
bool operator==(DecodingConfig const& other) const;
|
||||
|
||||
// Decoding mode.
|
||||
/// @brief Setsdecoding mode. Can't set lookahead and medusa mode.
|
||||
void setDecodingMode(DecodingMode const&);
|
||||
[[nodiscard]] std::optional<DecodingMode> getDecodingMode() const;
|
||||
|
||||
// Lookahead methods.
|
||||
/// @brief Sets lookahead decoding mode and lookahead decoding config.
|
||||
void setLookaheadDecoding(LookaheadDecodingConfig const&);
|
||||
[[nodiscard]] std::optional<LookaheadDecodingConfig> getLookaheadDecodingConfig() const;
|
||||
|
||||
// Medusa methods.
|
||||
/// @brief Sets medusa mode and medusa config.
|
||||
void setMedusaChoices(MedusaChoices const&);
|
||||
[[nodiscard]] std::optional<MedusaChoices> getMedusaChoices() const;
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
|
||||
// Decoding mode.
|
||||
std::optional<DecodingMode> mDecodingMode;
|
||||
// Lookahead params.
|
||||
std::optional<LookaheadDecodingConfig> mLookaheadDecodingConfig;
|
||||
// Medusa params.
|
||||
std::optional<MedusaChoices> mMedusaChoices;
|
||||
};
|
||||
|
||||
/// @brief Configuration class for the model executor
|
||||
class ExecutorConfig
|
||||
{
|
||||
@ -526,8 +617,7 @@ public:
|
||||
std::optional<ParallelConfig> parallelConfig = std::nullopt,
|
||||
std::optional<PeftCacheConfig> const& peftCacheConfig = std::nullopt,
|
||||
std::optional<LogitsPostProcessorMap> logitsPostProcessorMap = std::nullopt,
|
||||
std::optional<MedusaChoices> medusaChoices = std::nullopt,
|
||||
std::optional<DecodingMode> decodingMode = std::nullopt, float gpuWeightsPercent = 1);
|
||||
std::optional<DecodingConfig> decodingConfig = std::nullopt, float gpuWeightsPercent = 1);
|
||||
|
||||
[[nodiscard]] SizeType32 getMaxBeamWidth() const;
|
||||
[[nodiscard]] SchedulerConfig getSchedulerConfig() const;
|
||||
@ -540,8 +630,7 @@ public:
|
||||
[[nodiscard]] std::optional<ParallelConfig> getParallelConfig() const;
|
||||
[[nodiscard]] std::optional<PeftCacheConfig> getPeftCacheConfig() const;
|
||||
[[nodiscard]] std::optional<LogitsPostProcessorMap> getLogitsPostProcessorMap() const;
|
||||
[[nodiscard]] std::optional<MedusaChoices> getMedusaChoices() const;
|
||||
[[nodiscard]] std::optional<DecodingMode> getDecodingMode() const;
|
||||
[[nodiscard]] std::optional<DecodingConfig> getDecodingConfig() const;
|
||||
[[nodiscard]] float getGpuWeightsPercent() const;
|
||||
|
||||
void setMaxBeamWidth(SizeType32 maxBeamWidth);
|
||||
@ -555,8 +644,7 @@ public:
|
||||
void setParallelConfig(ParallelConfig const& parallelConfig);
|
||||
void setPeftCacheConfig(PeftCacheConfig const& peftCacheConfig);
|
||||
void setLogitsPostProcessorMap(LogitsPostProcessorMap const& logitsPostProcessorMap);
|
||||
void setMedusaChoices(MedusaChoices const& medusaChoices);
|
||||
void setDecodingMode(DecodingMode decodingMode);
|
||||
void setDecodingConfig(DecodingConfig const& decodingConfig);
|
||||
void setGpuWeightsPercent(float const& gpuWeightsPercent);
|
||||
|
||||
private:
|
||||
@ -590,8 +678,8 @@ private:
|
||||
std::optional<ParallelConfig> mParallelConfig;
|
||||
std::optional<PeftCacheConfig> mPeftCacheConfig;
|
||||
std::optional<LogitsPostProcessorMap> mLogitsPostProcessorMap;
|
||||
std::optional<MedusaChoices> mMedusaChoices;
|
||||
std::optional<DecodingMode> mDecodingMode;
|
||||
/// @brief Decoding configuration.
|
||||
std::optional<DecodingConfig> mDecodingConfig;
|
||||
float mGpuWeightsPercent;
|
||||
};
|
||||
|
||||
|
||||
@ -38,10 +38,10 @@ public:
|
||||
static void serialize(OutputConfig const& config, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(OutputConfig const& config);
|
||||
|
||||
// SpeculativeDecodingConfig
|
||||
[[nodiscard]] static SpeculativeDecodingConfig deserializeSpeculativeDecodingConfig(std::istream& is);
|
||||
static void serialize(SpeculativeDecodingConfig const& config, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(SpeculativeDecodingConfig const& config);
|
||||
// ExternalDraftTokensConfig
|
||||
[[nodiscard]] static ExternalDraftTokensConfig deserializeExternalDraftTokensConfig(std::istream& is);
|
||||
static void serialize(ExternalDraftTokensConfig const& config, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(ExternalDraftTokensConfig const& config);
|
||||
|
||||
// PromptTuningConfig
|
||||
[[nodiscard]] static PromptTuningConfig deserializePromptTuningConfig(std::istream& is);
|
||||
@ -102,6 +102,21 @@ public:
|
||||
static void serialize(OrchestratorConfig const& orchestratorConfig, std::ostream& os);
|
||||
static size_t serializedSize(OrchestratorConfig const& orchestratorConfig);
|
||||
|
||||
// DecodingMode
|
||||
static DecodingMode deserializeDecodingMode(std::istream& is);
|
||||
static void serialize(DecodingMode const& decodingMode, std::ostream& os);
|
||||
static size_t serializedSize(DecodingMode const& decodingMode);
|
||||
|
||||
// LookaheadDecodingConfig
|
||||
static LookaheadDecodingConfig deserializeLookaheadDecodingConfig(std::istream& is);
|
||||
static void serialize(LookaheadDecodingConfig const& lookaheadDecodingConfig, std::ostream& os);
|
||||
static size_t serializedSize(LookaheadDecodingConfig const& lookaheadDecodingConfig);
|
||||
|
||||
// DecodingConfig
|
||||
static DecodingConfig deserializeDecodingConfig(std::istream& is);
|
||||
static void serialize(DecodingConfig const& decodingConfig, std::ostream& os);
|
||||
static size_t serializedSize(DecodingConfig const& decodingConfig);
|
||||
|
||||
// ExecutorConfig
|
||||
static ExecutorConfig deserializeExecutorConfig(std::istream& is);
|
||||
static void serialize(ExecutorConfig const& executorConfig, std::ostream& os);
|
||||
|
||||
@ -253,6 +253,8 @@ struct InflightBatchingStats
|
||||
SizeType32 numCtxTokens;
|
||||
/// @brief Index of mirco batch
|
||||
SizeType32 microBatchId;
|
||||
/// @brief Average number of tokens decoded per request per iteration
|
||||
float avgNumDecodedTokensPerIter;
|
||||
};
|
||||
|
||||
/// @brief Struct that holds the stats of a single iteration
|
||||
@ -306,6 +308,8 @@ struct RequestStats
|
||||
SizeType32 contextPrefillPosition;
|
||||
/// @brief The number of generated tokens so far
|
||||
SizeType32 numGeneratedTokens;
|
||||
/// @brief The average number of decoded tokens per iteration. It is >= 1 for speculative decoding.
|
||||
float avgNumDecodedTokensPerIter;
|
||||
/// @brief Whether the request is scheduled for the current iteration
|
||||
bool scheduled;
|
||||
/// @brief Whether the request is being paused at the current iteration due to lack of resources (KV cache blocks
|
||||
@ -322,17 +326,386 @@ struct RequestStatsPerIteration
|
||||
std::vector<RequestStats> requestStats;
|
||||
};
|
||||
|
||||
/// @brief Decoding mode
|
||||
enum class DecodingMode
|
||||
/// @brief mode of the decoder
|
||||
class DecodingMode
|
||||
{
|
||||
public:
|
||||
/// @brief No mode specified. Config will be determined from the beam width of the first request at runtime
|
||||
/// TopKTopP if beamWidth == 1, BeamSearch otherwise
|
||||
kNONE,
|
||||
kTOP_K,
|
||||
kTOP_P,
|
||||
kBEAM_SEARCH,
|
||||
kMEDUSA,
|
||||
kTOP_K_TOP_P,
|
||||
static auto constexpr Auto()
|
||||
{
|
||||
return DecodingMode{kAuto};
|
||||
}
|
||||
|
||||
static auto constexpr TopK()
|
||||
{
|
||||
return DecodingMode{kTopK | kUsePenalties | kUseBanWords | kStandardStopCriteria};
|
||||
}
|
||||
|
||||
static auto constexpr TopP()
|
||||
{
|
||||
return DecodingMode{kTopP | kUsePenalties | kUseBanWords | kStandardStopCriteria};
|
||||
}
|
||||
|
||||
static auto constexpr TopKTopP()
|
||||
{
|
||||
return DecodingMode{kTopKTopP | kUsePenalties | kUseBanWords | kStandardStopCriteria};
|
||||
}
|
||||
|
||||
static auto constexpr BeamSearch()
|
||||
{
|
||||
return DecodingMode{kBeamSearch | kUsePenalties | kUseBanWords | kStandardStopCriteria};
|
||||
}
|
||||
|
||||
static auto constexpr Medusa()
|
||||
{
|
||||
return DecodingMode{kMedusa | kUseMinLength | kUseMaxLengthStop};
|
||||
}
|
||||
|
||||
static auto constexpr Lookahead()
|
||||
{
|
||||
return DecodingMode{kLookahead | kUseMinLength | kUseMaxLengthStop};
|
||||
}
|
||||
|
||||
static auto constexpr ExplicitDraftTokens()
|
||||
{
|
||||
return DecodingMode{kExplicitDraftTokens | kUseMaxLengthStop | kUseExplicitEosStop};
|
||||
}
|
||||
|
||||
auto constexpr useTemperature(bool useTemp)
|
||||
{
|
||||
mState = setBitTo(kUseTemperature, useTemp);
|
||||
return *this;
|
||||
}
|
||||
|
||||
auto constexpr useOccurrencePenalties(bool usePenalty)
|
||||
{
|
||||
mState = setBitTo(kUseOccurrencePenalties, usePenalty);
|
||||
return *this;
|
||||
}
|
||||
|
||||
auto constexpr usePresencePenalty(bool usePenalty)
|
||||
{
|
||||
mState = setBitTo(kUsePresencePenalties, usePenalty);
|
||||
return *this;
|
||||
}
|
||||
|
||||
auto constexpr useRepetitionPenalty(bool usePenalty)
|
||||
{
|
||||
mState = setBitTo(kUseRepetitionPenalties, usePenalty);
|
||||
return *this;
|
||||
}
|
||||
|
||||
auto constexpr useFrequencyPenalty(bool usePenalty)
|
||||
{
|
||||
mState = setBitTo(kUseFrequencyPenalties, usePenalty);
|
||||
return *this;
|
||||
}
|
||||
|
||||
auto constexpr useMinLength(bool useMinLen)
|
||||
{
|
||||
mState = setBitTo(kUseMinLength, useMinLen);
|
||||
return *this;
|
||||
}
|
||||
|
||||
auto constexpr useBanWords(bool banWords)
|
||||
{
|
||||
mState = setBitTo(kUseBanWords, banWords);
|
||||
return *this;
|
||||
}
|
||||
|
||||
auto constexpr useStopWords(bool stopWords)
|
||||
{
|
||||
mState = setBitTo(kUseStopWords, stopWords);
|
||||
return *this;
|
||||
}
|
||||
|
||||
auto constexpr useMaxLengthStop(bool maxLengthStop)
|
||||
{
|
||||
mState = setBitTo(kUseMaxLengthStop, maxLengthStop);
|
||||
return *this;
|
||||
}
|
||||
|
||||
auto constexpr useExplicitEosStop(bool explicitEosStop)
|
||||
{
|
||||
mState = setBitTo(kUseExplicitEosStop, explicitEosStop);
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool constexpr isAuto() const
|
||||
{
|
||||
return anyBitSet(kAuto);
|
||||
}
|
||||
|
||||
bool constexpr isTopK() const
|
||||
{
|
||||
return anyBitSet(kTopK);
|
||||
}
|
||||
|
||||
bool constexpr isTopP() const
|
||||
{
|
||||
return anyBitSet(kTopP);
|
||||
}
|
||||
|
||||
bool constexpr isTopKorTopP() const
|
||||
{
|
||||
return anyBitSet(kTopKTopP);
|
||||
}
|
||||
|
||||
bool constexpr isTopKandTopP() const
|
||||
{
|
||||
return allBitSet(kTopKTopP);
|
||||
}
|
||||
|
||||
bool constexpr isBeamSearch() const
|
||||
{
|
||||
return anyBitSet(kBeamSearch);
|
||||
}
|
||||
|
||||
bool constexpr isMedusa() const
|
||||
{
|
||||
return anyBitSet(kMedusa);
|
||||
}
|
||||
|
||||
bool constexpr isLookahead() const
|
||||
{
|
||||
return anyBitSet(kLookahead);
|
||||
}
|
||||
|
||||
bool constexpr isExplicitDraftTokens() const
|
||||
{
|
||||
return anyBitSet(kExplicitDraftTokens);
|
||||
}
|
||||
|
||||
bool constexpr isUseTemperature() const
|
||||
{
|
||||
return anyBitSet(kUseTemperature);
|
||||
}
|
||||
|
||||
bool constexpr isUsePresencePenalty() const
|
||||
{
|
||||
return anyBitSet(kUsePresencePenalties);
|
||||
}
|
||||
|
||||
bool constexpr isUseFrequencyPenalty() const
|
||||
{
|
||||
return anyBitSet(kUseFrequencyPenalties);
|
||||
}
|
||||
|
||||
bool constexpr isUseRepetitionPenalty() const
|
||||
{
|
||||
return anyBitSet(kUseRepetitionPenalties);
|
||||
}
|
||||
|
||||
bool constexpr isUseMinLength() const
|
||||
{
|
||||
return anyBitSet(kUseMinLength);
|
||||
}
|
||||
|
||||
bool constexpr isUseOccurrencePenalty() const
|
||||
{
|
||||
return anyBitSet(kUseOccurrencePenalties);
|
||||
}
|
||||
|
||||
bool constexpr isUsePenalty() const
|
||||
{
|
||||
return anyBitSet(kUsePenalties);
|
||||
}
|
||||
|
||||
bool constexpr isUseBanWords() const
|
||||
{
|
||||
return anyBitSet(kUseBanWords);
|
||||
}
|
||||
|
||||
bool constexpr isUseStopWords() const
|
||||
{
|
||||
return anyBitSet(kUseStopWords);
|
||||
}
|
||||
|
||||
bool constexpr isUseMaxLengthStop() const
|
||||
{
|
||||
return anyBitSet(kUseMaxLengthStop);
|
||||
}
|
||||
|
||||
bool constexpr isUseExplicitEosStop() const
|
||||
{
|
||||
return anyBitSet(kUseExplicitEosStop);
|
||||
}
|
||||
|
||||
bool constexpr isUseStopCriteria() const
|
||||
{
|
||||
return anyBitSet(kStandardStopCriteria | kUseExplicitEosStop);
|
||||
}
|
||||
|
||||
using UnderlyingType = uint32_t;
|
||||
|
||||
bool operator==(DecodingMode const& other) const
|
||||
{
|
||||
return mState == other.mState;
|
||||
}
|
||||
|
||||
constexpr DecodingMode(UnderlyingType state)
|
||||
: mState(state)
|
||||
{
|
||||
}
|
||||
|
||||
constexpr UnderlyingType getState() const
|
||||
{
|
||||
return mState;
|
||||
}
|
||||
|
||||
private:
|
||||
// No mode specified. Config will be determined from the beam width of the first request at runtime
|
||||
// TopKTopP if beamWidth == 1, BeamSearch otherwise
|
||||
static UnderlyingType constexpr kUseRepetitionPenalties{1u << 0};
|
||||
static UnderlyingType constexpr kUseFrequencyPenalties{1u << 1};
|
||||
static UnderlyingType constexpr kUsePresencePenalties{1u << 2};
|
||||
static UnderlyingType constexpr kUseTemperature{1u << 3};
|
||||
static UnderlyingType constexpr kUseMinLength{1u << 4};
|
||||
static UnderlyingType constexpr kUseBanWords{1u << 5};
|
||||
static UnderlyingType constexpr kUseStopWords{1u << 6};
|
||||
static UnderlyingType constexpr kUseMaxLengthStop{1u << 7};
|
||||
static UnderlyingType constexpr kUseExplicitEosStop{1u << 8};
|
||||
static UnderlyingType constexpr kStandardStopCriteria{kUseStopWords | kUseMaxLengthStop};
|
||||
static UnderlyingType constexpr kUseOccurrencePenalties{
|
||||
kUseRepetitionPenalties | kUseFrequencyPenalties | kUsePresencePenalties};
|
||||
static UnderlyingType constexpr kUsePenalties{kUseOccurrencePenalties | kUseTemperature | kUseMinLength};
|
||||
static SizeType32 constexpr kNumFlags{9};
|
||||
static UnderlyingType constexpr kAuto{1u << (kNumFlags + 0)};
|
||||
static UnderlyingType constexpr kTopK{1u << (kNumFlags + 1)};
|
||||
static UnderlyingType constexpr kTopP{1u << (kNumFlags + 2)};
|
||||
static UnderlyingType constexpr kBeamSearch{1u << (kNumFlags + 3)};
|
||||
static UnderlyingType constexpr kMedusa{1u << (kNumFlags + 4)};
|
||||
static UnderlyingType constexpr kLookahead{1u << (kNumFlags + 5)};
|
||||
static UnderlyingType constexpr kExplicitDraftTokens{1u << (kNumFlags + 6)};
|
||||
static UnderlyingType constexpr kTopKTopP{kTopK | kTopP};
|
||||
|
||||
bool constexpr anyBitSet(UnderlyingType bits) const
|
||||
{
|
||||
return (mState & bits) != 0;
|
||||
}
|
||||
|
||||
bool constexpr allBitSet(UnderlyingType bits) const
|
||||
{
|
||||
return (mState & bits) == bits;
|
||||
}
|
||||
|
||||
UnderlyingType constexpr setBitTo(UnderlyingType state, bool x)
|
||||
{
|
||||
return (mState & (~state)) | (state * static_cast<UnderlyingType>(x));
|
||||
}
|
||||
|
||||
UnderlyingType mState{};
|
||||
};
|
||||
|
||||
static_assert(DecodingMode::Auto().isAuto());
|
||||
static_assert(!DecodingMode::Auto().isUseBanWords());
|
||||
static_assert(!DecodingMode::Auto().isUseOccurrencePenalty());
|
||||
static_assert(!DecodingMode::Auto().isUseStopCriteria());
|
||||
static_assert(!DecodingMode::Auto().isTopK());
|
||||
static_assert(!DecodingMode::Auto().isTopP());
|
||||
static_assert(!DecodingMode::Auto().isBeamSearch());
|
||||
static_assert(!DecodingMode::Auto().isMedusa());
|
||||
static_assert(!DecodingMode::Auto().isLookahead());
|
||||
static_assert(!DecodingMode::Auto().isExplicitDraftTokens());
|
||||
|
||||
static_assert(DecodingMode::TopK().isTopK());
|
||||
static_assert(DecodingMode::TopK().isTopKorTopP());
|
||||
static_assert(DecodingMode::TopK().isUseBanWords());
|
||||
static_assert(DecodingMode::TopK().isUseOccurrencePenalty());
|
||||
static_assert(DecodingMode::TopK().isUseStopCriteria());
|
||||
static_assert(!DecodingMode::TopK().useRepetitionPenalty(false).isUseRepetitionPenalty());
|
||||
static_assert(DecodingMode::TopK().useRepetitionPenalty(false).isUseOccurrencePenalty());
|
||||
static_assert(!DecodingMode::TopK()
|
||||
.useRepetitionPenalty(false)
|
||||
.usePresencePenalty(false)
|
||||
.useFrequencyPenalty(false)
|
||||
.isUseOccurrencePenalty());
|
||||
static_assert(!DecodingMode::TopK().isTopKandTopP());
|
||||
static_assert(!DecodingMode::TopK().isTopP());
|
||||
static_assert(!DecodingMode::TopK().isBeamSearch());
|
||||
static_assert(!DecodingMode::TopK().isMedusa());
|
||||
static_assert(!DecodingMode::TopK().isLookahead());
|
||||
static_assert(!DecodingMode::TopK().isAuto());
|
||||
static_assert(!DecodingMode::TopK().isExplicitDraftTokens());
|
||||
|
||||
static_assert(DecodingMode::TopP().isTopP());
|
||||
static_assert(DecodingMode::TopP().isTopKorTopP());
|
||||
static_assert(DecodingMode::TopP().isUseBanWords());
|
||||
static_assert(DecodingMode::TopP().isUseOccurrencePenalty());
|
||||
static_assert(DecodingMode::TopP().isUseStopCriteria());
|
||||
static_assert(!DecodingMode::TopP().isTopKandTopP());
|
||||
static_assert(!DecodingMode::TopP().isTopK());
|
||||
static_assert(!DecodingMode::TopP().isBeamSearch());
|
||||
static_assert(!DecodingMode::TopP().isMedusa());
|
||||
static_assert(!DecodingMode::TopP().isLookahead());
|
||||
static_assert(!DecodingMode::TopP().isAuto());
|
||||
static_assert(!DecodingMode::TopP().isExplicitDraftTokens());
|
||||
|
||||
static_assert(DecodingMode::TopKTopP().isTopK());
|
||||
static_assert(DecodingMode::TopKTopP().isTopP());
|
||||
static_assert(DecodingMode::TopKTopP().isTopKorTopP());
|
||||
static_assert(DecodingMode::TopKTopP().isTopKandTopP());
|
||||
static_assert(DecodingMode::TopKTopP().isUseBanWords());
|
||||
static_assert(DecodingMode::TopKTopP().isUseOccurrencePenalty());
|
||||
static_assert(DecodingMode::TopKTopP().isUseStopCriteria());
|
||||
static_assert(!DecodingMode::TopKTopP().isBeamSearch());
|
||||
static_assert(!DecodingMode::TopKTopP().isMedusa());
|
||||
static_assert(!DecodingMode::TopKTopP().isLookahead());
|
||||
static_assert(!DecodingMode::TopKTopP().isAuto());
|
||||
static_assert(!DecodingMode::TopKTopP().isExplicitDraftTokens());
|
||||
|
||||
static_assert(DecodingMode::BeamSearch().isBeamSearch());
|
||||
static_assert(DecodingMode::BeamSearch().isUseStopCriteria());
|
||||
static_assert(!DecodingMode::BeamSearch().isTopKorTopP());
|
||||
static_assert(!DecodingMode::BeamSearch().isMedusa());
|
||||
static_assert(!DecodingMode::BeamSearch().isLookahead());
|
||||
static_assert(!DecodingMode::BeamSearch().isAuto());
|
||||
static_assert(!DecodingMode::BeamSearch().isExplicitDraftTokens());
|
||||
|
||||
static_assert(!DecodingMode::Medusa().isTopK());
|
||||
static_assert(!DecodingMode::Medusa().isTopKorTopP());
|
||||
static_assert(!DecodingMode::Medusa().isTopKandTopP());
|
||||
static_assert(!DecodingMode::Medusa().isTopP());
|
||||
static_assert(!DecodingMode::Medusa().isBeamSearch());
|
||||
static_assert(!DecodingMode::Medusa().isLookahead());
|
||||
static_assert(!DecodingMode::Medusa().isAuto());
|
||||
static_assert(!DecodingMode::Medusa().isUseBanWords());
|
||||
static_assert(!DecodingMode::Medusa().isUseOccurrencePenalty());
|
||||
static_assert(!DecodingMode::Medusa().isExplicitDraftTokens());
|
||||
static_assert(DecodingMode::Medusa().isUseStopCriteria());
|
||||
static_assert(!DecodingMode::Medusa().isUseStopWords());
|
||||
static_assert(!DecodingMode::Medusa().isUseExplicitEosStop());
|
||||
static_assert(DecodingMode::Medusa().isUsePenalty());
|
||||
static_assert(DecodingMode::Medusa().isUseMinLength());
|
||||
static_assert(DecodingMode::Medusa().isMedusa());
|
||||
|
||||
static_assert(!DecodingMode::Lookahead().isAuto());
|
||||
static_assert(!DecodingMode::Lookahead().isTopK());
|
||||
static_assert(!DecodingMode::Lookahead().isTopKorTopP());
|
||||
static_assert(!DecodingMode::Lookahead().isTopKandTopP());
|
||||
static_assert(!DecodingMode::Lookahead().isTopP());
|
||||
static_assert(!DecodingMode::Lookahead().isBeamSearch());
|
||||
static_assert(!DecodingMode::Lookahead().isMedusa());
|
||||
static_assert(!DecodingMode::Lookahead().isExplicitDraftTokens());
|
||||
static_assert(DecodingMode::Lookahead().isUseStopCriteria());
|
||||
static_assert(!DecodingMode::Lookahead().isUseStopWords());
|
||||
static_assert(!DecodingMode::Lookahead().isUseExplicitEosStop());
|
||||
static_assert(DecodingMode::Lookahead().isLookahead());
|
||||
|
||||
static_assert(!DecodingMode::ExplicitDraftTokens().isAuto());
|
||||
static_assert(!DecodingMode::ExplicitDraftTokens().isTopK());
|
||||
static_assert(!DecodingMode::ExplicitDraftTokens().isTopKorTopP());
|
||||
static_assert(!DecodingMode::ExplicitDraftTokens().isTopKandTopP());
|
||||
static_assert(!DecodingMode::ExplicitDraftTokens().isTopP());
|
||||
static_assert(!DecodingMode::ExplicitDraftTokens().isBeamSearch());
|
||||
static_assert(!DecodingMode::ExplicitDraftTokens().isMedusa());
|
||||
static_assert(!DecodingMode::ExplicitDraftTokens().isLookahead());
|
||||
static_assert(!DecodingMode::ExplicitDraftTokens().isUsePenalty());
|
||||
static_assert(DecodingMode::ExplicitDraftTokens().isUseStopCriteria());
|
||||
static_assert(DecodingMode::ExplicitDraftTokens().isUseMaxLengthStop());
|
||||
static_assert(DecodingMode::ExplicitDraftTokens().isUseExplicitEosStop());
|
||||
static_assert(!DecodingMode::ExplicitDraftTokens().isUseStopWords());
|
||||
static_assert(!DecodingMode::ExplicitDraftTokens().isUseBanWords());
|
||||
static_assert(DecodingMode::ExplicitDraftTokens().isExplicitDraftTokens());
|
||||
} // namespace tensorrt_llm::executor
|
||||
|
||||
@ -81,10 +81,10 @@ public:
|
||||
class MedusaInputs
|
||||
{
|
||||
public:
|
||||
TensorPtr medusaPaths; // [maxBatchSize, maxTokensPerStep, maxMedusaHeads + 1], on gpu
|
||||
TensorPtr medusaTreeIds; // [maxBatchSize, maxTokensPerStep], on gpu
|
||||
TensorPtr medusaPaths; // [maxBatchSize, maxTokensPerStep, maxMedusaHeads + 1], on gpu
|
||||
TensorPtr medusaTreeIds; // [maxBatchSize, maxTokensPerStep], on gpu
|
||||
std::vector<std::vector<TensorPtr>>
|
||||
medusaLogits; // [maxBatchSize][maxMedusaHeads][tokensPerStep, vocabSizePadded], on gpu
|
||||
medusaLogits; // [maxBatchSize][maxAcceptedDraftTokensPerStep][maxDraftTokens + 1, vocabSizePadded], on gpu
|
||||
TensorPtr medusaCurTokensPerStep; // [maxBatchSize], on gpu
|
||||
TensorPtr medusaTargetTokensPerStep; // [maxBatchSize], on gpu
|
||||
};
|
||||
|
||||
@ -1,190 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
|
||||
class DecodingMode
|
||||
{
|
||||
public:
|
||||
static auto constexpr None()
|
||||
{
|
||||
return DecodingMode{kNone};
|
||||
}
|
||||
|
||||
static auto constexpr TopK()
|
||||
{
|
||||
return DecodingMode{kTopK};
|
||||
}
|
||||
|
||||
static auto constexpr TopP()
|
||||
{
|
||||
return DecodingMode{kTopP};
|
||||
}
|
||||
|
||||
static auto constexpr TopKTopP()
|
||||
{
|
||||
return DecodingMode{kTopKTopP};
|
||||
}
|
||||
|
||||
static auto constexpr BeamSearch()
|
||||
{
|
||||
return DecodingMode{kBeamSearch};
|
||||
}
|
||||
|
||||
static auto constexpr Medusa()
|
||||
{
|
||||
return DecodingMode{kMedusa};
|
||||
}
|
||||
|
||||
bool constexpr isNone() const
|
||||
{
|
||||
return mState == 0;
|
||||
}
|
||||
|
||||
bool constexpr isTopK() const
|
||||
{
|
||||
return anyBitSet(kTopK);
|
||||
}
|
||||
|
||||
bool constexpr isTopP() const
|
||||
{
|
||||
return anyBitSet(kTopP);
|
||||
}
|
||||
|
||||
bool constexpr isTopKorTopP() const
|
||||
{
|
||||
return anyBitSet(kTopKTopP);
|
||||
}
|
||||
|
||||
bool constexpr isTopKandTopP() const
|
||||
{
|
||||
return allBitSet(kTopKTopP);
|
||||
}
|
||||
|
||||
bool constexpr isBeamSearch() const
|
||||
{
|
||||
return anyBitSet(kBeamSearch);
|
||||
}
|
||||
|
||||
bool constexpr isMedusa() const
|
||||
{
|
||||
return anyBitSet(kMedusa);
|
||||
}
|
||||
|
||||
using UnderlyingType = uint8_t;
|
||||
|
||||
bool operator==(DecodingMode const& other) const
|
||||
{
|
||||
return mState == other.mState;
|
||||
}
|
||||
|
||||
static DecodingMode fromExecutor(executor::DecodingMode decodingMode)
|
||||
{
|
||||
switch (decodingMode)
|
||||
{
|
||||
case executor::DecodingMode::kNONE: return DecodingMode::None();
|
||||
|
||||
case executor::DecodingMode::kTOP_K: return DecodingMode::TopK();
|
||||
|
||||
case executor::DecodingMode::kTOP_P: return DecodingMode::TopP();
|
||||
|
||||
case executor::DecodingMode::kBEAM_SEARCH: return DecodingMode::BeamSearch();
|
||||
|
||||
case executor::DecodingMode::kMEDUSA: return DecodingMode::Medusa();
|
||||
|
||||
case executor::DecodingMode::kTOP_K_TOP_P: return DecodingMode::TopKTopP();
|
||||
|
||||
default: TLLM_THROW("Invalid decoding mode"); break;
|
||||
}
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, DecodingMode other);
|
||||
|
||||
private:
|
||||
constexpr DecodingMode(UnderlyingType state)
|
||||
: mState(state)
|
||||
{
|
||||
}
|
||||
|
||||
// No mode specified. Config will be determined from the beam width of the first request at runtime
|
||||
// TopKTopP if beamWidth == 1, BeamSearch otherwise
|
||||
static UnderlyingType constexpr kNone{0};
|
||||
static UnderlyingType constexpr kTopK{1u << 0};
|
||||
static UnderlyingType constexpr kTopP{1u << 1};
|
||||
static UnderlyingType constexpr kBeamSearch{1u << 2};
|
||||
static UnderlyingType constexpr kMedusa{1u << 3};
|
||||
static UnderlyingType constexpr kTopKTopP{kTopK | kTopP};
|
||||
|
||||
bool constexpr anyBitSet(UnderlyingType bits) const
|
||||
{
|
||||
return (mState & bits) != 0;
|
||||
}
|
||||
|
||||
bool constexpr allBitSet(UnderlyingType bits) const
|
||||
{
|
||||
return (mState & bits) == bits;
|
||||
}
|
||||
|
||||
UnderlyingType mState{};
|
||||
};
|
||||
|
||||
static_assert(DecodingMode::None().isNone());
|
||||
static_assert(!DecodingMode::None().isTopK());
|
||||
static_assert(!DecodingMode::None().isTopP());
|
||||
static_assert(!DecodingMode::None().isBeamSearch());
|
||||
static_assert(!DecodingMode::None().isMedusa());
|
||||
|
||||
static_assert(DecodingMode::TopK().isTopK());
|
||||
static_assert(DecodingMode::TopK().isTopKorTopP());
|
||||
static_assert(!DecodingMode::TopK().isTopKandTopP());
|
||||
static_assert(!DecodingMode::TopK().isTopP());
|
||||
static_assert(!DecodingMode::TopK().isBeamSearch());
|
||||
static_assert(!DecodingMode::TopK().isMedusa());
|
||||
|
||||
static_assert(DecodingMode::TopP().isTopP());
|
||||
static_assert(DecodingMode::TopP().isTopKorTopP());
|
||||
static_assert(!DecodingMode::TopP().isTopKandTopP());
|
||||
static_assert(!DecodingMode::TopP().isTopK());
|
||||
static_assert(!DecodingMode::TopP().isBeamSearch());
|
||||
static_assert(!DecodingMode::TopP().isMedusa());
|
||||
|
||||
static_assert(DecodingMode::TopKTopP().isTopK());
|
||||
static_assert(DecodingMode::TopKTopP().isTopP());
|
||||
static_assert(DecodingMode::TopKTopP().isTopKorTopP());
|
||||
static_assert(DecodingMode::TopKTopP().isTopKandTopP());
|
||||
static_assert(!DecodingMode::TopKTopP().isBeamSearch());
|
||||
static_assert(!DecodingMode::TopKTopP().isMedusa());
|
||||
|
||||
static_assert(DecodingMode::BeamSearch().isBeamSearch());
|
||||
static_assert(!DecodingMode::BeamSearch().isTopKorTopP());
|
||||
static_assert(!DecodingMode::BeamSearch().isMedusa());
|
||||
|
||||
static_assert(!DecodingMode::Medusa().isTopK());
|
||||
static_assert(!DecodingMode::Medusa().isTopKorTopP());
|
||||
static_assert(!DecodingMode::Medusa().isTopKandTopP());
|
||||
static_assert(!DecodingMode::Medusa().isTopP());
|
||||
static_assert(!DecodingMode::Medusa().isBeamSearch());
|
||||
static_assert(DecodingMode::Medusa().isMedusa());
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace tensorrt_llm
|
||||
@ -88,17 +88,17 @@ public:
|
||||
|
||||
BeamHypotheses beamHypotheses;
|
||||
|
||||
// Medusa
|
||||
class MedusaOutputs
|
||||
// Speculative decoding
|
||||
class SpeculativeDecodingOutputs
|
||||
{
|
||||
public:
|
||||
TensorPtr medusaNextDraftTokens; // [maxBatchSize, maxTokensPerStep]
|
||||
TensorPtr medusaAcceptedTokensLen; // [maxBatchSize]
|
||||
TensorPtr medusaAcceptedLengthsCumSum; // [maxBatchSize + 1]
|
||||
TensorPtr medusaPathsOffsets; // [maxBatchSize * maxNumHeads]
|
||||
TensorPtr nextDraftTokens; // [maxBatchSize, maxDraftTokens]
|
||||
TensorPtr acceptedTokensLen; // [maxBatchSize]
|
||||
TensorPtr acceptedLengthsCumSum; // [maxBatchSize + 1]
|
||||
TensorPtr pathsOffsets; // [maxBatchSize, maxAcceptedDraftTokensPerStep]
|
||||
};
|
||||
|
||||
std::optional<MedusaOutputs> medusaOutputs;
|
||||
std::optional<SpeculativeDecodingOutputs> speculativeDecodingOutputs;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -16,10 +16,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/decodingInput.h"
|
||||
#include "tensorrt_llm/runtime/decodingMode.h"
|
||||
#include "tensorrt_llm/runtime/decodingOutput.h"
|
||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||
#include "tensorrt_llm/runtime/samplingConfig.h"
|
||||
@ -50,14 +50,14 @@ public:
|
||||
|
||||
virtual ~IGptDecoder() = default;
|
||||
|
||||
virtual void setup(SamplingConfig const& samplingConfig, size_t batchSize, SizeType32 maxSequenceLength,
|
||||
virtual void setup(SamplingConfig const& samplingConfig, size_t batchSize,
|
||||
std::optional<TensorPtr> const& batchSlots = std::nullopt)
|
||||
= 0;
|
||||
|
||||
virtual bool forward(DecodingOutput& output, DecodingInput const& input) = 0;
|
||||
|
||||
virtual void forwardAsync(DecodingOutput& output, DecodingInput const& input) = 0;
|
||||
|
||||
virtual void forwardSync(DecodingOutput& output, DecodingInput const& input) = 0;
|
||||
|
||||
virtual void gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput,
|
||||
DecodingInput const& decodingInput, BufferManager const& manager)
|
||||
= 0;
|
||||
@ -74,10 +74,10 @@ public:
|
||||
SizeType32 vocabSize, SizeType32 vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold,
|
||||
curandState_t* curandState, BufferManager::CudaStreamPtr const& stream);
|
||||
|
||||
static std::unique_ptr<IGptDecoder> create(DecodingMode const& mode, nvinfer1::DataType dtype, size_t maxBatchSize,
|
||||
size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
|
||||
static std::unique_ptr<IGptDecoder> create(executor::DecodingMode const& mode, nvinfer1::DataType dtype,
|
||||
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
|
||||
BufferManager::CudaStreamPtr const& stream, std::optional<runtime::SizeType32> maxTokensPerStep = std::nullopt,
|
||||
std::optional<runtime::SizeType32> maxNumMedusaHeads = std::nullopt);
|
||||
std::optional<runtime::SizeType32> maxAcceptedDraftTokensPerStep = std::nullopt);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -88,18 +88,18 @@ public:
|
||||
using CudaStreamPtr = BufferManager::CudaStreamPtr;
|
||||
using TensorPtr = std::shared_ptr<ITensor>;
|
||||
|
||||
GptDecoder(DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize,
|
||||
GptDecoder(executor::DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize,
|
||||
size_t vocabSizePadded, size_t maxSequenceLength, CudaStreamPtr const& stream,
|
||||
std::optional<runtime::SizeType32> maxTokensPerStep = std::nullopt,
|
||||
std::optional<runtime::SizeType32> maxNumMedusaHeads = std::nullopt);
|
||||
std::optional<runtime::SizeType32> maxAcceptedDraftTokensPerStep = std::nullopt);
|
||||
|
||||
void setup(SamplingConfig const& samplingConfig, size_t batchSize, SizeType32 maxSequenceLength,
|
||||
void setup(SamplingConfig const& samplingConfig, size_t batchSize,
|
||||
std::optional<TensorPtr> const& batchSlots = std::nullopt) override;
|
||||
|
||||
bool forward(DecodingOutput& output, DecodingInput const& input) override;
|
||||
|
||||
void forwardAsync(DecodingOutput& output, DecodingInput const& input) override;
|
||||
|
||||
void forwardSync(DecodingOutput& output, DecodingInput const& input) override;
|
||||
|
||||
void gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput, DecodingInput const& decodingInput,
|
||||
BufferManager const& manager) override;
|
||||
|
||||
@ -119,19 +119,19 @@ private:
|
||||
size_t mMaxBatchSize;
|
||||
};
|
||||
|
||||
inline std::unique_ptr<IGptDecoder> IGptDecoder::create(DecodingMode const& mode, nvinfer1::DataType dtype,
|
||||
inline std::unique_ptr<IGptDecoder> IGptDecoder::create(executor::DecodingMode const& mode, nvinfer1::DataType dtype,
|
||||
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
|
||||
BufferManager::CudaStreamPtr const& stream, std::optional<runtime::SizeType32> maxTokensPerStep,
|
||||
std::optional<runtime::SizeType32> maxNumMedusaHeads)
|
||||
std::optional<runtime::SizeType32> maxAcceptedDraftTokensPerStep)
|
||||
{
|
||||
switch (dtype)
|
||||
{
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
return std::make_unique<GptDecoder<float>>(mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded,
|
||||
maxSequenceLength, stream, maxTokensPerStep, maxNumMedusaHeads);
|
||||
maxSequenceLength, stream, maxTokensPerStep, maxAcceptedDraftTokensPerStep);
|
||||
case nvinfer1::DataType::kHALF:
|
||||
return std::make_unique<GptDecoder<half>>(mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded,
|
||||
maxSequenceLength, stream, maxTokensPerStep, maxNumMedusaHeads);
|
||||
maxSequenceLength, stream, maxTokensPerStep, maxAcceptedDraftTokensPerStep);
|
||||
default: return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
@ -41,7 +41,7 @@ public:
|
||||
GptDecoderBatch(std::size_t vocabSize, std::size_t vocabSizePadded, CudaStreamPtr stream);
|
||||
|
||||
//! Setup the decoder before calling `forward()`
|
||||
void setup(DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
|
||||
void setup(executor::DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
|
||||
SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, SizeType32 maxSequenceLength,
|
||||
SizeType32 maxTokensPerStep, bool fusedDecoder, nvinfer1::DataType dtype,
|
||||
ModelConfig const& modelConfig) override;
|
||||
@ -56,6 +56,9 @@ public:
|
||||
|
||||
void forwardSync(decoder_batch::Token const& token) override;
|
||||
|
||||
void forwardSync(
|
||||
decoder_batch::Token const& token, decoder_batch::Output& output, decoder_batch::Input const& input) override;
|
||||
|
||||
void forwardAsync(decoder::Output& output, decoder::Input const& input) override;
|
||||
|
||||
void forwardSync() override;
|
||||
@ -154,22 +157,22 @@ public:
|
||||
return mFinishedSum;
|
||||
}
|
||||
|
||||
//! @returns [batchSize, maxTokensPerStep-1], predicted draft tokens for next step, on gpu
|
||||
//! @returns [batchSize, maxDraftTokens], predicted draft tokens for next step, on gpu
|
||||
[[nodiscard]] TensorPtr getNextDraftTokens() const override
|
||||
{
|
||||
return mJointDecodingOutput->medusaOutputs->medusaNextDraftTokens;
|
||||
return mJointDecodingOutput->speculativeDecodingOutputs->nextDraftTokens;
|
||||
}
|
||||
|
||||
//! @returns [batchSize + 1], exclusive sum of accepted draft token lengths, on gpu
|
||||
[[nodiscard]] TensorPtr getMedusaAcceptedLengthsCumSum() const override
|
||||
[[nodiscard]] TensorPtr getSpecDecodingAcceptedLengthsCumSum() const override
|
||||
{
|
||||
return mJointDecodingOutput->medusaOutputs->medusaAcceptedLengthsCumSum;
|
||||
return mJointDecodingOutput->speculativeDecodingOutputs->acceptedLengthsCumSum;
|
||||
}
|
||||
|
||||
//! @returns [batchSize * maxMedusaHeads], accepted paths packed into continuous tensor, on gpu
|
||||
[[nodiscard]] TensorPtr getMedusaAcceptedPackedPaths() const override
|
||||
//! @returns [batchSize, maxAcceptedDraftTokensPerStep], accepted paths packed into continuous tensor, on gpu
|
||||
[[nodiscard]] TensorPtr getSpecDecodingAcceptedPackedPaths() const override
|
||||
{
|
||||
return mJointDecodingOutput->medusaOutputs->medusaPathsOffsets;
|
||||
return mJointDecodingOutput->speculativeDecodingOutputs->pathsOffsets;
|
||||
}
|
||||
|
||||
private:
|
||||
@ -189,16 +192,30 @@ private:
|
||||
void newRequestSpeculativeDecoding(
|
||||
SizeType32 batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig);
|
||||
|
||||
//! @brief Setups decoder internal tensors for new request in Draft model Sps mode
|
||||
void newRequestDraftTokensExternal(
|
||||
SizeType32 batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig);
|
||||
|
||||
//! @brief Setups decoder internal tensors for new Medusa request
|
||||
void newRequestMedusa(SizeType32 batchIdx, decoder_batch::Request const& request);
|
||||
|
||||
//! @brief Asynchronously calls unfused decoder for whole batch in loop
|
||||
void forwardAsyncUnfusedDecoder(
|
||||
SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input, CudaEvent const& eventStart);
|
||||
//! @brief Setups decoder internal tensors for new Lookahead request
|
||||
void newRequestLookahead(SizeType32 batchIdx, decoder_batch::Request const& request);
|
||||
|
||||
//! @brief Asynchronously calls fused decoder for whole batch
|
||||
void forwardAsyncFusedDecoder(
|
||||
SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input, CudaEvent const& eventStart);
|
||||
//! @brief Updates finished state on host for all active requests
|
||||
void updateFinished(decoder_batch::Token const& token);
|
||||
|
||||
//! @brief Calls unfused or fused decoders for tokens per engine step
|
||||
void forwardDispatch(
|
||||
decoder_batch::Output& output, decoder_batch::Input const& input, std::optional<CudaEvent> const& eventStart);
|
||||
|
||||
//! @brief Calls unfused decoder for whole batch in loop
|
||||
void forwardUnfusedDecoder(SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input,
|
||||
std::optional<CudaEvent> const& eventStart);
|
||||
|
||||
//! @brief Calls fused decoder for whole batch
|
||||
void forwardFusedDecoder(SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input,
|
||||
std::optional<CudaEvent> const& eventStart);
|
||||
|
||||
private:
|
||||
std::size_t const mVocabSize;
|
||||
@ -255,6 +272,6 @@ private:
|
||||
SizeType32 mMaxTokensPerDecoderStep{};
|
||||
|
||||
bool mFusedDecoder{false};
|
||||
bool mUseMedusa{false};
|
||||
SpeculativeDecodingMode mSpeculativeDecodingMode;
|
||||
};
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -24,10 +24,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/cudaEvent.h"
|
||||
#include "tensorrt_llm/runtime/decodingMode.h"
|
||||
#include "tensorrt_llm/runtime/generationInput.h"
|
||||
#include "tensorrt_llm/runtime/generationOutput.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
@ -112,7 +112,7 @@ public:
|
||||
// The micro batch size to be used in generation phase.
|
||||
// Batches entered in `GptSession::generation` will be split into smaller micro batches of this size.
|
||||
std::optional<SizeType32> genMicroBatchSize = std::nullopt;
|
||||
std::optional<DecodingMode> decodingMode = std::nullopt;
|
||||
std::optional<executor::DecodingMode> decodingMode = std::nullopt;
|
||||
bool normalizeLogProbs = true;
|
||||
};
|
||||
|
||||
@ -255,7 +255,7 @@ private:
|
||||
void createBuffers(SizeType32 numMicroBatches);
|
||||
void createDecoders(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 maxAttentionWindow,
|
||||
SizeType32 sinkTokenLength, SizeType32 maxSequenceLength, nvinfer1::DataType logitsType, bool decoderPerRequest,
|
||||
SizeType32 numMicroBatches, DecodingMode const& decodingMode);
|
||||
SizeType32 numMicroBatches, executor::DecodingMode const& decodingMode);
|
||||
void createKvCacheManager(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 maxAttentionWindow,
|
||||
SizeType32 sinkTokenLength, SizeType32 maxSequenceLength, KvCacheConfig const& config);
|
||||
void createCustomAllReduceWorkspace(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 maxSequenceLength);
|
||||
|
||||
@ -44,8 +44,6 @@ public:
|
||||
, inputLen(inputLen)
|
||||
, maxNewTokens{maxNewTokens}
|
||||
, endId{endId}
|
||||
, computeCumLogProbs(false)
|
||||
, computeLogProbs(false)
|
||||
, generatedTokensPerEngineStep(1)
|
||||
{
|
||||
}
|
||||
@ -64,11 +62,9 @@ public:
|
||||
TensorPtr badWordsList; // [2, badWordsLength], on gpu
|
||||
TensorPtr stopWordsList; // [2, stopWordsLength], on gpu
|
||||
|
||||
bool computeCumLogProbs; // boolean that controls if cumLogProbs should be computed for that request
|
||||
bool computeLogProbs; // boolean that controls if cumLogProbs should be computed for that request
|
||||
SizeType32 generatedTokensPerEngineStep;
|
||||
TensorPtr medusaPaths; // [tokensPerStep, medusaHeads + 1], on gpu
|
||||
TensorPtr medusaTreeIds; // [tokensPerStep], on gpu
|
||||
TensorPtr medusaPaths; // [maxDraftTokens + 1, maxAcceptedDraftTokensPerStep + 1], on gpu
|
||||
TensorPtr medusaTreeIds; // [maxDraftTokens + 1], on gpu
|
||||
};
|
||||
|
||||
class Input
|
||||
@ -112,7 +108,7 @@ public:
|
||||
TensorConstPtr cacheIndirection; // [batchSize, maxBeamWidth, maxSeqLen] - indices into KV cache of different rays
|
||||
// within one beam for beam search, on gpu
|
||||
std::vector<std::vector<TensorConstPtr>>
|
||||
medusaLogits; // [maxBatchSize][maxNumHeads][tokensPerStep, vocabSizePadded]
|
||||
predictedDraftLogits; // [maxBatchSize][maxAcceptedDraftTokensPerStep][maxDraftTokens + 1, vocabSizePadded]
|
||||
};
|
||||
|
||||
using Output = decoder::Output;
|
||||
@ -142,6 +138,11 @@ public:
|
||||
//! @brief Run one step for all requests without blocking the host process and return the token for synchronization.
|
||||
virtual TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) = 0;
|
||||
|
||||
//! @brief Call decoder forwardSync and wait for the call to `forwardAsync` associated with a token to complete.
|
||||
virtual void forwardSync(
|
||||
decoder_batch::Token const& token, decoder_batch::Output& output, decoder_batch::Input const& input)
|
||||
= 0;
|
||||
|
||||
//! @brief Wait for the call to `forwardAsync` associated with a token to complete.
|
||||
virtual void forwardSync(decoder_batch::Token const& token) = 0;
|
||||
|
||||
@ -188,10 +189,10 @@ public:
|
||||
virtual TensorPtr getNextDraftTokens() const = 0;
|
||||
|
||||
//! @returns [batchSize + 1], exclusive sum of accepted draft token lengths, on gpu
|
||||
virtual TensorPtr getMedusaAcceptedLengthsCumSum() const = 0;
|
||||
virtual TensorPtr getSpecDecodingAcceptedLengthsCumSum() const = 0;
|
||||
|
||||
//! @returns [batchSize * maxMedusaHeads], accepted paths packed into continuous tensor, on gpu
|
||||
virtual TensorPtr getMedusaAcceptedPackedPaths() const = 0;
|
||||
//! @returns [batchSize, maxAcceptedDraftTokensPerStep], accepted paths packed into continuous tensor, on gpu
|
||||
virtual TensorPtr getSpecDecodingAcceptedPackedPaths() const = 0;
|
||||
|
||||
protected:
|
||||
IGptDecoderBatch() = default;
|
||||
|
||||
@ -16,8 +16,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/decodingMode.h"
|
||||
#include "tensorrt_llm/runtime/generationInput.h"
|
||||
#include "tensorrt_llm/runtime/generationOutput.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
@ -73,7 +73,7 @@ public:
|
||||
using TensorPtr = std::shared_ptr<ITensor>;
|
||||
|
||||
//! Setup the decoder before calling `forward()`, also calls reshapeBuffers
|
||||
virtual void setup(DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
|
||||
virtual void setup(executor::DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
|
||||
SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, SizeType32 maxSequenceLength,
|
||||
SizeType32 maxTokensPerStep, bool fusedDecoder, nvinfer1::DataType dtype, ModelConfig const& modelConfig)
|
||||
= 0;
|
||||
|
||||
@ -55,6 +55,7 @@ private:
|
||||
SizeType32 mTpRank;
|
||||
std::vector<void*> mCommPtrs;
|
||||
BufferPtr mBuffer;
|
||||
bool mOpenIpc;
|
||||
};
|
||||
|
||||
class AllReduceBuffers
|
||||
|
||||
@ -18,8 +18,10 @@
|
||||
|
||||
#include "tensorrt_llm/common/quantization.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/lookaheadModule.h"
|
||||
#include "tensorrt_llm/runtime/loraModule.h"
|
||||
#include "tensorrt_llm/runtime/medusaModule.h"
|
||||
#include "tensorrt_llm/runtime/speculativeDecodingMode.h"
|
||||
#include <NvInferRuntime.h>
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
@ -81,10 +83,10 @@ public:
|
||||
, mUseXQA{false}
|
||||
, mUseLoraPlugin(false)
|
||||
, mMlpHiddenSize(0)
|
||||
, mMedusaModule(std::nullopt)
|
||||
, mUseCrossAttention(false)
|
||||
, mUsePositionEmbedding(true) // TODO: remove these two properties?
|
||||
, mUseTokenTypeEmbedding(false)
|
||||
, mSpeculativeDecodingMode(SpeculativeDecodingMode::None())
|
||||
{
|
||||
}
|
||||
|
||||
@ -441,19 +443,38 @@ public:
|
||||
mMaxLoraRank = maxLoraRank;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr useMedusa() const noexcept
|
||||
void setSpeculativeDecodingMode(SpeculativeDecodingMode mode) noexcept
|
||||
{
|
||||
return mMedusaModule.has_value();
|
||||
mSpeculativeDecodingMode = mode;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::optional<MedusaModule> getMedusaModule() const noexcept
|
||||
[[nodiscard]] bool hasSpeculativeDecodingModule() const noexcept
|
||||
{
|
||||
return mMedusaModule;
|
||||
return mSpeculativeDecodingModule != nullptr;
|
||||
}
|
||||
|
||||
void setMedusaModule(MedusaModule const& medusaModule) noexcept
|
||||
[[nodiscard]] SpeculativeDecodingModule const& getSpeculativeDecodingModule() const noexcept
|
||||
{
|
||||
mMedusaModule = medusaModule;
|
||||
TLLM_CHECK_WITH_INFO(mSpeculativeDecodingModule, "Speculative decoding module is not set");
|
||||
return *mSpeculativeDecodingModule;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::shared_ptr<SpeculativeDecodingModule const> getSpeculativeDecodingModulePtr() const noexcept
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mSpeculativeDecodingModule, "Speculative decoding module is not set");
|
||||
return mSpeculativeDecodingModule;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::shared_ptr<SpeculativeDecodingModule> getSpeculativeDecodingModulePtr() noexcept
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mSpeculativeDecodingModule, "Speculative decoding module is not set");
|
||||
return mSpeculativeDecodingModule;
|
||||
}
|
||||
|
||||
void setSpeculativeDecodingModule(
|
||||
std::shared_ptr<SpeculativeDecodingModule> const& speculativeDecodingModule) noexcept
|
||||
{
|
||||
mSpeculativeDecodingModule = speculativeDecodingModule;
|
||||
}
|
||||
|
||||
[[nodiscard]] nvinfer1::DataType getKvDataType() const noexcept
|
||||
@ -508,6 +529,11 @@ public:
|
||||
mLayerTypes = layerTypes;
|
||||
}
|
||||
|
||||
[[nodiscard]] SpeculativeDecodingMode getSpeculativeDecodingMode() const noexcept
|
||||
{
|
||||
return mSpeculativeDecodingMode;
|
||||
}
|
||||
|
||||
private:
|
||||
SizeType32 mVocabSize;
|
||||
SizeType32 mNbAttentionLayers;
|
||||
@ -546,7 +572,6 @@ private:
|
||||
SizeType32 mMlpHiddenSize;
|
||||
SizeType32 mMaxLoraRank;
|
||||
|
||||
std::optional<MedusaModule> mMedusaModule;
|
||||
std::optional<RnnConfig> mRnnConfig;
|
||||
|
||||
// Configs related to encoder / enc-dec models
|
||||
@ -556,6 +581,9 @@ private:
|
||||
SizeType32 mFfnHiddenSize; // indicates encoder output hidden size
|
||||
|
||||
std::vector<LayerType> mLayerTypes;
|
||||
// Speculative decoding members
|
||||
std::shared_ptr<SpeculativeDecodingModule> mSpeculativeDecodingModule;
|
||||
SpeculativeDecodingMode mSpeculativeDecodingMode;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/layers/defaultDecodingParams.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
@ -75,6 +76,35 @@ private:
|
||||
template <typename T>
|
||||
using Vec = std::vector<T>;
|
||||
|
||||
template <typename T>
|
||||
bool validateVec(std::string name, OptVec<T> const& vec, T min, std::optional<T> max = std::nullopt)
|
||||
{
|
||||
bool valid{true};
|
||||
if (vec)
|
||||
{
|
||||
valid = std::all_of(vec->begin(), vec->end(),
|
||||
[min, max](T elem)
|
||||
{ return min < elem && ((max.has_value() && elem <= max.value()) || (!max.has_value())); });
|
||||
if (!valid)
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << "Incorrect sampling param. " << name << " is out of range (";
|
||||
ss << min << ", ";
|
||||
if (max.has_value())
|
||||
{
|
||||
ss << max.value();
|
||||
}
|
||||
else
|
||||
{
|
||||
ss << "inf";
|
||||
}
|
||||
ss << "]";
|
||||
TLLM_LOG_WARNING(valid, ss.str());
|
||||
}
|
||||
}
|
||||
return valid;
|
||||
}
|
||||
|
||||
public:
|
||||
explicit SamplingConfig(SizeType32 beamWidth = 1)
|
||||
: beamWidth{beamWidth}
|
||||
@ -129,19 +159,24 @@ public:
|
||||
topKMedusaHeads = fuseValues<std::vector<SizeType32>>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].topKMedusaHeads; },
|
||||
layers::DefaultDecodingParams::getTopKMedusaHeads());
|
||||
outputLogProbs = fuseValues<bool>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].outputLogProbs; }, false);
|
||||
cumLogProbs = fuseValues<bool>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].cumLogProbs; }, false);
|
||||
// Only used for tests.
|
||||
draftAcceptanceThreshold = fuseValues<FloatType>(
|
||||
configs, [&configs](size_t ci) { return configs[ci].draftAcceptanceThreshold; }, 0);
|
||||
}
|
||||
|
||||
explicit SamplingConfig(executor::SamplingConfig const& samplingConfig,
|
||||
std::optional<executor::SpeculativeDecodingConfig> const& specDecodingConfig)
|
||||
std::optional<executor::ExternalDraftTokensConfig> const& externalDraftTokensConfig)
|
||||
: beamWidth{samplingConfig.getBeamWidth()}
|
||||
{
|
||||
|
||||
if (specDecodingConfig && specDecodingConfig.value().getAcceptanceThreshold())
|
||||
if (externalDraftTokensConfig && externalDraftTokensConfig.value().getAcceptanceThreshold())
|
||||
{
|
||||
draftAcceptanceThreshold = Vec<FloatType>{specDecodingConfig.value().getAcceptanceThreshold().value()};
|
||||
draftAcceptanceThreshold
|
||||
= Vec<FloatType>{externalDraftTokensConfig.value().getAcceptanceThreshold().value()};
|
||||
}
|
||||
|
||||
#define SET_FROM_OPTIONAL(varName, VarName, VarType) \
|
||||
@ -168,15 +203,69 @@ public:
|
||||
#undef SET_FROM_OPTIONAL
|
||||
}
|
||||
|
||||
bool validate()
|
||||
{
|
||||
auto constexpr fltEpsilon = std::numeric_limits<float>::epsilon();
|
||||
|
||||
bool valid{true};
|
||||
|
||||
valid &= (beamWidth > 0);
|
||||
if (!valid)
|
||||
{
|
||||
TLLM_LOG_WARNING(
|
||||
"Requested beam width %d is incorrect. Must be > 0. To de-activate beam searching set beamWidth to 1.",
|
||||
beamWidth);
|
||||
}
|
||||
valid &= validateVec("topK", topK, -1);
|
||||
valid &= validateVec("topP", topP, -fltEpsilon, {1.f});
|
||||
valid &= validateVec("topPMin", topPMin, 0.f, {1.f});
|
||||
valid &= validateVec("topPDecay", topPDecay, 0.f, {1.f});
|
||||
valid &= validateVec("topPResetIds", topPResetIds, -1);
|
||||
|
||||
valid &= validateVec("temperature", temperature, -fltEpsilon);
|
||||
valid &= validateVec("repetitionPenalty", repetitionPenalty, 0.f);
|
||||
valid &= validateVec("minLength", minLength, -1);
|
||||
|
||||
valid &= validateVec("beamSearchDiversityRate", beamSearchDiversityRate, -fltEpsilon);
|
||||
|
||||
// Detect greedy sampling and overwrite params.
|
||||
if (temperature)
|
||||
{
|
||||
for (size_t ti = 0; ti < temperature->size(); ++ti)
|
||||
{
|
||||
if (temperature->at(ti) == 0.f)
|
||||
{
|
||||
temperature->at(ti) = 1.0f;
|
||||
|
||||
if (topK)
|
||||
{
|
||||
topK->at(ti) = 1;
|
||||
}
|
||||
if (topP)
|
||||
{
|
||||
topP->at(ti) = 1.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return valid;
|
||||
}
|
||||
|
||||
public:
|
||||
SizeType32 beamWidth;
|
||||
|
||||
// penalties
|
||||
OptVec<FloatType> temperature; // [1] or [batch_size] on cpu
|
||||
OptVec<SizeType32> minLength; // [1] or [batch_size] on cpu
|
||||
OptVec<FloatType> repetitionPenalty; // [1] or [batch_size] on cpu
|
||||
OptVec<FloatType> presencePenalty; // [1] or [batch_size] on cpu
|
||||
OptVec<FloatType> frequencyPenalty; // [1] or [batch_size] on cpu
|
||||
|
||||
// probs
|
||||
OptVec<bool> outputLogProbs;
|
||||
OptVec<bool> cumLogProbs;
|
||||
|
||||
// sampling layers
|
||||
OptVec<SizeType32> topK; // [1] or [batch_size] on cpu
|
||||
OptVec<FloatType> topP; // [1] or [batch_size] on cpu
|
||||
@ -207,7 +296,8 @@ public:
|
||||
&& topPResetIds == other.topPResetIds && beamSearchDiversityRate == other.beamSearchDiversityRate
|
||||
&& lengthPenalty == other.lengthPenalty && earlyStopping == other.earlyStopping
|
||||
&& draftAcceptanceThreshold == other.draftAcceptanceThreshold && topKMedusaHeads == other.topKMedusaHeads
|
||||
&& normalizeLogProbs == other.normalizeLogProbs;
|
||||
&& normalizeLogProbs == other.normalizeLogProbs && outputLogProbs == other.outputLogProbs
|
||||
&& cumLogProbs == other.cumLogProbs;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -30,9 +30,9 @@ public:
|
||||
return SpeculativeDecodingMode{kNone};
|
||||
}
|
||||
|
||||
static auto constexpr DraftModel()
|
||||
static auto constexpr DraftTokensExternal()
|
||||
{
|
||||
return SpeculativeDecodingMode{kDraftModel};
|
||||
return SpeculativeDecodingMode{kDraftTokensExternal};
|
||||
}
|
||||
|
||||
static auto constexpr Medusa()
|
||||
@ -50,9 +50,9 @@ public:
|
||||
return anyBitSet(kNone);
|
||||
}
|
||||
|
||||
bool constexpr isDraftModel() const
|
||||
bool constexpr isDraftTokensExternal() const
|
||||
{
|
||||
return anyBitSet(kDraftModel);
|
||||
return anyBitSet(kDraftTokensExternal);
|
||||
}
|
||||
|
||||
bool constexpr isMedusa() const
|
||||
@ -100,7 +100,7 @@ public:
|
||||
private:
|
||||
// No speculative decoding is used.
|
||||
static UnderlyingType constexpr kNone{1u << 0};
|
||||
static UnderlyingType constexpr kDraftModel{1u << 1};
|
||||
static UnderlyingType constexpr kDraftTokensExternal{1u << 1};
|
||||
static UnderlyingType constexpr kMedusa{1u << 2};
|
||||
static UnderlyingType constexpr kLookaheadDecoding{1u << 3};
|
||||
|
||||
@ -118,23 +118,23 @@ private:
|
||||
};
|
||||
|
||||
static_assert(SpeculativeDecodingMode::None().isNone());
|
||||
static_assert(!SpeculativeDecodingMode::None().isDraftModel());
|
||||
static_assert(!SpeculativeDecodingMode::None().isDraftTokensExternal());
|
||||
static_assert(!SpeculativeDecodingMode::None().isMedusa());
|
||||
static_assert(!SpeculativeDecodingMode::None().isLookaheadDecoding());
|
||||
|
||||
static_assert(SpeculativeDecodingMode::DraftModel().isDraftModel());
|
||||
static_assert(!SpeculativeDecodingMode::DraftModel().isNone());
|
||||
static_assert(!SpeculativeDecodingMode::DraftModel().isMedusa());
|
||||
static_assert(!SpeculativeDecodingMode::DraftModel().isLookaheadDecoding());
|
||||
static_assert(SpeculativeDecodingMode::DraftTokensExternal().isDraftTokensExternal());
|
||||
static_assert(!SpeculativeDecodingMode::DraftTokensExternal().isNone());
|
||||
static_assert(!SpeculativeDecodingMode::DraftTokensExternal().isMedusa());
|
||||
static_assert(!SpeculativeDecodingMode::DraftTokensExternal().isLookaheadDecoding());
|
||||
|
||||
static_assert(SpeculativeDecodingMode::Medusa().isMedusa());
|
||||
static_assert(!SpeculativeDecodingMode::Medusa().isNone());
|
||||
static_assert(!SpeculativeDecodingMode::Medusa().isDraftModel());
|
||||
static_assert(!SpeculativeDecodingMode::Medusa().isDraftTokensExternal());
|
||||
static_assert(!SpeculativeDecodingMode::Medusa().isLookaheadDecoding());
|
||||
|
||||
static_assert(SpeculativeDecodingMode::LookaheadDecoding().isLookaheadDecoding());
|
||||
static_assert(!SpeculativeDecodingMode::LookaheadDecoding().isNone());
|
||||
static_assert(!SpeculativeDecodingMode::LookaheadDecoding().isDraftModel());
|
||||
static_assert(!SpeculativeDecodingMode::LookaheadDecoding().isDraftTokensExternal());
|
||||
static_assert(!SpeculativeDecodingMode::LookaheadDecoding().isMedusa());
|
||||
|
||||
} // namespace runtime
|
||||
|
||||
@ -82,6 +82,11 @@ public:
|
||||
return mDeviceIds[mRank % getGpusPerGroup()];
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getDeviceOf(SizeType32 rank) const noexcept
|
||||
{
|
||||
return mDeviceIds[rank % getGpusPerGroup()];
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 constexpr getPipelineParallelRank() const noexcept
|
||||
{
|
||||
return mRank / mTensorParallelism;
|
||||
@ -92,6 +97,21 @@ public:
|
||||
return mRank % mTensorParallelism;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 constexpr getLocalRank() const noexcept
|
||||
{
|
||||
return mRank % mGpusPerNode;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 constexpr getNodeRank() const noexcept
|
||||
{
|
||||
return mRank / mGpusPerNode;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 constexpr getNodeRankOf(SizeType32 rank) const noexcept
|
||||
{
|
||||
return rank / mGpusPerNode;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr isFirstPipelineParallelRank() const noexcept
|
||||
{
|
||||
return getPipelineParallelRank() == 0;
|
||||
@ -109,6 +129,7 @@ public:
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<SizeType32> getPipelineParallelGroup() const;
|
||||
[[nodiscard]] std::vector<SizeType32> getTensorParallelGroup() const;
|
||||
|
||||
static WorldConfig mpi(SizeType32 gpusPerNode = kDefaultGpusPerNode,
|
||||
std::optional<SizeType32> tensorParallelism = std::nullopt,
|
||||
|
||||
48
cpp/micro_benchmarks/CMakeLists.txt
Normal file
48
cpp/micro_benchmarks/CMakeLists.txt
Normal file
@ -0,0 +1,48 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
|
||||
# All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain a copy of
|
||||
# the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
|
||||
# Google Benchmark Preparation - Same as google test ../tests/CMakeLists.txt
|
||||
# Google Benchmark is provided under Apache-2.0 license
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
googlebenchmark
|
||||
GIT_REPOSITORY https://github.com/google/benchmark.git
|
||||
GIT_TAG v1.8.3)
|
||||
FetchContent_MakeAvailable(googlebenchmark)
|
||||
|
||||
add_custom_target(micro_benchmarks)
|
||||
|
||||
include_directories(
|
||||
${PROJECT_SOURCE_DIR}/tensorrt_llm/cutlass_extensions/include
|
||||
${PROJECT_SOURCE_DIR}/include)
|
||||
|
||||
set(TOP_LEVEL_DIR "${PROJECT_SOURCE_DIR}/..")
|
||||
|
||||
function(add_benchmark test_name test_src)
|
||||
add_executable(${test_name} ${test_src})
|
||||
|
||||
message("Linking with ${SHARED_TARGET}")
|
||||
target_link_libraries(${test_name} PUBLIC ${SHARED_TARGET}
|
||||
benchmark::benchmark)
|
||||
|
||||
target_compile_features(${test_name} PRIVATE cxx_std_17)
|
||||
target_compile_definitions(${test_name}
|
||||
PUBLIC TOP_LEVEL_DIR="${TOP_LEVEL_DIR}")
|
||||
|
||||
add_dependencies(micro_benchmarks ${test_name})
|
||||
endfunction()
|
||||
|
||||
add_benchmark(mixtureOfExpertsBackendBenchmark
|
||||
mixtureOfExpertsBackendBenchmarkLauncher.cu)
|
||||
36
cpp/micro_benchmarks/README.md
Normal file
36
cpp/micro_benchmarks/README.md
Normal file
@ -0,0 +1,36 @@
|
||||
# Micro Benchmarks
|
||||
|
||||
This folder contains benchmarks for specific components in TRT-LLM,
|
||||
using [google-benchmark](https://github.com/google/benchmark/tree/main)
|
||||
|
||||
## Building
|
||||
|
||||
To build add the `--micro_benchmark` flag to `build_wheel.py` or pass `-DBUILD_MICRO_BENCHMARKS=ON` to cmake
|
||||
|
||||
## Benchmark Documentations
|
||||
|
||||
### Mixture Of Experts Backend Benchmark
|
||||
|
||||
Target `mixtureOfExpertsBackendBenchmark`
|
||||
|
||||
This benchmark covers the backend used by the `MixtureOfExperts` plugin. It allows you to benchmark different MOE
|
||||
configurations without building a TRT engine.
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
./mixtureOfExpertsBackendBenchmark
|
||||
|
||||
# or
|
||||
|
||||
./mixtureOfExpertsBackendBenchmark --benchmark_file <JSON benchmark definition>
|
||||
```
|
||||
|
||||
For more information see:
|
||||
|
||||
```
|
||||
./mixtureOfExpertsBackendBenchmark --help
|
||||
```
|
||||
|
||||
The `gen-moe-workload-file.py` is a helper script that can generate workload files for MOE benchmarks. This is useful
|
||||
for sharing or comparing configurations, such as when generating a reproduction case for a performance bug
|
||||
92
cpp/micro_benchmarks/gen-moe-benchmark-file.py
Normal file
92
cpp/micro_benchmarks/gen-moe-benchmark-file.py
Normal file
@ -0,0 +1,92 @@
|
||||
import argparse
|
||||
|
||||
template = '''{{
|
||||
"num_experts": {num_experts},
|
||||
"k": {k},
|
||||
"hidden_size": {hidden_size},
|
||||
"inter_size": {inter_size},
|
||||
"tp_size": {tp_size},
|
||||
"ep_size": {ep_size},
|
||||
"world_rank": {world_rank},
|
||||
"num_tokens": {num_tokens},
|
||||
"bias": 0,
|
||||
"act_fn": {act_fn},
|
||||
"norm_mode": {norm_mode},
|
||||
{dtype_string}
|
||||
{routing_string}
|
||||
"tactic_id": {tactic_id}
|
||||
}}'''
|
||||
|
||||
|
||||
def make_dtype_string(dtypes=None):
|
||||
if dtypes is None:
|
||||
return ""
|
||||
if not isinstance(dtypes, list):
|
||||
dtypes = [dtypes]
|
||||
join_term = '","' # Include quotes because they should be strings
|
||||
return f'"dtypes": ["{join_term.join(dtypes)}"],'
|
||||
|
||||
|
||||
def make_routing_string(name=None, values=None):
|
||||
if values is None and name is None:
|
||||
return ""
|
||||
if values is None:
|
||||
return f'"routing_values": "{name}",'
|
||||
|
||||
values = f'"routing_values": [{",".join(values)}],'
|
||||
if name is not None:
|
||||
values += f' "routing_values_name": "{name}",'
|
||||
|
||||
return values
|
||||
|
||||
|
||||
def populate_benchmark_config(**kwargs):
|
||||
return template.format(**kwargs)
|
||||
|
||||
|
||||
# Default Mixtral configurations
|
||||
num_experts = 8
|
||||
k = 2
|
||||
hidden_size = 4096
|
||||
inter_size = 14336
|
||||
tp_size = 4
|
||||
ep_size = 1
|
||||
world_rank = 0
|
||||
act_fn = 3
|
||||
norm_mode = 1
|
||||
dtype_string = make_dtype_string() # All dtypes
|
||||
routing_string = make_routing_string(
|
||||
name="balanced") # Use the default uniform distribution
|
||||
tactic_id = '"auto"'
|
||||
|
||||
configs = []
|
||||
for num_tokens in [1, 8, 64, 2048, 65536]:
|
||||
configs.append(
|
||||
populate_benchmark_config(
|
||||
num_experts=num_experts,
|
||||
k=k,
|
||||
hidden_size=hidden_size,
|
||||
inter_size=inter_size,
|
||||
tp_size=tp_size,
|
||||
ep_size=ep_size,
|
||||
world_rank=world_rank,
|
||||
num_tokens=num_tokens,
|
||||
act_fn=act_fn,
|
||||
norm_mode=norm_mode,
|
||||
dtype_string=dtype_string,
|
||||
routing_string=routing_string,
|
||||
tactic_id=tactic_id,
|
||||
))
|
||||
|
||||
full_string = "[\n" + ",\n".join(configs) + "\n]"
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('filename',
|
||||
type=str,
|
||||
help='The name of the file to generate',
|
||||
nargs='?',
|
||||
default="moe-benchmark-file.json")
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.filename, "w+") as f:
|
||||
f.write(full_string)
|
||||
482
cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h
Normal file
482
cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h
Normal file
@ -0,0 +1,482 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <benchmark/benchmark.h>
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h"
|
||||
#include "tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cuda.h>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
using namespace tensorrt_llm::kernels;
|
||||
using namespace tensorrt_llm::common;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
using namespace tensorrt_llm::cutlass_extensions;
|
||||
|
||||
static BufferManager::CudaStreamPtr streamPtr;
|
||||
static std::unique_ptr<BufferManager> bufferManager;
|
||||
static int deviceCount;
|
||||
static char* workloadFile = nullptr;
|
||||
|
||||
constexpr bool VERBOSE = false;
|
||||
|
||||
namespace
|
||||
{
|
||||
// Abstract class for routing config
|
||||
struct RoutingConfig
|
||||
{
|
||||
virtual void setRouting(float* routing_output, int64_t num_experts, int64_t k, int64_t num_tokens) = 0;
|
||||
virtual std::string getName() = 0;
|
||||
};
|
||||
|
||||
struct LoadBalancedRoutingConfig : public RoutingConfig
|
||||
{
|
||||
std::string getName() override
|
||||
{
|
||||
return "balanced";
|
||||
}
|
||||
|
||||
void setRouting(float* routing_output, int64_t num_experts, int64_t k, int64_t num_tokens) override
|
||||
{
|
||||
nvinfer1::DataType type = nvinfer1::DataType::kFLOAT;
|
||||
makeLoadBalancedRoutingConfiguration(routing_output, num_experts, num_tokens, k, type, streamPtr->get());
|
||||
check_cuda_error(cudaStreamSynchronize(streamPtr->get()));
|
||||
}
|
||||
};
|
||||
|
||||
struct VectoredRoutingConfig : public RoutingConfig
|
||||
{
|
||||
std::vector<float> config;
|
||||
std::pair<int64_t, int64_t> shape;
|
||||
std::string name;
|
||||
|
||||
VectoredRoutingConfig(std::vector<float> config, std::pair<int64_t, int64_t> shape, std::string name = "vectored")
|
||||
: config(config)
|
||||
, shape(shape)
|
||||
, name(name)
|
||||
{
|
||||
}
|
||||
|
||||
std::string getName() override
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
void setRouting(float* routing_output, int64_t num_experts, int64_t k, int64_t num_tokens) override
|
||||
{
|
||||
assert(shape.second == num_experts);
|
||||
for (int64_t i = 0; i < num_tokens; i += shape.first)
|
||||
{
|
||||
int num_to_copy = std::min(num_tokens - i, shape.first);
|
||||
check_cuda_error(cudaMemcpyAsync(routing_output + i * num_experts, config.data(),
|
||||
num_to_copy * num_experts * sizeof(float), cudaMemcpyHostToDevice, streamPtr->get()));
|
||||
}
|
||||
check_cuda_error(cudaStreamSynchronize(streamPtr->get()));
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace
|
||||
|
||||
constexpr int LOAD_BALANCED_ROUTING_CONFIG = 0;
|
||||
std::vector<std::shared_ptr<RoutingConfig>> routingConfigCache{
|
||||
std::static_pointer_cast<RoutingConfig>(std::make_shared<LoadBalancedRoutingConfig>())};
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
using SafeFP8 = __nv_fp8_e4m3;
|
||||
#else
|
||||
using SafeFP8 = void;
|
||||
#endif
|
||||
|
||||
template <class TypeTuple_>
|
||||
class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
|
||||
{
|
||||
public:
|
||||
using DataType = typename TypeTuple_::DataType;
|
||||
using WeightType = typename TypeTuple_::WeightType;
|
||||
using OutputType = typename TypeTuple_::OutputType;
|
||||
constexpr static bool INT4 = std::is_same_v<WeightType, cutlass::uint4b_t>;
|
||||
constexpr static bool FP8 = std::is_same_v<DataType, SafeFP8>;
|
||||
constexpr static bool INT_QUANT = !std::is_same_v<DataType, WeightType>;
|
||||
using WeightStorage = std::conditional_t<INT_QUANT, uint8_t, WeightType>;
|
||||
constexpr static int WEIGHT_ELEM_PER_BYTE = INT4 ? 2 : 1;
|
||||
int const BASE_HIDDEN_SIZE = 64 / sizeof(WeightType) * WEIGHT_ELEM_PER_BYTE;
|
||||
|
||||
std::vector<BufferManager::IBufferPtr> managed_buffers;
|
||||
float* mInputProbabilities{};
|
||||
DataType* mInputTensor{};
|
||||
|
||||
int64_t mHiddenSize{};
|
||||
int64_t mNumExperts{};
|
||||
int64_t mK{};
|
||||
|
||||
constexpr static nvinfer1::DataType toDTypeID()
|
||||
{
|
||||
if (FP8)
|
||||
return nvinfer1::DataType::kFP8;
|
||||
if (INT_QUANT && INT4)
|
||||
return nvinfer1::DataType::kUINT8; // Hack to distinguish int4, use unsigned
|
||||
if (INT_QUANT)
|
||||
return nvinfer1::DataType::kINT8;
|
||||
if (std::is_same_v<DataType, float>)
|
||||
return nvinfer1::DataType::kFLOAT;
|
||||
if (std::is_same_v<DataType, half>)
|
||||
return nvinfer1::DataType::kHALF;
|
||||
#ifdef ENABLE_BF16
|
||||
if (std::is_same_v<DataType, nv_bfloat16>)
|
||||
return nvinfer1::DataType::kBF16;
|
||||
#endif
|
||||
return nvinfer1::DataType::kBOOL;
|
||||
};
|
||||
|
||||
static bool shouldSkip()
|
||||
{
|
||||
#ifndef ENABLE_FP8
|
||||
static_assert(!FP8, "FP8 Tests enabled on unsupported CUDA version");
|
||||
#endif
|
||||
bool should_skip_unsupported_fp8 = getSMVersion() < 90 && FP8;
|
||||
return should_skip_unsupported_fp8;
|
||||
}
|
||||
|
||||
// Deprecated, just here to suppress warnings
|
||||
void SetUp(benchmark::State const& s) override
|
||||
{
|
||||
abort();
|
||||
}
|
||||
|
||||
void TearDown(benchmark::State const& s) override
|
||||
{
|
||||
abort();
|
||||
}
|
||||
|
||||
cudaEvent_t mStartEvent, mEndEvent;
|
||||
|
||||
void SetUp(benchmark::State& s) override
|
||||
{
|
||||
assert(bufferManager);
|
||||
if (shouldSkip())
|
||||
{
|
||||
s.SkipWithMessage("GPU does not support FP8");
|
||||
}
|
||||
|
||||
// Makes sure nothing from a previous iteration hangs around
|
||||
check_cuda_error(cudaDeviceSynchronize());
|
||||
check_cuda_error(cudaEventCreate(&mStartEvent));
|
||||
check_cuda_error(cudaEventCreate(&mEndEvent));
|
||||
}
|
||||
|
||||
void TearDown(benchmark::State& s) override
|
||||
{
|
||||
managed_buffers.clear();
|
||||
|
||||
check_cuda_error(cudaEventDestroy(mStartEvent));
|
||||
check_cuda_error(cudaEventDestroy(mEndEvent));
|
||||
check_cuda_error(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
CutlassMoeFCRunner<DataType, WeightType, OutputType> mMoERunner{};
|
||||
char* mWorkspace{};
|
||||
float* mScaleProbs{};
|
||||
WeightStorage* mExpertWeight1{};
|
||||
WeightStorage* mExpertWeight2{};
|
||||
DataType* mExpertIntScale1{};
|
||||
DataType* mExpertIntScale2{};
|
||||
|
||||
float* mExpertFP8Scale1{};
|
||||
float* mExpertFP8Scale2{};
|
||||
float* mExpertFP8Scale3{};
|
||||
|
||||
DataType* mExpertBias1{};
|
||||
DataType* mExpertBias2{};
|
||||
|
||||
OutputType* mFinalOutput{};
|
||||
int* mSourceToExpandedMap;
|
||||
int* mSelectedExpert;
|
||||
int64_t mInterSize{};
|
||||
int64_t mTotalTokens{};
|
||||
|
||||
bool mUseBias = true;
|
||||
|
||||
bool mIsGated = false;
|
||||
int mGatedMultiplier = 1;
|
||||
|
||||
tensorrt_llm::ActivationType mActType = tensorrt_llm::ActivationType::Relu;
|
||||
MOEExpertScaleNormalizationMode mNormMode = MOEExpertScaleNormalizationMode::NONE;
|
||||
|
||||
QuantParams mQuantParams{};
|
||||
|
||||
std::optional<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> mSelectedConfig = std::nullopt;
|
||||
|
||||
template <class T>
|
||||
T* allocBuffer(size_t size)
|
||||
{
|
||||
auto i_buffer = bufferManager->gpu(size * sizeof(T));
|
||||
check_cuda_error(cudaGetLastError());
|
||||
managed_buffers.emplace_back(std::move(i_buffer));
|
||||
T* ptr = static_cast<T*>(managed_buffers.back()->data());
|
||||
check_cuda_error(cudaMemsetAsync(ptr, 0x0, size * sizeof(T), streamPtr->get()));
|
||||
return ptr;
|
||||
}
|
||||
|
||||
void initBuffersPermute(int64_t num_tokens, int64_t hidden_size, int64_t inter_size, int64_t num_experts, int64_t k,
|
||||
int64_t routing_config)
|
||||
{
|
||||
assert(hidden_size % BASE_HIDDEN_SIZE == 0);
|
||||
|
||||
managed_buffers.clear();
|
||||
|
||||
mTotalTokens = num_tokens;
|
||||
mHiddenSize = hidden_size;
|
||||
mInterSize = inter_size;
|
||||
mNumExperts = num_experts;
|
||||
mK = k;
|
||||
mIsGated = tensorrt_llm::isGatedActivation(mActType);
|
||||
mGatedMultiplier = mIsGated ? 2 : 1;
|
||||
auto const gated_inter = mInterSize * mGatedMultiplier;
|
||||
|
||||
size_t workspace_size
|
||||
= mMoERunner.getWorkspaceSize(mTotalTokens, mHiddenSize, mInterSize, mNumExperts, mK, mActType, {});
|
||||
|
||||
mWorkspace = allocBuffer<char>(workspace_size);
|
||||
size_t const expert_matrix_size = mNumExperts * mHiddenSize * mInterSize;
|
||||
|
||||
mExpertWeight1
|
||||
= allocBuffer<WeightStorage>(expert_matrix_size * mGatedMultiplier / WEIGHT_ELEM_PER_BYTE - 8192);
|
||||
mExpertWeight2 = allocBuffer<WeightStorage>(expert_matrix_size / WEIGHT_ELEM_PER_BYTE - 8192);
|
||||
|
||||
mExpertBias1 = nullptr;
|
||||
mExpertBias2 = nullptr;
|
||||
if (mUseBias)
|
||||
{
|
||||
mExpertBias1 = allocBuffer<DataType>(mNumExperts * gated_inter);
|
||||
mExpertBias2 = allocBuffer<DataType>(mNumExperts * mHiddenSize);
|
||||
}
|
||||
|
||||
if constexpr (INT_QUANT)
|
||||
{
|
||||
mExpertIntScale1 = allocBuffer<DataType>(mNumExperts * gated_inter);
|
||||
mExpertIntScale2 = allocBuffer<DataType>(mNumExperts * mHiddenSize);
|
||||
|
||||
mQuantParams = QuantParams::Int(mExpertIntScale1, mExpertIntScale2);
|
||||
}
|
||||
else if constexpr (FP8)
|
||||
{
|
||||
mExpertFP8Scale1 = allocBuffer<float>(mNumExperts);
|
||||
mExpertFP8Scale2 = allocBuffer<float>(1);
|
||||
mExpertFP8Scale3 = allocBuffer<float>(mNumExperts);
|
||||
|
||||
mQuantParams = QuantParams::FP8(mExpertFP8Scale1, mExpertFP8Scale2, mExpertFP8Scale3);
|
||||
}
|
||||
|
||||
mInputProbabilities = allocBuffer<float>(mTotalTokens * mNumExperts);
|
||||
mScaleProbs = allocBuffer<float>(mTotalTokens * mK);
|
||||
mInputTensor = allocBuffer<DataType>(mTotalTokens * mHiddenSize);
|
||||
mFinalOutput = allocBuffer<OutputType>(mTotalTokens * mHiddenSize);
|
||||
|
||||
mSourceToExpandedMap = allocBuffer<int>(mTotalTokens * mK);
|
||||
mSelectedExpert = allocBuffer<int>(mTotalTokens * mK);
|
||||
|
||||
auto tactic = routingConfigCache.at(routing_config);
|
||||
tactic->setRouting(mInputProbabilities, mNumExperts, mK, mTotalTokens);
|
||||
|
||||
check_cuda_error(cudaStreamSynchronize(streamPtr->get()));
|
||||
}
|
||||
|
||||
float benchmarkLoop(MOEParallelismConfig parallelism_config)
|
||||
{
|
||||
check_cuda_error(cudaEventRecord(mStartEvent, streamPtr->get()));
|
||||
runMoEPermute(parallelism_config);
|
||||
check_cuda_error(cudaEventRecord(mEndEvent, streamPtr->get()));
|
||||
check_cuda_error(cudaStreamSynchronize(streamPtr->get()));
|
||||
|
||||
float ms;
|
||||
check_cuda_error(cudaEventElapsedTime(&ms, mStartEvent, mEndEvent));
|
||||
return ms;
|
||||
}
|
||||
|
||||
// An imprecise benchmark pass for picking the best tactic.
|
||||
// Runs for 3 iterations or 1 second and picks the best option
|
||||
int pickBestTactic(MOEParallelismConfig parallelism_config)
|
||||
{
|
||||
auto tactics = mMoERunner.getTactics();
|
||||
|
||||
float best_time = INFINITY;
|
||||
int best_idx = -1;
|
||||
for (int tidx = 0; tidx < tactics.size(); tidx++)
|
||||
{
|
||||
try
|
||||
{
|
||||
// Set the tactic
|
||||
auto const& t = tactics[tidx];
|
||||
mMoERunner.setTactic(t);
|
||||
|
||||
// Warm-Up run
|
||||
benchmarkLoop(parallelism_config);
|
||||
|
||||
float const max_time_ms = 1000.f;
|
||||
int const max_iters = 3;
|
||||
float time = 0.f;
|
||||
int iter = 0;
|
||||
while (iter < max_iters && time < max_time_ms)
|
||||
{
|
||||
time += benchmarkLoop(parallelism_config);
|
||||
iter++;
|
||||
}
|
||||
// Get average time per iteration
|
||||
time /= static_cast<float>(iter);
|
||||
|
||||
// Update the best
|
||||
if (time < best_time)
|
||||
{
|
||||
best_idx = tidx;
|
||||
best_time = time;
|
||||
}
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
// Sync to tidy up
|
||||
if (VERBOSE)
|
||||
std::cout << "Tactic failed to run with: " << e.what() << std::endl;
|
||||
check_cuda_error(cudaDeviceSynchronize());
|
||||
// skip invalid tactic
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if all tactics failed
|
||||
if (best_idx < 0)
|
||||
return -1;
|
||||
|
||||
auto const& best_tactic = tactics[best_idx];
|
||||
mMoERunner.setTactic(best_tactic);
|
||||
return best_idx;
|
||||
}
|
||||
|
||||
int setTactic(int tactic_idx, MOEParallelismConfig parallelism_config)
|
||||
{
|
||||
if (tactic_idx == -1)
|
||||
{
|
||||
return pickBestTactic(parallelism_config);
|
||||
}
|
||||
|
||||
auto tactics = mMoERunner.getTactics();
|
||||
if (tactic_idx < 0 || tactic_idx >= tactics.size())
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
auto selected_tactic = tactics[tactic_idx];
|
||||
mMoERunner.setTactic(selected_tactic);
|
||||
return tactic_idx;
|
||||
}
|
||||
|
||||
void runMoEPermute(MOEParallelismConfig parallelism_config)
|
||||
{
|
||||
auto stream = streamPtr->get();
|
||||
mMoERunner.runMoe(mInputTensor, mInputProbabilities, mExpertWeight1, mExpertBias1, mActType, mExpertWeight2,
|
||||
mExpertBias2, mQuantParams, mTotalTokens, mHiddenSize, mInterSize, mNumExperts, mK, mWorkspace,
|
||||
mFinalOutput, nullptr, mTotalTokens, mScaleProbs, mSourceToExpandedMap, mSelectedExpert, parallelism_config,
|
||||
mNormMode, stream);
|
||||
}
|
||||
|
||||
void runBenchmark(benchmark::State& state);
|
||||
};
|
||||
|
||||
template <class TypeTuple_>
|
||||
void MixtureOfExpertsBenchmark<TypeTuple_>::runBenchmark(benchmark::State& state)
|
||||
{
|
||||
int const num_experts = state.range(0);
|
||||
int const top_k = state.range(1);
|
||||
int const hidden_size = state.range(2);
|
||||
int const inter_size = state.range(3);
|
||||
int const tp_size = state.range(4);
|
||||
int const ep_size = state.range(5);
|
||||
int const world_rank = state.range(6);
|
||||
int const num_tokens = state.range(7);
|
||||
mUseBias = state.range(8);
|
||||
mActType = static_cast<tensorrt_llm::ActivationType>(state.range(9));
|
||||
mNormMode = static_cast<MOEExpertScaleNormalizationMode>(state.range(10));
|
||||
int tactic_idx = state.range(11);
|
||||
int const routing_config = state.range(12);
|
||||
|
||||
state.counters["num_experts"] = num_experts;
|
||||
state.counters["top_k"] = top_k;
|
||||
state.counters["hidden_size"] = hidden_size;
|
||||
state.counters["inter_size"] = inter_size;
|
||||
state.counters["tp_size"] = tp_size;
|
||||
state.counters["ep_size"] = ep_size;
|
||||
state.counters["world_rank"] = world_rank;
|
||||
state.counters["num_tokens"] = num_tokens;
|
||||
state.counters["use_bias"] = (int) mUseBias;
|
||||
state.counters["act_fn"] = (int) mActType;
|
||||
state.counters["norm_mode"] = (int) mNormMode;
|
||||
state.counters["routing_config"] = (int) routing_config;
|
||||
state.counters["dtype"] = (int) toDTypeID();
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "Experts,K,Hidden,Inter,TP,EP,Rank,Tokens,Bias,Actfn,Norm Mode,Tactic,Routing=";
|
||||
for (auto v : {num_experts, top_k, hidden_size, inter_size, tp_size, ep_size, world_rank, num_tokens,
|
||||
(int) mUseBias, (int) mActType, (int) mNormMode, tactic_idx})
|
||||
{
|
||||
ss << v << ",";
|
||||
}
|
||||
ss << routingConfigCache.at(routing_config)->getName();
|
||||
// state.SetLabel(ss.str());
|
||||
|
||||
// Always use EP size for moe config until we support TP+EP, we just divide the inter size for TP
|
||||
MOEParallelismConfig parallelism_config = MOEParallelismConfig::ExpertParallelism(ep_size, world_rank / tp_size);
|
||||
initBuffersPermute(num_tokens, hidden_size, inter_size / tp_size, num_experts, top_k, routing_config);
|
||||
|
||||
// Parse the tactic, does checks for "auto" mode and out of range
|
||||
tactic_idx = setTactic(tactic_idx, parallelism_config);
|
||||
if (tactic_idx < 0)
|
||||
{
|
||||
state.SkipWithMessage("Out of range tactic");
|
||||
return;
|
||||
}
|
||||
if (VERBOSE)
|
||||
{
|
||||
auto tactics = mMoERunner.getTactics();
|
||||
std::cout << "Selected " << tactic_idx << "/" << tactics.size() << "\n"
|
||||
<< tactics[tactic_idx].toString() << std::endl;
|
||||
}
|
||||
state.counters["tactic_idx"] = tactic_idx;
|
||||
|
||||
for (auto _ : state)
|
||||
{
|
||||
float ms = benchmarkLoop(parallelism_config);
|
||||
state.SetIterationTime(ms / 1000.f);
|
||||
}
|
||||
|
||||
state.SetItemsProcessed(state.iterations() * num_tokens);
|
||||
|
||||
// Cleanup all the benchmark state
|
||||
managed_buffers.clear();
|
||||
check_cuda_error(cudaDeviceSynchronize());
|
||||
}
|
||||
786
cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu
Normal file
786
cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu
Normal file
@ -0,0 +1,786 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Include the fixture with the actual benchmark code
|
||||
#include "mixtureOfExpertsBackendBenchmarkFixture.h"
|
||||
|
||||
/*
|
||||
* Below is all the setup for parameterising the benchmarks
|
||||
*/
|
||||
|
||||
template <class DataType_, class WeightType_ = DataType_, class OutputType_ = DataType_>
|
||||
struct WeightParams
|
||||
{
|
||||
using DataType = DataType_;
|
||||
using WeightType = WeightType_;
|
||||
using OutputType = OutputType_;
|
||||
};
|
||||
|
||||
#define BENCHMARK_BASIC(atype, wtype, otype) \
|
||||
BENCHMARK_TEMPLATE_DEFINE_F(MixtureOfExpertsBenchmark, Basic_##atype##_##wtype, WeightParams<atype, wtype, otype>) \
|
||||
(benchmark::State & state) \
|
||||
{ \
|
||||
runBenchmark(state); \
|
||||
}
|
||||
|
||||
#define BENCHMARK_BASIC_DO_REGISTER(atype, wtype, otype) \
|
||||
BENCHMARK_REGISTER_F(MixtureOfExpertsBenchmark, Basic_##atype##_##wtype) \
|
||||
->Apply(argGen<MixtureOfExpertsBenchmark<WeightParams<atype, wtype, otype>>>)
|
||||
|
||||
template <class BenchClass>
|
||||
auto listAllTactics()
|
||||
{
|
||||
int const sm = getSMVersion();
|
||||
using RunnerType = decltype(BenchClass::mMoERunner);
|
||||
return RunnerType::getTactics(sm);
|
||||
}
|
||||
|
||||
template <class BenchClass>
|
||||
int parseTacticToId(nlohmann::json tactic_config)
|
||||
{
|
||||
bool is_sm90 = tactic_config.at("is_sm90").get<bool>();
|
||||
int tile_shape_id = -1;
|
||||
std::array<int, 3> tile_shape;
|
||||
if (tactic_config.at("tile_shape").is_array())
|
||||
tactic_config.at("tile_shape").get_to(tile_shape);
|
||||
else
|
||||
tile_shape_id = tactic_config.at("tile_shape").get<int>();
|
||||
|
||||
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> confs = listAllTactics<BenchClass>();
|
||||
|
||||
try
|
||||
{
|
||||
for (int i = 0; i < confs.size(); i++)
|
||||
{
|
||||
auto const& c = confs[i];
|
||||
if (c.is_sm90 != is_sm90)
|
||||
continue;
|
||||
|
||||
if (!is_sm90)
|
||||
{
|
||||
int stages = tactic_config.at("stages").get<int>();
|
||||
if (c.stages != stages)
|
||||
continue;
|
||||
}
|
||||
|
||||
if (tile_shape_id != -1)
|
||||
{
|
||||
int comp = is_sm90 ? (int) c.tile_config_sm90 : (int) c.tile_config;
|
||||
if (tile_shape_id != comp)
|
||||
continue;
|
||||
if (is_sm90 && (int) c.cluster_shape != tactic_config.at("cluster_shape").get<int>())
|
||||
continue;
|
||||
|
||||
// Found matching config
|
||||
return i;
|
||||
}
|
||||
|
||||
// Handle if the user provided a shape instead of the enum value
|
||||
if (is_sm90)
|
||||
{
|
||||
using Kv = uint64_t;
|
||||
constexpr static auto K = [](int m, int n) { return (uint64_t(m) << 32) | uint64_t(n); };
|
||||
static std::unordered_map<Kv, CutlassTileConfigSM90> const tile_map{
|
||||
{K(64, 16), CutlassTileConfigSM90::CtaShape64x16x128B},
|
||||
{K(64, 32), CutlassTileConfigSM90::CtaShape64x32x128B},
|
||||
{K(64, 64), CutlassTileConfigSM90::CtaShape64x64x128B},
|
||||
{K(64, 128), CutlassTileConfigSM90::CtaShape64x128x128B},
|
||||
{K(64, 256), CutlassTileConfigSM90::CtaShape64x256x128B},
|
||||
|
||||
{K(128, 16), CutlassTileConfigSM90::CtaShape128x16x128B},
|
||||
{K(128, 32), CutlassTileConfigSM90::CtaShape128x32x128B},
|
||||
{K(128, 64), CutlassTileConfigSM90::CtaShape128x64x128B},
|
||||
{K(128, 128), CutlassTileConfigSM90::CtaShape128x128x128B},
|
||||
{K(128, 256), CutlassTileConfigSM90::CtaShape128x256x128B},
|
||||
};
|
||||
|
||||
if (c.tile_config_sm90 != tile_map.at(K(tile_shape[0], tile_shape[1])))
|
||||
continue;
|
||||
|
||||
static std::unordered_map<Kv, ClusterShape> const cluster_map{
|
||||
// CTA configs for M=64
|
||||
{K(1, 1), ClusterShape::ClusterShape_1x1x1},
|
||||
{K(2, 1), ClusterShape::ClusterShape_2x1x1},
|
||||
{K(1, 2), ClusterShape::ClusterShape_1x2x1},
|
||||
{K(2, 2), ClusterShape::ClusterShape_2x2x1},
|
||||
};
|
||||
|
||||
std::array<int, 3> cluster_shape;
|
||||
tactic_config.at("cluster_shape").get_to(cluster_shape);
|
||||
|
||||
if (c.cluster_shape != cluster_map.at(K(cluster_shape[0], cluster_shape[1])))
|
||||
continue;
|
||||
|
||||
// Found matching config
|
||||
return i;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::array<int, 3> warp_shape;
|
||||
tactic_config.at("warp_shape").get_to(warp_shape);
|
||||
|
||||
using Kv = uint64_t;
|
||||
constexpr static auto K = [](std::array<int, 3> a, std::array<int, 3> b)
|
||||
{
|
||||
uint64_t sum = 0;
|
||||
for (auto v : a)
|
||||
sum = sum * 512 + v;
|
||||
for (auto v : b)
|
||||
sum = sum * 256 + v;
|
||||
return sum;
|
||||
};
|
||||
static std::unordered_map<Kv, CutlassTileConfig> tile_map{
|
||||
{K({128, 128, 8}, {64, 64, 8}), CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8},
|
||||
|
||||
{K({16, 128, 64}, {16, 32, 64}), CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64},
|
||||
{K({32, 128, 64}, {32, 32, 64}), CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64},
|
||||
|
||||
{K({64, 128, 64}, {32, 64, 64}), CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64},
|
||||
{K({64, 64, 128}, {32, 64, 64}), CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64},
|
||||
{K({64, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64},
|
||||
|
||||
{K({128, 64, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64},
|
||||
{K({128, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64},
|
||||
{K({128, 128, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64},
|
||||
{K({128, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64},
|
||||
{K({128, 256, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64},
|
||||
|
||||
{K({256, 128, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64},
|
||||
|
||||
{K({16, 256, 64}, {16, 64, 64}), CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64}
|
||||
|
||||
};
|
||||
if (c.tile_config != tile_map.at(K(tile_shape, warp_shape)))
|
||||
continue;
|
||||
|
||||
// Found matching config
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (std::out_of_range const& e)
|
||||
{
|
||||
std::cerr << "Warning: error parsing tactic " << tactic_config.dump(2) << std::endl;
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
template <class BenchClass>
|
||||
void parseTacticToVectorID(nlohmann::json& tactic, std::vector<int>& tactic_ids)
|
||||
{
|
||||
|
||||
if (tactic.is_number_integer())
|
||||
{
|
||||
tactic_ids.push_back(tactic.get<int>());
|
||||
}
|
||||
else if (tactic.is_array())
|
||||
{
|
||||
for (auto c : tactic)
|
||||
{
|
||||
parseTacticToVectorID<BenchClass>(c, tactic_ids);
|
||||
}
|
||||
}
|
||||
else if (tactic.is_object())
|
||||
{
|
||||
tactic_ids.push_back(parseTacticToId<BenchClass>(tactic));
|
||||
}
|
||||
else if (tactic.is_string())
|
||||
{
|
||||
assert(tactic.is_string());
|
||||
auto tactic_name = tactic.get<std::string>();
|
||||
if (tactic_name == "all")
|
||||
{
|
||||
auto all_tactics = listAllTactics<BenchClass>();
|
||||
tactic_ids.resize(all_tactics.size());
|
||||
std::iota(tactic_ids.begin(), tactic_ids.end(), 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(tactic.get<std::string>() == "auto");
|
||||
tactic_ids.push_back(-1);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("Invalid tactic format");
|
||||
}
|
||||
}
|
||||
|
||||
// This interdependence of globals could be better, but it works ok for this limited case.
|
||||
std::unordered_map<std::string, std::pair<int, int>> name_info_map{
|
||||
{routingConfigCache[LOAD_BALANCED_ROUTING_CONFIG]->getName(), {-1, LOAD_BALANCED_ROUTING_CONFIG}},
|
||||
};
|
||||
|
||||
int getNameCacheIdx(std::string const& name)
|
||||
{
|
||||
if (name_info_map.find(name) == name_info_map.end())
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
return name_info_map.at(name).second;
|
||||
}
|
||||
|
||||
void setNameCacheIdx(std::string const& name, int id)
|
||||
{
|
||||
name_info_map.at(name).second = id;
|
||||
}
|
||||
|
||||
// This is suboptimal for large benchmark files as we reread it for every data type
|
||||
template <class BenchClass>
|
||||
void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
|
||||
{
|
||||
/*
|
||||
* File schema
|
||||
*
|
||||
* [
|
||||
* {
|
||||
* "num_experts": int,
|
||||
* "k": int,
|
||||
* "hidden_size": int,
|
||||
* "inter_size": int,
|
||||
* "tp_size": int, (optional)
|
||||
* "ep_size": int, (optional)
|
||||
* "world_rank": int, (optional)
|
||||
* "num_tokens": int,
|
||||
* "bias": int,
|
||||
* "act_fn": int,
|
||||
* "norm_mode": int,
|
||||
* "tactic_id": tactic, (see below)
|
||||
* "dtypes": [string, ...], (optional)
|
||||
* "routing_values_name": string, (optional)
|
||||
* "routing_values": [float, ...], or string, (optional, length is a multiple of num_experts)
|
||||
* },
|
||||
* ...
|
||||
* ]
|
||||
*
|
||||
* Explanation:
|
||||
*
|
||||
* - "num_experts" - The number of experts
|
||||
* - "k" - The top k
|
||||
* - "hidden_size" - The hidden size
|
||||
* - "inter_size" - The inter size
|
||||
* - "tp_size" - The TP size
|
||||
* - "ep_size" - The EP size
|
||||
* - "world_rank" - The world rank = ep_rank * tp_size + tp_rank
|
||||
* - "num_tokens" - The total number of tokens to benchmark
|
||||
* - "bias" - If bias should be used, 0 = no bias, 1 = bias
|
||||
* - "act_fn" - The enum value of the activation function. See
|
||||
* "cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
|
||||
* - "norm_mode" - The normalization mode. 0 = NONE, 1 = RENORM. See
|
||||
* "cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
|
||||
*
|
||||
* - "tactic_id"
|
||||
* Valid tactics are:
|
||||
* - An object:
|
||||
* {
|
||||
* "is_sm90": bool,
|
||||
* "tile_shape": [int, int, int] or int,
|
||||
* "cluster_shape": [int, int, int] or int, (required for sm90, type must be an int if tile_shape is an int)
|
||||
* "warp_shape": [int, int, int], (required for non-sm90 if tile_shape is an array)
|
||||
* "stages": int, (required for non-sm90)
|
||||
* },
|
||||
* - An integer: corresponds to an index in the tactics array. WARNING this is not stable between test
|
||||
* configurations
|
||||
* - An array: of integers or objects, forms a list of tactics to sweep
|
||||
* - The string "all": This will sweep through all possible tactics
|
||||
* - The string "auto": This runs a short benchmark to pick the fastest tactic before each benchmark case. Useful
|
||||
* for quick perf tests, prefer a full sweep and manually setting the tactic for more accurate results
|
||||
*
|
||||
* - dtypes - A list of dtypes to run this config through.
|
||||
* Allowed values are: fp8, int4, int8, float, half, bfloat16
|
||||
* If this argument is omitted all dtypes will be run. Note, not all tactics are supported for all dtypes,
|
||||
* unsupported tactics will be skipped with a warning.
|
||||
*
|
||||
* - "routing_values_name" - a name to help identify the routing pattern. This can be used by later configs to reuse
|
||||
* the config
|
||||
* - "routing_values" - a flat array of routing values to define a new config, or a string referencing the name of a
|
||||
* previous config. Defaults to "balanced", which is short-hand for a uniform expert distribution
|
||||
* These define the routing values used as input to the moe backend, and is intended to allow comparing different
|
||||
* routing behaviours.
|
||||
* When defining an array, it must have `T*num_experts` floating point values. Each set of
|
||||
* `num_experts` values defines the input for a single token. If `num_tokens` is greater than `T` it will repeat
|
||||
* from the beginning
|
||||
*
|
||||
*/
|
||||
|
||||
std::ifstream file{workloadFile};
|
||||
std::stringstream buffer;
|
||||
buffer << file.rdbuf();
|
||||
auto file_contents = buffer.str();
|
||||
if (VERBOSE)
|
||||
std::cout << "Loaded benchmark file: " << file_contents << std::endl;
|
||||
auto source_data = nlohmann::json::parse(file_contents);
|
||||
|
||||
int i = 0;
|
||||
for (auto run_config : source_data)
|
||||
{
|
||||
if (VERBOSE)
|
||||
std::cout << "Parsing run config: " << run_config.dump(2) << std::endl;
|
||||
std::string config_name = "config_" + std::to_string(i);
|
||||
|
||||
// WARNING: Process the routing configuration immediately, so we can guarantee all configs get processed for all
|
||||
// data types. We should not skip any test cases as a later test config may depend on this config
|
||||
if (run_config.contains("routing_values_name"))
|
||||
{
|
||||
run_config["routing_values_name"].get_to(config_name);
|
||||
if (!run_config.contains("routing_values"))
|
||||
{
|
||||
throw std::invalid_argument("Setting routing value configuration name but missing routing values");
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<int> routing_config;
|
||||
auto res = name_info_map.emplace(config_name, std::pair{i, -1});
|
||||
// We must check i is not equal since this function gets called for each data type
|
||||
if (!res.second && res.first->second.first != i)
|
||||
{
|
||||
throw std::invalid_argument("Redefinition of routing_values_name " + config_name + " at config "
|
||||
+ std::to_string(i) + ". First declared at " + std::to_string(res.first->second.first));
|
||||
}
|
||||
else if (!res.second)
|
||||
{
|
||||
// Reuse the existing config from a previous parse
|
||||
routing_config = getNameCacheIdx(config_name);
|
||||
}
|
||||
i++;
|
||||
|
||||
int num_experts = run_config.at("num_experts").get<int>();
|
||||
|
||||
if (run_config.contains("routing_values") && !routing_config)
|
||||
{
|
||||
if (run_config["routing_values"].is_string())
|
||||
{
|
||||
routing_config = getNameCacheIdx(run_config["routing_values"].get<std::string>());
|
||||
if (routing_config < 0)
|
||||
{
|
||||
throw std::invalid_argument("Invalid routing value, could not find valid config");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (config_name.empty())
|
||||
{
|
||||
throw std::invalid_argument("Explicit routing configurations must specify a name");
|
||||
}
|
||||
std::vector<float> routing_values;
|
||||
run_config["routing_values"].get_to(routing_values);
|
||||
|
||||
int shape = routing_values.size() / num_experts;
|
||||
routingConfigCache.push_back(std::make_shared<VectoredRoutingConfig>(
|
||||
std::move(routing_values), std::pair<int, int>(shape, num_experts), config_name));
|
||||
routing_config = routingConfigCache.size() - 1;
|
||||
}
|
||||
|
||||
auto conf = routingConfigCache[*routing_config];
|
||||
auto conf_vec = std::dynamic_pointer_cast<VectoredRoutingConfig>(conf);
|
||||
if (conf->getName() != "balanced" && (!conf_vec || conf_vec->shape.second != num_experts))
|
||||
{
|
||||
throw std::invalid_argument("Incompatible config selected. Expected " + std::to_string(num_experts)
|
||||
+ " experts in routing configuration. "
|
||||
+ ((conf_vec) ? "Found: " + std::to_string(conf_vec->shape.second)
|
||||
: "Found incompatible routing config type"));
|
||||
}
|
||||
}
|
||||
// Use the selected config or fall back to balanced
|
||||
routing_config = routing_config.value_or(LOAD_BALANCED_ROUTING_CONFIG);
|
||||
setNameCacheIdx(config_name, *routing_config);
|
||||
|
||||
// Filter out the types we don't care about testing
|
||||
if (run_config.contains("dtypes"))
|
||||
{
|
||||
std::vector<std::string> dtypes;
|
||||
run_config["dtypes"].get_to(dtypes);
|
||||
|
||||
auto hasDtype = [&](char const* d)
|
||||
{ return std::any_of(dtypes.begin(), dtypes.end(), [&](auto const& n) { return n == d; }); };
|
||||
|
||||
if (BenchClass::FP8 && !hasDtype("fp8"))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
else if (BenchClass::INT4 && !hasDtype("int4"))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
else if (!BenchClass::INT4 && BenchClass::INT_QUANT && !hasDtype("int8"))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
else if (std::is_same_v<typename BenchClass::WeightType, float> && !hasDtype("float")
|
||||
&& !hasDtype("float32"))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
else if (std::is_same_v<typename BenchClass::WeightType, half> && !hasDtype("float16") && !hasDtype("half"))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
else if (std::is_same_v<typename BenchClass::WeightType, __nv_bfloat16> && !hasDtype("bfloat16")
|
||||
&& !hasDtype("bf16"))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Do this after filtering datatypes as tactics only make sense if we know the data type
|
||||
std::vector<int> tactic_ids{};
|
||||
parseTacticToVectorID<BenchClass>(run_config["tactic_id"], tactic_ids);
|
||||
if (tactic_ids.empty())
|
||||
{
|
||||
std::cerr << "Warning: Skipping benchmark, no such tactic: " << run_config["tactic"].dump(2) << std::endl;
|
||||
static bool printed = false;
|
||||
if (!printed)
|
||||
{
|
||||
printed = true;
|
||||
std::cerr << __PRETTY_FUNCTION__ << ": Valid Tactics are:\n";
|
||||
auto confs = listAllTactics<BenchClass>();
|
||||
for (auto c : confs)
|
||||
std::cerr << c.toString();
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
auto get_or = [&](auto name, auto def)
|
||||
{ return run_config.contains(name) ? run_config[name].template get<decltype(def)>() : def; };
|
||||
int tp_size = get_or("tp_size", 1);
|
||||
int ep_size = get_or("ep_size", 1);
|
||||
int world_rank = get_or("world_rank", 0);
|
||||
int bias = get_or("bias", 0);
|
||||
|
||||
for (auto tactic_id : tactic_ids)
|
||||
{
|
||||
benchmark->Args({num_experts, //
|
||||
run_config.at("k").get<int>(), //
|
||||
run_config.at("hidden_size").get<int>(), //
|
||||
run_config.at("inter_size").get<int>(), //
|
||||
tp_size, ep_size, world_rank, //
|
||||
run_config.at("num_tokens").get<int>(), //
|
||||
bias, //
|
||||
run_config.at("act_fn").get<int>(), //
|
||||
run_config.at("norm_mode").get<int>(), //
|
||||
tactic_id, //
|
||||
*routing_config});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class BenchClass>
|
||||
void argGenHardcoded(benchmark::internal::Benchmark* benchmark)
|
||||
{
|
||||
auto num_tactics = listAllTactics<BenchClass>().size();
|
||||
for (auto [tp, ep] : std::vector<std::pair<int, int>>{{8, 1}, {1, 8}, {2, 4}, {4, 2}})
|
||||
{
|
||||
for (int i = 0; i < num_tactics; i++)
|
||||
{
|
||||
for (auto tokens : {1, 64, 2048})
|
||||
{
|
||||
benchmark->Args({
|
||||
16, // Experts
|
||||
2, // K
|
||||
15360, // hidden
|
||||
30720, // inter
|
||||
tp, // TP Size
|
||||
ep, // EP Size
|
||||
0, // World Rank
|
||||
tokens, // Num tokens
|
||||
0, // bias
|
||||
(int) tensorrt_llm::ActivationType::Gelu, // Act fn
|
||||
(int) MOEExpertScaleNormalizationMode::RENORMALIZE, // Norm mode
|
||||
i, // Tactic ID. Index into getTactics() function result, see argGenLoadFile() for examples
|
||||
LOAD_BALANCED_ROUTING_CONFIG // Routing configuration id
|
||||
});
|
||||
|
||||
benchmark->Args({
|
||||
16, // Experts
|
||||
2, // K
|
||||
15360, // hidden
|
||||
20480, // inter
|
||||
tp, // TP Size
|
||||
ep, // EP Size
|
||||
0, // World Rank
|
||||
tokens, // Num tokens
|
||||
0, // bias
|
||||
(int) tensorrt_llm::ActivationType::Swiglu, // Act fn
|
||||
(int) MOEExpertScaleNormalizationMode::RENORMALIZE, // Norm mode
|
||||
i, // Tactic ID. Index into getTactics() function result, see argGenLoadFile() for examples
|
||||
LOAD_BALANCED_ROUTING_CONFIG // Routing configuration id
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
auto num_experts = {1, 8, 9, 64, 65, 257}; // {1, 8, 64, 65, 1024};
|
||||
auto top_k = {1, 2, 3, 16}; // {1, 2, 3, 42};
|
||||
auto hidden_size = {32}; // {8, 32, 96, 256, 1024};
|
||||
auto inter_size_mul = {4.f}; // {7.f/2.f, 4.f};
|
||||
auto num_tokens = {200}; // {1, 20, 200};
|
||||
auto use_bias = {0, 1}; // {0, 1};
|
||||
auto activation_type = {tensorrt_llm::ActivationType::Gelu};
|
||||
// {tensorrt_llm::ActivationType::Relu, tensorrt_llm::ActivationType::Gelu,
|
||||
// tensorrt_llm::ActivationType::Silu, tensorrt_llm::ActivationType::Geglu,
|
||||
// tensorrt_llm::ActivationType::Swiglu};
|
||||
auto norm_mode = {MOEExpertScaleNormalizationMode::NONE};
|
||||
auto cutlass_tactic = {0}; // {0, 1, 2};
|
||||
auto routing_config = {LOAD_BALANCED_ROUTING_CONFIG}; // {0, 1, 2};
|
||||
|
||||
for (auto num_expert : num_experts)
|
||||
for (auto k : top_k)
|
||||
if (k <= num_expert)
|
||||
for (auto size : hidden_size)
|
||||
for (auto inter_mul : inter_size_mul)
|
||||
{
|
||||
auto inter_size = static_cast<int>(size * inter_mul);
|
||||
for (auto tokens : num_tokens)
|
||||
for (auto bias : use_bias)
|
||||
for (auto act : activation_type)
|
||||
for (auto norm : norm_mode)
|
||||
for (auto tactic : cutlass_tactic)
|
||||
for (auto routing : routing_config)
|
||||
benchmark->Args({num_expert, k, size, inter_size, 1, 1, 0, tokens, bias,
|
||||
(int) act, (int) norm, tactic, routing});
|
||||
}
|
||||
}
|
||||
|
||||
template <class BenchClass>
|
||||
void argGen(benchmark::internal::Benchmark* benchmark)
|
||||
{
|
||||
if (VERBOSE)
|
||||
{
|
||||
std::cout << "List of all tactics for dtype " << (int) BenchClass::toDTypeID() << ":\n";
|
||||
int i = 0;
|
||||
for (auto& t : listAllTactics<BenchClass>())
|
||||
{
|
||||
std::cout << "Tactic " << i << ":\n";
|
||||
std::cout << t.toString() << std::endl;
|
||||
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
// Generic setup
|
||||
benchmark->UseManualTime();
|
||||
benchmark->ArgNames({"Num Experts", "K", "Hidden Size", "Inter Size", "TP Size", "EP Size", "World Rank",
|
||||
"Num Tokens", "Use Bias", "Activation Function", "Norm Mode", "Tactic ID", "Routing ID"});
|
||||
|
||||
if (workloadFile)
|
||||
argGenLoadFile<BenchClass>(benchmark);
|
||||
else
|
||||
argGenHardcoded<BenchClass>(benchmark);
|
||||
}
|
||||
|
||||
BENCHMARK_BASIC(float, float, float)
|
||||
BENCHMARK_BASIC(half, half, half)
|
||||
using uint8 = uint8_t;
|
||||
BENCHMARK_BASIC(half, uint8, half)
|
||||
using cutlass::uint4b_t;
|
||||
BENCHMARK_BASIC(half, uint4b_t, half)
|
||||
#ifdef ENABLE_BF16
|
||||
BENCHMARK_BASIC(nv_bfloat16, nv_bfloat16, nv_bfloat16)
|
||||
#endif
|
||||
#ifdef ENABLE_FP8
|
||||
BENCHMARK_BASIC(SafeFP8, SafeFP8, half)
|
||||
#endif
|
||||
|
||||
void delayedRegisterBenchmark()
|
||||
{
|
||||
BENCHMARK_BASIC_DO_REGISTER(half, half, half);
|
||||
#ifdef ENABLE_FP8
|
||||
BENCHMARK_BASIC_DO_REGISTER(SafeFP8, SafeFP8, half);
|
||||
#endif
|
||||
if (workloadFile)
|
||||
{
|
||||
// Extra ones we don't want for hardcoded runs
|
||||
BENCHMARK_BASIC_DO_REGISTER(float, float, float);
|
||||
BENCHMARK_BASIC_DO_REGISTER(half, uint8, half);
|
||||
BENCHMARK_BASIC_DO_REGISTER(half, uint4b_t, half);
|
||||
#ifdef ENABLE_BF16
|
||||
BENCHMARK_BASIC_DO_REGISTER(nv_bfloat16, nv_bfloat16, nv_bfloat16);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void doCleanup()
|
||||
{
|
||||
bufferManager.reset();
|
||||
streamPtr.reset();
|
||||
}
|
||||
|
||||
void help()
|
||||
{
|
||||
std::cout << "Usage: mixtureOfExpertsBackendBenchmark [--input_file <file>] [benchmark options]\n";
|
||||
std::cout
|
||||
<< "--input_file\t\tA JSON file describing the benchmark configurations\n\n"
|
||||
<< "File schema\n"
|
||||
"[\n"
|
||||
" {\n"
|
||||
" \"num_experts\": int,\n"
|
||||
" \"k\": int,\n"
|
||||
" \"hidden_size\": int,\n"
|
||||
" \"inter_size\": int,\n"
|
||||
" \"tp_size\": int, (optional)\n"
|
||||
" \"ep_size\": int, (optional)\n"
|
||||
" \"world_rank\": int, (optional)\n"
|
||||
" \"num_tokens\": int,\n"
|
||||
" \"bias\": int,\n"
|
||||
" \"act_fn\": int,\n"
|
||||
" \"norm_mode\": int,\n"
|
||||
" \"tactic_id\": tactic, (see below)\n"
|
||||
" \"dtypes\": [string, ...], (optional)\n"
|
||||
" \"routing_values_name\": string, (optional)\n"
|
||||
" \"routing_values\": [float, ...], or string, (optional, length is a multiple of num_experts)\n"
|
||||
" },\n"
|
||||
" ...\n"
|
||||
"]\n"
|
||||
"Explanation:\n"
|
||||
"- \"num_experts\" - The number of experts\n"
|
||||
"- \"k\" - The top k\n"
|
||||
"- \"hidden_size\" - The hidden size\n"
|
||||
"- \"inter_size\" - The inter size\n"
|
||||
"- \"tp_size\" - The TP size to use\n"
|
||||
"- \"ep_size\" - The EP size to use\n"
|
||||
"- \"world_rank\" - The world rank = ep_rank * tp_size + tp_rank\n"
|
||||
"- \"num_tokens\" - The total number of tokens to benchmark\n"
|
||||
"- \"bias\" - If bias should be used, 0 = no bias, 1 = bias\n"
|
||||
"- \"act_fn\" - The enum value of the activation function. See\n"
|
||||
"\"cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h\"\n"
|
||||
"- \"norm_mode\" - The normalization mode. 0 = NONE, 1 = RENORM. See\n"
|
||||
"\"cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h\"\n"
|
||||
"- \"tactic_id\"\n"
|
||||
"Valid tactics are:\n"
|
||||
" - An object:\n"
|
||||
" {\n"
|
||||
" \"is_sm90\": bool,\n"
|
||||
" \"tile_shape\": [int, int, int] or int,\n"
|
||||
" \"cluster_shape\": [int, int, int] or int, (required for sm90, type must be an int if tile_shape is "
|
||||
"an int)\n"
|
||||
" \"warp_shape\": [int, int, int], (required for non-sm90 if tile_shape is an array)\n"
|
||||
" \"stages\": int, (required for non-sm90)\n"
|
||||
" },\n"
|
||||
" - An integer: corresponds to an index in the tactics array. WARNING this is not stable between test "
|
||||
"configurations\n"
|
||||
" - An array: of integers or objects, forms a list of tactics to sweep\n"
|
||||
" - The string \"all\": This will sweep through all possible tactics\n"
|
||||
" - The string \"auto\": This runs a short benchmark to pick the fastest tactic before each benchmark case. "
|
||||
"Useful for quick perf tests, prefer a full sweep and manually setting the tactic for more accurate results"
|
||||
"- dtypes - A list of dtypes to run this config through.\n"
|
||||
"Allowed values are: fp8, int4, int8, float, half, bfloat16\n"
|
||||
"If this argument is omitted all dtypes will be run. Note, not all tactics are supported for all dtypes,\n"
|
||||
"unsupported tactics will be skipped with a warning.\n"
|
||||
"- \"routing_values_name\" - a name to help identify the routing pattern. This can be used by later "
|
||||
"benchmarks to reuse the config\n"
|
||||
"- \"routing_values\" - a flat array of routing values to define a new config, or a string referencing "
|
||||
"the name of a\n"
|
||||
"previous config. Defaults to \"balanced\", which is short-hand for a uniform expert distribution\n"
|
||||
"These define the routing values used as input to the moe backend, and is intended to allow comparing "
|
||||
"different routing behaviours.\n"
|
||||
"When defining an array, it must have `T*num_experts` floating point values. Each set of\n"
|
||||
"`num_experts` values defines the input for a single token. If `num_tokens` is greater than `T` it will "
|
||||
"repeat from the beginning\n\n";
|
||||
|
||||
std::cout << "benchmark options:\n";
|
||||
benchmark::PrintDefaultHelp();
|
||||
}
|
||||
|
||||
void gbenchCustomHelp()
|
||||
{
|
||||
help();
|
||||
// google-benchmark calls exit() so we need to cleanup manually
|
||||
doCleanup();
|
||||
}
|
||||
|
||||
int parseArgsAndRunBench(int argc, char** argv)
|
||||
{
|
||||
try
|
||||
{
|
||||
int shift = 0;
|
||||
for (int i = 1; i < argc; i++)
|
||||
{
|
||||
argv[i - shift] = argv[i];
|
||||
if (strcmp("--input_file", argv[i]) == 0)
|
||||
{
|
||||
i += 1;
|
||||
if (i == argc)
|
||||
{
|
||||
std::cerr << "Missing file name for input_file\n";
|
||||
return -1;
|
||||
}
|
||||
workloadFile = argv[i];
|
||||
if (workloadFile[0] == '-')
|
||||
{
|
||||
std::cerr << "Workload file " << workloadFile << " not a valid file name\n";
|
||||
return -2;
|
||||
}
|
||||
shift += 2;
|
||||
}
|
||||
else if (strcmp("--help", argv[i]) == 0 || strcmp("-h", argv[i]) == 0)
|
||||
{
|
||||
help();
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
argc -= shift;
|
||||
|
||||
// Delay after we know if the user passed a config file
|
||||
delayedRegisterBenchmark();
|
||||
|
||||
benchmark::Initialize(&argc, argv, &gbenchCustomHelp);
|
||||
|
||||
if (argc > 1)
|
||||
{
|
||||
help();
|
||||
std::cout << std::flush; // Force flush
|
||||
// Print the error second, so it's easy to see
|
||||
std::cerr << "\nUnrecognised argument: " << argv[1] << std::endl;
|
||||
return -4;
|
||||
}
|
||||
|
||||
benchmark::RunSpecifiedBenchmarks();
|
||||
benchmark::Shutdown();
|
||||
|
||||
return 0;
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
std::cerr << "Exiting benchmarks with exception: " << e.what() << std::endl;
|
||||
return -3;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
deviceCount = getDeviceCount();
|
||||
if (deviceCount < 0)
|
||||
return 0;
|
||||
streamPtr = std::make_shared<CudaStream>();
|
||||
bufferManager = std::make_unique<BufferManager>(streamPtr);
|
||||
|
||||
int res = -1;
|
||||
try
|
||||
{
|
||||
res = parseArgsAndRunBench(argc, argv);
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
std::cout << "Benchmark exited with unhandled exception: " << e.what() << std::endl;
|
||||
}
|
||||
|
||||
doCleanup();
|
||||
return res;
|
||||
}
|
||||
@ -204,12 +204,22 @@ else()
|
||||
"kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/libtensorrt_llm_nvrtc_wrapper.so"
|
||||
)
|
||||
else()
|
||||
set(NVRTC_WRAPPER_LIB_BINARY_REL_DIR
|
||||
"kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper"
|
||||
)
|
||||
set(NVRTC_WRAPPER_DLL_NAME "tensorrt_llm_nvrtc_wrapper.dll")
|
||||
set(NVRTC_WRAPPER_LIB_NAME "tensorrt_llm_nvrtc_wrapper.lib")
|
||||
|
||||
set(NVRTC_WRAPPER_LIB_SOURCE_REL_LOC
|
||||
"kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/${NVRTC_WRAPPER_TARGET_ARCH}/libtensorrt_llm_nvrtc_wrapper.dll"
|
||||
"${NVRTC_WRAPPER_LIB_BINARY_REL_DIR}/${NVRTC_WRAPPER_TARGET_ARCH}/${NVRTC_WRAPPER_DLL_NAME}"
|
||||
)
|
||||
set(NVRTC_WRAPPER_LIB_BINARY_REL_LOC
|
||||
"kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/libtensorrt_llm_nvrtc_wrapper.dll"
|
||||
"${NVRTC_WRAPPER_LIB_BINARY_REL_DIR}/${NVRTC_WRAPPER_DLL_NAME}")
|
||||
set(NVRTC_WRAPPER_IMPLIB_SOURCE_REL_LOC
|
||||
"${NVRTC_WRAPPER_LIB_BINARY_REL_DIR}/${NVRTC_WRAPPER_TARGET_ARCH}/${NVRTC_WRAPPER_LIB_NAME}"
|
||||
)
|
||||
set(NVRTC_WRAPPER_IMPLIB_BINARY_REL_LOC
|
||||
"${NVRTC_WRAPPER_LIB_BINARY_REL_DIR}/${NVRTC_WRAPPER_LIB_NAME}")
|
||||
endif()
|
||||
set(NVRTC_WRAPPER_LIB_LOC
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/${NVRTC_WRAPPER_LIB_SOURCE_REL_LOC}")
|
||||
@ -218,6 +228,15 @@ else()
|
||||
${NVRTC_WRAPPER_LIB_BINARY_REL_LOC} COPYONLY)
|
||||
set_property(TARGET ${NVRTC_WRAPPER_TARGET} PROPERTY IMPORTED_LOCATION
|
||||
${NVRTC_WRAPPER_LIB_LOC})
|
||||
if(WIN32)
|
||||
set(NVRTC_WRAPPER_IMPLIB_LOC
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/${NVRTC_WRAPPER_IMPLIB_SOURCE_REL_LOC}")
|
||||
configure_file(${NVRTC_WRAPPER_IMPLIB_SOURCE_REL_LOC}
|
||||
${NVRTC_WRAPPER_IMPLIB_BINARY_REL_LOC} COPYONLY)
|
||||
set_property(TARGET ${NVRTC_WRAPPER_TARGET}
|
||||
PROPERTY IMPORTED_IMPLIB ${NVRTC_WRAPPER_IMPLIB_LOC})
|
||||
endif()
|
||||
|
||||
file(SIZE ${NVRTC_WRAPPER_LIB_LOC} NVRTC_WRAPPER_LIB_SIZE)
|
||||
if(NVRTC_WRAPPER_LIB_SIZE LESS 1024)
|
||||
message(
|
||||
@ -234,7 +253,6 @@ set(TRTLLM_LINK_LIBS
|
||||
${TRT_LIB}
|
||||
common_src
|
||||
kernels_src
|
||||
context_attention_src
|
||||
decoder_attention_src
|
||||
fpA_intB_gemm_src
|
||||
moe_gemm_src
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6a75d7ce014817aaf87168c9db2381273fb2c91cd997663e90b476fe1c7d6503
|
||||
size 3308922
|
||||
oid sha256:2eaa6076f6ae6341710ca84097a3083f57235093bfc3bb1cd04f2bce06e1a99e
|
||||
size 3412616
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6a75d7ce014817aaf87168c9db2381273fb2c91cd997663e90b476fe1c7d6503
|
||||
size 3308922
|
||||
oid sha256:2eaa6076f6ae6341710ca84097a3083f57235093bfc3bb1cd04f2bce06e1a99e
|
||||
size 3412616
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
ed2ee8d73a5d374e800f653169bf293e libtensorrt_llm_batch_manager_static.a
|
||||
ed2ee8d73a5d374e800f653169bf293e libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
05aaf1a0fb2f0115af107b00aa839a6601f6a873 commit
|
||||
4b12adf3182aabf1df0b5dac217509e0 libtensorrt_llm_batch_manager_static.a
|
||||
4b12adf3182aabf1df0b5dac217509e0 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
fc46fa01e555f9f97387340e46e9571fabf73988 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:00e2d6ee8efd00e27dd8da61be576ba7978d885a055d591c90f600b334356846
|
||||
size 3211414
|
||||
oid sha256:e4fb588419244c8c07a9ce949edd7ed4e3dded008ed82aa993ff69e524394be9
|
||||
size 3310186
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b6b65183b0aa3f40f68aa13105da9dc00fb75b8bf8892813e46a09e3f0743570
|
||||
size 3186478
|
||||
oid sha256:4b31c50b2879f57022e788700ebf3a86fa8f30133a01533b03c7bc15d64ad364
|
||||
size 3283536
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7d4a3bc5160666612e529f21c61dbd9d0f1b387662768f76b9351f877108f84b
|
||||
size 19840380
|
||||
oid sha256:8f458ce861720a4ce15c9cedeae4bb6c1f6a8a98f0fced35198c9802feaddb10
|
||||
size 20305606
|
||||
|
||||
@ -207,8 +207,9 @@ __device__ __nv_bfloat16 atomicMaxExtd(__nv_bfloat16* address, __nv_bfloat16 val
|
||||
|
||||
return __ushort_as_bfloat16(old);
|
||||
#else
|
||||
asm volatile(" brkpt;\n");
|
||||
return 0;
|
||||
assert(0);
|
||||
asm volatile("brkpt;\n" ::);
|
||||
return __nv_bfloat16(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@ -598,8 +598,9 @@ __device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val)
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
||||
return __hmax(val.x, val.y);
|
||||
#else
|
||||
asm volatile(" brkpt;\n");
|
||||
return 0;
|
||||
assert(0);
|
||||
asm volatile("brkpt;\n" ::);
|
||||
return __nv_bfloat16(0);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -55,23 +55,23 @@ std::optional<int32_t> envXqaNbCtaPerKVHead()
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool getEnvDisableXQAJIT()
|
||||
bool getEnvEnableXQAJIT()
|
||||
{
|
||||
static bool init = false;
|
||||
static bool disableXQAJIT = false;
|
||||
static bool enableXQAJIT = false;
|
||||
if (!init)
|
||||
{
|
||||
init = true;
|
||||
char const* disable_xqa_jit_var = std::getenv("TRTLLM_DISABLE_XQA_JIT");
|
||||
if (disable_xqa_jit_var)
|
||||
char const* enable_xqa_jit_var = std::getenv("TRTLLM_ENABLE_XQA_JIT");
|
||||
if (enable_xqa_jit_var)
|
||||
{
|
||||
if (disable_xqa_jit_var[0] == '1' && disable_xqa_jit_var[1] == '\0')
|
||||
if (enable_xqa_jit_var[0] == '1' && enable_xqa_jit_var[1] == '\0')
|
||||
{
|
||||
disableXQAJIT = true;
|
||||
enableXQAJIT = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return disableXQAJIT;
|
||||
return enableXQAJIT;
|
||||
}
|
||||
|
||||
// Tune the number of blocks per sequence for accuracy/performance purpose.
|
||||
|
||||
@ -33,8 +33,8 @@ int32_t xqaMaxNbCtaPerKVHeadFactor();
|
||||
|
||||
std::optional<int32_t> envXqaNbCtaPerKVHead();
|
||||
|
||||
// Whether XQA JIT is disabled.
|
||||
bool getEnvDisableXQAJIT();
|
||||
// Whether XQA JIT is enabled.
|
||||
bool getEnvEnableXQAJIT();
|
||||
|
||||
// Tune the number of blocks per sequence for accuracy/performance purpose.
|
||||
bool getEnvMmhaMultiblockDebug();
|
||||
|
||||
@ -112,6 +112,9 @@ void initialize(MpiThreadSupport threadMode)
|
||||
|
||||
auto previousHandler = std::signal(SIGABRT, [](int signal) { MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); });
|
||||
TLLM_CHECK_WITH_INFO(previousHandler != SIG_ERR, "Signal handler setup failed");
|
||||
|
||||
// ensure local MPI communicator is initialized
|
||||
MpiComm::localSession();
|
||||
}
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
mpiInitialized = true;
|
||||
@ -271,6 +274,24 @@ MpiComm& MpiComm::session()
|
||||
return commSession;
|
||||
}
|
||||
|
||||
MpiComm getLocalSession()
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
MPI_Comm localComm;
|
||||
MPI_Comm_split_type(MPI_COMM_WORLD, OMPI_COMM_TYPE_HOST, 0, MPI_INFO_NULL, &localComm);
|
||||
MpiComm localSession{localComm, false};
|
||||
#else
|
||||
MpiComm localSession{MPI_COMM_WORLD, false};
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
return localSession;
|
||||
}
|
||||
|
||||
MpiComm& MpiComm::localSession()
|
||||
{
|
||||
static MpiComm localSession = getLocalSession();
|
||||
return localSession;
|
||||
}
|
||||
|
||||
MpiComm::MpiComm(MPI_Comm g, bool freeComm)
|
||||
: mComm{g}
|
||||
, mFreeComm{freeComm}
|
||||
|
||||
@ -25,6 +25,7 @@
|
||||
#include "cutlass_extensions/tile_interleaved_layout.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
|
||||
#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
@ -39,11 +40,22 @@ template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantO
|
||||
typename Enable = void>
|
||||
struct DefaultScaleIteratorsPipelined;
|
||||
|
||||
// TODO: Fine grained iterators
|
||||
// Fine grained iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIteratorsPipelined<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<isFinegrained(QuantOp)>>
|
||||
{
|
||||
private:
|
||||
using SmemScaleType = half_t;
|
||||
|
||||
public:
|
||||
using IteratorScale
|
||||
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>, Element,
|
||||
Layout, 0, Alignment>;
|
||||
|
||||
using SmemIteratorScale
|
||||
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>,
|
||||
SmemScaleType, Layout, 0, Alignment>;
|
||||
};
|
||||
|
||||
// Per column iterators
|
||||
@ -206,7 +218,6 @@ struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, Ele
|
||||
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
|
||||
|
||||
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
|
||||
using MmaCoreElementA = half_t;
|
||||
|
||||
@ -80,7 +80,7 @@ template <
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Data type for the scales
|
||||
/// Iterators over scales in global memory
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
|
||||
@ -80,13 +80,13 @@ template <
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Data type for the scales
|
||||
/// Iterators over scales in global memory
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
/// Layout of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
|
||||
@ -95,304 +95,12 @@ template <
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class DqMmaPipelined : public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>
|
||||
{
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>;
|
||||
|
||||
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
using Policy = Policy_; ///< Policy describing tuning details
|
||||
|
||||
using IteratorScale = IteratorScale_;
|
||||
using ElementScale = typename IteratorScale::Element;
|
||||
using LayoutScale = typename IteratorScale::Layout;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorScale = SmemIteratorScale_;
|
||||
|
||||
using TransformBAfterLDG = TransformBAfterLDG_;
|
||||
using TransformBAfterLDS = TransformBAfterLDS_;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Fragment of operand Scale loaded from global memory;
|
||||
using FragmentScale = typename IteratorScale::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
|
||||
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB,
|
||||
typename SmemIteratorScale::Fragment::Element, LayoutScale, 32, QuantOp>;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
|
||||
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave
|
||||
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
protected:
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
|
||||
SmemIteratorScale smem_iterator_scale_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaPipelined(typename Base::SharedStorage&
|
||||
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation
|
||||
///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this
|
||||
///< argument is not added, it does not affect compilation for sm>=80.
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
|
||||
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
|
||||
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
|
||||
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
|
||||
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
|
||||
{
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
|
||||
FragmentC& accum, ///< destination accumulator tile
|
||||
IteratorA iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB iterator_B, ///< iterator over B operand in global memory
|
||||
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
|
||||
FragmentC const& src_accum)
|
||||
{ ///< source accumulator tile
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
TransformBAfterLDG ldg_converter;
|
||||
TransformBAfterLDS lds_converter;
|
||||
|
||||
using TransformA
|
||||
= NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
|
||||
|
||||
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Fragment::Element,
|
||||
typename FragmentScale::Element, FragmentScale::kElements>;
|
||||
|
||||
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
|
||||
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
|
||||
TransformA transformA;
|
||||
TransformScale transformScale;
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
FragmentA tb_frag_A;
|
||||
FragmentB tb_frag_B;
|
||||
FragmentScale tb_frag_scales;
|
||||
|
||||
using WarpFragmentScale = typename Dequantizer::FragmentScale;
|
||||
WarpFragmentScale warp_frag_scales;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B.clear();
|
||||
tb_frag_scales.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
iterator_scale.load(tb_frag_scales);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
this->smem_iterator_scale_.store(transformScale(tb_frag_scales));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
warp_dequantizer_.load(warp_frag_scales);
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > 0; --gemm_k_iterations)
|
||||
{
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1)
|
||||
{
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1)
|
||||
{
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
|
||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||
{
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
if (warp_mma_k == 0)
|
||||
{
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 2);
|
||||
}
|
||||
|
||||
typename TransformBAfterLDS::result_type converted_frag_B
|
||||
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
|
||||
run_warp_mma(
|
||||
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
typename Enable = void>
|
||||
class DqMmaPipelined;
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h"
|
||||
|
||||
@ -0,0 +1,486 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Iterators over scales in global memory
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Converter for B matrix applied immediately after the LDG (before STS)
|
||||
typename TransformBAfterLDG_,
|
||||
/// Converter for B matrix applited immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_>
|
||||
class DqMmaPipelined<Shape_, IteratorA_, SmemIteratorA_, IteratorB_, SmemIteratorB_, IteratorScale_, SmemIteratorScale_,
|
||||
ElementC_, LayoutC_, Policy_, TransformBAfterLDG_, TransformBAfterLDS_, QuantOp_,
|
||||
std::enable_if_t<isFinegrained(QuantOp_)>>
|
||||
: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>
|
||||
{
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>;
|
||||
|
||||
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
using Policy = Policy_; ///< Policy describing tuning details
|
||||
|
||||
using IteratorScale = IteratorScale_;
|
||||
using ElementScale = typename IteratorScale::Element;
|
||||
using LayoutScale = typename IteratorScale::Layout;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorScale = SmemIteratorScale_;
|
||||
|
||||
using TransformBAfterLDG = TransformBAfterLDG_;
|
||||
using TransformBAfterLDS = TransformBAfterLDS_;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Fragment of operand Scale loaded from global memory;
|
||||
using FragmentScale = typename IteratorScale::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
|
||||
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB,
|
||||
typename SmemIteratorScale::Element, LayoutScale, 32, QuantOp>;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
|
||||
|
||||
static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, "");
|
||||
static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, "");
|
||||
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using WarpFragmentScale = typename Dequantizer::FragmentScale;
|
||||
using WarpFragmentZero = typename Dequantizer::FragmentZero;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave
|
||||
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
protected:
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory
|
||||
SmemIteratorScale smem_iterator_scale_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaPipelined(typename Base::SharedStorage&
|
||||
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int const group_size, ///< The group size for quantization
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
|
||||
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
||||
{shared_storage.operand_zero.data(), LayoutScale(Shape::kN)},
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
|
||||
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
|
||||
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
|
||||
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(),
|
||||
shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size)
|
||||
{
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_scales_and_advance(IteratorScale& iterator_scale)
|
||||
{
|
||||
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Element,
|
||||
typename FragmentScale::Element, FragmentScale::kElements>;
|
||||
|
||||
FragmentScale tb_frag_scales;
|
||||
FragmentScale tb_frag_zeros;
|
||||
tb_frag_scales.clear();
|
||||
tb_frag_zeros.clear();
|
||||
|
||||
TransformScale transformScale;
|
||||
|
||||
using FragmentElement = typename FragmentScale::Element;
|
||||
|
||||
auto gmem_scale_ptr = iterator_scale.get_scale();
|
||||
auto gmem_zero_ptr = iterator_scale.get_zero();
|
||||
|
||||
arch::global_load<FragmentScale, sizeof(FragmentScale)>(tb_frag_scales, gmem_scale_ptr, iterator_scale.valid());
|
||||
|
||||
if (gmem_zero_ptr != nullptr)
|
||||
{
|
||||
arch::global_load<FragmentScale, sizeof(FragmentScale)>(
|
||||
tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid());
|
||||
}
|
||||
|
||||
typename TransformScale::result_type tb_frag_scales_fp16 = transformScale(tb_frag_scales);
|
||||
typename TransformScale::result_type tb_frag_zeros_fp16;
|
||||
if (gmem_zero_ptr != nullptr)
|
||||
tb_frag_zeros_fp16 = transformScale(tb_frag_zeros);
|
||||
|
||||
auto frag_scale_ptr_fp16 = reinterpret_cast<typename SmemIteratorScale::Element*>(&tb_frag_scales_fp16);
|
||||
auto frag_zero_ptr_fp16 = reinterpret_cast<typename SmemIteratorScale::Element*>(&tb_frag_zeros_fp16);
|
||||
auto smem_scale_ptr = this->smem_iterator_scale_.get_scale();
|
||||
auto smem_zero_ptr = this->smem_iterator_scale_.get_zero();
|
||||
|
||||
if (iterator_scale.valid())
|
||||
{
|
||||
auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr);
|
||||
arch::shared_store<sizeof(FragmentScale)>(smem_offset, frag_scale_ptr_fp16);
|
||||
|
||||
if (gmem_zero_ptr != nullptr)
|
||||
{
|
||||
smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr);
|
||||
arch::shared_store<sizeof(FragmentScale)>(smem_offset, frag_zero_ptr_fp16);
|
||||
}
|
||||
}
|
||||
|
||||
if (iterator_scale.group_size_ == 64)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
else if (iterator_scale.group_size_ == 128)
|
||||
{
|
||||
if constexpr (Shape::kK == 128)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
else if constexpr (Shape::kK == 64)
|
||||
{
|
||||
if (iterator_scale.row_groupsize64_ & 0x1)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128");
|
||||
}
|
||||
}
|
||||
|
||||
iterator_scale.row_groupsize64_++;
|
||||
|
||||
this->smem_iterator_scale_.add_tile_offset({1, 0});
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
|
||||
FragmentC& accum, ///< destination accumulator tile
|
||||
IteratorA iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB iterator_B, ///< iterator over B operand in global memory
|
||||
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
|
||||
FragmentC const& src_accum)
|
||||
{ ///< source accumulator tile
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
TransformBAfterLDG ldg_converter;
|
||||
TransformBAfterLDS lds_converter;
|
||||
|
||||
using TransformA
|
||||
= NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
|
||||
|
||||
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
|
||||
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
|
||||
TransformA transformA;
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
FragmentA tb_frag_A;
|
||||
FragmentB tb_frag_B;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
copy_scales_and_advance(iterator_scale);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
WarpFragmentScale warp_frag_scales;
|
||||
WarpFragmentZero warp_frag_zero;
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
|
||||
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
warp_dequantizer_.add_pointer_offset(Shape::kN);
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_scale.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > 0; --gemm_k_iterations)
|
||||
{
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1)
|
||||
{
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1)
|
||||
{
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
||||
warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN);
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
|
||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||
{
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
if (warp_mma_k == 0)
|
||||
{
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
copy_scales_and_advance(iterator_scale);
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_scale.clear_mask(gemm_k_iterations <= 2);
|
||||
}
|
||||
|
||||
typename TransformBAfterLDS::result_type converted_frag_B
|
||||
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zero);
|
||||
run_warp_mma(
|
||||
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
|
||||
}
|
||||
|
||||
// Load the scales needed for the next tile iteration
|
||||
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
|
||||
// Update internal pointer to the set of scales in shared memory
|
||||
warp_dequantizer_.add_pointer_offset(Shape::kN);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,399 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Iterators over scales in global memory
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Converter for B matrix applied immediately after the LDG (before STS)
|
||||
typename TransformBAfterLDG_,
|
||||
/// Converter for B matrix applited immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_>
|
||||
class DqMmaPipelined<Shape_, IteratorA_, SmemIteratorA_, IteratorB_, SmemIteratorB_, IteratorScale_, SmemIteratorScale_,
|
||||
ElementC_, LayoutC_, Policy_, TransformBAfterLDG_, TransformBAfterLDS_, QuantOp_,
|
||||
std::enable_if_t<!isFinegrained(QuantOp_)>>
|
||||
: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>
|
||||
{
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>;
|
||||
|
||||
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
using Policy = Policy_; ///< Policy describing tuning details
|
||||
|
||||
using IteratorScale = IteratorScale_;
|
||||
using ElementScale = typename IteratorScale::Element;
|
||||
using LayoutScale = typename IteratorScale::Layout;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorScale = SmemIteratorScale_;
|
||||
|
||||
using TransformBAfterLDG = TransformBAfterLDG_;
|
||||
using TransformBAfterLDS = TransformBAfterLDS_;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Fragment of operand Scale loaded from global memory;
|
||||
using FragmentScale = typename IteratorScale::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
|
||||
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB,
|
||||
typename SmemIteratorScale::Fragment::Element, LayoutScale, 32, QuantOp>;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
|
||||
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave
|
||||
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
protected:
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
|
||||
SmemIteratorScale smem_iterator_scale_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaPipelined(typename Base::SharedStorage&
|
||||
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation
|
||||
///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this
|
||||
///< argument is not added, it does not affect compilation for sm>=80.
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
|
||||
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
|
||||
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
|
||||
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
|
||||
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
|
||||
{
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
|
||||
FragmentC& accum, ///< destination accumulator tile
|
||||
IteratorA iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB iterator_B, ///< iterator over B operand in global memory
|
||||
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
|
||||
FragmentC const& src_accum)
|
||||
{ ///< source accumulator tile
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
TransformBAfterLDG ldg_converter;
|
||||
TransformBAfterLDS lds_converter;
|
||||
|
||||
using TransformA
|
||||
= NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
|
||||
|
||||
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Fragment::Element,
|
||||
typename FragmentScale::Element, FragmentScale::kElements>;
|
||||
|
||||
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
|
||||
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
|
||||
TransformA transformA;
|
||||
TransformScale transformScale;
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
FragmentA tb_frag_A;
|
||||
FragmentB tb_frag_B;
|
||||
FragmentScale tb_frag_scales;
|
||||
|
||||
using WarpFragmentScale = typename Dequantizer::FragmentScale;
|
||||
WarpFragmentScale warp_frag_scales;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B.clear();
|
||||
tb_frag_scales.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
iterator_scale.load(tb_frag_scales);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
this->smem_iterator_scale_.store(transformScale(tb_frag_scales));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
warp_dequantizer_.load(warp_frag_scales);
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > 0; --gemm_k_iterations)
|
||||
{
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1)
|
||||
{
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1)
|
||||
{
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
|
||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||
{
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
if (warp_mma_k == 0)
|
||||
{
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 2);
|
||||
}
|
||||
|
||||
typename TransformBAfterLDS::result_type converted_frag_B
|
||||
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
|
||||
run_warp_mma(
|
||||
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -16,6 +16,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace cutlass_extensions
|
||||
@ -153,6 +157,32 @@ struct CutlassGemmConfig
|
||||
, is_sm90(true)
|
||||
{
|
||||
}
|
||||
|
||||
std::string toString()
|
||||
{
|
||||
std::stringstream tactic;
|
||||
tactic << "Cutlass GEMM Tactic";
|
||||
if (tile_config_sm90 != tensorrt_llm::cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic)
|
||||
{
|
||||
assert(is_sm90 && "Invalid cutlass GEMM config");
|
||||
tactic << "\n\tstyle=TMA"
|
||||
<< "\n\ttile shape ID: " << (int) tile_config_sm90 << "\n\tcluster shape ID: " << (int) cluster_shape
|
||||
<< "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule;
|
||||
}
|
||||
else if (tile_config != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic)
|
||||
{
|
||||
assert(!is_sm90 && "Invalid cutlass GEMM config");
|
||||
tactic << "\n\tstyle=compatible"
|
||||
<< "\n\ttile shape ID: " << (int) tile_config << "\n\tstages: " << (int) stages
|
||||
<< "\n\tsplit k: " << (int) split_k_factor;
|
||||
}
|
||||
else
|
||||
{
|
||||
tactic << "\n\tundefined";
|
||||
}
|
||||
tactic << "\n";
|
||||
return tactic.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass_extensions
|
||||
|
||||
@ -89,6 +89,8 @@ public:
|
||||
|
||||
using AccessType = AlignedArray<Element, kAlignment>;
|
||||
|
||||
using Fragment = cutlass::Array<Element, kAlignment>;
|
||||
|
||||
// For compatibility with existing iterator interface
|
||||
struct Params
|
||||
{
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8ff03e99e17e64c9f559e4586dec3983d438857c0050a34a417e4d86d56fbe2a
|
||||
size 1251854
|
||||
oid sha256:dc6db658e67f3ea6a11bc37be2b90090f23b831719559a3202bf8b2b397f6f3b
|
||||
size 1334290
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8ff03e99e17e64c9f559e4586dec3983d438857c0050a34a417e4d86d56fbe2a
|
||||
size 1251854
|
||||
oid sha256:dc6db658e67f3ea6a11bc37be2b90090f23b831719559a3202bf8b2b397f6f3b
|
||||
size 1334290
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
54670adde093baff8b031869bdeeeb1b libtensorrt_llm_executor_static.a
|
||||
54670adde093baff8b031869bdeeeb1b libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
05aaf1a0fb2f0115af107b00aa839a6601f6a873 commit
|
||||
8e5f1d6bb88c80004b4260aa2d022420 libtensorrt_llm_executor_static.a
|
||||
8e5f1d6bb88c80004b4260aa2d022420 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
fc46fa01e555f9f97387340e46e9571fabf73988 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:431dc6352dcb332821aab031ccbd887e6a60591a5ea276a9ffd3df1f28463326
|
||||
size 1271014
|
||||
oid sha256:f0421ca1e637adfdebc9718c47537ed81b55cad4f7fbd062b1d83ca0ab7ebbe5
|
||||
size 1371948
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:86319542d275570a0c66622d4656b88f3d153c6861db0e53f17f29d47e0a30c9
|
||||
size 1227362
|
||||
oid sha256:46b6253eef9136f91d2877e9baa827a8ff229b54b6fb1f2717fb6c85a7ffa047
|
||||
size 1306830
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8ed99448579b40e0046eca5c8989151a66579f8fbccef9cda4ee7fc2ffd2245b
|
||||
size 12076106
|
||||
oid sha256:fa1bde9d020eac84321954b0cb393ec7af63393e29947d6c147700063b0267da
|
||||
size 12726212
|
||||
|
||||
@ -80,6 +80,8 @@ int main(int argc, char* argv[])
|
||||
// In orchestrator mode, the spawned threads will wait for termination signal from orchestrator
|
||||
auto executor = tle::Executor(modelPath, modelType, executorConfig);
|
||||
|
||||
// Wait for all workers to have created their instances
|
||||
MPI_Barrier(parentComm);
|
||||
TLLM_LOG_INFO("Executor instance created by worker");
|
||||
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
@ -18,12 +18,9 @@
|
||||
file(GLOB_RECURSE SRC_CPP *.cpp)
|
||||
file(GLOB_RECURSE SRC_CU *.cu)
|
||||
|
||||
# Exclude files in the cutlass_kernels, contextFusedMultiHeadAttention and
|
||||
# unfusedAttentionKernels folder
|
||||
# Exclude files in the cutlass_kernels and unfusedAttentionKernels folder
|
||||
list(FILTER SRC_CPP EXCLUDE REGEX "cutlass_kernels/.*")
|
||||
list(FILTER SRC_CU EXCLUDE REGEX "cutlass_kernels/.*")
|
||||
list(FILTER SRC_CPP EXCLUDE REGEX "contextFusedMultiHeadAttention/.*")
|
||||
list(FILTER SRC_CU EXCLUDE REGEX "contextFusedMultiHeadAttention/.*")
|
||||
list(FILTER SRC_CPP EXCLUDE REGEX "decoderMaskedMultiheadAttention/.*")
|
||||
list(FILTER SRC_CU EXCLUDE REGEX "decoderMaskedMultiheadAttention/.*")
|
||||
|
||||
@ -37,4 +34,3 @@ set_property(TARGET kernels_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
|
||||
|
||||
add_subdirectory(cutlass_kernels)
|
||||
add_subdirectory(decoderMaskedMultiheadAttention)
|
||||
add_subdirectory(contextFusedMultiHeadAttention)
|
||||
|
||||
@ -1,24 +0,0 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
|
||||
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain a copy of
|
||||
# the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
#
|
||||
|
||||
file(GLOB_RECURSE SRC_CPP *.cpp)
|
||||
file(GLOB_RECURSE SRC_CU *.cu)
|
||||
|
||||
add_library(context_attention_src OBJECT ${SRC_CPP} ${SRC_CU})
|
||||
set_property(TARGET context_attention_src PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET context_attention_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS
|
||||
ON)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user