diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 692c4f4039..f6a3d1e420 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1080,13 +1080,22 @@ class PyTorchModelEngine(ModelEngine): return model def _init_max_seq_len(self): + inferred_max_seq_len = self.model.infer_max_seq_len() if self.max_seq_len is None: - inferred_max_seq_len = self.model.infer_max_seq_len() logger.info( f"max_seq_len is not specified, using inferred value {inferred_max_seq_len}" ) self.max_seq_len = inferred_max_seq_len + elif inferred_max_seq_len < self.max_seq_len: + # NOTE: py_executor_creator makes sure that the executor uses this + # smaller value as its max_seq_len too. + logger.warning( + f"Specified {self.max_seq_len=} is larger than what the model can support " + f"({inferred_max_seq_len}). Setting max_seq_len to {inferred_max_seq_len}. " + ) + self.max_seq_len = inferred_max_seq_len + def _init_max_num_tokens(self): # Modified from tensorrt_llm/_common.py check_max_num_tokens if self.max_num_tokens is None: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 7c8c1ccc32..f8437607d6 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1757,6 +1757,16 @@ class PyExecutor: # No space for draft tokens. continue + # Stop drafting when we hit the max seqlen. We still need dummy draft + # tokens attached to the requests to make sure everything works properly + # with CUDA graph. These dummy tokens are already added by + # _prepare_draft_requests to make the KV cache/scheduler aware of the fact + # that we want to do spec decoding, so no need to do anything else here. + # This makes the perf for this case suboptimal, but that's OK - this is + # a corner case for weird models like the llama 3.1 8b EAGLE3 implementation. + if request.max_beam_num_tokens - 1 >= self.draft_model_engine.max_seq_len: + continue + num_draft_tokens = len( request.py_last_draft_tokens ) if request.py_last_draft_tokens is not None else 0 diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 2345d5c779..c3844ebd7f 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -231,7 +231,9 @@ def create_py_executor(executor_config: ExecutorConfig, pytorch_backend_config, batch_size=executor_config.max_batch_size, max_num_tokens=executor_config.max_num_tokens, - max_seq_len=model_engine.max_seq_len, + # Note: The draft model engine will infer its own max_seq_len. + # We'll stop drafting when we hit the max. + max_seq_len=executor_config.max_seq_len, mapping=mapping, attn_runtime_features=attn_runtime_features, dist=dist, diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index 6c3527da7b..b51c785db2 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -46,11 +46,13 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str): llm_spec = LLM( model=target_model_dir, **pytorch_config, + # This max_seq_len is larger than the one specified + # in the llama 3 8B eagle's config. We want to make sure + # that the draft model won't go above its max in warmup + # in this test. + max_seq_len=8192, kv_cache_config=kv_cache_config, - speculative_config=spec_config, - # TODO: https://nvbugspro.nvidia.com/bug/5319281 - max_num_tokens=2048, - max_seq_len=2048) + speculative_config=spec_config) sampling_params = SamplingParams( max_tokens=32, @@ -88,9 +90,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str): llm_ref = LLM(model=target_model_dir, **pytorch_config, - kv_cache_config=kv_cache_config, - max_num_tokens=2048, - max_seq_len=2048) + kv_cache_config=kv_cache_config) results_ref = llm_ref.generate(prompts, sampling_params) generated_text_ref = [result.outputs[0].text for result in results_ref]