[https://nvbugs/5814914][fix] Fix llama sm120 spec dec (#10765)

Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
Mike Iovine 2026-01-23 10:50:17 -05:00 committed by Yanchao Lu
parent fa5c3ead05
commit d9aef94431
2 changed files with 13 additions and 8 deletions

View File

@ -3380,9 +3380,13 @@ class PyTorchModelEngine(ModelEngine):
no_cache=kv_cache_manager
is None)
# attn_metadata now depends on spec_metadata since it determines the shape/content of spec_dec parameter Tensors
enable_mla = is_mla(self.model.model_config.pretrained_config)
is_spec_dec_mode = spec_metadata.spec_dec_mode.attention_need_spec_dec_mode(
spec_resource_manager, self.is_draft_model, self.attn_backend,
self.model_is_wrapped)
spec_resource_manager,
self.is_draft_model,
self.attn_backend,
self.model_is_wrapped,
is_mla=enable_mla)
attn_metadata.update_spec_dec_param(
batch_size=scheduled_requests.batch_size,
is_spec_decoding_enabled=is_spec_dec_mode,

View File

@ -151,11 +151,12 @@ class SpeculativeDecodingMode(IntEnum):
TrtllmAttention) or not xqa_supported
def attention_need_spec_dec_mode(
self,
spec_resource_manager: Optional[BaseResourceManager],
is_draft_model: bool,
attention_backend: Type[AttentionBackend],
use_chain_drafter: bool, # CDL
self,
spec_resource_manager: Optional[BaseResourceManager],
is_draft_model: bool,
attention_backend: Type[AttentionBackend],
use_chain_drafter: bool, # CDL
is_mla: bool,
):
"""
If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode).
@ -168,7 +169,7 @@ class SpeculativeDecodingMode(IntEnum):
is_trtllm_attention = issubclass(attention_backend, TrtllmAttention)
# Always use the multi-token query mode for 1-model if the kernels are available.
xqa_supported = get_sm_version() < 120
xqa_supported = not is_mla or get_sm_version() < 120
use_case_1 = self.use_one_engine() and xqa_supported
# For 2-model, we need to enable it when we process multiple tokens at once. This occurs with
# the target model (verification) or on the first draft for CDL based speculation.