mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
chore [TRTLLM-6161]: add LLM speculative decoding example (#5706)
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
This commit is contained in:
parent
da8c7372d4
commit
e50d95c40d
92
examples/llm-api/llm_speculative_decoding.py
Normal file
92
examples/llm-api/llm_speculative_decoding.py
Normal file
@ -0,0 +1,92 @@
|
||||
### :title Speculative Decoding
|
||||
### :order 5
|
||||
### :section Customization
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.llmapi import (EagleDecodingConfig, MTPDecodingConfig,
|
||||
NGramDecodingConfig)
|
||||
|
||||
prompts = [
|
||||
"What is the capital of France?",
|
||||
"What is the future of AI?",
|
||||
]
|
||||
|
||||
|
||||
def run_MTP(model: Optional[str] = None):
|
||||
spec_config = MTPDecodingConfig(num_nextn_predict_layers=1,
|
||||
use_relaxed_acceptance_for_thinking=True,
|
||||
relaxed_topk=10,
|
||||
relaxed_delta=0.01)
|
||||
|
||||
llm = LLM(
|
||||
# You can change this to a local model path if you have the model downloaded
|
||||
model=model or "nvidia/DeepSeek-R1-FP4",
|
||||
speculative_config=spec_config,
|
||||
)
|
||||
|
||||
for prompt in prompts:
|
||||
response = llm.generate(prompt, SamplingParams(max_tokens=10))
|
||||
print(response.outputs[0].text)
|
||||
|
||||
|
||||
def run_Eagle3():
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=3,
|
||||
pytorch_weights_path="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
eagle3_one_model=True)
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
speculative_config=spec_config,
|
||||
)
|
||||
|
||||
for prompt in prompts:
|
||||
response = llm.generate(prompt, SamplingParams(max_tokens=10))
|
||||
print(response.outputs[0].text)
|
||||
|
||||
|
||||
def run_ngram():
|
||||
spec_config = NGramDecodingConfig(
|
||||
prompt_lookup_num_tokens=3,
|
||||
max_matching_ngram_size=3,
|
||||
is_keep_all=True,
|
||||
is_use_oldest=True,
|
||||
is_public_pool=True,
|
||||
)
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
speculative_config=spec_config,
|
||||
# ngram doesn't work with overlap_scheduler
|
||||
disable_overlap_scheduler=True,
|
||||
)
|
||||
|
||||
for prompt in prompts:
|
||||
response = llm.generate(prompt, SamplingParams(max_tokens=10))
|
||||
print(response.outputs[0].text)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("algo",
|
||||
type=click.Choice(["MTP", "EAGLE3", "DRAFT_TARGET", "NGRAM"]))
|
||||
@click.option("--model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the model or model name.")
|
||||
def main(algo: str, model: Optional[str] = None):
|
||||
algo = algo.upper()
|
||||
if algo == "MTP":
|
||||
run_MTP(model)
|
||||
elif algo == "EAGLE3":
|
||||
run_Eagle3()
|
||||
elif algo == "NGRAM":
|
||||
run_ngram()
|
||||
else:
|
||||
raise ValueError(f"Invalid algorithm: {algo}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -52,14 +52,27 @@ def _run_llmapi_example(llm_root, engine_dir, llm_venv, script_name: str,
|
||||
|
||||
# Create llm models softlink to avoid duplicated downloading for llm api example
|
||||
src_dst_dict = {
|
||||
# TinyLlama-1.1B-Chat-v1.0
|
||||
f"{llm_models_root()}/llama-models-v2/TinyLlama-1.1B-Chat-v1.0":
|
||||
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
# vicuna-7b-v1.3
|
||||
f"{llm_models_root()}/vicuna-7b-v1.3":
|
||||
f"{llm_venv.get_working_directory()}/lmsys/vicuna-7b-v1.3",
|
||||
# medusa-vicuna-7b-v1.3
|
||||
f"{llm_models_root()}/medusa-vicuna-7b-v1.3":
|
||||
f"{llm_venv.get_working_directory()}/FasterDecoding/medusa-vicuna-7b-v1.3",
|
||||
# llama3.1-medusa-8b-hf_v0.1
|
||||
f"{llm_models_root()}/llama3.1-medusa-8b-hf_v0.1":
|
||||
f"{llm_venv.get_working_directory()}/nvidia/Llama-3.1-8B-Medusa-FP8",
|
||||
# Llama-3.1-8B-Instruct
|
||||
f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct":
|
||||
f"{llm_venv.get_working_directory()}/meta-llama/Llama-3.1-8B-Instruct",
|
||||
# DeepSeek-V3-Lite/bf16
|
||||
f"{llm_models_root()}/DeepSeek-V3-Lite/bf16":
|
||||
f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/bf16",
|
||||
# EAGLE3-LLaMA3.1-Instruct-8B
|
||||
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B":
|
||||
f"{llm_venv.get_working_directory()}/yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
}
|
||||
|
||||
for src, dst in src_dst_dict.items():
|
||||
@ -129,3 +142,22 @@ def test_llmapi_quickstart_atexit(llm_root, engine_dir, llm_venv):
|
||||
llm_root
|
||||
) / "tests/integration/defs/examples/run_llm_quickstart_atexit.py"
|
||||
llm_venv.run_cmd([str(script_path)])
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5365825")
|
||||
def test_llmapi_speculative_decoding_mtp(llm_root, engine_dir, llm_venv):
|
||||
_run_llmapi_example(llm_root, engine_dir, llm_venv,
|
||||
"llm_speculative_decoding.py", "MTP", "--model",
|
||||
f"{llm_models_root()}/DeepSeek-V3-Lite/bf16")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5365825")
|
||||
def test_llmapi_speculative_decoding_eagle3(llm_root, engine_dir, llm_venv):
|
||||
_run_llmapi_example(llm_root, engine_dir, llm_venv,
|
||||
"llm_speculative_decoding.py", "EAGLE3")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5365825")
|
||||
def test_llmapi_speculative_decoding_ngram(llm_root, engine_dir, llm_venv):
|
||||
_run_llmapi_example(llm_root, engine_dir, llm_venv,
|
||||
"llm_speculative_decoding.py", "NGRAM")
|
||||
|
||||
@ -25,4 +25,7 @@ l0_sanity_check:
|
||||
- llmapi/test_llm_examples.py::test_llmapi_example_multilora
|
||||
- llmapi/test_llm_examples.py::test_llmapi_example_guided_decoding
|
||||
- llmapi/test_llm_examples.py::test_llmapi_example_logits_processor
|
||||
- llmapi/test_llm_examples.py::test_llmapi_speculative_decoding_mtp
|
||||
- llmapi/test_llm_examples.py::test_llmapi_speculative_decoding_eagle3
|
||||
- llmapi/test_llm_examples.py::test_llmapi_speculative_decoding_ngram
|
||||
- examples/test_llm_api_with_mpi.py::test_llm_api_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user