mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
[None][feat] AutoDeploy: Enhance memory consumption for MoE fusion transform (#10772)
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
This commit is contained in:
parent
b560598c79
commit
1fbbb1f3cd
@ -12,8 +12,9 @@ from tensorrt_llm._torch.utils import ActivationType
|
||||
from ...custom_ops.quant import TRTLLM_NVFP4_PACKING_FACTOR, TRTLLM_NVFP4_SCALING_VECTOR_SIZE
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code
|
||||
from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code, get_attr_by_name
|
||||
from ...utils.cuda_mem_tracker import cuda_memory_tracker
|
||||
from ...utils.module import get_submodule_of_param
|
||||
from ...utils.node_utils import bfs, extract_op_args, identify_regions_between_residuals, is_op
|
||||
from ..interface import (
|
||||
BaseTransform,
|
||||
@ -1292,6 +1293,33 @@ class MatchBmmMoePattern(BaseTransform):
|
||||
return gm, info
|
||||
|
||||
|
||||
def remove_original_experts(gm: GraphModule, weight_lists: List[List[Node]]) -> None:
|
||||
"""Remove original expert submodules after weights have been stacked.
|
||||
|
||||
This function attempts to free GPU memory by deleting the original expert
|
||||
submodules whose weights have been replaced by fused/stacked versions.
|
||||
|
||||
Args:
|
||||
gm: The GraphModule containing the expert submodules
|
||||
weight_lists: List of weight node lists (e.g., [w1_list, w2_list, w3_list])
|
||||
"""
|
||||
# Flatten all weight lists/
|
||||
weight_lists_flat = [w for weights in weight_lists for w in weights]
|
||||
|
||||
for w in weight_lists_flat:
|
||||
w_param = get_attr_by_name(gm, w.target)
|
||||
if w_param is not None:
|
||||
owner_module, owner_module_path, param_name = get_submodule_of_param(gm, w.target)
|
||||
owner_param = get_attr_by_name(owner_module, param_name)
|
||||
if owner_param is w_param:
|
||||
gm.delete_submodule(owner_module_path)
|
||||
else:
|
||||
# param w is not owned by owner_module, skip
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", "triton"]) -> int:
|
||||
"""
|
||||
Stack per-expert FP8 weights and scales by materializing stacked tensors as parameters.
|
||||
@ -1515,13 +1543,12 @@ def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", "
|
||||
)
|
||||
|
||||
node.replace_all_uses_with(new_node)
|
||||
input_nodes = node.all_input_nodes
|
||||
graph.erase_node(node)
|
||||
|
||||
# Clean up after processing all nodes
|
||||
# eliminate_dead_code will remove unused get_attr nodes, then delete_all_unused_submodules
|
||||
# will remove the parameters/buffers that are no longer referenced
|
||||
eliminate_dead_code(gm)
|
||||
delete_all_unused_submodules(gm)
|
||||
for input_node in input_nodes:
|
||||
if input_node.op == "get_attr" and len(input_node.users) == 0:
|
||||
graph.erase_node(input_node)
|
||||
remove_original_experts(gm, [w1_list, w2_list, w3_list])
|
||||
|
||||
return fused_key_counter
|
||||
|
||||
|
||||
@ -299,7 +299,7 @@ class TestNemotronSuperV3(LlmapiAccuracyTestHarness):
|
||||
|
||||
@pytest.mark.skip("Skipping FP8 test until it is supported")
|
||||
@pytest.mark.skip_less_device_memory(180000)
|
||||
@pytest.mark.parametrize("world_size", [4, 8])
|
||||
@pytest.mark.parametrize("world_size", [1, 4, 8])
|
||||
def test_fp8(self, world_size):
|
||||
if get_device_count() < world_size:
|
||||
pytest.skip("Not enough devices for world size, skipping test")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user