# Disaggregated Serving The execution method of disaggregated serving relies on the `trtllm-serve` command. Specifically, compared to the standard usage of `trtllm-serve`, serving requires running this command multiple times to separately start the router and workers (including context and generation) serving components. This document focuses on this approach and provides a detailed guide on how to use it. Please note that disaggregated serving is currently an experimental feature, so the usage described in this document may change in the future. ## Startup Procedure ### Configuration File The `trtllm-serve` command supports the `extra-llm-config.yaml` parameter. In the extra LLM configuration file, the `cache_transceiver_config` field is specifically used for disaggregated service. It is mainly used to specify additional parameters required for the KV cache transmission process. ```yaml cache_transceiver_config: # KV cache transmission backend. Valid options include `DEFAULT` (i.e., NIXL), `UCX`, `NIXL`. backend: # KV cache buffer size. Set it ≥ the maximum ISL (Input Sequence Length) for best performance. max_tokens_in_buffer: # KV cache transfer timeout in milliseconds # For requests, if they do not send/receive the KV cache in time they are cancelled and cleaned up kv_transfer_timeout_ms: # Timeout in milliseconds to wait for the sender future to be ready when scheduled batch size is 0. This allows the request to be eventually cancelled by the user or because of kv_transfer_timeout_ms kv_transfer_sender_future_timeout_ms: ``` The following is an example, consisting of the `ctx_config.yaml` and `gen_config.yaml` files needed in the sections below. ```yaml # ctx_config.yaml # The overlap scheduler for context servers is currently disabled, as it is # not yet supported in disaggregated context server architectures. disable_overlap_scheduler: True cache_transceiver_config: backend: UCX max_tokens_in_buffer: 2048 ``` ```yaml # gen_config.yaml cache_transceiver_config: backend: UCX max_tokens_in_buffer: 2048 ``` ## NIXL Backend Configuration NIXL supports multiple underlying communication backends for KV cache exchange. The backend can be configured using the `TRTLLM_NIXL_KVCACHE_BACKEND` environment variable. **Supported NIXL backends:** - **UCX** (default) - **LIBFABRIC** (available from v0.16.0) If an unsupported backend is specified, NIXL will automatically fall back to UCX. ### LIBFABRIC Backend Setup **Important Note:** The TensorRT LLM container does not include libfabric or the NIXL-LIBFABRIC plugin by default. You must either rebuild NIXL with libfabric support or provide a pre-compiled plugin. #### Prerequisites ##### For LIBFABRIC Backend **Required Dependencies:** **Libfabric** - Custom libfabric installation is available via [https://ofiwg.github.io/libfabric/](https://ofiwg.github.io/libfabric/) - **Minimum required version:** v1.21.0 - For EFA-enabled AWS instances, install through the [AWS EFA installer](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/efa-start.html) (recommend using the latest version) **hwloc** - hwloc is used to understand the underlying architecture to optimize application performance - **Suggested version:** 2.10.0 or newer **Network Hardware Requirements:** - Validated compatibility with AWS EFA (Elastic Fabric Adapter) ##### For UCX Backend UCX is typically pre-installed in NVIDIA GPU containers. No additional installation is usually required. #### Installation Options ##### Option 1: Rebuild NIXL with LIBFABRIC Support (Recommended) 1. **Install libfabric dependencies:** - Follow the installation instructions from the links above based on your system 2. **Install hwloc:** - Use your package manager or build from source 3. **Reinstall NIXL after installing libfabric:** - After installing libfabric and hwloc, you must rebuild NIXL to generate the LIBFABRIC plugin - You can base your installation on the TensorRT LLM NIXL installation script located at `docker/common/install_nixl.sh` - Modify the meson setup command in the script to include the libfabric path: ```bash meson setup builddir \ ... -Dlibfabric_path=/path/to/libfabric \ # Add this line --buildtype=release ``` - For more details, see the [NIXL LIBFABRIC Plugin documentation](https://github.com/ai-dynamo/nixl/tree/6ee64753605b3110f8ef96c7cfc2f1315675c9c7/src/plugins/libfabric#nixl-libfabric-plugin) ##### Option 2: Use Pre-compiled LIBFABRIC Plugin If you have a pre-compiled `libplugin_LIBFABRIC.so` that matches your NIXL version: 1. Place the plugin file in a directory of your choice 2. Set the environment variable to point to the plugin directory: ```bash export NIXL_PLUGINS_DIR=/path/to/plugin/directory export TRTLLM_NIXL_KVCACHE_BACKEND=LIBFABRIC ``` 3. Ensure the plugin was built with the same NIXL version as in your container ### NIXL Configuration Examples To use NIXL for KV cache exchange, configure the `cache_transceiver_config` with `backend: NIXL`. The underlying NIXL backend (UCX or LIBFABRIC) is selected via the `TRTLLM_NIXL_KVCACHE_BACKEND` environment variable. **Context server configuration:** ```yaml # context_config_nixl.yml disable_overlap_scheduler: True cache_transceiver_config: backend: NIXL max_tokens_in_buffer: 2048 ``` **Generation server configuration:** ```yaml # gen_config_nixl.yml cache_transceiver_config: backend: NIXL max_tokens_in_buffer: 2048 ``` #### Example 1: Using NIXL with UCX backend (default) ```bash # UCX is the default, but can be explicitly set export TRTLLM_NIXL_KVCACHE_BACKEND=UCX # Optional, UCX is default # Start Context servers with NIXL using UCX CUDA_VISIBLE_DEVICES=0 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --host localhost --port 8001 --backend pytorch \ --config ./context_config_nixl.yml &> log_ctx_0 & CUDA_VISIBLE_DEVICES=1 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --host localhost --port 8002 --backend pytorch \ --config ./context_config_nixl.yml &> log_ctx_1 & # Start Generation server with NIXL using UCX CUDA_VISIBLE_DEVICES=2 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --host localhost --port 8003 --backend pytorch \ --config ./gen_config_nixl.yml &> log_gen_0 & ``` #### Example 2: Using NIXL with LIBFABRIC backend ```bash # Configure NIXL to use LIBFABRIC backend export TRTLLM_NIXL_KVCACHE_BACKEND=LIBFABRIC # If using pre-compiled plugin: # export NIXL_PLUGINS_DIR=/path/to/plugin/directory # For AWS EFA (optional): # export FI_PROVIDER=efa # export FI_EFA_USE_DEVICE_RDMA=1 # export FI_LOG_LEVEL=warn # Start Context servers with NIXL using LIBFABRIC CUDA_VISIBLE_DEVICES=0 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --host localhost --port 8001 --backend pytorch \ --config ./context_config_nixl.yml &> log_ctx_0 & CUDA_VISIBLE_DEVICES=1 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --host localhost --port 8002 --backend pytorch \ --config ./context_config_nixl.yml &> log_ctx_1 & # Start Generation server with NIXL using LIBFABRIC CUDA_VISIBLE_DEVICES=2 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --host localhost --port 8003 --backend pytorch \ --config ./gen_config_nixl.yml &> log_gen_0 & ``` ### Environment Variables for NIXL Backends **NIXL Backend Selection:** - `TRTLLM_NIXL_KVCACHE_BACKEND`: Selects the underlying backend for NIXL. Valid options: - `UCX` (default) - `LIBFABRIC` (available from v0.16.0) - If an unsupported value is provided, NIXL automatically falls back to UCX **Additional Environment Variables by Backend:** **For UCX backend:** - `UCX_MAX_RNDV_RAILS`: Maximum number of InfiniBand NIC devices per GPU. Setting to 1 can reduce contention in multi-GPU scenarios - Standard UCX environment variables apply **For LIBFABRIC backend:** - `NIXL_PLUGINS_DIR`: Directory containing the NIXL LIBFABRIC plugin (`libplugin_LIBFABRIC.so`) if using pre-compiled plugin - `FI_PROVIDER`: Specifies the libfabric provider to use (e.g., `efa` for AWS EFA) - `FI_EFA_USE_DEVICE_RDMA`: Set to `1` to enable GPU Direct RDMA on AWS EFA (if supported) - `FI_LOG_LEVEL`: Controls libfabric logging verbosity (e.g., `warn`, `info`, `debug`) **Example configuration for AWS EFA with LIBFABRIC:** ```bash export TRTLLM_NIXL_KVCACHE_BACKEND=LIBFABRIC export FI_PROVIDER=efa export FI_EFA_USE_DEVICE_RDMA=1 export FI_LOG_LEVEL=warn ``` ### Basic Usage For non-SLURM clusters - particularly in single-node, multi-GPU setups, it is recommended to use standard mode. In such cases, the system does not enforce limits on process creation or termination. Suppose we have three CUDA devices on the same machine. The first two devices are used to launch one context model each, and the third device is used to launch one generation model. In this case, the following commands need to be executed. ```bash # Start context servers CUDA_VISIBLE_DEVICES=0 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --host localhost --port 8001 \ --config ./ctx_config.yaml &> log_ctx_0 & CUDA_VISIBLE_DEVICES=1 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --host localhost --port 8002 \ --config ./ctx_config.yaml &> log_ctx_1 & # Start generation server CUDA_VISIBLE_DEVICES=2 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --host localhost --port 8003 \ --config ./gen_config.yaml &> log_gen_0 & ``` Once the context and generation servers are launched, you can launch the disaggregated server, which will accept requests from clients and do the orchestration between context and generation servers. The disaggregated server can be launched with: ```bash # Start proxy trtllm-serve disaggregated -c disagg_config.yaml ``` where `disagg_config.yaml` contains information about the context and generation servers. For the current example, it would look like: ```yaml # disagg_config.yaml hostname: localhost port: 8000 backend: pytorch context_servers: num_instances: 2 urls: - "localhost:8001" - "localhost:8002" generation_servers: num_instances: 1 urls: - "localhost:8003" ``` Clients can then send requests to the disaggregated server at `localhost:8000`, which is an OpenAI API compatible endpoint. #### Sending requests to the disaggregated server Once the context, generation and disaggregated servers are launched, you can send requests to the disaggregated server using curl: ```bash curl http://localhost:8000/v1/completions \ -H "Content-Type: application/json" \ -d '{ "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "prompt": "NVIDIA is a great company because", "max_tokens": 16, "temperature": 0 }' -w "\n" ``` Or using the provided client parsing the prompts from a file and sending request to the disaggregated server specified in the `disagg_config.yaml` file at the `chat` endpoint: ``` python3 ./clients/disagg_client.py -c disagg_config.yaml -p ./clients/prompts.json -e chat ``` ### Launching disaggregated servers on SLURM clusters To simplify usage, TensorRT-LLM internally relies on MPI spawning processes. However, some clusters do not offer such process flexibility. In these cases, we provide the `trtllm-llmapi-launch` tool to launch all processes at once. Therefore, when using TensorRT-LLM on a Slurm cluster, please refer to the following method. #### Single-Node Execution After starting the node and entering interactive mode, you can run the following command to prevent process spawning. ```bash # Start context servers CUDA_VISIBLE_DEVICES=0 trtllm-llmapi-launch trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --host localhost --port 8001 \ --config ./ctx_config.yaml &> log_ctx_0 & CUDA_VISIBLE_DEVICES=1 trtllm-llmapi-launch trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --host localhost --port 8002 \ --config ./ctx_config.yaml &> log_ctx_1 & # Start generation server CUDA_VISIBLE_DEVICES=2 trtllm-llmapi-launch trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ --host localhost --port 8003 \ --config ./gen_config.yaml &> log_gen_0 & # Start proxy trtllm-llmapi-launch trtllm-serve disaggregated -c disagg_config.yaml ``` #### Multi-Node Execution If the model you are running cannot fit within a single node and requires multiple nodes, we introduce the startup method using [srun](https://slurm.schedmd.com/srun.html) to run parallel jobs. ```bash srun -A -p -t