TensorRT-LLMs/examples/pytorch
Thor Johnsen 5d438be59a
[TRTLLM-5000][feat] Pytorch implementation of ngram drafter (#3936)
* v1.5

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

v1.5.4 Add back draft_overhead to spec dec stats

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>

* v1.5.5: fix CI error

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* v1.6: fix CI error 8196 > 8192

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* Address reviewer concerns

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>

* Address reviewer concerns

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>

* precommit run

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>

* v2.0: Address reviewer concerns

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* v2.1: add fix from wili

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* Revert changes that require use of TypeAlias because that requires python version >= 3.10

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>

---------

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>
Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>
Co-authored-by: wili-65535 <wili-65535@users.noreply.github.com>
2025-05-21 10:40:00 +08:00
..
out_of_tree_example feat: Add pp support for hybrid attn/mamba model (#4358) 2025-05-19 14:47:45 +08:00
quickstart_advanced.py [TRTLLM-5000][feat] Pytorch implementation of ngram drafter (#3936) 2025-05-21 10:40:00 +08:00
quickstart_multimodal.py fix: [nvbug/5252057] Fix kv cache reuse on PyTorch multimodal (#4025) 2025-05-02 10:53:06 -07:00
quickstart.py chore: Simplify quickstart of PyTorch flow (#3000) 2025-03-24 14:32:17 +08:00
README.md [TRTLLM-5000][feat] Pytorch implementation of ngram drafter (#3936) 2025-05-21 10:40:00 +08:00
star_attention.py [TRTLLM-3925, https://nvbugs/5245262] [fix] Normalize LLM.generate API (#3985) 2025-05-07 11:06:23 +08:00

TRT-LLM with PyTorch

Run the quick start script:

python3 quickstart.py

Run the advanced usage example script:

# BF16
python3 quickstart_advanced.py --model_dir meta-llama/Llama-3.1-8B-Instruct

# FP8
python3 quickstart_advanced.py --model_dir nvidia/Llama-3.1-8B-Instruct-FP8

# BF16 + TP=2
python3 quickstart_advanced.py --model_dir meta-llama/Llama-3.1-8B-Instruct --tp_size 2

# FP8 + TP=2
python3 quickstart_advanced.py --model_dir nvidia/Llama-3.1-8B-Instruct-FP8 --tp_size 2

# FP8(e4m3) kvcache
python3 quickstart_advanced.py --model_dir nvidia/Llama-3.1-8B-Instruct-FP8 --kv_cache_dtype fp8

# BF16 + TP=8
python3 quickstart_advanced.py --model_dir nvidia/Llama-3_1-Nemotron-Ultra-253B-v1 --tp_size 8

# Nemotron-H requires disabling cache reuse in kv cache
python3 quickstart_advanced.py --model_dir nvidia/Nemotron-H-8B-Base-8K --disable_kv_cache_reuse --max_batch_size 8

Run the multimodal example script:

# default inputs
python3 quickstart_multimodal.py --model_dir Efficient-Large-Model/NVILA-8B --modality image [--use_cuda_graph]

# user inputs
# supported modes:
# (1) N prompt, N media (N requests are in-flight batched)
# (2) 1 prompt, N media
# Note: media should be either image or video. Mixing image and video is not supported.
python3 quickstart_multimodal.py --model_dir Efficient-Large-Model/NVILA-8B --modality video --prompt "Tell me what you see in the video briefly." "Describe the scene in the video briefly." --media "https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4" "https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4" --max_tokens 128 [--use_cuda_graph]

Supported Models

Architecture Model HuggingFace Example Modality
BertForSequenceClassification BERT-based textattack/bert-base-uncased-yelp-polarity L
DeepseekV3ForCausalLM DeepSeek-V3 deepseek-ai/DeepSeek-V3 L
Gemma3ForCausalLM Gemma3 google/gemma-3-1b-it L
LlavaLlamaModel VILA Efficient-Large-Model/NVILA-8B L + V
LlavaNextForConditionalGeneration LLaVA-NeXT llava-hf/llava-v1.6-mistral-7b-hf L + V
LlamaForCausalLM Llama 3
Llama 3.1
Llama 2
LLaMA
meta-llama/Meta-Llama-3.1-70B L
Llama4ForConditionalGeneration Llama 4 Scout
Llama 4 Maverick
meta-llama/Llama-4-Scout-17B-16E-Instruct
meta-llama/Llama-4-Maverick-17B-128E-Instruct
L + V
MistralForCausalLM Mistral mistralai/Mistral-7B-v0.1 L
MixtralForCausalLM Mixtral mistralai/Mixtral-8x7B-v0.1 L
MllamaForConditionalGeneration Llama 3.2 meta-llama/Llama-3.2-11B-Vision L
NemotronForCausalLM Nemotron-3
Nemotron-4
Minitron
nvidia/Minitron-8B-Base L
NemotronHForCausalLM Nemotron-H nvidia/Nemotron-H-8B-Base-8K
nvidia/Nemotron-H-47B-Base-8K
nvidia/Nemotron-H-56B-Base-8K
L
NemotronNASForCausalLM LLamaNemotron
LlamaNemotron Super
LlamaNemotron Ultra
nvidia/Llama-3_1-Nemotron-51B-Instruct
nvidia/Llama-3_3-Nemotron-Super-49B-v1
nvidia/Llama-3_1-Nemotron-Ultra-253B-v1
L
Qwen2ForCausalLM QwQ, Qwen2 Qwen/Qwen2-7B-Instruct L
Qwen2ForProcessRewardModel Qwen2-based Qwen/Qwen2.5-Math-PRM-7B L
Qwen2ForRewardModel Qwen2-based Qwen/Qwen2.5-Math-RM-72B L
Qwen2VLForConditionalGeneration Qwen2-VL Qwen/Qwen2-VL-7B-Instruct L + V
Qwen2_5_VLForConditionalGeneration Qwen2.5-VL Qwen/Qwen2.5-VL-7B-Instruct L + V

Note:

  • L: Language only
  • L + V: Language and Vision multimodal support
  • Llama 3.2 accepts vision input, but our support currently limited to text only.

Run the speculative decoding script:

# NGram drafter
python3 examples/pytorch/quickstart_advanced.py \
    --max_matching_ngram_size=2 \
    --spec_decode_nextn=4