From 4a743338c33574662c134b171e5eee42b1bc4ff7 Mon Sep 17 00:00:00 2001 From: Bala Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> Date: Mon, 9 Feb 2026 10:43:44 -0800 Subject: [PATCH] [None][infra] AutoDeploy: Dump graph IR after every transform (#11045) Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> --- .../_torch/auto_deploy/export/export.py | 47 ++++++ .../_torch/auto_deploy/transform/interface.py | 4 + .../_torch/auto_deploy/utils/graph_writer.py | 146 ++++++++++++++++++ 3 files changed, 197 insertions(+) create mode 100644 tensorrt_llm/_torch/auto_deploy/utils/graph_writer.py diff --git a/tensorrt_llm/_torch/auto_deploy/export/export.py b/tensorrt_llm/_torch/auto_deploy/export/export.py index f21275b3ee..4265fea9b9 100644 --- a/tensorrt_llm/_torch/auto_deploy/export/export.py +++ b/tensorrt_llm/_torch/auto_deploy/export/export.py @@ -1,5 +1,6 @@ """Main export functionality with utilities for torch.export.""" +import re from collections import defaultdict from contextlib import nullcontext from functools import partial @@ -188,6 +189,49 @@ def _add_load_hook_for_aliased_params(gm: fx.GraphModule, model: nn.Module) -> N gm._register_load_state_dict_pre_hook(aliasing_load_pre_hook) +def _rename_nodes_with_module_hierarchy(gm: fx.GraphModule) -> None: + """Rename call_function nodes to reflect their module hierarchy. + + Uses nn_module_stack metadata to build hierarchical names like: + 'layers_0_self_attn_linear' instead of 'linear_2' + """ + graph = gm.graph + + for node in graph.nodes: + if node.op != "call_function": + continue + + meta = getattr(node, "meta", None) + if not isinstance(meta, dict): + continue + + nn_stack = meta.get("nn_module_stack") + if not nn_stack or not isinstance(nn_stack, dict): + continue + + # Get innermost module path from the stack + # nn_module_stack is OrderedDict: {path: (qualified_name, module_class), ...} + module_path = list(nn_stack.keys())[-1] if nn_stack else "" + # Strip the "L__self__" prefix that torch.export adds (internal representation) + module_path = re.sub(r"^L__self__[._]?", "", module_path) + + # Get op name from target + target = node.target + if hasattr(target, "__name__"): + op_name = target.__name__ + elif hasattr(target, "_name"): + op_name = target._name + else: + op_name = str(target).split(".")[-1] + + unique_name = graph._graph_namespace.create_name(op_name, node) + # Build new name: module_path + op_name (dots -> underscores) + if module_path: + node.name = f"{module_path}_{unique_name}".replace(".", "_") + + gm.recompile() + + def _clean_up_assertions_and_guards(gm: fx.GraphModule): """This transformations removes shape checks and assertions from the graph.""" check_ops = { @@ -341,6 +385,9 @@ def torch_export_to_gm( # clean up checks --> generally the sanity checks are overly conservative and we can remove them _clean_up_assertions_and_guards(egm) + # Rename nodes to reflect module hierarchy for better debuggability + _rename_nodes_with_module_hierarchy(egm) + # show exported graph ad_logger.debug("exported graph: " + str(egm)) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/interface.py b/tensorrt_llm/_torch/auto_deploy/transform/interface.py index 668f15b7ea..5869e4b014 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/interface.py @@ -28,6 +28,7 @@ from ..utils._graph import ( run_shape_prop, ) from ..utils.cuda_mem_tracker import get_mem_info +from ..utils.graph_writer import graph_writer from ..utils.logger import ad_logger from .graph_module_visualizer import to_dot @@ -487,6 +488,9 @@ class BaseTransform(ABC): self._set_autodeploy_meta(mod, autodeploy_meta) self._visualize_graph(mod, idx) + # Dump graph after transform for debugging (controlled by AD_DUMP_GRAPHS_DIR env var) + graph_writer.dump_graph(mod, t_name, self.config.stage.value) + # return the graph module return mod diff --git a/tensorrt_llm/_torch/auto_deploy/utils/graph_writer.py b/tensorrt_llm/_torch/auto_deploy/utils/graph_writer.py new file mode 100644 index 0000000000..e350f5bc21 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/utils/graph_writer.py @@ -0,0 +1,146 @@ +import os +import shutil +from pathlib import Path +from typing import TextIO + +import torch.nn as nn +from torch.fx import GraphModule + +from ....logger import Singleton +from .logger import ADLogger + + +def _get_dtype_or_type(val): + """Get dtype if tensor-like, otherwise return type name for SymInt/SymFloat etc.""" + if hasattr(val, "dtype"): + return val.dtype + else: + # For SymInt, SymFloat, etc. - return the type name + return type(val).__name__ + + +def _get_shape_str(val): + """Get shape as 'dim0xdim1x...' string, or '?' if not available.""" + if hasattr(val, "shape"): + # Handle symbolic dimensions (SymInt) by converting to str + dims = [str(int(d)) if str(d).isdigit() else str(d) for d in val.shape] + return "x".join(dims) if dims else "scalar" + return "?" + + +def _get_shape_dtype_str(val): + """Return 'shape : dtype' string for a value.""" + shape = _get_shape_str(val) + dtype = _get_dtype_or_type(val) + return f"{shape} : {dtype}" + + +def dump_ssa_with_meta(f: TextIO, mod: GraphModule) -> None: + for node in mod.graph.nodes: + # Write out IR in traditional SSA style + if node.op == "placeholder": + if "val" in node.meta: + shape_dtype = _get_shape_dtype_str(node.meta["val"]) + else: + shape_dtype = "? : unknown" + f.write(f"%{node.name} : {shape_dtype}\n") + elif node.op in ("call_function", "call_method", "call_module"): + # Build inputs list in SSA format with shape:dtype info + input_vars = [] + for arg in node.args: + if hasattr(arg, "name"): + # Look up the arg node's metadata for shape/dtype + if hasattr(arg, "meta") and "val" in arg.meta: + arg_shape_dtype = _get_shape_dtype_str(arg.meta["val"]) + input_vars.append(f"%{arg.name} : {arg_shape_dtype}") + else: + input_vars.append(f"%{arg.name} : ? : unknown") + else: + input_vars.append(str(arg)) + + # Handle output shape/dtype (including multi-output) + if "val" in node.meta: + out_val = node.meta["val"] + if isinstance(out_val, (tuple, list)): + # Multi-output: (shape1, shape2) : (dtype1, dtype2) + shapes = [] + dtypes = [] + for v in out_val: + if v is not None: + shapes.append(_get_shape_str(v)) + dtypes.append(str(_get_dtype_or_type(v))) + else: + shapes.append("?") + dtypes.append("None") + out_info = f"({', '.join(shapes)}) : ({', '.join(dtypes)})" + else: + out_info = _get_shape_dtype_str(out_val) + else: + out_info = "? : N/A" + # Standard SSA notation: %out = op(args) : shape : dtype + f.write(f"%{node.name} = {node.target}({', '.join(input_vars)}) : {out_info}\n") + elif node.op == "output": + # Output assignment in SSA IR + outputs = node.args[0] if isinstance(node.args[0], (tuple, list)) else [node.args[0]] + output_vars = [] + for o in outputs: + if hasattr(o, "name"): + output_vars.append(f"%{o.name}") + else: + output_vars.append(str(o)) + f.write(f"output {', '.join(output_vars)}\n") + + +class GraphWriter(metaclass=Singleton): + DUMP_GRAPHS_ENV = "AD_DUMP_GRAPHS_DIR" + + def __init__(self): + self._dump_dir = os.environ.get(self.DUMP_GRAPHS_ENV) + self._logger = ADLogger() + self._transform_counter = 0 + self._dump_dir_initialized = False + + def dump_graph(self, mod: nn.Module, transform_name: str, stage: str) -> None: + """Dump the FX graph (SSA-style) to a file after a transform.""" + if not self._dump_dir: + return + + # Only dump from main process (rank 0) or single-process mode (rank is None) + if self._logger.rank is not None and self._logger.rank != 0: + return + + # Lazy directory initialization (only on rank 0 / main process) + if not self._dump_dir_initialized: + dump_dir_path = Path(self._dump_dir) + if dump_dir_path.exists(): + shutil.rmtree(dump_dir_path) + dump_dir_path.mkdir(parents=True, exist_ok=True) + self._logger.info(f"Graph dumping enabled to: {self._dump_dir}") + self._dump_dir_initialized = True + + # Collect all GraphModules (including from submodules) + graph_modules = [] + for name, submod in mod.named_modules(): + if isinstance(submod, GraphModule): + graph_modules.append((name if name else "(root)", submod)) + + if not graph_modules: + return # No GraphModules found + + self._transform_counter += 1 + filename = f"{self._transform_counter:03d}_{stage}_{transform_name}.txt" + filepath = Path(self._dump_dir) / filename + + with open(filepath, "w") as f: + f.write(f"# Transform: {transform_name}\n") + f.write(f"# Stage: {stage}\n") + f.write(f"# GraphModules found: {len(graph_modules)}\n\n") + + for module_name, gm in graph_modules: + f.write(f"\n{'=' * 80}\n") + f.write(f"# GraphModule: {module_name}\n") + f.write(f"{'=' * 80}\n\n") + dump_ssa_with_meta(f, gm) + + +graph_writer = GraphWriter()