diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index db2a5621db..998f4b28cb 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -110,13 +110,18 @@ class SpeculativeDecodingMode(IntEnum): def extend_ctx(self, attention_backend: Type[AttentionBackend]): """ If true, treat generation requests with draft tokens as - chunked context requests at the kernel level. Required for - any spec dec mode that uses the SpecExecutor. + chunked context requests at the kernel level. """ if self.use_one_engine(): # 1-model has separate logic for handling draft tokens return False + + if issubclass(attention_backend, + TrtllmAttention) and self.is_mtp_eagle(): + # TRTLLM MLA does not work with the chunked context mode. + return False + return not issubclass(attention_backend, TrtllmAttention) or get_sm_version() != 100 diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index 0a1a58d857..78c0938a99 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -165,6 +165,9 @@ class ModelDrafter(Drafter): input_tokens = get_draft_model_prompt(self.spec_config.spec_dec_mode, request.get_tokens(0)) + is_eagle_style = self.spec_config.spec_dec_mode.is_eagle3( + ) or self.spec_config.spec_dec_mode.is_mtp_eagle() + # First time seeing this request - context request if request.max_beam_num_tokens - 1 == request.py_prompt_len: # This is the first time the draft model is seeing this request. @@ -174,10 +177,8 @@ class ModelDrafter(Drafter): return self._create_context_request(request, input_tokens) # For TRTLLM attention backend, we need to create a generation request for both no tokens accepted and tokens accepted - elif issubclass( - self.draft_model_engine.attn_backend, TrtllmAttention - ) and self.use_static_draft_loop and self.spec_config.spec_dec_mode.is_eagle3( - ): + elif issubclass(self.draft_model_engine.attn_backend, TrtllmAttention + ) and self.use_static_draft_loop and is_eagle_style: return self._create_accepted_tokens_request_for_trtllm_attn( request, input_tokens, num_accepted_tokens) diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index 0cfe1449ac..d2d45b0c85 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -1953,18 +1953,17 @@ def test_ptp_quickstart_advanced_mtp_eagle(llm_root, llm_venv, model_name, dir="./", delete=True, delete_on_close=True) as running_log: - llm_venv.run_cmd( - [ - str(example_root / "quickstart_advanced.py"), - "--use_cuda_graph", - "--spec_decode_max_draft_len", - "1", # test 1 MTP module - "--spec_decode_algo", - "MTP", - "--model_dir", - f"{llm_models_root()}/{model_path}", - ], - stdout=running_log) + llm_venv.run_cmd([ + str(example_root / "quickstart_advanced.py"), + "--use_cuda_graph", + "--spec_decode_max_draft_len", + "3", + "--spec_decode_algo", + "MTP", + "--model_dir", + f"{llm_models_root()}/{model_path}", + ], + stdout=running_log) # 74.60 is the memory usage for DeepSeek-V3-Lite-BF16 with MTP Eagle 2 two model style as one extra kv cache is needed for draft model. _check_mem_usage(running_log, [74.60, 0, 0, 0])