[None][feat] move some disagg script's env configs from bash to submit.py (#10223)

Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
This commit is contained in:
Zhenhuan Chen 2026-02-04 17:32:04 +08:00 committed by GitHub
parent f0ca62b175
commit 3d8c1a51bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 176 additions and 58 deletions

View File

@ -4,12 +4,5 @@ set -e
set -x
config_file=$1
server_env_var=$2
# Export server environment variables from config
for env_var in ${server_env_var}; do
export "${env_var}"
echo "Exported: ${env_var}"
done
trtllm-serve disaggregated -c ${config_file} -t 7200 -r 7200

View File

@ -7,23 +7,16 @@ role=${1}
instance_id=${2}
model_path=${3}
port=${4}
benchmark_mode=${5}
concurrency=${6}
numa_bind=${7}
log_dir=${8}
enable_nsys=${9}
profile_range=${10}
config_file=${11}
worker_env_var=${12}
numa_bind=${5}
log_dir=${6}
enable_nsys=${7}
config_file=${8}
# Clear UCX_TLS for specific clusters
unset UCX_TLS
echo "SLURM_PROCID: ${SLURM_PROCID}, hostname: $(hostname), instance_id: ${instance_id}"
# Export worker environment variables from config
for env_var in ${worker_env_var}; do
export "${env_var}"
echo "Exported: ${env_var}"
done
echo "SLURM_PROCID: ${SLURM_PROCID}, hostname: $(hostname), instance_id: ${instance_id}"
echo "CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES}"
if [ "${numa_bind}" = "true" ]; then
numa_bind_cmd="numactl -m 0,1"
@ -33,10 +26,6 @@ else
echo "Not binding memory. If on GB200/GB300 NVL72, use \"numactl -m 0,1\" to only allocate memory from nodes."
fi
if [ "${benchmark_mode}" = "gen_only" ]; then
export TLLM_BENCHMARK_REQ_QUEUES_SIZE=${concurrency}
fi
echo "config_file: ${config_file}"
nsys_prefix=""
@ -44,11 +33,7 @@ if [ "${enable_nsys}" != "true" ]; then
echo "nsys is not enabled, start normal flow"
else
nsys_file=${log_dir}/nsys_worker_proc_${role}_${instance_id}_${SLURM_PROCID}
export TLLM_PROFILE_RECORD_GC=1
export TLLM_NVTX_DEBUG=1
export NSYS_MPI_STORE_TEAMS_PER_RANK=1
export TLLM_PROFILE_START_STOP=${profile_range}
echo "nsys is enabled on ${role} GPUs, TLLM_PROFILE_START_STOP=${profile_range}"
echo "nsys is enabled on ${role} GPUs, TLLM_PROFILE_START_STOP=${TLLM_PROFILE_START_STOP}"
nsys_prefix="nsys profile -o ${nsys_file} -f true -t cuda,nvtx,python-gil -c cudaProfilerApi --cuda-graph-trace node --capture-range-end=stop --gpu-metrics-devices=none"
fi

View File

@ -113,6 +113,7 @@ def convert_allocations_to_server_config(allocations, server_port=8333):
generation_servers = {}
context_servers = {}
server_hostname = None
for server_type in allocations.keys():
num_servers = len(allocations[server_type])
urls = []
@ -120,13 +121,16 @@ def convert_allocations_to_server_config(allocations, server_port=8333):
instance = allocations[server_type][server_id]
urls.append(
f"{list(instance['nodes'].keys())[0]}:{instance['port']}")
server_config_entry = {'num_instances': num_servers, 'urls': urls}
if server_type == "GEN":
generation_servers = {'num_instances': num_servers, 'urls': urls}
generation_servers = server_config_entry
server_hostname = urls[0].split(':')[0]
if allocations[server_type][server_id]['port'] == server_port:
server_port += 1 # Avoid port conflict
elif server_type == "CTX":
context_servers = {'num_instances': num_servers, 'urls': urls}
context_servers = server_config_entry
server_config = {
'backend': 'pytorch',
@ -168,6 +172,121 @@ def replace_env_in_file(log_dir, file_path, env_var):
return tmp_dir
def build_worker_environment(worker_config, env_config, role, benchmark_mode,
nsys_on, profile_range, concurrency, gpu_ids):
"""Build complete environment dictionary for worker processes.
Args:
worker_config: Worker configuration dict
env_config: Environment configuration dict
role: Server role ("CTX" or "GEN")
benchmark_mode: Benchmark mode string
nsys_on: Whether nsys profiling is enabled
profile_range: Profile range string (e.g., "10-30")
concurrency: Concurrency level
gpu_ids: List of GPU IDs assigned to this worker
Returns:
Dictionary of environment variables
Note:
CUDA_VISIBLE_DEVICES is NOT set here. It is passed as an argument to
start_worker.sh and set per-rank based on SLURM_LOCALID.
"""
env = {}
# 1. Use gpu_ids to set CUDA_VISIBLE_DEVICES
cuda_devices = ','.join(map(str, gpu_ids))
env["CUDA_VISIBLE_DEVICES"] = cuda_devices
# 2. Parse user-defined worker env vars from config
worker_env_var = env_config.get('worker_env_var', '')
for var_string in worker_env_var.split():
if '=' in var_string:
key, val = var_string.split('=', 1)
env[key] = val
# 3. Add role-specific env vars (CTX or GEN)
role_env_vars = {
"CTX": env_config.get('ctx_worker_env_var', ''),
"GEN": env_config.get('gen_worker_env_var', '')
}
role_specific_env_var = role_env_vars.get(role, '')
for var_string in role_specific_env_var.split():
if '=' in var_string:
key, val = var_string.split('=', 1)
env[key] = val
# 4. Add mode-based env vars
if benchmark_mode == "gen_only_no_context":
env["TRTLLM_DISAGG_BENCHMARK_GEN_ONLY"] = "1"
if benchmark_mode == "gen_only":
env["TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP"] = "1"
if role == "GEN":
env["TLLM_BENCHMARK_REQ_QUEUES_SIZE"] = str(concurrency)
# 5. Add profiling env vars (conditional)
if nsys_on:
env["TLLM_PROFILE_RECORD_GC"] = "1"
env["TLLM_NVTX_DEBUG"] = "1"
env["NSYS_MPI_STORE_TEAMS_PER_RANK"] = "1"
env["TLLM_PROFILE_START_STOP"] = profile_range
return env
def build_server_environment(env_config, benchmark_mode):
"""Build complete environment dictionary for server process.
Args:
env_config: Environment configuration dict
benchmark_mode: Benchmark mode string
Returns:
Dictionary of environment variables
"""
env = {}
# Parse user-defined server env vars
server_env_var = env_config.get('server_env_var', '')
for var_string in server_env_var.split():
if '=' in var_string:
key, val = var_string.split('=', 1)
env[key] = val
# Add mode-based env vars
if benchmark_mode == "gen_only_no_context":
env["TRTLLM_DISAGG_BENCHMARK_GEN_ONLY"] = "1"
return env
def format_export_string(env_dict):
"""Convert environment dictionary to srun --export format.
Args:
env_dict: Dictionary of environment variables
Returns:
String formatted for srun --export flag (e.g., "KEY1=val1,KEY2=val2")
Returns "NONE" if no variables specified.
Note:
Values containing commas are quoted to avoid conflicts with srun's delimiter.
"""
if not env_dict:
return "NONE"
export_list = []
for k, v in env_dict.items():
# srun cannot handle values that contain commas
if ',' in v:
export_list.append(f"'{k}={v}'")
else:
export_list.append(f"{k}={v}")
return ",".join(export_list)
def save_env_file(env_file, server_env_var, worker_env_var, ctx_worker_env_var,
gen_worker_env_var):
@ -226,16 +345,6 @@ def submit_job(config, log_dir, dry_run):
env_config.setdefault('worker_env_var', '')
env_config.setdefault('server_env_var', '')
worker_env_var = env_config.get('worker_env_var', '')
ctx_worker_env_var = env_config.get('ctx_worker_env_var', '')
gen_worker_env_var = env_config.get('gen_worker_env_var', '')
server_env_var = env_config.get('server_env_var', '')
if benchmark_config['mode'] == "gen_only_no_context":
worker_env_var += " TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1"
server_env_var += " TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1"
if benchmark_config['mode'] == "gen_only":
worker_env_var += " TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP=1"
profiling_config = config.get('profiling', {})
profiling_config.setdefault('nsys_on', False)
profiling_config.setdefault('ctx_profile_range', '10-30')
@ -320,7 +429,11 @@ def submit_job(config, log_dir, dry_run):
os.makedirs(log_dir, exist_ok=True)
print(f"Log will be saved to: {log_dir}")
# Save environment variables
# Save environment variables (for record-keeping only)
worker_env_var = env_config.get('worker_env_var', '')
ctx_worker_env_var = env_config.get('ctx_worker_env_var', '')
gen_worker_env_var = env_config.get('gen_worker_env_var', '')
server_env_var = env_config.get('server_env_var', '')
save_env_file(os.path.join(log_dir, "env_vars.json"), server_env_var,
worker_env_var, ctx_worker_env_var, gen_worker_env_var)
@ -353,24 +466,52 @@ def submit_job(config, log_dir, dry_run):
start_server_cmds = []
container_mount_str = env_config['container_mount']
container_mount_str += f",{script_dir}:{script_dir}"
# Pre-define server-type-specific configurations
server_configs = {
"GEN": {
"world_size": gen_world_size,
"profile_range": profiling_config['gen_profile_range'],
"config_path": gen_config_path
},
"CTX": {
"world_size": ctx_world_size,
"profile_range": profiling_config['ctx_profile_range'],
"config_path": ctx_config_path
}
}
# Generate start worker commands with placeholder hostnames
for server_type in allocations.keys():
server_cfg = server_configs[server_type]
for server_id in allocations[server_type].keys():
allocation = allocations[server_type][server_id]
cuda_devices = ",".join([
str(device) for device in list(allocation["nodes"].values())[0]
])
cur_worker_env_var = worker_env_var + \
f" CUDA_VISIBLE_DEVICES={cuda_devices}" + \
(f" {ctx_worker_env_var}" if server_type == "CTX" else "") + \
(f" {gen_worker_env_var}" if server_type == "GEN" else "")
# Get GPU IDs for this server from allocation
# When multi-node, all nodes have same device list, so use first node [0]
gpu_ids = list(allocation["nodes"].values())[0]
# Build environment for this worker
worker_env = build_worker_environment(
worker_config=worker_config,
env_config=env_config,
role=server_type,
benchmark_mode=benchmark_config['mode'],
nsys_on=profiling_config['nsys_on'],
profile_range=server_cfg['profile_range'],
concurrency=benchmark_config['concurrency_list'].split(',')[0],
gpu_ids=gpu_ids,
)
export_str = format_export_string(worker_env)
# Use script_dir for start_worker.sh
cmd = [
"srun -l",
f"--nodelist {','.join(allocation['nodes'].keys())}",
f"-N {len(allocation['nodes'])}",
f"--ntasks {gen_world_size if server_type == 'GEN' else ctx_world_size}",
f"--ntasks {server_cfg['world_size']}",
f"--ntasks-per-node {gpus_per_node}",
f"--export=\"{export_str}\"",
f"--container-image {env_config['container_image']}",
f"--container-name {container_name}",
f"--container-mounts {container_mount_str}",
@ -380,28 +521,27 @@ def submit_job(config, log_dir, dry_run):
str(server_id),
env_config['model_path'],
str(allocation["port"]),
benchmark_config['mode'],
f"'{benchmark_config['concurrency_list']}'",
str(slurm_config['numa_bind']).lower(),
log_dir,
str(profiling_config['nsys_on']).lower(),
f"'{profiling_config['gen_profile_range']}'" if server_type
== "GEN" else f"'{profiling_config['ctx_profile_range']}'",
gen_config_path if server_type == "GEN" else ctx_config_path,
f"'{cur_worker_env_var}'",
server_cfg['config_path'],
f"&> {log_dir}/3_output_{server_type}_{server_id}.log &",
]
start_server_cmds.append(" ".join(cmd))
# Generate start server commands (use script_dir for start_server.sh)
server_env = build_server_environment(env_config, benchmark_config['mode'])
export_str = format_export_string(server_env)
cmd = [
"srun -l",
f"--nodelist {disagg_server_hostname}",
f"--container-name={container_name}",
f"--export=\"{export_str}\"",
f"--container-image={env_config['container_image']}",
f"--container-mounts={container_mount_str}",
f"--no-container-mount-home --mpi=pmix --overlap -N 1 -n 1",
f"bash {os.path.join(script_dir, 'start_server.sh')} {os.path.join(log_dir, 'server_config.yaml')} \"{server_env_var}\"",
f"bash {os.path.join(script_dir, 'start_server.sh')} {os.path.join(log_dir, 'server_config.yaml')}",
f"&> {log_dir}/4_output_server.log &",
]
start_server_cmds.append(" ".join(cmd))