[TRTLLM-5990][doc] trtllm-serve doc improvement. (#5220)

Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
This commit is contained in:
Guoming Zhang 2025-08-05 13:04:01 +08:00 committed by GitHub
parent d53cc2374b
commit db51ab11a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 258 additions and 97 deletions

View File

@ -0,0 +1,9 @@
trtllm-serve
=======================
.. toctree::
:maxdepth: 1
trtllm-serve
run-benchmark-with-trtllm-serve

View File

@ -0,0 +1,222 @@
# Run benchmarking with `trtllm-serve`
TensorRT-LLM provides the OpenAI-compatiable API via `trtllm-serve` command.
A complete reference for the API is available in the [OpenAI API Reference](https://platform.openai.com/docs/api-reference).
This step-by-step tutorial covers the following topics for running online serving benchmarking with Llama 3.1 70B:
* Methodology Introduction
* Launch the OpenAI-Compatibale Server with NGC container
* Run the performance benchmark
* Using `extra_llm_api_options`
## Methodology Introduction
The overall performance benchmarking involves:
1. Launch the OpenAI-compatible service with `trtllm-serve`
2. Run the benchmark with [benchmark_serving.py](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/serve/scripts/benchmark_serving.py)
## Launch the NGC container
TensorRT-LLM distributes the pre-built container on [NGC Catalog](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release/tags).
You can launch the container using the following command:
```bash
docker run --rm --ipc host -p 8000:8000 --gpus all -it nvcr.io/nvidia/tensorrt-llm/release
```
## Start the trtllm-serve service
> [!WARNING]
> The commands and configurations presented in this document are for illustrative purposes only.
> They serve as examples and may not deliver the optimal performance for your specific use case.
> Users are encouraged to tune the parameters based on their hardware and workload.
For benchmarking purposes, first create a bash script using the following code and name it start.sh.
```bash
#! /bin/bash
model_path=/path/to/llama3.1_70B
extra_llm_api_file=/tmp/extra-llm-api-config.yml
cat << EOF > ${extra_llm_api_file}
enable_attention_dp: false
print_iter_log: true
cuda_graph_config:
enable_padding: true
max_batch_size: 1024
kv_cache_config:
dtype: fp8
EOF
trtllm-serve ${model_path} \
--max_batch_size 1024 \
--max_num_tokens 2048 \
--max_seq_len 1024 \
--kv_cache_free_gpu_memory_fraction 0.9 \
--tp_size 1 \
--ep_size 1 \
--trust_remote_code \
--extra_llm_api_options ${extra_llm_api_file}
```
> [!NOTE]
> The trtllm-llmapi-launch is a script that launches the LLM-API code on
> Slurm-like systems, and can support multi-node and multi-GPU setups.
> e.g, trtllm-llmapi-launch trtllm-serve .....
Run the start.sh script in the **background** with the following command:
```bash
bash -x start.sh &
```
Once the serving is set up, it will generate the output log as shown below.
```bash
INFO: Started server process [80833]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://localhost:8000 (Press CTRL+C to quit)
```
## Run the benchmark
Similar to starting trtllm-serve, create a script to execute the benchmark using the following code and name it bench.sh.
```bash
concurrency_list="1 2 4 8 16 32 64 128 256"
multi_round=5
isl=1024
osl=1024
result_dir=/tmp/llama3.1_output
model_path=/path/to/llama3.1_70B
for concurrency in ${concurrency_list}; do
num_prompts=$((concurrency * multi_round))
python -m tensorrt_llm.serve.scripts.benchmark_serving \
--model ${model_path} \
--backend openai \
--dataset-name "random" \
--random-input-len ${isl} \
--random-output-len ${osl} \
--random-prefix-len 0 \
--num-prompts ${num_prompts} \
--max-concurrency ${concurrency} \
--ignore-eos \
--save-result \
--result-dir "${result_dir}" \
--result-filename "concurrency_${concurrency}.json" \
--percentile-metrics "ttft,tpot,itl,e2el"
done
```
Then we can run the benchmark using the command below.
```bash
bash -x bench.sh &> output_bench.log
```
Below is some example TensorRT-LLM serving benchmark output. Your actual results may vary.
```
============ Serving Benchmark Result ============
Successful requests: 1
Benchmark duration (s): 1.64
Total input tokens: 1024
Total generated tokens: 1024
Request throughput (req/s): 0.61
Output token throughput (tok/s): 622.56
Total Token throughput (tok/s): 1245.12
User throughput (tok/s): 623.08
Mean Request AR: 0.9980
Median Request AR: 0.9980
---------------Time to First Token----------------
Mean TTFT (ms): 12.83
Median TTFT (ms): 12.83
P99 TTFT (ms): 12.83
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 1.59
Median TPOT (ms): 1.59
P99 TPOT (ms): 1.59
---------------Inter-token Latency----------------
Mean ITL (ms): 1.59
Median ITL (ms): 1.59
P99 ITL (ms): 1.77
----------------End-to-end Latency----------------
Mean E2EL (ms): 1643.44
Median E2EL (ms): 1643.44
P99 E2EL (ms): 1643.44
==================================================
```
### Key Metrics
* Median Time to First Token (TTFT)
* The typical time elapsed from when a request is sent until the first output token is generated.
* Median Time Per Output Token (TPOT)
* The typical time required to generate each token *after* the first one.
* Median Inter-Token Latency (ITL)
* The typical time delay between the completion of one token and the completion of the next.
* Median End-to-End Latency (E2EL)
* The typical total time from when a request is submitted until the final token of the response is received.
* Total Token Throughput
* The combined rate at which the system processes both input (prompt) tokens and output (generated) tokens.
## About `extra_llm_api_options`
trtllm-serve provides `extra_llm_api_options` knob to **overwrite** the parameters specified by trtllm-serve.
Generally, We create a YAML file that contains various performance switches.
e.g
```yaml
cuda_graph_config:
padding_enabled: true
print_iter_log: true
kv_cache_dtype: fp8
enable_attention_dp: true
```
The following is a list of common performance switches.
#### `kv_cache_config`
&emsp;**Description**: A section for configuring the Key-Value (KV) cache.
&emsp;**Options**:
&emsp;&emsp;dtype: Sets the data type for the KV cache.
&emsp;&emsp;**Default**: auto (uses the data type specified in the model checkpoint).
#### `cuda_graph_config`
&emsp;**Description**: A section for configuring CUDA graphs to optimize performance.
&emsp;**Options**:
&emsp;&emsp;enable\_padding: If true, input batches are padded to the nearest cuda\_graph\_batch\_size. This can significantly improve performance.
&emsp;&emsp;**Default**: false
&emsp;&emsp;max\_batch\_size: Sets the maximum batch size for which a CUDA graph will be created.
&emsp;&emsp;**Default**: 0
&emsp;&emsp;**Recommendation**: Set this to the same value as the \--max\_batch\_size command-line option.
&emsp;&emsp;batch\_sizes: A specific list of batch sizes to create CUDA graphs for.
&emsp;&emsp;**Default**: None
#### `moe_config`
&emsp;**Description**: Configuration for Mixture-of-Experts (MoE) models.
&emsp;**Options**:
&emsp;&emsp;backend: The backend to use for MoE operations.
&emsp;&emsp;**Default**: CUTLASS
#### `attention_backend`
&emsp;**Description**: The backend to use for attention calculations.
&emsp;**Default**: TRTLLM
See the [TorchLlmArgs class](https://nvidia.github.io/TensorRT-LLM/llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs) for the full list of options which can be used in the extra\_llm\_api\_options`.`

View File

@ -175,26 +175,6 @@ TRT-LLM multimodal supports the following modalities and data types (depending o
]}
Benchmark
---------
You can use any benchmark clients compatible with OpenAI API to test serving performance of ``trtllm_serve``, we recommend ``genai-perf`` and here is a benchmarking recipe.
First, install ``genai-perf`` with ``pip``:
.. code-block:: bash
pip install genai-perf
Then, :ref:`start a server<Starting a Server>` with ``trtllm-serve`` and ``TinyLlama-1.1B-Chat-v1.0``.
Finally, test performance with the following command:
.. literalinclude:: ../../../examples/serve/genai_perf_client.sh
:language: bash
:linenos:
Refer to `README <https://github.com/triton-inference-server/perf_analyzer/blob/main/genai-perf/README.md>`_ of ``genai-perf`` for more guidance.
Multi-node Serving with Slurm
-----------------------------
@ -278,3 +258,6 @@ Syntax
.. click:: tensorrt_llm.commands.serve:main
:prog: trtllm-serve
:nested: full
Besides the above examples, `trtllm-serve` is also used as an entrypoint for performance benchmarking.
Please refer to `Performance Benchmarking with `trtllm-serve` <https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/commands/trtllm-serve/trtllm-serve-bench.md>` for more details.

View File

@ -75,11 +75,11 @@ Welcome to TensorRT-LLM's Documentation!
.. toctree::
:maxdepth: 2
:caption: Command-Line Reference
:hidden:
:name: Command-Line Reference
commands/trtllm-bench
commands/trtllm-build
commands/trtllm-serve
commands/trtllm-serve/index
.. toctree::

View File

@ -19,8 +19,12 @@ Note: **This project will download and install additional third-party open sourc
The following examples can most easily be executed using the prebuilt [Docker release container available on NGC](https://registry.ngc.nvidia.com/orgs/nvstaging/teams/tensorrt-llm/containers/release) (see also [release.md](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docker/release.md) on GitHub). Ensure to run these commands as a user with appropriate permissions, preferably `root`, to streamline the setup process.
## LLM API
## Launch Docker on a node with NVIDIA GPUs deployed.
```bash
docker run --ipc host --gpus all -it nvcr.io/nvidia/tensorrt-llm/release
```
## Run Offline inference with LLM API
The LLM API is a Python API designed to facilitate setup and inference with TensorRT-LLM directly within Python. It enables model optimization by simply specifying a HuggingFace repository name or a model checkpoint. The LLM API streamlines the process by managing checkpoint conversion, engine building, engine loading, and model inference, all through a single Python object.
Here is a simple example to show how to use the LLM API with TinyLlama.
@ -34,17 +38,29 @@ You can also directly load TensorRT Model Optimizer's [quantized checkpoints on
To learn more about the LLM API, check out the [](llm-api/index) and [](examples/llm_api_examples).
(deploy-with-trtllm-serve)=
## Deploy with trtllm-serve
## Deploy online serving with trtllm-serve
You can use the `trtllm-serve` command to start an OpenAI compatible server to interact with a model.
To start the server, you can run a command like the following example:
To start the server, you can run a command like the following example inside a Docker container:
```bash
trtllm-serve "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
```
After the server starts, you can access familiar OpenAI endpoints such as `v1/chat/completions`.
You can run inference such as the following example from another terminal:
> [!NOTE]
> If you are running `trtllm-server` inside a Docker container, you have two options for sending API requests:
> 1. Expose port `8000` to access the server from outside the container.
> 2. Open a new terminal and use the following command to directly attach to the running container:
> ```bash
> docker exec -it <container_id> bash
> ```
After the server has started, you can access well-known OpenAI endpoints such as `v1/chat/completions`.
Inference can then be performed using examples similar to the one provided below, from a separate terminal.
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
@ -88,7 +104,7 @@ _Example Output_
}
```
For detailed examples and command syntax, refer to the [trtllm-serve](commands/trtllm-serve.rst) section. If you are running `trtllm-server` inside a Docker container, you have two options for sending API requests:
For detailed examples and command syntax, refer to the [trtllm-serve](commands/trtllm-serve.rst) section.
1. Expose port `8000` to access the server from outside the container.
@ -98,81 +114,12 @@ For detailed examples and command syntax, refer to the [trtllm-serve](commands/t
docker exec -it <container_id> bash
```
## Model Definition API
### Prerequisites
- This quick start uses the Meta Llama 3.1 model. This model is subject to a particular [license](https://llama.meta.com/llama-downloads/). To download the model files, agree to the terms and [authenticate with Hugging Face](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct?clone=true).
- Complete the [installation](./installation/linux.md) steps.
- Pull the weights and tokenizer files for the chat-tuned variant of the Llama 3.1 8B model from the [Hugging Face Hub](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct).
```console
git clone https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct
```
(quick-start-guide-compile)=
### Compile the Model into a TensorRT Engine
Use the [Llama model definition](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/llama) from the `examples/models/core/llama` directory of the GitHub repository.
The model definition is a minimal example that shows some of the optimizations available in TensorRT-LLM.
```console
# From the root of the cloned repository, start the TensorRT-LLM container
make -C docker ngc-release_run LOCAL_USER=1 IMAGE_TAG=x.y.z
# Log in to huggingface-cli
# You can get your token from huggingface.co/settings/token
huggingface-cli login --token *****
# Convert the model into TensorRT-LLM checkpoint format
cd examples/models/core/llama
pip install -r requirements.txt
pip install --upgrade transformers # Llama 3.1 requires transformer 4.43.0+ version.
python3 convert_checkpoint.py --model_dir Meta-Llama-3.1-8B-Instruct --output_dir llama-3.1-8b-ckpt
# Compile model
trtllm-build --checkpoint_dir llama-3.1-8b-ckpt \
--gemm_plugin float16 \
--output_dir ./llama-3.1-8b-engine
```
{{container_tag_admonition}}
When you create a model definition with the TensorRT-LLM API, you build a graph of operations from [NVIDIA TensorRT](https://developer.nvidia.com/tensorrt) primitives that form the layers of your neural network. These operations map to specific kernels; prewritten programs for the GPU.
In this example, we included the `gpt_attention` plugin, which implements a FlashAttention-like fused attention kernel, and the `gemm` plugin, that performs matrix multiplication with FP32 accumulation. We also called out the desired precision for the full model as FP16, matching the default precision of the weights that you downloaded from Hugging Face. For more information about plugins and quantizations, refer to the [Llama example](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/llama) and {ref}`precision` section.
### Run the Model
Now that you have the model engine, run the engine and perform inference.
```console
python3 ../run.py --engine_dir ./llama-3.1-8b-engine --max_output_len 100 --tokenizer_dir Meta-Llama-3.1-8B-Instruct --input_text "How do I count to nine in French?"
```
### Deploy with Triton Inference Server
To create a production-ready deployment of your LLM, use the [Triton Inference Server backend for TensorRT-LLM](https://github.com/triton-inference-server/tensorrtllm_backend) to leverage the TensorRT-LLM C++ runtime for rapid inference execution and include optimizations like in-flight batching and paged KV caching. Triton Inference Server with the TensorRT-LLM backend is available as a [pre-built container through NVIDIA NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver/tags).
1. Clone the TensorRT-LLM backend repository:
```console
cd ..
git clone https://github.com/triton-inference-server/tensorrtllm_backend.git
cd tensorrtllm_backend
```
2. Refer to [End to end workflow to run llama 7b](https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/docs/llama.md) in the TensorRT-LLM backend repository to deploy the model with Triton Inference Server.
## Next Steps
In this Quick Start Guide, you:
- Saw an example of the LLM API
- Learned about deploying a model with `trtllm-serve`
- Learned about the Model Definition API
For more examples, refer to: