* Update TensorRT-LLM --------- Co-authored-by: RunningLeon <mnsheng@yeah.net> Co-authored-by: Tlntin <TlntinDeng01@Gmail.com> Co-authored-by: ZHENG, Zhen <zhengzhen.z@qq.com> Co-authored-by: Pham Van Ngoan <ngoanpham1196@gmail.com> Co-authored-by: Nathan Price <nathan@abridge.com> Co-authored-by: Tushar Goel <tushar.goel.ml@gmail.com> Co-authored-by: Mati <132419219+matichon-vultureprime@users.noreply.github.com> |
||
|---|---|---|
| .. | ||
| convert_checkpoint.py | ||
| README.md | ||
| requirements.txt | ||
Medusa Decoding
This document shows how to build and run a model using Medusa decoding(Github, BLOG) in TensorRT-LLM on single GPU, single node multiple GPU.
Overview
Different from other models, Medusa decoding need a base model and Medusa heads.
The TensorRT-LLM Medusa Decoding implementation can be found in tensorrt_llm/models/medusa/model.py, which actually adds MedusaHeads to a base model.
For more info about Medusa visit speculative decoding documentation.
Support Matrix
- GPU Compute Capability >= 8.0 (Ampere or newer)
- FP16
- BF16
- PAGED_KV_CACHE
- Tensor Parallel
Usage
The TensorRT-LLM Medusa example code is located in examples/medusa. There is one convert_checkpoint.py file to convert and build the TensorRT engine(s) needed to run models with Medusa decoding support.
In our example, we use the model from huggingface FasterDecoding/medusa-vicuna-7b-v1.3, which is a LLAMA based model.
Build TensorRT engine(s)
Get the weights by downloading base model vicuna-7b-v1.3 and Medusa Heads medusa-vicuna-7b-v1.3 from HF.
pip install -r requirements.txt
git lfs install
git clone https://huggingface.co/lmsys/vicuna-7b-v1.3
https://huggingface.co/FasterDecoding/medusa-vicuna-7b-v1.3
We use convert_checkpoint.py script to convert the model for Medusa decoding into TensorRT-LLM checkpoint format.
Here we also add --fixed_num_medusa_heads 4 as medusa_num_heads is 2 in config.json of medusa-vicuna-7b-v1.3 but it actually has 4.
Here is the example:
# Convert and Build Medusa decoding support for vicuna-7b-v1.3
python convert_checkpoint.py --model_dir ./vicuna-7b-v1.3 \
--medusa_model_dir medusa-vicuna-7b-v1.3 \
--output_dir ./tllm_checkpoint_1gpu_medusa \
--dtype float16 \
--fixed_num_medusa_heads 4
trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_medusa \
--output_dir ./tmp/medusa/7B/trt_engines/fp16/1-gpu/ \
--gemm_plugin float16 \
--speculative_decoding_mode medusa \
--max_batch_size 8
# Convert and Build Medusa decoding support for vicuna-13b-v1.3 with 4-way tensor parallelism.
python convert_checkpoint.py --model_dir ./vicuna-7b-v1.3 \
--medusa_model_dir medusa-vicuna-7b-v1.3 \
--output_dir ./tllm_checkpoint_1gpu_medusa \
--dtype float16 \
--fixed_num_medusa_heads 4 \
--tp_size 4 \
--workers 4
trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_medusa \
--output_dir ./tmp/medusa/7B/trt_engines/fp16/1-gpu/ \
--gemm_plugin float16 \
--speculative_decoding_mode medusa \
--max_batch_size 8
Run
To run a TensorRT-LLM model with Medusa decoding support, we can use ../run.py script, with an additional argument --medusa_choices.
The --medusa_choices is of type list[list[int]], And also the built engine with Medusa decoding support.
Medusa decoding is supported by Python runtime and C++ runtime with inflight-batching. C++ runtime is recommended for performance.
For Python runtime use --use_py_session flag to run.py.
Note: Medusa decoding only supporting greedy decoding temperature=1.0 now. So also need --temperature 1.0.
# Medusa decoding using vicuna-7b-v1.3 model with 1 GPU
python ../run.py --engine_dir ./tmp/medusa/7B/trt_engines/fp16/1-gpu/ \
--tokenizer_dir ./vicuna-7b-v1.3/ \
--max_output_len=100 \
--medusa_choices="[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]" \
--temperature 1.0 \
--input_text "Once upon"
# Medusa decoding using vicuna-13b-v1.3 with 4 GPUs
mpirun -np 4 --allow-run-as-root --oversubscribe \
python ../run.py --engine_dir ./tmp/medusa/13B/trt_engines/fp16/4-gpu/ \
--tokenizer_dir ./vicuna-13b-v1.3/ \
--max_output_len=100 \
--medusa_choices="[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]" \
--temperature 1.0 \
--input_text "Once upon"
And you will see output like this if run successfully:
......
Input [Text 0]: "<s> Once upon"
Output [Text 0 Beam 0]: "a time, there was a young girl who loved to read. She would spend hours in the library, devouring books of all genres. She had a special love for fairy tales, and would often dream of living in a magical world where she could meet princes and princesses, and have adventures with talking animals.
One day, while she was reading a book, she came across a passage that spoke to her heart. It said, "You are the author of"
Summarization using Medusa decoding
# Medusa decoding using vicuna-7b-v1.3 model with 1 GPU
python ../summarize.py --engine_dir ./tmp/medusa/7B/trt_engines/fp16/1-gpu/ \
--hf_model_dir ./vicuna-7b-v1.3/ \
--tokenizer_dir ./vicuna-7b-v1.3/ \
--test_trt_llm \
--data_type fp16 \
--medusa_choices="[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]" \
--use_py_session \
--temperature 1.0 \
--batch_size 1
# Medusa decoding using vicuna-13b-v1.3 with 4 GPUs
mpirun -np 4 --allow-run-as-root --oversubscribe \
python ../summarize.py --engine_dir ./tmp/medusa/13B/trt_engines/fp16/4-gpu/ \
--hf_model_dir ./vicuna-13b-v1.3/ \
--tokenizer_dir ./vicuna-13b-v1.3/ \
--test_trt_llm \
--data_type fp16 \
--medusa_choices="[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]" \
--use_py_session \
--temperature 1.0 \
--batch_size 1