import argparse import itertools import json import os import numpy as np import nvtx import torch import yaml from tensorrt_llm._torch.autotuner import AutoTuner, autotune from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size from tensorrt_llm.logger import logger from tensorrt_llm.tools.layer_wise_benchmarks import BalanceMethod, get_runner_cls, mark_ranges def comma_separated_ints(s): return [int(x) for x in s.split(",")] def comma_separated_floats(s): return [float(x) for x in s.split(",")] # Parse cmdline parser = argparse.ArgumentParser() parser.add_argument("config_path", type=str) parser.add_argument("--model", type=str, help="Pretrained model name or path") parser.add_argument( "--layer-indices", type=comma_separated_ints, help="Comma separated indices of layers, should be a contiguous range", ) parser.add_argument("--run-type", type=str, choices=["CTX", "GEN"]) parser.add_argument("--scaled-from", type=int) # KV cache related args parser.add_argument("--max-batch-size", type=int) parser.add_argument("--tokens-per-block", type=int) parser.add_argument("--max-seq-len", type=int) group = parser.add_mutually_exclusive_group() group.add_argument("--enable-attention-dp", action="store_true", dest="enable_attention_dp") group.add_argument("--no-enable-attention-dp", action="store_false", dest="enable_attention_dp") parser.set_defaults(enable_attention_dp=None) # Model init args parser.add_argument("--max-num-tokens", type=int) parser.add_argument("--moe-backend", type=str) parser.add_argument("--moe-max-num-tokens", type=int) group = parser.add_mutually_exclusive_group() group.add_argument( "--use-low-precision-moe-combine", action="store_true", dest="use_low_precision_moe_combine" ) group.add_argument( "--no-use-low-precision-moe-combine", action="store_false", dest="use_low_precision_moe_combine", ) parser.set_defaults(use_low_precision_moe_combine=None) group = parser.add_mutually_exclusive_group() group.add_argument("--enable-autotuner", action="store_true", dest="enable_autotuner") group.add_argument("--no-enable-autotuner", action="store_false", dest="enable_autotuner") parser.set_defaults(enable_autotuner=None) group = parser.add_mutually_exclusive_group() group.add_argument("--use-cuda-graph", action="store_true", dest="use_cuda_graph") group.add_argument("--no-use-cuda-graph", action="store_false", dest="use_cuda_graph") parser.set_defaults(use_cuda_graph=None) # Per iteration args parser.add_argument("--batch-size", type=comma_separated_ints, dest="batch_size_list") parser.add_argument("--seq-len-q", type=comma_separated_ints, dest="seq_len_q_list") parser.add_argument("--seq-len-kv-cache", type=comma_separated_ints, dest="seq_len_kv_cache_list") parser.add_argument("--balance-method", type=str) parser.add_argument("--balance-ratio", type=comma_separated_floats, dest="balance_ratio_list") # Schedule parser.add_argument("--warmup-times", type=int, default=20) parser.add_argument("--run-times", type=int, default=100) args = parser.parse_args() # Load YAML file with open(args.config_path) as f: config = yaml.safe_load(f) del args.config_path for k, v in vars(args).items(): if k.endswith("_list"): config_key = k[: -len("_list")] if v is None and config_key in config: v = config[config_key] if isinstance(v, list): pass elif v is None or isinstance(v, (int, float)): v = [v] else: raise ValueError(f'Config "{config_key}" in YAML should be a value or a list') setattr(args, k, v) else: config_key = k if v is None and config_key in config: v = config[config_key] setattr(args, k, v) if config_key in config: del config[config_key] if config: raise ValueError(f"Config {','.join(config.keys())} from file are not options") # Set default values if args.max_batch_size is None: args.max_batch_size = max(args.batch_size_list) if args.max_seq_len is None: args.max_seq_len = max(args.seq_len_q_list) + max(args.seq_len_kv_cache_list) if args.enable_attention_dp is None: args.enable_attention_dp = False if args.max_num_tokens is None: args.max_num_tokens = args.max_batch_size * max(args.seq_len_q_list) if args.run_type == "GEN": ctx_batch_size = max(1, max(20480, args.max_num_tokens) // max(args.seq_len_kv_cache_list)) args.max_num_tokens = max( args.max_num_tokens, ctx_batch_size * max(args.seq_len_kv_cache_list) ) else: if args.run_type == "GEN": ctx_batch_size = max(1, args.max_num_tokens // max(args.seq_len_kv_cache_list)) assert args.max_num_tokens >= ctx_batch_size * max(args.seq_len_kv_cache_list), ( "Max_num_tokens is too small to prefill KV cache" ) if args.use_low_precision_moe_combine is None: args.use_low_precision_moe_combine = False if args.enable_autotuner is None: args.enable_autotuner = True if args.use_cuda_graph is None: args.use_cuda_graph = False print(args) # MPI args rank = mpi_rank() world_size = mpi_world_size() local_rank = local_mpi_rank() torch.cuda.set_device(local_rank) # Create KV cache manager logger.info("Layer-wise benchmarks: Create KV cache manager") Runner = get_runner_cls(args.model) mapping = Runner.create_mapping(enable_attention_dp=args.enable_attention_dp) kv_cache_manager = Runner.create_kv_cache_manager( args.model, mapping, tokens_per_block=args.tokens_per_block, max_batch_size=args.max_batch_size, max_seq_len=args.max_seq_len, layer_indices=args.layer_indices, ) attn_workspace = torch.empty((0,), device="cuda", dtype=torch.int8) logger.info("Layer-wise benchmarks: Create KV cache manager ... Done") # Create other global objects AutoTuner.get().clear_cache() capture_stream = torch.cuda.Stream() mark_ranges() # Create runner logger.info("Layer-wise benchmarks: Create runner") runner = Runner( args.model, mapping, moe_backend=args.moe_backend, layer_indices=args.layer_indices, scaled_from=args.scaled_from, max_seq_len=args.max_seq_len, max_num_tokens=args.max_num_tokens, moe_max_num_tokens=args.moe_max_num_tokens, use_low_precision_moe_combine=args.use_low_precision_moe_combine, use_cuda_graph=args.use_cuda_graph, ) logger.info("Layer-wise benchmarks: Create runner ... Done") # Warm up for autotune_flag, batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [ [ True, max(args.batch_size_list), max(args.seq_len_q_list), args.seq_len_kv_cache_list[0], args.balance_ratio_list[0], ], *itertools.product( [False], args.batch_size_list, args.seq_len_q_list, args.seq_len_kv_cache_list, args.balance_ratio_list, ), ]: assert batch_size <= args.max_batch_size assert seq_len_q + seq_len_kv_cache <= args.max_seq_len assert batch_size * seq_len_q <= args.max_num_tokens run_pack = runner.create_run_pack( args.run_type, batch_size=batch_size, request_id_begin=0, seq_len_q=seq_len_q, seq_len_kv_cache=seq_len_kv_cache, kv_cache_manager=kv_cache_manager, attn_workspace=attn_workspace, ) with runner.replace_routing_method_ctx( balance_method=BalanceMethod[args.balance_method], balance_ratio=balance_ratio ): capture_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(capture_stream): if autotune_flag: if args.enable_autotuner: cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None with autotune(cache_path=cache_path): run_pack() if args.run_type == "GEN": logger.info("Layer-wise benchmarks: Prefill KV cache") ctx_seq_len_q = max(args.seq_len_kv_cache_list) assert ctx_batch_size <= args.max_batch_size assert ctx_seq_len_q + 0 <= args.max_seq_len assert ctx_batch_size * ctx_seq_len_q <= args.max_num_tokens max_batch_size = max(args.batch_size_list) for request_id_begin in range(0, max_batch_size, ctx_batch_size): ctx_run_pack = runner.create_run_pack( "CTX", batch_size=min(ctx_batch_size, max_batch_size - request_id_begin), request_id_begin=request_id_begin, seq_len_q=ctx_seq_len_q, seq_len_kv_cache=0, kv_cache_manager=kv_cache_manager, attn_workspace=attn_workspace, ) ctx_run_pack(check=True) logger.info("Layer-wise benchmarks: Prefill KV cache ... Done") else: run_pack(check=True) torch.cuda.current_stream().wait_stream(capture_stream) torch.cuda.synchronize() events = [ torch.cuda.Event(enable_timing=True) for _ in range(args.warmup_times + args.run_times + 1) ] [e.record() for e in events] # Explicitly warmup events because torch is lazy torch.cuda.cudart().cudaProfilerStart() with nvtx.annotate(f"layer_wise_benchmarks args {json.dumps(args.__dict__)}"): pass # Use `annotate` instead of `mark` to avoid addition lines on the Nsight Systems UI for batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in itertools.product( args.batch_size_list, args.seq_len_q_list, args.seq_len_kv_cache_list, args.balance_ratio_list ): # Profile: capture graph and replay it problem_spec = { "batch_size": batch_size, "seq_len_q": seq_len_q, "seq_len_kv_cache": seq_len_kv_cache, "balance_ratio": balance_ratio, } with nvtx.annotate(f"layer_wise_benchmarks problem_spec {json.dumps(problem_spec)}"): pass run_pack = runner.create_run_pack( args.run_type, batch_size=batch_size, request_id_begin=0, seq_len_q=seq_len_q, seq_len_kv_cache=seq_len_kv_cache, kv_cache_manager=kv_cache_manager, attn_workspace=attn_workspace, ) with runner.replace_routing_method_ctx( balance_method=BalanceMethod[args.balance_method], balance_ratio=balance_ratio ): if args.use_cuda_graph: with with_multi_stream(True): g = torch.cuda.CUDAGraph() with torch.cuda.graph(g, stream=capture_stream, capture_error_mode="global"): run_pack() balance_ratio_str = "" if balance_ratio is None else f" balance={balance_ratio:.2g}" nvtx_message = f"b={batch_size} s={seq_len_q} past={seq_len_kv_cache}{balance_ratio_str} NP{world_size}" for i in range(args.warmup_times + args.run_times): events[i].record() with nvtx.annotate(nvtx_message): if args.use_cuda_graph: g.replay() else: run_pack() events[-1].record() torch.cuda.synchronize() # Print statistics # Print before `cudaProfilerStop` to ensure messages are included in the profile time_list = [start.elapsed_time(stop) for start, stop in zip(events, events[1:])] time_list = time_list[args.warmup_times :] print( f"[RANK {rank}]" f" batch_size {batch_size}" f" seq_len_q {seq_len_q}" f" seq_len_kv_cache {seq_len_kv_cache}" + ("" if balance_ratio is None else f" balance_ratio {balance_ratio:.2g}") + f" mean {np.mean(time_list) * 1000:.1f}" f" median {np.median(time_list) * 1000:.1f}" f" min {np.min(time_list) * 1000:.1f}" f" max {np.max(time_list) * 1000:.1f}" f" P90 {np.percentile(time_list, 90) * 1000:.1f}" f" (us)" ) torch.cuda.cudart().cudaProfilerStop()