[nvbug/5319281][fix] Stop drafting when we hit the draft model's max seq len (#4879)

Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
Mike Iovine 2025-06-13 11:06:36 -04:00 committed by GitHub
parent 3d87770e15
commit 25aa3881d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 30 additions and 9 deletions

View File

@ -1080,13 +1080,22 @@ class PyTorchModelEngine(ModelEngine):
return model
def _init_max_seq_len(self):
inferred_max_seq_len = self.model.infer_max_seq_len()
if self.max_seq_len is None:
inferred_max_seq_len = self.model.infer_max_seq_len()
logger.info(
f"max_seq_len is not specified, using inferred value {inferred_max_seq_len}"
)
self.max_seq_len = inferred_max_seq_len
elif inferred_max_seq_len < self.max_seq_len:
# NOTE: py_executor_creator makes sure that the executor uses this
# smaller value as its max_seq_len too.
logger.warning(
f"Specified {self.max_seq_len=} is larger than what the model can support "
f"({inferred_max_seq_len}). Setting max_seq_len to {inferred_max_seq_len}. "
)
self.max_seq_len = inferred_max_seq_len
def _init_max_num_tokens(self):
# Modified from tensorrt_llm/_common.py check_max_num_tokens
if self.max_num_tokens is None:

View File

@ -1757,6 +1757,16 @@ class PyExecutor:
# No space for draft tokens.
continue
# Stop drafting when we hit the max seqlen. We still need dummy draft
# tokens attached to the requests to make sure everything works properly
# with CUDA graph. These dummy tokens are already added by
# _prepare_draft_requests to make the KV cache/scheduler aware of the fact
# that we want to do spec decoding, so no need to do anything else here.
# This makes the perf for this case suboptimal, but that's OK - this is
# a corner case for weird models like the llama 3.1 8b EAGLE3 implementation.
if request.max_beam_num_tokens - 1 >= self.draft_model_engine.max_seq_len:
continue
num_draft_tokens = len(
request.py_last_draft_tokens
) if request.py_last_draft_tokens is not None else 0

View File

@ -231,7 +231,9 @@ def create_py_executor(executor_config: ExecutorConfig,
pytorch_backend_config,
batch_size=executor_config.max_batch_size,
max_num_tokens=executor_config.max_num_tokens,
max_seq_len=model_engine.max_seq_len,
# Note: The draft model engine will infer its own max_seq_len.
# We'll stop drafting when we hit the max.
max_seq_len=executor_config.max_seq_len,
mapping=mapping,
attn_runtime_features=attn_runtime_features,
dist=dist,

View File

@ -46,11 +46,13 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str):
llm_spec = LLM(
model=target_model_dir,
**pytorch_config,
# This max_seq_len is larger than the one specified
# in the llama 3 8B eagle's config. We want to make sure
# that the draft model won't go above its max in warmup
# in this test.
max_seq_len=8192,
kv_cache_config=kv_cache_config,
speculative_config=spec_config,
# TODO: https://nvbugspro.nvidia.com/bug/5319281
max_num_tokens=2048,
max_seq_len=2048)
speculative_config=spec_config)
sampling_params = SamplingParams(
max_tokens=32,
@ -88,9 +90,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str):
llm_ref = LLM(model=target_model_dir,
**pytorch_config,
kv_cache_config=kv_cache_config,
max_num_tokens=2048,
max_seq_len=2048)
kv_cache_config=kv_cache_config)
results_ref = llm_ref.generate(prompts, sampling_params)
generated_text_ref = [result.outputs[0].text for result in results_ref]