TensorRT-LLMs/docs/source/torch.md
Daniel Cámpora 22281cfc55
doc: Added documentation for enable_trtllm_sampler. (#4990)
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
Signed-off-by: Daniel Cámpora <961215+dcampora@users.noreply.github.com>
Co-authored-by: Abigail McCarthy <20771501+a-mccarthy@users.noreply.github.com>
2025-06-12 18:34:15 +08:00

2.9 KiB

PyTorch Backend

Note:
This feature is currently experimental, and the related API is subjected to change in future versions.

To enhance the usability of the system and improve developer efficiency, TensorRT-LLM launches a new experimental backend based on PyTorch.

The PyTorch backend of TensorRT-LLM is available in version 0.17 and later. You can try it via importing tensorrt_llm._torch.

Quick Start

Here is a simple example to show how to use tensorrt_llm._torch.LLM API with Llama model.

    :language: python
    :linenos:

Quantization

The PyTorch backend supports FP8 and NVFP4 quantization. You can pass quantized models in HF model hub, which are generated by TensorRT Model Optimizer.

from tensorrt_llm._torch import LLM
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8')
llm.generate("Hello, my name is")

Or you can try the following commands to get a quantized model by yourself:

git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git
cd TensorRT-Model-Optimizer/examples/llm_ptq
scripts/huggingface_example.sh --model <huggingface_model_card> --quant fp8 --export_fmt hf

Sampling

The PyTorch backend supports most of the sampling features that are supported on the C++ backend, such as temperature, top-k and top-p sampling, stop words, bad words, penalty, context and generation logits, and log probs.

In order to use this feature, it is necessary to enable option enable_trtllm_sampler in the LLM class, and pass a SamplingParams object with the desired options as well. The following example prepares two identical prompts which will give different results due to the sampling parameters chosen:

from tensorrt_llm._torch import LLM
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8',
          enable_trtllm_sampler=True)
sampling_params = SamplingParams(
        temperature=1.0,
        top_k=8,
        top_p=0.5,
    )
llm.generate(["Hello, my name is",
            "Hello, my name is"], sampling_params)

When using speculative decoders such as MTP or Eagle-3, the enable_trtllm_sampler option is not yet supported and therefore the subset of sampling options available is more restricted.

Developer Guide

Key Components

Known Issues

  • The PyTorch workflow on SBSA is incompatible with bare metal environments like Ubuntu 24.04. Please use the PyTorch NGC Container for optimal support on SBSA platforms.