mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#1793)
Co-authored-by: DreamGenX <x@dreamgen.com> Co-authored-by: Ace-RR <78812427+Ace-RR@users.noreply.github.com> Co-authored-by: bprus <39293131+bprus@users.noreply.github.com> Co-authored-by: janpetrov <janpetrov@icloud.com>
This commit is contained in:
parent
db4edea1e1
commit
2a115dae84
@ -232,7 +232,7 @@ ${HOME}/.local/bin/trtllm-build \
|
||||
--output_dir ${LORA_ENGINE} \
|
||||
--max_batch_size ${MAX_BATCH} \
|
||||
--max_input_len $MAX_LEN \
|
||||
--max_output_len $MAX_LEN \
|
||||
--max_seq_len $((2*${MAX_LEN})) \
|
||||
--gemm_plugin float16 \
|
||||
--lora_plugin float16 \
|
||||
--use_paged_context_fmha enable \
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/rawEngine.h"
|
||||
#include "tensorrt_llm/runtime/tllmLogger.h"
|
||||
#include "tensorrt_llm/runtime/tllmRuntime.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
@ -78,11 +79,10 @@ void benchmarkBert(std::string const& modelName, std::filesystem::path const& da
|
||||
{
|
||||
auto const worldConfig = WorldConfig::mpi();
|
||||
auto const enginePath = dataPath / engineFilename(dataPath, worldConfig, modelName);
|
||||
auto engineBlob = loadEngine(enginePath.string());
|
||||
|
||||
for (float gpuWeightsPercent : gpuWeightsPercents)
|
||||
{
|
||||
auto rt = std::make_shared<TllmRuntime>(engineBlob.data(), engineBlob.size(), gpuWeightsPercent, *logger);
|
||||
auto rt = std::make_shared<TllmRuntime>(RawEngine(enginePath), logger.get(), gpuWeightsPercent);
|
||||
rt->addContext(0);
|
||||
for (auto inLen : inLens)
|
||||
{
|
||||
|
||||
@ -150,6 +150,7 @@ struct BenchmarkParams
|
||||
bool streaming{false};
|
||||
bool enableExpDelays{false};
|
||||
std::optional<float> requestRate{std::nullopt};
|
||||
std::optional<SizeType32> maxBatchSize{std::nullopt};
|
||||
int randomSeed = 430;
|
||||
std::optional<int> maxAttentionWindow{std::nullopt};
|
||||
|
||||
@ -785,6 +786,10 @@ public:
|
||||
executorConfig.setPeftCacheConfig(peftCacheConfig);
|
||||
executorConfig.setBatchingType(
|
||||
modelType == TrtGptModelType::V1 ? texec::BatchingType::kSTATIC : texec::BatchingType::kINFLIGHT);
|
||||
if (benchmarkParams.maxBatchSize)
|
||||
{
|
||||
executorConfig.setMaxBatchSize(benchmarkParams.maxBatchSize.value());
|
||||
}
|
||||
|
||||
mExecutor = std::make_unique<texec::Executor>(trtEnginePath, texec::ModelType::kDECODER_ONLY, executorConfig);
|
||||
|
||||
@ -1339,6 +1344,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
||||
optionalParams.kvCacheConfig.onboardBlocks = benchmarkParams.kvOnboardBlocks;
|
||||
optionalParams.gpuWeightsPercent = benchmarkParams.gpuWeightsPercent;
|
||||
optionalParams.maxBeamWidth = beamWidth;
|
||||
optionalParams.maxBatchSize = benchmarkParams.maxBatchSize;
|
||||
optionalParams.schedulerConfig = texec::SchedulerConfig{capacitySchedulerPolicy};
|
||||
|
||||
auto const jsonConfig = GptJsonConfig::parse(engineDir / "config.json");
|
||||
@ -1628,6 +1634,7 @@ int main(int argc, char* argv[])
|
||||
options.add_options()("request_rate",
|
||||
"request rate in reqs/sec. Skipping this arg or negative value will trigger offline/0-delay.",
|
||||
cxxopts::value<float>());
|
||||
options.add_options()("max_batch_size", "The max runtime batch size when benchmarking", cxxopts::value<int>());
|
||||
options.add_options()("enable_trt_overlap", "Overlap TRT context preparation and execution",
|
||||
cxxopts::value<bool>()->default_value("false"));
|
||||
options.add_options()("enable_exp_delays", "Enables exponential delay distr to mimic real world request arrival",
|
||||
@ -1777,6 +1784,12 @@ int main(int argc, char* argv[])
|
||||
benchmarkParams.requestRate = result["request_rate"].as<float>();
|
||||
}
|
||||
|
||||
// Argument: request rate
|
||||
if (result.count("max_batch_size"))
|
||||
{
|
||||
benchmarkParams.maxBatchSize = result["max_batch_size"].as<int>();
|
||||
}
|
||||
|
||||
benchmarkParams.enableExpDelays = result["enable_exp_delays"].as<bool>();
|
||||
|
||||
// Argument: Enable batch stats output
|
||||
|
||||
@ -32,7 +32,7 @@ class BuildConfig:
|
||||
max_batch_size: int
|
||||
max_input_len: Optional[int] = None
|
||||
num_kv_heads: Optional[int] = None
|
||||
max_output_len: Optional[int] = None
|
||||
max_seq_len: Optional[int] = None
|
||||
max_beam_width: int = 1
|
||||
# TRT builder_optimization_level from 0 to 5
|
||||
builder_opt: Optional[int] = None
|
||||
@ -89,7 +89,7 @@ class EncDecBuildConfig:
|
||||
normalize_before: Optional[bool] = None
|
||||
max_encoder_input_len: Optional[int] = None
|
||||
max_decoder_input_len: Optional[int] = None
|
||||
max_output_len: Optional[int] = None
|
||||
max_seq_len: Optional[int] = None
|
||||
builder_opt: Optional[int] = None
|
||||
n_mels: Optional[int] = None
|
||||
skip_cross_qkv: bool = False
|
||||
@ -122,7 +122,7 @@ _allowed_configs = {
|
||||
n_positions=1024,
|
||||
max_batch_size=256,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"gpt_1.5b":
|
||||
@ -138,7 +138,7 @@ _allowed_configs = {
|
||||
n_positions=1024,
|
||||
max_batch_size=256,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"gpt_175b":
|
||||
@ -154,7 +154,7 @@ _allowed_configs = {
|
||||
n_positions=2048,
|
||||
max_batch_size=64,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"gpt_350m_moe":
|
||||
@ -170,7 +170,7 @@ _allowed_configs = {
|
||||
n_positions=1024,
|
||||
max_batch_size=256,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
moe_num_experts=8,
|
||||
moe_top_k=1,
|
||||
@ -188,7 +188,7 @@ _allowed_configs = {
|
||||
n_positions=1024,
|
||||
max_batch_size=256,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
quantization="int8_sq_per_tensor",
|
||||
)),
|
||||
@ -205,7 +205,7 @@ _allowed_configs = {
|
||||
n_positions=1024,
|
||||
max_batch_size=256,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
quantization="int8_sq_per_token_channel",
|
||||
)),
|
||||
@ -222,7 +222,7 @@ _allowed_configs = {
|
||||
n_positions=1024,
|
||||
max_batch_size=256,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
position_embedding_type='rope_gpt_neox',
|
||||
rotary_pct=0.5,
|
||||
@ -241,7 +241,7 @@ _allowed_configs = {
|
||||
n_positions=2048,
|
||||
max_batch_size=256,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
pre_norm=False,
|
||||
do_layer_norm_before=False,
|
||||
@ -259,7 +259,7 @@ _allowed_configs = {
|
||||
n_positions=2048,
|
||||
max_batch_size=256,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
pre_norm=False,
|
||||
do_layer_norm_before=True,
|
||||
@ -277,7 +277,7 @@ _allowed_configs = {
|
||||
n_positions=2048,
|
||||
max_batch_size=256,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
pre_norm=False,
|
||||
do_layer_norm_before=True,
|
||||
@ -295,7 +295,7 @@ _allowed_configs = {
|
||||
n_positions=2048,
|
||||
max_batch_size=256,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
pre_norm=False,
|
||||
do_layer_norm_before=True,
|
||||
@ -313,7 +313,7 @@ _allowed_configs = {
|
||||
n_positions=2048,
|
||||
max_batch_size=64,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
pre_norm=True,
|
||||
do_layer_norm_before=True,
|
||||
@ -332,7 +332,7 @@ _allowed_configs = {
|
||||
n_positions=8192,
|
||||
max_batch_size=256,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"starcoder2_3b":
|
||||
@ -351,7 +351,7 @@ _allowed_configs = {
|
||||
rotary_pct=1.0,
|
||||
max_batch_size=256,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"llama_7b":
|
||||
@ -368,7 +368,7 @@ _allowed_configs = {
|
||||
inter_size=11008,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"llama_13b":
|
||||
@ -385,7 +385,7 @@ _allowed_configs = {
|
||||
inter_size=13824,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"llama_30b":
|
||||
@ -402,7 +402,7 @@ _allowed_configs = {
|
||||
inter_size=17920,
|
||||
max_batch_size=64,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"llama_70b":
|
||||
@ -420,7 +420,7 @@ _allowed_configs = {
|
||||
inter_size=28672,
|
||||
max_batch_size=64,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"llama_70b_long_context":
|
||||
@ -437,7 +437,7 @@ _allowed_configs = {
|
||||
inter_size=28672,
|
||||
max_batch_size=16,
|
||||
max_input_len=8000,
|
||||
max_output_len=200,
|
||||
max_seq_len=8200,
|
||||
builder_opt=None,
|
||||
enable_multi_block_mode=True)),
|
||||
"llama_70b_long_generation":
|
||||
@ -454,7 +454,7 @@ _allowed_configs = {
|
||||
inter_size=28672,
|
||||
max_batch_size=64,
|
||||
max_input_len=200,
|
||||
max_output_len=16384,
|
||||
max_seq_len=16584,
|
||||
builder_opt=None,
|
||||
enable_multi_block_mode=True)),
|
||||
"llama_70b_sq_per_tensor":
|
||||
@ -471,7 +471,7 @@ _allowed_configs = {
|
||||
inter_size=28672,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
quantization="int8_sq_per_tensor")),
|
||||
"mixtral_8x7b":
|
||||
@ -489,7 +489,7 @@ _allowed_configs = {
|
||||
inter_size=14336,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
moe_num_experts=8,
|
||||
moe_top_k=2,
|
||||
@ -508,7 +508,7 @@ _allowed_configs = {
|
||||
rotary_dim=64,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"gptneox_20b":
|
||||
@ -525,7 +525,7 @@ _allowed_configs = {
|
||||
rotary_dim=24,
|
||||
max_batch_size=16,
|
||||
max_input_len=512,
|
||||
max_output_len=512,
|
||||
max_seq_len=1024,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"chatglm_6b":
|
||||
@ -543,7 +543,7 @@ _allowed_configs = {
|
||||
n_positions=2048,
|
||||
max_batch_size=256,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
remove_input_padding=False,
|
||||
)),
|
||||
@ -562,7 +562,7 @@ _allowed_configs = {
|
||||
n_positions=2048,
|
||||
max_batch_size=256,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
remove_input_padding=False,
|
||||
)),
|
||||
@ -581,7 +581,7 @@ _allowed_configs = {
|
||||
n_positions=2048,
|
||||
max_batch_size=256,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
remove_input_padding=False,
|
||||
)),
|
||||
@ -600,7 +600,7 @@ _allowed_configs = {
|
||||
n_positions=1024,
|
||||
max_batch_size=128,
|
||||
max_input_len=1024,
|
||||
max_output_len=256,
|
||||
max_seq_len=1280,
|
||||
builder_opt=None,
|
||||
remove_input_padding=False,
|
||||
)),
|
||||
@ -617,7 +617,7 @@ _allowed_configs = {
|
||||
n_positions=2048,
|
||||
max_batch_size=32,
|
||||
max_input_len=1024,
|
||||
max_output_len=1024,
|
||||
max_seq_len=2048,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"bloom_176b":
|
||||
@ -633,7 +633,7 @@ _allowed_configs = {
|
||||
n_positions=2048,
|
||||
max_batch_size=8,
|
||||
max_input_len=1024,
|
||||
max_output_len=1024,
|
||||
max_seq_len=2048,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"bert_base":
|
||||
@ -703,7 +703,7 @@ _allowed_configs = {
|
||||
n_positions=2048,
|
||||
max_batch_size=256,
|
||||
max_input_len=1024,
|
||||
max_output_len=1024,
|
||||
max_seq_len=2048,
|
||||
builder_opt=None,
|
||||
bias=True,
|
||||
use_alibi=True,
|
||||
@ -724,7 +724,7 @@ _allowed_configs = {
|
||||
n_positions=2048,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
bias=False,
|
||||
use_alibi=False,
|
||||
@ -745,7 +745,7 @@ _allowed_configs = {
|
||||
n_positions=2048,
|
||||
max_batch_size=64,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
bias=False,
|
||||
use_alibi=False,
|
||||
@ -766,7 +766,7 @@ _allowed_configs = {
|
||||
n_positions=2048,
|
||||
max_batch_size=8,
|
||||
max_input_len=1024,
|
||||
max_output_len=1024,
|
||||
max_seq_len=2048,
|
||||
builder_opt=None,
|
||||
bias=False,
|
||||
use_alibi=False,
|
||||
@ -791,7 +791,7 @@ _allowed_configs = {
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
max_seq_len=201,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"t5_base":
|
||||
@ -812,7 +812,7 @@ _allowed_configs = {
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
max_seq_len=201,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"t5_large":
|
||||
@ -833,7 +833,7 @@ _allowed_configs = {
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
max_seq_len=201,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"t5_3b":
|
||||
@ -854,7 +854,7 @@ _allowed_configs = {
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
max_seq_len=201,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"t5_11b":
|
||||
@ -875,7 +875,7 @@ _allowed_configs = {
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
max_seq_len=201,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"flan_t5_small":
|
||||
@ -897,7 +897,7 @@ _allowed_configs = {
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
max_seq_len=201,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"flan_t5_base":
|
||||
@ -919,7 +919,7 @@ _allowed_configs = {
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
max_seq_len=201,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"flan_t5_large":
|
||||
@ -941,7 +941,7 @@ _allowed_configs = {
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
max_seq_len=201,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"flan_t5_xl":
|
||||
@ -963,7 +963,7 @@ _allowed_configs = {
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
max_seq_len=201,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"flan_t5_xxl":
|
||||
@ -985,7 +985,7 @@ _allowed_configs = {
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
max_seq_len=201,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"bart_large_cnn":
|
||||
@ -1008,7 +1008,7 @@ _allowed_configs = {
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
max_seq_len=201,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"mbart_large_50_many_to_one_mmt":
|
||||
@ -1030,7 +1030,7 @@ _allowed_configs = {
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
max_seq_len=201,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"baichuan_7b":
|
||||
@ -1047,7 +1047,7 @@ _allowed_configs = {
|
||||
inter_size=11008,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"baichuan2_7b_chat":
|
||||
@ -1064,7 +1064,7 @@ _allowed_configs = {
|
||||
inter_size=11008,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"baichuan_13b_chat":
|
||||
@ -1081,7 +1081,7 @@ _allowed_configs = {
|
||||
inter_size=13696,
|
||||
max_batch_size=64,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"baichuan2_13b_chat":
|
||||
@ -1098,7 +1098,7 @@ _allowed_configs = {
|
||||
inter_size=13696,
|
||||
max_batch_size=64,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"internlm_chat_7b":
|
||||
@ -1116,7 +1116,7 @@ _allowed_configs = {
|
||||
inter_size=11008,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
bias=True,
|
||||
)),
|
||||
@ -1135,7 +1135,7 @@ _allowed_configs = {
|
||||
inter_size=13824,
|
||||
max_batch_size=64,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
bias=False,
|
||||
)),
|
||||
@ -1152,7 +1152,7 @@ _allowed_configs = {
|
||||
inter_size=22016,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
bias=False)),
|
||||
"qwen_14b_chat":
|
||||
@ -1169,7 +1169,7 @@ _allowed_configs = {
|
||||
inter_size=27392,
|
||||
max_batch_size=64,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"qwen1.5_7b_chat":
|
||||
@ -1185,7 +1185,7 @@ _allowed_configs = {
|
||||
inter_size=11008,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
bias=False)),
|
||||
"qwen1.5_14b_chat":
|
||||
@ -1202,7 +1202,7 @@ _allowed_configs = {
|
||||
inter_size=13696,
|
||||
max_batch_size=64,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
max_seq_len=712,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"mamba_2.8b":
|
||||
@ -1218,7 +1218,7 @@ _allowed_configs = {
|
||||
n_positions=8192,
|
||||
max_batch_size=64,
|
||||
max_input_len=1024,
|
||||
max_output_len=1024,
|
||||
max_seq_len=2048,
|
||||
state_size=16,
|
||||
conv_kernel=4,
|
||||
rnn_hidden_size=5120,
|
||||
@ -1238,7 +1238,7 @@ _allowed_configs = {
|
||||
n_positions=8192,
|
||||
max_batch_size=64,
|
||||
max_input_len=1024,
|
||||
max_output_len=1024,
|
||||
max_seq_len=2048,
|
||||
state_size=16,
|
||||
conv_kernel=4,
|
||||
rnn_hidden_size=4096,
|
||||
@ -1258,7 +1258,7 @@ _allowed_configs = {
|
||||
n_positions=8192,
|
||||
max_batch_size=64,
|
||||
max_input_len=1024,
|
||||
max_output_len=1024,
|
||||
max_seq_len=2048,
|
||||
state_size=16,
|
||||
conv_kernel=4,
|
||||
rnn_hidden_size=3072,
|
||||
@ -1278,7 +1278,7 @@ _allowed_configs = {
|
||||
n_positions=8192,
|
||||
max_batch_size=64,
|
||||
max_input_len=1024,
|
||||
max_output_len=1024,
|
||||
max_seq_len=2048,
|
||||
state_size=16,
|
||||
conv_kernel=4,
|
||||
rnn_hidden_size=2048,
|
||||
@ -1298,7 +1298,7 @@ _allowed_configs = {
|
||||
n_positions=8192,
|
||||
max_batch_size=64,
|
||||
max_input_len=1024,
|
||||
max_output_len=1024,
|
||||
max_seq_len=2048,
|
||||
state_size=16,
|
||||
conv_kernel=4,
|
||||
rnn_hidden_size=1536,
|
||||
@ -1323,7 +1323,7 @@ _allowed_configs = {
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1500,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
max_seq_len=201,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"recurrentgemma_2b":
|
||||
@ -1341,7 +1341,7 @@ _allowed_configs = {
|
||||
n_positions=8192,
|
||||
max_batch_size=64,
|
||||
max_input_len=1024,
|
||||
max_output_len=1024,
|
||||
max_seq_len=2048,
|
||||
position_embedding_type='rope_gpt_neox',
|
||||
rotary_pct=0.5,
|
||||
conv_kernel=4,
|
||||
|
||||
@ -184,6 +184,15 @@ def parse_arguments():
|
||||
help=
|
||||
('If this option is specified, it will override the max output len of '
|
||||
'TRT engines to the specified value instead of using pre-defined one'))
|
||||
parser.add_argument(
|
||||
'--max_seq_len',
|
||||
'--max_decoder_seq_len',
|
||||
dest='max_seq_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
('If this option is specified, it will override the max sequence len of '
|
||||
'TRT engines to the specified value instead of using pre-defined one'))
|
||||
parser.add_argument(
|
||||
'--max_batch_size',
|
||||
type=int,
|
||||
@ -351,6 +360,21 @@ def main(args):
|
||||
rank = tensorrt_llm.mpi_rank()
|
||||
world_size = tensorrt_llm.mpi_world_size()
|
||||
|
||||
if args.max_output_len:
|
||||
logger.warning(
|
||||
'--max_output_len has been deprecated in favor of --max_seq_len')
|
||||
if args.max_input_len:
|
||||
if args.max_seq_len:
|
||||
logger.warning(
|
||||
'--max_seq_len has been overwritten due to --max_output_len being specified'
|
||||
)
|
||||
args.max_seq_len = args.max_input_len + args.max_output_len
|
||||
else:
|
||||
raise Exception(
|
||||
f"--max_output_len is specified but not --max_input_len")
|
||||
|
||||
del args.max_output_len
|
||||
|
||||
# TODO: Re-enable memory monitor for multi-gpu benchmarks.
|
||||
# Current Mem Monitor will cause benchmark script hang
|
||||
# because MPI does not work well with multiprocessing.
|
||||
|
||||
@ -136,6 +136,15 @@ def parse_arguments():
|
||||
help=
|
||||
('If this option is specified, it will override the max output len of '
|
||||
'TRT engines to the specified value instead of using pre-defined one'))
|
||||
parser.add_argument(
|
||||
'--max_seq_len',
|
||||
'--max_decoder_seq_len',
|
||||
dest='max_seq_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
('If this option is specified, it will override the max sequence len of '
|
||||
'TRT engines to the specified value instead of using pre-defined one'))
|
||||
parser.add_argument(
|
||||
'--max_batch_size',
|
||||
type=int,
|
||||
@ -254,8 +263,24 @@ def build_gpt(args):
|
||||
if args.max_batch_size is None else args.max_batch_size
|
||||
max_input_len = build_config['max_input_len'] \
|
||||
if args.max_input_len is None else args.max_input_len
|
||||
max_output_len = build_config['max_output_len'] \
|
||||
if args.max_output_len is None else args.max_output_len
|
||||
|
||||
if args.max_output_len:
|
||||
logger.warning(
|
||||
'--max_output_len has been deprecated in favor of --max_seq_len')
|
||||
if args.max_input_len:
|
||||
if args.max_seq_len:
|
||||
logger.warning(
|
||||
'--max_seq_len has been overwritten due to --max_output_len being specified'
|
||||
)
|
||||
args.max_seq_len = args.max_input_len + args.max_output_len
|
||||
else:
|
||||
raise Exception(
|
||||
f"max_output_len is specified but not max_input_len")
|
||||
|
||||
del args.max_output_len
|
||||
|
||||
max_seq_len = build_config['max_seq_len'] \
|
||||
if args.max_seq_len is None else args.max_seq_len
|
||||
max_beam_width = build_config['max_beam_width'] \
|
||||
if args.max_beam_width is None else args.max_beam_width
|
||||
|
||||
@ -308,7 +333,7 @@ def build_gpt(args):
|
||||
max_batch_size=max_batch_size,
|
||||
max_beam_width=max_beam_width,
|
||||
max_input_len=max_input_len,
|
||||
max_output_len=max_output_len,
|
||||
max_seq_len=max_seq_len,
|
||||
max_num_tokens=max_num_tokens,
|
||||
int8=(quant_mode.has_act_and_weight_quant()
|
||||
or quant_mode.is_int8_weight_only()),
|
||||
@ -675,7 +700,6 @@ def build_gpt(args):
|
||||
config['quantization'].update({
|
||||
'has_zero_point': False,
|
||||
'pre_quant_scale': True,
|
||||
'exclude_modules': [],
|
||||
})
|
||||
config = PretrainedConfig.from_dict(config)
|
||||
tensorrt_llm_model = tensorrt_llm.models.FalconForCausalLM(config)
|
||||
@ -759,7 +783,6 @@ def build_gpt(args):
|
||||
"group_size": 128,
|
||||
"has_zero_point": False,
|
||||
"pre_quant_scale": True,
|
||||
"exclude_modules": [],
|
||||
})
|
||||
elif 'gptq' in args.quantization:
|
||||
config['quantization'].update({
|
||||
@ -968,14 +991,14 @@ def build_gpt(args):
|
||||
|
||||
# Forward
|
||||
print(
|
||||
f"max_batch_size: {max_batch_size}, max_input_len: {max_input_len}, max_output_len: {max_output_len}, max_beam_width: {max_beam_width}"
|
||||
f"max_batch_size: {max_batch_size}, max_input_len: {max_input_len}, max_seq_len: {max_seq_len}, max_beam_width: {max_beam_width}"
|
||||
)
|
||||
# NOTE: all other models use PretrainedModel.prepare_inputs(...)
|
||||
# except RecurrentGemmaForCausalLM and MambaForCausalLM
|
||||
inputs = tensorrt_llm_model.prepare_inputs(
|
||||
max_batch_size=max_batch_size,
|
||||
max_input_len=max_input_len,
|
||||
max_seq_len=max_input_len + max_output_len,
|
||||
max_seq_len=max_seq_len,
|
||||
max_num_tokens=max_num_tokens,
|
||||
use_cache=True,
|
||||
max_beam_width=max_beam_width,
|
||||
@ -1231,7 +1254,7 @@ def enc_dec_build_helper(component, config, args):
|
||||
max_batch_size=config['max_batch_size'],
|
||||
max_beam_width=config['max_beam_width'],
|
||||
max_decoder_input_len=config['max_decoder_input_len'],
|
||||
max_output_len=config['max_output_len'],
|
||||
max_seq_len=config['max_seq_len'],
|
||||
max_encoder_input_len=config['max_encoder_input_len'],
|
||||
opt_level=config['builder_opt'],
|
||||
cross_attention=(component == 'decoder'),
|
||||
@ -1473,7 +1496,7 @@ def enc_dec_build_helper(component, config, args):
|
||||
max_batch_size=config['max_batch_size'],
|
||||
max_beam_width=config['max_beam_width'],
|
||||
max_decoder_input_len=config['max_decoder_input_len'],
|
||||
max_seq_len=config['max_output_len'],
|
||||
max_seq_len=config['max_seq_len'],
|
||||
max_encoder_input_len=1500, # n_audio_ctx
|
||||
)
|
||||
tllm_model(**inputs)
|
||||
@ -1482,7 +1505,7 @@ def enc_dec_build_helper(component, config, args):
|
||||
max_batch_size=config['max_batch_size'],
|
||||
max_beam_width=config['max_beam_width'],
|
||||
max_decoder_input_len=config['max_decoder_input_len'],
|
||||
max_seq_len=config['max_output_len'],
|
||||
max_seq_len=config['max_seq_len'],
|
||||
max_encoder_input_len=config['max_encoder_input_len'],
|
||||
)
|
||||
|
||||
@ -1548,8 +1571,24 @@ def build_enc_dec(args):
|
||||
build_config['max_encoder_input_len'] = build_config['max_encoder_input_len'] \
|
||||
if args.max_input_len is None else args.max_input_len
|
||||
build_config['max_decoder_input_len'] = 1
|
||||
build_config['max_output_len'] = build_config['max_output_len'] \
|
||||
if args.max_output_len is None else args.max_output_len
|
||||
|
||||
if args.max_output_len:
|
||||
logger.warning(
|
||||
'--max_output_len has been deprecated in favor of --max_seq_len')
|
||||
if args.max_input_len:
|
||||
if args.max_seq_len:
|
||||
logger.warning(
|
||||
'--max_seq_len has been overwritten due to --max_output_len being specified'
|
||||
)
|
||||
args.max_seq_len = args.max_input_len + args.max_output_len
|
||||
else:
|
||||
raise Exception(
|
||||
f"max_output_len is specified but not max_input_len")
|
||||
|
||||
del args.max_output_len
|
||||
|
||||
build_config['max_seq_len'] = build_config['max_seq_len'] \
|
||||
if args.max_seq_len is None else args.max_seq_len
|
||||
build_config[
|
||||
'max_beam_width'] = 1 if args.max_beam_width is None else args.max_beam_width
|
||||
|
||||
|
||||
@ -115,7 +115,7 @@ class EncDecBenchmark(BaseBenchmark):
|
||||
self.max_batch_size = config["builder_config"]["max_batch_size"]
|
||||
self.max_input_len = config["builder_config"][
|
||||
"max_encoder_input_len"]
|
||||
self.max_output_len = config["builder_config"]["max_output_len"]
|
||||
self.max_seq_len = config["builder_config"]["max_seq_len"]
|
||||
self.n_mels = config["builder_config"][
|
||||
'n_mels'] if 'whisper' in self.model_name else 0
|
||||
|
||||
@ -180,8 +180,8 @@ class EncDecBenchmark(BaseBenchmark):
|
||||
if args.max_batch_size is None else args.max_batch_size
|
||||
self.max_input_len = build_config['max_encoder_input_len'] \
|
||||
if args.max_input_len is None else args.max_input_len
|
||||
self.max_output_len = build_config['max_output_len'] \
|
||||
if args.max_output_len is None else args.max_output_len
|
||||
self.max_seq_len = build_config['max_seq_len'] \
|
||||
if args.max_seq_len is None else args.max_seq_len
|
||||
self.n_mels = build_config[
|
||||
'n_mels'] if 'whisper' in self.model_name else 0
|
||||
# Build engine
|
||||
@ -218,10 +218,11 @@ class EncDecBenchmark(BaseBenchmark):
|
||||
f"[WARNING] whisper benchmark is input_len=1500, no text prompt, output_len=arbitrary"
|
||||
)
|
||||
for inlen, outlen in self.in_out_lens:
|
||||
if (inlen > self.max_input_len or outlen > self.max_output_len):
|
||||
if (inlen > self.max_input_len
|
||||
or inlen + outlen > self.max_seq_len):
|
||||
print(
|
||||
f"[WARNING] check inlen({inlen}) <= max_inlen({self.max_input_len}) and "
|
||||
f"outlen({outlen}) <= max_outlen({self.max_output_len}) failed, skipping."
|
||||
f"inlen({inlen}) + outlen({outlen}) <= max_seqlen({self.max_seq_len}) failed, skipping."
|
||||
)
|
||||
continue
|
||||
for batch_size in self.batch_sizes:
|
||||
|
||||
@ -88,8 +88,8 @@ class GPTBenchmark(BaseBenchmark):
|
||||
self.max_batch_size = args.max_batch_size
|
||||
if args.max_input_len is not None:
|
||||
self.max_input_len = args.max_input_len
|
||||
if args.max_output_len is not None:
|
||||
self.max_output_len = args.max_output_len
|
||||
if args.max_seq_len is not None:
|
||||
self.max_seq_len = args.max_seq_len
|
||||
|
||||
self.quant_config = get_quant_config(args.quantization)
|
||||
self.quant_mode = self.quant_config.quant_mode
|
||||
@ -209,10 +209,10 @@ class GPTBenchmark(BaseBenchmark):
|
||||
|
||||
def get_config(self):
|
||||
for inlen, outlen in self.in_out_lens:
|
||||
if inlen > self.max_input_len or outlen > self.max_output_len:
|
||||
if inlen > self.max_input_len or inlen + outlen > self.max_seq_len:
|
||||
print(
|
||||
f'[WARNING] check inlen({inlen}) <= max_inlen({self.max_input_len}) and '
|
||||
f'outlen({outlen}) <= max_outlen({self.max_output_len}) failed, skipping.'
|
||||
f'[WARNING] check inlen({inlen}) <= max_inlen({self.max_input_len}) or '
|
||||
f'seqlen({inlen + outlen}) <= max_seq_len({self.max_seq_len}) failed, skipping.'
|
||||
)
|
||||
continue
|
||||
for batch_size in self.batch_sizes:
|
||||
@ -314,7 +314,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
output_length=outlen,
|
||||
max_batch_size=self.build_config.max_batch_size,
|
||||
max_input_len=self.build_config.max_input_len,
|
||||
max_output_len=self.build_config.max_output_len,
|
||||
max_seq_len=self.build_config.max_seq_len,
|
||||
max_beam_width=self.build_config.max_beam_width)
|
||||
for k, v in build_args.items():
|
||||
tensorrt_llm.logger.info(f"{prefix} {k}:{v}")
|
||||
|
||||
@ -84,8 +84,8 @@ class gptSessionBenchmarker:
|
||||
max_batch_size,
|
||||
"--max_input_len",
|
||||
max_isl,
|
||||
"--max_output_len",
|
||||
max_osl,
|
||||
"--max_seq_len",
|
||||
max_osl + max_isl,
|
||||
"--context_fmha",
|
||||
"enable",
|
||||
# Set the attention plugin data type.
|
||||
|
||||
@ -140,8 +140,8 @@ def get_trtllm_build_command(benchmark_cfg: BenchmarkConfig) -> List[str]:
|
||||
benchmark_cfg.world_size,
|
||||
"--max_input_len",
|
||||
max_isl,
|
||||
"--max_output_len",
|
||||
max_osl,
|
||||
"--max_seq_len",
|
||||
max_osl + max_isl,
|
||||
"--context_fmha",
|
||||
"enable",
|
||||
# Set the attention plugin data type.
|
||||
|
||||
@ -82,6 +82,9 @@ struct KvCacheStats
|
||||
SizeType32 freeNumBlocks;
|
||||
SizeType32 usedNumBlocks;
|
||||
SizeType32 toksPerBlock;
|
||||
SizeType32 allocTotalBlocks;
|
||||
SizeType32 allocNewBlocks;
|
||||
SizeType32 reusedBlocks;
|
||||
};
|
||||
|
||||
// Basic building block of a paged KV cache - a single
|
||||
@ -329,6 +332,16 @@ public:
|
||||
return mFreePrimaryBlocks.size();
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getNumAllocTotalBlocks() const
|
||||
{
|
||||
return mAllocTotalBlocks;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getNumAllocNewBlocks() const
|
||||
{
|
||||
return mAllocNewBlocks;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getNumReusedBlocks() const noexcept
|
||||
{
|
||||
return mReusedBlocks;
|
||||
@ -496,6 +509,21 @@ public:
|
||||
return mBlockManager.getNumFreeBlocks();
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getNumAllocTotalBlocks() const
|
||||
{
|
||||
return mBlockManager.getNumAllocTotalBlocks();
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getNumAllocNewBlocks() const
|
||||
{
|
||||
return mBlockManager.getNumAllocNewBlocks();
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getNumReusedBlocks() const noexcept
|
||||
{
|
||||
return mBlockManager.getNumReusedBlocks();
|
||||
}
|
||||
|
||||
[[nodiscard]] KvCacheStats getKvCacheStats() const
|
||||
{
|
||||
KvCacheStats kvCacheStats;
|
||||
@ -503,6 +531,9 @@ public:
|
||||
kvCacheStats.freeNumBlocks = getNumFreeBlocks();
|
||||
kvCacheStats.usedNumBlocks = getUsedNumBlocks();
|
||||
kvCacheStats.toksPerBlock = getTokensPerBlock();
|
||||
kvCacheStats.allocTotalBlocks = getNumAllocTotalBlocks();
|
||||
kvCacheStats.allocNewBlocks = getNumAllocNewBlocks();
|
||||
kvCacheStats.reusedBlocks = getNumReusedBlocks();
|
||||
|
||||
return kvCacheStats;
|
||||
}
|
||||
|
||||
@ -62,7 +62,7 @@ public:
|
||||
using VecLogProbs = std::vector<float>;
|
||||
using BeamTokens = std::vector<VecTokens>;
|
||||
using TensorPtr = TTensor;
|
||||
using LogitsPostProcessor = std::function<void(RequestIdType, TensorPtr&, BeamTokens const&, TStream)>;
|
||||
using LogitsPostProcessor = std::function<void(RequestIdType, TensorPtr&, BeamTokens const&, TStream const&)>;
|
||||
|
||||
GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::shared_ptr<VecTokens> inputTokens,
|
||||
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
|
||||
@ -76,6 +76,7 @@ public:
|
||||
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt,
|
||||
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false,
|
||||
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt,
|
||||
bool applyLogitsPostProcessorBatched = false,
|
||||
std::optional<std::shared_ptr<VecTokens>> encoderInputTokens = std::nullopt, bool returnEncoderOutput = false)
|
||||
: mRequestId(requestId)
|
||||
, mPromptLen(inputTokens->size())
|
||||
@ -86,6 +87,7 @@ public:
|
||||
, mEndId(endId)
|
||||
, mPadId(padId)
|
||||
, mLogitsPostProcessor(logitsPostProcessor)
|
||||
, mApplyLogitsPostProcessorBatched(applyLogitsPostProcessorBatched)
|
||||
, mOrigPromptLen(mPromptLen)
|
||||
, mMaxSentTokenPos(mPromptLen - 1)
|
||||
, mEmbeddingBias(std::move(embeddingBias))
|
||||
@ -679,7 +681,7 @@ public:
|
||||
|
||||
void allocContextLogitsHost(SizeType32 vocabSizePadded, nvinfer1::DataType logitsDataType)
|
||||
{
|
||||
mContextLogitsHost = runtime::BufferManager::pinned(
|
||||
mContextLogitsHost = runtime::BufferManager::pinnedPool(
|
||||
runtime::ITensor::makeShape({mPromptLen, vocabSizePadded}), logitsDataType);
|
||||
}
|
||||
|
||||
@ -695,13 +697,13 @@ public:
|
||||
|
||||
void allocGenerationLogitsHost(SizeType32 vocabSizePadded, nvinfer1::DataType logitsDataType)
|
||||
{
|
||||
mGenerationLogitsHost = runtime::BufferManager::pinned(
|
||||
mGenerationLogitsHost = runtime::BufferManager::pinnedPool(
|
||||
runtime::ITensor::makeShape({mSamplingConfig.beamWidth, mMaxNewTokens, vocabSizePadded}), logitsDataType);
|
||||
}
|
||||
|
||||
void allocTargetModelAcceptedTokenLogitsHost(SizeType32 vocabSizePadded, nvinfer1::DataType logitsDataType)
|
||||
{
|
||||
mGenerationLogitsHost = runtime::BufferManager::pinned(
|
||||
mGenerationLogitsHost = runtime::BufferManager::pinnedPool(
|
||||
runtime::ITensor::makeShape({getNumDraftTokens() + 1, vocabSizePadded}), logitsDataType);
|
||||
}
|
||||
|
||||
@ -948,6 +950,7 @@ public:
|
||||
std::optional<TokenIdType> mPadId;
|
||||
std::optional<SizeType32> mSeqSlot;
|
||||
std::optional<LogitsPostProcessor> mLogitsPostProcessor;
|
||||
bool mApplyLogitsPostProcessorBatched;
|
||||
|
||||
protected:
|
||||
BeamTokens mTokens;
|
||||
@ -1073,20 +1076,24 @@ public:
|
||||
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt,
|
||||
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false,
|
||||
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt,
|
||||
bool applyLogitsPostProcessorBatched = false,
|
||||
std::optional<std::shared_ptr<VecTokens>> encoderInputTokens = std::nullopt, bool returnEncoderOutput = false)
|
||||
: Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId,
|
||||
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList),
|
||||
std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig),
|
||||
returnLogProbs, returnContextLogits, returnGenerationLogits, std::move(draftTokens), std::move(draftLogits),
|
||||
excludeInputFromOutput, std::move(logitsPostProcessor), std::move(encoderInputTokens), returnEncoderOutput)
|
||||
excludeInputFromOutput, std::move(logitsPostProcessor), applyLogitsPostProcessorBatched,
|
||||
std::move(encoderInputTokens), returnEncoderOutput)
|
||||
{
|
||||
}
|
||||
|
||||
LlmRequest(RequestIdType requestId, executor::Request const& Request,
|
||||
std::optional<Base::LogitsPostProcessor> logitsPostProcessor = std::nullopt)
|
||||
std::optional<Base::LogitsPostProcessor> logitsPostProcessor = std::nullopt,
|
||||
bool applyLogitsPostProcessorBatched = false)
|
||||
: Base(requestId, Request)
|
||||
{
|
||||
mLogitsPostProcessor = std::move(logitsPostProcessor);
|
||||
mApplyLogitsPostProcessorBatched = applyLogitsPostProcessorBatched;
|
||||
}
|
||||
|
||||
void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager)
|
||||
|
||||
@ -41,7 +41,7 @@ public:
|
||||
bool normalizeLogProbs = true, bool enableChunkedContext = false,
|
||||
PeftCacheManagerConfig const& peftCacheManagerConfig = PeftCacheManagerConfig{},
|
||||
executor::DecodingConfig decodingConfig = executor::DecodingConfig{}, float gpuWeightsPercent = 1,
|
||||
std::optional<SizeType32> maxBeamWidth = std::nullopt,
|
||||
std::optional<SizeType32> maxBeamWidth = std::nullopt, std::optional<SizeType32> maxBatchSize = std::nullopt,
|
||||
executor::SchedulerConfig const& schedulerConfig = executor::SchedulerConfig{})
|
||||
: kvCacheConfig{kvCacheConfig}
|
||||
, enableTrtOverlap{enableTrtOverlap}
|
||||
@ -52,6 +52,7 @@ public:
|
||||
, decodingConfig(std::move(decodingConfig))
|
||||
, gpuWeightsPercent(gpuWeightsPercent)
|
||||
, maxBeamWidth(maxBeamWidth)
|
||||
, maxBatchSize(maxBatchSize)
|
||||
, schedulerConfig{schedulerConfig}
|
||||
{
|
||||
}
|
||||
@ -62,7 +63,7 @@ public:
|
||||
executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext(),
|
||||
PeftCacheManagerConfig(executorConfig.getPeftCacheConfig().value_or(executor::PeftCacheConfig())),
|
||||
executorConfig.getDecodingConfig().value_or(executor::DecodingConfig{}),
|
||||
executorConfig.getGpuWeightsPercent(), executorConfig.getMaxBeamWidth(),
|
||||
executorConfig.getGpuWeightsPercent(), executorConfig.getMaxBeamWidth(), executorConfig.getMaxBatchSize(),
|
||||
executorConfig.getSchedulerConfig())
|
||||
{
|
||||
}
|
||||
@ -87,6 +88,7 @@ public:
|
||||
// Percentage of weights on the gpu at runtime
|
||||
float gpuWeightsPercent;
|
||||
std::optional<SizeType32> maxBeamWidth;
|
||||
std::optional<SizeType32> maxBatchSize;
|
||||
executor::SchedulerConfig schedulerConfig;
|
||||
};
|
||||
|
||||
|
||||
@ -263,6 +263,9 @@ public:
|
||||
std::optional<std::string> logitsPostProcessorName = std::nullopt,
|
||||
std::optional<VecTokens> encoderInputTokenIds = std::nullopt);
|
||||
|
||||
/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
|
||||
static auto constexpr kBatchedPostProcessorName = "batched";
|
||||
|
||||
Request(Request const& other);
|
||||
Request(Request&& other) noexcept;
|
||||
Request& operator=(Request const& other);
|
||||
@ -403,6 +406,14 @@ public:
|
||||
[[nodiscard]] std::optional<size_t> getHostCacheSize() const;
|
||||
[[nodiscard]] bool getOnboardBlocks() const;
|
||||
|
||||
void setEnableBlockReuse(bool enableBlockReuse);
|
||||
void setMaxTokens(SizeType32 maxTokens);
|
||||
void setMaxAttentionWindow(SizeType32 maxAttentionWindow);
|
||||
void setSinkTokenLength(SizeType32 sinkTokenLength);
|
||||
void setFreeGpuMemoryFraction(FloatType freeGpuMemoryFraction);
|
||||
void setHostCacheSize(size_t hostCacheSize);
|
||||
void setOnboardBlocks(bool onboardBlocks);
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
|
||||
@ -557,49 +568,39 @@ private:
|
||||
std::optional<size_t> mHostCacheSize;
|
||||
};
|
||||
|
||||
/// @brief Configuration class for Lookahead decoding.
|
||||
class LookaheadDecodingConfig
|
||||
struct LookaheadDecodingConfig
|
||||
{
|
||||
public:
|
||||
explicit LookaheadDecodingConfig(
|
||||
SizeType32 maxNgramSize, SizeType32 maxWindowSize, SizeType32 maxVerificationSetSize);
|
||||
LookaheadDecodingConfig(SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize);
|
||||
|
||||
explicit LookaheadDecodingConfig()
|
||||
: LookaheadDecodingConfig(1, 1, 0)
|
||||
{
|
||||
}
|
||||
|
||||
bool operator==(LookaheadDecodingConfig const& other) const;
|
||||
[[nodiscard]] std::tuple<SizeType32 const, SizeType32 const, SizeType32 const> get() const;
|
||||
[[nodiscard]] SizeType32 getWindowSize() const;
|
||||
[[nodiscard]] SizeType32 getNgramSize() const;
|
||||
[[nodiscard]] SizeType32 getVerificationSetSize() const;
|
||||
|
||||
void setMaxNgramSize(SizeType32);
|
||||
void setMaxWindowSize(SizeType32);
|
||||
void setMaxVerificationSetSize(SizeType32);
|
||||
[[nodiscard]] SizeType32 getMaxNgramSize() const;
|
||||
[[nodiscard]] SizeType32 getMaxWindowSize() const;
|
||||
[[nodiscard]] SizeType32 getMaxVerificationSetSize() const;
|
||||
/// @brief return <maxDecodingTokens, maxPathLen, maxDraftTokens, maxDraftPathLen>
|
||||
std::tuple<SizeType32, SizeType32, SizeType32, SizeType32> calculateSpeculativeResource() const;
|
||||
|
||||
/// @brief return true when `this` can be executed on resources defined by `that`
|
||||
bool isLE(LookaheadDecodingConfig const& that) const;
|
||||
|
||||
/// @brief return true when the parameter combination is valid.
|
||||
static bool isLegal(SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize) noexcept;
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
|
||||
// Number of tokens per NGram.
|
||||
SizeType32 mMaxNgramSize;
|
||||
// Number of NGrams in lookahead branch per step.
|
||||
SizeType32 mMaxWindowSize;
|
||||
SizeType32 mWindowSize;
|
||||
// Number of tokens per NGram.
|
||||
SizeType32 mNgramSize;
|
||||
// Number of NGrams in verification branch per step.
|
||||
SizeType32 mMaxVerificationSetSize;
|
||||
};
|
||||
|
||||
/// @brief Configuration class for explicit draft tokens decoding.
|
||||
class ExplicitDraftTokensConfig
|
||||
{
|
||||
public:
|
||||
explicit ExplicitDraftTokensConfig(float temperature);
|
||||
|
||||
bool operator==(ExplicitDraftTokensConfig const& other) const;
|
||||
|
||||
void setTemperature(float);
|
||||
[[nodiscard]] float getTemperature() const;
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
|
||||
// Sampling temperature.
|
||||
float mTemperature;
|
||||
SizeType32 mVerificationSetSize;
|
||||
};
|
||||
|
||||
/// @brief Configuration class for the speculative decoding.
|
||||
@ -608,8 +609,7 @@ class DecodingConfig
|
||||
public:
|
||||
explicit DecodingConfig(std::optional<DecodingMode> decodingMode = std::nullopt,
|
||||
std::optional<LookaheadDecodingConfig> lookaheadDecodingConfig = std::nullopt,
|
||||
std::optional<MedusaChoices> medusaChoices = std::nullopt,
|
||||
std::optional<ExplicitDraftTokensConfig> explicitDraftTokensConfig = std::nullopt);
|
||||
std::optional<MedusaChoices> medusaChoices = std::nullopt);
|
||||
|
||||
bool operator==(DecodingConfig const& other) const;
|
||||
|
||||
@ -620,7 +620,7 @@ public:
|
||||
|
||||
// Lookahead methods.
|
||||
/// @brief Sets lookahead decoding mode and config.
|
||||
void setLookaheadDecoding(LookaheadDecodingConfig const&);
|
||||
void setLookaheadDecoding(LookaheadDecodingConfig const& lookaheadDecodingConfig);
|
||||
[[nodiscard]] std::optional<LookaheadDecodingConfig> getLookaheadDecodingConfig() const;
|
||||
|
||||
// Medusa methods.
|
||||
@ -628,11 +628,6 @@ public:
|
||||
void setMedusaChoices(MedusaChoices const&);
|
||||
[[nodiscard]] std::optional<MedusaChoices> getMedusaChoices() const;
|
||||
|
||||
// ExplicitDraftTokens decoding methods.
|
||||
/// @brief Sets explicit draft tokens decoding mode and config.
|
||||
void setExplicitDraftTokens(ExplicitDraftTokensConfig const&);
|
||||
[[nodiscard]] std::optional<ExplicitDraftTokensConfig> getExplicitDraftTokensConfig() const;
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
|
||||
@ -642,8 +637,6 @@ private:
|
||||
std::optional<LookaheadDecodingConfig> mLookaheadDecodingConfig;
|
||||
// Medusa params.
|
||||
std::optional<MedusaChoices> mMedusaChoices;
|
||||
// Explicit draft tokens params.
|
||||
std::optional<ExplicitDraftTokensConfig> mExplicitDraftTokensConfig;
|
||||
};
|
||||
|
||||
/// @brief Configuration class for the model executor
|
||||
@ -654,10 +647,11 @@ public:
|
||||
KvCacheConfig const& kvCacheConfig = KvCacheConfig(), bool enableChunkedContext = false,
|
||||
bool normalizeLogProbs = true, SizeType32 iterStatsMaxIterations = kDefaultIterStatsMaxIterations,
|
||||
SizeType32 requestStatsMaxIterations = kDefaultRequestStatsMaxIterations,
|
||||
BatchingType batchingType = BatchingType::kINFLIGHT,
|
||||
BatchingType batchingType = BatchingType::kINFLIGHT, std::optional<SizeType32> maxBatchSize = std::nullopt,
|
||||
std::optional<ParallelConfig> parallelConfig = std::nullopt,
|
||||
std::optional<PeftCacheConfig> const& peftCacheConfig = std::nullopt,
|
||||
std::optional<LogitsPostProcessorMap> logitsPostProcessorMap = std::nullopt,
|
||||
std::optional<LogitsPostProcessorBatched> logitsPostProcessorBatched = std::nullopt,
|
||||
std::optional<DecodingConfig> decodingConfig = std::nullopt, float gpuWeightsPercent = 1);
|
||||
|
||||
[[nodiscard]] SizeType32 getMaxBeamWidth() const;
|
||||
@ -668,13 +662,16 @@ public:
|
||||
[[nodiscard]] SizeType32 getIterStatsMaxIterations() const;
|
||||
[[nodiscard]] SizeType32 getRequestStatsMaxIterations() const;
|
||||
[[nodiscard]] BatchingType getBatchingType() const;
|
||||
[[nodiscard]] std::optional<SizeType32> getMaxBatchSize() const;
|
||||
[[nodiscard]] std::optional<ParallelConfig> getParallelConfig() const;
|
||||
[[nodiscard]] std::optional<PeftCacheConfig> getPeftCacheConfig() const;
|
||||
[[nodiscard]] std::optional<LogitsPostProcessorMap> getLogitsPostProcessorMap() const;
|
||||
[[nodiscard]] std::optional<LogitsPostProcessorBatched> getLogitsPostProcessorBatched() const;
|
||||
[[nodiscard]] std::optional<DecodingConfig> getDecodingConfig() const;
|
||||
[[nodiscard]] float getGpuWeightsPercent() const;
|
||||
|
||||
void setMaxBeamWidth(SizeType32 maxBeamWidth);
|
||||
void setMaxBatchSize(SizeType32 maxBatchSize);
|
||||
void setSchedulerConfig(SchedulerConfig const& schedulerConfig);
|
||||
void setKvCacheConfig(KvCacheConfig const& kvCacheConfig);
|
||||
void setEnableChunkedContext(bool enableChunkedContext);
|
||||
@ -685,6 +682,7 @@ public:
|
||||
void setParallelConfig(ParallelConfig const& parallelConfig);
|
||||
void setPeftCacheConfig(PeftCacheConfig const& peftCacheConfig);
|
||||
void setLogitsPostProcessorMap(LogitsPostProcessorMap const& logitsPostProcessorMap);
|
||||
void setLogitsPostProcessorBatched(LogitsPostProcessorBatched const& logitsPostProcessorBatched);
|
||||
void setDecodingConfig(DecodingConfig const& decodingConfig);
|
||||
void setGpuWeightsPercent(float const& gpuWeightsPercent);
|
||||
|
||||
@ -715,10 +713,14 @@ private:
|
||||
/// @brief The type of batching strategy to use. See BatchingType.
|
||||
BatchingType mBatchingType;
|
||||
|
||||
/// @brief The max batch size of requests
|
||||
std::optional<SizeType32> mMaxBatchSize;
|
||||
|
||||
/// @brief The parallel execution configuration.
|
||||
std::optional<ParallelConfig> mParallelConfig;
|
||||
std::optional<PeftCacheConfig> mPeftCacheConfig;
|
||||
std::optional<LogitsPostProcessorMap> mLogitsPostProcessorMap;
|
||||
std::optional<LogitsPostProcessorBatched> mLogitsPostProcessorBatched;
|
||||
/// @brief Decoding configuration.
|
||||
std::optional<DecodingConfig> mDecodingConfig;
|
||||
float mGpuWeightsPercent;
|
||||
|
||||
@ -112,11 +112,6 @@ public:
|
||||
static void serialize(LookaheadDecodingConfig const& lookaheadDecodingConfig, std::ostream& os);
|
||||
static size_t serializedSize(LookaheadDecodingConfig const& lookaheadDecodingConfig);
|
||||
|
||||
// ExplicitDraftTokensConfig
|
||||
static ExplicitDraftTokensConfig deserializeExplicitDraftTokensConfig(std::istream& is);
|
||||
static void serialize(ExplicitDraftTokensConfig const& ExplicitDraftTokensConfig, std::ostream& os);
|
||||
static size_t serializedSize(ExplicitDraftTokensConfig const& ExplicitDraftTokensConfig);
|
||||
|
||||
// DecodingConfig
|
||||
static DecodingConfig deserializeDecodingConfig(std::istream& is);
|
||||
static void serialize(DecodingConfig const& decodingConfig, std::ostream& os);
|
||||
|
||||
@ -53,8 +53,10 @@ using IterationType = std::uint64_t;
|
||||
using RandomSeedType = std::uint64_t;
|
||||
using VecLogProbs = std::vector<FloatType>;
|
||||
using StreamPtr = std::shared_ptr<tensorrt_llm::runtime::CudaStream>;
|
||||
using LogitsPostProcessor = std::function<void(IdType, Tensor&, BeamTokens const&, StreamPtr&)>;
|
||||
using LogitsPostProcessor = std::function<void(IdType, Tensor&, BeamTokens const&, StreamPtr const&)>;
|
||||
using LogitsPostProcessorMap = std::unordered_map<std::string, LogitsPostProcessor>;
|
||||
using LogitsPostProcessorBatched = std::function<void(std::vector<IdType> const&, std::vector<Tensor>&,
|
||||
std::vector<std::reference_wrapper<BeamTokens const>> const&, StreamPtr const&)>;
|
||||
using MedusaChoices = std::vector<std::vector<SizeType32>>;
|
||||
|
||||
enum class DataType
|
||||
@ -224,6 +226,12 @@ struct KvCacheStats
|
||||
SizeType32 usedNumBlocks;
|
||||
/// @brief Number of tokens per block
|
||||
SizeType32 tokensPerBlock;
|
||||
/// @brief Number of total allocated block
|
||||
SizeType32 allocTotalBlocks;
|
||||
/// @brief Number of newly allocated block
|
||||
SizeType32 allocNewBlocks;
|
||||
/// @brief Number of reused block
|
||||
SizeType32 reusedBlocks;
|
||||
};
|
||||
|
||||
/// @brief Struct that holds the stats of static batching models for a single iteration
|
||||
@ -267,6 +275,8 @@ struct IterationStats
|
||||
std::string timestamp;
|
||||
/// @brief Iteration id
|
||||
IterationType iter;
|
||||
/// @brief Iteration latency (ms)
|
||||
double iterLatencyMS;
|
||||
/// @brief Number of active requests
|
||||
SizeType32 numActiveRequests;
|
||||
/// @brief Number of max active requests
|
||||
@ -717,6 +727,8 @@ static_assert(!DecodingMode::Lookahead().isBeamSearch());
|
||||
static_assert(!DecodingMode::Lookahead().isMedusa());
|
||||
static_assert(!DecodingMode::Lookahead().isExplicitDraftTokens());
|
||||
static_assert(DecodingMode::Lookahead().isUseStopCriteria());
|
||||
static_assert(DecodingMode::Lookahead().isUseStopWords());
|
||||
static_assert(DecodingMode::Lookahead().isUseExplicitEosStop());
|
||||
static_assert(DecodingMode::Lookahead().isLookahead());
|
||||
|
||||
static_assert(!DecodingMode::ExplicitDraftTokens().isAuto());
|
||||
|
||||
@ -29,13 +29,13 @@ class DecodingInput
|
||||
public:
|
||||
using TensorPtr = std::shared_ptr<ITensor const>;
|
||||
|
||||
DecodingInput(SizeType32 maxLength, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength,
|
||||
SizeType32 maxBatchSize, TensorPtr logits, TensorPtr endIds)
|
||||
DecodingInput(SizeType32 maxLength, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, SizeType32 batchSize,
|
||||
TensorPtr logits, TensorPtr endIds)
|
||||
: step{maxLength}
|
||||
, maxLength{maxLength}
|
||||
, maxAttentionWindow{maxAttentionWindow}
|
||||
, sinkTokenLength{sinkTokenLength}
|
||||
, maxBatchSize{maxBatchSize}
|
||||
, batchSize{batchSize}
|
||||
, maxStopWordsLen{0}
|
||||
, maxBadWordsLen{0}
|
||||
, logits{std::move(logits)}
|
||||
@ -50,46 +50,68 @@ public:
|
||||
SizeType32 maxLength;
|
||||
SizeType32 maxAttentionWindow;
|
||||
SizeType32 sinkTokenLength;
|
||||
SizeType32 maxBatchSize;
|
||||
SizeType32 batchSize;
|
||||
SizeType32 maxStopWordsLen; // The maximum value in the `stopWordsLens` tensor
|
||||
SizeType32 maxBadWordsLen; // The maximum value in the `badWordsLens` tensor
|
||||
TensorPtr logits; // [batchSize, beamWidth, vocabSizePadded], on gpu
|
||||
std::optional<std::vector<TensorPtr>>
|
||||
logitsVec; // vector of size [batchSize] contains logits of size [beamWidth, vocabSizePadded], on gpu
|
||||
TensorPtr endIds; // [maxBatchSize * beamWidth], on gpu
|
||||
TensorPtr endIds; // [batchSize * beamWidth], on gpu
|
||||
|
||||
// optional parameters
|
||||
TensorPtr finished; // [maxBatchSize, beamWidth], finished states at current iteration.
|
||||
TensorPtr finished; // [batchSize, beamWidth], finished states at current iteration.
|
||||
// If true for some request, the decoding step of it is skipped, on gpu
|
||||
TensorPtr sequenceLimitLength; // [maxBatchSize], on gpu
|
||||
TensorPtr embeddingBias; // [maxBatchSize, vocabSizePadded], on gpu
|
||||
TensorPtr lengths; // [maxBatchSize, beamWidth], on gpu
|
||||
TensorPtr badWordsList; // [2, badWordsLength] or [maxBatchSize, 2, badWordsLength], on gpu
|
||||
TensorPtr badWordsPtrs; // [maxBatchSize][2, badWordsLength], on gpu
|
||||
TensorPtr badWordsLens; // [maxBatchSize], on gpu
|
||||
TensorPtr stopWordsList; // [maxBatchSize, 2, stopWordsLength], on gpu
|
||||
TensorPtr stopWordsPtrs; // [maxBatchSize][2, stopWordsLength], on gpu
|
||||
TensorPtr stopWordsLens; // [maxBatchSize], on gpu
|
||||
TensorPtr noRepeatNgramSize; // [maxBatchSize], on gpu
|
||||
TensorPtr sequenceLimitLength; // [batchSize], on gpu
|
||||
TensorPtr embeddingBias; // [batchSize, vocabSizePadded], on gpu
|
||||
TensorPtr lengths; // [batchSize, beamWidth], on gpu
|
||||
TensorPtr badWordsList; // [2, badWordsLength] or [batchSize, 2, badWordsLength], on gpu
|
||||
TensorPtr badWordsPtrs; // [batchSize][2, badWordsLength], on gpu
|
||||
TensorPtr badWordsLens; // [batchSize], on gpu
|
||||
TensorPtr stopWordsList; // [batchSize, 2, stopWordsLength], on gpu
|
||||
TensorPtr stopWordsPtrs; // [batchSize][2, stopWordsLength], on gpu
|
||||
TensorPtr stopWordsLens; // [batchSize], on gpu
|
||||
TensorPtr noRepeatNgramSize; // [batchSize], on gpu
|
||||
TensorPtr
|
||||
batchSlots; // [batchSize], optional, address map of the linear batch id to to the seq slots, int32_t, pinned
|
||||
|
||||
// parameters for beam search
|
||||
TensorPtr cacheIndirection; // [maxBatchSize, beamWidth, maxSeqLen] - the k/v cache index for beam search, on gpu
|
||||
TensorPtr cacheIndirection; // [batchSize, beamWidth, maxSeqLen] - the k/v cache index for beam search, on gpu
|
||||
|
||||
// Medusa
|
||||
class MedusaInputs
|
||||
{
|
||||
public:
|
||||
TensorPtr medusaPaths; // [maxBatchSize, maxTokensPerStep, maxMedusaHeads + 1], on gpu
|
||||
TensorPtr medusaTreeIds; // [maxBatchSize, maxTokensPerStep], on gpu
|
||||
TensorPtr medusaPaths; // [batchSize, maxTokensPerStep, maxMedusaHeads + 1], on gpu
|
||||
TensorPtr medusaTreeIds; // [batchSize, maxTokensPerStep], on gpu
|
||||
std::vector<std::vector<TensorPtr>>
|
||||
medusaLogits; // [maxBatchSize][maxAcceptedDraftTokensPerStep][maxDraftTokens + 1, vocabSizePadded], on gpu
|
||||
TensorPtr medusaCurTokensPerStep; // [maxBatchSize], on gpu
|
||||
TensorPtr medusaTargetTokensPerStep; // [maxBatchSize], on gpu
|
||||
medusaLogits; // [batchSize][maxAcceptedDraftTokensPerStep][maxDraftTokens + 1, vocabSizePadded], on gpu
|
||||
TensorPtr medusaCurTokensPerStep; // [batchSize], on gpu
|
||||
TensorPtr medusaTargetTokensPerStep; // [batchSize], on gpu
|
||||
};
|
||||
|
||||
class ExplicitDraftTokensInputs
|
||||
{
|
||||
public:
|
||||
TensorPtr nextDraftTokens; // [batchSize, maxNumPaths, maxPathLen]
|
||||
TensorPtr nextFlatTokens; // [batchSize * maxDecodingTokens]
|
||||
TensorPtr nextDraftIndices; // [batchSize, maxNumPaths, maxPathLen]
|
||||
TensorPtr nextDraftProbs; // [batchSize, maxNumPaths, maxDraftPathLen, vocabSize]
|
||||
TensorPtr lastDraftTokens; // [batchSize, maxNumPaths, maxPathLen]
|
||||
TensorPtr lastDraftIndices; // [batchSize, maxNumPaths, maxPathLen]
|
||||
TensorPtr masks; // [batchSize, maxDecodingTokens, maxDecodingTokens], bool
|
||||
TensorPtr packedPositionIds; // [batchSize * maxDecodingTokens]
|
||||
TensorPtr bestPathLengths; // [batchSize]
|
||||
TensorPtr bestPathIndices; // [batchSize]
|
||||
TensorPtr nextGenerationLengths; // [batchSize]
|
||||
TensorPtr lastPositionIdsBase; // [batchSize]
|
||||
TensorPtr lastGenerationLengths; // [batchSize]
|
||||
TensorPtr maxGenLengthDevice; // [1]
|
||||
TensorPtr seqSlots; // [batchSize]
|
||||
};
|
||||
|
||||
std::optional<MedusaInputs> medusaInputs;
|
||||
|
||||
std::optional<ExplicitDraftTokensInputs> explicitDraftTokensInputs;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/explicitDraftTokensBuffers.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
@ -94,12 +95,15 @@ public:
|
||||
public:
|
||||
TensorPtr nextDraftTokens; // [maxBatchSize, maxDraftTokens]
|
||||
TensorPtr nextDraftTokensLen; // [maxBatchSize]
|
||||
TensorPtr prevDraftTokensLen; // [maxBatchSize]
|
||||
TensorPtr acceptedTokensLen; // [maxBatchSize]
|
||||
TensorPtr acceptedLengthsCumSum; // [maxBatchSize + 1]
|
||||
TensorPtr pathsOffsets; // [maxBatchSize, maxAcceptedDraftTokensPerStep]
|
||||
};
|
||||
|
||||
std::optional<SpeculativeDecodingOutputs> speculativeDecodingOutputs;
|
||||
|
||||
std::optional<ExplicitDraftTokensBuffers::Inputs> explicitDraftTokensBuffers;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
142
cpp/include/tensorrt_llm/runtime/explicitDraftTokensBuffers.h
Normal file
142
cpp/include/tensorrt_llm/runtime/explicitDraftTokensBuffers.h
Normal file
@ -0,0 +1,142 @@
|
||||
/*
|
||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/runtime/explicitDraftTokensModule.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||
#include "tensorrt_llm/runtime/tllmRuntime.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
|
||||
class ExplicitDraftTokensBuffers
|
||||
{
|
||||
public:
|
||||
using SizeType32 = runtime::SizeType32;
|
||||
using ITensor = runtime::ITensor;
|
||||
using BufferPtr = runtime::IBuffer::SharedPtr;
|
||||
using TensorPtr = runtime::ITensor::SharedPtr;
|
||||
using TensorMap = runtime::StringPtrMap<runtime::ITensor>;
|
||||
|
||||
class Inputs
|
||||
{
|
||||
public:
|
||||
//! [batchSize]
|
||||
TensorPtr temperatures;
|
||||
//! [batchSize]
|
||||
TensorPtr positionIdsBase;
|
||||
//! [batchSize] or [numGenSequences]
|
||||
TensorPtr generationLengths;
|
||||
//! [batchSize]
|
||||
TensorPtr randomDataSample;
|
||||
//! [batchSize, maxNumPaths, maxPathDraftLen] or [numGenSequences, maxNumPaths, maxPathDraftLen]
|
||||
TensorPtr randomDataValidation;
|
||||
//! [batchSize, maxNumPaths, maxPathLen] or [numGenSequences, maxNumPaths, maxPathLen]
|
||||
TensorPtr draftTokens;
|
||||
//! [batchSize, maxNumPaths, maxPathLen] or [numGenSequences, maxNumPaths, maxPathLen]
|
||||
TensorPtr draftIndices;
|
||||
//! [batchSize, maxNumPaths, maxPathDraftLen, vocabSize]
|
||||
//! or [numGenSequences, maxNumPaths, maxPathDraftLen, vocabSize]
|
||||
TensorPtr draftProbs;
|
||||
//! [batchSize, maxDecodingTokens, ceil(maxDecodingTokens / 32)]
|
||||
//! or [numGenSequences, maxDecodingTokens, ceil(maxDecodingTokens / 32)]
|
||||
TensorPtr packedMasks;
|
||||
//! [batchSize] or [numGenSequences]
|
||||
TensorPtr positionIds;
|
||||
// [1], on pinned
|
||||
TensorPtr maxGenLengthHost;
|
||||
|
||||
void create(SizeType32 maxNumSequences, runtime::TllmRuntime const& runtime,
|
||||
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig);
|
||||
};
|
||||
|
||||
class EngineInputs : public Inputs
|
||||
{
|
||||
public:
|
||||
//! [numSequences], on gpu
|
||||
TensorPtr requestTypesDevice;
|
||||
//! [numGenSequences]
|
||||
TensorPtr positionOffsets;
|
||||
} engineInputs;
|
||||
|
||||
class EngineOutputs
|
||||
{
|
||||
public:
|
||||
//! [batchSize]
|
||||
TensorPtr nextGenerationLengths;
|
||||
//! [batchSize]
|
||||
TensorPtr nextPositionOffsets;
|
||||
//! [batchSize, maxDecodingTokens, maxDecodingTokens], bool
|
||||
TensorPtr masks;
|
||||
|
||||
//! [batchSize, maxNumPaths, maxPathLen]
|
||||
TensorPtr nextDraftTokens;
|
||||
//! [batchSize, maxNumPaths, maxPathLen]
|
||||
TensorPtr nextDraftIndices;
|
||||
//! [batchSize, maxNumPaths, maxDraftPathLen, vocabSize]
|
||||
TensorPtr nextDraftProbs;
|
||||
|
||||
//! [batchSize * maxDecodingTokens]
|
||||
TensorPtr nextFlatTokens;
|
||||
//! [batchSize]
|
||||
TensorPtr bestPathLengths;
|
||||
//! [batchSize]
|
||||
TensorPtr bestPathIndices;
|
||||
//! [1]
|
||||
TensorPtr maxGenToken;
|
||||
//! [1]
|
||||
TensorPtr totalGenToken;
|
||||
//! [batchSize * maxDecodingTokens]
|
||||
TensorPtr packedPositionIds;
|
||||
} engineOutputs;
|
||||
|
||||
public:
|
||||
ExplicitDraftTokensBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, runtime::BufferManager const& manager,
|
||||
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
|
||||
executor::DecodingConfig const& decodingConfig, runtime::TllmRuntime const& runtime);
|
||||
|
||||
void reshape(SizeType32 numCtxSequences, SizeType32 numGenSequences, runtime::ModelConfig const& modelConfig);
|
||||
|
||||
void setFromInputs(SizeType32 numCtxSequences, SizeType32 numGenSequences, runtime::ITensor const& requestTypes,
|
||||
ITensor const& seqSlots, ExplicitDraftTokensBuffers::Inputs const& decoderBuffers,
|
||||
ITensor const& contextPositionIds, runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig,
|
||||
runtime::WorldConfig const& worldConfig) const;
|
||||
|
||||
void insertInputTensors(
|
||||
TensorMap& inputBuffers, TensorMap& outputBuffers, runtime::WorldConfig const& worldConfig) const;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void setFromInputs(SizeType32 numCtxSequences, SizeType32 numGenSequences, SizeType32 vocabSizePadded,
|
||||
ITensor const& seqSlots, ExplicitDraftTokensBuffers::Inputs const& draftBuffers,
|
||||
ITensor const& contextPositionIds, runtime::ExplicitDraftTokensModule const& explicitDraftTokensModule,
|
||||
runtime::CudaStream const& stream) const;
|
||||
|
||||
public:
|
||||
// helper tensors
|
||||
std::size_t scanTempStorageBytes{0};
|
||||
BufferPtr scanTempStorage;
|
||||
TensorPtr cumSumGenerationLengths;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
@ -18,18 +18,15 @@
|
||||
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/decodingInput.h"
|
||||
#include "tensorrt_llm/runtime/decodingOutput.h"
|
||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||
#include "tensorrt_llm/runtime/samplingConfig.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
#include <curand_kernel.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
|
||||
@ -43,6 +40,8 @@ class DynamicDecodeLayer;
|
||||
namespace runtime
|
||||
{
|
||||
|
||||
class SpeculativeDecodingModule;
|
||||
|
||||
class IGptDecoder
|
||||
{
|
||||
public:
|
||||
@ -51,7 +50,8 @@ public:
|
||||
virtual ~IGptDecoder() = default;
|
||||
|
||||
virtual void setup(SamplingConfig const& samplingConfig, size_t batchSize,
|
||||
std::optional<TensorPtr> const& batchSlots = std::nullopt)
|
||||
std::optional<TensorPtr> const& batchSlots = std::nullopt,
|
||||
std::optional<DecodingOutput> const& output = std::nullopt)
|
||||
= 0;
|
||||
|
||||
virtual void forwardAsync(DecodingOutput& output, DecodingInput const& input) = 0;
|
||||
@ -93,7 +93,8 @@ public:
|
||||
std::shared_ptr<SpeculativeDecodingModule const> speculativeDecodingModule = nullptr);
|
||||
|
||||
void setup(SamplingConfig const& samplingConfig, size_t batchSize,
|
||||
std::optional<TensorPtr> const& batchSlots = std::nullopt) override;
|
||||
std::optional<TensorPtr> const& batchSlots = std::nullopt,
|
||||
std::optional<DecodingOutput> const& output = std::nullopt) override;
|
||||
|
||||
void forwardAsync(DecodingOutput& output, DecodingInput const& input) override;
|
||||
|
||||
@ -133,7 +134,9 @@ inline std::unique_ptr<IGptDecoder> IGptDecoder::create(executor::DecodingMode c
|
||||
case nvinfer1::DataType::kHALF:
|
||||
return std::make_unique<GptDecoder<half>>(mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded,
|
||||
maxSequenceLength, stream, speculativeDecodingModule);
|
||||
default: TLLM_THROW("Unsupported decoder data type. Use either kFLOAT or kHALF."); return nullptr;
|
||||
default:
|
||||
TLLM_THROW("Unsupported decoder data type: %d. Use either kFLOAT or kHALF.", static_cast<int>(dtype));
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
|
||||
@ -38,6 +38,12 @@ public:
|
||||
using TensorPtr = ITensor::SharedPtr;
|
||||
using SharedConstPtr = ITensor::SharedConstPtr;
|
||||
|
||||
enum class ForwardType
|
||||
{
|
||||
kASYNC,
|
||||
kSYNC
|
||||
};
|
||||
|
||||
GptDecoderBatch(std::size_t vocabSize, std::size_t vocabSizePadded, CudaStreamPtr stream,
|
||||
SpeculativeDecodingMode const& speculativeDecodingMode);
|
||||
|
||||
@ -47,6 +53,8 @@ public:
|
||||
SizeType32 maxTokensPerStep, bool fusedDecoder, nvinfer1::DataType dtype,
|
||||
ModelConfig const& modelConfig) override;
|
||||
|
||||
void setupExplicitDraftTokens(ExplicitDraftTokensBuffers::Inputs explicitDraftTokensBuffers) override;
|
||||
|
||||
void newBatch(
|
||||
GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig) override;
|
||||
|
||||
@ -164,6 +172,12 @@ public:
|
||||
return mJointDecodingOutput->speculativeDecodingOutputs->nextDraftTokens;
|
||||
}
|
||||
|
||||
//! @returns [batchSize], predicted draft tokens lengths for previous step, on gpu
|
||||
[[nodiscard]] TensorPtr getPrevDraftTokensLengths() const override
|
||||
{
|
||||
return mJointDecodingOutput->speculativeDecodingOutputs->prevDraftTokensLen;
|
||||
}
|
||||
|
||||
//! @returns [batchSize], predicted draft tokens lengths for next step, on gpu
|
||||
[[nodiscard]] TensorPtr getNextDraftTokensLengths() const override
|
||||
{
|
||||
@ -171,13 +185,13 @@ public:
|
||||
}
|
||||
|
||||
//! @returns [batchSize + 1], exclusive sum of accepted draft token lengths, on gpu
|
||||
[[nodiscard]] TensorPtr getSpecDecodingAcceptedLengthsCumSum() const override
|
||||
[[nodiscard]] TensorPtr getAcceptedLengthsCumSum() const override
|
||||
{
|
||||
return mJointDecodingOutput->speculativeDecodingOutputs->acceptedLengthsCumSum;
|
||||
}
|
||||
|
||||
//! @returns [batchSize, maxAcceptedDraftTokensPerStep], accepted paths packed into continuous tensor, on gpu
|
||||
[[nodiscard]] TensorPtr getSpecDecodingAcceptedPackedPaths() const override
|
||||
[[nodiscard]] TensorPtr getAcceptedPackedPaths() const override
|
||||
{
|
||||
return mJointDecodingOutput->speculativeDecodingOutputs->pathsOffsets;
|
||||
}
|
||||
@ -215,17 +229,19 @@ private:
|
||||
//! @brief Updates finished state on host for all active requests
|
||||
void updateFinished(decoder_batch::Token const& token);
|
||||
|
||||
//! @brief Sets inputs for explicit draft tokens.
|
||||
void setExplicitDraftTokensInputs(decoder_batch::Input const& input);
|
||||
|
||||
//! @brief Calls unfused or fused decoders for tokens per engine step
|
||||
void forwardDispatch(
|
||||
decoder_batch::Output& output, decoder_batch::Input const& input, std::optional<CudaEvent> const& eventStart);
|
||||
void forwardDispatch(decoder_batch::Output& output, decoder_batch::Input const& input, ForwardType forwardType);
|
||||
|
||||
//! @brief Calls unfused decoder for whole batch in loop
|
||||
void forwardUnfusedDecoder(SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input,
|
||||
std::optional<CudaEvent> const& eventStart);
|
||||
void forwardUnfusedDecoder(
|
||||
SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input, ForwardType forwardType);
|
||||
|
||||
//! @brief Calls fused decoder for whole batch
|
||||
void forwardFusedDecoder(SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input,
|
||||
std::optional<CudaEvent> const& eventStart);
|
||||
void forwardFusedDecoder(
|
||||
SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input, ForwardType forwardType);
|
||||
|
||||
private:
|
||||
std::size_t const mVocabSize;
|
||||
|
||||
@ -32,6 +32,7 @@
|
||||
#include "tensorrt_llm/runtime/generationOutput.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||
#include "tensorrt_llm/runtime/rawEngine.h"
|
||||
#include "tensorrt_llm/runtime/samplingConfig.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
|
||||
@ -158,28 +159,34 @@ public:
|
||||
//! @param sessionConfig Configuration of the session,
|
||||
//! @param modelConfig Description of the model,
|
||||
//! @param worldConfig Description of the environment,
|
||||
//! @param engineBuffer The compiled TensorRT engine (const void*),
|
||||
//! @param engineSize The size in bytes of the TensorRT engine (size_t),
|
||||
//! @param rawEngine The compiled TensorRT engine,
|
||||
//! @param logger The optional logger.
|
||||
GptSession(Config const& sessionConfig, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
|
||||
void const* engineBuffer, std::size_t engineSize, LoggerPtr logger = nullptr);
|
||||
RawEngine const& rawEngine, LoggerPtr logger = nullptr);
|
||||
|
||||
GptSession(Config const& sessionConfig, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
|
||||
void const* engineBuffer, std::size_t engineSize, LoggerPtr logger = nullptr)
|
||||
: GptSession(sessionConfig, modelConfig, worldConfig, RawEngine(engineBuffer, engineSize), std::move(logger))
|
||||
{
|
||||
}
|
||||
|
||||
GptSession(Config const& sessionConfig, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
|
||||
std::vector<uint8_t> const& engineBuffer, LoggerPtr logger = nullptr)
|
||||
: GptSession(
|
||||
sessionConfig, modelConfig, worldConfig, engineBuffer.data(), engineBuffer.size(), std::move(logger))
|
||||
: GptSession(sessionConfig, modelConfig, worldConfig, RawEngine(engineBuffer.data(), engineBuffer.size()),
|
||||
std::move(logger))
|
||||
{
|
||||
}
|
||||
|
||||
GptSession(Config const& sessionConfig, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
|
||||
std::string const& engineFile, LoggerPtr logger = nullptr)
|
||||
: GptSession(sessionConfig, modelConfig, worldConfig, utils::loadEngine(engineFile), std::move(logger))
|
||||
: GptSession(sessionConfig, modelConfig, worldConfig, RawEngine(engineFile), std::move(logger))
|
||||
{
|
||||
}
|
||||
|
||||
[[nodiscard]] nvinfer1::ILogger& getLogger() const;
|
||||
|
||||
[[nodiscard]] BufferManager const& getBufferManager() const;
|
||||
[[nodiscard]] BufferManager::CudaStreamPtr getRuntimeStreamPtr() const;
|
||||
|
||||
[[nodiscard]] ModelConfig const& getModelConfig() const
|
||||
{
|
||||
|
||||
@ -18,8 +18,10 @@
|
||||
|
||||
#include "tensorrt_llm/runtime/cudaEvent.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/explicitDraftTokensBuffers.h"
|
||||
#include "tensorrt_llm/runtime/iStatefulGptDecoder.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/request.h"
|
||||
#include "tensorrt_llm/runtime/utils/sessionUtils.h"
|
||||
|
||||
#include <memory>
|
||||
@ -31,41 +33,6 @@ namespace tensorrt_llm::runtime
|
||||
|
||||
namespace decoder_batch
|
||||
{
|
||||
class Request
|
||||
{
|
||||
public:
|
||||
using ConstTensorPtr = ITensor::SharedConstPtr;
|
||||
using TensorPtr = ITensor::SharedPtr;
|
||||
using BufferPtr = IBuffer::SharedPtr;
|
||||
|
||||
explicit Request(ConstTensorPtr ids, SizeType32 inputLen, std::optional<SizeType32> maxNewTokens = std::nullopt,
|
||||
std::optional<SizeType32> endId = std::nullopt)
|
||||
: ids{std::move(ids)}
|
||||
, inputLen(inputLen)
|
||||
, maxNewTokens{maxNewTokens}
|
||||
, endId{endId}
|
||||
, generatedTokensPerEngineStep(1)
|
||||
{
|
||||
}
|
||||
|
||||
// mandatory parameters
|
||||
ConstTensorPtr ids; // [inputSeqLen], the input sequence of token ids, on gpu
|
||||
SizeType32 inputLen; // the input length without draft tokens
|
||||
|
||||
// optional parameters
|
||||
std::optional<SizeType32> maxNewTokens; // maximum number of tokens to generate for this request
|
||||
std::optional<SizeType32> endId; // end token id
|
||||
BufferPtr draftTokens; // [generatedTokensPerStep - 1], on gpu, draft tokens from speculative decoding
|
||||
std::optional<TensorPtr>
|
||||
draftLogits; // [generatedTokensPerStep - 1, vocabSize], on gpu, draft tokens from speculative decoding
|
||||
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
|
||||
TensorPtr badWordsList; // [2, badWordsLength], on gpu
|
||||
TensorPtr stopWordsList; // [2, stopWordsLength], on gpu
|
||||
|
||||
SizeType32 generatedTokensPerEngineStep;
|
||||
TensorPtr medusaPaths; // [maxDraftTokens + 1, maxAcceptedDraftTokensPerStep + 1], on gpu
|
||||
TensorPtr medusaTreeIds; // [maxDraftTokens + 1], on gpu
|
||||
};
|
||||
|
||||
class Input
|
||||
{
|
||||
@ -109,6 +76,11 @@ public:
|
||||
// within one beam for beam search, on gpu
|
||||
std::vector<std::vector<TensorConstPtr>>
|
||||
predictedDraftLogits; // [maxBatchSize][maxAcceptedDraftTokensPerStep][maxDraftTokens + 1, vocabSizePadded]
|
||||
TensorConstPtr seqSlots; // [batchSize]
|
||||
|
||||
// explicit draft tokens data.
|
||||
std::optional<ExplicitDraftTokensBuffers::EngineOutputs> explicitDraftTokensInputs;
|
||||
std::optional<ExplicitDraftTokensBuffers::EngineInputs> explicitDraftTokensLastInputs;
|
||||
};
|
||||
|
||||
using Output = decoder::Output;
|
||||
@ -136,6 +108,9 @@ public:
|
||||
using TensorPtr = std::shared_ptr<ITensor>;
|
||||
using TokenPtr = std::unique_ptr<decoder_batch::Token const>;
|
||||
|
||||
//! @brief Setup buffers for ExplicitDraftTokens decoding.
|
||||
virtual void setupExplicitDraftTokens(ExplicitDraftTokensBuffers::Inputs explicitDraftTokensBuffers) = 0;
|
||||
|
||||
//! @brief Run one step for all requests without blocking the host process and return the token for synchronization.
|
||||
virtual TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) = 0;
|
||||
|
||||
@ -189,14 +164,17 @@ public:
|
||||
//! @returns [batchSize, maxTokensPerStep-1], predicted draft tokens for next step, on gpu
|
||||
virtual TensorPtr getNextDraftTokens() const = 0;
|
||||
|
||||
//! @returns [batchSize], predicted draft tokens lengths for previous step, on gpu
|
||||
virtual TensorPtr getPrevDraftTokensLengths() const = 0;
|
||||
|
||||
//! @returns [batchSize], predicted draft tokens lengths for next step, on gpu
|
||||
virtual TensorPtr getNextDraftTokensLengths() const = 0;
|
||||
|
||||
//! @returns [batchSize + 1], exclusive sum of accepted draft token lengths, on gpu
|
||||
virtual TensorPtr getSpecDecodingAcceptedLengthsCumSum() const = 0;
|
||||
virtual TensorPtr getAcceptedLengthsCumSum() const = 0;
|
||||
|
||||
//! @returns [batchSize, maxAcceptedDraftTokensPerStep], accepted paths packed into continuous tensor, on gpu
|
||||
virtual TensorPtr getSpecDecodingAcceptedPackedPaths() const = 0;
|
||||
virtual TensorPtr getAcceptedPackedPaths() const = 0;
|
||||
|
||||
protected:
|
||||
IGptDecoderBatch() = default;
|
||||
|
||||
@ -16,7 +16,12 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||
#include "tensorrt_llm/runtime/request.h"
|
||||
#include "tensorrt_llm/runtime/speculativeDecodingModule.h"
|
||||
#include <memory>
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
@ -24,8 +29,9 @@ namespace tensorrt_llm::runtime
|
||||
class LookaheadModule : public SpeculativeDecodingModule
|
||||
{
|
||||
public:
|
||||
explicit LookaheadModule(SizeType32 maxAcceptedTokens, SizeType32 maxDraftTokens) noexcept
|
||||
: SpeculativeDecodingModule(maxAcceptedTokens, maxDraftTokens, maxDraftTokens)
|
||||
explicit LookaheadModule(SizeType32 maxDraftPathLen, SizeType32 maxDecodingDraftTokens) noexcept
|
||||
: SpeculativeDecodingModule(maxDraftPathLen, maxDecodingDraftTokens, maxDecodingDraftTokens)
|
||||
, mExecutionConfig()
|
||||
{
|
||||
}
|
||||
|
||||
@ -33,5 +39,19 @@ public:
|
||||
: LookaheadModule(0, 0)
|
||||
{
|
||||
}
|
||||
|
||||
void setExecutionConfig(executor::LookaheadDecodingConfig const& config)
|
||||
{
|
||||
mExecutionConfig = config;
|
||||
}
|
||||
|
||||
executor::LookaheadDecodingConfig const getExecutionConfig() const
|
||||
{
|
||||
return mExecutionConfig;
|
||||
}
|
||||
|
||||
private:
|
||||
executor::LookaheadDecodingConfig mExecutionConfig;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -21,7 +21,9 @@
|
||||
#include "tensorrt_llm/runtime/loraModule.h"
|
||||
#include "tensorrt_llm/runtime/speculativeDecodingMode.h"
|
||||
#include "tensorrt_llm/runtime/speculativeDecodingModule.h"
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
#include <array>
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
@ -29,6 +31,12 @@ namespace tensorrt_llm::runtime
|
||||
class ModelConfig
|
||||
{
|
||||
public:
|
||||
// See `split_point` defined in `tensorrt_llm/models/generation_mixin.py`.
|
||||
// The split points are tuned to get better perf, if we need to let
|
||||
// users tune that, we can support that by writing and reading the
|
||||
// points in `config.json`.
|
||||
static constexpr std::array kOPT_PROFILES_SPLIT_POINTS{64, 128, 256, 512, 1024};
|
||||
|
||||
enum class ModelVariant : std::int32_t
|
||||
{
|
||||
kGpt = 0,
|
||||
@ -88,9 +96,16 @@ public:
|
||||
, mUsePositionEmbedding(false)
|
||||
, mUseTokenTypeEmbedding(false)
|
||||
, mSpeculativeDecodingMode(SpeculativeDecodingMode::None())
|
||||
, mLogitsDtype(nvinfer1::DataType::kFLOAT)
|
||||
, mUseShapeInference(true)
|
||||
{
|
||||
}
|
||||
|
||||
[[nodiscard]] static std::vector<SizeType32> getOptProfilesSplitPoints() noexcept
|
||||
{
|
||||
return {kOPT_PROFILES_SPLIT_POINTS.begin(), kOPT_PROFILES_SPLIT_POINTS.end()};
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 constexpr getVocabSize() const noexcept
|
||||
{
|
||||
return mVocabSize;
|
||||
@ -555,6 +570,26 @@ public:
|
||||
return mSpeculativeDecodingMode;
|
||||
}
|
||||
|
||||
void setLogitsDtype(nvinfer1::DataType inputDtype) noexcept
|
||||
{
|
||||
mLogitsDtype = inputDtype;
|
||||
}
|
||||
|
||||
[[nodiscard]] nvinfer1::DataType constexpr getLogitsDtype() const noexcept
|
||||
{
|
||||
return mLogitsDtype;
|
||||
}
|
||||
|
||||
void setUseShapeInference(bool useShapeInference) noexcept
|
||||
{
|
||||
mUseShapeInference = useShapeInference;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool useShapeInference() const noexcept
|
||||
{
|
||||
return mUseShapeInference;
|
||||
}
|
||||
|
||||
private:
|
||||
SizeType32 mVocabSize;
|
||||
SizeType32 mNbAttentionLayers;
|
||||
@ -608,6 +643,10 @@ private:
|
||||
// Speculative decoding members
|
||||
std::shared_ptr<SpeculativeDecodingModule> mSpeculativeDecodingModule;
|
||||
SpeculativeDecodingMode mSpeculativeDecodingMode;
|
||||
|
||||
// Logits datatype
|
||||
nvinfer1::DataType mLogitsDtype;
|
||||
bool mUseShapeInference;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
98
cpp/include/tensorrt_llm/runtime/rawEngine.h
Normal file
98
cpp/include/tensorrt_llm/runtime/rawEngine.h
Normal file
@ -0,0 +1,98 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
#include <filesystem>
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
|
||||
class RawEngine
|
||||
{
|
||||
public:
|
||||
enum Type
|
||||
{
|
||||
FilePath,
|
||||
AddressWithSize,
|
||||
HostMemory
|
||||
};
|
||||
|
||||
explicit RawEngine(std::filesystem::path enginePath) noexcept
|
||||
: mType(FilePath)
|
||||
, mEnginePath(std::move(enginePath))
|
||||
{
|
||||
}
|
||||
|
||||
explicit RawEngine(void const* engineAddr, std::size_t engineSize) noexcept
|
||||
: mType(AddressWithSize)
|
||||
, mEngineAddr(engineAddr)
|
||||
, mEngineSize(engineSize)
|
||||
{
|
||||
}
|
||||
|
||||
explicit RawEngine(nvinfer1::IHostMemory const* engineBuffer) noexcept
|
||||
: mType(HostMemory)
|
||||
, mEngineBuffer(engineBuffer)
|
||||
{
|
||||
}
|
||||
|
||||
[[nodiscard]] Type getType() const
|
||||
{
|
||||
return mType;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::filesystem::path getPath() const
|
||||
{
|
||||
TLLM_CHECK(mType == FilePath);
|
||||
return mEnginePath;
|
||||
}
|
||||
|
||||
[[nodiscard]] void const* getAddress() const
|
||||
{
|
||||
TLLM_CHECK(mType == AddressWithSize);
|
||||
return mEngineAddr;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::size_t getSize() const
|
||||
{
|
||||
TLLM_CHECK(mType == AddressWithSize);
|
||||
return mEngineSize;
|
||||
}
|
||||
|
||||
[[nodiscard]] nvinfer1::IHostMemory const* getHostMemory() const
|
||||
{
|
||||
TLLM_CHECK(mType == HostMemory);
|
||||
return mEngineBuffer;
|
||||
}
|
||||
|
||||
private:
|
||||
Type mType;
|
||||
std::filesystem::path mEnginePath;
|
||||
|
||||
struct
|
||||
{
|
||||
void const* mEngineAddr{};
|
||||
std::size_t mEngineSize{};
|
||||
};
|
||||
|
||||
nvinfer1::IHostMemory const* mEngineBuffer{};
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
68
cpp/include/tensorrt_llm/runtime/request.h
Normal file
68
cpp/include/tensorrt_llm/runtime/request.h
Normal file
@ -0,0 +1,68 @@
|
||||
/*
|
||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/runtime/cudaEvent.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/iStatefulGptDecoder.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/utils/sessionUtils.h"
|
||||
#include <optional>
|
||||
|
||||
namespace tensorrt_llm::runtime::decoder_batch
|
||||
{
|
||||
|
||||
class Request
|
||||
{
|
||||
public:
|
||||
using ConstTensorPtr = ITensor::SharedConstPtr;
|
||||
using TensorPtr = ITensor::SharedPtr;
|
||||
using BufferPtr = IBuffer::SharedPtr;
|
||||
|
||||
explicit Request(ConstTensorPtr ids, SizeType32 inputLen, std::optional<SizeType32> maxNewTokens = std::nullopt,
|
||||
std::optional<SizeType32> endId = std::nullopt)
|
||||
: ids{std::move(ids)}
|
||||
, inputLen(inputLen)
|
||||
, maxNewTokens{maxNewTokens}
|
||||
, endId{endId}
|
||||
, generatedTokensPerEngineStep(1)
|
||||
{
|
||||
}
|
||||
|
||||
// mandatory parameters
|
||||
ConstTensorPtr ids; // [inputSeqLen], the input sequence of token ids, on gpu
|
||||
SizeType32 inputLen; // the input length without draft tokens
|
||||
|
||||
// optional parameters
|
||||
std::optional<SizeType32> maxNewTokens; // maximum number of tokens to generate for this request
|
||||
std::optional<SizeType32> endId; // end token id
|
||||
BufferPtr draftTokens; // [generatedTokensPerStep - 1], on gpu, draft tokens from speculative decoding
|
||||
std::optional<TensorPtr>
|
||||
draftLogits; // [generatedTokensPerStep - 1, vocabSize], on gpu, draft tokens from speculative decoding
|
||||
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
|
||||
TensorPtr badWordsList; // [2, badWordsLength], on gpu
|
||||
TensorPtr stopWordsList; // [2, stopWordsLength], on gpu
|
||||
|
||||
SizeType32 generatedTokensPerEngineStep;
|
||||
TensorPtr medusaPaths; // [maxDraftTokens + 1, maxAcceptedDraftTokensPerStep + 1], on gpu
|
||||
TensorPtr medusaTreeIds; // [maxDraftTokens + 1], on gpu
|
||||
std::optional<executor::LookaheadDecodingConfig> lookaheadRuntimeConfig;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime::decoder_batch
|
||||
@ -75,6 +75,11 @@ public:
|
||||
return anyBitSet(kExplicitDraftTokens);
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr updatesPositionIds() const
|
||||
{
|
||||
return anyBitSet(kLookaheadDecoding | kExplicitDraftTokens);
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr requiresAttentionMask() const
|
||||
{
|
||||
return anyBitSet(kLookaheadDecoding | kMedusa | kExplicitDraftTokens);
|
||||
@ -101,6 +106,12 @@ public:
|
||||
return anyBitSet(kMedusa);
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr needsDecoderPrologue() const
|
||||
{
|
||||
// Potentially lookahead should require it too.
|
||||
return anyBitSet(kExplicitDraftTokens);
|
||||
}
|
||||
|
||||
using UnderlyingType = std::uint8_t;
|
||||
|
||||
bool operator==(SpeculativeDecodingMode const& other) const
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3769cb4ad108cb9898a03b25e91781bcb5576b85397fbd7f673843abba27272e
|
||||
size 3977112
|
||||
oid sha256:1fec0fdc00c076761ec48eb5e2ea93473a329e844a8091e26c6e3e02fd14a8b1
|
||||
size 3931604
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3769cb4ad108cb9898a03b25e91781bcb5576b85397fbd7f673843abba27272e
|
||||
size 3977112
|
||||
oid sha256:1fec0fdc00c076761ec48eb5e2ea93473a329e844a8091e26c6e3e02fd14a8b1
|
||||
size 3931604
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
359da6357b9948425f249c226166408a libtensorrt_llm_batch_manager_static.a
|
||||
359da6357b9948425f249c226166408a libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
8d4b145290d5984494a1fa6e380d01456534dc62 commit
|
||||
93adf3003d7c422586a9bf892367371d libtensorrt_llm_batch_manager_static.a
|
||||
93adf3003d7c422586a9bf892367371d libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
c0bd2b69c932257678a2aad9bd8baba4b291795e commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3841fcf17899aa8cb75a01a5d0ee8c99e4e078399e4bb8a1201f9d53445d09cf
|
||||
size 3869232
|
||||
oid sha256:bd757c26886a3ffd6947615d9f2829434e94839b693007a64b47c6b5c26416e4
|
||||
size 3812158
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:99d1e58c95ea4267129b7a3ac95b65dc72b5e006b3168d07b213a1f9712930de
|
||||
size 3835982
|
||||
oid sha256:87321383075adf2d87cfbdc8a12a3d3815ef058d5da9b6aaa8d7d3f3263af439
|
||||
size 3773896
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e61e4199962b639502aba50adca548e79d6332e658c10ab717b2ec019d28ed45
|
||||
size 22213850
|
||||
oid sha256:58cdc0a330f8bfb7b50e3202aeac47bde0835b1dc600b4bfdcd2b30801e66e03
|
||||
size 22381766
|
||||
|
||||
@ -162,7 +162,7 @@ struct CutlassGemmConfig
|
||||
{
|
||||
}
|
||||
|
||||
std::string toString()
|
||||
std::string toString() const
|
||||
{
|
||||
std::stringstream tactic;
|
||||
tactic << "Cutlass GEMM Tactic";
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2a00e1d3526af9fe7877c5e3362b32244309ccfac8fd720d1020c966d13b71c9
|
||||
size 1372862
|
||||
oid sha256:18a967eaa1e9a7164e0b104a84b13ea95404f7c7c278375feb2513d5f063bafe
|
||||
size 1396404
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2a00e1d3526af9fe7877c5e3362b32244309ccfac8fd720d1020c966d13b71c9
|
||||
size 1372862
|
||||
oid sha256:18a967eaa1e9a7164e0b104a84b13ea95404f7c7c278375feb2513d5f063bafe
|
||||
size 1396404
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
a35a65a41062edf23a898ad42cdce31c libtensorrt_llm_executor_static.a
|
||||
a35a65a41062edf23a898ad42cdce31c libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
8d4b145290d5984494a1fa6e380d01456534dc62 commit
|
||||
7d12b9c04cb6738bb5f7747a88b00c1c libtensorrt_llm_executor_static.a
|
||||
7d12b9c04cb6738bb5f7747a88b00c1c libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
c0bd2b69c932257678a2aad9bd8baba4b291795e commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e37e2b2f28ac1ae37c22fac7c93394c6fba6e94e27403c0904e47eeb6cd4bf5c
|
||||
size 1412454
|
||||
oid sha256:e503b4cfb1c842850287a359ffed23a1773a67a96475d365b66d757a283ac218
|
||||
size 1448772
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0b801135ba31f7ea63de5deb1880a45b68b2bc9fa45403e7204f6b7a153bd3ee
|
||||
size 1346882
|
||||
oid sha256:f8c80cf7aca2b135a656a060456fb30a820e459b4b36560162b02fa65121ef50
|
||||
size 1375430
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9334de5c0a470731f8dd63f68e60ef320268d838547e5e6cbf537bf5c231eb6f
|
||||
size 12962386
|
||||
oid sha256:cc65971d6d74260cb49b354aa4b0b82f92863cc722fbf206bf8a4919a4897532
|
||||
size 14031364
|
||||
|
||||
@ -97,6 +97,7 @@ struct PackedOn16Bytes<__nv_bfloat16>
|
||||
{
|
||||
using Type = PackedBFloat16;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
// add two 128b data
|
||||
@ -600,7 +601,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params)
|
||||
|
||||
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true, bool PUSH_MODE = false, bool Bias = false,
|
||||
bool Residual = false>
|
||||
static __global__ void twoShotAllReduceKernel(AllReduceParams params)
|
||||
static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduceParams params)
|
||||
{
|
||||
// Suppose that two GPUs participate in the AR exchange, and we start two blocks.
|
||||
// The message is partitioned into chunks as detailed below:
|
||||
@ -674,7 +675,7 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params)
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
size_t offset_rank = ii * params.elts_per_rank + local_offset;
|
||||
size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset;
|
||||
if (offset_rank >= params.elts_total)
|
||||
{
|
||||
continue;
|
||||
@ -829,7 +830,6 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
|
||||
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
|
||||
blocks_per_grid = std::min(static_cast<size_t>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
|
||||
*/
|
||||
|
||||
while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE)
|
||||
{
|
||||
blocks_per_grid += 1;
|
||||
@ -863,7 +863,8 @@ template <typename T, int RANKS_PER_NODE, bool PUSH_MODE = false, bool USE_MEMCP
|
||||
void AllReduceNormKernelLaunch(AllReduceStrategyType algo, AllReduceStrategyConfig config, AllReduceFusionOp fusionOp,
|
||||
AllReduceParams& params, cudaStream_t stream)
|
||||
{
|
||||
TLLM_CHECK(fusionOp == AllReduceFusionOp::RESIDUAL_RMS_NORM);
|
||||
TLLM_CHECK_WITH_INFO(fusionOp == AllReduceFusionOp::RESIDUAL_RMS_NORM, "Unsupported AllReduceFusionOp: %d",
|
||||
static_cast<int>(fusionOp));
|
||||
if (algo == AllReduceStrategyType::ONESHOT)
|
||||
{
|
||||
reduce_fusion::one_shot_all_reduce_norm_kernel_launcher<T, RANKS_PER_NODE, Bias, Affine>(params, stream);
|
||||
@ -1019,7 +1020,6 @@ AllReduceParams AllReduceParams::deserialize(int32_t const* buffer, size_t tpSiz
|
||||
}
|
||||
params.barrier_flag = flag_value;
|
||||
params.ranks_per_node = tpSize;
|
||||
params.rank = tpRank;
|
||||
params.local_rank = tpRank;
|
||||
|
||||
return params;
|
||||
|
||||
@ -30,7 +30,7 @@ namespace tensorrt_llm::kernels
|
||||
constexpr size_t WARP_SIZE = 32;
|
||||
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 24;
|
||||
constexpr size_t MAX_RANKS_PER_NODE = 8;
|
||||
constexpr size_t DEFAULT_BLOCK_SIZE = 1024;
|
||||
constexpr size_t DEFAULT_BLOCK_SIZE = 512;
|
||||
|
||||
// Warning: python definition is in tensorrt_llm/functional.py
|
||||
// they must be kept in sync
|
||||
@ -82,7 +82,7 @@ struct AllReduceParams
|
||||
size_t elts_per_rank;
|
||||
size_t elts_per_block;
|
||||
size_t rank_offset;
|
||||
size_t ranks_per_node, rank, local_rank;
|
||||
size_t ranks_per_node, local_rank;
|
||||
uint32_t barrier_flag;
|
||||
uint32_t* peer_barrier_ptrs_in[MAX_RANKS_PER_NODE];
|
||||
uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE];
|
||||
|
||||
@ -173,6 +173,10 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
|
||||
std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(
|
||||
int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config)
|
||||
{
|
||||
#ifdef FAST_BUILD
|
||||
// Fast build disables all configs except this one for SM90
|
||||
return {CutlassTileConfigSM90::CtaShape128x128x128B};
|
||||
#else
|
||||
if (config & CutlassGemmConfig::GROUPED_GEMM)
|
||||
{
|
||||
return {CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B,
|
||||
@ -187,26 +191,35 @@ std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(
|
||||
CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B};
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// We only compile CUTLASS kernels with multi-cast along M if the M tile is >= 128. This is purely to improve
|
||||
// compilation speed.
|
||||
bool supports_mcast_along_m(const CutlassTileConfigSM90 tile)
|
||||
{
|
||||
#ifdef FAST_BUILD
|
||||
return false;
|
||||
#else
|
||||
std::set<CutlassTileConfigSM90> valid_tiles{CutlassTileConfigSM90::CtaShape128x16x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B};
|
||||
return valid_tiles.count(tile) == 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
// We only compile CUTLASS kernels with multi-cast along N if the N tile is >= 128. This is purely to improve
|
||||
// compilation speed.
|
||||
bool supports_mcast_along_n(const CutlassTileConfigSM90 tile)
|
||||
{
|
||||
#ifdef FAST_BUILD
|
||||
return false;
|
||||
#else
|
||||
std::set<CutlassTileConfigSM90> valid_tiles{CutlassTileConfigSM90::CtaShape64x128x128B,
|
||||
CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x256x128B};
|
||||
return valid_tiles.count(tile) == 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
std::vector<CutlassGemmConfig> get_candidate_configs(
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
|
||||
@ -26,13 +27,31 @@ namespace kernels
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
|
||||
template <class TileShape, class ClusterShape, class ActivationType>
|
||||
struct should_filter_sm90_gemm_problem_shape
|
||||
{
|
||||
#ifdef FAST_BUILD
|
||||
constexpr static int TILE_K = 128 * 8 / cutlass::sizeof_bits<ActivationType>::value;
|
||||
using SupportedCtaShape = cute::Shape<cute::_128, cute::_128, cute::Int<TILE_K>>;
|
||||
using SupportedCgaShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
|
||||
constexpr static bool value
|
||||
= !cute::is_same_v<SupportedCtaShape, TileShape> || !cute::is_same_v<SupportedCgaShape, ClusterShape>;
|
||||
#else
|
||||
constexpr static bool value = false;
|
||||
#endif
|
||||
};
|
||||
template <class TileShape, class ClusterShape, class ActivationType>
|
||||
constexpr static bool should_filter_sm90_gemm_problem_shape_v
|
||||
= should_filter_sm90_gemm_problem_shape<TileShape, ClusterShape, ActivationType>::value;
|
||||
|
||||
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> get_candidate_configs(
|
||||
int sm, int const max_split_k, tensorrt_llm::cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam const);
|
||||
|
||||
tensorrt_llm::cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies(
|
||||
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> const& candidate_configs,
|
||||
std::vector<int> const& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts,
|
||||
int const split_k_limit, const size_t workspace_bytes, int const multi_processor_count, int const is_weight_only);
|
||||
std::vector<int> const& occupancies, int64_t const m, int64_t const n, int64_t const k, int64_t const num_experts,
|
||||
int const split_k_limit, size_t const workspace_bytes, int const multi_processor_count, int const is_weight_only);
|
||||
|
||||
} // namespace cutlass_kernels
|
||||
} // namespace kernels
|
||||
|
||||
@ -41,6 +41,7 @@
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h"
|
||||
|
||||
@ -69,15 +70,8 @@ void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType
|
||||
#ifdef COMPILE_HOPPER_TMA_GEMMS
|
||||
using CutlassActivationType = typename TllmToCutlassTypeAdapter<ActivationType>::type;
|
||||
|
||||
// For FAST_BUILD, only instantiate kernels with 128x128x128B with 1x1x1 cluster shape.
|
||||
#ifdef FAST_BUILD
|
||||
constexpr int TILE_K = 128 * 8 / cutlass::sizeof_bits<CutlassActivationType>::value;
|
||||
using SupportedCtaShape = Shape<_128, _128, cute::Int<TILE_K>>;
|
||||
using SupportedCgaShape = Shape<_1, _1, _1>;
|
||||
|
||||
if constexpr (cute::is_same_v<SupportedCtaShape, CTAShape> && cute::is_same_v<SupportedCgaShape, ClusterShape>)
|
||||
if constexpr (!should_filter_sm90_gemm_problem_shape_v<CTAShape, ClusterShape, ActivationType>)
|
||||
{
|
||||
#endif // FAST_BUILD
|
||||
using CutlassWeightType__ = typename TllmToCutlassTypeAdapter<WeightType>::type;
|
||||
// We need to remap this since SM90 uses a different layout for the weight matrix.
|
||||
using CutlassWeightType_ = std::conditional_t<std::is_same_v<CutlassWeightType__, cutlass::uint4b_t>,
|
||||
@ -278,13 +272,17 @@ void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType
|
||||
= "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status));
|
||||
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg);
|
||||
}
|
||||
#ifdef FAST_BUILD
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] Config not compiled with FAST_BUILD.");
|
||||
std::stringstream ss;
|
||||
ss << "[TensorRT-LLm Error][fpA_intB Runner] Config (" << (int64_t) cute::size<0>(CTAShape{}) << ","
|
||||
<< (int64_t) cute::size<1>(CTAShape{}) << "," << (int64_t) cute::size<2>(CTAShape{}) << ") ("
|
||||
<< (int64_t) cute::size<0>(ClusterShape{}) << "," << (int64_t) cute::size<1>(ClusterShape{}) << ","
|
||||
<< (int64_t) cute::size<2>(ClusterShape{}) << ") not compiled with FAST_BUILD.";
|
||||
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
#endif // FAST_BUILD
|
||||
|
||||
#else // COMPILE_HOPPER_TMA_GEMMS
|
||||
throw std::runtime_error(
|
||||
|
||||
@ -197,14 +197,7 @@ void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, i
|
||||
{
|
||||
#ifdef COMPILE_HOPPER_TMA_GEMMS
|
||||
using namespace cute;
|
||||
// For FAST_BUILD, only instantiate kernels with 128x128x128B with 1x1x1 cluster shape.
|
||||
#ifdef FAST_BUILD
|
||||
constexpr int TILE_K = 128 * 8 / cutlass::sizeof_bits<WeightType>::value;
|
||||
using SupportedCtaShape = Shape<_128, _128, cute::Int<TILE_K>>;
|
||||
using SupportedCgaShape = Shape<_1, _1, _1>;
|
||||
|
||||
if constexpr (cute::is_same_v<SupportedCtaShape, TileShape> && cute::is_same_v<SupportedCgaShape, ClusterShape>)
|
||||
#endif // FAST_BUILD
|
||||
if constexpr (!should_filter_sm90_gemm_problem_shape_v<TileShape, ClusterShape, T>)
|
||||
{
|
||||
using GemmInfo = HopperGroupedGemmInfo<T, WeightType, EpilogueTag, TileShape, ClusterShape, BIAS>;
|
||||
|
||||
@ -287,12 +280,10 @@ void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, i
|
||||
"Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status)));
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
#ifdef FAST_BUILD
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Configuration was disabled by FAST_BUILD");
|
||||
}
|
||||
#endif
|
||||
|
||||
#else // COMPILE_HOPPER_TMA_GEMMS
|
||||
TLLM_THROW("Please recompile with support for hopper by passing 90-real as an arch to build_wheel.py.");
|
||||
|
||||
@ -30,7 +30,6 @@ namespace tensorrt_llm
|
||||
|
||||
struct HopperGroupedGemmInput
|
||||
{
|
||||
|
||||
template <class Tag>
|
||||
using TransposeLayoutTag = std::conditional_t<std::is_same_v<Tag, cutlass::layout::RowMajor>,
|
||||
cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>;
|
||||
@ -180,7 +179,7 @@ public:
|
||||
bool supportsHopperSpecialisation() const;
|
||||
[[nodiscard]] bool isFusedGatedActivation(bool is_gated_activation, int gemm_n, int gemm_k) const;
|
||||
|
||||
size_t calcMaxWorkspaceSize(int num_experts) const;
|
||||
size_t getMaxWorkspaceSize(int num_experts) const;
|
||||
|
||||
[[nodiscard]] int getSM() const;
|
||||
|
||||
@ -197,9 +196,12 @@ private:
|
||||
int64_t gemm_k, int num_experts, bool use_fused_moe, cudaStream_t stream);
|
||||
|
||||
private:
|
||||
int sm_;
|
||||
int multi_processor_count_;
|
||||
int sm_{};
|
||||
int multi_processor_count_{};
|
||||
mutable int num_experts_ = 0;
|
||||
mutable size_t gemm_workspace_size_ = 0;
|
||||
std::optional<cutlass_extensions::CutlassGemmConfig> best_config_{};
|
||||
size_t calcMaxWorkspaceSize(int num_experts) const;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -48,6 +48,8 @@
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
|
||||
|
||||
@ -533,6 +535,18 @@ void MoeGemmRunner<T, WeightType>::dispatchToArch<EpilogueTag>(T const* A, Weigh
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
size_t MoeGemmRunner<T, WeightType>::getMaxWorkspaceSize(int num_experts) const
|
||||
{
|
||||
if (num_experts != num_experts_)
|
||||
{
|
||||
TLLM_LOG_TRACE("Calling getMaxWorkspaceSize() with a new expert count %d vs %d", num_experts, num_experts_);
|
||||
num_experts_ = num_experts;
|
||||
gemm_workspace_size_ = calcMaxWorkspaceSize(num_experts);
|
||||
}
|
||||
return gemm_workspace_size_;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
size_t MoeGemmRunner<T, WeightType>::calcMaxWorkspaceSize(int num_experts) const
|
||||
{
|
||||
|
||||
@ -0,0 +1,48 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
* Common utils to be shared between Precompiled and JIT implementation.
|
||||
*/
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParams)
|
||||
{
|
||||
unsigned int head_size = xqaParams.head_size;
|
||||
unsigned int num_q_heads = xqaParams.num_q_heads;
|
||||
unsigned int num_kv_heads = xqaParams.num_kv_heads;
|
||||
TLLM_CHECK_WITH_INFO(num_q_heads % num_kv_heads == 0, "numQHeads should be multiple of numKVHeads.");
|
||||
unsigned int num_q_heads_over_kv = num_q_heads / num_kv_heads;
|
||||
unsigned int beam_width = xqaParams.beam_width;
|
||||
|
||||
// Use mTileSize = 16 kernels when qSeqLen <= 16.
|
||||
unsigned int qSeqLen = static_cast<unsigned int>(xqaParams.generation_input_length);
|
||||
unsigned int mTileSize = qSeqLen <= 16 ? 16 : 32;
|
||||
// MultiQueryToken kernels can support any num_q_heads_over_kv that is power of 2.
|
||||
unsigned int kernel_num_q_heads_over_kv = xqaParams.multi_query_tokens ? 0 : num_q_heads_over_kv;
|
||||
// MultiQueryToken kernels can handle either 16/32 for M direction per CTA.
|
||||
unsigned int kernel_m_tilesize = xqaParams.multi_query_tokens ? mTileSize : num_q_heads_over_kv;
|
||||
|
||||
return {xqaParams.kv_cache_data_type, head_size, beam_width, kernel_num_q_heads_over_kv, kernel_m_tilesize,
|
||||
xqaParams.paged_kv_cache ? static_cast<unsigned int>(xqaParams.tokens_per_block) : 0, xqaParams.paged_kv_cache,
|
||||
xqaParams.multi_query_tokens};
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -17,6 +17,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
#include "decoderXQAConstants.h"
|
||||
#include "tensorrt_llm/common/cudaDriverWrapper.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/common/workspace.h"
|
||||
@ -73,6 +74,8 @@ struct XQAKernelRuntimeHashKey
|
||||
}
|
||||
};
|
||||
|
||||
XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParams);
|
||||
|
||||
struct XQAKernelRuntimeHasher
|
||||
{
|
||||
size_t operator()(XQAKernelRuntimeHashKey const& s) const
|
||||
|
||||
@ -29,7 +29,7 @@ namespace jit
|
||||
{
|
||||
|
||||
CubinObj::CubinObj(void const* buffer_, size_t buffer_size)
|
||||
: mDriver(tensorrt_llm::common::CUDADriverWrapper::getInstance())
|
||||
: mInitialized(false)
|
||||
{
|
||||
uint8_t const* buffer = static_cast<uint8_t const*>(buffer_);
|
||||
size_t remaining_buffer_size = buffer_size;
|
||||
@ -37,8 +37,37 @@ CubinObj::CubinObj(void const* buffer_, size_t buffer_size)
|
||||
mContent.resize(len);
|
||||
TLLM_CHECK(len <= remaining_buffer_size);
|
||||
memcpy(mContent.data(), buffer, len);
|
||||
}
|
||||
|
||||
initialize(mContent.c_str(), "kernel_mha");
|
||||
CubinObj::CubinObj(std::string const& content)
|
||||
: mContent(content)
|
||||
, mInitialized(false)
|
||||
{
|
||||
}
|
||||
|
||||
CubinObj::CubinObj(CubinObj const& other)
|
||||
{
|
||||
// Only uninitialized CubinObj can be copy-constructed.
|
||||
TLLM_CHECK(!other.mInitialized);
|
||||
|
||||
this->mContent = other.mContent;
|
||||
this->mInitialized = false;
|
||||
}
|
||||
|
||||
CubinObj& CubinObj::operator=(CubinObj const& other)
|
||||
{
|
||||
if (this == &other)
|
||||
{
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Only uninitialized CubinObj can be copy-assigned.
|
||||
TLLM_CHECK(!other.mInitialized);
|
||||
|
||||
this->mContent = other.mContent;
|
||||
this->mInitialized = false;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
size_t CubinObj::getSerializationSize() const noexcept
|
||||
@ -59,47 +88,45 @@ void CubinObj::serialize(void* buffer_, size_t buffer_size) const noexcept
|
||||
memcpy(buffer, mContent.c_str(), len);
|
||||
}
|
||||
|
||||
CubinObj::CubinObj(std::string const& content)
|
||||
: mDriver(tensorrt_llm::common::CUDADriverWrapper::getInstance())
|
||||
, mContent(content)
|
||||
, mModule(nullptr)
|
||||
, mFunction(nullptr)
|
||||
, mSharedMemBytes(0)
|
||||
{
|
||||
initialize(mContent.c_str(), "kernel_mha");
|
||||
}
|
||||
|
||||
void CubinObj::launch(dim3 gridDim, dim3 blockDim, CUstream hStream, void** kernelParams)
|
||||
{
|
||||
TLLM_CHECK(mInitialized);
|
||||
cuErrCheck(mDriver->cuLaunchKernel(mFunction, gridDim.x, gridDim.y, gridDim.z, blockDim.x, blockDim.y, blockDim.z,
|
||||
mSharedMemBytes, hStream, kernelParams, /*extra=*/nullptr),
|
||||
mDriver);
|
||||
}
|
||||
|
||||
void CubinObj::initialize(char const* content, char const* funcName)
|
||||
void CubinObj::initialize()
|
||||
{
|
||||
cuErrCheck(mDriver->cuModuleLoadData(&mModule, content), mDriver);
|
||||
TLLM_CHECK(mModule != nullptr);
|
||||
cuErrCheck(mDriver->cuModuleGetFunction(&mFunction, mModule, funcName), mDriver);
|
||||
TLLM_CHECK(mFunction != nullptr);
|
||||
|
||||
// Populate mSharedMemBytes.
|
||||
CUdeviceptr shmem_dev_ptr = 0;
|
||||
cuErrCheck(mDriver->cuModuleGetGlobal(&shmem_dev_ptr, nullptr, mModule, "smemSize"), mDriver);
|
||||
TLLM_CHECK(shmem_dev_ptr != 0);
|
||||
cuErrCheck(mDriver->cuMemcpyDtoH(&mSharedMemBytes, shmem_dev_ptr, sizeof(unsigned int)), mDriver);
|
||||
|
||||
TLLM_CHECK(mSharedMemBytes > 0);
|
||||
|
||||
/* Set 46KB threshold here because we have to take static/driver shared memory into consideration. */
|
||||
if (mSharedMemBytes >= 46 * 1024)
|
||||
if (!mInitialized)
|
||||
{
|
||||
cuErrCheck(
|
||||
mDriver->cuFuncSetAttribute(mFunction, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, mSharedMemBytes),
|
||||
mDriver);
|
||||
}
|
||||
mDriver = tensorrt_llm::common::CUDADriverWrapper::getInstance();
|
||||
mModule = nullptr;
|
||||
cuErrCheck(mDriver->cuModuleLoadData(&mModule, mContent.c_str()), mDriver);
|
||||
TLLM_CHECK(mModule != nullptr);
|
||||
mFunction = nullptr;
|
||||
cuErrCheck(mDriver->cuModuleGetFunction(&mFunction, mModule, kFuncName), mDriver);
|
||||
TLLM_CHECK(mFunction != nullptr);
|
||||
|
||||
sync_check_cuda_error();
|
||||
// Populate mSharedMemBytes.
|
||||
CUdeviceptr shmem_dev_ptr = 0;
|
||||
cuErrCheck(mDriver->cuModuleGetGlobal(&shmem_dev_ptr, nullptr, mModule, kSmemName), mDriver);
|
||||
TLLM_CHECK(shmem_dev_ptr != 0);
|
||||
cuErrCheck(mDriver->cuMemcpyDtoH(&mSharedMemBytes, shmem_dev_ptr, sizeof(unsigned int)), mDriver);
|
||||
|
||||
TLLM_CHECK(mSharedMemBytes > 0);
|
||||
|
||||
/* Set 46KB threshold here because we have to take static/driver shared memory into consideration. */
|
||||
if (mSharedMemBytes >= 46 * 1024)
|
||||
{
|
||||
cuErrCheck(mDriver->cuFuncSetAttribute(
|
||||
mFunction, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, mSharedMemBytes),
|
||||
mDriver);
|
||||
}
|
||||
|
||||
sync_check_cuda_error();
|
||||
mInitialized = true;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
||||
@ -31,20 +31,35 @@ class CubinObj
|
||||
public:
|
||||
// Default constructor constructs an empty unusable CubinObj instance.
|
||||
CubinObj() = default;
|
||||
CubinObj(std::string const& content);
|
||||
// Constructs from raw cubin content.
|
||||
explicit CubinObj(std::string const& content);
|
||||
// Deserializes from a serialization buffer.
|
||||
CubinObj(void const* buffer, size_t buffer_size);
|
||||
|
||||
CubinObj(CubinObj const& other);
|
||||
CubinObj& operator=(CubinObj const& other);
|
||||
|
||||
// CubinObj can be move-constructed/assigned.
|
||||
CubinObj(CubinObj&& other) = default;
|
||||
CubinObj& operator=(CubinObj&& other) = default;
|
||||
|
||||
// Should be called at least once before calling launch().
|
||||
void initialize();
|
||||
void launch(dim3 gridDim, dim3 blockDim, CUstream hStream, void** kernelParams);
|
||||
|
||||
// It is safe to call getSerializeSize()/serialize() before calling initialize().
|
||||
size_t getSerializationSize() const noexcept;
|
||||
void serialize(void* buffer, size_t buffer_size) const noexcept;
|
||||
|
||||
private:
|
||||
void initialize(char const* content, char const* funcName);
|
||||
|
||||
std::shared_ptr<tensorrt_llm::common::CUDADriverWrapper> mDriver;
|
||||
|
||||
static constexpr char const* kFuncName = "kernel_mha";
|
||||
static constexpr char const* kSmemName = "smemSize";
|
||||
// Constructors should populate mContent.
|
||||
std::string mContent;
|
||||
|
||||
// Fields below are undefined prior to initialize() call.
|
||||
bool mInitialized;
|
||||
std::shared_ptr<tensorrt_llm::common::CUDADriverWrapper> mDriver;
|
||||
CUmodule mModule;
|
||||
CUfunction mFunction;
|
||||
unsigned int mSharedMemBytes;
|
||||
|
||||
@ -29,7 +29,7 @@ namespace kernels
|
||||
namespace jit
|
||||
{
|
||||
|
||||
// A collection of CubinObjs, with caching functionality.
|
||||
// A thread-safe collection of CubinObjs, with caching functionality.
|
||||
template <typename Key, class Hash = std::hash<Key>>
|
||||
class CubinObjRegistryTemplate
|
||||
{
|
||||
@ -64,6 +64,7 @@ public:
|
||||
|
||||
std::unique_ptr<CubinObjRegistryTemplate<Key, Hash>> clone() const noexcept
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
auto result = std::make_unique<CubinObjRegistryTemplate<Key, Hash>>();
|
||||
for (auto const& p : mMap)
|
||||
{
|
||||
@ -74,6 +75,7 @@ public:
|
||||
|
||||
size_t getSerializationSize() const noexcept
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
size_t result = sizeof(uint32_t);
|
||||
for (auto&& p : mMap)
|
||||
{
|
||||
@ -85,6 +87,7 @@ public:
|
||||
|
||||
void serialize(void* buffer_, size_t buffer_size) const noexcept
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
size_t remaining_buffer_size = buffer_size;
|
||||
uint8_t* buffer = static_cast<uint8_t*>(buffer_);
|
||||
uint32_t n = mMap.size();
|
||||
@ -108,31 +111,61 @@ public:
|
||||
TLLM_CHECK(remaining_buffer_size == 0);
|
||||
}
|
||||
|
||||
// Returns directly if the Cubin already exists in the registry, otherwise call compileEngine to compile it.
|
||||
//
|
||||
// compileEngine may be nullptr.
|
||||
CubinObj* getCubin(Key const& key, CompileEngine* compileEngine)
|
||||
// Compiles and inserts the cubin if not found in mMap. Does nothing otherwise.
|
||||
void insertCubinIfNotExists(Key const& key, CompileEngine* compileEngine)
|
||||
{
|
||||
TLLM_CHECK(compileEngine != nullptr);
|
||||
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
|
||||
auto iter = mMap.find(key);
|
||||
if (iter != mMap.end())
|
||||
{
|
||||
return &(iter->second);
|
||||
return;
|
||||
}
|
||||
|
||||
TLLM_CHECK_WITH_INFO(compileEngine != nullptr, "Key not found; compileEngine shouldn't be nullptr.");
|
||||
|
||||
CubinObj obj = compileEngine->compile();
|
||||
auto insertResultIter = mMap.insert({key, std::move(obj)}).first;
|
||||
return &(insertResultIter->second);
|
||||
mMap.insert({key, std::move(obj)});
|
||||
return;
|
||||
}
|
||||
|
||||
void insertCubin(Key const& key, CubinObj&& obj)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
mMap.insert({key, std::forward<CubinObj>(obj)});
|
||||
}
|
||||
|
||||
CubinObj* getCubin(Key const& key)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
auto iter = mMap.find(key);
|
||||
if (iter != mMap.end())
|
||||
{
|
||||
return &iter->second;
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void merge(CubinObjRegistryTemplate<Key, Hash> const& other)
|
||||
{
|
||||
for (auto&& p : other.mMap)
|
||||
{
|
||||
mMap.insert(p);
|
||||
}
|
||||
}
|
||||
|
||||
void clear()
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
mMap.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<Key, CubinObj, Hash> mMap;
|
||||
mutable std::mutex mMutex;
|
||||
};
|
||||
|
||||
using CubinObjKey = XQAKernelFullHashKey;
|
||||
|
||||
@ -36,24 +36,6 @@ XQAKernelRuntimeHashKey getRuntimeHashKeyFromKernelMeta(XQAKernelMetaInfo const&
|
||||
kernelMeta.mMTileSize, kernelMeta.mTokensPerPage, kernelMeta.mPagedKVCache, kernelMeta.mMultiQueryTokens};
|
||||
}
|
||||
|
||||
XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParams)
|
||||
{
|
||||
unsigned int head_size = xqaParams.head_size;
|
||||
int num_q_heads = xqaParams.num_q_heads;
|
||||
int num_kv_heads = xqaParams.num_kv_heads;
|
||||
TLLM_CHECK_WITH_INFO(num_q_heads % num_kv_heads == 0, "numQHeads should be multiple of numKVHeads.");
|
||||
unsigned int num_q_heads_over_kv = num_q_heads / num_kv_heads;
|
||||
unsigned int beam_width = xqaParams.beam_width;
|
||||
// MultiQueryToken kernels can support any num_q_heads_over_kv that is power of 2.
|
||||
unsigned int kernel_num_q_heads_over_kv = xqaParams.multi_query_tokens ? 0 : num_q_heads_over_kv;
|
||||
// MultiQueryToken kernels can handle either 16/32 for M direction per CTA.
|
||||
unsigned int m_tilesize = xqaParams.multi_query_tokens ? 16 : num_q_heads_over_kv;
|
||||
|
||||
return {xqaParams.kv_cache_data_type, head_size, beam_width, kernel_num_q_heads_over_kv, m_tilesize,
|
||||
xqaParams.paged_kv_cache ? static_cast<unsigned int>(xqaParams.tokens_per_block) : 0, xqaParams.paged_kv_cache,
|
||||
xqaParams.multi_query_tokens};
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
namespace tensorrt_llm
|
||||
@ -66,7 +48,6 @@ DecoderXQAImplJIT::DecoderXQAImplJIT(DecoderXQARunner* runner)
|
||||
, mDriver(tensorrt_llm::common::CUDADriverWrapper::getInstance())
|
||||
, mForceXQA(tensorrt_llm::common::forceXQAKernels())
|
||||
, mSM(tensorrt_llm::common::getSMVersion())
|
||||
, mCubinObjRegistry(runner->mResource->getCubinObjRegistry())
|
||||
{
|
||||
initSupportedConfigs();
|
||||
}
|
||||
@ -140,8 +121,24 @@ void DecoderXQAImplJIT::prepare(XQAParams const& xqaParams)
|
||||
|
||||
jit::CompileEngine compileEngine(mSM, xqaParams);
|
||||
|
||||
// Discard getCubin() result.
|
||||
mCubinObjRegistry->getCubin(key, &compileEngine);
|
||||
auto registryGlobal = DecoderXQARunner::getResourceGlobal()->getCubinObjRegistry();
|
||||
jit::CubinObj* uninitializedCubin = registryGlobal->getCubin(key);
|
||||
if (uninitializedCubin != nullptr)
|
||||
{
|
||||
// Inference time. Prepare for the inference.
|
||||
if (mInitializedCubinObjRegistry.getCubin(key) == nullptr)
|
||||
{
|
||||
// Make a copy and initialize it.
|
||||
jit::CubinObj initializedCubin = *uninitializedCubin;
|
||||
initializedCubin.initialize();
|
||||
mInitializedCubinObjRegistry.insertCubin(key, std::move(initializedCubin));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Engine-build time. Compile the cubin and place it into CubinObjRegistry.
|
||||
registryGlobal->insertCubinIfNotExists(key, &compileEngine);
|
||||
}
|
||||
}
|
||||
|
||||
void DecoderXQAImplJIT::runWithKVLinearBuffer(
|
||||
@ -204,9 +201,13 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
|
||||
BuildDecoderInfoParams<T> decoder_params;
|
||||
memset(&decoder_params, 0, sizeof(decoder_params));
|
||||
decoder_params.seqQOffsets = launchParams.cu_seq_lens;
|
||||
decoder_params.seqQLengths = xqaParams.spec_decoding_generation_lengths;
|
||||
decoder_params.seqKVLengths = xqaParams.sequence_lengths;
|
||||
decoder_params.batchSize = int(batch_beam_size);
|
||||
decoder_params.maxQSeqLength = xqaParams.generation_input_length;
|
||||
decoder_params.removePadding = xqaParams.multi_query_tokens;
|
||||
TLLM_CHECK_WITH_INFO(!xqaParams.multi_query_tokens || xqaParams.spec_decoding_generation_lengths != nullptr,
|
||||
"Spec_decoding_generation_lengths must be provided.");
|
||||
// Rotary embedding inv_freq buffer.
|
||||
decoder_params.rotaryEmbeddingScale = xqaParams.rotary_embedding_scale;
|
||||
decoder_params.rotaryEmbeddingBase = xqaParams.rotary_embedding_base;
|
||||
@ -222,16 +223,18 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
|
||||
// NOTE: MHA kernels should read kv cache that has already been appended with new tokens' kv cache.
|
||||
void* xqa_q_input_ptr = ioScratch;
|
||||
QKVPreprocessingParams<T, KVCacheBuffer> preprocessingParms{static_cast<T*>(const_cast<void*>(xqaParams.qkv)),
|
||||
nullptr, static_cast<T*>(xqa_q_input_ptr), kv_cache_buffer, static_cast<T const*>(xqaParams.qkv_bias), nullptr,
|
||||
xqaParams.sequence_lengths, nullptr, launchParams.rotary_inv_freq_buf, (float2 const*) nullptr,
|
||||
xqaParams.kv_scale_orig_quant, xqaParams.spec_decoding_position_offsets, int(batch_beam_size),
|
||||
xqaParams.generation_input_length, xqaParams.timestep, xqaParams.cyclic_attention_window_size,
|
||||
xqaParams.sink_token_length, int(xqaParams.batch_size * beam_width * xqaParams.generation_input_length),
|
||||
xqaParams.num_q_heads, xqaParams.num_kv_heads, xqaParams.num_q_heads / xqaParams.num_kv_heads,
|
||||
xqaParams.head_size, xqaParams.rotary_embedding_dim, xqaParams.rotary_embedding_base,
|
||||
xqaParams.rotary_embedding_scale_type, xqaParams.rotary_embedding_scale,
|
||||
xqaParams.rotary_embedding_max_positions, xqaParams.position_embedding_type, xqaParams.position_shift_enabled,
|
||||
cache_type, true, false, multiprocessor_count};
|
||||
nullptr, static_cast<T*>(xqa_q_input_ptr), kv_cache_buffer, static_cast<T const*>(xqaParams.qkv_bias),
|
||||
xqaParams.spec_decoding_generation_lengths, xqaParams.sequence_lengths,
|
||||
xqaParams.multi_query_tokens ? launchParams.cu_seq_lens : nullptr, launchParams.rotary_inv_freq_buf,
|
||||
(float2 const*) nullptr, xqaParams.kv_scale_orig_quant, xqaParams.spec_decoding_position_offsets,
|
||||
int(batch_beam_size), xqaParams.generation_input_length, xqaParams.timestep,
|
||||
xqaParams.cyclic_attention_window_size, xqaParams.sink_token_length,
|
||||
int(xqaParams.batch_size * beam_width * xqaParams.generation_input_length), xqaParams.num_q_heads,
|
||||
xqaParams.num_kv_heads, xqaParams.num_q_heads / xqaParams.num_kv_heads, xqaParams.head_size,
|
||||
xqaParams.rotary_embedding_dim, xqaParams.rotary_embedding_base, xqaParams.rotary_embedding_scale_type,
|
||||
xqaParams.rotary_embedding_scale, xqaParams.rotary_embedding_max_positions, xqaParams.position_embedding_type,
|
||||
xqaParams.position_shift_enabled, cache_type, true, false, multiprocessor_count, xqaParams.rotary_vision_start,
|
||||
xqaParams.rotary_vision_length};
|
||||
|
||||
invokeQKVPreprocessing<T, KVCacheBuffer>(preprocessingParms, stream);
|
||||
sync_check_cuda_error();
|
||||
@ -245,7 +248,8 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
|
||||
unsigned int kernel_m_tilesize = xqaParams.multi_query_tokens ? mTileSize : num_q_heads_over_kv;
|
||||
|
||||
jit::CubinObjKey key = getCubinObjKeyFromXQAParams(xqaParams);
|
||||
jit::CubinObj* cubinObj = mCubinObjRegistry->getCubin(key, /*compileEngine=*/nullptr);
|
||||
jit::CubinObj* cubinObj = mInitializedCubinObjRegistry.getCubin(key);
|
||||
TLLM_CHECK(cubinObj != nullptr);
|
||||
|
||||
if (xqaParams.multi_query_tokens)
|
||||
{
|
||||
@ -275,8 +279,8 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
|
||||
}
|
||||
else
|
||||
{
|
||||
// mha_sm90.cu kernels. Default to false because it is not available in JIT path for now.
|
||||
bool const isGmmaKernel = false;
|
||||
bool const isGmmaKernel = (mSM == kSM_90 && xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3
|
||||
&& xqaParams.beam_width == 1);
|
||||
constexpr uint32_t kMAX_NB_KERNEL_PARAMS = 11;
|
||||
uint32_t const maxNbKernelParams = (isGmmaKernel ? 11 : 10);
|
||||
uint32_t idxNextParam = 0;
|
||||
|
||||
@ -60,7 +60,8 @@ private:
|
||||
bool mForceXQA;
|
||||
int mSM;
|
||||
|
||||
jit::CubinObjRegistry* mCubinObjRegistry;
|
||||
jit::CubinObjRegistry mInitializedCubinObjRegistry;
|
||||
|
||||
jit::CubinObjKey getCubinObjKeyFromXQAParams(XQAParams const& xqaParams) const;
|
||||
|
||||
//! The first prototype just takes whatever available from the Precompiled cubins.
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f51307e90efbdd3dadc404efafb3b8a96ddbdb89a9068eba0b9676656be7d46d
|
||||
size 80202640
|
||||
oid sha256:8de0cd3bd46925e008f263b3f6c78c17f198578f74e23bc90661bec5a9acfbb1
|
||||
size 80250768
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
b3823dd8e1d7f154019fb7dc24172ff4 libtensorrt_llm_nvrtc_wrapper.so
|
||||
8d4b145290d5984494a1fa6e380d01456534dc62 commit
|
||||
5b6c74ce66f62d2a58aa9cac16f11ad6 libtensorrt_llm_nvrtc_wrapper.so
|
||||
c0bd2b69c932257678a2aad9bd8baba4b291795e commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:08384c1d7a80a86d888f6f23a5687ccb102b1a510b66db8dbcc3169127e4e88a
|
||||
size 83472488
|
||||
oid sha256:bbf358364915d5b023a6d0574cde0f602c104d24efe0bf5c04eeee4610a2413e
|
||||
size 83541760
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:22867facd7d8dfa699618884d2e6912b1a2a7afedc299aa91e14b36353d6b8bd
|
||||
size 1011200
|
||||
oid sha256:84319476e8ecf9666f40f69355f19ec3b585fc0987f940be14af9e11e3f524c3
|
||||
size 1080832
|
||||
|
||||
@ -213,14 +213,7 @@ public:
|
||||
// Use mTileSize = 16 kernels when qSeqLen <= 16.
|
||||
unsigned int qSeqLen = static_cast<unsigned int>(xqaParams.generation_input_length);
|
||||
unsigned int mTileSize = qSeqLen <= 16 ? 16 : 32;
|
||||
// MultiQueryToken kernels can support any num_q_heads_over_kv that is power of 2.
|
||||
unsigned int kernel_num_q_heads_over_kv = xqaParams.multi_query_tokens ? 0 : num_q_heads_over_kv;
|
||||
// MultiQueryToken kernels can handle either 16/32 for M direction per CTA.
|
||||
unsigned int kernel_m_tilesize = xqaParams.multi_query_tokens ? mTileSize : num_q_heads_over_kv;
|
||||
XQAKernelRuntimeHashKey hash_key{xqaParams.kv_cache_data_type, head_size, beam_width,
|
||||
kernel_num_q_heads_over_kv, kernel_m_tilesize,
|
||||
xqaParams.paged_kv_cache ? static_cast<unsigned int>(xqaParams.tokens_per_block) : 0,
|
||||
xqaParams.paged_kv_cache, xqaParams.multi_query_tokens};
|
||||
XQAKernelRuntimeHashKey hash_key = getRuntimeHashKeyFromXQAParams(xqaParams);
|
||||
auto const findIter = mFunctions.find(hash_key);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(findIter != mFunctions.end(), "XQAKernelFunc not found.");
|
||||
@ -310,28 +303,6 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static uint32_t getElemBytes(CUtensorMapDataType_enum dataType)
|
||||
{
|
||||
switch (dataType)
|
||||
{
|
||||
case CU_TENSOR_MAP_DATA_TYPE_UINT8: return 1;
|
||||
case CU_TENSOR_MAP_DATA_TYPE_UINT16: return 2;
|
||||
case CU_TENSOR_MAP_DATA_TYPE_UINT32: return 4;
|
||||
case CU_TENSOR_MAP_DATA_TYPE_INT32: return 4;
|
||||
case CU_TENSOR_MAP_DATA_TYPE_UINT64: return 8;
|
||||
case CU_TENSOR_MAP_DATA_TYPE_INT64: return 8;
|
||||
case CU_TENSOR_MAP_DATA_TYPE_FLOAT16: return 2;
|
||||
case CU_TENSOR_MAP_DATA_TYPE_FLOAT32: return 4;
|
||||
case CU_TENSOR_MAP_DATA_TYPE_FLOAT64: return 8;
|
||||
case CU_TENSOR_MAP_DATA_TYPE_BFLOAT16: return 2;
|
||||
case CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ: return 4;
|
||||
case CU_TENSOR_MAP_DATA_TYPE_TFLOAT32: return 4;
|
||||
case CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ: return 4;
|
||||
}
|
||||
throw std::runtime_error("unsupported data type");
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<tensorrt_llm::common::CUDADriverWrapper> mDriver;
|
||||
|
||||
|
||||
@ -36,10 +36,9 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
DecoderXQARunner::DecoderXQARunner(Resource* resource, const XQADataType data_type, int num_heads, int num_kv_heads,
|
||||
int head_size, bool multi_block_mode)
|
||||
: mResource(resource)
|
||||
, mDataType(data_type)
|
||||
DecoderXQARunner::DecoderXQARunner(
|
||||
const XQADataType data_type, int num_heads, int num_kv_heads, int head_size, bool multi_block_mode)
|
||||
: mDataType(data_type)
|
||||
, mNumHeads(num_heads)
|
||||
, mNumKVHeads(num_kv_heads)
|
||||
, mHeadSize(head_size)
|
||||
@ -104,18 +103,12 @@ size_t DecoderXQARunner::getWorkspaceSize(int max_batch_beam_size, int max_num_t
|
||||
|
||||
DecoderXQAImpl* DecoderXQARunner::getImplFromXQAParams(XQAParams const& xqaParams)
|
||||
{
|
||||
if (tensorrt_llm::common::getSMVersion() == kSM_90)
|
||||
{
|
||||
// Always use Precompiled impl for sm90 until Hopper XQA source gets integrated to JIT codepath.
|
||||
return mPrecompiledImpl.get();
|
||||
}
|
||||
if (xqaParams.multi_query_tokens)
|
||||
{
|
||||
// Use precompiled cubin for medusa, because medusa cubins are generated from a different CUDA source file than
|
||||
// non-medusa.
|
||||
return mPrecompiledImpl.get();
|
||||
}
|
||||
|
||||
if (tensorrt_llm::common::getEnvEnableXQAJIT())
|
||||
{
|
||||
return mJITImpl.get();
|
||||
@ -143,6 +136,12 @@ void DecoderXQARunner::run(
|
||||
return getImplFromXQAParams(xqa_params)->run(xqa_params, kv_cache_buffer, stream);
|
||||
}
|
||||
|
||||
DecoderXQARunner::Resource* DecoderXQARunner::getResourceGlobal()
|
||||
{
|
||||
static DecoderXQARunner::Resource sResource;
|
||||
return &sResource;
|
||||
}
|
||||
|
||||
template void DecoderXQARunner::run(
|
||||
XQAParams const& xqa_params, KVLinearBuffer const& kv_linear_buffer, cudaStream_t const& stream);
|
||||
template void DecoderXQARunner::run(
|
||||
|
||||
@ -77,10 +77,8 @@ struct XQADispatchHelper<__nv_bfloat16, KVBlockArray>
|
||||
class DecoderXQARunner
|
||||
{
|
||||
public:
|
||||
// Resources for constructing a DecoderXQARunner object.
|
||||
class Resource;
|
||||
DecoderXQARunner(Resource* resource, const XQADataType data_type, int num_heads, int num_kv_heads, int head_size,
|
||||
bool multi_block_mode);
|
||||
DecoderXQARunner(
|
||||
const XQADataType data_type, int num_heads, int num_kv_heads, int head_size, bool multi_block_mode);
|
||||
~DecoderXQARunner();
|
||||
|
||||
/**
|
||||
@ -169,6 +167,9 @@ public:
|
||||
this->run(xqa_params, kv_cache_buffer, stream);
|
||||
}
|
||||
|
||||
class Resource;
|
||||
static Resource* getResourceGlobal();
|
||||
|
||||
private:
|
||||
bool shouldUseImpl(XQAParams const& xqa_params, bool for_configure_plugin);
|
||||
void prepareForRun(XQAParams const& xqa_params);
|
||||
@ -178,8 +179,6 @@ private:
|
||||
|
||||
static constexpr int kMaxBeamWidth = 4;
|
||||
|
||||
Resource* mResource;
|
||||
|
||||
XQADataType mDataType;
|
||||
int mNumHeads;
|
||||
int mNumKVHeads;
|
||||
@ -206,11 +205,21 @@ public:
|
||||
Resource(void const* buffer, size_t buffer_size);
|
||||
~Resource() = default;
|
||||
|
||||
void merge(Resource const& other)
|
||||
{
|
||||
getCubinObjRegistry()->merge(*other.getCubinObjRegistry());
|
||||
}
|
||||
|
||||
jit::CubinObjRegistry* getCubinObjRegistry()
|
||||
{
|
||||
return mCubinObjRegistry.get();
|
||||
}
|
||||
|
||||
jit::CubinObjRegistry const* getCubinObjRegistry() const
|
||||
{
|
||||
return mCubinObjRegistry.get();
|
||||
}
|
||||
|
||||
size_t getSerializationSize() const noexcept;
|
||||
void serialize(void* buffer, size_t buffer_size) const noexcept;
|
||||
|
||||
|
||||
@ -339,7 +339,7 @@ __global__ void insertUnfinishedPathKernel(BeamHypotheses bh)
|
||||
// Other parameters
|
||||
bh.sequenceLengthsCBA[dstBeam] = bh.sequenceLengths[srcBeam];
|
||||
bh.normedScoresCBA[dstBeam]
|
||||
= applyLengthPenalty(bh.cumLogProbs[srcBeam], step - bh.inputLengths[srcBeam], bh.lengthPenalties[bid]);
|
||||
= applyLengthPenalty(bh.cumLogProbs[srcBeam], step - bh.inputLengths[srcBeam] + 1, bh.lengthPenalties[bid]);
|
||||
bh.cumLogProbsCBA[dstBeam] = bh.cumLogProbs[srcBeam];
|
||||
bh.numBeamsCBA[bid]++;
|
||||
}
|
||||
|
||||
@ -1070,7 +1070,7 @@ std::vector<size_t> CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::getWo
|
||||
size_t const sorter_size = CubKeyValueSorter::getWorkspaceSize(num_rows, num_experts);
|
||||
size_t const fc2_result_size = permuted_elems * gemm_output_dtype; // May be an intermediate type for quantization
|
||||
size_t const hopper_size = using_hopper ? HopperGroupedGemmInput::workspaceSize(num_experts_per_node) : 0;
|
||||
size_t const gemm_workspace_size = moe_gemm_runner_.calcMaxWorkspaceSize(num_experts_per_node);
|
||||
size_t const gemm_workspace_size = moe_gemm_runner_.getMaxWorkspaceSize(num_experts_per_node);
|
||||
|
||||
std::vector<size_t> workspace{source_rows_size, permuted_rows_size, permuted_experts_size, permuted_data_size,
|
||||
total_rows_before_expert_size, softmax_out_size, glu_inter_size,
|
||||
@ -1085,7 +1085,7 @@ size_t CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::getWorkspaceSize(i
|
||||
ActivationType activation_type, MOEParallelismConfig parallelism_config) const
|
||||
{
|
||||
int const ep_size = parallelism_config.ep_size;
|
||||
TLLM_CHECK_WITH_INFO(num_experts % ep_size == 0, "Number of experts must be a multiple of tp size");
|
||||
TLLM_CHECK_WITH_INFO(num_experts % ep_size == 0, "Number of experts must be a multiple of ep size");
|
||||
auto workspace = getWorkspaceBufferSizes(
|
||||
num_rows, hidden_size, inter_size, num_experts, num_experts / ep_size, k, activation_type);
|
||||
return tensorrt_llm::common::calculateTotalWorkspaceSize(workspace.data(), workspace.size());
|
||||
|
||||
@ -52,15 +52,6 @@ private:
|
||||
int num_bits_;
|
||||
};
|
||||
|
||||
enum class MOEParallelismMode : int
|
||||
{
|
||||
NONE = 0, //!< Ignore parallelism and duplicate the work across all nodes
|
||||
EXPERT_PARALLELISM, //!< Divide the experts between each node. The number of experts must be a multiple of
|
||||
//!< parallelism
|
||||
TENSOR_PARALLELISM, //!< Divide the weight matrices between the nodes. The hidden dimension must be a multiple of
|
||||
//!< parallelism
|
||||
};
|
||||
|
||||
enum class MOEExpertScaleNormalizationMode : int
|
||||
{
|
||||
NONE = 0, //!< Run the softmax on all scales and select the topk
|
||||
@ -91,20 +82,23 @@ enum class MOEExpertScaleNormalizationMode : int
|
||||
*/
|
||||
struct MOEParallelismConfig
|
||||
{
|
||||
constexpr static MOEParallelismConfig TensorParallelism(int tp_size, int tp_rank)
|
||||
int tp_size = 1;
|
||||
int tp_rank = 0;
|
||||
int ep_size = 1;
|
||||
int ep_rank = 0;
|
||||
|
||||
bool operator==(MOEParallelismConfig const& other) const
|
||||
{
|
||||
return {tp_size, tp_rank, 1, 0};
|
||||
return tp_size == other.tp_size && tp_rank == other.tp_rank && ep_size == other.ep_size
|
||||
&& ep_rank == other.ep_rank;
|
||||
}
|
||||
|
||||
constexpr static MOEParallelismConfig ExpertParallelism(int ep_size, int ep_rank)
|
||||
friend std::ostream& operator<<(std::ostream& os, MOEParallelismConfig const& config)
|
||||
{
|
||||
return {1, 0, ep_size, ep_rank};
|
||||
os << "tp_size: " << config.tp_size << ", tp_rank: " << config.tp_rank << ", ep_size: " << config.ep_size
|
||||
<< ", ep_rank: " << config.ep_rank;
|
||||
return os;
|
||||
}
|
||||
|
||||
int const tp_size = 1;
|
||||
int const tp_rank = 0;
|
||||
int const ep_size = 1;
|
||||
int const ep_rank = 0;
|
||||
};
|
||||
|
||||
struct QuantParams
|
||||
|
||||
@ -14,11 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
|
||||
#include "tensorrt_llm/kernels/speculativeDecoding/explicitDraftTokensKernels.h"
|
||||
#ifndef CUDART_VERSION
|
||||
#error CUDART_VERSION Undefined!
|
||||
@ -33,35 +29,34 @@ using namespace tensorrt_llm::runtime;
|
||||
|
||||
namespace tensorrt_llm::kernels::speculative_decoding
|
||||
{
|
||||
size_t invokeScanSpecDecodingGenerationLengths(void* __restrict__ scanTempStorage, size_t scanTempStorageBytes,
|
||||
SizeType32 const* __restrict__ specDecodingGenerationLengths,
|
||||
SizeType32* __restrict__ maxSpecDecodingGenerationLengths, SizeType32 batchSize, cudaStream_t stream)
|
||||
size_t invokeScanGenerationLengths(void* __restrict__ scanTempStorage, size_t scanTempStorageBytes,
|
||||
SizeType32 const* __restrict__ generationLengths, SizeType32* __restrict__ scannedGenerationLengths,
|
||||
SizeType32 batchSize, cudaStream_t stream)
|
||||
{
|
||||
cub::DeviceScan::InclusiveSum(scanTempStorage, scanTempStorageBytes, specDecodingGenerationLengths,
|
||||
maxSpecDecodingGenerationLengths, batchSize, stream);
|
||||
cub::DeviceScan::InclusiveSum(
|
||||
scanTempStorage, scanTempStorageBytes, generationLengths, scannedGenerationLengths, batchSize, stream);
|
||||
return scanTempStorageBytes;
|
||||
}
|
||||
|
||||
size_t invokeReduceMaxSpecDecodingGenerationLengths(void* __restrict__ reduceMaxTempStorage,
|
||||
size_t reduceTempStorageBytes, SizeType32 const* __restrict__ specDecodingGenerationLengths,
|
||||
SizeType32* __restrict__ scannedSpecDecodingGenerationLengths, SizeType32 batchSize, cudaStream_t stream)
|
||||
size_t invokeReduceMaxGenerationLengths(void* __restrict__ reduceMaxTempStorage, size_t reduceTempStorageBytes,
|
||||
SizeType32 const* __restrict__ generationLengths, SizeType32* __restrict__ maxGenerationLengths,
|
||||
SizeType32 batchSize, cudaStream_t stream)
|
||||
{
|
||||
cub::DeviceReduce::Max(reduceMaxTempStorage, reduceTempStorageBytes, specDecodingGenerationLengths,
|
||||
scannedSpecDecodingGenerationLengths, batchSize, stream);
|
||||
cub::DeviceReduce::Max(
|
||||
reduceMaxTempStorage, reduceTempStorageBytes, generationLengths, maxGenerationLengths, batchSize, stream);
|
||||
return reduceTempStorageBytes;
|
||||
}
|
||||
|
||||
// inclusive prefix sum specDecodingGenerationLengths and reduce max specDecodingGenerationLengths
|
||||
void invokeScanReduceSpecDecodingGenerationLengths(SizeType32 batchSize,
|
||||
SizeType32 const* __restrict__ specDecodingGenerationLengths, void* __restrict__ scanTempStorage,
|
||||
size_t scanTempStorageBytes, SizeType32* __restrict__ scanedSpecDecodingGenerationLengths,
|
||||
void* __restrict__ reduceMaxTempStorage, size_t reduceMaxTempStorageBytes,
|
||||
SizeType32* maxSpecDecodingGenerationLengths, cudaStream_t stream)
|
||||
// inclusive prefix sum generationLengths and reduce max generationLengths
|
||||
void invokeScanReduceGenerationLengths(SizeType32 batchSize, SizeType32 const* __restrict__ generationLengths,
|
||||
void* __restrict__ scanTempStorage, size_t scanTempStorageBytes, SizeType32* __restrict__ scanedGenerationLengths,
|
||||
void* __restrict__ reduceMaxTempStorage, size_t reduceMaxTempStorageBytes, SizeType32* maxGenerationLengths,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
invokeScanSpecDecodingGenerationLengths(scanTempStorage, scanTempStorageBytes, specDecodingGenerationLengths,
|
||||
scanedSpecDecodingGenerationLengths, batchSize, stream);
|
||||
invokeReduceMaxSpecDecodingGenerationLengths(reduceMaxTempStorage, reduceMaxTempStorageBytes,
|
||||
specDecodingGenerationLengths, maxSpecDecodingGenerationLengths, batchSize, stream);
|
||||
invokeScanGenerationLengths(
|
||||
scanTempStorage, scanTempStorageBytes, generationLengths, scanedGenerationLengths, batchSize, stream);
|
||||
invokeReduceMaxGenerationLengths(
|
||||
reduceMaxTempStorage, reduceMaxTempStorageBytes, generationLengths, maxGenerationLengths, batchSize, stream);
|
||||
}
|
||||
|
||||
////////////////////////
|
||||
@ -100,27 +95,25 @@ __device__ SizeType32 positivePowerOfTwo(SizeType32 n)
|
||||
return res;
|
||||
}
|
||||
|
||||
__global__ void getSpecDecodingPackedMask(SizeType32 const* __restrict__ specDecodingCumGenerationLengths,
|
||||
SizeType32 const* __restrict__ specDecodingMaxGenerationLengths, bool const* __restrict__ specDecodingMask,
|
||||
SizeType32 const* __restrict__ batchSlots, SizeType32 maxDraftTokens,
|
||||
SizeType32* __restrict__ specDecodingPackedMask)
|
||||
__global__ void getPackedMask(SizeType32 const* __restrict__ cumGenerationLengths,
|
||||
SizeType32 const* __restrict__ maxGenerationLengths, bool const* __restrict__ mask,
|
||||
SizeType32 const* __restrict__ batchSlots, SizeType32 maxDraftTokens, SizeType32* __restrict__ packedMask)
|
||||
{
|
||||
auto const batchIdx = static_cast<SizeType32>(blockIdx.y);
|
||||
auto const tokenIdx = static_cast<SizeType32>(blockIdx.x);
|
||||
|
||||
auto const numTokens = (batchIdx == 0)
|
||||
? specDecodingCumGenerationLengths[0]
|
||||
: specDecodingCumGenerationLengths[batchIdx] - specDecodingCumGenerationLengths[batchIdx - 1];
|
||||
auto const numTokens = (batchIdx == 0) ? cumGenerationLengths[0]
|
||||
: cumGenerationLengths[batchIdx] - cumGenerationLengths[batchIdx - 1];
|
||||
if (tokenIdx >= numTokens)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
auto const maxGenerationLength = specDecodingMaxGenerationLengths[0];
|
||||
auto const maxGenerationLength = maxGenerationLengths[0];
|
||||
auto const numPackedMasks = divUp(maxDraftTokens, 32);
|
||||
auto const outputStartId = batchSlots ? (batchSlots[batchIdx] * (maxDraftTokens + 1))
|
||||
: ((batchIdx == 0) ? 0 : specDecodingCumGenerationLengths[batchIdx - 1]);
|
||||
auto* outputPtr = specDecodingPackedMask + (outputStartId + tokenIdx) * numPackedMasks;
|
||||
: ((batchIdx == 0) ? 0 : cumGenerationLengths[batchIdx - 1]);
|
||||
auto* outputPtr = packedMask + (outputStartId + tokenIdx) * numPackedMasks;
|
||||
if (tokenIdx == 0)
|
||||
{
|
||||
for (auto maskId = static_cast<SizeType32>(threadIdx.x); maskId < numPackedMasks;
|
||||
@ -132,18 +125,18 @@ __global__ void getSpecDecodingPackedMask(SizeType32 const* __restrict__ specDec
|
||||
}
|
||||
else
|
||||
{
|
||||
bool const* specDecodingMaskPtr = specDecodingMask + batchIdx * maxGenerationLength * maxGenerationLength
|
||||
+ tokenIdx * maxGenerationLength + 1;
|
||||
extern __shared__ char shSpecDecodingMask[];
|
||||
bool const* maskPtr
|
||||
= mask + batchIdx * maxGenerationLength * maxGenerationLength + tokenIdx * maxGenerationLength + 1;
|
||||
extern __shared__ char shMask[];
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
shSpecDecodingMask[maxGenerationLength - 1] = '1';
|
||||
shMask[maxGenerationLength - 1] = '1';
|
||||
}
|
||||
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < maxGenerationLength - 1;
|
||||
ti += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
auto const shIndex = maxGenerationLength - 1 - ti - 1;
|
||||
shSpecDecodingMask[shIndex] = specDecodingMaskPtr[ti] ? '1' : '0';
|
||||
shMask[shIndex] = maskPtr[ti] ? '1' : '0';
|
||||
}
|
||||
__syncthreads();
|
||||
for (auto maskId = static_cast<SizeType32>(threadIdx.x); maskId < numPackedMasks;
|
||||
@ -156,19 +149,19 @@ __global__ void getSpecDecodingPackedMask(SizeType32 const* __restrict__ specDec
|
||||
}
|
||||
else
|
||||
{
|
||||
auto const shSpecDecodingMaskIndexStart
|
||||
auto const shMaskIndexStart
|
||||
= ((maxGenerationLength - (maskId + 1) * 32) < 0) ? 0 : (maxGenerationLength - (maskId + 1) * 32);
|
||||
auto const shSpecDecodingMaskIndexEnd = maxGenerationLength - (maskId * 32 + 1) + 1;
|
||||
auto const shMaskIndexEnd = maxGenerationLength - (maskId * 32 + 1) + 1;
|
||||
|
||||
auto const validNumBits = shSpecDecodingMaskIndexEnd - shSpecDecodingMaskIndexStart;
|
||||
auto const firstBit1 = (shSpecDecodingMask[shSpecDecodingMaskIndexStart] == '1') ? true : false;
|
||||
auto const validNumBits = shMaskIndexEnd - shMaskIndexStart;
|
||||
auto const firstBit1 = (shMask[shMaskIndexStart] == '1') ? true : false;
|
||||
SizeType32 mask31bits = 0;
|
||||
if (validNumBits != 1)
|
||||
{
|
||||
for (auto i = shSpecDecodingMaskIndexStart + 1; i < shSpecDecodingMaskIndexEnd; i++)
|
||||
for (auto i = shMaskIndexStart + 1; i < shMaskIndexEnd; i++)
|
||||
{
|
||||
auto const index = (validNumBits - 1) - (i - shSpecDecodingMaskIndexStart - 1) - 1;
|
||||
mask31bits += (shSpecDecodingMask[i] == '1') ? positivePowerOfTwo(index) : 0;
|
||||
auto const index = (validNumBits - 1) - (i - shMaskIndexStart - 1) - 1;
|
||||
mask31bits += (shMask[i] == '1') ? positivePowerOfTwo(index) : 0;
|
||||
}
|
||||
}
|
||||
SizeType32 mask32bits;
|
||||
@ -187,19 +180,47 @@ __global__ void getSpecDecodingPackedMask(SizeType32 const* __restrict__ specDec
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void invokeConvertSpecDecodingMaskToPackedMask(SizeType32 batchSize,
|
||||
SizeType32 const* __restrict__ specDecodingCumGenerationLengths,
|
||||
SizeType32 const* __restrict__ specDecodingMaxGenerationLengths, bool const* __restrict__ specDecodingMask,
|
||||
void invokeConvertMaskToPackedMask(SizeType32 batchSize, SizeType32 const* __restrict__ cumGenerationLengths,
|
||||
SizeType32 const* __restrict__ maxGenerationLengths, bool const* __restrict__ mask,
|
||||
SizeType32 const* __restrict__ batchSlots, SizeType32 maxDraftTokens, SizeType32 maxGenerationLength,
|
||||
SizeType32* __restrict__ specDecodingPackedMask, cudaStream_t stream)
|
||||
SizeType32* __restrict__ packedMask, cudaStream_t stream)
|
||||
{
|
||||
dim3 block(32);
|
||||
dim3 grid(maxGenerationLength, batchSize);
|
||||
size_t shmSize = maxGenerationLength * sizeof(char);
|
||||
getSpecDecodingPackedMask<<<grid, block, shmSize, stream>>>(specDecodingCumGenerationLengths,
|
||||
specDecodingMaxGenerationLengths, specDecodingMask, batchSlots, maxDraftTokens, specDecodingPackedMask);
|
||||
getPackedMask<<<grid, block, shmSize, stream>>>(
|
||||
cumGenerationLengths, maxGenerationLengths, mask, batchSlots, maxDraftTokens, packedMask);
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
template <typename T>
|
||||
__global__ void fillContextBuffers(FillContextExplicitDraftTokensParams<T> params)
|
||||
{
|
||||
auto const bid = static_cast<SizeType32>(blockIdx.x);
|
||||
auto const batchSlot = params.batchSlots ? params.batchSlots[bid] : bid;
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
// Generate new random data for sampling.
|
||||
params.randDataSample[batchSlot] = static_cast<T>(curand_uniform(params.curandState + batchSlot));
|
||||
|
||||
// Copy temperature.
|
||||
params.outputTemperatures[batchSlot] = __frcp_rn(params.inputTemperatures[batchSlot]);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void invokeFillContextBuffers(FillContextExplicitDraftTokensParams<T> const& params, cudaStream_t stream)
|
||||
{
|
||||
SizeType32 constexpr BLOCK_SIZE = 32;
|
||||
fillContextBuffers<<<params.batchSize, BLOCK_SIZE, 0, stream>>>(params);
|
||||
}
|
||||
|
||||
template void invokeFillContextBuffers(FillContextExplicitDraftTokensParams<float> const& params, cudaStream_t stream);
|
||||
template void invokeFillContextBuffers(FillContextExplicitDraftTokensParams<half> const& params, cudaStream_t stream);
|
||||
|
||||
namespace
|
||||
{
|
||||
template <typename T>
|
||||
@ -216,6 +237,8 @@ __global__ void extractExplicitDraftTokens(ExtractExplicitDraftTokensParams<T> p
|
||||
auto const bestPathIdx = params.bestPathIndices[bid];
|
||||
// Get current seq len (w/o newly accepted tokens).
|
||||
auto const curSeqLen = params.sequenceLengths[batchSlot];
|
||||
// `last*` tensors do not have data for context requests.
|
||||
auto const lastTensorBid = bid - params.numContextRequests;
|
||||
|
||||
// Get output ids.
|
||||
auto* outputIdsRequest = params.outputIds + batchSlot * params.maxSeqLen;
|
||||
@ -237,7 +260,8 @@ __global__ void extractExplicitDraftTokens(ExtractExplicitDraftTokensParams<T> p
|
||||
{
|
||||
// Read 1:bestPathLength slice of last draft tokens at best path idx.
|
||||
// This tensor comes directly from engine and has linear batch index.
|
||||
auto const pathOffset = flat_index3(bid, bestPathIdx, ti + 1, params.numPaths, params.maxPathLength);
|
||||
auto const pathOffset
|
||||
= flat_index3(lastTensorBid, bestPathIdx, ti + 1, params.numPaths, params.maxPathLength);
|
||||
// Read accepted token from last draft tokens.
|
||||
acceptedToken = params.lastDraftTokens[pathOffset];
|
||||
}
|
||||
@ -253,6 +277,11 @@ __global__ void extractExplicitDraftTokens(ExtractExplicitDraftTokensParams<T> p
|
||||
= params.nextDraftTokens[bid * params.numPaths * params.maxPathLength + ti];
|
||||
params.unpackedNextDraftIndices[batchSlot * params.numPaths * params.maxPathLength + ti]
|
||||
= params.inputUnpackedNextDraftIndices[bid * params.numPaths * params.maxPathLength + ti];
|
||||
if (lastTensorBid >= 0)
|
||||
{
|
||||
params.outputLastDraftIndices[batchSlot * params.numPaths * params.maxPathLength + ti]
|
||||
= params.lastDraftIndices[lastTensorBid * params.numPaths * params.maxPathLength + ti];
|
||||
}
|
||||
}
|
||||
|
||||
auto const numNextDraftTokens = (bid == 0)
|
||||
@ -274,14 +303,13 @@ __global__ void extractExplicitDraftTokens(ExtractExplicitDraftTokensParams<T> p
|
||||
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < numNextDraftTokens;
|
||||
ti += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
params.outputPositionIds[batchSlot * maxDecodingTokens + ti] = params.packedPositionIds[startId + ti];
|
||||
params.outputPositionIds[batchSlot * maxDecodingTokens + ti] = params.packedPositionIds[startId + ti] - 1;
|
||||
}
|
||||
|
||||
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < params.numPaths * (params.maxPathLength - 1);
|
||||
ti += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
// Generate new random data for token verification.
|
||||
// This tensor goes directly to engine and has linear batch index.
|
||||
auto const offset = flat_index2(batchSlot, ti, params.numPaths * (params.maxPathLength - 1));
|
||||
params.randDataVerification[offset] = static_cast<T>(curand_uniform(params.curandState + batchSlot));
|
||||
}
|
||||
@ -291,23 +319,31 @@ __global__ void extractExplicitDraftTokens(ExtractExplicitDraftTokensParams<T> p
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
// Update pos id base.
|
||||
// This tensor goes directly to engine and has linear batch index.
|
||||
params.outputPositionIdsBase[batchSlot] = params.inputPositionIdsBase[bid] + bestPathLength;
|
||||
|
||||
// Set number of accepted tokens at this iteration.
|
||||
params.acceptedLengths[batchSlot] = bestPathLength;
|
||||
|
||||
// Set number of draft tokens for the next iteration.
|
||||
params.prevDraftLengths[batchSlot] = params.nextDraftLengths[batchSlot];
|
||||
|
||||
// Set number of draft tokens for the next iteration.
|
||||
params.nextDraftLengths[batchSlot] = numNextDraftTokens - 1;
|
||||
|
||||
// Set number of tokens passed to the engine per request for the next iteration.
|
||||
params.outputGenerationLengths[batchSlot] = numNextDraftTokens;
|
||||
|
||||
// Generate new random data for sampling.
|
||||
// This tensor goes directly to engine and has linear batch index.
|
||||
params.randDataSample[batchSlot] = static_cast<T>(curand_uniform(params.curandState + batchSlot));
|
||||
|
||||
// Increase seqLen by accepted len.
|
||||
params.sequenceLengths[batchSlot] = curSeqLen + bestPathLength;
|
||||
|
||||
// Copy temperature.
|
||||
params.outputTemperatures[batchSlot] = params.inputTemperatures[batchSlot];
|
||||
params.outputTemperatures[batchSlot] = __frcp_rn(params.inputTemperatures[batchSlot]);
|
||||
|
||||
// Copy best path index.
|
||||
params.outputBestPathIndices[batchSlot] = bestPathIdx;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
@ -328,12 +364,13 @@ namespace
|
||||
{
|
||||
template <typename VecT>
|
||||
__global__ void copyProbs(uint8_t const* srcData, uint8_t* dstData, SizeType32 const* inputBatchSlots,
|
||||
SizeType32 const* outputBatchSlots, SizeType32 sizeInBytes)
|
||||
SizeType32 const* outputBatchSlots, SizeType32 sizeInBytes, SizeType32 inputBatchIdxOffset)
|
||||
{
|
||||
auto constexpr VEC_ELTS = static_cast<SizeType32>(sizeof(VecT));
|
||||
auto const bid = static_cast<SizeType32>(blockIdx.y);
|
||||
auto const intputBatchSlot = inputBatchSlots ? inputBatchSlots[bid] : bid;
|
||||
auto const outputBatchSlot = outputBatchSlots ? outputBatchSlots[bid] : bid;
|
||||
auto const inputBid = static_cast<SizeType32>(blockIdx.y) + inputBatchIdxOffset;
|
||||
auto const outputBid = static_cast<SizeType32>(blockIdx.y);
|
||||
auto const intputBatchSlot = inputBatchSlots ? inputBatchSlots[inputBid] : inputBid;
|
||||
auto const outputBatchSlot = outputBatchSlots ? outputBatchSlots[outputBid] : outputBid;
|
||||
auto const srcStartIdx = intputBatchSlot * sizeInBytes;
|
||||
auto const dstStartIdx = outputBatchSlot * sizeInBytes;
|
||||
auto const tidx = (static_cast<std::size_t>(blockIdx.x) * blockDim.x + threadIdx.x) * VEC_ELTS;
|
||||
@ -351,7 +388,8 @@ __global__ void copyProbs(uint8_t const* srcData, uint8_t* dstData, SizeType32 c
|
||||
} // namespace
|
||||
|
||||
void invokeCopyProbs(uint8_t const* srcDataPtr, uint8_t* dstDataPtr, SizeType32 const* inputBatchSlots,
|
||||
SizeType32 const* outputBatchSlots, SizeType32 batchSize, SizeType32 copyRowSizeInBytes, cudaStream_t stream)
|
||||
SizeType32 const* outputBatchSlots, SizeType32 batchSize, SizeType32 inputBatchIdxOffset,
|
||||
SizeType32 copyRowSizeInBytes, cudaStream_t stream)
|
||||
{
|
||||
auto copyProbsInvocation = copyProbs<uint8_t>;
|
||||
if (copyRowSizeInBytes % 16 == 0)
|
||||
@ -375,7 +413,7 @@ void invokeCopyProbs(uint8_t const* srcDataPtr, uint8_t* dstDataPtr, SizeType32
|
||||
SizeType32 constexpr BLOCKS_PER_ROW{32};
|
||||
dim3 const gridSize{BLOCKS_PER_ROW, static_cast<uint32_t>(batchSize)};
|
||||
copyProbsInvocation<<<gridSize, blockSize, 0, stream>>>(
|
||||
srcDataPtr, dstDataPtr, inputBatchSlots, outputBatchSlots, copyRowSizeInBytes);
|
||||
srcDataPtr, dstDataPtr, inputBatchSlots, outputBatchSlots, copyRowSizeInBytes, inputBatchIdxOffset);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -386,12 +424,41 @@ void invokeCopyProbs(ExtractExplicitDraftTokensParams<T> const& params, cudaStre
|
||||
auto const numCopyElems = params.numPaths * (params.maxPathLength - 1) * params.vocabSize;
|
||||
auto const copyRowSizeInBytes = numCopyElems * sizeof(T);
|
||||
|
||||
invokeCopyProbs(srcDataPtr, dstDataPtr, nullptr, params.batchSlots, params.batchSize, copyRowSizeInBytes, stream);
|
||||
invokeCopyProbs(
|
||||
srcDataPtr, dstDataPtr, nullptr, params.batchSlots, params.batchSize, 0, copyRowSizeInBytes, stream);
|
||||
}
|
||||
|
||||
template void invokeCopyProbs(ExtractExplicitDraftTokensParams<float> const& params, cudaStream_t stream);
|
||||
template void invokeCopyProbs(ExtractExplicitDraftTokensParams<half> const& params, cudaStream_t stream);
|
||||
|
||||
namespace
|
||||
{
|
||||
template <typename T>
|
||||
__global__ void packGenerationLengths(PackExplicitDraftTokensParams<T> params)
|
||||
{
|
||||
auto const batchIdx = static_cast<SizeType32>(blockIdx.x);
|
||||
auto const batchSlot = params.batchSlots ? params.batchSlots[batchIdx] : batchIdx;
|
||||
|
||||
auto const isGenerationRequest = batchIdx >= params.numContextRequests;
|
||||
auto const genIdx = batchIdx - params.numContextRequests;
|
||||
|
||||
if (threadIdx.x == 0 && isGenerationRequest)
|
||||
{
|
||||
params.outputGenerationLengths[genIdx] = params.inputGenerationLengths[batchSlot];
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void invokePackGenerationLengths(PackExplicitDraftTokensParams<T> const& params, cudaStream_t stream)
|
||||
{
|
||||
SizeType32 constexpr BLOCK_SIZE = 32;
|
||||
packGenerationLengths<<<params.batchSize, BLOCK_SIZE, 0, stream>>>(params);
|
||||
}
|
||||
|
||||
template void invokePackGenerationLengths(PackExplicitDraftTokensParams<float> const& params, cudaStream_t stream);
|
||||
template void invokePackGenerationLengths(PackExplicitDraftTokensParams<half> const& params, cudaStream_t stream);
|
||||
|
||||
namespace
|
||||
{
|
||||
template <typename T>
|
||||
@ -400,61 +467,79 @@ __global__ void packExplicitDraftTokens(PackExplicitDraftTokensParams<T> params)
|
||||
auto const batchIdx = static_cast<SizeType32>(blockIdx.x);
|
||||
auto const batchSlot = params.batchSlots ? params.batchSlots[batchIdx] : batchIdx;
|
||||
|
||||
auto const isGenerationRequest = batchIdx >= params.numContextRequests;
|
||||
auto const genIdx = batchIdx - params.numContextRequests;
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
params.outputPositionIdsBase[batchIdx] = params.inputPositionIdsBase[batchSlot];
|
||||
params.outputGenerationLengths[batchIdx] = params.inputGenerationLengths[batchSlot];
|
||||
params.outputRandomDataSample[batchIdx] = params.inputRandomDataSample[batchSlot];
|
||||
params.outputTemperatures[batchIdx] = params.inputTemperatures[batchSlot];
|
||||
}
|
||||
|
||||
// Copy random validation data.
|
||||
auto const numDecodingDraftTokens = params.numPaths * (params.maxPathLength - 1);
|
||||
auto outputRandomDataValidation = params.outputRandomDataValidation + batchIdx * numDecodingDraftTokens;
|
||||
auto inputRandomDataValidation = params.inputRandomDataValidation + batchSlot * numDecodingDraftTokens;
|
||||
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < numDecodingDraftTokens;
|
||||
ti += static_cast<SizeType32>(blockDim.x))
|
||||
if (isGenerationRequest)
|
||||
{
|
||||
outputRandomDataValidation[ti] = inputRandomDataValidation[ti];
|
||||
auto outputRandomDataValidation = params.outputRandomDataValidation + genIdx * numDecodingDraftTokens;
|
||||
auto const inputRandomDataValidation = params.inputRandomDataValidation + batchSlot * numDecodingDraftTokens;
|
||||
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < numDecodingDraftTokens;
|
||||
ti += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
outputRandomDataValidation[ti] = inputRandomDataValidation[ti];
|
||||
}
|
||||
}
|
||||
|
||||
// Copy draft tokens and indices
|
||||
auto const numUnpackedTokens = numDecodingDraftTokens + params.numPaths;
|
||||
auto outputNextDraftTokens = params.outputNextDraftTokens + batchIdx * numUnpackedTokens;
|
||||
auto outputNextDraftIndices = params.outputNextDraftIndices + batchIdx * numUnpackedTokens;
|
||||
auto const inputNextDraftTokens = params.inputNextDraftTokens + batchSlot * numUnpackedTokens;
|
||||
auto const inputNextDraftIndices = params.inputNextDraftIndices + batchSlot * numUnpackedTokens;
|
||||
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < numUnpackedTokens;
|
||||
ti += static_cast<SizeType32>(blockDim.x))
|
||||
if (isGenerationRequest)
|
||||
{
|
||||
outputNextDraftTokens[ti] = inputNextDraftTokens[ti];
|
||||
outputNextDraftIndices[ti] = inputNextDraftIndices[ti];
|
||||
auto const numUnpackedTokens = numDecodingDraftTokens + params.numPaths;
|
||||
auto outputNextDraftTokens = params.outputNextDraftTokens + genIdx * numUnpackedTokens;
|
||||
auto outputNextDraftIndices = params.outputNextDraftIndices + genIdx * numUnpackedTokens;
|
||||
auto const inputNextDraftTokens = params.inputNextDraftTokens + batchSlot * numUnpackedTokens;
|
||||
auto const inputNextDraftIndices = params.inputNextDraftIndices + batchSlot * numUnpackedTokens;
|
||||
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < numUnpackedTokens;
|
||||
ti += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
outputNextDraftTokens[ti] = inputNextDraftTokens[ti];
|
||||
outputNextDraftIndices[ti] = inputNextDraftIndices[ti];
|
||||
}
|
||||
}
|
||||
|
||||
auto const maxGenerationLength = params.maxGenerationLength[0];
|
||||
auto const maxDecodingTokens = numDecodingDraftTokens + 1;
|
||||
auto const numPackedMasks = divUp(maxGenerationLength, 32);
|
||||
auto const outputMaskStartId = (batchIdx == 0) ? 0 : params.cumSumGenerationLengths[batchIdx - 1];
|
||||
auto const numTokens = (batchIdx == 0)
|
||||
auto const outputMaskStartId = (genIdx == 0) ? 0 : params.cumSumGenerationLengths[genIdx - 1];
|
||||
auto const numTokens = (genIdx == 0)
|
||||
? params.cumSumGenerationLengths[0]
|
||||
: params.cumSumGenerationLengths[batchIdx] - params.cumSumGenerationLengths[batchIdx - 1];
|
||||
: params.cumSumGenerationLengths[genIdx] - params.cumSumGenerationLengths[genIdx - 1];
|
||||
// Copy packed masks.
|
||||
// Masks are placed next to each other with offsets of cumSumGenerationLengths[bi-1]
|
||||
auto const inputPackedMask = params.inputPackedMask + batchSlot * numPackedMasks * maxDecodingTokens;
|
||||
auto outputPackedMask = params.outputPackedMask + outputMaskStartId * numPackedMasks;
|
||||
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < numTokens * numPackedMasks;
|
||||
ti += static_cast<SizeType32>(blockDim.x))
|
||||
if (isGenerationRequest)
|
||||
{
|
||||
outputPackedMask[ti] = inputPackedMask[ti];
|
||||
auto const inputPackedMask = params.inputPackedMask + batchSlot * numPackedMasks * maxDecodingTokens;
|
||||
auto outputPackedMask = params.outputPackedMask + outputMaskStartId * numPackedMasks;
|
||||
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < numTokens * numPackedMasks;
|
||||
ti += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
outputPackedMask[ti] = inputPackedMask[ti];
|
||||
}
|
||||
}
|
||||
|
||||
// Copy pos offsets. Copy only for maxGenerationLength
|
||||
auto outputPositionOffsets = params.outputPositionOffsets + batchIdx * maxGenerationLength;
|
||||
auto const inputPositionOffsets = params.inputPositionOffsets + batchSlot * maxDecodingTokens;
|
||||
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < maxGenerationLength;
|
||||
ti += static_cast<SizeType32>(blockDim.x))
|
||||
if (isGenerationRequest)
|
||||
{
|
||||
outputPositionOffsets[ti] = inputPositionOffsets[ti];
|
||||
auto const basePosId = params.outputPositionIdsBase[batchIdx];
|
||||
auto outputPositionOffsets = params.outputPositionOffsets + genIdx * maxGenerationLength;
|
||||
auto outputPositionIds = params.outputPositionIds + genIdx * maxGenerationLength;
|
||||
auto const inputPositionIds = params.inputPositionIds + batchSlot * maxDecodingTokens;
|
||||
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < maxGenerationLength;
|
||||
ti += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
auto const posId = inputPositionIds[ti];
|
||||
outputPositionIds[params.numContextTokens + ti] = posId;
|
||||
outputPositionOffsets[ti] = posId - basePosId + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
@ -477,7 +562,8 @@ void invokeCopyProbs(PackExplicitDraftTokensParams<T> const& params, cudaStream_
|
||||
auto const numCopyElems = params.numPaths * (params.maxPathLength - 1) * params.vocabSize;
|
||||
auto const copyRowSizeInBytes = numCopyElems * sizeof(T);
|
||||
|
||||
invokeCopyProbs(srcDataPtr, dstDataPtr, params.batchSlots, nullptr, params.batchSize, copyRowSizeInBytes, stream);
|
||||
invokeCopyProbs(srcDataPtr, dstDataPtr, params.batchSlots, nullptr, params.numGenerationRequests,
|
||||
params.numContextRequests, copyRowSizeInBytes, stream);
|
||||
}
|
||||
|
||||
template void invokeCopyProbs(PackExplicitDraftTokensParams<float> const& params, cudaStream_t stream);
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/kernels/speculativeDecoding/common.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include <cuda_fp16.h>
|
||||
@ -25,6 +26,38 @@
|
||||
namespace tensorrt_llm::kernels::speculative_decoding
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
struct FillContextExplicitDraftTokensParams
|
||||
{
|
||||
//! [maxBatchSize]
|
||||
T* randDataSample{nullptr};
|
||||
//! [maxBatchSize]
|
||||
T* outputTemperatures{nullptr};
|
||||
//! [maxBatchSize]
|
||||
float const* inputTemperatures{nullptr};
|
||||
//! [maxBatchSize]
|
||||
curandState_t* curandState{nullptr};
|
||||
//! [forwardBatchSize]
|
||||
runtime::SizeType32 const* batchSlots{nullptr};
|
||||
|
||||
runtime::SizeType32 batchSize{0};
|
||||
|
||||
void checkParams() const
|
||||
{
|
||||
TLLM_CHECK(randDataSample);
|
||||
TLLM_CHECK(outputTemperatures);
|
||||
TLLM_CHECK(inputTemperatures);
|
||||
TLLM_CHECK(curandState);
|
||||
TLLM_CHECK(batchSlots);
|
||||
|
||||
TLLM_CHECK(batchSize > 0);
|
||||
}
|
||||
};
|
||||
|
||||
//! @brief Sets temperature and generates random variable for sampling.
|
||||
template <typename T>
|
||||
void invokeFillContextBuffers(FillContextExplicitDraftTokensParams<T> const& params, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
struct ExtractExplicitDraftTokensParams
|
||||
{
|
||||
@ -43,10 +76,18 @@ struct ExtractExplicitDraftTokensParams
|
||||
//! [maxBatchSize]
|
||||
runtime::SizeType32* acceptedLengths{nullptr};
|
||||
//! [maxBatchSize]
|
||||
runtime::SizeType32* prevDraftLengths{nullptr};
|
||||
//! [maxBatchSize]
|
||||
runtime::SizeType32* nextDraftLengths{nullptr};
|
||||
//! [maxBatchSize]
|
||||
runtime::SizeType32* sequenceLengths{nullptr};
|
||||
//! [maxBatchSize]
|
||||
runtime::SizeType32* outputGenerationLengths{nullptr};
|
||||
//! [maxBatchSize]
|
||||
runtime::SizeType32* outputBestPathIndices{nullptr};
|
||||
//! [maxBatchSize, maxNumPaths, maxPathLength]
|
||||
runtime::SizeType32* outputLastDraftIndices{nullptr};
|
||||
//! [maxBatchSize]
|
||||
T* randDataSample{nullptr};
|
||||
//! [maxBatchSize, maxNumPaths, maxPathDraftLength]
|
||||
T* randDataVerification{nullptr};
|
||||
@ -58,7 +99,7 @@ struct ExtractExplicitDraftTokensParams
|
||||
runtime::SizeType32 const* batchSlots{nullptr};
|
||||
//! [forwardBatchSize, maxNumPaths, maxPathLength]
|
||||
runtime::TokenIdType const* nextDraftTokens{nullptr};
|
||||
//! [forwardBatchSize, maxNumPaths, maxPathLength]
|
||||
//! [forwardBatchSize, maxNumPaths, maxPathLength], optional
|
||||
runtime::TokenIdType const* lastDraftTokens{nullptr};
|
||||
//! [forwardBatchSize, maxNumPaths, maxPathLength]
|
||||
runtime::SizeType32 const* inputUnpackedNextDraftIndices{nullptr};
|
||||
@ -74,17 +115,75 @@ struct ExtractExplicitDraftTokensParams
|
||||
runtime::TokenIdType const* nextFlatTokens{nullptr};
|
||||
//! [forwardBatchSize]
|
||||
runtime::SizeType32 const* generationLengthInclusiveSum{nullptr};
|
||||
//! [forwardBatchSize]
|
||||
runtime::SizeType32 const* lastGenerationLengths{nullptr};
|
||||
//! [maxBatchSize, maxNumPaths, maxPathLength]
|
||||
runtime::SizeType32 const* lastDraftIndices{nullptr};
|
||||
//! [forwardBatchSize, maxNumPaths, maxPathDraftLength, maxVocabSize]
|
||||
T const* nextDraftProbs{nullptr};
|
||||
//! [maxBatchSize]
|
||||
float const* inputTemperatures{nullptr};
|
||||
//! [maxBatchSize]
|
||||
curandState_t* curandState{nullptr};
|
||||
runtime::SizeType32 batchSize;
|
||||
runtime::SizeType32 numPaths;
|
||||
runtime::SizeType32 maxPathLength;
|
||||
runtime::SizeType32 maxSeqLen;
|
||||
runtime::SizeType32 vocabSize;
|
||||
runtime::SizeType32 batchSize{0};
|
||||
runtime::SizeType32 numPaths{0};
|
||||
runtime::SizeType32 maxPathLength{0};
|
||||
runtime::SizeType32 maxSeqLen{0};
|
||||
runtime::SizeType32 vocabSize{0};
|
||||
runtime::SizeType32 numContextRequests{0};
|
||||
runtime::SizeType32 numGenerationRequests{0};
|
||||
|
||||
void checkParams() const
|
||||
{
|
||||
TLLM_CHECK(outputIds);
|
||||
|
||||
TLLM_CHECK(outputPositionIdsBase);
|
||||
TLLM_CHECK(inputPositionIdsBase);
|
||||
|
||||
TLLM_CHECK(outputPositionIds);
|
||||
TLLM_CHECK(packedPositionIds);
|
||||
|
||||
TLLM_CHECK(outputTemperatures);
|
||||
TLLM_CHECK(inputTemperatures);
|
||||
|
||||
TLLM_CHECK(outputDraftProbs);
|
||||
TLLM_CHECK(nextDraftProbs);
|
||||
|
||||
TLLM_CHECK(outputNextDraftTokens);
|
||||
TLLM_CHECK(unpackedNextDraftTokens);
|
||||
|
||||
TLLM_CHECK(unpackedNextDraftIndices);
|
||||
TLLM_CHECK(inputUnpackedNextDraftIndices);
|
||||
|
||||
TLLM_CHECK(outputLastDraftIndices);
|
||||
|
||||
TLLM_CHECK(bestPathIndices);
|
||||
TLLM_CHECK(outputBestPathIndices);
|
||||
|
||||
TLLM_CHECK(curandState);
|
||||
TLLM_CHECK(batchSlots);
|
||||
TLLM_CHECK(nextDraftTokens);
|
||||
TLLM_CHECK(nextFlatTokens);
|
||||
TLLM_CHECK(generationLengthInclusiveSum);
|
||||
TLLM_CHECK(bestPathLengths);
|
||||
|
||||
TLLM_CHECK(randDataSample);
|
||||
TLLM_CHECK(randDataVerification);
|
||||
TLLM_CHECK(acceptedLengths);
|
||||
TLLM_CHECK(nextDraftLengths);
|
||||
TLLM_CHECK(prevDraftLengths);
|
||||
TLLM_CHECK(sequenceLengths);
|
||||
TLLM_CHECK(outputGenerationLengths);
|
||||
|
||||
TLLM_CHECK(batchSize > 0);
|
||||
TLLM_CHECK(numPaths > 0);
|
||||
TLLM_CHECK(maxPathLength > 0);
|
||||
TLLM_CHECK(maxSeqLen > 0);
|
||||
TLLM_CHECK(vocabSize > 0);
|
||||
TLLM_CHECK(numContextRequests >= 0);
|
||||
TLLM_CHECK(numGenerationRequests >= 0);
|
||||
TLLM_CHECK(numContextRequests + numGenerationRequests != 0);
|
||||
}
|
||||
};
|
||||
|
||||
//! @brief Modifies `outputIds` and `sequenceLengths` according to the accepted tokens
|
||||
@ -146,10 +245,12 @@ struct PackExplicitDraftTokensParams
|
||||
//! [forwardBatchSize, maxGenerationLength, divUp(maxGenerationLength, 32)]
|
||||
int32_t const* inputPackedMask{nullptr};
|
||||
|
||||
//! [forwardBatchSize, maxGenerationLength]
|
||||
runtime::SizeType32* outputPositionIds{nullptr};
|
||||
//! [forwardBatchSize, maxGenerationLength]
|
||||
runtime::SizeType32* outputPositionOffsets{nullptr};
|
||||
//! [maxBatchSize, maxGenerationLength]
|
||||
runtime::SizeType32 const* inputPositionOffsets{nullptr};
|
||||
runtime::SizeType32 const* inputPositionIds{nullptr};
|
||||
|
||||
//! [forwardBatchSize, maxNumPaths, maxPathDraftLength, maxVocabSize]
|
||||
T* outputDraftProbs{nullptr};
|
||||
@ -161,12 +262,59 @@ struct PackExplicitDraftTokensParams
|
||||
//! [maxBatchSize]
|
||||
T const* inputTemperatures{nullptr};
|
||||
|
||||
runtime::SizeType32 batchSize;
|
||||
runtime::SizeType32 numPaths;
|
||||
runtime::SizeType32 maxPathLength;
|
||||
runtime::SizeType32 vocabSize;
|
||||
runtime::SizeType32 batchSize{0};
|
||||
runtime::SizeType32 numPaths{0};
|
||||
runtime::SizeType32 maxPathLength{0};
|
||||
runtime::SizeType32 vocabSize{0};
|
||||
runtime::SizeType32 numContextTokens{0};
|
||||
runtime::SizeType32 numContextRequests{0};
|
||||
runtime::SizeType32 numGenerationRequests{0};
|
||||
|
||||
void checkParams() const
|
||||
{
|
||||
TLLM_CHECK(batchSlots);
|
||||
TLLM_CHECK(cumSumGenerationLengths);
|
||||
TLLM_CHECK(maxGenerationLength);
|
||||
|
||||
TLLM_CHECK(inputPositionIdsBase);
|
||||
|
||||
TLLM_CHECK(inputGenerationLengths);
|
||||
|
||||
TLLM_CHECK(outputRandomDataSample);
|
||||
TLLM_CHECK(inputRandomDataSample);
|
||||
|
||||
TLLM_CHECK(inputRandomDataValidation);
|
||||
|
||||
TLLM_CHECK(inputNextDraftTokens);
|
||||
|
||||
TLLM_CHECK(inputNextDraftIndices);
|
||||
|
||||
TLLM_CHECK(inputPackedMask);
|
||||
|
||||
TLLM_CHECK(inputPositionIds);
|
||||
|
||||
TLLM_CHECK(inputDraftProbs);
|
||||
|
||||
TLLM_CHECK(outputTemperatures);
|
||||
TLLM_CHECK(inputTemperatures);
|
||||
|
||||
TLLM_CHECK(batchSize > 0);
|
||||
TLLM_CHECK(numPaths > 0);
|
||||
TLLM_CHECK(maxPathLength > 0);
|
||||
TLLM_CHECK(vocabSize > 0);
|
||||
TLLM_CHECK(numContextRequests >= 0);
|
||||
TLLM_CHECK(numGenerationRequests >= 0);
|
||||
TLLM_CHECK(
|
||||
(numContextTokens == 0 && numContextRequests == 0) || (numContextTokens > 0 && numContextRequests > 0));
|
||||
TLLM_CHECK(numContextRequests + numGenerationRequests != 0);
|
||||
}
|
||||
};
|
||||
|
||||
//! @brief Copy all rows at `batchSlots[batchIdx]` from `inputGenerationLengths` tensors to `batchIdx` rows at
|
||||
//! `outputGenerationLengths` tensor.
|
||||
template <typename T>
|
||||
void invokePackGenerationLengths(PackExplicitDraftTokensParams<T> const& params, cudaStream_t stream);
|
||||
|
||||
//! @brief Copy all rows at `batchSlots[batchIdx]` from `input*` tensors to `batchIdx` rows at `output*` tensor.
|
||||
template <typename T>
|
||||
void invokePackExplicitDraftTokens(PackExplicitDraftTokensParams<T> const& params, cudaStream_t stream);
|
||||
@ -175,27 +323,24 @@ void invokePackExplicitDraftTokens(PackExplicitDraftTokensParams<T> const& param
|
||||
template <typename T>
|
||||
void invokeCopyProbs(PackExplicitDraftTokensParams<T> const& params, cudaStream_t stream);
|
||||
|
||||
size_t invokeScanSpecDecodingGenerationLengths(void* __restrict__ reduceMaxTempStorage, size_t reduceTempStorageBytes,
|
||||
runtime::SizeType32 const* __restrict__ specDecodingGenerationLengths,
|
||||
runtime::SizeType32* __restrict__ scannedSpecDecodingGenerationLengths, runtime::SizeType32 batchSize,
|
||||
cudaStream_t stream);
|
||||
size_t invokeReduceMaxSpecDecodingGenerationLengths(void* __restrict__ reduceMaxTempStorage,
|
||||
size_t reduceTempStorageBytes, runtime::SizeType32 const* __restrict__ specDecodingGenerationLengths,
|
||||
runtime::SizeType32* __restrict__ scannedSpecDecodingGenerationLengths, runtime::SizeType32 batchSize,
|
||||
cudaStream_t stream);
|
||||
size_t invokeScanGenerationLengths(void* __restrict__ scanTempStorage, size_t scanTempStorageBytes,
|
||||
runtime::SizeType32 const* __restrict__ generationLengths,
|
||||
runtime::SizeType32* __restrict__ scannedGenerationLengths, runtime::SizeType32 batchSize, cudaStream_t stream);
|
||||
size_t invokeReduceMaxGenerationLengths(void* __restrict__ reduceMaxTempStorage, size_t reduceTempStorageBytes,
|
||||
runtime::SizeType32 const* __restrict__ generationLengths, runtime::SizeType32* __restrict__ maxGenerationLengths,
|
||||
runtime::SizeType32 batchSize, cudaStream_t stream);
|
||||
|
||||
// inclusive prefix sum specDecodingGenerationLengths
|
||||
void invokeScanReduceSpecDecodingGenerationLengths(runtime::SizeType32 batchSize,
|
||||
runtime::SizeType32 const* __restrict__ specDecodingGenerationLengths, void* __restrict__ scanTempStorage,
|
||||
size_t scanTempStorageBytes, runtime::SizeType32* __restrict__ scanedSpecDecodingGenerationLengths,
|
||||
// inclusive prefix sum generationLengths
|
||||
void invokeScanReduceGenerationLengths(runtime::SizeType32 batchSize,
|
||||
runtime::SizeType32 const* __restrict__ generationLengths, void* __restrict__ scanTempStorage,
|
||||
size_t scanTempStorageBytes, runtime::SizeType32* __restrict__ scanedGenerationLengths,
|
||||
void* __restrict__ reduceMaxTempStorage, size_t reduceMaxTempStorageBytes,
|
||||
runtime::SizeType32* maxSpecDecodingGenerationLengths, cudaStream_t stream);
|
||||
runtime::SizeType32* maxGenerationLengths, cudaStream_t stream);
|
||||
|
||||
void invokeConvertSpecDecodingMaskToPackedMask(runtime::SizeType32 batchSize,
|
||||
runtime::SizeType32 const* __restrict__ specDecodingCumGenerationLengths,
|
||||
runtime::SizeType32 const* __restrict__ specDecodingMaxGenerationLengths, bool const* __restrict__ specDecodingMask,
|
||||
void invokeConvertMaskToPackedMask(runtime::SizeType32 batchSize,
|
||||
runtime::SizeType32 const* __restrict__ cumGenerationLengths,
|
||||
runtime::SizeType32 const* __restrict__ maxGenerationLengths, bool const* __restrict__ mask,
|
||||
runtime::SizeType32 const* __restrict__ batchSlots, runtime::SizeType32 maxDraftTokens,
|
||||
runtime::SizeType32 maxGenerationLength, runtime::SizeType32* __restrict__ specDecodingPackedMask,
|
||||
cudaStream_t stream);
|
||||
runtime::SizeType32 maxGenerationLength, runtime::SizeType32* __restrict__ packedMask, cudaStream_t stream);
|
||||
|
||||
} // namespace tensorrt_llm::kernels::speculative_decoding
|
||||
|
||||
@ -25,99 +25,107 @@
|
||||
namespace tensorrt_llm::kernels::speculative_decoding
|
||||
{
|
||||
|
||||
static constexpr int kUpdateKVCacheKernelShmSize = 16384;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
static constexpr SizeType32 kUpdateKVCacheKernelShmSize = 16384;
|
||||
|
||||
namespace
|
||||
{
|
||||
template <typename KVCacheBuffer, int MaxLayerCount, typename MoveEltType>
|
||||
template <typename KVCacheBuffer, SizeType32 MaxLayerCount, typename MoveEltType>
|
||||
__global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array<KVCacheBuffer, MaxLayerCount> kvCacheBuffers,
|
||||
int const* seqAcceptedDraftTokenOffsets, IndexType const* packedAcceptedDraftTokensIndices,
|
||||
int32_t const* pastKeyValueLengths, int rewindDraftTokenCommonCount, int const* rewindDraftTokenSeparateAdjustments,
|
||||
int const* seqSlotRemapping, int eltCountPerHead)
|
||||
SizeType32 const* seqAcceptedDraftTokenOffsets, IndexType const* packedAcceptedDraftTokensIndices,
|
||||
SizeType32 const* pastKeyValueLengths, SizeType32 rewindDraftTokenCommonCount,
|
||||
SizeType32 const* rewindDraftTokenSeparateAdjustments, SizeType32 const* seqSlotRemapping,
|
||||
SizeType32 const* batchSlots, SizeType32 eltCountPerHead)
|
||||
{
|
||||
int seqIdx = blockIdx.x;
|
||||
int headIdx = blockIdx.y;
|
||||
int layerIdx = blockIdx.z;
|
||||
int warpIdx = threadIdx.x / 32;
|
||||
int warpCount = blockDim.x / 32;
|
||||
int laneIdx = threadIdx.x & 0x1f;
|
||||
int seqDraftTokenStart = seqAcceptedDraftTokenOffsets[seqIdx];
|
||||
int seqDraftTokenEnd = seqAcceptedDraftTokenOffsets[seqIdx + 1];
|
||||
auto const seqIdx = static_cast<SizeType32>(blockIdx.x);
|
||||
auto const headIdx = static_cast<SizeType32>(blockIdx.y);
|
||||
auto const layerIdx = static_cast<SizeType32>(blockIdx.z);
|
||||
auto const warpIdx = static_cast<SizeType32>(threadIdx.x / 32);
|
||||
auto const warpCount = static_cast<SizeType32>(blockDim.x / 32);
|
||||
auto const laneIdx = static_cast<SizeType32>(threadIdx.x & 0x1f);
|
||||
auto const seqDraftTokenStart = seqAcceptedDraftTokenOffsets[seqIdx];
|
||||
auto const seqDraftTokenEnd = seqAcceptedDraftTokenOffsets[seqIdx + 1];
|
||||
auto const seqSlot = seqSlotRemapping == nullptr ? seqIdx : seqSlotRemapping[seqIdx];
|
||||
int seqDraftCount = seqDraftTokenEnd - seqDraftTokenStart;
|
||||
int maxEltCountPerMove = kUpdateKVCacheKernelShmSize / sizeof(MoveEltType) / seqDraftCount;
|
||||
int eltCountPerMove = min(maxEltCountPerMove, eltCountPerHead);
|
||||
auto const seqDraftCount = seqDraftTokenEnd - seqDraftTokenStart;
|
||||
auto const maxEltCountPerMove
|
||||
= static_cast<SizeType32>(kUpdateKVCacheKernelShmSize / sizeof(MoveEltType) / seqDraftCount);
|
||||
auto const eltCountPerMove = min(maxEltCountPerMove, eltCountPerHead);
|
||||
if (seqDraftCount == 0 || eltCountPerMove == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
KVCacheBuffer& kvCacheBuffer = kvCacheBuffers[layerIdx];
|
||||
int tokenStartIdx = pastKeyValueLengths[seqSlot] - rewindDraftTokenCommonCount;
|
||||
auto tokenStartIdx = pastKeyValueLengths[seqSlot] - rewindDraftTokenCommonCount;
|
||||
if (rewindDraftTokenSeparateAdjustments != nullptr)
|
||||
{
|
||||
tokenStartIdx -= rewindDraftTokenSeparateAdjustments[seqSlot];
|
||||
auto const batchSlot = batchSlots == nullptr ? seqIdx : batchSlots[seqIdx];
|
||||
tokenStartIdx -= rewindDraftTokenSeparateAdjustments[batchSlot];
|
||||
}
|
||||
__shared__ char loadSmemBuffer[kUpdateKVCacheKernelShmSize];
|
||||
auto* eltLoadSmemBuffer = reinterpret_cast<MoveEltType*>(&loadSmemBuffer[0]);
|
||||
for (int startChannelOffset = 0; startChannelOffset < eltCountPerHead; startChannelOffset += eltCountPerMove)
|
||||
for (SizeType32 startChannelOffset = 0; startChannelOffset < eltCountPerHead; startChannelOffset += eltCountPerMove)
|
||||
{
|
||||
int eltCountCurrentMove = min(eltCountPerMove, eltCountPerHead - startChannelOffset);
|
||||
SizeType32 eltCountCurrentMove = min(eltCountPerMove, eltCountPerHead - startChannelOffset);
|
||||
// load K
|
||||
for (int tokenIdx = warpIdx; tokenIdx < seqDraftCount; tokenIdx += warpCount)
|
||||
for (SizeType32 tokenIdx = warpIdx; tokenIdx < seqDraftCount; tokenIdx += warpCount)
|
||||
{
|
||||
int tokenPos = packedAcceptedDraftTokensIndices[seqDraftTokenStart + tokenIdx];
|
||||
auto const tokenPos = packedAcceptedDraftTokensIndices[seqDraftTokenStart + tokenIdx];
|
||||
auto* tokenSmemBuffer = eltLoadSmemBuffer + tokenIdx * eltCountCurrentMove;
|
||||
int tokenKVPosition = tokenStartIdx + tokenPos;
|
||||
auto const tokenKVPosition = tokenStartIdx + tokenPos;
|
||||
auto* kPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getKBlockPtr(seqSlot, tokenKVPosition));
|
||||
for (int loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
|
||||
for (SizeType32 loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
|
||||
{
|
||||
int channelIdx = loadChannelIdx + startChannelOffset;
|
||||
int kvLocationIdx = kvCacheBuffer.getKVLocalIdx(tokenKVPosition, headIdx, eltCountPerHead, channelIdx);
|
||||
auto const channelIdx = loadChannelIdx + startChannelOffset;
|
||||
auto const kvLocationIdx
|
||||
= kvCacheBuffer.getKVLocalIdx(tokenKVPosition, headIdx, eltCountPerHead, channelIdx);
|
||||
tokenSmemBuffer[loadChannelIdx] = kPtr[kvLocationIdx];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
// store K
|
||||
for (int tokenIdx = warpIdx; tokenIdx < seqDraftCount; tokenIdx += warpCount)
|
||||
for (SizeType32 tokenIdx = warpIdx; tokenIdx < seqDraftCount; tokenIdx += warpCount)
|
||||
{
|
||||
int tokenPos = tokenIdx;
|
||||
auto const tokenPos = tokenIdx;
|
||||
auto* tokenSmemBuffer = eltLoadSmemBuffer + tokenIdx * eltCountCurrentMove;
|
||||
int tokenKVPosition = tokenStartIdx + tokenPos;
|
||||
auto const tokenKVPosition = tokenStartIdx + tokenPos;
|
||||
auto* kPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getKBlockPtr(seqSlot, tokenKVPosition));
|
||||
for (int loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
|
||||
for (SizeType32 loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
|
||||
{
|
||||
int channelIdx = loadChannelIdx + startChannelOffset;
|
||||
int kvLocationIdx = kvCacheBuffer.getKVLocalIdx(tokenKVPosition, headIdx, eltCountPerHead, channelIdx);
|
||||
auto const channelIdx = loadChannelIdx + startChannelOffset;
|
||||
auto const kvLocationIdx
|
||||
= kvCacheBuffer.getKVLocalIdx(tokenKVPosition, headIdx, eltCountPerHead, channelIdx);
|
||||
kPtr[kvLocationIdx] = tokenSmemBuffer[loadChannelIdx];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
// load V
|
||||
for (int tokenIdx = warpIdx; tokenIdx < seqDraftCount; tokenIdx += warpCount)
|
||||
for (SizeType32 tokenIdx = warpIdx; tokenIdx < seqDraftCount; tokenIdx += warpCount)
|
||||
{
|
||||
int tokenPos = packedAcceptedDraftTokensIndices[seqDraftTokenStart + tokenIdx];
|
||||
auto const tokenPos = packedAcceptedDraftTokensIndices[seqDraftTokenStart + tokenIdx];
|
||||
auto* tokenSmemBuffer = eltLoadSmemBuffer + tokenIdx * eltCountCurrentMove;
|
||||
int tokenKVPosition = tokenStartIdx + tokenPos;
|
||||
auto const tokenKVPosition = tokenStartIdx + tokenPos;
|
||||
auto* vPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getVBlockPtr(seqSlot, tokenKVPosition));
|
||||
for (int loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
|
||||
for (SizeType32 loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
|
||||
{
|
||||
int channelIdx = loadChannelIdx + startChannelOffset;
|
||||
int kvLocationIdx = kvCacheBuffer.getKVLocalIdx(tokenKVPosition, headIdx, eltCountPerHead, channelIdx);
|
||||
auto const channelIdx = loadChannelIdx + startChannelOffset;
|
||||
auto const kvLocationIdx
|
||||
= kvCacheBuffer.getKVLocalIdx(tokenKVPosition, headIdx, eltCountPerHead, channelIdx);
|
||||
tokenSmemBuffer[loadChannelIdx] = vPtr[kvLocationIdx];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
// store V
|
||||
for (int tokenIdx = warpIdx; tokenIdx < seqDraftCount; tokenIdx += warpCount)
|
||||
for (SizeType32 tokenIdx = warpIdx; tokenIdx < seqDraftCount; tokenIdx += warpCount)
|
||||
{
|
||||
int tokenPos = tokenIdx;
|
||||
auto const tokenPos = tokenIdx;
|
||||
auto* tokenSmemBuffer = eltLoadSmemBuffer + tokenPos * eltCountCurrentMove;
|
||||
int tokenKVPosition = tokenStartIdx + tokenPos;
|
||||
auto const tokenKVPosition = tokenStartIdx + tokenPos;
|
||||
auto* vPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getVBlockPtr(seqSlot, tokenKVPosition));
|
||||
for (int loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
|
||||
for (SizeType32 loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
|
||||
{
|
||||
int channelIdx = loadChannelIdx + startChannelOffset;
|
||||
int kvLocationIdx = kvCacheBuffer.getKVLocalIdx(tokenKVPosition, headIdx, eltCountPerHead, channelIdx);
|
||||
auto const channelIdx = loadChannelIdx + startChannelOffset;
|
||||
auto const kvLocationIdx
|
||||
= kvCacheBuffer.getKVLocalIdx(tokenKVPosition, headIdx, eltCountPerHead, channelIdx);
|
||||
vPtr[kvLocationIdx] = tokenSmemBuffer[loadChannelIdx];
|
||||
}
|
||||
}
|
||||
@ -126,12 +134,13 @@ __global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array<KVCacheB
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename KVCacheBuffer, int MaxLayerCount>
|
||||
template <typename KVCacheBuffer, SizeType32 MaxLayerCount>
|
||||
void updateKVCacheDraftTokenLocationBatched(KVCacheBuffer const* kvCacheBuffers,
|
||||
int const* seqAcceptedDraftTokenOffsets, IndexType const* packedAcceptedDraftTokensIndices,
|
||||
int32_t const* pastKeyValueLengths, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int const* seqSlotRemapping,
|
||||
cudaStream_t stream)
|
||||
SizeType32 const* seqAcceptedDraftTokenOffsets, IndexType const* packedAcceptedDraftTokensIndices,
|
||||
SizeType32 const* pastKeyValueLengths, SizeType32 layerCount, SizeType32 seqCount, SizeType32 numKVHeads,
|
||||
SizeType32 sizeInBytesPerKVHead, SizeType32 rewindDraftTokenCommonCount,
|
||||
SizeType32 const* rewindDraftTokenSeparateAdjustments, SizeType32 const* seqSlotRemapping,
|
||||
SizeType32 const* batchSlots, cudaStream_t stream)
|
||||
{
|
||||
// make sure launch buffer is enough
|
||||
static_assert(MaxLayerCount * sizeof(KVCacheBuffer) <= 3072);
|
||||
@ -139,22 +148,22 @@ void updateKVCacheDraftTokenLocationBatched(KVCacheBuffer const* kvCacheBuffers,
|
||||
{
|
||||
return;
|
||||
}
|
||||
int alignedBytes = 16;
|
||||
SizeType32 alignedBytes = 16;
|
||||
while (alignedBytes > 0 && (sizeInBytesPerKVHead % alignedBytes != 0))
|
||||
{
|
||||
alignedBytes >>= 1;
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(alignedBytes > 0, "alignedByte should be positive");
|
||||
int eltCountPerHead = sizeInBytesPerKVHead / alignedBytes;
|
||||
SizeType32 eltCountPerHead = sizeInBytesPerKVHead / alignedBytes;
|
||||
dim3 grid(seqCount, numKVHeads, layerCount);
|
||||
dim3 block(128, 1, 1);
|
||||
std::array<KVCacheBuffer, MaxLayerCount> kvCacheBufferArray;
|
||||
for (int i = 0; i < layerCount; i++)
|
||||
for (SizeType32 i = 0; i < layerCount; i++)
|
||||
{
|
||||
kvCacheBufferArray[i] = kvCacheBuffers[i];
|
||||
}
|
||||
void (*pKernelFunc)(std::array<KVCacheBuffer, MaxLayerCount>, int const*, IndexType const*, int32_t const*, int,
|
||||
int const*, int const*, int)
|
||||
void (*pKernelFunc)(std::array<KVCacheBuffer, MaxLayerCount>, SizeType32 const*, IndexType const*,
|
||||
SizeType32 const*, SizeType32, SizeType32 const*, SizeType32 const*, SizeType32 const*, SizeType32)
|
||||
= nullptr;
|
||||
switch (alignedBytes)
|
||||
{
|
||||
@ -170,7 +179,7 @@ void updateKVCacheDraftTokenLocationBatched(KVCacheBuffer const* kvCacheBuffers,
|
||||
}
|
||||
case 4:
|
||||
{
|
||||
pKernelFunc = &updateKVCacheDraftTokenLocationBatchedKernel<KVCacheBuffer, MaxLayerCount, int32_t>;
|
||||
pKernelFunc = &updateKVCacheDraftTokenLocationBatchedKernel<KVCacheBuffer, MaxLayerCount, SizeType32>;
|
||||
break;
|
||||
}
|
||||
case 2:
|
||||
@ -187,7 +196,7 @@ void updateKVCacheDraftTokenLocationBatched(KVCacheBuffer const* kvCacheBuffers,
|
||||
}
|
||||
pKernelFunc<<<grid, block, 0, stream>>>(kvCacheBufferArray, seqAcceptedDraftTokenOffsets,
|
||||
packedAcceptedDraftTokensIndices, pastKeyValueLengths, rewindDraftTokenCommonCount,
|
||||
rewindDraftTokenSeparateAdjustments, seqSlotRemapping, eltCountPerHead);
|
||||
rewindDraftTokenSeparateAdjustments, seqSlotRemapping, batchSlots, eltCountPerHead);
|
||||
TLLM_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
@ -209,54 +218,59 @@ void updateKVCacheDraftTokenLocationBatched(KVCacheBuffer const* kvCacheBuffers,
|
||||
* @param stream : CUDA stream to use.
|
||||
*/
|
||||
template <typename KVCacheBuffer>
|
||||
void updateKVCacheDraftTokenLocation(KVCacheBuffer const* kvCacheBuffers, int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int layerCount, int seqCount,
|
||||
int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments,
|
||||
int const* seqSlotRemapping, cudaStream_t stream)
|
||||
void updateKVCacheDraftTokenLocation(KVCacheBuffer const* kvCacheBuffers,
|
||||
SizeType32 const* seqAcceptedDraftTokenOffsets, IndexType const* packedAcceptedDraftTokensIndices,
|
||||
SizeType32 const* pastKeyValueLengths, SizeType32 layerCount, SizeType32 seqCount, SizeType32 numKVHeads,
|
||||
SizeType32 sizeInBytesPerKVHead, SizeType32 rewindDraftTokenCommonCount,
|
||||
SizeType32 const* rewindDraftTokenSeparateAdjustments, SizeType32 const* seqSlotRemapping,
|
||||
SizeType32 const* batchSlots, cudaStream_t stream)
|
||||
{
|
||||
int startLayer = 0;
|
||||
static constexpr int kMaxLayersPerIter = 32;
|
||||
SizeType32 startLayer = 0;
|
||||
static constexpr SizeType32 kMaxLayersPerIter = 32;
|
||||
while (startLayer < layerCount)
|
||||
{
|
||||
int microBatchLayerCount = std::min(layerCount - startLayer, kMaxLayersPerIter);
|
||||
SizeType32 microBatchLayerCount = std::min(layerCount - startLayer, kMaxLayersPerIter);
|
||||
updateKVCacheDraftTokenLocationBatched<KVCacheBuffer, kMaxLayersPerIter>(kvCacheBuffers + startLayer,
|
||||
seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices, pastKeyValueLengths, microBatchLayerCount,
|
||||
seqCount, numKVHeads, sizeInBytesPerKVHead, rewindDraftTokenCommonCount,
|
||||
rewindDraftTokenSeparateAdjustments, seqSlotRemapping, stream);
|
||||
rewindDraftTokenSeparateAdjustments, seqSlotRemapping, batchSlots, stream);
|
||||
startLayer += microBatchLayerCount;
|
||||
}
|
||||
}
|
||||
|
||||
void updateLinearKVCacheDraftTokenLocation(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
|
||||
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int const* seqSlotRemapping,
|
||||
int maxKVCacheLen, cudaStream_t stream)
|
||||
void updateLinearKVCacheDraftTokenLocation(SizeType32 const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, SizeType32 const* pastKeyValueLengths,
|
||||
int8_t* const* pastKeyValueList, SizeType32 layerCount, SizeType32 seqCount, SizeType32 numKVHeads,
|
||||
SizeType32 sizeInBytesPerKVHead, SizeType32 rewindDraftTokenCommonCount,
|
||||
SizeType32 const* rewindDraftTokenSeparateAdjustments, SizeType32 const* seqSlotRemapping, SizeType32 maxKVCacheLen,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
std::vector<KVLinearBuffer> kvLinearBuffers;
|
||||
kvLinearBuffers.reserve(layerCount);
|
||||
auto const sizePerToken = numKVHeads * sizeInBytesPerKVHead;
|
||||
for (int i = 0; i < layerCount; i++)
|
||||
for (SizeType32 i = 0; i < layerCount; i++)
|
||||
{
|
||||
kvLinearBuffers.emplace_back(
|
||||
seqCount, maxKVCacheLen, sizePerToken, maxKVCacheLen, 0, false, pastKeyValueList[i]);
|
||||
}
|
||||
updateKVCacheDraftTokenLocation(kvLinearBuffers.data(), seqAcceptedDraftTokenOffsets,
|
||||
packedAcceptedDraftTokensIndices, pastKeyValueLengths, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
|
||||
rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, seqSlotRemapping, stream);
|
||||
rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, seqSlotRemapping, nullptr, stream);
|
||||
}
|
||||
|
||||
void updateKVBlockArrayDraftTokenLocation(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, void* const* pointerArray,
|
||||
KVBlockArray::DataType* offsetArray, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int const* seqSlotRemapping,
|
||||
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream)
|
||||
void updateKVBlockArrayDraftTokenLocation(SizeType32 const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, SizeType32 const* pastKeyValueLengths, void* const* pointerArray,
|
||||
KVBlockArray::DataType* offsetArray, SizeType32 layerCount, SizeType32 seqCount, SizeType32 numKVHeads,
|
||||
SizeType32 sizeInBytesPerKVHead, SizeType32 rewindDraftTokenCommonCount,
|
||||
SizeType32 const* rewindDraftTokenSeparateAdjustments, SizeType32 const* seqSlotRemapping,
|
||||
SizeType32 const* batchSlots, SizeType32 maxKVCacheLen, SizeType32 maxBlocksPerSeq, SizeType32 tokensPerBlock,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
std::vector<KVBlockArray> kvBlockArrays;
|
||||
kvBlockArrays.reserve(layerCount);
|
||||
auto const bytesPerToken = numKVHeads * sizeInBytesPerKVHead;
|
||||
auto const bytesPerBlock = tokensPerBlock * bytesPerToken;
|
||||
for (int layerIdx = 0; layerIdx < layerCount; layerIdx++)
|
||||
for (SizeType32 layerIdx = 0; layerIdx < layerCount; layerIdx++)
|
||||
{
|
||||
auto const layerOffset = layerIdx * 2 * bytesPerBlock;
|
||||
auto* const primaryPoolPointer
|
||||
@ -269,49 +283,52 @@ void updateKVBlockArrayDraftTokenLocation(int const* seqAcceptedDraftTokenOffset
|
||||
}
|
||||
updateKVCacheDraftTokenLocation(kvBlockArrays.data(), seqAcceptedDraftTokenOffsets,
|
||||
packedAcceptedDraftTokensIndices, pastKeyValueLengths, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
|
||||
rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, seqSlotRemapping, stream);
|
||||
rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, seqSlotRemapping, batchSlots, stream);
|
||||
}
|
||||
|
||||
void updateLinearKVCacheDraftTokenLocationCommonRewind(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
|
||||
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int rewindDraftTokenCount, int const* seqSlotRemapping, int maxKVCacheLen, cudaStream_t stream)
|
||||
void updateLinearKVCacheDraftTokenLocationCommonRewind(SizeType32 const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, SizeType32 const* pastKeyValueLengths,
|
||||
int8_t* const* pastKeyValueList, SizeType32 layerCount, SizeType32 seqCount, SizeType32 numKVHeads,
|
||||
SizeType32 sizeInBytesPerKVHead, SizeType32 rewindDraftTokenCount, SizeType32 const* seqSlotRemapping,
|
||||
SizeType32 maxKVCacheLen, cudaStream_t stream)
|
||||
{
|
||||
updateLinearKVCacheDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
|
||||
pastKeyValueLengths, pastKeyValueList, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
|
||||
rewindDraftTokenCount, nullptr, seqSlotRemapping, maxKVCacheLen, stream);
|
||||
}
|
||||
|
||||
void updateKVBlockArrayDraftTokenLocationCommonRewind(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, void* const* pointerArray,
|
||||
KVBlockArray::DataType* offsetArray, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int rewindDraftTokenCount, int const* seqSlotRemapping, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock,
|
||||
cudaStream_t stream)
|
||||
void updateKVBlockArrayDraftTokenLocationCommonRewind(SizeType32 const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, SizeType32 const* pastKeyValueLengths, void* const* pointerArray,
|
||||
KVBlockArray::DataType* offsetArray, SizeType32 layerCount, SizeType32 seqCount, SizeType32 numKVHeads,
|
||||
SizeType32 sizeInBytesPerKVHead, SizeType32 rewindDraftTokenCount, SizeType32 const* seqSlotRemapping,
|
||||
SizeType32 maxKVCacheLen, SizeType32 maxBlocksPerSeq, SizeType32 tokensPerBlock, cudaStream_t stream)
|
||||
{
|
||||
updateKVBlockArrayDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
|
||||
pastKeyValueLengths, pointerArray, offsetArray, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
|
||||
rewindDraftTokenCount, nullptr, seqSlotRemapping, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, stream);
|
||||
rewindDraftTokenCount, nullptr, seqSlotRemapping, nullptr, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock,
|
||||
stream);
|
||||
}
|
||||
|
||||
void updateLinearKVCacheDraftTokenLocationSeparateRewind(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
|
||||
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int* rewindDraftTokenCounts, int const* seqSlotRemapping, int maxKVCacheLen, cudaStream_t stream)
|
||||
void updateLinearKVCacheDraftTokenLocationSeparateRewind(SizeType32 const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, SizeType32 const* pastKeyValueLengths,
|
||||
int8_t* const* pastKeyValueList, SizeType32 layerCount, SizeType32 seqCount, SizeType32 numKVHeads,
|
||||
SizeType32 sizeInBytesPerKVHead, SizeType32* rewindDraftTokenCounts, SizeType32 const* seqSlotRemapping,
|
||||
SizeType32 maxKVCacheLen, cudaStream_t stream)
|
||||
{
|
||||
updateLinearKVCacheDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
|
||||
pastKeyValueLengths, pastKeyValueList, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead, 0,
|
||||
rewindDraftTokenCounts, seqSlotRemapping, maxKVCacheLen, stream);
|
||||
}
|
||||
|
||||
void updateKVBlockArrayDraftTokenLocationSeparateRewind(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, void* const* pointerArray,
|
||||
KVBlockArray::DataType* offsetArray, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int* rewindDraftTokenCounts, int const* seqSlotRemapping, int maxKVCacheLen, int maxBlocksPerSeq,
|
||||
int tokensPerBlock, cudaStream_t stream)
|
||||
void updateKVBlockArrayDraftTokenLocationSeparateRewind(SizeType32 const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, SizeType32 const* pastKeyValueLengths, void* const* pointerArray,
|
||||
KVBlockArray::DataType* offsetArray, SizeType32 layerCount, SizeType32 seqCount, SizeType32 numKVHeads,
|
||||
SizeType32 sizeInBytesPerKVHead, SizeType32 const* rewindDraftTokenCounts, SizeType32 const* seqSlotRemapping,
|
||||
SizeType32 maxKVCacheLen, SizeType32 maxBlocksPerSeq, SizeType32 tokensPerBlock, cudaStream_t stream)
|
||||
{
|
||||
updateKVBlockArrayDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
|
||||
pastKeyValueLengths, pointerArray, offsetArray, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead, 0,
|
||||
rewindDraftTokenCounts, seqSlotRemapping, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, stream);
|
||||
rewindDraftTokenCounts, seqSlotRemapping, nullptr, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, stream);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::kernels::speculative_decoding
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/kernels/kvCacheUtils.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include <cstdint>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
@ -44,11 +45,11 @@ using IndexType = int;
|
||||
* @param maxKVCacheLen : Maximum length of each KV cache
|
||||
* @param stream : CUDA stream to use.
|
||||
*/
|
||||
void updateLinearKVCacheDraftTokenLocationCommonRewind(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
|
||||
KVLinearBuffer::DataType* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads,
|
||||
int sizeInBytesPerKVHead, int rewindDraftTokenCount, int const* seqSlotRemapping, int maxKVCacheLen,
|
||||
cudaStream_t stream);
|
||||
void updateLinearKVCacheDraftTokenLocationCommonRewind(runtime::SizeType32 const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, runtime::SizeType32 const* pastKeyValueLengths,
|
||||
KVLinearBuffer::DataType* const* pastKeyValueList, runtime::SizeType32 layerCount, runtime::SizeType32 seqCount,
|
||||
runtime::SizeType32 numKVHeads, runtime::SizeType32 sizeInBytesPerKVHead, runtime::SizeType32 rewindDraftTokenCount,
|
||||
runtime::SizeType32 const* seqSlotRemapping, runtime::SizeType32 maxKVCacheLen, cudaStream_t stream);
|
||||
|
||||
/*!
|
||||
* Update Block KV cache using common rewind count.
|
||||
@ -72,10 +73,12 @@ void updateLinearKVCacheDraftTokenLocationCommonRewind(int const* seqAcceptedDra
|
||||
* @param tokensPerBlock : Tokens per block of Block KV cache
|
||||
* @param stream : CUDA stream to use.
|
||||
*/
|
||||
void updateKVBlockArrayDraftTokenLocationCommonRewind(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, void* const* pointerArray,
|
||||
KVBlockArray::DataType* offsetArray, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int rewindDraftTokenCount, int const* seqSlotRemapping, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock,
|
||||
void updateKVBlockArrayDraftTokenLocationCommonRewind(runtime::SizeType32 const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, runtime::SizeType32 const* pastKeyValueLengths,
|
||||
void* const* pointerArray, KVBlockArray::DataType* offsetArray, runtime::SizeType32 layerCount,
|
||||
runtime::SizeType32 seqCount, runtime::SizeType32 numKVHeads, runtime::SizeType32 sizeInBytesPerKVHead,
|
||||
runtime::SizeType32 rewindDraftTokenCount, runtime::SizeType32 const* seqSlotRemapping,
|
||||
runtime::SizeType32 maxKVCacheLen, runtime::SizeType32 maxBlocksPerSeq, runtime::SizeType32 tokensPerBlock,
|
||||
cudaStream_t stream);
|
||||
|
||||
/*!
|
||||
@ -98,11 +101,12 @@ void updateKVBlockArrayDraftTokenLocationCommonRewind(int const* seqAcceptedDraf
|
||||
* @param maxKVCacheLen : Maximum length of each KV cache
|
||||
* @param stream : CUDA stream to use.
|
||||
*/
|
||||
void updateLinearKVCacheDraftTokenLocationSeparateRewind(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
|
||||
KVLinearBuffer::DataType* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads,
|
||||
int sizeInBytesPerKVHead, int* rewindDraftTokenCounts, int const* seqSlotRemapping, int maxKVCacheLen,
|
||||
cudaStream_t stream);
|
||||
void updateLinearKVCacheDraftTokenLocationSeparateRewind(runtime::SizeType32 const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, runtime::SizeType32 const* pastKeyValueLengths,
|
||||
KVLinearBuffer::DataType* const* pastKeyValueList, runtime::SizeType32 layerCount, runtime::SizeType32 seqCount,
|
||||
runtime::SizeType32 numKVHeads, runtime::SizeType32 sizeInBytesPerKVHead,
|
||||
runtime::SizeType32* rewindDraftTokenCounts, runtime::SizeType32 const* seqSlotRemapping,
|
||||
runtime::SizeType32 maxKVCacheLen, cudaStream_t stream);
|
||||
|
||||
/*!
|
||||
* Update Block KV cache using separate rewind count for each sequence.
|
||||
@ -127,11 +131,13 @@ void updateLinearKVCacheDraftTokenLocationSeparateRewind(int const* seqAcceptedD
|
||||
* @param tokensPerBlock : Tokens per block of Block KV cache
|
||||
* @param stream : CUDA stream to use.
|
||||
*/
|
||||
void updateKVBlockArrayDraftTokenLocationSeparateRewind(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, void* const* pointerArray,
|
||||
KVBlockArray::DataType* offsetArray, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int* rewindDraftTokenCounts, int const* seqSlotRemapping, int maxKVCacheLen, int maxBlocksPerSeq,
|
||||
int tokensPerBlock, cudaStream_t stream);
|
||||
void updateKVBlockArrayDraftTokenLocationSeparateRewind(runtime::SizeType32 const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, runtime::SizeType32 const* pastKeyValueLengths,
|
||||
void* const* pointerArray, KVBlockArray::DataType* offsetArray, runtime::SizeType32 layerCount,
|
||||
runtime::SizeType32 seqCount, runtime::SizeType32 numKVHeads, runtime::SizeType32 sizeInBytesPerKVHead,
|
||||
runtime::SizeType32* rewindDraftTokenCounts, runtime::SizeType32 const* seqSlotRemapping,
|
||||
runtime::SizeType32 maxKVCacheLen, runtime::SizeType32 maxBlocksPerSeq, runtime::SizeType32 tokensPerBlock,
|
||||
cudaStream_t stream);
|
||||
|
||||
/*!
|
||||
* Update Linear KV cache using both common rewind and separate rewind count for each sequence. The common
|
||||
@ -156,11 +162,12 @@ void updateKVBlockArrayDraftTokenLocationSeparateRewind(int const* seqAcceptedDr
|
||||
* @param maxKVCacheLen : Maximum length of each KV cache
|
||||
* @param stream : CUDA stream to use.
|
||||
*/
|
||||
void updateLinearKVCacheDraftTokenLocation(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
|
||||
KVLinearBuffer::DataType* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads,
|
||||
int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments,
|
||||
int const* seqSlotRemapping, int maxKVCacheLen, cudaStream_t stream);
|
||||
void updateLinearKVCacheDraftTokenLocation(runtime::SizeType32 const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, runtime::SizeType32 const* pastKeyValueLengths,
|
||||
KVLinearBuffer::DataType* const* pastKeyValueList, runtime::SizeType32 layerCount, runtime::SizeType32 seqCount,
|
||||
runtime::SizeType32 numKVHeads, runtime::SizeType32 sizeInBytesPerKVHead,
|
||||
runtime::SizeType32 rewindDraftTokenCommonCount, runtime::SizeType32 const* rewindDraftTokenSeparateAdjustments,
|
||||
runtime::SizeType32 const* seqSlotRemapping, runtime::SizeType32 maxKVCacheLen, cudaStream_t stream);
|
||||
|
||||
/*!
|
||||
* Update Block KV cache using both common rewind and separate rewind count for each sequence. The common
|
||||
@ -178,20 +185,24 @@ void updateLinearKVCacheDraftTokenLocation(int const* seqAcceptedDraftTokenOffse
|
||||
* @param sizeInBytesPerKVHead : Size of each KV head
|
||||
* @param rewindDraftTokenCommonCount : Common token count to rewind
|
||||
* @param rewindDraftTokenSeparateAdjustments : Pointer to an array of length seqCount, each element indicated the
|
||||
* rewind adjustment for one sequence.
|
||||
* rewind adjustment for one sequence, indexed through batchSlots.
|
||||
* @param seqSlotRemapping mapping from batch index to index of the seqSlot in the sorted seqSlot buffer
|
||||
* e.g. for requests [0, 1, 2] with seqSlots [5, 3, 4], seqSlotRemapping is [1, 2, 0]
|
||||
* Required to match seqAcceptedDraftTokenOffsets and packedAcceptedDraftTokensIndices from gptDecoderBatch
|
||||
* and pointerArray and pastKeyValueLengths from runtimeBuffers.
|
||||
* @param batchSlots : [seqCount] indices of sequences in the seq slots.
|
||||
* @param maxKVCacheLen : Maximum length of each KV cache
|
||||
* @param maxBlocksPerSeq : Maximum blocks per sequence of Block KV cache.
|
||||
* @param tokensPerBlock : Tokens per block of Block KV cache
|
||||
* @param stream : CUDA stream to use.
|
||||
*/
|
||||
void updateKVBlockArrayDraftTokenLocation(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, void* const* pointerArray,
|
||||
KVBlockArray::DataType* offsetArray, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int const* seqSlotRemapping,
|
||||
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream);
|
||||
void updateKVBlockArrayDraftTokenLocation(runtime::SizeType32 const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, runtime::SizeType32 const* pastKeyValueLengths,
|
||||
void* const* pointerArray, KVBlockArray::DataType* offsetArray, runtime::SizeType32 layerCount,
|
||||
runtime::SizeType32 seqCount, runtime::SizeType32 numKVHeads, runtime::SizeType32 sizeInBytesPerKVHead,
|
||||
runtime::SizeType32 rewindDraftTokenCommonCount, runtime::SizeType32 const* rewindDraftTokenSeparateAdjustments,
|
||||
runtime::SizeType32 const* seqSlotRemapping, runtime::SizeType32 const* batchSlots,
|
||||
runtime::SizeType32 maxKVCacheLen, runtime::SizeType32 maxBlocksPerSeq, runtime::SizeType32 tokensPerBlock,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace tensorrt_llm::kernels::speculative_decoding
|
||||
|
||||
@ -131,7 +131,8 @@ void invokeStopWordsCriterion(TokenIdType const** outputIds, SizeType32 const**
|
||||
}
|
||||
|
||||
__global__ void lengthCriterion(FinishedState* finished, SizeType32* finishedSum, SizeType32 const* sequenceLimitLength,
|
||||
SizeType32* sequenceLengths, SizeType32 const* batchSlots, SizeType32 batchSize, SizeType32 beamWidth)
|
||||
SizeType32* sequenceLengths, SizeType32* numNewTokens, SizeType32 const* batchSlots, SizeType32 batchSize,
|
||||
SizeType32 beamWidth)
|
||||
{
|
||||
SizeType32 threadFinishedCount = 0;
|
||||
auto const batchIdx = blockIdx.x;
|
||||
@ -144,10 +145,15 @@ __global__ void lengthCriterion(FinishedState* finished, SizeType32* finishedSum
|
||||
|
||||
auto finishState = finished[batchSlotBeamWidthIdx];
|
||||
|
||||
if (sequenceLengths[batchSlotBeamWidthIdx] >= sequenceLimitLength[batchSlot])
|
||||
auto const numTokensToLimit = sequenceLimitLength[batchSlot] - sequenceLengths[batchSlotBeamWidthIdx];
|
||||
if (numTokensToLimit <= 0)
|
||||
{
|
||||
finishState.setFinishedMaxLength();
|
||||
sequenceLengths[batchSlotBeamWidthIdx] = sequenceLimitLength[batchSlot];
|
||||
if (numNewTokens)
|
||||
{
|
||||
numNewTokens[batchSlot] = numNewTokens[batchSlot] + numTokensToLimit;
|
||||
}
|
||||
}
|
||||
threadFinishedCount += finishState.isFinished() ? 1 : 0;
|
||||
finished[batchSlotBeamWidthIdx] = finishState;
|
||||
@ -174,8 +180,8 @@ __global__ void lengthCriterion(FinishedState* finished, SizeType32* finishedSum
|
||||
}
|
||||
|
||||
void invokeLengthCriterion(FinishedState* finished, SizeType32* finishedSum, SizeType32 const* sequenceLimitLength,
|
||||
SizeType32* sequenceLengths, SizeType32 const* batchSlots, SizeType32 batchSize, SizeType32 beamWidth,
|
||||
cudaStream_t stream)
|
||||
SizeType32* sequenceLengths, SizeType32* numNewTokens, SizeType32 const* batchSlots, SizeType32 batchSize,
|
||||
SizeType32 beamWidth, cudaStream_t stream)
|
||||
{
|
||||
// Check if we have attained the sequence length limit. If so, stop the
|
||||
// sequence. In addition, check if all sequences are stopped and return the
|
||||
@ -184,12 +190,12 @@ void invokeLengthCriterion(FinishedState* finished, SizeType32* finishedSum, Siz
|
||||
dim3 grid{static_cast<uint32_t>(batchSize)};
|
||||
|
||||
lengthCriterion<<<grid, block, 0, stream>>>(
|
||||
finished, finishedSum, sequenceLimitLength, sequenceLengths, batchSlots, batchSize, beamWidth);
|
||||
finished, finishedSum, sequenceLimitLength, sequenceLengths, numNewTokens, batchSlots, batchSize, beamWidth);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
__global__ void explicitEOSCriterion(TokenIdType const** outputIds, TokenIdType const* endIds, FinishedState* finished,
|
||||
SizeType32* sequenceLengths, SizeType32 const* tokensPerStep, SizeType32 const* batchSlots, SizeType32 batchSize,
|
||||
SizeType32* sequenceLengths, SizeType32* numNewTokens, SizeType32 const* batchSlots, SizeType32 batchSize,
|
||||
SizeType32 maxTokensPerStep)
|
||||
{
|
||||
auto const batchIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
@ -204,7 +210,7 @@ __global__ void explicitEOSCriterion(TokenIdType const** outputIds, TokenIdType
|
||||
return;
|
||||
}
|
||||
|
||||
auto const numTokens = tokensPerStep != nullptr ? tokensPerStep[batchSlot] : maxTokensPerStep;
|
||||
auto const numTokens = numNewTokens != nullptr ? numNewTokens[batchSlot] : maxTokensPerStep;
|
||||
auto const endId = endIds[batchSlot];
|
||||
auto const sequenceLength = sequenceLengths[batchSlot];
|
||||
|
||||
@ -217,12 +223,17 @@ __global__ void explicitEOSCriterion(TokenIdType const** outputIds, TokenIdType
|
||||
{
|
||||
finished[batchSlot].setFinishedEOS();
|
||||
sequenceLengths[batchSlot] = max(0, pos);
|
||||
if (numNewTokens)
|
||||
{
|
||||
numNewTokens[batchSlot] = pos - posStart;
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void invokeExplicitEOSCriterion(TokenIdType const** outputIds, TokenIdType const* endIds, FinishedState* finished,
|
||||
SizeType32* sequenceLengths, SizeType32 const* tokensPerStep, SizeType32 const* batchSlots, SizeType32 batchSize,
|
||||
SizeType32* sequenceLengths, SizeType32* numNewTokens, SizeType32 const* batchSlots, SizeType32 batchSize,
|
||||
SizeType32 beamWidth, SizeType32 maxTokensPerStep, cudaStream_t stream)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(beamWidth == 1, "Explicit EOS criterion does not support beam search");
|
||||
@ -233,7 +244,7 @@ void invokeExplicitEOSCriterion(TokenIdType const** outputIds, TokenIdType const
|
||||
grid.x = divUp(batchSize, blockSize);
|
||||
|
||||
explicitEOSCriterion<<<grid, blockSize, 0, stream>>>(
|
||||
outputIds, endIds, finished, sequenceLengths, tokensPerStep, batchSlots, batchSize, maxTokensPerStep);
|
||||
outputIds, endIds, finished, sequenceLengths, numNewTokens, batchSlots, batchSize, maxTokensPerStep);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
|
||||
@ -62,14 +62,16 @@ void invokeStopWordsCriterion(runtime::TokenIdType const** outputIds, runtime::S
|
||||
//! \param sequenceLimitLength input buffer [maxBatchSize]. Maximum sequence length.
|
||||
//! \param sequenceLengths input/output buffer [maxBatchSize, beamWidth].
|
||||
//! Current sequence lengths of the request tokens.
|
||||
//! \param numNewTokens output buffer [maxBatchSize], optional. Number of tokens per step for each request.
|
||||
//! It is assumed that all requests have maxTokensPerStep tokens per step if nullptr.
|
||||
//! \param batchSlots input buffer[batchSize], optional. Indices of rows of data in memory pool
|
||||
//! \param batchSize batch size
|
||||
//! \param beamWidth beam width
|
||||
//! \param stream stream
|
||||
void invokeLengthCriterion(FinishedState* finished, runtime::SizeType32* finishedSum,
|
||||
runtime::SizeType32 const* sequenceLimitLength, runtime::SizeType32* sequenceLengths,
|
||||
runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth,
|
||||
cudaStream_t stream);
|
||||
runtime::SizeType32* numNewTokens, runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize,
|
||||
runtime::SizeType32 beamWidth, cudaStream_t stream);
|
||||
|
||||
//! \brief Sets finished states based on the endIds and ajusts sequence length to length before the first EOS token.
|
||||
//! Does not support beamWidth > 1 for now.
|
||||
@ -81,7 +83,7 @@ void invokeLengthCriterion(FinishedState* finished, runtime::SizeType32* finishe
|
||||
//! [maxBatchSize, beamWidth]. Finished states. Set to FinishedState::FINISHED_EOS if any new tokens contain EOS.
|
||||
//! \param sequenceLengths input/output buffer [maxBatchSize, beamWidth].
|
||||
//! Current sequence lengths of the request tokens.
|
||||
//! \param tokensPerStep input buffer [maxBatchSize], optional. Number of tokens per step for each request.
|
||||
//! \param numNewTokens input/output buffer [maxBatchSize], optional. Number of tokens per step for each request.
|
||||
//! It is assumed that all requests have maxTokensPerStep tokens per step if nullptr.
|
||||
//! \param batchSlots input buffer[batchSize], optional. Indices of rows of data in memory pool
|
||||
//! \param batchSize batch size
|
||||
@ -89,7 +91,7 @@ void invokeLengthCriterion(FinishedState* finished, runtime::SizeType32* finishe
|
||||
//! \param maxTokensPerStep maximum number of tokens decoded per step
|
||||
//! \param stream stream
|
||||
void invokeExplicitEOSCriterion(runtime::TokenIdType const** outputIds, runtime::TokenIdType const* endIds,
|
||||
FinishedState* finished, runtime::SizeType32* sequenceLengths, runtime::SizeType32 const* tokensPerStep,
|
||||
FinishedState* finished, runtime::SizeType32* sequenceLengths, runtime::SizeType32* numNewTokens,
|
||||
runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth,
|
||||
runtime::SizeType32 maxTokensPerStep, cudaStream_t stream);
|
||||
} // namespace kernels
|
||||
|
||||
@ -16,22 +16,16 @@
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/layers/banWordsLayer.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/kernels/banBadWords.h"
|
||||
#include "tensorrt_llm/kernels/banRepeatNgram.h"
|
||||
#include "tensorrt_llm/layers/defaultDecodingParams.h"
|
||||
#include "tensorrt_llm/layers/layerUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
using namespace tensorrt_llm::common;
|
||||
using namespace tensorrt_llm::kernels;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace layers
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
@ -97,30 +91,31 @@ void BanWordsLayer<T>::freeBuffer()
|
||||
|
||||
template <typename T>
|
||||
void BanWordsLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 const* batchSlots,
|
||||
std::shared_ptr<BaseSetupParams> baseSetupParams)
|
||||
std::shared_ptr<BaseSetupParams> const& baseSetupParams)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto setupParams = std::dynamic_pointer_cast<DynamicDecodeSetupParams>(baseSetupParams);
|
||||
std::vector<SizeType32> batchSlotsVec(batchSize);
|
||||
std::iota(batchSlotsVec.begin(), batchSlotsVec.end(), 0);
|
||||
auto batchSlotsHost = batchSlots ? batchSlots : batchSlotsVec.data();
|
||||
auto const& penaltyParams = setupParams->penaltyParams;
|
||||
auto const& banWordsParams = setupParams->banWordsParams;
|
||||
TLLM_CHECK_WITH_INFO(banWordsParams, "banWordsParams for setup is not set");
|
||||
bool const useNoRepeatNgramSize
|
||||
= mDecodingMode.isUseNoRepeatNgramSize() && penaltyParams.noRepeatNgramSize.has_value();
|
||||
= mDecodingMode.isUseNoRepeatNgramSize() && banWordsParams->noRepeatNgramSize.has_value();
|
||||
FillBuffers const fillBuffers{batchSize, mDecoderDomain.getBatchSize(), mStream};
|
||||
mUseNoRepeatNgramSize |= useNoRepeatNgramSize;
|
||||
if (mUseNoRepeatNgramSize)
|
||||
{
|
||||
fillBuffers(penaltyParams.noRepeatNgramSize, DefaultDecodingParams::getNoRepeatNgramSize(), mNoRepeatNgramSize,
|
||||
mNoRepeatNgramSizeDevice, batchSlotsHost, std::make_pair(0.f, std::numeric_limits<float>::max()),
|
||||
"no_repeat_ngram_size");
|
||||
fillBuffers(banWordsParams->noRepeatNgramSize, DefaultDecodingParams::getNoRepeatNgramSize(),
|
||||
mNoRepeatNgramSize, mNoRepeatNgramSizeDevice, batchSlotsHost,
|
||||
std::make_pair(0.f, std::numeric_limits<float>::max()), "no_repeat_ngram_size");
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BanWordsLayer<T>::banRepeatNGrams(Tensor& logits, std::shared_ptr<DynamicDecodeOutputParams> const& outputs,
|
||||
std::shared_ptr<DynamicDecodeInputParams> const& inputs, SizeType32 const* batchSlots,
|
||||
void BanWordsLayer<T>::banRepeatNGrams(Tensor& logits, std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<DecodingInputs> const& inputs, SizeType32 const* batchSlots,
|
||||
SizeType32 const* noRepeatNgramSizeDevice, DecoderDomain const& decoderDomain, SizeType32 maxSeqLen,
|
||||
bool useNoRepeatNgramSize, cudaStream_t stream)
|
||||
{
|
||||
@ -129,11 +124,11 @@ void BanWordsLayer<T>::banRepeatNGrams(Tensor& logits, std::shared_ptr<DynamicDe
|
||||
auto const maxStep = maxSeqLen;
|
||||
if (useNoRepeatNgramSize)
|
||||
{
|
||||
invokeBanRepeatNgram(logits.template getPtr<T>(), outputs->output_ids_ptr.template getPtr<TokenIdType const*>(),
|
||||
invokeBanRepeatNgram(logits.template getPtr<T>(), outputs->outputIdsPtr.template getPtr<TokenIdType const*>(),
|
||||
reinterpret_cast<FinishedState*>(
|
||||
inputs->finished.value_or(Tensor{}).template getPtr<FinishedState::UnderlyingType>()),
|
||||
outputs->parent_ids_ptr.template getPtr<SizeType32 const*>(), batchSlots,
|
||||
outputs->sequence_length->template getPtr<SizeType32>(), decoderDomain.getBatchSize(),
|
||||
outputs->parentIdsPtr.template getPtr<SizeType32 const*>(), batchSlots,
|
||||
outputs->sequenceLength->template getPtr<SizeType32>(), decoderDomain.getBatchSize(),
|
||||
decoderDomain.getBeamWidth(), maxSeqLen, noRepeatNgramSizeDevice, decoderDomain.getVocabSizePadded(),
|
||||
maxStep, stream);
|
||||
}
|
||||
@ -141,39 +136,40 @@ void BanWordsLayer<T>::banRepeatNGrams(Tensor& logits, std::shared_ptr<DynamicDe
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BanWordsLayer<T>::banBadWords(Tensor& logits, std::shared_ptr<DynamicDecodeOutputParams> const& outputs,
|
||||
std::shared_ptr<DynamicDecodeInputParams> const& inputs, SizeType32 const* batchSlots,
|
||||
DecoderDomain const& decoderDomain, SizeType32 maxSeqLen, cudaStream_t stream)
|
||||
void BanWordsLayer<T>::banBadWords(Tensor& logits, std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<DecodingInputs> const& inputs, SizeType32 const* batchSlots, DecoderDomain const& decoderDomain,
|
||||
SizeType32 maxSeqLen, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto const maxBadWordsLength = inputs->max_bad_words_len;
|
||||
auto const maxBadWordsLength = inputs->banWordsInputs->maxBadWordsLen;
|
||||
if (maxBadWordsLength)
|
||||
{
|
||||
auto const** badWordsPtr = inputs->bad_words_ptr->template getPtr<TokenIdType const*>();
|
||||
auto const* badWordsLens = inputs->bad_words_lengths->template getPtr<SizeType32>();
|
||||
auto const** badWordsPtr = inputs->banWordsInputs->badWordsPtr->template getPtr<TokenIdType const*>();
|
||||
auto const* badWordsLens = inputs->banWordsInputs->badWordsLengths->template getPtr<SizeType32>();
|
||||
|
||||
invokeBanBadWords((T*) logits.template getPtr<T>(),
|
||||
outputs->output_ids_ptr.template getPtr<TokenIdType const*>(),
|
||||
decoderDomain.getBeamWidth() > 1 ? outputs->parent_ids_ptr.template getPtr<SizeType32 const*>() : nullptr,
|
||||
invokeBanBadWords((T*) logits.template getPtr<T>(), outputs->outputIdsPtr.template getPtr<TokenIdType const*>(),
|
||||
decoderDomain.getBeamWidth() > 1 ? outputs->parentIdsPtr.template getPtr<SizeType32 const*>() : nullptr,
|
||||
batchSlots, decoderDomain.getBatchSize(), decoderDomain.getBeamWidth(), badWordsPtr, badWordsLens,
|
||||
maxBadWordsLength, decoderDomain.getVocabSizePadded(),
|
||||
outputs->sequence_length->template getPtr<SizeType32>(), maxSeqLen, stream);
|
||||
outputs->sequenceLength->template getPtr<SizeType32>(), maxSeqLen, stream);
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BanWordsLayer<T>::forwardAsync(
|
||||
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
|
||||
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto inputs = std::dynamic_pointer_cast<DynamicDecodeInputParams>(baseInputs);
|
||||
auto outputs = std::dynamic_pointer_cast<DynamicDecodeOutputParams>(baseOutputs);
|
||||
auto inputs = std::dynamic_pointer_cast<DecodingInputs>(baseInputs);
|
||||
auto outputs = std::dynamic_pointer_cast<BaseDecodingOutputs>(baseOutputs);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(inputs->banWordsInputs, "banWordsInputs for forward is not set");
|
||||
|
||||
auto const localDecoderDomain = getLocalDecoderDomain(inputs, mDecoderDomain);
|
||||
auto const maxSeqLen = outputs->output_ids.shape[outputs->output_ids.shape.size() - 1];
|
||||
auto batchSlots = inputs->batch_slots ? inputs->batch_slots->template getPtr<SizeType32 const>() : nullptr;
|
||||
auto const maxSeqLen = outputs->outputIds.shape[outputs->outputIds.shape.size() - 1];
|
||||
auto batchSlots = inputs->batchSlots ? inputs->batchSlots->template getPtr<SizeType32 const>() : nullptr;
|
||||
|
||||
banRepeatNGrams(inputs->logits.value(), outputs, inputs, batchSlots, mNoRepeatNgramSizeDevice, localDecoderDomain,
|
||||
maxSeqLen, mUseNoRepeatNgramSize, mStream);
|
||||
@ -185,5 +181,4 @@ void BanWordsLayer<T>::forwardAsync(
|
||||
template class BanWordsLayer<float>;
|
||||
template class BanWordsLayer<half>;
|
||||
|
||||
} // namespace layers
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -17,17 +17,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <curand_kernel.h>
|
||||
|
||||
#include "tensorrt_llm/common/tensor.h"
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include "tensorrt_llm/layers/baseLayer.h"
|
||||
#include "tensorrt_llm/layers/decodingParams.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace layers
|
||||
#include <curand_kernel.h>
|
||||
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
//! \brief Layer to ban specific words from being sampled.
|
||||
@ -45,20 +42,21 @@ public:
|
||||
~BanWordsLayer() override;
|
||||
|
||||
void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 const* batchSlots,
|
||||
std::shared_ptr<BaseSetupParams> baseSetupParams) override;
|
||||
std::shared_ptr<BaseSetupParams> const& baseSetupParams) override;
|
||||
|
||||
//! \brief Modifies 'outputs->logits' in-place with -INF for banned words
|
||||
void forwardAsync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) override;
|
||||
void forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& inputs) override;
|
||||
|
||||
private:
|
||||
void initialize();
|
||||
void allocateBuffer();
|
||||
void freeBuffer();
|
||||
static void banBadWords(tc::Tensor& logits, std::shared_ptr<DynamicDecodeOutputParams> const& outputs,
|
||||
std::shared_ptr<DynamicDecodeInputParams> const& params, runtime::SizeType32 const* batchSlots,
|
||||
static void banBadWords(tc::Tensor& logits, std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<DecodingInputs> const& inputs, runtime::SizeType32 const* batchSlots,
|
||||
DecoderDomain const& decoderDomain, runtime::SizeType32 maxSeqLen, cudaStream_t stream);
|
||||
static void banRepeatNGrams(tc::Tensor& logits, std::shared_ptr<DynamicDecodeOutputParams> const& outputs,
|
||||
std::shared_ptr<DynamicDecodeInputParams> const& inputs, runtime::SizeType32 const* batchSlots,
|
||||
static void banRepeatNGrams(tc::Tensor& logits, std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<DecodingInputs> const& inputs, runtime::SizeType32 const* batchSlots,
|
||||
runtime::SizeType32 const* noRepeatNgramSizeDevice, DecoderDomain const& decoderDomain,
|
||||
runtime::SizeType32 maxSeqLen, bool useNoRepeatNgramSize, cudaStream_t stream);
|
||||
|
||||
@ -76,5 +74,4 @@ private:
|
||||
bool mUseNoRepeatNgramSize{false};
|
||||
};
|
||||
|
||||
} // namespace layers
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -17,14 +17,14 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/allocator.h"
|
||||
#include "tensorrt_llm/common/tensor.h"
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include "tensorrt_llm/common/cudaAllocator.h"
|
||||
#include "tensorrt_llm/layers/decodingParams.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace layers
|
||||
#include <utility>
|
||||
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
class BaseLayer
|
||||
@ -35,8 +35,17 @@ public:
|
||||
|
||||
BaseLayer(DecoderDomain const& decoderDomain, cudaStream_t stream,
|
||||
std::shared_ptr<tensorrt_llm::common::IAllocator> allocator)
|
||||
: mStream(stream)
|
||||
: mBufferManager(nullptr)
|
||||
, mStream(stream)
|
||||
, mAllocator(std::move(allocator))
|
||||
, mDecoderDomain(std::move(decoderDomain))
|
||||
{
|
||||
}
|
||||
|
||||
BaseLayer(DecoderDomain const& decoderDomain, std::shared_ptr<runtime::BufferManager> const& bufferManager)
|
||||
: mBufferManager(bufferManager)
|
||||
, mStream(mBufferManager->getStream().get())
|
||||
, mAllocator(std::make_shared<tensorrt_llm::common::CudaAllocator>(*mBufferManager))
|
||||
, mDecoderDomain(decoderDomain)
|
||||
{
|
||||
}
|
||||
@ -79,29 +88,36 @@ public:
|
||||
//! \param setupParams shared pointer to params inherited from BaseSetupParams
|
||||
// clang-format on
|
||||
virtual void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth,
|
||||
runtime::SizeType32 const* batchSlots, std::shared_ptr<BaseSetupParams> setupParams)
|
||||
runtime::SizeType32 const* batchSlots, std::shared_ptr<BaseSetupParams> const& setupParams)
|
||||
= 0;
|
||||
|
||||
// clang-format off
|
||||
//! \brief Virtual function to execute layer async on GPU.
|
||||
//! There must be no stream synchronization inside this function.
|
||||
//!
|
||||
//! \param outputs shared pointer to params inherited from BaseOutputParams
|
||||
//! \param outputs shared pointer to params inherited from BaseDecodingOutputs
|
||||
//! \param inputs shared pointer to params inherited from BaseForwardParams
|
||||
// clang-format on
|
||||
virtual void forwardAsync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) = 0;
|
||||
virtual void forwardAsync(
|
||||
std::shared_ptr<BaseDecodingOutputs> const& outputs, std::shared_ptr<BaseDecodingInputs> const& inputs)
|
||||
= 0;
|
||||
|
||||
// clang-format off
|
||||
//! \brief Virtual function to execute layer synchronously on CPU / GPU.
|
||||
//! It is allowed (but not necassary) to synchronize on stream inside this function.
|
||||
//! It is targeted mainly for prototyping.
|
||||
//!
|
||||
//! \param outputs shared pointer to params inherited from BaseOutputParams
|
||||
//! \param outputs shared pointer to params inherited from BaseDecodingOutputs
|
||||
//! \param inputs shared pointer to params inherited from BaseForwardParams
|
||||
// clang-format on
|
||||
virtual void forwardSync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) {}
|
||||
virtual void forwardSync(
|
||||
std::shared_ptr<BaseDecodingOutputs> const& outputs, std::shared_ptr<BaseDecodingInputs> const& inputs)
|
||||
{
|
||||
}
|
||||
|
||||
protected:
|
||||
// Buffer Manager
|
||||
std::shared_ptr<runtime::BufferManager> mBufferManager;
|
||||
// Cuda stream
|
||||
cudaStream_t mStream;
|
||||
// Memory allocator
|
||||
@ -119,5 +135,4 @@ protected:
|
||||
bool mIsAllocateBuffer{false};
|
||||
};
|
||||
|
||||
} // namespace layers
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -14,19 +14,17 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/kernels/beamSearchKernels.h"
|
||||
#include "tensorrt_llm/layers/beamSearchLayer.h"
|
||||
#include "tensorrt_llm/layers/defaultDecodingParams.h"
|
||||
#include "tensorrt_llm/layers/layerUtils.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
using namespace tensorrt_llm::common;
|
||||
using namespace tensorrt_llm::kernels;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace layers
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
@ -47,7 +45,7 @@ BeamSearchLayer<T>::~BeamSearchLayer()
|
||||
|
||||
template <typename T>
|
||||
void BeamSearchLayer<T>::setup(runtime::SizeType32 const batchSize, runtime::SizeType32 const beamWidth,
|
||||
runtime::SizeType32 const* batchSlots, std::shared_ptr<BaseSetupParams> baseSetupParams)
|
||||
runtime::SizeType32 const* batchSlots, std::shared_ptr<BaseSetupParams> const& baseSetupParams)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
@ -65,12 +63,12 @@ void BeamSearchLayer<T>::setup(runtime::SizeType32 const batchSize, runtime::Siz
|
||||
auto constexpr fltEpsilon = std::numeric_limits<float>::epsilon();
|
||||
|
||||
FillBuffers const fillBuffers{batchSize, batchSize, mStream};
|
||||
fillBuffers(setupParams->beam_search_diversity_rate, DefaultDecodingParams::getBeamSearchDiversity(),
|
||||
fillBuffers(setupParams->beamSearchDiversityRate, DefaultDecodingParams::getBeamSearchDiversity(),
|
||||
mDiversityRateHost, mDiversityRateDevice, (int*) nullptr, std::make_pair(-fltEpsilon, fltMax),
|
||||
"diveristy rate");
|
||||
fillBuffers(setupParams->length_penalty, DefaultDecodingParams::getLengthPenalty(), mLengthPenaltyHost,
|
||||
fillBuffers(setupParams->lengthPenalty, DefaultDecodingParams::getLengthPenalty(), mLengthPenaltyHost,
|
||||
mLengthPenaltyDevice, (int*) nullptr, std::make_pair(fltMin, fltMax), "length penalty");
|
||||
fillBuffers(setupParams->early_stopping, DefaultDecodingParams::getEarlyStopping(), mEarlyStoppingHost,
|
||||
fillBuffers(setupParams->earlyStopping, DefaultDecodingParams::getEarlyStopping(), mEarlyStoppingHost,
|
||||
mEarlyStoppingDevice, (int*) nullptr, std::make_pair(fltMin, fltMax), "early stopping");
|
||||
mHasDiffRuntimeArgs = setupParams->hasDiffRuntimeArgs;
|
||||
|
||||
@ -80,7 +78,7 @@ void BeamSearchLayer<T>::setup(runtime::SizeType32 const batchSize, runtime::Siz
|
||||
__global__ void updateCacheIndirectionKernel(
|
||||
int* tgtCI, int const* srcCI, BeamHypotheses bh, int const nMaxAttentionWindow, int const nSinkTokenLength)
|
||||
{
|
||||
// Update indirections from steps `bh.inputLength[indexBatchBeam]` to step `sequence_lengths[indexBatchBeam]`
|
||||
// Update indirections from steps `bh.inputLength[indexBatchBeam]` to step `sequenceLengths[indexBatchBeam]`
|
||||
int const step = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int const indexBatchBeam = blockIdx.y;
|
||||
int const nBS{bh.nBatchSize};
|
||||
@ -88,7 +86,7 @@ __global__ void updateCacheIndirectionKernel(
|
||||
int const nMSL{bh.nMaxSeqLen};
|
||||
int const indexBatch = indexBatchBeam / nBM;
|
||||
int const indexBeam = indexBatchBeam % nBM;
|
||||
int const lastStep{bh.sequenceLengths[indexBatchBeam] - 1}; // the sequence_lengths is updated, need to minus 1
|
||||
int const lastStep{bh.sequenceLengths[indexBatchBeam] - 1}; // the sequenceLengths is updated, need to minus 1
|
||||
|
||||
// Return early when the indexBatchBeam or step is out of the bound
|
||||
// No update for the indices of context part since KV Cache is shared
|
||||
@ -111,38 +109,38 @@ __global__ void updateCacheIndirectionKernel(
|
||||
|
||||
template <typename T>
|
||||
void BeamSearchLayer<T>::forwardAsyncSingleRequest(
|
||||
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
|
||||
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto ip = std::dynamic_pointer_cast<BeamSearchInputParams>(baseInputs);
|
||||
auto op = std::dynamic_pointer_cast<BeamSearchOutputParams>(baseOutputs);
|
||||
auto op = std::dynamic_pointer_cast<BeamSearchOutputs>(baseOutputs);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(op->beamHypotheses, std::string("Output BeamHypotheses is not set."));
|
||||
TLLM_CHECK_WITH_INFO(op->sequence_length->template getPtr<int>() != nullptr || mLengthPenaltyDevice == nullptr,
|
||||
TLLM_CHECK_WITH_INFO(op->sequenceLength->template getPtr<int>() != nullptr || mLengthPenaltyDevice == nullptr,
|
||||
std::string("Current sequence lengths must be set for length penalty computation."));
|
||||
TLLM_CHECK_WITH_INFO(ip->ite == 0, "Pipeline Parallelism is not supported yet !");
|
||||
|
||||
BeamHypotheses& bh{*op->beamHypotheses};
|
||||
// bh's members already initialized in op: *CBA, batchDones
|
||||
// bh's members not used in function: outputIds, logProbs, outputIdsUnfinish, parentIdsUnfinish
|
||||
bh.nMaxBatchSize = static_cast<std::int32_t>(op->output_ids_ptr.shape[0]);
|
||||
bh.nBatchSize = ip->logits.shape[0];
|
||||
bh.nBeamWidth = static_cast<std::int32_t>(op->output_ids_ptr.shape[1]);
|
||||
bh.nMaxBatchSize = static_cast<std::int32_t>(op->outputIdsPtr.shape[0]);
|
||||
bh.nBatchSize = ip->localBatchSize;
|
||||
bh.nBeamWidth = static_cast<std::int32_t>(op->outputIdsPtr.shape[1]);
|
||||
bh.nIte = ip->ite;
|
||||
bh.nMaxSeqLen = static_cast<std::int32_t>(op->output_ids_ptr.shape[2]);
|
||||
bh.nMaxSeqLen = static_cast<std::int32_t>(op->outputIdsPtr.shape[2]);
|
||||
bh.nVocabSize = mVocabSizePadded;
|
||||
bh.diversityRates = mDiversityRateDevice;
|
||||
bh.lengthPenalties = mLengthPenaltyDevice;
|
||||
bh.earlyStoppings = mEarlyStoppingDevice;
|
||||
bh.inputLengths = ip->input_lengths->template getPtr<int const>();
|
||||
bh.endIds = ip->end_ids.template getPtr<int const>();
|
||||
bh.logProbsTiled = (op->output_log_probs) ? op->output_log_probs->template getPtr<float>() : nullptr;
|
||||
bh.sequenceLengths = op->sequence_length->template getPtr<int>();
|
||||
bh.cumLogProbs = op->cum_log_probs->template getPtr<float>();
|
||||
bh.inputLengths = ip->inputLengths->template getPtr<int const>();
|
||||
bh.endIds = ip->endIds.template getPtr<int const>();
|
||||
bh.logProbsTiled = (op->outputLogProbs) ? op->outputLogProbs->template getPtr<float>() : nullptr;
|
||||
bh.sequenceLengths = op->sequenceLength->template getPtr<int>();
|
||||
bh.cumLogProbs = op->cumLogProbs->template getPtr<float>();
|
||||
bh.finished = reinterpret_cast<FinishedState*>(op->finished->template getPtr<FinishedState::UnderlyingType>());
|
||||
bh.outputIdsPtr = op->output_ids_ptr.template getPtr<int*>();
|
||||
bh.parentIdsPtr = op->parent_ids_ptr.template getPtr<int*>();
|
||||
bh.outputIdsPtr = op->outputIdsPtr.template getPtr<int*>();
|
||||
bh.parentIdsPtr = op->parentIdsPtr.template getPtr<int*>();
|
||||
|
||||
T const* logits = ip->logits.template getPtr<T>();
|
||||
T const* bias = static_cast<T const*>(nullptr);
|
||||
@ -155,11 +153,11 @@ void BeamSearchLayer<T>::forwardAsyncSingleRequest(
|
||||
|
||||
if (bh.nBeamWidth > 1)
|
||||
{
|
||||
auto tgtCI = op->tgt_cache_indirection.template getPtr<int>();
|
||||
auto srcCI = ip->src_cache_indirection.template getPtr<int const>();
|
||||
auto tgtCI = op->tgtCacheIndirection.template getPtr<int>();
|
||||
auto srcCI = ip->srcCacheIndirection.template getPtr<int const>();
|
||||
dim3 const grid(roundUp(bh.nMaxSeqLen, 32), bh.nBatchSize * bh.nBeamWidth);
|
||||
updateCacheIndirectionKernel<<<grid, 32, 0, mStream>>>(
|
||||
tgtCI, srcCI, bh, ip->max_attention_window, ip->sink_token_length);
|
||||
tgtCI, srcCI, bh, ip->maxAttentionWindow, ip->sinkTokenLength);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
@ -168,32 +166,29 @@ void BeamSearchLayer<T>::forwardAsyncSingleRequest(
|
||||
|
||||
template <typename T>
|
||||
void BeamSearchLayer<T>::forwardAsync(
|
||||
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
|
||||
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto outputs = std::dynamic_pointer_cast<DynamicDecodeOutputParams>(baseOutputs);
|
||||
auto params = std::dynamic_pointer_cast<DynamicDecodeInputParams>(baseInputs);
|
||||
auto outputs = std::dynamic_pointer_cast<BeamSearchOutputs>(baseOutputs);
|
||||
auto params = std::dynamic_pointer_cast<DecodingInputs>(baseInputs);
|
||||
|
||||
auto const localDecoderDomain = getLocalDecoderDomain(params, mDecoderDomain);
|
||||
|
||||
auto batchSlots = params->batch_slots ? params->batch_slots->template getPtr<SizeType32 const>() : nullptr;
|
||||
auto const maxSeqLen = outputs->output_ids.shape[outputs->output_ids.shape.size() - 1];
|
||||
auto batchSlots = params->batchSlots ? params->batchSlots->template getPtr<SizeType32 const>() : nullptr;
|
||||
auto const maxSeqLen = outputs->outputIds.shape[outputs->outputIds.shape.size() - 1];
|
||||
auto const ite = params->ite;
|
||||
auto const step = params->step;
|
||||
|
||||
// common inputs
|
||||
auto const& endIds = params->end_ids;
|
||||
auto const localBatchSize = static_cast<std::size_t>(params->local_batch_size);
|
||||
auto const& endIds = params->endIds;
|
||||
auto const localBatchSize = static_cast<std::size_t>(params->localBatchSize);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(localDecoderDomain.getBeamWidth() > 1,
|
||||
"Decoding mode is beam search, but beamWidth <= 1 (%d <= 1)", localDecoderDomain.getBeamWidth());
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
params->src_cache_indirection.has_value(), "src_cache_indirection is mandatory in beam search.");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
outputs->tgt_cache_indirection.has_value(), "tgt_cache_indirection is mandatory in beam search.");
|
||||
TLLM_CHECK_WITH_INFO(outputs->parent_ids.has_value(), "parent_ids tensor is mandatory in beam search.");
|
||||
TLLM_CHECK_WITH_INFO(params->srcCacheIndirection.has_value(), "srcCacheIndirection is mandatory in beam search.");
|
||||
TLLM_CHECK_WITH_INFO(outputs->parentIds.has_value(), "parentIds tensor is mandatory in beam search.");
|
||||
TLLM_CHECK_WITH_INFO(outputs->finished.has_value(), "finished tensor is mandatory in beam search.");
|
||||
TLLM_CHECK_WITH_INFO(outputs->cum_log_probs.has_value(), "cum_log_probs tensor is mandatory in beam search.");
|
||||
TLLM_CHECK_WITH_INFO(outputs->cumLogProbs.has_value(), "cumLogProbs tensor is mandatory in beam search.");
|
||||
|
||||
// Compute one by one if there are different runtime arguments
|
||||
// due to Batch-Beam-Search is not supported yet, so we need to compute
|
||||
@ -211,30 +206,32 @@ void BeamSearchLayer<T>::forwardAsync(
|
||||
auto const end_id_offset = endIds.slice({dynamic_decode_batch_size}, dynamic_ite * dynamic_decode_batch_size);
|
||||
|
||||
auto forwardParams = std::make_shared<BeamSearchInputParams>(step, ite, logits_offset, end_id_offset,
|
||||
*params->src_cache_indirection, static_cast<std::int32_t>(params->max_attention_window),
|
||||
static_cast<std::int32_t>(params->sink_token_length), static_cast<std::int32_t>(maxSeqLen));
|
||||
*params->srcCacheIndirection, static_cast<std::int32_t>(params->maxAttentionWindow),
|
||||
static_cast<std::int32_t>(params->sinkTokenLength), static_cast<std::int32_t>(maxSeqLen),
|
||||
dynamic_decode_batch_size);
|
||||
|
||||
if (params->input_lengths)
|
||||
if (params->inputLengths)
|
||||
{
|
||||
forwardParams->input_lengths = params->input_lengths->slice(
|
||||
forwardParams->inputLengths = params->inputLengths->slice(
|
||||
{dynamic_decode_batch_size * localDecoderDomain.getBeamWidth()}, dynamic_id_offset);
|
||||
}
|
||||
|
||||
auto outputParams = std::make_shared<BeamSearchOutputParams>(
|
||||
outputs->output_ids, outputs->parent_ids.value(), outputs->tgt_cache_indirection.value());
|
||||
auto outputParams = std::make_shared<BeamSearchOutputs>(outputs->outputIds);
|
||||
|
||||
outputParams->output_ids_ptr = std::move(outputs->output_ids_ptr);
|
||||
outputParams->parent_ids_ptr = std::move(outputs->parent_ids_ptr);
|
||||
outputParams->sequence_length = outputs->sequence_length->slice(
|
||||
outputParams->parentIds = std::move(outputs->parentIds);
|
||||
outputParams->tgtCacheIndirection = std::move(outputs->tgtCacheIndirection);
|
||||
outputParams->outputIdsPtr = std::move(outputs->outputIdsPtr);
|
||||
outputParams->parentIdsPtr = std::move(outputs->parentIdsPtr);
|
||||
outputParams->sequenceLength = outputs->sequenceLength->slice(
|
||||
{dynamic_decode_batch_size * localDecoderDomain.getBeamWidth()}, dynamic_id_offset);
|
||||
outputParams->finished = outputs->finished->slice(
|
||||
{dynamic_decode_batch_size * localDecoderDomain.getBeamWidth()}, dynamic_id_offset);
|
||||
outputParams->cum_log_probs = outputs->cum_log_probs->slice(
|
||||
outputParams->cumLogProbs = outputs->cumLogProbs->slice(
|
||||
{dynamic_decode_batch_size * localDecoderDomain.getBeamWidth()}, dynamic_id_offset);
|
||||
outputParams->output_log_probs = outputs->output_log_probs_tiled; // notice: use tiled tensor
|
||||
outputParams->outputLogProbs = outputs->outputLogProbsTiled; // notice: use tiled tensor
|
||||
outputParams->beamHypotheses = std::move(outputs->beamHypotheses);
|
||||
|
||||
// beam_search_diversity_rate is only supported when using BeamHypotheses
|
||||
// beamSearchDiversityRate is only supported when using BeamHypotheses
|
||||
forwardAsyncSingleRequest(outputParams, forwardParams);
|
||||
} // end of dynamic_ite
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
@ -275,5 +272,4 @@ void BeamSearchLayer<T>::freeBuffer()
|
||||
template class BeamSearchLayer<float>;
|
||||
template class BeamSearchLayer<half>;
|
||||
|
||||
} // namespace layers
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -17,71 +17,40 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/tensor.h"
|
||||
#include "tensorrt_llm/kernels/beamSearchKernels.h"
|
||||
#include "tensorrt_llm/kernels/decodingCommon.h"
|
||||
#include "tensorrt_llm/layers/baseLayer.h"
|
||||
#include "tensorrt_llm/layers/decodingParams.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
namespace tc = tensorrt_llm::common;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace layers
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
// BS: batch_size, lBS: local_batch_size, BM: beam_width, mSL: max_seq_length
|
||||
class BeamSearchSetupParams : public BaseSetupParams
|
||||
{
|
||||
public:
|
||||
std::optional<std::vector<float>> beam_search_diversity_rate; // [BS] on cpu
|
||||
std::optional<std::vector<float>> length_penalty; // [BS] on cpu
|
||||
std::optional<std::vector<int>> early_stopping; // [BS] on cpu
|
||||
bool hasDiffRuntimeArgs{false};
|
||||
};
|
||||
|
||||
class BeamSearchInputParams : public BaseInputParams
|
||||
class BeamSearchInputParams : public DecodingInputs
|
||||
{
|
||||
public:
|
||||
explicit BeamSearchInputParams(runtime::SizeType32 step, runtime::SizeType32 ite, tc::Tensor logits,
|
||||
tc::Tensor endIds, tc::Tensor src_cache_indirection, runtime::SizeType32 max_attention_window,
|
||||
runtime::SizeType32 sink_token_length, runtime::SizeType32 max_seq_len)
|
||||
: BaseInputParams(step, ite, std::move(endIds))
|
||||
tc::Tensor endIds, tc::Tensor srcCacheIndirection, runtime::SizeType32 maxAttentionWindow,
|
||||
runtime::SizeType32 sinkTokenLength, runtime::SizeType32 maxSeqLen, runtime::SizeType32 localBatchSize)
|
||||
: DecodingInputs(std::move(endIds), step, ite, localBatchSize)
|
||||
, logits{std::move(logits)}
|
||||
, max_attention_window{max_attention_window}
|
||||
, sink_token_length{sink_token_length}
|
||||
, max_seq_len{max_seq_len}
|
||||
, src_cache_indirection{std::move(src_cache_indirection)}
|
||||
, maxAttentionWindow{maxAttentionWindow}
|
||||
, sinkTokenLength{sinkTokenLength}
|
||||
, maxSeqLen{maxSeqLen}
|
||||
, srcCacheIndirection{std::move(srcCacheIndirection)}
|
||||
{
|
||||
}
|
||||
|
||||
// mandatory parameters
|
||||
tc::Tensor logits; // [maxBatchSize, beamWidth, vocabSizePadded]
|
||||
runtime::SizeType32 max_attention_window;
|
||||
runtime::SizeType32 sink_token_length;
|
||||
runtime::SizeType32 max_seq_len;
|
||||
tc::Tensor src_cache_indirection; // [BS, BM, mSL]
|
||||
std::optional<tc::Tensor> input_lengths; // [BS, BM]
|
||||
};
|
||||
|
||||
class BeamSearchOutputParams : public BaseOutputParams
|
||||
{
|
||||
public:
|
||||
explicit BeamSearchOutputParams(tc::Tensor outputIds, tc::Tensor parentIds, tc::Tensor tgt_cache_indirection)
|
||||
: BaseOutputParams{std::move(outputIds)}
|
||||
, parent_ids{std::move(parentIds)}
|
||||
, tgt_cache_indirection{std::move(tgt_cache_indirection)}
|
||||
{
|
||||
}
|
||||
|
||||
std::shared_ptr<kernels::BeamHypotheses> beamHypotheses;
|
||||
tc::Tensor parent_ids; // [BS, BM, mSL]
|
||||
tc::Tensor tgt_cache_indirection; // [BS, BM, mSL]
|
||||
tc::Tensor parent_ids_ptr; // [BS][BM, mSL]
|
||||
runtime::SizeType32 maxAttentionWindow;
|
||||
runtime::SizeType32 sinkTokenLength;
|
||||
runtime::SizeType32 maxSeqLen;
|
||||
tc::Tensor srcCacheIndirection; // [BS, BM, mSL]
|
||||
std::optional<tc::Tensor> inputLengths; // [BS, BM]
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -94,15 +63,17 @@ public:
|
||||
|
||||
~BeamSearchLayer() override;
|
||||
|
||||
void setup(runtime::SizeType32 const batch_size, runtime::SizeType32 const beamWidth,
|
||||
runtime::SizeType32 const* batchSlots, std::shared_ptr<BaseSetupParams> setupParams) override;
|
||||
void setup(runtime::SizeType32 const batchSize, runtime::SizeType32 const beamWidth,
|
||||
runtime::SizeType32 const* batchSlots, std::shared_ptr<BaseSetupParams> const& setupParams) override;
|
||||
|
||||
void forwardAsync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) override;
|
||||
void forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& inputs) override;
|
||||
|
||||
private:
|
||||
void forwardAsyncSingleRequest(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs);
|
||||
void forwardAsyncSingleRequest(
|
||||
std::shared_ptr<BaseDecodingOutputs> const& outputs, std::shared_ptr<BaseDecodingInputs> const& inputs);
|
||||
|
||||
void allocateBuffer(runtime::SizeType32 const batch_size, runtime::SizeType32 const beam_width);
|
||||
void allocateBuffer(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth);
|
||||
void freeBuffer();
|
||||
|
||||
private:
|
||||
@ -126,5 +97,4 @@ private:
|
||||
bool mHasDiffRuntimeArgs{false};
|
||||
};
|
||||
|
||||
} // namespace layers
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -16,16 +16,13 @@
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/layers/decodingLayer.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/kernels/decodingCommon.h"
|
||||
#include "tensorrt_llm/kernels/samplingTopKKernels.h"
|
||||
#include "tensorrt_llm/layers/beamSearchLayer.h"
|
||||
#include "tensorrt_llm/layers/decodingParams.h"
|
||||
#include "tensorrt_llm/layers/explicitDraftTokensLayer.h"
|
||||
#include "tensorrt_llm/layers/layerUtils.h"
|
||||
#include "tensorrt_llm/layers/medusaDecodingLayer.h"
|
||||
#include "tensorrt_llm/layers/samplingLayer.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
using namespace tensorrt_llm::common;
|
||||
using namespace tensorrt_llm::kernels;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
@ -59,15 +56,14 @@ bool allSame(std::optional<std::vector<T>> const& vOpt)
|
||||
|
||||
bool hasDiffRuntimeArgs(std::shared_ptr<tensorrt_llm::layers::DynamicDecodeSetupParams> const& params)
|
||||
{
|
||||
return !allSame(params->penaltyParams.frequencyPenalty) || !allSame(params->penaltyParams.presencePenalty)
|
||||
|| !allSame(params->penaltyParams.repetitionPenalty) || !allSame(params->penaltyParams.temperature)
|
||||
|| !allSame(params->penaltyParams.minLength) || !allSame(params->penaltyParams.noRepeatNgramSize);
|
||||
// return !allSame(params->penaltyParams.frequencyPenalty) || !allSame(params->penaltyParams.presencePenalty)
|
||||
// || !allSame(params->penaltyParams.repetitionPenalty) || !allSame(params->penaltyParams.temperature)
|
||||
// || !allSame(params->penaltyParams.minLength) || !allSame(params->banWordsInputs.noRepeatNgramSize);
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace layers
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
template <typename T>
|
||||
DecodingLayer<T>::DecodingLayer(executor::DecodingMode const& mode, DecoderDomain const& decoderDomain,
|
||||
@ -110,61 +106,40 @@ DecodingLayer<T>::DecodingLayer(executor::DecodingMode const& mode, DecoderDomai
|
||||
|
||||
template <typename T>
|
||||
void DecodingLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 const* batchSlots,
|
||||
std::shared_ptr<BaseSetupParams> baseSetupParams)
|
||||
std::shared_ptr<BaseSetupParams> const& baseSetupParams)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto setupParams = std::dynamic_pointer_cast<DynamicDecodeSetupParams>(baseSetupParams);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(setupParams->decodingParams, "decodingParams for setup is not set");
|
||||
|
||||
if (mDecodingMode.isTopKorTopP())
|
||||
{ // sampling layers
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
beamWidth == 1, "Decoding mode is TopK and/or TopP, but beamWidth != 1 (%d != 1)", beamWidth);
|
||||
auto samplingParams = std::make_shared<SamplingSetupParams>();
|
||||
|
||||
samplingParams->runtime_top_k = setupParams->samplingParams.runtime_top_k;
|
||||
samplingParams->runtime_top_p = setupParams->samplingParams.runtime_top_p;
|
||||
samplingParams->randomSeed = setupParams->randomSeed;
|
||||
|
||||
samplingParams->top_p_decay = setupParams->samplingParams.top_p_decay;
|
||||
samplingParams->top_p_min = setupParams->samplingParams.top_p_min;
|
||||
samplingParams->top_p_reset_ids = setupParams->samplingParams.top_p_reset_ids;
|
||||
samplingParams->normalize_log_probs = setupParams->samplingParams.normalize_log_probs;
|
||||
samplingParams->outputLogProbs = setupParams->samplingParams.outputLogProbs;
|
||||
samplingParams->cumLogProbs = setupParams->samplingParams.cumLogProbs;
|
||||
|
||||
mDecodingLayer->setup(batchSize, beamWidth, batchSlots, samplingParams);
|
||||
mDecodingLayer->setup(batchSize, beamWidth, batchSlots, setupParams->decodingParams);
|
||||
}
|
||||
else if (mDecodingMode.isBeamSearch())
|
||||
{ // beam search layer
|
||||
TLLM_CHECK_WITH_INFO(beamWidth > 1, "Decoding mode is beam search, but beamWidth <= 1 (%d <= 1)", beamWidth);
|
||||
auto beamSearchParams = std::make_shared<BeamSearchSetupParams>();
|
||||
|
||||
beamSearchParams->beam_search_diversity_rate = setupParams->beamSearchParams.beam_search_diversity_rate;
|
||||
beamSearchParams->length_penalty = setupParams->beamSearchParams.length_penalty;
|
||||
beamSearchParams->early_stopping = setupParams->beamSearchParams.early_stopping;
|
||||
beamSearchParams->hasDiffRuntimeArgs = hasDiffRuntimeArgs(setupParams);
|
||||
|
||||
mDecodingLayer->setup(batchSize, beamWidth, nullptr, beamSearchParams);
|
||||
mDecodingLayer->setup(batchSize, beamWidth, nullptr, setupParams->decodingParams);
|
||||
}
|
||||
else if (mDecodingMode.isMedusa())
|
||||
{
|
||||
auto medusaSetupParams = std::make_shared<MedusaSetupParams>();
|
||||
medusaSetupParams->runtimeTopK = setupParams->samplingParams.runtime_top_k;
|
||||
medusaSetupParams->runtimeHeadsTopK = setupParams->medusaParams.topKMedusaHeads;
|
||||
medusaSetupParams->randomSeed = setupParams->randomSeed;
|
||||
mDecodingLayer->setup(batchSize, beamWidth, batchSlots, medusaSetupParams);
|
||||
TLLM_CHECK_WITH_INFO(beamWidth == 1, "Decoding mode is Medusa, but beamWidth != 1 (%d != 1)", beamWidth);
|
||||
mDecodingLayer->setup(batchSize, beamWidth, batchSlots, setupParams->decodingParams);
|
||||
}
|
||||
else if (mDecodingMode.isLookahead())
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(beamWidth == 1, "Decoding mode is Lookahead, but beamWidth != 1 (%d != 1)", beamWidth);
|
||||
// TODO(nkorobov) add lookahead layer
|
||||
}
|
||||
else if (mDecodingMode.isExplicitDraftTokens())
|
||||
{
|
||||
auto explicitDraftTokensSetupParams = std::make_shared<ExplicitDraftTokensSetupParams>();
|
||||
explicitDraftTokensSetupParams->temperature = setupParams->penaltyParams.temperature;
|
||||
explicitDraftTokensSetupParams->randomSeed = setupParams->randomSeed;
|
||||
mDecodingLayer->setup(batchSize, /* beamWidth */ 1, batchSlots, explicitDraftTokensSetupParams);
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
beamWidth == 1, "Decoding mode is ExplicitDraftTokens, but beamWidth != 1 (%d != 1)", beamWidth);
|
||||
mDecodingLayer->setup(batchSize, beamWidth, batchSlots, setupParams->decodingParams);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -178,7 +153,7 @@ void DecodingLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeTyp
|
||||
|
||||
template <typename T>
|
||||
void DecodingLayer<T>::forwardAsync(
|
||||
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
|
||||
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto [outputParams, inputParams] = prepareParams(baseOutputs, baseInputs);
|
||||
@ -188,7 +163,7 @@ void DecodingLayer<T>::forwardAsync(
|
||||
|
||||
template <typename T>
|
||||
void DecodingLayer<T>::forwardSync(
|
||||
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
|
||||
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto [outputParams, inputParams] = prepareParams(baseOutputs, baseInputs);
|
||||
@ -197,31 +172,30 @@ void DecodingLayer<T>::forwardSync(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::tuple<std::shared_ptr<BaseOutputParams>, std::shared_ptr<BaseInputParams>> DecodingLayer<T>::prepareParams(
|
||||
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs) const
|
||||
std::tuple<std::shared_ptr<BaseDecodingOutputs>, std::shared_ptr<BaseDecodingInputs>> DecodingLayer<T>::prepareParams(
|
||||
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& baseInputs) const
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto outputs = std::dynamic_pointer_cast<DynamicDecodeOutputParams>(baseOutputs);
|
||||
auto params = std::dynamic_pointer_cast<DynamicDecodeInputParams>(baseInputs);
|
||||
auto params = std::dynamic_pointer_cast<DecodingInputs>(baseInputs);
|
||||
|
||||
auto const localDecoderDomain = getLocalDecoderDomain(params, mDecoderDomain);
|
||||
auto const maxSeqLen = outputs->output_ids.shape[outputs->output_ids.shape.size() - 1];
|
||||
auto const& endIds = params->end_ids;
|
||||
auto const maxSeqLen = baseOutputs->outputIds.shape[baseOutputs->outputIds.shape.size() - 1];
|
||||
auto const& endIds = params->endIds;
|
||||
|
||||
std::shared_ptr<BaseOutputParams> preparedOutputs;
|
||||
std::shared_ptr<BaseInputParams> preparedInputs;
|
||||
std::shared_ptr<BaseDecodingOutputs> preparedOutputs;
|
||||
std::shared_ptr<BaseDecodingInputs> preparedInputs;
|
||||
|
||||
// dynamic decode GPT
|
||||
if (mDecodingMode.isBeamSearch())
|
||||
{
|
||||
preparedInputs = baseInputs;
|
||||
preparedOutputs = baseOutputs;
|
||||
}
|
||||
else if (mDecodingMode.isTopKorTopP())
|
||||
{ // beamWidth == 1
|
||||
{
|
||||
auto const ite = params->ite;
|
||||
auto const step = params->step;
|
||||
auto const localBatchSize = static_cast<std::size_t>(params->local_batch_size);
|
||||
auto const localBatchSize = static_cast<std::size_t>(params->localBatchSize);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(localDecoderDomain.getBeamWidth() == 1,
|
||||
"Decoding mode is TopK and/or TopP, but beamWidth != 1 (%d != 1)", localDecoderDomain.getBeamWidth());
|
||||
@ -231,60 +205,29 @@ std::tuple<std::shared_ptr<BaseOutputParams>, std::shared_ptr<BaseInputParams>>
|
||||
Tensor const logitsSlice{params->logits->slice(
|
||||
{localBatchSize, static_cast<size_t>(localDecoderDomain.getBeamWidth()), params->logits->shape[2]}, 0)};
|
||||
Tensor const endIdSlice{endIds.slice({localBatchSize}, 0)};
|
||||
auto decodeInputs = std::make_shared<SamplingInputParams>(
|
||||
step, ite, logitsSlice, endIdSlice, static_cast<SizeType32>(maxSeqLen));
|
||||
auto decodeInputs = std::make_shared<SamplingInputs>(endIdSlice, step, ite, localBatchSize);
|
||||
|
||||
decodeInputs->finished = params->finished;
|
||||
|
||||
if (params->input_lengths)
|
||||
decodeInputs->logits = logitsSlice;
|
||||
|
||||
if (params->inputLengths)
|
||||
{
|
||||
auto& inputLengths = params->input_lengths.value();
|
||||
decodeInputs->input_lengths
|
||||
auto& inputLengths = params->inputLengths.value();
|
||||
decodeInputs->inputLengths
|
||||
= inputLengths.slice({localBatchSize, static_cast<size_t>(localDecoderDomain.getBeamWidth())}, 0);
|
||||
}
|
||||
decodeInputs->batch_slots = params->batch_slots;
|
||||
|
||||
auto decodeOutputs = std::make_shared<SamplingOutputParams>(outputs->output_ids);
|
||||
decodeOutputs->output_ids_ptr = std::move(outputs->output_ids_ptr);
|
||||
if (outputs->sequence_length)
|
||||
{
|
||||
decodeOutputs->sequence_length
|
||||
= outputs->sequence_length->slice({localBatchSize * localDecoderDomain.getBeamWidth()}, 0);
|
||||
}
|
||||
if (outputs->finished)
|
||||
{
|
||||
decodeOutputs->finished = outputs->finished->slice({localBatchSize * localDecoderDomain.getBeamWidth()}, 0);
|
||||
}
|
||||
if (outputs->cum_log_probs)
|
||||
{
|
||||
decodeOutputs->cum_log_probs
|
||||
= outputs->cum_log_probs->slice({localBatchSize * localDecoderDomain.getBeamWidth()}, 0);
|
||||
}
|
||||
if (outputs->output_log_probs_tiled)
|
||||
{
|
||||
Tensor& output_log_probs = outputs->output_log_probs_tiled.value();
|
||||
decodeOutputs->output_log_probs
|
||||
= output_log_probs.slice({1, localBatchSize * localDecoderDomain.getBeamWidth()}, 0);
|
||||
}
|
||||
decodeInputs->batchSlots = params->batchSlots;
|
||||
|
||||
preparedInputs = decodeInputs;
|
||||
preparedOutputs = decodeOutputs;
|
||||
preparedOutputs = baseOutputs;
|
||||
}
|
||||
else if (mDecodingMode.isMedusa())
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(localDecoderDomain.getBeamWidth() == 1,
|
||||
"Decoding mode is Medusa, but beamWidth != 1 (%d != 1)", localDecoderDomain.getBeamWidth());
|
||||
|
||||
auto medusaInputParams = std::make_shared<MedusaInputParams>(params->logits.value(), endIds);
|
||||
medusaInputParams->finished = outputs->finished.value();
|
||||
medusaInputParams->batch_slots = params->batch_slots;
|
||||
medusaInputParams->paths = params->medusaInputs->medusaPaths;
|
||||
medusaInputParams->medusaLogits = params->medusaInputs->medusaLogits;
|
||||
medusaInputParams->medusaCurTokensPerStep = params->medusaInputs->medusaCurTokensPerStep;
|
||||
medusaInputParams->medusaTargetTokensPerStep = params->medusaInputs->medusaTargetTokensPerStep;
|
||||
medusaInputParams->treeIds = params->medusaInputs->medusaTreeIds;
|
||||
|
||||
preparedInputs = medusaInputParams;
|
||||
preparedInputs = baseInputs;
|
||||
preparedOutputs = baseOutputs;
|
||||
}
|
||||
else if (mDecodingMode.isLookahead())
|
||||
@ -311,5 +254,4 @@ std::tuple<std::shared_ptr<BaseOutputParams>, std::shared_ptr<BaseInputParams>>
|
||||
template class DecodingLayer<float>;
|
||||
template class DecodingLayer<half>;
|
||||
|
||||
} // namespace layers
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -17,22 +17,13 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <curand_kernel.h>
|
||||
|
||||
#include "tensorrt_llm/common/tensor.h"
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include "tensorrt_llm/layers/baseLayer.h"
|
||||
#include "tensorrt_llm/layers/beamSearchLayer.h"
|
||||
#include "tensorrt_llm/layers/decodingParams.h"
|
||||
#include "tensorrt_llm/layers/explicitDraftTokensLayer.h"
|
||||
#include "tensorrt_llm/layers/medusaDecodingLayer.h"
|
||||
#include "tensorrt_llm/layers/samplingLayer.h"
|
||||
|
||||
namespace tc = tensorrt_llm::common;
|
||||
#include <curand_kernel.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace layers
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
//! \brief Layer performs token decoding using sampling (beamWidth=1), beam search (beamWidth>1) or Medusa.
|
||||
@ -46,19 +37,21 @@ public:
|
||||
~DecodingLayer() override = default;
|
||||
|
||||
void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 const* batchSlots,
|
||||
std::shared_ptr<BaseSetupParams> setupParams) override;
|
||||
std::shared_ptr<BaseSetupParams> const& setupParams) override;
|
||||
|
||||
//! \brief Calls single SamplingLayer::forwardAsync or MedusaDecodingLayer::forwardAsync in batched mode
|
||||
//! or runs BeamSearchLayer::forwardAsync in the loop for each request.
|
||||
//! Modifies outputs->logits in-place.
|
||||
void forwardAsync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) override;
|
||||
void forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& inputs) override;
|
||||
|
||||
//! \brief Calls forwardSync of configired decoding layer.
|
||||
void forwardSync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) override;
|
||||
void forwardSync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& inputs) override;
|
||||
|
||||
private:
|
||||
std::tuple<std::shared_ptr<BaseOutputParams>, std::shared_ptr<BaseInputParams>> prepareParams(
|
||||
std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) const;
|
||||
[[nodiscard]] std::tuple<std::shared_ptr<BaseDecodingOutputs>, std::shared_ptr<BaseDecodingInputs>> prepareParams(
|
||||
std::shared_ptr<BaseDecodingOutputs> const& outputs, std::shared_ptr<BaseDecodingInputs> const& inputs) const;
|
||||
|
||||
private:
|
||||
using BaseLayer::mWorkspaceSize;
|
||||
@ -74,5 +67,4 @@ private:
|
||||
std::unique_ptr<BaseLayer> mDecodingLayer;
|
||||
};
|
||||
|
||||
} // namespace layers
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -16,12 +16,16 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/kernels/beamSearchKernels.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/request.h"
|
||||
#include <tensorrt_llm/common/tensor.h>
|
||||
#include <tensorrt_llm/runtime/common.h>
|
||||
#include <tensorrt_llm/runtime/speculativeDecodingModule.h>
|
||||
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace tc = tensorrt_llm::common;
|
||||
@ -42,10 +46,11 @@ namespace tensorrt_llm::layers
|
||||
//! It is passed through `setup` method.
|
||||
//! 3. `forwardBatchSize` for layers forwarding for a batch of existing active requests.
|
||||
//! it is passed through `forwardAsync` and `forwardSync` methods.
|
||||
//! `setup` and `forward` always provide `batch_slots` indexed by
|
||||
//! `setup` and `forward` always provide `batchSlots` indexed by
|
||||
//! local batch index ranging in [0, setupBatchSize) or [0, forwardBatchSize),
|
||||
//! holding the global batch index ranging in [0, maxBatchSize).
|
||||
//! In case of beam search, maxBatchSize = forwardBatchSize = 1.
|
||||
|
||||
class DecoderDomain
|
||||
{
|
||||
public:
|
||||
@ -56,7 +61,7 @@ public:
|
||||
, mBeamWidth(beamWidth)
|
||||
, mVocabSize(vocabSize)
|
||||
, mVocabSizePadded(vocabSizePadded.value_or(vocabSize))
|
||||
, mSpeculativeDecodingModule(speculativeDecodingModule)
|
||||
, mSpeculativeDecodingModule(std::move(speculativeDecodingModule))
|
||||
{
|
||||
}
|
||||
|
||||
@ -91,6 +96,11 @@ public:
|
||||
return mSpeculativeDecodingModule;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::shared_ptr<runtime::SpeculativeDecodingModule const> getSpeculativeDecodingModulePtr() const
|
||||
{
|
||||
return mSpeculativeDecodingModule;
|
||||
}
|
||||
|
||||
private:
|
||||
runtime::SizeType32 mBatchSize;
|
||||
runtime::SizeType32 mBeamWidth;
|
||||
@ -102,245 +112,444 @@ private:
|
||||
class BaseSetupParams
|
||||
{
|
||||
public:
|
||||
virtual ~BaseSetupParams() {}
|
||||
virtual ~BaseSetupParams() = default;
|
||||
};
|
||||
|
||||
// Penalty layer
|
||||
class PenaltySetupParams : public BaseSetupParams
|
||||
{
|
||||
public:
|
||||
std::optional<std::vector<float>> temperature; // [1] or [setupBatchSize] on cpu
|
||||
std::optional<std::vector<runtime::SizeType32>> minLength; // [1] or [setupBatchSize] on cpu
|
||||
std::optional<std::vector<float>> repetitionPenalty; // [1] or [setupBatchSize] on cpu
|
||||
std::optional<std::vector<float>> presencePenalty; // [1] or [setupBatchSize] on cpu
|
||||
std::optional<std::vector<float>> frequencyPenalty; // [1] or [setupBatchSize] on cpu
|
||||
};
|
||||
|
||||
// Ban words layer
|
||||
class BanWordsSetupParams : public BaseSetupParams
|
||||
{
|
||||
public:
|
||||
std::optional<std::vector<runtime::SizeType32>> noRepeatNgramSize; // [1] or [setupBatchSize] on cpu
|
||||
};
|
||||
|
||||
class DecodingSetupParams : public BaseSetupParams
|
||||
{
|
||||
public:
|
||||
virtual ~DecodingSetupParams() = default;
|
||||
|
||||
std::optional<std::vector<uint64_t>> randomSeed; // [1] or [setupBatchSize] on cpu
|
||||
std::optional<std::vector<bool>> outputLogProbs; // [setupBatchSize]
|
||||
std::optional<std::vector<bool>> cumLogProbs; // [setupBatchSize]
|
||||
};
|
||||
|
||||
class SamplingSetupParams : public DecodingSetupParams
|
||||
{
|
||||
public:
|
||||
// baseSamplingLayer
|
||||
std::optional<std::vector<runtime::SizeType32>> runtimeTopK; // [1] or [setupBatchSize] on cpu
|
||||
std::optional<std::vector<float>> runtimeTopP; // [1] or [setupBatchSize] on cpu
|
||||
|
||||
// topPSamplingLayer
|
||||
std::optional<std::vector<float>> topPDecay; // [setupBatchSize], must between [0, 1]
|
||||
std::optional<std::vector<float>> topPMin; // [setupBatchSize], must between [0, 1]
|
||||
std::optional<std::vector<runtime::TokenIdType>> topPResetIds; // [setupBatchSize]
|
||||
std::optional<bool> normalizeLogProbs;
|
||||
};
|
||||
|
||||
class BeamSearchSetupParams : public DecodingSetupParams
|
||||
{
|
||||
public:
|
||||
// BeamSearchLayer
|
||||
std::optional<std::vector<float>> beamSearchDiversityRate; // [setupBatchSize] on cpu
|
||||
std::optional<std::vector<float>> lengthPenalty; // [setupBatchSize] on cpu
|
||||
std::optional<std::vector<int>> earlyStopping; // [setupBatchSize] on cpu
|
||||
bool hasDiffRuntimeArgs{false};
|
||||
};
|
||||
|
||||
class MedusaSetupParams : public DecodingSetupParams
|
||||
{
|
||||
public:
|
||||
// Medusa params
|
||||
std::optional<std::vector<runtime::SizeType32>> runtimeTopK; // [setupBatchSize] on cpu
|
||||
std::optional<std::vector<std::vector<runtime::SizeType32>>> runtimeHeadsTopK; // [setupBatchSize, maxMedusaHeads]
|
||||
};
|
||||
|
||||
class ExplicitDraftTokensSetupParams : public DecodingSetupParams
|
||||
{
|
||||
public:
|
||||
std::optional<std::vector<float>> temperature; // [setupBatchSize] on cpu
|
||||
// Hack to init some data for the context phase in the setup.
|
||||
tc::Tensor randomDataSample; // [maxBatchSize], on gpu
|
||||
tc::Tensor temperatures; // [maxBatchSize], on gpu
|
||||
};
|
||||
|
||||
class DynamicDecodeSetupParams : public BaseSetupParams
|
||||
{
|
||||
public:
|
||||
// Penalty layer
|
||||
struct PenaltyParams
|
||||
{
|
||||
std::optional<std::vector<float>> temperature; // [1] or [setupBatchSize] on cpu
|
||||
std::optional<std::vector<runtime::SizeType32>> minLength; // [1] or [setupBatchSize] on cpu
|
||||
std::optional<std::vector<float>> repetitionPenalty; // [1] or [setupBatchSize] on cpu
|
||||
std::optional<std::vector<float>> presencePenalty; // [1] or [setupBatchSize] on cpu
|
||||
std::optional<std::vector<float>> frequencyPenalty; // [1] or [setupBatchSize] on cpu
|
||||
std::optional<std::vector<runtime::SizeType32>> noRepeatNgramSize; // [1] or [setupBatchSize] on cpu
|
||||
};
|
||||
std::shared_ptr<PenaltySetupParams> penaltyParams;
|
||||
|
||||
struct SamplingParams
|
||||
{
|
||||
// baseSamplingLayer
|
||||
std::optional<std::vector<runtime::SizeType32>> runtime_top_k; // [1] or [setupBatchSize] on cpu
|
||||
std::optional<std::vector<float>> runtime_top_p; // [1] or [setupBatchSize] on cpu
|
||||
std::shared_ptr<BanWordsSetupParams> banWordsParams;
|
||||
|
||||
// topPSamplingLayer
|
||||
std::optional<std::vector<float>> top_p_decay; // [setupBatchSize], must between [0, 1]
|
||||
std::optional<std::vector<float>> top_p_min; // [setupBatchSize], must between [0, 1]
|
||||
std::optional<std::vector<runtime::TokenIdType>> top_p_reset_ids; // [setupBatchSize]
|
||||
std::optional<bool> normalize_log_probs;
|
||||
std::optional<std::vector<bool>> outputLogProbs; // [setupBatchSize]
|
||||
std::optional<std::vector<bool>> cumLogProbs; // [setupBatchSize]
|
||||
};
|
||||
|
||||
struct BeamSearchParams
|
||||
{
|
||||
// BeamSearchLayer
|
||||
std::optional<std::vector<float>> beam_search_diversity_rate; // [setupBatchSize] on cpu
|
||||
std::optional<std::vector<float>> length_penalty; // [setupBatchSize] on cpu
|
||||
std::optional<std::vector<int>> early_stopping; // [setupBatchSize] on cpu
|
||||
};
|
||||
|
||||
struct MedusaParams
|
||||
{
|
||||
// Medusa params
|
||||
std::optional<std::vector<std::vector<runtime::SizeType32>>>
|
||||
topKMedusaHeads; // [setupBatchSize, maxMedusaHeads]
|
||||
};
|
||||
|
||||
std::optional<std::vector<uint64_t>> randomSeed; // [1] or [setupBatchSize] on cpu
|
||||
|
||||
PenaltyParams penaltyParams;
|
||||
|
||||
SamplingParams samplingParams;
|
||||
|
||||
BeamSearchParams beamSearchParams;
|
||||
|
||||
MedusaParams medusaParams;
|
||||
std::shared_ptr<DecodingSetupParams> decodingParams;
|
||||
};
|
||||
|
||||
class BaseInputParams
|
||||
class LookaheadSetupParams : public DecodingSetupParams
|
||||
{
|
||||
public:
|
||||
explicit BaseInputParams(runtime::SizeType32 step, runtime::SizeType32 ite, tc::Tensor endIds)
|
||||
: step{step}
|
||||
, ite{ite}
|
||||
, end_ids{std::move(endIds)}
|
||||
std::vector<runtime::ITensor::SharedConstPtr> prompt; // [batchSize][maxSeqLen] on cpu
|
||||
std::optional<std::vector<uint64_t>> randomSeed; // [1] or [batchSize] on cpu
|
||||
std::vector<executor::LookaheadDecodingConfig> algoConfigs; // [1 or batchSize] on cpu
|
||||
};
|
||||
|
||||
class BaseDecodingInputs
|
||||
{
|
||||
public:
|
||||
BaseDecodingInputs(runtime::SizeType32 localBatchSize)
|
||||
: localBatchSize(localBatchSize)
|
||||
{
|
||||
}
|
||||
|
||||
virtual ~BaseInputParams() {}
|
||||
virtual ~BaseDecodingInputs() = default;
|
||||
|
||||
// mandatory parameters
|
||||
runtime::SizeType32 localBatchSize;
|
||||
};
|
||||
|
||||
// Ban words inputs
|
||||
class BanWordsDecodingInputs : public BaseDecodingInputs
|
||||
{
|
||||
public:
|
||||
BanWordsDecodingInputs(runtime::SizeType32 localBatchSize)
|
||||
: BaseDecodingInputs(localBatchSize)
|
||||
{
|
||||
}
|
||||
|
||||
runtime::SizeType32 maxBadWordsLen{0};
|
||||
//! [maxBatchSize][2, bad_words_length], on gpu
|
||||
std::optional<tc::Tensor> badWordsPtr;
|
||||
//! [maxBatchSize], on gpu
|
||||
std::optional<tc::Tensor> badWordsLengths;
|
||||
};
|
||||
|
||||
// Stop criteria inputs
|
||||
class StopCriteriaDecodingInputs : public BaseDecodingInputs
|
||||
{
|
||||
public:
|
||||
StopCriteriaDecodingInputs(runtime::SizeType32 localBatchSize)
|
||||
: BaseDecodingInputs(localBatchSize)
|
||||
{
|
||||
}
|
||||
|
||||
runtime::SizeType32 maxStopWordsLen{0};
|
||||
//! [maxBatchSize], on gpu
|
||||
std::optional<tc::Tensor> sequenceLimitLength;
|
||||
//! [maxBatchSize][2, stop_words_length], on gpu
|
||||
std::optional<tc::Tensor> stopWordsPtr;
|
||||
//! [maxBatchSize], on gpu
|
||||
std::optional<tc::Tensor> stopWordsLengths;
|
||||
};
|
||||
|
||||
class DecodingInputs : public BaseDecodingInputs
|
||||
{
|
||||
public:
|
||||
DecodingInputs(tc::Tensor endIds, runtime::SizeType32 step = 0, runtime::SizeType32 ite = 0,
|
||||
runtime::SizeType32 localBatchSize = 0, runtime::SizeType32 maxAttentionWindow = 0,
|
||||
runtime::SizeType32 sinkTokenLength = 0)
|
||||
: BaseDecodingInputs(localBatchSize)
|
||||
, endIds{std::move(endIds)}
|
||||
, step{step}
|
||||
, ite{ite}
|
||||
, maxAttentionWindow{maxAttentionWindow}
|
||||
, sinkTokenLength{sinkTokenLength}
|
||||
{
|
||||
}
|
||||
|
||||
//! [maxBatchSize]
|
||||
tc::Tensor endIds;
|
||||
|
||||
// used only for python runtime
|
||||
runtime::SizeType32 step;
|
||||
runtime::SizeType32 ite;
|
||||
tc::Tensor end_ids; // [maxBatchSize]
|
||||
std::optional<tc::Tensor> batch_slots; // [forwardBatchSize], on pinned memory
|
||||
std::optional<tc::Tensor> finished; // [maxBatchSize, maxBeamWidth]
|
||||
};
|
||||
|
||||
class DynamicDecodeInputParams : public BaseInputParams
|
||||
{
|
||||
public:
|
||||
DynamicDecodeInputParams(runtime::SizeType32 step, runtime::SizeType32 ite, runtime::SizeType32 maxInputLength,
|
||||
runtime::SizeType32 maxAttentionWindow, runtime::SizeType32 sinkTokenLength, runtime::SizeType32 localBatchSize,
|
||||
tc::Tensor endIds)
|
||||
: BaseInputParams(step, ite, std::move(endIds))
|
||||
, max_input_length{maxInputLength}
|
||||
, max_attention_window{maxAttentionWindow}
|
||||
, sink_token_length{sinkTokenLength}
|
||||
, local_batch_size{localBatchSize}
|
||||
, max_stop_words_len{0}
|
||||
, max_bad_words_len{0}
|
||||
{
|
||||
}
|
||||
|
||||
// mandatory parameters
|
||||
runtime::SizeType32 max_input_length;
|
||||
runtime::SizeType32 max_attention_window;
|
||||
runtime::SizeType32 sink_token_length;
|
||||
runtime::SizeType32 local_batch_size;
|
||||
runtime::SizeType32 max_stop_words_len;
|
||||
runtime::SizeType32 max_bad_words_len;
|
||||
runtime::SizeType32 maxAttentionWindow;
|
||||
runtime::SizeType32 sinkTokenLength;
|
||||
|
||||
// One of these two fields has to be set
|
||||
// DynamicDecodeLayer::forward checks for it
|
||||
// Need both of these fields to support legacy code during transition period to the batched decoder
|
||||
std::optional<tc::Tensor> logits; // [maxBatchSize, beamWidth, vocabSizePadded]
|
||||
std::optional<std::vector<tc::Tensor>> logits_vec; // [forwardBatchSize][beamWidth, vocabSizePadded], on gpu
|
||||
//! One of these two fields has to be set
|
||||
//! DynamicDecodeLayer::forward checks for it
|
||||
//! Need both of these fields to support legacy code during transition period to the batched decoder
|
||||
//! [forwardBatchSize, beamWidth, vocabSizePadded]
|
||||
std::optional<tc::Tensor> logits;
|
||||
//! [forwardBatchSize][beamWidth, vocabSizePadded], on gpu
|
||||
std::optional<std::vector<tc::Tensor>> logitsVec;
|
||||
|
||||
// optional parameters
|
||||
std::optional<tc::Tensor> src_cache_indirection; // [forwardBatchSize, maxBeamWidth, maxSeqLen] - the k/v cache
|
||||
// index for beam search, mandatory for beam search, on gpu
|
||||
std::optional<tc::Tensor> sequence_limit_length; // [maxBatchSize], on gpu
|
||||
std::optional<tc::Tensor> embedding_bias; // [vocabSizePadded], on gpu
|
||||
std::optional<tc::Tensor> input_lengths; // [maxBatchSize, maxBeamWidth], on gpu
|
||||
std::optional<tc::Tensor> bad_words_ptr; // [maxBatchSize][2, bad_words_length], on gpu
|
||||
std::optional<tc::Tensor> bad_words_lengths; // [maxBatchSize], on gpu
|
||||
std::optional<tc::Tensor> stop_words_ptr; // [maxBatchSize][2, stop_words_length], on gpu
|
||||
std::optional<tc::Tensor> stop_words_lengths; // [maxBatchSize], on gpu
|
||||
//! the indices of the selected beams, mandatory for beam search, on gpu
|
||||
//! [forwardBatchSize, maxBeamWidth, maxSeqLen]
|
||||
std::optional<tc::Tensor> srcCacheIndirection;
|
||||
//! [vocabSizePadded], on gpu
|
||||
std::optional<tc::Tensor> embeddingBias;
|
||||
//! [maxBatchSize, maxBeamWidth], on gpu
|
||||
std::optional<tc::Tensor> inputLengths;
|
||||
//! [forwardBatchSize], on pinned memory
|
||||
std::optional<tc::Tensor> batchSlots;
|
||||
//! [maxBatchSize, maxBeamWidth]
|
||||
std::optional<tc::Tensor> finished;
|
||||
//! [maxBatchSize], on gpu
|
||||
std::optional<tc::Tensor> curTokensPerStep;
|
||||
|
||||
// Medusa inputs
|
||||
class MedusaInputs
|
||||
{
|
||||
public:
|
||||
tc::Tensor medusaCurTokensPerStep; // [maxBatchSize], optional, on gpu
|
||||
tc::Tensor medusaTargetTokensPerStep; // [maxBatchSize], optional, on gpu
|
||||
tc::Tensor medusaPaths; // [maxBatchSize, maxPathLen, maxPathLen]
|
||||
// optional, on gpu
|
||||
tc::Tensor medusaTreeIds; // [maxBatchSize, maxDecodingTokens], optional, on gpu
|
||||
std::vector<std::vector<tc::Tensor>> medusaLogits; // [maxBatchSize][maxDraftPathLen]
|
||||
// [maxDecodingTokens, vocabSizePadded], optional, on gpu
|
||||
};
|
||||
std::shared_ptr<BanWordsDecodingInputs> banWordsInputs;
|
||||
|
||||
// Explicit draft tokens inputs
|
||||
// FIXME(nkorobov): this should be ExplicitDraftTokensBuffers?
|
||||
class ExplicitDraftTokensInputs
|
||||
{
|
||||
public:
|
||||
};
|
||||
|
||||
std::optional<MedusaInputs> medusaInputs;
|
||||
|
||||
std::optional<ExplicitDraftTokensInputs> explicitDraftTokensInputs;
|
||||
std::shared_ptr<StopCriteriaDecodingInputs> stopCriteriaInputs;
|
||||
};
|
||||
|
||||
class BaseOutputParams
|
||||
class SamplingInputs : public DecodingInputs
|
||||
{
|
||||
public:
|
||||
explicit BaseOutputParams(tc::Tensor outputIds)
|
||||
: output_ids{std::move(outputIds)}
|
||||
explicit SamplingInputs(
|
||||
tc::Tensor endIds, runtime::SizeType32 step, runtime::SizeType32 ite, runtime::SizeType32 localBatchSize)
|
||||
: DecodingInputs{std::move(endIds), step, ite, localBatchSize}
|
||||
{
|
||||
}
|
||||
|
||||
virtual ~BaseOutputParams() {}
|
||||
//! optional parameters
|
||||
//! [localBatchSize]
|
||||
curandState_t* curandStates{};
|
||||
//! Pointer to the workspace for sampling computation
|
||||
void* samplingWorkspace{};
|
||||
//! Flag to mark that logits tensor contains probabilities
|
||||
bool probsComputed{};
|
||||
};
|
||||
|
||||
// Medusa inputs
|
||||
class MedusaDecodingInputs : public DecodingInputs
|
||||
{
|
||||
public:
|
||||
explicit MedusaDecodingInputs(tc::Tensor endIds, runtime::SizeType32 localBatchSize)
|
||||
: DecodingInputs(std::move(endIds), 0, 0, localBatchSize)
|
||||
{
|
||||
}
|
||||
|
||||
//! [maxBatchSize], on gpu
|
||||
tc::Tensor targetTokensPerStep;
|
||||
//! [maxBatchSize, maxPathLen, maxPathLen], on gpu
|
||||
tc::Tensor paths;
|
||||
//! [maxBatchSize, maxDecodingTokens], on gpu
|
||||
tc::Tensor treeIds;
|
||||
//! [maxBatchSize][maxDraftPathLen][maxDecodingTokens, vocabSizePadded], on gpu
|
||||
std::vector<std::vector<tc::Tensor>> medusaLogits;
|
||||
};
|
||||
|
||||
// Explicit draft tokens inputs
|
||||
class ExplicitDraftTokensInputs : public DecodingInputs
|
||||
{
|
||||
public:
|
||||
explicit ExplicitDraftTokensInputs(tc::Tensor endIds, runtime::SizeType32 batchSize)
|
||||
: DecodingInputs(std::move(endIds), 0, 0, batchSize)
|
||||
{
|
||||
}
|
||||
|
||||
//! Draft tokens for the next iteration. The first token in each path is the last accepted token at current
|
||||
//! iteration. E.g. if forwardBatchSize == 1, maxNumPaths == 2, maxPathLen== 3, [[[0, 1, 2], [0, 1, 10]]]
|
||||
tc::Tensor nextDraftTokens; // [forwardBatchSize, maxNumPaths, maxPathLen], gpu
|
||||
//! Compressed form of `nextDraftTokens`, where common prefixes and collapsed.
|
||||
//! Using example above [0, 1, 2, 10]
|
||||
tc::Tensor nextFlatTokens; // [forwardBatchSize * maxDecodingTokens], gpu
|
||||
//! Indices of draft tokens in the compressed `nextFlatTokens` for the next iteration.
|
||||
//! Using example above, [[[0, 1, 2], [0, 1, 3]]]
|
||||
tc::Tensor nextDraftIndices; // [forwardBatchSize, maxNumPaths, maxPathLen], gpu
|
||||
//! Probabilities of the next draft tokens.
|
||||
tc::Tensor nextDraftProbs; // [forwardBatchSize, maxNumPaths, maxDraftPathLen, vocabSize], gpu
|
||||
//! Same as `nextDraftTokens`, but for current iteration.
|
||||
//! Current accepted tokens obtained as `lastDraftTokens[bi][bestPathIndices[bi]][1:bestPathLengths[bi]]`.
|
||||
tc::Tensor lastDraftTokens; // [forwardBatchSize, maxNumPaths, maxPathLen], gpu
|
||||
//! Same as `nextDraftIndices`, but for current iteration.
|
||||
tc::Tensor lastDraftIndices; // [forwardBatchSize, maxNumPaths, maxPathLen], gpu
|
||||
//! Boolean attention masks.
|
||||
//! maxDecodingTokens' = generationLengths.max()
|
||||
tc::Tensor masks; // [forwardBatchSize, maxDecodingTokens', maxDecodingTokens'], gpu
|
||||
//! Relative to `positionIdsBase` position ids. Same as `nextFlatTokens` for next draft indices.
|
||||
//! Using example above, [0, 1, 2, 3]
|
||||
tc::Tensor packedPosIds; // [forwardBatchSize * maxDecodingTokens], gpu
|
||||
//! Lengths of the accepted paths for each request. It is 1 for context phase (Only 1 primary tokens is accepted).
|
||||
tc::Tensor bestPathLengths; // [forwardBatchSize], gpu
|
||||
//! Indices of the accepted paths for each request. It is 0 for context phase.
|
||||
tc::Tensor bestPathIndices; // [forwardBatchSize], gpu
|
||||
//! Number of the draft tokens for the next iteration.
|
||||
tc::Tensor generationLengths; // [forwardBatchSize], gpu
|
||||
//! Baseline for the position ids.
|
||||
tc::Tensor positionIdsBase; // [forwardBatchSize], gpu
|
||||
//! Generation length for the previous stage.
|
||||
tc::Tensor lastGenerationLengths; // [forwardBatchSize], gpu
|
||||
//! Maximum number of generated tokens for the next step across whole batch
|
||||
tc::Tensor maxGenLengthDevice; // [1], on gpu
|
||||
//! Address map to map from linear indices of the engine outputs to seqSlot.
|
||||
//! It is not the same as batchSlots because it maps the ordered engine outputs to the respective seqSlot,
|
||||
//! while batchSlots is just a a list of active seqSlots.
|
||||
tc::Tensor seqSlots; // [forwardBatchSize], on gpu
|
||||
};
|
||||
|
||||
class LookaheadDecodingInputs : public DecodingInputs
|
||||
{
|
||||
using TensorConstPtr = runtime::ITensor::SharedConstPtr;
|
||||
|
||||
public:
|
||||
explicit LookaheadDecodingInputs(tc::Tensor endIds)
|
||||
: DecodingInputs{std::move(endIds)}
|
||||
//, logits{logits}
|
||||
{
|
||||
}
|
||||
// TODO(liweim) reuse base logits and curTokensPerStep.
|
||||
// TensorConstPtr logits; // [batchSize, maxTokensPerStep, vocabSizePadded] on gpu
|
||||
// TensorConstPtr tokensPerStep; // [maxBatchSize] on gpu
|
||||
};
|
||||
|
||||
class BaseDecodingOutputs
|
||||
{
|
||||
public:
|
||||
explicit BaseDecodingOutputs(tc::Tensor outputIds)
|
||||
: outputIds{std::move(outputIds)}
|
||||
{
|
||||
}
|
||||
|
||||
virtual ~BaseDecodingOutputs() = default;
|
||||
|
||||
// mandatory parameters
|
||||
tc::Tensor output_ids; // [maxBatchSize, maxSeqLen]
|
||||
tc::Tensor outputIds; // [maxBatchSize, maxSeqLen]
|
||||
|
||||
// optional parameters
|
||||
std::optional<tc::Tensor> finished; // [maxBatchSize * maxBeamWidth], optional
|
||||
std::optional<tc::Tensor> sequence_length; // [maxBatchSize * maxBeamWidth], optional
|
||||
std::optional<tc::Tensor> cum_log_probs; // [maxBatchSize * maxBeamWidth], necessary in beam search
|
||||
std::optional<tc::Tensor> output_log_probs; // [maxBatchSize, maxBeamWidth, maxSeqLen], must be float*, optional
|
||||
std::optional<tc::Tensor> parent_ids; // [maxBatchSize, maxBeamWidth, maxSeqLen], necessary in beam search
|
||||
//! [maxBatchSize * maxBeamWidth], optional
|
||||
std::optional<tc::Tensor> finished;
|
||||
//! [maxBatchSize * maxBeamWidth], optional
|
||||
std::optional<tc::Tensor> sequenceLength;
|
||||
//! [maxBatchSize * maxBeamWidth], necessary in beam search
|
||||
std::optional<tc::Tensor> cumLogProbs;
|
||||
//! [maxBatchSize, maxBeamWidth, maxSeqLen], must be float*, optional
|
||||
std::optional<tc::Tensor> outputLogProbs;
|
||||
//! [maxBatchSize, maxBeamWidth, maxSeqLen], necessary in beam search
|
||||
std::optional<tc::Tensor> parentIds;
|
||||
|
||||
tc::Tensor output_ids_ptr; // [maxBatchSize] int* (2-d array), each int* has [maxBeamWidth, maxSeqLen]
|
||||
//! [maxBatchSize] int* (2-d array), each int* has [maxBeamWidth, maxSeqLen]
|
||||
tc::Tensor outputIdsPtr;
|
||||
//! [maxBatchSize] int* (2-d array), each int* has [maxBeamWidth, maxSeqLen]
|
||||
tc::Tensor parentIdsPtr;
|
||||
|
||||
//!
|
||||
//! \brief SpeculativeDecodingOutputs outputs.
|
||||
//!
|
||||
//! For one example sequence [a, b] [c] <x, y, z>, where, [a, b, c] is the accepted sequence,
|
||||
//! [c] is the last accepted token, and <x, y, z> is the draft tokens from `nextDraftTokens` saved by last step.
|
||||
//! [c]'s position id is known, only position ids for <x, y, z> need to be provided in `nextDraftPosIds`.
|
||||
//! LLM inputs {c, x, y, z} and generates {c', x', y', z'}.
|
||||
//!
|
||||
//! {c'} is always accepted and {x', z'} is supposed to be accepted.
|
||||
//! The accepted tokens [c', x', z'] is saved in `output_ids` in-place, starting from `sequence_length`.
|
||||
//! The `acceptedLength` is 3, and the accepted draft tokens length is 2.
|
||||
//! `sequence_length` is also increaded by `acceptedLength` in-place.
|
||||
//! The pathsOffset is {0, 1, 3} for {c', x', z'}.
|
||||
//! [] for accepted, <> for draft, {} for input/output.
|
||||
//!
|
||||
//! For a batchSlots {1, 3}, `acceptedLengthsCumSum` is an exclusive sum of `acceptedLength` over the batch,
|
||||
//! the `acceptedLengths` may be {3, 5}, `acceptedLengthsCumSum` is {0, 3, 8}.
|
||||
class SpeculativeDecodingOutputs
|
||||
{
|
||||
public:
|
||||
tc::Tensor nextDraftTokens; // [maxBatchSize, maxDecodingDraftTokens], draft tokens for the next step
|
||||
tc::Tensor nextDraftPosIds; // [maxBatchSize, maxDecodingDraftTokens], draft token position IDs
|
||||
tc::Tensor nextDraftLengths; // [maxBatchSize], next step draft tokens lengths
|
||||
tc::Tensor acceptedLengths; // [maxBatchSize], lengths of the accepted draft tokens + 1.
|
||||
tc::Tensor acceptedLengthsCumSum; // [maxBatchSize + 1] accumulative sum along batchSlots.
|
||||
tc::Tensor pathsOffsets; // [maxBatchSize, maxPathLen]
|
||||
tc::Tensor packedMasks; // [maxBatchSize, maxDecodingTokens, divUp(maxDecodingTokens, 32)]
|
||||
};
|
||||
|
||||
class ExplicitDraftTokensOutputs : public SpeculativeDecodingOutputs
|
||||
{
|
||||
public:
|
||||
//! Draft tokens for the next iteration. The first token in each path is the last accepted token at current
|
||||
//! iteration. E.g. if batchSize == 1, maxNumPaths == 2, maxPathLen== 3, [[[0, 1, 2], [0, 1, 10]]]
|
||||
tc::Tensor unpackedNextDraftTokens; // [maxBatchSize, maxNumPaths, maxPathLen] on gpu
|
||||
//! Indices of draft tokens in the compressed `nextFlatTokens` for the next iteration.
|
||||
//! Using example above, [[[0, 1, 2], [0, 1, 3]]]
|
||||
tc::Tensor unpackedNextDraftIndices; // [maxBatchSize, maxNumPaths, maxPathLen] on gpu
|
||||
//! Probabilities of the next draft tokens.
|
||||
tc::Tensor nextDraftProbs; // [maxBatchSize, maxNumPaths, maxPathDraftLen, vocabSize] on gpu
|
||||
//! Baseline for the position ids.
|
||||
tc::Tensor positionIdsBase; // [maxBatchSize] on gpu
|
||||
//! Randomly sampled data (between 0.f and 1.f)
|
||||
tc::Tensor randomDataSample; // [maxBatchSize] on gpu
|
||||
//! Randomly sampled data (between 0.f and 1.f)
|
||||
tc::Tensor randomDataValidation; // [maxBatchSize, maxNumPaths, maxDraftPathLen] on gpu
|
||||
//! Sampling temperature.
|
||||
tc::Tensor temperatures; // [maxBatchSize] on gpu
|
||||
};
|
||||
|
||||
std::optional<SpeculativeDecodingOutputs> speculativeDecodingOutputs;
|
||||
|
||||
std::optional<ExplicitDraftTokensOutputs> explicitDraftTokensOutputs;
|
||||
};
|
||||
|
||||
class DynamicDecodeOutputParams : public BaseOutputParams
|
||||
{
|
||||
public:
|
||||
explicit DynamicDecodeOutputParams(tc::Tensor outputIds)
|
||||
: BaseOutputParams{std::move(outputIds)}
|
||||
{
|
||||
}
|
||||
|
||||
// mandatory parameters
|
||||
// Tokens predicted at current iteration.
|
||||
tc::Tensor newTokens; // [maxBatchSize, maxBeamWidth]
|
||||
// optional parameters
|
||||
std::optional<tc::Tensor> finished_sum; // [1] in pinned host memory
|
||||
std::optional<tc::Tensor> output_log_probs_tiled; // [maxSeqLen, maxBatchSize, maxBeamWidth], must be float*
|
||||
std::optional<tc::Tensor>
|
||||
tgt_cache_indirection; // [forwardBatchSize, maxBeamWidth, maxSeqLen], the k/v cache index for beam search
|
||||
std::unique_ptr<kernels::BeamHypotheses> beamHypotheses; // structure maintains some pointers of beam search
|
||||
|
||||
tc::Tensor parent_ids_ptr; // [maxBatchSize] int* (2-d array), each int* has [maxBeamWidth, maxSeqLen]
|
||||
// optional parameters
|
||||
//! Number of tokens predicted at current iteration.
|
||||
//! [maxBatchSize]
|
||||
std::optional<tc::Tensor> numNewTokens;
|
||||
//! [1] in pinned host memory
|
||||
std::optional<tc::Tensor> finishedSum;
|
||||
//! [maxSeqLen, maxBatchSize, maxBeamWidth], must be float*
|
||||
std::optional<tc::Tensor> outputLogProbsTiled;
|
||||
};
|
||||
|
||||
class BeamSearchOutputs : public BaseDecodingOutputs
|
||||
{
|
||||
public:
|
||||
explicit BeamSearchOutputs(tc::Tensor outputIds)
|
||||
: BaseDecodingOutputs{std::move(outputIds)}
|
||||
{
|
||||
}
|
||||
|
||||
//! the k/v cache index for beam search
|
||||
//! [forwardBatchSize, maxBeamWidth, maxSeqLen]
|
||||
tc::Tensor tgtCacheIndirection;
|
||||
//! structure maintains some pointers of beam search
|
||||
std::unique_ptr<kernels::BeamHypotheses> beamHypotheses;
|
||||
};
|
||||
|
||||
//!
|
||||
//! \brief SpeculativeDecodingOutputs outputs.
|
||||
//!
|
||||
//! For one example sequence [a, b] [c] <x, y, z>, where, [a, b, c] is the accepted sequence,
|
||||
//! [c] is the last accepted token, and <x, y, z> is the draft tokens from `nextDraftTokens` saved by last step.
|
||||
//! [c]'s position id is known, only position ids for <x, y, z> need to be provided in `nextDraftPosIds`.
|
||||
//! LLM inputs {c, x, y, z} and generates {c', x', y', z'}.
|
||||
//!
|
||||
//! {c'} is always accepted and {x', z'} is supposed to be accepted.
|
||||
//! The accepted tokens [c', x', z'] is saved in `outputIds` in-place, starting from `sequenceLength`.
|
||||
//! The `acceptedLength` is 3, and the accepted draft tokens length is 2.
|
||||
//! `sequenceLength` is also increaded by `acceptedLength` in-place.
|
||||
//! The pathsOffset is {0, 1, 3} for {c', x', z'}.
|
||||
//! [] for accepted, <> for draft, {} for input/output.
|
||||
//!
|
||||
//! For a batchSlots {1, 3}, `numNewTokensCumSum` is an exclusive sum of `numNewTokens` over the batch,
|
||||
//! the `numNewTokens` may be {3, 5}, `numNewTokensCumSum` is {0, 3, 8}.
|
||||
//!
|
||||
//! `nextDraftLengths` and `prevDraftLengths` are needed for methods that support if variable
|
||||
//! draft length. `nextDraftLengths` must contain the number of draft tokens per request for the next iteration.
|
||||
//! `prevDraftLengths` must contain the number of draft tokens used in the current iteraiton.
|
||||
//!
|
||||
//! `pathsOffsets` is needed for KV cache rewind. It contains the positions of the accepted draft tokens in the
|
||||
//! flattened tensor of draft tokens. E.g. if for sequence {c, x, y, z} only `y` and `z` were accepted,
|
||||
//! `pathsOffsets` contains [1, 2]. `pathsOffsets` is flattened tensor for whole batch.
|
||||
//!
|
||||
//! The order of `pathsOffsets` and `numNewTokensCumSum` must be aligned. Such that
|
||||
//! `pathsOffset[numNewTokensCumSum[bi]:numNewTokensCumSum[bi+1]]` is the slice of offsets for `bi`th request.
|
||||
//! Furthermore, the order of requests is important and must be aligned with sorted `RuntimeBuffers::seqSlots`
|
||||
//! such that the request with smaller `seqSlot` stays earlier in the tensors.
|
||||
//! However, this condition usually holds if method does not expect from the engine anything else, but logits.
|
||||
class SpeculativeDecodingOutputs : public BaseDecodingOutputs
|
||||
{
|
||||
public:
|
||||
explicit SpeculativeDecodingOutputs(tc::Tensor outputIds)
|
||||
: BaseDecodingOutputs{std::move(outputIds)}
|
||||
{
|
||||
}
|
||||
|
||||
//! Draft tokens for the next step
|
||||
// [maxBatchSize, maxDecodingDraftTokens]
|
||||
tc::Tensor nextDraftTokens;
|
||||
//! Draft token position IDs
|
||||
//! [maxBatchSize, maxDecodingDraftTokens]
|
||||
tc::Tensor nextDraftPosIds;
|
||||
//! Prev step draft tokens lengths, should be filled only for variable draft length speculative decoding mode
|
||||
//! [maxBatchSize]
|
||||
tc::Tensor prevDraftLengths;
|
||||
//! Next step draft tokens lengths, should be filled only for variable draft length speculative decoding mode
|
||||
//! [maxBatchSize]
|
||||
tc::Tensor nextDraftLengths;
|
||||
//! Accumulative sum along batchSlots.
|
||||
//! [maxBatchSize + 1]
|
||||
tc::Tensor numNewTokensCumSum;
|
||||
//! [maxBatchSize * maxPathLen]
|
||||
tc::Tensor pathsOffsets;
|
||||
//! [maxBatchSize, maxDecodingTokens, divUp(maxDecodingTokens, 32)]
|
||||
tc::Tensor packedMasks;
|
||||
};
|
||||
|
||||
class ExplicitDraftTokensOutputs : public SpeculativeDecodingOutputs
|
||||
{
|
||||
public:
|
||||
explicit ExplicitDraftTokensOutputs(tc::Tensor outputIds)
|
||||
: SpeculativeDecodingOutputs{std::move(outputIds)}
|
||||
{
|
||||
}
|
||||
|
||||
//! Draft tokens for the next iteration. The first token in each path is the last accepted token at current
|
||||
//! iteration. E.g. if batchSize == 1, maxNumPaths == 2, maxPathLen== 3, [[[0, 1, 2], [0, 1, 10]]]
|
||||
tc::Tensor unpackedNextDraftTokens; // [maxBatchSize, maxNumPaths, maxPathLen] on gpu
|
||||
//! Indices of draft tokens in the compressed `nextFlatTokens` for the next iteration.
|
||||
//! Using example above, [[[0, 1, 2], [0, 1, 3]]]
|
||||
tc::Tensor unpackedNextDraftIndices; // [maxBatchSize, maxNumPaths, maxPathLen] on gpu
|
||||
//! Probabilities of the next draft tokens.
|
||||
tc::Tensor nextDraftProbs; // [maxBatchSize, maxNumPaths, maxPathDraftLen, vocabSize] on gpu
|
||||
//! Baseline for the position ids.
|
||||
tc::Tensor positionIdsBase; // [maxBatchSize] on gpu
|
||||
//! Randomly sampled data (between 0.f and 1.f)
|
||||
tc::Tensor randomDataSample; // [maxBatchSize] on gpu
|
||||
//! Randomly sampled data (between 0.f and 1.f)
|
||||
tc::Tensor randomDataValidation; // [maxBatchSize, maxNumPaths, maxDraftPathLen] on gpu
|
||||
//! Sampling temperature.
|
||||
tc::Tensor temperatures; // [maxBatchSize] on gpu
|
||||
//! Next generation lengths.
|
||||
tc::Tensor generationLengths; // [maxBatchSize] on gpu
|
||||
//! Maximum number of generated tokens for the next step across whole batch
|
||||
tc::Tensor maxGenLengthHost; // [1] on pinned
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -15,22 +15,20 @@
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/layers/dynamicDecodeLayer.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/common/tensor.h"
|
||||
#include "tensorrt_llm/kernels/decodingKernels.h"
|
||||
#include "tensorrt_llm/layers/beamSearchLayer.h"
|
||||
#include "tensorrt_llm/layers/defaultDecodingParams.h"
|
||||
#include "tensorrt_llm/layers/layerUtils.h"
|
||||
#include "tensorrt_llm/layers/layersFactory.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
using namespace tensorrt_llm::common;
|
||||
using namespace tensorrt_llm::kernels;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace layers
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
@ -111,17 +109,18 @@ void DynamicDecodeLayer<T>::initializeLayers()
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 const* batchSlots,
|
||||
std::shared_ptr<BaseSetupParams> baseSetupParams)
|
||||
std::shared_ptr<BaseSetupParams> const& baseSetupParams)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto setupParams = std::dynamic_pointer_cast<DynamicDecodeSetupParams>(baseSetupParams);
|
||||
|
||||
if (setupParams->samplingParams.outputLogProbs)
|
||||
TLLM_CHECK_WITH_INFO(setupParams->decodingParams, "decodingParams for setup is not set");
|
||||
if (setupParams->decodingParams->outputLogProbs)
|
||||
{
|
||||
// FIXME(nkorobov): monotonically growing
|
||||
mOutputLogProbs = std::any_of(setupParams->samplingParams.outputLogProbs->begin(),
|
||||
setupParams->samplingParams.outputLogProbs->end(),
|
||||
mOutputLogProbs = std::any_of(setupParams->decodingParams->outputLogProbs->begin(),
|
||||
setupParams->decodingParams->outputLogProbs->end(),
|
||||
[this](bool outputLogProbs) { return this->mOutputLogProbs | outputLogProbs; });
|
||||
}
|
||||
|
||||
@ -153,20 +152,19 @@ void DynamicDecodeLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, Si
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::forwardAsync(
|
||||
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
|
||||
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto params = std::dynamic_pointer_cast<DynamicDecodeInputParams>(baseInputs);
|
||||
auto outputs = std::dynamic_pointer_cast<DynamicDecodeOutputParams>(baseOutputs);
|
||||
auto params = std::dynamic_pointer_cast<DecodingInputs>(baseInputs);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(mDecodingMode.isExplicitDraftTokens() || params->logits || params->logits_vec,
|
||||
"If not explicit Draft Tokens mode, either logits or logits_vec have to be specified.");
|
||||
TLLM_CHECK_WITH_INFO(mDecodingMode.isExplicitDraftTokens() || params->logits || params->logitsVec,
|
||||
"If not explicit Draft Tokens mode, either logits or logitsVec have to be specified.");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
outputs->sequence_length.has_value(), "sequence_length tensor is required in DynamicDecoderLayer.");
|
||||
baseOutputs->sequenceLength.has_value(), "sequenceLength tensor is required in DynamicDecoderLayer.");
|
||||
|
||||
auto const localDecoderDomain = getLocalDecoderDomain(params, mDecoderDomain);
|
||||
auto const maxSeqLen = outputs->output_ids.shape[outputs->output_ids.shape.size() - 1];
|
||||
auto const maxSeqLen = baseOutputs->outputIds.shape[baseOutputs->outputIds.shape.size() - 1];
|
||||
|
||||
TLLM_CHECK_WITH_INFO((mConfiguredBeamWidth == 1 && localDecoderDomain.getBeamWidth() == 1)
|
||||
|| (mConfiguredBeamWidth > 1 && localDecoderDomain.getBeamWidth() > 1
|
||||
@ -185,12 +183,12 @@ void DynamicDecodeLayer<T>::forwardAsync(
|
||||
std::vector<SizeType32> batchSlotsVec(localDecoderDomain.getBatchSize());
|
||||
std::iota(batchSlotsVec.begin(), batchSlotsVec.end(), 0);
|
||||
auto batchSlotsHost
|
||||
= params->batch_slots ? params->batch_slots->template getPtr<SizeType32 const>() : batchSlotsVec.data();
|
||||
auto batchSlots = params->batch_slots ? params->batch_slots->template getPtr<SizeType32 const>() : nullptr;
|
||||
= params->batchSlots ? params->batchSlots->template getPtr<SizeType32 const>() : batchSlotsVec.data();
|
||||
auto batchSlots = params->batchSlots ? params->batchSlots->template getPtr<SizeType32 const>() : nullptr;
|
||||
|
||||
mCyclicStep = mCyclicStep % mRuntimeMaxSeqLen;
|
||||
prepareIdsPtrs(
|
||||
outputs, batchSlotsHost, localDecoderDomain.getBatchSize(), localDecoderDomain.getBeamWidth(), maxSeqLen);
|
||||
baseOutputs, batchSlotsHost, localDecoderDomain.getBatchSize(), localDecoderDomain.getBeamWidth(), maxSeqLen);
|
||||
|
||||
for (auto& layer : mLayers)
|
||||
{
|
||||
@ -198,7 +196,7 @@ void DynamicDecodeLayer<T>::forwardAsync(
|
||||
}
|
||||
|
||||
// Copy nextIds and transpose logits when needed
|
||||
prepareOutputData(outputs, params, mIdsPtrHost, batchSlots, localDecoderDomain.getBatchSize(),
|
||||
prepareOutputData(baseOutputs, params, mIdsPtrHost, batchSlots, localDecoderDomain.getBatchSize(),
|
||||
mDecoderDomain.getBatchSize(), localDecoderDomain.getBeamWidth(), maxSeqLen,
|
||||
mDecoderDomain.getMaxDecodingTokens(), mCyclicStep, mOutputLogProbs, mStream);
|
||||
|
||||
@ -210,7 +208,7 @@ void DynamicDecodeLayer<T>::forwardAsync(
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::forwardSync(
|
||||
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
|
||||
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
for (auto& layer : mLayers)
|
||||
@ -221,7 +219,7 @@ void DynamicDecodeLayer<T>::forwardSync(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::prepareIdsPtrs(std::shared_ptr<DynamicDecodeOutputParams> const& outputs,
|
||||
void DynamicDecodeLayer<T>::prepareIdsPtrs(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
SizeType32 const* batchSlots, SizeType32 batchSize, SizeType32 beamWidth, SizeType32 maxSeqLen)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
@ -231,7 +229,7 @@ void DynamicDecodeLayer<T>::prepareIdsPtrs(std::shared_ptr<DynamicDecodeOutputPa
|
||||
{
|
||||
auto const batchSlot = batchSlots[bi];
|
||||
idsPtrHost[batchSlot]
|
||||
= outputs->output_ids.template getPtrWithOffset<TokenIdType>(batchSlot * beamWidth * maxSeqLen);
|
||||
= outputs->outputIds.template getPtrWithOffset<TokenIdType>(batchSlot * beamWidth * maxSeqLen);
|
||||
}
|
||||
for (SizeType32 bi = 0; bi < batchSize; bi++)
|
||||
{
|
||||
@ -239,7 +237,7 @@ void DynamicDecodeLayer<T>::prepareIdsPtrs(std::shared_ptr<DynamicDecodeOutputPa
|
||||
if (beamWidth > 1)
|
||||
{
|
||||
idsPtrHost[mDecoderDomain.getBatchSize() + batchSlot]
|
||||
= outputs->parent_ids.value().template getPtrWithOffset<SizeType32>(bi * beamWidth * maxSeqLen);
|
||||
= outputs->parentIds.value().template getPtrWithOffset<SizeType32>(bi * beamWidth * maxSeqLen);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -247,11 +245,11 @@ void DynamicDecodeLayer<T>::prepareIdsPtrs(std::shared_ptr<DynamicDecodeOutputPa
|
||||
}
|
||||
}
|
||||
|
||||
outputs->output_ids_ptr = Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR,
|
||||
outputs->outputIdsPtr = Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR,
|
||||
{static_cast<size_t>(mDecoderDomain.getBatchSize()), static_cast<size_t>(beamWidth),
|
||||
static_cast<size_t>(maxSeqLen)},
|
||||
idsPtrHost);
|
||||
outputs->parent_ids_ptr = Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR,
|
||||
outputs->parentIdsPtr = Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR,
|
||||
{static_cast<size_t>(mDecoderDomain.getBatchSize()), static_cast<size_t>(beamWidth),
|
||||
static_cast<size_t>(maxSeqLen)},
|
||||
idsPtrHost + mDecoderDomain.getBatchSize());
|
||||
@ -259,29 +257,28 @@ void DynamicDecodeLayer<T>::prepareIdsPtrs(std::shared_ptr<DynamicDecodeOutputPa
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::prepareOutputData(std::shared_ptr<DynamicDecodeOutputParams> const& outputs,
|
||||
std::shared_ptr<DynamicDecodeInputParams> const& params, runtime::ITensor::SharedPtr const& idsPtrsHost,
|
||||
void DynamicDecodeLayer<T>::prepareOutputData(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<DecodingInputs> const& params, runtime::ITensor::SharedPtr const& idsPtrsHost,
|
||||
SizeType32 const* batchSlots, SizeType32 batchSize, SizeType32 maxBatchSize, SizeType32 beamWidth,
|
||||
SizeType32 maxSeqLen, SizeType32 maxTokensPerStep, SizeType32 cyclicStep, bool outputLogProbs, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto idsPtrHostSlice = ITensor::slice(idsPtrsHost, cyclicStep, 1);
|
||||
auto idsPtrHost = reinterpret_cast<TokenIdType**>(runtime::bufferCast<int64_t>(*idsPtrHostSlice));
|
||||
auto const numNewTokens = outputs->speculativeDecodingOutputs
|
||||
? outputs->speculativeDecodingOutputs->acceptedLengths.template getPtr<SizeType32 const>()
|
||||
: nullptr;
|
||||
auto const numNewTokens
|
||||
= outputs->numNewTokens ? outputs->numNewTokens->template getPtr<SizeType32 const>() : nullptr;
|
||||
invokeCopyNextStepIds(outputs->newTokens.template getPtr<TokenIdType>(), idsPtrHost,
|
||||
outputs->sequence_length->template getPtr<SizeType32>(), numNewTokens, batchSlots, batchSize, maxBatchSize,
|
||||
outputs->sequenceLength->template getPtr<SizeType32>(), numNewTokens, batchSlots, batchSize, maxBatchSize,
|
||||
beamWidth, maxSeqLen, maxTokensPerStep, stream);
|
||||
|
||||
// Transpose output log probs from [maxSeqLen, batchSize, beamWidth] to [batchSize, beamWidth, maxSeqLen]
|
||||
if (outputLogProbs && outputs->output_log_probs_tiled)
|
||||
if (outputLogProbs && outputs->outputLogProbsTiled)
|
||||
{
|
||||
auto logProbsMaxSeqLen = outputs->output_log_probs_tiled.value().shape[0];
|
||||
auto logProbsMaxSeqLen = outputs->outputLogProbsTiled.value().shape[0];
|
||||
|
||||
invokeTransposeLogProbs(outputs->output_log_probs.value().template getPtr<float>(),
|
||||
outputs->output_log_probs_tiled.value().template getPtr<float>(),
|
||||
outputs->sequence_length->template getPtr<SizeType32>(), batchSlots, batchSize, maxBatchSize, beamWidth,
|
||||
invokeTransposeLogProbs(outputs->outputLogProbs.value().template getPtr<float>(),
|
||||
outputs->outputLogProbsTiled.value().template getPtr<float>(),
|
||||
outputs->sequenceLength->template getPtr<SizeType32>(), batchSlots, batchSize, maxBatchSize, beamWidth,
|
||||
logProbsMaxSeqLen, stream);
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
@ -290,5 +287,4 @@ void DynamicDecodeLayer<T>::prepareOutputData(std::shared_ptr<DynamicDecodeOutpu
|
||||
template class DynamicDecodeLayer<float>;
|
||||
template class DynamicDecodeLayer<half>;
|
||||
|
||||
} // namespace layers
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -16,35 +16,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/tensor.h"
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include "tensorrt_llm/layers/banWordsLayer.h"
|
||||
#include "tensorrt_llm/layers/baseLayer.h"
|
||||
#include "tensorrt_llm/layers/beamSearchLayer.h"
|
||||
#include "tensorrt_llm/layers/decodingLayer.h"
|
||||
#include "tensorrt_llm/layers/layerUtils.h"
|
||||
#include "tensorrt_llm/layers/medusaDecodingLayer.h"
|
||||
#include "tensorrt_llm/layers/penaltyLayer.h"
|
||||
#include "tensorrt_llm/layers/samplingLayer.h"
|
||||
#include "tensorrt_llm/layers/stopCriteriaLayer.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
namespace tc = tensorrt_llm::common;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
struct BeamHypotheses;
|
||||
}
|
||||
|
||||
namespace layers
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
@ -59,11 +38,13 @@ public:
|
||||
~DynamicDecodeLayer() override;
|
||||
|
||||
void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 const* batchSlots,
|
||||
std::shared_ptr<BaseSetupParams> setupParams) override;
|
||||
std::shared_ptr<BaseSetupParams> const& setupParams) override;
|
||||
|
||||
void forwardAsync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) override;
|
||||
void forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& inputs) override;
|
||||
|
||||
void forwardSync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) override;
|
||||
void forwardSync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& inputs) override;
|
||||
|
||||
// Function is only used by test.
|
||||
// It is guaranteed by LayersFactory that the first layer is the Penalty layer.
|
||||
@ -79,11 +60,10 @@ private:
|
||||
void initialize();
|
||||
void initializeLayers();
|
||||
|
||||
void prepareIdsPtrs(std::shared_ptr<DynamicDecodeOutputParams> const& outputs,
|
||||
runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth,
|
||||
runtime::SizeType32 maxSeqLen);
|
||||
static void prepareOutputData(std::shared_ptr<DynamicDecodeOutputParams> const& outputs,
|
||||
std::shared_ptr<DynamicDecodeInputParams> const& params, runtime::ITensor::SharedPtr const& idsPtrsHost,
|
||||
void prepareIdsPtrs(std::shared_ptr<BaseDecodingOutputs> const& outputs, runtime::SizeType32 const* batchSlots,
|
||||
runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 maxSeqLen);
|
||||
static void prepareOutputData(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<DecodingInputs> const& params, runtime::ITensor::SharedPtr const& idsPtrsHost,
|
||||
runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize, runtime::SizeType32 maxBatchSize,
|
||||
runtime::SizeType32 beamWidth, runtime::SizeType32 maxSeqLen, runtime::SizeType32 maxTokensPerStep,
|
||||
runtime::SizeType32 cyclicStep, bool outputLogProbs, cudaStream_t stream);
|
||||
@ -109,5 +89,4 @@ private:
|
||||
runtime::SizeType32 mConfiguredBeamWidth{-1};
|
||||
};
|
||||
|
||||
} // namespace layers
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -18,13 +18,10 @@
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/kernels/decodingCommon.h"
|
||||
#include "tensorrt_llm/kernels/decodingKernels.h"
|
||||
#include "tensorrt_llm/kernels/penaltyKernels.h"
|
||||
#include "tensorrt_llm/kernels/penaltyTypes.h"
|
||||
#include "tensorrt_llm/kernels/speculativeDecoding/explicitDraftTokensKernels.h"
|
||||
#include "tensorrt_llm/layers/defaultDecodingParams.h"
|
||||
#include "tensorrt_llm/layers/layerUtils.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
@ -67,23 +64,28 @@ void ExplicitDraftTokensLayer<T>::allocateBuffer()
|
||||
|
||||
mTemperature.resize(mDecoderDomain.getBatchSize());
|
||||
|
||||
mScanWorkspaceSizeInBytes = invokeScanSpecDecodingGenerationLengths(
|
||||
mScanWorkspaceSizeInBytes = invokeScanGenerationLengths(
|
||||
nullptr, mScanWorkspaceSizeInBytes, nullptr, nullptr, mDecoderDomain.getBatchSize(), mStream);
|
||||
mReduceWorkspaceSizeInBytes = invokeReduceMaxSpecDecodingGenerationLengths(
|
||||
mReduceWorkspaceSizeInBytes = invokeReduceMaxGenerationLengths(
|
||||
nullptr, mReduceWorkspaceSizeInBytes, nullptr, nullptr, mDecoderDomain.getBatchSize(), mStream);
|
||||
|
||||
mWorkspaceSizeInBytes = std::max(mScanWorkspaceSizeInBytes, mReduceWorkspaceSizeInBytes);
|
||||
|
||||
std::array<size_t, 6> deviceBufferSizes
|
||||
std::array<size_t, 8> deviceBufferSizes
|
||||
= {sizeof(curandState_t) * mDecoderDomain.getBatchSize(), sizeof(uint64_t) * mDecoderDomain.getBatchSize(),
|
||||
mWorkspaceSizeInBytes, sizeof(SizeType32) * mDecoderDomain.getBatchSize(), sizeof(SizeType32),
|
||||
sizeof(float) * mDecoderDomain.getBatchSize()};
|
||||
sizeof(float) * mDecoderDomain.getBatchSize(), sizeof(SizeType32) * mDecoderDomain.getBatchSize(),
|
||||
sizeof(SizeType32) * mDecoderDomain.getBatchSize()
|
||||
* mDecoderDomain.getSpeculativeDecodingModule()->getMaxNumPaths()
|
||||
* mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen()};
|
||||
mCurandStatesDevice = mAllocator->reMalloc(mCurandStatesDevice, deviceBufferSizes[0], false);
|
||||
mRandomSeedsDevice = mAllocator->reMalloc(mRandomSeedsDevice, deviceBufferSizes[1], false);
|
||||
mWorkspaceDevice = mAllocator->reMalloc(mWorkspaceDevice, deviceBufferSizes[2], false);
|
||||
mGenerationLengthInclusiveSum = mAllocator->reMalloc(mGenerationLengthInclusiveSum, deviceBufferSizes[3], false);
|
||||
mMaxGenerationLength = mAllocator->reMalloc(mMaxGenerationLength, deviceBufferSizes[4], false);
|
||||
mTemperatureDevice = mAllocator->reMalloc(mTemperatureDevice, deviceBufferSizes[5], false);
|
||||
mBestPathIndicesSlots = mAllocator->reMalloc(mBestPathIndicesSlots, deviceBufferSizes[6], false);
|
||||
mLastDraftIndicesSlots = mAllocator->reMalloc(mLastDraftIndicesSlots, deviceBufferSizes[7], false);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
@ -99,13 +101,15 @@ void ExplicitDraftTokensLayer<T>::freeBuffer()
|
||||
mAllocator->free((void**) (&mGenerationLengthInclusiveSum));
|
||||
mAllocator->free((void**) (&mMaxGenerationLength));
|
||||
mAllocator->free((void**) (&mTemperatureDevice));
|
||||
mAllocator->free((void**) (&mBestPathIndicesSlots));
|
||||
mAllocator->free((void**) (&mLastDraftIndicesSlots));
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ExplicitDraftTokensLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 const* batchSlots,
|
||||
std::shared_ptr<BaseSetupParams> baseSetupParams)
|
||||
std::shared_ptr<BaseSetupParams> const& baseSetupParams)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
@ -139,17 +143,40 @@ void ExplicitDraftTokensLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWid
|
||||
fillBuffers(setupParams->temperature, DefaultDecodingParams::getTemperature(), mTemperature, mTemperatureDevice,
|
||||
batchSlots, getLimitsPenalty(DecodingPenaltyType::Temperature), "temperature penalty");
|
||||
|
||||
fillContextBuffers(batchSize, batchSlots, *setupParams);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ExplicitDraftTokensLayer<T>::fillContextBuffers(
|
||||
SizeType32 batchSize, SizeType32 const* batchSlots, ExplicitDraftTokensSetupParams const& setupParams)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
FillContextExplicitDraftTokensParams<T> params;
|
||||
params.randDataSample = setupParams.randomDataSample.template getPtr<T>();
|
||||
params.outputTemperatures = setupParams.temperatures.template getPtr<T>();
|
||||
params.inputTemperatures = mTemperatureDevice;
|
||||
params.curandState = mCurandStatesDevice;
|
||||
params.batchSlots = batchSlots;
|
||||
params.batchSize = batchSize;
|
||||
|
||||
params.checkParams();
|
||||
|
||||
invokeFillContextBuffers(params, mStream);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ExplicitDraftTokensLayer<T>::forwardAsync(
|
||||
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
|
||||
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto inputs = std::dynamic_pointer_cast<ExplicitDraftTokensInputParams>(baseInputs);
|
||||
auto outputs = std::dynamic_pointer_cast<DynamicDecodeOutputParams>(baseOutputs);
|
||||
auto inputs = std::dynamic_pointer_cast<ExplicitDraftTokensInputs>(baseInputs);
|
||||
auto outputs = std::dynamic_pointer_cast<ExplicitDraftTokensOutputs>(baseOutputs);
|
||||
|
||||
// DO NOT CHANGE THE ORDER.
|
||||
|
||||
@ -167,23 +194,22 @@ void ExplicitDraftTokensLayer<T>::forwardAsync(
|
||||
|
||||
template <typename T>
|
||||
void ExplicitDraftTokensLayer<T>::convertPackedMask(
|
||||
DynamicDecodeOutputParams const& outputs, ExplicitDraftTokensInputParams const& inputs)
|
||||
ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto batchSlots = inputs.batch_slots->template getPtr<SizeType32 const>();
|
||||
auto batchSlots = inputs.seqSlots.template getPtr<SizeType32 const>();
|
||||
auto masksDevice = inputs.masks.template getPtr<bool const>();
|
||||
auto specDecodingGenerationLengths = inputs.specDecodingGenerationLengths.template getPtr<SizeType32 const>();
|
||||
auto packedMasksDevice = outputs.explicitDraftTokensOutputs->packedMasks.template getPtr<SizeType32>();
|
||||
auto generationLengths = inputs.generationLengths.template getPtr<SizeType32 const>();
|
||||
auto packedMasksDevice = outputs.packedMasks.template getPtr<SizeType32>();
|
||||
|
||||
auto const batchSize = inputs.batch_slots->shape[0];
|
||||
auto const batchSize = inputs.localBatchSize;
|
||||
|
||||
invokeScanReduceSpecDecodingGenerationLengths(batchSize, specDecodingGenerationLengths, mWorkspaceDevice,
|
||||
mScanWorkspaceSizeInBytes, mGenerationLengthInclusiveSum, mWorkspaceDevice, mReduceWorkspaceSizeInBytes,
|
||||
mMaxGenerationLength, mStream);
|
||||
invokeScanReduceGenerationLengths(batchSize, generationLengths, mWorkspaceDevice, mScanWorkspaceSizeInBytes,
|
||||
mGenerationLengthInclusiveSum, mWorkspaceDevice, mReduceWorkspaceSizeInBytes, mMaxGenerationLength, mStream);
|
||||
|
||||
invokeConvertSpecDecodingMaskToPackedMask(batchSize, mGenerationLengthInclusiveSum, mMaxGenerationLength,
|
||||
masksDevice, batchSlots, mDecoderDomain.getSpeculativeDecodingModule()->getMaxDecodingDraftTokens(),
|
||||
invokeConvertMaskToPackedMask(batchSize, mGenerationLengthInclusiveSum, mMaxGenerationLength, masksDevice,
|
||||
batchSlots, mDecoderDomain.getSpeculativeDecodingModule()->getMaxDecodingDraftTokens(),
|
||||
mDecoderDomain.getSpeculativeDecodingModule()->getMaxDecodingTokens(), packedMasksDevice, mStream);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
@ -191,32 +217,34 @@ void ExplicitDraftTokensLayer<T>::convertPackedMask(
|
||||
|
||||
template <typename T>
|
||||
void ExplicitDraftTokensLayer<T>::splitInputDataToBatchSlots(
|
||||
DynamicDecodeOutputParams const& outputs, ExplicitDraftTokensInputParams const& inputs)
|
||||
ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const batchSize = inputs.batch_slots->shape[0];
|
||||
auto const maxSeqLen = outputs.output_ids.shape[outputs.output_ids.shape.size() - 1];
|
||||
auto const batchSize = inputs.localBatchSize;
|
||||
auto const maxSeqLen = outputs.outputIds.shape[outputs.outputIds.shape.size() - 1];
|
||||
|
||||
ExtractExplicitDraftTokensParams<T> params;
|
||||
|
||||
params.outputIds = outputs.output_ids.template getPtr<TokenIdType>();
|
||||
params.outputPositionIdsBase = outputs.explicitDraftTokensOutputs->positionIdsBase.template getPtr<SizeType32>();
|
||||
params.outputPositionIds = outputs.explicitDraftTokensOutputs->nextDraftPosIds.template getPtr<SizeType32>();
|
||||
params.outputNextDraftTokens = outputs.explicitDraftTokensOutputs->nextDraftTokens.template getPtr<TokenIdType>();
|
||||
params.unpackedNextDraftTokens
|
||||
= outputs.explicitDraftTokensOutputs->unpackedNextDraftTokens.template getPtr<TokenIdType>();
|
||||
params.unpackedNextDraftIndices
|
||||
= outputs.explicitDraftTokensOutputs->unpackedNextDraftIndices.template getPtr<SizeType32>();
|
||||
params.acceptedLengths = outputs.explicitDraftTokensOutputs->acceptedLengths.template getPtr<SizeType32>();
|
||||
params.nextDraftLengths = outputs.explicitDraftTokensOutputs->nextDraftLengths.template getPtr<SizeType32>();
|
||||
params.sequenceLengths = outputs.sequence_length->template getPtr<SizeType32>();
|
||||
params.randDataSample = outputs.explicitDraftTokensOutputs->randomDataSample.template getPtr<T>();
|
||||
params.randDataVerification = outputs.explicitDraftTokensOutputs->randomDataValidation.template getPtr<T>();
|
||||
params.outputDraftProbs = outputs.explicitDraftTokensOutputs->nextDraftProbs.template getPtr<T>();
|
||||
params.outputTemperatures = outputs.explicitDraftTokensOutputs->temperatures.template getPtr<T>();
|
||||
params.outputIds = outputs.outputIds.template getPtr<TokenIdType>();
|
||||
params.outputPositionIdsBase = outputs.positionIdsBase.template getPtr<SizeType32>();
|
||||
params.outputPositionIds = outputs.nextDraftPosIds.template getPtr<SizeType32>();
|
||||
params.outputNextDraftTokens = outputs.nextDraftTokens.template getPtr<TokenIdType>();
|
||||
params.unpackedNextDraftTokens = outputs.unpackedNextDraftTokens.template getPtr<TokenIdType>();
|
||||
params.unpackedNextDraftIndices = outputs.unpackedNextDraftIndices.template getPtr<SizeType32>();
|
||||
params.acceptedLengths = outputs.numNewTokens->template getPtr<SizeType32>();
|
||||
params.nextDraftLengths = outputs.nextDraftLengths.template getPtr<SizeType32>();
|
||||
params.prevDraftLengths = outputs.prevDraftLengths.template getPtr<SizeType32>();
|
||||
params.sequenceLengths = outputs.sequenceLength->template getPtr<SizeType32>();
|
||||
params.randDataSample = outputs.randomDataSample.template getPtr<T>();
|
||||
params.randDataVerification = outputs.randomDataValidation.template getPtr<T>();
|
||||
params.outputDraftProbs = outputs.nextDraftProbs.template getPtr<T>();
|
||||
params.outputTemperatures = outputs.temperatures.template getPtr<T>();
|
||||
params.outputGenerationLengths = outputs.generationLengths.template getPtr<SizeType32>();
|
||||
params.outputBestPathIndices = mBestPathIndicesSlots;
|
||||
params.outputLastDraftIndices = mLastDraftIndicesSlots;
|
||||
|
||||
params.batchSlots = inputs.batch_slots->template getPtr<SizeType32 const>();
|
||||
params.batchSlots = inputs.seqSlots.template getPtr<SizeType32 const>();
|
||||
params.nextDraftTokens = inputs.nextDraftTokens.template getPtr<TokenIdType const>();
|
||||
params.lastDraftTokens = inputs.lastDraftTokens.template getPtr<TokenIdType const>();
|
||||
params.inputUnpackedNextDraftIndices = inputs.nextDraftIndices.template getPtr<SizeType32 const>();
|
||||
@ -226,15 +254,32 @@ void ExplicitDraftTokensLayer<T>::splitInputDataToBatchSlots(
|
||||
params.packedPositionIds = inputs.packedPosIds.template getPtr<SizeType32 const>();
|
||||
params.nextFlatTokens = inputs.nextFlatTokens.template getPtr<TokenIdType const>();
|
||||
params.nextDraftProbs = inputs.nextDraftProbs.template getPtr<T const>();
|
||||
params.lastGenerationLengths = inputs.lastGenerationLengths.template getPtr<SizeType32 const>();
|
||||
params.generationLengthInclusiveSum = mGenerationLengthInclusiveSum;
|
||||
params.lastDraftIndices = inputs.lastDraftIndices.template getPtr<SizeType32 const>();
|
||||
params.inputTemperatures = mTemperatureDevice;
|
||||
params.curandState = mCurandStatesDevice;
|
||||
params.curandState = mCurandStatesDevice;
|
||||
params.batchSize = batchSize;
|
||||
params.numPaths = mDecoderDomain.getSpeculativeDecodingModule()->getMaxNumPaths();
|
||||
params.maxPathLength = mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen();
|
||||
params.maxSeqLen = maxSeqLen;
|
||||
params.vocabSize = mDecoderDomain.getVocabSizePadded();
|
||||
params.numContextRequests = batchSize - inputs.lastDraftTokens.shape[0];
|
||||
params.numGenerationRequests = inputs.lastDraftTokens.shape[0];
|
||||
|
||||
params.checkParams();
|
||||
|
||||
// Copy max generation length
|
||||
cudaMemcpyAsync(outputs.maxGenLengthHost.template getPtr<SizeType32>(),
|
||||
inputs.maxGenLengthDevice.template getPtr<SizeType32 const>(), sizeof(SizeType32), cudaMemcpyDeviceToHost,
|
||||
mStream);
|
||||
|
||||
params.checkParams();
|
||||
|
||||
// Copy max generation length
|
||||
cudaMemcpyAsync(outputs.maxGenLengthHost.template getPtr<SizeType32>(),
|
||||
inputs.maxGenLengthDevice.template getPtr<SizeType32 const>(), sizeof(SizeType32), cudaMemcpyDeviceToHost,
|
||||
mStream);
|
||||
|
||||
invokeExtractExplicitDraftTokens(params, mStream);
|
||||
|
||||
@ -245,28 +290,25 @@ void ExplicitDraftTokensLayer<T>::splitInputDataToBatchSlots(
|
||||
|
||||
template <typename T>
|
||||
void ExplicitDraftTokensLayer<T>::packAcceptedPaths(
|
||||
DynamicDecodeOutputParams const& outputs, ExplicitDraftTokensInputParams const& inputs)
|
||||
ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const batchSize = inputs.batch_slots->shape[0];
|
||||
auto const batchSize = inputs.localBatchSize;
|
||||
|
||||
auto paths = inputs.lastDraftIndices.template getPtr<SizeType32 const>();
|
||||
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType32 const>() : nullptr;
|
||||
auto acceptedLengths = outputs.explicitDraftTokensOutputs->acceptedLengths.template getPtr<SizeType32 const>();
|
||||
auto acceptedLengthsCumSum
|
||||
= outputs.explicitDraftTokensOutputs->acceptedLengthsCumSum.template getPtr<SizeType32>();
|
||||
auto pathsOffsets = outputs.explicitDraftTokensOutputs->pathsOffsets.template getPtr<SizeType32>();
|
||||
auto bestPathIndices = inputs.bestPathIndices.template getPtr<SizeType32 const>();
|
||||
auto numNewTokens = outputs.numNewTokens->template getPtr<SizeType32 const>();
|
||||
auto numNewTokensCumSum = outputs.numNewTokensCumSum.template getPtr<SizeType32>();
|
||||
auto pathsOffsets = outputs.pathsOffsets.template getPtr<SizeType32>();
|
||||
auto batchSlots = inputs.batchSlots->template getPtr<SizeType32 const>();
|
||||
|
||||
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for ExplicitDraftTokensLayer");
|
||||
TLLM_CHECK_WITH_INFO(acceptedLengths != nullptr, "Accepted lengths must be provided for ExplicitDraftTokensLayer");
|
||||
TLLM_CHECK_WITH_INFO(numNewTokens != nullptr, "Accepted lengths must be provided for ExplicitDraftTokensLayer");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
acceptedLengthsCumSum != nullptr, "acceptedLengthsCumSum must be provided for ExplicitDraftTokensLayer");
|
||||
numNewTokensCumSum != nullptr, "numNewTokensCumSum must be provided for ExplicitDraftTokensLayer");
|
||||
TLLM_CHECK_WITH_INFO(pathsOffsets != nullptr, "pathsOffsets must be provided for ExplicitDraftTokensLayer");
|
||||
invokePackAcceptedPaths(acceptedLengthsCumSum, pathsOffsets, acceptedLengths, bestPathIndices, paths, batchSlots,
|
||||
batchSize, mDecoderDomain.getSpeculativeDecodingModule()->getMaxNumPaths(),
|
||||
mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen(), true, mStream);
|
||||
invokePackAcceptedPaths(numNewTokensCumSum, pathsOffsets, numNewTokens, mBestPathIndicesSlots,
|
||||
mLastDraftIndicesSlots, batchSlots, batchSize, mDecoderDomain.getSpeculativeDecodingModule()->getMaxNumPaths(),
|
||||
mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen(), false, mStream);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
@ -16,68 +16,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <curand_kernel.h>
|
||||
|
||||
#include "tensorrt_llm/common/tensor.h"
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include "tensorrt_llm/layers/baseLayer.h"
|
||||
#include "tensorrt_llm/layers/decodingParams.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
namespace tc = tensorrt_llm::common;
|
||||
#include <curand_kernel.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
namespace layers
|
||||
{
|
||||
|
||||
class ExplicitDraftTokensSetupParams : public BaseSetupParams
|
||||
{
|
||||
public:
|
||||
std::optional<std::vector<float>> temperature; // [setupBatchSize] on cpu
|
||||
std::optional<std::vector<uint64_t>> randomSeed; // [1] or [setupBatchSize] on cpu
|
||||
};
|
||||
|
||||
class ExplicitDraftTokensInputParams : public BaseInputParams
|
||||
{
|
||||
public:
|
||||
explicit ExplicitDraftTokensInputParams()
|
||||
: BaseInputParams{0, 0, tc::Tensor()}
|
||||
{
|
||||
}
|
||||
|
||||
//! Draft tokens for the next iteration. The first token in each path is the last accepted token at current
|
||||
//! iteration. E.g. if forwardBatchSize == 1, maxNumPaths == 2, maxPathLen== 3, [[[0, 1, 2], [0, 1, 10]]]
|
||||
tc::Tensor nextDraftTokens; // [forwardBatchSize, maxNumPaths, maxPathLen], gpu
|
||||
//! Compressed form of `nextDraftTokens`, where common prefixes and collapsed.
|
||||
//! Using example above [0, 1, 2, 10]
|
||||
tc::Tensor nextFlatTokens; // [forwardBatchSize * maxDecodingTokens], gpu
|
||||
//! Indices of draft tokens in the compressed `nextFlatTokens` for the next iteration.
|
||||
//! Using example above, [[[0, 1, 2], [0, 1, 3]]]
|
||||
tc::Tensor nextDraftIndices; // [forwardBatchSize, maxNumPaths, maxPathLen], gpu
|
||||
//! Probabilities of the next draft tokens.
|
||||
tc::Tensor nextDraftProbs; // [forwardBatchSize, maxNumPaths, maxDraftPathLen, vocabSize], gpu
|
||||
//! Same as `nextDraftTokens`, but for current iteration.
|
||||
//! Current accepted tokens obtained as `lastDraftTokens[bi][bestPathIndices[bi]][1:bestPathLengths[bi]]`.
|
||||
tc::Tensor lastDraftTokens; // [forwardBatchSize, maxNumPaths, maxPathLen], gpu
|
||||
//! Same as `nextDraftIndices`, but for current iteration.
|
||||
tc::Tensor lastDraftIndices; // [forwardBatchSize, maxNumPaths, maxPathLen], gpu
|
||||
//! Boolean attention masks.
|
||||
//! maxDecodingTokens' = specDecodingGenerationLengths.max()
|
||||
tc::Tensor masks; // [forwardBatchSize, maxDecodingTokens', maxDecodingTokens'], gpu
|
||||
//! Relative to `positionIdsBase` position ids. Same as `nextFlatTokens` for next draft indices.
|
||||
//! Using example above, [0, 1, 2, 3]
|
||||
tc::Tensor packedPosIds; // [forwardBatchSize * maxDecodingTokens], gpu
|
||||
//! Lengths of the accepted paths for each request. It is 1 for context phase (Only 1 primary tokens is accepted).
|
||||
tc::Tensor bestPathLengths; // [forwardBatchSize], gpu
|
||||
//! Indices of the accepted paths for each request. It is 0 for context phase.
|
||||
tc::Tensor bestPathIndices; // [forwardBatchSize], gpu
|
||||
//! Number of the draft tokens for the next iteration.
|
||||
tc::Tensor specDecodingGenerationLengths; // [forwardBatchSize], gpu
|
||||
//! Baseline for the position ids.
|
||||
tc::Tensor positionIdsBase; // [forwardBatchSize], gpu
|
||||
};
|
||||
|
||||
//! \brief Decoding layer for speculative decoding technique, when all tokens are generated, decoded and accepted in the
|
||||
//! engine.
|
||||
@ -94,20 +40,23 @@ public:
|
||||
~ExplicitDraftTokensLayer() override;
|
||||
|
||||
void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 const* batchSlots,
|
||||
std::shared_ptr<BaseSetupParams> setupParams) override;
|
||||
std::shared_ptr<BaseSetupParams> const& setupParams) override;
|
||||
|
||||
void forwardAsync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) override;
|
||||
void forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& inputs) override;
|
||||
|
||||
private:
|
||||
void allocateBuffer();
|
||||
void freeBuffer();
|
||||
|
||||
void convertPackedMask(DynamicDecodeOutputParams const& outputs, ExplicitDraftTokensInputParams const& inputs);
|
||||
void fillContextBuffers(
|
||||
SizeType32 batchSize, SizeType32 const* batchSlots, ExplicitDraftTokensSetupParams const& params);
|
||||
|
||||
void splitInputDataToBatchSlots(
|
||||
DynamicDecodeOutputParams const& outputs, ExplicitDraftTokensInputParams const& inputs);
|
||||
void convertPackedMask(ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs);
|
||||
|
||||
void packAcceptedPaths(DynamicDecodeOutputParams const& outputs, ExplicitDraftTokensInputParams const& inputs);
|
||||
void splitInputDataToBatchSlots(ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs);
|
||||
|
||||
void packAcceptedPaths(ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs);
|
||||
|
||||
private:
|
||||
using Base::mStream;
|
||||
@ -129,9 +78,10 @@ private:
|
||||
SizeType32* mGenerationLengthInclusiveSum{nullptr};
|
||||
SizeType32* mMaxGenerationLength{nullptr};
|
||||
float* mTemperatureDevice{nullptr};
|
||||
SizeType32* mBestPathIndicesSlots{nullptr};
|
||||
SizeType32* mLastDraftIndicesSlots{nullptr};
|
||||
|
||||
std::vector<float> mTemperature;
|
||||
};
|
||||
|
||||
} // namespace layers
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -92,35 +92,32 @@ inline bool allOfBatchSlots(
|
||||
}
|
||||
|
||||
inline DecoderDomain getLocalDecoderDomain(
|
||||
std::shared_ptr<BaseInputParams> baseInputs, DecoderDomain const& globalDecoderDomain)
|
||||
std::shared_ptr<BaseDecodingInputs> baseInputs, DecoderDomain const& globalDecoderDomain)
|
||||
{
|
||||
auto inputs = std::dynamic_pointer_cast<DynamicDecodeInputParams>(baseInputs);
|
||||
runtime::SizeType32 batchSize{0};
|
||||
auto inputs = std::dynamic_pointer_cast<DecodingInputs>(baseInputs);
|
||||
runtime::SizeType32 batchSize{baseInputs->localBatchSize};
|
||||
runtime::SizeType32 beamWidth{0};
|
||||
runtime::SizeType32 vocabSize{0};
|
||||
if (inputs->logits)
|
||||
{
|
||||
auto const& logitsShape = inputs->logits->shape;
|
||||
TLLM_CHECK(logitsShape.size() == 3 || logitsShape.size() == 4);
|
||||
batchSize = logitsShape[0];
|
||||
auto const idxOffset = logitsShape.size() - 3;
|
||||
beamWidth = logitsShape[idxOffset + 1];
|
||||
vocabSize = logitsShape[idxOffset + 2];
|
||||
}
|
||||
else if (inputs->logits_vec)
|
||||
else if (inputs->logitsVec)
|
||||
{
|
||||
TLLM_CHECK(inputs->logits_vec->size());
|
||||
auto const& logitsShape = inputs->logits_vec.value()[0].shape;
|
||||
TLLM_CHECK(inputs->logitsVec->size());
|
||||
auto const& logitsShape = inputs->logitsVec.value()[0].shape;
|
||||
TLLM_CHECK(logitsShape.size() == 3 || logitsShape.size() == 4);
|
||||
auto const idxOffset = logitsShape.size() - 3;
|
||||
batchSize = inputs->logits_vec->size();
|
||||
beamWidth = logitsShape[idxOffset + 1];
|
||||
vocabSize = logitsShape[idxOffset + 2];
|
||||
}
|
||||
else if (inputs->batch_slots)
|
||||
else if (inputs->batchSlots)
|
||||
{
|
||||
auto const& batchSlotsShape = inputs->batch_slots->shape;
|
||||
batchSize = batchSlotsShape[0];
|
||||
auto const& batchSlotsShape = inputs->batchSlots->shape;
|
||||
beamWidth = globalDecoderDomain.getBeamWidth();
|
||||
vocabSize = globalDecoderDomain.getVocabSize();
|
||||
}
|
||||
|
||||
@ -15,7 +15,12 @@
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/layers/lookaheadAlgorithm.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/layers/lookaheadDecodingUtils.h"
|
||||
#include "tensorrt_llm/runtime/lookaheadModule.h"
|
||||
#include <tuple>
|
||||
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
@ -31,10 +36,13 @@ void LookaheadAlgorithm::setup(TensorConstPtr const& prompt, SizeType32 w, SizeT
|
||||
mW = w;
|
||||
mN = n;
|
||||
mG = g;
|
||||
std::tie(std::ignore, std::ignore, mRuntimeMaxDraftLen, std::ignore)
|
||||
= executor::LookaheadDecodingConfig(mW, mN, mG).calculateSpeculativeResource();
|
||||
|
||||
mPoolManager.setup(mG);
|
||||
mPoolManager.accept(prompt, mN);
|
||||
mGoldenTokens = ITensor::slice(mGoldenTokensMax, 0, mN * 2 - 1);
|
||||
mPrefills = ITensor::slice(mPrefillsMax, 0, mN - 2);
|
||||
mPrefills = ITensor::slice(mPrefillsMax, 0, mN <= 1 ? 0 : mN - 2);
|
||||
mKeyTokens = ITensor::slice(mKeyTokensMax, 0, mW);
|
||||
mPastTokens = ITensor::slice(mPastTokensMax, 0, mW * (mN - 1));
|
||||
mPastTokens->reshape(ITensor::makeShape({mW, mN - 1}));
|
||||
@ -48,10 +56,14 @@ void LookaheadAlgorithm::setup(TensorConstPtr const& prompt, SizeType32 w, SizeT
|
||||
std::for_each(pastRange.begin(), pastRange.end(), [](auto& a) { a = -1; });
|
||||
for (SizeType32 i = 0; i < mW; i++)
|
||||
{
|
||||
randToken(pastRange[i * (mN - 1)]);
|
||||
if (mN - 1 > 0)
|
||||
{
|
||||
randToken(pastRange[i * (mN - 1)]);
|
||||
}
|
||||
}
|
||||
std::copy(std::prev(promptRange.end(), mN - 1), promptRange.end(), goldRange.begin());
|
||||
mFilling = 1;
|
||||
mGuessTokens = ITensor::slice(mGuessTokensMax, 0, 0);
|
||||
mFilling = (mN - 1) > 0 ? 1 : 0;
|
||||
PRINT_TOKENS(prompt);
|
||||
PRINT_TOKENS(mPrefills);
|
||||
PRINT_TOKENS(mPastTokens);
|
||||
@ -75,6 +87,8 @@ void LookaheadAlgorithm::accept(TensorConstPtr const& generatedTokens)
|
||||
runtime::SizeType32 LookaheadAlgorithm::lookahead(TensorPtr const& draftTokens, TensorPtr const& positionIds,
|
||||
TensorPtr const& samplingMask, runtime::SizeType32 offset)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
SizeType32 prefill = mN - 2 - mFilling;
|
||||
SizeType32 len = prefill + mFilling * mW;
|
||||
TLLM_CHECK(len <= ITensor::volume(draftTokens->getShape()));
|
||||
@ -132,8 +146,9 @@ runtime::SizeType32 LookaheadAlgorithm::lookahead(TensorPtr const& draftTokens,
|
||||
samplingMaskRange[wj * mFilling + mFilling - 1 - 1] = true;
|
||||
}
|
||||
}
|
||||
TLLM_LOG_DEBUG("prefill=%d, offset=%d", prefill, offset);
|
||||
PRINT_VALUES(positionIds);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
return len;
|
||||
}
|
||||
|
||||
@ -171,10 +186,19 @@ void LookaheadAlgorithm::prepare(TensorPtr const& draftTokens, TensorPtr const&
|
||||
TensorPtr const& samplingMask, TensorPtr const& length, TensorConstPtr const& offsetPtr,
|
||||
TensorConstPtr const& lastTokenPtr)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
if (mRuntimeMaxDraftLen == 0)
|
||||
{
|
||||
(BufferRange<SizeType32>(*length))[0] = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
auto lastToken = BufferRange<TokenIdType const>(*lastTokenPtr)[0];
|
||||
auto offset = BufferRange<SizeType32 const>(*offsetPtr)[0];
|
||||
|
||||
SizeType32 inputLen = ITensor::volume(draftTokens->getShape());
|
||||
TLLM_CHECK(inputLen >= mRuntimeMaxDraftLen);
|
||||
|
||||
BufferRange<TokenIdType> draftRange(*draftTokens);
|
||||
BufferRange<TokenIdType> positionRange(*positionIds);
|
||||
@ -182,33 +206,39 @@ void LookaheadAlgorithm::prepare(TensorPtr const& draftTokens, TensorPtr const&
|
||||
|
||||
SizeType32 filledLen = 0;
|
||||
|
||||
filledLen += lookahead(ITensor::slice(draftTokens, filledLen, inputLen - filledLen),
|
||||
ITensor::slice(positionIds, filledLen, inputLen - filledLen),
|
||||
ITensor::slice(samplingMask, filledLen, inputLen - filledLen), offset);
|
||||
filledLen += lookahead(ITensor::slice(draftTokens, filledLen, mRuntimeMaxDraftLen - filledLen),
|
||||
ITensor::slice(positionIds, filledLen, mRuntimeMaxDraftLen - filledLen),
|
||||
ITensor::slice(samplingMask, filledLen, mRuntimeMaxDraftLen - filledLen), offset);
|
||||
|
||||
auto guessStart = filledLen;
|
||||
filledLen += guess(ITensor::slice(draftTokens, filledLen, inputLen - filledLen),
|
||||
ITensor::slice(positionIds, filledLen, inputLen - filledLen),
|
||||
ITensor::slice(samplingMask, filledLen, inputLen - filledLen), offset, lastToken);
|
||||
filledLen += guess(ITensor::slice(draftTokens, filledLen, mRuntimeMaxDraftLen - filledLen),
|
||||
ITensor::slice(positionIds, filledLen, mRuntimeMaxDraftLen - filledLen),
|
||||
ITensor::slice(samplingMask, filledLen, mRuntimeMaxDraftLen - filledLen), offset, lastToken);
|
||||
auto guessEnd = filledLen;
|
||||
|
||||
mGuessTokens = ITensor::slice(mGuessTokensMax, 0, guessEnd - guessStart);
|
||||
|
||||
std::copy(draftRange.begin() + guessStart, draftRange.begin() + guessEnd,
|
||||
BufferRange<TokenIdType>(*mGuessTokens).begin());
|
||||
|
||||
(BufferRange<SizeType32>(*length))[0] = filledLen;
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void LookaheadAlgorithm::verify(TensorPtr const& accepted, TensorPtr const& acceptedOffsets,
|
||||
TensorPtr const& acceptedLength, TokenIdType newLastToken, TensorConstPtr const& goldenTokens,
|
||||
TensorConstPtr const& endToken)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
TLLM_CHECK(ITensor::volume(goldenTokens->getShape()) == ITensor::volume(mGuessTokens->getShape()));
|
||||
BufferRange<TokenIdType const> goldRange(*goldenTokens);
|
||||
BufferRange<TokenIdType> guessTokensRange(*mGuessTokens);
|
||||
auto guessSize = ITensor::volume(mGuessTokens->getShape());
|
||||
|
||||
SizeType32 guesses = guessSize / (mN - 1), hit = 0, maxHit = 0, hitIdx = 0;
|
||||
SizeType32 guesses = (mN - 1 > 0) ? (guessSize / (mN - 1)) : 0;
|
||||
SizeType32 hit = 0, maxHit = 0, hitIdx = 0;
|
||||
for (SizeType32 i = 0; i < guesses; i++)
|
||||
{
|
||||
SizeType32 hit = 0;
|
||||
@ -248,6 +278,8 @@ void LookaheadAlgorithm::verify(TensorPtr const& accepted, TensorPtr const& acce
|
||||
}
|
||||
|
||||
*BufferRange<SizeType32>(*acceptedLength).begin() = maxHit + 1;
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
//! lookahead Jacobi matrix has prefilling phase and maintenance phase.
|
||||
@ -293,6 +325,8 @@ void LookaheadAlgorithm::verify(TensorPtr const& accepted, TensorPtr const& acce
|
||||
void LookaheadAlgorithm::update(TensorPtr const& acceptedTokens, TensorPtr const& acceptedOffsets,
|
||||
TensorPtr const& acceptedLength, TensorConstPtr const& sampledTokens, TensorConstPtr const& endToken)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
TLLM_CHECK(ITensor::volume(acceptedTokens->getShape()) >= mN);
|
||||
BufferRange<TokenIdType const> sampledRange(*sampledTokens);
|
||||
BufferRange<TokenIdType> keyRange(*mKeyTokens);
|
||||
@ -312,7 +346,7 @@ void LookaheadAlgorithm::update(TensorPtr const& acceptedTokens, TensorPtr const
|
||||
pastRange[i * (mN - 1) + mFilling] = keyRange[i];
|
||||
}
|
||||
}
|
||||
else
|
||||
else if (mN > 1)
|
||||
{
|
||||
for (SizeType32 i = 0; i < mW; i++)
|
||||
{
|
||||
@ -329,8 +363,9 @@ void LookaheadAlgorithm::update(TensorPtr const& acceptedTokens, TensorPtr const
|
||||
|
||||
auto guessSize = ITensor::volume(mGuessTokens->getShape());
|
||||
auto outputSize = ITensor::volume(sampledTokens->getShape());
|
||||
auto lookSize = 1 + mN - 2 - mFilling + mFilling * mW;
|
||||
auto lookSize = 1 + (mN > 1 ? mN - 2 : 0) - mFilling + mFilling * mW;
|
||||
TLLM_CHECK(guessSize + lookSize == outputSize);
|
||||
|
||||
TensorConstPtr goldenTokens = ITensor::slice(sampledTokens, lookSize, guessSize);
|
||||
|
||||
verify(acceptedTokens, acceptedOffsets, acceptedLength, newLastToken, goldenTokens, endToken);
|
||||
@ -341,6 +376,8 @@ void LookaheadAlgorithm::update(TensorPtr const& acceptedTokens, TensorPtr const
|
||||
{
|
||||
mFilling++;
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -44,7 +44,8 @@ public:
|
||||
, mId(id)
|
||||
, mGoldenTokensMax(
|
||||
runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxN * 2 - 1}), nvinfer1::DataType::kINT32))
|
||||
, mPrefillsMax(runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxN - 2}), nvinfer1::DataType::kINT32))
|
||||
, mPrefillsMax(runtime::BufferManager::cpu(
|
||||
runtime::ITensor::makeShape({(maxN <= 1 ? 0 : maxN - 2)}), nvinfer1::DataType::kINT32))
|
||||
, mKeyTokensMax(runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxW}), nvinfer1::DataType::kINT32))
|
||||
, mPastTokensMax(
|
||||
runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxW * (maxN - 1)}), nvinfer1::DataType::kINT32))
|
||||
@ -125,6 +126,7 @@ private:
|
||||
runtime::SizeType32 mW{0};
|
||||
runtime::SizeType32 mN{0};
|
||||
runtime::SizeType32 mG{0};
|
||||
runtime::SizeType32 mRuntimeMaxDraftLen{0};
|
||||
//! in prefilling mode when mFilling < mN-1.
|
||||
runtime::SizeType32 mFilling;
|
||||
|
||||
|
||||
364
cpp/tensorrt_llm/layers/lookaheadDecodingLayer.cpp
Normal file
364
cpp/tensorrt_llm/layers/lookaheadDecodingLayer.cpp
Normal file
@ -0,0 +1,364 @@
|
||||
/*
|
||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/layers/lookaheadDecodingLayer.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/kernels/decodingKernels.h"
|
||||
#include "tensorrt_llm/kernels/samplingTopKKernels.h"
|
||||
#include "tensorrt_llm/layers/decodingParams.h"
|
||||
#include "tensorrt_llm/layers/defaultDecodingParams.h"
|
||||
#include "tensorrt_llm/layers/lookaheadAlgorithm.h"
|
||||
#include "tensorrt_llm/layers/lookaheadDecodingUtils.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include <cstdint>
|
||||
#include <ios>
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
using namespace tensorrt_llm::common;
|
||||
using namespace tensorrt_llm::kernels;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
template <typename T>
|
||||
LookaheadDecodingLayer<T>::CpuAlgorithmResources::CpuAlgorithmResources(DecoderDomain const& decoderDomain)
|
||||
{
|
||||
auto maxBatchSize = decoderDomain.getBatchSize();
|
||||
auto lookaheadModule
|
||||
= std::dynamic_pointer_cast<LookaheadModule const>(decoderDomain.getSpeculativeDecodingModule());
|
||||
auto const [maxW, maxN, maxG] = lookaheadModule->getExecutionConfig().get();
|
||||
|
||||
for (runtime::SizeType32 id = 0; id < maxBatchSize; id++)
|
||||
{
|
||||
mAlgos.emplace_back(maxW, maxN, maxG, id);
|
||||
}
|
||||
|
||||
SizeType32 maxTokensPerStep, maxNumNewTokens, maxDraftLen;
|
||||
std::tie(maxTokensPerStep, maxNumNewTokens, maxDraftLen, std::ignore)
|
||||
= executor::LookaheadDecodingConfig(maxW, maxN, maxG).calculateSpeculativeResource();
|
||||
|
||||
auto const maxBatchShape1D = ITensor::makeShape({maxBatchSize});
|
||||
mBatchSlots = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
|
||||
mTargetTokens
|
||||
= BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxTokensPerStep}), nvinfer1::DataType::kINT32);
|
||||
mTokensPerStep = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
|
||||
mEndIds = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
|
||||
|
||||
mOutputIds = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxNumNewTokens}), nvinfer1::DataType::kINT32);
|
||||
mPathsOffsets = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxNumNewTokens}), nvinfer1::DataType::kINT32);
|
||||
mNumNewTokens = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
|
||||
mNumNewTokensCumSum = BufferManager::cpu(ITensor::makeShape({maxBatchSize + 1}), nvinfer1::DataType::kINT32);
|
||||
mNextDraftTokens = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kINT32);
|
||||
mNextDraftPosIds = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kINT32);
|
||||
auto divUp32 = [](SizeType32 x) { return x / 32 + ((x % 32) ? 1 : 0); };
|
||||
mPackedMasks = BufferManager::cpu(
|
||||
ITensor::makeShape({maxBatchSize, maxTokensPerStep, divUp32(maxTokensPerStep)}), nvinfer1::DataType::kINT32);
|
||||
mSamplingMask = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kBOOL);
|
||||
mNextDraftLengths = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
|
||||
mSequenceLengths = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
LookaheadDecodingLayer<T>::LookaheadDecodingLayer(
|
||||
DecoderDomain const& decoderDomain, std::shared_ptr<runtime::BufferManager> const& bufferManager)
|
||||
: BaseLayer(decoderDomain, bufferManager)
|
||||
, mCpuAlgo(std::make_optional<CpuAlgorithmResources>(decoderDomain))
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const maxBatchSize = mDecoderDomain.getBatchSize();
|
||||
auto const maxTokensPerStep = mDecoderDomain.getMaxDecodingTokens();
|
||||
auto const vocabSizePadded = mDecoderDomain.getVocabSizePadded();
|
||||
auto const maxTopK = 1;
|
||||
auto const maxBatchShape1D = ITensor::makeShape({maxBatchSize});
|
||||
auto const maxBatchShape2D = ITensor::makeShape({maxBatchSize, maxTokensPerStep});
|
||||
|
||||
mWorkspaceSize = getTopKWorkspaceSize<T>(maxBatchSize, maxTokensPerStep, maxTopK, vocabSizePadded);
|
||||
TLLM_LOG_DEBUG("mWorkspaceSize=%d", mWorkspaceSize);
|
||||
|
||||
mSamplingWorkspaceDevice
|
||||
= mBufferManager->gpu(ITensor::makeShape({static_cast<int32_t>(mWorkspaceSize)}), nvinfer1::DataType::kINT8);
|
||||
mTargetTokensDevice = mBufferManager->gpu(maxBatchShape2D, nvinfer1::DataType::kINT32);
|
||||
mRandomSeedsDevice = mBufferManager->gpu(maxBatchShape1D, nvinfer1::DataType::kINT64);
|
||||
mSamplingMaskDevice = mBufferManager->gpu(maxBatchShape2D, nvinfer1::DataType::kBOOL);
|
||||
mCurandStatesDevice = mBufferManager->gpu(maxBatchShape1D, nvinfer1::DataType::kINT8);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LookaheadDecodingLayer<T>::setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth,
|
||||
runtime::SizeType32 const* batchSlots, std::shared_ptr<BaseSetupParams> const& baseSetupParams)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto setupParams = std::dynamic_pointer_cast<LookaheadSetupParams>(baseSetupParams);
|
||||
|
||||
if (mCpuAlgo)
|
||||
{
|
||||
auto& algoConfigs = setupParams->algoConfigs;
|
||||
TLLM_CHECK_WITH_INFO(algoConfigs.size() == 1 || algoConfigs.size() == batchSize,
|
||||
"Lookahead runtime configuration size should be either 1 or batchSize");
|
||||
for (runtime::SizeType32 bi = 0; bi < batchSize; bi++)
|
||||
{
|
||||
SizeType32 gbi = batchSlots[bi];
|
||||
SizeType32 bi1orN = (algoConfigs.size() == 1) ? 0 : bi;
|
||||
TLLM_LOG_DEBUG("CPU ALGO [ %d ] setup", gbi);
|
||||
PRINT_TOKENS(setupParams->prompt[bi]);
|
||||
auto [w, n, g] = algoConfigs[bi1orN].get();
|
||||
SizeType32 runtimeTokensPerStep;
|
||||
std::tie(runtimeTokensPerStep, std::ignore, std::ignore, std::ignore)
|
||||
= executor::LookaheadDecodingConfig(w, n, g).calculateSpeculativeResource();
|
||||
TLLM_CHECK_WITH_INFO(runtimeTokensPerStep <= mDecoderDomain.getMaxDecodingTokens(),
|
||||
"runtime w(%d) n(%d) g(%d) exceeds maxTokensPerStep(%d)", w, n, g,
|
||||
mDecoderDomain.getMaxDecodingTokens());
|
||||
mCpuAlgo->mAlgos[gbi].setup(setupParams->prompt[bi], w, n, g);
|
||||
}
|
||||
}
|
||||
|
||||
auto curandStatesDevicePtr = reinterpret_cast<curandState_t*>(bufferCast<int8_t>(*mCurandStatesDevice));
|
||||
if (setupParams->randomSeed)
|
||||
{
|
||||
auto& randomSeed = setupParams->randomSeed.value();
|
||||
if (randomSeed.size() == 1)
|
||||
{
|
||||
invokeCurandInitialize(curandStatesDevicePtr, batchSlots, batchSize, randomSeed.front(), mStream);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(randomSeed.size() == batchSize, "Random seed vector size mismatch.");
|
||||
cudaAutoCpy(bufferCast<uint64_t>(*mRandomSeedsDevice), randomSeed.data(), batchSize, mStream);
|
||||
invokeCurandBatchInitialize(
|
||||
curandStatesDevicePtr, batchSlots, batchSize, bufferCast<uint64_t>(*mRandomSeedsDevice), mStream);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
invokeCurandInitialize(curandStatesDevicePtr, batchSlots, batchSize, DefaultDecodingParams::getSeed(), mStream);
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LookaheadDecodingLayer<T>::forwardAsync(
|
||||
std::shared_ptr<BaseDecodingOutputs> const& outputParams, std::shared_ptr<BaseDecodingInputs> const& inputParams)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto inputs = std::dynamic_pointer_cast<LookaheadDecodingInputs>(inputParams);
|
||||
auto outputs = std::dynamic_pointer_cast<SpeculativeDecodingOutputs>(outputParams);
|
||||
auto batchSize = inputs->localBatchSize;
|
||||
|
||||
TLLM_CHECK_WITH_INFO(inputs->batchSlots, "Batch slots must be provided for LookaheadDecoding");
|
||||
TLLM_CHECK_WITH_INFO(inputs->curTokensPerStep, "curTokensPerStep must be provided for LookaheadDecoding");
|
||||
TLLM_CHECK_WITH_INFO(outputs->sequenceLength, "sequenceLength must be provided for LookaheadDecoding");
|
||||
// TODO(liweim) to be confirmed.
|
||||
TLLM_CHECK(inputs->logits);
|
||||
|
||||
mBufferManager->copy(
|
||||
inputs->batchSlots->template getPtr<SizeType32 const>(), *mCpuAlgo->mBatchSlots, runtime::MemoryType::kGPU);
|
||||
mBufferManager->copy(inputs->curTokensPerStep->template getPtr<SizeType32 const>(), *mCpuAlgo->mTokensPerStep,
|
||||
runtime::MemoryType::kGPU);
|
||||
mBufferManager->copy(
|
||||
inputs->endIds.template getPtr<TokenIdType const>(), *mCpuAlgo->mEndIds, runtime::MemoryType::kGPU);
|
||||
mBufferManager->copy(outputs->sequenceLength->template getPtr<SizeType32 const>(), *mCpuAlgo->mSequenceLengths,
|
||||
runtime::MemoryType::kGPU);
|
||||
|
||||
TopKSamplingKernelParams<T> params;
|
||||
params.maxBatchSize = mDecoderDomain.getBatchSize();
|
||||
params.batchSize = batchSize;
|
||||
params.maxTopK = 1;
|
||||
params.returnAllTopK = true;
|
||||
params.maxTokensPerStep = mDecoderDomain.getMaxDecodingTokens();
|
||||
params.maxSeqLen = mDecoderDomain.getMaxDecodingTokens();
|
||||
params.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
|
||||
params.batchSlots = inputs->batchSlots->template getPtr<SizeType32 const>();
|
||||
TLLM_LOG_DEBUG("batchSize = %d", batchSize);
|
||||
params.logProbs = inputs->logits ? inputs->logits->template getPtr<T>() : nullptr;
|
||||
params.outputIds = bufferCast<TokenIdType>(*mTargetTokensDevice);
|
||||
params.workspace = bufferCast<int8_t>(*mSamplingWorkspaceDevice);
|
||||
params.curandState = reinterpret_cast<curandState_t*>(bufferCast<int8_t>(*mCurandStatesDevice));
|
||||
params.tokensPerStep = inputs->curTokensPerStep->template getPtr<SizeType32 const>();
|
||||
|
||||
TLLM_LOG_DEBUG(
|
||||
"invokeBatchTopKSampling: maxBatchSize=%d, batchSize=%d, maxTopK=%d, maxTokensPerStep=%d, maxSeqLen=%d, "
|
||||
"vocabSizePadded=%d",
|
||||
params.maxBatchSize, params.batchSize, params.maxTopK, params.maxTokensPerStep, params.maxSeqLen,
|
||||
params.vocabSizePadded);
|
||||
|
||||
// Sample multiple tokens per request and store them to separate to be accepted/rejected later
|
||||
// Sequence length is not modified, endIds is not checked, outputLogProbs are not supported.
|
||||
// Finished state is not set.
|
||||
invokeBatchTopKSampling(params, mStream);
|
||||
|
||||
mBufferManager->copy(*mTargetTokensDevice, *mCpuAlgo->mTargetTokens);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LookaheadDecodingLayer<T>::forwardSync(
|
||||
std::shared_ptr<BaseDecodingOutputs> const& outputParams, std::shared_ptr<BaseDecodingInputs> const& inputParams)
|
||||
{
|
||||
if (mCpuAlgo)
|
||||
{
|
||||
forwardSyncCPU(outputParams, inputParams);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LookaheadDecodingLayer<T>::posIdsToMask(TensorPtr mask, TensorConstPtr posIds)
|
||||
{
|
||||
auto len = ITensor::volume(posIds->getShape());
|
||||
TLLM_CHECK(mask->getShape().d[0] > len);
|
||||
TLLM_CHECK(mask->getShape().d[1] * 32 > len);
|
||||
auto posIdsRange = BufferRange<SizeType32 const>(*posIds);
|
||||
auto maskLocation = BufferLocation<SizeType32>(*mask);
|
||||
|
||||
for (auto i = 0; i < maskLocation.size(); i++)
|
||||
{
|
||||
maskLocation[i] = 0;
|
||||
}
|
||||
maskLocation.at(0, 0) = 1;
|
||||
|
||||
auto setBit = [](SizeType32& x, SizeType32 idx) { x |= (1 << idx); };
|
||||
if (len > 0)
|
||||
{
|
||||
std::vector<std::pair<SizeType32, SizeType32>> stack;
|
||||
stack.push_back(std::make_pair(0, posIdsRange[0] - 1));
|
||||
for (auto i = 1; i < len + 1; i++)
|
||||
{
|
||||
auto cur = posIdsRange[i - 1];
|
||||
while (stack.size() > 0 && cur <= stack.back().second)
|
||||
{
|
||||
stack.pop_back();
|
||||
}
|
||||
TLLM_CHECK(stack.size() > 0 ? cur == stack.back().second + 1 : true);
|
||||
stack.push_back(std::make_pair(i, cur));
|
||||
for (auto prev : stack)
|
||||
{
|
||||
setBit(maskLocation.at(i, prev.first / 32), prev.first % 32);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LookaheadDecodingLayer<T>::forwardSyncCPU(
|
||||
std::shared_ptr<BaseDecodingOutputs> const& outputParams, std::shared_ptr<BaseDecodingInputs> const& inputParams)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto inputs = std::dynamic_pointer_cast<LookaheadDecodingInputs>(inputParams);
|
||||
auto outputs = std::dynamic_pointer_cast<SpeculativeDecodingOutputs>(outputParams);
|
||||
auto const batchSize = inputs->localBatchSize;
|
||||
|
||||
TensorPtr outputIds(wrap(outputs->outputIds));
|
||||
BufferRange<SizeType32 const> tokensPerStepRange(*mCpuAlgo->mTokensPerStep);
|
||||
BufferRange<SizeType32> numNewTokensRange(*mCpuAlgo->mNumNewTokens);
|
||||
BufferRange<SizeType32> numNewTokensCumSumRange(*mCpuAlgo->mNumNewTokensCumSum);
|
||||
BufferRange<SizeType32> batchSlotsRange(*mCpuAlgo->mBatchSlots);
|
||||
BufferRange<SizeType32> nextDraftLengthsRange(*mCpuAlgo->mNextDraftLengths);
|
||||
BufferRange<SizeType32> sequenceLengthsRange(*mCpuAlgo->mSequenceLengths);
|
||||
|
||||
for (SizeType32 bi = 0; bi < batchSize; bi++)
|
||||
{
|
||||
SizeType32 gbi = batchSlotsRange[bi];
|
||||
LookaheadAlgorithm& theAlgo(mCpuAlgo->mAlgos[gbi]);
|
||||
|
||||
SizeType32 const tokensPerStep = tokensPerStepRange[gbi];
|
||||
TensorPtr sampledTokens = ITensor::slice(mCpuAlgo->mTargetTokens, {gbi, 0}, tokensPerStep);
|
||||
|
||||
if (tokensPerStep == 1)
|
||||
{ // The first step in generation phase has no draft tokens.
|
||||
theAlgo.accept(sampledTokens);
|
||||
mBufferManager->copy(*sampledTokens, *ITensor::slice(mCpuAlgo->mOutputIds, {gbi, 0}, tokensPerStep));
|
||||
BufferLocation<SizeType32>(*mCpuAlgo->mPathsOffsets).at(gbi, 0) = 0;
|
||||
numNewTokensRange[gbi] = tokensPerStep;
|
||||
BufferLocation<SizeType32>(*mCpuAlgo->mNextDraftLengths).at(gbi) = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
theAlgo.update( //
|
||||
ITensor::at(mCpuAlgo->mOutputIds, {gbi}), //
|
||||
ITensor::at(mCpuAlgo->mPathsOffsets, {gbi}), //
|
||||
ITensor::at(mCpuAlgo->mNumNewTokens, {gbi}), //
|
||||
sampledTokens, //
|
||||
ITensor::at(mCpuAlgo->mEndIds, {gbi}));
|
||||
}
|
||||
|
||||
auto maxNumNewTokens = mCpuAlgo->mOutputIds->getShape().d[1];
|
||||
mBufferManager->copy(*ITensor::at(mCpuAlgo->mOutputIds, {gbi}),
|
||||
*ITensor::slice(outputIds, {gbi, sequenceLengthsRange[gbi]}, maxNumNewTokens));
|
||||
|
||||
sequenceLengthsRange[gbi] += numNewTokensRange[gbi];
|
||||
|
||||
theAlgo.prepare( //
|
||||
ITensor::at(mCpuAlgo->mNextDraftTokens, {gbi}), //
|
||||
ITensor::at(mCpuAlgo->mNextDraftPosIds, {gbi}), //
|
||||
ITensor::at(mCpuAlgo->mSamplingMask, {gbi}), //
|
||||
ITensor::at(mCpuAlgo->mNextDraftLengths, {gbi}), //
|
||||
ITensor::at(mCpuAlgo->mSequenceLengths, {gbi}), //
|
||||
ITensor::at(mCpuAlgo->mOutputIds, {gbi, numNewTokensRange[gbi] - 1}));
|
||||
|
||||
posIdsToMask( //
|
||||
ITensor::at(mCpuAlgo->mPackedMasks, {gbi}), //
|
||||
ITensor::slice(mCpuAlgo->mNextDraftPosIds, {gbi, 0}, nextDraftLengthsRange[gbi]));
|
||||
}
|
||||
|
||||
numNewTokensCumSumRange[0] = 0;
|
||||
for (SizeType32 i = 0; i < numNewTokensRange.size(); i++)
|
||||
{
|
||||
numNewTokensCumSumRange[i + 1] = numNewTokensCumSumRange[i] + numNewTokensRange[i];
|
||||
}
|
||||
|
||||
TLLM_CHECK(outputs->numNewTokens);
|
||||
|
||||
mBufferManager->copy(*mCpuAlgo->mSequenceLengths, //
|
||||
const_cast<void*>(outputs->sequenceLength.value().data), runtime::MemoryType::kGPU);
|
||||
mBufferManager->copy(*mCpuAlgo->mPathsOffsets, //
|
||||
const_cast<void*>(outputs->pathsOffsets.data), runtime::MemoryType::kGPU);
|
||||
mBufferManager->copy(*mCpuAlgo->mNumNewTokens, //
|
||||
const_cast<void*>(outputs->numNewTokens->data), runtime::MemoryType::kGPU);
|
||||
mBufferManager->copy(*mCpuAlgo->mNumNewTokensCumSum, //
|
||||
const_cast<void*>(outputs->numNewTokensCumSum.data), runtime::MemoryType::kGPU);
|
||||
mBufferManager->copy(*mCpuAlgo->mNextDraftTokens, //
|
||||
const_cast<void*>(outputs->nextDraftTokens.data), runtime::MemoryType::kGPU);
|
||||
mBufferManager->copy(*mCpuAlgo->mNextDraftPosIds, //
|
||||
const_cast<void*>(outputs->nextDraftPosIds.data), runtime::MemoryType::kGPU);
|
||||
mBufferManager->copy(*mCpuAlgo->mPackedMasks, //
|
||||
const_cast<void*>(outputs->packedMasks.data), runtime::MemoryType::kGPU);
|
||||
mBufferManager->copy(*mCpuAlgo->mNextDraftLengths, //
|
||||
const_cast<void*>(outputs->nextDraftLengths.data), runtime::MemoryType::kGPU);
|
||||
|
||||
// TODO(liweim) do we need this?
|
||||
// mBufferManager->getStream().synchronize();
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
template class LookaheadDecodingLayer<float>;
|
||||
template class LookaheadDecodingLayer<half>;
|
||||
|
||||
} // namespace tensorrt_llm::layers
|
||||
97
cpp/tensorrt_llm/layers/lookaheadDecodingLayer.h
Normal file
97
cpp/tensorrt_llm/layers/lookaheadDecodingLayer.h
Normal file
@ -0,0 +1,97 @@
|
||||
/*
|
||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "lookaheadAlgorithm.h"
|
||||
#include "tensorrt_llm/common/cudaAllocator.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/layers/baseLayer.h"
|
||||
#include "tensorrt_llm/layers/decodingParams.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/lookaheadModule.h"
|
||||
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
//! \brief LookaheadDecodingLayer
|
||||
template <typename T>
|
||||
class LookaheadDecodingLayer : public BaseLayer
|
||||
{
|
||||
public:
|
||||
using Base = BaseLayer;
|
||||
using TensorPtr = runtime::ITensor::SharedPtr;
|
||||
using TensorConstPtr = runtime::ITensor::SharedConstPtr;
|
||||
using Base::mBufferManager;
|
||||
|
||||
LookaheadDecodingLayer(
|
||||
DecoderDomain const& decoderDomain, std::shared_ptr<runtime::BufferManager> const& bufferManager);
|
||||
|
||||
~LookaheadDecodingLayer() override {}
|
||||
|
||||
void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 const* batchSlots,
|
||||
std::shared_ptr<BaseSetupParams> const& baseSetupParams) override;
|
||||
|
||||
void forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputParams,
|
||||
std::shared_ptr<BaseDecodingInputs> const& inputParams) override;
|
||||
|
||||
void forwardSync(std::shared_ptr<BaseDecodingOutputs> const& outputParams,
|
||||
std::shared_ptr<BaseDecodingInputs> const& inputParams) override;
|
||||
|
||||
private:
|
||||
void forwardSyncCPU(std::shared_ptr<BaseDecodingOutputs> const& outputParams,
|
||||
std::shared_ptr<BaseDecodingInputs> const& inputParams);
|
||||
void posIdsToMask(TensorPtr mask, TensorConstPtr posIds);
|
||||
|
||||
private:
|
||||
using Base::mStream;
|
||||
using Base::mAllocator;
|
||||
using Base::mWorkspaceSize;
|
||||
using Base::mDecoderDomain;
|
||||
|
||||
TensorPtr mCurandStatesDevice;
|
||||
TensorPtr mSamplingWorkspaceDevice;
|
||||
TensorPtr mTargetTokensDevice;
|
||||
TensorPtr mRandomSeedsDevice;
|
||||
TensorPtr mSamplingMaskDevice;
|
||||
|
||||
struct CpuAlgorithmResources
|
||||
{
|
||||
explicit CpuAlgorithmResources(DecoderDomain const& decoderDomain);
|
||||
|
||||
std::vector<LookaheadAlgorithm> mAlgos;
|
||||
TensorPtr mBatchSlots;
|
||||
TensorPtr mTargetTokens;
|
||||
TensorPtr mTokensPerStep;
|
||||
TensorPtr mEndIds;
|
||||
|
||||
TensorPtr mOutputIds;
|
||||
TensorPtr mPathsOffsets;
|
||||
TensorPtr mNumNewTokens;
|
||||
TensorPtr mNumNewTokensCumSum;
|
||||
|
||||
TensorPtr mNextDraftTokens;
|
||||
TensorPtr mNextDraftPosIds;
|
||||
TensorPtr mPackedMasks;
|
||||
TensorPtr mSamplingMask;
|
||||
TensorPtr mNextDraftLengths;
|
||||
TensorPtr mSequenceLengths;
|
||||
};
|
||||
|
||||
std::optional<CpuAlgorithmResources> mCpuAlgo;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::layers
|
||||
@ -1,74 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/layers/lookaheadDecodingUtils.h"
|
||||
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
using TensorPtr = ITensor::SharedPtr;
|
||||
|
||||
ITensor::UniquePtr slice(
|
||||
ITensor::SharedPtr tensor, std::initializer_list<SizeType32> const& offsetDims, size_t const sizeDim)
|
||||
{
|
||||
auto shape = tensor->getShape();
|
||||
TLLM_CHECK(offsetDims.size() > 0);
|
||||
TLLM_CHECK(shape.nbDims >= offsetDims.size());
|
||||
std::vector<size_t> volumes(shape.nbDims);
|
||||
|
||||
int i;
|
||||
volumes[shape.nbDims - 1] = 1;
|
||||
for (i = shape.nbDims - 2; i >= 0; i--)
|
||||
{
|
||||
volumes[i] = shape.d[i + 1] * volumes[i + 1];
|
||||
}
|
||||
|
||||
size_t offset = 0;
|
||||
i = 0;
|
||||
for (auto itd = offsetDims.begin(); itd != offsetDims.end(); itd++)
|
||||
{
|
||||
TLLM_CHECK(0 <= (*itd) && (*itd) < shape.d[i]);
|
||||
offset += (*itd) * volumes[i++];
|
||||
}
|
||||
|
||||
ITensor::Shape dims;
|
||||
dims.nbDims = shape.nbDims - offsetDims.size() + 1;
|
||||
dims.d[0] = sizeDim;
|
||||
for (i = 1; i < dims.nbDims; i++)
|
||||
{
|
||||
dims.d[i] = shape.d[i - 1 + offsetDims.size()];
|
||||
}
|
||||
|
||||
size_t size = ITensor::volume(dims);
|
||||
|
||||
return std::make_unique<TensorView>(std::move(tensor), offset, size, dims);
|
||||
}
|
||||
|
||||
ITensor::UniquePtr slice(ITensor::SharedPtr tensor, std::initializer_list<SizeType32> const& offsetDims)
|
||||
{
|
||||
auto result = slice(tensor, offsetDims, 1);
|
||||
if (result->getShape().nbDims > 1)
|
||||
{
|
||||
result->squeeze(0);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::layers
|
||||
@ -240,10 +240,18 @@ public:
|
||||
{
|
||||
buf << token;
|
||||
}
|
||||
buf << (i == size - 1 ? ']' : ',');
|
||||
if (i != size - 1)
|
||||
{
|
||||
buf << ',';
|
||||
}
|
||||
}
|
||||
buf << ']';
|
||||
};
|
||||
if (shape.nbDims == 1)
|
||||
if (shape.nbDims == 0)
|
||||
{
|
||||
buf << "[]";
|
||||
}
|
||||
else if (shape.nbDims == 1)
|
||||
{
|
||||
line(tensorRange.begin(), shape.d[0]);
|
||||
}
|
||||
@ -277,10 +285,19 @@ public:
|
||||
buf << '[';
|
||||
for (SizeType32 i = 0; i < size; i++)
|
||||
{
|
||||
buf << array[i] << (i == size - 1 ? ']' : ',');
|
||||
buf << array[i];
|
||||
if (i != size - 1)
|
||||
{
|
||||
buf << ',';
|
||||
}
|
||||
}
|
||||
buf << ']';
|
||||
};
|
||||
if (shape.nbDims == 1)
|
||||
if (shape.nbDims == 0)
|
||||
{
|
||||
buf << "[]";
|
||||
}
|
||||
else if (shape.nbDims == 1)
|
||||
{
|
||||
line(tensorRange.begin(), shape.d[0]);
|
||||
}
|
||||
@ -305,15 +322,24 @@ public:
|
||||
{
|
||||
switch (mTensor.getDataType())
|
||||
{
|
||||
case nvinfer1::DataType::kBOOL: return values<bool>();
|
||||
case nvinfer1::DataType::kFLOAT: return values<float>();
|
||||
case nvinfer1::DataType::kINT8: return values<std::int8_t>();
|
||||
case nvinfer1::DataType::kINT32: return values<std::int32_t>();
|
||||
case nvinfer1::DataType::kINT64: return values<std::int64_t>();
|
||||
case nvinfer1::DataType::kUINT8: return values<std::uint8_t>();
|
||||
default: return std::string("Unsupported data type");
|
||||
default: return std::string(mName + ": Unsupported data type");
|
||||
}
|
||||
}
|
||||
|
||||
std::string shape(void)
|
||||
{
|
||||
using namespace tensorrt_llm::runtime;
|
||||
std::ostringstream buf;
|
||||
buf << mName << ": " << mTensor.getShape();
|
||||
return buf.str();
|
||||
}
|
||||
|
||||
void print_tokens(void)
|
||||
{
|
||||
TLLM_LOG_DEBUG(tokens());
|
||||
@ -324,6 +350,11 @@ public:
|
||||
TLLM_LOG_DEBUG(values());
|
||||
}
|
||||
|
||||
void print_shape(void)
|
||||
{
|
||||
TLLM_LOG_DEBUG(shape());
|
||||
}
|
||||
|
||||
private:
|
||||
runtime::ITensor const& mTensor;
|
||||
std::string mName;
|
||||
@ -332,5 +363,6 @@ private:
|
||||
#define D(x) tensorrt_llm::layers::DebugTensor(x, #x)
|
||||
#define PRINT_TOKENS(x) D(x).print_tokens()
|
||||
#define PRINT_VALUES(x) D(x).print_values()
|
||||
#define PRINT_SHAPE(x) D(x).print_shape()
|
||||
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/layers/lookaheadPoolManager.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/layers/lookaheadDecodingUtils.h"
|
||||
|
||||
namespace tensorrt_llm::layers
|
||||
@ -24,13 +25,18 @@ using namespace tensorrt_llm::runtime;
|
||||
|
||||
void LookaheadPoolManager::setup(SizeType32 guessSetSize)
|
||||
{
|
||||
TLLM_CHECK(guessSetSize > 0 && guessSetSize <= mGuessSetSizeMax);
|
||||
TLLM_CHECK(guessSetSize >= 0 && guessSetSize <= mGuessSetSizeMax);
|
||||
mGuessSetSize = guessSetSize;
|
||||
mTokenMap.clear();
|
||||
}
|
||||
|
||||
void LookaheadPoolManager::insertOne(Key key, TensorConstPtr const& ngram)
|
||||
{
|
||||
if (TLLM_UNLIKELY(ITensor::volume(ngram->getShape()) == 0 || mGuessSetSize == 0))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
auto search = mTokenMap.find(key);
|
||||
if (search != mTokenMap.end())
|
||||
{
|
||||
@ -41,7 +47,7 @@ void LookaheadPoolManager::insertOne(Key key, TensorConstPtr const& ngram)
|
||||
BufferRange<TokenIdType const> itemRange(*item);
|
||||
return std::equal(ngramRange.begin(), ngramRange.end(), itemRange.begin());
|
||||
});
|
||||
if (mGuessSetSize >= 0 && search->second.size() >= mGuessSetSize)
|
||||
if (mGuessSetSize > 0 && search->second.size() >= mGuessSetSize)
|
||||
{
|
||||
search->second.pop_front();
|
||||
}
|
||||
@ -104,7 +110,6 @@ void LookaheadPoolManager::update(TensorConstPtr const& keyTokens, TensorConstPt
|
||||
BufferRange<TokenIdType const> sourceRange(*source);
|
||||
BufferRange<TokenIdType> ngramRange(*ngram);
|
||||
std::copy(sourceRange.begin(), sourceRange.end(), ngramRange.begin());
|
||||
|
||||
insertOne(keyRange[wi], ngram);
|
||||
}
|
||||
}
|
||||
|
||||
@ -31,9 +31,7 @@ using namespace tensorrt_llm::kernels;
|
||||
using namespace tensorrt_llm::kernels::speculative_decoding;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace layers
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
@ -144,7 +142,7 @@ void MedusaDecodingLayer<T>::freeBuffer()
|
||||
|
||||
template <typename T>
|
||||
void MedusaDecodingLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 const* batchSlots,
|
||||
std::shared_ptr<BaseSetupParams> baseSetupParams)
|
||||
std::shared_ptr<BaseSetupParams> const& baseSetupParams)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
@ -272,12 +270,12 @@ void MedusaDecodingLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, S
|
||||
|
||||
template <typename T>
|
||||
void MedusaDecodingLayer<T>::forwardAsync(
|
||||
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
|
||||
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto inputs = std::dynamic_pointer_cast<MedusaInputParams>(baseInputs);
|
||||
auto outputs = std::dynamic_pointer_cast<DynamicDecodeOutputParams>(baseOutputs);
|
||||
auto inputs = std::dynamic_pointer_cast<MedusaDecodingInputs>(baseInputs);
|
||||
auto outputs = std::dynamic_pointer_cast<SpeculativeDecodingOutputs>(baseOutputs);
|
||||
|
||||
samplePrimeHeadTokens(*outputs, *inputs);
|
||||
|
||||
@ -294,16 +292,16 @@ void MedusaDecodingLayer<T>::forwardAsync(
|
||||
|
||||
template <typename T>
|
||||
void MedusaDecodingLayer<T>::samplePrimeHeadTokens(
|
||||
DynamicDecodeOutputParams const& outputs, MedusaInputParams const& inputs)
|
||||
SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const batchSize = inputs.logits.shape[0];
|
||||
auto const batchSize = inputs.logits->shape[0];
|
||||
|
||||
auto logits = inputs.logits.template getPtr<T>();
|
||||
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType32 const>() : nullptr;
|
||||
auto sequenceLengths = outputs.sequence_length ? outputs.sequence_length->template getPtr<SizeType32>() : nullptr;
|
||||
auto tokensPerStepDevice = inputs.medusaCurTokensPerStep.template getPtr<SizeType32>();
|
||||
auto logits = inputs.logits->template getPtr<T>();
|
||||
auto batchSlots = inputs.batchSlots ? inputs.batchSlots->template getPtr<SizeType32 const>() : nullptr;
|
||||
auto sequenceLengths = outputs.sequenceLength ? outputs.sequenceLength->template getPtr<SizeType32>() : nullptr;
|
||||
auto tokensPerStepDevice = inputs.curTokensPerStep->template getPtr<SizeType32>();
|
||||
|
||||
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding");
|
||||
@ -333,22 +331,22 @@ void MedusaDecodingLayer<T>::samplePrimeHeadTokens(
|
||||
|
||||
template <typename T>
|
||||
void MedusaDecodingLayer<T>::acceptDraftTokens(
|
||||
DynamicDecodeOutputParams const& outputs, MedusaInputParams const& inputs)
|
||||
SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const batchSize = inputs.logits.shape[0];
|
||||
auto const maxSeqLen = outputs.output_ids.shape[outputs.output_ids.shape.size() - 1];
|
||||
auto const batchSize = inputs.logits->shape[0];
|
||||
auto const maxSeqLen = outputs.outputIds.shape[outputs.outputIds.shape.size() - 1];
|
||||
|
||||
auto outputIds = outputs.output_ids.template getPtr<TokenIdType>();
|
||||
auto endIds = inputs.end_ids.template getPtr<TokenIdType const>();
|
||||
auto outputIds = outputs.outputIds.template getPtr<TokenIdType>();
|
||||
auto endIds = inputs.endIds.template getPtr<TokenIdType const>();
|
||||
auto paths = inputs.paths.template getPtr<SizeType32 const>();
|
||||
|
||||
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType32 const>() : nullptr;
|
||||
auto sequenceLengths = outputs.sequence_length ? outputs.sequence_length->template getPtr<SizeType32>() : nullptr;
|
||||
auto acceptedLengths = outputs.speculativeDecodingOutputs->acceptedLengths.template getPtr<SizeType32>();
|
||||
auto curTokensPerStepDevice = inputs.medusaCurTokensPerStep.template getPtr<SizeType32>();
|
||||
auto targetTokensPerStepDevice = inputs.medusaTargetTokensPerStep.template getPtr<SizeType32>();
|
||||
auto batchSlots = inputs.batchSlots ? inputs.batchSlots->template getPtr<SizeType32 const>() : nullptr;
|
||||
auto sequenceLengths = outputs.sequenceLength ? outputs.sequenceLength->template getPtr<SizeType32>() : nullptr;
|
||||
auto numNewTokens = outputs.numNewTokens->template getPtr<SizeType32>();
|
||||
auto curTokensPerStepDevice = inputs.curTokensPerStep->template getPtr<SizeType32>();
|
||||
auto targetTokensPerStepDevice = inputs.targetTokensPerStep.template getPtr<SizeType32>();
|
||||
|
||||
auto const maxDraftPathLen = mDecoderDomain.getSpeculativeDecodingModule()->getMaxDraftPathLen();
|
||||
|
||||
@ -362,12 +360,12 @@ void MedusaDecodingLayer<T>::acceptDraftTokens(
|
||||
}
|
||||
}
|
||||
|
||||
auto draftIds = outputs.speculativeDecodingOutputs->nextDraftTokens.template getPtr<TokenIdType>();
|
||||
auto draftIds = outputs.nextDraftTokens.template getPtr<TokenIdType>();
|
||||
|
||||
TLLM_CHECK_WITH_INFO(draftIds != nullptr, "Draft ids must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(acceptedLengths != nullptr, "Accepted lengths must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(numNewTokens != nullptr, "Accepted lengths must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
curTokensPerStepDevice != nullptr, "Current tokens per step must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
@ -379,7 +377,7 @@ void MedusaDecodingLayer<T>::acceptDraftTokens(
|
||||
// Compare draft tokens from outputIds with sampled target tokens at mTargetTokensDevice using paths.
|
||||
// Select the longest accepted path, modify outputIds in-place, increment sequenceLengths accordingly.
|
||||
// Fill mMedusaSelectedLogitsPtrsDevice with respective Medusa logits
|
||||
acceptDraftTokensByIdsWithPaths(outputIds, draftIds, mTargetTokensDevice, sequenceLengths, acceptedLengths,
|
||||
acceptDraftTokensByIdsWithPaths(outputIds, draftIds, mTargetTokensDevice, sequenceLengths, numNewTokens,
|
||||
finishedStates, batchSlots, paths, endIds,
|
||||
reinterpret_cast<T const**>(bufferCast<int64_t>(*mMedusaInputLogitsPtrs)),
|
||||
const_cast<T const**>(mMedusaSelectedLogitsPtrsDevice), curTokensPerStepDevice, targetTokensPerStepDevice,
|
||||
@ -391,13 +389,13 @@ void MedusaDecodingLayer<T>::acceptDraftTokens(
|
||||
|
||||
template <typename T>
|
||||
void MedusaDecodingLayer<T>::sampleNewDraftTokens(
|
||||
DynamicDecodeOutputParams const& outputs, MedusaInputParams const& inputs)
|
||||
SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const batchSize = inputs.logits.shape[0];
|
||||
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType32 const>() : nullptr;
|
||||
auto sequenceLengths = (outputs.sequence_length) ? outputs.sequence_length->template getPtr<SizeType32>() : nullptr;
|
||||
auto const batchSize = inputs.logits->shape[0];
|
||||
auto batchSlots = inputs.batchSlots ? inputs.batchSlots->template getPtr<SizeType32 const>() : nullptr;
|
||||
auto sequenceLengths = (outputs.sequenceLength) ? outputs.sequenceLength->template getPtr<SizeType32>() : nullptr;
|
||||
|
||||
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding");
|
||||
@ -449,18 +447,18 @@ void MedusaDecodingLayer<T>::sampleNewDraftTokens(
|
||||
|
||||
template <typename T>
|
||||
void MedusaDecodingLayer<T>::scatterNewDraftTokens(
|
||||
DynamicDecodeOutputParams const& outputs, MedusaInputParams const& inputs)
|
||||
SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const batchSize = inputs.logits.shape[0];
|
||||
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType32 const>()
|
||||
: static_cast<SizeType32*>(nullptr);
|
||||
auto const batchSize = inputs.logits->shape[0];
|
||||
auto batchSlots = inputs.batchSlots ? inputs.batchSlots->template getPtr<SizeType32 const>()
|
||||
: static_cast<SizeType32*>(nullptr);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding");
|
||||
|
||||
auto draftIds = outputs.speculativeDecodingOutputs->nextDraftTokens.template getPtr<TokenIdType>();
|
||||
auto tokensPerStepDevice = inputs.medusaCurTokensPerStep.template getPtr<SizeType32>();
|
||||
auto draftIds = outputs.nextDraftTokens.template getPtr<TokenIdType>();
|
||||
auto tokensPerStepDevice = inputs.curTokensPerStep->template getPtr<SizeType32>();
|
||||
auto treeIds = inputs.treeIds.template getPtr<SizeType32>();
|
||||
TLLM_CHECK_WITH_INFO(draftIds != nullptr, "Draft ids must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(tokensPerStepDevice != nullptr, "Tokens per step must be provided for MedusaDecoding");
|
||||
@ -474,23 +472,22 @@ void MedusaDecodingLayer<T>::scatterNewDraftTokens(
|
||||
|
||||
template <typename T>
|
||||
void MedusaDecodingLayer<T>::packAcceptedPaths(
|
||||
DynamicDecodeOutputParams const& outputs, MedusaInputParams const& inputs)
|
||||
SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const batchSize = inputs.logits.shape[0];
|
||||
auto const batchSize = inputs.logits->shape[0];
|
||||
auto paths = inputs.paths.template getPtr<SizeType32 const>();
|
||||
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType32 const>() : nullptr;
|
||||
auto acceptedLengths = outputs.speculativeDecodingOutputs->acceptedLengths.template getPtr<SizeType32>();
|
||||
auto acceptedLengthsCumSum
|
||||
= outputs.speculativeDecodingOutputs->acceptedLengthsCumSum.template getPtr<SizeType32>();
|
||||
auto pathsOffsets = outputs.speculativeDecodingOutputs->pathsOffsets.template getPtr<SizeType32>();
|
||||
auto batchSlots = inputs.batchSlots ? inputs.batchSlots->template getPtr<SizeType32 const>() : nullptr;
|
||||
auto numNewTokens = outputs.numNewTokens->template getPtr<SizeType32>();
|
||||
auto numNewTokensCumSum = outputs.numNewTokensCumSum.template getPtr<SizeType32>();
|
||||
auto pathsOffsets = outputs.pathsOffsets.template getPtr<SizeType32>();
|
||||
|
||||
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(acceptedLengths != nullptr, "Accepted lengths must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(acceptedLengthsCumSum != nullptr, "acceptedLengthsCumSum must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(numNewTokens != nullptr, "Accepted lengths must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(numNewTokensCumSum != nullptr, "numNewTokensCumSum must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(pathsOffsets != nullptr, "pathsOffsets must be provided for MedusaDecoding");
|
||||
invokePackAcceptedPaths(acceptedLengthsCumSum, pathsOffsets, acceptedLengths, mBestPathIdsDevice, paths, batchSlots,
|
||||
invokePackAcceptedPaths(numNewTokensCumSum, pathsOffsets, numNewTokens, mBestPathIdsDevice, paths, batchSlots,
|
||||
batchSize, mDecoderDomain.getMaxDecodingTokens(),
|
||||
mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen(), false, mStream);
|
||||
|
||||
@ -500,5 +497,4 @@ void MedusaDecodingLayer<T>::packAcceptedPaths(
|
||||
template class MedusaDecodingLayer<float>;
|
||||
template class MedusaDecodingLayer<half>;
|
||||
|
||||
} // namespace layers
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -19,46 +19,13 @@
|
||||
|
||||
#include <curand_kernel.h>
|
||||
|
||||
#include "tensorrt_llm/common/tensor.h"
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include "tensorrt_llm/layers/baseLayer.h"
|
||||
#include "tensorrt_llm/layers/decodingParams.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
namespace tc = tensorrt_llm::common;
|
||||
|
||||
namespace tensorrt_llm
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
namespace layers
|
||||
{
|
||||
|
||||
class MedusaSetupParams : public BaseSetupParams
|
||||
{
|
||||
public:
|
||||
std::optional<std::vector<int32_t>> runtimeTopK; // [1] or [setupBatchSize] on cpu
|
||||
std::optional<std::vector<std::vector<int32_t>>> runtimeHeadsTopK; // [setupBatchSize, maxDraftPathLen] on cpu
|
||||
std::optional<std::vector<uint64_t>> randomSeed; // [1] or [setupBatchSize] on cpu
|
||||
};
|
||||
|
||||
class MedusaInputParams : public BaseInputParams
|
||||
{
|
||||
public:
|
||||
explicit MedusaInputParams(tc::Tensor logits, tc::Tensor endIds)
|
||||
: BaseInputParams{0, 0, std::move(endIds)}
|
||||
, logits{std::move(logits)}
|
||||
{
|
||||
}
|
||||
|
||||
tc::Tensor logits; // [maxBatchSize, beamWidth, vocabSizePadded]
|
||||
|
||||
tc::Tensor paths; // [maxBatchSize, maxDecodingTokens, maxPathLen] on gpu
|
||||
std::vector<std::vector<tc::Tensor>>
|
||||
medusaLogits; // [maxBatchSize][maxDraftPathLen][maxDecodingTokens, vocabSize] on gpu
|
||||
tc::Tensor medusaCurTokensPerStep; // [maxBatchSize] on gpu
|
||||
tc::Tensor medusaTargetTokensPerStep; // [maxBatchSize] on gpu
|
||||
tc::Tensor treeIds; // [maxBatchSize, maxDecodingTokens] on gpu
|
||||
};
|
||||
|
||||
//! \brief
|
||||
template <typename T>
|
||||
@ -74,19 +41,20 @@ public:
|
||||
~MedusaDecodingLayer() override;
|
||||
|
||||
void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 const* batchSlots,
|
||||
std::shared_ptr<BaseSetupParams> setupParams) override;
|
||||
std::shared_ptr<BaseSetupParams> const& setupParams) override;
|
||||
|
||||
void forwardAsync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) override;
|
||||
void forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& inputs) override;
|
||||
|
||||
private:
|
||||
void allocateBuffer();
|
||||
void freeBuffer();
|
||||
|
||||
void samplePrimeHeadTokens(DynamicDecodeOutputParams const& outputs, MedusaInputParams const& inputs);
|
||||
void acceptDraftTokens(DynamicDecodeOutputParams const& outputs, MedusaInputParams const& inputs);
|
||||
void sampleNewDraftTokens(DynamicDecodeOutputParams const& outputs, MedusaInputParams const& inputs);
|
||||
void scatterNewDraftTokens(DynamicDecodeOutputParams const& outputs, MedusaInputParams const& inputs);
|
||||
void packAcceptedPaths(DynamicDecodeOutputParams const& outputs, MedusaInputParams const& inputs);
|
||||
void samplePrimeHeadTokens(SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs);
|
||||
void acceptDraftTokens(SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs);
|
||||
void sampleNewDraftTokens(SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs);
|
||||
void scatterNewDraftTokens(SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs);
|
||||
void packAcceptedPaths(SpeculativeDecodingOutputs const& outputs, MedusaDecodingInputs const& inputs);
|
||||
|
||||
private:
|
||||
using Base::mStream;
|
||||
@ -118,5 +86,4 @@ private:
|
||||
std::vector<runtime::SizeType32> mCummulativeTopK;
|
||||
};
|
||||
|
||||
} // namespace layers
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -17,9 +17,11 @@
|
||||
|
||||
#include "tensorrt_llm/layers/penaltyLayer.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/common/tensor.h"
|
||||
#include "tensorrt_llm/kernels/penaltyKernels.h"
|
||||
#include "tensorrt_llm/layers/defaultDecodingParams.h"
|
||||
#include "tensorrt_llm/layers/layerUtils.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
@ -27,9 +29,7 @@ using namespace tensorrt_llm::common;
|
||||
using namespace tensorrt_llm::kernels;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace layers
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
@ -182,7 +182,7 @@ void PenaltyLayer<T>::freeBuffer()
|
||||
|
||||
template <typename T>
|
||||
void PenaltyLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 const* batchSlots,
|
||||
std::shared_ptr<BaseSetupParams> baseSetupParams)
|
||||
std::shared_ptr<BaseSetupParams> const& baseSetupParams)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
@ -207,14 +207,15 @@ void PenaltyLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType
|
||||
FillBuffers const fillBuffers{batchSize, mDecoderDomain.getBatchSize(), mStream};
|
||||
|
||||
auto const& penaltyParams = setupParams->penaltyParams;
|
||||
TLLM_CHECK_WITH_INFO(penaltyParams, "penaltyParams for setup is not set");
|
||||
|
||||
bool const useTemperature = mDecodingMode.isUseTemperature() && penaltyParams.temperature.has_value();
|
||||
bool const useTemperature = mDecodingMode.isUseTemperature() && penaltyParams->temperature.has_value();
|
||||
bool const useRepetitionPenalty
|
||||
= mDecodingMode.isUseRepetitionPenalty() && penaltyParams.repetitionPenalty.has_value();
|
||||
bool const usePresencePenalty = mDecodingMode.isUsePresencePenalty() && penaltyParams.presencePenalty.has_value();
|
||||
= mDecodingMode.isUseRepetitionPenalty() && penaltyParams->repetitionPenalty.has_value();
|
||||
bool const usePresencePenalty = mDecodingMode.isUsePresencePenalty() && penaltyParams->presencePenalty.has_value();
|
||||
bool const useFrequencyPenalty
|
||||
= mDecodingMode.isUseFrequencyPenalty() && penaltyParams.frequencyPenalty.has_value();
|
||||
bool const useMinLength = mDecodingMode.isUseMinLength() && penaltyParams.minLength.has_value();
|
||||
= mDecodingMode.isUseFrequencyPenalty() && penaltyParams->frequencyPenalty.has_value();
|
||||
bool const useMinLength = mDecodingMode.isUseMinLength() && penaltyParams->minLength.has_value();
|
||||
// FIXME(nkorobov): once one of the requests has some penalty, we will always have to compute it.
|
||||
// To avoid that we need to scan through all active requests at each iteration.
|
||||
mUseTemperature |= useTemperature;
|
||||
@ -225,31 +226,31 @@ void PenaltyLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType
|
||||
|
||||
if (mUseTemperature)
|
||||
{
|
||||
fillBuffers(penaltyParams.temperature, DefaultDecodingParams::getTemperature(), mTemperature,
|
||||
fillBuffers(penaltyParams->temperature, DefaultDecodingParams::getTemperature(), mTemperature,
|
||||
mTemperatureDevice, batchSlotsHost, getLimitsPenalty(DecodingPenaltyType::Temperature),
|
||||
"temperature penalty");
|
||||
}
|
||||
if (mUseRepetitionPenalty)
|
||||
{
|
||||
fillBuffers(penaltyParams.repetitionPenalty, DefaultDecodingParams::getRepetitionPenalty(), mRepetitionPenalty,
|
||||
fillBuffers(penaltyParams->repetitionPenalty, DefaultDecodingParams::getRepetitionPenalty(), mRepetitionPenalty,
|
||||
mRepetitionPenaltyDevice, batchSlotsHost, getLimitsPenalty(DecodingPenaltyType::Repetition),
|
||||
"repetition penalty");
|
||||
}
|
||||
if (mUsePresencePenalty)
|
||||
{
|
||||
fillBuffers(penaltyParams.presencePenalty, DefaultDecodingParams::getPresencePenalty(), mPresencePenalty,
|
||||
fillBuffers(penaltyParams->presencePenalty, DefaultDecodingParams::getPresencePenalty(), mPresencePenalty,
|
||||
mPresencePenaltyDevice, batchSlotsHost, getLimitsPenalty(DecodingPenaltyType::Presence),
|
||||
"presence penalty");
|
||||
}
|
||||
if (mUseFrequencyPenalty)
|
||||
{
|
||||
fillBuffers(penaltyParams.frequencyPenalty, DefaultDecodingParams::getFrequencyPenalty(), mFrequencyPenalty,
|
||||
fillBuffers(penaltyParams->frequencyPenalty, DefaultDecodingParams::getFrequencyPenalty(), mFrequencyPenalty,
|
||||
mFrequencyPenaltyDevice, batchSlotsHost, getLimitsPenalty(DecodingPenaltyType::Frequency),
|
||||
"frequency penalty");
|
||||
}
|
||||
if (mUseMinLength)
|
||||
{
|
||||
fillBuffers(penaltyParams.minLength, DefaultDecodingParams::getMinLength(), mMinLength, mMinLengthDevice,
|
||||
fillBuffers(penaltyParams->minLength, DefaultDecodingParams::getMinLength(), mMinLength, mMinLengthDevice,
|
||||
batchSlotsHost, getLimitsPenalty(DecodingPenaltyType::MinLength), "min length");
|
||||
}
|
||||
|
||||
@ -258,21 +259,21 @@ void PenaltyLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType
|
||||
|
||||
template <typename T>
|
||||
void PenaltyLayer<T>::forwardAsync(
|
||||
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
|
||||
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto outputs = std::dynamic_pointer_cast<DynamicDecodeOutputParams>(baseOutputs);
|
||||
auto params = std::dynamic_pointer_cast<DynamicDecodeInputParams>(baseInputs);
|
||||
auto outputs = std::dynamic_pointer_cast<BaseDecodingOutputs>(baseOutputs);
|
||||
auto params = std::dynamic_pointer_cast<DecodingInputs>(baseInputs);
|
||||
|
||||
auto const localDecoderDomain = getLocalDecoderDomain(params, mDecoderDomain);
|
||||
auto const maxSeqLen = outputs->output_ids.shape[outputs->output_ids.shape.size() - 1];
|
||||
auto batchSlots = params->batch_slots ? params->batch_slots->template getPtr<SizeType32 const>() : nullptr;
|
||||
auto const maxSeqLen = outputs->outputIds.shape[outputs->outputIds.shape.size() - 1];
|
||||
auto batchSlots = params->batchSlots ? params->batchSlots->template getPtr<SizeType32 const>() : nullptr;
|
||||
|
||||
std::vector<SizeType32> batchSlotsVec(localDecoderDomain.getBatchSize());
|
||||
std::iota(batchSlotsVec.begin(), batchSlotsVec.end(), 0);
|
||||
auto batchSlotsHost
|
||||
= params->batch_slots ? params->batch_slots->template getPtr<SizeType32 const>() : batchSlotsVec.data();
|
||||
= params->batchSlots ? params->batchSlots->template getPtr<SizeType32 const>() : batchSlotsVec.data();
|
||||
|
||||
if (!mLogitsPtrsHost->data())
|
||||
{
|
||||
@ -288,12 +289,12 @@ void PenaltyLayer<T>::forwardAsync(
|
||||
auto logitsPtrsHostData = reinterpret_cast<T const**>(runtime::bufferCast<int64_t>(*logitsPtrsHost));
|
||||
for (SizeType32 bi = 0; bi < localDecoderDomain.getBatchSize(); bi++)
|
||||
{
|
||||
if (params->logits_vec)
|
||||
if (params->logitsVec)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(params->logits_vec->size() == localDecoderDomain.getBatchSize(),
|
||||
"Logits vector size (%lu) is not equal to the batchSize (%d)", params->logits_vec->size(),
|
||||
TLLM_CHECK_WITH_INFO(params->logitsVec->size() == localDecoderDomain.getBatchSize(),
|
||||
"Logits vector size (%lu) is not equal to the batchSize (%d)", params->logitsVec->size(),
|
||||
localDecoderDomain.getBatchSize());
|
||||
logitsPtrsHostData[bi] = params->logits_vec.value()[bi].template getPtr<T>();
|
||||
logitsPtrsHostData[bi] = params->logitsVec.value()[bi].template getPtr<T>();
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -303,12 +304,11 @@ void PenaltyLayer<T>::forwardAsync(
|
||||
}
|
||||
|
||||
SizeType32 const* inputLengths = nullptr;
|
||||
if (params->input_lengths)
|
||||
if (params->inputLengths)
|
||||
{
|
||||
auto& input_lengths = params->input_lengths.value();
|
||||
inputLengths = input_lengths.template getPtr<SizeType32 const>();
|
||||
inputLengths = params->inputLengths->template getPtr<SizeType32 const>();
|
||||
}
|
||||
auto* embeddingBias = params->embedding_bias ? params->embedding_bias->template getPtr<T const>() : nullptr;
|
||||
auto* embeddingBias = params->embeddingBias ? params->embeddingBias->template getPtr<T const>() : nullptr;
|
||||
#define GET_PENALTIES(capital_name, type) \
|
||||
(mUse##capital_name \
|
||||
&& !allOfBatchSlots(batchSlotsHost, m##capital_name.data(), localDecoderDomain.getBatchSize(), \
|
||||
@ -324,9 +324,8 @@ void PenaltyLayer<T>::forwardAsync(
|
||||
|
||||
#undef GET_PENALTIES
|
||||
|
||||
auto const tokensPerStep = params->medusaInputs
|
||||
? params->medusaInputs->medusaCurTokensPerStep.template getPtr<SizeType32 const>()
|
||||
: nullptr;
|
||||
auto const tokensPerStep
|
||||
= params->curTokensPerStep ? params->curTokensPerStep->template getPtr<SizeType32 const>() : nullptr;
|
||||
|
||||
InvokeBatchApplyPenaltyParams<T> penaltyParams;
|
||||
penaltyParams.inputLogits = reinterpret_cast<T const* const*>(logitsPtrsHostData);
|
||||
@ -343,12 +342,12 @@ void PenaltyLayer<T>::forwardAsync(
|
||||
penaltyParams.maxSeqLen = maxSeqLen;
|
||||
penaltyParams.vocabSize = mDecoderDomain.getVocabSize();
|
||||
penaltyParams.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
|
||||
penaltyParams.outputIdsPtr = outputs->output_ids_ptr.template getPtr<TokenIdType const*>();
|
||||
penaltyParams.parentIdsPtr = outputs->parent_ids_ptr.template getPtr<SizeType32 const*>();
|
||||
penaltyParams.outputIdsPtr = outputs->outputIdsPtr.template getPtr<TokenIdType const*>();
|
||||
penaltyParams.parentIdsPtr = outputs->parentIdsPtr.template getPtr<SizeType32 const*>();
|
||||
penaltyParams.inputLengths = inputLengths;
|
||||
penaltyParams.sequenceLengths = outputs->sequence_length->template getPtr<SizeType32 const>();
|
||||
penaltyParams.sequenceLengths = outputs->sequenceLength->template getPtr<SizeType32 const>();
|
||||
penaltyParams.minLengths = minLengths;
|
||||
penaltyParams.endIds = params->end_ids.template getPtr<TokenIdType const>();
|
||||
penaltyParams.endIds = params->endIds.template getPtr<TokenIdType const>();
|
||||
penaltyParams.batchSlots = batchSlots;
|
||||
penaltyParams.maxTokensPerStep = mDecoderDomain.getMaxDecodingTokens();
|
||||
penaltyParams.tokensPerStep = tokensPerStep;
|
||||
@ -376,5 +375,4 @@ void PenaltyLayer<T>::forwardAsync(
|
||||
template class PenaltyLayer<float>;
|
||||
template class PenaltyLayer<half>;
|
||||
|
||||
} // namespace layers
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -19,17 +19,12 @@
|
||||
|
||||
#include <curand_kernel.h>
|
||||
|
||||
#include "tensorrt_llm/common/tensor.h"
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include "tensorrt_llm/layers/baseLayer.h"
|
||||
#include "tensorrt_llm/layers/decodingParams.h"
|
||||
#include "tensorrt_llm/layers/layerUtils.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace layers
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
//! \brief Layer applies penalties to the logits. Supports:
|
||||
@ -48,10 +43,11 @@ public:
|
||||
~PenaltyLayer() override;
|
||||
|
||||
void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 const* batchSlots,
|
||||
std::shared_ptr<BaseSetupParams> setupParams) override;
|
||||
std::shared_ptr<BaseSetupParams> const& setupParams) override;
|
||||
|
||||
//! \brief Modifies 'outputs->logits' in-place with -INF for banned words
|
||||
void forwardAsync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) override;
|
||||
void forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& inputs) override;
|
||||
|
||||
T* getRuntimeLogitsDevice()
|
||||
{
|
||||
@ -103,5 +99,4 @@ private:
|
||||
runtime::ITensor::SharedPtr mLogitsPtrsHost;
|
||||
};
|
||||
|
||||
} // namespace layers
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -19,7 +19,8 @@
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/kernels/decodingCommon.h"
|
||||
#include "tensorrt_llm/kernels/samplingTopKKernels.h"
|
||||
#include "tensorrt_llm/layers/topKSamplingLayer.h"
|
||||
#include "tensorrt_llm/layers/topPSamplingLayer.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
@ -27,9 +28,7 @@ using namespace tensorrt_llm::common;
|
||||
using namespace tensorrt_llm::kernels;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace layers
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
template <typename T>
|
||||
SamplingLayer<T>::SamplingLayer(executor::DecodingMode const& mode, DecoderDomain const& decoderDomain,
|
||||
@ -111,7 +110,7 @@ void SamplingLayer<T>::freeBuffer()
|
||||
|
||||
template <typename T>
|
||||
void SamplingLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 const* batchSlots,
|
||||
std::shared_ptr<BaseSetupParams> baseSetupParams)
|
||||
std::shared_ptr<BaseSetupParams> const& baseSetupParams)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
@ -168,20 +167,17 @@ void SamplingLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeTyp
|
||||
|
||||
template <typename T>
|
||||
void SamplingLayer<T>::forwardAsync(
|
||||
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
|
||||
std::shared_ptr<BaseDecodingOutputs> const& outputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto inputs = std::dynamic_pointer_cast<SamplingInputParams>(baseInputs);
|
||||
auto outputs = std::dynamic_pointer_cast<SamplingOutputParams>(baseOutputs);
|
||||
auto inputs = std::dynamic_pointer_cast<SamplingInputs>(baseInputs);
|
||||
|
||||
auto const batchSize = inputs->logits.shape[0];
|
||||
auto const batchSize = inputs->logits->shape[0];
|
||||
|
||||
auto logits = inputs->logits.template getPtr<T>();
|
||||
auto endIds = inputs->end_ids.template getPtr<int const>();
|
||||
auto batchSlots = inputs->batch_slots ? inputs->batch_slots->template getPtr<int const>() : nullptr;
|
||||
float* cumLogProbs = (outputs->cum_log_probs) ? outputs->cum_log_probs->template getPtr<float>() : nullptr;
|
||||
float* outputLogProbs = (outputs->output_log_probs) ? outputs->output_log_probs->template getPtr<float>() : nullptr;
|
||||
auto logits = inputs->logits->template getPtr<T>();
|
||||
auto endIds = inputs->endIds.template getPtr<int const>();
|
||||
auto batchSlots = inputs->batchSlots ? inputs->batchSlots->template getPtr<int const>() : nullptr;
|
||||
|
||||
FinishedState* finishedInput = (inputs->finished)
|
||||
? reinterpret_cast<FinishedState*>(inputs->finished->template getPtr<FinishedState::UnderlyingType>())
|
||||
@ -192,9 +188,9 @@ void SamplingLayer<T>::forwardAsync(
|
||||
// Compute probabilities either for TopP or if cumLogProbs or outputLogProbs are specified
|
||||
bool const skipSoftMax = skipTopP && !mOutputLogProbs && !mCumLogProbs;
|
||||
|
||||
inputs->curand_states = mCurandStatesDevice;
|
||||
inputs->sampling_workspace = mSamplingWorkspaceDevice;
|
||||
inputs->probs_computed = !skipSoftMax;
|
||||
inputs->curandStates = mCurandStatesDevice;
|
||||
inputs->samplingWorkspace = mSamplingWorkspaceDevice;
|
||||
inputs->probsComputed = !skipSoftMax;
|
||||
if (!skipSoftMax)
|
||||
{
|
||||
invokeAddBiasSoftMax(logits, (T**) nullptr, logits, (T*) (nullptr), endIds, finishedInput, batchSlots,
|
||||
@ -205,7 +201,7 @@ void SamplingLayer<T>::forwardAsync(
|
||||
|
||||
for (auto&& layer : mSamplingLayers)
|
||||
{
|
||||
layer->forwardAsync(baseOutputs, baseInputs);
|
||||
layer->forwardAsync(outputs, baseInputs);
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
@ -214,5 +210,4 @@ void SamplingLayer<T>::forwardAsync(
|
||||
template class SamplingLayer<float>;
|
||||
template class SamplingLayer<half>;
|
||||
|
||||
} // namespace layers
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -17,21 +17,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <curand_kernel.h>
|
||||
|
||||
#include "tensorrt_llm/common/tensor.h"
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include "tensorrt_llm/layers/baseLayer.h"
|
||||
#include "tensorrt_llm/layers/decodingParams.h"
|
||||
#include "tensorrt_llm/layers/samplingParams.h"
|
||||
#include "tensorrt_llm/layers/topKSamplingLayer.h"
|
||||
#include "tensorrt_llm/layers/topPSamplingLayer.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
|
||||
namespace tc = tensorrt_llm::common;
|
||||
#include <curand_kernel.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace layers
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
//! \brief Top class for sampling layers.
|
||||
@ -48,9 +41,10 @@ public:
|
||||
~SamplingLayer() override = default;
|
||||
|
||||
void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 const* batchSlots,
|
||||
std::shared_ptr<BaseSetupParams> setupParams) override;
|
||||
std::shared_ptr<BaseSetupParams> const& setupParams) override;
|
||||
|
||||
void forwardAsync(std::shared_ptr<BaseOutputParams> outputs, std::shared_ptr<BaseInputParams> inputs) override;
|
||||
void forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& inputs) override;
|
||||
|
||||
private:
|
||||
using Base::mWorkspaceSize;
|
||||
@ -82,5 +76,4 @@ private:
|
||||
void freeBuffer();
|
||||
};
|
||||
|
||||
} // namespace layers
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -1,77 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/layers/decodingParams.h"
|
||||
#include <tensorrt_llm/common/tensor.h>
|
||||
#include <tensorrt_llm/runtime/common.h>
|
||||
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
namespace tc = tensorrt_llm::common;
|
||||
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
class SamplingSetupParams : public BaseSetupParams
|
||||
{
|
||||
public:
|
||||
std::optional<std::vector<runtime::SizeType32>> runtime_top_k; // [1] or [batchSize] on cpu
|
||||
std::optional<std::vector<float>> runtime_top_p; // [1] or [batchSize] on cpu
|
||||
std::optional<std::vector<uint64_t>> randomSeed; // [1] or [batchSize] on cpu
|
||||
std::optional<std::vector<float>> top_p_decay; // [batchSize], must between [0, 1]
|
||||
std::optional<std::vector<float>> top_p_min; // [batchSize], must between [0, 1]
|
||||
std::optional<std::vector<runtime::TokenIdType>> top_p_reset_ids; // [batchSize]
|
||||
std::optional<std::vector<bool>> outputLogProbs; // [batchSize]
|
||||
std::optional<std::vector<bool>> cumLogProbs; // [batchSize]
|
||||
std::optional<bool> normalize_log_probs;
|
||||
};
|
||||
|
||||
class SamplingInputParams : public BaseInputParams
|
||||
{
|
||||
public:
|
||||
explicit SamplingInputParams(runtime::SizeType32 step, runtime::SizeType32 ite, tc::Tensor logits,
|
||||
tc::Tensor end_ids, runtime::SizeType32 max_seq_len)
|
||||
: BaseInputParams{step, ite, std::move(end_ids)}
|
||||
, logits{std::move(logits)}
|
||||
, max_seq_len{max_seq_len}
|
||||
{
|
||||
}
|
||||
|
||||
// mandatory parameters
|
||||
tc::Tensor logits; // [local_batch_size, beam_width, vocab_size_padded]
|
||||
runtime::SizeType32 max_seq_len;
|
||||
|
||||
// optional parameters
|
||||
std::optional<tc::Tensor> input_lengths; // [localBatchSize]
|
||||
curandState_t* curand_states; // [localBatchSize]
|
||||
// Pointer to the workspace for sampling computation
|
||||
void* sampling_workspace;
|
||||
// Flag to mark that logits tensor contains probabilities
|
||||
bool probs_computed;
|
||||
};
|
||||
|
||||
class SamplingOutputParams : public BaseOutputParams
|
||||
{
|
||||
public:
|
||||
explicit SamplingOutputParams(tc::Tensor outputIds)
|
||||
: BaseOutputParams{std::move(outputIds)}
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::layers
|
||||
@ -17,19 +17,14 @@
|
||||
|
||||
#include "tensorrt_llm/layers/stopCriteriaLayer.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/kernels/stopCriteriaKernels.h"
|
||||
#include "tensorrt_llm/layers/layerUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
using namespace tensorrt_llm::common;
|
||||
using namespace tensorrt_llm::kernels;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace layers
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
@ -44,7 +39,7 @@ StopCriteriaLayer<T>::StopCriteriaLayer(executor::DecodingMode const& mode, Deco
|
||||
|
||||
template <typename T>
|
||||
void StopCriteriaLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 const* batchSlots,
|
||||
std::shared_ptr<BaseSetupParams> setupParams)
|
||||
std::shared_ptr<BaseSetupParams> const& setupParams)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
@ -52,16 +47,18 @@ void StopCriteriaLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, Siz
|
||||
|
||||
template <typename T>
|
||||
void StopCriteriaLayer<T>::forwardAsync(
|
||||
std::shared_ptr<BaseOutputParams> baseOutputs, std::shared_ptr<BaseInputParams> baseInputs)
|
||||
std::shared_ptr<BaseDecodingOutputs> const& baseOutputs, std::shared_ptr<BaseDecodingInputs> const& baseInputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto inputs = std::dynamic_pointer_cast<DynamicDecodeInputParams>(baseInputs);
|
||||
auto outputs = std::dynamic_pointer_cast<DynamicDecodeOutputParams>(baseOutputs);
|
||||
auto inputs = std::dynamic_pointer_cast<DecodingInputs>(baseInputs);
|
||||
auto outputs = std::dynamic_pointer_cast<BaseDecodingOutputs>(baseOutputs);
|
||||
|
||||
auto const localDecoderDomain = getLocalDecoderDomain(inputs, mDecoderDomain);
|
||||
auto const maxSeqLen = outputs->output_ids.shape[outputs->output_ids.shape.size() - 1];
|
||||
auto batchSlots = inputs->batch_slots ? inputs->batch_slots->template getPtr<SizeType32 const>() : nullptr;
|
||||
auto const maxSeqLen = outputs->outputIds.shape[outputs->outputIds.shape.size() - 1];
|
||||
auto batchSlots = inputs->batchSlots ? inputs->batchSlots->template getPtr<SizeType32 const>() : nullptr;
|
||||
|
||||
TLLM_CHECK_WITH_INFO(inputs->stopCriteriaInputs, "stopCriteriaInputs for forward is not set");
|
||||
|
||||
if (mDecodingMode.isUseStopWords())
|
||||
{
|
||||
@ -80,60 +77,61 @@ void StopCriteriaLayer<T>::forwardAsync(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void StopCriteriaLayer<T>::checkStopWordsStopCriteria(std::shared_ptr<DynamicDecodeOutputParams>& outputs,
|
||||
std::shared_ptr<DynamicDecodeInputParams> const& inputs, SizeType32 const* batchSlots,
|
||||
DecoderDomain const& decoderDomain, SizeType32 maxSeqLen, cudaStream_t stream)
|
||||
void StopCriteriaLayer<T>::checkStopWordsStopCriteria(std::shared_ptr<BaseDecodingOutputs>& outputs,
|
||||
std::shared_ptr<DecodingInputs> const& inputs, SizeType32 const* batchSlots, DecoderDomain const& decoderDomain,
|
||||
SizeType32 maxSeqLen, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto const maxStopWordsLength = inputs->max_stop_words_len;
|
||||
auto const maxStopWordsLength = inputs->stopCriteriaInputs->maxStopWordsLen;
|
||||
if (maxStopWordsLength)
|
||||
{
|
||||
auto numNewTokens = outputs->speculativeDecodingOutputs
|
||||
? outputs->speculativeDecodingOutputs->acceptedLengths.template getPtr<SizeType32>()
|
||||
: nullptr;
|
||||
invokeStopWordsCriterion(outputs->output_ids_ptr.template getPtr<TokenIdType const*>(),
|
||||
outputs->parent_ids_ptr.template getPtr<SizeType32 const*>(),
|
||||
inputs->stop_words_ptr->template getPtr<TokenIdType const*>(),
|
||||
auto numNewTokens = outputs->numNewTokens ? outputs->numNewTokens->template getPtr<SizeType32>() : nullptr;
|
||||
invokeStopWordsCriterion(outputs->outputIdsPtr.template getPtr<TokenIdType const*>(),
|
||||
outputs->parentIdsPtr.template getPtr<SizeType32 const*>(),
|
||||
inputs->stopCriteriaInputs->stopWordsPtr->template getPtr<TokenIdType const*>(),
|
||||
reinterpret_cast<FinishedState*>(outputs->finished->template getPtr<FinishedState::UnderlyingType>()),
|
||||
outputs->sequence_length->template getPtr<SizeType32>(), batchSlots,
|
||||
inputs->stop_words_lengths->template getPtr<SizeType32 const>(), numNewTokens, maxStopWordsLength,
|
||||
decoderDomain.getBatchSize(), decoderDomain.getBeamWidth(), maxSeqLen, stream);
|
||||
outputs->sequenceLength->template getPtr<SizeType32>(), batchSlots,
|
||||
inputs->stopCriteriaInputs->stopWordsLengths->template getPtr<SizeType32 const>(), numNewTokens,
|
||||
maxStopWordsLength, decoderDomain.getBatchSize(), decoderDomain.getBeamWidth(), maxSeqLen, stream);
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void StopCriteriaLayer<T>::checkMaxLengthStopCriteria(std::shared_ptr<DynamicDecodeOutputParams>& outputs,
|
||||
std::shared_ptr<DynamicDecodeInputParams> const& inputs, SizeType32 const* batchSlots,
|
||||
DecoderDomain const& decoderDomain, SizeType32 maxSeqLen, cudaStream_t stream)
|
||||
void StopCriteriaLayer<T>::checkMaxLengthStopCriteria(std::shared_ptr<BaseDecodingOutputs>& outputs,
|
||||
std::shared_ptr<DecodingInputs> const& inputs, SizeType32 const* batchSlots, DecoderDomain const& decoderDomain,
|
||||
SizeType32 maxSeqLen, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
if (inputs->sequence_limit_length)
|
||||
if (inputs->stopCriteriaInputs->sequenceLimitLength)
|
||||
{
|
||||
auto numNewTokens = outputs->numNewTokens ? outputs->numNewTokens->template getPtr<SizeType32>() : nullptr;
|
||||
|
||||
invokeLengthCriterion(
|
||||
reinterpret_cast<FinishedState*>(outputs->finished->template getPtr<FinishedState::UnderlyingType>()),
|
||||
outputs->finished_sum ? outputs->finished_sum->template getPtr<SizeType32>() : nullptr,
|
||||
inputs->sequence_limit_length->template getPtr<SizeType32 const>(),
|
||||
outputs->sequence_length->template getPtr<SizeType32>(), batchSlots, decoderDomain.getBatchSize(),
|
||||
decoderDomain.getBeamWidth(), stream);
|
||||
outputs->finishedSum ? outputs->finishedSum->template getPtr<SizeType32>() : nullptr,
|
||||
inputs->stopCriteriaInputs->sequenceLimitLength->template getPtr<SizeType32 const>(),
|
||||
outputs->sequenceLength->template getPtr<SizeType32>(), numNewTokens, batchSlots,
|
||||
decoderDomain.getBatchSize(), decoderDomain.getBeamWidth(), stream);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void StopCriteriaLayer<T>::checkEosToken(std::shared_ptr<DynamicDecodeOutputParams>& outputs,
|
||||
std::shared_ptr<DynamicDecodeInputParams> const& inputs, SizeType32 const* batchSlots,
|
||||
DecoderDomain const& decoderDomain, SizeType32 maxSeqLen, cudaStream_t stream)
|
||||
void StopCriteriaLayer<T>::checkEosToken(std::shared_ptr<BaseDecodingOutputs>& outputs,
|
||||
std::shared_ptr<DecodingInputs> const& inputs, SizeType32 const* batchSlots, DecoderDomain const& decoderDomain,
|
||||
SizeType32 maxSeqLen, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
invokeExplicitEOSCriterion(outputs->output_ids_ptr.template getPtr<TokenIdType const*>(),
|
||||
inputs->end_ids.template getPtr<TokenIdType const>(),
|
||||
|
||||
auto numNewTokens = outputs->numNewTokens ? outputs->numNewTokens->template getPtr<SizeType32>() : nullptr;
|
||||
|
||||
invokeExplicitEOSCriterion(outputs->outputIdsPtr.template getPtr<TokenIdType const*>(),
|
||||
inputs->endIds.template getPtr<TokenIdType const>(),
|
||||
reinterpret_cast<FinishedState*>(outputs->finished->template getPtr<FinishedState::UnderlyingType>()),
|
||||
outputs->sequence_length->template getPtr<SizeType32>(),
|
||||
// FIXME(nkorobov): add tokens per step tensor when necessary
|
||||
/* tokensPerStep */ nullptr, batchSlots, decoderDomain.getBatchSize(), decoderDomain.getBeamWidth(),
|
||||
decoderDomain.getMaxDecodingTokens(), stream);
|
||||
outputs->sequenceLength->template getPtr<SizeType32>(), numNewTokens, batchSlots, decoderDomain.getBatchSize(),
|
||||
decoderDomain.getBeamWidth(), decoderDomain.getMaxDecodingTokens(), stream);
|
||||
sync_check_cuda_error();
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
@ -141,5 +139,4 @@ void StopCriteriaLayer<T>::checkEosToken(std::shared_ptr<DynamicDecodeOutputPara
|
||||
template class StopCriteriaLayer<float>;
|
||||
template class StopCriteriaLayer<half>;
|
||||
|
||||
} // namespace layers
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user