mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
160 lines
5.5 KiB
Python
160 lines
5.5 KiB
Python
import argparse
|
|
|
|
import numpy as np
|
|
import nvtx
|
|
import torch
|
|
import yaml
|
|
|
|
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
|
|
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
|
|
from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size
|
|
from tensorrt_llm.tools.layer_wise_benchmarks.deepseekv3_runner import (
|
|
BalanceMethod, DeepSeekV3Runner)
|
|
|
|
|
|
def comma_separated_ints(s):
|
|
return [int(x) for x in s.split(",")]
|
|
|
|
|
|
# Parse cmdline
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("config_path", type=str)
|
|
parser.add_argument("--model", type=str, help="Pretrained model name or path")
|
|
parser.add_argument(
|
|
"--layer-indices",
|
|
type=comma_separated_ints,
|
|
help="Comma separated indices of layers, should be a contiguous range")
|
|
parser.add_argument("--run-type", type=str, choices=["CTX", "GEN"])
|
|
parser.add_argument("--scaled-from", type=int)
|
|
# KV cache related args
|
|
parser.add_argument("--tokens-per-block", type=int)
|
|
parser.add_argument("--max-seq-len", type=int)
|
|
group = parser.add_mutually_exclusive_group(required=False)
|
|
group.add_argument("--enable-attention-dp",
|
|
action="store_true",
|
|
dest="enable_attention_dp")
|
|
group.add_argument("--no-enable-attention-dp",
|
|
action="store_false",
|
|
dest="enable_attention_dp")
|
|
parser.set_defaults(enable_attention_dp=None)
|
|
# Model init args
|
|
parser.add_argument("--max-num-tokens", type=int)
|
|
parser.add_argument("--moe-backend", type=str)
|
|
group = parser.add_mutually_exclusive_group(required=False)
|
|
group.add_argument("--use-cuda-graph",
|
|
action="store_true",
|
|
dest="use_cuda_graph")
|
|
group.add_argument("--no-use-cuda-graph",
|
|
action="store_false",
|
|
dest="use_cuda_graph")
|
|
parser.set_defaults(use_cuda_graph=None)
|
|
# Per iteration args
|
|
parser.add_argument("--batch-size", type=int)
|
|
parser.add_argument("--seq-len-q", type=int)
|
|
parser.add_argument("--seq-len-kv-cache", type=int)
|
|
parser.add_argument("--balance-method", type=str)
|
|
parser.add_argument("--balance-ratio", type=float)
|
|
args = parser.parse_args()
|
|
with open(args.config_path) as f:
|
|
config = yaml.safe_load(f)
|
|
del args.config_path
|
|
for k, v in vars(args).items():
|
|
if v is None:
|
|
setattr(args, k, config[k])
|
|
print(args)
|
|
|
|
# MPI args
|
|
rank = mpi_rank()
|
|
world_size = mpi_world_size()
|
|
local_rank = local_mpi_rank()
|
|
torch.cuda.set_device(local_rank)
|
|
|
|
# Create KV cache manager
|
|
mapping = DeepSeekV3Runner.create_mapping(
|
|
enable_attention_dp=args.enable_attention_dp)
|
|
max_batch_size = 2048
|
|
kv_cache_manager = DeepSeekV3Runner.create_kv_cache_manager(
|
|
args.model,
|
|
mapping,
|
|
tokens_per_block=args.tokens_per_block,
|
|
max_batch_size=max_batch_size,
|
|
max_seq_len=args.max_seq_len,
|
|
layer_indices=args.layer_indices)
|
|
attn_workspace = torch.empty((0, ), device="cuda", dtype=torch.int8)
|
|
|
|
# Create other global objects
|
|
AutoTuner.get().clear_cache()
|
|
capture_stream = torch.cuda.Stream()
|
|
|
|
# Create Runner
|
|
runner = DeepSeekV3Runner(args.model,
|
|
mapping,
|
|
moe_backend=args.moe_backend,
|
|
layer_indices=args.layer_indices,
|
|
scaled_from=args.scaled_from,
|
|
max_seq_len=args.max_seq_len,
|
|
max_num_tokens=args.max_num_tokens,
|
|
use_cuda_graph=args.use_cuda_graph)
|
|
|
|
# Warm up
|
|
assert args.batch_size <= max_batch_size
|
|
assert args.seq_len_q + args.seq_len_kv_cache <= args.max_seq_len
|
|
run_pack = runner.create_run_pack(args.run_type,
|
|
batch_size=args.batch_size,
|
|
seq_len_q=args.seq_len_q,
|
|
seq_len_kv_cache=args.seq_len_kv_cache,
|
|
kv_cache_manager=kv_cache_manager,
|
|
attn_workspace=attn_workspace)
|
|
runner.replace_routing_method(balance_method=BalanceMethod[args.balance_method],
|
|
balance_ratio=args.balance_ratio)
|
|
capture_stream.wait_stream(torch.cuda.current_stream())
|
|
with torch.cuda.stream(capture_stream):
|
|
run_pack()
|
|
with autotune():
|
|
run_pack()
|
|
torch.cuda.current_stream().wait_stream(capture_stream)
|
|
torch.cuda.synchronize()
|
|
|
|
# Profile: capture graph and replay it
|
|
torch.cuda.cudart().cudaProfilerStart()
|
|
if args.use_cuda_graph:
|
|
with with_multi_stream(True):
|
|
g = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(g,
|
|
stream=capture_stream,
|
|
capture_error_mode="global"):
|
|
run_pack()
|
|
|
|
warmup_times = 20
|
|
run_times = 100
|
|
events = [
|
|
torch.cuda.Event(enable_timing=True)
|
|
for _ in range(warmup_times + run_times + 1)
|
|
]
|
|
for i in range(warmup_times + run_times):
|
|
events[i].record()
|
|
with nvtx.annotate(
|
|
f"b={args.batch_size} s={args.seq_len_q} EP{world_size}"):
|
|
if args.use_cuda_graph:
|
|
g.replay()
|
|
else:
|
|
run_pack()
|
|
events[-1].record()
|
|
torch.cuda.synchronize()
|
|
|
|
# Print statistics
|
|
# Print before `cudaProfilerStop` to ensure messages are included in the profile
|
|
time_list = [
|
|
start.elapsed_time(stop) for start, stop in zip(events, events[1:])
|
|
]
|
|
time_list = time_list[warmup_times:]
|
|
print(f"[RANK {rank}]"
|
|
f" min {np.min(time_list) * 1000:.1f}"
|
|
f" max {np.max(time_list) * 1000:.1f}"
|
|
f" mean {np.mean(time_list) * 1000:.1f}"
|
|
f" median {np.median(time_list) * 1000:.1f}"
|
|
f" P90 {np.percentile(time_list, 90) * 1000:.1f}"
|
|
f" (us)")
|
|
|
|
torch.cuda.cudart().cudaProfilerStop()
|