diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index 8c32dac0d2..f7dc25e0f0 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -12,8 +12,9 @@ from tensorrt_llm._torch.utils import ActivationType from ...custom_ops.quant import TRTLLM_NVFP4_PACKING_FACTOR, TRTLLM_NVFP4_SCALING_VECTOR_SIZE from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface -from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code +from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code, get_attr_by_name from ...utils.cuda_mem_tracker import cuda_memory_tracker +from ...utils.module import get_submodule_of_param from ...utils.node_utils import bfs, extract_op_args, identify_regions_between_residuals, is_op from ..interface import ( BaseTransform, @@ -1292,6 +1293,33 @@ class MatchBmmMoePattern(BaseTransform): return gm, info +def remove_original_experts(gm: GraphModule, weight_lists: List[List[Node]]) -> None: + """Remove original expert submodules after weights have been stacked. + + This function attempts to free GPU memory by deleting the original expert + submodules whose weights have been replaced by fused/stacked versions. + + Args: + gm: The GraphModule containing the expert submodules + weight_lists: List of weight node lists (e.g., [w1_list, w2_list, w3_list]) + """ + # Flatten all weight lists/ + weight_lists_flat = [w for weights in weight_lists for w in weights] + + for w in weight_lists_flat: + w_param = get_attr_by_name(gm, w.target) + if w_param is not None: + owner_module, owner_module_path, param_name = get_submodule_of_param(gm, w.target) + owner_param = get_attr_by_name(owner_module, param_name) + if owner_param is w_param: + gm.delete_submodule(owner_module_path) + else: + # param w is not owned by owner_module, skip + continue + else: + continue + + def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", "triton"]) -> int: """ Stack per-expert FP8 weights and scales by materializing stacked tensors as parameters. @@ -1515,13 +1543,12 @@ def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", " ) node.replace_all_uses_with(new_node) + input_nodes = node.all_input_nodes graph.erase_node(node) - - # Clean up after processing all nodes - # eliminate_dead_code will remove unused get_attr nodes, then delete_all_unused_submodules - # will remove the parameters/buffers that are no longer referenced - eliminate_dead_code(gm) - delete_all_unused_submodules(gm) + for input_node in input_nodes: + if input_node.op == "get_attr" and len(input_node.users) == 0: + graph.erase_node(input_node) + remove_original_experts(gm, [w1_list, w2_list, w3_list]) return fused_key_counter diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index 4a3e277cb9..7b293abf4c 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -299,7 +299,7 @@ class TestNemotronSuperV3(LlmapiAccuracyTestHarness): @pytest.mark.skip("Skipping FP8 test until it is supported") @pytest.mark.skip_less_device_memory(180000) - @pytest.mark.parametrize("world_size", [4, 8]) + @pytest.mark.parametrize("world_size", [1, 4, 8]) def test_fp8(self, world_size): if get_device_count() < world_size: pytest.skip("Not enough devices for world size, skipping test")