[None] [feat] nsys profile output kernel classifier (#7020)

Signed-off-by: Grace Ho <grho@nvidia.com>
This commit is contained in:
Grace Ho 2025-08-22 21:57:37 -07:00 committed by GitHub
parent 81fd468fec
commit 3d54a1a521
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 585 additions and 0 deletions

View File

@ -0,0 +1,174 @@
# gputrc2graph.py
This script processes NVIDIA Nsight Systems (`nsys`) GPU trace files
(`.nsys-rep`) with -t cuda tracing enabled, and generates kernel-level
summaries and visualizations of GPU and non-GPU time. It is useful for
profiling and analyzing nsys profile output.
## Usage
### Command-line Arguments
- `--in_file`
**(required)**
List of input files and their metadata. Each entry should be in the format:
`<nsys-rep>,<engine>,<model>,<elapsed_nonprofiled_sec>`
- `nsys-rep`: Path to the `.nsys-rep` file.
- `engine`: Engine name (e.g., `trtllm`).
- `model`: Model name (e.g., `llama`, `gpt-oss`, `ds`).
- `elapsed_nonprofiled_sec`: Wall-clock runtime (in seconds) without
profiling. Specify `0` to use the elapsed GPU time calculated from the nsys-rep file (this may inflate non-GPU time if actual runtime without profiling is less). Multiple entries can be provided, separated by spaces.
- `--out_dir`
Output directory for the generated CSV and HTML files.
If not specified, results are saved in the current directory.
- `--title`
Title for the HTML chart/visualization.
- `--nsys_cmd`
Path to the `nsys` command.
Default: `nsys` (assumes it is in your PATH).
Use this if `nsys` is not in your system PATH.
## Notes
- Make sure you have pandas and plotly python packages installed.
- Make sure [nsys](https://developer.nvidia.com/nsight-systems/get-started) is
installed, and specify the path to the `nsys` command with `--nsys_cmd` if it
is not in your PATH.
- For more details on available engines and models, see the help string in
the script or run:
```bash
python3 gputrc2graph.py --help
```
## Example 1: analyze a single profile
To analyze the GPU cycles of for example, a llama-3.1-8B model with trtllm:
1. Run the following command to collect nsys profile, for trtllm serve config.
```bash
nsys profile -t cuda -o nsys_res -f true --trace-fork-before-exec=true \
--cuda-graph-trace=node --delay <DELAY> --duration <DURATION> \
python3 -m trtllm-serve meta-llama/Llama-4-Scout-17B-16E-Instruct ...
```
where:
- DELAY: how many seconds to delay nsys from collecting profiles, needed so
that profiles aren't captured till trtllm server has come up and load
generation starts.
- DURATION: how many seconds for nsys profile to run before generating the
profile. This should be > the duration of the run.
2. Run again, this time without collecting the profile, and get the total run
time in seconds. This value will be used by the script to calculate the
CPU(non-GPU) seconds for the analysis.
3. Say the run elapsed time is .35 seconds, from step #2. Run script to
analyze:
```bash
python3 gputrc2graph.py \
--in_file run1.nsys-rep,trtllm,llama,.35
```
The command will produce 2 files for analysis:
- result.html: this categorizes kernel names into different categories in a
stacked bar chart.
- result.csv: shows how the kernel names are mapped to the different
categories.
### HTML visualization with result.html
The html file shows the number of elapsed seconds due to different GPU
Substages or categories, which consist of moe_gemm as the biggest
category, at .14 seconds, followed by "attn" kernels. This lets the user
prioritize the kernels to focus on for performance optimizations.
![Example GPU Trace Visualization](images/html.png)
There's also an appended data table underneath the bar chart for copying out to
other post-processing tools.
![Example GPU Trace Visualization Table](images/html_tbl.png)
### Kernel to category mapping with result.csv
Suppose the user would like to focus on improving decreasing calls to nccl
kernels. The next step is to use the result.csv to dive into what the kernels
are which compose the nccl GPU cycles. The following image shows that
ar_fusion all reduce kernel to be the biggest contributor to GPU cycles for
nccl, followed by AllGather.
![Example GPU Trace csv](images/csv.png)
## Example 2: analyze multiple profiles
Suppose the user has multiple nsys trace files, captured for different models,
say llama and gpt-oss in this case, and wish to compare their GPU/non-GPU
time, something like the following command can be used.
```bash
python3 gputrc2graph.py \
--in_file run1.nsys-rep,trtllm,llama,100 run2.nsys-rep,trtllm,gpt-oss,102 \
--out_dir results
```
The analysis process is similar to example 1 but now there will be multiple
stack bar charts that can be compared. The categories for the different
kernels will remain the same, so that it's easy to compare the GPU cycles for
the same categories.
Once a category is shown to have more cycles for one configuration than
another, the next step would be to use the csv file to see what kernels are
mapped into that category, and which kernels are taking the largest amount of
time which would cause a difference for the overall category.
## Example 3: add new classification for a new model
To create a new engine DEF with model ABC, just add another json file in the
same directory as gputrc2graph.py with the same format as the other json files.
The script will automatically pick up all the json files in the same directory
as engine/model specifications.
Then, for this new model, suppose there are 4 kernels to be classified into
"gemm" and "attn", where the gemm kernelshave names with "*H*" or "*I*" in
them, and attn kernels have names with "*J*" or "*K*" in them, just add another
.json file in the same directory as gputrc2graph.py with the same format as
the other json files, like the following:
```json
{
"DEF": {
"ABC": {
"H|I": "gemm",
"J|K": "attn",
"CUDA mem": "non-gpu-H_D_memops",
".*": "misc"
}
}
}
```
Each entry in the dictionary consists of:
- key: a regex used to classify the kernels
- value: the category to classify the kernels into.
The last 2 entries are common for all engine/models, consisting of CUDA memory
operations and a 'misc' for anything that's leftover and can't be classified.
When invoking gputrc2graph.py, specify a trace file with this new model/engine
like the following:
```bash
--in_file new.nsys-rep,DEF,ABC,<runtime>
```
If the engine_DEF.json file already exists, just add the model as a new node in
the existing engine file, after the other models.

View File

@ -0,0 +1,349 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This generates gpu kernel analysis output from nsys rep. Will call nsys
stats -r cuda_gpu_trace, get non-overlapped gpu cycles, then generate
csv and html output for analysis
"""
import argparse
import logging
import os
import regex as re
logger = logging.getLogger(__name__)
# helper data class for annotating kernels
def load_engine_model():
"""returns engine_model built from all json files in the current dir"""
import glob
import json
engine_model = {}
json_files = glob.glob(
os.path.join(os.path.dirname(__file__) or ".", "*.json"))
for fname in json_files:
with open(fname, encoding="utf-8") as f:
engine_model.update(json.load(f))
return engine_model
class GPUTrace2Graph:
"""
Parses output of nsys report, generates csv and bar chart output
"""
def __init__(self):
import pandas as pd # avoid importing till needed
self.pd = pd
self.pd.options.mode.copy_on_write = True
# helper functions for generating trace->summary csvs
def gen_nonoverlapped_sum_from_gputrace(self, in_file, out_file):
logger.info("loading %s", in_file)
df = self.pd.read_csv(in_file,
usecols=["Start (ns)", "Duration (ns)", "Name"])
if df.empty:
return
df["End (ns)"] = df["Start (ns)"] + df["Duration (ns)"]
df = self.sum_non_overlapping_intervals(df)
# get ready to print table with elapsed times per kernel
df["Instances"] = 1
df_sum = df.groupby("Name", as_index=False).agg({
"Elapsed Time (ns)": "sum",
"Duration (ns)": "sum",
"Instances": "size"
})
# generate csv
df_sum["Total Time (sec)"] = df_sum["Duration (ns)"] / 1e9
df_sum["Elapsed Time (sec)"] = df_sum["Elapsed Time (ns)"] / 1e9
df_sum = df_sum.sort_values(by="Elapsed Time (sec)", ascending=False)
df_sum[["Elapsed Time (sec)", "Total Time (sec)", "Instances",
"Name"]].to_csv(out_file, index=False)
def sum_non_overlapping_intervals(self, df):
"""
returns new sorted df with Elapsed Time (ns) column using
vectorized operations
"""
logger.info("sorting %s trace records by start time", str(df.shape))
assert not df.empty, 'empty nsys records'
# Sort by start time and reset index
df = df.sort_values(by="Start (ns)").reset_index(drop=True)
# Initialize elapsed time as duration
df["Elapsed Time (ns)"] = df["Duration (ns)"]
# Get numpy arrays for faster operations
starts = df["Start (ns)"].values
ends = df["End (ns)"].values
# Keep track of current interval end
current_end = ends[0]
display_units = max(1, int(len(df) / 100))
# Update current_end for overlapping intervals
for i in range(1, len(df)):
if i % display_units == 0:
print(f"processing trace: {int(i/len(df) * 100)} %", end="\r")
if starts[i] <= current_end:
if ends[i] > current_end:
# Partial overlap
df.iloc[i, df.columns.get_loc("Elapsed Time (ns)")] = (
ends[i] - current_end)
current_end = ends[i]
else:
# Complete overlap
df.iloc[i, df.columns.get_loc("Elapsed Time (ns)")] = 0
else:
# No overlap
current_end = ends[i]
return df
# functions for generating html files
def make_html(self, df, output_dir, title):
"""make html graph from df"""
import plotly.express as px
if df.empty:
return
output_name = os.path.join(output_dir, "result")
if not title:
title = "Model_Engine"
x = "Model_Engine"
y = "Elapsed Time (sec)"
color = "Category"
""" generate kernel mapping table """
# Sort Model_Engine categories by last field after underscore
df["Model_Engine"] = self.pd.Categorical(
df["Model_Engine"],
sorted(df["Model_Engine"].unique(), key=lambda x: x.split("_")[-1]),
)
df[["Model_Engine", color, "Instances", "Name",
y]].sort_values(by=color).to_csv(f"{output_name}.csv", index=False)
graph = px.histogram(
df.round(2),
x=x,
y=y,
title=(f"{y} for {title}"),
color=color,
text_auto=True,
)
# wrap x axis labels
graph.update_xaxes(automargin=True)
graph.write_html(f"{output_name}.html")
"""
Generate data table with columns per Model_Engine into result.html
"""
pivot_df = df.pivot_table(
values="Elapsed Time (sec)",
index="Category",
columns="Model_Engine",
aggfunc="sum",
observed=False,
).round(2)
# Add sum row at bottom
pivot_df.loc["total_elapsed_sec"] = pivot_df.sum()
pivot_df.fillna("").to_html("temp.html")
with (
open(f"{output_name}.html", "a", encoding="utf-8") as outfile,
open("temp.html", encoding="utf-8") as infile,
):
outfile.write(infile.read())
os.remove("temp.html")
print(f"Finished generating: \n"
f" {output_name}.html for stack bar chart \n"
f" {output_name}.csv for Kernel-Category mapping")
def anno_gpu_kernname(self, df, mapping):
"""add "Category" column"""
def anno_gpu_kernname_helper(name):
for kern_name, val in mapping.items():
if re.search(kern_name, name):
return val
df["Category"] = df["Name"].apply(anno_gpu_kernname_helper)
def make_nongpu_row(self, df, nongpu_sec):
"""this will append non-gpu time entry at end of df"""
nongpu_row = self.pd.DataFrame([df.iloc[-1]])
nongpu_row["Category"] = nongpu_row["Name"] = "CPU(non-GPU)"
nongpu_row["Instances"] = 1
nongpu_row["Elapsed Time (sec)"] = nongpu_sec
return nongpu_row
def is_valid_file(self, base_file):
"""asserts if base_file is non-existent or is empty"""
assert (os.path.isfile(base_file) and os.path.getsize(base_file)
> 0), f"{base_file} doesn't exist or is empty"
def should_gen_file(self, new_file, base_file):
"""figure out if new file should be generated from base_file"""
self.is_valid_file(base_file)
if (os.path.exists(new_file)
and (os.path.getmtime(new_file) > os.path.getmtime(base_file))
and (os.path.getsize(base_file) > 0)):
logger.info("reusing %s", new_file)
return False
else:
logger.info("generating %s", new_file)
return True
def gen_sum_file(self, file, nsys_cmd):
"""
generates sum file from nsys trace with times per kernel and
returns the name of the sum file
"""
import subprocess # nosec B404
file_dir = os.path.dirname(file)
file_name = os.path.basename(file)
if not file_dir:
file_dir = "."
# Walk through trace and get the total non-overlapped time
nsys_stats_file = os.path.join(file_dir,
f"{file_name}_cuda_gpu_trace.csv")
sum_file = os.path.join(file_dir,
f"{file_name}_cuda_gpu_kernel_tracesum.csv")
if self.should_gen_file(nsys_stats_file, file):
cmd = [
nsys_cmd,
"stats",
"-r",
"cuda_gpu_trace",
file,
"-o",
f"{file_dir}/{file_name}",
]
cmd_str = " ".join(cmd)
logger.info("+ %s", cmd_str)
# estimate time based on calibrated 240M/min
file_size_mb = os.path.getsize(file) / 1e6
logger.info(
"nsys stats for %.2f MB file expected to take %.2f min",
file_size_mb,
file_size_mb / 240,
)
try:
subprocess.run(cmd)
except Exception:
logger.error("%s failed; Use --nsys_cmd to specify nsys path",
cmd_str)
exit(1)
logger.info("generating non-overalapped sum %s", sum_file)
self.gen_nonoverlapped_sum_from_gputrace(nsys_stats_file, sum_file)
self.is_valid_file(sum_file)
logger.info("Finished generating %s", sum_file)
return sum_file
def gen_graph(self, in_file, out_dir, title, nsys_cmd, engine_model):
"""generates graph and csv file from in_file into out_dir"""
# Initialize an empty DataFrame to store combined data
combined_df = self.pd.DataFrame()
for idx, (file, engine, model, total_sec) in enumerate(in_file):
file_dir = os.path.dirname(file)
file_name = os.path.basename(file)
if not file_dir:
file_dir = "."
sum_file = self.gen_sum_file(file, nsys_cmd)
# read kernel summary file
df = self.pd.read_csv(sum_file)
# annotate kernel to their categories
assert engine_model.get(engine), f"engine {engine} unknown"
assert engine_model[engine].get(model), f"model {model} unknown"
# remove nsys-rep from file_name for shorter x-label
file_name = file_name.replace(".nsys-rep", "")
df["Model_Engine"] = f"{model}_{engine}_{file_name}_{idx}"
self.anno_gpu_kernname(df, engine_model[engine][model])
# patch in non-gpu time
gpu_sec = round(df["Elapsed Time (sec)"].sum(), 1)
total_sec = round(float(total_sec), 1)
if total_sec < gpu_sec:
logger.warning(
"Elapsed sec %.2f < GPU sec %.2f resetting Elapsed sec ",
total_sec,
gpu_sec,
)
total_sec = gpu_sec
nongpu_row = self.make_nongpu_row(df, total_sec - gpu_sec)
df = self.pd.concat([df, nongpu_row], ignore_index=True)
combined_df = self.pd.concat([combined_df, df], ignore_index=True)
if out_dir is None:
out_dir = "."
else:
os.makedirs(out_dir, exist_ok=True)
# generate html file
self.make_html(combined_df, out_dir, title)
def parse_tuple(s):
return tuple(s.split(","))
def main():
logging.basicConfig(format=("%(asctime)s - %(levelname)s - %(message)s"),
level=logging.INFO)
parser = argparse.ArgumentParser(
description=(
"Process nsys rep and generate kernel non-overlapped cycles. \n"
"Example:\n"
"gputrc2graph.py --in_file d1.nsys-rep,trtllm,llama,100 \n"
"d2.nsys-rep,trtllm,gpt-oss,102 "
'--out_dir results/ --title "Model=gpt-oss TRTLLM chart"'),
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# load supported engine_model
engine_model_supported = load_engine_model()
# Get a string representation of supported engine/model combinations
engine_model_supported_str = ", ".join(
f"{engine}:[{', '.join(models.keys())}]"
for engine, models in engine_model_supported.items())
parser.add_argument(
"--in_file",
type=parse_tuple,
nargs="+",
help=("list of (nsys-rep, engine, model, elapsed_nonprofiled_sec) "
"separated by space. Elapsed_nonprofiled_sec is runtime without "
"profiling used to calculate non-gpu time. Specify 0 to use "
"elapsed time from nsys-rep but that might inflate non-gpu time. "
f"Available engine:[model] are: {engine_model_supported_str} "
f"Example: --in_file d1.nsys-rep,sglan,llama,100 "
"d2.nsys-rep,trtllm,gpt-oss,102"),
required=True,
)
parser.add_argument("--out_dir", help=("output dir for result.csv/html"))
parser.add_argument("--title", help=("title for html chart"))
parser.add_argument(
"--nsys_cmd",
help=("nsys cmd, e.g. /usr/bin/nsys, Default: nsys"),
default="nsys",
)
args = parser.parse_args()
gputrace = GPUTrace2Graph()
gputrace.gen_graph(args.in_file, args.out_dir, args.title, args.nsys_cmd,
engine_model_supported)
if __name__ == "__main__":
main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 132 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 140 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 150 KiB

View File

@ -0,0 +1,62 @@
{
"trtllm": {
"llama": {
"Fused_Moe_Kernel|gemm::|fused_moe|bmm_|GemmUniversal": "moe_gemm",
"gemm|nvjet_": "gemm",
"moe|Expert|Moe": "moe",
"CatArrayBatched": "prepare_next",
"ncclDevKernel|AllReduce": "nccl_and_custom_ar",
"RMSNormKernel": "norm",
"topk": "topk",
"act_and_mul_|Activation": "activation",
"Rotary": "rope",
"SoftMax": "softmax",
"flash|splitKreduce|kernel_mha|mmha|fmha": "attn",
"elementwise": "elementwise",
"Quantize|cvt_": "quantize",
"reduce_kernel": "reduce",
"triton": "triton_kernel",
"CUDA mem": "non-gpu-H_D_memops",
".*": "misc"
},
"ds": {
"fp8_blockscale_gemm": "block_fp8_gemm",
"gemm::GroupProblemShape|Fused_Moe_Kernel|bmm_": "moe_gemm",
"gemm|matmul|nvjet|gemvx": "gemm",
"moe|buildExpertMaps|Moe|Expert|Moe": "moe",
"CatArrayBatched": "prepare_next",
"ncclDevKernel|cross_device_reduce|AllReduce": "nccl_and_custom_ar",
"Norm|_norm_": "norm",
"topk": "topk",
"act_and_mul_|Activation": "activation",
"Rope": "rope",
"elementwise": "elementwise",
"fmha|flash_fwd_kernel": "attn",
"Quantize|fp8_quant|quant_fp8|cvt_": "quantize",
"reduce": "reduce",
"SoftMax": "softmax",
"CUDA mem": "non-gpu-H_D_memops",
".*": "misc"
},
"gpt-oss": {
"block_fp8|gemm_fp8_blockwise": "block_fp8_gemm",
"fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_|matmul_ogs_|_topk_forward|_combined_routing|_sum_bitmatrix_rows|_compute_writeback_idx": "moe_gemm",
"gemm|matmul|nvjet": "gemm",
"moe|sigmoid|expert|splitKreduce|Moe": "moe",
"CatArrayBatched": "prepare_next",
"ncclDevKernel|cross_device_reduce|AllReduce": "nccl_and_custom_ar",
"Norm|_norm_": "norm",
"sbtopk": "topk",
"act_and_mul_|Activation": "activation",
"Rope": "rope",
"elementwise": "elementwise",
"fp8_quant|quant_fp8|cvt_": "quantize",
"reduce": "reduce",
"SoftMax": "softmax",
"fmha|mha|flash_fwd_kernel": "attn",
"triton": "triton_kernel",
"CUDA mem": "non-gpu-H_D_memops",
".*": "misc"
}
}
}