TensorRT-LLMs/examples/layer_wise_benchmarks/run.py
Yukun He 9e7182b603
[TRTLLM-9615][feat] Implement a distributed tuning system (#9621)
Four distinct strategies are implemented to accommodate different distributed tuning scenarios, including BROADCAST, INDEPENDENT, MERGE, PARALLEL.

* Distributed tuning is disabled by default, with the INDEPENDENT strategy as the fallback. This conservative approach prevents unexpected behavior in standard use cases.
* Only operations with significant tuning time overhead have been assigned the PARALLEL strategy, which allows the same tensor parallelism (TP) rank to tune tactics concurrently across different ranks. This targeted approach balances performance gains with stability.
* Operations with nested tuning structures, such as NVFP4GemmUnifiedRunner, currently support only the INDEPENDENT strategy. This restriction exists because the synchronization mechanism is optimized only for leaf operations and doesn't yet handle nested hierarchies.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
2025-12-15 21:08:53 +08:00

303 lines
12 KiB
Python

import argparse
import itertools
import json
import os
import numpy as np
import nvtx
import torch
import yaml
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size
from tensorrt_llm.logger import logger
from tensorrt_llm.tools.layer_wise_benchmarks import BalanceMethod, get_runner_cls, mark_ranges
def comma_separated_ints(s):
return [int(x) for x in s.split(",")]
def comma_separated_floats(s):
return [float(x) for x in s.split(",")]
# Parse cmdline
parser = argparse.ArgumentParser()
parser.add_argument("config_path", type=str)
parser.add_argument("--model", type=str, help="Pretrained model name or path")
parser.add_argument(
"--layer-indices",
type=comma_separated_ints,
help="Comma separated indices of layers, should be a contiguous range",
)
parser.add_argument("--run-type", type=str, choices=["CTX", "GEN"])
parser.add_argument("--scaled-from", type=int)
# KV cache related args
parser.add_argument("--max-batch-size", type=int)
parser.add_argument("--tokens-per-block", type=int)
parser.add_argument("--max-seq-len", type=int)
group = parser.add_mutually_exclusive_group()
group.add_argument("--enable-attention-dp", action="store_true", dest="enable_attention_dp")
group.add_argument("--no-enable-attention-dp", action="store_false", dest="enable_attention_dp")
parser.set_defaults(enable_attention_dp=None)
# Model init args
parser.add_argument("--max-num-tokens", type=int)
parser.add_argument("--moe-backend", type=str)
parser.add_argument("--moe-max-num-tokens", type=int)
group = parser.add_mutually_exclusive_group()
group.add_argument(
"--use-low-precision-moe-combine", action="store_true", dest="use_low_precision_moe_combine"
)
group.add_argument(
"--no-use-low-precision-moe-combine",
action="store_false",
dest="use_low_precision_moe_combine",
)
parser.set_defaults(use_low_precision_moe_combine=None)
group = parser.add_mutually_exclusive_group()
group.add_argument("--enable-autotuner", action="store_true", dest="enable_autotuner")
group.add_argument("--no-enable-autotuner", action="store_false", dest="enable_autotuner")
parser.set_defaults(enable_autotuner=None)
group = parser.add_mutually_exclusive_group()
group.add_argument("--use-cuda-graph", action="store_true", dest="use_cuda_graph")
group.add_argument("--no-use-cuda-graph", action="store_false", dest="use_cuda_graph")
parser.set_defaults(use_cuda_graph=None)
# Per iteration args
parser.add_argument("--batch-size", type=comma_separated_ints, dest="batch_size_list")
parser.add_argument("--seq-len-q", type=comma_separated_ints, dest="seq_len_q_list")
parser.add_argument("--seq-len-kv-cache", type=comma_separated_ints, dest="seq_len_kv_cache_list")
parser.add_argument("--balance-method", type=str)
parser.add_argument("--balance-ratio", type=comma_separated_floats, dest="balance_ratio_list")
# Schedule
parser.add_argument("--warmup-times", type=int, default=20)
parser.add_argument("--run-times", type=int, default=100)
args = parser.parse_args()
# Load YAML file
with open(args.config_path) as f:
config = yaml.safe_load(f)
del args.config_path
for k, v in vars(args).items():
if k.endswith("_list"):
config_key = k[: -len("_list")]
if v is None and config_key in config:
v = config[config_key]
if isinstance(v, list):
pass
elif v is None or isinstance(v, (int, float)):
v = [v]
else:
raise ValueError(f'Config "{config_key}" in YAML should be a value or a list')
setattr(args, k, v)
else:
config_key = k
if v is None and config_key in config:
v = config[config_key]
setattr(args, k, v)
if config_key in config:
del config[config_key]
if config:
raise ValueError(f"Config {','.join(config.keys())} from file are not options")
# Set default values
if args.max_batch_size is None:
args.max_batch_size = max(args.batch_size_list)
if args.max_seq_len is None:
args.max_seq_len = max(args.seq_len_q_list) + max(args.seq_len_kv_cache_list)
if args.enable_attention_dp is None:
args.enable_attention_dp = False
if args.max_num_tokens is None:
args.max_num_tokens = args.max_batch_size * max(args.seq_len_q_list)
if args.run_type == "GEN":
ctx_batch_size = max(1, max(20480, args.max_num_tokens) // max(args.seq_len_kv_cache_list))
args.max_num_tokens = max(
args.max_num_tokens, ctx_batch_size * max(args.seq_len_kv_cache_list)
)
else:
if args.run_type == "GEN":
ctx_batch_size = max(1, args.max_num_tokens // max(args.seq_len_kv_cache_list))
assert args.max_num_tokens >= ctx_batch_size * max(args.seq_len_kv_cache_list), (
"Max_num_tokens is too small to prefill KV cache"
)
if args.use_low_precision_moe_combine is None:
args.use_low_precision_moe_combine = False
if args.enable_autotuner is None:
args.enable_autotuner = True
if args.use_cuda_graph is None:
args.use_cuda_graph = False
print(args)
# MPI args
rank = mpi_rank()
world_size = mpi_world_size()
local_rank = local_mpi_rank()
torch.cuda.set_device(local_rank)
# Create KV cache manager
logger.info("Layer-wise benchmarks: Create KV cache manager")
Runner = get_runner_cls(args.model)
mapping = Runner.create_mapping(enable_attention_dp=args.enable_attention_dp)
kv_cache_manager = Runner.create_kv_cache_manager(
args.model,
mapping,
tokens_per_block=args.tokens_per_block,
max_batch_size=args.max_batch_size,
max_seq_len=args.max_seq_len,
layer_indices=args.layer_indices,
)
attn_workspace = torch.empty((0,), device="cuda", dtype=torch.int8)
logger.info("Layer-wise benchmarks: Create KV cache manager ... Done")
# Create other global objects
AutoTuner.get().clear_cache()
capture_stream = torch.cuda.Stream()
mark_ranges()
# Create runner
logger.info("Layer-wise benchmarks: Create runner")
runner = Runner(
args.model,
mapping,
moe_backend=args.moe_backend,
layer_indices=args.layer_indices,
scaled_from=args.scaled_from,
max_seq_len=args.max_seq_len,
max_num_tokens=args.max_num_tokens,
moe_max_num_tokens=args.moe_max_num_tokens,
use_low_precision_moe_combine=args.use_low_precision_moe_combine,
use_cuda_graph=args.use_cuda_graph,
)
logger.info("Layer-wise benchmarks: Create runner ... Done")
# Warm up
for autotune_flag, batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [
[
True,
max(args.batch_size_list),
max(args.seq_len_q_list),
args.seq_len_kv_cache_list[0],
args.balance_ratio_list[0],
],
*itertools.product(
[False],
args.batch_size_list,
args.seq_len_q_list,
args.seq_len_kv_cache_list,
args.balance_ratio_list,
),
]:
assert batch_size <= args.max_batch_size
assert seq_len_q + seq_len_kv_cache <= args.max_seq_len
assert batch_size * seq_len_q <= args.max_num_tokens
run_pack = runner.create_run_pack(
args.run_type,
batch_size=batch_size,
request_id_begin=0,
seq_len_q=seq_len_q,
seq_len_kv_cache=seq_len_kv_cache,
kv_cache_manager=kv_cache_manager,
attn_workspace=attn_workspace,
)
with runner.replace_routing_method_ctx(
balance_method=BalanceMethod[args.balance_method], balance_ratio=balance_ratio
):
capture_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(capture_stream):
if autotune_flag:
if args.enable_autotuner:
cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None
with autotune(cache_path=cache_path):
run_pack()
if args.run_type == "GEN":
logger.info("Layer-wise benchmarks: Prefill KV cache")
ctx_seq_len_q = max(args.seq_len_kv_cache_list)
assert ctx_batch_size <= args.max_batch_size
assert ctx_seq_len_q + 0 <= args.max_seq_len
assert ctx_batch_size * ctx_seq_len_q <= args.max_num_tokens
max_batch_size = max(args.batch_size_list)
for request_id_begin in range(0, max_batch_size, ctx_batch_size):
ctx_run_pack = runner.create_run_pack(
"CTX",
batch_size=min(ctx_batch_size, max_batch_size - request_id_begin),
request_id_begin=request_id_begin,
seq_len_q=ctx_seq_len_q,
seq_len_kv_cache=0,
kv_cache_manager=kv_cache_manager,
attn_workspace=attn_workspace,
)
ctx_run_pack(check=True)
logger.info("Layer-wise benchmarks: Prefill KV cache ... Done")
else:
run_pack(check=True)
torch.cuda.current_stream().wait_stream(capture_stream)
torch.cuda.synchronize()
events = [
torch.cuda.Event(enable_timing=True) for _ in range(args.warmup_times + args.run_times + 1)
]
[e.record() for e in events] # Explicitly warmup events because torch is lazy
torch.cuda.cudart().cudaProfilerStart()
with nvtx.annotate(f"layer_wise_benchmarks args {json.dumps(args.__dict__)}"):
pass # Use `annotate` instead of `mark` to avoid addition lines on the Nsight Systems UI
for batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in itertools.product(
args.batch_size_list, args.seq_len_q_list, args.seq_len_kv_cache_list, args.balance_ratio_list
):
# Profile: capture graph and replay it
problem_spec = {
"batch_size": batch_size,
"seq_len_q": seq_len_q,
"seq_len_kv_cache": seq_len_kv_cache,
"balance_ratio": balance_ratio,
}
with nvtx.annotate(f"layer_wise_benchmarks problem_spec {json.dumps(problem_spec)}"):
pass
run_pack = runner.create_run_pack(
args.run_type,
batch_size=batch_size,
request_id_begin=0,
seq_len_q=seq_len_q,
seq_len_kv_cache=seq_len_kv_cache,
kv_cache_manager=kv_cache_manager,
attn_workspace=attn_workspace,
)
with runner.replace_routing_method_ctx(
balance_method=BalanceMethod[args.balance_method], balance_ratio=balance_ratio
):
if args.use_cuda_graph:
with with_multi_stream(True):
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g, stream=capture_stream, capture_error_mode="global"):
run_pack()
balance_ratio_str = "" if balance_ratio is None else f" balance={balance_ratio:.2g}"
nvtx_message = f"b={batch_size} s={seq_len_q} past={seq_len_kv_cache}{balance_ratio_str} NP{world_size}"
for i in range(args.warmup_times + args.run_times):
events[i].record()
with nvtx.annotate(nvtx_message):
if args.use_cuda_graph:
g.replay()
else:
run_pack()
events[-1].record()
torch.cuda.synchronize()
# Print statistics
# Print before `cudaProfilerStop` to ensure messages are included in the profile
time_list = [start.elapsed_time(stop) for start, stop in zip(events, events[1:])]
time_list = time_list[args.warmup_times :]
print(
f"[RANK {rank}]"
f" batch_size {batch_size}"
f" seq_len_q {seq_len_q}"
f" seq_len_kv_cache {seq_len_kv_cache}"
+ ("" if balance_ratio is None else f" balance_ratio {balance_ratio:.2g}")
+ f" mean {np.mean(time_list) * 1000:.1f}"
f" median {np.median(time_list) * 1000:.1f}"
f" min {np.min(time_list) * 1000:.1f}"
f" max {np.max(time_list) * 1000:.1f}"
f" P90 {np.percentile(time_list, 90) * 1000:.1f}"
f" (us)"
)
torch.cuda.cudart().cudaProfilerStop()