mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Four distinct strategies are implemented to accommodate different distributed tuning scenarios, including BROADCAST, INDEPENDENT, MERGE, PARALLEL. * Distributed tuning is disabled by default, with the INDEPENDENT strategy as the fallback. This conservative approach prevents unexpected behavior in standard use cases. * Only operations with significant tuning time overhead have been assigned the PARALLEL strategy, which allows the same tensor parallelism (TP) rank to tune tactics concurrently across different ranks. This targeted approach balances performance gains with stability. * Operations with nested tuning structures, such as NVFP4GemmUnifiedRunner, currently support only the INDEPENDENT strategy. This restriction exists because the synchronization mechanism is optimized only for leaf operations and doesn't yet handle nested hierarchies. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
303 lines
12 KiB
Python
303 lines
12 KiB
Python
import argparse
|
|
import itertools
|
|
import json
|
|
import os
|
|
|
|
import numpy as np
|
|
import nvtx
|
|
import torch
|
|
import yaml
|
|
|
|
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
|
|
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
|
|
from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size
|
|
from tensorrt_llm.logger import logger
|
|
from tensorrt_llm.tools.layer_wise_benchmarks import BalanceMethod, get_runner_cls, mark_ranges
|
|
|
|
|
|
def comma_separated_ints(s):
|
|
return [int(x) for x in s.split(",")]
|
|
|
|
|
|
def comma_separated_floats(s):
|
|
return [float(x) for x in s.split(",")]
|
|
|
|
|
|
# Parse cmdline
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("config_path", type=str)
|
|
parser.add_argument("--model", type=str, help="Pretrained model name or path")
|
|
parser.add_argument(
|
|
"--layer-indices",
|
|
type=comma_separated_ints,
|
|
help="Comma separated indices of layers, should be a contiguous range",
|
|
)
|
|
parser.add_argument("--run-type", type=str, choices=["CTX", "GEN"])
|
|
parser.add_argument("--scaled-from", type=int)
|
|
# KV cache related args
|
|
parser.add_argument("--max-batch-size", type=int)
|
|
parser.add_argument("--tokens-per-block", type=int)
|
|
parser.add_argument("--max-seq-len", type=int)
|
|
group = parser.add_mutually_exclusive_group()
|
|
group.add_argument("--enable-attention-dp", action="store_true", dest="enable_attention_dp")
|
|
group.add_argument("--no-enable-attention-dp", action="store_false", dest="enable_attention_dp")
|
|
parser.set_defaults(enable_attention_dp=None)
|
|
# Model init args
|
|
parser.add_argument("--max-num-tokens", type=int)
|
|
parser.add_argument("--moe-backend", type=str)
|
|
parser.add_argument("--moe-max-num-tokens", type=int)
|
|
group = parser.add_mutually_exclusive_group()
|
|
group.add_argument(
|
|
"--use-low-precision-moe-combine", action="store_true", dest="use_low_precision_moe_combine"
|
|
)
|
|
group.add_argument(
|
|
"--no-use-low-precision-moe-combine",
|
|
action="store_false",
|
|
dest="use_low_precision_moe_combine",
|
|
)
|
|
parser.set_defaults(use_low_precision_moe_combine=None)
|
|
group = parser.add_mutually_exclusive_group()
|
|
group.add_argument("--enable-autotuner", action="store_true", dest="enable_autotuner")
|
|
group.add_argument("--no-enable-autotuner", action="store_false", dest="enable_autotuner")
|
|
parser.set_defaults(enable_autotuner=None)
|
|
group = parser.add_mutually_exclusive_group()
|
|
group.add_argument("--use-cuda-graph", action="store_true", dest="use_cuda_graph")
|
|
group.add_argument("--no-use-cuda-graph", action="store_false", dest="use_cuda_graph")
|
|
parser.set_defaults(use_cuda_graph=None)
|
|
# Per iteration args
|
|
parser.add_argument("--batch-size", type=comma_separated_ints, dest="batch_size_list")
|
|
parser.add_argument("--seq-len-q", type=comma_separated_ints, dest="seq_len_q_list")
|
|
parser.add_argument("--seq-len-kv-cache", type=comma_separated_ints, dest="seq_len_kv_cache_list")
|
|
parser.add_argument("--balance-method", type=str)
|
|
parser.add_argument("--balance-ratio", type=comma_separated_floats, dest="balance_ratio_list")
|
|
# Schedule
|
|
parser.add_argument("--warmup-times", type=int, default=20)
|
|
parser.add_argument("--run-times", type=int, default=100)
|
|
args = parser.parse_args()
|
|
# Load YAML file
|
|
with open(args.config_path) as f:
|
|
config = yaml.safe_load(f)
|
|
del args.config_path
|
|
for k, v in vars(args).items():
|
|
if k.endswith("_list"):
|
|
config_key = k[: -len("_list")]
|
|
if v is None and config_key in config:
|
|
v = config[config_key]
|
|
if isinstance(v, list):
|
|
pass
|
|
elif v is None or isinstance(v, (int, float)):
|
|
v = [v]
|
|
else:
|
|
raise ValueError(f'Config "{config_key}" in YAML should be a value or a list')
|
|
setattr(args, k, v)
|
|
else:
|
|
config_key = k
|
|
if v is None and config_key in config:
|
|
v = config[config_key]
|
|
setattr(args, k, v)
|
|
if config_key in config:
|
|
del config[config_key]
|
|
if config:
|
|
raise ValueError(f"Config {','.join(config.keys())} from file are not options")
|
|
# Set default values
|
|
if args.max_batch_size is None:
|
|
args.max_batch_size = max(args.batch_size_list)
|
|
if args.max_seq_len is None:
|
|
args.max_seq_len = max(args.seq_len_q_list) + max(args.seq_len_kv_cache_list)
|
|
if args.enable_attention_dp is None:
|
|
args.enable_attention_dp = False
|
|
if args.max_num_tokens is None:
|
|
args.max_num_tokens = args.max_batch_size * max(args.seq_len_q_list)
|
|
if args.run_type == "GEN":
|
|
ctx_batch_size = max(1, max(20480, args.max_num_tokens) // max(args.seq_len_kv_cache_list))
|
|
args.max_num_tokens = max(
|
|
args.max_num_tokens, ctx_batch_size * max(args.seq_len_kv_cache_list)
|
|
)
|
|
else:
|
|
if args.run_type == "GEN":
|
|
ctx_batch_size = max(1, args.max_num_tokens // max(args.seq_len_kv_cache_list))
|
|
assert args.max_num_tokens >= ctx_batch_size * max(args.seq_len_kv_cache_list), (
|
|
"Max_num_tokens is too small to prefill KV cache"
|
|
)
|
|
if args.use_low_precision_moe_combine is None:
|
|
args.use_low_precision_moe_combine = False
|
|
if args.enable_autotuner is None:
|
|
args.enable_autotuner = True
|
|
if args.use_cuda_graph is None:
|
|
args.use_cuda_graph = False
|
|
print(args)
|
|
|
|
# MPI args
|
|
rank = mpi_rank()
|
|
world_size = mpi_world_size()
|
|
local_rank = local_mpi_rank()
|
|
torch.cuda.set_device(local_rank)
|
|
|
|
# Create KV cache manager
|
|
logger.info("Layer-wise benchmarks: Create KV cache manager")
|
|
Runner = get_runner_cls(args.model)
|
|
mapping = Runner.create_mapping(enable_attention_dp=args.enable_attention_dp)
|
|
kv_cache_manager = Runner.create_kv_cache_manager(
|
|
args.model,
|
|
mapping,
|
|
tokens_per_block=args.tokens_per_block,
|
|
max_batch_size=args.max_batch_size,
|
|
max_seq_len=args.max_seq_len,
|
|
layer_indices=args.layer_indices,
|
|
)
|
|
attn_workspace = torch.empty((0,), device="cuda", dtype=torch.int8)
|
|
logger.info("Layer-wise benchmarks: Create KV cache manager ... Done")
|
|
|
|
# Create other global objects
|
|
AutoTuner.get().clear_cache()
|
|
capture_stream = torch.cuda.Stream()
|
|
mark_ranges()
|
|
|
|
# Create runner
|
|
logger.info("Layer-wise benchmarks: Create runner")
|
|
runner = Runner(
|
|
args.model,
|
|
mapping,
|
|
moe_backend=args.moe_backend,
|
|
layer_indices=args.layer_indices,
|
|
scaled_from=args.scaled_from,
|
|
max_seq_len=args.max_seq_len,
|
|
max_num_tokens=args.max_num_tokens,
|
|
moe_max_num_tokens=args.moe_max_num_tokens,
|
|
use_low_precision_moe_combine=args.use_low_precision_moe_combine,
|
|
use_cuda_graph=args.use_cuda_graph,
|
|
)
|
|
logger.info("Layer-wise benchmarks: Create runner ... Done")
|
|
|
|
# Warm up
|
|
for autotune_flag, batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [
|
|
[
|
|
True,
|
|
max(args.batch_size_list),
|
|
max(args.seq_len_q_list),
|
|
args.seq_len_kv_cache_list[0],
|
|
args.balance_ratio_list[0],
|
|
],
|
|
*itertools.product(
|
|
[False],
|
|
args.batch_size_list,
|
|
args.seq_len_q_list,
|
|
args.seq_len_kv_cache_list,
|
|
args.balance_ratio_list,
|
|
),
|
|
]:
|
|
assert batch_size <= args.max_batch_size
|
|
assert seq_len_q + seq_len_kv_cache <= args.max_seq_len
|
|
assert batch_size * seq_len_q <= args.max_num_tokens
|
|
run_pack = runner.create_run_pack(
|
|
args.run_type,
|
|
batch_size=batch_size,
|
|
request_id_begin=0,
|
|
seq_len_q=seq_len_q,
|
|
seq_len_kv_cache=seq_len_kv_cache,
|
|
kv_cache_manager=kv_cache_manager,
|
|
attn_workspace=attn_workspace,
|
|
)
|
|
with runner.replace_routing_method_ctx(
|
|
balance_method=BalanceMethod[args.balance_method], balance_ratio=balance_ratio
|
|
):
|
|
capture_stream.wait_stream(torch.cuda.current_stream())
|
|
with torch.cuda.stream(capture_stream):
|
|
if autotune_flag:
|
|
if args.enable_autotuner:
|
|
cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None
|
|
with autotune(cache_path=cache_path):
|
|
run_pack()
|
|
if args.run_type == "GEN":
|
|
logger.info("Layer-wise benchmarks: Prefill KV cache")
|
|
ctx_seq_len_q = max(args.seq_len_kv_cache_list)
|
|
assert ctx_batch_size <= args.max_batch_size
|
|
assert ctx_seq_len_q + 0 <= args.max_seq_len
|
|
assert ctx_batch_size * ctx_seq_len_q <= args.max_num_tokens
|
|
max_batch_size = max(args.batch_size_list)
|
|
for request_id_begin in range(0, max_batch_size, ctx_batch_size):
|
|
ctx_run_pack = runner.create_run_pack(
|
|
"CTX",
|
|
batch_size=min(ctx_batch_size, max_batch_size - request_id_begin),
|
|
request_id_begin=request_id_begin,
|
|
seq_len_q=ctx_seq_len_q,
|
|
seq_len_kv_cache=0,
|
|
kv_cache_manager=kv_cache_manager,
|
|
attn_workspace=attn_workspace,
|
|
)
|
|
ctx_run_pack(check=True)
|
|
logger.info("Layer-wise benchmarks: Prefill KV cache ... Done")
|
|
else:
|
|
run_pack(check=True)
|
|
torch.cuda.current_stream().wait_stream(capture_stream)
|
|
torch.cuda.synchronize()
|
|
|
|
events = [
|
|
torch.cuda.Event(enable_timing=True) for _ in range(args.warmup_times + args.run_times + 1)
|
|
]
|
|
[e.record() for e in events] # Explicitly warmup events because torch is lazy
|
|
|
|
torch.cuda.cudart().cudaProfilerStart()
|
|
with nvtx.annotate(f"layer_wise_benchmarks args {json.dumps(args.__dict__)}"):
|
|
pass # Use `annotate` instead of `mark` to avoid addition lines on the Nsight Systems UI
|
|
for batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in itertools.product(
|
|
args.batch_size_list, args.seq_len_q_list, args.seq_len_kv_cache_list, args.balance_ratio_list
|
|
):
|
|
# Profile: capture graph and replay it
|
|
problem_spec = {
|
|
"batch_size": batch_size,
|
|
"seq_len_q": seq_len_q,
|
|
"seq_len_kv_cache": seq_len_kv_cache,
|
|
"balance_ratio": balance_ratio,
|
|
}
|
|
with nvtx.annotate(f"layer_wise_benchmarks problem_spec {json.dumps(problem_spec)}"):
|
|
pass
|
|
run_pack = runner.create_run_pack(
|
|
args.run_type,
|
|
batch_size=batch_size,
|
|
request_id_begin=0,
|
|
seq_len_q=seq_len_q,
|
|
seq_len_kv_cache=seq_len_kv_cache,
|
|
kv_cache_manager=kv_cache_manager,
|
|
attn_workspace=attn_workspace,
|
|
)
|
|
with runner.replace_routing_method_ctx(
|
|
balance_method=BalanceMethod[args.balance_method], balance_ratio=balance_ratio
|
|
):
|
|
if args.use_cuda_graph:
|
|
with with_multi_stream(True):
|
|
g = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(g, stream=capture_stream, capture_error_mode="global"):
|
|
run_pack()
|
|
|
|
balance_ratio_str = "" if balance_ratio is None else f" balance={balance_ratio:.2g}"
|
|
nvtx_message = f"b={batch_size} s={seq_len_q} past={seq_len_kv_cache}{balance_ratio_str} NP{world_size}"
|
|
for i in range(args.warmup_times + args.run_times):
|
|
events[i].record()
|
|
with nvtx.annotate(nvtx_message):
|
|
if args.use_cuda_graph:
|
|
g.replay()
|
|
else:
|
|
run_pack()
|
|
events[-1].record()
|
|
torch.cuda.synchronize()
|
|
|
|
# Print statistics
|
|
# Print before `cudaProfilerStop` to ensure messages are included in the profile
|
|
time_list = [start.elapsed_time(stop) for start, stop in zip(events, events[1:])]
|
|
time_list = time_list[args.warmup_times :]
|
|
print(
|
|
f"[RANK {rank}]"
|
|
f" batch_size {batch_size}"
|
|
f" seq_len_q {seq_len_q}"
|
|
f" seq_len_kv_cache {seq_len_kv_cache}"
|
|
+ ("" if balance_ratio is None else f" balance_ratio {balance_ratio:.2g}")
|
|
+ f" mean {np.mean(time_list) * 1000:.1f}"
|
|
f" median {np.median(time_list) * 1000:.1f}"
|
|
f" min {np.min(time_list) * 1000:.1f}"
|
|
f" max {np.max(time_list) * 1000:.1f}"
|
|
f" P90 {np.percentile(time_list, 90) * 1000:.1f}"
|
|
f" (us)"
|
|
)
|
|
torch.cuda.cudart().cudaProfilerStop()
|