mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[None][fix] Remove overlap scheduler adjustment for max sequence length in create_py_executor function (#9229)
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
This commit is contained in:
parent
c47ff4da43
commit
7a103035be
@ -273,10 +273,8 @@ class KvCacheCreator:
|
||||
num_cache_blocks = 0
|
||||
num_extra_tokens_per_seq = 1 # account for generated tokens
|
||||
spec_cfg = self._speculative_config
|
||||
if not self._llm_args.disable_overlap_scheduler:
|
||||
num_extra_tokens_per_seq = num_extra_tokens_per_seq + 1
|
||||
if spec_cfg is not None:
|
||||
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens
|
||||
if not self._llm_args.disable_overlap_scheduler and spec_cfg is not None:
|
||||
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens
|
||||
|
||||
if spec_cfg is not None:
|
||||
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens
|
||||
|
||||
@ -457,8 +457,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# This way it can also be used for CUDA graphs.
|
||||
if self.use_beam_search:
|
||||
self.cache_indirection_attention = torch.zeros(
|
||||
(self.batch_size, self.max_beam_width, self.max_seq_len +
|
||||
(0 if self._disable_overlap_scheduler else 1)),
|
||||
(self.batch_size, self.max_beam_width, self.max_seq_len),
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
else:
|
||||
|
||||
@ -436,10 +436,8 @@ def create_py_executor(
|
||||
# PyTorchModelEngine modifies these fields, update them
|
||||
model_engine_max_seq_len = model_engine.max_seq_len
|
||||
net_max_seq_len = model_engine_max_seq_len
|
||||
if not llm_args.disable_overlap_scheduler:
|
||||
model_engine_max_seq_len = model_engine.max_seq_len + 1
|
||||
if spec_config is not None:
|
||||
model_engine_max_seq_len += spec_config.max_total_draft_tokens
|
||||
if not llm_args.disable_overlap_scheduler and spec_config is not None:
|
||||
model_engine_max_seq_len += spec_config.max_total_draft_tokens
|
||||
|
||||
if spec_config is not None:
|
||||
model_engine_max_seq_len += get_num_extra_kv_tokens(spec_config)
|
||||
|
||||
@ -1047,7 +1047,7 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
self.CACHE_INDIRECTION_SHAPE = (
|
||||
self.max_num_sequences,
|
||||
self.max_beam_width,
|
||||
self.max_seq_len + (0 if args.disable_overlap_scheduler else 1),
|
||||
self.max_seq_len,
|
||||
)
|
||||
self.LOGPROBS_SHAPE = (self.max_num_sequences, self.max_beam_width, self.max_tokens)
|
||||
self.TOPK_LOGPROBS_SHAPE = (self.max_num_sequences, self.max_tokens, self.max_topk_logprobs)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user