Update TensorRT-LLM (#1168)

* Update TensorRT-LLM

---------

Co-authored-by: Bhuvanesh Sridharan <bhuvan.sridharan@gmail.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
Kaiyu Xie 2024-02-27 17:37:34 +08:00 committed by GitHub
parent e4e09dafea
commit 655524dd82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
229 changed files with 4810 additions and 4097 deletions

View File

@ -428,7 +428,7 @@ public:
, mStaticEmulatedTimeoutMs(staticEmulatedTimeoutMs)
, mActiveCount(0)
{
ReturnBatchManagerStatsCallback iterationDataCallback = [this, &logIterationData](std::string const& log)
ReturnBatchManagerStatsCallback iterationDataCallback = [this, logIterationData](std::string const& log)
{
if (logIterationData)
{
@ -563,16 +563,18 @@ public:
{
auto numNewWorkItems = static_cast<int64_t>(rval.size());
comm.bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0);
std::vector<int64_t> packed;
for (auto const& ir : rval)
if (numNewWorkItems > 0)
{
auto vpacked = ir->serialize();
packed.push_back(static_cast<int64_t>(vpacked.size()));
packed.insert(
packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
std::vector<int64_t> packed;
for (auto const& ir : rval)
{
auto vpacked = ir->serialize();
packed.push_back(static_cast<int64_t>(vpacked.size()));
packed.insert(
packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
}
comm.bcast(packed, 0);
}
comm.bcast(packed, 0);
}
}
else
@ -791,7 +793,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
recorder->report();
recorder->writeOpMetricsToCsv();
// Send terminateReqId to terminate servers on all ranks
// Sever on rank 0 will broadcast the terminate signal to other servers on multi-GPU cases
// Server on rank 0 will broadcast the terminate signal to other servers on multi-GPU cases
gptServer->enqueue(std::make_shared<InferenceRequest>(terminateReqId));
}
// Wait until benchmarking is done and batch manager is terminated

View File

@ -114,9 +114,9 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
"benchmark on %d tokens",
maxNumTokens.value(), maxBatchSize * maxInputLength);
}
std::atomic_bool done = false;
try
{
std::atomic_bool done = false;
auto peakMemFuture = std::async(&monitorMemory, std::ref(done));
TLLM_LOG_INFO(memoryCounter.toString());
@ -266,11 +266,14 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
catch (std::runtime_error& e)
{
std::size_t found = std::string(e.what()).find("out of memory");
// We need to kill the memory monitor when OOM.
done = true;
// Unexpected error; rethrow
if (found == std::string::npos)
{
throw;
TLLM_LOG_ERROR(e.what());
throw e;
}
// We can ignore the OOM exception and continue the rest of the benchmark
@ -283,6 +286,12 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
}
continue;
}
catch (...)
{
// We need to kill memory monitor when any other issue occurs
done = true;
throw;
}
}
TLLM_LOG_INFO(memoryCounter.toString());
}

View File

@ -85,6 +85,7 @@ class EncDecBuildConfig:
max_decoder_input_len: Optional[int] = None
max_output_len: Optional[int] = None
builder_opt: Optional[int] = None
n_mels: Optional[int] = None
def __post_init__(self) -> None:
assert self.head_size is not None
@ -1179,6 +1180,27 @@ _allowed_configs = {
mamba_d_conv=4,
mamba_expand=2,
)),
"whisper_large_v3":
ModelConfig(name="whisper_large_v3",
family="whisper",
benchmark_type="enc_dec",
build_config=EncDecBuildConfig(
num_layers=32,
num_decoder_layers=32,
num_heads=20,
head_size=64,
ffn_hidden_size=5120,
hidden_size=1280,
vocab_size=51866,
hidden_act="gelu",
n_positions=448,
n_mels=128,
max_batch_size=8,
max_encoder_input_len=1500,
max_decoder_input_len=1,
max_output_len=200,
builder_opt=None,
)),
}

View File

@ -795,7 +795,7 @@ def build_gpt(args):
max_beam_width=max_beam_width)
if family in [
'opt', 'bloom', 'falcon', 'llama', 'internlm', 'gptneox',
'gptj', 'mamba', 'baichuan'
'gptj', 'mamba', 'baichuan', 'chatglm', 'chatglm2', 'chatglm3'
]:
tensorrt_llm_model(**inputs)
else:
@ -957,6 +957,8 @@ def enc_dec_build_helper(component, config, args):
torch.cuda.set_device(runtime_rank)
family = get_model_family(args.model)
logits_dtype = 'float32'
n_mels = 0
if family == 'bart':
q_scaling = 1.0
has_attention_qkvo_bias = True
@ -969,6 +971,19 @@ def enc_dec_build_helper(component, config, args):
layernorm_position = LayerNormPositionType.pre_layernorm if config.get(
'normalize_before', True) else LayerNormPositionType.post_layernorm
rescale_before_lm_head = False
elif family == 'whisper':
q_scaling = 1.0
has_position_embedding = True
relative_attention = False
has_embedding_layernorm = False
has_attention_qkvo_bias = True
has_mlp_bias = True
has_model_final_layernorm = True
layernorm_position = LayerNormPositionType.pre_layernorm
layernorm_type = LayerNormType.LayerNorm
rescale_before_lm_head = False
logits_dtype = str_dtype_to_trt(args.dtype)
n_mels = config['n_mels']
else:
q_scaling = 1 / config['head_size']**.5
has_attention_qkvo_bias = False
@ -984,6 +999,9 @@ def enc_dec_build_helper(component, config, args):
else:
rescale_before_lm_head = False
quant_mode, _, _ = get_quant_mode(args.quantization)
use_weight_only = quant_mode.is_weight_only()
builder = Builder()
builder_config = builder.create_builder_config(
name=args.model,
@ -1011,6 +1029,10 @@ def enc_dec_build_helper(component, config, args):
has_token_type_embedding=False, # by default
strongly_typed=False, # by default
gather_all_token_logits=False, # by default
int8=(quant_mode.has_act_and_weight_quant()
or quant_mode.is_int8_weight_only()),
quant_mode=quant_mode,
n_mels=n_mels,
)
# build engine
@ -1024,34 +1046,45 @@ def enc_dec_build_helper(component, config, args):
fp16_clamping = (args.dtype == 'float16') and ('t5' in family)
if component == 'encoder':
tllm_model = tensorrt_llm.models.EncoderModel(
num_layers=config['num_layers'],
num_heads=config['num_heads'],
num_kv_heads=config['num_heads'],
head_size=config['head_size'],
hidden_size=config['hidden_size'],
ffn_hidden_size=config['ffn_hidden_size'],
vocab_size=config['vocab_size'],
max_position_embeddings=config.get('n_positions', 0),
has_position_embedding=has_position_embedding,
relative_attention=relative_attention,
max_distance=config.get('max_distance', 0),
num_buckets=config.get('num_buckets', 0),
has_embedding_layernorm=has_embedding_layernorm,
has_embedding_scale=config.get('has_embedding_scale', False),
q_scaling=q_scaling,
has_attention_qkvo_bias=has_attention_qkvo_bias,
has_mlp_bias=has_mlp_bias,
has_model_final_layernorm=has_model_final_layernorm,
layernorm_eps=1e-6,
layernorm_position=layernorm_position,
layernorm_type=layernorm_type,
hidden_act=config['hidden_act'],
dtype=dtype,
use_parallel_embedding=False, # by default
embedding_sharding_dim=0, # by default
mapping=mapping,
fp16_clamping=fp16_clamping)
if family == 'whisper':
tllm_model = tensorrt_llm.models.WhisperEncoder(
n_mels=config['n_mels'],
n_ctx=1500, # n_audio_ctx
n_state=config['hidden_size'],
n_head=config['num_heads'],
n_layer=config['num_layers'],
dtype=dtype)
if use_weight_only:
tllm_model = quantize_model(tllm_model, quant_mode)
else:
tllm_model = tensorrt_llm.models.EncoderModel(
num_layers=config['num_layers'],
num_heads=config['num_heads'],
num_kv_heads=config['num_heads'],
head_size=config['head_size'],
hidden_size=config['hidden_size'],
ffn_hidden_size=config['ffn_hidden_size'],
vocab_size=config['vocab_size'],
max_position_embeddings=config.get('n_positions', 0),
has_position_embedding=has_position_embedding,
relative_attention=relative_attention,
max_distance=config.get('max_distance', 0),
num_buckets=config.get('num_buckets', 0),
has_embedding_layernorm=has_embedding_layernorm,
has_embedding_scale=config.get('has_embedding_scale', False),
q_scaling=q_scaling,
has_attention_qkvo_bias=has_attention_qkvo_bias,
has_mlp_bias=has_mlp_bias,
has_model_final_layernorm=has_model_final_layernorm,
layernorm_eps=1e-6,
layernorm_position=layernorm_position,
layernorm_type=layernorm_type,
hidden_act=config['hidden_act'],
dtype=dtype,
use_parallel_embedding=False, # by default
embedding_sharding_dim=0, # by default
mapping=mapping,
fp16_clamping=fp16_clamping)
elif component == 'decoder':
tllm_model = tensorrt_llm.models.DecoderModel(
num_layers=config['num_layers'],
@ -1084,8 +1117,10 @@ def enc_dec_build_helper(component, config, args):
embedding_sharding_dim=0, # by default
mapping=mapping,
rescale_before_lm_head=rescale_before_lm_head,
logits_dtype='float32', # by default
logits_dtype=logits_dtype, # by default
fp16_clamping=fp16_clamping)
if use_weight_only and family == 'whisper':
tllm_model = quantize_model(tllm_model, quant_mode)
# Module -> Network
engine_name = get_engine_name(args.model, args.dtype, world_size,
@ -1099,6 +1134,12 @@ def enc_dec_build_helper(component, config, args):
network.plugin_config.set_bert_attention_plugin(dtype=args.dtype)
network.plugin_config.set_gemm_plugin(dtype=args.dtype)
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)
if use_weight_only:
network.plugin_config.set_weight_only_quant_matmul_plugin(
dtype=args.dtype)
elif args.mode == 'ootb-except-mha':
network.plugin_config.set_bert_attention_plugin(dtype=args.dtype)
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)
if world_size > 1:
network.plugin_config.set_nccl_plugin(
@ -1110,18 +1151,31 @@ def enc_dec_build_helper(component, config, args):
# Forward
if component == 'encoder':
inputs = tllm_model.prepare_inputs(
max_batch_size=config['max_batch_size'],
max_input_len=config['max_encoder_input_len'],
)
if family == 'whisper':
inputs = tllm_model.prepare_inputs(
max_batch_size=config['max_batch_size'], )
else:
inputs = tllm_model.prepare_inputs(
max_batch_size=config['max_batch_size'],
max_input_len=config['max_encoder_input_len'],
)
elif component == 'decoder':
inputs = tllm_model.prepare_inputs(
max_batch_size=config['max_batch_size'],
max_beam_width=config['max_beam_width'],
max_decoder_input_len=config['max_decoder_input_len'],
max_new_tokens=config['max_output_len'],
max_encoder_input_len=config['max_encoder_input_len'],
)
if family == 'whisper':
inputs = tllm_model.prepare_inputs(
max_batch_size=config['max_batch_size'],
max_beam_width=config['max_beam_width'],
max_decoder_input_len=config['max_decoder_input_len'],
max_new_tokens=config['max_output_len'],
max_encoder_input_len=1500, # n_audio_ctx
)
else:
inputs = tllm_model.prepare_inputs(
max_batch_size=config['max_batch_size'],
max_beam_width=config['max_beam_width'],
max_decoder_input_len=config['max_decoder_input_len'],
max_new_tokens=config['max_output_len'],
max_encoder_input_len=config['max_encoder_input_len'],
)
tllm_model(*inputs)

View File

@ -23,8 +23,9 @@ from base_benchmark import BaseBenchmark, get_engine_name
from build import build_enc_dec
import tensorrt_llm
from tensorrt_llm._utils import trt_dtype_to_torch
from tensorrt_llm._utils import (trt_dtype_to_torch, str_dtype_to_trt)
from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.runtime.session import TensorInfo
class EncDecBenchmark(BaseBenchmark):
@ -49,6 +50,8 @@ class EncDecBenchmark(BaseBenchmark):
# So we use separate variables for encoder and decoder here.
self.encoder_engine_model_name = args.model
self.decoder_engine_model_name = args.model
# only for whisper parameter
self.n_mels = 0
if self.engine_dir is not None:
@ -109,6 +112,8 @@ class EncDecBenchmark(BaseBenchmark):
self.max_input_len = config["builder_config"][
"max_encoder_input_len"]
self.max_output_len = config["builder_config"]["max_output_len"]
self.n_mels = config["builder_config"][
'n_mels'] if 'whisper' in self.model_name else 0
for key, value in config["builder_config"].items():
if key == "name":
@ -173,6 +178,8 @@ class EncDecBenchmark(BaseBenchmark):
if args.max_input_len is None else args.max_input_len
self.max_output_len = build_config['max_output_len'] \
if args.max_output_len is None else args.max_output_len
self.n_mels = build_config[
'n_mels'] if 'whisper' in self.model_name else 0
# Build engine
(
encoder_engine_buffer,
@ -198,6 +205,10 @@ class EncDecBenchmark(BaseBenchmark):
)
def get_config(self):
if 'whisper' in self.model_name:
print(
f"[WARNING] whisper benchmark is input_len=1500, no text prompt, output_len=arbitrary"
)
for inlen, outlen in self.in_out_lens:
if (inlen > self.max_input_len or outlen > self.max_output_len):
print(
@ -216,29 +227,95 @@ class EncDecBenchmark(BaseBenchmark):
def prepare_inputs(self, config):
batch_size, encoder_input_len = config[0], config[1]
encoder_input_ids = (torch.randint(
100, (batch_size, encoder_input_len)).int().cuda())
# For now, just hardcode the decoder_start_token_id to 0 for t5 models.
decoder_start_token_id = 0
decoder_input_ids = torch.IntTensor([[decoder_start_token_id]
]).to(self.device)
decoder_input_ids = decoder_input_ids.repeat(
(encoder_input_ids.shape[0], 1))
# in padding mode --> keep input, just calculate actual length and max length
# Note: 1st token should always count, even if it is pad_token_id (0). e.g., decoder start id in enc-dec models could be a single pad_token_id, we should count
encoder_input_lengths = ((1 + (encoder_input_ids[:, 1:] != 0).sum(
dim=1).type(torch.IntTensor).to(self.device)).clone().detach().to(
dtype=torch.int32, device=self.device))
decoder_input_lengths = ((1 + (decoder_input_ids[:, 1:] != 0).sum(
dim=1).type(torch.IntTensor).to(self.device)).clone().detach().to(
dtype=torch.int32, device=self.device))
# attention mask, always set 1 as if all are valid tokens
attention_mask = torch.ones(
(batch_size, encoder_input_len)).int().cuda()
# cross attention mask, always set 1 as if all are valid tokens
# [batch_size, query_len, encoder_input_len] currently, use query_len=1
cross_attention_mask = torch.ones(
(batch_size, 1, encoder_input_len)).int().cuda()
attention_mask = None
whisper_decoder_encoder_input_lengths = None
outputs = {}
if 'whisper' in self.model_name:
# feature_len always fixed 3000 now
feature_len = 3000
encoder_input_ids = (torch.randint(
1, 100, (batch_size, self.n_mels, feature_len)).int().cuda())
encoder_input_lengths = torch.tensor([
encoder_input_ids.shape[2] // 2
for _ in range(encoder_input_ids.shape[0])
],
dtype=torch.int32,
device=self.device)
decoder_input_ids = (torch.randint(1, 100, (1, )).int().cuda())
decoder_input_ids = decoder_input_ids.repeat(
(encoder_input_ids.shape[0], 1))
output_list = [
TensorInfo('x', str_dtype_to_trt(self.dtype),
encoder_input_ids.shape),
TensorInfo('input_lengths', str_dtype_to_trt('int32'),
encoder_input_lengths.shape)
]
output_info = (self.encoder_session).infer_shapes(output_list)
outputs = {
t.name: torch.empty(tuple(t.shape),
dtype=trt_dtype_to_torch(t.dtype),
device='cuda')
for t in output_info
}
whisper_decoder_encoder_input_lengths = torch.tensor(
[
outputs['output'].shape[1]
for x in range(outputs['output'].shape[0])
],
dtype=torch.int32,
device='cuda')
decoder_input_lengths = torch.tensor([
decoder_input_ids.shape[-1]
for _ in range(decoder_input_ids.shape[0])
],
dtype=torch.int32,
device='cuda')
cross_attention_mask = torch.ones(
[outputs['output'].shape[0], 1,
outputs['output'].shape[1]]).int().cuda()
else:
encoder_input_ids = (torch.randint(
100, (batch_size, encoder_input_len)).int().cuda())
# For now, just hardcode the decoder_start_token_id to 0 for t5 models.
decoder_start_token_id = 0
decoder_input_ids = torch.IntTensor([[decoder_start_token_id]
]).to(self.device)
decoder_input_ids = decoder_input_ids.repeat(
(encoder_input_ids.shape[0], 1))
# in padding mode --> keep input, just calculate actual length and max length
# Note: 1st token should always count, even if it is pad_token_id (0). e.g., decoder start id in enc-dec models could be a single pad_token_id, we should count
encoder_input_lengths = ((
1 + (encoder_input_ids[:, 1:] != 0).sum(dim=1).type(
torch.IntTensor).to(self.device)).clone().detach().to(
dtype=torch.int32, device=self.device))
decoder_input_lengths = ((
1 + (decoder_input_ids[:, 1:] != 0).sum(dim=1).type(
torch.IntTensor).to(self.device)).clone().detach().to(
dtype=torch.int32, device=self.device))
# attention mask, always set 1 as if all are valid tokens
attention_mask = torch.ones(
(batch_size, encoder_input_len)).int().cuda()
# cross attention mask, always set 1 as if all are valid tokens
# [batch_size, query_len, encoder_input_len] currently, use query_len=1
cross_attention_mask = torch.ones(
(batch_size, 1, encoder_input_len)).int().cuda()
hidden_size = (self.encoder_model_config.hidden_size *
self.world_size) # tp_size
hidden_states_shape = (
encoder_input_ids.shape[0],
encoder_input_ids.shape[1],
hidden_size,
)
hidden_states_dtype = lambda name: trt_dtype_to_torch(
self.encoder_session.engine.get_tensor_dtype(name))
outputs["encoder_output"] = torch.empty(
hidden_states_shape,
dtype=hidden_states_dtype("encoder_output"),
device=self.device,
).contiguous()
stream = torch.cuda.current_stream().cuda_stream
return (
@ -248,6 +325,8 @@ class EncDecBenchmark(BaseBenchmark):
decoder_input_ids,
decoder_input_lengths,
cross_attention_mask,
whisper_decoder_encoder_input_lengths,
outputs,
stream,
)
@ -260,47 +339,37 @@ class EncDecBenchmark(BaseBenchmark):
decoder_input_ids,
decoder_input_lengths,
cross_attention_mask,
whisper_decoder_encoder_input_lengths,
outputs,
stream,
) = inputs
hidden_size = (self.encoder_model_config.hidden_size * self.world_size
) # tp_size
hidden_states_shape = (
encoder_input_ids.shape[0],
encoder_input_ids.shape[1],
hidden_size,
)
hidden_states_dtype = lambda name: trt_dtype_to_torch(
self.encoder_session.engine.get_tensor_dtype(name))
# input tensors
inputs = {}
inputs["input_ids"] = encoder_input_ids.contiguous()
inputs["input_lengths"] = encoder_input_lengths
inputs["max_input_length"] = torch.empty(
(self.max_input_len, ),
dtype=hidden_states_dtype("max_input_length"),
device=self.device,
).contiguous()
if 'whisper' in self.model_name:
inputs['x'] = encoder_input_ids.contiguous()
inputs["input_lengths"] = encoder_input_lengths
else:
inputs["input_ids"] = encoder_input_ids.contiguous()
inputs["input_lengths"] = encoder_input_lengths
inputs["max_input_length"] = torch.empty(
(self.max_input_len, ),
dtype=hidden_states_dtype("max_input_length"),
device=self.device,
).contiguous()
if not self.encoder_model_config.gpt_attention_plugin:
inputs["attention_mask"] = attention_mask.contiguous()
if not self.encoder_model_config.gpt_attention_plugin:
inputs["attention_mask"] = attention_mask.contiguous()
if self.encoder_model_config.has_position_embedding:
bsz, seq_len = encoder_input_ids.shape[:2]
position_ids = torch.arange(seq_len,
dtype=torch.int32,
device=encoder_input_ids.device).expand(
bsz, -1)
inputs['position_ids'] = position_ids.contiguous()
# output tensors
outputs = {}
outputs["encoder_output"] = torch.empty(
hidden_states_shape,
dtype=hidden_states_dtype("encoder_output"),
device=self.device,
).contiguous()
if self.encoder_model_config.has_position_embedding:
bsz, seq_len = encoder_input_ids.shape[:2]
position_ids = torch.arange(
seq_len, dtype=torch.int32,
device=encoder_input_ids.device).expand(bsz, -1)
inputs['position_ids'] = position_ids.contiguous()
# run encoder
self.encoder_session.set_shapes(inputs)
@ -311,6 +380,12 @@ class EncDecBenchmark(BaseBenchmark):
# run decoder
sampling_config = tensorrt_llm.runtime.SamplingConfig(
end_id=1, pad_id=0, num_beams=self.num_beams, min_length=output_len)
encoder_output = outputs[
'output'] if 'whisper' in self.model_name else outputs[
"encoder_output"]
encoder_max_input_length = encoder_output.shape[
1] if 'whisper' in self.model_name else torch.max(
encoder_input_lengths).item()
self.decoder_session.setup(
decoder_input_lengths.size(0),
@ -318,9 +393,8 @@ class EncDecBenchmark(BaseBenchmark):
output_len,
beam_width=self.num_beams,
max_attention_window_size=None,
encoder_max_input_length=torch.max(encoder_input_lengths).item(),
encoder_max_input_length=encoder_max_input_length,
)
torch.cuda.synchronize()
cross_attention_mask = None if self.decoder_model_config.gpt_attention_plugin else cross_attention_mask
@ -328,11 +402,11 @@ class EncDecBenchmark(BaseBenchmark):
decoder_input_ids,
decoder_input_lengths,
sampling_config,
encoder_output=outputs["encoder_output"],
encoder_input_lengths=encoder_input_lengths,
encoder_output=encoder_output,
encoder_input_lengths=whisper_decoder_encoder_input_lengths
if 'whisper' in self.model_name else encoder_input_lengths,
cross_attention_mask=cross_attention_mask,
)
torch.cuda.synchronize()
def report(self,
config,

View File

@ -182,7 +182,7 @@ if(ENABLE_MULTI_DEVICE EQUAL 1)
find_library(NCCL_LIB nccl HINTS ${NCCL_LIB_DIR})
endif()
get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_SOURCE_DIR} PATH)
get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR} PATH)
set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty)
include_directories(

View File

@ -51,7 +51,7 @@ public:
batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback getInferenceRequestsCb,
SendResponseCallback sendResponseCb, PollStopSignalCallback pollStopSignalCb = nullptr,
ReturnBatchManagerStatsCallback returnBatchManagerStatsCb = nullptr,
const TrtGptModelOptionalParams& optionalParams = TrtGptModelOptionalParams(),
TrtGptModelOptionalParams const& optionalParams = TrtGptModelOptionalParams(),
std::optional<uint64_t> terminateReqId = std::nullopt, std::optional<SizeType> maxDraftTokens = std::nullopt,
bool excludeInputInOutput = false);
@ -82,9 +82,9 @@ protected:
virtual BatchManagerErrorCode_t step(RequestList& activeRequests, std::set<uint64_t>& activeRequestsIds);
private:
SizeType getMaxInputLen() const;
SizeType getMaxSequenceLen() const;
SizeType getMaxNumSequences() const;
[[nodiscard]] SizeType getMaxInputLen() const;
[[nodiscard]] SizeType getMaxSequenceLen() const;
[[nodiscard]] SizeType getMaxNumSequences() const;
void validateLlmRequest(LlmRequest& newReq) const;
static std::shared_ptr<LlmRequest> fillLlmRequest(std::shared_ptr<InferenceRequest> newReq);

View File

@ -1,96 +0,0 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/batch_manager/schedulerPolicy.h"
#include "tensorrt_llm/runtime/common.h"
#include <list>
#include <memory>
namespace tensorrt_llm::batch_manager::batch_scheduler
{
/// @brief This scheduler takes into account the given request capacity and the KV cache capacity.
/// Depending on the SchedulerPolicy it will schedule already started and new requests,
/// or even terminate previously started requests.
class CapacityScheduler
{
public:
virtual ~CapacityScheduler() = default;
using RequestTable = std::map<uint64_t, std::shared_ptr<LlmRequest>>;
using SizeType = tensorrt_llm::runtime::SizeType;
using RequestList = std::list<std::shared_ptr<LlmRequest>>;
/// @brief Takes as input a sorted list of requests and outputs a sorted lists of requests
/// to update for this current iteration, and a map of requests to terminate
virtual std::tuple<RequestList, RequestTable> scheduleRequests(const RequestList& activeRequests) = 0;
};
/// @brief Schedule up to maxNumRequests requests
class MaxRequestsScheduler : public CapacityScheduler
{
public:
explicit MaxRequestsScheduler(SizeType maxNumRequests);
std::tuple<RequestList, RequestTable> scheduleRequests(const RequestList& activeRequests) override;
private:
SizeType mMaxNumRequests;
};
/// @brief Schedule requests using the MAX_UTILIZATION policy
/// @details Try reserving resources to advance requests by one step,
/// may terminate previously started requests.
class MaxUtilizationScheduler : public CapacityScheduler
{
public:
MaxUtilizationScheduler(
SizeType maxNumRequests, kv_cache_manager::KVCacheManager* kvCacheManager, bool manyMicroBatches);
std::tuple<RequestList, RequestTable> scheduleRequests(const RequestList& activeRequests) override;
private:
bool trySchedulingRequestMaxUtilization(
std::shared_ptr<LlmRequest> const& req, RequestList& scheduledRequests, SizeType& numScheduledBlocks);
SizeType mMaxNumRequests;
kv_cache_manager::KVCacheManager* mKvCacheManager{nullptr};
/// @brief Boolean that indicates if multiple micro batches might be in flight
bool mManyMicroBatches;
};
/// @brief Schedule requests using the GUARANTEED_NO_EVICT policy
class GuaranteedNoEvictScheduler : public CapacityScheduler
{
public:
GuaranteedNoEvictScheduler(SizeType maxNumRequests, kv_cache_manager::KVCacheManager* kvCacheManager);
std::tuple<RequestList, RequestTable> scheduleRequests(const RequestList& activeRequests) override;
private:
SizeType mMaxNumRequests;
kv_cache_manager::KVCacheManager* mKvCacheManager{nullptr};
};
std::unique_ptr<CapacityScheduler> makeCapacityScheduler(tensorrt_llm::runtime::SizeType maxNumRequests,
kv_cache_manager::KVCacheManager* kvCacheManager, SchedulerPolicy schedulerPolicy, bool manyMicroBatches = false);
} // namespace tensorrt_llm::batch_manager::batch_scheduler

View File

@ -16,8 +16,8 @@
#pragma once
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/batch_manager/namedTensor.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include <algorithm>
@ -48,6 +48,7 @@ auto constexpr kTemperatureTensorName = "temperature";
auto constexpr kRuntimeTopKTensorName = "runtime_top_k";
auto constexpr kRuntimeTopPTensorName = "runtime_top_p";
auto constexpr kLengthPenaltyTensorName = "len_penalty";
auto constexpr kEarlyStoppingTensorName = "early_stopping";
auto constexpr kRepetitionPenaltyTensorName = "repetition_penalty";
auto constexpr kMinLengthTensorName = "min_length";
auto constexpr kPresencePenaltyTensorName = "presence_penalty";
@ -74,6 +75,11 @@ auto constexpr kLoraWeights = "lora_weights";
// "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
// "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
// "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
// "cross_attn_qkv": 8 # for enc-dec adapter for cross attention in decoder
// "cross_attn_q": 9 # for enc-dec adapter for cross attention in decoder
// "cross_attn_k": 10 # for enc-dec adapter for cross attention in decoder
// "cross_attn_v": 11 # for enc-dec adapter for cross attention in decoder
// "cross_attn_dense": 12 # for enc-dec adapter for cross attention in decoder
//
// last dim holds [ module_id, layer_idx, adapter_size (D / R value) ]
auto constexpr kLoraConfig = "lora_config"; // [num_lora_modules_layers, 3]
@ -91,24 +97,29 @@ auto constexpr kGenerationLogitsName = "generation_logits";
} // namespace inference_request
template <typename TTensor, typename TNamedTensor>
template <typename TTensor, typename TNamedTensor, typename TStream = runtime::BufferManager::CudaStreamPtr>
class GenericInferenceRequest
{
public:
using TensorPtr = TTensor;
using NamedTensorType = TNamedTensor;
using TensorMap = std::unordered_map<std::string, TTensor>;
using LogitsPostProcessor = typename GenericLlmRequest<TensorPtr, TStream>::LogitsPostProcessor;
explicit GenericInferenceRequest(uint64_t requestId)
explicit GenericInferenceRequest(
uint64_t requestId, std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
: mRequestId{requestId}
, mIsStreaming{false}
, mlogitsPostProcessor(logitsPostProcessor)
{
}
GenericInferenceRequest(uint64_t requestId, TensorMap&& tensorMap)
GenericInferenceRequest(uint64_t requestId, TensorMap&& tensorMap,
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
: mRequestId{requestId}
, mIsStreaming{false}
, mInputTensors{std::move(tensorMap)}
, mlogitsPostProcessor(logitsPostProcessor)
{
for (auto const& [name, tensor] : mInputTensors)
{
@ -116,8 +127,9 @@ public:
}
}
GenericInferenceRequest(uint64_t requestId, TensorMap const& tensorMap)
: GenericInferenceRequest(requestId, TensorMap{tensorMap})
GenericInferenceRequest(uint64_t requestId, TensorMap const& tensorMap,
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
: GenericInferenceRequest(requestId, TensorMap{tensorMap}, logitsPostProcessor)
{
}
@ -141,6 +153,16 @@ public:
return mInputTensors;
}
void setLogitsPostProcessor(std::optional<LogitsPostProcessor> cb)
{
mlogitsPostProcessor = cb;
}
std::optional<LogitsPostProcessor> getLogitsPostProcessor()
{
return mlogitsPostProcessor;
}
static std::array constexpr kTensorNames = {
inference_request::kInputIdsTensorName,
inference_request::kDraftInputIdsTensorName,
@ -156,6 +178,7 @@ public:
inference_request::kRuntimeTopKTensorName,
inference_request::kRuntimeTopPTensorName,
inference_request::kLengthPenaltyTensorName,
inference_request::kEarlyStoppingTensorName,
inference_request::kRepetitionPenaltyTensorName,
inference_request::kMinLengthTensorName,
inference_request::kPresencePenaltyTensorName,
@ -200,7 +223,10 @@ public:
\
void set##funcName(TensorPtr const& tensor) \
{ \
TLLM_CHECK_WITH_INFO(tensor, "Cannot set nullptr when calling %s", __FUNCTION__); \
if constexpr (std::is_same_v<TensorPtr, tensorrt_llm::runtime::ITensor::SharedPtr>) \
{ \
TLLM_CHECK_WITH_INFO(tensor, "Cannot set nullptr when calling %s", __FUNCTION__); \
} \
mInputTensors[tensorName] = tensor; \
}
@ -218,6 +244,7 @@ public:
TENSOR_GETTER_SETTER(RuntimeTopK, inference_request::kRuntimeTopKTensorName)
TENSOR_GETTER_SETTER(RuntimeTopP, inference_request::kRuntimeTopPTensorName)
TENSOR_GETTER_SETTER(LengthPenalty, inference_request::kLengthPenaltyTensorName)
TENSOR_GETTER_SETTER(EarlyStopping, inference_request::kEarlyStoppingTensorName)
TENSOR_GETTER_SETTER(RepetitionPenalty, inference_request::kRepetitionPenaltyTensorName)
TENSOR_GETTER_SETTER(MinLength, inference_request::kMinLengthTensorName)
TENSOR_GETTER_SETTER(PresencePenalty, inference_request::kPresencePenaltyTensorName)
@ -243,6 +270,7 @@ protected:
uint64_t mRequestId;
bool mIsStreaming;
TensorMap mInputTensors;
std::optional<LogitsPostProcessor> mlogitsPostProcessor;
};
class InferenceRequest : public GenericInferenceRequest<tensorrt_llm::runtime::ITensor::SharedPtr, NamedTensor>

View File

@ -326,7 +326,7 @@ private:
//! \param seqSlotIdx Batch slot of sequence to which blocks are assigned.
//! \return Number of matched tokens from loaded blocks.
SizeType loadOrAllocateBlocks(
std::list<VecTokens> blockedTokens, GenerationRequest& sequence, SizeType beamIdx, SizeType seqSlotIdx);
std::list<VecTokens> const& blockedTokens, GenerationRequest& sequence, SizeType beamIdx, SizeType seqSlotIdx);
//! \brief Find block least likely to be reused, free it if necessary and return.
[[nodiscard]] BlockPtr getFreeBlock();
@ -362,9 +362,9 @@ public:
using CudaStreamPtr = std::shared_ptr<runtime::CudaStream>;
KVCacheManager(SizeType numLayers, SizeType numKvHeads, SizeType sizePerHead, SizeType tokensPerBlock,
SizeType maxNumBlocks, SizeType maxNumSequences, SizeType maxBeamWidth, SizeType maxBlocksPerSeq,
SizeType maxAttentionWindow, SizeType sinkTokenLength, bool useOneMoreBlock, nvinfer1::DataType dtype,
CudaStreamPtr stream, bool enableBlockReuse = false, bool useUvm = false);
SizeType maxNumBlocks, SizeType maxNumSequences, SizeType maxBeamWidth, SizeType maxAttentionWindow,
SizeType sinkTokenLength, bool useOneMoreBlock, nvinfer1::DataType dtype, CudaStreamPtr stream,
bool enableBlockReuse = false, bool useUvm = false);
void startScheduling();
@ -405,6 +405,11 @@ public:
return mBlockSize;
}
[[nodiscard]] SizeType getMaxBlocksPerSeq() const
{
return mMaxBlocksPerSeq;
}
[[nodiscard]] BlockManager const& getBlockManager() const
{
return mBlockManager;
@ -414,13 +419,13 @@ public:
/// iterations
/// @param req The request for which we need to calculate the number of needed KV cache blocks
/// @return The number of blocks
SizeType getNeededBlocksOneStep(LlmRequest const& req, bool twoStepsLookAhead) const;
[[nodiscard]] SizeType getNeededBlocksOneStep(LlmRequest const& req, bool twoStepsLookAhead) const;
/// @brief Function that computes the number of KV cache blocks needed to advance a request to completion (i.e. for
/// maxNewTokens)
/// @param req The request for which we need to calculate the number of needed KV cache blocks
/// @return The number of blocks
SizeType getNeededBlocksToCompletion(LlmRequest const& req) const;
[[nodiscard]] SizeType getNeededBlocksToCompletion(LlmRequest const& req) const;
[[nodiscard]] std::vector<runtime::ITensor::SharedPtr> const& getMemoryPools() const
{
@ -458,7 +463,7 @@ public:
* modelConfig.getSizePerHead();
}
[[nodiscard]] static SizeType getMaxNumTokens(KvCacheConfig const& config, nvinfer1::DataType dtype,
[[nodiscard]] static SizeType calculateMaxNumBlocks(KvCacheConfig const& config, nvinfer1::DataType dtype,
tensorrt_llm::runtime::GptModelConfig const& modelConfig, tensorrt_llm::runtime::WorldConfig const& worldConfig,
runtime::BufferManager const& bufferManager);
@ -491,10 +496,8 @@ private:
// Maximum kv cache length per sequence
// Enable cyclic kv cache when it exceeds
SizeType mMaxAttentionWindow;
// Sink token length in the kv cache per sequence
SizeType mSinkTokenLength;
// Bubble token length
SizeType mBubbleLength;
// Number of tokens to fill up the sink tokens to a full block size
SizeType mSinkBubbleLength;
// Maximum token length (including bubble)
SizeType mMaxTokenNum;
// Number of tokens in the sink blocks

View File

@ -25,6 +25,7 @@
#include <cassert>
#include <cstdint>
#include <memory>
#include <utility>
#include <vector>
namespace tensorrt_llm::batch_manager
@ -38,7 +39,7 @@ enum LlmRequestState_t
REQUEST_STATE_GENERATION_COMPLETE = 3
};
template <typename TTensor>
template <typename TTensor, typename TStream = runtime::BufferManager::CudaStreamPtr>
class GenericLlmRequest
{
public:
@ -49,9 +50,10 @@ public:
using VecLogProbs = std::vector<float>;
using BeamTokens = std::vector<VecTokens>;
using TensorPtr = TTensor;
using LogitsPostProcessor = std::function<TensorPtr(RequestIdType, TensorPtr&, BeamTokens const&, TStream)>;
GenericLlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<VecTokens> inputTokens,
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
@ -59,7 +61,8 @@ public:
std::optional<TensorPtr> loraConfig = std::nullopt, bool returnLogProbs = false,
bool returnContextLogits = false, bool returnGenerationLogits = false,
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt,
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false)
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false,
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
: mRequestId(requestId)
, mPromptLen(inputTokens->size())
, mMaxNewTokens(maxNewTokens)
@ -69,15 +72,16 @@ public:
, mEndId(endId)
, mPadId(padId)
, mSeqSlot(-1)
, mLogitsPostProcessor(logitsPostProcessor)
, mOrigPromptLen(mPromptLen)
, mMaxSentTokenPos(mPromptLen - 1)
, mEmbeddingBias(embeddingBias)
, mBadWordsList(badWordsList)
, mStopWordsList(stopWordsList)
, mPromptEmbeddingTable(promptEmbeddingTable)
, mEmbeddingBias(std::move(embeddingBias))
, mBadWordsList(std::move(badWordsList))
, mStopWordsList(std::move(stopWordsList))
, mPromptEmbeddingTable(std::move(promptEmbeddingTable))
, mPromptVocabSize(promptVocabSize)
, mLoraWeights(loraWeights)
, mLoraConfig(loraConfig)
, mLoraWeights(std::move(loraWeights))
, mLoraConfig(std::move(loraConfig))
, mReturnLogProbs(returnLogProbs)
, mContextChunkSize(std::nullopt)
, mContextCurrentPosition(0)
@ -198,14 +202,14 @@ public:
/// @brief Get total number of tokens for this req (prompt + generated)
/// @param beam The beam index
/// @return The number of tokens
SizeType getNumTokens(SizeType beam) const
[[nodiscard]] SizeType getNumTokens(SizeType beam) const
{
return mTokens.at(beam).size();
}
/// @brief Get max number of tokens across all beams
/// @return The number of tokens
SizeType getMaxBeamNumTokens() const
[[nodiscard]] SizeType getMaxBeamNumTokens() const
{
SizeType maxTokens = 0;
for (SizeType beam = 0; beam < mSamplingConfig.beamWidth; ++beam)
@ -219,7 +223,7 @@ public:
/// @param beam The beam index
/// @param pos The position of the token relative to beginning of the prompt
/// @return The token index
TokenIdType getToken(SizeType beam, SizeType pos) const
[[nodiscard]] TokenIdType getToken(SizeType beam, SizeType pos) const
{
return mTokens.at(beam).at(pos);
}
@ -227,42 +231,42 @@ public:
/// @brief Get the tokens at a given beam index
/// @param beam The beam index
/// @return A vector of tokens for this beam index, includes the prompt
VecTokens const& getTokens(SizeType beam) const
[[nodiscard]] VecTokens const& getTokens(SizeType beam) const
{
return mTokens.at(beam);
}
/// @brief Get all tokens (input+output) for all beams
/// @return A vector of vector of tokens.
BeamTokens const& getTokens() const
[[nodiscard]] BeamTokens const& getTokens() const
{
return mTokens;
}
/// @brief Get the draft tokens
/// @return shared_ptr to vector of draft tokens
std::shared_ptr<VecTokens> const& getDraftTokens() const
[[nodiscard]] std::shared_ptr<VecTokens> const& getDraftTokens() const
{
return mDraftTokens;
}
/// @brief Get the logits for the draft tokens
/// @return Tensor of draft logits
std::optional<TensorPtr> getDraftLogits() const
[[nodiscard]] std::optional<TensorPtr> getDraftLogits() const
{
return mDraftLogits;
}
/// @brief Returns true if request has draft tokens
/// @return flag
bool hasDraftTokens() const
[[nodiscard]] bool hasDraftTokens() const
{
return mDraftTokens && mDraftTokens->size() > 0;
return mDraftTokens && !mDraftTokens->empty();
}
/// @brief Get the maximum number of generated tokens among all rays in beam
/// @return The number of generated tokens (doesn't include the prompt tokens)
SizeType getMaxNumGeneratedTokens() const
[[nodiscard]] SizeType getMaxNumGeneratedTokens() const
{
return getMaxBeamNumTokens() - mPromptLen;
}
@ -283,14 +287,14 @@ public:
assert(static_cast<size_t>(mSamplingConfig.beamWidth) == beamTokens.size());
for (std::size_t beam = 0; beam < beamTokens.size(); ++beam)
{
const auto outputId = beamTokens[beam];
auto const outputId = beamTokens[beam];
mTokens.at(beam).push_back(outputId);
}
}
/// @brief Sets the generated tokens for all beams. Erases all previous generated tokens.
/// @param generatedBeamTokens The generated tokens for all beams (vector of vector of tokens)
void setGeneratedTokens(const BeamTokens& generatedBeamTokens)
void setGeneratedTokens(BeamTokens const& generatedBeamTokens)
{
assert(generatedBeamTokens.size() == static_cast<size_t>(mSamplingConfig.beamWidth));
for (std::size_t beam = 0; beam < generatedBeamTokens.size(); ++beam)
@ -346,7 +350,7 @@ public:
/// @brief Get the maximum position of the tokens returned to the client. Use to ensure we don't return to
/// client duplicated token positions.
/// @return The maximum position of the tokens sent to the client
SizeType getMaxSentTokenPos() const
[[nodiscard]] SizeType getMaxSentTokenPos() const
{
return mMaxSentTokenPos;
}
@ -359,17 +363,17 @@ public:
mMaxSentTokenPos = pos;
}
std::optional<TensorPtr> getPromptEmbeddingTable() const
[[nodiscard]] std::optional<TensorPtr> getPromptEmbeddingTable() const
{
return mPromptEmbeddingTable;
}
std::optional<SizeType> getPromptVocabSize() const
[[nodiscard]] std::optional<SizeType> getPromptVocabSize() const
{
return mPromptVocabSize;
}
std::optional<TensorPtr> getLoraWeights() const
[[nodiscard]] std::optional<TensorPtr> getLoraWeights() const
{
return mLoraWeights;
}
@ -384,7 +388,7 @@ public:
mLoraWeights = std::nullopt;
}
std::optional<TensorPtr> getLoraConfig() const
[[nodiscard]] std::optional<TensorPtr> getLoraConfig() const
{
return mLoraConfig;
}
@ -399,32 +403,37 @@ public:
mLoraConfig = std::nullopt;
}
std::optional<TensorPtr> getEmbeddingBias() const
[[nodiscard]] std::optional<TensorPtr> getEmbeddingBias() const
{
return mEmbeddingBias;
}
std::optional<TensorPtr> getBadWordsList() const
[[nodiscard]] std::optional<TensorPtr> getBadWordsList() const
{
return mBadWordsList;
}
std::optional<TensorPtr> getStopWordsList() const
[[nodiscard]] std::optional<TensorPtr> getStopWordsList() const
{
return mStopWordsList;
}
bool returnLogProbs() const
[[nodiscard]] bool returnLogProbs() const
{
return mReturnLogProbs;
}
std::vector<VecLogProbs> const& getLogProbs() const
void setReturnLogProbs(bool returnLogProbs)
{
mReturnLogProbs = returnLogProbs;
}
[[nodiscard]] std::vector<VecLogProbs> const& getLogProbs() const
{
return mLogProbs;
}
VecLogProbs const& getLogProbs(SizeType beam) const
[[nodiscard]] VecLogProbs const& getLogProbs(SizeType beam) const
{
return mLogProbs.at(beam);
}
@ -435,7 +444,7 @@ public:
mLogProbs.at(beam).insert(mLogProbs.at(beam).end(), logProbs.begin(), logProbs.end());
}
VecLogProbs const& getCumLogProbs() const
[[nodiscard]] VecLogProbs const& getCumLogProbs() const
{
return mCumLogProbs;
}
@ -445,17 +454,17 @@ public:
mCumLogProbs.at(beam) = cumLogProb;
}
SizeType getOrigPromptLen() const
[[nodiscard]] SizeType getOrigPromptLen() const
{
return mOrigPromptLen;
}
void setDraftTokens(const std::shared_ptr<VecTokens>& draftTokens)
void setDraftTokens(std::shared_ptr<VecTokens> const& draftTokens)
{
mDraftTokens = draftTokens;
}
void setDraftLogits(const std::optional<TensorPtr>& draftLogits)
void setDraftLogits(std::optional<TensorPtr> const& draftLogits)
{
mDraftLogits = draftLogits;
}
@ -470,7 +479,7 @@ public:
mReturnContextLogits = returnContextLogits;
}
bool getReturnContextLogits() const
[[nodiscard]] bool getReturnContextLogits() const
{
return mReturnContextLogits;
}
@ -480,12 +489,12 @@ public:
mReturnGenerationLogits = returnGenerationLogits;
}
bool getReturnGenerationLogits() const
[[nodiscard]] bool getReturnGenerationLogits() const
{
return mReturnGenerationLogits;
}
TensorPtr const& getContextLogitsHost() const
[[nodiscard]] TensorPtr const& getContextLogitsHost() const
{
return mContextLogitsHost;
}
@ -501,7 +510,7 @@ public:
runtime::ITensor::makeShape({mPromptLen, vocabSizePadded}), logitsDataType);
}
TensorPtr const& getGenerationLogitsHost() const
[[nodiscard]] TensorPtr const& getGenerationLogitsHost() const
{
return mGenerationLogitsHost;
}
@ -517,7 +526,7 @@ public:
runtime::ITensor::makeShape({mSamplingConfig.beamWidth, mMaxNewTokens, vocabSizePadded}), logitsDataType);
}
std::vector<TensorPtr> const& getGenerationLogitsFragments() const
[[nodiscard]] std::vector<TensorPtr> const& getGenerationLogitsFragments() const
{
return mGenerationLogitsFragments;
}
@ -537,38 +546,38 @@ public:
mGenerationLogitsFragments.clear();
}
bool isContextInitState() const noexcept
[[nodiscard]] bool isContextInitState() const noexcept
{
return mState == REQUEST_STATE_CONTEXT_INIT;
}
bool isGenerationInProgessState() const noexcept
[[nodiscard]] bool isGenerationInProgessState() const noexcept
{
return mState == REQUEST_STATE_GENERATION_IN_PROGRESS;
}
/// To determine whether the context is unchunked. When a context is chunked into only a part, it
/// is still different from the unchunked state, which indicates the initial status.
bool isFullContextRequest() const noexcept
[[nodiscard]] bool isFullContextRequest() const noexcept
{
return isContextInitState() && !mContextChunkSize;
}
/// When chunked, the position of the current chunk is returned. Otherwise, only the beginning
/// or end of the context is returned.
SizeType getContextCurrentPosition() const noexcept
[[nodiscard]] SizeType getContextCurrentPosition() const noexcept
{
return mContextCurrentPosition;
}
/// Return the length of the context that has not yet been processed.
SizeType getContextRemainingLength() const noexcept
[[nodiscard]] SizeType getContextRemainingLength() const noexcept
{
return mPromptLen - getContextCurrentPosition();
}
/// To retrieve the context chunk size, throw an exception when the context is not chunked.
SizeType getContextChunkSize() const
[[nodiscard]] SizeType getContextChunkSize() const
{
TLLM_CHECK_WITH_INFO(
isContextInitState() && mContextChunkSize, "The current request is not in context chunking state.");
@ -587,7 +596,7 @@ public:
/// Determines whether the current position is only one chunk away from the end of the context.
/// It will return true when the context is not chunked.
bool isLastContextChunk() const noexcept
[[nodiscard]] bool isLastContextChunk() const noexcept
{
return isFullContextRequest()
|| (isContextInitState() && getContextCurrentPosition() + getContextChunkSize() == mPromptLen);
@ -595,7 +604,7 @@ public:
/// Returns whether the position is at the beginning of the context. It will return true when the
/// context is not chunked.
bool isFirstContextChunk() const noexcept
[[nodiscard]] bool isFirstContextChunk() const noexcept
{
return isFullContextRequest() || getContextCurrentPosition() == 0;
}
@ -704,6 +713,7 @@ public:
std::optional<SizeType> mEndId;
std::optional<SizeType> mPadId;
SizeType mSeqSlot;
std::optional<LogitsPostProcessor> mLogitsPostProcessor;
protected:
SizeType mOrigPromptLen;
@ -805,7 +815,7 @@ public:
using VecTokens = Base::VecTokens;
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<VecTokens> inputTokens,
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
@ -813,10 +823,13 @@ public:
std::optional<TensorPtr> loraConfig = std::nullopt, bool returnLogProbs = false,
bool returnContextLogits = false, bool returnGenerationLogits = false,
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt,
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false)
: Base(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, endId, padId, embeddingBias,
badWordsList, stopWordsList, promptEmbeddingTable, promptVocabSize, loraWeights, loraConfig, returnLogProbs,
returnContextLogits, returnGenerationLogits, draftTokens, draftLogits, excludeInputFromOutput)
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false,
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
: Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId,
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList),
std::move(promptEmbeddingTable), promptVocabSize, std::move(loraWeights), std::move(loraConfig),
returnLogProbs, returnContextLogits, returnGenerationLogits, std::move(draftTokens), std::move(draftLogits),
excludeInputFromOutput, std::move(logitsPostProcessor))
{
}

View File

@ -302,6 +302,8 @@ public:
void allgather(const void* sendbuf, void* recvbuf, int count, MpiType dtype) const;
void barrier() const;
void mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const;
bool operator==(MpiComm const& rhs) const
{
return mComm == rhs.mComm;

View File

@ -46,8 +46,8 @@ public:
std::optional<FloatType> beamSearchDiversityRate = std::nullopt,
std::optional<FloatType> repetitionPenalty = std::nullopt,
std::optional<FloatType> presencePenalty = std::nullopt,
std::optional<FloatType> frequencyPenalty = std::nullopt,
std::optional<FloatType> lengthPenalty = std::nullopt);
std::optional<FloatType> frequencyPenalty = std::nullopt, std::optional<FloatType> lengthPenalty = std::nullopt,
std::optional<SizeType> earlyStopping = std::nullopt);
~SamplingConfig();
@ -65,6 +65,7 @@ public:
[[nodiscard]] std::optional<FloatType> getPresencePenalty() const;
[[nodiscard]] std::optional<FloatType> getFrequencyPenalty() const;
[[nodiscard]] std::optional<FloatType> getLengthPenalty() const;
[[nodiscard]] std::optional<SizeType> getEarlyStopping() const;
private:
SizeType mBeamWidth;
@ -81,6 +82,7 @@ private:
std::optional<FloatType> mPresencePenalty;
std::optional<FloatType> mFrequencyPenalty;
std::optional<FloatType> mLengthPenalty;
std::optional<SizeType> mEarlyStopping;
};
/// @brief Configuration that controls the outputs of a Result

View File

@ -22,9 +22,7 @@
#include "tensorrt_llm/runtime/iTensor.h"
#include <NvInferRuntime.h>
#include <cstdint>
#include <memory>
#include <stdexcept>
#include <string>
#include <vector>

View File

@ -16,7 +16,6 @@
#pragma once
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h"

View File

@ -16,7 +16,6 @@
#pragma once
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h"

View File

@ -16,7 +16,6 @@
#pragma once
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/promptTuningParams.h"

View File

@ -16,7 +16,6 @@
#pragma once
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h"

View File

@ -16,7 +16,6 @@
#pragma once
#include "tensorrt_llm/common/cudaAllocator.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/decodingInput.h"
#include "tensorrt_llm/runtime/decodingMode.h"
@ -24,7 +23,6 @@
#include "tensorrt_llm/runtime/samplingConfig.h"
#include <curand_kernel.h>
#include <cstdint>
#include <memory>
#include <NvInferRuntime.h>

View File

@ -16,7 +16,6 @@
#pragma once
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/cudaStream.h"
@ -25,11 +24,7 @@
#include "tensorrt_llm/runtime/iGptDecoderBatch.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include <cstdint>
#include <memory>
#include <optional>
#include <tuple>
#include <utility>
#include <vector>
namespace tensorrt_llm::runtime
@ -58,7 +53,7 @@ public:
TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) override;
void forwardSync(decoder_batch::Token const& e) override;
void forwardSync(decoder_batch::Token const& token) override;
void forwardAsync(decoder::Output& output, decoder::Input const& input) override;
@ -89,7 +84,7 @@ public:
//! @brief Gather final beam search results for request `batchIdx`.
//! Result will only be available after event returned.
[[nodiscard]] CudaEvent finalize(SizeType batchIdx) const;
[[nodiscard]] CudaEvent finalize(SizeType batchIdx) const override;
//! @brief Gather final beam search results for all requests.
void finalize() const override;
@ -108,7 +103,7 @@ public:
}
//! @returns [maxBeamWidth], cumulative log probabilities (per beam), on gpu
[[nodiscard]] TensorPtr getCumLogProbs(SizeType batchIdx) const
[[nodiscard]] TensorPtr getCumLogProbs(SizeType batchIdx) const override
{
auto tensor = ITensor::slice(mJointDecodingOutput->cumLogProbs, batchIdx, 1);
tensor->squeeze(0);
@ -122,7 +117,7 @@ public:
}
//! @returns [maxBeamWidth, maxSequenceLength], log probabilities (per beam), on gpu
[[nodiscard]] TensorPtr getLogProbs(SizeType batchIdx) const
[[nodiscard]] TensorPtr getLogProbs(SizeType batchIdx) const override
{
auto tensor = ITensor::slice(mJointDecodingOutput->logProbs, batchIdx, 1);
tensor->squeeze(0);
@ -141,7 +136,7 @@ public:
//! @returns [batchSize, beamWidth], tokens generated in `iter` (per beam), on gpu
[[nodiscard]] TensorPtr getNewTokens(SizeType iter = 0) const override
{
TensorPtr newTokensView = std::move(ITensor::slice(mJointDecodingOutput->newTokensSteps, iter, 1));
TensorPtr newTokensView = ITensor::slice(mJointDecodingOutput->newTokensSteps, iter, 1);
newTokensView->squeeze(0);
return ITensor::slice(newTokensView, 0, mActualBatchSize);
}
@ -149,7 +144,7 @@ public:
//! @returns [batchSize], the number of generation steps executed on each request
[[nodiscard]] std::vector<SizeType> getNbSteps() const override
{
return std::vector<SizeType>(mNbSteps.begin(), mNbSteps.begin() + mActualBatchSize);
return {mNbSteps.begin(), mNbSteps.begin() + mActualBatchSize};
}
//! @returns [1], number of finished sequences, in pinned host memory
@ -160,7 +155,7 @@ public:
private:
//! @brief Gather final beam search results for request `batchIdx`.
CudaEvent postProcessRequest(SizeType batchIdx) const;
[[nodiscard]] CudaEvent postProcessRequest(SizeType batchIdx) const;
//! @brief Initialize the decoder at `batchIdx` with a new `request`.
void newRequest(SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig);

View File

@ -31,7 +31,6 @@
#include <cuda_fp16.h>
#include <memory>
#include <ostream>
#include <stdexcept>
#include <type_traits>
#include <typeinfo>
#include <vector>

View File

@ -16,14 +16,12 @@
#pragma once
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/iStatefulGptDecoder.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/utils/sessionUtils.h"
#include <cstdint>
#include <memory>
#include <utility>
#include <vector>
@ -156,30 +154,30 @@ public:
//! @param batchIdx index of the batch
//! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token
//! ids without padding for request `batchIdx`, on gpu
virtual TensorPtr getOutputIds(SizeType batchIdx) const = 0;
[[nodiscard]] virtual TensorPtr getOutputIds(SizeType batchIdx) const = 0;
//! @brief Gather final beam search results for request `batchIdx`.
//! Result will only be available after event returned
virtual CudaEvent finalize(SizeType batchIdx) const = 0;
[[nodiscard]] virtual CudaEvent finalize(SizeType batchIdx) const = 0;
//! @returns [batchSize (actual)], marks finished requests (per batch)
virtual std::vector<bool> getFinished() const = 0;
[[nodiscard]] virtual std::vector<bool> getFinished() const = 0;
//! @returns [batchSize, beamWidth], cumulative log probabilities (per beam), on gpu
virtual TensorPtr getCumLogProbs() const = 0;
[[nodiscard]] virtual TensorPtr getCumLogProbs() const = 0;
//! @returns [beamWidth], cumulative log probabilities (per beam) for request batchIdx, on gpu
virtual TensorPtr getCumLogProbs(SizeType batchIdx) const = 0;
[[nodiscard]] virtual TensorPtr getCumLogProbs(SizeType batchIdx) const = 0;
//! @returns [batchSize, beamWidth, maxSeqLen], log probabilities (per beam), on gpu
virtual TensorPtr getLogProbs() const = 0;
[[nodiscard]] virtual TensorPtr getLogProbs() const = 0;
//! @returns [beamWidth, maxSeqLen], cumulative log probabilities (per beam) for request batchIdx, on gpu
virtual TensorPtr getLogProbs(SizeType batchIdx) const = 0;
[[nodiscard]] virtual TensorPtr getLogProbs(SizeType batchIdx) const = 0;
virtual TensorPtr getParentIds() const = 0;
[[nodiscard]] virtual TensorPtr getParentIds() const = 0;
virtual std::vector<SizeType> getNbSteps() const = 0;
[[nodiscard]] virtual std::vector<SizeType> getNbSteps() const = 0;
//! @brief Initialize batched decoder at seqSlots with a new `requests`.
virtual void newRequests(std::vector<SizeType> const& seqSlots, std::vector<decoder_batch::Request> const& requests,

View File

@ -23,10 +23,8 @@
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/samplingConfig.h"
#include <cstdint>
#include <memory>
#include <utility>
#include <vector>
#include <NvInferRuntime.h>
@ -102,25 +100,25 @@ public:
virtual void finalize() const = 0;
//! @returns [batchSize, beamWidth, maxSequenceLength], all token ids, on gpu
virtual TensorPtr getOutputIds() const = 0;
[[nodiscard]] virtual TensorPtr getOutputIds() const = 0;
//! @returns [batchSize, maxBeamWidth], cumulative log probabilities (per beam), on gpu
virtual TensorPtr getCumLogProbs() const = 0;
[[nodiscard]] virtual TensorPtr getCumLogProbs() const = 0;
//! @returns [batchSize, maxBeamWidth, maxSequenceLength], log probabilities (per beam), on gpu
virtual TensorPtr getLogProbs() const = 0;
[[nodiscard]] virtual TensorPtr getLogProbs() const = 0;
//! @brief Get tokens generated in one step of last forward pass
//! @param iter The iteration within [0; maxTokensPerStep) for which to get the tokens
//! @returns [batchSize, beamWidth], tokens generated in `iter` (per beam), on gpu
virtual TensorPtr getNewTokens(SizeType iter = 0) const = 0;
[[nodiscard]] virtual TensorPtr getNewTokens(SizeType iter = 0) const = 0;
//! @brief Get maxTokensPerStep tokens generated in the last forward pass
//! @returns [maxTokensPerStep, batchSize, maxBeamWidth], tokens generated in last forward pass, on gpu
virtual TensorPtr getAllNewTokens() const = 0;
[[nodiscard]] virtual TensorPtr getAllNewTokens() const = 0;
//! @returns [1], number of finished sequences, in pinned host memory
virtual TensorPtr getNbFinished() const = 0;
[[nodiscard]] virtual TensorPtr getNbFinished() const = 0;
virtual ~IStatefulGptDecoder() = default;

View File

@ -30,7 +30,6 @@
#include <memory>
#include <numeric>
#include <ostream>
#include <stdexcept>
#include <string>
#include <type_traits>

View File

@ -18,7 +18,6 @@
#pragma once
#include "tensorrt_llm/kernels/customAllReduceKernels.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/worldConfig.h"

View File

@ -19,8 +19,8 @@
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include <algorithm>
#include <cstdint>
#include <atomic>
#include <cstddef>
#include <string>
namespace tensorrt_llm::runtime
@ -112,22 +112,22 @@ public:
auto const sizeDiff = -static_cast<DiffType>(size);
if constexpr (T == MemoryType::kGPU)
{
mGpu -= std::min(size, mGpu);
mGpu -= size;
mGpuDiff = sizeDiff;
}
else if constexpr (T == MemoryType::kCPU)
{
mCpu -= std::min(size, mCpu);
mCpu -= size;
mCpuDiff = sizeDiff;
}
else if constexpr (T == MemoryType::kPINNED)
{
mPinned -= std::min(size, mPinned);
mPinned -= size;
mPinnedDiff = sizeDiff;
}
else if constexpr (T == MemoryType::kUVM)
{
mUVM -= std::min(size, mUVM);
mUVM -= size;
mUVMDiff = sizeDiff;
}
else
@ -138,11 +138,7 @@ public:
void deallocate(MemoryType memoryType, SizeType size);
static MemoryCounters& getInstance()
{
static thread_local MemoryCounters mInstance;
return mInstance;
}
static MemoryCounters& getInstance();
static std::string bytesToString(SizeType bytes, int precision = 2);
@ -151,8 +147,8 @@ public:
[[nodiscard]] std::string toString() const;
private:
SizeType mGpu{}, mCpu{}, mPinned{}, mUVM{};
DiffType mGpuDiff{}, mCpuDiff{}, mPinnedDiff{}, mUVMDiff{};
std::atomic<SizeType> mGpu{}, mCpu{}, mPinned{}, mUVM{};
std::atomic<DiffType> mGpuDiff{}, mCpuDiff{}, mPinnedDiff{}, mUVMDiff{};
};
} // namespace tensorrt_llm::runtime

View File

@ -16,12 +16,10 @@
#pragma once
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include <optional>
#include <utility>
namespace tensorrt_llm::runtime

View File

@ -72,6 +72,7 @@ public:
{
TLLM_CHECK(configs.size() > 0);
beamWidth = configs.front().beamWidth;
normalizeLogProbs = configs.front().normalizeLogProbs;
temperature = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].temperature; });
minLength = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].minLength; });
repetitionPenalty
@ -87,6 +88,7 @@ public:
beamSearchDiversityRate
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].beamSearchDiversityRate; });
lengthPenalty = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].lengthPenalty; });
earlyStopping = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].earlyStopping; });
draftAcceptanceThreshold
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].draftAcceptanceThreshold; });
}
@ -121,6 +123,7 @@ public:
SET_FROM_OPTIONAL(presencePenalty, PresencePenalty, FloatType)
SET_FROM_OPTIONAL(frequencyPenalty, FrequencyPenalty, FloatType)
SET_FROM_OPTIONAL(lengthPenalty, LengthPenalty, FloatType)
SET_FROM_OPTIONAL(earlyStopping, EarlyStopping, SizeType)
#undef SET_FROM_OPTIONAL
}
@ -142,11 +145,12 @@ public:
OptVec<SizeType> topPResetIds; // [batch_size]
// beam search layer
OptVec<FloatType> beamSearchDiversityRate;
OptVec<FloatType> lengthPenalty;
OptVec<FloatType> beamSearchDiversityRate; // [1] or [batch_size]
OptVec<FloatType> lengthPenalty; // [1] or [batch_size]
OptVec<SizeType> earlyStopping; // [1] or [batch_size]
// speculative decoding
OptVec<FloatType> draftAcceptanceThreshold;
// speculative decoding, only the first value is used (in gptDecoderBatch.cpp)
OptVec<FloatType> draftAcceptanceThreshold; // [1] or [batch_size]
std::optional<bool> normalizeLogProbs;
};

View File

@ -195,7 +195,8 @@ set(TRTLLM_LINK_LIBS
${TRT_LIB}
common_src
kernels_src
cutlass_src
cutlass_src_pre_hopper
cutlass_src_hopper
layers_src
runtime_src)

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4ba61c04ed7623fc44b5364802c1893fa824467455f4a9fe8245d5d51fef97e6
size 2172266
oid sha256:c9fd644e0a38b1d4d1a54d4b7b834cc6b0110a5771fcfc480e96795b3f9bc892
size 2081046

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bf4afdfd281029c8e4bf0af548529b94a4a6d0f9bb5148ae10423e5e0275db06
size 2195822
oid sha256:90436c59eb243a0156e3f0aa95412a7caacbefdcde768c158edc4b821044dfd1
size 2102486

View File

@ -1,3 +1,3 @@
4c405d39a0cbb93d44a5758480a1a223 libtensorrt_llm_batch_manager_static.a
68aea75a2ed5b219eec5a0f77ce33482 libtensorrt_llm_batch_manager_static.pre_cxx11.a
9b63c754d2a1edc7a17106e83c3e131d312f0a80 commit
f53c02e3829b516a6e9221745bcbacbd libtensorrt_llm_batch_manager_static.a
9e92e5dbb104e3e676952ea40c81916f libtensorrt_llm_batch_manager_static.pre_cxx11.a
25adff90cc350eb9ca9804051a08de80d547c113 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:39835ca321e9c45d3b554ebceb1734966b75f83dbe8c550cc44846fb4fae8f72
size 2110728
oid sha256:c3433d7b52bb6dcac32111172cb6201a9fee56e739f3660895083baebd1b89ee
size 2033616

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:789c2eba349161e84a76b95b23f8294cf3bdcf855871672d76722c4ae858d81b
size 2091842
oid sha256:fb3f4145881984de6268c34f7e5d452f78f54952f454f747a1cd52bc3171de62
size 2012002

View File

@ -1,2 +1,2 @@
30a6c963121b3cfda21dc0117b7984e1 libtensorrt_llm_batch_manager_static.a
0d2d2e3157201f6336d749b3e6f994bc libtensorrt_llm_batch_manager_static.pre_cxx11.a
d60b12741e940f56addaf2d92e78b50f libtensorrt_llm_batch_manager_static.a
c55e606a3430d3a56cee3968a77b46f1 libtensorrt_llm_batch_manager_static.pre_cxx11.a

View File

@ -172,6 +172,11 @@ void MpiComm::allgather(const void* sendbuf, void* recvbuf, int count, MpiType d
MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm));
}
void MpiComm::mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const
{
MPICHECK(MPI_Mprobe(source, tag, mComm, msg, status));
}
int MpiComm::getRank() const
{
int rank = 0;

View File

@ -34,6 +34,8 @@ enum class CutlassTileConfig
CtaShape128x128x8_WarpShape64x64x8,
// TensorCore configs CTA_N = 128, CTA_K = 64
// Warp configs for M=16
CtaShape16x128x64_WarpShape16x32x64,
// Warp configs for M=32
CtaShape32x128x64_WarpShape32x32x64,
@ -50,7 +52,10 @@ enum class CutlassTileConfig
CtaShape128x256x64_WarpShape64x64x64,
// Warp configs for M=256
CtaShape256x128x64_WarpShape64x64x64
CtaShape256x128x64_WarpShape64x64x64,
// TensorCore config CTA_N = 256, CTA_K = 64
CtaShape16x256x64_WarpShape16x64x64
};
enum class SplitKStyle

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:49c84a22cee9e6c3a975db08d8d0d8dbe88867e2eb4fc12a4b3ff6c1c90e8c21
size 586202
oid sha256:13e17e2d9a94d2bc1b131d096a3722a83a67ab115fa8271b57b27f7e2877bdc1
size 587334

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fea93ae7d09e74b073a65d5d0ac34aec9ccc8f8299af1abd6826e97e9c8427f4
size 589652
oid sha256:45438204eba812694bd30b68cfc9bb2bc54a8a59c6c86e037bbc4ac7e5f8230c
size 589438

View File

@ -1,3 +1,3 @@
73999f4c2b3a4328db454b7ab6fe86d3 libtensorrt_llm_executor_static.a
df53aa83848b5ed75550a7b536ca02a4 libtensorrt_llm_executor_static.pre_cxx11.a
9b63c754d2a1edc7a17106e83c3e131d312f0a80 commit
835767a37292ea9786c0d6149ae270f4 libtensorrt_llm_executor_static.a
1fe0c9ac7a1a35ce7d80676146867374 libtensorrt_llm_executor_static.pre_cxx11.a
25adff90cc350eb9ca9804051a08de80d547c113 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:643e546711fd33a85073560e3428c6a2f60525f7592aa3328043dfad61631c30
size 586532
oid sha256:7969768d3b9a65182ee519c60e11f27b0a088c2c0b732f3780d7c0c563dbb180
size 587776

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:28059131a9325c88bd362cb12c57a2b2e47d3e0aac140e5d1cf9a7020a81999e
size 570860
oid sha256:98d9b7c4a586f0be0499a0df487cacba69985ce43ca5fd543c90c6a368c91b67
size 571150

View File

@ -1,2 +1,2 @@
fa89714705a1915f052c635a07dc4c73 libtensorrt_llm_executor_static.a
83cbfaf10bedd7d8edeab33552dcf3df libtensorrt_llm_executor_static.pre_cxx11.a
6771f94e0bce39c6cab391cf1f92484c libtensorrt_llm_executor_static.a
84b7550448f8710de17644a5d404178f libtensorrt_llm_executor_static.pre_cxx11.a

View File

@ -28,7 +28,7 @@ if(FAST_BUILD)
"decoderMaskedMultiheadAttention(48|80|96|112|144|160|192|224).*cu$")
endif()
add_library(kernels_src OBJECT ${SRC_CPP} ${SRC_CU})
add_library(kernels_src STATIC ${SRC_CPP} ${SRC_CU})
set_property(TARGET kernels_src PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET kernels_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)

View File

@ -39,7 +39,7 @@ namespace kernels
template <typename T>
__device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty)
{
// score = log(prob) / (length)^length_penalty.
// score = log(prob) / (length ^ length_penalty)
if (length_penalty == 0.0f || length == 1)
{
return log_prob;

View File

@ -35,35 +35,42 @@ namespace kernels
// After we collect `beam_width` beams, we will sort them by their norm_scores.
struct BeamHypotheses
{
int* output_ids_tgt = nullptr;
int* sequence_lengths_tgt = nullptr;
float* cum_log_probs = nullptr; // cum_log
float* normed_scores = nullptr; // cum_log / (length**length_penalty)
float* log_probs = nullptr; // log probs of each generated token
float* min_normed_scores = nullptr; // record the min normed scores for each batch
int* num_beams = nullptr; // the number of finished beams we collect
bool* is_done = nullptr;
// TODO: simplify the pointers
// Pointers initialized in function prepareOutputs in gptDecoder.cpp
bool* is_done{nullptr}; // [batchSize], whether the batch is finished
const int* input_lengths{nullptr}; // [batchSize]
float* cum_log_probs{nullptr}; // [batchSize, 2 * beamWidth], outputs.cum_log_probs->template getPtr<float>()
float* log_probs{nullptr}; // [batchSize, 2 * beamWidth, maxSeqLen], not used?
float* min_normed_scores{nullptr}; // [batchSize], worst normed scores for each batch
float* normed_scores{nullptr}; // [batchSize, 2 * beamWidth], cum_log / (length ^ length_penalty)
int* num_beams{nullptr}; // [batchSize], count of finished beams for each batch
int* output_ids_tgt{nullptr}; // [batchSize, 2 * beamWidth, maxSeqLen],
int* sequence_lengths_tgt{nullptr}; // [batchSize, 2 * beamWidth], different from sequence_lengths_src
// Used to set inputs
const int* output_ids_src;
const int** output_ids_src_ptr;
const int* parent_ids_src;
const int** parent_ids_src_ptr;
const int* sequence_lengths_src;
const int* end_ids;
const float* log_probs_src;
const int* input_lengths;
// Pointers initialized in function invokeSoftMax in onlineBeamSearchLayer.cu
const int* end_ids{nullptr}; // get from SoftmaxParams
const int* output_ids_src{nullptr}; // for gatherTree
const int* parent_ids_src{nullptr}; // for gatherTree
const int** output_ids_src_ptr{nullptr}; // get from BeamSearchOutputParams for reading
const int** parent_ids_src_ptr{nullptr}; // get from BeamSearchOutputParams for reading
float* log_probs_src{nullptr}; // get from outputs.output_log_probs
int* sequence_lengths_src{nullptr}; // get from BeamSearchOutputParams
// For reading in function invokeTopkSoftMax but reading and writing in function invokeUpdate
int** output_ids_tgt_ptr{nullptr}; // get from BeamSearchOutputParams for writing
int** parent_ids_tgt_ptr{nullptr}; // get from BeamSearchOutputParams for writing
// some variables for kernels
int step;
int ite;
int batch_size;
int local_batch_size;
int max_seq_len;
float* length_penalties;
bool early_stopping = true;
bool is_return_normed_score = true; // return normed_cum_log_probs or cum_log_probs
// Other scalar values and buffers
int batch_size{0};
int beam_width{0};
int ite{0};
int local_batch_size{0};
int max_seq_len{0};
int step{0}; // useless in online version of beam search
int vocab_size{0};
float* diversity_rates{nullptr};
float* length_penalties{nullptr};
int* early_stoppings{nullptr};
bool is_return_normed_score{true}; // return normed_cum_log_probs or cum_log_probs
};
template <typename T>

View File

@ -57,9 +57,17 @@ endif()
file(GLOB_RECURSE CU_INSTANTIATIONS ${CMAKE_CURRENT_BINARY_DIR}/*.cu)
add_library(cutlass_src OBJECT ${SRC_CPP} ${SRC_CU} ${CU_INSTANTIATIONS})
set_property(TARGET cutlass_src PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cutlass_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_library(cutlass_src_pre_hopper STATIC ${SRC_CPP} ${SRC_CU})
set_property(TARGET cutlass_src_pre_hopper PROPERTY POSITION_INDEPENDENT_CODE
ON)
set_property(TARGET cutlass_src_pre_hopper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS
ON)
add_library(cutlass_src_hopper STATIC ${CU_INSTANTIATIONS})
set_property(TARGET cutlass_src_hopper PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cutlass_src_hopper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_dependencies(cutlass_src_hopper cutlass_src_pre_hopper)
# Note - we deliberately do not include 90a PTX (even when 9.0+PTX is
# specified). This is because sm_90a has arch conditional instructions that are
@ -70,15 +78,21 @@ if("9.0" IN_LIST TORCH_CUDA_ARCH_LIST
OR TORCH_CUDA_ARCH_LIST STREQUAL "Auto")
message(STATUS "MANUALLY APPENDING FLAG TO COMPILE FOR SM_90a.")
target_compile_options(
cutlass_src
cutlass_src_pre_hopper
PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-gencode=arch=compute_90a,code=sm_90a>)
target_compile_options(
cutlass_src_hopper
PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-gencode=arch=compute_90a,code=sm_90a>)
# Hopper kernels require cuda lib for TMA APIs
target_link_libraries(cutlass_src PRIVATE CUDA::cuda_driver)
target_link_libraries(cutlass_src_pre_hopper PRIVATE CUDA::cuda_driver)
target_link_libraries(cutlass_src_hopper PRIVATE CUDA::cuda_driver)
# No kernels should be parsed, unless hopper is specified. This is a build
# time improvement
target_compile_definitions(cutlass_src
target_compile_definitions(cutlass_src_pre_hopper
PRIVATE COMPILE_HOPPER_MIXED_INPUT_GEMMS)
target_compile_definitions(cutlass_src_hopper
PRIVATE COMPILE_HOPPER_MIXED_INPUT_GEMMS)
endif()
@ -87,5 +101,9 @@ endif()
# compilation output.
if(NOT WIN32)
target_compile_options(
cutlass_src PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-psabi>)
cutlass_src_pre_hopper
PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-psabi>)
target_compile_options(
cutlass_src_hopper
PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-psabi>)
endif()

View File

@ -52,6 +52,8 @@ TileShape get_cta_shape_for_config(CutlassTileConfig tile_config)
{
switch (tile_config)
{
case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: return TileShape{16, 128};
case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: return TileShape{16, 256};
case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: return TileShape{32, 128};
case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: return TileShape{64, 64};
case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64:
@ -145,7 +147,9 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
case CutlassGemmType::WeightOnly:
if (sm >= 75)
{
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
return {CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64,
CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64,
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64};
}

View File

@ -127,9 +127,8 @@ LayoutDetails getLayoutDetailsForArch(QuantType quant_type)
return details;
}
LayoutDetails getLayoutDetailsForTransform(QuantType quant_type)
LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch)
{
const int arch = getSMVersion();
if (arch >= 70 && arch < 75)
{
return getLayoutDetailsForArch<cutlass::arch::Sm70>(quant_type);
@ -534,10 +533,15 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const
}
void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight,
const std::vector<size_t>& shape, QuantType quant_type)
const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave)
{
const int arch = getSMVersion();
LayoutDetails details = getLayoutDetailsForTransform(quant_type);
int arch = getSMVersion();
if (force_interleave && arch == 90)
{
// Workaround for MOE which doesn't have specialised Hopper kernels yet
arch = 80;
}
LayoutDetails details = getLayoutDetailsForTransform(quant_type, arch);
TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
@ -616,7 +620,8 @@ Outputs
template <typename ComputeType, typename WeightType>
void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight,
ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector<size_t>& shape, QuantType quant_type)
ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector<size_t>& shape, QuantType quant_type,
bool force_interleave)
{
TLLM_CHECK_WITH_INFO(processed_quantized_weight, "Processed quantized tensor is NULL");
@ -719,48 +724,52 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
}
}
preprocess_weights_for_mixed_gemm(processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type);
preprocess_weights_for_mixed_gemm(
processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type, force_interleave);
}
template void symmetric_quantize<half, float>(
int8_t*, int8_t*, half*, const float*, const std::vector<size_t>&, QuantType);
int8_t*, int8_t*, half*, const float*, const std::vector<size_t>&, QuantType, bool);
template void symmetric_quantize<half, half>(
int8_t*, int8_t*, half*, const half*, const std::vector<size_t>&, QuantType);
int8_t*, int8_t*, half*, const half*, const std::vector<size_t>&, QuantType, bool);
#ifdef ENABLE_BF16
template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>(
int8_t*, int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType);
int8_t*, int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType, bool);
template void symmetric_quantize<__nv_bfloat16, float>(
int8_t*, int8_t*, __nv_bfloat16*, const float*, const std::vector<size_t>&, QuantType);
int8_t*, int8_t*, __nv_bfloat16*, const float*, const std::vector<size_t>&, QuantType, bool);
#endif
template <typename ComputeType, typename WeightType>
void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, const WeightType* input_weight_ptr,
const std::vector<size_t>& shape, QuantType quant_type)
const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave)
{
symmetric_quantize(processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type);
symmetric_quantize(
processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type, force_interleave);
}
template void symmetric_quantize<float, float>(int8_t*, float*, const float*, const std::vector<size_t>&, QuantType);
template void symmetric_quantize<float, float>(
int8_t*, float*, const float*, const std::vector<size_t>&, QuantType, bool);
template void symmetric_quantize<half, float>(int8_t*, half*, const float*, const std::vector<size_t>&, QuantType);
template void symmetric_quantize<half, float>(
int8_t*, half*, const float*, const std::vector<size_t>&, QuantType, bool);
template void symmetric_quantize<half, half>(int8_t*, half*, const half*, const std::vector<size_t>&, QuantType);
template void symmetric_quantize<half, half>(int8_t*, half*, const half*, const std::vector<size_t>&, QuantType, bool);
#ifdef ENABLE_BF16
template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>(
int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType);
int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType, bool);
template void symmetric_quantize<__nv_bfloat16, half>(
int8_t*, __nv_bfloat16*, const half*, const std::vector<size_t>&, QuantType);
int8_t*, __nv_bfloat16*, const half*, const std::vector<size_t>&, QuantType, bool);
template void symmetric_quantize<half, __nv_bfloat16>(
int8_t*, half*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType);
int8_t*, half*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType, bool);
template void symmetric_quantize<__nv_bfloat16, float>(
int8_t*, __nv_bfloat16*, const float*, const std::vector<size_t>&, QuantType);
int8_t*, __nv_bfloat16*, const float*, const std::vector<size_t>&, QuantType, bool);
#endif
} // namespace cutlass_kernels

View File

@ -47,17 +47,18 @@ void subbyte_transpose(int8_t* transposed_quantized_tensor, const int8_t* quanti
void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type);
void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight,
const std::vector<size_t>& shape, QuantType quant_type);
const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave = false);
template <typename ComputeType, typename WeightType>
void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, const WeightType* input_weight_ptr,
const std::vector<size_t>& shape, QuantType quant_type);
const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave);
// This is exposed so that we can write tests that use the processed weights for CUTLASS but the unprocessed weight
// to implement a simple reference implementation.
template <typename ComputeType, typename WeightType>
void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight,
ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector<size_t>& shape, QuantType quant_type);
ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector<size_t>& shape, QuantType quant_type,
bool force_interleave);
} // namespace cutlass_kernels
} // namespace kernels

View File

@ -84,7 +84,7 @@ public:
protected:
static constexpr int SPLIT_K_LIMIT = 7;
static constexpr int MIN_M_TILE = 32;
static constexpr int MIN_M_TILE = 16;
static constexpr int MIN_N_TILE = 64;
};

View File

@ -324,6 +324,26 @@ void dispatch_gemm_to_cutlass(const ActivationType* A, const WeightType* B, cons
// best for mixed type gemms.
switch (gemm_config.tile_config)
{
case tkc::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
if constexpr (arch::kMinComputeCapability >= 75)
{
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
cutlass::gemm::GemmShape<16, 128, 64>, cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales,
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
stream, occupancy);
}
break;
case tkc::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
if constexpr (arch::kMinComputeCapability >= 75)
{
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
cutlass::gemm::GemmShape<16, 256, 64>, cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales,
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
stream, occupancy);
}
break;
case tkc::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
cutlass::gemm::GemmShape<32, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales,
@ -337,11 +357,8 @@ void dispatch_gemm_to_cutlass(const ActivationType* A, const WeightType* B, cons
stream, occupancy);
break;
case tkc::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
if (arch::kMinComputeCapability < 75)
{
TLLM_CHECK_WITH_INFO(false, "Invalid config on Volta");
}
else
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
if constexpr (arch::kMinComputeCapability >= 75)
{
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales,
@ -433,41 +450,6 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
}
}
// Disabled since the fused GEMM, activation kernels will not be used in v1.
// template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
// void CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::gemm_bias_act(const T* A, const WeightType* B, const T*
// weight_scales,
// const T* biases, T* C, int m, int n, int k, ActivationType activation_type, char* workspace_ptr,
// const size_t workspace_bytes, cudaStream_t stream)
// {
// TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
// switch (activation_type)
// {
// case ActivationType::Relu:
// run_gemm<tkc::EpilogueOpBiasReLU>(
// A, B, weight_scales, biases, C, m, n, k, workspace_ptr, workspace_bytes, stream);
// break;
// case ActivationType::Gelu:
// run_gemm<tkc::EpilogueOpBiasFtGelu>(
// A, B, weight_scales, biases, C, m, n, k, workspace_ptr, workspace_bytes, stream);
// break;
// case ActivationType::Silu:
// run_gemm<tkc::EpilogueOpBiasSilu>(
// A, B, weight_scales, biases, C, m, n, k, workspace_ptr, workspace_bytes, stream);
// break;
// case ActivationType::Identity:
// run_gemm<tkc::EpilogueOpBias>(A, B, weight_scales, biases, C, m, n, k, workspace_ptr, workspace_bytes,
// stream); break;
// case ActivationType::InvalidType: TLLM_CHECK_WITH_INFO(false, "Activation type for fpA_intB must be
// valid."); break; default:
// {
// TLLM_CHECK_WITH_INFO(false, "Invalid activation type.");
// }
// }
// }
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
typename BiasType, typename OutputType>
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(

View File

@ -231,6 +231,24 @@ void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_s
{
switch (gemm_config.tile_config)
{
case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
if constexpr (arch::kMinComputeCapability >= 75)
{
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 128, 64>,
cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert,
total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
}
break;
case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
if constexpr (arch::kMinComputeCapability >= 75)
{
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 256, 64>,
cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert,
total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
}
break;
case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows,
@ -266,6 +284,24 @@ void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_s
{
switch (gemm_config.tile_config)
{
case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
if constexpr (arch::kMinComputeCapability >= 75)
{
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 128, 64>,
cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert,
total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
}
break;
case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
if constexpr (arch::kMinComputeCapability >= 75)
{
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 256, 64>,
cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert,
total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
}
break;
case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows,

View File

@ -53,6 +53,7 @@ struct XQAParams
// almost copy from GPTAttentionPluginCommon.
// maybe use one struct for parameters in GPTAttentionPluginCommon and share the same here.
int32_t generation_input_length;
int32_t layer_idx = 0;
int32_t num_q_heads = 0;
int32_t num_kv_heads = 0;
int32_t head_size = 0;

View File

@ -121,7 +121,7 @@ __global__ void gatherTree(gatherTreeParam param)
template <typename T>
__device__ __forceinline__ T applyLengthPenalty(T logProb, int length, float lengthPenalty)
{
// score = log(prob) / (length)^lengthPenalty.
// score = log(prob) / (length ^ lengthPenalty)
if (lengthPenalty == 0.0f || length == 1)
{
return logProb;

View File

@ -46,7 +46,8 @@ struct gatherTreeParam
int32_t* outputIds = nullptr; // the buffer to put finalized ids
cudaStream_t stream;
float* cumLogProbs = nullptr; // [batchSize, beamWidth]
float lengthPenalty = 1.0f; // on cpu
float lengthPenalty = 1.0f;
int earlyStopping = 1;
};
/*

View File

@ -62,7 +62,7 @@ void groupedGemm_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vect
std::vector<void*> ptrC, std::vector<void*> ptrD, void* gemmParamsWorkSpace, int64_t gemmParamsWorkSpaceSize,
void* gemmWorkSpace, int64_t gemmWorkspaceSize, nvinfer1::DataType dataType, cudaStream_t stream)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
using ElementA = cutlassType;
using ElementB = cutlassType;
using ElementOutput = cutlassType;
@ -167,7 +167,7 @@ void groupedGemm_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vect
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to run CUTLASS Grouped GEMM kernel.");
std::free(host_workspace);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <int M1, int N1, int K1, int M2, int N2, int K2>

View File

@ -25,27 +25,20 @@ namespace kernels
{
template <typename T, int MAX_K>
void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const FinishedState* finished,
const int* sequence_lengths, float* cum_log_probs, float* output_log_probs, int** output_ids_ptr,
void* temp_storage, const int temp_storage_size, BeamHypotheses* beam_hyps, const int batch_size,
const int beam_width, const int vocab_size, const int* end_ids, const float* diversity_rates,
const float* length_penalties, cudaStream_t stream);
void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const FinishedState* finished, float* cum_log_probs,
void* temp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps, cudaStream_t stream);
#define CASE_K(MAX_K) \
topK_softMax_kernelLauncher<T, MAX_K>(log_probs, bias, finished, sequence_lengths, cum_log_probs, \
output_log_probs, output_ids_ptr, temp_storage, temp_storage_size, beam_hyps, batch_size, beam_width, \
vocab_size, end_ids, diversity_rates, length_penalties, stream); \
topK_softMax_kernelLauncher<T, MAX_K>( \
log_probs, bias, finished, cum_log_probs, temp_storage, temp_storage_size, beam_hyps, stream); \
break;
template <typename T>
void invokeTopkSoftMax(const T* log_probs, const T* bias, const FinishedState* finished, const int* sequence_lengths,
float* cum_log_probs, float* output_log_probs, int** output_ids_ptr, void* temp_storage,
const int temp_storage_size, BeamHypotheses* beam_hyps, const int batch_size, const int beam_width,
const int vocab_size, const int* end_ids, const float* diversity_rates, const float* length_penalties,
cudaStream_t stream)
void invokeTopkSoftMax(const T* log_probs, const T* bias, const FinishedState* finished, float* cum_log_probs,
void* temp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps, cudaStream_t stream)
{
int log_beam_width(0);
int recursor(beam_width - 1);
int recursor(beam_hyps.beam_width - 1);
while (recursor >>= 1)
++log_beam_width;
@ -66,23 +59,19 @@ void invokeTopkSoftMax(const T* log_probs, const T* bias, const FinishedState* f
CASE_K(64)
#endif // FAST_BUILD
default:
throw std::runtime_error(
fmtstr("%s:%d Topk kernel of beam search does not support beam_width=%d", __FILE__, __LINE__, beam_width));
throw std::runtime_error(fmtstr("%s:%d Topk kernel of beam search does not support beam_width=%d", __FILE__,
__LINE__, beam_hyps.beam_width));
}
}
#undef CASE_K
template void invokeTopkSoftMax<float>(const float* log_probs, const float* bias, const FinishedState* finished,
const int* sequence_lengths, float* cum_log_probs, float* output_log_probs, int** output_ids_ptr, void* tmp_storage,
const int temp_storage_size, BeamHypotheses* beam_hyps, const int batch_size, const int beam_width,
const int vocab_size, const int* end_ids, const float* diversity_rates, const float* length_penalties,
float* cum_log_probs, void* tmp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps,
cudaStream_t stream);
template void invokeTopkSoftMax<half>(const half* log_probs, const half* bias, const FinishedState* finished,
const int* sequence_lengths, float* cum_log_probs, float* output_log_probs, int** output_ids_ptr, void* tmp_storage,
const int temp_storage_size, BeamHypotheses* beam_hyps, const int batch_size, const int beam_width,
const int vocab_size, const int* end_ids, const float* diversity_rates, const float* length_penalties,
float* cum_log_probs, void* tmp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps,
cudaStream_t stream);
} // namespace kernels

View File

@ -24,10 +24,8 @@ namespace kernels
{
template <typename T>
void invokeTopkSoftMax(const T* log_probs, const T* bias, const FinishedState* finished, const int* sequence_lengths,
float* cum_log_probs, float* output_log_probs, int** output_ids_ptr, void* tmp_storage, const int temp_storage_size,
BeamHypotheses* beam_hyps, const int batch_size, const int beam_width, const int vocab_size, const int* end_ids,
const float* diversity_rates, const float* length_penalties, cudaStream_t stream);
void invokeTopkSoftMax(const T* log_probs, const T* bias, const FinishedState* finished, float* cum_log_probs,
void* tmp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -44,7 +44,7 @@ static const int SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE = 256;
template <typename T>
__device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty)
{
// score = log(prob) / (length)^length_penalty.
// score = log(prob) / (length ^ length_penalty).
if (length_penalty == 0.0f || length == 1)
{
return log_prob;
@ -56,9 +56,10 @@ template <typename T, int MAX_K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__
void batch_topK_kernel(int* topk_tmp_id_buf, T* topk_tmp_val_buf, int* id_buf)
{
int thread_id = threadIdx.x;
int block_id = blockIdx.x;
const int thread_id = threadIdx.x;
const int block_id = blockIdx.x;
TopK<T, MAX_K> partial;
if (thread_id == 0)
{
for (int i = 0; i < MAX_K; ++i)
@ -70,7 +71,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
int index = block_id * MAX_K * MAX_K;
for (int i = 0; i < MAX_K * MAX_K; i++)
{
partial.insert((T) topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]);
partial.insert(topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]);
}
index = block_id * MAX_K;
@ -85,9 +86,10 @@ template <typename T, int MAX_K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(const int* __restrict topk_tmp_id_buf,
const T* __restrict topk_tmp_val_buf, int* __restrict id_buf, T* __restrict val_buf)
{
int thread_id = threadIdx.x;
int block_id = blockIdx.x;
const int thread_id = threadIdx.x;
const int block_id = blockIdx.x;
TopK<T, MAX_K> partial;
if (thread_id == 0)
{
for (int i = 0; i < MAX_K; ++i)
@ -99,7 +101,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(const int*
int index = block_id * MAX_K * MAX_K;
for (int i = 0; i < MAX_K * MAX_K; i++)
{
partial.insert((T) topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]);
partial.insert(topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]);
}
index = block_id * MAX_K;
@ -112,34 +114,37 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(const int*
}
template <typename T, int MAX_K2, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__
void batch_topk_kernel(const int* __restrict x, const T* __restrict y, int** output_ids_ptr, float* __restrict v,
float* output_log_probs, const FinishedState* finished, const int* sequence_lengths, BeamHypotheses beam_hyps,
const int V, const int K, const int vocab_size, const float* length_penalties, const float* diversity_rates)
__launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel(const int* __restrict topk_tmp_id_buf,
const T* __restrict topk_tmp_val_buf, float* __restrict cum_log_probs, const FinishedState* finished,
BeamHypotheses beam_hyps, const int candidate_size)
{
int thread_id = threadIdx.x;
int vector_id = blockIdx.x;
const int thread_id = threadIdx.x;
const int vector_id = blockIdx.x;
const int K{beam_hyps.beam_width};
const int vocab_size{beam_hyps.vocab_size};
const int global_batch_idx{beam_hyps.ite * beam_hyps.local_batch_size + vector_id};
const T diversity_rate{diversity_rates[global_batch_idx]};
const float length_penalty{length_penalties[global_batch_idx]};
const T MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
// reposition x, y to data for the current vector
x += vector_id * V;
y += vector_id * V;
extern __shared__ char buf_s_[]; // intermediate result
T* buf_s = reinterpret_cast<T*>(buf_s_);
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
const float length_penalty{beam_hyps.length_penalties[global_batch_idx]};
const int early_stopping{beam_hyps.early_stoppings[global_batch_idx]};
const int* sequence_lengths{beam_hyps.sequence_lengths_src};
const T diversity_rate{beam_hyps.diversity_rates[global_batch_idx]};
float* output_log_probs{beam_hyps.log_probs_src};
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduce = cub::BlockReduce<cub_kvp, THREADBLOCK_SIZE>;
extern __shared__ char buf_s_[]; // intermediate result
T* buf_s = reinterpret_cast<T*>(buf_s_);
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ int selected_beams;
__shared__ float old_cum_log_probs[MAX_K2];
__shared__ char cta_topk_store[MAX_K2 * sizeof(cub_kvp)];
auto* cta_topk = reinterpret_cast<cub_kvp*>(cta_topk_store);
__shared__ cub_kvp cta_topk[MAX_K2];
__shared__ int selected_beams;
__shared__ int thread_requiring_update;
// reposition topk_tmp_id_buf, topk_tmp_val_buf to data for the current vector
topk_tmp_id_buf += vector_id * candidate_size;
topk_tmp_val_buf += vector_id * candidate_size;
if (thread_id == 0)
{
@ -147,47 +152,56 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
}
if (thread_id < K)
{
old_cum_log_probs[thread_id] = v[vector_id * K + thread_id];
old_cum_log_probs[thread_id] = cum_log_probs[vector_id * K + thread_id];
}
__syncthreads();
if (beam_hyps.num_beams != nullptr)
{
// initialize worst_score if this batch has no finished beam
if (beam_hyps.num_beams[global_batch_idx] == 0 && thread_id == 0)
{
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX;
}
// return if this batch has enough finished beams
else if (beam_hyps.num_beams[global_batch_idx] == K)
{
return;
}
}
// Get top 2K tokens from cadidates
cub::ArgMax arg_max;
cub_kvp partial_topk{V - 1, -MAX_T_VAL};
cub_kvp partial_topk{candidate_size - 1, -MAX_T_VAL};
for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE)
for (int elem_id = thread_id; elem_id < candidate_size; elem_id += THREADBLOCK_SIZE)
{
int i = beam_hyps.num_beams == nullptr ? elem_id % K : elem_id / 2 / K;
T elem = length_penalty == 0.0f
? y[elem_id]
: apply_length_penalty(y[elem_id],
finished[vector_id * K + i].isFinished() ? sequence_lengths[vector_id * K + i]
: sequence_lengths[vector_id * K + i] + 1,
length_penalty);
T elem = topk_tmp_val_buf[elem_id];
if (length_penalty > 0.0f)
{
int length = sequence_lengths[vector_id * K + i];
if (early_stopping == 0)
{
// Use generated_length (rather than sequence_length) to compute length_penalty
// https://github.com/huggingface/transformers/blob/main/src/transformers/generation/beam_search.py#L957
// But this branch will cause CI error in
// "C++ Tests (GPT) on A30", "C++ Tests (GPT-J) on H100_PCIe", "H100_PCIe-accuracy-0"
length -= beam_hyps.input_lengths[global_batch_idx];
}
const int pad_if_not_finish = finished[vector_id * K + i].isFinished() ? 0 : 1;
elem = apply_length_penalty(elem, length + pad_if_not_finish, length_penalty);
}
elem += diversity_rate * (T) i;
int elem_idx = elem_id; // x[elem_id];
cub_kvp new_elem{elem_idx, elem};
cub_kvp new_elem{elem_id, elem};
partial_topk = arg_max(partial_topk, new_elem);
buf_s[elem_id] = elem;
}
__syncthreads();
__shared__ int thread_requiring_update;
for (int i = 0; i < 2 * K; ++i)
{
cub_kvp total_topk = BlockReduce(temp_storage).Reduce(partial_topk, arg_max);
if (threadIdx.x == 0)
{
cta_topk[i] = total_topk;
@ -196,13 +210,13 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
}
__syncthreads();
// Only 1 thread needs to update the old partial before the next block reduce. We don't need to do this update
// on the last iteration.
// Only one thread needs to update the old partial before the next block reduce.
// No need to do this in the last iteration.
if (thread_id == thread_requiring_update && i < (2 * K - 1))
{
partial_topk.key = V - 1;
partial_topk.key = candidate_size - 1;
partial_topk.value = -MAX_T_VAL;
for (int tid = thread_id; tid < V; tid += THREADBLOCK_SIZE)
for (int tid = thread_id; tid < candidate_size; tid += THREADBLOCK_SIZE)
{
cub_kvp new_elem{tid, buf_s[tid]};
partial_topk = arg_max(partial_topk, new_elem);
@ -212,104 +226,101 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
if (thread_id == 0)
{
v += vector_id * K;
cum_log_probs += vector_id * K;
for (int i = 0; i < 2 * K; ++i)
{
const int current_key = cta_topk[i].key;
const T current_value = cta_topk[i].value;
if (i < K && beam_hyps.num_beams != nullptr && x[current_key] % vocab_size == beam_hyps.end_ids[vector_id])
if (i < K && beam_hyps.num_beams != nullptr
&& topk_tmp_id_buf[current_key] % vocab_size == beam_hyps.end_ids[vector_id])
{
// if beam_token does not belong to top num_beams tokens, it should not
// be added. Refer from
// https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/generation_beam_search.py#L257
{
const float normed_score = (float) current_value;
const int num_beam = beam_hyps.num_beams[global_batch_idx];
int beam_idx = num_beam;
// If there are beam_width finished sentences, check that the score of
// selected candidatet is higher than min_normed_score or not. If
// current score is better, replace worst one and update the
// min_normed_score.
if (num_beam == K)
{
if (normed_score < beam_hyps.min_normed_scores[global_batch_idx])
{
// end the tracing and exist this for loop
selected_beams = K;
break;
}
else
{
// find the beam index which's score = min_normed_score, erase it.
for (int j = 0; j < K; j++)
{
if (beam_hyps.normed_scores[global_batch_idx * (K * 2) + j]
== beam_hyps.min_normed_scores[global_batch_idx])
{
beam_idx = j;
beam_hyps.num_beams[global_batch_idx]--;
// Add beam only if beam_token belongs to top K tokens
// https://github.com/huggingface/transformers/blob/main/src/transformers/generation/beam_search.py#L272
const float normed_score = (float) current_value;
const int num_beam = beam_hyps.num_beams[global_batch_idx];
int beam_idx = num_beam;
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX;
beam_hyps.normed_scores[global_batch_idx * (K * 2) + j] = normed_score;
for (int l = 0; l < K; l++)
{
beam_hyps.min_normed_scores[global_batch_idx]
= min(beam_hyps.min_normed_scores[global_batch_idx],
beam_hyps.normed_scores[global_batch_idx * (K * 2) + l]);
}
break;
// There are already K beams
if (num_beam == K)
{
// The current score is worse than the worst one in beams
if (normed_score < beam_hyps.min_normed_scores[global_batch_idx])
{
selected_beams = K;
break;
}
// The current score is better than the worst one in beams
else
{
// Find the beam index which score == min_normed_score and erase it.
for (int j = 0; j < K; j++)
{
if (beam_hyps.normed_scores[global_batch_idx * (K * 2) + j]
== beam_hyps.min_normed_scores[global_batch_idx])
{
beam_idx = j;
beam_hyps.num_beams[global_batch_idx]--;
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX;
beam_hyps.normed_scores[global_batch_idx * (K * 2) + j] = normed_score;
for (int l = 0; l < K; l++)
{
beam_hyps.min_normed_scores[global_batch_idx]
= min(beam_hyps.min_normed_scores[global_batch_idx],
beam_hyps.normed_scores[global_batch_idx * (K * 2) + l]);
}
break;
}
}
}
const int tgt_id_offset
= ((vector_id + beam_hyps.ite * beam_hyps.local_batch_size) * (K * 2) + beam_idx)
* (beam_hyps.max_seq_len);
int prev_id = (x[current_key] / vocab_size) % K;
const int current_step{sequence_lengths[vector_id * K + prev_id]};
beam_hyps.output_ids_tgt[tgt_id_offset + current_step] = beam_hyps.end_ids[vector_id];
if (beam_hyps.log_probs != nullptr)
{
beam_hyps.log_probs[tgt_id_offset + current_step]
= (float) y[current_key] - old_cum_log_probs[(x[current_key] / vocab_size) % K];
}
for (int j = current_step - 1; j >= 0; j--)
{
const int src_idx = j * beam_hyps.batch_size * K
+ beam_hyps.ite * beam_hyps.local_batch_size * K + vector_id * K + prev_id;
beam_hyps.output_ids_tgt[tgt_id_offset + j]
= beam_hyps.output_ids_src_ptr[vector_id][prev_id * beam_hyps.max_seq_len + j];
if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr)
{
beam_hyps.log_probs[tgt_id_offset + j] = beam_hyps.log_probs_src[src_idx];
}
prev_id = beam_hyps.parent_ids_src_ptr[vector_id][prev_id * beam_hyps.max_seq_len + j];
}
const int tgt_beam_idx = global_batch_idx * (K * 2) + beam_idx;
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = current_step;
beam_hyps.normed_scores[tgt_beam_idx] = normed_score;
beam_hyps.min_normed_scores[global_batch_idx]
= min(beam_hyps.min_normed_scores[global_batch_idx], beam_hyps.normed_scores[tgt_beam_idx]);
beam_hyps.num_beams[global_batch_idx]++;
beam_hyps.cum_log_probs[tgt_beam_idx] = (float) y[current_key];
}
const int tgt_id_offset
= ((vector_id + beam_hyps.ite * beam_hyps.local_batch_size) * (K * 2) + beam_idx)
* (beam_hyps.max_seq_len);
int prev_id = (topk_tmp_id_buf[current_key] / vocab_size) % K;
const int current_step{sequence_lengths[vector_id * K + prev_id]};
beam_hyps.output_ids_tgt[tgt_id_offset + current_step] = beam_hyps.end_ids[vector_id];
if (beam_hyps.log_probs != nullptr)
{
beam_hyps.log_probs[tgt_id_offset + current_step] = (float) topk_tmp_val_buf[current_key]
- old_cum_log_probs[(topk_tmp_id_buf[current_key] / vocab_size) % K];
}
for (int j = current_step - 1; j >= 0; j--)
{
const int src_idx = j * beam_hyps.batch_size * K + beam_hyps.ite * beam_hyps.local_batch_size * K
+ vector_id * K + prev_id;
beam_hyps.output_ids_tgt[tgt_id_offset + j]
= beam_hyps.output_ids_src_ptr[vector_id][prev_id * beam_hyps.max_seq_len + j];
if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr)
{
beam_hyps.log_probs[tgt_id_offset + j] = beam_hyps.log_probs_src[src_idx];
}
prev_id = beam_hyps.parent_ids_src_ptr[vector_id][prev_id * beam_hyps.max_seq_len + j];
}
const int tgt_beam_idx = global_batch_idx * (K * 2) + beam_idx;
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = current_step;
beam_hyps.normed_scores[tgt_beam_idx] = normed_score;
beam_hyps.min_normed_scores[global_batch_idx]
= min(beam_hyps.min_normed_scores[global_batch_idx], beam_hyps.normed_scores[tgt_beam_idx]);
beam_hyps.num_beams[global_batch_idx]++;
cum_log_probs[tgt_beam_idx] = (float) topk_tmp_val_buf[current_key];
}
else if ((beam_hyps.num_beams != nullptr && i < 2 * K) || (beam_hyps.num_beams == nullptr && i < K))
else if (beam_hyps.num_beams != nullptr || beam_hyps.num_beams == nullptr && i < K)
{
const int current_step{sequence_lengths[vector_id * K + selected_beams]};
output_ids_ptr[vector_id][selected_beams * beam_hyps.max_seq_len + current_step] = x[current_key];
beam_hyps.output_ids_tgt_ptr[vector_id][selected_beams * beam_hyps.max_seq_len + current_step]
= topk_tmp_id_buf[current_key];
if (output_log_probs != nullptr)
{
output_log_probs[current_step * beam_hyps.batch_size * K + vector_id * K + selected_beams]
= (float) y[current_key] - old_cum_log_probs[(x[current_key] / vocab_size) % K];
= (float) topk_tmp_val_buf[current_key]
- old_cum_log_probs[(topk_tmp_id_buf[current_key] / vocab_size) % K];
}
v[selected_beams] = (float) y[current_key];
cum_log_probs[selected_beams] = (float) topk_tmp_val_buf[current_key];
selected_beams++;
}
__syncthreads();
@ -319,15 +330,46 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
}
}
}
// update beam_hyps.is_done for each batch
if (threadIdx.x == 0 && beam_hyps.num_beams != nullptr)
{
// no enough beams
if (beam_hyps.num_beams[blockIdx.x] < K)
{
beam_hyps.is_done[blockIdx.x] = false;
return;
}
else if (beam_hyps.early_stopping)
float highest_attainable_score = 0.0f;
switch (early_stopping)
{
case 1:
// enough beams with early stopping
beam_hyps.is_done[blockIdx.x] = true;
return;
case 0:
// enough beams without early stopping
highest_attainable_score = static_cast<float>(apply_length_penalty(cum_log_probs[0],
sequence_lengths[vector_id * K] - beam_hyps.input_lengths[global_batch_idx], length_penalty));
beam_hyps.is_done[blockIdx.x] = beam_hyps.min_normed_scores[global_batch_idx] >= highest_attainable_score;
return;
default:
// early_stopping == "never" in HF, i.e., compute the best possible score depending on `length_penalty`
// https://github.com/huggingface/transformers/blob/main/src/transformers/generation/beam_search.py#L990
if (length_penalty > 0.0f)
{
highest_attainable_score = static_cast<float>(apply_length_penalty(cum_log_probs[0],
beam_hyps.max_seq_len - beam_hyps.input_lengths[global_batch_idx], length_penalty));
beam_hyps.is_done[blockIdx.x]
= beam_hyps.min_normed_scores[global_batch_idx] >= highest_attainable_score;
}
else
{
highest_attainable_score = static_cast<float>(apply_length_penalty(cum_log_probs[0],
sequence_lengths[vector_id * K] - beam_hyps.input_lengths[global_batch_idx], length_penalty));
beam_hyps.is_done[blockIdx.x]
= beam_hyps.min_normed_scores[global_batch_idx] >= highest_attainable_score;
}
return;
}
}
}
@ -366,24 +408,22 @@ __device__ __forceinline__ TopKMD<T, MAX_K> reduce_topk_md_op(const TopKMD<T, MA
}
template <typename T, int ITEMS_PER_THREAD, int MAX_K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_kernel(const T* __restrict x,
const T* __restrict b, const float* __restrict c, const FinishedState* __restrict finished, int* __restrict z,
T* __restrict v, int V, int K, const int* __restrict end_ids)
__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_kernel(const T* __restrict log_probs,
const T* __restrict bias, const float* __restrict cum_log_probs, const FinishedState* __restrict finished,
int* __restrict topk_tmp_id_buf, T* __restrict topk_tmp_val_buf, int vocab_size, int K,
const int* __restrict end_ids)
{
int thread_id = threadIdx.x;
int vector_id = blockIdx.x;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
// reposition y to data for the current vector
x += vector_id * V;
const int thread_id = threadIdx.x;
const int vector_id = blockIdx.x;
const T MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
typedef cub::BlockReduce<TopKMD<float, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
// reposition log_probs to data for the current vector
log_probs += vector_id * vocab_size;
TopKMD<float, MAX_K> partial;
bool finish = finished[vector_id].isFinished();
for (int i = 0; i < MAX_K; ++i)
{
partial.topk.p[i] = -1;
@ -392,22 +432,22 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_ker
partial.md.m = -MAX_T_VAL;
partial.md.d = 0.0F;
if (finish)
if (finished[vector_id].isFinished())
{
for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE)
for (int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE)
{
float elem = (elem_id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL;
MD new_elem{elem, 1.0F};
partial.md = reduce_md_op(partial.md, new_elem);
partial.topk.insert(elem, elem_id);
// if (elem_id > THREADBLOCK_SIZE * MAX_K && (elem_id == E)) break;
// if (elem_id > THREADBLOCK_SIZE * MAX_K && elem_id == E) break;
}
}
else
{
for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE)
for (int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE)
{
float elem = x[elem_id] + b[elem_id];
float elem = log_probs[elem_id] + bias[elem_id];
MD new_elem{elem, 1.0F};
partial.md = reduce_md_op(partial.md, new_elem);
partial.topk.insert(elem, elem_id);
@ -418,9 +458,9 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_ker
if (thread_id == 0)
{
z += vector_id * K;
v += vector_id * K;
c += vector_id;
topk_tmp_id_buf += vector_id * K;
topk_tmp_val_buf += vector_id * K;
cum_log_probs += vector_id;
// float d_total_inverse = __fdividef(1.0F, total.md.d);
float d_total_log = logf(total.md.d);
@ -430,48 +470,43 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_ker
float val = total.topk.u[i] - total.md.m - d_total_log;
if (i < K)
{
z[i] = total.topk.p[i] + vector_id * V; // trtllm needs absolute id
v[i] = val + c[0];
topk_tmp_id_buf[i] = total.topk.p[i] + vector_id * vocab_size; // trtllm needs absolute id
topk_tmp_val_buf[i] = val + cum_log_probs[0];
}
}
}
}
template <typename T, int ITEMS_PER_THREAD, int MAX_K2, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE, 1) __global__
void beam_online_softmax_topk_stage1_kernel_base(const T* __restrict x, const T* __restrict b,
const FinishedState* __restrict finished, float* __restrict t, int V, int K, const int* __restrict end_ids)
__launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_stage1_kernel_base(
const T* __restrict log_probs, const T* __restrict bias, const FinishedState* __restrict finished,
float* __restrict tmp_buffer, int vocab_size, int K, const int* __restrict end_ids)
{
int thread_id = threadIdx.x;
int vector_id = blockIdx.x; // batch beam index.
const int thread_id = threadIdx.x;
const int vector_id = blockIdx.x;
const T MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
const int PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
// one will have multiple sections per V
const int v_local = (V + gridDim.y - 1) / gridDim.y;
// one threadblock has multiple sections per vocab_size
const int v_local = (vocab_size + gridDim.y - 1) / gridDim.y;
const int section_start = v_local * blockIdx.y;
int section_end = section_start + v_local;
section_end = (section_end > V) ? V : section_end;
const int section_end = std::min(section_start + v_local, vocab_size);
// reposition x to data for the current vector
x += vector_id * V;
#if TOPK_FP16_STORAGE == 1
typedef cub::BlockReduce<TopKMD<__half, MAX_K2>, THREADBLOCK_SIZE> BlockReduce;
#else
typedef cub::BlockReduce<TopKMD<T, MAX_K2>, THREADBLOCK_SIZE> BlockReduce;
#endif
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ float buf_s[PACKED_TOP_KMD_SIZE]; // save intermediate result
__shared__ float buf_s[PACKED_TOP_KMD_SIZE];
// reposition log_probs to data for the current vector
log_probs += vector_id * vocab_size;
#if TOPK_FP16_STORAGE == 1
TopKMD<__half, MAX_K2> partial;
#else
TopKMD<T, MAX_K2> partial;
#endif
bool finish = finished[vector_id].isFinished();
for (int i = 0; i < MAX_K2; ++i)
{
partial.topk.p[i] = -1;
@ -480,7 +515,7 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__
partial.md.m = -MAX_T_VAL;
partial.md.d = 0.0F;
if (finish)
if (finished[vector_id].isFinished())
{
#pragma unroll 1
for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE)
@ -496,8 +531,8 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__
#pragma unroll 1
for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE)
{
T bias = b == nullptr ? (T) 0.0f : b[elem_id]; // gpt-2 does not use bias
T elem = x[elem_id] + bias;
T b = bias == nullptr ? (T) 0.0f : bias[elem_id];
T elem = log_probs[elem_id] + b;
MD new_elem{elem, 1.0F};
partial.md = reduce_md_op(partial.md, new_elem);
partial.topk.insert(elem, elem_id);
@ -514,7 +549,7 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__
{
for (int i = 0; i < 2 * K; i++)
{
reinterpret_cast<int*>(buf_s)[i] = total.topk.p[i] + vector_id * V; // trtllm needs absolute id
reinterpret_cast<int*>(buf_s)[i] = total.topk.p[i] + vector_id * vocab_size; // trtllm needs absolute id
buf_s[MAX_K2 + i] = total.topk.u[i];
}
buf_s[2 * MAX_K2] = total.md.d;
@ -523,38 +558,25 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__
__syncthreads();
for (int elem_id = thread_id; elem_id < PACKED_TOP_KMD_SIZE; elem_id += THREADBLOCK_SIZE)
{
t[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id] = buf_s[elem_id];
tmp_buffer[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id]
= buf_s[elem_id];
}
}
template <typename T, int ITEMS_PER_THREAD, int MAX_K2, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_stage1_kernel_fast(
const T* __restrict x, const T* __restrict b, const FinishedState* __restrict finished, float* __restrict t, int V,
int K, const int* __restrict end_ids, const int v_local)
const T* __restrict log_probs, const T* __restrict bias, const FinishedState* __restrict finished,
float* __restrict t, int vocab_size, int K, const int* __restrict end_ids, const int v_local)
{
extern __shared__ char buf_smem_logprobs_[];
T* buf_smem_logprobs = reinterpret_cast<T*>(buf_smem_logprobs_);
int thread_id = threadIdx.x;
int vector_id = blockIdx.x; // batch beam index.
const int thread_id = threadIdx.x;
const int vector_id = blockIdx.x;
const T MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
const int PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
// reposition x to data for the current vector
x += vector_id * V;
// one will have multiple sections per V
// one threadblock has multiple sections per vocab_size
const int section_start = v_local * blockIdx.y;
int section_end = section_start + v_local;
section_end = (section_end > V) ? V : section_end;
const int section_end = std::min(section_start + v_local, vocab_size);
const int valid_smem_length = section_end - section_start;
bool finish = finished[vector_id].isFinished();
MD partial_md{-MAX_T_VAL, 0.0f};
#if TOPK_FP16_STORAGE == 1
using cub_kvp = cub::KeyValuePair<int, __half>;
using BlockReduceTopK = cub::BlockReduce<cub_kvp, THREADBLOCK_SIZE>;
@ -565,10 +587,25 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_
using BlockReduceMD = cub::BlockReduce<MD, THREADBLOCK_SIZE>;
cub::ArgMax arg_max;
cub_kvp partial_topk{V - 1, -MAX_T_VAL};
extern __shared__ char buf_smem_logprobs_[];
T* buf_smem_logprobs = reinterpret_cast<T*>(buf_smem_logprobs_);
if (finish)
__shared__ union
{
typename BlockReduceMD::TempStorage md_smem;
typename BlockReduceTopK::TempStorage topk_smem;
} temp_storage;
__shared__ float buf_s[PACKED_TOP_KMD_SIZE];
__shared__ int thread_requiring_update;
// reposition log_probs to data for the current vector
log_probs += vector_id * vocab_size;
cub::ArgMax arg_max;
cub_kvp partial_topk{vocab_size - 1, -MAX_T_VAL};
MD partial_md{-MAX_T_VAL, 0.0f};
if (finished[vector_id].isFinished())
{
#pragma unroll 1
for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE)
@ -589,8 +626,8 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_
#pragma unroll 1
for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE)
{
T bias = b == nullptr ? (T) 0.0f : b[elem_id]; // gpt-2 does not use bias
T elem = x[elem_id] + bias;
T b = bias == nullptr ? (T) 0.0f : bias[elem_id];
T elem = log_probs[elem_id] + b;
MD new_elem_md{elem, 1.0F};
partial_md = reduce_md_op(partial_md, new_elem_md);
@ -600,18 +637,8 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_
buf_smem_logprobs[smem_index] = elem;
}
}
__syncthreads();
__shared__ union
{
typename BlockReduceMD::TempStorage md_smem;
typename BlockReduceTopK::TempStorage topk_smem;
} temp_storage;
__shared__ float buf_s[PACKED_TOP_KMD_SIZE]; // save intermediate result
__shared__ int thread_requiring_update;
for (int i = 0; i < 2 * K; ++i)
{
cub_kvp total_topk = BlockReduceTopK(temp_storage.topk_smem).Reduce(partial_topk, arg_max);
@ -619,18 +646,18 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_
if (threadIdx.x == 0)
{
reinterpret_cast<int*>(buf_s)[i]
= section_start + total_topk.key + vector_id * V; // trtllm needs absolute id
= section_start + total_topk.key + vector_id * vocab_size; // trtllm needs absolute id
buf_s[MAX_K2 + i] = total_topk.value;
buf_smem_logprobs[total_topk.key] = -MAX_T_VAL;
thread_requiring_update = total_topk.key % THREADBLOCK_SIZE;
}
__syncthreads();
// Only 1 thread needs to update the old partial before the next block reduce. We don't need to do this update
// on the last iteration.
// Only one thread needs to update the old partial before the next block reduce.
// No need to do this in the last iteration.
if (thread_id == thread_requiring_update && i < (2 * K - 1))
{
partial_topk.key = V - 1;
partial_topk.key = vocab_size - 1;
partial_topk.value = -MAX_T_VAL;
for (int tid = thread_id; tid < valid_smem_length; tid += THREADBLOCK_SIZE)
{
@ -649,7 +676,6 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_
buf_s[2 * MAX_K2 + 1] = total_md.m;
}
__syncthreads();
for (int elem_id = thread_id; elem_id < PACKED_TOP_KMD_SIZE; elem_id += THREADBLOCK_SIZE)
{
t[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id] = buf_s[elem_id];
@ -657,58 +683,52 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_
}
template <typename T, int MAX_K2, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_stage2_kernel(const float* __restrict x,
const float* __restrict c, int* __restrict z, T* __restrict v, int K, int parts_per_beam, const int V)
__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_stage2_kernel(
const float* __restrict temp_storage, const float* __restrict cum_log_probs, int* __restrict ids,
T* __restrict vals, int K, int parts_per_beam, const int vocab_size)
{
const int vector_id = blockIdx.x;
const int thread_id = threadIdx.x;
const T MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
const int PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
extern __shared__ char buf_s_[]; // intermediate result
float* buf_s = reinterpret_cast<float*>(buf_s_);
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduceTopK = cub::BlockReduce<cub_kvp, THREADBLOCK_SIZE>;
using BlockReduceMD = cub::BlockReduce<MD, THREADBLOCK_SIZE>;
extern __shared__ char buf_s_[];
float* buf_s = reinterpret_cast<float*>(buf_s_);
__shared__ cub_kvp buf_smem_kv[MAX_K2];
__shared__ union
{
typename BlockReduceTopK::TempStorage topk_smem;
typename BlockReduceMD::TempStorage md_smem;
} temp_storage;
} shared_temp_storage;
temp_storage += vector_id * PACKED_TOP_KMD_SIZE * parts_per_beam;
cub::ArgMax arg_max;
x += vector_id * PACKED_TOP_KMD_SIZE * parts_per_beam;
MD partial_md{-MAX_T_VAL, 0.0f};
cub_kvp total_topk{V - 1, -MAX_T_VAL};
cub_kvp total_topk{vocab_size - 1, -MAX_T_VAL};
__shared__ char buf_smem_kv_store[MAX_K2 * sizeof(cub_kvp)];
auto* buf_smem_kv = reinterpret_cast<cub_kvp*>(buf_smem_kv_store);
// load and unpack into registers through smem
// Load and unpack into registers through smem
for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * parts_per_beam; idx += THREADBLOCK_SIZE)
{
buf_s[idx] = x[idx];
buf_s[idx] = temp_storage[idx];
}
__syncthreads();
// find the argmax within each parts_per_beam,
// find the topK across all parts_per_beam.
// Find the argmax within each parts_per_beam
// Find the topK across all parts_per_beam
for (int k = 0; k < 2 * K; ++k)
{
cub_kvp partial_topk{V - 1, -MAX_T_VAL};
cub_kvp partial_topk{vocab_size - 1, -MAX_T_VAL};
// Only threads responsible for a chunk will do the computation
if (threadIdx.x < parts_per_beam)
{
float* b_s = buf_s + threadIdx.x * PACKED_TOP_KMD_SIZE;
for (int i = 0; i < K; ++i)
{
int current_index = threadIdx.x * PACKED_TOP_KMD_SIZE + i;
@ -718,21 +738,20 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta
}
}
cub_kvp total_topk = BlockReduceTopK(temp_storage.topk_smem).Reduce(partial_topk, arg_max);
cub_kvp total_topk = BlockReduceTopK(shared_temp_storage.topk_smem).Reduce(partial_topk, arg_max);
__syncthreads();
if (threadIdx.x == 0)
{
// store kv pairs in shared mem buffer
// Store kv pairs in shared mem buffer
int temp_offset = total_topk.key;
int global_offset = reinterpret_cast<int*>(buf_s)[temp_offset];
total_topk.key = global_offset;
buf_smem_kv[k] = total_topk;
// Invalidate the maximum value within the chunk
reinterpret_cast<int*>(buf_s)[temp_offset] = V - 1; // id in share memory
buf_s[temp_offset + MAX_K2] = -MAX_T_VAL; // value in share memory
reinterpret_cast<int*>(buf_s)[temp_offset] = vocab_size - 1; // id in share memory
buf_s[temp_offset + MAX_K2] = -MAX_T_VAL; // value in share memory
}
__syncthreads();
}
@ -744,18 +763,16 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta
partial_md.d = b_s[2 * MAX_K2];
partial_md.m = b_s[2 * MAX_K2 + 1];
}
__syncthreads();
auto reduce_md_func = [](const MD& a, const MD& b) { return reduce_md_op(a, b); };
MD total_md = BlockReduceMD(temp_storage.md_smem).Reduce(partial_md, reduce_md_func);
__syncthreads();
MD total_md = BlockReduceMD(shared_temp_storage.md_smem).Reduce(partial_md, reduce_md_func);
if (thread_id == 0)
{
z += vector_id * 2 * K;
v += vector_id * 2 * K;
c += vector_id;
ids += vector_id * 2 * K;
vals += vector_id * 2 * K;
cum_log_probs += vector_id;
float d_total_log = logf(total_md.d);
for (int i = 0; i < MAX_K2; ++i)
@ -763,8 +780,8 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta
float val = (float) buf_smem_kv[i].value - total_md.m - d_total_log;
if (i < 2 * K)
{
z[i] = buf_smem_kv[i].key;
v[i] = (float) val + (float) c[0];
ids[i] = buf_smem_kv[i].key;
vals[i] = (float) val + (float) cum_log_probs[0];
}
}
}
@ -774,9 +791,9 @@ template <typename T, int MAX_K2>
void beam_online_softmax_topk_stage2_kernelLauncher(const float* temp_storage, const float* cum_log_probs, int* ids,
T* vals, int batch_size, int beam_width, int parts_per_beam, cudaStream_t stream, const int vocab_size)
{
// might rewrite beam_online_softmax_topk_stage2_kernel no to depend on
// constant block size in oreder to reduce compilation time
int smem_stage2_size = parts_per_beam * (2 * MAX_K2 + 2) * sizeof(float);
// TODO: rewrite beam_online_softmax_topk_stage2_kernel to remove dependence
// of constant block size in oreder to reduce compilation time
const int smem_stage2_size = parts_per_beam * (2 * MAX_K2 + 2) * sizeof(float);
if (parts_per_beam <= 32)
{
@ -803,20 +820,22 @@ void beam_online_softmax_topk_stage2_kernelLauncher(const float* temp_storage, c
}
template <typename T, int MAX_K>
void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const FinishedState* finished,
const int* sequence_lengths, float* cum_log_probs, float* output_log_probs, int** output_ids_ptr,
void* temp_storage, const int temp_storage_size, BeamHypotheses* beam_hyps, const int batch_size,
const int beam_width, const int vocab_size, const int* end_ids, const float* diversity_rates,
const float* length_penalties, cudaStream_t stream)
void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const FinishedState* finished, float* cum_log_probs,
void* temp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps, cudaStream_t stream)
{
const int batch_size{beam_hyps.local_batch_size};
const int beam_width{beam_hyps.beam_width};
const int vocab_size{beam_hyps.vocab_size};
const int* end_ids{beam_hyps.end_ids};
const int items_per_thread = 1;
const int block_sz = (MAX_K < 16) ? (MAX_K < 8) ? SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE : 128 : 64;
const int block_sz = (MAX_K < 16) ? ((MAX_K < 8) ? SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE : 128) : 64;
// const int block_sz = SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE;
assert(temp_storage_size % 2 == 0);
assert(temp_storage_size >= 2 * batch_size * beam_width * beam_width * 2);
// Beam search needs the sequence lengths of beams to apply length penalty.
assert(length_penalties == nullptr || sequence_lengths != nullptr);
assert(beam_hyps.length_penalties == nullptr || beam_hyps.sequence_lengths_src != nullptr);
const int topk_buf_offset = ceil(batch_size * beam_width * beam_width * 2 / 4.) * 4;
int* topk_tmp_id_buf = reinterpret_cast<int*>(temp_storage);
@ -928,25 +947,22 @@ void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const Finish
// we will not put them into next iteration
const int candidates = beam_width * beam_width * 2;
int smem_size_batch_topk = sizeof(T) * candidates;
const int smem_size_batch_topk = sizeof(T) * candidates;
if (smem_size_batch_topk >= (48 << 10))
{
TLLM_CUDA_CHECK(cudaFuncSetAttribute(
batch_topk_kernel<T, MAX_K * 2, 32>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_batch_topk));
}
batch_topk_kernel<T, MAX_K * 2, 32><<<batch_size, 32, smem_size_batch_topk, stream>>>(topk_tmp_id_buf,
topk_tmp_val_buf, output_ids_ptr, cum_log_probs, output_log_probs, finished, sequence_lengths, *beam_hyps,
candidates, beam_width, vocab_size, length_penalties, diversity_rates);
batch_topk_kernel<T, MAX_K * 2, 32><<<batch_size, 32, smem_size_batch_topk, stream>>>(
topk_tmp_id_buf, topk_tmp_val_buf, cum_log_probs, finished, beam_hyps, candidates);
sync_check_cuda_error();
}
#define INSTANTIATE_BEAMSEARCH_K(T, MAX_K) \
template void topK_softMax_kernelLauncher<T, MAX_K>(const T* log_probs, const T* bias, \
const FinishedState* finished, const int* sequence_lengths, float* cum_log_probs, float* output_log_probs, \
int** output_ids_ptr, void* temp_storage, const int temp_storage_size, BeamHypotheses* beam_hyps, \
const int batch_size, const int beam_width, const int vocab_size, const int* end_ids, \
const float* diversity_rates, const float* length_penalties, cudaStream_t stream);
const FinishedState* finished, float* cum_log_probs, void* temp_storage, const int temp_storage_size, \
BeamHypotheses& beam_hyps, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -70,7 +70,7 @@ void splitkGroupedGemm_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std
int64_t gemmParamsWorkSpaceSize, void* gemmWorkSpace, int64_t gemmWorkSpaceSize,
std::vector<int64_t> splitkBufferOffsets, int splitKSlices, cudaStream_t stream)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
using ElementA = cutlassType;
using ElementB = cutlassType;
using ElementOutput = float;
@ -194,7 +194,7 @@ void splitkGroupedGemm_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to run CUTLASS Grouped GEMM kernel.");
std::free(host_workspace);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <int M1, int N1, int K1, int M2, int N2, int K2>

View File

@ -113,7 +113,7 @@ __device__ __forceinline__ void apply_scale(void* act, void* act_scale)
}
template <int N, int K, bool EnableZero>
__device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* scales, void* zeros)
__device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* scales, void* zeros, half alpha)
{
using Converter = ConverterI4ToF16;
static_assert(K % 2 == 0);
@ -123,11 +123,11 @@ __device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* sca
{
ConverterI4ToF16::convert<K>(
reinterpret_cast<uint8_t*>(quantized_w) + n * K / 2, reinterpret_cast<half*>(w) + n * K);
half2 vec_scale = __half2half2(reinterpret_cast<half*>(scales)[n]);
half2 vec_scale = __half2half2(reinterpret_cast<half*>(scales)[n] * alpha);
half2 vec_zero = __half2half2(__float2half_rn(0.f));
if constexpr (EnableZero)
{
vec_zero = __half2half2(reinterpret_cast<half*>(zeros)[n]);
vec_zero = __half2half2(reinterpret_cast<half*>(zeros)[n] * alpha);
}
#pragma unroll
for (int k = 0; k < VecK; ++k)
@ -186,7 +186,7 @@ __device__ __forceinline__ T warp_reduce_sum(T& val)
}
template <int CtaM, int CtaN, int Threads, bool EnableBias>
__device__ __forceinline__ void epilogue(void* out, int stride, void* tile_acc, void* bias, float alpha)
__device__ __forceinline__ void epilogue(void* out, int stride, void* tile_acc, void* bias)
{
static constexpr int WarpSize = 32;
static constexpr int WarpNum = Threads / WarpSize;
@ -224,7 +224,7 @@ __device__ __forceinline__ void epilogue(void* out, int stride, void* tile_acc,
{
val += shmem[jj * AlignShmemSize + ii];
}
reinterpret_cast<half*>(out)[m * stride + n] = __float2half_rn(alpha * val + v_bias);
reinterpret_cast<half*>(out)[m * stride + n] = __float2half_rn(val + v_bias);
}
}
@ -348,15 +348,17 @@ __global__ void kernel(typename Details::ActDataType* act, half* act_scale, uint
load<half, CtaN, Mandatory>(vec_scale, scales + idx_k / GroupSize * n, 1);
load<half, CtaN, EnableZero>(vec_zero, zeros + idx_k / GroupSize * n, 1);
// Dequantize Data
// W4A8 checkpoints have larger activation and weight values. In order to prevent the warp-level FP16
// accumulator from overflow, the multiplication of alpha is moved from epilogue to dequantize
apply_scale<CtaM, StepK, EnableActScale>(tile_a, vec_act_scale);
dequantize<CtaN, StepK, EnableZero>(tile_w, tile_w_quantized, vec_scale, vec_zero);
dequantize<CtaN, StepK, EnableZero>(tile_w, tile_w_quantized, vec_scale, vec_zero, __float2half_rn(alpha));
// Rearrange
pack_to_vec2<CtaN, StepK>(tile_w_pack2, tile_w);
// MMA
mma<CtaM, CtaN, StepK>(tile_acc, tile_w_pack2, tile_a);
}
// Epilogue
epilogue<CtaM, CtaN, Threads, EnableBias>(out, n, tile_acc, bias, alpha);
epilogue<CtaM, CtaN, Threads, EnableBias>(out, n, tile_acc, bias);
}
template <typename Details, int CtaM, int CtaN, int Threads, int GroupSize, bool EnableActScale, bool EnableZero,

View File

@ -251,6 +251,7 @@ void DynamicDecodeLayer<T>::setupLayers(
beamSearchParams.beam_search_diversity_rate = setupParams.beam_search_diversity_rate;
beamSearchParams.length_penalty = setupParams.length_penalty;
beamSearchParams.early_stopping = setupParams.early_stopping;
mHasDiffRuntimeArgs = hasDiffRuntimeArgs(beamSearchParams);
mOnlineBeamSearchDecode->setup(batchSize, beamSearchParams);

View File

@ -75,6 +75,7 @@ public:
// omlineBeamSearchLayer
std::optional<std::vector<float>> beam_search_diversity_rate;
std::optional<std::vector<float>> length_penalty;
std::optional<std::vector<int>> early_stopping;
std::optional<bool> normalize_log_probs;
};

View File

@ -31,12 +31,21 @@ static const int SMALL_TOP_K_SOFTMAX_MAX_VOC_PARTS = 128;
static const int MAX_K = 4;
template <typename T>
__global__ void update_kernel(FinishedState* finished, int** parent_ids_ptr, int* sequence_lengths,
int** output_ids_ptr, BeamHypotheses beam_hyps, const int vocab_size, const int* end_ids,
const int local_batch_size, const int beam_width, const int max_seq_len)
__global__ void update_kernel(FinishedState* finished, BeamHypotheses beam_hyps)
{
const int beam_width{beam_hyps.beam_width};
const int ite{beam_hyps.ite};
const int local_batch_size{beam_hyps.local_batch_size};
const int max_seq_len{beam_hyps.max_seq_len};
const int vocab_size{beam_hyps.vocab_size};
const int end_id{beam_hyps.end_ids[blockIdx.x]};
int* num_beams{beam_hyps.num_beams};
int* sequence_lengths{beam_hyps.sequence_lengths_src};
int** output_ids_ptr{beam_hyps.output_ids_tgt_ptr};
int** parent_ids_ptr{beam_hyps.parent_ids_tgt_ptr};
extern __shared__ char s_buf[]; // intermediate result
int* s_sequence_lengths = (int*) (s_buf);
int* s_sequence_lengths = reinterpret_cast<int*>(s_buf);
for (int beam_idx = threadIdx.x; beam_idx < beam_width; beam_idx += blockDim.x)
{
@ -63,35 +72,27 @@ __global__ void update_kernel(FinishedState* finished, int** parent_ids_ptr, int
new_word_id = new_word_id % vocab_size;
sequence_lengths[batch_beam_idx] = s_sequence_lengths[new_beam_id];
if (new_word_id == end_ids[blockIdx.x])
if (new_word_id == end_id)
{
finished[batch_beam_idx].setFinishedEOS();
}
parent_ids_ptr[blockIdx.x][beam_idx * max_seq_len + current_step] = new_beam_id;
output_ids_ptr[blockIdx.x][beam_idx * max_seq_len + current_step] = new_word_id;
}
if (beam_hyps.num_beams != nullptr)
if (num_beams != nullptr && num_beams[ite * local_batch_size + blockIdx.x] == beam_width)
{
if (beam_hyps.num_beams[beam_hyps.ite * beam_hyps.local_batch_size + blockIdx.x] == beam_width)
for (int beam_idx = threadIdx.x; beam_idx < beam_width; beam_idx += blockDim.x)
{
for (int beam_idx = threadIdx.x; beam_idx < beam_width; beam_idx += blockDim.x)
{
const auto batch_beam_idx = blockIdx.x * beam_width + beam_idx;
finished[batch_beam_idx].setFinished();
}
finished[blockIdx.x * beam_width + beam_idx].setFinished();
}
}
}
void invokeUpdate(FinishedState* finished, int** parent_ids_ptr, int* sequence_lengths, int** output_ids_ptr,
BeamHypotheses* beam_hyps, const int local_batch_size, const int beam_width, const int vocab_size_padded,
const int* end_ids, const int max_seq_len, cudaStream_t stream)
void invokeUpdate(FinishedState* finished, BeamHypotheses& beam_hyps, cudaStream_t stream)
{
dim3 grid(local_batch_size);
dim3 block(min(beam_width, 1024));
update_kernel<float><<<grid, block, sizeof(int) * beam_width, stream>>>(finished, parent_ids_ptr, sequence_lengths,
output_ids_ptr, *beam_hyps, vocab_size_padded, end_ids, local_batch_size, beam_width, max_seq_len);
dim3 grid(beam_hyps.local_batch_size);
dim3 block(min(beam_hyps.beam_width, 1024));
update_kernel<float><<<grid, block, sizeof(int) * beam_hyps.beam_width, stream>>>(finished, beam_hyps);
}
template <typename T>
@ -103,11 +104,12 @@ void OnlineBeamSearchLayer<T>::setup(size_t batch_size, SetupParams const& setup
mDiversityRate.resize(batch_size);
mLengthPenalty.resize(batch_size);
mEarlyStopping.resize(batch_size);
FillBuffers const fillBuffers{batch_size, batch_size, mStream};
fillBuffers(setupParams.beam_search_diversity_rate, 0.0f, mDiversityRate, diversity_rates_buf_, (int*) nullptr);
fillBuffers(setupParams.length_penalty, 0.0f, mLengthPenalty, length_penalties_buf_, (int*) nullptr);
fillBuffers(setupParams.early_stopping, 1, mEarlyStopping, early_stoppings_buf_, (int*) nullptr);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
@ -115,44 +117,39 @@ template <typename T>
void OnlineBeamSearchLayer<T>::invokeSoftMax(BeamSearchOutputParams& outputs, SoftmaxParams const& params)
{
TLLM_LOG_TRACE("%s", __PRETTY_FUNCTION__);
Tensor const& output_ids_ptr = outputs.output_ids_ptr;
const auto batch_size = static_cast<std::int32_t>(output_ids_ptr.shape[0]);
const auto beam_width = static_cast<std::int32_t>(output_ids_ptr.shape[1]);
const auto max_seq_len = static_cast<std::int32_t>(output_ids_ptr.shape[2]);
const int ite{params.ite};
Tensor const& logits{params.logits};
const auto local_batch_size = logits.shape[0];
BeamHypotheses beamHypotheses;
auto* const end_ids = params.end_ids.template getPtr<const int>();
float* output_log_probs = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr<float>() : nullptr;
auto* finished
= reinterpret_cast<FinishedState*>(outputs.finished->template getPtr<FinishedState::UnderlyingType>());
auto* sequence_lengths = outputs.sequence_length->template getPtr<int>();
BeamHypotheses beam_hyps;
if (outputs.beamHypotheses)
{
beamHypotheses = *outputs.beamHypotheses;
beamHypotheses.ite = ite;
beamHypotheses.local_batch_size = local_batch_size;
beamHypotheses.batch_size = batch_size;
beamHypotheses.max_seq_len = max_seq_len;
beamHypotheses.output_ids_src_ptr = output_ids_ptr.template getPtr<const int*>();
beamHypotheses.parent_ids_src_ptr = outputs.parent_ids_ptr.template getPtr<const int*>();
beamHypotheses.sequence_lengths_src = sequence_lengths;
beamHypotheses.log_probs_src = output_log_probs;
beamHypotheses.length_penalties = length_penalties_buf_;
beamHypotheses.end_ids = end_ids;
beam_hyps = *outputs.beamHypotheses;
// Some of beam_hyps members have been initialized before function invokeSoftMax
beam_hyps.end_ids = params.end_ids.template getPtr<const int>();
beam_hyps.log_probs = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr<float>() : nullptr;
beam_hyps.output_ids_src_ptr = outputs.output_ids_ptr.template getPtr<const int*>();
beam_hyps.output_ids_tgt_ptr = outputs.output_ids_ptr.template getPtr<int*>();
beam_hyps.parent_ids_src_ptr = outputs.parent_ids_ptr.template getPtr<const int*>();
beam_hyps.parent_ids_tgt_ptr = outputs.parent_ids_ptr.template getPtr<int*>();
beam_hyps.sequence_lengths_src = outputs.sequence_length->template getPtr<int>();
beam_hyps.batch_size = static_cast<std::int32_t>(outputs.output_ids_ptr.shape[0]);
beam_hyps.beam_width = static_cast<std::int32_t>(outputs.output_ids_ptr.shape[1]);
beam_hyps.ite = params.ite;
beam_hyps.local_batch_size = params.logits.shape[0];
beam_hyps.max_seq_len = static_cast<std::int32_t>(outputs.output_ids_ptr.shape[2]);
beam_hyps.vocab_size = vocab_size_padded_;
beam_hyps.diversity_rates = diversity_rates_buf_;
beam_hyps.length_penalties = length_penalties_buf_;
beam_hyps.early_stoppings = early_stoppings_buf_;
}
invokeTopkSoftMax(logits.template getPtr<T>(), (const T*) (nullptr), finished, sequence_lengths,
outputs.cum_log_probs->template getPtr<float>(), output_log_probs, output_ids_ptr.getPtr<int*>(),
topk_softmax_workspace_, topk_softmax_workspace_size_, &beamHypotheses, local_batch_size, beam_width,
vocab_size_padded_, end_ids, diversity_rates_buf_, length_penalties_buf_, mStream);
invokeTopkSoftMax(params.logits.template getPtr<T>(), (const T*) (nullptr), finished,
outputs.cum_log_probs->template getPtr<float>(), topk_softmax_workspace_, topk_softmax_workspace_size_,
beam_hyps, mStream);
sync_check_cuda_error();
invokeUpdate(finished, outputs.parent_ids_ptr.template getPtr<int*>(), sequence_lengths,
output_ids_ptr.getPtr<int*>(), &beamHypotheses, local_batch_size, beam_width, vocab_size_padded_, end_ids,
max_seq_len, mStream);
invokeUpdate(finished, beam_hyps, mStream);
sync_check_cuda_error();
}
@ -169,6 +166,7 @@ void OnlineBeamSearchLayer<T>::allocateBuffer(size_t batch_size)
mAllocator->reMalloc(topk_softmax_workspace_, sizeof(float) * topk_softmax_workspace_size_, true));
diversity_rates_buf_ = mAllocator->reMalloc(diversity_rates_buf_, sizeof(float) * batch_size, false);
length_penalties_buf_ = mAllocator->reMalloc(length_penalties_buf_, sizeof(float) * batch_size, false);
early_stoppings_buf_ = mAllocator->reMalloc(early_stoppings_buf_, sizeof(int) * batch_size, false);
mIsAllocateBuffer = true;
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
@ -183,6 +181,7 @@ void OnlineBeamSearchLayer<T>::freeBuffer()
mAllocator->free((void**) (&topk_softmax_workspace_));
mAllocator->free((void**) (&diversity_rates_buf_));
mAllocator->free((void**) (&length_penalties_buf_));
mAllocator->free((void**) (&early_stoppings_buf_));
mIsAllocateBuffer = false;
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);

View File

@ -41,6 +41,7 @@ public:
public:
std::optional<std::vector<float>> beam_search_diversity_rate; // [1] or [batch_size] on cpu
std::optional<std::vector<float>> length_penalty; // [1] or [batch_size] on cpu
std::optional<std::vector<int>> early_stopping; // [1] or [batch_size] on cpu
};
OnlineBeamSearchLayer(
@ -71,8 +72,10 @@ protected:
std::vector<float> mDiversityRate;
std::vector<float> mLengthPenalty;
std::vector<int> mEarlyStopping;
float* diversity_rates_buf_;
float* length_penalties_buf_;
int* early_stoppings_buf_;
private:
void allocateBuffer(size_t batch_size);

View File

@ -168,6 +168,7 @@ bool GPTAttentionPluginCommon::convertMMHAParamsToXQAParams(tensorrt_llm::kernel
memset(&xqaParams, 0, sizeof(XQAParams));
xqaParams.data_type = ConvertMMHAToXQAParamsHelper<T, KVCacheBuffer>::data_type;
xqaParams.layer_idx = mLayerIdx;
xqaParams.num_q_heads = mNumHeads;
xqaParams.num_kv_heads = mNumKVHeads;
xqaParams.head_size = mHeadSize;
@ -363,8 +364,8 @@ INSTANTIATE_MMHA_DISPATCH(__nv_bfloat16, __nv_bfloat16)
#endif
#undef INSTANTIATE_MMHA_DISPATCH
GPTAttentionPluginCommon::GPTAttentionPluginCommon(int num_heads, int num_kv_heads, int head_size, int unidirectional,
float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, int num_kv_heads, int head_size,
int unidirectional, float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
@ -374,7 +375,8 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int num_heads, int num_kv_hea
bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length,
bool qkv_bias_enabled, bool cross_attention, int max_distance, bool pos_shift_enabled, bool dense_context_fmha,
bool use_paged_context_fmha, bool use_cache, bool is_medusa_enabled)
: mNumHeads(num_heads)
: mLayerIdx(layer_idx)
, mNumHeads(num_heads)
, mNumKVHeads(num_kv_heads)
, mHeadSize(head_size)
, mUnidirectional(unidirectional)
@ -477,6 +479,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(const void* data, size_t leng
const char *d = reinterpret_cast<const char*>(data), *a = d;
unsigned int kvCacheQuantMode;
read(d, mLayerIdx);
read(d, mNumHeads);
read(d, mNumKVHeads);
read(d, mHeadSize);
@ -1490,11 +1493,12 @@ void GPTAttentionPluginCommon::destroy() noexcept
size_t GPTAttentionPluginCommon::getCommonSerializationSize() noexcept
{
return sizeof(mNumHeads) + sizeof(mNumKVHeads) + sizeof(mHeadSize) + sizeof(mUnidirectional) + sizeof(mQScaling)
+ sizeof(mPositionEmbeddingType) + sizeof(mRotaryEmbeddingDim) + sizeof(mRotaryEmbeddingBase)
+ sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale) + sizeof(mRotaryEmbeddingMaxPositions)
+ sizeof(mTpSize) + sizeof(mTpRank) + sizeof(mEnableContextFMHA) + sizeof(mFMHAForceFP32Acc)
+ sizeof(mMultiBlockMode) + sizeof(mEnableXQA) + sizeof(unsigned int) // mKVCacheQuantMode
return sizeof(mLayerIdx) + sizeof(mNumHeads) + sizeof(mNumKVHeads) + sizeof(mHeadSize) + sizeof(mUnidirectional)
+ sizeof(mQScaling) + sizeof(mPositionEmbeddingType) + sizeof(mRotaryEmbeddingDim)
+ sizeof(mRotaryEmbeddingBase) + sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale)
+ sizeof(mRotaryEmbeddingMaxPositions) + sizeof(mTpSize) + sizeof(mTpRank) + sizeof(mEnableContextFMHA)
+ sizeof(mFMHAForceFP32Acc) + sizeof(mMultiBlockMode) + sizeof(mEnableXQA)
+ sizeof(unsigned int) // mKVCacheQuantMode
+ sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mPagedKVCache) + sizeof(mTokensPerBlock) + sizeof(mType)
+ sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled) + sizeof(mCrossAttention) + sizeof(mMaxDistance)
+ sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA) + sizeof(mPagedContextFMHA) + sizeof(mUseKVCache)
@ -1504,6 +1508,7 @@ size_t GPTAttentionPluginCommon::getCommonSerializationSize() noexcept
void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
{
char *d = static_cast<char*>(buffer), *a = d;
write(d, mLayerIdx);
write(d, mNumHeads);
write(d, mNumKVHeads);
write(d, mHeadSize);

View File

@ -36,8 +36,8 @@ class GPTAttentionPluginCommon : public BasePlugin
public:
GPTAttentionPluginCommon() = delete;
GPTAttentionPluginCommon(int num_heads, int num_kv_heads, int head_size, int unidirectional, float q_scaling,
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
GPTAttentionPluginCommon(int layer_idx, int num_heads, int num_kv_heads, int head_size, int unidirectional,
float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
@ -213,6 +213,7 @@ protected:
const std::string mLayerName;
int mLayerIdx;
int mNumHeads;
int mNumKVHeads;
int mHeadSize;

View File

@ -37,8 +37,8 @@ using tensorrt_llm::plugins::GPTAttentionPlugin;
static const char* GPT_ATTENTION_PLUGIN_VERSION{"1"};
static const char* GPT_ATTENTION_PLUGIN_NAME{"GPTAttention"};
GPTAttentionPlugin::GPTAttentionPlugin(int num_heads, int num_kv_heads, int head_size, int unidirectional,
float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int num_kv_heads, int head_size,
int unidirectional, float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
@ -48,12 +48,12 @@ GPTAttentionPlugin::GPTAttentionPlugin(int num_heads, int num_kv_heads, int head
bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length,
bool qkv_bias_enabled, bool cross_attention, int max_distance, bool pos_shift_enabled, bool dense_context_fmha,
bool use_paged_context_fmha, bool use_cache, bool is_medusa_enabled)
: GPTAttentionPluginCommon(num_heads, num_kv_heads, head_size, unidirectional, q_scaling, position_embedding_type,
rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale,
rotary_embedding_max_positions, tp_size, tp_rank, unfuse_qkv_gemm, context_fmha_type, multi_block_mode,
enable_xqa, kv_cache_quant_mode, remove_input_padding, mask_type, paged_kv_cache, tokens_per_block, type,
max_context_length, qkv_bias_enabled, cross_attention, max_distance, pos_shift_enabled, dense_context_fmha,
use_paged_context_fmha, use_cache, is_medusa_enabled)
: GPTAttentionPluginCommon(layer_idx, num_heads, num_kv_heads, head_size, unidirectional, q_scaling,
position_embedding_type, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type,
rotary_embedding_scale, rotary_embedding_max_positions, tp_size, tp_rank, unfuse_qkv_gemm, context_fmha_type,
multi_block_mode, enable_xqa, kv_cache_quant_mode, remove_input_padding, mask_type, paged_kv_cache,
tokens_per_block, type, max_context_length, qkv_bias_enabled, cross_attention, max_distance, pos_shift_enabled,
dense_context_fmha, use_paged_context_fmha, use_cache, is_medusa_enabled)
{
initEntryIdx();
}
@ -485,7 +485,7 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
// Note that this cyclic_attention_window_size might be smaller than the actual kv cache capactity.
const int cyclic_attention_window_size = isCrossAttention()
? max_encoder_context_len
: reinterpret_cast<const int*>(inputs[getIdx(IdxEntry::HOST_MAX_ATTENTION_WINDOW)])[0];
: reinterpret_cast<const int*>(inputs[getIdx(IdxEntry::HOST_MAX_ATTENTION_WINDOW)])[mLayerIdx];
const int sink_token_length = reinterpret_cast<const int*>(inputs[getIdx(IdxEntry::HOST_SINK_TOKEN_LENGTH)])[0];
const float* kv_scale_orig_quant = nullptr;
@ -503,14 +503,18 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
void* host_block_pointers = nullptr;
if (useKVCache() && mPagedKVCache)
{
auto& kvCacheBlockPointers = inputDesc[getIdx(IdxEntry::KV_CACHE_BLOCK_POINTERS)];
auto& kvCacheBlockPointersShape = inputDesc[getIdx(IdxEntry::KV_CACHE_BLOCK_POINTERS)].dims;
auto const& kvCacheBlockPointers = inputDesc[getIdx(IdxEntry::KV_CACHE_BLOCK_POINTERS)];
auto const& kvCacheBlockPointersShape = inputDesc[getIdx(IdxEntry::KV_CACHE_BLOCK_POINTERS)].dims;
max_blocks_per_sequence = kvCacheBlockPointersShape.d[kvCacheBlockPointersShape.nbDims - 1];
auto offset = getStride(kvCacheBlockPointersShape, 0) * seqIdxBeg;
auto const typed_block_pointers
auto const batchSize = kvCacheBlockPointersShape.d[1];
auto const seqStride = getStride(kvCacheBlockPointersShape, 1);
auto const layerOffset = mLayerIdx * batchSize * seqStride;
auto const seqOffset = seqIdxBeg * seqStride;
auto const offset = layerOffset + seqOffset;
auto const* const typed_block_pointers
= static_cast<void* const*>(inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_POINTERS)]) + offset;
block_pointers = const_cast<void*>(static_cast<void const*>(typed_block_pointers));
auto const typed_host_block_pointers
auto const* const typed_host_block_pointers
= static_cast<void* const*>(inputs[getIdx(IdxEntry::HOST_KV_CACHE_BLOCK_POINTERS)]) + offset;
host_block_pointers = const_cast<void*>(static_cast<void const*>(typed_host_block_pointers));
}
@ -762,9 +766,10 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(const char* name, const Plugi
try
{
auto* obj = new GPTAttentionPlugin(p.getScalar<int32_t>("num_heads").value(),
p.getScalar<int32_t>("num_kv_heads").value(), p.getScalar<int32_t>("head_size").value(),
p.getScalar<int32_t>("unidirectional").value(), p.getScalar<float>("q_scaling").value(),
auto* obj = new GPTAttentionPlugin(p.getScalar<int32_t>("layer_idx").value(),
p.getScalar<int32_t>("num_heads").value(), p.getScalar<int32_t>("num_kv_heads").value(),
p.getScalar<int32_t>("head_size").value(), p.getScalar<int32_t>("unidirectional").value(),
p.getScalar<float>("q_scaling").value(),
static_cast<PositionEmbeddingType>(p.getScalar<int8_t>("position_embedding_type").value()),
p.getScalar<int32_t>("rotary_embedding_dim").value(), p.getScalar<float>("rotary_embedding_base").value(),
static_cast<RotaryScalingType>(p.getScalar<int8_t>("rotary_embedding_scale_type").value()),

View File

@ -46,7 +46,7 @@ namespace tensorrt_llm::plugins
// enable_remove_input_padding
// 1. sequence_length [batch_size] (optional)
// 2. host_past_key_value_lengths [batch_size] (int32) (optional)
// 3. host_max_attention_window_sizes [1] (int32)
// 3. host_max_attention_window_sizes [num_layers] (int32)
// 4. host_sink_token_length [1] (int32)
// 5. context_lengths [batch_size]
// 6. cache_indir [num_gen_requests, beam_width, memory_max_len] (required in beamsearch) (optional)
@ -54,8 +54,8 @@ namespace tensorrt_llm::plugins
// mode,
// all elements must be identical.
// 8. past_key_value_pool [batch_size, 2, local_num_kv_heads, max_seq_len, head_size] or
// block_pointers [batch_size, 2, max_blocks_per_seq] if paged kv cache (optional)
// 8.1 host_block_pointers [batch_size, 2, max_blocks_per_seq] if paged kv cache (optional)
// block_pointers [num_layers, batch_size, 2, max_blocks_per_seq] if paged kv cache (optional)
// 8.1 host_block_pointers [num_layers, batch_size, 2, max_blocks_per_seq] if paged kv cache (optional)
// 9. kv_cache_quantization_scale [1] (optional)
// 10. kv_cache_dequantization_scale [1] (optional)
// 11. alibi_slopes [num_heads] (optional for ALiBi position embedding)
@ -70,8 +70,8 @@ namespace tensorrt_llm::plugins
class GPTAttentionPlugin : public GPTAttentionPluginCommon
{
public:
GPTAttentionPlugin(int num_heads, int num_kv_heads, int head_size, int unidirectional, float q_scaling,
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
GPTAttentionPlugin(int layer_idx, int num_heads, int num_kv_heads, int head_size, int unidirectional,
float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi

View File

@ -301,7 +301,7 @@ void LoraPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, in
}
mGemmId.n = N;
mGemmId.k = K;
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
int64_t getLowRankWorkSpaceSize(
@ -344,7 +344,7 @@ size_t LoraPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, in
void runCublasGemmEx(const int M, const int N, const int K, const bool transA, const bool transB, const void* act,
const void* weight, void* output, cublasHandle_t cublas_handle)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
float a = 1.0f;
float b = 0.0f;
void* alpha = &a;
@ -356,13 +356,13 @@ void runCublasGemmEx(const int M, const int N, const int K, const bool transA, c
tensorrt_llm::common::check_cuda_error(cublasGemmEx(cublas_handle, transa, transb, m, n, k, alpha, weight,
CUDA_R_16F, lda, act, CUDA_R_16F, ldb, beta, output, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
// inputs
// input [-1, K] (view as 2D)
// host_request_type [batch_size] on cpu
@ -565,7 +565,7 @@ int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf
}
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return 0;
}

View File

@ -19,16 +19,14 @@
#include "namedTensor.h"
#include "tensorrt_llm/batch_manager/GptManager.h"
#include "tensorrt_llm/batch_manager/callbacks.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/pybind/utils/pathCaster.h"
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/ops/tensor.h>
#include <memory>
#include <optional>
@ -38,17 +36,19 @@ namespace tensorrt_llm::pybind::batch_manager
{
GptManager::GptManager(std::filesystem::path const& trtEnginePath, tb::TrtGptModelType modelType, int32_t maxBeamWidth,
tb::batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback getInferenceRequestsCb,
SendResponseCallback sendResponseCb, tb::PollStopSignalCallback pollStopSignalCb,
tb::ReturnBatchManagerStatsCallback returnBatchManagerStatsCb, const tb::TrtGptModelOptionalParams& optionalParams,
std::optional<uint64_t> terminateReqId)
: tb::GptManager(trtEnginePath, modelType, maxBeamWidth, schedulerPolicy, callbackAdapter(getInferenceRequestsCb),
callbackAdapter(sendResponseCb), pollStopSignalCb, returnBatchManagerStatsCb, optionalParams, terminateReqId)
tb::batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback const& getInferenceRequestsCb,
SendResponseCallback const& sendResponseCb, const tb::PollStopSignalCallback& pollStopSignalCb,
tb::ReturnBatchManagerStatsCallback const& returnBatchManagerStatsCb,
tb::TrtGptModelOptionalParams const& optionalParams, std::optional<uint64_t> terminateReqId)
{
mManager = std::make_unique<tb::GptManager>(trtEnginePath, modelType, maxBeamWidth, schedulerPolicy,
callbackAdapter(getInferenceRequestsCb), callbackAdapter(sendResponseCb), pollStopSignalCb,
returnBatchManagerStatsCb, optionalParams, terminateReqId);
}
py::object GptManager::enter()
{
TLLM_CHECK(static_cast<bool>(mManager));
return py::cast(this);
}
@ -62,11 +62,14 @@ void GptManager::shutdown()
// NOTE: we must release the GIL here. GptManager has spawned a thread for the execution loop. That thread must be
// able to do forward progress for the shutdown process to succeed. It takes the GIL during its callbacks, so
// we release it now. Note that we shouldn't do anything related to python objects after that.
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
py::gil_scoped_release release;
tb::GptManager::shutdown();
mManager->shutdown();
mManager = nullptr;
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
}
tb::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback callback)
tb::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback const& callback)
{
return [callback](int32_t max_sequences)
{
@ -82,14 +85,14 @@ tb::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback ca
};
}
tb::SendResponseCallback callbackAdapter(SendResponseCallback callback)
tb::SendResponseCallback callbackAdapter(SendResponseCallback const& callback)
{
return [callback](uint64_t id, std::list<tb::NamedTensor> const& cppTensors, bool isOk, const std::string& errMsg)
{
std::list<NamedTensor> pythonList{};
for (const auto& cppNamedTensor : cppTensors)
{
pythonList.push_back(NamedTensor{cppNamedTensor});
pythonList.emplace_back(cppNamedTensor);
}
callback(id, pythonList, isOk, errMsg);
};

View File

@ -24,6 +24,7 @@
#include <ATen/ops/tensor.h>
#include <functional>
#include <memory>
namespace tensorrt_llm::pybind::batch_manager
{
@ -31,18 +32,18 @@ namespace tensorrt_llm::pybind::batch_manager
using GetInferenceRequestsCallback = std::function<std::list<InferenceRequest>(int32_t)>;
using SendResponseCallback = std::function<void(uint64_t, std::list<NamedTensor> const&, bool, const std::string&)>;
tensorrt_llm::batch_manager::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback callback);
tensorrt_llm::batch_manager::SendResponseCallback callbackAdapter(SendResponseCallback callback);
tensorrt_llm::batch_manager::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback const& callback);
tensorrt_llm::batch_manager::SendResponseCallback callbackAdapter(SendResponseCallback const& callback);
class GptManager : tensorrt_llm::batch_manager::GptManager
class GptManager
{
public:
GptManager(std::filesystem::path const& trtEnginePath, tensorrt_llm::batch_manager::TrtGptModelType modelType,
int32_t maxBeamWidth, tensorrt_llm::batch_manager::batch_scheduler::SchedulerPolicy schedulerPolicy,
GetInferenceRequestsCallback getInferenceRequestsCb, SendResponseCallback sendResponseCb,
tensorrt_llm::batch_manager::PollStopSignalCallback pollStopSignalCb = nullptr,
tensorrt_llm::batch_manager::ReturnBatchManagerStatsCallback returnBatchManagerStatsCb = nullptr,
const tensorrt_llm::batch_manager::TrtGptModelOptionalParams& optionalParams
GetInferenceRequestsCallback const& getInferenceRequestsCb, SendResponseCallback const& sendResponseCb,
tensorrt_llm::batch_manager::PollStopSignalCallback const& pollStopSignalCb = nullptr,
tensorrt_llm::batch_manager::ReturnBatchManagerStatsCallback const& returnBatchManagerStatsCb = nullptr,
tensorrt_llm::batch_manager::TrtGptModelOptionalParams const& optionalParams
= tensorrt_llm::batch_manager::TrtGptModelOptionalParams(),
std::optional<uint64_t> terminateReqId = std::nullopt);
@ -51,6 +52,9 @@ public:
void shutdown();
static void initBindings(pybind11::module_& m);
private:
std::unique_ptr<tensorrt_llm::batch_manager::GptManager> mManager;
};
} // namespace tensorrt_llm::pybind::batch_manager

View File

@ -17,6 +17,7 @@
#include "inferenceRequest.h"
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/runtime/torch.h"
#include "tensorrt_llm/runtime/torchView.h"
#include <memory>
@ -25,6 +26,15 @@
#include <pybind11/stl.h>
#include <torch/extension.h>
#ifdef _WIN32
// FIXME: THPStream_Wrap seems not to be present in libtorch_python.so on Windows
PyObject* THPStream_Wrap(const c10::Stream& stream)
{
TLLM_THROW("Stream conversion in not yet supported on Windows.");
return nullptr;
}
#endif
namespace tb = tensorrt_llm::batch_manager;
namespace tr = tensorrt_llm::runtime;
@ -32,6 +42,7 @@ using namespace tensorrt_llm::pybind::batch_manager;
namespace
{
std::shared_ptr<InferenceRequest> fromTrtLlm(tb::InferenceRequest const& request)
{
InferenceRequest::TensorMap tensorMap;
@ -53,18 +64,24 @@ std::shared_ptr<tb::InferenceRequest> InferenceRequest::toTrtLlm() const
tb::InferenceRequest::TensorMap tensorMap;
for (auto const& [name, tensor] : mInputTensors)
{
if (tensor.has_value())
{
tensorMap[name] = tr::TorchView::of(tensor.value());
}
tensorMap[name] = tr::TorchView::of(tensor);
}
auto inferenceRequest = std::make_shared<tb::InferenceRequest>(std::move(tensorMap), mRequestId);
inferenceRequest->setIsStreaming(isStreaming());
if (mlogitsPostProcessor)
{
inferenceRequest->setLogitsPostProcessor(LlmRequest::callbackAdapter(mlogitsPostProcessor));
}
return inferenceRequest;
}
std::string InferenceRequest::serialize() const
{
TLLM_CHECK_WITH_INFO(mlogitsPostProcessor == std::nullopt,
"Serializing InferenceRequest with logitsPostProcessor set is not supported."
"Please set the callback after de-serialization");
std::vector<std::int64_t> serialized{toTrtLlm()->serialize()};
static_assert(sizeof(decltype(serialized)::value_type) / sizeof(char) == 8);
return {reinterpret_cast<char const*>(serialized.data()), serialized.size() * 8};
@ -81,8 +98,12 @@ std::shared_ptr<InferenceRequest> InferenceRequest::deserialize(std::string cons
void InferenceRequest::initBindings(py::module_& m)
{
py::class_<InferenceRequest>(m, "InferenceRequest")
.def(py::init<uint64_t>())
.def(py::init<uint64_t, std::optional<LogitsProcessorCallback>>(), py::arg("request_id"),
py::arg("logits_post_processor_callback") = py::none())
.def(py::init<uint64_t, InferenceRequest::TensorMap const&>(), "deprecated: use direct tensor access instead")
.def_property("logits_post_processor",
nullptr, // passing logits processor in the cpp->python direction doesn't work. getter is then undefined
&InferenceRequest::setLogitsPostProcessor)
.def_property("input_ids", &InferenceRequest::getInputIdsUnchecked, &InferenceRequest::setInputIds)
.def_property(
"draft_input_ids", &InferenceRequest::getDraftInputIdsUnchecked, &InferenceRequest::setDraftInputIds)
@ -101,6 +122,8 @@ void InferenceRequest::initBindings(py::module_& m)
.def_property("runtime_top_p", &InferenceRequest::getRuntimeTopPUnchecked, &InferenceRequest::setRuntimeTopP)
.def_property(
"length_penalty", &InferenceRequest::getLengthPenaltyUnchecked, &InferenceRequest::setLengthPenalty)
.def_property(
"early_stopping", &InferenceRequest::getEarlyStoppingUnchecked, &InferenceRequest::setEarlyStopping)
.def_property("repetition_penalty", &InferenceRequest::getRepetitionPenaltyUnchecked,
&InferenceRequest::setRepetitionPenalty)
.def_property("min_length", &InferenceRequest::getMinLengthUnchecked, &InferenceRequest::setMinLength)

View File

@ -18,6 +18,7 @@
#pragma once
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
#include "tensorrt_llm/pybind/batch_manager/llmRequest.h"
#include "tensorrt_llm/pybind/batch_manager/namedTensor.h"
#include <ATen/ATen.h>
@ -30,25 +31,28 @@ namespace tensorrt_llm::pybind::batch_manager
{
class InferenceRequest
: public tensorrt_llm::batch_manager::GenericInferenceRequest<std::optional<at::Tensor>, NamedTensor>
: public tensorrt_llm::batch_manager::GenericInferenceRequest<at::Tensor, NamedTensor, c10::Stream>
{
public:
using Base = tensorrt_llm::batch_manager::GenericInferenceRequest<std::optional<at::Tensor>, NamedTensor>;
using Base = tensorrt_llm::batch_manager::GenericInferenceRequest<at::Tensor, NamedTensor, c10::Stream>;
using TensorPtr = Base::TensorPtr;
using TensorMap = Base::TensorMap;
using LogitsProcessorCallback = Base::LogitsPostProcessor;
InferenceRequest(uint64_t requestId)
: Base(requestId)
InferenceRequest(uint64_t requestId, std::optional<LogitsProcessorCallback> logitsCb = std::nullopt)
: Base(requestId, logitsCb)
{
}
InferenceRequest(uint64_t requestId, TensorMap const& inputTensors)
: Base{requestId, inputTensors}
InferenceRequest(uint64_t requestId, TensorMap const& inputTensors,
std::optional<LogitsProcessorCallback> logitsCb = std::nullopt)
: Base{requestId, inputTensors, logitsCb}
{
}
InferenceRequest(uint64_t requestId, TensorMap&& inputTensors)
: Base{requestId, std::move(inputTensors)}
InferenceRequest(
uint64_t requestId, TensorMap&& inputTensors, std::optional<LogitsProcessorCallback> logitsCb = std::nullopt)
: Base{requestId, std::move(inputTensors), logitsCb}
{
}

View File

@ -17,7 +17,10 @@
#include "llmRequest.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/generationInput.h"
#include "tensorrt_llm/runtime/torch.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/runtime/torchView.h"
#include <memory>
@ -45,6 +48,25 @@ std::optional<tb::LlmRequest::TensorPtr> from_torch(std::optional<LlmRequest::Te
} // namespace
std::optional<tb::LlmRequest::LogitsPostProcessor> LlmRequest::callbackAdapter(
std::optional<LlmRequest::LogitsPostProcessor> callback)
{
if (!callback)
{
return std::nullopt;
}
return [callback](RequestIdType reqId, tensorrt_llm::runtime::ITensor::SharedPtr& tensor,
tensorrt_llm::batch_manager::LlmRequest::BeamTokens const& tokens,
tensorrt_llm::runtime::BufferManager::CudaStreamPtr stream)
{
at::Tensor atTensor = tr::Torch::tensor(tensor);
auto result = callback.value()(reqId, atTensor, tokens, runtime::TorchUtils::stream(*stream).unwrap());
return tr::TorchView::of(result);
};
}
std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
{
auto embeddingBias = from_torch(mEmbeddingBias);
@ -58,7 +80,8 @@ std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
return std::make_shared<tb::LlmRequest>(mRequestId, mMaxNewTokens,
std::make_shared<std::vector<TokenIdType>>(mTokens.at(0)), mSamplingConfig, mIsStreaming, mEndId, mPadId,
embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable, mPromptVocabSize, loraWeights, loraConfig,
mReturnLogProbs, mReturnContextLogits, mReturnGenerationLogits, mDraftTokens, draftLogits);
mReturnLogProbs, mReturnContextLogits, mReturnGenerationLogits, mDraftTokens, draftLogits,
mExcludeInputFromOutput, callbackAdapter(mLogitsPostProcessor));
}
void LlmRequest::initBindings(py::module_& m)
@ -70,7 +93,7 @@ void LlmRequest::initBindings(py::module_& m)
std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::TensorPtr>,
std::optional<LlmRequest::SizeType>, std::optional<LlmRequest::TensorPtr>,
std::optional<LlmRequest::TensorPtr>, bool, bool, bool, std::optional<LlmRequest::VecTokens>,
std::optional<LlmRequest::TensorPtr>>(),
std::optional<LlmRequest::TensorPtr>, bool, std::optional<LlmRequest::LogitsPostProcessor>>(),
py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"),
py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt,
py::arg("embedding_bias") = std::nullopt, py::arg("bad_words_list") = std::nullopt,
@ -78,7 +101,8 @@ void LlmRequest::initBindings(py::module_& m)
py::arg("prompt_vocab_size") = std::nullopt, py::arg("lora_weights") = std::nullopt,
py::arg("lora_config") = std::nullopt, py::arg("return_log_probs") = false,
py::arg("return_context_logits") = false, py::arg("return_generation_logits") = false,
py::arg("draft_tokens") = std::nullopt, py::arg("draft_logits") = std::nullopt)
py::arg("draft_tokens") = std::nullopt, py::arg("draft_logits") = std::nullopt,
py::arg("exclude_input_from_output") = false, py::arg("logits_post_processor") = std::nullopt)
.def("get_num_tokens", &LlmRequest::getNumTokens, py::arg("beam"))
.def_property_readonly("max_beam_num_tokens", &LlmRequest::getMaxBeamNumTokens)
.def("get_token", &LlmRequest::getToken, py::arg("beam"), py::arg("pos"))

View File

@ -28,10 +28,16 @@
namespace tensorrt_llm::pybind::batch_manager
{
class LlmRequest : public tensorrt_llm::batch_manager::GenericLlmRequest<at::Tensor>
namespace tb = tensorrt_llm::batch_manager;
/* Unfortunately, torch's default pybind bindings don't know about c10::cuda::CUDAStream,
* so we have to pass the more generic c10::Stream, and convert it back to a full-fledged
* torch.cuda.Stream in python. See example in test/bindings/test_gpt_manager.py
*/
class LlmRequest : public tb::GenericLlmRequest<at::Tensor, c10::Stream>
{
public:
using Base = GenericLlmRequest<at::Tensor>;
using Base = GenericLlmRequest<at::Tensor, c10::Stream>;
using TensorPtr = Base::TensorPtr;
using SizeType = Base::SizeType;
using TokenIdType = Base::TokenIdType;
@ -39,6 +45,7 @@ public:
using VecLogProbs = Base::VecLogProbs;
using BeamTokens = Base::BeamTokens;
using VecTokens = Base::VecTokens;
using LogitsPostProcessor = Base::LogitsPostProcessor;
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::vector<TokenIdType> inputTokens,
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
@ -48,16 +55,20 @@ public:
std::optional<SizeType> promptVocabSize = std::nullopt, std::optional<TensorPtr> loraWeights = std::nullopt,
std::optional<TensorPtr> loraConfig = std::nullopt, bool returnLogProbs = false,
bool returnContextLogits = false, bool returnGenerationLogits = false,
std::optional<VecTokens> draftTokens = std::nullopt, std::optional<TensorPtr> draftLogits = std::nullopt)
std::optional<VecTokens> draftTokens = std::nullopt, std::optional<TensorPtr> draftLogits = std::nullopt,
bool excludeInputFromOutput = false, std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
: Base(requestId, maxNewTokens, std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)),
samplingConfig, isStreaming, endId, padId, embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable,
promptVocabSize, loraWeights, loraConfig, returnLogProbs, returnContextLogits, returnGenerationLogits,
draftTokens.has_value() ? std::make_shared<VecTokens>(std::move(draftTokens.value()))
: std::make_shared<VecTokens>(),
draftLogits)
draftLogits, excludeInputFromOutput, logitsPostProcessor)
{
}
static std::optional<tb::LlmRequest::LogitsPostProcessor> callbackAdapter(
std::optional<LlmRequest::LogitsPostProcessor> callback);
[[nodiscard]] std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> toTrtLlm() const;
static void initBindings(pybind11::module_& m);
};

View File

@ -37,6 +37,7 @@
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/gptJsonConfig.h"
#include "tensorrt_llm/runtime/gptSession.h"
#include "tensorrt_llm/runtime/memoryCounters.h"
#include "tensorrt_llm/runtime/samplingConfig.h"
namespace py = pybind11;
@ -232,7 +233,8 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.def_readwrite("top_p_min", &tr::SamplingConfig::topPMin)
.def_readwrite("top_p_reset_ids", &tr::SamplingConfig::topPResetIds)
.def_readwrite("beam_search_diversity_rate", &tr::SamplingConfig::beamSearchDiversityRate)
.def_readwrite("length_penalty", &tr::SamplingConfig::lengthPenalty);
.def_readwrite("length_penalty", &tr::SamplingConfig::lengthPenalty)
.def_readwrite("early_stopping", &tr::SamplingConfig::earlyStopping);
py::class_<tr::GptJsonConfig>(m, "GptJsonConfig")
.def(py::init<std::string, std::string, std::string, SizeType, SizeType, tr::GptModelConfig>(), py::arg("name"),
@ -302,6 +304,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
tensorNames.attr("RUNTIME_TOP_K") = py::str(tb::inference_request::kRuntimeTopKTensorName);
tensorNames.attr("RUNTIME_TOP_P") = py::str(tb::inference_request::kRuntimeTopPTensorName);
tensorNames.attr("LENGTH_PENALTY") = py::str(tb::inference_request::kLengthPenaltyTensorName);
tensorNames.attr("EARLY_STOPPING") = py::str(tb::inference_request::kEarlyStoppingTensorName);
tensorNames.attr("REPETITION_PENALTY") = py::str(tb::inference_request::kRepetitionPenaltyTensorName);
tensorNames.attr("MIN_LENGTH") = py::str(tb::inference_request::kMinLengthTensorName);
tensorNames.attr("PRESENCE_PENALTY") = py::str(tb::inference_request::kPresencePenaltyTensorName);
@ -342,4 +345,11 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.def_readwrite("decoding_mode", &tb::TrtGptModelOptionalParams::decodingMode);
tpb::GptManager::initBindings(m);
py::class_<tr::MemoryCounters>(m, "MemoryCounters")
.def_static("instance", &tr::MemoryCounters::getInstance, py::return_value_policy::reference)
.def_property_readonly("gpu", &tr::MemoryCounters::getGpu)
.def_property_readonly("cpu", &tr::MemoryCounters::getCpu)
.def_property_readonly("pinned", &tr::MemoryCounters::getPinned)
.def_property_readonly("uvm", &tr::MemoryCounters::getUVM);
}

View File

@ -19,7 +19,6 @@
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include <memory>
#include <string>
namespace tensorrt_llm::runtime

View File

@ -16,6 +16,7 @@
#include "tensorrt_llm/runtime/gptDecoder.h"
#include "tensorrt_llm/common/cudaAllocator.h"
#include "tensorrt_llm/common/tensorConversion.h"
#include "tensorrt_llm/kernels/decodingKernels.h"
#include "tensorrt_llm/layers/dynamicDecodeLayer.h"
@ -81,6 +82,7 @@ void GptDecoder<T>::setup(SamplingConfig const& samplingConfig, size_t batchSize
setupParams.beam_search_diversity_rate = samplingConfig.beamSearchDiversityRate;
setupParams.length_penalty = samplingConfig.lengthPenalty;
setupParams.early_stopping = samplingConfig.earlyStopping;
auto const batchSlotsPtr = batchSlots.has_value() ? bufferCast<SizeType>(*(batchSlots.value())) : nullptr;
mDynamicDecodeLayer->setup(batchSize, samplingConfig.beamWidth, batchSlotsPtr, setupParams);
@ -174,7 +176,7 @@ template <typename T>
typename tl::DynamicDecodeLayer<T>::OutputParams prepareOutputs(
DecodingOutput& output, DecodingInput::TensorPtr const& inputLengths, DecodingOutput::TensorPtr& logProbsTiled)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
typename tl::DynamicDecodeLayer<T>::OutputParams outputParams(tcc::toTllmTensor(*output.ids));
outputParams.newTokens = tcc::toTllmTensor(*output.newTokens);
@ -261,7 +263,7 @@ typename tl::DynamicDecodeLayer<T>::OutputParams prepareOutputs(
template <typename T>
bool GptDecoder<T>::forward(DecodingOutput& output, DecodingInput const& input)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto forwardParams = prepareInputs<T>(input);
auto outputParams = prepareOutputs<T>(output, input.lengths, mLogProbsTiled);
auto const maxBatchSize = input.maxBatchSize;
@ -309,7 +311,7 @@ bool GptDecoder<T>::forward(DecodingOutput& output, DecodingInput const& input)
template <typename T>
void GptDecoder<T>::forwardAsync(DecodingOutput& output, DecodingInput const& input)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto forwardParams = prepareInputs<T>(input);
auto outputParams = prepareOutputs<T>(output, input.lengths, mLogProbsTiled);
@ -321,7 +323,7 @@ template <typename T>
void GptDecoder<T>::gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput,
DecodingInput const& decodingInput, BufferManager const& manager)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const& finalOutputIdsShape = finalOutputIds.getShape();
auto const& decodingOutputIdsShape = decodingOutput.ids->getShape();
auto const batchSize = finalOutputIdsShape.d[0];
@ -355,7 +357,7 @@ void GptDecoder<T>::gatherTree(ITensor& finalOutputIds, DecodingOutput const& de
beamHypotheses.length_penalties
= nullptr; // TODO (bhsueh) should set length penalties, this should be a gpu tensor When it is set as
// nullptr, the kernel will use default value (1.0f) automatically.
beamHypotheses.early_stoppings = nullptr; // TODO (wili), similar as length_penalties
beamHypotheses.output_ids_tgt = bufferCast<TokenIdType>(*decodingOutput.beamHypotheses.outputIdsTgt);
beamHypotheses.sequence_lengths_tgt = bufferCast<SizeType>(*decodingOutput.beamHypotheses.sequenceLengthsTgt);
beamHypotheses.cum_log_probs = bufferCast<float>(*decodingOutput.beamHypotheses.cumLogProbs);
@ -381,7 +383,7 @@ void GptDecoder<T>::gatherTree(ITensor& finalOutputIds, DecodingOutput const& de
batchSize, stream.get());
sync_check_cuda_error();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
namespace tensorrt_llm::runtime
@ -394,7 +396,7 @@ void IGptDecoder::acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor
ITensor const& contextLengths, ITensor const& numDraftTokens, ITensor& sequenceLengths, ITensor const& finishedVec,
ITensor& finishedFinal, ITensor& finishedSum, ITensor const& batchSlots, BufferManager::CudaStreamPtr const& stream)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const finishedVecShape = finishedVec.getShape();
auto const maxBatchSize = finishedVecShape.d[1];
@ -439,7 +441,7 @@ void IGptDecoder::acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor
sync_check_cuda_error();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void IGptDecoder::acceptDraftTokensByLogits(ITensor& draftLogits, ITensor const& targetLogits, ITensor& draftProbs,
@ -447,7 +449,7 @@ void IGptDecoder::acceptDraftTokensByLogits(ITensor& draftLogits, ITensor const&
SizeType vocabSize, SizeType vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold,
curandState_t* curandState, BufferManager::CudaStreamPtr const& stream)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const draftLogitsShape = draftLogits.getShape();
auto const maxBatchSize = draftLogitsShape.d[0];
@ -488,5 +490,5 @@ void IGptDecoder::acceptDraftTokensByLogits(ITensor& draftLogits, ITensor const&
sync_check_cuda_error();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

View File

@ -22,6 +22,7 @@
#include "tensorrt_llm/runtime/runtimeKernels.h"
#include <algorithm>
#include <cassert>
#include <memory>
using namespace tensorrt_llm::runtime;
@ -33,7 +34,7 @@ namespace
{
SamplingConfig extractSamplingConfig(SamplingConfig const& batchSamplingConfig, SizeType batchIdx)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
SamplingConfig samplingConfig{batchSamplingConfig.beamWidth};
auto extractOptional = [&batchIdx](auto& single, auto const& batch)
@ -64,8 +65,9 @@ SamplingConfig extractSamplingConfig(SamplingConfig const& batchSamplingConfig,
// beam search layer
samplingConfig.beamSearchDiversityRate = batchSamplingConfig.beamSearchDiversityRate;
samplingConfig.lengthPenalty = batchSamplingConfig.lengthPenalty;
samplingConfig.earlyStopping = batchSamplingConfig.earlyStopping;
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return samplingConfig;
}
@ -78,7 +80,7 @@ GptDecoderBatch::GptDecoderBatch(
, mStream{std::move(stream)}
, mBufferManager{mStream}
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto constexpr nvTokenIdType = TRTDataType<TokenIdType>::value;
auto constexpr nvSizeType = TRTDataType<SizeType>::value;
auto constexpr nvFloatType = TRTDataType<float>::value;
@ -125,14 +127,14 @@ GptDecoderBatch::GptDecoderBatch(
dInput->badWordsLens = mBufferManager.emptyTensor(MemoryType::kPINNED, TRTDataType<SizeType>::value);
dInput->embeddingBias = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void GptDecoderBatch::setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth,
SizeType maxAttentionWindow, SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep,
bool fusedDecoder, nvinfer1::DataType dtype)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(maxBatchSize > 0);
TLLM_CHECK(maxBeamWidth > 0);
TLLM_CHECK(maxTokensPerStep > 0);
@ -253,13 +255,13 @@ void GptDecoderBatch::setup(DecodingMode const& mode, SizeType maxBatchSize, Siz
mBeamWidths[i] = 0;
mGeneratedTokensPerStep[i] = 0;
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void GptDecoderBatch::newRequest(
SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(batchIdx >= 0);
auto const& jointOutputIdsShape = mJointDecodingOutput->ids->getShape();
auto const batchSize = jointOutputIdsShape.d[0];
@ -373,7 +375,7 @@ void GptDecoderBatch::newRequest(
dOutput->newTokensVec.resize(mMaxTokensPerStep);
for (SizeType ti = 0; ti < mMaxTokensPerStep; ++ti)
{
TensorPtr newTokensStepView = std::move(ITensor::slice(dJointOutput.newTokensSteps, ti, localBatchSize));
TensorPtr newTokensStepView = ITensor::slice(dJointOutput.newTokensSteps, ti, localBatchSize);
newTokensStepView->squeeze(0);
dOutput->newTokensVec[ti] = ITensor::slice(newTokensStepView, batchIdx, localBatchSize);
manager.setZero(*dOutput->newTokensVec[ti]);
@ -469,7 +471,7 @@ void GptDecoderBatch::newRequest(
auto outputIdsView = ITensor::view(outputIds, ITensor::makeShape({beamWidth, mMaxSequenceLength}));
kernels::invokeFill(*outputIdsView, endId, *stream);
kernels::tileTensor(*outputIdsView, *inputIdsView, beamWidth, *stream);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void GptDecoderBatch::newRequests(std::vector<SizeType> const& seqSlots,
@ -500,7 +502,7 @@ void GptDecoderBatch::newRequests(std::vector<SizeType> const& seqSlots,
GptDecoderBatch::TokenPtr GptDecoderBatch::forwardAsync(
decoder_batch::Output& output, decoder_batch::Input const& input)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto& allTargetLogits = input.logits;
// TODO(nkorobov): check logits shape considering draft tokens
@ -737,13 +739,13 @@ GptDecoderBatch::TokenPtr GptDecoderBatch::forwardAsync(
CudaEvent eventStop{};
mStream->record(eventStop);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return std::make_unique<decoder_batch::Token>(std::move(eventStop), input.active);
}
void GptDecoderBatch::forwardSync(decoder_batch::Token const& token)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
token.event.synchronize();
for (std::int32_t i = 0; i < mActualBatchSize; ++i)
@ -756,13 +758,13 @@ void GptDecoderBatch::forwardSync(decoder_batch::Token const& token)
|| bufferCast<SizeType>(*dOutput.finishedSum)[0] == static_cast<SizeType>(mBeamWidths[i]);
}
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
// TODO call this at the end of forward if mFinished[i] changes from false to true?
CudaEvent GptDecoderBatch::postProcessRequest(SizeType batchIdx) const
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto& stream = mStreams[batchIdx];
auto manager = BufferManager{stream};
auto& decoder = *mDecoders[batchIdx];
@ -779,14 +781,14 @@ CudaEvent GptDecoderBatch::postProcessRequest(SizeType batchIdx) const
CudaEvent event{};
stream->record(event);
mStream->wait(event);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return event;
}
void GptDecoderBatch::newBatch(
GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
// split batch into single requests
auto const& inputLengths = inputs.lengths;
mActualBatchSize = inputLengths->getShape().d[0];
@ -855,12 +857,12 @@ void GptDecoderBatch::newBatch(
}
newRequest(batchIdx, request, extractSamplingConfig(samplingConfig, batchIdx));
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void GptDecoderBatch::forwardAsync(decoder::Output& output, decoder::Input const& input)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const& logitsShape = input.logits->getShape();
auto const batchSize = logitsShape.d[0];
@ -886,32 +888,32 @@ void GptDecoderBatch::forwardAsync(decoder::Output& output, decoder::Input const
kernels::reduce(*mFinishedSum, *ITensor::slice(mJointDecodingOutput->finishedSum, 0, mActualBatchSize), *mStream);
mStream->record(mForwardEvent);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void GptDecoderBatch::forwardSync()
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
forwardSync(*mForwardToken);
// wait for mFinishedSum to be updated
mForwardEvent.synchronize();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void GptDecoderBatch::finalize() const
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
for (SizeType batchIdx = 0; batchIdx < mActualBatchSize; ++batchIdx)
{
postProcessRequest(batchIdx);
auto event = postProcessRequest(batchIdx);
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
CudaEvent GptDecoderBatch::finalize(SizeType batchIdx) const
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto event = postProcessRequest(batchIdx);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return event;
}

View File

@ -16,6 +16,7 @@
#include "tensorrt_llm/runtime/gptJsonConfig.h"
#include "common.h"
#include "gptModelConfig.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
@ -24,6 +25,7 @@
#include <fstream>
#include <nlohmann/json.hpp>
#include <string_view>
#include <utility>
using namespace tensorrt_llm::runtime;
namespace tc = tensorrt_llm::common;
@ -69,12 +71,122 @@ std::optional<FieldType> parseJsonFieldOptional(Json const& json, std::string_vi
return value;
}
GptModelConfig createModelConfig(
Json const& json, bool engineVersionNone, SizeType tensorParallelism, nvinfer1::DataType dataType)
{
auto const& config = engineVersionNone ? json.at("builder_config") : json.at("pretrained_config");
auto const* const numLayersField = engineVersionNone ? "num_layers" : "num_hidden_layers";
auto const* const numHeadsField = engineVersionNone ? "num_heads" : "num_attention_heads";
auto const* const numKvHeadsField = engineVersionNone ? "num_kv_heads" : "num_key_value_heads";
auto const* const mlpHiddenSizeField = engineVersionNone ? "mlp_hidden_size" : "intermediate_size";
auto const numLayers = config.at(numLayersField).template get<SizeType>();
auto const numHeads = config.at(numHeadsField).template get<SizeType>() / tensorParallelism;
auto const vocabSize = config.at("vocab_size").template get<SizeType>();
auto const hiddenSize = config.at("hidden_size").template get<SizeType>() / tensorParallelism;
auto const sizePerHead = parseJsonFieldOr(config, "head_size", hiddenSize / numHeads);
// TODO:
// Code crashes when numKvHeads <= 0. Clamping downwards to 1 prevents that, make sure this is best fix.
auto const numKvHeads
= std::max(parseJsonFieldOr(config, numKvHeadsField, numHeads * tensorParallelism) / tensorParallelism, 1);
auto const mlpHiddenSize = parseJsonFieldOptional<SizeType>(config, mlpHiddenSizeField);
auto modelConfig = GptModelConfig{vocabSize, numLayers, numHeads, hiddenSize, dataType};
modelConfig.setSizePerHead(sizePerHead);
modelConfig.setNbKvHeads(numKvHeads);
if (mlpHiddenSize.has_value())
{
modelConfig.setMlpHiddenSize(mlpHiddenSize.value() / tensorParallelism);
}
return modelConfig;
};
void parseBuilderConfig(GptModelConfig& modelConfig, Json const& builderConfig)
{
auto const maxBatchSize = parseJsonFieldOr(builderConfig, "max_batch_size", 0);
auto const maxBeamWidth = parseJsonFieldOr(builderConfig, "max_beam_width", 0);
auto const maxInputLen = parseJsonFieldOr(builderConfig, "max_input_len", 0);
auto const maxSequenceLen = maxInputLen + parseJsonFieldOr(builderConfig, "max_output_len", 0);
auto const maxDraftLen = parseJsonFieldOr(builderConfig, "max_draft_len", 0);
auto const maxNumTokens = parseJsonFieldOptional<SizeType>(builderConfig, "max_num_tokens");
auto const maxPromptEmbeddingTableSize
= parseJsonFieldOr<SizeType>(builderConfig, "max_prompt_embedding_table_size", 0);
auto const computeContextLogits = parseJsonFieldOr(builderConfig, "gather_context_logits", false);
auto const computeGenerationLogits = parseJsonFieldOr(builderConfig, "gather_generation_logits", false);
modelConfig.setMaxBatchSize(maxBatchSize);
modelConfig.setMaxBeamWidth(maxBeamWidth);
modelConfig.setMaxInputLen(maxInputLen);
modelConfig.setMaxSequenceLen(maxSequenceLen);
modelConfig.setMaxNumTokens(maxNumTokens);
modelConfig.setMaxDraftLen(maxDraftLen);
modelConfig.setMaxPromptEmbeddingTableSize(maxPromptEmbeddingTableSize);
modelConfig.computeContextLogits(computeContextLogits);
modelConfig.computeGenerationLogits(computeGenerationLogits);
}
void parsePluginConfig(GptModelConfig& modelConfig, Json const& pluginConfig)
{
auto const useGptAttentionPlugin = !pluginConfig.at("gpt_attention_plugin").is_null();
auto const removeInputPadding = pluginConfig.at("remove_input_padding").template get<bool>();
auto const& pagedKvCache = pluginConfig.at("paged_kv_cache");
auto const& tokensPerBlock = pluginConfig.at("tokens_per_block");
auto const useCustomAllReduce = pluginConfig.at("use_custom_all_reduce").template get<bool>();
auto const useContextFMHAForGeneration = pluginConfig.at("use_context_fmha_for_generation").template get<bool>();
auto const pagedContextFMHA = pluginConfig.at("use_paged_context_fmha").template get<bool>();
modelConfig.useGptAttentionPlugin(useGptAttentionPlugin);
modelConfig.usePackedInput(removeInputPadding);
modelConfig.usePagedKvCache(pagedKvCache);
modelConfig.setTokensPerBlock(tokensPerBlock);
modelConfig.useCustomAllReduce(useCustomAllReduce);
modelConfig.setUseContextFMHAForGeneration(useContextFMHAForGeneration);
modelConfig.setPagedContextFMHA(pagedContextFMHA);
}
void parseLora(GptModelConfig& modelConfig, Json const& json, Json const& pluginConfig, bool engineVersionNone,
SizeType tensorParallelism)
{
auto const& config = engineVersionNone ? json.at("builder_config") : json.at("pretrained_config");
auto const loraMaxRank = parseJsonFieldOr(config, "max_lora_rank", SizeType{0});
auto const loraTargetModules = parseJsonFieldOptional<std::vector<std::string>>(config, "lora_target_modules");
if (loraTargetModules.has_value())
{
modelConfig.setLoraModules(LoraModule::createLoraModules(loraTargetModules.value(), modelConfig.getHiddenSize(),
modelConfig.getMlpHiddenSize(), modelConfig.getNbHeads(), modelConfig.getNbKvHeads(),
modelConfig.getSizePerHead(), tensorParallelism));
}
modelConfig.setMaxLoraRank(loraMaxRank);
auto useLoraPlugin = !pluginConfig.at("lora_plugin").is_null();
if (useLoraPlugin)
{
if (modelConfig.getLoraModules().empty() || modelConfig.getMaxLoraRank() == 0)
{
TLLM_LOG_WARNING("lora_plugin enabled, but no lora module enabled: setting useLoraPlugin to false");
useLoraPlugin = false;
}
}
modelConfig.useLoraPlugin(useLoraPlugin);
}
template <typename InputType>
GptJsonConfig parseJson(InputType&& i)
GptJsonConfig parseJson(InputType&& input)
{
auto constexpr allowExceptions = true;
auto constexpr ingoreComments = true;
auto const json = nlohmann::json::parse(i, nullptr, allowExceptions, ingoreComments);
auto const json = nlohmann::json::parse(std::forward<InputType>(input), nullptr, allowExceptions, ingoreComments);
auto const engineVersion = parseJsonFieldOr(json, "version", std::string("none"));
@ -104,113 +216,26 @@ GptJsonConfig parseJson(InputType&& i)
auto const precision = engineVersionNone ? builderConfig.at("precision").template get<std::string>()
: json.at("pretrained_config").at("dtype").template get<std::string>();
auto dataType = nvinfer1::DataType::kFLOAT;
if (!precision.compare("float32"))
dataType = nvinfer1::DataType::kFLOAT;
else if (!precision.compare("float16"))
dataType = nvinfer1::DataType::kHALF;
else if (!precision.compare("bfloat16"))
dataType = nvinfer1::DataType::kBF16;
else
TLLM_CHECK_WITH_INFO(false, tc::fmtstr("Model data type '%s' not supported", precision.c_str()));
auto modelConfig = [&engineVersionNone, &json, &builderConfig, &tensorParallelism, &dataType]()
auto const dataType = [&precision]()
{
auto const& config = engineVersionNone ? builderConfig : json.at("pretrained_config");
std::string const numLayersField = engineVersionNone ? "num_layers" : "num_hidden_layers";
std::string const numHeadsField = engineVersionNone ? "num_heads" : "num_attention_heads";
std::string const numKvHeadsField = engineVersionNone ? "num_kv_heads" : "num_key_value_heads";
std::string const mlpHiddenSizeField = engineVersionNone ? "mlp_hidden_size" : "intermediate_size";
auto const numLayers = config.at(numLayersField).template get<SizeType>();
auto const numHeads = config.at(numHeadsField).template get<SizeType>() / tensorParallelism;
auto const vocabSize = config.at("vocab_size").template get<SizeType>();
auto const hiddenSize = config.at("hidden_size").template get<SizeType>() / tensorParallelism;
auto const sizePerHead = parseJsonFieldOr(config, "head_size", hiddenSize / numHeads);
auto const loraMaxRank = parseJsonFieldOr(config, "max_lora_rank", SizeType{0});
auto const loraTargetModules = parseJsonFieldOptional<std::vector<std::string>>(config, "lora_target_modules");
// TODO:
// Code crashes when numKvHeads <= 0. Clamping downwards to 1 prevents that, make sure this is best fix.
auto const numKvHeads
= std::max(parseJsonFieldOr(config, numKvHeadsField, numHeads * tensorParallelism) / tensorParallelism, 1);
auto const mlpHiddenSize = parseJsonFieldOptional<SizeType>(config, mlpHiddenSizeField);
auto modelConfig = GptModelConfig{vocabSize, numLayers, numHeads, hiddenSize, dataType};
modelConfig.setSizePerHead(sizePerHead);
modelConfig.setNbKvHeads(numKvHeads);
if (mlpHiddenSize.has_value())
{
modelConfig.setMlpHiddenSize(mlpHiddenSize.value() / tensorParallelism);
}
if (loraTargetModules.has_value())
{
modelConfig.setLoraModules(LoraModule::createLoraModules(loraTargetModules.value(),
modelConfig.getHiddenSize(), modelConfig.getMlpHiddenSize(), modelConfig.getNbHeads(),
modelConfig.getNbKvHeads(), modelConfig.getSizePerHead(), tensorParallelism));
}
modelConfig.setMaxLoraRank(loraMaxRank);
return modelConfig;
if (!precision.compare("float32"))
return nvinfer1::DataType::kFLOAT;
else if (!precision.compare("float16"))
return nvinfer1::DataType::kHALF;
else if (!precision.compare("bfloat16"))
return nvinfer1::DataType::kBF16;
else
TLLM_THROW("Model data type '%s' not supported", precision.c_str());
}();
auto const maxBatchSize = parseJsonFieldOr(builderConfig, "max_batch_size", 0);
auto const maxBeamWidth = parseJsonFieldOr(builderConfig, "max_beam_width", 0);
auto const maxInputLen = parseJsonFieldOr(builderConfig, "max_input_len", 0);
auto const maxSequenceLen = maxInputLen + parseJsonFieldOr(builderConfig, "max_output_len", 0);
auto const maxDraftLen = parseJsonFieldOr(builderConfig, "max_draft_len", 0);
auto const maxNumTokens = parseJsonFieldOptional<SizeType>(builderConfig, "max_num_tokens");
auto const maxPromptEmbeddingTableSize
= parseJsonFieldOr<SizeType>(builderConfig, "max_prompt_embedding_table_size", 0);
auto const computeContextLogits = parseJsonFieldOr(builderConfig, "gather_context_logits", false);
auto const computeGenerationLogits = parseJsonFieldOr(builderConfig, "gather_generation_logits", false);
auto modelConfig = createModelConfig(json, engineVersionNone, tensorParallelism, dataType);
modelConfig.setMaxBatchSize(maxBatchSize);
modelConfig.setMaxBeamWidth(maxBeamWidth);
modelConfig.setMaxInputLen(maxInputLen);
modelConfig.setMaxSequenceLen(maxSequenceLen);
modelConfig.setMaxNumTokens(maxNumTokens);
modelConfig.setMaxDraftLen(maxDraftLen);
modelConfig.setMaxPromptEmbeddingTableSize(maxPromptEmbeddingTableSize);
modelConfig.computeContextLogits(computeContextLogits);
modelConfig.computeGenerationLogits(computeGenerationLogits);
parseBuilderConfig(modelConfig, builderConfig);
auto const& pluginConfig = engineVersionNone ? json.at("plugin_config") : builderConfig.at("plugin_config");
parsePluginConfig(modelConfig, pluginConfig);
auto const useGptAttentionPlugin = !pluginConfig.at("gpt_attention_plugin").is_null();
auto const removeInputPadding = pluginConfig.at("remove_input_padding").template get<bool>();
auto const pagedKvCache = pluginConfig.at("paged_kv_cache");
auto const tokensPerBlock = pluginConfig.at("tokens_per_block");
auto const useCustomAllReduce = pluginConfig.at("use_custom_all_reduce").template get<bool>();
auto const useContextFMHAForGeneration = pluginConfig.at("use_context_fmha_for_generation").template get<bool>();
auto const pagedContextFMHA = pluginConfig.at("use_paged_context_fmha").template get<bool>();
modelConfig.useGptAttentionPlugin(useGptAttentionPlugin);
modelConfig.usePackedInput(removeInputPadding);
modelConfig.usePagedKvCache(pagedKvCache);
modelConfig.setTokensPerBlock(tokensPerBlock);
modelConfig.useCustomAllReduce(useCustomAllReduce);
modelConfig.setUseContextFMHAForGeneration(useContextFMHAForGeneration);
modelConfig.setPagedContextFMHA(pagedContextFMHA);
auto useLoraPlugin = !pluginConfig.at("lora_plugin").is_null();
if (useLoraPlugin)
{
if (modelConfig.getLoraModules().empty() || modelConfig.getMaxLoraRank() == 0)
{
TLLM_LOG_WARNING("lora_plugin enabled, but no lora module enabled: setting useLoraPlugin to false");
useLoraPlugin = false;
}
}
modelConfig.useLoraPlugin(useLoraPlugin);
parseLora(modelConfig, json, pluginConfig, engineVersionNone, tensorParallelism);
if (engineVersionNone)
{

View File

@ -23,7 +23,6 @@
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/common/customAllReduceUtils.h"
#include "tensorrt_llm/common/stringUtils.h"
#include "tensorrt_llm/kernels/decodingKernels.h"
#include "tensorrt_llm/runtime/gptDecoderBatch.h"
#include "tensorrt_llm/runtime/ipcUtils.h"
#include "tensorrt_llm/runtime/ncclCommunicator.h"
@ -189,29 +188,20 @@ void GptSession::createKvCacheManager(SizeType batchSize, SizeType beamWidth, Si
kvDtype = mModelConfig.getDataType();
}
auto const maxNumTokens
= bmkv::KVCacheManager::getMaxNumTokens(kvCacheConfig, kvDtype, mModelConfig, mWorldConfig, getBufferManager());
TLLM_LOG_INFO("Using %d tokens in paged KV cache.", maxNumTokens);
auto const maxNumBlocks = tc::ceilDiv(maxNumTokens, tokensPerBlock);
auto const maxNumBlocks = bmkv::KVCacheManager::calculateMaxNumBlocks(
kvCacheConfig, kvDtype, mModelConfig, mWorldConfig, getBufferManager());
SizeType sinkTokensInLastBlock = sinkTokenLength % tokensPerBlock;
SizeType bubbleLen = sinkTokensInLastBlock != 0 ? tokensPerBlock - sinkTokensInLastBlock : 0;
auto maxBlocksPerSeq = tc::ceilDiv(maxAttentionWindow + bubbleLen, tokensPerBlock);
// If beamWidth > 1, use one more block for each sequence in the paged kv cache to avoid dropping the needed
// tokens, when enabling cyclic kv cache.
auto const useOneMoreBlock = beamWidth > 1 && maxSequenceLength > maxAttentionWindow;
if (useOneMoreBlock)
{
maxBlocksPerSeq += 1;
}
auto const localNbLayers = mModelConfig.getNbLayers(mWorldConfig.getPipelineParallelism());
auto const nbKvHeads = mModelConfig.getNbKvHeads();
auto const sizePerHead = mModelConfig.getSizePerHead();
bool enableBlockReuse{false};
bool constexpr enableBlockReuse{false};
mKvCacheManager = std::make_shared<bmkv::KVCacheManager>(localNbLayers, nbKvHeads, sizePerHead, tokensPerBlock,
maxNumBlocks, batchSize, beamWidth, maxBlocksPerSeq, maxAttentionWindow, sinkTokenLength, useOneMoreBlock,
kvDtype, mRuntime->getStreamPtr(), enableBlockReuse, kvCacheConfig.useUvm);
maxNumBlocks, batchSize, beamWidth, maxAttentionWindow, sinkTokenLength, useOneMoreBlock, kvDtype,
mRuntime->getStreamPtr(), enableBlockReuse, kvCacheConfig.useUvm);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
@ -335,12 +325,14 @@ void GptSession::setup(Config const& sessionConfig)
createCustomAllReduceWorkspace(mMicroBatchConfig.genBatchSize, maxBeamWidth, maxSequenceLength);
}
auto* kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManager.get() : nullptr;
for (auto& buffers : mBuffers)
{
// we don't know maxInputLength yet and ignore it for pre-allocation
buffers->generationConfig = RuntimeBuffers::GenerationConfig{
mMicroBatchConfig.genBatchSize, maxBeamWidth, 0, maxAttentionWindow, sinkTokenLength, maxSequenceLength};
buffers->reshape(mModelConfig, mWorldConfig);
buffers->reshape(kvCacheManager, mModelConfig, mWorldConfig);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
@ -684,6 +676,8 @@ void GptSession::generateBatched(std::vector<GenerationOutput>& microBatchesOutp
TLLM_CHECK(numMicroBatches <= mMicroBatchConfig.numGenBatches);
SizeType const beamWidth{samplingConfig.beamWidth};
auto* kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManager.get() : nullptr;
// Initialize and reshape buffers
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
{
@ -691,7 +685,7 @@ void GptSession::generateBatched(std::vector<GenerationOutput>& microBatchesOutp
auto& buffers = *mBuffers.at(microBatchId);
buffers.initFromInput(*microBatchInputs.ids, microBatchInputs.lengths, microBatchInputs.packed, beamWidth,
mDecoderMaxAttentionWindow, mDecoderSinkTokenLength, mDecoderMaxSequenceLength, manager);
buffers.reshape(mModelConfig, mWorldConfig);
buffers.reshape(kvCacheManager, mModelConfig, mWorldConfig);
buffers.reset(manager);
}
@ -746,8 +740,6 @@ void GptSession::generateBatched(std::vector<GenerationOutput>& microBatchesOutp
}
}
auto kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManager.get() : nullptr;
auto const profileContext = !kProfileMbIdxs.empty() && kProfileMbIdxs.count(0) > 0;
if (profileContext)
cudaProfilerStart();

View File

@ -49,7 +49,7 @@ LoraManager::LoraReqTensors& LoraManager::getTask(TaskIdType reqId)
void LoraManager::create(
GptModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& manager)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto modules = modelConfig.getLoraModules();
SizeType modOff = 0;
@ -62,14 +62,14 @@ void LoraManager::create(
// TODO set this size from max adapter size
mWorkspace = manager.emptyTensor(MemoryType::kGPU, modelConfig.getDataType());
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes, ReqIdsVec const& reqIds,
std::vector<SizeType> const& reqBeamWidth, std::vector<bool> const& loraEnabled, SizeType numContextRequests,
GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism());
auto firstLayerId = worldConfig.getPipelineParallelRank() * localNbLayers;
auto lastLayerId = firstLayerId + localNbLayers;
@ -85,13 +85,13 @@ void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes
fillInputTensors(
weightsPtrs, adapterSizes, bid, reqIds[bid], reqBeamWidth[bid], firstLayerId, lastLayerId, tpSize, tpRank);
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes, SizeType batchIdx, TaskIdType taskId,
SizeType beamWidth, SizeType firstLayerId, SizeType lastLayerId, SizeType tpSize, SizeType tpRank)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto weightsPointersPtr = bufferCast<int64_t>(*weightsPtrs);
auto adapterSizesPtr = bufferCast<int32_t>(*adapterSizes);
@ -172,13 +172,13 @@ void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes
}
std::fill_n(writeAdapterSizesPtr, beamWidth, adapterSize);
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void LoraManager::insertInputTensors(TensorMap& inputTensors, TensorPtr weightsPtrs, TensorPtr adapterSizes,
GptModelConfig const& modelConfig, WorldConfig const& worldConfig) const
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism());
auto firstLayerId = worldConfig.getPipelineParallelRank() * localNbLayers;
@ -210,13 +210,13 @@ void LoraManager::insertInputTensors(TensorMap& inputTensors, TensorPtr weightsP
}
}
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void LoraManager::formatTaskTensors(LoraWeightsTensorPtr weights, LoraConfigTensorPtr config,
GptModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& manager)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
weights->squeeze(0);
config->squeeze(0);
@ -261,7 +261,7 @@ void LoraManager::formatTaskTensors(LoraWeightsTensorPtr weights, LoraConfigTens
manager.copy(*mWorkspace, *weightsOut);
}
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void LoraManager::reset()

View File

@ -38,13 +38,18 @@ std::vector<LoraModule> LoraModule::createLoraModules(std::vector<std::string> c
{
case ModuleType::kATTN_QKV:
case ModuleType::kCROSS_ATTN_QKV:
modules.emplace_back(
LoraModule(t, hidden, (numHeads * attnHeadSize + 2 * numKvHeads * attnHeadSize), false, true, -1, 0));
break;
case ModuleType::kATTN_Q:
case ModuleType::kATTN_K:
case ModuleType::kATTN_V: modules.emplace_back(t, hidden, hidden, false, true, -1, 0); break;
case ModuleType::kATTN_DENSE: modules.emplace_back(t, hidden, hidden, false, true, 1, -1); break;
case ModuleType::kATTN_V:
case ModuleType::kCROSS_ATTN_Q:
case ModuleType::kCROSS_ATTN_K:
case ModuleType::kCROSS_ATTN_V: modules.emplace_back(t, hidden, hidden, false, true, -1, 0); break;
case ModuleType::kATTN_DENSE:
case ModuleType::kCROSS_ATTN_DENSE: modules.emplace_back(t, hidden, hidden, false, true, 1, -1); break;
case ModuleType::kMLP_H_TO_4H: modules.emplace_back(t, hidden, mlpHidden, false, true, -1, 0); break;
case ModuleType::kMLP_GATE: modules.emplace_back(t, hidden, mlpHidden, false, true, -1, 0); break;
case ModuleType::kMLP_4H_TO_H: modules.emplace_back(t, mlpHiddenSize, hidden, false, true, 1, -1); break;

View File

@ -39,6 +39,11 @@ public:
kMLP_H_TO_4H = 5,
kMLP_4H_TO_H = 6,
kMLP_GATE = 7,
kCROSS_ATTN_QKV = 8,
kCROSS_ATTN_Q = 9,
kCROSS_ATTN_K = 10,
kCROSS_ATTN_V = 11,
kCROSS_ATTN_DENSE = 12,
};
explicit constexpr LoraModule(ModuleType const& t, SizeType inDim, SizeType outDim, bool inDimFirst,
@ -128,6 +133,16 @@ public:
return ModuleType::kMLP_4H_TO_H;
else if (name == "mlp_gate")
return ModuleType::kMLP_GATE;
else if (name == "cross_attn_qkv")
return ModuleType::kCROSS_ATTN_QKV;
else if (name == "cross_attn_q")
return ModuleType::kCROSS_ATTN_Q;
else if (name == "cross_attn_k")
return ModuleType::kCROSS_ATTN_K;
else if (name == "cross_attn_v")
return ModuleType::kCROSS_ATTN_V;
else if (name == "cross_attn_dense")
return ModuleType::kCROSS_ATTN_DENSE;
else
return ModuleType::kINVALID;
}
@ -144,6 +159,11 @@ public:
case ModuleType::kMLP_H_TO_4H: return "mlp_h_to_4h";
case ModuleType::kMLP_4H_TO_H: return "mlp_4h_to_h";
case ModuleType::kMLP_GATE: return "mlp_gate";
case ModuleType::kCROSS_ATTN_QKV: return "cross_attn_qkv";
case ModuleType::kCROSS_ATTN_Q: return "cross_attn_q";
case ModuleType::kCROSS_ATTN_K: return "cross_attn_k";
case ModuleType::kCROSS_ATTN_V: return "cross_attn_v";
case ModuleType::kCROSS_ATTN_DENSE: return "cross_attn_dense";
case ModuleType::kINVALID: return "INVALID";
}
return "INVALID";

View File

@ -82,4 +82,10 @@ void MemoryCounters::deallocate(MemoryType memoryType, MemoryCounters::SizeType
default: TLLM_THROW("Unknown memory type");
}
}
MemoryCounters& MemoryCounters::getInstance()
{
static MemoryCounters mInstance;
return mInstance;
}
} // namespace tensorrt_llm::runtime

View File

@ -17,13 +17,13 @@
#include "tensorrt_llm/runtime/runtimeBuffers.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/stlUtils.h"
#include "tensorrt_llm/runtime/runtimeKernels.h"
#include "tensorrt_llm/runtime/tllmRuntime.h"
#include "tensorrt_llm/runtime/utils/sessionUtils.h"
#include <algorithm>
#include <iostream>
using namespace tensorrt_llm::runtime;
namespace tc = tensorrt_llm::common;
@ -32,7 +32,7 @@ RuntimeBuffers::GenerationConfig RuntimeBuffers::GenerationConfig::fromInput(ITe
ITensor const& inputLengthsHost, bool const inputPacked, SizeType const beamWidth,
SizeType const maxAttentionWindow, SizeType const sinkTokenLength, SizeType const maxSequenceLength)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = static_cast<SizeType>(inputLengthsHost.getSize());
auto const* inputLengthsPtr = bufferCast<SizeType>(inputLengthsHost);
@ -71,14 +71,14 @@ RuntimeBuffers::GenerationConfig RuntimeBuffers::GenerationConfig::fromInput(ITe
"Max input length is equal to or larger that maxSequenceLength given in setup. No new tokens can be "
"generated.");
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return GenerationConfig{
batchSize, beamWidth, maxInputLength, maxAttentionWindow, sinkTokenLength, maxSequenceLength, inputLengthSum};
}
void RuntimeBuffers::clear()
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
contextLengthsHost = nullptr;
contextLengthsDevice = nullptr;
@ -104,22 +104,22 @@ void RuntimeBuffers::clear()
hiddenStates = nullptr;
allocated = false;
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void RuntimeBuffers::clearTensorMaps()
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
for (auto& buffer : inputBuffers)
buffer.clear();
for (auto& buffer : outputBuffers)
buffer.clear();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto& manager = runtime.getBufferManager();
auto& engine = runtime.getEngine();
@ -170,8 +170,7 @@ void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelCon
if (modelConfig.usePagedKvCache())
{
auto const kvCacheBlockPointersType
= engine.getTensorDataType(("kv_cache_block_pointers_" + std::to_string(firstLayerId)).c_str());
auto const kvCacheBlockPointersType = engine.getTensorDataType("kv_cache_block_pointers");
kvCacheBlockPointersHost = manager.emptyTensor(MemoryType::kCPU, kvCacheBlockPointersType);
kvCacheBlockPointersDevice = manager.emptyTensor(MemoryType::kGPU, kvCacheBlockPointersType);
}
@ -183,8 +182,7 @@ void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelCon
if (modelConfig.useGptAttentionPlugin())
{
pastKeyValueLengths = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32);
maxAttentionWindows
= utils::createBufferVector(runtime, localNbLayers, MemoryType::kCPU, nvinfer1::DataType::kINT32);
maxAttentionWindows = BufferManager::cpu(ITensor::makeShape({localNbLayers}), nvinfer1::DataType::kINT32);
sinkTokenLengths = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32);
}
else
@ -207,7 +205,7 @@ void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelCon
hiddenStates = manager.emptyTensor(MemoryType::kGPU, modelConfig.getDataType());
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void RuntimeBuffers::initFromInput(ITensor const& inputIds, TensorPtr const& inputLengths, bool inputPacked,
@ -223,15 +221,15 @@ void RuntimeBuffers::initFromInput(ITensor const& inputIds, TensorPtr const& inp
inputIds, *contextLengthsHost, inputPacked, beamWidth, maxAttentionWindow, sinkTokenLength, maxSequenceLength);
}
void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
void RuntimeBuffers::reshape(
KvCacheManager const* kvCacheManager, GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = generationConfig.batchSize;
auto const beamWidth = generationConfig.beamWidth;
auto const maxInputLength = generationConfig.maxInputLength;
auto const maxAttentionWindow = generationConfig.maxAttentionWindow;
auto const sinkTokenLen = generationConfig.sinkTokenLength;
auto const maxSeqLength = generationConfig.maxSeqLength;
auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize());
@ -259,13 +257,12 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
if (modelConfig.computeGenerationLogits())
{
allGenerationLogits->reshape(
ITensor::makeShape({(generationConfig.maxSeqLength - generationConfig.maxInputLength), batchSize,
beamWidth, vocabSizePadded}));
ITensor::makeShape({(maxSeqLength - maxInputLength), batchSize, beamWidth, vocabSizePadded}));
cacheGenerationFragmentPointerDevice->reshape(
ITensor::makeShape({batchSize, (generationConfig.maxSeqLength - generationConfig.maxInputLength)}));
ITensor::makeShape({batchSize, (maxSeqLength - maxInputLength)}));
cacheGenerationFragmentPointerHost->reshape(
ITensor::makeShape({batchSize, (generationConfig.maxSeqLength - generationConfig.maxInputLength)}));
ITensor::makeShape({batchSize, (maxSeqLength - maxInputLength)}));
}
}
@ -277,18 +274,11 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
= ITensor::makeShape({batchSize, 2, modelConfig.getNbKvHeads(), maxInputLength, modelConfig.getSizePerHead()});
if (modelConfig.usePagedKvCache())
{
auto const localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism());
auto const tokensPerBlock = modelConfig.getTokensPerBlock();
SizeType bubbleLen
= (sinkTokenLen % tokensPerBlock == 0) ? 0 : tokensPerBlock - (sinkTokenLen % tokensPerBlock);
auto maxBlocksPerSeq = tc::ceilDiv(maxAttentionWindow + bubbleLen, tokensPerBlock);
// If beamWidth > 1, use one more block for each sequence in the paged kv cache to avoid dropping the needed
// tokens, when enabling cyclic kv cache.
if (beamWidth > 1 && maxSeqLength > maxAttentionWindow)
{
maxBlocksPerSeq += 1;
}
TLLM_CHECK(kvCacheManager);
auto const localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism());
auto const maxBlocksPerSeq = kvCacheManager->getMaxBlocksPerSeq();
// reserve batchSize * beamWidth and resize to batchSize
auto cacheBlockPointersShape = ITensor::makeShape({localNbLayers, batchSize * beamWidth, 2, maxBlocksPerSeq});
kvCacheBlockPointersHost->reshape(cacheBlockPointersShape);
@ -302,11 +292,13 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
utils::reshapeBufferVector(presentKeysVals, kvCacheReserve);
}
auto const localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism());
if (modelConfig.useGptAttentionPlugin())
{
pastKeyValueLengths->reshape(ITensor::makeShape({batchSize}));
requestTypes->reshape(ITensor::makeShape({batchSize}));
utils::reshapeBufferVector(maxAttentionWindows, ITensor::makeShape({1}));
maxAttentionWindows->reshape(ITensor::makeShape({localNbLayers}));
sinkTokenLengths->reshape(ITensor::makeShape({1}));
}
else
@ -332,7 +324,7 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
}
allocated = true;
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void RuntimeBuffers::reset(BufferManager& manager)
@ -345,7 +337,7 @@ void RuntimeBuffers::reset(BufferManager& manager)
std::vector<RuntimeBuffers> RuntimeBuffers::split(
SizeType contextBatchSize, GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
std::vector<RuntimeBuffers> bufferSlices;
auto const generationBatchSize = generationConfig.batchSize;
@ -432,14 +424,14 @@ std::vector<RuntimeBuffers> RuntimeBuffers::split(
}
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return bufferSlices;
}
void RuntimeBuffers::gatherLastTokenLogits(
BufferManager& manager, GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK_WITH_INFO(modelConfig.computeContextLogits(),
"Gather last token logits is only necessary when context logits are computed");
@ -459,12 +451,12 @@ void RuntimeBuffers::gatherLastTokenLogits(
}
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void RuntimeBuffers::copyAttentionMasks(std::vector<RuntimeBuffers> const& contextBatches, BufferManager& manager)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = generationConfig.batchSize;
auto const maxInputLength = generationConfig.maxInputLength;
@ -481,12 +473,12 @@ void RuntimeBuffers::copyAttentionMasks(std::vector<RuntimeBuffers> const& conte
manager.copy(*buffers.attentionMask, *attentionMaskSlice);
offset += contextBatchSize;
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void RuntimeBuffers::tile(BufferManager& manager, GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const beamWidth = generationConfig.beamWidth;
TLLM_CHECK_WITH_INFO(beamWidth > 1, "Tiling is only necessary for beam search.");
@ -519,13 +511,13 @@ void RuntimeBuffers::tile(BufferManager& manager, GptModelConfig const& modelCon
for (auto& buffer : presentKeysValsAlt)
utils::tileBufferReplace(buffer, beamWidth, manager);
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void RuntimeBuffers::postContextStep(std::vector<RuntimeBuffers> const& contextBuffers, BufferManager& manager,
GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = generationConfig.batchSize;
auto const beamWidth = generationConfig.beamWidth;
@ -580,14 +572,14 @@ void RuntimeBuffers::postContextStep(std::vector<RuntimeBuffers> const& contextB
manager, modelConfig.usePackedInput());
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType const padId, BufferManager& manager,
KvCacheManager const* kvCacheManager, SizeType firstBatchSlotIdx, GptModelConfig const& modelConfig,
WorldConfig const& worldConfig)
batch_manager::kv_cache_manager::KVCacheManager const* kvCacheManager, SizeType firstBatchSlotIdx,
GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto& stream = manager.getStream();
SizeType const batchSize = generationConfig.batchSize;
SizeType const maxInputLength = generationConfig.maxInputLength;
@ -607,11 +599,9 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c
TLLM_CHECK(requestTypes->getSize() == static_cast<std::size_t>(batchSize));
std::fill_n(RequestTypesPtr, batchSize, 0);
// Set maxAttentionWindows buffer and sinkTokenLengths to the same value currently.
for (auto layer = 0; layer < localNbLayers; ++layer)
{
bufferCast<SizeType>(*maxAttentionWindows[layer])[0] = generationConfig.maxAttentionWindow;
}
auto maxAttentionWindowsPtr = bufferCast<SizeType>(*maxAttentionWindows);
std::fill_n(maxAttentionWindowsPtr, localNbLayers, generationConfig.maxAttentionWindow);
bufferCast<SizeType>(*sinkTokenLengths)[0] = generationConfig.sinkTokenLength;
auto const& inputShape = inputIds->getShape();
@ -720,14 +710,14 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c
manager.copy(*contextLengthsDevice, *lastTokenIds);
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
RuntimeBuffers::TensorPtr RuntimeBuffers::prepareNextStep(SizeType const step, BufferManager& manager,
KvCacheManager* kvCacheManager, SizeType firstBatchSlotIdx, GptModelConfig const& modelConfig,
WorldConfig const& worldConfig)
batch_manager::kv_cache_manager::KVCacheManager* kvCacheManager, SizeType firstBatchSlotIdx,
GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto& stream = manager.getStream();
SizeType const batchSize = generationConfig.batchSize;
SizeType const beamWidth = generationConfig.beamWidth;
@ -842,7 +832,7 @@ RuntimeBuffers::TensorPtr RuntimeBuffers::prepareNextStep(SizeType const step, B
{
kernels::invokeInclusiveSum(*lastTokenIds, *lastTokenIds, manager, stream);
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return nextInputIds;
}
@ -850,7 +840,7 @@ void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outpu
TensorPtr const& inputIds, TensorPtr const& commPtrs, GptModelConfig const& modelConfig,
WorldConfig const& worldConfig) const
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
inputBuffers.clear();
outputBuffers.clear();
@ -890,7 +880,7 @@ void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outpu
inputBuffers.insert_or_assign("host_request_types", requestTypes);
inputBuffers.insert_or_assign("sequence_length", sequenceLengths);
inputBuffers.insert_or_assign("host_sink_token_length", sinkTokenLengths);
utils::insertTensorVector(inputBuffers, "host_max_attention_window_size_", maxAttentionWindows, firstLayerId);
inputBuffers.insert_or_assign("host_max_attention_window_sizes", maxAttentionWindows);
if (modelConfig.usePackedInput())
{
@ -898,10 +888,8 @@ void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outpu
}
if (modelConfig.usePagedKvCache())
{
utils::insertTensorSlices(
inputBuffers, "kv_cache_block_pointers_", kvCacheBlockPointersDevice, firstLayerId);
utils::insertTensorSlices(
inputBuffers, "host_kv_cache_block_pointers_", kvCacheBlockPointersHost, firstLayerId);
inputBuffers.insert_or_assign("kv_cache_block_pointers", kvCacheBlockPointersDevice);
inputBuffers.insert_or_assign("host_kv_cache_block_pointers", kvCacheBlockPointersHost);
}
else
{
@ -946,7 +934,10 @@ void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outpu
inputBuffers.insert_or_assign("tasks", promptTuningParams.tasks);
inputBuffers.insert_or_assign("prompt_vocab_size", promptTuningParams.vocabSize);
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
// utils::printTensorMap(std::cerr, inputBuffers);
// utils::printTensorMap(std::cerr, outputBuffers);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
std::vector<SizeType> RuntimeBuffers::getPositionIdsContextPhaseGlm(const SizeType& batchSize,

View File

@ -99,11 +99,11 @@ public:
// still point to `allGenerationLogits` and bring overwrite conflict.
std::vector<TensorPtr> presentKeysVals;
std::vector<TensorPtr> presentKeysValsAlt; // without attention plugin
std::vector<TensorPtr> maxAttentionWindows; // with attention plugin, host tensor
TensorPtr sinkTokenLengths; // with attention plugin, host tensor
TensorPtr kvCacheBlockPointersHost; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
TensorPtr kvCacheBlockPointersDevice; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
std::vector<TensorPtr> presentKeysValsAlt; // without attention plugin
TensorPtr maxAttentionWindows; // with attention plugin, host tensor
TensorPtr sinkTokenLengths; // with attention plugin, host tensor
TensorPtr kvCacheBlockPointersHost; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
TensorPtr kvCacheBlockPointersDevice; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
// References to tmp buffers
TensorPtr newTokens;
@ -147,7 +147,8 @@ public:
SizeType maxAttentionWindow, SizeType sinkTokenLength, SizeType maxSequenceLength, BufferManager& manager);
//! \brief Reshape buffers based on current GenerationConfig
void reshape(GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
void reshape(
KvCacheManager const* kvCacheManager, GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
void reset(BufferManager& manager);

View File

@ -417,6 +417,21 @@ void invokeInclusiveSum(IBuffer& output, IBuffer const& input, BufferManager con
cub::DeviceScan::InclusiveSum(tempStorageData, tempStorageBytes, inputData, outputData, size, stream.get());
}
void invokeInclusiveSum(IBuffer& output, IBuffer& tmpBuffer, IBuffer const& input, CudaStream const& stream)
{
TLLM_CHECK_WITH_INFO(nvinfer1::DataType::kUINT8 == tmpBuffer.getDataType(), "tmpBuffer has wrong data type");
auto const size = input.getSize();
auto const* inputData = bufferCast<SizeType>(input);
auto* outputData = bufferCast<SizeType>(output);
std::size_t tempStorageBytes{0};
cub::DeviceScan::InclusiveSum(nullptr, tempStorageBytes, inputData, outputData, size, stream.get());
tmpBuffer.resize(tempStorageBytes);
auto* tmpBufferPtr = bufferCast<std::uint8_t>(tmpBuffer);
cub::DeviceScan::InclusiveSum(tmpBufferPtr, tempStorageBytes, inputData, outputData, size, stream.get());
}
namespace
{
__global__ void buildTokenMask(SizeType* tokenMask, SizeType const* inputLengths, SizeType const batchSize,
@ -752,7 +767,7 @@ void initOutputIds(ITensor& outputIds, ITensor const& inputIds, ITensor const& i
ITensor const& inputOffsets, TokenIdType const padId, TokenIdType const endId, SizeType const maxInputLength,
bool const inputPacked, CudaStream const& stream)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
kernels::invokeFill(outputIds, endId, stream);
if (inputPacked)
@ -763,7 +778,7 @@ void initOutputIds(ITensor& outputIds, ITensor const& inputIds, ITensor const& i
{
kernels::invokeCopyInputToOutput(outputIds, inputIds, inputLengths, padId, stream);
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
namespace

View File

@ -51,6 +51,8 @@ void invokeTransposeWithInputOffset(
void invokeInclusiveSum(IBuffer& output, IBuffer const& input, BufferManager const& manager, CudaStream const& stream);
void invokeInclusiveSum(IBuffer& output, IBuffer& tmpBuffer, IBuffer const& input, CudaStream const& stream);
void invokeBuildTokenMask(
ITensor& tokenMask, ITensor const& inputLengths, SizeType maxInputLength, CudaStream const& stream);

View File

@ -16,13 +16,12 @@
#include "tensorrt_llm/runtime/statefulGptDecoder.h"
#include <algorithm>
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/kernels/decodingCommon.h"
#include "tensorrt_llm/runtime/runtimeKernels.h"
#include <algorithm>
namespace tc = tensorrt_llm::common;
namespace tk = tensorrt_llm::kernels;
using namespace tensorrt_llm::runtime;
@ -35,7 +34,7 @@ StatefulGptDecoder::StatefulGptDecoder(std::size_t vocabSize, std::size_t vocabS
, mStream{std::move(stream)}
, mBufferManager{mStream}
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto constexpr nvTokenIdType = TRTDataType<TokenIdType>::value;
auto constexpr nvSizeType = TRTDataType<SizeType>::value;
auto constexpr nvFloatType = TRTDataType<float>::value;
@ -68,26 +67,26 @@ StatefulGptDecoder::StatefulGptDecoder(std::size_t vocabSize, std::size_t vocabS
mFinishedSum = mBufferManager.pinned(ITensor::makeShape({1}), nvSizeType);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void StatefulGptDecoder::setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth,
SizeType maxAttentionWindow, SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep,
bool fusedDecoder, nvinfer1::DataType dtype)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(maxTokensPerStep == 1);
mDecoder = IGptDecoder::create(
mode, dtype, maxBatchSize, maxBeamWidth, mVocabSize, mVocabSizePadded, maxSequenceLength, mStream);
reshapeBuffers(maxBatchSize, maxBeamWidth, maxAttentionWindow, sinkTokenLength, maxSequenceLength);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void StatefulGptDecoder::reshapeBuffers(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow,
SizeType sinkTokenLength, SizeType maxSequenceLength)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(batchSize > 0);
TLLM_CHECK(beamWidth > 0);
TLLM_CHECK(maxSequenceLength > 0);
@ -139,13 +138,13 @@ void StatefulGptDecoder::reshapeBuffers(SizeType batchSize, SizeType beamWidth,
}
mNbSteps = 0;
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void StatefulGptDecoder::newBatch(
GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto& manager = mBufferManager;
auto& stream = mStream;
@ -296,12 +295,12 @@ void StatefulGptDecoder::newBatch(
// remaining
mNbSteps = 0;
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void StatefulGptDecoder::forwardAsync(decoder::Output& output, decoder::Input const& input)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto& logits = input.logits;
auto const& logitsShape = logits->getShape();
@ -335,24 +334,24 @@ void StatefulGptDecoder::forwardAsync(decoder::Output& output, decoder::Input co
dInput.step += 1;
mNbSteps += 1;
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void StatefulGptDecoder::forwardSync()
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
mDecodedEvent.synchronize();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void StatefulGptDecoder::finalize() const
{
// TODO (rkobus) can we do this inplace?
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto& outputIds = mDecodingOutput->ids;
auto finalOutputIds = mBufferManager.gpu(outputIds->getShape(), outputIds->getDataType());
mDecoder->gatherTree(*finalOutputIds, *mDecodingOutput, *mDecodingInput, mBufferManager);
mBufferManager.copy(*finalOutputIds, *outputIds);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return;
}

View File

@ -18,17 +18,12 @@
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/decodingMode.h"
#include "tensorrt_llm/runtime/gptDecoder.h"
#include "tensorrt_llm/runtime/iStatefulGptDecoder.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
#include <vector>
namespace tensorrt_llm::runtime
{
@ -88,7 +83,7 @@ public:
//! @returns [batchSize, maxBeamWidth], tokens generated in last forward pass, on gpu
[[nodiscard]] TensorPtr getAllNewTokens() const override
{
TensorPtr newTokens = std::move(ITensor::view(mDecodingOutput->newTokensSteps));
TensorPtr newTokens = ITensor::view(mDecodingOutput->newTokensSteps);
newTokens->unsqueeze(0);
return newTokens;
}

View File

@ -19,7 +19,6 @@
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iTensor.h"
@ -30,7 +29,6 @@
#include <algorithm>
#include <cstdlib>
#include <limits>
#include <list>
#include <memory>
#include <mutex>

View File

@ -14,9 +14,7 @@
* limitations under the License.
*/
#include "tllmRuntime.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/nvtxUtils.h"
#include "tensorrt_llm/common/stringUtils.h"
#include "tllmLogger.h"
#include <limits>
@ -24,8 +22,6 @@
using namespace tensorrt_llm::runtime;
namespace tc = tensorrt_llm::common;
namespace
{
using DimType = std::remove_reference_t<decltype(std::declval<nvinfer1::Dims>().d[0])>;
@ -108,14 +104,12 @@ bool TllmRuntime::executeContext(SizeType contextIndex) const
void TllmRuntime::setInputTensors(SizeType contextIndex, TensorMap const& tensorMap)
{
NVTX3_FUNC_RANGE();
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
auto& context = getContext(contextIndex);
for (std::int32_t i = 0; i < mEngine->getNbIOTensors(); ++i)
{
auto const name = mEngine->getIOTensorName(i);
if (mEngine->getTensorIOMode(name) == nvinfer1::TensorIOMode::kINPUT)
{
NVTX3_SCOPED_RANGE(input_tensor);
auto pos = tensorMap.find(name);
if (pos == tensorMap.end())
{
@ -197,7 +191,6 @@ void TllmRuntime::setOutputTensors(SizeType contextIndex, TensorMap& tensorMap)
auto const name = mEngine->getIOTensorName(i);
if (mEngine->getTensorIOMode(name) == nvinfer1::TensorIOMode::kOUTPUT)
{
NVTX3_SCOPED_RANGE(output_tensor);
auto const dims = context.getTensorShape(name);
auto const engineDtype = mEngine->getTensorDataType(name);
auto pos = tensorMap.find(name);

View File

@ -23,8 +23,6 @@
#include <ATen/ATen.h>
#include <torch/types.h>
#include <memory>
namespace tensorrt_llm::runtime
{
class TorchView : virtual public ITensor

View File

@ -115,6 +115,16 @@ void insertTensorSlices(
}
}
void printTensorMap(std::ostream& stream, StringPtrMap<ITensor> const& map)
{
for (auto const& [name, tensor] : map)
{
stream << "Tensor name: " << name << '\n';
stream << "Shape" << tensor->getShape() << '\n';
stream << *tensor << '\n';
}
}
void setRawPointers(ITensor& pointers, ITensor::SharedPtr const& input, int32_t pointersSlot, int32_t inputSlot)
{
auto const pointersLength = static_cast<int32_t>(pointers.getSizeInBytes() / sizeof(void**));

View File

@ -64,6 +64,8 @@ void insertTensorVector(StringPtrMap<ITensor>& map, std::string const& key, std:
void insertTensorSlices(
StringPtrMap<ITensor>& map, std::string const& key, ITensor::SharedPtr const& tensor, SizeType indexOffset);
void printTensorMap(std::ostream& stream, StringPtrMap<ITensor> const& map);
void setRawPointers(ITensor& pointers, ITensor::SharedPtr const& input, int32_t pointersSlot, int32_t inputSlot);
void setRawPointers(ITensor& pointers, ITensor::SharedPtr const& input);

View File

@ -85,8 +85,9 @@ WorldConfig WorldConfig::mpi(SizeType gpusPerNode, std::optional<SizeType> tenso
auto const mpiSize = comm.getSize();
auto const mpiRank = comm.getRank();
TLLM_LOG_INFO("MPI size: %d, rank: %d", mpiSize, mpiRank);
auto pp = pipelineParallelism.value_or(1);
auto tp = tensorParallelism.value_or(mpiSize / pp);
auto const pp = pipelineParallelism.value_or(1);
auto const tp = tensorParallelism.value_or(mpiSize / pp);
TLLM_LOG_DEBUG("TP: %d, PP: %d", tp, pp);
TLLM_CHECK(mpiSize == tp * pp);
return WorldConfig{tp, pp, mpiRank, gpusPerNode, deviceIds};

View File

@ -103,11 +103,12 @@ void FtDynamicDecode<T>::setup(size_t batch_size, size_t beam_width, th::optiona
th::optional<th::Tensor> runtime_top_p_opt, th::optional<th::Tensor> temperature_opt,
th::optional<th::Tensor> repetition_penalty_opt, th::optional<th::Tensor> presence_penalty_opt,
th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> min_length_opt,
th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> beam_search_diversity_rate_opt,
th::optional<th::Tensor> random_seed_opt, th::optional<th::Tensor> top_p_decay_opt,
th::optional<th::Tensor> top_p_min_opt, th::optional<th::Tensor> top_p_reset_ids_opt)
th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> early_stopping_opt,
th::optional<th::Tensor> beam_search_diversity_rate_opt, th::optional<th::Tensor> random_seed_opt,
th::optional<th::Tensor> top_p_decay_opt, th::optional<th::Tensor> top_p_min_opt,
th::optional<th::Tensor> top_p_reset_ids_opt)
{
// unused: length_penalty_opt, beam_search_diversity_rate_opt
// unused: length_penalty_opt, beam_search_diversity_rate_opt, early_stopping_opt
auto stream = at::cuda::getCurrentCUDAStream().stream();
dynamic_decode_layer_->setStream(stream);
@ -127,6 +128,7 @@ void FtDynamicDecode<T>::setup(size_t batch_size, size_t beam_width, th::optiona
safeInsert(top_p_reset_ids_opt, setupParams.top_p_reset_ids);
safeInsert(beam_search_diversity_rate_opt, setupParams.beam_search_diversity_rate);
safeInsert(length_penalty_opt, setupParams.length_penalty);
safeInsert(early_stopping_opt, setupParams.early_stopping);
dynamic_decode_layer_->setup(batch_size, beam_width, nullptr, setupParams);
}
@ -211,13 +213,13 @@ void FtDynamicDecode<T>::forward(th::Tensor& logits, // (batch_size, beam_width,
dynamic_decode_layer_->forward(outputParams, forwardParams);
if (finished_sum_host)
{
TLLM_CUDA_CHECK(::cudaStreamSynchronize(dynamic_decode_layer_->getStream()));
int32_t finished_sum = 0;
for (int32_t bi = 0; bi < local_batch_size; ++bi)
{
finished_sum += finished_sum_host[bi];
}
auto const numToFinish = outputParams.finished->size();
TLLM_CUDA_CHECK(::cudaStreamSynchronize(dynamic_decode_layer_->getStream()));
auto should_stop_accessor = should_stop.accessor<bool, 1>();
should_stop_accessor[0] = numToFinish == finished_sum;
}
@ -259,9 +261,10 @@ void DynamicDecodeOp::setup(int64_t batch_size, int64_t beam_width, th::optional
th::optional<th::Tensor> runtime_top_p_opt, th::optional<th::Tensor> temperature_opt,
th::optional<th::Tensor> repetition_penalty_opt, th::optional<th::Tensor> presence_penalty_opt,
th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> min_length_opt,
th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> beam_search_diversity_rate_opt,
th::optional<th::Tensor> random_seed_opt, th::optional<th::Tensor> top_p_decay_opt,
th::optional<th::Tensor> top_p_min_opt, th::optional<th::Tensor> top_p_reset_ids_opt)
th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> early_stopping_opt,
th::optional<th::Tensor> beam_search_diversity_rate_opt, th::optional<th::Tensor> random_seed_opt,
th::optional<th::Tensor> top_p_decay_opt, th::optional<th::Tensor> top_p_min_opt,
th::optional<th::Tensor> top_p_reset_ids_opt)
{
// TODO: Revise DynamicDecodeLayer and make the decode arguments consistent.
CHECK_OPTIONAL_CPU_INPUT(runtime_top_k_opt, torch::kInt32);
@ -273,6 +276,7 @@ void DynamicDecodeOp::setup(int64_t batch_size, int64_t beam_width, th::optional
CHECK_OPTIONAL_CPU_INPUT(frequency_penalty_opt, torch::kFloat);
CHECK_OPTIONAL_CPU_INPUT(min_length_opt, torch::kInt32);
CHECK_OPTIONAL_CPU_INPUT(length_penalty_opt, torch::kFloat);
CHECK_OPTIONAL_CPU_INPUT(early_stopping_opt, torch::kInt32);
CHECK_OPTIONAL_CPU_INPUT(beam_search_diversity_rate_opt, torch::kFloat);
CHECK_OPTIONAL_CPU_INPUT(random_seed_opt, torch::kInt64);
CHECK_OPTIONAL_INPUT(top_p_decay_opt, torch::kFloat);
@ -281,8 +285,8 @@ void DynamicDecodeOp::setup(int64_t batch_size, int64_t beam_width, th::optional
dynamic_decode_->setup(static_cast<size_t>(batch_size), static_cast<size_t>(beam_width), runtime_top_k_opt,
runtime_top_p_opt, temperature_opt, repetition_penalty_opt, presence_penalty_opt, frequency_penalty_opt,
min_length_opt, length_penalty_opt, beam_search_diversity_rate_opt, random_seed_opt, top_p_decay_opt,
top_p_min_opt, top_p_reset_ids_opt);
min_length_opt, length_penalty_opt, early_stopping_opt, beam_search_diversity_rate_opt, random_seed_opt,
top_p_decay_opt, top_p_min_opt, top_p_reset_ids_opt);
}
th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max_input_length,

Some files were not shown because too many files have changed in this diff Show More