[#9717][chore] Standardize MoE weights interface (#10295)

Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
This commit is contained in:
tcherckez-nvidia 2025-12-31 14:37:18 +02:00 committed by GitHub
parent ef1d4a40b5
commit 464847c6be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 561 additions and 619 deletions

View File

@ -102,10 +102,6 @@ def torch_moe(
Unified Mixture-of-Experts (MoE) operator that uses a Mixtral-style dispatch
(token routing + index_add_ accumulation) and a selectable per-expert MLP.
Supports both:
- Standard MoE with per-expert weight lists (apply_routing_on_input=False)
- Llama4 MoE with stacked weight tensors (apply_routing_on_input=True)
Parameters:
x (torch.Tensor): Input tensor of shape (B, H) or (B, S, H), where B is the batch size,
S is the sequence length, and H is the hidden size.
@ -113,87 +109,31 @@ def torch_moe(
of the selected experts for each token. Only experts within range [0,num_experts) is processed
routing_weights (torch.Tensor): A tensor of shape (B, TOP_K) or (B*S, TOP_K) containing the normalized
routing weights for the selected experts.
- Standard MoE: softmax normalized weights
- Llama4 MoE: sigmoid activated weights
w1_weight:
For per-expert lists:
is_gated_mlp==True: List of W1 with shape (I, H) "gate" projection.
is_gated_mlp==False: List of W_up with shape (I, H) up projection.
For stacked tensors (Llama4):
Single-element list containing stacked w3_w1 tensor with shape (E, 2*I, H) in TRT-LLM format
w2_weight:
For per-expert lists:
List of W2/W_down with shape (H, I) down projection.
For stacked tensors (Llama4):
Single-element list containing stacked w2 tensor with shape (E, H, I) in TRT-LLM format
w3_weight:
For per-expert lists with gated_mlp:
List of W3 with shape (I, H) "up" (second) projection in gated MLP.
For is_gated_mlp==False or stacked tensors:
pass an empty list []; ignored.
is_gated_mlp:
Selects the per-expert MLP computation:
is_gated_mlp==True (default, Mixtral/DeepSeek/Llama4-style):
y = W2( act(W1 x) * (W3 x) )
is_gated_mlp==False (NemotronH-style 2-layer MLP):
y = W_down( act(W_up x) )
act_fn:
Elementwise activation applied inside the expert MLP.
w1_weight: List of per-expert weight tensors of up projection.
w2_weight: List of per-expert weight tensors of down projection.
w3_weight: List of per-expert weight tensors of gate projection.
is_gated_mlp: If True, use a gated MLP. If False, use a simple MLP.
act_fn: Activation function applied inside the expert MLP.
Supported: ActivationType.Silu (default), ActivationType.Relu2 (ReLU then square).
apply_routing_on_input:
If True (Llama4 pattern): multiply routing weights with INPUT before MLP
Result: act(input * routing_weight) - routing affects activation
If False (standard pattern): multiply routing weights with OUTPUT after MLP
Result: act(input) * routing_weight - routing scales output
apply_routing_on_input: If True, multiply routing weights with INPUT before MLP
This means: silu(input * routing_weight)
If False, multiply routing weights with OUTPUT after MLP
This means: silu(input) * routing_weight
Returns:
torch.Tensor: Output tensor with the same shape as the input x.
"""
torch_act_fn = _resolve_torch_fn(act_fn)
# Detect if using stacked tensor format (Llama4) vs per-expert lists (standard)
is_stacked = len(w1_weight) == 1 and w1_weight[0].ndim == 3
# Todo: either change torch_moe to use a single condition, or refactor this code.
# it should be :
# is_gated_mlp:
# stacked:
# ...
# not stacked:
# .
# else:
# assert (not stacked)
# ...
# .
if is_stacked:
# Llama4 stacked tensor format - only supports gated_mlp
if not is_gated_mlp:
raise ValueError("Stacked tensor format only supports gated MLP style")
w3_w1_stacked = w1_weight[0] # (E, 2*I, H)
intermediate_size = w3_w1_stacked.shape[1] // 2
w2_stacked = w2_weight[0] # (E, H, I)
def make_mlp(idx: int):
gate_up = w3_w1_stacked[idx] # (2*I, H)
W3 = gate_up[:intermediate_size, :] # (I, H)
W1 = gate_up[intermediate_size:, :] # (I, H)
W2 = w2_stacked[idx] # (H, I)
weight_dtype = W1.dtype
return lambda inp: F.linear(
torch_act_fn(F.linear(inp.to(weight_dtype), W1))
* F.linear(inp.to(weight_dtype), W3),
W2,
)
mlps = [make_mlp(idx) for idx in range(w3_w1_stacked.shape[0])]
elif is_gated_mlp:
mlps = []
if is_gated_mlp:
# Standard per-expert list format with gated MLP
def make_mlp(i: int):
W1 = w1_weight[i] # (I, H)
W2 = w2_weight[i] # (H, I)
W3 = w3_weight[i] # (I, H)
return lambda inp: F.linear(torch_act_fn(F.linear(inp, W1)) * F.linear(inp, W3), W2)
return lambda inp: F.linear(
torch_act_fn(F.linear(inp.to(W1.dtype), W1)) * F.linear(inp.to(W3.dtype), W3), W2
)
mlps = [make_mlp(i) for i in range(len(w1_weight))]

View File

@ -1,4 +1,5 @@
from collections import defaultdict
from functools import partial
from typing import Dict, List, Literal, Optional, Tuple, Type
import torch
@ -20,6 +21,60 @@ from ..interface import (
)
def _bmm_moe_gate_up_split_hook(
state_dict,
prefix,
*args,
source_key: str,
intermediate_size: int,
w1_keys: List[str],
w3_keys: List[str],
):
"""Hook to split gate_up_weight into all per-expert w1 and w3 weights.
Args:
source_key: Original stacked weight key (e.g., "gate_up_weight")
intermediate_size: Intermediate dimension size
w1_keys: List of target parameter keys for w1 weights
w3_keys: List of target parameter keys for w3 weights
"""
source_full_key = prefix + source_key
if source_full_key in state_dict:
stacked_tensor = state_dict[source_full_key]
# Split on last dim: (E, H, 2I) -> 2x (E, H, I)
w1_stacked, w3_stacked = stacked_tensor.split(intermediate_size, dim=2)
# Transpose and contiguous in batch, then unbind into views
w1_experts = w1_stacked.transpose(1, 2).contiguous().unbind(0)
w3_experts = w3_stacked.transpose(1, 2).contiguous().unbind(0)
for w1_key, w3_key, w1, w3 in zip(w1_keys, w3_keys, w1_experts, w3_experts):
state_dict[prefix + w1_key] = w1
state_dict[prefix + w3_key] = w3
def _bmm_moe_down_split_hook(
state_dict,
prefix,
*args,
source_key: str,
w2_keys: List[str],
):
"""Hook to split down_weight into all per-expert w2 weights.
Args:
source_key: Original stacked weight key (e.g., "down_weight")
w2_keys: List of target parameter keys for w2 weights
"""
source_full_key = prefix + source_key
if source_full_key in state_dict:
stacked_tensor = state_dict[source_full_key]
# Transpose and contiguous in batch, then unbind into views
w2_experts = stacked_tensor.transpose(1, 2).contiguous().unbind(0)
for w2_key, w2 in zip(w2_keys, w2_experts):
state_dict[prefix + w2_key] = w2
def _insert_fused_moe_ops(gm: GraphModule, backend: Literal["auto", "trtllm", "triton"]) -> int:
"""Replace torch MoE ops with fused backend-specific implementations.
@ -41,65 +96,29 @@ def _insert_fused_moe_ops(gm: GraphModule, backend: Literal["auto", "trtllm", "t
}[backend]
for node in graph.nodes:
if not is_op(node, torch.ops.auto_deploy.torch_moe):
continue
if is_op(node, torch.ops.auto_deploy.torch_moe):
(is_gated_mlp, act_fn) = extract_op_args(node, "is_gated_mlp", "act_fn")
# Detect if this is a stacked MoE (Llama4 pattern) or per-expert list (standard pattern)
(apply_routing_val, w1_weight_list) = extract_op_args(
node, "apply_routing_on_input", "w1_weight"
)
# Check if it's stacked format: single-element list with 3D tensor
is_stacked_moe = False
if apply_routing_val:
# In FX graphs, w1_weight_list might be a Node representing a list() call
list_content = None
if isinstance(w1_weight_list, Node) and w1_weight_list.target is list:
# Extract from list() call node
if w1_weight_list.args:
list_content = w1_weight_list.args[0]
elif isinstance(w1_weight_list, (list, tuple)):
# Direct Python list
list_content = w1_weight_list
# Check if it's a single-element list with a 3D tensor
if list_content is not None and len(list_content) == 1:
w1_node = list_content[0]
if isinstance(w1_node, Node) and w1_node.op == "get_attr":
try:
w1_tensor = gm.get_parameter(w1_node.target)
is_stacked_moe = w1_tensor.ndim == 3
except (AttributeError, KeyError):
pass
(is_gated_mlp, act_fn) = extract_op_args(node, "is_gated_mlp", "act_fn")
if is_stacked_moe:
# Stacked MoE (Llama4 pattern): only supports gated MLP
_process_llama4_stacked_moe_node(
gm, graph, node, replacement_op, act_fn, fused_key_counter
)
else:
# Standard MoE with per-expert weight lists
assert backend != "triton" or not is_gated_mlp, (
"Triton backend only supports mlp style."
)
_process_regular_moe_node(
_process_moe_node(
gm, graph, node, replacement_op, is_gated_mlp, act_fn, fused_key_counter
)
fused_key_counter += 1
fused_key_counter += 1
# Delete the unstacked weights immediately to save GPU memory
# This will happen automatically after the graph is canonicalized, but for large models we'll run out of memory
# during the transformation itself.
gm.graph.eliminate_dead_code()
gm.delete_all_unused_submodules()
# Delete the unstacked weights immediately to save GPU memory
# This will happen automatically after the graph is canonicalized,
# but for large models we'll run out of memory during the transformation itself.
gm.graph.eliminate_dead_code()
gm.delete_all_unused_submodules()
return fused_key_counter
def _process_regular_moe_node(
def _process_moe_node(
gm: GraphModule,
graph: torch.fx.Graph,
node: Node,
@ -113,7 +132,15 @@ def _process_regular_moe_node(
Stacks weight parameters and creates a fused MoE node.
The kernel applies routing weights to the output.
"""
hidden_states, selected_experts, routing_weights, w1_list, w2_list, w3_list = extract_op_args(
(
hidden_states,
selected_experts,
routing_weights,
w1_list,
w2_list,
w3_list,
apply_routing_on_input,
) = extract_op_args(
node,
"x",
"selected_experts",
@ -121,6 +148,7 @@ def _process_regular_moe_node(
"w1_weight",
"w2_weight",
"w3_weight",
"apply_routing_on_input",
)
# Stack weights based on MLP style
@ -157,6 +185,27 @@ def _process_regular_moe_node(
with graph.inserting_before(node):
w_up_arg = graph.get_attr(new_key_w_up)
w_down_arg = graph.get_attr(new_key_w_down)
# Get weight dtype for casting - fused kernel requires activation dtype to match weight dtype
weight_dtype = fused_w_up_experts.dtype
if apply_routing_on_input:
# Scale input: hidden_states = hidden_states * routing_weights
hidden_states = graph.call_function(
torch.ops.aten.mul.Tensor,
args=(hidden_states, routing_weights),
)
# Pass ones to kernel to prevent it from multiplying routing again (already applied)
routing_weights = graph.call_function(
torch.ops.aten.ones_like.default,
args=(routing_weights,),
)
# Kernel requires activation dtype to match weight dtype
hidden_states = graph.call_function(
torch.ops.aten.to,
args=(hidden_states, weight_dtype),
)
new_node = graph.call_function(
replacement_op,
@ -171,145 +220,6 @@ def _process_regular_moe_node(
graph.erase_node(node)
def _process_llama4_stacked_moe_node(
gm: GraphModule,
graph: torch.fx.Graph,
node: Node,
replacement_op,
act_fn: ActivationType,
fused_key_counter: int,
) -> None:
"""Process a single Llama4 MoE node with pre-stacked weight tensors.
Only supports gated MLP (SwiGLU-style) architecture.
Converts Llama4 format weights to TRT-LLM format to standardize all downstream ops.
Applies routing weights to INPUT before the fused kernel to prevent double multiplication.
This is the Llama4 pattern where weights are already stacked across experts.
Result: silu(input * routing_weight) - routing affects activation.
"""
# torch_moe with stacked format: weights are in single-element lists
hidden_states, selected_experts, routing_weights, w1_list, w2_list = extract_op_args(
node,
"x",
"selected_experts",
"routing_weights",
"w1_weight",
"w2_weight",
)
# Extract the single stacked tensor from each list
# Handle both FX graph Nodes (list() calls) and direct Python lists
def extract_from_list_arg(list_arg):
if isinstance(list_arg, Node) and list_arg.target is list:
# Extract from list() call node
return list_arg.args[0][0] if list_arg.args else None
elif isinstance(list_arg, (list, tuple)):
# Direct Python list
return list_arg[0]
else:
raise ValueError(f"Unexpected list format: {type(list_arg)}")
w3_w1_stacked = extract_from_list_arg(w1_list)
w2_stacked = extract_from_list_arg(w2_list)
# Convert Llama4 format to TRT-LLM format if needed
# This standardizes all downstream ops to only handle TRT-LLM format
if w3_w1_stacked.op == "get_attr" and w2_stacked.op == "get_attr":
gate_up_weight = gm.get_parameter(w3_w1_stacked.target)
down_weight = gm.get_parameter(w2_stacked.target)
# Detect format:
# - Llama4: gate_up is (E, H, 2*I) and down is (E, I, H)
# - TRT-LLM: gate_up is (E, 2*I, H) and down is (E, H, I)
# If both have H in middle dimension, they're Llama4 format
is_llama4 = gate_up_weight.shape[1] == down_weight.shape[2]
if is_llama4:
# Convert Llama4 (E, H, 2*I) -> TRT-LLM (E, 2*I, H)
gate_up_trtllm = gate_up_weight.transpose(1, 2).contiguous()
# Convert Llama4 (E, I, H) -> TRT-LLM (E, H, I)
down_trtllm = down_weight.transpose(1, 2).contiguous()
# Register converted weights
new_key_w_up = f"llama4_to_trtllm_w3_w1_{fused_key_counter}"
new_key_w_down = f"llama4_to_trtllm_w2_{fused_key_counter}"
gm.register_parameter(new_key_w_up, torch.nn.Parameter(gate_up_trtllm))
gm.register_parameter(new_key_w_down, torch.nn.Parameter(down_trtllm))
# Store keys to create get_attr nodes later in insertion context
needs_get_attr = True
w_up_key = new_key_w_up
w_down_key = new_key_w_down
else:
# Already TRT-LLM format, use directly
needs_get_attr = False
w_up_arg = w3_w1_stacked
w_down_arg = w2_stacked
else:
# Not get_attr nodes (might be intermediate ops), use directly
needs_get_attr = False
w_up_arg = w3_w1_stacked
w_down_arg = w2_stacked
# Llama4 INPUT-SIDE routing: apply routing to INPUT before kernel
# Cast BOTH input and routing_weights to weight dtype if needed
# Critical: BFloat16 * Float32 → Float32 (type promotion) so we cast both to same dtype
with graph.inserting_before(node):
# Create get_attr nodes INSIDE insertion context for proper topological ordering
if needs_get_attr:
w_up_arg = graph.get_attr(w_up_key)
w_down_arg = graph.get_attr(w_down_key)
# Get weight dtype to ensure dtype consistency for Llama4 stacked tensors
# The fused kernel requires input and weights to have matching dtypes
weight_dtype = None
if w_up_arg.op == "get_attr":
try:
weight_tensor = gm.get_parameter(w_up_arg.target)
weight_dtype = weight_tensor.dtype
except (AttributeError, KeyError):
pass
input_to_scale = hidden_states
routing_to_scale = routing_weights
if weight_dtype is not None and weight_dtype != torch.float32:
input_to_scale = graph.call_function(
torch.ops.aten._to_copy.default,
args=(hidden_states,),
kwargs={"dtype": weight_dtype},
)
routing_to_scale = graph.call_function(
torch.ops.aten._to_copy.default,
args=(routing_weights,),
kwargs={"dtype": weight_dtype},
)
# Scale input: hidden_states = hidden_states * routing_weights (both same dtype now)
scaled_input = graph.call_function(
torch.ops.aten.mul.Tensor,
args=(input_to_scale, routing_to_scale),
)
# Pass ones to kernel to prevent it from multiplying routing again (already applied)
ones_node = graph.call_function(
torch.ops.aten.ones_like.default,
args=(routing_weights,),
)
new_node = graph.call_function(
replacement_op,
args=(scaled_input, selected_experts, ones_node, w_up_arg, w_down_arg),
kwargs={
"act_fn": act_fn,
"is_gated_mlp": True,
},
)
node.replace_all_uses_with(new_node)
graph.erase_node(node)
def _find_lowest_common_ancessor(nodes: list[Node]) -> Optional[Node]:
"""
Find the lowest common ancestor for a list of nodes in a torch.fx Graph by following
@ -1068,6 +978,31 @@ class MatchBmmMoePattern(BaseTransform):
continue
first_bmm, gate_up_weight = result
# Get shapes from node metadata
if not hasattr(gate_up_weight, "meta") or "val" not in gate_up_weight.meta:
continue
if not hasattr(down_weight, "meta") or "val" not in down_weight.meta:
continue
gate_up_shape = gate_up_weight.meta["val"].shape
down_shape = down_weight.meta["val"].shape
# Only support llama4 shaped weights for now
if len(gate_up_shape) != len(down_shape) or len(gate_up_shape) != 3:
continue
# Llama4 expectation:
# num_experts = gate_up_shape[0] == down_shape[0]
# hidden_size = gate_up_shape[1] == down_shape[2]
# gate_up_shape[2] == 2 * down_shape[1] (intermediate_size)
if gate_up_shape[0] != down_shape[0]:
continue
if gate_up_shape[2] != 2 * down_shape[1]:
continue
if gate_up_shape[1] != down_shape[2]:
continue
# Step 3: Get batched input and trace back to original input and routing
batched_input = first_bmm.args[0]
if not isinstance(batched_input, Node) or not is_op(batched_input, torch.ops.aten.view):
@ -1212,31 +1147,119 @@ class MatchBmmMoePattern(BaseTransform):
# If input_routing is False: kernel applies routing to output
apply_routing_on_input = input_routing
# Wrap stacked tensors in single-element lists for torch_moe unified interface
with graph.inserting_before(output_node):
# Create list nodes for stacked weights
w1_list_node = graph.call_function(
list,
args=([gate_up_weight],),
)
w2_list_node = graph.call_function(
list,
args=([down_weight],),
)
w3_list_node = graph.call_function(
list,
args=([],), # Empty list for stacked gated MLP
# Materialize stacked tensors into per-expert parameters for torch_moe
# Get the actual tensors from the graph nodes
if gate_up_weight.op != "get_attr" or down_weight.op != "get_attr":
raise RuntimeError(
f"Expected get_attr nodes for BMM MoE weights, got {gate_up_weight.op} and {down_weight.op}"
)
gate_up_tensor = gm.get_parameter(gate_up_weight.target)
down_tensor = gm.get_parameter(down_weight.target)
# Support only llama4 shaped weights for now
if gate_up_tensor.shape[2] != 2 * down_tensor.shape[1]:
raise RuntimeError(
f"Expected gate_up_tensor.shape[2] == 2 * down_tensor.shape[1],"
f"got {gate_up_tensor.shape[2]} and {down_tensor.shape[1]}"
)
# Get dimensions
assert len(gate_up_tensor.shape) == 3, (
f"Expected gate_up_tensor.shape to have 3 dimensions, got {len(gate_up_tensor.shape)}"
)
assert len(down_tensor.shape) == 3, (
f"Expected down_tensor.shape to have 3 dimensions, got {len(down_tensor.shape)}"
)
num_experts = gate_up_tensor.shape[0]
assert num_experts == down_tensor.shape[0], (
f"Expected num_experts == down_tensor.shape[0],"
f"got {num_experts} and {down_tensor.shape[0]}"
)
hidden_size = gate_up_tensor.shape[1]
assert hidden_size == down_tensor.shape[2], (
f"Expected hidden_size == down_tensor.shape[2],"
f"got {hidden_size} and {down_tensor.shape[2]}"
)
intermediate_size = gate_up_tensor.shape[2] // 2
assert intermediate_size == down_tensor.shape[1], (
f"Expected intermediate_size == down_tensor.shape[1],"
f"got {intermediate_size} and {down_tensor.shape[1]}"
)
# Store checkpoint keys for hooks
gate_up_checkpoint_key = str(gate_up_weight.target)
down_checkpoint_key = str(down_weight.target)
# Split each stacked tensor into per-expert tensors and register as parameters
# This creates get_attr nodes that sharding expects
w1_keys = []
w2_keys = []
w3_keys = []
for expert_idx in range(num_experts):
# Register each expert's weight as a separate parameter
w1_key = f"bmm_moe_w1_expert_{num_moe_patterns}_{expert_idx}"
w2_key = f"bmm_moe_w2_expert_{num_moe_patterns}_{expert_idx}"
w3_key = f"bmm_moe_w3_expert_{num_moe_patterns}_{expert_idx}"
w1_keys.append(w1_key)
w2_keys.append(w2_key)
w3_keys.append(w3_key)
w1_param = torch.nn.Parameter(
gate_up_tensor[expert_idx, :, :intermediate_size].transpose(0, 1)
)
w2_param = torch.nn.Parameter(down_tensor[expert_idx].transpose(0, 1))
w3_param = torch.nn.Parameter(
gate_up_tensor[expert_idx, :, intermediate_size:].transpose(0, 1)
)
gm.register_parameter(w1_key, w1_param)
gm.register_parameter(w2_key, w2_param)
gm.register_parameter(w3_key, w3_param)
# Register checkpoint loading hooks - ONE per stacked weight
# Hook for gate_up_weight: splits into all w1 and w3 expert weights
gm._register_load_state_dict_pre_hook(
partial(
_bmm_moe_gate_up_split_hook,
source_key=gate_up_checkpoint_key,
intermediate_size=intermediate_size,
w1_keys=w1_keys,
w3_keys=w3_keys,
)
)
# Hook for down_weight: splits into all w2 expert weights
gm._register_load_state_dict_pre_hook(
partial(
_bmm_moe_down_split_hook,
source_key=down_checkpoint_key,
w2_keys=w2_keys,
)
)
# Now create get_attr nodes for each expert weight
# These must be created within the insertion context for proper graph ordering
insertion_point = graph.find_nodes(op="get_attr")[0]
with graph.inserting_before(insertion_point):
w1_nodes = [graph.get_attr(key) for key in w1_keys]
w2_nodes = [graph.get_attr(key) for key in w2_keys]
w3_nodes = [graph.get_attr(key) for key in w3_keys]
with graph.inserting_before(output_node):
fused_moe_node = graph.call_function(
torch.ops.auto_deploy.torch_moe,
args=(
input_hidden_states,
selected_experts,
routing_weights_node,
w1_list_node,
w2_list_node,
w3_list_node,
w1_nodes,
w2_nodes,
w3_nodes,
),
kwargs={
"is_gated_mlp": True,

View File

@ -266,6 +266,14 @@ def _execute_op_in_aux_stream(
if input_node.target == torch.ops.aten.view.default:
target_input_node = input_node
break
# Look through dtype cast nodes (aten.to) to find the view node
if input_node.target == torch.ops.aten.to:
for nested_input in input_node.all_input_nodes:
if nested_input.target == torch.ops.aten.view.default:
target_input_node = nested_input
break
if target_input_node is not None:
break
assert target_input_node is not None, f"Target input node not found for node {n}"
with graph.inserting_before(target_input_node):

View File

@ -594,8 +594,6 @@ class BMMShardingInfo(ShardingTransformInfo):
class EPShardingInfo(ShardingTransformInfo):
"""Configuration for EP sharding transformations."""
mlp_type: MLPType
@classmethod
def from_node(cls, node: Node, **kwargs) -> "EPShardingInfo":
"""
@ -613,7 +611,7 @@ class EPShardingInfo(ShardingTransformInfo):
def apply(self, gm: GraphModule, node: Node) -> None:
"""Apply EP sharding transformation to the graph module."""
_insert_sharded_moe(gm, node, self.config, mlp_type=self.mlp_type)
_insert_sharded_moe(gm, node, self.config)
class MXFP4EPShardingInfo(EPShardingInfo):
@ -1065,71 +1063,6 @@ def _resolve_tp_cls_from_node(node: Node):
return WeightShardingInfo
def _transform_bmm_moe_weight_param(
gm: GraphModule,
param_node: Node,
lo: int,
hi: int,
swap_gate_up: bool = False,
) -> None:
"""Transform a parameter for BMM MoE: slice experts, optionally swap gate/up, transpose.
This modifies the parameter in-place and registers a load hook.
Does NOT create graph nodes - those should be created separately by the caller.
Args:
gm: Graph module
param_node: The get_attr node for the parameter
lo: Start index for expert slicing
hi: End index for expert slicing
swap_gate_up: If True, swap W1 and W3 (Llama4 -> TRT-LLM format)
"""
if param_node.op != "get_attr":
return # Only works on parameters
param_key = str(param_node.target)
modname, _, param_name = param_key.rpartition(".")
submod = gm.get_submodule(modname) if modname else gm
full_param = getattr(submod, param_name)
# Slice the parameter along expert dimension (dim 0)
sliced_param = full_param[lo:hi].detach().clone()
# Swap W1 and W3 if needed (for gate_up weights)
# Llama4: (E, H, 2*I) with [W1, W3], TRT-LLM wants [W3, W1]
if swap_gate_up and sliced_param.ndim == 3:
intermediate_size = sliced_param.shape[2] // 2
w1 = sliced_param[:, :, :intermediate_size]
w3 = sliced_param[:, :, intermediate_size:]
sliced_param = torch.cat([w3, w1], dim=2)
# Transpose: Llama4 (E, H, X) -> TRT-LLM (E, X, H)
transposed_param = sliced_param.transpose(1, 2)
transposed_shape = transposed_param.shape
# Define transformation function for load hook
def transform_tensor(t: torch.Tensor) -> torch.Tensor:
t_sliced = t[lo:hi]
if swap_gate_up and t_sliced.ndim == 3:
intermediate_size = t_sliced.shape[2] // 2
w1 = t_sliced[:, :, :intermediate_size]
w3 = t_sliced[:, :, intermediate_size:]
t_sliced = torch.cat([w3, w1], dim=2)
return t_sliced.transpose(1, 2).contiguous()
# Register load hook
gm._register_load_state_dict_pre_hook(
partial(
_load_hook,
f_split=transform_tensor,
param_key=param_key,
param_shape=transposed_shape,
)
)
# Replace the parameter with the transformed version
new_param = nn.Parameter(transposed_param, requires_grad=False)
setattr(submod, param_name, new_param)
def _get_dim0_from_arg(gm: GraphModule, arg: Union[Node, torch.Tensor]) -> int:
"""Helper to get the first dimension size of an argument (Node or Tensor)."""
if isinstance(arg, torch.Tensor):
@ -1195,26 +1128,6 @@ def init_process_grid_from_config(
return process_grid
def _canonicalize_node_args(node: Node) -> list:
"""
Canonicalize the node's arguments.
Actions performed:
- Flatten list arguments
"""
new_args = list(node.args)
for i in range(len(new_args)):
# In FX graphs, the list might be a Node representing a list() call
if isinstance(new_args[i], Node):
# Check if this is a list() call node
if new_args[i].target is list and len(new_args[i].args) == 1:
new_args[i] = new_args[i].args[0]
if isinstance(new_args[i], (list, tuple)):
if len(new_args[i]) == 1:
new_args[i] = new_args[i][0]
return new_args
########################################################
# Sharding transform functions
########################################################
@ -1422,121 +1335,16 @@ def _update_node_args(node: Node, args: tuple) -> None:
)
def _insert_sharded_moe_stacked(
gm: GraphModule,
node: Node,
rank: int,
world_size: int,
allreduce_strategy: AllReduceStrategy,
scale_names: Sequence[str] = (),
):
"""Update the torch_moe node with sliced stacked weight tensors,
sharded `selected_experts` and `final_scales(router_logics)`.
Add an all_reduce node after the moe node.
For torch_moe with stacked tensor format (single-element lists containing 3D tensors).
NOTE: allreduce_strategy is MANDATORY and must be explicitly provided.
"""
if allreduce_strategy is None:
raise ValueError(f"allreduce_strategy must be set for MoE sharding on node {node.name}")
# Extract the stacked tensors from single-element lists
# args[3] = w1_weight (Node representing list with one 3D tensor, or direct list)
# args[4] = w2_weight (Node representing list with one 3D tensor, or direct list)
# Helper to extract tensor node from list (handles both Node and direct list)
def extract_tensor_from_list_arg(list_arg):
if isinstance(list_arg, Node) and list_arg.target is list:
# It's a list() call node - extract from its args
return list_arg.args[0][0] # args[0] is the list content, [0] is first element
elif isinstance(list_arg, (list, tuple)):
# Direct list
return list_arg[0]
else:
raise ValueError(f"Unexpected list format: {type(list_arg)}")
w3_w1_tensor_node = extract_tensor_from_list_arg(node.args[3])
w2_tensor_node = extract_tensor_from_list_arg(node.args[4])
num_experts = _get_dim0_from_arg(gm, w3_w1_tensor_node)
args = list(node.args)
# -- Handle selected_experts and final_scales sharding --
selected_experts = args[1]
final_scales = args[2]
experts_per_rank = num_experts // world_size
with gm.graph.inserting_before(node):
lower = experts_per_rank * rank
# selected_experts_local = selected_experts - low
selected_experts_local = gm.graph.create_node(
"call_function", operator.sub, args=(selected_experts, lower), kwargs={}
)
# For num_experts % world_size != 0 case,
# assign the last (num_experts % world_size) experts to the last rank
div_node = gm.graph.create_node(
"call_function", operator.floordiv, args=(selected_experts, experts_per_rank), kwargs={}
)
comp_op = torch.ge if rank == world_size - 1 else torch.eq
rank_mask = gm.graph.create_node("call_function", comp_op, args=(div_node, rank), kwargs={})
# final_scales_local = final_scales * rank_mask
final_scales_local = gm.graph.create_node(
"call_function", operator.mul, args=(final_scales, rank_mask), kwargs={}
)
# -- Transform expert weight parameters --
local_lo, local_hi = _split_range_last_remainder(num_experts, world_size, rank)
# Transform w3_w1_stacked: slice experts, swap [W1,W3]->[W3,W1], transpose (E,H,2I)->(E,2I,H)
if isinstance(w3_w1_tensor_node, Node):
_transform_bmm_moe_weight_param(
gm, w3_w1_tensor_node, local_lo, local_hi, swap_gate_up=True
)
# Transform w2_stacked: slice experts, transpose (E,I,H)->(E,H,I)
if isinstance(w2_tensor_node, Node):
_transform_bmm_moe_weight_param(gm, w2_tensor_node, local_lo, local_hi, swap_gate_up=False)
# -- Update args (keep same lists/nodes, just with transformed parameters) --
args[1] = selected_experts_local
args[2] = final_scales_local
# args[3] and args[4] stay the same - we modified the parameters in-place
ad_logger.debug(
f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}."
)
node.args = tuple(args)
# -- add an all_reduce node --
with gm.graph.inserting_after(node):
dist_node = gm.graph.call_function(
torch.ops.auto_deploy.torch_dist_all_reduce.default,
args=(node, allreduce_strategy),
)
node.replace_all_uses_with(dist_node)
dist_node.replace_input_with(dist_node, node)
def _insert_sharded_moe(
gm: GraphModule,
node: Node,
config: ShardingTransformConfig,
mlp_type: MLPType,
scale_names: Sequence[str] = (),
):
"""Update the torch_moe node with sharded weight lists or stacked tensors,
"""Update the torch_moe node with sharded weight lists,
sharded `selected_experts` and `final_scales(router_logics)`.
Add an all_reduce node after the moe node.
Handles both:
- Standard format: per-expert weight lists
- Stacked format: single-element lists containing stacked 3D tensors (Llama4 pattern)
NOTE: allreduce_strategy is MANDATORY.
"""
@ -1551,23 +1359,10 @@ def _insert_sharded_moe(
raise ValueError(f"allreduce_strategy must be set for MoE sharding on node {node.name}")
scale_names = list(scale_names)
flat_args = _canonicalize_node_args(node)
# we have two variants of MoE: stacked and listed:
# - stacked: w1, w2, w3 weight args are order-3 tensors, where the 1st dimension corresponds
# to the stacked expert weigthts.
# - listed: w1, w2, w3 weight args are lists of order-2 tensors, where each expert weight
# is a separate entry in the list.
if isinstance(flat_args[3], Node):
is_stacked = True
num_experts = shape(flat_args[3])[0]
else:
is_stacked = False
num_experts = len(flat_args[3])
args = list(node.args)
# -- Handle selected_experts and final_scales sharding --
selected_experts = args[1]
final_scales = args[2]
num_experts = len(args[3])
experts_per_rank = num_experts // ep_size
@ -1600,72 +1395,61 @@ def _insert_sharded_moe(
args[1] = selected_experts_local
args[2] = final_scales_local
if is_stacked:
# bmm-style stacked MoE: sharding is done by slicing the 1st dimension of the stacked weight tensor
# if mlp_type == MLPType.FUSED_GATED_MLP:
w_gate_up_stacked = flat_args[3]
w_down_stacked = flat_args[4]
local_lo, local_hi = _split_range_last_remainder(num_experts, ep_size, ep_rank)
_transform_bmm_moe_weight_param(
gm, w_gate_up_stacked, local_lo, local_hi, swap_gate_up=True
# -- Shard expert weights --
def get_partition(lst, world_size, rank):
num_experts = len(lst)
expert_size_per_partition = num_experts // world_size
expert_start = rank * expert_size_per_partition
# For num_experts % world_size != 0 case,
# assign the last (num_experts % world_size) experts to the last rank
expert_end = (
num_experts if (rank == world_size - 1) else expert_start + expert_size_per_partition
)
_transform_bmm_moe_weight_param(gm, w_down_stacked, local_lo, local_hi, swap_gate_up=False)
else:
# listed MoE: sharding is done by taking a range of the listed weight tensors
# -- Shard expert weights --
def get_partition(lst, world_size, rank):
num_experts = len(lst)
expert_size_per_partition = num_experts // world_size
expert_start = rank * expert_size_per_partition
# For num_experts % world_size != 0 case,
# assign the last (num_experts % world_size) experts to the last rank
expert_end = (
num_experts
if (rank == world_size - 1)
else expert_start + expert_size_per_partition
)
return lst[expert_start:expert_end]
return lst[expert_start:expert_end], lst[:expert_start] + lst[expert_end:]
w_up_list_sharded = get_partition(args[3], ep_size, ep_rank)
w_down_list_sharded = get_partition(args[4], ep_size, ep_rank)
w_gate_list_sharded = get_partition(args[5], ep_size, ep_rank)
w_up_list_sharded, w_up_list_to_remove = get_partition(args[3], ep_size, ep_rank)
w_down_list_sharded, w_down_list_to_remove = get_partition(args[4], ep_size, ep_rank)
w_gate_list_sharded, w_gate_list_to_remove = get_partition(args[5], ep_size, ep_rank)
# if tp_size > 1, we do 2D EP+TP sharding.
# we add TP sharding of all expert weights.
for w_up in w_up_list_sharded + w_gate_list_sharded:
shard_weight_tensor(
gm=gm,
weight_tensor=gm.get_parameter(w_up.target),
param_key=w_up.target,
dim=SplitDimension.COLUMN,
rank=tp_rank,
world_size=tp_size,
)
# here we don't need to add all-reduce: it's enough to have
# just one all-reduce after the whole EP+TP sharded MoE node.
for w_down in w_down_list_sharded:
shard_weight_tensor(
gm=gm,
weight_tensor=gm.get_parameter(w_down.target),
param_key=w_down.target,
dim=SplitDimension.ROW,
rank=tp_rank,
world_size=tp_size,
)
# -- Update args --
args[3] = w_up_list_sharded
args[4] = w_down_list_sharded
args[5] = w_gate_list_sharded
# Shard scales for quantized ops
for i in range(len(scale_names) * 3): # 3 layers (w1, w2, w3) × #scale_names per layer
args[6 + i] = get_partition(args[6 + i], ep_size, ep_rank)
ad_logger.debug(
f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}."
# if tp_size > 1, we do 2D EP+TP sharding.
# we add TP sharding of all expert weights.
for w_up in w_up_list_sharded + w_gate_list_sharded:
shard_weight_tensor(
gm=gm,
weight_tensor=gm.get_parameter(w_up.target),
param_key=w_up.target,
dim=SplitDimension.COLUMN,
rank=tp_rank,
world_size=tp_size,
)
# here we don't need to add all-reduce: it's enough to have
# just one all-reduce after the whole EP+TP sharded MoE node.
for w_down in w_down_list_sharded:
shard_weight_tensor(
gm=gm,
weight_tensor=gm.get_parameter(w_down.target),
param_key=w_down.target,
dim=SplitDimension.ROW,
rank=tp_rank,
world_size=tp_size,
)
# -- Update args --
args[3] = w_up_list_sharded
args[4] = w_down_list_sharded
args[5] = w_gate_list_sharded
# Shard scales for quantized ops
scales_to_remove = []
for i in range(len(scale_names) * 3): # 3 layers (w1, w2, w3) × #scale_names per layer
sharded, to_remove = get_partition(args[6 + i], ep_size, ep_rank)
args[6 + i] = sharded
scales_to_remove.extend(to_remove)
ad_logger.debug(
f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}."
)
node.args = tuple(args)
# -- add an all_reduce node --
@ -1676,6 +1460,15 @@ def _insert_sharded_moe(
node.replace_all_uses_with(dist_node)
dist_node.replace_input_with(dist_node, node)
gm.graph.eliminate_dead_code()
# Expert weights registered via gm.register_parameter() are top-level attributes.
# Unlike submodules, these aren't cleaned up by eliminate_dead_code() or
# delete_all_unused_submodules() - must delete manually after removing their get_attr nodes.
for expert in (
w_up_list_to_remove + w_down_list_to_remove + w_gate_list_to_remove + scales_to_remove
):
delattr(gm, expert.target)
def _slice_expert_dim(gm: GraphModule, tensor_node: Node, lo: int, hi: int) -> Node:
"""Return tensor_node[lo:hi, ...] via aten.slice along dim 0."""
@ -2652,21 +2445,7 @@ def detect_ep_shard(
for node in list(gm.graph.nodes):
if not is_any_moe_op(node):
continue
args = _canonicalize_node_args(node)
if isinstance(args[3], Node):
mlp_type = MLPType.FUSED_GATED_MLP
else:
if len(args[5]) > 0:
mlp_type = MLPType.GATED_MLP
else:
mlp_type = MLPType.MLP
if transform_container.add(
EPShardingInfo.from_node(
node,
config=config,
mlp_type=mlp_type,
)
):
if transform_container.add(EPShardingInfo.from_node(node, config=config)):
num_moe_patterns += 1
ad_logger.info(f"Found {num_moe_patterns} MoE patterns")

View File

@ -1,13 +1,12 @@
import pytest
import torch
import torch.nn.functional as F
from _torch.helpers import reference_bmm_moe_torch, reference_moe_torch
from _torch.helpers import reference_moe_torch
from _torch_test_utils import fp4_compatible, fp8_compatible, trtllm_ops_available
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp4_global_scale
from tensorrt_llm._torch.modules.fused_moe import MoE # noqa: F401
from tensorrt_llm._torch.utils import ActivationType
def setup_moe_test(dtype, num_experts):
@ -152,62 +151,6 @@ def test_moe_op_run(dtype):
torch.testing.assert_close(output_torch_moe, ref_output, rtol=1e-5, atol=1e-5)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_bmm_based_moe_op_run(dtype):
num_experts = 3
(
x,
selected_experts,
final_scales,
fused_w3_w1_stacked_weight,
fused_w2_weight,
) = setup_bmm_moe_test(dtype, num_experts)
with torch.inference_mode():
x = final_scales * x
selected_experts = torch.ones_like(selected_experts)
# Use torch_moe with stacked tensor format (single-element lists)
output_torch_moe = torch.ops.auto_deploy.torch_moe(
x,
selected_experts,
final_scales,
[fused_w3_w1_stacked_weight], # Wrap in list for unified interface
[fused_w2_weight], # Wrap in list for unified interface
[], # Empty w3_weight list for stacked gated MLP
is_gated_mlp=True,
act_fn=ActivationType.Silu,
apply_routing_on_input=True,
)
output_torch_fused_moe = torch.ops.auto_deploy.torch_moe_fused(
x,
selected_experts,
final_scales,
fused_w3_w1_stacked_weight,
fused_w2_weight,
)
output_trt_fused_moe = torch.ops.auto_deploy.trtllm_moe_fused(
x,
selected_experts,
final_scales,
fused_w3_w1_stacked_weight,
fused_w2_weight,
)
ref_output = reference_bmm_moe_torch(
x,
selected_experts,
final_scales,
fused_w3_w1_stacked_weight,
fused_w2_weight,
apply_routing_on_input=True,
)
torch.cuda.synchronize()
torch.testing.assert_close(output_trt_fused_moe, output_torch_fused_moe, rtol=5e-2, atol=5e-2)
torch.testing.assert_close(output_trt_fused_moe, ref_output, rtol=5e-2, atol=5e-2)
torch.testing.assert_close(output_torch_fused_moe, ref_output, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(output_torch_moe, ref_output, rtol=1e-5, atol=1e-5)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.skipif(not fp8_compatible(), reason="Requires fp8 support")
def test_fp8_moe_op_run(dtype):

View File

@ -0,0 +1,249 @@
"""Tests for BMM MoE checkpoint loading hooks."""
import pytest
import torch
from tensorrt_llm._torch.auto_deploy.transform.library.fused_moe import (
_bmm_moe_down_split_hook,
_bmm_moe_gate_up_split_hook,
)
@pytest.fixture
def gate_up_stacked_weight():
"""Fixture for stacked gate_up weight in Llama4 format (E, H, 2*I)."""
num_experts = 4
hidden_size = 64
intermediate_size = 32
return torch.randn(num_experts, hidden_size, intermediate_size * 2)
@pytest.fixture
def down_stacked_weight():
"""Fixture for stacked down weight in Llama4 format (E, I, H)."""
num_experts = 4
hidden_size = 64
intermediate_size = 32
return torch.randn(num_experts, intermediate_size, hidden_size)
class TestBmmMoeGateUpSplitHook:
"""Tests for _bmm_moe_gate_up_split_hook."""
@pytest.mark.parametrize(
"num_experts,hidden_size,intermediate_size",
[
(4, 64, 32),
(8, 128, 64),
(2, 32, 16),
],
)
def test_splits_stacked_weights_into_per_expert_w1_w3(
self, num_experts, hidden_size, intermediate_size
):
"""Verify gate_up hook splits stacked weights into w1/w3 per expert."""
# Llama4 format: (E, H, 2*I)
stacked = torch.randn(num_experts, hidden_size, intermediate_size * 2)
state_dict = {"gate_up_weight": stacked}
w1_keys = [f"w1_{i}" for i in range(num_experts)]
w3_keys = [f"w3_{i}" for i in range(num_experts)]
_bmm_moe_gate_up_split_hook(
state_dict,
"",
source_key="gate_up_weight",
intermediate_size=intermediate_size,
w1_keys=w1_keys,
w3_keys=w3_keys,
)
for i in range(num_experts):
assert w1_keys[i] in state_dict
assert w3_keys[i] in state_dict
# After transpose: (I, H)
assert state_dict[w1_keys[i]].shape == (intermediate_size, hidden_size)
assert state_dict[w3_keys[i]].shape == (intermediate_size, hidden_size)
def test_w1_w3_content_matches_original_stacked(self):
"""Verify split w1/w3 tensors match the original stacked content."""
num_experts = 2
hidden_size = 32
intermediate_size = 16
stacked = torch.randn(num_experts, hidden_size, intermediate_size * 2)
state_dict = {"gate_up_weight": stacked.clone()}
w1_keys = [f"w1_{i}" for i in range(num_experts)]
w3_keys = [f"w3_{i}" for i in range(num_experts)]
_bmm_moe_gate_up_split_hook(
state_dict,
"",
source_key="gate_up_weight",
intermediate_size=intermediate_size,
w1_keys=w1_keys,
w3_keys=w3_keys,
)
for i in range(num_experts):
# w1 is first half: stacked[i, :, :intermediate_size].T
expected_w1 = stacked[i, :, :intermediate_size].transpose(0, 1).contiguous()
# w3 is second half: stacked[i, :, intermediate_size:].T
expected_w3 = stacked[i, :, intermediate_size:].transpose(0, 1).contiguous()
torch.testing.assert_close(state_dict[w1_keys[i]], expected_w1)
torch.testing.assert_close(state_dict[w3_keys[i]], expected_w3)
def test_handles_missing_source_key(self):
"""Verify hook does nothing when source key is missing."""
state_dict = {}
# Should not raise
_bmm_moe_gate_up_split_hook(
state_dict,
"",
source_key="missing_key",
intermediate_size=32,
w1_keys=["w1"],
w3_keys=["w3"],
)
assert len(state_dict) == 0
@pytest.mark.parametrize("prefix", ["", "model.layers.0.moe."])
def test_works_with_module_prefix(self, prefix):
"""Verify hook works correctly with module path prefix."""
num_experts = 2
hidden_size = 32
intermediate_size = 16
stacked = torch.randn(num_experts, hidden_size, intermediate_size * 2)
state_dict = {f"{prefix}gate_up_weight": stacked}
w1_keys = [f"w1_{i}" for i in range(num_experts)]
w3_keys = [f"w3_{i}" for i in range(num_experts)]
_bmm_moe_gate_up_split_hook(
state_dict,
prefix,
source_key="gate_up_weight",
intermediate_size=intermediate_size,
w1_keys=w1_keys,
w3_keys=w3_keys,
)
for i in range(num_experts):
assert f"{prefix}{w1_keys[i]}" in state_dict
assert f"{prefix}{w3_keys[i]}" in state_dict
class TestBmmMoeDownSplitHook:
"""Tests for _bmm_moe_down_split_hook."""
@pytest.mark.parametrize(
"num_experts,hidden_size,intermediate_size",
[
(4, 64, 32),
(8, 128, 64),
(2, 32, 16),
],
)
def test_splits_stacked_weights_into_per_expert_w2(
self, num_experts, hidden_size, intermediate_size
):
"""Verify down hook splits stacked weights into w2 per expert."""
# Llama4 format: (E, I, H)
stacked = torch.randn(num_experts, intermediate_size, hidden_size)
state_dict = {"down_weight": stacked}
w2_keys = [f"w2_{i}" for i in range(num_experts)]
_bmm_moe_down_split_hook(
state_dict,
"",
source_key="down_weight",
w2_keys=w2_keys,
)
for i in range(num_experts):
assert w2_keys[i] in state_dict
# After transpose: (H, I)
assert state_dict[w2_keys[i]].shape == (hidden_size, intermediate_size)
def test_w2_content_matches_original_stacked(self):
"""Verify split w2 tensors match the original stacked content."""
num_experts = 2
hidden_size = 32
intermediate_size = 16
stacked = torch.randn(num_experts, intermediate_size, hidden_size)
state_dict = {"down_weight": stacked.clone()}
w2_keys = [f"w2_{i}" for i in range(num_experts)]
_bmm_moe_down_split_hook(
state_dict,
"",
source_key="down_weight",
w2_keys=w2_keys,
)
for i in range(num_experts):
expected_w2 = stacked[i].transpose(0, 1).contiguous()
torch.testing.assert_close(state_dict[w2_keys[i]], expected_w2)
def test_handles_missing_source_key(self):
"""Verify hook does nothing when source key is missing."""
state_dict = {}
_bmm_moe_down_split_hook(
state_dict,
"",
source_key="missing_key",
w2_keys=["w2"],
)
assert len(state_dict) == 0
class TestBmmMoeHooksIntegration:
"""Integration tests for BMM MoE hooks working together."""
def test_full_checkpoint_loading_flow(self):
"""Test the full flow: split gate_up and down into per-expert weights."""
num_experts = 4
hidden_size = 64
intermediate_size = 32
# Simulate a checkpoint with stacked weights
gate_up_stacked = torch.randn(num_experts, hidden_size, intermediate_size * 2)
down_stacked = torch.randn(num_experts, intermediate_size, hidden_size)
state_dict = {
"gate_up_weight": gate_up_stacked.clone(),
"down_weight": down_stacked.clone(),
}
w1_keys = [f"w1_{i}" for i in range(num_experts)]
w2_keys = [f"w2_{i}" for i in range(num_experts)]
w3_keys = [f"w3_{i}" for i in range(num_experts)]
# Step 1: Split gate_up into w1 and w3
_bmm_moe_gate_up_split_hook(
state_dict,
"",
source_key="gate_up_weight",
intermediate_size=intermediate_size,
w1_keys=w1_keys,
w3_keys=w3_keys,
)
# Step 2: Split down into w2
_bmm_moe_down_split_hook(
state_dict,
"",
source_key="down_weight",
w2_keys=w2_keys,
)
# Verify: all per-expert weights present with correct shapes
for i in range(num_experts):
assert state_dict[w1_keys[i]].shape == (intermediate_size, hidden_size)
assert state_dict[w2_keys[i]].shape == (hidden_size, intermediate_size)
assert state_dict[w3_keys[i]].shape == (intermediate_size, hidden_size)