mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-03 17:52:19 +08:00
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
This commit is contained in:
parent
4d2916d683
commit
9879400479
@ -110,10 +110,6 @@ class EliminateRedundantTransposes(BaseTransform):
|
||||
original_input.replace_all_uses_with(new_contiguous_node)
|
||||
new_contiguous_node.replace_input_with(new_contiguous_node, original_input)
|
||||
|
||||
# Clean up the graph
|
||||
if nodes_to_eliminate:
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
num_matches=len(nodes_to_eliminate),
|
||||
|
||||
@ -127,12 +127,10 @@ class FuseCausalConvActivation(BaseTransform):
|
||||
graph.erase_node(activation_node)
|
||||
graph.erase_node(conv_node)
|
||||
|
||||
gm.recompile()
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
num_matches=len(matches),
|
||||
is_clean=False,
|
||||
has_valid_shapes=False,
|
||||
is_clean=len(matches) == 0,
|
||||
has_valid_shapes=len(matches) == 0,
|
||||
)
|
||||
return gm, info
|
||||
|
||||
@ -213,5 +213,5 @@ class FuseMambaALog(BaseTransform):
|
||||
skipped=False,
|
||||
num_matches=num_matches,
|
||||
is_clean=num_matches == 0,
|
||||
has_valid_shapes=True,
|
||||
has_valid_shapes=num_matches == 0,
|
||||
)
|
||||
|
||||
@ -10,6 +10,7 @@ from tensorrt_llm._torch.utils import ActivationType
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code
|
||||
from ...utils.cuda_mem_tracker import cuda_memory_tracker
|
||||
from ...utils.node_utils import bfs, extract_op_args, identify_regions_between_residuals, is_op
|
||||
from ..interface import (
|
||||
@ -112,8 +113,8 @@ def _insert_fused_moe_ops(gm: GraphModule, backend: Literal["auto", "trtllm", "t
|
||||
# Delete the unstacked weights immediately to save GPU memory
|
||||
# This will happen automatically after the graph is canonicalized,
|
||||
# but for large models we'll run out of memory during the transformation itself.
|
||||
gm.graph.eliminate_dead_code()
|
||||
gm.delete_all_unused_submodules()
|
||||
eliminate_dead_code(gm)
|
||||
delete_all_unused_submodules(gm)
|
||||
|
||||
return fused_key_counter
|
||||
|
||||
@ -635,7 +636,7 @@ class MatchMoePattern(BaseTransform):
|
||||
graph.erase_node(final_hidden_state_node)
|
||||
|
||||
while _remove_dead_inplace_nodes_in_region(gm.graph, start_boundary, end_boundary):
|
||||
gm.graph.eliminate_dead_code()
|
||||
eliminate_dead_code(gm)
|
||||
|
||||
num_moe_patterns += 1
|
||||
|
||||
@ -1272,14 +1273,14 @@ class MatchBmmMoePattern(BaseTransform):
|
||||
graph.erase_node(output_node)
|
||||
|
||||
# Clean up dead nodes
|
||||
gm.graph.eliminate_dead_code()
|
||||
eliminate_dead_code(gm)
|
||||
|
||||
# Clean up dead inplace nodes in the region
|
||||
while _remove_dead_inplace_nodes_in_region(gm.graph, start_boundary, end_boundary):
|
||||
gm.graph.eliminate_dead_code()
|
||||
eliminate_dead_code(gm)
|
||||
|
||||
# Delete unused submodules/parameters
|
||||
gm.delete_all_unused_submodules()
|
||||
delete_all_unused_submodules(gm)
|
||||
|
||||
num_moe_patterns += 1
|
||||
|
||||
@ -1517,8 +1518,8 @@ def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", "
|
||||
# Clean up after processing all nodes
|
||||
# eliminate_dead_code will remove unused get_attr nodes, then delete_all_unused_submodules
|
||||
# will remove the parameters/buffers that are no longer referenced
|
||||
gm.graph.eliminate_dead_code()
|
||||
gm.delete_all_unused_submodules()
|
||||
eliminate_dead_code(gm)
|
||||
delete_all_unused_submodules(gm)
|
||||
|
||||
return fused_key_counter
|
||||
|
||||
@ -1776,8 +1777,8 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
|
||||
# Clean up after processing all nodes
|
||||
# eliminate_dead_code will remove unused get_attr nodes, then delete_all_unused_submodules
|
||||
# will remove the parameters/buffers that are no longer referenced
|
||||
gm.graph.eliminate_dead_code()
|
||||
gm.delete_all_unused_submodules()
|
||||
eliminate_dead_code(gm)
|
||||
delete_all_unused_submodules(gm)
|
||||
return fused_key_counter
|
||||
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ from torch.fx import GraphModule, Node
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code
|
||||
from ...utils.cuda_mem_tracker import cuda_memory_tracker
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import extract_weight_name, is_linear_op, is_op
|
||||
@ -75,8 +76,8 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node
|
||||
n.replace_all_uses_with(get_split_node)
|
||||
|
||||
# Clean up deleted modules to save GPU memory
|
||||
gm.graph.eliminate_dead_code()
|
||||
gm.delete_all_unused_submodules()
|
||||
eliminate_dead_code(gm)
|
||||
delete_all_unused_submodules(gm)
|
||||
|
||||
|
||||
def check_same_children(parent_node: Node, is_desired_child: Callable[[Node], bool]) -> bool:
|
||||
@ -185,8 +186,8 @@ class QuantizationFusionMixin(ABC):
|
||||
n.replace_all_uses_with(get_split_node)
|
||||
|
||||
# Clean up deleted modules to save GPU memory
|
||||
gm.graph.eliminate_dead_code()
|
||||
gm.delete_all_unused_submodules()
|
||||
eliminate_dead_code(gm)
|
||||
delete_all_unused_submodules(gm)
|
||||
|
||||
def _apply_fusion_pass(
|
||||
self,
|
||||
|
||||
@ -291,10 +291,6 @@ def _execute_op_in_aux_stream(
|
||||
n.replace_all_uses_with(new_node)
|
||||
graph.erase_node(n)
|
||||
num_replaced += 1
|
||||
if num_replaced:
|
||||
graph.eliminate_dead_code()
|
||||
graph.lint()
|
||||
gm.recompile()
|
||||
|
||||
return gm, num_replaced
|
||||
|
||||
@ -322,8 +318,8 @@ class MultiStreamMOE(BaseTransform):
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
num_matches=num_matches,
|
||||
is_clean=False,
|
||||
has_valid_shapes=False,
|
||||
is_clean=num_matches == 0,
|
||||
has_valid_shapes=num_matches == 0,
|
||||
)
|
||||
|
||||
return gm, info
|
||||
|
||||
@ -33,7 +33,7 @@ from .....functional import AllReduceStrategy
|
||||
from ...custom_ops.trtllm_dist import is_trtllm_op_available
|
||||
from ...models.factory import ModelFactory, ShardingConfigSource
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils._graph import del_attr_by_name
|
||||
from ...utils._graph import del_attr_by_name, eliminate_dead_code
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import (
|
||||
LayerSubgraph,
|
||||
@ -1444,7 +1444,7 @@ def _insert_sharded_moe(
|
||||
node.replace_all_uses_with(dist_node)
|
||||
dist_node.replace_input_with(dist_node, node)
|
||||
|
||||
gm.graph.eliminate_dead_code()
|
||||
eliminate_dead_code(gm)
|
||||
# Expert weights registered via gm.register_parameter() are top-level attributes.
|
||||
# Unlike submodules, these aren't cleaned up by eliminate_dead_code() or
|
||||
# delete_all_unused_submodules() - must delete manually after removing their get_attr nodes.
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
"""Graph-related utilities for transformations."""
|
||||
|
||||
import itertools
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Iterator, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterator, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -129,8 +130,8 @@ def _move_single_gm_to_device(gm: GraphModule, device: torch.device) -> None:
|
||||
)
|
||||
if recompile_graph:
|
||||
# recompile graph to update self generated codes in subgraph
|
||||
gm.graph.lint()
|
||||
gm.recompile()
|
||||
lint(gm)
|
||||
recompile(gm)
|
||||
|
||||
|
||||
def move_to_device(mod: nn.Module, device: DeviceLikeType) -> None:
|
||||
@ -161,18 +162,91 @@ def _is_impure_node(node: Node) -> bool:
|
||||
node.target._nondeterministic_seeded = True
|
||||
|
||||
|
||||
def _canonicalize_single_gm(gm: GraphModule) -> None:
|
||||
# clean up graph (needs to be done repeatedly until no more dead code)
|
||||
gm.graph.eliminate_dead_code(is_impure_node=_is_impure_node)
|
||||
def delete_all_unused_submodules(gm: GraphModule) -> None:
|
||||
"""Optimized version of delete_all_unused_submodules with O(n+m) complexity.
|
||||
|
||||
# recompile to propagate all graph changes to the graph module
|
||||
The original implementation uses a list for tracking used modules, making membership
|
||||
checks O(n). This version uses a set for O(1) lookups.
|
||||
|
||||
Original implementation is at GraphModule.delete_all_unused_submodules
|
||||
|
||||
Args:
|
||||
gm: The GraphModule to clean up.
|
||||
"""
|
||||
used: Set[str] = set()
|
||||
|
||||
for node in itertools.chain(
|
||||
gm.graph.find_nodes(op="call_module", sort=False),
|
||||
gm.graph.find_nodes(op="get_attr", sort=False),
|
||||
):
|
||||
# check if it's already used and it's not a call_module node
|
||||
# in this case we can skip. We cannot skip if it's a call_module node because we need to
|
||||
# mark all recursive submodules as used.
|
||||
if node.target in used and node.op != "call_module":
|
||||
continue
|
||||
|
||||
# A list of strings representing the different parts
|
||||
# of the path. For example, `foo.bar.baz` gives us
|
||||
# ["foo", "bar", "baz"]
|
||||
fullpath = node.target.split(".")
|
||||
|
||||
# Progressively collect all the names of intermediate
|
||||
# modules. For example, if we have the target
|
||||
# `foo.bar.baz`, we'll add `foo`, `foo.bar`, and
|
||||
# `foo.bar.baz` to the list.
|
||||
used.update(".".join(fullpath[:i]) for i in range(1, len(fullpath) + 1))
|
||||
|
||||
# For call_module, also mark all recursive submodules as used
|
||||
if node.op == "call_module":
|
||||
try:
|
||||
submod = gm.get_submodule(node.target)
|
||||
for submod_name, _ in submod.named_modules():
|
||||
if submod_name != "":
|
||||
used.add(f"{node.target}.{submod_name}")
|
||||
except AttributeError:
|
||||
# Node referenced nonexistent submodule, don't need to
|
||||
# worry about GCing anything
|
||||
pass
|
||||
|
||||
# also add the root module to the used set
|
||||
used.add("")
|
||||
|
||||
# Go over all modules and delete if on the list. Since we use named_modules, parents will be
|
||||
# deleted first and children will be automatically skipped inside delete_submodule.
|
||||
to_delete = [name for name, _ in gm.named_modules() if name not in used]
|
||||
for name in to_delete:
|
||||
gm.delete_submodule(name)
|
||||
|
||||
|
||||
def eliminate_dead_code(
|
||||
gm: GraphModule, is_impure_node: Optional[Callable[[Node], bool]] = None
|
||||
) -> None:
|
||||
"""Eliminate dead code from the graph of the given GraphModule."""
|
||||
gm.graph.eliminate_dead_code(is_impure_node=is_impure_node)
|
||||
|
||||
|
||||
def recompile(gm: GraphModule) -> None:
|
||||
"""Recompile the graph of the given GraphModule."""
|
||||
gm.recompile()
|
||||
|
||||
|
||||
def lint(gm: GraphModule) -> None:
|
||||
"""Lint the graph of the given GraphModule."""
|
||||
gm.graph.lint()
|
||||
|
||||
|
||||
def _canonicalize_single_gm(gm: GraphModule) -> None:
|
||||
# clean up graph (needs to be done repeatedly until no more dead code)
|
||||
eliminate_dead_code(gm, is_impure_node=_is_impure_node)
|
||||
|
||||
# recompile to propagate all graph changes to the graph module
|
||||
recompile(gm)
|
||||
|
||||
# clean up graph module
|
||||
gm.delete_all_unused_submodules()
|
||||
delete_all_unused_submodules(gm)
|
||||
|
||||
# lint the graph
|
||||
gm.graph.lint()
|
||||
lint(gm)
|
||||
|
||||
|
||||
def canonicalize_graph(mod: nn.Module) -> None:
|
||||
@ -217,7 +291,7 @@ def _run_shape_prop_single_gm(
|
||||
ad_logger.warning("No fake tensors and no args available for shape propagation")
|
||||
|
||||
# lint the graph
|
||||
gm.graph.lint()
|
||||
lint(gm)
|
||||
|
||||
|
||||
def run_shape_prop(
|
||||
|
||||
@ -19,6 +19,7 @@ from tensorrt_llm._torch.auto_deploy.transform.library.sharding import (
|
||||
SplitDimension,
|
||||
WeightShardingInfo,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.utils._graph import recompile
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
from tensorrt_llm.commands.bench import main
|
||||
from tensorrt_llm.functional import AllReduceStrategy
|
||||
@ -378,7 +379,7 @@ def test_allreduce_strategy_propagation(strategy):
|
||||
if node:
|
||||
transform.check_and_apply(gm, node)
|
||||
|
||||
gm.recompile()
|
||||
recompile(gm)
|
||||
|
||||
# Verify the graph contains torch_dist_all_reduce nodes with correct strategy
|
||||
allreduce_nodes = [
|
||||
|
||||
@ -18,6 +18,7 @@ from tensorrt_llm._torch.auto_deploy.transform.library.sharding import (
|
||||
ShardingTransformConfig,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils._graph import lint, recompile
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
from tensorrt_llm.functional import AllReduceStrategy
|
||||
|
||||
@ -188,8 +189,8 @@ def test_llama4_stacked_moe_pattern_detection():
|
||||
)
|
||||
graph.output(moe_node)
|
||||
|
||||
graph.lint()
|
||||
gm.recompile()
|
||||
lint(gm)
|
||||
recompile(gm)
|
||||
|
||||
# Run pattern detection for EP
|
||||
optimizer = InferenceOptimizer(
|
||||
|
||||
@ -9,6 +9,7 @@ from tensorrt_llm._torch.auto_deploy.transform.library.multi_stream_moe import (
|
||||
cuda_stream_manager,
|
||||
record_event_wrapper,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.utils._graph import canonicalize_graph
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
|
||||
@ -75,9 +76,7 @@ def replace_multi_stream_linear_with_aux_stream_wrapper(gm: GraphModule) -> Tupl
|
||||
num_replaced += 1
|
||||
|
||||
if num_replaced:
|
||||
graph.eliminate_dead_code()
|
||||
graph.lint()
|
||||
gm.recompile()
|
||||
canonicalize_graph(gm)
|
||||
|
||||
return gm, num_replaced
|
||||
|
||||
|
||||
@ -0,0 +1,478 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for delete_all_unused_submodules.
|
||||
|
||||
This module tests the optimized implementation against PyTorch's original
|
||||
GraphModule.delete_all_unused_submodules to ensure functional equivalence.
|
||||
|
||||
Since torch_export_to_gm already eliminates unused submodules during export,
|
||||
we must modify the graph after export to create unused submodules for testing.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import List, Set
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.utils._graph import (
|
||||
canonicalize_graph,
|
||||
delete_all_unused_submodules,
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Test Models - All submodules are used in forward pass
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class FullyUsedModel(nn.Module):
|
||||
"""Model where all submodules are used in forward - none will be pruned during export."""
|
||||
|
||||
def __init__(self, hidden_dim: int = 32):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.linear3 = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.nested = nn.ModuleDict(
|
||||
{
|
||||
"layer_a": nn.Linear(hidden_dim, hidden_dim),
|
||||
"layer_b": nn.Linear(hidden_dim, hidden_dim),
|
||||
}
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# Use ALL submodules so they appear in the exported graph
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
x = self.linear3(x)
|
||||
x = self.nested["layer_a"](x)
|
||||
x = self.nested["layer_b"](x)
|
||||
return x
|
||||
|
||||
|
||||
class DeeplyNestedModel(nn.Module):
|
||||
"""Model with deeply nested submodule hierarchy for testing depth-first deletion."""
|
||||
|
||||
def __init__(self, hidden_dim: int = 32):
|
||||
super().__init__()
|
||||
self.level1 = nn.ModuleDict(
|
||||
{
|
||||
"level2": nn.ModuleDict(
|
||||
{
|
||||
"level3": nn.ModuleDict(
|
||||
{
|
||||
"leaf_a": nn.Linear(hidden_dim, hidden_dim),
|
||||
"leaf_b": nn.Linear(hidden_dim, hidden_dim),
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
self.other = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
def forward(self, x):
|
||||
# Use all nested modules
|
||||
x = self.level1["level2"]["level3"]["leaf_a"](x)
|
||||
x = self.level1["level2"]["level3"]["leaf_b"](x)
|
||||
x = self.other(x)
|
||||
return x
|
||||
|
||||
|
||||
class ModelWithSequential(nn.Module):
|
||||
"""Model with Sequential submodules for testing call_module behavior."""
|
||||
|
||||
def __init__(self, hidden_dim: int = 32):
|
||||
super().__init__()
|
||||
self.pre = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.sequential = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
)
|
||||
self.post = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pre(x)
|
||||
x = self.sequential(x)
|
||||
x = self.post(x)
|
||||
return x
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_submodule_names(gm: nn.Module) -> Set[str]:
|
||||
"""Get all submodule names (excluding root empty string)."""
|
||||
return {name for name, _ in gm.named_modules() if name}
|
||||
|
||||
|
||||
def _target_matches_any(target: str, targets_to_remove: List[str]) -> bool:
|
||||
"""Check if a target matches any of the targets to remove.
|
||||
|
||||
Handles both exact matches and prefix matches (for nested attributes).
|
||||
Target names in the graph use underscores, so we check both formats.
|
||||
"""
|
||||
# Normalize target by replacing underscores with dots for comparison
|
||||
# e.g., "linear2_weight" -> check against "linear2"
|
||||
for t in targets_to_remove:
|
||||
# Exact match
|
||||
if target == t:
|
||||
return True
|
||||
# Prefix match (e.g., "linear2.weight" starts with "linear2.")
|
||||
if target.startswith(t + "."):
|
||||
return True
|
||||
# Handle underscore-separated names from export
|
||||
# e.g., target="linear2_weight" should match t="linear2"
|
||||
t_underscore = t.replace(".", "_")
|
||||
if target == t_underscore or target.startswith(t_underscore + "_"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def make_submodules_unused(gm: GraphModule, targets_to_remove: List[str]) -> None:
|
||||
"""Remove nodes referencing specific submodules to make them unused.
|
||||
|
||||
This function finds all operations that use parameters from the specified
|
||||
submodules and removes them from the graph, rewiring the data flow to bypass
|
||||
those operations. After removal, eliminate_dead_code() cleans up unused
|
||||
get_attr nodes.
|
||||
|
||||
Args:
|
||||
gm: The GraphModule to modify.
|
||||
targets_to_remove: List of submodule target names to make unused.
|
||||
"""
|
||||
graph = gm.graph
|
||||
|
||||
# First pass: find call_function nodes that use parameters from target modules
|
||||
# These are the actual operations (e.g., linear, matmul) that we need to remove
|
||||
nodes_to_bypass = []
|
||||
|
||||
for node in graph.nodes:
|
||||
if node.op == "call_function":
|
||||
# Check if any of the node's inputs come from a target module's get_attr
|
||||
uses_target_module = False
|
||||
for arg in node.args:
|
||||
if isinstance(arg, torch.fx.Node) and arg.op == "get_attr":
|
||||
if _target_matches_any(arg.target, targets_to_remove):
|
||||
uses_target_module = True
|
||||
break
|
||||
if uses_target_module:
|
||||
nodes_to_bypass.append(node)
|
||||
|
||||
elif node.op == "call_module":
|
||||
# Direct call_module nodes (if any)
|
||||
if _target_matches_any(node.target, targets_to_remove):
|
||||
nodes_to_bypass.append(node)
|
||||
|
||||
# Bypass each node by replacing its uses with its first tensor input
|
||||
for node in nodes_to_bypass:
|
||||
# Find the first tensor input (usually the activation, not weights)
|
||||
replacement = None
|
||||
for arg in node.args:
|
||||
if isinstance(arg, torch.fx.Node):
|
||||
# Skip get_attr nodes (weights/biases) - we want the activation input
|
||||
if arg.op != "get_attr":
|
||||
replacement = arg
|
||||
break
|
||||
if replacement is not None:
|
||||
node.replace_all_uses_with(replacement)
|
||||
|
||||
# Remove the bypassed nodes (in reverse topological order)
|
||||
for node in reversed(nodes_to_bypass):
|
||||
graph.erase_node(node)
|
||||
|
||||
canonicalize_graph(gm)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Class
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestDeleteAllUnusedSubmodulesOptimized:
|
||||
"""Tests for delete_all_unused_submodules function."""
|
||||
|
||||
def test_functional_equivalence_basic(self):
|
||||
"""Test that optimized version produces identical results to original."""
|
||||
model = FullyUsedModel().to("cuda")
|
||||
x = torch.randn(2, 32, device="cuda")
|
||||
|
||||
gm = torch_export_to_gm(model, args=(x,))
|
||||
|
||||
# Verify submodules are present after export
|
||||
submodules_after_export = get_submodule_names(gm)
|
||||
assert len(submodules_after_export) > 0, "Expected submodules after export"
|
||||
|
||||
# Make some modules unused by modifying the graph
|
||||
gm_original = copy.deepcopy(gm)
|
||||
gm_optimized = copy.deepcopy(gm)
|
||||
|
||||
make_submodules_unused(gm_original, ["linear2"])
|
||||
make_submodules_unused(gm_optimized, ["linear2"])
|
||||
|
||||
# Apply both deletion implementations
|
||||
gm_original.delete_all_unused_submodules()
|
||||
delete_all_unused_submodules(gm_optimized)
|
||||
|
||||
# Verify identical results
|
||||
original_submodules = get_submodule_names(gm_original)
|
||||
optimized_submodules = get_submodule_names(gm_optimized)
|
||||
|
||||
assert original_submodules == optimized_submodules, (
|
||||
f"Mismatch in submodules:\n"
|
||||
f"Original: {original_submodules}\n"
|
||||
f"Optimized: {optimized_submodules}"
|
||||
)
|
||||
|
||||
def test_functional_equivalence_multiple_removals(self):
|
||||
"""Test equivalence when multiple submodules are made unused."""
|
||||
model = FullyUsedModel().to("cuda")
|
||||
x = torch.randn(2, 32, device="cuda")
|
||||
|
||||
gm = torch_export_to_gm(model, args=(x,))
|
||||
|
||||
gm_original = copy.deepcopy(gm)
|
||||
gm_optimized = copy.deepcopy(gm)
|
||||
|
||||
# Remove multiple submodules
|
||||
targets = ["linear2", "linear3"]
|
||||
make_submodules_unused(gm_original, targets)
|
||||
make_submodules_unused(gm_optimized, targets)
|
||||
|
||||
gm_original.delete_all_unused_submodules()
|
||||
delete_all_unused_submodules(gm_optimized)
|
||||
|
||||
assert get_submodule_names(gm_original) == get_submodule_names(gm_optimized)
|
||||
|
||||
def test_no_unused_modules(self):
|
||||
"""Test that nothing is deleted when all modules are used."""
|
||||
model = FullyUsedModel().to("cuda")
|
||||
x = torch.randn(2, 32, device="cuda")
|
||||
|
||||
gm = torch_export_to_gm(model, args=(x,))
|
||||
|
||||
submodules_before = get_submodule_names(gm)
|
||||
|
||||
gm_original = copy.deepcopy(gm)
|
||||
gm_optimized = copy.deepcopy(gm)
|
||||
|
||||
# Don't modify the graph - all modules should remain used
|
||||
gm_original.delete_all_unused_submodules()
|
||||
delete_all_unused_submodules(gm_optimized)
|
||||
|
||||
# Verify nothing was deleted
|
||||
assert get_submodule_names(gm_original) == submodules_before
|
||||
assert get_submodule_names(gm_optimized) == submodules_before
|
||||
|
||||
def test_all_modules_made_unused(self):
|
||||
"""Test deletion when all call_module/get_attr nodes are removed."""
|
||||
model = FullyUsedModel().to("cuda")
|
||||
x = torch.randn(2, 32, device="cuda")
|
||||
|
||||
gm = torch_export_to_gm(model, args=(x,))
|
||||
|
||||
gm_original = copy.deepcopy(gm)
|
||||
gm_optimized = copy.deepcopy(gm)
|
||||
|
||||
# Find all submodule targets and remove them
|
||||
all_targets = list(get_submodule_names(gm))
|
||||
|
||||
# Remove references to all submodules
|
||||
make_submodules_unused(gm_original, all_targets)
|
||||
make_submodules_unused(gm_optimized, all_targets)
|
||||
|
||||
gm_original.delete_all_unused_submodules()
|
||||
delete_all_unused_submodules(gm_optimized)
|
||||
|
||||
# Both should have the same (empty or minimal) set of submodules
|
||||
assert get_submodule_names(gm_original) == get_submodule_names(gm_optimized)
|
||||
|
||||
def test_nested_module_partial_removal(self):
|
||||
"""Test that parent module stays when only one child is removed."""
|
||||
model = FullyUsedModel().to("cuda")
|
||||
x = torch.randn(2, 32, device="cuda")
|
||||
|
||||
gm = torch_export_to_gm(model, args=(x,))
|
||||
|
||||
gm_original = copy.deepcopy(gm)
|
||||
gm_optimized = copy.deepcopy(gm)
|
||||
|
||||
# Remove only one nested layer, keep the other
|
||||
make_submodules_unused(gm_original, ["nested.layer_a"])
|
||||
make_submodules_unused(gm_optimized, ["nested.layer_a"])
|
||||
|
||||
gm_original.delete_all_unused_submodules()
|
||||
delete_all_unused_submodules(gm_optimized)
|
||||
|
||||
original_submodules = get_submodule_names(gm_original)
|
||||
optimized_submodules = get_submodule_names(gm_optimized)
|
||||
|
||||
assert original_submodules == optimized_submodules
|
||||
|
||||
# Verify nested parent still exists (because layer_b is still used)
|
||||
# Note: The exact behavior depends on how the graph represents nested modules
|
||||
|
||||
def test_deeply_nested_hierarchy(self):
|
||||
"""Test deletion with deeply nested module hierarchy."""
|
||||
model = DeeplyNestedModel().to("cuda")
|
||||
x = torch.randn(2, 32, device="cuda")
|
||||
|
||||
gm = torch_export_to_gm(model, args=(x,))
|
||||
|
||||
gm_original = copy.deepcopy(gm)
|
||||
gm_optimized = copy.deepcopy(gm)
|
||||
|
||||
# Remove one of the deep leaves
|
||||
make_submodules_unused(gm_original, ["level1.level2.level3.leaf_a"])
|
||||
make_submodules_unused(gm_optimized, ["level1.level2.level3.leaf_a"])
|
||||
|
||||
gm_original.delete_all_unused_submodules()
|
||||
delete_all_unused_submodules(gm_optimized)
|
||||
|
||||
original_submodules = get_submodule_names(gm_original)
|
||||
optimized_submodules = get_submodule_names(gm_optimized)
|
||||
|
||||
assert original_submodules == optimized_submodules
|
||||
|
||||
def test_deeply_nested_full_branch_removal(self):
|
||||
"""Test removal of entire deeply nested branch."""
|
||||
model = DeeplyNestedModel().to("cuda")
|
||||
x = torch.randn(2, 32, device="cuda")
|
||||
|
||||
gm = torch_export_to_gm(model, args=(x,))
|
||||
|
||||
gm_original = copy.deepcopy(gm)
|
||||
gm_optimized = copy.deepcopy(gm)
|
||||
|
||||
# Remove entire nested branch
|
||||
make_submodules_unused(
|
||||
gm_original, ["level1.level2.level3.leaf_a", "level1.level2.level3.leaf_b"]
|
||||
)
|
||||
make_submodules_unused(
|
||||
gm_optimized, ["level1.level2.level3.leaf_a", "level1.level2.level3.leaf_b"]
|
||||
)
|
||||
|
||||
gm_original.delete_all_unused_submodules()
|
||||
delete_all_unused_submodules(gm_optimized)
|
||||
|
||||
original_submodules = get_submodule_names(gm_original)
|
||||
optimized_submodules = get_submodule_names(gm_optimized)
|
||||
|
||||
assert original_submodules == optimized_submodules
|
||||
|
||||
def test_sequential_module_handling(self):
|
||||
"""Test handling of Sequential modules (call_module marks children used)."""
|
||||
model = ModelWithSequential().to("cuda")
|
||||
x = torch.randn(2, 32, device="cuda")
|
||||
|
||||
gm = torch_export_to_gm(model, args=(x,))
|
||||
|
||||
gm_original = copy.deepcopy(gm)
|
||||
gm_optimized = copy.deepcopy(gm)
|
||||
|
||||
# Remove the sequential module
|
||||
make_submodules_unused(gm_original, ["sequential"])
|
||||
make_submodules_unused(gm_optimized, ["sequential"])
|
||||
|
||||
gm_original.delete_all_unused_submodules()
|
||||
delete_all_unused_submodules(gm_optimized)
|
||||
|
||||
original_submodules = get_submodule_names(gm_original)
|
||||
optimized_submodules = get_submodule_names(gm_optimized)
|
||||
|
||||
assert original_submodules == optimized_submodules
|
||||
|
||||
def test_idempotent_deletion(self):
|
||||
"""Test that running deletion multiple times is idempotent."""
|
||||
model = FullyUsedModel().to("cuda")
|
||||
x = torch.randn(2, 32, device="cuda")
|
||||
|
||||
gm = torch_export_to_gm(model, args=(x,))
|
||||
|
||||
gm_optimized = copy.deepcopy(gm)
|
||||
make_submodules_unused(gm_optimized, ["linear2"])
|
||||
|
||||
# Run deletion twice
|
||||
delete_all_unused_submodules(gm_optimized)
|
||||
submodules_after_first = get_submodule_names(gm_optimized)
|
||||
|
||||
delete_all_unused_submodules(gm_optimized)
|
||||
submodules_after_second = get_submodule_names(gm_optimized)
|
||||
|
||||
assert submodules_after_first == submodules_after_second
|
||||
|
||||
def test_empty_graph_module(self):
|
||||
"""Test handling of a minimal GraphModule."""
|
||||
|
||||
class MinimalModel(nn.Module):
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
model = MinimalModel().to("cuda")
|
||||
x = torch.randn(2, 32, device="cuda")
|
||||
|
||||
gm = torch_export_to_gm(model, args=(x,))
|
||||
|
||||
gm_original = copy.deepcopy(gm)
|
||||
gm_optimized = copy.deepcopy(gm)
|
||||
|
||||
# Should not raise any errors
|
||||
gm_original.delete_all_unused_submodules()
|
||||
delete_all_unused_submodules(gm_optimized)
|
||||
|
||||
assert get_submodule_names(gm_original) == get_submodule_names(gm_optimized)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Parametrized Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"targets_to_remove",
|
||||
[
|
||||
["linear1"],
|
||||
["linear2"],
|
||||
["linear3"],
|
||||
["linear1", "linear2"],
|
||||
["linear2", "linear3"],
|
||||
["linear1", "linear3"],
|
||||
["linear1", "linear2", "linear3"],
|
||||
],
|
||||
)
|
||||
def test_various_removal_combinations(targets_to_remove):
|
||||
"""Test various combinations of submodule removals."""
|
||||
model = FullyUsedModel().to("cuda")
|
||||
x = torch.randn(2, 32, device="cuda")
|
||||
|
||||
gm = torch_export_to_gm(model, args=(x,))
|
||||
|
||||
gm_original = copy.deepcopy(gm)
|
||||
gm_optimized = copy.deepcopy(gm)
|
||||
|
||||
make_submodules_unused(gm_original, targets_to_remove)
|
||||
make_submodules_unused(gm_optimized, targets_to_remove)
|
||||
|
||||
gm_original.delete_all_unused_submodules()
|
||||
delete_all_unused_submodules(gm_optimized)
|
||||
|
||||
assert get_submodule_names(gm_original) == get_submodule_names(gm_optimized)
|
||||
Loading…
Reference in New Issue
Block a user