mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[None][infra] AutoDeploy: Dump graph IR after every transform (#11045)
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
This commit is contained in:
parent
e719721a60
commit
4a743338c3
@ -1,5 +1,6 @@
|
||||
"""Main export functionality with utilities for torch.export."""
|
||||
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
@ -188,6 +189,49 @@ def _add_load_hook_for_aliased_params(gm: fx.GraphModule, model: nn.Module) -> N
|
||||
gm._register_load_state_dict_pre_hook(aliasing_load_pre_hook)
|
||||
|
||||
|
||||
def _rename_nodes_with_module_hierarchy(gm: fx.GraphModule) -> None:
|
||||
"""Rename call_function nodes to reflect their module hierarchy.
|
||||
|
||||
Uses nn_module_stack metadata to build hierarchical names like:
|
||||
'layers_0_self_attn_linear' instead of 'linear_2'
|
||||
"""
|
||||
graph = gm.graph
|
||||
|
||||
for node in graph.nodes:
|
||||
if node.op != "call_function":
|
||||
continue
|
||||
|
||||
meta = getattr(node, "meta", None)
|
||||
if not isinstance(meta, dict):
|
||||
continue
|
||||
|
||||
nn_stack = meta.get("nn_module_stack")
|
||||
if not nn_stack or not isinstance(nn_stack, dict):
|
||||
continue
|
||||
|
||||
# Get innermost module path from the stack
|
||||
# nn_module_stack is OrderedDict: {path: (qualified_name, module_class), ...}
|
||||
module_path = list(nn_stack.keys())[-1] if nn_stack else ""
|
||||
# Strip the "L__self__" prefix that torch.export adds (internal representation)
|
||||
module_path = re.sub(r"^L__self__[._]?", "", module_path)
|
||||
|
||||
# Get op name from target
|
||||
target = node.target
|
||||
if hasattr(target, "__name__"):
|
||||
op_name = target.__name__
|
||||
elif hasattr(target, "_name"):
|
||||
op_name = target._name
|
||||
else:
|
||||
op_name = str(target).split(".")[-1]
|
||||
|
||||
unique_name = graph._graph_namespace.create_name(op_name, node)
|
||||
# Build new name: module_path + op_name (dots -> underscores)
|
||||
if module_path:
|
||||
node.name = f"{module_path}_{unique_name}".replace(".", "_")
|
||||
|
||||
gm.recompile()
|
||||
|
||||
|
||||
def _clean_up_assertions_and_guards(gm: fx.GraphModule):
|
||||
"""This transformations removes shape checks and assertions from the graph."""
|
||||
check_ops = {
|
||||
@ -341,6 +385,9 @@ def torch_export_to_gm(
|
||||
# clean up checks --> generally the sanity checks are overly conservative and we can remove them
|
||||
_clean_up_assertions_and_guards(egm)
|
||||
|
||||
# Rename nodes to reflect module hierarchy for better debuggability
|
||||
_rename_nodes_with_module_hierarchy(egm)
|
||||
|
||||
# show exported graph
|
||||
ad_logger.debug("exported graph: " + str(egm))
|
||||
|
||||
|
||||
@ -28,6 +28,7 @@ from ..utils._graph import (
|
||||
run_shape_prop,
|
||||
)
|
||||
from ..utils.cuda_mem_tracker import get_mem_info
|
||||
from ..utils.graph_writer import graph_writer
|
||||
from ..utils.logger import ad_logger
|
||||
from .graph_module_visualizer import to_dot
|
||||
|
||||
@ -487,6 +488,9 @@ class BaseTransform(ABC):
|
||||
self._set_autodeploy_meta(mod, autodeploy_meta)
|
||||
self._visualize_graph(mod, idx)
|
||||
|
||||
# Dump graph after transform for debugging (controlled by AD_DUMP_GRAPHS_DIR env var)
|
||||
graph_writer.dump_graph(mod, t_name, self.config.stage.value)
|
||||
|
||||
# return the graph module
|
||||
return mod
|
||||
|
||||
|
||||
146
tensorrt_llm/_torch/auto_deploy/utils/graph_writer.py
Normal file
146
tensorrt_llm/_torch/auto_deploy/utils/graph_writer.py
Normal file
@ -0,0 +1,146 @@
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import TextIO
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ....logger import Singleton
|
||||
from .logger import ADLogger
|
||||
|
||||
|
||||
def _get_dtype_or_type(val):
|
||||
"""Get dtype if tensor-like, otherwise return type name for SymInt/SymFloat etc."""
|
||||
if hasattr(val, "dtype"):
|
||||
return val.dtype
|
||||
else:
|
||||
# For SymInt, SymFloat, etc. - return the type name
|
||||
return type(val).__name__
|
||||
|
||||
|
||||
def _get_shape_str(val):
|
||||
"""Get shape as 'dim0xdim1x...' string, or '?' if not available."""
|
||||
if hasattr(val, "shape"):
|
||||
# Handle symbolic dimensions (SymInt) by converting to str
|
||||
dims = [str(int(d)) if str(d).isdigit() else str(d) for d in val.shape]
|
||||
return "x".join(dims) if dims else "scalar"
|
||||
return "?"
|
||||
|
||||
|
||||
def _get_shape_dtype_str(val):
|
||||
"""Return 'shape : dtype' string for a value."""
|
||||
shape = _get_shape_str(val)
|
||||
dtype = _get_dtype_or_type(val)
|
||||
return f"{shape} : {dtype}"
|
||||
|
||||
|
||||
def dump_ssa_with_meta(f: TextIO, mod: GraphModule) -> None:
|
||||
for node in mod.graph.nodes:
|
||||
# Write out IR in traditional SSA style
|
||||
if node.op == "placeholder":
|
||||
if "val" in node.meta:
|
||||
shape_dtype = _get_shape_dtype_str(node.meta["val"])
|
||||
else:
|
||||
shape_dtype = "? : unknown"
|
||||
f.write(f"%{node.name} : {shape_dtype}\n")
|
||||
elif node.op in ("call_function", "call_method", "call_module"):
|
||||
# Build inputs list in SSA format with shape:dtype info
|
||||
input_vars = []
|
||||
for arg in node.args:
|
||||
if hasattr(arg, "name"):
|
||||
# Look up the arg node's metadata for shape/dtype
|
||||
if hasattr(arg, "meta") and "val" in arg.meta:
|
||||
arg_shape_dtype = _get_shape_dtype_str(arg.meta["val"])
|
||||
input_vars.append(f"%{arg.name} : {arg_shape_dtype}")
|
||||
else:
|
||||
input_vars.append(f"%{arg.name} : ? : unknown")
|
||||
else:
|
||||
input_vars.append(str(arg))
|
||||
|
||||
# Handle output shape/dtype (including multi-output)
|
||||
if "val" in node.meta:
|
||||
out_val = node.meta["val"]
|
||||
if isinstance(out_val, (tuple, list)):
|
||||
# Multi-output: (shape1, shape2) : (dtype1, dtype2)
|
||||
shapes = []
|
||||
dtypes = []
|
||||
for v in out_val:
|
||||
if v is not None:
|
||||
shapes.append(_get_shape_str(v))
|
||||
dtypes.append(str(_get_dtype_or_type(v)))
|
||||
else:
|
||||
shapes.append("?")
|
||||
dtypes.append("None")
|
||||
out_info = f"({', '.join(shapes)}) : ({', '.join(dtypes)})"
|
||||
else:
|
||||
out_info = _get_shape_dtype_str(out_val)
|
||||
else:
|
||||
out_info = "? : N/A"
|
||||
# Standard SSA notation: %out = op(args) : shape : dtype
|
||||
f.write(f"%{node.name} = {node.target}({', '.join(input_vars)}) : {out_info}\n")
|
||||
elif node.op == "output":
|
||||
# Output assignment in SSA IR
|
||||
outputs = node.args[0] if isinstance(node.args[0], (tuple, list)) else [node.args[0]]
|
||||
output_vars = []
|
||||
for o in outputs:
|
||||
if hasattr(o, "name"):
|
||||
output_vars.append(f"%{o.name}")
|
||||
else:
|
||||
output_vars.append(str(o))
|
||||
f.write(f"output {', '.join(output_vars)}\n")
|
||||
|
||||
|
||||
class GraphWriter(metaclass=Singleton):
|
||||
DUMP_GRAPHS_ENV = "AD_DUMP_GRAPHS_DIR"
|
||||
|
||||
def __init__(self):
|
||||
self._dump_dir = os.environ.get(self.DUMP_GRAPHS_ENV)
|
||||
self._logger = ADLogger()
|
||||
self._transform_counter = 0
|
||||
self._dump_dir_initialized = False
|
||||
|
||||
def dump_graph(self, mod: nn.Module, transform_name: str, stage: str) -> None:
|
||||
"""Dump the FX graph (SSA-style) to a file after a transform."""
|
||||
if not self._dump_dir:
|
||||
return
|
||||
|
||||
# Only dump from main process (rank 0) or single-process mode (rank is None)
|
||||
if self._logger.rank is not None and self._logger.rank != 0:
|
||||
return
|
||||
|
||||
# Lazy directory initialization (only on rank 0 / main process)
|
||||
if not self._dump_dir_initialized:
|
||||
dump_dir_path = Path(self._dump_dir)
|
||||
if dump_dir_path.exists():
|
||||
shutil.rmtree(dump_dir_path)
|
||||
dump_dir_path.mkdir(parents=True, exist_ok=True)
|
||||
self._logger.info(f"Graph dumping enabled to: {self._dump_dir}")
|
||||
self._dump_dir_initialized = True
|
||||
|
||||
# Collect all GraphModules (including from submodules)
|
||||
graph_modules = []
|
||||
for name, submod in mod.named_modules():
|
||||
if isinstance(submod, GraphModule):
|
||||
graph_modules.append((name if name else "(root)", submod))
|
||||
|
||||
if not graph_modules:
|
||||
return # No GraphModules found
|
||||
|
||||
self._transform_counter += 1
|
||||
filename = f"{self._transform_counter:03d}_{stage}_{transform_name}.txt"
|
||||
filepath = Path(self._dump_dir) / filename
|
||||
|
||||
with open(filepath, "w") as f:
|
||||
f.write(f"# Transform: {transform_name}\n")
|
||||
f.write(f"# Stage: {stage}\n")
|
||||
f.write(f"# GraphModules found: {len(graph_modules)}\n\n")
|
||||
|
||||
for module_name, gm in graph_modules:
|
||||
f.write(f"\n{'=' * 80}\n")
|
||||
f.write(f"# GraphModule: {module_name}\n")
|
||||
f.write(f"{'=' * 80}\n\n")
|
||||
dump_ssa_with_meta(f, gm)
|
||||
|
||||
|
||||
graph_writer = GraphWriter()
|
||||
Loading…
Reference in New Issue
Block a user