Update TensorRT-LLM (#1168)

* Update TensorRT-LLM

---------

Co-authored-by: Bhuvanesh Sridharan <bhuvan.sridharan@gmail.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
Kaiyu Xie 2024-02-27 17:37:34 +08:00 committed by GitHub
parent e4e09dafea
commit 655524dd82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
229 changed files with 4810 additions and 4097 deletions

View File

@ -428,7 +428,7 @@ public:
, mStaticEmulatedTimeoutMs(staticEmulatedTimeoutMs) , mStaticEmulatedTimeoutMs(staticEmulatedTimeoutMs)
, mActiveCount(0) , mActiveCount(0)
{ {
ReturnBatchManagerStatsCallback iterationDataCallback = [this, &logIterationData](std::string const& log) ReturnBatchManagerStatsCallback iterationDataCallback = [this, logIterationData](std::string const& log)
{ {
if (logIterationData) if (logIterationData)
{ {
@ -563,16 +563,18 @@ public:
{ {
auto numNewWorkItems = static_cast<int64_t>(rval.size()); auto numNewWorkItems = static_cast<int64_t>(rval.size());
comm.bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0); comm.bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0);
if (numNewWorkItems > 0)
std::vector<int64_t> packed;
for (auto const& ir : rval)
{ {
auto vpacked = ir->serialize(); std::vector<int64_t> packed;
packed.push_back(static_cast<int64_t>(vpacked.size())); for (auto const& ir : rval)
packed.insert( {
packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end())); auto vpacked = ir->serialize();
packed.push_back(static_cast<int64_t>(vpacked.size()));
packed.insert(
packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
}
comm.bcast(packed, 0);
} }
comm.bcast(packed, 0);
} }
} }
else else
@ -791,7 +793,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
recorder->report(); recorder->report();
recorder->writeOpMetricsToCsv(); recorder->writeOpMetricsToCsv();
// Send terminateReqId to terminate servers on all ranks // Send terminateReqId to terminate servers on all ranks
// Sever on rank 0 will broadcast the terminate signal to other servers on multi-GPU cases // Server on rank 0 will broadcast the terminate signal to other servers on multi-GPU cases
gptServer->enqueue(std::make_shared<InferenceRequest>(terminateReqId)); gptServer->enqueue(std::make_shared<InferenceRequest>(terminateReqId));
} }
// Wait until benchmarking is done and batch manager is terminated // Wait until benchmarking is done and batch manager is terminated

View File

@ -114,9 +114,9 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
"benchmark on %d tokens", "benchmark on %d tokens",
maxNumTokens.value(), maxBatchSize * maxInputLength); maxNumTokens.value(), maxBatchSize * maxInputLength);
} }
std::atomic_bool done = false;
try try
{ {
std::atomic_bool done = false;
auto peakMemFuture = std::async(&monitorMemory, std::ref(done)); auto peakMemFuture = std::async(&monitorMemory, std::ref(done));
TLLM_LOG_INFO(memoryCounter.toString()); TLLM_LOG_INFO(memoryCounter.toString());
@ -266,11 +266,14 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
catch (std::runtime_error& e) catch (std::runtime_error& e)
{ {
std::size_t found = std::string(e.what()).find("out of memory"); std::size_t found = std::string(e.what()).find("out of memory");
// We need to kill the memory monitor when OOM.
done = true;
// Unexpected error; rethrow // Unexpected error; rethrow
if (found == std::string::npos) if (found == std::string::npos)
{ {
throw; TLLM_LOG_ERROR(e.what());
throw e;
} }
// We can ignore the OOM exception and continue the rest of the benchmark // We can ignore the OOM exception and continue the rest of the benchmark
@ -283,6 +286,12 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
} }
continue; continue;
} }
catch (...)
{
// We need to kill memory monitor when any other issue occurs
done = true;
throw;
}
} }
TLLM_LOG_INFO(memoryCounter.toString()); TLLM_LOG_INFO(memoryCounter.toString());
} }

View File

@ -85,6 +85,7 @@ class EncDecBuildConfig:
max_decoder_input_len: Optional[int] = None max_decoder_input_len: Optional[int] = None
max_output_len: Optional[int] = None max_output_len: Optional[int] = None
builder_opt: Optional[int] = None builder_opt: Optional[int] = None
n_mels: Optional[int] = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
assert self.head_size is not None assert self.head_size is not None
@ -1179,6 +1180,27 @@ _allowed_configs = {
mamba_d_conv=4, mamba_d_conv=4,
mamba_expand=2, mamba_expand=2,
)), )),
"whisper_large_v3":
ModelConfig(name="whisper_large_v3",
family="whisper",
benchmark_type="enc_dec",
build_config=EncDecBuildConfig(
num_layers=32,
num_decoder_layers=32,
num_heads=20,
head_size=64,
ffn_hidden_size=5120,
hidden_size=1280,
vocab_size=51866,
hidden_act="gelu",
n_positions=448,
n_mels=128,
max_batch_size=8,
max_encoder_input_len=1500,
max_decoder_input_len=1,
max_output_len=200,
builder_opt=None,
)),
} }

View File

@ -795,7 +795,7 @@ def build_gpt(args):
max_beam_width=max_beam_width) max_beam_width=max_beam_width)
if family in [ if family in [
'opt', 'bloom', 'falcon', 'llama', 'internlm', 'gptneox', 'opt', 'bloom', 'falcon', 'llama', 'internlm', 'gptneox',
'gptj', 'mamba', 'baichuan' 'gptj', 'mamba', 'baichuan', 'chatglm', 'chatglm2', 'chatglm3'
]: ]:
tensorrt_llm_model(**inputs) tensorrt_llm_model(**inputs)
else: else:
@ -957,6 +957,8 @@ def enc_dec_build_helper(component, config, args):
torch.cuda.set_device(runtime_rank) torch.cuda.set_device(runtime_rank)
family = get_model_family(args.model) family = get_model_family(args.model)
logits_dtype = 'float32'
n_mels = 0
if family == 'bart': if family == 'bart':
q_scaling = 1.0 q_scaling = 1.0
has_attention_qkvo_bias = True has_attention_qkvo_bias = True
@ -969,6 +971,19 @@ def enc_dec_build_helper(component, config, args):
layernorm_position = LayerNormPositionType.pre_layernorm if config.get( layernorm_position = LayerNormPositionType.pre_layernorm if config.get(
'normalize_before', True) else LayerNormPositionType.post_layernorm 'normalize_before', True) else LayerNormPositionType.post_layernorm
rescale_before_lm_head = False rescale_before_lm_head = False
elif family == 'whisper':
q_scaling = 1.0
has_position_embedding = True
relative_attention = False
has_embedding_layernorm = False
has_attention_qkvo_bias = True
has_mlp_bias = True
has_model_final_layernorm = True
layernorm_position = LayerNormPositionType.pre_layernorm
layernorm_type = LayerNormType.LayerNorm
rescale_before_lm_head = False
logits_dtype = str_dtype_to_trt(args.dtype)
n_mels = config['n_mels']
else: else:
q_scaling = 1 / config['head_size']**.5 q_scaling = 1 / config['head_size']**.5
has_attention_qkvo_bias = False has_attention_qkvo_bias = False
@ -984,6 +999,9 @@ def enc_dec_build_helper(component, config, args):
else: else:
rescale_before_lm_head = False rescale_before_lm_head = False
quant_mode, _, _ = get_quant_mode(args.quantization)
use_weight_only = quant_mode.is_weight_only()
builder = Builder() builder = Builder()
builder_config = builder.create_builder_config( builder_config = builder.create_builder_config(
name=args.model, name=args.model,
@ -1011,6 +1029,10 @@ def enc_dec_build_helper(component, config, args):
has_token_type_embedding=False, # by default has_token_type_embedding=False, # by default
strongly_typed=False, # by default strongly_typed=False, # by default
gather_all_token_logits=False, # by default gather_all_token_logits=False, # by default
int8=(quant_mode.has_act_and_weight_quant()
or quant_mode.is_int8_weight_only()),
quant_mode=quant_mode,
n_mels=n_mels,
) )
# build engine # build engine
@ -1024,34 +1046,45 @@ def enc_dec_build_helper(component, config, args):
fp16_clamping = (args.dtype == 'float16') and ('t5' in family) fp16_clamping = (args.dtype == 'float16') and ('t5' in family)
if component == 'encoder': if component == 'encoder':
tllm_model = tensorrt_llm.models.EncoderModel( if family == 'whisper':
num_layers=config['num_layers'], tllm_model = tensorrt_llm.models.WhisperEncoder(
num_heads=config['num_heads'], n_mels=config['n_mels'],
num_kv_heads=config['num_heads'], n_ctx=1500, # n_audio_ctx
head_size=config['head_size'], n_state=config['hidden_size'],
hidden_size=config['hidden_size'], n_head=config['num_heads'],
ffn_hidden_size=config['ffn_hidden_size'], n_layer=config['num_layers'],
vocab_size=config['vocab_size'], dtype=dtype)
max_position_embeddings=config.get('n_positions', 0), if use_weight_only:
has_position_embedding=has_position_embedding, tllm_model = quantize_model(tllm_model, quant_mode)
relative_attention=relative_attention, else:
max_distance=config.get('max_distance', 0), tllm_model = tensorrt_llm.models.EncoderModel(
num_buckets=config.get('num_buckets', 0), num_layers=config['num_layers'],
has_embedding_layernorm=has_embedding_layernorm, num_heads=config['num_heads'],
has_embedding_scale=config.get('has_embedding_scale', False), num_kv_heads=config['num_heads'],
q_scaling=q_scaling, head_size=config['head_size'],
has_attention_qkvo_bias=has_attention_qkvo_bias, hidden_size=config['hidden_size'],
has_mlp_bias=has_mlp_bias, ffn_hidden_size=config['ffn_hidden_size'],
has_model_final_layernorm=has_model_final_layernorm, vocab_size=config['vocab_size'],
layernorm_eps=1e-6, max_position_embeddings=config.get('n_positions', 0),
layernorm_position=layernorm_position, has_position_embedding=has_position_embedding,
layernorm_type=layernorm_type, relative_attention=relative_attention,
hidden_act=config['hidden_act'], max_distance=config.get('max_distance', 0),
dtype=dtype, num_buckets=config.get('num_buckets', 0),
use_parallel_embedding=False, # by default has_embedding_layernorm=has_embedding_layernorm,
embedding_sharding_dim=0, # by default has_embedding_scale=config.get('has_embedding_scale', False),
mapping=mapping, q_scaling=q_scaling,
fp16_clamping=fp16_clamping) has_attention_qkvo_bias=has_attention_qkvo_bias,
has_mlp_bias=has_mlp_bias,
has_model_final_layernorm=has_model_final_layernorm,
layernorm_eps=1e-6,
layernorm_position=layernorm_position,
layernorm_type=layernorm_type,
hidden_act=config['hidden_act'],
dtype=dtype,
use_parallel_embedding=False, # by default
embedding_sharding_dim=0, # by default
mapping=mapping,
fp16_clamping=fp16_clamping)
elif component == 'decoder': elif component == 'decoder':
tllm_model = tensorrt_llm.models.DecoderModel( tllm_model = tensorrt_llm.models.DecoderModel(
num_layers=config['num_layers'], num_layers=config['num_layers'],
@ -1084,8 +1117,10 @@ def enc_dec_build_helper(component, config, args):
embedding_sharding_dim=0, # by default embedding_sharding_dim=0, # by default
mapping=mapping, mapping=mapping,
rescale_before_lm_head=rescale_before_lm_head, rescale_before_lm_head=rescale_before_lm_head,
logits_dtype='float32', # by default logits_dtype=logits_dtype, # by default
fp16_clamping=fp16_clamping) fp16_clamping=fp16_clamping)
if use_weight_only and family == 'whisper':
tllm_model = quantize_model(tllm_model, quant_mode)
# Module -> Network # Module -> Network
engine_name = get_engine_name(args.model, args.dtype, world_size, engine_name = get_engine_name(args.model, args.dtype, world_size,
@ -1099,6 +1134,12 @@ def enc_dec_build_helper(component, config, args):
network.plugin_config.set_bert_attention_plugin(dtype=args.dtype) network.plugin_config.set_bert_attention_plugin(dtype=args.dtype)
network.plugin_config.set_gemm_plugin(dtype=args.dtype) network.plugin_config.set_gemm_plugin(dtype=args.dtype)
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype) network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)
if use_weight_only:
network.plugin_config.set_weight_only_quant_matmul_plugin(
dtype=args.dtype)
elif args.mode == 'ootb-except-mha':
network.plugin_config.set_bert_attention_plugin(dtype=args.dtype)
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)
if world_size > 1: if world_size > 1:
network.plugin_config.set_nccl_plugin( network.plugin_config.set_nccl_plugin(
@ -1110,18 +1151,31 @@ def enc_dec_build_helper(component, config, args):
# Forward # Forward
if component == 'encoder': if component == 'encoder':
inputs = tllm_model.prepare_inputs( if family == 'whisper':
max_batch_size=config['max_batch_size'], inputs = tllm_model.prepare_inputs(
max_input_len=config['max_encoder_input_len'], max_batch_size=config['max_batch_size'], )
) else:
inputs = tllm_model.prepare_inputs(
max_batch_size=config['max_batch_size'],
max_input_len=config['max_encoder_input_len'],
)
elif component == 'decoder': elif component == 'decoder':
inputs = tllm_model.prepare_inputs( if family == 'whisper':
max_batch_size=config['max_batch_size'], inputs = tllm_model.prepare_inputs(
max_beam_width=config['max_beam_width'], max_batch_size=config['max_batch_size'],
max_decoder_input_len=config['max_decoder_input_len'], max_beam_width=config['max_beam_width'],
max_new_tokens=config['max_output_len'], max_decoder_input_len=config['max_decoder_input_len'],
max_encoder_input_len=config['max_encoder_input_len'], max_new_tokens=config['max_output_len'],
) max_encoder_input_len=1500, # n_audio_ctx
)
else:
inputs = tllm_model.prepare_inputs(
max_batch_size=config['max_batch_size'],
max_beam_width=config['max_beam_width'],
max_decoder_input_len=config['max_decoder_input_len'],
max_new_tokens=config['max_output_len'],
max_encoder_input_len=config['max_encoder_input_len'],
)
tllm_model(*inputs) tllm_model(*inputs)

View File

@ -23,8 +23,9 @@ from base_benchmark import BaseBenchmark, get_engine_name
from build import build_enc_dec from build import build_enc_dec
import tensorrt_llm import tensorrt_llm
from tensorrt_llm._utils import trt_dtype_to_torch from tensorrt_llm._utils import (trt_dtype_to_torch, str_dtype_to_trt)
from tensorrt_llm.quantization import QuantMode from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.runtime.session import TensorInfo
class EncDecBenchmark(BaseBenchmark): class EncDecBenchmark(BaseBenchmark):
@ -49,6 +50,8 @@ class EncDecBenchmark(BaseBenchmark):
# So we use separate variables for encoder and decoder here. # So we use separate variables for encoder and decoder here.
self.encoder_engine_model_name = args.model self.encoder_engine_model_name = args.model
self.decoder_engine_model_name = args.model self.decoder_engine_model_name = args.model
# only for whisper parameter
self.n_mels = 0
if self.engine_dir is not None: if self.engine_dir is not None:
@ -109,6 +112,8 @@ class EncDecBenchmark(BaseBenchmark):
self.max_input_len = config["builder_config"][ self.max_input_len = config["builder_config"][
"max_encoder_input_len"] "max_encoder_input_len"]
self.max_output_len = config["builder_config"]["max_output_len"] self.max_output_len = config["builder_config"]["max_output_len"]
self.n_mels = config["builder_config"][
'n_mels'] if 'whisper' in self.model_name else 0
for key, value in config["builder_config"].items(): for key, value in config["builder_config"].items():
if key == "name": if key == "name":
@ -173,6 +178,8 @@ class EncDecBenchmark(BaseBenchmark):
if args.max_input_len is None else args.max_input_len if args.max_input_len is None else args.max_input_len
self.max_output_len = build_config['max_output_len'] \ self.max_output_len = build_config['max_output_len'] \
if args.max_output_len is None else args.max_output_len if args.max_output_len is None else args.max_output_len
self.n_mels = build_config[
'n_mels'] if 'whisper' in self.model_name else 0
# Build engine # Build engine
( (
encoder_engine_buffer, encoder_engine_buffer,
@ -198,6 +205,10 @@ class EncDecBenchmark(BaseBenchmark):
) )
def get_config(self): def get_config(self):
if 'whisper' in self.model_name:
print(
f"[WARNING] whisper benchmark is input_len=1500, no text prompt, output_len=arbitrary"
)
for inlen, outlen in self.in_out_lens: for inlen, outlen in self.in_out_lens:
if (inlen > self.max_input_len or outlen > self.max_output_len): if (inlen > self.max_input_len or outlen > self.max_output_len):
print( print(
@ -216,29 +227,95 @@ class EncDecBenchmark(BaseBenchmark):
def prepare_inputs(self, config): def prepare_inputs(self, config):
batch_size, encoder_input_len = config[0], config[1] batch_size, encoder_input_len = config[0], config[1]
encoder_input_ids = (torch.randint( attention_mask = None
100, (batch_size, encoder_input_len)).int().cuda()) whisper_decoder_encoder_input_lengths = None
# For now, just hardcode the decoder_start_token_id to 0 for t5 models. outputs = {}
decoder_start_token_id = 0 if 'whisper' in self.model_name:
decoder_input_ids = torch.IntTensor([[decoder_start_token_id] # feature_len always fixed 3000 now
]).to(self.device) feature_len = 3000
decoder_input_ids = decoder_input_ids.repeat( encoder_input_ids = (torch.randint(
(encoder_input_ids.shape[0], 1)) 1, 100, (batch_size, self.n_mels, feature_len)).int().cuda())
# in padding mode --> keep input, just calculate actual length and max length encoder_input_lengths = torch.tensor([
# Note: 1st token should always count, even if it is pad_token_id (0). e.g., decoder start id in enc-dec models could be a single pad_token_id, we should count encoder_input_ids.shape[2] // 2
encoder_input_lengths = ((1 + (encoder_input_ids[:, 1:] != 0).sum( for _ in range(encoder_input_ids.shape[0])
dim=1).type(torch.IntTensor).to(self.device)).clone().detach().to( ],
dtype=torch.int32, device=self.device)) dtype=torch.int32,
decoder_input_lengths = ((1 + (decoder_input_ids[:, 1:] != 0).sum( device=self.device)
dim=1).type(torch.IntTensor).to(self.device)).clone().detach().to( decoder_input_ids = (torch.randint(1, 100, (1, )).int().cuda())
dtype=torch.int32, device=self.device)) decoder_input_ids = decoder_input_ids.repeat(
# attention mask, always set 1 as if all are valid tokens (encoder_input_ids.shape[0], 1))
attention_mask = torch.ones( output_list = [
(batch_size, encoder_input_len)).int().cuda() TensorInfo('x', str_dtype_to_trt(self.dtype),
# cross attention mask, always set 1 as if all are valid tokens encoder_input_ids.shape),
# [batch_size, query_len, encoder_input_len] currently, use query_len=1 TensorInfo('input_lengths', str_dtype_to_trt('int32'),
cross_attention_mask = torch.ones( encoder_input_lengths.shape)
(batch_size, 1, encoder_input_len)).int().cuda() ]
output_info = (self.encoder_session).infer_shapes(output_list)
outputs = {
t.name: torch.empty(tuple(t.shape),
dtype=trt_dtype_to_torch(t.dtype),
device='cuda')
for t in output_info
}
whisper_decoder_encoder_input_lengths = torch.tensor(
[
outputs['output'].shape[1]
for x in range(outputs['output'].shape[0])
],
dtype=torch.int32,
device='cuda')
decoder_input_lengths = torch.tensor([
decoder_input_ids.shape[-1]
for _ in range(decoder_input_ids.shape[0])
],
dtype=torch.int32,
device='cuda')
cross_attention_mask = torch.ones(
[outputs['output'].shape[0], 1,
outputs['output'].shape[1]]).int().cuda()
else:
encoder_input_ids = (torch.randint(
100, (batch_size, encoder_input_len)).int().cuda())
# For now, just hardcode the decoder_start_token_id to 0 for t5 models.
decoder_start_token_id = 0
decoder_input_ids = torch.IntTensor([[decoder_start_token_id]
]).to(self.device)
decoder_input_ids = decoder_input_ids.repeat(
(encoder_input_ids.shape[0], 1))
# in padding mode --> keep input, just calculate actual length and max length
# Note: 1st token should always count, even if it is pad_token_id (0). e.g., decoder start id in enc-dec models could be a single pad_token_id, we should count
encoder_input_lengths = ((
1 + (encoder_input_ids[:, 1:] != 0).sum(dim=1).type(
torch.IntTensor).to(self.device)).clone().detach().to(
dtype=torch.int32, device=self.device))
decoder_input_lengths = ((
1 + (decoder_input_ids[:, 1:] != 0).sum(dim=1).type(
torch.IntTensor).to(self.device)).clone().detach().to(
dtype=torch.int32, device=self.device))
# attention mask, always set 1 as if all are valid tokens
attention_mask = torch.ones(
(batch_size, encoder_input_len)).int().cuda()
# cross attention mask, always set 1 as if all are valid tokens
# [batch_size, query_len, encoder_input_len] currently, use query_len=1
cross_attention_mask = torch.ones(
(batch_size, 1, encoder_input_len)).int().cuda()
hidden_size = (self.encoder_model_config.hidden_size *
self.world_size) # tp_size
hidden_states_shape = (
encoder_input_ids.shape[0],
encoder_input_ids.shape[1],
hidden_size,
)
hidden_states_dtype = lambda name: trt_dtype_to_torch(
self.encoder_session.engine.get_tensor_dtype(name))
outputs["encoder_output"] = torch.empty(
hidden_states_shape,
dtype=hidden_states_dtype("encoder_output"),
device=self.device,
).contiguous()
stream = torch.cuda.current_stream().cuda_stream stream = torch.cuda.current_stream().cuda_stream
return ( return (
@ -248,6 +325,8 @@ class EncDecBenchmark(BaseBenchmark):
decoder_input_ids, decoder_input_ids,
decoder_input_lengths, decoder_input_lengths,
cross_attention_mask, cross_attention_mask,
whisper_decoder_encoder_input_lengths,
outputs,
stream, stream,
) )
@ -260,47 +339,37 @@ class EncDecBenchmark(BaseBenchmark):
decoder_input_ids, decoder_input_ids,
decoder_input_lengths, decoder_input_lengths,
cross_attention_mask, cross_attention_mask,
whisper_decoder_encoder_input_lengths,
outputs,
stream, stream,
) = inputs ) = inputs
hidden_size = (self.encoder_model_config.hidden_size * self.world_size
) # tp_size
hidden_states_shape = (
encoder_input_ids.shape[0],
encoder_input_ids.shape[1],
hidden_size,
)
hidden_states_dtype = lambda name: trt_dtype_to_torch( hidden_states_dtype = lambda name: trt_dtype_to_torch(
self.encoder_session.engine.get_tensor_dtype(name)) self.encoder_session.engine.get_tensor_dtype(name))
# input tensors # input tensors
inputs = {} inputs = {}
inputs["input_ids"] = encoder_input_ids.contiguous() if 'whisper' in self.model_name:
inputs["input_lengths"] = encoder_input_lengths inputs['x'] = encoder_input_ids.contiguous()
inputs["max_input_length"] = torch.empty( inputs["input_lengths"] = encoder_input_lengths
(self.max_input_len, ), else:
dtype=hidden_states_dtype("max_input_length"), inputs["input_ids"] = encoder_input_ids.contiguous()
device=self.device, inputs["input_lengths"] = encoder_input_lengths
).contiguous() inputs["max_input_length"] = torch.empty(
(self.max_input_len, ),
dtype=hidden_states_dtype("max_input_length"),
device=self.device,
).contiguous()
if not self.encoder_model_config.gpt_attention_plugin: if not self.encoder_model_config.gpt_attention_plugin:
inputs["attention_mask"] = attention_mask.contiguous() inputs["attention_mask"] = attention_mask.contiguous()
if self.encoder_model_config.has_position_embedding: if self.encoder_model_config.has_position_embedding:
bsz, seq_len = encoder_input_ids.shape[:2] bsz, seq_len = encoder_input_ids.shape[:2]
position_ids = torch.arange(seq_len, position_ids = torch.arange(
dtype=torch.int32, seq_len, dtype=torch.int32,
device=encoder_input_ids.device).expand( device=encoder_input_ids.device).expand(bsz, -1)
bsz, -1) inputs['position_ids'] = position_ids.contiguous()
inputs['position_ids'] = position_ids.contiguous()
# output tensors
outputs = {}
outputs["encoder_output"] = torch.empty(
hidden_states_shape,
dtype=hidden_states_dtype("encoder_output"),
device=self.device,
).contiguous()
# run encoder # run encoder
self.encoder_session.set_shapes(inputs) self.encoder_session.set_shapes(inputs)
@ -311,6 +380,12 @@ class EncDecBenchmark(BaseBenchmark):
# run decoder # run decoder
sampling_config = tensorrt_llm.runtime.SamplingConfig( sampling_config = tensorrt_llm.runtime.SamplingConfig(
end_id=1, pad_id=0, num_beams=self.num_beams, min_length=output_len) end_id=1, pad_id=0, num_beams=self.num_beams, min_length=output_len)
encoder_output = outputs[
'output'] if 'whisper' in self.model_name else outputs[
"encoder_output"]
encoder_max_input_length = encoder_output.shape[
1] if 'whisper' in self.model_name else torch.max(
encoder_input_lengths).item()
self.decoder_session.setup( self.decoder_session.setup(
decoder_input_lengths.size(0), decoder_input_lengths.size(0),
@ -318,9 +393,8 @@ class EncDecBenchmark(BaseBenchmark):
output_len, output_len,
beam_width=self.num_beams, beam_width=self.num_beams,
max_attention_window_size=None, max_attention_window_size=None,
encoder_max_input_length=torch.max(encoder_input_lengths).item(), encoder_max_input_length=encoder_max_input_length,
) )
torch.cuda.synchronize()
cross_attention_mask = None if self.decoder_model_config.gpt_attention_plugin else cross_attention_mask cross_attention_mask = None if self.decoder_model_config.gpt_attention_plugin else cross_attention_mask
@ -328,11 +402,11 @@ class EncDecBenchmark(BaseBenchmark):
decoder_input_ids, decoder_input_ids,
decoder_input_lengths, decoder_input_lengths,
sampling_config, sampling_config,
encoder_output=outputs["encoder_output"], encoder_output=encoder_output,
encoder_input_lengths=encoder_input_lengths, encoder_input_lengths=whisper_decoder_encoder_input_lengths
if 'whisper' in self.model_name else encoder_input_lengths,
cross_attention_mask=cross_attention_mask, cross_attention_mask=cross_attention_mask,
) )
torch.cuda.synchronize()
def report(self, def report(self,
config, config,

View File

@ -182,7 +182,7 @@ if(ENABLE_MULTI_DEVICE EQUAL 1)
find_library(NCCL_LIB nccl HINTS ${NCCL_LIB_DIR}) find_library(NCCL_LIB nccl HINTS ${NCCL_LIB_DIR})
endif() endif()
get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_SOURCE_DIR} PATH) get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR} PATH)
set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty) set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty)
include_directories( include_directories(

View File

@ -51,7 +51,7 @@ public:
batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback getInferenceRequestsCb, batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback getInferenceRequestsCb,
SendResponseCallback sendResponseCb, PollStopSignalCallback pollStopSignalCb = nullptr, SendResponseCallback sendResponseCb, PollStopSignalCallback pollStopSignalCb = nullptr,
ReturnBatchManagerStatsCallback returnBatchManagerStatsCb = nullptr, ReturnBatchManagerStatsCallback returnBatchManagerStatsCb = nullptr,
const TrtGptModelOptionalParams& optionalParams = TrtGptModelOptionalParams(), TrtGptModelOptionalParams const& optionalParams = TrtGptModelOptionalParams(),
std::optional<uint64_t> terminateReqId = std::nullopt, std::optional<SizeType> maxDraftTokens = std::nullopt, std::optional<uint64_t> terminateReqId = std::nullopt, std::optional<SizeType> maxDraftTokens = std::nullopt,
bool excludeInputInOutput = false); bool excludeInputInOutput = false);
@ -82,9 +82,9 @@ protected:
virtual BatchManagerErrorCode_t step(RequestList& activeRequests, std::set<uint64_t>& activeRequestsIds); virtual BatchManagerErrorCode_t step(RequestList& activeRequests, std::set<uint64_t>& activeRequestsIds);
private: private:
SizeType getMaxInputLen() const; [[nodiscard]] SizeType getMaxInputLen() const;
SizeType getMaxSequenceLen() const; [[nodiscard]] SizeType getMaxSequenceLen() const;
SizeType getMaxNumSequences() const; [[nodiscard]] SizeType getMaxNumSequences() const;
void validateLlmRequest(LlmRequest& newReq) const; void validateLlmRequest(LlmRequest& newReq) const;
static std::shared_ptr<LlmRequest> fillLlmRequest(std::shared_ptr<InferenceRequest> newReq); static std::shared_ptr<LlmRequest> fillLlmRequest(std::shared_ptr<InferenceRequest> newReq);

View File

@ -1,96 +0,0 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/batch_manager/schedulerPolicy.h"
#include "tensorrt_llm/runtime/common.h"
#include <list>
#include <memory>
namespace tensorrt_llm::batch_manager::batch_scheduler
{
/// @brief This scheduler takes into account the given request capacity and the KV cache capacity.
/// Depending on the SchedulerPolicy it will schedule already started and new requests,
/// or even terminate previously started requests.
class CapacityScheduler
{
public:
virtual ~CapacityScheduler() = default;
using RequestTable = std::map<uint64_t, std::shared_ptr<LlmRequest>>;
using SizeType = tensorrt_llm::runtime::SizeType;
using RequestList = std::list<std::shared_ptr<LlmRequest>>;
/// @brief Takes as input a sorted list of requests and outputs a sorted lists of requests
/// to update for this current iteration, and a map of requests to terminate
virtual std::tuple<RequestList, RequestTable> scheduleRequests(const RequestList& activeRequests) = 0;
};
/// @brief Schedule up to maxNumRequests requests
class MaxRequestsScheduler : public CapacityScheduler
{
public:
explicit MaxRequestsScheduler(SizeType maxNumRequests);
std::tuple<RequestList, RequestTable> scheduleRequests(const RequestList& activeRequests) override;
private:
SizeType mMaxNumRequests;
};
/// @brief Schedule requests using the MAX_UTILIZATION policy
/// @details Try reserving resources to advance requests by one step,
/// may terminate previously started requests.
class MaxUtilizationScheduler : public CapacityScheduler
{
public:
MaxUtilizationScheduler(
SizeType maxNumRequests, kv_cache_manager::KVCacheManager* kvCacheManager, bool manyMicroBatches);
std::tuple<RequestList, RequestTable> scheduleRequests(const RequestList& activeRequests) override;
private:
bool trySchedulingRequestMaxUtilization(
std::shared_ptr<LlmRequest> const& req, RequestList& scheduledRequests, SizeType& numScheduledBlocks);
SizeType mMaxNumRequests;
kv_cache_manager::KVCacheManager* mKvCacheManager{nullptr};
/// @brief Boolean that indicates if multiple micro batches might be in flight
bool mManyMicroBatches;
};
/// @brief Schedule requests using the GUARANTEED_NO_EVICT policy
class GuaranteedNoEvictScheduler : public CapacityScheduler
{
public:
GuaranteedNoEvictScheduler(SizeType maxNumRequests, kv_cache_manager::KVCacheManager* kvCacheManager);
std::tuple<RequestList, RequestTable> scheduleRequests(const RequestList& activeRequests) override;
private:
SizeType mMaxNumRequests;
kv_cache_manager::KVCacheManager* mKvCacheManager{nullptr};
};
std::unique_ptr<CapacityScheduler> makeCapacityScheduler(tensorrt_llm::runtime::SizeType maxNumRequests,
kv_cache_manager::KVCacheManager* kvCacheManager, SchedulerPolicy schedulerPolicy, bool manyMicroBatches = false);
} // namespace tensorrt_llm::batch_manager::batch_scheduler

View File

@ -16,8 +16,8 @@
#pragma once #pragma once
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/batch_manager/namedTensor.h" #include "tensorrt_llm/batch_manager/namedTensor.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/iTensor.h"
#include <algorithm> #include <algorithm>
@ -48,6 +48,7 @@ auto constexpr kTemperatureTensorName = "temperature";
auto constexpr kRuntimeTopKTensorName = "runtime_top_k"; auto constexpr kRuntimeTopKTensorName = "runtime_top_k";
auto constexpr kRuntimeTopPTensorName = "runtime_top_p"; auto constexpr kRuntimeTopPTensorName = "runtime_top_p";
auto constexpr kLengthPenaltyTensorName = "len_penalty"; auto constexpr kLengthPenaltyTensorName = "len_penalty";
auto constexpr kEarlyStoppingTensorName = "early_stopping";
auto constexpr kRepetitionPenaltyTensorName = "repetition_penalty"; auto constexpr kRepetitionPenaltyTensorName = "repetition_penalty";
auto constexpr kMinLengthTensorName = "min_length"; auto constexpr kMinLengthTensorName = "min_length";
auto constexpr kPresencePenaltyTensorName = "presence_penalty"; auto constexpr kPresencePenaltyTensorName = "presence_penalty";
@ -74,6 +75,11 @@ auto constexpr kLoraWeights = "lora_weights";
// "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection // "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
// "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection // "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
// "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate // "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
// "cross_attn_qkv": 8 # for enc-dec adapter for cross attention in decoder
// "cross_attn_q": 9 # for enc-dec adapter for cross attention in decoder
// "cross_attn_k": 10 # for enc-dec adapter for cross attention in decoder
// "cross_attn_v": 11 # for enc-dec adapter for cross attention in decoder
// "cross_attn_dense": 12 # for enc-dec adapter for cross attention in decoder
// //
// last dim holds [ module_id, layer_idx, adapter_size (D / R value) ] // last dim holds [ module_id, layer_idx, adapter_size (D / R value) ]
auto constexpr kLoraConfig = "lora_config"; // [num_lora_modules_layers, 3] auto constexpr kLoraConfig = "lora_config"; // [num_lora_modules_layers, 3]
@ -91,24 +97,29 @@ auto constexpr kGenerationLogitsName = "generation_logits";
} // namespace inference_request } // namespace inference_request
template <typename TTensor, typename TNamedTensor> template <typename TTensor, typename TNamedTensor, typename TStream = runtime::BufferManager::CudaStreamPtr>
class GenericInferenceRequest class GenericInferenceRequest
{ {
public: public:
using TensorPtr = TTensor; using TensorPtr = TTensor;
using NamedTensorType = TNamedTensor; using NamedTensorType = TNamedTensor;
using TensorMap = std::unordered_map<std::string, TTensor>; using TensorMap = std::unordered_map<std::string, TTensor>;
using LogitsPostProcessor = typename GenericLlmRequest<TensorPtr, TStream>::LogitsPostProcessor;
explicit GenericInferenceRequest(uint64_t requestId) explicit GenericInferenceRequest(
uint64_t requestId, std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
: mRequestId{requestId} : mRequestId{requestId}
, mIsStreaming{false} , mIsStreaming{false}
, mlogitsPostProcessor(logitsPostProcessor)
{ {
} }
GenericInferenceRequest(uint64_t requestId, TensorMap&& tensorMap) GenericInferenceRequest(uint64_t requestId, TensorMap&& tensorMap,
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
: mRequestId{requestId} : mRequestId{requestId}
, mIsStreaming{false} , mIsStreaming{false}
, mInputTensors{std::move(tensorMap)} , mInputTensors{std::move(tensorMap)}
, mlogitsPostProcessor(logitsPostProcessor)
{ {
for (auto const& [name, tensor] : mInputTensors) for (auto const& [name, tensor] : mInputTensors)
{ {
@ -116,8 +127,9 @@ public:
} }
} }
GenericInferenceRequest(uint64_t requestId, TensorMap const& tensorMap) GenericInferenceRequest(uint64_t requestId, TensorMap const& tensorMap,
: GenericInferenceRequest(requestId, TensorMap{tensorMap}) std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
: GenericInferenceRequest(requestId, TensorMap{tensorMap}, logitsPostProcessor)
{ {
} }
@ -141,6 +153,16 @@ public:
return mInputTensors; return mInputTensors;
} }
void setLogitsPostProcessor(std::optional<LogitsPostProcessor> cb)
{
mlogitsPostProcessor = cb;
}
std::optional<LogitsPostProcessor> getLogitsPostProcessor()
{
return mlogitsPostProcessor;
}
static std::array constexpr kTensorNames = { static std::array constexpr kTensorNames = {
inference_request::kInputIdsTensorName, inference_request::kInputIdsTensorName,
inference_request::kDraftInputIdsTensorName, inference_request::kDraftInputIdsTensorName,
@ -156,6 +178,7 @@ public:
inference_request::kRuntimeTopKTensorName, inference_request::kRuntimeTopKTensorName,
inference_request::kRuntimeTopPTensorName, inference_request::kRuntimeTopPTensorName,
inference_request::kLengthPenaltyTensorName, inference_request::kLengthPenaltyTensorName,
inference_request::kEarlyStoppingTensorName,
inference_request::kRepetitionPenaltyTensorName, inference_request::kRepetitionPenaltyTensorName,
inference_request::kMinLengthTensorName, inference_request::kMinLengthTensorName,
inference_request::kPresencePenaltyTensorName, inference_request::kPresencePenaltyTensorName,
@ -200,7 +223,10 @@ public:
\ \
void set##funcName(TensorPtr const& tensor) \ void set##funcName(TensorPtr const& tensor) \
{ \ { \
TLLM_CHECK_WITH_INFO(tensor, "Cannot set nullptr when calling %s", __FUNCTION__); \ if constexpr (std::is_same_v<TensorPtr, tensorrt_llm::runtime::ITensor::SharedPtr>) \
{ \
TLLM_CHECK_WITH_INFO(tensor, "Cannot set nullptr when calling %s", __FUNCTION__); \
} \
mInputTensors[tensorName] = tensor; \ mInputTensors[tensorName] = tensor; \
} }
@ -218,6 +244,7 @@ public:
TENSOR_GETTER_SETTER(RuntimeTopK, inference_request::kRuntimeTopKTensorName) TENSOR_GETTER_SETTER(RuntimeTopK, inference_request::kRuntimeTopKTensorName)
TENSOR_GETTER_SETTER(RuntimeTopP, inference_request::kRuntimeTopPTensorName) TENSOR_GETTER_SETTER(RuntimeTopP, inference_request::kRuntimeTopPTensorName)
TENSOR_GETTER_SETTER(LengthPenalty, inference_request::kLengthPenaltyTensorName) TENSOR_GETTER_SETTER(LengthPenalty, inference_request::kLengthPenaltyTensorName)
TENSOR_GETTER_SETTER(EarlyStopping, inference_request::kEarlyStoppingTensorName)
TENSOR_GETTER_SETTER(RepetitionPenalty, inference_request::kRepetitionPenaltyTensorName) TENSOR_GETTER_SETTER(RepetitionPenalty, inference_request::kRepetitionPenaltyTensorName)
TENSOR_GETTER_SETTER(MinLength, inference_request::kMinLengthTensorName) TENSOR_GETTER_SETTER(MinLength, inference_request::kMinLengthTensorName)
TENSOR_GETTER_SETTER(PresencePenalty, inference_request::kPresencePenaltyTensorName) TENSOR_GETTER_SETTER(PresencePenalty, inference_request::kPresencePenaltyTensorName)
@ -243,6 +270,7 @@ protected:
uint64_t mRequestId; uint64_t mRequestId;
bool mIsStreaming; bool mIsStreaming;
TensorMap mInputTensors; TensorMap mInputTensors;
std::optional<LogitsPostProcessor> mlogitsPostProcessor;
}; };
class InferenceRequest : public GenericInferenceRequest<tensorrt_llm::runtime::ITensor::SharedPtr, NamedTensor> class InferenceRequest : public GenericInferenceRequest<tensorrt_llm::runtime::ITensor::SharedPtr, NamedTensor>

View File

@ -326,7 +326,7 @@ private:
//! \param seqSlotIdx Batch slot of sequence to which blocks are assigned. //! \param seqSlotIdx Batch slot of sequence to which blocks are assigned.
//! \return Number of matched tokens from loaded blocks. //! \return Number of matched tokens from loaded blocks.
SizeType loadOrAllocateBlocks( SizeType loadOrAllocateBlocks(
std::list<VecTokens> blockedTokens, GenerationRequest& sequence, SizeType beamIdx, SizeType seqSlotIdx); std::list<VecTokens> const& blockedTokens, GenerationRequest& sequence, SizeType beamIdx, SizeType seqSlotIdx);
//! \brief Find block least likely to be reused, free it if necessary and return. //! \brief Find block least likely to be reused, free it if necessary and return.
[[nodiscard]] BlockPtr getFreeBlock(); [[nodiscard]] BlockPtr getFreeBlock();
@ -362,9 +362,9 @@ public:
using CudaStreamPtr = std::shared_ptr<runtime::CudaStream>; using CudaStreamPtr = std::shared_ptr<runtime::CudaStream>;
KVCacheManager(SizeType numLayers, SizeType numKvHeads, SizeType sizePerHead, SizeType tokensPerBlock, KVCacheManager(SizeType numLayers, SizeType numKvHeads, SizeType sizePerHead, SizeType tokensPerBlock,
SizeType maxNumBlocks, SizeType maxNumSequences, SizeType maxBeamWidth, SizeType maxBlocksPerSeq, SizeType maxNumBlocks, SizeType maxNumSequences, SizeType maxBeamWidth, SizeType maxAttentionWindow,
SizeType maxAttentionWindow, SizeType sinkTokenLength, bool useOneMoreBlock, nvinfer1::DataType dtype, SizeType sinkTokenLength, bool useOneMoreBlock, nvinfer1::DataType dtype, CudaStreamPtr stream,
CudaStreamPtr stream, bool enableBlockReuse = false, bool useUvm = false); bool enableBlockReuse = false, bool useUvm = false);
void startScheduling(); void startScheduling();
@ -405,6 +405,11 @@ public:
return mBlockSize; return mBlockSize;
} }
[[nodiscard]] SizeType getMaxBlocksPerSeq() const
{
return mMaxBlocksPerSeq;
}
[[nodiscard]] BlockManager const& getBlockManager() const [[nodiscard]] BlockManager const& getBlockManager() const
{ {
return mBlockManager; return mBlockManager;
@ -414,13 +419,13 @@ public:
/// iterations /// iterations
/// @param req The request for which we need to calculate the number of needed KV cache blocks /// @param req The request for which we need to calculate the number of needed KV cache blocks
/// @return The number of blocks /// @return The number of blocks
SizeType getNeededBlocksOneStep(LlmRequest const& req, bool twoStepsLookAhead) const; [[nodiscard]] SizeType getNeededBlocksOneStep(LlmRequest const& req, bool twoStepsLookAhead) const;
/// @brief Function that computes the number of KV cache blocks needed to advance a request to completion (i.e. for /// @brief Function that computes the number of KV cache blocks needed to advance a request to completion (i.e. for
/// maxNewTokens) /// maxNewTokens)
/// @param req The request for which we need to calculate the number of needed KV cache blocks /// @param req The request for which we need to calculate the number of needed KV cache blocks
/// @return The number of blocks /// @return The number of blocks
SizeType getNeededBlocksToCompletion(LlmRequest const& req) const; [[nodiscard]] SizeType getNeededBlocksToCompletion(LlmRequest const& req) const;
[[nodiscard]] std::vector<runtime::ITensor::SharedPtr> const& getMemoryPools() const [[nodiscard]] std::vector<runtime::ITensor::SharedPtr> const& getMemoryPools() const
{ {
@ -458,7 +463,7 @@ public:
* modelConfig.getSizePerHead(); * modelConfig.getSizePerHead();
} }
[[nodiscard]] static SizeType getMaxNumTokens(KvCacheConfig const& config, nvinfer1::DataType dtype, [[nodiscard]] static SizeType calculateMaxNumBlocks(KvCacheConfig const& config, nvinfer1::DataType dtype,
tensorrt_llm::runtime::GptModelConfig const& modelConfig, tensorrt_llm::runtime::WorldConfig const& worldConfig, tensorrt_llm::runtime::GptModelConfig const& modelConfig, tensorrt_llm::runtime::WorldConfig const& worldConfig,
runtime::BufferManager const& bufferManager); runtime::BufferManager const& bufferManager);
@ -491,10 +496,8 @@ private:
// Maximum kv cache length per sequence // Maximum kv cache length per sequence
// Enable cyclic kv cache when it exceeds // Enable cyclic kv cache when it exceeds
SizeType mMaxAttentionWindow; SizeType mMaxAttentionWindow;
// Sink token length in the kv cache per sequence // Number of tokens to fill up the sink tokens to a full block size
SizeType mSinkTokenLength; SizeType mSinkBubbleLength;
// Bubble token length
SizeType mBubbleLength;
// Maximum token length (including bubble) // Maximum token length (including bubble)
SizeType mMaxTokenNum; SizeType mMaxTokenNum;
// Number of tokens in the sink blocks // Number of tokens in the sink blocks

View File

@ -25,6 +25,7 @@
#include <cassert> #include <cassert>
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include <utility>
#include <vector> #include <vector>
namespace tensorrt_llm::batch_manager namespace tensorrt_llm::batch_manager
@ -38,7 +39,7 @@ enum LlmRequestState_t
REQUEST_STATE_GENERATION_COMPLETE = 3 REQUEST_STATE_GENERATION_COMPLETE = 3
}; };
template <typename TTensor> template <typename TTensor, typename TStream = runtime::BufferManager::CudaStreamPtr>
class GenericLlmRequest class GenericLlmRequest
{ {
public: public:
@ -49,9 +50,10 @@ public:
using VecLogProbs = std::vector<float>; using VecLogProbs = std::vector<float>;
using BeamTokens = std::vector<VecTokens>; using BeamTokens = std::vector<VecTokens>;
using TensorPtr = TTensor; using TensorPtr = TTensor;
using LogitsPostProcessor = std::function<TensorPtr(RequestIdType, TensorPtr&, BeamTokens const&, TStream)>;
GenericLlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<VecTokens> inputTokens, GenericLlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<VecTokens> inputTokens,
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt, runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt, std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt, std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt, std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
@ -59,7 +61,8 @@ public:
std::optional<TensorPtr> loraConfig = std::nullopt, bool returnLogProbs = false, std::optional<TensorPtr> loraConfig = std::nullopt, bool returnLogProbs = false,
bool returnContextLogits = false, bool returnGenerationLogits = false, bool returnContextLogits = false, bool returnGenerationLogits = false,
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt, std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt,
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false) std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false,
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
: mRequestId(requestId) : mRequestId(requestId)
, mPromptLen(inputTokens->size()) , mPromptLen(inputTokens->size())
, mMaxNewTokens(maxNewTokens) , mMaxNewTokens(maxNewTokens)
@ -69,15 +72,16 @@ public:
, mEndId(endId) , mEndId(endId)
, mPadId(padId) , mPadId(padId)
, mSeqSlot(-1) , mSeqSlot(-1)
, mLogitsPostProcessor(logitsPostProcessor)
, mOrigPromptLen(mPromptLen) , mOrigPromptLen(mPromptLen)
, mMaxSentTokenPos(mPromptLen - 1) , mMaxSentTokenPos(mPromptLen - 1)
, mEmbeddingBias(embeddingBias) , mEmbeddingBias(std::move(embeddingBias))
, mBadWordsList(badWordsList) , mBadWordsList(std::move(badWordsList))
, mStopWordsList(stopWordsList) , mStopWordsList(std::move(stopWordsList))
, mPromptEmbeddingTable(promptEmbeddingTable) , mPromptEmbeddingTable(std::move(promptEmbeddingTable))
, mPromptVocabSize(promptVocabSize) , mPromptVocabSize(promptVocabSize)
, mLoraWeights(loraWeights) , mLoraWeights(std::move(loraWeights))
, mLoraConfig(loraConfig) , mLoraConfig(std::move(loraConfig))
, mReturnLogProbs(returnLogProbs) , mReturnLogProbs(returnLogProbs)
, mContextChunkSize(std::nullopt) , mContextChunkSize(std::nullopt)
, mContextCurrentPosition(0) , mContextCurrentPosition(0)
@ -198,14 +202,14 @@ public:
/// @brief Get total number of tokens for this req (prompt + generated) /// @brief Get total number of tokens for this req (prompt + generated)
/// @param beam The beam index /// @param beam The beam index
/// @return The number of tokens /// @return The number of tokens
SizeType getNumTokens(SizeType beam) const [[nodiscard]] SizeType getNumTokens(SizeType beam) const
{ {
return mTokens.at(beam).size(); return mTokens.at(beam).size();
} }
/// @brief Get max number of tokens across all beams /// @brief Get max number of tokens across all beams
/// @return The number of tokens /// @return The number of tokens
SizeType getMaxBeamNumTokens() const [[nodiscard]] SizeType getMaxBeamNumTokens() const
{ {
SizeType maxTokens = 0; SizeType maxTokens = 0;
for (SizeType beam = 0; beam < mSamplingConfig.beamWidth; ++beam) for (SizeType beam = 0; beam < mSamplingConfig.beamWidth; ++beam)
@ -219,7 +223,7 @@ public:
/// @param beam The beam index /// @param beam The beam index
/// @param pos The position of the token relative to beginning of the prompt /// @param pos The position of the token relative to beginning of the prompt
/// @return The token index /// @return The token index
TokenIdType getToken(SizeType beam, SizeType pos) const [[nodiscard]] TokenIdType getToken(SizeType beam, SizeType pos) const
{ {
return mTokens.at(beam).at(pos); return mTokens.at(beam).at(pos);
} }
@ -227,42 +231,42 @@ public:
/// @brief Get the tokens at a given beam index /// @brief Get the tokens at a given beam index
/// @param beam The beam index /// @param beam The beam index
/// @return A vector of tokens for this beam index, includes the prompt /// @return A vector of tokens for this beam index, includes the prompt
VecTokens const& getTokens(SizeType beam) const [[nodiscard]] VecTokens const& getTokens(SizeType beam) const
{ {
return mTokens.at(beam); return mTokens.at(beam);
} }
/// @brief Get all tokens (input+output) for all beams /// @brief Get all tokens (input+output) for all beams
/// @return A vector of vector of tokens. /// @return A vector of vector of tokens.
BeamTokens const& getTokens() const [[nodiscard]] BeamTokens const& getTokens() const
{ {
return mTokens; return mTokens;
} }
/// @brief Get the draft tokens /// @brief Get the draft tokens
/// @return shared_ptr to vector of draft tokens /// @return shared_ptr to vector of draft tokens
std::shared_ptr<VecTokens> const& getDraftTokens() const [[nodiscard]] std::shared_ptr<VecTokens> const& getDraftTokens() const
{ {
return mDraftTokens; return mDraftTokens;
} }
/// @brief Get the logits for the draft tokens /// @brief Get the logits for the draft tokens
/// @return Tensor of draft logits /// @return Tensor of draft logits
std::optional<TensorPtr> getDraftLogits() const [[nodiscard]] std::optional<TensorPtr> getDraftLogits() const
{ {
return mDraftLogits; return mDraftLogits;
} }
/// @brief Returns true if request has draft tokens /// @brief Returns true if request has draft tokens
/// @return flag /// @return flag
bool hasDraftTokens() const [[nodiscard]] bool hasDraftTokens() const
{ {
return mDraftTokens && mDraftTokens->size() > 0; return mDraftTokens && !mDraftTokens->empty();
} }
/// @brief Get the maximum number of generated tokens among all rays in beam /// @brief Get the maximum number of generated tokens among all rays in beam
/// @return The number of generated tokens (doesn't include the prompt tokens) /// @return The number of generated tokens (doesn't include the prompt tokens)
SizeType getMaxNumGeneratedTokens() const [[nodiscard]] SizeType getMaxNumGeneratedTokens() const
{ {
return getMaxBeamNumTokens() - mPromptLen; return getMaxBeamNumTokens() - mPromptLen;
} }
@ -283,14 +287,14 @@ public:
assert(static_cast<size_t>(mSamplingConfig.beamWidth) == beamTokens.size()); assert(static_cast<size_t>(mSamplingConfig.beamWidth) == beamTokens.size());
for (std::size_t beam = 0; beam < beamTokens.size(); ++beam) for (std::size_t beam = 0; beam < beamTokens.size(); ++beam)
{ {
const auto outputId = beamTokens[beam]; auto const outputId = beamTokens[beam];
mTokens.at(beam).push_back(outputId); mTokens.at(beam).push_back(outputId);
} }
} }
/// @brief Sets the generated tokens for all beams. Erases all previous generated tokens. /// @brief Sets the generated tokens for all beams. Erases all previous generated tokens.
/// @param generatedBeamTokens The generated tokens for all beams (vector of vector of tokens) /// @param generatedBeamTokens The generated tokens for all beams (vector of vector of tokens)
void setGeneratedTokens(const BeamTokens& generatedBeamTokens) void setGeneratedTokens(BeamTokens const& generatedBeamTokens)
{ {
assert(generatedBeamTokens.size() == static_cast<size_t>(mSamplingConfig.beamWidth)); assert(generatedBeamTokens.size() == static_cast<size_t>(mSamplingConfig.beamWidth));
for (std::size_t beam = 0; beam < generatedBeamTokens.size(); ++beam) for (std::size_t beam = 0; beam < generatedBeamTokens.size(); ++beam)
@ -346,7 +350,7 @@ public:
/// @brief Get the maximum position of the tokens returned to the client. Use to ensure we don't return to /// @brief Get the maximum position of the tokens returned to the client. Use to ensure we don't return to
/// client duplicated token positions. /// client duplicated token positions.
/// @return The maximum position of the tokens sent to the client /// @return The maximum position of the tokens sent to the client
SizeType getMaxSentTokenPos() const [[nodiscard]] SizeType getMaxSentTokenPos() const
{ {
return mMaxSentTokenPos; return mMaxSentTokenPos;
} }
@ -359,17 +363,17 @@ public:
mMaxSentTokenPos = pos; mMaxSentTokenPos = pos;
} }
std::optional<TensorPtr> getPromptEmbeddingTable() const [[nodiscard]] std::optional<TensorPtr> getPromptEmbeddingTable() const
{ {
return mPromptEmbeddingTable; return mPromptEmbeddingTable;
} }
std::optional<SizeType> getPromptVocabSize() const [[nodiscard]] std::optional<SizeType> getPromptVocabSize() const
{ {
return mPromptVocabSize; return mPromptVocabSize;
} }
std::optional<TensorPtr> getLoraWeights() const [[nodiscard]] std::optional<TensorPtr> getLoraWeights() const
{ {
return mLoraWeights; return mLoraWeights;
} }
@ -384,7 +388,7 @@ public:
mLoraWeights = std::nullopt; mLoraWeights = std::nullopt;
} }
std::optional<TensorPtr> getLoraConfig() const [[nodiscard]] std::optional<TensorPtr> getLoraConfig() const
{ {
return mLoraConfig; return mLoraConfig;
} }
@ -399,32 +403,37 @@ public:
mLoraConfig = std::nullopt; mLoraConfig = std::nullopt;
} }
std::optional<TensorPtr> getEmbeddingBias() const [[nodiscard]] std::optional<TensorPtr> getEmbeddingBias() const
{ {
return mEmbeddingBias; return mEmbeddingBias;
} }
std::optional<TensorPtr> getBadWordsList() const [[nodiscard]] std::optional<TensorPtr> getBadWordsList() const
{ {
return mBadWordsList; return mBadWordsList;
} }
std::optional<TensorPtr> getStopWordsList() const [[nodiscard]] std::optional<TensorPtr> getStopWordsList() const
{ {
return mStopWordsList; return mStopWordsList;
} }
bool returnLogProbs() const [[nodiscard]] bool returnLogProbs() const
{ {
return mReturnLogProbs; return mReturnLogProbs;
} }
std::vector<VecLogProbs> const& getLogProbs() const void setReturnLogProbs(bool returnLogProbs)
{
mReturnLogProbs = returnLogProbs;
}
[[nodiscard]] std::vector<VecLogProbs> const& getLogProbs() const
{ {
return mLogProbs; return mLogProbs;
} }
VecLogProbs const& getLogProbs(SizeType beam) const [[nodiscard]] VecLogProbs const& getLogProbs(SizeType beam) const
{ {
return mLogProbs.at(beam); return mLogProbs.at(beam);
} }
@ -435,7 +444,7 @@ public:
mLogProbs.at(beam).insert(mLogProbs.at(beam).end(), logProbs.begin(), logProbs.end()); mLogProbs.at(beam).insert(mLogProbs.at(beam).end(), logProbs.begin(), logProbs.end());
} }
VecLogProbs const& getCumLogProbs() const [[nodiscard]] VecLogProbs const& getCumLogProbs() const
{ {
return mCumLogProbs; return mCumLogProbs;
} }
@ -445,17 +454,17 @@ public:
mCumLogProbs.at(beam) = cumLogProb; mCumLogProbs.at(beam) = cumLogProb;
} }
SizeType getOrigPromptLen() const [[nodiscard]] SizeType getOrigPromptLen() const
{ {
return mOrigPromptLen; return mOrigPromptLen;
} }
void setDraftTokens(const std::shared_ptr<VecTokens>& draftTokens) void setDraftTokens(std::shared_ptr<VecTokens> const& draftTokens)
{ {
mDraftTokens = draftTokens; mDraftTokens = draftTokens;
} }
void setDraftLogits(const std::optional<TensorPtr>& draftLogits) void setDraftLogits(std::optional<TensorPtr> const& draftLogits)
{ {
mDraftLogits = draftLogits; mDraftLogits = draftLogits;
} }
@ -470,7 +479,7 @@ public:
mReturnContextLogits = returnContextLogits; mReturnContextLogits = returnContextLogits;
} }
bool getReturnContextLogits() const [[nodiscard]] bool getReturnContextLogits() const
{ {
return mReturnContextLogits; return mReturnContextLogits;
} }
@ -480,12 +489,12 @@ public:
mReturnGenerationLogits = returnGenerationLogits; mReturnGenerationLogits = returnGenerationLogits;
} }
bool getReturnGenerationLogits() const [[nodiscard]] bool getReturnGenerationLogits() const
{ {
return mReturnGenerationLogits; return mReturnGenerationLogits;
} }
TensorPtr const& getContextLogitsHost() const [[nodiscard]] TensorPtr const& getContextLogitsHost() const
{ {
return mContextLogitsHost; return mContextLogitsHost;
} }
@ -501,7 +510,7 @@ public:
runtime::ITensor::makeShape({mPromptLen, vocabSizePadded}), logitsDataType); runtime::ITensor::makeShape({mPromptLen, vocabSizePadded}), logitsDataType);
} }
TensorPtr const& getGenerationLogitsHost() const [[nodiscard]] TensorPtr const& getGenerationLogitsHost() const
{ {
return mGenerationLogitsHost; return mGenerationLogitsHost;
} }
@ -517,7 +526,7 @@ public:
runtime::ITensor::makeShape({mSamplingConfig.beamWidth, mMaxNewTokens, vocabSizePadded}), logitsDataType); runtime::ITensor::makeShape({mSamplingConfig.beamWidth, mMaxNewTokens, vocabSizePadded}), logitsDataType);
} }
std::vector<TensorPtr> const& getGenerationLogitsFragments() const [[nodiscard]] std::vector<TensorPtr> const& getGenerationLogitsFragments() const
{ {
return mGenerationLogitsFragments; return mGenerationLogitsFragments;
} }
@ -537,38 +546,38 @@ public:
mGenerationLogitsFragments.clear(); mGenerationLogitsFragments.clear();
} }
bool isContextInitState() const noexcept [[nodiscard]] bool isContextInitState() const noexcept
{ {
return mState == REQUEST_STATE_CONTEXT_INIT; return mState == REQUEST_STATE_CONTEXT_INIT;
} }
bool isGenerationInProgessState() const noexcept [[nodiscard]] bool isGenerationInProgessState() const noexcept
{ {
return mState == REQUEST_STATE_GENERATION_IN_PROGRESS; return mState == REQUEST_STATE_GENERATION_IN_PROGRESS;
} }
/// To determine whether the context is unchunked. When a context is chunked into only a part, it /// To determine whether the context is unchunked. When a context is chunked into only a part, it
/// is still different from the unchunked state, which indicates the initial status. /// is still different from the unchunked state, which indicates the initial status.
bool isFullContextRequest() const noexcept [[nodiscard]] bool isFullContextRequest() const noexcept
{ {
return isContextInitState() && !mContextChunkSize; return isContextInitState() && !mContextChunkSize;
} }
/// When chunked, the position of the current chunk is returned. Otherwise, only the beginning /// When chunked, the position of the current chunk is returned. Otherwise, only the beginning
/// or end of the context is returned. /// or end of the context is returned.
SizeType getContextCurrentPosition() const noexcept [[nodiscard]] SizeType getContextCurrentPosition() const noexcept
{ {
return mContextCurrentPosition; return mContextCurrentPosition;
} }
/// Return the length of the context that has not yet been processed. /// Return the length of the context that has not yet been processed.
SizeType getContextRemainingLength() const noexcept [[nodiscard]] SizeType getContextRemainingLength() const noexcept
{ {
return mPromptLen - getContextCurrentPosition(); return mPromptLen - getContextCurrentPosition();
} }
/// To retrieve the context chunk size, throw an exception when the context is not chunked. /// To retrieve the context chunk size, throw an exception when the context is not chunked.
SizeType getContextChunkSize() const [[nodiscard]] SizeType getContextChunkSize() const
{ {
TLLM_CHECK_WITH_INFO( TLLM_CHECK_WITH_INFO(
isContextInitState() && mContextChunkSize, "The current request is not in context chunking state."); isContextInitState() && mContextChunkSize, "The current request is not in context chunking state.");
@ -587,7 +596,7 @@ public:
/// Determines whether the current position is only one chunk away from the end of the context. /// Determines whether the current position is only one chunk away from the end of the context.
/// It will return true when the context is not chunked. /// It will return true when the context is not chunked.
bool isLastContextChunk() const noexcept [[nodiscard]] bool isLastContextChunk() const noexcept
{ {
return isFullContextRequest() return isFullContextRequest()
|| (isContextInitState() && getContextCurrentPosition() + getContextChunkSize() == mPromptLen); || (isContextInitState() && getContextCurrentPosition() + getContextChunkSize() == mPromptLen);
@ -595,7 +604,7 @@ public:
/// Returns whether the position is at the beginning of the context. It will return true when the /// Returns whether the position is at the beginning of the context. It will return true when the
/// context is not chunked. /// context is not chunked.
bool isFirstContextChunk() const noexcept [[nodiscard]] bool isFirstContextChunk() const noexcept
{ {
return isFullContextRequest() || getContextCurrentPosition() == 0; return isFullContextRequest() || getContextCurrentPosition() == 0;
} }
@ -704,6 +713,7 @@ public:
std::optional<SizeType> mEndId; std::optional<SizeType> mEndId;
std::optional<SizeType> mPadId; std::optional<SizeType> mPadId;
SizeType mSeqSlot; SizeType mSeqSlot;
std::optional<LogitsPostProcessor> mLogitsPostProcessor;
protected: protected:
SizeType mOrigPromptLen; SizeType mOrigPromptLen;
@ -805,7 +815,7 @@ public:
using VecTokens = Base::VecTokens; using VecTokens = Base::VecTokens;
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<VecTokens> inputTokens, LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<VecTokens> inputTokens,
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt, runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt, std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt, std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt, std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
@ -813,10 +823,13 @@ public:
std::optional<TensorPtr> loraConfig = std::nullopt, bool returnLogProbs = false, std::optional<TensorPtr> loraConfig = std::nullopt, bool returnLogProbs = false,
bool returnContextLogits = false, bool returnGenerationLogits = false, bool returnContextLogits = false, bool returnGenerationLogits = false,
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt, std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt,
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false) std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false,
: Base(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, endId, padId, embeddingBias, std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
badWordsList, stopWordsList, promptEmbeddingTable, promptVocabSize, loraWeights, loraConfig, returnLogProbs, : Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId,
returnContextLogits, returnGenerationLogits, draftTokens, draftLogits, excludeInputFromOutput) std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList),
std::move(promptEmbeddingTable), promptVocabSize, std::move(loraWeights), std::move(loraConfig),
returnLogProbs, returnContextLogits, returnGenerationLogits, std::move(draftTokens), std::move(draftLogits),
excludeInputFromOutput, std::move(logitsPostProcessor))
{ {
} }

View File

@ -302,6 +302,8 @@ public:
void allgather(const void* sendbuf, void* recvbuf, int count, MpiType dtype) const; void allgather(const void* sendbuf, void* recvbuf, int count, MpiType dtype) const;
void barrier() const; void barrier() const;
void mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const;
bool operator==(MpiComm const& rhs) const bool operator==(MpiComm const& rhs) const
{ {
return mComm == rhs.mComm; return mComm == rhs.mComm;

View File

@ -46,8 +46,8 @@ public:
std::optional<FloatType> beamSearchDiversityRate = std::nullopt, std::optional<FloatType> beamSearchDiversityRate = std::nullopt,
std::optional<FloatType> repetitionPenalty = std::nullopt, std::optional<FloatType> repetitionPenalty = std::nullopt,
std::optional<FloatType> presencePenalty = std::nullopt, std::optional<FloatType> presencePenalty = std::nullopt,
std::optional<FloatType> frequencyPenalty = std::nullopt, std::optional<FloatType> frequencyPenalty = std::nullopt, std::optional<FloatType> lengthPenalty = std::nullopt,
std::optional<FloatType> lengthPenalty = std::nullopt); std::optional<SizeType> earlyStopping = std::nullopt);
~SamplingConfig(); ~SamplingConfig();
@ -65,6 +65,7 @@ public:
[[nodiscard]] std::optional<FloatType> getPresencePenalty() const; [[nodiscard]] std::optional<FloatType> getPresencePenalty() const;
[[nodiscard]] std::optional<FloatType> getFrequencyPenalty() const; [[nodiscard]] std::optional<FloatType> getFrequencyPenalty() const;
[[nodiscard]] std::optional<FloatType> getLengthPenalty() const; [[nodiscard]] std::optional<FloatType> getLengthPenalty() const;
[[nodiscard]] std::optional<SizeType> getEarlyStopping() const;
private: private:
SizeType mBeamWidth; SizeType mBeamWidth;
@ -81,6 +82,7 @@ private:
std::optional<FloatType> mPresencePenalty; std::optional<FloatType> mPresencePenalty;
std::optional<FloatType> mFrequencyPenalty; std::optional<FloatType> mFrequencyPenalty;
std::optional<FloatType> mLengthPenalty; std::optional<FloatType> mLengthPenalty;
std::optional<SizeType> mEarlyStopping;
}; };
/// @brief Configuration that controls the outputs of a Result /// @brief Configuration that controls the outputs of a Result

View File

@ -22,9 +22,7 @@
#include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/iTensor.h"
#include <NvInferRuntime.h> #include <NvInferRuntime.h>
#include <cstdint>
#include <memory> #include <memory>
#include <stdexcept>
#include <string> #include <string>
#include <vector> #include <vector>

View File

@ -16,7 +16,6 @@
#pragma once #pragma once
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/iTensor.h"

View File

@ -16,7 +16,6 @@
#pragma once #pragma once
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/iTensor.h"

View File

@ -16,7 +16,6 @@
#pragma once #pragma once
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/promptTuningParams.h" #include "tensorrt_llm/runtime/promptTuningParams.h"

View File

@ -16,7 +16,6 @@
#pragma once #pragma once
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/iTensor.h"

View File

@ -16,7 +16,6 @@
#pragma once #pragma once
#include "tensorrt_llm/common/cudaAllocator.h"
#include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/decodingInput.h" #include "tensorrt_llm/runtime/decodingInput.h"
#include "tensorrt_llm/runtime/decodingMode.h" #include "tensorrt_llm/runtime/decodingMode.h"
@ -24,7 +23,6 @@
#include "tensorrt_llm/runtime/samplingConfig.h" #include "tensorrt_llm/runtime/samplingConfig.h"
#include <curand_kernel.h> #include <curand_kernel.h>
#include <cstdint>
#include <memory> #include <memory>
#include <NvInferRuntime.h> #include <NvInferRuntime.h>

View File

@ -16,7 +16,6 @@
#pragma once #pragma once
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaEvent.h" #include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/cudaStream.h" #include "tensorrt_llm/runtime/cudaStream.h"
@ -25,11 +24,7 @@
#include "tensorrt_llm/runtime/iGptDecoderBatch.h" #include "tensorrt_llm/runtime/iGptDecoderBatch.h"
#include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/iTensor.h"
#include <cstdint>
#include <memory> #include <memory>
#include <optional>
#include <tuple>
#include <utility>
#include <vector> #include <vector>
namespace tensorrt_llm::runtime namespace tensorrt_llm::runtime
@ -58,7 +53,7 @@ public:
TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) override; TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) override;
void forwardSync(decoder_batch::Token const& e) override; void forwardSync(decoder_batch::Token const& token) override;
void forwardAsync(decoder::Output& output, decoder::Input const& input) override; void forwardAsync(decoder::Output& output, decoder::Input const& input) override;
@ -89,7 +84,7 @@ public:
//! @brief Gather final beam search results for request `batchIdx`. //! @brief Gather final beam search results for request `batchIdx`.
//! Result will only be available after event returned. //! Result will only be available after event returned.
[[nodiscard]] CudaEvent finalize(SizeType batchIdx) const; [[nodiscard]] CudaEvent finalize(SizeType batchIdx) const override;
//! @brief Gather final beam search results for all requests. //! @brief Gather final beam search results for all requests.
void finalize() const override; void finalize() const override;
@ -108,7 +103,7 @@ public:
} }
//! @returns [maxBeamWidth], cumulative log probabilities (per beam), on gpu //! @returns [maxBeamWidth], cumulative log probabilities (per beam), on gpu
[[nodiscard]] TensorPtr getCumLogProbs(SizeType batchIdx) const [[nodiscard]] TensorPtr getCumLogProbs(SizeType batchIdx) const override
{ {
auto tensor = ITensor::slice(mJointDecodingOutput->cumLogProbs, batchIdx, 1); auto tensor = ITensor::slice(mJointDecodingOutput->cumLogProbs, batchIdx, 1);
tensor->squeeze(0); tensor->squeeze(0);
@ -122,7 +117,7 @@ public:
} }
//! @returns [maxBeamWidth, maxSequenceLength], log probabilities (per beam), on gpu //! @returns [maxBeamWidth, maxSequenceLength], log probabilities (per beam), on gpu
[[nodiscard]] TensorPtr getLogProbs(SizeType batchIdx) const [[nodiscard]] TensorPtr getLogProbs(SizeType batchIdx) const override
{ {
auto tensor = ITensor::slice(mJointDecodingOutput->logProbs, batchIdx, 1); auto tensor = ITensor::slice(mJointDecodingOutput->logProbs, batchIdx, 1);
tensor->squeeze(0); tensor->squeeze(0);
@ -141,7 +136,7 @@ public:
//! @returns [batchSize, beamWidth], tokens generated in `iter` (per beam), on gpu //! @returns [batchSize, beamWidth], tokens generated in `iter` (per beam), on gpu
[[nodiscard]] TensorPtr getNewTokens(SizeType iter = 0) const override [[nodiscard]] TensorPtr getNewTokens(SizeType iter = 0) const override
{ {
TensorPtr newTokensView = std::move(ITensor::slice(mJointDecodingOutput->newTokensSteps, iter, 1)); TensorPtr newTokensView = ITensor::slice(mJointDecodingOutput->newTokensSteps, iter, 1);
newTokensView->squeeze(0); newTokensView->squeeze(0);
return ITensor::slice(newTokensView, 0, mActualBatchSize); return ITensor::slice(newTokensView, 0, mActualBatchSize);
} }
@ -149,7 +144,7 @@ public:
//! @returns [batchSize], the number of generation steps executed on each request //! @returns [batchSize], the number of generation steps executed on each request
[[nodiscard]] std::vector<SizeType> getNbSteps() const override [[nodiscard]] std::vector<SizeType> getNbSteps() const override
{ {
return std::vector<SizeType>(mNbSteps.begin(), mNbSteps.begin() + mActualBatchSize); return {mNbSteps.begin(), mNbSteps.begin() + mActualBatchSize};
} }
//! @returns [1], number of finished sequences, in pinned host memory //! @returns [1], number of finished sequences, in pinned host memory
@ -160,7 +155,7 @@ public:
private: private:
//! @brief Gather final beam search results for request `batchIdx`. //! @brief Gather final beam search results for request `batchIdx`.
CudaEvent postProcessRequest(SizeType batchIdx) const; [[nodiscard]] CudaEvent postProcessRequest(SizeType batchIdx) const;
//! @brief Initialize the decoder at `batchIdx` with a new `request`. //! @brief Initialize the decoder at `batchIdx` with a new `request`.
void newRequest(SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig); void newRequest(SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig);

View File

@ -31,7 +31,6 @@
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <memory> #include <memory>
#include <ostream> #include <ostream>
#include <stdexcept>
#include <type_traits> #include <type_traits>
#include <typeinfo> #include <typeinfo>
#include <vector> #include <vector>

View File

@ -16,14 +16,12 @@
#pragma once #pragma once
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaEvent.h" #include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/cudaStream.h" #include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/iStatefulGptDecoder.h" #include "tensorrt_llm/runtime/iStatefulGptDecoder.h"
#include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/utils/sessionUtils.h" #include "tensorrt_llm/runtime/utils/sessionUtils.h"
#include <cstdint>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -156,30 +154,30 @@ public:
//! @param batchIdx index of the batch //! @param batchIdx index of the batch
//! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token //! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token
//! ids without padding for request `batchIdx`, on gpu //! ids without padding for request `batchIdx`, on gpu
virtual TensorPtr getOutputIds(SizeType batchIdx) const = 0; [[nodiscard]] virtual TensorPtr getOutputIds(SizeType batchIdx) const = 0;
//! @brief Gather final beam search results for request `batchIdx`. //! @brief Gather final beam search results for request `batchIdx`.
//! Result will only be available after event returned //! Result will only be available after event returned
virtual CudaEvent finalize(SizeType batchIdx) const = 0; [[nodiscard]] virtual CudaEvent finalize(SizeType batchIdx) const = 0;
//! @returns [batchSize (actual)], marks finished requests (per batch) //! @returns [batchSize (actual)], marks finished requests (per batch)
virtual std::vector<bool> getFinished() const = 0; [[nodiscard]] virtual std::vector<bool> getFinished() const = 0;
//! @returns [batchSize, beamWidth], cumulative log probabilities (per beam), on gpu //! @returns [batchSize, beamWidth], cumulative log probabilities (per beam), on gpu
virtual TensorPtr getCumLogProbs() const = 0; [[nodiscard]] virtual TensorPtr getCumLogProbs() const = 0;
//! @returns [beamWidth], cumulative log probabilities (per beam) for request batchIdx, on gpu //! @returns [beamWidth], cumulative log probabilities (per beam) for request batchIdx, on gpu
virtual TensorPtr getCumLogProbs(SizeType batchIdx) const = 0; [[nodiscard]] virtual TensorPtr getCumLogProbs(SizeType batchIdx) const = 0;
//! @returns [batchSize, beamWidth, maxSeqLen], log probabilities (per beam), on gpu //! @returns [batchSize, beamWidth, maxSeqLen], log probabilities (per beam), on gpu
virtual TensorPtr getLogProbs() const = 0; [[nodiscard]] virtual TensorPtr getLogProbs() const = 0;
//! @returns [beamWidth, maxSeqLen], cumulative log probabilities (per beam) for request batchIdx, on gpu //! @returns [beamWidth, maxSeqLen], cumulative log probabilities (per beam) for request batchIdx, on gpu
virtual TensorPtr getLogProbs(SizeType batchIdx) const = 0; [[nodiscard]] virtual TensorPtr getLogProbs(SizeType batchIdx) const = 0;
virtual TensorPtr getParentIds() const = 0; [[nodiscard]] virtual TensorPtr getParentIds() const = 0;
virtual std::vector<SizeType> getNbSteps() const = 0; [[nodiscard]] virtual std::vector<SizeType> getNbSteps() const = 0;
//! @brief Initialize batched decoder at seqSlots with a new `requests`. //! @brief Initialize batched decoder at seqSlots with a new `requests`.
virtual void newRequests(std::vector<SizeType> const& seqSlots, std::vector<decoder_batch::Request> const& requests, virtual void newRequests(std::vector<SizeType> const& seqSlots, std::vector<decoder_batch::Request> const& requests,

View File

@ -23,10 +23,8 @@
#include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/samplingConfig.h" #include "tensorrt_llm/runtime/samplingConfig.h"
#include <cstdint>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector>
#include <NvInferRuntime.h> #include <NvInferRuntime.h>
@ -102,25 +100,25 @@ public:
virtual void finalize() const = 0; virtual void finalize() const = 0;
//! @returns [batchSize, beamWidth, maxSequenceLength], all token ids, on gpu //! @returns [batchSize, beamWidth, maxSequenceLength], all token ids, on gpu
virtual TensorPtr getOutputIds() const = 0; [[nodiscard]] virtual TensorPtr getOutputIds() const = 0;
//! @returns [batchSize, maxBeamWidth], cumulative log probabilities (per beam), on gpu //! @returns [batchSize, maxBeamWidth], cumulative log probabilities (per beam), on gpu
virtual TensorPtr getCumLogProbs() const = 0; [[nodiscard]] virtual TensorPtr getCumLogProbs() const = 0;
//! @returns [batchSize, maxBeamWidth, maxSequenceLength], log probabilities (per beam), on gpu //! @returns [batchSize, maxBeamWidth, maxSequenceLength], log probabilities (per beam), on gpu
virtual TensorPtr getLogProbs() const = 0; [[nodiscard]] virtual TensorPtr getLogProbs() const = 0;
//! @brief Get tokens generated in one step of last forward pass //! @brief Get tokens generated in one step of last forward pass
//! @param iter The iteration within [0; maxTokensPerStep) for which to get the tokens //! @param iter The iteration within [0; maxTokensPerStep) for which to get the tokens
//! @returns [batchSize, beamWidth], tokens generated in `iter` (per beam), on gpu //! @returns [batchSize, beamWidth], tokens generated in `iter` (per beam), on gpu
virtual TensorPtr getNewTokens(SizeType iter = 0) const = 0; [[nodiscard]] virtual TensorPtr getNewTokens(SizeType iter = 0) const = 0;
//! @brief Get maxTokensPerStep tokens generated in the last forward pass //! @brief Get maxTokensPerStep tokens generated in the last forward pass
//! @returns [maxTokensPerStep, batchSize, maxBeamWidth], tokens generated in last forward pass, on gpu //! @returns [maxTokensPerStep, batchSize, maxBeamWidth], tokens generated in last forward pass, on gpu
virtual TensorPtr getAllNewTokens() const = 0; [[nodiscard]] virtual TensorPtr getAllNewTokens() const = 0;
//! @returns [1], number of finished sequences, in pinned host memory //! @returns [1], number of finished sequences, in pinned host memory
virtual TensorPtr getNbFinished() const = 0; [[nodiscard]] virtual TensorPtr getNbFinished() const = 0;
virtual ~IStatefulGptDecoder() = default; virtual ~IStatefulGptDecoder() = default;

View File

@ -30,7 +30,6 @@
#include <memory> #include <memory>
#include <numeric> #include <numeric>
#include <ostream> #include <ostream>
#include <stdexcept>
#include <string> #include <string>
#include <type_traits> #include <type_traits>

View File

@ -18,7 +18,6 @@
#pragma once #pragma once
#include "tensorrt_llm/kernels/customAllReduceKernels.h" #include "tensorrt_llm/kernels/customAllReduceKernels.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/worldConfig.h" #include "tensorrt_llm/runtime/worldConfig.h"

View File

@ -19,8 +19,8 @@
#include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/runtime/iBuffer.h" #include "tensorrt_llm/runtime/iBuffer.h"
#include <algorithm> #include <atomic>
#include <cstdint> #include <cstddef>
#include <string> #include <string>
namespace tensorrt_llm::runtime namespace tensorrt_llm::runtime
@ -112,22 +112,22 @@ public:
auto const sizeDiff = -static_cast<DiffType>(size); auto const sizeDiff = -static_cast<DiffType>(size);
if constexpr (T == MemoryType::kGPU) if constexpr (T == MemoryType::kGPU)
{ {
mGpu -= std::min(size, mGpu); mGpu -= size;
mGpuDiff = sizeDiff; mGpuDiff = sizeDiff;
} }
else if constexpr (T == MemoryType::kCPU) else if constexpr (T == MemoryType::kCPU)
{ {
mCpu -= std::min(size, mCpu); mCpu -= size;
mCpuDiff = sizeDiff; mCpuDiff = sizeDiff;
} }
else if constexpr (T == MemoryType::kPINNED) else if constexpr (T == MemoryType::kPINNED)
{ {
mPinned -= std::min(size, mPinned); mPinned -= size;
mPinnedDiff = sizeDiff; mPinnedDiff = sizeDiff;
} }
else if constexpr (T == MemoryType::kUVM) else if constexpr (T == MemoryType::kUVM)
{ {
mUVM -= std::min(size, mUVM); mUVM -= size;
mUVMDiff = sizeDiff; mUVMDiff = sizeDiff;
} }
else else
@ -138,11 +138,7 @@ public:
void deallocate(MemoryType memoryType, SizeType size); void deallocate(MemoryType memoryType, SizeType size);
static MemoryCounters& getInstance() static MemoryCounters& getInstance();
{
static thread_local MemoryCounters mInstance;
return mInstance;
}
static std::string bytesToString(SizeType bytes, int precision = 2); static std::string bytesToString(SizeType bytes, int precision = 2);
@ -151,8 +147,8 @@ public:
[[nodiscard]] std::string toString() const; [[nodiscard]] std::string toString() const;
private: private:
SizeType mGpu{}, mCpu{}, mPinned{}, mUVM{}; std::atomic<SizeType> mGpu{}, mCpu{}, mPinned{}, mUVM{};
DiffType mGpuDiff{}, mCpuDiff{}, mPinnedDiff{}, mUVMDiff{}; std::atomic<DiffType> mGpuDiff{}, mCpuDiff{}, mPinnedDiff{}, mUVMDiff{};
}; };
} // namespace tensorrt_llm::runtime } // namespace tensorrt_llm::runtime

View File

@ -16,12 +16,10 @@
#pragma once #pragma once
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/iTensor.h"
#include <optional>
#include <utility> #include <utility>
namespace tensorrt_llm::runtime namespace tensorrt_llm::runtime

View File

@ -72,6 +72,7 @@ public:
{ {
TLLM_CHECK(configs.size() > 0); TLLM_CHECK(configs.size() > 0);
beamWidth = configs.front().beamWidth; beamWidth = configs.front().beamWidth;
normalizeLogProbs = configs.front().normalizeLogProbs;
temperature = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].temperature; }); temperature = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].temperature; });
minLength = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].minLength; }); minLength = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].minLength; });
repetitionPenalty repetitionPenalty
@ -87,6 +88,7 @@ public:
beamSearchDiversityRate beamSearchDiversityRate
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].beamSearchDiversityRate; }); = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].beamSearchDiversityRate; });
lengthPenalty = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].lengthPenalty; }); lengthPenalty = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].lengthPenalty; });
earlyStopping = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].earlyStopping; });
draftAcceptanceThreshold draftAcceptanceThreshold
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].draftAcceptanceThreshold; }); = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].draftAcceptanceThreshold; });
} }
@ -121,6 +123,7 @@ public:
SET_FROM_OPTIONAL(presencePenalty, PresencePenalty, FloatType) SET_FROM_OPTIONAL(presencePenalty, PresencePenalty, FloatType)
SET_FROM_OPTIONAL(frequencyPenalty, FrequencyPenalty, FloatType) SET_FROM_OPTIONAL(frequencyPenalty, FrequencyPenalty, FloatType)
SET_FROM_OPTIONAL(lengthPenalty, LengthPenalty, FloatType) SET_FROM_OPTIONAL(lengthPenalty, LengthPenalty, FloatType)
SET_FROM_OPTIONAL(earlyStopping, EarlyStopping, SizeType)
#undef SET_FROM_OPTIONAL #undef SET_FROM_OPTIONAL
} }
@ -142,11 +145,12 @@ public:
OptVec<SizeType> topPResetIds; // [batch_size] OptVec<SizeType> topPResetIds; // [batch_size]
// beam search layer // beam search layer
OptVec<FloatType> beamSearchDiversityRate; OptVec<FloatType> beamSearchDiversityRate; // [1] or [batch_size]
OptVec<FloatType> lengthPenalty; OptVec<FloatType> lengthPenalty; // [1] or [batch_size]
OptVec<SizeType> earlyStopping; // [1] or [batch_size]
// speculative decoding // speculative decoding, only the first value is used (in gptDecoderBatch.cpp)
OptVec<FloatType> draftAcceptanceThreshold; OptVec<FloatType> draftAcceptanceThreshold; // [1] or [batch_size]
std::optional<bool> normalizeLogProbs; std::optional<bool> normalizeLogProbs;
}; };

View File

@ -195,7 +195,8 @@ set(TRTLLM_LINK_LIBS
${TRT_LIB} ${TRT_LIB}
common_src common_src
kernels_src kernels_src
cutlass_src cutlass_src_pre_hopper
cutlass_src_hopper
layers_src layers_src
runtime_src) runtime_src)

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:4ba61c04ed7623fc44b5364802c1893fa824467455f4a9fe8245d5d51fef97e6 oid sha256:c9fd644e0a38b1d4d1a54d4b7b834cc6b0110a5771fcfc480e96795b3f9bc892
size 2172266 size 2081046

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:bf4afdfd281029c8e4bf0af548529b94a4a6d0f9bb5148ae10423e5e0275db06 oid sha256:90436c59eb243a0156e3f0aa95412a7caacbefdcde768c158edc4b821044dfd1
size 2195822 size 2102486

View File

@ -1,3 +1,3 @@
4c405d39a0cbb93d44a5758480a1a223 libtensorrt_llm_batch_manager_static.a f53c02e3829b516a6e9221745bcbacbd libtensorrt_llm_batch_manager_static.a
68aea75a2ed5b219eec5a0f77ce33482 libtensorrt_llm_batch_manager_static.pre_cxx11.a 9e92e5dbb104e3e676952ea40c81916f libtensorrt_llm_batch_manager_static.pre_cxx11.a
9b63c754d2a1edc7a17106e83c3e131d312f0a80 commit 25adff90cc350eb9ca9804051a08de80d547c113 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:39835ca321e9c45d3b554ebceb1734966b75f83dbe8c550cc44846fb4fae8f72 oid sha256:c3433d7b52bb6dcac32111172cb6201a9fee56e739f3660895083baebd1b89ee
size 2110728 size 2033616

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:789c2eba349161e84a76b95b23f8294cf3bdcf855871672d76722c4ae858d81b oid sha256:fb3f4145881984de6268c34f7e5d452f78f54952f454f747a1cd52bc3171de62
size 2091842 size 2012002

View File

@ -1,2 +1,2 @@
30a6c963121b3cfda21dc0117b7984e1 libtensorrt_llm_batch_manager_static.a d60b12741e940f56addaf2d92e78b50f libtensorrt_llm_batch_manager_static.a
0d2d2e3157201f6336d749b3e6f994bc libtensorrt_llm_batch_manager_static.pre_cxx11.a c55e606a3430d3a56cee3968a77b46f1 libtensorrt_llm_batch_manager_static.pre_cxx11.a

View File

@ -172,6 +172,11 @@ void MpiComm::allgather(const void* sendbuf, void* recvbuf, int count, MpiType d
MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm)); MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm));
} }
void MpiComm::mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const
{
MPICHECK(MPI_Mprobe(source, tag, mComm, msg, status));
}
int MpiComm::getRank() const int MpiComm::getRank() const
{ {
int rank = 0; int rank = 0;

View File

@ -34,6 +34,8 @@ enum class CutlassTileConfig
CtaShape128x128x8_WarpShape64x64x8, CtaShape128x128x8_WarpShape64x64x8,
// TensorCore configs CTA_N = 128, CTA_K = 64 // TensorCore configs CTA_N = 128, CTA_K = 64
// Warp configs for M=16
CtaShape16x128x64_WarpShape16x32x64,
// Warp configs for M=32 // Warp configs for M=32
CtaShape32x128x64_WarpShape32x32x64, CtaShape32x128x64_WarpShape32x32x64,
@ -50,7 +52,10 @@ enum class CutlassTileConfig
CtaShape128x256x64_WarpShape64x64x64, CtaShape128x256x64_WarpShape64x64x64,
// Warp configs for M=256 // Warp configs for M=256
CtaShape256x128x64_WarpShape64x64x64 CtaShape256x128x64_WarpShape64x64x64,
// TensorCore config CTA_N = 256, CTA_K = 64
CtaShape16x256x64_WarpShape16x64x64
}; };
enum class SplitKStyle enum class SplitKStyle

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:49c84a22cee9e6c3a975db08d8d0d8dbe88867e2eb4fc12a4b3ff6c1c90e8c21 oid sha256:13e17e2d9a94d2bc1b131d096a3722a83a67ab115fa8271b57b27f7e2877bdc1
size 586202 size 587334

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:fea93ae7d09e74b073a65d5d0ac34aec9ccc8f8299af1abd6826e97e9c8427f4 oid sha256:45438204eba812694bd30b68cfc9bb2bc54a8a59c6c86e037bbc4ac7e5f8230c
size 589652 size 589438

View File

@ -1,3 +1,3 @@
73999f4c2b3a4328db454b7ab6fe86d3 libtensorrt_llm_executor_static.a 835767a37292ea9786c0d6149ae270f4 libtensorrt_llm_executor_static.a
df53aa83848b5ed75550a7b536ca02a4 libtensorrt_llm_executor_static.pre_cxx11.a 1fe0c9ac7a1a35ce7d80676146867374 libtensorrt_llm_executor_static.pre_cxx11.a
9b63c754d2a1edc7a17106e83c3e131d312f0a80 commit 25adff90cc350eb9ca9804051a08de80d547c113 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:643e546711fd33a85073560e3428c6a2f60525f7592aa3328043dfad61631c30 oid sha256:7969768d3b9a65182ee519c60e11f27b0a088c2c0b732f3780d7c0c563dbb180
size 586532 size 587776

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:28059131a9325c88bd362cb12c57a2b2e47d3e0aac140e5d1cf9a7020a81999e oid sha256:98d9b7c4a586f0be0499a0df487cacba69985ce43ca5fd543c90c6a368c91b67
size 570860 size 571150

View File

@ -1,2 +1,2 @@
fa89714705a1915f052c635a07dc4c73 libtensorrt_llm_executor_static.a 6771f94e0bce39c6cab391cf1f92484c libtensorrt_llm_executor_static.a
83cbfaf10bedd7d8edeab33552dcf3df libtensorrt_llm_executor_static.pre_cxx11.a 84b7550448f8710de17644a5d404178f libtensorrt_llm_executor_static.pre_cxx11.a

View File

@ -28,7 +28,7 @@ if(FAST_BUILD)
"decoderMaskedMultiheadAttention(48|80|96|112|144|160|192|224).*cu$") "decoderMaskedMultiheadAttention(48|80|96|112|144|160|192|224).*cu$")
endif() endif()
add_library(kernels_src OBJECT ${SRC_CPP} ${SRC_CU}) add_library(kernels_src STATIC ${SRC_CPP} ${SRC_CU})
set_property(TARGET kernels_src PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET kernels_src PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET kernels_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) set_property(TARGET kernels_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)

View File

@ -39,7 +39,7 @@ namespace kernels
template <typename T> template <typename T>
__device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty) __device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty)
{ {
// score = log(prob) / (length)^length_penalty. // score = log(prob) / (length ^ length_penalty)
if (length_penalty == 0.0f || length == 1) if (length_penalty == 0.0f || length == 1)
{ {
return log_prob; return log_prob;

View File

@ -35,35 +35,42 @@ namespace kernels
// After we collect `beam_width` beams, we will sort them by their norm_scores. // After we collect `beam_width` beams, we will sort them by their norm_scores.
struct BeamHypotheses struct BeamHypotheses
{ {
int* output_ids_tgt = nullptr; // TODO: simplify the pointers
int* sequence_lengths_tgt = nullptr; // Pointers initialized in function prepareOutputs in gptDecoder.cpp
float* cum_log_probs = nullptr; // cum_log bool* is_done{nullptr}; // [batchSize], whether the batch is finished
float* normed_scores = nullptr; // cum_log / (length**length_penalty) const int* input_lengths{nullptr}; // [batchSize]
float* log_probs = nullptr; // log probs of each generated token float* cum_log_probs{nullptr}; // [batchSize, 2 * beamWidth], outputs.cum_log_probs->template getPtr<float>()
float* min_normed_scores = nullptr; // record the min normed scores for each batch float* log_probs{nullptr}; // [batchSize, 2 * beamWidth, maxSeqLen], not used?
int* num_beams = nullptr; // the number of finished beams we collect float* min_normed_scores{nullptr}; // [batchSize], worst normed scores for each batch
bool* is_done = nullptr; float* normed_scores{nullptr}; // [batchSize, 2 * beamWidth], cum_log / (length ^ length_penalty)
int* num_beams{nullptr}; // [batchSize], count of finished beams for each batch
int* output_ids_tgt{nullptr}; // [batchSize, 2 * beamWidth, maxSeqLen],
int* sequence_lengths_tgt{nullptr}; // [batchSize, 2 * beamWidth], different from sequence_lengths_src
// Used to set inputs // Pointers initialized in function invokeSoftMax in onlineBeamSearchLayer.cu
const int* output_ids_src; const int* end_ids{nullptr}; // get from SoftmaxParams
const int** output_ids_src_ptr; const int* output_ids_src{nullptr}; // for gatherTree
const int* parent_ids_src; const int* parent_ids_src{nullptr}; // for gatherTree
const int** parent_ids_src_ptr; const int** output_ids_src_ptr{nullptr}; // get from BeamSearchOutputParams for reading
const int* sequence_lengths_src; const int** parent_ids_src_ptr{nullptr}; // get from BeamSearchOutputParams for reading
const int* end_ids; float* log_probs_src{nullptr}; // get from outputs.output_log_probs
const float* log_probs_src; int* sequence_lengths_src{nullptr}; // get from BeamSearchOutputParams
const int* input_lengths; // For reading in function invokeTopkSoftMax but reading and writing in function invokeUpdate
int** output_ids_tgt_ptr{nullptr}; // get from BeamSearchOutputParams for writing
int** parent_ids_tgt_ptr{nullptr}; // get from BeamSearchOutputParams for writing
// some variables for kernels // Other scalar values and buffers
int step; int batch_size{0};
int ite; int beam_width{0};
int batch_size; int ite{0};
int local_batch_size; int local_batch_size{0};
int max_seq_len; int max_seq_len{0};
float* length_penalties; int step{0}; // useless in online version of beam search
int vocab_size{0};
bool early_stopping = true; float* diversity_rates{nullptr};
bool is_return_normed_score = true; // return normed_cum_log_probs or cum_log_probs float* length_penalties{nullptr};
int* early_stoppings{nullptr};
bool is_return_normed_score{true}; // return normed_cum_log_probs or cum_log_probs
}; };
template <typename T> template <typename T>

View File

@ -57,9 +57,17 @@ endif()
file(GLOB_RECURSE CU_INSTANTIATIONS ${CMAKE_CURRENT_BINARY_DIR}/*.cu) file(GLOB_RECURSE CU_INSTANTIATIONS ${CMAKE_CURRENT_BINARY_DIR}/*.cu)
add_library(cutlass_src OBJECT ${SRC_CPP} ${SRC_CU} ${CU_INSTANTIATIONS}) add_library(cutlass_src_pre_hopper STATIC ${SRC_CPP} ${SRC_CU})
set_property(TARGET cutlass_src PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET cutlass_src_pre_hopper PROPERTY POSITION_INDEPENDENT_CODE
set_property(TARGET cutlass_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) ON)
set_property(TARGET cutlass_src_pre_hopper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS
ON)
add_library(cutlass_src_hopper STATIC ${CU_INSTANTIATIONS})
set_property(TARGET cutlass_src_hopper PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cutlass_src_hopper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_dependencies(cutlass_src_hopper cutlass_src_pre_hopper)
# Note - we deliberately do not include 90a PTX (even when 9.0+PTX is # Note - we deliberately do not include 90a PTX (even when 9.0+PTX is
# specified). This is because sm_90a has arch conditional instructions that are # specified). This is because sm_90a has arch conditional instructions that are
@ -70,15 +78,21 @@ if("9.0" IN_LIST TORCH_CUDA_ARCH_LIST
OR TORCH_CUDA_ARCH_LIST STREQUAL "Auto") OR TORCH_CUDA_ARCH_LIST STREQUAL "Auto")
message(STATUS "MANUALLY APPENDING FLAG TO COMPILE FOR SM_90a.") message(STATUS "MANUALLY APPENDING FLAG TO COMPILE FOR SM_90a.")
target_compile_options( target_compile_options(
cutlass_src cutlass_src_pre_hopper
PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-gencode=arch=compute_90a,code=sm_90a>)
target_compile_options(
cutlass_src_hopper
PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-gencode=arch=compute_90a,code=sm_90a>) PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-gencode=arch=compute_90a,code=sm_90a>)
# Hopper kernels require cuda lib for TMA APIs # Hopper kernels require cuda lib for TMA APIs
target_link_libraries(cutlass_src PRIVATE CUDA::cuda_driver) target_link_libraries(cutlass_src_pre_hopper PRIVATE CUDA::cuda_driver)
target_link_libraries(cutlass_src_hopper PRIVATE CUDA::cuda_driver)
# No kernels should be parsed, unless hopper is specified. This is a build # No kernels should be parsed, unless hopper is specified. This is a build
# time improvement # time improvement
target_compile_definitions(cutlass_src target_compile_definitions(cutlass_src_pre_hopper
PRIVATE COMPILE_HOPPER_MIXED_INPUT_GEMMS)
target_compile_definitions(cutlass_src_hopper
PRIVATE COMPILE_HOPPER_MIXED_INPUT_GEMMS) PRIVATE COMPILE_HOPPER_MIXED_INPUT_GEMMS)
endif() endif()
@ -87,5 +101,9 @@ endif()
# compilation output. # compilation output.
if(NOT WIN32) if(NOT WIN32)
target_compile_options( target_compile_options(
cutlass_src PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-psabi>) cutlass_src_pre_hopper
PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-psabi>)
target_compile_options(
cutlass_src_hopper
PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-psabi>)
endif() endif()

View File

@ -52,6 +52,8 @@ TileShape get_cta_shape_for_config(CutlassTileConfig tile_config)
{ {
switch (tile_config) switch (tile_config)
{ {
case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: return TileShape{16, 128};
case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: return TileShape{16, 256};
case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: return TileShape{32, 128}; case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: return TileShape{32, 128};
case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: return TileShape{64, 64}; case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: return TileShape{64, 64};
case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64:
@ -145,7 +147,9 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
case CutlassGemmType::WeightOnly: case CutlassGemmType::WeightOnly:
if (sm >= 75) if (sm >= 75)
{ {
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, return {CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64,
CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64,
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64};
} }

View File

@ -127,9 +127,8 @@ LayoutDetails getLayoutDetailsForArch(QuantType quant_type)
return details; return details;
} }
LayoutDetails getLayoutDetailsForTransform(QuantType quant_type) LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch)
{ {
const int arch = getSMVersion();
if (arch >= 70 && arch < 75) if (arch >= 70 && arch < 75)
{ {
return getLayoutDetailsForArch<cutlass::arch::Sm70>(quant_type); return getLayoutDetailsForArch<cutlass::arch::Sm70>(quant_type);
@ -534,10 +533,15 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const
} }
void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight, void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight,
const std::vector<size_t>& shape, QuantType quant_type) const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave)
{ {
const int arch = getSMVersion(); int arch = getSMVersion();
LayoutDetails details = getLayoutDetailsForTransform(quant_type); if (force_interleave && arch == 90)
{
// Workaround for MOE which doesn't have specialised Hopper kernels yet
arch = 80;
}
LayoutDetails details = getLayoutDetailsForTransform(quant_type, arch);
TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
@ -616,7 +620,8 @@ Outputs
template <typename ComputeType, typename WeightType> template <typename ComputeType, typename WeightType>
void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight,
ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector<size_t>& shape, QuantType quant_type) ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector<size_t>& shape, QuantType quant_type,
bool force_interleave)
{ {
TLLM_CHECK_WITH_INFO(processed_quantized_weight, "Processed quantized tensor is NULL"); TLLM_CHECK_WITH_INFO(processed_quantized_weight, "Processed quantized tensor is NULL");
@ -719,48 +724,52 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
} }
} }
preprocess_weights_for_mixed_gemm(processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type); preprocess_weights_for_mixed_gemm(
processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type, force_interleave);
} }
template void symmetric_quantize<half, float>( template void symmetric_quantize<half, float>(
int8_t*, int8_t*, half*, const float*, const std::vector<size_t>&, QuantType); int8_t*, int8_t*, half*, const float*, const std::vector<size_t>&, QuantType, bool);
template void symmetric_quantize<half, half>( template void symmetric_quantize<half, half>(
int8_t*, int8_t*, half*, const half*, const std::vector<size_t>&, QuantType); int8_t*, int8_t*, half*, const half*, const std::vector<size_t>&, QuantType, bool);
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>(
int8_t*, int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType); int8_t*, int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType, bool);
template void symmetric_quantize<__nv_bfloat16, float>( template void symmetric_quantize<__nv_bfloat16, float>(
int8_t*, int8_t*, __nv_bfloat16*, const float*, const std::vector<size_t>&, QuantType); int8_t*, int8_t*, __nv_bfloat16*, const float*, const std::vector<size_t>&, QuantType, bool);
#endif #endif
template <typename ComputeType, typename WeightType> template <typename ComputeType, typename WeightType>
void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, const WeightType* input_weight_ptr, void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, const WeightType* input_weight_ptr,
const std::vector<size_t>& shape, QuantType quant_type) const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave)
{ {
symmetric_quantize(processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type); symmetric_quantize(
processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type, force_interleave);
} }
template void symmetric_quantize<float, float>(int8_t*, float*, const float*, const std::vector<size_t>&, QuantType); template void symmetric_quantize<float, float>(
int8_t*, float*, const float*, const std::vector<size_t>&, QuantType, bool);
template void symmetric_quantize<half, float>(int8_t*, half*, const float*, const std::vector<size_t>&, QuantType); template void symmetric_quantize<half, float>(
int8_t*, half*, const float*, const std::vector<size_t>&, QuantType, bool);
template void symmetric_quantize<half, half>(int8_t*, half*, const half*, const std::vector<size_t>&, QuantType); template void symmetric_quantize<half, half>(int8_t*, half*, const half*, const std::vector<size_t>&, QuantType, bool);
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>(
int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType); int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType, bool);
template void symmetric_quantize<__nv_bfloat16, half>( template void symmetric_quantize<__nv_bfloat16, half>(
int8_t*, __nv_bfloat16*, const half*, const std::vector<size_t>&, QuantType); int8_t*, __nv_bfloat16*, const half*, const std::vector<size_t>&, QuantType, bool);
template void symmetric_quantize<half, __nv_bfloat16>( template void symmetric_quantize<half, __nv_bfloat16>(
int8_t*, half*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType); int8_t*, half*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType, bool);
template void symmetric_quantize<__nv_bfloat16, float>( template void symmetric_quantize<__nv_bfloat16, float>(
int8_t*, __nv_bfloat16*, const float*, const std::vector<size_t>&, QuantType); int8_t*, __nv_bfloat16*, const float*, const std::vector<size_t>&, QuantType, bool);
#endif #endif
} // namespace cutlass_kernels } // namespace cutlass_kernels

View File

@ -47,17 +47,18 @@ void subbyte_transpose(int8_t* transposed_quantized_tensor, const int8_t* quanti
void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type); void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type);
void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight, void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight,
const std::vector<size_t>& shape, QuantType quant_type); const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave = false);
template <typename ComputeType, typename WeightType> template <typename ComputeType, typename WeightType>
void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, const WeightType* input_weight_ptr, void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, const WeightType* input_weight_ptr,
const std::vector<size_t>& shape, QuantType quant_type); const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave);
// This is exposed so that we can write tests that use the processed weights for CUTLASS but the unprocessed weight // This is exposed so that we can write tests that use the processed weights for CUTLASS but the unprocessed weight
// to implement a simple reference implementation. // to implement a simple reference implementation.
template <typename ComputeType, typename WeightType> template <typename ComputeType, typename WeightType>
void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight,
ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector<size_t>& shape, QuantType quant_type); ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector<size_t>& shape, QuantType quant_type,
bool force_interleave);
} // namespace cutlass_kernels } // namespace cutlass_kernels
} // namespace kernels } // namespace kernels

View File

@ -84,7 +84,7 @@ public:
protected: protected:
static constexpr int SPLIT_K_LIMIT = 7; static constexpr int SPLIT_K_LIMIT = 7;
static constexpr int MIN_M_TILE = 32; static constexpr int MIN_M_TILE = 16;
static constexpr int MIN_N_TILE = 64; static constexpr int MIN_N_TILE = 64;
}; };

View File

@ -324,6 +324,26 @@ void dispatch_gemm_to_cutlass(const ActivationType* A, const WeightType* B, cons
// best for mixed type gemms. // best for mixed type gemms.
switch (gemm_config.tile_config) switch (gemm_config.tile_config)
{ {
case tkc::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
if constexpr (arch::kMinComputeCapability >= 75)
{
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
cutlass::gemm::GemmShape<16, 128, 64>, cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales,
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
stream, occupancy);
}
break;
case tkc::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
if constexpr (arch::kMinComputeCapability >= 75)
{
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
cutlass::gemm::GemmShape<16, 256, 64>, cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales,
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
stream, occupancy);
}
break;
case tkc::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: case tkc::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag, dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
cutlass::gemm::GemmShape<32, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, cutlass::gemm::GemmShape<32, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales,
@ -337,11 +357,8 @@ void dispatch_gemm_to_cutlass(const ActivationType* A, const WeightType* B, cons
stream, occupancy); stream, occupancy);
break; break;
case tkc::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: case tkc::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
if (arch::kMinComputeCapability < 75) TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
{ if constexpr (arch::kMinComputeCapability >= 75)
TLLM_CHECK_WITH_INFO(false, "Invalid config on Volta");
}
else
{ {
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag, dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales,
@ -433,41 +450,6 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
} }
} }
// Disabled since the fused GEMM, activation kernels will not be used in v1.
// template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
// void CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::gemm_bias_act(const T* A, const WeightType* B, const T*
// weight_scales,
// const T* biases, T* C, int m, int n, int k, ActivationType activation_type, char* workspace_ptr,
// const size_t workspace_bytes, cudaStream_t stream)
// {
// TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
// switch (activation_type)
// {
// case ActivationType::Relu:
// run_gemm<tkc::EpilogueOpBiasReLU>(
// A, B, weight_scales, biases, C, m, n, k, workspace_ptr, workspace_bytes, stream);
// break;
// case ActivationType::Gelu:
// run_gemm<tkc::EpilogueOpBiasFtGelu>(
// A, B, weight_scales, biases, C, m, n, k, workspace_ptr, workspace_bytes, stream);
// break;
// case ActivationType::Silu:
// run_gemm<tkc::EpilogueOpBiasSilu>(
// A, B, weight_scales, biases, C, m, n, k, workspace_ptr, workspace_bytes, stream);
// break;
// case ActivationType::Identity:
// run_gemm<tkc::EpilogueOpBias>(A, B, weight_scales, biases, C, m, n, k, workspace_ptr, workspace_bytes,
// stream); break;
// case ActivationType::InvalidType: TLLM_CHECK_WITH_INFO(false, "Activation type for fpA_intB must be
// valid."); break; default:
// {
// TLLM_CHECK_WITH_INFO(false, "Invalid activation type.");
// }
// }
// }
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType, template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
typename BiasType, typename OutputType> typename BiasType, typename OutputType>
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm( void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(

View File

@ -231,6 +231,24 @@ void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_s
{ {
switch (gemm_config.tile_config) switch (gemm_config.tile_config)
{ {
case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
if constexpr (arch::kMinComputeCapability >= 75)
{
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 128, 64>,
cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert,
total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
}
break;
case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
if constexpr (arch::kMinComputeCapability >= 75)
{
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 256, 64>,
cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert,
total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
}
break;
case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>, dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows,
@ -266,6 +284,24 @@ void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_s
{ {
switch (gemm_config.tile_config) switch (gemm_config.tile_config)
{ {
case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
if constexpr (arch::kMinComputeCapability >= 75)
{
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 128, 64>,
cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert,
total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
}
break;
case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
if constexpr (arch::kMinComputeCapability >= 75)
{
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 256, 64>,
cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert,
total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
}
break;
case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>, dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows,

View File

@ -53,6 +53,7 @@ struct XQAParams
// almost copy from GPTAttentionPluginCommon. // almost copy from GPTAttentionPluginCommon.
// maybe use one struct for parameters in GPTAttentionPluginCommon and share the same here. // maybe use one struct for parameters in GPTAttentionPluginCommon and share the same here.
int32_t generation_input_length; int32_t generation_input_length;
int32_t layer_idx = 0;
int32_t num_q_heads = 0; int32_t num_q_heads = 0;
int32_t num_kv_heads = 0; int32_t num_kv_heads = 0;
int32_t head_size = 0; int32_t head_size = 0;

View File

@ -121,7 +121,7 @@ __global__ void gatherTree(gatherTreeParam param)
template <typename T> template <typename T>
__device__ __forceinline__ T applyLengthPenalty(T logProb, int length, float lengthPenalty) __device__ __forceinline__ T applyLengthPenalty(T logProb, int length, float lengthPenalty)
{ {
// score = log(prob) / (length)^lengthPenalty. // score = log(prob) / (length ^ lengthPenalty)
if (lengthPenalty == 0.0f || length == 1) if (lengthPenalty == 0.0f || length == 1)
{ {
return logProb; return logProb;

View File

@ -46,7 +46,8 @@ struct gatherTreeParam
int32_t* outputIds = nullptr; // the buffer to put finalized ids int32_t* outputIds = nullptr; // the buffer to put finalized ids
cudaStream_t stream; cudaStream_t stream;
float* cumLogProbs = nullptr; // [batchSize, beamWidth] float* cumLogProbs = nullptr; // [batchSize, beamWidth]
float lengthPenalty = 1.0f; // on cpu float lengthPenalty = 1.0f;
int earlyStopping = 1;
}; };
/* /*

View File

@ -62,7 +62,7 @@ void groupedGemm_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vect
std::vector<void*> ptrC, std::vector<void*> ptrD, void* gemmParamsWorkSpace, int64_t gemmParamsWorkSpaceSize, std::vector<void*> ptrC, std::vector<void*> ptrD, void* gemmParamsWorkSpace, int64_t gemmParamsWorkSpaceSize,
void* gemmWorkSpace, int64_t gemmWorkspaceSize, nvinfer1::DataType dataType, cudaStream_t stream) void* gemmWorkSpace, int64_t gemmWorkspaceSize, nvinfer1::DataType dataType, cudaStream_t stream)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
using ElementA = cutlassType; using ElementA = cutlassType;
using ElementB = cutlassType; using ElementB = cutlassType;
using ElementOutput = cutlassType; using ElementOutput = cutlassType;
@ -167,7 +167,7 @@ void groupedGemm_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vect
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to run CUTLASS Grouped GEMM kernel."); TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to run CUTLASS Grouped GEMM kernel.");
std::free(host_workspace); std::free(host_workspace);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
template <int M1, int N1, int K1, int M2, int N2, int K2> template <int M1, int N1, int K1, int M2, int N2, int K2>

View File

@ -25,27 +25,20 @@ namespace kernels
{ {
template <typename T, int MAX_K> template <typename T, int MAX_K>
void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const FinishedState* finished, void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const FinishedState* finished, float* cum_log_probs,
const int* sequence_lengths, float* cum_log_probs, float* output_log_probs, int** output_ids_ptr, void* temp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps, cudaStream_t stream);
void* temp_storage, const int temp_storage_size, BeamHypotheses* beam_hyps, const int batch_size,
const int beam_width, const int vocab_size, const int* end_ids, const float* diversity_rates,
const float* length_penalties, cudaStream_t stream);
#define CASE_K(MAX_K) \ #define CASE_K(MAX_K) \
topK_softMax_kernelLauncher<T, MAX_K>(log_probs, bias, finished, sequence_lengths, cum_log_probs, \ topK_softMax_kernelLauncher<T, MAX_K>( \
output_log_probs, output_ids_ptr, temp_storage, temp_storage_size, beam_hyps, batch_size, beam_width, \ log_probs, bias, finished, cum_log_probs, temp_storage, temp_storage_size, beam_hyps, stream); \
vocab_size, end_ids, diversity_rates, length_penalties, stream); \
break; break;
template <typename T> template <typename T>
void invokeTopkSoftMax(const T* log_probs, const T* bias, const FinishedState* finished, const int* sequence_lengths, void invokeTopkSoftMax(const T* log_probs, const T* bias, const FinishedState* finished, float* cum_log_probs,
float* cum_log_probs, float* output_log_probs, int** output_ids_ptr, void* temp_storage, void* temp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps, cudaStream_t stream)
const int temp_storage_size, BeamHypotheses* beam_hyps, const int batch_size, const int beam_width,
const int vocab_size, const int* end_ids, const float* diversity_rates, const float* length_penalties,
cudaStream_t stream)
{ {
int log_beam_width(0); int log_beam_width(0);
int recursor(beam_width - 1); int recursor(beam_hyps.beam_width - 1);
while (recursor >>= 1) while (recursor >>= 1)
++log_beam_width; ++log_beam_width;
@ -66,23 +59,19 @@ void invokeTopkSoftMax(const T* log_probs, const T* bias, const FinishedState* f
CASE_K(64) CASE_K(64)
#endif // FAST_BUILD #endif // FAST_BUILD
default: default:
throw std::runtime_error( throw std::runtime_error(fmtstr("%s:%d Topk kernel of beam search does not support beam_width=%d", __FILE__,
fmtstr("%s:%d Topk kernel of beam search does not support beam_width=%d", __FILE__, __LINE__, beam_width)); __LINE__, beam_hyps.beam_width));
} }
} }
#undef CASE_K #undef CASE_K
template void invokeTopkSoftMax<float>(const float* log_probs, const float* bias, const FinishedState* finished, template void invokeTopkSoftMax<float>(const float* log_probs, const float* bias, const FinishedState* finished,
const int* sequence_lengths, float* cum_log_probs, float* output_log_probs, int** output_ids_ptr, void* tmp_storage, float* cum_log_probs, void* tmp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps,
const int temp_storage_size, BeamHypotheses* beam_hyps, const int batch_size, const int beam_width,
const int vocab_size, const int* end_ids, const float* diversity_rates, const float* length_penalties,
cudaStream_t stream); cudaStream_t stream);
template void invokeTopkSoftMax<half>(const half* log_probs, const half* bias, const FinishedState* finished, template void invokeTopkSoftMax<half>(const half* log_probs, const half* bias, const FinishedState* finished,
const int* sequence_lengths, float* cum_log_probs, float* output_log_probs, int** output_ids_ptr, void* tmp_storage, float* cum_log_probs, void* tmp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps,
const int temp_storage_size, BeamHypotheses* beam_hyps, const int batch_size, const int beam_width,
const int vocab_size, const int* end_ids, const float* diversity_rates, const float* length_penalties,
cudaStream_t stream); cudaStream_t stream);
} // namespace kernels } // namespace kernels

View File

@ -24,10 +24,8 @@ namespace kernels
{ {
template <typename T> template <typename T>
void invokeTopkSoftMax(const T* log_probs, const T* bias, const FinishedState* finished, const int* sequence_lengths, void invokeTopkSoftMax(const T* log_probs, const T* bias, const FinishedState* finished, float* cum_log_probs,
float* cum_log_probs, float* output_log_probs, int** output_ids_ptr, void* tmp_storage, const int temp_storage_size, void* tmp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps, cudaStream_t stream);
BeamHypotheses* beam_hyps, const int batch_size, const int beam_width, const int vocab_size, const int* end_ids,
const float* diversity_rates, const float* length_penalties, cudaStream_t stream);
} // namespace kernels } // namespace kernels
} // namespace tensorrt_llm } // namespace tensorrt_llm

View File

@ -44,7 +44,7 @@ static const int SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE = 256;
template <typename T> template <typename T>
__device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty) __device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty)
{ {
// score = log(prob) / (length)^length_penalty. // score = log(prob) / (length ^ length_penalty).
if (length_penalty == 0.0f || length == 1) if (length_penalty == 0.0f || length == 1)
{ {
return log_prob; return log_prob;
@ -56,9 +56,10 @@ template <typename T, int MAX_K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__ __launch_bounds__(THREADBLOCK_SIZE) __global__
void batch_topK_kernel(int* topk_tmp_id_buf, T* topk_tmp_val_buf, int* id_buf) void batch_topK_kernel(int* topk_tmp_id_buf, T* topk_tmp_val_buf, int* id_buf)
{ {
int thread_id = threadIdx.x; const int thread_id = threadIdx.x;
int block_id = blockIdx.x; const int block_id = blockIdx.x;
TopK<T, MAX_K> partial; TopK<T, MAX_K> partial;
if (thread_id == 0) if (thread_id == 0)
{ {
for (int i = 0; i < MAX_K; ++i) for (int i = 0; i < MAX_K; ++i)
@ -70,7 +71,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
int index = block_id * MAX_K * MAX_K; int index = block_id * MAX_K * MAX_K;
for (int i = 0; i < MAX_K * MAX_K; i++) for (int i = 0; i < MAX_K * MAX_K; i++)
{ {
partial.insert((T) topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]); partial.insert(topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]);
} }
index = block_id * MAX_K; index = block_id * MAX_K;
@ -85,9 +86,10 @@ template <typename T, int MAX_K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(const int* __restrict topk_tmp_id_buf, __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(const int* __restrict topk_tmp_id_buf,
const T* __restrict topk_tmp_val_buf, int* __restrict id_buf, T* __restrict val_buf) const T* __restrict topk_tmp_val_buf, int* __restrict id_buf, T* __restrict val_buf)
{ {
int thread_id = threadIdx.x; const int thread_id = threadIdx.x;
int block_id = blockIdx.x; const int block_id = blockIdx.x;
TopK<T, MAX_K> partial; TopK<T, MAX_K> partial;
if (thread_id == 0) if (thread_id == 0)
{ {
for (int i = 0; i < MAX_K; ++i) for (int i = 0; i < MAX_K; ++i)
@ -99,7 +101,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(const int*
int index = block_id * MAX_K * MAX_K; int index = block_id * MAX_K * MAX_K;
for (int i = 0; i < MAX_K * MAX_K; i++) for (int i = 0; i < MAX_K * MAX_K; i++)
{ {
partial.insert((T) topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]); partial.insert(topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]);
} }
index = block_id * MAX_K; index = block_id * MAX_K;
@ -112,34 +114,37 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(const int*
} }
template <typename T, int MAX_K2, int THREADBLOCK_SIZE> template <typename T, int MAX_K2, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel(const int* __restrict topk_tmp_id_buf,
void batch_topk_kernel(const int* __restrict x, const T* __restrict y, int** output_ids_ptr, float* __restrict v, const T* __restrict topk_tmp_val_buf, float* __restrict cum_log_probs, const FinishedState* finished,
float* output_log_probs, const FinishedState* finished, const int* sequence_lengths, BeamHypotheses beam_hyps, BeamHypotheses beam_hyps, const int candidate_size)
const int V, const int K, const int vocab_size, const float* length_penalties, const float* diversity_rates)
{ {
int thread_id = threadIdx.x; const int thread_id = threadIdx.x;
int vector_id = blockIdx.x; const int vector_id = blockIdx.x;
const int K{beam_hyps.beam_width};
const int vocab_size{beam_hyps.vocab_size};
const int global_batch_idx{beam_hyps.ite * beam_hyps.local_batch_size + vector_id}; const int global_batch_idx{beam_hyps.ite * beam_hyps.local_batch_size + vector_id};
const T diversity_rate{diversity_rates[global_batch_idx]}; const T MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
const float length_penalty{length_penalties[global_batch_idx]};
// reposition x, y to data for the current vector const float length_penalty{beam_hyps.length_penalties[global_batch_idx]};
x += vector_id * V; const int early_stopping{beam_hyps.early_stoppings[global_batch_idx]};
y += vector_id * V; const int* sequence_lengths{beam_hyps.sequence_lengths_src};
const T diversity_rate{beam_hyps.diversity_rates[global_batch_idx]};
extern __shared__ char buf_s_[]; // intermediate result float* output_log_probs{beam_hyps.log_probs_src};
T* buf_s = reinterpret_cast<T*>(buf_s_);
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
using cub_kvp = cub::KeyValuePair<int, T>; using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduce = cub::BlockReduce<cub_kvp, THREADBLOCK_SIZE>; using BlockReduce = cub::BlockReduce<cub_kvp, THREADBLOCK_SIZE>;
extern __shared__ char buf_s_[]; // intermediate result
T* buf_s = reinterpret_cast<T*>(buf_s_);
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ int selected_beams;
__shared__ float old_cum_log_probs[MAX_K2]; __shared__ float old_cum_log_probs[MAX_K2];
__shared__ char cta_topk_store[MAX_K2 * sizeof(cub_kvp)]; __shared__ cub_kvp cta_topk[MAX_K2];
auto* cta_topk = reinterpret_cast<cub_kvp*>(cta_topk_store); __shared__ int selected_beams;
__shared__ int thread_requiring_update;
// reposition topk_tmp_id_buf, topk_tmp_val_buf to data for the current vector
topk_tmp_id_buf += vector_id * candidate_size;
topk_tmp_val_buf += vector_id * candidate_size;
if (thread_id == 0) if (thread_id == 0)
{ {
@ -147,47 +152,56 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
} }
if (thread_id < K) if (thread_id < K)
{ {
old_cum_log_probs[thread_id] = v[vector_id * K + thread_id]; old_cum_log_probs[thread_id] = cum_log_probs[vector_id * K + thread_id];
} }
__syncthreads(); __syncthreads();
if (beam_hyps.num_beams != nullptr) if (beam_hyps.num_beams != nullptr)
{ {
// initialize worst_score if this batch has no finished beam
if (beam_hyps.num_beams[global_batch_idx] == 0 && thread_id == 0) if (beam_hyps.num_beams[global_batch_idx] == 0 && thread_id == 0)
{ {
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX;
} }
// return if this batch has enough finished beams
else if (beam_hyps.num_beams[global_batch_idx] == K) else if (beam_hyps.num_beams[global_batch_idx] == K)
{ {
return; return;
} }
} }
// Get top 2K tokens from cadidates
cub::ArgMax arg_max; cub::ArgMax arg_max;
cub_kvp partial_topk{V - 1, -MAX_T_VAL}; cub_kvp partial_topk{candidate_size - 1, -MAX_T_VAL};
for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE) for (int elem_id = thread_id; elem_id < candidate_size; elem_id += THREADBLOCK_SIZE)
{ {
int i = beam_hyps.num_beams == nullptr ? elem_id % K : elem_id / 2 / K; int i = beam_hyps.num_beams == nullptr ? elem_id % K : elem_id / 2 / K;
T elem = length_penalty == 0.0f T elem = topk_tmp_val_buf[elem_id];
? y[elem_id] if (length_penalty > 0.0f)
: apply_length_penalty(y[elem_id], {
finished[vector_id * K + i].isFinished() ? sequence_lengths[vector_id * K + i] int length = sequence_lengths[vector_id * K + i];
: sequence_lengths[vector_id * K + i] + 1, if (early_stopping == 0)
length_penalty); {
// Use generated_length (rather than sequence_length) to compute length_penalty
// https://github.com/huggingface/transformers/blob/main/src/transformers/generation/beam_search.py#L957
// But this branch will cause CI error in
// "C++ Tests (GPT) on A30", "C++ Tests (GPT-J) on H100_PCIe", "H100_PCIe-accuracy-0"
length -= beam_hyps.input_lengths[global_batch_idx];
}
const int pad_if_not_finish = finished[vector_id * K + i].isFinished() ? 0 : 1;
elem = apply_length_penalty(elem, length + pad_if_not_finish, length_penalty);
}
elem += diversity_rate * (T) i; elem += diversity_rate * (T) i;
int elem_idx = elem_id; // x[elem_id]; cub_kvp new_elem{elem_id, elem};
cub_kvp new_elem{elem_idx, elem};
partial_topk = arg_max(partial_topk, new_elem); partial_topk = arg_max(partial_topk, new_elem);
buf_s[elem_id] = elem; buf_s[elem_id] = elem;
} }
__syncthreads(); __syncthreads();
__shared__ int thread_requiring_update;
for (int i = 0; i < 2 * K; ++i) for (int i = 0; i < 2 * K; ++i)
{ {
cub_kvp total_topk = BlockReduce(temp_storage).Reduce(partial_topk, arg_max); cub_kvp total_topk = BlockReduce(temp_storage).Reduce(partial_topk, arg_max);
if (threadIdx.x == 0) if (threadIdx.x == 0)
{ {
cta_topk[i] = total_topk; cta_topk[i] = total_topk;
@ -196,13 +210,13 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
} }
__syncthreads(); __syncthreads();
// Only 1 thread needs to update the old partial before the next block reduce. We don't need to do this update // Only one thread needs to update the old partial before the next block reduce.
// on the last iteration. // No need to do this in the last iteration.
if (thread_id == thread_requiring_update && i < (2 * K - 1)) if (thread_id == thread_requiring_update && i < (2 * K - 1))
{ {
partial_topk.key = V - 1; partial_topk.key = candidate_size - 1;
partial_topk.value = -MAX_T_VAL; partial_topk.value = -MAX_T_VAL;
for (int tid = thread_id; tid < V; tid += THREADBLOCK_SIZE) for (int tid = thread_id; tid < candidate_size; tid += THREADBLOCK_SIZE)
{ {
cub_kvp new_elem{tid, buf_s[tid]}; cub_kvp new_elem{tid, buf_s[tid]};
partial_topk = arg_max(partial_topk, new_elem); partial_topk = arg_max(partial_topk, new_elem);
@ -212,104 +226,101 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
if (thread_id == 0) if (thread_id == 0)
{ {
v += vector_id * K; cum_log_probs += vector_id * K;
for (int i = 0; i < 2 * K; ++i) for (int i = 0; i < 2 * K; ++i)
{ {
const int current_key = cta_topk[i].key; const int current_key = cta_topk[i].key;
const T current_value = cta_topk[i].value; const T current_value = cta_topk[i].value;
if (i < K && beam_hyps.num_beams != nullptr && x[current_key] % vocab_size == beam_hyps.end_ids[vector_id]) if (i < K && beam_hyps.num_beams != nullptr
&& topk_tmp_id_buf[current_key] % vocab_size == beam_hyps.end_ids[vector_id])
{ {
// if beam_token does not belong to top num_beams tokens, it should not // Add beam only if beam_token belongs to top K tokens
// be added. Refer from // https://github.com/huggingface/transformers/blob/main/src/transformers/generation/beam_search.py#L272
// https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/generation_beam_search.py#L257 const float normed_score = (float) current_value;
{ const int num_beam = beam_hyps.num_beams[global_batch_idx];
const float normed_score = (float) current_value; int beam_idx = num_beam;
const int num_beam = beam_hyps.num_beams[global_batch_idx];
int beam_idx = num_beam;
// If there are beam_width finished sentences, check that the score of
// selected candidatet is higher than min_normed_score or not. If
// current score is better, replace worst one and update the
// min_normed_score.
if (num_beam == K)
{
if (normed_score < beam_hyps.min_normed_scores[global_batch_idx])
{
// end the tracing and exist this for loop
selected_beams = K;
break;
}
else
{
// find the beam index which's score = min_normed_score, erase it.
for (int j = 0; j < K; j++)
{
if (beam_hyps.normed_scores[global_batch_idx * (K * 2) + j]
== beam_hyps.min_normed_scores[global_batch_idx])
{
beam_idx = j;
beam_hyps.num_beams[global_batch_idx]--;
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; // There are already K beams
beam_hyps.normed_scores[global_batch_idx * (K * 2) + j] = normed_score; if (num_beam == K)
for (int l = 0; l < K; l++) {
{ // The current score is worse than the worst one in beams
beam_hyps.min_normed_scores[global_batch_idx] if (normed_score < beam_hyps.min_normed_scores[global_batch_idx])
= min(beam_hyps.min_normed_scores[global_batch_idx], {
beam_hyps.normed_scores[global_batch_idx * (K * 2) + l]); selected_beams = K;
} break;
break; }
// The current score is better than the worst one in beams
else
{
// Find the beam index which score == min_normed_score and erase it.
for (int j = 0; j < K; j++)
{
if (beam_hyps.normed_scores[global_batch_idx * (K * 2) + j]
== beam_hyps.min_normed_scores[global_batch_idx])
{
beam_idx = j;
beam_hyps.num_beams[global_batch_idx]--;
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX;
beam_hyps.normed_scores[global_batch_idx * (K * 2) + j] = normed_score;
for (int l = 0; l < K; l++)
{
beam_hyps.min_normed_scores[global_batch_idx]
= min(beam_hyps.min_normed_scores[global_batch_idx],
beam_hyps.normed_scores[global_batch_idx * (K * 2) + l]);
} }
break;
} }
} }
} }
const int tgt_id_offset
= ((vector_id + beam_hyps.ite * beam_hyps.local_batch_size) * (K * 2) + beam_idx)
* (beam_hyps.max_seq_len);
int prev_id = (x[current_key] / vocab_size) % K;
const int current_step{sequence_lengths[vector_id * K + prev_id]};
beam_hyps.output_ids_tgt[tgt_id_offset + current_step] = beam_hyps.end_ids[vector_id];
if (beam_hyps.log_probs != nullptr)
{
beam_hyps.log_probs[tgt_id_offset + current_step]
= (float) y[current_key] - old_cum_log_probs[(x[current_key] / vocab_size) % K];
}
for (int j = current_step - 1; j >= 0; j--)
{
const int src_idx = j * beam_hyps.batch_size * K
+ beam_hyps.ite * beam_hyps.local_batch_size * K + vector_id * K + prev_id;
beam_hyps.output_ids_tgt[tgt_id_offset + j]
= beam_hyps.output_ids_src_ptr[vector_id][prev_id * beam_hyps.max_seq_len + j];
if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr)
{
beam_hyps.log_probs[tgt_id_offset + j] = beam_hyps.log_probs_src[src_idx];
}
prev_id = beam_hyps.parent_ids_src_ptr[vector_id][prev_id * beam_hyps.max_seq_len + j];
}
const int tgt_beam_idx = global_batch_idx * (K * 2) + beam_idx;
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = current_step;
beam_hyps.normed_scores[tgt_beam_idx] = normed_score;
beam_hyps.min_normed_scores[global_batch_idx]
= min(beam_hyps.min_normed_scores[global_batch_idx], beam_hyps.normed_scores[tgt_beam_idx]);
beam_hyps.num_beams[global_batch_idx]++;
beam_hyps.cum_log_probs[tgt_beam_idx] = (float) y[current_key];
} }
const int tgt_id_offset
= ((vector_id + beam_hyps.ite * beam_hyps.local_batch_size) * (K * 2) + beam_idx)
* (beam_hyps.max_seq_len);
int prev_id = (topk_tmp_id_buf[current_key] / vocab_size) % K;
const int current_step{sequence_lengths[vector_id * K + prev_id]};
beam_hyps.output_ids_tgt[tgt_id_offset + current_step] = beam_hyps.end_ids[vector_id];
if (beam_hyps.log_probs != nullptr)
{
beam_hyps.log_probs[tgt_id_offset + current_step] = (float) topk_tmp_val_buf[current_key]
- old_cum_log_probs[(topk_tmp_id_buf[current_key] / vocab_size) % K];
}
for (int j = current_step - 1; j >= 0; j--)
{
const int src_idx = j * beam_hyps.batch_size * K + beam_hyps.ite * beam_hyps.local_batch_size * K
+ vector_id * K + prev_id;
beam_hyps.output_ids_tgt[tgt_id_offset + j]
= beam_hyps.output_ids_src_ptr[vector_id][prev_id * beam_hyps.max_seq_len + j];
if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr)
{
beam_hyps.log_probs[tgt_id_offset + j] = beam_hyps.log_probs_src[src_idx];
}
prev_id = beam_hyps.parent_ids_src_ptr[vector_id][prev_id * beam_hyps.max_seq_len + j];
}
const int tgt_beam_idx = global_batch_idx * (K * 2) + beam_idx;
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = current_step;
beam_hyps.normed_scores[tgt_beam_idx] = normed_score;
beam_hyps.min_normed_scores[global_batch_idx]
= min(beam_hyps.min_normed_scores[global_batch_idx], beam_hyps.normed_scores[tgt_beam_idx]);
beam_hyps.num_beams[global_batch_idx]++;
cum_log_probs[tgt_beam_idx] = (float) topk_tmp_val_buf[current_key];
} }
else if ((beam_hyps.num_beams != nullptr && i < 2 * K) || (beam_hyps.num_beams == nullptr && i < K)) else if (beam_hyps.num_beams != nullptr || beam_hyps.num_beams == nullptr && i < K)
{ {
const int current_step{sequence_lengths[vector_id * K + selected_beams]}; const int current_step{sequence_lengths[vector_id * K + selected_beams]};
output_ids_ptr[vector_id][selected_beams * beam_hyps.max_seq_len + current_step] = x[current_key]; beam_hyps.output_ids_tgt_ptr[vector_id][selected_beams * beam_hyps.max_seq_len + current_step]
= topk_tmp_id_buf[current_key];
if (output_log_probs != nullptr) if (output_log_probs != nullptr)
{ {
output_log_probs[current_step * beam_hyps.batch_size * K + vector_id * K + selected_beams] output_log_probs[current_step * beam_hyps.batch_size * K + vector_id * K + selected_beams]
= (float) y[current_key] - old_cum_log_probs[(x[current_key] / vocab_size) % K]; = (float) topk_tmp_val_buf[current_key]
- old_cum_log_probs[(topk_tmp_id_buf[current_key] / vocab_size) % K];
} }
v[selected_beams] = (float) y[current_key]; cum_log_probs[selected_beams] = (float) topk_tmp_val_buf[current_key];
selected_beams++; selected_beams++;
} }
__syncthreads(); __syncthreads();
@ -319,15 +330,46 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
} }
} }
} }
// update beam_hyps.is_done for each batch
if (threadIdx.x == 0 && beam_hyps.num_beams != nullptr) if (threadIdx.x == 0 && beam_hyps.num_beams != nullptr)
{ {
// no enough beams
if (beam_hyps.num_beams[blockIdx.x] < K) if (beam_hyps.num_beams[blockIdx.x] < K)
{ {
beam_hyps.is_done[blockIdx.x] = false; beam_hyps.is_done[blockIdx.x] = false;
return;
} }
else if (beam_hyps.early_stopping) float highest_attainable_score = 0.0f;
switch (early_stopping)
{ {
case 1:
// enough beams with early stopping
beam_hyps.is_done[blockIdx.x] = true; beam_hyps.is_done[blockIdx.x] = true;
return;
case 0:
// enough beams without early stopping
highest_attainable_score = static_cast<float>(apply_length_penalty(cum_log_probs[0],
sequence_lengths[vector_id * K] - beam_hyps.input_lengths[global_batch_idx], length_penalty));
beam_hyps.is_done[blockIdx.x] = beam_hyps.min_normed_scores[global_batch_idx] >= highest_attainable_score;
return;
default:
// early_stopping == "never" in HF, i.e., compute the best possible score depending on `length_penalty`
// https://github.com/huggingface/transformers/blob/main/src/transformers/generation/beam_search.py#L990
if (length_penalty > 0.0f)
{
highest_attainable_score = static_cast<float>(apply_length_penalty(cum_log_probs[0],
beam_hyps.max_seq_len - beam_hyps.input_lengths[global_batch_idx], length_penalty));
beam_hyps.is_done[blockIdx.x]
= beam_hyps.min_normed_scores[global_batch_idx] >= highest_attainable_score;
}
else
{
highest_attainable_score = static_cast<float>(apply_length_penalty(cum_log_probs[0],
sequence_lengths[vector_id * K] - beam_hyps.input_lengths[global_batch_idx], length_penalty));
beam_hyps.is_done[blockIdx.x]
= beam_hyps.min_normed_scores[global_batch_idx] >= highest_attainable_score;
}
return;
} }
} }
} }
@ -366,24 +408,22 @@ __device__ __forceinline__ TopKMD<T, MAX_K> reduce_topk_md_op(const TopKMD<T, MA
} }
template <typename T, int ITEMS_PER_THREAD, int MAX_K, int THREADBLOCK_SIZE> template <typename T, int ITEMS_PER_THREAD, int MAX_K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_kernel(const T* __restrict x, __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_kernel(const T* __restrict log_probs,
const T* __restrict b, const float* __restrict c, const FinishedState* __restrict finished, int* __restrict z, const T* __restrict bias, const float* __restrict cum_log_probs, const FinishedState* __restrict finished,
T* __restrict v, int V, int K, const int* __restrict end_ids) int* __restrict topk_tmp_id_buf, T* __restrict topk_tmp_val_buf, int vocab_size, int K,
const int* __restrict end_ids)
{ {
int thread_id = threadIdx.x; const int thread_id = threadIdx.x;
int vector_id = blockIdx.x; const int vector_id = blockIdx.x;
const T MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
// reposition y to data for the current vector
x += vector_id * V;
typedef cub::BlockReduce<TopKMD<float, MAX_K>, THREADBLOCK_SIZE> BlockReduce; typedef cub::BlockReduce<TopKMD<float, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
// reposition log_probs to data for the current vector
log_probs += vector_id * vocab_size;
TopKMD<float, MAX_K> partial; TopKMD<float, MAX_K> partial;
bool finish = finished[vector_id].isFinished();
for (int i = 0; i < MAX_K; ++i) for (int i = 0; i < MAX_K; ++i)
{ {
partial.topk.p[i] = -1; partial.topk.p[i] = -1;
@ -392,22 +432,22 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_ker
partial.md.m = -MAX_T_VAL; partial.md.m = -MAX_T_VAL;
partial.md.d = 0.0F; partial.md.d = 0.0F;
if (finish) if (finished[vector_id].isFinished())
{ {
for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE) for (int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE)
{ {
float elem = (elem_id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL; float elem = (elem_id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL;
MD new_elem{elem, 1.0F}; MD new_elem{elem, 1.0F};
partial.md = reduce_md_op(partial.md, new_elem); partial.md = reduce_md_op(partial.md, new_elem);
partial.topk.insert(elem, elem_id); partial.topk.insert(elem, elem_id);
// if (elem_id > THREADBLOCK_SIZE * MAX_K && (elem_id == E)) break; // if (elem_id > THREADBLOCK_SIZE * MAX_K && elem_id == E) break;
} }
} }
else else
{ {
for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE) for (int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE)
{ {
float elem = x[elem_id] + b[elem_id]; float elem = log_probs[elem_id] + bias[elem_id];
MD new_elem{elem, 1.0F}; MD new_elem{elem, 1.0F};
partial.md = reduce_md_op(partial.md, new_elem); partial.md = reduce_md_op(partial.md, new_elem);
partial.topk.insert(elem, elem_id); partial.topk.insert(elem, elem_id);
@ -418,9 +458,9 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_ker
if (thread_id == 0) if (thread_id == 0)
{ {
z += vector_id * K; topk_tmp_id_buf += vector_id * K;
v += vector_id * K; topk_tmp_val_buf += vector_id * K;
c += vector_id; cum_log_probs += vector_id;
// float d_total_inverse = __fdividef(1.0F, total.md.d); // float d_total_inverse = __fdividef(1.0F, total.md.d);
float d_total_log = logf(total.md.d); float d_total_log = logf(total.md.d);
@ -430,48 +470,43 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_ker
float val = total.topk.u[i] - total.md.m - d_total_log; float val = total.topk.u[i] - total.md.m - d_total_log;
if (i < K) if (i < K)
{ {
z[i] = total.topk.p[i] + vector_id * V; // trtllm needs absolute id topk_tmp_id_buf[i] = total.topk.p[i] + vector_id * vocab_size; // trtllm needs absolute id
v[i] = val + c[0]; topk_tmp_val_buf[i] = val + cum_log_probs[0];
} }
} }
} }
} }
template <typename T, int ITEMS_PER_THREAD, int MAX_K2, int THREADBLOCK_SIZE> template <typename T, int ITEMS_PER_THREAD, int MAX_K2, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE, 1) __global__ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_stage1_kernel_base(
void beam_online_softmax_topk_stage1_kernel_base(const T* __restrict x, const T* __restrict b, const T* __restrict log_probs, const T* __restrict bias, const FinishedState* __restrict finished,
const FinishedState* __restrict finished, float* __restrict t, int V, int K, const int* __restrict end_ids) float* __restrict tmp_buffer, int vocab_size, int K, const int* __restrict end_ids)
{ {
int thread_id = threadIdx.x; const int thread_id = threadIdx.x;
int vector_id = blockIdx.x; // batch beam index. const int vector_id = blockIdx.x;
const T MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
const int PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2; const int PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2;
// one threadblock has multiple sections per vocab_size
const bool IS_FP16 = std::is_same<T, half>::value; const int v_local = (vocab_size + gridDim.y - 1) / gridDim.y;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
// one will have multiple sections per V
const int v_local = (V + gridDim.y - 1) / gridDim.y;
const int section_start = v_local * blockIdx.y; const int section_start = v_local * blockIdx.y;
int section_end = section_start + v_local; const int section_end = std::min(section_start + v_local, vocab_size);
section_end = (section_end > V) ? V : section_end;
// reposition x to data for the current vector
x += vector_id * V;
#if TOPK_FP16_STORAGE == 1 #if TOPK_FP16_STORAGE == 1
typedef cub::BlockReduce<TopKMD<__half, MAX_K2>, THREADBLOCK_SIZE> BlockReduce; typedef cub::BlockReduce<TopKMD<__half, MAX_K2>, THREADBLOCK_SIZE> BlockReduce;
#else #else
typedef cub::BlockReduce<TopKMD<T, MAX_K2>, THREADBLOCK_SIZE> BlockReduce; typedef cub::BlockReduce<TopKMD<T, MAX_K2>, THREADBLOCK_SIZE> BlockReduce;
#endif #endif
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ float buf_s[PACKED_TOP_KMD_SIZE]; // save intermediate result __shared__ float buf_s[PACKED_TOP_KMD_SIZE];
// reposition log_probs to data for the current vector
log_probs += vector_id * vocab_size;
#if TOPK_FP16_STORAGE == 1 #if TOPK_FP16_STORAGE == 1
TopKMD<__half, MAX_K2> partial; TopKMD<__half, MAX_K2> partial;
#else #else
TopKMD<T, MAX_K2> partial; TopKMD<T, MAX_K2> partial;
#endif #endif
bool finish = finished[vector_id].isFinished();
for (int i = 0; i < MAX_K2; ++i) for (int i = 0; i < MAX_K2; ++i)
{ {
partial.topk.p[i] = -1; partial.topk.p[i] = -1;
@ -480,7 +515,7 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__
partial.md.m = -MAX_T_VAL; partial.md.m = -MAX_T_VAL;
partial.md.d = 0.0F; partial.md.d = 0.0F;
if (finish) if (finished[vector_id].isFinished())
{ {
#pragma unroll 1 #pragma unroll 1
for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE) for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE)
@ -496,8 +531,8 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__
#pragma unroll 1 #pragma unroll 1
for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE) for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE)
{ {
T bias = b == nullptr ? (T) 0.0f : b[elem_id]; // gpt-2 does not use bias T b = bias == nullptr ? (T) 0.0f : bias[elem_id];
T elem = x[elem_id] + bias; T elem = log_probs[elem_id] + b;
MD new_elem{elem, 1.0F}; MD new_elem{elem, 1.0F};
partial.md = reduce_md_op(partial.md, new_elem); partial.md = reduce_md_op(partial.md, new_elem);
partial.topk.insert(elem, elem_id); partial.topk.insert(elem, elem_id);
@ -514,7 +549,7 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__
{ {
for (int i = 0; i < 2 * K; i++) for (int i = 0; i < 2 * K; i++)
{ {
reinterpret_cast<int*>(buf_s)[i] = total.topk.p[i] + vector_id * V; // trtllm needs absolute id reinterpret_cast<int*>(buf_s)[i] = total.topk.p[i] + vector_id * vocab_size; // trtllm needs absolute id
buf_s[MAX_K2 + i] = total.topk.u[i]; buf_s[MAX_K2 + i] = total.topk.u[i];
} }
buf_s[2 * MAX_K2] = total.md.d; buf_s[2 * MAX_K2] = total.md.d;
@ -523,38 +558,25 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__
__syncthreads(); __syncthreads();
for (int elem_id = thread_id; elem_id < PACKED_TOP_KMD_SIZE; elem_id += THREADBLOCK_SIZE) for (int elem_id = thread_id; elem_id < PACKED_TOP_KMD_SIZE; elem_id += THREADBLOCK_SIZE)
{ {
t[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id] = buf_s[elem_id]; tmp_buffer[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id]
= buf_s[elem_id];
} }
} }
template <typename T, int ITEMS_PER_THREAD, int MAX_K2, int THREADBLOCK_SIZE> template <typename T, int ITEMS_PER_THREAD, int MAX_K2, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_stage1_kernel_fast( __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_stage1_kernel_fast(
const T* __restrict x, const T* __restrict b, const FinishedState* __restrict finished, float* __restrict t, int V, const T* __restrict log_probs, const T* __restrict bias, const FinishedState* __restrict finished,
int K, const int* __restrict end_ids, const int v_local) float* __restrict t, int vocab_size, int K, const int* __restrict end_ids, const int v_local)
{ {
extern __shared__ char buf_smem_logprobs_[]; const int thread_id = threadIdx.x;
T* buf_smem_logprobs = reinterpret_cast<T*>(buf_smem_logprobs_); const int vector_id = blockIdx.x;
const T MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
int thread_id = threadIdx.x;
int vector_id = blockIdx.x; // batch beam index.
const int PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2; const int PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2;
// one threadblock has multiple sections per vocab_size
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
// reposition x to data for the current vector
x += vector_id * V;
// one will have multiple sections per V
const int section_start = v_local * blockIdx.y; const int section_start = v_local * blockIdx.y;
int section_end = section_start + v_local; const int section_end = std::min(section_start + v_local, vocab_size);
section_end = (section_end > V) ? V : section_end;
const int valid_smem_length = section_end - section_start; const int valid_smem_length = section_end - section_start;
bool finish = finished[vector_id].isFinished();
MD partial_md{-MAX_T_VAL, 0.0f};
#if TOPK_FP16_STORAGE == 1 #if TOPK_FP16_STORAGE == 1
using cub_kvp = cub::KeyValuePair<int, __half>; using cub_kvp = cub::KeyValuePair<int, __half>;
using BlockReduceTopK = cub::BlockReduce<cub_kvp, THREADBLOCK_SIZE>; using BlockReduceTopK = cub::BlockReduce<cub_kvp, THREADBLOCK_SIZE>;
@ -565,10 +587,25 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_
using BlockReduceMD = cub::BlockReduce<MD, THREADBLOCK_SIZE>; using BlockReduceMD = cub::BlockReduce<MD, THREADBLOCK_SIZE>;
cub::ArgMax arg_max; extern __shared__ char buf_smem_logprobs_[];
cub_kvp partial_topk{V - 1, -MAX_T_VAL}; T* buf_smem_logprobs = reinterpret_cast<T*>(buf_smem_logprobs_);
if (finish) __shared__ union
{
typename BlockReduceMD::TempStorage md_smem;
typename BlockReduceTopK::TempStorage topk_smem;
} temp_storage;
__shared__ float buf_s[PACKED_TOP_KMD_SIZE];
__shared__ int thread_requiring_update;
// reposition log_probs to data for the current vector
log_probs += vector_id * vocab_size;
cub::ArgMax arg_max;
cub_kvp partial_topk{vocab_size - 1, -MAX_T_VAL};
MD partial_md{-MAX_T_VAL, 0.0f};
if (finished[vector_id].isFinished())
{ {
#pragma unroll 1 #pragma unroll 1
for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE) for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE)
@ -589,8 +626,8 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_
#pragma unroll 1 #pragma unroll 1
for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE) for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE)
{ {
T bias = b == nullptr ? (T) 0.0f : b[elem_id]; // gpt-2 does not use bias T b = bias == nullptr ? (T) 0.0f : bias[elem_id];
T elem = x[elem_id] + bias; T elem = log_probs[elem_id] + b;
MD new_elem_md{elem, 1.0F}; MD new_elem_md{elem, 1.0F};
partial_md = reduce_md_op(partial_md, new_elem_md); partial_md = reduce_md_op(partial_md, new_elem_md);
@ -600,18 +637,8 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_
buf_smem_logprobs[smem_index] = elem; buf_smem_logprobs[smem_index] = elem;
} }
} }
__syncthreads(); __syncthreads();
__shared__ union
{
typename BlockReduceMD::TempStorage md_smem;
typename BlockReduceTopK::TempStorage topk_smem;
} temp_storage;
__shared__ float buf_s[PACKED_TOP_KMD_SIZE]; // save intermediate result
__shared__ int thread_requiring_update;
for (int i = 0; i < 2 * K; ++i) for (int i = 0; i < 2 * K; ++i)
{ {
cub_kvp total_topk = BlockReduceTopK(temp_storage.topk_smem).Reduce(partial_topk, arg_max); cub_kvp total_topk = BlockReduceTopK(temp_storage.topk_smem).Reduce(partial_topk, arg_max);
@ -619,18 +646,18 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_
if (threadIdx.x == 0) if (threadIdx.x == 0)
{ {
reinterpret_cast<int*>(buf_s)[i] reinterpret_cast<int*>(buf_s)[i]
= section_start + total_topk.key + vector_id * V; // trtllm needs absolute id = section_start + total_topk.key + vector_id * vocab_size; // trtllm needs absolute id
buf_s[MAX_K2 + i] = total_topk.value; buf_s[MAX_K2 + i] = total_topk.value;
buf_smem_logprobs[total_topk.key] = -MAX_T_VAL; buf_smem_logprobs[total_topk.key] = -MAX_T_VAL;
thread_requiring_update = total_topk.key % THREADBLOCK_SIZE; thread_requiring_update = total_topk.key % THREADBLOCK_SIZE;
} }
__syncthreads(); __syncthreads();
// Only 1 thread needs to update the old partial before the next block reduce. We don't need to do this update // Only one thread needs to update the old partial before the next block reduce.
// on the last iteration. // No need to do this in the last iteration.
if (thread_id == thread_requiring_update && i < (2 * K - 1)) if (thread_id == thread_requiring_update && i < (2 * K - 1))
{ {
partial_topk.key = V - 1; partial_topk.key = vocab_size - 1;
partial_topk.value = -MAX_T_VAL; partial_topk.value = -MAX_T_VAL;
for (int tid = thread_id; tid < valid_smem_length; tid += THREADBLOCK_SIZE) for (int tid = thread_id; tid < valid_smem_length; tid += THREADBLOCK_SIZE)
{ {
@ -649,7 +676,6 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_
buf_s[2 * MAX_K2 + 1] = total_md.m; buf_s[2 * MAX_K2 + 1] = total_md.m;
} }
__syncthreads(); __syncthreads();
for (int elem_id = thread_id; elem_id < PACKED_TOP_KMD_SIZE; elem_id += THREADBLOCK_SIZE) for (int elem_id = thread_id; elem_id < PACKED_TOP_KMD_SIZE; elem_id += THREADBLOCK_SIZE)
{ {
t[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id] = buf_s[elem_id]; t[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id] = buf_s[elem_id];
@ -657,58 +683,52 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_
} }
template <typename T, int MAX_K2, int THREADBLOCK_SIZE> template <typename T, int MAX_K2, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_stage2_kernel(const float* __restrict x, __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_stage2_kernel(
const float* __restrict c, int* __restrict z, T* __restrict v, int K, int parts_per_beam, const int V) const float* __restrict temp_storage, const float* __restrict cum_log_probs, int* __restrict ids,
T* __restrict vals, int K, int parts_per_beam, const int vocab_size)
{ {
const int vector_id = blockIdx.x; const int vector_id = blockIdx.x;
const int thread_id = threadIdx.x; const int thread_id = threadIdx.x;
const T MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
const int PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2; const int PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
extern __shared__ char buf_s_[]; // intermediate result
float* buf_s = reinterpret_cast<float*>(buf_s_);
using cub_kvp = cub::KeyValuePair<int, T>; using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduceTopK = cub::BlockReduce<cub_kvp, THREADBLOCK_SIZE>; using BlockReduceTopK = cub::BlockReduce<cub_kvp, THREADBLOCK_SIZE>;
using BlockReduceMD = cub::BlockReduce<MD, THREADBLOCK_SIZE>; using BlockReduceMD = cub::BlockReduce<MD, THREADBLOCK_SIZE>;
extern __shared__ char buf_s_[];
float* buf_s = reinterpret_cast<float*>(buf_s_);
__shared__ cub_kvp buf_smem_kv[MAX_K2];
__shared__ union __shared__ union
{ {
typename BlockReduceTopK::TempStorage topk_smem; typename BlockReduceTopK::TempStorage topk_smem;
typename BlockReduceMD::TempStorage md_smem; typename BlockReduceMD::TempStorage md_smem;
} temp_storage; } shared_temp_storage;
temp_storage += vector_id * PACKED_TOP_KMD_SIZE * parts_per_beam;
cub::ArgMax arg_max; cub::ArgMax arg_max;
x += vector_id * PACKED_TOP_KMD_SIZE * parts_per_beam;
MD partial_md{-MAX_T_VAL, 0.0f}; MD partial_md{-MAX_T_VAL, 0.0f};
cub_kvp total_topk{V - 1, -MAX_T_VAL}; cub_kvp total_topk{vocab_size - 1, -MAX_T_VAL};
__shared__ char buf_smem_kv_store[MAX_K2 * sizeof(cub_kvp)]; // Load and unpack into registers through smem
auto* buf_smem_kv = reinterpret_cast<cub_kvp*>(buf_smem_kv_store);
// load and unpack into registers through smem
for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * parts_per_beam; idx += THREADBLOCK_SIZE) for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * parts_per_beam; idx += THREADBLOCK_SIZE)
{ {
buf_s[idx] = x[idx]; buf_s[idx] = temp_storage[idx];
} }
__syncthreads(); __syncthreads();
// find the argmax within each parts_per_beam, // Find the argmax within each parts_per_beam
// find the topK across all parts_per_beam. // Find the topK across all parts_per_beam
for (int k = 0; k < 2 * K; ++k) for (int k = 0; k < 2 * K; ++k)
{ {
cub_kvp partial_topk{V - 1, -MAX_T_VAL}; cub_kvp partial_topk{vocab_size - 1, -MAX_T_VAL};
// Only threads responsible for a chunk will do the computation // Only threads responsible for a chunk will do the computation
if (threadIdx.x < parts_per_beam) if (threadIdx.x < parts_per_beam)
{ {
float* b_s = buf_s + threadIdx.x * PACKED_TOP_KMD_SIZE; float* b_s = buf_s + threadIdx.x * PACKED_TOP_KMD_SIZE;
for (int i = 0; i < K; ++i) for (int i = 0; i < K; ++i)
{ {
int current_index = threadIdx.x * PACKED_TOP_KMD_SIZE + i; int current_index = threadIdx.x * PACKED_TOP_KMD_SIZE + i;
@ -718,21 +738,20 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta
} }
} }
cub_kvp total_topk = BlockReduceTopK(temp_storage.topk_smem).Reduce(partial_topk, arg_max); cub_kvp total_topk = BlockReduceTopK(shared_temp_storage.topk_smem).Reduce(partial_topk, arg_max);
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) if (threadIdx.x == 0)
{ {
// store kv pairs in shared mem buffer // Store kv pairs in shared mem buffer
int temp_offset = total_topk.key; int temp_offset = total_topk.key;
int global_offset = reinterpret_cast<int*>(buf_s)[temp_offset]; int global_offset = reinterpret_cast<int*>(buf_s)[temp_offset];
total_topk.key = global_offset; total_topk.key = global_offset;
buf_smem_kv[k] = total_topk; buf_smem_kv[k] = total_topk;
// Invalidate the maximum value within the chunk // Invalidate the maximum value within the chunk
reinterpret_cast<int*>(buf_s)[temp_offset] = V - 1; // id in share memory reinterpret_cast<int*>(buf_s)[temp_offset] = vocab_size - 1; // id in share memory
buf_s[temp_offset + MAX_K2] = -MAX_T_VAL; // value in share memory buf_s[temp_offset + MAX_K2] = -MAX_T_VAL; // value in share memory
} }
__syncthreads(); __syncthreads();
} }
@ -744,18 +763,16 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta
partial_md.d = b_s[2 * MAX_K2]; partial_md.d = b_s[2 * MAX_K2];
partial_md.m = b_s[2 * MAX_K2 + 1]; partial_md.m = b_s[2 * MAX_K2 + 1];
} }
__syncthreads();
auto reduce_md_func = [](const MD& a, const MD& b) { return reduce_md_op(a, b); }; auto reduce_md_func = [](const MD& a, const MD& b) { return reduce_md_op(a, b); };
MD total_md = BlockReduceMD(temp_storage.md_smem).Reduce(partial_md, reduce_md_func); MD total_md = BlockReduceMD(shared_temp_storage.md_smem).Reduce(partial_md, reduce_md_func);
__syncthreads();
if (thread_id == 0) if (thread_id == 0)
{ {
z += vector_id * 2 * K; ids += vector_id * 2 * K;
v += vector_id * 2 * K; vals += vector_id * 2 * K;
c += vector_id; cum_log_probs += vector_id;
float d_total_log = logf(total_md.d); float d_total_log = logf(total_md.d);
for (int i = 0; i < MAX_K2; ++i) for (int i = 0; i < MAX_K2; ++i)
@ -763,8 +780,8 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta
float val = (float) buf_smem_kv[i].value - total_md.m - d_total_log; float val = (float) buf_smem_kv[i].value - total_md.m - d_total_log;
if (i < 2 * K) if (i < 2 * K)
{ {
z[i] = buf_smem_kv[i].key; ids[i] = buf_smem_kv[i].key;
v[i] = (float) val + (float) c[0]; vals[i] = (float) val + (float) cum_log_probs[0];
} }
} }
} }
@ -774,9 +791,9 @@ template <typename T, int MAX_K2>
void beam_online_softmax_topk_stage2_kernelLauncher(const float* temp_storage, const float* cum_log_probs, int* ids, void beam_online_softmax_topk_stage2_kernelLauncher(const float* temp_storage, const float* cum_log_probs, int* ids,
T* vals, int batch_size, int beam_width, int parts_per_beam, cudaStream_t stream, const int vocab_size) T* vals, int batch_size, int beam_width, int parts_per_beam, cudaStream_t stream, const int vocab_size)
{ {
// might rewrite beam_online_softmax_topk_stage2_kernel no to depend on // TODO: rewrite beam_online_softmax_topk_stage2_kernel to remove dependence
// constant block size in oreder to reduce compilation time // of constant block size in oreder to reduce compilation time
int smem_stage2_size = parts_per_beam * (2 * MAX_K2 + 2) * sizeof(float); const int smem_stage2_size = parts_per_beam * (2 * MAX_K2 + 2) * sizeof(float);
if (parts_per_beam <= 32) if (parts_per_beam <= 32)
{ {
@ -803,20 +820,22 @@ void beam_online_softmax_topk_stage2_kernelLauncher(const float* temp_storage, c
} }
template <typename T, int MAX_K> template <typename T, int MAX_K>
void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const FinishedState* finished, void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const FinishedState* finished, float* cum_log_probs,
const int* sequence_lengths, float* cum_log_probs, float* output_log_probs, int** output_ids_ptr, void* temp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps, cudaStream_t stream)
void* temp_storage, const int temp_storage_size, BeamHypotheses* beam_hyps, const int batch_size,
const int beam_width, const int vocab_size, const int* end_ids, const float* diversity_rates,
const float* length_penalties, cudaStream_t stream)
{ {
const int batch_size{beam_hyps.local_batch_size};
const int beam_width{beam_hyps.beam_width};
const int vocab_size{beam_hyps.vocab_size};
const int* end_ids{beam_hyps.end_ids};
const int items_per_thread = 1; const int items_per_thread = 1;
const int block_sz = (MAX_K < 16) ? (MAX_K < 8) ? SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE : 128 : 64; const int block_sz = (MAX_K < 16) ? ((MAX_K < 8) ? SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE : 128) : 64;
// const int block_sz = SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE; // const int block_sz = SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE;
assert(temp_storage_size % 2 == 0); assert(temp_storage_size % 2 == 0);
assert(temp_storage_size >= 2 * batch_size * beam_width * beam_width * 2); assert(temp_storage_size >= 2 * batch_size * beam_width * beam_width * 2);
// Beam search needs the sequence lengths of beams to apply length penalty. // Beam search needs the sequence lengths of beams to apply length penalty.
assert(length_penalties == nullptr || sequence_lengths != nullptr); assert(beam_hyps.length_penalties == nullptr || beam_hyps.sequence_lengths_src != nullptr);
const int topk_buf_offset = ceil(batch_size * beam_width * beam_width * 2 / 4.) * 4; const int topk_buf_offset = ceil(batch_size * beam_width * beam_width * 2 / 4.) * 4;
int* topk_tmp_id_buf = reinterpret_cast<int*>(temp_storage); int* topk_tmp_id_buf = reinterpret_cast<int*>(temp_storage);
@ -928,25 +947,22 @@ void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const Finish
// we will not put them into next iteration // we will not put them into next iteration
const int candidates = beam_width * beam_width * 2; const int candidates = beam_width * beam_width * 2;
int smem_size_batch_topk = sizeof(T) * candidates; const int smem_size_batch_topk = sizeof(T) * candidates;
if (smem_size_batch_topk >= (48 << 10)) if (smem_size_batch_topk >= (48 << 10))
{ {
TLLM_CUDA_CHECK(cudaFuncSetAttribute( TLLM_CUDA_CHECK(cudaFuncSetAttribute(
batch_topk_kernel<T, MAX_K * 2, 32>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_batch_topk)); batch_topk_kernel<T, MAX_K * 2, 32>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_batch_topk));
} }
batch_topk_kernel<T, MAX_K * 2, 32><<<batch_size, 32, smem_size_batch_topk, stream>>>(topk_tmp_id_buf, batch_topk_kernel<T, MAX_K * 2, 32><<<batch_size, 32, smem_size_batch_topk, stream>>>(
topk_tmp_val_buf, output_ids_ptr, cum_log_probs, output_log_probs, finished, sequence_lengths, *beam_hyps, topk_tmp_id_buf, topk_tmp_val_buf, cum_log_probs, finished, beam_hyps, candidates);
candidates, beam_width, vocab_size, length_penalties, diversity_rates);
sync_check_cuda_error(); sync_check_cuda_error();
} }
#define INSTANTIATE_BEAMSEARCH_K(T, MAX_K) \ #define INSTANTIATE_BEAMSEARCH_K(T, MAX_K) \
template void topK_softMax_kernelLauncher<T, MAX_K>(const T* log_probs, const T* bias, \ template void topK_softMax_kernelLauncher<T, MAX_K>(const T* log_probs, const T* bias, \
const FinishedState* finished, const int* sequence_lengths, float* cum_log_probs, float* output_log_probs, \ const FinishedState* finished, float* cum_log_probs, void* temp_storage, const int temp_storage_size, \
int** output_ids_ptr, void* temp_storage, const int temp_storage_size, BeamHypotheses* beam_hyps, \ BeamHypotheses& beam_hyps, cudaStream_t stream);
const int batch_size, const int beam_width, const int vocab_size, const int* end_ids, \
const float* diversity_rates, const float* length_penalties, cudaStream_t stream);
} // namespace kernels } // namespace kernels
} // namespace tensorrt_llm } // namespace tensorrt_llm

View File

@ -70,7 +70,7 @@ void splitkGroupedGemm_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std
int64_t gemmParamsWorkSpaceSize, void* gemmWorkSpace, int64_t gemmWorkSpaceSize, int64_t gemmParamsWorkSpaceSize, void* gemmWorkSpace, int64_t gemmWorkSpaceSize,
std::vector<int64_t> splitkBufferOffsets, int splitKSlices, cudaStream_t stream) std::vector<int64_t> splitkBufferOffsets, int splitKSlices, cudaStream_t stream)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
using ElementA = cutlassType; using ElementA = cutlassType;
using ElementB = cutlassType; using ElementB = cutlassType;
using ElementOutput = float; using ElementOutput = float;
@ -194,7 +194,7 @@ void splitkGroupedGemm_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to run CUTLASS Grouped GEMM kernel."); TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to run CUTLASS Grouped GEMM kernel.");
std::free(host_workspace); std::free(host_workspace);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
template <int M1, int N1, int K1, int M2, int N2, int K2> template <int M1, int N1, int K1, int M2, int N2, int K2>

View File

@ -113,7 +113,7 @@ __device__ __forceinline__ void apply_scale(void* act, void* act_scale)
} }
template <int N, int K, bool EnableZero> template <int N, int K, bool EnableZero>
__device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* scales, void* zeros) __device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* scales, void* zeros, half alpha)
{ {
using Converter = ConverterI4ToF16; using Converter = ConverterI4ToF16;
static_assert(K % 2 == 0); static_assert(K % 2 == 0);
@ -123,11 +123,11 @@ __device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* sca
{ {
ConverterI4ToF16::convert<K>( ConverterI4ToF16::convert<K>(
reinterpret_cast<uint8_t*>(quantized_w) + n * K / 2, reinterpret_cast<half*>(w) + n * K); reinterpret_cast<uint8_t*>(quantized_w) + n * K / 2, reinterpret_cast<half*>(w) + n * K);
half2 vec_scale = __half2half2(reinterpret_cast<half*>(scales)[n]); half2 vec_scale = __half2half2(reinterpret_cast<half*>(scales)[n] * alpha);
half2 vec_zero = __half2half2(__float2half_rn(0.f)); half2 vec_zero = __half2half2(__float2half_rn(0.f));
if constexpr (EnableZero) if constexpr (EnableZero)
{ {
vec_zero = __half2half2(reinterpret_cast<half*>(zeros)[n]); vec_zero = __half2half2(reinterpret_cast<half*>(zeros)[n] * alpha);
} }
#pragma unroll #pragma unroll
for (int k = 0; k < VecK; ++k) for (int k = 0; k < VecK; ++k)
@ -186,7 +186,7 @@ __device__ __forceinline__ T warp_reduce_sum(T& val)
} }
template <int CtaM, int CtaN, int Threads, bool EnableBias> template <int CtaM, int CtaN, int Threads, bool EnableBias>
__device__ __forceinline__ void epilogue(void* out, int stride, void* tile_acc, void* bias, float alpha) __device__ __forceinline__ void epilogue(void* out, int stride, void* tile_acc, void* bias)
{ {
static constexpr int WarpSize = 32; static constexpr int WarpSize = 32;
static constexpr int WarpNum = Threads / WarpSize; static constexpr int WarpNum = Threads / WarpSize;
@ -224,7 +224,7 @@ __device__ __forceinline__ void epilogue(void* out, int stride, void* tile_acc,
{ {
val += shmem[jj * AlignShmemSize + ii]; val += shmem[jj * AlignShmemSize + ii];
} }
reinterpret_cast<half*>(out)[m * stride + n] = __float2half_rn(alpha * val + v_bias); reinterpret_cast<half*>(out)[m * stride + n] = __float2half_rn(val + v_bias);
} }
} }
@ -348,15 +348,17 @@ __global__ void kernel(typename Details::ActDataType* act, half* act_scale, uint
load<half, CtaN, Mandatory>(vec_scale, scales + idx_k / GroupSize * n, 1); load<half, CtaN, Mandatory>(vec_scale, scales + idx_k / GroupSize * n, 1);
load<half, CtaN, EnableZero>(vec_zero, zeros + idx_k / GroupSize * n, 1); load<half, CtaN, EnableZero>(vec_zero, zeros + idx_k / GroupSize * n, 1);
// Dequantize Data // Dequantize Data
// W4A8 checkpoints have larger activation and weight values. In order to prevent the warp-level FP16
// accumulator from overflow, the multiplication of alpha is moved from epilogue to dequantize
apply_scale<CtaM, StepK, EnableActScale>(tile_a, vec_act_scale); apply_scale<CtaM, StepK, EnableActScale>(tile_a, vec_act_scale);
dequantize<CtaN, StepK, EnableZero>(tile_w, tile_w_quantized, vec_scale, vec_zero); dequantize<CtaN, StepK, EnableZero>(tile_w, tile_w_quantized, vec_scale, vec_zero, __float2half_rn(alpha));
// Rearrange // Rearrange
pack_to_vec2<CtaN, StepK>(tile_w_pack2, tile_w); pack_to_vec2<CtaN, StepK>(tile_w_pack2, tile_w);
// MMA // MMA
mma<CtaM, CtaN, StepK>(tile_acc, tile_w_pack2, tile_a); mma<CtaM, CtaN, StepK>(tile_acc, tile_w_pack2, tile_a);
} }
// Epilogue // Epilogue
epilogue<CtaM, CtaN, Threads, EnableBias>(out, n, tile_acc, bias, alpha); epilogue<CtaM, CtaN, Threads, EnableBias>(out, n, tile_acc, bias);
} }
template <typename Details, int CtaM, int CtaN, int Threads, int GroupSize, bool EnableActScale, bool EnableZero, template <typename Details, int CtaM, int CtaN, int Threads, int GroupSize, bool EnableActScale, bool EnableZero,

View File

@ -251,6 +251,7 @@ void DynamicDecodeLayer<T>::setupLayers(
beamSearchParams.beam_search_diversity_rate = setupParams.beam_search_diversity_rate; beamSearchParams.beam_search_diversity_rate = setupParams.beam_search_diversity_rate;
beamSearchParams.length_penalty = setupParams.length_penalty; beamSearchParams.length_penalty = setupParams.length_penalty;
beamSearchParams.early_stopping = setupParams.early_stopping;
mHasDiffRuntimeArgs = hasDiffRuntimeArgs(beamSearchParams); mHasDiffRuntimeArgs = hasDiffRuntimeArgs(beamSearchParams);
mOnlineBeamSearchDecode->setup(batchSize, beamSearchParams); mOnlineBeamSearchDecode->setup(batchSize, beamSearchParams);

View File

@ -75,6 +75,7 @@ public:
// omlineBeamSearchLayer // omlineBeamSearchLayer
std::optional<std::vector<float>> beam_search_diversity_rate; std::optional<std::vector<float>> beam_search_diversity_rate;
std::optional<std::vector<float>> length_penalty; std::optional<std::vector<float>> length_penalty;
std::optional<std::vector<int>> early_stopping;
std::optional<bool> normalize_log_probs; std::optional<bool> normalize_log_probs;
}; };

View File

@ -31,12 +31,21 @@ static const int SMALL_TOP_K_SOFTMAX_MAX_VOC_PARTS = 128;
static const int MAX_K = 4; static const int MAX_K = 4;
template <typename T> template <typename T>
__global__ void update_kernel(FinishedState* finished, int** parent_ids_ptr, int* sequence_lengths, __global__ void update_kernel(FinishedState* finished, BeamHypotheses beam_hyps)
int** output_ids_ptr, BeamHypotheses beam_hyps, const int vocab_size, const int* end_ids,
const int local_batch_size, const int beam_width, const int max_seq_len)
{ {
const int beam_width{beam_hyps.beam_width};
const int ite{beam_hyps.ite};
const int local_batch_size{beam_hyps.local_batch_size};
const int max_seq_len{beam_hyps.max_seq_len};
const int vocab_size{beam_hyps.vocab_size};
const int end_id{beam_hyps.end_ids[blockIdx.x]};
int* num_beams{beam_hyps.num_beams};
int* sequence_lengths{beam_hyps.sequence_lengths_src};
int** output_ids_ptr{beam_hyps.output_ids_tgt_ptr};
int** parent_ids_ptr{beam_hyps.parent_ids_tgt_ptr};
extern __shared__ char s_buf[]; // intermediate result extern __shared__ char s_buf[]; // intermediate result
int* s_sequence_lengths = (int*) (s_buf); int* s_sequence_lengths = reinterpret_cast<int*>(s_buf);
for (int beam_idx = threadIdx.x; beam_idx < beam_width; beam_idx += blockDim.x) for (int beam_idx = threadIdx.x; beam_idx < beam_width; beam_idx += blockDim.x)
{ {
@ -63,35 +72,27 @@ __global__ void update_kernel(FinishedState* finished, int** parent_ids_ptr, int
new_word_id = new_word_id % vocab_size; new_word_id = new_word_id % vocab_size;
sequence_lengths[batch_beam_idx] = s_sequence_lengths[new_beam_id]; sequence_lengths[batch_beam_idx] = s_sequence_lengths[new_beam_id];
if (new_word_id == end_ids[blockIdx.x]) if (new_word_id == end_id)
{ {
finished[batch_beam_idx].setFinishedEOS(); finished[batch_beam_idx].setFinishedEOS();
} }
parent_ids_ptr[blockIdx.x][beam_idx * max_seq_len + current_step] = new_beam_id; parent_ids_ptr[blockIdx.x][beam_idx * max_seq_len + current_step] = new_beam_id;
output_ids_ptr[blockIdx.x][beam_idx * max_seq_len + current_step] = new_word_id; output_ids_ptr[blockIdx.x][beam_idx * max_seq_len + current_step] = new_word_id;
} }
if (beam_hyps.num_beams != nullptr) if (num_beams != nullptr && num_beams[ite * local_batch_size + blockIdx.x] == beam_width)
{ {
if (beam_hyps.num_beams[beam_hyps.ite * beam_hyps.local_batch_size + blockIdx.x] == beam_width) for (int beam_idx = threadIdx.x; beam_idx < beam_width; beam_idx += blockDim.x)
{ {
for (int beam_idx = threadIdx.x; beam_idx < beam_width; beam_idx += blockDim.x) finished[blockIdx.x * beam_width + beam_idx].setFinished();
{
const auto batch_beam_idx = blockIdx.x * beam_width + beam_idx;
finished[batch_beam_idx].setFinished();
}
} }
} }
} }
void invokeUpdate(FinishedState* finished, int** parent_ids_ptr, int* sequence_lengths, int** output_ids_ptr, void invokeUpdate(FinishedState* finished, BeamHypotheses& beam_hyps, cudaStream_t stream)
BeamHypotheses* beam_hyps, const int local_batch_size, const int beam_width, const int vocab_size_padded,
const int* end_ids, const int max_seq_len, cudaStream_t stream)
{ {
dim3 grid(local_batch_size); dim3 grid(beam_hyps.local_batch_size);
dim3 block(min(beam_width, 1024)); dim3 block(min(beam_hyps.beam_width, 1024));
update_kernel<float><<<grid, block, sizeof(int) * beam_hyps.beam_width, stream>>>(finished, beam_hyps);
update_kernel<float><<<grid, block, sizeof(int) * beam_width, stream>>>(finished, parent_ids_ptr, sequence_lengths,
output_ids_ptr, *beam_hyps, vocab_size_padded, end_ids, local_batch_size, beam_width, max_seq_len);
} }
template <typename T> template <typename T>
@ -103,11 +104,12 @@ void OnlineBeamSearchLayer<T>::setup(size_t batch_size, SetupParams const& setup
mDiversityRate.resize(batch_size); mDiversityRate.resize(batch_size);
mLengthPenalty.resize(batch_size); mLengthPenalty.resize(batch_size);
mEarlyStopping.resize(batch_size);
FillBuffers const fillBuffers{batch_size, batch_size, mStream}; FillBuffers const fillBuffers{batch_size, batch_size, mStream};
fillBuffers(setupParams.beam_search_diversity_rate, 0.0f, mDiversityRate, diversity_rates_buf_, (int*) nullptr); fillBuffers(setupParams.beam_search_diversity_rate, 0.0f, mDiversityRate, diversity_rates_buf_, (int*) nullptr);
fillBuffers(setupParams.length_penalty, 0.0f, mLengthPenalty, length_penalties_buf_, (int*) nullptr); fillBuffers(setupParams.length_penalty, 0.0f, mLengthPenalty, length_penalties_buf_, (int*) nullptr);
fillBuffers(setupParams.early_stopping, 1, mEarlyStopping, early_stoppings_buf_, (int*) nullptr);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
@ -115,44 +117,39 @@ template <typename T>
void OnlineBeamSearchLayer<T>::invokeSoftMax(BeamSearchOutputParams& outputs, SoftmaxParams const& params) void OnlineBeamSearchLayer<T>::invokeSoftMax(BeamSearchOutputParams& outputs, SoftmaxParams const& params)
{ {
TLLM_LOG_TRACE("%s", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s", __PRETTY_FUNCTION__);
Tensor const& output_ids_ptr = outputs.output_ids_ptr;
const auto batch_size = static_cast<std::int32_t>(output_ids_ptr.shape[0]);
const auto beam_width = static_cast<std::int32_t>(output_ids_ptr.shape[1]);
const auto max_seq_len = static_cast<std::int32_t>(output_ids_ptr.shape[2]);
const int ite{params.ite};
Tensor const& logits{params.logits};
const auto local_batch_size = logits.shape[0];
BeamHypotheses beamHypotheses;
auto* const end_ids = params.end_ids.template getPtr<const int>();
float* output_log_probs = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr<float>() : nullptr;
auto* finished auto* finished
= reinterpret_cast<FinishedState*>(outputs.finished->template getPtr<FinishedState::UnderlyingType>()); = reinterpret_cast<FinishedState*>(outputs.finished->template getPtr<FinishedState::UnderlyingType>());
auto* sequence_lengths = outputs.sequence_length->template getPtr<int>();
BeamHypotheses beam_hyps;
if (outputs.beamHypotheses) if (outputs.beamHypotheses)
{ {
beamHypotheses = *outputs.beamHypotheses; beam_hyps = *outputs.beamHypotheses;
beamHypotheses.ite = ite; // Some of beam_hyps members have been initialized before function invokeSoftMax
beamHypotheses.local_batch_size = local_batch_size; beam_hyps.end_ids = params.end_ids.template getPtr<const int>();
beamHypotheses.batch_size = batch_size; beam_hyps.log_probs = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr<float>() : nullptr;
beamHypotheses.max_seq_len = max_seq_len; beam_hyps.output_ids_src_ptr = outputs.output_ids_ptr.template getPtr<const int*>();
beamHypotheses.output_ids_src_ptr = output_ids_ptr.template getPtr<const int*>(); beam_hyps.output_ids_tgt_ptr = outputs.output_ids_ptr.template getPtr<int*>();
beamHypotheses.parent_ids_src_ptr = outputs.parent_ids_ptr.template getPtr<const int*>(); beam_hyps.parent_ids_src_ptr = outputs.parent_ids_ptr.template getPtr<const int*>();
beamHypotheses.sequence_lengths_src = sequence_lengths; beam_hyps.parent_ids_tgt_ptr = outputs.parent_ids_ptr.template getPtr<int*>();
beamHypotheses.log_probs_src = output_log_probs; beam_hyps.sequence_lengths_src = outputs.sequence_length->template getPtr<int>();
beamHypotheses.length_penalties = length_penalties_buf_;
beamHypotheses.end_ids = end_ids; beam_hyps.batch_size = static_cast<std::int32_t>(outputs.output_ids_ptr.shape[0]);
beam_hyps.beam_width = static_cast<std::int32_t>(outputs.output_ids_ptr.shape[1]);
beam_hyps.ite = params.ite;
beam_hyps.local_batch_size = params.logits.shape[0];
beam_hyps.max_seq_len = static_cast<std::int32_t>(outputs.output_ids_ptr.shape[2]);
beam_hyps.vocab_size = vocab_size_padded_;
beam_hyps.diversity_rates = diversity_rates_buf_;
beam_hyps.length_penalties = length_penalties_buf_;
beam_hyps.early_stoppings = early_stoppings_buf_;
} }
invokeTopkSoftMax(logits.template getPtr<T>(), (const T*) (nullptr), finished, sequence_lengths, invokeTopkSoftMax(params.logits.template getPtr<T>(), (const T*) (nullptr), finished,
outputs.cum_log_probs->template getPtr<float>(), output_log_probs, output_ids_ptr.getPtr<int*>(), outputs.cum_log_probs->template getPtr<float>(), topk_softmax_workspace_, topk_softmax_workspace_size_,
topk_softmax_workspace_, topk_softmax_workspace_size_, &beamHypotheses, local_batch_size, beam_width, beam_hyps, mStream);
vocab_size_padded_, end_ids, diversity_rates_buf_, length_penalties_buf_, mStream);
sync_check_cuda_error(); sync_check_cuda_error();
invokeUpdate(finished, outputs.parent_ids_ptr.template getPtr<int*>(), sequence_lengths, invokeUpdate(finished, beam_hyps, mStream);
output_ids_ptr.getPtr<int*>(), &beamHypotheses, local_batch_size, beam_width, vocab_size_padded_, end_ids,
max_seq_len, mStream);
sync_check_cuda_error(); sync_check_cuda_error();
} }
@ -169,6 +166,7 @@ void OnlineBeamSearchLayer<T>::allocateBuffer(size_t batch_size)
mAllocator->reMalloc(topk_softmax_workspace_, sizeof(float) * topk_softmax_workspace_size_, true)); mAllocator->reMalloc(topk_softmax_workspace_, sizeof(float) * topk_softmax_workspace_size_, true));
diversity_rates_buf_ = mAllocator->reMalloc(diversity_rates_buf_, sizeof(float) * batch_size, false); diversity_rates_buf_ = mAllocator->reMalloc(diversity_rates_buf_, sizeof(float) * batch_size, false);
length_penalties_buf_ = mAllocator->reMalloc(length_penalties_buf_, sizeof(float) * batch_size, false); length_penalties_buf_ = mAllocator->reMalloc(length_penalties_buf_, sizeof(float) * batch_size, false);
early_stoppings_buf_ = mAllocator->reMalloc(early_stoppings_buf_, sizeof(int) * batch_size, false);
mIsAllocateBuffer = true; mIsAllocateBuffer = true;
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
@ -183,6 +181,7 @@ void OnlineBeamSearchLayer<T>::freeBuffer()
mAllocator->free((void**) (&topk_softmax_workspace_)); mAllocator->free((void**) (&topk_softmax_workspace_));
mAllocator->free((void**) (&diversity_rates_buf_)); mAllocator->free((void**) (&diversity_rates_buf_));
mAllocator->free((void**) (&length_penalties_buf_)); mAllocator->free((void**) (&length_penalties_buf_));
mAllocator->free((void**) (&early_stoppings_buf_));
mIsAllocateBuffer = false; mIsAllocateBuffer = false;
} }
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);

View File

@ -41,6 +41,7 @@ public:
public: public:
std::optional<std::vector<float>> beam_search_diversity_rate; // [1] or [batch_size] on cpu std::optional<std::vector<float>> beam_search_diversity_rate; // [1] or [batch_size] on cpu
std::optional<std::vector<float>> length_penalty; // [1] or [batch_size] on cpu std::optional<std::vector<float>> length_penalty; // [1] or [batch_size] on cpu
std::optional<std::vector<int>> early_stopping; // [1] or [batch_size] on cpu
}; };
OnlineBeamSearchLayer( OnlineBeamSearchLayer(
@ -71,8 +72,10 @@ protected:
std::vector<float> mDiversityRate; std::vector<float> mDiversityRate;
std::vector<float> mLengthPenalty; std::vector<float> mLengthPenalty;
std::vector<int> mEarlyStopping;
float* diversity_rates_buf_; float* diversity_rates_buf_;
float* length_penalties_buf_; float* length_penalties_buf_;
int* early_stoppings_buf_;
private: private:
void allocateBuffer(size_t batch_size); void allocateBuffer(size_t batch_size);

View File

@ -168,6 +168,7 @@ bool GPTAttentionPluginCommon::convertMMHAParamsToXQAParams(tensorrt_llm::kernel
memset(&xqaParams, 0, sizeof(XQAParams)); memset(&xqaParams, 0, sizeof(XQAParams));
xqaParams.data_type = ConvertMMHAToXQAParamsHelper<T, KVCacheBuffer>::data_type; xqaParams.data_type = ConvertMMHAToXQAParamsHelper<T, KVCacheBuffer>::data_type;
xqaParams.layer_idx = mLayerIdx;
xqaParams.num_q_heads = mNumHeads; xqaParams.num_q_heads = mNumHeads;
xqaParams.num_kv_heads = mNumKVHeads; xqaParams.num_kv_heads = mNumKVHeads;
xqaParams.head_size = mHeadSize; xqaParams.head_size = mHeadSize;
@ -363,8 +364,8 @@ INSTANTIATE_MMHA_DISPATCH(__nv_bfloat16, __nv_bfloat16)
#endif #endif
#undef INSTANTIATE_MMHA_DISPATCH #undef INSTANTIATE_MMHA_DISPATCH
GPTAttentionPluginCommon::GPTAttentionPluginCommon(int num_heads, int num_kv_heads, int head_size, int unidirectional, GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, int num_kv_heads, int head_size,
float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type, int unidirectional, float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type, float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
@ -374,7 +375,8 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int num_heads, int num_kv_hea
bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length, bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length,
bool qkv_bias_enabled, bool cross_attention, int max_distance, bool pos_shift_enabled, bool dense_context_fmha, bool qkv_bias_enabled, bool cross_attention, int max_distance, bool pos_shift_enabled, bool dense_context_fmha,
bool use_paged_context_fmha, bool use_cache, bool is_medusa_enabled) bool use_paged_context_fmha, bool use_cache, bool is_medusa_enabled)
: mNumHeads(num_heads) : mLayerIdx(layer_idx)
, mNumHeads(num_heads)
, mNumKVHeads(num_kv_heads) , mNumKVHeads(num_kv_heads)
, mHeadSize(head_size) , mHeadSize(head_size)
, mUnidirectional(unidirectional) , mUnidirectional(unidirectional)
@ -477,6 +479,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(const void* data, size_t leng
const char *d = reinterpret_cast<const char*>(data), *a = d; const char *d = reinterpret_cast<const char*>(data), *a = d;
unsigned int kvCacheQuantMode; unsigned int kvCacheQuantMode;
read(d, mLayerIdx);
read(d, mNumHeads); read(d, mNumHeads);
read(d, mNumKVHeads); read(d, mNumKVHeads);
read(d, mHeadSize); read(d, mHeadSize);
@ -1490,11 +1493,12 @@ void GPTAttentionPluginCommon::destroy() noexcept
size_t GPTAttentionPluginCommon::getCommonSerializationSize() noexcept size_t GPTAttentionPluginCommon::getCommonSerializationSize() noexcept
{ {
return sizeof(mNumHeads) + sizeof(mNumKVHeads) + sizeof(mHeadSize) + sizeof(mUnidirectional) + sizeof(mQScaling) return sizeof(mLayerIdx) + sizeof(mNumHeads) + sizeof(mNumKVHeads) + sizeof(mHeadSize) + sizeof(mUnidirectional)
+ sizeof(mPositionEmbeddingType) + sizeof(mRotaryEmbeddingDim) + sizeof(mRotaryEmbeddingBase) + sizeof(mQScaling) + sizeof(mPositionEmbeddingType) + sizeof(mRotaryEmbeddingDim)
+ sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale) + sizeof(mRotaryEmbeddingMaxPositions) + sizeof(mRotaryEmbeddingBase) + sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale)
+ sizeof(mTpSize) + sizeof(mTpRank) + sizeof(mEnableContextFMHA) + sizeof(mFMHAForceFP32Acc) + sizeof(mRotaryEmbeddingMaxPositions) + sizeof(mTpSize) + sizeof(mTpRank) + sizeof(mEnableContextFMHA)
+ sizeof(mMultiBlockMode) + sizeof(mEnableXQA) + sizeof(unsigned int) // mKVCacheQuantMode + sizeof(mFMHAForceFP32Acc) + sizeof(mMultiBlockMode) + sizeof(mEnableXQA)
+ sizeof(unsigned int) // mKVCacheQuantMode
+ sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mPagedKVCache) + sizeof(mTokensPerBlock) + sizeof(mType) + sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mPagedKVCache) + sizeof(mTokensPerBlock) + sizeof(mType)
+ sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled) + sizeof(mCrossAttention) + sizeof(mMaxDistance) + sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled) + sizeof(mCrossAttention) + sizeof(mMaxDistance)
+ sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA) + sizeof(mPagedContextFMHA) + sizeof(mUseKVCache) + sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA) + sizeof(mPagedContextFMHA) + sizeof(mUseKVCache)
@ -1504,6 +1508,7 @@ size_t GPTAttentionPluginCommon::getCommonSerializationSize() noexcept
void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
{ {
char *d = static_cast<char*>(buffer), *a = d; char *d = static_cast<char*>(buffer), *a = d;
write(d, mLayerIdx);
write(d, mNumHeads); write(d, mNumHeads);
write(d, mNumKVHeads); write(d, mNumKVHeads);
write(d, mHeadSize); write(d, mHeadSize);

View File

@ -36,8 +36,8 @@ class GPTAttentionPluginCommon : public BasePlugin
public: public:
GPTAttentionPluginCommon() = delete; GPTAttentionPluginCommon() = delete;
GPTAttentionPluginCommon(int num_heads, int num_kv_heads, int head_size, int unidirectional, float q_scaling, GPTAttentionPluginCommon(int layer_idx, int num_heads, int num_kv_heads, int head_size, int unidirectional,
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type, float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type, float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
@ -213,6 +213,7 @@ protected:
const std::string mLayerName; const std::string mLayerName;
int mLayerIdx;
int mNumHeads; int mNumHeads;
int mNumKVHeads; int mNumKVHeads;
int mHeadSize; int mHeadSize;

View File

@ -37,8 +37,8 @@ using tensorrt_llm::plugins::GPTAttentionPlugin;
static const char* GPT_ATTENTION_PLUGIN_VERSION{"1"}; static const char* GPT_ATTENTION_PLUGIN_VERSION{"1"};
static const char* GPT_ATTENTION_PLUGIN_NAME{"GPTAttention"}; static const char* GPT_ATTENTION_PLUGIN_NAME{"GPTAttention"};
GPTAttentionPlugin::GPTAttentionPlugin(int num_heads, int num_kv_heads, int head_size, int unidirectional, GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int num_kv_heads, int head_size,
float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type, int unidirectional, float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type, float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
@ -48,12 +48,12 @@ GPTAttentionPlugin::GPTAttentionPlugin(int num_heads, int num_kv_heads, int head
bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length, bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length,
bool qkv_bias_enabled, bool cross_attention, int max_distance, bool pos_shift_enabled, bool dense_context_fmha, bool qkv_bias_enabled, bool cross_attention, int max_distance, bool pos_shift_enabled, bool dense_context_fmha,
bool use_paged_context_fmha, bool use_cache, bool is_medusa_enabled) bool use_paged_context_fmha, bool use_cache, bool is_medusa_enabled)
: GPTAttentionPluginCommon(num_heads, num_kv_heads, head_size, unidirectional, q_scaling, position_embedding_type, : GPTAttentionPluginCommon(layer_idx, num_heads, num_kv_heads, head_size, unidirectional, q_scaling,
rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale, position_embedding_type, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type,
rotary_embedding_max_positions, tp_size, tp_rank, unfuse_qkv_gemm, context_fmha_type, multi_block_mode, rotary_embedding_scale, rotary_embedding_max_positions, tp_size, tp_rank, unfuse_qkv_gemm, context_fmha_type,
enable_xqa, kv_cache_quant_mode, remove_input_padding, mask_type, paged_kv_cache, tokens_per_block, type, multi_block_mode, enable_xqa, kv_cache_quant_mode, remove_input_padding, mask_type, paged_kv_cache,
max_context_length, qkv_bias_enabled, cross_attention, max_distance, pos_shift_enabled, dense_context_fmha, tokens_per_block, type, max_context_length, qkv_bias_enabled, cross_attention, max_distance, pos_shift_enabled,
use_paged_context_fmha, use_cache, is_medusa_enabled) dense_context_fmha, use_paged_context_fmha, use_cache, is_medusa_enabled)
{ {
initEntryIdx(); initEntryIdx();
} }
@ -485,7 +485,7 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
// Note that this cyclic_attention_window_size might be smaller than the actual kv cache capactity. // Note that this cyclic_attention_window_size might be smaller than the actual kv cache capactity.
const int cyclic_attention_window_size = isCrossAttention() const int cyclic_attention_window_size = isCrossAttention()
? max_encoder_context_len ? max_encoder_context_len
: reinterpret_cast<const int*>(inputs[getIdx(IdxEntry::HOST_MAX_ATTENTION_WINDOW)])[0]; : reinterpret_cast<const int*>(inputs[getIdx(IdxEntry::HOST_MAX_ATTENTION_WINDOW)])[mLayerIdx];
const int sink_token_length = reinterpret_cast<const int*>(inputs[getIdx(IdxEntry::HOST_SINK_TOKEN_LENGTH)])[0]; const int sink_token_length = reinterpret_cast<const int*>(inputs[getIdx(IdxEntry::HOST_SINK_TOKEN_LENGTH)])[0];
const float* kv_scale_orig_quant = nullptr; const float* kv_scale_orig_quant = nullptr;
@ -503,14 +503,18 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
void* host_block_pointers = nullptr; void* host_block_pointers = nullptr;
if (useKVCache() && mPagedKVCache) if (useKVCache() && mPagedKVCache)
{ {
auto& kvCacheBlockPointers = inputDesc[getIdx(IdxEntry::KV_CACHE_BLOCK_POINTERS)]; auto const& kvCacheBlockPointers = inputDesc[getIdx(IdxEntry::KV_CACHE_BLOCK_POINTERS)];
auto& kvCacheBlockPointersShape = inputDesc[getIdx(IdxEntry::KV_CACHE_BLOCK_POINTERS)].dims; auto const& kvCacheBlockPointersShape = inputDesc[getIdx(IdxEntry::KV_CACHE_BLOCK_POINTERS)].dims;
max_blocks_per_sequence = kvCacheBlockPointersShape.d[kvCacheBlockPointersShape.nbDims - 1]; max_blocks_per_sequence = kvCacheBlockPointersShape.d[kvCacheBlockPointersShape.nbDims - 1];
auto offset = getStride(kvCacheBlockPointersShape, 0) * seqIdxBeg; auto const batchSize = kvCacheBlockPointersShape.d[1];
auto const typed_block_pointers auto const seqStride = getStride(kvCacheBlockPointersShape, 1);
auto const layerOffset = mLayerIdx * batchSize * seqStride;
auto const seqOffset = seqIdxBeg * seqStride;
auto const offset = layerOffset + seqOffset;
auto const* const typed_block_pointers
= static_cast<void* const*>(inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_POINTERS)]) + offset; = static_cast<void* const*>(inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_POINTERS)]) + offset;
block_pointers = const_cast<void*>(static_cast<void const*>(typed_block_pointers)); block_pointers = const_cast<void*>(static_cast<void const*>(typed_block_pointers));
auto const typed_host_block_pointers auto const* const typed_host_block_pointers
= static_cast<void* const*>(inputs[getIdx(IdxEntry::HOST_KV_CACHE_BLOCK_POINTERS)]) + offset; = static_cast<void* const*>(inputs[getIdx(IdxEntry::HOST_KV_CACHE_BLOCK_POINTERS)]) + offset;
host_block_pointers = const_cast<void*>(static_cast<void const*>(typed_host_block_pointers)); host_block_pointers = const_cast<void*>(static_cast<void const*>(typed_host_block_pointers));
} }
@ -762,9 +766,10 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(const char* name, const Plugi
try try
{ {
auto* obj = new GPTAttentionPlugin(p.getScalar<int32_t>("num_heads").value(), auto* obj = new GPTAttentionPlugin(p.getScalar<int32_t>("layer_idx").value(),
p.getScalar<int32_t>("num_kv_heads").value(), p.getScalar<int32_t>("head_size").value(), p.getScalar<int32_t>("num_heads").value(), p.getScalar<int32_t>("num_kv_heads").value(),
p.getScalar<int32_t>("unidirectional").value(), p.getScalar<float>("q_scaling").value(), p.getScalar<int32_t>("head_size").value(), p.getScalar<int32_t>("unidirectional").value(),
p.getScalar<float>("q_scaling").value(),
static_cast<PositionEmbeddingType>(p.getScalar<int8_t>("position_embedding_type").value()), static_cast<PositionEmbeddingType>(p.getScalar<int8_t>("position_embedding_type").value()),
p.getScalar<int32_t>("rotary_embedding_dim").value(), p.getScalar<float>("rotary_embedding_base").value(), p.getScalar<int32_t>("rotary_embedding_dim").value(), p.getScalar<float>("rotary_embedding_base").value(),
static_cast<RotaryScalingType>(p.getScalar<int8_t>("rotary_embedding_scale_type").value()), static_cast<RotaryScalingType>(p.getScalar<int8_t>("rotary_embedding_scale_type").value()),

View File

@ -46,7 +46,7 @@ namespace tensorrt_llm::plugins
// enable_remove_input_padding // enable_remove_input_padding
// 1. sequence_length [batch_size] (optional) // 1. sequence_length [batch_size] (optional)
// 2. host_past_key_value_lengths [batch_size] (int32) (optional) // 2. host_past_key_value_lengths [batch_size] (int32) (optional)
// 3. host_max_attention_window_sizes [1] (int32) // 3. host_max_attention_window_sizes [num_layers] (int32)
// 4. host_sink_token_length [1] (int32) // 4. host_sink_token_length [1] (int32)
// 5. context_lengths [batch_size] // 5. context_lengths [batch_size]
// 6. cache_indir [num_gen_requests, beam_width, memory_max_len] (required in beamsearch) (optional) // 6. cache_indir [num_gen_requests, beam_width, memory_max_len] (required in beamsearch) (optional)
@ -54,8 +54,8 @@ namespace tensorrt_llm::plugins
// mode, // mode,
// all elements must be identical. // all elements must be identical.
// 8. past_key_value_pool [batch_size, 2, local_num_kv_heads, max_seq_len, head_size] or // 8. past_key_value_pool [batch_size, 2, local_num_kv_heads, max_seq_len, head_size] or
// block_pointers [batch_size, 2, max_blocks_per_seq] if paged kv cache (optional) // block_pointers [num_layers, batch_size, 2, max_blocks_per_seq] if paged kv cache (optional)
// 8.1 host_block_pointers [batch_size, 2, max_blocks_per_seq] if paged kv cache (optional) // 8.1 host_block_pointers [num_layers, batch_size, 2, max_blocks_per_seq] if paged kv cache (optional)
// 9. kv_cache_quantization_scale [1] (optional) // 9. kv_cache_quantization_scale [1] (optional)
// 10. kv_cache_dequantization_scale [1] (optional) // 10. kv_cache_dequantization_scale [1] (optional)
// 11. alibi_slopes [num_heads] (optional for ALiBi position embedding) // 11. alibi_slopes [num_heads] (optional for ALiBi position embedding)
@ -70,8 +70,8 @@ namespace tensorrt_llm::plugins
class GPTAttentionPlugin : public GPTAttentionPluginCommon class GPTAttentionPlugin : public GPTAttentionPluginCommon
{ {
public: public:
GPTAttentionPlugin(int num_heads, int num_kv_heads, int head_size, int unidirectional, float q_scaling, GPTAttentionPlugin(int layer_idx, int num_heads, int num_kv_heads, int head_size, int unidirectional,
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type, float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type, float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi

View File

@ -301,7 +301,7 @@ void LoraPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, in
} }
mGemmId.n = N; mGemmId.n = N;
mGemmId.k = K; mGemmId.k = K;
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
int64_t getLowRankWorkSpaceSize( int64_t getLowRankWorkSpaceSize(
@ -344,7 +344,7 @@ size_t LoraPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, in
void runCublasGemmEx(const int M, const int N, const int K, const bool transA, const bool transB, const void* act, void runCublasGemmEx(const int M, const int N, const int K, const bool transA, const bool transB, const void* act,
const void* weight, void* output, cublasHandle_t cublas_handle) const void* weight, void* output, cublasHandle_t cublas_handle)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
float a = 1.0f; float a = 1.0f;
float b = 0.0f; float b = 0.0f;
void* alpha = &a; void* alpha = &a;
@ -356,13 +356,13 @@ void runCublasGemmEx(const int M, const int N, const int K, const bool transA, c
tensorrt_llm::common::check_cuda_error(cublasGemmEx(cublas_handle, transa, transb, m, n, k, alpha, weight, tensorrt_llm::common::check_cuda_error(cublasGemmEx(cublas_handle, transa, transb, m, n, k, alpha, weight,
CUDA_R_16F, lda, act, CUDA_R_16F, ldb, beta, output, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); CUDA_R_16F, lda, act, CUDA_R_16F, ldb, beta, output, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
// inputs // inputs
// input [-1, K] (view as 2D) // input [-1, K] (view as 2D)
// host_request_type [batch_size] on cpu // host_request_type [batch_size] on cpu
@ -565,7 +565,7 @@ int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf
} }
} }
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return 0; return 0;
} }

View File

@ -19,16 +19,14 @@
#include "namedTensor.h" #include "namedTensor.h"
#include "tensorrt_llm/batch_manager/GptManager.h" #include "tensorrt_llm/batch_manager/GptManager.h"
#include "tensorrt_llm/batch_manager/callbacks.h" #include "tensorrt_llm/batch_manager/callbacks.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/pybind/utils/pathCaster.h" #include "tensorrt_llm/pybind/utils/pathCaster.h"
#include <pybind11/functional.h> #include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/ops/tensor.h>
#include <memory> #include <memory>
#include <optional> #include <optional>
@ -38,17 +36,19 @@ namespace tensorrt_llm::pybind::batch_manager
{ {
GptManager::GptManager(std::filesystem::path const& trtEnginePath, tb::TrtGptModelType modelType, int32_t maxBeamWidth, GptManager::GptManager(std::filesystem::path const& trtEnginePath, tb::TrtGptModelType modelType, int32_t maxBeamWidth,
tb::batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback getInferenceRequestsCb, tb::batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback const& getInferenceRequestsCb,
SendResponseCallback sendResponseCb, tb::PollStopSignalCallback pollStopSignalCb, SendResponseCallback const& sendResponseCb, const tb::PollStopSignalCallback& pollStopSignalCb,
tb::ReturnBatchManagerStatsCallback returnBatchManagerStatsCb, const tb::TrtGptModelOptionalParams& optionalParams, tb::ReturnBatchManagerStatsCallback const& returnBatchManagerStatsCb,
std::optional<uint64_t> terminateReqId) tb::TrtGptModelOptionalParams const& optionalParams, std::optional<uint64_t> terminateReqId)
: tb::GptManager(trtEnginePath, modelType, maxBeamWidth, schedulerPolicy, callbackAdapter(getInferenceRequestsCb),
callbackAdapter(sendResponseCb), pollStopSignalCb, returnBatchManagerStatsCb, optionalParams, terminateReqId)
{ {
mManager = std::make_unique<tb::GptManager>(trtEnginePath, modelType, maxBeamWidth, schedulerPolicy,
callbackAdapter(getInferenceRequestsCb), callbackAdapter(sendResponseCb), pollStopSignalCb,
returnBatchManagerStatsCb, optionalParams, terminateReqId);
} }
py::object GptManager::enter() py::object GptManager::enter()
{ {
TLLM_CHECK(static_cast<bool>(mManager));
return py::cast(this); return py::cast(this);
} }
@ -62,11 +62,14 @@ void GptManager::shutdown()
// NOTE: we must release the GIL here. GptManager has spawned a thread for the execution loop. That thread must be // NOTE: we must release the GIL here. GptManager has spawned a thread for the execution loop. That thread must be
// able to do forward progress for the shutdown process to succeed. It takes the GIL during its callbacks, so // able to do forward progress for the shutdown process to succeed. It takes the GIL during its callbacks, so
// we release it now. Note that we shouldn't do anything related to python objects after that. // we release it now. Note that we shouldn't do anything related to python objects after that.
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
py::gil_scoped_release release; py::gil_scoped_release release;
tb::GptManager::shutdown(); mManager->shutdown();
mManager = nullptr;
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
} }
tb::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback callback) tb::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback const& callback)
{ {
return [callback](int32_t max_sequences) return [callback](int32_t max_sequences)
{ {
@ -82,14 +85,14 @@ tb::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback ca
}; };
} }
tb::SendResponseCallback callbackAdapter(SendResponseCallback callback) tb::SendResponseCallback callbackAdapter(SendResponseCallback const& callback)
{ {
return [callback](uint64_t id, std::list<tb::NamedTensor> const& cppTensors, bool isOk, const std::string& errMsg) return [callback](uint64_t id, std::list<tb::NamedTensor> const& cppTensors, bool isOk, const std::string& errMsg)
{ {
std::list<NamedTensor> pythonList{}; std::list<NamedTensor> pythonList{};
for (const auto& cppNamedTensor : cppTensors) for (const auto& cppNamedTensor : cppTensors)
{ {
pythonList.push_back(NamedTensor{cppNamedTensor}); pythonList.emplace_back(cppNamedTensor);
} }
callback(id, pythonList, isOk, errMsg); callback(id, pythonList, isOk, errMsg);
}; };

View File

@ -24,6 +24,7 @@
#include <ATen/ops/tensor.h> #include <ATen/ops/tensor.h>
#include <functional> #include <functional>
#include <memory>
namespace tensorrt_llm::pybind::batch_manager namespace tensorrt_llm::pybind::batch_manager
{ {
@ -31,18 +32,18 @@ namespace tensorrt_llm::pybind::batch_manager
using GetInferenceRequestsCallback = std::function<std::list<InferenceRequest>(int32_t)>; using GetInferenceRequestsCallback = std::function<std::list<InferenceRequest>(int32_t)>;
using SendResponseCallback = std::function<void(uint64_t, std::list<NamedTensor> const&, bool, const std::string&)>; using SendResponseCallback = std::function<void(uint64_t, std::list<NamedTensor> const&, bool, const std::string&)>;
tensorrt_llm::batch_manager::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback callback); tensorrt_llm::batch_manager::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback const& callback);
tensorrt_llm::batch_manager::SendResponseCallback callbackAdapter(SendResponseCallback callback); tensorrt_llm::batch_manager::SendResponseCallback callbackAdapter(SendResponseCallback const& callback);
class GptManager : tensorrt_llm::batch_manager::GptManager class GptManager
{ {
public: public:
GptManager(std::filesystem::path const& trtEnginePath, tensorrt_llm::batch_manager::TrtGptModelType modelType, GptManager(std::filesystem::path const& trtEnginePath, tensorrt_llm::batch_manager::TrtGptModelType modelType,
int32_t maxBeamWidth, tensorrt_llm::batch_manager::batch_scheduler::SchedulerPolicy schedulerPolicy, int32_t maxBeamWidth, tensorrt_llm::batch_manager::batch_scheduler::SchedulerPolicy schedulerPolicy,
GetInferenceRequestsCallback getInferenceRequestsCb, SendResponseCallback sendResponseCb, GetInferenceRequestsCallback const& getInferenceRequestsCb, SendResponseCallback const& sendResponseCb,
tensorrt_llm::batch_manager::PollStopSignalCallback pollStopSignalCb = nullptr, tensorrt_llm::batch_manager::PollStopSignalCallback const& pollStopSignalCb = nullptr,
tensorrt_llm::batch_manager::ReturnBatchManagerStatsCallback returnBatchManagerStatsCb = nullptr, tensorrt_llm::batch_manager::ReturnBatchManagerStatsCallback const& returnBatchManagerStatsCb = nullptr,
const tensorrt_llm::batch_manager::TrtGptModelOptionalParams& optionalParams tensorrt_llm::batch_manager::TrtGptModelOptionalParams const& optionalParams
= tensorrt_llm::batch_manager::TrtGptModelOptionalParams(), = tensorrt_llm::batch_manager::TrtGptModelOptionalParams(),
std::optional<uint64_t> terminateReqId = std::nullopt); std::optional<uint64_t> terminateReqId = std::nullopt);
@ -51,6 +52,9 @@ public:
void shutdown(); void shutdown();
static void initBindings(pybind11::module_& m); static void initBindings(pybind11::module_& m);
private:
std::unique_ptr<tensorrt_llm::batch_manager::GptManager> mManager;
}; };
} // namespace tensorrt_llm::pybind::batch_manager } // namespace tensorrt_llm::pybind::batch_manager

View File

@ -17,6 +17,7 @@
#include "inferenceRequest.h" #include "inferenceRequest.h"
#include "tensorrt_llm/batch_manager/inferenceRequest.h" #include "tensorrt_llm/batch_manager/inferenceRequest.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/runtime/torch.h" #include "tensorrt_llm/runtime/torch.h"
#include "tensorrt_llm/runtime/torchView.h" #include "tensorrt_llm/runtime/torchView.h"
#include <memory> #include <memory>
@ -25,6 +26,15 @@
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <torch/extension.h> #include <torch/extension.h>
#ifdef _WIN32
// FIXME: THPStream_Wrap seems not to be present in libtorch_python.so on Windows
PyObject* THPStream_Wrap(const c10::Stream& stream)
{
TLLM_THROW("Stream conversion in not yet supported on Windows.");
return nullptr;
}
#endif
namespace tb = tensorrt_llm::batch_manager; namespace tb = tensorrt_llm::batch_manager;
namespace tr = tensorrt_llm::runtime; namespace tr = tensorrt_llm::runtime;
@ -32,6 +42,7 @@ using namespace tensorrt_llm::pybind::batch_manager;
namespace namespace
{ {
std::shared_ptr<InferenceRequest> fromTrtLlm(tb::InferenceRequest const& request) std::shared_ptr<InferenceRequest> fromTrtLlm(tb::InferenceRequest const& request)
{ {
InferenceRequest::TensorMap tensorMap; InferenceRequest::TensorMap tensorMap;
@ -53,18 +64,24 @@ std::shared_ptr<tb::InferenceRequest> InferenceRequest::toTrtLlm() const
tb::InferenceRequest::TensorMap tensorMap; tb::InferenceRequest::TensorMap tensorMap;
for (auto const& [name, tensor] : mInputTensors) for (auto const& [name, tensor] : mInputTensors)
{ {
if (tensor.has_value()) tensorMap[name] = tr::TorchView::of(tensor);
{
tensorMap[name] = tr::TorchView::of(tensor.value());
}
} }
auto inferenceRequest = std::make_shared<tb::InferenceRequest>(std::move(tensorMap), mRequestId); auto inferenceRequest = std::make_shared<tb::InferenceRequest>(std::move(tensorMap), mRequestId);
inferenceRequest->setIsStreaming(isStreaming()); inferenceRequest->setIsStreaming(isStreaming());
if (mlogitsPostProcessor)
{
inferenceRequest->setLogitsPostProcessor(LlmRequest::callbackAdapter(mlogitsPostProcessor));
}
return inferenceRequest; return inferenceRequest;
} }
std::string InferenceRequest::serialize() const std::string InferenceRequest::serialize() const
{ {
TLLM_CHECK_WITH_INFO(mlogitsPostProcessor == std::nullopt,
"Serializing InferenceRequest with logitsPostProcessor set is not supported."
"Please set the callback after de-serialization");
std::vector<std::int64_t> serialized{toTrtLlm()->serialize()}; std::vector<std::int64_t> serialized{toTrtLlm()->serialize()};
static_assert(sizeof(decltype(serialized)::value_type) / sizeof(char) == 8); static_assert(sizeof(decltype(serialized)::value_type) / sizeof(char) == 8);
return {reinterpret_cast<char const*>(serialized.data()), serialized.size() * 8}; return {reinterpret_cast<char const*>(serialized.data()), serialized.size() * 8};
@ -81,8 +98,12 @@ std::shared_ptr<InferenceRequest> InferenceRequest::deserialize(std::string cons
void InferenceRequest::initBindings(py::module_& m) void InferenceRequest::initBindings(py::module_& m)
{ {
py::class_<InferenceRequest>(m, "InferenceRequest") py::class_<InferenceRequest>(m, "InferenceRequest")
.def(py::init<uint64_t>()) .def(py::init<uint64_t, std::optional<LogitsProcessorCallback>>(), py::arg("request_id"),
py::arg("logits_post_processor_callback") = py::none())
.def(py::init<uint64_t, InferenceRequest::TensorMap const&>(), "deprecated: use direct tensor access instead") .def(py::init<uint64_t, InferenceRequest::TensorMap const&>(), "deprecated: use direct tensor access instead")
.def_property("logits_post_processor",
nullptr, // passing logits processor in the cpp->python direction doesn't work. getter is then undefined
&InferenceRequest::setLogitsPostProcessor)
.def_property("input_ids", &InferenceRequest::getInputIdsUnchecked, &InferenceRequest::setInputIds) .def_property("input_ids", &InferenceRequest::getInputIdsUnchecked, &InferenceRequest::setInputIds)
.def_property( .def_property(
"draft_input_ids", &InferenceRequest::getDraftInputIdsUnchecked, &InferenceRequest::setDraftInputIds) "draft_input_ids", &InferenceRequest::getDraftInputIdsUnchecked, &InferenceRequest::setDraftInputIds)
@ -101,6 +122,8 @@ void InferenceRequest::initBindings(py::module_& m)
.def_property("runtime_top_p", &InferenceRequest::getRuntimeTopPUnchecked, &InferenceRequest::setRuntimeTopP) .def_property("runtime_top_p", &InferenceRequest::getRuntimeTopPUnchecked, &InferenceRequest::setRuntimeTopP)
.def_property( .def_property(
"length_penalty", &InferenceRequest::getLengthPenaltyUnchecked, &InferenceRequest::setLengthPenalty) "length_penalty", &InferenceRequest::getLengthPenaltyUnchecked, &InferenceRequest::setLengthPenalty)
.def_property(
"early_stopping", &InferenceRequest::getEarlyStoppingUnchecked, &InferenceRequest::setEarlyStopping)
.def_property("repetition_penalty", &InferenceRequest::getRepetitionPenaltyUnchecked, .def_property("repetition_penalty", &InferenceRequest::getRepetitionPenaltyUnchecked,
&InferenceRequest::setRepetitionPenalty) &InferenceRequest::setRepetitionPenalty)
.def_property("min_length", &InferenceRequest::getMinLengthUnchecked, &InferenceRequest::setMinLength) .def_property("min_length", &InferenceRequest::getMinLengthUnchecked, &InferenceRequest::setMinLength)

View File

@ -18,6 +18,7 @@
#pragma once #pragma once
#include "tensorrt_llm/batch_manager/inferenceRequest.h" #include "tensorrt_llm/batch_manager/inferenceRequest.h"
#include "tensorrt_llm/pybind/batch_manager/llmRequest.h"
#include "tensorrt_llm/pybind/batch_manager/namedTensor.h" #include "tensorrt_llm/pybind/batch_manager/namedTensor.h"
#include <ATen/ATen.h> #include <ATen/ATen.h>
@ -30,25 +31,28 @@ namespace tensorrt_llm::pybind::batch_manager
{ {
class InferenceRequest class InferenceRequest
: public tensorrt_llm::batch_manager::GenericInferenceRequest<std::optional<at::Tensor>, NamedTensor> : public tensorrt_llm::batch_manager::GenericInferenceRequest<at::Tensor, NamedTensor, c10::Stream>
{ {
public: public:
using Base = tensorrt_llm::batch_manager::GenericInferenceRequest<std::optional<at::Tensor>, NamedTensor>; using Base = tensorrt_llm::batch_manager::GenericInferenceRequest<at::Tensor, NamedTensor, c10::Stream>;
using TensorPtr = Base::TensorPtr; using TensorPtr = Base::TensorPtr;
using TensorMap = Base::TensorMap; using TensorMap = Base::TensorMap;
using LogitsProcessorCallback = Base::LogitsPostProcessor;
InferenceRequest(uint64_t requestId) InferenceRequest(uint64_t requestId, std::optional<LogitsProcessorCallback> logitsCb = std::nullopt)
: Base(requestId) : Base(requestId, logitsCb)
{ {
} }
InferenceRequest(uint64_t requestId, TensorMap const& inputTensors) InferenceRequest(uint64_t requestId, TensorMap const& inputTensors,
: Base{requestId, inputTensors} std::optional<LogitsProcessorCallback> logitsCb = std::nullopt)
: Base{requestId, inputTensors, logitsCb}
{ {
} }
InferenceRequest(uint64_t requestId, TensorMap&& inputTensors) InferenceRequest(
: Base{requestId, std::move(inputTensors)} uint64_t requestId, TensorMap&& inputTensors, std::optional<LogitsProcessorCallback> logitsCb = std::nullopt)
: Base{requestId, std::move(inputTensors), logitsCb}
{ {
} }

View File

@ -17,7 +17,10 @@
#include "llmRequest.h" #include "llmRequest.h"
#include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/generationInput.h" #include "tensorrt_llm/runtime/generationInput.h"
#include "tensorrt_llm/runtime/torch.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/runtime/torchView.h" #include "tensorrt_llm/runtime/torchView.h"
#include <memory> #include <memory>
@ -45,6 +48,25 @@ std::optional<tb::LlmRequest::TensorPtr> from_torch(std::optional<LlmRequest::Te
} // namespace } // namespace
std::optional<tb::LlmRequest::LogitsPostProcessor> LlmRequest::callbackAdapter(
std::optional<LlmRequest::LogitsPostProcessor> callback)
{
if (!callback)
{
return std::nullopt;
}
return [callback](RequestIdType reqId, tensorrt_llm::runtime::ITensor::SharedPtr& tensor,
tensorrt_llm::batch_manager::LlmRequest::BeamTokens const& tokens,
tensorrt_llm::runtime::BufferManager::CudaStreamPtr stream)
{
at::Tensor atTensor = tr::Torch::tensor(tensor);
auto result = callback.value()(reqId, atTensor, tokens, runtime::TorchUtils::stream(*stream).unwrap());
return tr::TorchView::of(result);
};
}
std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
{ {
auto embeddingBias = from_torch(mEmbeddingBias); auto embeddingBias = from_torch(mEmbeddingBias);
@ -58,7 +80,8 @@ std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
return std::make_shared<tb::LlmRequest>(mRequestId, mMaxNewTokens, return std::make_shared<tb::LlmRequest>(mRequestId, mMaxNewTokens,
std::make_shared<std::vector<TokenIdType>>(mTokens.at(0)), mSamplingConfig, mIsStreaming, mEndId, mPadId, std::make_shared<std::vector<TokenIdType>>(mTokens.at(0)), mSamplingConfig, mIsStreaming, mEndId, mPadId,
embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable, mPromptVocabSize, loraWeights, loraConfig, embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable, mPromptVocabSize, loraWeights, loraConfig,
mReturnLogProbs, mReturnContextLogits, mReturnGenerationLogits, mDraftTokens, draftLogits); mReturnLogProbs, mReturnContextLogits, mReturnGenerationLogits, mDraftTokens, draftLogits,
mExcludeInputFromOutput, callbackAdapter(mLogitsPostProcessor));
} }
void LlmRequest::initBindings(py::module_& m) void LlmRequest::initBindings(py::module_& m)
@ -70,7 +93,7 @@ void LlmRequest::initBindings(py::module_& m)
std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::TensorPtr>,
std::optional<LlmRequest::SizeType>, std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::SizeType>, std::optional<LlmRequest::TensorPtr>,
std::optional<LlmRequest::TensorPtr>, bool, bool, bool, std::optional<LlmRequest::VecTokens>, std::optional<LlmRequest::TensorPtr>, bool, bool, bool, std::optional<LlmRequest::VecTokens>,
std::optional<LlmRequest::TensorPtr>>(), std::optional<LlmRequest::TensorPtr>, bool, std::optional<LlmRequest::LogitsPostProcessor>>(),
py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"), py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"),
py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt, py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt,
py::arg("embedding_bias") = std::nullopt, py::arg("bad_words_list") = std::nullopt, py::arg("embedding_bias") = std::nullopt, py::arg("bad_words_list") = std::nullopt,
@ -78,7 +101,8 @@ void LlmRequest::initBindings(py::module_& m)
py::arg("prompt_vocab_size") = std::nullopt, py::arg("lora_weights") = std::nullopt, py::arg("prompt_vocab_size") = std::nullopt, py::arg("lora_weights") = std::nullopt,
py::arg("lora_config") = std::nullopt, py::arg("return_log_probs") = false, py::arg("lora_config") = std::nullopt, py::arg("return_log_probs") = false,
py::arg("return_context_logits") = false, py::arg("return_generation_logits") = false, py::arg("return_context_logits") = false, py::arg("return_generation_logits") = false,
py::arg("draft_tokens") = std::nullopt, py::arg("draft_logits") = std::nullopt) py::arg("draft_tokens") = std::nullopt, py::arg("draft_logits") = std::nullopt,
py::arg("exclude_input_from_output") = false, py::arg("logits_post_processor") = std::nullopt)
.def("get_num_tokens", &LlmRequest::getNumTokens, py::arg("beam")) .def("get_num_tokens", &LlmRequest::getNumTokens, py::arg("beam"))
.def_property_readonly("max_beam_num_tokens", &LlmRequest::getMaxBeamNumTokens) .def_property_readonly("max_beam_num_tokens", &LlmRequest::getMaxBeamNumTokens)
.def("get_token", &LlmRequest::getToken, py::arg("beam"), py::arg("pos")) .def("get_token", &LlmRequest::getToken, py::arg("beam"), py::arg("pos"))

View File

@ -28,10 +28,16 @@
namespace tensorrt_llm::pybind::batch_manager namespace tensorrt_llm::pybind::batch_manager
{ {
class LlmRequest : public tensorrt_llm::batch_manager::GenericLlmRequest<at::Tensor> namespace tb = tensorrt_llm::batch_manager;
/* Unfortunately, torch's default pybind bindings don't know about c10::cuda::CUDAStream,
* so we have to pass the more generic c10::Stream, and convert it back to a full-fledged
* torch.cuda.Stream in python. See example in test/bindings/test_gpt_manager.py
*/
class LlmRequest : public tb::GenericLlmRequest<at::Tensor, c10::Stream>
{ {
public: public:
using Base = GenericLlmRequest<at::Tensor>; using Base = GenericLlmRequest<at::Tensor, c10::Stream>;
using TensorPtr = Base::TensorPtr; using TensorPtr = Base::TensorPtr;
using SizeType = Base::SizeType; using SizeType = Base::SizeType;
using TokenIdType = Base::TokenIdType; using TokenIdType = Base::TokenIdType;
@ -39,6 +45,7 @@ public:
using VecLogProbs = Base::VecLogProbs; using VecLogProbs = Base::VecLogProbs;
using BeamTokens = Base::BeamTokens; using BeamTokens = Base::BeamTokens;
using VecTokens = Base::VecTokens; using VecTokens = Base::VecTokens;
using LogitsPostProcessor = Base::LogitsPostProcessor;
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::vector<TokenIdType> inputTokens, LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::vector<TokenIdType> inputTokens,
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt, runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
@ -48,16 +55,20 @@ public:
std::optional<SizeType> promptVocabSize = std::nullopt, std::optional<TensorPtr> loraWeights = std::nullopt, std::optional<SizeType> promptVocabSize = std::nullopt, std::optional<TensorPtr> loraWeights = std::nullopt,
std::optional<TensorPtr> loraConfig = std::nullopt, bool returnLogProbs = false, std::optional<TensorPtr> loraConfig = std::nullopt, bool returnLogProbs = false,
bool returnContextLogits = false, bool returnGenerationLogits = false, bool returnContextLogits = false, bool returnGenerationLogits = false,
std::optional<VecTokens> draftTokens = std::nullopt, std::optional<TensorPtr> draftLogits = std::nullopt) std::optional<VecTokens> draftTokens = std::nullopt, std::optional<TensorPtr> draftLogits = std::nullopt,
bool excludeInputFromOutput = false, std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
: Base(requestId, maxNewTokens, std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)), : Base(requestId, maxNewTokens, std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)),
samplingConfig, isStreaming, endId, padId, embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable, samplingConfig, isStreaming, endId, padId, embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable,
promptVocabSize, loraWeights, loraConfig, returnLogProbs, returnContextLogits, returnGenerationLogits, promptVocabSize, loraWeights, loraConfig, returnLogProbs, returnContextLogits, returnGenerationLogits,
draftTokens.has_value() ? std::make_shared<VecTokens>(std::move(draftTokens.value())) draftTokens.has_value() ? std::make_shared<VecTokens>(std::move(draftTokens.value()))
: std::make_shared<VecTokens>(), : std::make_shared<VecTokens>(),
draftLogits) draftLogits, excludeInputFromOutput, logitsPostProcessor)
{ {
} }
static std::optional<tb::LlmRequest::LogitsPostProcessor> callbackAdapter(
std::optional<LlmRequest::LogitsPostProcessor> callback);
[[nodiscard]] std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> toTrtLlm() const; [[nodiscard]] std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> toTrtLlm() const;
static void initBindings(pybind11::module_& m); static void initBindings(pybind11::module_& m);
}; };

View File

@ -37,6 +37,7 @@
#include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/gptJsonConfig.h" #include "tensorrt_llm/runtime/gptJsonConfig.h"
#include "tensorrt_llm/runtime/gptSession.h" #include "tensorrt_llm/runtime/gptSession.h"
#include "tensorrt_llm/runtime/memoryCounters.h"
#include "tensorrt_llm/runtime/samplingConfig.h" #include "tensorrt_llm/runtime/samplingConfig.h"
namespace py = pybind11; namespace py = pybind11;
@ -232,7 +233,8 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.def_readwrite("top_p_min", &tr::SamplingConfig::topPMin) .def_readwrite("top_p_min", &tr::SamplingConfig::topPMin)
.def_readwrite("top_p_reset_ids", &tr::SamplingConfig::topPResetIds) .def_readwrite("top_p_reset_ids", &tr::SamplingConfig::topPResetIds)
.def_readwrite("beam_search_diversity_rate", &tr::SamplingConfig::beamSearchDiversityRate) .def_readwrite("beam_search_diversity_rate", &tr::SamplingConfig::beamSearchDiversityRate)
.def_readwrite("length_penalty", &tr::SamplingConfig::lengthPenalty); .def_readwrite("length_penalty", &tr::SamplingConfig::lengthPenalty)
.def_readwrite("early_stopping", &tr::SamplingConfig::earlyStopping);
py::class_<tr::GptJsonConfig>(m, "GptJsonConfig") py::class_<tr::GptJsonConfig>(m, "GptJsonConfig")
.def(py::init<std::string, std::string, std::string, SizeType, SizeType, tr::GptModelConfig>(), py::arg("name"), .def(py::init<std::string, std::string, std::string, SizeType, SizeType, tr::GptModelConfig>(), py::arg("name"),
@ -302,6 +304,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
tensorNames.attr("RUNTIME_TOP_K") = py::str(tb::inference_request::kRuntimeTopKTensorName); tensorNames.attr("RUNTIME_TOP_K") = py::str(tb::inference_request::kRuntimeTopKTensorName);
tensorNames.attr("RUNTIME_TOP_P") = py::str(tb::inference_request::kRuntimeTopPTensorName); tensorNames.attr("RUNTIME_TOP_P") = py::str(tb::inference_request::kRuntimeTopPTensorName);
tensorNames.attr("LENGTH_PENALTY") = py::str(tb::inference_request::kLengthPenaltyTensorName); tensorNames.attr("LENGTH_PENALTY") = py::str(tb::inference_request::kLengthPenaltyTensorName);
tensorNames.attr("EARLY_STOPPING") = py::str(tb::inference_request::kEarlyStoppingTensorName);
tensorNames.attr("REPETITION_PENALTY") = py::str(tb::inference_request::kRepetitionPenaltyTensorName); tensorNames.attr("REPETITION_PENALTY") = py::str(tb::inference_request::kRepetitionPenaltyTensorName);
tensorNames.attr("MIN_LENGTH") = py::str(tb::inference_request::kMinLengthTensorName); tensorNames.attr("MIN_LENGTH") = py::str(tb::inference_request::kMinLengthTensorName);
tensorNames.attr("PRESENCE_PENALTY") = py::str(tb::inference_request::kPresencePenaltyTensorName); tensorNames.attr("PRESENCE_PENALTY") = py::str(tb::inference_request::kPresencePenaltyTensorName);
@ -342,4 +345,11 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.def_readwrite("decoding_mode", &tb::TrtGptModelOptionalParams::decodingMode); .def_readwrite("decoding_mode", &tb::TrtGptModelOptionalParams::decodingMode);
tpb::GptManager::initBindings(m); tpb::GptManager::initBindings(m);
py::class_<tr::MemoryCounters>(m, "MemoryCounters")
.def_static("instance", &tr::MemoryCounters::getInstance, py::return_value_policy::reference)
.def_property_readonly("gpu", &tr::MemoryCounters::getGpu)
.def_property_readonly("cpu", &tr::MemoryCounters::getCpu)
.def_property_readonly("pinned", &tr::MemoryCounters::getPinned)
.def_property_readonly("uvm", &tr::MemoryCounters::getUVM);
} }

View File

@ -19,7 +19,6 @@
#include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/runtime/iBuffer.h" #include "tensorrt_llm/runtime/iBuffer.h"
#include <memory>
#include <string> #include <string>
namespace tensorrt_llm::runtime namespace tensorrt_llm::runtime

View File

@ -16,6 +16,7 @@
#include "tensorrt_llm/runtime/gptDecoder.h" #include "tensorrt_llm/runtime/gptDecoder.h"
#include "tensorrt_llm/common/cudaAllocator.h"
#include "tensorrt_llm/common/tensorConversion.h" #include "tensorrt_llm/common/tensorConversion.h"
#include "tensorrt_llm/kernels/decodingKernels.h" #include "tensorrt_llm/kernels/decodingKernels.h"
#include "tensorrt_llm/layers/dynamicDecodeLayer.h" #include "tensorrt_llm/layers/dynamicDecodeLayer.h"
@ -81,6 +82,7 @@ void GptDecoder<T>::setup(SamplingConfig const& samplingConfig, size_t batchSize
setupParams.beam_search_diversity_rate = samplingConfig.beamSearchDiversityRate; setupParams.beam_search_diversity_rate = samplingConfig.beamSearchDiversityRate;
setupParams.length_penalty = samplingConfig.lengthPenalty; setupParams.length_penalty = samplingConfig.lengthPenalty;
setupParams.early_stopping = samplingConfig.earlyStopping;
auto const batchSlotsPtr = batchSlots.has_value() ? bufferCast<SizeType>(*(batchSlots.value())) : nullptr; auto const batchSlotsPtr = batchSlots.has_value() ? bufferCast<SizeType>(*(batchSlots.value())) : nullptr;
mDynamicDecodeLayer->setup(batchSize, samplingConfig.beamWidth, batchSlotsPtr, setupParams); mDynamicDecodeLayer->setup(batchSize, samplingConfig.beamWidth, batchSlotsPtr, setupParams);
@ -174,7 +176,7 @@ template <typename T>
typename tl::DynamicDecodeLayer<T>::OutputParams prepareOutputs( typename tl::DynamicDecodeLayer<T>::OutputParams prepareOutputs(
DecodingOutput& output, DecodingInput::TensorPtr const& inputLengths, DecodingOutput::TensorPtr& logProbsTiled) DecodingOutput& output, DecodingInput::TensorPtr const& inputLengths, DecodingOutput::TensorPtr& logProbsTiled)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
typename tl::DynamicDecodeLayer<T>::OutputParams outputParams(tcc::toTllmTensor(*output.ids)); typename tl::DynamicDecodeLayer<T>::OutputParams outputParams(tcc::toTllmTensor(*output.ids));
outputParams.newTokens = tcc::toTllmTensor(*output.newTokens); outputParams.newTokens = tcc::toTllmTensor(*output.newTokens);
@ -261,7 +263,7 @@ typename tl::DynamicDecodeLayer<T>::OutputParams prepareOutputs(
template <typename T> template <typename T>
bool GptDecoder<T>::forward(DecodingOutput& output, DecodingInput const& input) bool GptDecoder<T>::forward(DecodingOutput& output, DecodingInput const& input)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto forwardParams = prepareInputs<T>(input); auto forwardParams = prepareInputs<T>(input);
auto outputParams = prepareOutputs<T>(output, input.lengths, mLogProbsTiled); auto outputParams = prepareOutputs<T>(output, input.lengths, mLogProbsTiled);
auto const maxBatchSize = input.maxBatchSize; auto const maxBatchSize = input.maxBatchSize;
@ -309,7 +311,7 @@ bool GptDecoder<T>::forward(DecodingOutput& output, DecodingInput const& input)
template <typename T> template <typename T>
void GptDecoder<T>::forwardAsync(DecodingOutput& output, DecodingInput const& input) void GptDecoder<T>::forwardAsync(DecodingOutput& output, DecodingInput const& input)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto forwardParams = prepareInputs<T>(input); auto forwardParams = prepareInputs<T>(input);
auto outputParams = prepareOutputs<T>(output, input.lengths, mLogProbsTiled); auto outputParams = prepareOutputs<T>(output, input.lengths, mLogProbsTiled);
@ -321,7 +323,7 @@ template <typename T>
void GptDecoder<T>::gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput, void GptDecoder<T>::gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput,
DecodingInput const& decodingInput, BufferManager const& manager) DecodingInput const& decodingInput, BufferManager const& manager)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const& finalOutputIdsShape = finalOutputIds.getShape(); auto const& finalOutputIdsShape = finalOutputIds.getShape();
auto const& decodingOutputIdsShape = decodingOutput.ids->getShape(); auto const& decodingOutputIdsShape = decodingOutput.ids->getShape();
auto const batchSize = finalOutputIdsShape.d[0]; auto const batchSize = finalOutputIdsShape.d[0];
@ -355,7 +357,7 @@ void GptDecoder<T>::gatherTree(ITensor& finalOutputIds, DecodingOutput const& de
beamHypotheses.length_penalties beamHypotheses.length_penalties
= nullptr; // TODO (bhsueh) should set length penalties, this should be a gpu tensor When it is set as = nullptr; // TODO (bhsueh) should set length penalties, this should be a gpu tensor When it is set as
// nullptr, the kernel will use default value (1.0f) automatically. // nullptr, the kernel will use default value (1.0f) automatically.
beamHypotheses.early_stoppings = nullptr; // TODO (wili), similar as length_penalties
beamHypotheses.output_ids_tgt = bufferCast<TokenIdType>(*decodingOutput.beamHypotheses.outputIdsTgt); beamHypotheses.output_ids_tgt = bufferCast<TokenIdType>(*decodingOutput.beamHypotheses.outputIdsTgt);
beamHypotheses.sequence_lengths_tgt = bufferCast<SizeType>(*decodingOutput.beamHypotheses.sequenceLengthsTgt); beamHypotheses.sequence_lengths_tgt = bufferCast<SizeType>(*decodingOutput.beamHypotheses.sequenceLengthsTgt);
beamHypotheses.cum_log_probs = bufferCast<float>(*decodingOutput.beamHypotheses.cumLogProbs); beamHypotheses.cum_log_probs = bufferCast<float>(*decodingOutput.beamHypotheses.cumLogProbs);
@ -381,7 +383,7 @@ void GptDecoder<T>::gatherTree(ITensor& finalOutputIds, DecodingOutput const& de
batchSize, stream.get()); batchSize, stream.get());
sync_check_cuda_error(); sync_check_cuda_error();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
namespace tensorrt_llm::runtime namespace tensorrt_llm::runtime
@ -394,7 +396,7 @@ void IGptDecoder::acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor
ITensor const& contextLengths, ITensor const& numDraftTokens, ITensor& sequenceLengths, ITensor const& finishedVec, ITensor const& contextLengths, ITensor const& numDraftTokens, ITensor& sequenceLengths, ITensor const& finishedVec,
ITensor& finishedFinal, ITensor& finishedSum, ITensor const& batchSlots, BufferManager::CudaStreamPtr const& stream) ITensor& finishedFinal, ITensor& finishedSum, ITensor const& batchSlots, BufferManager::CudaStreamPtr const& stream)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const finishedVecShape = finishedVec.getShape(); auto const finishedVecShape = finishedVec.getShape();
auto const maxBatchSize = finishedVecShape.d[1]; auto const maxBatchSize = finishedVecShape.d[1];
@ -439,7 +441,7 @@ void IGptDecoder::acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor
sync_check_cuda_error(); sync_check_cuda_error();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void IGptDecoder::acceptDraftTokensByLogits(ITensor& draftLogits, ITensor const& targetLogits, ITensor& draftProbs, void IGptDecoder::acceptDraftTokensByLogits(ITensor& draftLogits, ITensor const& targetLogits, ITensor& draftProbs,
@ -447,7 +449,7 @@ void IGptDecoder::acceptDraftTokensByLogits(ITensor& draftLogits, ITensor const&
SizeType vocabSize, SizeType vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold, SizeType vocabSize, SizeType vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold,
curandState_t* curandState, BufferManager::CudaStreamPtr const& stream) curandState_t* curandState, BufferManager::CudaStreamPtr const& stream)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const draftLogitsShape = draftLogits.getShape(); auto const draftLogitsShape = draftLogits.getShape();
auto const maxBatchSize = draftLogitsShape.d[0]; auto const maxBatchSize = draftLogitsShape.d[0];
@ -488,5 +490,5 @@ void IGptDecoder::acceptDraftTokensByLogits(ITensor& draftLogits, ITensor const&
sync_check_cuda_error(); sync_check_cuda_error();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }

View File

@ -22,6 +22,7 @@
#include "tensorrt_llm/runtime/runtimeKernels.h" #include "tensorrt_llm/runtime/runtimeKernels.h"
#include <algorithm> #include <algorithm>
#include <cassert>
#include <memory> #include <memory>
using namespace tensorrt_llm::runtime; using namespace tensorrt_llm::runtime;
@ -33,7 +34,7 @@ namespace
{ {
SamplingConfig extractSamplingConfig(SamplingConfig const& batchSamplingConfig, SizeType batchIdx) SamplingConfig extractSamplingConfig(SamplingConfig const& batchSamplingConfig, SizeType batchIdx)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
SamplingConfig samplingConfig{batchSamplingConfig.beamWidth}; SamplingConfig samplingConfig{batchSamplingConfig.beamWidth};
auto extractOptional = [&batchIdx](auto& single, auto const& batch) auto extractOptional = [&batchIdx](auto& single, auto const& batch)
@ -64,8 +65,9 @@ SamplingConfig extractSamplingConfig(SamplingConfig const& batchSamplingConfig,
// beam search layer // beam search layer
samplingConfig.beamSearchDiversityRate = batchSamplingConfig.beamSearchDiversityRate; samplingConfig.beamSearchDiversityRate = batchSamplingConfig.beamSearchDiversityRate;
samplingConfig.lengthPenalty = batchSamplingConfig.lengthPenalty; samplingConfig.lengthPenalty = batchSamplingConfig.lengthPenalty;
samplingConfig.earlyStopping = batchSamplingConfig.earlyStopping;
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return samplingConfig; return samplingConfig;
} }
@ -78,7 +80,7 @@ GptDecoderBatch::GptDecoderBatch(
, mStream{std::move(stream)} , mStream{std::move(stream)}
, mBufferManager{mStream} , mBufferManager{mStream}
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto constexpr nvTokenIdType = TRTDataType<TokenIdType>::value; auto constexpr nvTokenIdType = TRTDataType<TokenIdType>::value;
auto constexpr nvSizeType = TRTDataType<SizeType>::value; auto constexpr nvSizeType = TRTDataType<SizeType>::value;
auto constexpr nvFloatType = TRTDataType<float>::value; auto constexpr nvFloatType = TRTDataType<float>::value;
@ -125,14 +127,14 @@ GptDecoderBatch::GptDecoderBatch(
dInput->badWordsLens = mBufferManager.emptyTensor(MemoryType::kPINNED, TRTDataType<SizeType>::value); dInput->badWordsLens = mBufferManager.emptyTensor(MemoryType::kPINNED, TRTDataType<SizeType>::value);
dInput->embeddingBias = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType); dInput->embeddingBias = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void GptDecoderBatch::setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth, void GptDecoderBatch::setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth,
SizeType maxAttentionWindow, SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep, SizeType maxAttentionWindow, SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep,
bool fusedDecoder, nvinfer1::DataType dtype) bool fusedDecoder, nvinfer1::DataType dtype)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(maxBatchSize > 0); TLLM_CHECK(maxBatchSize > 0);
TLLM_CHECK(maxBeamWidth > 0); TLLM_CHECK(maxBeamWidth > 0);
TLLM_CHECK(maxTokensPerStep > 0); TLLM_CHECK(maxTokensPerStep > 0);
@ -253,13 +255,13 @@ void GptDecoderBatch::setup(DecodingMode const& mode, SizeType maxBatchSize, Siz
mBeamWidths[i] = 0; mBeamWidths[i] = 0;
mGeneratedTokensPerStep[i] = 0; mGeneratedTokensPerStep[i] = 0;
} }
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void GptDecoderBatch::newRequest( void GptDecoderBatch::newRequest(
SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig) SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(batchIdx >= 0); TLLM_CHECK(batchIdx >= 0);
auto const& jointOutputIdsShape = mJointDecodingOutput->ids->getShape(); auto const& jointOutputIdsShape = mJointDecodingOutput->ids->getShape();
auto const batchSize = jointOutputIdsShape.d[0]; auto const batchSize = jointOutputIdsShape.d[0];
@ -373,7 +375,7 @@ void GptDecoderBatch::newRequest(
dOutput->newTokensVec.resize(mMaxTokensPerStep); dOutput->newTokensVec.resize(mMaxTokensPerStep);
for (SizeType ti = 0; ti < mMaxTokensPerStep; ++ti) for (SizeType ti = 0; ti < mMaxTokensPerStep; ++ti)
{ {
TensorPtr newTokensStepView = std::move(ITensor::slice(dJointOutput.newTokensSteps, ti, localBatchSize)); TensorPtr newTokensStepView = ITensor::slice(dJointOutput.newTokensSteps, ti, localBatchSize);
newTokensStepView->squeeze(0); newTokensStepView->squeeze(0);
dOutput->newTokensVec[ti] = ITensor::slice(newTokensStepView, batchIdx, localBatchSize); dOutput->newTokensVec[ti] = ITensor::slice(newTokensStepView, batchIdx, localBatchSize);
manager.setZero(*dOutput->newTokensVec[ti]); manager.setZero(*dOutput->newTokensVec[ti]);
@ -469,7 +471,7 @@ void GptDecoderBatch::newRequest(
auto outputIdsView = ITensor::view(outputIds, ITensor::makeShape({beamWidth, mMaxSequenceLength})); auto outputIdsView = ITensor::view(outputIds, ITensor::makeShape({beamWidth, mMaxSequenceLength}));
kernels::invokeFill(*outputIdsView, endId, *stream); kernels::invokeFill(*outputIdsView, endId, *stream);
kernels::tileTensor(*outputIdsView, *inputIdsView, beamWidth, *stream); kernels::tileTensor(*outputIdsView, *inputIdsView, beamWidth, *stream);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void GptDecoderBatch::newRequests(std::vector<SizeType> const& seqSlots, void GptDecoderBatch::newRequests(std::vector<SizeType> const& seqSlots,
@ -500,7 +502,7 @@ void GptDecoderBatch::newRequests(std::vector<SizeType> const& seqSlots,
GptDecoderBatch::TokenPtr GptDecoderBatch::forwardAsync( GptDecoderBatch::TokenPtr GptDecoderBatch::forwardAsync(
decoder_batch::Output& output, decoder_batch::Input const& input) decoder_batch::Output& output, decoder_batch::Input const& input)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto& allTargetLogits = input.logits; auto& allTargetLogits = input.logits;
// TODO(nkorobov): check logits shape considering draft tokens // TODO(nkorobov): check logits shape considering draft tokens
@ -737,13 +739,13 @@ GptDecoderBatch::TokenPtr GptDecoderBatch::forwardAsync(
CudaEvent eventStop{}; CudaEvent eventStop{};
mStream->record(eventStop); mStream->record(eventStop);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return std::make_unique<decoder_batch::Token>(std::move(eventStop), input.active); return std::make_unique<decoder_batch::Token>(std::move(eventStop), input.active);
} }
void GptDecoderBatch::forwardSync(decoder_batch::Token const& token) void GptDecoderBatch::forwardSync(decoder_batch::Token const& token)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
token.event.synchronize(); token.event.synchronize();
for (std::int32_t i = 0; i < mActualBatchSize; ++i) for (std::int32_t i = 0; i < mActualBatchSize; ++i)
@ -756,13 +758,13 @@ void GptDecoderBatch::forwardSync(decoder_batch::Token const& token)
|| bufferCast<SizeType>(*dOutput.finishedSum)[0] == static_cast<SizeType>(mBeamWidths[i]); || bufferCast<SizeType>(*dOutput.finishedSum)[0] == static_cast<SizeType>(mBeamWidths[i]);
} }
} }
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
// TODO call this at the end of forward if mFinished[i] changes from false to true? // TODO call this at the end of forward if mFinished[i] changes from false to true?
CudaEvent GptDecoderBatch::postProcessRequest(SizeType batchIdx) const CudaEvent GptDecoderBatch::postProcessRequest(SizeType batchIdx) const
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto& stream = mStreams[batchIdx]; auto& stream = mStreams[batchIdx];
auto manager = BufferManager{stream}; auto manager = BufferManager{stream};
auto& decoder = *mDecoders[batchIdx]; auto& decoder = *mDecoders[batchIdx];
@ -779,14 +781,14 @@ CudaEvent GptDecoderBatch::postProcessRequest(SizeType batchIdx) const
CudaEvent event{}; CudaEvent event{};
stream->record(event); stream->record(event);
mStream->wait(event); mStream->wait(event);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return event; return event;
} }
void GptDecoderBatch::newBatch( void GptDecoderBatch::newBatch(
GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig) GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
// split batch into single requests // split batch into single requests
auto const& inputLengths = inputs.lengths; auto const& inputLengths = inputs.lengths;
mActualBatchSize = inputLengths->getShape().d[0]; mActualBatchSize = inputLengths->getShape().d[0];
@ -855,12 +857,12 @@ void GptDecoderBatch::newBatch(
} }
newRequest(batchIdx, request, extractSamplingConfig(samplingConfig, batchIdx)); newRequest(batchIdx, request, extractSamplingConfig(samplingConfig, batchIdx));
} }
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void GptDecoderBatch::forwardAsync(decoder::Output& output, decoder::Input const& input) void GptDecoderBatch::forwardAsync(decoder::Output& output, decoder::Input const& input)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const& logitsShape = input.logits->getShape(); auto const& logitsShape = input.logits->getShape();
auto const batchSize = logitsShape.d[0]; auto const batchSize = logitsShape.d[0];
@ -886,32 +888,32 @@ void GptDecoderBatch::forwardAsync(decoder::Output& output, decoder::Input const
kernels::reduce(*mFinishedSum, *ITensor::slice(mJointDecodingOutput->finishedSum, 0, mActualBatchSize), *mStream); kernels::reduce(*mFinishedSum, *ITensor::slice(mJointDecodingOutput->finishedSum, 0, mActualBatchSize), *mStream);
mStream->record(mForwardEvent); mStream->record(mForwardEvent);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void GptDecoderBatch::forwardSync() void GptDecoderBatch::forwardSync()
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
forwardSync(*mForwardToken); forwardSync(*mForwardToken);
// wait for mFinishedSum to be updated // wait for mFinishedSum to be updated
mForwardEvent.synchronize(); mForwardEvent.synchronize();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void GptDecoderBatch::finalize() const void GptDecoderBatch::finalize() const
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
for (SizeType batchIdx = 0; batchIdx < mActualBatchSize; ++batchIdx) for (SizeType batchIdx = 0; batchIdx < mActualBatchSize; ++batchIdx)
{ {
postProcessRequest(batchIdx); auto event = postProcessRequest(batchIdx);
} }
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
CudaEvent GptDecoderBatch::finalize(SizeType batchIdx) const CudaEvent GptDecoderBatch::finalize(SizeType batchIdx) const
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto event = postProcessRequest(batchIdx); auto event = postProcessRequest(batchIdx);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return event; return event;
} }

View File

@ -16,6 +16,7 @@
#include "tensorrt_llm/runtime/gptJsonConfig.h" #include "tensorrt_llm/runtime/gptJsonConfig.h"
#include "common.h"
#include "gptModelConfig.h" #include "gptModelConfig.h"
#include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/logger.h"
@ -24,6 +25,7 @@
#include <fstream> #include <fstream>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include <string_view> #include <string_view>
#include <utility>
using namespace tensorrt_llm::runtime; using namespace tensorrt_llm::runtime;
namespace tc = tensorrt_llm::common; namespace tc = tensorrt_llm::common;
@ -69,12 +71,122 @@ std::optional<FieldType> parseJsonFieldOptional(Json const& json, std::string_vi
return value; return value;
} }
GptModelConfig createModelConfig(
Json const& json, bool engineVersionNone, SizeType tensorParallelism, nvinfer1::DataType dataType)
{
auto const& config = engineVersionNone ? json.at("builder_config") : json.at("pretrained_config");
auto const* const numLayersField = engineVersionNone ? "num_layers" : "num_hidden_layers";
auto const* const numHeadsField = engineVersionNone ? "num_heads" : "num_attention_heads";
auto const* const numKvHeadsField = engineVersionNone ? "num_kv_heads" : "num_key_value_heads";
auto const* const mlpHiddenSizeField = engineVersionNone ? "mlp_hidden_size" : "intermediate_size";
auto const numLayers = config.at(numLayersField).template get<SizeType>();
auto const numHeads = config.at(numHeadsField).template get<SizeType>() / tensorParallelism;
auto const vocabSize = config.at("vocab_size").template get<SizeType>();
auto const hiddenSize = config.at("hidden_size").template get<SizeType>() / tensorParallelism;
auto const sizePerHead = parseJsonFieldOr(config, "head_size", hiddenSize / numHeads);
// TODO:
// Code crashes when numKvHeads <= 0. Clamping downwards to 1 prevents that, make sure this is best fix.
auto const numKvHeads
= std::max(parseJsonFieldOr(config, numKvHeadsField, numHeads * tensorParallelism) / tensorParallelism, 1);
auto const mlpHiddenSize = parseJsonFieldOptional<SizeType>(config, mlpHiddenSizeField);
auto modelConfig = GptModelConfig{vocabSize, numLayers, numHeads, hiddenSize, dataType};
modelConfig.setSizePerHead(sizePerHead);
modelConfig.setNbKvHeads(numKvHeads);
if (mlpHiddenSize.has_value())
{
modelConfig.setMlpHiddenSize(mlpHiddenSize.value() / tensorParallelism);
}
return modelConfig;
};
void parseBuilderConfig(GptModelConfig& modelConfig, Json const& builderConfig)
{
auto const maxBatchSize = parseJsonFieldOr(builderConfig, "max_batch_size", 0);
auto const maxBeamWidth = parseJsonFieldOr(builderConfig, "max_beam_width", 0);
auto const maxInputLen = parseJsonFieldOr(builderConfig, "max_input_len", 0);
auto const maxSequenceLen = maxInputLen + parseJsonFieldOr(builderConfig, "max_output_len", 0);
auto const maxDraftLen = parseJsonFieldOr(builderConfig, "max_draft_len", 0);
auto const maxNumTokens = parseJsonFieldOptional<SizeType>(builderConfig, "max_num_tokens");
auto const maxPromptEmbeddingTableSize
= parseJsonFieldOr<SizeType>(builderConfig, "max_prompt_embedding_table_size", 0);
auto const computeContextLogits = parseJsonFieldOr(builderConfig, "gather_context_logits", false);
auto const computeGenerationLogits = parseJsonFieldOr(builderConfig, "gather_generation_logits", false);
modelConfig.setMaxBatchSize(maxBatchSize);
modelConfig.setMaxBeamWidth(maxBeamWidth);
modelConfig.setMaxInputLen(maxInputLen);
modelConfig.setMaxSequenceLen(maxSequenceLen);
modelConfig.setMaxNumTokens(maxNumTokens);
modelConfig.setMaxDraftLen(maxDraftLen);
modelConfig.setMaxPromptEmbeddingTableSize(maxPromptEmbeddingTableSize);
modelConfig.computeContextLogits(computeContextLogits);
modelConfig.computeGenerationLogits(computeGenerationLogits);
}
void parsePluginConfig(GptModelConfig& modelConfig, Json const& pluginConfig)
{
auto const useGptAttentionPlugin = !pluginConfig.at("gpt_attention_plugin").is_null();
auto const removeInputPadding = pluginConfig.at("remove_input_padding").template get<bool>();
auto const& pagedKvCache = pluginConfig.at("paged_kv_cache");
auto const& tokensPerBlock = pluginConfig.at("tokens_per_block");
auto const useCustomAllReduce = pluginConfig.at("use_custom_all_reduce").template get<bool>();
auto const useContextFMHAForGeneration = pluginConfig.at("use_context_fmha_for_generation").template get<bool>();
auto const pagedContextFMHA = pluginConfig.at("use_paged_context_fmha").template get<bool>();
modelConfig.useGptAttentionPlugin(useGptAttentionPlugin);
modelConfig.usePackedInput(removeInputPadding);
modelConfig.usePagedKvCache(pagedKvCache);
modelConfig.setTokensPerBlock(tokensPerBlock);
modelConfig.useCustomAllReduce(useCustomAllReduce);
modelConfig.setUseContextFMHAForGeneration(useContextFMHAForGeneration);
modelConfig.setPagedContextFMHA(pagedContextFMHA);
}
void parseLora(GptModelConfig& modelConfig, Json const& json, Json const& pluginConfig, bool engineVersionNone,
SizeType tensorParallelism)
{
auto const& config = engineVersionNone ? json.at("builder_config") : json.at("pretrained_config");
auto const loraMaxRank = parseJsonFieldOr(config, "max_lora_rank", SizeType{0});
auto const loraTargetModules = parseJsonFieldOptional<std::vector<std::string>>(config, "lora_target_modules");
if (loraTargetModules.has_value())
{
modelConfig.setLoraModules(LoraModule::createLoraModules(loraTargetModules.value(), modelConfig.getHiddenSize(),
modelConfig.getMlpHiddenSize(), modelConfig.getNbHeads(), modelConfig.getNbKvHeads(),
modelConfig.getSizePerHead(), tensorParallelism));
}
modelConfig.setMaxLoraRank(loraMaxRank);
auto useLoraPlugin = !pluginConfig.at("lora_plugin").is_null();
if (useLoraPlugin)
{
if (modelConfig.getLoraModules().empty() || modelConfig.getMaxLoraRank() == 0)
{
TLLM_LOG_WARNING("lora_plugin enabled, but no lora module enabled: setting useLoraPlugin to false");
useLoraPlugin = false;
}
}
modelConfig.useLoraPlugin(useLoraPlugin);
}
template <typename InputType> template <typename InputType>
GptJsonConfig parseJson(InputType&& i) GptJsonConfig parseJson(InputType&& input)
{ {
auto constexpr allowExceptions = true; auto constexpr allowExceptions = true;
auto constexpr ingoreComments = true; auto constexpr ingoreComments = true;
auto const json = nlohmann::json::parse(i, nullptr, allowExceptions, ingoreComments); auto const json = nlohmann::json::parse(std::forward<InputType>(input), nullptr, allowExceptions, ingoreComments);
auto const engineVersion = parseJsonFieldOr(json, "version", std::string("none")); auto const engineVersion = parseJsonFieldOr(json, "version", std::string("none"));
@ -104,113 +216,26 @@ GptJsonConfig parseJson(InputType&& i)
auto const precision = engineVersionNone ? builderConfig.at("precision").template get<std::string>() auto const precision = engineVersionNone ? builderConfig.at("precision").template get<std::string>()
: json.at("pretrained_config").at("dtype").template get<std::string>(); : json.at("pretrained_config").at("dtype").template get<std::string>();
auto dataType = nvinfer1::DataType::kFLOAT; auto const dataType = [&precision]()
if (!precision.compare("float32"))
dataType = nvinfer1::DataType::kFLOAT;
else if (!precision.compare("float16"))
dataType = nvinfer1::DataType::kHALF;
else if (!precision.compare("bfloat16"))
dataType = nvinfer1::DataType::kBF16;
else
TLLM_CHECK_WITH_INFO(false, tc::fmtstr("Model data type '%s' not supported", precision.c_str()));
auto modelConfig = [&engineVersionNone, &json, &builderConfig, &tensorParallelism, &dataType]()
{ {
auto const& config = engineVersionNone ? builderConfig : json.at("pretrained_config"); if (!precision.compare("float32"))
return nvinfer1::DataType::kFLOAT;
std::string const numLayersField = engineVersionNone ? "num_layers" : "num_hidden_layers"; else if (!precision.compare("float16"))
std::string const numHeadsField = engineVersionNone ? "num_heads" : "num_attention_heads"; return nvinfer1::DataType::kHALF;
std::string const numKvHeadsField = engineVersionNone ? "num_kv_heads" : "num_key_value_heads"; else if (!precision.compare("bfloat16"))
std::string const mlpHiddenSizeField = engineVersionNone ? "mlp_hidden_size" : "intermediate_size"; return nvinfer1::DataType::kBF16;
else
auto const numLayers = config.at(numLayersField).template get<SizeType>(); TLLM_THROW("Model data type '%s' not supported", precision.c_str());
auto const numHeads = config.at(numHeadsField).template get<SizeType>() / tensorParallelism;
auto const vocabSize = config.at("vocab_size").template get<SizeType>();
auto const hiddenSize = config.at("hidden_size").template get<SizeType>() / tensorParallelism;
auto const sizePerHead = parseJsonFieldOr(config, "head_size", hiddenSize / numHeads);
auto const loraMaxRank = parseJsonFieldOr(config, "max_lora_rank", SizeType{0});
auto const loraTargetModules = parseJsonFieldOptional<std::vector<std::string>>(config, "lora_target_modules");
// TODO:
// Code crashes when numKvHeads <= 0. Clamping downwards to 1 prevents that, make sure this is best fix.
auto const numKvHeads
= std::max(parseJsonFieldOr(config, numKvHeadsField, numHeads * tensorParallelism) / tensorParallelism, 1);
auto const mlpHiddenSize = parseJsonFieldOptional<SizeType>(config, mlpHiddenSizeField);
auto modelConfig = GptModelConfig{vocabSize, numLayers, numHeads, hiddenSize, dataType};
modelConfig.setSizePerHead(sizePerHead);
modelConfig.setNbKvHeads(numKvHeads);
if (mlpHiddenSize.has_value())
{
modelConfig.setMlpHiddenSize(mlpHiddenSize.value() / tensorParallelism);
}
if (loraTargetModules.has_value())
{
modelConfig.setLoraModules(LoraModule::createLoraModules(loraTargetModules.value(),
modelConfig.getHiddenSize(), modelConfig.getMlpHiddenSize(), modelConfig.getNbHeads(),
modelConfig.getNbKvHeads(), modelConfig.getSizePerHead(), tensorParallelism));
}
modelConfig.setMaxLoraRank(loraMaxRank);
return modelConfig;
}(); }();
auto const maxBatchSize = parseJsonFieldOr(builderConfig, "max_batch_size", 0); auto modelConfig = createModelConfig(json, engineVersionNone, tensorParallelism, dataType);
auto const maxBeamWidth = parseJsonFieldOr(builderConfig, "max_beam_width", 0);
auto const maxInputLen = parseJsonFieldOr(builderConfig, "max_input_len", 0);
auto const maxSequenceLen = maxInputLen + parseJsonFieldOr(builderConfig, "max_output_len", 0);
auto const maxDraftLen = parseJsonFieldOr(builderConfig, "max_draft_len", 0);
auto const maxNumTokens = parseJsonFieldOptional<SizeType>(builderConfig, "max_num_tokens");
auto const maxPromptEmbeddingTableSize
= parseJsonFieldOr<SizeType>(builderConfig, "max_prompt_embedding_table_size", 0);
auto const computeContextLogits = parseJsonFieldOr(builderConfig, "gather_context_logits", false);
auto const computeGenerationLogits = parseJsonFieldOr(builderConfig, "gather_generation_logits", false);
modelConfig.setMaxBatchSize(maxBatchSize); parseBuilderConfig(modelConfig, builderConfig);
modelConfig.setMaxBeamWidth(maxBeamWidth);
modelConfig.setMaxInputLen(maxInputLen);
modelConfig.setMaxSequenceLen(maxSequenceLen);
modelConfig.setMaxNumTokens(maxNumTokens);
modelConfig.setMaxDraftLen(maxDraftLen);
modelConfig.setMaxPromptEmbeddingTableSize(maxPromptEmbeddingTableSize);
modelConfig.computeContextLogits(computeContextLogits);
modelConfig.computeGenerationLogits(computeGenerationLogits);
auto const& pluginConfig = engineVersionNone ? json.at("plugin_config") : builderConfig.at("plugin_config"); auto const& pluginConfig = engineVersionNone ? json.at("plugin_config") : builderConfig.at("plugin_config");
parsePluginConfig(modelConfig, pluginConfig);
auto const useGptAttentionPlugin = !pluginConfig.at("gpt_attention_plugin").is_null(); parseLora(modelConfig, json, pluginConfig, engineVersionNone, tensorParallelism);
auto const removeInputPadding = pluginConfig.at("remove_input_padding").template get<bool>();
auto const pagedKvCache = pluginConfig.at("paged_kv_cache");
auto const tokensPerBlock = pluginConfig.at("tokens_per_block");
auto const useCustomAllReduce = pluginConfig.at("use_custom_all_reduce").template get<bool>();
auto const useContextFMHAForGeneration = pluginConfig.at("use_context_fmha_for_generation").template get<bool>();
auto const pagedContextFMHA = pluginConfig.at("use_paged_context_fmha").template get<bool>();
modelConfig.useGptAttentionPlugin(useGptAttentionPlugin);
modelConfig.usePackedInput(removeInputPadding);
modelConfig.usePagedKvCache(pagedKvCache);
modelConfig.setTokensPerBlock(tokensPerBlock);
modelConfig.useCustomAllReduce(useCustomAllReduce);
modelConfig.setUseContextFMHAForGeneration(useContextFMHAForGeneration);
modelConfig.setPagedContextFMHA(pagedContextFMHA);
auto useLoraPlugin = !pluginConfig.at("lora_plugin").is_null();
if (useLoraPlugin)
{
if (modelConfig.getLoraModules().empty() || modelConfig.getMaxLoraRank() == 0)
{
TLLM_LOG_WARNING("lora_plugin enabled, but no lora module enabled: setting useLoraPlugin to false");
useLoraPlugin = false;
}
}
modelConfig.useLoraPlugin(useLoraPlugin);
if (engineVersionNone) if (engineVersionNone)
{ {

View File

@ -23,7 +23,6 @@
#include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/common/customAllReduceUtils.h" #include "tensorrt_llm/common/customAllReduceUtils.h"
#include "tensorrt_llm/common/stringUtils.h" #include "tensorrt_llm/common/stringUtils.h"
#include "tensorrt_llm/kernels/decodingKernels.h"
#include "tensorrt_llm/runtime/gptDecoderBatch.h" #include "tensorrt_llm/runtime/gptDecoderBatch.h"
#include "tensorrt_llm/runtime/ipcUtils.h" #include "tensorrt_llm/runtime/ipcUtils.h"
#include "tensorrt_llm/runtime/ncclCommunicator.h" #include "tensorrt_llm/runtime/ncclCommunicator.h"
@ -189,29 +188,20 @@ void GptSession::createKvCacheManager(SizeType batchSize, SizeType beamWidth, Si
kvDtype = mModelConfig.getDataType(); kvDtype = mModelConfig.getDataType();
} }
auto const maxNumTokens auto const maxNumBlocks = bmkv::KVCacheManager::calculateMaxNumBlocks(
= bmkv::KVCacheManager::getMaxNumTokens(kvCacheConfig, kvDtype, mModelConfig, mWorldConfig, getBufferManager()); kvCacheConfig, kvDtype, mModelConfig, mWorldConfig, getBufferManager());
TLLM_LOG_INFO("Using %d tokens in paged KV cache.", maxNumTokens);
auto const maxNumBlocks = tc::ceilDiv(maxNumTokens, tokensPerBlock);
SizeType sinkTokensInLastBlock = sinkTokenLength % tokensPerBlock;
SizeType bubbleLen = sinkTokensInLastBlock != 0 ? tokensPerBlock - sinkTokensInLastBlock : 0;
auto maxBlocksPerSeq = tc::ceilDiv(maxAttentionWindow + bubbleLen, tokensPerBlock);
// If beamWidth > 1, use one more block for each sequence in the paged kv cache to avoid dropping the needed // If beamWidth > 1, use one more block for each sequence in the paged kv cache to avoid dropping the needed
// tokens, when enabling cyclic kv cache. // tokens, when enabling cyclic kv cache.
auto const useOneMoreBlock = beamWidth > 1 && maxSequenceLength > maxAttentionWindow; auto const useOneMoreBlock = beamWidth > 1 && maxSequenceLength > maxAttentionWindow;
if (useOneMoreBlock)
{
maxBlocksPerSeq += 1;
}
auto const localNbLayers = mModelConfig.getNbLayers(mWorldConfig.getPipelineParallelism()); auto const localNbLayers = mModelConfig.getNbLayers(mWorldConfig.getPipelineParallelism());
auto const nbKvHeads = mModelConfig.getNbKvHeads(); auto const nbKvHeads = mModelConfig.getNbKvHeads();
auto const sizePerHead = mModelConfig.getSizePerHead(); auto const sizePerHead = mModelConfig.getSizePerHead();
bool enableBlockReuse{false}; bool constexpr enableBlockReuse{false};
mKvCacheManager = std::make_shared<bmkv::KVCacheManager>(localNbLayers, nbKvHeads, sizePerHead, tokensPerBlock, mKvCacheManager = std::make_shared<bmkv::KVCacheManager>(localNbLayers, nbKvHeads, sizePerHead, tokensPerBlock,
maxNumBlocks, batchSize, beamWidth, maxBlocksPerSeq, maxAttentionWindow, sinkTokenLength, useOneMoreBlock, maxNumBlocks, batchSize, beamWidth, maxAttentionWindow, sinkTokenLength, useOneMoreBlock, kvDtype,
kvDtype, mRuntime->getStreamPtr(), enableBlockReuse, kvCacheConfig.useUvm); mRuntime->getStreamPtr(), enableBlockReuse, kvCacheConfig.useUvm);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
@ -335,12 +325,14 @@ void GptSession::setup(Config const& sessionConfig)
createCustomAllReduceWorkspace(mMicroBatchConfig.genBatchSize, maxBeamWidth, maxSequenceLength); createCustomAllReduceWorkspace(mMicroBatchConfig.genBatchSize, maxBeamWidth, maxSequenceLength);
} }
auto* kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManager.get() : nullptr;
for (auto& buffers : mBuffers) for (auto& buffers : mBuffers)
{ {
// we don't know maxInputLength yet and ignore it for pre-allocation // we don't know maxInputLength yet and ignore it for pre-allocation
buffers->generationConfig = RuntimeBuffers::GenerationConfig{ buffers->generationConfig = RuntimeBuffers::GenerationConfig{
mMicroBatchConfig.genBatchSize, maxBeamWidth, 0, maxAttentionWindow, sinkTokenLength, maxSequenceLength}; mMicroBatchConfig.genBatchSize, maxBeamWidth, 0, maxAttentionWindow, sinkTokenLength, maxSequenceLength};
buffers->reshape(mModelConfig, mWorldConfig); buffers->reshape(kvCacheManager, mModelConfig, mWorldConfig);
} }
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
@ -684,6 +676,8 @@ void GptSession::generateBatched(std::vector<GenerationOutput>& microBatchesOutp
TLLM_CHECK(numMicroBatches <= mMicroBatchConfig.numGenBatches); TLLM_CHECK(numMicroBatches <= mMicroBatchConfig.numGenBatches);
SizeType const beamWidth{samplingConfig.beamWidth}; SizeType const beamWidth{samplingConfig.beamWidth};
auto* kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManager.get() : nullptr;
// Initialize and reshape buffers // Initialize and reshape buffers
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId) for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
{ {
@ -691,7 +685,7 @@ void GptSession::generateBatched(std::vector<GenerationOutput>& microBatchesOutp
auto& buffers = *mBuffers.at(microBatchId); auto& buffers = *mBuffers.at(microBatchId);
buffers.initFromInput(*microBatchInputs.ids, microBatchInputs.lengths, microBatchInputs.packed, beamWidth, buffers.initFromInput(*microBatchInputs.ids, microBatchInputs.lengths, microBatchInputs.packed, beamWidth,
mDecoderMaxAttentionWindow, mDecoderSinkTokenLength, mDecoderMaxSequenceLength, manager); mDecoderMaxAttentionWindow, mDecoderSinkTokenLength, mDecoderMaxSequenceLength, manager);
buffers.reshape(mModelConfig, mWorldConfig); buffers.reshape(kvCacheManager, mModelConfig, mWorldConfig);
buffers.reset(manager); buffers.reset(manager);
} }
@ -746,8 +740,6 @@ void GptSession::generateBatched(std::vector<GenerationOutput>& microBatchesOutp
} }
} }
auto kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManager.get() : nullptr;
auto const profileContext = !kProfileMbIdxs.empty() && kProfileMbIdxs.count(0) > 0; auto const profileContext = !kProfileMbIdxs.empty() && kProfileMbIdxs.count(0) > 0;
if (profileContext) if (profileContext)
cudaProfilerStart(); cudaProfilerStart();

View File

@ -49,7 +49,7 @@ LoraManager::LoraReqTensors& LoraManager::getTask(TaskIdType reqId)
void LoraManager::create( void LoraManager::create(
GptModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& manager) GptModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& manager)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto modules = modelConfig.getLoraModules(); auto modules = modelConfig.getLoraModules();
SizeType modOff = 0; SizeType modOff = 0;
@ -62,14 +62,14 @@ void LoraManager::create(
// TODO set this size from max adapter size // TODO set this size from max adapter size
mWorkspace = manager.emptyTensor(MemoryType::kGPU, modelConfig.getDataType()); mWorkspace = manager.emptyTensor(MemoryType::kGPU, modelConfig.getDataType());
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes, ReqIdsVec const& reqIds, void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes, ReqIdsVec const& reqIds,
std::vector<SizeType> const& reqBeamWidth, std::vector<bool> const& loraEnabled, SizeType numContextRequests, std::vector<SizeType> const& reqBeamWidth, std::vector<bool> const& loraEnabled, SizeType numContextRequests,
GptModelConfig const& modelConfig, WorldConfig const& worldConfig) GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism()); auto localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism());
auto firstLayerId = worldConfig.getPipelineParallelRank() * localNbLayers; auto firstLayerId = worldConfig.getPipelineParallelRank() * localNbLayers;
auto lastLayerId = firstLayerId + localNbLayers; auto lastLayerId = firstLayerId + localNbLayers;
@ -85,13 +85,13 @@ void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes
fillInputTensors( fillInputTensors(
weightsPtrs, adapterSizes, bid, reqIds[bid], reqBeamWidth[bid], firstLayerId, lastLayerId, tpSize, tpRank); weightsPtrs, adapterSizes, bid, reqIds[bid], reqBeamWidth[bid], firstLayerId, lastLayerId, tpSize, tpRank);
} }
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes, SizeType batchIdx, TaskIdType taskId, void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes, SizeType batchIdx, TaskIdType taskId,
SizeType beamWidth, SizeType firstLayerId, SizeType lastLayerId, SizeType tpSize, SizeType tpRank) SizeType beamWidth, SizeType firstLayerId, SizeType lastLayerId, SizeType tpSize, SizeType tpRank)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto weightsPointersPtr = bufferCast<int64_t>(*weightsPtrs); auto weightsPointersPtr = bufferCast<int64_t>(*weightsPtrs);
auto adapterSizesPtr = bufferCast<int32_t>(*adapterSizes); auto adapterSizesPtr = bufferCast<int32_t>(*adapterSizes);
@ -172,13 +172,13 @@ void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes
} }
std::fill_n(writeAdapterSizesPtr, beamWidth, adapterSize); std::fill_n(writeAdapterSizesPtr, beamWidth, adapterSize);
} }
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void LoraManager::insertInputTensors(TensorMap& inputTensors, TensorPtr weightsPtrs, TensorPtr adapterSizes, void LoraManager::insertInputTensors(TensorMap& inputTensors, TensorPtr weightsPtrs, TensorPtr adapterSizes,
GptModelConfig const& modelConfig, WorldConfig const& worldConfig) const GptModelConfig const& modelConfig, WorldConfig const& worldConfig) const
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism()); auto localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism());
auto firstLayerId = worldConfig.getPipelineParallelRank() * localNbLayers; auto firstLayerId = worldConfig.getPipelineParallelRank() * localNbLayers;
@ -210,13 +210,13 @@ void LoraManager::insertInputTensors(TensorMap& inputTensors, TensorPtr weightsP
} }
} }
} }
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void LoraManager::formatTaskTensors(LoraWeightsTensorPtr weights, LoraConfigTensorPtr config, void LoraManager::formatTaskTensors(LoraWeightsTensorPtr weights, LoraConfigTensorPtr config,
GptModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& manager) GptModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& manager)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
weights->squeeze(0); weights->squeeze(0);
config->squeeze(0); config->squeeze(0);
@ -261,7 +261,7 @@ void LoraManager::formatTaskTensors(LoraWeightsTensorPtr weights, LoraConfigTens
manager.copy(*mWorkspace, *weightsOut); manager.copy(*mWorkspace, *weightsOut);
} }
} }
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void LoraManager::reset() void LoraManager::reset()

View File

@ -38,13 +38,18 @@ std::vector<LoraModule> LoraModule::createLoraModules(std::vector<std::string> c
{ {
case ModuleType::kATTN_QKV: case ModuleType::kATTN_QKV:
case ModuleType::kCROSS_ATTN_QKV:
modules.emplace_back( modules.emplace_back(
LoraModule(t, hidden, (numHeads * attnHeadSize + 2 * numKvHeads * attnHeadSize), false, true, -1, 0)); LoraModule(t, hidden, (numHeads * attnHeadSize + 2 * numKvHeads * attnHeadSize), false, true, -1, 0));
break; break;
case ModuleType::kATTN_Q: case ModuleType::kATTN_Q:
case ModuleType::kATTN_K: case ModuleType::kATTN_K:
case ModuleType::kATTN_V: modules.emplace_back(t, hidden, hidden, false, true, -1, 0); break; case ModuleType::kATTN_V:
case ModuleType::kATTN_DENSE: modules.emplace_back(t, hidden, hidden, false, true, 1, -1); break; case ModuleType::kCROSS_ATTN_Q:
case ModuleType::kCROSS_ATTN_K:
case ModuleType::kCROSS_ATTN_V: modules.emplace_back(t, hidden, hidden, false, true, -1, 0); break;
case ModuleType::kATTN_DENSE:
case ModuleType::kCROSS_ATTN_DENSE: modules.emplace_back(t, hidden, hidden, false, true, 1, -1); break;
case ModuleType::kMLP_H_TO_4H: modules.emplace_back(t, hidden, mlpHidden, false, true, -1, 0); break; case ModuleType::kMLP_H_TO_4H: modules.emplace_back(t, hidden, mlpHidden, false, true, -1, 0); break;
case ModuleType::kMLP_GATE: modules.emplace_back(t, hidden, mlpHidden, false, true, -1, 0); break; case ModuleType::kMLP_GATE: modules.emplace_back(t, hidden, mlpHidden, false, true, -1, 0); break;
case ModuleType::kMLP_4H_TO_H: modules.emplace_back(t, mlpHiddenSize, hidden, false, true, 1, -1); break; case ModuleType::kMLP_4H_TO_H: modules.emplace_back(t, mlpHiddenSize, hidden, false, true, 1, -1); break;

View File

@ -39,6 +39,11 @@ public:
kMLP_H_TO_4H = 5, kMLP_H_TO_4H = 5,
kMLP_4H_TO_H = 6, kMLP_4H_TO_H = 6,
kMLP_GATE = 7, kMLP_GATE = 7,
kCROSS_ATTN_QKV = 8,
kCROSS_ATTN_Q = 9,
kCROSS_ATTN_K = 10,
kCROSS_ATTN_V = 11,
kCROSS_ATTN_DENSE = 12,
}; };
explicit constexpr LoraModule(ModuleType const& t, SizeType inDim, SizeType outDim, bool inDimFirst, explicit constexpr LoraModule(ModuleType const& t, SizeType inDim, SizeType outDim, bool inDimFirst,
@ -128,6 +133,16 @@ public:
return ModuleType::kMLP_4H_TO_H; return ModuleType::kMLP_4H_TO_H;
else if (name == "mlp_gate") else if (name == "mlp_gate")
return ModuleType::kMLP_GATE; return ModuleType::kMLP_GATE;
else if (name == "cross_attn_qkv")
return ModuleType::kCROSS_ATTN_QKV;
else if (name == "cross_attn_q")
return ModuleType::kCROSS_ATTN_Q;
else if (name == "cross_attn_k")
return ModuleType::kCROSS_ATTN_K;
else if (name == "cross_attn_v")
return ModuleType::kCROSS_ATTN_V;
else if (name == "cross_attn_dense")
return ModuleType::kCROSS_ATTN_DENSE;
else else
return ModuleType::kINVALID; return ModuleType::kINVALID;
} }
@ -144,6 +159,11 @@ public:
case ModuleType::kMLP_H_TO_4H: return "mlp_h_to_4h"; case ModuleType::kMLP_H_TO_4H: return "mlp_h_to_4h";
case ModuleType::kMLP_4H_TO_H: return "mlp_4h_to_h"; case ModuleType::kMLP_4H_TO_H: return "mlp_4h_to_h";
case ModuleType::kMLP_GATE: return "mlp_gate"; case ModuleType::kMLP_GATE: return "mlp_gate";
case ModuleType::kCROSS_ATTN_QKV: return "cross_attn_qkv";
case ModuleType::kCROSS_ATTN_Q: return "cross_attn_q";
case ModuleType::kCROSS_ATTN_K: return "cross_attn_k";
case ModuleType::kCROSS_ATTN_V: return "cross_attn_v";
case ModuleType::kCROSS_ATTN_DENSE: return "cross_attn_dense";
case ModuleType::kINVALID: return "INVALID"; case ModuleType::kINVALID: return "INVALID";
} }
return "INVALID"; return "INVALID";

View File

@ -82,4 +82,10 @@ void MemoryCounters::deallocate(MemoryType memoryType, MemoryCounters::SizeType
default: TLLM_THROW("Unknown memory type"); default: TLLM_THROW("Unknown memory type");
} }
} }
MemoryCounters& MemoryCounters::getInstance()
{
static MemoryCounters mInstance;
return mInstance;
}
} // namespace tensorrt_llm::runtime } // namespace tensorrt_llm::runtime

View File

@ -17,13 +17,13 @@
#include "tensorrt_llm/runtime/runtimeBuffers.h" #include "tensorrt_llm/runtime/runtimeBuffers.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/stlUtils.h" #include "tensorrt_llm/common/stlUtils.h"
#include "tensorrt_llm/runtime/runtimeKernels.h" #include "tensorrt_llm/runtime/runtimeKernels.h"
#include "tensorrt_llm/runtime/tllmRuntime.h" #include "tensorrt_llm/runtime/tllmRuntime.h"
#include "tensorrt_llm/runtime/utils/sessionUtils.h" #include "tensorrt_llm/runtime/utils/sessionUtils.h"
#include <algorithm> #include <algorithm>
#include <iostream>
using namespace tensorrt_llm::runtime; using namespace tensorrt_llm::runtime;
namespace tc = tensorrt_llm::common; namespace tc = tensorrt_llm::common;
@ -32,7 +32,7 @@ RuntimeBuffers::GenerationConfig RuntimeBuffers::GenerationConfig::fromInput(ITe
ITensor const& inputLengthsHost, bool const inputPacked, SizeType const beamWidth, ITensor const& inputLengthsHost, bool const inputPacked, SizeType const beamWidth,
SizeType const maxAttentionWindow, SizeType const sinkTokenLength, SizeType const maxSequenceLength) SizeType const maxAttentionWindow, SizeType const sinkTokenLength, SizeType const maxSequenceLength)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = static_cast<SizeType>(inputLengthsHost.getSize()); auto const batchSize = static_cast<SizeType>(inputLengthsHost.getSize());
auto const* inputLengthsPtr = bufferCast<SizeType>(inputLengthsHost); auto const* inputLengthsPtr = bufferCast<SizeType>(inputLengthsHost);
@ -71,14 +71,14 @@ RuntimeBuffers::GenerationConfig RuntimeBuffers::GenerationConfig::fromInput(ITe
"Max input length is equal to or larger that maxSequenceLength given in setup. No new tokens can be " "Max input length is equal to or larger that maxSequenceLength given in setup. No new tokens can be "
"generated."); "generated.");
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return GenerationConfig{ return GenerationConfig{
batchSize, beamWidth, maxInputLength, maxAttentionWindow, sinkTokenLength, maxSequenceLength, inputLengthSum}; batchSize, beamWidth, maxInputLength, maxAttentionWindow, sinkTokenLength, maxSequenceLength, inputLengthSum};
} }
void RuntimeBuffers::clear() void RuntimeBuffers::clear()
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
contextLengthsHost = nullptr; contextLengthsHost = nullptr;
contextLengthsDevice = nullptr; contextLengthsDevice = nullptr;
@ -104,22 +104,22 @@ void RuntimeBuffers::clear()
hiddenStates = nullptr; hiddenStates = nullptr;
allocated = false; allocated = false;
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void RuntimeBuffers::clearTensorMaps() void RuntimeBuffers::clearTensorMaps()
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
for (auto& buffer : inputBuffers) for (auto& buffer : inputBuffers)
buffer.clear(); buffer.clear();
for (auto& buffer : outputBuffers) for (auto& buffer : outputBuffers)
buffer.clear(); buffer.clear();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelConfig, WorldConfig const& worldConfig) void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto& manager = runtime.getBufferManager(); auto& manager = runtime.getBufferManager();
auto& engine = runtime.getEngine(); auto& engine = runtime.getEngine();
@ -170,8 +170,7 @@ void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelCon
if (modelConfig.usePagedKvCache()) if (modelConfig.usePagedKvCache())
{ {
auto const kvCacheBlockPointersType auto const kvCacheBlockPointersType = engine.getTensorDataType("kv_cache_block_pointers");
= engine.getTensorDataType(("kv_cache_block_pointers_" + std::to_string(firstLayerId)).c_str());
kvCacheBlockPointersHost = manager.emptyTensor(MemoryType::kCPU, kvCacheBlockPointersType); kvCacheBlockPointersHost = manager.emptyTensor(MemoryType::kCPU, kvCacheBlockPointersType);
kvCacheBlockPointersDevice = manager.emptyTensor(MemoryType::kGPU, kvCacheBlockPointersType); kvCacheBlockPointersDevice = manager.emptyTensor(MemoryType::kGPU, kvCacheBlockPointersType);
} }
@ -183,8 +182,7 @@ void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelCon
if (modelConfig.useGptAttentionPlugin()) if (modelConfig.useGptAttentionPlugin())
{ {
pastKeyValueLengths = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32); pastKeyValueLengths = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32);
maxAttentionWindows maxAttentionWindows = BufferManager::cpu(ITensor::makeShape({localNbLayers}), nvinfer1::DataType::kINT32);
= utils::createBufferVector(runtime, localNbLayers, MemoryType::kCPU, nvinfer1::DataType::kINT32);
sinkTokenLengths = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32); sinkTokenLengths = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32);
} }
else else
@ -207,7 +205,7 @@ void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelCon
hiddenStates = manager.emptyTensor(MemoryType::kGPU, modelConfig.getDataType()); hiddenStates = manager.emptyTensor(MemoryType::kGPU, modelConfig.getDataType());
} }
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void RuntimeBuffers::initFromInput(ITensor const& inputIds, TensorPtr const& inputLengths, bool inputPacked, void RuntimeBuffers::initFromInput(ITensor const& inputIds, TensorPtr const& inputLengths, bool inputPacked,
@ -223,15 +221,15 @@ void RuntimeBuffers::initFromInput(ITensor const& inputIds, TensorPtr const& inp
inputIds, *contextLengthsHost, inputPacked, beamWidth, maxAttentionWindow, sinkTokenLength, maxSequenceLength); inputIds, *contextLengthsHost, inputPacked, beamWidth, maxAttentionWindow, sinkTokenLength, maxSequenceLength);
} }
void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig const& worldConfig) void RuntimeBuffers::reshape(
KvCacheManager const* kvCacheManager, GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = generationConfig.batchSize; auto const batchSize = generationConfig.batchSize;
auto const beamWidth = generationConfig.beamWidth; auto const beamWidth = generationConfig.beamWidth;
auto const maxInputLength = generationConfig.maxInputLength; auto const maxInputLength = generationConfig.maxInputLength;
auto const maxAttentionWindow = generationConfig.maxAttentionWindow; auto const maxAttentionWindow = generationConfig.maxAttentionWindow;
auto const sinkTokenLen = generationConfig.sinkTokenLength;
auto const maxSeqLength = generationConfig.maxSeqLength; auto const maxSeqLength = generationConfig.maxSeqLength;
auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize()); auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize());
@ -259,13 +257,12 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
if (modelConfig.computeGenerationLogits()) if (modelConfig.computeGenerationLogits())
{ {
allGenerationLogits->reshape( allGenerationLogits->reshape(
ITensor::makeShape({(generationConfig.maxSeqLength - generationConfig.maxInputLength), batchSize, ITensor::makeShape({(maxSeqLength - maxInputLength), batchSize, beamWidth, vocabSizePadded}));
beamWidth, vocabSizePadded}));
cacheGenerationFragmentPointerDevice->reshape( cacheGenerationFragmentPointerDevice->reshape(
ITensor::makeShape({batchSize, (generationConfig.maxSeqLength - generationConfig.maxInputLength)})); ITensor::makeShape({batchSize, (maxSeqLength - maxInputLength)}));
cacheGenerationFragmentPointerHost->reshape( cacheGenerationFragmentPointerHost->reshape(
ITensor::makeShape({batchSize, (generationConfig.maxSeqLength - generationConfig.maxInputLength)})); ITensor::makeShape({batchSize, (maxSeqLength - maxInputLength)}));
} }
} }
@ -277,18 +274,11 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
= ITensor::makeShape({batchSize, 2, modelConfig.getNbKvHeads(), maxInputLength, modelConfig.getSizePerHead()}); = ITensor::makeShape({batchSize, 2, modelConfig.getNbKvHeads(), maxInputLength, modelConfig.getSizePerHead()});
if (modelConfig.usePagedKvCache()) if (modelConfig.usePagedKvCache())
{ {
auto const localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism()); TLLM_CHECK(kvCacheManager);
auto const tokensPerBlock = modelConfig.getTokensPerBlock();
SizeType bubbleLen
= (sinkTokenLen % tokensPerBlock == 0) ? 0 : tokensPerBlock - (sinkTokenLen % tokensPerBlock);
auto maxBlocksPerSeq = tc::ceilDiv(maxAttentionWindow + bubbleLen, tokensPerBlock);
// If beamWidth > 1, use one more block for each sequence in the paged kv cache to avoid dropping the needed
// tokens, when enabling cyclic kv cache.
if (beamWidth > 1 && maxSeqLength > maxAttentionWindow)
{
maxBlocksPerSeq += 1;
}
auto const localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism());
auto const maxBlocksPerSeq = kvCacheManager->getMaxBlocksPerSeq();
// reserve batchSize * beamWidth and resize to batchSize // reserve batchSize * beamWidth and resize to batchSize
auto cacheBlockPointersShape = ITensor::makeShape({localNbLayers, batchSize * beamWidth, 2, maxBlocksPerSeq}); auto cacheBlockPointersShape = ITensor::makeShape({localNbLayers, batchSize * beamWidth, 2, maxBlocksPerSeq});
kvCacheBlockPointersHost->reshape(cacheBlockPointersShape); kvCacheBlockPointersHost->reshape(cacheBlockPointersShape);
@ -302,11 +292,13 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
utils::reshapeBufferVector(presentKeysVals, kvCacheReserve); utils::reshapeBufferVector(presentKeysVals, kvCacheReserve);
} }
auto const localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism());
if (modelConfig.useGptAttentionPlugin()) if (modelConfig.useGptAttentionPlugin())
{ {
pastKeyValueLengths->reshape(ITensor::makeShape({batchSize})); pastKeyValueLengths->reshape(ITensor::makeShape({batchSize}));
requestTypes->reshape(ITensor::makeShape({batchSize})); requestTypes->reshape(ITensor::makeShape({batchSize}));
utils::reshapeBufferVector(maxAttentionWindows, ITensor::makeShape({1})); maxAttentionWindows->reshape(ITensor::makeShape({localNbLayers}));
sinkTokenLengths->reshape(ITensor::makeShape({1})); sinkTokenLengths->reshape(ITensor::makeShape({1}));
} }
else else
@ -332,7 +324,7 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
} }
allocated = true; allocated = true;
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void RuntimeBuffers::reset(BufferManager& manager) void RuntimeBuffers::reset(BufferManager& manager)
@ -345,7 +337,7 @@ void RuntimeBuffers::reset(BufferManager& manager)
std::vector<RuntimeBuffers> RuntimeBuffers::split( std::vector<RuntimeBuffers> RuntimeBuffers::split(
SizeType contextBatchSize, GptModelConfig const& modelConfig, WorldConfig const& worldConfig) SizeType contextBatchSize, GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
std::vector<RuntimeBuffers> bufferSlices; std::vector<RuntimeBuffers> bufferSlices;
auto const generationBatchSize = generationConfig.batchSize; auto const generationBatchSize = generationConfig.batchSize;
@ -432,14 +424,14 @@ std::vector<RuntimeBuffers> RuntimeBuffers::split(
} }
} }
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return bufferSlices; return bufferSlices;
} }
void RuntimeBuffers::gatherLastTokenLogits( void RuntimeBuffers::gatherLastTokenLogits(
BufferManager& manager, GptModelConfig const& modelConfig, WorldConfig const& worldConfig) BufferManager& manager, GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK_WITH_INFO(modelConfig.computeContextLogits(), TLLM_CHECK_WITH_INFO(modelConfig.computeContextLogits(),
"Gather last token logits is only necessary when context logits are computed"); "Gather last token logits is only necessary when context logits are computed");
@ -459,12 +451,12 @@ void RuntimeBuffers::gatherLastTokenLogits(
} }
} }
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void RuntimeBuffers::copyAttentionMasks(std::vector<RuntimeBuffers> const& contextBatches, BufferManager& manager) void RuntimeBuffers::copyAttentionMasks(std::vector<RuntimeBuffers> const& contextBatches, BufferManager& manager)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = generationConfig.batchSize; auto const batchSize = generationConfig.batchSize;
auto const maxInputLength = generationConfig.maxInputLength; auto const maxInputLength = generationConfig.maxInputLength;
@ -481,12 +473,12 @@ void RuntimeBuffers::copyAttentionMasks(std::vector<RuntimeBuffers> const& conte
manager.copy(*buffers.attentionMask, *attentionMaskSlice); manager.copy(*buffers.attentionMask, *attentionMaskSlice);
offset += contextBatchSize; offset += contextBatchSize;
} }
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void RuntimeBuffers::tile(BufferManager& manager, GptModelConfig const& modelConfig, WorldConfig const& worldConfig) void RuntimeBuffers::tile(BufferManager& manager, GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const beamWidth = generationConfig.beamWidth; auto const beamWidth = generationConfig.beamWidth;
TLLM_CHECK_WITH_INFO(beamWidth > 1, "Tiling is only necessary for beam search."); TLLM_CHECK_WITH_INFO(beamWidth > 1, "Tiling is only necessary for beam search.");
@ -519,13 +511,13 @@ void RuntimeBuffers::tile(BufferManager& manager, GptModelConfig const& modelCon
for (auto& buffer : presentKeysValsAlt) for (auto& buffer : presentKeysValsAlt)
utils::tileBufferReplace(buffer, beamWidth, manager); utils::tileBufferReplace(buffer, beamWidth, manager);
} }
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void RuntimeBuffers::postContextStep(std::vector<RuntimeBuffers> const& contextBuffers, BufferManager& manager, void RuntimeBuffers::postContextStep(std::vector<RuntimeBuffers> const& contextBuffers, BufferManager& manager,
GptModelConfig const& modelConfig, WorldConfig const& worldConfig) GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = generationConfig.batchSize; auto const batchSize = generationConfig.batchSize;
auto const beamWidth = generationConfig.beamWidth; auto const beamWidth = generationConfig.beamWidth;
@ -580,14 +572,14 @@ void RuntimeBuffers::postContextStep(std::vector<RuntimeBuffers> const& contextB
manager, modelConfig.usePackedInput()); manager, modelConfig.usePackedInput());
} }
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType const padId, BufferManager& manager, void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType const padId, BufferManager& manager,
KvCacheManager const* kvCacheManager, SizeType firstBatchSlotIdx, GptModelConfig const& modelConfig, batch_manager::kv_cache_manager::KVCacheManager const* kvCacheManager, SizeType firstBatchSlotIdx,
WorldConfig const& worldConfig) GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto& stream = manager.getStream(); auto& stream = manager.getStream();
SizeType const batchSize = generationConfig.batchSize; SizeType const batchSize = generationConfig.batchSize;
SizeType const maxInputLength = generationConfig.maxInputLength; SizeType const maxInputLength = generationConfig.maxInputLength;
@ -607,11 +599,9 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c
TLLM_CHECK(requestTypes->getSize() == static_cast<std::size_t>(batchSize)); TLLM_CHECK(requestTypes->getSize() == static_cast<std::size_t>(batchSize));
std::fill_n(RequestTypesPtr, batchSize, 0); std::fill_n(RequestTypesPtr, batchSize, 0);
// Set maxAttentionWindows buffer and sinkTokenLengths to the same value currently. auto maxAttentionWindowsPtr = bufferCast<SizeType>(*maxAttentionWindows);
for (auto layer = 0; layer < localNbLayers; ++layer) std::fill_n(maxAttentionWindowsPtr, localNbLayers, generationConfig.maxAttentionWindow);
{
bufferCast<SizeType>(*maxAttentionWindows[layer])[0] = generationConfig.maxAttentionWindow;
}
bufferCast<SizeType>(*sinkTokenLengths)[0] = generationConfig.sinkTokenLength; bufferCast<SizeType>(*sinkTokenLengths)[0] = generationConfig.sinkTokenLength;
auto const& inputShape = inputIds->getShape(); auto const& inputShape = inputIds->getShape();
@ -720,14 +710,14 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c
manager.copy(*contextLengthsDevice, *lastTokenIds); manager.copy(*contextLengthsDevice, *lastTokenIds);
} }
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
RuntimeBuffers::TensorPtr RuntimeBuffers::prepareNextStep(SizeType const step, BufferManager& manager, RuntimeBuffers::TensorPtr RuntimeBuffers::prepareNextStep(SizeType const step, BufferManager& manager,
KvCacheManager* kvCacheManager, SizeType firstBatchSlotIdx, GptModelConfig const& modelConfig, batch_manager::kv_cache_manager::KVCacheManager* kvCacheManager, SizeType firstBatchSlotIdx,
WorldConfig const& worldConfig) GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto& stream = manager.getStream(); auto& stream = manager.getStream();
SizeType const batchSize = generationConfig.batchSize; SizeType const batchSize = generationConfig.batchSize;
SizeType const beamWidth = generationConfig.beamWidth; SizeType const beamWidth = generationConfig.beamWidth;
@ -842,7 +832,7 @@ RuntimeBuffers::TensorPtr RuntimeBuffers::prepareNextStep(SizeType const step, B
{ {
kernels::invokeInclusiveSum(*lastTokenIds, *lastTokenIds, manager, stream); kernels::invokeInclusiveSum(*lastTokenIds, *lastTokenIds, manager, stream);
} }
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return nextInputIds; return nextInputIds;
} }
@ -850,7 +840,7 @@ void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outpu
TensorPtr const& inputIds, TensorPtr const& commPtrs, GptModelConfig const& modelConfig, TensorPtr const& inputIds, TensorPtr const& commPtrs, GptModelConfig const& modelConfig,
WorldConfig const& worldConfig) const WorldConfig const& worldConfig) const
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
inputBuffers.clear(); inputBuffers.clear();
outputBuffers.clear(); outputBuffers.clear();
@ -890,7 +880,7 @@ void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outpu
inputBuffers.insert_or_assign("host_request_types", requestTypes); inputBuffers.insert_or_assign("host_request_types", requestTypes);
inputBuffers.insert_or_assign("sequence_length", sequenceLengths); inputBuffers.insert_or_assign("sequence_length", sequenceLengths);
inputBuffers.insert_or_assign("host_sink_token_length", sinkTokenLengths); inputBuffers.insert_or_assign("host_sink_token_length", sinkTokenLengths);
utils::insertTensorVector(inputBuffers, "host_max_attention_window_size_", maxAttentionWindows, firstLayerId); inputBuffers.insert_or_assign("host_max_attention_window_sizes", maxAttentionWindows);
if (modelConfig.usePackedInput()) if (modelConfig.usePackedInput())
{ {
@ -898,10 +888,8 @@ void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outpu
} }
if (modelConfig.usePagedKvCache()) if (modelConfig.usePagedKvCache())
{ {
utils::insertTensorSlices( inputBuffers.insert_or_assign("kv_cache_block_pointers", kvCacheBlockPointersDevice);
inputBuffers, "kv_cache_block_pointers_", kvCacheBlockPointersDevice, firstLayerId); inputBuffers.insert_or_assign("host_kv_cache_block_pointers", kvCacheBlockPointersHost);
utils::insertTensorSlices(
inputBuffers, "host_kv_cache_block_pointers_", kvCacheBlockPointersHost, firstLayerId);
} }
else else
{ {
@ -946,7 +934,10 @@ void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outpu
inputBuffers.insert_or_assign("tasks", promptTuningParams.tasks); inputBuffers.insert_or_assign("tasks", promptTuningParams.tasks);
inputBuffers.insert_or_assign("prompt_vocab_size", promptTuningParams.vocabSize); inputBuffers.insert_or_assign("prompt_vocab_size", promptTuningParams.vocabSize);
} }
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
// utils::printTensorMap(std::cerr, inputBuffers);
// utils::printTensorMap(std::cerr, outputBuffers);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
std::vector<SizeType> RuntimeBuffers::getPositionIdsContextPhaseGlm(const SizeType& batchSize, std::vector<SizeType> RuntimeBuffers::getPositionIdsContextPhaseGlm(const SizeType& batchSize,

View File

@ -99,11 +99,11 @@ public:
// still point to `allGenerationLogits` and bring overwrite conflict. // still point to `allGenerationLogits` and bring overwrite conflict.
std::vector<TensorPtr> presentKeysVals; std::vector<TensorPtr> presentKeysVals;
std::vector<TensorPtr> presentKeysValsAlt; // without attention plugin std::vector<TensorPtr> presentKeysValsAlt; // without attention plugin
std::vector<TensorPtr> maxAttentionWindows; // with attention plugin, host tensor TensorPtr maxAttentionWindows; // with attention plugin, host tensor
TensorPtr sinkTokenLengths; // with attention plugin, host tensor TensorPtr sinkTokenLengths; // with attention plugin, host tensor
TensorPtr kvCacheBlockPointersHost; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2] TensorPtr kvCacheBlockPointersHost; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
TensorPtr kvCacheBlockPointersDevice; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2] TensorPtr kvCacheBlockPointersDevice; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
// References to tmp buffers // References to tmp buffers
TensorPtr newTokens; TensorPtr newTokens;
@ -147,7 +147,8 @@ public:
SizeType maxAttentionWindow, SizeType sinkTokenLength, SizeType maxSequenceLength, BufferManager& manager); SizeType maxAttentionWindow, SizeType sinkTokenLength, SizeType maxSequenceLength, BufferManager& manager);
//! \brief Reshape buffers based on current GenerationConfig //! \brief Reshape buffers based on current GenerationConfig
void reshape(GptModelConfig const& modelConfig, WorldConfig const& worldConfig); void reshape(
KvCacheManager const* kvCacheManager, GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
void reset(BufferManager& manager); void reset(BufferManager& manager);

View File

@ -417,6 +417,21 @@ void invokeInclusiveSum(IBuffer& output, IBuffer const& input, BufferManager con
cub::DeviceScan::InclusiveSum(tempStorageData, tempStorageBytes, inputData, outputData, size, stream.get()); cub::DeviceScan::InclusiveSum(tempStorageData, tempStorageBytes, inputData, outputData, size, stream.get());
} }
void invokeInclusiveSum(IBuffer& output, IBuffer& tmpBuffer, IBuffer const& input, CudaStream const& stream)
{
TLLM_CHECK_WITH_INFO(nvinfer1::DataType::kUINT8 == tmpBuffer.getDataType(), "tmpBuffer has wrong data type");
auto const size = input.getSize();
auto const* inputData = bufferCast<SizeType>(input);
auto* outputData = bufferCast<SizeType>(output);
std::size_t tempStorageBytes{0};
cub::DeviceScan::InclusiveSum(nullptr, tempStorageBytes, inputData, outputData, size, stream.get());
tmpBuffer.resize(tempStorageBytes);
auto* tmpBufferPtr = bufferCast<std::uint8_t>(tmpBuffer);
cub::DeviceScan::InclusiveSum(tmpBufferPtr, tempStorageBytes, inputData, outputData, size, stream.get());
}
namespace namespace
{ {
__global__ void buildTokenMask(SizeType* tokenMask, SizeType const* inputLengths, SizeType const batchSize, __global__ void buildTokenMask(SizeType* tokenMask, SizeType const* inputLengths, SizeType const batchSize,
@ -752,7 +767,7 @@ void initOutputIds(ITensor& outputIds, ITensor const& inputIds, ITensor const& i
ITensor const& inputOffsets, TokenIdType const padId, TokenIdType const endId, SizeType const maxInputLength, ITensor const& inputOffsets, TokenIdType const padId, TokenIdType const endId, SizeType const maxInputLength,
bool const inputPacked, CudaStream const& stream) bool const inputPacked, CudaStream const& stream)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
kernels::invokeFill(outputIds, endId, stream); kernels::invokeFill(outputIds, endId, stream);
if (inputPacked) if (inputPacked)
@ -763,7 +778,7 @@ void initOutputIds(ITensor& outputIds, ITensor const& inputIds, ITensor const& i
{ {
kernels::invokeCopyInputToOutput(outputIds, inputIds, inputLengths, padId, stream); kernels::invokeCopyInputToOutput(outputIds, inputIds, inputLengths, padId, stream);
} }
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
namespace namespace

View File

@ -51,6 +51,8 @@ void invokeTransposeWithInputOffset(
void invokeInclusiveSum(IBuffer& output, IBuffer const& input, BufferManager const& manager, CudaStream const& stream); void invokeInclusiveSum(IBuffer& output, IBuffer const& input, BufferManager const& manager, CudaStream const& stream);
void invokeInclusiveSum(IBuffer& output, IBuffer& tmpBuffer, IBuffer const& input, CudaStream const& stream);
void invokeBuildTokenMask( void invokeBuildTokenMask(
ITensor& tokenMask, ITensor const& inputLengths, SizeType maxInputLength, CudaStream const& stream); ITensor& tokenMask, ITensor const& inputLengths, SizeType maxInputLength, CudaStream const& stream);

View File

@ -16,13 +16,12 @@
#include "tensorrt_llm/runtime/statefulGptDecoder.h" #include "tensorrt_llm/runtime/statefulGptDecoder.h"
#include <algorithm>
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/kernels/decodingCommon.h"
#include "tensorrt_llm/runtime/runtimeKernels.h" #include "tensorrt_llm/runtime/runtimeKernels.h"
#include <algorithm>
namespace tc = tensorrt_llm::common; namespace tc = tensorrt_llm::common;
namespace tk = tensorrt_llm::kernels; namespace tk = tensorrt_llm::kernels;
using namespace tensorrt_llm::runtime; using namespace tensorrt_llm::runtime;
@ -35,7 +34,7 @@ StatefulGptDecoder::StatefulGptDecoder(std::size_t vocabSize, std::size_t vocabS
, mStream{std::move(stream)} , mStream{std::move(stream)}
, mBufferManager{mStream} , mBufferManager{mStream}
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto constexpr nvTokenIdType = TRTDataType<TokenIdType>::value; auto constexpr nvTokenIdType = TRTDataType<TokenIdType>::value;
auto constexpr nvSizeType = TRTDataType<SizeType>::value; auto constexpr nvSizeType = TRTDataType<SizeType>::value;
auto constexpr nvFloatType = TRTDataType<float>::value; auto constexpr nvFloatType = TRTDataType<float>::value;
@ -68,26 +67,26 @@ StatefulGptDecoder::StatefulGptDecoder(std::size_t vocabSize, std::size_t vocabS
mFinishedSum = mBufferManager.pinned(ITensor::makeShape({1}), nvSizeType); mFinishedSum = mBufferManager.pinned(ITensor::makeShape({1}), nvSizeType);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void StatefulGptDecoder::setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth, void StatefulGptDecoder::setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth,
SizeType maxAttentionWindow, SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep, SizeType maxAttentionWindow, SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep,
bool fusedDecoder, nvinfer1::DataType dtype) bool fusedDecoder, nvinfer1::DataType dtype)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(maxTokensPerStep == 1); TLLM_CHECK(maxTokensPerStep == 1);
mDecoder = IGptDecoder::create( mDecoder = IGptDecoder::create(
mode, dtype, maxBatchSize, maxBeamWidth, mVocabSize, mVocabSizePadded, maxSequenceLength, mStream); mode, dtype, maxBatchSize, maxBeamWidth, mVocabSize, mVocabSizePadded, maxSequenceLength, mStream);
reshapeBuffers(maxBatchSize, maxBeamWidth, maxAttentionWindow, sinkTokenLength, maxSequenceLength); reshapeBuffers(maxBatchSize, maxBeamWidth, maxAttentionWindow, sinkTokenLength, maxSequenceLength);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void StatefulGptDecoder::reshapeBuffers(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow, void StatefulGptDecoder::reshapeBuffers(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow,
SizeType sinkTokenLength, SizeType maxSequenceLength) SizeType sinkTokenLength, SizeType maxSequenceLength)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(batchSize > 0); TLLM_CHECK(batchSize > 0);
TLLM_CHECK(beamWidth > 0); TLLM_CHECK(beamWidth > 0);
TLLM_CHECK(maxSequenceLength > 0); TLLM_CHECK(maxSequenceLength > 0);
@ -139,13 +138,13 @@ void StatefulGptDecoder::reshapeBuffers(SizeType batchSize, SizeType beamWidth,
} }
mNbSteps = 0; mNbSteps = 0;
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void StatefulGptDecoder::newBatch( void StatefulGptDecoder::newBatch(
GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig) GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto& manager = mBufferManager; auto& manager = mBufferManager;
auto& stream = mStream; auto& stream = mStream;
@ -296,12 +295,12 @@ void StatefulGptDecoder::newBatch(
// remaining // remaining
mNbSteps = 0; mNbSteps = 0;
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void StatefulGptDecoder::forwardAsync(decoder::Output& output, decoder::Input const& input) void StatefulGptDecoder::forwardAsync(decoder::Output& output, decoder::Input const& input)
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto& logits = input.logits; auto& logits = input.logits;
auto const& logitsShape = logits->getShape(); auto const& logitsShape = logits->getShape();
@ -335,24 +334,24 @@ void StatefulGptDecoder::forwardAsync(decoder::Output& output, decoder::Input co
dInput.step += 1; dInput.step += 1;
mNbSteps += 1; mNbSteps += 1;
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void StatefulGptDecoder::forwardSync() void StatefulGptDecoder::forwardSync()
{ {
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
mDecodedEvent.synchronize(); mDecodedEvent.synchronize();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
} }
void StatefulGptDecoder::finalize() const void StatefulGptDecoder::finalize() const
{ {
// TODO (rkobus) can we do this inplace? // TODO (rkobus) can we do this inplace?
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto& outputIds = mDecodingOutput->ids; auto& outputIds = mDecodingOutput->ids;
auto finalOutputIds = mBufferManager.gpu(outputIds->getShape(), outputIds->getDataType()); auto finalOutputIds = mBufferManager.gpu(outputIds->getShape(), outputIds->getDataType());
mDecoder->gatherTree(*finalOutputIds, *mDecodingOutput, *mDecodingInput, mBufferManager); mDecoder->gatherTree(*finalOutputIds, *mDecodingOutput, *mDecodingInput, mBufferManager);
mBufferManager.copy(*finalOutputIds, *outputIds); mBufferManager.copy(*finalOutputIds, *outputIds);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return; return;
} }

View File

@ -18,17 +18,12 @@
#include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaEvent.h" #include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/decodingMode.h" #include "tensorrt_llm/runtime/decodingMode.h"
#include "tensorrt_llm/runtime/gptDecoder.h" #include "tensorrt_llm/runtime/gptDecoder.h"
#include "tensorrt_llm/runtime/iStatefulGptDecoder.h" #include "tensorrt_llm/runtime/iStatefulGptDecoder.h"
#include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/iTensor.h"
#include <cstdint>
#include <memory> #include <memory>
#include <optional>
#include <utility>
#include <vector>
namespace tensorrt_llm::runtime namespace tensorrt_llm::runtime
{ {
@ -88,7 +83,7 @@ public:
//! @returns [batchSize, maxBeamWidth], tokens generated in last forward pass, on gpu //! @returns [batchSize, maxBeamWidth], tokens generated in last forward pass, on gpu
[[nodiscard]] TensorPtr getAllNewTokens() const override [[nodiscard]] TensorPtr getAllNewTokens() const override
{ {
TensorPtr newTokens = std::move(ITensor::view(mDecodingOutput->newTokensSteps)); TensorPtr newTokens = ITensor::view(mDecodingOutput->newTokensSteps);
newTokens->unsqueeze(0); newTokens->unsqueeze(0);
return newTokens; return newTokens;
} }

View File

@ -19,7 +19,6 @@
#include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/cudaStream.h" #include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/iBuffer.h" #include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/iTensor.h"
@ -30,7 +29,6 @@
#include <algorithm> #include <algorithm>
#include <cstdlib> #include <cstdlib>
#include <limits>
#include <list> #include <list>
#include <memory> #include <memory>
#include <mutex> #include <mutex>

View File

@ -14,9 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "tllmRuntime.h" #include "tllmRuntime.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/nvtxUtils.h" #include "tensorrt_llm/common/nvtxUtils.h"
#include "tensorrt_llm/common/stringUtils.h"
#include "tllmLogger.h" #include "tllmLogger.h"
#include <limits> #include <limits>
@ -24,8 +22,6 @@
using namespace tensorrt_llm::runtime; using namespace tensorrt_llm::runtime;
namespace tc = tensorrt_llm::common;
namespace namespace
{ {
using DimType = std::remove_reference_t<decltype(std::declval<nvinfer1::Dims>().d[0])>; using DimType = std::remove_reference_t<decltype(std::declval<nvinfer1::Dims>().d[0])>;
@ -108,14 +104,12 @@ bool TllmRuntime::executeContext(SizeType contextIndex) const
void TllmRuntime::setInputTensors(SizeType contextIndex, TensorMap const& tensorMap) void TllmRuntime::setInputTensors(SizeType contextIndex, TensorMap const& tensorMap)
{ {
NVTX3_FUNC_RANGE(); NVTX3_FUNC_RANGE();
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
auto& context = getContext(contextIndex); auto& context = getContext(contextIndex);
for (std::int32_t i = 0; i < mEngine->getNbIOTensors(); ++i) for (std::int32_t i = 0; i < mEngine->getNbIOTensors(); ++i)
{ {
auto const name = mEngine->getIOTensorName(i); auto const name = mEngine->getIOTensorName(i);
if (mEngine->getTensorIOMode(name) == nvinfer1::TensorIOMode::kINPUT) if (mEngine->getTensorIOMode(name) == nvinfer1::TensorIOMode::kINPUT)
{ {
NVTX3_SCOPED_RANGE(input_tensor);
auto pos = tensorMap.find(name); auto pos = tensorMap.find(name);
if (pos == tensorMap.end()) if (pos == tensorMap.end())
{ {
@ -197,7 +191,6 @@ void TllmRuntime::setOutputTensors(SizeType contextIndex, TensorMap& tensorMap)
auto const name = mEngine->getIOTensorName(i); auto const name = mEngine->getIOTensorName(i);
if (mEngine->getTensorIOMode(name) == nvinfer1::TensorIOMode::kOUTPUT) if (mEngine->getTensorIOMode(name) == nvinfer1::TensorIOMode::kOUTPUT)
{ {
NVTX3_SCOPED_RANGE(output_tensor);
auto const dims = context.getTensorShape(name); auto const dims = context.getTensorShape(name);
auto const engineDtype = mEngine->getTensorDataType(name); auto const engineDtype = mEngine->getTensorDataType(name);
auto pos = tensorMap.find(name); auto pos = tensorMap.find(name);

View File

@ -23,8 +23,6 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <torch/types.h> #include <torch/types.h>
#include <memory>
namespace tensorrt_llm::runtime namespace tensorrt_llm::runtime
{ {
class TorchView : virtual public ITensor class TorchView : virtual public ITensor

View File

@ -115,6 +115,16 @@ void insertTensorSlices(
} }
} }
void printTensorMap(std::ostream& stream, StringPtrMap<ITensor> const& map)
{
for (auto const& [name, tensor] : map)
{
stream << "Tensor name: " << name << '\n';
stream << "Shape" << tensor->getShape() << '\n';
stream << *tensor << '\n';
}
}
void setRawPointers(ITensor& pointers, ITensor::SharedPtr const& input, int32_t pointersSlot, int32_t inputSlot) void setRawPointers(ITensor& pointers, ITensor::SharedPtr const& input, int32_t pointersSlot, int32_t inputSlot)
{ {
auto const pointersLength = static_cast<int32_t>(pointers.getSizeInBytes() / sizeof(void**)); auto const pointersLength = static_cast<int32_t>(pointers.getSizeInBytes() / sizeof(void**));

View File

@ -64,6 +64,8 @@ void insertTensorVector(StringPtrMap<ITensor>& map, std::string const& key, std:
void insertTensorSlices( void insertTensorSlices(
StringPtrMap<ITensor>& map, std::string const& key, ITensor::SharedPtr const& tensor, SizeType indexOffset); StringPtrMap<ITensor>& map, std::string const& key, ITensor::SharedPtr const& tensor, SizeType indexOffset);
void printTensorMap(std::ostream& stream, StringPtrMap<ITensor> const& map);
void setRawPointers(ITensor& pointers, ITensor::SharedPtr const& input, int32_t pointersSlot, int32_t inputSlot); void setRawPointers(ITensor& pointers, ITensor::SharedPtr const& input, int32_t pointersSlot, int32_t inputSlot);
void setRawPointers(ITensor& pointers, ITensor::SharedPtr const& input); void setRawPointers(ITensor& pointers, ITensor::SharedPtr const& input);

View File

@ -85,8 +85,9 @@ WorldConfig WorldConfig::mpi(SizeType gpusPerNode, std::optional<SizeType> tenso
auto const mpiSize = comm.getSize(); auto const mpiSize = comm.getSize();
auto const mpiRank = comm.getRank(); auto const mpiRank = comm.getRank();
TLLM_LOG_INFO("MPI size: %d, rank: %d", mpiSize, mpiRank); TLLM_LOG_INFO("MPI size: %d, rank: %d", mpiSize, mpiRank);
auto pp = pipelineParallelism.value_or(1); auto const pp = pipelineParallelism.value_or(1);
auto tp = tensorParallelism.value_or(mpiSize / pp); auto const tp = tensorParallelism.value_or(mpiSize / pp);
TLLM_LOG_DEBUG("TP: %d, PP: %d", tp, pp);
TLLM_CHECK(mpiSize == tp * pp); TLLM_CHECK(mpiSize == tp * pp);
return WorldConfig{tp, pp, mpiRank, gpusPerNode, deviceIds}; return WorldConfig{tp, pp, mpiRank, gpusPerNode, deviceIds};

View File

@ -103,11 +103,12 @@ void FtDynamicDecode<T>::setup(size_t batch_size, size_t beam_width, th::optiona
th::optional<th::Tensor> runtime_top_p_opt, th::optional<th::Tensor> temperature_opt, th::optional<th::Tensor> runtime_top_p_opt, th::optional<th::Tensor> temperature_opt,
th::optional<th::Tensor> repetition_penalty_opt, th::optional<th::Tensor> presence_penalty_opt, th::optional<th::Tensor> repetition_penalty_opt, th::optional<th::Tensor> presence_penalty_opt,
th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> min_length_opt, th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> min_length_opt,
th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> beam_search_diversity_rate_opt, th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> early_stopping_opt,
th::optional<th::Tensor> random_seed_opt, th::optional<th::Tensor> top_p_decay_opt, th::optional<th::Tensor> beam_search_diversity_rate_opt, th::optional<th::Tensor> random_seed_opt,
th::optional<th::Tensor> top_p_min_opt, th::optional<th::Tensor> top_p_reset_ids_opt) th::optional<th::Tensor> top_p_decay_opt, th::optional<th::Tensor> top_p_min_opt,
th::optional<th::Tensor> top_p_reset_ids_opt)
{ {
// unused: length_penalty_opt, beam_search_diversity_rate_opt // unused: length_penalty_opt, beam_search_diversity_rate_opt, early_stopping_opt
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
dynamic_decode_layer_->setStream(stream); dynamic_decode_layer_->setStream(stream);
@ -127,6 +128,7 @@ void FtDynamicDecode<T>::setup(size_t batch_size, size_t beam_width, th::optiona
safeInsert(top_p_reset_ids_opt, setupParams.top_p_reset_ids); safeInsert(top_p_reset_ids_opt, setupParams.top_p_reset_ids);
safeInsert(beam_search_diversity_rate_opt, setupParams.beam_search_diversity_rate); safeInsert(beam_search_diversity_rate_opt, setupParams.beam_search_diversity_rate);
safeInsert(length_penalty_opt, setupParams.length_penalty); safeInsert(length_penalty_opt, setupParams.length_penalty);
safeInsert(early_stopping_opt, setupParams.early_stopping);
dynamic_decode_layer_->setup(batch_size, beam_width, nullptr, setupParams); dynamic_decode_layer_->setup(batch_size, beam_width, nullptr, setupParams);
} }
@ -211,13 +213,13 @@ void FtDynamicDecode<T>::forward(th::Tensor& logits, // (batch_size, beam_width,
dynamic_decode_layer_->forward(outputParams, forwardParams); dynamic_decode_layer_->forward(outputParams, forwardParams);
if (finished_sum_host) if (finished_sum_host)
{ {
TLLM_CUDA_CHECK(::cudaStreamSynchronize(dynamic_decode_layer_->getStream()));
int32_t finished_sum = 0; int32_t finished_sum = 0;
for (int32_t bi = 0; bi < local_batch_size; ++bi) for (int32_t bi = 0; bi < local_batch_size; ++bi)
{ {
finished_sum += finished_sum_host[bi]; finished_sum += finished_sum_host[bi];
} }
auto const numToFinish = outputParams.finished->size(); auto const numToFinish = outputParams.finished->size();
TLLM_CUDA_CHECK(::cudaStreamSynchronize(dynamic_decode_layer_->getStream()));
auto should_stop_accessor = should_stop.accessor<bool, 1>(); auto should_stop_accessor = should_stop.accessor<bool, 1>();
should_stop_accessor[0] = numToFinish == finished_sum; should_stop_accessor[0] = numToFinish == finished_sum;
} }
@ -259,9 +261,10 @@ void DynamicDecodeOp::setup(int64_t batch_size, int64_t beam_width, th::optional
th::optional<th::Tensor> runtime_top_p_opt, th::optional<th::Tensor> temperature_opt, th::optional<th::Tensor> runtime_top_p_opt, th::optional<th::Tensor> temperature_opt,
th::optional<th::Tensor> repetition_penalty_opt, th::optional<th::Tensor> presence_penalty_opt, th::optional<th::Tensor> repetition_penalty_opt, th::optional<th::Tensor> presence_penalty_opt,
th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> min_length_opt, th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> min_length_opt,
th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> beam_search_diversity_rate_opt, th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> early_stopping_opt,
th::optional<th::Tensor> random_seed_opt, th::optional<th::Tensor> top_p_decay_opt, th::optional<th::Tensor> beam_search_diversity_rate_opt, th::optional<th::Tensor> random_seed_opt,
th::optional<th::Tensor> top_p_min_opt, th::optional<th::Tensor> top_p_reset_ids_opt) th::optional<th::Tensor> top_p_decay_opt, th::optional<th::Tensor> top_p_min_opt,
th::optional<th::Tensor> top_p_reset_ids_opt)
{ {
// TODO: Revise DynamicDecodeLayer and make the decode arguments consistent. // TODO: Revise DynamicDecodeLayer and make the decode arguments consistent.
CHECK_OPTIONAL_CPU_INPUT(runtime_top_k_opt, torch::kInt32); CHECK_OPTIONAL_CPU_INPUT(runtime_top_k_opt, torch::kInt32);
@ -273,6 +276,7 @@ void DynamicDecodeOp::setup(int64_t batch_size, int64_t beam_width, th::optional
CHECK_OPTIONAL_CPU_INPUT(frequency_penalty_opt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(frequency_penalty_opt, torch::kFloat);
CHECK_OPTIONAL_CPU_INPUT(min_length_opt, torch::kInt32); CHECK_OPTIONAL_CPU_INPUT(min_length_opt, torch::kInt32);
CHECK_OPTIONAL_CPU_INPUT(length_penalty_opt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(length_penalty_opt, torch::kFloat);
CHECK_OPTIONAL_CPU_INPUT(early_stopping_opt, torch::kInt32);
CHECK_OPTIONAL_CPU_INPUT(beam_search_diversity_rate_opt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(beam_search_diversity_rate_opt, torch::kFloat);
CHECK_OPTIONAL_CPU_INPUT(random_seed_opt, torch::kInt64); CHECK_OPTIONAL_CPU_INPUT(random_seed_opt, torch::kInt64);
CHECK_OPTIONAL_INPUT(top_p_decay_opt, torch::kFloat); CHECK_OPTIONAL_INPUT(top_p_decay_opt, torch::kFloat);
@ -281,8 +285,8 @@ void DynamicDecodeOp::setup(int64_t batch_size, int64_t beam_width, th::optional
dynamic_decode_->setup(static_cast<size_t>(batch_size), static_cast<size_t>(beam_width), runtime_top_k_opt, dynamic_decode_->setup(static_cast<size_t>(batch_size), static_cast<size_t>(beam_width), runtime_top_k_opt,
runtime_top_p_opt, temperature_opt, repetition_penalty_opt, presence_penalty_opt, frequency_penalty_opt, runtime_top_p_opt, temperature_opt, repetition_penalty_opt, presence_penalty_opt, frequency_penalty_opt,
min_length_opt, length_penalty_opt, beam_search_diversity_rate_opt, random_seed_opt, top_p_decay_opt, min_length_opt, length_penalty_opt, early_stopping_opt, beam_search_diversity_rate_opt, random_seed_opt,
top_p_min_opt, top_p_reset_ids_opt); top_p_decay_opt, top_p_min_opt, top_p_reset_ids_opt);
} }
th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max_input_length, th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max_input_length,

Some files were not shown because too many files have changed in this diff Show More