TensorRT-LLMs/examples/layer_wise_benchmarks/parse.py
Tailing Yuan 91528365a9
[None][feat] Add performance alignment to layer-wise benchmarks (#11018)
Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
2026-01-29 14:01:51 +08:00

400 lines
15 KiB
Python

import argparse
import bisect
import csv
import json
import re
import sqlite3
from pathlib import Path
import jinja2
import numpy as np
import pandas as pd
from parser_utils import (
kernel_short_name,
lazy_convert_sqlite,
shortest_common_supersequence,
warned_names,
)
# Parse cmdline
parser = argparse.ArgumentParser()
parser.add_argument("--file-path", type=str)
parser.add_argument("--profile-dir", type=str, default="profiles")
parser.add_argument("--world-size", "--np", type=int)
parser.add_argument("--rank", type=int, default=0)
parser.add_argument("--warmup-times", type=int)
parser.add_argument("--module", type=str)
parser.add_argument("--query", type=str)
group = parser.add_mutually_exclusive_group()
group.add_argument("--error-on-unknown-kernel", action="store_true", dest="error_on_unknown_kernel")
group.add_argument(
"--no-error-on-unknown-kernel", action="store_false", dest="error_on_unknown_kernel"
)
parser.set_defaults(error_on_unknown_kernel=False)
args = parser.parse_args()
if (args.file_path is None) == (args.world_size is None):
parser.error("Please specify exactly one of --file-path and --world-size.")
print(args)
if args.file_path is not None:
nsys_rep_file_path = Path(args.file_path)
if not nsys_rep_file_path.name.endswith(".nsys-rep"):
raise ValueError("Expect a .nsys-rep file")
else:
profile_dir = Path(args.profile_dir)
nsys_rep_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.nsys-rep"
sqlite_file_path = nsys_rep_file_path.parent / (
nsys_rep_file_path.name[: -len(".nsys-rep")] + ".sqlite"
)
csv_file_path = nsys_rep_file_path.parent / (nsys_rep_file_path.name[: -len(".nsys-rep")] + ".csv")
html_file_path = nsys_rep_file_path.parent / (
nsys_rep_file_path.name[: -len(".nsys-rep")] + ".html"
)
json_file_path = nsys_rep_file_path.parent / (
nsys_rep_file_path.name[: -len(".nsys-rep")] + ".json"
)
lazy_convert_sqlite(nsys_rep_file_path, sqlite_file_path)
conn = sqlite3.connect(f"file:{sqlite_file_path}?mode=ro", uri=True)
query = "SELECT * FROM ENUM_NSYS_EVENT_TYPE"
df = pd.read_sql_query(query, conn)
event_id_NvtxDomainCreate = df[df["name"] == "NvtxDomainCreate"].iloc[0]["id"].tolist()
event_id_NvtxPushPopRange = df[df["name"] == "NvtxPushPopRange"].iloc[0]["id"].tolist()
query = "SELECT domainId FROM NVTX_EVENTS WHERE eventType = ? AND text = ?"
df = pd.read_sql_query(query, conn, params=(event_id_NvtxDomainCreate, "NCCL"))
nccl_domain_id = -1 if df.empty else df.iloc[0]["domainId"].tolist()
query = """SELECT T1.start, T2.value AS text
FROM NVTX_EVENTS AS T1
JOIN StringIds AS T2 ON T1.textId = T2.id
WHERE eventType = ? AND T2.value LIKE ?"""
df = pd.read_sql_query(query, conn, params=(event_id_NvtxPushPopRange, "layer_wise_benchmarks %"))
problem_start = []
problem_set = []
for start, text in df.itertuples(index=False):
if text.startswith("layer_wise_benchmarks args {"):
run_args = json.loads(text[len("layer_wise_benchmarks args") :])
elif text.startswith("layer_wise_benchmarks problem_spec {"):
problem_start.append(start)
problem_set.append(
{
"spec": json.loads(text[len("layer_wise_benchmarks problem_spec ") :]),
"text": "",
"runs": [],
"runs_end": [],
"ranges": [],
"kernel_count_per_range": [],
}
)
query = """SELECT T1.start, T1.end, T2.value AS text
FROM NVTX_EVENTS AS T1
JOIN StringIds AS T2 ON T1.textId = T2.id
WHERE eventType = ? AND T2.value NOT LIKE ? AND domainId != ?"""
df = pd.read_sql_query(
query,
conn,
params=(event_id_NvtxPushPopRange, "[DG]%", nccl_domain_id),
)
for start, end, text in df.itertuples(index=False):
problem_id = bisect.bisect(problem_start, start) - 1
if text.startswith("layer_wise_benchmarks "):
if text != "layer_wise_benchmarks ignore":
continue
else:
assert problem_id != -1
if re.match(r"b=\d+ s=\d+ ", text):
problem_set[problem_id]["text"] = text
problem_set[problem_id]["runs"].append(start)
problem_set[problem_id]["runs_end"].append(end)
else:
problem_set[problem_id]["ranges"].append((start, end, text))
problem_set[problem_id]["kernel_count_per_range"].append(0)
query = """SELECT name FROM sqlite_master WHERE type = ?"""
df = pd.read_sql_query(query, conn, params=("table",))
tables = df["name"].tolist()
unified_subquery = """SELECT T1.start, T1.end, T1.demangledName, T1.correlationId, T1.graphNodeId
FROM CUPTI_ACTIVITY_KIND_KERNEL AS T1"""
if "CUPTI_ACTIVITY_KIND_MEMCPY" in tables:
unified_subquery += """ UNION ALL
SELECT T2.start, T2.end, -2 AS demangledName, T2.correlationId, T2.graphNodeId
FROM CUPTI_ACTIVITY_KIND_MEMCPY AS T2"""
if "CUPTI_ACTIVITY_KIND_MEMSET" in tables:
unified_subquery += """ UNION ALL
SELECT T3.start, T3.end, -3 AS demangledName, T3.correlationId, T3.graphNodeId
FROM CUPTI_ACTIVITY_KIND_MEMSET AS T3"""
query = f"""SELECT unified.start, unified.end, unified.demangledName,
R.start AS runtime_start, R.end AS runtime_end,
R.start AS capture_start, R.end AS capture_end
FROM ({unified_subquery}) AS unified
JOIN CUPTI_ACTIVITY_KIND_RUNTIME AS R ON unified.correlationId = R.correlationId
WHERE unified.graphNodeId IS NULL"""
if "CUDA_GRAPH_NODE_EVENTS" in tables:
query += f""" UNION ALL
SELECT unified.start, unified.end, unified.demangledName,
R.start AS runtime_start, R.end AS runtime_end,
CGE2.start AS capture_start, CGE2.end AS capture_end
FROM ({unified_subquery}) AS unified
JOIN CUPTI_ACTIVITY_KIND_RUNTIME AS R ON unified.graphNodeId IS NOT NULL AND
unified.correlationId = R.correlationId
LEFT JOIN CUDA_GRAPH_NODE_EVENTS AS CGE1 ON unified.graphNodeId = CGE1.graphNodeId AND
CGE1.originalGraphNodeId IS NOT NULL
LEFT JOIN CUDA_GRAPH_NODE_EVENTS AS CGE2 ON CGE1.originalGraphNodeId = CGE2.graphNodeId"""
df = pd.read_sql_query(query, conn)
kernel_list = []
for (
start,
end,
demangledName,
runtime_start,
runtime_end,
capture_start,
capture_end,
) in df.itertuples(index=False):
problem_id = bisect.bisect(problem_start, start) - 1
problem = problem_set[problem_id]
run_id = bisect.bisect(problem["runs"], runtime_start) - 1
if run_id == -1 or runtime_start >= problem["runs_end"][run_id]:
continue
ranges = [
i
for i, (range_start, range_end, text) in enumerate(problem["ranges"])
if capture_start >= range_start and capture_end <= range_end
]
for range_id in ranges:
problem["kernel_count_per_range"][range_id] += 1
range_names = [problem["ranges"][i][2] for i in ranges]
if (
args.module is None or args.module in range_names
) and "layer_wise_benchmarks ignore" not in range_names:
kernel_list.append(
(
problem_id,
run_id,
range_names,
start,
end,
demangledName,
runtime_start,
runtime_end,
capture_start,
capture_end,
)
)
query = "SELECT * FROM StringIds"
df = pd.read_sql_query(query, conn)
string_ids = dict(zip(df["id"], df["value"]))
string_ids.update({-2: "Memcpy", -3: "Memset"})
conn.close()
# Check ambiguous modules
if args.module:
for problem in problem_set:
num_matches_per_run = [0] * (len(problem["runs"]) + 1)
for (range_start, _, text), kernel_count in zip(
problem["ranges"], problem["kernel_count_per_range"]
):
if text == args.module and kernel_count > 0:
num_matches_per_run[bisect.bisect(problem["runs"], range_start)] += 1
for run_id_plus_one, num_matches in enumerate(num_matches_per_run):
if num_matches > 1:
raise ValueError(
f'Module is ambiguous: "{args.module}" appears {num_matches} times'
f' in "{problem["text"]}"\'s {run_id_plus_one}-th run'
)
kernel_list.sort(key=lambda t: (t[6], t[8]))
kernels = [[[] for _ in problem["runs"]] for problem in problem_set]
for (
problem_id,
run_id,
ranges,
start,
end,
demangledName,
runtime_start,
runtime_end,
capture_start,
capture_end,
) in kernel_list:
kernels[problem_id][run_id].append((demangledName, start, end, ranges))
for problem_id in range(len(kernels)):
required_seq = [demangledName for demangledName, _, _, _ in kernels[problem_id][0]]
for run_id in range(len(kernels[problem_id])):
seq = [demangledName for demangledName, _, _, _ in kernels[problem_id][run_id]]
assert seq == required_seq
converted_seqs = []
warmup_times = run_args["warmup_times"] if args.warmup_times is None else args.warmup_times
for runs in kernels:
converted_seq = []
# Kernel time
for i, (demangledName, _, _, ranges) in enumerate(runs[0]):
name = kernel_short_name(string_ids[demangledName])
category = (*ranges, name)
time_list = [run[i][2] - run[i][1] for run in runs]
t = np.mean(time_list[warmup_times:]).tolist()
converted_seq.append((category, t))
# Space and Overlap
overlap_list = []
space_list = []
for run in runs:
sorted_run = sorted(run, key=lambda op: op[1])
last_end = sorted_run[0][1]
overlap_time = 0
space_time = 0
for _, start, end, _ in sorted_run:
if start > last_end:
space_time += start - last_end
else:
overlap_time += min(last_end, end) - start
last_end = max(last_end, end)
overlap_list.append(-overlap_time)
space_list.append(space_time)
converted_seq.append((("Overlap",), np.mean(overlap_list[warmup_times:]).tolist()))
converted_seq.append((("Space",), np.mean(space_list[warmup_times:]).tolist()))
converted_seq.append((("Total",), sum(t for _, t in converted_seq)))
converted_seqs.append(converted_seq)
if args.error_on_unknown_kernel and warned_names:
raise ValueError("Unknown kernel names encountered")
merged_title = []
for converted_seq in converted_seqs:
title = [name for name, _ in converted_seq]
merged_title = shortest_common_supersequence(merged_title, title)
merged_data = [[0.0] * len(problem_set) for _ in merged_title]
for problem_id, converted_seq in enumerate(converted_seqs):
cur = 0
for category, t in converted_seq:
cur = merged_title.index(category, cur)
merged_data[cur][problem_id] = t
cur += 1
print("Run args:")
print(run_args)
print("Problem set:")
for problem in problem_set:
print(
f'- "{problem["text"]}" {len(problem["runs"])} runs'
f" Ranges: [{', '.join(text for _, end, text in problem['ranges'] if end <= problem['runs_end'][0])}]"
)
stack = []
csv_data = [["", *[problem["text"] for problem in problem_set]]]
js_data = []
js_stack = [js_data]
max_title_len = max((len(title) - 1) * 3 + len(title[-1][:40]) for title in merged_title)
print("-" * (max_title_len + 1 + 6 * len(problem_set)))
for title, time_data in zip(merged_title, merged_data):
while stack != list(title[: len(stack)]):
level_title = stack[-1]
stack.pop()
js_stack[-2].append(
{
"name": level_title,
"children": js_stack[-1],
}
)
js_stack.pop()
while len(stack) != len(title) - 1:
level_title = title[len(stack)]
stack.append(level_title)
level = len(stack)
print("|--" * (level - 1) + level_title)
csv_data.append(["|--" * (level - 1) + level_title] + [""] * len(problem_set))
js_stack.append([])
level = len(stack) + 1
print(
"|--" * (level - 1)
+ title[-1][:40]
+ " " * (max_title_len - (level - 1) * 3 - len(title[-1][:40])),
*[f"{x / 1000:-6.1f}" for x in time_data],
)
csv_data.append(["|--" * (level - 1) + title[-1], *[f"{x / 1000:.1f}" for x in time_data]])
if title != ("Total",):
js_stack[-1].append(
{
"name": title[-1],
"time": [x / 1000 for x in time_data],
}
)
# TODO: Group repeated modules
with csv_file_path.open("w", newline="") as f:
csv_writer = csv.writer(f, quoting=csv.QUOTE_MINIMAL)
for row in csv_data:
csv_writer.writerow(row)
js_header_config = [{"name": problem["text"]} for problem in problem_set]
js_header_config = []
for problem in problem_set:
innermost_children = js_header_config
for k, msg_prefix in [
("batch_size", "b="),
("seq_len_q", "q="),
("seq_len_kv_cache", "past="),
]:
if len(run_args[k + "_list"]) > 1:
if len(innermost_children) == 0 or problem["spec"][k] != innermost_children[-1][k]:
innermost_children.append(
{
"name": msg_prefix + str(problem["spec"][k]),
"children": [],
k: problem["spec"][k],
}
)
innermost_children = innermost_children[-1]["children"]
innermost_children.append({"name": problem["text"]})
loader = jinja2.FileSystemLoader(Path(__file__).parent)
template = jinja2.Environment(loader=loader).get_template("breakdown_template.html")
with html_file_path.open("w") as f:
configText = (
"Run:\n"
+ json.dumps(run_args, indent=4)
+ "\n\nParse:\n"
+ json.dumps(args.__dict__, indent=4)
)
f.write(template.render(headerConfig=js_header_config, rawData=js_data, configText=configText))
if args.query is not None:
print("Query:")
for query in args.query.split(","):
query = query.strip()
query_matched = [0.0] * len(problem_set)
for title, time_data in zip(merged_title, merged_data):
if query in ".".join(title):
for i, x in enumerate(time_data):
query_matched[i] += x
print(
query + " " * (max_title_len - len(query)),
*[f"{x / 1000:-6.1f}" for x in query_matched],
)
correlation = []
for problem, runs in zip(problem_set, kernels):
timeline = []
for i, (demangledName, _, _, _) in enumerate(runs[0]):
name = string_ids[demangledName]
duration_list = [run[i][2] - run[i][1] for run in runs]
end_list = [run[i][2] - run[0][1] for run in runs]
timeline.append(
{
"name": name,
"duration": np.mean(duration_list[warmup_times:]).tolist(),
"end": np.mean(end_list[warmup_times:]).tolist(),
}
)
correlation.append(
{
"name": problem["text"],
"timeline": timeline,
}
)
with json_file_path.open("w") as f:
json.dump(correlation, f)