TensorRT-LLMs/docs/source/features/speculative-decoding.md
Guoming Zhang 7f3f658d5f [None][doc] Rename TensorRT-LLM to TensorRT LLM. (#7554)
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
2025-09-09 12:16:03 +08:00

242 lines
14 KiB
Markdown

# Speculative Decoding
There are two flavors of speculative decoding currently supported in the PyTorch backend:
- The "one model" implementation -- a variant which inserts a drafter directly into the model code as a submodule.
- The "two model" implementation -- a variant which produces draft tokens in the `PyExecutor`. The draft tokens are attached to requests before they are passed
into the target model's `ModelEngine`.
In general, the one model implementation is faster. It's able to achieve better performance in extreme low latency
scenarios because it can launch the entire drafting loop as a single CUDA graph. The trade off is flexibility. The one model implementation
does not support dynamic draft lengths. Additionally, only a subset of models/speculative decoding algorithms support the one model implementation.
The table below enumerates all of the algorithm/model combinations that are supported.
| Speculative Decoding Algorithm | Model |
| ------------------------------ | ------------------------------ |
| EAGLE 3 | Llama 4 Maverick |
| MTP | Deepseek V3/R1 |
| EAGLE-style MTP | Deepseek V3/R1 |
The two model implementation supports the following speculative decoding algorithms:
| Speculative Decoding Algorithm | Model |
| --------------------------------------------- | --------------------------------------------- |
| EAGLE 3 | Llama 4 Maverick, Llama 3.1 8B, Llama 3.3 70B |
| Draft/target | All models |
| NGram | All models |
| User-provided | All models |
## Quick Start
For all speculation algorithms, when speculation is enabled, a single sequence of draft tokens with length `max_draft_len` is created for every request. There is currently no way to dynamically disable speculation, thus speed ups are only observable at low batch sizes.
### Draft/Target
Draft/target is the simplest form of speculative decoding. In this approach, an arbitrary draft model is used to produce draft tokens. It is important to make sure that the draft and target models were trained with the same tokenizer, else the acceptance rate is extremely low and performance is regressed.
```python
from tensorrt_llm.llmapi import DraftTargetDecodingConfig
speculative_config = DraftTargetDecodingConfig(
max_draft_len=3, speculative_model="/path/to/draft_model")
llm = LLM("/path/to/target_model", speculative_config=speculative_config, disable_overlap_scheduler=True)
```
### EAGLE 3
The EAGLE 3 algorithm is described in the paper [EAGLE-3: Scaling up Inference Acceleration of Large Language Models via Training-Time Test](https://arxiv.org/pdf/2503.01840).
TRT-LLM supports a modified version of the algorithm presented in the paper: tree structures for draft sequences are not supported. Instead, each request uses a single sequence of draft tokens with length `max_draft_len`.
The following draft model checkpoints can be used for EAGLE 3:
* Llama 3 variants: [use the checkpoints from the authors of the original EAGLE 3 paper](https://huggingface.co/yuhuili).
* Llama 4 Maverick: [use the checkpoint from the NVIDIA HuggingFace repository](https://huggingface.co/nvidia/Llama-4-Maverick-17B-128E-Eagle3).
```python
from tensorrt_llm.llmapi import EagleDecodingConfig
# Enable to use the faster one-model implementation for Llama 4.
eagle3_one_model = False
speculative_config = EagleDecodingConfig(
max_draft_len=3, speculative_model="/path/to/draft_model", eagle3_one_model=eagle3_one_model)
# Only need to disable overlap scheduler if eagle3_one_model is False.
llm = LLM("/path/to/target_model", speculative_config=speculative_config, disable_overlap_scheduler=True)
```
### NGram
The NGram method is an implementation of [this Prompt Lookup Decoding algorithm](https://github.com/apoorvumang/prompt-lookup-decoding).
When the NGram algorithm is used, TRT-LLM will maintain a map from token prefixes to candidate draft sequences. For example, the 3-gram ["The ", " future ", " is"] could map to the draft sequence [" bright", " because"]. The prefixes are token sequences that are extracted from the prompt and the tokens generated by the target model. The NGram pool and matching procedure can be tuned with the following options:
* `max_draft_len`: Maximum draft candidate length.
* `max_matching_ngram_size`: Maximum prompt suffix length to match with keys in the pool.
* `is_public_pool`: If true, a single ngram pool is shared for all requests. Otherwise, each request has its own ngram pool.
* `is_keep_all`: If true, draft candidates will be retained in the pool forever. Otherwise, only the largest draft candidate is retained.
* `is_use_oldest`: If true, the oldest draft candidate is always proposed for a given match. Otherwise, the newest draft candidate is used. Only applicable if `is_keep_all == True` because `is_keep_all == False` means we'll only ever have a single value for each key.
```python
from tensorrt_llm.llmapi import NGramDecodingConfig
speculative_config = NGramDecodingConfig(
max_draft_len=3, max_matching_ngram_size=4, is_public_pool=True)
llm = LLM("/path/to/target_model", speculative_config=speculative_config, disable_overlap_scheduler=True)
```
### MTP
MTP is currently only supported by Deepseek. MTP can be tuned with the following configuration options:
* `max_draft_len`: Maximum draft candidate length.
* `num_nextn_predict_layers`: Number of MTP modules to use. Currently must match `max_draft_len`.
* `use_relaxed_acceptance_for_thinking`: If true, use relaxed decoding for reasoning models in the thinking phase. In this mode, speculation requirements are relaxed for the thinking phase - a draft token may be accepted if it appears in a candidate set constructed with `relaxed_topk` and `relaxed_delta`.
* `relaxed_topk`: The top K tokens are sampled from the target model's logits to create the initial candidate set for relaxed decoding.
* `relaxed_delta`: Used to further filter the top K candidate set for relaxed decoding. We remove tokens `t` for which `log(P(top 1 token)) - log(P(t)) > relaxed_delta`.
```python
from tensorrt_llm.llmapi import MTPDecodingConfig
speculative_config = MTPDecodingConfig(
max_draft_len=3, num_nextn_predict_layers=3)
llm = LLM("/path/to/deepseek_model", speculative_config=speculative_config)
```
### User-provided drafting
A completely user-defined drafting method can be supplied with a `UserProvidedDecodingConfig` that includes
* `max_draft_len`: Maximum draft candidate length.
* `drafter`: An object of type `Drafter` that implements the `prepare_draft_tokens` method (see [Developer Guide](speculative-decoding.md#developer-guide) 7.)
* `resource_manager`: An optional `ResourceManager` object (see [Developer Guide](speculative-decoding.md#developer-guide) 4.)
```python
from tensorrt_llm.llmapi import UserProvidedDecodingConfig
speculative_config = UserProvidedDecodingConfig(
max_draft_len=3, drafter=MyDrafter())
llm = LLM("/path/to/target_model", speculative_config=speculative_config)
```
## Usage with `trtllm-bench` and `trtllm-serve`
Speculative decoding options must be specified via `--extra_llm_api_options config.yaml` for both `trtllm-bench` and `trtllm-serve`. All speculative decoding options can be specified in this YAML file. An additional `decoding_type` option is used to specify the type of speculation to use. The available options are:
* `MTP`
* `Eagle` (for EAGLE 3)
* `NGram`
* `DraftTarget`
The rest of the argument names/valid values are the same as in their corresponding configuration class described in the Quick Start section. For example, a YAML configuration could look like this:
```
disable_overlap_scheduler: true
speculative_config:
decoding_type: Eagle
max_draft_len: 4
speculative_model: /path/to/draft/model
```
## Developer Guide
This section describes the components of a speculative decoding algorithm. All of the interfaces are defined in [`_torch/speculative/interface.py`](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/speculative/interface.py).
1. `SpeculativeDecodingMode`: this is a simple `IntEnum`, one for each supported algorithm. There are a few
nontrivial methods, however.
- `needs_kv_cache_rewind`. See "KV Cache Rewind" below. In general, this is true for all two model speculative
decoding algorithms.
- `extend_ctx`: If true, the speculative decoding dispatches requests with `py_draft_tokens` attached to them
to the *prefill* version of the attention kernels. This usually needs to be true. The exception is when you're on
Blackwell using the TensorRT LLM attention backend. In that case, use the generation kernels for better performance.
This optimized kernel has one limitation; all draft lengths must be the same (or padding must be used) in this case.
> *These may be refactored in the future to reduce the difficulty of adding a new speculative
decoding algorithm. `extend_ctx` in particular is problematic. Ideally, we would move all of the kernel dispatching logic
to a lower level of abstraction.*
2. `SpecMetadata`: Defines all metadata that should be passed to the model during the forward pass to facilitate speculative decoding.
Each speculative decoding algorithm defines a subclass of `SpecMetadata`. Similar to `AttentionMetadata`, each `CUDAGraphRunner` owns
its own `SpecMetadata`, and CUDA-graph compatible `SpecMetadata` objects may be created by invoking `create_cuda_graph_metadata(batch_size)`.
`SpecMetadata` has many fields. Many of them are exclusively used by the one model implementation. For the two model implementation, the
main purpose of `SpecMetadata` is to facilitate the capture of hidden states. In EAGLE 3, we need to capture hidden states from the
target model to use as draft model inputs. The `SpecMetadata` stores a list of layers to capture and the model calls
`maybe_capture_hidden_states(layer_id, hidden_states, residual)` during its forward pass. If the layer ID is in the list of layers to capture,
the hidden states are saved. For CUDA graph compatibility, these may be saved in pre-allocated buffers.
`SpecMetadata` is derived from a `SpecConfig` object in `_torch/speculative/utils.py`. There are a few other optional components created in
this file too:
4. `ResourceManager`: Create a custom resource manager to prepare and free resources before and after target forward passes; see
the section on `ResourceManager` in `arch.md`. This is used by the n-gram method to manage its pool. The one model implementation also uses
`ResourceManager`s to manage hidden states.
5. `Sampler`: Each speculative decoding algorithm can optionally create its own sampler. This is mostly used by the one model implementation.
The default `TorchSampler` is used as a fallback if no custom sampler is provided. EAGLE 3 two model also has a simple custom decoder to handle
differences in the draft/target model vocab sizes.
6. `Worker`: This is exclusive to the one-model implementation. The `Worker` is the object that gets injected into the target model as a
submodule.
7. `Drafter`: All of the logic required to actually produce draft tokens should be implemented in a `Drafter` subclass. There is a single
abstract method, `prepare_draft_tokens`. It takes a set of requests (a `ScheduledRequests` object) and returns nothing. The [`PyExecutor`](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/pyexecutor/py_executor.py#L162) expects
draft tokens to be attached to the `py_draft_tokens` field of request that speculation is to be done for.
## Two Model Speculative Decoding Architecture
Two-model based speculation implementations do not support overlap scheduler. It will be disabled automatically.
In this approach, there are two new steps to the `PyExecutor`'s `_executor_loop`.
* `_prepare_draft_requests`
* `_prepare_draft_tokens`
### `_prepare_draft_requests`
This stage occurs for all speculative decoding algorithms before scheduling. The purpose
of this stage is to make the KV cache and scheduler aware of the fact that speculative decoding
will occur. Draft tokens take up extra KV cache pages and count towards the executor's
`max_num_tokens` limit. Thus, we need a way to tell the scheduler that drafting will occur
**before we do the scheduling**.
To achieve this, we simply attach the maximum number of draft tokens to each request. The
scheduler and KV cache manager will automatically account for tokens attached to the
`py_draft_tokens` attribute.
```python
for req in self.active_requests:
req.py_draft_tokens = [0] * max_draft_len
```
### `_prepare_draft_tokens`
This stage occurs after scheduling and KV cache allocation. The purpose of this stage
is to attach draft tokens to the `py_draft_tokens` attribute. This occurs by calling `self.drafter.prepare_draft_tokens`;
each speculative decoding algorithm should have a concrete instance of the `Drafter` class associated with it that defines
the drafting logic.
In addition to producing all "real" draft tokens, `_prepare_draft_tokens` currently must also pad
all `py_draft_tokens` to the maximum draft length. This is a CUDA graph limitation - the target
model captures its CUDA graphs using the maximum number of draft tokens on each request.
### Verification and Sampling
Once the draft tokens are obtained, the target model runs a forward pass through the usual flow.
Everything is the same, except that the logits for all the draft tokens are returned and passed
to the sampler.
Currently, only greedy sampling is supported for speculative decoding. A draft token is accepted if
matches the previously decoded token exactly. For example, suppose there is a generation request
`[t, d1, d2, d3]`, where `d1`, `d2`, and `d3` are drat tokens. Suppose the token after `t` is `d1`
(determined with the `argmax` of the logits). `d1` is then accepted. If the token after `d1` is `d2`,
then `d2` can be accepted. And so on until draft tokens cannot be accepted anymore.
### KV Cache Rewind
KV cache space allocated to rejected tokens is freed before the next iteration. This is achieved by setting
the `request.py_rewind_len` attribute to `num_draft_tokens_allocated - num_accepted_tokens`. The pages are
freed as part of the `resource_manager.free_resources` routine.
The purpose of KV cache rewind is to avoid complicated page reuse logic in the KV cache manager's `prepare_resources`
function. In practice, this is very cheap since the blocks are just marked as available; no memory is actually freed.