TensorRT-LLMs/docs/source/scripts/disaggregated/README.md
Xianjie Qiao 857108aeca
Add disagg slurm scripts (#5243)
Signed-off-by: Xianjie <5410381+qiaoxj07@users.noreply.github.com>
2025-06-18 23:17:55 +08:00

94 lines
4.8 KiB
Markdown

# Disaggregated Inference Benchmark Scripts
This directory contains scripts to run disaggregated inference benchmarks using TensorRT-LLM and SLURM.
## Overview
The benchmarking process is orchestrated through a set of shell scripts and a Python script that work together:
1. `submit.sh`: The main entry point for submitting benchmark jobs to SLURM. It runs a parameter sweep by calling `sbatch` with different configurations.
2. `disaggr_torch.slurm`: The SLURM script that sets up and runs a single benchmark experiment. It launches a container, generates a configuration file, starts the server and workers, and runs the benchmark client.
3. `gen_yaml.py`: A Python script that generates the `config.yaml` file needed by `trtllm-serve`. It determines the server and worker configuration based on SLURM environment variables and script arguments.
4. `start_worker.sh`: A shell script responsible for starting a `trtllm-serve disaggregated_mpi_worker` on each allocated machine.
5. `run_benchmark.sh`: A shell script that waits for the server to be healthy and then runs the actual benchmark client (`run_benchmark.py`, not included in this directory).
## File Descriptions
### `submit.sh`
This script is used to submit multiple SLURM jobs for running benchmarks with different parameters. It iterates through various configurations and uses `sbatch` to submit `disaggr_torch.slurm` for each one.
**Usage:**
```bash
./submit.sh
```
You can modify the loops in this script to change the parameter space for the benchmark sweep.
### `disaggr_torch.slurm`
This is the core SLURM script for a single benchmark run. It is not meant to be run directly, but rather submitted via `sbatch` (e.g., by `submit.sh`).
It takes the following arguments in order:
1. `num_ctx_servers`: Number of context servers.
2. `ctx_tp_size`: Tensor parallel size for context servers.
3. `ctx_batch_size`: Max batch size for context servers.
4. `ctx_max_num_tokens`: Max number of tokens for context servers.
5. `ctx_enable_attention_dp`: `true` or `false` to enable attention DP for context servers.
6. `num_gen_servers`: Number of generation servers.
7. `gen_tp_size`: Tensor parallel size for generation servers.
8. `gen_batch_size`: Max batch size for generation servers.
9. `gen_max_num_tokens`: Max number of tokens for generation servers.
10. `gen_enable_attention_dp`: `true` or `false` to enable attention DP for generation servers.
11. `gen_gpu_memory_fraction`: GPU memory fraction for generation servers.
12. `concurrency_list`: A space-separated list of concurrencies to test (e.g., "1 2 4 8").
13. `sub_file`: A subdirectory name for logs.
### `gen_yaml.py`
This Python script generates the `config.yaml` file that configures the `trtllm-serve` application. It reads SLURM environment variables (`SLURM_JOB_NODELIST`, `SLURM_TASKS_PER_NODE`) to distribute workers across nodes.
**Usage:**
The script is called from within `disaggr_torch.slurm`. It takes numerous arguments to define the model, parallelism, and server configurations.
### `start_worker.sh`
This script starts a `trtllm-serve disaggregated_mpi_worker`. It is launched by `srun` from the `disaggr_torch.slurm` script on all allocated nodes.
**Arguments:**
1. `config_file`: Path to the `config.yaml` file.
2. `enable_pdl`: `true` or `false`.
3. `ctx_gpus`: Number of GPUs used for the context phase.
4. `work_dir`: (Optional) Directory to store nsys profiling output.
### `run_benchmark.sh`
This script orchestrates the execution of the benchmark client. It waits for the `config.yaml` to be created and for the server's `/health` endpoint to respond, then it runs the benchmark.
**Arguments:**
1. `isl`: Input sequence length.
2. `osl`: Output sequence length.
3. `multi_round`: Number of rounds for the benchmark.
4. `model_name`: Name of the model being benchmarked.
5. `concurrency_list`: Space-separated list of concurrencies.
6. `streaming`: `true` or `false`.
7. `log_path`: Path to the log directory.
## Workflow
1. The user runs `./submit.sh`.
2. `submit.sh` submits one or more jobs to SLURM by calling `sbatch disaggr_torch.slurm` with different parameters.
3. For each job, SLURM allocates resources and runs `disaggr_torch.slurm`.
4. `disaggr_torch.slurm` runs `gen_yaml.py` to create a `config.yaml`.
5. `disaggr_torch.slurm` uses `srun` to launch `start_worker.sh` on all nodes, starting the MPI workers.
6. `disaggr_torch.slurm` starts the main `trtllm-serve` process.
7. `disaggr_torch.slurm` runs `run_benchmark.sh` which waits for the server to be ready.
8. `run_benchmark.sh` executes the benchmark for each concurrency level specified.
9. After the benchmark, `run_benchmark.sh` and `disaggr_torch.slurm` attempt to kill the server and worker processes.
10. Logs for each run are stored in a subdirectory specified by the `sub_file` parameter.