diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 110b05b4f6..9a2bdcf738 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -3380,9 +3380,13 @@ class PyTorchModelEngine(ModelEngine): no_cache=kv_cache_manager is None) # attn_metadata now depends on spec_metadata since it determines the shape/content of spec_dec parameter Tensors + enable_mla = is_mla(self.model.model_config.pretrained_config) is_spec_dec_mode = spec_metadata.spec_dec_mode.attention_need_spec_dec_mode( - spec_resource_manager, self.is_draft_model, self.attn_backend, - self.model_is_wrapped) + spec_resource_manager, + self.is_draft_model, + self.attn_backend, + self.model_is_wrapped, + is_mla=enable_mla) attn_metadata.update_spec_dec_param( batch_size=scheduled_requests.batch_size, is_spec_decoding_enabled=is_spec_dec_mode, diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 6bc0846650..a155bac5a0 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -151,11 +151,12 @@ class SpeculativeDecodingMode(IntEnum): TrtllmAttention) or not xqa_supported def attention_need_spec_dec_mode( - self, - spec_resource_manager: Optional[BaseResourceManager], - is_draft_model: bool, - attention_backend: Type[AttentionBackend], - use_chain_drafter: bool, # CDL + self, + spec_resource_manager: Optional[BaseResourceManager], + is_draft_model: bool, + attention_backend: Type[AttentionBackend], + use_chain_drafter: bool, # CDL + is_mla: bool, ): """ If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode). @@ -168,7 +169,7 @@ class SpeculativeDecodingMode(IntEnum): is_trtllm_attention = issubclass(attention_backend, TrtllmAttention) # Always use the multi-token query mode for 1-model if the kernels are available. - xqa_supported = get_sm_version() < 120 + xqa_supported = not is_mla or get_sm_version() < 120 use_case_1 = self.use_one_engine() and xqa_supported # For 2-model, we need to enable it when we process multiple tokens at once. This occurs with # the target model (verification) or on the first draft for CDL based speculation.