mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[None][feat] move some disagg script's env configs from bash to submit.py (#10223)
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
This commit is contained in:
parent
f0ca62b175
commit
3d8c1a51bd
@ -4,12 +4,5 @@ set -e
|
||||
set -x
|
||||
|
||||
config_file=$1
|
||||
server_env_var=$2
|
||||
|
||||
# Export server environment variables from config
|
||||
for env_var in ${server_env_var}; do
|
||||
export "${env_var}"
|
||||
echo "Exported: ${env_var}"
|
||||
done
|
||||
|
||||
trtllm-serve disaggregated -c ${config_file} -t 7200 -r 7200
|
||||
|
||||
@ -7,23 +7,16 @@ role=${1}
|
||||
instance_id=${2}
|
||||
model_path=${3}
|
||||
port=${4}
|
||||
benchmark_mode=${5}
|
||||
concurrency=${6}
|
||||
numa_bind=${7}
|
||||
log_dir=${8}
|
||||
enable_nsys=${9}
|
||||
profile_range=${10}
|
||||
config_file=${11}
|
||||
worker_env_var=${12}
|
||||
numa_bind=${5}
|
||||
log_dir=${6}
|
||||
enable_nsys=${7}
|
||||
config_file=${8}
|
||||
|
||||
# Clear UCX_TLS for specific clusters
|
||||
unset UCX_TLS
|
||||
echo "SLURM_PROCID: ${SLURM_PROCID}, hostname: $(hostname), instance_id: ${instance_id}"
|
||||
|
||||
# Export worker environment variables from config
|
||||
for env_var in ${worker_env_var}; do
|
||||
export "${env_var}"
|
||||
echo "Exported: ${env_var}"
|
||||
done
|
||||
echo "SLURM_PROCID: ${SLURM_PROCID}, hostname: $(hostname), instance_id: ${instance_id}"
|
||||
echo "CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES}"
|
||||
|
||||
if [ "${numa_bind}" = "true" ]; then
|
||||
numa_bind_cmd="numactl -m 0,1"
|
||||
@ -33,10 +26,6 @@ else
|
||||
echo "Not binding memory. If on GB200/GB300 NVL72, use \"numactl -m 0,1\" to only allocate memory from nodes."
|
||||
fi
|
||||
|
||||
if [ "${benchmark_mode}" = "gen_only" ]; then
|
||||
export TLLM_BENCHMARK_REQ_QUEUES_SIZE=${concurrency}
|
||||
fi
|
||||
|
||||
echo "config_file: ${config_file}"
|
||||
|
||||
nsys_prefix=""
|
||||
@ -44,11 +33,7 @@ if [ "${enable_nsys}" != "true" ]; then
|
||||
echo "nsys is not enabled, start normal flow"
|
||||
else
|
||||
nsys_file=${log_dir}/nsys_worker_proc_${role}_${instance_id}_${SLURM_PROCID}
|
||||
export TLLM_PROFILE_RECORD_GC=1
|
||||
export TLLM_NVTX_DEBUG=1
|
||||
export NSYS_MPI_STORE_TEAMS_PER_RANK=1
|
||||
export TLLM_PROFILE_START_STOP=${profile_range}
|
||||
echo "nsys is enabled on ${role} GPUs, TLLM_PROFILE_START_STOP=${profile_range}"
|
||||
echo "nsys is enabled on ${role} GPUs, TLLM_PROFILE_START_STOP=${TLLM_PROFILE_START_STOP}"
|
||||
nsys_prefix="nsys profile -o ${nsys_file} -f true -t cuda,nvtx,python-gil -c cudaProfilerApi --cuda-graph-trace node --capture-range-end=stop --gpu-metrics-devices=none"
|
||||
fi
|
||||
|
||||
|
||||
@ -113,6 +113,7 @@ def convert_allocations_to_server_config(allocations, server_port=8333):
|
||||
generation_servers = {}
|
||||
context_servers = {}
|
||||
server_hostname = None
|
||||
|
||||
for server_type in allocations.keys():
|
||||
num_servers = len(allocations[server_type])
|
||||
urls = []
|
||||
@ -120,13 +121,16 @@ def convert_allocations_to_server_config(allocations, server_port=8333):
|
||||
instance = allocations[server_type][server_id]
|
||||
urls.append(
|
||||
f"{list(instance['nodes'].keys())[0]}:{instance['port']}")
|
||||
|
||||
server_config_entry = {'num_instances': num_servers, 'urls': urls}
|
||||
|
||||
if server_type == "GEN":
|
||||
generation_servers = {'num_instances': num_servers, 'urls': urls}
|
||||
generation_servers = server_config_entry
|
||||
server_hostname = urls[0].split(':')[0]
|
||||
if allocations[server_type][server_id]['port'] == server_port:
|
||||
server_port += 1 # Avoid port conflict
|
||||
elif server_type == "CTX":
|
||||
context_servers = {'num_instances': num_servers, 'urls': urls}
|
||||
context_servers = server_config_entry
|
||||
|
||||
server_config = {
|
||||
'backend': 'pytorch',
|
||||
@ -168,6 +172,121 @@ def replace_env_in_file(log_dir, file_path, env_var):
|
||||
return tmp_dir
|
||||
|
||||
|
||||
def build_worker_environment(worker_config, env_config, role, benchmark_mode,
|
||||
nsys_on, profile_range, concurrency, gpu_ids):
|
||||
"""Build complete environment dictionary for worker processes.
|
||||
|
||||
Args:
|
||||
worker_config: Worker configuration dict
|
||||
env_config: Environment configuration dict
|
||||
role: Server role ("CTX" or "GEN")
|
||||
benchmark_mode: Benchmark mode string
|
||||
nsys_on: Whether nsys profiling is enabled
|
||||
profile_range: Profile range string (e.g., "10-30")
|
||||
concurrency: Concurrency level
|
||||
gpu_ids: List of GPU IDs assigned to this worker
|
||||
|
||||
Returns:
|
||||
Dictionary of environment variables
|
||||
|
||||
Note:
|
||||
CUDA_VISIBLE_DEVICES is NOT set here. It is passed as an argument to
|
||||
start_worker.sh and set per-rank based on SLURM_LOCALID.
|
||||
"""
|
||||
env = {}
|
||||
|
||||
# 1. Use gpu_ids to set CUDA_VISIBLE_DEVICES
|
||||
cuda_devices = ','.join(map(str, gpu_ids))
|
||||
env["CUDA_VISIBLE_DEVICES"] = cuda_devices
|
||||
|
||||
# 2. Parse user-defined worker env vars from config
|
||||
worker_env_var = env_config.get('worker_env_var', '')
|
||||
for var_string in worker_env_var.split():
|
||||
if '=' in var_string:
|
||||
key, val = var_string.split('=', 1)
|
||||
env[key] = val
|
||||
|
||||
# 3. Add role-specific env vars (CTX or GEN)
|
||||
role_env_vars = {
|
||||
"CTX": env_config.get('ctx_worker_env_var', ''),
|
||||
"GEN": env_config.get('gen_worker_env_var', '')
|
||||
}
|
||||
role_specific_env_var = role_env_vars.get(role, '')
|
||||
for var_string in role_specific_env_var.split():
|
||||
if '=' in var_string:
|
||||
key, val = var_string.split('=', 1)
|
||||
env[key] = val
|
||||
|
||||
# 4. Add mode-based env vars
|
||||
if benchmark_mode == "gen_only_no_context":
|
||||
env["TRTLLM_DISAGG_BENCHMARK_GEN_ONLY"] = "1"
|
||||
if benchmark_mode == "gen_only":
|
||||
env["TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP"] = "1"
|
||||
if role == "GEN":
|
||||
env["TLLM_BENCHMARK_REQ_QUEUES_SIZE"] = str(concurrency)
|
||||
|
||||
# 5. Add profiling env vars (conditional)
|
||||
if nsys_on:
|
||||
env["TLLM_PROFILE_RECORD_GC"] = "1"
|
||||
env["TLLM_NVTX_DEBUG"] = "1"
|
||||
env["NSYS_MPI_STORE_TEAMS_PER_RANK"] = "1"
|
||||
env["TLLM_PROFILE_START_STOP"] = profile_range
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def build_server_environment(env_config, benchmark_mode):
|
||||
"""Build complete environment dictionary for server process.
|
||||
|
||||
Args:
|
||||
env_config: Environment configuration dict
|
||||
benchmark_mode: Benchmark mode string
|
||||
|
||||
Returns:
|
||||
Dictionary of environment variables
|
||||
"""
|
||||
env = {}
|
||||
|
||||
# Parse user-defined server env vars
|
||||
server_env_var = env_config.get('server_env_var', '')
|
||||
for var_string in server_env_var.split():
|
||||
if '=' in var_string:
|
||||
key, val = var_string.split('=', 1)
|
||||
env[key] = val
|
||||
|
||||
# Add mode-based env vars
|
||||
if benchmark_mode == "gen_only_no_context":
|
||||
env["TRTLLM_DISAGG_BENCHMARK_GEN_ONLY"] = "1"
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def format_export_string(env_dict):
|
||||
"""Convert environment dictionary to srun --export format.
|
||||
|
||||
Args:
|
||||
env_dict: Dictionary of environment variables
|
||||
|
||||
Returns:
|
||||
String formatted for srun --export flag (e.g., "KEY1=val1,KEY2=val2")
|
||||
Returns "NONE" if no variables specified.
|
||||
|
||||
Note:
|
||||
Values containing commas are quoted to avoid conflicts with srun's delimiter.
|
||||
"""
|
||||
if not env_dict:
|
||||
return "NONE"
|
||||
|
||||
export_list = []
|
||||
for k, v in env_dict.items():
|
||||
# srun cannot handle values that contain commas
|
||||
if ',' in v:
|
||||
export_list.append(f"'{k}={v}'")
|
||||
else:
|
||||
export_list.append(f"{k}={v}")
|
||||
return ",".join(export_list)
|
||||
|
||||
|
||||
def save_env_file(env_file, server_env_var, worker_env_var, ctx_worker_env_var,
|
||||
gen_worker_env_var):
|
||||
|
||||
@ -226,16 +345,6 @@ def submit_job(config, log_dir, dry_run):
|
||||
env_config.setdefault('worker_env_var', '')
|
||||
env_config.setdefault('server_env_var', '')
|
||||
|
||||
worker_env_var = env_config.get('worker_env_var', '')
|
||||
ctx_worker_env_var = env_config.get('ctx_worker_env_var', '')
|
||||
gen_worker_env_var = env_config.get('gen_worker_env_var', '')
|
||||
server_env_var = env_config.get('server_env_var', '')
|
||||
if benchmark_config['mode'] == "gen_only_no_context":
|
||||
worker_env_var += " TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1"
|
||||
server_env_var += " TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1"
|
||||
if benchmark_config['mode'] == "gen_only":
|
||||
worker_env_var += " TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP=1"
|
||||
|
||||
profiling_config = config.get('profiling', {})
|
||||
profiling_config.setdefault('nsys_on', False)
|
||||
profiling_config.setdefault('ctx_profile_range', '10-30')
|
||||
@ -320,7 +429,11 @@ def submit_job(config, log_dir, dry_run):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
print(f"Log will be saved to: {log_dir}")
|
||||
|
||||
# Save environment variables
|
||||
# Save environment variables (for record-keeping only)
|
||||
worker_env_var = env_config.get('worker_env_var', '')
|
||||
ctx_worker_env_var = env_config.get('ctx_worker_env_var', '')
|
||||
gen_worker_env_var = env_config.get('gen_worker_env_var', '')
|
||||
server_env_var = env_config.get('server_env_var', '')
|
||||
save_env_file(os.path.join(log_dir, "env_vars.json"), server_env_var,
|
||||
worker_env_var, ctx_worker_env_var, gen_worker_env_var)
|
||||
|
||||
@ -353,24 +466,52 @@ def submit_job(config, log_dir, dry_run):
|
||||
start_server_cmds = []
|
||||
container_mount_str = env_config['container_mount']
|
||||
container_mount_str += f",{script_dir}:{script_dir}"
|
||||
|
||||
# Pre-define server-type-specific configurations
|
||||
server_configs = {
|
||||
"GEN": {
|
||||
"world_size": gen_world_size,
|
||||
"profile_range": profiling_config['gen_profile_range'],
|
||||
"config_path": gen_config_path
|
||||
},
|
||||
"CTX": {
|
||||
"world_size": ctx_world_size,
|
||||
"profile_range": profiling_config['ctx_profile_range'],
|
||||
"config_path": ctx_config_path
|
||||
}
|
||||
}
|
||||
|
||||
# Generate start worker commands with placeholder hostnames
|
||||
for server_type in allocations.keys():
|
||||
server_cfg = server_configs[server_type]
|
||||
|
||||
for server_id in allocations[server_type].keys():
|
||||
allocation = allocations[server_type][server_id]
|
||||
cuda_devices = ",".join([
|
||||
str(device) for device in list(allocation["nodes"].values())[0]
|
||||
])
|
||||
cur_worker_env_var = worker_env_var + \
|
||||
f" CUDA_VISIBLE_DEVICES={cuda_devices}" + \
|
||||
(f" {ctx_worker_env_var}" if server_type == "CTX" else "") + \
|
||||
(f" {gen_worker_env_var}" if server_type == "GEN" else "")
|
||||
# Get GPU IDs for this server from allocation
|
||||
# When multi-node, all nodes have same device list, so use first node [0]
|
||||
gpu_ids = list(allocation["nodes"].values())[0]
|
||||
|
||||
# Build environment for this worker
|
||||
worker_env = build_worker_environment(
|
||||
worker_config=worker_config,
|
||||
env_config=env_config,
|
||||
role=server_type,
|
||||
benchmark_mode=benchmark_config['mode'],
|
||||
nsys_on=profiling_config['nsys_on'],
|
||||
profile_range=server_cfg['profile_range'],
|
||||
concurrency=benchmark_config['concurrency_list'].split(',')[0],
|
||||
gpu_ids=gpu_ids,
|
||||
)
|
||||
export_str = format_export_string(worker_env)
|
||||
|
||||
# Use script_dir for start_worker.sh
|
||||
cmd = [
|
||||
"srun -l",
|
||||
f"--nodelist {','.join(allocation['nodes'].keys())}",
|
||||
f"-N {len(allocation['nodes'])}",
|
||||
f"--ntasks {gen_world_size if server_type == 'GEN' else ctx_world_size}",
|
||||
f"--ntasks {server_cfg['world_size']}",
|
||||
f"--ntasks-per-node {gpus_per_node}",
|
||||
f"--export=\"{export_str}\"",
|
||||
f"--container-image {env_config['container_image']}",
|
||||
f"--container-name {container_name}",
|
||||
f"--container-mounts {container_mount_str}",
|
||||
@ -380,28 +521,27 @@ def submit_job(config, log_dir, dry_run):
|
||||
str(server_id),
|
||||
env_config['model_path'],
|
||||
str(allocation["port"]),
|
||||
benchmark_config['mode'],
|
||||
f"'{benchmark_config['concurrency_list']}'",
|
||||
str(slurm_config['numa_bind']).lower(),
|
||||
log_dir,
|
||||
str(profiling_config['nsys_on']).lower(),
|
||||
f"'{profiling_config['gen_profile_range']}'" if server_type
|
||||
== "GEN" else f"'{profiling_config['ctx_profile_range']}'",
|
||||
gen_config_path if server_type == "GEN" else ctx_config_path,
|
||||
f"'{cur_worker_env_var}'",
|
||||
server_cfg['config_path'],
|
||||
f"&> {log_dir}/3_output_{server_type}_{server_id}.log &",
|
||||
]
|
||||
start_server_cmds.append(" ".join(cmd))
|
||||
|
||||
# Generate start server commands (use script_dir for start_server.sh)
|
||||
server_env = build_server_environment(env_config, benchmark_config['mode'])
|
||||
export_str = format_export_string(server_env)
|
||||
|
||||
cmd = [
|
||||
"srun -l",
|
||||
f"--nodelist {disagg_server_hostname}",
|
||||
f"--container-name={container_name}",
|
||||
f"--export=\"{export_str}\"",
|
||||
f"--container-image={env_config['container_image']}",
|
||||
f"--container-mounts={container_mount_str}",
|
||||
f"--no-container-mount-home --mpi=pmix --overlap -N 1 -n 1",
|
||||
f"bash {os.path.join(script_dir, 'start_server.sh')} {os.path.join(log_dir, 'server_config.yaml')} \"{server_env_var}\"",
|
||||
f"bash {os.path.join(script_dir, 'start_server.sh')} {os.path.join(log_dir, 'server_config.yaml')}",
|
||||
f"&> {log_dir}/4_output_server.log &",
|
||||
]
|
||||
start_server_cmds.append(" ".join(cmd))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user