Update TensorRT-LLM (#1793)

Co-authored-by: DreamGenX <x@dreamgen.com>
Co-authored-by: Ace-RR <78812427+Ace-RR@users.noreply.github.com>
Co-authored-by: bprus <39293131+bprus@users.noreply.github.com>
Co-authored-by: janpetrov <janpetrov@icloud.com>
This commit is contained in:
石晓伟 2024-06-18 18:18:23 +08:00 committed by GitHub
parent db4edea1e1
commit 2a115dae84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
318 changed files with 8672 additions and 4814 deletions

View File

@ -232,7 +232,7 @@ ${HOME}/.local/bin/trtllm-build \
--output_dir ${LORA_ENGINE} \
--max_batch_size ${MAX_BATCH} \
--max_input_len $MAX_LEN \
--max_output_len $MAX_LEN \
--max_seq_len $((2*${MAX_LEN})) \
--gemm_plugin float16 \
--lora_plugin float16 \
--use_paged_context_fmha enable \

View File

@ -17,6 +17,7 @@
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/rawEngine.h"
#include "tensorrt_llm/runtime/tllmLogger.h"
#include "tensorrt_llm/runtime/tllmRuntime.h"
#include "tensorrt_llm/runtime/worldConfig.h"
@ -78,11 +79,10 @@ void benchmarkBert(std::string const& modelName, std::filesystem::path const& da
{
auto const worldConfig = WorldConfig::mpi();
auto const enginePath = dataPath / engineFilename(dataPath, worldConfig, modelName);
auto engineBlob = loadEngine(enginePath.string());
for (float gpuWeightsPercent : gpuWeightsPercents)
{
auto rt = std::make_shared<TllmRuntime>(engineBlob.data(), engineBlob.size(), gpuWeightsPercent, *logger);
auto rt = std::make_shared<TllmRuntime>(RawEngine(enginePath), logger.get(), gpuWeightsPercent);
rt->addContext(0);
for (auto inLen : inLens)
{

View File

@ -150,6 +150,7 @@ struct BenchmarkParams
bool streaming{false};
bool enableExpDelays{false};
std::optional<float> requestRate{std::nullopt};
std::optional<SizeType32> maxBatchSize{std::nullopt};
int randomSeed = 430;
std::optional<int> maxAttentionWindow{std::nullopt};
@ -785,6 +786,10 @@ public:
executorConfig.setPeftCacheConfig(peftCacheConfig);
executorConfig.setBatchingType(
modelType == TrtGptModelType::V1 ? texec::BatchingType::kSTATIC : texec::BatchingType::kINFLIGHT);
if (benchmarkParams.maxBatchSize)
{
executorConfig.setMaxBatchSize(benchmarkParams.maxBatchSize.value());
}
mExecutor = std::make_unique<texec::Executor>(trtEnginePath, texec::ModelType::kDECODER_ONLY, executorConfig);
@ -1339,6 +1344,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
optionalParams.kvCacheConfig.onboardBlocks = benchmarkParams.kvOnboardBlocks;
optionalParams.gpuWeightsPercent = benchmarkParams.gpuWeightsPercent;
optionalParams.maxBeamWidth = beamWidth;
optionalParams.maxBatchSize = benchmarkParams.maxBatchSize;
optionalParams.schedulerConfig = texec::SchedulerConfig{capacitySchedulerPolicy};
auto const jsonConfig = GptJsonConfig::parse(engineDir / "config.json");
@ -1628,6 +1634,7 @@ int main(int argc, char* argv[])
options.add_options()("request_rate",
"request rate in reqs/sec. Skipping this arg or negative value will trigger offline/0-delay.",
cxxopts::value<float>());
options.add_options()("max_batch_size", "The max runtime batch size when benchmarking", cxxopts::value<int>());
options.add_options()("enable_trt_overlap", "Overlap TRT context preparation and execution",
cxxopts::value<bool>()->default_value("false"));
options.add_options()("enable_exp_delays", "Enables exponential delay distr to mimic real world request arrival",
@ -1777,6 +1784,12 @@ int main(int argc, char* argv[])
benchmarkParams.requestRate = result["request_rate"].as<float>();
}
// Argument: request rate
if (result.count("max_batch_size"))
{
benchmarkParams.maxBatchSize = result["max_batch_size"].as<int>();
}
benchmarkParams.enableExpDelays = result["enable_exp_delays"].as<bool>();
// Argument: Enable batch stats output

View File

@ -32,7 +32,7 @@ class BuildConfig:
max_batch_size: int
max_input_len: Optional[int] = None
num_kv_heads: Optional[int] = None
max_output_len: Optional[int] = None
max_seq_len: Optional[int] = None
max_beam_width: int = 1
# TRT builder_optimization_level from 0 to 5
builder_opt: Optional[int] = None
@ -89,7 +89,7 @@ class EncDecBuildConfig:
normalize_before: Optional[bool] = None
max_encoder_input_len: Optional[int] = None
max_decoder_input_len: Optional[int] = None
max_output_len: Optional[int] = None
max_seq_len: Optional[int] = None
builder_opt: Optional[int] = None
n_mels: Optional[int] = None
skip_cross_qkv: bool = False
@ -122,7 +122,7 @@ _allowed_configs = {
n_positions=1024,
max_batch_size=256,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
)),
"gpt_1.5b":
@ -138,7 +138,7 @@ _allowed_configs = {
n_positions=1024,
max_batch_size=256,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
)),
"gpt_175b":
@ -154,7 +154,7 @@ _allowed_configs = {
n_positions=2048,
max_batch_size=64,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
)),
"gpt_350m_moe":
@ -170,7 +170,7 @@ _allowed_configs = {
n_positions=1024,
max_batch_size=256,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
moe_num_experts=8,
moe_top_k=1,
@ -188,7 +188,7 @@ _allowed_configs = {
n_positions=1024,
max_batch_size=256,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
quantization="int8_sq_per_tensor",
)),
@ -205,7 +205,7 @@ _allowed_configs = {
n_positions=1024,
max_batch_size=256,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
quantization="int8_sq_per_token_channel",
)),
@ -222,7 +222,7 @@ _allowed_configs = {
n_positions=1024,
max_batch_size=256,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
position_embedding_type='rope_gpt_neox',
rotary_pct=0.5,
@ -241,7 +241,7 @@ _allowed_configs = {
n_positions=2048,
max_batch_size=256,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
pre_norm=False,
do_layer_norm_before=False,
@ -259,7 +259,7 @@ _allowed_configs = {
n_positions=2048,
max_batch_size=256,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
pre_norm=False,
do_layer_norm_before=True,
@ -277,7 +277,7 @@ _allowed_configs = {
n_positions=2048,
max_batch_size=256,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
pre_norm=False,
do_layer_norm_before=True,
@ -295,7 +295,7 @@ _allowed_configs = {
n_positions=2048,
max_batch_size=256,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
pre_norm=False,
do_layer_norm_before=True,
@ -313,7 +313,7 @@ _allowed_configs = {
n_positions=2048,
max_batch_size=64,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
pre_norm=True,
do_layer_norm_before=True,
@ -332,7 +332,7 @@ _allowed_configs = {
n_positions=8192,
max_batch_size=256,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
)),
"starcoder2_3b":
@ -351,7 +351,7 @@ _allowed_configs = {
rotary_pct=1.0,
max_batch_size=256,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
)),
"llama_7b":
@ -368,7 +368,7 @@ _allowed_configs = {
inter_size=11008,
max_batch_size=128,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
)),
"llama_13b":
@ -385,7 +385,7 @@ _allowed_configs = {
inter_size=13824,
max_batch_size=128,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
)),
"llama_30b":
@ -402,7 +402,7 @@ _allowed_configs = {
inter_size=17920,
max_batch_size=64,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
)),
"llama_70b":
@ -420,7 +420,7 @@ _allowed_configs = {
inter_size=28672,
max_batch_size=64,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
)),
"llama_70b_long_context":
@ -437,7 +437,7 @@ _allowed_configs = {
inter_size=28672,
max_batch_size=16,
max_input_len=8000,
max_output_len=200,
max_seq_len=8200,
builder_opt=None,
enable_multi_block_mode=True)),
"llama_70b_long_generation":
@ -454,7 +454,7 @@ _allowed_configs = {
inter_size=28672,
max_batch_size=64,
max_input_len=200,
max_output_len=16384,
max_seq_len=16584,
builder_opt=None,
enable_multi_block_mode=True)),
"llama_70b_sq_per_tensor":
@ -471,7 +471,7 @@ _allowed_configs = {
inter_size=28672,
max_batch_size=128,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
quantization="int8_sq_per_tensor")),
"mixtral_8x7b":
@ -489,7 +489,7 @@ _allowed_configs = {
inter_size=14336,
max_batch_size=128,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
moe_num_experts=8,
moe_top_k=2,
@ -508,7 +508,7 @@ _allowed_configs = {
rotary_dim=64,
max_batch_size=128,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
)),
"gptneox_20b":
@ -525,7 +525,7 @@ _allowed_configs = {
rotary_dim=24,
max_batch_size=16,
max_input_len=512,
max_output_len=512,
max_seq_len=1024,
builder_opt=None,
)),
"chatglm_6b":
@ -543,7 +543,7 @@ _allowed_configs = {
n_positions=2048,
max_batch_size=256,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
remove_input_padding=False,
)),
@ -562,7 +562,7 @@ _allowed_configs = {
n_positions=2048,
max_batch_size=256,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
remove_input_padding=False,
)),
@ -581,7 +581,7 @@ _allowed_configs = {
n_positions=2048,
max_batch_size=256,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
remove_input_padding=False,
)),
@ -600,7 +600,7 @@ _allowed_configs = {
n_positions=1024,
max_batch_size=128,
max_input_len=1024,
max_output_len=256,
max_seq_len=1280,
builder_opt=None,
remove_input_padding=False,
)),
@ -617,7 +617,7 @@ _allowed_configs = {
n_positions=2048,
max_batch_size=32,
max_input_len=1024,
max_output_len=1024,
max_seq_len=2048,
builder_opt=None,
)),
"bloom_176b":
@ -633,7 +633,7 @@ _allowed_configs = {
n_positions=2048,
max_batch_size=8,
max_input_len=1024,
max_output_len=1024,
max_seq_len=2048,
builder_opt=None,
)),
"bert_base":
@ -703,7 +703,7 @@ _allowed_configs = {
n_positions=2048,
max_batch_size=256,
max_input_len=1024,
max_output_len=1024,
max_seq_len=2048,
builder_opt=None,
bias=True,
use_alibi=True,
@ -724,7 +724,7 @@ _allowed_configs = {
n_positions=2048,
max_batch_size=128,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
bias=False,
use_alibi=False,
@ -745,7 +745,7 @@ _allowed_configs = {
n_positions=2048,
max_batch_size=64,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
bias=False,
use_alibi=False,
@ -766,7 +766,7 @@ _allowed_configs = {
n_positions=2048,
max_batch_size=8,
max_input_len=1024,
max_output_len=1024,
max_seq_len=2048,
builder_opt=None,
bias=False,
use_alibi=False,
@ -791,7 +791,7 @@ _allowed_configs = {
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
max_seq_len=201,
builder_opt=None,
)),
"t5_base":
@ -812,7 +812,7 @@ _allowed_configs = {
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
max_seq_len=201,
builder_opt=None,
)),
"t5_large":
@ -833,7 +833,7 @@ _allowed_configs = {
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
max_seq_len=201,
builder_opt=None,
)),
"t5_3b":
@ -854,7 +854,7 @@ _allowed_configs = {
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
max_seq_len=201,
builder_opt=None,
)),
"t5_11b":
@ -875,7 +875,7 @@ _allowed_configs = {
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
max_seq_len=201,
builder_opt=None,
)),
"flan_t5_small":
@ -897,7 +897,7 @@ _allowed_configs = {
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
max_seq_len=201,
builder_opt=None,
)),
"flan_t5_base":
@ -919,7 +919,7 @@ _allowed_configs = {
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
max_seq_len=201,
builder_opt=None,
)),
"flan_t5_large":
@ -941,7 +941,7 @@ _allowed_configs = {
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
max_seq_len=201,
builder_opt=None,
)),
"flan_t5_xl":
@ -963,7 +963,7 @@ _allowed_configs = {
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
max_seq_len=201,
builder_opt=None,
)),
"flan_t5_xxl":
@ -985,7 +985,7 @@ _allowed_configs = {
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
max_seq_len=201,
builder_opt=None,
)),
"bart_large_cnn":
@ -1008,7 +1008,7 @@ _allowed_configs = {
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
max_seq_len=201,
builder_opt=None,
)),
"mbart_large_50_many_to_one_mmt":
@ -1030,7 +1030,7 @@ _allowed_configs = {
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
max_output_len=200,
max_seq_len=201,
builder_opt=None,
)),
"baichuan_7b":
@ -1047,7 +1047,7 @@ _allowed_configs = {
inter_size=11008,
max_batch_size=128,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
)),
"baichuan2_7b_chat":
@ -1064,7 +1064,7 @@ _allowed_configs = {
inter_size=11008,
max_batch_size=128,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
)),
"baichuan_13b_chat":
@ -1081,7 +1081,7 @@ _allowed_configs = {
inter_size=13696,
max_batch_size=64,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
)),
"baichuan2_13b_chat":
@ -1098,7 +1098,7 @@ _allowed_configs = {
inter_size=13696,
max_batch_size=64,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
)),
"internlm_chat_7b":
@ -1116,7 +1116,7 @@ _allowed_configs = {
inter_size=11008,
max_batch_size=128,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
bias=True,
)),
@ -1135,7 +1135,7 @@ _allowed_configs = {
inter_size=13824,
max_batch_size=64,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
bias=False,
)),
@ -1152,7 +1152,7 @@ _allowed_configs = {
inter_size=22016,
max_batch_size=128,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
bias=False)),
"qwen_14b_chat":
@ -1169,7 +1169,7 @@ _allowed_configs = {
inter_size=27392,
max_batch_size=64,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
)),
"qwen1.5_7b_chat":
@ -1185,7 +1185,7 @@ _allowed_configs = {
inter_size=11008,
max_batch_size=128,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
bias=False)),
"qwen1.5_14b_chat":
@ -1202,7 +1202,7 @@ _allowed_configs = {
inter_size=13696,
max_batch_size=64,
max_input_len=512,
max_output_len=200,
max_seq_len=712,
builder_opt=None,
)),
"mamba_2.8b":
@ -1218,7 +1218,7 @@ _allowed_configs = {
n_positions=8192,
max_batch_size=64,
max_input_len=1024,
max_output_len=1024,
max_seq_len=2048,
state_size=16,
conv_kernel=4,
rnn_hidden_size=5120,
@ -1238,7 +1238,7 @@ _allowed_configs = {
n_positions=8192,
max_batch_size=64,
max_input_len=1024,
max_output_len=1024,
max_seq_len=2048,
state_size=16,
conv_kernel=4,
rnn_hidden_size=4096,
@ -1258,7 +1258,7 @@ _allowed_configs = {
n_positions=8192,
max_batch_size=64,
max_input_len=1024,
max_output_len=1024,
max_seq_len=2048,
state_size=16,
conv_kernel=4,
rnn_hidden_size=3072,
@ -1278,7 +1278,7 @@ _allowed_configs = {
n_positions=8192,
max_batch_size=64,
max_input_len=1024,
max_output_len=1024,
max_seq_len=2048,
state_size=16,
conv_kernel=4,
rnn_hidden_size=2048,
@ -1298,7 +1298,7 @@ _allowed_configs = {
n_positions=8192,
max_batch_size=64,
max_input_len=1024,
max_output_len=1024,
max_seq_len=2048,
state_size=16,
conv_kernel=4,
rnn_hidden_size=1536,
@ -1323,7 +1323,7 @@ _allowed_configs = {
max_batch_size=8,
max_encoder_input_len=1500,
max_decoder_input_len=1,
max_output_len=200,
max_seq_len=201,
builder_opt=None,
)),
"recurrentgemma_2b":
@ -1341,7 +1341,7 @@ _allowed_configs = {
n_positions=8192,
max_batch_size=64,
max_input_len=1024,
max_output_len=1024,
max_seq_len=2048,
position_embedding_type='rope_gpt_neox',
rotary_pct=0.5,
conv_kernel=4,

View File

@ -184,6 +184,15 @@ def parse_arguments():
help=
('If this option is specified, it will override the max output len of '
'TRT engines to the specified value instead of using pre-defined one'))
parser.add_argument(
'--max_seq_len',
'--max_decoder_seq_len',
dest='max_seq_len',
type=int,
default=None,
help=
('If this option is specified, it will override the max sequence len of '
'TRT engines to the specified value instead of using pre-defined one'))
parser.add_argument(
'--max_batch_size',
type=int,
@ -351,6 +360,21 @@ def main(args):
rank = tensorrt_llm.mpi_rank()
world_size = tensorrt_llm.mpi_world_size()
if args.max_output_len:
logger.warning(
'--max_output_len has been deprecated in favor of --max_seq_len')
if args.max_input_len:
if args.max_seq_len:
logger.warning(
'--max_seq_len has been overwritten due to --max_output_len being specified'
)
args.max_seq_len = args.max_input_len + args.max_output_len
else:
raise Exception(
f"--max_output_len is specified but not --max_input_len")
del args.max_output_len
# TODO: Re-enable memory monitor for multi-gpu benchmarks.
# Current Mem Monitor will cause benchmark script hang
# because MPI does not work well with multiprocessing.

View File

@ -136,6 +136,15 @@ def parse_arguments():
help=
('If this option is specified, it will override the max output len of '
'TRT engines to the specified value instead of using pre-defined one'))
parser.add_argument(
'--max_seq_len',
'--max_decoder_seq_len',
dest='max_seq_len',
type=int,
default=None,
help=
('If this option is specified, it will override the max sequence len of '
'TRT engines to the specified value instead of using pre-defined one'))
parser.add_argument(
'--max_batch_size',
type=int,
@ -254,8 +263,24 @@ def build_gpt(args):
if args.max_batch_size is None else args.max_batch_size
max_input_len = build_config['max_input_len'] \
if args.max_input_len is None else args.max_input_len
max_output_len = build_config['max_output_len'] \
if args.max_output_len is None else args.max_output_len
if args.max_output_len:
logger.warning(
'--max_output_len has been deprecated in favor of --max_seq_len')
if args.max_input_len:
if args.max_seq_len:
logger.warning(
'--max_seq_len has been overwritten due to --max_output_len being specified'
)
args.max_seq_len = args.max_input_len + args.max_output_len
else:
raise Exception(
f"max_output_len is specified but not max_input_len")
del args.max_output_len
max_seq_len = build_config['max_seq_len'] \
if args.max_seq_len is None else args.max_seq_len
max_beam_width = build_config['max_beam_width'] \
if args.max_beam_width is None else args.max_beam_width
@ -308,7 +333,7 @@ def build_gpt(args):
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_input_len=max_input_len,
max_output_len=max_output_len,
max_seq_len=max_seq_len,
max_num_tokens=max_num_tokens,
int8=(quant_mode.has_act_and_weight_quant()
or quant_mode.is_int8_weight_only()),
@ -675,7 +700,6 @@ def build_gpt(args):
config['quantization'].update({
'has_zero_point': False,
'pre_quant_scale': True,
'exclude_modules': [],
})
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.FalconForCausalLM(config)
@ -759,7 +783,6 @@ def build_gpt(args):
"group_size": 128,
"has_zero_point": False,
"pre_quant_scale": True,
"exclude_modules": [],
})
elif 'gptq' in args.quantization:
config['quantization'].update({
@ -968,14 +991,14 @@ def build_gpt(args):
# Forward
print(
f"max_batch_size: {max_batch_size}, max_input_len: {max_input_len}, max_output_len: {max_output_len}, max_beam_width: {max_beam_width}"
f"max_batch_size: {max_batch_size}, max_input_len: {max_input_len}, max_seq_len: {max_seq_len}, max_beam_width: {max_beam_width}"
)
# NOTE: all other models use PretrainedModel.prepare_inputs(...)
# except RecurrentGemmaForCausalLM and MambaForCausalLM
inputs = tensorrt_llm_model.prepare_inputs(
max_batch_size=max_batch_size,
max_input_len=max_input_len,
max_seq_len=max_input_len + max_output_len,
max_seq_len=max_seq_len,
max_num_tokens=max_num_tokens,
use_cache=True,
max_beam_width=max_beam_width,
@ -1231,7 +1254,7 @@ def enc_dec_build_helper(component, config, args):
max_batch_size=config['max_batch_size'],
max_beam_width=config['max_beam_width'],
max_decoder_input_len=config['max_decoder_input_len'],
max_output_len=config['max_output_len'],
max_seq_len=config['max_seq_len'],
max_encoder_input_len=config['max_encoder_input_len'],
opt_level=config['builder_opt'],
cross_attention=(component == 'decoder'),
@ -1473,7 +1496,7 @@ def enc_dec_build_helper(component, config, args):
max_batch_size=config['max_batch_size'],
max_beam_width=config['max_beam_width'],
max_decoder_input_len=config['max_decoder_input_len'],
max_seq_len=config['max_output_len'],
max_seq_len=config['max_seq_len'],
max_encoder_input_len=1500, # n_audio_ctx
)
tllm_model(**inputs)
@ -1482,7 +1505,7 @@ def enc_dec_build_helper(component, config, args):
max_batch_size=config['max_batch_size'],
max_beam_width=config['max_beam_width'],
max_decoder_input_len=config['max_decoder_input_len'],
max_seq_len=config['max_output_len'],
max_seq_len=config['max_seq_len'],
max_encoder_input_len=config['max_encoder_input_len'],
)
@ -1548,8 +1571,24 @@ def build_enc_dec(args):
build_config['max_encoder_input_len'] = build_config['max_encoder_input_len'] \
if args.max_input_len is None else args.max_input_len
build_config['max_decoder_input_len'] = 1
build_config['max_output_len'] = build_config['max_output_len'] \
if args.max_output_len is None else args.max_output_len
if args.max_output_len:
logger.warning(
'--max_output_len has been deprecated in favor of --max_seq_len')
if args.max_input_len:
if args.max_seq_len:
logger.warning(
'--max_seq_len has been overwritten due to --max_output_len being specified'
)
args.max_seq_len = args.max_input_len + args.max_output_len
else:
raise Exception(
f"max_output_len is specified but not max_input_len")
del args.max_output_len
build_config['max_seq_len'] = build_config['max_seq_len'] \
if args.max_seq_len is None else args.max_seq_len
build_config[
'max_beam_width'] = 1 if args.max_beam_width is None else args.max_beam_width

View File

@ -115,7 +115,7 @@ class EncDecBenchmark(BaseBenchmark):
self.max_batch_size = config["builder_config"]["max_batch_size"]
self.max_input_len = config["builder_config"][
"max_encoder_input_len"]
self.max_output_len = config["builder_config"]["max_output_len"]
self.max_seq_len = config["builder_config"]["max_seq_len"]
self.n_mels = config["builder_config"][
'n_mels'] if 'whisper' in self.model_name else 0
@ -180,8 +180,8 @@ class EncDecBenchmark(BaseBenchmark):
if args.max_batch_size is None else args.max_batch_size
self.max_input_len = build_config['max_encoder_input_len'] \
if args.max_input_len is None else args.max_input_len
self.max_output_len = build_config['max_output_len'] \
if args.max_output_len is None else args.max_output_len
self.max_seq_len = build_config['max_seq_len'] \
if args.max_seq_len is None else args.max_seq_len
self.n_mels = build_config[
'n_mels'] if 'whisper' in self.model_name else 0
# Build engine
@ -218,10 +218,11 @@ class EncDecBenchmark(BaseBenchmark):
f"[WARNING] whisper benchmark is input_len=1500, no text prompt, output_len=arbitrary"
)
for inlen, outlen in self.in_out_lens:
if (inlen > self.max_input_len or outlen > self.max_output_len):
if (inlen > self.max_input_len
or inlen + outlen > self.max_seq_len):
print(
f"[WARNING] check inlen({inlen}) <= max_inlen({self.max_input_len}) and "
f"outlen({outlen}) <= max_outlen({self.max_output_len}) failed, skipping."
f"inlen({inlen}) + outlen({outlen}) <= max_seqlen({self.max_seq_len}) failed, skipping."
)
continue
for batch_size in self.batch_sizes:

View File

@ -88,8 +88,8 @@ class GPTBenchmark(BaseBenchmark):
self.max_batch_size = args.max_batch_size
if args.max_input_len is not None:
self.max_input_len = args.max_input_len
if args.max_output_len is not None:
self.max_output_len = args.max_output_len
if args.max_seq_len is not None:
self.max_seq_len = args.max_seq_len
self.quant_config = get_quant_config(args.quantization)
self.quant_mode = self.quant_config.quant_mode
@ -209,10 +209,10 @@ class GPTBenchmark(BaseBenchmark):
def get_config(self):
for inlen, outlen in self.in_out_lens:
if inlen > self.max_input_len or outlen > self.max_output_len:
if inlen > self.max_input_len or inlen + outlen > self.max_seq_len:
print(
f'[WARNING] check inlen({inlen}) <= max_inlen({self.max_input_len}) and '
f'outlen({outlen}) <= max_outlen({self.max_output_len}) failed, skipping.'
f'[WARNING] check inlen({inlen}) <= max_inlen({self.max_input_len}) or '
f'seqlen({inlen + outlen}) <= max_seq_len({self.max_seq_len}) failed, skipping.'
)
continue
for batch_size in self.batch_sizes:
@ -314,7 +314,7 @@ class GPTBenchmark(BaseBenchmark):
output_length=outlen,
max_batch_size=self.build_config.max_batch_size,
max_input_len=self.build_config.max_input_len,
max_output_len=self.build_config.max_output_len,
max_seq_len=self.build_config.max_seq_len,
max_beam_width=self.build_config.max_beam_width)
for k, v in build_args.items():
tensorrt_llm.logger.info(f"{prefix} {k}:{v}")

View File

@ -84,8 +84,8 @@ class gptSessionBenchmarker:
max_batch_size,
"--max_input_len",
max_isl,
"--max_output_len",
max_osl,
"--max_seq_len",
max_osl + max_isl,
"--context_fmha",
"enable",
# Set the attention plugin data type.

View File

@ -140,8 +140,8 @@ def get_trtllm_build_command(benchmark_cfg: BenchmarkConfig) -> List[str]:
benchmark_cfg.world_size,
"--max_input_len",
max_isl,
"--max_output_len",
max_osl,
"--max_seq_len",
max_osl + max_isl,
"--context_fmha",
"enable",
# Set the attention plugin data type.

View File

@ -82,6 +82,9 @@ struct KvCacheStats
SizeType32 freeNumBlocks;
SizeType32 usedNumBlocks;
SizeType32 toksPerBlock;
SizeType32 allocTotalBlocks;
SizeType32 allocNewBlocks;
SizeType32 reusedBlocks;
};
// Basic building block of a paged KV cache - a single
@ -329,6 +332,16 @@ public:
return mFreePrimaryBlocks.size();
}
[[nodiscard]] SizeType32 getNumAllocTotalBlocks() const
{
return mAllocTotalBlocks;
}
[[nodiscard]] SizeType32 getNumAllocNewBlocks() const
{
return mAllocNewBlocks;
}
[[nodiscard]] SizeType32 getNumReusedBlocks() const noexcept
{
return mReusedBlocks;
@ -496,6 +509,21 @@ public:
return mBlockManager.getNumFreeBlocks();
}
[[nodiscard]] SizeType32 getNumAllocTotalBlocks() const
{
return mBlockManager.getNumAllocTotalBlocks();
}
[[nodiscard]] SizeType32 getNumAllocNewBlocks() const
{
return mBlockManager.getNumAllocNewBlocks();
}
[[nodiscard]] SizeType32 getNumReusedBlocks() const noexcept
{
return mBlockManager.getNumReusedBlocks();
}
[[nodiscard]] KvCacheStats getKvCacheStats() const
{
KvCacheStats kvCacheStats;
@ -503,6 +531,9 @@ public:
kvCacheStats.freeNumBlocks = getNumFreeBlocks();
kvCacheStats.usedNumBlocks = getUsedNumBlocks();
kvCacheStats.toksPerBlock = getTokensPerBlock();
kvCacheStats.allocTotalBlocks = getNumAllocTotalBlocks();
kvCacheStats.allocNewBlocks = getNumAllocNewBlocks();
kvCacheStats.reusedBlocks = getNumReusedBlocks();
return kvCacheStats;
}

View File

@ -62,7 +62,7 @@ public:
using VecLogProbs = std::vector<float>;
using BeamTokens = std::vector<VecTokens>;
using TensorPtr = TTensor;
using LogitsPostProcessor = std::function<void(RequestIdType, TensorPtr&, BeamTokens const&, TStream)>;
using LogitsPostProcessor = std::function<void(RequestIdType, TensorPtr&, BeamTokens const&, TStream const&)>;
GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::shared_ptr<VecTokens> inputTokens,
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
@ -76,6 +76,7 @@ public:
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt,
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false,
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt,
bool applyLogitsPostProcessorBatched = false,
std::optional<std::shared_ptr<VecTokens>> encoderInputTokens = std::nullopt, bool returnEncoderOutput = false)
: mRequestId(requestId)
, mPromptLen(inputTokens->size())
@ -86,6 +87,7 @@ public:
, mEndId(endId)
, mPadId(padId)
, mLogitsPostProcessor(logitsPostProcessor)
, mApplyLogitsPostProcessorBatched(applyLogitsPostProcessorBatched)
, mOrigPromptLen(mPromptLen)
, mMaxSentTokenPos(mPromptLen - 1)
, mEmbeddingBias(std::move(embeddingBias))
@ -679,7 +681,7 @@ public:
void allocContextLogitsHost(SizeType32 vocabSizePadded, nvinfer1::DataType logitsDataType)
{
mContextLogitsHost = runtime::BufferManager::pinned(
mContextLogitsHost = runtime::BufferManager::pinnedPool(
runtime::ITensor::makeShape({mPromptLen, vocabSizePadded}), logitsDataType);
}
@ -695,13 +697,13 @@ public:
void allocGenerationLogitsHost(SizeType32 vocabSizePadded, nvinfer1::DataType logitsDataType)
{
mGenerationLogitsHost = runtime::BufferManager::pinned(
mGenerationLogitsHost = runtime::BufferManager::pinnedPool(
runtime::ITensor::makeShape({mSamplingConfig.beamWidth, mMaxNewTokens, vocabSizePadded}), logitsDataType);
}
void allocTargetModelAcceptedTokenLogitsHost(SizeType32 vocabSizePadded, nvinfer1::DataType logitsDataType)
{
mGenerationLogitsHost = runtime::BufferManager::pinned(
mGenerationLogitsHost = runtime::BufferManager::pinnedPool(
runtime::ITensor::makeShape({getNumDraftTokens() + 1, vocabSizePadded}), logitsDataType);
}
@ -948,6 +950,7 @@ public:
std::optional<TokenIdType> mPadId;
std::optional<SizeType32> mSeqSlot;
std::optional<LogitsPostProcessor> mLogitsPostProcessor;
bool mApplyLogitsPostProcessorBatched;
protected:
BeamTokens mTokens;
@ -1073,20 +1076,24 @@ public:
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt,
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false,
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt,
bool applyLogitsPostProcessorBatched = false,
std::optional<std::shared_ptr<VecTokens>> encoderInputTokens = std::nullopt, bool returnEncoderOutput = false)
: Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId,
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList),
std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig),
returnLogProbs, returnContextLogits, returnGenerationLogits, std::move(draftTokens), std::move(draftLogits),
excludeInputFromOutput, std::move(logitsPostProcessor), std::move(encoderInputTokens), returnEncoderOutput)
excludeInputFromOutput, std::move(logitsPostProcessor), applyLogitsPostProcessorBatched,
std::move(encoderInputTokens), returnEncoderOutput)
{
}
LlmRequest(RequestIdType requestId, executor::Request const& Request,
std::optional<Base::LogitsPostProcessor> logitsPostProcessor = std::nullopt)
std::optional<Base::LogitsPostProcessor> logitsPostProcessor = std::nullopt,
bool applyLogitsPostProcessorBatched = false)
: Base(requestId, Request)
{
mLogitsPostProcessor = std::move(logitsPostProcessor);
mApplyLogitsPostProcessorBatched = applyLogitsPostProcessorBatched;
}
void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager)

View File

@ -41,7 +41,7 @@ public:
bool normalizeLogProbs = true, bool enableChunkedContext = false,
PeftCacheManagerConfig const& peftCacheManagerConfig = PeftCacheManagerConfig{},
executor::DecodingConfig decodingConfig = executor::DecodingConfig{}, float gpuWeightsPercent = 1,
std::optional<SizeType32> maxBeamWidth = std::nullopt,
std::optional<SizeType32> maxBeamWidth = std::nullopt, std::optional<SizeType32> maxBatchSize = std::nullopt,
executor::SchedulerConfig const& schedulerConfig = executor::SchedulerConfig{})
: kvCacheConfig{kvCacheConfig}
, enableTrtOverlap{enableTrtOverlap}
@ -52,6 +52,7 @@ public:
, decodingConfig(std::move(decodingConfig))
, gpuWeightsPercent(gpuWeightsPercent)
, maxBeamWidth(maxBeamWidth)
, maxBatchSize(maxBatchSize)
, schedulerConfig{schedulerConfig}
{
}
@ -62,7 +63,7 @@ public:
executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext(),
PeftCacheManagerConfig(executorConfig.getPeftCacheConfig().value_or(executor::PeftCacheConfig())),
executorConfig.getDecodingConfig().value_or(executor::DecodingConfig{}),
executorConfig.getGpuWeightsPercent(), executorConfig.getMaxBeamWidth(),
executorConfig.getGpuWeightsPercent(), executorConfig.getMaxBeamWidth(), executorConfig.getMaxBatchSize(),
executorConfig.getSchedulerConfig())
{
}
@ -87,6 +88,7 @@ public:
// Percentage of weights on the gpu at runtime
float gpuWeightsPercent;
std::optional<SizeType32> maxBeamWidth;
std::optional<SizeType32> maxBatchSize;
executor::SchedulerConfig schedulerConfig;
};

View File

@ -263,6 +263,9 @@ public:
std::optional<std::string> logitsPostProcessorName = std::nullopt,
std::optional<VecTokens> encoderInputTokenIds = std::nullopt);
/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
static auto constexpr kBatchedPostProcessorName = "batched";
Request(Request const& other);
Request(Request&& other) noexcept;
Request& operator=(Request const& other);
@ -403,6 +406,14 @@ public:
[[nodiscard]] std::optional<size_t> getHostCacheSize() const;
[[nodiscard]] bool getOnboardBlocks() const;
void setEnableBlockReuse(bool enableBlockReuse);
void setMaxTokens(SizeType32 maxTokens);
void setMaxAttentionWindow(SizeType32 maxAttentionWindow);
void setSinkTokenLength(SizeType32 sinkTokenLength);
void setFreeGpuMemoryFraction(FloatType freeGpuMemoryFraction);
void setHostCacheSize(size_t hostCacheSize);
void setOnboardBlocks(bool onboardBlocks);
private:
friend class Serialization;
@ -557,49 +568,39 @@ private:
std::optional<size_t> mHostCacheSize;
};
/// @brief Configuration class for Lookahead decoding.
class LookaheadDecodingConfig
struct LookaheadDecodingConfig
{
public:
explicit LookaheadDecodingConfig(
SizeType32 maxNgramSize, SizeType32 maxWindowSize, SizeType32 maxVerificationSetSize);
LookaheadDecodingConfig(SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize);
explicit LookaheadDecodingConfig()
: LookaheadDecodingConfig(1, 1, 0)
{
}
bool operator==(LookaheadDecodingConfig const& other) const;
[[nodiscard]] std::tuple<SizeType32 const, SizeType32 const, SizeType32 const> get() const;
[[nodiscard]] SizeType32 getWindowSize() const;
[[nodiscard]] SizeType32 getNgramSize() const;
[[nodiscard]] SizeType32 getVerificationSetSize() const;
void setMaxNgramSize(SizeType32);
void setMaxWindowSize(SizeType32);
void setMaxVerificationSetSize(SizeType32);
[[nodiscard]] SizeType32 getMaxNgramSize() const;
[[nodiscard]] SizeType32 getMaxWindowSize() const;
[[nodiscard]] SizeType32 getMaxVerificationSetSize() const;
/// @brief return <maxDecodingTokens, maxPathLen, maxDraftTokens, maxDraftPathLen>
std::tuple<SizeType32, SizeType32, SizeType32, SizeType32> calculateSpeculativeResource() const;
/// @brief return true when `this` can be executed on resources defined by `that`
bool isLE(LookaheadDecodingConfig const& that) const;
/// @brief return true when the parameter combination is valid.
static bool isLegal(SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize) noexcept;
private:
friend class Serialization;
// Number of tokens per NGram.
SizeType32 mMaxNgramSize;
// Number of NGrams in lookahead branch per step.
SizeType32 mMaxWindowSize;
SizeType32 mWindowSize;
// Number of tokens per NGram.
SizeType32 mNgramSize;
// Number of NGrams in verification branch per step.
SizeType32 mMaxVerificationSetSize;
};
/// @brief Configuration class for explicit draft tokens decoding.
class ExplicitDraftTokensConfig
{
public:
explicit ExplicitDraftTokensConfig(float temperature);
bool operator==(ExplicitDraftTokensConfig const& other) const;
void setTemperature(float);
[[nodiscard]] float getTemperature() const;
private:
friend class Serialization;
// Sampling temperature.
float mTemperature;
SizeType32 mVerificationSetSize;
};
/// @brief Configuration class for the speculative decoding.
@ -608,8 +609,7 @@ class DecodingConfig
public:
explicit DecodingConfig(std::optional<DecodingMode> decodingMode = std::nullopt,
std::optional<LookaheadDecodingConfig> lookaheadDecodingConfig = std::nullopt,
std::optional<MedusaChoices> medusaChoices = std::nullopt,
std::optional<ExplicitDraftTokensConfig> explicitDraftTokensConfig = std::nullopt);
std::optional<MedusaChoices> medusaChoices = std::nullopt);
bool operator==(DecodingConfig const& other) const;
@ -620,7 +620,7 @@ public:
// Lookahead methods.
/// @brief Sets lookahead decoding mode and config.
void setLookaheadDecoding(LookaheadDecodingConfig const&);
void setLookaheadDecoding(LookaheadDecodingConfig const& lookaheadDecodingConfig);
[[nodiscard]] std::optional<LookaheadDecodingConfig> getLookaheadDecodingConfig() const;
// Medusa methods.
@ -628,11 +628,6 @@ public:
void setMedusaChoices(MedusaChoices const&);
[[nodiscard]] std::optional<MedusaChoices> getMedusaChoices() const;
// ExplicitDraftTokens decoding methods.
/// @brief Sets explicit draft tokens decoding mode and config.
void setExplicitDraftTokens(ExplicitDraftTokensConfig const&);
[[nodiscard]] std::optional<ExplicitDraftTokensConfig> getExplicitDraftTokensConfig() const;
private:
friend class Serialization;
@ -642,8 +637,6 @@ private:
std::optional<LookaheadDecodingConfig> mLookaheadDecodingConfig;
// Medusa params.
std::optional<MedusaChoices> mMedusaChoices;
// Explicit draft tokens params.
std::optional<ExplicitDraftTokensConfig> mExplicitDraftTokensConfig;
};
/// @brief Configuration class for the model executor
@ -654,10 +647,11 @@ public:
KvCacheConfig const& kvCacheConfig = KvCacheConfig(), bool enableChunkedContext = false,
bool normalizeLogProbs = true, SizeType32 iterStatsMaxIterations = kDefaultIterStatsMaxIterations,
SizeType32 requestStatsMaxIterations = kDefaultRequestStatsMaxIterations,
BatchingType batchingType = BatchingType::kINFLIGHT,
BatchingType batchingType = BatchingType::kINFLIGHT, std::optional<SizeType32> maxBatchSize = std::nullopt,
std::optional<ParallelConfig> parallelConfig = std::nullopt,
std::optional<PeftCacheConfig> const& peftCacheConfig = std::nullopt,
std::optional<LogitsPostProcessorMap> logitsPostProcessorMap = std::nullopt,
std::optional<LogitsPostProcessorBatched> logitsPostProcessorBatched = std::nullopt,
std::optional<DecodingConfig> decodingConfig = std::nullopt, float gpuWeightsPercent = 1);
[[nodiscard]] SizeType32 getMaxBeamWidth() const;
@ -668,13 +662,16 @@ public:
[[nodiscard]] SizeType32 getIterStatsMaxIterations() const;
[[nodiscard]] SizeType32 getRequestStatsMaxIterations() const;
[[nodiscard]] BatchingType getBatchingType() const;
[[nodiscard]] std::optional<SizeType32> getMaxBatchSize() const;
[[nodiscard]] std::optional<ParallelConfig> getParallelConfig() const;
[[nodiscard]] std::optional<PeftCacheConfig> getPeftCacheConfig() const;
[[nodiscard]] std::optional<LogitsPostProcessorMap> getLogitsPostProcessorMap() const;
[[nodiscard]] std::optional<LogitsPostProcessorBatched> getLogitsPostProcessorBatched() const;
[[nodiscard]] std::optional<DecodingConfig> getDecodingConfig() const;
[[nodiscard]] float getGpuWeightsPercent() const;
void setMaxBeamWidth(SizeType32 maxBeamWidth);
void setMaxBatchSize(SizeType32 maxBatchSize);
void setSchedulerConfig(SchedulerConfig const& schedulerConfig);
void setKvCacheConfig(KvCacheConfig const& kvCacheConfig);
void setEnableChunkedContext(bool enableChunkedContext);
@ -685,6 +682,7 @@ public:
void setParallelConfig(ParallelConfig const& parallelConfig);
void setPeftCacheConfig(PeftCacheConfig const& peftCacheConfig);
void setLogitsPostProcessorMap(LogitsPostProcessorMap const& logitsPostProcessorMap);
void setLogitsPostProcessorBatched(LogitsPostProcessorBatched const& logitsPostProcessorBatched);
void setDecodingConfig(DecodingConfig const& decodingConfig);
void setGpuWeightsPercent(float const& gpuWeightsPercent);
@ -715,10 +713,14 @@ private:
/// @brief The type of batching strategy to use. See BatchingType.
BatchingType mBatchingType;
/// @brief The max batch size of requests
std::optional<SizeType32> mMaxBatchSize;
/// @brief The parallel execution configuration.
std::optional<ParallelConfig> mParallelConfig;
std::optional<PeftCacheConfig> mPeftCacheConfig;
std::optional<LogitsPostProcessorMap> mLogitsPostProcessorMap;
std::optional<LogitsPostProcessorBatched> mLogitsPostProcessorBatched;
/// @brief Decoding configuration.
std::optional<DecodingConfig> mDecodingConfig;
float mGpuWeightsPercent;

View File

@ -112,11 +112,6 @@ public:
static void serialize(LookaheadDecodingConfig const& lookaheadDecodingConfig, std::ostream& os);
static size_t serializedSize(LookaheadDecodingConfig const& lookaheadDecodingConfig);
// ExplicitDraftTokensConfig
static ExplicitDraftTokensConfig deserializeExplicitDraftTokensConfig(std::istream& is);
static void serialize(ExplicitDraftTokensConfig const& ExplicitDraftTokensConfig, std::ostream& os);
static size_t serializedSize(ExplicitDraftTokensConfig const& ExplicitDraftTokensConfig);
// DecodingConfig
static DecodingConfig deserializeDecodingConfig(std::istream& is);
static void serialize(DecodingConfig const& decodingConfig, std::ostream& os);

View File

@ -53,8 +53,10 @@ using IterationType = std::uint64_t;
using RandomSeedType = std::uint64_t;
using VecLogProbs = std::vector<FloatType>;
using StreamPtr = std::shared_ptr<tensorrt_llm::runtime::CudaStream>;
using LogitsPostProcessor = std::function<void(IdType, Tensor&, BeamTokens const&, StreamPtr&)>;
using LogitsPostProcessor = std::function<void(IdType, Tensor&, BeamTokens const&, StreamPtr const&)>;
using LogitsPostProcessorMap = std::unordered_map<std::string, LogitsPostProcessor>;
using LogitsPostProcessorBatched = std::function<void(std::vector<IdType> const&, std::vector<Tensor>&,
std::vector<std::reference_wrapper<BeamTokens const>> const&, StreamPtr const&)>;
using MedusaChoices = std::vector<std::vector<SizeType32>>;
enum class DataType
@ -224,6 +226,12 @@ struct KvCacheStats
SizeType32 usedNumBlocks;
/// @brief Number of tokens per block
SizeType32 tokensPerBlock;
/// @brief Number of total allocated block
SizeType32 allocTotalBlocks;
/// @brief Number of newly allocated block
SizeType32 allocNewBlocks;
/// @brief Number of reused block
SizeType32 reusedBlocks;
};
/// @brief Struct that holds the stats of static batching models for a single iteration
@ -267,6 +275,8 @@ struct IterationStats
std::string timestamp;
/// @brief Iteration id
IterationType iter;
/// @brief Iteration latency (ms)
double iterLatencyMS;
/// @brief Number of active requests
SizeType32 numActiveRequests;
/// @brief Number of max active requests
@ -717,6 +727,8 @@ static_assert(!DecodingMode::Lookahead().isBeamSearch());
static_assert(!DecodingMode::Lookahead().isMedusa());
static_assert(!DecodingMode::Lookahead().isExplicitDraftTokens());
static_assert(DecodingMode::Lookahead().isUseStopCriteria());
static_assert(DecodingMode::Lookahead().isUseStopWords());
static_assert(DecodingMode::Lookahead().isUseExplicitEosStop());
static_assert(DecodingMode::Lookahead().isLookahead());
static_assert(!DecodingMode::ExplicitDraftTokens().isAuto());

View File

@ -29,13 +29,13 @@ class DecodingInput
public:
using TensorPtr = std::shared_ptr<ITensor const>;
DecodingInput(SizeType32 maxLength, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength,
SizeType32 maxBatchSize, TensorPtr logits, TensorPtr endIds)
DecodingInput(SizeType32 maxLength, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, SizeType32 batchSize,
TensorPtr logits, TensorPtr endIds)
: step{maxLength}
, maxLength{maxLength}
, maxAttentionWindow{maxAttentionWindow}
, sinkTokenLength{sinkTokenLength}
, maxBatchSize{maxBatchSize}
, batchSize{batchSize}
, maxStopWordsLen{0}
, maxBadWordsLen{0}
, logits{std::move(logits)}
@ -50,46 +50,68 @@ public:
SizeType32 maxLength;
SizeType32 maxAttentionWindow;
SizeType32 sinkTokenLength;
SizeType32 maxBatchSize;
SizeType32 batchSize;
SizeType32 maxStopWordsLen; // The maximum value in the `stopWordsLens` tensor
SizeType32 maxBadWordsLen; // The maximum value in the `badWordsLens` tensor
TensorPtr logits; // [batchSize, beamWidth, vocabSizePadded], on gpu
std::optional<std::vector<TensorPtr>>
logitsVec; // vector of size [batchSize] contains logits of size [beamWidth, vocabSizePadded], on gpu
TensorPtr endIds; // [maxBatchSize * beamWidth], on gpu
TensorPtr endIds; // [batchSize * beamWidth], on gpu
// optional parameters
TensorPtr finished; // [maxBatchSize, beamWidth], finished states at current iteration.
TensorPtr finished; // [batchSize, beamWidth], finished states at current iteration.
// If true for some request, the decoding step of it is skipped, on gpu
TensorPtr sequenceLimitLength; // [maxBatchSize], on gpu
TensorPtr embeddingBias; // [maxBatchSize, vocabSizePadded], on gpu
TensorPtr lengths; // [maxBatchSize, beamWidth], on gpu
TensorPtr badWordsList; // [2, badWordsLength] or [maxBatchSize, 2, badWordsLength], on gpu
TensorPtr badWordsPtrs; // [maxBatchSize][2, badWordsLength], on gpu
TensorPtr badWordsLens; // [maxBatchSize], on gpu
TensorPtr stopWordsList; // [maxBatchSize, 2, stopWordsLength], on gpu
TensorPtr stopWordsPtrs; // [maxBatchSize][2, stopWordsLength], on gpu
TensorPtr stopWordsLens; // [maxBatchSize], on gpu
TensorPtr noRepeatNgramSize; // [maxBatchSize], on gpu
TensorPtr sequenceLimitLength; // [batchSize], on gpu
TensorPtr embeddingBias; // [batchSize, vocabSizePadded], on gpu
TensorPtr lengths; // [batchSize, beamWidth], on gpu
TensorPtr badWordsList; // [2, badWordsLength] or [batchSize, 2, badWordsLength], on gpu
TensorPtr badWordsPtrs; // [batchSize][2, badWordsLength], on gpu
TensorPtr badWordsLens; // [batchSize], on gpu
TensorPtr stopWordsList; // [batchSize, 2, stopWordsLength], on gpu
TensorPtr stopWordsPtrs; // [batchSize][2, stopWordsLength], on gpu
TensorPtr stopWordsLens; // [batchSize], on gpu
TensorPtr noRepeatNgramSize; // [batchSize], on gpu
TensorPtr
batchSlots; // [batchSize], optional, address map of the linear batch id to to the seq slots, int32_t, pinned
// parameters for beam search
TensorPtr cacheIndirection; // [maxBatchSize, beamWidth, maxSeqLen] - the k/v cache index for beam search, on gpu
TensorPtr cacheIndirection; // [batchSize, beamWidth, maxSeqLen] - the k/v cache index for beam search, on gpu
// Medusa
class MedusaInputs
{
public:
TensorPtr medusaPaths; // [maxBatchSize, maxTokensPerStep, maxMedusaHeads + 1], on gpu
TensorPtr medusaTreeIds; // [maxBatchSize, maxTokensPerStep], on gpu
TensorPtr medusaPaths; // [batchSize, maxTokensPerStep, maxMedusaHeads + 1], on gpu
TensorPtr medusaTreeIds; // [batchSize, maxTokensPerStep], on gpu
std::vector<std::vector<TensorPtr>>
medusaLogits; // [maxBatchSize][maxAcceptedDraftTokensPerStep][maxDraftTokens + 1, vocabSizePadded], on gpu
TensorPtr medusaCurTokensPerStep; // [maxBatchSize], on gpu
TensorPtr medusaTargetTokensPerStep; // [maxBatchSize], on gpu
medusaLogits; // [batchSize][maxAcceptedDraftTokensPerStep][maxDraftTokens + 1, vocabSizePadded], on gpu
TensorPtr medusaCurTokensPerStep; // [batchSize], on gpu
TensorPtr medusaTargetTokensPerStep; // [batchSize], on gpu
};
class ExplicitDraftTokensInputs
{
public:
TensorPtr nextDraftTokens; // [batchSize, maxNumPaths, maxPathLen]
TensorPtr nextFlatTokens; // [batchSize * maxDecodingTokens]
TensorPtr nextDraftIndices; // [batchSize, maxNumPaths, maxPathLen]
TensorPtr nextDraftProbs; // [batchSize, maxNumPaths, maxDraftPathLen, vocabSize]
TensorPtr lastDraftTokens; // [batchSize, maxNumPaths, maxPathLen]
TensorPtr lastDraftIndices; // [batchSize, maxNumPaths, maxPathLen]
TensorPtr masks; // [batchSize, maxDecodingTokens, maxDecodingTokens], bool
TensorPtr packedPositionIds; // [batchSize * maxDecodingTokens]
TensorPtr bestPathLengths; // [batchSize]
TensorPtr bestPathIndices; // [batchSize]
TensorPtr nextGenerationLengths; // [batchSize]
TensorPtr lastPositionIdsBase; // [batchSize]
TensorPtr lastGenerationLengths; // [batchSize]
TensorPtr maxGenLengthDevice; // [1]
TensorPtr seqSlots; // [batchSize]
};
std::optional<MedusaInputs> medusaInputs;
std::optional<ExplicitDraftTokensInputs> explicitDraftTokensInputs;
};
} // namespace tensorrt_llm::runtime

View File

@ -18,6 +18,7 @@
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/explicitDraftTokensBuffers.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include <optional>
#include <utility>
@ -94,12 +95,15 @@ public:
public:
TensorPtr nextDraftTokens; // [maxBatchSize, maxDraftTokens]
TensorPtr nextDraftTokensLen; // [maxBatchSize]
TensorPtr prevDraftTokensLen; // [maxBatchSize]
TensorPtr acceptedTokensLen; // [maxBatchSize]
TensorPtr acceptedLengthsCumSum; // [maxBatchSize + 1]
TensorPtr pathsOffsets; // [maxBatchSize, maxAcceptedDraftTokensPerStep]
};
std::optional<SpeculativeDecodingOutputs> speculativeDecodingOutputs;
std::optional<ExplicitDraftTokensBuffers::Inputs> explicitDraftTokensBuffers;
};
} // namespace tensorrt_llm::runtime

View File

@ -0,0 +1,142 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/runtime/explicitDraftTokensModule.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/tllmRuntime.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <cstddef>
namespace tensorrt_llm::runtime
{
class ExplicitDraftTokensBuffers
{
public:
using SizeType32 = runtime::SizeType32;
using ITensor = runtime::ITensor;
using BufferPtr = runtime::IBuffer::SharedPtr;
using TensorPtr = runtime::ITensor::SharedPtr;
using TensorMap = runtime::StringPtrMap<runtime::ITensor>;
class Inputs
{
public:
//! [batchSize]
TensorPtr temperatures;
//! [batchSize]
TensorPtr positionIdsBase;
//! [batchSize] or [numGenSequences]
TensorPtr generationLengths;
//! [batchSize]
TensorPtr randomDataSample;
//! [batchSize, maxNumPaths, maxPathDraftLen] or [numGenSequences, maxNumPaths, maxPathDraftLen]
TensorPtr randomDataValidation;
//! [batchSize, maxNumPaths, maxPathLen] or [numGenSequences, maxNumPaths, maxPathLen]
TensorPtr draftTokens;
//! [batchSize, maxNumPaths, maxPathLen] or [numGenSequences, maxNumPaths, maxPathLen]
TensorPtr draftIndices;
//! [batchSize, maxNumPaths, maxPathDraftLen, vocabSize]
//! or [numGenSequences, maxNumPaths, maxPathDraftLen, vocabSize]
TensorPtr draftProbs;
//! [batchSize, maxDecodingTokens, ceil(maxDecodingTokens / 32)]
//! or [numGenSequences, maxDecodingTokens, ceil(maxDecodingTokens / 32)]
TensorPtr packedMasks;
//! [batchSize] or [numGenSequences]
TensorPtr positionIds;
// [1], on pinned
TensorPtr maxGenLengthHost;
void create(SizeType32 maxNumSequences, runtime::TllmRuntime const& runtime,
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig);
};
class EngineInputs : public Inputs
{
public:
//! [numSequences], on gpu
TensorPtr requestTypesDevice;
//! [numGenSequences]
TensorPtr positionOffsets;
} engineInputs;
class EngineOutputs
{
public:
//! [batchSize]
TensorPtr nextGenerationLengths;
//! [batchSize]
TensorPtr nextPositionOffsets;
//! [batchSize, maxDecodingTokens, maxDecodingTokens], bool
TensorPtr masks;
//! [batchSize, maxNumPaths, maxPathLen]
TensorPtr nextDraftTokens;
//! [batchSize, maxNumPaths, maxPathLen]
TensorPtr nextDraftIndices;
//! [batchSize, maxNumPaths, maxDraftPathLen, vocabSize]
TensorPtr nextDraftProbs;
//! [batchSize * maxDecodingTokens]
TensorPtr nextFlatTokens;
//! [batchSize]
TensorPtr bestPathLengths;
//! [batchSize]
TensorPtr bestPathIndices;
//! [1]
TensorPtr maxGenToken;
//! [1]
TensorPtr totalGenToken;
//! [batchSize * maxDecodingTokens]
TensorPtr packedPositionIds;
} engineOutputs;
public:
ExplicitDraftTokensBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, runtime::BufferManager const& manager,
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
executor::DecodingConfig const& decodingConfig, runtime::TllmRuntime const& runtime);
void reshape(SizeType32 numCtxSequences, SizeType32 numGenSequences, runtime::ModelConfig const& modelConfig);
void setFromInputs(SizeType32 numCtxSequences, SizeType32 numGenSequences, runtime::ITensor const& requestTypes,
ITensor const& seqSlots, ExplicitDraftTokensBuffers::Inputs const& decoderBuffers,
ITensor const& contextPositionIds, runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig,
runtime::WorldConfig const& worldConfig) const;
void insertInputTensors(
TensorMap& inputBuffers, TensorMap& outputBuffers, runtime::WorldConfig const& worldConfig) const;
private:
template <typename T>
void setFromInputs(SizeType32 numCtxSequences, SizeType32 numGenSequences, SizeType32 vocabSizePadded,
ITensor const& seqSlots, ExplicitDraftTokensBuffers::Inputs const& draftBuffers,
ITensor const& contextPositionIds, runtime::ExplicitDraftTokensModule const& explicitDraftTokensModule,
runtime::CudaStream const& stream) const;
public:
// helper tensors
std::size_t scanTempStorageBytes{0};
BufferPtr scanTempStorage;
TensorPtr cumSumGenerationLengths;
};
} // namespace tensorrt_llm::runtime

View File

@ -18,18 +18,15 @@
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/decodingInput.h"
#include "tensorrt_llm/runtime/decodingOutput.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/samplingConfig.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <NvInferRuntime.h>
#include <curand_kernel.h>
#include <memory>
#include <NvInferRuntime.h>
namespace tensorrt_llm
{
@ -43,6 +40,8 @@ class DynamicDecodeLayer;
namespace runtime
{
class SpeculativeDecodingModule;
class IGptDecoder
{
public:
@ -51,7 +50,8 @@ public:
virtual ~IGptDecoder() = default;
virtual void setup(SamplingConfig const& samplingConfig, size_t batchSize,
std::optional<TensorPtr> const& batchSlots = std::nullopt)
std::optional<TensorPtr> const& batchSlots = std::nullopt,
std::optional<DecodingOutput> const& output = std::nullopt)
= 0;
virtual void forwardAsync(DecodingOutput& output, DecodingInput const& input) = 0;
@ -93,7 +93,8 @@ public:
std::shared_ptr<SpeculativeDecodingModule const> speculativeDecodingModule = nullptr);
void setup(SamplingConfig const& samplingConfig, size_t batchSize,
std::optional<TensorPtr> const& batchSlots = std::nullopt) override;
std::optional<TensorPtr> const& batchSlots = std::nullopt,
std::optional<DecodingOutput> const& output = std::nullopt) override;
void forwardAsync(DecodingOutput& output, DecodingInput const& input) override;
@ -133,7 +134,9 @@ inline std::unique_ptr<IGptDecoder> IGptDecoder::create(executor::DecodingMode c
case nvinfer1::DataType::kHALF:
return std::make_unique<GptDecoder<half>>(mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded,
maxSequenceLength, stream, speculativeDecodingModule);
default: TLLM_THROW("Unsupported decoder data type. Use either kFLOAT or kHALF."); return nullptr;
default:
TLLM_THROW("Unsupported decoder data type: %d. Use either kFLOAT or kHALF.", static_cast<int>(dtype));
return nullptr;
}
}
} // namespace runtime

View File

@ -38,6 +38,12 @@ public:
using TensorPtr = ITensor::SharedPtr;
using SharedConstPtr = ITensor::SharedConstPtr;
enum class ForwardType
{
kASYNC,
kSYNC
};
GptDecoderBatch(std::size_t vocabSize, std::size_t vocabSizePadded, CudaStreamPtr stream,
SpeculativeDecodingMode const& speculativeDecodingMode);
@ -47,6 +53,8 @@ public:
SizeType32 maxTokensPerStep, bool fusedDecoder, nvinfer1::DataType dtype,
ModelConfig const& modelConfig) override;
void setupExplicitDraftTokens(ExplicitDraftTokensBuffers::Inputs explicitDraftTokensBuffers) override;
void newBatch(
GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig) override;
@ -164,6 +172,12 @@ public:
return mJointDecodingOutput->speculativeDecodingOutputs->nextDraftTokens;
}
//! @returns [batchSize], predicted draft tokens lengths for previous step, on gpu
[[nodiscard]] TensorPtr getPrevDraftTokensLengths() const override
{
return mJointDecodingOutput->speculativeDecodingOutputs->prevDraftTokensLen;
}
//! @returns [batchSize], predicted draft tokens lengths for next step, on gpu
[[nodiscard]] TensorPtr getNextDraftTokensLengths() const override
{
@ -171,13 +185,13 @@ public:
}
//! @returns [batchSize + 1], exclusive sum of accepted draft token lengths, on gpu
[[nodiscard]] TensorPtr getSpecDecodingAcceptedLengthsCumSum() const override
[[nodiscard]] TensorPtr getAcceptedLengthsCumSum() const override
{
return mJointDecodingOutput->speculativeDecodingOutputs->acceptedLengthsCumSum;
}
//! @returns [batchSize, maxAcceptedDraftTokensPerStep], accepted paths packed into continuous tensor, on gpu
[[nodiscard]] TensorPtr getSpecDecodingAcceptedPackedPaths() const override
[[nodiscard]] TensorPtr getAcceptedPackedPaths() const override
{
return mJointDecodingOutput->speculativeDecodingOutputs->pathsOffsets;
}
@ -215,17 +229,19 @@ private:
//! @brief Updates finished state on host for all active requests
void updateFinished(decoder_batch::Token const& token);
//! @brief Sets inputs for explicit draft tokens.
void setExplicitDraftTokensInputs(decoder_batch::Input const& input);
//! @brief Calls unfused or fused decoders for tokens per engine step
void forwardDispatch(
decoder_batch::Output& output, decoder_batch::Input const& input, std::optional<CudaEvent> const& eventStart);
void forwardDispatch(decoder_batch::Output& output, decoder_batch::Input const& input, ForwardType forwardType);
//! @brief Calls unfused decoder for whole batch in loop
void forwardUnfusedDecoder(SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input,
std::optional<CudaEvent> const& eventStart);
void forwardUnfusedDecoder(
SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input, ForwardType forwardType);
//! @brief Calls fused decoder for whole batch
void forwardFusedDecoder(SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input,
std::optional<CudaEvent> const& eventStart);
void forwardFusedDecoder(
SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input, ForwardType forwardType);
private:
std::size_t const mVocabSize;

View File

@ -32,6 +32,7 @@
#include "tensorrt_llm/runtime/generationOutput.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/rawEngine.h"
#include "tensorrt_llm/runtime/samplingConfig.h"
#include "tensorrt_llm/runtime/worldConfig.h"
@ -158,28 +159,34 @@ public:
//! @param sessionConfig Configuration of the session,
//! @param modelConfig Description of the model,
//! @param worldConfig Description of the environment,
//! @param engineBuffer The compiled TensorRT engine (const void*),
//! @param engineSize The size in bytes of the TensorRT engine (size_t),
//! @param rawEngine The compiled TensorRT engine,
//! @param logger The optional logger.
GptSession(Config const& sessionConfig, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
void const* engineBuffer, std::size_t engineSize, LoggerPtr logger = nullptr);
RawEngine const& rawEngine, LoggerPtr logger = nullptr);
GptSession(Config const& sessionConfig, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
void const* engineBuffer, std::size_t engineSize, LoggerPtr logger = nullptr)
: GptSession(sessionConfig, modelConfig, worldConfig, RawEngine(engineBuffer, engineSize), std::move(logger))
{
}
GptSession(Config const& sessionConfig, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
std::vector<uint8_t> const& engineBuffer, LoggerPtr logger = nullptr)
: GptSession(
sessionConfig, modelConfig, worldConfig, engineBuffer.data(), engineBuffer.size(), std::move(logger))
: GptSession(sessionConfig, modelConfig, worldConfig, RawEngine(engineBuffer.data(), engineBuffer.size()),
std::move(logger))
{
}
GptSession(Config const& sessionConfig, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
std::string const& engineFile, LoggerPtr logger = nullptr)
: GptSession(sessionConfig, modelConfig, worldConfig, utils::loadEngine(engineFile), std::move(logger))
: GptSession(sessionConfig, modelConfig, worldConfig, RawEngine(engineFile), std::move(logger))
{
}
[[nodiscard]] nvinfer1::ILogger& getLogger() const;
[[nodiscard]] BufferManager const& getBufferManager() const;
[[nodiscard]] BufferManager::CudaStreamPtr getRuntimeStreamPtr() const;
[[nodiscard]] ModelConfig const& getModelConfig() const
{

View File

@ -18,8 +18,10 @@
#include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/explicitDraftTokensBuffers.h"
#include "tensorrt_llm/runtime/iStatefulGptDecoder.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/request.h"
#include "tensorrt_llm/runtime/utils/sessionUtils.h"
#include <memory>
@ -31,41 +33,6 @@ namespace tensorrt_llm::runtime
namespace decoder_batch
{
class Request
{
public:
using ConstTensorPtr = ITensor::SharedConstPtr;
using TensorPtr = ITensor::SharedPtr;
using BufferPtr = IBuffer::SharedPtr;
explicit Request(ConstTensorPtr ids, SizeType32 inputLen, std::optional<SizeType32> maxNewTokens = std::nullopt,
std::optional<SizeType32> endId = std::nullopt)
: ids{std::move(ids)}
, inputLen(inputLen)
, maxNewTokens{maxNewTokens}
, endId{endId}
, generatedTokensPerEngineStep(1)
{
}
// mandatory parameters
ConstTensorPtr ids; // [inputSeqLen], the input sequence of token ids, on gpu
SizeType32 inputLen; // the input length without draft tokens
// optional parameters
std::optional<SizeType32> maxNewTokens; // maximum number of tokens to generate for this request
std::optional<SizeType32> endId; // end token id
BufferPtr draftTokens; // [generatedTokensPerStep - 1], on gpu, draft tokens from speculative decoding
std::optional<TensorPtr>
draftLogits; // [generatedTokensPerStep - 1, vocabSize], on gpu, draft tokens from speculative decoding
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
TensorPtr badWordsList; // [2, badWordsLength], on gpu
TensorPtr stopWordsList; // [2, stopWordsLength], on gpu
SizeType32 generatedTokensPerEngineStep;
TensorPtr medusaPaths; // [maxDraftTokens + 1, maxAcceptedDraftTokensPerStep + 1], on gpu
TensorPtr medusaTreeIds; // [maxDraftTokens + 1], on gpu
};
class Input
{
@ -109,6 +76,11 @@ public:
// within one beam for beam search, on gpu
std::vector<std::vector<TensorConstPtr>>
predictedDraftLogits; // [maxBatchSize][maxAcceptedDraftTokensPerStep][maxDraftTokens + 1, vocabSizePadded]
TensorConstPtr seqSlots; // [batchSize]
// explicit draft tokens data.
std::optional<ExplicitDraftTokensBuffers::EngineOutputs> explicitDraftTokensInputs;
std::optional<ExplicitDraftTokensBuffers::EngineInputs> explicitDraftTokensLastInputs;
};
using Output = decoder::Output;
@ -136,6 +108,9 @@ public:
using TensorPtr = std::shared_ptr<ITensor>;
using TokenPtr = std::unique_ptr<decoder_batch::Token const>;
//! @brief Setup buffers for ExplicitDraftTokens decoding.
virtual void setupExplicitDraftTokens(ExplicitDraftTokensBuffers::Inputs explicitDraftTokensBuffers) = 0;
//! @brief Run one step for all requests without blocking the host process and return the token for synchronization.
virtual TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) = 0;
@ -189,14 +164,17 @@ public:
//! @returns [batchSize, maxTokensPerStep-1], predicted draft tokens for next step, on gpu
virtual TensorPtr getNextDraftTokens() const = 0;
//! @returns [batchSize], predicted draft tokens lengths for previous step, on gpu
virtual TensorPtr getPrevDraftTokensLengths() const = 0;
//! @returns [batchSize], predicted draft tokens lengths for next step, on gpu
virtual TensorPtr getNextDraftTokensLengths() const = 0;
//! @returns [batchSize + 1], exclusive sum of accepted draft token lengths, on gpu
virtual TensorPtr getSpecDecodingAcceptedLengthsCumSum() const = 0;
virtual TensorPtr getAcceptedLengthsCumSum() const = 0;
//! @returns [batchSize, maxAcceptedDraftTokensPerStep], accepted paths packed into continuous tensor, on gpu
virtual TensorPtr getSpecDecodingAcceptedPackedPaths() const = 0;
virtual TensorPtr getAcceptedPackedPaths() const = 0;
protected:
IGptDecoderBatch() = default;

View File

@ -16,7 +16,12 @@
#pragma once
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/request.h"
#include "tensorrt_llm/runtime/speculativeDecodingModule.h"
#include <memory>
namespace tensorrt_llm::runtime
{
@ -24,8 +29,9 @@ namespace tensorrt_llm::runtime
class LookaheadModule : public SpeculativeDecodingModule
{
public:
explicit LookaheadModule(SizeType32 maxAcceptedTokens, SizeType32 maxDraftTokens) noexcept
: SpeculativeDecodingModule(maxAcceptedTokens, maxDraftTokens, maxDraftTokens)
explicit LookaheadModule(SizeType32 maxDraftPathLen, SizeType32 maxDecodingDraftTokens) noexcept
: SpeculativeDecodingModule(maxDraftPathLen, maxDecodingDraftTokens, maxDecodingDraftTokens)
, mExecutionConfig()
{
}
@ -33,5 +39,19 @@ public:
: LookaheadModule(0, 0)
{
}
void setExecutionConfig(executor::LookaheadDecodingConfig const& config)
{
mExecutionConfig = config;
}
executor::LookaheadDecodingConfig const getExecutionConfig() const
{
return mExecutionConfig;
}
private:
executor::LookaheadDecodingConfig mExecutionConfig;
};
} // namespace tensorrt_llm::runtime

View File

@ -21,7 +21,9 @@
#include "tensorrt_llm/runtime/loraModule.h"
#include "tensorrt_llm/runtime/speculativeDecodingMode.h"
#include "tensorrt_llm/runtime/speculativeDecodingModule.h"
#include <NvInferRuntime.h>
#include <array>
namespace tensorrt_llm::runtime
{
@ -29,6 +31,12 @@ namespace tensorrt_llm::runtime
class ModelConfig
{
public:
// See `split_point` defined in `tensorrt_llm/models/generation_mixin.py`.
// The split points are tuned to get better perf, if we need to let
// users tune that, we can support that by writing and reading the
// points in `config.json`.
static constexpr std::array kOPT_PROFILES_SPLIT_POINTS{64, 128, 256, 512, 1024};
enum class ModelVariant : std::int32_t
{
kGpt = 0,
@ -88,9 +96,16 @@ public:
, mUsePositionEmbedding(false)
, mUseTokenTypeEmbedding(false)
, mSpeculativeDecodingMode(SpeculativeDecodingMode::None())
, mLogitsDtype(nvinfer1::DataType::kFLOAT)
, mUseShapeInference(true)
{
}
[[nodiscard]] static std::vector<SizeType32> getOptProfilesSplitPoints() noexcept
{
return {kOPT_PROFILES_SPLIT_POINTS.begin(), kOPT_PROFILES_SPLIT_POINTS.end()};
}
[[nodiscard]] SizeType32 constexpr getVocabSize() const noexcept
{
return mVocabSize;
@ -555,6 +570,26 @@ public:
return mSpeculativeDecodingMode;
}
void setLogitsDtype(nvinfer1::DataType inputDtype) noexcept
{
mLogitsDtype = inputDtype;
}
[[nodiscard]] nvinfer1::DataType constexpr getLogitsDtype() const noexcept
{
return mLogitsDtype;
}
void setUseShapeInference(bool useShapeInference) noexcept
{
mUseShapeInference = useShapeInference;
}
[[nodiscard]] bool useShapeInference() const noexcept
{
return mUseShapeInference;
}
private:
SizeType32 mVocabSize;
SizeType32 mNbAttentionLayers;
@ -608,6 +643,10 @@ private:
// Speculative decoding members
std::shared_ptr<SpeculativeDecodingModule> mSpeculativeDecodingModule;
SpeculativeDecodingMode mSpeculativeDecodingMode;
// Logits datatype
nvinfer1::DataType mLogitsDtype;
bool mUseShapeInference;
};
} // namespace tensorrt_llm::runtime

View File

@ -0,0 +1,98 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/assert.h"
#include <NvInferRuntime.h>
#include <filesystem>
namespace tensorrt_llm::runtime
{
class RawEngine
{
public:
enum Type
{
FilePath,
AddressWithSize,
HostMemory
};
explicit RawEngine(std::filesystem::path enginePath) noexcept
: mType(FilePath)
, mEnginePath(std::move(enginePath))
{
}
explicit RawEngine(void const* engineAddr, std::size_t engineSize) noexcept
: mType(AddressWithSize)
, mEngineAddr(engineAddr)
, mEngineSize(engineSize)
{
}
explicit RawEngine(nvinfer1::IHostMemory const* engineBuffer) noexcept
: mType(HostMemory)
, mEngineBuffer(engineBuffer)
{
}
[[nodiscard]] Type getType() const
{
return mType;
}
[[nodiscard]] std::filesystem::path getPath() const
{
TLLM_CHECK(mType == FilePath);
return mEnginePath;
}
[[nodiscard]] void const* getAddress() const
{
TLLM_CHECK(mType == AddressWithSize);
return mEngineAddr;
}
[[nodiscard]] std::size_t getSize() const
{
TLLM_CHECK(mType == AddressWithSize);
return mEngineSize;
}
[[nodiscard]] nvinfer1::IHostMemory const* getHostMemory() const
{
TLLM_CHECK(mType == HostMemory);
return mEngineBuffer;
}
private:
Type mType;
std::filesystem::path mEnginePath;
struct
{
void const* mEngineAddr{};
std::size_t mEngineSize{};
};
nvinfer1::IHostMemory const* mEngineBuffer{};
};
} // namespace tensorrt_llm::runtime

View File

@ -0,0 +1,68 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/iStatefulGptDecoder.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/utils/sessionUtils.h"
#include <optional>
namespace tensorrt_llm::runtime::decoder_batch
{
class Request
{
public:
using ConstTensorPtr = ITensor::SharedConstPtr;
using TensorPtr = ITensor::SharedPtr;
using BufferPtr = IBuffer::SharedPtr;
explicit Request(ConstTensorPtr ids, SizeType32 inputLen, std::optional<SizeType32> maxNewTokens = std::nullopt,
std::optional<SizeType32> endId = std::nullopt)
: ids{std::move(ids)}
, inputLen(inputLen)
, maxNewTokens{maxNewTokens}
, endId{endId}
, generatedTokensPerEngineStep(1)
{
}
// mandatory parameters
ConstTensorPtr ids; // [inputSeqLen], the input sequence of token ids, on gpu
SizeType32 inputLen; // the input length without draft tokens
// optional parameters
std::optional<SizeType32> maxNewTokens; // maximum number of tokens to generate for this request
std::optional<SizeType32> endId; // end token id
BufferPtr draftTokens; // [generatedTokensPerStep - 1], on gpu, draft tokens from speculative decoding
std::optional<TensorPtr>
draftLogits; // [generatedTokensPerStep - 1, vocabSize], on gpu, draft tokens from speculative decoding
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
TensorPtr badWordsList; // [2, badWordsLength], on gpu
TensorPtr stopWordsList; // [2, stopWordsLength], on gpu
SizeType32 generatedTokensPerEngineStep;
TensorPtr medusaPaths; // [maxDraftTokens + 1, maxAcceptedDraftTokensPerStep + 1], on gpu
TensorPtr medusaTreeIds; // [maxDraftTokens + 1], on gpu
std::optional<executor::LookaheadDecodingConfig> lookaheadRuntimeConfig;
};
} // namespace tensorrt_llm::runtime::decoder_batch

View File

@ -75,6 +75,11 @@ public:
return anyBitSet(kExplicitDraftTokens);
}
[[nodiscard]] bool constexpr updatesPositionIds() const
{
return anyBitSet(kLookaheadDecoding | kExplicitDraftTokens);
}
[[nodiscard]] bool constexpr requiresAttentionMask() const
{
return anyBitSet(kLookaheadDecoding | kMedusa | kExplicitDraftTokens);
@ -101,6 +106,12 @@ public:
return anyBitSet(kMedusa);
}
[[nodiscard]] bool constexpr needsDecoderPrologue() const
{
// Potentially lookahead should require it too.
return anyBitSet(kExplicitDraftTokens);
}
using UnderlyingType = std::uint8_t;
bool operator==(SpeculativeDecodingMode const& other) const

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3769cb4ad108cb9898a03b25e91781bcb5576b85397fbd7f673843abba27272e
size 3977112
oid sha256:1fec0fdc00c076761ec48eb5e2ea93473a329e844a8091e26c6e3e02fd14a8b1
size 3931604

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3769cb4ad108cb9898a03b25e91781bcb5576b85397fbd7f673843abba27272e
size 3977112
oid sha256:1fec0fdc00c076761ec48eb5e2ea93473a329e844a8091e26c6e3e02fd14a8b1
size 3931604

View File

@ -1,3 +1,3 @@
359da6357b9948425f249c226166408a libtensorrt_llm_batch_manager_static.a
359da6357b9948425f249c226166408a libtensorrt_llm_batch_manager_static.pre_cxx11.a
8d4b145290d5984494a1fa6e380d01456534dc62 commit
93adf3003d7c422586a9bf892367371d libtensorrt_llm_batch_manager_static.a
93adf3003d7c422586a9bf892367371d libtensorrt_llm_batch_manager_static.pre_cxx11.a
c0bd2b69c932257678a2aad9bd8baba4b291795e commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3841fcf17899aa8cb75a01a5d0ee8c99e4e078399e4bb8a1201f9d53445d09cf
size 3869232
oid sha256:bd757c26886a3ffd6947615d9f2829434e94839b693007a64b47c6b5c26416e4
size 3812158

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:99d1e58c95ea4267129b7a3ac95b65dc72b5e006b3168d07b213a1f9712930de
size 3835982
oid sha256:87321383075adf2d87cfbdc8a12a3d3815ef058d5da9b6aaa8d7d3f3263af439
size 3773896

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e61e4199962b639502aba50adca548e79d6332e658c10ab717b2ec019d28ed45
size 22213850
oid sha256:58cdc0a330f8bfb7b50e3202aeac47bde0835b1dc600b4bfdcd2b30801e66e03
size 22381766

View File

@ -162,7 +162,7 @@ struct CutlassGemmConfig
{
}
std::string toString()
std::string toString() const
{
std::stringstream tactic;
tactic << "Cutlass GEMM Tactic";

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2a00e1d3526af9fe7877c5e3362b32244309ccfac8fd720d1020c966d13b71c9
size 1372862
oid sha256:18a967eaa1e9a7164e0b104a84b13ea95404f7c7c278375feb2513d5f063bafe
size 1396404

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2a00e1d3526af9fe7877c5e3362b32244309ccfac8fd720d1020c966d13b71c9
size 1372862
oid sha256:18a967eaa1e9a7164e0b104a84b13ea95404f7c7c278375feb2513d5f063bafe
size 1396404

View File

@ -1,3 +1,3 @@
a35a65a41062edf23a898ad42cdce31c libtensorrt_llm_executor_static.a
a35a65a41062edf23a898ad42cdce31c libtensorrt_llm_executor_static.pre_cxx11.a
8d4b145290d5984494a1fa6e380d01456534dc62 commit
7d12b9c04cb6738bb5f7747a88b00c1c libtensorrt_llm_executor_static.a
7d12b9c04cb6738bb5f7747a88b00c1c libtensorrt_llm_executor_static.pre_cxx11.a
c0bd2b69c932257678a2aad9bd8baba4b291795e commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e37e2b2f28ac1ae37c22fac7c93394c6fba6e94e27403c0904e47eeb6cd4bf5c
size 1412454
oid sha256:e503b4cfb1c842850287a359ffed23a1773a67a96475d365b66d757a283ac218
size 1448772

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0b801135ba31f7ea63de5deb1880a45b68b2bc9fa45403e7204f6b7a153bd3ee
size 1346882
oid sha256:f8c80cf7aca2b135a656a060456fb30a820e459b4b36560162b02fa65121ef50
size 1375430

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9334de5c0a470731f8dd63f68e60ef320268d838547e5e6cbf537bf5c231eb6f
size 12962386
oid sha256:cc65971d6d74260cb49b354aa4b0b82f92863cc722fbf206bf8a4919a4897532
size 14031364

View File

@ -97,6 +97,7 @@ struct PackedOn16Bytes<__nv_bfloat16>
{
using Type = PackedBFloat16;
};
#endif
// add two 128b data
@ -600,7 +601,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params)
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true, bool PUSH_MODE = false, bool Bias = false,
bool Residual = false>
static __global__ void twoShotAllReduceKernel(AllReduceParams params)
static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduceParams params)
{
// Suppose that two GPUs participate in the AR exchange, and we start two blocks.
// The message is partitioned into chunks as detailed below:
@ -674,7 +675,7 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params)
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{
size_t offset_rank = ii * params.elts_per_rank + local_offset;
size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset;
if (offset_rank >= params.elts_total)
{
continue;
@ -829,7 +830,6 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
blocks_per_grid = std::min(static_cast<size_t>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
*/
while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE)
{
blocks_per_grid += 1;
@ -863,7 +863,8 @@ template <typename T, int RANKS_PER_NODE, bool PUSH_MODE = false, bool USE_MEMCP
void AllReduceNormKernelLaunch(AllReduceStrategyType algo, AllReduceStrategyConfig config, AllReduceFusionOp fusionOp,
AllReduceParams& params, cudaStream_t stream)
{
TLLM_CHECK(fusionOp == AllReduceFusionOp::RESIDUAL_RMS_NORM);
TLLM_CHECK_WITH_INFO(fusionOp == AllReduceFusionOp::RESIDUAL_RMS_NORM, "Unsupported AllReduceFusionOp: %d",
static_cast<int>(fusionOp));
if (algo == AllReduceStrategyType::ONESHOT)
{
reduce_fusion::one_shot_all_reduce_norm_kernel_launcher<T, RANKS_PER_NODE, Bias, Affine>(params, stream);
@ -1019,7 +1020,6 @@ AllReduceParams AllReduceParams::deserialize(int32_t const* buffer, size_t tpSiz
}
params.barrier_flag = flag_value;
params.ranks_per_node = tpSize;
params.rank = tpRank;
params.local_rank = tpRank;
return params;

View File

@ -30,7 +30,7 @@ namespace tensorrt_llm::kernels
constexpr size_t WARP_SIZE = 32;
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 24;
constexpr size_t MAX_RANKS_PER_NODE = 8;
constexpr size_t DEFAULT_BLOCK_SIZE = 1024;
constexpr size_t DEFAULT_BLOCK_SIZE = 512;
// Warning: python definition is in tensorrt_llm/functional.py
// they must be kept in sync
@ -82,7 +82,7 @@ struct AllReduceParams
size_t elts_per_rank;
size_t elts_per_block;
size_t rank_offset;
size_t ranks_per_node, rank, local_rank;
size_t ranks_per_node, local_rank;
uint32_t barrier_flag;
uint32_t* peer_barrier_ptrs_in[MAX_RANKS_PER_NODE];
uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE];

View File

@ -173,6 +173,10 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(
int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config)
{
#ifdef FAST_BUILD
// Fast build disables all configs except this one for SM90
return {CutlassTileConfigSM90::CtaShape128x128x128B};
#else
if (config & CutlassGemmConfig::GROUPED_GEMM)
{
return {CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B,
@ -187,26 +191,35 @@ std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(
CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B,
CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B};
}
#endif
}
// We only compile CUTLASS kernels with multi-cast along M if the M tile is >= 128. This is purely to improve
// compilation speed.
bool supports_mcast_along_m(const CutlassTileConfigSM90 tile)
{
#ifdef FAST_BUILD
return false;
#else
std::set<CutlassTileConfigSM90> valid_tiles{CutlassTileConfigSM90::CtaShape128x16x128B,
CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B,
CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B};
return valid_tiles.count(tile) == 1;
#endif
}
// We only compile CUTLASS kernels with multi-cast along N if the N tile is >= 128. This is purely to improve
// compilation speed.
bool supports_mcast_along_n(const CutlassTileConfigSM90 tile)
{
#ifdef FAST_BUILD
return false;
#else
std::set<CutlassTileConfigSM90> valid_tiles{CutlassTileConfigSM90::CtaShape64x128x128B,
CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x128x128B,
CutlassTileConfigSM90::CtaShape128x256x128B};
return valid_tiles.count(tile) == 1;
#endif
}
std::vector<CutlassGemmConfig> get_candidate_configs(

View File

@ -16,6 +16,7 @@
#pragma once
#include "cute/tensor.hpp"
#include "cutlass_extensions/gemm_configs.h"
#include "tensorrt_llm/common/cudaUtils.h"
@ -26,13 +27,31 @@ namespace kernels
namespace cutlass_kernels
{
template <class TileShape, class ClusterShape, class ActivationType>
struct should_filter_sm90_gemm_problem_shape
{
#ifdef FAST_BUILD
constexpr static int TILE_K = 128 * 8 / cutlass::sizeof_bits<ActivationType>::value;
using SupportedCtaShape = cute::Shape<cute::_128, cute::_128, cute::Int<TILE_K>>;
using SupportedCgaShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
constexpr static bool value
= !cute::is_same_v<SupportedCtaShape, TileShape> || !cute::is_same_v<SupportedCgaShape, ClusterShape>;
#else
constexpr static bool value = false;
#endif
};
template <class TileShape, class ClusterShape, class ActivationType>
constexpr static bool should_filter_sm90_gemm_problem_shape_v
= should_filter_sm90_gemm_problem_shape<TileShape, ClusterShape, ActivationType>::value;
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> get_candidate_configs(
int sm, int const max_split_k, tensorrt_llm::cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam const);
tensorrt_llm::cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies(
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> const& candidate_configs,
std::vector<int> const& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts,
int const split_k_limit, const size_t workspace_bytes, int const multi_processor_count, int const is_weight_only);
std::vector<int> const& occupancies, int64_t const m, int64_t const n, int64_t const k, int64_t const num_experts,
int const split_k_limit, size_t const workspace_bytes, int const multi_processor_count, int const is_weight_only);
} // namespace cutlass_kernels
} // namespace kernels

View File

@ -41,6 +41,7 @@
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h"
@ -69,15 +70,8 @@ void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType
#ifdef COMPILE_HOPPER_TMA_GEMMS
using CutlassActivationType = typename TllmToCutlassTypeAdapter<ActivationType>::type;
// For FAST_BUILD, only instantiate kernels with 128x128x128B with 1x1x1 cluster shape.
#ifdef FAST_BUILD
constexpr int TILE_K = 128 * 8 / cutlass::sizeof_bits<CutlassActivationType>::value;
using SupportedCtaShape = Shape<_128, _128, cute::Int<TILE_K>>;
using SupportedCgaShape = Shape<_1, _1, _1>;
if constexpr (cute::is_same_v<SupportedCtaShape, CTAShape> && cute::is_same_v<SupportedCgaShape, ClusterShape>)
if constexpr (!should_filter_sm90_gemm_problem_shape_v<CTAShape, ClusterShape, ActivationType>)
{
#endif // FAST_BUILD
using CutlassWeightType__ = typename TllmToCutlassTypeAdapter<WeightType>::type;
// We need to remap this since SM90 uses a different layout for the weight matrix.
using CutlassWeightType_ = std::conditional_t<std::is_same_v<CutlassWeightType__, cutlass::uint4b_t>,
@ -278,13 +272,17 @@ void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType
= "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status));
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg);
}
#ifdef FAST_BUILD
}
else
{
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] Config not compiled with FAST_BUILD.");
std::stringstream ss;
ss << "[TensorRT-LLm Error][fpA_intB Runner] Config (" << (int64_t) cute::size<0>(CTAShape{}) << ","
<< (int64_t) cute::size<1>(CTAShape{}) << "," << (int64_t) cute::size<2>(CTAShape{}) << ") ("
<< (int64_t) cute::size<0>(ClusterShape{}) << "," << (int64_t) cute::size<1>(ClusterShape{}) << ","
<< (int64_t) cute::size<2>(ClusterShape{}) << ") not compiled with FAST_BUILD.";
throw std::runtime_error(ss.str());
}
#endif // FAST_BUILD
#else // COMPILE_HOPPER_TMA_GEMMS
throw std::runtime_error(

View File

@ -197,14 +197,7 @@ void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, i
{
#ifdef COMPILE_HOPPER_TMA_GEMMS
using namespace cute;
// For FAST_BUILD, only instantiate kernels with 128x128x128B with 1x1x1 cluster shape.
#ifdef FAST_BUILD
constexpr int TILE_K = 128 * 8 / cutlass::sizeof_bits<WeightType>::value;
using SupportedCtaShape = Shape<_128, _128, cute::Int<TILE_K>>;
using SupportedCgaShape = Shape<_1, _1, _1>;
if constexpr (cute::is_same_v<SupportedCtaShape, TileShape> && cute::is_same_v<SupportedCgaShape, ClusterShape>)
#endif // FAST_BUILD
if constexpr (!should_filter_sm90_gemm_problem_shape_v<TileShape, ClusterShape, T>)
{
using GemmInfo = HopperGroupedGemmInfo<T, WeightType, EpilogueTag, TileShape, ClusterShape, BIAS>;
@ -287,12 +280,10 @@ void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, i
"Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status)));
sync_check_cuda_error();
}
#ifdef FAST_BUILD
else
{
TLLM_THROW("Configuration was disabled by FAST_BUILD");
}
#endif
#else // COMPILE_HOPPER_TMA_GEMMS
TLLM_THROW("Please recompile with support for hopper by passing 90-real as an arch to build_wheel.py.");

View File

@ -30,7 +30,6 @@ namespace tensorrt_llm
struct HopperGroupedGemmInput
{
template <class Tag>
using TransposeLayoutTag = std::conditional_t<std::is_same_v<Tag, cutlass::layout::RowMajor>,
cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>;
@ -180,7 +179,7 @@ public:
bool supportsHopperSpecialisation() const;
[[nodiscard]] bool isFusedGatedActivation(bool is_gated_activation, int gemm_n, int gemm_k) const;
size_t calcMaxWorkspaceSize(int num_experts) const;
size_t getMaxWorkspaceSize(int num_experts) const;
[[nodiscard]] int getSM() const;
@ -197,9 +196,12 @@ private:
int64_t gemm_k, int num_experts, bool use_fused_moe, cudaStream_t stream);
private:
int sm_;
int multi_processor_count_;
int sm_{};
int multi_processor_count_{};
mutable int num_experts_ = 0;
mutable size_t gemm_workspace_size_ = 0;
std::optional<cutlass_extensions::CutlassGemmConfig> best_config_{};
size_t calcMaxWorkspaceSize(int num_experts) const;
};
} // namespace tensorrt_llm

View File

@ -48,6 +48,8 @@
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
@ -533,6 +535,18 @@ void MoeGemmRunner<T, WeightType>::dispatchToArch<EpilogueTag>(T const* A, Weigh
}
}
template <typename T, typename WeightType>
size_t MoeGemmRunner<T, WeightType>::getMaxWorkspaceSize(int num_experts) const
{
if (num_experts != num_experts_)
{
TLLM_LOG_TRACE("Calling getMaxWorkspaceSize() with a new expert count %d vs %d", num_experts, num_experts_);
num_experts_ = num_experts;
gemm_workspace_size_ = calcMaxWorkspaceSize(num_experts);
}
return gemm_workspace_size_;
}
template <typename T, typename WeightType>
size_t MoeGemmRunner<T, WeightType>::calcMaxWorkspaceSize(int num_experts) const
{

View File

@ -0,0 +1,48 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Common utils to be shared between Precompiled and JIT implementation.
*/
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h"
namespace tensorrt_llm
{
namespace kernels
{
XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParams)
{
unsigned int head_size = xqaParams.head_size;
unsigned int num_q_heads = xqaParams.num_q_heads;
unsigned int num_kv_heads = xqaParams.num_kv_heads;
TLLM_CHECK_WITH_INFO(num_q_heads % num_kv_heads == 0, "numQHeads should be multiple of numKVHeads.");
unsigned int num_q_heads_over_kv = num_q_heads / num_kv_heads;
unsigned int beam_width = xqaParams.beam_width;
// Use mTileSize = 16 kernels when qSeqLen <= 16.
unsigned int qSeqLen = static_cast<unsigned int>(xqaParams.generation_input_length);
unsigned int mTileSize = qSeqLen <= 16 ? 16 : 32;
// MultiQueryToken kernels can support any num_q_heads_over_kv that is power of 2.
unsigned int kernel_num_q_heads_over_kv = xqaParams.multi_query_tokens ? 0 : num_q_heads_over_kv;
// MultiQueryToken kernels can handle either 16/32 for M direction per CTA.
unsigned int kernel_m_tilesize = xqaParams.multi_query_tokens ? mTileSize : num_q_heads_over_kv;
return {xqaParams.kv_cache_data_type, head_size, beam_width, kernel_num_q_heads_over_kv, kernel_m_tilesize,
xqaParams.paged_kv_cache ? static_cast<unsigned int>(xqaParams.tokens_per_block) : 0, xqaParams.paged_kv_cache,
xqaParams.multi_query_tokens};
}
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -17,6 +17,7 @@
*/
#pragma once
#include "decoderXQAConstants.h"
#include "tensorrt_llm/common/cudaDriverWrapper.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/common/workspace.h"
@ -73,6 +74,8 @@ struct XQAKernelRuntimeHashKey
}
};
XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParams);
struct XQAKernelRuntimeHasher
{
size_t operator()(XQAKernelRuntimeHashKey const& s) const

View File

@ -29,7 +29,7 @@ namespace jit
{
CubinObj::CubinObj(void const* buffer_, size_t buffer_size)
: mDriver(tensorrt_llm::common::CUDADriverWrapper::getInstance())
: mInitialized(false)
{
uint8_t const* buffer = static_cast<uint8_t const*>(buffer_);
size_t remaining_buffer_size = buffer_size;
@ -37,8 +37,37 @@ CubinObj::CubinObj(void const* buffer_, size_t buffer_size)
mContent.resize(len);
TLLM_CHECK(len <= remaining_buffer_size);
memcpy(mContent.data(), buffer, len);
}
initialize(mContent.c_str(), "kernel_mha");
CubinObj::CubinObj(std::string const& content)
: mContent(content)
, mInitialized(false)
{
}
CubinObj::CubinObj(CubinObj const& other)
{
// Only uninitialized CubinObj can be copy-constructed.
TLLM_CHECK(!other.mInitialized);
this->mContent = other.mContent;
this->mInitialized = false;
}
CubinObj& CubinObj::operator=(CubinObj const& other)
{
if (this == &other)
{
return *this;
}
// Only uninitialized CubinObj can be copy-assigned.
TLLM_CHECK(!other.mInitialized);
this->mContent = other.mContent;
this->mInitialized = false;
return *this;
}
size_t CubinObj::getSerializationSize() const noexcept
@ -59,47 +88,45 @@ void CubinObj::serialize(void* buffer_, size_t buffer_size) const noexcept
memcpy(buffer, mContent.c_str(), len);
}
CubinObj::CubinObj(std::string const& content)
: mDriver(tensorrt_llm::common::CUDADriverWrapper::getInstance())
, mContent(content)
, mModule(nullptr)
, mFunction(nullptr)
, mSharedMemBytes(0)
{
initialize(mContent.c_str(), "kernel_mha");
}
void CubinObj::launch(dim3 gridDim, dim3 blockDim, CUstream hStream, void** kernelParams)
{
TLLM_CHECK(mInitialized);
cuErrCheck(mDriver->cuLaunchKernel(mFunction, gridDim.x, gridDim.y, gridDim.z, blockDim.x, blockDim.y, blockDim.z,
mSharedMemBytes, hStream, kernelParams, /*extra=*/nullptr),
mDriver);
}
void CubinObj::initialize(char const* content, char const* funcName)
void CubinObj::initialize()
{
cuErrCheck(mDriver->cuModuleLoadData(&mModule, content), mDriver);
TLLM_CHECK(mModule != nullptr);
cuErrCheck(mDriver->cuModuleGetFunction(&mFunction, mModule, funcName), mDriver);
TLLM_CHECK(mFunction != nullptr);
// Populate mSharedMemBytes.
CUdeviceptr shmem_dev_ptr = 0;
cuErrCheck(mDriver->cuModuleGetGlobal(&shmem_dev_ptr, nullptr, mModule, "smemSize"), mDriver);
TLLM_CHECK(shmem_dev_ptr != 0);
cuErrCheck(mDriver->cuMemcpyDtoH(&mSharedMemBytes, shmem_dev_ptr, sizeof(unsigned int)), mDriver);
TLLM_CHECK(mSharedMemBytes > 0);
/* Set 46KB threshold here because we have to take static/driver shared memory into consideration. */
if (mSharedMemBytes >= 46 * 1024)
if (!mInitialized)
{
cuErrCheck(
mDriver->cuFuncSetAttribute(mFunction, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, mSharedMemBytes),
mDriver);
}
mDriver = tensorrt_llm::common::CUDADriverWrapper::getInstance();
mModule = nullptr;
cuErrCheck(mDriver->cuModuleLoadData(&mModule, mContent.c_str()), mDriver);
TLLM_CHECK(mModule != nullptr);
mFunction = nullptr;
cuErrCheck(mDriver->cuModuleGetFunction(&mFunction, mModule, kFuncName), mDriver);
TLLM_CHECK(mFunction != nullptr);
sync_check_cuda_error();
// Populate mSharedMemBytes.
CUdeviceptr shmem_dev_ptr = 0;
cuErrCheck(mDriver->cuModuleGetGlobal(&shmem_dev_ptr, nullptr, mModule, kSmemName), mDriver);
TLLM_CHECK(shmem_dev_ptr != 0);
cuErrCheck(mDriver->cuMemcpyDtoH(&mSharedMemBytes, shmem_dev_ptr, sizeof(unsigned int)), mDriver);
TLLM_CHECK(mSharedMemBytes > 0);
/* Set 46KB threshold here because we have to take static/driver shared memory into consideration. */
if (mSharedMemBytes >= 46 * 1024)
{
cuErrCheck(mDriver->cuFuncSetAttribute(
mFunction, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, mSharedMemBytes),
mDriver);
}
sync_check_cuda_error();
mInitialized = true;
}
}
} // namespace jit

View File

@ -31,20 +31,35 @@ class CubinObj
public:
// Default constructor constructs an empty unusable CubinObj instance.
CubinObj() = default;
CubinObj(std::string const& content);
// Constructs from raw cubin content.
explicit CubinObj(std::string const& content);
// Deserializes from a serialization buffer.
CubinObj(void const* buffer, size_t buffer_size);
CubinObj(CubinObj const& other);
CubinObj& operator=(CubinObj const& other);
// CubinObj can be move-constructed/assigned.
CubinObj(CubinObj&& other) = default;
CubinObj& operator=(CubinObj&& other) = default;
// Should be called at least once before calling launch().
void initialize();
void launch(dim3 gridDim, dim3 blockDim, CUstream hStream, void** kernelParams);
// It is safe to call getSerializeSize()/serialize() before calling initialize().
size_t getSerializationSize() const noexcept;
void serialize(void* buffer, size_t buffer_size) const noexcept;
private:
void initialize(char const* content, char const* funcName);
std::shared_ptr<tensorrt_llm::common::CUDADriverWrapper> mDriver;
static constexpr char const* kFuncName = "kernel_mha";
static constexpr char const* kSmemName = "smemSize";
// Constructors should populate mContent.
std::string mContent;
// Fields below are undefined prior to initialize() call.
bool mInitialized;
std::shared_ptr<tensorrt_llm::common::CUDADriverWrapper> mDriver;
CUmodule mModule;
CUfunction mFunction;
unsigned int mSharedMemBytes;

View File

@ -29,7 +29,7 @@ namespace kernels
namespace jit
{
// A collection of CubinObjs, with caching functionality.
// A thread-safe collection of CubinObjs, with caching functionality.
template <typename Key, class Hash = std::hash<Key>>
class CubinObjRegistryTemplate
{
@ -64,6 +64,7 @@ public:
std::unique_ptr<CubinObjRegistryTemplate<Key, Hash>> clone() const noexcept
{
std::lock_guard<std::mutex> lock(mMutex);
auto result = std::make_unique<CubinObjRegistryTemplate<Key, Hash>>();
for (auto const& p : mMap)
{
@ -74,6 +75,7 @@ public:
size_t getSerializationSize() const noexcept
{
std::lock_guard<std::mutex> lock(mMutex);
size_t result = sizeof(uint32_t);
for (auto&& p : mMap)
{
@ -85,6 +87,7 @@ public:
void serialize(void* buffer_, size_t buffer_size) const noexcept
{
std::lock_guard<std::mutex> lock(mMutex);
size_t remaining_buffer_size = buffer_size;
uint8_t* buffer = static_cast<uint8_t*>(buffer_);
uint32_t n = mMap.size();
@ -108,31 +111,61 @@ public:
TLLM_CHECK(remaining_buffer_size == 0);
}
// Returns directly if the Cubin already exists in the registry, otherwise call compileEngine to compile it.
//
// compileEngine may be nullptr.
CubinObj* getCubin(Key const& key, CompileEngine* compileEngine)
// Compiles and inserts the cubin if not found in mMap. Does nothing otherwise.
void insertCubinIfNotExists(Key const& key, CompileEngine* compileEngine)
{
TLLM_CHECK(compileEngine != nullptr);
std::lock_guard<std::mutex> lock(mMutex);
auto iter = mMap.find(key);
if (iter != mMap.end())
{
return &(iter->second);
return;
}
TLLM_CHECK_WITH_INFO(compileEngine != nullptr, "Key not found; compileEngine shouldn't be nullptr.");
CubinObj obj = compileEngine->compile();
auto insertResultIter = mMap.insert({key, std::move(obj)}).first;
return &(insertResultIter->second);
mMap.insert({key, std::move(obj)});
return;
}
void insertCubin(Key const& key, CubinObj&& obj)
{
std::lock_guard<std::mutex> lock(mMutex);
mMap.insert({key, std::forward<CubinObj>(obj)});
}
CubinObj* getCubin(Key const& key)
{
std::lock_guard<std::mutex> lock(mMutex);
auto iter = mMap.find(key);
if (iter != mMap.end())
{
return &iter->second;
}
else
{
return nullptr;
}
}
void merge(CubinObjRegistryTemplate<Key, Hash> const& other)
{
for (auto&& p : other.mMap)
{
mMap.insert(p);
}
}
void clear()
{
std::lock_guard<std::mutex> lock(mMutex);
mMap.clear();
}
private:
std::unordered_map<Key, CubinObj, Hash> mMap;
mutable std::mutex mMutex;
};
using CubinObjKey = XQAKernelFullHashKey;

View File

@ -36,24 +36,6 @@ XQAKernelRuntimeHashKey getRuntimeHashKeyFromKernelMeta(XQAKernelMetaInfo const&
kernelMeta.mMTileSize, kernelMeta.mTokensPerPage, kernelMeta.mPagedKVCache, kernelMeta.mMultiQueryTokens};
}
XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParams)
{
unsigned int head_size = xqaParams.head_size;
int num_q_heads = xqaParams.num_q_heads;
int num_kv_heads = xqaParams.num_kv_heads;
TLLM_CHECK_WITH_INFO(num_q_heads % num_kv_heads == 0, "numQHeads should be multiple of numKVHeads.");
unsigned int num_q_heads_over_kv = num_q_heads / num_kv_heads;
unsigned int beam_width = xqaParams.beam_width;
// MultiQueryToken kernels can support any num_q_heads_over_kv that is power of 2.
unsigned int kernel_num_q_heads_over_kv = xqaParams.multi_query_tokens ? 0 : num_q_heads_over_kv;
// MultiQueryToken kernels can handle either 16/32 for M direction per CTA.
unsigned int m_tilesize = xqaParams.multi_query_tokens ? 16 : num_q_heads_over_kv;
return {xqaParams.kv_cache_data_type, head_size, beam_width, kernel_num_q_heads_over_kv, m_tilesize,
xqaParams.paged_kv_cache ? static_cast<unsigned int>(xqaParams.tokens_per_block) : 0, xqaParams.paged_kv_cache,
xqaParams.multi_query_tokens};
}
} // anonymous namespace
namespace tensorrt_llm
@ -66,7 +48,6 @@ DecoderXQAImplJIT::DecoderXQAImplJIT(DecoderXQARunner* runner)
, mDriver(tensorrt_llm::common::CUDADriverWrapper::getInstance())
, mForceXQA(tensorrt_llm::common::forceXQAKernels())
, mSM(tensorrt_llm::common::getSMVersion())
, mCubinObjRegistry(runner->mResource->getCubinObjRegistry())
{
initSupportedConfigs();
}
@ -140,8 +121,24 @@ void DecoderXQAImplJIT::prepare(XQAParams const& xqaParams)
jit::CompileEngine compileEngine(mSM, xqaParams);
// Discard getCubin() result.
mCubinObjRegistry->getCubin(key, &compileEngine);
auto registryGlobal = DecoderXQARunner::getResourceGlobal()->getCubinObjRegistry();
jit::CubinObj* uninitializedCubin = registryGlobal->getCubin(key);
if (uninitializedCubin != nullptr)
{
// Inference time. Prepare for the inference.
if (mInitializedCubinObjRegistry.getCubin(key) == nullptr)
{
// Make a copy and initialize it.
jit::CubinObj initializedCubin = *uninitializedCubin;
initializedCubin.initialize();
mInitializedCubinObjRegistry.insertCubin(key, std::move(initializedCubin));
}
}
else
{
// Engine-build time. Compile the cubin and place it into CubinObjRegistry.
registryGlobal->insertCubinIfNotExists(key, &compileEngine);
}
}
void DecoderXQAImplJIT::runWithKVLinearBuffer(
@ -204,9 +201,13 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
BuildDecoderInfoParams<T> decoder_params;
memset(&decoder_params, 0, sizeof(decoder_params));
decoder_params.seqQOffsets = launchParams.cu_seq_lens;
decoder_params.seqQLengths = xqaParams.spec_decoding_generation_lengths;
decoder_params.seqKVLengths = xqaParams.sequence_lengths;
decoder_params.batchSize = int(batch_beam_size);
decoder_params.maxQSeqLength = xqaParams.generation_input_length;
decoder_params.removePadding = xqaParams.multi_query_tokens;
TLLM_CHECK_WITH_INFO(!xqaParams.multi_query_tokens || xqaParams.spec_decoding_generation_lengths != nullptr,
"Spec_decoding_generation_lengths must be provided.");
// Rotary embedding inv_freq buffer.
decoder_params.rotaryEmbeddingScale = xqaParams.rotary_embedding_scale;
decoder_params.rotaryEmbeddingBase = xqaParams.rotary_embedding_base;
@ -222,16 +223,18 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
// NOTE: MHA kernels should read kv cache that has already been appended with new tokens' kv cache.
void* xqa_q_input_ptr = ioScratch;
QKVPreprocessingParams<T, KVCacheBuffer> preprocessingParms{static_cast<T*>(const_cast<void*>(xqaParams.qkv)),
nullptr, static_cast<T*>(xqa_q_input_ptr), kv_cache_buffer, static_cast<T const*>(xqaParams.qkv_bias), nullptr,
xqaParams.sequence_lengths, nullptr, launchParams.rotary_inv_freq_buf, (float2 const*) nullptr,
xqaParams.kv_scale_orig_quant, xqaParams.spec_decoding_position_offsets, int(batch_beam_size),
xqaParams.generation_input_length, xqaParams.timestep, xqaParams.cyclic_attention_window_size,
xqaParams.sink_token_length, int(xqaParams.batch_size * beam_width * xqaParams.generation_input_length),
xqaParams.num_q_heads, xqaParams.num_kv_heads, xqaParams.num_q_heads / xqaParams.num_kv_heads,
xqaParams.head_size, xqaParams.rotary_embedding_dim, xqaParams.rotary_embedding_base,
xqaParams.rotary_embedding_scale_type, xqaParams.rotary_embedding_scale,
xqaParams.rotary_embedding_max_positions, xqaParams.position_embedding_type, xqaParams.position_shift_enabled,
cache_type, true, false, multiprocessor_count};
nullptr, static_cast<T*>(xqa_q_input_ptr), kv_cache_buffer, static_cast<T const*>(xqaParams.qkv_bias),
xqaParams.spec_decoding_generation_lengths, xqaParams.sequence_lengths,
xqaParams.multi_query_tokens ? launchParams.cu_seq_lens : nullptr, launchParams.rotary_inv_freq_buf,
(float2 const*) nullptr, xqaParams.kv_scale_orig_quant, xqaParams.spec_decoding_position_offsets,
int(batch_beam_size), xqaParams.generation_input_length, xqaParams.timestep,
xqaParams.cyclic_attention_window_size, xqaParams.sink_token_length,
int(xqaParams.batch_size * beam_width * xqaParams.generation_input_length), xqaParams.num_q_heads,
xqaParams.num_kv_heads, xqaParams.num_q_heads / xqaParams.num_kv_heads, xqaParams.head_size,
xqaParams.rotary_embedding_dim, xqaParams.rotary_embedding_base, xqaParams.rotary_embedding_scale_type,
xqaParams.rotary_embedding_scale, xqaParams.rotary_embedding_max_positions, xqaParams.position_embedding_type,
xqaParams.position_shift_enabled, cache_type, true, false, multiprocessor_count, xqaParams.rotary_vision_start,
xqaParams.rotary_vision_length};
invokeQKVPreprocessing<T, KVCacheBuffer>(preprocessingParms, stream);
sync_check_cuda_error();
@ -245,7 +248,8 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
unsigned int kernel_m_tilesize = xqaParams.multi_query_tokens ? mTileSize : num_q_heads_over_kv;
jit::CubinObjKey key = getCubinObjKeyFromXQAParams(xqaParams);
jit::CubinObj* cubinObj = mCubinObjRegistry->getCubin(key, /*compileEngine=*/nullptr);
jit::CubinObj* cubinObj = mInitializedCubinObjRegistry.getCubin(key);
TLLM_CHECK(cubinObj != nullptr);
if (xqaParams.multi_query_tokens)
{
@ -275,8 +279,8 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
}
else
{
// mha_sm90.cu kernels. Default to false because it is not available in JIT path for now.
bool const isGmmaKernel = false;
bool const isGmmaKernel = (mSM == kSM_90 && xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3
&& xqaParams.beam_width == 1);
constexpr uint32_t kMAX_NB_KERNEL_PARAMS = 11;
uint32_t const maxNbKernelParams = (isGmmaKernel ? 11 : 10);
uint32_t idxNextParam = 0;

View File

@ -60,7 +60,8 @@ private:
bool mForceXQA;
int mSM;
jit::CubinObjRegistry* mCubinObjRegistry;
jit::CubinObjRegistry mInitializedCubinObjRegistry;
jit::CubinObjKey getCubinObjKeyFromXQAParams(XQAParams const& xqaParams) const;
//! The first prototype just takes whatever available from the Precompiled cubins.

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f51307e90efbdd3dadc404efafb3b8a96ddbdb89a9068eba0b9676656be7d46d
size 80202640
oid sha256:8de0cd3bd46925e008f263b3f6c78c17f198578f74e23bc90661bec5a9acfbb1
size 80250768

View File

@ -1,2 +1,2 @@
b3823dd8e1d7f154019fb7dc24172ff4 libtensorrt_llm_nvrtc_wrapper.so
8d4b145290d5984494a1fa6e380d01456534dc62 commit
5b6c74ce66f62d2a58aa9cac16f11ad6 libtensorrt_llm_nvrtc_wrapper.so
c0bd2b69c932257678a2aad9bd8baba4b291795e commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:08384c1d7a80a86d888f6f23a5687ccb102b1a510b66db8dbcc3169127e4e88a
size 83472488
oid sha256:bbf358364915d5b023a6d0574cde0f602c104d24efe0bf5c04eeee4610a2413e
size 83541760

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:22867facd7d8dfa699618884d2e6912b1a2a7afedc299aa91e14b36353d6b8bd
size 1011200
oid sha256:84319476e8ecf9666f40f69355f19ec3b585fc0987f940be14af9e11e3f524c3
size 1080832

View File

@ -213,14 +213,7 @@ public:
// Use mTileSize = 16 kernels when qSeqLen <= 16.
unsigned int qSeqLen = static_cast<unsigned int>(xqaParams.generation_input_length);
unsigned int mTileSize = qSeqLen <= 16 ? 16 : 32;
// MultiQueryToken kernels can support any num_q_heads_over_kv that is power of 2.
unsigned int kernel_num_q_heads_over_kv = xqaParams.multi_query_tokens ? 0 : num_q_heads_over_kv;
// MultiQueryToken kernels can handle either 16/32 for M direction per CTA.
unsigned int kernel_m_tilesize = xqaParams.multi_query_tokens ? mTileSize : num_q_heads_over_kv;
XQAKernelRuntimeHashKey hash_key{xqaParams.kv_cache_data_type, head_size, beam_width,
kernel_num_q_heads_over_kv, kernel_m_tilesize,
xqaParams.paged_kv_cache ? static_cast<unsigned int>(xqaParams.tokens_per_block) : 0,
xqaParams.paged_kv_cache, xqaParams.multi_query_tokens};
XQAKernelRuntimeHashKey hash_key = getRuntimeHashKeyFromXQAParams(xqaParams);
auto const findIter = mFunctions.find(hash_key);
TLLM_CHECK_WITH_INFO(findIter != mFunctions.end(), "XQAKernelFunc not found.");
@ -310,28 +303,6 @@ public:
}
}
private:
static uint32_t getElemBytes(CUtensorMapDataType_enum dataType)
{
switch (dataType)
{
case CU_TENSOR_MAP_DATA_TYPE_UINT8: return 1;
case CU_TENSOR_MAP_DATA_TYPE_UINT16: return 2;
case CU_TENSOR_MAP_DATA_TYPE_UINT32: return 4;
case CU_TENSOR_MAP_DATA_TYPE_INT32: return 4;
case CU_TENSOR_MAP_DATA_TYPE_UINT64: return 8;
case CU_TENSOR_MAP_DATA_TYPE_INT64: return 8;
case CU_TENSOR_MAP_DATA_TYPE_FLOAT16: return 2;
case CU_TENSOR_MAP_DATA_TYPE_FLOAT32: return 4;
case CU_TENSOR_MAP_DATA_TYPE_FLOAT64: return 8;
case CU_TENSOR_MAP_DATA_TYPE_BFLOAT16: return 2;
case CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ: return 4;
case CU_TENSOR_MAP_DATA_TYPE_TFLOAT32: return 4;
case CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ: return 4;
}
throw std::runtime_error("unsupported data type");
}
protected:
std::shared_ptr<tensorrt_llm::common::CUDADriverWrapper> mDriver;

View File

@ -36,10 +36,9 @@ namespace tensorrt_llm
namespace kernels
{
DecoderXQARunner::DecoderXQARunner(Resource* resource, const XQADataType data_type, int num_heads, int num_kv_heads,
int head_size, bool multi_block_mode)
: mResource(resource)
, mDataType(data_type)
DecoderXQARunner::DecoderXQARunner(
const XQADataType data_type, int num_heads, int num_kv_heads, int head_size, bool multi_block_mode)
: mDataType(data_type)
, mNumHeads(num_heads)
, mNumKVHeads(num_kv_heads)
, mHeadSize(head_size)
@ -104,18 +103,12 @@ size_t DecoderXQARunner::getWorkspaceSize(int max_batch_beam_size, int max_num_t
DecoderXQAImpl* DecoderXQARunner::getImplFromXQAParams(XQAParams const& xqaParams)
{
if (tensorrt_llm::common::getSMVersion() == kSM_90)
{
// Always use Precompiled impl for sm90 until Hopper XQA source gets integrated to JIT codepath.
return mPrecompiledImpl.get();
}
if (xqaParams.multi_query_tokens)
{
// Use precompiled cubin for medusa, because medusa cubins are generated from a different CUDA source file than
// non-medusa.
return mPrecompiledImpl.get();
}
if (tensorrt_llm::common::getEnvEnableXQAJIT())
{
return mJITImpl.get();
@ -143,6 +136,12 @@ void DecoderXQARunner::run(
return getImplFromXQAParams(xqa_params)->run(xqa_params, kv_cache_buffer, stream);
}
DecoderXQARunner::Resource* DecoderXQARunner::getResourceGlobal()
{
static DecoderXQARunner::Resource sResource;
return &sResource;
}
template void DecoderXQARunner::run(
XQAParams const& xqa_params, KVLinearBuffer const& kv_linear_buffer, cudaStream_t const& stream);
template void DecoderXQARunner::run(

View File

@ -77,10 +77,8 @@ struct XQADispatchHelper<__nv_bfloat16, KVBlockArray>
class DecoderXQARunner
{
public:
// Resources for constructing a DecoderXQARunner object.
class Resource;
DecoderXQARunner(Resource* resource, const XQADataType data_type, int num_heads, int num_kv_heads, int head_size,
bool multi_block_mode);
DecoderXQARunner(
const XQADataType data_type, int num_heads, int num_kv_heads, int head_size, bool multi_block_mode);
~DecoderXQARunner();
/**
@ -169,6 +167,9 @@ public:
this->run(xqa_params, kv_cache_buffer, stream);
}
class Resource;
static Resource* getResourceGlobal();
private:
bool shouldUseImpl(XQAParams const& xqa_params, bool for_configure_plugin);
void prepareForRun(XQAParams const& xqa_params);
@ -178,8 +179,6 @@ private:
static constexpr int kMaxBeamWidth = 4;
Resource* mResource;
XQADataType mDataType;
int mNumHeads;
int mNumKVHeads;
@ -206,11 +205,21 @@ public:
Resource(void const* buffer, size_t buffer_size);
~Resource() = default;
void merge(Resource const& other)
{
getCubinObjRegistry()->merge(*other.getCubinObjRegistry());
}
jit::CubinObjRegistry* getCubinObjRegistry()
{
return mCubinObjRegistry.get();
}
jit::CubinObjRegistry const* getCubinObjRegistry() const
{
return mCubinObjRegistry.get();
}
size_t getSerializationSize() const noexcept;
void serialize(void* buffer, size_t buffer_size) const noexcept;

View File

@ -339,7 +339,7 @@ __global__ void insertUnfinishedPathKernel(BeamHypotheses bh)
// Other parameters
bh.sequenceLengthsCBA[dstBeam] = bh.sequenceLengths[srcBeam];
bh.normedScoresCBA[dstBeam]
= applyLengthPenalty(bh.cumLogProbs[srcBeam], step - bh.inputLengths[srcBeam], bh.lengthPenalties[bid]);
= applyLengthPenalty(bh.cumLogProbs[srcBeam], step - bh.inputLengths[srcBeam] + 1, bh.lengthPenalties[bid]);
bh.cumLogProbsCBA[dstBeam] = bh.cumLogProbs[srcBeam];
bh.numBeamsCBA[bid]++;
}

View File

@ -1070,7 +1070,7 @@ std::vector<size_t> CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::getWo
size_t const sorter_size = CubKeyValueSorter::getWorkspaceSize(num_rows, num_experts);
size_t const fc2_result_size = permuted_elems * gemm_output_dtype; // May be an intermediate type for quantization
size_t const hopper_size = using_hopper ? HopperGroupedGemmInput::workspaceSize(num_experts_per_node) : 0;
size_t const gemm_workspace_size = moe_gemm_runner_.calcMaxWorkspaceSize(num_experts_per_node);
size_t const gemm_workspace_size = moe_gemm_runner_.getMaxWorkspaceSize(num_experts_per_node);
std::vector<size_t> workspace{source_rows_size, permuted_rows_size, permuted_experts_size, permuted_data_size,
total_rows_before_expert_size, softmax_out_size, glu_inter_size,
@ -1085,7 +1085,7 @@ size_t CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::getWorkspaceSize(i
ActivationType activation_type, MOEParallelismConfig parallelism_config) const
{
int const ep_size = parallelism_config.ep_size;
TLLM_CHECK_WITH_INFO(num_experts % ep_size == 0, "Number of experts must be a multiple of tp size");
TLLM_CHECK_WITH_INFO(num_experts % ep_size == 0, "Number of experts must be a multiple of ep size");
auto workspace = getWorkspaceBufferSizes(
num_rows, hidden_size, inter_size, num_experts, num_experts / ep_size, k, activation_type);
return tensorrt_llm::common::calculateTotalWorkspaceSize(workspace.data(), workspace.size());

View File

@ -52,15 +52,6 @@ private:
int num_bits_;
};
enum class MOEParallelismMode : int
{
NONE = 0, //!< Ignore parallelism and duplicate the work across all nodes
EXPERT_PARALLELISM, //!< Divide the experts between each node. The number of experts must be a multiple of
//!< parallelism
TENSOR_PARALLELISM, //!< Divide the weight matrices between the nodes. The hidden dimension must be a multiple of
//!< parallelism
};
enum class MOEExpertScaleNormalizationMode : int
{
NONE = 0, //!< Run the softmax on all scales and select the topk
@ -91,20 +82,23 @@ enum class MOEExpertScaleNormalizationMode : int
*/
struct MOEParallelismConfig
{
constexpr static MOEParallelismConfig TensorParallelism(int tp_size, int tp_rank)
int tp_size = 1;
int tp_rank = 0;
int ep_size = 1;
int ep_rank = 0;
bool operator==(MOEParallelismConfig const& other) const
{
return {tp_size, tp_rank, 1, 0};
return tp_size == other.tp_size && tp_rank == other.tp_rank && ep_size == other.ep_size
&& ep_rank == other.ep_rank;
}
constexpr static MOEParallelismConfig ExpertParallelism(int ep_size, int ep_rank)
friend std::ostream& operator<<(std::ostream& os, MOEParallelismConfig const& config)
{
return {1, 0, ep_size, ep_rank};
os << "tp_size: " << config.tp_size << ", tp_rank: " << config.tp_rank << ", ep_size: " << config.ep_size
<< ", ep_rank: " << config.ep_rank;
return os;
}
int const tp_size = 1;
int const tp_rank = 0;
int const ep_size = 1;
int const ep_rank = 0;
};
struct QuantParams

View File

@ -14,11 +14,7 @@
* limitations under the License.
*/
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
#include "tensorrt_llm/kernels/speculativeDecoding/explicitDraftTokensKernels.h"
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
@ -33,35 +29,34 @@ using namespace tensorrt_llm::runtime;
namespace tensorrt_llm::kernels::speculative_decoding
{
size_t invokeScanSpecDecodingGenerationLengths(void* __restrict__ scanTempStorage, size_t scanTempStorageBytes,
SizeType32 const* __restrict__ specDecodingGenerationLengths,
SizeType32* __restrict__ maxSpecDecodingGenerationLengths, SizeType32 batchSize, cudaStream_t stream)
size_t invokeScanGenerationLengths(void* __restrict__ scanTempStorage, size_t scanTempStorageBytes,
SizeType32 const* __restrict__ generationLengths, SizeType32* __restrict__ scannedGenerationLengths,
SizeType32 batchSize, cudaStream_t stream)
{
cub::DeviceScan::InclusiveSum(scanTempStorage, scanTempStorageBytes, specDecodingGenerationLengths,
maxSpecDecodingGenerationLengths, batchSize, stream);
cub::DeviceScan::InclusiveSum(
scanTempStorage, scanTempStorageBytes, generationLengths, scannedGenerationLengths, batchSize, stream);
return scanTempStorageBytes;
}
size_t invokeReduceMaxSpecDecodingGenerationLengths(void* __restrict__ reduceMaxTempStorage,
size_t reduceTempStorageBytes, SizeType32 const* __restrict__ specDecodingGenerationLengths,
SizeType32* __restrict__ scannedSpecDecodingGenerationLengths, SizeType32 batchSize, cudaStream_t stream)
size_t invokeReduceMaxGenerationLengths(void* __restrict__ reduceMaxTempStorage, size_t reduceTempStorageBytes,
SizeType32 const* __restrict__ generationLengths, SizeType32* __restrict__ maxGenerationLengths,
SizeType32 batchSize, cudaStream_t stream)
{
cub::DeviceReduce::Max(reduceMaxTempStorage, reduceTempStorageBytes, specDecodingGenerationLengths,
scannedSpecDecodingGenerationLengths, batchSize, stream);
cub::DeviceReduce::Max(
reduceMaxTempStorage, reduceTempStorageBytes, generationLengths, maxGenerationLengths, batchSize, stream);
return reduceTempStorageBytes;
}
// inclusive prefix sum specDecodingGenerationLengths and reduce max specDecodingGenerationLengths
void invokeScanReduceSpecDecodingGenerationLengths(SizeType32 batchSize,
SizeType32 const* __restrict__ specDecodingGenerationLengths, void* __restrict__ scanTempStorage,
size_t scanTempStorageBytes, SizeType32* __restrict__ scanedSpecDecodingGenerationLengths,
void* __restrict__ reduceMaxTempStorage, size_t reduceMaxTempStorageBytes,
SizeType32* maxSpecDecodingGenerationLengths, cudaStream_t stream)
// inclusive prefix sum generationLengths and reduce max generationLengths
void invokeScanReduceGenerationLengths(SizeType32 batchSize, SizeType32 const* __restrict__ generationLengths,
void* __restrict__ scanTempStorage, size_t scanTempStorageBytes, SizeType32* __restrict__ scanedGenerationLengths,
void* __restrict__ reduceMaxTempStorage, size_t reduceMaxTempStorageBytes, SizeType32* maxGenerationLengths,
cudaStream_t stream)
{
invokeScanSpecDecodingGenerationLengths(scanTempStorage, scanTempStorageBytes, specDecodingGenerationLengths,
scanedSpecDecodingGenerationLengths, batchSize, stream);
invokeReduceMaxSpecDecodingGenerationLengths(reduceMaxTempStorage, reduceMaxTempStorageBytes,
specDecodingGenerationLengths, maxSpecDecodingGenerationLengths, batchSize, stream);
invokeScanGenerationLengths(
scanTempStorage, scanTempStorageBytes, generationLengths, scanedGenerationLengths, batchSize, stream);
invokeReduceMaxGenerationLengths(
reduceMaxTempStorage, reduceMaxTempStorageBytes, generationLengths, maxGenerationLengths, batchSize, stream);
}
////////////////////////
@ -100,27 +95,25 @@ __device__ SizeType32 positivePowerOfTwo(SizeType32 n)
return res;
}
__global__ void getSpecDecodingPackedMask(SizeType32 const* __restrict__ specDecodingCumGenerationLengths,
SizeType32 const* __restrict__ specDecodingMaxGenerationLengths, bool const* __restrict__ specDecodingMask,
SizeType32 const* __restrict__ batchSlots, SizeType32 maxDraftTokens,
SizeType32* __restrict__ specDecodingPackedMask)
__global__ void getPackedMask(SizeType32 const* __restrict__ cumGenerationLengths,
SizeType32 const* __restrict__ maxGenerationLengths, bool const* __restrict__ mask,
SizeType32 const* __restrict__ batchSlots, SizeType32 maxDraftTokens, SizeType32* __restrict__ packedMask)
{
auto const batchIdx = static_cast<SizeType32>(blockIdx.y);
auto const tokenIdx = static_cast<SizeType32>(blockIdx.x);
auto const numTokens = (batchIdx == 0)
? specDecodingCumGenerationLengths[0]
: specDecodingCumGenerationLengths[batchIdx] - specDecodingCumGenerationLengths[batchIdx - 1];
auto const numTokens = (batchIdx == 0) ? cumGenerationLengths[0]
: cumGenerationLengths[batchIdx] - cumGenerationLengths[batchIdx - 1];
if (tokenIdx >= numTokens)
{
return;
}
auto const maxGenerationLength = specDecodingMaxGenerationLengths[0];
auto const maxGenerationLength = maxGenerationLengths[0];
auto const numPackedMasks = divUp(maxDraftTokens, 32);
auto const outputStartId = batchSlots ? (batchSlots[batchIdx] * (maxDraftTokens + 1))
: ((batchIdx == 0) ? 0 : specDecodingCumGenerationLengths[batchIdx - 1]);
auto* outputPtr = specDecodingPackedMask + (outputStartId + tokenIdx) * numPackedMasks;
: ((batchIdx == 0) ? 0 : cumGenerationLengths[batchIdx - 1]);
auto* outputPtr = packedMask + (outputStartId + tokenIdx) * numPackedMasks;
if (tokenIdx == 0)
{
for (auto maskId = static_cast<SizeType32>(threadIdx.x); maskId < numPackedMasks;
@ -132,18 +125,18 @@ __global__ void getSpecDecodingPackedMask(SizeType32 const* __restrict__ specDec
}
else
{
bool const* specDecodingMaskPtr = specDecodingMask + batchIdx * maxGenerationLength * maxGenerationLength
+ tokenIdx * maxGenerationLength + 1;
extern __shared__ char shSpecDecodingMask[];
bool const* maskPtr
= mask + batchIdx * maxGenerationLength * maxGenerationLength + tokenIdx * maxGenerationLength + 1;
extern __shared__ char shMask[];
if (threadIdx.x == 0)
{
shSpecDecodingMask[maxGenerationLength - 1] = '1';
shMask[maxGenerationLength - 1] = '1';
}
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < maxGenerationLength - 1;
ti += static_cast<SizeType32>(blockDim.x))
{
auto const shIndex = maxGenerationLength - 1 - ti - 1;
shSpecDecodingMask[shIndex] = specDecodingMaskPtr[ti] ? '1' : '0';
shMask[shIndex] = maskPtr[ti] ? '1' : '0';
}
__syncthreads();
for (auto maskId = static_cast<SizeType32>(threadIdx.x); maskId < numPackedMasks;
@ -156,19 +149,19 @@ __global__ void getSpecDecodingPackedMask(SizeType32 const* __restrict__ specDec
}
else
{
auto const shSpecDecodingMaskIndexStart
auto const shMaskIndexStart
= ((maxGenerationLength - (maskId + 1) * 32) < 0) ? 0 : (maxGenerationLength - (maskId + 1) * 32);
auto const shSpecDecodingMaskIndexEnd = maxGenerationLength - (maskId * 32 + 1) + 1;
auto const shMaskIndexEnd = maxGenerationLength - (maskId * 32 + 1) + 1;
auto const validNumBits = shSpecDecodingMaskIndexEnd - shSpecDecodingMaskIndexStart;
auto const firstBit1 = (shSpecDecodingMask[shSpecDecodingMaskIndexStart] == '1') ? true : false;
auto const validNumBits = shMaskIndexEnd - shMaskIndexStart;
auto const firstBit1 = (shMask[shMaskIndexStart] == '1') ? true : false;
SizeType32 mask31bits = 0;
if (validNumBits != 1)
{
for (auto i = shSpecDecodingMaskIndexStart + 1; i < shSpecDecodingMaskIndexEnd; i++)
for (auto i = shMaskIndexStart + 1; i < shMaskIndexEnd; i++)
{
auto const index = (validNumBits - 1) - (i - shSpecDecodingMaskIndexStart - 1) - 1;
mask31bits += (shSpecDecodingMask[i] == '1') ? positivePowerOfTwo(index) : 0;
auto const index = (validNumBits - 1) - (i - shMaskIndexStart - 1) - 1;
mask31bits += (shMask[i] == '1') ? positivePowerOfTwo(index) : 0;
}
}
SizeType32 mask32bits;
@ -187,19 +180,47 @@ __global__ void getSpecDecodingPackedMask(SizeType32 const* __restrict__ specDec
}
} // namespace
void invokeConvertSpecDecodingMaskToPackedMask(SizeType32 batchSize,
SizeType32 const* __restrict__ specDecodingCumGenerationLengths,
SizeType32 const* __restrict__ specDecodingMaxGenerationLengths, bool const* __restrict__ specDecodingMask,
void invokeConvertMaskToPackedMask(SizeType32 batchSize, SizeType32 const* __restrict__ cumGenerationLengths,
SizeType32 const* __restrict__ maxGenerationLengths, bool const* __restrict__ mask,
SizeType32 const* __restrict__ batchSlots, SizeType32 maxDraftTokens, SizeType32 maxGenerationLength,
SizeType32* __restrict__ specDecodingPackedMask, cudaStream_t stream)
SizeType32* __restrict__ packedMask, cudaStream_t stream)
{
dim3 block(32);
dim3 grid(maxGenerationLength, batchSize);
size_t shmSize = maxGenerationLength * sizeof(char);
getSpecDecodingPackedMask<<<grid, block, shmSize, stream>>>(specDecodingCumGenerationLengths,
specDecodingMaxGenerationLengths, specDecodingMask, batchSlots, maxDraftTokens, specDecodingPackedMask);
getPackedMask<<<grid, block, shmSize, stream>>>(
cumGenerationLengths, maxGenerationLengths, mask, batchSlots, maxDraftTokens, packedMask);
}
namespace
{
template <typename T>
__global__ void fillContextBuffers(FillContextExplicitDraftTokensParams<T> params)
{
auto const bid = static_cast<SizeType32>(blockIdx.x);
auto const batchSlot = params.batchSlots ? params.batchSlots[bid] : bid;
if (threadIdx.x == 0)
{
// Generate new random data for sampling.
params.randDataSample[batchSlot] = static_cast<T>(curand_uniform(params.curandState + batchSlot));
// Copy temperature.
params.outputTemperatures[batchSlot] = __frcp_rn(params.inputTemperatures[batchSlot]);
}
}
} // namespace
template <typename T>
void invokeFillContextBuffers(FillContextExplicitDraftTokensParams<T> const& params, cudaStream_t stream)
{
SizeType32 constexpr BLOCK_SIZE = 32;
fillContextBuffers<<<params.batchSize, BLOCK_SIZE, 0, stream>>>(params);
}
template void invokeFillContextBuffers(FillContextExplicitDraftTokensParams<float> const& params, cudaStream_t stream);
template void invokeFillContextBuffers(FillContextExplicitDraftTokensParams<half> const& params, cudaStream_t stream);
namespace
{
template <typename T>
@ -216,6 +237,8 @@ __global__ void extractExplicitDraftTokens(ExtractExplicitDraftTokensParams<T> p
auto const bestPathIdx = params.bestPathIndices[bid];
// Get current seq len (w/o newly accepted tokens).
auto const curSeqLen = params.sequenceLengths[batchSlot];
// `last*` tensors do not have data for context requests.
auto const lastTensorBid = bid - params.numContextRequests;
// Get output ids.
auto* outputIdsRequest = params.outputIds + batchSlot * params.maxSeqLen;
@ -237,7 +260,8 @@ __global__ void extractExplicitDraftTokens(ExtractExplicitDraftTokensParams<T> p
{
// Read 1:bestPathLength slice of last draft tokens at best path idx.
// This tensor comes directly from engine and has linear batch index.
auto const pathOffset = flat_index3(bid, bestPathIdx, ti + 1, params.numPaths, params.maxPathLength);
auto const pathOffset
= flat_index3(lastTensorBid, bestPathIdx, ti + 1, params.numPaths, params.maxPathLength);
// Read accepted token from last draft tokens.
acceptedToken = params.lastDraftTokens[pathOffset];
}
@ -253,6 +277,11 @@ __global__ void extractExplicitDraftTokens(ExtractExplicitDraftTokensParams<T> p
= params.nextDraftTokens[bid * params.numPaths * params.maxPathLength + ti];
params.unpackedNextDraftIndices[batchSlot * params.numPaths * params.maxPathLength + ti]
= params.inputUnpackedNextDraftIndices[bid * params.numPaths * params.maxPathLength + ti];
if (lastTensorBid >= 0)
{
params.outputLastDraftIndices[batchSlot * params.numPaths * params.maxPathLength + ti]
= params.lastDraftIndices[lastTensorBid * params.numPaths * params.maxPathLength + ti];
}
}
auto const numNextDraftTokens = (bid == 0)
@ -274,14 +303,13 @@ __global__ void extractExplicitDraftTokens(ExtractExplicitDraftTokensParams<T> p
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < numNextDraftTokens;
ti += static_cast<SizeType32>(blockDim.x))
{
params.outputPositionIds[batchSlot * maxDecodingTokens + ti] = params.packedPositionIds[startId + ti];
params.outputPositionIds[batchSlot * maxDecodingTokens + ti] = params.packedPositionIds[startId + ti] - 1;
}
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < params.numPaths * (params.maxPathLength - 1);
ti += static_cast<SizeType32>(blockDim.x))
{
// Generate new random data for token verification.
// This tensor goes directly to engine and has linear batch index.
auto const offset = flat_index2(batchSlot, ti, params.numPaths * (params.maxPathLength - 1));
params.randDataVerification[offset] = static_cast<T>(curand_uniform(params.curandState + batchSlot));
}
@ -291,23 +319,31 @@ __global__ void extractExplicitDraftTokens(ExtractExplicitDraftTokensParams<T> p
if (threadIdx.x == 0)
{
// Update pos id base.
// This tensor goes directly to engine and has linear batch index.
params.outputPositionIdsBase[batchSlot] = params.inputPositionIdsBase[bid] + bestPathLength;
// Set number of accepted tokens at this iteration.
params.acceptedLengths[batchSlot] = bestPathLength;
// Set number of draft tokens for the next iteration.
params.prevDraftLengths[batchSlot] = params.nextDraftLengths[batchSlot];
// Set number of draft tokens for the next iteration.
params.nextDraftLengths[batchSlot] = numNextDraftTokens - 1;
// Set number of tokens passed to the engine per request for the next iteration.
params.outputGenerationLengths[batchSlot] = numNextDraftTokens;
// Generate new random data for sampling.
// This tensor goes directly to engine and has linear batch index.
params.randDataSample[batchSlot] = static_cast<T>(curand_uniform(params.curandState + batchSlot));
// Increase seqLen by accepted len.
params.sequenceLengths[batchSlot] = curSeqLen + bestPathLength;
// Copy temperature.
params.outputTemperatures[batchSlot] = params.inputTemperatures[batchSlot];
params.outputTemperatures[batchSlot] = __frcp_rn(params.inputTemperatures[batchSlot]);
// Copy best path index.
params.outputBestPathIndices[batchSlot] = bestPathIdx;
}
}
} // namespace
@ -328,12 +364,13 @@ namespace
{
template <typename VecT>
__global__ void copyProbs(uint8_t const* srcData, uint8_t* dstData, SizeType32 const* inputBatchSlots,
SizeType32 const* outputBatchSlots, SizeType32 sizeInBytes)
SizeType32 const* outputBatchSlots, SizeType32 sizeInBytes, SizeType32 inputBatchIdxOffset)
{
auto constexpr VEC_ELTS = static_cast<SizeType32>(sizeof(VecT));
auto const bid = static_cast<SizeType32>(blockIdx.y);
auto const intputBatchSlot = inputBatchSlots ? inputBatchSlots[bid] : bid;
auto const outputBatchSlot = outputBatchSlots ? outputBatchSlots[bid] : bid;
auto const inputBid = static_cast<SizeType32>(blockIdx.y) + inputBatchIdxOffset;
auto const outputBid = static_cast<SizeType32>(blockIdx.y);
auto const intputBatchSlot = inputBatchSlots ? inputBatchSlots[inputBid] : inputBid;
auto const outputBatchSlot = outputBatchSlots ? outputBatchSlots[outputBid] : outputBid;
auto const srcStartIdx = intputBatchSlot * sizeInBytes;
auto const dstStartIdx = outputBatchSlot * sizeInBytes;
auto const tidx = (static_cast<std::size_t>(blockIdx.x) * blockDim.x + threadIdx.x) * VEC_ELTS;
@ -351,7 +388,8 @@ __global__ void copyProbs(uint8_t const* srcData, uint8_t* dstData, SizeType32 c
} // namespace
void invokeCopyProbs(uint8_t const* srcDataPtr, uint8_t* dstDataPtr, SizeType32 const* inputBatchSlots,
SizeType32 const* outputBatchSlots, SizeType32 batchSize, SizeType32 copyRowSizeInBytes, cudaStream_t stream)
SizeType32 const* outputBatchSlots, SizeType32 batchSize, SizeType32 inputBatchIdxOffset,
SizeType32 copyRowSizeInBytes, cudaStream_t stream)
{
auto copyProbsInvocation = copyProbs<uint8_t>;
if (copyRowSizeInBytes % 16 == 0)
@ -375,7 +413,7 @@ void invokeCopyProbs(uint8_t const* srcDataPtr, uint8_t* dstDataPtr, SizeType32
SizeType32 constexpr BLOCKS_PER_ROW{32};
dim3 const gridSize{BLOCKS_PER_ROW, static_cast<uint32_t>(batchSize)};
copyProbsInvocation<<<gridSize, blockSize, 0, stream>>>(
srcDataPtr, dstDataPtr, inputBatchSlots, outputBatchSlots, copyRowSizeInBytes);
srcDataPtr, dstDataPtr, inputBatchSlots, outputBatchSlots, copyRowSizeInBytes, inputBatchIdxOffset);
}
template <typename T>
@ -386,12 +424,41 @@ void invokeCopyProbs(ExtractExplicitDraftTokensParams<T> const& params, cudaStre
auto const numCopyElems = params.numPaths * (params.maxPathLength - 1) * params.vocabSize;
auto const copyRowSizeInBytes = numCopyElems * sizeof(T);
invokeCopyProbs(srcDataPtr, dstDataPtr, nullptr, params.batchSlots, params.batchSize, copyRowSizeInBytes, stream);
invokeCopyProbs(
srcDataPtr, dstDataPtr, nullptr, params.batchSlots, params.batchSize, 0, copyRowSizeInBytes, stream);
}
template void invokeCopyProbs(ExtractExplicitDraftTokensParams<float> const& params, cudaStream_t stream);
template void invokeCopyProbs(ExtractExplicitDraftTokensParams<half> const& params, cudaStream_t stream);
namespace
{
template <typename T>
__global__ void packGenerationLengths(PackExplicitDraftTokensParams<T> params)
{
auto const batchIdx = static_cast<SizeType32>(blockIdx.x);
auto const batchSlot = params.batchSlots ? params.batchSlots[batchIdx] : batchIdx;
auto const isGenerationRequest = batchIdx >= params.numContextRequests;
auto const genIdx = batchIdx - params.numContextRequests;
if (threadIdx.x == 0 && isGenerationRequest)
{
params.outputGenerationLengths[genIdx] = params.inputGenerationLengths[batchSlot];
}
}
} // namespace
template <typename T>
void invokePackGenerationLengths(PackExplicitDraftTokensParams<T> const& params, cudaStream_t stream)
{
SizeType32 constexpr BLOCK_SIZE = 32;
packGenerationLengths<<<params.batchSize, BLOCK_SIZE, 0, stream>>>(params);
}
template void invokePackGenerationLengths(PackExplicitDraftTokensParams<float> const& params, cudaStream_t stream);
template void invokePackGenerationLengths(PackExplicitDraftTokensParams<half> const& params, cudaStream_t stream);
namespace
{
template <typename T>
@ -400,61 +467,79 @@ __global__ void packExplicitDraftTokens(PackExplicitDraftTokensParams<T> params)
auto const batchIdx = static_cast<SizeType32>(blockIdx.x);
auto const batchSlot = params.batchSlots ? params.batchSlots[batchIdx] : batchIdx;
auto const isGenerationRequest = batchIdx >= params.numContextRequests;
auto const genIdx = batchIdx - params.numContextRequests;
if (threadIdx.x == 0)
{
params.outputPositionIdsBase[batchIdx] = params.inputPositionIdsBase[batchSlot];
params.outputGenerationLengths[batchIdx] = params.inputGenerationLengths[batchSlot];
params.outputRandomDataSample[batchIdx] = params.inputRandomDataSample[batchSlot];
params.outputTemperatures[batchIdx] = params.inputTemperatures[batchSlot];
}
// Copy random validation data.
auto const numDecodingDraftTokens = params.numPaths * (params.maxPathLength - 1);
auto outputRandomDataValidation = params.outputRandomDataValidation + batchIdx * numDecodingDraftTokens;
auto inputRandomDataValidation = params.inputRandomDataValidation + batchSlot * numDecodingDraftTokens;
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < numDecodingDraftTokens;
ti += static_cast<SizeType32>(blockDim.x))
if (isGenerationRequest)
{
outputRandomDataValidation[ti] = inputRandomDataValidation[ti];
auto outputRandomDataValidation = params.outputRandomDataValidation + genIdx * numDecodingDraftTokens;
auto const inputRandomDataValidation = params.inputRandomDataValidation + batchSlot * numDecodingDraftTokens;
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < numDecodingDraftTokens;
ti += static_cast<SizeType32>(blockDim.x))
{
outputRandomDataValidation[ti] = inputRandomDataValidation[ti];
}
}
// Copy draft tokens and indices
auto const numUnpackedTokens = numDecodingDraftTokens + params.numPaths;
auto outputNextDraftTokens = params.outputNextDraftTokens + batchIdx * numUnpackedTokens;
auto outputNextDraftIndices = params.outputNextDraftIndices + batchIdx * numUnpackedTokens;
auto const inputNextDraftTokens = params.inputNextDraftTokens + batchSlot * numUnpackedTokens;
auto const inputNextDraftIndices = params.inputNextDraftIndices + batchSlot * numUnpackedTokens;
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < numUnpackedTokens;
ti += static_cast<SizeType32>(blockDim.x))
if (isGenerationRequest)
{
outputNextDraftTokens[ti] = inputNextDraftTokens[ti];
outputNextDraftIndices[ti] = inputNextDraftIndices[ti];
auto const numUnpackedTokens = numDecodingDraftTokens + params.numPaths;
auto outputNextDraftTokens = params.outputNextDraftTokens + genIdx * numUnpackedTokens;
auto outputNextDraftIndices = params.outputNextDraftIndices + genIdx * numUnpackedTokens;
auto const inputNextDraftTokens = params.inputNextDraftTokens + batchSlot * numUnpackedTokens;
auto const inputNextDraftIndices = params.inputNextDraftIndices + batchSlot * numUnpackedTokens;
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < numUnpackedTokens;
ti += static_cast<SizeType32>(blockDim.x))
{
outputNextDraftTokens[ti] = inputNextDraftTokens[ti];
outputNextDraftIndices[ti] = inputNextDraftIndices[ti];
}
}
auto const maxGenerationLength = params.maxGenerationLength[0];
auto const maxDecodingTokens = numDecodingDraftTokens + 1;
auto const numPackedMasks = divUp(maxGenerationLength, 32);
auto const outputMaskStartId = (batchIdx == 0) ? 0 : params.cumSumGenerationLengths[batchIdx - 1];
auto const numTokens = (batchIdx == 0)
auto const outputMaskStartId = (genIdx == 0) ? 0 : params.cumSumGenerationLengths[genIdx - 1];
auto const numTokens = (genIdx == 0)
? params.cumSumGenerationLengths[0]
: params.cumSumGenerationLengths[batchIdx] - params.cumSumGenerationLengths[batchIdx - 1];
: params.cumSumGenerationLengths[genIdx] - params.cumSumGenerationLengths[genIdx - 1];
// Copy packed masks.
// Masks are placed next to each other with offsets of cumSumGenerationLengths[bi-1]
auto const inputPackedMask = params.inputPackedMask + batchSlot * numPackedMasks * maxDecodingTokens;
auto outputPackedMask = params.outputPackedMask + outputMaskStartId * numPackedMasks;
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < numTokens * numPackedMasks;
ti += static_cast<SizeType32>(blockDim.x))
if (isGenerationRequest)
{
outputPackedMask[ti] = inputPackedMask[ti];
auto const inputPackedMask = params.inputPackedMask + batchSlot * numPackedMasks * maxDecodingTokens;
auto outputPackedMask = params.outputPackedMask + outputMaskStartId * numPackedMasks;
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < numTokens * numPackedMasks;
ti += static_cast<SizeType32>(blockDim.x))
{
outputPackedMask[ti] = inputPackedMask[ti];
}
}
// Copy pos offsets. Copy only for maxGenerationLength
auto outputPositionOffsets = params.outputPositionOffsets + batchIdx * maxGenerationLength;
auto const inputPositionOffsets = params.inputPositionOffsets + batchSlot * maxDecodingTokens;
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < maxGenerationLength;
ti += static_cast<SizeType32>(blockDim.x))
if (isGenerationRequest)
{
outputPositionOffsets[ti] = inputPositionOffsets[ti];
auto const basePosId = params.outputPositionIdsBase[batchIdx];
auto outputPositionOffsets = params.outputPositionOffsets + genIdx * maxGenerationLength;
auto outputPositionIds = params.outputPositionIds + genIdx * maxGenerationLength;
auto const inputPositionIds = params.inputPositionIds + batchSlot * maxDecodingTokens;
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < maxGenerationLength;
ti += static_cast<SizeType32>(blockDim.x))
{
auto const posId = inputPositionIds[ti];
outputPositionIds[params.numContextTokens + ti] = posId;
outputPositionOffsets[ti] = posId - basePosId + 1;
}
}
}
} // namespace
@ -477,7 +562,8 @@ void invokeCopyProbs(PackExplicitDraftTokensParams<T> const& params, cudaStream_
auto const numCopyElems = params.numPaths * (params.maxPathLength - 1) * params.vocabSize;
auto const copyRowSizeInBytes = numCopyElems * sizeof(T);
invokeCopyProbs(srcDataPtr, dstDataPtr, params.batchSlots, nullptr, params.batchSize, copyRowSizeInBytes, stream);
invokeCopyProbs(srcDataPtr, dstDataPtr, params.batchSlots, nullptr, params.numGenerationRequests,
params.numContextRequests, copyRowSizeInBytes, stream);
}
template void invokeCopyProbs(PackExplicitDraftTokensParams<float> const& params, cudaStream_t stream);

View File

@ -16,6 +16,7 @@
#pragma once
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/kernels/speculativeDecoding/common.h"
#include "tensorrt_llm/runtime/common.h"
#include <cuda_fp16.h>
@ -25,6 +26,38 @@
namespace tensorrt_llm::kernels::speculative_decoding
{
template <typename T>
struct FillContextExplicitDraftTokensParams
{
//! [maxBatchSize]
T* randDataSample{nullptr};
//! [maxBatchSize]
T* outputTemperatures{nullptr};
//! [maxBatchSize]
float const* inputTemperatures{nullptr};
//! [maxBatchSize]
curandState_t* curandState{nullptr};
//! [forwardBatchSize]
runtime::SizeType32 const* batchSlots{nullptr};
runtime::SizeType32 batchSize{0};
void checkParams() const
{
TLLM_CHECK(randDataSample);
TLLM_CHECK(outputTemperatures);
TLLM_CHECK(inputTemperatures);
TLLM_CHECK(curandState);
TLLM_CHECK(batchSlots);
TLLM_CHECK(batchSize > 0);
}
};
//! @brief Sets temperature and generates random variable for sampling.
template <typename T>
void invokeFillContextBuffers(FillContextExplicitDraftTokensParams<T> const& params, cudaStream_t stream);
template <typename T>
struct ExtractExplicitDraftTokensParams
{
@ -43,10 +76,18 @@ struct ExtractExplicitDraftTokensParams
//! [maxBatchSize]
runtime::SizeType32* acceptedLengths{nullptr};
//! [maxBatchSize]
runtime::SizeType32* prevDraftLengths{nullptr};
//! [maxBatchSize]
runtime::SizeType32* nextDraftLengths{nullptr};
//! [maxBatchSize]
runtime::SizeType32* sequenceLengths{nullptr};
//! [maxBatchSize]
runtime::SizeType32* outputGenerationLengths{nullptr};
//! [maxBatchSize]
runtime::SizeType32* outputBestPathIndices{nullptr};
//! [maxBatchSize, maxNumPaths, maxPathLength]
runtime::SizeType32* outputLastDraftIndices{nullptr};
//! [maxBatchSize]
T* randDataSample{nullptr};
//! [maxBatchSize, maxNumPaths, maxPathDraftLength]
T* randDataVerification{nullptr};
@ -58,7 +99,7 @@ struct ExtractExplicitDraftTokensParams
runtime::SizeType32 const* batchSlots{nullptr};
//! [forwardBatchSize, maxNumPaths, maxPathLength]
runtime::TokenIdType const* nextDraftTokens{nullptr};
//! [forwardBatchSize, maxNumPaths, maxPathLength]
//! [forwardBatchSize, maxNumPaths, maxPathLength], optional
runtime::TokenIdType const* lastDraftTokens{nullptr};
//! [forwardBatchSize, maxNumPaths, maxPathLength]
runtime::SizeType32 const* inputUnpackedNextDraftIndices{nullptr};
@ -74,17 +115,75 @@ struct ExtractExplicitDraftTokensParams
runtime::TokenIdType const* nextFlatTokens{nullptr};
//! [forwardBatchSize]
runtime::SizeType32 const* generationLengthInclusiveSum{nullptr};
//! [forwardBatchSize]
runtime::SizeType32 const* lastGenerationLengths{nullptr};
//! [maxBatchSize, maxNumPaths, maxPathLength]
runtime::SizeType32 const* lastDraftIndices{nullptr};
//! [forwardBatchSize, maxNumPaths, maxPathDraftLength, maxVocabSize]
T const* nextDraftProbs{nullptr};
//! [maxBatchSize]
float const* inputTemperatures{nullptr};
//! [maxBatchSize]
curandState_t* curandState{nullptr};
runtime::SizeType32 batchSize;
runtime::SizeType32 numPaths;
runtime::SizeType32 maxPathLength;
runtime::SizeType32 maxSeqLen;
runtime::SizeType32 vocabSize;
runtime::SizeType32 batchSize{0};
runtime::SizeType32 numPaths{0};
runtime::SizeType32 maxPathLength{0};
runtime::SizeType32 maxSeqLen{0};
runtime::SizeType32 vocabSize{0};
runtime::SizeType32 numContextRequests{0};
runtime::SizeType32 numGenerationRequests{0};
void checkParams() const
{
TLLM_CHECK(outputIds);
TLLM_CHECK(outputPositionIdsBase);
TLLM_CHECK(inputPositionIdsBase);
TLLM_CHECK(outputPositionIds);
TLLM_CHECK(packedPositionIds);
TLLM_CHECK(outputTemperatures);
TLLM_CHECK(inputTemperatures);
TLLM_CHECK(outputDraftProbs);
TLLM_CHECK(nextDraftProbs);
TLLM_CHECK(outputNextDraftTokens);
TLLM_CHECK(unpackedNextDraftTokens);
TLLM_CHECK(unpackedNextDraftIndices);
TLLM_CHECK(inputUnpackedNextDraftIndices);
TLLM_CHECK(outputLastDraftIndices);
TLLM_CHECK(bestPathIndices);
TLLM_CHECK(outputBestPathIndices);
TLLM_CHECK(curandState);
TLLM_CHECK(batchSlots);
TLLM_CHECK(nextDraftTokens);
TLLM_CHECK(nextFlatTokens);
TLLM_CHECK(generationLengthInclusiveSum);
TLLM_CHECK(bestPathLengths);
TLLM_CHECK(randDataSample);
TLLM_CHECK(randDataVerification);
TLLM_CHECK(acceptedLengths);
TLLM_CHECK(nextDraftLengths);
TLLM_CHECK(prevDraftLengths);
TLLM_CHECK(sequenceLengths);
TLLM_CHECK(outputGenerationLengths);
TLLM_CHECK(batchSize > 0);
TLLM_CHECK(numPaths > 0);
TLLM_CHECK(maxPathLength > 0);
TLLM_CHECK(maxSeqLen > 0);
TLLM_CHECK(vocabSize > 0);
TLLM_CHECK(numContextRequests >= 0);
TLLM_CHECK(numGenerationRequests >= 0);
TLLM_CHECK(numContextRequests + numGenerationRequests != 0);
}
};
//! @brief Modifies `outputIds` and `sequenceLengths` according to the accepted tokens
@ -146,10 +245,12 @@ struct PackExplicitDraftTokensParams
//! [forwardBatchSize, maxGenerationLength, divUp(maxGenerationLength, 32)]
int32_t const* inputPackedMask{nullptr};
//! [forwardBatchSize, maxGenerationLength]
runtime::SizeType32* outputPositionIds{nullptr};
//! [forwardBatchSize, maxGenerationLength]
runtime::SizeType32* outputPositionOffsets{nullptr};
//! [maxBatchSize, maxGenerationLength]
runtime::SizeType32 const* inputPositionOffsets{nullptr};
runtime::SizeType32 const* inputPositionIds{nullptr};
//! [forwardBatchSize, maxNumPaths, maxPathDraftLength, maxVocabSize]
T* outputDraftProbs{nullptr};
@ -161,12 +262,59 @@ struct PackExplicitDraftTokensParams
//! [maxBatchSize]
T const* inputTemperatures{nullptr};
runtime::SizeType32 batchSize;
runtime::SizeType32 numPaths;
runtime::SizeType32 maxPathLength;
runtime::SizeType32 vocabSize;
runtime::SizeType32 batchSize{0};
runtime::SizeType32 numPaths{0};
runtime::SizeType32 maxPathLength{0};
runtime::SizeType32 vocabSize{0};
runtime::SizeType32 numContextTokens{0};
runtime::SizeType32 numContextRequests{0};
runtime::SizeType32 numGenerationRequests{0};
void checkParams() const
{
TLLM_CHECK(batchSlots);
TLLM_CHECK(cumSumGenerationLengths);
TLLM_CHECK(maxGenerationLength);
TLLM_CHECK(inputPositionIdsBase);
TLLM_CHECK(inputGenerationLengths);
TLLM_CHECK(outputRandomDataSample);
TLLM_CHECK(inputRandomDataSample);
TLLM_CHECK(inputRandomDataValidation);
TLLM_CHECK(inputNextDraftTokens);
TLLM_CHECK(inputNextDraftIndices);
TLLM_CHECK(inputPackedMask);
TLLM_CHECK(inputPositionIds);
TLLM_CHECK(inputDraftProbs);
TLLM_CHECK(outputTemperatures);
TLLM_CHECK(inputTemperatures);
TLLM_CHECK(batchSize > 0);
TLLM_CHECK(numPaths > 0);
TLLM_CHECK(maxPathLength > 0);
TLLM_CHECK(vocabSize > 0);
TLLM_CHECK(numContextRequests >= 0);
TLLM_CHECK(numGenerationRequests >= 0);
TLLM_CHECK(
(numContextTokens == 0 && numContextRequests == 0) || (numContextTokens > 0 && numContextRequests > 0));
TLLM_CHECK(numContextRequests + numGenerationRequests != 0);
}
};
//! @brief Copy all rows at `batchSlots[batchIdx]` from `inputGenerationLengths` tensors to `batchIdx` rows at
//! `outputGenerationLengths` tensor.
template <typename T>
void invokePackGenerationLengths(PackExplicitDraftTokensParams<T> const& params, cudaStream_t stream);
//! @brief Copy all rows at `batchSlots[batchIdx]` from `input*` tensors to `batchIdx` rows at `output*` tensor.
template <typename T>
void invokePackExplicitDraftTokens(PackExplicitDraftTokensParams<T> const& params, cudaStream_t stream);
@ -175,27 +323,24 @@ void invokePackExplicitDraftTokens(PackExplicitDraftTokensParams<T> const& param
template <typename T>
void invokeCopyProbs(PackExplicitDraftTokensParams<T> const& params, cudaStream_t stream);
size_t invokeScanSpecDecodingGenerationLengths(void* __restrict__ reduceMaxTempStorage, size_t reduceTempStorageBytes,
runtime::SizeType32 const* __restrict__ specDecodingGenerationLengths,
runtime::SizeType32* __restrict__ scannedSpecDecodingGenerationLengths, runtime::SizeType32 batchSize,
cudaStream_t stream);
size_t invokeReduceMaxSpecDecodingGenerationLengths(void* __restrict__ reduceMaxTempStorage,
size_t reduceTempStorageBytes, runtime::SizeType32 const* __restrict__ specDecodingGenerationLengths,
runtime::SizeType32* __restrict__ scannedSpecDecodingGenerationLengths, runtime::SizeType32 batchSize,
cudaStream_t stream);
size_t invokeScanGenerationLengths(void* __restrict__ scanTempStorage, size_t scanTempStorageBytes,
runtime::SizeType32 const* __restrict__ generationLengths,
runtime::SizeType32* __restrict__ scannedGenerationLengths, runtime::SizeType32 batchSize, cudaStream_t stream);
size_t invokeReduceMaxGenerationLengths(void* __restrict__ reduceMaxTempStorage, size_t reduceTempStorageBytes,
runtime::SizeType32 const* __restrict__ generationLengths, runtime::SizeType32* __restrict__ maxGenerationLengths,
runtime::SizeType32 batchSize, cudaStream_t stream);
// inclusive prefix sum specDecodingGenerationLengths
void invokeScanReduceSpecDecodingGenerationLengths(runtime::SizeType32 batchSize,
runtime::SizeType32 const* __restrict__ specDecodingGenerationLengths, void* __restrict__ scanTempStorage,
size_t scanTempStorageBytes, runtime::SizeType32* __restrict__ scanedSpecDecodingGenerationLengths,
// inclusive prefix sum generationLengths
void invokeScanReduceGenerationLengths(runtime::SizeType32 batchSize,
runtime::SizeType32 const* __restrict__ generationLengths, void* __restrict__ scanTempStorage,
size_t scanTempStorageBytes, runtime::SizeType32* __restrict__ scanedGenerationLengths,
void* __restrict__ reduceMaxTempStorage, size_t reduceMaxTempStorageBytes,
runtime::SizeType32* maxSpecDecodingGenerationLengths, cudaStream_t stream);
runtime::SizeType32* maxGenerationLengths, cudaStream_t stream);
void invokeConvertSpecDecodingMaskToPackedMask(runtime::SizeType32 batchSize,
runtime::SizeType32 const* __restrict__ specDecodingCumGenerationLengths,
runtime::SizeType32 const* __restrict__ specDecodingMaxGenerationLengths, bool const* __restrict__ specDecodingMask,
void invokeConvertMaskToPackedMask(runtime::SizeType32 batchSize,
runtime::SizeType32 const* __restrict__ cumGenerationLengths,
runtime::SizeType32 const* __restrict__ maxGenerationLengths, bool const* __restrict__ mask,
runtime::SizeType32 const* __restrict__ batchSlots, runtime::SizeType32 maxDraftTokens,
runtime::SizeType32 maxGenerationLength, runtime::SizeType32* __restrict__ specDecodingPackedMask,
cudaStream_t stream);
runtime::SizeType32 maxGenerationLength, runtime::SizeType32* __restrict__ packedMask, cudaStream_t stream);
} // namespace tensorrt_llm::kernels::speculative_decoding

View File

@ -25,99 +25,107 @@
namespace tensorrt_llm::kernels::speculative_decoding
{
static constexpr int kUpdateKVCacheKernelShmSize = 16384;
using namespace tensorrt_llm::runtime;
static constexpr SizeType32 kUpdateKVCacheKernelShmSize = 16384;
namespace
{
template <typename KVCacheBuffer, int MaxLayerCount, typename MoveEltType>
template <typename KVCacheBuffer, SizeType32 MaxLayerCount, typename MoveEltType>
__global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array<KVCacheBuffer, MaxLayerCount> kvCacheBuffers,
int const* seqAcceptedDraftTokenOffsets, IndexType const* packedAcceptedDraftTokensIndices,
int32_t const* pastKeyValueLengths, int rewindDraftTokenCommonCount, int const* rewindDraftTokenSeparateAdjustments,
int const* seqSlotRemapping, int eltCountPerHead)
SizeType32 const* seqAcceptedDraftTokenOffsets, IndexType const* packedAcceptedDraftTokensIndices,
SizeType32 const* pastKeyValueLengths, SizeType32 rewindDraftTokenCommonCount,
SizeType32 const* rewindDraftTokenSeparateAdjustments, SizeType32 const* seqSlotRemapping,
SizeType32 const* batchSlots, SizeType32 eltCountPerHead)
{
int seqIdx = blockIdx.x;
int headIdx = blockIdx.y;
int layerIdx = blockIdx.z;
int warpIdx = threadIdx.x / 32;
int warpCount = blockDim.x / 32;
int laneIdx = threadIdx.x & 0x1f;
int seqDraftTokenStart = seqAcceptedDraftTokenOffsets[seqIdx];
int seqDraftTokenEnd = seqAcceptedDraftTokenOffsets[seqIdx + 1];
auto const seqIdx = static_cast<SizeType32>(blockIdx.x);
auto const headIdx = static_cast<SizeType32>(blockIdx.y);
auto const layerIdx = static_cast<SizeType32>(blockIdx.z);
auto const warpIdx = static_cast<SizeType32>(threadIdx.x / 32);
auto const warpCount = static_cast<SizeType32>(blockDim.x / 32);
auto const laneIdx = static_cast<SizeType32>(threadIdx.x & 0x1f);
auto const seqDraftTokenStart = seqAcceptedDraftTokenOffsets[seqIdx];
auto const seqDraftTokenEnd = seqAcceptedDraftTokenOffsets[seqIdx + 1];
auto const seqSlot = seqSlotRemapping == nullptr ? seqIdx : seqSlotRemapping[seqIdx];
int seqDraftCount = seqDraftTokenEnd - seqDraftTokenStart;
int maxEltCountPerMove = kUpdateKVCacheKernelShmSize / sizeof(MoveEltType) / seqDraftCount;
int eltCountPerMove = min(maxEltCountPerMove, eltCountPerHead);
auto const seqDraftCount = seqDraftTokenEnd - seqDraftTokenStart;
auto const maxEltCountPerMove
= static_cast<SizeType32>(kUpdateKVCacheKernelShmSize / sizeof(MoveEltType) / seqDraftCount);
auto const eltCountPerMove = min(maxEltCountPerMove, eltCountPerHead);
if (seqDraftCount == 0 || eltCountPerMove == 0)
{
return;
}
KVCacheBuffer& kvCacheBuffer = kvCacheBuffers[layerIdx];
int tokenStartIdx = pastKeyValueLengths[seqSlot] - rewindDraftTokenCommonCount;
auto tokenStartIdx = pastKeyValueLengths[seqSlot] - rewindDraftTokenCommonCount;
if (rewindDraftTokenSeparateAdjustments != nullptr)
{
tokenStartIdx -= rewindDraftTokenSeparateAdjustments[seqSlot];
auto const batchSlot = batchSlots == nullptr ? seqIdx : batchSlots[seqIdx];
tokenStartIdx -= rewindDraftTokenSeparateAdjustments[batchSlot];
}
__shared__ char loadSmemBuffer[kUpdateKVCacheKernelShmSize];
auto* eltLoadSmemBuffer = reinterpret_cast<MoveEltType*>(&loadSmemBuffer[0]);
for (int startChannelOffset = 0; startChannelOffset < eltCountPerHead; startChannelOffset += eltCountPerMove)
for (SizeType32 startChannelOffset = 0; startChannelOffset < eltCountPerHead; startChannelOffset += eltCountPerMove)
{
int eltCountCurrentMove = min(eltCountPerMove, eltCountPerHead - startChannelOffset);
SizeType32 eltCountCurrentMove = min(eltCountPerMove, eltCountPerHead - startChannelOffset);
// load K
for (int tokenIdx = warpIdx; tokenIdx < seqDraftCount; tokenIdx += warpCount)
for (SizeType32 tokenIdx = warpIdx; tokenIdx < seqDraftCount; tokenIdx += warpCount)
{
int tokenPos = packedAcceptedDraftTokensIndices[seqDraftTokenStart + tokenIdx];
auto const tokenPos = packedAcceptedDraftTokensIndices[seqDraftTokenStart + tokenIdx];
auto* tokenSmemBuffer = eltLoadSmemBuffer + tokenIdx * eltCountCurrentMove;
int tokenKVPosition = tokenStartIdx + tokenPos;
auto const tokenKVPosition = tokenStartIdx + tokenPos;
auto* kPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getKBlockPtr(seqSlot, tokenKVPosition));
for (int loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
for (SizeType32 loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
{
int channelIdx = loadChannelIdx + startChannelOffset;
int kvLocationIdx = kvCacheBuffer.getKVLocalIdx(tokenKVPosition, headIdx, eltCountPerHead, channelIdx);
auto const channelIdx = loadChannelIdx + startChannelOffset;
auto const kvLocationIdx
= kvCacheBuffer.getKVLocalIdx(tokenKVPosition, headIdx, eltCountPerHead, channelIdx);
tokenSmemBuffer[loadChannelIdx] = kPtr[kvLocationIdx];
}
}
__syncthreads();
// store K
for (int tokenIdx = warpIdx; tokenIdx < seqDraftCount; tokenIdx += warpCount)
for (SizeType32 tokenIdx = warpIdx; tokenIdx < seqDraftCount; tokenIdx += warpCount)
{
int tokenPos = tokenIdx;
auto const tokenPos = tokenIdx;
auto* tokenSmemBuffer = eltLoadSmemBuffer + tokenIdx * eltCountCurrentMove;
int tokenKVPosition = tokenStartIdx + tokenPos;
auto const tokenKVPosition = tokenStartIdx + tokenPos;
auto* kPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getKBlockPtr(seqSlot, tokenKVPosition));
for (int loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
for (SizeType32 loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
{
int channelIdx = loadChannelIdx + startChannelOffset;
int kvLocationIdx = kvCacheBuffer.getKVLocalIdx(tokenKVPosition, headIdx, eltCountPerHead, channelIdx);
auto const channelIdx = loadChannelIdx + startChannelOffset;
auto const kvLocationIdx
= kvCacheBuffer.getKVLocalIdx(tokenKVPosition, headIdx, eltCountPerHead, channelIdx);
kPtr[kvLocationIdx] = tokenSmemBuffer[loadChannelIdx];
}
}
__syncthreads();
// load V
for (int tokenIdx = warpIdx; tokenIdx < seqDraftCount; tokenIdx += warpCount)
for (SizeType32 tokenIdx = warpIdx; tokenIdx < seqDraftCount; tokenIdx += warpCount)
{
int tokenPos = packedAcceptedDraftTokensIndices[seqDraftTokenStart + tokenIdx];
auto const tokenPos = packedAcceptedDraftTokensIndices[seqDraftTokenStart + tokenIdx];
auto* tokenSmemBuffer = eltLoadSmemBuffer + tokenIdx * eltCountCurrentMove;
int tokenKVPosition = tokenStartIdx + tokenPos;
auto const tokenKVPosition = tokenStartIdx + tokenPos;
auto* vPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getVBlockPtr(seqSlot, tokenKVPosition));
for (int loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
for (SizeType32 loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
{
int channelIdx = loadChannelIdx + startChannelOffset;
int kvLocationIdx = kvCacheBuffer.getKVLocalIdx(tokenKVPosition, headIdx, eltCountPerHead, channelIdx);
auto const channelIdx = loadChannelIdx + startChannelOffset;
auto const kvLocationIdx
= kvCacheBuffer.getKVLocalIdx(tokenKVPosition, headIdx, eltCountPerHead, channelIdx);
tokenSmemBuffer[loadChannelIdx] = vPtr[kvLocationIdx];
}
}
__syncthreads();
// store V
for (int tokenIdx = warpIdx; tokenIdx < seqDraftCount; tokenIdx += warpCount)
for (SizeType32 tokenIdx = warpIdx; tokenIdx < seqDraftCount; tokenIdx += warpCount)
{
int tokenPos = tokenIdx;
auto const tokenPos = tokenIdx;
auto* tokenSmemBuffer = eltLoadSmemBuffer + tokenPos * eltCountCurrentMove;
int tokenKVPosition = tokenStartIdx + tokenPos;
auto const tokenKVPosition = tokenStartIdx + tokenPos;
auto* vPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getVBlockPtr(seqSlot, tokenKVPosition));
for (int loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
for (SizeType32 loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
{
int channelIdx = loadChannelIdx + startChannelOffset;
int kvLocationIdx = kvCacheBuffer.getKVLocalIdx(tokenKVPosition, headIdx, eltCountPerHead, channelIdx);
auto const channelIdx = loadChannelIdx + startChannelOffset;
auto const kvLocationIdx
= kvCacheBuffer.getKVLocalIdx(tokenKVPosition, headIdx, eltCountPerHead, channelIdx);
vPtr[kvLocationIdx] = tokenSmemBuffer[loadChannelIdx];
}
}
@ -126,12 +134,13 @@ __global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array<KVCacheB
}
} // namespace
template <typename KVCacheBuffer, int MaxLayerCount>
template <typename KVCacheBuffer, SizeType32 MaxLayerCount>
void updateKVCacheDraftTokenLocationBatched(KVCacheBuffer const* kvCacheBuffers,
int const* seqAcceptedDraftTokenOffsets, IndexType const* packedAcceptedDraftTokensIndices,
int32_t const* pastKeyValueLengths, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int const* seqSlotRemapping,
cudaStream_t stream)
SizeType32 const* seqAcceptedDraftTokenOffsets, IndexType const* packedAcceptedDraftTokensIndices,
SizeType32 const* pastKeyValueLengths, SizeType32 layerCount, SizeType32 seqCount, SizeType32 numKVHeads,
SizeType32 sizeInBytesPerKVHead, SizeType32 rewindDraftTokenCommonCount,
SizeType32 const* rewindDraftTokenSeparateAdjustments, SizeType32 const* seqSlotRemapping,
SizeType32 const* batchSlots, cudaStream_t stream)
{
// make sure launch buffer is enough
static_assert(MaxLayerCount * sizeof(KVCacheBuffer) <= 3072);
@ -139,22 +148,22 @@ void updateKVCacheDraftTokenLocationBatched(KVCacheBuffer const* kvCacheBuffers,
{
return;
}
int alignedBytes = 16;
SizeType32 alignedBytes = 16;
while (alignedBytes > 0 && (sizeInBytesPerKVHead % alignedBytes != 0))
{
alignedBytes >>= 1;
}
TLLM_CHECK_WITH_INFO(alignedBytes > 0, "alignedByte should be positive");
int eltCountPerHead = sizeInBytesPerKVHead / alignedBytes;
SizeType32 eltCountPerHead = sizeInBytesPerKVHead / alignedBytes;
dim3 grid(seqCount, numKVHeads, layerCount);
dim3 block(128, 1, 1);
std::array<KVCacheBuffer, MaxLayerCount> kvCacheBufferArray;
for (int i = 0; i < layerCount; i++)
for (SizeType32 i = 0; i < layerCount; i++)
{
kvCacheBufferArray[i] = kvCacheBuffers[i];
}
void (*pKernelFunc)(std::array<KVCacheBuffer, MaxLayerCount>, int const*, IndexType const*, int32_t const*, int,
int const*, int const*, int)
void (*pKernelFunc)(std::array<KVCacheBuffer, MaxLayerCount>, SizeType32 const*, IndexType const*,
SizeType32 const*, SizeType32, SizeType32 const*, SizeType32 const*, SizeType32 const*, SizeType32)
= nullptr;
switch (alignedBytes)
{
@ -170,7 +179,7 @@ void updateKVCacheDraftTokenLocationBatched(KVCacheBuffer const* kvCacheBuffers,
}
case 4:
{
pKernelFunc = &updateKVCacheDraftTokenLocationBatchedKernel<KVCacheBuffer, MaxLayerCount, int32_t>;
pKernelFunc = &updateKVCacheDraftTokenLocationBatchedKernel<KVCacheBuffer, MaxLayerCount, SizeType32>;
break;
}
case 2:
@ -187,7 +196,7 @@ void updateKVCacheDraftTokenLocationBatched(KVCacheBuffer const* kvCacheBuffers,
}
pKernelFunc<<<grid, block, 0, stream>>>(kvCacheBufferArray, seqAcceptedDraftTokenOffsets,
packedAcceptedDraftTokensIndices, pastKeyValueLengths, rewindDraftTokenCommonCount,
rewindDraftTokenSeparateAdjustments, seqSlotRemapping, eltCountPerHead);
rewindDraftTokenSeparateAdjustments, seqSlotRemapping, batchSlots, eltCountPerHead);
TLLM_CUDA_CHECK(cudaGetLastError());
}
@ -209,54 +218,59 @@ void updateKVCacheDraftTokenLocationBatched(KVCacheBuffer const* kvCacheBuffers,
* @param stream : CUDA stream to use.
*/
template <typename KVCacheBuffer>
void updateKVCacheDraftTokenLocation(KVCacheBuffer const* kvCacheBuffers, int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int layerCount, int seqCount,
int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments,
int const* seqSlotRemapping, cudaStream_t stream)
void updateKVCacheDraftTokenLocation(KVCacheBuffer const* kvCacheBuffers,
SizeType32 const* seqAcceptedDraftTokenOffsets, IndexType const* packedAcceptedDraftTokensIndices,
SizeType32 const* pastKeyValueLengths, SizeType32 layerCount, SizeType32 seqCount, SizeType32 numKVHeads,
SizeType32 sizeInBytesPerKVHead, SizeType32 rewindDraftTokenCommonCount,
SizeType32 const* rewindDraftTokenSeparateAdjustments, SizeType32 const* seqSlotRemapping,
SizeType32 const* batchSlots, cudaStream_t stream)
{
int startLayer = 0;
static constexpr int kMaxLayersPerIter = 32;
SizeType32 startLayer = 0;
static constexpr SizeType32 kMaxLayersPerIter = 32;
while (startLayer < layerCount)
{
int microBatchLayerCount = std::min(layerCount - startLayer, kMaxLayersPerIter);
SizeType32 microBatchLayerCount = std::min(layerCount - startLayer, kMaxLayersPerIter);
updateKVCacheDraftTokenLocationBatched<KVCacheBuffer, kMaxLayersPerIter>(kvCacheBuffers + startLayer,
seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices, pastKeyValueLengths, microBatchLayerCount,
seqCount, numKVHeads, sizeInBytesPerKVHead, rewindDraftTokenCommonCount,
rewindDraftTokenSeparateAdjustments, seqSlotRemapping, stream);
rewindDraftTokenSeparateAdjustments, seqSlotRemapping, batchSlots, stream);
startLayer += microBatchLayerCount;
}
}
void updateLinearKVCacheDraftTokenLocation(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int const* seqSlotRemapping,
int maxKVCacheLen, cudaStream_t stream)
void updateLinearKVCacheDraftTokenLocation(SizeType32 const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, SizeType32 const* pastKeyValueLengths,
int8_t* const* pastKeyValueList, SizeType32 layerCount, SizeType32 seqCount, SizeType32 numKVHeads,
SizeType32 sizeInBytesPerKVHead, SizeType32 rewindDraftTokenCommonCount,
SizeType32 const* rewindDraftTokenSeparateAdjustments, SizeType32 const* seqSlotRemapping, SizeType32 maxKVCacheLen,
cudaStream_t stream)
{
std::vector<KVLinearBuffer> kvLinearBuffers;
kvLinearBuffers.reserve(layerCount);
auto const sizePerToken = numKVHeads * sizeInBytesPerKVHead;
for (int i = 0; i < layerCount; i++)
for (SizeType32 i = 0; i < layerCount; i++)
{
kvLinearBuffers.emplace_back(
seqCount, maxKVCacheLen, sizePerToken, maxKVCacheLen, 0, false, pastKeyValueList[i]);
}
updateKVCacheDraftTokenLocation(kvLinearBuffers.data(), seqAcceptedDraftTokenOffsets,
packedAcceptedDraftTokensIndices, pastKeyValueLengths, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, seqSlotRemapping, stream);
rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, seqSlotRemapping, nullptr, stream);
}
void updateKVBlockArrayDraftTokenLocation(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, void* const* pointerArray,
KVBlockArray::DataType* offsetArray, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int const* seqSlotRemapping,
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream)
void updateKVBlockArrayDraftTokenLocation(SizeType32 const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, SizeType32 const* pastKeyValueLengths, void* const* pointerArray,
KVBlockArray::DataType* offsetArray, SizeType32 layerCount, SizeType32 seqCount, SizeType32 numKVHeads,
SizeType32 sizeInBytesPerKVHead, SizeType32 rewindDraftTokenCommonCount,
SizeType32 const* rewindDraftTokenSeparateAdjustments, SizeType32 const* seqSlotRemapping,
SizeType32 const* batchSlots, SizeType32 maxKVCacheLen, SizeType32 maxBlocksPerSeq, SizeType32 tokensPerBlock,
cudaStream_t stream)
{
std::vector<KVBlockArray> kvBlockArrays;
kvBlockArrays.reserve(layerCount);
auto const bytesPerToken = numKVHeads * sizeInBytesPerKVHead;
auto const bytesPerBlock = tokensPerBlock * bytesPerToken;
for (int layerIdx = 0; layerIdx < layerCount; layerIdx++)
for (SizeType32 layerIdx = 0; layerIdx < layerCount; layerIdx++)
{
auto const layerOffset = layerIdx * 2 * bytesPerBlock;
auto* const primaryPoolPointer
@ -269,49 +283,52 @@ void updateKVBlockArrayDraftTokenLocation(int const* seqAcceptedDraftTokenOffset
}
updateKVCacheDraftTokenLocation(kvBlockArrays.data(), seqAcceptedDraftTokenOffsets,
packedAcceptedDraftTokensIndices, pastKeyValueLengths, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, seqSlotRemapping, stream);
rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, seqSlotRemapping, batchSlots, stream);
}
void updateLinearKVCacheDraftTokenLocationCommonRewind(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int rewindDraftTokenCount, int const* seqSlotRemapping, int maxKVCacheLen, cudaStream_t stream)
void updateLinearKVCacheDraftTokenLocationCommonRewind(SizeType32 const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, SizeType32 const* pastKeyValueLengths,
int8_t* const* pastKeyValueList, SizeType32 layerCount, SizeType32 seqCount, SizeType32 numKVHeads,
SizeType32 sizeInBytesPerKVHead, SizeType32 rewindDraftTokenCount, SizeType32 const* seqSlotRemapping,
SizeType32 maxKVCacheLen, cudaStream_t stream)
{
updateLinearKVCacheDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
pastKeyValueLengths, pastKeyValueList, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
rewindDraftTokenCount, nullptr, seqSlotRemapping, maxKVCacheLen, stream);
}
void updateKVBlockArrayDraftTokenLocationCommonRewind(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, void* const* pointerArray,
KVBlockArray::DataType* offsetArray, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int rewindDraftTokenCount, int const* seqSlotRemapping, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock,
cudaStream_t stream)
void updateKVBlockArrayDraftTokenLocationCommonRewind(SizeType32 const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, SizeType32 const* pastKeyValueLengths, void* const* pointerArray,
KVBlockArray::DataType* offsetArray, SizeType32 layerCount, SizeType32 seqCount, SizeType32 numKVHeads,
SizeType32 sizeInBytesPerKVHead, SizeType32 rewindDraftTokenCount, SizeType32 const* seqSlotRemapping,
SizeType32 maxKVCacheLen, SizeType32 maxBlocksPerSeq, SizeType32 tokensPerBlock, cudaStream_t stream)
{
updateKVBlockArrayDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
pastKeyValueLengths, pointerArray, offsetArray, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
rewindDraftTokenCount, nullptr, seqSlotRemapping, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, stream);
rewindDraftTokenCount, nullptr, seqSlotRemapping, nullptr, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock,
stream);
}
void updateLinearKVCacheDraftTokenLocationSeparateRewind(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int* rewindDraftTokenCounts, int const* seqSlotRemapping, int maxKVCacheLen, cudaStream_t stream)
void updateLinearKVCacheDraftTokenLocationSeparateRewind(SizeType32 const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, SizeType32 const* pastKeyValueLengths,
int8_t* const* pastKeyValueList, SizeType32 layerCount, SizeType32 seqCount, SizeType32 numKVHeads,
SizeType32 sizeInBytesPerKVHead, SizeType32* rewindDraftTokenCounts, SizeType32 const* seqSlotRemapping,
SizeType32 maxKVCacheLen, cudaStream_t stream)
{
updateLinearKVCacheDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
pastKeyValueLengths, pastKeyValueList, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead, 0,
rewindDraftTokenCounts, seqSlotRemapping, maxKVCacheLen, stream);
}
void updateKVBlockArrayDraftTokenLocationSeparateRewind(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, void* const* pointerArray,
KVBlockArray::DataType* offsetArray, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int* rewindDraftTokenCounts, int const* seqSlotRemapping, int maxKVCacheLen, int maxBlocksPerSeq,
int tokensPerBlock, cudaStream_t stream)
void updateKVBlockArrayDraftTokenLocationSeparateRewind(SizeType32 const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, SizeType32 const* pastKeyValueLengths, void* const* pointerArray,
KVBlockArray::DataType* offsetArray, SizeType32 layerCount, SizeType32 seqCount, SizeType32 numKVHeads,
SizeType32 sizeInBytesPerKVHead, SizeType32 const* rewindDraftTokenCounts, SizeType32 const* seqSlotRemapping,
SizeType32 maxKVCacheLen, SizeType32 maxBlocksPerSeq, SizeType32 tokensPerBlock, cudaStream_t stream)
{
updateKVBlockArrayDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
pastKeyValueLengths, pointerArray, offsetArray, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead, 0,
rewindDraftTokenCounts, seqSlotRemapping, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, stream);
rewindDraftTokenCounts, seqSlotRemapping, nullptr, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, stream);
}
} // namespace tensorrt_llm::kernels::speculative_decoding

View File

@ -17,6 +17,7 @@
#pragma once
#include "tensorrt_llm/kernels/kvCacheUtils.h"
#include "tensorrt_llm/runtime/common.h"
#include <cstdint>
#include <cuda_runtime_api.h>
@ -44,11 +45,11 @@ using IndexType = int;
* @param maxKVCacheLen : Maximum length of each KV cache
* @param stream : CUDA stream to use.
*/
void updateLinearKVCacheDraftTokenLocationCommonRewind(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
KVLinearBuffer::DataType* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads,
int sizeInBytesPerKVHead, int rewindDraftTokenCount, int const* seqSlotRemapping, int maxKVCacheLen,
cudaStream_t stream);
void updateLinearKVCacheDraftTokenLocationCommonRewind(runtime::SizeType32 const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, runtime::SizeType32 const* pastKeyValueLengths,
KVLinearBuffer::DataType* const* pastKeyValueList, runtime::SizeType32 layerCount, runtime::SizeType32 seqCount,
runtime::SizeType32 numKVHeads, runtime::SizeType32 sizeInBytesPerKVHead, runtime::SizeType32 rewindDraftTokenCount,
runtime::SizeType32 const* seqSlotRemapping, runtime::SizeType32 maxKVCacheLen, cudaStream_t stream);
/*!
* Update Block KV cache using common rewind count.
@ -72,10 +73,12 @@ void updateLinearKVCacheDraftTokenLocationCommonRewind(int const* seqAcceptedDra
* @param tokensPerBlock : Tokens per block of Block KV cache
* @param stream : CUDA stream to use.
*/
void updateKVBlockArrayDraftTokenLocationCommonRewind(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, void* const* pointerArray,
KVBlockArray::DataType* offsetArray, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int rewindDraftTokenCount, int const* seqSlotRemapping, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock,
void updateKVBlockArrayDraftTokenLocationCommonRewind(runtime::SizeType32 const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, runtime::SizeType32 const* pastKeyValueLengths,
void* const* pointerArray, KVBlockArray::DataType* offsetArray, runtime::SizeType32 layerCount,
runtime::SizeType32 seqCount, runtime::SizeType32 numKVHeads, runtime::SizeType32 sizeInBytesPerKVHead,
runtime::SizeType32 rewindDraftTokenCount, runtime::SizeType32 const* seqSlotRemapping,
runtime::SizeType32 maxKVCacheLen, runtime::SizeType32 maxBlocksPerSeq, runtime::SizeType32 tokensPerBlock,
cudaStream_t stream);
/*!
@ -98,11 +101,12 @@ void updateKVBlockArrayDraftTokenLocationCommonRewind(int const* seqAcceptedDraf
* @param maxKVCacheLen : Maximum length of each KV cache
* @param stream : CUDA stream to use.
*/
void updateLinearKVCacheDraftTokenLocationSeparateRewind(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
KVLinearBuffer::DataType* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads,
int sizeInBytesPerKVHead, int* rewindDraftTokenCounts, int const* seqSlotRemapping, int maxKVCacheLen,
cudaStream_t stream);
void updateLinearKVCacheDraftTokenLocationSeparateRewind(runtime::SizeType32 const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, runtime::SizeType32 const* pastKeyValueLengths,
KVLinearBuffer::DataType* const* pastKeyValueList, runtime::SizeType32 layerCount, runtime::SizeType32 seqCount,
runtime::SizeType32 numKVHeads, runtime::SizeType32 sizeInBytesPerKVHead,
runtime::SizeType32* rewindDraftTokenCounts, runtime::SizeType32 const* seqSlotRemapping,
runtime::SizeType32 maxKVCacheLen, cudaStream_t stream);
/*!
* Update Block KV cache using separate rewind count for each sequence.
@ -127,11 +131,13 @@ void updateLinearKVCacheDraftTokenLocationSeparateRewind(int const* seqAcceptedD
* @param tokensPerBlock : Tokens per block of Block KV cache
* @param stream : CUDA stream to use.
*/
void updateKVBlockArrayDraftTokenLocationSeparateRewind(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, void* const* pointerArray,
KVBlockArray::DataType* offsetArray, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int* rewindDraftTokenCounts, int const* seqSlotRemapping, int maxKVCacheLen, int maxBlocksPerSeq,
int tokensPerBlock, cudaStream_t stream);
void updateKVBlockArrayDraftTokenLocationSeparateRewind(runtime::SizeType32 const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, runtime::SizeType32 const* pastKeyValueLengths,
void* const* pointerArray, KVBlockArray::DataType* offsetArray, runtime::SizeType32 layerCount,
runtime::SizeType32 seqCount, runtime::SizeType32 numKVHeads, runtime::SizeType32 sizeInBytesPerKVHead,
runtime::SizeType32* rewindDraftTokenCounts, runtime::SizeType32 const* seqSlotRemapping,
runtime::SizeType32 maxKVCacheLen, runtime::SizeType32 maxBlocksPerSeq, runtime::SizeType32 tokensPerBlock,
cudaStream_t stream);
/*!
* Update Linear KV cache using both common rewind and separate rewind count for each sequence. The common
@ -156,11 +162,12 @@ void updateKVBlockArrayDraftTokenLocationSeparateRewind(int const* seqAcceptedDr
* @param maxKVCacheLen : Maximum length of each KV cache
* @param stream : CUDA stream to use.
*/
void updateLinearKVCacheDraftTokenLocation(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
KVLinearBuffer::DataType* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads,
int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments,
int const* seqSlotRemapping, int maxKVCacheLen, cudaStream_t stream);
void updateLinearKVCacheDraftTokenLocation(runtime::SizeType32 const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, runtime::SizeType32 const* pastKeyValueLengths,
KVLinearBuffer::DataType* const* pastKeyValueList, runtime::SizeType32 layerCount, runtime::SizeType32 seqCount,
runtime::SizeType32 numKVHeads, runtime::SizeType32 sizeInBytesPerKVHead,
runtime::SizeType32 rewindDraftTokenCommonCount, runtime::SizeType32 const* rewindDraftTokenSeparateAdjustments,
runtime::SizeType32 const* seqSlotRemapping, runtime::SizeType32 maxKVCacheLen, cudaStream_t stream);
/*!
* Update Block KV cache using both common rewind and separate rewind count for each sequence. The common
@ -178,20 +185,24 @@ void updateLinearKVCacheDraftTokenLocation(int const* seqAcceptedDraftTokenOffse
* @param sizeInBytesPerKVHead : Size of each KV head
* @param rewindDraftTokenCommonCount : Common token count to rewind
* @param rewindDraftTokenSeparateAdjustments : Pointer to an array of length seqCount, each element indicated the
* rewind adjustment for one sequence.
* rewind adjustment for one sequence, indexed through batchSlots.
* @param seqSlotRemapping mapping from batch index to index of the seqSlot in the sorted seqSlot buffer
* e.g. for requests [0, 1, 2] with seqSlots [5, 3, 4], seqSlotRemapping is [1, 2, 0]
* Required to match seqAcceptedDraftTokenOffsets and packedAcceptedDraftTokensIndices from gptDecoderBatch
* and pointerArray and pastKeyValueLengths from runtimeBuffers.
* @param batchSlots : [seqCount] indices of sequences in the seq slots.
* @param maxKVCacheLen : Maximum length of each KV cache
* @param maxBlocksPerSeq : Maximum blocks per sequence of Block KV cache.
* @param tokensPerBlock : Tokens per block of Block KV cache
* @param stream : CUDA stream to use.
*/
void updateKVBlockArrayDraftTokenLocation(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, void* const* pointerArray,
KVBlockArray::DataType* offsetArray, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int const* seqSlotRemapping,
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream);
void updateKVBlockArrayDraftTokenLocation(runtime::SizeType32 const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, runtime::SizeType32 const* pastKeyValueLengths,
void* const* pointerArray, KVBlockArray::DataType* offsetArray, runtime::SizeType32 layerCount,
runtime::SizeType32 seqCount, runtime::SizeType32 numKVHeads, runtime::SizeType32 sizeInBytesPerKVHead,
runtime::SizeType32 rewindDraftTokenCommonCount, runtime::SizeType32 const* rewindDraftTokenSeparateAdjustments,
runtime::SizeType32 const* seqSlotRemapping, runtime::SizeType32 const* batchSlots,
runtime::SizeType32 maxKVCacheLen, runtime::SizeType32 maxBlocksPerSeq, runtime::SizeType32 tokensPerBlock,
cudaStream_t stream);
} // namespace tensorrt_llm::kernels::speculative_decoding

View File

@ -131,7 +131,8 @@ void invokeStopWordsCriterion(TokenIdType const** outputIds, SizeType32 const**
}
__global__ void lengthCriterion(FinishedState* finished, SizeType32* finishedSum, SizeType32 const* sequenceLimitLength,
SizeType32* sequenceLengths, SizeType32 const* batchSlots, SizeType32 batchSize, SizeType32 beamWidth)
SizeType32* sequenceLengths, SizeType32* numNewTokens, SizeType32 const* batchSlots, SizeType32 batchSize,
SizeType32 beamWidth)
{
SizeType32 threadFinishedCount = 0;
auto const batchIdx = blockIdx.x;
@ -144,10 +145,15 @@ __global__ void lengthCriterion(FinishedState* finished, SizeType32* finishedSum
auto finishState = finished[batchSlotBeamWidthIdx];
if (sequenceLengths[batchSlotBeamWidthIdx] >= sequenceLimitLength[batchSlot])
auto const numTokensToLimit = sequenceLimitLength[batchSlot] - sequenceLengths[batchSlotBeamWidthIdx];
if (numTokensToLimit <= 0)
{
finishState.setFinishedMaxLength();
sequenceLengths[batchSlotBeamWidthIdx] = sequenceLimitLength[batchSlot];
if (numNewTokens)
{
numNewTokens[batchSlot] = numNewTokens[batchSlot] + numTokensToLimit;
}
}
threadFinishedCount += finishState.isFinished() ? 1 : 0;
finished[batchSlotBeamWidthIdx] = finishState;
@ -174,8 +180,8 @@ __global__ void lengthCriterion(FinishedState* finished, SizeType32* finishedSum
}
void invokeLengthCriterion(FinishedState* finished, SizeType32* finishedSum, SizeType32 const* sequenceLimitLength,
SizeType32* sequenceLengths, SizeType32 const* batchSlots, SizeType32 batchSize, SizeType32 beamWidth,
cudaStream_t stream)
SizeType32* sequenceLengths, SizeType32* numNewTokens, SizeType32 const* batchSlots, SizeType32 batchSize,
SizeType32 beamWidth, cudaStream_t stream)
{
// Check if we have attained the sequence length limit. If so, stop the
// sequence. In addition, check if all sequences are stopped and return the
@ -184,12 +190,12 @@ void invokeLengthCriterion(FinishedState* finished, SizeType32* finishedSum, Siz
dim3 grid{static_cast<uint32_t>(batchSize)};
lengthCriterion<<<grid, block, 0, stream>>>(
finished, finishedSum, sequenceLimitLength, sequenceLengths, batchSlots, batchSize, beamWidth);
finished, finishedSum, sequenceLimitLength, sequenceLengths, numNewTokens, batchSlots, batchSize, beamWidth);
sync_check_cuda_error();
}
__global__ void explicitEOSCriterion(TokenIdType const** outputIds, TokenIdType const* endIds, FinishedState* finished,
SizeType32* sequenceLengths, SizeType32 const* tokensPerStep, SizeType32 const* batchSlots, SizeType32 batchSize,
SizeType32* sequenceLengths, SizeType32* numNewTokens, SizeType32 const* batchSlots, SizeType32 batchSize,
SizeType32 maxTokensPerStep)
{
auto const batchIdx = blockIdx.x * blockDim.x + threadIdx.x;
@ -204,7 +210,7 @@ __global__ void explicitEOSCriterion(TokenIdType const** outputIds, TokenIdType
return;
}
auto const numTokens = tokensPerStep != nullptr ? tokensPerStep[batchSlot] : maxTokensPerStep;
auto const numTokens = numNewTokens != nullptr ? numNewTokens[batchSlot] : maxTokensPerStep;
auto const endId = endIds[batchSlot];
auto const sequenceLength = sequenceLengths[batchSlot];
@ -217,12 +223,17 @@ __global__ void explicitEOSCriterion(TokenIdType const** outputIds, TokenIdType
{
finished[batchSlot].setFinishedEOS();
sequenceLengths[batchSlot] = max(0, pos);
if (numNewTokens)
{
numNewTokens[batchSlot] = pos - posStart;
}
return;
}
}
}
void invokeExplicitEOSCriterion(TokenIdType const** outputIds, TokenIdType const* endIds, FinishedState* finished,
SizeType32* sequenceLengths, SizeType32 const* tokensPerStep, SizeType32 const* batchSlots, SizeType32 batchSize,
SizeType32* sequenceLengths, SizeType32* numNewTokens, SizeType32 const* batchSlots, SizeType32 batchSize,
SizeType32 beamWidth, SizeType32 maxTokensPerStep, cudaStream_t stream)
{
TLLM_CHECK_WITH_INFO(beamWidth == 1, "Explicit EOS criterion does not support beam search");
@ -233,7 +244,7 @@ void invokeExplicitEOSCriterion(TokenIdType const** outputIds, TokenIdType const
grid.x = divUp(batchSize, blockSize);
explicitEOSCriterion<<<grid, blockSize, 0, stream>>>(
outputIds, endIds, finished, sequenceLengths, tokensPerStep, batchSlots, batchSize, maxTokensPerStep);
outputIds, endIds, finished, sequenceLengths, numNewTokens, batchSlots, batchSize, maxTokensPerStep);
sync_check_cuda_error();
}

View File

@ -62,14 +62,16 @@ void invokeStopWordsCriterion(runtime::TokenIdType const** outputIds, runtime::S
//! \param sequenceLimitLength input buffer [maxBatchSize]. Maximum sequence length.
//! \param sequenceLengths input/output buffer [maxBatchSize, beamWidth].
//! Current sequence lengths of the request tokens.
//! \param numNewTokens output buffer [maxBatchSize], optional. Number of tokens per step for each request.
//! It is assumed that all requests have maxTokensPerStep tokens per step if nullptr.
//! \param batchSlots input buffer[batchSize], optional. Indices of rows of data in memory pool
//! \param batchSize batch size
//! \param beamWidth beam width
//! \param stream stream
void invokeLengthCriterion(FinishedState* finished, runtime::SizeType32* finishedSum,
runtime::SizeType32 const* sequenceLimitLength, runtime::SizeType32* sequenceLengths,
runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth,
cudaStream_t stream);
runtime::SizeType32* numNewTokens, runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize,
runtime::SizeType32 beamWidth, cudaStream_t stream);
//! \brief Sets finished states based on the endIds and ajusts sequence length to length before the first EOS token.
//! Does not support beamWidth > 1 for now.
@ -81,7 +83,7 @@ void invokeLengthCriterion(FinishedState* finished, runtime::SizeType32* finishe
//! [maxBatchSize, beamWidth]. Finished states. Set to FinishedState::FINISHED_EOS if any new tokens contain EOS.
//! \param sequenceLengths input/output buffer [maxBatchSize, beamWidth].
//! Current sequence lengths of the request tokens.
//! \param tokensPerStep input buffer [maxBatchSize], optional. Number of tokens per step for each request.
//! \param numNewTokens input/output buffer [maxBatchSize], optional. Number of tokens per step for each request.
//! It is assumed that all requests have maxTokensPerStep tokens per step if nullptr.
//! \param batchSlots input buffer[batchSize], optional. Indices of rows of data in memory pool
//! \param batchSize batch size
@ -89,7 +91,7 @@ void invokeLengthCriterion(FinishedState* finished, runtime::SizeType32* finishe
//! \param maxTokensPerStep maximum number of tokens decoded per step
//! \param stream stream
void invokeExplicitEOSCriterion(runtime::TokenIdType const** outputIds, runtime::TokenIdType const* endIds,
FinishedState* finished, runtime::SizeType32* sequenceLengths, runtime::SizeType32 const* tokensPerStep,
FinishedState* finished, runtime::SizeType32* sequenceLengths, runtime::SizeType32* numNewTokens,
runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth,
runtime::SizeType32 maxTokensPerStep, cudaStream_t stream);
} // namespace kernels

View File

@ -16,22 +16,16 @@
*/
#include "tensorrt_llm/layers/banWordsLayer.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/kernels/banBadWords.h"
#include "tensorrt_llm/kernels/banRepeatNgram.h"
#include "tensorrt_llm/layers/defaultDecodingParams.h"
#include "tensorrt_llm/layers/layerUtils.h"
#include <algorithm>
using namespace tensorrt_llm::common;
using namespace tensorrt_llm::kernels;
using namespace tensorrt_llm::runtime;
namespace tensorrt_llm
{
namespace layers
namespace tensorrt_llm::layers
{
template <typename T>
@ -97,30 +91,31 @@ void BanWordsLayer<T>::freeBuffer()
template <typename T>
void BanWordsLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 const* batchSlots,
std::shared_ptr<BaseSetupParams> baseSetupParams)
std::shared_ptr<BaseSetupParams> const& baseSetupParams)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto setupParams = std::dynamic_pointer_cast<DynamicDecodeSetupParams>(baseSetupParams);
std::vector<SizeType32> batchSlotsVec(batchSize);
std::iota(batchSlotsVec.begin(), batchSlotsVec.end(), 0);
auto batchSlotsHost = batchSlots ? batchSlots : batchSlotsVec.data();
auto const& penaltyParams = setupParams->penaltyParams;
auto const& banWordsParams = setupParams->banWordsParams;
TLLM_CHECK_WITH_INFO(banWordsParams, "banWordsParams for setup is not set");
bool const useNoRepeatNgramSize
= mDecodingMode.isUseNoRepeatNgramSize() && penaltyParams.noRepeatNgramSize.has_value();
= mDecodingMode.isUseNoRepeatNgramSize() && banWordsParams->noRepeatNgramSize.has_value();
FillBuffers const fillBuffers{batchSize, mDecoderDomain.getBatchSize(), mStream};
mUseNoRepeatNgramSize |= useNoRepeatNgramSize;
if (mUseNoRepeatNgramSize)
{
fillBuffers(penaltyParams.noRepeatNgramSize, DefaultDecodingParams::getNoRepeatNgramSize(), mNoRepeatNgramSize,
mNoRepeatNgramSizeDevice, batchSlotsHost, std::make_pair(0.f, std::numeric_limits<float>::max()),
"no_repeat_ngram_size");
fillBuffers(banWordsParams->noRepeatNgramSize, DefaultDecodingParams::getNoRepeatNgramSize(),
mNoRepeatNgramSize, mNoRepeatNgramSizeDevice, batchSlotsHost,
std::make_pair(0.f, std::numeric_limits<float>::max()), "no_repeat_ngram_size");
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void BanWordsLayer<T>::banRepeatNGrams(Tensor& logits, std::shared_ptr<DynamicDecodeOutputParams> const& outputs,
std::shared_ptr<DynamicDecodeInputParams> const& inputs, SizeType32 const* batchSlots,
void BanWordsLayer<T>::banRepeatNGrams(Tensor& logits, std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<DecodingInputs> const& inputs, SizeType32 const* batchSlots,
SizeType32 const* noRepeatNgramSizeDevice, DecoderDomain const& decoderDomain, SizeType32 maxSeqLen,
bool useNoRepeatNgramSize, cudaStream_t stream)
{
@ -129,11 +124,11 @@ void BanWordsLayer<T>::banRepeatNGrams(Tensor& logits, std::shared_ptr<DynamicDe
auto const maxStep = maxSeqLen;
if (useNoRepeatNgramSize)
{
invokeBanRepeatNgram(logits.template getPtr<T>(), outputs->output_ids_ptr.template getPtr<TokenIdType const*>(),
invokeBanRepeatNgram(logits.template getPtr<T>(), outputs->outputIdsPtr.template getPtr<TokenIdType const*>(),
reinterpret_cast<FinishedState*>(
inputs->finished.value_or(Tensor{}).template getPtr<FinishedState::UnderlyingType>()),
outputs->parent_ids_ptr.template getPtr<SizeType32 const*>(), batchSlots,
outputs->sequence_length->template getPtr<SizeType32>(), decoderDomain.getBatchSize(),
outputs->parentIdsPtr.template getPtr<SizeType32 const*>(), batchSlots,
outputs->sequenceLength->template getPtr<SizeType32>(), decoderDomain.getBatchSize(),
decoderDomain.getBeamWidth(), maxSeqLen, noRepeatNgramSizeDevice, decoderDomain.getVocabSizePadded(),
maxStep, stream);
}
@ -141,39 +136,40 @@ void BanWordsLayer<T>::banRepeatNGrams(Tensor& logits, std::shared_ptr<DynamicDe
}
template <typename T>
void BanWordsLayer<T>::banBadWords(Tensor& logits, std::shared_ptr<DynamicDecodeOutputParams> const& outputs,
std::shared_ptr<DynamicDecodeInputParams> const& inputs, SizeType32 const* batchSlots,
DecoderDomain const& decoderDomain, SizeType32 maxSeqLen, cudaStream_t stream)
void BanWordsLayer<T>::banBadWords(Tensor& logits, std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<DecodingInputs> const& inputs, SizeType32 const* batchSlots, DecoderDomain const& decoderDomain,
SizeType32 maxSeqLen, cudaStream_t stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const maxBadWordsLength = inputs->max_bad_words_len;
auto const maxBadWordsLength = inputs->banWordsInputs->maxBadWordsLen;
if (maxBadWordsLength)
{
auto const** badWordsPtr = inputs->bad_words_ptr->template getPtr<TokenIdType const*>();
auto const* badWordsLens = inputs->bad_words_lengths->template getPtr<SizeType32>();
auto const** badWordsPtr = inputs->banWordsInputs->badWordsPtr->template getPtr<TokenIdType const*>();
auto const* badWordsLens = inputs->banWordsInputs->badWordsLengths->template getPtr<SizeType32>();
invokeBanBadWords((T*) logits.template getPtr<T>(),
outputs->output_ids_ptr.template getPtr<TokenIdType const*>(),
decoderDomain.getBeamWidth() > 1 ? outputs->parent_ids_ptr.template getPtr<SizeType32 const*>() : nullptr,
invokeBanBadWords((T*) logits.template getPtr<T>(), outputs->outputIdsPtr.template getPtr<TokenIdType const*>(),
decoderDomain.getBeamWidth() > 1 ? outputs->parentIdsPtr.template getPtr<SizeType32 const*>() : nullptr,
batchSlots, decoderDomain.getBatchSize(), decoderDomain.getBeamWidth(), badWordsPtr, badWordsLens,
maxBadWordsLength, decoderDomain.getVocabSizePadded(),
outputs->sequence_length->template getPtr<SizeType32>(), maxSeqLen, stream);
outputs->sequenceLength->template getPtr<SizeType32>(), maxSeqLen, stream);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void BanWordsLayer<T>::forwardAsync(
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto inputs = std::dynamic_pointer_cast<DynamicDecodeInputParams>(baseInputs);
auto outputs = std::dynamic_pointer_cast<DynamicDecodeOutputParams>(baseOutputs);
auto inputs = std::dynamic_pointer_cast<DecodingInputs>(baseInputs);
auto outputs = std::dynamic_pointer_cast<BaseDecodingOutputs>(baseOutputs);
TLLM_CHECK_WITH_INFO(inputs->banWordsInputs, "banWordsInputs for forward is not set");
auto const localDecoderDomain = getLocalDecoderDomain(inputs, mDecoderDomain);
auto const maxSeqLen = outputs->output_ids.shape[outputs->output_ids.shape.size() - 1];
auto batchSlots = inputs->batch_slots ? inputs->batch_slots->template getPtr<SizeType32 const>() : nullptr;
auto const maxSeqLen = outputs->outputIds.shape[outputs->outputIds.shape.size() - 1];
auto batchSlots = inputs->batchSlots ? inputs->batchSlots->template getPtr<SizeType32 const>() : nullptr;
banRepeatNGrams(inputs->logits.value(), outputs, inputs, batchSlots, mNoRepeatNgramSizeDevice, localDecoderDomain,
maxSeqLen, mUseNoRepeatNgramSize, mStream);
@ -185,5 +181,4 @@ void BanWordsLayer<T>::forwardAsync(
template class BanWordsLayer<float>;
template class BanWordsLayer<half>;
} // namespace layers
} // namespace tensorrt_llm
} // namespace tensorrt_llm::layers

View File

@ -17,17 +17,14 @@
#pragma once
#include <curand_kernel.h>
#include "tensorrt_llm/common/tensor.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/layers/baseLayer.h"
#include "tensorrt_llm/layers/decodingParams.h"
#include "tensorrt_llm/runtime/iTensor.h"
namespace tensorrt_llm
{
namespace layers
#include <curand_kernel.h>
namespace tensorrt_llm::layers
{
//! \brief Layer to ban specific words from being sampled.
@ -45,20 +42,21 @@ public:
~BanWordsLayer() override;
void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 const* batchSlots,
std::shared_ptr<BaseSetupParams> baseSetupParams) override;
std::shared_ptr<BaseSetupParams> const& baseSetupParams) override;
//! \brief Modifies 'outputs->logits' in-place with -INF for banned words
void forwardAsync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) override;
void forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& inputs) override;
private:
void initialize();
void allocateBuffer();
void freeBuffer();
static void banBadWords(tc::Tensor& logits, std::shared_ptr<DynamicDecodeOutputParams> const& outputs,
std::shared_ptr<DynamicDecodeInputParams> const& params, runtime::SizeType32 const* batchSlots,
static void banBadWords(tc::Tensor& logits, std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<DecodingInputs> const& inputs, runtime::SizeType32 const* batchSlots,
DecoderDomain const& decoderDomain, runtime::SizeType32 maxSeqLen, cudaStream_t stream);
static void banRepeatNGrams(tc::Tensor& logits, std::shared_ptr<DynamicDecodeOutputParams> const& outputs,
std::shared_ptr<DynamicDecodeInputParams> const& inputs, runtime::SizeType32 const* batchSlots,
static void banRepeatNGrams(tc::Tensor& logits, std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<DecodingInputs> const& inputs, runtime::SizeType32 const* batchSlots,
runtime::SizeType32 const* noRepeatNgramSizeDevice, DecoderDomain const& decoderDomain,
runtime::SizeType32 maxSeqLen, bool useNoRepeatNgramSize, cudaStream_t stream);
@ -76,5 +74,4 @@ private:
bool mUseNoRepeatNgramSize{false};
};
} // namespace layers
} // namespace tensorrt_llm
} // namespace tensorrt_llm::layers

View File

@ -17,14 +17,14 @@
#pragma once
#include "tensorrt_llm/common/allocator.h"
#include "tensorrt_llm/common/tensor.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/common/cudaAllocator.h"
#include "tensorrt_llm/layers/decodingParams.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
namespace tensorrt_llm
{
namespace layers
#include <utility>
namespace tensorrt_llm::layers
{
class BaseLayer
@ -35,8 +35,17 @@ public:
BaseLayer(DecoderDomain const& decoderDomain, cudaStream_t stream,
std::shared_ptr<tensorrt_llm::common::IAllocator> allocator)
: mStream(stream)
: mBufferManager(nullptr)
, mStream(stream)
, mAllocator(std::move(allocator))
, mDecoderDomain(std::move(decoderDomain))
{
}
BaseLayer(DecoderDomain const& decoderDomain, std::shared_ptr<runtime::BufferManager> const& bufferManager)
: mBufferManager(bufferManager)
, mStream(mBufferManager->getStream().get())
, mAllocator(std::make_shared<tensorrt_llm::common::CudaAllocator>(*mBufferManager))
, mDecoderDomain(decoderDomain)
{
}
@ -79,29 +88,36 @@ public:
//! \param setupParams shared pointer to params inherited from BaseSetupParams
// clang-format on
virtual void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth,
runtime::SizeType32 const* batchSlots, std::shared_ptr<BaseSetupParams> setupParams)
runtime::SizeType32 const* batchSlots, std::shared_ptr<BaseSetupParams> const& setupParams)
= 0;
// clang-format off
//! \brief Virtual function to execute layer async on GPU.
//! There must be no stream synchronization inside this function.
//!
//! \param outputs shared pointer to params inherited from BaseOutputParams
//! \param outputs shared pointer to params inherited from BaseDecodingOutputs
//! \param inputs shared pointer to params inherited from BaseForwardParams
// clang-format on
virtual void forwardAsync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) = 0;
virtual void forwardAsync(
std::shared_ptr<BaseDecodingOutputs> const& outputs, std::shared_ptr<BaseDecodingInputs> const& inputs)
= 0;
// clang-format off
//! \brief Virtual function to execute layer synchronously on CPU / GPU.
//! It is allowed (but not necassary) to synchronize on stream inside this function.
//! It is targeted mainly for prototyping.
//!
//! \param outputs shared pointer to params inherited from BaseOutputParams
//! \param outputs shared pointer to params inherited from BaseDecodingOutputs
//! \param inputs shared pointer to params inherited from BaseForwardParams
// clang-format on
virtual void forwardSync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) {}
virtual void forwardSync(
std::shared_ptr<BaseDecodingOutputs> const& outputs, std::shared_ptr<BaseDecodingInputs> const& inputs)
{
}
protected:
// Buffer Manager
std::shared_ptr<runtime::BufferManager> mBufferManager;
// Cuda stream
cudaStream_t mStream;
// Memory allocator
@ -119,5 +135,4 @@ protected:
bool mIsAllocateBuffer{false};
};
} // namespace layers
} // namespace tensorrt_llm
} // namespace tensorrt_llm::layers

View File

@ -14,19 +14,17 @@
* limitations under the License.
*/
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/kernels/beamSearchKernels.h"
#include "tensorrt_llm/layers/beamSearchLayer.h"
#include "tensorrt_llm/layers/defaultDecodingParams.h"
#include "tensorrt_llm/layers/layerUtils.h"
#include <limits>
using namespace tensorrt_llm::common;
using namespace tensorrt_llm::kernels;
namespace tensorrt_llm
{
namespace layers
namespace tensorrt_llm::layers
{
template <typename T>
@ -47,7 +45,7 @@ BeamSearchLayer<T>::~BeamSearchLayer()
template <typename T>
void BeamSearchLayer<T>::setup(runtime::SizeType32 const batchSize, runtime::SizeType32 const beamWidth,
runtime::SizeType32 const* batchSlots, std::shared_ptr<BaseSetupParams> baseSetupParams)
runtime::SizeType32 const* batchSlots, std::shared_ptr<BaseSetupParams> const& baseSetupParams)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK_WITH_INFO(
@ -65,12 +63,12 @@ void BeamSearchLayer<T>::setup(runtime::SizeType32 const batchSize, runtime::Siz
auto constexpr fltEpsilon = std::numeric_limits<float>::epsilon();
FillBuffers const fillBuffers{batchSize, batchSize, mStream};
fillBuffers(setupParams->beam_search_diversity_rate, DefaultDecodingParams::getBeamSearchDiversity(),
fillBuffers(setupParams->beamSearchDiversityRate, DefaultDecodingParams::getBeamSearchDiversity(),
mDiversityRateHost, mDiversityRateDevice, (int*) nullptr, std::make_pair(-fltEpsilon, fltMax),
"diveristy rate");
fillBuffers(setupParams->length_penalty, DefaultDecodingParams::getLengthPenalty(), mLengthPenaltyHost,
fillBuffers(setupParams->lengthPenalty, DefaultDecodingParams::getLengthPenalty(), mLengthPenaltyHost,
mLengthPenaltyDevice, (int*) nullptr, std::make_pair(fltMin, fltMax), "length penalty");
fillBuffers(setupParams->early_stopping, DefaultDecodingParams::getEarlyStopping(), mEarlyStoppingHost,
fillBuffers(setupParams->earlyStopping, DefaultDecodingParams::getEarlyStopping(), mEarlyStoppingHost,
mEarlyStoppingDevice, (int*) nullptr, std::make_pair(fltMin, fltMax), "early stopping");
mHasDiffRuntimeArgs = setupParams->hasDiffRuntimeArgs;
@ -80,7 +78,7 @@ void BeamSearchLayer<T>::setup(runtime::SizeType32 const batchSize, runtime::Siz
__global__ void updateCacheIndirectionKernel(
int* tgtCI, int const* srcCI, BeamHypotheses bh, int const nMaxAttentionWindow, int const nSinkTokenLength)
{
// Update indirections from steps `bh.inputLength[indexBatchBeam]` to step `sequence_lengths[indexBatchBeam]`
// Update indirections from steps `bh.inputLength[indexBatchBeam]` to step `sequenceLengths[indexBatchBeam]`
int const step = threadIdx.x + blockIdx.x * blockDim.x;
int const indexBatchBeam = blockIdx.y;
int const nBS{bh.nBatchSize};
@ -88,7 +86,7 @@ __global__ void updateCacheIndirectionKernel(
int const nMSL{bh.nMaxSeqLen};
int const indexBatch = indexBatchBeam / nBM;
int const indexBeam = indexBatchBeam % nBM;
int const lastStep{bh.sequenceLengths[indexBatchBeam] - 1}; // the sequence_lengths is updated, need to minus 1
int const lastStep{bh.sequenceLengths[indexBatchBeam] - 1}; // the sequenceLengths is updated, need to minus 1
// Return early when the indexBatchBeam or step is out of the bound
// No update for the indices of context part since KV Cache is shared
@ -111,38 +109,38 @@ __global__ void updateCacheIndirectionKernel(
template <typename T>
void BeamSearchLayer<T>::forwardAsyncSingleRequest(
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto ip = std::dynamic_pointer_cast<BeamSearchInputParams>(baseInputs);
auto op = std::dynamic_pointer_cast<BeamSearchOutputParams>(baseOutputs);
auto op = std::dynamic_pointer_cast<BeamSearchOutputs>(baseOutputs);
TLLM_CHECK_WITH_INFO(op->beamHypotheses, std::string("Output BeamHypotheses is not set."));
TLLM_CHECK_WITH_INFO(op->sequence_length->template getPtr<int>() != nullptr || mLengthPenaltyDevice == nullptr,
TLLM_CHECK_WITH_INFO(op->sequenceLength->template getPtr<int>() != nullptr || mLengthPenaltyDevice == nullptr,
std::string("Current sequence lengths must be set for length penalty computation."));
TLLM_CHECK_WITH_INFO(ip->ite == 0, "Pipeline Parallelism is not supported yet !");
BeamHypotheses& bh{*op->beamHypotheses};
// bh's members already initialized in op: *CBA, batchDones
// bh's members not used in function: outputIds, logProbs, outputIdsUnfinish, parentIdsUnfinish
bh.nMaxBatchSize = static_cast<std::int32_t>(op->output_ids_ptr.shape[0]);
bh.nBatchSize = ip->logits.shape[0];
bh.nBeamWidth = static_cast<std::int32_t>(op->output_ids_ptr.shape[1]);
bh.nMaxBatchSize = static_cast<std::int32_t>(op->outputIdsPtr.shape[0]);
bh.nBatchSize = ip->localBatchSize;
bh.nBeamWidth = static_cast<std::int32_t>(op->outputIdsPtr.shape[1]);
bh.nIte = ip->ite;
bh.nMaxSeqLen = static_cast<std::int32_t>(op->output_ids_ptr.shape[2]);
bh.nMaxSeqLen = static_cast<std::int32_t>(op->outputIdsPtr.shape[2]);
bh.nVocabSize = mVocabSizePadded;
bh.diversityRates = mDiversityRateDevice;
bh.lengthPenalties = mLengthPenaltyDevice;
bh.earlyStoppings = mEarlyStoppingDevice;
bh.inputLengths = ip->input_lengths->template getPtr<int const>();
bh.endIds = ip->end_ids.template getPtr<int const>();
bh.logProbsTiled = (op->output_log_probs) ? op->output_log_probs->template getPtr<float>() : nullptr;
bh.sequenceLengths = op->sequence_length->template getPtr<int>();
bh.cumLogProbs = op->cum_log_probs->template getPtr<float>();
bh.inputLengths = ip->inputLengths->template getPtr<int const>();
bh.endIds = ip->endIds.template getPtr<int const>();
bh.logProbsTiled = (op->outputLogProbs) ? op->outputLogProbs->template getPtr<float>() : nullptr;
bh.sequenceLengths = op->sequenceLength->template getPtr<int>();
bh.cumLogProbs = op->cumLogProbs->template getPtr<float>();
bh.finished = reinterpret_cast<FinishedState*>(op->finished->template getPtr<FinishedState::UnderlyingType>());
bh.outputIdsPtr = op->output_ids_ptr.template getPtr<int*>();
bh.parentIdsPtr = op->parent_ids_ptr.template getPtr<int*>();
bh.outputIdsPtr = op->outputIdsPtr.template getPtr<int*>();
bh.parentIdsPtr = op->parentIdsPtr.template getPtr<int*>();
T const* logits = ip->logits.template getPtr<T>();
T const* bias = static_cast<T const*>(nullptr);
@ -155,11 +153,11 @@ void BeamSearchLayer<T>::forwardAsyncSingleRequest(
if (bh.nBeamWidth > 1)
{
auto tgtCI = op->tgt_cache_indirection.template getPtr<int>();
auto srcCI = ip->src_cache_indirection.template getPtr<int const>();
auto tgtCI = op->tgtCacheIndirection.template getPtr<int>();
auto srcCI = ip->srcCacheIndirection.template getPtr<int const>();
dim3 const grid(roundUp(bh.nMaxSeqLen, 32), bh.nBatchSize * bh.nBeamWidth);
updateCacheIndirectionKernel<<<grid, 32, 0, mStream>>>(
tgtCI, srcCI, bh, ip->max_attention_window, ip->sink_token_length);
tgtCI, srcCI, bh, ip->maxAttentionWindow, ip->sinkTokenLength);
sync_check_cuda_error();
}
@ -168,32 +166,29 @@ void BeamSearchLayer<T>::forwardAsyncSingleRequest(
template <typename T>
void BeamSearchLayer<T>::forwardAsync(
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto outputs = std::dynamic_pointer_cast<DynamicDecodeOutputParams>(baseOutputs);
auto params = std::dynamic_pointer_cast<DynamicDecodeInputParams>(baseInputs);
auto outputs = std::dynamic_pointer_cast<BeamSearchOutputs>(baseOutputs);
auto params = std::dynamic_pointer_cast<DecodingInputs>(baseInputs);
auto const localDecoderDomain = getLocalDecoderDomain(params, mDecoderDomain);
auto batchSlots = params->batch_slots ? params->batch_slots->template getPtr<SizeType32 const>() : nullptr;
auto const maxSeqLen = outputs->output_ids.shape[outputs->output_ids.shape.size() - 1];
auto batchSlots = params->batchSlots ? params->batchSlots->template getPtr<SizeType32 const>() : nullptr;
auto const maxSeqLen = outputs->outputIds.shape[outputs->outputIds.shape.size() - 1];
auto const ite = params->ite;
auto const step = params->step;
// common inputs
auto const& endIds = params->end_ids;
auto const localBatchSize = static_cast<std::size_t>(params->local_batch_size);
auto const& endIds = params->endIds;
auto const localBatchSize = static_cast<std::size_t>(params->localBatchSize);
TLLM_CHECK_WITH_INFO(localDecoderDomain.getBeamWidth() > 1,
"Decoding mode is beam search, but beamWidth <= 1 (%d <= 1)", localDecoderDomain.getBeamWidth());
TLLM_CHECK_WITH_INFO(
params->src_cache_indirection.has_value(), "src_cache_indirection is mandatory in beam search.");
TLLM_CHECK_WITH_INFO(
outputs->tgt_cache_indirection.has_value(), "tgt_cache_indirection is mandatory in beam search.");
TLLM_CHECK_WITH_INFO(outputs->parent_ids.has_value(), "parent_ids tensor is mandatory in beam search.");
TLLM_CHECK_WITH_INFO(params->srcCacheIndirection.has_value(), "srcCacheIndirection is mandatory in beam search.");
TLLM_CHECK_WITH_INFO(outputs->parentIds.has_value(), "parentIds tensor is mandatory in beam search.");
TLLM_CHECK_WITH_INFO(outputs->finished.has_value(), "finished tensor is mandatory in beam search.");
TLLM_CHECK_WITH_INFO(outputs->cum_log_probs.has_value(), "cum_log_probs tensor is mandatory in beam search.");
TLLM_CHECK_WITH_INFO(outputs->cumLogProbs.has_value(), "cumLogProbs tensor is mandatory in beam search.");
// Compute one by one if there are different runtime arguments
// due to Batch-Beam-Search is not supported yet, so we need to compute
@ -211,30 +206,32 @@ void BeamSearchLayer<T>::forwardAsync(
auto const end_id_offset = endIds.slice({dynamic_decode_batch_size}, dynamic_ite * dynamic_decode_batch_size);
auto forwardParams = std::make_shared<BeamSearchInputParams>(step, ite, logits_offset, end_id_offset,
*params->src_cache_indirection, static_cast<std::int32_t>(params->max_attention_window),
static_cast<std::int32_t>(params->sink_token_length), static_cast<std::int32_t>(maxSeqLen));
*params->srcCacheIndirection, static_cast<std::int32_t>(params->maxAttentionWindow),
static_cast<std::int32_t>(params->sinkTokenLength), static_cast<std::int32_t>(maxSeqLen),
dynamic_decode_batch_size);
if (params->input_lengths)
if (params->inputLengths)
{
forwardParams->input_lengths = params->input_lengths->slice(
forwardParams->inputLengths = params->inputLengths->slice(
{dynamic_decode_batch_size * localDecoderDomain.getBeamWidth()}, dynamic_id_offset);
}
auto outputParams = std::make_shared<BeamSearchOutputParams>(
outputs->output_ids, outputs->parent_ids.value(), outputs->tgt_cache_indirection.value());
auto outputParams = std::make_shared<BeamSearchOutputs>(outputs->outputIds);
outputParams->output_ids_ptr = std::move(outputs->output_ids_ptr);
outputParams->parent_ids_ptr = std::move(outputs->parent_ids_ptr);
outputParams->sequence_length = outputs->sequence_length->slice(
outputParams->parentIds = std::move(outputs->parentIds);
outputParams->tgtCacheIndirection = std::move(outputs->tgtCacheIndirection);
outputParams->outputIdsPtr = std::move(outputs->outputIdsPtr);
outputParams->parentIdsPtr = std::move(outputs->parentIdsPtr);
outputParams->sequenceLength = outputs->sequenceLength->slice(
{dynamic_decode_batch_size * localDecoderDomain.getBeamWidth()}, dynamic_id_offset);
outputParams->finished = outputs->finished->slice(
{dynamic_decode_batch_size * localDecoderDomain.getBeamWidth()}, dynamic_id_offset);
outputParams->cum_log_probs = outputs->cum_log_probs->slice(
outputParams->cumLogProbs = outputs->cumLogProbs->slice(
{dynamic_decode_batch_size * localDecoderDomain.getBeamWidth()}, dynamic_id_offset);
outputParams->output_log_probs = outputs->output_log_probs_tiled; // notice: use tiled tensor
outputParams->outputLogProbs = outputs->outputLogProbsTiled; // notice: use tiled tensor
outputParams->beamHypotheses = std::move(outputs->beamHypotheses);
// beam_search_diversity_rate is only supported when using BeamHypotheses
// beamSearchDiversityRate is only supported when using BeamHypotheses
forwardAsyncSingleRequest(outputParams, forwardParams);
} // end of dynamic_ite
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
@ -275,5 +272,4 @@ void BeamSearchLayer<T>::freeBuffer()
template class BeamSearchLayer<float>;
template class BeamSearchLayer<half>;
} // namespace layers
} // namespace tensorrt_llm
} // namespace tensorrt_llm::layers

View File

@ -17,71 +17,40 @@
#pragma once
#include "tensorrt_llm/common/tensor.h"
#include "tensorrt_llm/kernels/beamSearchKernels.h"
#include "tensorrt_llm/kernels/decodingCommon.h"
#include "tensorrt_llm/layers/baseLayer.h"
#include "tensorrt_llm/layers/decodingParams.h"
#include "tensorrt_llm/runtime/common.h"
#include <utility>
#include <optional>
#include <utility>
namespace tc = tensorrt_llm::common;
namespace tensorrt_llm
{
namespace layers
namespace tensorrt_llm::layers
{
// BS: batch_size, lBS: local_batch_size, BM: beam_width, mSL: max_seq_length
class BeamSearchSetupParams : public BaseSetupParams
{
public:
std::optional<std::vector<float>> beam_search_diversity_rate; // [BS] on cpu
std::optional<std::vector<float>> length_penalty; // [BS] on cpu
std::optional<std::vector<int>> early_stopping; // [BS] on cpu
bool hasDiffRuntimeArgs{false};
};
class BeamSearchInputParams : public BaseInputParams
class BeamSearchInputParams : public DecodingInputs
{
public:
explicit BeamSearchInputParams(runtime::SizeType32 step, runtime::SizeType32 ite, tc::Tensor logits,
tc::Tensor endIds, tc::Tensor src_cache_indirection, runtime::SizeType32 max_attention_window,
runtime::SizeType32 sink_token_length, runtime::SizeType32 max_seq_len)
: BaseInputParams(step, ite, std::move(endIds))
tc::Tensor endIds, tc::Tensor srcCacheIndirection, runtime::SizeType32 maxAttentionWindow,
runtime::SizeType32 sinkTokenLength, runtime::SizeType32 maxSeqLen, runtime::SizeType32 localBatchSize)
: DecodingInputs(std::move(endIds), step, ite, localBatchSize)
, logits{std::move(logits)}
, max_attention_window{max_attention_window}
, sink_token_length{sink_token_length}
, max_seq_len{max_seq_len}
, src_cache_indirection{std::move(src_cache_indirection)}
, maxAttentionWindow{maxAttentionWindow}
, sinkTokenLength{sinkTokenLength}
, maxSeqLen{maxSeqLen}
, srcCacheIndirection{std::move(srcCacheIndirection)}
{
}
// mandatory parameters
tc::Tensor logits; // [maxBatchSize, beamWidth, vocabSizePadded]
runtime::SizeType32 max_attention_window;
runtime::SizeType32 sink_token_length;
runtime::SizeType32 max_seq_len;
tc::Tensor src_cache_indirection; // [BS, BM, mSL]
std::optional<tc::Tensor> input_lengths; // [BS, BM]
};
class BeamSearchOutputParams : public BaseOutputParams
{
public:
explicit BeamSearchOutputParams(tc::Tensor outputIds, tc::Tensor parentIds, tc::Tensor tgt_cache_indirection)
: BaseOutputParams{std::move(outputIds)}
, parent_ids{std::move(parentIds)}
, tgt_cache_indirection{std::move(tgt_cache_indirection)}
{
}
std::shared_ptr<kernels::BeamHypotheses> beamHypotheses;
tc::Tensor parent_ids; // [BS, BM, mSL]
tc::Tensor tgt_cache_indirection; // [BS, BM, mSL]
tc::Tensor parent_ids_ptr; // [BS][BM, mSL]
runtime::SizeType32 maxAttentionWindow;
runtime::SizeType32 sinkTokenLength;
runtime::SizeType32 maxSeqLen;
tc::Tensor srcCacheIndirection; // [BS, BM, mSL]
std::optional<tc::Tensor> inputLengths; // [BS, BM]
};
template <typename T>
@ -94,15 +63,17 @@ public:
~BeamSearchLayer() override;
void setup(runtime::SizeType32 const batch_size, runtime::SizeType32 const beamWidth,
runtime::SizeType32 const* batchSlots, std::shared_ptr<BaseSetupParams> setupParams) override;
void setup(runtime::SizeType32 const batchSize, runtime::SizeType32 const beamWidth,
runtime::SizeType32 const* batchSlots, std::shared_ptr<BaseSetupParams> const& setupParams) override;
void forwardAsync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) override;
void forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& inputs) override;
private:
void forwardAsyncSingleRequest(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs);
void forwardAsyncSingleRequest(
std::shared_ptr<BaseDecodingOutputs> const& outputs, std::shared_ptr<BaseDecodingInputs> const& inputs);
void allocateBuffer(runtime::SizeType32 const batch_size, runtime::SizeType32 const beam_width);
void allocateBuffer(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth);
void freeBuffer();
private:
@ -126,5 +97,4 @@ private:
bool mHasDiffRuntimeArgs{false};
};
} // namespace layers
} // namespace tensorrt_llm
} // namespace tensorrt_llm::layers

View File

@ -16,16 +16,13 @@
*/
#include "tensorrt_llm/layers/decodingLayer.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/kernels/decodingCommon.h"
#include "tensorrt_llm/kernels/samplingTopKKernels.h"
#include "tensorrt_llm/layers/beamSearchLayer.h"
#include "tensorrt_llm/layers/decodingParams.h"
#include "tensorrt_llm/layers/explicitDraftTokensLayer.h"
#include "tensorrt_llm/layers/layerUtils.h"
#include "tensorrt_llm/layers/medusaDecodingLayer.h"
#include "tensorrt_llm/layers/samplingLayer.h"
#include <algorithm>
using namespace tensorrt_llm::common;
using namespace tensorrt_llm::kernels;
using namespace tensorrt_llm::runtime;
@ -59,15 +56,14 @@ bool allSame(std::optional<std::vector<T>> const& vOpt)
bool hasDiffRuntimeArgs(std::shared_ptr<tensorrt_llm::layers::DynamicDecodeSetupParams> const& params)
{
return !allSame(params->penaltyParams.frequencyPenalty) || !allSame(params->penaltyParams.presencePenalty)
|| !allSame(params->penaltyParams.repetitionPenalty) || !allSame(params->penaltyParams.temperature)
|| !allSame(params->penaltyParams.minLength) || !allSame(params->penaltyParams.noRepeatNgramSize);
// return !allSame(params->penaltyParams.frequencyPenalty) || !allSame(params->penaltyParams.presencePenalty)
// || !allSame(params->penaltyParams.repetitionPenalty) || !allSame(params->penaltyParams.temperature)
// || !allSame(params->penaltyParams.minLength) || !allSame(params->banWordsInputs.noRepeatNgramSize);
return false;
}
} // namespace
namespace tensorrt_llm
{
namespace layers
namespace tensorrt_llm::layers
{
template <typename T>
DecodingLayer<T>::DecodingLayer(executor::DecodingMode const& mode, DecoderDomain const& decoderDomain,
@ -110,61 +106,40 @@ DecodingLayer<T>::DecodingLayer(executor::DecodingMode const& mode, DecoderDomai
template <typename T>
void DecodingLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 const* batchSlots,
std::shared_ptr<BaseSetupParams> baseSetupParams)
std::shared_ptr<BaseSetupParams> const& baseSetupParams)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto setupParams = std::dynamic_pointer_cast<DynamicDecodeSetupParams>(baseSetupParams);
TLLM_CHECK_WITH_INFO(setupParams->decodingParams, "decodingParams for setup is not set");
if (mDecodingMode.isTopKorTopP())
{ // sampling layers
TLLM_CHECK_WITH_INFO(
beamWidth == 1, "Decoding mode is TopK and/or TopP, but beamWidth != 1 (%d != 1)", beamWidth);
auto samplingParams = std::make_shared<SamplingSetupParams>();
samplingParams->runtime_top_k = setupParams->samplingParams.runtime_top_k;
samplingParams->runtime_top_p = setupParams->samplingParams.runtime_top_p;
samplingParams->randomSeed = setupParams->randomSeed;
samplingParams->top_p_decay = setupParams->samplingParams.top_p_decay;
samplingParams->top_p_min = setupParams->samplingParams.top_p_min;
samplingParams->top_p_reset_ids = setupParams->samplingParams.top_p_reset_ids;
samplingParams->normalize_log_probs = setupParams->samplingParams.normalize_log_probs;
samplingParams->outputLogProbs = setupParams->samplingParams.outputLogProbs;
samplingParams->cumLogProbs = setupParams->samplingParams.cumLogProbs;
mDecodingLayer->setup(batchSize, beamWidth, batchSlots, samplingParams);
mDecodingLayer->setup(batchSize, beamWidth, batchSlots, setupParams->decodingParams);
}
else if (mDecodingMode.isBeamSearch())
{ // beam search layer
TLLM_CHECK_WITH_INFO(beamWidth > 1, "Decoding mode is beam search, but beamWidth <= 1 (%d <= 1)", beamWidth);
auto beamSearchParams = std::make_shared<BeamSearchSetupParams>();
beamSearchParams->beam_search_diversity_rate = setupParams->beamSearchParams.beam_search_diversity_rate;
beamSearchParams->length_penalty = setupParams->beamSearchParams.length_penalty;
beamSearchParams->early_stopping = setupParams->beamSearchParams.early_stopping;
beamSearchParams->hasDiffRuntimeArgs = hasDiffRuntimeArgs(setupParams);
mDecodingLayer->setup(batchSize, beamWidth, nullptr, beamSearchParams);
mDecodingLayer->setup(batchSize, beamWidth, nullptr, setupParams->decodingParams);
}
else if (mDecodingMode.isMedusa())
{
auto medusaSetupParams = std::make_shared<MedusaSetupParams>();
medusaSetupParams->runtimeTopK = setupParams->samplingParams.runtime_top_k;
medusaSetupParams->runtimeHeadsTopK = setupParams->medusaParams.topKMedusaHeads;
medusaSetupParams->randomSeed = setupParams->randomSeed;
mDecodingLayer->setup(batchSize, beamWidth, batchSlots, medusaSetupParams);
TLLM_CHECK_WITH_INFO(beamWidth == 1, "Decoding mode is Medusa, but beamWidth != 1 (%d != 1)", beamWidth);
mDecodingLayer->setup(batchSize, beamWidth, batchSlots, setupParams->decodingParams);
}
else if (mDecodingMode.isLookahead())
{
TLLM_CHECK_WITH_INFO(beamWidth == 1, "Decoding mode is Lookahead, but beamWidth != 1 (%d != 1)", beamWidth);
// TODO(nkorobov) add lookahead layer
}
else if (mDecodingMode.isExplicitDraftTokens())
{
auto explicitDraftTokensSetupParams = std::make_shared<ExplicitDraftTokensSetupParams>();
explicitDraftTokensSetupParams->temperature = setupParams->penaltyParams.temperature;
explicitDraftTokensSetupParams->randomSeed = setupParams->randomSeed;
mDecodingLayer->setup(batchSize, /* beamWidth */ 1, batchSlots, explicitDraftTokensSetupParams);
TLLM_CHECK_WITH_INFO(
beamWidth == 1, "Decoding mode is ExplicitDraftTokens, but beamWidth != 1 (%d != 1)", beamWidth);
mDecodingLayer->setup(batchSize, beamWidth, batchSlots, setupParams->decodingParams);
}
else
{
@ -178,7 +153,7 @@ void DecodingLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeTyp
template <typename T>
void DecodingLayer<T>::forwardAsync(
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto [outputParams, inputParams] = prepareParams(baseOutputs, baseInputs);
@ -188,7 +163,7 @@ void DecodingLayer<T>::forwardAsync(
template <typename T>
void DecodingLayer<T>::forwardSync(
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto [outputParams, inputParams] = prepareParams(baseOutputs, baseInputs);
@ -197,31 +172,30 @@ void DecodingLayer<T>::forwardSync(
}
template <typename T>
std::tuple<std::shared_ptr<BaseOutputParams>, std::shared_ptr<BaseInputParams>> DecodingLayer<T>::prepareParams(
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs) const
std::tuple<std::shared_ptr<BaseDecodingOutputs>, std::shared_ptr<BaseDecodingInputs>> DecodingLayer<T>::prepareParams(
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs,
std::shared_ptr<BaseDecodingInputs> const& baseInputs) const
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto outputs = std::dynamic_pointer_cast<DynamicDecodeOutputParams>(baseOutputs);
auto params = std::dynamic_pointer_cast<DynamicDecodeInputParams>(baseInputs);
auto params = std::dynamic_pointer_cast<DecodingInputs>(baseInputs);
auto const localDecoderDomain = getLocalDecoderDomain(params, mDecoderDomain);
auto const maxSeqLen = outputs->output_ids.shape[outputs->output_ids.shape.size() - 1];
auto const& endIds = params->end_ids;
auto const maxSeqLen = baseOutputs->outputIds.shape[baseOutputs->outputIds.shape.size() - 1];
auto const& endIds = params->endIds;
std::shared_ptr<BaseOutputParams> preparedOutputs;
std::shared_ptr<BaseInputParams> preparedInputs;
std::shared_ptr<BaseDecodingOutputs> preparedOutputs;
std::shared_ptr<BaseDecodingInputs> preparedInputs;
// dynamic decode GPT
if (mDecodingMode.isBeamSearch())
{
preparedInputs = baseInputs;
preparedOutputs = baseOutputs;
}
else if (mDecodingMode.isTopKorTopP())
{ // beamWidth == 1
{
auto const ite = params->ite;
auto const step = params->step;
auto const localBatchSize = static_cast<std::size_t>(params->local_batch_size);
auto const localBatchSize = static_cast<std::size_t>(params->localBatchSize);
TLLM_CHECK_WITH_INFO(localDecoderDomain.getBeamWidth() == 1,
"Decoding mode is TopK and/or TopP, but beamWidth != 1 (%d != 1)", localDecoderDomain.getBeamWidth());
@ -231,60 +205,29 @@ std::tuple<std::shared_ptr<BaseOutputParams>, std::shared_ptr<BaseInputParams>>
Tensor const logitsSlice{params->logits->slice(
{localBatchSize, static_cast<size_t>(localDecoderDomain.getBeamWidth()), params->logits->shape[2]}, 0)};
Tensor const endIdSlice{endIds.slice({localBatchSize}, 0)};
auto decodeInputs = std::make_shared<SamplingInputParams>(
step, ite, logitsSlice, endIdSlice, static_cast<SizeType32>(maxSeqLen));
auto decodeInputs = std::make_shared<SamplingInputs>(endIdSlice, step, ite, localBatchSize);
decodeInputs->finished = params->finished;
if (params->input_lengths)
decodeInputs->logits = logitsSlice;
if (params->inputLengths)
{
auto& inputLengths = params->input_lengths.value();
decodeInputs->input_lengths
auto& inputLengths = params->inputLengths.value();
decodeInputs->inputLengths
= inputLengths.slice({localBatchSize, static_cast<size_t>(localDecoderDomain.getBeamWidth())}, 0);
}
decodeInputs->batch_slots = params->batch_slots;
auto decodeOutputs = std::make_shared<SamplingOutputParams>(outputs->output_ids);
decodeOutputs->output_ids_ptr = std::move(outputs->output_ids_ptr);
if (outputs->sequence_length)
{
decodeOutputs->sequence_length
= outputs->sequence_length->slice({localBatchSize * localDecoderDomain.getBeamWidth()}, 0);
}
if (outputs->finished)
{
decodeOutputs->finished = outputs->finished->slice({localBatchSize * localDecoderDomain.getBeamWidth()}, 0);
}
if (outputs->cum_log_probs)
{
decodeOutputs->cum_log_probs
= outputs->cum_log_probs->slice({localBatchSize * localDecoderDomain.getBeamWidth()}, 0);
}
if (outputs->output_log_probs_tiled)
{
Tensor& output_log_probs = outputs->output_log_probs_tiled.value();
decodeOutputs->output_log_probs
= output_log_probs.slice({1, localBatchSize * localDecoderDomain.getBeamWidth()}, 0);
}
decodeInputs->batchSlots = params->batchSlots;
preparedInputs = decodeInputs;
preparedOutputs = decodeOutputs;
preparedOutputs = baseOutputs;
}
else if (mDecodingMode.isMedusa())
{
TLLM_CHECK_WITH_INFO(localDecoderDomain.getBeamWidth() == 1,
"Decoding mode is Medusa, but beamWidth != 1 (%d != 1)", localDecoderDomain.getBeamWidth());
auto medusaInputParams = std::make_shared<MedusaInputParams>(params->logits.value(), endIds);
medusaInputParams->finished = outputs->finished.value();
medusaInputParams->batch_slots = params->batch_slots;
medusaInputParams->paths = params->medusaInputs->medusaPaths;
medusaInputParams->medusaLogits = params->medusaInputs->medusaLogits;
medusaInputParams->medusaCurTokensPerStep = params->medusaInputs->medusaCurTokensPerStep;
medusaInputParams->medusaTargetTokensPerStep = params->medusaInputs->medusaTargetTokensPerStep;
medusaInputParams->treeIds = params->medusaInputs->medusaTreeIds;
preparedInputs = medusaInputParams;
preparedInputs = baseInputs;
preparedOutputs = baseOutputs;
}
else if (mDecodingMode.isLookahead())
@ -311,5 +254,4 @@ std::tuple<std::shared_ptr<BaseOutputParams>, std::shared_ptr<BaseInputParams>>
template class DecodingLayer<float>;
template class DecodingLayer<half>;
} // namespace layers
} // namespace tensorrt_llm
} // namespace tensorrt_llm::layers

View File

@ -17,22 +17,13 @@
#pragma once
#include <curand_kernel.h>
#include "tensorrt_llm/common/tensor.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/layers/baseLayer.h"
#include "tensorrt_llm/layers/beamSearchLayer.h"
#include "tensorrt_llm/layers/decodingParams.h"
#include "tensorrt_llm/layers/explicitDraftTokensLayer.h"
#include "tensorrt_llm/layers/medusaDecodingLayer.h"
#include "tensorrt_llm/layers/samplingLayer.h"
namespace tc = tensorrt_llm::common;
#include <curand_kernel.h>
namespace tensorrt_llm
{
namespace layers
namespace tensorrt_llm::layers
{
//! \brief Layer performs token decoding using sampling (beamWidth=1), beam search (beamWidth>1) or Medusa.
@ -46,19 +37,21 @@ public:
~DecodingLayer() override = default;
void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 const* batchSlots,
std::shared_ptr<BaseSetupParams> setupParams) override;
std::shared_ptr<BaseSetupParams> const& setupParams) override;
//! \brief Calls single SamplingLayer::forwardAsync or MedusaDecodingLayer::forwardAsync in batched mode
//! or runs BeamSearchLayer::forwardAsync in the loop for each request.
//! Modifies outputs->logits in-place.
void forwardAsync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) override;
void forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& inputs) override;
//! \brief Calls forwardSync of configired decoding layer.
void forwardSync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) override;
void forwardSync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& inputs) override;
private:
std::tuple<std::shared_ptr<BaseOutputParams>, std::shared_ptr<BaseInputParams>> prepareParams(
std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) const;
[[nodiscard]] std::tuple<std::shared_ptr<BaseDecodingOutputs>, std::shared_ptr<BaseDecodingInputs>> prepareParams(
std::shared_ptr<BaseDecodingOutputs> const& outputs, std::shared_ptr<BaseDecodingInputs> const& inputs) const;
private:
using BaseLayer::mWorkspaceSize;
@ -74,5 +67,4 @@ private:
std::unique_ptr<BaseLayer> mDecodingLayer;
};
} // namespace layers
} // namespace tensorrt_llm
} // namespace tensorrt_llm::layers

View File

@ -16,12 +16,16 @@
#pragma once
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/kernels/beamSearchKernels.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/request.h"
#include <tensorrt_llm/common/tensor.h>
#include <tensorrt_llm/runtime/common.h>
#include <tensorrt_llm/runtime/speculativeDecodingModule.h>
#include <optional>
#include <utility>
#include <vector>
namespace tc = tensorrt_llm::common;
@ -42,10 +46,11 @@ namespace tensorrt_llm::layers
//! It is passed through `setup` method.
//! 3. `forwardBatchSize` for layers forwarding for a batch of existing active requests.
//! it is passed through `forwardAsync` and `forwardSync` methods.
//! `setup` and `forward` always provide `batch_slots` indexed by
//! `setup` and `forward` always provide `batchSlots` indexed by
//! local batch index ranging in [0, setupBatchSize) or [0, forwardBatchSize),
//! holding the global batch index ranging in [0, maxBatchSize).
//! In case of beam search, maxBatchSize = forwardBatchSize = 1.
class DecoderDomain
{
public:
@ -56,7 +61,7 @@ public:
, mBeamWidth(beamWidth)
, mVocabSize(vocabSize)
, mVocabSizePadded(vocabSizePadded.value_or(vocabSize))
, mSpeculativeDecodingModule(speculativeDecodingModule)
, mSpeculativeDecodingModule(std::move(speculativeDecodingModule))
{
}
@ -91,6 +96,11 @@ public:
return mSpeculativeDecodingModule;
}
[[nodiscard]] std::shared_ptr<runtime::SpeculativeDecodingModule const> getSpeculativeDecodingModulePtr() const
{
return mSpeculativeDecodingModule;
}
private:
runtime::SizeType32 mBatchSize;
runtime::SizeType32 mBeamWidth;
@ -102,245 +112,444 @@ private:
class BaseSetupParams
{
public:
virtual ~BaseSetupParams() {}
virtual ~BaseSetupParams() = default;
};
// Penalty layer
class PenaltySetupParams : public BaseSetupParams
{
public:
std::optional<std::vector<float>> temperature; // [1] or [setupBatchSize] on cpu
std::optional<std::vector<runtime::SizeType32>> minLength; // [1] or [setupBatchSize] on cpu
std::optional<std::vector<float>> repetitionPenalty; // [1] or [setupBatchSize] on cpu
std::optional<std::vector<float>> presencePenalty; // [1] or [setupBatchSize] on cpu
std::optional<std::vector<float>> frequencyPenalty; // [1] or [setupBatchSize] on cpu
};
// Ban words layer
class BanWordsSetupParams : public BaseSetupParams
{
public:
std::optional<std::vector<runtime::SizeType32>> noRepeatNgramSize; // [1] or [setupBatchSize] on cpu
};
class DecodingSetupParams : public BaseSetupParams
{
public:
virtual ~DecodingSetupParams() = default;
std::optional<std::vector<uint64_t>> randomSeed; // [1] or [setupBatchSize] on cpu
std::optional<std::vector<bool>> outputLogProbs; // [setupBatchSize]
std::optional<std::vector<bool>> cumLogProbs; // [setupBatchSize]
};
class SamplingSetupParams : public DecodingSetupParams
{
public:
// baseSamplingLayer
std::optional<std::vector<runtime::SizeType32>> runtimeTopK; // [1] or [setupBatchSize] on cpu
std::optional<std::vector<float>> runtimeTopP; // [1] or [setupBatchSize] on cpu
// topPSamplingLayer
std::optional<std::vector<float>> topPDecay; // [setupBatchSize], must between [0, 1]
std::optional<std::vector<float>> topPMin; // [setupBatchSize], must between [0, 1]
std::optional<std::vector<runtime::TokenIdType>> topPResetIds; // [setupBatchSize]
std::optional<bool> normalizeLogProbs;
};
class BeamSearchSetupParams : public DecodingSetupParams
{
public:
// BeamSearchLayer
std::optional<std::vector<float>> beamSearchDiversityRate; // [setupBatchSize] on cpu
std::optional<std::vector<float>> lengthPenalty; // [setupBatchSize] on cpu
std::optional<std::vector<int>> earlyStopping; // [setupBatchSize] on cpu
bool hasDiffRuntimeArgs{false};
};
class MedusaSetupParams : public DecodingSetupParams
{
public:
// Medusa params
std::optional<std::vector<runtime::SizeType32>> runtimeTopK; // [setupBatchSize] on cpu
std::optional<std::vector<std::vector<runtime::SizeType32>>> runtimeHeadsTopK; // [setupBatchSize, maxMedusaHeads]
};
class ExplicitDraftTokensSetupParams : public DecodingSetupParams
{
public:
std::optional<std::vector<float>> temperature; // [setupBatchSize] on cpu
// Hack to init some data for the context phase in the setup.
tc::Tensor randomDataSample; // [maxBatchSize], on gpu
tc::Tensor temperatures; // [maxBatchSize], on gpu
};
class DynamicDecodeSetupParams : public BaseSetupParams
{
public:
// Penalty layer
struct PenaltyParams
{
std::optional<std::vector<float>> temperature; // [1] or [setupBatchSize] on cpu
std::optional<std::vector<runtime::SizeType32>> minLength; // [1] or [setupBatchSize] on cpu
std::optional<std::vector<float>> repetitionPenalty; // [1] or [setupBatchSize] on cpu
std::optional<std::vector<float>> presencePenalty; // [1] or [setupBatchSize] on cpu
std::optional<std::vector<float>> frequencyPenalty; // [1] or [setupBatchSize] on cpu
std::optional<std::vector<runtime::SizeType32>> noRepeatNgramSize; // [1] or [setupBatchSize] on cpu
};
std::shared_ptr<PenaltySetupParams> penaltyParams;
struct SamplingParams
{
// baseSamplingLayer
std::optional<std::vector<runtime::SizeType32>> runtime_top_k; // [1] or [setupBatchSize] on cpu
std::optional<std::vector<float>> runtime_top_p; // [1] or [setupBatchSize] on cpu
std::shared_ptr<BanWordsSetupParams> banWordsParams;
// topPSamplingLayer
std::optional<std::vector<float>> top_p_decay; // [setupBatchSize], must between [0, 1]
std::optional<std::vector<float>> top_p_min; // [setupBatchSize], must between [0, 1]
std::optional<std::vector<runtime::TokenIdType>> top_p_reset_ids; // [setupBatchSize]
std::optional<bool> normalize_log_probs;
std::optional<std::vector<bool>> outputLogProbs; // [setupBatchSize]
std::optional<std::vector<bool>> cumLogProbs; // [setupBatchSize]
};
struct BeamSearchParams
{
// BeamSearchLayer
std::optional<std::vector<float>> beam_search_diversity_rate; // [setupBatchSize] on cpu
std::optional<std::vector<float>> length_penalty; // [setupBatchSize] on cpu
std::optional<std::vector<int>> early_stopping; // [setupBatchSize] on cpu
};
struct MedusaParams
{
// Medusa params
std::optional<std::vector<std::vector<runtime::SizeType32>>>
topKMedusaHeads; // [setupBatchSize, maxMedusaHeads]
};
std::optional<std::vector<uint64_t>> randomSeed; // [1] or [setupBatchSize] on cpu
PenaltyParams penaltyParams;
SamplingParams samplingParams;
BeamSearchParams beamSearchParams;
MedusaParams medusaParams;
std::shared_ptr<DecodingSetupParams> decodingParams;
};
class BaseInputParams
class LookaheadSetupParams : public DecodingSetupParams
{
public:
explicit BaseInputParams(runtime::SizeType32 step, runtime::SizeType32 ite, tc::Tensor endIds)
: step{step}
, ite{ite}
, end_ids{std::move(endIds)}
std::vector<runtime::ITensor::SharedConstPtr> prompt; // [batchSize][maxSeqLen] on cpu
std::optional<std::vector<uint64_t>> randomSeed; // [1] or [batchSize] on cpu
std::vector<executor::LookaheadDecodingConfig> algoConfigs; // [1 or batchSize] on cpu
};
class BaseDecodingInputs
{
public:
BaseDecodingInputs(runtime::SizeType32 localBatchSize)
: localBatchSize(localBatchSize)
{
}
virtual ~BaseInputParams() {}
virtual ~BaseDecodingInputs() = default;
// mandatory parameters
runtime::SizeType32 localBatchSize;
};
// Ban words inputs
class BanWordsDecodingInputs : public BaseDecodingInputs
{
public:
BanWordsDecodingInputs(runtime::SizeType32 localBatchSize)
: BaseDecodingInputs(localBatchSize)
{
}
runtime::SizeType32 maxBadWordsLen{0};
//! [maxBatchSize][2, bad_words_length], on gpu
std::optional<tc::Tensor> badWordsPtr;
//! [maxBatchSize], on gpu
std::optional<tc::Tensor> badWordsLengths;
};
// Stop criteria inputs
class StopCriteriaDecodingInputs : public BaseDecodingInputs
{
public:
StopCriteriaDecodingInputs(runtime::SizeType32 localBatchSize)
: BaseDecodingInputs(localBatchSize)
{
}
runtime::SizeType32 maxStopWordsLen{0};
//! [maxBatchSize], on gpu
std::optional<tc::Tensor> sequenceLimitLength;
//! [maxBatchSize][2, stop_words_length], on gpu
std::optional<tc::Tensor> stopWordsPtr;
//! [maxBatchSize], on gpu
std::optional<tc::Tensor> stopWordsLengths;
};
class DecodingInputs : public BaseDecodingInputs
{
public:
DecodingInputs(tc::Tensor endIds, runtime::SizeType32 step = 0, runtime::SizeType32 ite = 0,
runtime::SizeType32 localBatchSize = 0, runtime::SizeType32 maxAttentionWindow = 0,
runtime::SizeType32 sinkTokenLength = 0)
: BaseDecodingInputs(localBatchSize)
, endIds{std::move(endIds)}
, step{step}
, ite{ite}
, maxAttentionWindow{maxAttentionWindow}
, sinkTokenLength{sinkTokenLength}
{
}
//! [maxBatchSize]
tc::Tensor endIds;
// used only for python runtime
runtime::SizeType32 step;
runtime::SizeType32 ite;
tc::Tensor end_ids; // [maxBatchSize]
std::optional<tc::Tensor> batch_slots; // [forwardBatchSize], on pinned memory
std::optional<tc::Tensor> finished; // [maxBatchSize, maxBeamWidth]
};
class DynamicDecodeInputParams : public BaseInputParams
{
public:
DynamicDecodeInputParams(runtime::SizeType32 step, runtime::SizeType32 ite, runtime::SizeType32 maxInputLength,
runtime::SizeType32 maxAttentionWindow, runtime::SizeType32 sinkTokenLength, runtime::SizeType32 localBatchSize,
tc::Tensor endIds)
: BaseInputParams(step, ite, std::move(endIds))
, max_input_length{maxInputLength}
, max_attention_window{maxAttentionWindow}
, sink_token_length{sinkTokenLength}
, local_batch_size{localBatchSize}
, max_stop_words_len{0}
, max_bad_words_len{0}
{
}
// mandatory parameters
runtime::SizeType32 max_input_length;
runtime::SizeType32 max_attention_window;
runtime::SizeType32 sink_token_length;
runtime::SizeType32 local_batch_size;
runtime::SizeType32 max_stop_words_len;
runtime::SizeType32 max_bad_words_len;
runtime::SizeType32 maxAttentionWindow;
runtime::SizeType32 sinkTokenLength;
// One of these two fields has to be set
// DynamicDecodeLayer::forward checks for it
// Need both of these fields to support legacy code during transition period to the batched decoder
std::optional<tc::Tensor> logits; // [maxBatchSize, beamWidth, vocabSizePadded]
std::optional<std::vector<tc::Tensor>> logits_vec; // [forwardBatchSize][beamWidth, vocabSizePadded], on gpu
//! One of these two fields has to be set
//! DynamicDecodeLayer::forward checks for it
//! Need both of these fields to support legacy code during transition period to the batched decoder
//! [forwardBatchSize, beamWidth, vocabSizePadded]
std::optional<tc::Tensor> logits;
//! [forwardBatchSize][beamWidth, vocabSizePadded], on gpu
std::optional<std::vector<tc::Tensor>> logitsVec;
// optional parameters
std::optional<tc::Tensor> src_cache_indirection; // [forwardBatchSize, maxBeamWidth, maxSeqLen] - the k/v cache
// index for beam search, mandatory for beam search, on gpu
std::optional<tc::Tensor> sequence_limit_length; // [maxBatchSize], on gpu
std::optional<tc::Tensor> embedding_bias; // [vocabSizePadded], on gpu
std::optional<tc::Tensor> input_lengths; // [maxBatchSize, maxBeamWidth], on gpu
std::optional<tc::Tensor> bad_words_ptr; // [maxBatchSize][2, bad_words_length], on gpu
std::optional<tc::Tensor> bad_words_lengths; // [maxBatchSize], on gpu
std::optional<tc::Tensor> stop_words_ptr; // [maxBatchSize][2, stop_words_length], on gpu
std::optional<tc::Tensor> stop_words_lengths; // [maxBatchSize], on gpu
//! the indices of the selected beams, mandatory for beam search, on gpu
//! [forwardBatchSize, maxBeamWidth, maxSeqLen]
std::optional<tc::Tensor> srcCacheIndirection;
//! [vocabSizePadded], on gpu
std::optional<tc::Tensor> embeddingBias;
//! [maxBatchSize, maxBeamWidth], on gpu
std::optional<tc::Tensor> inputLengths;
//! [forwardBatchSize], on pinned memory
std::optional<tc::Tensor> batchSlots;
//! [maxBatchSize, maxBeamWidth]
std::optional<tc::Tensor> finished;
//! [maxBatchSize], on gpu
std::optional<tc::Tensor> curTokensPerStep;
// Medusa inputs
class MedusaInputs
{
public:
tc::Tensor medusaCurTokensPerStep; // [maxBatchSize], optional, on gpu
tc::Tensor medusaTargetTokensPerStep; // [maxBatchSize], optional, on gpu
tc::Tensor medusaPaths; // [maxBatchSize, maxPathLen, maxPathLen]
// optional, on gpu
tc::Tensor medusaTreeIds; // [maxBatchSize, maxDecodingTokens], optional, on gpu
std::vector<std::vector<tc::Tensor>> medusaLogits; // [maxBatchSize][maxDraftPathLen]
// [maxDecodingTokens, vocabSizePadded], optional, on gpu
};
std::shared_ptr<BanWordsDecodingInputs> banWordsInputs;
// Explicit draft tokens inputs
// FIXME(nkorobov): this should be ExplicitDraftTokensBuffers?
class ExplicitDraftTokensInputs
{
public:
};
std::optional<MedusaInputs> medusaInputs;
std::optional<ExplicitDraftTokensInputs> explicitDraftTokensInputs;
std::shared_ptr<StopCriteriaDecodingInputs> stopCriteriaInputs;
};
class BaseOutputParams
class SamplingInputs : public DecodingInputs
{
public:
explicit BaseOutputParams(tc::Tensor outputIds)
: output_ids{std::move(outputIds)}
explicit SamplingInputs(
tc::Tensor endIds, runtime::SizeType32 step, runtime::SizeType32 ite, runtime::SizeType32 localBatchSize)
: DecodingInputs{std::move(endIds), step, ite, localBatchSize}
{
}
virtual ~BaseOutputParams() {}
//! optional parameters
//! [localBatchSize]
curandState_t* curandStates{};
//! Pointer to the workspace for sampling computation
void* samplingWorkspace{};
//! Flag to mark that logits tensor contains probabilities
bool probsComputed{};
};
// Medusa inputs
class MedusaDecodingInputs : public DecodingInputs
{
public:
explicit MedusaDecodingInputs(tc::Tensor endIds, runtime::SizeType32 localBatchSize)
: DecodingInputs(std::move(endIds), 0, 0, localBatchSize)
{
}
//! [maxBatchSize], on gpu
tc::Tensor targetTokensPerStep;
//! [maxBatchSize, maxPathLen, maxPathLen], on gpu
tc::Tensor paths;
//! [maxBatchSize, maxDecodingTokens], on gpu
tc::Tensor treeIds;
//! [maxBatchSize][maxDraftPathLen][maxDecodingTokens, vocabSizePadded], on gpu
std::vector<std::vector<tc::Tensor>> medusaLogits;
};
// Explicit draft tokens inputs
class ExplicitDraftTokensInputs : public DecodingInputs
{
public:
explicit ExplicitDraftTokensInputs(tc::Tensor endIds, runtime::SizeType32 batchSize)
: DecodingInputs(std::move(endIds), 0, 0, batchSize)
{
}
//! Draft tokens for the next iteration. The first token in each path is the last accepted token at current
//! iteration. E.g. if forwardBatchSize == 1, maxNumPaths == 2, maxPathLen== 3, [[[0, 1, 2], [0, 1, 10]]]
tc::Tensor nextDraftTokens; // [forwardBatchSize, maxNumPaths, maxPathLen], gpu
//! Compressed form of `nextDraftTokens`, where common prefixes and collapsed.
//! Using example above [0, 1, 2, 10]
tc::Tensor nextFlatTokens; // [forwardBatchSize * maxDecodingTokens], gpu
//! Indices of draft tokens in the compressed `nextFlatTokens` for the next iteration.
//! Using example above, [[[0, 1, 2], [0, 1, 3]]]
tc::Tensor nextDraftIndices; // [forwardBatchSize, maxNumPaths, maxPathLen], gpu
//! Probabilities of the next draft tokens.
tc::Tensor nextDraftProbs; // [forwardBatchSize, maxNumPaths, maxDraftPathLen, vocabSize], gpu
//! Same as `nextDraftTokens`, but for current iteration.
//! Current accepted tokens obtained as `lastDraftTokens[bi][bestPathIndices[bi]][1:bestPathLengths[bi]]`.
tc::Tensor lastDraftTokens; // [forwardBatchSize, maxNumPaths, maxPathLen], gpu
//! Same as `nextDraftIndices`, but for current iteration.
tc::Tensor lastDraftIndices; // [forwardBatchSize, maxNumPaths, maxPathLen], gpu
//! Boolean attention masks.
//! maxDecodingTokens' = generationLengths.max()
tc::Tensor masks; // [forwardBatchSize, maxDecodingTokens', maxDecodingTokens'], gpu
//! Relative to `positionIdsBase` position ids. Same as `nextFlatTokens` for next draft indices.
//! Using example above, [0, 1, 2, 3]
tc::Tensor packedPosIds; // [forwardBatchSize * maxDecodingTokens], gpu
//! Lengths of the accepted paths for each request. It is 1 for context phase (Only 1 primary tokens is accepted).
tc::Tensor bestPathLengths; // [forwardBatchSize], gpu
//! Indices of the accepted paths for each request. It is 0 for context phase.
tc::Tensor bestPathIndices; // [forwardBatchSize], gpu
//! Number of the draft tokens for the next iteration.
tc::Tensor generationLengths; // [forwardBatchSize], gpu
//! Baseline for the position ids.
tc::Tensor positionIdsBase; // [forwardBatchSize], gpu
//! Generation length for the previous stage.
tc::Tensor lastGenerationLengths; // [forwardBatchSize], gpu
//! Maximum number of generated tokens for the next step across whole batch
tc::Tensor maxGenLengthDevice; // [1], on gpu
//! Address map to map from linear indices of the engine outputs to seqSlot.
//! It is not the same as batchSlots because it maps the ordered engine outputs to the respective seqSlot,
//! while batchSlots is just a a list of active seqSlots.
tc::Tensor seqSlots; // [forwardBatchSize], on gpu
};
class LookaheadDecodingInputs : public DecodingInputs
{
using TensorConstPtr = runtime::ITensor::SharedConstPtr;
public:
explicit LookaheadDecodingInputs(tc::Tensor endIds)
: DecodingInputs{std::move(endIds)}
//, logits{logits}
{
}
// TODO(liweim) reuse base logits and curTokensPerStep.
// TensorConstPtr logits; // [batchSize, maxTokensPerStep, vocabSizePadded] on gpu
// TensorConstPtr tokensPerStep; // [maxBatchSize] on gpu
};
class BaseDecodingOutputs
{
public:
explicit BaseDecodingOutputs(tc::Tensor outputIds)
: outputIds{std::move(outputIds)}
{
}
virtual ~BaseDecodingOutputs() = default;
// mandatory parameters
tc::Tensor output_ids; // [maxBatchSize, maxSeqLen]
tc::Tensor outputIds; // [maxBatchSize, maxSeqLen]
// optional parameters
std::optional<tc::Tensor> finished; // [maxBatchSize * maxBeamWidth], optional
std::optional<tc::Tensor> sequence_length; // [maxBatchSize * maxBeamWidth], optional
std::optional<tc::Tensor> cum_log_probs; // [maxBatchSize * maxBeamWidth], necessary in beam search
std::optional<tc::Tensor> output_log_probs; // [maxBatchSize, maxBeamWidth, maxSeqLen], must be float*, optional
std::optional<tc::Tensor> parent_ids; // [maxBatchSize, maxBeamWidth, maxSeqLen], necessary in beam search
//! [maxBatchSize * maxBeamWidth], optional
std::optional<tc::Tensor> finished;
//! [maxBatchSize * maxBeamWidth], optional
std::optional<tc::Tensor> sequenceLength;
//! [maxBatchSize * maxBeamWidth], necessary in beam search
std::optional<tc::Tensor> cumLogProbs;
//! [maxBatchSize, maxBeamWidth, maxSeqLen], must be float*, optional
std::optional<tc::Tensor> outputLogProbs;
//! [maxBatchSize, maxBeamWidth, maxSeqLen], necessary in beam search
std::optional<tc::Tensor> parentIds;
tc::Tensor output_ids_ptr; // [maxBatchSize] int* (2-d array), each int* has [maxBeamWidth, maxSeqLen]
//! [maxBatchSize] int* (2-d array), each int* has [maxBeamWidth, maxSeqLen]
tc::Tensor outputIdsPtr;
//! [maxBatchSize] int* (2-d array), each int* has [maxBeamWidth, maxSeqLen]
tc::Tensor parentIdsPtr;
//!
//! \brief SpeculativeDecodingOutputs outputs.
//!
//! For one example sequence [a, b] [c] <x, y, z>, where, [a, b, c] is the accepted sequence,
//! [c] is the last accepted token, and <x, y, z> is the draft tokens from `nextDraftTokens` saved by last step.
//! [c]'s position id is known, only position ids for <x, y, z> need to be provided in `nextDraftPosIds`.
//! LLM inputs {c, x, y, z} and generates {c', x', y', z'}.
//!
//! {c'} is always accepted and {x', z'} is supposed to be accepted.
//! The accepted tokens [c', x', z'] is saved in `output_ids` in-place, starting from `sequence_length`.
//! The `acceptedLength` is 3, and the accepted draft tokens length is 2.
//! `sequence_length` is also increaded by `acceptedLength` in-place.
//! The pathsOffset is {0, 1, 3} for {c', x', z'}.
//! [] for accepted, <> for draft, {} for input/output.
//!
//! For a batchSlots {1, 3}, `acceptedLengthsCumSum` is an exclusive sum of `acceptedLength` over the batch,
//! the `acceptedLengths` may be {3, 5}, `acceptedLengthsCumSum` is {0, 3, 8}.
class SpeculativeDecodingOutputs
{
public:
tc::Tensor nextDraftTokens; // [maxBatchSize, maxDecodingDraftTokens], draft tokens for the next step
tc::Tensor nextDraftPosIds; // [maxBatchSize, maxDecodingDraftTokens], draft token position IDs
tc::Tensor nextDraftLengths; // [maxBatchSize], next step draft tokens lengths
tc::Tensor acceptedLengths; // [maxBatchSize], lengths of the accepted draft tokens + 1.
tc::Tensor acceptedLengthsCumSum; // [maxBatchSize + 1] accumulative sum along batchSlots.
tc::Tensor pathsOffsets; // [maxBatchSize, maxPathLen]
tc::Tensor packedMasks; // [maxBatchSize, maxDecodingTokens, divUp(maxDecodingTokens, 32)]
};
class ExplicitDraftTokensOutputs : public SpeculativeDecodingOutputs
{
public:
//! Draft tokens for the next iteration. The first token in each path is the last accepted token at current
//! iteration. E.g. if batchSize == 1, maxNumPaths == 2, maxPathLen== 3, [[[0, 1, 2], [0, 1, 10]]]
tc::Tensor unpackedNextDraftTokens; // [maxBatchSize, maxNumPaths, maxPathLen] on gpu
//! Indices of draft tokens in the compressed `nextFlatTokens` for the next iteration.
//! Using example above, [[[0, 1, 2], [0, 1, 3]]]
tc::Tensor unpackedNextDraftIndices; // [maxBatchSize, maxNumPaths, maxPathLen] on gpu
//! Probabilities of the next draft tokens.
tc::Tensor nextDraftProbs; // [maxBatchSize, maxNumPaths, maxPathDraftLen, vocabSize] on gpu
//! Baseline for the position ids.
tc::Tensor positionIdsBase; // [maxBatchSize] on gpu
//! Randomly sampled data (between 0.f and 1.f)
tc::Tensor randomDataSample; // [maxBatchSize] on gpu
//! Randomly sampled data (between 0.f and 1.f)
tc::Tensor randomDataValidation; // [maxBatchSize, maxNumPaths, maxDraftPathLen] on gpu
//! Sampling temperature.
tc::Tensor temperatures; // [maxBatchSize] on gpu
};
std::optional<SpeculativeDecodingOutputs> speculativeDecodingOutputs;
std::optional<ExplicitDraftTokensOutputs> explicitDraftTokensOutputs;
};
class DynamicDecodeOutputParams : public BaseOutputParams
{
public:
explicit DynamicDecodeOutputParams(tc::Tensor outputIds)
: BaseOutputParams{std::move(outputIds)}
{
}
// mandatory parameters
// Tokens predicted at current iteration.
tc::Tensor newTokens; // [maxBatchSize, maxBeamWidth]
// optional parameters
std::optional<tc::Tensor> finished_sum; // [1] in pinned host memory
std::optional<tc::Tensor> output_log_probs_tiled; // [maxSeqLen, maxBatchSize, maxBeamWidth], must be float*
std::optional<tc::Tensor>
tgt_cache_indirection; // [forwardBatchSize, maxBeamWidth, maxSeqLen], the k/v cache index for beam search
std::unique_ptr<kernels::BeamHypotheses> beamHypotheses; // structure maintains some pointers of beam search
tc::Tensor parent_ids_ptr; // [maxBatchSize] int* (2-d array), each int* has [maxBeamWidth, maxSeqLen]
// optional parameters
//! Number of tokens predicted at current iteration.
//! [maxBatchSize]
std::optional<tc::Tensor> numNewTokens;
//! [1] in pinned host memory
std::optional<tc::Tensor> finishedSum;
//! [maxSeqLen, maxBatchSize, maxBeamWidth], must be float*
std::optional<tc::Tensor> outputLogProbsTiled;
};
class BeamSearchOutputs : public BaseDecodingOutputs
{
public:
explicit BeamSearchOutputs(tc::Tensor outputIds)
: BaseDecodingOutputs{std::move(outputIds)}
{
}
//! the k/v cache index for beam search
//! [forwardBatchSize, maxBeamWidth, maxSeqLen]
tc::Tensor tgtCacheIndirection;
//! structure maintains some pointers of beam search
std::unique_ptr<kernels::BeamHypotheses> beamHypotheses;
};
//!
//! \brief SpeculativeDecodingOutputs outputs.
//!
//! For one example sequence [a, b] [c] <x, y, z>, where, [a, b, c] is the accepted sequence,
//! [c] is the last accepted token, and <x, y, z> is the draft tokens from `nextDraftTokens` saved by last step.
//! [c]'s position id is known, only position ids for <x, y, z> need to be provided in `nextDraftPosIds`.
//! LLM inputs {c, x, y, z} and generates {c', x', y', z'}.
//!
//! {c'} is always accepted and {x', z'} is supposed to be accepted.
//! The accepted tokens [c', x', z'] is saved in `outputIds` in-place, starting from `sequenceLength`.
//! The `acceptedLength` is 3, and the accepted draft tokens length is 2.
//! `sequenceLength` is also increaded by `acceptedLength` in-place.
//! The pathsOffset is {0, 1, 3} for {c', x', z'}.
//! [] for accepted, <> for draft, {} for input/output.
//!
//! For a batchSlots {1, 3}, `numNewTokensCumSum` is an exclusive sum of `numNewTokens` over the batch,
//! the `numNewTokens` may be {3, 5}, `numNewTokensCumSum` is {0, 3, 8}.
//!
//! `nextDraftLengths` and `prevDraftLengths` are needed for methods that support if variable
//! draft length. `nextDraftLengths` must contain the number of draft tokens per request for the next iteration.
//! `prevDraftLengths` must contain the number of draft tokens used in the current iteraiton.
//!
//! `pathsOffsets` is needed for KV cache rewind. It contains the positions of the accepted draft tokens in the
//! flattened tensor of draft tokens. E.g. if for sequence {c, x, y, z} only `y` and `z` were accepted,
//! `pathsOffsets` contains [1, 2]. `pathsOffsets` is flattened tensor for whole batch.
//!
//! The order of `pathsOffsets` and `numNewTokensCumSum` must be aligned. Such that
//! `pathsOffset[numNewTokensCumSum[bi]:numNewTokensCumSum[bi+1]]` is the slice of offsets for `bi`th request.
//! Furthermore, the order of requests is important and must be aligned with sorted `RuntimeBuffers::seqSlots`
//! such that the request with smaller `seqSlot` stays earlier in the tensors.
//! However, this condition usually holds if method does not expect from the engine anything else, but logits.
class SpeculativeDecodingOutputs : public BaseDecodingOutputs
{
public:
explicit SpeculativeDecodingOutputs(tc::Tensor outputIds)
: BaseDecodingOutputs{std::move(outputIds)}
{
}
//! Draft tokens for the next step
// [maxBatchSize, maxDecodingDraftTokens]
tc::Tensor nextDraftTokens;
//! Draft token position IDs
//! [maxBatchSize, maxDecodingDraftTokens]
tc::Tensor nextDraftPosIds;
//! Prev step draft tokens lengths, should be filled only for variable draft length speculative decoding mode
//! [maxBatchSize]
tc::Tensor prevDraftLengths;
//! Next step draft tokens lengths, should be filled only for variable draft length speculative decoding mode
//! [maxBatchSize]
tc::Tensor nextDraftLengths;
//! Accumulative sum along batchSlots.
//! [maxBatchSize + 1]
tc::Tensor numNewTokensCumSum;
//! [maxBatchSize * maxPathLen]
tc::Tensor pathsOffsets;
//! [maxBatchSize, maxDecodingTokens, divUp(maxDecodingTokens, 32)]
tc::Tensor packedMasks;
};
class ExplicitDraftTokensOutputs : public SpeculativeDecodingOutputs
{
public:
explicit ExplicitDraftTokensOutputs(tc::Tensor outputIds)
: SpeculativeDecodingOutputs{std::move(outputIds)}
{
}
//! Draft tokens for the next iteration. The first token in each path is the last accepted token at current
//! iteration. E.g. if batchSize == 1, maxNumPaths == 2, maxPathLen== 3, [[[0, 1, 2], [0, 1, 10]]]
tc::Tensor unpackedNextDraftTokens; // [maxBatchSize, maxNumPaths, maxPathLen] on gpu
//! Indices of draft tokens in the compressed `nextFlatTokens` for the next iteration.
//! Using example above, [[[0, 1, 2], [0, 1, 3]]]
tc::Tensor unpackedNextDraftIndices; // [maxBatchSize, maxNumPaths, maxPathLen] on gpu
//! Probabilities of the next draft tokens.
tc::Tensor nextDraftProbs; // [maxBatchSize, maxNumPaths, maxPathDraftLen, vocabSize] on gpu
//! Baseline for the position ids.
tc::Tensor positionIdsBase; // [maxBatchSize] on gpu
//! Randomly sampled data (between 0.f and 1.f)
tc::Tensor randomDataSample; // [maxBatchSize] on gpu
//! Randomly sampled data (between 0.f and 1.f)
tc::Tensor randomDataValidation; // [maxBatchSize, maxNumPaths, maxDraftPathLen] on gpu
//! Sampling temperature.
tc::Tensor temperatures; // [maxBatchSize] on gpu
//! Next generation lengths.
tc::Tensor generationLengths; // [maxBatchSize] on gpu
//! Maximum number of generated tokens for the next step across whole batch
tc::Tensor maxGenLengthHost; // [1] on pinned
};
} // namespace tensorrt_llm::layers

View File

@ -15,22 +15,20 @@
*/
#include "tensorrt_llm/layers/dynamicDecodeLayer.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/common/tensor.h"
#include "tensorrt_llm/kernels/decodingKernels.h"
#include "tensorrt_llm/layers/beamSearchLayer.h"
#include "tensorrt_llm/layers/defaultDecodingParams.h"
#include "tensorrt_llm/layers/layerUtils.h"
#include "tensorrt_llm/layers/layersFactory.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include <optional>
#include <utility>
using namespace tensorrt_llm::common;
using namespace tensorrt_llm::kernels;
using namespace tensorrt_llm::runtime;
namespace tensorrt_llm
{
namespace layers
namespace tensorrt_llm::layers
{
template <typename T>
@ -111,17 +109,18 @@ void DynamicDecodeLayer<T>::initializeLayers()
template <typename T>
void DynamicDecodeLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 const* batchSlots,
std::shared_ptr<BaseSetupParams> baseSetupParams)
std::shared_ptr<BaseSetupParams> const& baseSetupParams)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto setupParams = std::dynamic_pointer_cast<DynamicDecodeSetupParams>(baseSetupParams);
if (setupParams->samplingParams.outputLogProbs)
TLLM_CHECK_WITH_INFO(setupParams->decodingParams, "decodingParams for setup is not set");
if (setupParams->decodingParams->outputLogProbs)
{
// FIXME(nkorobov): monotonically growing
mOutputLogProbs = std::any_of(setupParams->samplingParams.outputLogProbs->begin(),
setupParams->samplingParams.outputLogProbs->end(),
mOutputLogProbs = std::any_of(setupParams->decodingParams->outputLogProbs->begin(),
setupParams->decodingParams->outputLogProbs->end(),
[this](bool outputLogProbs) { return this->mOutputLogProbs | outputLogProbs; });
}
@ -153,20 +152,19 @@ void DynamicDecodeLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, Si
template <typename T>
void DynamicDecodeLayer<T>::forwardAsync(
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto params = std::dynamic_pointer_cast<DynamicDecodeInputParams>(baseInputs);
auto outputs = std::dynamic_pointer_cast<DynamicDecodeOutputParams>(baseOutputs);
auto params = std::dynamic_pointer_cast<DecodingInputs>(baseInputs);
TLLM_CHECK_WITH_INFO(mDecodingMode.isExplicitDraftTokens() || params->logits || params->logits_vec,
"If not explicit Draft Tokens mode, either logits or logits_vec have to be specified.");
TLLM_CHECK_WITH_INFO(mDecodingMode.isExplicitDraftTokens() || params->logits || params->logitsVec,
"If not explicit Draft Tokens mode, either logits or logitsVec have to be specified.");
TLLM_CHECK_WITH_INFO(
outputs->sequence_length.has_value(), "sequence_length tensor is required in DynamicDecoderLayer.");
baseOutputs->sequenceLength.has_value(), "sequenceLength tensor is required in DynamicDecoderLayer.");
auto const localDecoderDomain = getLocalDecoderDomain(params, mDecoderDomain);
auto const maxSeqLen = outputs->output_ids.shape[outputs->output_ids.shape.size() - 1];
auto const maxSeqLen = baseOutputs->outputIds.shape[baseOutputs->outputIds.shape.size() - 1];
TLLM_CHECK_WITH_INFO((mConfiguredBeamWidth == 1 && localDecoderDomain.getBeamWidth() == 1)
|| (mConfiguredBeamWidth > 1 && localDecoderDomain.getBeamWidth() > 1
@ -185,12 +183,12 @@ void DynamicDecodeLayer<T>::forwardAsync(
std::vector<SizeType32> batchSlotsVec(localDecoderDomain.getBatchSize());
std::iota(batchSlotsVec.begin(), batchSlotsVec.end(), 0);
auto batchSlotsHost
= params->batch_slots ? params->batch_slots->template getPtr<SizeType32 const>() : batchSlotsVec.data();
auto batchSlots = params->batch_slots ? params->batch_slots->template getPtr<SizeType32 const>() : nullptr;
= params->batchSlots ? params->batchSlots->template getPtr<SizeType32 const>() : batchSlotsVec.data();
auto batchSlots = params->batchSlots ? params->batchSlots->template getPtr<SizeType32 const>() : nullptr;
mCyclicStep = mCyclicStep % mRuntimeMaxSeqLen;
prepareIdsPtrs(
outputs, batchSlotsHost, localDecoderDomain.getBatchSize(), localDecoderDomain.getBeamWidth(), maxSeqLen);
baseOutputs, batchSlotsHost, localDecoderDomain.getBatchSize(), localDecoderDomain.getBeamWidth(), maxSeqLen);
for (auto& layer : mLayers)
{
@ -198,7 +196,7 @@ void DynamicDecodeLayer<T>::forwardAsync(
}
// Copy nextIds and transpose logits when needed
prepareOutputData(outputs, params, mIdsPtrHost, batchSlots, localDecoderDomain.getBatchSize(),
prepareOutputData(baseOutputs, params, mIdsPtrHost, batchSlots, localDecoderDomain.getBatchSize(),
mDecoderDomain.getBatchSize(), localDecoderDomain.getBeamWidth(), maxSeqLen,
mDecoderDomain.getMaxDecodingTokens(), mCyclicStep, mOutputLogProbs, mStream);
@ -210,7 +208,7 @@ void DynamicDecodeLayer<T>::forwardAsync(
template <typename T>
void DynamicDecodeLayer<T>::forwardSync(
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
for (auto& layer : mLayers)
@ -221,7 +219,7 @@ void DynamicDecodeLayer<T>::forwardSync(
}
template <typename T>
void DynamicDecodeLayer<T>::prepareIdsPtrs(std::shared_ptr<DynamicDecodeOutputParams> const& outputs,
void DynamicDecodeLayer<T>::prepareIdsPtrs(std::shared_ptr<BaseDecodingOutputs> const& outputs,
SizeType32 const* batchSlots, SizeType32 batchSize, SizeType32 beamWidth, SizeType32 maxSeqLen)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
@ -231,7 +229,7 @@ void DynamicDecodeLayer<T>::prepareIdsPtrs(std::shared_ptr<DynamicDecodeOutputPa
{
auto const batchSlot = batchSlots[bi];
idsPtrHost[batchSlot]
= outputs->output_ids.template getPtrWithOffset<TokenIdType>(batchSlot * beamWidth * maxSeqLen);
= outputs->outputIds.template getPtrWithOffset<TokenIdType>(batchSlot * beamWidth * maxSeqLen);
}
for (SizeType32 bi = 0; bi < batchSize; bi++)
{
@ -239,7 +237,7 @@ void DynamicDecodeLayer<T>::prepareIdsPtrs(std::shared_ptr<DynamicDecodeOutputPa
if (beamWidth > 1)
{
idsPtrHost[mDecoderDomain.getBatchSize() + batchSlot]
= outputs->parent_ids.value().template getPtrWithOffset<SizeType32>(bi * beamWidth * maxSeqLen);
= outputs->parentIds.value().template getPtrWithOffset<SizeType32>(bi * beamWidth * maxSeqLen);
}
else
{
@ -247,11 +245,11 @@ void DynamicDecodeLayer<T>::prepareIdsPtrs(std::shared_ptr<DynamicDecodeOutputPa
}
}
outputs->output_ids_ptr = Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR,
outputs->outputIdsPtr = Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR,
{static_cast<size_t>(mDecoderDomain.getBatchSize()), static_cast<size_t>(beamWidth),
static_cast<size_t>(maxSeqLen)},
idsPtrHost);
outputs->parent_ids_ptr = Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR,
outputs->parentIdsPtr = Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR,
{static_cast<size_t>(mDecoderDomain.getBatchSize()), static_cast<size_t>(beamWidth),
static_cast<size_t>(maxSeqLen)},
idsPtrHost + mDecoderDomain.getBatchSize());
@ -259,29 +257,28 @@ void DynamicDecodeLayer<T>::prepareIdsPtrs(std::shared_ptr<DynamicDecodeOutputPa
}
template <typename T>
void DynamicDecodeLayer<T>::prepareOutputData(std::shared_ptr<DynamicDecodeOutputParams> const& outputs,
std::shared_ptr<DynamicDecodeInputParams> const& params, runtime::ITensor::SharedPtr const& idsPtrsHost,
void DynamicDecodeLayer<T>::prepareOutputData(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<DecodingInputs> const& params, runtime::ITensor::SharedPtr const& idsPtrsHost,
SizeType32 const* batchSlots, SizeType32 batchSize, SizeType32 maxBatchSize, SizeType32 beamWidth,
SizeType32 maxSeqLen, SizeType32 maxTokensPerStep, SizeType32 cyclicStep, bool outputLogProbs, cudaStream_t stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto idsPtrHostSlice = ITensor::slice(idsPtrsHost, cyclicStep, 1);
auto idsPtrHost = reinterpret_cast<TokenIdType**>(runtime::bufferCast<int64_t>(*idsPtrHostSlice));
auto const numNewTokens = outputs->speculativeDecodingOutputs
? outputs->speculativeDecodingOutputs->acceptedLengths.template getPtr<SizeType32 const>()
: nullptr;
auto const numNewTokens
= outputs->numNewTokens ? outputs->numNewTokens->template getPtr<SizeType32 const>() : nullptr;
invokeCopyNextStepIds(outputs->newTokens.template getPtr<TokenIdType>(), idsPtrHost,
outputs->sequence_length->template getPtr<SizeType32>(), numNewTokens, batchSlots, batchSize, maxBatchSize,
outputs->sequenceLength->template getPtr<SizeType32>(), numNewTokens, batchSlots, batchSize, maxBatchSize,
beamWidth, maxSeqLen, maxTokensPerStep, stream);
// Transpose output log probs from [maxSeqLen, batchSize, beamWidth] to [batchSize, beamWidth, maxSeqLen]
if (outputLogProbs && outputs->output_log_probs_tiled)
if (outputLogProbs && outputs->outputLogProbsTiled)
{
auto logProbsMaxSeqLen = outputs->output_log_probs_tiled.value().shape[0];
auto logProbsMaxSeqLen = outputs->outputLogProbsTiled.value().shape[0];
invokeTransposeLogProbs(outputs->output_log_probs.value().template getPtr<float>(),
outputs->output_log_probs_tiled.value().template getPtr<float>(),
outputs->sequence_length->template getPtr<SizeType32>(), batchSlots, batchSize, maxBatchSize, beamWidth,
invokeTransposeLogProbs(outputs->outputLogProbs.value().template getPtr<float>(),
outputs->outputLogProbsTiled.value().template getPtr<float>(),
outputs->sequenceLength->template getPtr<SizeType32>(), batchSlots, batchSize, maxBatchSize, beamWidth,
logProbsMaxSeqLen, stream);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
@ -290,5 +287,4 @@ void DynamicDecodeLayer<T>::prepareOutputData(std::shared_ptr<DynamicDecodeOutpu
template class DynamicDecodeLayer<float>;
template class DynamicDecodeLayer<half>;
} // namespace layers
} // namespace tensorrt_llm
} // namespace tensorrt_llm::layers

View File

@ -16,35 +16,14 @@
#pragma once
#include "tensorrt_llm/common/tensor.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/layers/banWordsLayer.h"
#include "tensorrt_llm/layers/baseLayer.h"
#include "tensorrt_llm/layers/beamSearchLayer.h"
#include "tensorrt_llm/layers/decodingLayer.h"
#include "tensorrt_llm/layers/layerUtils.h"
#include "tensorrt_llm/layers/medusaDecodingLayer.h"
#include "tensorrt_llm/layers/penaltyLayer.h"
#include "tensorrt_llm/layers/samplingLayer.h"
#include "tensorrt_llm/layers/stopCriteriaLayer.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include <optional>
#include <string>
#include <unordered_map>
#include <utility>
namespace tc = tensorrt_llm::common;
namespace tensorrt_llm
{
namespace kernels
{
struct BeamHypotheses;
}
namespace layers
namespace tensorrt_llm::layers
{
template <typename T>
@ -59,11 +38,13 @@ public:
~DynamicDecodeLayer() override;
void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 const* batchSlots,
std::shared_ptr<BaseSetupParams> setupParams) override;
std::shared_ptr<BaseSetupParams> const& setupParams) override;
void forwardAsync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) override;
void forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& inputs) override;
void forwardSync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) override;
void forwardSync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& inputs) override;
// Function is only used by test.
// It is guaranteed by LayersFactory that the first layer is the Penalty layer.
@ -79,11 +60,10 @@ private:
void initialize();
void initializeLayers();
void prepareIdsPtrs(std::shared_ptr<DynamicDecodeOutputParams> const& outputs,
runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth,
runtime::SizeType32 maxSeqLen);
static void prepareOutputData(std::shared_ptr<DynamicDecodeOutputParams> const& outputs,
std::shared_ptr<DynamicDecodeInputParams> const& params, runtime::ITensor::SharedPtr const& idsPtrsHost,
void prepareIdsPtrs(std::shared_ptr<BaseDecodingOutputs> const& outputs, runtime::SizeType32 const* batchSlots,
runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 maxSeqLen);
static void prepareOutputData(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<DecodingInputs> const& params, runtime::ITensor::SharedPtr const& idsPtrsHost,
runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize, runtime::SizeType32 maxBatchSize,
runtime::SizeType32 beamWidth, runtime::SizeType32 maxSeqLen, runtime::SizeType32 maxTokensPerStep,
runtime::SizeType32 cyclicStep, bool outputLogProbs, cudaStream_t stream);
@ -109,5 +89,4 @@ private:
runtime::SizeType32 mConfiguredBeamWidth{-1};
};
} // namespace layers
} // namespace tensorrt_llm
} // namespace tensorrt_llm::layers

View File

@ -18,13 +18,10 @@
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/kernels/decodingCommon.h"
#include "tensorrt_llm/kernels/decodingKernels.h"
#include "tensorrt_llm/kernels/penaltyKernels.h"
#include "tensorrt_llm/kernels/penaltyTypes.h"
#include "tensorrt_llm/kernels/speculativeDecoding/explicitDraftTokensKernels.h"
#include "tensorrt_llm/layers/defaultDecodingParams.h"
#include "tensorrt_llm/layers/layerUtils.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include <algorithm>
@ -67,23 +64,28 @@ void ExplicitDraftTokensLayer<T>::allocateBuffer()
mTemperature.resize(mDecoderDomain.getBatchSize());
mScanWorkspaceSizeInBytes = invokeScanSpecDecodingGenerationLengths(
mScanWorkspaceSizeInBytes = invokeScanGenerationLengths(
nullptr, mScanWorkspaceSizeInBytes, nullptr, nullptr, mDecoderDomain.getBatchSize(), mStream);
mReduceWorkspaceSizeInBytes = invokeReduceMaxSpecDecodingGenerationLengths(
mReduceWorkspaceSizeInBytes = invokeReduceMaxGenerationLengths(
nullptr, mReduceWorkspaceSizeInBytes, nullptr, nullptr, mDecoderDomain.getBatchSize(), mStream);
mWorkspaceSizeInBytes = std::max(mScanWorkspaceSizeInBytes, mReduceWorkspaceSizeInBytes);
std::array<size_t, 6> deviceBufferSizes
std::array<size_t, 8> deviceBufferSizes
= {sizeof(curandState_t) * mDecoderDomain.getBatchSize(), sizeof(uint64_t) * mDecoderDomain.getBatchSize(),
mWorkspaceSizeInBytes, sizeof(SizeType32) * mDecoderDomain.getBatchSize(), sizeof(SizeType32),
sizeof(float) * mDecoderDomain.getBatchSize()};
sizeof(float) * mDecoderDomain.getBatchSize(), sizeof(SizeType32) * mDecoderDomain.getBatchSize(),
sizeof(SizeType32) * mDecoderDomain.getBatchSize()
* mDecoderDomain.getSpeculativeDecodingModule()->getMaxNumPaths()
* mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen()};
mCurandStatesDevice = mAllocator->reMalloc(mCurandStatesDevice, deviceBufferSizes[0], false);
mRandomSeedsDevice = mAllocator->reMalloc(mRandomSeedsDevice, deviceBufferSizes[1], false);
mWorkspaceDevice = mAllocator->reMalloc(mWorkspaceDevice, deviceBufferSizes[2], false);
mGenerationLengthInclusiveSum = mAllocator->reMalloc(mGenerationLengthInclusiveSum, deviceBufferSizes[3], false);
mMaxGenerationLength = mAllocator->reMalloc(mMaxGenerationLength, deviceBufferSizes[4], false);
mTemperatureDevice = mAllocator->reMalloc(mTemperatureDevice, deviceBufferSizes[5], false);
mBestPathIndicesSlots = mAllocator->reMalloc(mBestPathIndicesSlots, deviceBufferSizes[6], false);
mLastDraftIndicesSlots = mAllocator->reMalloc(mLastDraftIndicesSlots, deviceBufferSizes[7], false);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
@ -99,13 +101,15 @@ void ExplicitDraftTokensLayer<T>::freeBuffer()
mAllocator->free((void**) (&mGenerationLengthInclusiveSum));
mAllocator->free((void**) (&mMaxGenerationLength));
mAllocator->free((void**) (&mTemperatureDevice));
mAllocator->free((void**) (&mBestPathIndicesSlots));
mAllocator->free((void**) (&mLastDraftIndicesSlots));
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void ExplicitDraftTokensLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 const* batchSlots,
std::shared_ptr<BaseSetupParams> baseSetupParams)
std::shared_ptr<BaseSetupParams> const& baseSetupParams)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
@ -139,17 +143,40 @@ void ExplicitDraftTokensLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWid
fillBuffers(setupParams->temperature, DefaultDecodingParams::getTemperature(), mTemperature, mTemperatureDevice,
batchSlots, getLimitsPenalty(DecodingPenaltyType::Temperature), "temperature penalty");
fillContextBuffers(batchSize, batchSlots, *setupParams);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void ExplicitDraftTokensLayer<T>::fillContextBuffers(
SizeType32 batchSize, SizeType32 const* batchSlots, ExplicitDraftTokensSetupParams const& setupParams)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
FillContextExplicitDraftTokensParams<T> params;
params.randDataSample = setupParams.randomDataSample.template getPtr<T>();
params.outputTemperatures = setupParams.temperatures.template getPtr<T>();
params.inputTemperatures = mTemperatureDevice;
params.curandState = mCurandStatesDevice;
params.batchSlots = batchSlots;
params.batchSize = batchSize;
params.checkParams();
invokeFillContextBuffers(params, mStream);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void ExplicitDraftTokensLayer<T>::forwardAsync(
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto inputs = std::dynamic_pointer_cast<ExplicitDraftTokensInputParams>(baseInputs);
auto outputs = std::dynamic_pointer_cast<DynamicDecodeOutputParams>(baseOutputs);
auto inputs = std::dynamic_pointer_cast<ExplicitDraftTokensInputs>(baseInputs);
auto outputs = std::dynamic_pointer_cast<ExplicitDraftTokensOutputs>(baseOutputs);
// DO NOT CHANGE THE ORDER.
@ -167,23 +194,22 @@ void ExplicitDraftTokensLayer<T>::forwardAsync(
template <typename T>
void ExplicitDraftTokensLayer<T>::convertPackedMask(
DynamicDecodeOutputParams const& outputs, ExplicitDraftTokensInputParams const& inputs)
ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto batchSlots = inputs.batch_slots->template getPtr<SizeType32 const>();
auto batchSlots = inputs.seqSlots.template getPtr<SizeType32 const>();
auto masksDevice = inputs.masks.template getPtr<bool const>();
auto specDecodingGenerationLengths = inputs.specDecodingGenerationLengths.template getPtr<SizeType32 const>();
auto packedMasksDevice = outputs.explicitDraftTokensOutputs->packedMasks.template getPtr<SizeType32>();
auto generationLengths = inputs.generationLengths.template getPtr<SizeType32 const>();
auto packedMasksDevice = outputs.packedMasks.template getPtr<SizeType32>();
auto const batchSize = inputs.batch_slots->shape[0];
auto const batchSize = inputs.localBatchSize;
invokeScanReduceSpecDecodingGenerationLengths(batchSize, specDecodingGenerationLengths, mWorkspaceDevice,
mScanWorkspaceSizeInBytes, mGenerationLengthInclusiveSum, mWorkspaceDevice, mReduceWorkspaceSizeInBytes,
mMaxGenerationLength, mStream);
invokeScanReduceGenerationLengths(batchSize, generationLengths, mWorkspaceDevice, mScanWorkspaceSizeInBytes,
mGenerationLengthInclusiveSum, mWorkspaceDevice, mReduceWorkspaceSizeInBytes, mMaxGenerationLength, mStream);
invokeConvertSpecDecodingMaskToPackedMask(batchSize, mGenerationLengthInclusiveSum, mMaxGenerationLength,
masksDevice, batchSlots, mDecoderDomain.getSpeculativeDecodingModule()->getMaxDecodingDraftTokens(),
invokeConvertMaskToPackedMask(batchSize, mGenerationLengthInclusiveSum, mMaxGenerationLength, masksDevice,
batchSlots, mDecoderDomain.getSpeculativeDecodingModule()->getMaxDecodingDraftTokens(),
mDecoderDomain.getSpeculativeDecodingModule()->getMaxDecodingTokens(), packedMasksDevice, mStream);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
@ -191,32 +217,34 @@ void ExplicitDraftTokensLayer<T>::convertPackedMask(
template <typename T>
void ExplicitDraftTokensLayer<T>::splitInputDataToBatchSlots(
DynamicDecodeOutputParams const& outputs, ExplicitDraftTokensInputParams const& inputs)
ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = inputs.batch_slots->shape[0];
auto const maxSeqLen = outputs.output_ids.shape[outputs.output_ids.shape.size() - 1];
auto const batchSize = inputs.localBatchSize;
auto const maxSeqLen = outputs.outputIds.shape[outputs.outputIds.shape.size() - 1];
ExtractExplicitDraftTokensParams<T> params;
params.outputIds = outputs.output_ids.template getPtr<TokenIdType>();
params.outputPositionIdsBase = outputs.explicitDraftTokensOutputs->positionIdsBase.template getPtr<SizeType32>();
params.outputPositionIds = outputs.explicitDraftTokensOutputs->nextDraftPosIds.template getPtr<SizeType32>();
params.outputNextDraftTokens = outputs.explicitDraftTokensOutputs->nextDraftTokens.template getPtr<TokenIdType>();
params.unpackedNextDraftTokens
= outputs.explicitDraftTokensOutputs->unpackedNextDraftTokens.template getPtr<TokenIdType>();
params.unpackedNextDraftIndices
= outputs.explicitDraftTokensOutputs->unpackedNextDraftIndices.template getPtr<SizeType32>();
params.acceptedLengths = outputs.explicitDraftTokensOutputs->acceptedLengths.template getPtr<SizeType32>();
params.nextDraftLengths = outputs.explicitDraftTokensOutputs->nextDraftLengths.template getPtr<SizeType32>();
params.sequenceLengths = outputs.sequence_length->template getPtr<SizeType32>();
params.randDataSample = outputs.explicitDraftTokensOutputs->randomDataSample.template getPtr<T>();
params.randDataVerification = outputs.explicitDraftTokensOutputs->randomDataValidation.template getPtr<T>();
params.outputDraftProbs = outputs.explicitDraftTokensOutputs->nextDraftProbs.template getPtr<T>();
params.outputTemperatures = outputs.explicitDraftTokensOutputs->temperatures.template getPtr<T>();
params.outputIds = outputs.outputIds.template getPtr<TokenIdType>();
params.outputPositionIdsBase = outputs.positionIdsBase.template getPtr<SizeType32>();
params.outputPositionIds = outputs.nextDraftPosIds.template getPtr<SizeType32>();
params.outputNextDraftTokens = outputs.nextDraftTokens.template getPtr<TokenIdType>();
params.unpackedNextDraftTokens = outputs.unpackedNextDraftTokens.template getPtr<TokenIdType>();
params.unpackedNextDraftIndices = outputs.unpackedNextDraftIndices.template getPtr<SizeType32>();
params.acceptedLengths = outputs.numNewTokens->template getPtr<SizeType32>();
params.nextDraftLengths = outputs.nextDraftLengths.template getPtr<SizeType32>();
params.prevDraftLengths = outputs.prevDraftLengths.template getPtr<SizeType32>();
params.sequenceLengths = outputs.sequenceLength->template getPtr<SizeType32>();
params.randDataSample = outputs.randomDataSample.template getPtr<T>();
params.randDataVerification = outputs.randomDataValidation.template getPtr<T>();
params.outputDraftProbs = outputs.nextDraftProbs.template getPtr<T>();
params.outputTemperatures = outputs.temperatures.template getPtr<T>();
params.outputGenerationLengths = outputs.generationLengths.template getPtr<SizeType32>();
params.outputBestPathIndices = mBestPathIndicesSlots;
params.outputLastDraftIndices = mLastDraftIndicesSlots;
params.batchSlots = inputs.batch_slots->template getPtr<SizeType32 const>();
params.batchSlots = inputs.seqSlots.template getPtr<SizeType32 const>();
params.nextDraftTokens = inputs.nextDraftTokens.template getPtr<TokenIdType const>();
params.lastDraftTokens = inputs.lastDraftTokens.template getPtr<TokenIdType const>();
params.inputUnpackedNextDraftIndices = inputs.nextDraftIndices.template getPtr<SizeType32 const>();
@ -226,15 +254,32 @@ void ExplicitDraftTokensLayer<T>::splitInputDataToBatchSlots(
params.packedPositionIds = inputs.packedPosIds.template getPtr<SizeType32 const>();
params.nextFlatTokens = inputs.nextFlatTokens.template getPtr<TokenIdType const>();
params.nextDraftProbs = inputs.nextDraftProbs.template getPtr<T const>();
params.lastGenerationLengths = inputs.lastGenerationLengths.template getPtr<SizeType32 const>();
params.generationLengthInclusiveSum = mGenerationLengthInclusiveSum;
params.lastDraftIndices = inputs.lastDraftIndices.template getPtr<SizeType32 const>();
params.inputTemperatures = mTemperatureDevice;
params.curandState = mCurandStatesDevice;
params.curandState = mCurandStatesDevice;
params.batchSize = batchSize;
params.numPaths = mDecoderDomain.getSpeculativeDecodingModule()->getMaxNumPaths();
params.maxPathLength = mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen();
params.maxSeqLen = maxSeqLen;
params.vocabSize = mDecoderDomain.getVocabSizePadded();
params.numContextRequests = batchSize - inputs.lastDraftTokens.shape[0];
params.numGenerationRequests = inputs.lastDraftTokens.shape[0];
params.checkParams();
// Copy max generation length
cudaMemcpyAsync(outputs.maxGenLengthHost.template getPtr<SizeType32>(),
inputs.maxGenLengthDevice.template getPtr<SizeType32 const>(), sizeof(SizeType32), cudaMemcpyDeviceToHost,
mStream);
params.checkParams();
// Copy max generation length
cudaMemcpyAsync(outputs.maxGenLengthHost.template getPtr<SizeType32>(),
inputs.maxGenLengthDevice.template getPtr<SizeType32 const>(), sizeof(SizeType32), cudaMemcpyDeviceToHost,
mStream);
invokeExtractExplicitDraftTokens(params, mStream);
@ -245,28 +290,25 @@ void ExplicitDraftTokensLayer<T>::splitInputDataToBatchSlots(
template <typename T>
void ExplicitDraftTokensLayer<T>::packAcceptedPaths(
DynamicDecodeOutputParams const& outputs, ExplicitDraftTokensInputParams const& inputs)
ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = inputs.batch_slots->shape[0];
auto const batchSize = inputs.localBatchSize;
auto paths = inputs.lastDraftIndices.template getPtr<SizeType32 const>();
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType32 const>() : nullptr;
auto acceptedLengths = outputs.explicitDraftTokensOutputs->acceptedLengths.template getPtr<SizeType32 const>();
auto acceptedLengthsCumSum
= outputs.explicitDraftTokensOutputs->acceptedLengthsCumSum.template getPtr<SizeType32>();
auto pathsOffsets = outputs.explicitDraftTokensOutputs->pathsOffsets.template getPtr<SizeType32>();
auto bestPathIndices = inputs.bestPathIndices.template getPtr<SizeType32 const>();
auto numNewTokens = outputs.numNewTokens->template getPtr<SizeType32 const>();
auto numNewTokensCumSum = outputs.numNewTokensCumSum.template getPtr<SizeType32>();
auto pathsOffsets = outputs.pathsOffsets.template getPtr<SizeType32>();
auto batchSlots = inputs.batchSlots->template getPtr<SizeType32 const>();
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for ExplicitDraftTokensLayer");
TLLM_CHECK_WITH_INFO(acceptedLengths != nullptr, "Accepted lengths must be provided for ExplicitDraftTokensLayer");
TLLM_CHECK_WITH_INFO(numNewTokens != nullptr, "Accepted lengths must be provided for ExplicitDraftTokensLayer");
TLLM_CHECK_WITH_INFO(
acceptedLengthsCumSum != nullptr, "acceptedLengthsCumSum must be provided for ExplicitDraftTokensLayer");
numNewTokensCumSum != nullptr, "numNewTokensCumSum must be provided for ExplicitDraftTokensLayer");
TLLM_CHECK_WITH_INFO(pathsOffsets != nullptr, "pathsOffsets must be provided for ExplicitDraftTokensLayer");
invokePackAcceptedPaths(acceptedLengthsCumSum, pathsOffsets, acceptedLengths, bestPathIndices, paths, batchSlots,
batchSize, mDecoderDomain.getSpeculativeDecodingModule()->getMaxNumPaths(),
mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen(), true, mStream);
invokePackAcceptedPaths(numNewTokensCumSum, pathsOffsets, numNewTokens, mBestPathIndicesSlots,
mLastDraftIndicesSlots, batchSlots, batchSize, mDecoderDomain.getSpeculativeDecodingModule()->getMaxNumPaths(),
mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen(), false, mStream);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

View File

@ -16,68 +16,14 @@
#pragma once
#include <curand_kernel.h>
#include "tensorrt_llm/common/tensor.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/layers/baseLayer.h"
#include "tensorrt_llm/layers/decodingParams.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h"
namespace tc = tensorrt_llm::common;
#include <curand_kernel.h>
namespace tensorrt_llm
namespace tensorrt_llm::layers
{
namespace layers
{
class ExplicitDraftTokensSetupParams : public BaseSetupParams
{
public:
std::optional<std::vector<float>> temperature; // [setupBatchSize] on cpu
std::optional<std::vector<uint64_t>> randomSeed; // [1] or [setupBatchSize] on cpu
};
class ExplicitDraftTokensInputParams : public BaseInputParams
{
public:
explicit ExplicitDraftTokensInputParams()
: BaseInputParams{0, 0, tc::Tensor()}
{
}
//! Draft tokens for the next iteration. The first token in each path is the last accepted token at current
//! iteration. E.g. if forwardBatchSize == 1, maxNumPaths == 2, maxPathLen== 3, [[[0, 1, 2], [0, 1, 10]]]
tc::Tensor nextDraftTokens; // [forwardBatchSize, maxNumPaths, maxPathLen], gpu
//! Compressed form of `nextDraftTokens`, where common prefixes and collapsed.
//! Using example above [0, 1, 2, 10]
tc::Tensor nextFlatTokens; // [forwardBatchSize * maxDecodingTokens], gpu
//! Indices of draft tokens in the compressed `nextFlatTokens` for the next iteration.
//! Using example above, [[[0, 1, 2], [0, 1, 3]]]
tc::Tensor nextDraftIndices; // [forwardBatchSize, maxNumPaths, maxPathLen], gpu
//! Probabilities of the next draft tokens.
tc::Tensor nextDraftProbs; // [forwardBatchSize, maxNumPaths, maxDraftPathLen, vocabSize], gpu
//! Same as `nextDraftTokens`, but for current iteration.
//! Current accepted tokens obtained as `lastDraftTokens[bi][bestPathIndices[bi]][1:bestPathLengths[bi]]`.
tc::Tensor lastDraftTokens; // [forwardBatchSize, maxNumPaths, maxPathLen], gpu
//! Same as `nextDraftIndices`, but for current iteration.
tc::Tensor lastDraftIndices; // [forwardBatchSize, maxNumPaths, maxPathLen], gpu
//! Boolean attention masks.
//! maxDecodingTokens' = specDecodingGenerationLengths.max()
tc::Tensor masks; // [forwardBatchSize, maxDecodingTokens', maxDecodingTokens'], gpu
//! Relative to `positionIdsBase` position ids. Same as `nextFlatTokens` for next draft indices.
//! Using example above, [0, 1, 2, 3]
tc::Tensor packedPosIds; // [forwardBatchSize * maxDecodingTokens], gpu
//! Lengths of the accepted paths for each request. It is 1 for context phase (Only 1 primary tokens is accepted).
tc::Tensor bestPathLengths; // [forwardBatchSize], gpu
//! Indices of the accepted paths for each request. It is 0 for context phase.
tc::Tensor bestPathIndices; // [forwardBatchSize], gpu
//! Number of the draft tokens for the next iteration.
tc::Tensor specDecodingGenerationLengths; // [forwardBatchSize], gpu
//! Baseline for the position ids.
tc::Tensor positionIdsBase; // [forwardBatchSize], gpu
};
//! \brief Decoding layer for speculative decoding technique, when all tokens are generated, decoded and accepted in the
//! engine.
@ -94,20 +40,23 @@ public:
~ExplicitDraftTokensLayer() override;
void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 const* batchSlots,
std::shared_ptr<BaseSetupParams> setupParams) override;
std::shared_ptr<BaseSetupParams> const& setupParams) override;
void forwardAsync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) override;
void forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& inputs) override;
private:
void allocateBuffer();
void freeBuffer();
void convertPackedMask(DynamicDecodeOutputParams const& outputs, ExplicitDraftTokensInputParams const& inputs);
void fillContextBuffers(
SizeType32 batchSize, SizeType32 const* batchSlots, ExplicitDraftTokensSetupParams const& params);
void splitInputDataToBatchSlots(
DynamicDecodeOutputParams const& outputs, ExplicitDraftTokensInputParams const& inputs);
void convertPackedMask(ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs);
void packAcceptedPaths(DynamicDecodeOutputParams const& outputs, ExplicitDraftTokensInputParams const& inputs);
void splitInputDataToBatchSlots(ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs);
void packAcceptedPaths(ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs);
private:
using Base::mStream;
@ -129,9 +78,10 @@ private:
SizeType32* mGenerationLengthInclusiveSum{nullptr};
SizeType32* mMaxGenerationLength{nullptr};
float* mTemperatureDevice{nullptr};
SizeType32* mBestPathIndicesSlots{nullptr};
SizeType32* mLastDraftIndicesSlots{nullptr};
std::vector<float> mTemperature;
};
} // namespace layers
} // namespace tensorrt_llm
} // namespace tensorrt_llm::layers

View File

@ -92,35 +92,32 @@ inline bool allOfBatchSlots(
}
inline DecoderDomain getLocalDecoderDomain(
std::shared_ptr<BaseInputParams> baseInputs, DecoderDomain const& globalDecoderDomain)
std::shared_ptr<BaseDecodingInputs> baseInputs, DecoderDomain const& globalDecoderDomain)
{
auto inputs = std::dynamic_pointer_cast<DynamicDecodeInputParams>(baseInputs);
runtime::SizeType32 batchSize{0};
auto inputs = std::dynamic_pointer_cast<DecodingInputs>(baseInputs);
runtime::SizeType32 batchSize{baseInputs->localBatchSize};
runtime::SizeType32 beamWidth{0};
runtime::SizeType32 vocabSize{0};
if (inputs->logits)
{
auto const& logitsShape = inputs->logits->shape;
TLLM_CHECK(logitsShape.size() == 3 || logitsShape.size() == 4);
batchSize = logitsShape[0];
auto const idxOffset = logitsShape.size() - 3;
beamWidth = logitsShape[idxOffset + 1];
vocabSize = logitsShape[idxOffset + 2];
}
else if (inputs->logits_vec)
else if (inputs->logitsVec)
{
TLLM_CHECK(inputs->logits_vec->size());
auto const& logitsShape = inputs->logits_vec.value()[0].shape;
TLLM_CHECK(inputs->logitsVec->size());
auto const& logitsShape = inputs->logitsVec.value()[0].shape;
TLLM_CHECK(logitsShape.size() == 3 || logitsShape.size() == 4);
auto const idxOffset = logitsShape.size() - 3;
batchSize = inputs->logits_vec->size();
beamWidth = logitsShape[idxOffset + 1];
vocabSize = logitsShape[idxOffset + 2];
}
else if (inputs->batch_slots)
else if (inputs->batchSlots)
{
auto const& batchSlotsShape = inputs->batch_slots->shape;
batchSize = batchSlotsShape[0];
auto const& batchSlotsShape = inputs->batchSlots->shape;
beamWidth = globalDecoderDomain.getBeamWidth();
vocabSize = globalDecoderDomain.getVocabSize();
}

View File

@ -15,7 +15,12 @@
*/
#include "tensorrt_llm/layers/lookaheadAlgorithm.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/layers/lookaheadDecodingUtils.h"
#include "tensorrt_llm/runtime/lookaheadModule.h"
#include <tuple>
namespace tensorrt_llm::layers
{
@ -31,10 +36,13 @@ void LookaheadAlgorithm::setup(TensorConstPtr const& prompt, SizeType32 w, SizeT
mW = w;
mN = n;
mG = g;
std::tie(std::ignore, std::ignore, mRuntimeMaxDraftLen, std::ignore)
= executor::LookaheadDecodingConfig(mW, mN, mG).calculateSpeculativeResource();
mPoolManager.setup(mG);
mPoolManager.accept(prompt, mN);
mGoldenTokens = ITensor::slice(mGoldenTokensMax, 0, mN * 2 - 1);
mPrefills = ITensor::slice(mPrefillsMax, 0, mN - 2);
mPrefills = ITensor::slice(mPrefillsMax, 0, mN <= 1 ? 0 : mN - 2);
mKeyTokens = ITensor::slice(mKeyTokensMax, 0, mW);
mPastTokens = ITensor::slice(mPastTokensMax, 0, mW * (mN - 1));
mPastTokens->reshape(ITensor::makeShape({mW, mN - 1}));
@ -48,10 +56,14 @@ void LookaheadAlgorithm::setup(TensorConstPtr const& prompt, SizeType32 w, SizeT
std::for_each(pastRange.begin(), pastRange.end(), [](auto& a) { a = -1; });
for (SizeType32 i = 0; i < mW; i++)
{
randToken(pastRange[i * (mN - 1)]);
if (mN - 1 > 0)
{
randToken(pastRange[i * (mN - 1)]);
}
}
std::copy(std::prev(promptRange.end(), mN - 1), promptRange.end(), goldRange.begin());
mFilling = 1;
mGuessTokens = ITensor::slice(mGuessTokensMax, 0, 0);
mFilling = (mN - 1) > 0 ? 1 : 0;
PRINT_TOKENS(prompt);
PRINT_TOKENS(mPrefills);
PRINT_TOKENS(mPastTokens);
@ -75,6 +87,8 @@ void LookaheadAlgorithm::accept(TensorConstPtr const& generatedTokens)
runtime::SizeType32 LookaheadAlgorithm::lookahead(TensorPtr const& draftTokens, TensorPtr const& positionIds,
TensorPtr const& samplingMask, runtime::SizeType32 offset)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
SizeType32 prefill = mN - 2 - mFilling;
SizeType32 len = prefill + mFilling * mW;
TLLM_CHECK(len <= ITensor::volume(draftTokens->getShape()));
@ -132,8 +146,9 @@ runtime::SizeType32 LookaheadAlgorithm::lookahead(TensorPtr const& draftTokens,
samplingMaskRange[wj * mFilling + mFilling - 1 - 1] = true;
}
}
TLLM_LOG_DEBUG("prefill=%d, offset=%d", prefill, offset);
PRINT_VALUES(positionIds);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return len;
}
@ -171,10 +186,19 @@ void LookaheadAlgorithm::prepare(TensorPtr const& draftTokens, TensorPtr const&
TensorPtr const& samplingMask, TensorPtr const& length, TensorConstPtr const& offsetPtr,
TensorConstPtr const& lastTokenPtr)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if (mRuntimeMaxDraftLen == 0)
{
(BufferRange<SizeType32>(*length))[0] = 0;
return;
}
auto lastToken = BufferRange<TokenIdType const>(*lastTokenPtr)[0];
auto offset = BufferRange<SizeType32 const>(*offsetPtr)[0];
SizeType32 inputLen = ITensor::volume(draftTokens->getShape());
TLLM_CHECK(inputLen >= mRuntimeMaxDraftLen);
BufferRange<TokenIdType> draftRange(*draftTokens);
BufferRange<TokenIdType> positionRange(*positionIds);
@ -182,33 +206,39 @@ void LookaheadAlgorithm::prepare(TensorPtr const& draftTokens, TensorPtr const&
SizeType32 filledLen = 0;
filledLen += lookahead(ITensor::slice(draftTokens, filledLen, inputLen - filledLen),
ITensor::slice(positionIds, filledLen, inputLen - filledLen),
ITensor::slice(samplingMask, filledLen, inputLen - filledLen), offset);
filledLen += lookahead(ITensor::slice(draftTokens, filledLen, mRuntimeMaxDraftLen - filledLen),
ITensor::slice(positionIds, filledLen, mRuntimeMaxDraftLen - filledLen),
ITensor::slice(samplingMask, filledLen, mRuntimeMaxDraftLen - filledLen), offset);
auto guessStart = filledLen;
filledLen += guess(ITensor::slice(draftTokens, filledLen, inputLen - filledLen),
ITensor::slice(positionIds, filledLen, inputLen - filledLen),
ITensor::slice(samplingMask, filledLen, inputLen - filledLen), offset, lastToken);
filledLen += guess(ITensor::slice(draftTokens, filledLen, mRuntimeMaxDraftLen - filledLen),
ITensor::slice(positionIds, filledLen, mRuntimeMaxDraftLen - filledLen),
ITensor::slice(samplingMask, filledLen, mRuntimeMaxDraftLen - filledLen), offset, lastToken);
auto guessEnd = filledLen;
mGuessTokens = ITensor::slice(mGuessTokensMax, 0, guessEnd - guessStart);
std::copy(draftRange.begin() + guessStart, draftRange.begin() + guessEnd,
BufferRange<TokenIdType>(*mGuessTokens).begin());
(BufferRange<SizeType32>(*length))[0] = filledLen;
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void LookaheadAlgorithm::verify(TensorPtr const& accepted, TensorPtr const& acceptedOffsets,
TensorPtr const& acceptedLength, TokenIdType newLastToken, TensorConstPtr const& goldenTokens,
TensorConstPtr const& endToken)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(ITensor::volume(goldenTokens->getShape()) == ITensor::volume(mGuessTokens->getShape()));
BufferRange<TokenIdType const> goldRange(*goldenTokens);
BufferRange<TokenIdType> guessTokensRange(*mGuessTokens);
auto guessSize = ITensor::volume(mGuessTokens->getShape());
SizeType32 guesses = guessSize / (mN - 1), hit = 0, maxHit = 0, hitIdx = 0;
SizeType32 guesses = (mN - 1 > 0) ? (guessSize / (mN - 1)) : 0;
SizeType32 hit = 0, maxHit = 0, hitIdx = 0;
for (SizeType32 i = 0; i < guesses; i++)
{
SizeType32 hit = 0;
@ -248,6 +278,8 @@ void LookaheadAlgorithm::verify(TensorPtr const& accepted, TensorPtr const& acce
}
*BufferRange<SizeType32>(*acceptedLength).begin() = maxHit + 1;
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
//! lookahead Jacobi matrix has prefilling phase and maintenance phase.
@ -293,6 +325,8 @@ void LookaheadAlgorithm::verify(TensorPtr const& accepted, TensorPtr const& acce
void LookaheadAlgorithm::update(TensorPtr const& acceptedTokens, TensorPtr const& acceptedOffsets,
TensorPtr const& acceptedLength, TensorConstPtr const& sampledTokens, TensorConstPtr const& endToken)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(ITensor::volume(acceptedTokens->getShape()) >= mN);
BufferRange<TokenIdType const> sampledRange(*sampledTokens);
BufferRange<TokenIdType> keyRange(*mKeyTokens);
@ -312,7 +346,7 @@ void LookaheadAlgorithm::update(TensorPtr const& acceptedTokens, TensorPtr const
pastRange[i * (mN - 1) + mFilling] = keyRange[i];
}
}
else
else if (mN > 1)
{
for (SizeType32 i = 0; i < mW; i++)
{
@ -329,8 +363,9 @@ void LookaheadAlgorithm::update(TensorPtr const& acceptedTokens, TensorPtr const
auto guessSize = ITensor::volume(mGuessTokens->getShape());
auto outputSize = ITensor::volume(sampledTokens->getShape());
auto lookSize = 1 + mN - 2 - mFilling + mFilling * mW;
auto lookSize = 1 + (mN > 1 ? mN - 2 : 0) - mFilling + mFilling * mW;
TLLM_CHECK(guessSize + lookSize == outputSize);
TensorConstPtr goldenTokens = ITensor::slice(sampledTokens, lookSize, guessSize);
verify(acceptedTokens, acceptedOffsets, acceptedLength, newLastToken, goldenTokens, endToken);
@ -341,6 +376,8 @@ void LookaheadAlgorithm::update(TensorPtr const& acceptedTokens, TensorPtr const
{
mFilling++;
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
} // namespace tensorrt_llm::layers

View File

@ -44,7 +44,8 @@ public:
, mId(id)
, mGoldenTokensMax(
runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxN * 2 - 1}), nvinfer1::DataType::kINT32))
, mPrefillsMax(runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxN - 2}), nvinfer1::DataType::kINT32))
, mPrefillsMax(runtime::BufferManager::cpu(
runtime::ITensor::makeShape({(maxN <= 1 ? 0 : maxN - 2)}), nvinfer1::DataType::kINT32))
, mKeyTokensMax(runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxW}), nvinfer1::DataType::kINT32))
, mPastTokensMax(
runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxW * (maxN - 1)}), nvinfer1::DataType::kINT32))
@ -125,6 +126,7 @@ private:
runtime::SizeType32 mW{0};
runtime::SizeType32 mN{0};
runtime::SizeType32 mG{0};
runtime::SizeType32 mRuntimeMaxDraftLen{0};
//! in prefilling mode when mFilling < mN-1.
runtime::SizeType32 mFilling;

View File

@ -0,0 +1,364 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/layers/lookaheadDecodingLayer.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/kernels/decodingKernels.h"
#include "tensorrt_llm/kernels/samplingTopKKernels.h"
#include "tensorrt_llm/layers/decodingParams.h"
#include "tensorrt_llm/layers/defaultDecodingParams.h"
#include "tensorrt_llm/layers/lookaheadAlgorithm.h"
#include "tensorrt_llm/layers/lookaheadDecodingUtils.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include <cstdint>
#include <ios>
#include <memory>
#include <tuple>
namespace tensorrt_llm::layers
{
using namespace tensorrt_llm::common;
using namespace tensorrt_llm::kernels;
using namespace tensorrt_llm::runtime;
template <typename T>
LookaheadDecodingLayer<T>::CpuAlgorithmResources::CpuAlgorithmResources(DecoderDomain const& decoderDomain)
{
auto maxBatchSize = decoderDomain.getBatchSize();
auto lookaheadModule
= std::dynamic_pointer_cast<LookaheadModule const>(decoderDomain.getSpeculativeDecodingModule());
auto const [maxW, maxN, maxG] = lookaheadModule->getExecutionConfig().get();
for (runtime::SizeType32 id = 0; id < maxBatchSize; id++)
{
mAlgos.emplace_back(maxW, maxN, maxG, id);
}
SizeType32 maxTokensPerStep, maxNumNewTokens, maxDraftLen;
std::tie(maxTokensPerStep, maxNumNewTokens, maxDraftLen, std::ignore)
= executor::LookaheadDecodingConfig(maxW, maxN, maxG).calculateSpeculativeResource();
auto const maxBatchShape1D = ITensor::makeShape({maxBatchSize});
mBatchSlots = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
mTargetTokens
= BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxTokensPerStep}), nvinfer1::DataType::kINT32);
mTokensPerStep = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
mEndIds = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
mOutputIds = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxNumNewTokens}), nvinfer1::DataType::kINT32);
mPathsOffsets = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxNumNewTokens}), nvinfer1::DataType::kINT32);
mNumNewTokens = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
mNumNewTokensCumSum = BufferManager::cpu(ITensor::makeShape({maxBatchSize + 1}), nvinfer1::DataType::kINT32);
mNextDraftTokens = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kINT32);
mNextDraftPosIds = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kINT32);
auto divUp32 = [](SizeType32 x) { return x / 32 + ((x % 32) ? 1 : 0); };
mPackedMasks = BufferManager::cpu(
ITensor::makeShape({maxBatchSize, maxTokensPerStep, divUp32(maxTokensPerStep)}), nvinfer1::DataType::kINT32);
mSamplingMask = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kBOOL);
mNextDraftLengths = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
mSequenceLengths = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
}
template <typename T>
LookaheadDecodingLayer<T>::LookaheadDecodingLayer(
DecoderDomain const& decoderDomain, std::shared_ptr<runtime::BufferManager> const& bufferManager)
: BaseLayer(decoderDomain, bufferManager)
, mCpuAlgo(std::make_optional<CpuAlgorithmResources>(decoderDomain))
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const maxBatchSize = mDecoderDomain.getBatchSize();
auto const maxTokensPerStep = mDecoderDomain.getMaxDecodingTokens();
auto const vocabSizePadded = mDecoderDomain.getVocabSizePadded();
auto const maxTopK = 1;
auto const maxBatchShape1D = ITensor::makeShape({maxBatchSize});
auto const maxBatchShape2D = ITensor::makeShape({maxBatchSize, maxTokensPerStep});
mWorkspaceSize = getTopKWorkspaceSize<T>(maxBatchSize, maxTokensPerStep, maxTopK, vocabSizePadded);
TLLM_LOG_DEBUG("mWorkspaceSize=%d", mWorkspaceSize);
mSamplingWorkspaceDevice
= mBufferManager->gpu(ITensor::makeShape({static_cast<int32_t>(mWorkspaceSize)}), nvinfer1::DataType::kINT8);
mTargetTokensDevice = mBufferManager->gpu(maxBatchShape2D, nvinfer1::DataType::kINT32);
mRandomSeedsDevice = mBufferManager->gpu(maxBatchShape1D, nvinfer1::DataType::kINT64);
mSamplingMaskDevice = mBufferManager->gpu(maxBatchShape2D, nvinfer1::DataType::kBOOL);
mCurandStatesDevice = mBufferManager->gpu(maxBatchShape1D, nvinfer1::DataType::kINT8);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void LookaheadDecodingLayer<T>::setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth,
runtime::SizeType32 const* batchSlots, std::shared_ptr<BaseSetupParams> const& baseSetupParams)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto setupParams = std::dynamic_pointer_cast<LookaheadSetupParams>(baseSetupParams);
if (mCpuAlgo)
{
auto& algoConfigs = setupParams->algoConfigs;
TLLM_CHECK_WITH_INFO(algoConfigs.size() == 1 || algoConfigs.size() == batchSize,
"Lookahead runtime configuration size should be either 1 or batchSize");
for (runtime::SizeType32 bi = 0; bi < batchSize; bi++)
{
SizeType32 gbi = batchSlots[bi];
SizeType32 bi1orN = (algoConfigs.size() == 1) ? 0 : bi;
TLLM_LOG_DEBUG("CPU ALGO [ %d ] setup", gbi);
PRINT_TOKENS(setupParams->prompt[bi]);
auto [w, n, g] = algoConfigs[bi1orN].get();
SizeType32 runtimeTokensPerStep;
std::tie(runtimeTokensPerStep, std::ignore, std::ignore, std::ignore)
= executor::LookaheadDecodingConfig(w, n, g).calculateSpeculativeResource();
TLLM_CHECK_WITH_INFO(runtimeTokensPerStep <= mDecoderDomain.getMaxDecodingTokens(),
"runtime w(%d) n(%d) g(%d) exceeds maxTokensPerStep(%d)", w, n, g,
mDecoderDomain.getMaxDecodingTokens());
mCpuAlgo->mAlgos[gbi].setup(setupParams->prompt[bi], w, n, g);
}
}
auto curandStatesDevicePtr = reinterpret_cast<curandState_t*>(bufferCast<int8_t>(*mCurandStatesDevice));
if (setupParams->randomSeed)
{
auto& randomSeed = setupParams->randomSeed.value();
if (randomSeed.size() == 1)
{
invokeCurandInitialize(curandStatesDevicePtr, batchSlots, batchSize, randomSeed.front(), mStream);
sync_check_cuda_error();
}
else
{
TLLM_CHECK_WITH_INFO(randomSeed.size() == batchSize, "Random seed vector size mismatch.");
cudaAutoCpy(bufferCast<uint64_t>(*mRandomSeedsDevice), randomSeed.data(), batchSize, mStream);
invokeCurandBatchInitialize(
curandStatesDevicePtr, batchSlots, batchSize, bufferCast<uint64_t>(*mRandomSeedsDevice), mStream);
sync_check_cuda_error();
}
}
else
{
invokeCurandInitialize(curandStatesDevicePtr, batchSlots, batchSize, DefaultDecodingParams::getSeed(), mStream);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void LookaheadDecodingLayer<T>::forwardAsync(
std::shared_ptr<BaseDecodingOutputs> const& outputParams, std::shared_ptr<BaseDecodingInputs> const& inputParams)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto inputs = std::dynamic_pointer_cast<LookaheadDecodingInputs>(inputParams);
auto outputs = std::dynamic_pointer_cast<SpeculativeDecodingOutputs>(outputParams);
auto batchSize = inputs->localBatchSize;
TLLM_CHECK_WITH_INFO(inputs->batchSlots, "Batch slots must be provided for LookaheadDecoding");
TLLM_CHECK_WITH_INFO(inputs->curTokensPerStep, "curTokensPerStep must be provided for LookaheadDecoding");
TLLM_CHECK_WITH_INFO(outputs->sequenceLength, "sequenceLength must be provided for LookaheadDecoding");
// TODO(liweim) to be confirmed.
TLLM_CHECK(inputs->logits);
mBufferManager->copy(
inputs->batchSlots->template getPtr<SizeType32 const>(), *mCpuAlgo->mBatchSlots, runtime::MemoryType::kGPU);
mBufferManager->copy(inputs->curTokensPerStep->template getPtr<SizeType32 const>(), *mCpuAlgo->mTokensPerStep,
runtime::MemoryType::kGPU);
mBufferManager->copy(
inputs->endIds.template getPtr<TokenIdType const>(), *mCpuAlgo->mEndIds, runtime::MemoryType::kGPU);
mBufferManager->copy(outputs->sequenceLength->template getPtr<SizeType32 const>(), *mCpuAlgo->mSequenceLengths,
runtime::MemoryType::kGPU);
TopKSamplingKernelParams<T> params;
params.maxBatchSize = mDecoderDomain.getBatchSize();
params.batchSize = batchSize;
params.maxTopK = 1;
params.returnAllTopK = true;
params.maxTokensPerStep = mDecoderDomain.getMaxDecodingTokens();
params.maxSeqLen = mDecoderDomain.getMaxDecodingTokens();
params.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
params.batchSlots = inputs->batchSlots->template getPtr<SizeType32 const>();
TLLM_LOG_DEBUG("batchSize = %d", batchSize);
params.logProbs = inputs->logits ? inputs->logits->template getPtr<T>() : nullptr;
params.outputIds = bufferCast<TokenIdType>(*mTargetTokensDevice);
params.workspace = bufferCast<int8_t>(*mSamplingWorkspaceDevice);
params.curandState = reinterpret_cast<curandState_t*>(bufferCast<int8_t>(*mCurandStatesDevice));
params.tokensPerStep = inputs->curTokensPerStep->template getPtr<SizeType32 const>();
TLLM_LOG_DEBUG(
"invokeBatchTopKSampling: maxBatchSize=%d, batchSize=%d, maxTopK=%d, maxTokensPerStep=%d, maxSeqLen=%d, "
"vocabSizePadded=%d",
params.maxBatchSize, params.batchSize, params.maxTopK, params.maxTokensPerStep, params.maxSeqLen,
params.vocabSizePadded);
// Sample multiple tokens per request and store them to separate to be accepted/rejected later
// Sequence length is not modified, endIds is not checked, outputLogProbs are not supported.
// Finished state is not set.
invokeBatchTopKSampling(params, mStream);
mBufferManager->copy(*mTargetTokensDevice, *mCpuAlgo->mTargetTokens);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void LookaheadDecodingLayer<T>::forwardSync(
std::shared_ptr<BaseDecodingOutputs> const& outputParams, std::shared_ptr<BaseDecodingInputs> const& inputParams)
{
if (mCpuAlgo)
{
forwardSyncCPU(outputParams, inputParams);
}
}
template <typename T>
void LookaheadDecodingLayer<T>::posIdsToMask(TensorPtr mask, TensorConstPtr posIds)
{
auto len = ITensor::volume(posIds->getShape());
TLLM_CHECK(mask->getShape().d[0] > len);
TLLM_CHECK(mask->getShape().d[1] * 32 > len);
auto posIdsRange = BufferRange<SizeType32 const>(*posIds);
auto maskLocation = BufferLocation<SizeType32>(*mask);
for (auto i = 0; i < maskLocation.size(); i++)
{
maskLocation[i] = 0;
}
maskLocation.at(0, 0) = 1;
auto setBit = [](SizeType32& x, SizeType32 idx) { x |= (1 << idx); };
if (len > 0)
{
std::vector<std::pair<SizeType32, SizeType32>> stack;
stack.push_back(std::make_pair(0, posIdsRange[0] - 1));
for (auto i = 1; i < len + 1; i++)
{
auto cur = posIdsRange[i - 1];
while (stack.size() > 0 && cur <= stack.back().second)
{
stack.pop_back();
}
TLLM_CHECK(stack.size() > 0 ? cur == stack.back().second + 1 : true);
stack.push_back(std::make_pair(i, cur));
for (auto prev : stack)
{
setBit(maskLocation.at(i, prev.first / 32), prev.first % 32);
}
}
}
}
template <typename T>
void LookaheadDecodingLayer<T>::forwardSyncCPU(
std::shared_ptr<BaseDecodingOutputs> const& outputParams, std::shared_ptr<BaseDecodingInputs> const& inputParams)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto inputs = std::dynamic_pointer_cast<LookaheadDecodingInputs>(inputParams);
auto outputs = std::dynamic_pointer_cast<SpeculativeDecodingOutputs>(outputParams);
auto const batchSize = inputs->localBatchSize;
TensorPtr outputIds(wrap(outputs->outputIds));
BufferRange<SizeType32 const> tokensPerStepRange(*mCpuAlgo->mTokensPerStep);
BufferRange<SizeType32> numNewTokensRange(*mCpuAlgo->mNumNewTokens);
BufferRange<SizeType32> numNewTokensCumSumRange(*mCpuAlgo->mNumNewTokensCumSum);
BufferRange<SizeType32> batchSlotsRange(*mCpuAlgo->mBatchSlots);
BufferRange<SizeType32> nextDraftLengthsRange(*mCpuAlgo->mNextDraftLengths);
BufferRange<SizeType32> sequenceLengthsRange(*mCpuAlgo->mSequenceLengths);
for (SizeType32 bi = 0; bi < batchSize; bi++)
{
SizeType32 gbi = batchSlotsRange[bi];
LookaheadAlgorithm& theAlgo(mCpuAlgo->mAlgos[gbi]);
SizeType32 const tokensPerStep = tokensPerStepRange[gbi];
TensorPtr sampledTokens = ITensor::slice(mCpuAlgo->mTargetTokens, {gbi, 0}, tokensPerStep);
if (tokensPerStep == 1)
{ // The first step in generation phase has no draft tokens.
theAlgo.accept(sampledTokens);
mBufferManager->copy(*sampledTokens, *ITensor::slice(mCpuAlgo->mOutputIds, {gbi, 0}, tokensPerStep));
BufferLocation<SizeType32>(*mCpuAlgo->mPathsOffsets).at(gbi, 0) = 0;
numNewTokensRange[gbi] = tokensPerStep;
BufferLocation<SizeType32>(*mCpuAlgo->mNextDraftLengths).at(gbi) = 0;
}
else
{
theAlgo.update( //
ITensor::at(mCpuAlgo->mOutputIds, {gbi}), //
ITensor::at(mCpuAlgo->mPathsOffsets, {gbi}), //
ITensor::at(mCpuAlgo->mNumNewTokens, {gbi}), //
sampledTokens, //
ITensor::at(mCpuAlgo->mEndIds, {gbi}));
}
auto maxNumNewTokens = mCpuAlgo->mOutputIds->getShape().d[1];
mBufferManager->copy(*ITensor::at(mCpuAlgo->mOutputIds, {gbi}),
*ITensor::slice(outputIds, {gbi, sequenceLengthsRange[gbi]}, maxNumNewTokens));
sequenceLengthsRange[gbi] += numNewTokensRange[gbi];
theAlgo.prepare( //
ITensor::at(mCpuAlgo->mNextDraftTokens, {gbi}), //
ITensor::at(mCpuAlgo->mNextDraftPosIds, {gbi}), //
ITensor::at(mCpuAlgo->mSamplingMask, {gbi}), //
ITensor::at(mCpuAlgo->mNextDraftLengths, {gbi}), //
ITensor::at(mCpuAlgo->mSequenceLengths, {gbi}), //
ITensor::at(mCpuAlgo->mOutputIds, {gbi, numNewTokensRange[gbi] - 1}));
posIdsToMask( //
ITensor::at(mCpuAlgo->mPackedMasks, {gbi}), //
ITensor::slice(mCpuAlgo->mNextDraftPosIds, {gbi, 0}, nextDraftLengthsRange[gbi]));
}
numNewTokensCumSumRange[0] = 0;
for (SizeType32 i = 0; i < numNewTokensRange.size(); i++)
{
numNewTokensCumSumRange[i + 1] = numNewTokensCumSumRange[i] + numNewTokensRange[i];
}
TLLM_CHECK(outputs->numNewTokens);
mBufferManager->copy(*mCpuAlgo->mSequenceLengths, //
const_cast<void*>(outputs->sequenceLength.value().data), runtime::MemoryType::kGPU);
mBufferManager->copy(*mCpuAlgo->mPathsOffsets, //
const_cast<void*>(outputs->pathsOffsets.data), runtime::MemoryType::kGPU);
mBufferManager->copy(*mCpuAlgo->mNumNewTokens, //
const_cast<void*>(outputs->numNewTokens->data), runtime::MemoryType::kGPU);
mBufferManager->copy(*mCpuAlgo->mNumNewTokensCumSum, //
const_cast<void*>(outputs->numNewTokensCumSum.data), runtime::MemoryType::kGPU);
mBufferManager->copy(*mCpuAlgo->mNextDraftTokens, //
const_cast<void*>(outputs->nextDraftTokens.data), runtime::MemoryType::kGPU);
mBufferManager->copy(*mCpuAlgo->mNextDraftPosIds, //
const_cast<void*>(outputs->nextDraftPosIds.data), runtime::MemoryType::kGPU);
mBufferManager->copy(*mCpuAlgo->mPackedMasks, //
const_cast<void*>(outputs->packedMasks.data), runtime::MemoryType::kGPU);
mBufferManager->copy(*mCpuAlgo->mNextDraftLengths, //
const_cast<void*>(outputs->nextDraftLengths.data), runtime::MemoryType::kGPU);
// TODO(liweim) do we need this?
// mBufferManager->getStream().synchronize();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template class LookaheadDecodingLayer<float>;
template class LookaheadDecodingLayer<half>;
} // namespace tensorrt_llm::layers

View File

@ -0,0 +1,97 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "lookaheadAlgorithm.h"
#include "tensorrt_llm/common/cudaAllocator.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/layers/baseLayer.h"
#include "tensorrt_llm/layers/decodingParams.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/lookaheadModule.h"
namespace tensorrt_llm::layers
{
//! \brief LookaheadDecodingLayer
template <typename T>
class LookaheadDecodingLayer : public BaseLayer
{
public:
using Base = BaseLayer;
using TensorPtr = runtime::ITensor::SharedPtr;
using TensorConstPtr = runtime::ITensor::SharedConstPtr;
using Base::mBufferManager;
LookaheadDecodingLayer(
DecoderDomain const& decoderDomain, std::shared_ptr<runtime::BufferManager> const& bufferManager);
~LookaheadDecodingLayer() override {}
void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 const* batchSlots,
std::shared_ptr<BaseSetupParams> const& baseSetupParams) override;
void forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputParams,
std::shared_ptr<BaseDecodingInputs> const& inputParams) override;
void forwardSync(std::shared_ptr<BaseDecodingOutputs> const& outputParams,
std::shared_ptr<BaseDecodingInputs> const& inputParams) override;
private:
void forwardSyncCPU(std::shared_ptr<BaseDecodingOutputs> const& outputParams,
std::shared_ptr<BaseDecodingInputs> const& inputParams);
void posIdsToMask(TensorPtr mask, TensorConstPtr posIds);
private:
using Base::mStream;
using Base::mAllocator;
using Base::mWorkspaceSize;
using Base::mDecoderDomain;
TensorPtr mCurandStatesDevice;
TensorPtr mSamplingWorkspaceDevice;
TensorPtr mTargetTokensDevice;
TensorPtr mRandomSeedsDevice;
TensorPtr mSamplingMaskDevice;
struct CpuAlgorithmResources
{
explicit CpuAlgorithmResources(DecoderDomain const& decoderDomain);
std::vector<LookaheadAlgorithm> mAlgos;
TensorPtr mBatchSlots;
TensorPtr mTargetTokens;
TensorPtr mTokensPerStep;
TensorPtr mEndIds;
TensorPtr mOutputIds;
TensorPtr mPathsOffsets;
TensorPtr mNumNewTokens;
TensorPtr mNumNewTokensCumSum;
TensorPtr mNextDraftTokens;
TensorPtr mNextDraftPosIds;
TensorPtr mPackedMasks;
TensorPtr mSamplingMask;
TensorPtr mNextDraftLengths;
TensorPtr mSequenceLengths;
};
std::optional<CpuAlgorithmResources> mCpuAlgo;
};
} // namespace tensorrt_llm::layers

View File

@ -1,74 +0,0 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sstream>
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/layers/lookaheadDecodingUtils.h"
namespace tensorrt_llm::layers
{
using namespace tensorrt_llm::runtime;
using TensorPtr = ITensor::SharedPtr;
ITensor::UniquePtr slice(
ITensor::SharedPtr tensor, std::initializer_list<SizeType32> const& offsetDims, size_t const sizeDim)
{
auto shape = tensor->getShape();
TLLM_CHECK(offsetDims.size() > 0);
TLLM_CHECK(shape.nbDims >= offsetDims.size());
std::vector<size_t> volumes(shape.nbDims);
int i;
volumes[shape.nbDims - 1] = 1;
for (i = shape.nbDims - 2; i >= 0; i--)
{
volumes[i] = shape.d[i + 1] * volumes[i + 1];
}
size_t offset = 0;
i = 0;
for (auto itd = offsetDims.begin(); itd != offsetDims.end(); itd++)
{
TLLM_CHECK(0 <= (*itd) && (*itd) < shape.d[i]);
offset += (*itd) * volumes[i++];
}
ITensor::Shape dims;
dims.nbDims = shape.nbDims - offsetDims.size() + 1;
dims.d[0] = sizeDim;
for (i = 1; i < dims.nbDims; i++)
{
dims.d[i] = shape.d[i - 1 + offsetDims.size()];
}
size_t size = ITensor::volume(dims);
return std::make_unique<TensorView>(std::move(tensor), offset, size, dims);
}
ITensor::UniquePtr slice(ITensor::SharedPtr tensor, std::initializer_list<SizeType32> const& offsetDims)
{
auto result = slice(tensor, offsetDims, 1);
if (result->getShape().nbDims > 1)
{
result->squeeze(0);
}
return result;
}
} // namespace tensorrt_llm::layers

View File

@ -240,10 +240,18 @@ public:
{
buf << token;
}
buf << (i == size - 1 ? ']' : ',');
if (i != size - 1)
{
buf << ',';
}
}
buf << ']';
};
if (shape.nbDims == 1)
if (shape.nbDims == 0)
{
buf << "[]";
}
else if (shape.nbDims == 1)
{
line(tensorRange.begin(), shape.d[0]);
}
@ -277,10 +285,19 @@ public:
buf << '[';
for (SizeType32 i = 0; i < size; i++)
{
buf << array[i] << (i == size - 1 ? ']' : ',');
buf << array[i];
if (i != size - 1)
{
buf << ',';
}
}
buf << ']';
};
if (shape.nbDims == 1)
if (shape.nbDims == 0)
{
buf << "[]";
}
else if (shape.nbDims == 1)
{
line(tensorRange.begin(), shape.d[0]);
}
@ -305,15 +322,24 @@ public:
{
switch (mTensor.getDataType())
{
case nvinfer1::DataType::kBOOL: return values<bool>();
case nvinfer1::DataType::kFLOAT: return values<float>();
case nvinfer1::DataType::kINT8: return values<std::int8_t>();
case nvinfer1::DataType::kINT32: return values<std::int32_t>();
case nvinfer1::DataType::kINT64: return values<std::int64_t>();
case nvinfer1::DataType::kUINT8: return values<std::uint8_t>();
default: return std::string("Unsupported data type");
default: return std::string(mName + ": Unsupported data type");
}
}
std::string shape(void)
{
using namespace tensorrt_llm::runtime;
std::ostringstream buf;
buf << mName << ": " << mTensor.getShape();
return buf.str();
}
void print_tokens(void)
{
TLLM_LOG_DEBUG(tokens());
@ -324,6 +350,11 @@ public:
TLLM_LOG_DEBUG(values());
}
void print_shape(void)
{
TLLM_LOG_DEBUG(shape());
}
private:
runtime::ITensor const& mTensor;
std::string mName;
@ -332,5 +363,6 @@ private:
#define D(x) tensorrt_llm::layers::DebugTensor(x, #x)
#define PRINT_TOKENS(x) D(x).print_tokens()
#define PRINT_VALUES(x) D(x).print_values()
#define PRINT_SHAPE(x) D(x).print_shape()
} // namespace tensorrt_llm::layers

View File

@ -15,6 +15,7 @@
*/
#include "tensorrt_llm/layers/lookaheadPoolManager.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/layers/lookaheadDecodingUtils.h"
namespace tensorrt_llm::layers
@ -24,13 +25,18 @@ using namespace tensorrt_llm::runtime;
void LookaheadPoolManager::setup(SizeType32 guessSetSize)
{
TLLM_CHECK(guessSetSize > 0 && guessSetSize <= mGuessSetSizeMax);
TLLM_CHECK(guessSetSize >= 0 && guessSetSize <= mGuessSetSizeMax);
mGuessSetSize = guessSetSize;
mTokenMap.clear();
}
void LookaheadPoolManager::insertOne(Key key, TensorConstPtr const& ngram)
{
if (TLLM_UNLIKELY(ITensor::volume(ngram->getShape()) == 0 || mGuessSetSize == 0))
{
return;
}
auto search = mTokenMap.find(key);
if (search != mTokenMap.end())
{
@ -41,7 +47,7 @@ void LookaheadPoolManager::insertOne(Key key, TensorConstPtr const& ngram)
BufferRange<TokenIdType const> itemRange(*item);
return std::equal(ngramRange.begin(), ngramRange.end(), itemRange.begin());
});
if (mGuessSetSize >= 0 && search->second.size() >= mGuessSetSize)
if (mGuessSetSize > 0 && search->second.size() >= mGuessSetSize)
{
search->second.pop_front();
}
@ -104,7 +110,6 @@ void LookaheadPoolManager::update(TensorConstPtr const& keyTokens, TensorConstPt
BufferRange<TokenIdType const> sourceRange(*source);
BufferRange<TokenIdType> ngramRange(*ngram);
std::copy(sourceRange.begin(), sourceRange.end(), ngramRange.begin());
insertOne(keyRange[wi], ngram);
}
}

View File

@ -31,9 +31,7 @@ using namespace tensorrt_llm::kernels;
using namespace tensorrt_llm::kernels::speculative_decoding;
using namespace tensorrt_llm::runtime;
namespace tensorrt_llm
{
namespace layers
namespace tensorrt_llm::layers
{
template <typename T>
@ -144,7 +142,7 @@ void MedusaDecodingLayer<T>::freeBuffer()
template <typename T>
void MedusaDecodingLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 const* batchSlots,
std::shared_ptr<BaseSetupParams> baseSetupParams)
std::shared_ptr<BaseSetupParams> const& baseSetupParams)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
@ -272,12 +270,12 @@ void MedusaDecodingLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, S
template <typename T>
void MedusaDecodingLayer<T>::forwardAsync(
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto inputs = std::dynamic_pointer_cast<MedusaInputParams>(baseInputs);
auto outputs = std::dynamic_pointer_cast<DynamicDecodeOutputParams>(baseOutputs);
auto inputs = std::dynamic_pointer_cast<MedusaDecodingInputs>(baseInputs);
auto outputs = std::dynamic_pointer_cast<SpeculativeDecodingOutputs>(baseOutputs);
samplePrimeHeadTokens(*outputs, *inputs);
@ -294,16 +292,16 @@ void MedusaDecodingLayer<T>::forwardAsync(
template <typename T>
void MedusaDecodingLayer<T>::samplePrimeHeadTokens(
DynamicDecodeOutputParams const& outputs, MedusaInputParams const& inputs)
SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = inputs.logits.shape[0];
auto const batchSize = inputs.logits->shape[0];
auto logits = inputs.logits.template getPtr<T>();
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType32 const>() : nullptr;
auto sequenceLengths = outputs.sequence_length ? outputs.sequence_length->template getPtr<SizeType32>() : nullptr;
auto tokensPerStepDevice = inputs.medusaCurTokensPerStep.template getPtr<SizeType32>();
auto logits = inputs.logits->template getPtr<T>();
auto batchSlots = inputs.batchSlots ? inputs.batchSlots->template getPtr<SizeType32 const>() : nullptr;
auto sequenceLengths = outputs.sequenceLength ? outputs.sequenceLength->template getPtr<SizeType32>() : nullptr;
auto tokensPerStepDevice = inputs.curTokensPerStep->template getPtr<SizeType32>();
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding");
@ -333,22 +331,22 @@ void MedusaDecodingLayer<T>::samplePrimeHeadTokens(
template <typename T>
void MedusaDecodingLayer<T>::acceptDraftTokens(
DynamicDecodeOutputParams const& outputs, MedusaInputParams const& inputs)
SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = inputs.logits.shape[0];
auto const maxSeqLen = outputs.output_ids.shape[outputs.output_ids.shape.size() - 1];
auto const batchSize = inputs.logits->shape[0];
auto const maxSeqLen = outputs.outputIds.shape[outputs.outputIds.shape.size() - 1];
auto outputIds = outputs.output_ids.template getPtr<TokenIdType>();
auto endIds = inputs.end_ids.template getPtr<TokenIdType const>();
auto outputIds = outputs.outputIds.template getPtr<TokenIdType>();
auto endIds = inputs.endIds.template getPtr<TokenIdType const>();
auto paths = inputs.paths.template getPtr<SizeType32 const>();
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType32 const>() : nullptr;
auto sequenceLengths = outputs.sequence_length ? outputs.sequence_length->template getPtr<SizeType32>() : nullptr;
auto acceptedLengths = outputs.speculativeDecodingOutputs->acceptedLengths.template getPtr<SizeType32>();
auto curTokensPerStepDevice = inputs.medusaCurTokensPerStep.template getPtr<SizeType32>();
auto targetTokensPerStepDevice = inputs.medusaTargetTokensPerStep.template getPtr<SizeType32>();
auto batchSlots = inputs.batchSlots ? inputs.batchSlots->template getPtr<SizeType32 const>() : nullptr;
auto sequenceLengths = outputs.sequenceLength ? outputs.sequenceLength->template getPtr<SizeType32>() : nullptr;
auto numNewTokens = outputs.numNewTokens->template getPtr<SizeType32>();
auto curTokensPerStepDevice = inputs.curTokensPerStep->template getPtr<SizeType32>();
auto targetTokensPerStepDevice = inputs.targetTokensPerStep.template getPtr<SizeType32>();
auto const maxDraftPathLen = mDecoderDomain.getSpeculativeDecodingModule()->getMaxDraftPathLen();
@ -362,12 +360,12 @@ void MedusaDecodingLayer<T>::acceptDraftTokens(
}
}
auto draftIds = outputs.speculativeDecodingOutputs->nextDraftTokens.template getPtr<TokenIdType>();
auto draftIds = outputs.nextDraftTokens.template getPtr<TokenIdType>();
TLLM_CHECK_WITH_INFO(draftIds != nullptr, "Draft ids must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(acceptedLengths != nullptr, "Accepted lengths must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(numNewTokens != nullptr, "Accepted lengths must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(
curTokensPerStepDevice != nullptr, "Current tokens per step must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(
@ -379,7 +377,7 @@ void MedusaDecodingLayer<T>::acceptDraftTokens(
// Compare draft tokens from outputIds with sampled target tokens at mTargetTokensDevice using paths.
// Select the longest accepted path, modify outputIds in-place, increment sequenceLengths accordingly.
// Fill mMedusaSelectedLogitsPtrsDevice with respective Medusa logits
acceptDraftTokensByIdsWithPaths(outputIds, draftIds, mTargetTokensDevice, sequenceLengths, acceptedLengths,
acceptDraftTokensByIdsWithPaths(outputIds, draftIds, mTargetTokensDevice, sequenceLengths, numNewTokens,
finishedStates, batchSlots, paths, endIds,
reinterpret_cast<T const**>(bufferCast<int64_t>(*mMedusaInputLogitsPtrs)),
const_cast<T const**>(mMedusaSelectedLogitsPtrsDevice), curTokensPerStepDevice, targetTokensPerStepDevice,
@ -391,13 +389,13 @@ void MedusaDecodingLayer<T>::acceptDraftTokens(
template <typename T>
void MedusaDecodingLayer<T>::sampleNewDraftTokens(
DynamicDecodeOutputParams const& outputs, MedusaInputParams const& inputs)
SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = inputs.logits.shape[0];
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType32 const>() : nullptr;
auto sequenceLengths = (outputs.sequence_length) ? outputs.sequence_length->template getPtr<SizeType32>() : nullptr;
auto const batchSize = inputs.logits->shape[0];
auto batchSlots = inputs.batchSlots ? inputs.batchSlots->template getPtr<SizeType32 const>() : nullptr;
auto sequenceLengths = (outputs.sequenceLength) ? outputs.sequenceLength->template getPtr<SizeType32>() : nullptr;
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding");
@ -449,18 +447,18 @@ void MedusaDecodingLayer<T>::sampleNewDraftTokens(
template <typename T>
void MedusaDecodingLayer<T>::scatterNewDraftTokens(
DynamicDecodeOutputParams const& outputs, MedusaInputParams const& inputs)
SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = inputs.logits.shape[0];
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType32 const>()
: static_cast<SizeType32*>(nullptr);
auto const batchSize = inputs.logits->shape[0];
auto batchSlots = inputs.batchSlots ? inputs.batchSlots->template getPtr<SizeType32 const>()
: static_cast<SizeType32*>(nullptr);
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding");
auto draftIds = outputs.speculativeDecodingOutputs->nextDraftTokens.template getPtr<TokenIdType>();
auto tokensPerStepDevice = inputs.medusaCurTokensPerStep.template getPtr<SizeType32>();
auto draftIds = outputs.nextDraftTokens.template getPtr<TokenIdType>();
auto tokensPerStepDevice = inputs.curTokensPerStep->template getPtr<SizeType32>();
auto treeIds = inputs.treeIds.template getPtr<SizeType32>();
TLLM_CHECK_WITH_INFO(draftIds != nullptr, "Draft ids must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(tokensPerStepDevice != nullptr, "Tokens per step must be provided for MedusaDecoding");
@ -474,23 +472,22 @@ void MedusaDecodingLayer<T>::scatterNewDraftTokens(
template <typename T>
void MedusaDecodingLayer<T>::packAcceptedPaths(
DynamicDecodeOutputParams const& outputs, MedusaInputParams const& inputs)
SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = inputs.logits.shape[0];
auto const batchSize = inputs.logits->shape[0];
auto paths = inputs.paths.template getPtr<SizeType32 const>();
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType32 const>() : nullptr;
auto acceptedLengths = outputs.speculativeDecodingOutputs->acceptedLengths.template getPtr<SizeType32>();
auto acceptedLengthsCumSum
= outputs.speculativeDecodingOutputs->acceptedLengthsCumSum.template getPtr<SizeType32>();
auto pathsOffsets = outputs.speculativeDecodingOutputs->pathsOffsets.template getPtr<SizeType32>();
auto batchSlots = inputs.batchSlots ? inputs.batchSlots->template getPtr<SizeType32 const>() : nullptr;
auto numNewTokens = outputs.numNewTokens->template getPtr<SizeType32>();
auto numNewTokensCumSum = outputs.numNewTokensCumSum.template getPtr<SizeType32>();
auto pathsOffsets = outputs.pathsOffsets.template getPtr<SizeType32>();
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(acceptedLengths != nullptr, "Accepted lengths must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(acceptedLengthsCumSum != nullptr, "acceptedLengthsCumSum must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(numNewTokens != nullptr, "Accepted lengths must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(numNewTokensCumSum != nullptr, "numNewTokensCumSum must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(pathsOffsets != nullptr, "pathsOffsets must be provided for MedusaDecoding");
invokePackAcceptedPaths(acceptedLengthsCumSum, pathsOffsets, acceptedLengths, mBestPathIdsDevice, paths, batchSlots,
invokePackAcceptedPaths(numNewTokensCumSum, pathsOffsets, numNewTokens, mBestPathIdsDevice, paths, batchSlots,
batchSize, mDecoderDomain.getMaxDecodingTokens(),
mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen(), false, mStream);
@ -500,5 +497,4 @@ void MedusaDecodingLayer<T>::packAcceptedPaths(
template class MedusaDecodingLayer<float>;
template class MedusaDecodingLayer<half>;
} // namespace layers
} // namespace tensorrt_llm
} // namespace tensorrt_llm::layers

View File

@ -19,46 +19,13 @@
#include <curand_kernel.h>
#include "tensorrt_llm/common/tensor.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/layers/baseLayer.h"
#include "tensorrt_llm/layers/decodingParams.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h"
namespace tc = tensorrt_llm::common;
namespace tensorrt_llm
namespace tensorrt_llm::layers
{
namespace layers
{
class MedusaSetupParams : public BaseSetupParams
{
public:
std::optional<std::vector<int32_t>> runtimeTopK; // [1] or [setupBatchSize] on cpu
std::optional<std::vector<std::vector<int32_t>>> runtimeHeadsTopK; // [setupBatchSize, maxDraftPathLen] on cpu
std::optional<std::vector<uint64_t>> randomSeed; // [1] or [setupBatchSize] on cpu
};
class MedusaInputParams : public BaseInputParams
{
public:
explicit MedusaInputParams(tc::Tensor logits, tc::Tensor endIds)
: BaseInputParams{0, 0, std::move(endIds)}
, logits{std::move(logits)}
{
}
tc::Tensor logits; // [maxBatchSize, beamWidth, vocabSizePadded]
tc::Tensor paths; // [maxBatchSize, maxDecodingTokens, maxPathLen] on gpu
std::vector<std::vector<tc::Tensor>>
medusaLogits; // [maxBatchSize][maxDraftPathLen][maxDecodingTokens, vocabSize] on gpu
tc::Tensor medusaCurTokensPerStep; // [maxBatchSize] on gpu
tc::Tensor medusaTargetTokensPerStep; // [maxBatchSize] on gpu
tc::Tensor treeIds; // [maxBatchSize, maxDecodingTokens] on gpu
};
//! \brief
template <typename T>
@ -74,19 +41,20 @@ public:
~MedusaDecodingLayer() override;
void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 const* batchSlots,
std::shared_ptr<BaseSetupParams> setupParams) override;
std::shared_ptr<BaseSetupParams> const& setupParams) override;
void forwardAsync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) override;
void forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& inputs) override;
private:
void allocateBuffer();
void freeBuffer();
void samplePrimeHeadTokens(DynamicDecodeOutputParams const& outputs, MedusaInputParams const& inputs);
void acceptDraftTokens(DynamicDecodeOutputParams const& outputs, MedusaInputParams const& inputs);
void sampleNewDraftTokens(DynamicDecodeOutputParams const& outputs, MedusaInputParams const& inputs);
void scatterNewDraftTokens(DynamicDecodeOutputParams const& outputs, MedusaInputParams const& inputs);
void packAcceptedPaths(DynamicDecodeOutputParams const& outputs, MedusaInputParams const& inputs);
void samplePrimeHeadTokens(SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs);
void acceptDraftTokens(SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs);
void sampleNewDraftTokens(SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs);
void scatterNewDraftTokens(SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs);
void packAcceptedPaths(SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs);
private:
using Base::mStream;
@ -118,5 +86,4 @@ private:
std::vector<runtime::SizeType32> mCummulativeTopK;
};
} // namespace layers
} // namespace tensorrt_llm
} // namespace tensorrt_llm::layers

View File

@ -17,9 +17,11 @@
#include "tensorrt_llm/layers/penaltyLayer.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/common/tensor.h"
#include "tensorrt_llm/kernels/penaltyKernels.h"
#include "tensorrt_llm/layers/defaultDecodingParams.h"
#include "tensorrt_llm/layers/layerUtils.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include <algorithm>
@ -27,9 +29,7 @@ using namespace tensorrt_llm::common;
using namespace tensorrt_llm::kernels;
using namespace tensorrt_llm::runtime;
namespace tensorrt_llm
{
namespace layers
namespace tensorrt_llm::layers
{
template <typename T>
@ -182,7 +182,7 @@ void PenaltyLayer<T>::freeBuffer()
template <typename T>
void PenaltyLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 const* batchSlots,
std::shared_ptr<BaseSetupParams> baseSetupParams)
std::shared_ptr<BaseSetupParams> const& baseSetupParams)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
@ -207,14 +207,15 @@ void PenaltyLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType
FillBuffers const fillBuffers{batchSize, mDecoderDomain.getBatchSize(), mStream};
auto const& penaltyParams = setupParams->penaltyParams;
TLLM_CHECK_WITH_INFO(penaltyParams, "penaltyParams for setup is not set");
bool const useTemperature = mDecodingMode.isUseTemperature() && penaltyParams.temperature.has_value();
bool const useTemperature = mDecodingMode.isUseTemperature() && penaltyParams->temperature.has_value();
bool const useRepetitionPenalty
= mDecodingMode.isUseRepetitionPenalty() && penaltyParams.repetitionPenalty.has_value();
bool const usePresencePenalty = mDecodingMode.isUsePresencePenalty() && penaltyParams.presencePenalty.has_value();
= mDecodingMode.isUseRepetitionPenalty() && penaltyParams->repetitionPenalty.has_value();
bool const usePresencePenalty = mDecodingMode.isUsePresencePenalty() && penaltyParams->presencePenalty.has_value();
bool const useFrequencyPenalty
= mDecodingMode.isUseFrequencyPenalty() && penaltyParams.frequencyPenalty.has_value();
bool const useMinLength = mDecodingMode.isUseMinLength() && penaltyParams.minLength.has_value();
= mDecodingMode.isUseFrequencyPenalty() && penaltyParams->frequencyPenalty.has_value();
bool const useMinLength = mDecodingMode.isUseMinLength() && penaltyParams->minLength.has_value();
// FIXME(nkorobov): once one of the requests has some penalty, we will always have to compute it.
// To avoid that we need to scan through all active requests at each iteration.
mUseTemperature |= useTemperature;
@ -225,31 +226,31 @@ void PenaltyLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType
if (mUseTemperature)
{
fillBuffers(penaltyParams.temperature, DefaultDecodingParams::getTemperature(), mTemperature,
fillBuffers(penaltyParams->temperature, DefaultDecodingParams::getTemperature(), mTemperature,
mTemperatureDevice, batchSlotsHost, getLimitsPenalty(DecodingPenaltyType::Temperature),
"temperature penalty");
}
if (mUseRepetitionPenalty)
{
fillBuffers(penaltyParams.repetitionPenalty, DefaultDecodingParams::getRepetitionPenalty(), mRepetitionPenalty,
fillBuffers(penaltyParams->repetitionPenalty, DefaultDecodingParams::getRepetitionPenalty(), mRepetitionPenalty,
mRepetitionPenaltyDevice, batchSlotsHost, getLimitsPenalty(DecodingPenaltyType::Repetition),
"repetition penalty");
}
if (mUsePresencePenalty)
{
fillBuffers(penaltyParams.presencePenalty, DefaultDecodingParams::getPresencePenalty(), mPresencePenalty,
fillBuffers(penaltyParams->presencePenalty, DefaultDecodingParams::getPresencePenalty(), mPresencePenalty,
mPresencePenaltyDevice, batchSlotsHost, getLimitsPenalty(DecodingPenaltyType::Presence),
"presence penalty");
}
if (mUseFrequencyPenalty)
{
fillBuffers(penaltyParams.frequencyPenalty, DefaultDecodingParams::getFrequencyPenalty(), mFrequencyPenalty,
fillBuffers(penaltyParams->frequencyPenalty, DefaultDecodingParams::getFrequencyPenalty(), mFrequencyPenalty,
mFrequencyPenaltyDevice, batchSlotsHost, getLimitsPenalty(DecodingPenaltyType::Frequency),
"frequency penalty");
}
if (mUseMinLength)
{
fillBuffers(penaltyParams.minLength, DefaultDecodingParams::getMinLength(), mMinLength, mMinLengthDevice,
fillBuffers(penaltyParams->minLength, DefaultDecodingParams::getMinLength(), mMinLength, mMinLengthDevice,
batchSlotsHost, getLimitsPenalty(DecodingPenaltyType::MinLength), "min length");
}
@ -258,21 +259,21 @@ void PenaltyLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType
template <typename T>
void PenaltyLayer<T>::forwardAsync(
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto outputs = std::dynamic_pointer_cast<DynamicDecodeOutputParams>(baseOutputs);
auto params = std::dynamic_pointer_cast<DynamicDecodeInputParams>(baseInputs);
auto outputs = std::dynamic_pointer_cast<BaseDecodingOutputs>(baseOutputs);
auto params = std::dynamic_pointer_cast<DecodingInputs>(baseInputs);
auto const localDecoderDomain = getLocalDecoderDomain(params, mDecoderDomain);
auto const maxSeqLen = outputs->output_ids.shape[outputs->output_ids.shape.size() - 1];
auto batchSlots = params->batch_slots ? params->batch_slots->template getPtr<SizeType32 const>() : nullptr;
auto const maxSeqLen = outputs->outputIds.shape[outputs->outputIds.shape.size() - 1];
auto batchSlots = params->batchSlots ? params->batchSlots->template getPtr<SizeType32 const>() : nullptr;
std::vector<SizeType32> batchSlotsVec(localDecoderDomain.getBatchSize());
std::iota(batchSlotsVec.begin(), batchSlotsVec.end(), 0);
auto batchSlotsHost
= params->batch_slots ? params->batch_slots->template getPtr<SizeType32 const>() : batchSlotsVec.data();
= params->batchSlots ? params->batchSlots->template getPtr<SizeType32 const>() : batchSlotsVec.data();
if (!mLogitsPtrsHost->data())
{
@ -288,12 +289,12 @@ void PenaltyLayer<T>::forwardAsync(
auto logitsPtrsHostData = reinterpret_cast<T const**>(runtime::bufferCast<int64_t>(*logitsPtrsHost));
for (SizeType32 bi = 0; bi < localDecoderDomain.getBatchSize(); bi++)
{
if (params->logits_vec)
if (params->logitsVec)
{
TLLM_CHECK_WITH_INFO(params->logits_vec->size() == localDecoderDomain.getBatchSize(),
"Logits vector size (%lu) is not equal to the batchSize (%d)", params->logits_vec->size(),
TLLM_CHECK_WITH_INFO(params->logitsVec->size() == localDecoderDomain.getBatchSize(),
"Logits vector size (%lu) is not equal to the batchSize (%d)", params->logitsVec->size(),
localDecoderDomain.getBatchSize());
logitsPtrsHostData[bi] = params->logits_vec.value()[bi].template getPtr<T>();
logitsPtrsHostData[bi] = params->logitsVec.value()[bi].template getPtr<T>();
}
else
{
@ -303,12 +304,11 @@ void PenaltyLayer<T>::forwardAsync(
}
SizeType32 const* inputLengths = nullptr;
if (params->input_lengths)
if (params->inputLengths)
{
auto& input_lengths = params->input_lengths.value();
inputLengths = input_lengths.template getPtr<SizeType32 const>();
inputLengths = params->inputLengths->template getPtr<SizeType32 const>();
}
auto* embeddingBias = params->embedding_bias ? params->embedding_bias->template getPtr<T const>() : nullptr;
auto* embeddingBias = params->embeddingBias ? params->embeddingBias->template getPtr<T const>() : nullptr;
#define GET_PENALTIES(capital_name, type) \
(mUse##capital_name \
&& !allOfBatchSlots(batchSlotsHost, m##capital_name.data(), localDecoderDomain.getBatchSize(), \
@ -324,9 +324,8 @@ void PenaltyLayer<T>::forwardAsync(
#undef GET_PENALTIES
auto const tokensPerStep = params->medusaInputs
? params->medusaInputs->medusaCurTokensPerStep.template getPtr<SizeType32 const>()
: nullptr;
auto const tokensPerStep
= params->curTokensPerStep ? params->curTokensPerStep->template getPtr<SizeType32 const>() : nullptr;
InvokeBatchApplyPenaltyParams<T> penaltyParams;
penaltyParams.inputLogits = reinterpret_cast<T const* const*>(logitsPtrsHostData);
@ -343,12 +342,12 @@ void PenaltyLayer<T>::forwardAsync(
penaltyParams.maxSeqLen = maxSeqLen;
penaltyParams.vocabSize = mDecoderDomain.getVocabSize();
penaltyParams.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
penaltyParams.outputIdsPtr = outputs->output_ids_ptr.template getPtr<TokenIdType const*>();
penaltyParams.parentIdsPtr = outputs->parent_ids_ptr.template getPtr<SizeType32 const*>();
penaltyParams.outputIdsPtr = outputs->outputIdsPtr.template getPtr<TokenIdType const*>();
penaltyParams.parentIdsPtr = outputs->parentIdsPtr.template getPtr<SizeType32 const*>();
penaltyParams.inputLengths = inputLengths;
penaltyParams.sequenceLengths = outputs->sequence_length->template getPtr<SizeType32 const>();
penaltyParams.sequenceLengths = outputs->sequenceLength->template getPtr<SizeType32 const>();
penaltyParams.minLengths = minLengths;
penaltyParams.endIds = params->end_ids.template getPtr<TokenIdType const>();
penaltyParams.endIds = params->endIds.template getPtr<TokenIdType const>();
penaltyParams.batchSlots = batchSlots;
penaltyParams.maxTokensPerStep = mDecoderDomain.getMaxDecodingTokens();
penaltyParams.tokensPerStep = tokensPerStep;
@ -376,5 +375,4 @@ void PenaltyLayer<T>::forwardAsync(
template class PenaltyLayer<float>;
template class PenaltyLayer<half>;
} // namespace layers
} // namespace tensorrt_llm
} // namespace tensorrt_llm::layers

View File

@ -19,17 +19,12 @@
#include <curand_kernel.h>
#include "tensorrt_llm/common/tensor.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/layers/baseLayer.h"
#include "tensorrt_llm/layers/decodingParams.h"
#include "tensorrt_llm/layers/layerUtils.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/iTensor.h"
namespace tensorrt_llm
{
namespace layers
namespace tensorrt_llm::layers
{
//! \brief Layer applies penalties to the logits. Supports:
@ -48,10 +43,11 @@ public:
~PenaltyLayer() override;
void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 const* batchSlots,
std::shared_ptr<BaseSetupParams> setupParams) override;
std::shared_ptr<BaseSetupParams> const& setupParams) override;
//! \brief Modifies 'outputs->logits' in-place with -INF for banned words
void forwardAsync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) override;
void forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& inputs) override;
T* getRuntimeLogitsDevice()
{
@ -103,5 +99,4 @@ private:
runtime::ITensor::SharedPtr mLogitsPtrsHost;
};
} // namespace layers
} // namespace tensorrt_llm
} // namespace tensorrt_llm::layers

View File

@ -19,7 +19,8 @@
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/kernels/decodingCommon.h"
#include "tensorrt_llm/kernels/samplingTopKKernels.h"
#include "tensorrt_llm/layers/topKSamplingLayer.h"
#include "tensorrt_llm/layers/topPSamplingLayer.h"
#include <algorithm>
@ -27,9 +28,7 @@ using namespace tensorrt_llm::common;
using namespace tensorrt_llm::kernels;
using namespace tensorrt_llm::runtime;
namespace tensorrt_llm
{
namespace layers
namespace tensorrt_llm::layers
{
template <typename T>
SamplingLayer<T>::SamplingLayer(executor::DecodingMode const& mode, DecoderDomain const& decoderDomain,
@ -111,7 +110,7 @@ void SamplingLayer<T>::freeBuffer()
template <typename T>
void SamplingLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 const* batchSlots,
std::shared_ptr<BaseSetupParams> baseSetupParams)
std::shared_ptr<BaseSetupParams> const& baseSetupParams)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
@ -168,20 +167,17 @@ void SamplingLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeTyp
template <typename T>
void SamplingLayer<T>::forwardAsync(
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
std::shared_ptr<BaseDecodingOutputs> const& outputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto inputs = std::dynamic_pointer_cast<SamplingInputParams>(baseInputs);
auto outputs = std::dynamic_pointer_cast<SamplingOutputParams>(baseOutputs);
auto inputs = std::dynamic_pointer_cast<SamplingInputs>(baseInputs);
auto const batchSize = inputs->logits.shape[0];
auto const batchSize = inputs->logits->shape[0];
auto logits = inputs->logits.template getPtr<T>();
auto endIds = inputs->end_ids.template getPtr<int const>();
auto batchSlots = inputs->batch_slots ? inputs->batch_slots->template getPtr<int const>() : nullptr;
float* cumLogProbs = (outputs->cum_log_probs) ? outputs->cum_log_probs->template getPtr<float>() : nullptr;
float* outputLogProbs = (outputs->output_log_probs) ? outputs->output_log_probs->template getPtr<float>() : nullptr;
auto logits = inputs->logits->template getPtr<T>();
auto endIds = inputs->endIds.template getPtr<int const>();
auto batchSlots = inputs->batchSlots ? inputs->batchSlots->template getPtr<int const>() : nullptr;
FinishedState* finishedInput = (inputs->finished)
? reinterpret_cast<FinishedState*>(inputs->finished->template getPtr<FinishedState::UnderlyingType>())
@ -192,9 +188,9 @@ void SamplingLayer<T>::forwardAsync(
// Compute probabilities either for TopP or if cumLogProbs or outputLogProbs are specified
bool const skipSoftMax = skipTopP && !mOutputLogProbs && !mCumLogProbs;
inputs->curand_states = mCurandStatesDevice;
inputs->sampling_workspace = mSamplingWorkspaceDevice;
inputs->probs_computed = !skipSoftMax;
inputs->curandStates = mCurandStatesDevice;
inputs->samplingWorkspace = mSamplingWorkspaceDevice;
inputs->probsComputed = !skipSoftMax;
if (!skipSoftMax)
{
invokeAddBiasSoftMax(logits, (T**) nullptr, logits, (T*) (nullptr), endIds, finishedInput, batchSlots,
@ -205,7 +201,7 @@ void SamplingLayer<T>::forwardAsync(
for (auto&& layer : mSamplingLayers)
{
layer->forwardAsync(baseOutputs, baseInputs);
layer->forwardAsync(outputs, baseInputs);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
@ -214,5 +210,4 @@ void SamplingLayer<T>::forwardAsync(
template class SamplingLayer<float>;
template class SamplingLayer<half>;
} // namespace layers
} // namespace tensorrt_llm
} // namespace tensorrt_llm::layers

View File

@ -17,21 +17,14 @@
#pragma once
#include <curand_kernel.h>
#include "tensorrt_llm/common/tensor.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/layers/baseLayer.h"
#include "tensorrt_llm/layers/decodingParams.h"
#include "tensorrt_llm/layers/samplingParams.h"
#include "tensorrt_llm/layers/topKSamplingLayer.h"
#include "tensorrt_llm/layers/topPSamplingLayer.h"
#include "tensorrt_llm/runtime/common.h"
namespace tc = tensorrt_llm::common;
#include <curand_kernel.h>
namespace tensorrt_llm
{
namespace layers
namespace tensorrt_llm::layers
{
//! \brief Top class for sampling layers.
@ -48,9 +41,10 @@ public:
~SamplingLayer() override = default;
void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 const* batchSlots,
std::shared_ptr<BaseSetupParams> setupParams) override;
std::shared_ptr<BaseSetupParams> const& setupParams) override;
void forwardAsync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) override;
void forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& inputs) override;
private:
using Base::mWorkspaceSize;
@ -82,5 +76,4 @@ private:
void freeBuffer();
};
} // namespace layers
} // namespace tensorrt_llm
} // namespace tensorrt_llm::layers

View File

@ -1,77 +0,0 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/layers/decodingParams.h"
#include <tensorrt_llm/common/tensor.h>
#include <tensorrt_llm/runtime/common.h>
#include <optional>
#include <vector>
namespace tc = tensorrt_llm::common;
namespace tensorrt_llm::layers
{
class SamplingSetupParams : public BaseSetupParams
{
public:
std::optional<std::vector<runtime::SizeType32>> runtime_top_k; // [1] or [batchSize] on cpu
std::optional<std::vector<float>> runtime_top_p; // [1] or [batchSize] on cpu
std::optional<std::vector<uint64_t>> randomSeed; // [1] or [batchSize] on cpu
std::optional<std::vector<float>> top_p_decay; // [batchSize], must between [0, 1]
std::optional<std::vector<float>> top_p_min; // [batchSize], must between [0, 1]
std::optional<std::vector<runtime::TokenIdType>> top_p_reset_ids; // [batchSize]
std::optional<std::vector<bool>> outputLogProbs; // [batchSize]
std::optional<std::vector<bool>> cumLogProbs; // [batchSize]
std::optional<bool> normalize_log_probs;
};
class SamplingInputParams : public BaseInputParams
{
public:
explicit SamplingInputParams(runtime::SizeType32 step, runtime::SizeType32 ite, tc::Tensor logits,
tc::Tensor end_ids, runtime::SizeType32 max_seq_len)
: BaseInputParams{step, ite, std::move(end_ids)}
, logits{std::move(logits)}
, max_seq_len{max_seq_len}
{
}
// mandatory parameters
tc::Tensor logits; // [local_batch_size, beam_width, vocab_size_padded]
runtime::SizeType32 max_seq_len;
// optional parameters
std::optional<tc::Tensor> input_lengths; // [localBatchSize]
curandState_t* curand_states; // [localBatchSize]
// Pointer to the workspace for sampling computation
void* sampling_workspace;
// Flag to mark that logits tensor contains probabilities
bool probs_computed;
};
class SamplingOutputParams : public BaseOutputParams
{
public:
explicit SamplingOutputParams(tc::Tensor outputIds)
: BaseOutputParams{std::move(outputIds)}
{
}
};
} // namespace tensorrt_llm::layers

View File

@ -17,19 +17,14 @@
#include "tensorrt_llm/layers/stopCriteriaLayer.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/kernels/stopCriteriaKernels.h"
#include "tensorrt_llm/layers/layerUtils.h"
#include <algorithm>
using namespace tensorrt_llm::common;
using namespace tensorrt_llm::kernels;
using namespace tensorrt_llm::runtime;
namespace tensorrt_llm
{
namespace layers
namespace tensorrt_llm::layers
{
template <typename T>
@ -44,7 +39,7 @@ StopCriteriaLayer<T>::StopCriteriaLayer(executor::DecodingMode const& mode, Deco
template <typename T>
void StopCriteriaLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 const* batchSlots,
std::shared_ptr<BaseSetupParams> setupParams)
std::shared_ptr<BaseSetupParams> const& setupParams)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
@ -52,16 +47,18 @@ void StopCriteriaLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, Siz
template <typename T>
void StopCriteriaLayer<T>::forwardAsync(
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto inputs = std::dynamic_pointer_cast<DynamicDecodeInputParams>(baseInputs);
auto outputs = std::dynamic_pointer_cast<DynamicDecodeOutputParams>(baseOutputs);
auto inputs = std::dynamic_pointer_cast<DecodingInputs>(baseInputs);
auto outputs = std::dynamic_pointer_cast<BaseDecodingOutputs>(baseOutputs);
auto const localDecoderDomain = getLocalDecoderDomain(inputs, mDecoderDomain);
auto const maxSeqLen = outputs->output_ids.shape[outputs->output_ids.shape.size() - 1];
auto batchSlots = inputs->batch_slots ? inputs->batch_slots->template getPtr<SizeType32 const>() : nullptr;
auto const maxSeqLen = outputs->outputIds.shape[outputs->outputIds.shape.size() - 1];
auto batchSlots = inputs->batchSlots ? inputs->batchSlots->template getPtr<SizeType32 const>() : nullptr;
TLLM_CHECK_WITH_INFO(inputs->stopCriteriaInputs, "stopCriteriaInputs for forward is not set");
if (mDecodingMode.isUseStopWords())
{
@ -80,60 +77,61 @@ void StopCriteriaLayer<T>::forwardAsync(
}
template <typename T>
void StopCriteriaLayer<T>::checkStopWordsStopCriteria(std::shared_ptr<DynamicDecodeOutputParams>& outputs,
std::shared_ptr<DynamicDecodeInputParams> const& inputs, SizeType32 const* batchSlots,
DecoderDomain const& decoderDomain, SizeType32 maxSeqLen, cudaStream_t stream)
void StopCriteriaLayer<T>::checkStopWordsStopCriteria(std::shared_ptr<BaseDecodingOutputs>& outputs,
std::shared_ptr<DecodingInputs> const& inputs, SizeType32 const* batchSlots, DecoderDomain const& decoderDomain,
SizeType32 maxSeqLen, cudaStream_t stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const maxStopWordsLength = inputs->max_stop_words_len;
auto const maxStopWordsLength = inputs->stopCriteriaInputs->maxStopWordsLen;
if (maxStopWordsLength)
{
auto numNewTokens = outputs->speculativeDecodingOutputs
? outputs->speculativeDecodingOutputs->acceptedLengths.template getPtr<SizeType32>()
: nullptr;
invokeStopWordsCriterion(outputs->output_ids_ptr.template getPtr<TokenIdType const*>(),
outputs->parent_ids_ptr.template getPtr<SizeType32 const*>(),
inputs->stop_words_ptr->template getPtr<TokenIdType const*>(),
auto numNewTokens = outputs->numNewTokens ? outputs->numNewTokens->template getPtr<SizeType32>() : nullptr;
invokeStopWordsCriterion(outputs->outputIdsPtr.template getPtr<TokenIdType const*>(),
outputs->parentIdsPtr.template getPtr<SizeType32 const*>(),
inputs->stopCriteriaInputs->stopWordsPtr->template getPtr<TokenIdType const*>(),
reinterpret_cast<FinishedState*>(outputs->finished->template getPtr<FinishedState::UnderlyingType>()),
outputs->sequence_length->template getPtr<SizeType32>(), batchSlots,
inputs->stop_words_lengths->template getPtr<SizeType32 const>(), numNewTokens, maxStopWordsLength,
decoderDomain.getBatchSize(), decoderDomain.getBeamWidth(), maxSeqLen, stream);
outputs->sequenceLength->template getPtr<SizeType32>(), batchSlots,
inputs->stopCriteriaInputs->stopWordsLengths->template getPtr<SizeType32 const>(), numNewTokens,
maxStopWordsLength, decoderDomain.getBatchSize(), decoderDomain.getBeamWidth(), maxSeqLen, stream);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void StopCriteriaLayer<T>::checkMaxLengthStopCriteria(std::shared_ptr<DynamicDecodeOutputParams>& outputs,
std::shared_ptr<DynamicDecodeInputParams> const& inputs, SizeType32 const* batchSlots,
DecoderDomain const& decoderDomain, SizeType32 maxSeqLen, cudaStream_t stream)
void StopCriteriaLayer<T>::checkMaxLengthStopCriteria(std::shared_ptr<BaseDecodingOutputs>& outputs,
std::shared_ptr<DecodingInputs> const& inputs, SizeType32 const* batchSlots, DecoderDomain const& decoderDomain,
SizeType32 maxSeqLen, cudaStream_t stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if (inputs->sequence_limit_length)
if (inputs->stopCriteriaInputs->sequenceLimitLength)
{
auto numNewTokens = outputs->numNewTokens ? outputs->numNewTokens->template getPtr<SizeType32>() : nullptr;
invokeLengthCriterion(
reinterpret_cast<FinishedState*>(outputs->finished->template getPtr<FinishedState::UnderlyingType>()),
outputs->finished_sum ? outputs->finished_sum->template getPtr<SizeType32>() : nullptr,
inputs->sequence_limit_length->template getPtr<SizeType32 const>(),
outputs->sequence_length->template getPtr<SizeType32>(), batchSlots, decoderDomain.getBatchSize(),
decoderDomain.getBeamWidth(), stream);
outputs->finishedSum ? outputs->finishedSum->template getPtr<SizeType32>() : nullptr,
inputs->stopCriteriaInputs->sequenceLimitLength->template getPtr<SizeType32 const>(),
outputs->sequenceLength->template getPtr<SizeType32>(), numNewTokens, batchSlots,
decoderDomain.getBatchSize(), decoderDomain.getBeamWidth(), stream);
sync_check_cuda_error();
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void StopCriteriaLayer<T>::checkEosToken(std::shared_ptr<DynamicDecodeOutputParams>& outputs,
std::shared_ptr<DynamicDecodeInputParams> const& inputs, SizeType32 const* batchSlots,
DecoderDomain const& decoderDomain, SizeType32 maxSeqLen, cudaStream_t stream)
void StopCriteriaLayer<T>::checkEosToken(std::shared_ptr<BaseDecodingOutputs>& outputs,
std::shared_ptr<DecodingInputs> const& inputs, SizeType32 const* batchSlots, DecoderDomain const& decoderDomain,
SizeType32 maxSeqLen, cudaStream_t stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
invokeExplicitEOSCriterion(outputs->output_ids_ptr.template getPtr<TokenIdType const*>(),
inputs->end_ids.template getPtr<TokenIdType const>(),
auto numNewTokens = outputs->numNewTokens ? outputs->numNewTokens->template getPtr<SizeType32>() : nullptr;
invokeExplicitEOSCriterion(outputs->outputIdsPtr.template getPtr<TokenIdType const*>(),
inputs->endIds.template getPtr<TokenIdType const>(),
reinterpret_cast<FinishedState*>(outputs->finished->template getPtr<FinishedState::UnderlyingType>()),
outputs->sequence_length->template getPtr<SizeType32>(),
// FIXME(nkorobov): add tokens per step tensor when necessary
/* tokensPerStep */ nullptr, batchSlots, decoderDomain.getBatchSize(), decoderDomain.getBeamWidth(),
decoderDomain.getMaxDecodingTokens(), stream);
outputs->sequenceLength->template getPtr<SizeType32>(), numNewTokens, batchSlots, decoderDomain.getBatchSize(),
decoderDomain.getBeamWidth(), decoderDomain.getMaxDecodingTokens(), stream);
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
@ -141,5 +139,4 @@ void StopCriteriaLayer<T>::checkEosToken(std::shared_ptr<DynamicDecodeOutputPara
template class StopCriteriaLayer<float>;
template class StopCriteriaLayer<half>;
} // namespace layers
} // namespace tensorrt_llm
} // namespace tensorrt_llm::layers

Some files were not shown because too many files have changed in this diff Show More