chore [TRTLLM-6161]: add LLM speculative decoding example (#5706)

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
This commit is contained in:
Yan Chunwei 2025-07-09 07:33:11 +08:00 committed by GitHub
parent da8c7372d4
commit e50d95c40d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 127 additions and 0 deletions

View File

@ -0,0 +1,92 @@
### :title Speculative Decoding
### :order 5
### :section Customization
from typing import Optional
import click
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import (EagleDecodingConfig, MTPDecodingConfig,
NGramDecodingConfig)
prompts = [
"What is the capital of France?",
"What is the future of AI?",
]
def run_MTP(model: Optional[str] = None):
spec_config = MTPDecodingConfig(num_nextn_predict_layers=1,
use_relaxed_acceptance_for_thinking=True,
relaxed_topk=10,
relaxed_delta=0.01)
llm = LLM(
# You can change this to a local model path if you have the model downloaded
model=model or "nvidia/DeepSeek-R1-FP4",
speculative_config=spec_config,
)
for prompt in prompts:
response = llm.generate(prompt, SamplingParams(max_tokens=10))
print(response.outputs[0].text)
def run_Eagle3():
spec_config = EagleDecodingConfig(
max_draft_len=3,
pytorch_weights_path="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
eagle3_one_model=True)
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
speculative_config=spec_config,
)
for prompt in prompts:
response = llm.generate(prompt, SamplingParams(max_tokens=10))
print(response.outputs[0].text)
def run_ngram():
spec_config = NGramDecodingConfig(
prompt_lookup_num_tokens=3,
max_matching_ngram_size=3,
is_keep_all=True,
is_use_oldest=True,
is_public_pool=True,
)
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
speculative_config=spec_config,
# ngram doesn't work with overlap_scheduler
disable_overlap_scheduler=True,
)
for prompt in prompts:
response = llm.generate(prompt, SamplingParams(max_tokens=10))
print(response.outputs[0].text)
@click.command()
@click.argument("algo",
type=click.Choice(["MTP", "EAGLE3", "DRAFT_TARGET", "NGRAM"]))
@click.option("--model",
type=str,
default=None,
help="Path to the model or model name.")
def main(algo: str, model: Optional[str] = None):
algo = algo.upper()
if algo == "MTP":
run_MTP(model)
elif algo == "EAGLE3":
run_Eagle3()
elif algo == "NGRAM":
run_ngram()
else:
raise ValueError(f"Invalid algorithm: {algo}")
if __name__ == "__main__":
main()

View File

@ -52,14 +52,27 @@ def _run_llmapi_example(llm_root, engine_dir, llm_venv, script_name: str,
# Create llm models softlink to avoid duplicated downloading for llm api example
src_dst_dict = {
# TinyLlama-1.1B-Chat-v1.0
f"{llm_models_root()}/llama-models-v2/TinyLlama-1.1B-Chat-v1.0":
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
# vicuna-7b-v1.3
f"{llm_models_root()}/vicuna-7b-v1.3":
f"{llm_venv.get_working_directory()}/lmsys/vicuna-7b-v1.3",
# medusa-vicuna-7b-v1.3
f"{llm_models_root()}/medusa-vicuna-7b-v1.3":
f"{llm_venv.get_working_directory()}/FasterDecoding/medusa-vicuna-7b-v1.3",
# llama3.1-medusa-8b-hf_v0.1
f"{llm_models_root()}/llama3.1-medusa-8b-hf_v0.1":
f"{llm_venv.get_working_directory()}/nvidia/Llama-3.1-8B-Medusa-FP8",
# Llama-3.1-8B-Instruct
f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct":
f"{llm_venv.get_working_directory()}/meta-llama/Llama-3.1-8B-Instruct",
# DeepSeek-V3-Lite/bf16
f"{llm_models_root()}/DeepSeek-V3-Lite/bf16":
f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/bf16",
# EAGLE3-LLaMA3.1-Instruct-8B
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B":
f"{llm_venv.get_working_directory()}/yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
}
for src, dst in src_dst_dict.items():
@ -129,3 +142,22 @@ def test_llmapi_quickstart_atexit(llm_root, engine_dir, llm_venv):
llm_root
) / "tests/integration/defs/examples/run_llm_quickstart_atexit.py"
llm_venv.run_cmd([str(script_path)])
@pytest.mark.skip(reason="https://nvbugs/5365825")
def test_llmapi_speculative_decoding_mtp(llm_root, engine_dir, llm_venv):
_run_llmapi_example(llm_root, engine_dir, llm_venv,
"llm_speculative_decoding.py", "MTP", "--model",
f"{llm_models_root()}/DeepSeek-V3-Lite/bf16")
@pytest.mark.skip(reason="https://nvbugs/5365825")
def test_llmapi_speculative_decoding_eagle3(llm_root, engine_dir, llm_venv):
_run_llmapi_example(llm_root, engine_dir, llm_venv,
"llm_speculative_decoding.py", "EAGLE3")
@pytest.mark.skip(reason="https://nvbugs/5365825")
def test_llmapi_speculative_decoding_ngram(llm_root, engine_dir, llm_venv):
_run_llmapi_example(llm_root, engine_dir, llm_venv,
"llm_speculative_decoding.py", "NGRAM")

View File

@ -25,4 +25,7 @@ l0_sanity_check:
- llmapi/test_llm_examples.py::test_llmapi_example_multilora
- llmapi/test_llm_examples.py::test_llmapi_example_guided_decoding
- llmapi/test_llm_examples.py::test_llmapi_example_logits_processor
- llmapi/test_llm_examples.py::test_llmapi_speculative_decoding_mtp
- llmapi/test_llm_examples.py::test_llmapi_speculative_decoding_eagle3
- llmapi/test_llm_examples.py::test_llmapi_speculative_decoding_ngram
- examples/test_llm_api_with_mpi.py::test_llm_api_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]