[None][fix] Fix MTP illegal memory access (#8161)

Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
Mike Iovine 2025-10-07 14:02:55 -04:00 committed by GitHub
parent ca9da1f1c2
commit 7facac077b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 23 additions and 18 deletions

View File

@ -110,13 +110,18 @@ class SpeculativeDecodingMode(IntEnum):
def extend_ctx(self, attention_backend: Type[AttentionBackend]):
"""
If true, treat generation requests with draft tokens as
chunked context requests at the kernel level. Required for
any spec dec mode that uses the SpecExecutor.
chunked context requests at the kernel level.
"""
if self.use_one_engine():
# 1-model has separate logic for handling draft tokens
return False
if issubclass(attention_backend,
TrtllmAttention) and self.is_mtp_eagle():
# TRTLLM MLA does not work with the chunked context mode.
return False
return not issubclass(attention_backend,
TrtllmAttention) or get_sm_version() != 100

View File

@ -165,6 +165,9 @@ class ModelDrafter(Drafter):
input_tokens = get_draft_model_prompt(self.spec_config.spec_dec_mode,
request.get_tokens(0))
is_eagle_style = self.spec_config.spec_dec_mode.is_eagle3(
) or self.spec_config.spec_dec_mode.is_mtp_eagle()
# First time seeing this request - context request
if request.max_beam_num_tokens - 1 == request.py_prompt_len:
# This is the first time the draft model is seeing this request.
@ -174,10 +177,8 @@ class ModelDrafter(Drafter):
return self._create_context_request(request, input_tokens)
# For TRTLLM attention backend, we need to create a generation request for both no tokens accepted and tokens accepted
elif issubclass(
self.draft_model_engine.attn_backend, TrtllmAttention
) and self.use_static_draft_loop and self.spec_config.spec_dec_mode.is_eagle3(
):
elif issubclass(self.draft_model_engine.attn_backend, TrtllmAttention
) and self.use_static_draft_loop and is_eagle_style:
return self._create_accepted_tokens_request_for_trtllm_attn(
request, input_tokens, num_accepted_tokens)

View File

@ -1953,18 +1953,17 @@ def test_ptp_quickstart_advanced_mtp_eagle(llm_root, llm_venv, model_name,
dir="./",
delete=True,
delete_on_close=True) as running_log:
llm_venv.run_cmd(
[
str(example_root / "quickstart_advanced.py"),
"--use_cuda_graph",
"--spec_decode_max_draft_len",
"1", # test 1 MTP module
"--spec_decode_algo",
"MTP",
"--model_dir",
f"{llm_models_root()}/{model_path}",
],
stdout=running_log)
llm_venv.run_cmd([
str(example_root / "quickstart_advanced.py"),
"--use_cuda_graph",
"--spec_decode_max_draft_len",
"3",
"--spec_decode_algo",
"MTP",
"--model_dir",
f"{llm_models_root()}/{model_path}",
],
stdout=running_log)
# 74.60 is the memory usage for DeepSeek-V3-Lite-BF16 with MTP Eagle 2 two model style as one extra kv cache is needed for draft model.
_check_mem_usage(running_log, [74.60, 0, 0, 0])