mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
This commit is contained in:
parent
ef1d4a40b5
commit
464847c6be
@ -102,10 +102,6 @@ def torch_moe(
|
||||
Unified Mixture-of-Experts (MoE) operator that uses a Mixtral-style dispatch
|
||||
(token routing + index_add_ accumulation) and a selectable per-expert MLP.
|
||||
|
||||
Supports both:
|
||||
- Standard MoE with per-expert weight lists (apply_routing_on_input=False)
|
||||
- Llama4 MoE with stacked weight tensors (apply_routing_on_input=True)
|
||||
|
||||
Parameters:
|
||||
x (torch.Tensor): Input tensor of shape (B, H) or (B, S, H), where B is the batch size,
|
||||
S is the sequence length, and H is the hidden size.
|
||||
@ -113,87 +109,31 @@ def torch_moe(
|
||||
of the selected experts for each token. Only experts within range [0,num_experts) is processed
|
||||
routing_weights (torch.Tensor): A tensor of shape (B, TOP_K) or (B*S, TOP_K) containing the normalized
|
||||
routing weights for the selected experts.
|
||||
- Standard MoE: softmax normalized weights
|
||||
- Llama4 MoE: sigmoid activated weights
|
||||
w1_weight:
|
||||
For per-expert lists:
|
||||
• is_gated_mlp==True: List of W1 with shape (I, H) — "gate" projection.
|
||||
• is_gated_mlp==False: List of W_up with shape (I, H) — up projection.
|
||||
For stacked tensors (Llama4):
|
||||
• Single-element list containing stacked w3_w1 tensor with shape (E, 2*I, H) in TRT-LLM format
|
||||
w2_weight:
|
||||
For per-expert lists:
|
||||
• List of W2/W_down with shape (H, I) — down projection.
|
||||
For stacked tensors (Llama4):
|
||||
• Single-element list containing stacked w2 tensor with shape (E, H, I) in TRT-LLM format
|
||||
w3_weight:
|
||||
For per-expert lists with gated_mlp:
|
||||
• List of W3 with shape (I, H) — "up" (second) projection in gated MLP.
|
||||
For is_gated_mlp==False or stacked tensors:
|
||||
• pass an empty list []; ignored.
|
||||
is_gated_mlp:
|
||||
Selects the per-expert MLP computation:
|
||||
• is_gated_mlp==True (default, Mixtral/DeepSeek/Llama4-style):
|
||||
y = W2( act(W1 x) * (W3 x) )
|
||||
• is_gated_mlp==False (NemotronH-style 2-layer MLP):
|
||||
y = W_down( act(W_up x) )
|
||||
act_fn:
|
||||
Elementwise activation applied inside the expert MLP.
|
||||
w1_weight: List of per-expert weight tensors of up projection.
|
||||
w2_weight: List of per-expert weight tensors of down projection.
|
||||
w3_weight: List of per-expert weight tensors of gate projection.
|
||||
is_gated_mlp: If True, use a gated MLP. If False, use a simple MLP.
|
||||
act_fn: Activation function applied inside the expert MLP.
|
||||
Supported: ActivationType.Silu (default), ActivationType.Relu2 (ReLU then square).
|
||||
apply_routing_on_input:
|
||||
If True (Llama4 pattern): multiply routing weights with INPUT before MLP
|
||||
Result: act(input * routing_weight) - routing affects activation
|
||||
If False (standard pattern): multiply routing weights with OUTPUT after MLP
|
||||
Result: act(input) * routing_weight - routing scales output
|
||||
apply_routing_on_input: If True, multiply routing weights with INPUT before MLP
|
||||
This means: silu(input * routing_weight)
|
||||
If False, multiply routing weights with OUTPUT after MLP
|
||||
This means: silu(input) * routing_weight
|
||||
Returns:
|
||||
torch.Tensor: Output tensor with the same shape as the input x.
|
||||
"""
|
||||
torch_act_fn = _resolve_torch_fn(act_fn)
|
||||
|
||||
# Detect if using stacked tensor format (Llama4) vs per-expert lists (standard)
|
||||
is_stacked = len(w1_weight) == 1 and w1_weight[0].ndim == 3
|
||||
|
||||
# Todo: either change torch_moe to use a single condition, or refactor this code.
|
||||
# it should be :
|
||||
# is_gated_mlp:
|
||||
# stacked:
|
||||
# ...
|
||||
# not stacked:
|
||||
# .
|
||||
# else:
|
||||
# assert (not stacked)
|
||||
# ...
|
||||
# .
|
||||
if is_stacked:
|
||||
# Llama4 stacked tensor format - only supports gated_mlp
|
||||
if not is_gated_mlp:
|
||||
raise ValueError("Stacked tensor format only supports gated MLP style")
|
||||
|
||||
w3_w1_stacked = w1_weight[0] # (E, 2*I, H)
|
||||
intermediate_size = w3_w1_stacked.shape[1] // 2
|
||||
w2_stacked = w2_weight[0] # (E, H, I)
|
||||
|
||||
def make_mlp(idx: int):
|
||||
gate_up = w3_w1_stacked[idx] # (2*I, H)
|
||||
W3 = gate_up[:intermediate_size, :] # (I, H)
|
||||
W1 = gate_up[intermediate_size:, :] # (I, H)
|
||||
W2 = w2_stacked[idx] # (H, I)
|
||||
weight_dtype = W1.dtype
|
||||
return lambda inp: F.linear(
|
||||
torch_act_fn(F.linear(inp.to(weight_dtype), W1))
|
||||
* F.linear(inp.to(weight_dtype), W3),
|
||||
W2,
|
||||
)
|
||||
|
||||
mlps = [make_mlp(idx) for idx in range(w3_w1_stacked.shape[0])]
|
||||
|
||||
elif is_gated_mlp:
|
||||
mlps = []
|
||||
if is_gated_mlp:
|
||||
# Standard per-expert list format with gated MLP
|
||||
def make_mlp(i: int):
|
||||
W1 = w1_weight[i] # (I, H)
|
||||
W2 = w2_weight[i] # (H, I)
|
||||
W3 = w3_weight[i] # (I, H)
|
||||
return lambda inp: F.linear(torch_act_fn(F.linear(inp, W1)) * F.linear(inp, W3), W2)
|
||||
return lambda inp: F.linear(
|
||||
torch_act_fn(F.linear(inp.to(W1.dtype), W1)) * F.linear(inp.to(W3.dtype), W3), W2
|
||||
)
|
||||
|
||||
mlps = [make_mlp(i) for i in range(len(w1_weight))]
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
@ -20,6 +21,60 @@ from ..interface import (
|
||||
)
|
||||
|
||||
|
||||
def _bmm_moe_gate_up_split_hook(
|
||||
state_dict,
|
||||
prefix,
|
||||
*args,
|
||||
source_key: str,
|
||||
intermediate_size: int,
|
||||
w1_keys: List[str],
|
||||
w3_keys: List[str],
|
||||
):
|
||||
"""Hook to split gate_up_weight into all per-expert w1 and w3 weights.
|
||||
|
||||
Args:
|
||||
source_key: Original stacked weight key (e.g., "gate_up_weight")
|
||||
intermediate_size: Intermediate dimension size
|
||||
w1_keys: List of target parameter keys for w1 weights
|
||||
w3_keys: List of target parameter keys for w3 weights
|
||||
"""
|
||||
source_full_key = prefix + source_key
|
||||
|
||||
if source_full_key in state_dict:
|
||||
stacked_tensor = state_dict[source_full_key]
|
||||
# Split on last dim: (E, H, 2I) -> 2x (E, H, I)
|
||||
w1_stacked, w3_stacked = stacked_tensor.split(intermediate_size, dim=2)
|
||||
# Transpose and contiguous in batch, then unbind into views
|
||||
w1_experts = w1_stacked.transpose(1, 2).contiguous().unbind(0)
|
||||
w3_experts = w3_stacked.transpose(1, 2).contiguous().unbind(0)
|
||||
for w1_key, w3_key, w1, w3 in zip(w1_keys, w3_keys, w1_experts, w3_experts):
|
||||
state_dict[prefix + w1_key] = w1
|
||||
state_dict[prefix + w3_key] = w3
|
||||
|
||||
|
||||
def _bmm_moe_down_split_hook(
|
||||
state_dict,
|
||||
prefix,
|
||||
*args,
|
||||
source_key: str,
|
||||
w2_keys: List[str],
|
||||
):
|
||||
"""Hook to split down_weight into all per-expert w2 weights.
|
||||
|
||||
Args:
|
||||
source_key: Original stacked weight key (e.g., "down_weight")
|
||||
w2_keys: List of target parameter keys for w2 weights
|
||||
"""
|
||||
source_full_key = prefix + source_key
|
||||
|
||||
if source_full_key in state_dict:
|
||||
stacked_tensor = state_dict[source_full_key]
|
||||
# Transpose and contiguous in batch, then unbind into views
|
||||
w2_experts = stacked_tensor.transpose(1, 2).contiguous().unbind(0)
|
||||
for w2_key, w2 in zip(w2_keys, w2_experts):
|
||||
state_dict[prefix + w2_key] = w2
|
||||
|
||||
|
||||
def _insert_fused_moe_ops(gm: GraphModule, backend: Literal["auto", "trtllm", "triton"]) -> int:
|
||||
"""Replace torch MoE ops with fused backend-specific implementations.
|
||||
|
||||
@ -41,65 +96,29 @@ def _insert_fused_moe_ops(gm: GraphModule, backend: Literal["auto", "trtllm", "t
|
||||
}[backend]
|
||||
|
||||
for node in graph.nodes:
|
||||
if not is_op(node, torch.ops.auto_deploy.torch_moe):
|
||||
continue
|
||||
if is_op(node, torch.ops.auto_deploy.torch_moe):
|
||||
(is_gated_mlp, act_fn) = extract_op_args(node, "is_gated_mlp", "act_fn")
|
||||
|
||||
# Detect if this is a stacked MoE (Llama4 pattern) or per-expert list (standard pattern)
|
||||
(apply_routing_val, w1_weight_list) = extract_op_args(
|
||||
node, "apply_routing_on_input", "w1_weight"
|
||||
)
|
||||
|
||||
# Check if it's stacked format: single-element list with 3D tensor
|
||||
is_stacked_moe = False
|
||||
if apply_routing_val:
|
||||
# In FX graphs, w1_weight_list might be a Node representing a list() call
|
||||
list_content = None
|
||||
if isinstance(w1_weight_list, Node) and w1_weight_list.target is list:
|
||||
# Extract from list() call node
|
||||
if w1_weight_list.args:
|
||||
list_content = w1_weight_list.args[0]
|
||||
elif isinstance(w1_weight_list, (list, tuple)):
|
||||
# Direct Python list
|
||||
list_content = w1_weight_list
|
||||
|
||||
# Check if it's a single-element list with a 3D tensor
|
||||
if list_content is not None and len(list_content) == 1:
|
||||
w1_node = list_content[0]
|
||||
if isinstance(w1_node, Node) and w1_node.op == "get_attr":
|
||||
try:
|
||||
w1_tensor = gm.get_parameter(w1_node.target)
|
||||
is_stacked_moe = w1_tensor.ndim == 3
|
||||
except (AttributeError, KeyError):
|
||||
pass
|
||||
|
||||
(is_gated_mlp, act_fn) = extract_op_args(node, "is_gated_mlp", "act_fn")
|
||||
|
||||
if is_stacked_moe:
|
||||
# Stacked MoE (Llama4 pattern): only supports gated MLP
|
||||
_process_llama4_stacked_moe_node(
|
||||
gm, graph, node, replacement_op, act_fn, fused_key_counter
|
||||
)
|
||||
else:
|
||||
# Standard MoE with per-expert weight lists
|
||||
assert backend != "triton" or not is_gated_mlp, (
|
||||
"Triton backend only supports mlp style."
|
||||
)
|
||||
_process_regular_moe_node(
|
||||
_process_moe_node(
|
||||
gm, graph, node, replacement_op, is_gated_mlp, act_fn, fused_key_counter
|
||||
)
|
||||
|
||||
fused_key_counter += 1
|
||||
fused_key_counter += 1
|
||||
|
||||
# Delete the unstacked weights immediately to save GPU memory
|
||||
# This will happen automatically after the graph is canonicalized, but for large models we'll run out of memory
|
||||
# during the transformation itself.
|
||||
gm.graph.eliminate_dead_code()
|
||||
gm.delete_all_unused_submodules()
|
||||
# Delete the unstacked weights immediately to save GPU memory
|
||||
# This will happen automatically after the graph is canonicalized,
|
||||
# but for large models we'll run out of memory during the transformation itself.
|
||||
gm.graph.eliminate_dead_code()
|
||||
gm.delete_all_unused_submodules()
|
||||
|
||||
return fused_key_counter
|
||||
|
||||
|
||||
def _process_regular_moe_node(
|
||||
def _process_moe_node(
|
||||
gm: GraphModule,
|
||||
graph: torch.fx.Graph,
|
||||
node: Node,
|
||||
@ -113,7 +132,15 @@ def _process_regular_moe_node(
|
||||
Stacks weight parameters and creates a fused MoE node.
|
||||
The kernel applies routing weights to the output.
|
||||
"""
|
||||
hidden_states, selected_experts, routing_weights, w1_list, w2_list, w3_list = extract_op_args(
|
||||
(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
w1_list,
|
||||
w2_list,
|
||||
w3_list,
|
||||
apply_routing_on_input,
|
||||
) = extract_op_args(
|
||||
node,
|
||||
"x",
|
||||
"selected_experts",
|
||||
@ -121,6 +148,7 @@ def _process_regular_moe_node(
|
||||
"w1_weight",
|
||||
"w2_weight",
|
||||
"w3_weight",
|
||||
"apply_routing_on_input",
|
||||
)
|
||||
|
||||
# Stack weights based on MLP style
|
||||
@ -157,6 +185,27 @@ def _process_regular_moe_node(
|
||||
with graph.inserting_before(node):
|
||||
w_up_arg = graph.get_attr(new_key_w_up)
|
||||
w_down_arg = graph.get_attr(new_key_w_down)
|
||||
# Get weight dtype for casting - fused kernel requires activation dtype to match weight dtype
|
||||
weight_dtype = fused_w_up_experts.dtype
|
||||
|
||||
if apply_routing_on_input:
|
||||
# Scale input: hidden_states = hidden_states * routing_weights
|
||||
hidden_states = graph.call_function(
|
||||
torch.ops.aten.mul.Tensor,
|
||||
args=(hidden_states, routing_weights),
|
||||
)
|
||||
|
||||
# Pass ones to kernel to prevent it from multiplying routing again (already applied)
|
||||
routing_weights = graph.call_function(
|
||||
torch.ops.aten.ones_like.default,
|
||||
args=(routing_weights,),
|
||||
)
|
||||
|
||||
# Kernel requires activation dtype to match weight dtype
|
||||
hidden_states = graph.call_function(
|
||||
torch.ops.aten.to,
|
||||
args=(hidden_states, weight_dtype),
|
||||
)
|
||||
|
||||
new_node = graph.call_function(
|
||||
replacement_op,
|
||||
@ -171,145 +220,6 @@ def _process_regular_moe_node(
|
||||
graph.erase_node(node)
|
||||
|
||||
|
||||
def _process_llama4_stacked_moe_node(
|
||||
gm: GraphModule,
|
||||
graph: torch.fx.Graph,
|
||||
node: Node,
|
||||
replacement_op,
|
||||
act_fn: ActivationType,
|
||||
fused_key_counter: int,
|
||||
) -> None:
|
||||
"""Process a single Llama4 MoE node with pre-stacked weight tensors.
|
||||
|
||||
Only supports gated MLP (SwiGLU-style) architecture.
|
||||
Converts Llama4 format weights to TRT-LLM format to standardize all downstream ops.
|
||||
Applies routing weights to INPUT before the fused kernel to prevent double multiplication.
|
||||
This is the Llama4 pattern where weights are already stacked across experts.
|
||||
Result: silu(input * routing_weight) - routing affects activation.
|
||||
"""
|
||||
# torch_moe with stacked format: weights are in single-element lists
|
||||
hidden_states, selected_experts, routing_weights, w1_list, w2_list = extract_op_args(
|
||||
node,
|
||||
"x",
|
||||
"selected_experts",
|
||||
"routing_weights",
|
||||
"w1_weight",
|
||||
"w2_weight",
|
||||
)
|
||||
|
||||
# Extract the single stacked tensor from each list
|
||||
# Handle both FX graph Nodes (list() calls) and direct Python lists
|
||||
def extract_from_list_arg(list_arg):
|
||||
if isinstance(list_arg, Node) and list_arg.target is list:
|
||||
# Extract from list() call node
|
||||
return list_arg.args[0][0] if list_arg.args else None
|
||||
elif isinstance(list_arg, (list, tuple)):
|
||||
# Direct Python list
|
||||
return list_arg[0]
|
||||
else:
|
||||
raise ValueError(f"Unexpected list format: {type(list_arg)}")
|
||||
|
||||
w3_w1_stacked = extract_from_list_arg(w1_list)
|
||||
w2_stacked = extract_from_list_arg(w2_list)
|
||||
|
||||
# Convert Llama4 format to TRT-LLM format if needed
|
||||
# This standardizes all downstream ops to only handle TRT-LLM format
|
||||
if w3_w1_stacked.op == "get_attr" and w2_stacked.op == "get_attr":
|
||||
gate_up_weight = gm.get_parameter(w3_w1_stacked.target)
|
||||
down_weight = gm.get_parameter(w2_stacked.target)
|
||||
|
||||
# Detect format:
|
||||
# - Llama4: gate_up is (E, H, 2*I) and down is (E, I, H)
|
||||
# - TRT-LLM: gate_up is (E, 2*I, H) and down is (E, H, I)
|
||||
# If both have H in middle dimension, they're Llama4 format
|
||||
is_llama4 = gate_up_weight.shape[1] == down_weight.shape[2]
|
||||
|
||||
if is_llama4:
|
||||
# Convert Llama4 (E, H, 2*I) -> TRT-LLM (E, 2*I, H)
|
||||
gate_up_trtllm = gate_up_weight.transpose(1, 2).contiguous()
|
||||
# Convert Llama4 (E, I, H) -> TRT-LLM (E, H, I)
|
||||
down_trtllm = down_weight.transpose(1, 2).contiguous()
|
||||
|
||||
# Register converted weights
|
||||
new_key_w_up = f"llama4_to_trtllm_w3_w1_{fused_key_counter}"
|
||||
new_key_w_down = f"llama4_to_trtllm_w2_{fused_key_counter}"
|
||||
|
||||
gm.register_parameter(new_key_w_up, torch.nn.Parameter(gate_up_trtllm))
|
||||
gm.register_parameter(new_key_w_down, torch.nn.Parameter(down_trtllm))
|
||||
|
||||
# Store keys to create get_attr nodes later in insertion context
|
||||
needs_get_attr = True
|
||||
w_up_key = new_key_w_up
|
||||
w_down_key = new_key_w_down
|
||||
else:
|
||||
# Already TRT-LLM format, use directly
|
||||
needs_get_attr = False
|
||||
w_up_arg = w3_w1_stacked
|
||||
w_down_arg = w2_stacked
|
||||
else:
|
||||
# Not get_attr nodes (might be intermediate ops), use directly
|
||||
needs_get_attr = False
|
||||
w_up_arg = w3_w1_stacked
|
||||
w_down_arg = w2_stacked
|
||||
|
||||
# Llama4 INPUT-SIDE routing: apply routing to INPUT before kernel
|
||||
# Cast BOTH input and routing_weights to weight dtype if needed
|
||||
# Critical: BFloat16 * Float32 → Float32 (type promotion) so we cast both to same dtype
|
||||
with graph.inserting_before(node):
|
||||
# Create get_attr nodes INSIDE insertion context for proper topological ordering
|
||||
if needs_get_attr:
|
||||
w_up_arg = graph.get_attr(w_up_key)
|
||||
w_down_arg = graph.get_attr(w_down_key)
|
||||
|
||||
# Get weight dtype to ensure dtype consistency for Llama4 stacked tensors
|
||||
# The fused kernel requires input and weights to have matching dtypes
|
||||
weight_dtype = None
|
||||
if w_up_arg.op == "get_attr":
|
||||
try:
|
||||
weight_tensor = gm.get_parameter(w_up_arg.target)
|
||||
weight_dtype = weight_tensor.dtype
|
||||
except (AttributeError, KeyError):
|
||||
pass
|
||||
input_to_scale = hidden_states
|
||||
routing_to_scale = routing_weights
|
||||
|
||||
if weight_dtype is not None and weight_dtype != torch.float32:
|
||||
input_to_scale = graph.call_function(
|
||||
torch.ops.aten._to_copy.default,
|
||||
args=(hidden_states,),
|
||||
kwargs={"dtype": weight_dtype},
|
||||
)
|
||||
routing_to_scale = graph.call_function(
|
||||
torch.ops.aten._to_copy.default,
|
||||
args=(routing_weights,),
|
||||
kwargs={"dtype": weight_dtype},
|
||||
)
|
||||
|
||||
# Scale input: hidden_states = hidden_states * routing_weights (both same dtype now)
|
||||
scaled_input = graph.call_function(
|
||||
torch.ops.aten.mul.Tensor,
|
||||
args=(input_to_scale, routing_to_scale),
|
||||
)
|
||||
|
||||
# Pass ones to kernel to prevent it from multiplying routing again (already applied)
|
||||
ones_node = graph.call_function(
|
||||
torch.ops.aten.ones_like.default,
|
||||
args=(routing_weights,),
|
||||
)
|
||||
|
||||
new_node = graph.call_function(
|
||||
replacement_op,
|
||||
args=(scaled_input, selected_experts, ones_node, w_up_arg, w_down_arg),
|
||||
kwargs={
|
||||
"act_fn": act_fn,
|
||||
"is_gated_mlp": True,
|
||||
},
|
||||
)
|
||||
|
||||
node.replace_all_uses_with(new_node)
|
||||
graph.erase_node(node)
|
||||
|
||||
|
||||
def _find_lowest_common_ancessor(nodes: list[Node]) -> Optional[Node]:
|
||||
"""
|
||||
Find the lowest common ancestor for a list of nodes in a torch.fx Graph by following
|
||||
@ -1068,6 +978,31 @@ class MatchBmmMoePattern(BaseTransform):
|
||||
continue
|
||||
first_bmm, gate_up_weight = result
|
||||
|
||||
# Get shapes from node metadata
|
||||
if not hasattr(gate_up_weight, "meta") or "val" not in gate_up_weight.meta:
|
||||
continue
|
||||
if not hasattr(down_weight, "meta") or "val" not in down_weight.meta:
|
||||
continue
|
||||
gate_up_shape = gate_up_weight.meta["val"].shape
|
||||
down_shape = down_weight.meta["val"].shape
|
||||
|
||||
# Only support llama4 shaped weights for now
|
||||
if len(gate_up_shape) != len(down_shape) or len(gate_up_shape) != 3:
|
||||
continue
|
||||
|
||||
# Llama4 expectation:
|
||||
# num_experts = gate_up_shape[0] == down_shape[0]
|
||||
# hidden_size = gate_up_shape[1] == down_shape[2]
|
||||
# gate_up_shape[2] == 2 * down_shape[1] (intermediate_size)
|
||||
if gate_up_shape[0] != down_shape[0]:
|
||||
continue
|
||||
|
||||
if gate_up_shape[2] != 2 * down_shape[1]:
|
||||
continue
|
||||
|
||||
if gate_up_shape[1] != down_shape[2]:
|
||||
continue
|
||||
|
||||
# Step 3: Get batched input and trace back to original input and routing
|
||||
batched_input = first_bmm.args[0]
|
||||
if not isinstance(batched_input, Node) or not is_op(batched_input, torch.ops.aten.view):
|
||||
@ -1212,31 +1147,119 @@ class MatchBmmMoePattern(BaseTransform):
|
||||
# If input_routing is False: kernel applies routing to output
|
||||
apply_routing_on_input = input_routing
|
||||
|
||||
# Wrap stacked tensors in single-element lists for torch_moe unified interface
|
||||
with graph.inserting_before(output_node):
|
||||
# Create list nodes for stacked weights
|
||||
w1_list_node = graph.call_function(
|
||||
list,
|
||||
args=([gate_up_weight],),
|
||||
)
|
||||
w2_list_node = graph.call_function(
|
||||
list,
|
||||
args=([down_weight],),
|
||||
)
|
||||
w3_list_node = graph.call_function(
|
||||
list,
|
||||
args=([],), # Empty list for stacked gated MLP
|
||||
# Materialize stacked tensors into per-expert parameters for torch_moe
|
||||
|
||||
# Get the actual tensors from the graph nodes
|
||||
if gate_up_weight.op != "get_attr" or down_weight.op != "get_attr":
|
||||
raise RuntimeError(
|
||||
f"Expected get_attr nodes for BMM MoE weights, got {gate_up_weight.op} and {down_weight.op}"
|
||||
)
|
||||
|
||||
gate_up_tensor = gm.get_parameter(gate_up_weight.target)
|
||||
down_tensor = gm.get_parameter(down_weight.target)
|
||||
|
||||
# Support only llama4 shaped weights for now
|
||||
|
||||
if gate_up_tensor.shape[2] != 2 * down_tensor.shape[1]:
|
||||
raise RuntimeError(
|
||||
f"Expected gate_up_tensor.shape[2] == 2 * down_tensor.shape[1],"
|
||||
f"got {gate_up_tensor.shape[2]} and {down_tensor.shape[1]}"
|
||||
)
|
||||
|
||||
# Get dimensions
|
||||
assert len(gate_up_tensor.shape) == 3, (
|
||||
f"Expected gate_up_tensor.shape to have 3 dimensions, got {len(gate_up_tensor.shape)}"
|
||||
)
|
||||
assert len(down_tensor.shape) == 3, (
|
||||
f"Expected down_tensor.shape to have 3 dimensions, got {len(down_tensor.shape)}"
|
||||
)
|
||||
num_experts = gate_up_tensor.shape[0]
|
||||
assert num_experts == down_tensor.shape[0], (
|
||||
f"Expected num_experts == down_tensor.shape[0],"
|
||||
f"got {num_experts} and {down_tensor.shape[0]}"
|
||||
)
|
||||
hidden_size = gate_up_tensor.shape[1]
|
||||
assert hidden_size == down_tensor.shape[2], (
|
||||
f"Expected hidden_size == down_tensor.shape[2],"
|
||||
f"got {hidden_size} and {down_tensor.shape[2]}"
|
||||
)
|
||||
intermediate_size = gate_up_tensor.shape[2] // 2
|
||||
assert intermediate_size == down_tensor.shape[1], (
|
||||
f"Expected intermediate_size == down_tensor.shape[1],"
|
||||
f"got {intermediate_size} and {down_tensor.shape[1]}"
|
||||
)
|
||||
|
||||
# Store checkpoint keys for hooks
|
||||
gate_up_checkpoint_key = str(gate_up_weight.target)
|
||||
down_checkpoint_key = str(down_weight.target)
|
||||
|
||||
# Split each stacked tensor into per-expert tensors and register as parameters
|
||||
# This creates get_attr nodes that sharding expects
|
||||
w1_keys = []
|
||||
w2_keys = []
|
||||
w3_keys = []
|
||||
|
||||
for expert_idx in range(num_experts):
|
||||
# Register each expert's weight as a separate parameter
|
||||
w1_key = f"bmm_moe_w1_expert_{num_moe_patterns}_{expert_idx}"
|
||||
w2_key = f"bmm_moe_w2_expert_{num_moe_patterns}_{expert_idx}"
|
||||
w3_key = f"bmm_moe_w3_expert_{num_moe_patterns}_{expert_idx}"
|
||||
|
||||
w1_keys.append(w1_key)
|
||||
w2_keys.append(w2_key)
|
||||
w3_keys.append(w3_key)
|
||||
|
||||
w1_param = torch.nn.Parameter(
|
||||
gate_up_tensor[expert_idx, :, :intermediate_size].transpose(0, 1)
|
||||
)
|
||||
w2_param = torch.nn.Parameter(down_tensor[expert_idx].transpose(0, 1))
|
||||
w3_param = torch.nn.Parameter(
|
||||
gate_up_tensor[expert_idx, :, intermediate_size:].transpose(0, 1)
|
||||
)
|
||||
|
||||
gm.register_parameter(w1_key, w1_param)
|
||||
gm.register_parameter(w2_key, w2_param)
|
||||
gm.register_parameter(w3_key, w3_param)
|
||||
|
||||
# Register checkpoint loading hooks - ONE per stacked weight
|
||||
# Hook for gate_up_weight: splits into all w1 and w3 expert weights
|
||||
gm._register_load_state_dict_pre_hook(
|
||||
partial(
|
||||
_bmm_moe_gate_up_split_hook,
|
||||
source_key=gate_up_checkpoint_key,
|
||||
intermediate_size=intermediate_size,
|
||||
w1_keys=w1_keys,
|
||||
w3_keys=w3_keys,
|
||||
)
|
||||
)
|
||||
|
||||
# Hook for down_weight: splits into all w2 expert weights
|
||||
gm._register_load_state_dict_pre_hook(
|
||||
partial(
|
||||
_bmm_moe_down_split_hook,
|
||||
source_key=down_checkpoint_key,
|
||||
w2_keys=w2_keys,
|
||||
)
|
||||
)
|
||||
|
||||
# Now create get_attr nodes for each expert weight
|
||||
# These must be created within the insertion context for proper graph ordering
|
||||
insertion_point = graph.find_nodes(op="get_attr")[0]
|
||||
with graph.inserting_before(insertion_point):
|
||||
w1_nodes = [graph.get_attr(key) for key in w1_keys]
|
||||
w2_nodes = [graph.get_attr(key) for key in w2_keys]
|
||||
w3_nodes = [graph.get_attr(key) for key in w3_keys]
|
||||
|
||||
with graph.inserting_before(output_node):
|
||||
fused_moe_node = graph.call_function(
|
||||
torch.ops.auto_deploy.torch_moe,
|
||||
args=(
|
||||
input_hidden_states,
|
||||
selected_experts,
|
||||
routing_weights_node,
|
||||
w1_list_node,
|
||||
w2_list_node,
|
||||
w3_list_node,
|
||||
w1_nodes,
|
||||
w2_nodes,
|
||||
w3_nodes,
|
||||
),
|
||||
kwargs={
|
||||
"is_gated_mlp": True,
|
||||
|
||||
@ -266,6 +266,14 @@ def _execute_op_in_aux_stream(
|
||||
if input_node.target == torch.ops.aten.view.default:
|
||||
target_input_node = input_node
|
||||
break
|
||||
# Look through dtype cast nodes (aten.to) to find the view node
|
||||
if input_node.target == torch.ops.aten.to:
|
||||
for nested_input in input_node.all_input_nodes:
|
||||
if nested_input.target == torch.ops.aten.view.default:
|
||||
target_input_node = nested_input
|
||||
break
|
||||
if target_input_node is not None:
|
||||
break
|
||||
|
||||
assert target_input_node is not None, f"Target input node not found for node {n}"
|
||||
with graph.inserting_before(target_input_node):
|
||||
|
||||
@ -594,8 +594,6 @@ class BMMShardingInfo(ShardingTransformInfo):
|
||||
class EPShardingInfo(ShardingTransformInfo):
|
||||
"""Configuration for EP sharding transformations."""
|
||||
|
||||
mlp_type: MLPType
|
||||
|
||||
@classmethod
|
||||
def from_node(cls, node: Node, **kwargs) -> "EPShardingInfo":
|
||||
"""
|
||||
@ -613,7 +611,7 @@ class EPShardingInfo(ShardingTransformInfo):
|
||||
|
||||
def apply(self, gm: GraphModule, node: Node) -> None:
|
||||
"""Apply EP sharding transformation to the graph module."""
|
||||
_insert_sharded_moe(gm, node, self.config, mlp_type=self.mlp_type)
|
||||
_insert_sharded_moe(gm, node, self.config)
|
||||
|
||||
|
||||
class MXFP4EPShardingInfo(EPShardingInfo):
|
||||
@ -1065,71 +1063,6 @@ def _resolve_tp_cls_from_node(node: Node):
|
||||
return WeightShardingInfo
|
||||
|
||||
|
||||
def _transform_bmm_moe_weight_param(
|
||||
gm: GraphModule,
|
||||
param_node: Node,
|
||||
lo: int,
|
||||
hi: int,
|
||||
swap_gate_up: bool = False,
|
||||
) -> None:
|
||||
"""Transform a parameter for BMM MoE: slice experts, optionally swap gate/up, transpose.
|
||||
This modifies the parameter in-place and registers a load hook.
|
||||
Does NOT create graph nodes - those should be created separately by the caller.
|
||||
Args:
|
||||
gm: Graph module
|
||||
param_node: The get_attr node for the parameter
|
||||
lo: Start index for expert slicing
|
||||
hi: End index for expert slicing
|
||||
swap_gate_up: If True, swap W1 and W3 (Llama4 -> TRT-LLM format)
|
||||
"""
|
||||
if param_node.op != "get_attr":
|
||||
return # Only works on parameters
|
||||
|
||||
param_key = str(param_node.target)
|
||||
modname, _, param_name = param_key.rpartition(".")
|
||||
submod = gm.get_submodule(modname) if modname else gm
|
||||
full_param = getattr(submod, param_name)
|
||||
|
||||
# Slice the parameter along expert dimension (dim 0)
|
||||
sliced_param = full_param[lo:hi].detach().clone()
|
||||
|
||||
# Swap W1 and W3 if needed (for gate_up weights)
|
||||
# Llama4: (E, H, 2*I) with [W1, W3], TRT-LLM wants [W3, W1]
|
||||
if swap_gate_up and sliced_param.ndim == 3:
|
||||
intermediate_size = sliced_param.shape[2] // 2
|
||||
w1 = sliced_param[:, :, :intermediate_size]
|
||||
w3 = sliced_param[:, :, intermediate_size:]
|
||||
sliced_param = torch.cat([w3, w1], dim=2)
|
||||
|
||||
# Transpose: Llama4 (E, H, X) -> TRT-LLM (E, X, H)
|
||||
transposed_param = sliced_param.transpose(1, 2)
|
||||
transposed_shape = transposed_param.shape
|
||||
|
||||
# Define transformation function for load hook
|
||||
def transform_tensor(t: torch.Tensor) -> torch.Tensor:
|
||||
t_sliced = t[lo:hi]
|
||||
if swap_gate_up and t_sliced.ndim == 3:
|
||||
intermediate_size = t_sliced.shape[2] // 2
|
||||
w1 = t_sliced[:, :, :intermediate_size]
|
||||
w3 = t_sliced[:, :, intermediate_size:]
|
||||
t_sliced = torch.cat([w3, w1], dim=2)
|
||||
return t_sliced.transpose(1, 2).contiguous()
|
||||
|
||||
# Register load hook
|
||||
gm._register_load_state_dict_pre_hook(
|
||||
partial(
|
||||
_load_hook,
|
||||
f_split=transform_tensor,
|
||||
param_key=param_key,
|
||||
param_shape=transposed_shape,
|
||||
)
|
||||
)
|
||||
|
||||
# Replace the parameter with the transformed version
|
||||
new_param = nn.Parameter(transposed_param, requires_grad=False)
|
||||
setattr(submod, param_name, new_param)
|
||||
|
||||
|
||||
def _get_dim0_from_arg(gm: GraphModule, arg: Union[Node, torch.Tensor]) -> int:
|
||||
"""Helper to get the first dimension size of an argument (Node or Tensor)."""
|
||||
if isinstance(arg, torch.Tensor):
|
||||
@ -1195,26 +1128,6 @@ def init_process_grid_from_config(
|
||||
return process_grid
|
||||
|
||||
|
||||
def _canonicalize_node_args(node: Node) -> list:
|
||||
"""
|
||||
Canonicalize the node's arguments.
|
||||
Actions performed:
|
||||
- Flatten list arguments
|
||||
"""
|
||||
new_args = list(node.args)
|
||||
for i in range(len(new_args)):
|
||||
# In FX graphs, the list might be a Node representing a list() call
|
||||
if isinstance(new_args[i], Node):
|
||||
# Check if this is a list() call node
|
||||
if new_args[i].target is list and len(new_args[i].args) == 1:
|
||||
new_args[i] = new_args[i].args[0]
|
||||
if isinstance(new_args[i], (list, tuple)):
|
||||
if len(new_args[i]) == 1:
|
||||
new_args[i] = new_args[i][0]
|
||||
|
||||
return new_args
|
||||
|
||||
|
||||
########################################################
|
||||
# Sharding transform functions
|
||||
########################################################
|
||||
@ -1422,121 +1335,16 @@ def _update_node_args(node: Node, args: tuple) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _insert_sharded_moe_stacked(
|
||||
gm: GraphModule,
|
||||
node: Node,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
allreduce_strategy: AllReduceStrategy,
|
||||
scale_names: Sequence[str] = (),
|
||||
):
|
||||
"""Update the torch_moe node with sliced stacked weight tensors,
|
||||
sharded `selected_experts` and `final_scales(router_logics)`.
|
||||
Add an all_reduce node after the moe node.
|
||||
|
||||
For torch_moe with stacked tensor format (single-element lists containing 3D tensors).
|
||||
|
||||
NOTE: allreduce_strategy is MANDATORY and must be explicitly provided.
|
||||
"""
|
||||
if allreduce_strategy is None:
|
||||
raise ValueError(f"allreduce_strategy must be set for MoE sharding on node {node.name}")
|
||||
|
||||
# Extract the stacked tensors from single-element lists
|
||||
# args[3] = w1_weight (Node representing list with one 3D tensor, or direct list)
|
||||
# args[4] = w2_weight (Node representing list with one 3D tensor, or direct list)
|
||||
|
||||
# Helper to extract tensor node from list (handles both Node and direct list)
|
||||
def extract_tensor_from_list_arg(list_arg):
|
||||
if isinstance(list_arg, Node) and list_arg.target is list:
|
||||
# It's a list() call node - extract from its args
|
||||
return list_arg.args[0][0] # args[0] is the list content, [0] is first element
|
||||
elif isinstance(list_arg, (list, tuple)):
|
||||
# Direct list
|
||||
return list_arg[0]
|
||||
else:
|
||||
raise ValueError(f"Unexpected list format: {type(list_arg)}")
|
||||
|
||||
w3_w1_tensor_node = extract_tensor_from_list_arg(node.args[3])
|
||||
w2_tensor_node = extract_tensor_from_list_arg(node.args[4])
|
||||
num_experts = _get_dim0_from_arg(gm, w3_w1_tensor_node)
|
||||
|
||||
args = list(node.args)
|
||||
|
||||
# -- Handle selected_experts and final_scales sharding --
|
||||
selected_experts = args[1]
|
||||
final_scales = args[2]
|
||||
|
||||
experts_per_rank = num_experts // world_size
|
||||
|
||||
with gm.graph.inserting_before(node):
|
||||
lower = experts_per_rank * rank
|
||||
# selected_experts_local = selected_experts - low
|
||||
selected_experts_local = gm.graph.create_node(
|
||||
"call_function", operator.sub, args=(selected_experts, lower), kwargs={}
|
||||
)
|
||||
|
||||
# For num_experts % world_size != 0 case,
|
||||
# assign the last (num_experts % world_size) experts to the last rank
|
||||
div_node = gm.graph.create_node(
|
||||
"call_function", operator.floordiv, args=(selected_experts, experts_per_rank), kwargs={}
|
||||
)
|
||||
|
||||
comp_op = torch.ge if rank == world_size - 1 else torch.eq
|
||||
rank_mask = gm.graph.create_node("call_function", comp_op, args=(div_node, rank), kwargs={})
|
||||
|
||||
# final_scales_local = final_scales * rank_mask
|
||||
final_scales_local = gm.graph.create_node(
|
||||
"call_function", operator.mul, args=(final_scales, rank_mask), kwargs={}
|
||||
)
|
||||
|
||||
# -- Transform expert weight parameters --
|
||||
local_lo, local_hi = _split_range_last_remainder(num_experts, world_size, rank)
|
||||
|
||||
# Transform w3_w1_stacked: slice experts, swap [W1,W3]->[W3,W1], transpose (E,H,2I)->(E,2I,H)
|
||||
if isinstance(w3_w1_tensor_node, Node):
|
||||
_transform_bmm_moe_weight_param(
|
||||
gm, w3_w1_tensor_node, local_lo, local_hi, swap_gate_up=True
|
||||
)
|
||||
|
||||
# Transform w2_stacked: slice experts, transpose (E,I,H)->(E,H,I)
|
||||
if isinstance(w2_tensor_node, Node):
|
||||
_transform_bmm_moe_weight_param(gm, w2_tensor_node, local_lo, local_hi, swap_gate_up=False)
|
||||
|
||||
# -- Update args (keep same lists/nodes, just with transformed parameters) --
|
||||
args[1] = selected_experts_local
|
||||
args[2] = final_scales_local
|
||||
# args[3] and args[4] stay the same - we modified the parameters in-place
|
||||
|
||||
ad_logger.debug(
|
||||
f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}."
|
||||
)
|
||||
|
||||
node.args = tuple(args)
|
||||
|
||||
# -- add an all_reduce node --
|
||||
with gm.graph.inserting_after(node):
|
||||
dist_node = gm.graph.call_function(
|
||||
torch.ops.auto_deploy.torch_dist_all_reduce.default,
|
||||
args=(node, allreduce_strategy),
|
||||
)
|
||||
node.replace_all_uses_with(dist_node)
|
||||
dist_node.replace_input_with(dist_node, node)
|
||||
|
||||
|
||||
def _insert_sharded_moe(
|
||||
gm: GraphModule,
|
||||
node: Node,
|
||||
config: ShardingTransformConfig,
|
||||
mlp_type: MLPType,
|
||||
scale_names: Sequence[str] = (),
|
||||
):
|
||||
"""Update the torch_moe node with sharded weight lists or stacked tensors,
|
||||
"""Update the torch_moe node with sharded weight lists,
|
||||
sharded `selected_experts` and `final_scales(router_logics)`.
|
||||
Add an all_reduce node after the moe node.
|
||||
|
||||
Handles both:
|
||||
- Standard format: per-expert weight lists
|
||||
- Stacked format: single-element lists containing stacked 3D tensors (Llama4 pattern)
|
||||
|
||||
NOTE: allreduce_strategy is MANDATORY.
|
||||
"""
|
||||
@ -1551,23 +1359,10 @@ def _insert_sharded_moe(
|
||||
raise ValueError(f"allreduce_strategy must be set for MoE sharding on node {node.name}")
|
||||
scale_names = list(scale_names)
|
||||
|
||||
flat_args = _canonicalize_node_args(node)
|
||||
# we have two variants of MoE: stacked and listed:
|
||||
# - stacked: w1, w2, w3 weight args are order-3 tensors, where the 1st dimension corresponds
|
||||
# to the stacked expert weigthts.
|
||||
# - listed: w1, w2, w3 weight args are lists of order-2 tensors, where each expert weight
|
||||
# is a separate entry in the list.
|
||||
if isinstance(flat_args[3], Node):
|
||||
is_stacked = True
|
||||
num_experts = shape(flat_args[3])[0]
|
||||
else:
|
||||
is_stacked = False
|
||||
num_experts = len(flat_args[3])
|
||||
args = list(node.args)
|
||||
|
||||
# -- Handle selected_experts and final_scales sharding --
|
||||
selected_experts = args[1]
|
||||
final_scales = args[2]
|
||||
num_experts = len(args[3])
|
||||
|
||||
experts_per_rank = num_experts // ep_size
|
||||
|
||||
@ -1600,72 +1395,61 @@ def _insert_sharded_moe(
|
||||
args[1] = selected_experts_local
|
||||
args[2] = final_scales_local
|
||||
|
||||
if is_stacked:
|
||||
# bmm-style stacked MoE: sharding is done by slicing the 1st dimension of the stacked weight tensor
|
||||
# if mlp_type == MLPType.FUSED_GATED_MLP:
|
||||
w_gate_up_stacked = flat_args[3]
|
||||
w_down_stacked = flat_args[4]
|
||||
local_lo, local_hi = _split_range_last_remainder(num_experts, ep_size, ep_rank)
|
||||
_transform_bmm_moe_weight_param(
|
||||
gm, w_gate_up_stacked, local_lo, local_hi, swap_gate_up=True
|
||||
# -- Shard expert weights --
|
||||
def get_partition(lst, world_size, rank):
|
||||
num_experts = len(lst)
|
||||
expert_size_per_partition = num_experts // world_size
|
||||
expert_start = rank * expert_size_per_partition
|
||||
# For num_experts % world_size != 0 case,
|
||||
# assign the last (num_experts % world_size) experts to the last rank
|
||||
expert_end = (
|
||||
num_experts if (rank == world_size - 1) else expert_start + expert_size_per_partition
|
||||
)
|
||||
_transform_bmm_moe_weight_param(gm, w_down_stacked, local_lo, local_hi, swap_gate_up=False)
|
||||
else:
|
||||
# listed MoE: sharding is done by taking a range of the listed weight tensors
|
||||
|
||||
# -- Shard expert weights --
|
||||
def get_partition(lst, world_size, rank):
|
||||
num_experts = len(lst)
|
||||
expert_size_per_partition = num_experts // world_size
|
||||
expert_start = rank * expert_size_per_partition
|
||||
# For num_experts % world_size != 0 case,
|
||||
# assign the last (num_experts % world_size) experts to the last rank
|
||||
expert_end = (
|
||||
num_experts
|
||||
if (rank == world_size - 1)
|
||||
else expert_start + expert_size_per_partition
|
||||
)
|
||||
return lst[expert_start:expert_end]
|
||||
return lst[expert_start:expert_end], lst[:expert_start] + lst[expert_end:]
|
||||
|
||||
w_up_list_sharded = get_partition(args[3], ep_size, ep_rank)
|
||||
w_down_list_sharded = get_partition(args[4], ep_size, ep_rank)
|
||||
w_gate_list_sharded = get_partition(args[5], ep_size, ep_rank)
|
||||
w_up_list_sharded, w_up_list_to_remove = get_partition(args[3], ep_size, ep_rank)
|
||||
w_down_list_sharded, w_down_list_to_remove = get_partition(args[4], ep_size, ep_rank)
|
||||
w_gate_list_sharded, w_gate_list_to_remove = get_partition(args[5], ep_size, ep_rank)
|
||||
|
||||
# if tp_size > 1, we do 2D EP+TP sharding.
|
||||
# we add TP sharding of all expert weights.
|
||||
for w_up in w_up_list_sharded + w_gate_list_sharded:
|
||||
shard_weight_tensor(
|
||||
gm=gm,
|
||||
weight_tensor=gm.get_parameter(w_up.target),
|
||||
param_key=w_up.target,
|
||||
dim=SplitDimension.COLUMN,
|
||||
rank=tp_rank,
|
||||
world_size=tp_size,
|
||||
)
|
||||
# here we don't need to add all-reduce: it's enough to have
|
||||
# just one all-reduce after the whole EP+TP sharded MoE node.
|
||||
for w_down in w_down_list_sharded:
|
||||
shard_weight_tensor(
|
||||
gm=gm,
|
||||
weight_tensor=gm.get_parameter(w_down.target),
|
||||
param_key=w_down.target,
|
||||
dim=SplitDimension.ROW,
|
||||
rank=tp_rank,
|
||||
world_size=tp_size,
|
||||
)
|
||||
|
||||
# -- Update args --
|
||||
args[3] = w_up_list_sharded
|
||||
args[4] = w_down_list_sharded
|
||||
args[5] = w_gate_list_sharded
|
||||
|
||||
# Shard scales for quantized ops
|
||||
for i in range(len(scale_names) * 3): # 3 layers (w1, w2, w3) × #scale_names per layer
|
||||
args[6 + i] = get_partition(args[6 + i], ep_size, ep_rank)
|
||||
|
||||
ad_logger.debug(
|
||||
f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}."
|
||||
# if tp_size > 1, we do 2D EP+TP sharding.
|
||||
# we add TP sharding of all expert weights.
|
||||
for w_up in w_up_list_sharded + w_gate_list_sharded:
|
||||
shard_weight_tensor(
|
||||
gm=gm,
|
||||
weight_tensor=gm.get_parameter(w_up.target),
|
||||
param_key=w_up.target,
|
||||
dim=SplitDimension.COLUMN,
|
||||
rank=tp_rank,
|
||||
world_size=tp_size,
|
||||
)
|
||||
# here we don't need to add all-reduce: it's enough to have
|
||||
# just one all-reduce after the whole EP+TP sharded MoE node.
|
||||
for w_down in w_down_list_sharded:
|
||||
shard_weight_tensor(
|
||||
gm=gm,
|
||||
weight_tensor=gm.get_parameter(w_down.target),
|
||||
param_key=w_down.target,
|
||||
dim=SplitDimension.ROW,
|
||||
rank=tp_rank,
|
||||
world_size=tp_size,
|
||||
)
|
||||
|
||||
# -- Update args --
|
||||
args[3] = w_up_list_sharded
|
||||
args[4] = w_down_list_sharded
|
||||
args[5] = w_gate_list_sharded
|
||||
|
||||
# Shard scales for quantized ops
|
||||
scales_to_remove = []
|
||||
for i in range(len(scale_names) * 3): # 3 layers (w1, w2, w3) × #scale_names per layer
|
||||
sharded, to_remove = get_partition(args[6 + i], ep_size, ep_rank)
|
||||
args[6 + i] = sharded
|
||||
scales_to_remove.extend(to_remove)
|
||||
|
||||
ad_logger.debug(
|
||||
f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}."
|
||||
)
|
||||
node.args = tuple(args)
|
||||
|
||||
# -- add an all_reduce node --
|
||||
@ -1676,6 +1460,15 @@ def _insert_sharded_moe(
|
||||
node.replace_all_uses_with(dist_node)
|
||||
dist_node.replace_input_with(dist_node, node)
|
||||
|
||||
gm.graph.eliminate_dead_code()
|
||||
# Expert weights registered via gm.register_parameter() are top-level attributes.
|
||||
# Unlike submodules, these aren't cleaned up by eliminate_dead_code() or
|
||||
# delete_all_unused_submodules() - must delete manually after removing their get_attr nodes.
|
||||
for expert in (
|
||||
w_up_list_to_remove + w_down_list_to_remove + w_gate_list_to_remove + scales_to_remove
|
||||
):
|
||||
delattr(gm, expert.target)
|
||||
|
||||
|
||||
def _slice_expert_dim(gm: GraphModule, tensor_node: Node, lo: int, hi: int) -> Node:
|
||||
"""Return tensor_node[lo:hi, ...] via aten.slice along dim 0."""
|
||||
@ -2652,21 +2445,7 @@ def detect_ep_shard(
|
||||
for node in list(gm.graph.nodes):
|
||||
if not is_any_moe_op(node):
|
||||
continue
|
||||
args = _canonicalize_node_args(node)
|
||||
if isinstance(args[3], Node):
|
||||
mlp_type = MLPType.FUSED_GATED_MLP
|
||||
else:
|
||||
if len(args[5]) > 0:
|
||||
mlp_type = MLPType.GATED_MLP
|
||||
else:
|
||||
mlp_type = MLPType.MLP
|
||||
if transform_container.add(
|
||||
EPShardingInfo.from_node(
|
||||
node,
|
||||
config=config,
|
||||
mlp_type=mlp_type,
|
||||
)
|
||||
):
|
||||
if transform_container.add(EPShardingInfo.from_node(node, config=config)):
|
||||
num_moe_patterns += 1
|
||||
|
||||
ad_logger.info(f"Found {num_moe_patterns} MoE patterns")
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from _torch.helpers import reference_bmm_moe_torch, reference_moe_torch
|
||||
from _torch.helpers import reference_moe_torch
|
||||
from _torch_test_utils import fp4_compatible, fp8_compatible, trtllm_ops_available
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
|
||||
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp4_global_scale
|
||||
from tensorrt_llm._torch.modules.fused_moe import MoE # noqa: F401
|
||||
from tensorrt_llm._torch.utils import ActivationType
|
||||
|
||||
|
||||
def setup_moe_test(dtype, num_experts):
|
||||
@ -152,62 +151,6 @@ def test_moe_op_run(dtype):
|
||||
torch.testing.assert_close(output_torch_moe, ref_output, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_bmm_based_moe_op_run(dtype):
|
||||
num_experts = 3
|
||||
(
|
||||
x,
|
||||
selected_experts,
|
||||
final_scales,
|
||||
fused_w3_w1_stacked_weight,
|
||||
fused_w2_weight,
|
||||
) = setup_bmm_moe_test(dtype, num_experts)
|
||||
|
||||
with torch.inference_mode():
|
||||
x = final_scales * x
|
||||
selected_experts = torch.ones_like(selected_experts)
|
||||
# Use torch_moe with stacked tensor format (single-element lists)
|
||||
output_torch_moe = torch.ops.auto_deploy.torch_moe(
|
||||
x,
|
||||
selected_experts,
|
||||
final_scales,
|
||||
[fused_w3_w1_stacked_weight], # Wrap in list for unified interface
|
||||
[fused_w2_weight], # Wrap in list for unified interface
|
||||
[], # Empty w3_weight list for stacked gated MLP
|
||||
is_gated_mlp=True,
|
||||
act_fn=ActivationType.Silu,
|
||||
apply_routing_on_input=True,
|
||||
)
|
||||
output_torch_fused_moe = torch.ops.auto_deploy.torch_moe_fused(
|
||||
x,
|
||||
selected_experts,
|
||||
final_scales,
|
||||
fused_w3_w1_stacked_weight,
|
||||
fused_w2_weight,
|
||||
)
|
||||
output_trt_fused_moe = torch.ops.auto_deploy.trtllm_moe_fused(
|
||||
x,
|
||||
selected_experts,
|
||||
final_scales,
|
||||
fused_w3_w1_stacked_weight,
|
||||
fused_w2_weight,
|
||||
)
|
||||
ref_output = reference_bmm_moe_torch(
|
||||
x,
|
||||
selected_experts,
|
||||
final_scales,
|
||||
fused_w3_w1_stacked_weight,
|
||||
fused_w2_weight,
|
||||
apply_routing_on_input=True,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(output_trt_fused_moe, output_torch_fused_moe, rtol=5e-2, atol=5e-2)
|
||||
torch.testing.assert_close(output_trt_fused_moe, ref_output, rtol=5e-2, atol=5e-2)
|
||||
torch.testing.assert_close(output_torch_fused_moe, ref_output, rtol=1e-5, atol=1e-5)
|
||||
torch.testing.assert_close(output_torch_moe, ref_output, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.skipif(not fp8_compatible(), reason="Requires fp8 support")
|
||||
def test_fp8_moe_op_run(dtype):
|
||||
|
||||
@ -0,0 +1,249 @@
|
||||
"""Tests for BMM MoE checkpoint loading hooks."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.transform.library.fused_moe import (
|
||||
_bmm_moe_down_split_hook,
|
||||
_bmm_moe_gate_up_split_hook,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gate_up_stacked_weight():
|
||||
"""Fixture for stacked gate_up weight in Llama4 format (E, H, 2*I)."""
|
||||
num_experts = 4
|
||||
hidden_size = 64
|
||||
intermediate_size = 32
|
||||
return torch.randn(num_experts, hidden_size, intermediate_size * 2)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def down_stacked_weight():
|
||||
"""Fixture for stacked down weight in Llama4 format (E, I, H)."""
|
||||
num_experts = 4
|
||||
hidden_size = 64
|
||||
intermediate_size = 32
|
||||
return torch.randn(num_experts, intermediate_size, hidden_size)
|
||||
|
||||
|
||||
class TestBmmMoeGateUpSplitHook:
|
||||
"""Tests for _bmm_moe_gate_up_split_hook."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_experts,hidden_size,intermediate_size",
|
||||
[
|
||||
(4, 64, 32),
|
||||
(8, 128, 64),
|
||||
(2, 32, 16),
|
||||
],
|
||||
)
|
||||
def test_splits_stacked_weights_into_per_expert_w1_w3(
|
||||
self, num_experts, hidden_size, intermediate_size
|
||||
):
|
||||
"""Verify gate_up hook splits stacked weights into w1/w3 per expert."""
|
||||
# Llama4 format: (E, H, 2*I)
|
||||
stacked = torch.randn(num_experts, hidden_size, intermediate_size * 2)
|
||||
state_dict = {"gate_up_weight": stacked}
|
||||
w1_keys = [f"w1_{i}" for i in range(num_experts)]
|
||||
w3_keys = [f"w3_{i}" for i in range(num_experts)]
|
||||
|
||||
_bmm_moe_gate_up_split_hook(
|
||||
state_dict,
|
||||
"",
|
||||
source_key="gate_up_weight",
|
||||
intermediate_size=intermediate_size,
|
||||
w1_keys=w1_keys,
|
||||
w3_keys=w3_keys,
|
||||
)
|
||||
|
||||
for i in range(num_experts):
|
||||
assert w1_keys[i] in state_dict
|
||||
assert w3_keys[i] in state_dict
|
||||
# After transpose: (I, H)
|
||||
assert state_dict[w1_keys[i]].shape == (intermediate_size, hidden_size)
|
||||
assert state_dict[w3_keys[i]].shape == (intermediate_size, hidden_size)
|
||||
|
||||
def test_w1_w3_content_matches_original_stacked(self):
|
||||
"""Verify split w1/w3 tensors match the original stacked content."""
|
||||
num_experts = 2
|
||||
hidden_size = 32
|
||||
intermediate_size = 16
|
||||
|
||||
stacked = torch.randn(num_experts, hidden_size, intermediate_size * 2)
|
||||
state_dict = {"gate_up_weight": stacked.clone()}
|
||||
w1_keys = [f"w1_{i}" for i in range(num_experts)]
|
||||
w3_keys = [f"w3_{i}" for i in range(num_experts)]
|
||||
|
||||
_bmm_moe_gate_up_split_hook(
|
||||
state_dict,
|
||||
"",
|
||||
source_key="gate_up_weight",
|
||||
intermediate_size=intermediate_size,
|
||||
w1_keys=w1_keys,
|
||||
w3_keys=w3_keys,
|
||||
)
|
||||
|
||||
for i in range(num_experts):
|
||||
# w1 is first half: stacked[i, :, :intermediate_size].T
|
||||
expected_w1 = stacked[i, :, :intermediate_size].transpose(0, 1).contiguous()
|
||||
# w3 is second half: stacked[i, :, intermediate_size:].T
|
||||
expected_w3 = stacked[i, :, intermediate_size:].transpose(0, 1).contiguous()
|
||||
|
||||
torch.testing.assert_close(state_dict[w1_keys[i]], expected_w1)
|
||||
torch.testing.assert_close(state_dict[w3_keys[i]], expected_w3)
|
||||
|
||||
def test_handles_missing_source_key(self):
|
||||
"""Verify hook does nothing when source key is missing."""
|
||||
state_dict = {}
|
||||
|
||||
# Should not raise
|
||||
_bmm_moe_gate_up_split_hook(
|
||||
state_dict,
|
||||
"",
|
||||
source_key="missing_key",
|
||||
intermediate_size=32,
|
||||
w1_keys=["w1"],
|
||||
w3_keys=["w3"],
|
||||
)
|
||||
|
||||
assert len(state_dict) == 0
|
||||
|
||||
@pytest.mark.parametrize("prefix", ["", "model.layers.0.moe."])
|
||||
def test_works_with_module_prefix(self, prefix):
|
||||
"""Verify hook works correctly with module path prefix."""
|
||||
num_experts = 2
|
||||
hidden_size = 32
|
||||
intermediate_size = 16
|
||||
|
||||
stacked = torch.randn(num_experts, hidden_size, intermediate_size * 2)
|
||||
state_dict = {f"{prefix}gate_up_weight": stacked}
|
||||
w1_keys = [f"w1_{i}" for i in range(num_experts)]
|
||||
w3_keys = [f"w3_{i}" for i in range(num_experts)]
|
||||
|
||||
_bmm_moe_gate_up_split_hook(
|
||||
state_dict,
|
||||
prefix,
|
||||
source_key="gate_up_weight",
|
||||
intermediate_size=intermediate_size,
|
||||
w1_keys=w1_keys,
|
||||
w3_keys=w3_keys,
|
||||
)
|
||||
|
||||
for i in range(num_experts):
|
||||
assert f"{prefix}{w1_keys[i]}" in state_dict
|
||||
assert f"{prefix}{w3_keys[i]}" in state_dict
|
||||
|
||||
|
||||
class TestBmmMoeDownSplitHook:
|
||||
"""Tests for _bmm_moe_down_split_hook."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_experts,hidden_size,intermediate_size",
|
||||
[
|
||||
(4, 64, 32),
|
||||
(8, 128, 64),
|
||||
(2, 32, 16),
|
||||
],
|
||||
)
|
||||
def test_splits_stacked_weights_into_per_expert_w2(
|
||||
self, num_experts, hidden_size, intermediate_size
|
||||
):
|
||||
"""Verify down hook splits stacked weights into w2 per expert."""
|
||||
# Llama4 format: (E, I, H)
|
||||
stacked = torch.randn(num_experts, intermediate_size, hidden_size)
|
||||
state_dict = {"down_weight": stacked}
|
||||
w2_keys = [f"w2_{i}" for i in range(num_experts)]
|
||||
|
||||
_bmm_moe_down_split_hook(
|
||||
state_dict,
|
||||
"",
|
||||
source_key="down_weight",
|
||||
w2_keys=w2_keys,
|
||||
)
|
||||
|
||||
for i in range(num_experts):
|
||||
assert w2_keys[i] in state_dict
|
||||
# After transpose: (H, I)
|
||||
assert state_dict[w2_keys[i]].shape == (hidden_size, intermediate_size)
|
||||
|
||||
def test_w2_content_matches_original_stacked(self):
|
||||
"""Verify split w2 tensors match the original stacked content."""
|
||||
num_experts = 2
|
||||
hidden_size = 32
|
||||
intermediate_size = 16
|
||||
|
||||
stacked = torch.randn(num_experts, intermediate_size, hidden_size)
|
||||
state_dict = {"down_weight": stacked.clone()}
|
||||
w2_keys = [f"w2_{i}" for i in range(num_experts)]
|
||||
|
||||
_bmm_moe_down_split_hook(
|
||||
state_dict,
|
||||
"",
|
||||
source_key="down_weight",
|
||||
w2_keys=w2_keys,
|
||||
)
|
||||
|
||||
for i in range(num_experts):
|
||||
expected_w2 = stacked[i].transpose(0, 1).contiguous()
|
||||
torch.testing.assert_close(state_dict[w2_keys[i]], expected_w2)
|
||||
|
||||
def test_handles_missing_source_key(self):
|
||||
"""Verify hook does nothing when source key is missing."""
|
||||
state_dict = {}
|
||||
|
||||
_bmm_moe_down_split_hook(
|
||||
state_dict,
|
||||
"",
|
||||
source_key="missing_key",
|
||||
w2_keys=["w2"],
|
||||
)
|
||||
|
||||
assert len(state_dict) == 0
|
||||
|
||||
|
||||
class TestBmmMoeHooksIntegration:
|
||||
"""Integration tests for BMM MoE hooks working together."""
|
||||
|
||||
def test_full_checkpoint_loading_flow(self):
|
||||
"""Test the full flow: split gate_up and down into per-expert weights."""
|
||||
num_experts = 4
|
||||
hidden_size = 64
|
||||
intermediate_size = 32
|
||||
|
||||
# Simulate a checkpoint with stacked weights
|
||||
gate_up_stacked = torch.randn(num_experts, hidden_size, intermediate_size * 2)
|
||||
down_stacked = torch.randn(num_experts, intermediate_size, hidden_size)
|
||||
|
||||
state_dict = {
|
||||
"gate_up_weight": gate_up_stacked.clone(),
|
||||
"down_weight": down_stacked.clone(),
|
||||
}
|
||||
|
||||
w1_keys = [f"w1_{i}" for i in range(num_experts)]
|
||||
w2_keys = [f"w2_{i}" for i in range(num_experts)]
|
||||
w3_keys = [f"w3_{i}" for i in range(num_experts)]
|
||||
|
||||
# Step 1: Split gate_up into w1 and w3
|
||||
_bmm_moe_gate_up_split_hook(
|
||||
state_dict,
|
||||
"",
|
||||
source_key="gate_up_weight",
|
||||
intermediate_size=intermediate_size,
|
||||
w1_keys=w1_keys,
|
||||
w3_keys=w3_keys,
|
||||
)
|
||||
|
||||
# Step 2: Split down into w2
|
||||
_bmm_moe_down_split_hook(
|
||||
state_dict,
|
||||
"",
|
||||
source_key="down_weight",
|
||||
w2_keys=w2_keys,
|
||||
)
|
||||
|
||||
# Verify: all per-expert weights present with correct shapes
|
||||
for i in range(num_experts):
|
||||
assert state_dict[w1_keys[i]].shape == (intermediate_size, hidden_size)
|
||||
assert state_dict[w2_keys[i]].shape == (hidden_size, intermediate_size)
|
||||
assert state_dict[w3_keys[i]].shape == (intermediate_size, hidden_size)
|
||||
Loading…
Reference in New Issue
Block a user