mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-17 00:04:57 +08:00
399 lines
15 KiB
Python
399 lines
15 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
import os
|
|
|
|
import yaml
|
|
|
|
|
|
def get_hardware_config(config, benchmark_mode):
|
|
hardware = config.get("hardware", {})
|
|
worker_config = config.get("worker_config", {})
|
|
|
|
num_ctx_servers = 0 if "gen_only" in benchmark_mode else hardware.get("num_ctx_servers")
|
|
num_gen_servers = hardware.get("num_gen_servers")
|
|
gpus_per_node = hardware.get("gpus_per_node")
|
|
|
|
# Get gpus_per_ctx_server and gpus_per_gen_server from worker_config's tensor_parallel_size
|
|
ctx_config = worker_config.get("ctx", {})
|
|
gen_config = worker_config.get("gen", {})
|
|
ctx_tp = ctx_config.get("tensor_parallel_size", 1)
|
|
ctx_pp = ctx_config.get("pipeline_parallel_size", 1)
|
|
ctx_cp = ctx_config.get("context_parallel_size", 1)
|
|
gpus_per_ctx_server = ctx_tp * ctx_pp * ctx_cp
|
|
gen_tp = gen_config.get("tensor_parallel_size", 1)
|
|
gen_pp = gen_config.get("pipeline_parallel_size", 1)
|
|
gen_cp = gen_config.get("context_parallel_size", 1)
|
|
gpus_per_gen_server = gen_tp * gen_pp * gen_cp
|
|
|
|
if None in [
|
|
num_ctx_servers,
|
|
num_gen_servers,
|
|
gpus_per_node,
|
|
gpus_per_ctx_server,
|
|
gpus_per_gen_server,
|
|
]:
|
|
raise ValueError("Missing required hardware configuration")
|
|
|
|
# Calculate nodes per server
|
|
nodes_per_ctx_server = (gpus_per_ctx_server + gpus_per_node - 1) // gpus_per_node
|
|
nodes_per_gen_server = (gpus_per_gen_server + gpus_per_node - 1) // gpus_per_node
|
|
|
|
gpus_per_node_per_ctx_server = min(gpus_per_ctx_server, gpus_per_node)
|
|
gpus_per_node_per_gen_server = min(gpus_per_gen_server, gpus_per_node)
|
|
|
|
total_nodes = num_ctx_servers * nodes_per_ctx_server + num_gen_servers * nodes_per_gen_server
|
|
total_gpus = total_nodes * gpus_per_node
|
|
|
|
return {
|
|
"num_ctx_servers": num_ctx_servers,
|
|
"num_gen_servers": num_gen_servers,
|
|
"gpus_per_node": gpus_per_node,
|
|
"gpus_per_ctx_server": gpus_per_ctx_server,
|
|
"gpus_per_gen_server": gpus_per_gen_server,
|
|
"nodes_per_ctx_server": nodes_per_ctx_server,
|
|
"nodes_per_gen_server": nodes_per_gen_server,
|
|
"gpus_per_node_per_ctx_server": gpus_per_node_per_ctx_server,
|
|
"gpus_per_node_per_gen_server": gpus_per_node_per_gen_server,
|
|
"total_nodes": total_nodes,
|
|
"total_gpus": total_gpus,
|
|
}
|
|
|
|
|
|
def get_env_config(config):
|
|
env = config.get("environment", {})
|
|
|
|
container = env.get("container_image", "")
|
|
mounts = env.get("container_mount", "")
|
|
workdir = env.get("container_workdir", "")
|
|
llm_models_root = env.get("llm_models_root", "")
|
|
llmsrc = env.get("trtllm_repo", "")
|
|
build_wheel = env.get("build_wheel", False)
|
|
# Use work_dir as job_workspace
|
|
job_workspace = env.get("work_dir", "")
|
|
worker_env_var = env.get("worker_env_var", "")
|
|
server_env_var = env.get("server_env_var", "")
|
|
benchmark_env_var = env.get("benchmark_env_var", "")
|
|
open_search_db_base_url = env.get("open_search_db_base_url", "")
|
|
|
|
return {
|
|
"container": container,
|
|
"mounts": mounts,
|
|
"workdir": workdir,
|
|
"llm_models_root": llm_models_root,
|
|
"llmsrc": llmsrc,
|
|
"build_wheel": build_wheel,
|
|
"job_workspace": job_workspace,
|
|
"worker_env_var": worker_env_var,
|
|
"server_env_var": server_env_var,
|
|
"benchmark_env_var": benchmark_env_var,
|
|
"open_search_db_base_url": open_search_db_base_url,
|
|
}
|
|
|
|
|
|
def get_benchmark_config(config):
|
|
benchmark = config.get("benchmark", {})
|
|
|
|
mode = benchmark.get("mode", "e2e")
|
|
concurrency_str = benchmark.get("concurrency_list", "1")
|
|
concurrency = int(concurrency_str) if isinstance(concurrency_str, str) else concurrency_str
|
|
|
|
return {
|
|
"mode": mode,
|
|
"concurrency": concurrency,
|
|
}
|
|
|
|
|
|
def remove_whitespace_lines(lines):
|
|
return [line.strip() for line in lines if line.strip()]
|
|
|
|
|
|
def get_pytest_commands(script_prefix_lines):
|
|
# Get worker, disagg_server, benchmark pytest commands from pytest command.
|
|
# Worker pytest command is pytest command with trtllm-llmapi-launch and
|
|
# without --csv, --cov, --periodic flags.
|
|
# Disagg_server pytest command is pytest command without trtllm-llmapi-launch
|
|
# and without --csv, --cov, --periodic flags.
|
|
# Benchmark pytest command is pytest command without trtllm-llmapi-launch
|
|
# and with --csv, --cov, --periodic flags.
|
|
pytest_command_line = None
|
|
for line in script_prefix_lines:
|
|
if "export pytestCommand=" in line:
|
|
pytest_command_line = line
|
|
break
|
|
|
|
if not pytest_command_line:
|
|
return "", "", ""
|
|
|
|
def split_pytest_command_line(command_line):
|
|
# After pytest, there are six types of substrings:
|
|
# Type 1: --xxx=yyy (long option with value, self-contained)
|
|
# Type 2: --xxx= (long option with empty value, self-contained)
|
|
# Type 3: --xxx (long option flag, no value)
|
|
# Type 4: --xxx yyy (long option with value as next arg)
|
|
# Type 5: -x yyy (short single-letter option with value as next arg)
|
|
# Type 6: -x (short option flag, e.g., -v, -vv)
|
|
parts = command_line.split()
|
|
pytest_index = None
|
|
for idx, part in enumerate(parts):
|
|
if "pytest" == part:
|
|
pytest_index = idx
|
|
break
|
|
if pytest_index is None:
|
|
return parts
|
|
|
|
grouped_parts = parts[: pytest_index + 1]
|
|
i = pytest_index + 1
|
|
while i < len(parts):
|
|
part = parts[i]
|
|
has_next = i + 1 < len(parts)
|
|
next_is_value = has_next and not parts[i + 1].startswith("-")
|
|
|
|
# Type 1 & 2: --xxx=yyy or --xxx= (self-contained, has '=')
|
|
if part.startswith("--") and "=" in part:
|
|
grouped_parts.append(part)
|
|
i += 1
|
|
continue
|
|
|
|
# Type 4: --xxx yyy (long option with value as next arg)
|
|
if part.startswith("--") and next_is_value:
|
|
grouped_parts.append(f"{part} {parts[i + 1]}")
|
|
i += 2
|
|
continue
|
|
|
|
# Type 3: --xxx (long option flag)
|
|
if part.startswith("--"):
|
|
grouped_parts.append(part)
|
|
i += 1
|
|
continue
|
|
|
|
# Type 5: -x yyy (short single-letter option with value as next arg)
|
|
# Only single letter after dash, e.g., -o, not -vv
|
|
if part.startswith("-") and len(part) == 2 and next_is_value:
|
|
grouped_parts.append(f"{part} {parts[i + 1]}")
|
|
i += 2
|
|
continue
|
|
|
|
# Type 6: -x (short option flag, including combined like -vv)
|
|
if part.startswith("-"):
|
|
grouped_parts.append(part)
|
|
i += 1
|
|
continue
|
|
|
|
# Other parts (shouldn't happen after pytest, but handle gracefully)
|
|
grouped_parts.append(part)
|
|
i += 1
|
|
|
|
return grouped_parts
|
|
|
|
def is_llmapi_launch(part):
|
|
return "trtllm-llmapi-launch" in part
|
|
|
|
def is_output_file_part(part):
|
|
return any(flag in part for flag in ("--csv", "--cov", "--periodic"))
|
|
|
|
worker_line = pytest_command_line.replace("pytestCommand", "partialPytestCommandWorker")
|
|
worker_parts = [
|
|
part for part in split_pytest_command_line(worker_line) if not is_output_file_part(part)
|
|
]
|
|
worker_pytest_command = " ".join(worker_parts)
|
|
|
|
disagg_server_line = pytest_command_line.replace(
|
|
"pytestCommand", "partialPytestCommandDisaggServer"
|
|
)
|
|
disagg_server_parts = [
|
|
part
|
|
for part in split_pytest_command_line(disagg_server_line)
|
|
if not is_llmapi_launch(part) and not is_output_file_part(part)
|
|
]
|
|
disagg_server_pytest_command = " ".join(disagg_server_parts)
|
|
|
|
benchmark_line = pytest_command_line.replace("pytestCommand", "partialPytestCommandBenchmark")
|
|
benchmark_parts = [
|
|
part for part in split_pytest_command_line(benchmark_line) if not is_llmapi_launch(part)
|
|
]
|
|
benchmark_pytest_command = " ".join(benchmark_parts)
|
|
|
|
return (
|
|
worker_pytest_command,
|
|
disagg_server_pytest_command,
|
|
benchmark_pytest_command,
|
|
)
|
|
|
|
|
|
def get_config_yaml(test_list_path, llm_src):
|
|
with open(test_list_path, "r") as f:
|
|
first_line = f.readline().strip()
|
|
|
|
if "[" not in first_line or "]" not in first_line:
|
|
raise ValueError(
|
|
f"Invalid test list format. Expected test name with brackets: {first_line}"
|
|
)
|
|
bracket_content = first_line.split("[")[-1].split("]")[0]
|
|
parts = bracket_content.split("-")
|
|
if len(parts) < 2:
|
|
raise ValueError(
|
|
f"Invalid test name format. Expected format: prefix-config_name, got: {bracket_content}"
|
|
)
|
|
|
|
# parts[0] is the prefix, parts[1:] is the config name
|
|
if "disagg" not in parts[0]:
|
|
raise ValueError(
|
|
f"Invalid test name format. Expected format: disagg-config_name, got: {bracket_content}"
|
|
)
|
|
config_base_name = "-".join(parts[1:])
|
|
config_yaml_path = os.path.join(
|
|
llm_src,
|
|
"tests",
|
|
"integration",
|
|
"defs",
|
|
"perf",
|
|
"disagg",
|
|
"test_configs",
|
|
"disagg",
|
|
"perf-sanity",
|
|
f"{config_base_name}.yaml",
|
|
)
|
|
if not os.path.exists(config_yaml_path):
|
|
raise FileNotFoundError(f"Config file not found: {config_yaml_path}")
|
|
return config_yaml_path
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate SLURM launch script for both CI and local modes"
|
|
)
|
|
parser.add_argument(
|
|
"--run-ci",
|
|
action="store_true",
|
|
default=False,
|
|
help="Run in CI mode (true) or local mode (false)",
|
|
)
|
|
parser.add_argument("--draft-launch-sh", required=True, help="Path to draft-launch.sh script")
|
|
parser.add_argument("--launch-sh", required=True, help="Path to output launch.sh script")
|
|
parser.add_argument("--run-sh", required=True, help="Path to slurm_run.sh script")
|
|
parser.add_argument("--install-sh", required=True, help="Path to slurm_install.sh script")
|
|
|
|
# Optional arguments for local mode
|
|
parser.add_argument("--config-yaml", default="", help="Path to config YAML file")
|
|
parser.add_argument("--stage-name", default="", help="Stage name (optional, local mode only)")
|
|
|
|
# Optional arguments for CI mode
|
|
parser.add_argument("--llm-src", default="", help="Path to LLM source code")
|
|
parser.add_argument("--test-list", default="", help="Path to test list file")
|
|
parser.add_argument(
|
|
"--script-prefix",
|
|
default="",
|
|
help="Launch script prefix file path (optional, CI mode only)",
|
|
)
|
|
parser.add_argument(
|
|
"--srun-args",
|
|
default="",
|
|
help="Path to file containing srun args (optional, CI mode only)",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
config_yaml = get_config_yaml(args.test_list, args.llm_src)
|
|
|
|
with open(config_yaml, "r") as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
# Determine install script path
|
|
install_script = args.install_sh
|
|
|
|
env_config = get_env_config(config)
|
|
print(f"Environment configuration: {env_config}")
|
|
|
|
benchmark_config = get_benchmark_config(config)
|
|
print(f"Benchmark configuration: {benchmark_config}")
|
|
benchmark_mode = benchmark_config["mode"]
|
|
|
|
hardware_config = get_hardware_config(config, benchmark_mode)
|
|
print(f"Hardware configuration: {hardware_config}")
|
|
|
|
script_prefix_lines = []
|
|
srun_args_lines = []
|
|
|
|
with open(args.script_prefix, "r") as f:
|
|
script_prefix_content = f.read()
|
|
script_prefix_lines = script_prefix_content.split("\n")
|
|
with open(args.srun_args, "r") as f:
|
|
srun_args_content = f.read()
|
|
|
|
srun_args_lines = srun_args_content.split()
|
|
|
|
# Extract pytestCommand and generate partial pytest commands
|
|
(
|
|
worker_pytest_command,
|
|
disagg_server_pytest_command,
|
|
benchmark_pytest_command,
|
|
) = get_pytest_commands(script_prefix_lines)
|
|
|
|
# Build worker env vars, add extra env vars for gen_only mode
|
|
worker_env_vars = env_config["worker_env_var"]
|
|
server_env_vars = env_config["server_env_var"]
|
|
if "gen_only" in benchmark_config["mode"]:
|
|
concurrency = benchmark_config["concurrency"]
|
|
worker_env_vars = (
|
|
"TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1 "
|
|
f"TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP=1 "
|
|
f"TLLM_BENCHMARK_REQ_QUEUES_SIZE={concurrency} {worker_env_vars}"
|
|
)
|
|
server_env_vars = f"TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1 {server_env_vars}"
|
|
script_prefix_lines.append("export TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1")
|
|
srun_args_lines.append("--container-env=TRTLLM_DISAGG_BENCHMARK_GEN_ONLY")
|
|
|
|
script_prefix_lines.extend(
|
|
[
|
|
worker_pytest_command,
|
|
disagg_server_pytest_command,
|
|
benchmark_pytest_command,
|
|
f'export pytestCommandWorker="unset UCX_TLS && {worker_env_vars} $partialPytestCommandWorker"',
|
|
f'export pytestCommandDisaggServer="{server_env_vars} $partialPytestCommandDisaggServer"',
|
|
f'export pytestCommandBenchmark="{env_config["benchmark_env_var"]} $partialPytestCommandBenchmark"',
|
|
f"export runScript={args.run_sh}",
|
|
f"export installScript={install_script}",
|
|
f"export configYamlPath={config_yaml}",
|
|
f"export numCtxServers={hardware_config['num_ctx_servers']}",
|
|
f"export numGenServers={hardware_config['num_gen_servers']}",
|
|
f"export gpusPerNode={hardware_config['gpus_per_node']}",
|
|
f"export gpusPerCtxServer={hardware_config['gpus_per_ctx_server']}",
|
|
f"export gpusPerGenServer={hardware_config['gpus_per_gen_server']}",
|
|
f"export nodesPerCtxServer={hardware_config['nodes_per_ctx_server']}",
|
|
f"export nodesPerGenServer={hardware_config['nodes_per_gen_server']}",
|
|
f"export gpusPerfNodePerfCtxServer={hardware_config['gpus_per_node_per_ctx_server']}",
|
|
f"export gpusPerfNodePerfGenServer={hardware_config['gpus_per_node_per_gen_server']}",
|
|
f"export totalNodes={hardware_config['total_nodes']}",
|
|
f"export totalGpus={hardware_config['total_gpus']}",
|
|
]
|
|
)
|
|
|
|
remove_whitespace_lines(script_prefix_lines)
|
|
script_prefix = "\n".join(script_prefix_lines)
|
|
|
|
remove_whitespace_lines(srun_args_lines)
|
|
srun_args_lines.extend(
|
|
[
|
|
"--container-env=DISAGG_SERVING_TYPE",
|
|
"--container-env=pytestCommand",
|
|
]
|
|
)
|
|
srun_args_lines = ["srunArgs=("] + [f' "{line}"' for line in srun_args_lines] + [")"]
|
|
srun_args = "\n".join(srun_args_lines)
|
|
|
|
with open(args.draft_launch_sh, "r") as f:
|
|
draft_launch_content = f.read()
|
|
draft_launch_lines = draft_launch_content.split("\n")
|
|
remove_whitespace_lines(draft_launch_lines)
|
|
draft_launch_content = "\n".join(draft_launch_lines)
|
|
|
|
with open(args.launch_sh, "w") as f:
|
|
f.write(f"{script_prefix}\n{srun_args}\n{draft_launch_content}")
|
|
|
|
print(f"Launch script generated at: {args.launch_sh}")
|
|
print(f"Launch script:\n{script_prefix}\n{srun_args}\n{draft_launch_content}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|