Update TensorRT-LLM (#613)

* Update TensorRT-LLM

---------

Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
Co-authored-by: zhang-ge-hao <842720660@qq.com>
This commit is contained in:
Kaiyu Xie 2023-12-08 17:49:24 +08:00 committed by GitHub
parent 42af740db5
commit f7eca56161
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
213 changed files with 31369 additions and 15451 deletions

View File

@ -44,4 +44,5 @@ repos:
- id: codespell
args:
- --skip=".git,3rdparty"
- --exclude-file=examples/whisper/tokenizer.py
- --ignore-words-list=rouge,inout,atleast,strat

View File

@ -257,7 +257,7 @@ The list of supported models is:
* [InternLM](examples/internlm)
* [LLaMA](examples/llama)
* [LLaMA-v2](examples/llama)
* [Mistral](examples/llama)
* [Mistral](examples/llama#mistral-v01)
* [MPT](examples/mpt)
* [mT5](examples/enc_dec)
* [OPT](examples/opt)
@ -266,9 +266,10 @@ The list of supported models is:
* [SantaCoder](examples/gpt)
* [StarCoder](examples/gpt)
* [T5](examples/enc_dec)
* [Whisper](examples/whisper)
Note: [Encoder-Decoder](examples/enc_dec/) provides general encoder-decoder
support that contains many encoder-decoder models such as T5, Flan-T5, etc. We
functionality that supports many encoder-decoder models such as T5 family, BART family, Whisper family, etc. We
unroll the exact model names in the list above to let users find specific
models easier.
@ -372,7 +373,7 @@ For example: `mpirun -n 1 python3 examples/gpt/build.py ...`
### Change Log
#### Version 0.6.0
#### Version 0.6.1
* Models
* ChatGLM3

View File

@ -257,6 +257,7 @@ int main(int argc, char* argv[])
options.add_options()("ctx_micro_batch_size", "Batch size for context phase.", cxxopts::value<int>());
options.add_options()("gen_micro_batch_size", "Batch size for generation phase.", cxxopts::value<int>());
options.add_options()("max_attention_window", "Max kv cache length per sequence.", cxxopts::value<int>());
options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>());
options.add_options()(
"kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.", cxxopts::value<float>());
@ -352,6 +353,11 @@ int main(int argc, char* argv[])
{
sessionConfig.kvCacheConfig.maxTokens = result["max_tokens_in_paged_kvcache"].as<int>();
}
// Argument: Max KV Cache Length
if (result.count("max_attention_window"))
{
sessionConfig.kvCacheConfig.maxAttentionWindow = result["max_attention_window"].as<int>();
}
// Argument: K-V Cache Free Gpu Mem Fraction
if (result.count("kv_cache_free_gpu_mem_fraction"))
{

View File

@ -12,6 +12,7 @@ The benchmark implementation and entrypoint can be found in [`benchmarks/python/
* [`benchmarks/python/base_benchmark.py`](./base_benchmark.py) to implement the base class for benchmark.
* [`benchmarks/python/gpt_benchmark.py`](./gpt_benchmark.py) to implement benchmark scripts for GPT and GPT-like(LLaMA/OPT/GPT-J/SmoothQuant-GPT) models.
* [`benchmarks/python/bert_benchmark.py`](./bert_benchmark.py) to implement benchmark scripts for BERT models.
* [`benchmarks/python/enc_dec_benchmark.py`](./enc_dec_benchmark.py) to implement benchmark scripts for Encoder-Decoder models.
## Usage

View File

@ -27,7 +27,7 @@ class BuildConfig(BaseModel, extra=Extra.allow):
hidden_act: Optional[str]
n_positions: int
max_batch_size: int
max_input_len: int
max_input_len: Optional[int] = None
num_kv_heads: Optional[int] = None
max_output_len: Optional[int] = None
max_beam_width: int = 1
@ -54,10 +54,25 @@ class BuildConfig(BaseModel, extra=Extra.allow):
moe_top_k: int = None
class EncDecBuildConfig(BuildConfig, extra=Extra.allow):
num_decoder_layers: Optional[int] = None
head_size: Optional[int] = None
ffn_hidden_size: Optional[int] = None
num_buckets: Optional[int] = None
max_distance: Optional[int] = None
max_encoder_input_len: Optional[int] = None
max_decoder_input_len: Optional[int] = None
def __post_init__(self) -> None:
assert self.head_size is not None
assert self.ffn_hidden_size is not None
assert self.num_buckets is not None
class ModelConfig(BaseModel):
name: str
family: str
benchmark_type: Literal["gpt", "bert"]
benchmark_type: Literal["gpt", "bert", "enc_dec"]
build_config: BuildConfig
@ -263,7 +278,7 @@ _allowed_configs = {
hidden_size=4096,
vocab_size=32000,
hidden_act='silu',
n_positions=2048,
n_positions=4096,
inter_size=11008,
max_batch_size=128,
max_input_len=512,
@ -280,7 +295,7 @@ _allowed_configs = {
hidden_size=5120,
vocab_size=32000,
hidden_act='silu',
n_positions=2048,
n_positions=4096,
inter_size=13824,
max_batch_size=128,
max_input_len=512,
@ -315,7 +330,7 @@ _allowed_configs = {
hidden_size=8192,
vocab_size=32000,
hidden_act='silu',
n_positions=2048,
n_positions=4096,
inter_size=28672,
max_batch_size=64,
max_input_len=512,
@ -332,7 +347,7 @@ _allowed_configs = {
hidden_size=8192,
vocab_size=32000,
hidden_act='silu',
n_positions=2048,
n_positions=4096,
inter_size=28672,
max_batch_size=16,
max_input_len=8000,
@ -349,7 +364,7 @@ _allowed_configs = {
hidden_size=8192,
vocab_size=32000,
hidden_act='silu',
n_positions=2048,
n_positions=4096,
inter_size=28672,
max_batch_size=64,
max_input_len=200,
@ -366,7 +381,7 @@ _allowed_configs = {
hidden_size=8192,
vocab_size=32000,
hidden_act='silu',
n_positions=2048,
n_positions=4096,
inter_size=28672,
max_batch_size=128,
max_input_len=512,
@ -588,7 +603,7 @@ _allowed_configs = {
bias=False,
use_alibi=False,
parallel_attention=True,
new_decoder_architecture=False,
new_decoder_architecture=True,
)),
"falcon_180b":
ModelConfig(name="falcon_180b",
@ -609,7 +624,217 @@ _allowed_configs = {
bias=False,
use_alibi=False,
parallel_attention=True,
new_decoder_architecture=False,
new_decoder_architecture=True,
)),
"t5_small":
ModelConfig(name="t5_small",
family="t5",
benchmark_type="enc_dec",
build_config=EncDecBuildConfig(
num_layers=6,
num_heads=8,
head_size=64,
ffn_hidden_size=2048,
hidden_size=512,
vocab_size=32128,
hidden_act="relu",
n_positions=512,
num_buckets=32,
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
builder_opt=None,
)),
"t5_base":
ModelConfig(name="t5_base",
family="t5",
benchmark_type="enc_dec",
build_config=EncDecBuildConfig(
num_layers=12,
num_heads=12,
head_size=64,
ffn_hidden_size=3072,
hidden_size=768,
vocab_size=32128,
hidden_act="relu",
n_positions=512,
num_buckets=32,
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
builder_opt=None,
)),
"t5_large":
ModelConfig(name="t5_large",
family="t5",
benchmark_type="enc_dec",
build_config=EncDecBuildConfig(
num_layers=24,
num_heads=16,
head_size=64,
ffn_hidden_size=4096,
hidden_size=1024,
vocab_size=32128,
hidden_act="relu",
n_positions=512,
num_buckets=32,
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
builder_opt=None,
)),
"t5_3b":
ModelConfig(name="t5_3b",
family="t5",
benchmark_type="enc_dec",
build_config=EncDecBuildConfig(
num_layers=24,
num_heads=32,
head_size=128,
ffn_hidden_size=16384,
hidden_size=1024,
vocab_size=32128,
hidden_act="relu",
n_positions=512,
num_buckets=32,
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
builder_opt=None,
)),
"t5_11b":
ModelConfig(name="t5_11b",
family="t5",
benchmark_type="enc_dec",
build_config=EncDecBuildConfig(
num_layers=24,
num_heads=128,
head_size=128,
ffn_hidden_size=65536,
hidden_size=1024,
vocab_size=32128,
hidden_act="relu",
n_positions=512,
num_buckets=32,
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
builder_opt=None,
)),
"flan_t5_small":
ModelConfig(name="flan_t5_small",
family="t5",
benchmark_type="enc_dec",
build_config=EncDecBuildConfig(
num_layers=8,
num_decoder_layers=8,
num_heads=6,
head_size=64,
ffn_hidden_size=1024,
hidden_size=512,
vocab_size=32128,
hidden_act="gated-gelu",
n_positions=512,
num_buckets=32,
max_distance=128,
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
builder_opt=None,
)),
"flan_t5_base":
ModelConfig(name="flan_t5_base",
family="t5",
benchmark_type="enc_dec",
build_config=EncDecBuildConfig(
num_layers=12,
num_decoder_layers=12,
num_heads=12,
head_size=64,
ffn_hidden_size=2048,
hidden_size=768,
vocab_size=32128,
hidden_act="gated-gelu",
n_positions=512,
num_buckets=32,
max_distance=128,
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
builder_opt=None,
)),
"flan_t5_large":
ModelConfig(name="flan_t5_large",
family="t5",
benchmark_type="enc_dec",
build_config=EncDecBuildConfig(
num_layers=24,
num_decoder_layers=24,
num_heads=16,
head_size=64,
ffn_hidden_size=2816,
hidden_size=1024,
vocab_size=32128,
hidden_act="gated-gelu",
n_positions=512,
num_buckets=32,
max_distance=128,
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
builder_opt=None,
)),
"flan_t5_xl":
ModelConfig(name="flan_t5_xl",
family="t5",
benchmark_type="enc_dec",
build_config=EncDecBuildConfig(
num_layers=24,
num_decoder_layers=24,
num_heads=32,
head_size=64,
ffn_hidden_size=5120,
hidden_size=2048,
vocab_size=32128,
hidden_act="gated-gelu",
n_positions=512,
num_buckets=32,
max_distance=128,
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
builder_opt=None,
)),
"flan_t5_xxl":
ModelConfig(name="flan_t5_xxl",
family="t5",
benchmark_type="enc_dec",
build_config=EncDecBuildConfig(
num_layers=24,
num_decoder_layers=24,
num_heads=64,
head_size=64,
ffn_hidden_size=10240,
hidden_size=4096,
vocab_size=32128,
hidden_act="gelu_new",
n_positions=0,
num_buckets=32,
max_distance=128,
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
builder_opt=None,
)),
}

View File

@ -59,12 +59,18 @@ def serialize_engine(engine, path):
class BaseBenchmark(object):
def __init__(self, engine_dir, model_name, dtype):
def __init__(self,
engine_dir,
model_name,
dtype,
rank,
world_size,
serial_build: bool = False):
self.engine_dir = engine_dir
self.model_name = model_name
self.dtype = dtype
self.runtime_rank = tensorrt_llm.mpi_rank()
self.world_size = tensorrt_llm.mpi_world_size()
self.runtime_rank = rank
self.world_size = world_size
self.engine_model_name = model_name
self.quant_mode = QuantMode(0)
self.enable_fp8 = False
@ -100,8 +106,9 @@ class BaseBenchmark(object):
self.runtime_mapping = tensorrt_llm.Mapping(world_size=self.world_size,
rank=self.runtime_rank,
tp_size=self.world_size)
torch.cuda.set_device(self.runtime_rank %
self.runtime_mapping.gpus_per_node)
if not serial_build:
torch.cuda.set_device(self.runtime_rank %
self.runtime_mapping.gpus_per_node)
self.csv_filename = "" # lazy init

View File

@ -141,6 +141,24 @@ def parse_arguments():
help=
('If this option is specified, it will override the max input len of '
'TRT engines to the specified value instead of using pre-defined one'))
parser.add_argument(
'--max_encoder_input_len',
type=int,
default=None,
help=
('This argument is only for encoder-decoder models'
'If this option is specified, it will override the max encoder input len of TRT engines to the specified value instead of using pre-defined one'
'By default when this option is not used, it will use pre-defined max encoder input len'
))
parser.add_argument(
'--max_decoder_input_len',
type=int,
default=None,
help=
('This argument is only for encoder-decoder models'
'If this option is specified, it will override the max decoder input len of TRT engines to the specified value instead of using pre-defined one'
'By default when this option is not used, it will use pre-defined max decoder input len'
))
parser.add_argument(
'--max_output_len',
type=int,
@ -181,6 +199,33 @@ def parse_arguments():
'int4_weight_only_awq', 'int4_weight_only_gptq'
],
help="Optimize the model with specified quantization recipe")
parser.add_argument(
'--build_only',
default=False,
action='store_true',
help=
"Build engine only and skip inference, this can help to benchmark the build time on single gpu node for multi GPU model, where the inference is not possible"
)
parser.add_argument('--serial_build',
default=False,
action='store_true',
help="Build engines serially")
parser.add_argument(
'--rank',
type=int,
default=None,
help=
"The rank of the model to be built, only used when --build_only and --serial_build is specified"
)
parser.add_argument(
'--world_size',
type=int,
default=None,
help=
"The number of gpus to be used for inference, only used when --build_only and --serial_build is specified"
)
return parser.parse_args()
@ -193,9 +238,11 @@ def main(args):
from allowed_configs import get_allowed_models
from benchmark_profiler import BenchmarkProfiler
from bert_benchmark import BERTBenchmark
from enc_dec_benchmark import EncDecBenchmark
from gpt_benchmark import GPTBenchmark
from mem_monitor import MemoryMonitor
import tensorrt_llm
from tensorrt_llm.logger import logger
logger.set_level(args.log_level)
@ -206,20 +253,40 @@ def main(args):
# Input length (for BERT-like models)
input_len_options = args.input_len.split(';')
input_len_options = [int(i) for i in input_len_options]
# Input-output length combination (for GPT-like models)
# Input-output length combination (for GPT-like models and enc_dec models)
in_out_len_options = args.input_output_len.split(';')
in_out_len_options = [[int(i) for i in io.split(',')]
for io in in_out_len_options]
if args.serial_build and not args.build_only:
raise Exception(
f"--serial_build must be used with --build_only, always need to parallel build to do inference in the same process"
)
if args.build_only and args.serial_build and args.rank is not None and args.world_size is not None:
rank = args.rank
world_size = args.world_size
else:
rank = tensorrt_llm.mpi_rank()
world_size = tensorrt_llm.mpi_world_size()
benchmark_profiler = None
if args.model in get_allowed_models(benchmark_type="gpt"):
benchmark_profiler = BenchmarkProfiler()
benchmarker = GPTBenchmark(args, batch_size_options, in_out_len_options)
benchmarker = GPTBenchmark(args, batch_size_options, in_out_len_options,
rank, world_size)
elif args.model in get_allowed_models(benchmark_type="bert"):
benchmarker = BERTBenchmark(args, batch_size_options, input_len_options)
benchmarker = BERTBenchmark(args, batch_size_options, input_len_options,
rank, world_size)
elif args.model in get_allowed_models(benchmark_type="enc_dec"):
benchmarker = EncDecBenchmark(args, batch_size_options,
in_out_len_options, rank, world_size)
else:
raise Exception(f'Unexpected model: {args.model}')
if args.build_only:
return
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
benchmarker.print_report_header(args.csv,

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
# isort: off
import torch
@ -30,8 +29,9 @@ from tensorrt_llm.runtime import TensorInfo
class BERTBenchmark(BaseBenchmark):
def __init__(self, args, batch_sizes, in_lens):
super().__init__(args.engine_dir, args.model, args.dtype)
def __init__(self, args, batch_sizes, in_lens, rank, world_size):
super().__init__(args.engine_dir, args.model, args.dtype, rank,
world_size, args.serial_build)
self.batch_sizes = batch_sizes
self.in_lens = in_lens
self.build_time = 0
@ -49,12 +49,17 @@ class BERTBenchmark(BaseBenchmark):
setattr(self, key, value)
if args.force_num_layer_1:
self.num_layers = 1
if args.max_batch_size is not None:
self.max_batch_size = args.max_batch_size
if args.max_input_len is not None:
self.max_input_len = args.max_input_len
start = time.time()
engine_buffer = build_bert(args)
self.build_time = round(time.time() - start, 2)
engine_buffer, build_time = build_bert(args)
self.build_time = build_time
assert engine_buffer is not None
if args.build_only:
return
self.session = tensorrt_llm.runtime.Session.from_serialized_engine(
engine_buffer)

View File

@ -15,6 +15,7 @@
import argparse
import multiprocessing as mp
import os
import time
from collections import OrderedDict
import tensorrt as trt
@ -28,7 +29,7 @@ from tensorrt_llm._utils import str_dtype_to_trt
from tensorrt_llm.builder import Builder
from tensorrt_llm.layers import PositionEmbeddingType
from tensorrt_llm.logger import logger
from tensorrt_llm.models import quantize_model
from tensorrt_llm.models import PretrainedConfig, quantize_model
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
from tensorrt_llm.quantization import QuantMode
@ -117,6 +118,25 @@ def parse_arguments():
default=False,
action='store_true',
help='Quick sanity check with num_layer=1.')
parser.add_argument('--serial_build',
default=False,
action='store_true',
help="Build engines serially")
parser.add_argument(
'--rank',
type=int,
default=None,
help=
"The rank of the model to be built, only used when --serial_build is specified"
)
parser.add_argument(
'--world_size',
type=int,
default=None,
help=
"The number of gpus to be used for inference, only used when --serial_build is specified"
)
return parser.parse_args()
@ -194,9 +214,15 @@ def build_gpt(args):
build_config['num_layers'] = 1
# More parameters
world_size = tensorrt_llm.mpi_world_size()
runtime_rank = tensorrt_llm.mpi_rank()
torch.cuda.set_device(runtime_rank)
if args.serial_build and args.rank is not None and args.world_size is not None:
runtime_rank = args.rank
world_size = args.world_size
else:
runtime_rank = tensorrt_llm.mpi_rank()
world_size = tensorrt_llm.mpi_world_size()
if not args.serial_build:
torch.cuda.set_device(runtime_rank)
num_kv_heads = build_config['num_heads'] \
if build_config['num_kv_heads'] is None else build_config['num_kv_heads']
apply_query_key_layer_scaling = False
@ -265,18 +291,26 @@ def build_gpt(args):
moe_layer_config=tensorrt_llm.moe_config.MoeLayerConfig(
build_config["moe_num_experts"], build_config["moe_top_k"]))
elif family == "opt":
tensorrt_llm_model = tensorrt_llm.models.OPTLMHeadModel(
num_layers=build_config['num_layers'],
num_heads=build_config['num_heads'],
hidden_size=build_config['hidden_size'],
vocab_size=build_config['vocab_size'],
hidden_act=build_config['hidden_act'],
max_position_embeddings=build_config['n_positions'],
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(world_size=world_size,
tp_size=world_size), # TP only
pre_norm=build_config['pre_norm'],
do_layer_norm_before=build_config['do_layer_norm_before'])
config = {
'architecture': 'OPTForCausalLM',
'dtype': args.dtype,
'vocab_size': build_config['vocab_size'],
'hidden_size': build_config['hidden_size'],
'num_hidden_layers': build_config['num_layers'],
'num_attention_heads': build_config['num_heads'],
'hidden_act': build_config['hidden_act'],
'max_position_embeddings': build_config['n_positions'],
'mapping': {
'world_size': world_size,
'tp_size': world_size
},
'use_parallel_embedding': False,
'share_embedding_table': False,
'embedding_sharding_dim': 0,
'do_layer_norm_before': build_config['do_layer_norm_before']
}
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.OPTForCausalLM(config)
elif family == "llama":
tensorrt_llm_model = tensorrt_llm.models.LLaMAForCausalLM(
num_layers=build_config['num_layers'],
@ -467,8 +501,10 @@ def build_gpt(args):
tensorrt_llm.graph_rewriting.optimize(network)
# Network -> Engine
start = time.time()
engine = builder.build_engine(network, builder_config)
assert engine is not None, f'Failed to build engine for rank {runtime_rank}'
build_time = round(time.time() - start, 2)
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
@ -478,7 +514,7 @@ def build_gpt(args):
config_path = os.path.join(args.output_dir, 'config.json')
builder_config.plugin_config = network.plugin_config
builder.save_config(builder_config, config_path)
return engine
return engine, build_time
def build_bert(args):
@ -487,9 +523,15 @@ def build_bert(args):
build_config['num_layers'] = 1
# More parameters
world_size = tensorrt_llm.mpi_world_size()
runtime_rank = tensorrt_llm.mpi_rank()
torch.cuda.set_device(runtime_rank)
if args.serial_build and args.rank is not None and args.world_size is not None:
runtime_rank = args.rank
world_size = args.world_size
else:
runtime_rank = tensorrt_llm.mpi_rank()
world_size = tensorrt_llm.mpi_world_size()
if not args.serial_build:
torch.cuda.set_device(runtime_rank)
num_kv_heads = build_config['num_heads'] \
if build_config['num_kv_heads'] is None else build_config['num_kv_heads']
max_batch_size = build_config['max_batch_size'] \
@ -575,8 +617,10 @@ def build_bert(args):
hidden_states.mark_output('hidden_states', hidden_states_dtype)
# Network -> Engine
start = time.time()
engine = builder.build_engine(network, builder_config)
assert engine is not None, f'Failed to build engine for rank {runtime_rank}'
build_time = round(time.time() - start, 2)
if args.output_dir is not None:
if not os.path.exists(args.output_dir):
@ -587,7 +631,7 @@ def build_bert(args):
config_path = os.path.join(args.output_dir, 'config.json')
builder_config.plugin_config = network.plugin_config
builder.save_config(builder_config, config_path)
return engine
return engine, build_time
def main(args):

View File

@ -0,0 +1,341 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import torch
from base_benchmark import BaseBenchmark, get_engine_name
import tensorrt_llm
from tensorrt_llm._utils import trt_dtype_to_torch
class EncDecBenchmark(BaseBenchmark):
def __init__(self, args, batch_sizes, in_out_lens, rank, world_size):
self.engine_dir = args.engine_dir
self.model_name = args.model
self.mode = args.mode
self.enable_fp8 = False # hardcode for enc-dec models
self.dtype = args.dtype
self.output_dir = args.output_dir
self.runtime_rank = rank
self.world_size = world_size
self.csv_filename = "" # lazy init
self.batch_sizes = batch_sizes
self.in_out_lens = in_out_lens
self.num_beams = args.num_beams
self.build_time = 0
# In current implementation, encoder and decoder have the same name,
# builder config, and plugin config. But they can be different in the future.
# So we use separate variables for encoder and decoder here.
self.encoder_engine_model_name = args.model
self.decoder_engine_model_name = args.model
if self.engine_dir is not None:
def read_config(component):
config_path = os.path.join(self.engine_dir, component,
"config.json")
with open(config_path, "r") as f:
config = json.load(f)
# Sanity checks
config_dtype = config["builder_config"]["precision"]
assert (
self.dtype == config_dtype
), f"Engine dtype ({config_dtype}) != Runtime dtype ({self.dtype})"
world_size = config["builder_config"]["tensor_parallel"]
assert (
world_size == self.world_size
), f"Engine world size ({world_size}) != Runtime world size ({self.world_size})"
tp_size = config["builder_config"]["tensor_parallel"]
# TP only for benchmarking
assert (
tp_size == self.world_size
), f"Engine tensor parallel size ({tp_size}) should be equal to world size ({self.world_size})"
assert (
config["plugin_config"]["remove_input_padding"] == False
), "remove_input_padding should be False for enc-dec benchmarks"
num_heads = config["builder_config"]["num_heads"]
assert (num_heads % tp_size) == 0
# Get model config
num_heads = num_heads // tp_size
hidden_size = config["builder_config"]["hidden_size"] // tp_size
num_kv_heads = config["builder_config"].get(
"num_kv_heads", config["builder_config"]["num_heads"])
num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size
model_config = tensorrt_llm.runtime.ModelConfig(
num_heads=num_heads,
num_kv_heads=num_kv_heads,
hidden_size=hidden_size,
head_size=config["builder_config"]["head_size"],
vocab_size=config["builder_config"]["vocab_size"],
num_layers=config["builder_config"]["num_layers"],
gpt_attention_plugin=config["plugin_config"]
["gpt_attention_plugin"],
remove_input_padding=config["plugin_config"]
["remove_input_padding"],
cross_attention=config["builder_config"]["cross_attention"],
has_position_embedding=config["builder_config"]
["has_position_embedding"],
has_token_type_embedding=config["builder_config"]
["has_token_type_embedding"],
use_custom_all_reduce=config["plugin_config"].get(
"use_custom_all_reduce", False),
dtype=config_dtype,
)
# get builder config
builder_config = dict()
for key, value in config["builder_config"].items():
if key == "name":
engine_model_name = value
else:
builder_config[key] = value
# get plugin config
plugin_config = dict()
for key, value in config["plugin_config"].items():
# Same effect as self.use_foo_plugin = config.json["foo_plugin"]
if "plugin" in key:
key = "use_" + key
plugin_config[key] = value
return engine_model_name, model_config, builder_config, plugin_config
(
self.encoder_engine_model_name,
self.encoder_model_config,
self.encoder_builder_config,
self.encoder_plugin_config,
) = read_config("encoder")
(
self.decoder_engine_model_name,
self.decoder_model_config,
self.decoder_builder_config,
self.decoder_plugin_config,
) = read_config("decoder")
self.encoder_engine_name = get_engine_name(
self.encoder_engine_model_name,
self.dtype,
self.world_size,
self.runtime_rank,
)
self.decoder_engine_name = get_engine_name(
self.decoder_engine_model_name,
self.dtype,
self.world_size,
self.runtime_rank,
)
self.encoder_runtime_mapping = tensorrt_llm.Mapping(
world_size=self.world_size,
rank=self.runtime_rank,
tp_size=self.world_size,
gpus_per_node=self.encoder_builder_config.get("gpus_per_node", 8),
)
self.decoder_runtime_mapping = tensorrt_llm.Mapping(
world_size=self.world_size,
rank=self.runtime_rank,
tp_size=self.world_size,
gpus_per_node=self.encoder_builder_config.get("gpus_per_node", 8),
)
if not args.serial_build:
torch.cuda.set_device(self.runtime_rank %
self.encoder_runtime_mapping.gpus_per_node)
self.device = torch.cuda.current_device()
if self.engine_dir is not None:
# Deserialize engine from engine directory
self.encoder_serialize_path = os.path.join(self.engine_dir,
"encoder",
self.encoder_engine_name)
with open(self.encoder_serialize_path, "rb") as f:
encoder_engine_buffer = f.read()
self.decoder_serialize_path = os.path.join(self.engine_dir,
"decoder",
self.decoder_engine_name)
with open(self.decoder_serialize_path, "rb") as f:
decoder_engine_buffer = f.read()
else:
# TODO: Build engine
assert False, "Engine directory is currently required for enc-dec benchmarks"
encoder_engine_buffer = None
decoder_engine_buffer = None
assert encoder_engine_buffer is not None
assert decoder_engine_buffer is not None
# session setup
self.encoder_session = tensorrt_llm.runtime.Session.from_serialized_engine(
encoder_engine_buffer)
self.decoder_session = tensorrt_llm.runtime.GenerationSession(
self.decoder_model_config,
decoder_engine_buffer,
self.decoder_runtime_mapping,
)
def get_config(self):
max_batch_size = self.encoder_builder_config["max_batch_size"]
for inlen, outlen in self.in_out_lens:
if (inlen > self.encoder_builder_config["max_encoder_input_len"]
or outlen > self.encoder_builder_config["max_output_len"]):
print(
f"[WARNING] check inlen({inlen}) <= max_inlen({self.max_input_len}) and "
f"outlen({outlen}) <= max_outlen({self.max_output_len}) failed, skipping."
)
continue
for batch_size in self.batch_sizes:
if batch_size > max_batch_size:
print(
f"[WARNING] check batch_size({batch_size}) "
f"<= max_batch_size({max_batch_size}) failed, skipping."
)
continue
yield (batch_size, inlen, outlen)
def prepare_inputs(self, config):
batch_size, encoder_input_len = config[0], config[1]
encoder_input_ids = (torch.randint(
100, (batch_size, encoder_input_len)).int().cuda())
# For now, just hardcode the decoder_start_token_id to 0 for t5 models.
decoder_start_token_id = 0
decoder_input_ids = torch.IntTensor([[decoder_start_token_id]
]).to(self.device)
decoder_input_ids = decoder_input_ids.repeat(
(encoder_input_ids.shape[0], 1))
# in padding mode --> keep input, just calculate actual length and max length
# Note: 1st token should always count, even if it is pad_token_id (0). e.g., decoder start id in enc-dec models could be a single pad_token_id, we should count
encoder_input_lengths = ((1 + (encoder_input_ids[:, 1:] != 0).sum(
dim=1).type(torch.IntTensor).to(self.device)).clone().detach().to(
dtype=torch.int32, device=self.device))
decoder_input_lengths = ((1 + (decoder_input_ids[:, 1:] != 0).sum(
dim=1).type(torch.IntTensor).to(self.device)).clone().detach().to(
dtype=torch.int32, device=self.device))
stream = torch.cuda.current_stream().cuda_stream
return (
encoder_input_ids,
encoder_input_lengths,
decoder_input_ids,
decoder_input_lengths,
stream,
)
def run(self, inputs, config, benchmark_profiler=None):
output_len = config[2]
(
encoder_input_ids,
encoder_input_lengths,
decoder_input_ids,
decoder_input_lengths,
stream,
) = inputs
hidden_size = (self.encoder_model_config.hidden_size *
self.encoder_runtime_mapping.tp_size)
hidden_states_shape = (
encoder_input_ids.shape[0],
encoder_input_ids.shape[1],
hidden_size,
)
hidden_states_dtype = lambda name: trt_dtype_to_torch(
self.encoder_session.engine.get_tensor_dtype(name))
# input tensors
inputs = {}
inputs["input_ids"] = encoder_input_ids.contiguous()
inputs["input_lengths"] = encoder_input_lengths
inputs["max_input_length"] = torch.empty(
(self.encoder_builder_config["max_encoder_input_len"], ),
dtype=hidden_states_dtype("max_input_length"),
device=self.device,
).contiguous()
# output tensors
outputs = {}
outputs["encoder_output"] = torch.empty(
hidden_states_shape,
dtype=hidden_states_dtype("encoder_output"),
device=self.device,
).contiguous()
# run encoder
self.encoder_session.set_shapes(inputs)
ok = self.encoder_session.run(inputs, outputs, stream)
assert ok, "Runtime execution failed"
torch.cuda.synchronize()
# run decoder
sampling_config = tensorrt_llm.runtime.SamplingConfig(
end_id=1, pad_id=0, num_beams=self.num_beams, min_length=output_len)
self.decoder_session.setup(
decoder_input_lengths.size(0),
torch.max(decoder_input_lengths).item(),
output_len,
beam_width=self.num_beams,
max_kv_cache_length=None,
encoder_max_input_length=torch.max(encoder_input_lengths).item(),
)
torch.cuda.synchronize()
self.decoder_session.decode(
decoder_input_ids,
decoder_input_lengths,
sampling_config,
encoder_output=outputs["encoder_output"],
encoder_input_lengths=encoder_input_lengths,
)
torch.cuda.synchronize()
def report(self,
config,
latency,
percentile95,
percentile99,
peak_gpu_used,
csv,
benchmark_profiler=None):
# Note: Theoretically, the encoder and decoder can have different configs.
# But for current implementation, we assume they are the same. In the future,
# we can have a special structure of report_dict for enc-dec models.
report_dict = super().get_report_dict()
batch_size, encoder_input_len, output_len = config[0], config[
1], config[2]
tokens_per_sec = round(batch_size * output_len / (latency / 1000), 2)
report_dict["num_heads"] = self.encoder_model_config.num_heads
report_dict["num_kv_heads"] = self.encoder_model_config.num_kv_heads
report_dict["num_layers"] = self.encoder_model_config.num_layers
report_dict["hidden_size"] = self.encoder_model_config.hidden_size
report_dict["vocab_size"] = self.encoder_model_config.vocab_size
report_dict["batch_size"] = batch_size
report_dict["input_length"] = encoder_input_len
report_dict["output_length"] = output_len
report_dict["latency(ms)"] = latency
report_dict["build_time(s)"] = self.build_time
report_dict["tokens_per_sec"] = tokens_per_sec
report_dict["percentile95(ms)"] = percentile95
report_dict["percentile99(ms)"] = percentile99
report_dict["gpu_peak_mem(gb)"] = peak_gpu_used
if self.runtime_rank == 0:
if csv:
line = ",".join([str(v) for v in report_dict.values()])
print(line)
with open(self.get_csv_filename(), "a") as file:
file.write(line + "\n")
else:
kv_pairs = [f"{k} {v}" for k, v in report_dict.items()]
line = "[BENCHMARK] " + " ".join(kv_pairs)
print(line)

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from math import ceil
import torch
@ -27,8 +26,9 @@ from build import build_gpt, get_quant_mode # isort:skip
class GPTBenchmark(BaseBenchmark):
def __init__(self, args, batch_sizes, in_out_lens):
super().__init__(args.engine_dir, args.model, args.dtype)
def __init__(self, args, batch_sizes, in_out_lens, rank, world_size):
super().__init__(args.engine_dir, args.model, args.dtype, rank,
world_size, args.serial_build)
self.batch_sizes = batch_sizes
self.in_out_lens = in_out_lens
self.num_beams = args.num_beams
@ -49,6 +49,13 @@ class GPTBenchmark(BaseBenchmark):
setattr(self, key, value)
if args.force_num_layer_1:
self.num_layers = 1
if args.max_batch_size is not None:
self.max_batch_size = args.max_batch_size
if args.max_input_len is not None:
self.max_input_len = args.max_input_len
if args.max_output_len is not None:
self.max_output_len = args.max_output_len
self.quant_mode, _, _, _ = get_quant_mode(args.quantization)
self.enable_fp8 = self.quant_mode.has_fp8_qdq()
self.fp8_kv_cache = self.quant_mode.has_fp8_kv_cache()
@ -62,11 +69,12 @@ class GPTBenchmark(BaseBenchmark):
elif args.mode == 'ootb-except-mha':
self.use_gpt_attention_plugin = True
start = time.time()
engine_buffer = build_gpt(args)
self.build_time = round(time.time() - start, 2)
engine_buffer, build_time = build_gpt(args)
self.build_time = build_time
assert engine_buffer is not None
if args.build_only:
return
if not hasattr(self, 'num_kv_heads') or self.num_kv_heads is None:
self.num_kv_heads = self.num_heads
@ -194,7 +202,8 @@ class GPTBenchmark(BaseBenchmark):
'generation_step_count')
token_per_step = batch_size * self.num_beams
total_tokens = generation_step_count * token_per_step
report_dict["generation_time(ms)"] = generation_time_ms / iter_count
report_dict["generation_time(ms)"] = round(
generation_time_ms / iter_count, 3)
report_dict["total_generated_tokens"] = total_tokens / iter_count
tokens_per_second = round(
total_tokens * 1000.0 / generation_time_ms, 3)

View File

@ -17,7 +17,6 @@
#pragma once
#include "tensorrt_llm/batch_manager/namedTensor.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h"
@ -57,6 +56,9 @@ auto constexpr kReturnLogProbsTensorName = "return_log_probs";
auto constexpr kPromptEmbeddingTableName = "prompt_embedding_table";
auto constexpr kPromptVocabSizeName = "prompt_vocab_size";
// Obsolete names for backward compatibility
auto constexpr kInputLengthsTensorName = "input_lengths";
// Output tensors
auto constexpr kOutputIdsTensorName = "output_ids";
auto constexpr kSequenceLengthTensorName = "sequence_length";
@ -131,6 +133,8 @@ public:
inference_request::kReturnLogProbsTensorName,
inference_request::kPromptEmbeddingTableName,
inference_request::kPromptVocabSizeName,
// obsolete names for backward compatibility
inference_request::kInputLengthsTensorName,
};
#define TENSOR_GETTER_SETTER(funcName, tensorName) \
@ -191,11 +195,8 @@ public:
protected:
static void validateTensorName(std::string const& tensorName)
{
// TODO (martinma): Throw an exception if the tensor name is not valid.
if (std::find(kTensorNames.begin(), kTensorNames.end(), tensorName) == kTensorNames.end())
{
TLLM_LOG_WARNING("Invalid tensor name in InferenceRequest: %s", tensorName.c_str());
}
TLLM_CHECK_WITH_INFO(std::find(kTensorNames.begin(), kTensorNames.end(), tensorName) != kTensorNames.end(),
"Invalid tensor name: %s", tensorName.c_str());
}
uint64_t mRequestId;

View File

@ -30,17 +30,19 @@ public:
using SizeType = tensorrt_llm::runtime::SizeType;
explicit KvCacheConfig(std::optional<SizeType> maxTokens = std::nullopt,
std::optional<SizeType> maxKvCacheLength = std::nullopt,
std::optional<float> freeGpuMemoryFraction = std::nullopt)
std::optional<SizeType> maxAttentionWindow = std::nullopt,
std::optional<float> freeGpuMemoryFraction = std::nullopt, bool enableBlockReuse = false)
: maxTokens{maxTokens}
, maxKvCacheLength{maxKvCacheLength}
, maxAttentionWindow{maxAttentionWindow}
, freeGpuMemoryFraction{freeGpuMemoryFraction}
, enableBlockReuse(enableBlockReuse)
{
}
std::optional<SizeType> maxTokens;
std::optional<SizeType> maxKvCacheLength;
std::optional<SizeType> maxAttentionWindow;
std::optional<float> freeGpuMemoryFraction;
bool enableBlockReuse;
static constexpr auto kDefaultGpuMemFraction = 0.85f;
};

View File

@ -17,8 +17,9 @@
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/batch_manager/llmRequest.h" // TODO forward declare
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/gptModelConfig.h"
#include "tensorrt_llm/runtime/iTensor.h"
@ -26,14 +27,52 @@
#include <NvInferRuntime.h>
#include <cstdint>
#include <functional>
#include <list>
#include <memory>
#include <optional>
#include <unordered_map>
#include <vector>
namespace std
{
// Implement std::hash function object for vector<TokenIdType>.
// This allows us to use unordered_map with vector<TokenIdType> as key.
// Based on https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector/72073933#72073933
template <>
struct hash<vector<int32_t>>
{
size_t operator()(vector<int32_t> const& vec) const noexcept
{
size_t seed = vec.size();
for (auto x : vec)
{
uint32_t y = static_cast<uint32_t>(x);
y = ((y >> 16) ^ y) * 0x45d9f3b;
y = ((y >> 16) ^ y) * 0x45d9f3b;
y = (y >> 16) ^ y;
seed ^= y + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
};
} // namespace std
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
class KVCacheBlock;
using SizeType = tensorrt_llm::runtime::SizeType;
using TokenIdType = tensorrt_llm::runtime::TokenIdType;
using VecTokens = std::vector<TokenIdType>;
using BeamTokens = std::vector<VecTokens>;
using BlockPtr = std::shared_ptr<KVCacheBlock>;
using FreeBlocksQueue = std::list<BlockPtr>;
using NextBlockMap = std::unordered_map<VecTokens, BlockPtr>;
struct KvCacheStats
{
@ -49,8 +88,6 @@ struct KvCacheStats
class KVCacheBlock
{
public:
using SizeType = tensorrt_llm::runtime::SizeType;
explicit KVCacheBlock(SizeType blockIdx);
void startScheduling();
@ -67,6 +104,29 @@ public:
[[nodiscard]] bool hasSchedulingRefs() const;
void setTokens(VecTokens& tokens, bool isFull);
[[nodiscard]] VecTokens const& getTokens() const;
void setFreeBlockIterator(FreeBlocksQueue::iterator freeBlockIterator);
void resetFreeBlockIterator();
[[nodiscard]] std::optional<FreeBlocksQueue::iterator> const& getFreeBlockIterator() const;
void setPrevBlock(BlockPtr prevBlock);
void addNextBlock(VecTokens const& tokens, BlockPtr block);
void removeNextBlock(VecTokens const& tokens);
[[nodiscard]] BlockPtr findMatchingBlock(VecTokens const& tokens) const;
//! \brief Free block from previous block if present.
void freeLeafBlock();
[[nodiscard]] bool isFull() const;
private:
// Linear index of block in pool
SizeType mBlockIdx;
@ -76,6 +136,21 @@ private:
// Number of references to the block
SizeType mSchedulingRefCount;
// Key of this block in mNextBlocks map in block pointed to by mPrevBlock
VecTokens mTokens;
// Previous block in sequence
BlockPtr mPrevBlock;
// Next block(s) in sequence(s)
NextBlockMap mNextBlocks;
// Iterator pointing to this block in mFreeBlocks.
std::optional<FreeBlocksQueue::iterator> mFreeBlockIterator;
// Flag indicating if block is full
bool mIsFull;
};
class GenerationRequest
@ -84,32 +159,22 @@ public:
using SizeType = tensorrt_llm::runtime::SizeType;
using SharedPtr = std::shared_ptr<GenerationRequest>;
GenerationRequest(SizeType batchSlotIdx, SizeType numTokens, SizeType beamWidth)
: mBatchSlotIdx(batchSlotIdx)
explicit GenerationRequest(SizeType seqSlotIdx, SizeType numTokens, SizeType beamWidth)
: mSeqSlotIdx(seqSlotIdx)
, mNumTokens(numTokens)
, mBeamWidth(beamWidth)
, mCacheBlockIds(beamWidth)
{
}
void setBatchSlotIdx(SizeType batchSlotIdx)
{
mBatchSlotIdx = batchSlotIdx;
}
void setNumTokens(SizeType numTokens)
{
mNumTokens = numTokens;
}
void addToken()
{
mNumTokens++;
}
[[nodiscard]] SizeType getBatchSlotIdx() const
[[nodiscard]] SizeType getSequenceSlotIdx() const
{
return mBatchSlotIdx;
return mSeqSlotIdx;
}
[[nodiscard]] SizeType getNumTokens() const
@ -140,15 +205,29 @@ public:
}
}
void setNumPrepopulatedTokens(std::vector<int> numPrepopulatedTokens)
{
mNumPrepopulatedTokens = std::move(numPrepopulatedTokens);
}
[[nodiscard]] std::vector<int> const& getNumPrepopulatedTokens() const
{
return mNumPrepopulatedTokens;
}
private:
// Index of sequence in the batch
SizeType mBatchSlotIdx;
// Slot id of the sequence
SizeType mSeqSlotIdx;
// Current number of generated tokens
SizeType mNumTokens;
// Number of beams
SizeType mBeamWidth;
// List of blocks allocated for each beam of the sequence
std::vector<std::vector<SizeType>> mCacheBlockIds;
// Number of tokens already in kv cache before context phase.
// A value > 0 indicates cached kv cache blocks were reused.
// One value per beam.
std::vector<int> mNumPrepopulatedTokens;
};
// BlockManager manages overall metadata of KVCacheBlocks in a layer of the
@ -161,23 +240,34 @@ private:
// Block shape is [2, num_heads, tokens_per_block, head_size].
// BlockManager maintains a list of free blocks at any time.
// Alloc pops off the block at the front, and Free pushes it back to the vector.
// BlockManager maintains a vector of lists of batchSlotIdx to allocated blocks
// BlockManager maintains a vector of lists of seqSlotIdx to allocated blocks
// per sequence. This can be used to Free all blocks belonging to a sequence.
class BlockManager
{
public:
using SizeType = tensorrt_llm::runtime::SizeType;
explicit BlockManager(SizeType blocksInPool);
explicit BlockManager(SizeType blocksInPool, SizeType tokensPerBlock);
~BlockManager();
void startScheduling();
//! \brief Assign blocks for new sequence. Try to reuse blocks.
void addSequence(GenerationRequest& sequence, std::shared_ptr<LlmRequest> const& llmRequest);
//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
void addSequence(GenerationRequest& sequence, SizeType inputLength, bool enableCyclicKvCache);
//! \brief Allocate new block for each beam of the sequence.
//! \details Might free cached blocks if no free blocks are available.
void allocateBlock(GenerationRequest& sequence, bool shareAmongBeams = false);
void freeAllBlocks(GenerationRequest& sequence);
//! \brief Release blocks of the sequence. Store blocks for reuse if llmReqeust is provided.
void releaseBlocks(GenerationRequest& sequence, std::shared_ptr<LlmRequest> const& llmRequest = nullptr);
// Simulate freeing all blocks for that sequence to check impact on number of free blocks
void schedulingFreeAllBlocks(GenerationRequest& sequence);
//! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
void schedulingReleaseBlocks(GenerationRequest& sequence);
[[nodiscard]] SizeType getNumFreeBlocks() const
{
@ -186,7 +276,7 @@ public:
[[nodiscard]] SizeType getNumAllocatedBlocks() const
{
return mAllocatedBlocks.size();
return getMaxNumBlocks() - getNumFreeBlocks();
}
[[nodiscard]] bool hasFreeBlocks(SizeType numRequired = 1) const
@ -199,13 +289,58 @@ public:
return mSchedulingNumFreeBlocks >= numRequired;
}
[[nodiscard]] SizeType getMaxNumBlocks() const
{
return static_cast<SizeType>(mAllBlocksByIdx.size());
}
[[nodiscard]] SizeType getTokensPerBlock() const
{
return mTokensPerBlock;
}
private:
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType beamIdx, SizeType seqSlotIdx);
//! \brief Store blocks in cached blocks.
//! \param blockedTokens Tokens of each block.
//! \param blockIds Id of each block.
void storeBlocks(std::list<VecTokens> blockedTokens, std::vector<SizeType> const& blockIds);
//! \brief Try to load blocks from cache. Allocate new blocks if necessary.
//! \param blockedTokens Tokens of each block.
//! \param sequence Sequence to which blocks are assigned.
//! \param beamIdx Beam of sequence to which blocks are assigned.
//! \param seqSlotIdx Batch slot of sequence to which blocks are assigned.
//! \return Number of matched tokens from loaded blocks.
SizeType loadOrAllocateBlocks(
std::list<VecTokens> blockedTokens, GenerationRequest& sequence, SizeType beamIdx, SizeType seqSlotIdx);
//! \brief Find block least likely to be reused, free it if necessary and return.
[[nodiscard]] BlockPtr getFreeBlock();
//! \brief Claim block if it is in free blocks list.
void claimBlock(KVCacheBlock& block);
//! \brief Free block from previous block and claim it from free blocks list.
void claimLeafBlock(KVCacheBlock& block);
private:
// List of free blocks
std::list<KVCacheBlock> mFreeBlocks;
FreeBlocksQueue mFreeBlocks;
// List of allocated blocks for each sequences
std::vector<std::vector<KVCacheBlock>> mAllocatedBlocks;
std::vector<std::vector<BlockPtr>> mAllocatedBlocksPerSeq;
// Used to keep track of number of free blocks during scheduling
SizeType mSchedulingNumFreeBlocks;
// Number of tokens per one block
SizeType mTokensPerBlock;
// List of all blocks by idx
std::vector<BlockPtr> mAllBlocksByIdx;
// Dummy block acting as root for BlockToken searches
BlockPtr mCachedBlocksRoot;
// Statistics for block allocations/reuse
std::size_t mAllocTotalBlocks, mAllocNewBlocks, mReusedBlocks;
};
class KVCacheManager
@ -216,19 +351,20 @@ public:
using CudaStreamPtr = std::shared_ptr<runtime::CudaStream>;
KVCacheManager(SizeType numLayers, SizeType numHeads, SizeType numKvHeads, SizeType hiddenSize,
SizeType tokensPerBlock, SizeType maxNumBlocks, SizeType maxBatchSize, SizeType maxBeamWidth,
SizeType maxBlocksPerSeq, SizeType maxKvCacheLength, nvinfer1::DataType dtype, CudaStreamPtr stream);
SizeType tokensPerBlock, SizeType maxNumBlocks, SizeType maxNumSequences, SizeType maxBeamWidth,
SizeType maxBlocksPerSeq, SizeType maxAttentionWindow, nvinfer1::DataType dtype, CudaStreamPtr stream,
bool enableBlockReuse = false);
void startScheduling();
[[nodiscard]] SizeType getTokensPerBlock() const
{
return mTokensPerBlock;
return mBlockManager.getTokensPerBlock();
}
[[nodiscard]] SizeType getMaxNumBlocks() const
{
return mMaxNumBlocks;
return mBlockManager.getMaxNumBlocks();
}
[[nodiscard]] SizeType getUsedNumBlocks() const
@ -267,32 +403,33 @@ public:
/// iterations
/// @param req The request for which we need to calculate the number of needed KV cache blocks
/// @return The number of blocks
SizeType getNeededBlocksOneStep(const LlmRequest& req, bool twoStepsLookAhead) const;
SizeType getNeededBlocksOneStep(LlmRequest const& req, bool twoStepsLookAhead) const;
/// @brief Function that computes the number of KV cache blocks needed to advance a request to completion (i.e. for
/// maxNewTokens)
/// @param req The request for which we need to calculate the number of needed KV cache blocks
/// @return The number of blocks
SizeType getNeededBlocksToCompletion(const LlmRequest& req) const;
SizeType getNeededBlocksToCompletion(LlmRequest const& req) const;
[[nodiscard]] std::vector<runtime::ITensor::SharedPtr> const& getMemoryPools() const
{
return mPools;
}
void addToken(SizeType batchSlotIdx);
void addToken(SizeType seqSlotIdx);
void addSequence(SizeType batchSlotIdx, SizeType inputLength, SizeType beamWidth);
void addSequence(SizeType seqSlotIdx, SizeType inputLength, SizeType beamWidth,
std::shared_ptr<LlmRequest> const& llmRequest = nullptr);
void removeSequence(SizeType batchSlotIdx);
void removeSequence(SizeType seqSlotIdx, std::shared_ptr<LlmRequest> const& llmRequest = nullptr);
void schedulingRemoveSequence(SizeType batchSlotIdx);
void schedulingRemoveSequence(SizeType seqSlotIdx);
void getBlockPointersOfBatch(
runtime::ITensor& dstPointers, SizeType firstBatchSlotIdx, SizeType batchSize, SizeType beamWidth) const;
void copyBlockPointers(
runtime::ITensor& dstPointers, SizeType dstSlotOffset, SizeType batchSlotIdx, SizeType beamWidth) const;
runtime::ITensor& dstPointers, SizeType dstSlotOffset, SizeType seqSlotIdx, SizeType beamWidth) const;
// Volume of [2, numKvHeads, tokensPerBlock, sizePerHead]
[[nodiscard]] static SizeType constexpr calculatePageSize(tensorrt_llm::runtime::GptModelConfig const& modelConfig)
@ -313,26 +450,22 @@ public:
runtime::BufferManager const& bufferManager);
private:
void resetBlockPointers(SizeType batchSlotIdx, SizeType beamWidth);
void cacheNewBlockPointer(const GenerationRequest& seq, SizeType batchSlotIdx);
void resetBlockPointers(SizeType seqSlotIdx, SizeType beamWidth);
void cacheBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx);
void cacheNewBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx);
private:
// Number of elements per one blocks
SizeType mBlockSize;
// Number of tokens per one blocks
SizeType mTokensPerBlock;
// Total maximum number of blocks
SizeType mMaxNumBlocks;
// Maximum size of batch
SizeType mMaxBatchSize;
// Maximum number of sequences
SizeType mMaxNumSequences;
// Maximum beam width
SizeType mMaxBeamWidth;
// Maximum number of blocks per sequence
SizeType mMaxBlocksPerSeq;
// Maximum kv cache length per sequence
// Enable cyclic kv cache when it exceeds
SizeType mMaxKvCacheLength;
SizeType mMaxAttentionWindow;
// Pools
std::vector<runtime::ITensor::SharedPtr> mPools;
// Block manager
@ -341,7 +474,9 @@ private:
std::vector<SequencesPtr> mSequences;
// buffer for block pointers for all managed sequences
runtime::ITensor::SharedPtr mSequenceBlockPointers;
runtime::BufferManager mManager;
// Buffer manager
runtime::BufferManager mBufferManager;
// Whether to cache KV pages for reuse
bool mEnableBlockReuse;
};
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

View File

@ -48,11 +48,10 @@ public:
using BeamTokens = std::vector<VecTokens>;
using TensorPtr = TTensor;
GenericLlmRequest(RequestIdType requestId, SizeType maxNewTokens,
std::shared_ptr<std::vector<TokenIdType>> inputTokens, runtime::SamplingConfig samplingConfig, bool isStreaming,
std::optional<SizeType> endId = std::nullopt, std::optional<SizeType> padId = std::nullopt,
std::optional<TensorPtr> embeddingBias = std::nullopt, std::optional<TensorPtr> badWordsList = std::nullopt,
std::optional<TensorPtr> stopWordsList = std::nullopt,
GenericLlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<VecTokens> inputTokens,
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
std::optional<SizeType> promptVocabSize = std::nullopt, bool returnLogProbs = false,
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt,
@ -65,7 +64,7 @@ public:
, mIsStreaming(isStreaming)
, mEndId(endId)
, mPadId(padId)
, mBatchSlot(-1)
, mSeqSlot(-1)
, mOrigPromptLen(inputTokens->size())
, mEmbeddingBias(embeddingBias)
, mBadWordsList(badWordsList)
@ -132,14 +131,21 @@ public:
/// @brief Get the tokens at a given beam index
/// @param beam The beam index
/// @return A vector of tokens for this beam index, includes the prompt
std::vector<TokenIdType> const& getTokens(SizeType beam) const
VecTokens const& getTokens(SizeType beam) const
{
return mTokens.at(beam);
}
/// @brief Get all tokens (input+output) for all beams
/// @return A vector of vector of tokens.
BeamTokens const& getTokens() const
{
return mTokens;
}
/// @brief Get the draft tokens
/// @return shared_ptr to vector of draft tokens
std::shared_ptr<std::vector<TokenIdType>> const& getDraftTokens() const
std::shared_ptr<VecTokens> const& getDraftTokens() const
{
return mDraftTokens;
}
@ -176,7 +182,7 @@ public:
/// @brief Add new generated tokens to the vector of tokens
/// @param beamTokens A vector containing the tokens to add for each beam index
/// beamTokens is expected to be of size beamWidth
void addNewTokens(const std::vector<TokenIdType>& beamTokens)
void addNewTokens(VecTokens const& beamTokens)
{
assert(static_cast<size_t>(mSamplingConfig.beamWidth) == beamTokens.size());
for (std::size_t beam = 0; beam < beamTokens.size(); ++beam)
@ -236,7 +242,7 @@ public:
mPromptLen = newPromptLen;
}
mState = REQUEST_STATE_CONTEXT_INIT;
mBatchSlot = -1;
mSeqSlot = -1;
}
/// @brief Get the maximum position of the tokens returned to the client. Use to ensure we don't return to
@ -335,7 +341,7 @@ public:
bool mIsStreaming;
std::optional<SizeType> mEndId;
std::optional<SizeType> mPadId;
SizeType mBatchSlot;
SizeType mSeqSlot;
protected:
SizeType mOrigPromptLen;
@ -369,7 +375,7 @@ public:
using BeamTokens = Base::BeamTokens;
using VecTokens = Base::VecTokens;
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<std::vector<TokenIdType>> inputTokens,
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<VecTokens> inputTokens,
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,

View File

@ -34,11 +34,13 @@ public:
explicit TrtGptModelOptionalParams(KvCacheConfig const& kvCacheConfig = KvCacheConfig{},
std::optional<SizeType> maxNumSequences = std::nullopt, bool enableTrtOverlap = true,
bool useContextFMHAForGeneration = false)
bool useContextFMHAForGeneration = false,
std::optional<std::vector<SizeType>> userSpecifiedDeviceIds = std::nullopt)
: kvCacheConfig{kvCacheConfig}
, maxNumSequences{maxNumSequences}
, enableTrtOverlap{enableTrtOverlap}
, useContextFMHAForGeneration(useContextFMHAForGeneration)
, userSpecifiedDeviceIds(userSpecifiedDeviceIds)
{
}
@ -46,6 +48,7 @@ public:
std::optional<SizeType> maxNumSequences;
bool enableTrtOverlap;
bool useContextFMHAForGeneration;
std::optional<std::vector<SizeType>> userSpecifiedDeviceIds;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -29,10 +29,11 @@ class DecodingInput
public:
using TensorPtr = std::shared_ptr<ITensor const>;
DecodingInput(SizeType maxLength, SizeType maxKvCacheLength, SizeType batchSize, TensorPtr logits, TensorPtr endIds)
DecodingInput(
SizeType maxLength, SizeType maxAttentionWindow, SizeType batchSize, TensorPtr logits, TensorPtr endIds)
: step{maxLength}
, maxLength{maxLength}
, maxKvCacheLength{maxKvCacheLength}
, maxAttentionWindow{maxAttentionWindow}
, batchSize{batchSize}
, logits{std::move(logits)}
, endIds{std::move(endIds)}
@ -44,7 +45,7 @@ public:
// mandatory parameters
SizeType step;
SizeType maxLength;
SizeType maxKvCacheLength;
SizeType maxAttentionWindow;
SizeType batchSize;
TensorPtr logits; // [batchSize, beamWidth, vocabSizePadded], on gpu
TensorPtr endIds; // [batchSize * beamWidth], on gpu

View File

@ -45,7 +45,7 @@ public:
GptDecoderBatch(std::size_t vocabSize, std::size_t vocabSizePadded, CudaStreamPtr stream);
//! Setup the decoder before calling `forward()`
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength,
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow, SizeType maxSequenceLength,
SizeType maxTokensPerStep, nvinfer1::DataType dtype) override;
//! @brief Initialize the decoder at `batchIdx` with a new `request`.
@ -200,7 +200,7 @@ private:
TensorPtr mTargetProbs; // [batchSize, maxDraftTokens+1, beamWidth, vocabPadded], temporary data for speculative
// decoding accept by logits kernel, on gpu
SizeType mMaxSequenceLength{};
SizeType mMaxKvCacheLength{};
SizeType mMaxAttentionWindow{};
SizeType mActualBatchSize{};
SizeType mMaxTokensPerStep{};
};

View File

@ -32,9 +32,10 @@ namespace tensorrt_llm::runtime
class GptJsonConfig
{
public:
GptJsonConfig(std::string name, std::string precision, SizeType tensorParallelism, SizeType pipelineParallelism,
GptModelConfig const& modelConfig)
GptJsonConfig(std::string name, std::string version, std::string precision, SizeType tensorParallelism,
SizeType pipelineParallelism, GptModelConfig const& modelConfig)
: mName(std::move(name))
, mVersion(std::move(version))
, mPrecision(std::move(precision))
, mTensorParallelism{tensorParallelism}
, mPipelineParallelism{pipelineParallelism}
@ -58,6 +59,11 @@ public:
return mName;
}
[[nodiscard]] std::string const& getVersion() const
{
return mVersion;
}
[[nodiscard]] std::string const& getPrecision() const
{
return mPrecision;
@ -87,6 +93,7 @@ public:
private:
std::string const mName;
std::string const mVersion;
std::string const mPrecision;
SizeType const mTensorParallelism;
SizeType const mPipelineParallelism;

View File

@ -46,6 +46,7 @@ public:
, mTokensPerBlock{64}
, mQuantMode{common::QuantMode::none()}
, mMaxBatchSize(0)
, mMaxBeamWidth(0)
, mMaxInputLen(0)
, mMaxOutputLen(0)
, mMaxNumTokens(std::nullopt)
@ -169,6 +170,16 @@ public:
mMaxBatchSize = maxBatchSize;
}
[[nodiscard]] SizeType constexpr getMaxBeamWidth() const noexcept
{
return mMaxBeamWidth;
}
void constexpr setMaxBeamWidth(SizeType maxBeamWidth) noexcept
{
mMaxBeamWidth = maxBeamWidth;
}
[[nodiscard]] SizeType constexpr getMaxInputLen() const noexcept
{
return mMaxInputLen;
@ -259,6 +270,11 @@ public:
mMaxDraftLen = maxDraftLen;
}
[[nodiscard]] SizeType getMaxDraftLen() const
{
return mMaxDraftLen;
}
[[nodiscard]] SizeType constexpr getMaxTokensPerStep() const noexcept
{
return mMaxDraftLen + 1;
@ -277,6 +293,7 @@ private:
SizeType mTokensPerBlock;
common::QuantMode mQuantMode;
SizeType mMaxBatchSize;
SizeType mMaxBeamWidth;
SizeType mMaxInputLen;
SizeType mMaxOutputLen;
std::optional<SizeType> mMaxNumTokens;

View File

@ -143,9 +143,9 @@ private:
void createContexts(SizeType numBatchesCtx, SizeType numBatchesGen, bool useCudaGraphs);
void createBuffers(SizeType numMicroBatches);
void createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength,
void createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow, SizeType maxSequenceLength,
nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches);
void createKvCacheManager(SizeType batchSize, SizeType beamWidth, SizeType maxKvCacheLength,
void createKvCacheManager(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow,
SizeType maxSequenceLength, KvCacheConfig const& config);
void createCustomAllReduceWorkspace(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength);
@ -261,7 +261,7 @@ private:
std::vector<std::shared_ptr<IpcMemory>> mIpcMemoryHandles;
SizeType mDecoderMaxSequenceLength{};
SizeType mDecoderMaxKvCacheLength{};
SizeType mDecoderMaxAttentionWindow{};
LoggerPtr mLogger;
std::shared_ptr<TllmRuntime> mRuntime;

View File

@ -74,7 +74,7 @@ public:
using TensorPtr = std::shared_ptr<ITensor>;
//! Setup the decoder before calling `forward()`, also calls reshapeBuffers
virtual void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength,
virtual void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow,
SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype)
= 0;

View File

@ -29,12 +29,13 @@ class WorldConfig
public:
static SizeType constexpr kDefaultGpusPerNode = 8;
constexpr explicit WorldConfig(SizeType tensorParallelism = 1, SizeType pipelineParallelism = 1, SizeType rank = 0,
SizeType gpusPerNode = kDefaultGpusPerNode)
explicit WorldConfig(SizeType tensorParallelism = 1, SizeType pipelineParallelism = 1, SizeType rank = 0,
SizeType gpusPerNode = kDefaultGpusPerNode, std::vector<SizeType> deviceIds = {})
: mTensorParallelism{tensorParallelism}
, mPipelineParallelism{pipelineParallelism}
, mRank{rank}
, mGpusPerNode{gpusPerNode}
, mDeviceIds{deviceIds}
{
}
@ -73,8 +74,12 @@ public:
return mGpusPerNode;
}
[[nodiscard]] SizeType constexpr getDevice() const noexcept
[[nodiscard]] SizeType getDevice() const noexcept
{
if (mDeviceIds.size())
{
return mDeviceIds[mRank % mGpusPerNode];
}
return mRank % mGpusPerNode;
}
@ -110,17 +115,20 @@ public:
static WorldConfig mpi(nvinfer1::ILogger& logger, SizeType gpusPerNode = kDefaultGpusPerNode,
std::optional<SizeType> tensorParallelism = std::nullopt,
std::optional<SizeType> pipelineParallelism = std::nullopt);
std::optional<SizeType> pipelineParallelism = std::nullopt,
std::optional<std::vector<SizeType>> userSpecifiedDeviceIds = std::nullopt);
static WorldConfig mpi(SizeType gpusPerNode = kDefaultGpusPerNode,
std::optional<SizeType> tensorParallelism = std::nullopt,
std::optional<SizeType> pipelineParallelism = std::nullopt);
std::optional<SizeType> pipelineParallelism = std::nullopt,
std::optional<std::vector<SizeType>> userSpecifiedDeviceIds = std::nullopt);
private:
SizeType mTensorParallelism;
SizeType mPipelineParallelism;
SizeType mRank;
SizeType mGpusPerNode;
std::vector<SizeType> mDeviceIds;
};
} // namespace tensorrt_llm::runtime

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ba982afff27c597c9f5f25bec4ed37debd883c7be2107b47776a014075899fbd
size 1719266
oid sha256:7d9f7d0f7dee2c48a424ff8873c2fd1298a27850f870657734641f2eb1190faf
size 1791038

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:04ec1f2f45dde1ef6b6b0f605e79715eebed38b19b4d833fcb668d2cb71f8a03
size 1733118
oid sha256:fa79a0d563fc01a0cb2fe94dcb626ff4e5b736284d9244313cbe7aa0261dd48e
size 1806500

View File

@ -1,3 +1,3 @@
aab384dfc59de5df4c7ecf53e30d03e9 libtensorrt_llm_batch_manager_static.a
e0074afa6959c896f1cbc7ab90872058 libtensorrt_llm_batch_manager_static.pre_cxx11.a
c7450aa071e91659a3e2855c0cca21021f96ada8 commit
d9723ab671c9fc3889cc624a58def81a libtensorrt_llm_batch_manager_static.a
4b6773c990e8a59f1c716d88505b84a2 libtensorrt_llm_batch_manager_static.pre_cxx11.a
9a136bb59c51bbae09221c1667e23529ed05c752 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:546c9e2b79cb3cf2623876902ef2d40c65925157d43850b2505eedf274e060a1
size 1638840
oid sha256:6a7b872fe6ee63a4342c3cd17b3557d74c72e537dbf0d4ddf132a2c40e000e57
size 1709462

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:935a706ce0d107f8c226566a50946a0f0e35ce926c98b7a12b000b3d72e5f0b6
size 1635602
oid sha256:c83f7c0e4fc22b32df669ada2b99b88f0f7faac935a251fe7a20030e2b364cc8
size 1705432

View File

@ -1,2 +1,2 @@
8e0c5b31d579f4118b84a34ffb00c15a libtensorrt_llm_batch_manager_static.a
885ea7b9f594d7aa9cc9018527b95f6d libtensorrt_llm_batch_manager_static.pre_cxx11.a
583141c3003a08acebc7054d024bee89 libtensorrt_llm_batch_manager_static.a
03e5360f9b8074b8273500898581212f libtensorrt_llm_batch_manager_static.pre_cxx11.a

View File

@ -102,12 +102,12 @@ struct Multihead_attention_params_base
int batch_size = 0;
// The beam width
int beam_width = 0;
// By default, max_kv_cache_length == cyclic_kv_cache_length
// By default, max_attention_window_size == cyclic_attention_window_size
// unless each layer has different cyclic kv cache length.
// Max cache capacity (used to allocate KV cache)
int max_kv_cache_length = 0;
int max_attention_window_size = 0;
// Cyclic kv cache capacity (used to get the cyclic kv cache position for new tokens)
int cyclic_kv_cache_length = 0;
int cyclic_attention_window_size = 0;
// The number of heads (H).
int num_heads = 0;
// Controls MHA/MQA/GQA

View File

@ -19,15 +19,27 @@ namespace tensorrt_llm
namespace kernels
{
// clang-format off
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_90_cubin[];
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_90_cubin_len;
static const struct XQAKernelMetaInfo
@ -42,10 +54,16 @@ static const struct XQAKernelMetaInfo
unsigned int mCubinSize;
const char* mFuncName;
} sXqaKernelMetaInfo[] = {
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_90_cubin_len, "kernel_mha"}
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_90_cubin_len, "kernel_mha"}
};
// clang-format on

View File

@ -44,8 +44,8 @@ inline size_t smem_size_in_bytes(const Multihead_attention_params<T, DO_CROSS_AT
using Tk = typename kernel_type_t<T>::Type;
// The amount of shared memory needed to store the Q*K^T values in float.
const int max_timesteps = DO_CROSS_ATTENTION
? params.cyclic_kv_cache_length
: min((DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep), params.cyclic_kv_cache_length);
? params.cyclic_attention_window_size
: min((DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep), params.cyclic_attention_window_size);
const auto qk_elts = static_cast<std::size_t>(divUp(max_timesteps + 1, 4)); // explicit cast because of the sign
const auto qk_sz = qk_elts * 16;
@ -110,6 +110,9 @@ inline void multi_block_grid_setup(dim3& grid, const Multihead_attention_params<
params.seq_len_tile = std::min(params.seq_len_tile, max_seq_len_tile);
TLLM_CHECK_WITH_INFO(
params.seq_len_tile <= block_size, "The number of blocks per sequence may not exceed the thread block size.");
// We should consider the new timestep.
params.timesteps_per_block = mmha::divUp(tlength + 1, params.seq_len_tile);

View File

@ -1273,9 +1273,9 @@ __global__ void masked_multihead_attention_kernel(
// The maximum sequence length in the cyclic kv_cache, i.e., an upper bound on L.
// Note that the maximum sequence length supported by the model might be greater than this.
// Note max_kv_cache_length is maximum of cyclic_kv_cache_length among all layers.
// Note max_attention_window_size is maximum of cyclic_attention_window_size among all layers.
// By default, you can assume that they are the same.
const auto cyclic_kv_cache_len = static_cast<unsigned>(params.cyclic_kv_cache_length);
const auto cyclic_kv_cache_len = static_cast<unsigned>(params.cyclic_attention_window_size);
// The current timestep (including paddings).
// It is only used to calculate the smem stride.
const auto timestep = static_cast<unsigned>(DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep);
@ -1796,11 +1796,12 @@ __global__ void masked_multihead_attention_kernel(
: divUp(static_cast<unsigned>(kv_loop_length), K_PER_WARP) * K_PER_WARP;
// Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
// Note max_kv_cache_length is maximum of cyclic_kv_cache_length among all layers.
// Note max_attention_window_size is maximum of cyclic_attention_window_size among all layers.
// By default, you can assume that they are the same.
const auto bi_seq_len_offset = static_cast<std::size_t>(batch_beam_idx) * params.max_kv_cache_length;
// Beam indices are based on the max_kv_cache_length while each layer may have different cyclic_kv_cache_length
// So we need to rebuild the beam_indices if max_kv_cache_length is not equal to cyclic_kv_cache_length.
const auto bi_seq_len_offset = static_cast<std::size_t>(batch_beam_idx) * params.max_attention_window_size;
// Beam indices are based on the max_attention_window_size while each layer may have different
// cyclic_attention_window_size So we need to rebuild the beam_indices if max_attention_window_size is not equal to
// cyclic_attention_window_size.
const int* beam_indices = HAS_BEAMS ? &params.cache_indir[bi_seq_len_offset] : nullptr;
const auto c_tile_times_timesteps_per_block = c_tile * timesteps_per_block; // 0 if !MULTI_BLOCK_FLAG
@ -2596,39 +2597,38 @@ __global__ void masked_multihead_attention_kernel(
T* out_oi_smem = reinterpret_cast<T*>(smem_);
const auto o_idx = chunk_index<T, V_vec_k, THREADS_PER_VALUE>(tidx);
// The partial output region this thread takes care of
const auto oo = o_idx.x;
// Init partial out for accumulation.
V_vec_k zero_k;
zero(zero_k);
V_vec_k thread_accumulated_out = zero_k;
// The hidden dimensions computed by this particular thread. (refer to vi)
const auto oi = o_idx.y;
// Within the bound.
const bool within_bound = oo < gridDim.z;
// The partial output region this thread takes care of
const auto oo = o_idx.x;
// Load partial output
int thread_partial_out_offset = oo * params.batch_size * num_heads * params.hidden_size_per_head;
// Load partial max (different to thread_partial_max since the threadIdx rule changes here)
float thread_partial_max_for_out = within_bound ? params.partial_max[bhi_seq_len_tile + oo] : final_max;
// Load the partial outputs.
V_vec_k zero_k;
zero(zero_k);
V_vec_k thread_partial_out = within_bound
? *reinterpret_cast<const V_vec_k*>(&params.partial_out[thread_partial_out_offset + bhi * Dh + oi])
: zero_k;
Tk factor_compute;
convert_from_float(&factor_compute, __expf(thread_partial_max_for_out - final_max));
thread_partial_out = mul<V_vec_k, Tk, V_vec_k>(factor_compute, thread_partial_out);
// Make sure we can start writing to shared memory.
__syncthreads();
// The reduction iteration should start with a number which is a power of 2
const auto reduction_iteration = static_cast<int>(cuda::std::bit_ceil(gridDim.z));
// Each thread may handle more than one partial output.
for (int tile_idx = o_idx.x; tile_idx < gridDim.z; tile_idx += V_PER_ITER)
{
// Load partial output
int thread_partial_out_offset = tile_idx * params.batch_size * num_heads * params.hidden_size_per_head;
// Load partial max (different to thread_partial_max since the threadIdx rule changes here)
float thread_partial_max_for_out = params.partial_max[bhi_seq_len_tile + tile_idx];
// Load the partial outputs.
V_vec_k thread_partial_out
= *reinterpret_cast<const V_vec_k*>(&params.partial_out[thread_partial_out_offset + bhi * Dh + oi]);
// Apply the correction factor.
Tk factor_compute;
convert_from_float(&factor_compute, __expf(thread_partial_max_for_out - final_max));
thread_partial_out = mul<V_vec_k, Tk, V_vec_k>(factor_compute, thread_partial_out);
thread_accumulated_out = add(thread_partial_out, thread_accumulated_out);
}
// Run the final reduction amongst the different groups computing different partial outputs.
#pragma unroll
for (int active_groups = reduction_iteration; active_groups >= 2; active_groups /= 2)
for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2)
{
// The midpoint in the number of active groups.
@ -2637,15 +2637,15 @@ __global__ void masked_multihead_attention_kernel(
// The upper part of active threads store to shared memory.
if (oo >= midpoint && oo < active_groups && (Dh == Dh_MAX || oi < Dh))
{
*reinterpret_cast<V_vec_k*>(&out_oi_smem[(oo - midpoint) * Dh + oi]) = thread_partial_out;
*reinterpret_cast<V_vec_k*>(&out_oi_smem[(oo - midpoint) * Dh + oi]) = thread_accumulated_out;
}
__syncthreads();
// The bottom warps update their values.
if (oo < midpoint && (Dh == Dh_MAX || oi < Dh))
{
thread_partial_out
= add(thread_partial_out, *reinterpret_cast<const V_vec_k*>(&out_oi_smem[oo * Dh + oi]));
thread_accumulated_out
= add(thread_accumulated_out, *reinterpret_cast<const V_vec_k*>(&out_oi_smem[oo * Dh + oi]));
}
__syncthreads();
}
@ -2661,8 +2661,8 @@ __global__ void masked_multihead_attention_kernel(
Tk inv_sum_compute;
convert_from_float(&inv_sum_compute, inv_sum);
thread_partial_out = mul<V_vec_k, Tk, V_vec_k>(inv_sum_compute, thread_partial_out);
*reinterpret_cast<V_vec_k*>(&params.out[bhi * Dh + oi]) = thread_partial_out;
thread_accumulated_out = mul<V_vec_k, Tk, V_vec_k>(inv_sum_compute, thread_accumulated_out);
*reinterpret_cast<V_vec_k*>(&params.out[bhi * Dh + oi]) = thread_accumulated_out;
}
// Reset qk_current_smem and block_counter for the next timestep

View File

@ -110,11 +110,11 @@ public:
const void* qkv;
#ifdef USE_KV_SCALE
const float* kv_scale_orig_quant = nullptr;
const float* kv_scale_quant_orig = nullptr;
#endif
// Max 3K size
KVCache<HasBeam> cacheList[kMaxBatchSizePerWave];
int batch_size;
const float* kv_scale_quant_orig = nullptr;
void* scratch = nullptr;
};
@ -144,21 +144,42 @@ int buildXQALaunchParams(XQALaunchParam<HasBeam>& launchParams, const XQAParams&
launchParams.num_k_heads = params.num_kv_heads;
#ifdef USE_KV_SCALE
launchParams.kv_scale_orig_quant = params.kv_scale_orig_quant;
launchParams.kv_scale_quant_orig = params.kv_scale_quant_orig;
#endif
launchParams.kv_scale_quant_orig = params.kv_scale_quant_orig;
launchParams.batch_size = micro_batch_size;
launchParams.scratch = params.workspaces;
int max_context_length = 0;
int max_past_kv_length = 0;
if (params.host_context_lengths)
{
// TODO: remove this logic, maybe use xqaParams.sequence_lengths inside kernel.
max_context_length
= *std::max_element(params.host_context_lengths, params.host_context_lengths + params.batch_size);
max_past_kv_length = *std::max_element(
params.host_past_key_value_lengths, params.host_past_key_value_lengths + params.batch_size);
}
for (int i = 0; i < micro_batch_size; i++)
{
int batch_idx = start_batch_idx + i;
launchParams.cacheList[i].data = kv_linear_buffer.getKBlockPtr(batch_idx * params.beam_width, 0);
// the kernel_mha use KV from KVCache, so need plus 1 here.
launchParams.cacheList[i].size = params.host_past_key_value_lengths[batch_idx] + 1;
launchParams.cacheList[i].capacity = params.max_kv_cache_length;
int current_len = 0;
// TODO: remove this logic, maybe use xqaParams.sequence_lengths inside kernel.
if (params.host_context_lengths)
{
// the kernel_mha use KV from KVCache, so need plus 1 here.
current_len = params.host_context_lengths[batch_idx] + max_past_kv_length - max_context_length + 1;
}
else
{
current_len = params.host_past_key_value_lengths[batch_idx] + 1;
}
launchParams.cacheList[i].size = current_len;
launchParams.cacheList[i].capacity = params.max_attention_window_size;
if constexpr (HasBeam)
{
launchParams.cacheList[i].cacheInDir
= params.cache_indir + batch_idx * params.beam_width * params.max_kv_cache_length;
= params.cache_indir + batch_idx * params.beam_width * params.max_attention_window_size;
}
}
return micro_batch_size;
@ -175,6 +196,14 @@ public:
, mKernelMeta(&sXqaKernelMetaInfo[0])
, mSM(sm)
{
const char* enable_xqa_env_var = getenv("TRTLLM_FORCE_XQA");
if (enable_xqa_env_var != nullptr)
{
if (enable_xqa_env_var[0] == '1' && enable_xqa_env_var[1] == '\0')
{
mForceXQA = true;
}
}
}
void loadXQAKernels()
@ -238,8 +267,12 @@ public:
return findIter != mFunctions.end();
}
static bool mayHavePerfGain(const XQAParams& xqaParams, int multiprocessor_count)
bool mayHavePerfGain(const XQAParams& xqaParams, int multiprocessor_count) const
{
if (mForceXQA)
{
return true;
}
int num_kv_heads = xqaParams.num_kv_heads;
int batch_size = static_cast<int>(xqaParams.batch_size);
int multi_block_count = 1;
@ -270,7 +303,7 @@ public:
invokeApplyBiasRopeUpdateKVCache<T, KVLinearBuffer, true>(static_cast<T*>(const_cast<void*>(xqaParams.qkv)),
nullptr, kv_linear_buffer, static_cast<const T*>(xqaParams.qkv_bias), xqaParams.sequence_lengths, nullptr,
nullptr, xqaParams.batch_size, 1, xqaParams.cyclic_kv_cache_length, xqaParams.batch_size * beam_width,
nullptr, xqaParams.batch_size, 1, xqaParams.cyclic_attention_window_size, xqaParams.batch_size * beam_width,
xqaParams.num_q_heads, xqaParams.num_kv_heads, xqaParams.head_size, xqaParams.rotary_embedding_dim,
xqaParams.rotary_embedding_base, xqaParams.rotary_embedding_scale_type, xqaParams.rotary_embedding_scale,
xqaParams.rotary_embedding_max_positions, xqaParams.position_embedding_type, (float*) nullptr, 0,
@ -291,8 +324,9 @@ public:
XQALaunchParam<HAS_BEAM> launchParams;
int micro_batch_size = buildXQALaunchParams(
launchParams, xqaParams, kv_linear_buffer, start_batch_idx, multiprocessor_count);
void* kernelParams[] = {&launchParams.num_k_heads, &launchParams.output, &launchParams.qkv,
&launchParams.cacheList, &launchParams.batch_size, &launchParams.scratch, nullptr};
void* kernelParams[]
= {&launchParams.num_k_heads, &launchParams.output, &launchParams.qkv, &launchParams.cacheList,
&launchParams.batch_size, &launchParams.kv_scale_quant_orig, &launchParams.scratch, nullptr};
int multi_block = 1;
if (xqaParams.multi_block_mode)
{
@ -352,6 +386,8 @@ protected:
unsigned int mSM;
std::unordered_map<const unsigned long long*, CUmodule> mModules;
bool mForceXQA = false;
struct XQAKernelFuncInfo
{
unsigned int mMetaInfoIndex;
@ -420,7 +456,7 @@ public:
bool shouldUse(const XQAParams& xqaParams)
{
return xqaKernel->supportConfig(xqaParams) && XQAKernelList::mayHavePerfGain(xqaParams, mMultiProcessorCount);
return xqaKernel->supportConfig(xqaParams) && xqaKernel->mayHavePerfGain(xqaParams, mMultiProcessorCount);
}
void run(const XQAParams& xqa_params, KVLinearBuffer& kv_linear_buffer, const cudaStream_t& stream,

View File

@ -57,11 +57,12 @@ struct XQAParams
const float* kv_scale_orig_quant = nullptr;
const float* kv_scale_quant_orig = nullptr;
const int32_t* host_past_key_value_lengths = nullptr;
const int32_t* host_context_lengths = nullptr;
void* workspaces = nullptr;
uint32_t batch_size = 0;
int32_t beam_width = 0;
int32_t max_kv_cache_length = 0;
int32_t cyclic_kv_cache_length = 0;
int32_t max_attention_window_size = 0;
int32_t cyclic_attention_window_size = 0;
int timestep = 0;
const void* qkv_bias;
const int32_t* sequence_lengths; //
@ -128,21 +129,10 @@ public:
SUPPORT_RETURN_FALSE("rotary_embedding_base");
if (xqaParams.rotary_embedding_scale_type != tensorrt_llm::kernels::RotaryScalingType::kNONE)
SUPPORT_RETURN_FALSE("rotary_embedding_scale_type");
if (xqaParams.rotary_embedding_scale != 1.0f)
SUPPORT_RETURN_FALSE("rotary_embedding_scale");
if (xqaParams.position_embedding_type != tensorrt_llm::kernels::PositionEmbeddingType::kROPE_GPT_NEOX)
SUPPORT_RETURN_FALSE("position_embedding_type");
// xqaParams.remove_padding
if (xqaParams.mask_type != tensorrt_llm::kernels::AttentionMaskType::CAUSAL)
SUPPORT_RETURN_FALSE("mask_type");
if (xqaParams.paged_kv_cache)
SUPPORT_RETURN_FALSE("paged_kv_cache");
if (xqaParams.kv_cache_quant_mode != tensorrt_llm::common::QuantMode::int8KvCache()
&& xqaParams.kv_cache_quant_mode != tensorrt_llm::common::QuantMode::fp8KvCache()
&& xqaParams.kv_cache_quant_mode != tensorrt_llm::common::QuantMode::none())
SUPPORT_RETURN_FALSE("kv_cache_quant_mode");
if (xqaParams.kv_cache_quant_mode != tensorrt_llm::common::QuantMode::none())
SUPPORT_RETURN_FALSE("kv_cache_quant_mode");
if (xqaParams.qkv_bias_enabled)
SUPPORT_RETURN_FALSE("qkv_bias_enabled");
if (xqaParams.cross_attention)
@ -152,8 +142,8 @@ public:
SUPPORT_RETURN_FALSE("host_past_key_value_lengths");
if (xqaParams.beam_width != 1)
SUPPORT_RETURN_FALSE("beam_width");
if (xqaParams.cyclic_kv_cache_length != xqaParams.max_kv_cache_length)
SUPPORT_RETURN_FALSE("cyclic_kv_cache_length != max_kv_cache_length");
if (xqaParams.cyclic_attention_window_size != xqaParams.max_attention_window_size)
SUPPORT_RETURN_FALSE("cyclic_attention_window_size != max_attention_window_size");
return shouldUseImpl(xqaParams);
}

View File

@ -134,7 +134,7 @@ __global__ void computePaddingOffsets(int* paddingOffsets, const int* seqOffsets
template <typename AttentionMaskDataType>
__global__ void computeAttentionMask(AttentionMaskDataType* attentionMask, const int* seqOffsets, int maxSeqLength,
int maxKvCacheLength, AttentionMaskType attentionMaskType)
int attentionWindowSize, AttentionMaskType attentionMaskType)
{
// The index of the sequence in the batch.
int batchIdx = blockIdx.y;
@ -174,7 +174,7 @@ __global__ void computeAttentionMask(AttentionMaskDataType* attentionMask, const
case AttentionMaskType::CAUSAL:
isValid = rowIdx < seqLength && colIdx < seqLength && colIdx <= rowIdx;
// Sliding_window_causal when there are not enough kv cache.
isValid = isValid && colIdx >= max(0, rowIdx - maxKvCacheLength);
isValid = isValid && colIdx >= max(0, rowIdx - attentionWindowSize);
// seq_length==4, max_seq_len==5
// 1 0 0 0 0
// 1 1 0 0 0
@ -182,7 +182,7 @@ __global__ void computeAttentionMask(AttentionMaskDataType* attentionMask, const
// 1 1 1 1 0
// 0 0 0 0 0
// seq_length==6, max_seq_len==6, max_kv_cache_length = 2
// seq_length==6, max_seq_len==6, max_attention_window_size = 2
// 1 0 0 0 0 0
// 1 1 0 0 0 0
// 1 1 1 0 0 0
@ -248,7 +248,7 @@ void invokeBuildDecoderInfo(const BuildDecoderInfoParams<T>& params, cudaStream_
}
dim3 grid(blocksPerSeq, params.batchSize);
computeAttentionMask<<<grid, THREADS_PER_BLOCK, 0, stream>>>(params.attentionMask, params.seqQOffsets,
params.maxSeqLength, params.maxKvCacheLength, params.attentionMaskType);
params.maxSeqLength, params.attentionWindowSize, params.attentionMaskType);
}
}

View File

@ -82,7 +82,7 @@ struct BuildDecoderInfoParams
int maxSeqLength;
// The kv cache capacity.
// We will apply the limited_length_causal mask when there are not enough kv cache.
int maxKvCacheLength;
int attentionWindowSize;
// The number of tokens in total. It's \sum_{ii=0}^{batchSize} seqLengths[ii].
int numTokens;
// The type of attention.

View File

@ -0,0 +1,160 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/gemm/kernel/gemm_grouped.h"
#include "groupGemm.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/common/cudaUtils.h"
namespace tensorrt_llm
{
namespace kernels
{
template <int M1, int N1, int K1, int M2, int N2, int K2>
void run_cutlass_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA, std::vector<void*> ptrB,
std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize, void* cublasWorkSpace,
int64_t cublasWorkspaceSize, cudaStream_t stream)
{
using ElementA = cutlass::half_t;
using ElementB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
const int kAlignmentA = 8;
const int kAlignmentB = 8;
int problem_count = problem_sizes.size();
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<ElementA, LayoutA,
cutlass::ComplexTransform::kNone, kAlignmentA, ElementB, LayoutB, cutlass::ComplexTransform::kNone, kAlignmentB,
ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<M1, N1, K1>, cutlass::gemm::GemmShape<M2, N2, K2>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
// NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
// This parameter is passed in at present to match the APIs of other kernels. The parameter
// is unused within the kernel.
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
4, // kStages
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>::GemmKernel;
using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>;
float alpha = 1.0f;
float beta = 0.0f;
typename Gemm::EpilogueOutputOp::Params epilogue_op(alpha, beta);
auto gemm_coord_size = tensorrt_llm::common::divUp(problem_count * sizeof(cutlass::gemm::GemmCoord), 16) * 16;
auto ptr_size = tensorrt_llm::common::divUp(problem_count * sizeof(half*), 16) * 16;
auto ldd_size = tensorrt_llm::common::divUp(problem_count * sizeof(int64_t), 16) * 16;
char* host_workspace = (char*) std::malloc(workSpaceSize);
cutlass::gemm::GemmCoord* problem_sizes_host = reinterpret_cast<cutlass::gemm::GemmCoord*>(host_workspace);
ElementA** ptr_A_host = reinterpret_cast<ElementA**>(host_workspace + gemm_coord_size);
ElementB** ptr_B_host = reinterpret_cast<ElementB**>(host_workspace + gemm_coord_size + ptr_size);
ElementOutput** ptr_C_host = reinterpret_cast<ElementOutput**>(host_workspace + gemm_coord_size + 2 * ptr_size);
ElementOutput** ptr_D_host = reinterpret_cast<ElementOutput**>(host_workspace + gemm_coord_size + 3 * ptr_size);
int64_t* lda_host = reinterpret_cast<int64_t*>(host_workspace + gemm_coord_size + 4 * ptr_size + 0 * ldd_size);
int64_t* ldb_host = reinterpret_cast<int64_t*>(host_workspace + gemm_coord_size + 4 * ptr_size + 1 * ldd_size);
int64_t* ldc_host = reinterpret_cast<int64_t*>(host_workspace + gemm_coord_size + 4 * ptr_size + 2 * ldd_size);
int64_t* ldd_host = reinterpret_cast<int64_t*>(host_workspace + gemm_coord_size + 4 * ptr_size + 3 * ldd_size);
for (int32_t i = 0; i < problem_count; ++i)
{
problem_sizes_host[i] = problem_sizes.at(i);
ptr_A_host[i] = (ElementA*) ptrA.at(i);
ptr_B_host[i] = (ElementB*) ptrB.at(i);
ptr_C_host[i] = (ElementOutput*) ptrC.at(i);
ptr_D_host[i] = (ElementOutput*) ptrD.at(i);
auto problem = problem_sizes.at(i);
lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0);
ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0);
ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0);
ldd_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0);
}
cutlass::gemm::GemmCoord* problem_sizes_device = reinterpret_cast<cutlass::gemm::GemmCoord*>(workspace);
ElementA** ptr_A = reinterpret_cast<ElementA**>((char*) workspace + gemm_coord_size);
ElementB** ptr_B = reinterpret_cast<ElementB**>((char*) workspace + gemm_coord_size + ptr_size);
ElementOutput** ptr_C = reinterpret_cast<ElementOutput**>((char*) workspace + gemm_coord_size + 2 * ptr_size);
ElementOutput** ptr_D = reinterpret_cast<ElementOutput**>((char*) workspace + gemm_coord_size + 3 * ptr_size);
int64_t* lda = reinterpret_cast<int64_t*>((char*) workspace + gemm_coord_size + 4 * ptr_size + 0 * ldd_size);
int64_t* ldb = reinterpret_cast<int64_t*>((char*) workspace + gemm_coord_size + 4 * ptr_size + 1 * ldd_size);
int64_t* ldc = reinterpret_cast<int64_t*>((char*) workspace + gemm_coord_size + 4 * ptr_size + 2 * ldd_size);
int64_t* ldd = reinterpret_cast<int64_t*>((char*) workspace + gemm_coord_size + 4 * ptr_size + 3 * ldd_size);
TLLM_CHECK(((char*) ldc_host - (char*) host_workspace) == ((char*) ldc - (char*) workspace));
tensorrt_llm::common::cudaAutoCpy((int8_t*) workspace, (int8_t*) host_workspace, workSpaceSize, stream);
int threadblock_count = Gemm::sufficient(problem_sizes.data(), problem_count);
typename Gemm::Arguments args(problem_sizes_device, problem_count, threadblock_count, epilogue_op, ptr_A, ptr_B,
ptr_C, ptr_D, lda, ldb, ldc, ldd, problem_sizes.data());
// Initialize the GEMM object
Gemm gemm;
size_t workspace_size = gemm.get_workspace_size(args);
TLLM_CHECK(gemm.get_workspace_size(args) <= cublasWorkspaceSize);
cutlass::Status status = gemm.initialize(args, cublasWorkSpace);
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to initialize CUTLASS Grouped GEMM kernel.");
// Run the grouped GEMM object
status = gemm.run(stream);
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to run CUTLASS Grouped GEMM kernel.");
std::free(host_workspace);
}
void run_cutlass_1(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA,
std::vector<void*> ptrB, std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize,
void* cublasWorkSpace, int64_t cublasWorkspaceSize, cudaStream_t stream)
{
// For lora in, which has smaller N
run_cutlass_<128, 32, 32, 32, 32, 32>(
problem_sizes, ptrA, ptrB, ptrC, ptrD, workspace, workSpaceSize, cublasWorkSpace, cublasWorkspaceSize, stream);
}
void run_cutlass_2(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA,
std::vector<void*> ptrB, std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize,
void* cublasWorkSpace, int64_t cublasWorkspaceSize, cudaStream_t stream)
{
// For lora out, which has larger N
run_cutlass_<128, 128, 32, 64, 64, 32>(
problem_sizes, ptrA, ptrB, ptrC, ptrD, workspace, workSpaceSize, cublasWorkSpace, cublasWorkspaceSize, stream);
}
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,35 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/gemm_coord.h"
namespace tensorrt_llm
{
namespace kernels
{
void run_cutlass_1(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA,
std::vector<void*> ptrB, std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize,
void* cublasWorkSpace, int64_t cublasWorkspaceSize, cudaStream_t stream);
void run_cutlass_2(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA,
std::vector<void*> ptrB, std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize,
void* cublasWorkSpace, int64_t cublasWorkspaceSize, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -411,7 +411,7 @@ template <typename T, int EXPERTS, int WARPS_PER_TB>
void topkGatingSoftmaxLauncherHelper(const T* input, const bool* finished, T* output, int* indices, int* source_row,
const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
{
static constexpr unsigned long MAX_BYTES_PER_LDG = 16;
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(T) * EXPERTS);
using Constants = detail::TopkConstants<T, EXPERTS, BYTES_PER_LDG>;

View File

@ -1628,8 +1628,8 @@ INSTANTIATE_TRANSPOSE_4D(half);
template <typename T, typename T_cache, typename KVCacheBuffer>
__global__ void transpose4dBatchMajorKVCache(const T* kSrc, const T* vSrc, KVCacheBuffer kvCacheBuffer,
const int headNum, const int sizePerHead, const int seqLen, const int maxKvCacheLen, const float* kvScaleOrigQuant,
const int* sequence_lengths)
const int headNum, const int sizePerHead, const int seqLen, const int attentionWindowSize,
const float* kvScaleOrigQuant, const int* sequence_lengths)
{
// We allow only fp32/fp16/bf16 as input types
static_assert(sizeof(T) == 4 || sizeof(T) == 2, "");
@ -1655,9 +1655,9 @@ __global__ void transpose4dBatchMajorKVCache(const T* kSrc, const T* vSrc, KVCac
// Get linear token index
int tokenIdx = idx / sizePerHeadDivX;
// Apply cyclic kv cache if tokenIdx >= max_kv_cache_length.
// which means we will drop the tokens in the beginning if seqLen > max_kv_cache_length.
const int tokenIdxLowerBound = max(sequence_lengths[batchIdx] - maxKvCacheLen, 0);
// Apply cyclic kv cache if tokenIdx >= max_attention_window_size.
// which means we will drop the tokens in the beginning if seqLen > max_attention_window_size.
const int tokenIdxLowerBound = max(sequence_lengths[batchIdx] - attentionWindowSize, 0);
// Get channel index
const int channelIdx = idx % sizePerHeadDivX;
if (tokenIdx >= sequence_lengths[batchIdx] || tokenIdx < tokenIdxLowerBound)
@ -1665,8 +1665,8 @@ __global__ void transpose4dBatchMajorKVCache(const T* kSrc, const T* vSrc, KVCac
return;
}
// Apply cyclic kv cache if tokenIdx >= max_kv_cache_length.
tokenIdx = tokenIdx % maxKvCacheLen;
// Apply cyclic kv cache if tokenIdx >= max_attention_window_size.
tokenIdx = tokenIdx % attentionWindowSize;
// Get pointer to the dst block given sequence, head and token ids
auto valDst = handle_k ? reinterpret_cast<T_dst*>(kvCacheBuffer.getKBlockPtr(batchIdx, tokenIdx))
@ -1702,7 +1702,7 @@ __global__ void transpose4dBatchMajorKVCache(const T* kSrc, const T* vSrc, KVCac
template <typename T, typename KVCacheBuffer>
void invokeTranspose4dBatchMajor(const T* kSrc, const T* vSrc, KVCacheBuffer& kvTable, const int localBatchSize,
const int seqLen, const int maxKvCacheLen, const int sizePerHead, const int localHeadNum,
const int seqLen, const int attentionWindowSize, const int sizePerHead, const int localHeadNum,
const KvCacheDataType cache_type, const float* kvScaleOrigQuant, const int* sequence_lengths, cudaStream_t stream)
{
// Block handles both K and V tile.
@ -1714,26 +1714,26 @@ void invokeTranspose4dBatchMajor(const T* kSrc, const T* vSrc, KVCacheBuffer& kv
if (cache_type == KvCacheDataType::INT8)
{
transpose4dBatchMajorKVCache<T, int8_t, KVCacheBuffer><<<gridSz, blockSz, 0, stream>>>(
kSrc, vSrc, kvTable, localHeadNum, sizePerHead, seqLen, maxKvCacheLen, kvScaleOrigQuant, sequence_lengths);
transpose4dBatchMajorKVCache<T, int8_t, KVCacheBuffer><<<gridSz, blockSz, 0, stream>>>(kSrc, vSrc, kvTable,
localHeadNum, sizePerHead, seqLen, attentionWindowSize, kvScaleOrigQuant, sequence_lengths);
}
#ifdef ENABLE_FP8
else if (cache_type == KvCacheDataType::FP8)
{
transpose4dBatchMajorKVCache<T, __nv_fp8_e4m3, KVCacheBuffer><<<gridSz, blockSz, 0, stream>>>(
kSrc, vSrc, kvTable, localHeadNum, sizePerHead, seqLen, maxKvCacheLen, kvScaleOrigQuant, sequence_lengths);
transpose4dBatchMajorKVCache<T, __nv_fp8_e4m3, KVCacheBuffer><<<gridSz, blockSz, 0, stream>>>(kSrc, vSrc,
kvTable, localHeadNum, sizePerHead, seqLen, attentionWindowSize, kvScaleOrigQuant, sequence_lengths);
}
#endif // ENABLE_FP8
else
{
transpose4dBatchMajorKVCache<T, T, KVCacheBuffer><<<gridSz, blockSz, 0, stream>>>(
kSrc, vSrc, kvTable, localHeadNum, sizePerHead, seqLen, maxKvCacheLen, kvScaleOrigQuant, sequence_lengths);
transpose4dBatchMajorKVCache<T, T, KVCacheBuffer><<<gridSz, blockSz, 0, stream>>>(kSrc, vSrc, kvTable,
localHeadNum, sizePerHead, seqLen, attentionWindowSize, kvScaleOrigQuant, sequence_lengths);
}
}
#define INSTANTIATE_TRANSPOSE_4D_BATCH_MAJOR_KV_CACHE_TYPE(T, KVCacheBuffer) \
template void invokeTranspose4dBatchMajor(const T* kSrc, const T* vSrc, KVCacheBuffer& kvTable, \
const int localBatchSize, const int seqLen, const int maxKvCacheLen, const int sizePerHead, \
const int localBatchSize, const int seqLen, const int attentionWindowSize, const int sizePerHead, \
const int localHeadNum, const KvCacheDataType cache_type, const float* kvScaleOrigQuant, \
const int* sequence_lengths, cudaStream_t stream)

View File

@ -105,7 +105,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, const
template <typename T, typename KVCacheBuffer>
void invokeTranspose4dBatchMajor(const T* k_src, const T* v_src, KVCacheBuffer& kvTable, const int local_batch_size,
const int seq_len, const int max_kv_cache_len, const int size_per_head, const int local_head_num,
const int seq_len, const int max_attention_window_size, const int size_per_head, const int local_head_num,
const KvCacheDataType cache_type, const float* kvScaleOrigQuant, const int* sequence_lengths, cudaStream_t stream);
// NOTE: this kernel is in-place, QKV will be modified, if other kernels need that, may need copy or use before it.

View File

@ -32,41 +32,41 @@ namespace layers
__global__ void update_indir_cache_kernel(int* tgt_indir_cache, const int* src_indir_cache, const int** parent_ids,
const FinishedState* finished, const int* sequence_lengths, const int* input_lengths, int batch_dim,
int local_batch_size, int beam_width, int max_kv_cache_length, int max_seq_len)
int local_batch_size, int beam_width, int max_attention_window, int max_seq_len)
{
int time_step = threadIdx.x + blockIdx.x * blockDim.x;
int bb_id = threadIdx.y + blockIdx.y * blockDim.y; // should be just blockIdx.y?
const int current_step{sequence_lengths[bb_id] - 1}; // the sequence_lengths is updated, need to minus 1
const int batch_id = bb_id / beam_width;
const int beam_id = bb_id % beam_width;
if (bb_id >= beam_width * local_batch_size || time_step < (max_seq_len - max_kv_cache_length)
if (bb_id >= beam_width * local_batch_size || time_step < (max_seq_len - max_attention_window)
|| finished[bb_id].isFinished())
{
return;
}
int time_step_circ = time_step % max_kv_cache_length;
int time_step_circ = time_step % max_attention_window;
// for the parent_ids, we will still keep it for all past tokens (i.e. max_seq_len)
const int src_beam = parent_ids[batch_id][beam_id * max_seq_len + current_step];
// for the indir tables, we have the cyclic kv cache.
const uint32_t tgt_offset
= batch_id * beam_width * max_kv_cache_length + beam_id * max_kv_cache_length + time_step_circ;
= batch_id * beam_width * max_attention_window + beam_id * max_attention_window + time_step_circ;
const uint32_t src_offset
= batch_id * beam_width * max_kv_cache_length + src_beam * max_kv_cache_length + time_step_circ;
= batch_id * beam_width * max_attention_window + src_beam * max_attention_window + time_step_circ;
tgt_indir_cache[tgt_offset] = (time_step == current_step) ? beam_id : src_indir_cache[src_offset];
}
void update_indir_cache_kernelLauncher(int* tgt_indir_cache, const int* src_indir_cache, const int** parent_ids,
const FinishedState* finished, const int* sequence_lengths, const int* input_lengths, int batch_dim,
int local_batch_size, int beam_width, int max_seq_len, int max_kv_cache_length, cudaStream_t stream)
int local_batch_size, int beam_width, int max_seq_len, int max_attention_window, cudaStream_t stream)
{
const dim3 block(32);
// Update indirections steps [input_length[bb_id], sequence_lengths[bb_id]], included
const dim3 grid((max_seq_len + block.x - 1) / block.x, local_batch_size * beam_width);
update_indir_cache_kernel<<<grid, block, 0, stream>>>(tgt_indir_cache, src_indir_cache, parent_ids, finished,
sequence_lengths, input_lengths, batch_dim, local_batch_size, beam_width, max_kv_cache_length, max_seq_len);
sequence_lengths, input_lengths, batch_dim, local_batch_size, beam_width, max_attention_window, max_seq_len);
}
template <typename T>
@ -191,7 +191,7 @@ void BaseBeamSearchLayer<T>::forward(BeamSearchOutputParams& outputs, ForwardPar
reinterpret_cast<const FinishedState*>(
outputs.finished->template getPtr<const FinishedState::UnderlyingType>()),
sequence_length, input_lengths, batch_size, local_batch_size, beam_width, max_seq_len,
params.max_kv_cache_length, stream_);
params.max_attention_window, stream_);
sync_check_cuda_error();
}
sync_check_cuda_error();

View File

@ -56,16 +56,16 @@ public:
{
public:
ForwardParams(int step, int ite, tc::Tensor logits, tc::Tensor endIds, tc::Tensor src_cache_indirection,
int max_kv_cache_length, int max_seq_len)
int max_attention_window, int max_seq_len)
: SoftmaxParams(step, ite, std::move(logits), std::move(endIds))
, src_cache_indirection{std::move(src_cache_indirection)}
, max_kv_cache_length{max_kv_cache_length}
, max_attention_window{max_attention_window}
, max_seq_len{max_seq_len}
{
}
// mandatory parameters
int max_kv_cache_length;
int max_attention_window;
int max_seq_len;
tc::Tensor src_cache_indirection; // [local_batch_size, beam_width, max_seq_len]

View File

@ -296,7 +296,7 @@ void DynamicDecodeLayer<T>::forward(OutputParams& outputs, ForwardParams const&
auto const end_id_offset
= end_ids.slice({dynamic_decode_batch_size}, dynamic_ite * dynamic_decode_batch_size);
typename BaseBeamSearchLayer<T>::ForwardParams dynamic_decode_input_tensors{step, ite, logits_offset,
end_id_offset, *params.src_cache_indirection, static_cast<std::int32_t>(params.max_kv_cache_length),
end_id_offset, *params.src_cache_indirection, static_cast<std::int32_t>(params.max_attention_window),
static_cast<std::int32_t>(max_seq_len)};
dynamic_decode_input_tensors.embedding_bias = params.embedding_bias;

View File

@ -79,12 +79,12 @@ public:
class ForwardParams
{
public:
ForwardParams(int step, int ite, int maxInputLength, int maxKvCacheLength, int localBatchSize,
ForwardParams(int step, int ite, int maxInputLength, int maxAttentionWindow, int localBatchSize,
tc::Tensor logits, tc::Tensor endIds)
: step{step}
, ite{ite}
, max_input_length{maxInputLength}
, max_kv_cache_length{maxKvCacheLength}
, max_attention_window{maxAttentionWindow}
, local_batch_size{localBatchSize}
, logits{std::move(logits)}
, end_ids{std::move(endIds)}
@ -95,7 +95,7 @@ public:
int step;
int ite;
int max_input_length;
int max_kv_cache_length;
int max_attention_window;
int local_batch_size;
tc::Tensor logits; // [batch_size, beam_width, vocab_size_padded], on gpu
tc::Tensor end_ids; // [batch_size], on gpu

View File

@ -84,8 +84,8 @@ struct FusedQKVMaskedAttentionDispatchParams
float rotary_embedding_scale;
int rotary_embedding_max_positions;
PositionEmbeddingType position_embedding_type;
int max_kv_cache_length;
int cyclic_kv_cache_length;
int max_attention_window;
int cyclic_attention_window_size;
const int* input_lengths;
int step;
float q_scaling;
@ -182,11 +182,12 @@ bool GPTAttentionPluginCommon::convertMMHAParamsToXQAParams(
xqaParams.kv_scale_orig_quant = generationsParams.kv_scale_orig_quant;
xqaParams.kv_scale_quant_orig = generationsParams.kv_scale_quant_orig;
xqaParams.host_past_key_value_lengths = generationsParams.host_past_key_value_lengths;
xqaParams.host_context_lengths = generationsParams.host_context_lengths;
xqaParams.workspaces = generationsParams.workspace;
xqaParams.batch_size = generationsParams.num_requests;
xqaParams.beam_width = generationsParams.beam_width;
xqaParams.max_kv_cache_length = generationsParams.max_kv_cache_length;
xqaParams.cyclic_kv_cache_length = generationsParams.cyclic_kv_cache_length;
xqaParams.max_attention_window_size = generationsParams.max_attention_window;
xqaParams.cyclic_attention_window_size = generationsParams.cyclic_attention_window_size;
xqaParams.timestep = generationsParams.past_kv_length;
xqaParams.qkv_bias = generationsParams.qkv_bias;
xqaParams.sequence_lengths = generationsParams.sequence_lengths;
@ -241,8 +242,8 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS
params.cache_indir = input_params.cache_indir;
params.batch_size = input_params.inference_batch_size;
params.beam_width = input_params.beam_width;
params.max_kv_cache_length = input_params.max_kv_cache_length;
params.cyclic_kv_cache_length = input_params.cyclic_kv_cache_length;
params.max_attention_window_size = input_params.max_attention_window;
params.cyclic_attention_window_size = input_params.cyclic_attention_window_size;
params.length_per_sample = input_params.sequence_lengths; // max_input_length + current output length
// timestep for shared memory size calculation and rotary embedding computation
params.timestep = input_params.step - 1;
@ -550,7 +551,7 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
{
using BufferDataType = typename KVCacheBufferDataType<KVCacheBuffer>::Type;
kv_cache_buffer = KVCacheBuffer(params.batch_size, 1,
isCrossAttention() ? params.cross_qkv_length : params.max_kv_cache_length,
isCrossAttention() ? params.cross_qkv_length : params.max_attention_window,
num_kv_heads * head_size * elem_size);
kv_cache_buffer.data = reinterpret_cast<BufferDataType*>(params.key_value_cache);
}
@ -651,7 +652,7 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
decoder_params.seqKVLengths = isCrossAttention() ? params.encoder_input_lengths : params.kv_seq_lengths;
decoder_params.batchSize = params.batch_size;
decoder_params.maxSeqLength = isCrossAttention() ? params.cross_qkv_length : params.input_seq_length;
decoder_params.maxKvCacheLength = params.cyclic_kv_cache_length;
decoder_params.attentionWindowSize = params.cyclic_attention_window_size;
decoder_params.numTokens = params.num_tokens;
decoder_params.attentionMaskType = mMaskType;
invokeBuildDecoderInfo(decoder_params, stream);
@ -712,7 +713,7 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
invokeApplyBiasRopeUpdateKVCache(const_cast<T*>(params.attention_input), q_buf_2_, kv_cache_buffer,
const_cast<T*>(params.qkv_bias), params.q_seq_lengths, params.kv_seq_lengths,
mRemovePadding ? padding_offset : nullptr, params.batch_size, params.input_seq_length,
params.cyclic_kv_cache_length, params.num_tokens, mNumHeads, mNumKVHeads, getHeadSize(),
params.cyclic_attention_window_size, params.num_tokens, mNumHeads, mNumKVHeads, getHeadSize(),
mRotaryEmbeddingDim, mRotaryEmbeddingBase, mRotaryEmbeddingScaleType, mRotaryEmbeddingScale,
mRotaryEmbeddingMaxPositions, position_embedding_type, (float*) nullptr, 0, cache_type,
params.kv_scale_orig_quant, enablePagedKVContextFMHA, stream);
@ -733,9 +734,9 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
// - cu_kv_seqlens: the cumulative kv sequence lengths, needed for variable sequence length.
// the token will pay attention to previous tokens while starting from max(0, rowIdx -
// cyclic_kv_cache_length);
// cyclic_attention_window_size);
mFMHARunner->setup_paged_kv(params.batch_size, params.input_seq_length, params.max_past_kv_len,
blocks_per_context_sequence, mTokensPerBlock, params.cyclic_kv_cache_length, params.num_tokens,
blocks_per_context_sequence, mTokensPerBlock, params.cyclic_attention_window_size, params.num_tokens,
isALiBi(), isAliBiWithScale(), mTpSize, mTpRank);
mFMHARunner->run_paged_kv(q_buf_2_, paged_kv_tma_desc, host_kv_cache_block_ptrs,
reinterpret_cast<KVBlockArray&>(kv_cache_buffer), cu_q_seqlens, cu_kv_seqlens, params.context_buf,
@ -744,8 +745,8 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
else
{
// the token will pay attention to previous tokens while starting from max(0, rowIdx -
// cyclic_kv_cache_length);
mFMHARunner->setup(params.batch_size, params.input_seq_length, params.cyclic_kv_cache_length,
// cyclic_attention_window_size);
mFMHARunner->setup(params.batch_size, params.input_seq_length, params.cyclic_attention_window_size,
params.num_tokens, isALiBi(), isAliBiWithScale(), mTpSize, mTpRank);
mFMHARunner->run(const_cast<T*>(params.attention_input), cu_q_seqlens, params.context_buf, stream);
}
@ -795,7 +796,7 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
{
invokeTranspose4dBatchMajor(k_buf_2_, v_buf_2_, kv_cache_buffer, params.batch_size,
isCrossAttention() ? params.cross_qkv_length : params.input_seq_length,
isCrossAttention() ? params.cross_qkv_length : params.cyclic_kv_cache_length, getHeadSize(),
isCrossAttention() ? params.cross_qkv_length : params.cyclic_attention_window_size, getHeadSize(),
mNumKVHeads, cache_type, params.kv_scale_orig_quant,
isCrossAttention() ? params.encoder_input_lengths : params.q_seq_lengths, stream);
}
@ -891,7 +892,7 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
// [num_heads, num_buckets], with necessary params (max_distance, num_buckets) passed at the end
invokeAddRelativeAttentionBiasUnaligned(qk_buf_float_, relative_attention_bias, params.batch_size,
mNumHeads, attention_seq_len_1,
isCrossAttention() ? params.cross_qkv_length : params.cyclic_kv_cache_length, stream,
isCrossAttention() ? params.cross_qkv_length : params.cyclic_attention_window_size, stream,
max_distance > 0, relative_attention_bias_stride, max_distance, true /* bidirectional */);
}
@ -918,8 +919,9 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
// already max_output_len + 1. In implicit mode, relative_attention_bias is relative_attention_table
// [num_heads, num_buckets], with necessary params (max_distance, num_buckets) passed at the end
invokeAddRelativeAttentionBiasUnaligned(qk_buf_, relative_attention_bias, params.batch_size, mNumHeads,
attention_seq_len_1, isCrossAttention() ? params.cross_qkv_length : params.cyclic_kv_cache_length,
stream, max_distance > 0, relative_attention_bias_stride, max_distance, true /* bidirectional */);
attention_seq_len_1,
isCrossAttention() ? params.cross_qkv_length : params.cyclic_attention_window_size, stream,
max_distance > 0, relative_attention_bias_stride, max_distance, true /* bidirectional */);
}
MaskedSoftmaxParam<T, T> param;
@ -1078,7 +1080,7 @@ int GPTAttentionPluginCommon::enqueueGeneration(
{
using BufferDataType = typename KVCacheBufferDataType<KVCacheBuffer>::Type;
kv_cache_buffer
= KVCacheBuffer(batch_beam, 1, params.max_kv_cache_length, num_kv_heads * head_size * elem_size);
= KVCacheBuffer(batch_beam, 1, params.max_attention_window, num_kv_heads * head_size * elem_size);
kv_cache_buffer.data = reinterpret_cast<BufferDataType*>(params.key_value_cache);
}
}
@ -1098,8 +1100,8 @@ int GPTAttentionPluginCommon::enqueueGeneration(
}
int timestep = params.past_kv_length;
const int max_timesteps
= mCrossAttention ? params.cyclic_kv_cache_length : std::min(timestep, params.cyclic_kv_cache_length);
const int max_timesteps = mCrossAttention ? params.cyclic_attention_window_size
: std::min(timestep, params.cyclic_attention_window_size);
int estimated_min_multi_block_count
= estimate_min_multi_block_count<T>(max_timesteps, mMaxSharedMemoryPerBlockOptin - 2048);
@ -1157,8 +1159,8 @@ int GPTAttentionPluginCommon::enqueueGeneration(
dispatch_params.size_per_head = getHeadSize();
dispatch_params.rotary_embedding_dim = mRotaryEmbeddingDim;
dispatch_params.position_embedding_type = mPositionEmbeddingType;
dispatch_params.max_kv_cache_length = params.max_kv_cache_length;
dispatch_params.cyclic_kv_cache_length = params.cyclic_kv_cache_length;
dispatch_params.max_attention_window = params.max_attention_window;
dispatch_params.cyclic_attention_window_size = params.cyclic_attention_window_size;
dispatch_params.input_lengths = params.context_lengths;
dispatch_params.step = step;
dispatch_params.q_scaling = q_scaling;

View File

@ -87,12 +87,12 @@ protected:
T const* qkv_bias;
int32_t input_seq_length; // padded input length
int32_t max_past_kv_len;
// By default, max_kv_cache_length == cyclic_kv_cache_length
// By default, max_attention_window == cyclic_attention_window_size
// unless each layer has different cyclic kv cache length.
// Max cache capacity (used to allocate KV cache)
int32_t max_kv_cache_length;
int32_t max_attention_window;
// Cyclic kv cache capacity (used to get the cyclic kv cache position for new tokens)
int32_t cyclic_kv_cache_length;
int32_t cyclic_attention_window_size;
int32_t const* q_seq_lengths;
int32_t const* kv_seq_lengths;
float const* kv_scale_orig_quant;
@ -134,12 +134,12 @@ protected:
T* context_buf;
void* key_value_cache;
void* block_pointers;
// By default, max_kv_cache_length == cyclic_kv_cache_length
// By default, max_attention_window == cyclic_attention_window_size
// unless each layer has different cyclic kv cache length.
// Max cache capacity (used to allocate KV cache)
int32_t max_kv_cache_length;
int32_t max_attention_window;
// Cyclic kv cache capacity (used to get the cyclic kv cache position for new tokens)
int32_t cyclic_kv_cache_length;
int32_t cyclic_attention_window_size;
int32_t num_requests;
int32_t max_blocks_per_sequence;
int32_t const* cache_indir;
@ -150,6 +150,7 @@ protected:
int relative_attention_bias_stride = 0;
// optional when cross attention
int32_t const* encoder_input_lengths = nullptr;
int32_t const* host_context_lengths = nullptr;
};
template <typename T, typename KVCacheBuffer>

View File

@ -66,7 +66,7 @@ bool GPTAttentionPlugin::isEntryUsed(const IdxEntry& entry) const
case IdxEntry::QKV_TENSOR: return true;
case IdxEntry::SEQUENCE_LENGTH: return useKVCache();
case IdxEntry::HOST_PAST_KEY_VALUE_LENGTHS: return useKVCache();
case IdxEntry::HOST_MAX_KV_CACHE_LENGTH: return true;
case IdxEntry::HOST_MAX_ATTENTION_WINDOW: return true;
case IdxEntry::CONTEXT_LENGTHS: return true;
case IdxEntry::CACHE_INDIR: return useKVCache();
case IdxEntry::REQUEST_TYPES: return true;
@ -132,7 +132,7 @@ bool GPTAttentionPlugin::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept
{
if (pos == getIdx(IdxEntry::CONTEXT_LENGTHS) || pos == getIdx(IdxEntry::REQUEST_TYPES)
|| pos == getIdx(IdxEntry::HOST_MAX_KV_CACHE_LENGTH))
|| pos == getIdx(IdxEntry::HOST_MAX_ATTENTION_WINDOW))
{
return inOut[pos].type == nvinfer1::DataType::kINT32;
}
@ -319,17 +319,17 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
const int beamWidth = useKVCache() ? inputDesc[getIdx(IdxEntry::CACHE_INDIR)].dims.d[1] : 1;
// Commonly, cyclic kv cache length, and max kv cache length will be the same
// unless each layer has different max kv cache length.
// Commonly, cyclic_attention_window_size, and max_attention_window_size will be the same
// unless each layer has different attention window sizes.
// the kv_cache capacity.
const int max_kv_cache_length = isCrossAttention()
const int max_attention_window_size = isCrossAttention()
? max_encoder_context_len
: (useKVCache() ? inputDesc[getIdx(IdxEntry::CACHE_INDIR)].dims.d[2] : 0);
// The cyclic_kv_cache_length will determine the cyclic kv cache position of new tokens.
// Note that this cyclic_kv_cache_length might be smaller than the actual kv cache capactity (max_kv_cache_length).
const int cyclic_kv_cache_length = isCrossAttention()
// The cyclic_attention_window_size will determine the cyclic kv cache position of new tokens.
// Note that this cyclic_attention_window_size might be smaller than the actual kv cache capactity.
const int cyclic_attention_window_size = isCrossAttention()
? max_encoder_context_len
: reinterpret_cast<const int*>(inputs[getIdx(IdxEntry::HOST_MAX_KV_CACHE_LENGTH)])[0];
: reinterpret_cast<const int*>(inputs[getIdx(IdxEntry::HOST_MAX_ATTENTION_WINDOW)])[0];
const float* kv_scale_orig_quant = nullptr;
const float* kv_scale_quant_orig = nullptr;
@ -395,9 +395,9 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
}
EnqueueContextParams<T, KVCacheBuffer> enqueue_params{attention_input, qkv_bias, max_context_q_len,
max_context_kv_len, max_kv_cache_length, cyclic_kv_cache_length, context_q_lengths, sequence_kv_length,
kv_scale_orig_quant, kv_scale_quant_orig, alibi_slopes, context_buf_, key_value_cache, block_pointers,
host_block_pointers, batch_size, localNbTokens, max_blocks_per_sequence, workspace};
max_context_kv_len, max_attention_window_size, cyclic_attention_window_size, context_q_lengths,
sequence_kv_length, kv_scale_orig_quant, kv_scale_quant_orig, alibi_slopes, context_buf_, key_value_cache,
block_pointers, host_block_pointers, batch_size, localNbTokens, max_blocks_per_sequence, workspace};
if (isRelativePosition())
{
enqueue_params.relative_attention_bias
@ -425,11 +425,14 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
const int* cache_indir
= beamWidth == 1 ? nullptr : reinterpret_cast<const int*>(inputs[getIdx(IdxEntry::CACHE_INDIR)]);
const int* host_context_lengths
= mRemovePadding ? reinterpret_cast<const int*>(inputs[getIdx(IdxEntry::HOST_CONTEXT_LENGTH)]) : nullptr;
EnqueueGenerationParams<T, KVCacheBuffer> enqueue_params{attention_input, qkv_bias, sequence_kv_length,
max_context_kv_len, beamWidth, context_q_lengths, kv_scale_orig_quant, kv_scale_quant_orig, alibi_slopes,
context_buf_, key_value_cache, block_pointers, max_kv_cache_length, cyclic_kv_cache_length, num_requests,
max_blocks_per_sequence, cache_indir, workspace, max_context_kv_len_list};
context_buf_, key_value_cache, block_pointers, max_attention_window_size, cyclic_attention_window_size,
num_requests, max_blocks_per_sequence, cache_indir, workspace, max_context_kv_len_list};
enqueue_params.host_context_lengths = host_context_lengths;
if (isRelativePosition())
{
enqueue_params.relative_attention_bias

View File

@ -46,7 +46,7 @@ namespace tensorrt_llm::plugins
// enable_remove_input_padding
// 1. sequence_length [batch_size] (optional)
// 2. host_past_key_value_lengths [batch_size] (int32) (optional)
// 3. host_max_kv_cache_lengths [1] (int32)
// 3. host_max_attention_window_sizes [1] (int32)
// 4. context_lengths [batch_size]
// 5. cache_indir [num_gen_requests, beam_width, memory_max_len] (required in beamsearch) (optional)
// 6. host_request_types [batch_size] int32. 0: context; 1: generation: 2: none. When not in inflight-batching
@ -141,7 +141,7 @@ private:
QKV_TENSOR,
SEQUENCE_LENGTH,
HOST_PAST_KEY_VALUE_LENGTHS,
HOST_MAX_KV_CACHE_LENGTH,
HOST_MAX_ATTENTION_WINDOW,
CONTEXT_LENGTHS,
CACHE_INDIR,
REQUEST_TYPES,

View File

@ -15,9 +15,16 @@
* limitations under the License.
*/
#include "loraPlugin.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/kernels/groupGemm.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cublasMMWrapper.h"
#include "tensorrt_llm/common/cublasVersionCheck.h"
#include <algorithm>
using namespace nvinfer1;
using namespace tensorrt_llm::common;
using tensorrt_llm::plugins::LoraPluginCreator;
@ -31,6 +38,10 @@ static const char* LORA_PLUGIN_NAME{"Lora"};
PluginFieldCollection LoraPluginCreator::mFC{};
std::vector<nvinfer1::PluginField> LoraPluginCreator::mPluginAttributes;
// TODO should be managed by better way
static std::vector<cublasHandle_t> cublas_handles;
static std::vector<cudaStream_t> streams;
// TODO should reuse the function in gemmPlugin
void _getProblemParams(cublasOperation_t& transa, cublasOperation_t& transb, int& m, int& n, int& k, int& lda, int& ldb,
int& ldc, bool transA, bool transB, int M, int N, int K)
@ -283,6 +294,34 @@ void LoraPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, in
mGemmId.k = K;
}
size_t getLowRankWorkSpaceSize(int nbReq, int maxContextLength, int maxLowRank, int typeSize)
{
return (size_t) divUp(nbReq * maxContextLength * maxLowRank * typeSize, 16) * 16;
}
size_t getCutlassWorkSpaceSize(int nbReq)
{
auto gemm_coord_size = divUp(nbReq * sizeof(cutlass::gemm::GemmCoord), 16) * 16;
auto ptr_size = 4 * divUp(nbReq * sizeof(half*), 16) * 16;
auto ldd_size = 4 * divUp(nbReq * sizeof(int64_t), 16) * 16;
return gemm_coord_size + ptr_size + ldd_size;
}
LoraPlugin::~LoraPlugin()
{
for (int i = 0; i < streams.size(); i++)
{
if (i != 0)
{
cudaStreamDestroy(streams.at(i));
}
cublasDestroy(cublas_handles.at(i));
}
streams.clear();
cublas_handles.clear();
}
size_t LoraPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept
{
@ -291,15 +330,30 @@ size_t LoraPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, in
auto const type = inputs[getInputTensorIdx()].type;
auto const typeSize = tensorrt_llm::runtime::BufferDataType(type).getSize();
size_t const lowRankWorkSpaceSize = nbReq * mMaxContextLength * mMaxLowRank * typeSize;
return CUBLAS_WORKSPACE_SIZE + getLowRankWorkSpaceSize(nbReq, mMaxContextLength, mMaxLowRank, typeSize)
+ getCutlassWorkSpaceSize(nbReq);
}
return CUBLAS_WORKSPACE_SIZE + lowRankWorkSpaceSize;
void runCublasGemmEx(const int M, const int N, const int K, const bool transA, const bool transB, const void* act,
const void* weight, void* output, cublasHandle_t cublas_handle)
{
float a = 1.0f;
float b = 0.0f;
void* alpha = &a;
void* beta = &b;
cublasOperation_t transa, transb;
int m, n, k;
int lda, ldb, ldc;
_getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, transA, transB, M, N, K);
tensorrt_llm::common::check_cuda_error(cublasGemmEx(cublas_handle, transa, transb, m, n, k, alpha, weight,
CUDA_R_16F, lda, act, CUDA_R_16F, ldb, beta, output, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));
}
int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
{
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
// inputs
// input [-1, K] (view as 2D)
// host_request_type [batch_size] on cpu
@ -309,10 +363,9 @@ int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf
// outputs
// output [-1, N] (view as 2D)
auto const typeSize = tensorrt_llm::runtime::BufferDataType(mType).getSize();
void* cublasWorkSpace = workspace;
void* lowRankWorkSpace = static_cast<char*>(cublasWorkSpace) + CUBLAS_WORKSPACE_SIZE;
TLLM_CHECK(mType == DataType::kHALF); // Only support on half now, will extend to more data type in near future.
auto const typeSize = tensorrt_llm::runtime::BufferDataType(mType).getSize();
setGemmConfig();
auto const batch_size = inputDesc[getLoraRanksIdx()].dims.d[0];
auto const lora_ranks = static_cast<int32_t const*>(inputs[getLoraRanksIdx()]);
@ -321,55 +374,230 @@ int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf
= mRemoveInputPadding ? static_cast<int32_t const*>(inputs[getHostContextLengthsIdx()]) : nullptr;
RequestType const* reqTypes = static_cast<RequestType const*>(inputs[getHostRequestTypesIdx()]);
void* cublasWorkSpace = workspace;
void* lowRankWorkSpace = static_cast<char*>(cublasWorkSpace) + CUBLAS_WORKSPACE_SIZE;
void* cutlassWorkSpace = static_cast<char*>(lowRankWorkSpace)
+ getLowRankWorkSpaceSize(batch_size, mMaxContextLength, mMaxLowRank, typeSize);
size_t cutlassWorkSpaceSize = getCutlassWorkSpaceSize(batch_size);
size_t handled_token_num = 0;
// TODO only initialize output buffer when the lora rank is -1
const int nbDimsA = inputDesc[0].dims.nbDims;
bool useUnifyGemm = false;
for (int batchIdx = 0; batchIdx < batch_size; batchIdx++)
{
const RequestType reqType = reqTypes[batchIdx];
const auto M = (reqType != RequestType::kCONTEXT)
? 1
: (mRemoveInputPadding ? host_context_lengths[batchIdx] : inputDesc[0].dims.d[1]);
const auto lora_rank = lora_ranks[batchIdx];
if (lora_rank <= 0)
if (lora_weights_ptr[2 * batchIdx] != lora_weights_ptr[0]
|| lora_weights_ptr[2 * batchIdx + 1] != lora_weights_ptr[1] || lora_ranks[batchIdx] == 0)
{
const auto N = outputDesc[0].dims.d[outputDesc[0].dims.nbDims - 1];
void* output = static_cast<void*>(static_cast<char*>(outputs[0]) + handled_token_num * N * typeSize);
if (typeSize == 2)
{
deviceFill((half*) output, M * N, (half) 0.0f, stream);
}
else
{
deviceFill((float*) output, M * N, 0.0f, stream);
}
useUnifyGemm = false;
}
else
{
// the input shape should be [1, token_num, K] under remove_input_padding,
// [batch, seqlen, K] under non-remove_input_padding
auto bestTactic = mPluginProfiler->getBestConfig(M, mGemmId);
}
const int nbDimsA = inputDesc[0].dims.nbDims;
bool UseGroupedGemm = true;
if (useUnifyGemm)
{
const RequestType reqType = reqTypes[0];
int M = 0;
for (int batchIdx = 0; batchIdx < batch_size; batchIdx++)
{
M += (reqType != RequestType::kCONTEXT)
? 1
: (mRemoveInputPadding ? host_context_lengths[batchIdx] : inputDesc[0].dims.d[1]);
}
const auto lora_rank = lora_ranks[0];
auto bestTactic = mPluginProfiler->getBestConfig(M, mGemmId);
const auto N = lora_rank;
TLLM_CHECK_WITH_INFO(N <= mMaxLowRank,
fmtstr("Invalid low_rank (%d). low_rank must be smaller than mMaxLowRank (%d)", N, mMaxLowRank));
const auto K = mTransA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1]; // input hidden size
const auto N2 = outputDesc[0].dims.d[nbDimsA - 1];
// [M, K] -> [M, N] -> [M, N2]
void* lora_in_weight = reinterpret_cast<void*>(lora_weights_ptr[0]);
void* lora_out_weight = reinterpret_cast<void*>(lora_weights_ptr[1]);
const void* input = inputs[0];
void* output = outputs[0];
_runGemm(M, N, K, mTransA, mTransB, mType, mCublasWrapper, input, lora_in_weight, lowRankWorkSpace, bestTactic,
cublasWorkSpace, stream);
_runGemm(M, N2, N, mTransA, mTransB, mType, mCublasWrapper, lowRankWorkSpace, lora_out_weight, output,
bestTactic, cublasWorkSpace, stream);
}
else if (UseGroupedGemm)
{
int handled_token_num = 0;
std::vector<cutlass::gemm::GemmCoord> problem_sizes;
problem_sizes.reserve(batch_size);
std::vector<void*> ptrA;
ptrA.reserve(batch_size);
std::vector<void*> ptrB;
ptrB.reserve(batch_size);
std::vector<void*> ptrC;
ptrC.reserve(batch_size);
std::vector<void*> ptrD;
ptrD.reserve(batch_size);
std::vector<cutlass::gemm::GemmCoord> problem_sizes_2;
problem_sizes_2.reserve(batch_size);
std::vector<void*> ptrA_2;
ptrA_2.reserve(batch_size);
std::vector<void*> ptrB_2;
ptrB_2.reserve(batch_size);
std::vector<void*> ptrC_2;
ptrC_2.reserve(batch_size);
std::vector<void*> ptrD_2;
ptrD_2.reserve(batch_size);
for (int batchIdx = 0; batchIdx < batch_size; batchIdx++)
{
const RequestType reqType = reqTypes[batchIdx];
const auto M = (reqType != RequestType::kCONTEXT)
? 1
: (mRemoveInputPadding ? host_context_lengths[batchIdx] : inputDesc[0].dims.d[1]);
const auto lora_rank = lora_ranks[batchIdx];
const auto N = lora_rank;
if (N > 0)
{
TLLM_CHECK_WITH_INFO(N <= mMaxLowRank,
fmtstr("Invalid low_rank (%d). low_rank must be smaller than mMaxLowRank (%d)", N, mMaxLowRank));
const auto K = mTransA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1]; // input hidden size
TLLM_CHECK_WITH_INFO(N <= mMaxLowRank,
fmtstr("Invalid low_rank (%d). low_rank must be smaller than mMaxLowRank (%d)", N, mMaxLowRank));
const auto K = mTransA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1]; // input hidden size
const auto N2 = outputDesc[0].dims.d[nbDimsA - 1];
// [M, K] -> [M, N] -> [M, N2]
cutlass::gemm::GemmCoord problem(M, N, K);
problem_sizes.push_back(problem);
void* lora_in_weight = reinterpret_cast<void*>(lora_weights_ptr[batchIdx * 2 + 0]);
void* lora_out_weight = reinterpret_cast<void*>(lora_weights_ptr[batchIdx * 2 + 1]);
const void* input
= static_cast<const void*>(static_cast<const char*>(inputs[0]) + handled_token_num * K * typeSize);
void* output = static_cast<void*>(static_cast<char*>(outputs[0]) + handled_token_num * N2 * typeSize);
_runGemm(M, N, K, mTransA, mTransB, mType, mCublasWrapper, input, lora_in_weight, lowRankWorkSpace,
bestTactic, cublasWorkSpace, stream);
ptrA.push_back(static_cast<void*>(
static_cast<char*>(const_cast<void*>(inputs[0])) + handled_token_num * K * typeSize));
ptrB.push_back(reinterpret_cast<void*>(lora_weights_ptr[batchIdx * 2 + 0]));
ptrC.push_back(static_cast<void*>(
static_cast<char*>(lowRankWorkSpace) + handled_token_num * mMaxLowRank * typeSize));
ptrD.push_back(static_cast<void*>(
static_cast<char*>(lowRankWorkSpace) + handled_token_num * mMaxLowRank * typeSize));
_runGemm(M, N2, N, mTransA, mTransB, mType, mCublasWrapper, lowRankWorkSpace, lora_out_weight, output,
bestTactic, cublasWorkSpace, stream);
const auto N2 = outputDesc[0].dims.d[nbDimsA - 1];
cutlass::gemm::GemmCoord problem_2(M, N2, N);
problem_sizes_2.push_back(problem_2);
ptrA_2.push_back(static_cast<void*>(
static_cast<char*>(lowRankWorkSpace) + handled_token_num * mMaxLowRank * typeSize));
ptrB_2.push_back(reinterpret_cast<void*>(lora_weights_ptr[batchIdx * 2 + 1]));
ptrC_2.push_back(
static_cast<void*>(static_cast<char*>(outputs[0]) + handled_token_num * N2 * typeSize));
ptrD_2.push_back(
static_cast<void*>(static_cast<char*>(outputs[0]) + handled_token_num * N2 * typeSize));
}
handled_token_num += M;
}
handled_token_num += M;
tensorrt_llm::kernels::run_cutlass_1(problem_sizes, ptrA, ptrB, ptrC, ptrD, cutlassWorkSpace,
cutlassWorkSpaceSize, cublasWorkSpace, CUBLAS_WORKSPACE_SIZE, stream);
sync_check_cuda_error();
tensorrt_llm::kernels::run_cutlass_2(problem_sizes_2, ptrA_2, ptrB_2, ptrC_2, ptrD_2, cutlassWorkSpace,
cutlassWorkSpaceSize, cublasWorkSpace, CUBLAS_WORKSPACE_SIZE, stream);
sync_check_cuda_error();
}
else
{
if (streams.size() != batch_size)
{
for (int i = 0; i < batch_size; i++)
{
// TLLM_LOG_INFO("allocate %d stream and handle", i);
cudaStream_t stream_;
cublasHandle_t handle;
if (i == 0)
{
stream_ = stream;
}
else
{
cudaStreamCreate(&stream_);
}
cublasCreate(&handle);
cublasSetStream(handle, stream_);
cublas_handles.push_back(handle);
streams.push_back(stream_);
cudaStreamSynchronize(stream_);
}
}
if (!graphCreated)
{
cudaGraph_t graph;
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
std::vector<cudaEvent_t> events;
events.reserve(batch_size);
for (int batchIdx = 0; batchIdx < batch_size; batchIdx++)
{
cublasSetStream(cublas_handles[batchIdx], streams[batchIdx]);
cudaEventCreate(&events[batchIdx]);
const RequestType reqType = reqTypes[batchIdx];
const auto M = (reqType != RequestType::kCONTEXT)
? 1
: (mRemoveInputPadding ? host_context_lengths[batchIdx] : inputDesc[0].dims.d[1]);
const auto lora_rank = lora_ranks[batchIdx];
if (lora_rank <= 0)
{
const auto N = outputDesc[0].dims.d[outputDesc[0].dims.nbDims - 1];
void* output
= static_cast<void*>(static_cast<char*>(outputs[0]) + handled_token_num * N * typeSize);
if (typeSize == 2)
{
deviceFill((half*) output, M * N, (half) 0.0f, streams[batchIdx]);
}
else
{
deviceFill((float*) output, M * N, 0.0f, streams[batchIdx]);
}
}
else
{
// the input shape should be [1, token_num, K] under remove_input_padding,
// [batch, seqlen, K] under non-remove_input_padding
auto bestTactic = mPluginProfiler->getBestConfig(M, mGemmId);
const auto N = lora_rank;
TLLM_CHECK_WITH_INFO(N <= mMaxLowRank,
fmtstr(
"Invalid low_rank (%d). low_rank must be smaller than mMaxLowRank (%d)", N, mMaxLowRank));
const auto K
= mTransA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1]; // input hidden size
const auto N2 = outputDesc[0].dims.d[nbDimsA - 1];
// [M, K] -> [M, N] -> [M, N2]
void* lora_in_weight = reinterpret_cast<void*>(lora_weights_ptr[batchIdx * 2 + 0]);
void* lora_out_weight = reinterpret_cast<void*>(lora_weights_ptr[batchIdx * 2 + 1]);
const void* input = static_cast<const void*>(
static_cast<const char*>(inputs[0]) + handled_token_num * K * typeSize);
void* output
= static_cast<void*>(static_cast<char*>(outputs[0]) + handled_token_num * N2 * typeSize);
if (batchIdx > 0)
{
cudaStreamWaitEvent(streams[batchIdx], events[0]);
}
runCublasGemmEx(
M, N, K, mTransA, mTransB, input, lora_in_weight, lowRankWorkSpace, cublas_handles[batchIdx]);
runCublasGemmEx(M, N2, N, mTransA, mTransB, lowRankWorkSpace, lora_out_weight, output,
cublas_handles[batchIdx]);
cudaEventRecord(events[batchIdx], streams[batchIdx]);
}
handled_token_num += M;
cudaStreamWaitEvent(stream, events[batchIdx], 0);
}
cudaStreamEndCapture(stream, &graph);
cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);
graphCreated = true;
}
cudaGraphLaunch(instance, stream);
}
return 0;
}

View File

@ -44,7 +44,7 @@ public:
LoraPlugin(const void* data, size_t length, const PluginProfilerPtr& profiler);
~LoraPlugin() override = default;
~LoraPlugin(); // override = default;
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
@ -132,6 +132,9 @@ private:
GemmIdCublas mGemmId{};
PluginProfilerPtr mPluginProfiler;
bool graphCreated = false;
cudaGraphExec_t instance;
};
class LoraPluginCreator : public BaseCreator

View File

@ -19,12 +19,16 @@
#include "namedTensor.h"
#include "tensorrt_llm/batch_manager/GptManager.h"
#include "tensorrt_llm/batch_manager/callbacks.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/pybind/utils/pathCaster.h"
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/ops/tensor.h>
#include <functional>
#include <memory>
#include <optional>
@ -51,8 +55,8 @@ py::object GptManager::enter()
void GptManager::exit(py::handle type, py::handle value, py::handle traceback)
{
// NOTE: we must release the GIL here. GptManager has spawned a thread for the execution loop. That thread must be
// able to do forward progress for the shutdown process to succeed. For that, we must manually release the GIL while
// waiting in `process.join()`.
// able to do forward progress for the shutdown process to succeed. It takes the GIL during its callbacks, so
// we release it now. Note that we shouldn't do anything related to python objects after that.
py::gil_scoped_release release;
shutdown();
}
@ -85,4 +89,21 @@ tb::SendResponseCallback callbackAdapter(SendResponseCallback callback)
callback(id, pythonList, isOk, errMsg);
};
}
void GptManager::initBindings(py::module_& m)
{
py::class_<GptManager>(m, "GptManager")
.def(py::init<std::filesystem::path const&, tb::TrtGptModelType, int32_t, tb::batch_scheduler::SchedulerPolicy,
GetInferenceRequestsCallback, SendResponseCallback, tb::PollStopSignalCallback,
tb::ReturnBatchManagerStatsCallback, const tb::TrtGptModelOptionalParams&, std::optional<uint64_t>>(),
py::arg("trt_engine_path"), py::arg("model_type"), py::arg("max_beam_width"), py::arg("scheduler_policy"),
py::arg("get_inference_requests_cb"), py::arg("send_response_cb"), py::arg("poll_stop_signal_cb") = nullptr,
py::arg("return_batch_manager_stats_cb") = nullptr,
py::arg_v("optional_params", tb::TrtGptModelOptionalParams(), "TrtGptModelOptionalParams"),
py::arg("terminate_req_id") = std::nullopt)
.def("shutdown", &GptManager::exit)
.def("__enter__", &GptManager::enter)
.def("__exit__", &GptManager::exit);
}
} // namespace tensorrt_llm::pybind::batch_manager

View File

@ -25,30 +25,31 @@
#include <ATen/ops/tensor.h>
#include <functional>
namespace py = pybind11;
namespace tb = tensorrt_llm::batch_manager;
namespace tensorrt_llm::pybind::batch_manager
{
using GetInferenceRequestsCallback = std::function<std::list<InferenceRequest>(int32_t)>;
using SendResponseCallback = std::function<void(uint64_t, std::list<NamedTensor> const&, bool, const std::string&)>;
tb::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback callback);
tb::SendResponseCallback callbackAdapter(SendResponseCallback callback);
tensorrt_llm::batch_manager::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback callback);
tensorrt_llm::batch_manager::SendResponseCallback callbackAdapter(SendResponseCallback callback);
class GptManager : tb::GptManager
class GptManager : tensorrt_llm::batch_manager::GptManager
{
public:
GptManager(std::filesystem::path const& trtEnginePath, tb::TrtGptModelType modelType, int32_t maxBeamWidth,
tb::batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback getInferenceRequestsCb,
SendResponseCallback sendResponseCb, tb::PollStopSignalCallback pollStopSignalCb = nullptr,
tb::ReturnBatchManagerStatsCallback returnBatchManagerStatsCb = nullptr,
const tb::TrtGptModelOptionalParams& optionalParams = tb::TrtGptModelOptionalParams(),
GptManager(std::filesystem::path const& trtEnginePath, tensorrt_llm::batch_manager::TrtGptModelType modelType,
int32_t maxBeamWidth, tensorrt_llm::batch_manager::batch_scheduler::SchedulerPolicy schedulerPolicy,
GetInferenceRequestsCallback getInferenceRequestsCb, SendResponseCallback sendResponseCb,
tensorrt_llm::batch_manager::PollStopSignalCallback pollStopSignalCb = nullptr,
tensorrt_llm::batch_manager::ReturnBatchManagerStatsCallback returnBatchManagerStatsCb = nullptr,
const tensorrt_llm::batch_manager::TrtGptModelOptionalParams& optionalParams
= tensorrt_llm::batch_manager::TrtGptModelOptionalParams(),
std::optional<uint64_t> terminateReqId = std::nullopt);
py::object enter();
void exit(py::handle type, py::handle value, py::handle traceback);
pybind11::object enter();
void exit(pybind11::handle type, pybind11::handle value, pybind11::handle traceback);
static void initBindings(pybind11::module_& m);
};
} // namespace tensorrt_llm::pybind::batch_manager

View File

@ -20,6 +20,11 @@
#include "tensorrt_llm/runtime/torchView.h"
#include <memory>
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
namespace tb = tensorrt_llm::batch_manager;
namespace tr = tensorrt_llm::runtime;
@ -48,5 +53,46 @@ std::shared_ptr<tb::InferenceRequest> InferenceRequest::toTrtLlm() const
tensorMap[name] = tr::TorchView::of(tensor.value());
}
}
return std::make_shared<tb::InferenceRequest>(tensorMap, mRequestId);
auto inferenceRequest = std::make_shared<tb::InferenceRequest>(tensorMap, mRequestId);
inferenceRequest->setIsStreaming(isStreaming());
return inferenceRequest;
}
void InferenceRequest::initBindings(py::module_& m)
{
py::class_<InferenceRequest>(m, "InferenceRequest")
.def(py::init<uint64_t>())
.def(py::init<uint64_t, InferenceRequest::TensorMap const&>(), "deprecated: use direct tensor access instead")
.def_property("input_ids", &InferenceRequest::getInputIdsUnchecked, &InferenceRequest::setInputIds)
.def_property(
"draft_input_ids", &InferenceRequest::getDraftInputIdsUnchecked, &InferenceRequest::setDraftInputIds)
.def_property("draft_logits", &InferenceRequest::getDraftLogitsUnchecked, &InferenceRequest::setDraftLogits)
.def_property("max_new_tokens", &InferenceRequest::getMaxNewTokensUnchecked, &InferenceRequest::setMaxNewTokens)
.def_property("beam_width", &InferenceRequest::getBeamWidthUnchecked, &InferenceRequest::setBeamWidth)
.def_property("end_id", &InferenceRequest::getEndIdUnchecked, &InferenceRequest::setEndId)
.def_property("pad_id", &InferenceRequest::getPadIdUnchecked, &InferenceRequest::setPadId)
.def_property("bad_words_list", &InferenceRequest::getBadWordsListUnchecked, &InferenceRequest::setBadWordsList)
.def_property(
"stop_words_list", &InferenceRequest::getStopWordsListUnchecked, &InferenceRequest::setStopWordsList)
.def_property(
"embedding_bias", &InferenceRequest::getEmbeddingBiasUnchecked, &InferenceRequest::setEmbeddingBias)
.def_property("temperature", &InferenceRequest::getTemperatureUnchecked, &InferenceRequest::setTemperature)
.def_property("runtime_top_k", &InferenceRequest::getRuntimeTopKUnchecked, &InferenceRequest::setRuntimeTopK)
.def_property("runtime_top_p", &InferenceRequest::getRuntimeTopPUnchecked, &InferenceRequest::setRuntimeTopP)
.def_property(
"length_penalty", &InferenceRequest::getLengthPenaltyUnchecked, &InferenceRequest::setLengthPenalty)
.def_property("repetition_penalty", &InferenceRequest::getRepetitionPenaltyUnchecked,
&InferenceRequest::setRepetitionPenalty)
.def_property("min_length", &InferenceRequest::getMinLengthUnchecked, &InferenceRequest::setMinLength)
.def_property(
"presence_penalty", &InferenceRequest::getPresencePenaltyUnchecked, &InferenceRequest::setPresencePenalty)
.def_property("random_seed", &InferenceRequest::getRandomSeedUnchecked, &InferenceRequest::setRandomSeed)
.def_property(
"return_log_probs", &InferenceRequest::getReturnLogProbsUnchecked, &InferenceRequest::setReturnLogProbs)
.def_property("prompt_embedding_table", &InferenceRequest::getPromptEmbeddingTableUnchecked,
&InferenceRequest::setPromptEmbeddingTable)
.def_property(
"prompt_vocab_size", &InferenceRequest::getPromptVocabSizeUnchecked, &InferenceRequest::setPromptVocabSize)
.def_property("is_streaming", &InferenceRequest::isStreaming, &InferenceRequest::setIsStreaming)
.def_property_readonly("request_id", &InferenceRequest::getRequestId);
}

View File

@ -19,9 +19,9 @@
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
#include "tensorrt_llm/pybind/batch_manager/namedTensor.h"
#include "tensorrt_llm/runtime/common.h"
#include <ATen/ATen.h>
#include <pybind11/pybind11.h>
#include <memory>
#include <optional>
@ -53,6 +53,7 @@ public:
}
[[nodiscard]] std::shared_ptr<tensorrt_llm::batch_manager::InferenceRequest> toTrtLlm() const;
static void initBindings(pybind11::module_& m);
};
} // namespace tensorrt_llm::pybind::batch_manager

View File

@ -21,6 +21,11 @@
#include "tensorrt_llm/runtime/torchView.h"
#include <memory>
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
namespace tb = tensorrt_llm::batch_manager;
namespace tr = tensorrt_llm::runtime;
@ -53,3 +58,61 @@ std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable, mPromptVocabSize, mReturnLogProbs,
mDraftTokens, draftLogits);
}
void LlmRequest::initBindings(py::module_& m)
{
py::class_<LlmRequest>(m, "LlmRequest")
.def(py::init<LlmRequest::RequestIdType, LlmRequest::SizeType, LlmRequest::VecTokens, tr::SamplingConfig, bool,
std::optional<LlmRequest::SizeType>, std::optional<LlmRequest::SizeType>,
std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::TensorPtr>,
std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::TensorPtr>,
std::optional<LlmRequest::SizeType>, bool, std::optional<LlmRequest::VecTokens>,
std::optional<LlmRequest::TensorPtr>>(),
py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"),
py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt,
py::arg("embedding_bias") = std::nullopt, py::arg("bad_words_list") = std::nullopt,
py::arg("stop_words_list") = std::nullopt, py::arg("prompt_embedding_table") = std::nullopt,
py::arg("prompt_vocab_size") = std::nullopt, py::arg("return_log_probs") = false,
py::arg("draft_tokens") = std::nullopt, py::arg("draft_logits") = std::nullopt)
.def("get_num_tokens", &LlmRequest::getNumTokens, py::arg("beam"))
.def_property_readonly("max_beam_num_tokens", &LlmRequest::getMaxBeamNumTokens)
.def("get_token", &LlmRequest::getToken, py::arg("beam"), py::arg("pos"))
.def("get_tokens", py::overload_cast<LlmRequest::SizeType>(&LlmRequest::getTokens, py::const_), py::arg("beam"))
.def("get_tokens", py::overload_cast<>(&LlmRequest::getTokens, py::const_))
.def_property_readonly("max_num_generated_tokens", &LlmRequest::getMaxNumGeneratedTokens)
.def("add_new_token", &LlmRequest::addNewToken, py::arg("token"), py::arg("beam"))
.def("add_new_tokens", &LlmRequest::addNewTokens, py::arg("beam_tokens"))
.def("set_generated_tokens", &LlmRequest::setGeneratedTokens, py::arg("generated_beam_tokens"))
.def("pause", &LlmRequest::pause, py::arg("max_input_len"))
.def_property("max_sent_token_pos", &LlmRequest::getMaxSentTokenPos, &LlmRequest::setMaxSentTokenPos)
.def_property_readonly("prompt_embedding_table", &LlmRequest::getPromptEmbeddingTable)
.def_property_readonly("prompt_vocab_size", &LlmRequest::getPromptVocabSize)
.def_property_readonly("embedding_bias", &LlmRequest::getEmbeddingBias)
.def_property_readonly("bad_words_list", &LlmRequest::getBadWordsList)
.def_property_readonly("stop_words_list", &LlmRequest::getStopWordsList)
.def_readwrite("request_id", &LlmRequest::mRequestId)
.def_readwrite("prompt_len", &LlmRequest::mPromptLen)
.def_readwrite("max_new_tokens", &LlmRequest::mMaxNewTokens)
.def_readwrite("sampling_config", &LlmRequest::mSamplingConfig)
.def_readwrite("state", &LlmRequest::mState)
.def_readwrite("is_streaming", &LlmRequest::mIsStreaming)
.def_readwrite("end_id", &LlmRequest::mEndId)
.def_readwrite("pad_id", &LlmRequest::mPadId)
.def_readwrite("seq_slot", &LlmRequest::mSeqSlot)
.def_property_readonly("return_log_probs", &LlmRequest::returnLogProbs)
.def_property_readonly("log_probs", py::overload_cast<>(&LlmRequest::getLogProbs, py::const_))
.def("get_log_probs", py::overload_cast<SizeType>(&LlmRequest::getLogProbs, py::const_))
.def("set_log_probs", &LlmRequest::setLogProbs, py::arg("log_probs"), py::arg("beam"))
.def_property_readonly("cum_log_probs", &LlmRequest::getCumLogProbs)
.def("set_cum_log_prob", &LlmRequest::setCumLogProb, py::arg("cum_log_prob"), py::arg("beam"))
.def_property_readonly("orig_prompt_len", &LlmRequest::getOrigPromptLen)
.def("has_draft_tokens", &LlmRequest::hasDraftTokens)
.def_property(
"draft_tokens", [](LlmRequest& self) { return *self.getDraftTokens(); },
[](LlmRequest& self, LlmRequest::VecTokens& draftTokens)
{ self.setDraftTokens(std::make_shared<LlmRequest::VecTokens>(std::move(draftTokens))); })
.def_property(
"draft_logits", [](LlmRequest& self) { return self.getDraftLogits(); },
[](LlmRequest& self, LlmRequest::TensorPtr& logits)
{ self.setDraftLogits(std::make_optional<LlmRequest::TensorPtr>(logits)); });
}

View File

@ -18,13 +18,12 @@
#pragma once
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/common/assert.h"
#include <ATen/ATen.h>
#include <ATen/ops/tensor.h>
#include <memory>
#include <optional>
#include <pybind11/pybind11.h>
namespace tensorrt_llm::pybind::batch_manager
{
@ -58,6 +57,7 @@ public:
}
[[nodiscard]] std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> toTrtLlm() const;
static void initBindings(pybind11::module_& m);
};
} // namespace tensorrt_llm::pybind::batch_manager

View File

@ -16,9 +16,14 @@
*/
#include "namedTensor.h"
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
#include "tensorrt_llm/runtime/torch.h"
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
namespace py = pybind11;
namespace tb = tensorrt_llm::batch_manager;
namespace tensorrt_llm::pybind::batch_manager
@ -29,4 +34,12 @@ NamedTensor::NamedTensor(const tb::NamedTensor& cppNamedTensor)
{
}
void NamedTensor::initBindings(py::module_& m)
{
py::class_<NamedTensor>(m, "NamedTensor")
.def(py::init<NamedTensor::TensorPtr, std::string>(), py::arg("tensor"), py::arg("name"))
.def_readwrite("tensor", &NamedTensor::tensor)
.def_readonly("name", &NamedTensor::name);
}
} // namespace tensorrt_llm::pybind::batch_manager

View File

@ -18,28 +18,19 @@
#pragma once
#include "tensorrt_llm/batch_manager/namedTensor.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include <ATen/ATen.h>
#include <ATen/core/ATen_fwd.h>
#include <ATen/ops/from_blob.h>
#include <ATen/ops/tensor.h>
#include <c10/core/DeviceType.h>
#include <c10/util/ArrayRef.h>
#include <memory>
#include <optional>
namespace tb = tensorrt_llm::batch_manager;
#include <pybind11/pybind11.h>
namespace tensorrt_llm::pybind::batch_manager
{
class NamedTensor : public tb::GenericNamedTensor<std::optional<at::Tensor>>
class NamedTensor : public tensorrt_llm::batch_manager::GenericNamedTensor<std::optional<at::Tensor>>
{
public:
using Base = tb::GenericNamedTensor<std::optional<at::Tensor>>;
using Base = tensorrt_llm::batch_manager::GenericNamedTensor<std::optional<at::Tensor>>;
using TensorPtr = Base::TensorPtr;
NamedTensor(TensorPtr _tensor, std::string _name)
@ -48,7 +39,8 @@ public:
explicit NamedTensor(std::string _name)
: Base(std::move(_name)){};
explicit NamedTensor(const tb::NamedTensor& cppNamedTensor);
explicit NamedTensor(const tensorrt_llm::batch_manager::NamedTensor& cppNamedTensor);
static void initBindings(pybind11::module_& m);
};
} // namespace tensorrt_llm::pybind::batch_manager

View File

@ -58,45 +58,18 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
{
m.doc() = "TensorRT-LLM Python bindings for C++ runtime";
py::class_<tpr::PromptTuningParams>(m, "PromptTuningParams")
.def(py::init<tpr::PromptTuningParams::TensorPtr, tpr::PromptTuningParams::TensorPtr,
tpr::PromptTuningParams::TensorPtr>(),
py::arg("embedding_table") = py::none(), py::arg("tasks") = py::none(), py::arg("vocab_size") = py::none())
.def_readwrite("embedding_table", &tpr::PromptTuningParams::embeddingTable)
.def_readwrite("tasks", &tpr::PromptTuningParams::tasks)
.def_readwrite("vocab_size", &tpr::PromptTuningParams::vocabSize)
.def_readwrite("prompt_tuning_enabled", &tpr::PromptTuningParams::promptTuningEnabled);
py::class_<tpr::GenerationInput>(m, "GenerationInput")
.def(py::init<SizeType, SizeType, tpr::GenerationInput::TensorPtr, tpr::GenerationInput::TensorPtr, bool>(),
py::arg("end_id"), py::arg("pad_id"), py::arg("ids"), py::arg("lengths"), py::arg("packed") = false)
.def_readwrite("end_id", &tpr::GenerationInput::endId)
.def_readwrite("pad_id", &tpr::GenerationInput::padId)
.def_readwrite("ids", &tpr::GenerationInput::ids)
.def_readwrite("lengths", &tpr::GenerationInput::lengths)
.def_readwrite("packed", &tpr::GenerationInput::packed)
.def_readwrite("embedding_bias", &tpr::GenerationInput::embeddingBias)
.def_readwrite("bad_words_list", &tpr::GenerationInput::badWordsList)
.def_readwrite("stop_words_list", &tpr::GenerationInput::stopWordsList)
.def_readwrite("max_new_tokens", &tpr::GenerationInput::maxNewTokens)
.def_readwrite("prompt_tuning_params", &tpr::GenerationInput::promptTuningParams);
py::class_<tpr::GenerationOutput>(m, "GenerationOutput")
.def(py::init<tpr::GenerationOutput::TensorPtr, tpr::GenerationOutput::TensorPtr>(), py::arg("ids"),
py::arg("lengths"))
.def_readwrite("ids", &tpr::GenerationOutput::ids)
.def_readwrite("lengths", &tpr::GenerationOutput::lengths)
.def_readwrite("log_probs", &tpr::GenerationOutput::logProbs)
.def_readwrite("context_logits", &tpr::GenerationOutput::contextLogits)
.def_readwrite("on_token_generated", &tpr::GenerationOutput::onTokenGenerated);
tpr::PromptTuningParams::initBindings(m);
tpr::GenerationInput::initBindings(m);
tpr::GenerationOutput::initBindings(m);
py::class_<tbk::KvCacheConfig>(m, "KvCacheConfig")
.def(py::init<std::optional<SizeType>, std::optional<SizeType>, std::optional<float>>(),
py::arg("max_tokens") = py::none(), py::arg("max_kv_cache_length") = py::none(),
py::arg("free_gpu_memory_fraction") = py::none())
.def(py::init<std::optional<SizeType>, std::optional<SizeType>, std::optional<float>, bool>(),
py::arg("max_tokens") = py::none(), py::arg("max_attention_window") = py::none(),
py::arg("free_gpu_memory_fraction") = py::none(), py::arg("enable_block_reuse") = false)
.def_readwrite("max_tokens", &tbk::KvCacheConfig::maxTokens)
.def_readwrite("max_kv_cache_length", &tbk::KvCacheConfig::maxKvCacheLength)
.def_readwrite("free_gpu_memory_fraction", &tbk::KvCacheConfig::freeGpuMemoryFraction);
.def_readwrite("max_attention_window", &tbk::KvCacheConfig::maxAttentionWindow)
.def_readwrite("free_gpu_memory_fraction", &tbk::KvCacheConfig::freeGpuMemoryFraction)
.def_readwrite("enable_block_reuse", &tbk::KvCacheConfig::enableBlockReuse);
py::class_<tr::GptSession::Config>(m, "GptSessionConfig")
.def(py::init<SizeType, SizeType, SizeType>(), py::arg("max_batch_size"), py::arg("max_beam_width"),
@ -184,9 +157,13 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.def_property("quant_mode", &tr::GptModelConfig::getQuantMode, &tr::GptModelConfig::setQuantMode)
.def_property_readonly("supports_inflight_batching", &tr::GptModelConfig::supportsInflightBatching)
.def_property("max_batch_size", &tr::GptModelConfig::getMaxBatchSize, &tr::GptModelConfig::setMaxBatchSize)
.def_property("max_beam_width", &tr::GptModelConfig::getMaxBeamWidth, &tr::GptModelConfig::setMaxBeamWidth)
.def_property("max_input_len", &tr::GptModelConfig::getMaxInputLen, &tr::GptModelConfig::setMaxInputLen)
.def_property("max_output_len", &tr::GptModelConfig::getMaxOutputLen, &tr::GptModelConfig::setMaxOutputLen)
.def_property("max_num_tokens", &tr::GptModelConfig::getMaxNumTokens, &tr::GptModelConfig::setMaxNumTokens)
.def_property("max_prompt_embedding_table_size", &tr::GptModelConfig::getMaxPromptEmbeddingTableSize,
&tr::GptModelConfig::setMaxPromptEmbeddingTableSize)
.def_property_readonly("use_prompt_tuning", &tr::GptModelConfig::usePromptTuning)
.def_property("compute_context_logits",
py::overload_cast<>(&tr::GptModelConfig::computeContextLogits, py::const_),
py::overload_cast<bool>(&tr::GptModelConfig::computeContextLogits))
@ -212,9 +189,10 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.def_property_readonly("pipeline_parallel_rank", &tr::WorldConfig::getPipelineParallelRank)
.def_property_readonly("tensor_parallel_rank", &tr::WorldConfig::getTensorParallelRank)
.def_static("mpi",
py::overload_cast<SizeType, std::optional<SizeType>, std::optional<SizeType>>(&tr::WorldConfig::mpi),
py::overload_cast<SizeType, std::optional<SizeType>, std::optional<SizeType>,
std::optional<std::vector<SizeType>>>(&tr::WorldConfig::mpi),
py::arg("gpus_per_node") = tr::WorldConfig::kDefaultGpusPerNode, py::arg("tensor_parallelism") = py::none(),
py::arg("pipeline_parallelism") = py::none());
py::arg("pipeline_parallelism") = py::none(), py::arg("user_specified_device_ids") = py::none());
py::class_<tr::SamplingConfig>(m, "SamplingConfig")
.def(py::init<SizeType>(), py::arg("beam_width") = 1)
@ -233,14 +211,15 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.def_readwrite("length_penalty", &tr::SamplingConfig::lengthPenalty);
py::class_<tr::GptJsonConfig>(m, "GptJsonConfig")
.def(py::init<std::string, std::string, SizeType, SizeType, tr::GptModelConfig>(), py::arg("name"),
py::arg("precision"), py::arg("tensor_parallelism"), py::arg("pipeline_parallelism"),
.def(py::init<std::string, std::string, std::string, SizeType, SizeType, tr::GptModelConfig>(), py::arg("name"),
py::arg("version"), py::arg("precision"), py::arg("tensor_parallelism"), py::arg("pipeline_parallelism"),
py::arg("model_config"))
.def_static("parse", py::overload_cast<std::string const&>(&tr::GptJsonConfig::parse), py::arg("json"))
.def_static(
"parse_file", py::overload_cast<std::filesystem::path const&>(&tr::GptJsonConfig::parse), py::arg("path"))
.def_property_readonly("model_config", &tr::GptJsonConfig::getModelConfig)
.def_property_readonly("name", &tr::GptJsonConfig::getName)
.def_property_readonly("version", &tr::GptJsonConfig::getVersion)
.def_property_readonly("precision", &tr::GptJsonConfig::getPrecision)
.def_property_readonly("tensor_parallelism", &tr::GptJsonConfig::getTensorParallelism)
.def_property_readonly("pipeline_parallelism", &tr::GptJsonConfig::getPipelineParallelism)
@ -254,6 +233,14 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
py::arg("world_config"));
py::class_<tr::GptSession>(m, "GptSession")
.def(py::init(
[](tr::GptSession::Config const& config, tr::GptModelConfig const& modelConfig,
tr::WorldConfig const& worldConfig, py::bytearray const& bytes)
{
auto buf = static_cast<std::string>(bytes);
return tr::GptSession{config, modelConfig, worldConfig, buf.data(), buf.size()};
}),
py::arg("config"), py::arg("model_config"), py::arg("world_config"), py::arg("engine_buffer"))
.def(py::init<tr::GptSession::Config, tr::GptModelConfig, tr::WorldConfig, std::string>(), py::arg("config"),
py::arg("model_config"), py::arg("world_config"), py::arg("engine_file"))
.def_property_readonly("model_config", &tr::GptSession::getModelConfig)
@ -272,65 +259,8 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.value("REQUEST_STATE_GENERATION_IN_PROGRESS", tb::LlmRequestState_t::REQUEST_STATE_GENERATION_IN_PROGRESS)
.value("REQUEST_STATE_GENERATION_COMPLETE", tb::LlmRequestState_t::REQUEST_STATE_GENERATION_COMPLETE);
using LlmRequest = tpb::LlmRequest;
py::class_<LlmRequest>(m, "LlmRequest")
.def(py::init<LlmRequest::RequestIdType, LlmRequest::SizeType, std::vector<LlmRequest::TokenIdType>,
tr::SamplingConfig, bool, std::optional<LlmRequest::SizeType>, std::optional<LlmRequest::SizeType>,
std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::TensorPtr>,
std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::TensorPtr>,
std::optional<LlmRequest::SizeType>, bool, std::optional<LlmRequest::VecTokens>,
std::optional<LlmRequest::TensorPtr>>(),
py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"),
py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt,
py::arg("embedding_bias") = std::nullopt, py::arg("bad_words_list") = std::nullopt,
py::arg("stop_words_list") = std::nullopt, py::arg("prompt_embedding_table") = std::nullopt,
py::arg("prompt_vocab_size") = std::nullopt, py::arg("return_log_probs") = false,
py::arg("draft_tokens") = std::nullopt, py::arg("draft_logits") = std::nullopt)
.def("get_num_tokens", &LlmRequest::getNumTokens, py::arg("beam"))
.def_property_readonly("max_beam_num_tokens", &LlmRequest::getMaxBeamNumTokens)
.def("get_token", &LlmRequest::getToken, py::arg("beam"), py::arg("pos"))
.def("get_tokens", &LlmRequest::getTokens, py::arg("beam"))
.def_property_readonly("max_num_generated_tokens", &LlmRequest::getMaxNumGeneratedTokens)
.def("add_new_token", &LlmRequest::addNewToken, py::arg("token"), py::arg("beam"))
.def("add_new_tokens", &LlmRequest::addNewTokens, py::arg("beam_tokens"))
.def("set_generated_tokens", &LlmRequest::setGeneratedTokens, py::arg("generated_beam_tokens"))
.def("pause", &LlmRequest::pause, py::arg("max_input_len"))
.def_property("max_sent_token_pos", &LlmRequest::getMaxSentTokenPos, &LlmRequest::setMaxSentTokenPos)
.def_property_readonly("prompt_embedding_table", &LlmRequest::getPromptEmbeddingTable)
.def_property_readonly("prompt_vocab_size", &LlmRequest::getPromptVocabSize)
.def_property_readonly("embedding_bias", &LlmRequest::getEmbeddingBias)
.def_property_readonly("bad_words_list", &LlmRequest::getBadWordsList)
.def_property_readonly("stop_words_list", &LlmRequest::getStopWordsList)
.def_readwrite("request_id", &LlmRequest::mRequestId)
.def_readwrite("prompt_len", &LlmRequest::mPromptLen)
.def_readwrite("max_new_tokens", &LlmRequest::mMaxNewTokens)
.def_readwrite("sampling_config", &LlmRequest::mSamplingConfig)
.def_readwrite("state", &LlmRequest::mState)
.def_readwrite("is_streaming", &LlmRequest::mIsStreaming)
.def_readwrite("end_id", &LlmRequest::mEndId)
.def_readwrite("pad_id", &LlmRequest::mPadId)
.def_readwrite("batch_slot", &LlmRequest::mBatchSlot)
.def_property_readonly("return_log_probs", &LlmRequest::returnLogProbs)
.def_property_readonly("log_probs", py::overload_cast<>(&LlmRequest::getLogProbs, py::const_))
.def("get_log_probs", py::overload_cast<SizeType>(&LlmRequest::getLogProbs, py::const_))
.def("set_log_probs", &LlmRequest::setLogProbs, py::arg("log_probs"), py::arg("beam"))
.def_property_readonly("cum_log_probs", &LlmRequest::getCumLogProbs)
.def("set_cum_log_prob", &LlmRequest::setCumLogProb, py::arg("cum_log_prob"), py::arg("beam"))
.def_property_readonly("orig_prompt_len", &LlmRequest::getOrigPromptLen)
.def("has_draft_tokens", &LlmRequest::hasDraftTokens)
.def_property(
"draft_tokens", [](LlmRequest& self) { return *self.getDraftTokens(); },
[](LlmRequest& self, LlmRequest::VecTokens& draftTokens)
{ self.setDraftTokens(std::make_shared<LlmRequest::VecTokens>(std::move(draftTokens))); })
.def_property(
"draft_logits", [](LlmRequest& self) { return self.getDraftLogits(); },
[](LlmRequest& self, LlmRequest::TensorPtr& logits)
{ self.setDraftLogits(std::make_optional<LlmRequest::TensorPtr>(logits)); });
py::class_<tpb::NamedTensor>(m, "NamedTensor")
.def(py::init<tpb::NamedTensor::TensorPtr, std::string>(), py::arg("tensor"), py::arg("name"))
.def_readwrite("tensor", &tpb::NamedTensor::tensor)
.def_readonly("name", &tpb::NamedTensor::name);
tpb::NamedTensor::initBindings(m);
tpb::LlmRequest::initBindings(m);
auto tensorNames = m.def_submodule("tensor_names");
// Input tensor names
@ -362,42 +292,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
tensorNames.attr("OUTPUT_LOG_PROBS") = py::str(tb::inference_request::kLogProbsTensorName);
tensorNames.attr("CUM_LOG_PROBS") = py::str(tb::inference_request::kCumLogProbsTensorName);
using InferenceRequest = tpb::InferenceRequest;
py::class_<InferenceRequest>(m, "InferenceRequest")
.def(py::init<uint64_t>())
.def(py::init<uint64_t, InferenceRequest::TensorMap const&>(), "deprecated: use direct tensor access instead")
.def_property("input_ids", &InferenceRequest::getInputIdsUnchecked, &InferenceRequest::setInputIds)
.def_property(
"draft_input_ids", &InferenceRequest::getDraftInputIdsUnchecked, &InferenceRequest::setDraftInputIds)
.def_property("draft_logits", &InferenceRequest::getDraftLogitsUnchecked, &InferenceRequest::setDraftLogits)
.def_property("max_new_tokens", &InferenceRequest::getMaxNewTokensUnchecked, &InferenceRequest::setMaxNewTokens)
.def_property("beam_width", &InferenceRequest::getBeamWidthUnchecked, &InferenceRequest::setBeamWidth)
.def_property("end_id", &InferenceRequest::getEndIdUnchecked, &InferenceRequest::setEndId)
.def_property("pad_id", &InferenceRequest::getPadIdUnchecked, &InferenceRequest::setPadId)
.def_property("bad_words_list", &InferenceRequest::getBadWordsListUnchecked, &InferenceRequest::setBadWordsList)
.def_property(
"stop_words_list", &InferenceRequest::getStopWordsListUnchecked, &InferenceRequest::setStopWordsList)
.def_property(
"embedding_bias", &InferenceRequest::getEmbeddingBiasUnchecked, &InferenceRequest::setEmbeddingBias)
.def_property("temperature", &InferenceRequest::getTemperatureUnchecked, &InferenceRequest::setTemperature)
.def_property("runtime_top_k", &InferenceRequest::getRuntimeTopKUnchecked, &InferenceRequest::setRuntimeTopK)
.def_property("runtime_top_p", &InferenceRequest::getRuntimeTopPUnchecked, &InferenceRequest::setRuntimeTopP)
.def_property(
"length_penalty", &InferenceRequest::getLengthPenaltyUnchecked, &InferenceRequest::setLengthPenalty)
.def_property("repetition_penalty", &InferenceRequest::getRepetitionPenaltyUnchecked,
&InferenceRequest::setRepetitionPenalty)
.def_property("min_length", &InferenceRequest::getMinLengthUnchecked, &InferenceRequest::setMinLength)
.def_property(
"presence_penalty", &InferenceRequest::getPresencePenaltyUnchecked, &InferenceRequest::setPresencePenalty)
.def_property("random_seed", &InferenceRequest::getRandomSeedUnchecked, &InferenceRequest::setRandomSeed)
.def_property(
"return_log_probs", &InferenceRequest::getReturnLogProbsUnchecked, &InferenceRequest::setReturnLogProbs)
.def_property("prompt_embedding_table", &InferenceRequest::getPromptEmbeddingTableUnchecked,
&InferenceRequest::setPromptEmbeddingTable)
.def_property(
"prompt_vocab_size", &InferenceRequest::getPromptVocabSizeUnchecked, &InferenceRequest::setPromptVocabSize)
.def_property("is_streaming", &InferenceRequest::isStreaming, &InferenceRequest::setIsStreaming)
.def_property_readonly("request_id", &InferenceRequest::getRequestId);
tpb::InferenceRequest::initBindings(m);
py::enum_<tb::TrtGptModelType>(m, "TrtGptModelType")
.value("V1", tb::TrtGptModelType::V1)
@ -409,22 +304,14 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.value("GUARANTEED_NO_EVICT", tbb::SchedulerPolicy::GUARANTEED_NO_EVICT);
py::class_<tb::TrtGptModelOptionalParams>(m, "TrtGptModelOptionalParams")
.def(py::init<tbk::KvCacheConfig, std::optional<SizeType>, bool>(),
py::arg("kv_cache_config") = tbk::KvCacheConfig{}, py::arg("max_num_sequences") = py::none(),
py::arg("enable_trt_overlap") = true)
.def(py::init<tbk::KvCacheConfig, std::optional<SizeType>, bool, bool>(),
py::arg_v("kv_cache_config", tbk::KvCacheConfig{}, "KvCacheConfig()"),
py::arg("max_num_sequences") = py::none(), py::arg("enable_trt_overlap") = true,
py::arg("use_context_fmha_for_generation") = false)
.def_readwrite("kv_cache_config", &tb::TrtGptModelOptionalParams::kvCacheConfig)
.def_readwrite("max_num_sequences", &tb::TrtGptModelOptionalParams::maxNumSequences)
.def_readwrite("enable_trt_overlap", &tb::TrtGptModelOptionalParams::enableTrtOverlap);
.def_readwrite("enable_trt_overlap", &tb::TrtGptModelOptionalParams::enableTrtOverlap)
.def_readwrite("use_context_fmha_for_generation", &tb::TrtGptModelOptionalParams::useContextFMHAForGeneration);
py::class_<tpb::GptManager>(m, "GptManager")
.def(py::init<std::filesystem::path const&, tb::TrtGptModelType, int32_t, tb::batch_scheduler::SchedulerPolicy,
tpb::GetInferenceRequestsCallback, tpb::SendResponseCallback, tb::PollStopSignalCallback,
tb::ReturnBatchManagerStatsCallback, const tb::TrtGptModelOptionalParams&, std::optional<uint64_t>>(),
py::arg("trt_engine_path"), py::arg("model_type"), py::arg("max_beam_width"), py::arg("scheduler_policy"),
py::arg("get_inference_requests_cb"), py::arg("send_response_cb"), py::arg("poll_stop_signal_cb") = nullptr,
py::arg("return_batch_manager_stats_cb") = nullptr,
py::arg("optional_params") = tb::TrtGptModelOptionalParams(), py::arg("terminate_req_id") = std::nullopt)
.def("shutdown", &tpb::GptManager::exit)
.def("__enter__", &tpb::GptManager::enter)
.def("__exit__", &tpb::GptManager::exit);
tpb::GptManager::initBindings(m);
}

View File

@ -19,6 +19,11 @@
#include "tensorrt_llm/runtime/generationInput.h"
#include "tensorrt_llm/runtime/torchView.h"
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
namespace tr = tensorrt_llm::runtime;
using namespace tensorrt_llm::pybind::runtime;
@ -36,6 +41,17 @@ std::shared_ptr<tr::PromptTuningParams> PromptTuningParams::toTrtLlm() const
return ptt;
}
void PromptTuningParams::initBindings(pybind11::module_& m)
{
py::class_<PromptTuningParams>(m, "PromptTuningParams")
.def(py::init<PromptTuningParams::TensorPtr, PromptTuningParams::TensorPtr, PromptTuningParams::TensorPtr>(),
py::arg("embedding_table") = py::none(), py::arg("tasks") = py::none(), py::arg("vocab_size") = py::none())
.def_readwrite("embedding_table", &PromptTuningParams::embeddingTable)
.def_readwrite("tasks", &PromptTuningParams::tasks)
.def_readwrite("vocab_size", &PromptTuningParams::vocabSize)
.def_readwrite("prompt_tuning_enabled", &PromptTuningParams::promptTuningEnabled);
}
std::shared_ptr<tr::GenerationInput> GenerationInput::toTrtLlm() const
{
auto input = std::make_shared<tr::GenerationInput>(
@ -52,3 +68,20 @@ std::shared_ptr<tr::GenerationInput> GenerationInput::toTrtLlm() const
return input;
}
void GenerationInput::initBindings(pybind11::module_& m)
{
py::class_<GenerationInput>(m, "GenerationInput")
.def(py::init<SizeType, SizeType, GenerationInput::TensorPtr, GenerationInput::TensorPtr, bool>(),
py::arg("end_id"), py::arg("pad_id"), py::arg("ids"), py::arg("lengths"), py::arg("packed") = false)
.def_readwrite("end_id", &GenerationInput::endId)
.def_readwrite("pad_id", &GenerationInput::padId)
.def_readwrite("ids", &GenerationInput::ids)
.def_readwrite("lengths", &GenerationInput::lengths)
.def_readwrite("packed", &GenerationInput::packed)
.def_readwrite("embedding_bias", &GenerationInput::embeddingBias)
.def_readwrite("bad_words_list", &GenerationInput::badWordsList)
.def_readwrite("stop_words_list", &GenerationInput::stopWordsList)
.def_readwrite("max_new_tokens", &GenerationInput::maxNewTokens)
.def_readwrite("prompt_tuning_params", &GenerationInput::promptTuningParams);
}

View File

@ -17,15 +17,14 @@
#pragma once
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/generationInput.h"
#include <ATen/ATen.h>
#include <ATen/ops/tensor.h>
#include <memory>
#include <optional>
#include <pybind11/pybind11.h>
namespace tensorrt_llm::pybind::runtime
{
@ -46,6 +45,7 @@ public:
}
[[nodiscard]] std::shared_ptr<tensorrt_llm::runtime::PromptTuningParams> toTrtLlm() const;
static void initBindings(pybind11::module_& m);
};
class GenerationInput
@ -62,5 +62,6 @@ public:
}
[[nodiscard]] std::shared_ptr<tensorrt_llm::runtime::GenerationInput> toTrtLlm() const;
static void initBindings(pybind11::module_& m);
};
} // namespace tensorrt_llm::pybind::runtime

View File

@ -19,6 +19,11 @@
#include "tensorrt_llm/runtime/torch.h"
#include "tensorrt_llm/runtime/torchView.h"
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
namespace tr = tensorrt_llm::runtime;
using namespace tensorrt_llm::pybind::runtime;
@ -35,6 +40,10 @@ std::shared_ptr<tr::GenerationOutput> GenerationOutput::toTrtLlm() const
{
output->contextLogits = tr::TorchView::of(contextLogits.value());
}
if (generationLogits)
{
output->generationLogits = tr::TorchView::of(generationLogits.value());
}
if (onTokenGenerated)
{
@ -44,3 +53,15 @@ std::shared_ptr<tr::GenerationOutput> GenerationOutput::toTrtLlm() const
}
return output;
}
void GenerationOutput::initBindings(py::module_& m)
{
py::class_<GenerationOutput>(m, "GenerationOutput")
.def(py::init<GenerationOutput::TensorPtr, GenerationOutput::TensorPtr>(), py::arg("ids"), py::arg("lengths"))
.def_readwrite("ids", &GenerationOutput::ids)
.def_readwrite("lengths", &GenerationOutput::lengths)
.def_readwrite("log_probs", &GenerationOutput::logProbs)
.def_readwrite("context_logits", &GenerationOutput::contextLogits)
.def_readwrite("generation_logits", &GenerationOutput::generationLogits)
.def_readwrite("on_token_generated", &GenerationOutput::onTokenGenerated);
}

View File

@ -20,6 +20,7 @@
#include <ATen/ATen.h>
#include <optional>
#include <pybind11/pybind11.h>
namespace tensorrt_llm::pybind::runtime
{
@ -36,6 +37,7 @@ public:
}
[[nodiscard]] std::shared_ptr<tensorrt_llm::runtime::GenerationOutput> toTrtLlm() const;
static void initBindings(pybind11::module_& m);
};
} // namespace tensorrt_llm::pybind::runtime

View File

@ -100,7 +100,7 @@ typename tl::DynamicDecodeLayer<T>::ForwardParams prepareInputs(DecodingInput co
auto constexpr ite = 0; // no pipeline parallelism
typename tl::DynamicDecodeLayer<T>::ForwardParams forwardParams{input.step, ite, input.maxLength,
input.maxKvCacheLength, input.batchSize, tcc::toTllmTensor(*input.logits), tcc::toTllmTensor(*input.endIds)};
input.maxAttentionWindow, input.batchSize, tcc::toTllmTensor(*input.logits), tcc::toTllmTensor(*input.endIds)};
if (input.cacheIndirection)
{

View File

@ -114,7 +114,7 @@ GptDecoderBatch::GptDecoderBatch(
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
void GptDecoderBatch::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength,
void GptDecoderBatch::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow,
SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
@ -125,7 +125,7 @@ void GptDecoderBatch::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeTy
mActualBatchSize = maxBatchSize;
mGeneratedTokensPerStep.resize(maxBatchSize);
mMaxSequenceLength = maxSequenceLength;
mMaxKvCacheLength = maxKvCacheLength;
mMaxAttentionWindow = maxAttentionWindow;
mMaxTokensPerStep = maxTokensPerStep;
auto const maxBatchSizeShape = ITensor::makeShape({maxBatchSize});
@ -248,7 +248,7 @@ void GptDecoderBatch::newRequest(
TensorPtr endIdTensorPtr{ITensor::slice(constPointerCast(dJointInput.endIds), batchIdx, localBatchSize)};
kernels::invokeFill(*endIdTensorPtr, endId, *stream);
dInput = std::make_unique<DecodingInput>(
inputLength, mMaxKvCacheLength, localBatchSize, dJointInput.logits, endIdTensorPtr);
inputLength, mMaxAttentionWindow, localBatchSize, dJointInput.logits, endIdTensorPtr);
// Here, we need to add leading 1 dimension since decoderInput expects batchSize as leading dim
// and decoder_batch::Request doesn't have batch dimension

View File

@ -74,75 +74,190 @@ GptJsonConfig parseJson(InputType&& i)
auto constexpr ingoreComments = true;
auto json = nlohmann::json::parse(i, nullptr, allowExceptions, ingoreComments);
auto const& builderConfig = json.at("builder_config");
auto const name = builderConfig.at("name").template get<std::string>();
auto const precision = builderConfig.at("precision").template get<std::string>();
auto const tensorParallelism = builderConfig.at("tensor_parallel").template get<SizeType>();
auto const pipelineParallelism = parseJsonFieldOr(builderConfig, "pipeline_parallel", 1);
auto const numHeads = builderConfig.at("num_heads").template get<SizeType>() / tensorParallelism;
auto const hiddenSize = builderConfig.at("hidden_size").template get<SizeType>() / tensorParallelism;
auto const vocabSize = builderConfig.at("vocab_size").template get<SizeType>();
auto const numLayers = builderConfig.at("num_layers").template get<SizeType>();
auto dataType = nvinfer1::DataType::kFLOAT;
if (!precision.compare("float32"))
dataType = nvinfer1::DataType::kFLOAT;
else if (!precision.compare("float16"))
dataType = nvinfer1::DataType::kHALF;
else if (!precision.compare("bfloat16"))
dataType = nvinfer1::DataType::kBF16;
else
TLLM_CHECK_WITH_INFO(false, tc::fmtstr("Model data type '%s' not supported", precision.c_str()));
auto const quantMode = tc::QuantMode(parseJsonFieldOr(builderConfig, "quant_mode", tc::QuantMode::none().value()));
// TODO:
// Code crashes when numKvHeads <= 0. Clamping downwards to 1 prevents that, make sure this is best fix.
auto const numKvHeads = std::max(
parseJsonFieldOr(builderConfig, "num_kv_heads", numHeads * tensorParallelism) / tensorParallelism, 1);
auto const maxBatchSize = parseJsonFieldOr(builderConfig, "max_batch_size", 0);
auto const maxInputLen = parseJsonFieldOr(builderConfig, "max_input_len", 0);
auto const maxOutputLen = parseJsonFieldOr(builderConfig, "max_output_len", 0);
auto const maxDraftLen = parseJsonFieldOr(builderConfig, "max_draft_len", 0);
auto const maxNumTokens = parseJsonFieldOptional<SizeType>(builderConfig, "max_num_tokens");
auto const maxPromptEmbeddingTableSize
= parseJsonFieldOr<SizeType>(builderConfig, "max_prompt_embedding_table_size", 0);
auto const computeContextLogits = parseJsonFieldOr(builderConfig, "gather_all_token_logits", false);
auto const computeGenerationLogits = parseJsonFieldOr(builderConfig, "gather_all_token_logits", false);
auto const& pluginConfig = json.at("plugin_config");
auto const pagedKvCache = pluginConfig.at("paged_kv_cache");
auto const tokensPerBlock = pluginConfig.at("tokens_per_block");
auto const& gptAttentionPlugin = pluginConfig.at("gpt_attention_plugin");
auto const useGptAttentionPlugin = !gptAttentionPlugin.is_boolean() || gptAttentionPlugin.template get<bool>();
auto const removeInputPadding = pluginConfig.at("remove_input_padding").template get<bool>();
auto const useCustomAllReduce = pluginConfig.at("use_custom_all_reduce").template get<bool>();
auto modelConfig = GptModelConfig{vocabSize, numLayers, numHeads, hiddenSize, dataType};
modelConfig.useGptAttentionPlugin(useGptAttentionPlugin);
modelConfig.usePackedInput(removeInputPadding);
modelConfig.usePagedKvCache(pagedKvCache);
modelConfig.useCustomAllReduce(useCustomAllReduce);
modelConfig.setTokensPerBlock(tokensPerBlock);
modelConfig.setQuantMode(quantMode);
modelConfig.setNbKvHeads(numKvHeads);
modelConfig.computeContextLogits(computeContextLogits);
modelConfig.computeGenerationLogits(computeGenerationLogits);
modelConfig.setMaxBatchSize(maxBatchSize);
modelConfig.setMaxInputLen(maxInputLen);
modelConfig.setMaxOutputLen(maxOutputLen);
modelConfig.setMaxNumTokens(maxNumTokens);
modelConfig.setMaxDraftLen(maxDraftLen);
modelConfig.setMaxPromptEmbeddingTableSize(maxPromptEmbeddingTableSize);
if (name == std::string("chatglm_6b") || name == std::string("glm_10b"))
auto engine_version = parseJsonFieldOr(json, "version", std::string("none"));
if (engine_version == std::string("none"))
{
modelConfig.setModelVariant(GptModelConfig::ModelVariant::kGlm);
// kGlm is only for ChatGLM-6B and GLM-10B
}
auto const& builderConfig = json.at("builder_config");
auto const name = builderConfig.at("name").template get<std::string>();
auto const precision = builderConfig.at("precision").template get<std::string>();
auto const tensorParallelism = builderConfig.at("tensor_parallel").template get<SizeType>();
auto const pipelineParallelism = parseJsonFieldOr(builderConfig, "pipeline_parallel", 1);
auto const numHeads = builderConfig.at("num_heads").template get<SizeType>() / tensorParallelism;
auto const hiddenSize = builderConfig.at("hidden_size").template get<SizeType>() / tensorParallelism;
auto const vocabSize = builderConfig.at("vocab_size").template get<SizeType>();
auto const numLayers = builderConfig.at("num_layers").template get<SizeType>();
return GptJsonConfig{name, precision, tensorParallelism, pipelineParallelism, modelConfig};
auto dataType = nvinfer1::DataType::kFLOAT;
if (!precision.compare("float32"))
dataType = nvinfer1::DataType::kFLOAT;
else if (!precision.compare("float16"))
dataType = nvinfer1::DataType::kHALF;
else if (!precision.compare("bfloat16"))
dataType = nvinfer1::DataType::kBF16;
else
TLLM_CHECK_WITH_INFO(false, tc::fmtstr("Model data type '%s' not supported", precision.c_str()));
auto const quantMode
= tc::QuantMode(parseJsonFieldOr(builderConfig, "quant_mode", tc::QuantMode::none().value()));
// TODO:
// Code crashes when numKvHeads <= 0. Clamping downwards to 1 prevents that, make sure this is best fix.
auto const numKvHeads = std::max(
parseJsonFieldOr(builderConfig, "num_kv_heads", numHeads * tensorParallelism) / tensorParallelism, 1);
auto const maxBatchSize = parseJsonFieldOr(builderConfig, "max_batch_size", 0);
auto const maxBeamWidth = parseJsonFieldOr(builderConfig, "max_beam_width", 0);
auto const maxInputLen = parseJsonFieldOr(builderConfig, "max_input_len", 0);
auto const maxOutputLen = parseJsonFieldOr(builderConfig, "max_output_len", 0);
auto const maxDraftLen = parseJsonFieldOr(builderConfig, "max_draft_len", 0);
auto const maxNumTokens = parseJsonFieldOptional<SizeType>(builderConfig, "max_num_tokens");
auto const maxPromptEmbeddingTableSize
= parseJsonFieldOr<SizeType>(builderConfig, "max_prompt_embedding_table_size", 0);
auto const computeContextLogits = parseJsonFieldOr(builderConfig, "gather_all_token_logits", false);
auto const computeGenerationLogits = parseJsonFieldOr(builderConfig, "gather_all_token_logits", false);
auto const& pluginConfig = json.at("plugin_config");
auto const pagedKvCache = pluginConfig.at("paged_kv_cache");
auto const tokensPerBlock = pluginConfig.at("tokens_per_block");
auto const& gptAttentionPlugin = pluginConfig.at("gpt_attention_plugin");
auto const useGptAttentionPlugin = !gptAttentionPlugin.is_boolean() || gptAttentionPlugin.template get<bool>();
auto const removeInputPadding = pluginConfig.at("remove_input_padding").template get<bool>();
auto const useCustomAllReduce = pluginConfig.at("use_custom_all_reduce").template get<bool>();
auto modelConfig = GptModelConfig{vocabSize, numLayers, numHeads, hiddenSize, dataType};
modelConfig.useGptAttentionPlugin(useGptAttentionPlugin);
modelConfig.usePackedInput(removeInputPadding);
modelConfig.usePagedKvCache(pagedKvCache);
modelConfig.useCustomAllReduce(useCustomAllReduce);
modelConfig.setTokensPerBlock(tokensPerBlock);
modelConfig.setQuantMode(quantMode);
modelConfig.setNbKvHeads(numKvHeads);
modelConfig.computeContextLogits(computeContextLogits);
modelConfig.computeGenerationLogits(computeGenerationLogits);
modelConfig.setMaxBatchSize(maxBatchSize);
modelConfig.setMaxBeamWidth(maxBeamWidth);
modelConfig.setMaxInputLen(maxInputLen);
modelConfig.setMaxOutputLen(maxOutputLen);
modelConfig.setMaxNumTokens(maxNumTokens);
modelConfig.setMaxDraftLen(maxDraftLen);
modelConfig.setMaxPromptEmbeddingTableSize(maxPromptEmbeddingTableSize);
if (name == std::string("chatglm_6b") || name == std::string("glm_10b"))
{
modelConfig.setModelVariant(GptModelConfig::ModelVariant::kGlm);
// kGlm is only for ChatGLM-6B and GLM-10B
}
return GptJsonConfig{name, engine_version, precision, tensorParallelism, pipelineParallelism, modelConfig};
}
else
{
auto const& pretrainedConfig = json.at("pretrained_config");
auto const& buildConfig = json.at("build_config");
auto const architecture = pretrainedConfig.at("architecture").template get<std::string>();
auto const name = architecture;
auto const dtype = pretrainedConfig.at("dtype").template get<std::string>();
auto const& mapping = pretrainedConfig.at("mapping");
auto const tpSize = mapping.at("tp_size").template get<SizeType>();
auto const ppSize = parseJsonFieldOr(mapping, "pp_size", 1);
auto const numAttentionHeads = pretrainedConfig.at("num_attention_heads").template get<SizeType>() / tpSize;
auto const hiddenSize = pretrainedConfig.at("hidden_size").template get<SizeType>() / tpSize;
auto const vocabSize = pretrainedConfig.at("vocab_size").template get<SizeType>();
auto const numHiddenLayers = pretrainedConfig.at("num_hidden_layers").template get<SizeType>();
auto dataType = nvinfer1::DataType::kFLOAT;
if (!dtype.compare("float32"))
dataType = nvinfer1::DataType::kFLOAT;
else if (!dtype.compare("float16"))
dataType = nvinfer1::DataType::kHALF;
else if (!dtype.compare("bfloat16"))
dataType = nvinfer1::DataType::kBF16;
else
TLLM_CHECK_WITH_INFO(false, tc::fmtstr("Model data type '%s' not supported", dtype.c_str()));
auto const& quantization = pretrainedConfig.at("quantization");
auto useSmoothQuant = parseJsonFieldOr(quantization, "use_smooth_quant", false);
auto perChannel = parseJsonFieldOr(quantization, "per_channel", false);
auto perToken = parseJsonFieldOr(quantization, "per_token", false);
// TODO: Unused parameters
// auto perGroup = parseJsonFieldOr(quantization, "per_group", false);
// auto groupSize = parseJsonFieldOr(quantization, "group_size", 128);
auto int8KvCache = parseJsonFieldOr(quantization, "int8_kv_cache", false);
auto enableFp8 = parseJsonFieldOr(quantization, "enable_fp8", false);
auto fp8KvCache = parseJsonFieldOr(quantization, "fp8_kv_cache", false);
auto useWeightOnly = parseJsonFieldOr(quantization, "use_weight_only", false);
auto weightOnlyPrecision = parseJsonFieldOr(quantization, "weight_only_precision", std::string("int8"));
bool quantizeWeights = false;
bool quantizeActivations = false;
if (useSmoothQuant)
{
quantizeWeights = true;
quantizeActivations = true;
}
else if (useWeightOnly)
{
quantizeWeights = true;
perToken = false;
perChannel = false;
}
bool useInt4Weights = (weightOnlyPrecision == std::string("int4"));
auto const quantMode = tc::QuantMode::fromDescription(quantizeWeights, quantizeActivations, perToken,
perChannel, useInt4Weights, int8KvCache, fp8KvCache, enableFp8);
// TODO:
// Code crashes when numKvHeads <= 0. Clamping downwards to 1 prevents that, make sure this is best fix.
auto const numKVHeads = pretrainedConfig.at("num_key_value_heads").template get<SizeType>();
auto const numKeyValueHeads = std::max(numKVHeads / tpSize, 1);
auto const maxBatchSize = parseJsonFieldOr(buildConfig, "max_batch_size", 0);
auto const maxBeamWidth = parseJsonFieldOr(buildConfig, "max_beam_width", 0);
auto const maxInputLen = parseJsonFieldOr(buildConfig, "max_input_len", 0);
auto const maxOutputLen = parseJsonFieldOr(buildConfig, "max_output_len", 0);
auto const maxDraftLen = parseJsonFieldOr(buildConfig, "max_draft_len", 0);
auto const maxNumTokens = parseJsonFieldOptional<SizeType>(buildConfig, "max_num_tokens");
auto const maxPromptEmbeddingTableSize
= parseJsonFieldOr<SizeType>(buildConfig, "max_prompt_embedding_table_size", 0);
auto const computeContextLogits = parseJsonFieldOr(buildConfig, "gather_all_token_logits", false);
auto const computeGenerationLogits = parseJsonFieldOr(buildConfig, "gather_all_token_logits", false);
auto const& pluginConfig = buildConfig.at("plugin_config");
auto const pagedKvCache = pluginConfig.at("paged_kv_cache");
auto const tokensPerBlock = pluginConfig.at("tokens_per_block");
auto const& gptAttentionPlugin = pluginConfig.at("gpt_attention_plugin");
auto const useGptAttentionPlugin = !gptAttentionPlugin.is_boolean() || gptAttentionPlugin.template get<bool>();
auto const removeInputPadding = pluginConfig.at("remove_input_padding").template get<bool>();
auto const useCustomAllReduce = pluginConfig.at("use_custom_all_reduce").template get<bool>();
auto modelConfig = GptModelConfig{vocabSize, numHiddenLayers, numAttentionHeads, hiddenSize, dataType};
modelConfig.useGptAttentionPlugin(useGptAttentionPlugin);
modelConfig.usePackedInput(removeInputPadding);
modelConfig.usePagedKvCache(pagedKvCache);
modelConfig.useCustomAllReduce(useCustomAllReduce);
modelConfig.setTokensPerBlock(tokensPerBlock);
modelConfig.setQuantMode(quantMode);
modelConfig.setNbKvHeads(numKeyValueHeads);
modelConfig.computeContextLogits(computeContextLogits);
modelConfig.computeGenerationLogits(computeGenerationLogits);
modelConfig.setMaxBatchSize(maxBatchSize);
modelConfig.setMaxBeamWidth(maxBeamWidth);
modelConfig.setMaxInputLen(maxInputLen);
modelConfig.setMaxOutputLen(maxOutputLen);
modelConfig.setMaxNumTokens(maxNumTokens);
modelConfig.setMaxDraftLen(maxDraftLen);
modelConfig.setMaxPromptEmbeddingTableSize(maxPromptEmbeddingTableSize);
// TODO: Verify the architecture field in ChatGLM models
if (name == std::string("ChatGLMModel") || name == std::string("GLMModel"))
{
modelConfig.setModelVariant(GptModelConfig::ModelVariant::kGlm);
// kGlm is only for ChatGLM-6B and GLM-10B
}
return GptJsonConfig{name, engine_version, dtype, tpSize, ppSize, modelConfig};
}
}
} // namespace
@ -153,8 +268,15 @@ std::string GptJsonConfig::engineFilename(WorldConfig const& worldConfig, std::s
TLLM_CHECK_WITH_INFO(
getPipelineParallelism() == worldConfig.getPipelineParallelism(), "pipeline parallelism mismatch");
auto pp = worldConfig.isPipelineParallel() ? "_pp" + std::to_string(worldConfig.getPipelineParallelism()) : "";
return model + "_" + getPrecision() + "_tp" + std::to_string(worldConfig.getTensorParallelism()) + pp + "_rank"
+ std::to_string(worldConfig.getRank()) + ".engine";
if (getVersion() == std::string("none"))
{
return model + "_" + getPrecision() + "_tp" + std::to_string(worldConfig.getTensorParallelism()) + pp + "_rank"
+ std::to_string(worldConfig.getRank()) + ".engine";
}
else
{
return "rank" + std::to_string(worldConfig.getRank()) + ".engine";
}
}
GptJsonConfig GptJsonConfig::parse(std::string const& json)

View File

@ -117,7 +117,7 @@ void GptSession::createBuffers(SizeType numMicroBatches)
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
void GptSession::createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxKvCacheLength,
void GptSession::createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow,
SizeType maxSequenceLength, nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
@ -135,13 +135,13 @@ void GptSession::createDecoders(SizeType batchSize, SizeType beamWidth, SizeType
mDecoders.emplace_back(std::make_shared<StatefulGptDecoder>(vocabSize, vocabSizePadded, stream));
constexpr SizeType maxTokensPerStep = 1;
mDecoders.back()->setup(
batchSize, beamWidth, maxKvCacheLength, maxSequenceLength, maxTokensPerStep, logitsType);
batchSize, beamWidth, maxAttentionWindow, maxSequenceLength, maxTokensPerStep, logitsType);
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
void GptSession::createKvCacheManager(SizeType batchSize, SizeType beamWidth, SizeType maxKvCacheLength,
void GptSession::createKvCacheManager(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow,
SizeType maxSequenceLength, KvCacheConfig const& config)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
@ -169,11 +169,11 @@ void GptSession::createKvCacheManager(SizeType batchSize, SizeType beamWidth, Si
= bmkv::KVCacheManager::getMaxNumTokens(config, kvDtype, mModelConfig, mWorldConfig, getBufferManager());
TLLM_LOG_INFO("Using %d tokens in paged KV cache.", maxNumTokens);
auto const maxNumBlocks = tc::ceilDiv(maxNumTokens, tokensPerBlock);
auto const maxBlocksPerSeq = tc::ceilDiv(maxSequenceLength, tokensPerBlock);
auto const maxBlocksPerSeq = tc::ceilDiv(std::min(maxSequenceLength, maxAttentionWindow), tokensPerBlock);
mKvCacheManager
= std::make_shared<bmkv::KVCacheManager>(localNbLayers, nbHeads, nbKvHeads, hiddenSize, tokensPerBlock,
maxNumBlocks, batchSize, beamWidth, maxBlocksPerSeq, maxKvCacheLength, kvDtype, mRuntime->getStreamPtr());
maxNumBlocks, batchSize, beamWidth, maxBlocksPerSeq, maxAttentionWindow, kvDtype, mRuntime->getStreamPtr());
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
@ -234,8 +234,8 @@ void GptSession::setup(Config const& sessionConfig)
auto const maxBatchSize = sessionConfig.maxBatchSize;
auto const maxBeamWidth = sessionConfig.maxBeamWidth;
auto const maxSequenceLength = sessionConfig.maxSequenceLength;
auto const maxKvCacheLength = sessionConfig.kvCacheConfig.maxKvCacheLength.has_value()
? std::min(sessionConfig.kvCacheConfig.maxKvCacheLength.value(), maxSequenceLength)
auto const maxAttentionWindow = sessionConfig.kvCacheConfig.maxAttentionWindow.has_value()
? std::min(sessionConfig.kvCacheConfig.maxAttentionWindow.value(), maxSequenceLength)
: maxSequenceLength;
mMicroBatchConfig = MicroBatchConfig(maxBatchSize, mWorldConfig.getPipelineParallelism(),
@ -249,18 +249,18 @@ void GptSession::setup(Config const& sessionConfig)
// gptDecoderBatch does not resize buffers, but allows smaller batchSize and beamWidth.
// TODO refactor batch manager to remove dependency on maxSequenceLength.
mDecoderMaxSequenceLength = maxSequenceLength;
mDecoderMaxKvCacheLength = maxKvCacheLength;
mDecoderMaxAttentionWindow = maxAttentionWindow;
if (mModelConfig.usePagedKvCache())
{
createKvCacheManager(
maxBatchSize, maxBeamWidth, maxKvCacheLength, maxSequenceLength, sessionConfig.kvCacheConfig);
maxBatchSize, maxBeamWidth, maxAttentionWindow, maxSequenceLength, sessionConfig.kvCacheConfig);
}
if (mWorldConfig.isLastPipelineParallelRank())
{
auto const logitsType = mRuntime->getEngine().getTensorDataType("logits");
createDecoders(mMicroBatchConfig.genBatchSize, maxBeamWidth, maxKvCacheLength, maxSequenceLength, logitsType,
createDecoders(mMicroBatchConfig.genBatchSize, maxBeamWidth, maxAttentionWindow, maxSequenceLength, logitsType,
sessionConfig.decoderPerRequest, mMicroBatchConfig.numGenBatches);
}
@ -280,7 +280,7 @@ void GptSession::setup(Config const& sessionConfig)
{
// we don't know maxInputLength yet and ignore it for pre-allocation
buffers->generationConfig = RuntimeBuffers::GenerationConfig{
mMicroBatchConfig.genBatchSize, maxBeamWidth, 0, maxKvCacheLength, maxSequenceLength};
mMicroBatchConfig.genBatchSize, maxBeamWidth, 0, maxAttentionWindow, maxSequenceLength};
buffers->reshape(mModelConfig, mWorldConfig);
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
@ -295,9 +295,9 @@ void GptSession::kvCacheAddSequences(SizeType beamWidth, SizeType microBatchId,
TLLM_CHECK(contextLengthsHost);
auto const contextLengthsPtr = bufferCast<SizeType const>(*contextLengthsHost);
auto const contextLengthsSize = static_cast<SizeType>(contextLengthsHost->getSize());
for (SizeType batchIdx = firstBatchIdx; batchIdx < firstBatchIdx + contextLengthsSize; ++batchIdx)
for (SizeType batchIdx = 0; batchIdx < contextLengthsSize; ++batchIdx)
{
mKvCacheManager->addSequence(batchIdx, contextLengthsPtr[batchIdx], beamWidth);
mKvCacheManager->addSequence(firstBatchIdx + batchIdx, contextLengthsPtr[batchIdx], beamWidth);
}
}
}
@ -621,7 +621,7 @@ void GptSession::generateBatched(std::vector<GenerationOutput>& microBatchesOutp
auto const& microBatchInputs = microBatchesInputs.at(microBatchId);
auto& buffers = *mBuffers.at(microBatchId);
buffers.initFromInput(*microBatchInputs.ids, microBatchInputs.lengths, microBatchInputs.packed, beamWidth,
mDecoderMaxKvCacheLength, mDecoderMaxSequenceLength, manager);
mDecoderMaxAttentionWindow, mDecoderMaxSequenceLength, manager);
buffers.reshape(mModelConfig, mWorldConfig);
buffers.reset(manager);
}

View File

@ -29,8 +29,8 @@ using namespace tensorrt_llm::runtime;
namespace tc = tensorrt_llm::common;
RuntimeBuffers::GenerationConfig RuntimeBuffers::GenerationConfig::fromInput(ITensor const& inputIds,
ITensor const& inputLengthsHost, bool const inputPacked, SizeType const beamWidth, SizeType const maxKvCacheLength,
SizeType const maxSequenceLength)
ITensor const& inputLengthsHost, bool const inputPacked, SizeType const beamWidth,
SizeType const maxAttentionWindow, SizeType const maxSequenceLength)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
auto const batchSize = static_cast<SizeType>(inputLengthsHost.getSize());
@ -58,7 +58,8 @@ RuntimeBuffers::GenerationConfig RuntimeBuffers::GenerationConfig::fromInput(ITe
"generated.");
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
return GenerationConfig{batchSize, beamWidth, maxInputLength, maxKvCacheLength, maxSequenceLength, inputLengthSum};
return GenerationConfig{
batchSize, beamWidth, maxInputLength, maxAttentionWindow, maxSequenceLength, inputLengthSum};
}
void RuntimeBuffers::clear()
@ -158,7 +159,7 @@ void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelCon
if (modelConfig.useGptAttentionPlugin())
{
pastKeyValueLengths = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32);
maxKvCacheLengths
maxAttentionWindows
= utils::createBufferVector(runtime, localNbLayers, MemoryType::kCPU, nvinfer1::DataType::kINT32);
}
else
@ -185,7 +186,7 @@ void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelCon
}
void RuntimeBuffers::initFromInput(ITensor const& inputIds, TensorPtr const& inputLengths, bool inputPacked,
SizeType beamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength, BufferManager& manager)
SizeType beamWidth, SizeType maxAttentionWindow, SizeType maxSequenceLength, BufferManager& manager)
{
contextLengthsDevice = inputLengths;
contextLengthsHost->reshape(inputLengths->getShape());
@ -193,7 +194,7 @@ void RuntimeBuffers::initFromInput(ITensor const& inputIds, TensorPtr const& inp
manager.getStream().synchronize(); // wait for context lengths to be copied to host
generationConfig = RuntimeBuffers::GenerationConfig::fromInput(
inputIds, *contextLengthsHost, inputPacked, beamWidth, maxKvCacheLength, maxSequenceLength);
inputIds, *contextLengthsHost, inputPacked, beamWidth, maxAttentionWindow, maxSequenceLength);
}
void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
@ -203,7 +204,7 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
auto const batchSize = generationConfig.batchSize;
auto const beamWidth = generationConfig.beamWidth;
auto const maxInputLength = generationConfig.maxInputLength;
auto const maxKvCacheLength = generationConfig.maxKvCacheLength;
auto const maxAttentionWindow = generationConfig.maxAttentionWindow;
if (worldConfig.isLastPipelineParallelRank() && !modelConfig.computeContextLogits())
{
@ -214,14 +215,14 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
lastTokenIds->reshape(ITensor::makeShape({batchSize}));
auto kvCacheReserve = ITensor::makeShape(
{batchSize, 2, modelConfig.getNbKvHeads(), maxKvCacheLength, modelConfig.getSizePerHead()});
{batchSize, 2, modelConfig.getNbKvHeads(), maxAttentionWindow, modelConfig.getSizePerHead()});
auto kvCacheShape
= ITensor::makeShape({batchSize, 2, modelConfig.getNbKvHeads(), maxInputLength, modelConfig.getSizePerHead()});
if (modelConfig.usePagedKvCache())
{
auto const localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism());
auto const tokensPerBlock = modelConfig.getTokensPerBlock();
auto const maxBlocksPerSeq = (maxKvCacheLength + tokensPerBlock - 1) / tokensPerBlock;
auto const maxBlocksPerSeq = (maxAttentionWindow + tokensPerBlock - 1) / tokensPerBlock;
// reserve batchSize * beamWidth and resize to batchSize
auto cacheBlockPointersShape = ITensor::makeShape({localNbLayers, batchSize * beamWidth, 2, maxBlocksPerSeq});
@ -240,7 +241,7 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
{
pastKeyValueLengths->reshape(ITensor::makeShape({batchSize}));
requestTypes->reshape(ITensor::makeShape({batchSize}));
utils::reshapeBufferVector(maxKvCacheLengths, ITensor::makeShape({1}));
utils::reshapeBufferVector(maxAttentionWindows, ITensor::makeShape({1}));
}
else
{
@ -250,7 +251,7 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
utils::reshapeBufferVector(presentKeysVals, kvCacheShape);
}
auto const cacheIndirShape = ITensor::makeShape({batchSize, beamWidth, maxKvCacheLength});
auto const cacheIndirShape = ITensor::makeShape({batchSize, beamWidth, maxAttentionWindow});
cacheIndirectionDecoderInput->reshape(cacheIndirShape);
cacheIndirectionDecoderOutput->reshape(cacheIndirShape);
@ -334,7 +335,7 @@ std::vector<RuntimeBuffers> RuntimeBuffers::split(
if (modelConfig.useGptAttentionPlugin())
{
buffers.pastKeyValueLengths = ITensor::slice(pastKeyValueLengths, offset, batchSize);
buffers.maxKvCacheLengths = maxKvCacheLengths;
buffers.maxAttentionWindows = maxAttentionWindows;
buffers.requestTypes = ITensor::slice(requestTypes, offset, batchSize);
}
else
@ -548,10 +549,10 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c
TLLM_CHECK(requestTypes->getSize() == static_cast<std::size_t>(batchSize));
std::fill_n(RequestTypesPtr, batchSize, 0);
// Set maxKvCacheLengths buffer to the same value currently.
// Set maxAttentionWindows buffer to the same value currently.
for (auto layer = 0; layer < localNbLayers; ++layer)
{
bufferCast<SizeType>(*maxKvCacheLengths[layer])[0] = generationConfig.maxKvCacheLength;
bufferCast<SizeType>(*maxAttentionWindows[layer])[0] = generationConfig.maxAttentionWindow;
}
auto const& inputShape = inputIds->getShape();
@ -823,7 +824,7 @@ void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outpu
inputBuffers.insert_or_assign("host_past_key_value_lengths", pastKeyValueLengths);
inputBuffers.insert_or_assign("host_request_types", requestTypes);
inputBuffers.insert_or_assign("sequence_length", sequenceLengths);
utils::insertTensorVector(inputBuffers, "host_max_kv_cache_length_", maxKvCacheLengths, firstLayerId);
utils::insertTensorVector(inputBuffers, "host_max_attention_window_size_", maxAttentionWindows, firstLayerId);
if (modelConfig.usePackedInput())
{

View File

@ -49,11 +49,11 @@ public:
GenerationConfig() = default;
explicit GenerationConfig(SizeType batchSize, SizeType beamWidth, SizeType maxInputLength,
SizeType maxKvCacheLength, SizeType maxSeqLength, SizeType inputLengthSum = SizeType(0))
SizeType maxAttentionWindow, SizeType maxSeqLength, SizeType inputLengthSum = SizeType(0))
: batchSize{batchSize}
, beamWidth{beamWidth}
, maxInputLength{maxInputLength}
, maxKvCacheLength{maxKvCacheLength}
, maxAttentionWindow{maxAttentionWindow}
, maxSeqLength{maxSeqLength}
, inputLengthSum{inputLengthSum}
{
@ -62,12 +62,12 @@ public:
SizeType batchSize{};
SizeType beamWidth{};
SizeType maxInputLength{};
SizeType maxKvCacheLength{};
SizeType maxAttentionWindow{};
SizeType maxSeqLength{};
SizeType inputLengthSum{}; // Initialized only if inputPacked is set to true in fromInput.
static GenerationConfig fromInput(ITensor const& inputIds, ITensor const& inputLengths, bool inputPacked,
SizeType beamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength);
SizeType beamWidth, SizeType maxAttentionWindow, SizeType maxSequenceLength);
};
public:
@ -89,10 +89,10 @@ public:
TensorPtr requestTypes; // with attention plugin. Host tensor
std::vector<TensorPtr> presentKeysVals;
std::vector<TensorPtr> presentKeysValsAlt; // without attention plugin
std::vector<TensorPtr> maxKvCacheLengths; // with attention plugin, host tensor
TensorPtr kvCacheBlockPointersHost; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
TensorPtr kvCacheBlockPointersDevice; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
std::vector<TensorPtr> presentKeysValsAlt; // without attention plugin
std::vector<TensorPtr> maxAttentionWindows; // with attention plugin, host tensor
TensorPtr kvCacheBlockPointersHost; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
TensorPtr kvCacheBlockPointersDevice; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
// References to tmp buffers
TensorPtr newTokens;
@ -126,7 +126,7 @@ public:
void create(TllmRuntime& runtime, GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
void initFromInput(ITensor const& inputIds, TensorPtr const& inputLengths, bool inputPacked, SizeType beamWidth,
SizeType maxKvCacheLength, SizeType maxSequenceLength, BufferManager& manager);
SizeType maxAttentionWindow, SizeType maxSequenceLength, BufferManager& manager);
//! \brief Reshape buffers based on current GenerationConfig
void reshape(GptModelConfig const& modelConfig, WorldConfig const& worldConfig);

View File

@ -908,6 +908,7 @@ void tileTensor(ITensor& output, ITensor const& input, SizeType beamWidth, CudaS
case nvinfer1::DataType::kINT32: invokeTileTensor<SizeType>(output, input, beamWidth, stream); break;
case nvinfer1::DataType::kFLOAT: invokeTileTensor<float>(output, input, beamWidth, stream); break;
case nvinfer1::DataType::kHALF: invokeTileTensor<half>(output, input, beamWidth, stream); break;
case nvinfer1::DataType::kBF16: invokeTileTensor<__nv_bfloat16>(output, input, beamWidth, stream); break;
case nvinfer1::DataType::kINT8: invokeTileTensor<int8_t>(output, input, beamWidth, stream); break;
case nvinfer1::DataType::kFP8: invokeTileTensor<__nv_fp8_e4m3>(output, input, beamWidth, stream); break;
default: TLLM_CHECK_WITH_INFO(false, "data type not supported");

View File

@ -64,19 +64,19 @@ StatefulGptDecoder::StatefulGptDecoder(std::size_t vocabSize, std::size_t vocabS
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
void StatefulGptDecoder::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength,
void StatefulGptDecoder::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow,
SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(maxTokensPerStep == 1);
mDecoder = IGptDecoder::create(dtype, mVocabSize, mVocabSizePadded, mStream);
reshapeBuffers(maxBatchSize, maxBeamWidth, maxKvCacheLength, maxSequenceLength);
reshapeBuffers(maxBatchSize, maxBeamWidth, maxAttentionWindow, maxSequenceLength);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
void StatefulGptDecoder::reshapeBuffers(
SizeType batchSize, SizeType beamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength)
SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow, SizeType maxSequenceLength)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(batchSize > 0);
@ -84,7 +84,7 @@ void StatefulGptDecoder::reshapeBuffers(
TLLM_CHECK(maxSequenceLength > 0);
mMaxSequenceLength = maxSequenceLength;
mMaxKvCacheLength = maxKvCacheLength;
mMaxAttentionWindow = maxAttentionWindow;
auto const batchSizeShape = ITensor::makeShape({batchSize});
auto const batchSizeXbeamWidth = ITensor::makeShape({batchSize, beamWidth});
@ -137,7 +137,7 @@ void StatefulGptDecoder::newBatch(
auto const batchSize = inputLengthsShape.d[0];
auto const beamWidth = samplingConfig.beamWidth;
reshapeBuffers(batchSize, beamWidth, mMaxKvCacheLength, mMaxSequenceLength);
reshapeBuffers(batchSize, beamWidth, mMaxAttentionWindow, mMaxSequenceLength);
mDecoder->setup(samplingConfig, batchSize, mMaxSequenceLength);
// sanity checks, should always be true after reshape
@ -167,7 +167,7 @@ void StatefulGptDecoder::newBatch(
// inputs
auto& dInput = *mDecodingInput;
dInput.maxLength = maxInputLength;
dInput.maxKvCacheLength = mMaxKvCacheLength;
dInput.maxAttentionWindow = mMaxAttentionWindow;
dInput.batchSize = batchSize;
kernels::invokeFill(const_cast<ITensor&>(*dInput.endIds), endId, *stream);
dInput.embeddingBias = inputs.embeddingBias;

View File

@ -39,7 +39,7 @@ public:
StatefulGptDecoder(std::size_t vocabSize, std::size_t vocabSizePadded, CudaStreamPtr stream);
//! Setup the decoder before calling `forward()`
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength,
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow, SizeType maxSequenceLength,
SizeType maxTokensPerStep, nvinfer1::DataType dtype) override;
//! @brief Initialize the decoder with new batch of inputs.
@ -98,7 +98,8 @@ public:
}
private:
void reshapeBuffers(SizeType batchSize, SizeType beamWidth, SizeType mMaxKvCacheLength, SizeType maxSequenceLength);
void reshapeBuffers(
SizeType batchSize, SizeType beamWidth, SizeType mMaxAttentionWindow, SizeType maxSequenceLength);
private:
std::size_t const mVocabSize;
@ -116,6 +117,6 @@ private:
SizeType mNbSteps;
SizeType mMaxSequenceLength{};
SizeType mMaxKvCacheLength{};
SizeType mMaxAttentionWindow{};
};
} // namespace tensorrt_llm::runtime

View File

@ -70,11 +70,6 @@ public:
return mTensor.is_cuda() ? MemoryType::kGPU : mTensor.is_pinned() ? MemoryType::kPINNED : MemoryType::kCPU;
}
void resize(std::size_t newSize) override
{
ITensor::resize(newSize);
}
void release() override
{
resize(0);
@ -87,9 +82,19 @@ public:
void reshape(Shape const& dims) override
{
TLLM_CHECK(volumeNonNegative(dims) <= getCapacity());
mTensor.resize_(TorchUtils::shape(dims));
try
{
mTensor.resize_(TorchUtils::shape(dims));
}
catch (c10::Error const& e)
{
TLLM_THROW("%s", e.what_without_backtrace());
}
mDims = dims;
if (auto const newSize = volumeNonNegative(dims); mCapacity < newSize)
{
mCapacity = newSize;
}
}
private:

View File

@ -21,10 +21,12 @@
#include "tensorrt_llm/runtime/tllmLogger.h"
#include "tensorrt_llm/runtime/utils/multiDeviceUtils.h"
#include <algorithm>
#include <csignal>
#include <cstdlib>
#include <mpi.h>
#include <mutex>
#include <numeric>
using namespace tensorrt_llm::runtime;
namespace tc = tensorrt_llm::common;
@ -73,7 +75,7 @@ bool WorldConfig::validConfig(nvinfer1::ILogger& logger, SizeType tensorParallel
}
WorldConfig WorldConfig::mpi(nvinfer1::ILogger& logger, SizeType gpusPerNode, std::optional<SizeType> tensorParallelism,
std::optional<SizeType> pipelineParallelism)
std::optional<SizeType> pipelineParallelism, std::optional<std::vector<SizeType>> userSpecifiedDeviceIds)
{
initMpi(logger);
@ -85,14 +87,60 @@ WorldConfig WorldConfig::mpi(nvinfer1::ILogger& logger, SizeType gpusPerNode, st
auto pp = pipelineParallelism.value_or(1);
auto tp = tensorParallelism.value_or(mpiSize / pp);
TLLM_CHECK(mpiSize == tp * pp);
return WorldConfig{tp, pp, mpiRank, gpusPerNode};
// Pass the user-specified device lists to the WorldConfig. Otherwise create a default list of device ids.
std::vector<SizeType> deviceIds(mpiSize);
std::iota(deviceIds.begin(), deviceIds.end(), 0);
if (userSpecifiedDeviceIds)
{
TLLM_CHECK(static_cast<SizeType>(userSpecifiedDeviceIds.value().size()) == tp * pp);
deviceIds = userSpecifiedDeviceIds.value();
// For user provided device list, verify:
// 1) total number is smaller than the total cuda-visible device counts
// 2) all deviceIds is within the range
// 3) All ids are unique
// 4) if the deviceIds are contiguous, and throw a warning if not
TLLM_CHECK((gpusPerNode >= static_cast<SizeType>(deviceIds.size()))
&& (gpusPerNode > *std::max_element(deviceIds.begin(), deviceIds.end()))
&& *std::min_element(deviceIds.begin(), deviceIds.end()) >= 0);
gpusPerNode = deviceIds.size();
auto it = std::unique(deviceIds.begin(), deviceIds.end());
TLLM_CHECK(std::distance(deviceIds.begin(), it) == gpusPerNode);
std::sort(deviceIds.begin(), deviceIds.end());
// If the deviceIds are not contiguous, throw a warning
bool isContiguous = true;
for (SizeType i = 1; i < static_cast<SizeType>(deviceIds.size()); ++i)
{
if (deviceIds[i] != deviceIds[i - 1] + 1)
{
isContiguous = false;
break;
}
}
if (!isContiguous)
{
logger.log(nvinfer1::ILogger::Severity::kWARNING, "The user specified device IDs are not contiguous!");
}
std::stringstream ss;
ss << "Using user-specificed devices: [";
for (auto& id : deviceIds)
{
ss << id << ",";
}
ss << "]";
logger.log(nvinfer1::ILogger::Severity::kINFO, ss.str().c_str());
}
return WorldConfig{tp, pp, mpiRank, gpusPerNode, deviceIds};
}
WorldConfig WorldConfig::mpi(
SizeType gpusPerNode, std::optional<SizeType> tensorParallelism, std::optional<SizeType> pipelineParallelism)
WorldConfig WorldConfig::mpi(SizeType gpusPerNode, std::optional<SizeType> tensorParallelism,
std::optional<SizeType> pipelineParallelism, std::optional<std::vector<SizeType>> userSpecifiedDeviceIds)
{
TllmLogger logger{};
return mpi(logger, gpusPerNode, tensorParallelism, pipelineParallelism);
return mpi(logger, gpusPerNode, tensorParallelism, pipelineParallelism, userSpecifiedDeviceIds);
}
std::vector<SizeType> WorldConfig::getPipelineParallelGroup() const

View File

@ -138,7 +138,7 @@ void FtDynamicDecode<T>::setup(size_t batch_size, size_t beam_width, th::optiona
template <typename T>
void FtDynamicDecode<T>::forward(th::Tensor& logits, // (batch_size, beam_width, hidden_size)
int step, int max_input_length, int max_kv_cache_length, uint64_t ite, int local_batch_size, th::Tensor end_id,
int step, int max_input_length, int max_attention_window, uint64_t ite, int local_batch_size, th::Tensor end_id,
th::optional<th::Tensor> embedding_bias_opt, th::optional<th::Tensor> input_lengths_opt,
th::optional<th::Tensor> sequence_limit_length_opt, th::optional<th::Tensor> stop_words_list_opt,
th::optional<th::Tensor> bad_words_list_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
@ -158,7 +158,7 @@ void FtDynamicDecode<T>::forward(th::Tensor& logits, // (batch_size, beam_width,
auto const& logits_converted = convert_tensor<float>(logits);
auto const& end_ids_converted = convert_tensor<int>(end_id);
typename tensorrt_llm::layers::DynamicDecodeLayer<T>::ForwardParams forwardParams{step, static_cast<int>(ite),
max_input_length, max_kv_cache_length, local_batch_size, logits_converted, end_ids_converted};
max_input_length, max_attention_window, local_batch_size, logits_converted, end_ids_converted};
safeUpdate<int>(src_cache_indirection_opt, forwardParams.src_cache_indirection);
safeUpdate<int>(sequence_limit_length_opt, forwardParams.sequence_limit_length);
@ -275,7 +275,7 @@ void DynamicDecodeOp::setup(int64_t batch_size, int64_t beam_width, th::optional
}
th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max_input_length,
int64_t max_kv_cache_length, int64_t ite, int64_t local_batch_size, th::Tensor end_id,
int64_t max_attention_window, int64_t ite, int64_t local_batch_size, th::Tensor end_id,
th::optional<th::Tensor> embedding_bias_opt,
th::optional<th::Tensor> input_lengths_opt, // length of input contexts.
th::optional<th::Tensor> sequence_limit_length_opt, th::optional<th::Tensor> stop_words_list_opt,
@ -341,7 +341,7 @@ th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max
dynamic_decode_->forward(
// Inputs
logits, static_cast<int>(step), static_cast<int>(max_input_length), static_cast<int>(max_kv_cache_length),
logits, static_cast<int>(step), static_cast<int>(max_input_length), static_cast<int>(max_attention_window),
static_cast<uint32_t>(ite), static_cast<int>(local_batch_size), end_id, embedding_bias_opt, input_lengths_opt,
sequence_limit_length_opt, stop_words_list_opt, bad_words_list_opt, no_repeat_ngram_size_opt,
src_cache_indirection_opt,

View File

@ -39,7 +39,7 @@ public:
= 0;
virtual void forward(th::Tensor& logits, // (batch_size, beam_width, hidden_size)
int step, int max_input_length, int max_kv_cache_length, uint64_t ite, int local_batch_size, th::Tensor end_id,
int step, int max_input_length, int max_attention_window, uint64_t ite, int local_batch_size, th::Tensor end_id,
th::optional<th::Tensor> embedding_bias_opt, th::optional<th::Tensor> input_lengths_opt,
th::optional<th::Tensor> sequence_limit_length_opt, th::optional<th::Tensor> stop_words_list_opt,
th::optional<th::Tensor> bad_words_list_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
@ -77,7 +77,7 @@ public:
th::optional<th::Tensor> top_p_reset_ids_opt) override;
void forward(th::Tensor& logits, // (batch_size, beam_width, hidden_size)
int step, int max_input_length, int max_kv_cache_length, uint64_t ite, int local_batch_size, th::Tensor end_id,
int step, int max_input_length, int max_attention_window, uint64_t ite, int local_batch_size, th::Tensor end_id,
th::optional<th::Tensor> embedding_bias_opt, th::optional<th::Tensor> input_lengths_opt,
th::optional<th::Tensor> sequence_limit_length_opt, th::optional<th::Tensor> stop_words_list_opt,
th::optional<th::Tensor> bad_words_list_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
@ -121,7 +121,7 @@ public:
th::optional<th::Tensor> top_p_reset_ids_opt);
th::Tensor forward(th::Tensor logits, // (batch_size, beam_width, vocab_size)
int64_t step, int64_t max_input_length, int64_t max_kv_cache_length, int64_t ite, int64_t local_batch_size,
int64_t step, int64_t max_input_length, int64_t max_attention_window, int64_t ite, int64_t local_batch_size,
th::Tensor end_id, th::optional<th::Tensor> embedding_bias_opt,
th::optional<th::Tensor> input_lengths_opt, // length of input contexts.
th::optional<th::Tensor> sequence_limit_length_opt, th::optional<th::Tensor> stop_words_list_opt,

View File

@ -4,13 +4,33 @@ This document explains how to build and run the C++ tests, and the included [res
Windows users: Be sure to set DLL paths as specified in [Extra Steps for C++ Runtime Usage](../../windows/README.md#extra-steps-for-c-runtime-usage).
## Compile
## All-in-one script
The script [test_cpp.py](resources/scripts/test_cpp.py) can be executed to build TRT-LLM, build engines, generate expected outputs and run C++ tests all in one go.
To get an overview of the parameters call:
```bash
python3 cpp/tests/resources/scripts/test_cpp.py -h
```
It is possible to choose a single model for end-to-end tests or skip models that should not be tested.
An example call may look like this:
```bash
CPP_BUILD_DIR=cpp/build
MODEL_CACHE=/path/to/model_cache
python3 cpp/tests/resources/scripts/test_cpp.py -a "80-real;86-real" --build_dir ${CPP_BUILD_DIR} --trt_root /usr/local/tensorrt --model_cache ${MODEL_CACHE} --only_gptj
```
## Manual steps
### Compile
From the top-level directory call:
```bash
CPP_BUILD_DIR=cpp/build
python3 scripts/build_wheel.py -a "80-real;86-real" --build_dir ${CPP_BUILD_DIR}
python3 scripts/build_wheel.py -a "80-real;86-real" --build_dir ${CPP_BUILD_DIR} --trt_root /usr/local/tensorrt
pip install -r requirements-dev.txt --extra-index-url https://pypi.ngc.nvidia.com
pip install build/tensorrt_llm*.whl
cd $CPP_BUILD_DIR && make -j$(nproc) google-tests
@ -22,11 +42,11 @@ Single tests can be executed from `CPP_BUILD_DIR/tests`, e.g.
./$CPP_BUILD_DIR/tests/allocatorTest
```
## End-to-end tests
### End-to-end tests
`gptSessionTest`, `gptManagerTest` and `trtGptModelRealDecoderTest` require pre-built TensorRT engines, which are loaded in the tests. They also require data files which are stored in [cpp/tests/resources/data](resources/data).
### Build engines
#### Build engines
[Scripts](resources/scripts) are provided that download the GPT2 and GPT-J models from Huggingface and convert them to TensorRT engines.
The weights and built engines are stored under [cpp/tests/resources/models](resources/models).
@ -45,7 +65,7 @@ It is possible to build engines with tensor and pipeline parallelism for LLaMA u
PYTHONPATH=examples/llama python3 cpp/tests/resources/scripts/build_llama_engines.py --only_multi_gpu
```
### Generate expected output
#### Generate expected output
End-to-end tests read inputs and expected outputs from Numpy files located at [cpp/tests/resources/data](resources/data). The expected outputs can be generated using [scripts](resources/scripts) which employ the Python runtime to run the built engines:
@ -56,7 +76,7 @@ PYTHONPATH=examples:$PYTHONPATH python3 cpp/tests/resources/scripts/generate_exp
PYTHONPATH=examples:$PYTHONPATH python3 cpp/tests/resources/scripts/generate_expected_chatglm_output.py
```
### Generate data with tensor and pipeline parallelism
#### Generate data with tensor and pipeline parallelism
It is possible to generate tensor and pipeline parallelism data for LLaMA using 4 GPUs. To generate results from the top-level directory:
@ -64,7 +84,7 @@ It is possible to generate tensor and pipeline parallelism data for LLaMA using
PYTHONPATH=examples mpirun -n 4 python3 cpp/tests/resources/scripts/generate_expected_llama_output.py --only_multi_gpu
```
### Run test
#### Run test
After building the engines and generating the expected output execute the tests
@ -72,7 +92,7 @@ After building the engines and generating the expected output execute the tests
./$CPP_BUILD_DIR/tests/gptSessionTest
```
## Run all tests with ctest
### Run all tests with ctest
To run all tests and produce an xml report, call

View File

@ -292,8 +292,9 @@ float compare(void* _pa, void* _pb, int size, float scale)
#if defined(ENABLE_BF16)
if constexpr (std::is_same_v<T, __nv_bfloat16>)
{
// bfloat16 has fewer mantissa digits than float16, so the cumulative error will be larger.
diff_thres *= 2.f;
// bfloat16 has fewer mantissa digits than float16(10 bits for fp16 but only 7 bits for bf16), so the cumulative
// error will be larger.
diff_thres *= 3.f;
}
else
#endif
@ -308,8 +309,7 @@ float compare(void* _pa, void* _pb, int size, float scale)
template <typename T1, typename T2>
void random_fill(std::vector<T1>& vec, T2 minv, T2 maxv)
{
std::random_device rd;
std::mt19937 gen(rd());
std::mt19937 gen(20231205);
std::uniform_real_distribution<float> dis(static_cast<float>(minv), static_cast<float>(maxv));
for (auto& v : vec)
{

View File

@ -98,7 +98,10 @@ def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1):
run_command(["git", "lfs", "pull", "--include", model_file_name],
cwd=hf_dir)
(hf_dir / "model.safetensors").unlink(missing_ok=True)
safetensor_file = hf_dir / "model.safetensors"
has_safetensor = safetensor_file.exists()
if has_safetensor:
safetensor_file.rename(str(safetensor_file) + ".bak")
assert (hf_dir / model_file_name).is_file()
@ -158,6 +161,9 @@ def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1):
engine_dir / 'fp16-plugin-packed-paged-gather' / tp_pp_dir,
tp_size, '--gather_all_token_logits', *ifb_args)
if has_safetensor:
_pl.Path(str(safetensor_file) + ".bak").rename(safetensor_file)
print("Done.")

View File

@ -32,7 +32,6 @@ def build_engine(weight_dir: _pl.Path, engine_dir: _pl.Path, *args):
'--dtype=float16',
'--logits_dtype=float16',
'--use_gemm_plugin=float16',
'--use_layernorm_plugin=float16',
'--max_batch_size=32',
'--max_input_len=40',
'--max_output_len=20',

View File

@ -55,7 +55,7 @@ def generate_output(engine: str,
str(output_dir / (output_name + '.csv')), '--max_output_len',
str(max_output_len), '--num_beams',
str(num_beams), '--output_logits_npy',
str(output_logits_npy)
str(output_logits_npy), '--use_py_session'
])
run.main(args)

View File

@ -52,7 +52,7 @@ def generate_output(engine: str,
str(output_dir / (output_name + '.npy')), '--output_csv',
str(output_dir / (output_name + '.csv')), '--max_output_len',
str(max_output_len), '--num_beams',
str(num_beams)
str(num_beams), '--use_py_session'
])
run.main(args)

View File

@ -52,7 +52,7 @@ def generate_output(engine: str,
str(output_dir / (output_name + '.npy')), '--output_csv',
str(output_dir / (output_name + '.csv')), '--max_output_len',
str(max_output_len), '--num_beams',
str(num_beams)
str(num_beams), '--use_py_session'
])
run.main(args)

View File

@ -68,7 +68,7 @@ def build_trt_llm(python_exe: str,
python_exe, "scripts/build_wheel.py", "--cuda_architectures",
cuda_architectures, "--build_dir",
str(build_dir), "--dist_dir",
str(dist_dir)
str(dist_dir), "--python_bindings"
]
if trt_root is not None:
build_wheel += ["--trt_root", str(trt_root)]

Some files were not shown because too many files have changed in this diff Show More