mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5772414][fix] Fix draft token tree depth=1 corner case (#10385)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
parent
bedfff4f00
commit
db2614ef10
@ -2139,7 +2139,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# For the target model + tree decoding
|
||||
if not self.is_draft_model and not spec_config.is_linear_tree:
|
||||
assert spec_tree_manager is not None
|
||||
assert num_draft_tokens == spec_tree_manager.max_total_draft_tokens
|
||||
position_ids.extend(
|
||||
past_seen_token_num +
|
||||
spec_tree_manager.spec_dec_position_offsets[
|
||||
|
||||
@ -294,6 +294,13 @@ class TreeDraftingLoopWrapper(BaseDraftingLoopWrapper):
|
||||
self.draft_tokens_buffer[:batch_size, :-1], 0, 1)
|
||||
|
||||
# return_draft_logits: [batch_size, max_total_draft_tokens + 1, vocab_size] -> [max_total_draft_tokens, batch_size, vocab_size]
|
||||
if return_draft_logits is None:
|
||||
# When max_draft_len == 1, the loop doesn't execute.
|
||||
# Expand the initial logits to match the expected shape.
|
||||
return_draft_logits = logits.unsqueeze(1).expand(
|
||||
batch_size, self.max_total_draft_tokens + 1,
|
||||
vocab_size).reshape(-1, vocab_size)
|
||||
|
||||
return_draft_logits = return_draft_logits.reshape(
|
||||
batch_size, self.max_total_draft_tokens + 1, vocab_size)
|
||||
return_draft_logits = torch.transpose(return_draft_logits[:, :-1, :], 0,
|
||||
|
||||
@ -2115,6 +2115,40 @@ def test_draft_token_tree_quickstart_advanced_eagle3(llm_root, llm_venv,
|
||||
_check_mem_usage(running_log, [27, 0, 0, 0])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name,model_path,eagle_model_path", [
|
||||
("Llama-3.1-8b-Instruct", "llama-3.1-model/Llama-3.1-8B-Instruct",
|
||||
"EAGLE3-LLaMA3.1-Instruct-8B"),
|
||||
])
|
||||
def test_draft_token_tree_quickstart_advanced_eagle3_depth_1_tree(
|
||||
llm_root, llm_venv, model_name, model_path, eagle_model_path):
|
||||
print(f"Testing {model_name}.")
|
||||
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
|
||||
with tempfile.NamedTemporaryFile(mode='w+t',
|
||||
suffix=f".{model_name}.log",
|
||||
dir="./",
|
||||
delete=True,
|
||||
delete_on_close=True) as running_log:
|
||||
llm_venv.run_cmd([
|
||||
str(example_root / "quickstart_advanced.py"),
|
||||
"--prompt",
|
||||
"You are a good assistant. Please tell me the capital of France is",
|
||||
"--spec_decode_max_draft_len",
|
||||
"3",
|
||||
"--spec_decode_algo",
|
||||
"eagle3",
|
||||
"--model_dir",
|
||||
f"{llm_models_root()}/{model_path}",
|
||||
"--draft_model_dir",
|
||||
f"{llm_models_root()}/{eagle_model_path}",
|
||||
"--disable_kv_cache_reuse",
|
||||
"--disable_overlap_scheduler",
|
||||
"--eagle_choices",
|
||||
"[[0], [1], [2]]",
|
||||
],
|
||||
stdout=running_log)
|
||||
_check_mem_usage(running_log, [27, 0, 0, 0])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name,model_path", [
|
||||
("Llama-3.1-8B-Instruct", "llama-3.1-model/Llama-3.1-8B-Instruct"),
|
||||
])
|
||||
|
||||
@ -288,6 +288,8 @@ l0_h100:
|
||||
- examples/test_mistral.py::test_mistral_with_bf16_lora_torch[mistral-7b-v0.1]
|
||||
- examples/test_phi.py::test_phi_4_mini_instruct_with_bf16_lora_torch[Phi-4-mini-instruct]
|
||||
- examples/test_llama.py::test_llama_3_x_with_bf16_lora_torch[llama-3.2-1b-instruct]
|
||||
- test_e2e.py::test_draft_token_tree_quickstart_advanced_eagle3[Llama-3.1-8b-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct-EAGLE3-LLaMA3.1-Instruct-8B]
|
||||
- test_e2e.py::test_draft_token_tree_quickstart_advanced_eagle3_depth_1_tree[Llama-3.1-8b-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct-EAGLE3-LLaMA3.1-Instruct-8B]
|
||||
# https://nvbugs/5563469: Disable Nemotron-Nano-8B-v1 test due to non-deterministic failures, revisit as part of TRTLLM-7885
|
||||
# - examples/test_nemotron_nas.py::test_nemotron_nano_8b_lora_torch[Llama-3.1-Nemotron-Nano-8B-v1]
|
||||
- condition:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user