diff --git a/docs/source/features/sampling.md b/docs/source/features/sampling.md index 5d6db8626b..43c8ed4d43 100644 --- a/docs/source/features/sampling.md +++ b/docs/source/features/sampling.md @@ -1,19 +1,50 @@ # Sampling -The PyTorch backend supports most of the sampling features that are supported on the C++ backend, such as temperature, top-k and top-p sampling, beam search, stop words, bad words, penalty, context and generation logits, log probability and logits processors + +The Pytorch backend supports a wide variety of features, listed below: + +| Forward Pass | Sampling Strategies | Sampling Features | +|--------------------|----------------------------------|--------------------------------| +| No drafting | Greedy | Guided Decoding | +| Draft target model | TopP | Plugging Logits Post-Processor | +| Eagle 3 | TopK | Temperature | +| Ngram | TopK + TopP | MinP | +| | Beam Search | Embedding / Logits Bias | +| | Best of / n (composable) | Stop criteria | +| | Rejection sampling (composable) | Return Logits | +| | | Return LogProbs | +| | | TopK LogProbs | ## General usage -To use the feature: +There are two sampling backends available. -1. Enable the `enable_trtllm_sampler` option in the `LLM` class -2. Pass a [`SamplingParams`](source:tensorrt_llm/sampling_params.py#L125) object with the desired options to the `generate()` function +* Torch Sampler +* TRTLLM Sampler -The following example prepares two identical prompts which will give different results due to the sampling parameters chosen: +Torch Sampler currently supports a superset of features of TRTLLM Sampler, and is intended as the long term solution. One can specify which sampler to use explicitly with: + +```python +from tensorrt_llm import LLM + +# Chooses TorchSampler explicitly +llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8', + sampler_type="TorchSampler") + +# Chooses TRTLLMSampler explicitly +llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8', + sampler_type="TRTLLMSampler") +``` + +By default, the sampling backend is chosen to be `auto`. This will use: + +* TRTLLM Sampler when using Beam Search. +* Torch Sampler otherwise. + +Here is an example to run a model with basic usage of sampling parameters. This example prepares two identical prompts which will give different results due to the sampling parameters chosen: ```python from tensorrt_llm import LLM, SamplingParams -llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8', - enable_trtllm_sampler=True) +llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8') sampling_params = SamplingParams( temperature=1.0, top_k=8, @@ -23,7 +54,24 @@ llm.generate(["Hello, my name is", "Hello, my name is"], sampling_params) ``` -Note: The `enable_trtllm_sampler` option is not currently supported when using speculative decoders, such as MTP or Eagle-3, so there is a smaller subset of sampling options available. +It is also possible to specify different sampling parameters on a per-prompt basis: + +```python +from tensorrt_llm import LLM, SamplingParams +llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8') +sampling_params_0 = SamplingParams( + temperature=1.0, + top_k=8, + top_p=0.5, + ) +sampling_params_1 = SamplingParams( + top_k=4, + ) +llm.generate(["Hello, my name is", + "Hello, my name is"], + [sampling_params_0, + sampling_params_1]) +``` ### LLM API sampling behavior when using Torch Sampler @@ -79,8 +127,6 @@ To enable beam search, you must: 1. Enable the `use_beam_search` option in the `SamplingParams` object 2. Set the `max_beam_width` parameter in the `LLM` class to match the `best_of` parameter in `SamplingParams` -3. Disable overlap scheduling using the `disable_overlap_scheduler` parameter of the `LLM` class -4. Disable the usage of CUDA Graphs by passing `None` to the `cuda_graph_config` parameter of the `LLM` class Parameter Configuration: - `best_of`: Controls the number of beams processed during generation (beam width) @@ -93,10 +139,8 @@ The following example demonstrates beam search with a beam width of 4, returning ```python from tensorrt_llm import LLM, SamplingParams llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8', - enable_trtllm_sampler=True, max_beam_width=4, # must equal SamplingParams.best_of - disable_overlap_scheduler=True, - cuda_graph_config=None) + ) sampling_params = SamplingParams( best_of=4, # must equal LLM.max_beam_width use_beam_search=True,