mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Signed-off-by: Karthik Vetrivel <kvetrivel@nvidia.com>
This commit is contained in:
parent
aa1fe931de
commit
617f728903
@ -149,7 +149,7 @@ transforms:
|
||||
############################################################################################
|
||||
visualize_namespace:
|
||||
stage: visualize
|
||||
enabled: false # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/8460
|
||||
enabled: false
|
||||
############################################################################################
|
||||
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
|
||||
############################################################################################
|
||||
|
||||
@ -1,92 +1,30 @@
|
||||
"""Transformation to the graph to render nicely in model_explorer."""
|
||||
|
||||
import json
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.export as te
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.logger import ad_logger
|
||||
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
|
||||
|
||||
try:
|
||||
import model_explorer
|
||||
from model_explorer.graph_builder import GraphNode, KeyValue, MetadataItem
|
||||
from model_explorer.pytorch_exported_program_adater_impl import (
|
||||
PytorchExportedProgramAdapterImpl,
|
||||
)
|
||||
except ImportError:
|
||||
model_explorer = None
|
||||
GraphNode = KeyValue = MetadataItem = PytorchExportedProgramAdapterImpl = None
|
||||
# Optionally, you can log a warning or handle this gracefully elsewhere
|
||||
|
||||
|
||||
def print_tensor(self, tensor: torch.Tensor, size_limit: int = 16):
|
||||
shape = tensor.shape
|
||||
total_size = 1
|
||||
for dim in shape:
|
||||
total_size *= dim
|
||||
|
||||
if size_limit < 0 or size_limit >= total_size:
|
||||
return json.dumps(tensor.cpu().detach().to(torch.float32).numpy().tolist())
|
||||
|
||||
return json.dumps(
|
||||
(tensor.cpu().detach().to(torch.float32).numpy().flatten())[:size_limit].tolist()
|
||||
)
|
||||
|
||||
|
||||
def _get_shape(val):
|
||||
return json.dumps(
|
||||
list(
|
||||
map(
|
||||
lambda x: int(x) if str(x).isdigit() else str(x),
|
||||
val.shape,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode):
|
||||
out_vals = fx_node.meta.get("val")
|
||||
if out_vals is None:
|
||||
return
|
||||
|
||||
if isinstance(out_vals, (tuple, list)):
|
||||
for idx, val in enumerate(out_vals):
|
||||
metadata = MetadataItem(id=str(idx), attrs=[])
|
||||
if val is None:
|
||||
continue
|
||||
dtype = str(val.dtype)
|
||||
shape = _get_shape(val)
|
||||
metadata.attrs.append(KeyValue(key="tensor_shape", value=dtype + shape))
|
||||
node.outputsMetadata.append(metadata)
|
||||
elif isinstance(out_vals, torch.Tensor):
|
||||
dtype = str(out_vals.dtype)
|
||||
shape = _get_shape(out_vals)
|
||||
metadata = MetadataItem(id="0", attrs=[KeyValue(key="tensor_shape", value=dtype + shape)])
|
||||
node.outputsMetadata.append(metadata)
|
||||
elif isinstance(out_vals, bool):
|
||||
metadata = MetadataItem(id="0", attrs=[KeyValue(key="tensor_shape", value="bool[1]")])
|
||||
node.outputsMetadata.append(metadata)
|
||||
else:
|
||||
raise ValueError(f"Unsupported output type: {type(out_vals)}")
|
||||
|
||||
|
||||
# TODO(yudong): make custom_ops configurable
|
||||
CUSTOM_OPS = (
|
||||
torch.ops.auto_deploy.torch_dist_all_reduce.default,
|
||||
torch.ops.auto_deploy.trtllm_dist_all_reduce.default,
|
||||
torch.ops.aten.slice.Tensor,
|
||||
torch.ops.auto_deploy.triton_attention_fused_mha_with_cache.default,
|
||||
torch.ops.auto_deploy.torch_linear_simple.default,
|
||||
torch.ops.aten.split_with_sizes.default,
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("visualize_namespace")
|
||||
class VisualizeNamespace(BaseTransform):
|
||||
"""Transform to visualize the graph using Model Explorer.
|
||||
|
||||
This transform exports the graph module to an ExportedProgram and launches
|
||||
Model Explorer for interactive visualization. The visualization helps debug
|
||||
and understand the graph structure after AutoDeploy transformations.
|
||||
"""
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
@ -94,17 +32,37 @@ class VisualizeNamespace(BaseTransform):
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
PytorchExportedProgramAdapterImpl.print_tensor = print_tensor
|
||||
PytorchExportedProgramAdapterImpl.add_outputs_metadata = add_outputs_metadata
|
||||
"""Export the graph and launch Model Explorer for visualization.
|
||||
|
||||
# TODO(yudong): make viz as non-block call.
|
||||
ep = te.export(gm, args=cm.args, dynamic_shapes=cm.dynamic_shapes)
|
||||
graph = ep.graph
|
||||
# Ensure the ops land up in the right module for better viz
|
||||
for n in graph.nodes:
|
||||
if n.target in CUSTOM_OPS:
|
||||
n.meta["nn_module_stack"] = n.args[0].meta["nn_module_stack"]
|
||||
Args:
|
||||
gm: The graph module to visualize.
|
||||
cm: The cached sequence interface with input arguments.
|
||||
factory: The model factory (unused).
|
||||
shared_config: Shared configuration across transforms (unused).
|
||||
|
||||
model_explorer.visualize_pytorch("model-viz", ep)
|
||||
Returns:
|
||||
A tuple of the unchanged graph module and transform info indicating
|
||||
whether visualization was successful or skipped.
|
||||
"""
|
||||
if model_explorer is None:
|
||||
return gm, TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
|
||||
return gm, TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
|
||||
try:
|
||||
# Export graph module to ExportedProgram for visualization
|
||||
exported_program = te.export(gm, args=(), kwargs=cm.named_args, dynamic_shapes=None)
|
||||
|
||||
ad_logger.info("Launching Model Explorer visualization...")
|
||||
model_explorer.visualize_pytorch("model-viz", exported_program)
|
||||
|
||||
return gm, TransformInfo(
|
||||
skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
ad_logger.error(f"Failed to visualize graph with Model Explorer: {e}")
|
||||
# Don't fail the pipeline if visualization fails
|
||||
return gm, TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user