[https://nvbugs/5772414][fix] Fix draft token tree depth=1 corner case (#10385)

Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
Mike Iovine 2026-01-05 11:20:14 -05:00 committed by GitHub
parent bedfff4f00
commit db2614ef10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 43 additions and 1 deletions

View File

@ -2139,7 +2139,6 @@ class PyTorchModelEngine(ModelEngine):
# For the target model + tree decoding
if not self.is_draft_model and not spec_config.is_linear_tree:
assert spec_tree_manager is not None
assert num_draft_tokens == spec_tree_manager.max_total_draft_tokens
position_ids.extend(
past_seen_token_num +
spec_tree_manager.spec_dec_position_offsets[

View File

@ -294,6 +294,13 @@ class TreeDraftingLoopWrapper(BaseDraftingLoopWrapper):
self.draft_tokens_buffer[:batch_size, :-1], 0, 1)
# return_draft_logits: [batch_size, max_total_draft_tokens + 1, vocab_size] -> [max_total_draft_tokens, batch_size, vocab_size]
if return_draft_logits is None:
# When max_draft_len == 1, the loop doesn't execute.
# Expand the initial logits to match the expected shape.
return_draft_logits = logits.unsqueeze(1).expand(
batch_size, self.max_total_draft_tokens + 1,
vocab_size).reshape(-1, vocab_size)
return_draft_logits = return_draft_logits.reshape(
batch_size, self.max_total_draft_tokens + 1, vocab_size)
return_draft_logits = torch.transpose(return_draft_logits[:, :-1, :], 0,

View File

@ -2115,6 +2115,40 @@ def test_draft_token_tree_quickstart_advanced_eagle3(llm_root, llm_venv,
_check_mem_usage(running_log, [27, 0, 0, 0])
@pytest.mark.parametrize("model_name,model_path,eagle_model_path", [
("Llama-3.1-8b-Instruct", "llama-3.1-model/Llama-3.1-8B-Instruct",
"EAGLE3-LLaMA3.1-Instruct-8B"),
])
def test_draft_token_tree_quickstart_advanced_eagle3_depth_1_tree(
llm_root, llm_venv, model_name, model_path, eagle_model_path):
print(f"Testing {model_name}.")
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
with tempfile.NamedTemporaryFile(mode='w+t',
suffix=f".{model_name}.log",
dir="./",
delete=True,
delete_on_close=True) as running_log:
llm_venv.run_cmd([
str(example_root / "quickstart_advanced.py"),
"--prompt",
"You are a good assistant. Please tell me the capital of France is",
"--spec_decode_max_draft_len",
"3",
"--spec_decode_algo",
"eagle3",
"--model_dir",
f"{llm_models_root()}/{model_path}",
"--draft_model_dir",
f"{llm_models_root()}/{eagle_model_path}",
"--disable_kv_cache_reuse",
"--disable_overlap_scheduler",
"--eagle_choices",
"[[0], [1], [2]]",
],
stdout=running_log)
_check_mem_usage(running_log, [27, 0, 0, 0])
@pytest.mark.parametrize("model_name,model_path", [
("Llama-3.1-8B-Instruct", "llama-3.1-model/Llama-3.1-8B-Instruct"),
])

View File

@ -288,6 +288,8 @@ l0_h100:
- examples/test_mistral.py::test_mistral_with_bf16_lora_torch[mistral-7b-v0.1]
- examples/test_phi.py::test_phi_4_mini_instruct_with_bf16_lora_torch[Phi-4-mini-instruct]
- examples/test_llama.py::test_llama_3_x_with_bf16_lora_torch[llama-3.2-1b-instruct]
- test_e2e.py::test_draft_token_tree_quickstart_advanced_eagle3[Llama-3.1-8b-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct-EAGLE3-LLaMA3.1-Instruct-8B]
- test_e2e.py::test_draft_token_tree_quickstart_advanced_eagle3_depth_1_tree[Llama-3.1-8b-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct-EAGLE3-LLaMA3.1-Instruct-8B]
# https://nvbugs/5563469: Disable Nemotron-Nano-8B-v1 test due to non-deterministic failures, revisit as part of TRTLLM-7885
# - examples/test_nemotron_nas.py::test_nemotron_nano_8b_lora_torch[Llama-3.1-Nemotron-Nano-8B-v1]
- condition: