From 393c3d259e4afa77ed0f66118053226c67003376 Mon Sep 17 00:00:00 2001 From: Bala Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 14:22:32 -0800 Subject: [PATCH] [#10245][feat] AutoDeploy: Add Minimax M2 support (#10525) Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> --- .../custom_ops/flashinfer_attention.py | 6 +- .../_torch/auto_deploy/custom_ops/rms_norm.py | 63 +++ .../auto_deploy/models/patches/minimax_m2.py | 78 ++++ .../auto_deploy/transform/library/sharding.py | 375 ++++++++++++++++-- .../_torch/auto_deploy/utils/node_utils.py | 4 +- .../custom_ops/test_sharded_rmsnorm.py | 171 ++++++++ .../library/test_rmsnorm_sharding.py | 325 +++++++++++++++ .../models/test_minimax_m2_patches.py | 138 +++++++ 8 files changed, 1116 insertions(+), 44 deletions(-) create mode 100644 tensorrt_llm/_torch/auto_deploy/models/patches/minimax_m2.py create mode 100644 tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_sharded_rmsnorm.py create mode 100644 tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_rmsnorm_sharding.py create mode 100644 tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_minimax_m2_patches.py diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py index 7d43dc6cea..eb6ed39fef 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py @@ -333,9 +333,9 @@ def flashinfer_mha_with_cache( q_shape_og = q.shape b, s = q_shape_og[:2] - q = q.reshape(b * s, -1, head_dim) - k = k.reshape(b * s, -1, head_dim) - v = v.reshape(b * s, -1, head_dim) + q = q.reshape(b * s, -1, head_dim).contiguous() + k = k.reshape(b * s, -1, head_dim).contiguous() + v = v.reshape(b * s, -1, head_dim).contiguous() # convert to flashinfer-style metadata num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py index 39792a14fa..2f85d87518 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py @@ -2,6 +2,7 @@ import flashinfer import torch +import torch.distributed as dist import torch.nn.functional as F from einops import rearrange @@ -252,3 +253,65 @@ def gated_rms_norm_ref( if z is not None and norm_before_gate: out *= F.silu(z) return out.to(dtype) + + +# ============================================================================= +# Sharded RMSNorm (for sharded activations) +# ============================================================================= + + +@torch.library.custom_op("auto_deploy::sharded_rmsnorm", mutates_args=()) +def sharded_rmsnorm( + input: torch.Tensor, weight: torch.Tensor, eps: float, world_size: int +) -> torch.Tensor: + """RMSNorm for sharded activations that need global reduction. + + When activations are sharded (split along the last dimension across devices), + standard RMSNorm computes an incorrect local mean. This op uses all_reduce to compute + the global mean of squared values across all shards, ensuring correct normalization. + + The computation is: + 1. Compute local sum of squares: sum(input^2) over local features + 2. All-reduce to get global sum of squares across all shards + 3. Compute global mean: global_sum / (local_dim * world_size) + 4. Normalize: input * rsqrt(global_mean + eps) + 5. Scale with local weight (weight is also column-sharded) + + Args: + input: Input tensor, shape [..., local_hidden_size] where local_hidden_size + is the shard of the full hidden dimension on this device. + weight: Scaling weights, shape [local_hidden_size] (column-sharded). + eps: Small constant for numerical stability. + world_size: Number of devices across which the activation is sharded. + + Returns: + Normalized and scaled tensor with same shape as input. + """ + local_dim = input.shape[-1] + + # Cast to float32 for precision + input_fp32 = input.to(torch.float32) + + # Compute local sum of squares (NOT mean - we need sum for all_reduce) + local_sum_sq = input_fp32.pow(2).sum(-1, keepdim=True) + + # All-reduce to get global sum of squares + global_sum_sq = local_sum_sq.clone() + dist.all_reduce(global_sum_sq, op=dist.ReduceOp.SUM) + + # Compute global mean: global_sum / total_elements + global_count = local_dim * world_size + global_mean_sq = global_sum_sq / global_count + + # Normalize + input_normalized = input_fp32 * torch.rsqrt(global_mean_sq + eps) + + # Apply weight (local weight since it's also column-sharded) + out = weight * input_normalized.to(input.dtype) + return out + + +@sharded_rmsnorm.register_fake +def _(input: torch.Tensor, weight: torch.Tensor, eps: float, world_size: int) -> torch.Tensor: + """Fake implementation for tracing.""" + return torch.empty_like(input) diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/minimax_m2.py b/tensorrt_llm/_torch/auto_deploy/models/patches/minimax_m2.py new file mode 100644 index 0000000000..85ae69a5c1 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/minimax_m2.py @@ -0,0 +1,78 @@ +"""A patch for MiniMax-M2 MoE to make it compatible with torch.export. + +MiniMax-M2 is loaded from HuggingFace hub (trust_remote_code), so we cannot +import MiniMaxM2SparseMoeBlock directly. Instead, we use the same pattern as +DeepSeek: patching AutoModelForCausalLM.from_config to iterate over modules +and patch by class name. +""" + +import types +from typing import Dict + +import torch +from transformers import AutoModelForCausalLM + + +def minimax_m2_moe(self, hidden_states: torch.Tensor): + """MiniMaxM2SparseMoeBlock forward function rewritten to enable torch.export. + + Key differences from Mixtral: + - Uses sigmoid activation for routing (not softmax) + - Has e_score_correction_bias added for expert selection only + - Gathers original sigmoid weights after topk selection + """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.training and self.jitter_noise > 0: + hidden_states *= torch.empty_like(hidden_states).uniform_( + 1.0 - self.jitter_noise, 1.0 + self.jitter_noise + ) + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + # MiniMax-M2 specific routing: + # Step 1: Sigmoid activation (not softmax like Mixtral) + routing_weights = torch.sigmoid(router_logits.float()) + + # Step 2: Add bias for expert selection only + scores_for_choice = routing_weights + self.e_score_correction_bias + + # Step 3: Select top-k experts based on biased scores + _, selected_experts = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False) + + # Step 4: Gather ORIGINAL sigmoid weights (not biased scores) + top_k_weights = routing_weights.gather(1, selected_experts) + + # Step 5: Normalize so weights sum to 1 + top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True) + top_k_weights = top_k_weights.to(hidden_states.dtype) + + final_hidden_states = torch.ops.auto_deploy.torch_moe( + hidden_states, + selected_experts, + top_k_weights, + w1_weight=[expert.w1.weight for expert in self.experts], # gate projection + w2_weight=[expert.w2.weight for expert in self.experts], # down projection + w3_weight=[expert.w3.weight for expert in self.experts], # up projection + ) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +_from_config_previous = AutoModelForCausalLM.from_config + +CUSTOM_MODULE_PATCHES: Dict[str, callable] = {"MiniMaxM2SparseMoeBlock": minimax_m2_moe} + + +def get_model_from_config_patched(config, **kwargs): + model = _from_config_previous(config, **kwargs) + # Patch modules by class name + for _, module in model.named_modules(): + if type(module).__name__ in CUSTOM_MODULE_PATCHES: + module.forward = types.MethodType(CUSTOM_MODULE_PATCHES[type(module).__name__], module) + + return model + + +# Patch AutoModelForCausalLM.from_config +AutoModelForCausalLM.from_config = get_model_from_config_patched diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index e14e71f218..1ecf85aeda 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -691,6 +691,77 @@ def _resolve_ep_cls_from_node(node: Node) -> type[EPShardingInfo]: return EPShardingInfo +class RMSNormShardingInfo(ShardingTransformInfo): + """Configuration for replacing RMSNorm with sharded version. + + When RMSNorm (torch_rmsnorm) operates on sharded activations with + weight shape [num_heads * head_dim] (full hidden size), it needs to be + replaced with sharded_rmsnorm which uses all_reduce to compute + the correct global mean across shards. + + The detection of whether an RMSNorm needs this treatment is done in + _shard_qk_norm based on weight shape matching q/k projection + output dimensions. + """ + + world_size: int + + @classmethod + def from_node(cls, node: Node, **kwargs) -> "RMSNormShardingInfo": + """Create a RMSNormShardingInfo from a node.""" + return cls(target_node=node.name, **kwargs) + + def validate(self, gm: GraphModule = None, node: Node = None) -> bool: + """Validate that the node is a torch_rmsnorm op.""" + if not is_op(node, torch.ops.auto_deploy.torch_rmsnorm): + ad_logger.debug( + f"RMSNormShardingInfo only applies to torch_rmsnorm ops. " + f"Got {node.target}. Skipping." + ) + return False + return True + + def apply(self, gm: GraphModule, node: Node) -> None: + """Replace torch_rmsnorm with sharded_rmsnorm and shard the weight. + + This handles both: + 1. Weight sharding: Shard the RMSNorm weight across ranks + 2. Op replacement: Replace with sharded_rmsnorm which uses all_reduce + for computing the correct global mean across shards. + """ + # Get original arguments: (input, weight, eps) + input_node = node.args[0] + weight_node = node.args[1] + eps = node.args[2] + + # Shard the weight parameter (column-wise for 1D RMSNorm weight) + weight_key = weight_node.target + weight_tensor = gm.get_parameter(weight_key) + shard_weight_tensor( + gm=gm, + weight_tensor=weight_tensor, + param_key=weight_key, + dim=0, # Column shard for 1D weight + rank=self.config.rank, + world_size=self.world_size, + ) + ad_logger.debug(f"Sharded RMSNorm weight: {weight_key}") + + # Insert the new node with world_size parameter + with gm.graph.inserting_after(node): + new_node = gm.graph.call_function( + torch.ops.auto_deploy.sharded_rmsnorm.default, + args=(input_node, weight_node, eps, self.world_size), + ) + + # Replace all uses with the new node and remove the old node + node.replace_all_uses_with(new_node) + gm.graph.erase_node(node) + ad_logger.debug( + f"Replaced torch_rmsnorm with sharded_rmsnorm (world_size={self.world_size})" + ) + + ######################################################## # Transform API classes ######################################################## @@ -828,6 +899,9 @@ class ShardingTransformExecutor(BaseTransform): for ep_transform in transforms.ep_transforms: if check_and_apply(ep_transform): num_matches += 1 + for rmsnorm_transform in transforms.rmsnorm_transforms: + if check_and_apply(rmsnorm_transform): + num_matches += 1 # post-sharding cleanup transformations for update_transform in transforms.parameter_update_transforms: @@ -851,6 +925,7 @@ class ShardingTransformContainer(BaseModel): parameter_update_transforms: List[ParameterUpdateInfo] = Field(default_factory=list) bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list) ep_transforms: List[EPShardingInfo] = Field(default_factory=list) + rmsnorm_transforms: List[RMSNormShardingInfo] = Field(default_factory=list) def __init__(self, **kwargs): super().__init__(**kwargs) @@ -859,6 +934,7 @@ class ShardingTransformContainer(BaseModel): BMMShardingInfo: self.bmm_transforms, EPShardingInfo: self.ep_transforms, ParameterUpdateInfo: self.parameter_update_transforms, + RMSNormShardingInfo: self.rmsnorm_transforms, } def add(self, transform: ShardingTransformInfo) -> bool: @@ -1299,14 +1375,30 @@ def _shard_parameter_node( def _update_node_args(node: Node, args: tuple) -> None: - """Update the node's arguments with the new sharded arguments.""" + """Update the node's arguments with the new sharded arguments. + + For Node args: preserve the current value (may have been updated by other transforms). + For non-Node args (shapes, sizes, indices): use the stored sharded value. + + This prevents ParameterUpdateInfo from reverting Node references that were + intentionally updated by other transforms. + """ if "sharded" in node.meta and node.meta["sharded"]: return - node.args = args + + # Build new args: preserve current Node refs, apply stored non-Node values + new_args = [] + for i, stored_arg in enumerate(args): + if isinstance(stored_arg, Node): + # Node args: preserve current value (may have been updated by other transforms) + new_args.append(node.args[i]) + else: + # Non-Node args (shapes, sizes, indices): use stored sharded value + new_args.append(stored_arg) + + node.args = tuple(new_args) node.meta["sharded"] = True - ad_logger.debug( - f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}." - ) + ad_logger.debug(f"Updated node {node}: sharded arguments are now {node.args}.") def _insert_sharded_moe( @@ -1838,43 +1930,242 @@ def _determine_fused_weight_dims( return None linear_node = linear_nodes[0] fused_weight_dims = None - # check if there are split nodes in the subgraph. They may indicate fused weights (e.g., QKV) - split_nodes = list(filtered_nodes(linear_node.users, ops=[torch.ops.aten.split_with_sizes])) - if len(split_nodes) > 0: - assert len(linear_nodes) == 1 + if len(linear_nodes) == 1: linear_node = linear_nodes[0] - assert len(split_nodes) == 1, "Expecting exactly one split node for fused weights" - fused_weight_dims = split_nodes[0].args[1] + # check if there are split nodes in the subgraph. They may indicate fused weights (e.g., QKV) + linear_split_users = list( + filtered_nodes(linear_node.users, ops=torch.ops.aten.split_with_sizes) + ) + linear_slice_users = list(filtered_nodes(linear_node.users, ops=torch.ops.aten.slice)) + linear_chunk_users = list(filtered_nodes(linear_node.users, ops=torch.ops.aten.chunk)) + if len(linear_split_users) > 0: + assert len(linear_split_users) == 1, ( + "Expecting exactly one split node for fused weights" + ) + fused_weight_dims = linear_split_users[0].args[1] - slice_nodes = list(filtered_nodes(linear_node.users, ops=[torch.ops.aten.slice])) - if len(slice_nodes) > 0: - # we are probably in fused QKV case with single linear node and 3 slice nodes - assert len(linear_nodes) == 1 - linear_node = linear_nodes[0] - assert all( - s.args[1] == 2 for s in filtered_nodes(linear_node.users, ops=torch.ops.aten.slice) - ), "Expecting slice nodes to slice tensor over dim=2" - fused_weight_dims = [s.args[3] - s.args[2] for s in linear_node.users] - weight_dim = shape(linear_node)[2] - if sum(fused_weight_dims) != weight_dim: - if fused_weight_dims[-1] > weight_dim: - fused_weight_dims[-1] = weight_dim - sum(fused_weight_dims[:-1]) - else: - ad_logger.warning( - f"Fused weight dims {fused_weight_dims} do not sum to weight dim {weight_dim}. Skipping." - ) - return None - chunk_nodes = list(filtered_nodes(linear_node.users, ops=torch.ops.aten.chunk)) - if len(chunk_nodes) > 0: - assert len(linear_nodes) == 1 - linear_node = linear_nodes[0] - assert len(chunk_nodes) == 1, "Expecting exactly one chunk node for fused weights" - num_chunks = chunk_nodes[0].args[1] - weight_dim = shape(linear_node)[2] - fused_weight_dims = [weight_dim // num_chunks] * num_chunks - if fused_weight_dims is not None: - fused_weight_dims = tuple(fused_weight_dims) - return fused_weight_dims + elif len(linear_slice_users) > 0: + # we are probably in fused QKV case with single linear node and 3 slice nodes + assert all(s.args[1] == 2 for s in linear_slice_users), ( + "Expecting slice nodes to slice tensor over dim=2" + ) + fused_weight_dims = [s.args[3] - s.args[2] for s in linear_slice_users] + assert fused_weight_dims, "fused weight dims cannot be empty" + weight_dim = linear_node.meta["val"].shape[2] + if sum(fused_weight_dims) != weight_dim: + if fused_weight_dims[-1] > weight_dim: + fused_weight_dims[-1] = weight_dim - sum(fused_weight_dims[:-1]) + else: + ad_logger.warning( + f"Fused weight dims {fused_weight_dims} do not sum to weight dim {weight_dim}. Skipping." + ) + return + + elif len(linear_chunk_users) > 0: + assert len(linear_chunk_users) == 1, ( + "Expecting exactly one chunk node for fused weights" + ) + num_chunks = linear_chunk_users[0].args[1] + weight_dim = linear_node.meta["val"].shape[2] + fused_weight_dims = [weight_dim // num_chunks] * num_chunks + + +def _find_upstream_qk_proj(node: Node, gm: GraphModule) -> Optional[str]: + """ + Find the upstream q/k projection linear node from an RMSNorm input. + + Traverses backwards through pass-through tensor operations (view, reshape, + type conversions, quantize/dequantize, etc.) to find the immediate producer. + If that producer is a q_proj or k_proj linear, returns the weight name. + + Args: + node: The input node to start traversing from (typically RMSNorm's activation input) + gm: The graph module containing the nodes + + Returns: + The weight name if a q_proj or k_proj is found as immediate upstream, None otherwise. + + TODO: Is there a more efficient way to do this? + """ + # Pass-through ops that we traverse through (these don't change the semantic meaning) + passthrough_ops = [ + torch.ops.aten.view, + torch.ops.aten.reshape, + torch.ops.aten.contiguous, + torch.ops.aten.clone, + torch.ops.aten.to, + torch.ops.aten._to_copy, + torch.ops.aten.slice, + torch.ops.aten.transpose, + torch.ops.aten.permute, + ] + + visited = set() + current = node + + # Traverse through pass-through ops to find the actual producer + while current is not None and current not in visited: + visited.add(current) + + # Check if this is a linear operation (the producer we're looking for) + if is_any_lin_op(current): + try: + weight_name = extract_weight_name(current) + if weight_name and ("q_proj" in weight_name or "k_proj" in weight_name): + return weight_name + except (AttributeError, AssertionError): + pass + # Found a linear but not q/k proj - this is not a QK norm + return None + + # If this is a pass-through op, continue to its first input + if is_op(current, passthrough_ops): + # Get the first tensor input (skip non-tensor args like dims) + tensor_inputs = [arg for arg in current.all_input_nodes if isinstance(arg, Node)] + if tensor_inputs: + current = tensor_inputs[0] + continue + + # Hit a non-passthrough, non-linear op - stop searching + break + + return None + + +def _shard_qk_norm( + layer_subgraph: LayerSubgraph, + linear_nodes: List[Node], + transform_container: ShardingTransformContainer, +) -> int: + """ + Shard RMSNorm ops in attention layers that operate on full hidden size. + + This function detects torch_rmsnorm ops that are true QK norms - i.e., norms that + operate on the output of q_proj or k_proj. These global QK norms operate on + flattened Q/K output [batch, seq, hidden_size] before reshape and need special + handling for tensor parallelism: + 1. Weight sharding: The norm weight is column-sharded across ranks + 2. Global mean: Replaced with sharded_rmsnorm which uses all_reduce + + Detection criteria: + 1. The RMSNorm's input must trace back to a q_proj or k_proj linear operation + 2. The weight shape must match q/k projection output dimensions + + Example1: - Global Norm on all heads directly on flattened Q/K output (e.g. MiniMax): + self.q_norm = MiniMaxM2RMSNorm(self.head_dim * config.num_attention_heads, eps=config.rms_norm_eps) + weight shape: [num_heads * head_dim] -> matches valid_sizes -> will be sharded + Status: Needs sharding (weight matches q projection output dim) + + Example2: - Norm per head after reshaping to [batch, seq, num_heads, head_dim] (e.g. GLM 4.7): + self.q_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + weight shape: [head_dim] -> does NOT match valid_sizes -> skipped + Status: No sharding needed (weight doesn't match q/k output dims) + + Args: + layer_subgraph: The attention layer subgraph + linear_nodes: The linear nodes (q/k/v projections) + transform_container: Container for sharding transformations + + Returns: + Number of nodes added for sharding + """ + if layer_subgraph.layer_type != LayerType.ATTENTION or layer_subgraph.terminating_node is None: + return 0 + + config = transform_container.config + world_size = config.world_size + gm = linear_nodes[0].graph.owning_module + added_nodes = 0 + + # Collect valid output dimensions from q/k/v projections + # These are the only valid sizes for intermediate weights (e.g., q_norm, k_norm) + valid_sizes = set() + for lin_node in linear_nodes: + try: + wkey = extract_weight_name(lin_node) + w = gm.get_parameter(wkey) + valid_sizes.add(w.shape[0]) # q: num_heads*head_dim, k/v: num_kv_heads*head_dim + except (AttributeError, AssertionError): + pass + + # Find all intermediate weight nodes between q/k/v projections and o_proj. + intermediate_weight_nodes = subgraph( + sources=linear_nodes, + sinks=[layer_subgraph.terminating_node], + include=lambda n: n.op == "get_attr", + ) + + for weight_node in intermediate_weight_nodes: + weight_key = weight_node.target + + # First check: is this weight consumed by a torch_rmsnorm op + if len(list(weight_node.users)) == 0: + continue + + user_node = list(weight_node.users)[0] + if not is_op(user_node, torch.ops.auto_deploy.torch_rmsnorm): + continue + + # Verify this is a true QK norm by checking its input traces back to q_proj or k_proj + # This filters out input_layernorm which feeds INTO q/k/v projections (not after them) + rmsnorm_input = user_node.args[0] # activation input (not the weight) + upstream_proj = _find_upstream_qk_proj(rmsnorm_input, gm) + if upstream_proj is None: + ad_logger.debug( + f"Skipping {user_node.name} - input does not trace back to q_proj or k_proj " + f"(likely input_layernorm or other non-QK norm)" + ) + continue + + ad_logger.debug(f"Found QK norm {user_node.name} with upstream projection: {upstream_proj}") + + # Try to get the parameter + try: + param = gm.get_parameter(weight_key) + except AttributeError: + ad_logger.debug(f"Could not get parameter for {weight_key}, skipping") + continue + + # Only shard 1D weights (RMSNorm weights are always 1D) + if param.dim() != 1: + ad_logger.debug( + f"Skipping intermediate weight {weight_key} with dim={param.dim()} (not 1D)" + ) + continue + + # Check if the weight size is divisible by world_size + if param.shape[0] % world_size != 0: + ad_logger.debug( + f"Skipping intermediate weight {weight_key} with shape {param.shape} " + f"(not divisible by world_size={world_size})" + ) + continue + + # Check if weight size matches one of the q/k/v projection output dimensions + # This filters out per-head norms and unrelated weights like inv_freq + if valid_sizes and param.shape[0] not in valid_sizes: + ad_logger.debug( + f"Skipping intermediate weight {weight_key} with shape {param.shape} " + f"(not in valid_sizes={valid_sizes}, likely per-head norm)" + ) + continue + + # Add RMSNormShardingInfo to replace with sharded_rmsnorm + # This handles both weight sharding and op replacement in one transform + if transform_container.add( + RMSNormShardingInfo.from_node( + user_node, + config=config, + world_size=world_size, + ) + ): + ad_logger.debug( + f"Added RMSNormShardingInfo for {user_node.name} " + f"(will replace with sharded_rmsnorm for global mean)" + ) + added_nodes += 1 + + return added_nodes def _process_column_sharding( @@ -1955,6 +2246,10 @@ def _process_column_sharding( ) ) # chunk nodes do not need to be updated + + # Shard intermediate weights (e.g. q/k/v -> q_norm, k_norm ... -> o_proj) for attention layers + added_nodes += _shard_qk_norm(layer_subgraph, linear_nodes, transform_container) + return added_nodes diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 83cc46309b..78755d353d 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -967,7 +967,9 @@ def get_layer_after_linear_node( min_local_shape=min_local_shape, ) assert linear_nodes[start_lin_index] in opening_linear_nodes, ( - "Linear node not found in opening linear nodes" + f"Linear node not found in opening linear nodes - " + f"terminating_linear_node:{terminating_linear_node.name}, " + f"opening_linear_nodes: {[n.name for n in opening_linear_nodes]}" ) # return the index of the terminating linear node diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_sharded_rmsnorm.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_sharded_rmsnorm.py new file mode 100644 index 0000000000..3ce97f88f3 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_sharded_rmsnorm.py @@ -0,0 +1,171 @@ +"""Functional tests for sharded_rmsnorm custom op. + +This module tests that the sharded_rmsnorm op produces correct +numerical results when run across multiple GPUs with column-sharded inputs. + +The op computes RMSNorm on column-sharded activations by using all_reduce to +compute the global mean of squared values across all shards, ensuring correct +normalization equivalent to non-sharded global RMSNorm. +""" + +import pickle +import sys +import traceback + +import cloudpickle +import pytest +import torch +import torch.distributed as dist +from mpi4py import MPI + +from tensorrt_llm._torch.auto_deploy.distributed.common import initialize, is_initialized + +# Register this module for cloudpickle serialization for MPI workers +cloudpickle.register_pickle_by_value(sys.modules[__name__]) +MPI.pickle.__init__( + cloudpickle.dumps, + cloudpickle.loads, + pickle.HIGHEST_PROTOCOL, +) + +# needed since we reuse the mpi executor pool, first test running will leak a thread +pytestmark = pytest.mark.threadleak(enabled=False) + + +def _reference_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: + """Reference RMSNorm implementation for comparison.""" + input_fp32 = input.to(torch.float32) + variance = input_fp32.pow(2).mean(-1, keepdim=True) + input_normalized = input_fp32 * torch.rsqrt(variance + eps) + return weight * input_normalized.to(input.dtype) + + +def _run_sharded_rmsnorm_test( + tensor_parallel_size: int, + batch_size: int = 2, + seq_len: int = 8, + hidden_size: int = 64, + eps: float = 1e-6, + dtype_str: str = "float16", +): + """Test that sharded_rmsnorm matches non-sharded global RMSNorm. + + Each rank: + 1. Creates identical full input and weight tensors (same seed) + 2. Computes reference result using non-sharded RMSNorm + 3. Column-shards input and weight for this rank + 4. Calls sharded_rmsnorm on local shard + 5. All-gathers results and compares with reference + """ + # Import inside worker to avoid cloudpickle serialization issues with torch.ops + import tensorrt_llm + import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 + + rank = tensorrt_llm.mpi_rank() + world_size = tensor_parallel_size + torch.cuda.set_device(rank) + + if not is_initialized(): + initialize(rank, port=29500) + + # Map string to torch.dtype inside worker to avoid cloudpickle serialization issues + dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16} + dtype = dtype_map[dtype_str] + + try: + torch.manual_seed(42) # Same seed for reproducibility across ranks + + # Full tensors (same on all ranks due to seed) + full_input = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=dtype) + full_weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + + # Reference: non-sharded global RMSNorm + reference_output = _reference_rmsnorm(full_input, full_weight, eps) + + # Column shard: split hidden_size across ranks + local_hidden_size = hidden_size // world_size + start_idx = rank * local_hidden_size + end_idx = start_idx + local_hidden_size + + local_input = full_input[..., start_idx:end_idx].contiguous() + local_weight = full_weight[start_idx:end_idx].contiguous() + + # Call sharded_rmsnorm using dynamic getattr to avoid cloudpickle capturing torch.ops + sharded_rmsnorm_op = getattr(getattr(torch, "ops"), "auto_deploy").sharded_rmsnorm + local_output = sharded_rmsnorm_op(local_input, local_weight, eps, world_size) + + # All-gather to reconstruct full output + gathered_outputs = [torch.zeros_like(local_output) for _ in range(world_size)] + dist.all_gather(gathered_outputs, local_output) + reconstructed_output = torch.cat(gathered_outputs, dim=-1) + + # Compare with reference + torch.testing.assert_close( + reconstructed_output, + reference_output, + atol=1e-2, + rtol=1e-2, + msg=f"sharded_rmsnorm result doesn't match reference (rank={rank})", + ) + except Exception: + traceback.print_exc() + raise + + return True + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs for this test") +@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True) +def test_sharded_rmsnorm_functional(mpi_pool_executor): + """Functional test: verify sharded_rmsnorm produces correct numerical results.""" + torch.manual_seed(0) + tensor_parallel_size = mpi_pool_executor.num_workers + + results = mpi_pool_executor.map( + _run_sharded_rmsnorm_test, + *zip(*[(tensor_parallel_size,)] * tensor_parallel_size), + ) + for r in results: + assert r is True + + +def _run_sharded_rmsnorm_hidden_size_test(tensor_parallel_size: int, hidden_size: int): + """Worker function for hidden size parametrized test.""" + return _run_sharded_rmsnorm_test(tensor_parallel_size, hidden_size=hidden_size) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs for this test") +@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True) +@pytest.mark.parametrize("hidden_size", [64, 128, 256]) +def test_sharded_rmsnorm_different_hidden_sizes(mpi_pool_executor, hidden_size): + """Test sharded_rmsnorm with different hidden sizes.""" + torch.manual_seed(0) + tensor_parallel_size = mpi_pool_executor.num_workers + + results = mpi_pool_executor.map( + _run_sharded_rmsnorm_hidden_size_test, + *zip(*[(tensor_parallel_size, hidden_size)] * tensor_parallel_size), + ) + for r in results: + assert r is True + + +def _run_sharded_rmsnorm_dtype_test(tensor_parallel_size: int, dtype_str: str): + """Worker function for dtype parametrized test.""" + return _run_sharded_rmsnorm_test(tensor_parallel_size, dtype_str=dtype_str) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs for this test") +@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True) +@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) +def test_sharded_rmsnorm_different_dtypes(mpi_pool_executor, dtype_str): + """Test sharded_rmsnorm with different dtypes.""" + torch.manual_seed(0) + tensor_parallel_size = mpi_pool_executor.num_workers + + results = mpi_pool_executor.map( + _run_sharded_rmsnorm_dtype_test, + *zip(*[(tensor_parallel_size, dtype_str)] * tensor_parallel_size), + ) + for r in results: + assert r is True diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_rmsnorm_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_rmsnorm_sharding.py new file mode 100644 index 0000000000..6248174516 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_rmsnorm_sharding.py @@ -0,0 +1,325 @@ +"""Tests for RMSNorm sharding detection and transformation. + +This module tests that the sharding transform correctly detects and transforms +torch_rmsnorm ops based on weight shape and position in the graph: + +1. Full hidden norm (weight shape = [num_heads * head_dim], AFTER q/k projection): + - Detected as QK norm needing sharding → replaced with sharded_rmsnorm + - Weight is sharded + +2. Per-head norm (weight shape = [head_dim], like GLM): + - NOT detected as needing sharding → stays as local torch_rmsnorm + - No transformation needed + +3. Input layernorm (feeds INTO q/k/v projections, not after): + - NOT detected as QK norm → stays as local torch_rmsnorm + - Even though weight shape matches, it's not a QK norm + +These are graph-level unit tests that verify the transform logic. +""" + +import torch +import torch.nn as nn + +# Ensure custom ops are registered +from tensorrt_llm._torch.auto_deploy.custom_ops import rms_norm # noqa: F401 +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op + + +class AttentionWithInputLayernorm(nn.Module): + """Attention module with input_layernorm BEFORE q/k/v projections. + + This simulates the standard Llama pattern where input_layernorm feeds + INTO the attention projections. This should NOT be detected as a QK norm. + """ + + def __init__(self, hidden_size: int = 64, num_heads: int = 4): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + + # Input layernorm - BEFORE q/k/v projections (like Llama) + self.input_layernorm_weight = nn.Parameter(torch.ones(hidden_size)) + self.eps = 1e-6 + + # Q/K/V/O projections + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, s, _ = x.shape + + # Input layernorm BEFORE projections (Llama pattern) + x_normed = torch.ops.auto_deploy.torch_rmsnorm(x, self.input_layernorm_weight, self.eps) + + q = self.q_proj(x_normed) + k = self.k_proj(x_normed) + v = self.v_proj(x_normed) + + # Reshape for attention + q = q.view(b, s, self.num_heads, self.head_dim) + k = k.view(b, s, self.num_heads, self.head_dim) + v = v.view(b, s, self.num_heads, self.head_dim) + + y = torch.ops.auto_deploy.torch_attention(q, k, v, is_causal=True, layout="bsnd") + y = y.contiguous().view(b, s, -1) + + return self.o_proj(y) + + +class SimpleAttentionWithQKNorm(nn.Module): + """Attention module with configurable QK normalization. + + Args: + hidden_size: Total hidden dimension + num_heads: Number of attention heads + use_full_hidden_norm: If True, use full hidden norm (like MiniMax) + - Weight shape = [hidden_size], norm before reshape + - Should be transformed to sharded_rmsnorm + If False, use per-head norm (like GLM) + - Weight shape = [head_dim], norm after reshape + - Should NOT be transformed + """ + + def __init__( + self, hidden_size: int = 64, num_heads: int = 4, use_full_hidden_norm: bool = True + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.use_full_hidden_norm = use_full_hidden_norm + + # Q/K/V/O projections + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False) + + # QK norm weights: shape depends on norm type + norm_dim = hidden_size if use_full_hidden_norm else self.head_dim + self.q_norm_weight = nn.Parameter(torch.ones(norm_dim)) + self.k_norm_weight = nn.Parameter(torch.ones(norm_dim)) + self.eps = 1e-6 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, s, _ = x.shape + + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + if self.use_full_hidden_norm: + # Full hidden norm: apply before reshape + q = torch.ops.auto_deploy.torch_rmsnorm(q, self.q_norm_weight, self.eps) + k = torch.ops.auto_deploy.torch_rmsnorm(k, self.k_norm_weight, self.eps) + # Reshape for attention + q = q.view(b, s, self.num_heads, self.head_dim) + k = k.view(b, s, self.num_heads, self.head_dim) + v = v.view(b, s, self.num_heads, self.head_dim) + else: + # Reshape first for per-head norm + q = q.view(b, s, self.num_heads, self.head_dim) + k = k.view(b, s, self.num_heads, self.head_dim) + v = v.view(b, s, self.num_heads, self.head_dim) + # Per-head norm: apply after reshape (broadcasts over heads) + q = torch.ops.auto_deploy.torch_rmsnorm(q, self.q_norm_weight, self.eps) + k = torch.ops.auto_deploy.torch_rmsnorm(k, self.k_norm_weight, self.eps) + + y = torch.ops.auto_deploy.torch_attention(q, k, v, is_causal=True, layout="bsnd") + y = y.contiguous().view(b, s, -1) + + return self.o_proj(y) + + +def count_ops(gm, op) -> int: + """Count the number of nodes with a specific op in the graph.""" + count = 0 + for node in gm.graph.nodes: + if node.op == "call_function" and is_op(node, op): + count += 1 + return count + + +# ============================================================================= +# Tests +# ============================================================================= + + +class TestRMSNormShardingTransform: + """Tests for the sharding transform on RMSNorm ops.""" + + def test_full_hidden_norm_transformed_to_sharded(self): + """Test that full hidden norm RMSNorm ops are replaced with sharded_rmsnorm. + + When weight shape = [hidden_size] (matches q_proj output dim): + - torch_rmsnorm should be replaced with sharded_rmsnorm + - Weight should be sharded + """ + model = SimpleAttentionWithQKNorm(use_full_hidden_norm=True).to("cuda", dtype=torch.float16) + x = torch.randn(1, 8, 64, device="cuda", dtype=torch.float16) + + # Export to graph + gm = torch_export_to_gm(model, args=(x,), clone=True) + + # Check before transform + before_rmsnorm = count_ops(gm, torch.ops.auto_deploy.torch_rmsnorm.default) + before_sharded = count_ops(gm, torch.ops.auto_deploy.sharded_rmsnorm.default) + assert before_rmsnorm == 2, f"Expected 2 torch_rmsnorm before, got {before_rmsnorm}" + assert before_sharded == 0, f"Expected 0 sharded_rmsnorm before, got {before_sharded}" + + # Apply sharding transform with world_size=2 + optimizer = InferenceOptimizer( + None, + { + "detect_sharding": { + "stage": "sharding", + "simple_shard_only": False, + "sharding_source": ["manual", "factory", "heuristic"], + "support_partial_config": True, + "sharding_dims": ["tp", "ep", "bmm"], + "shard_all_unprocessed": True, + "allreduce_strategy": "NCCL", + "dist_backend": "auto", + "requires_shape_prop": True, + }, + "sharding_transform_executor": { + "stage": "sharding", + "run_shape_prop": True, + }, + }, + ) + optimizer.shared_config.local_rank = 0 + optimizer.shared_config.world_size = 2 + gm_transformed = optimizer(None, gm) + + # Check after transform + after_rmsnorm = count_ops(gm_transformed, torch.ops.auto_deploy.torch_rmsnorm.default) + after_sharded = count_ops(gm_transformed, torch.ops.auto_deploy.sharded_rmsnorm.default) + + # The QK norms (weight matching q/k output dims) should be transformed + assert after_sharded == 2, ( + f"Expected 2 sharded_rmsnorm after transform, got {after_sharded}. " + f"Remaining torch_rmsnorm: {after_rmsnorm}" + ) + + def test_per_head_norm_not_transformed(self): + """Test that per-head norm RMSNorm ops are NOT replaced. + + When weight shape = [head_dim] (doesn't match q_proj output dim): + - torch_rmsnorm should stay as torch_rmsnorm + - No sharded_rmsnorm should be added + """ + model = SimpleAttentionWithQKNorm(use_full_hidden_norm=False).to( + "cuda", dtype=torch.float16 + ) + x = torch.randn(1, 8, 64, device="cuda", dtype=torch.float16) + + # Export to graph + gm = torch_export_to_gm(model, args=(x,), clone=True) + + # Check before transform + before_rmsnorm = count_ops(gm, torch.ops.auto_deploy.torch_rmsnorm.default) + assert before_rmsnorm == 2, f"Expected 2 torch_rmsnorm before, got {before_rmsnorm}" + + # Apply sharding transform with world_size=2 + # Using the same config as default.yaml for detect_sharding and sharding_transform_executor + optimizer = InferenceOptimizer( + None, + { + "detect_sharding": { + "stage": "sharding", + "simple_shard_only": False, + "sharding_source": ["manual", "factory", "heuristic"], + "support_partial_config": True, + "sharding_dims": ["tp", "ep", "bmm"], + "shard_all_unprocessed": True, + "allreduce_strategy": "NCCL", + "dist_backend": "auto", + "requires_shape_prop": True, + }, + "sharding_transform_executor": { + "stage": "sharding", + "run_shape_prop": True, + }, + }, + ) + # Set world_size and rank on shared_config + optimizer.shared_config.local_rank = 0 + optimizer.shared_config.world_size = 2 + gm_transformed = optimizer(None, gm) + + # Check after transform + after_rmsnorm = count_ops(gm_transformed, torch.ops.auto_deploy.torch_rmsnorm.default) + after_sharded = count_ops(gm_transformed, torch.ops.auto_deploy.sharded_rmsnorm.default) + + # Per-head norms should NOT be transformed to sharded + assert after_sharded == 0, ( + f"Expected 0 sharded_rmsnorm for per-head norm, got {after_sharded}" + ) + # The original rmsnorm ops should still be present (or fewer if some were removed) + # Note: Some rmsnorm ops may be removed/transformed for other reasons, but none should become sharded + print(f"After transform: {after_rmsnorm} torch_rmsnorm, {after_sharded} sharded_rmsnorm") + + def test_input_layernorm_not_transformed(self): + """Test that input_layernorm (before q/k/v projections) is NOT replaced. + + When RMSNorm feeds INTO q/k/v projections (like Llama's input_layernorm): + - torch_rmsnorm should stay as torch_rmsnorm + - No sharded_rmsnorm should be added + - Even though weight shape matches [hidden_size], it's not a QK norm + + This tests the fix for the bug where input_layernorm was incorrectly + detected as a QK norm because its weight shape matched q_proj output dim. + """ + model = AttentionWithInputLayernorm().to("cuda", dtype=torch.float16) + x = torch.randn(1, 8, 64, device="cuda", dtype=torch.float16) + + # Export to graph + gm = torch_export_to_gm(model, args=(x,), clone=True) + + # Check before transform + before_rmsnorm = count_ops(gm, torch.ops.auto_deploy.torch_rmsnorm.default) + assert before_rmsnorm == 1, f"Expected 1 torch_rmsnorm before, got {before_rmsnorm}" + + # Apply sharding transform with world_size=4 (like the failing Llama test) + optimizer = InferenceOptimizer( + None, + { + "detect_sharding": { + "stage": "sharding", + "simple_shard_only": False, + "sharding_source": ["manual", "factory", "heuristic"], + "support_partial_config": True, + "sharding_dims": ["tp", "ep", "bmm"], + "shard_all_unprocessed": True, + "allreduce_strategy": "NCCL", + "dist_backend": "auto", + "requires_shape_prop": True, + }, + "sharding_transform_executor": { + "stage": "sharding", + "run_shape_prop": True, + }, + }, + ) + optimizer.shared_config.local_rank = 0 + optimizer.shared_config.world_size = 4 + gm_transformed = optimizer(None, gm) + + # Check after transform + after_rmsnorm = count_ops(gm_transformed, torch.ops.auto_deploy.torch_rmsnorm.default) + after_sharded = count_ops(gm_transformed, torch.ops.auto_deploy.sharded_rmsnorm.default) + + # Input layernorm should NOT be transformed to sharded_rmsnorm + # because it feeds INTO q/k/v projections, not after them + assert after_sharded == 0, ( + f"Expected 0 sharded_rmsnorm for input_layernorm, got {after_sharded}. " + f"input_layernorm should not be detected as a QK norm." + ) + print(f"After transform: {after_rmsnorm} torch_rmsnorm, {after_sharded} sharded_rmsnorm") diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_minimax_m2_patches.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_minimax_m2_patches.py new file mode 100644 index 0000000000..240e5a7589 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_minimax_m2_patches.py @@ -0,0 +1,138 @@ +"""Testing module patches that enable export of MiniMax-M2 model. + +This test verifies that the patched MiniMaxM2SparseMoeBlock forward function +produces identical outputs to the original HuggingFace implementation. +""" + +import types + +import pytest +import torch +from test_common.llm_data import hf_id_to_local_model_dir +from transformers import AutoConfig, AutoModelForCausalLM + +# Import custom_ops to register torch.ops.auto_deploy.torch_moe +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 +from tensorrt_llm._torch.auto_deploy.models.patches.minimax_m2 import minimax_m2_moe + + +def _load_minimax_m2_moe_layer(model_name_or_path): + """ + Loads the MoE layer from MiniMax-M2 model with a minimal configuration. + + We create a small model to keep tests fast while still exercising the + MoE routing and computation logic. + + Parameters: + model_name_or_path (str): Path or name of the pretrained model. + + Returns: + module: The MiniMaxM2SparseMoeBlock layer. + """ + try: + # Load only the model configuration + config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) + + # Configure minimal model for fast testing + config.num_hidden_layers = 1 + config.use_cache = False + config.hidden_size = 16 # Small hidden size + config.intermediate_size = 8 # For MLP within experts + config.mlp_intermediate_size = 32 + config.num_local_experts = 4 # Small number of experts + config.num_experts_per_tok = 2 # Top-k experts + config.num_attention_heads = 2 + config.num_key_value_heads = 2 + config.router_jitter_noise = 0.0 # Disable jitter for deterministic tests + + # Build the model architecture (no weights loaded) + # Note: Importing minimax_m2 module auto-patches from_config, so the + # instance's forward is already patched. But the CLASS method is still + # the original HF implementation, which we use as reference. + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + model.eval() + + # Access the MoE layer + layer_name = "model.layers.0.block_sparse_moe" + module = dict(model.named_modules()).get(layer_name) + if module is None: + print(f"Layer '{layer_name}' not found in the model.") + else: + print(f"Successfully extracted layer '{layer_name}'.") + return module + except Exception as e: + print(f"Error extracting layer: {e}") + return None + + +@pytest.mark.parametrize( + "model_name", + [ + pytest.param( + hf_id_to_local_model_dir("MiniMaxAI/MiniMax-M2"), + ), + ], +) +def test_minimax_m2_moe_patch(model_name): + """ + Test that the patched MiniMaxM2SparseMoeBlock forward produces the same + output as the original HuggingFace implementation. + + The patch rewrites the forward to use torch.ops.auto_deploy.torch_moe + for torch.export compatibility while maintaining numerical equivalence. + + Since importing minimax_m2.py auto-patches module instances, we use the + CLASS method (type(module).forward) as the original HF reference. + """ + # Set seed for reproducibility + torch.manual_seed(42) + + # Get MoE module (instance is already patched by import side-effect) + module = _load_minimax_m2_moe_layer(model_name) + assert module is not None, "Failed to load MiniMax-M2 MoE layer" + + # Convert module to bfloat16 to match input dtype + module = module.to(torch.bfloat16) + + # Create test input - same input will be used for both original and patched + # hidden_size=16 matches the config in _load_minimax_m2_moe_layer + hidden_size = 16 + inputs = torch.randn(2, 6, hidden_size, dtype=torch.bfloat16) + + # The CLASS method is still the original HuggingFace implementation + # (the auto-patch only patches instance methods, not the class) + original_class_forward = type(module).forward + + # Generate reference output using original HF class method + # Uses: same module weights, same input tensor + with torch.no_grad(): + ref_output, ref_router_logits = original_class_forward(module, inputs) + + # The instance forward is already patched by the import side-effect, + # but let's be explicit and apply our patch function directly + module.forward = types.MethodType(minimax_m2_moe, module) + + # Generate test output using patched implementation + # Uses: same module weights, same input tensor + with torch.no_grad(): + test_output, test_router_logits = module(inputs) + + # Verify outputs match + # Router logits should be identical (same computation path) + torch.testing.assert_close( + ref_router_logits, + test_router_logits, + atol=1e-5, + rtol=1e-5, + msg="Router logits mismatch between original and patched MoE", + ) + + # Final hidden states should be very close + # (small tolerance for different computation order in torch_moe) + torch.testing.assert_close( + ref_output, + test_output, + atol=1e-3, + rtol=1e-3, + msg="Output mismatch between original and patched MoE", + )