mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-18 16:55:08 +08:00
[https://nvbugs/5815025][fix] Fix spec-dec mode flag and related cpp requirements (#10996)
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
parent
ebd859cf61
commit
2b4ef3a014
@ -1271,14 +1271,6 @@ int AttentionOp::mlaGeneration(
|
||||
mXqaDispatcher->run(xqaParams, kv_cache_buffer, kv_scale_cache_buffer);
|
||||
return 0;
|
||||
}
|
||||
else if (mIsSpecDecodingEnabled && mUseSpecDecoding)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false, "No available XQA kernels are found for speculative decoding mode.");
|
||||
}
|
||||
else if (mFuseFp4Quant)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false, "No available kernels are found for FP4 output.");
|
||||
}
|
||||
}
|
||||
|
||||
// Use FMHA otherwise.
|
||||
|
||||
@ -3445,13 +3445,9 @@ class PyTorchModelEngine(ModelEngine):
|
||||
no_cache=kv_cache_manager
|
||||
is None)
|
||||
# attn_metadata now depends on spec_metadata since it determines the shape/content of spec_dec parameter Tensors
|
||||
enable_mla = is_mla(self.model.model_config.pretrained_config)
|
||||
is_spec_dec_mode = spec_metadata.spec_dec_mode.attention_need_spec_dec_mode(
|
||||
spec_resource_manager,
|
||||
self.is_draft_model,
|
||||
self.attn_backend,
|
||||
self.model_is_wrapped,
|
||||
is_mla=enable_mla)
|
||||
spec_resource_manager, self.is_draft_model, self.attn_backend,
|
||||
self.model_is_wrapped)
|
||||
attn_metadata.update_spec_dec_param(
|
||||
batch_size=scheduled_requests.batch_size,
|
||||
is_spec_decoding_enabled=is_spec_dec_mode,
|
||||
|
||||
@ -164,12 +164,11 @@ class SpeculativeDecodingMode(IntEnum):
|
||||
TrtllmAttention) or not xqa_supported
|
||||
|
||||
def attention_need_spec_dec_mode(
|
||||
self,
|
||||
spec_resource_manager: Optional[BaseResourceManager],
|
||||
is_draft_model: bool,
|
||||
attention_backend: Type[AttentionBackend],
|
||||
use_chain_drafter: bool, # CDL
|
||||
is_mla: bool,
|
||||
self,
|
||||
spec_resource_manager: Optional[BaseResourceManager],
|
||||
is_draft_model: bool,
|
||||
attention_backend: Type[AttentionBackend],
|
||||
use_chain_drafter: bool, # CDL
|
||||
):
|
||||
"""
|
||||
If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode).
|
||||
@ -182,8 +181,7 @@ class SpeculativeDecodingMode(IntEnum):
|
||||
is_trtllm_attention = issubclass(attention_backend, TrtllmAttention)
|
||||
|
||||
# Always use the multi-token query mode for 1-model if the kernels are available.
|
||||
xqa_supported = not is_mla or get_sm_version() < 120
|
||||
use_case_1 = self.use_one_engine() and xqa_supported
|
||||
use_case_1 = self.use_one_engine()
|
||||
# For 2-model, we need to enable it when we process multiple tokens at once. This occurs with
|
||||
# the target model (verification) or on the first draft for CDL based speculation.
|
||||
use_case_2 = not self.use_one_engine() and (
|
||||
|
||||
@ -246,7 +246,6 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_vswa_reuse_4gpus[one_m
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_vswa_reuse_4gpus[two_model] SKIP (https://nvbugs/5756028)
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_fp8_prequantized SKIP (https://nvbugs/5785465)
|
||||
accuracy/test_llm_api_pytorch.py::TestMinistral8BInstruct::test_fp8 SKIP (https://nvbugs/5785485)
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=True-torch_compile=False] SKIP (https://nvbugs/5787892)
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=False-torch_compile=False] SKIP (https://nvbugs/5787892)
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_chunked_prefill[tp8ep8-cuda_graph=False] SKIP (https://nvbugs/5795918)
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=True] SKIP (https://nvbugs/5800591)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user