This commit is contained in:
Stefan Niebler 2026-01-13 14:53:02 +01:00 committed by GitHub
commit d00a26d715
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,19 +1,50 @@
# Sampling
The PyTorch backend supports most of the sampling features that are supported on the C++ backend, such as temperature, top-k and top-p sampling, beam search, stop words, bad words, penalty, context and generation logits, log probability and logits processors
The Pytorch backend supports a wide variety of features, listed below:
| Forward Pass | Sampling Strategies | Sampling Features |
|--------------------|----------------------------------|--------------------------------|
| No drafting | Greedy | Guided Decoding |
| Draft target model | TopP | Plugging Logits Post-Processor |
| Eagle 3 | TopK | Temperature |
| Ngram | TopK + TopP | MinP |
| | Beam Search | Embedding / Logits Bias |
| | Best of / n (composable) | Stop criteria |
| | Rejection sampling (composable) | Return Logits |
| | | Return LogProbs |
| | | TopK LogProbs |
## General usage
To use the feature:
There are two sampling backends available.
1. Enable the `enable_trtllm_sampler` option in the `LLM` class
2. Pass a [`SamplingParams`](source:tensorrt_llm/sampling_params.py#L125) object with the desired options to the `generate()` function
* Torch Sampler
* TRTLLM Sampler
The following example prepares two identical prompts which will give different results due to the sampling parameters chosen:
Torch Sampler currently supports a superset of features of TRTLLM Sampler, and is intended as the long term solution. One can specify which sampler to use explicitly with:
```python
from tensorrt_llm import LLM
# Chooses TorchSampler explicitly
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8',
sampler_type="TorchSampler")
# Chooses TRTLLMSampler explicitly
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8',
sampler_type="TRTLLMSampler")
```
By default, the sampling backend is chosen to be `auto`. This will use:
* TRTLLM Sampler when using Beam Search.
* Torch Sampler otherwise.
Here is an example to run a model with basic usage of sampling parameters. This example prepares two identical prompts which will give different results due to the sampling parameters chosen:
```python
from tensorrt_llm import LLM, SamplingParams
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8',
enable_trtllm_sampler=True)
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8')
sampling_params = SamplingParams(
temperature=1.0,
top_k=8,
@ -23,7 +54,24 @@ llm.generate(["Hello, my name is",
"Hello, my name is"], sampling_params)
```
Note: The `enable_trtllm_sampler` option is not currently supported when using speculative decoders, such as MTP or Eagle-3, so there is a smaller subset of sampling options available.
It is also possible to specify different sampling parameters on a per-prompt basis:
```python
from tensorrt_llm import LLM, SamplingParams
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8')
sampling_params_0 = SamplingParams(
temperature=1.0,
top_k=8,
top_p=0.5,
)
sampling_params_1 = SamplingParams(
top_k=4,
)
llm.generate(["Hello, my name is",
"Hello, my name is"],
[sampling_params_0,
sampling_params_1])
```
### LLM API sampling behavior when using Torch Sampler
@ -79,8 +127,6 @@ To enable beam search, you must:
1. Enable the `use_beam_search` option in the `SamplingParams` object
2. Set the `max_beam_width` parameter in the `LLM` class to match the `best_of` parameter in `SamplingParams`
3. Disable overlap scheduling using the `disable_overlap_scheduler` parameter of the `LLM` class
4. Disable the usage of CUDA Graphs by passing `None` to the `cuda_graph_config` parameter of the `LLM` class
Parameter Configuration:
- `best_of`: Controls the number of beams processed during generation (beam width)
@ -93,10 +139,8 @@ The following example demonstrates beam search with a beam width of 4, returning
```python
from tensorrt_llm import LLM, SamplingParams
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8',
enable_trtllm_sampler=True,
max_beam_width=4, # must equal SamplingParams.best_of
disable_overlap_scheduler=True,
cuda_graph_config=None)
)
sampling_params = SamplingParams(
best_of=4, # must equal LLM.max_beam_width
use_beam_search=True,