mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Add wide-ep benchmarking scripts (#5760)
Signed-off-by: Xianjie <5410381+qiaoxj07@users.noreply.github.com> Signed-off-by: Xianjie Qiao <5410381+qiaoxj07@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
089fd55eda
commit
b1976c2add
129
examples/wide_ep/slurm_scripts/README.md
Normal file
129
examples/wide_ep/slurm_scripts/README.md
Normal file
@ -0,0 +1,129 @@
|
||||
# TensorRT-LLM Wide-EP Benchmark Scripts
|
||||
|
||||
This directory contains scripts for benchmarking TensorRT-LLM wide-ep performance using SLURM job scheduler.
|
||||
|
||||
## ⚠️ DISCLAIMER
|
||||
|
||||
**These scripts are currently not QA'ed and are provided for demonstration purposes only.**
|
||||
|
||||
Please note that:
|
||||
|
||||
- These scripts have not undergone formal quality assurance testing
|
||||
- They are intended for demonstration and educational purposes
|
||||
- Use at your own risk in production environments
|
||||
- Always review and test scripts thoroughly before running in your specific environment
|
||||
|
||||
## Scripts Overview
|
||||
|
||||
### Core Scripts
|
||||
|
||||
1. **`submit.sh`** - Main entry point for submitting benchmark jobs
|
||||
2. **`disaggr_torch.slurm`** - SLURM job script orchestrating the entire benchmark
|
||||
3. **`gen_yaml.py`** - Generates configuration files for serving setup
|
||||
4. **`start_server.sh`** - Starts the inference server
|
||||
5. **`start_worker.sh`** - Starts the worker processes
|
||||
6. **`run_benchmark.sh`** - Executes the benchmark workload
|
||||
7. **`process_gen_iterlog.py`** - Processes benchmark results and generates reports
|
||||
|
||||
## Usage
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Before running the scripts, ensure you have:
|
||||
- Access to a SLURM cluster
|
||||
- Container image with TensorRT-LLM installed
|
||||
- Model files accessible on the cluster
|
||||
- Required environment variables set
|
||||
|
||||
### Configuration
|
||||
|
||||
Edit the following variables in `submit.sh` and `disaggr_torch.slurm`:
|
||||
|
||||
```bash
|
||||
# In disaggr_torch.slurm
|
||||
container_image=${container_image} # Your container image
|
||||
mount_dir=${mount_dir} # Mount directory path
|
||||
model_dir=${model_dir} # Model directory path
|
||||
```
|
||||
|
||||
### Running Benchmarks
|
||||
|
||||
1. **Submit benchmark jobs**:
|
||||
```bash
|
||||
./submit.sh
|
||||
```
|
||||
|
||||
2. **Monitor job progress**:
|
||||
```bash
|
||||
squeue -u $USER
|
||||
```
|
||||
|
||||
3. **View results**:
|
||||
Results are saved in `bm_20250703_deepseek-r1-{isl}-{osl}/` directory
|
||||
|
||||
## Script Details
|
||||
|
||||
### `submit.sh`
|
||||
Main entry script that submits multiple SLURM jobs with different configurations:
|
||||
- **DEP8**: 8-way parallelism for decode servers
|
||||
- **DEP16**: 16-way parallelism with different EPLB slot configurations
|
||||
- **DEP32**: 32-way parallelism for high-throughput scenarios
|
||||
|
||||
Parameters tested:
|
||||
- Concurrency levels: 1x, 64x, 1024x multipliers
|
||||
- EPLB slots: 0, 256, 288
|
||||
- Different parallelism sizes
|
||||
|
||||
### `disaggr_torch.slurm`
|
||||
SLURM job script that:
|
||||
1. Sets up container environment
|
||||
2. Generates configuration files
|
||||
3. Starts server and workers
|
||||
4. Executes benchmarks
|
||||
5. Cleans up processes
|
||||
|
||||
**Key parameters**:
|
||||
- `num_ctx_servers`: Number of context servers
|
||||
- `ctx_tp_size`: Tensor parallel size for context servers
|
||||
- `num_gen_servers`: Number of generation servers
|
||||
- `gen_tp_size`: Tensor parallel size for generation servers
|
||||
- `concurrency`: Number of concurrent requests
|
||||
|
||||
### `gen_yaml.py`
|
||||
Generates YAML configuration files with:
|
||||
- Server topology and resource allocation
|
||||
- Network configuration (hostnames, ports)
|
||||
- Memory and batch size settings
|
||||
- Optimization parameters (CUDA graphs, KV cache)
|
||||
|
||||
**Key features**:
|
||||
- Automatic node and task allocation
|
||||
- Support for attention data parallelism
|
||||
- MoE load balancing configuration
|
||||
- Speculative decoding (MTP) support
|
||||
|
||||
### `start_server.sh` & `start_worker.sh`
|
||||
- **Server**: Starts the main inference server with API endpoint
|
||||
- **Workers**: Starts MPI workers for distributed processing
|
||||
- Support for profiling with NSight Systems
|
||||
- Environment variable configuration for optimizations
|
||||
|
||||
### `run_benchmark.sh`
|
||||
Executes benchmarking using TensorRT-LLM's benchmark_serving tool:
|
||||
- Downloads ShareGPT dataset for realistic workloads
|
||||
- Waits for server health checks
|
||||
- Runs load testing with specified concurrency
|
||||
- Collects performance metrics
|
||||
- Gracefully shuts down services
|
||||
|
||||
**Metrics collected**:
|
||||
- Throughput (tokens/second)
|
||||
- Latency (request completion time)
|
||||
- Context vs generation only statistics
|
||||
|
||||
### `process_gen_iterlog.py`
|
||||
Post-processes benchmark results:
|
||||
- Parses iteration logs from workers
|
||||
- Calculates throughput metrics
|
||||
- Generates CSV reports
|
||||
- Supports MTP (Multi-Token Prediction) analysis
|
||||
119
examples/wide_ep/slurm_scripts/disaggr_torch.slurm
Normal file
119
examples/wide_ep/slurm_scripts/disaggr_torch.slurm
Normal file
@ -0,0 +1,119 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --nodes=2
|
||||
#SBATCH --ntasks=8
|
||||
#SBATCH --ntasks-per-node=4
|
||||
#SBATCH --partition=${partition} # add your partition here
|
||||
#SBATCH --account=${account} # add your account here
|
||||
#SBATCH --time=01:00:00
|
||||
#SBATCH --job-name=${job_name} # add your job name here
|
||||
|
||||
isl=1024
|
||||
osl=1024
|
||||
multi_round=1
|
||||
gen_yaml_file=gen_yaml.py
|
||||
container_image=${container_image} # add your container image here
|
||||
mount_dir=${mount_dir} # add your mount directory here
|
||||
workdir=${mount_dir}/bench-large-ep/slurm_scripts/
|
||||
model_dir=${model_dir} # add your model directory here
|
||||
logdir=${workdir}/bm_20250703_deepseek-r1-${isl}-${osl}/
|
||||
streaming=false
|
||||
mkdir -p ${logdir}
|
||||
|
||||
container_name=disaggr-test
|
||||
|
||||
num_ctx_servers=$1
|
||||
ctx_tp_size=$2
|
||||
ctx_batch_size=$3
|
||||
ctx_max_num_tokens=$4
|
||||
ctx_enable_attention_dp=$5
|
||||
num_gen_servers=$6
|
||||
gen_tp_size=$7
|
||||
gen_batch_size=$8
|
||||
gen_max_num_tokens=$9
|
||||
gen_enable_attention_dp=${10}
|
||||
gen_gpu_memory_fraction=${11}
|
||||
eplb_num_slots=${12}
|
||||
mtp_size=${13}
|
||||
concurrency=${14}
|
||||
|
||||
sub_dir=${logdir}/dep${gen_tp_size}_concurrency${concurrency}_eplb${eplb_num_slots}_mtp${mtp_size}
|
||||
|
||||
ctx_gpus=$((num_ctx_servers * ctx_tp_size))
|
||||
gen_gpus=$((num_gen_servers * gen_tp_size))
|
||||
|
||||
echo "enable_attention_dp: ${ctx_enable_attention_dp}, ${gen_enable_attention_dp}, gpu_memory_fraction: ${gen_gpu_memory_fraction}"
|
||||
|
||||
enable_pdl=false
|
||||
if [ "${gen_enable_attention_dp}" = "false" ]; then
|
||||
enable_pdl=true
|
||||
echo "enable_pdl: ${enable_pdl}"
|
||||
sub_dir=${logdir}/tep${gen_tp_size}_concurrency${concurrency}_eplb${eplb_num_slots}_mtp${mtp_size}
|
||||
fi
|
||||
|
||||
full_logdir=${sub_dir}
|
||||
mkdir -p ${full_logdir}
|
||||
|
||||
# start the container
|
||||
srun -l --container-image=${container_image} \
|
||||
--container-name=${container_name} \
|
||||
--container-mounts=${mount_dir}:${mount_dir} \
|
||||
--mpi=pmix \
|
||||
echo "Container up."
|
||||
|
||||
# generate the yaml file
|
||||
srun -l --container-name=${container_name} \
|
||||
--container-mounts=${mount_dir}:${mount_dir} \
|
||||
--mpi=pmix --overlap \
|
||||
python3 ${workdir}/${gen_yaml_file} --config ${full_logdir}/config.yaml \
|
||||
--model ${model_dir} \
|
||||
--num_ctx_servers ${num_ctx_servers} \
|
||||
--ctx_tp_size ${ctx_tp_size} \
|
||||
--ctx_batch_size ${ctx_batch_size} \
|
||||
--ctx_max_num_tokens ${ctx_max_num_tokens} \
|
||||
--num_gen_servers ${num_gen_servers} \
|
||||
--gen_tp_size ${gen_tp_size} \
|
||||
--gen_batch_size ${gen_batch_size} \
|
||||
--gen_max_num_tokens ${gen_max_num_tokens} \
|
||||
--gen_gpu_memory_fraction ${gen_gpu_memory_fraction} \
|
||||
--eplb_num_slots ${eplb_num_slots} \
|
||||
$(if [ "${gen_enable_attention_dp}" = "true" ]; then echo "--gen_enable_attention_dp"; fi) \
|
||||
$(if [ "${ctx_enable_attention_dp}" = "true" ]; then echo "--ctx_enable_attention_dp"; fi) \
|
||||
$(if [ "${mtp_size}" -gt 0 ]; then echo "--mtp_size ${mtp_size}"; fi)
|
||||
|
||||
echo "YAML file generated."
|
||||
|
||||
hostname_value=$(grep '^hostname:' ${full_logdir}/config.yaml | awk -F': ' '{print $2}')
|
||||
echo "server host name: $hostname_value"
|
||||
|
||||
# try to kill the server and workers
|
||||
srun -l --container-name=${container_name} \
|
||||
--container-mounts=${mount_dir}:${mount_dir} \
|
||||
--mpi=pmix --overlap \
|
||||
pkill -f "trtllm-serve" || true
|
||||
|
||||
nsys_on=""
|
||||
# nsys_on=${full_logdir}
|
||||
|
||||
# start the workers
|
||||
srun -l --container-name=${container_name} \
|
||||
--container-mounts=${mount_dir}:${mount_dir} \
|
||||
--mpi=pmix --overlap \
|
||||
bash ${workdir}/start_worker.sh ${full_logdir}/config.yaml "${concurrency}" "${enable_pdl}" ${ctx_gpus} ${nsys_on} &> ${full_logdir}/output_workers.log &
|
||||
# start the server
|
||||
srun -l --container-name=${container_name} \
|
||||
--container-mounts=${mount_dir}:${mount_dir} \
|
||||
--mpi=pmix --overlap -N 1 -n 1 \
|
||||
-w ${hostname_value} \
|
||||
bash ${workdir}/start_server.sh ${full_logdir}/config.yaml &> ${full_logdir}/output_server.log &
|
||||
# start benchmarking
|
||||
srun -l --container-name=${container_name} \
|
||||
--container-mounts=${mount_dir}:${mount_dir} \
|
||||
--mpi=pmix --overlap -N 1 -n 1 \
|
||||
bash ${workdir}/run_benchmark.sh ${isl} ${osl} ${multi_round} ${model_dir} "${concurrency}" ${streaming} ${full_logdir}/ > ${full_logdir}/benchmark.log 2>&1
|
||||
|
||||
# try to kill the server and workers
|
||||
srun -l --container-name=${container_name} \
|
||||
--container-mounts=${mount_dir}:${mount_dir} \
|
||||
--mpi=pmix --overlap \
|
||||
kill -9 $(ps aux | grep '[t]rtllm-serve' | awk '{print $2}') >/dev/null 2>&1 || true
|
||||
wait
|
||||
338
examples/wide_ep/slurm_scripts/gen_yaml.py
Normal file
338
examples/wide_ep/slurm_scripts/gen_yaml.py
Normal file
@ -0,0 +1,338 @@
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def process_node_and_task() -> tuple[int, List[str], List[str]]:
|
||||
"""
|
||||
Process SLURM node and task environment variables.
|
||||
|
||||
Returns:
|
||||
tuple: (max_tasks_per_node, nodes, task_nodes)
|
||||
"""
|
||||
slurm_job_nodelist = os.getenv('SLURM_JOB_NODELIST', '')
|
||||
print(f"SLURM_JOB_NODELIST: {slurm_job_nodelist}")
|
||||
if not slurm_job_nodelist:
|
||||
raise ValueError(f"Environment variable SLURM_JOB_NODELIST not found.")
|
||||
|
||||
slurm_tasks_per_node = os.getenv('SLURM_TASKS_PER_NODE', '')
|
||||
print(f"SLURM_TASKS_PER_NODE: {slurm_tasks_per_node}")
|
||||
if not slurm_tasks_per_node:
|
||||
raise ValueError(
|
||||
f"Environment variable SLURM_TASKS_PER_NODE not found.")
|
||||
|
||||
# Generate list of nodes
|
||||
if '[' in slurm_job_nodelist:
|
||||
# Handle nodelist with range format
|
||||
node_prefix = slurm_job_nodelist.split('[')[
|
||||
0] # Extract everything before '['
|
||||
node_range = re.search(r'\[(.*?)\]', slurm_job_nodelist).group(1)
|
||||
nodes = []
|
||||
for part in node_range.split(','):
|
||||
if '-' in part:
|
||||
start, end = part.split('-')
|
||||
# Get the width of the number format from the first number
|
||||
width = len(start)
|
||||
# Convert to integers after getting the width
|
||||
start, end = int(start), int(end)
|
||||
# Format numbers with leading zeros
|
||||
nodes.extend([
|
||||
f"{node_prefix}{str(i).zfill(width)}"
|
||||
for i in range(start, end + 1)
|
||||
])
|
||||
else:
|
||||
# Preserve the original format for single numbers
|
||||
nodes.append(f"{node_prefix}{part}")
|
||||
else:
|
||||
# Handle single node format
|
||||
nodes = [slurm_job_nodelist]
|
||||
print(f"Nodes: {nodes}")
|
||||
|
||||
# Generate tasks per node
|
||||
tasks_per_node = []
|
||||
for part in slurm_tasks_per_node.split(','):
|
||||
if '(x' in part:
|
||||
count, repeat = map(int, re.findall(r'\d+', part))
|
||||
tasks_per_node.extend([count] * repeat)
|
||||
else:
|
||||
tasks_per_node.append(int(part))
|
||||
print(f"Tasks per node: {tasks_per_node}")
|
||||
|
||||
if (len(tasks_per_node) != len(nodes)):
|
||||
raise ValueError(
|
||||
f"Number of nodes and tasks per node do not match. Number of nodes: {len(nodes)}, Number of tasks per node: {len(tasks_per_node)}"
|
||||
)
|
||||
|
||||
max_tasks_per_node = max(tasks_per_node)
|
||||
task_nodes = []
|
||||
for node, tasks in zip(nodes, tasks_per_node):
|
||||
task_nodes.extend([node] * tasks)
|
||||
|
||||
return max_tasks_per_node, nodes, task_nodes
|
||||
|
||||
|
||||
def generate_urls(ctx_or_gen: str,
|
||||
num_instances: int,
|
||||
tensor_parallel_size: int,
|
||||
pipeline_parallel_size: int,
|
||||
max_tasks_per_node: int,
|
||||
nodes: List[str],
|
||||
task_nodes: List[str],
|
||||
node_to_port: Dict[str, int],
|
||||
task_nodes_offset: int = 0) -> tuple[List[str], int]:
|
||||
"""
|
||||
Generate URLs for context or generation servers.
|
||||
|
||||
Returns:
|
||||
tuple: (urls, updated_task_nodes_offset)
|
||||
"""
|
||||
urls = []
|
||||
|
||||
for instance in range(num_instances):
|
||||
tasks_needed = tensor_parallel_size * pipeline_parallel_size
|
||||
|
||||
if (task_nodes_offset + tasks_needed) > len(task_nodes):
|
||||
print(f"{ctx_or_gen} urls so far: {urls}")
|
||||
raise ValueError(
|
||||
f"For {ctx_or_gen} instance {instance}, there are not enough tasks available. task_nodes_offset: {task_nodes_offset}, tasks_needed: {tasks_needed}, len(task_nodes): {len(task_nodes)}"
|
||||
)
|
||||
|
||||
min_node = (tasks_needed + max_tasks_per_node - 1) // max_tasks_per_node
|
||||
instance_nodes = set(task_nodes[task_nodes_offset:task_nodes_offset +
|
||||
tasks_needed])
|
||||
if len(instance_nodes) > min_node:
|
||||
raise ValueError(
|
||||
f"Tasks for a instance {instance} of {ctx_or_gen} instances use more node than expected. Nodes used: {instance_nodes}, number of nodes expected: {min_node}, max_tasks_per_node: {max_tasks_per_node}"
|
||||
)
|
||||
|
||||
node = task_nodes[task_nodes_offset]
|
||||
port = node_to_port[node]
|
||||
node_to_port[node] += 1
|
||||
task_nodes_offset += tasks_needed
|
||||
|
||||
urls.append(f"{node}:{port}")
|
||||
|
||||
print(f"{ctx_or_gen} urls: {urls}")
|
||||
return urls, task_nodes_offset
|
||||
|
||||
|
||||
def gen_config_file(config_path: str,
|
||||
model_path: str,
|
||||
num_ctx_servers: int,
|
||||
ctx_tp_size: int,
|
||||
ctx_batch_size: int,
|
||||
ctx_max_num_tokens: int,
|
||||
ctx_enable_attention_dp: bool,
|
||||
num_gen_servers: int,
|
||||
gen_tp_size: int,
|
||||
gen_batch_size: int,
|
||||
gen_max_num_tokens: int,
|
||||
gen_enable_attention_dp: bool,
|
||||
gen_gpu_memory_fraction: float,
|
||||
eplb_num_slots: int,
|
||||
mtp_size: int = 0,
|
||||
worker_start_port: int = 8001,
|
||||
server_port: int = 8000) -> None:
|
||||
"""
|
||||
Generate configuration YAML file for disaggregated inference.
|
||||
|
||||
Args:
|
||||
config_path: Path to save the config file
|
||||
model_path: Path to the model
|
||||
num_ctx_servers: Number of context servers
|
||||
ctx_tp_size: Tensor parallel size for context servers
|
||||
ctx_batch_size: Batch size for context servers
|
||||
ctx_max_num_tokens: Max number of tokens for context servers
|
||||
ctx_enable_attention_dp: Enable attention DP for context servers
|
||||
num_gen_servers: Number of generation servers
|
||||
gen_tp_size: Tensor parallel size for generation servers
|
||||
gen_batch_size: Batch size for generation servers
|
||||
gen_max_num_tokens: Max number of tokens for generation servers
|
||||
gen_enable_attention_dp: Enable attention DP for generation servers
|
||||
gen_gpu_memory_fraction: GPU memory fraction for generation servers
|
||||
eplb_num_slots: Number of slots for eplb
|
||||
worker_start_port: Start port for workers
|
||||
server_port: Server port
|
||||
"""
|
||||
gen_cuda_graph_batch_sizes = [
|
||||
1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 768, 1024, 2048, gen_batch_size
|
||||
]
|
||||
|
||||
gen_moe_backend = "WIDEEP"
|
||||
|
||||
config = {
|
||||
'model': model_path,
|
||||
'hostname': 'localhost',
|
||||
'port': server_port,
|
||||
'backend': 'pytorch',
|
||||
'context_servers': {
|
||||
'num_instances': num_ctx_servers,
|
||||
'max_batch_size': ctx_batch_size,
|
||||
'max_num_tokens': ctx_max_num_tokens,
|
||||
'max_seq_len': 1152,
|
||||
'free_gpu_memory_fraction': 0.85,
|
||||
'tensor_parallel_size': ctx_tp_size,
|
||||
'moe_expert_parallel_size': ctx_tp_size,
|
||||
'enable_attention_dp': ctx_enable_attention_dp,
|
||||
'pipeline_parallel_size': 1,
|
||||
'print_iter_log': True,
|
||||
'disable_overlap_scheduler': True,
|
||||
'kv_cache_dtype': 'fp8',
|
||||
'cache_transceiver_config': {
|
||||
'max_num_tokens': 4608,
|
||||
},
|
||||
},
|
||||
'generation_servers': {
|
||||
'num_instances': num_gen_servers,
|
||||
'tensor_parallel_size': gen_tp_size,
|
||||
'moe_expert_parallel_size': gen_tp_size,
|
||||
'enable_attention_dp': gen_enable_attention_dp,
|
||||
'pipeline_parallel_size': 1,
|
||||
'max_batch_size': gen_batch_size,
|
||||
'max_num_tokens': gen_max_num_tokens,
|
||||
'max_seq_len': 2176,
|
||||
'free_gpu_memory_fraction': gen_gpu_memory_fraction,
|
||||
'cuda_graph_config': {
|
||||
'padding_enabled': True,
|
||||
'batch_sizes': gen_cuda_graph_batch_sizes,
|
||||
},
|
||||
'print_iter_log': True,
|
||||
'kv_cache_dtype': 'fp8',
|
||||
'moe_backend': gen_moe_backend,
|
||||
'cache_transceiver_config': {
|
||||
'max_num_tokens': 4608,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
# Process nodes and generate URLs
|
||||
max_tasks_per_node, nodes, task_nodes = process_node_and_task()
|
||||
node_ports = {node: worker_start_port for node in nodes}
|
||||
|
||||
# Generate URLs for context and generation servers
|
||||
ctx_urls, task_nodes_offset = generate_urls("ctx", num_ctx_servers,
|
||||
ctx_tp_size, 1,
|
||||
max_tasks_per_node, nodes,
|
||||
task_nodes, node_ports)
|
||||
if num_ctx_servers > 0:
|
||||
config['context_servers']['urls'] = ctx_urls
|
||||
|
||||
gen_urls, _ = generate_urls("gen", num_gen_servers, gen_tp_size, 1,
|
||||
max_tasks_per_node, nodes, task_nodes,
|
||||
node_ports, task_nodes_offset)
|
||||
config['generation_servers']['urls'] = gen_urls
|
||||
|
||||
# set the hostname to the first node
|
||||
config['hostname'] = nodes[0]
|
||||
|
||||
if eplb_num_slots > 0:
|
||||
moe_load_balancer_file = os.path.join(os.path.dirname(config_path),
|
||||
"moe_load_balancer.yaml")
|
||||
moe_load_balancer_config = {
|
||||
'num_slots': eplb_num_slots,
|
||||
'layer_updates_per_iter': 1
|
||||
}
|
||||
with open(moe_load_balancer_file, "w") as f:
|
||||
yaml.dump(moe_load_balancer_config,
|
||||
f,
|
||||
default_flow_style=False,
|
||||
sort_keys=False)
|
||||
config['generation_servers'][
|
||||
'moe_load_balancer'] = moe_load_balancer_file
|
||||
|
||||
if mtp_size > 0:
|
||||
config['context_servers']['speculative_config'] = {
|
||||
'decoding_type': 'MTP',
|
||||
'num_nextn_predict_layers': mtp_size
|
||||
}
|
||||
config['generation_servers']['speculative_config'] = {
|
||||
'decoding_type': 'MTP',
|
||||
'num_nextn_predict_layers': mtp_size
|
||||
}
|
||||
|
||||
# Write config to file
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(config, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
|
||||
# gen main and args
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="/tmp/config.yaml")
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the model")
|
||||
parser.add_argument("--num_ctx_servers",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Number of context servers")
|
||||
parser.add_argument("--ctx_tp_size",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Tensor parallel size for context servers")
|
||||
parser.add_argument("--ctx_batch_size",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Batch size for context servers")
|
||||
parser.add_argument("--ctx_max_num_tokens",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Max number of tokens for context servers")
|
||||
parser.add_argument("--ctx_enable_attention_dp",
|
||||
dest='ctx_enable_attention_dp',
|
||||
action='store_true',
|
||||
help="Enable attention DP for context servers")
|
||||
parser.add_argument("--num_gen_servers",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Number of generation servers")
|
||||
parser.add_argument("--gen_tp_size",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Tensor parallel size for generation servers")
|
||||
parser.add_argument("--gen_batch_size",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Batch size for generation servers")
|
||||
parser.add_argument("--gen_max_num_tokens",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Max number of tokens for generation servers")
|
||||
parser.add_argument("--gen_enable_attention_dp",
|
||||
dest='gen_enable_attention_dp',
|
||||
action='store_true',
|
||||
help="Enable attention DP for generation servers")
|
||||
parser.add_argument("--gen_gpu_memory_fraction",
|
||||
type=float,
|
||||
required=True,
|
||||
help="GPU memory fraction for generation servers")
|
||||
parser.add_argument("--eplb_num_slots",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of slots for eplb")
|
||||
parser.add_argument("--mtp_size",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of nextn layers for MTP")
|
||||
parser.add_argument("--worker_start_port",
|
||||
type=int,
|
||||
default=8336,
|
||||
help="Start port for workers")
|
||||
parser.add_argument("--server_port",
|
||||
type=int,
|
||||
default=8333,
|
||||
help="Server port")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
gen_config_file(args.config, args.model, args.num_ctx_servers,
|
||||
args.ctx_tp_size, args.ctx_batch_size,
|
||||
args.ctx_max_num_tokens, args.ctx_enable_attention_dp,
|
||||
args.num_gen_servers, args.gen_tp_size, args.gen_batch_size,
|
||||
args.gen_max_num_tokens, args.gen_enable_attention_dp,
|
||||
args.gen_gpu_memory_fraction, args.eplb_num_slots,
|
||||
args.mtp_size, args.worker_start_port, args.server_port)
|
||||
154
examples/wide_ep/slurm_scripts/process_gen_iterlog.py
Normal file
154
examples/wide_ep/slurm_scripts/process_gen_iterlog.py
Normal file
@ -0,0 +1,154 @@
|
||||
import argparse
|
||||
import glob
|
||||
import re
|
||||
|
||||
import pandas as pd
|
||||
|
||||
mtp_accept_rate = {1: 1.86, 2: 2.42, 3: 2.68}
|
||||
|
||||
|
||||
def process_files(dir_prefix):
|
||||
summary_data = []
|
||||
pattern = f"{dir_prefix}*/concurrency_*/gen_only.txt"
|
||||
files = glob.glob(pattern)
|
||||
print(f"Found {len(files)} files matching pattern {pattern}")
|
||||
|
||||
for file in files:
|
||||
data = []
|
||||
# Extract parameter information from file path
|
||||
# Match (tep|dep)(number)_concurrency(number)_eplb(number)
|
||||
match = re.search(
|
||||
r'(tep|dep)(\d+)_concurrency(\d+)_eplb(\d+)(?:_mtp(\d+))?', file)
|
||||
if not match:
|
||||
# print(f"No match found for file {file}")
|
||||
continue
|
||||
|
||||
# Directly use the second format parsing logic
|
||||
attn_type = match.group(1)
|
||||
rank_num = int(match.group(2))
|
||||
concurrency = int(match.group(3))
|
||||
eplb_num = int(match.group(4))
|
||||
mtp_num = int(match.group(5)) if match.group(5) else 0
|
||||
|
||||
# Determine tp_rank and ep_rank based on folder name
|
||||
if attn_type == 'tep':
|
||||
ep_rank = rank_num
|
||||
else: # dep
|
||||
ep_rank = rank_num
|
||||
name = f"{attn_type}_{rank_num}_eplb{eplb_num}_mtp{mtp_num}"
|
||||
|
||||
# Read and parse log file
|
||||
try:
|
||||
with open(file, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read()
|
||||
|
||||
# Use regex to match specified format log lines
|
||||
log_pattern = r'iter = (\d+), global_rank = (\d+), rank = (\d+), currank_total_requests = (\d+)/(\d+), elapsed_time = ([\d.]+)s, timestamp = ([^,]+), num_scheduled_requests: (\d+), states = \{\'num_ctx_requests\': (\d+), \'num_ctx_tokens\': (\d+), \'num_generation_tokens\': (\d+)\}'
|
||||
|
||||
matches = re.findall(log_pattern, content)
|
||||
|
||||
if matches:
|
||||
# Process each matched log line
|
||||
for match in matches:
|
||||
iter_num = int(match[0])
|
||||
global_rank = int(match[1])
|
||||
rank = int(match[2])
|
||||
current_requests = int(match[3])
|
||||
total_requests = int(match[4])
|
||||
elapsed_time = float(match[5])
|
||||
timestamp = match[6]
|
||||
num_scheduled_requests = int(match[7])
|
||||
num_ctx_requests = int(match[8])
|
||||
num_ctx_tokens = int(match[9])
|
||||
num_generation_tokens = int(match[10])
|
||||
|
||||
# Calculate throughput metrics
|
||||
# Here you can calculate corresponding performance metrics as needed
|
||||
throughput_per_user = num_generation_tokens / elapsed_time if elapsed_time > 0 else 0
|
||||
|
||||
data.append({
|
||||
'concurrency': concurrency,
|
||||
'iter': iter_num,
|
||||
'global_rank': global_rank,
|
||||
'rank': rank,
|
||||
'current_requests': current_requests,
|
||||
'total_requests': total_requests,
|
||||
'elapsed_time': elapsed_time,
|
||||
'timestamp': timestamp,
|
||||
'num_scheduled_requests': num_scheduled_requests,
|
||||
'num_ctx_requests': num_ctx_requests,
|
||||
'num_ctx_tokens': num_ctx_tokens,
|
||||
'num_generation_tokens': num_generation_tokens,
|
||||
'throughput_per_user': throughput_per_user
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"Error processing file {file}: {e}")
|
||||
continue
|
||||
# if data is not empty, save to csv
|
||||
if data:
|
||||
df = pd.DataFrame(data)
|
||||
df = df.sort_values(['concurrency', 'iter'])
|
||||
# file name is the same as the file prefix + .csv
|
||||
output_file = file.split('.')[0] + '.csv'
|
||||
|
||||
# Filter rows where num_ctx_tokens == 0
|
||||
df = df[df['num_ctx_tokens'] == 0]
|
||||
|
||||
df = df.iloc[50:-10]
|
||||
if attn_type == 'tep':
|
||||
df = df[df['num_scheduled_requests'] == int(concurrency)]
|
||||
df = df[df['num_generation_tokens'] == int(concurrency *
|
||||
(mtp_num + 1))]
|
||||
elif attn_type == 'dep':
|
||||
df = df[df['num_scheduled_requests'] == int(concurrency /
|
||||
ep_rank)]
|
||||
df = df[df['num_generation_tokens'] == int(concurrency /
|
||||
ep_rank *
|
||||
(mtp_num + 1))]
|
||||
|
||||
df.to_csv(output_file, index=False)
|
||||
print(f"Data saved to {output_file}")
|
||||
print(f"Total records processed: {len(data)}")
|
||||
|
||||
# check df is empty
|
||||
if df.empty:
|
||||
print(f"No valid data found for {file}")
|
||||
else:
|
||||
# get elapsed_time avg time
|
||||
elapsed_time_avg = df['elapsed_time'].mean()
|
||||
throughput_per_user = 1 / elapsed_time_avg if elapsed_time_avg > 0 else 0
|
||||
throughput_per_user = throughput_per_user * mtp_accept_rate[
|
||||
mtp_num] if mtp_num > 0 else throughput_per_user
|
||||
output_throughput = throughput_per_user * concurrency
|
||||
throughput_per_gpu = output_throughput / ep_rank
|
||||
summary_data.append({
|
||||
'name': name,
|
||||
'concurrency': concurrency,
|
||||
'throughput_per_user': throughput_per_user,
|
||||
'throughput_per_gpu': throughput_per_gpu,
|
||||
'output_throughput': output_throughput,
|
||||
'elapsed_time_avg': elapsed_time_avg,
|
||||
'number_iters': len(df)
|
||||
})
|
||||
|
||||
if summary_data:
|
||||
# Create DataFrame and sort
|
||||
df = pd.DataFrame(summary_data)
|
||||
df = df.sort_values(['name', 'concurrency'])
|
||||
|
||||
# Save as CSV
|
||||
output_file = f"{dir_prefix}_iterlog.csv"
|
||||
df.to_csv(output_file, index=False)
|
||||
print(f"Data saved to {output_file}")
|
||||
print(f"Total records processed: {len(data)}")
|
||||
else:
|
||||
print("No valid data found to save")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Process benchmark files and aggregate data.')
|
||||
parser.add_argument('--dir_prefix',
|
||||
help='Directory prefix to search for benchmark files')
|
||||
args = parser.parse_args()
|
||||
process_files(args.dir_prefix)
|
||||
133
examples/wide_ep/slurm_scripts/run_benchmark.sh
Normal file
133
examples/wide_ep/slurm_scripts/run_benchmark.sh
Normal file
@ -0,0 +1,133 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Add error handling
|
||||
set -e
|
||||
set -u
|
||||
trap 'echo "Error occurred at line $LINENO"; exit 1' ERR
|
||||
|
||||
# Add parameter validation
|
||||
if [ "$#" -lt 7 ]; then
|
||||
echo "Error: Missing required arguments"
|
||||
echo "Usage: $0 isl osl multi_round model_name concurrency_list streaming log_path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
isl=$1
|
||||
osl=$2
|
||||
multi_round=$3
|
||||
model_name=$4
|
||||
concurrency=$5
|
||||
streaming=$6
|
||||
log_path=$7
|
||||
|
||||
# check process id is not 0
|
||||
if [[ ${SLURM_PROCID} != "0" ]]; then
|
||||
echo "Process id is ${SLURM_PROCID} for loadgen, exiting"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "TRT_LLM_GIT_COMMIT: ${TRT_LLM_GIT_COMMIT}"
|
||||
|
||||
set -x
|
||||
config_file=${log_path}/config.yaml
|
||||
|
||||
# check if the config file exists every 10 seconds timeout 1800 seconds
|
||||
timeout=1800
|
||||
start_time=$(date +%s)
|
||||
while [ ! -f ${config_file} ]; do
|
||||
current_time=$(date +%s)
|
||||
elapsed=$((current_time - start_time))
|
||||
if [ $elapsed -ge $timeout ]; then
|
||||
echo "Error: Config file ${config_file} not found within ${timeout} seconds"
|
||||
exit 1
|
||||
fi
|
||||
if [ $((elapsed % 30)) -eq 0 ]; then
|
||||
echo "Waiting for config file... (${elapsed}s elapsed)"
|
||||
fi
|
||||
sleep 10
|
||||
done
|
||||
|
||||
# grep the host and port from the config file
|
||||
hostname=$(grep -i "hostname:" ${config_file} | awk '{print $2}')
|
||||
port=$(grep -i "port:" ${config_file} | awk '{print $2}')
|
||||
if [ -z "$hostname" ] || [ -z "$port" ]; then
|
||||
echo "Error: Failed to extract hostname or port from config file"
|
||||
exit 1
|
||||
fi
|
||||
echo "Hostname: ${hostname}, Port: ${port}"
|
||||
|
||||
# download sharedgpt for benchmarking
|
||||
shared_gpt_path=/tmp/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
if [ ! -f ${shared_gpt_path} ]; then
|
||||
echo "Downloading sharedgpt..."
|
||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -O ${shared_gpt_path}
|
||||
fi
|
||||
|
||||
# check server is health by curl every 10 seconds timeout 1800 seconds
|
||||
timeout=1800
|
||||
start_time=$(date +%s)
|
||||
while true; do
|
||||
status_code=$(curl -s -o /dev/null -w "%{http_code}" http://${hostname}:${port}/health)
|
||||
if [ "$status_code" -eq 200 ]; then
|
||||
break
|
||||
fi
|
||||
current_time=$(date +%s)
|
||||
elapsed=$((current_time - start_time))
|
||||
if [ $elapsed -ge $timeout ]; then
|
||||
echo "Error: Server is not healthy after ${timeout} seconds"
|
||||
exit 1
|
||||
fi
|
||||
if [ $((elapsed % 30)) -eq 0 ]; then
|
||||
echo "Waiting for server to be healthy... (${elapsed}s elapsed)"
|
||||
fi
|
||||
sleep 10
|
||||
done
|
||||
|
||||
# try client
|
||||
|
||||
do_get_logs(){
|
||||
worker_log_path=$1
|
||||
output_folder=$2
|
||||
grep -a "'num_ctx_requests': 0, 'num_ctx_tokens': 0" ${worker_log_path} > ${output_folder}/gen_only.txt || true
|
||||
grep -a "'num_generation_tokens': 0" ${worker_log_path} > ${output_folder}/ctx_only.txt || true
|
||||
}
|
||||
|
||||
# run the loadgen
|
||||
|
||||
mkdir -p ${log_path}/concurrency_${concurrency}
|
||||
cp ${log_path}/output_workers.log ${log_path}/concurrency_${concurrency}/workers_start.log
|
||||
max_count=$((${concurrency} * ${multi_round}))
|
||||
echo "Running loadgen with concurrency: ${concurrency}, max_count: ${max_count}"
|
||||
|
||||
python -m tensorrt_llm.serve.scripts.benchmark_serving \
|
||||
--model ${model_name} \
|
||||
--tokenizer ${model_name} \
|
||||
--dataset-name random \
|
||||
--dataset-path ${shared_gpt_path} \
|
||||
--random-input-len ${isl} \
|
||||
--random-output-len ${osl} \
|
||||
--random-prefix-len 0 \
|
||||
--num-prompts ${max_count} \
|
||||
--max-concurrency ${concurrency} \
|
||||
--host ${hostname} \
|
||||
--port ${port} \
|
||||
--ignore-eos \
|
||||
--no-test-input \
|
||||
$(if [ "${streaming}" = "false" ]; then echo "--non-streaming"; fi)
|
||||
|
||||
do_get_logs ${log_path}/output_workers.log ${log_path}/concurrency_${concurrency}
|
||||
# echo "" > ${log_path}/output_workers.log
|
||||
echo "done for ${concurrency} in folder ${log_path}/concurrency_${concurrency}"
|
||||
|
||||
echo "Benchmark done, gracefully shutting down server and workers..."
|
||||
kill -9 $(ps aux | grep '[s]tart_server.sh' | awk '{print $2}') >/dev/null 2>&1 || true
|
||||
kill -9 $(ps aux | grep '[s]tart_worker.sh' | awk '{print $2}') >/dev/null 2>&1 || true
|
||||
kill -9 $(ps aux | grep '[t]rtllm-serve' | awk '{print $2}') >/dev/null 2>&1 || true
|
||||
sleep 20 # Give processes some time to clean up
|
||||
|
||||
# Check if there are remaining processes
|
||||
if pgrep -f "trtllm-serve"; then
|
||||
echo "Warning: Some processes may still be running"
|
||||
else
|
||||
echo "All processes successfully terminated"
|
||||
fi
|
||||
34
examples/wide_ep/slurm_scripts/start_server.sh
Normal file
34
examples/wide_ep/slurm_scripts/start_server.sh
Normal file
@ -0,0 +1,34 @@
|
||||
#! /bin/bash
|
||||
|
||||
echo "commit id: $TRT_LLM_GIT_COMMIT"
|
||||
echo "ucx info: $(ucx_info -v)"
|
||||
echo "hostname: $(hostname)"
|
||||
|
||||
hostname=$(hostname)
|
||||
short_hostname=$(echo "$hostname" | awk -F'.' '{print $1}')
|
||||
echo "short_hostname: ${short_hostname}"
|
||||
|
||||
config_file=$1
|
||||
|
||||
# Check and replace hostname settings in config_file
|
||||
if [ -f "$config_file" ]; then
|
||||
# Use sed to find hostname line and check if replacement is needed
|
||||
if grep -q "^hostname:" "$config_file"; then
|
||||
# Extract current hostname value from config
|
||||
current_hostname=$(grep "^hostname:" "$config_file" | sed 's/^hostname:[ ]*//' | awk '{print $1}')
|
||||
|
||||
if [ "$current_hostname" != "$short_hostname" ]; then
|
||||
echo "Replacing hostname '$current_hostname' with '$short_hostname' in $config_file"
|
||||
# Use sed to replace hostname value
|
||||
sed -i "s/hostname:[ ]*[^ ]*/hostname: $short_hostname/" "$config_file"
|
||||
else
|
||||
echo "Hostname '$current_hostname' already matches '$short_hostname', no change needed"
|
||||
fi
|
||||
else
|
||||
echo "No hostname setting found in $config_file"
|
||||
fi
|
||||
else
|
||||
echo "Config file $config_file not found"
|
||||
fi
|
||||
|
||||
trtllm-serve disaggregated -c ${config_file} -t 1800 -r 1800
|
||||
38
examples/wide_ep/slurm_scripts/start_worker.sh
Normal file
38
examples/wide_ep/slurm_scripts/start_worker.sh
Normal file
@ -0,0 +1,38 @@
|
||||
#! /bin/bash
|
||||
|
||||
config_file=$1
|
||||
concurrency=$2
|
||||
enable_pdl=$3
|
||||
ctx_gpus=$4
|
||||
work_dir=$5
|
||||
unset UCX_TLS
|
||||
echo "config_file: ${config_file}, concurrency: ${concurrency}, enable_pdl: ${enable_pdl}, ctx_gpus: ${ctx_gpus}, work_dir: ${work_dir}"
|
||||
|
||||
export TLLM_LOG_LEVEL=INFO
|
||||
export TRTLLM_USE_UCX_KVCACHE=1
|
||||
export TLLM_BENCHMARK_REQ_QUEUES_SIZE=${concurrency}
|
||||
export TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP=1
|
||||
export TRTLLM_MOE_ENABLE_ALLTOALL_WITHOUT_ALLGATHER=1
|
||||
|
||||
if [ "${enable_pdl}" = "true" ]; then
|
||||
export TRTLLM_ENABLE_PDL=1
|
||||
fi
|
||||
|
||||
#check if work_dir is provided
|
||||
if [ -z "${work_dir}" ]; then
|
||||
echo "nsys is not enabled, start normal flow"
|
||||
trtllm-serve disaggregated_mpi_worker -c ${config_file}
|
||||
else
|
||||
nsys_prefix=""
|
||||
nsys_file=${work_dir}/nsys_worker_proc_${SLURM_PROCID}
|
||||
export TLLM_PROFILE_RECORD_GC=1
|
||||
export TLLM_NVTX_DEBUG=1
|
||||
if [ "${SLURM_PROCID}" -ge "${ctx_gpus}" ]; then
|
||||
export TLLM_PROFILE_START_STOP=200-250
|
||||
nsys_prefix="nsys profile -e \"NSYS_MPI_STORE_TEAMS_PER_RANK=1\" -o ${nsys_file} -f true -t cuda,nvtx,python-gil -c cudaProfilerApi --cuda-graph-trace node --capture-range-end=stop --gpu-metrics-devices=none"
|
||||
echo "nsys_prefix: ${nsys_prefix}"
|
||||
else
|
||||
echo "nsys is not enabled on ctx_gpus"
|
||||
fi
|
||||
${nsys_prefix} trtllm-serve disaggregated_mpi_worker -c ${config_file}
|
||||
fi
|
||||
31
examples/wide_ep/slurm_scripts/submit.sh
Normal file
31
examples/wide_ep/slurm_scripts/submit.sh
Normal file
@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
mtp_size=0
|
||||
|
||||
# dep8
|
||||
for b in 1 64 1024; do
|
||||
concurrency=$((b * 8))
|
||||
ctx_num=$(((concurrency + 5499)/5500))
|
||||
total_gpu_num=$((ctx_num + 2))
|
||||
total_tasks=$((total_gpu_num * 4))
|
||||
sbatch --nodes=${total_gpu_num} --ntasks=${total_tasks} --ntasks-per-node=4 --segment=${total_gpu_num} disaggr_torch.slurm ${ctx_num} 4 4 4480 true 1 8 1024 1024 true "0.8" 0 "$mtp_size" "$concurrency"
|
||||
done
|
||||
|
||||
# dep16 eplb0, 256, 288
|
||||
for b in 1 64 1024; do
|
||||
concurrency=$((b * 16))
|
||||
ctx_num=$(((concurrency + 5499)/5500))
|
||||
total_gpu_num=$((ctx_num + 4))
|
||||
total_tasks=$((total_gpu_num * 4))
|
||||
sbatch --nodes=${total_gpu_num} --ntasks=${total_tasks} --ntasks-per-node=4 --segment=${total_gpu_num} disaggr_torch.slurm ${ctx_num} 4 4 4480 true 1 16 1024 1024 true "0.7" 0 "$mtp_size" "$concurrency"
|
||||
sbatch --nodes=${total_gpu_num} --ntasks=${total_tasks} --ntasks-per-node=4 --segment=${total_gpu_num} disaggr_torch.slurm ${ctx_num} 4 4 4480 true 1 16 1024 1024 true "0.7" 256 "$mtp_size" "$concurrency"
|
||||
sbatch --nodes=${total_gpu_num} --ntasks=${total_tasks} --ntasks-per-node=4 --segment=${total_gpu_num} disaggr_torch.slurm ${ctx_num} 4 4 4480 true 1 16 1024 1024 true "0.7" 288 "$mtp_size" "$concurrency"
|
||||
done
|
||||
|
||||
# dep32 eplb288
|
||||
for b in 512; do
|
||||
concurrency=$((b * 32))
|
||||
ctx_num=$(((concurrency + 5499)/5500))
|
||||
total_gpu_num=$((ctx_num + 8))
|
||||
total_tasks=$((total_gpu_num * 4))
|
||||
sbatch --nodes=${total_gpu_num} --ntasks=${total_tasks} --ntasks-per-node=4 --segment=${total_gpu_num} disaggr_torch.slurm ${ctx_num} 4 4 4480 true 1 32 1024 1024 true "0.7" 288 "$mtp_size" "$concurrency"
|
||||
done
|
||||
@ -250,37 +250,41 @@ async def benchmark(
|
||||
lora_modules: Optional[Iterable[str]],
|
||||
extra_body: Optional[dict],
|
||||
streaming: bool,
|
||||
no_test_input: bool = False,
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
print("Starting initial single prompt test run...")
|
||||
test_prompt, test_prompt_len, test_output_len = \
|
||||
input_requests[0].prompt, input_requests[0].prompt_len, \
|
||||
input_requests[0].expected_output_len
|
||||
if not no_test_input:
|
||||
print("Starting initial single prompt test run...")
|
||||
test_prompt, test_prompt_len, test_output_len = \
|
||||
input_requests[0].prompt, input_requests[0].prompt_len, \
|
||||
input_requests[0].expected_output_len
|
||||
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
model_name=model_name,
|
||||
prompt=test_prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
model_name=model_name,
|
||||
prompt=test_prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
test_output = await request_func(request_func_input=test_input,
|
||||
streaming=streaming)
|
||||
if not test_output.success:
|
||||
raise ValueError(
|
||||
"Initial test run failed - Please make sure benchmark arguments "
|
||||
f"are correctly specified. Error: {test_output.error}")
|
||||
test_output = await request_func(request_func_input=test_input,
|
||||
streaming=streaming)
|
||||
if not test_output.success:
|
||||
raise ValueError(
|
||||
"Initial test run failed - Please make sure benchmark arguments "
|
||||
f"are correctly specified. Error: {test_output.error}")
|
||||
else:
|
||||
print("Initial test run completed. Starting main benchmark run...")
|
||||
else:
|
||||
print("Initial test run completed. Starting main benchmark run...")
|
||||
print("Skipping initial test run. Starting main benchmark run...")
|
||||
|
||||
if lora_modules:
|
||||
# For each input request, choose a LoRA module at random.
|
||||
@ -706,6 +710,7 @@ def main(args: argparse.Namespace):
|
||||
lora_modules=args.lora_modules,
|
||||
extra_body=sampling_params,
|
||||
streaming=not args.non_streaming,
|
||||
no_test_input=args.no_test_input,
|
||||
))
|
||||
|
||||
# Save config and results to json
|
||||
@ -1111,6 +1116,12 @@ if __name__ == "__main__":
|
||||
"launching the server. For each request, the "
|
||||
"script chooses a LoRA module at random.")
|
||||
|
||||
parser.add_argument(
|
||||
"--no-test-input",
|
||||
action="store_true",
|
||||
help="Skip initial test run with a single prompt.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user