From 9879400479ae1a16906bfd6008a9a7966611bfdc Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Sun, 18 Jan 2026 13:42:30 -0500 Subject: [PATCH] [#10642][feat] AutoDeploy: optimized canonicalize_graph utilities [1/2] (#10675) Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- .../library/eliminate_redundant_transposes.py | 4 - .../transform/library/fuse_causal_conv.py | 6 +- .../transform/library/fuse_mamba_a_log.py | 2 +- .../transform/library/fused_moe.py | 21 +- .../auto_deploy/transform/library/fusion.py | 9 +- .../transform/library/multi_stream_moe.py | 8 +- .../auto_deploy/transform/library/sharding.py | 4 +- .../_torch/auto_deploy/utils/_graph.py | 94 +++- .../multigpu/test_ad_allreduce_strategies.py | 3 +- .../library/test_ep_sharding.py | 5 +- .../singlegpu/custom_ops/test_multi_stream.py | 5 +- .../utils/test_delete_unused_submodules.py | 478 ++++++++++++++++++ 12 files changed, 592 insertions(+), 47 deletions(-) create mode 100644 tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_delete_unused_submodules.py diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py b/tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py index 4f260b2c48..aff4b4f52c 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py @@ -110,10 +110,6 @@ class EliminateRedundantTransposes(BaseTransform): original_input.replace_all_uses_with(new_contiguous_node) new_contiguous_node.replace_input_with(new_contiguous_node, original_input) - # Clean up the graph - if nodes_to_eliminate: - gm.graph.eliminate_dead_code() - info = TransformInfo( skipped=False, num_matches=len(nodes_to_eliminate), diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py index 30f1959bbe..bce8476e6e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py @@ -127,12 +127,10 @@ class FuseCausalConvActivation(BaseTransform): graph.erase_node(activation_node) graph.erase_node(conv_node) - gm.recompile() - info = TransformInfo( skipped=False, num_matches=len(matches), - is_clean=False, - has_valid_shapes=False, + is_clean=len(matches) == 0, + has_valid_shapes=len(matches) == 0, ) return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py index 977aeeabc5..ea0dc3275e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py @@ -213,5 +213,5 @@ class FuseMambaALog(BaseTransform): skipped=False, num_matches=num_matches, is_clean=num_matches == 0, - has_valid_shapes=True, + has_valid_shapes=num_matches == 0, ) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index 819f80ed64..c7d949f20e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -10,6 +10,7 @@ from tensorrt_llm._torch.utils import ActivationType from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface +from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code from ...utils.cuda_mem_tracker import cuda_memory_tracker from ...utils.node_utils import bfs, extract_op_args, identify_regions_between_residuals, is_op from ..interface import ( @@ -112,8 +113,8 @@ def _insert_fused_moe_ops(gm: GraphModule, backend: Literal["auto", "trtllm", "t # Delete the unstacked weights immediately to save GPU memory # This will happen automatically after the graph is canonicalized, # but for large models we'll run out of memory during the transformation itself. - gm.graph.eliminate_dead_code() - gm.delete_all_unused_submodules() + eliminate_dead_code(gm) + delete_all_unused_submodules(gm) return fused_key_counter @@ -635,7 +636,7 @@ class MatchMoePattern(BaseTransform): graph.erase_node(final_hidden_state_node) while _remove_dead_inplace_nodes_in_region(gm.graph, start_boundary, end_boundary): - gm.graph.eliminate_dead_code() + eliminate_dead_code(gm) num_moe_patterns += 1 @@ -1272,14 +1273,14 @@ class MatchBmmMoePattern(BaseTransform): graph.erase_node(output_node) # Clean up dead nodes - gm.graph.eliminate_dead_code() + eliminate_dead_code(gm) # Clean up dead inplace nodes in the region while _remove_dead_inplace_nodes_in_region(gm.graph, start_boundary, end_boundary): - gm.graph.eliminate_dead_code() + eliminate_dead_code(gm) # Delete unused submodules/parameters - gm.delete_all_unused_submodules() + delete_all_unused_submodules(gm) num_moe_patterns += 1 @@ -1517,8 +1518,8 @@ def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", " # Clean up after processing all nodes # eliminate_dead_code will remove unused get_attr nodes, then delete_all_unused_submodules # will remove the parameters/buffers that are no longer referenced - gm.graph.eliminate_dead_code() - gm.delete_all_unused_submodules() + eliminate_dead_code(gm) + delete_all_unused_submodules(gm) return fused_key_counter @@ -1776,8 +1777,8 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int: # Clean up after processing all nodes # eliminate_dead_code will remove unused get_attr nodes, then delete_all_unused_submodules # will remove the parameters/buffers that are no longer referenced - gm.graph.eliminate_dead_code() - gm.delete_all_unused_submodules() + eliminate_dead_code(gm) + delete_all_unused_submodules(gm) return fused_key_counter diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py index ac129f2d9f..c31bbe7887 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py @@ -11,6 +11,7 @@ from torch.fx import GraphModule, Node from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface +from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code from ...utils.cuda_mem_tracker import cuda_memory_tracker from ...utils.logger import ad_logger from ...utils.node_utils import extract_weight_name, is_linear_op, is_op @@ -75,8 +76,8 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node n.replace_all_uses_with(get_split_node) # Clean up deleted modules to save GPU memory - gm.graph.eliminate_dead_code() - gm.delete_all_unused_submodules() + eliminate_dead_code(gm) + delete_all_unused_submodules(gm) def check_same_children(parent_node: Node, is_desired_child: Callable[[Node], bool]) -> bool: @@ -185,8 +186,8 @@ class QuantizationFusionMixin(ABC): n.replace_all_uses_with(get_split_node) # Clean up deleted modules to save GPU memory - gm.graph.eliminate_dead_code() - gm.delete_all_unused_submodules() + eliminate_dead_code(gm) + delete_all_unused_submodules(gm) def _apply_fusion_pass( self, diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py index cd82b04646..5cd02553d9 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py @@ -291,10 +291,6 @@ def _execute_op_in_aux_stream( n.replace_all_uses_with(new_node) graph.erase_node(n) num_replaced += 1 - if num_replaced: - graph.eliminate_dead_code() - graph.lint() - gm.recompile() return gm, num_replaced @@ -322,8 +318,8 @@ class MultiStreamMOE(BaseTransform): info = TransformInfo( skipped=False, num_matches=num_matches, - is_clean=False, - has_valid_shapes=False, + is_clean=num_matches == 0, + has_valid_shapes=num_matches == 0, ) return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 26257cdfee..3fa5c877a2 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -33,7 +33,7 @@ from .....functional import AllReduceStrategy from ...custom_ops.trtllm_dist import is_trtllm_op_available from ...models.factory import ModelFactory, ShardingConfigSource from ...shim.interface import CachedSequenceInterface -from ...utils._graph import del_attr_by_name +from ...utils._graph import del_attr_by_name, eliminate_dead_code from ...utils.logger import ad_logger from ...utils.node_utils import ( LayerSubgraph, @@ -1444,7 +1444,7 @@ def _insert_sharded_moe( node.replace_all_uses_with(dist_node) dist_node.replace_input_with(dist_node, node) - gm.graph.eliminate_dead_code() + eliminate_dead_code(gm) # Expert weights registered via gm.register_parameter() are top-level attributes. # Unlike submodules, these aren't cleaned up by eliminate_dead_code() or # delete_all_unused_submodules() - must delete manually after removing their get_attr nodes. diff --git a/tensorrt_llm/_torch/auto_deploy/utils/_graph.py b/tensorrt_llm/_torch/auto_deploy/utils/_graph.py index e1a922b59f..270e1ea4fe 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/_graph.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/_graph.py @@ -1,7 +1,8 @@ """Graph-related utilities for transformations.""" +import itertools from contextlib import contextmanager -from typing import Any, Dict, Iterator, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterator, Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -129,8 +130,8 @@ def _move_single_gm_to_device(gm: GraphModule, device: torch.device) -> None: ) if recompile_graph: # recompile graph to update self generated codes in subgraph - gm.graph.lint() - gm.recompile() + lint(gm) + recompile(gm) def move_to_device(mod: nn.Module, device: DeviceLikeType) -> None: @@ -161,18 +162,91 @@ def _is_impure_node(node: Node) -> bool: node.target._nondeterministic_seeded = True -def _canonicalize_single_gm(gm: GraphModule) -> None: - # clean up graph (needs to be done repeatedly until no more dead code) - gm.graph.eliminate_dead_code(is_impure_node=_is_impure_node) +def delete_all_unused_submodules(gm: GraphModule) -> None: + """Optimized version of delete_all_unused_submodules with O(n+m) complexity. - # recompile to propagate all graph changes to the graph module + The original implementation uses a list for tracking used modules, making membership + checks O(n). This version uses a set for O(1) lookups. + + Original implementation is at GraphModule.delete_all_unused_submodules + + Args: + gm: The GraphModule to clean up. + """ + used: Set[str] = set() + + for node in itertools.chain( + gm.graph.find_nodes(op="call_module", sort=False), + gm.graph.find_nodes(op="get_attr", sort=False), + ): + # check if it's already used and it's not a call_module node + # in this case we can skip. We cannot skip if it's a call_module node because we need to + # mark all recursive submodules as used. + if node.target in used and node.op != "call_module": + continue + + # A list of strings representing the different parts + # of the path. For example, `foo.bar.baz` gives us + # ["foo", "bar", "baz"] + fullpath = node.target.split(".") + + # Progressively collect all the names of intermediate + # modules. For example, if we have the target + # `foo.bar.baz`, we'll add `foo`, `foo.bar`, and + # `foo.bar.baz` to the list. + used.update(".".join(fullpath[:i]) for i in range(1, len(fullpath) + 1)) + + # For call_module, also mark all recursive submodules as used + if node.op == "call_module": + try: + submod = gm.get_submodule(node.target) + for submod_name, _ in submod.named_modules(): + if submod_name != "": + used.add(f"{node.target}.{submod_name}") + except AttributeError: + # Node referenced nonexistent submodule, don't need to + # worry about GCing anything + pass + + # also add the root module to the used set + used.add("") + + # Go over all modules and delete if on the list. Since we use named_modules, parents will be + # deleted first and children will be automatically skipped inside delete_submodule. + to_delete = [name for name, _ in gm.named_modules() if name not in used] + for name in to_delete: + gm.delete_submodule(name) + + +def eliminate_dead_code( + gm: GraphModule, is_impure_node: Optional[Callable[[Node], bool]] = None +) -> None: + """Eliminate dead code from the graph of the given GraphModule.""" + gm.graph.eliminate_dead_code(is_impure_node=is_impure_node) + + +def recompile(gm: GraphModule) -> None: + """Recompile the graph of the given GraphModule.""" gm.recompile() + +def lint(gm: GraphModule) -> None: + """Lint the graph of the given GraphModule.""" + gm.graph.lint() + + +def _canonicalize_single_gm(gm: GraphModule) -> None: + # clean up graph (needs to be done repeatedly until no more dead code) + eliminate_dead_code(gm, is_impure_node=_is_impure_node) + + # recompile to propagate all graph changes to the graph module + recompile(gm) + # clean up graph module - gm.delete_all_unused_submodules() + delete_all_unused_submodules(gm) # lint the graph - gm.graph.lint() + lint(gm) def canonicalize_graph(mod: nn.Module) -> None: @@ -217,7 +291,7 @@ def _run_shape_prop_single_gm( ad_logger.warning("No fake tensors and no args available for shape propagation") # lint the graph - gm.graph.lint() + lint(gm) def run_shape_prop( diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_allreduce_strategies.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_allreduce_strategies.py index 9d4e444e4d..4b6c9985e0 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_allreduce_strategies.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_allreduce_strategies.py @@ -19,6 +19,7 @@ from tensorrt_llm._torch.auto_deploy.transform.library.sharding import ( SplitDimension, WeightShardingInfo, ) +from tensorrt_llm._torch.auto_deploy.utils._graph import recompile from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op from tensorrt_llm.commands.bench import main from tensorrt_llm.functional import AllReduceStrategy @@ -378,7 +379,7 @@ def test_allreduce_strategy_propagation(strategy): if node: transform.check_and_apply(gm, node) - gm.recompile() + recompile(gm) # Verify the graph contains torch_dist_all_reduce nodes with correct strategy allreduce_nodes = [ diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py index a1e642db31..5ee6d0f883 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py @@ -18,6 +18,7 @@ from tensorrt_llm._torch.auto_deploy.transform.library.sharding import ( ShardingTransformConfig, ) from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer +from tensorrt_llm._torch.auto_deploy.utils._graph import lint, recompile from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op from tensorrt_llm.functional import AllReduceStrategy @@ -188,8 +189,8 @@ def test_llama4_stacked_moe_pattern_detection(): ) graph.output(moe_node) - graph.lint() - gm.recompile() + lint(gm) + recompile(gm) # Run pattern detection for EP optimizer = InferenceOptimizer( diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py index 779924340b..182c6a0a31 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py @@ -9,6 +9,7 @@ from tensorrt_llm._torch.auto_deploy.transform.library.multi_stream_moe import ( cuda_stream_manager, record_event_wrapper, ) +from tensorrt_llm._torch.auto_deploy.utils._graph import canonicalize_graph from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op @@ -75,9 +76,7 @@ def replace_multi_stream_linear_with_aux_stream_wrapper(gm: GraphModule) -> Tupl num_replaced += 1 if num_replaced: - graph.eliminate_dead_code() - graph.lint() - gm.recompile() + canonicalize_graph(gm) return gm, num_replaced diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_delete_unused_submodules.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_delete_unused_submodules.py new file mode 100644 index 0000000000..5539950e41 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_delete_unused_submodules.py @@ -0,0 +1,478 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for delete_all_unused_submodules. + +This module tests the optimized implementation against PyTorch's original +GraphModule.delete_all_unused_submodules to ensure functional equivalence. + +Since torch_export_to_gm already eliminates unused submodules during export, +we must modify the graph after export to create unused submodules for testing. +""" + +import copy +from typing import List, Set + +import pytest +import torch +import torch.nn as nn +from torch.fx import GraphModule + +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.utils._graph import ( + canonicalize_graph, + delete_all_unused_submodules, +) + +# ============================================================================= +# Test Models - All submodules are used in forward pass +# ============================================================================= + + +class FullyUsedModel(nn.Module): + """Model where all submodules are used in forward - none will be pruned during export.""" + + def __init__(self, hidden_dim: int = 32): + super().__init__() + self.linear1 = nn.Linear(hidden_dim, hidden_dim) + self.linear2 = nn.Linear(hidden_dim, hidden_dim) + self.linear3 = nn.Linear(hidden_dim, hidden_dim) + self.nested = nn.ModuleDict( + { + "layer_a": nn.Linear(hidden_dim, hidden_dim), + "layer_b": nn.Linear(hidden_dim, hidden_dim), + } + ) + + def forward(self, x): + # Use ALL submodules so they appear in the exported graph + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.nested["layer_a"](x) + x = self.nested["layer_b"](x) + return x + + +class DeeplyNestedModel(nn.Module): + """Model with deeply nested submodule hierarchy for testing depth-first deletion.""" + + def __init__(self, hidden_dim: int = 32): + super().__init__() + self.level1 = nn.ModuleDict( + { + "level2": nn.ModuleDict( + { + "level3": nn.ModuleDict( + { + "leaf_a": nn.Linear(hidden_dim, hidden_dim), + "leaf_b": nn.Linear(hidden_dim, hidden_dim), + } + ) + } + ) + } + ) + self.other = nn.Linear(hidden_dim, hidden_dim) + + def forward(self, x): + # Use all nested modules + x = self.level1["level2"]["level3"]["leaf_a"](x) + x = self.level1["level2"]["level3"]["leaf_b"](x) + x = self.other(x) + return x + + +class ModelWithSequential(nn.Module): + """Model with Sequential submodules for testing call_module behavior.""" + + def __init__(self, hidden_dim: int = 32): + super().__init__() + self.pre = nn.Linear(hidden_dim, hidden_dim) + self.sequential = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + ) + self.post = nn.Linear(hidden_dim, hidden_dim) + + def forward(self, x): + x = self.pre(x) + x = self.sequential(x) + x = self.post(x) + return x + + +# ============================================================================= +# Helper Functions +# ============================================================================= + + +def get_submodule_names(gm: nn.Module) -> Set[str]: + """Get all submodule names (excluding root empty string).""" + return {name for name, _ in gm.named_modules() if name} + + +def _target_matches_any(target: str, targets_to_remove: List[str]) -> bool: + """Check if a target matches any of the targets to remove. + + Handles both exact matches and prefix matches (for nested attributes). + Target names in the graph use underscores, so we check both formats. + """ + # Normalize target by replacing underscores with dots for comparison + # e.g., "linear2_weight" -> check against "linear2" + for t in targets_to_remove: + # Exact match + if target == t: + return True + # Prefix match (e.g., "linear2.weight" starts with "linear2.") + if target.startswith(t + "."): + return True + # Handle underscore-separated names from export + # e.g., target="linear2_weight" should match t="linear2" + t_underscore = t.replace(".", "_") + if target == t_underscore or target.startswith(t_underscore + "_"): + return True + return False + + +def make_submodules_unused(gm: GraphModule, targets_to_remove: List[str]) -> None: + """Remove nodes referencing specific submodules to make them unused. + + This function finds all operations that use parameters from the specified + submodules and removes them from the graph, rewiring the data flow to bypass + those operations. After removal, eliminate_dead_code() cleans up unused + get_attr nodes. + + Args: + gm: The GraphModule to modify. + targets_to_remove: List of submodule target names to make unused. + """ + graph = gm.graph + + # First pass: find call_function nodes that use parameters from target modules + # These are the actual operations (e.g., linear, matmul) that we need to remove + nodes_to_bypass = [] + + for node in graph.nodes: + if node.op == "call_function": + # Check if any of the node's inputs come from a target module's get_attr + uses_target_module = False + for arg in node.args: + if isinstance(arg, torch.fx.Node) and arg.op == "get_attr": + if _target_matches_any(arg.target, targets_to_remove): + uses_target_module = True + break + if uses_target_module: + nodes_to_bypass.append(node) + + elif node.op == "call_module": + # Direct call_module nodes (if any) + if _target_matches_any(node.target, targets_to_remove): + nodes_to_bypass.append(node) + + # Bypass each node by replacing its uses with its first tensor input + for node in nodes_to_bypass: + # Find the first tensor input (usually the activation, not weights) + replacement = None + for arg in node.args: + if isinstance(arg, torch.fx.Node): + # Skip get_attr nodes (weights/biases) - we want the activation input + if arg.op != "get_attr": + replacement = arg + break + if replacement is not None: + node.replace_all_uses_with(replacement) + + # Remove the bypassed nodes (in reverse topological order) + for node in reversed(nodes_to_bypass): + graph.erase_node(node) + + canonicalize_graph(gm) + + +# ============================================================================= +# Test Class +# ============================================================================= + + +class TestDeleteAllUnusedSubmodulesOptimized: + """Tests for delete_all_unused_submodules function.""" + + def test_functional_equivalence_basic(self): + """Test that optimized version produces identical results to original.""" + model = FullyUsedModel().to("cuda") + x = torch.randn(2, 32, device="cuda") + + gm = torch_export_to_gm(model, args=(x,)) + + # Verify submodules are present after export + submodules_after_export = get_submodule_names(gm) + assert len(submodules_after_export) > 0, "Expected submodules after export" + + # Make some modules unused by modifying the graph + gm_original = copy.deepcopy(gm) + gm_optimized = copy.deepcopy(gm) + + make_submodules_unused(gm_original, ["linear2"]) + make_submodules_unused(gm_optimized, ["linear2"]) + + # Apply both deletion implementations + gm_original.delete_all_unused_submodules() + delete_all_unused_submodules(gm_optimized) + + # Verify identical results + original_submodules = get_submodule_names(gm_original) + optimized_submodules = get_submodule_names(gm_optimized) + + assert original_submodules == optimized_submodules, ( + f"Mismatch in submodules:\n" + f"Original: {original_submodules}\n" + f"Optimized: {optimized_submodules}" + ) + + def test_functional_equivalence_multiple_removals(self): + """Test equivalence when multiple submodules are made unused.""" + model = FullyUsedModel().to("cuda") + x = torch.randn(2, 32, device="cuda") + + gm = torch_export_to_gm(model, args=(x,)) + + gm_original = copy.deepcopy(gm) + gm_optimized = copy.deepcopy(gm) + + # Remove multiple submodules + targets = ["linear2", "linear3"] + make_submodules_unused(gm_original, targets) + make_submodules_unused(gm_optimized, targets) + + gm_original.delete_all_unused_submodules() + delete_all_unused_submodules(gm_optimized) + + assert get_submodule_names(gm_original) == get_submodule_names(gm_optimized) + + def test_no_unused_modules(self): + """Test that nothing is deleted when all modules are used.""" + model = FullyUsedModel().to("cuda") + x = torch.randn(2, 32, device="cuda") + + gm = torch_export_to_gm(model, args=(x,)) + + submodules_before = get_submodule_names(gm) + + gm_original = copy.deepcopy(gm) + gm_optimized = copy.deepcopy(gm) + + # Don't modify the graph - all modules should remain used + gm_original.delete_all_unused_submodules() + delete_all_unused_submodules(gm_optimized) + + # Verify nothing was deleted + assert get_submodule_names(gm_original) == submodules_before + assert get_submodule_names(gm_optimized) == submodules_before + + def test_all_modules_made_unused(self): + """Test deletion when all call_module/get_attr nodes are removed.""" + model = FullyUsedModel().to("cuda") + x = torch.randn(2, 32, device="cuda") + + gm = torch_export_to_gm(model, args=(x,)) + + gm_original = copy.deepcopy(gm) + gm_optimized = copy.deepcopy(gm) + + # Find all submodule targets and remove them + all_targets = list(get_submodule_names(gm)) + + # Remove references to all submodules + make_submodules_unused(gm_original, all_targets) + make_submodules_unused(gm_optimized, all_targets) + + gm_original.delete_all_unused_submodules() + delete_all_unused_submodules(gm_optimized) + + # Both should have the same (empty or minimal) set of submodules + assert get_submodule_names(gm_original) == get_submodule_names(gm_optimized) + + def test_nested_module_partial_removal(self): + """Test that parent module stays when only one child is removed.""" + model = FullyUsedModel().to("cuda") + x = torch.randn(2, 32, device="cuda") + + gm = torch_export_to_gm(model, args=(x,)) + + gm_original = copy.deepcopy(gm) + gm_optimized = copy.deepcopy(gm) + + # Remove only one nested layer, keep the other + make_submodules_unused(gm_original, ["nested.layer_a"]) + make_submodules_unused(gm_optimized, ["nested.layer_a"]) + + gm_original.delete_all_unused_submodules() + delete_all_unused_submodules(gm_optimized) + + original_submodules = get_submodule_names(gm_original) + optimized_submodules = get_submodule_names(gm_optimized) + + assert original_submodules == optimized_submodules + + # Verify nested parent still exists (because layer_b is still used) + # Note: The exact behavior depends on how the graph represents nested modules + + def test_deeply_nested_hierarchy(self): + """Test deletion with deeply nested module hierarchy.""" + model = DeeplyNestedModel().to("cuda") + x = torch.randn(2, 32, device="cuda") + + gm = torch_export_to_gm(model, args=(x,)) + + gm_original = copy.deepcopy(gm) + gm_optimized = copy.deepcopy(gm) + + # Remove one of the deep leaves + make_submodules_unused(gm_original, ["level1.level2.level3.leaf_a"]) + make_submodules_unused(gm_optimized, ["level1.level2.level3.leaf_a"]) + + gm_original.delete_all_unused_submodules() + delete_all_unused_submodules(gm_optimized) + + original_submodules = get_submodule_names(gm_original) + optimized_submodules = get_submodule_names(gm_optimized) + + assert original_submodules == optimized_submodules + + def test_deeply_nested_full_branch_removal(self): + """Test removal of entire deeply nested branch.""" + model = DeeplyNestedModel().to("cuda") + x = torch.randn(2, 32, device="cuda") + + gm = torch_export_to_gm(model, args=(x,)) + + gm_original = copy.deepcopy(gm) + gm_optimized = copy.deepcopy(gm) + + # Remove entire nested branch + make_submodules_unused( + gm_original, ["level1.level2.level3.leaf_a", "level1.level2.level3.leaf_b"] + ) + make_submodules_unused( + gm_optimized, ["level1.level2.level3.leaf_a", "level1.level2.level3.leaf_b"] + ) + + gm_original.delete_all_unused_submodules() + delete_all_unused_submodules(gm_optimized) + + original_submodules = get_submodule_names(gm_original) + optimized_submodules = get_submodule_names(gm_optimized) + + assert original_submodules == optimized_submodules + + def test_sequential_module_handling(self): + """Test handling of Sequential modules (call_module marks children used).""" + model = ModelWithSequential().to("cuda") + x = torch.randn(2, 32, device="cuda") + + gm = torch_export_to_gm(model, args=(x,)) + + gm_original = copy.deepcopy(gm) + gm_optimized = copy.deepcopy(gm) + + # Remove the sequential module + make_submodules_unused(gm_original, ["sequential"]) + make_submodules_unused(gm_optimized, ["sequential"]) + + gm_original.delete_all_unused_submodules() + delete_all_unused_submodules(gm_optimized) + + original_submodules = get_submodule_names(gm_original) + optimized_submodules = get_submodule_names(gm_optimized) + + assert original_submodules == optimized_submodules + + def test_idempotent_deletion(self): + """Test that running deletion multiple times is idempotent.""" + model = FullyUsedModel().to("cuda") + x = torch.randn(2, 32, device="cuda") + + gm = torch_export_to_gm(model, args=(x,)) + + gm_optimized = copy.deepcopy(gm) + make_submodules_unused(gm_optimized, ["linear2"]) + + # Run deletion twice + delete_all_unused_submodules(gm_optimized) + submodules_after_first = get_submodule_names(gm_optimized) + + delete_all_unused_submodules(gm_optimized) + submodules_after_second = get_submodule_names(gm_optimized) + + assert submodules_after_first == submodules_after_second + + def test_empty_graph_module(self): + """Test handling of a minimal GraphModule.""" + + class MinimalModel(nn.Module): + def forward(self, x): + return x + + model = MinimalModel().to("cuda") + x = torch.randn(2, 32, device="cuda") + + gm = torch_export_to_gm(model, args=(x,)) + + gm_original = copy.deepcopy(gm) + gm_optimized = copy.deepcopy(gm) + + # Should not raise any errors + gm_original.delete_all_unused_submodules() + delete_all_unused_submodules(gm_optimized) + + assert get_submodule_names(gm_original) == get_submodule_names(gm_optimized) + + +# ============================================================================= +# Parametrized Tests +# ============================================================================= + + +@pytest.mark.parametrize( + "targets_to_remove", + [ + ["linear1"], + ["linear2"], + ["linear3"], + ["linear1", "linear2"], + ["linear2", "linear3"], + ["linear1", "linear3"], + ["linear1", "linear2", "linear3"], + ], +) +def test_various_removal_combinations(targets_to_remove): + """Test various combinations of submodule removals.""" + model = FullyUsedModel().to("cuda") + x = torch.randn(2, 32, device="cuda") + + gm = torch_export_to_gm(model, args=(x,)) + + gm_original = copy.deepcopy(gm) + gm_optimized = copy.deepcopy(gm) + + make_submodules_unused(gm_original, targets_to_remove) + make_submodules_unused(gm_optimized, targets_to_remove) + + gm_original.delete_all_unused_submodules() + delete_all_unused_submodules(gm_optimized) + + assert get_submodule_names(gm_original) == get_submodule_names(gm_optimized)