mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#1168)
* Update TensorRT-LLM --------- Co-authored-by: Bhuvanesh Sridharan <bhuvan.sridharan@gmail.com> Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
parent
e4e09dafea
commit
655524dd82
@ -428,7 +428,7 @@ public:
|
||||
, mStaticEmulatedTimeoutMs(staticEmulatedTimeoutMs)
|
||||
, mActiveCount(0)
|
||||
{
|
||||
ReturnBatchManagerStatsCallback iterationDataCallback = [this, &logIterationData](std::string const& log)
|
||||
ReturnBatchManagerStatsCallback iterationDataCallback = [this, logIterationData](std::string const& log)
|
||||
{
|
||||
if (logIterationData)
|
||||
{
|
||||
@ -563,16 +563,18 @@ public:
|
||||
{
|
||||
auto numNewWorkItems = static_cast<int64_t>(rval.size());
|
||||
comm.bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0);
|
||||
|
||||
std::vector<int64_t> packed;
|
||||
for (auto const& ir : rval)
|
||||
if (numNewWorkItems > 0)
|
||||
{
|
||||
auto vpacked = ir->serialize();
|
||||
packed.push_back(static_cast<int64_t>(vpacked.size()));
|
||||
packed.insert(
|
||||
packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
|
||||
std::vector<int64_t> packed;
|
||||
for (auto const& ir : rval)
|
||||
{
|
||||
auto vpacked = ir->serialize();
|
||||
packed.push_back(static_cast<int64_t>(vpacked.size()));
|
||||
packed.insert(
|
||||
packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
|
||||
}
|
||||
comm.bcast(packed, 0);
|
||||
}
|
||||
comm.bcast(packed, 0);
|
||||
}
|
||||
}
|
||||
else
|
||||
@ -791,7 +793,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
||||
recorder->report();
|
||||
recorder->writeOpMetricsToCsv();
|
||||
// Send terminateReqId to terminate servers on all ranks
|
||||
// Sever on rank 0 will broadcast the terminate signal to other servers on multi-GPU cases
|
||||
// Server on rank 0 will broadcast the terminate signal to other servers on multi-GPU cases
|
||||
gptServer->enqueue(std::make_shared<InferenceRequest>(terminateReqId));
|
||||
}
|
||||
// Wait until benchmarking is done and batch manager is terminated
|
||||
|
||||
@ -114,9 +114,9 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
"benchmark on %d tokens",
|
||||
maxNumTokens.value(), maxBatchSize * maxInputLength);
|
||||
}
|
||||
std::atomic_bool done = false;
|
||||
try
|
||||
{
|
||||
std::atomic_bool done = false;
|
||||
auto peakMemFuture = std::async(&monitorMemory, std::ref(done));
|
||||
TLLM_LOG_INFO(memoryCounter.toString());
|
||||
|
||||
@ -266,11 +266,14 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
catch (std::runtime_error& e)
|
||||
{
|
||||
std::size_t found = std::string(e.what()).find("out of memory");
|
||||
// We need to kill the memory monitor when OOM.
|
||||
done = true;
|
||||
|
||||
// Unexpected error; rethrow
|
||||
if (found == std::string::npos)
|
||||
{
|
||||
throw;
|
||||
TLLM_LOG_ERROR(e.what());
|
||||
throw e;
|
||||
}
|
||||
|
||||
// We can ignore the OOM exception and continue the rest of the benchmark
|
||||
@ -283,6 +286,12 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
}
|
||||
continue;
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
// We need to kill memory monitor when any other issue occurs
|
||||
done = true;
|
||||
throw;
|
||||
}
|
||||
}
|
||||
TLLM_LOG_INFO(memoryCounter.toString());
|
||||
}
|
||||
|
||||
@ -85,6 +85,7 @@ class EncDecBuildConfig:
|
||||
max_decoder_input_len: Optional[int] = None
|
||||
max_output_len: Optional[int] = None
|
||||
builder_opt: Optional[int] = None
|
||||
n_mels: Optional[int] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.head_size is not None
|
||||
@ -1179,6 +1180,27 @@ _allowed_configs = {
|
||||
mamba_d_conv=4,
|
||||
mamba_expand=2,
|
||||
)),
|
||||
"whisper_large_v3":
|
||||
ModelConfig(name="whisper_large_v3",
|
||||
family="whisper",
|
||||
benchmark_type="enc_dec",
|
||||
build_config=EncDecBuildConfig(
|
||||
num_layers=32,
|
||||
num_decoder_layers=32,
|
||||
num_heads=20,
|
||||
head_size=64,
|
||||
ffn_hidden_size=5120,
|
||||
hidden_size=1280,
|
||||
vocab_size=51866,
|
||||
hidden_act="gelu",
|
||||
n_positions=448,
|
||||
n_mels=128,
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1500,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -795,7 +795,7 @@ def build_gpt(args):
|
||||
max_beam_width=max_beam_width)
|
||||
if family in [
|
||||
'opt', 'bloom', 'falcon', 'llama', 'internlm', 'gptneox',
|
||||
'gptj', 'mamba', 'baichuan'
|
||||
'gptj', 'mamba', 'baichuan', 'chatglm', 'chatglm2', 'chatglm3'
|
||||
]:
|
||||
tensorrt_llm_model(**inputs)
|
||||
else:
|
||||
@ -957,6 +957,8 @@ def enc_dec_build_helper(component, config, args):
|
||||
torch.cuda.set_device(runtime_rank)
|
||||
|
||||
family = get_model_family(args.model)
|
||||
logits_dtype = 'float32'
|
||||
n_mels = 0
|
||||
if family == 'bart':
|
||||
q_scaling = 1.0
|
||||
has_attention_qkvo_bias = True
|
||||
@ -969,6 +971,19 @@ def enc_dec_build_helper(component, config, args):
|
||||
layernorm_position = LayerNormPositionType.pre_layernorm if config.get(
|
||||
'normalize_before', True) else LayerNormPositionType.post_layernorm
|
||||
rescale_before_lm_head = False
|
||||
elif family == 'whisper':
|
||||
q_scaling = 1.0
|
||||
has_position_embedding = True
|
||||
relative_attention = False
|
||||
has_embedding_layernorm = False
|
||||
has_attention_qkvo_bias = True
|
||||
has_mlp_bias = True
|
||||
has_model_final_layernorm = True
|
||||
layernorm_position = LayerNormPositionType.pre_layernorm
|
||||
layernorm_type = LayerNormType.LayerNorm
|
||||
rescale_before_lm_head = False
|
||||
logits_dtype = str_dtype_to_trt(args.dtype)
|
||||
n_mels = config['n_mels']
|
||||
else:
|
||||
q_scaling = 1 / config['head_size']**.5
|
||||
has_attention_qkvo_bias = False
|
||||
@ -984,6 +999,9 @@ def enc_dec_build_helper(component, config, args):
|
||||
else:
|
||||
rescale_before_lm_head = False
|
||||
|
||||
quant_mode, _, _ = get_quant_mode(args.quantization)
|
||||
use_weight_only = quant_mode.is_weight_only()
|
||||
|
||||
builder = Builder()
|
||||
builder_config = builder.create_builder_config(
|
||||
name=args.model,
|
||||
@ -1011,6 +1029,10 @@ def enc_dec_build_helper(component, config, args):
|
||||
has_token_type_embedding=False, # by default
|
||||
strongly_typed=False, # by default
|
||||
gather_all_token_logits=False, # by default
|
||||
int8=(quant_mode.has_act_and_weight_quant()
|
||||
or quant_mode.is_int8_weight_only()),
|
||||
quant_mode=quant_mode,
|
||||
n_mels=n_mels,
|
||||
)
|
||||
|
||||
# build engine
|
||||
@ -1024,34 +1046,45 @@ def enc_dec_build_helper(component, config, args):
|
||||
fp16_clamping = (args.dtype == 'float16') and ('t5' in family)
|
||||
|
||||
if component == 'encoder':
|
||||
tllm_model = tensorrt_llm.models.EncoderModel(
|
||||
num_layers=config['num_layers'],
|
||||
num_heads=config['num_heads'],
|
||||
num_kv_heads=config['num_heads'],
|
||||
head_size=config['head_size'],
|
||||
hidden_size=config['hidden_size'],
|
||||
ffn_hidden_size=config['ffn_hidden_size'],
|
||||
vocab_size=config['vocab_size'],
|
||||
max_position_embeddings=config.get('n_positions', 0),
|
||||
has_position_embedding=has_position_embedding,
|
||||
relative_attention=relative_attention,
|
||||
max_distance=config.get('max_distance', 0),
|
||||
num_buckets=config.get('num_buckets', 0),
|
||||
has_embedding_layernorm=has_embedding_layernorm,
|
||||
has_embedding_scale=config.get('has_embedding_scale', False),
|
||||
q_scaling=q_scaling,
|
||||
has_attention_qkvo_bias=has_attention_qkvo_bias,
|
||||
has_mlp_bias=has_mlp_bias,
|
||||
has_model_final_layernorm=has_model_final_layernorm,
|
||||
layernorm_eps=1e-6,
|
||||
layernorm_position=layernorm_position,
|
||||
layernorm_type=layernorm_type,
|
||||
hidden_act=config['hidden_act'],
|
||||
dtype=dtype,
|
||||
use_parallel_embedding=False, # by default
|
||||
embedding_sharding_dim=0, # by default
|
||||
mapping=mapping,
|
||||
fp16_clamping=fp16_clamping)
|
||||
if family == 'whisper':
|
||||
tllm_model = tensorrt_llm.models.WhisperEncoder(
|
||||
n_mels=config['n_mels'],
|
||||
n_ctx=1500, # n_audio_ctx
|
||||
n_state=config['hidden_size'],
|
||||
n_head=config['num_heads'],
|
||||
n_layer=config['num_layers'],
|
||||
dtype=dtype)
|
||||
if use_weight_only:
|
||||
tllm_model = quantize_model(tllm_model, quant_mode)
|
||||
else:
|
||||
tllm_model = tensorrt_llm.models.EncoderModel(
|
||||
num_layers=config['num_layers'],
|
||||
num_heads=config['num_heads'],
|
||||
num_kv_heads=config['num_heads'],
|
||||
head_size=config['head_size'],
|
||||
hidden_size=config['hidden_size'],
|
||||
ffn_hidden_size=config['ffn_hidden_size'],
|
||||
vocab_size=config['vocab_size'],
|
||||
max_position_embeddings=config.get('n_positions', 0),
|
||||
has_position_embedding=has_position_embedding,
|
||||
relative_attention=relative_attention,
|
||||
max_distance=config.get('max_distance', 0),
|
||||
num_buckets=config.get('num_buckets', 0),
|
||||
has_embedding_layernorm=has_embedding_layernorm,
|
||||
has_embedding_scale=config.get('has_embedding_scale', False),
|
||||
q_scaling=q_scaling,
|
||||
has_attention_qkvo_bias=has_attention_qkvo_bias,
|
||||
has_mlp_bias=has_mlp_bias,
|
||||
has_model_final_layernorm=has_model_final_layernorm,
|
||||
layernorm_eps=1e-6,
|
||||
layernorm_position=layernorm_position,
|
||||
layernorm_type=layernorm_type,
|
||||
hidden_act=config['hidden_act'],
|
||||
dtype=dtype,
|
||||
use_parallel_embedding=False, # by default
|
||||
embedding_sharding_dim=0, # by default
|
||||
mapping=mapping,
|
||||
fp16_clamping=fp16_clamping)
|
||||
elif component == 'decoder':
|
||||
tllm_model = tensorrt_llm.models.DecoderModel(
|
||||
num_layers=config['num_layers'],
|
||||
@ -1084,8 +1117,10 @@ def enc_dec_build_helper(component, config, args):
|
||||
embedding_sharding_dim=0, # by default
|
||||
mapping=mapping,
|
||||
rescale_before_lm_head=rescale_before_lm_head,
|
||||
logits_dtype='float32', # by default
|
||||
logits_dtype=logits_dtype, # by default
|
||||
fp16_clamping=fp16_clamping)
|
||||
if use_weight_only and family == 'whisper':
|
||||
tllm_model = quantize_model(tllm_model, quant_mode)
|
||||
|
||||
# Module -> Network
|
||||
engine_name = get_engine_name(args.model, args.dtype, world_size,
|
||||
@ -1099,6 +1134,12 @@ def enc_dec_build_helper(component, config, args):
|
||||
network.plugin_config.set_bert_attention_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_gemm_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)
|
||||
if use_weight_only:
|
||||
network.plugin_config.set_weight_only_quant_matmul_plugin(
|
||||
dtype=args.dtype)
|
||||
elif args.mode == 'ootb-except-mha':
|
||||
network.plugin_config.set_bert_attention_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)
|
||||
|
||||
if world_size > 1:
|
||||
network.plugin_config.set_nccl_plugin(
|
||||
@ -1110,18 +1151,31 @@ def enc_dec_build_helper(component, config, args):
|
||||
|
||||
# Forward
|
||||
if component == 'encoder':
|
||||
inputs = tllm_model.prepare_inputs(
|
||||
max_batch_size=config['max_batch_size'],
|
||||
max_input_len=config['max_encoder_input_len'],
|
||||
)
|
||||
if family == 'whisper':
|
||||
inputs = tllm_model.prepare_inputs(
|
||||
max_batch_size=config['max_batch_size'], )
|
||||
else:
|
||||
inputs = tllm_model.prepare_inputs(
|
||||
max_batch_size=config['max_batch_size'],
|
||||
max_input_len=config['max_encoder_input_len'],
|
||||
)
|
||||
elif component == 'decoder':
|
||||
inputs = tllm_model.prepare_inputs(
|
||||
max_batch_size=config['max_batch_size'],
|
||||
max_beam_width=config['max_beam_width'],
|
||||
max_decoder_input_len=config['max_decoder_input_len'],
|
||||
max_new_tokens=config['max_output_len'],
|
||||
max_encoder_input_len=config['max_encoder_input_len'],
|
||||
)
|
||||
if family == 'whisper':
|
||||
inputs = tllm_model.prepare_inputs(
|
||||
max_batch_size=config['max_batch_size'],
|
||||
max_beam_width=config['max_beam_width'],
|
||||
max_decoder_input_len=config['max_decoder_input_len'],
|
||||
max_new_tokens=config['max_output_len'],
|
||||
max_encoder_input_len=1500, # n_audio_ctx
|
||||
)
|
||||
else:
|
||||
inputs = tllm_model.prepare_inputs(
|
||||
max_batch_size=config['max_batch_size'],
|
||||
max_beam_width=config['max_beam_width'],
|
||||
max_decoder_input_len=config['max_decoder_input_len'],
|
||||
max_new_tokens=config['max_output_len'],
|
||||
max_encoder_input_len=config['max_encoder_input_len'],
|
||||
)
|
||||
|
||||
tllm_model(*inputs)
|
||||
|
||||
|
||||
@ -23,8 +23,9 @@ from base_benchmark import BaseBenchmark, get_engine_name
|
||||
from build import build_enc_dec
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._utils import trt_dtype_to_torch
|
||||
from tensorrt_llm._utils import (trt_dtype_to_torch, str_dtype_to_trt)
|
||||
from tensorrt_llm.quantization import QuantMode
|
||||
from tensorrt_llm.runtime.session import TensorInfo
|
||||
|
||||
|
||||
class EncDecBenchmark(BaseBenchmark):
|
||||
@ -49,6 +50,8 @@ class EncDecBenchmark(BaseBenchmark):
|
||||
# So we use separate variables for encoder and decoder here.
|
||||
self.encoder_engine_model_name = args.model
|
||||
self.decoder_engine_model_name = args.model
|
||||
# only for whisper parameter
|
||||
self.n_mels = 0
|
||||
|
||||
if self.engine_dir is not None:
|
||||
|
||||
@ -109,6 +112,8 @@ class EncDecBenchmark(BaseBenchmark):
|
||||
self.max_input_len = config["builder_config"][
|
||||
"max_encoder_input_len"]
|
||||
self.max_output_len = config["builder_config"]["max_output_len"]
|
||||
self.n_mels = config["builder_config"][
|
||||
'n_mels'] if 'whisper' in self.model_name else 0
|
||||
|
||||
for key, value in config["builder_config"].items():
|
||||
if key == "name":
|
||||
@ -173,6 +178,8 @@ class EncDecBenchmark(BaseBenchmark):
|
||||
if args.max_input_len is None else args.max_input_len
|
||||
self.max_output_len = build_config['max_output_len'] \
|
||||
if args.max_output_len is None else args.max_output_len
|
||||
self.n_mels = build_config[
|
||||
'n_mels'] if 'whisper' in self.model_name else 0
|
||||
# Build engine
|
||||
(
|
||||
encoder_engine_buffer,
|
||||
@ -198,6 +205,10 @@ class EncDecBenchmark(BaseBenchmark):
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
if 'whisper' in self.model_name:
|
||||
print(
|
||||
f"[WARNING] whisper benchmark is input_len=1500, no text prompt, output_len=arbitrary"
|
||||
)
|
||||
for inlen, outlen in self.in_out_lens:
|
||||
if (inlen > self.max_input_len or outlen > self.max_output_len):
|
||||
print(
|
||||
@ -216,29 +227,95 @@ class EncDecBenchmark(BaseBenchmark):
|
||||
|
||||
def prepare_inputs(self, config):
|
||||
batch_size, encoder_input_len = config[0], config[1]
|
||||
encoder_input_ids = (torch.randint(
|
||||
100, (batch_size, encoder_input_len)).int().cuda())
|
||||
# For now, just hardcode the decoder_start_token_id to 0 for t5 models.
|
||||
decoder_start_token_id = 0
|
||||
decoder_input_ids = torch.IntTensor([[decoder_start_token_id]
|
||||
]).to(self.device)
|
||||
decoder_input_ids = decoder_input_ids.repeat(
|
||||
(encoder_input_ids.shape[0], 1))
|
||||
# in padding mode --> keep input, just calculate actual length and max length
|
||||
# Note: 1st token should always count, even if it is pad_token_id (0). e.g., decoder start id in enc-dec models could be a single pad_token_id, we should count
|
||||
encoder_input_lengths = ((1 + (encoder_input_ids[:, 1:] != 0).sum(
|
||||
dim=1).type(torch.IntTensor).to(self.device)).clone().detach().to(
|
||||
dtype=torch.int32, device=self.device))
|
||||
decoder_input_lengths = ((1 + (decoder_input_ids[:, 1:] != 0).sum(
|
||||
dim=1).type(torch.IntTensor).to(self.device)).clone().detach().to(
|
||||
dtype=torch.int32, device=self.device))
|
||||
# attention mask, always set 1 as if all are valid tokens
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, encoder_input_len)).int().cuda()
|
||||
# cross attention mask, always set 1 as if all are valid tokens
|
||||
# [batch_size, query_len, encoder_input_len] currently, use query_len=1
|
||||
cross_attention_mask = torch.ones(
|
||||
(batch_size, 1, encoder_input_len)).int().cuda()
|
||||
attention_mask = None
|
||||
whisper_decoder_encoder_input_lengths = None
|
||||
outputs = {}
|
||||
if 'whisper' in self.model_name:
|
||||
# feature_len always fixed 3000 now
|
||||
feature_len = 3000
|
||||
encoder_input_ids = (torch.randint(
|
||||
1, 100, (batch_size, self.n_mels, feature_len)).int().cuda())
|
||||
encoder_input_lengths = torch.tensor([
|
||||
encoder_input_ids.shape[2] // 2
|
||||
for _ in range(encoder_input_ids.shape[0])
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
decoder_input_ids = (torch.randint(1, 100, (1, )).int().cuda())
|
||||
decoder_input_ids = decoder_input_ids.repeat(
|
||||
(encoder_input_ids.shape[0], 1))
|
||||
output_list = [
|
||||
TensorInfo('x', str_dtype_to_trt(self.dtype),
|
||||
encoder_input_ids.shape),
|
||||
TensorInfo('input_lengths', str_dtype_to_trt('int32'),
|
||||
encoder_input_lengths.shape)
|
||||
]
|
||||
output_info = (self.encoder_session).infer_shapes(output_list)
|
||||
outputs = {
|
||||
t.name: torch.empty(tuple(t.shape),
|
||||
dtype=trt_dtype_to_torch(t.dtype),
|
||||
device='cuda')
|
||||
for t in output_info
|
||||
}
|
||||
whisper_decoder_encoder_input_lengths = torch.tensor(
|
||||
[
|
||||
outputs['output'].shape[1]
|
||||
for x in range(outputs['output'].shape[0])
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
|
||||
decoder_input_lengths = torch.tensor([
|
||||
decoder_input_ids.shape[-1]
|
||||
for _ in range(decoder_input_ids.shape[0])
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
cross_attention_mask = torch.ones(
|
||||
[outputs['output'].shape[0], 1,
|
||||
outputs['output'].shape[1]]).int().cuda()
|
||||
else:
|
||||
encoder_input_ids = (torch.randint(
|
||||
100, (batch_size, encoder_input_len)).int().cuda())
|
||||
# For now, just hardcode the decoder_start_token_id to 0 for t5 models.
|
||||
decoder_start_token_id = 0
|
||||
decoder_input_ids = torch.IntTensor([[decoder_start_token_id]
|
||||
]).to(self.device)
|
||||
decoder_input_ids = decoder_input_ids.repeat(
|
||||
(encoder_input_ids.shape[0], 1))
|
||||
# in padding mode --> keep input, just calculate actual length and max length
|
||||
# Note: 1st token should always count, even if it is pad_token_id (0). e.g., decoder start id in enc-dec models could be a single pad_token_id, we should count
|
||||
encoder_input_lengths = ((
|
||||
1 + (encoder_input_ids[:, 1:] != 0).sum(dim=1).type(
|
||||
torch.IntTensor).to(self.device)).clone().detach().to(
|
||||
dtype=torch.int32, device=self.device))
|
||||
decoder_input_lengths = ((
|
||||
1 + (decoder_input_ids[:, 1:] != 0).sum(dim=1).type(
|
||||
torch.IntTensor).to(self.device)).clone().detach().to(
|
||||
dtype=torch.int32, device=self.device))
|
||||
# attention mask, always set 1 as if all are valid tokens
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, encoder_input_len)).int().cuda()
|
||||
# cross attention mask, always set 1 as if all are valid tokens
|
||||
# [batch_size, query_len, encoder_input_len] currently, use query_len=1
|
||||
cross_attention_mask = torch.ones(
|
||||
(batch_size, 1, encoder_input_len)).int().cuda()
|
||||
|
||||
hidden_size = (self.encoder_model_config.hidden_size *
|
||||
self.world_size) # tp_size
|
||||
hidden_states_shape = (
|
||||
encoder_input_ids.shape[0],
|
||||
encoder_input_ids.shape[1],
|
||||
hidden_size,
|
||||
)
|
||||
hidden_states_dtype = lambda name: trt_dtype_to_torch(
|
||||
self.encoder_session.engine.get_tensor_dtype(name))
|
||||
|
||||
outputs["encoder_output"] = torch.empty(
|
||||
hidden_states_shape,
|
||||
dtype=hidden_states_dtype("encoder_output"),
|
||||
device=self.device,
|
||||
).contiguous()
|
||||
|
||||
stream = torch.cuda.current_stream().cuda_stream
|
||||
return (
|
||||
@ -248,6 +325,8 @@ class EncDecBenchmark(BaseBenchmark):
|
||||
decoder_input_ids,
|
||||
decoder_input_lengths,
|
||||
cross_attention_mask,
|
||||
whisper_decoder_encoder_input_lengths,
|
||||
outputs,
|
||||
stream,
|
||||
)
|
||||
|
||||
@ -260,47 +339,37 @@ class EncDecBenchmark(BaseBenchmark):
|
||||
decoder_input_ids,
|
||||
decoder_input_lengths,
|
||||
cross_attention_mask,
|
||||
whisper_decoder_encoder_input_lengths,
|
||||
outputs,
|
||||
stream,
|
||||
) = inputs
|
||||
|
||||
hidden_size = (self.encoder_model_config.hidden_size * self.world_size
|
||||
) # tp_size
|
||||
hidden_states_shape = (
|
||||
encoder_input_ids.shape[0],
|
||||
encoder_input_ids.shape[1],
|
||||
hidden_size,
|
||||
)
|
||||
hidden_states_dtype = lambda name: trt_dtype_to_torch(
|
||||
self.encoder_session.engine.get_tensor_dtype(name))
|
||||
|
||||
# input tensors
|
||||
inputs = {}
|
||||
inputs["input_ids"] = encoder_input_ids.contiguous()
|
||||
inputs["input_lengths"] = encoder_input_lengths
|
||||
inputs["max_input_length"] = torch.empty(
|
||||
(self.max_input_len, ),
|
||||
dtype=hidden_states_dtype("max_input_length"),
|
||||
device=self.device,
|
||||
).contiguous()
|
||||
if 'whisper' in self.model_name:
|
||||
inputs['x'] = encoder_input_ids.contiguous()
|
||||
inputs["input_lengths"] = encoder_input_lengths
|
||||
else:
|
||||
inputs["input_ids"] = encoder_input_ids.contiguous()
|
||||
inputs["input_lengths"] = encoder_input_lengths
|
||||
inputs["max_input_length"] = torch.empty(
|
||||
(self.max_input_len, ),
|
||||
dtype=hidden_states_dtype("max_input_length"),
|
||||
device=self.device,
|
||||
).contiguous()
|
||||
|
||||
if not self.encoder_model_config.gpt_attention_plugin:
|
||||
inputs["attention_mask"] = attention_mask.contiguous()
|
||||
if not self.encoder_model_config.gpt_attention_plugin:
|
||||
inputs["attention_mask"] = attention_mask.contiguous()
|
||||
|
||||
if self.encoder_model_config.has_position_embedding:
|
||||
bsz, seq_len = encoder_input_ids.shape[:2]
|
||||
position_ids = torch.arange(seq_len,
|
||||
dtype=torch.int32,
|
||||
device=encoder_input_ids.device).expand(
|
||||
bsz, -1)
|
||||
inputs['position_ids'] = position_ids.contiguous()
|
||||
|
||||
# output tensors
|
||||
outputs = {}
|
||||
outputs["encoder_output"] = torch.empty(
|
||||
hidden_states_shape,
|
||||
dtype=hidden_states_dtype("encoder_output"),
|
||||
device=self.device,
|
||||
).contiguous()
|
||||
if self.encoder_model_config.has_position_embedding:
|
||||
bsz, seq_len = encoder_input_ids.shape[:2]
|
||||
position_ids = torch.arange(
|
||||
seq_len, dtype=torch.int32,
|
||||
device=encoder_input_ids.device).expand(bsz, -1)
|
||||
inputs['position_ids'] = position_ids.contiguous()
|
||||
|
||||
# run encoder
|
||||
self.encoder_session.set_shapes(inputs)
|
||||
@ -311,6 +380,12 @@ class EncDecBenchmark(BaseBenchmark):
|
||||
# run decoder
|
||||
sampling_config = tensorrt_llm.runtime.SamplingConfig(
|
||||
end_id=1, pad_id=0, num_beams=self.num_beams, min_length=output_len)
|
||||
encoder_output = outputs[
|
||||
'output'] if 'whisper' in self.model_name else outputs[
|
||||
"encoder_output"]
|
||||
encoder_max_input_length = encoder_output.shape[
|
||||
1] if 'whisper' in self.model_name else torch.max(
|
||||
encoder_input_lengths).item()
|
||||
|
||||
self.decoder_session.setup(
|
||||
decoder_input_lengths.size(0),
|
||||
@ -318,9 +393,8 @@ class EncDecBenchmark(BaseBenchmark):
|
||||
output_len,
|
||||
beam_width=self.num_beams,
|
||||
max_attention_window_size=None,
|
||||
encoder_max_input_length=torch.max(encoder_input_lengths).item(),
|
||||
encoder_max_input_length=encoder_max_input_length,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
cross_attention_mask = None if self.decoder_model_config.gpt_attention_plugin else cross_attention_mask
|
||||
|
||||
@ -328,11 +402,11 @@ class EncDecBenchmark(BaseBenchmark):
|
||||
decoder_input_ids,
|
||||
decoder_input_lengths,
|
||||
sampling_config,
|
||||
encoder_output=outputs["encoder_output"],
|
||||
encoder_input_lengths=encoder_input_lengths,
|
||||
encoder_output=encoder_output,
|
||||
encoder_input_lengths=whisper_decoder_encoder_input_lengths
|
||||
if 'whisper' in self.model_name else encoder_input_lengths,
|
||||
cross_attention_mask=cross_attention_mask,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def report(self,
|
||||
config,
|
||||
|
||||
@ -182,7 +182,7 @@ if(ENABLE_MULTI_DEVICE EQUAL 1)
|
||||
find_library(NCCL_LIB nccl HINTS ${NCCL_LIB_DIR})
|
||||
endif()
|
||||
|
||||
get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_SOURCE_DIR} PATH)
|
||||
get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR} PATH)
|
||||
|
||||
set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty)
|
||||
include_directories(
|
||||
|
||||
@ -51,7 +51,7 @@ public:
|
||||
batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback getInferenceRequestsCb,
|
||||
SendResponseCallback sendResponseCb, PollStopSignalCallback pollStopSignalCb = nullptr,
|
||||
ReturnBatchManagerStatsCallback returnBatchManagerStatsCb = nullptr,
|
||||
const TrtGptModelOptionalParams& optionalParams = TrtGptModelOptionalParams(),
|
||||
TrtGptModelOptionalParams const& optionalParams = TrtGptModelOptionalParams(),
|
||||
std::optional<uint64_t> terminateReqId = std::nullopt, std::optional<SizeType> maxDraftTokens = std::nullopt,
|
||||
bool excludeInputInOutput = false);
|
||||
|
||||
@ -82,9 +82,9 @@ protected:
|
||||
virtual BatchManagerErrorCode_t step(RequestList& activeRequests, std::set<uint64_t>& activeRequestsIds);
|
||||
|
||||
private:
|
||||
SizeType getMaxInputLen() const;
|
||||
SizeType getMaxSequenceLen() const;
|
||||
SizeType getMaxNumSequences() const;
|
||||
[[nodiscard]] SizeType getMaxInputLen() const;
|
||||
[[nodiscard]] SizeType getMaxSequenceLen() const;
|
||||
[[nodiscard]] SizeType getMaxNumSequences() const;
|
||||
|
||||
void validateLlmRequest(LlmRequest& newReq) const;
|
||||
static std::shared_ptr<LlmRequest> fillLlmRequest(std::shared_ptr<InferenceRequest> newReq);
|
||||
|
||||
@ -1,96 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/batch_manager/schedulerPolicy.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
|
||||
#include <list>
|
||||
#include <memory>
|
||||
|
||||
namespace tensorrt_llm::batch_manager::batch_scheduler
|
||||
{
|
||||
|
||||
/// @brief This scheduler takes into account the given request capacity and the KV cache capacity.
|
||||
/// Depending on the SchedulerPolicy it will schedule already started and new requests,
|
||||
/// or even terminate previously started requests.
|
||||
class CapacityScheduler
|
||||
{
|
||||
public:
|
||||
virtual ~CapacityScheduler() = default;
|
||||
|
||||
using RequestTable = std::map<uint64_t, std::shared_ptr<LlmRequest>>;
|
||||
using SizeType = tensorrt_llm::runtime::SizeType;
|
||||
using RequestList = std::list<std::shared_ptr<LlmRequest>>;
|
||||
|
||||
/// @brief Takes as input a sorted list of requests and outputs a sorted lists of requests
|
||||
/// to update for this current iteration, and a map of requests to terminate
|
||||
virtual std::tuple<RequestList, RequestTable> scheduleRequests(const RequestList& activeRequests) = 0;
|
||||
};
|
||||
|
||||
/// @brief Schedule up to maxNumRequests requests
|
||||
class MaxRequestsScheduler : public CapacityScheduler
|
||||
{
|
||||
public:
|
||||
explicit MaxRequestsScheduler(SizeType maxNumRequests);
|
||||
|
||||
std::tuple<RequestList, RequestTable> scheduleRequests(const RequestList& activeRequests) override;
|
||||
|
||||
private:
|
||||
SizeType mMaxNumRequests;
|
||||
};
|
||||
|
||||
/// @brief Schedule requests using the MAX_UTILIZATION policy
|
||||
/// @details Try reserving resources to advance requests by one step,
|
||||
/// may terminate previously started requests.
|
||||
class MaxUtilizationScheduler : public CapacityScheduler
|
||||
{
|
||||
public:
|
||||
MaxUtilizationScheduler(
|
||||
SizeType maxNumRequests, kv_cache_manager::KVCacheManager* kvCacheManager, bool manyMicroBatches);
|
||||
|
||||
std::tuple<RequestList, RequestTable> scheduleRequests(const RequestList& activeRequests) override;
|
||||
|
||||
private:
|
||||
bool trySchedulingRequestMaxUtilization(
|
||||
std::shared_ptr<LlmRequest> const& req, RequestList& scheduledRequests, SizeType& numScheduledBlocks);
|
||||
|
||||
SizeType mMaxNumRequests;
|
||||
kv_cache_manager::KVCacheManager* mKvCacheManager{nullptr};
|
||||
/// @brief Boolean that indicates if multiple micro batches might be in flight
|
||||
bool mManyMicroBatches;
|
||||
};
|
||||
|
||||
/// @brief Schedule requests using the GUARANTEED_NO_EVICT policy
|
||||
class GuaranteedNoEvictScheduler : public CapacityScheduler
|
||||
{
|
||||
public:
|
||||
GuaranteedNoEvictScheduler(SizeType maxNumRequests, kv_cache_manager::KVCacheManager* kvCacheManager);
|
||||
|
||||
std::tuple<RequestList, RequestTable> scheduleRequests(const RequestList& activeRequests) override;
|
||||
|
||||
private:
|
||||
SizeType mMaxNumRequests;
|
||||
kv_cache_manager::KVCacheManager* mKvCacheManager{nullptr};
|
||||
};
|
||||
|
||||
std::unique_ptr<CapacityScheduler> makeCapacityScheduler(tensorrt_llm::runtime::SizeType maxNumRequests,
|
||||
kv_cache_manager::KVCacheManager* kvCacheManager, SchedulerPolicy schedulerPolicy, bool manyMicroBatches = false);
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::batch_scheduler
|
||||
@ -16,8 +16,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/batch_manager/namedTensor.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
#include <algorithm>
|
||||
@ -48,6 +48,7 @@ auto constexpr kTemperatureTensorName = "temperature";
|
||||
auto constexpr kRuntimeTopKTensorName = "runtime_top_k";
|
||||
auto constexpr kRuntimeTopPTensorName = "runtime_top_p";
|
||||
auto constexpr kLengthPenaltyTensorName = "len_penalty";
|
||||
auto constexpr kEarlyStoppingTensorName = "early_stopping";
|
||||
auto constexpr kRepetitionPenaltyTensorName = "repetition_penalty";
|
||||
auto constexpr kMinLengthTensorName = "min_length";
|
||||
auto constexpr kPresencePenaltyTensorName = "presence_penalty";
|
||||
@ -74,6 +75,11 @@ auto constexpr kLoraWeights = "lora_weights";
|
||||
// "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
|
||||
// "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
|
||||
// "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
|
||||
// "cross_attn_qkv": 8 # for enc-dec adapter for cross attention in decoder
|
||||
// "cross_attn_q": 9 # for enc-dec adapter for cross attention in decoder
|
||||
// "cross_attn_k": 10 # for enc-dec adapter for cross attention in decoder
|
||||
// "cross_attn_v": 11 # for enc-dec adapter for cross attention in decoder
|
||||
// "cross_attn_dense": 12 # for enc-dec adapter for cross attention in decoder
|
||||
//
|
||||
// last dim holds [ module_id, layer_idx, adapter_size (D / R value) ]
|
||||
auto constexpr kLoraConfig = "lora_config"; // [num_lora_modules_layers, 3]
|
||||
@ -91,24 +97,29 @@ auto constexpr kGenerationLogitsName = "generation_logits";
|
||||
|
||||
} // namespace inference_request
|
||||
|
||||
template <typename TTensor, typename TNamedTensor>
|
||||
template <typename TTensor, typename TNamedTensor, typename TStream = runtime::BufferManager::CudaStreamPtr>
|
||||
class GenericInferenceRequest
|
||||
{
|
||||
public:
|
||||
using TensorPtr = TTensor;
|
||||
using NamedTensorType = TNamedTensor;
|
||||
using TensorMap = std::unordered_map<std::string, TTensor>;
|
||||
using LogitsPostProcessor = typename GenericLlmRequest<TensorPtr, TStream>::LogitsPostProcessor;
|
||||
|
||||
explicit GenericInferenceRequest(uint64_t requestId)
|
||||
explicit GenericInferenceRequest(
|
||||
uint64_t requestId, std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
|
||||
: mRequestId{requestId}
|
||||
, mIsStreaming{false}
|
||||
, mlogitsPostProcessor(logitsPostProcessor)
|
||||
{
|
||||
}
|
||||
|
||||
GenericInferenceRequest(uint64_t requestId, TensorMap&& tensorMap)
|
||||
GenericInferenceRequest(uint64_t requestId, TensorMap&& tensorMap,
|
||||
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
|
||||
: mRequestId{requestId}
|
||||
, mIsStreaming{false}
|
||||
, mInputTensors{std::move(tensorMap)}
|
||||
, mlogitsPostProcessor(logitsPostProcessor)
|
||||
{
|
||||
for (auto const& [name, tensor] : mInputTensors)
|
||||
{
|
||||
@ -116,8 +127,9 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
GenericInferenceRequest(uint64_t requestId, TensorMap const& tensorMap)
|
||||
: GenericInferenceRequest(requestId, TensorMap{tensorMap})
|
||||
GenericInferenceRequest(uint64_t requestId, TensorMap const& tensorMap,
|
||||
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
|
||||
: GenericInferenceRequest(requestId, TensorMap{tensorMap}, logitsPostProcessor)
|
||||
{
|
||||
}
|
||||
|
||||
@ -141,6 +153,16 @@ public:
|
||||
return mInputTensors;
|
||||
}
|
||||
|
||||
void setLogitsPostProcessor(std::optional<LogitsPostProcessor> cb)
|
||||
{
|
||||
mlogitsPostProcessor = cb;
|
||||
}
|
||||
|
||||
std::optional<LogitsPostProcessor> getLogitsPostProcessor()
|
||||
{
|
||||
return mlogitsPostProcessor;
|
||||
}
|
||||
|
||||
static std::array constexpr kTensorNames = {
|
||||
inference_request::kInputIdsTensorName,
|
||||
inference_request::kDraftInputIdsTensorName,
|
||||
@ -156,6 +178,7 @@ public:
|
||||
inference_request::kRuntimeTopKTensorName,
|
||||
inference_request::kRuntimeTopPTensorName,
|
||||
inference_request::kLengthPenaltyTensorName,
|
||||
inference_request::kEarlyStoppingTensorName,
|
||||
inference_request::kRepetitionPenaltyTensorName,
|
||||
inference_request::kMinLengthTensorName,
|
||||
inference_request::kPresencePenaltyTensorName,
|
||||
@ -200,7 +223,10 @@ public:
|
||||
\
|
||||
void set##funcName(TensorPtr const& tensor) \
|
||||
{ \
|
||||
TLLM_CHECK_WITH_INFO(tensor, "Cannot set nullptr when calling %s", __FUNCTION__); \
|
||||
if constexpr (std::is_same_v<TensorPtr, tensorrt_llm::runtime::ITensor::SharedPtr>) \
|
||||
{ \
|
||||
TLLM_CHECK_WITH_INFO(tensor, "Cannot set nullptr when calling %s", __FUNCTION__); \
|
||||
} \
|
||||
mInputTensors[tensorName] = tensor; \
|
||||
}
|
||||
|
||||
@ -218,6 +244,7 @@ public:
|
||||
TENSOR_GETTER_SETTER(RuntimeTopK, inference_request::kRuntimeTopKTensorName)
|
||||
TENSOR_GETTER_SETTER(RuntimeTopP, inference_request::kRuntimeTopPTensorName)
|
||||
TENSOR_GETTER_SETTER(LengthPenalty, inference_request::kLengthPenaltyTensorName)
|
||||
TENSOR_GETTER_SETTER(EarlyStopping, inference_request::kEarlyStoppingTensorName)
|
||||
TENSOR_GETTER_SETTER(RepetitionPenalty, inference_request::kRepetitionPenaltyTensorName)
|
||||
TENSOR_GETTER_SETTER(MinLength, inference_request::kMinLengthTensorName)
|
||||
TENSOR_GETTER_SETTER(PresencePenalty, inference_request::kPresencePenaltyTensorName)
|
||||
@ -243,6 +270,7 @@ protected:
|
||||
uint64_t mRequestId;
|
||||
bool mIsStreaming;
|
||||
TensorMap mInputTensors;
|
||||
std::optional<LogitsPostProcessor> mlogitsPostProcessor;
|
||||
};
|
||||
|
||||
class InferenceRequest : public GenericInferenceRequest<tensorrt_llm::runtime::ITensor::SharedPtr, NamedTensor>
|
||||
|
||||
@ -326,7 +326,7 @@ private:
|
||||
//! \param seqSlotIdx Batch slot of sequence to which blocks are assigned.
|
||||
//! \return Number of matched tokens from loaded blocks.
|
||||
SizeType loadOrAllocateBlocks(
|
||||
std::list<VecTokens> blockedTokens, GenerationRequest& sequence, SizeType beamIdx, SizeType seqSlotIdx);
|
||||
std::list<VecTokens> const& blockedTokens, GenerationRequest& sequence, SizeType beamIdx, SizeType seqSlotIdx);
|
||||
|
||||
//! \brief Find block least likely to be reused, free it if necessary and return.
|
||||
[[nodiscard]] BlockPtr getFreeBlock();
|
||||
@ -362,9 +362,9 @@ public:
|
||||
using CudaStreamPtr = std::shared_ptr<runtime::CudaStream>;
|
||||
|
||||
KVCacheManager(SizeType numLayers, SizeType numKvHeads, SizeType sizePerHead, SizeType tokensPerBlock,
|
||||
SizeType maxNumBlocks, SizeType maxNumSequences, SizeType maxBeamWidth, SizeType maxBlocksPerSeq,
|
||||
SizeType maxAttentionWindow, SizeType sinkTokenLength, bool useOneMoreBlock, nvinfer1::DataType dtype,
|
||||
CudaStreamPtr stream, bool enableBlockReuse = false, bool useUvm = false);
|
||||
SizeType maxNumBlocks, SizeType maxNumSequences, SizeType maxBeamWidth, SizeType maxAttentionWindow,
|
||||
SizeType sinkTokenLength, bool useOneMoreBlock, nvinfer1::DataType dtype, CudaStreamPtr stream,
|
||||
bool enableBlockReuse = false, bool useUvm = false);
|
||||
|
||||
void startScheduling();
|
||||
|
||||
@ -405,6 +405,11 @@ public:
|
||||
return mBlockSize;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType getMaxBlocksPerSeq() const
|
||||
{
|
||||
return mMaxBlocksPerSeq;
|
||||
}
|
||||
|
||||
[[nodiscard]] BlockManager const& getBlockManager() const
|
||||
{
|
||||
return mBlockManager;
|
||||
@ -414,13 +419,13 @@ public:
|
||||
/// iterations
|
||||
/// @param req The request for which we need to calculate the number of needed KV cache blocks
|
||||
/// @return The number of blocks
|
||||
SizeType getNeededBlocksOneStep(LlmRequest const& req, bool twoStepsLookAhead) const;
|
||||
[[nodiscard]] SizeType getNeededBlocksOneStep(LlmRequest const& req, bool twoStepsLookAhead) const;
|
||||
|
||||
/// @brief Function that computes the number of KV cache blocks needed to advance a request to completion (i.e. for
|
||||
/// maxNewTokens)
|
||||
/// @param req The request for which we need to calculate the number of needed KV cache blocks
|
||||
/// @return The number of blocks
|
||||
SizeType getNeededBlocksToCompletion(LlmRequest const& req) const;
|
||||
[[nodiscard]] SizeType getNeededBlocksToCompletion(LlmRequest const& req) const;
|
||||
|
||||
[[nodiscard]] std::vector<runtime::ITensor::SharedPtr> const& getMemoryPools() const
|
||||
{
|
||||
@ -458,7 +463,7 @@ public:
|
||||
* modelConfig.getSizePerHead();
|
||||
}
|
||||
|
||||
[[nodiscard]] static SizeType getMaxNumTokens(KvCacheConfig const& config, nvinfer1::DataType dtype,
|
||||
[[nodiscard]] static SizeType calculateMaxNumBlocks(KvCacheConfig const& config, nvinfer1::DataType dtype,
|
||||
tensorrt_llm::runtime::GptModelConfig const& modelConfig, tensorrt_llm::runtime::WorldConfig const& worldConfig,
|
||||
runtime::BufferManager const& bufferManager);
|
||||
|
||||
@ -491,10 +496,8 @@ private:
|
||||
// Maximum kv cache length per sequence
|
||||
// Enable cyclic kv cache when it exceeds
|
||||
SizeType mMaxAttentionWindow;
|
||||
// Sink token length in the kv cache per sequence
|
||||
SizeType mSinkTokenLength;
|
||||
// Bubble token length
|
||||
SizeType mBubbleLength;
|
||||
// Number of tokens to fill up the sink tokens to a full block size
|
||||
SizeType mSinkBubbleLength;
|
||||
// Maximum token length (including bubble)
|
||||
SizeType mMaxTokenNum;
|
||||
// Number of tokens in the sink blocks
|
||||
|
||||
@ -25,6 +25,7 @@
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
@ -38,7 +39,7 @@ enum LlmRequestState_t
|
||||
REQUEST_STATE_GENERATION_COMPLETE = 3
|
||||
};
|
||||
|
||||
template <typename TTensor>
|
||||
template <typename TTensor, typename TStream = runtime::BufferManager::CudaStreamPtr>
|
||||
class GenericLlmRequest
|
||||
{
|
||||
public:
|
||||
@ -49,9 +50,10 @@ public:
|
||||
using VecLogProbs = std::vector<float>;
|
||||
using BeamTokens = std::vector<VecTokens>;
|
||||
using TensorPtr = TTensor;
|
||||
using LogitsPostProcessor = std::function<TensorPtr(RequestIdType, TensorPtr&, BeamTokens const&, TStream)>;
|
||||
|
||||
GenericLlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<VecTokens> inputTokens,
|
||||
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
|
||||
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
|
||||
std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
|
||||
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
|
||||
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
|
||||
@ -59,7 +61,8 @@ public:
|
||||
std::optional<TensorPtr> loraConfig = std::nullopt, bool returnLogProbs = false,
|
||||
bool returnContextLogits = false, bool returnGenerationLogits = false,
|
||||
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt,
|
||||
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false)
|
||||
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false,
|
||||
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
|
||||
: mRequestId(requestId)
|
||||
, mPromptLen(inputTokens->size())
|
||||
, mMaxNewTokens(maxNewTokens)
|
||||
@ -69,15 +72,16 @@ public:
|
||||
, mEndId(endId)
|
||||
, mPadId(padId)
|
||||
, mSeqSlot(-1)
|
||||
, mLogitsPostProcessor(logitsPostProcessor)
|
||||
, mOrigPromptLen(mPromptLen)
|
||||
, mMaxSentTokenPos(mPromptLen - 1)
|
||||
, mEmbeddingBias(embeddingBias)
|
||||
, mBadWordsList(badWordsList)
|
||||
, mStopWordsList(stopWordsList)
|
||||
, mPromptEmbeddingTable(promptEmbeddingTable)
|
||||
, mEmbeddingBias(std::move(embeddingBias))
|
||||
, mBadWordsList(std::move(badWordsList))
|
||||
, mStopWordsList(std::move(stopWordsList))
|
||||
, mPromptEmbeddingTable(std::move(promptEmbeddingTable))
|
||||
, mPromptVocabSize(promptVocabSize)
|
||||
, mLoraWeights(loraWeights)
|
||||
, mLoraConfig(loraConfig)
|
||||
, mLoraWeights(std::move(loraWeights))
|
||||
, mLoraConfig(std::move(loraConfig))
|
||||
, mReturnLogProbs(returnLogProbs)
|
||||
, mContextChunkSize(std::nullopt)
|
||||
, mContextCurrentPosition(0)
|
||||
@ -198,14 +202,14 @@ public:
|
||||
/// @brief Get total number of tokens for this req (prompt + generated)
|
||||
/// @param beam The beam index
|
||||
/// @return The number of tokens
|
||||
SizeType getNumTokens(SizeType beam) const
|
||||
[[nodiscard]] SizeType getNumTokens(SizeType beam) const
|
||||
{
|
||||
return mTokens.at(beam).size();
|
||||
}
|
||||
|
||||
/// @brief Get max number of tokens across all beams
|
||||
/// @return The number of tokens
|
||||
SizeType getMaxBeamNumTokens() const
|
||||
[[nodiscard]] SizeType getMaxBeamNumTokens() const
|
||||
{
|
||||
SizeType maxTokens = 0;
|
||||
for (SizeType beam = 0; beam < mSamplingConfig.beamWidth; ++beam)
|
||||
@ -219,7 +223,7 @@ public:
|
||||
/// @param beam The beam index
|
||||
/// @param pos The position of the token relative to beginning of the prompt
|
||||
/// @return The token index
|
||||
TokenIdType getToken(SizeType beam, SizeType pos) const
|
||||
[[nodiscard]] TokenIdType getToken(SizeType beam, SizeType pos) const
|
||||
{
|
||||
return mTokens.at(beam).at(pos);
|
||||
}
|
||||
@ -227,42 +231,42 @@ public:
|
||||
/// @brief Get the tokens at a given beam index
|
||||
/// @param beam The beam index
|
||||
/// @return A vector of tokens for this beam index, includes the prompt
|
||||
VecTokens const& getTokens(SizeType beam) const
|
||||
[[nodiscard]] VecTokens const& getTokens(SizeType beam) const
|
||||
{
|
||||
return mTokens.at(beam);
|
||||
}
|
||||
|
||||
/// @brief Get all tokens (input+output) for all beams
|
||||
/// @return A vector of vector of tokens.
|
||||
BeamTokens const& getTokens() const
|
||||
[[nodiscard]] BeamTokens const& getTokens() const
|
||||
{
|
||||
return mTokens;
|
||||
}
|
||||
|
||||
/// @brief Get the draft tokens
|
||||
/// @return shared_ptr to vector of draft tokens
|
||||
std::shared_ptr<VecTokens> const& getDraftTokens() const
|
||||
[[nodiscard]] std::shared_ptr<VecTokens> const& getDraftTokens() const
|
||||
{
|
||||
return mDraftTokens;
|
||||
}
|
||||
|
||||
/// @brief Get the logits for the draft tokens
|
||||
/// @return Tensor of draft logits
|
||||
std::optional<TensorPtr> getDraftLogits() const
|
||||
[[nodiscard]] std::optional<TensorPtr> getDraftLogits() const
|
||||
{
|
||||
return mDraftLogits;
|
||||
}
|
||||
|
||||
/// @brief Returns true if request has draft tokens
|
||||
/// @return flag
|
||||
bool hasDraftTokens() const
|
||||
[[nodiscard]] bool hasDraftTokens() const
|
||||
{
|
||||
return mDraftTokens && mDraftTokens->size() > 0;
|
||||
return mDraftTokens && !mDraftTokens->empty();
|
||||
}
|
||||
|
||||
/// @brief Get the maximum number of generated tokens among all rays in beam
|
||||
/// @return The number of generated tokens (doesn't include the prompt tokens)
|
||||
SizeType getMaxNumGeneratedTokens() const
|
||||
[[nodiscard]] SizeType getMaxNumGeneratedTokens() const
|
||||
{
|
||||
return getMaxBeamNumTokens() - mPromptLen;
|
||||
}
|
||||
@ -283,14 +287,14 @@ public:
|
||||
assert(static_cast<size_t>(mSamplingConfig.beamWidth) == beamTokens.size());
|
||||
for (std::size_t beam = 0; beam < beamTokens.size(); ++beam)
|
||||
{
|
||||
const auto outputId = beamTokens[beam];
|
||||
auto const outputId = beamTokens[beam];
|
||||
mTokens.at(beam).push_back(outputId);
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Sets the generated tokens for all beams. Erases all previous generated tokens.
|
||||
/// @param generatedBeamTokens The generated tokens for all beams (vector of vector of tokens)
|
||||
void setGeneratedTokens(const BeamTokens& generatedBeamTokens)
|
||||
void setGeneratedTokens(BeamTokens const& generatedBeamTokens)
|
||||
{
|
||||
assert(generatedBeamTokens.size() == static_cast<size_t>(mSamplingConfig.beamWidth));
|
||||
for (std::size_t beam = 0; beam < generatedBeamTokens.size(); ++beam)
|
||||
@ -346,7 +350,7 @@ public:
|
||||
/// @brief Get the maximum position of the tokens returned to the client. Use to ensure we don't return to
|
||||
/// client duplicated token positions.
|
||||
/// @return The maximum position of the tokens sent to the client
|
||||
SizeType getMaxSentTokenPos() const
|
||||
[[nodiscard]] SizeType getMaxSentTokenPos() const
|
||||
{
|
||||
return mMaxSentTokenPos;
|
||||
}
|
||||
@ -359,17 +363,17 @@ public:
|
||||
mMaxSentTokenPos = pos;
|
||||
}
|
||||
|
||||
std::optional<TensorPtr> getPromptEmbeddingTable() const
|
||||
[[nodiscard]] std::optional<TensorPtr> getPromptEmbeddingTable() const
|
||||
{
|
||||
return mPromptEmbeddingTable;
|
||||
}
|
||||
|
||||
std::optional<SizeType> getPromptVocabSize() const
|
||||
[[nodiscard]] std::optional<SizeType> getPromptVocabSize() const
|
||||
{
|
||||
return mPromptVocabSize;
|
||||
}
|
||||
|
||||
std::optional<TensorPtr> getLoraWeights() const
|
||||
[[nodiscard]] std::optional<TensorPtr> getLoraWeights() const
|
||||
{
|
||||
return mLoraWeights;
|
||||
}
|
||||
@ -384,7 +388,7 @@ public:
|
||||
mLoraWeights = std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<TensorPtr> getLoraConfig() const
|
||||
[[nodiscard]] std::optional<TensorPtr> getLoraConfig() const
|
||||
{
|
||||
return mLoraConfig;
|
||||
}
|
||||
@ -399,32 +403,37 @@ public:
|
||||
mLoraConfig = std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<TensorPtr> getEmbeddingBias() const
|
||||
[[nodiscard]] std::optional<TensorPtr> getEmbeddingBias() const
|
||||
{
|
||||
return mEmbeddingBias;
|
||||
}
|
||||
|
||||
std::optional<TensorPtr> getBadWordsList() const
|
||||
[[nodiscard]] std::optional<TensorPtr> getBadWordsList() const
|
||||
{
|
||||
return mBadWordsList;
|
||||
}
|
||||
|
||||
std::optional<TensorPtr> getStopWordsList() const
|
||||
[[nodiscard]] std::optional<TensorPtr> getStopWordsList() const
|
||||
{
|
||||
return mStopWordsList;
|
||||
}
|
||||
|
||||
bool returnLogProbs() const
|
||||
[[nodiscard]] bool returnLogProbs() const
|
||||
{
|
||||
return mReturnLogProbs;
|
||||
}
|
||||
|
||||
std::vector<VecLogProbs> const& getLogProbs() const
|
||||
void setReturnLogProbs(bool returnLogProbs)
|
||||
{
|
||||
mReturnLogProbs = returnLogProbs;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<VecLogProbs> const& getLogProbs() const
|
||||
{
|
||||
return mLogProbs;
|
||||
}
|
||||
|
||||
VecLogProbs const& getLogProbs(SizeType beam) const
|
||||
[[nodiscard]] VecLogProbs const& getLogProbs(SizeType beam) const
|
||||
{
|
||||
return mLogProbs.at(beam);
|
||||
}
|
||||
@ -435,7 +444,7 @@ public:
|
||||
mLogProbs.at(beam).insert(mLogProbs.at(beam).end(), logProbs.begin(), logProbs.end());
|
||||
}
|
||||
|
||||
VecLogProbs const& getCumLogProbs() const
|
||||
[[nodiscard]] VecLogProbs const& getCumLogProbs() const
|
||||
{
|
||||
return mCumLogProbs;
|
||||
}
|
||||
@ -445,17 +454,17 @@ public:
|
||||
mCumLogProbs.at(beam) = cumLogProb;
|
||||
}
|
||||
|
||||
SizeType getOrigPromptLen() const
|
||||
[[nodiscard]] SizeType getOrigPromptLen() const
|
||||
{
|
||||
return mOrigPromptLen;
|
||||
}
|
||||
|
||||
void setDraftTokens(const std::shared_ptr<VecTokens>& draftTokens)
|
||||
void setDraftTokens(std::shared_ptr<VecTokens> const& draftTokens)
|
||||
{
|
||||
mDraftTokens = draftTokens;
|
||||
}
|
||||
|
||||
void setDraftLogits(const std::optional<TensorPtr>& draftLogits)
|
||||
void setDraftLogits(std::optional<TensorPtr> const& draftLogits)
|
||||
{
|
||||
mDraftLogits = draftLogits;
|
||||
}
|
||||
@ -470,7 +479,7 @@ public:
|
||||
mReturnContextLogits = returnContextLogits;
|
||||
}
|
||||
|
||||
bool getReturnContextLogits() const
|
||||
[[nodiscard]] bool getReturnContextLogits() const
|
||||
{
|
||||
return mReturnContextLogits;
|
||||
}
|
||||
@ -480,12 +489,12 @@ public:
|
||||
mReturnGenerationLogits = returnGenerationLogits;
|
||||
}
|
||||
|
||||
bool getReturnGenerationLogits() const
|
||||
[[nodiscard]] bool getReturnGenerationLogits() const
|
||||
{
|
||||
return mReturnGenerationLogits;
|
||||
}
|
||||
|
||||
TensorPtr const& getContextLogitsHost() const
|
||||
[[nodiscard]] TensorPtr const& getContextLogitsHost() const
|
||||
{
|
||||
return mContextLogitsHost;
|
||||
}
|
||||
@ -501,7 +510,7 @@ public:
|
||||
runtime::ITensor::makeShape({mPromptLen, vocabSizePadded}), logitsDataType);
|
||||
}
|
||||
|
||||
TensorPtr const& getGenerationLogitsHost() const
|
||||
[[nodiscard]] TensorPtr const& getGenerationLogitsHost() const
|
||||
{
|
||||
return mGenerationLogitsHost;
|
||||
}
|
||||
@ -517,7 +526,7 @@ public:
|
||||
runtime::ITensor::makeShape({mSamplingConfig.beamWidth, mMaxNewTokens, vocabSizePadded}), logitsDataType);
|
||||
}
|
||||
|
||||
std::vector<TensorPtr> const& getGenerationLogitsFragments() const
|
||||
[[nodiscard]] std::vector<TensorPtr> const& getGenerationLogitsFragments() const
|
||||
{
|
||||
return mGenerationLogitsFragments;
|
||||
}
|
||||
@ -537,38 +546,38 @@ public:
|
||||
mGenerationLogitsFragments.clear();
|
||||
}
|
||||
|
||||
bool isContextInitState() const noexcept
|
||||
[[nodiscard]] bool isContextInitState() const noexcept
|
||||
{
|
||||
return mState == REQUEST_STATE_CONTEXT_INIT;
|
||||
}
|
||||
|
||||
bool isGenerationInProgessState() const noexcept
|
||||
[[nodiscard]] bool isGenerationInProgessState() const noexcept
|
||||
{
|
||||
return mState == REQUEST_STATE_GENERATION_IN_PROGRESS;
|
||||
}
|
||||
|
||||
/// To determine whether the context is unchunked. When a context is chunked into only a part, it
|
||||
/// is still different from the unchunked state, which indicates the initial status.
|
||||
bool isFullContextRequest() const noexcept
|
||||
[[nodiscard]] bool isFullContextRequest() const noexcept
|
||||
{
|
||||
return isContextInitState() && !mContextChunkSize;
|
||||
}
|
||||
|
||||
/// When chunked, the position of the current chunk is returned. Otherwise, only the beginning
|
||||
/// or end of the context is returned.
|
||||
SizeType getContextCurrentPosition() const noexcept
|
||||
[[nodiscard]] SizeType getContextCurrentPosition() const noexcept
|
||||
{
|
||||
return mContextCurrentPosition;
|
||||
}
|
||||
|
||||
/// Return the length of the context that has not yet been processed.
|
||||
SizeType getContextRemainingLength() const noexcept
|
||||
[[nodiscard]] SizeType getContextRemainingLength() const noexcept
|
||||
{
|
||||
return mPromptLen - getContextCurrentPosition();
|
||||
}
|
||||
|
||||
/// To retrieve the context chunk size, throw an exception when the context is not chunked.
|
||||
SizeType getContextChunkSize() const
|
||||
[[nodiscard]] SizeType getContextChunkSize() const
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
isContextInitState() && mContextChunkSize, "The current request is not in context chunking state.");
|
||||
@ -587,7 +596,7 @@ public:
|
||||
|
||||
/// Determines whether the current position is only one chunk away from the end of the context.
|
||||
/// It will return true when the context is not chunked.
|
||||
bool isLastContextChunk() const noexcept
|
||||
[[nodiscard]] bool isLastContextChunk() const noexcept
|
||||
{
|
||||
return isFullContextRequest()
|
||||
|| (isContextInitState() && getContextCurrentPosition() + getContextChunkSize() == mPromptLen);
|
||||
@ -595,7 +604,7 @@ public:
|
||||
|
||||
/// Returns whether the position is at the beginning of the context. It will return true when the
|
||||
/// context is not chunked.
|
||||
bool isFirstContextChunk() const noexcept
|
||||
[[nodiscard]] bool isFirstContextChunk() const noexcept
|
||||
{
|
||||
return isFullContextRequest() || getContextCurrentPosition() == 0;
|
||||
}
|
||||
@ -704,6 +713,7 @@ public:
|
||||
std::optional<SizeType> mEndId;
|
||||
std::optional<SizeType> mPadId;
|
||||
SizeType mSeqSlot;
|
||||
std::optional<LogitsPostProcessor> mLogitsPostProcessor;
|
||||
|
||||
protected:
|
||||
SizeType mOrigPromptLen;
|
||||
@ -805,7 +815,7 @@ public:
|
||||
using VecTokens = Base::VecTokens;
|
||||
|
||||
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<VecTokens> inputTokens,
|
||||
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
|
||||
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
|
||||
std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
|
||||
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
|
||||
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
|
||||
@ -813,10 +823,13 @@ public:
|
||||
std::optional<TensorPtr> loraConfig = std::nullopt, bool returnLogProbs = false,
|
||||
bool returnContextLogits = false, bool returnGenerationLogits = false,
|
||||
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt,
|
||||
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false)
|
||||
: Base(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, endId, padId, embeddingBias,
|
||||
badWordsList, stopWordsList, promptEmbeddingTable, promptVocabSize, loraWeights, loraConfig, returnLogProbs,
|
||||
returnContextLogits, returnGenerationLogits, draftTokens, draftLogits, excludeInputFromOutput)
|
||||
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false,
|
||||
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
|
||||
: Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId,
|
||||
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList),
|
||||
std::move(promptEmbeddingTable), promptVocabSize, std::move(loraWeights), std::move(loraConfig),
|
||||
returnLogProbs, returnContextLogits, returnGenerationLogits, std::move(draftTokens), std::move(draftLogits),
|
||||
excludeInputFromOutput, std::move(logitsPostProcessor))
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
@ -302,6 +302,8 @@ public:
|
||||
void allgather(const void* sendbuf, void* recvbuf, int count, MpiType dtype) const;
|
||||
void barrier() const;
|
||||
|
||||
void mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const;
|
||||
|
||||
bool operator==(MpiComm const& rhs) const
|
||||
{
|
||||
return mComm == rhs.mComm;
|
||||
|
||||
@ -46,8 +46,8 @@ public:
|
||||
std::optional<FloatType> beamSearchDiversityRate = std::nullopt,
|
||||
std::optional<FloatType> repetitionPenalty = std::nullopt,
|
||||
std::optional<FloatType> presencePenalty = std::nullopt,
|
||||
std::optional<FloatType> frequencyPenalty = std::nullopt,
|
||||
std::optional<FloatType> lengthPenalty = std::nullopt);
|
||||
std::optional<FloatType> frequencyPenalty = std::nullopt, std::optional<FloatType> lengthPenalty = std::nullopt,
|
||||
std::optional<SizeType> earlyStopping = std::nullopt);
|
||||
|
||||
~SamplingConfig();
|
||||
|
||||
@ -65,6 +65,7 @@ public:
|
||||
[[nodiscard]] std::optional<FloatType> getPresencePenalty() const;
|
||||
[[nodiscard]] std::optional<FloatType> getFrequencyPenalty() const;
|
||||
[[nodiscard]] std::optional<FloatType> getLengthPenalty() const;
|
||||
[[nodiscard]] std::optional<SizeType> getEarlyStopping() const;
|
||||
|
||||
private:
|
||||
SizeType mBeamWidth;
|
||||
@ -81,6 +82,7 @@ private:
|
||||
std::optional<FloatType> mPresencePenalty;
|
||||
std::optional<FloatType> mFrequencyPenalty;
|
||||
std::optional<FloatType> mLengthPenalty;
|
||||
std::optional<SizeType> mEarlyStopping;
|
||||
};
|
||||
|
||||
/// @brief Configuration that controls the outputs of a Result
|
||||
|
||||
@ -22,9 +22,7 @@
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include <NvInferRuntime.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
||||
@ -16,7 +16,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
|
||||
@ -16,7 +16,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
@ -16,7 +16,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/promptTuningParams.h"
|
||||
|
||||
@ -16,7 +16,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
|
||||
@ -16,7 +16,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaAllocator.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/decodingInput.h"
|
||||
#include "tensorrt_llm/runtime/decodingMode.h"
|
||||
@ -24,7 +23,6 @@
|
||||
#include "tensorrt_llm/runtime/samplingConfig.h"
|
||||
#include <curand_kernel.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
|
||||
@ -16,7 +16,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/cudaEvent.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
@ -25,11 +24,7 @@
|
||||
#include "tensorrt_llm/runtime/iGptDecoderBatch.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
@ -58,7 +53,7 @@ public:
|
||||
|
||||
TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) override;
|
||||
|
||||
void forwardSync(decoder_batch::Token const& e) override;
|
||||
void forwardSync(decoder_batch::Token const& token) override;
|
||||
|
||||
void forwardAsync(decoder::Output& output, decoder::Input const& input) override;
|
||||
|
||||
@ -89,7 +84,7 @@ public:
|
||||
|
||||
//! @brief Gather final beam search results for request `batchIdx`.
|
||||
//! Result will only be available after event returned.
|
||||
[[nodiscard]] CudaEvent finalize(SizeType batchIdx) const;
|
||||
[[nodiscard]] CudaEvent finalize(SizeType batchIdx) const override;
|
||||
|
||||
//! @brief Gather final beam search results for all requests.
|
||||
void finalize() const override;
|
||||
@ -108,7 +103,7 @@ public:
|
||||
}
|
||||
|
||||
//! @returns [maxBeamWidth], cumulative log probabilities (per beam), on gpu
|
||||
[[nodiscard]] TensorPtr getCumLogProbs(SizeType batchIdx) const
|
||||
[[nodiscard]] TensorPtr getCumLogProbs(SizeType batchIdx) const override
|
||||
{
|
||||
auto tensor = ITensor::slice(mJointDecodingOutput->cumLogProbs, batchIdx, 1);
|
||||
tensor->squeeze(0);
|
||||
@ -122,7 +117,7 @@ public:
|
||||
}
|
||||
|
||||
//! @returns [maxBeamWidth, maxSequenceLength], log probabilities (per beam), on gpu
|
||||
[[nodiscard]] TensorPtr getLogProbs(SizeType batchIdx) const
|
||||
[[nodiscard]] TensorPtr getLogProbs(SizeType batchIdx) const override
|
||||
{
|
||||
auto tensor = ITensor::slice(mJointDecodingOutput->logProbs, batchIdx, 1);
|
||||
tensor->squeeze(0);
|
||||
@ -141,7 +136,7 @@ public:
|
||||
//! @returns [batchSize, beamWidth], tokens generated in `iter` (per beam), on gpu
|
||||
[[nodiscard]] TensorPtr getNewTokens(SizeType iter = 0) const override
|
||||
{
|
||||
TensorPtr newTokensView = std::move(ITensor::slice(mJointDecodingOutput->newTokensSteps, iter, 1));
|
||||
TensorPtr newTokensView = ITensor::slice(mJointDecodingOutput->newTokensSteps, iter, 1);
|
||||
newTokensView->squeeze(0);
|
||||
return ITensor::slice(newTokensView, 0, mActualBatchSize);
|
||||
}
|
||||
@ -149,7 +144,7 @@ public:
|
||||
//! @returns [batchSize], the number of generation steps executed on each request
|
||||
[[nodiscard]] std::vector<SizeType> getNbSteps() const override
|
||||
{
|
||||
return std::vector<SizeType>(mNbSteps.begin(), mNbSteps.begin() + mActualBatchSize);
|
||||
return {mNbSteps.begin(), mNbSteps.begin() + mActualBatchSize};
|
||||
}
|
||||
|
||||
//! @returns [1], number of finished sequences, in pinned host memory
|
||||
@ -160,7 +155,7 @@ public:
|
||||
|
||||
private:
|
||||
//! @brief Gather final beam search results for request `batchIdx`.
|
||||
CudaEvent postProcessRequest(SizeType batchIdx) const;
|
||||
[[nodiscard]] CudaEvent postProcessRequest(SizeType batchIdx) const;
|
||||
|
||||
//! @brief Initialize the decoder at `batchIdx` with a new `request`.
|
||||
void newRequest(SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig);
|
||||
|
||||
@ -31,7 +31,6 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <memory>
|
||||
#include <ostream>
|
||||
#include <stdexcept>
|
||||
#include <type_traits>
|
||||
#include <typeinfo>
|
||||
#include <vector>
|
||||
|
||||
@ -16,14 +16,12 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/cudaEvent.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/iStatefulGptDecoder.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/utils/sessionUtils.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@ -156,30 +154,30 @@ public:
|
||||
//! @param batchIdx index of the batch
|
||||
//! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token
|
||||
//! ids without padding for request `batchIdx`, on gpu
|
||||
virtual TensorPtr getOutputIds(SizeType batchIdx) const = 0;
|
||||
[[nodiscard]] virtual TensorPtr getOutputIds(SizeType batchIdx) const = 0;
|
||||
|
||||
//! @brief Gather final beam search results for request `batchIdx`.
|
||||
//! Result will only be available after event returned
|
||||
virtual CudaEvent finalize(SizeType batchIdx) const = 0;
|
||||
[[nodiscard]] virtual CudaEvent finalize(SizeType batchIdx) const = 0;
|
||||
|
||||
//! @returns [batchSize (actual)], marks finished requests (per batch)
|
||||
virtual std::vector<bool> getFinished() const = 0;
|
||||
[[nodiscard]] virtual std::vector<bool> getFinished() const = 0;
|
||||
|
||||
//! @returns [batchSize, beamWidth], cumulative log probabilities (per beam), on gpu
|
||||
virtual TensorPtr getCumLogProbs() const = 0;
|
||||
[[nodiscard]] virtual TensorPtr getCumLogProbs() const = 0;
|
||||
|
||||
//! @returns [beamWidth], cumulative log probabilities (per beam) for request batchIdx, on gpu
|
||||
virtual TensorPtr getCumLogProbs(SizeType batchIdx) const = 0;
|
||||
[[nodiscard]] virtual TensorPtr getCumLogProbs(SizeType batchIdx) const = 0;
|
||||
|
||||
//! @returns [batchSize, beamWidth, maxSeqLen], log probabilities (per beam), on gpu
|
||||
virtual TensorPtr getLogProbs() const = 0;
|
||||
[[nodiscard]] virtual TensorPtr getLogProbs() const = 0;
|
||||
|
||||
//! @returns [beamWidth, maxSeqLen], cumulative log probabilities (per beam) for request batchIdx, on gpu
|
||||
virtual TensorPtr getLogProbs(SizeType batchIdx) const = 0;
|
||||
[[nodiscard]] virtual TensorPtr getLogProbs(SizeType batchIdx) const = 0;
|
||||
|
||||
virtual TensorPtr getParentIds() const = 0;
|
||||
[[nodiscard]] virtual TensorPtr getParentIds() const = 0;
|
||||
|
||||
virtual std::vector<SizeType> getNbSteps() const = 0;
|
||||
[[nodiscard]] virtual std::vector<SizeType> getNbSteps() const = 0;
|
||||
|
||||
//! @brief Initialize batched decoder at seqSlots with a new `requests`.
|
||||
virtual void newRequests(std::vector<SizeType> const& seqSlots, std::vector<decoder_batch::Request> const& requests,
|
||||
|
||||
@ -23,10 +23,8 @@
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/samplingConfig.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
|
||||
@ -102,25 +100,25 @@ public:
|
||||
virtual void finalize() const = 0;
|
||||
|
||||
//! @returns [batchSize, beamWidth, maxSequenceLength], all token ids, on gpu
|
||||
virtual TensorPtr getOutputIds() const = 0;
|
||||
[[nodiscard]] virtual TensorPtr getOutputIds() const = 0;
|
||||
|
||||
//! @returns [batchSize, maxBeamWidth], cumulative log probabilities (per beam), on gpu
|
||||
virtual TensorPtr getCumLogProbs() const = 0;
|
||||
[[nodiscard]] virtual TensorPtr getCumLogProbs() const = 0;
|
||||
|
||||
//! @returns [batchSize, maxBeamWidth, maxSequenceLength], log probabilities (per beam), on gpu
|
||||
virtual TensorPtr getLogProbs() const = 0;
|
||||
[[nodiscard]] virtual TensorPtr getLogProbs() const = 0;
|
||||
|
||||
//! @brief Get tokens generated in one step of last forward pass
|
||||
//! @param iter The iteration within [0; maxTokensPerStep) for which to get the tokens
|
||||
//! @returns [batchSize, beamWidth], tokens generated in `iter` (per beam), on gpu
|
||||
virtual TensorPtr getNewTokens(SizeType iter = 0) const = 0;
|
||||
[[nodiscard]] virtual TensorPtr getNewTokens(SizeType iter = 0) const = 0;
|
||||
|
||||
//! @brief Get maxTokensPerStep tokens generated in the last forward pass
|
||||
//! @returns [maxTokensPerStep, batchSize, maxBeamWidth], tokens generated in last forward pass, on gpu
|
||||
virtual TensorPtr getAllNewTokens() const = 0;
|
||||
[[nodiscard]] virtual TensorPtr getAllNewTokens() const = 0;
|
||||
|
||||
//! @returns [1], number of finished sequences, in pinned host memory
|
||||
virtual TensorPtr getNbFinished() const = 0;
|
||||
[[nodiscard]] virtual TensorPtr getNbFinished() const = 0;
|
||||
|
||||
virtual ~IStatefulGptDecoder() = default;
|
||||
|
||||
|
||||
@ -30,7 +30,6 @@
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <ostream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
|
||||
@ -18,7 +18,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/kernels/customAllReduceKernels.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
|
||||
|
||||
@ -19,8 +19,8 @@
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <atomic>
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
@ -112,22 +112,22 @@ public:
|
||||
auto const sizeDiff = -static_cast<DiffType>(size);
|
||||
if constexpr (T == MemoryType::kGPU)
|
||||
{
|
||||
mGpu -= std::min(size, mGpu);
|
||||
mGpu -= size;
|
||||
mGpuDiff = sizeDiff;
|
||||
}
|
||||
else if constexpr (T == MemoryType::kCPU)
|
||||
{
|
||||
mCpu -= std::min(size, mCpu);
|
||||
mCpu -= size;
|
||||
mCpuDiff = sizeDiff;
|
||||
}
|
||||
else if constexpr (T == MemoryType::kPINNED)
|
||||
{
|
||||
mPinned -= std::min(size, mPinned);
|
||||
mPinned -= size;
|
||||
mPinnedDiff = sizeDiff;
|
||||
}
|
||||
else if constexpr (T == MemoryType::kUVM)
|
||||
{
|
||||
mUVM -= std::min(size, mUVM);
|
||||
mUVM -= size;
|
||||
mUVMDiff = sizeDiff;
|
||||
}
|
||||
else
|
||||
@ -138,11 +138,7 @@ public:
|
||||
|
||||
void deallocate(MemoryType memoryType, SizeType size);
|
||||
|
||||
static MemoryCounters& getInstance()
|
||||
{
|
||||
static thread_local MemoryCounters mInstance;
|
||||
return mInstance;
|
||||
}
|
||||
static MemoryCounters& getInstance();
|
||||
|
||||
static std::string bytesToString(SizeType bytes, int precision = 2);
|
||||
|
||||
@ -151,8 +147,8 @@ public:
|
||||
[[nodiscard]] std::string toString() const;
|
||||
|
||||
private:
|
||||
SizeType mGpu{}, mCpu{}, mPinned{}, mUVM{};
|
||||
DiffType mGpuDiff{}, mCpuDiff{}, mPinnedDiff{}, mUVMDiff{};
|
||||
std::atomic<SizeType> mGpu{}, mCpu{}, mPinned{}, mUVM{};
|
||||
std::atomic<DiffType> mGpuDiff{}, mCpuDiff{}, mPinnedDiff{}, mUVMDiff{};
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -16,12 +16,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
|
||||
@ -72,6 +72,7 @@ public:
|
||||
{
|
||||
TLLM_CHECK(configs.size() > 0);
|
||||
beamWidth = configs.front().beamWidth;
|
||||
normalizeLogProbs = configs.front().normalizeLogProbs;
|
||||
temperature = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].temperature; });
|
||||
minLength = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].minLength; });
|
||||
repetitionPenalty
|
||||
@ -87,6 +88,7 @@ public:
|
||||
beamSearchDiversityRate
|
||||
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].beamSearchDiversityRate; });
|
||||
lengthPenalty = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].lengthPenalty; });
|
||||
earlyStopping = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].earlyStopping; });
|
||||
draftAcceptanceThreshold
|
||||
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].draftAcceptanceThreshold; });
|
||||
}
|
||||
@ -121,6 +123,7 @@ public:
|
||||
SET_FROM_OPTIONAL(presencePenalty, PresencePenalty, FloatType)
|
||||
SET_FROM_OPTIONAL(frequencyPenalty, FrequencyPenalty, FloatType)
|
||||
SET_FROM_OPTIONAL(lengthPenalty, LengthPenalty, FloatType)
|
||||
SET_FROM_OPTIONAL(earlyStopping, EarlyStopping, SizeType)
|
||||
#undef SET_FROM_OPTIONAL
|
||||
}
|
||||
|
||||
@ -142,11 +145,12 @@ public:
|
||||
OptVec<SizeType> topPResetIds; // [batch_size]
|
||||
|
||||
// beam search layer
|
||||
OptVec<FloatType> beamSearchDiversityRate;
|
||||
OptVec<FloatType> lengthPenalty;
|
||||
OptVec<FloatType> beamSearchDiversityRate; // [1] or [batch_size]
|
||||
OptVec<FloatType> lengthPenalty; // [1] or [batch_size]
|
||||
OptVec<SizeType> earlyStopping; // [1] or [batch_size]
|
||||
|
||||
// speculative decoding
|
||||
OptVec<FloatType> draftAcceptanceThreshold;
|
||||
// speculative decoding, only the first value is used (in gptDecoderBatch.cpp)
|
||||
OptVec<FloatType> draftAcceptanceThreshold; // [1] or [batch_size]
|
||||
|
||||
std::optional<bool> normalizeLogProbs;
|
||||
};
|
||||
|
||||
@ -195,7 +195,8 @@ set(TRTLLM_LINK_LIBS
|
||||
${TRT_LIB}
|
||||
common_src
|
||||
kernels_src
|
||||
cutlass_src
|
||||
cutlass_src_pre_hopper
|
||||
cutlass_src_hopper
|
||||
layers_src
|
||||
runtime_src)
|
||||
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4ba61c04ed7623fc44b5364802c1893fa824467455f4a9fe8245d5d51fef97e6
|
||||
size 2172266
|
||||
oid sha256:c9fd644e0a38b1d4d1a54d4b7b834cc6b0110a5771fcfc480e96795b3f9bc892
|
||||
size 2081046
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bf4afdfd281029c8e4bf0af548529b94a4a6d0f9bb5148ae10423e5e0275db06
|
||||
size 2195822
|
||||
oid sha256:90436c59eb243a0156e3f0aa95412a7caacbefdcde768c158edc4b821044dfd1
|
||||
size 2102486
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
4c405d39a0cbb93d44a5758480a1a223 libtensorrt_llm_batch_manager_static.a
|
||||
68aea75a2ed5b219eec5a0f77ce33482 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
9b63c754d2a1edc7a17106e83c3e131d312f0a80 commit
|
||||
f53c02e3829b516a6e9221745bcbacbd libtensorrt_llm_batch_manager_static.a
|
||||
9e92e5dbb104e3e676952ea40c81916f libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
25adff90cc350eb9ca9804051a08de80d547c113 commit
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:39835ca321e9c45d3b554ebceb1734966b75f83dbe8c550cc44846fb4fae8f72
|
||||
size 2110728
|
||||
oid sha256:c3433d7b52bb6dcac32111172cb6201a9fee56e739f3660895083baebd1b89ee
|
||||
size 2033616
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:789c2eba349161e84a76b95b23f8294cf3bdcf855871672d76722c4ae858d81b
|
||||
size 2091842
|
||||
oid sha256:fb3f4145881984de6268c34f7e5d452f78f54952f454f747a1cd52bc3171de62
|
||||
size 2012002
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
30a6c963121b3cfda21dc0117b7984e1 libtensorrt_llm_batch_manager_static.a
|
||||
0d2d2e3157201f6336d749b3e6f994bc libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
d60b12741e940f56addaf2d92e78b50f libtensorrt_llm_batch_manager_static.a
|
||||
c55e606a3430d3a56cee3968a77b46f1 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
|
||||
@ -172,6 +172,11 @@ void MpiComm::allgather(const void* sendbuf, void* recvbuf, int count, MpiType d
|
||||
MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm));
|
||||
}
|
||||
|
||||
void MpiComm::mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const
|
||||
{
|
||||
MPICHECK(MPI_Mprobe(source, tag, mComm, msg, status));
|
||||
}
|
||||
|
||||
int MpiComm::getRank() const
|
||||
{
|
||||
int rank = 0;
|
||||
|
||||
@ -34,6 +34,8 @@ enum class CutlassTileConfig
|
||||
CtaShape128x128x8_WarpShape64x64x8,
|
||||
|
||||
// TensorCore configs CTA_N = 128, CTA_K = 64
|
||||
// Warp configs for M=16
|
||||
CtaShape16x128x64_WarpShape16x32x64,
|
||||
// Warp configs for M=32
|
||||
CtaShape32x128x64_WarpShape32x32x64,
|
||||
|
||||
@ -50,7 +52,10 @@ enum class CutlassTileConfig
|
||||
CtaShape128x256x64_WarpShape64x64x64,
|
||||
|
||||
// Warp configs for M=256
|
||||
CtaShape256x128x64_WarpShape64x64x64
|
||||
CtaShape256x128x64_WarpShape64x64x64,
|
||||
|
||||
// TensorCore config CTA_N = 256, CTA_K = 64
|
||||
CtaShape16x256x64_WarpShape16x64x64
|
||||
};
|
||||
|
||||
enum class SplitKStyle
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:49c84a22cee9e6c3a975db08d8d0d8dbe88867e2eb4fc12a4b3ff6c1c90e8c21
|
||||
size 586202
|
||||
oid sha256:13e17e2d9a94d2bc1b131d096a3722a83a67ab115fa8271b57b27f7e2877bdc1
|
||||
size 587334
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fea93ae7d09e74b073a65d5d0ac34aec9ccc8f8299af1abd6826e97e9c8427f4
|
||||
size 589652
|
||||
oid sha256:45438204eba812694bd30b68cfc9bb2bc54a8a59c6c86e037bbc4ac7e5f8230c
|
||||
size 589438
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
73999f4c2b3a4328db454b7ab6fe86d3 libtensorrt_llm_executor_static.a
|
||||
df53aa83848b5ed75550a7b536ca02a4 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
9b63c754d2a1edc7a17106e83c3e131d312f0a80 commit
|
||||
835767a37292ea9786c0d6149ae270f4 libtensorrt_llm_executor_static.a
|
||||
1fe0c9ac7a1a35ce7d80676146867374 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
25adff90cc350eb9ca9804051a08de80d547c113 commit
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:643e546711fd33a85073560e3428c6a2f60525f7592aa3328043dfad61631c30
|
||||
size 586532
|
||||
oid sha256:7969768d3b9a65182ee519c60e11f27b0a088c2c0b732f3780d7c0c563dbb180
|
||||
size 587776
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:28059131a9325c88bd362cb12c57a2b2e47d3e0aac140e5d1cf9a7020a81999e
|
||||
size 570860
|
||||
oid sha256:98d9b7c4a586f0be0499a0df487cacba69985ce43ca5fd543c90c6a368c91b67
|
||||
size 571150
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
fa89714705a1915f052c635a07dc4c73 libtensorrt_llm_executor_static.a
|
||||
83cbfaf10bedd7d8edeab33552dcf3df libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
6771f94e0bce39c6cab391cf1f92484c libtensorrt_llm_executor_static.a
|
||||
84b7550448f8710de17644a5d404178f libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
|
||||
@ -28,7 +28,7 @@ if(FAST_BUILD)
|
||||
"decoderMaskedMultiheadAttention(48|80|96|112|144|160|192|224).*cu$")
|
||||
endif()
|
||||
|
||||
add_library(kernels_src OBJECT ${SRC_CPP} ${SRC_CU})
|
||||
add_library(kernels_src STATIC ${SRC_CPP} ${SRC_CU})
|
||||
set_property(TARGET kernels_src PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET kernels_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ namespace kernels
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty)
|
||||
{
|
||||
// score = log(prob) / (length)^length_penalty.
|
||||
// score = log(prob) / (length ^ length_penalty)
|
||||
if (length_penalty == 0.0f || length == 1)
|
||||
{
|
||||
return log_prob;
|
||||
|
||||
@ -35,35 +35,42 @@ namespace kernels
|
||||
// After we collect `beam_width` beams, we will sort them by their norm_scores.
|
||||
struct BeamHypotheses
|
||||
{
|
||||
int* output_ids_tgt = nullptr;
|
||||
int* sequence_lengths_tgt = nullptr;
|
||||
float* cum_log_probs = nullptr; // cum_log
|
||||
float* normed_scores = nullptr; // cum_log / (length**length_penalty)
|
||||
float* log_probs = nullptr; // log probs of each generated token
|
||||
float* min_normed_scores = nullptr; // record the min normed scores for each batch
|
||||
int* num_beams = nullptr; // the number of finished beams we collect
|
||||
bool* is_done = nullptr;
|
||||
// TODO: simplify the pointers
|
||||
// Pointers initialized in function prepareOutputs in gptDecoder.cpp
|
||||
bool* is_done{nullptr}; // [batchSize], whether the batch is finished
|
||||
const int* input_lengths{nullptr}; // [batchSize]
|
||||
float* cum_log_probs{nullptr}; // [batchSize, 2 * beamWidth], outputs.cum_log_probs->template getPtr<float>()
|
||||
float* log_probs{nullptr}; // [batchSize, 2 * beamWidth, maxSeqLen], not used?
|
||||
float* min_normed_scores{nullptr}; // [batchSize], worst normed scores for each batch
|
||||
float* normed_scores{nullptr}; // [batchSize, 2 * beamWidth], cum_log / (length ^ length_penalty)
|
||||
int* num_beams{nullptr}; // [batchSize], count of finished beams for each batch
|
||||
int* output_ids_tgt{nullptr}; // [batchSize, 2 * beamWidth, maxSeqLen],
|
||||
int* sequence_lengths_tgt{nullptr}; // [batchSize, 2 * beamWidth], different from sequence_lengths_src
|
||||
|
||||
// Used to set inputs
|
||||
const int* output_ids_src;
|
||||
const int** output_ids_src_ptr;
|
||||
const int* parent_ids_src;
|
||||
const int** parent_ids_src_ptr;
|
||||
const int* sequence_lengths_src;
|
||||
const int* end_ids;
|
||||
const float* log_probs_src;
|
||||
const int* input_lengths;
|
||||
// Pointers initialized in function invokeSoftMax in onlineBeamSearchLayer.cu
|
||||
const int* end_ids{nullptr}; // get from SoftmaxParams
|
||||
const int* output_ids_src{nullptr}; // for gatherTree
|
||||
const int* parent_ids_src{nullptr}; // for gatherTree
|
||||
const int** output_ids_src_ptr{nullptr}; // get from BeamSearchOutputParams for reading
|
||||
const int** parent_ids_src_ptr{nullptr}; // get from BeamSearchOutputParams for reading
|
||||
float* log_probs_src{nullptr}; // get from outputs.output_log_probs
|
||||
int* sequence_lengths_src{nullptr}; // get from BeamSearchOutputParams
|
||||
// For reading in function invokeTopkSoftMax but reading and writing in function invokeUpdate
|
||||
int** output_ids_tgt_ptr{nullptr}; // get from BeamSearchOutputParams for writing
|
||||
int** parent_ids_tgt_ptr{nullptr}; // get from BeamSearchOutputParams for writing
|
||||
|
||||
// some variables for kernels
|
||||
int step;
|
||||
int ite;
|
||||
int batch_size;
|
||||
int local_batch_size;
|
||||
int max_seq_len;
|
||||
float* length_penalties;
|
||||
|
||||
bool early_stopping = true;
|
||||
bool is_return_normed_score = true; // return normed_cum_log_probs or cum_log_probs
|
||||
// Other scalar values and buffers
|
||||
int batch_size{0};
|
||||
int beam_width{0};
|
||||
int ite{0};
|
||||
int local_batch_size{0};
|
||||
int max_seq_len{0};
|
||||
int step{0}; // useless in online version of beam search
|
||||
int vocab_size{0};
|
||||
float* diversity_rates{nullptr};
|
||||
float* length_penalties{nullptr};
|
||||
int* early_stoppings{nullptr};
|
||||
bool is_return_normed_score{true}; // return normed_cum_log_probs or cum_log_probs
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
||||
@ -57,9 +57,17 @@ endif()
|
||||
|
||||
file(GLOB_RECURSE CU_INSTANTIATIONS ${CMAKE_CURRENT_BINARY_DIR}/*.cu)
|
||||
|
||||
add_library(cutlass_src OBJECT ${SRC_CPP} ${SRC_CU} ${CU_INSTANTIATIONS})
|
||||
set_property(TARGET cutlass_src PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET cutlass_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
|
||||
add_library(cutlass_src_pre_hopper STATIC ${SRC_CPP} ${SRC_CU})
|
||||
set_property(TARGET cutlass_src_pre_hopper PROPERTY POSITION_INDEPENDENT_CODE
|
||||
ON)
|
||||
set_property(TARGET cutlass_src_pre_hopper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS
|
||||
ON)
|
||||
|
||||
add_library(cutlass_src_hopper STATIC ${CU_INSTANTIATIONS})
|
||||
set_property(TARGET cutlass_src_hopper PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET cutlass_src_hopper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
|
||||
|
||||
add_dependencies(cutlass_src_hopper cutlass_src_pre_hopper)
|
||||
|
||||
# Note - we deliberately do not include 90a PTX (even when 9.0+PTX is
|
||||
# specified). This is because sm_90a has arch conditional instructions that are
|
||||
@ -70,15 +78,21 @@ if("9.0" IN_LIST TORCH_CUDA_ARCH_LIST
|
||||
OR TORCH_CUDA_ARCH_LIST STREQUAL "Auto")
|
||||
message(STATUS "MANUALLY APPENDING FLAG TO COMPILE FOR SM_90a.")
|
||||
target_compile_options(
|
||||
cutlass_src
|
||||
cutlass_src_pre_hopper
|
||||
PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-gencode=arch=compute_90a,code=sm_90a>)
|
||||
target_compile_options(
|
||||
cutlass_src_hopper
|
||||
PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-gencode=arch=compute_90a,code=sm_90a>)
|
||||
|
||||
# Hopper kernels require cuda lib for TMA APIs
|
||||
target_link_libraries(cutlass_src PRIVATE CUDA::cuda_driver)
|
||||
target_link_libraries(cutlass_src_pre_hopper PRIVATE CUDA::cuda_driver)
|
||||
target_link_libraries(cutlass_src_hopper PRIVATE CUDA::cuda_driver)
|
||||
|
||||
# No kernels should be parsed, unless hopper is specified. This is a build
|
||||
# time improvement
|
||||
target_compile_definitions(cutlass_src
|
||||
target_compile_definitions(cutlass_src_pre_hopper
|
||||
PRIVATE COMPILE_HOPPER_MIXED_INPUT_GEMMS)
|
||||
target_compile_definitions(cutlass_src_hopper
|
||||
PRIVATE COMPILE_HOPPER_MIXED_INPUT_GEMMS)
|
||||
endif()
|
||||
|
||||
@ -87,5 +101,9 @@ endif()
|
||||
# compilation output.
|
||||
if(NOT WIN32)
|
||||
target_compile_options(
|
||||
cutlass_src PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-psabi>)
|
||||
cutlass_src_pre_hopper
|
||||
PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-psabi>)
|
||||
target_compile_options(
|
||||
cutlass_src_hopper
|
||||
PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-psabi>)
|
||||
endif()
|
||||
|
||||
@ -52,6 +52,8 @@ TileShape get_cta_shape_for_config(CutlassTileConfig tile_config)
|
||||
{
|
||||
switch (tile_config)
|
||||
{
|
||||
case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: return TileShape{16, 128};
|
||||
case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: return TileShape{16, 256};
|
||||
case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: return TileShape{32, 128};
|
||||
case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: return TileShape{64, 64};
|
||||
case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64:
|
||||
@ -145,7 +147,9 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
|
||||
case CutlassGemmType::WeightOnly:
|
||||
if (sm >= 75)
|
||||
{
|
||||
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
|
||||
return {CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64,
|
||||
CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64,
|
||||
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
|
||||
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
|
||||
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64};
|
||||
}
|
||||
|
||||
@ -127,9 +127,8 @@ LayoutDetails getLayoutDetailsForArch(QuantType quant_type)
|
||||
return details;
|
||||
}
|
||||
|
||||
LayoutDetails getLayoutDetailsForTransform(QuantType quant_type)
|
||||
LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch)
|
||||
{
|
||||
const int arch = getSMVersion();
|
||||
if (arch >= 70 && arch < 75)
|
||||
{
|
||||
return getLayoutDetailsForArch<cutlass::arch::Sm70>(quant_type);
|
||||
@ -534,10 +533,15 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const
|
||||
}
|
||||
|
||||
void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight,
|
||||
const std::vector<size_t>& shape, QuantType quant_type)
|
||||
const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave)
|
||||
{
|
||||
const int arch = getSMVersion();
|
||||
LayoutDetails details = getLayoutDetailsForTransform(quant_type);
|
||||
int arch = getSMVersion();
|
||||
if (force_interleave && arch == 90)
|
||||
{
|
||||
// Workaround for MOE which doesn't have specialised Hopper kernels yet
|
||||
arch = 80;
|
||||
}
|
||||
LayoutDetails details = getLayoutDetailsForTransform(quant_type, arch);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
|
||||
|
||||
@ -616,7 +620,8 @@ Outputs
|
||||
|
||||
template <typename ComputeType, typename WeightType>
|
||||
void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight,
|
||||
ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector<size_t>& shape, QuantType quant_type)
|
||||
ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector<size_t>& shape, QuantType quant_type,
|
||||
bool force_interleave)
|
||||
{
|
||||
|
||||
TLLM_CHECK_WITH_INFO(processed_quantized_weight, "Processed quantized tensor is NULL");
|
||||
@ -719,48 +724,52 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
|
||||
}
|
||||
}
|
||||
|
||||
preprocess_weights_for_mixed_gemm(processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type);
|
||||
preprocess_weights_for_mixed_gemm(
|
||||
processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type, force_interleave);
|
||||
}
|
||||
|
||||
template void symmetric_quantize<half, float>(
|
||||
int8_t*, int8_t*, half*, const float*, const std::vector<size_t>&, QuantType);
|
||||
int8_t*, int8_t*, half*, const float*, const std::vector<size_t>&, QuantType, bool);
|
||||
|
||||
template void symmetric_quantize<half, half>(
|
||||
int8_t*, int8_t*, half*, const half*, const std::vector<size_t>&, QuantType);
|
||||
int8_t*, int8_t*, half*, const half*, const std::vector<size_t>&, QuantType, bool);
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>(
|
||||
int8_t*, int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType);
|
||||
int8_t*, int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType, bool);
|
||||
|
||||
template void symmetric_quantize<__nv_bfloat16, float>(
|
||||
int8_t*, int8_t*, __nv_bfloat16*, const float*, const std::vector<size_t>&, QuantType);
|
||||
int8_t*, int8_t*, __nv_bfloat16*, const float*, const std::vector<size_t>&, QuantType, bool);
|
||||
#endif
|
||||
|
||||
template <typename ComputeType, typename WeightType>
|
||||
void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, const WeightType* input_weight_ptr,
|
||||
const std::vector<size_t>& shape, QuantType quant_type)
|
||||
const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave)
|
||||
{
|
||||
symmetric_quantize(processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type);
|
||||
symmetric_quantize(
|
||||
processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type, force_interleave);
|
||||
}
|
||||
|
||||
template void symmetric_quantize<float, float>(int8_t*, float*, const float*, const std::vector<size_t>&, QuantType);
|
||||
template void symmetric_quantize<float, float>(
|
||||
int8_t*, float*, const float*, const std::vector<size_t>&, QuantType, bool);
|
||||
|
||||
template void symmetric_quantize<half, float>(int8_t*, half*, const float*, const std::vector<size_t>&, QuantType);
|
||||
template void symmetric_quantize<half, float>(
|
||||
int8_t*, half*, const float*, const std::vector<size_t>&, QuantType, bool);
|
||||
|
||||
template void symmetric_quantize<half, half>(int8_t*, half*, const half*, const std::vector<size_t>&, QuantType);
|
||||
template void symmetric_quantize<half, half>(int8_t*, half*, const half*, const std::vector<size_t>&, QuantType, bool);
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>(
|
||||
int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType);
|
||||
int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType, bool);
|
||||
|
||||
template void symmetric_quantize<__nv_bfloat16, half>(
|
||||
int8_t*, __nv_bfloat16*, const half*, const std::vector<size_t>&, QuantType);
|
||||
int8_t*, __nv_bfloat16*, const half*, const std::vector<size_t>&, QuantType, bool);
|
||||
|
||||
template void symmetric_quantize<half, __nv_bfloat16>(
|
||||
int8_t*, half*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType);
|
||||
int8_t*, half*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType, bool);
|
||||
|
||||
template void symmetric_quantize<__nv_bfloat16, float>(
|
||||
int8_t*, __nv_bfloat16*, const float*, const std::vector<size_t>&, QuantType);
|
||||
int8_t*, __nv_bfloat16*, const float*, const std::vector<size_t>&, QuantType, bool);
|
||||
#endif
|
||||
|
||||
} // namespace cutlass_kernels
|
||||
|
||||
@ -47,17 +47,18 @@ void subbyte_transpose(int8_t* transposed_quantized_tensor, const int8_t* quanti
|
||||
void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type);
|
||||
|
||||
void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight,
|
||||
const std::vector<size_t>& shape, QuantType quant_type);
|
||||
const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave = false);
|
||||
|
||||
template <typename ComputeType, typename WeightType>
|
||||
void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, const WeightType* input_weight_ptr,
|
||||
const std::vector<size_t>& shape, QuantType quant_type);
|
||||
const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave);
|
||||
|
||||
// This is exposed so that we can write tests that use the processed weights for CUTLASS but the unprocessed weight
|
||||
// to implement a simple reference implementation.
|
||||
template <typename ComputeType, typename WeightType>
|
||||
void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight,
|
||||
ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector<size_t>& shape, QuantType quant_type);
|
||||
ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector<size_t>& shape, QuantType quant_type,
|
||||
bool force_interleave);
|
||||
|
||||
} // namespace cutlass_kernels
|
||||
} // namespace kernels
|
||||
|
||||
@ -84,7 +84,7 @@ public:
|
||||
|
||||
protected:
|
||||
static constexpr int SPLIT_K_LIMIT = 7;
|
||||
static constexpr int MIN_M_TILE = 32;
|
||||
static constexpr int MIN_M_TILE = 16;
|
||||
static constexpr int MIN_N_TILE = 64;
|
||||
};
|
||||
|
||||
|
||||
@ -324,6 +324,26 @@ void dispatch_gemm_to_cutlass(const ActivationType* A, const WeightType* B, cons
|
||||
// best for mixed type gemms.
|
||||
switch (gemm_config.tile_config)
|
||||
{
|
||||
case tkc::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
|
||||
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
|
||||
if constexpr (arch::kMinComputeCapability >= 75)
|
||||
{
|
||||
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
|
||||
cutlass::gemm::GemmShape<16, 128, 64>, cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales,
|
||||
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
|
||||
stream, occupancy);
|
||||
}
|
||||
break;
|
||||
case tkc::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
|
||||
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
|
||||
if constexpr (arch::kMinComputeCapability >= 75)
|
||||
{
|
||||
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
|
||||
cutlass::gemm::GemmShape<16, 256, 64>, cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales,
|
||||
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
|
||||
stream, occupancy);
|
||||
}
|
||||
break;
|
||||
case tkc::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
|
||||
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
|
||||
cutlass::gemm::GemmShape<32, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales,
|
||||
@ -337,11 +357,8 @@ void dispatch_gemm_to_cutlass(const ActivationType* A, const WeightType* B, cons
|
||||
stream, occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
|
||||
if (arch::kMinComputeCapability < 75)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false, "Invalid config on Volta");
|
||||
}
|
||||
else
|
||||
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
|
||||
if constexpr (arch::kMinComputeCapability >= 75)
|
||||
{
|
||||
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales,
|
||||
@ -433,41 +450,6 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
|
||||
}
|
||||
}
|
||||
|
||||
// Disabled since the fused GEMM, activation kernels will not be used in v1.
|
||||
|
||||
// template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
|
||||
// void CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::gemm_bias_act(const T* A, const WeightType* B, const T*
|
||||
// weight_scales,
|
||||
// const T* biases, T* C, int m, int n, int k, ActivationType activation_type, char* workspace_ptr,
|
||||
// const size_t workspace_bytes, cudaStream_t stream)
|
||||
// {
|
||||
// TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
|
||||
// switch (activation_type)
|
||||
// {
|
||||
// case ActivationType::Relu:
|
||||
// run_gemm<tkc::EpilogueOpBiasReLU>(
|
||||
// A, B, weight_scales, biases, C, m, n, k, workspace_ptr, workspace_bytes, stream);
|
||||
// break;
|
||||
// case ActivationType::Gelu:
|
||||
// run_gemm<tkc::EpilogueOpBiasFtGelu>(
|
||||
// A, B, weight_scales, biases, C, m, n, k, workspace_ptr, workspace_bytes, stream);
|
||||
// break;
|
||||
// case ActivationType::Silu:
|
||||
// run_gemm<tkc::EpilogueOpBiasSilu>(
|
||||
// A, B, weight_scales, biases, C, m, n, k, workspace_ptr, workspace_bytes, stream);
|
||||
// break;
|
||||
// case ActivationType::Identity:
|
||||
// run_gemm<tkc::EpilogueOpBias>(A, B, weight_scales, biases, C, m, n, k, workspace_ptr, workspace_bytes,
|
||||
// stream); break;
|
||||
// case ActivationType::InvalidType: TLLM_CHECK_WITH_INFO(false, "Activation type for fpA_intB must be
|
||||
// valid."); break; default:
|
||||
// {
|
||||
// TLLM_CHECK_WITH_INFO(false, "Invalid activation type.");
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
||||
typename BiasType, typename OutputType>
|
||||
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
|
||||
|
||||
@ -231,6 +231,24 @@ void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_s
|
||||
{
|
||||
switch (gemm_config.tile_config)
|
||||
{
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
|
||||
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
|
||||
if constexpr (arch::kMinComputeCapability >= 75)
|
||||
{
|
||||
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 128, 64>,
|
||||
cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert,
|
||||
total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
}
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
|
||||
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
|
||||
if constexpr (arch::kMinComputeCapability >= 75)
|
||||
{
|
||||
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 256, 64>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert,
|
||||
total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
}
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
|
||||
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows,
|
||||
@ -266,6 +284,24 @@ void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_s
|
||||
{
|
||||
switch (gemm_config.tile_config)
|
||||
{
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
|
||||
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
|
||||
if constexpr (arch::kMinComputeCapability >= 75)
|
||||
{
|
||||
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 128, 64>,
|
||||
cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert,
|
||||
total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
}
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
|
||||
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
|
||||
if constexpr (arch::kMinComputeCapability >= 75)
|
||||
{
|
||||
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 256, 64>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert,
|
||||
total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy);
|
||||
}
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
|
||||
dispatchGemmConfig<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows,
|
||||
|
||||
@ -53,6 +53,7 @@ struct XQAParams
|
||||
// almost copy from GPTAttentionPluginCommon.
|
||||
// maybe use one struct for parameters in GPTAttentionPluginCommon and share the same here.
|
||||
int32_t generation_input_length;
|
||||
int32_t layer_idx = 0;
|
||||
int32_t num_q_heads = 0;
|
||||
int32_t num_kv_heads = 0;
|
||||
int32_t head_size = 0;
|
||||
|
||||
@ -121,7 +121,7 @@ __global__ void gatherTree(gatherTreeParam param)
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T applyLengthPenalty(T logProb, int length, float lengthPenalty)
|
||||
{
|
||||
// score = log(prob) / (length)^lengthPenalty.
|
||||
// score = log(prob) / (length ^ lengthPenalty)
|
||||
if (lengthPenalty == 0.0f || length == 1)
|
||||
{
|
||||
return logProb;
|
||||
|
||||
@ -46,7 +46,8 @@ struct gatherTreeParam
|
||||
int32_t* outputIds = nullptr; // the buffer to put finalized ids
|
||||
cudaStream_t stream;
|
||||
float* cumLogProbs = nullptr; // [batchSize, beamWidth]
|
||||
float lengthPenalty = 1.0f; // on cpu
|
||||
float lengthPenalty = 1.0f;
|
||||
int earlyStopping = 1;
|
||||
};
|
||||
|
||||
/*
|
||||
|
||||
@ -62,7 +62,7 @@ void groupedGemm_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vect
|
||||
std::vector<void*> ptrC, std::vector<void*> ptrD, void* gemmParamsWorkSpace, int64_t gemmParamsWorkSpaceSize,
|
||||
void* gemmWorkSpace, int64_t gemmWorkspaceSize, nvinfer1::DataType dataType, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
using ElementA = cutlassType;
|
||||
using ElementB = cutlassType;
|
||||
using ElementOutput = cutlassType;
|
||||
@ -167,7 +167,7 @@ void groupedGemm_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vect
|
||||
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to run CUTLASS Grouped GEMM kernel.");
|
||||
|
||||
std::free(host_workspace);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <int M1, int N1, int K1, int M2, int N2, int K2>
|
||||
|
||||
@ -25,27 +25,20 @@ namespace kernels
|
||||
{
|
||||
|
||||
template <typename T, int MAX_K>
|
||||
void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const FinishedState* finished,
|
||||
const int* sequence_lengths, float* cum_log_probs, float* output_log_probs, int** output_ids_ptr,
|
||||
void* temp_storage, const int temp_storage_size, BeamHypotheses* beam_hyps, const int batch_size,
|
||||
const int beam_width, const int vocab_size, const int* end_ids, const float* diversity_rates,
|
||||
const float* length_penalties, cudaStream_t stream);
|
||||
void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const FinishedState* finished, float* cum_log_probs,
|
||||
void* temp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps, cudaStream_t stream);
|
||||
|
||||
#define CASE_K(MAX_K) \
|
||||
topK_softMax_kernelLauncher<T, MAX_K>(log_probs, bias, finished, sequence_lengths, cum_log_probs, \
|
||||
output_log_probs, output_ids_ptr, temp_storage, temp_storage_size, beam_hyps, batch_size, beam_width, \
|
||||
vocab_size, end_ids, diversity_rates, length_penalties, stream); \
|
||||
topK_softMax_kernelLauncher<T, MAX_K>( \
|
||||
log_probs, bias, finished, cum_log_probs, temp_storage, temp_storage_size, beam_hyps, stream); \
|
||||
break;
|
||||
|
||||
template <typename T>
|
||||
void invokeTopkSoftMax(const T* log_probs, const T* bias, const FinishedState* finished, const int* sequence_lengths,
|
||||
float* cum_log_probs, float* output_log_probs, int** output_ids_ptr, void* temp_storage,
|
||||
const int temp_storage_size, BeamHypotheses* beam_hyps, const int batch_size, const int beam_width,
|
||||
const int vocab_size, const int* end_ids, const float* diversity_rates, const float* length_penalties,
|
||||
cudaStream_t stream)
|
||||
void invokeTopkSoftMax(const T* log_probs, const T* bias, const FinishedState* finished, float* cum_log_probs,
|
||||
void* temp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps, cudaStream_t stream)
|
||||
{
|
||||
int log_beam_width(0);
|
||||
int recursor(beam_width - 1);
|
||||
int recursor(beam_hyps.beam_width - 1);
|
||||
while (recursor >>= 1)
|
||||
++log_beam_width;
|
||||
|
||||
@ -66,23 +59,19 @@ void invokeTopkSoftMax(const T* log_probs, const T* bias, const FinishedState* f
|
||||
CASE_K(64)
|
||||
#endif // FAST_BUILD
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
fmtstr("%s:%d Topk kernel of beam search does not support beam_width=%d", __FILE__, __LINE__, beam_width));
|
||||
throw std::runtime_error(fmtstr("%s:%d Topk kernel of beam search does not support beam_width=%d", __FILE__,
|
||||
__LINE__, beam_hyps.beam_width));
|
||||
}
|
||||
}
|
||||
|
||||
#undef CASE_K
|
||||
|
||||
template void invokeTopkSoftMax<float>(const float* log_probs, const float* bias, const FinishedState* finished,
|
||||
const int* sequence_lengths, float* cum_log_probs, float* output_log_probs, int** output_ids_ptr, void* tmp_storage,
|
||||
const int temp_storage_size, BeamHypotheses* beam_hyps, const int batch_size, const int beam_width,
|
||||
const int vocab_size, const int* end_ids, const float* diversity_rates, const float* length_penalties,
|
||||
float* cum_log_probs, void* tmp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps,
|
||||
cudaStream_t stream);
|
||||
|
||||
template void invokeTopkSoftMax<half>(const half* log_probs, const half* bias, const FinishedState* finished,
|
||||
const int* sequence_lengths, float* cum_log_probs, float* output_log_probs, int** output_ids_ptr, void* tmp_storage,
|
||||
const int temp_storage_size, BeamHypotheses* beam_hyps, const int batch_size, const int beam_width,
|
||||
const int vocab_size, const int* end_ids, const float* diversity_rates, const float* length_penalties,
|
||||
float* cum_log_probs, void* tmp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
@ -24,10 +24,8 @@ namespace kernels
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
void invokeTopkSoftMax(const T* log_probs, const T* bias, const FinishedState* finished, const int* sequence_lengths,
|
||||
float* cum_log_probs, float* output_log_probs, int** output_ids_ptr, void* tmp_storage, const int temp_storage_size,
|
||||
BeamHypotheses* beam_hyps, const int batch_size, const int beam_width, const int vocab_size, const int* end_ids,
|
||||
const float* diversity_rates, const float* length_penalties, cudaStream_t stream);
|
||||
void invokeTopkSoftMax(const T* log_probs, const T* bias, const FinishedState* finished, float* cum_log_probs,
|
||||
void* tmp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -44,7 +44,7 @@ static const int SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE = 256;
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty)
|
||||
{
|
||||
// score = log(prob) / (length)^length_penalty.
|
||||
// score = log(prob) / (length ^ length_penalty).
|
||||
if (length_penalty == 0.0f || length == 1)
|
||||
{
|
||||
return log_prob;
|
||||
@ -56,9 +56,10 @@ template <typename T, int MAX_K, int THREADBLOCK_SIZE>
|
||||
__launch_bounds__(THREADBLOCK_SIZE) __global__
|
||||
void batch_topK_kernel(int* topk_tmp_id_buf, T* topk_tmp_val_buf, int* id_buf)
|
||||
{
|
||||
int thread_id = threadIdx.x;
|
||||
int block_id = blockIdx.x;
|
||||
const int thread_id = threadIdx.x;
|
||||
const int block_id = blockIdx.x;
|
||||
TopK<T, MAX_K> partial;
|
||||
|
||||
if (thread_id == 0)
|
||||
{
|
||||
for (int i = 0; i < MAX_K; ++i)
|
||||
@ -70,7 +71,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
|
||||
int index = block_id * MAX_K * MAX_K;
|
||||
for (int i = 0; i < MAX_K * MAX_K; i++)
|
||||
{
|
||||
partial.insert((T) topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]);
|
||||
partial.insert(topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]);
|
||||
}
|
||||
|
||||
index = block_id * MAX_K;
|
||||
@ -85,9 +86,10 @@ template <typename T, int MAX_K, int THREADBLOCK_SIZE>
|
||||
__launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(const int* __restrict topk_tmp_id_buf,
|
||||
const T* __restrict topk_tmp_val_buf, int* __restrict id_buf, T* __restrict val_buf)
|
||||
{
|
||||
int thread_id = threadIdx.x;
|
||||
int block_id = blockIdx.x;
|
||||
const int thread_id = threadIdx.x;
|
||||
const int block_id = blockIdx.x;
|
||||
TopK<T, MAX_K> partial;
|
||||
|
||||
if (thread_id == 0)
|
||||
{
|
||||
for (int i = 0; i < MAX_K; ++i)
|
||||
@ -99,7 +101,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(const int*
|
||||
int index = block_id * MAX_K * MAX_K;
|
||||
for (int i = 0; i < MAX_K * MAX_K; i++)
|
||||
{
|
||||
partial.insert((T) topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]);
|
||||
partial.insert(topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]);
|
||||
}
|
||||
|
||||
index = block_id * MAX_K;
|
||||
@ -112,34 +114,37 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(const int*
|
||||
}
|
||||
|
||||
template <typename T, int MAX_K2, int THREADBLOCK_SIZE>
|
||||
__launch_bounds__(THREADBLOCK_SIZE) __global__
|
||||
void batch_topk_kernel(const int* __restrict x, const T* __restrict y, int** output_ids_ptr, float* __restrict v,
|
||||
float* output_log_probs, const FinishedState* finished, const int* sequence_lengths, BeamHypotheses beam_hyps,
|
||||
const int V, const int K, const int vocab_size, const float* length_penalties, const float* diversity_rates)
|
||||
__launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel(const int* __restrict topk_tmp_id_buf,
|
||||
const T* __restrict topk_tmp_val_buf, float* __restrict cum_log_probs, const FinishedState* finished,
|
||||
BeamHypotheses beam_hyps, const int candidate_size)
|
||||
{
|
||||
int thread_id = threadIdx.x;
|
||||
int vector_id = blockIdx.x;
|
||||
const int thread_id = threadIdx.x;
|
||||
const int vector_id = blockIdx.x;
|
||||
const int K{beam_hyps.beam_width};
|
||||
const int vocab_size{beam_hyps.vocab_size};
|
||||
const int global_batch_idx{beam_hyps.ite * beam_hyps.local_batch_size + vector_id};
|
||||
const T diversity_rate{diversity_rates[global_batch_idx]};
|
||||
const float length_penalty{length_penalties[global_batch_idx]};
|
||||
const T MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
|
||||
|
||||
// reposition x, y to data for the current vector
|
||||
x += vector_id * V;
|
||||
y += vector_id * V;
|
||||
|
||||
extern __shared__ char buf_s_[]; // intermediate result
|
||||
T* buf_s = reinterpret_cast<T*>(buf_s_);
|
||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||
const float length_penalty{beam_hyps.length_penalties[global_batch_idx]};
|
||||
const int early_stopping{beam_hyps.early_stoppings[global_batch_idx]};
|
||||
const int* sequence_lengths{beam_hyps.sequence_lengths_src};
|
||||
const T diversity_rate{beam_hyps.diversity_rates[global_batch_idx]};
|
||||
float* output_log_probs{beam_hyps.log_probs_src};
|
||||
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
using BlockReduce = cub::BlockReduce<cub_kvp, THREADBLOCK_SIZE>;
|
||||
|
||||
extern __shared__ char buf_s_[]; // intermediate result
|
||||
T* buf_s = reinterpret_cast<T*>(buf_s_);
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ int selected_beams;
|
||||
__shared__ float old_cum_log_probs[MAX_K2];
|
||||
__shared__ char cta_topk_store[MAX_K2 * sizeof(cub_kvp)];
|
||||
auto* cta_topk = reinterpret_cast<cub_kvp*>(cta_topk_store);
|
||||
__shared__ cub_kvp cta_topk[MAX_K2];
|
||||
__shared__ int selected_beams;
|
||||
__shared__ int thread_requiring_update;
|
||||
|
||||
// reposition topk_tmp_id_buf, topk_tmp_val_buf to data for the current vector
|
||||
topk_tmp_id_buf += vector_id * candidate_size;
|
||||
topk_tmp_val_buf += vector_id * candidate_size;
|
||||
|
||||
if (thread_id == 0)
|
||||
{
|
||||
@ -147,47 +152,56 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
|
||||
}
|
||||
if (thread_id < K)
|
||||
{
|
||||
old_cum_log_probs[thread_id] = v[vector_id * K + thread_id];
|
||||
old_cum_log_probs[thread_id] = cum_log_probs[vector_id * K + thread_id];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (beam_hyps.num_beams != nullptr)
|
||||
{
|
||||
// initialize worst_score if this batch has no finished beam
|
||||
if (beam_hyps.num_beams[global_batch_idx] == 0 && thread_id == 0)
|
||||
{
|
||||
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX;
|
||||
}
|
||||
// return if this batch has enough finished beams
|
||||
else if (beam_hyps.num_beams[global_batch_idx] == K)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Get top 2K tokens from cadidates
|
||||
cub::ArgMax arg_max;
|
||||
cub_kvp partial_topk{V - 1, -MAX_T_VAL};
|
||||
cub_kvp partial_topk{candidate_size - 1, -MAX_T_VAL};
|
||||
|
||||
for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE)
|
||||
for (int elem_id = thread_id; elem_id < candidate_size; elem_id += THREADBLOCK_SIZE)
|
||||
{
|
||||
int i = beam_hyps.num_beams == nullptr ? elem_id % K : elem_id / 2 / K;
|
||||
T elem = length_penalty == 0.0f
|
||||
? y[elem_id]
|
||||
: apply_length_penalty(y[elem_id],
|
||||
finished[vector_id * K + i].isFinished() ? sequence_lengths[vector_id * K + i]
|
||||
: sequence_lengths[vector_id * K + i] + 1,
|
||||
length_penalty);
|
||||
T elem = topk_tmp_val_buf[elem_id];
|
||||
if (length_penalty > 0.0f)
|
||||
{
|
||||
int length = sequence_lengths[vector_id * K + i];
|
||||
if (early_stopping == 0)
|
||||
{
|
||||
// Use generated_length (rather than sequence_length) to compute length_penalty
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/generation/beam_search.py#L957
|
||||
// But this branch will cause CI error in
|
||||
// "C++ Tests (GPT) on A30", "C++ Tests (GPT-J) on H100_PCIe", "H100_PCIe-accuracy-0"
|
||||
length -= beam_hyps.input_lengths[global_batch_idx];
|
||||
}
|
||||
const int pad_if_not_finish = finished[vector_id * K + i].isFinished() ? 0 : 1;
|
||||
elem = apply_length_penalty(elem, length + pad_if_not_finish, length_penalty);
|
||||
}
|
||||
elem += diversity_rate * (T) i;
|
||||
int elem_idx = elem_id; // x[elem_id];
|
||||
cub_kvp new_elem{elem_idx, elem};
|
||||
cub_kvp new_elem{elem_id, elem};
|
||||
partial_topk = arg_max(partial_topk, new_elem);
|
||||
buf_s[elem_id] = elem;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
__shared__ int thread_requiring_update;
|
||||
|
||||
for (int i = 0; i < 2 * K; ++i)
|
||||
{
|
||||
cub_kvp total_topk = BlockReduce(temp_storage).Reduce(partial_topk, arg_max);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
cta_topk[i] = total_topk;
|
||||
@ -196,13 +210,13 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Only 1 thread needs to update the old partial before the next block reduce. We don't need to do this update
|
||||
// on the last iteration.
|
||||
// Only one thread needs to update the old partial before the next block reduce.
|
||||
// No need to do this in the last iteration.
|
||||
if (thread_id == thread_requiring_update && i < (2 * K - 1))
|
||||
{
|
||||
partial_topk.key = V - 1;
|
||||
partial_topk.key = candidate_size - 1;
|
||||
partial_topk.value = -MAX_T_VAL;
|
||||
for (int tid = thread_id; tid < V; tid += THREADBLOCK_SIZE)
|
||||
for (int tid = thread_id; tid < candidate_size; tid += THREADBLOCK_SIZE)
|
||||
{
|
||||
cub_kvp new_elem{tid, buf_s[tid]};
|
||||
partial_topk = arg_max(partial_topk, new_elem);
|
||||
@ -212,104 +226,101 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
|
||||
|
||||
if (thread_id == 0)
|
||||
{
|
||||
v += vector_id * K;
|
||||
cum_log_probs += vector_id * K;
|
||||
|
||||
for (int i = 0; i < 2 * K; ++i)
|
||||
{
|
||||
const int current_key = cta_topk[i].key;
|
||||
const T current_value = cta_topk[i].value;
|
||||
if (i < K && beam_hyps.num_beams != nullptr && x[current_key] % vocab_size == beam_hyps.end_ids[vector_id])
|
||||
if (i < K && beam_hyps.num_beams != nullptr
|
||||
&& topk_tmp_id_buf[current_key] % vocab_size == beam_hyps.end_ids[vector_id])
|
||||
{
|
||||
// if beam_token does not belong to top num_beams tokens, it should not
|
||||
// be added. Refer from
|
||||
// https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/generation_beam_search.py#L257
|
||||
{
|
||||
const float normed_score = (float) current_value;
|
||||
const int num_beam = beam_hyps.num_beams[global_batch_idx];
|
||||
int beam_idx = num_beam;
|
||||
// If there are beam_width finished sentences, check that the score of
|
||||
// selected candidatet is higher than min_normed_score or not. If
|
||||
// current score is better, replace worst one and update the
|
||||
// min_normed_score.
|
||||
if (num_beam == K)
|
||||
{
|
||||
if (normed_score < beam_hyps.min_normed_scores[global_batch_idx])
|
||||
{
|
||||
// end the tracing and exist this for loop
|
||||
selected_beams = K;
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
// find the beam index which's score = min_normed_score, erase it.
|
||||
for (int j = 0; j < K; j++)
|
||||
{
|
||||
if (beam_hyps.normed_scores[global_batch_idx * (K * 2) + j]
|
||||
== beam_hyps.min_normed_scores[global_batch_idx])
|
||||
{
|
||||
beam_idx = j;
|
||||
beam_hyps.num_beams[global_batch_idx]--;
|
||||
// Add beam only if beam_token belongs to top K tokens
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/generation/beam_search.py#L272
|
||||
const float normed_score = (float) current_value;
|
||||
const int num_beam = beam_hyps.num_beams[global_batch_idx];
|
||||
int beam_idx = num_beam;
|
||||
|
||||
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX;
|
||||
beam_hyps.normed_scores[global_batch_idx * (K * 2) + j] = normed_score;
|
||||
for (int l = 0; l < K; l++)
|
||||
{
|
||||
beam_hyps.min_normed_scores[global_batch_idx]
|
||||
= min(beam_hyps.min_normed_scores[global_batch_idx],
|
||||
beam_hyps.normed_scores[global_batch_idx * (K * 2) + l]);
|
||||
}
|
||||
break;
|
||||
// There are already K beams
|
||||
if (num_beam == K)
|
||||
{
|
||||
// The current score is worse than the worst one in beams
|
||||
if (normed_score < beam_hyps.min_normed_scores[global_batch_idx])
|
||||
{
|
||||
selected_beams = K;
|
||||
break;
|
||||
}
|
||||
// The current score is better than the worst one in beams
|
||||
else
|
||||
{
|
||||
// Find the beam index which score == min_normed_score and erase it.
|
||||
for (int j = 0; j < K; j++)
|
||||
{
|
||||
if (beam_hyps.normed_scores[global_batch_idx * (K * 2) + j]
|
||||
== beam_hyps.min_normed_scores[global_batch_idx])
|
||||
{
|
||||
beam_idx = j;
|
||||
beam_hyps.num_beams[global_batch_idx]--;
|
||||
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX;
|
||||
beam_hyps.normed_scores[global_batch_idx * (K * 2) + j] = normed_score;
|
||||
for (int l = 0; l < K; l++)
|
||||
{
|
||||
beam_hyps.min_normed_scores[global_batch_idx]
|
||||
= min(beam_hyps.min_normed_scores[global_batch_idx],
|
||||
beam_hyps.normed_scores[global_batch_idx * (K * 2) + l]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
const int tgt_id_offset
|
||||
= ((vector_id + beam_hyps.ite * beam_hyps.local_batch_size) * (K * 2) + beam_idx)
|
||||
* (beam_hyps.max_seq_len);
|
||||
|
||||
int prev_id = (x[current_key] / vocab_size) % K;
|
||||
const int current_step{sequence_lengths[vector_id * K + prev_id]};
|
||||
beam_hyps.output_ids_tgt[tgt_id_offset + current_step] = beam_hyps.end_ids[vector_id];
|
||||
|
||||
if (beam_hyps.log_probs != nullptr)
|
||||
{
|
||||
beam_hyps.log_probs[tgt_id_offset + current_step]
|
||||
= (float) y[current_key] - old_cum_log_probs[(x[current_key] / vocab_size) % K];
|
||||
}
|
||||
|
||||
for (int j = current_step - 1; j >= 0; j--)
|
||||
{
|
||||
const int src_idx = j * beam_hyps.batch_size * K
|
||||
+ beam_hyps.ite * beam_hyps.local_batch_size * K + vector_id * K + prev_id;
|
||||
|
||||
beam_hyps.output_ids_tgt[tgt_id_offset + j]
|
||||
= beam_hyps.output_ids_src_ptr[vector_id][prev_id * beam_hyps.max_seq_len + j];
|
||||
if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr)
|
||||
{
|
||||
beam_hyps.log_probs[tgt_id_offset + j] = beam_hyps.log_probs_src[src_idx];
|
||||
}
|
||||
prev_id = beam_hyps.parent_ids_src_ptr[vector_id][prev_id * beam_hyps.max_seq_len + j];
|
||||
}
|
||||
const int tgt_beam_idx = global_batch_idx * (K * 2) + beam_idx;
|
||||
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = current_step;
|
||||
beam_hyps.normed_scores[tgt_beam_idx] = normed_score;
|
||||
beam_hyps.min_normed_scores[global_batch_idx]
|
||||
= min(beam_hyps.min_normed_scores[global_batch_idx], beam_hyps.normed_scores[tgt_beam_idx]);
|
||||
|
||||
beam_hyps.num_beams[global_batch_idx]++;
|
||||
beam_hyps.cum_log_probs[tgt_beam_idx] = (float) y[current_key];
|
||||
}
|
||||
const int tgt_id_offset
|
||||
= ((vector_id + beam_hyps.ite * beam_hyps.local_batch_size) * (K * 2) + beam_idx)
|
||||
* (beam_hyps.max_seq_len);
|
||||
int prev_id = (topk_tmp_id_buf[current_key] / vocab_size) % K;
|
||||
const int current_step{sequence_lengths[vector_id * K + prev_id]};
|
||||
beam_hyps.output_ids_tgt[tgt_id_offset + current_step] = beam_hyps.end_ids[vector_id];
|
||||
|
||||
if (beam_hyps.log_probs != nullptr)
|
||||
{
|
||||
beam_hyps.log_probs[tgt_id_offset + current_step] = (float) topk_tmp_val_buf[current_key]
|
||||
- old_cum_log_probs[(topk_tmp_id_buf[current_key] / vocab_size) % K];
|
||||
}
|
||||
|
||||
for (int j = current_step - 1; j >= 0; j--)
|
||||
{
|
||||
const int src_idx = j * beam_hyps.batch_size * K + beam_hyps.ite * beam_hyps.local_batch_size * K
|
||||
+ vector_id * K + prev_id;
|
||||
|
||||
beam_hyps.output_ids_tgt[tgt_id_offset + j]
|
||||
= beam_hyps.output_ids_src_ptr[vector_id][prev_id * beam_hyps.max_seq_len + j];
|
||||
if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr)
|
||||
{
|
||||
beam_hyps.log_probs[tgt_id_offset + j] = beam_hyps.log_probs_src[src_idx];
|
||||
}
|
||||
prev_id = beam_hyps.parent_ids_src_ptr[vector_id][prev_id * beam_hyps.max_seq_len + j];
|
||||
}
|
||||
const int tgt_beam_idx = global_batch_idx * (K * 2) + beam_idx;
|
||||
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = current_step;
|
||||
beam_hyps.normed_scores[tgt_beam_idx] = normed_score;
|
||||
beam_hyps.min_normed_scores[global_batch_idx]
|
||||
= min(beam_hyps.min_normed_scores[global_batch_idx], beam_hyps.normed_scores[tgt_beam_idx]);
|
||||
|
||||
beam_hyps.num_beams[global_batch_idx]++;
|
||||
cum_log_probs[tgt_beam_idx] = (float) topk_tmp_val_buf[current_key];
|
||||
}
|
||||
else if ((beam_hyps.num_beams != nullptr && i < 2 * K) || (beam_hyps.num_beams == nullptr && i < K))
|
||||
else if (beam_hyps.num_beams != nullptr || beam_hyps.num_beams == nullptr && i < K)
|
||||
{
|
||||
const int current_step{sequence_lengths[vector_id * K + selected_beams]};
|
||||
output_ids_ptr[vector_id][selected_beams * beam_hyps.max_seq_len + current_step] = x[current_key];
|
||||
beam_hyps.output_ids_tgt_ptr[vector_id][selected_beams * beam_hyps.max_seq_len + current_step]
|
||||
= topk_tmp_id_buf[current_key];
|
||||
if (output_log_probs != nullptr)
|
||||
{
|
||||
output_log_probs[current_step * beam_hyps.batch_size * K + vector_id * K + selected_beams]
|
||||
= (float) y[current_key] - old_cum_log_probs[(x[current_key] / vocab_size) % K];
|
||||
= (float) topk_tmp_val_buf[current_key]
|
||||
- old_cum_log_probs[(topk_tmp_id_buf[current_key] / vocab_size) % K];
|
||||
}
|
||||
v[selected_beams] = (float) y[current_key];
|
||||
cum_log_probs[selected_beams] = (float) topk_tmp_val_buf[current_key];
|
||||
selected_beams++;
|
||||
}
|
||||
__syncthreads();
|
||||
@ -319,15 +330,46 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
|
||||
}
|
||||
}
|
||||
}
|
||||
// update beam_hyps.is_done for each batch
|
||||
if (threadIdx.x == 0 && beam_hyps.num_beams != nullptr)
|
||||
{
|
||||
// no enough beams
|
||||
if (beam_hyps.num_beams[blockIdx.x] < K)
|
||||
{
|
||||
beam_hyps.is_done[blockIdx.x] = false;
|
||||
return;
|
||||
}
|
||||
else if (beam_hyps.early_stopping)
|
||||
float highest_attainable_score = 0.0f;
|
||||
switch (early_stopping)
|
||||
{
|
||||
case 1:
|
||||
// enough beams with early stopping
|
||||
beam_hyps.is_done[blockIdx.x] = true;
|
||||
return;
|
||||
case 0:
|
||||
// enough beams without early stopping
|
||||
highest_attainable_score = static_cast<float>(apply_length_penalty(cum_log_probs[0],
|
||||
sequence_lengths[vector_id * K] - beam_hyps.input_lengths[global_batch_idx], length_penalty));
|
||||
beam_hyps.is_done[blockIdx.x] = beam_hyps.min_normed_scores[global_batch_idx] >= highest_attainable_score;
|
||||
return;
|
||||
default:
|
||||
// early_stopping == "never" in HF, i.e., compute the best possible score depending on `length_penalty`
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/generation/beam_search.py#L990
|
||||
if (length_penalty > 0.0f)
|
||||
{
|
||||
highest_attainable_score = static_cast<float>(apply_length_penalty(cum_log_probs[0],
|
||||
beam_hyps.max_seq_len - beam_hyps.input_lengths[global_batch_idx], length_penalty));
|
||||
beam_hyps.is_done[blockIdx.x]
|
||||
= beam_hyps.min_normed_scores[global_batch_idx] >= highest_attainable_score;
|
||||
}
|
||||
else
|
||||
{
|
||||
highest_attainable_score = static_cast<float>(apply_length_penalty(cum_log_probs[0],
|
||||
sequence_lengths[vector_id * K] - beam_hyps.input_lengths[global_batch_idx], length_penalty));
|
||||
beam_hyps.is_done[blockIdx.x]
|
||||
= beam_hyps.min_normed_scores[global_batch_idx] >= highest_attainable_score;
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -366,24 +408,22 @@ __device__ __forceinline__ TopKMD<T, MAX_K> reduce_topk_md_op(const TopKMD<T, MA
|
||||
}
|
||||
|
||||
template <typename T, int ITEMS_PER_THREAD, int MAX_K, int THREADBLOCK_SIZE>
|
||||
__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_kernel(const T* __restrict x,
|
||||
const T* __restrict b, const float* __restrict c, const FinishedState* __restrict finished, int* __restrict z,
|
||||
T* __restrict v, int V, int K, const int* __restrict end_ids)
|
||||
__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_kernel(const T* __restrict log_probs,
|
||||
const T* __restrict bias, const float* __restrict cum_log_probs, const FinishedState* __restrict finished,
|
||||
int* __restrict topk_tmp_id_buf, T* __restrict topk_tmp_val_buf, int vocab_size, int K,
|
||||
const int* __restrict end_ids)
|
||||
{
|
||||
int thread_id = threadIdx.x;
|
||||
int vector_id = blockIdx.x;
|
||||
|
||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||
|
||||
// reposition y to data for the current vector
|
||||
x += vector_id * V;
|
||||
const int thread_id = threadIdx.x;
|
||||
const int vector_id = blockIdx.x;
|
||||
const T MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
|
||||
|
||||
typedef cub::BlockReduce<TopKMD<float, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
// reposition log_probs to data for the current vector
|
||||
log_probs += vector_id * vocab_size;
|
||||
|
||||
TopKMD<float, MAX_K> partial;
|
||||
bool finish = finished[vector_id].isFinished();
|
||||
for (int i = 0; i < MAX_K; ++i)
|
||||
{
|
||||
partial.topk.p[i] = -1;
|
||||
@ -392,22 +432,22 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_ker
|
||||
partial.md.m = -MAX_T_VAL;
|
||||
partial.md.d = 0.0F;
|
||||
|
||||
if (finish)
|
||||
if (finished[vector_id].isFinished())
|
||||
{
|
||||
for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE)
|
||||
for (int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE)
|
||||
{
|
||||
float elem = (elem_id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL;
|
||||
MD new_elem{elem, 1.0F};
|
||||
partial.md = reduce_md_op(partial.md, new_elem);
|
||||
partial.topk.insert(elem, elem_id);
|
||||
// if (elem_id > THREADBLOCK_SIZE * MAX_K && (elem_id == E)) break;
|
||||
// if (elem_id > THREADBLOCK_SIZE * MAX_K && elem_id == E) break;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE)
|
||||
for (int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE)
|
||||
{
|
||||
float elem = x[elem_id] + b[elem_id];
|
||||
float elem = log_probs[elem_id] + bias[elem_id];
|
||||
MD new_elem{elem, 1.0F};
|
||||
partial.md = reduce_md_op(partial.md, new_elem);
|
||||
partial.topk.insert(elem, elem_id);
|
||||
@ -418,9 +458,9 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_ker
|
||||
|
||||
if (thread_id == 0)
|
||||
{
|
||||
z += vector_id * K;
|
||||
v += vector_id * K;
|
||||
c += vector_id;
|
||||
topk_tmp_id_buf += vector_id * K;
|
||||
topk_tmp_val_buf += vector_id * K;
|
||||
cum_log_probs += vector_id;
|
||||
|
||||
// float d_total_inverse = __fdividef(1.0F, total.md.d);
|
||||
float d_total_log = logf(total.md.d);
|
||||
@ -430,48 +470,43 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_ker
|
||||
float val = total.topk.u[i] - total.md.m - d_total_log;
|
||||
if (i < K)
|
||||
{
|
||||
z[i] = total.topk.p[i] + vector_id * V; // trtllm needs absolute id
|
||||
v[i] = val + c[0];
|
||||
topk_tmp_id_buf[i] = total.topk.p[i] + vector_id * vocab_size; // trtllm needs absolute id
|
||||
topk_tmp_val_buf[i] = val + cum_log_probs[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int ITEMS_PER_THREAD, int MAX_K2, int THREADBLOCK_SIZE>
|
||||
__launch_bounds__(THREADBLOCK_SIZE, 1) __global__
|
||||
void beam_online_softmax_topk_stage1_kernel_base(const T* __restrict x, const T* __restrict b,
|
||||
const FinishedState* __restrict finished, float* __restrict t, int V, int K, const int* __restrict end_ids)
|
||||
__launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_stage1_kernel_base(
|
||||
const T* __restrict log_probs, const T* __restrict bias, const FinishedState* __restrict finished,
|
||||
float* __restrict tmp_buffer, int vocab_size, int K, const int* __restrict end_ids)
|
||||
{
|
||||
int thread_id = threadIdx.x;
|
||||
int vector_id = blockIdx.x; // batch beam index.
|
||||
|
||||
const int thread_id = threadIdx.x;
|
||||
const int vector_id = blockIdx.x;
|
||||
const T MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
|
||||
const int PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2;
|
||||
|
||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||
|
||||
// one will have multiple sections per V
|
||||
const int v_local = (V + gridDim.y - 1) / gridDim.y;
|
||||
// one threadblock has multiple sections per vocab_size
|
||||
const int v_local = (vocab_size + gridDim.y - 1) / gridDim.y;
|
||||
const int section_start = v_local * blockIdx.y;
|
||||
int section_end = section_start + v_local;
|
||||
section_end = (section_end > V) ? V : section_end;
|
||||
const int section_end = std::min(section_start + v_local, vocab_size);
|
||||
|
||||
// reposition x to data for the current vector
|
||||
x += vector_id * V;
|
||||
#if TOPK_FP16_STORAGE == 1
|
||||
typedef cub::BlockReduce<TopKMD<__half, MAX_K2>, THREADBLOCK_SIZE> BlockReduce;
|
||||
#else
|
||||
typedef cub::BlockReduce<TopKMD<T, MAX_K2>, THREADBLOCK_SIZE> BlockReduce;
|
||||
#endif
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ float buf_s[PACKED_TOP_KMD_SIZE]; // save intermediate result
|
||||
__shared__ float buf_s[PACKED_TOP_KMD_SIZE];
|
||||
|
||||
// reposition log_probs to data for the current vector
|
||||
log_probs += vector_id * vocab_size;
|
||||
|
||||
#if TOPK_FP16_STORAGE == 1
|
||||
TopKMD<__half, MAX_K2> partial;
|
||||
#else
|
||||
TopKMD<T, MAX_K2> partial;
|
||||
#endif
|
||||
bool finish = finished[vector_id].isFinished();
|
||||
for (int i = 0; i < MAX_K2; ++i)
|
||||
{
|
||||
partial.topk.p[i] = -1;
|
||||
@ -480,7 +515,7 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__
|
||||
partial.md.m = -MAX_T_VAL;
|
||||
partial.md.d = 0.0F;
|
||||
|
||||
if (finish)
|
||||
if (finished[vector_id].isFinished())
|
||||
{
|
||||
#pragma unroll 1
|
||||
for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE)
|
||||
@ -496,8 +531,8 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__
|
||||
#pragma unroll 1
|
||||
for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE)
|
||||
{
|
||||
T bias = b == nullptr ? (T) 0.0f : b[elem_id]; // gpt-2 does not use bias
|
||||
T elem = x[elem_id] + bias;
|
||||
T b = bias == nullptr ? (T) 0.0f : bias[elem_id];
|
||||
T elem = log_probs[elem_id] + b;
|
||||
MD new_elem{elem, 1.0F};
|
||||
partial.md = reduce_md_op(partial.md, new_elem);
|
||||
partial.topk.insert(elem, elem_id);
|
||||
@ -514,7 +549,7 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__
|
||||
{
|
||||
for (int i = 0; i < 2 * K; i++)
|
||||
{
|
||||
reinterpret_cast<int*>(buf_s)[i] = total.topk.p[i] + vector_id * V; // trtllm needs absolute id
|
||||
reinterpret_cast<int*>(buf_s)[i] = total.topk.p[i] + vector_id * vocab_size; // trtllm needs absolute id
|
||||
buf_s[MAX_K2 + i] = total.topk.u[i];
|
||||
}
|
||||
buf_s[2 * MAX_K2] = total.md.d;
|
||||
@ -523,38 +558,25 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__
|
||||
__syncthreads();
|
||||
for (int elem_id = thread_id; elem_id < PACKED_TOP_KMD_SIZE; elem_id += THREADBLOCK_SIZE)
|
||||
{
|
||||
t[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id] = buf_s[elem_id];
|
||||
tmp_buffer[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id]
|
||||
= buf_s[elem_id];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int ITEMS_PER_THREAD, int MAX_K2, int THREADBLOCK_SIZE>
|
||||
__launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_stage1_kernel_fast(
|
||||
const T* __restrict x, const T* __restrict b, const FinishedState* __restrict finished, float* __restrict t, int V,
|
||||
int K, const int* __restrict end_ids, const int v_local)
|
||||
const T* __restrict log_probs, const T* __restrict bias, const FinishedState* __restrict finished,
|
||||
float* __restrict t, int vocab_size, int K, const int* __restrict end_ids, const int v_local)
|
||||
{
|
||||
extern __shared__ char buf_smem_logprobs_[];
|
||||
T* buf_smem_logprobs = reinterpret_cast<T*>(buf_smem_logprobs_);
|
||||
|
||||
int thread_id = threadIdx.x;
|
||||
int vector_id = blockIdx.x; // batch beam index.
|
||||
|
||||
const int thread_id = threadIdx.x;
|
||||
const int vector_id = blockIdx.x;
|
||||
const T MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
|
||||
const int PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2;
|
||||
|
||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||
|
||||
// reposition x to data for the current vector
|
||||
x += vector_id * V;
|
||||
|
||||
// one will have multiple sections per V
|
||||
// one threadblock has multiple sections per vocab_size
|
||||
const int section_start = v_local * blockIdx.y;
|
||||
int section_end = section_start + v_local;
|
||||
section_end = (section_end > V) ? V : section_end;
|
||||
const int section_end = std::min(section_start + v_local, vocab_size);
|
||||
const int valid_smem_length = section_end - section_start;
|
||||
|
||||
bool finish = finished[vector_id].isFinished();
|
||||
MD partial_md{-MAX_T_VAL, 0.0f};
|
||||
|
||||
#if TOPK_FP16_STORAGE == 1
|
||||
using cub_kvp = cub::KeyValuePair<int, __half>;
|
||||
using BlockReduceTopK = cub::BlockReduce<cub_kvp, THREADBLOCK_SIZE>;
|
||||
@ -565,10 +587,25 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_
|
||||
|
||||
using BlockReduceMD = cub::BlockReduce<MD, THREADBLOCK_SIZE>;
|
||||
|
||||
cub::ArgMax arg_max;
|
||||
cub_kvp partial_topk{V - 1, -MAX_T_VAL};
|
||||
extern __shared__ char buf_smem_logprobs_[];
|
||||
T* buf_smem_logprobs = reinterpret_cast<T*>(buf_smem_logprobs_);
|
||||
|
||||
if (finish)
|
||||
__shared__ union
|
||||
{
|
||||
typename BlockReduceMD::TempStorage md_smem;
|
||||
typename BlockReduceTopK::TempStorage topk_smem;
|
||||
} temp_storage;
|
||||
|
||||
__shared__ float buf_s[PACKED_TOP_KMD_SIZE];
|
||||
__shared__ int thread_requiring_update;
|
||||
|
||||
// reposition log_probs to data for the current vector
|
||||
log_probs += vector_id * vocab_size;
|
||||
|
||||
cub::ArgMax arg_max;
|
||||
cub_kvp partial_topk{vocab_size - 1, -MAX_T_VAL};
|
||||
MD partial_md{-MAX_T_VAL, 0.0f};
|
||||
if (finished[vector_id].isFinished())
|
||||
{
|
||||
#pragma unroll 1
|
||||
for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE)
|
||||
@ -589,8 +626,8 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_
|
||||
#pragma unroll 1
|
||||
for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE)
|
||||
{
|
||||
T bias = b == nullptr ? (T) 0.0f : b[elem_id]; // gpt-2 does not use bias
|
||||
T elem = x[elem_id] + bias;
|
||||
T b = bias == nullptr ? (T) 0.0f : bias[elem_id];
|
||||
T elem = log_probs[elem_id] + b;
|
||||
MD new_elem_md{elem, 1.0F};
|
||||
partial_md = reduce_md_op(partial_md, new_elem_md);
|
||||
|
||||
@ -600,18 +637,8 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_
|
||||
buf_smem_logprobs[smem_index] = elem;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
__shared__ union
|
||||
{
|
||||
typename BlockReduceMD::TempStorage md_smem;
|
||||
typename BlockReduceTopK::TempStorage topk_smem;
|
||||
} temp_storage;
|
||||
|
||||
__shared__ float buf_s[PACKED_TOP_KMD_SIZE]; // save intermediate result
|
||||
__shared__ int thread_requiring_update;
|
||||
|
||||
for (int i = 0; i < 2 * K; ++i)
|
||||
{
|
||||
cub_kvp total_topk = BlockReduceTopK(temp_storage.topk_smem).Reduce(partial_topk, arg_max);
|
||||
@ -619,18 +646,18 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
reinterpret_cast<int*>(buf_s)[i]
|
||||
= section_start + total_topk.key + vector_id * V; // trtllm needs absolute id
|
||||
= section_start + total_topk.key + vector_id * vocab_size; // trtllm needs absolute id
|
||||
buf_s[MAX_K2 + i] = total_topk.value;
|
||||
buf_smem_logprobs[total_topk.key] = -MAX_T_VAL;
|
||||
thread_requiring_update = total_topk.key % THREADBLOCK_SIZE;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Only 1 thread needs to update the old partial before the next block reduce. We don't need to do this update
|
||||
// on the last iteration.
|
||||
// Only one thread needs to update the old partial before the next block reduce.
|
||||
// No need to do this in the last iteration.
|
||||
if (thread_id == thread_requiring_update && i < (2 * K - 1))
|
||||
{
|
||||
partial_topk.key = V - 1;
|
||||
partial_topk.key = vocab_size - 1;
|
||||
partial_topk.value = -MAX_T_VAL;
|
||||
for (int tid = thread_id; tid < valid_smem_length; tid += THREADBLOCK_SIZE)
|
||||
{
|
||||
@ -649,7 +676,6 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_
|
||||
buf_s[2 * MAX_K2 + 1] = total_md.m;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int elem_id = thread_id; elem_id < PACKED_TOP_KMD_SIZE; elem_id += THREADBLOCK_SIZE)
|
||||
{
|
||||
t[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id] = buf_s[elem_id];
|
||||
@ -657,58 +683,52 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_
|
||||
}
|
||||
|
||||
template <typename T, int MAX_K2, int THREADBLOCK_SIZE>
|
||||
__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_stage2_kernel(const float* __restrict x,
|
||||
const float* __restrict c, int* __restrict z, T* __restrict v, int K, int parts_per_beam, const int V)
|
||||
__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_stage2_kernel(
|
||||
const float* __restrict temp_storage, const float* __restrict cum_log_probs, int* __restrict ids,
|
||||
T* __restrict vals, int K, int parts_per_beam, const int vocab_size)
|
||||
{
|
||||
const int vector_id = blockIdx.x;
|
||||
const int thread_id = threadIdx.x;
|
||||
const T MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
|
||||
const int PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2;
|
||||
|
||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||
|
||||
extern __shared__ char buf_s_[]; // intermediate result
|
||||
float* buf_s = reinterpret_cast<float*>(buf_s_);
|
||||
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
using BlockReduceTopK = cub::BlockReduce<cub_kvp, THREADBLOCK_SIZE>;
|
||||
using BlockReduceMD = cub::BlockReduce<MD, THREADBLOCK_SIZE>;
|
||||
|
||||
extern __shared__ char buf_s_[];
|
||||
float* buf_s = reinterpret_cast<float*>(buf_s_);
|
||||
__shared__ cub_kvp buf_smem_kv[MAX_K2];
|
||||
|
||||
__shared__ union
|
||||
{
|
||||
typename BlockReduceTopK::TempStorage topk_smem;
|
||||
typename BlockReduceMD::TempStorage md_smem;
|
||||
|
||||
} temp_storage;
|
||||
} shared_temp_storage;
|
||||
|
||||
temp_storage += vector_id * PACKED_TOP_KMD_SIZE * parts_per_beam;
|
||||
|
||||
cub::ArgMax arg_max;
|
||||
|
||||
x += vector_id * PACKED_TOP_KMD_SIZE * parts_per_beam;
|
||||
|
||||
MD partial_md{-MAX_T_VAL, 0.0f};
|
||||
cub_kvp total_topk{V - 1, -MAX_T_VAL};
|
||||
cub_kvp total_topk{vocab_size - 1, -MAX_T_VAL};
|
||||
|
||||
__shared__ char buf_smem_kv_store[MAX_K2 * sizeof(cub_kvp)];
|
||||
auto* buf_smem_kv = reinterpret_cast<cub_kvp*>(buf_smem_kv_store);
|
||||
|
||||
// load and unpack into registers through smem
|
||||
// Load and unpack into registers through smem
|
||||
for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * parts_per_beam; idx += THREADBLOCK_SIZE)
|
||||
{
|
||||
buf_s[idx] = x[idx];
|
||||
buf_s[idx] = temp_storage[idx];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// find the argmax within each parts_per_beam,
|
||||
// find the topK across all parts_per_beam.
|
||||
|
||||
// Find the argmax within each parts_per_beam
|
||||
// Find the topK across all parts_per_beam
|
||||
for (int k = 0; k < 2 * K; ++k)
|
||||
{
|
||||
cub_kvp partial_topk{V - 1, -MAX_T_VAL};
|
||||
cub_kvp partial_topk{vocab_size - 1, -MAX_T_VAL};
|
||||
// Only threads responsible for a chunk will do the computation
|
||||
if (threadIdx.x < parts_per_beam)
|
||||
{
|
||||
float* b_s = buf_s + threadIdx.x * PACKED_TOP_KMD_SIZE;
|
||||
|
||||
for (int i = 0; i < K; ++i)
|
||||
{
|
||||
int current_index = threadIdx.x * PACKED_TOP_KMD_SIZE + i;
|
||||
@ -718,21 +738,20 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta
|
||||
}
|
||||
}
|
||||
|
||||
cub_kvp total_topk = BlockReduceTopK(temp_storage.topk_smem).Reduce(partial_topk, arg_max);
|
||||
|
||||
cub_kvp total_topk = BlockReduceTopK(shared_temp_storage.topk_smem).Reduce(partial_topk, arg_max);
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
// store kv pairs in shared mem buffer
|
||||
// Store kv pairs in shared mem buffer
|
||||
int temp_offset = total_topk.key;
|
||||
int global_offset = reinterpret_cast<int*>(buf_s)[temp_offset];
|
||||
total_topk.key = global_offset;
|
||||
buf_smem_kv[k] = total_topk;
|
||||
|
||||
// Invalidate the maximum value within the chunk
|
||||
reinterpret_cast<int*>(buf_s)[temp_offset] = V - 1; // id in share memory
|
||||
buf_s[temp_offset + MAX_K2] = -MAX_T_VAL; // value in share memory
|
||||
reinterpret_cast<int*>(buf_s)[temp_offset] = vocab_size - 1; // id in share memory
|
||||
buf_s[temp_offset + MAX_K2] = -MAX_T_VAL; // value in share memory
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
@ -744,18 +763,16 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta
|
||||
partial_md.d = b_s[2 * MAX_K2];
|
||||
partial_md.m = b_s[2 * MAX_K2 + 1];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
auto reduce_md_func = [](const MD& a, const MD& b) { return reduce_md_op(a, b); };
|
||||
MD total_md = BlockReduceMD(temp_storage.md_smem).Reduce(partial_md, reduce_md_func);
|
||||
|
||||
__syncthreads();
|
||||
MD total_md = BlockReduceMD(shared_temp_storage.md_smem).Reduce(partial_md, reduce_md_func);
|
||||
|
||||
if (thread_id == 0)
|
||||
{
|
||||
z += vector_id * 2 * K;
|
||||
v += vector_id * 2 * K;
|
||||
c += vector_id;
|
||||
|
||||
ids += vector_id * 2 * K;
|
||||
vals += vector_id * 2 * K;
|
||||
cum_log_probs += vector_id;
|
||||
float d_total_log = logf(total_md.d);
|
||||
|
||||
for (int i = 0; i < MAX_K2; ++i)
|
||||
@ -763,8 +780,8 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta
|
||||
float val = (float) buf_smem_kv[i].value - total_md.m - d_total_log;
|
||||
if (i < 2 * K)
|
||||
{
|
||||
z[i] = buf_smem_kv[i].key;
|
||||
v[i] = (float) val + (float) c[0];
|
||||
ids[i] = buf_smem_kv[i].key;
|
||||
vals[i] = (float) val + (float) cum_log_probs[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -774,9 +791,9 @@ template <typename T, int MAX_K2>
|
||||
void beam_online_softmax_topk_stage2_kernelLauncher(const float* temp_storage, const float* cum_log_probs, int* ids,
|
||||
T* vals, int batch_size, int beam_width, int parts_per_beam, cudaStream_t stream, const int vocab_size)
|
||||
{
|
||||
// might rewrite beam_online_softmax_topk_stage2_kernel no to depend on
|
||||
// constant block size in oreder to reduce compilation time
|
||||
int smem_stage2_size = parts_per_beam * (2 * MAX_K2 + 2) * sizeof(float);
|
||||
// TODO: rewrite beam_online_softmax_topk_stage2_kernel to remove dependence
|
||||
// of constant block size in oreder to reduce compilation time
|
||||
const int smem_stage2_size = parts_per_beam * (2 * MAX_K2 + 2) * sizeof(float);
|
||||
|
||||
if (parts_per_beam <= 32)
|
||||
{
|
||||
@ -803,20 +820,22 @@ void beam_online_softmax_topk_stage2_kernelLauncher(const float* temp_storage, c
|
||||
}
|
||||
|
||||
template <typename T, int MAX_K>
|
||||
void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const FinishedState* finished,
|
||||
const int* sequence_lengths, float* cum_log_probs, float* output_log_probs, int** output_ids_ptr,
|
||||
void* temp_storage, const int temp_storage_size, BeamHypotheses* beam_hyps, const int batch_size,
|
||||
const int beam_width, const int vocab_size, const int* end_ids, const float* diversity_rates,
|
||||
const float* length_penalties, cudaStream_t stream)
|
||||
void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const FinishedState* finished, float* cum_log_probs,
|
||||
void* temp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps, cudaStream_t stream)
|
||||
{
|
||||
const int batch_size{beam_hyps.local_batch_size};
|
||||
const int beam_width{beam_hyps.beam_width};
|
||||
const int vocab_size{beam_hyps.vocab_size};
|
||||
const int* end_ids{beam_hyps.end_ids};
|
||||
|
||||
const int items_per_thread = 1;
|
||||
const int block_sz = (MAX_K < 16) ? (MAX_K < 8) ? SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE : 128 : 64;
|
||||
const int block_sz = (MAX_K < 16) ? ((MAX_K < 8) ? SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE : 128) : 64;
|
||||
// const int block_sz = SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE;
|
||||
|
||||
assert(temp_storage_size % 2 == 0);
|
||||
assert(temp_storage_size >= 2 * batch_size * beam_width * beam_width * 2);
|
||||
// Beam search needs the sequence lengths of beams to apply length penalty.
|
||||
assert(length_penalties == nullptr || sequence_lengths != nullptr);
|
||||
assert(beam_hyps.length_penalties == nullptr || beam_hyps.sequence_lengths_src != nullptr);
|
||||
|
||||
const int topk_buf_offset = ceil(batch_size * beam_width * beam_width * 2 / 4.) * 4;
|
||||
int* topk_tmp_id_buf = reinterpret_cast<int*>(temp_storage);
|
||||
@ -928,25 +947,22 @@ void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const Finish
|
||||
// we will not put them into next iteration
|
||||
|
||||
const int candidates = beam_width * beam_width * 2;
|
||||
int smem_size_batch_topk = sizeof(T) * candidates;
|
||||
const int smem_size_batch_topk = sizeof(T) * candidates;
|
||||
if (smem_size_batch_topk >= (48 << 10))
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
batch_topk_kernel<T, MAX_K * 2, 32>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_batch_topk));
|
||||
}
|
||||
|
||||
batch_topk_kernel<T, MAX_K * 2, 32><<<batch_size, 32, smem_size_batch_topk, stream>>>(topk_tmp_id_buf,
|
||||
topk_tmp_val_buf, output_ids_ptr, cum_log_probs, output_log_probs, finished, sequence_lengths, *beam_hyps,
|
||||
candidates, beam_width, vocab_size, length_penalties, diversity_rates);
|
||||
batch_topk_kernel<T, MAX_K * 2, 32><<<batch_size, 32, smem_size_batch_topk, stream>>>(
|
||||
topk_tmp_id_buf, topk_tmp_val_buf, cum_log_probs, finished, beam_hyps, candidates);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
#define INSTANTIATE_BEAMSEARCH_K(T, MAX_K) \
|
||||
template void topK_softMax_kernelLauncher<T, MAX_K>(const T* log_probs, const T* bias, \
|
||||
const FinishedState* finished, const int* sequence_lengths, float* cum_log_probs, float* output_log_probs, \
|
||||
int** output_ids_ptr, void* temp_storage, const int temp_storage_size, BeamHypotheses* beam_hyps, \
|
||||
const int batch_size, const int beam_width, const int vocab_size, const int* end_ids, \
|
||||
const float* diversity_rates, const float* length_penalties, cudaStream_t stream);
|
||||
const FinishedState* finished, float* cum_log_probs, void* temp_storage, const int temp_storage_size, \
|
||||
BeamHypotheses& beam_hyps, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -70,7 +70,7 @@ void splitkGroupedGemm_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std
|
||||
int64_t gemmParamsWorkSpaceSize, void* gemmWorkSpace, int64_t gemmWorkSpaceSize,
|
||||
std::vector<int64_t> splitkBufferOffsets, int splitKSlices, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
using ElementA = cutlassType;
|
||||
using ElementB = cutlassType;
|
||||
using ElementOutput = float;
|
||||
@ -194,7 +194,7 @@ void splitkGroupedGemm_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std
|
||||
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to run CUTLASS Grouped GEMM kernel.");
|
||||
|
||||
std::free(host_workspace);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <int M1, int N1, int K1, int M2, int N2, int K2>
|
||||
|
||||
@ -113,7 +113,7 @@ __device__ __forceinline__ void apply_scale(void* act, void* act_scale)
|
||||
}
|
||||
|
||||
template <int N, int K, bool EnableZero>
|
||||
__device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* scales, void* zeros)
|
||||
__device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* scales, void* zeros, half alpha)
|
||||
{
|
||||
using Converter = ConverterI4ToF16;
|
||||
static_assert(K % 2 == 0);
|
||||
@ -123,11 +123,11 @@ __device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* sca
|
||||
{
|
||||
ConverterI4ToF16::convert<K>(
|
||||
reinterpret_cast<uint8_t*>(quantized_w) + n * K / 2, reinterpret_cast<half*>(w) + n * K);
|
||||
half2 vec_scale = __half2half2(reinterpret_cast<half*>(scales)[n]);
|
||||
half2 vec_scale = __half2half2(reinterpret_cast<half*>(scales)[n] * alpha);
|
||||
half2 vec_zero = __half2half2(__float2half_rn(0.f));
|
||||
if constexpr (EnableZero)
|
||||
{
|
||||
vec_zero = __half2half2(reinterpret_cast<half*>(zeros)[n]);
|
||||
vec_zero = __half2half2(reinterpret_cast<half*>(zeros)[n] * alpha);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int k = 0; k < VecK; ++k)
|
||||
@ -186,7 +186,7 @@ __device__ __forceinline__ T warp_reduce_sum(T& val)
|
||||
}
|
||||
|
||||
template <int CtaM, int CtaN, int Threads, bool EnableBias>
|
||||
__device__ __forceinline__ void epilogue(void* out, int stride, void* tile_acc, void* bias, float alpha)
|
||||
__device__ __forceinline__ void epilogue(void* out, int stride, void* tile_acc, void* bias)
|
||||
{
|
||||
static constexpr int WarpSize = 32;
|
||||
static constexpr int WarpNum = Threads / WarpSize;
|
||||
@ -224,7 +224,7 @@ __device__ __forceinline__ void epilogue(void* out, int stride, void* tile_acc,
|
||||
{
|
||||
val += shmem[jj * AlignShmemSize + ii];
|
||||
}
|
||||
reinterpret_cast<half*>(out)[m * stride + n] = __float2half_rn(alpha * val + v_bias);
|
||||
reinterpret_cast<half*>(out)[m * stride + n] = __float2half_rn(val + v_bias);
|
||||
}
|
||||
}
|
||||
|
||||
@ -348,15 +348,17 @@ __global__ void kernel(typename Details::ActDataType* act, half* act_scale, uint
|
||||
load<half, CtaN, Mandatory>(vec_scale, scales + idx_k / GroupSize * n, 1);
|
||||
load<half, CtaN, EnableZero>(vec_zero, zeros + idx_k / GroupSize * n, 1);
|
||||
// Dequantize Data
|
||||
// W4A8 checkpoints have larger activation and weight values. In order to prevent the warp-level FP16
|
||||
// accumulator from overflow, the multiplication of alpha is moved from epilogue to dequantize
|
||||
apply_scale<CtaM, StepK, EnableActScale>(tile_a, vec_act_scale);
|
||||
dequantize<CtaN, StepK, EnableZero>(tile_w, tile_w_quantized, vec_scale, vec_zero);
|
||||
dequantize<CtaN, StepK, EnableZero>(tile_w, tile_w_quantized, vec_scale, vec_zero, __float2half_rn(alpha));
|
||||
// Rearrange
|
||||
pack_to_vec2<CtaN, StepK>(tile_w_pack2, tile_w);
|
||||
// MMA
|
||||
mma<CtaM, CtaN, StepK>(tile_acc, tile_w_pack2, tile_a);
|
||||
}
|
||||
// Epilogue
|
||||
epilogue<CtaM, CtaN, Threads, EnableBias>(out, n, tile_acc, bias, alpha);
|
||||
epilogue<CtaM, CtaN, Threads, EnableBias>(out, n, tile_acc, bias);
|
||||
}
|
||||
|
||||
template <typename Details, int CtaM, int CtaN, int Threads, int GroupSize, bool EnableActScale, bool EnableZero,
|
||||
|
||||
@ -251,6 +251,7 @@ void DynamicDecodeLayer<T>::setupLayers(
|
||||
|
||||
beamSearchParams.beam_search_diversity_rate = setupParams.beam_search_diversity_rate;
|
||||
beamSearchParams.length_penalty = setupParams.length_penalty;
|
||||
beamSearchParams.early_stopping = setupParams.early_stopping;
|
||||
|
||||
mHasDiffRuntimeArgs = hasDiffRuntimeArgs(beamSearchParams);
|
||||
mOnlineBeamSearchDecode->setup(batchSize, beamSearchParams);
|
||||
|
||||
@ -75,6 +75,7 @@ public:
|
||||
// omlineBeamSearchLayer
|
||||
std::optional<std::vector<float>> beam_search_diversity_rate;
|
||||
std::optional<std::vector<float>> length_penalty;
|
||||
std::optional<std::vector<int>> early_stopping;
|
||||
|
||||
std::optional<bool> normalize_log_probs;
|
||||
};
|
||||
|
||||
@ -31,12 +31,21 @@ static const int SMALL_TOP_K_SOFTMAX_MAX_VOC_PARTS = 128;
|
||||
static const int MAX_K = 4;
|
||||
|
||||
template <typename T>
|
||||
__global__ void update_kernel(FinishedState* finished, int** parent_ids_ptr, int* sequence_lengths,
|
||||
int** output_ids_ptr, BeamHypotheses beam_hyps, const int vocab_size, const int* end_ids,
|
||||
const int local_batch_size, const int beam_width, const int max_seq_len)
|
||||
__global__ void update_kernel(FinishedState* finished, BeamHypotheses beam_hyps)
|
||||
{
|
||||
const int beam_width{beam_hyps.beam_width};
|
||||
const int ite{beam_hyps.ite};
|
||||
const int local_batch_size{beam_hyps.local_batch_size};
|
||||
const int max_seq_len{beam_hyps.max_seq_len};
|
||||
const int vocab_size{beam_hyps.vocab_size};
|
||||
const int end_id{beam_hyps.end_ids[blockIdx.x]};
|
||||
int* num_beams{beam_hyps.num_beams};
|
||||
int* sequence_lengths{beam_hyps.sequence_lengths_src};
|
||||
int** output_ids_ptr{beam_hyps.output_ids_tgt_ptr};
|
||||
int** parent_ids_ptr{beam_hyps.parent_ids_tgt_ptr};
|
||||
|
||||
extern __shared__ char s_buf[]; // intermediate result
|
||||
int* s_sequence_lengths = (int*) (s_buf);
|
||||
int* s_sequence_lengths = reinterpret_cast<int*>(s_buf);
|
||||
|
||||
for (int beam_idx = threadIdx.x; beam_idx < beam_width; beam_idx += blockDim.x)
|
||||
{
|
||||
@ -63,35 +72,27 @@ __global__ void update_kernel(FinishedState* finished, int** parent_ids_ptr, int
|
||||
new_word_id = new_word_id % vocab_size;
|
||||
|
||||
sequence_lengths[batch_beam_idx] = s_sequence_lengths[new_beam_id];
|
||||
if (new_word_id == end_ids[blockIdx.x])
|
||||
if (new_word_id == end_id)
|
||||
{
|
||||
finished[batch_beam_idx].setFinishedEOS();
|
||||
}
|
||||
parent_ids_ptr[blockIdx.x][beam_idx * max_seq_len + current_step] = new_beam_id;
|
||||
output_ids_ptr[blockIdx.x][beam_idx * max_seq_len + current_step] = new_word_id;
|
||||
}
|
||||
if (beam_hyps.num_beams != nullptr)
|
||||
if (num_beams != nullptr && num_beams[ite * local_batch_size + blockIdx.x] == beam_width)
|
||||
{
|
||||
if (beam_hyps.num_beams[beam_hyps.ite * beam_hyps.local_batch_size + blockIdx.x] == beam_width)
|
||||
for (int beam_idx = threadIdx.x; beam_idx < beam_width; beam_idx += blockDim.x)
|
||||
{
|
||||
for (int beam_idx = threadIdx.x; beam_idx < beam_width; beam_idx += blockDim.x)
|
||||
{
|
||||
const auto batch_beam_idx = blockIdx.x * beam_width + beam_idx;
|
||||
finished[batch_beam_idx].setFinished();
|
||||
}
|
||||
finished[blockIdx.x * beam_width + beam_idx].setFinished();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void invokeUpdate(FinishedState* finished, int** parent_ids_ptr, int* sequence_lengths, int** output_ids_ptr,
|
||||
BeamHypotheses* beam_hyps, const int local_batch_size, const int beam_width, const int vocab_size_padded,
|
||||
const int* end_ids, const int max_seq_len, cudaStream_t stream)
|
||||
void invokeUpdate(FinishedState* finished, BeamHypotheses& beam_hyps, cudaStream_t stream)
|
||||
{
|
||||
dim3 grid(local_batch_size);
|
||||
dim3 block(min(beam_width, 1024));
|
||||
|
||||
update_kernel<float><<<grid, block, sizeof(int) * beam_width, stream>>>(finished, parent_ids_ptr, sequence_lengths,
|
||||
output_ids_ptr, *beam_hyps, vocab_size_padded, end_ids, local_batch_size, beam_width, max_seq_len);
|
||||
dim3 grid(beam_hyps.local_batch_size);
|
||||
dim3 block(min(beam_hyps.beam_width, 1024));
|
||||
update_kernel<float><<<grid, block, sizeof(int) * beam_hyps.beam_width, stream>>>(finished, beam_hyps);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -103,11 +104,12 @@ void OnlineBeamSearchLayer<T>::setup(size_t batch_size, SetupParams const& setup
|
||||
|
||||
mDiversityRate.resize(batch_size);
|
||||
mLengthPenalty.resize(batch_size);
|
||||
|
||||
mEarlyStopping.resize(batch_size);
|
||||
FillBuffers const fillBuffers{batch_size, batch_size, mStream};
|
||||
|
||||
fillBuffers(setupParams.beam_search_diversity_rate, 0.0f, mDiversityRate, diversity_rates_buf_, (int*) nullptr);
|
||||
fillBuffers(setupParams.length_penalty, 0.0f, mLengthPenalty, length_penalties_buf_, (int*) nullptr);
|
||||
fillBuffers(setupParams.early_stopping, 1, mEarlyStopping, early_stoppings_buf_, (int*) nullptr);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
@ -115,44 +117,39 @@ template <typename T>
|
||||
void OnlineBeamSearchLayer<T>::invokeSoftMax(BeamSearchOutputParams& outputs, SoftmaxParams const& params)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s", __PRETTY_FUNCTION__);
|
||||
Tensor const& output_ids_ptr = outputs.output_ids_ptr;
|
||||
const auto batch_size = static_cast<std::int32_t>(output_ids_ptr.shape[0]);
|
||||
const auto beam_width = static_cast<std::int32_t>(output_ids_ptr.shape[1]);
|
||||
const auto max_seq_len = static_cast<std::int32_t>(output_ids_ptr.shape[2]);
|
||||
const int ite{params.ite};
|
||||
Tensor const& logits{params.logits};
|
||||
const auto local_batch_size = logits.shape[0];
|
||||
|
||||
BeamHypotheses beamHypotheses;
|
||||
auto* const end_ids = params.end_ids.template getPtr<const int>();
|
||||
float* output_log_probs = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr<float>() : nullptr;
|
||||
auto* finished
|
||||
= reinterpret_cast<FinishedState*>(outputs.finished->template getPtr<FinishedState::UnderlyingType>());
|
||||
auto* sequence_lengths = outputs.sequence_length->template getPtr<int>();
|
||||
|
||||
BeamHypotheses beam_hyps;
|
||||
if (outputs.beamHypotheses)
|
||||
{
|
||||
beamHypotheses = *outputs.beamHypotheses;
|
||||
beamHypotheses.ite = ite;
|
||||
beamHypotheses.local_batch_size = local_batch_size;
|
||||
beamHypotheses.batch_size = batch_size;
|
||||
beamHypotheses.max_seq_len = max_seq_len;
|
||||
beamHypotheses.output_ids_src_ptr = output_ids_ptr.template getPtr<const int*>();
|
||||
beamHypotheses.parent_ids_src_ptr = outputs.parent_ids_ptr.template getPtr<const int*>();
|
||||
beamHypotheses.sequence_lengths_src = sequence_lengths;
|
||||
beamHypotheses.log_probs_src = output_log_probs;
|
||||
beamHypotheses.length_penalties = length_penalties_buf_;
|
||||
beamHypotheses.end_ids = end_ids;
|
||||
beam_hyps = *outputs.beamHypotheses;
|
||||
// Some of beam_hyps members have been initialized before function invokeSoftMax
|
||||
beam_hyps.end_ids = params.end_ids.template getPtr<const int>();
|
||||
beam_hyps.log_probs = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr<float>() : nullptr;
|
||||
beam_hyps.output_ids_src_ptr = outputs.output_ids_ptr.template getPtr<const int*>();
|
||||
beam_hyps.output_ids_tgt_ptr = outputs.output_ids_ptr.template getPtr<int*>();
|
||||
beam_hyps.parent_ids_src_ptr = outputs.parent_ids_ptr.template getPtr<const int*>();
|
||||
beam_hyps.parent_ids_tgt_ptr = outputs.parent_ids_ptr.template getPtr<int*>();
|
||||
beam_hyps.sequence_lengths_src = outputs.sequence_length->template getPtr<int>();
|
||||
|
||||
beam_hyps.batch_size = static_cast<std::int32_t>(outputs.output_ids_ptr.shape[0]);
|
||||
beam_hyps.beam_width = static_cast<std::int32_t>(outputs.output_ids_ptr.shape[1]);
|
||||
beam_hyps.ite = params.ite;
|
||||
beam_hyps.local_batch_size = params.logits.shape[0];
|
||||
beam_hyps.max_seq_len = static_cast<std::int32_t>(outputs.output_ids_ptr.shape[2]);
|
||||
beam_hyps.vocab_size = vocab_size_padded_;
|
||||
beam_hyps.diversity_rates = diversity_rates_buf_;
|
||||
beam_hyps.length_penalties = length_penalties_buf_;
|
||||
beam_hyps.early_stoppings = early_stoppings_buf_;
|
||||
}
|
||||
|
||||
invokeTopkSoftMax(logits.template getPtr<T>(), (const T*) (nullptr), finished, sequence_lengths,
|
||||
outputs.cum_log_probs->template getPtr<float>(), output_log_probs, output_ids_ptr.getPtr<int*>(),
|
||||
topk_softmax_workspace_, topk_softmax_workspace_size_, &beamHypotheses, local_batch_size, beam_width,
|
||||
vocab_size_padded_, end_ids, diversity_rates_buf_, length_penalties_buf_, mStream);
|
||||
invokeTopkSoftMax(params.logits.template getPtr<T>(), (const T*) (nullptr), finished,
|
||||
outputs.cum_log_probs->template getPtr<float>(), topk_softmax_workspace_, topk_softmax_workspace_size_,
|
||||
beam_hyps, mStream);
|
||||
sync_check_cuda_error();
|
||||
|
||||
invokeUpdate(finished, outputs.parent_ids_ptr.template getPtr<int*>(), sequence_lengths,
|
||||
output_ids_ptr.getPtr<int*>(), &beamHypotheses, local_batch_size, beam_width, vocab_size_padded_, end_ids,
|
||||
max_seq_len, mStream);
|
||||
invokeUpdate(finished, beam_hyps, mStream);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
@ -169,6 +166,7 @@ void OnlineBeamSearchLayer<T>::allocateBuffer(size_t batch_size)
|
||||
mAllocator->reMalloc(topk_softmax_workspace_, sizeof(float) * topk_softmax_workspace_size_, true));
|
||||
diversity_rates_buf_ = mAllocator->reMalloc(diversity_rates_buf_, sizeof(float) * batch_size, false);
|
||||
length_penalties_buf_ = mAllocator->reMalloc(length_penalties_buf_, sizeof(float) * batch_size, false);
|
||||
early_stoppings_buf_ = mAllocator->reMalloc(early_stoppings_buf_, sizeof(int) * batch_size, false);
|
||||
|
||||
mIsAllocateBuffer = true;
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
@ -183,6 +181,7 @@ void OnlineBeamSearchLayer<T>::freeBuffer()
|
||||
mAllocator->free((void**) (&topk_softmax_workspace_));
|
||||
mAllocator->free((void**) (&diversity_rates_buf_));
|
||||
mAllocator->free((void**) (&length_penalties_buf_));
|
||||
mAllocator->free((void**) (&early_stoppings_buf_));
|
||||
mIsAllocateBuffer = false;
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
|
||||
@ -41,6 +41,7 @@ public:
|
||||
public:
|
||||
std::optional<std::vector<float>> beam_search_diversity_rate; // [1] or [batch_size] on cpu
|
||||
std::optional<std::vector<float>> length_penalty; // [1] or [batch_size] on cpu
|
||||
std::optional<std::vector<int>> early_stopping; // [1] or [batch_size] on cpu
|
||||
};
|
||||
|
||||
OnlineBeamSearchLayer(
|
||||
@ -71,8 +72,10 @@ protected:
|
||||
|
||||
std::vector<float> mDiversityRate;
|
||||
std::vector<float> mLengthPenalty;
|
||||
std::vector<int> mEarlyStopping;
|
||||
float* diversity_rates_buf_;
|
||||
float* length_penalties_buf_;
|
||||
int* early_stoppings_buf_;
|
||||
|
||||
private:
|
||||
void allocateBuffer(size_t batch_size);
|
||||
|
||||
@ -168,6 +168,7 @@ bool GPTAttentionPluginCommon::convertMMHAParamsToXQAParams(tensorrt_llm::kernel
|
||||
memset(&xqaParams, 0, sizeof(XQAParams));
|
||||
xqaParams.data_type = ConvertMMHAToXQAParamsHelper<T, KVCacheBuffer>::data_type;
|
||||
|
||||
xqaParams.layer_idx = mLayerIdx;
|
||||
xqaParams.num_q_heads = mNumHeads;
|
||||
xqaParams.num_kv_heads = mNumKVHeads;
|
||||
xqaParams.head_size = mHeadSize;
|
||||
@ -363,8 +364,8 @@ INSTANTIATE_MMHA_DISPATCH(__nv_bfloat16, __nv_bfloat16)
|
||||
#endif
|
||||
#undef INSTANTIATE_MMHA_DISPATCH
|
||||
|
||||
GPTAttentionPluginCommon::GPTAttentionPluginCommon(int num_heads, int num_kv_heads, int head_size, int unidirectional,
|
||||
float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, int num_kv_heads, int head_size,
|
||||
int unidirectional, float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
|
||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
|
||||
@ -374,7 +375,8 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int num_heads, int num_kv_hea
|
||||
bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length,
|
||||
bool qkv_bias_enabled, bool cross_attention, int max_distance, bool pos_shift_enabled, bool dense_context_fmha,
|
||||
bool use_paged_context_fmha, bool use_cache, bool is_medusa_enabled)
|
||||
: mNumHeads(num_heads)
|
||||
: mLayerIdx(layer_idx)
|
||||
, mNumHeads(num_heads)
|
||||
, mNumKVHeads(num_kv_heads)
|
||||
, mHeadSize(head_size)
|
||||
, mUnidirectional(unidirectional)
|
||||
@ -477,6 +479,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(const void* data, size_t leng
|
||||
const char *d = reinterpret_cast<const char*>(data), *a = d;
|
||||
unsigned int kvCacheQuantMode;
|
||||
|
||||
read(d, mLayerIdx);
|
||||
read(d, mNumHeads);
|
||||
read(d, mNumKVHeads);
|
||||
read(d, mHeadSize);
|
||||
@ -1490,11 +1493,12 @@ void GPTAttentionPluginCommon::destroy() noexcept
|
||||
|
||||
size_t GPTAttentionPluginCommon::getCommonSerializationSize() noexcept
|
||||
{
|
||||
return sizeof(mNumHeads) + sizeof(mNumKVHeads) + sizeof(mHeadSize) + sizeof(mUnidirectional) + sizeof(mQScaling)
|
||||
+ sizeof(mPositionEmbeddingType) + sizeof(mRotaryEmbeddingDim) + sizeof(mRotaryEmbeddingBase)
|
||||
+ sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale) + sizeof(mRotaryEmbeddingMaxPositions)
|
||||
+ sizeof(mTpSize) + sizeof(mTpRank) + sizeof(mEnableContextFMHA) + sizeof(mFMHAForceFP32Acc)
|
||||
+ sizeof(mMultiBlockMode) + sizeof(mEnableXQA) + sizeof(unsigned int) // mKVCacheQuantMode
|
||||
return sizeof(mLayerIdx) + sizeof(mNumHeads) + sizeof(mNumKVHeads) + sizeof(mHeadSize) + sizeof(mUnidirectional)
|
||||
+ sizeof(mQScaling) + sizeof(mPositionEmbeddingType) + sizeof(mRotaryEmbeddingDim)
|
||||
+ sizeof(mRotaryEmbeddingBase) + sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale)
|
||||
+ sizeof(mRotaryEmbeddingMaxPositions) + sizeof(mTpSize) + sizeof(mTpRank) + sizeof(mEnableContextFMHA)
|
||||
+ sizeof(mFMHAForceFP32Acc) + sizeof(mMultiBlockMode) + sizeof(mEnableXQA)
|
||||
+ sizeof(unsigned int) // mKVCacheQuantMode
|
||||
+ sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mPagedKVCache) + sizeof(mTokensPerBlock) + sizeof(mType)
|
||||
+ sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled) + sizeof(mCrossAttention) + sizeof(mMaxDistance)
|
||||
+ sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA) + sizeof(mPagedContextFMHA) + sizeof(mUseKVCache)
|
||||
@ -1504,6 +1508,7 @@ size_t GPTAttentionPluginCommon::getCommonSerializationSize() noexcept
|
||||
void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
|
||||
{
|
||||
char *d = static_cast<char*>(buffer), *a = d;
|
||||
write(d, mLayerIdx);
|
||||
write(d, mNumHeads);
|
||||
write(d, mNumKVHeads);
|
||||
write(d, mHeadSize);
|
||||
|
||||
@ -36,8 +36,8 @@ class GPTAttentionPluginCommon : public BasePlugin
|
||||
public:
|
||||
GPTAttentionPluginCommon() = delete;
|
||||
|
||||
GPTAttentionPluginCommon(int num_heads, int num_kv_heads, int head_size, int unidirectional, float q_scaling,
|
||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
GPTAttentionPluginCommon(int layer_idx, int num_heads, int num_kv_heads, int head_size, int unidirectional,
|
||||
float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
|
||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
|
||||
@ -213,6 +213,7 @@ protected:
|
||||
|
||||
const std::string mLayerName;
|
||||
|
||||
int mLayerIdx;
|
||||
int mNumHeads;
|
||||
int mNumKVHeads;
|
||||
int mHeadSize;
|
||||
|
||||
@ -37,8 +37,8 @@ using tensorrt_llm::plugins::GPTAttentionPlugin;
|
||||
static const char* GPT_ATTENTION_PLUGIN_VERSION{"1"};
|
||||
static const char* GPT_ATTENTION_PLUGIN_NAME{"GPTAttention"};
|
||||
|
||||
GPTAttentionPlugin::GPTAttentionPlugin(int num_heads, int num_kv_heads, int head_size, int unidirectional,
|
||||
float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int num_kv_heads, int head_size,
|
||||
int unidirectional, float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
|
||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
|
||||
@ -48,12 +48,12 @@ GPTAttentionPlugin::GPTAttentionPlugin(int num_heads, int num_kv_heads, int head
|
||||
bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length,
|
||||
bool qkv_bias_enabled, bool cross_attention, int max_distance, bool pos_shift_enabled, bool dense_context_fmha,
|
||||
bool use_paged_context_fmha, bool use_cache, bool is_medusa_enabled)
|
||||
: GPTAttentionPluginCommon(num_heads, num_kv_heads, head_size, unidirectional, q_scaling, position_embedding_type,
|
||||
rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale,
|
||||
rotary_embedding_max_positions, tp_size, tp_rank, unfuse_qkv_gemm, context_fmha_type, multi_block_mode,
|
||||
enable_xqa, kv_cache_quant_mode, remove_input_padding, mask_type, paged_kv_cache, tokens_per_block, type,
|
||||
max_context_length, qkv_bias_enabled, cross_attention, max_distance, pos_shift_enabled, dense_context_fmha,
|
||||
use_paged_context_fmha, use_cache, is_medusa_enabled)
|
||||
: GPTAttentionPluginCommon(layer_idx, num_heads, num_kv_heads, head_size, unidirectional, q_scaling,
|
||||
position_embedding_type, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type,
|
||||
rotary_embedding_scale, rotary_embedding_max_positions, tp_size, tp_rank, unfuse_qkv_gemm, context_fmha_type,
|
||||
multi_block_mode, enable_xqa, kv_cache_quant_mode, remove_input_padding, mask_type, paged_kv_cache,
|
||||
tokens_per_block, type, max_context_length, qkv_bias_enabled, cross_attention, max_distance, pos_shift_enabled,
|
||||
dense_context_fmha, use_paged_context_fmha, use_cache, is_medusa_enabled)
|
||||
{
|
||||
initEntryIdx();
|
||||
}
|
||||
@ -485,7 +485,7 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
// Note that this cyclic_attention_window_size might be smaller than the actual kv cache capactity.
|
||||
const int cyclic_attention_window_size = isCrossAttention()
|
||||
? max_encoder_context_len
|
||||
: reinterpret_cast<const int*>(inputs[getIdx(IdxEntry::HOST_MAX_ATTENTION_WINDOW)])[0];
|
||||
: reinterpret_cast<const int*>(inputs[getIdx(IdxEntry::HOST_MAX_ATTENTION_WINDOW)])[mLayerIdx];
|
||||
const int sink_token_length = reinterpret_cast<const int*>(inputs[getIdx(IdxEntry::HOST_SINK_TOKEN_LENGTH)])[0];
|
||||
|
||||
const float* kv_scale_orig_quant = nullptr;
|
||||
@ -503,14 +503,18 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
void* host_block_pointers = nullptr;
|
||||
if (useKVCache() && mPagedKVCache)
|
||||
{
|
||||
auto& kvCacheBlockPointers = inputDesc[getIdx(IdxEntry::KV_CACHE_BLOCK_POINTERS)];
|
||||
auto& kvCacheBlockPointersShape = inputDesc[getIdx(IdxEntry::KV_CACHE_BLOCK_POINTERS)].dims;
|
||||
auto const& kvCacheBlockPointers = inputDesc[getIdx(IdxEntry::KV_CACHE_BLOCK_POINTERS)];
|
||||
auto const& kvCacheBlockPointersShape = inputDesc[getIdx(IdxEntry::KV_CACHE_BLOCK_POINTERS)].dims;
|
||||
max_blocks_per_sequence = kvCacheBlockPointersShape.d[kvCacheBlockPointersShape.nbDims - 1];
|
||||
auto offset = getStride(kvCacheBlockPointersShape, 0) * seqIdxBeg;
|
||||
auto const typed_block_pointers
|
||||
auto const batchSize = kvCacheBlockPointersShape.d[1];
|
||||
auto const seqStride = getStride(kvCacheBlockPointersShape, 1);
|
||||
auto const layerOffset = mLayerIdx * batchSize * seqStride;
|
||||
auto const seqOffset = seqIdxBeg * seqStride;
|
||||
auto const offset = layerOffset + seqOffset;
|
||||
auto const* const typed_block_pointers
|
||||
= static_cast<void* const*>(inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_POINTERS)]) + offset;
|
||||
block_pointers = const_cast<void*>(static_cast<void const*>(typed_block_pointers));
|
||||
auto const typed_host_block_pointers
|
||||
auto const* const typed_host_block_pointers
|
||||
= static_cast<void* const*>(inputs[getIdx(IdxEntry::HOST_KV_CACHE_BLOCK_POINTERS)]) + offset;
|
||||
host_block_pointers = const_cast<void*>(static_cast<void const*>(typed_host_block_pointers));
|
||||
}
|
||||
@ -762,9 +766,10 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(const char* name, const Plugi
|
||||
|
||||
try
|
||||
{
|
||||
auto* obj = new GPTAttentionPlugin(p.getScalar<int32_t>("num_heads").value(),
|
||||
p.getScalar<int32_t>("num_kv_heads").value(), p.getScalar<int32_t>("head_size").value(),
|
||||
p.getScalar<int32_t>("unidirectional").value(), p.getScalar<float>("q_scaling").value(),
|
||||
auto* obj = new GPTAttentionPlugin(p.getScalar<int32_t>("layer_idx").value(),
|
||||
p.getScalar<int32_t>("num_heads").value(), p.getScalar<int32_t>("num_kv_heads").value(),
|
||||
p.getScalar<int32_t>("head_size").value(), p.getScalar<int32_t>("unidirectional").value(),
|
||||
p.getScalar<float>("q_scaling").value(),
|
||||
static_cast<PositionEmbeddingType>(p.getScalar<int8_t>("position_embedding_type").value()),
|
||||
p.getScalar<int32_t>("rotary_embedding_dim").value(), p.getScalar<float>("rotary_embedding_base").value(),
|
||||
static_cast<RotaryScalingType>(p.getScalar<int8_t>("rotary_embedding_scale_type").value()),
|
||||
|
||||
@ -46,7 +46,7 @@ namespace tensorrt_llm::plugins
|
||||
// enable_remove_input_padding
|
||||
// 1. sequence_length [batch_size] (optional)
|
||||
// 2. host_past_key_value_lengths [batch_size] (int32) (optional)
|
||||
// 3. host_max_attention_window_sizes [1] (int32)
|
||||
// 3. host_max_attention_window_sizes [num_layers] (int32)
|
||||
// 4. host_sink_token_length [1] (int32)
|
||||
// 5. context_lengths [batch_size]
|
||||
// 6. cache_indir [num_gen_requests, beam_width, memory_max_len] (required in beamsearch) (optional)
|
||||
@ -54,8 +54,8 @@ namespace tensorrt_llm::plugins
|
||||
// mode,
|
||||
// all elements must be identical.
|
||||
// 8. past_key_value_pool [batch_size, 2, local_num_kv_heads, max_seq_len, head_size] or
|
||||
// block_pointers [batch_size, 2, max_blocks_per_seq] if paged kv cache (optional)
|
||||
// 8.1 host_block_pointers [batch_size, 2, max_blocks_per_seq] if paged kv cache (optional)
|
||||
// block_pointers [num_layers, batch_size, 2, max_blocks_per_seq] if paged kv cache (optional)
|
||||
// 8.1 host_block_pointers [num_layers, batch_size, 2, max_blocks_per_seq] if paged kv cache (optional)
|
||||
// 9. kv_cache_quantization_scale [1] (optional)
|
||||
// 10. kv_cache_dequantization_scale [1] (optional)
|
||||
// 11. alibi_slopes [num_heads] (optional for ALiBi position embedding)
|
||||
@ -70,8 +70,8 @@ namespace tensorrt_llm::plugins
|
||||
class GPTAttentionPlugin : public GPTAttentionPluginCommon
|
||||
{
|
||||
public:
|
||||
GPTAttentionPlugin(int num_heads, int num_kv_heads, int head_size, int unidirectional, float q_scaling,
|
||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
GPTAttentionPlugin(int layer_idx, int num_heads, int num_kv_heads, int head_size, int unidirectional,
|
||||
float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
|
||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
|
||||
|
||||
@ -301,7 +301,7 @@ void LoraPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, in
|
||||
}
|
||||
mGemmId.n = N;
|
||||
mGemmId.k = K;
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
int64_t getLowRankWorkSpaceSize(
|
||||
@ -344,7 +344,7 @@ size_t LoraPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, in
|
||||
void runCublasGemmEx(const int M, const int N, const int K, const bool transA, const bool transB, const void* act,
|
||||
const void* weight, void* output, cublasHandle_t cublas_handle)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
float a = 1.0f;
|
||||
float b = 0.0f;
|
||||
void* alpha = &a;
|
||||
@ -356,13 +356,13 @@ void runCublasGemmEx(const int M, const int N, const int K, const bool transA, c
|
||||
|
||||
tensorrt_llm::common::check_cuda_error(cublasGemmEx(cublas_handle, transa, transb, m, n, k, alpha, weight,
|
||||
CUDA_R_16F, lda, act, CUDA_R_16F, ldb, beta, output, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
|
||||
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
// inputs
|
||||
// input [-1, K] (view as 2D)
|
||||
// host_request_type [batch_size] on cpu
|
||||
@ -565,7 +565,7 @@ int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf
|
||||
}
|
||||
}
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -19,16 +19,14 @@
|
||||
#include "namedTensor.h"
|
||||
#include "tensorrt_llm/batch_manager/GptManager.h"
|
||||
#include "tensorrt_llm/batch_manager/callbacks.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/pybind/utils/pathCaster.h"
|
||||
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <ATen/ops/tensor.h>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
@ -38,17 +36,19 @@ namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
|
||||
GptManager::GptManager(std::filesystem::path const& trtEnginePath, tb::TrtGptModelType modelType, int32_t maxBeamWidth,
|
||||
tb::batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback getInferenceRequestsCb,
|
||||
SendResponseCallback sendResponseCb, tb::PollStopSignalCallback pollStopSignalCb,
|
||||
tb::ReturnBatchManagerStatsCallback returnBatchManagerStatsCb, const tb::TrtGptModelOptionalParams& optionalParams,
|
||||
std::optional<uint64_t> terminateReqId)
|
||||
: tb::GptManager(trtEnginePath, modelType, maxBeamWidth, schedulerPolicy, callbackAdapter(getInferenceRequestsCb),
|
||||
callbackAdapter(sendResponseCb), pollStopSignalCb, returnBatchManagerStatsCb, optionalParams, terminateReqId)
|
||||
tb::batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback const& getInferenceRequestsCb,
|
||||
SendResponseCallback const& sendResponseCb, const tb::PollStopSignalCallback& pollStopSignalCb,
|
||||
tb::ReturnBatchManagerStatsCallback const& returnBatchManagerStatsCb,
|
||||
tb::TrtGptModelOptionalParams const& optionalParams, std::optional<uint64_t> terminateReqId)
|
||||
{
|
||||
mManager = std::make_unique<tb::GptManager>(trtEnginePath, modelType, maxBeamWidth, schedulerPolicy,
|
||||
callbackAdapter(getInferenceRequestsCb), callbackAdapter(sendResponseCb), pollStopSignalCb,
|
||||
returnBatchManagerStatsCb, optionalParams, terminateReqId);
|
||||
}
|
||||
|
||||
py::object GptManager::enter()
|
||||
{
|
||||
TLLM_CHECK(static_cast<bool>(mManager));
|
||||
return py::cast(this);
|
||||
}
|
||||
|
||||
@ -62,11 +62,14 @@ void GptManager::shutdown()
|
||||
// NOTE: we must release the GIL here. GptManager has spawned a thread for the execution loop. That thread must be
|
||||
// able to do forward progress for the shutdown process to succeed. It takes the GIL during its callbacks, so
|
||||
// we release it now. Note that we shouldn't do anything related to python objects after that.
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
py::gil_scoped_release release;
|
||||
tb::GptManager::shutdown();
|
||||
mManager->shutdown();
|
||||
mManager = nullptr;
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
tb::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback callback)
|
||||
tb::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback const& callback)
|
||||
{
|
||||
return [callback](int32_t max_sequences)
|
||||
{
|
||||
@ -82,14 +85,14 @@ tb::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback ca
|
||||
};
|
||||
}
|
||||
|
||||
tb::SendResponseCallback callbackAdapter(SendResponseCallback callback)
|
||||
tb::SendResponseCallback callbackAdapter(SendResponseCallback const& callback)
|
||||
{
|
||||
return [callback](uint64_t id, std::list<tb::NamedTensor> const& cppTensors, bool isOk, const std::string& errMsg)
|
||||
{
|
||||
std::list<NamedTensor> pythonList{};
|
||||
for (const auto& cppNamedTensor : cppTensors)
|
||||
{
|
||||
pythonList.push_back(NamedTensor{cppNamedTensor});
|
||||
pythonList.emplace_back(cppNamedTensor);
|
||||
}
|
||||
callback(id, pythonList, isOk, errMsg);
|
||||
};
|
||||
|
||||
@ -24,6 +24,7 @@
|
||||
|
||||
#include <ATen/ops/tensor.h>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
@ -31,18 +32,18 @@ namespace tensorrt_llm::pybind::batch_manager
|
||||
using GetInferenceRequestsCallback = std::function<std::list<InferenceRequest>(int32_t)>;
|
||||
using SendResponseCallback = std::function<void(uint64_t, std::list<NamedTensor> const&, bool, const std::string&)>;
|
||||
|
||||
tensorrt_llm::batch_manager::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback callback);
|
||||
tensorrt_llm::batch_manager::SendResponseCallback callbackAdapter(SendResponseCallback callback);
|
||||
tensorrt_llm::batch_manager::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback const& callback);
|
||||
tensorrt_llm::batch_manager::SendResponseCallback callbackAdapter(SendResponseCallback const& callback);
|
||||
|
||||
class GptManager : tensorrt_llm::batch_manager::GptManager
|
||||
class GptManager
|
||||
{
|
||||
public:
|
||||
GptManager(std::filesystem::path const& trtEnginePath, tensorrt_llm::batch_manager::TrtGptModelType modelType,
|
||||
int32_t maxBeamWidth, tensorrt_llm::batch_manager::batch_scheduler::SchedulerPolicy schedulerPolicy,
|
||||
GetInferenceRequestsCallback getInferenceRequestsCb, SendResponseCallback sendResponseCb,
|
||||
tensorrt_llm::batch_manager::PollStopSignalCallback pollStopSignalCb = nullptr,
|
||||
tensorrt_llm::batch_manager::ReturnBatchManagerStatsCallback returnBatchManagerStatsCb = nullptr,
|
||||
const tensorrt_llm::batch_manager::TrtGptModelOptionalParams& optionalParams
|
||||
GetInferenceRequestsCallback const& getInferenceRequestsCb, SendResponseCallback const& sendResponseCb,
|
||||
tensorrt_llm::batch_manager::PollStopSignalCallback const& pollStopSignalCb = nullptr,
|
||||
tensorrt_llm::batch_manager::ReturnBatchManagerStatsCallback const& returnBatchManagerStatsCb = nullptr,
|
||||
tensorrt_llm::batch_manager::TrtGptModelOptionalParams const& optionalParams
|
||||
= tensorrt_llm::batch_manager::TrtGptModelOptionalParams(),
|
||||
std::optional<uint64_t> terminateReqId = std::nullopt);
|
||||
|
||||
@ -51,6 +52,9 @@ public:
|
||||
void shutdown();
|
||||
|
||||
static void initBindings(pybind11::module_& m);
|
||||
|
||||
private:
|
||||
std::unique_ptr<tensorrt_llm::batch_manager::GptManager> mManager;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::pybind::batch_manager
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include "inferenceRequest.h"
|
||||
|
||||
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/runtime/torch.h"
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
#include <memory>
|
||||
@ -25,6 +26,15 @@
|
||||
#include <pybind11/stl.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
// FIXME: THPStream_Wrap seems not to be present in libtorch_python.so on Windows
|
||||
PyObject* THPStream_Wrap(const c10::Stream& stream)
|
||||
{
|
||||
TLLM_THROW("Stream conversion in not yet supported on Windows.");
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
|
||||
@ -32,6 +42,7 @@ using namespace tensorrt_llm::pybind::batch_manager;
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
std::shared_ptr<InferenceRequest> fromTrtLlm(tb::InferenceRequest const& request)
|
||||
{
|
||||
InferenceRequest::TensorMap tensorMap;
|
||||
@ -53,18 +64,24 @@ std::shared_ptr<tb::InferenceRequest> InferenceRequest::toTrtLlm() const
|
||||
tb::InferenceRequest::TensorMap tensorMap;
|
||||
for (auto const& [name, tensor] : mInputTensors)
|
||||
{
|
||||
if (tensor.has_value())
|
||||
{
|
||||
tensorMap[name] = tr::TorchView::of(tensor.value());
|
||||
}
|
||||
tensorMap[name] = tr::TorchView::of(tensor);
|
||||
}
|
||||
auto inferenceRequest = std::make_shared<tb::InferenceRequest>(std::move(tensorMap), mRequestId);
|
||||
inferenceRequest->setIsStreaming(isStreaming());
|
||||
|
||||
if (mlogitsPostProcessor)
|
||||
{
|
||||
inferenceRequest->setLogitsPostProcessor(LlmRequest::callbackAdapter(mlogitsPostProcessor));
|
||||
}
|
||||
|
||||
return inferenceRequest;
|
||||
}
|
||||
|
||||
std::string InferenceRequest::serialize() const
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mlogitsPostProcessor == std::nullopt,
|
||||
"Serializing InferenceRequest with logitsPostProcessor set is not supported."
|
||||
"Please set the callback after de-serialization");
|
||||
std::vector<std::int64_t> serialized{toTrtLlm()->serialize()};
|
||||
static_assert(sizeof(decltype(serialized)::value_type) / sizeof(char) == 8);
|
||||
return {reinterpret_cast<char const*>(serialized.data()), serialized.size() * 8};
|
||||
@ -81,8 +98,12 @@ std::shared_ptr<InferenceRequest> InferenceRequest::deserialize(std::string cons
|
||||
void InferenceRequest::initBindings(py::module_& m)
|
||||
{
|
||||
py::class_<InferenceRequest>(m, "InferenceRequest")
|
||||
.def(py::init<uint64_t>())
|
||||
.def(py::init<uint64_t, std::optional<LogitsProcessorCallback>>(), py::arg("request_id"),
|
||||
py::arg("logits_post_processor_callback") = py::none())
|
||||
.def(py::init<uint64_t, InferenceRequest::TensorMap const&>(), "deprecated: use direct tensor access instead")
|
||||
.def_property("logits_post_processor",
|
||||
nullptr, // passing logits processor in the cpp->python direction doesn't work. getter is then undefined
|
||||
&InferenceRequest::setLogitsPostProcessor)
|
||||
.def_property("input_ids", &InferenceRequest::getInputIdsUnchecked, &InferenceRequest::setInputIds)
|
||||
.def_property(
|
||||
"draft_input_ids", &InferenceRequest::getDraftInputIdsUnchecked, &InferenceRequest::setDraftInputIds)
|
||||
@ -101,6 +122,8 @@ void InferenceRequest::initBindings(py::module_& m)
|
||||
.def_property("runtime_top_p", &InferenceRequest::getRuntimeTopPUnchecked, &InferenceRequest::setRuntimeTopP)
|
||||
.def_property(
|
||||
"length_penalty", &InferenceRequest::getLengthPenaltyUnchecked, &InferenceRequest::setLengthPenalty)
|
||||
.def_property(
|
||||
"early_stopping", &InferenceRequest::getEarlyStoppingUnchecked, &InferenceRequest::setEarlyStopping)
|
||||
.def_property("repetition_penalty", &InferenceRequest::getRepetitionPenaltyUnchecked,
|
||||
&InferenceRequest::setRepetitionPenalty)
|
||||
.def_property("min_length", &InferenceRequest::getMinLengthUnchecked, &InferenceRequest::setMinLength)
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/namedTensor.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
@ -30,25 +31,28 @@ namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
|
||||
class InferenceRequest
|
||||
: public tensorrt_llm::batch_manager::GenericInferenceRequest<std::optional<at::Tensor>, NamedTensor>
|
||||
: public tensorrt_llm::batch_manager::GenericInferenceRequest<at::Tensor, NamedTensor, c10::Stream>
|
||||
{
|
||||
public:
|
||||
using Base = tensorrt_llm::batch_manager::GenericInferenceRequest<std::optional<at::Tensor>, NamedTensor>;
|
||||
using Base = tensorrt_llm::batch_manager::GenericInferenceRequest<at::Tensor, NamedTensor, c10::Stream>;
|
||||
using TensorPtr = Base::TensorPtr;
|
||||
using TensorMap = Base::TensorMap;
|
||||
using LogitsProcessorCallback = Base::LogitsPostProcessor;
|
||||
|
||||
InferenceRequest(uint64_t requestId)
|
||||
: Base(requestId)
|
||||
InferenceRequest(uint64_t requestId, std::optional<LogitsProcessorCallback> logitsCb = std::nullopt)
|
||||
: Base(requestId, logitsCb)
|
||||
{
|
||||
}
|
||||
|
||||
InferenceRequest(uint64_t requestId, TensorMap const& inputTensors)
|
||||
: Base{requestId, inputTensors}
|
||||
InferenceRequest(uint64_t requestId, TensorMap const& inputTensors,
|
||||
std::optional<LogitsProcessorCallback> logitsCb = std::nullopt)
|
||||
: Base{requestId, inputTensors, logitsCb}
|
||||
{
|
||||
}
|
||||
|
||||
InferenceRequest(uint64_t requestId, TensorMap&& inputTensors)
|
||||
: Base{requestId, std::move(inputTensors)}
|
||||
InferenceRequest(
|
||||
uint64_t requestId, TensorMap&& inputTensors, std::optional<LogitsProcessorCallback> logitsCb = std::nullopt)
|
||||
: Base{requestId, std::move(inputTensors), logitsCb}
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
@ -17,7 +17,10 @@
|
||||
#include "llmRequest.h"
|
||||
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/generationInput.h"
|
||||
#include "tensorrt_llm/runtime/torch.h"
|
||||
#include "tensorrt_llm/runtime/torchUtils.h"
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
#include <memory>
|
||||
|
||||
@ -45,6 +48,25 @@ std::optional<tb::LlmRequest::TensorPtr> from_torch(std::optional<LlmRequest::Te
|
||||
|
||||
} // namespace
|
||||
|
||||
std::optional<tb::LlmRequest::LogitsPostProcessor> LlmRequest::callbackAdapter(
|
||||
std::optional<LlmRequest::LogitsPostProcessor> callback)
|
||||
{
|
||||
if (!callback)
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return [callback](RequestIdType reqId, tensorrt_llm::runtime::ITensor::SharedPtr& tensor,
|
||||
tensorrt_llm::batch_manager::LlmRequest::BeamTokens const& tokens,
|
||||
tensorrt_llm::runtime::BufferManager::CudaStreamPtr stream)
|
||||
{
|
||||
at::Tensor atTensor = tr::Torch::tensor(tensor);
|
||||
|
||||
auto result = callback.value()(reqId, atTensor, tokens, runtime::TorchUtils::stream(*stream).unwrap());
|
||||
return tr::TorchView::of(result);
|
||||
};
|
||||
}
|
||||
|
||||
std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
|
||||
{
|
||||
auto embeddingBias = from_torch(mEmbeddingBias);
|
||||
@ -58,7 +80,8 @@ std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
|
||||
return std::make_shared<tb::LlmRequest>(mRequestId, mMaxNewTokens,
|
||||
std::make_shared<std::vector<TokenIdType>>(mTokens.at(0)), mSamplingConfig, mIsStreaming, mEndId, mPadId,
|
||||
embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable, mPromptVocabSize, loraWeights, loraConfig,
|
||||
mReturnLogProbs, mReturnContextLogits, mReturnGenerationLogits, mDraftTokens, draftLogits);
|
||||
mReturnLogProbs, mReturnContextLogits, mReturnGenerationLogits, mDraftTokens, draftLogits,
|
||||
mExcludeInputFromOutput, callbackAdapter(mLogitsPostProcessor));
|
||||
}
|
||||
|
||||
void LlmRequest::initBindings(py::module_& m)
|
||||
@ -70,7 +93,7 @@ void LlmRequest::initBindings(py::module_& m)
|
||||
std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::TensorPtr>,
|
||||
std::optional<LlmRequest::SizeType>, std::optional<LlmRequest::TensorPtr>,
|
||||
std::optional<LlmRequest::TensorPtr>, bool, bool, bool, std::optional<LlmRequest::VecTokens>,
|
||||
std::optional<LlmRequest::TensorPtr>>(),
|
||||
std::optional<LlmRequest::TensorPtr>, bool, std::optional<LlmRequest::LogitsPostProcessor>>(),
|
||||
py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"),
|
||||
py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt,
|
||||
py::arg("embedding_bias") = std::nullopt, py::arg("bad_words_list") = std::nullopt,
|
||||
@ -78,7 +101,8 @@ void LlmRequest::initBindings(py::module_& m)
|
||||
py::arg("prompt_vocab_size") = std::nullopt, py::arg("lora_weights") = std::nullopt,
|
||||
py::arg("lora_config") = std::nullopt, py::arg("return_log_probs") = false,
|
||||
py::arg("return_context_logits") = false, py::arg("return_generation_logits") = false,
|
||||
py::arg("draft_tokens") = std::nullopt, py::arg("draft_logits") = std::nullopt)
|
||||
py::arg("draft_tokens") = std::nullopt, py::arg("draft_logits") = std::nullopt,
|
||||
py::arg("exclude_input_from_output") = false, py::arg("logits_post_processor") = std::nullopt)
|
||||
.def("get_num_tokens", &LlmRequest::getNumTokens, py::arg("beam"))
|
||||
.def_property_readonly("max_beam_num_tokens", &LlmRequest::getMaxBeamNumTokens)
|
||||
.def("get_token", &LlmRequest::getToken, py::arg("beam"), py::arg("pos"))
|
||||
|
||||
@ -28,10 +28,16 @@
|
||||
namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
|
||||
class LlmRequest : public tensorrt_llm::batch_manager::GenericLlmRequest<at::Tensor>
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
|
||||
/* Unfortunately, torch's default pybind bindings don't know about c10::cuda::CUDAStream,
|
||||
* so we have to pass the more generic c10::Stream, and convert it back to a full-fledged
|
||||
* torch.cuda.Stream in python. See example in test/bindings/test_gpt_manager.py
|
||||
*/
|
||||
class LlmRequest : public tb::GenericLlmRequest<at::Tensor, c10::Stream>
|
||||
{
|
||||
public:
|
||||
using Base = GenericLlmRequest<at::Tensor>;
|
||||
using Base = GenericLlmRequest<at::Tensor, c10::Stream>;
|
||||
using TensorPtr = Base::TensorPtr;
|
||||
using SizeType = Base::SizeType;
|
||||
using TokenIdType = Base::TokenIdType;
|
||||
@ -39,6 +45,7 @@ public:
|
||||
using VecLogProbs = Base::VecLogProbs;
|
||||
using BeamTokens = Base::BeamTokens;
|
||||
using VecTokens = Base::VecTokens;
|
||||
using LogitsPostProcessor = Base::LogitsPostProcessor;
|
||||
|
||||
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::vector<TokenIdType> inputTokens,
|
||||
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
|
||||
@ -48,16 +55,20 @@ public:
|
||||
std::optional<SizeType> promptVocabSize = std::nullopt, std::optional<TensorPtr> loraWeights = std::nullopt,
|
||||
std::optional<TensorPtr> loraConfig = std::nullopt, bool returnLogProbs = false,
|
||||
bool returnContextLogits = false, bool returnGenerationLogits = false,
|
||||
std::optional<VecTokens> draftTokens = std::nullopt, std::optional<TensorPtr> draftLogits = std::nullopt)
|
||||
std::optional<VecTokens> draftTokens = std::nullopt, std::optional<TensorPtr> draftLogits = std::nullopt,
|
||||
bool excludeInputFromOutput = false, std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt)
|
||||
: Base(requestId, maxNewTokens, std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)),
|
||||
samplingConfig, isStreaming, endId, padId, embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable,
|
||||
promptVocabSize, loraWeights, loraConfig, returnLogProbs, returnContextLogits, returnGenerationLogits,
|
||||
draftTokens.has_value() ? std::make_shared<VecTokens>(std::move(draftTokens.value()))
|
||||
: std::make_shared<VecTokens>(),
|
||||
draftLogits)
|
||||
draftLogits, excludeInputFromOutput, logitsPostProcessor)
|
||||
{
|
||||
}
|
||||
|
||||
static std::optional<tb::LlmRequest::LogitsPostProcessor> callbackAdapter(
|
||||
std::optional<LlmRequest::LogitsPostProcessor> callback);
|
||||
|
||||
[[nodiscard]] std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> toTrtLlm() const;
|
||||
static void initBindings(pybind11::module_& m);
|
||||
};
|
||||
|
||||
@ -37,6 +37,7 @@
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/gptJsonConfig.h"
|
||||
#include "tensorrt_llm/runtime/gptSession.h"
|
||||
#include "tensorrt_llm/runtime/memoryCounters.h"
|
||||
#include "tensorrt_llm/runtime/samplingConfig.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
@ -232,7 +233,8 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.def_readwrite("top_p_min", &tr::SamplingConfig::topPMin)
|
||||
.def_readwrite("top_p_reset_ids", &tr::SamplingConfig::topPResetIds)
|
||||
.def_readwrite("beam_search_diversity_rate", &tr::SamplingConfig::beamSearchDiversityRate)
|
||||
.def_readwrite("length_penalty", &tr::SamplingConfig::lengthPenalty);
|
||||
.def_readwrite("length_penalty", &tr::SamplingConfig::lengthPenalty)
|
||||
.def_readwrite("early_stopping", &tr::SamplingConfig::earlyStopping);
|
||||
|
||||
py::class_<tr::GptJsonConfig>(m, "GptJsonConfig")
|
||||
.def(py::init<std::string, std::string, std::string, SizeType, SizeType, tr::GptModelConfig>(), py::arg("name"),
|
||||
@ -302,6 +304,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
tensorNames.attr("RUNTIME_TOP_K") = py::str(tb::inference_request::kRuntimeTopKTensorName);
|
||||
tensorNames.attr("RUNTIME_TOP_P") = py::str(tb::inference_request::kRuntimeTopPTensorName);
|
||||
tensorNames.attr("LENGTH_PENALTY") = py::str(tb::inference_request::kLengthPenaltyTensorName);
|
||||
tensorNames.attr("EARLY_STOPPING") = py::str(tb::inference_request::kEarlyStoppingTensorName);
|
||||
tensorNames.attr("REPETITION_PENALTY") = py::str(tb::inference_request::kRepetitionPenaltyTensorName);
|
||||
tensorNames.attr("MIN_LENGTH") = py::str(tb::inference_request::kMinLengthTensorName);
|
||||
tensorNames.attr("PRESENCE_PENALTY") = py::str(tb::inference_request::kPresencePenaltyTensorName);
|
||||
@ -342,4 +345,11 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.def_readwrite("decoding_mode", &tb::TrtGptModelOptionalParams::decodingMode);
|
||||
|
||||
tpb::GptManager::initBindings(m);
|
||||
|
||||
py::class_<tr::MemoryCounters>(m, "MemoryCounters")
|
||||
.def_static("instance", &tr::MemoryCounters::getInstance, py::return_value_policy::reference)
|
||||
.def_property_readonly("gpu", &tr::MemoryCounters::getGpu)
|
||||
.def_property_readonly("cpu", &tr::MemoryCounters::getCpu)
|
||||
.def_property_readonly("pinned", &tr::MemoryCounters::getPinned)
|
||||
.def_property_readonly("uvm", &tr::MemoryCounters::getUVM);
|
||||
}
|
||||
|
||||
@ -19,7 +19,6 @@
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
#include "tensorrt_llm/runtime/gptDecoder.h"
|
||||
|
||||
#include "tensorrt_llm/common/cudaAllocator.h"
|
||||
#include "tensorrt_llm/common/tensorConversion.h"
|
||||
#include "tensorrt_llm/kernels/decodingKernels.h"
|
||||
#include "tensorrt_llm/layers/dynamicDecodeLayer.h"
|
||||
@ -81,6 +82,7 @@ void GptDecoder<T>::setup(SamplingConfig const& samplingConfig, size_t batchSize
|
||||
|
||||
setupParams.beam_search_diversity_rate = samplingConfig.beamSearchDiversityRate;
|
||||
setupParams.length_penalty = samplingConfig.lengthPenalty;
|
||||
setupParams.early_stopping = samplingConfig.earlyStopping;
|
||||
|
||||
auto const batchSlotsPtr = batchSlots.has_value() ? bufferCast<SizeType>(*(batchSlots.value())) : nullptr;
|
||||
mDynamicDecodeLayer->setup(batchSize, samplingConfig.beamWidth, batchSlotsPtr, setupParams);
|
||||
@ -174,7 +176,7 @@ template <typename T>
|
||||
typename tl::DynamicDecodeLayer<T>::OutputParams prepareOutputs(
|
||||
DecodingOutput& output, DecodingInput::TensorPtr const& inputLengths, DecodingOutput::TensorPtr& logProbsTiled)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
typename tl::DynamicDecodeLayer<T>::OutputParams outputParams(tcc::toTllmTensor(*output.ids));
|
||||
|
||||
outputParams.newTokens = tcc::toTllmTensor(*output.newTokens);
|
||||
@ -261,7 +263,7 @@ typename tl::DynamicDecodeLayer<T>::OutputParams prepareOutputs(
|
||||
template <typename T>
|
||||
bool GptDecoder<T>::forward(DecodingOutput& output, DecodingInput const& input)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto forwardParams = prepareInputs<T>(input);
|
||||
auto outputParams = prepareOutputs<T>(output, input.lengths, mLogProbsTiled);
|
||||
auto const maxBatchSize = input.maxBatchSize;
|
||||
@ -309,7 +311,7 @@ bool GptDecoder<T>::forward(DecodingOutput& output, DecodingInput const& input)
|
||||
template <typename T>
|
||||
void GptDecoder<T>::forwardAsync(DecodingOutput& output, DecodingInput const& input)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto forwardParams = prepareInputs<T>(input);
|
||||
auto outputParams = prepareOutputs<T>(output, input.lengths, mLogProbsTiled);
|
||||
|
||||
@ -321,7 +323,7 @@ template <typename T>
|
||||
void GptDecoder<T>::gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput,
|
||||
DecodingInput const& decodingInput, BufferManager const& manager)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto const& finalOutputIdsShape = finalOutputIds.getShape();
|
||||
auto const& decodingOutputIdsShape = decodingOutput.ids->getShape();
|
||||
auto const batchSize = finalOutputIdsShape.d[0];
|
||||
@ -355,7 +357,7 @@ void GptDecoder<T>::gatherTree(ITensor& finalOutputIds, DecodingOutput const& de
|
||||
beamHypotheses.length_penalties
|
||||
= nullptr; // TODO (bhsueh) should set length penalties, this should be a gpu tensor When it is set as
|
||||
// nullptr, the kernel will use default value (1.0f) automatically.
|
||||
|
||||
beamHypotheses.early_stoppings = nullptr; // TODO (wili), similar as length_penalties
|
||||
beamHypotheses.output_ids_tgt = bufferCast<TokenIdType>(*decodingOutput.beamHypotheses.outputIdsTgt);
|
||||
beamHypotheses.sequence_lengths_tgt = bufferCast<SizeType>(*decodingOutput.beamHypotheses.sequenceLengthsTgt);
|
||||
beamHypotheses.cum_log_probs = bufferCast<float>(*decodingOutput.beamHypotheses.cumLogProbs);
|
||||
@ -381,7 +383,7 @@ void GptDecoder<T>::gatherTree(ITensor& finalOutputIds, DecodingOutput const& de
|
||||
batchSize, stream.get());
|
||||
sync_check_cuda_error();
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
@ -394,7 +396,7 @@ void IGptDecoder::acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor
|
||||
ITensor const& contextLengths, ITensor const& numDraftTokens, ITensor& sequenceLengths, ITensor const& finishedVec,
|
||||
ITensor& finishedFinal, ITensor& finishedSum, ITensor const& batchSlots, BufferManager::CudaStreamPtr const& stream)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const finishedVecShape = finishedVec.getShape();
|
||||
auto const maxBatchSize = finishedVecShape.d[1];
|
||||
@ -439,7 +441,7 @@ void IGptDecoder::acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void IGptDecoder::acceptDraftTokensByLogits(ITensor& draftLogits, ITensor const& targetLogits, ITensor& draftProbs,
|
||||
@ -447,7 +449,7 @@ void IGptDecoder::acceptDraftTokensByLogits(ITensor& draftLogits, ITensor const&
|
||||
SizeType vocabSize, SizeType vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold,
|
||||
curandState_t* curandState, BufferManager::CudaStreamPtr const& stream)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const draftLogitsShape = draftLogits.getShape();
|
||||
auto const maxBatchSize = draftLogitsShape.d[0];
|
||||
@ -488,5 +490,5 @@ void IGptDecoder::acceptDraftTokensByLogits(ITensor& draftLogits, ITensor const&
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
@ -22,6 +22,7 @@
|
||||
#include "tensorrt_llm/runtime/runtimeKernels.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <memory>
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
@ -33,7 +34,7 @@ namespace
|
||||
{
|
||||
SamplingConfig extractSamplingConfig(SamplingConfig const& batchSamplingConfig, SizeType batchIdx)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
SamplingConfig samplingConfig{batchSamplingConfig.beamWidth};
|
||||
|
||||
auto extractOptional = [&batchIdx](auto& single, auto const& batch)
|
||||
@ -64,8 +65,9 @@ SamplingConfig extractSamplingConfig(SamplingConfig const& batchSamplingConfig,
|
||||
// beam search layer
|
||||
samplingConfig.beamSearchDiversityRate = batchSamplingConfig.beamSearchDiversityRate;
|
||||
samplingConfig.lengthPenalty = batchSamplingConfig.lengthPenalty;
|
||||
samplingConfig.earlyStopping = batchSamplingConfig.earlyStopping;
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
return samplingConfig;
|
||||
}
|
||||
|
||||
@ -78,7 +80,7 @@ GptDecoderBatch::GptDecoderBatch(
|
||||
, mStream{std::move(stream)}
|
||||
, mBufferManager{mStream}
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto constexpr nvTokenIdType = TRTDataType<TokenIdType>::value;
|
||||
auto constexpr nvSizeType = TRTDataType<SizeType>::value;
|
||||
auto constexpr nvFloatType = TRTDataType<float>::value;
|
||||
@ -125,14 +127,14 @@ GptDecoderBatch::GptDecoderBatch(
|
||||
dInput->badWordsLens = mBufferManager.emptyTensor(MemoryType::kPINNED, TRTDataType<SizeType>::value);
|
||||
dInput->embeddingBias = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptDecoderBatch::setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth,
|
||||
SizeType maxAttentionWindow, SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep,
|
||||
bool fusedDecoder, nvinfer1::DataType dtype)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK(maxBatchSize > 0);
|
||||
TLLM_CHECK(maxBeamWidth > 0);
|
||||
TLLM_CHECK(maxTokensPerStep > 0);
|
||||
@ -253,13 +255,13 @@ void GptDecoderBatch::setup(DecodingMode const& mode, SizeType maxBatchSize, Siz
|
||||
mBeamWidths[i] = 0;
|
||||
mGeneratedTokensPerStep[i] = 0;
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptDecoderBatch::newRequest(
|
||||
SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK(batchIdx >= 0);
|
||||
auto const& jointOutputIdsShape = mJointDecodingOutput->ids->getShape();
|
||||
auto const batchSize = jointOutputIdsShape.d[0];
|
||||
@ -373,7 +375,7 @@ void GptDecoderBatch::newRequest(
|
||||
dOutput->newTokensVec.resize(mMaxTokensPerStep);
|
||||
for (SizeType ti = 0; ti < mMaxTokensPerStep; ++ti)
|
||||
{
|
||||
TensorPtr newTokensStepView = std::move(ITensor::slice(dJointOutput.newTokensSteps, ti, localBatchSize));
|
||||
TensorPtr newTokensStepView = ITensor::slice(dJointOutput.newTokensSteps, ti, localBatchSize);
|
||||
newTokensStepView->squeeze(0);
|
||||
dOutput->newTokensVec[ti] = ITensor::slice(newTokensStepView, batchIdx, localBatchSize);
|
||||
manager.setZero(*dOutput->newTokensVec[ti]);
|
||||
@ -469,7 +471,7 @@ void GptDecoderBatch::newRequest(
|
||||
auto outputIdsView = ITensor::view(outputIds, ITensor::makeShape({beamWidth, mMaxSequenceLength}));
|
||||
kernels::invokeFill(*outputIdsView, endId, *stream);
|
||||
kernels::tileTensor(*outputIdsView, *inputIdsView, beamWidth, *stream);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptDecoderBatch::newRequests(std::vector<SizeType> const& seqSlots,
|
||||
@ -500,7 +502,7 @@ void GptDecoderBatch::newRequests(std::vector<SizeType> const& seqSlots,
|
||||
GptDecoderBatch::TokenPtr GptDecoderBatch::forwardAsync(
|
||||
decoder_batch::Output& output, decoder_batch::Input const& input)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto& allTargetLogits = input.logits;
|
||||
|
||||
// TODO(nkorobov): check logits shape considering draft tokens
|
||||
@ -737,13 +739,13 @@ GptDecoderBatch::TokenPtr GptDecoderBatch::forwardAsync(
|
||||
|
||||
CudaEvent eventStop{};
|
||||
mStream->record(eventStop);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
return std::make_unique<decoder_batch::Token>(std::move(eventStop), input.active);
|
||||
}
|
||||
|
||||
void GptDecoderBatch::forwardSync(decoder_batch::Token const& token)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
token.event.synchronize();
|
||||
|
||||
for (std::int32_t i = 0; i < mActualBatchSize; ++i)
|
||||
@ -756,13 +758,13 @@ void GptDecoderBatch::forwardSync(decoder_batch::Token const& token)
|
||||
|| bufferCast<SizeType>(*dOutput.finishedSum)[0] == static_cast<SizeType>(mBeamWidths[i]);
|
||||
}
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
// TODO call this at the end of forward if mFinished[i] changes from false to true?
|
||||
CudaEvent GptDecoderBatch::postProcessRequest(SizeType batchIdx) const
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto& stream = mStreams[batchIdx];
|
||||
auto manager = BufferManager{stream};
|
||||
auto& decoder = *mDecoders[batchIdx];
|
||||
@ -779,14 +781,14 @@ CudaEvent GptDecoderBatch::postProcessRequest(SizeType batchIdx) const
|
||||
CudaEvent event{};
|
||||
stream->record(event);
|
||||
mStream->wait(event);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
return event;
|
||||
}
|
||||
|
||||
void GptDecoderBatch::newBatch(
|
||||
GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
// split batch into single requests
|
||||
auto const& inputLengths = inputs.lengths;
|
||||
mActualBatchSize = inputLengths->getShape().d[0];
|
||||
@ -855,12 +857,12 @@ void GptDecoderBatch::newBatch(
|
||||
}
|
||||
newRequest(batchIdx, request, extractSamplingConfig(samplingConfig, batchIdx));
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptDecoderBatch::forwardAsync(decoder::Output& output, decoder::Input const& input)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const& logitsShape = input.logits->getShape();
|
||||
auto const batchSize = logitsShape.d[0];
|
||||
@ -886,32 +888,32 @@ void GptDecoderBatch::forwardAsync(decoder::Output& output, decoder::Input const
|
||||
kernels::reduce(*mFinishedSum, *ITensor::slice(mJointDecodingOutput->finishedSum, 0, mActualBatchSize), *mStream);
|
||||
mStream->record(mForwardEvent);
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptDecoderBatch::forwardSync()
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
forwardSync(*mForwardToken);
|
||||
// wait for mFinishedSum to be updated
|
||||
mForwardEvent.synchronize();
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptDecoderBatch::finalize() const
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
for (SizeType batchIdx = 0; batchIdx < mActualBatchSize; ++batchIdx)
|
||||
{
|
||||
postProcessRequest(batchIdx);
|
||||
auto event = postProcessRequest(batchIdx);
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
CudaEvent GptDecoderBatch::finalize(SizeType batchIdx) const
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto event = postProcessRequest(batchIdx);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
return event;
|
||||
}
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
#include "tensorrt_llm/runtime/gptJsonConfig.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "gptModelConfig.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
@ -24,6 +25,7 @@
|
||||
#include <fstream>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <string_view>
|
||||
#include <utility>
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
namespace tc = tensorrt_llm::common;
|
||||
@ -69,12 +71,122 @@ std::optional<FieldType> parseJsonFieldOptional(Json const& json, std::string_vi
|
||||
return value;
|
||||
}
|
||||
|
||||
GptModelConfig createModelConfig(
|
||||
Json const& json, bool engineVersionNone, SizeType tensorParallelism, nvinfer1::DataType dataType)
|
||||
{
|
||||
auto const& config = engineVersionNone ? json.at("builder_config") : json.at("pretrained_config");
|
||||
|
||||
auto const* const numLayersField = engineVersionNone ? "num_layers" : "num_hidden_layers";
|
||||
auto const* const numHeadsField = engineVersionNone ? "num_heads" : "num_attention_heads";
|
||||
auto const* const numKvHeadsField = engineVersionNone ? "num_kv_heads" : "num_key_value_heads";
|
||||
auto const* const mlpHiddenSizeField = engineVersionNone ? "mlp_hidden_size" : "intermediate_size";
|
||||
|
||||
auto const numLayers = config.at(numLayersField).template get<SizeType>();
|
||||
auto const numHeads = config.at(numHeadsField).template get<SizeType>() / tensorParallelism;
|
||||
|
||||
auto const vocabSize = config.at("vocab_size").template get<SizeType>();
|
||||
auto const hiddenSize = config.at("hidden_size").template get<SizeType>() / tensorParallelism;
|
||||
auto const sizePerHead = parseJsonFieldOr(config, "head_size", hiddenSize / numHeads);
|
||||
|
||||
// TODO:
|
||||
// Code crashes when numKvHeads <= 0. Clamping downwards to 1 prevents that, make sure this is best fix.
|
||||
auto const numKvHeads
|
||||
= std::max(parseJsonFieldOr(config, numKvHeadsField, numHeads * tensorParallelism) / tensorParallelism, 1);
|
||||
|
||||
auto const mlpHiddenSize = parseJsonFieldOptional<SizeType>(config, mlpHiddenSizeField);
|
||||
|
||||
auto modelConfig = GptModelConfig{vocabSize, numLayers, numHeads, hiddenSize, dataType};
|
||||
modelConfig.setSizePerHead(sizePerHead);
|
||||
modelConfig.setNbKvHeads(numKvHeads);
|
||||
|
||||
if (mlpHiddenSize.has_value())
|
||||
{
|
||||
modelConfig.setMlpHiddenSize(mlpHiddenSize.value() / tensorParallelism);
|
||||
}
|
||||
|
||||
return modelConfig;
|
||||
};
|
||||
|
||||
void parseBuilderConfig(GptModelConfig& modelConfig, Json const& builderConfig)
|
||||
{
|
||||
auto const maxBatchSize = parseJsonFieldOr(builderConfig, "max_batch_size", 0);
|
||||
auto const maxBeamWidth = parseJsonFieldOr(builderConfig, "max_beam_width", 0);
|
||||
auto const maxInputLen = parseJsonFieldOr(builderConfig, "max_input_len", 0);
|
||||
auto const maxSequenceLen = maxInputLen + parseJsonFieldOr(builderConfig, "max_output_len", 0);
|
||||
auto const maxDraftLen = parseJsonFieldOr(builderConfig, "max_draft_len", 0);
|
||||
auto const maxNumTokens = parseJsonFieldOptional<SizeType>(builderConfig, "max_num_tokens");
|
||||
auto const maxPromptEmbeddingTableSize
|
||||
= parseJsonFieldOr<SizeType>(builderConfig, "max_prompt_embedding_table_size", 0);
|
||||
auto const computeContextLogits = parseJsonFieldOr(builderConfig, "gather_context_logits", false);
|
||||
auto const computeGenerationLogits = parseJsonFieldOr(builderConfig, "gather_generation_logits", false);
|
||||
|
||||
modelConfig.setMaxBatchSize(maxBatchSize);
|
||||
modelConfig.setMaxBeamWidth(maxBeamWidth);
|
||||
modelConfig.setMaxInputLen(maxInputLen);
|
||||
modelConfig.setMaxSequenceLen(maxSequenceLen);
|
||||
modelConfig.setMaxNumTokens(maxNumTokens);
|
||||
modelConfig.setMaxDraftLen(maxDraftLen);
|
||||
modelConfig.setMaxPromptEmbeddingTableSize(maxPromptEmbeddingTableSize);
|
||||
modelConfig.computeContextLogits(computeContextLogits);
|
||||
modelConfig.computeGenerationLogits(computeGenerationLogits);
|
||||
}
|
||||
|
||||
void parsePluginConfig(GptModelConfig& modelConfig, Json const& pluginConfig)
|
||||
{
|
||||
auto const useGptAttentionPlugin = !pluginConfig.at("gpt_attention_plugin").is_null();
|
||||
auto const removeInputPadding = pluginConfig.at("remove_input_padding").template get<bool>();
|
||||
auto const& pagedKvCache = pluginConfig.at("paged_kv_cache");
|
||||
auto const& tokensPerBlock = pluginConfig.at("tokens_per_block");
|
||||
auto const useCustomAllReduce = pluginConfig.at("use_custom_all_reduce").template get<bool>();
|
||||
auto const useContextFMHAForGeneration = pluginConfig.at("use_context_fmha_for_generation").template get<bool>();
|
||||
auto const pagedContextFMHA = pluginConfig.at("use_paged_context_fmha").template get<bool>();
|
||||
|
||||
modelConfig.useGptAttentionPlugin(useGptAttentionPlugin);
|
||||
modelConfig.usePackedInput(removeInputPadding);
|
||||
modelConfig.usePagedKvCache(pagedKvCache);
|
||||
modelConfig.setTokensPerBlock(tokensPerBlock);
|
||||
modelConfig.useCustomAllReduce(useCustomAllReduce);
|
||||
modelConfig.setUseContextFMHAForGeneration(useContextFMHAForGeneration);
|
||||
modelConfig.setPagedContextFMHA(pagedContextFMHA);
|
||||
}
|
||||
|
||||
void parseLora(GptModelConfig& modelConfig, Json const& json, Json const& pluginConfig, bool engineVersionNone,
|
||||
SizeType tensorParallelism)
|
||||
{
|
||||
auto const& config = engineVersionNone ? json.at("builder_config") : json.at("pretrained_config");
|
||||
|
||||
auto const loraMaxRank = parseJsonFieldOr(config, "max_lora_rank", SizeType{0});
|
||||
auto const loraTargetModules = parseJsonFieldOptional<std::vector<std::string>>(config, "lora_target_modules");
|
||||
|
||||
if (loraTargetModules.has_value())
|
||||
{
|
||||
|
||||
modelConfig.setLoraModules(LoraModule::createLoraModules(loraTargetModules.value(), modelConfig.getHiddenSize(),
|
||||
modelConfig.getMlpHiddenSize(), modelConfig.getNbHeads(), modelConfig.getNbKvHeads(),
|
||||
modelConfig.getSizePerHead(), tensorParallelism));
|
||||
}
|
||||
|
||||
modelConfig.setMaxLoraRank(loraMaxRank);
|
||||
|
||||
auto useLoraPlugin = !pluginConfig.at("lora_plugin").is_null();
|
||||
if (useLoraPlugin)
|
||||
{
|
||||
|
||||
if (modelConfig.getLoraModules().empty() || modelConfig.getMaxLoraRank() == 0)
|
||||
{
|
||||
TLLM_LOG_WARNING("lora_plugin enabled, but no lora module enabled: setting useLoraPlugin to false");
|
||||
useLoraPlugin = false;
|
||||
}
|
||||
}
|
||||
modelConfig.useLoraPlugin(useLoraPlugin);
|
||||
}
|
||||
|
||||
template <typename InputType>
|
||||
GptJsonConfig parseJson(InputType&& i)
|
||||
GptJsonConfig parseJson(InputType&& input)
|
||||
{
|
||||
auto constexpr allowExceptions = true;
|
||||
auto constexpr ingoreComments = true;
|
||||
auto const json = nlohmann::json::parse(i, nullptr, allowExceptions, ingoreComments);
|
||||
auto const json = nlohmann::json::parse(std::forward<InputType>(input), nullptr, allowExceptions, ingoreComments);
|
||||
|
||||
auto const engineVersion = parseJsonFieldOr(json, "version", std::string("none"));
|
||||
|
||||
@ -104,113 +216,26 @@ GptJsonConfig parseJson(InputType&& i)
|
||||
auto const precision = engineVersionNone ? builderConfig.at("precision").template get<std::string>()
|
||||
: json.at("pretrained_config").at("dtype").template get<std::string>();
|
||||
|
||||
auto dataType = nvinfer1::DataType::kFLOAT;
|
||||
if (!precision.compare("float32"))
|
||||
dataType = nvinfer1::DataType::kFLOAT;
|
||||
else if (!precision.compare("float16"))
|
||||
dataType = nvinfer1::DataType::kHALF;
|
||||
else if (!precision.compare("bfloat16"))
|
||||
dataType = nvinfer1::DataType::kBF16;
|
||||
else
|
||||
TLLM_CHECK_WITH_INFO(false, tc::fmtstr("Model data type '%s' not supported", precision.c_str()));
|
||||
|
||||
auto modelConfig = [&engineVersionNone, &json, &builderConfig, &tensorParallelism, &dataType]()
|
||||
auto const dataType = [&precision]()
|
||||
{
|
||||
auto const& config = engineVersionNone ? builderConfig : json.at("pretrained_config");
|
||||
|
||||
std::string const numLayersField = engineVersionNone ? "num_layers" : "num_hidden_layers";
|
||||
std::string const numHeadsField = engineVersionNone ? "num_heads" : "num_attention_heads";
|
||||
std::string const numKvHeadsField = engineVersionNone ? "num_kv_heads" : "num_key_value_heads";
|
||||
std::string const mlpHiddenSizeField = engineVersionNone ? "mlp_hidden_size" : "intermediate_size";
|
||||
|
||||
auto const numLayers = config.at(numLayersField).template get<SizeType>();
|
||||
auto const numHeads = config.at(numHeadsField).template get<SizeType>() / tensorParallelism;
|
||||
|
||||
auto const vocabSize = config.at("vocab_size").template get<SizeType>();
|
||||
auto const hiddenSize = config.at("hidden_size").template get<SizeType>() / tensorParallelism;
|
||||
auto const sizePerHead = parseJsonFieldOr(config, "head_size", hiddenSize / numHeads);
|
||||
auto const loraMaxRank = parseJsonFieldOr(config, "max_lora_rank", SizeType{0});
|
||||
auto const loraTargetModules = parseJsonFieldOptional<std::vector<std::string>>(config, "lora_target_modules");
|
||||
|
||||
// TODO:
|
||||
// Code crashes when numKvHeads <= 0. Clamping downwards to 1 prevents that, make sure this is best fix.
|
||||
auto const numKvHeads
|
||||
= std::max(parseJsonFieldOr(config, numKvHeadsField, numHeads * tensorParallelism) / tensorParallelism, 1);
|
||||
|
||||
auto const mlpHiddenSize = parseJsonFieldOptional<SizeType>(config, mlpHiddenSizeField);
|
||||
|
||||
auto modelConfig = GptModelConfig{vocabSize, numLayers, numHeads, hiddenSize, dataType};
|
||||
modelConfig.setSizePerHead(sizePerHead);
|
||||
modelConfig.setNbKvHeads(numKvHeads);
|
||||
|
||||
if (mlpHiddenSize.has_value())
|
||||
{
|
||||
modelConfig.setMlpHiddenSize(mlpHiddenSize.value() / tensorParallelism);
|
||||
}
|
||||
|
||||
if (loraTargetModules.has_value())
|
||||
{
|
||||
|
||||
modelConfig.setLoraModules(LoraModule::createLoraModules(loraTargetModules.value(),
|
||||
modelConfig.getHiddenSize(), modelConfig.getMlpHiddenSize(), modelConfig.getNbHeads(),
|
||||
modelConfig.getNbKvHeads(), modelConfig.getSizePerHead(), tensorParallelism));
|
||||
}
|
||||
|
||||
modelConfig.setMaxLoraRank(loraMaxRank);
|
||||
|
||||
return modelConfig;
|
||||
if (!precision.compare("float32"))
|
||||
return nvinfer1::DataType::kFLOAT;
|
||||
else if (!precision.compare("float16"))
|
||||
return nvinfer1::DataType::kHALF;
|
||||
else if (!precision.compare("bfloat16"))
|
||||
return nvinfer1::DataType::kBF16;
|
||||
else
|
||||
TLLM_THROW("Model data type '%s' not supported", precision.c_str());
|
||||
}();
|
||||
|
||||
auto const maxBatchSize = parseJsonFieldOr(builderConfig, "max_batch_size", 0);
|
||||
auto const maxBeamWidth = parseJsonFieldOr(builderConfig, "max_beam_width", 0);
|
||||
auto const maxInputLen = parseJsonFieldOr(builderConfig, "max_input_len", 0);
|
||||
auto const maxSequenceLen = maxInputLen + parseJsonFieldOr(builderConfig, "max_output_len", 0);
|
||||
auto const maxDraftLen = parseJsonFieldOr(builderConfig, "max_draft_len", 0);
|
||||
auto const maxNumTokens = parseJsonFieldOptional<SizeType>(builderConfig, "max_num_tokens");
|
||||
auto const maxPromptEmbeddingTableSize
|
||||
= parseJsonFieldOr<SizeType>(builderConfig, "max_prompt_embedding_table_size", 0);
|
||||
auto const computeContextLogits = parseJsonFieldOr(builderConfig, "gather_context_logits", false);
|
||||
auto const computeGenerationLogits = parseJsonFieldOr(builderConfig, "gather_generation_logits", false);
|
||||
auto modelConfig = createModelConfig(json, engineVersionNone, tensorParallelism, dataType);
|
||||
|
||||
modelConfig.setMaxBatchSize(maxBatchSize);
|
||||
modelConfig.setMaxBeamWidth(maxBeamWidth);
|
||||
modelConfig.setMaxInputLen(maxInputLen);
|
||||
modelConfig.setMaxSequenceLen(maxSequenceLen);
|
||||
modelConfig.setMaxNumTokens(maxNumTokens);
|
||||
modelConfig.setMaxDraftLen(maxDraftLen);
|
||||
modelConfig.setMaxPromptEmbeddingTableSize(maxPromptEmbeddingTableSize);
|
||||
modelConfig.computeContextLogits(computeContextLogits);
|
||||
modelConfig.computeGenerationLogits(computeGenerationLogits);
|
||||
parseBuilderConfig(modelConfig, builderConfig);
|
||||
|
||||
auto const& pluginConfig = engineVersionNone ? json.at("plugin_config") : builderConfig.at("plugin_config");
|
||||
parsePluginConfig(modelConfig, pluginConfig);
|
||||
|
||||
auto const useGptAttentionPlugin = !pluginConfig.at("gpt_attention_plugin").is_null();
|
||||
auto const removeInputPadding = pluginConfig.at("remove_input_padding").template get<bool>();
|
||||
auto const pagedKvCache = pluginConfig.at("paged_kv_cache");
|
||||
auto const tokensPerBlock = pluginConfig.at("tokens_per_block");
|
||||
auto const useCustomAllReduce = pluginConfig.at("use_custom_all_reduce").template get<bool>();
|
||||
auto const useContextFMHAForGeneration = pluginConfig.at("use_context_fmha_for_generation").template get<bool>();
|
||||
auto const pagedContextFMHA = pluginConfig.at("use_paged_context_fmha").template get<bool>();
|
||||
|
||||
modelConfig.useGptAttentionPlugin(useGptAttentionPlugin);
|
||||
modelConfig.usePackedInput(removeInputPadding);
|
||||
modelConfig.usePagedKvCache(pagedKvCache);
|
||||
modelConfig.setTokensPerBlock(tokensPerBlock);
|
||||
modelConfig.useCustomAllReduce(useCustomAllReduce);
|
||||
modelConfig.setUseContextFMHAForGeneration(useContextFMHAForGeneration);
|
||||
modelConfig.setPagedContextFMHA(pagedContextFMHA);
|
||||
|
||||
auto useLoraPlugin = !pluginConfig.at("lora_plugin").is_null();
|
||||
if (useLoraPlugin)
|
||||
{
|
||||
|
||||
if (modelConfig.getLoraModules().empty() || modelConfig.getMaxLoraRank() == 0)
|
||||
{
|
||||
TLLM_LOG_WARNING("lora_plugin enabled, but no lora module enabled: setting useLoraPlugin to false");
|
||||
useLoraPlugin = false;
|
||||
}
|
||||
}
|
||||
modelConfig.useLoraPlugin(useLoraPlugin);
|
||||
parseLora(modelConfig, json, pluginConfig, engineVersionNone, tensorParallelism);
|
||||
|
||||
if (engineVersionNone)
|
||||
{
|
||||
|
||||
@ -23,7 +23,6 @@
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/common/customAllReduceUtils.h"
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
#include "tensorrt_llm/kernels/decodingKernels.h"
|
||||
#include "tensorrt_llm/runtime/gptDecoderBatch.h"
|
||||
#include "tensorrt_llm/runtime/ipcUtils.h"
|
||||
#include "tensorrt_llm/runtime/ncclCommunicator.h"
|
||||
@ -189,29 +188,20 @@ void GptSession::createKvCacheManager(SizeType batchSize, SizeType beamWidth, Si
|
||||
kvDtype = mModelConfig.getDataType();
|
||||
}
|
||||
|
||||
auto const maxNumTokens
|
||||
= bmkv::KVCacheManager::getMaxNumTokens(kvCacheConfig, kvDtype, mModelConfig, mWorldConfig, getBufferManager());
|
||||
TLLM_LOG_INFO("Using %d tokens in paged KV cache.", maxNumTokens);
|
||||
auto const maxNumBlocks = tc::ceilDiv(maxNumTokens, tokensPerBlock);
|
||||
auto const maxNumBlocks = bmkv::KVCacheManager::calculateMaxNumBlocks(
|
||||
kvCacheConfig, kvDtype, mModelConfig, mWorldConfig, getBufferManager());
|
||||
|
||||
SizeType sinkTokensInLastBlock = sinkTokenLength % tokensPerBlock;
|
||||
SizeType bubbleLen = sinkTokensInLastBlock != 0 ? tokensPerBlock - sinkTokensInLastBlock : 0;
|
||||
auto maxBlocksPerSeq = tc::ceilDiv(maxAttentionWindow + bubbleLen, tokensPerBlock);
|
||||
// If beamWidth > 1, use one more block for each sequence in the paged kv cache to avoid dropping the needed
|
||||
// tokens, when enabling cyclic kv cache.
|
||||
auto const useOneMoreBlock = beamWidth > 1 && maxSequenceLength > maxAttentionWindow;
|
||||
if (useOneMoreBlock)
|
||||
{
|
||||
maxBlocksPerSeq += 1;
|
||||
}
|
||||
|
||||
auto const localNbLayers = mModelConfig.getNbLayers(mWorldConfig.getPipelineParallelism());
|
||||
auto const nbKvHeads = mModelConfig.getNbKvHeads();
|
||||
auto const sizePerHead = mModelConfig.getSizePerHead();
|
||||
bool enableBlockReuse{false};
|
||||
bool constexpr enableBlockReuse{false};
|
||||
mKvCacheManager = std::make_shared<bmkv::KVCacheManager>(localNbLayers, nbKvHeads, sizePerHead, tokensPerBlock,
|
||||
maxNumBlocks, batchSize, beamWidth, maxBlocksPerSeq, maxAttentionWindow, sinkTokenLength, useOneMoreBlock,
|
||||
kvDtype, mRuntime->getStreamPtr(), enableBlockReuse, kvCacheConfig.useUvm);
|
||||
maxNumBlocks, batchSize, beamWidth, maxAttentionWindow, sinkTokenLength, useOneMoreBlock, kvDtype,
|
||||
mRuntime->getStreamPtr(), enableBlockReuse, kvCacheConfig.useUvm);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
@ -335,12 +325,14 @@ void GptSession::setup(Config const& sessionConfig)
|
||||
createCustomAllReduceWorkspace(mMicroBatchConfig.genBatchSize, maxBeamWidth, maxSequenceLength);
|
||||
}
|
||||
|
||||
auto* kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManager.get() : nullptr;
|
||||
|
||||
for (auto& buffers : mBuffers)
|
||||
{
|
||||
// we don't know maxInputLength yet and ignore it for pre-allocation
|
||||
buffers->generationConfig = RuntimeBuffers::GenerationConfig{
|
||||
mMicroBatchConfig.genBatchSize, maxBeamWidth, 0, maxAttentionWindow, sinkTokenLength, maxSequenceLength};
|
||||
buffers->reshape(mModelConfig, mWorldConfig);
|
||||
buffers->reshape(kvCacheManager, mModelConfig, mWorldConfig);
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
@ -684,6 +676,8 @@ void GptSession::generateBatched(std::vector<GenerationOutput>& microBatchesOutp
|
||||
TLLM_CHECK(numMicroBatches <= mMicroBatchConfig.numGenBatches);
|
||||
SizeType const beamWidth{samplingConfig.beamWidth};
|
||||
|
||||
auto* kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManager.get() : nullptr;
|
||||
|
||||
// Initialize and reshape buffers
|
||||
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
|
||||
{
|
||||
@ -691,7 +685,7 @@ void GptSession::generateBatched(std::vector<GenerationOutput>& microBatchesOutp
|
||||
auto& buffers = *mBuffers.at(microBatchId);
|
||||
buffers.initFromInput(*microBatchInputs.ids, microBatchInputs.lengths, microBatchInputs.packed, beamWidth,
|
||||
mDecoderMaxAttentionWindow, mDecoderSinkTokenLength, mDecoderMaxSequenceLength, manager);
|
||||
buffers.reshape(mModelConfig, mWorldConfig);
|
||||
buffers.reshape(kvCacheManager, mModelConfig, mWorldConfig);
|
||||
buffers.reset(manager);
|
||||
}
|
||||
|
||||
@ -746,8 +740,6 @@ void GptSession::generateBatched(std::vector<GenerationOutput>& microBatchesOutp
|
||||
}
|
||||
}
|
||||
|
||||
auto kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManager.get() : nullptr;
|
||||
|
||||
auto const profileContext = !kProfileMbIdxs.empty() && kProfileMbIdxs.count(0) > 0;
|
||||
if (profileContext)
|
||||
cudaProfilerStart();
|
||||
|
||||
@ -49,7 +49,7 @@ LoraManager::LoraReqTensors& LoraManager::getTask(TaskIdType reqId)
|
||||
void LoraManager::create(
|
||||
GptModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& manager)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto modules = modelConfig.getLoraModules();
|
||||
SizeType modOff = 0;
|
||||
@ -62,14 +62,14 @@ void LoraManager::create(
|
||||
// TODO set this size from max adapter size
|
||||
mWorkspace = manager.emptyTensor(MemoryType::kGPU, modelConfig.getDataType());
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes, ReqIdsVec const& reqIds,
|
||||
std::vector<SizeType> const& reqBeamWidth, std::vector<bool> const& loraEnabled, SizeType numContextRequests,
|
||||
GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism());
|
||||
auto firstLayerId = worldConfig.getPipelineParallelRank() * localNbLayers;
|
||||
auto lastLayerId = firstLayerId + localNbLayers;
|
||||
@ -85,13 +85,13 @@ void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes
|
||||
fillInputTensors(
|
||||
weightsPtrs, adapterSizes, bid, reqIds[bid], reqBeamWidth[bid], firstLayerId, lastLayerId, tpSize, tpRank);
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes, SizeType batchIdx, TaskIdType taskId,
|
||||
SizeType beamWidth, SizeType firstLayerId, SizeType lastLayerId, SizeType tpSize, SizeType tpRank)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto weightsPointersPtr = bufferCast<int64_t>(*weightsPtrs);
|
||||
auto adapterSizesPtr = bufferCast<int32_t>(*adapterSizes);
|
||||
@ -172,13 +172,13 @@ void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes
|
||||
}
|
||||
std::fill_n(writeAdapterSizesPtr, beamWidth, adapterSize);
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void LoraManager::insertInputTensors(TensorMap& inputTensors, TensorPtr weightsPtrs, TensorPtr adapterSizes,
|
||||
GptModelConfig const& modelConfig, WorldConfig const& worldConfig) const
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism());
|
||||
auto firstLayerId = worldConfig.getPipelineParallelRank() * localNbLayers;
|
||||
|
||||
@ -210,13 +210,13 @@ void LoraManager::insertInputTensors(TensorMap& inputTensors, TensorPtr weightsP
|
||||
}
|
||||
}
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void LoraManager::formatTaskTensors(LoraWeightsTensorPtr weights, LoraConfigTensorPtr config,
|
||||
GptModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& manager)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
weights->squeeze(0);
|
||||
config->squeeze(0);
|
||||
|
||||
@ -261,7 +261,7 @@ void LoraManager::formatTaskTensors(LoraWeightsTensorPtr weights, LoraConfigTens
|
||||
manager.copy(*mWorkspace, *weightsOut);
|
||||
}
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void LoraManager::reset()
|
||||
|
||||
@ -38,13 +38,18 @@ std::vector<LoraModule> LoraModule::createLoraModules(std::vector<std::string> c
|
||||
{
|
||||
|
||||
case ModuleType::kATTN_QKV:
|
||||
case ModuleType::kCROSS_ATTN_QKV:
|
||||
modules.emplace_back(
|
||||
LoraModule(t, hidden, (numHeads * attnHeadSize + 2 * numKvHeads * attnHeadSize), false, true, -1, 0));
|
||||
break;
|
||||
case ModuleType::kATTN_Q:
|
||||
case ModuleType::kATTN_K:
|
||||
case ModuleType::kATTN_V: modules.emplace_back(t, hidden, hidden, false, true, -1, 0); break;
|
||||
case ModuleType::kATTN_DENSE: modules.emplace_back(t, hidden, hidden, false, true, 1, -1); break;
|
||||
case ModuleType::kATTN_V:
|
||||
case ModuleType::kCROSS_ATTN_Q:
|
||||
case ModuleType::kCROSS_ATTN_K:
|
||||
case ModuleType::kCROSS_ATTN_V: modules.emplace_back(t, hidden, hidden, false, true, -1, 0); break;
|
||||
case ModuleType::kATTN_DENSE:
|
||||
case ModuleType::kCROSS_ATTN_DENSE: modules.emplace_back(t, hidden, hidden, false, true, 1, -1); break;
|
||||
case ModuleType::kMLP_H_TO_4H: modules.emplace_back(t, hidden, mlpHidden, false, true, -1, 0); break;
|
||||
case ModuleType::kMLP_GATE: modules.emplace_back(t, hidden, mlpHidden, false, true, -1, 0); break;
|
||||
case ModuleType::kMLP_4H_TO_H: modules.emplace_back(t, mlpHiddenSize, hidden, false, true, 1, -1); break;
|
||||
|
||||
@ -39,6 +39,11 @@ public:
|
||||
kMLP_H_TO_4H = 5,
|
||||
kMLP_4H_TO_H = 6,
|
||||
kMLP_GATE = 7,
|
||||
kCROSS_ATTN_QKV = 8,
|
||||
kCROSS_ATTN_Q = 9,
|
||||
kCROSS_ATTN_K = 10,
|
||||
kCROSS_ATTN_V = 11,
|
||||
kCROSS_ATTN_DENSE = 12,
|
||||
};
|
||||
|
||||
explicit constexpr LoraModule(ModuleType const& t, SizeType inDim, SizeType outDim, bool inDimFirst,
|
||||
@ -128,6 +133,16 @@ public:
|
||||
return ModuleType::kMLP_4H_TO_H;
|
||||
else if (name == "mlp_gate")
|
||||
return ModuleType::kMLP_GATE;
|
||||
else if (name == "cross_attn_qkv")
|
||||
return ModuleType::kCROSS_ATTN_QKV;
|
||||
else if (name == "cross_attn_q")
|
||||
return ModuleType::kCROSS_ATTN_Q;
|
||||
else if (name == "cross_attn_k")
|
||||
return ModuleType::kCROSS_ATTN_K;
|
||||
else if (name == "cross_attn_v")
|
||||
return ModuleType::kCROSS_ATTN_V;
|
||||
else if (name == "cross_attn_dense")
|
||||
return ModuleType::kCROSS_ATTN_DENSE;
|
||||
else
|
||||
return ModuleType::kINVALID;
|
||||
}
|
||||
@ -144,6 +159,11 @@ public:
|
||||
case ModuleType::kMLP_H_TO_4H: return "mlp_h_to_4h";
|
||||
case ModuleType::kMLP_4H_TO_H: return "mlp_4h_to_h";
|
||||
case ModuleType::kMLP_GATE: return "mlp_gate";
|
||||
case ModuleType::kCROSS_ATTN_QKV: return "cross_attn_qkv";
|
||||
case ModuleType::kCROSS_ATTN_Q: return "cross_attn_q";
|
||||
case ModuleType::kCROSS_ATTN_K: return "cross_attn_k";
|
||||
case ModuleType::kCROSS_ATTN_V: return "cross_attn_v";
|
||||
case ModuleType::kCROSS_ATTN_DENSE: return "cross_attn_dense";
|
||||
case ModuleType::kINVALID: return "INVALID";
|
||||
}
|
||||
return "INVALID";
|
||||
|
||||
@ -82,4 +82,10 @@ void MemoryCounters::deallocate(MemoryType memoryType, MemoryCounters::SizeType
|
||||
default: TLLM_THROW("Unknown memory type");
|
||||
}
|
||||
}
|
||||
|
||||
MemoryCounters& MemoryCounters::getInstance()
|
||||
{
|
||||
static MemoryCounters mInstance;
|
||||
return mInstance;
|
||||
}
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -17,13 +17,13 @@
|
||||
#include "tensorrt_llm/runtime/runtimeBuffers.h"
|
||||
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/stlUtils.h"
|
||||
#include "tensorrt_llm/runtime/runtimeKernels.h"
|
||||
#include "tensorrt_llm/runtime/tllmRuntime.h"
|
||||
#include "tensorrt_llm/runtime/utils/sessionUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
namespace tc = tensorrt_llm::common;
|
||||
@ -32,7 +32,7 @@ RuntimeBuffers::GenerationConfig RuntimeBuffers::GenerationConfig::fromInput(ITe
|
||||
ITensor const& inputLengthsHost, bool const inputPacked, SizeType const beamWidth,
|
||||
SizeType const maxAttentionWindow, SizeType const sinkTokenLength, SizeType const maxSequenceLength)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto const batchSize = static_cast<SizeType>(inputLengthsHost.getSize());
|
||||
|
||||
auto const* inputLengthsPtr = bufferCast<SizeType>(inputLengthsHost);
|
||||
@ -71,14 +71,14 @@ RuntimeBuffers::GenerationConfig RuntimeBuffers::GenerationConfig::fromInput(ITe
|
||||
"Max input length is equal to or larger that maxSequenceLength given in setup. No new tokens can be "
|
||||
"generated.");
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
return GenerationConfig{
|
||||
batchSize, beamWidth, maxInputLength, maxAttentionWindow, sinkTokenLength, maxSequenceLength, inputLengthSum};
|
||||
}
|
||||
|
||||
void RuntimeBuffers::clear()
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
contextLengthsHost = nullptr;
|
||||
contextLengthsDevice = nullptr;
|
||||
|
||||
@ -104,22 +104,22 @@ void RuntimeBuffers::clear()
|
||||
hiddenStates = nullptr;
|
||||
|
||||
allocated = false;
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::clearTensorMaps()
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
for (auto& buffer : inputBuffers)
|
||||
buffer.clear();
|
||||
for (auto& buffer : outputBuffers)
|
||||
buffer.clear();
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto& manager = runtime.getBufferManager();
|
||||
auto& engine = runtime.getEngine();
|
||||
|
||||
@ -170,8 +170,7 @@ void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelCon
|
||||
|
||||
if (modelConfig.usePagedKvCache())
|
||||
{
|
||||
auto const kvCacheBlockPointersType
|
||||
= engine.getTensorDataType(("kv_cache_block_pointers_" + std::to_string(firstLayerId)).c_str());
|
||||
auto const kvCacheBlockPointersType = engine.getTensorDataType("kv_cache_block_pointers");
|
||||
kvCacheBlockPointersHost = manager.emptyTensor(MemoryType::kCPU, kvCacheBlockPointersType);
|
||||
kvCacheBlockPointersDevice = manager.emptyTensor(MemoryType::kGPU, kvCacheBlockPointersType);
|
||||
}
|
||||
@ -183,8 +182,7 @@ void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelCon
|
||||
if (modelConfig.useGptAttentionPlugin())
|
||||
{
|
||||
pastKeyValueLengths = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32);
|
||||
maxAttentionWindows
|
||||
= utils::createBufferVector(runtime, localNbLayers, MemoryType::kCPU, nvinfer1::DataType::kINT32);
|
||||
maxAttentionWindows = BufferManager::cpu(ITensor::makeShape({localNbLayers}), nvinfer1::DataType::kINT32);
|
||||
sinkTokenLengths = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32);
|
||||
}
|
||||
else
|
||||
@ -207,7 +205,7 @@ void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelCon
|
||||
hiddenStates = manager.emptyTensor(MemoryType::kGPU, modelConfig.getDataType());
|
||||
}
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::initFromInput(ITensor const& inputIds, TensorPtr const& inputLengths, bool inputPacked,
|
||||
@ -223,15 +221,15 @@ void RuntimeBuffers::initFromInput(ITensor const& inputIds, TensorPtr const& inp
|
||||
inputIds, *contextLengthsHost, inputPacked, beamWidth, maxAttentionWindow, sinkTokenLength, maxSequenceLength);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
void RuntimeBuffers::reshape(
|
||||
KvCacheManager const* kvCacheManager, GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const batchSize = generationConfig.batchSize;
|
||||
auto const beamWidth = generationConfig.beamWidth;
|
||||
auto const maxInputLength = generationConfig.maxInputLength;
|
||||
auto const maxAttentionWindow = generationConfig.maxAttentionWindow;
|
||||
auto const sinkTokenLen = generationConfig.sinkTokenLength;
|
||||
auto const maxSeqLength = generationConfig.maxSeqLength;
|
||||
auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize());
|
||||
|
||||
@ -259,13 +257,12 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
|
||||
if (modelConfig.computeGenerationLogits())
|
||||
{
|
||||
allGenerationLogits->reshape(
|
||||
ITensor::makeShape({(generationConfig.maxSeqLength - generationConfig.maxInputLength), batchSize,
|
||||
beamWidth, vocabSizePadded}));
|
||||
ITensor::makeShape({(maxSeqLength - maxInputLength), batchSize, beamWidth, vocabSizePadded}));
|
||||
|
||||
cacheGenerationFragmentPointerDevice->reshape(
|
||||
ITensor::makeShape({batchSize, (generationConfig.maxSeqLength - generationConfig.maxInputLength)}));
|
||||
ITensor::makeShape({batchSize, (maxSeqLength - maxInputLength)}));
|
||||
cacheGenerationFragmentPointerHost->reshape(
|
||||
ITensor::makeShape({batchSize, (generationConfig.maxSeqLength - generationConfig.maxInputLength)}));
|
||||
ITensor::makeShape({batchSize, (maxSeqLength - maxInputLength)}));
|
||||
}
|
||||
}
|
||||
|
||||
@ -277,18 +274,11 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
|
||||
= ITensor::makeShape({batchSize, 2, modelConfig.getNbKvHeads(), maxInputLength, modelConfig.getSizePerHead()});
|
||||
if (modelConfig.usePagedKvCache())
|
||||
{
|
||||
auto const localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism());
|
||||
auto const tokensPerBlock = modelConfig.getTokensPerBlock();
|
||||
SizeType bubbleLen
|
||||
= (sinkTokenLen % tokensPerBlock == 0) ? 0 : tokensPerBlock - (sinkTokenLen % tokensPerBlock);
|
||||
auto maxBlocksPerSeq = tc::ceilDiv(maxAttentionWindow + bubbleLen, tokensPerBlock);
|
||||
// If beamWidth > 1, use one more block for each sequence in the paged kv cache to avoid dropping the needed
|
||||
// tokens, when enabling cyclic kv cache.
|
||||
if (beamWidth > 1 && maxSeqLength > maxAttentionWindow)
|
||||
{
|
||||
maxBlocksPerSeq += 1;
|
||||
}
|
||||
TLLM_CHECK(kvCacheManager);
|
||||
|
||||
auto const localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism());
|
||||
|
||||
auto const maxBlocksPerSeq = kvCacheManager->getMaxBlocksPerSeq();
|
||||
// reserve batchSize * beamWidth and resize to batchSize
|
||||
auto cacheBlockPointersShape = ITensor::makeShape({localNbLayers, batchSize * beamWidth, 2, maxBlocksPerSeq});
|
||||
kvCacheBlockPointersHost->reshape(cacheBlockPointersShape);
|
||||
@ -302,11 +292,13 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
|
||||
utils::reshapeBufferVector(presentKeysVals, kvCacheReserve);
|
||||
}
|
||||
|
||||
auto const localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism());
|
||||
|
||||
if (modelConfig.useGptAttentionPlugin())
|
||||
{
|
||||
pastKeyValueLengths->reshape(ITensor::makeShape({batchSize}));
|
||||
requestTypes->reshape(ITensor::makeShape({batchSize}));
|
||||
utils::reshapeBufferVector(maxAttentionWindows, ITensor::makeShape({1}));
|
||||
maxAttentionWindows->reshape(ITensor::makeShape({localNbLayers}));
|
||||
sinkTokenLengths->reshape(ITensor::makeShape({1}));
|
||||
}
|
||||
else
|
||||
@ -332,7 +324,7 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
|
||||
}
|
||||
|
||||
allocated = true;
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::reset(BufferManager& manager)
|
||||
@ -345,7 +337,7 @@ void RuntimeBuffers::reset(BufferManager& manager)
|
||||
std::vector<RuntimeBuffers> RuntimeBuffers::split(
|
||||
SizeType contextBatchSize, GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
std::vector<RuntimeBuffers> bufferSlices;
|
||||
auto const generationBatchSize = generationConfig.batchSize;
|
||||
@ -432,14 +424,14 @@ std::vector<RuntimeBuffers> RuntimeBuffers::split(
|
||||
}
|
||||
}
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
return bufferSlices;
|
||||
}
|
||||
|
||||
void RuntimeBuffers::gatherLastTokenLogits(
|
||||
BufferManager& manager, GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK_WITH_INFO(modelConfig.computeContextLogits(),
|
||||
"Gather last token logits is only necessary when context logits are computed");
|
||||
|
||||
@ -459,12 +451,12 @@ void RuntimeBuffers::gatherLastTokenLogits(
|
||||
}
|
||||
}
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::copyAttentionMasks(std::vector<RuntimeBuffers> const& contextBatches, BufferManager& manager)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto const batchSize = generationConfig.batchSize;
|
||||
auto const maxInputLength = generationConfig.maxInputLength;
|
||||
|
||||
@ -481,12 +473,12 @@ void RuntimeBuffers::copyAttentionMasks(std::vector<RuntimeBuffers> const& conte
|
||||
manager.copy(*buffers.attentionMask, *attentionMaskSlice);
|
||||
offset += contextBatchSize;
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::tile(BufferManager& manager, GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto const beamWidth = generationConfig.beamWidth;
|
||||
TLLM_CHECK_WITH_INFO(beamWidth > 1, "Tiling is only necessary for beam search.");
|
||||
|
||||
@ -519,13 +511,13 @@ void RuntimeBuffers::tile(BufferManager& manager, GptModelConfig const& modelCon
|
||||
for (auto& buffer : presentKeysValsAlt)
|
||||
utils::tileBufferReplace(buffer, beamWidth, manager);
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::postContextStep(std::vector<RuntimeBuffers> const& contextBuffers, BufferManager& manager,
|
||||
GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto const batchSize = generationConfig.batchSize;
|
||||
auto const beamWidth = generationConfig.beamWidth;
|
||||
|
||||
@ -580,14 +572,14 @@ void RuntimeBuffers::postContextStep(std::vector<RuntimeBuffers> const& contextB
|
||||
manager, modelConfig.usePackedInput());
|
||||
}
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType const padId, BufferManager& manager,
|
||||
KvCacheManager const* kvCacheManager, SizeType firstBatchSlotIdx, GptModelConfig const& modelConfig,
|
||||
WorldConfig const& worldConfig)
|
||||
batch_manager::kv_cache_manager::KVCacheManager const* kvCacheManager, SizeType firstBatchSlotIdx,
|
||||
GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto& stream = manager.getStream();
|
||||
SizeType const batchSize = generationConfig.batchSize;
|
||||
SizeType const maxInputLength = generationConfig.maxInputLength;
|
||||
@ -607,11 +599,9 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c
|
||||
TLLM_CHECK(requestTypes->getSize() == static_cast<std::size_t>(batchSize));
|
||||
std::fill_n(RequestTypesPtr, batchSize, 0);
|
||||
|
||||
// Set maxAttentionWindows buffer and sinkTokenLengths to the same value currently.
|
||||
for (auto layer = 0; layer < localNbLayers; ++layer)
|
||||
{
|
||||
bufferCast<SizeType>(*maxAttentionWindows[layer])[0] = generationConfig.maxAttentionWindow;
|
||||
}
|
||||
auto maxAttentionWindowsPtr = bufferCast<SizeType>(*maxAttentionWindows);
|
||||
std::fill_n(maxAttentionWindowsPtr, localNbLayers, generationConfig.maxAttentionWindow);
|
||||
|
||||
bufferCast<SizeType>(*sinkTokenLengths)[0] = generationConfig.sinkTokenLength;
|
||||
|
||||
auto const& inputShape = inputIds->getShape();
|
||||
@ -720,14 +710,14 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c
|
||||
manager.copy(*contextLengthsDevice, *lastTokenIds);
|
||||
}
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
RuntimeBuffers::TensorPtr RuntimeBuffers::prepareNextStep(SizeType const step, BufferManager& manager,
|
||||
KvCacheManager* kvCacheManager, SizeType firstBatchSlotIdx, GptModelConfig const& modelConfig,
|
||||
WorldConfig const& worldConfig)
|
||||
batch_manager::kv_cache_manager::KVCacheManager* kvCacheManager, SizeType firstBatchSlotIdx,
|
||||
GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto& stream = manager.getStream();
|
||||
SizeType const batchSize = generationConfig.batchSize;
|
||||
SizeType const beamWidth = generationConfig.beamWidth;
|
||||
@ -842,7 +832,7 @@ RuntimeBuffers::TensorPtr RuntimeBuffers::prepareNextStep(SizeType const step, B
|
||||
{
|
||||
kernels::invokeInclusiveSum(*lastTokenIds, *lastTokenIds, manager, stream);
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
return nextInputIds;
|
||||
}
|
||||
|
||||
@ -850,7 +840,7 @@ void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outpu
|
||||
TensorPtr const& inputIds, TensorPtr const& commPtrs, GptModelConfig const& modelConfig,
|
||||
WorldConfig const& worldConfig) const
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
inputBuffers.clear();
|
||||
outputBuffers.clear();
|
||||
|
||||
@ -890,7 +880,7 @@ void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outpu
|
||||
inputBuffers.insert_or_assign("host_request_types", requestTypes);
|
||||
inputBuffers.insert_or_assign("sequence_length", sequenceLengths);
|
||||
inputBuffers.insert_or_assign("host_sink_token_length", sinkTokenLengths);
|
||||
utils::insertTensorVector(inputBuffers, "host_max_attention_window_size_", maxAttentionWindows, firstLayerId);
|
||||
inputBuffers.insert_or_assign("host_max_attention_window_sizes", maxAttentionWindows);
|
||||
|
||||
if (modelConfig.usePackedInput())
|
||||
{
|
||||
@ -898,10 +888,8 @@ void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outpu
|
||||
}
|
||||
if (modelConfig.usePagedKvCache())
|
||||
{
|
||||
utils::insertTensorSlices(
|
||||
inputBuffers, "kv_cache_block_pointers_", kvCacheBlockPointersDevice, firstLayerId);
|
||||
utils::insertTensorSlices(
|
||||
inputBuffers, "host_kv_cache_block_pointers_", kvCacheBlockPointersHost, firstLayerId);
|
||||
inputBuffers.insert_or_assign("kv_cache_block_pointers", kvCacheBlockPointersDevice);
|
||||
inputBuffers.insert_or_assign("host_kv_cache_block_pointers", kvCacheBlockPointersHost);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -946,7 +934,10 @@ void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outpu
|
||||
inputBuffers.insert_or_assign("tasks", promptTuningParams.tasks);
|
||||
inputBuffers.insert_or_assign("prompt_vocab_size", promptTuningParams.vocabSize);
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
|
||||
// utils::printTensorMap(std::cerr, inputBuffers);
|
||||
// utils::printTensorMap(std::cerr, outputBuffers);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
std::vector<SizeType> RuntimeBuffers::getPositionIdsContextPhaseGlm(const SizeType& batchSize,
|
||||
|
||||
@ -99,11 +99,11 @@ public:
|
||||
// still point to `allGenerationLogits` and bring overwrite conflict.
|
||||
|
||||
std::vector<TensorPtr> presentKeysVals;
|
||||
std::vector<TensorPtr> presentKeysValsAlt; // without attention plugin
|
||||
std::vector<TensorPtr> maxAttentionWindows; // with attention plugin, host tensor
|
||||
TensorPtr sinkTokenLengths; // with attention plugin, host tensor
|
||||
TensorPtr kvCacheBlockPointersHost; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
|
||||
TensorPtr kvCacheBlockPointersDevice; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
|
||||
std::vector<TensorPtr> presentKeysValsAlt; // without attention plugin
|
||||
TensorPtr maxAttentionWindows; // with attention plugin, host tensor
|
||||
TensorPtr sinkTokenLengths; // with attention plugin, host tensor
|
||||
TensorPtr kvCacheBlockPointersHost; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
|
||||
TensorPtr kvCacheBlockPointersDevice; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
|
||||
|
||||
// References to tmp buffers
|
||||
TensorPtr newTokens;
|
||||
@ -147,7 +147,8 @@ public:
|
||||
SizeType maxAttentionWindow, SizeType sinkTokenLength, SizeType maxSequenceLength, BufferManager& manager);
|
||||
|
||||
//! \brief Reshape buffers based on current GenerationConfig
|
||||
void reshape(GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
|
||||
void reshape(
|
||||
KvCacheManager const* kvCacheManager, GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
|
||||
|
||||
void reset(BufferManager& manager);
|
||||
|
||||
|
||||
@ -417,6 +417,21 @@ void invokeInclusiveSum(IBuffer& output, IBuffer const& input, BufferManager con
|
||||
cub::DeviceScan::InclusiveSum(tempStorageData, tempStorageBytes, inputData, outputData, size, stream.get());
|
||||
}
|
||||
|
||||
void invokeInclusiveSum(IBuffer& output, IBuffer& tmpBuffer, IBuffer const& input, CudaStream const& stream)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(nvinfer1::DataType::kUINT8 == tmpBuffer.getDataType(), "tmpBuffer has wrong data type");
|
||||
|
||||
auto const size = input.getSize();
|
||||
auto const* inputData = bufferCast<SizeType>(input);
|
||||
auto* outputData = bufferCast<SizeType>(output);
|
||||
|
||||
std::size_t tempStorageBytes{0};
|
||||
cub::DeviceScan::InclusiveSum(nullptr, tempStorageBytes, inputData, outputData, size, stream.get());
|
||||
tmpBuffer.resize(tempStorageBytes);
|
||||
auto* tmpBufferPtr = bufferCast<std::uint8_t>(tmpBuffer);
|
||||
cub::DeviceScan::InclusiveSum(tmpBufferPtr, tempStorageBytes, inputData, outputData, size, stream.get());
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
__global__ void buildTokenMask(SizeType* tokenMask, SizeType const* inputLengths, SizeType const batchSize,
|
||||
@ -752,7 +767,7 @@ void initOutputIds(ITensor& outputIds, ITensor const& inputIds, ITensor const& i
|
||||
ITensor const& inputOffsets, TokenIdType const padId, TokenIdType const endId, SizeType const maxInputLength,
|
||||
bool const inputPacked, CudaStream const& stream)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
kernels::invokeFill(outputIds, endId, stream);
|
||||
|
||||
if (inputPacked)
|
||||
@ -763,7 +778,7 @@ void initOutputIds(ITensor& outputIds, ITensor const& inputIds, ITensor const& i
|
||||
{
|
||||
kernels::invokeCopyInputToOutput(outputIds, inputIds, inputLengths, padId, stream);
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
namespace
|
||||
|
||||
@ -51,6 +51,8 @@ void invokeTransposeWithInputOffset(
|
||||
|
||||
void invokeInclusiveSum(IBuffer& output, IBuffer const& input, BufferManager const& manager, CudaStream const& stream);
|
||||
|
||||
void invokeInclusiveSum(IBuffer& output, IBuffer& tmpBuffer, IBuffer const& input, CudaStream const& stream);
|
||||
|
||||
void invokeBuildTokenMask(
|
||||
ITensor& tokenMask, ITensor const& inputLengths, SizeType maxInputLength, CudaStream const& stream);
|
||||
|
||||
|
||||
@ -16,13 +16,12 @@
|
||||
|
||||
#include "tensorrt_llm/runtime/statefulGptDecoder.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/kernels/decodingCommon.h"
|
||||
#include "tensorrt_llm/runtime/runtimeKernels.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
namespace tc = tensorrt_llm::common;
|
||||
namespace tk = tensorrt_llm::kernels;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
@ -35,7 +34,7 @@ StatefulGptDecoder::StatefulGptDecoder(std::size_t vocabSize, std::size_t vocabS
|
||||
, mStream{std::move(stream)}
|
||||
, mBufferManager{mStream}
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto constexpr nvTokenIdType = TRTDataType<TokenIdType>::value;
|
||||
auto constexpr nvSizeType = TRTDataType<SizeType>::value;
|
||||
auto constexpr nvFloatType = TRTDataType<float>::value;
|
||||
@ -68,26 +67,26 @@ StatefulGptDecoder::StatefulGptDecoder(std::size_t vocabSize, std::size_t vocabS
|
||||
|
||||
mFinishedSum = mBufferManager.pinned(ITensor::makeShape({1}), nvSizeType);
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void StatefulGptDecoder::setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth,
|
||||
SizeType maxAttentionWindow, SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep,
|
||||
bool fusedDecoder, nvinfer1::DataType dtype)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK(maxTokensPerStep == 1);
|
||||
mDecoder = IGptDecoder::create(
|
||||
mode, dtype, maxBatchSize, maxBeamWidth, mVocabSize, mVocabSizePadded, maxSequenceLength, mStream);
|
||||
|
||||
reshapeBuffers(maxBatchSize, maxBeamWidth, maxAttentionWindow, sinkTokenLength, maxSequenceLength);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void StatefulGptDecoder::reshapeBuffers(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow,
|
||||
SizeType sinkTokenLength, SizeType maxSequenceLength)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK(batchSize > 0);
|
||||
TLLM_CHECK(beamWidth > 0);
|
||||
TLLM_CHECK(maxSequenceLength > 0);
|
||||
@ -139,13 +138,13 @@ void StatefulGptDecoder::reshapeBuffers(SizeType batchSize, SizeType beamWidth,
|
||||
}
|
||||
|
||||
mNbSteps = 0;
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void StatefulGptDecoder::newBatch(
|
||||
GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto& manager = mBufferManager;
|
||||
auto& stream = mStream;
|
||||
|
||||
@ -296,12 +295,12 @@ void StatefulGptDecoder::newBatch(
|
||||
|
||||
// remaining
|
||||
mNbSteps = 0;
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void StatefulGptDecoder::forwardAsync(decoder::Output& output, decoder::Input const& input)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto& logits = input.logits;
|
||||
auto const& logitsShape = logits->getShape();
|
||||
|
||||
@ -335,24 +334,24 @@ void StatefulGptDecoder::forwardAsync(decoder::Output& output, decoder::Input co
|
||||
|
||||
dInput.step += 1;
|
||||
mNbSteps += 1;
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void StatefulGptDecoder::forwardSync()
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
mDecodedEvent.synchronize();
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void StatefulGptDecoder::finalize() const
|
||||
{
|
||||
// TODO (rkobus) can we do this inplace?
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto& outputIds = mDecodingOutput->ids;
|
||||
auto finalOutputIds = mBufferManager.gpu(outputIds->getShape(), outputIds->getDataType());
|
||||
mDecoder->gatherTree(*finalOutputIds, *mDecodingOutput, *mDecodingInput, mBufferManager);
|
||||
mBufferManager.copy(*finalOutputIds, *outputIds);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -18,17 +18,12 @@
|
||||
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/cudaEvent.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/decodingMode.h"
|
||||
#include "tensorrt_llm/runtime/gptDecoder.h"
|
||||
#include "tensorrt_llm/runtime/iStatefulGptDecoder.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
@ -88,7 +83,7 @@ public:
|
||||
//! @returns [batchSize, maxBeamWidth], tokens generated in last forward pass, on gpu
|
||||
[[nodiscard]] TensorPtr getAllNewTokens() const override
|
||||
{
|
||||
TensorPtr newTokens = std::move(ITensor::view(mDecodingOutput->newTokensSteps));
|
||||
TensorPtr newTokens = ITensor::view(mDecodingOutput->newTokensSteps);
|
||||
newTokens->unsqueeze(0);
|
||||
return newTokens;
|
||||
}
|
||||
|
||||
@ -19,7 +19,6 @@
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
@ -30,7 +29,6 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <limits>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
|
||||
@ -14,9 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tllmRuntime.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/nvtxUtils.h"
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
#include "tllmLogger.h"
|
||||
|
||||
#include <limits>
|
||||
@ -24,8 +22,6 @@
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
namespace tc = tensorrt_llm::common;
|
||||
|
||||
namespace
|
||||
{
|
||||
using DimType = std::remove_reference_t<decltype(std::declval<nvinfer1::Dims>().d[0])>;
|
||||
@ -108,14 +104,12 @@ bool TllmRuntime::executeContext(SizeType contextIndex) const
|
||||
void TllmRuntime::setInputTensors(SizeType contextIndex, TensorMap const& tensorMap)
|
||||
{
|
||||
NVTX3_FUNC_RANGE();
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto& context = getContext(contextIndex);
|
||||
for (std::int32_t i = 0; i < mEngine->getNbIOTensors(); ++i)
|
||||
{
|
||||
auto const name = mEngine->getIOTensorName(i);
|
||||
if (mEngine->getTensorIOMode(name) == nvinfer1::TensorIOMode::kINPUT)
|
||||
{
|
||||
NVTX3_SCOPED_RANGE(input_tensor);
|
||||
auto pos = tensorMap.find(name);
|
||||
if (pos == tensorMap.end())
|
||||
{
|
||||
@ -197,7 +191,6 @@ void TllmRuntime::setOutputTensors(SizeType contextIndex, TensorMap& tensorMap)
|
||||
auto const name = mEngine->getIOTensorName(i);
|
||||
if (mEngine->getTensorIOMode(name) == nvinfer1::TensorIOMode::kOUTPUT)
|
||||
{
|
||||
NVTX3_SCOPED_RANGE(output_tensor);
|
||||
auto const dims = context.getTensorShape(name);
|
||||
auto const engineDtype = mEngine->getTensorDataType(name);
|
||||
auto pos = tensorMap.find(name);
|
||||
|
||||
@ -23,8 +23,6 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
class TorchView : virtual public ITensor
|
||||
|
||||
@ -115,6 +115,16 @@ void insertTensorSlices(
|
||||
}
|
||||
}
|
||||
|
||||
void printTensorMap(std::ostream& stream, StringPtrMap<ITensor> const& map)
|
||||
{
|
||||
for (auto const& [name, tensor] : map)
|
||||
{
|
||||
stream << "Tensor name: " << name << '\n';
|
||||
stream << "Shape" << tensor->getShape() << '\n';
|
||||
stream << *tensor << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
void setRawPointers(ITensor& pointers, ITensor::SharedPtr const& input, int32_t pointersSlot, int32_t inputSlot)
|
||||
{
|
||||
auto const pointersLength = static_cast<int32_t>(pointers.getSizeInBytes() / sizeof(void**));
|
||||
|
||||
@ -64,6 +64,8 @@ void insertTensorVector(StringPtrMap<ITensor>& map, std::string const& key, std:
|
||||
void insertTensorSlices(
|
||||
StringPtrMap<ITensor>& map, std::string const& key, ITensor::SharedPtr const& tensor, SizeType indexOffset);
|
||||
|
||||
void printTensorMap(std::ostream& stream, StringPtrMap<ITensor> const& map);
|
||||
|
||||
void setRawPointers(ITensor& pointers, ITensor::SharedPtr const& input, int32_t pointersSlot, int32_t inputSlot);
|
||||
|
||||
void setRawPointers(ITensor& pointers, ITensor::SharedPtr const& input);
|
||||
|
||||
@ -85,8 +85,9 @@ WorldConfig WorldConfig::mpi(SizeType gpusPerNode, std::optional<SizeType> tenso
|
||||
auto const mpiSize = comm.getSize();
|
||||
auto const mpiRank = comm.getRank();
|
||||
TLLM_LOG_INFO("MPI size: %d, rank: %d", mpiSize, mpiRank);
|
||||
auto pp = pipelineParallelism.value_or(1);
|
||||
auto tp = tensorParallelism.value_or(mpiSize / pp);
|
||||
auto const pp = pipelineParallelism.value_or(1);
|
||||
auto const tp = tensorParallelism.value_or(mpiSize / pp);
|
||||
TLLM_LOG_DEBUG("TP: %d, PP: %d", tp, pp);
|
||||
TLLM_CHECK(mpiSize == tp * pp);
|
||||
|
||||
return WorldConfig{tp, pp, mpiRank, gpusPerNode, deviceIds};
|
||||
|
||||
@ -103,11 +103,12 @@ void FtDynamicDecode<T>::setup(size_t batch_size, size_t beam_width, th::optiona
|
||||
th::optional<th::Tensor> runtime_top_p_opt, th::optional<th::Tensor> temperature_opt,
|
||||
th::optional<th::Tensor> repetition_penalty_opt, th::optional<th::Tensor> presence_penalty_opt,
|
||||
th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> min_length_opt,
|
||||
th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> beam_search_diversity_rate_opt,
|
||||
th::optional<th::Tensor> random_seed_opt, th::optional<th::Tensor> top_p_decay_opt,
|
||||
th::optional<th::Tensor> top_p_min_opt, th::optional<th::Tensor> top_p_reset_ids_opt)
|
||||
th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> early_stopping_opt,
|
||||
th::optional<th::Tensor> beam_search_diversity_rate_opt, th::optional<th::Tensor> random_seed_opt,
|
||||
th::optional<th::Tensor> top_p_decay_opt, th::optional<th::Tensor> top_p_min_opt,
|
||||
th::optional<th::Tensor> top_p_reset_ids_opt)
|
||||
{
|
||||
// unused: length_penalty_opt, beam_search_diversity_rate_opt
|
||||
// unused: length_penalty_opt, beam_search_diversity_rate_opt, early_stopping_opt
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
dynamic_decode_layer_->setStream(stream);
|
||||
@ -127,6 +128,7 @@ void FtDynamicDecode<T>::setup(size_t batch_size, size_t beam_width, th::optiona
|
||||
safeInsert(top_p_reset_ids_opt, setupParams.top_p_reset_ids);
|
||||
safeInsert(beam_search_diversity_rate_opt, setupParams.beam_search_diversity_rate);
|
||||
safeInsert(length_penalty_opt, setupParams.length_penalty);
|
||||
safeInsert(early_stopping_opt, setupParams.early_stopping);
|
||||
|
||||
dynamic_decode_layer_->setup(batch_size, beam_width, nullptr, setupParams);
|
||||
}
|
||||
@ -211,13 +213,13 @@ void FtDynamicDecode<T>::forward(th::Tensor& logits, // (batch_size, beam_width,
|
||||
dynamic_decode_layer_->forward(outputParams, forwardParams);
|
||||
if (finished_sum_host)
|
||||
{
|
||||
TLLM_CUDA_CHECK(::cudaStreamSynchronize(dynamic_decode_layer_->getStream()));
|
||||
int32_t finished_sum = 0;
|
||||
for (int32_t bi = 0; bi < local_batch_size; ++bi)
|
||||
{
|
||||
finished_sum += finished_sum_host[bi];
|
||||
}
|
||||
auto const numToFinish = outputParams.finished->size();
|
||||
TLLM_CUDA_CHECK(::cudaStreamSynchronize(dynamic_decode_layer_->getStream()));
|
||||
auto should_stop_accessor = should_stop.accessor<bool, 1>();
|
||||
should_stop_accessor[0] = numToFinish == finished_sum;
|
||||
}
|
||||
@ -259,9 +261,10 @@ void DynamicDecodeOp::setup(int64_t batch_size, int64_t beam_width, th::optional
|
||||
th::optional<th::Tensor> runtime_top_p_opt, th::optional<th::Tensor> temperature_opt,
|
||||
th::optional<th::Tensor> repetition_penalty_opt, th::optional<th::Tensor> presence_penalty_opt,
|
||||
th::optional<th::Tensor> frequency_penalty_opt, th::optional<th::Tensor> min_length_opt,
|
||||
th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> beam_search_diversity_rate_opt,
|
||||
th::optional<th::Tensor> random_seed_opt, th::optional<th::Tensor> top_p_decay_opt,
|
||||
th::optional<th::Tensor> top_p_min_opt, th::optional<th::Tensor> top_p_reset_ids_opt)
|
||||
th::optional<th::Tensor> length_penalty_opt, th::optional<th::Tensor> early_stopping_opt,
|
||||
th::optional<th::Tensor> beam_search_diversity_rate_opt, th::optional<th::Tensor> random_seed_opt,
|
||||
th::optional<th::Tensor> top_p_decay_opt, th::optional<th::Tensor> top_p_min_opt,
|
||||
th::optional<th::Tensor> top_p_reset_ids_opt)
|
||||
{
|
||||
// TODO: Revise DynamicDecodeLayer and make the decode arguments consistent.
|
||||
CHECK_OPTIONAL_CPU_INPUT(runtime_top_k_opt, torch::kInt32);
|
||||
@ -273,6 +276,7 @@ void DynamicDecodeOp::setup(int64_t batch_size, int64_t beam_width, th::optional
|
||||
CHECK_OPTIONAL_CPU_INPUT(frequency_penalty_opt, torch::kFloat);
|
||||
CHECK_OPTIONAL_CPU_INPUT(min_length_opt, torch::kInt32);
|
||||
CHECK_OPTIONAL_CPU_INPUT(length_penalty_opt, torch::kFloat);
|
||||
CHECK_OPTIONAL_CPU_INPUT(early_stopping_opt, torch::kInt32);
|
||||
CHECK_OPTIONAL_CPU_INPUT(beam_search_diversity_rate_opt, torch::kFloat);
|
||||
CHECK_OPTIONAL_CPU_INPUT(random_seed_opt, torch::kInt64);
|
||||
CHECK_OPTIONAL_INPUT(top_p_decay_opt, torch::kFloat);
|
||||
@ -281,8 +285,8 @@ void DynamicDecodeOp::setup(int64_t batch_size, int64_t beam_width, th::optional
|
||||
|
||||
dynamic_decode_->setup(static_cast<size_t>(batch_size), static_cast<size_t>(beam_width), runtime_top_k_opt,
|
||||
runtime_top_p_opt, temperature_opt, repetition_penalty_opt, presence_penalty_opt, frequency_penalty_opt,
|
||||
min_length_opt, length_penalty_opt, beam_search_diversity_rate_opt, random_seed_opt, top_p_decay_opt,
|
||||
top_p_min_opt, top_p_reset_ids_opt);
|
||||
min_length_opt, length_penalty_opt, early_stopping_opt, beam_search_diversity_rate_opt, random_seed_opt,
|
||||
top_p_decay_opt, top_p_min_opt, top_p_reset_ids_opt);
|
||||
}
|
||||
|
||||
th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max_input_length,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user