[None][infra] AutoDeploy: Dump graph IR after every transform (#11045)

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
This commit is contained in:
Bala Marimuthu 2026-02-09 10:43:44 -08:00 committed by GitHub
parent e719721a60
commit 4a743338c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 197 additions and 0 deletions

View File

@ -1,5 +1,6 @@
"""Main export functionality with utilities for torch.export."""
import re
from collections import defaultdict
from contextlib import nullcontext
from functools import partial
@ -188,6 +189,49 @@ def _add_load_hook_for_aliased_params(gm: fx.GraphModule, model: nn.Module) -> N
gm._register_load_state_dict_pre_hook(aliasing_load_pre_hook)
def _rename_nodes_with_module_hierarchy(gm: fx.GraphModule) -> None:
"""Rename call_function nodes to reflect their module hierarchy.
Uses nn_module_stack metadata to build hierarchical names like:
'layers_0_self_attn_linear' instead of 'linear_2'
"""
graph = gm.graph
for node in graph.nodes:
if node.op != "call_function":
continue
meta = getattr(node, "meta", None)
if not isinstance(meta, dict):
continue
nn_stack = meta.get("nn_module_stack")
if not nn_stack or not isinstance(nn_stack, dict):
continue
# Get innermost module path from the stack
# nn_module_stack is OrderedDict: {path: (qualified_name, module_class), ...}
module_path = list(nn_stack.keys())[-1] if nn_stack else ""
# Strip the "L__self__" prefix that torch.export adds (internal representation)
module_path = re.sub(r"^L__self__[._]?", "", module_path)
# Get op name from target
target = node.target
if hasattr(target, "__name__"):
op_name = target.__name__
elif hasattr(target, "_name"):
op_name = target._name
else:
op_name = str(target).split(".")[-1]
unique_name = graph._graph_namespace.create_name(op_name, node)
# Build new name: module_path + op_name (dots -> underscores)
if module_path:
node.name = f"{module_path}_{unique_name}".replace(".", "_")
gm.recompile()
def _clean_up_assertions_and_guards(gm: fx.GraphModule):
"""This transformations removes shape checks and assertions from the graph."""
check_ops = {
@ -341,6 +385,9 @@ def torch_export_to_gm(
# clean up checks --> generally the sanity checks are overly conservative and we can remove them
_clean_up_assertions_and_guards(egm)
# Rename nodes to reflect module hierarchy for better debuggability
_rename_nodes_with_module_hierarchy(egm)
# show exported graph
ad_logger.debug("exported graph: " + str(egm))

View File

@ -28,6 +28,7 @@ from ..utils._graph import (
run_shape_prop,
)
from ..utils.cuda_mem_tracker import get_mem_info
from ..utils.graph_writer import graph_writer
from ..utils.logger import ad_logger
from .graph_module_visualizer import to_dot
@ -487,6 +488,9 @@ class BaseTransform(ABC):
self._set_autodeploy_meta(mod, autodeploy_meta)
self._visualize_graph(mod, idx)
# Dump graph after transform for debugging (controlled by AD_DUMP_GRAPHS_DIR env var)
graph_writer.dump_graph(mod, t_name, self.config.stage.value)
# return the graph module
return mod

View File

@ -0,0 +1,146 @@
import os
import shutil
from pathlib import Path
from typing import TextIO
import torch.nn as nn
from torch.fx import GraphModule
from ....logger import Singleton
from .logger import ADLogger
def _get_dtype_or_type(val):
"""Get dtype if tensor-like, otherwise return type name for SymInt/SymFloat etc."""
if hasattr(val, "dtype"):
return val.dtype
else:
# For SymInt, SymFloat, etc. - return the type name
return type(val).__name__
def _get_shape_str(val):
"""Get shape as 'dim0xdim1x...' string, or '?' if not available."""
if hasattr(val, "shape"):
# Handle symbolic dimensions (SymInt) by converting to str
dims = [str(int(d)) if str(d).isdigit() else str(d) for d in val.shape]
return "x".join(dims) if dims else "scalar"
return "?"
def _get_shape_dtype_str(val):
"""Return 'shape : dtype' string for a value."""
shape = _get_shape_str(val)
dtype = _get_dtype_or_type(val)
return f"{shape} : {dtype}"
def dump_ssa_with_meta(f: TextIO, mod: GraphModule) -> None:
for node in mod.graph.nodes:
# Write out IR in traditional SSA style
if node.op == "placeholder":
if "val" in node.meta:
shape_dtype = _get_shape_dtype_str(node.meta["val"])
else:
shape_dtype = "? : unknown"
f.write(f"%{node.name} : {shape_dtype}\n")
elif node.op in ("call_function", "call_method", "call_module"):
# Build inputs list in SSA format with shape:dtype info
input_vars = []
for arg in node.args:
if hasattr(arg, "name"):
# Look up the arg node's metadata for shape/dtype
if hasattr(arg, "meta") and "val" in arg.meta:
arg_shape_dtype = _get_shape_dtype_str(arg.meta["val"])
input_vars.append(f"%{arg.name} : {arg_shape_dtype}")
else:
input_vars.append(f"%{arg.name} : ? : unknown")
else:
input_vars.append(str(arg))
# Handle output shape/dtype (including multi-output)
if "val" in node.meta:
out_val = node.meta["val"]
if isinstance(out_val, (tuple, list)):
# Multi-output: (shape1, shape2) : (dtype1, dtype2)
shapes = []
dtypes = []
for v in out_val:
if v is not None:
shapes.append(_get_shape_str(v))
dtypes.append(str(_get_dtype_or_type(v)))
else:
shapes.append("?")
dtypes.append("None")
out_info = f"({', '.join(shapes)}) : ({', '.join(dtypes)})"
else:
out_info = _get_shape_dtype_str(out_val)
else:
out_info = "? : N/A"
# Standard SSA notation: %out = op(args) : shape : dtype
f.write(f"%{node.name} = {node.target}({', '.join(input_vars)}) : {out_info}\n")
elif node.op == "output":
# Output assignment in SSA IR
outputs = node.args[0] if isinstance(node.args[0], (tuple, list)) else [node.args[0]]
output_vars = []
for o in outputs:
if hasattr(o, "name"):
output_vars.append(f"%{o.name}")
else:
output_vars.append(str(o))
f.write(f"output {', '.join(output_vars)}\n")
class GraphWriter(metaclass=Singleton):
DUMP_GRAPHS_ENV = "AD_DUMP_GRAPHS_DIR"
def __init__(self):
self._dump_dir = os.environ.get(self.DUMP_GRAPHS_ENV)
self._logger = ADLogger()
self._transform_counter = 0
self._dump_dir_initialized = False
def dump_graph(self, mod: nn.Module, transform_name: str, stage: str) -> None:
"""Dump the FX graph (SSA-style) to a file after a transform."""
if not self._dump_dir:
return
# Only dump from main process (rank 0) or single-process mode (rank is None)
if self._logger.rank is not None and self._logger.rank != 0:
return
# Lazy directory initialization (only on rank 0 / main process)
if not self._dump_dir_initialized:
dump_dir_path = Path(self._dump_dir)
if dump_dir_path.exists():
shutil.rmtree(dump_dir_path)
dump_dir_path.mkdir(parents=True, exist_ok=True)
self._logger.info(f"Graph dumping enabled to: {self._dump_dir}")
self._dump_dir_initialized = True
# Collect all GraphModules (including from submodules)
graph_modules = []
for name, submod in mod.named_modules():
if isinstance(submod, GraphModule):
graph_modules.append((name if name else "(root)", submod))
if not graph_modules:
return # No GraphModules found
self._transform_counter += 1
filename = f"{self._transform_counter:03d}_{stage}_{transform_name}.txt"
filepath = Path(self._dump_dir) / filename
with open(filepath, "w") as f:
f.write(f"# Transform: {transform_name}\n")
f.write(f"# Stage: {stage}\n")
f.write(f"# GraphModules found: {len(graph_modules)}\n\n")
for module_name, gm in graph_modules:
f.write(f"\n{'=' * 80}\n")
f.write(f"# GraphModule: {module_name}\n")
f.write(f"{'=' * 80}\n\n")
dump_ssa_with_meta(f, gm)
graph_writer = GraphWriter()