[None][fix] Remove overlap scheduler adjustment for max sequence length in create_py_executor function (#9229)

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
This commit is contained in:
Robin Kobus 2026-02-11 17:46:25 +01:00 committed by GitHub
parent c47ff4da43
commit 7a103035be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 6 additions and 11 deletions

View File

@ -273,10 +273,8 @@ class KvCacheCreator:
num_cache_blocks = 0
num_extra_tokens_per_seq = 1 # account for generated tokens
spec_cfg = self._speculative_config
if not self._llm_args.disable_overlap_scheduler:
num_extra_tokens_per_seq = num_extra_tokens_per_seq + 1
if spec_cfg is not None:
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens
if not self._llm_args.disable_overlap_scheduler and spec_cfg is not None:
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens
if spec_cfg is not None:
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens

View File

@ -457,8 +457,7 @@ class PyTorchModelEngine(ModelEngine):
# This way it can also be used for CUDA graphs.
if self.use_beam_search:
self.cache_indirection_attention = torch.zeros(
(self.batch_size, self.max_beam_width, self.max_seq_len +
(0 if self._disable_overlap_scheduler else 1)),
(self.batch_size, self.max_beam_width, self.max_seq_len),
device="cuda",
dtype=torch.int32)
else:

View File

@ -436,10 +436,8 @@ def create_py_executor(
# PyTorchModelEngine modifies these fields, update them
model_engine_max_seq_len = model_engine.max_seq_len
net_max_seq_len = model_engine_max_seq_len
if not llm_args.disable_overlap_scheduler:
model_engine_max_seq_len = model_engine.max_seq_len + 1
if spec_config is not None:
model_engine_max_seq_len += spec_config.max_total_draft_tokens
if not llm_args.disable_overlap_scheduler and spec_config is not None:
model_engine_max_seq_len += spec_config.max_total_draft_tokens
if spec_config is not None:
model_engine_max_seq_len += get_num_extra_kv_tokens(spec_config)

View File

@ -1047,7 +1047,7 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
self.CACHE_INDIRECTION_SHAPE = (
self.max_num_sequences,
self.max_beam_width,
self.max_seq_len + (0 if args.disable_overlap_scheduler else 1),
self.max_seq_len,
)
self.LOGPROBS_SHAPE = (self.max_num_sequences, self.max_beam_width, self.max_tokens)
self.TOPK_LOGPROBS_SHAPE = (self.max_num_sequences, self.max_tokens, self.max_topk_logprobs)