[None][feat] AutoDeploy: Enhance memory consumption for MoE fusion transform (#10772)

Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
This commit is contained in:
Taylor Yeonbok Lee 2026-01-23 15:22:54 -08:00 committed by GitHub
parent b560598c79
commit 1fbbb1f3cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 8 deletions

View File

@ -12,8 +12,9 @@ from tensorrt_llm._torch.utils import ActivationType
from ...custom_ops.quant import TRTLLM_NVFP4_PACKING_FACTOR, TRTLLM_NVFP4_SCALING_VECTOR_SIZE
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code
from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code, get_attr_by_name
from ...utils.cuda_mem_tracker import cuda_memory_tracker
from ...utils.module import get_submodule_of_param
from ...utils.node_utils import bfs, extract_op_args, identify_regions_between_residuals, is_op
from ..interface import (
BaseTransform,
@ -1292,6 +1293,33 @@ class MatchBmmMoePattern(BaseTransform):
return gm, info
def remove_original_experts(gm: GraphModule, weight_lists: List[List[Node]]) -> None:
"""Remove original expert submodules after weights have been stacked.
This function attempts to free GPU memory by deleting the original expert
submodules whose weights have been replaced by fused/stacked versions.
Args:
gm: The GraphModule containing the expert submodules
weight_lists: List of weight node lists (e.g., [w1_list, w2_list, w3_list])
"""
# Flatten all weight lists/
weight_lists_flat = [w for weights in weight_lists for w in weights]
for w in weight_lists_flat:
w_param = get_attr_by_name(gm, w.target)
if w_param is not None:
owner_module, owner_module_path, param_name = get_submodule_of_param(gm, w.target)
owner_param = get_attr_by_name(owner_module, param_name)
if owner_param is w_param:
gm.delete_submodule(owner_module_path)
else:
# param w is not owned by owner_module, skip
continue
else:
continue
def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", "triton"]) -> int:
"""
Stack per-expert FP8 weights and scales by materializing stacked tensors as parameters.
@ -1515,13 +1543,12 @@ def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", "
)
node.replace_all_uses_with(new_node)
input_nodes = node.all_input_nodes
graph.erase_node(node)
# Clean up after processing all nodes
# eliminate_dead_code will remove unused get_attr nodes, then delete_all_unused_submodules
# will remove the parameters/buffers that are no longer referenced
eliminate_dead_code(gm)
delete_all_unused_submodules(gm)
for input_node in input_nodes:
if input_node.op == "get_attr" and len(input_node.users) == 0:
graph.erase_node(input_node)
remove_original_experts(gm, [w1_list, w2_list, w3_list])
return fused_key_counter

View File

@ -299,7 +299,7 @@ class TestNemotronSuperV3(LlmapiAccuracyTestHarness):
@pytest.mark.skip("Skipping FP8 test until it is supported")
@pytest.mark.skip_less_device_memory(180000)
@pytest.mark.parametrize("world_size", [4, 8])
@pytest.mark.parametrize("world_size", [1, 4, 8])
def test_fp8(self, world_size):
if get_device_count() < world_size:
pytest.skip("Not enough devices for world size, skipping test")