[TRTLLM-5673] Doc: ensure the disagg doc is up to date (#5938)

This commit is contained in:
Shi Xiaowei 2025-07-11 16:39:05 +08:00 committed by GitHub
parent c5fb692a7d
commit 49359574c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 65 additions and 134 deletions

View File

@ -1,81 +1,22 @@
(disaggregated-service)=
# Disaggregated-Service (experimental)
# Disaggregated-Service (Experimental)
```{note}
Note:
This feature is currently experimental, and the related API is subjected to change in future versions.
```
Currently TRT-LLM supports `disaggregated-service`, where the context and generation phases of a request can run on different executors. TRT-LLM's disaggregated service relies on the executor API, please make sure to read the [executor page](executor.md) before reading the document.
For more information on disaggregated service in LLM inference, one can refer to papers such as [DistServe](https://arxiv.org/abs/2401.09670), [SplitWise](https://arxiv.org/abs/2311.18677).
## Usage
```cpp
enum class RequestType
{
REQUEST_TYPE_CONTEXT_AND_GENERATION = 0,
REQUEST_TYPE_CONTEXT_ONLY = 1,
REQUEST_TYPE_GENERATION_ONLY = 2
};
```
The TRT-LLM executor can execute three types of requests: `REQUEST_TYPE_CONTEXT_AND_GENERATION`, `REQUEST_TYPE_CONTEXT_ONLY`, and `REQUEST_TYPE_GENERATION_ONLY`. An executor instance could execute the context phase of the context-only request or the generation phase of the generation-only request. When the executor completes the context phase of a context-only request, it maintains the corresponding KV cache, which will be requested by the executor for the subsequent generation-only request.
Note that the environment variable `TRTLLM_USE_MPI_KVCACHE=1` should be set for `disaggregated-service`.
Here are some key APIs to use disaggregated service:
```cpp
Request request{...};
request.setRequestType(tensorrt_llm::executor::RequestType::REQUEST_TYPE_CONTEXT_ONLY);
auto contextRequestId = contextExecutor.enqueueRequest(request);
auto contextResponses = contextExecutor.awaitResponses(contextRequestId);
auto contextPhaseParams = contextResponses.back().getResult().contextPhaseParams.value();
request.setContextPhaseParams(contextPhaseParams);
request.setRequestType(tensorrt_llm::executor::RequestType::REQUEST_TYPE_GENERATION_ONLY);
auto generationRequestId = generationExecutor.enqueueRequest(request);
auto genResponses = generationExecutor.awaitResponses(generationRequestId);
```
The generationExecutor will require data such as KV cache from the corresponding contextExecutor based on the `contextPhaseParams` attached to the request, so please make sure that the corresponding contextExecutor is not shut down before getting the generationExecutor's response.
In the code example above, the `contextRequestId` assigned by the contextExecutor and the `generationRequestId` assigned by the generationExecutor are independent, it is the user's responsibility to manage the mapping of the `requestId` for context-only requests to the `requestId` for generation-only requests. The `contextResponses` contains the first output token generated by the context phase, and the `genResponses` also contains the first output token generated by the contextExecutor, so all output tokens can be obtained from generationExecutor's responses.
![disaggregated-service usage](images/disaggregated-service_usage.png)
An `orchestrator` is required in `disaggregated-service` to manage multiple executor instances and route requests to different executors, TRT-LLM provides class `DisaggExecutorOrchestrator` in `cpp/include/tensorrt_llm/executor/disaggServerUtil.h` to launch multiple executor instances, however, `DisaggExecutorOrchestrator` only routes requests to executors in a simple round-robin policy, users need to implement their own orchestrator for disaggregated-service based on their usage scenario.
## Example
Please refer to `examples/cpp/executor/executorExampleDisaggregated.cpp`
## Benchmarks
Please refer to `benchmarks/cpp/disaggServerBenchmark.cpp` and `benchmarks/cpp/README.md`
An [architectural and performance overview](../../../docs/source/blogs/tech_blog/blog5_Disaggregated_Serving_in_TensorRT-LLM.md), as well as [usage examples](../../../examples/disaggregated/README.md), are provided.
## Environment Variables
TRT-LLM uses some environment variables to control the behavior of disaggregated service.
* `TRTLLM_USE_MPI_KVCACHE`: Whether to use MPI to transfer KV cache. Currently, the default value is `0`.
* `TRTLLM_USE_UCX_KVCACHE`: Whether to use UCX to transfer KV cache. Currently, the default value is `0`. To use disaggregated service, either `TRTLLM_USE_MPI_KVCACHE=1` or `TRTLLM_USE_UCX_KVCACHE=1` is required to be set.
* `TRTLLM_USE_UCX_KVCACHE`: Specifies whether to use UCX for KV cache transfer. The default value is `0`. This must be enabled when using a disaggregated service.
* `TRTLLM_PARALLEL_CACHE_SEND`: If set to `1`, contextExecutor will attempt to send KV cache for multiple requests in parallel. The default value is `0`.
@ -93,7 +34,6 @@ TRT-LLM uses some environment variables to control the behavior of disaggregated
* `TRTLLM_KVCACHE_SEND_MAX_CONCURRENCY_NUM`: The maximum number of concurrent KV cache sends. The default value is `4`. This environment variable only takes effect when `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE` is greater than 0.
## Troubleshooting and FAQ
### General FAQs
@ -122,28 +62,18 @@ A. Yes, but it's not recommended, TRT-LLM does not implement proper scheduling f
A. Yes, it's recommended that different executor use different GPUs . We support context-only executor and genertion-only executor run on same node or different nodes. The `participantIds` and `deviceIds` used by each executor need to be explicitly set by the user, and the `participantIds` of each executor must not be intersecting.
*Q. What's the requirement for disaggregated-service in TRT-LLM?*
A. TRT-LLM requires `UCX`-backend `CUDA-aware MPI` currently, TRT-LLM implements KV cache transfer with [`CUDA-aware MPI`](https://docs.open-mpi.org/en/v5.0.x/tuning-apps/networking/cuda.html#how-do-i-build-open-mpi-with-cuda-aware-support), and will support more communication components for KV cache transfer in future version.
### Debugging FAQs
*Q. How to handle error `Disaggregated serving is not enabled, please check the configuration?`*
A. please set the environment variables
```
export TRTLLM_USE_MPI_KVCACHE=1
```
or
A. Please set the environment variables
```
export TRTLLM_USE_UCX_KVCACHE=1
```
When the environment variable `TRTLLM_USE_MPI_KVCACHE=1` is set, TRT-LLM will transfer the KV cache using `CUDA-aware MPI`. All executor processes involved must share the same MPI world communicator. Consequently, with `TRTLLM_USE_MPI_KVCACHE=1`, TRT-LLM only supports launching multiple executors via `MPI`. Additionally, the `CommunicationMode` for the executors must be set to `kLEADER` or `kORCHESTRATOR` with `SpawnProcesses=false` for the `disaggregated-service`. These restrictions do not apply when `TRTLLM_USE_UCX_KVCACHE=1` is set.
*Q. Why do some profiling tools show that TRT-LLM's KV cache transfer does not utilize NVLink even on devices equipped with NVLink?*
A. Ensure TRT-LLM is running with `UCX`-backend `CUDA-aware MPI` , and check version of `UCX` with `ucx_info -v`.
A. Please check version of `UCX` with `ucx_info -v`.
If the version of UCX <=1.17, set the environment variables `UCX_RNDV_FRAG_MEM_TYPE=cuda` and `UCX_MEMTYPE_CACHE=n` to enable NVLink. For BlackWell architecture GPUs, UCX version >=1.19 is required to enable NVLink.
If the version of UCX >=1.18, there are several ways to enable NVLink:
1. Set the environment variables `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE=0B`,`UCX_CUDA_COPY_ASYNC_MEM_TYPE=cuda`, `UCX_CUDA_COPY_DMABUF=no`, `UCX_MEMTYPE_CACHE=n` and `UCX_RNDV_PIPELINE_ERROR_HANDLING=y`.
@ -154,7 +84,6 @@ If the version of UCX >=1.18, there are several ways to enable NVLink:
A. Yes, TRT-LLM supports using GPU direct RDMA for inter-node KV cache transfer, but it is not enabled by default. There are several ways to enable GPU direct RDMA:
1. Set the environment variables `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE=0B`,`UCX_RNDV_FRAG_MEM_TYPE=cuda`, `UCX_MEMTYPE_CACHE=n` and `UCX_RNDV_PIPELINE_ERROR_HANDLING=y`.
2. Set the environment variables `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE=$Size`, `UCX_MEMTYPE_CACHE=n` and `UCX_RNDV_PIPELINE_ERROR_HANDLING=y`, $Size represents the size of the buffer for KV cache transfer, which is recommended to be larger than the size of the KV cache for the longest request.
To achieve the optimal performance when using GPU direct RDMA, it is advisable to create CUDA context before MPI initialization when TRTLLM_USE_MPI_KVCACHE=1 is set. One possible approach is to rely on MPI environment variables to set the correct device before MPI initialization.
*Q. Are there any guidelines for performance tuning of KV cache transfer?*

View File

@ -1,7 +1,6 @@
# TRT-LLM Disaggregated Serving
To run TRT-LLM in disaggregated mode, you must first launch context (prefill) and generation (decode) servers using `trtllm-serve`.
Depending on your deployment environment, this can be done in different ways.
## Launching context and generation servers using multiple independent `trtllm-serve` commands
@ -45,7 +44,66 @@ generation_servers:
Clients can then send requests to the disaggregated server at `localhost:8000`, which is an OpenAI compatible endpoint.
## Launching context and generation servers using MPI
## Sending requests to the disaggregated server
Once the context, generation and disaggregated servers are launched, you can send requests to the disaggregated server using curl:
```
curl http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"prompt": "NVIDIA is a great company because",
"max_tokens": 16,
"temperature": 0
}' -w "\n"
```
Or using the provided client parsing the prompts from a file and sending request to the disaggregated server specified in the `disagg_config.yaml` file at the `chat` endpoint:
```
python3 ./clients/disagg_client.py -c disagg_config.yaml -p ./clients/prompts.json -e chat
```
## Dynamic scaling (Experimental)
Currently, trtllm supports dynamic addition and removal of servers by leveraging ETCD. To enable this feature, you should start the context and generation servers with an additional flag ```--metadata_server_config_file``` and ```--server_role```.
Before launching the context and generation servers, you should first start the ETCD server. By default, the ETCD server listens for client requests at ```localhost:2379```.
```
etcd
```
After this, you can enable the dynamic scaling feature for the use case above as follows:
```
export TRTLLM_USE_UCX_KVCACHE=1
#Context servers
CUDA_VISIBLE_DEVICES=0 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8001 --backend pytorch --server_role CONTEXT --extra_llm_api_options ./context_extra-llm-api-config.yml --metadata_server_config_file ./metadata_config.yml &> log_ctx_0 &
CUDA_VISIBLE_DEVICES=1 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8002 --backend pytorch --server_role CONTEXT --extra_llm_api_options ./context_extra-llm-api-config.yml --metadata_server_config_file ./metadata_config.yml &> log_ctx_1 &
#Generation servers
CUDA_VISIBLE_DEVICES=2 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8003 --backend pytorch --server_role GENERATION --extra_llm_api_options ./gen_extra-llm-api-config.yml --metadata_server_config_file ./metadata_config.yml &> log_gen_0 &
```
As for the disaggregated server, you should also specify the --metadata_server_config_file like the following
```
trtllm-serve disaggregated -c disagg_config.yaml -m ./metadata_config.yml
```
The metadata_config file looks like
```
hostname: "localhost"
port: 2379
health_check_timeout: 5.0
refersh_interval: 10.0
```
The ```hostname``` and ```port``` must match those used when starting the ETCD server. The ```health_check_timeout``` parameter specifies how long a server will be considered dead if no healthy response is received. By default, trtllm will perform two checks before marking a server as dead. The ```refresh_interval``` parameter determines how often the latest server list is fetched from the ETCD server.
### Dynamically adding servers
Users can add servers by directly launching them with trtllm-serve. For example, you can start an additional generation server as follows:
```
CUDA_VISIBLE_DEVICES=3 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8004 --backend pytorch --server_role GENERATION --extra_llm_api_options ./gen_extra-llm-api-config.yml --metadata_server_config_file ./metadata_config.yml &> log_gen_0 &
```
Trtllm will automatically register any newly launched server with the ETCD server, allowing the router to send new requests to the added server.
### Dynamically removing servers
When removing servers, special attention is required in the current version. You need to first remove the corresponding key from the ETCD server. After you see the log message "Server xxxx is removed," you can then safely shut down the server. This part will be improved soon.
## Launching context and generation servers using MPI (Deprecated)
One can also launch all context and generation servers using MPI. This can be done by issuing the following command:
```
@ -85,59 +143,3 @@ Once the context and generation servers are launched, you can again launch the d
```
trtllm-serve disaggregated -c disagg_config.yaml
```
## Sending requests to the disaggregated server
Once the context, generation and disaggregated servers are launched, you can send requests to the disaggregated server using curl:
```
curl http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"prompt": "NVIDIA is a great company because",
"max_tokens": 16,
"temperature": 0
}' -w "\n"
```
Or using the provided client parsing the prompts from a file and sending request to the disaggregated server specified in the `disagg_config.yaml` file at the `chat` endpoint:
```
python3 ./clients/disagg_client.py -c disagg_config.yaml -p ./clients/prompts.json -e chat
```
## Dynamic scaling with ETCD (experimental)
Currently, trtllm supports dynamic addition and removal of servers by leveraging ETCD. To enable this feature, you should start the context and generation servers with an additional flag ```--metadata_server_config_file``` and ```--server_role```.
Before launching the context and generation servers, you should first start the ETCD server. By default, the ETCD server listens for client requests at ```localhost:2379```.
```
etcd
```
After this, you can enable the dynamic scaling feature for the use case above as follows:
```
export TRTLLM_USE_UCX_KVCACHE=1
#Context servers
CUDA_VISIBLE_DEVICES=0 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8001 --backend pytorch --server_role CONTEXT --extra_llm_api_options ./context_extra-llm-api-config.yml --metadata_server_config_file ./metadata_config.yml &> log_ctx_0 &
CUDA_VISIBLE_DEVICES=1 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8002 --backend pytorch --server_role CONTEXT --extra_llm_api_options ./context_extra-llm-api-config.yml --metadata_server_config_file ./metadata_config.yml &> log_ctx_1 &
#Generation servers
CUDA_VISIBLE_DEVICES=2 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8003 --backend pytorch --server_role GENERATION --extra_llm_api_options ./gen_extra-llm-api-config.yml --metadata_server_config_file ./metadata_config.yml &> log_gen_0 &
```
As for the disaggregated server, you should also specify the --metadata_server_config_file like the following
```
trtllm-serve disaggregated -c disagg_config.yaml -m ./metadata_config.yml
```
The metadata_config file looks like
```
hostname: "localhost"
port: 2379
health_check_timeout: 5.0
refersh_interval: 10.0
```
The ```hostname``` and ```port``` must match those used when starting the ETCD server. The ```health_check_timeout``` parameter specifies how long a server will be considered dead if no healthy response is received. By default, trtllm will perform two checks before marking a server as dead. The ```refresh_interval``` parameter determines how often the latest server list is fetched from the ETCD server.
### Dynamically adding servers
Users can add servers by directly launching them with trtllm-serve. For example, you can start an additional generation server as follows:
```
CUDA_VISIBLE_DEVICES=3 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8004 --backend pytorch --server_role GENERATION --extra_llm_api_options ./gen_extra-llm-api-config.yml --metadata_server_config_file ./metadata_config.yml &> log_gen_0 &
```
Trtllm will automatically register any newly launched server with the ETCD server, allowing the router to send new requests to the added server.
### Dynamically removing servers
When removing servers, special attention is required in the current version. You need to first remove the corresponding key from the ETCD server. After you see the log message "Server xxxx is removed," you can then safely shut down the server. This part will be improved soon.