[#10642][feat] AutoDeploy: optimized canonicalize_graph utilities [1/2] (#10675)

Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
This commit is contained in:
Lucas Liebenwein 2026-01-18 13:42:30 -05:00 committed by GitHub
parent 4d2916d683
commit 9879400479
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 592 additions and 47 deletions

View File

@ -110,10 +110,6 @@ class EliminateRedundantTransposes(BaseTransform):
original_input.replace_all_uses_with(new_contiguous_node)
new_contiguous_node.replace_input_with(new_contiguous_node, original_input)
# Clean up the graph
if nodes_to_eliminate:
gm.graph.eliminate_dead_code()
info = TransformInfo(
skipped=False,
num_matches=len(nodes_to_eliminate),

View File

@ -127,12 +127,10 @@ class FuseCausalConvActivation(BaseTransform):
graph.erase_node(activation_node)
graph.erase_node(conv_node)
gm.recompile()
info = TransformInfo(
skipped=False,
num_matches=len(matches),
is_clean=False,
has_valid_shapes=False,
is_clean=len(matches) == 0,
has_valid_shapes=len(matches) == 0,
)
return gm, info

View File

@ -213,5 +213,5 @@ class FuseMambaALog(BaseTransform):
skipped=False,
num_matches=num_matches,
is_clean=num_matches == 0,
has_valid_shapes=True,
has_valid_shapes=num_matches == 0,
)

View File

@ -10,6 +10,7 @@ from tensorrt_llm._torch.utils import ActivationType
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code
from ...utils.cuda_mem_tracker import cuda_memory_tracker
from ...utils.node_utils import bfs, extract_op_args, identify_regions_between_residuals, is_op
from ..interface import (
@ -112,8 +113,8 @@ def _insert_fused_moe_ops(gm: GraphModule, backend: Literal["auto", "trtllm", "t
# Delete the unstacked weights immediately to save GPU memory
# This will happen automatically after the graph is canonicalized,
# but for large models we'll run out of memory during the transformation itself.
gm.graph.eliminate_dead_code()
gm.delete_all_unused_submodules()
eliminate_dead_code(gm)
delete_all_unused_submodules(gm)
return fused_key_counter
@ -635,7 +636,7 @@ class MatchMoePattern(BaseTransform):
graph.erase_node(final_hidden_state_node)
while _remove_dead_inplace_nodes_in_region(gm.graph, start_boundary, end_boundary):
gm.graph.eliminate_dead_code()
eliminate_dead_code(gm)
num_moe_patterns += 1
@ -1272,14 +1273,14 @@ class MatchBmmMoePattern(BaseTransform):
graph.erase_node(output_node)
# Clean up dead nodes
gm.graph.eliminate_dead_code()
eliminate_dead_code(gm)
# Clean up dead inplace nodes in the region
while _remove_dead_inplace_nodes_in_region(gm.graph, start_boundary, end_boundary):
gm.graph.eliminate_dead_code()
eliminate_dead_code(gm)
# Delete unused submodules/parameters
gm.delete_all_unused_submodules()
delete_all_unused_submodules(gm)
num_moe_patterns += 1
@ -1517,8 +1518,8 @@ def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", "
# Clean up after processing all nodes
# eliminate_dead_code will remove unused get_attr nodes, then delete_all_unused_submodules
# will remove the parameters/buffers that are no longer referenced
gm.graph.eliminate_dead_code()
gm.delete_all_unused_submodules()
eliminate_dead_code(gm)
delete_all_unused_submodules(gm)
return fused_key_counter
@ -1776,8 +1777,8 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
# Clean up after processing all nodes
# eliminate_dead_code will remove unused get_attr nodes, then delete_all_unused_submodules
# will remove the parameters/buffers that are no longer referenced
gm.graph.eliminate_dead_code()
gm.delete_all_unused_submodules()
eliminate_dead_code(gm)
delete_all_unused_submodules(gm)
return fused_key_counter

View File

@ -11,6 +11,7 @@ from torch.fx import GraphModule, Node
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code
from ...utils.cuda_mem_tracker import cuda_memory_tracker
from ...utils.logger import ad_logger
from ...utils.node_utils import extract_weight_name, is_linear_op, is_op
@ -75,8 +76,8 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node
n.replace_all_uses_with(get_split_node)
# Clean up deleted modules to save GPU memory
gm.graph.eliminate_dead_code()
gm.delete_all_unused_submodules()
eliminate_dead_code(gm)
delete_all_unused_submodules(gm)
def check_same_children(parent_node: Node, is_desired_child: Callable[[Node], bool]) -> bool:
@ -185,8 +186,8 @@ class QuantizationFusionMixin(ABC):
n.replace_all_uses_with(get_split_node)
# Clean up deleted modules to save GPU memory
gm.graph.eliminate_dead_code()
gm.delete_all_unused_submodules()
eliminate_dead_code(gm)
delete_all_unused_submodules(gm)
def _apply_fusion_pass(
self,

View File

@ -291,10 +291,6 @@ def _execute_op_in_aux_stream(
n.replace_all_uses_with(new_node)
graph.erase_node(n)
num_replaced += 1
if num_replaced:
graph.eliminate_dead_code()
graph.lint()
gm.recompile()
return gm, num_replaced
@ -322,8 +318,8 @@ class MultiStreamMOE(BaseTransform):
info = TransformInfo(
skipped=False,
num_matches=num_matches,
is_clean=False,
has_valid_shapes=False,
is_clean=num_matches == 0,
has_valid_shapes=num_matches == 0,
)
return gm, info

View File

@ -33,7 +33,7 @@ from .....functional import AllReduceStrategy
from ...custom_ops.trtllm_dist import is_trtllm_op_available
from ...models.factory import ModelFactory, ShardingConfigSource
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import del_attr_by_name
from ...utils._graph import del_attr_by_name, eliminate_dead_code
from ...utils.logger import ad_logger
from ...utils.node_utils import (
LayerSubgraph,
@ -1444,7 +1444,7 @@ def _insert_sharded_moe(
node.replace_all_uses_with(dist_node)
dist_node.replace_input_with(dist_node, node)
gm.graph.eliminate_dead_code()
eliminate_dead_code(gm)
# Expert weights registered via gm.register_parameter() are top-level attributes.
# Unlike submodules, these aren't cleaned up by eliminate_dead_code() or
# delete_all_unused_submodules() - must delete manually after removing their get_attr nodes.

View File

@ -1,7 +1,8 @@
"""Graph-related utilities for transformations."""
import itertools
from contextlib import contextmanager
from typing import Any, Dict, Iterator, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterator, Optional, Set, Tuple, Union
import torch
import torch.nn as nn
@ -129,8 +130,8 @@ def _move_single_gm_to_device(gm: GraphModule, device: torch.device) -> None:
)
if recompile_graph:
# recompile graph to update self generated codes in subgraph
gm.graph.lint()
gm.recompile()
lint(gm)
recompile(gm)
def move_to_device(mod: nn.Module, device: DeviceLikeType) -> None:
@ -161,18 +162,91 @@ def _is_impure_node(node: Node) -> bool:
node.target._nondeterministic_seeded = True
def _canonicalize_single_gm(gm: GraphModule) -> None:
# clean up graph (needs to be done repeatedly until no more dead code)
gm.graph.eliminate_dead_code(is_impure_node=_is_impure_node)
def delete_all_unused_submodules(gm: GraphModule) -> None:
"""Optimized version of delete_all_unused_submodules with O(n+m) complexity.
# recompile to propagate all graph changes to the graph module
The original implementation uses a list for tracking used modules, making membership
checks O(n). This version uses a set for O(1) lookups.
Original implementation is at GraphModule.delete_all_unused_submodules
Args:
gm: The GraphModule to clean up.
"""
used: Set[str] = set()
for node in itertools.chain(
gm.graph.find_nodes(op="call_module", sort=False),
gm.graph.find_nodes(op="get_attr", sort=False),
):
# check if it's already used and it's not a call_module node
# in this case we can skip. We cannot skip if it's a call_module node because we need to
# mark all recursive submodules as used.
if node.target in used and node.op != "call_module":
continue
# A list of strings representing the different parts
# of the path. For example, `foo.bar.baz` gives us
# ["foo", "bar", "baz"]
fullpath = node.target.split(".")
# Progressively collect all the names of intermediate
# modules. For example, if we have the target
# `foo.bar.baz`, we'll add `foo`, `foo.bar`, and
# `foo.bar.baz` to the list.
used.update(".".join(fullpath[:i]) for i in range(1, len(fullpath) + 1))
# For call_module, also mark all recursive submodules as used
if node.op == "call_module":
try:
submod = gm.get_submodule(node.target)
for submod_name, _ in submod.named_modules():
if submod_name != "":
used.add(f"{node.target}.{submod_name}")
except AttributeError:
# Node referenced nonexistent submodule, don't need to
# worry about GCing anything
pass
# also add the root module to the used set
used.add("")
# Go over all modules and delete if on the list. Since we use named_modules, parents will be
# deleted first and children will be automatically skipped inside delete_submodule.
to_delete = [name for name, _ in gm.named_modules() if name not in used]
for name in to_delete:
gm.delete_submodule(name)
def eliminate_dead_code(
gm: GraphModule, is_impure_node: Optional[Callable[[Node], bool]] = None
) -> None:
"""Eliminate dead code from the graph of the given GraphModule."""
gm.graph.eliminate_dead_code(is_impure_node=is_impure_node)
def recompile(gm: GraphModule) -> None:
"""Recompile the graph of the given GraphModule."""
gm.recompile()
def lint(gm: GraphModule) -> None:
"""Lint the graph of the given GraphModule."""
gm.graph.lint()
def _canonicalize_single_gm(gm: GraphModule) -> None:
# clean up graph (needs to be done repeatedly until no more dead code)
eliminate_dead_code(gm, is_impure_node=_is_impure_node)
# recompile to propagate all graph changes to the graph module
recompile(gm)
# clean up graph module
gm.delete_all_unused_submodules()
delete_all_unused_submodules(gm)
# lint the graph
gm.graph.lint()
lint(gm)
def canonicalize_graph(mod: nn.Module) -> None:
@ -217,7 +291,7 @@ def _run_shape_prop_single_gm(
ad_logger.warning("No fake tensors and no args available for shape propagation")
# lint the graph
gm.graph.lint()
lint(gm)
def run_shape_prop(

View File

@ -19,6 +19,7 @@ from tensorrt_llm._torch.auto_deploy.transform.library.sharding import (
SplitDimension,
WeightShardingInfo,
)
from tensorrt_llm._torch.auto_deploy.utils._graph import recompile
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
from tensorrt_llm.commands.bench import main
from tensorrt_llm.functional import AllReduceStrategy
@ -378,7 +379,7 @@ def test_allreduce_strategy_propagation(strategy):
if node:
transform.check_and_apply(gm, node)
gm.recompile()
recompile(gm)
# Verify the graph contains torch_dist_all_reduce nodes with correct strategy
allreduce_nodes = [

View File

@ -18,6 +18,7 @@ from tensorrt_llm._torch.auto_deploy.transform.library.sharding import (
ShardingTransformConfig,
)
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils._graph import lint, recompile
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
from tensorrt_llm.functional import AllReduceStrategy
@ -188,8 +189,8 @@ def test_llama4_stacked_moe_pattern_detection():
)
graph.output(moe_node)
graph.lint()
gm.recompile()
lint(gm)
recompile(gm)
# Run pattern detection for EP
optimizer = InferenceOptimizer(

View File

@ -9,6 +9,7 @@ from tensorrt_llm._torch.auto_deploy.transform.library.multi_stream_moe import (
cuda_stream_manager,
record_event_wrapper,
)
from tensorrt_llm._torch.auto_deploy.utils._graph import canonicalize_graph
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
@ -75,9 +76,7 @@ def replace_multi_stream_linear_with_aux_stream_wrapper(gm: GraphModule) -> Tupl
num_replaced += 1
if num_replaced:
graph.eliminate_dead_code()
graph.lint()
gm.recompile()
canonicalize_graph(gm)
return gm, num_replaced

View File

@ -0,0 +1,478 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for delete_all_unused_submodules.
This module tests the optimized implementation against PyTorch's original
GraphModule.delete_all_unused_submodules to ensure functional equivalence.
Since torch_export_to_gm already eliminates unused submodules during export,
we must modify the graph after export to create unused submodules for testing.
"""
import copy
from typing import List, Set
import pytest
import torch
import torch.nn as nn
from torch.fx import GraphModule
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.utils._graph import (
canonicalize_graph,
delete_all_unused_submodules,
)
# =============================================================================
# Test Models - All submodules are used in forward pass
# =============================================================================
class FullyUsedModel(nn.Module):
"""Model where all submodules are used in forward - none will be pruned during export."""
def __init__(self, hidden_dim: int = 32):
super().__init__()
self.linear1 = nn.Linear(hidden_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.linear3 = nn.Linear(hidden_dim, hidden_dim)
self.nested = nn.ModuleDict(
{
"layer_a": nn.Linear(hidden_dim, hidden_dim),
"layer_b": nn.Linear(hidden_dim, hidden_dim),
}
)
def forward(self, x):
# Use ALL submodules so they appear in the exported graph
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = self.nested["layer_a"](x)
x = self.nested["layer_b"](x)
return x
class DeeplyNestedModel(nn.Module):
"""Model with deeply nested submodule hierarchy for testing depth-first deletion."""
def __init__(self, hidden_dim: int = 32):
super().__init__()
self.level1 = nn.ModuleDict(
{
"level2": nn.ModuleDict(
{
"level3": nn.ModuleDict(
{
"leaf_a": nn.Linear(hidden_dim, hidden_dim),
"leaf_b": nn.Linear(hidden_dim, hidden_dim),
}
)
}
)
}
)
self.other = nn.Linear(hidden_dim, hidden_dim)
def forward(self, x):
# Use all nested modules
x = self.level1["level2"]["level3"]["leaf_a"](x)
x = self.level1["level2"]["level3"]["leaf_b"](x)
x = self.other(x)
return x
class ModelWithSequential(nn.Module):
"""Model with Sequential submodules for testing call_module behavior."""
def __init__(self, hidden_dim: int = 32):
super().__init__()
self.pre = nn.Linear(hidden_dim, hidden_dim)
self.sequential = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
)
self.post = nn.Linear(hidden_dim, hidden_dim)
def forward(self, x):
x = self.pre(x)
x = self.sequential(x)
x = self.post(x)
return x
# =============================================================================
# Helper Functions
# =============================================================================
def get_submodule_names(gm: nn.Module) -> Set[str]:
"""Get all submodule names (excluding root empty string)."""
return {name for name, _ in gm.named_modules() if name}
def _target_matches_any(target: str, targets_to_remove: List[str]) -> bool:
"""Check if a target matches any of the targets to remove.
Handles both exact matches and prefix matches (for nested attributes).
Target names in the graph use underscores, so we check both formats.
"""
# Normalize target by replacing underscores with dots for comparison
# e.g., "linear2_weight" -> check against "linear2"
for t in targets_to_remove:
# Exact match
if target == t:
return True
# Prefix match (e.g., "linear2.weight" starts with "linear2.")
if target.startswith(t + "."):
return True
# Handle underscore-separated names from export
# e.g., target="linear2_weight" should match t="linear2"
t_underscore = t.replace(".", "_")
if target == t_underscore or target.startswith(t_underscore + "_"):
return True
return False
def make_submodules_unused(gm: GraphModule, targets_to_remove: List[str]) -> None:
"""Remove nodes referencing specific submodules to make them unused.
This function finds all operations that use parameters from the specified
submodules and removes them from the graph, rewiring the data flow to bypass
those operations. After removal, eliminate_dead_code() cleans up unused
get_attr nodes.
Args:
gm: The GraphModule to modify.
targets_to_remove: List of submodule target names to make unused.
"""
graph = gm.graph
# First pass: find call_function nodes that use parameters from target modules
# These are the actual operations (e.g., linear, matmul) that we need to remove
nodes_to_bypass = []
for node in graph.nodes:
if node.op == "call_function":
# Check if any of the node's inputs come from a target module's get_attr
uses_target_module = False
for arg in node.args:
if isinstance(arg, torch.fx.Node) and arg.op == "get_attr":
if _target_matches_any(arg.target, targets_to_remove):
uses_target_module = True
break
if uses_target_module:
nodes_to_bypass.append(node)
elif node.op == "call_module":
# Direct call_module nodes (if any)
if _target_matches_any(node.target, targets_to_remove):
nodes_to_bypass.append(node)
# Bypass each node by replacing its uses with its first tensor input
for node in nodes_to_bypass:
# Find the first tensor input (usually the activation, not weights)
replacement = None
for arg in node.args:
if isinstance(arg, torch.fx.Node):
# Skip get_attr nodes (weights/biases) - we want the activation input
if arg.op != "get_attr":
replacement = arg
break
if replacement is not None:
node.replace_all_uses_with(replacement)
# Remove the bypassed nodes (in reverse topological order)
for node in reversed(nodes_to_bypass):
graph.erase_node(node)
canonicalize_graph(gm)
# =============================================================================
# Test Class
# =============================================================================
class TestDeleteAllUnusedSubmodulesOptimized:
"""Tests for delete_all_unused_submodules function."""
def test_functional_equivalence_basic(self):
"""Test that optimized version produces identical results to original."""
model = FullyUsedModel().to("cuda")
x = torch.randn(2, 32, device="cuda")
gm = torch_export_to_gm(model, args=(x,))
# Verify submodules are present after export
submodules_after_export = get_submodule_names(gm)
assert len(submodules_after_export) > 0, "Expected submodules after export"
# Make some modules unused by modifying the graph
gm_original = copy.deepcopy(gm)
gm_optimized = copy.deepcopy(gm)
make_submodules_unused(gm_original, ["linear2"])
make_submodules_unused(gm_optimized, ["linear2"])
# Apply both deletion implementations
gm_original.delete_all_unused_submodules()
delete_all_unused_submodules(gm_optimized)
# Verify identical results
original_submodules = get_submodule_names(gm_original)
optimized_submodules = get_submodule_names(gm_optimized)
assert original_submodules == optimized_submodules, (
f"Mismatch in submodules:\n"
f"Original: {original_submodules}\n"
f"Optimized: {optimized_submodules}"
)
def test_functional_equivalence_multiple_removals(self):
"""Test equivalence when multiple submodules are made unused."""
model = FullyUsedModel().to("cuda")
x = torch.randn(2, 32, device="cuda")
gm = torch_export_to_gm(model, args=(x,))
gm_original = copy.deepcopy(gm)
gm_optimized = copy.deepcopy(gm)
# Remove multiple submodules
targets = ["linear2", "linear3"]
make_submodules_unused(gm_original, targets)
make_submodules_unused(gm_optimized, targets)
gm_original.delete_all_unused_submodules()
delete_all_unused_submodules(gm_optimized)
assert get_submodule_names(gm_original) == get_submodule_names(gm_optimized)
def test_no_unused_modules(self):
"""Test that nothing is deleted when all modules are used."""
model = FullyUsedModel().to("cuda")
x = torch.randn(2, 32, device="cuda")
gm = torch_export_to_gm(model, args=(x,))
submodules_before = get_submodule_names(gm)
gm_original = copy.deepcopy(gm)
gm_optimized = copy.deepcopy(gm)
# Don't modify the graph - all modules should remain used
gm_original.delete_all_unused_submodules()
delete_all_unused_submodules(gm_optimized)
# Verify nothing was deleted
assert get_submodule_names(gm_original) == submodules_before
assert get_submodule_names(gm_optimized) == submodules_before
def test_all_modules_made_unused(self):
"""Test deletion when all call_module/get_attr nodes are removed."""
model = FullyUsedModel().to("cuda")
x = torch.randn(2, 32, device="cuda")
gm = torch_export_to_gm(model, args=(x,))
gm_original = copy.deepcopy(gm)
gm_optimized = copy.deepcopy(gm)
# Find all submodule targets and remove them
all_targets = list(get_submodule_names(gm))
# Remove references to all submodules
make_submodules_unused(gm_original, all_targets)
make_submodules_unused(gm_optimized, all_targets)
gm_original.delete_all_unused_submodules()
delete_all_unused_submodules(gm_optimized)
# Both should have the same (empty or minimal) set of submodules
assert get_submodule_names(gm_original) == get_submodule_names(gm_optimized)
def test_nested_module_partial_removal(self):
"""Test that parent module stays when only one child is removed."""
model = FullyUsedModel().to("cuda")
x = torch.randn(2, 32, device="cuda")
gm = torch_export_to_gm(model, args=(x,))
gm_original = copy.deepcopy(gm)
gm_optimized = copy.deepcopy(gm)
# Remove only one nested layer, keep the other
make_submodules_unused(gm_original, ["nested.layer_a"])
make_submodules_unused(gm_optimized, ["nested.layer_a"])
gm_original.delete_all_unused_submodules()
delete_all_unused_submodules(gm_optimized)
original_submodules = get_submodule_names(gm_original)
optimized_submodules = get_submodule_names(gm_optimized)
assert original_submodules == optimized_submodules
# Verify nested parent still exists (because layer_b is still used)
# Note: The exact behavior depends on how the graph represents nested modules
def test_deeply_nested_hierarchy(self):
"""Test deletion with deeply nested module hierarchy."""
model = DeeplyNestedModel().to("cuda")
x = torch.randn(2, 32, device="cuda")
gm = torch_export_to_gm(model, args=(x,))
gm_original = copy.deepcopy(gm)
gm_optimized = copy.deepcopy(gm)
# Remove one of the deep leaves
make_submodules_unused(gm_original, ["level1.level2.level3.leaf_a"])
make_submodules_unused(gm_optimized, ["level1.level2.level3.leaf_a"])
gm_original.delete_all_unused_submodules()
delete_all_unused_submodules(gm_optimized)
original_submodules = get_submodule_names(gm_original)
optimized_submodules = get_submodule_names(gm_optimized)
assert original_submodules == optimized_submodules
def test_deeply_nested_full_branch_removal(self):
"""Test removal of entire deeply nested branch."""
model = DeeplyNestedModel().to("cuda")
x = torch.randn(2, 32, device="cuda")
gm = torch_export_to_gm(model, args=(x,))
gm_original = copy.deepcopy(gm)
gm_optimized = copy.deepcopy(gm)
# Remove entire nested branch
make_submodules_unused(
gm_original, ["level1.level2.level3.leaf_a", "level1.level2.level3.leaf_b"]
)
make_submodules_unused(
gm_optimized, ["level1.level2.level3.leaf_a", "level1.level2.level3.leaf_b"]
)
gm_original.delete_all_unused_submodules()
delete_all_unused_submodules(gm_optimized)
original_submodules = get_submodule_names(gm_original)
optimized_submodules = get_submodule_names(gm_optimized)
assert original_submodules == optimized_submodules
def test_sequential_module_handling(self):
"""Test handling of Sequential modules (call_module marks children used)."""
model = ModelWithSequential().to("cuda")
x = torch.randn(2, 32, device="cuda")
gm = torch_export_to_gm(model, args=(x,))
gm_original = copy.deepcopy(gm)
gm_optimized = copy.deepcopy(gm)
# Remove the sequential module
make_submodules_unused(gm_original, ["sequential"])
make_submodules_unused(gm_optimized, ["sequential"])
gm_original.delete_all_unused_submodules()
delete_all_unused_submodules(gm_optimized)
original_submodules = get_submodule_names(gm_original)
optimized_submodules = get_submodule_names(gm_optimized)
assert original_submodules == optimized_submodules
def test_idempotent_deletion(self):
"""Test that running deletion multiple times is idempotent."""
model = FullyUsedModel().to("cuda")
x = torch.randn(2, 32, device="cuda")
gm = torch_export_to_gm(model, args=(x,))
gm_optimized = copy.deepcopy(gm)
make_submodules_unused(gm_optimized, ["linear2"])
# Run deletion twice
delete_all_unused_submodules(gm_optimized)
submodules_after_first = get_submodule_names(gm_optimized)
delete_all_unused_submodules(gm_optimized)
submodules_after_second = get_submodule_names(gm_optimized)
assert submodules_after_first == submodules_after_second
def test_empty_graph_module(self):
"""Test handling of a minimal GraphModule."""
class MinimalModel(nn.Module):
def forward(self, x):
return x
model = MinimalModel().to("cuda")
x = torch.randn(2, 32, device="cuda")
gm = torch_export_to_gm(model, args=(x,))
gm_original = copy.deepcopy(gm)
gm_optimized = copy.deepcopy(gm)
# Should not raise any errors
gm_original.delete_all_unused_submodules()
delete_all_unused_submodules(gm_optimized)
assert get_submodule_names(gm_original) == get_submodule_names(gm_optimized)
# =============================================================================
# Parametrized Tests
# =============================================================================
@pytest.mark.parametrize(
"targets_to_remove",
[
["linear1"],
["linear2"],
["linear3"],
["linear1", "linear2"],
["linear2", "linear3"],
["linear1", "linear3"],
["linear1", "linear2", "linear3"],
],
)
def test_various_removal_combinations(targets_to_remove):
"""Test various combinations of submodule removals."""
model = FullyUsedModel().to("cuda")
x = torch.randn(2, 32, device="cuda")
gm = torch_export_to_gm(model, args=(x,))
gm_original = copy.deepcopy(gm)
gm_optimized = copy.deepcopy(gm)
make_submodules_unused(gm_original, targets_to_remove)
make_submodules_unused(gm_optimized, targets_to_remove)
gm_original.delete_all_unused_submodules()
delete_all_unused_submodules(gm_optimized)
assert get_submodule_names(gm_original) == get_submodule_names(gm_optimized)