[#8460][feat] Revive and simplify Model Explorer visualization integration (#10150)

Signed-off-by: Karthik Vetrivel <kvetrivel@nvidia.com>
This commit is contained in:
Karthik 2026-01-05 22:15:25 -05:00 committed by GitHub
parent aa1fe931de
commit 617f728903
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 82 deletions

View File

@ -149,7 +149,7 @@ transforms:
############################################################################################
visualize_namespace:
stage: visualize
enabled: false # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/8460
enabled: false
############################################################################################
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
############################################################################################

View File

@ -1,92 +1,30 @@
"""Transformation to the graph to render nicely in model_explorer."""
import json
from typing import Tuple
import torch
import torch.export as te
from torch.fx import GraphModule
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.logger import ad_logger
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
try:
import model_explorer
from model_explorer.graph_builder import GraphNode, KeyValue, MetadataItem
from model_explorer.pytorch_exported_program_adater_impl import (
PytorchExportedProgramAdapterImpl,
)
except ImportError:
model_explorer = None
GraphNode = KeyValue = MetadataItem = PytorchExportedProgramAdapterImpl = None
# Optionally, you can log a warning or handle this gracefully elsewhere
def print_tensor(self, tensor: torch.Tensor, size_limit: int = 16):
shape = tensor.shape
total_size = 1
for dim in shape:
total_size *= dim
if size_limit < 0 or size_limit >= total_size:
return json.dumps(tensor.cpu().detach().to(torch.float32).numpy().tolist())
return json.dumps(
(tensor.cpu().detach().to(torch.float32).numpy().flatten())[:size_limit].tolist()
)
def _get_shape(val):
return json.dumps(
list(
map(
lambda x: int(x) if str(x).isdigit() else str(x),
val.shape,
)
)
)
def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode):
out_vals = fx_node.meta.get("val")
if out_vals is None:
return
if isinstance(out_vals, (tuple, list)):
for idx, val in enumerate(out_vals):
metadata = MetadataItem(id=str(idx), attrs=[])
if val is None:
continue
dtype = str(val.dtype)
shape = _get_shape(val)
metadata.attrs.append(KeyValue(key="tensor_shape", value=dtype + shape))
node.outputsMetadata.append(metadata)
elif isinstance(out_vals, torch.Tensor):
dtype = str(out_vals.dtype)
shape = _get_shape(out_vals)
metadata = MetadataItem(id="0", attrs=[KeyValue(key="tensor_shape", value=dtype + shape)])
node.outputsMetadata.append(metadata)
elif isinstance(out_vals, bool):
metadata = MetadataItem(id="0", attrs=[KeyValue(key="tensor_shape", value="bool[1]")])
node.outputsMetadata.append(metadata)
else:
raise ValueError(f"Unsupported output type: {type(out_vals)}")
# TODO(yudong): make custom_ops configurable
CUSTOM_OPS = (
torch.ops.auto_deploy.torch_dist_all_reduce.default,
torch.ops.auto_deploy.trtllm_dist_all_reduce.default,
torch.ops.aten.slice.Tensor,
torch.ops.auto_deploy.triton_attention_fused_mha_with_cache.default,
torch.ops.auto_deploy.torch_linear_simple.default,
torch.ops.aten.split_with_sizes.default,
)
@TransformRegistry.register("visualize_namespace")
class VisualizeNamespace(BaseTransform):
"""Transform to visualize the graph using Model Explorer.
This transform exports the graph module to an ExportedProgram and launches
Model Explorer for interactive visualization. The visualization helps debug
and understand the graph structure after AutoDeploy transformations.
"""
def _apply(
self,
gm: GraphModule,
@ -94,17 +32,37 @@ class VisualizeNamespace(BaseTransform):
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
PytorchExportedProgramAdapterImpl.print_tensor = print_tensor
PytorchExportedProgramAdapterImpl.add_outputs_metadata = add_outputs_metadata
"""Export the graph and launch Model Explorer for visualization.
# TODO(yudong): make viz as non-block call.
ep = te.export(gm, args=cm.args, dynamic_shapes=cm.dynamic_shapes)
graph = ep.graph
# Ensure the ops land up in the right module for better viz
for n in graph.nodes:
if n.target in CUSTOM_OPS:
n.meta["nn_module_stack"] = n.args[0].meta["nn_module_stack"]
Args:
gm: The graph module to visualize.
cm: The cached sequence interface with input arguments.
factory: The model factory (unused).
shared_config: Shared configuration across transforms (unused).
model_explorer.visualize_pytorch("model-viz", ep)
Returns:
A tuple of the unchanged graph module and transform info indicating
whether visualization was successful or skipped.
"""
if model_explorer is None:
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
return gm, TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
try:
# Export graph module to ExportedProgram for visualization
exported_program = te.export(gm, args=(), kwargs=cm.named_args, dynamic_shapes=None)
ad_logger.info("Launching Model Explorer visualization...")
model_explorer.visualize_pytorch("model-viz", exported_program)
return gm, TransformInfo(
skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True
)
except Exception as e:
ad_logger.error(f"Failed to visualize graph with Model Explorer: {e}")
# Don't fail the pipeline if visualization fails
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)