diff --git a/docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md b/docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md index b50171ddf7..8f5c1dfec0 100644 --- a/docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md +++ b/docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md @@ -1,30 +1,29 @@ # Running a High Performance GPT-OSS-120B Inference Server with TensorRT-LLM -In the guide below, we will walk you through how to launch your own +NVIDIA has [announced](https://developer.nvidia.com/blog/delivering-1-5-m-tps-inference-on-nvidia-gb200-nvl72-nvidia-accelerates-openai-gpt-oss-models-from-cloud-to-edge/) day-0 support for OpenAI's new open-source model series, [gpt-oss](https://openai.com/index/introducing-gpt-oss/). In the guide below, we will walk you through how to launch your own high-performance TensorRT-LLM server for **gpt-oss-120b** for inference. -This guide covers both low-latency and max-throughput cases. -The typical use case for **low-latency**, is when we try to maximize the number of tokens per second per user with a limited concurrency (4, 8 or 16 users). - -For **maximum throughput**, the goal is to maximize the amount of tokens produced per GPU per second. The former is an indication of how fast a system can produce tokens, the latter measures how many tokens a "chip" can generate per unit of time. +**Low-latency** use cases aim to maximize the number of tokens per second per user (tps/user) with limited concurrency. +For **max-throughput**, the goal is to maximize the tokens produced per GPU per second (tps/gpu). While tps/user indicates user experience quality, tps/gpu measures the economic efficiency of the system. ## Prerequisites -- 1x NVIDIA B200/GB200/H200 GPU (8x NVIDIA B200/H200 GPUs or 4x GB200 GPUs in a single node recommended for higher performance) -- CUDA Toolkit 12.8 or later -- Docker with NVIDIA Container Toolkit installed +- 1x NVIDIA B200/GB200/H200 GPU (more GPUs could be used for lower latency and higher throughput) - Fast SSD storage for model weights - Access to the gpt-oss-120b model checkpoint -We have a forthcoming guide for getting great performance on H100, however this guide focuses on the above GPUs. +We have a forthcoming guide for achieving great performance on H100; however, this guide focuses on the GPUs listed above. +## Install TensorRT-LLM -## Launching the TensorRT-LLM docker container +In this section, we introduce several ways to install TensorRT-LLM. -The container image that you will use will be pulled from NVIDIA's NGC. This container is multi-platform and will run on both x64 and arm64 architectures: `nvcr.io/nvidia/tensorrt-llm/release:gpt-oss-dev` +### NGC Docker Image of dev branch -Run the follow docker command to start the TensorRT-LLM container in interactive mode: +Day-0 support for gpt-oss is provided via the NGC container image `nvcr.io/nvidia/tensorrt-llm/release:gpt-oss-dev`. This image was built on top of the pre-day-0 **dev branch**. This container is multi-platform and will run on both x64 and arm64 architectures. + +Run the following docker command to start the TensorRT-LLM container in interactive mode: ```bash docker run --rm --ipc=host -it \ @@ -33,117 +32,134 @@ docker run --rm --ipc=host -it \ --gpus all \ -p 8000:8000 \ -e TRTLLM_ENABLE_PDL=1 \ - -e TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL=True \ -v ~/.cache:/root/.cache:rw \ nvcr.io/nvidia/tensorrt-llm/release:gpt-oss-dev \ /bin/bash ``` -This command: +Explanation of the command: - Automatically removes the container when stopped (`--rm`) - Allows container to interact with the host's IPC resources and shared memory for optimal performance (`--ipc=host`) - Runs the container in interactive mode (`-it`) - Sets up shared memory and stack limits for optimal performance -- Maps port 8000 from the container to your host -- enables PDL for low-latency perf optimization -- disables parallel weight loading +- Maps port 8000 from the container to the host +- Enables PDL for performance optimization -Lastly the container mounts your user `.cache` directory to save the downloaded model checkpoints which are saved to `~/.cache/huggingface/hub/` by default. This prevents having to redownload the weights each time you rerun the container. +Additionally, the container mounts your user `.cache` directory to save the downloaded model checkpoints, which are stored in `~/.cache/huggingface/hub/` by default. This prevents having to redownload the weights each time you rerun the container. You can also download the weights to a custom location (we assume `${local_model_path}` is the path to the local model weights). + +### Build from source + +Support for gpt-oss has been [merged](https://github.com/NVIDIA/TensorRT-LLM/pull/6645) into the **main branch** of TensorRT-LLM. As we continue to optimize gpt-oss performance, you can build TensorRT-LLM from source to get the latest features and support. Please refer to the [doc](https://nvidia.github.io/TensorRT-LLM/latest/installation/build-from-source-linux.html) if you want to build from source yourself. -## Running the TensorRT-LLM Server +### Regular Release of TensorRT-LLM -As pointed out in the introduction, this guide covers low-latency and max-throughput cases. Each requires a different configurations and commands to run. We will first cover the Low-Latency use-case, followed by the max throughput use-case. +Since gpt-oss has been supported on the main branch, you can get TensorRT-LLM out of the box through its regular release in the future. Please check the latest [release notes](https://github.com/NVIDIA/TensorRT-LLM/releases) to keep track of the support status. The release is provided as [NGC Container Image](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release/tags) or [pip Python wheel](https://pypi.org/project/tensorrt-llm/#history). You can find instructions on pip install [here](https://nvidia.github.io/TensorRT-LLM/installation/linux.html). + + +## Performance Benchmarking and Model Serving + +This guide covers how to configure for both low-latency and max-throughput cases, as well as how to benchmark end-to-end performance. + +### Prepare the dataset +Before getting started, we need to prepare a dataset of randomized tokens for benchmarking: + +```bash +python benchmarks/cpp/prepare_dataset.py \ + --stdout \ + --tokenizer openai/gpt-oss-120b \ + token-norm-dist \ + --input-mean 1024 \ + --output-mean 2048 \ + --input-stdev 0 \ + --output-stdev 0 \ + --num-requests 20000 > gpt-oss-120b-1k2k.txt +``` + +### Low-latency Use Case + +The low-latency configuration maximizes tps/user under limited concurrency (e.g., 1, 4, 8, or 16 users). Please set the number of GPUs and concurrency according to your specific situation and workload. + +```bash +num_gpus=8 +max_batch_size=1 +``` -### Low-latency Use-Case #### Creating the Extra Options Configuration -To run a server for low-latency workloads, create a YAML configuration file, `low_latency.yaml`, as follows: +Create a YAML configuration file, `low_latency.yaml`, as follows: -```yaml +```bash cat < low_latency.yaml enable_attention_dp: false -enable_mixed_sampler: true cuda_graph_config: - max_batch_size: 8 + max_batch_size: ${max_batch_size} enable_padding: true moe_config: backend: TRTLLM EOF ``` -> Note: If you are using NVIDIA H200 GPUs it is highly recommended to set the `moe_config.backend` to TRITON to use the OpenAI Triton MoE kernel. See the section [(H200 Only) Using OpenAI Triton Kernels for MoE](#h200-only-using-openai-triton-kernels-for-moe) for more details. +Key takeaways: +- `enable_attention_dp` is set to `false` to use TP instead of DP for attention. +s- `cuda_graph_config.max_batch_size` is the maximum batch size for CUDA graph. +- `cuda_graph_config.enable_padding` is set to `true` to enable CUDA graph padding. +- `moe_config.backend` is set to `TRTLLM` to use the `trtllm-gen` MoE kernels which are optimized for low concurrency. -#### Launching TensorRT-LLM Serve +> Note: If you are using NVIDIA H200 GPUs please set the `moe_config.backend` to `TRITON` to use the OpenAI Triton MoE kernel regardless of use case. See the section [(H200/H100 Only) Using OpenAI Triton Kernels for MoE](#h200h100-only-using-openai-triton-kernels-for-moe) for more details. -To launch the TensorRT-LLM Server to serve the model with the **low latency** config, run the following command. Commands for different GPU configurations are provided (1xGPU, 8xGPU, 4xGPU): -
1x B200/GB200/H200 +#### Run the benchmark +Use `trtllm-bench` to benchmark the performance of your system: ```bash -mpirun -n 1 --oversubscribe --allow-run-as-root \ -trtllm-serve openai/gpt-oss-120b \ - --host 0.0.0.0 \ - --port 8000 \ - --backend pytorch \ - --tp_size 1 \ - --ep_size 1 \ - --trust_remote_code \ - --extra_llm_api_options low_latency.yaml \ - --kv_cache_free_gpu_memory_fraction 0.75 +trtllm-bench \ + --model openai/gpt-oss-120b \ + --model_path ${local_model_path} \ + throughput \ + --backend pytorch \ + --tp ${num_gpus} \ + --ep 1 \ + --extra_llm_api_options low_latency.yaml \ + --dataset gpt-oss-120b-1k2k.txt \ + --max_batch_size ${max_batch_size} \ + --concurrency ${max_batch_size} \ + --num_requests $((max_batch_size * 10)) \ + --kv_cache_free_gpu_mem_fraction 0.9 \ + --streaming \ + --warmup 0 \ + --report_json low_latency_benchmark.json ``` -
-
8x B200/H200 +`--max_batch_size` controls the maximum batch size that the inference engine could serve, while `--concurrency` is the number of concurrent requests that the benchmarking client is sending. `--num_requests` is set to 10 times of `--concurrency` to run enough number of requests. + +Note that you can set `--ep` to a value larger than 1, which will enable mixed TP/EP for MoE. In minimum-latency scenarios, we recommend a small EP size to avoid load imbalance in MoE. + +For reference, we achieve **420 tps/user** with 8x B200 GPUs and batch size 1. + + +### Max-Throughput Use Case + +The max-throughput configuration maximizes tps/gpu at high concurrency levels. With increasing concurrency, we trade per-user latency for higher throughput that saturates the system's GPUs. Using input sequence length (isl) of 1k and output sequence length (osl) of 2k, we can currently achieve a batch size of 640 with 8x B200 GPUs. ```bash -mpirun -n 1 --oversubscribe --allow-run-as-root \ -trtllm-serve openai/gpt-oss-120b \ - --host 0.0.0.0 \ - --port 8000 \ - --backend pytorch \ - --tp_size 8 \ - --ep_size 8 \ - --trust_remote_code \ - --extra_llm_api_options low_latency.yaml \ - --kv_cache_free_gpu_memory_fraction 0.75 +num_gpus=8 +max_batch_size=640 ``` -
-
4x GB200/B200/H200 - -```bash -mpirun -n 1 --oversubscribe --allow-run-as-root \ -trtllm-serve openai/gpt-oss-120b \ - --host 0.0.0.0 \ - --port 8000 \ - --backend pytorch \ - --tp_size 4 \ - --ep_size 4 \ - --trust_remote_code \ - --extra_llm_api_options low_latency.yaml \ - --kv_cache_free_gpu_memory_fraction 0.75 -``` -
- - - - -### Max-Throughput Use-Case #### Creating the Extra Options Configuration -To run a server for max-throughput workloads, create a YAML configuration file, -`max_throughput.yaml`, as follows: +Like before, create a YAML configuration file, `max_throughput.yaml`, as follows: -```yaml +```bash cat < max_throughput.yaml enable_attention_dp: true cuda_graph_config: - max_batch_size: 640 + max_batch_size: ${max_batch_size} enable_padding: true stream_interval: 10 moe_config: @@ -151,97 +167,85 @@ moe_config: EOF ``` -> Note: If you are using NVIDIA H200 GPUs it is highly recommended to set the `moe_config.backend` to TRITON to use the OpenAI Triton MoE kernel. See the section [(H200 Only) Using OpenAI Triton Kernels for MoE](#h200-only-using-openai-triton-kernels-for-moe) for more details. +Compared to the low-latency configuration, we: +- set `enable_attention_dp` to `true` to use attention DP which is better for high throughput. +- set `stream_interval` to 10 to stream results to the client every 10 tokens. At high concurrency, the detokenization overhead of streaming mode cannot be hidden under GPU execution time, so `stream_interval` serves as a workaround to reduce this overhead. +- set `moe_config.backend` to `CUTLASS` to use the `CUTLASS` MoE kernels which are optimized for high throughput. -#### Launching TensorRT-LLM Serve +#### Run the benchmark -To launch the TensorRT-LLM Server to serve the model with the **max throughput** config, run the following command. Commands for different GPU configurations are provided (1xGPU, 8xGPU, 4xGPU): - -
1x B200/GB200/H200 +Run the following command to benchmark the throughput of your system: ```bash -mpirun -n 1 --oversubscribe --allow-run-as-root \ -trtllm-serve openai/gpt-oss-120b \ +trtllm-bench \ + --model openai/gpt-oss-120b \ + --model_path ${local_model_path} \ + throughput \ + --backend pytorch \ + --tp ${num_gpus} \ + --ep ${num_gpus} \ + --extra_llm_api_options max_throughput.yaml \ + --dataset gpt-oss-120b-1k2k.txt \ + --max_batch_size ${max_batch_size} \ + --concurrency $((max_batch_size * num_gpus)) \ + --num_requests $((max_batch_size * num_gpus * 3)) \ + --kv_cache_free_gpu_mem_fraction 0.9 \ + --streaming \ + --warmup 0 \ + --report_json max_throughput_benchmark.json +``` + +Note: +- `CUTLASS` MoE backend only supports pure EP for MoE, so we set `--ep` to `num_gpus`. +- When using `enable_attention_dp`, `max_batch_size` describes the maximum batch size for each local rank, so to saturate the system, we need to multiply `max_batch_size` by `num_gpus` for `--concurrency`. +- `--num_requests` is set to 3 times `--concurrency` to run enough number of requests. + +Currently, the best throughput **19.5k tps/gpu** is achieved with DP4EP4 using 4x B200 GPUs and over **20k tps/gpu** on GB200 GPUs due to slightly better performance of GB200, which translates to over **1.5M tps** on a GB200 NVL72 system. In theory, even better tps/gpu could be achieved with larger world size due to larger allowable batch size and smaller MoE weights per-GPU, but the communication implementation for >4GPUs is suboptimal and we are actively working on improving it. + + + +## Launch the TensorRT-LLM Server + +We can use `trtllm-serve` to serve the model by translating the benchmark commands above. For low-latency configuration, run: + +```bash +trtllm-serve \ + gpt-oss-120b \ # Or ${local_model_path} --host 0.0.0.0 \ --port 8000 \ --backend pytorch \ - --tp_size 1 \ - --ep_size 1 \ - --max_batch_size 640 \ - --trust_remote_code \ - --extra_llm_api_options max_throughput.yaml \ - --kv_cache_free_gpu_memory_fraction 0.9 + --tp_size ${num_gpus} \ + --ep_size 1 \ + --extra_llm_api_options low_latency.yaml \ + --kv_cache_free_gpu_memory_fraction 0.9 \ + --max_batch_size ${max_batch_size} \ # E.g., 1 + --trust_remote_code ``` -
- -
8x B200/H200 - -```bash -mpirun -n 1 --oversubscribe --allow-run-as-root \ -trtllm-serve openai/gpt-oss-120b \ - --host 0.0.0.0 \ - --port 8000 \ - --backend pytorch \ - --tp_size 8 \ - --ep_size 8 \ - --max_batch_size 640 \ - --trust_remote_code \ - --extra_llm_api_options max_throughput.yaml \ - --kv_cache_free_gpu_memory_fraction 0.9 -``` -
- -
4x GB200/B200/H200 - -```bash -mpirun -n 1 --oversubscribe --allow-run-as-root \ -trtllm-serve openai/gpt-oss-120b \ - --host 0.0.0.0 \ - --port 8000 \ - --backend pytorch \ - --tp_size 4 \ - --ep_size 4 \ - --max_batch_size 640 \ - --trust_remote_code \ - --extra_llm_api_options max_throughput.yaml \ - --kv_cache_free_gpu_memory_fraction 0.9 -``` -
- - -This command: -- Maps port 8000 from the container to your host -- Uses the PyTorch backend and specifies the tensor and expert parallel sizes -- References the low latency or max throughput configuration file for extra options -- Configures memory settings for optimal performance -- Enables all GPUs with attention data parallelism for the max throughput scenario The initialization may take several minutes as it loads and optimizes the models. - -## (H200 Only) Using OpenAI Triton Kernels for MoE - -OpenAI ships a set of Triton kernels optimized for its MoE models. TensorRT-LLM can leverage these kernels for Hopper based GPUs like NVIDIA's H200 for best performance. The NGC TensorRT-LLM container image mentioned above already includes the required kernels so you do not need to build or install them. It is highly recommended to enable them with the steps below: - -### Selecting Triton as the MoE backend - -To use the Triton MoE backend with **trtllm-serve** (or other similar commands) add this snippet to the YAML file passed via `--extra_llm_api_options`: - -```yaml -moe_config: - backend: TRITON -``` - -Alternatively the TRITON backend can be enabled by passing the CLI flag to the trtllm-server command at runtime: +For max-throughput configuration, run: ```bash ---moe_backend TRITON +trtllm-serve \ + gpt-oss-120b \ # Or ${local_model_path} + --host 0.0.0.0 \ + --port 8000 \ + --backend pytorch \ + --tp_size ${num_gpus} \ + --ep_size ${num_gpus} \ + --extra_llm_api_options max_throughput.yaml \ + --kv_cache_free_gpu_memory_fraction 0.9 \ + --max_batch_size ${max_batch_size} \ # E.g., 640 + --trust_remote_code ``` -## Test the Server with a Sample Request -You can query the health/readiness of the server using +### Test the Server with a Sample Request + + +To check the server's health and readiness: ```bash curl -s -o /dev/null -w "Status: %{http_code}\n" "http://localhost:8000/health" @@ -252,14 +256,13 @@ very first query may take longer due to initialization and compilation. Once the server is running, you can test it with a simple curl request: - ```bash curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{ "model": "openai/gpt-oss-120b", "messages": [ { "role": "user", - "content": "What is NVIDIAs advantage for inference?" + "content": "What is NVIDIA's advantage for inference?" } ], "max_tokens": 1024, @@ -343,20 +346,29 @@ requests. You can adjust parameters like `max_tokens`, `temperature`, and others according to your needs. +## (H200/H100 Only) Using OpenAI Triton Kernels for MoE + +OpenAI ships a set of Triton kernels optimized for its MoE models. TensorRT-LLM can leverage these kernels for Hopper-based GPUs like NVIDIA's H200 for optimal performance. `TRTLLM` MoE backend is not supported on Hopper, and `CUTLASS` backend support is still ongoing. Please enable `TRITON` backend with the steps below if you are running on Hopper GPUs. + +### Installing OpenAI Triton + +The `nvcr.io/nvidia/tensorrt-llm/release:gpt-oss-dev` has prepared Triton already (`echo $TRITON_ROOT` could reveal the path). In other situations, you will need to build and install a specific version of Triton. Please follow the instructions in this [link](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/gpt_oss#using-openai-triton-kernels-for-moe). + + +### Selecting Triton as the MoE backend + +To use the Triton MoE backend with **trtllm-serve** (or other commands), add this snippet to the YAML file passed via `--extra_llm_api_options`: + +```yaml +moe_config: + backend: TRITON +``` + + ## Troubleshooting Tips -- If you encounter CUDA out-of-memory errors, try reducing `max_batch_size`, `max_seq_len`, or `--kv_cache_free_gpu_memory_fraction` -- Ensure your model checkpoints are compatible with the expected format -- For performance issues, check GPU utilization with `nvidia-smi` while the server is running +- If you encounter CUDA out-of-memory errors, try reducing `--max_batch_size`, `--max_num_tokens`, or `--kv_cache_free_gpu_memory_fraction`. See the [doc](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/performance/performance-tuning-guide/tuning-max-batch-size-and-max-num-tokens.md) for the explanation of these parameters. +- Add `print_iter_log: true` to extra LLM API options YAML file to inspect the per-iteration log. +- Check GPU utilization with `nvidia-smi` while the server is running to inspect GPU status and memory usage. - If the container fails to start, verify that the NVIDIA Container Toolkit is properly installed - For connection issues, make sure port 8000 is not being used by another application - - -## Performance Tuning - -The configuration provided is optimized for 8xB200 GPUs, but you can adjust -several parameters for your specific workload: - -- `max_batch_size`: Controls how many requests can be batched together -- `max_draft_len`: The number of tokens Eagle can speculate ahead -- `kv_cache_free_gpu_memory_fraction`: Controls memory allocation for the KV cache