mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[https://nvbugs/5814914][fix] Fix llama sm120 spec dec (#10765)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
parent
fa5c3ead05
commit
d9aef94431
@ -3380,9 +3380,13 @@ class PyTorchModelEngine(ModelEngine):
|
||||
no_cache=kv_cache_manager
|
||||
is None)
|
||||
# attn_metadata now depends on spec_metadata since it determines the shape/content of spec_dec parameter Tensors
|
||||
enable_mla = is_mla(self.model.model_config.pretrained_config)
|
||||
is_spec_dec_mode = spec_metadata.spec_dec_mode.attention_need_spec_dec_mode(
|
||||
spec_resource_manager, self.is_draft_model, self.attn_backend,
|
||||
self.model_is_wrapped)
|
||||
spec_resource_manager,
|
||||
self.is_draft_model,
|
||||
self.attn_backend,
|
||||
self.model_is_wrapped,
|
||||
is_mla=enable_mla)
|
||||
attn_metadata.update_spec_dec_param(
|
||||
batch_size=scheduled_requests.batch_size,
|
||||
is_spec_decoding_enabled=is_spec_dec_mode,
|
||||
|
||||
@ -151,11 +151,12 @@ class SpeculativeDecodingMode(IntEnum):
|
||||
TrtllmAttention) or not xqa_supported
|
||||
|
||||
def attention_need_spec_dec_mode(
|
||||
self,
|
||||
spec_resource_manager: Optional[BaseResourceManager],
|
||||
is_draft_model: bool,
|
||||
attention_backend: Type[AttentionBackend],
|
||||
use_chain_drafter: bool, # CDL
|
||||
self,
|
||||
spec_resource_manager: Optional[BaseResourceManager],
|
||||
is_draft_model: bool,
|
||||
attention_backend: Type[AttentionBackend],
|
||||
use_chain_drafter: bool, # CDL
|
||||
is_mla: bool,
|
||||
):
|
||||
"""
|
||||
If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode).
|
||||
@ -168,7 +169,7 @@ class SpeculativeDecodingMode(IntEnum):
|
||||
is_trtllm_attention = issubclass(attention_backend, TrtllmAttention)
|
||||
|
||||
# Always use the multi-token query mode for 1-model if the kernels are available.
|
||||
xqa_supported = get_sm_version() < 120
|
||||
xqa_supported = not is_mla or get_sm_version() < 120
|
||||
use_case_1 = self.use_one_engine() and xqa_supported
|
||||
# For 2-model, we need to enable it when we process multiple tokens at once. This occurs with
|
||||
# the target model (verification) or on the first draft for CDL based speculation.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user