TensorRT-LLMs/examples/layer_wise_benchmarks/run_single.py
Tailing Yuan cc4c980e03
[None][feat] Add Qwen3-Next to layer-wise benchmarks (#9065)
Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
2025-11-14 10:03:00 +08:00

158 lines
5.3 KiB
Python

import argparse
import numpy as np
import nvtx
import torch
import yaml
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size
from tensorrt_llm.tools.layer_wise_benchmarks import BalanceMethod, get_runner_cls
def comma_separated_ints(s):
return [int(x) for x in s.split(",")]
# Parse cmdline
parser = argparse.ArgumentParser()
parser.add_argument("config_path", type=str)
parser.add_argument("--model", type=str, help="Pretrained model name or path")
parser.add_argument(
"--layer-indices",
type=comma_separated_ints,
help="Comma separated indices of layers, should be a contiguous range",
)
parser.add_argument("--run-type", type=str, choices=["CTX", "GEN"])
parser.add_argument("--scaled-from", type=int)
# KV cache related args
parser.add_argument("--max-batch-size", type=int)
parser.add_argument("--tokens-per-block", type=int)
parser.add_argument("--max-seq-len", type=int)
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("--enable-attention-dp", action="store_true", dest="enable_attention_dp")
group.add_argument("--no-enable-attention-dp", action="store_false", dest="enable_attention_dp")
parser.set_defaults(enable_attention_dp=None)
# Model init args
parser.add_argument("--max-num-tokens", type=int)
parser.add_argument("--moe-backend", type=str)
parser.add_argument("--moe-max-num-tokens", type=int)
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("--use-cuda-graph", action="store_true", dest="use_cuda_graph")
group.add_argument("--no-use-cuda-graph", action="store_false", dest="use_cuda_graph")
parser.set_defaults(use_cuda_graph=None)
# Per iteration args
parser.add_argument("--batch-size", type=int)
parser.add_argument("--seq-len-q", type=int)
parser.add_argument("--seq-len-kv-cache", type=int)
parser.add_argument("--balance-method", type=str)
parser.add_argument("--balance-ratio", type=float)
args = parser.parse_args()
with open(args.config_path) as f:
config = yaml.safe_load(f)
del args.config_path
for k, v in vars(args).items():
if v is None and k in config:
setattr(args, k, config[k])
if args.max_batch_size is None:
args.max_batch_size = args.batch_size
if args.max_num_tokens is None:
args.max_num_tokens = args.max_batch_size * args.seq_len_q
print(args)
# MPI args
rank = mpi_rank()
world_size = mpi_world_size()
local_rank = local_mpi_rank()
torch.cuda.set_device(local_rank)
# Create KV cache manager
Runner = get_runner_cls(args.model)
mapping = Runner.create_mapping(enable_attention_dp=args.enable_attention_dp)
kv_cache_manager = Runner.create_kv_cache_manager(
args.model,
mapping,
tokens_per_block=args.tokens_per_block,
max_batch_size=args.max_batch_size,
max_seq_len=args.max_seq_len,
layer_indices=args.layer_indices,
)
attn_workspace = torch.empty((0,), device="cuda", dtype=torch.int8)
# Create other global objects
AutoTuner.get().clear_cache()
capture_stream = torch.cuda.Stream()
# Create Runner
runner = Runner(
args.model,
mapping,
moe_backend=args.moe_backend,
layer_indices=args.layer_indices,
scaled_from=args.scaled_from,
max_seq_len=args.max_seq_len,
max_num_tokens=args.max_num_tokens,
moe_max_num_tokens=args.moe_max_num_tokens,
use_cuda_graph=args.use_cuda_graph,
)
# Warm up
assert args.batch_size <= args.max_batch_size
assert args.seq_len_q + args.seq_len_kv_cache <= args.max_seq_len
run_pack = runner.create_run_pack(
args.run_type,
batch_size=args.batch_size,
seq_len_q=args.seq_len_q,
seq_len_kv_cache=args.seq_len_kv_cache,
kv_cache_manager=kv_cache_manager,
attn_workspace=attn_workspace,
)
runner.replace_routing_method(
balance_method=BalanceMethod[args.balance_method], balance_ratio=args.balance_ratio
)
capture_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(capture_stream):
run_pack()
with autotune():
run_pack()
torch.cuda.current_stream().wait_stream(capture_stream)
torch.cuda.synchronize()
# Profile: capture graph and replay it
torch.cuda.cudart().cudaProfilerStart()
if args.use_cuda_graph:
with with_multi_stream(True):
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g, stream=capture_stream, capture_error_mode="global"):
run_pack()
warmup_times = 20
run_times = 100
events = [torch.cuda.Event(enable_timing=True) for _ in range(warmup_times + run_times + 1)]
for i in range(warmup_times + run_times):
events[i].record()
with nvtx.annotate(f"b={args.batch_size} s={args.seq_len_q} EP{world_size}"):
if args.use_cuda_graph:
g.replay()
else:
run_pack()
events[-1].record()
torch.cuda.synchronize()
# Print statistics
# Print before `cudaProfilerStop` to ensure messages are included in the profile
time_list = [start.elapsed_time(stop) for start, stop in zip(events, events[1:])]
time_list = time_list[warmup_times:]
print(
f"[RANK {rank}]"
f" min {np.min(time_list) * 1000:.1f}"
f" max {np.max(time_list) * 1000:.1f}"
f" mean {np.mean(time_list) * 1000:.1f}"
f" median {np.median(time_list) * 1000:.1f}"
f" P90 {np.percentile(time_list, 90) * 1000:.1f}"
f" (us)"
)
torch.cuda.cudart().cudaProfilerStop()