mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
76 lines
2.9 KiB
Plaintext
76 lines
2.9 KiB
Plaintext
# PyTorch Backend
|
|
|
|
```{note}
|
|
Note:
|
|
This feature is currently experimental, and the related API is subjected to change in future versions.
|
|
```
|
|
|
|
To enhance the usability of the system and improve developer efficiency, TensorRT-LLM launches a new experimental backend based on PyTorch.
|
|
|
|
The PyTorch backend of TensorRT-LLM is available in version 0.17 and later. You can try it via importing `tensorrt_llm._torch`.
|
|
|
|
## Quick Start
|
|
|
|
Here is a simple example to show how to use `tensorrt_llm._torch.LLM` API with Llama model.
|
|
|
|
```{literalinclude} ../../examples/pytorch/quickstart.py
|
|
:language: python
|
|
:linenos:
|
|
```
|
|
|
|
## Quantization
|
|
|
|
The PyTorch backend supports FP8 and NVFP4 quantization. You can pass quantized models in HF model hub,
|
|
which are generated by [TensorRT Model Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer).
|
|
|
|
```python
|
|
from tensorrt_llm._torch import LLM
|
|
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8')
|
|
llm.generate("Hello, my name is")
|
|
```
|
|
|
|
Or you can try the following commands to get a quantized model by yourself:
|
|
|
|
```bash
|
|
git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git
|
|
cd TensorRT-Model-Optimizer/examples/llm_ptq
|
|
scripts/huggingface_example.sh --model <huggingface_model_card> --quant fp8 --export_fmt hf
|
|
```
|
|
|
|
## Sampling
|
|
|
|
The PyTorch backend supports most of the sampling features that are supported on the C++ backend, such as temperature, top-k and top-p sampling, stop words, bad words, penalty, context and generation logits, and log probs.
|
|
|
|
In order to use this feature, it is necessary to enable option `enable_trtllm_sampler` in the `LLM` class, and pass a `SamplingParams` object with the desired options as well. The following example prepares two identical prompts which will give different results due to the sampling parameters chosen:
|
|
|
|
```python
|
|
from tensorrt_llm._torch import LLM
|
|
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8',
|
|
enable_trtllm_sampler=True)
|
|
sampling_params = SamplingParams(
|
|
temperature=1.0,
|
|
top_k=8,
|
|
top_p=0.5,
|
|
)
|
|
llm.generate(["Hello, my name is",
|
|
"Hello, my name is"], sampling_params)
|
|
```
|
|
|
|
When using speculative decoders such as MTP or Eagle-3, the `enable_trtllm_sampler` option is not yet supported and therefore the subset of sampling options available is more restricted.
|
|
|
|
## Developer Guide
|
|
|
|
- [Architecture Overview](./torch/arch_overview.md)
|
|
- [Adding a New Model](./torch/adding_new_model.md)
|
|
- [Examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/pytorch/README.md)
|
|
|
|
## Key Components
|
|
|
|
- [Attention](./torch/attention.md)
|
|
- [KV Cache Manager](./torch/kv_cache_manager.md)
|
|
- [Scheduler](./torch/scheduler.md)
|
|
|
|
## Known Issues
|
|
|
|
- The PyTorch workflow on SBSA is incompatible with bare metal environments like Ubuntu 24.04. Please use the [PyTorch NGC Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) for optimal support on SBSA platforms.
|