[#10245][feat] AutoDeploy: Add Minimax M2 support (#10525)

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
This commit is contained in:
Bala Marimuthu 2026-01-28 14:22:32 -08:00 committed by GitHub
parent 744a955cbb
commit 393c3d259e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 1116 additions and 44 deletions

View File

@ -333,9 +333,9 @@ def flashinfer_mha_with_cache(
q_shape_og = q.shape
b, s = q_shape_og[:2]
q = q.reshape(b * s, -1, head_dim)
k = k.reshape(b * s, -1, head_dim)
v = v.reshape(b * s, -1, head_dim)
q = q.reshape(b * s, -1, head_dim).contiguous()
k = k.reshape(b * s, -1, head_dim).contiguous()
v = v.reshape(b * s, -1, head_dim).contiguous()
# convert to flashinfer-style metadata
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()

View File

@ -2,6 +2,7 @@
import flashinfer
import torch
import torch.distributed as dist
import torch.nn.functional as F
from einops import rearrange
@ -252,3 +253,65 @@ def gated_rms_norm_ref(
if z is not None and norm_before_gate:
out *= F.silu(z)
return out.to(dtype)
# =============================================================================
# Sharded RMSNorm (for sharded activations)
# =============================================================================
@torch.library.custom_op("auto_deploy::sharded_rmsnorm", mutates_args=())
def sharded_rmsnorm(
input: torch.Tensor, weight: torch.Tensor, eps: float, world_size: int
) -> torch.Tensor:
"""RMSNorm for sharded activations that need global reduction.
When activations are sharded (split along the last dimension across devices),
standard RMSNorm computes an incorrect local mean. This op uses all_reduce to compute
the global mean of squared values across all shards, ensuring correct normalization.
The computation is:
1. Compute local sum of squares: sum(input^2) over local features
2. All-reduce to get global sum of squares across all shards
3. Compute global mean: global_sum / (local_dim * world_size)
4. Normalize: input * rsqrt(global_mean + eps)
5. Scale with local weight (weight is also column-sharded)
Args:
input: Input tensor, shape [..., local_hidden_size] where local_hidden_size
is the shard of the full hidden dimension on this device.
weight: Scaling weights, shape [local_hidden_size] (column-sharded).
eps: Small constant for numerical stability.
world_size: Number of devices across which the activation is sharded.
Returns:
Normalized and scaled tensor with same shape as input.
"""
local_dim = input.shape[-1]
# Cast to float32 for precision
input_fp32 = input.to(torch.float32)
# Compute local sum of squares (NOT mean - we need sum for all_reduce)
local_sum_sq = input_fp32.pow(2).sum(-1, keepdim=True)
# All-reduce to get global sum of squares
global_sum_sq = local_sum_sq.clone()
dist.all_reduce(global_sum_sq, op=dist.ReduceOp.SUM)
# Compute global mean: global_sum / total_elements
global_count = local_dim * world_size
global_mean_sq = global_sum_sq / global_count
# Normalize
input_normalized = input_fp32 * torch.rsqrt(global_mean_sq + eps)
# Apply weight (local weight since it's also column-sharded)
out = weight * input_normalized.to(input.dtype)
return out
@sharded_rmsnorm.register_fake
def _(input: torch.Tensor, weight: torch.Tensor, eps: float, world_size: int) -> torch.Tensor:
"""Fake implementation for tracing."""
return torch.empty_like(input)

View File

@ -0,0 +1,78 @@
"""A patch for MiniMax-M2 MoE to make it compatible with torch.export.
MiniMax-M2 is loaded from HuggingFace hub (trust_remote_code), so we cannot
import MiniMaxM2SparseMoeBlock directly. Instead, we use the same pattern as
DeepSeek: patching AutoModelForCausalLM.from_config to iterate over modules
and patch by class name.
"""
import types
from typing import Dict
import torch
from transformers import AutoModelForCausalLM
def minimax_m2_moe(self, hidden_states: torch.Tensor):
"""MiniMaxM2SparseMoeBlock forward function rewritten to enable torch.export.
Key differences from Mixtral:
- Uses sigmoid activation for routing (not softmax)
- Has e_score_correction_bias added for expert selection only
- Gathers original sigmoid weights after topk selection
"""
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
)
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
# MiniMax-M2 specific routing:
# Step 1: Sigmoid activation (not softmax like Mixtral)
routing_weights = torch.sigmoid(router_logits.float())
# Step 2: Add bias for expert selection only
scores_for_choice = routing_weights + self.e_score_correction_bias
# Step 3: Select top-k experts based on biased scores
_, selected_experts = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False)
# Step 4: Gather ORIGINAL sigmoid weights (not biased scores)
top_k_weights = routing_weights.gather(1, selected_experts)
# Step 5: Normalize so weights sum to 1
top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
top_k_weights = top_k_weights.to(hidden_states.dtype)
final_hidden_states = torch.ops.auto_deploy.torch_moe(
hidden_states,
selected_experts,
top_k_weights,
w1_weight=[expert.w1.weight for expert in self.experts], # gate projection
w2_weight=[expert.w2.weight for expert in self.experts], # down projection
w3_weight=[expert.w3.weight for expert in self.experts], # up projection
)
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
_from_config_previous = AutoModelForCausalLM.from_config
CUSTOM_MODULE_PATCHES: Dict[str, callable] = {"MiniMaxM2SparseMoeBlock": minimax_m2_moe}
def get_model_from_config_patched(config, **kwargs):
model = _from_config_previous(config, **kwargs)
# Patch modules by class name
for _, module in model.named_modules():
if type(module).__name__ in CUSTOM_MODULE_PATCHES:
module.forward = types.MethodType(CUSTOM_MODULE_PATCHES[type(module).__name__], module)
return model
# Patch AutoModelForCausalLM.from_config
AutoModelForCausalLM.from_config = get_model_from_config_patched

View File

@ -691,6 +691,77 @@ def _resolve_ep_cls_from_node(node: Node) -> type[EPShardingInfo]:
return EPShardingInfo
class RMSNormShardingInfo(ShardingTransformInfo):
"""Configuration for replacing RMSNorm with sharded version.
When RMSNorm (torch_rmsnorm) operates on sharded activations with
weight shape [num_heads * head_dim] (full hidden size), it needs to be
replaced with sharded_rmsnorm which uses all_reduce to compute
the correct global mean across shards.
The detection of whether an RMSNorm needs this treatment is done in
_shard_qk_norm based on weight shape matching q/k projection
output dimensions.
"""
world_size: int
@classmethod
def from_node(cls, node: Node, **kwargs) -> "RMSNormShardingInfo":
"""Create a RMSNormShardingInfo from a node."""
return cls(target_node=node.name, **kwargs)
def validate(self, gm: GraphModule = None, node: Node = None) -> bool:
"""Validate that the node is a torch_rmsnorm op."""
if not is_op(node, torch.ops.auto_deploy.torch_rmsnorm):
ad_logger.debug(
f"RMSNormShardingInfo only applies to torch_rmsnorm ops. "
f"Got {node.target}. Skipping."
)
return False
return True
def apply(self, gm: GraphModule, node: Node) -> None:
"""Replace torch_rmsnorm with sharded_rmsnorm and shard the weight.
This handles both:
1. Weight sharding: Shard the RMSNorm weight across ranks
2. Op replacement: Replace with sharded_rmsnorm which uses all_reduce
for computing the correct global mean across shards.
"""
# Get original arguments: (input, weight, eps)
input_node = node.args[0]
weight_node = node.args[1]
eps = node.args[2]
# Shard the weight parameter (column-wise for 1D RMSNorm weight)
weight_key = weight_node.target
weight_tensor = gm.get_parameter(weight_key)
shard_weight_tensor(
gm=gm,
weight_tensor=weight_tensor,
param_key=weight_key,
dim=0, # Column shard for 1D weight
rank=self.config.rank,
world_size=self.world_size,
)
ad_logger.debug(f"Sharded RMSNorm weight: {weight_key}")
# Insert the new node with world_size parameter
with gm.graph.inserting_after(node):
new_node = gm.graph.call_function(
torch.ops.auto_deploy.sharded_rmsnorm.default,
args=(input_node, weight_node, eps, self.world_size),
)
# Replace all uses with the new node and remove the old node
node.replace_all_uses_with(new_node)
gm.graph.erase_node(node)
ad_logger.debug(
f"Replaced torch_rmsnorm with sharded_rmsnorm (world_size={self.world_size})"
)
########################################################
# Transform API classes
########################################################
@ -828,6 +899,9 @@ class ShardingTransformExecutor(BaseTransform):
for ep_transform in transforms.ep_transforms:
if check_and_apply(ep_transform):
num_matches += 1
for rmsnorm_transform in transforms.rmsnorm_transforms:
if check_and_apply(rmsnorm_transform):
num_matches += 1
# post-sharding cleanup transformations
for update_transform in transforms.parameter_update_transforms:
@ -851,6 +925,7 @@ class ShardingTransformContainer(BaseModel):
parameter_update_transforms: List[ParameterUpdateInfo] = Field(default_factory=list)
bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list)
ep_transforms: List[EPShardingInfo] = Field(default_factory=list)
rmsnorm_transforms: List[RMSNormShardingInfo] = Field(default_factory=list)
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -859,6 +934,7 @@ class ShardingTransformContainer(BaseModel):
BMMShardingInfo: self.bmm_transforms,
EPShardingInfo: self.ep_transforms,
ParameterUpdateInfo: self.parameter_update_transforms,
RMSNormShardingInfo: self.rmsnorm_transforms,
}
def add(self, transform: ShardingTransformInfo) -> bool:
@ -1299,14 +1375,30 @@ def _shard_parameter_node(
def _update_node_args(node: Node, args: tuple) -> None:
"""Update the node's arguments with the new sharded arguments."""
"""Update the node's arguments with the new sharded arguments.
For Node args: preserve the current value (may have been updated by other transforms).
For non-Node args (shapes, sizes, indices): use the stored sharded value.
This prevents ParameterUpdateInfo from reverting Node references that were
intentionally updated by other transforms.
"""
if "sharded" in node.meta and node.meta["sharded"]:
return
node.args = args
# Build new args: preserve current Node refs, apply stored non-Node values
new_args = []
for i, stored_arg in enumerate(args):
if isinstance(stored_arg, Node):
# Node args: preserve current value (may have been updated by other transforms)
new_args.append(node.args[i])
else:
# Non-Node args (shapes, sizes, indices): use stored sharded value
new_args.append(stored_arg)
node.args = tuple(new_args)
node.meta["sharded"] = True
ad_logger.debug(
f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}."
)
ad_logger.debug(f"Updated node {node}: sharded arguments are now {node.args}.")
def _insert_sharded_moe(
@ -1838,43 +1930,242 @@ def _determine_fused_weight_dims(
return None
linear_node = linear_nodes[0]
fused_weight_dims = None
# check if there are split nodes in the subgraph. They may indicate fused weights (e.g., QKV)
split_nodes = list(filtered_nodes(linear_node.users, ops=[torch.ops.aten.split_with_sizes]))
if len(split_nodes) > 0:
assert len(linear_nodes) == 1
if len(linear_nodes) == 1:
linear_node = linear_nodes[0]
assert len(split_nodes) == 1, "Expecting exactly one split node for fused weights"
fused_weight_dims = split_nodes[0].args[1]
# check if there are split nodes in the subgraph. They may indicate fused weights (e.g., QKV)
linear_split_users = list(
filtered_nodes(linear_node.users, ops=torch.ops.aten.split_with_sizes)
)
linear_slice_users = list(filtered_nodes(linear_node.users, ops=torch.ops.aten.slice))
linear_chunk_users = list(filtered_nodes(linear_node.users, ops=torch.ops.aten.chunk))
if len(linear_split_users) > 0:
assert len(linear_split_users) == 1, (
"Expecting exactly one split node for fused weights"
)
fused_weight_dims = linear_split_users[0].args[1]
slice_nodes = list(filtered_nodes(linear_node.users, ops=[torch.ops.aten.slice]))
if len(slice_nodes) > 0:
# we are probably in fused QKV case with single linear node and 3 slice nodes
assert len(linear_nodes) == 1
linear_node = linear_nodes[0]
assert all(
s.args[1] == 2 for s in filtered_nodes(linear_node.users, ops=torch.ops.aten.slice)
), "Expecting slice nodes to slice tensor over dim=2"
fused_weight_dims = [s.args[3] - s.args[2] for s in linear_node.users]
weight_dim = shape(linear_node)[2]
if sum(fused_weight_dims) != weight_dim:
if fused_weight_dims[-1] > weight_dim:
fused_weight_dims[-1] = weight_dim - sum(fused_weight_dims[:-1])
else:
ad_logger.warning(
f"Fused weight dims {fused_weight_dims} do not sum to weight dim {weight_dim}. Skipping."
)
return None
chunk_nodes = list(filtered_nodes(linear_node.users, ops=torch.ops.aten.chunk))
if len(chunk_nodes) > 0:
assert len(linear_nodes) == 1
linear_node = linear_nodes[0]
assert len(chunk_nodes) == 1, "Expecting exactly one chunk node for fused weights"
num_chunks = chunk_nodes[0].args[1]
weight_dim = shape(linear_node)[2]
fused_weight_dims = [weight_dim // num_chunks] * num_chunks
if fused_weight_dims is not None:
fused_weight_dims = tuple(fused_weight_dims)
return fused_weight_dims
elif len(linear_slice_users) > 0:
# we are probably in fused QKV case with single linear node and 3 slice nodes
assert all(s.args[1] == 2 for s in linear_slice_users), (
"Expecting slice nodes to slice tensor over dim=2"
)
fused_weight_dims = [s.args[3] - s.args[2] for s in linear_slice_users]
assert fused_weight_dims, "fused weight dims cannot be empty"
weight_dim = linear_node.meta["val"].shape[2]
if sum(fused_weight_dims) != weight_dim:
if fused_weight_dims[-1] > weight_dim:
fused_weight_dims[-1] = weight_dim - sum(fused_weight_dims[:-1])
else:
ad_logger.warning(
f"Fused weight dims {fused_weight_dims} do not sum to weight dim {weight_dim}. Skipping."
)
return
elif len(linear_chunk_users) > 0:
assert len(linear_chunk_users) == 1, (
"Expecting exactly one chunk node for fused weights"
)
num_chunks = linear_chunk_users[0].args[1]
weight_dim = linear_node.meta["val"].shape[2]
fused_weight_dims = [weight_dim // num_chunks] * num_chunks
def _find_upstream_qk_proj(node: Node, gm: GraphModule) -> Optional[str]:
"""
Find the upstream q/k projection linear node from an RMSNorm input.
Traverses backwards through pass-through tensor operations (view, reshape,
type conversions, quantize/dequantize, etc.) to find the immediate producer.
If that producer is a q_proj or k_proj linear, returns the weight name.
Args:
node: The input node to start traversing from (typically RMSNorm's activation input)
gm: The graph module containing the nodes
Returns:
The weight name if a q_proj or k_proj is found as immediate upstream, None otherwise.
TODO: Is there a more efficient way to do this?
"""
# Pass-through ops that we traverse through (these don't change the semantic meaning)
passthrough_ops = [
torch.ops.aten.view,
torch.ops.aten.reshape,
torch.ops.aten.contiguous,
torch.ops.aten.clone,
torch.ops.aten.to,
torch.ops.aten._to_copy,
torch.ops.aten.slice,
torch.ops.aten.transpose,
torch.ops.aten.permute,
]
visited = set()
current = node
# Traverse through pass-through ops to find the actual producer
while current is not None and current not in visited:
visited.add(current)
# Check if this is a linear operation (the producer we're looking for)
if is_any_lin_op(current):
try:
weight_name = extract_weight_name(current)
if weight_name and ("q_proj" in weight_name or "k_proj" in weight_name):
return weight_name
except (AttributeError, AssertionError):
pass
# Found a linear but not q/k proj - this is not a QK norm
return None
# If this is a pass-through op, continue to its first input
if is_op(current, passthrough_ops):
# Get the first tensor input (skip non-tensor args like dims)
tensor_inputs = [arg for arg in current.all_input_nodes if isinstance(arg, Node)]
if tensor_inputs:
current = tensor_inputs[0]
continue
# Hit a non-passthrough, non-linear op - stop searching
break
return None
def _shard_qk_norm(
layer_subgraph: LayerSubgraph,
linear_nodes: List[Node],
transform_container: ShardingTransformContainer,
) -> int:
"""
Shard RMSNorm ops in attention layers that operate on full hidden size.
This function detects torch_rmsnorm ops that are true QK norms - i.e., norms that
operate on the output of q_proj or k_proj. These global QK norms operate on
flattened Q/K output [batch, seq, hidden_size] before reshape and need special
handling for tensor parallelism:
1. Weight sharding: The norm weight is column-sharded across ranks
2. Global mean: Replaced with sharded_rmsnorm which uses all_reduce
Detection criteria:
1. The RMSNorm's input must trace back to a q_proj or k_proj linear operation
2. The weight shape must match q/k projection output dimensions
Example1: - Global Norm on all heads directly on flattened Q/K output (e.g. MiniMax):
self.q_norm = MiniMaxM2RMSNorm(self.head_dim * config.num_attention_heads, eps=config.rms_norm_eps)
weight shape: [num_heads * head_dim] -> matches valid_sizes -> will be sharded
Status: Needs sharding (weight matches q projection output dim)
Example2: - Norm per head after reshaping to [batch, seq, num_heads, head_dim] (e.g. GLM 4.7):
self.q_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
weight shape: [head_dim] -> does NOT match valid_sizes -> skipped
Status: No sharding needed (weight doesn't match q/k output dims)
Args:
layer_subgraph: The attention layer subgraph
linear_nodes: The linear nodes (q/k/v projections)
transform_container: Container for sharding transformations
Returns:
Number of nodes added for sharding
"""
if layer_subgraph.layer_type != LayerType.ATTENTION or layer_subgraph.terminating_node is None:
return 0
config = transform_container.config
world_size = config.world_size
gm = linear_nodes[0].graph.owning_module
added_nodes = 0
# Collect valid output dimensions from q/k/v projections
# These are the only valid sizes for intermediate weights (e.g., q_norm, k_norm)
valid_sizes = set()
for lin_node in linear_nodes:
try:
wkey = extract_weight_name(lin_node)
w = gm.get_parameter(wkey)
valid_sizes.add(w.shape[0]) # q: num_heads*head_dim, k/v: num_kv_heads*head_dim
except (AttributeError, AssertionError):
pass
# Find all intermediate weight nodes between q/k/v projections and o_proj.
intermediate_weight_nodes = subgraph(
sources=linear_nodes,
sinks=[layer_subgraph.terminating_node],
include=lambda n: n.op == "get_attr",
)
for weight_node in intermediate_weight_nodes:
weight_key = weight_node.target
# First check: is this weight consumed by a torch_rmsnorm op
if len(list(weight_node.users)) == 0:
continue
user_node = list(weight_node.users)[0]
if not is_op(user_node, torch.ops.auto_deploy.torch_rmsnorm):
continue
# Verify this is a true QK norm by checking its input traces back to q_proj or k_proj
# This filters out input_layernorm which feeds INTO q/k/v projections (not after them)
rmsnorm_input = user_node.args[0] # activation input (not the weight)
upstream_proj = _find_upstream_qk_proj(rmsnorm_input, gm)
if upstream_proj is None:
ad_logger.debug(
f"Skipping {user_node.name} - input does not trace back to q_proj or k_proj "
f"(likely input_layernorm or other non-QK norm)"
)
continue
ad_logger.debug(f"Found QK norm {user_node.name} with upstream projection: {upstream_proj}")
# Try to get the parameter
try:
param = gm.get_parameter(weight_key)
except AttributeError:
ad_logger.debug(f"Could not get parameter for {weight_key}, skipping")
continue
# Only shard 1D weights (RMSNorm weights are always 1D)
if param.dim() != 1:
ad_logger.debug(
f"Skipping intermediate weight {weight_key} with dim={param.dim()} (not 1D)"
)
continue
# Check if the weight size is divisible by world_size
if param.shape[0] % world_size != 0:
ad_logger.debug(
f"Skipping intermediate weight {weight_key} with shape {param.shape} "
f"(not divisible by world_size={world_size})"
)
continue
# Check if weight size matches one of the q/k/v projection output dimensions
# This filters out per-head norms and unrelated weights like inv_freq
if valid_sizes and param.shape[0] not in valid_sizes:
ad_logger.debug(
f"Skipping intermediate weight {weight_key} with shape {param.shape} "
f"(not in valid_sizes={valid_sizes}, likely per-head norm)"
)
continue
# Add RMSNormShardingInfo to replace with sharded_rmsnorm
# This handles both weight sharding and op replacement in one transform
if transform_container.add(
RMSNormShardingInfo.from_node(
user_node,
config=config,
world_size=world_size,
)
):
ad_logger.debug(
f"Added RMSNormShardingInfo for {user_node.name} "
f"(will replace with sharded_rmsnorm for global mean)"
)
added_nodes += 1
return added_nodes
def _process_column_sharding(
@ -1955,6 +2246,10 @@ def _process_column_sharding(
)
)
# chunk nodes do not need to be updated
# Shard intermediate weights (e.g. q/k/v -> q_norm, k_norm ... -> o_proj) for attention layers
added_nodes += _shard_qk_norm(layer_subgraph, linear_nodes, transform_container)
return added_nodes

View File

@ -967,7 +967,9 @@ def get_layer_after_linear_node(
min_local_shape=min_local_shape,
)
assert linear_nodes[start_lin_index] in opening_linear_nodes, (
"Linear node not found in opening linear nodes"
f"Linear node not found in opening linear nodes - "
f"terminating_linear_node:{terminating_linear_node.name}, "
f"opening_linear_nodes: {[n.name for n in opening_linear_nodes]}"
)
# return the index of the terminating linear node

View File

@ -0,0 +1,171 @@
"""Functional tests for sharded_rmsnorm custom op.
This module tests that the sharded_rmsnorm op produces correct
numerical results when run across multiple GPUs with column-sharded inputs.
The op computes RMSNorm on column-sharded activations by using all_reduce to
compute the global mean of squared values across all shards, ensuring correct
normalization equivalent to non-sharded global RMSNorm.
"""
import pickle
import sys
import traceback
import cloudpickle
import pytest
import torch
import torch.distributed as dist
from mpi4py import MPI
from tensorrt_llm._torch.auto_deploy.distributed.common import initialize, is_initialized
# Register this module for cloudpickle serialization for MPI workers
cloudpickle.register_pickle_by_value(sys.modules[__name__])
MPI.pickle.__init__(
cloudpickle.dumps,
cloudpickle.loads,
pickle.HIGHEST_PROTOCOL,
)
# needed since we reuse the mpi executor pool, first test running will leak a thread
pytestmark = pytest.mark.threadleak(enabled=False)
def _reference_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
"""Reference RMSNorm implementation for comparison."""
input_fp32 = input.to(torch.float32)
variance = input_fp32.pow(2).mean(-1, keepdim=True)
input_normalized = input_fp32 * torch.rsqrt(variance + eps)
return weight * input_normalized.to(input.dtype)
def _run_sharded_rmsnorm_test(
tensor_parallel_size: int,
batch_size: int = 2,
seq_len: int = 8,
hidden_size: int = 64,
eps: float = 1e-6,
dtype_str: str = "float16",
):
"""Test that sharded_rmsnorm matches non-sharded global RMSNorm.
Each rank:
1. Creates identical full input and weight tensors (same seed)
2. Computes reference result using non-sharded RMSNorm
3. Column-shards input and weight for this rank
4. Calls sharded_rmsnorm on local shard
5. All-gathers results and compares with reference
"""
# Import inside worker to avoid cloudpickle serialization issues with torch.ops
import tensorrt_llm
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
rank = tensorrt_llm.mpi_rank()
world_size = tensor_parallel_size
torch.cuda.set_device(rank)
if not is_initialized():
initialize(rank, port=29500)
# Map string to torch.dtype inside worker to avoid cloudpickle serialization issues
dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16}
dtype = dtype_map[dtype_str]
try:
torch.manual_seed(42) # Same seed for reproducibility across ranks
# Full tensors (same on all ranks due to seed)
full_input = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=dtype)
full_weight = torch.randn(hidden_size, device="cuda", dtype=dtype)
# Reference: non-sharded global RMSNorm
reference_output = _reference_rmsnorm(full_input, full_weight, eps)
# Column shard: split hidden_size across ranks
local_hidden_size = hidden_size // world_size
start_idx = rank * local_hidden_size
end_idx = start_idx + local_hidden_size
local_input = full_input[..., start_idx:end_idx].contiguous()
local_weight = full_weight[start_idx:end_idx].contiguous()
# Call sharded_rmsnorm using dynamic getattr to avoid cloudpickle capturing torch.ops
sharded_rmsnorm_op = getattr(getattr(torch, "ops"), "auto_deploy").sharded_rmsnorm
local_output = sharded_rmsnorm_op(local_input, local_weight, eps, world_size)
# All-gather to reconstruct full output
gathered_outputs = [torch.zeros_like(local_output) for _ in range(world_size)]
dist.all_gather(gathered_outputs, local_output)
reconstructed_output = torch.cat(gathered_outputs, dim=-1)
# Compare with reference
torch.testing.assert_close(
reconstructed_output,
reference_output,
atol=1e-2,
rtol=1e-2,
msg=f"sharded_rmsnorm result doesn't match reference (rank={rank})",
)
except Exception:
traceback.print_exc()
raise
return True
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs for this test")
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
def test_sharded_rmsnorm_functional(mpi_pool_executor):
"""Functional test: verify sharded_rmsnorm produces correct numerical results."""
torch.manual_seed(0)
tensor_parallel_size = mpi_pool_executor.num_workers
results = mpi_pool_executor.map(
_run_sharded_rmsnorm_test,
*zip(*[(tensor_parallel_size,)] * tensor_parallel_size),
)
for r in results:
assert r is True
def _run_sharded_rmsnorm_hidden_size_test(tensor_parallel_size: int, hidden_size: int):
"""Worker function for hidden size parametrized test."""
return _run_sharded_rmsnorm_test(tensor_parallel_size, hidden_size=hidden_size)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs for this test")
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
@pytest.mark.parametrize("hidden_size", [64, 128, 256])
def test_sharded_rmsnorm_different_hidden_sizes(mpi_pool_executor, hidden_size):
"""Test sharded_rmsnorm with different hidden sizes."""
torch.manual_seed(0)
tensor_parallel_size = mpi_pool_executor.num_workers
results = mpi_pool_executor.map(
_run_sharded_rmsnorm_hidden_size_test,
*zip(*[(tensor_parallel_size, hidden_size)] * tensor_parallel_size),
)
for r in results:
assert r is True
def _run_sharded_rmsnorm_dtype_test(tensor_parallel_size: int, dtype_str: str):
"""Worker function for dtype parametrized test."""
return _run_sharded_rmsnorm_test(tensor_parallel_size, dtype_str=dtype_str)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs for this test")
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
def test_sharded_rmsnorm_different_dtypes(mpi_pool_executor, dtype_str):
"""Test sharded_rmsnorm with different dtypes."""
torch.manual_seed(0)
tensor_parallel_size = mpi_pool_executor.num_workers
results = mpi_pool_executor.map(
_run_sharded_rmsnorm_dtype_test,
*zip(*[(tensor_parallel_size, dtype_str)] * tensor_parallel_size),
)
for r in results:
assert r is True

View File

@ -0,0 +1,325 @@
"""Tests for RMSNorm sharding detection and transformation.
This module tests that the sharding transform correctly detects and transforms
torch_rmsnorm ops based on weight shape and position in the graph:
1. Full hidden norm (weight shape = [num_heads * head_dim], AFTER q/k projection):
- Detected as QK norm needing sharding replaced with sharded_rmsnorm
- Weight is sharded
2. Per-head norm (weight shape = [head_dim], like GLM):
- NOT detected as needing sharding stays as local torch_rmsnorm
- No transformation needed
3. Input layernorm (feeds INTO q/k/v projections, not after):
- NOT detected as QK norm stays as local torch_rmsnorm
- Even though weight shape matches, it's not a QK norm
These are graph-level unit tests that verify the transform logic.
"""
import torch
import torch.nn as nn
# Ensure custom ops are registered
from tensorrt_llm._torch.auto_deploy.custom_ops import rms_norm # noqa: F401
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
class AttentionWithInputLayernorm(nn.Module):
"""Attention module with input_layernorm BEFORE q/k/v projections.
This simulates the standard Llama pattern where input_layernorm feeds
INTO the attention projections. This should NOT be detected as a QK norm.
"""
def __init__(self, hidden_size: int = 64, num_heads: int = 4):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
# Input layernorm - BEFORE q/k/v projections (like Llama)
self.input_layernorm_weight = nn.Parameter(torch.ones(hidden_size))
self.eps = 1e-6
# Q/K/V/O projections
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, s, _ = x.shape
# Input layernorm BEFORE projections (Llama pattern)
x_normed = torch.ops.auto_deploy.torch_rmsnorm(x, self.input_layernorm_weight, self.eps)
q = self.q_proj(x_normed)
k = self.k_proj(x_normed)
v = self.v_proj(x_normed)
# Reshape for attention
q = q.view(b, s, self.num_heads, self.head_dim)
k = k.view(b, s, self.num_heads, self.head_dim)
v = v.view(b, s, self.num_heads, self.head_dim)
y = torch.ops.auto_deploy.torch_attention(q, k, v, is_causal=True, layout="bsnd")
y = y.contiguous().view(b, s, -1)
return self.o_proj(y)
class SimpleAttentionWithQKNorm(nn.Module):
"""Attention module with configurable QK normalization.
Args:
hidden_size: Total hidden dimension
num_heads: Number of attention heads
use_full_hidden_norm: If True, use full hidden norm (like MiniMax)
- Weight shape = [hidden_size], norm before reshape
- Should be transformed to sharded_rmsnorm
If False, use per-head norm (like GLM)
- Weight shape = [head_dim], norm after reshape
- Should NOT be transformed
"""
def __init__(
self, hidden_size: int = 64, num_heads: int = 4, use_full_hidden_norm: bool = True
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.use_full_hidden_norm = use_full_hidden_norm
# Q/K/V/O projections
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
# QK norm weights: shape depends on norm type
norm_dim = hidden_size if use_full_hidden_norm else self.head_dim
self.q_norm_weight = nn.Parameter(torch.ones(norm_dim))
self.k_norm_weight = nn.Parameter(torch.ones(norm_dim))
self.eps = 1e-6
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, s, _ = x.shape
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
if self.use_full_hidden_norm:
# Full hidden norm: apply before reshape
q = torch.ops.auto_deploy.torch_rmsnorm(q, self.q_norm_weight, self.eps)
k = torch.ops.auto_deploy.torch_rmsnorm(k, self.k_norm_weight, self.eps)
# Reshape for attention
q = q.view(b, s, self.num_heads, self.head_dim)
k = k.view(b, s, self.num_heads, self.head_dim)
v = v.view(b, s, self.num_heads, self.head_dim)
else:
# Reshape first for per-head norm
q = q.view(b, s, self.num_heads, self.head_dim)
k = k.view(b, s, self.num_heads, self.head_dim)
v = v.view(b, s, self.num_heads, self.head_dim)
# Per-head norm: apply after reshape (broadcasts over heads)
q = torch.ops.auto_deploy.torch_rmsnorm(q, self.q_norm_weight, self.eps)
k = torch.ops.auto_deploy.torch_rmsnorm(k, self.k_norm_weight, self.eps)
y = torch.ops.auto_deploy.torch_attention(q, k, v, is_causal=True, layout="bsnd")
y = y.contiguous().view(b, s, -1)
return self.o_proj(y)
def count_ops(gm, op) -> int:
"""Count the number of nodes with a specific op in the graph."""
count = 0
for node in gm.graph.nodes:
if node.op == "call_function" and is_op(node, op):
count += 1
return count
# =============================================================================
# Tests
# =============================================================================
class TestRMSNormShardingTransform:
"""Tests for the sharding transform on RMSNorm ops."""
def test_full_hidden_norm_transformed_to_sharded(self):
"""Test that full hidden norm RMSNorm ops are replaced with sharded_rmsnorm.
When weight shape = [hidden_size] (matches q_proj output dim):
- torch_rmsnorm should be replaced with sharded_rmsnorm
- Weight should be sharded
"""
model = SimpleAttentionWithQKNorm(use_full_hidden_norm=True).to("cuda", dtype=torch.float16)
x = torch.randn(1, 8, 64, device="cuda", dtype=torch.float16)
# Export to graph
gm = torch_export_to_gm(model, args=(x,), clone=True)
# Check before transform
before_rmsnorm = count_ops(gm, torch.ops.auto_deploy.torch_rmsnorm.default)
before_sharded = count_ops(gm, torch.ops.auto_deploy.sharded_rmsnorm.default)
assert before_rmsnorm == 2, f"Expected 2 torch_rmsnorm before, got {before_rmsnorm}"
assert before_sharded == 0, f"Expected 0 sharded_rmsnorm before, got {before_sharded}"
# Apply sharding transform with world_size=2
optimizer = InferenceOptimizer(
None,
{
"detect_sharding": {
"stage": "sharding",
"simple_shard_only": False,
"sharding_source": ["manual", "factory", "heuristic"],
"support_partial_config": True,
"sharding_dims": ["tp", "ep", "bmm"],
"shard_all_unprocessed": True,
"allreduce_strategy": "NCCL",
"dist_backend": "auto",
"requires_shape_prop": True,
},
"sharding_transform_executor": {
"stage": "sharding",
"run_shape_prop": True,
},
},
)
optimizer.shared_config.local_rank = 0
optimizer.shared_config.world_size = 2
gm_transformed = optimizer(None, gm)
# Check after transform
after_rmsnorm = count_ops(gm_transformed, torch.ops.auto_deploy.torch_rmsnorm.default)
after_sharded = count_ops(gm_transformed, torch.ops.auto_deploy.sharded_rmsnorm.default)
# The QK norms (weight matching q/k output dims) should be transformed
assert after_sharded == 2, (
f"Expected 2 sharded_rmsnorm after transform, got {after_sharded}. "
f"Remaining torch_rmsnorm: {after_rmsnorm}"
)
def test_per_head_norm_not_transformed(self):
"""Test that per-head norm RMSNorm ops are NOT replaced.
When weight shape = [head_dim] (doesn't match q_proj output dim):
- torch_rmsnorm should stay as torch_rmsnorm
- No sharded_rmsnorm should be added
"""
model = SimpleAttentionWithQKNorm(use_full_hidden_norm=False).to(
"cuda", dtype=torch.float16
)
x = torch.randn(1, 8, 64, device="cuda", dtype=torch.float16)
# Export to graph
gm = torch_export_to_gm(model, args=(x,), clone=True)
# Check before transform
before_rmsnorm = count_ops(gm, torch.ops.auto_deploy.torch_rmsnorm.default)
assert before_rmsnorm == 2, f"Expected 2 torch_rmsnorm before, got {before_rmsnorm}"
# Apply sharding transform with world_size=2
# Using the same config as default.yaml for detect_sharding and sharding_transform_executor
optimizer = InferenceOptimizer(
None,
{
"detect_sharding": {
"stage": "sharding",
"simple_shard_only": False,
"sharding_source": ["manual", "factory", "heuristic"],
"support_partial_config": True,
"sharding_dims": ["tp", "ep", "bmm"],
"shard_all_unprocessed": True,
"allreduce_strategy": "NCCL",
"dist_backend": "auto",
"requires_shape_prop": True,
},
"sharding_transform_executor": {
"stage": "sharding",
"run_shape_prop": True,
},
},
)
# Set world_size and rank on shared_config
optimizer.shared_config.local_rank = 0
optimizer.shared_config.world_size = 2
gm_transformed = optimizer(None, gm)
# Check after transform
after_rmsnorm = count_ops(gm_transformed, torch.ops.auto_deploy.torch_rmsnorm.default)
after_sharded = count_ops(gm_transformed, torch.ops.auto_deploy.sharded_rmsnorm.default)
# Per-head norms should NOT be transformed to sharded
assert after_sharded == 0, (
f"Expected 0 sharded_rmsnorm for per-head norm, got {after_sharded}"
)
# The original rmsnorm ops should still be present (or fewer if some were removed)
# Note: Some rmsnorm ops may be removed/transformed for other reasons, but none should become sharded
print(f"After transform: {after_rmsnorm} torch_rmsnorm, {after_sharded} sharded_rmsnorm")
def test_input_layernorm_not_transformed(self):
"""Test that input_layernorm (before q/k/v projections) is NOT replaced.
When RMSNorm feeds INTO q/k/v projections (like Llama's input_layernorm):
- torch_rmsnorm should stay as torch_rmsnorm
- No sharded_rmsnorm should be added
- Even though weight shape matches [hidden_size], it's not a QK norm
This tests the fix for the bug where input_layernorm was incorrectly
detected as a QK norm because its weight shape matched q_proj output dim.
"""
model = AttentionWithInputLayernorm().to("cuda", dtype=torch.float16)
x = torch.randn(1, 8, 64, device="cuda", dtype=torch.float16)
# Export to graph
gm = torch_export_to_gm(model, args=(x,), clone=True)
# Check before transform
before_rmsnorm = count_ops(gm, torch.ops.auto_deploy.torch_rmsnorm.default)
assert before_rmsnorm == 1, f"Expected 1 torch_rmsnorm before, got {before_rmsnorm}"
# Apply sharding transform with world_size=4 (like the failing Llama test)
optimizer = InferenceOptimizer(
None,
{
"detect_sharding": {
"stage": "sharding",
"simple_shard_only": False,
"sharding_source": ["manual", "factory", "heuristic"],
"support_partial_config": True,
"sharding_dims": ["tp", "ep", "bmm"],
"shard_all_unprocessed": True,
"allreduce_strategy": "NCCL",
"dist_backend": "auto",
"requires_shape_prop": True,
},
"sharding_transform_executor": {
"stage": "sharding",
"run_shape_prop": True,
},
},
)
optimizer.shared_config.local_rank = 0
optimizer.shared_config.world_size = 4
gm_transformed = optimizer(None, gm)
# Check after transform
after_rmsnorm = count_ops(gm_transformed, torch.ops.auto_deploy.torch_rmsnorm.default)
after_sharded = count_ops(gm_transformed, torch.ops.auto_deploy.sharded_rmsnorm.default)
# Input layernorm should NOT be transformed to sharded_rmsnorm
# because it feeds INTO q/k/v projections, not after them
assert after_sharded == 0, (
f"Expected 0 sharded_rmsnorm for input_layernorm, got {after_sharded}. "
f"input_layernorm should not be detected as a QK norm."
)
print(f"After transform: {after_rmsnorm} torch_rmsnorm, {after_sharded} sharded_rmsnorm")

View File

@ -0,0 +1,138 @@
"""Testing module patches that enable export of MiniMax-M2 model.
This test verifies that the patched MiniMaxM2SparseMoeBlock forward function
produces identical outputs to the original HuggingFace implementation.
"""
import types
import pytest
import torch
from test_common.llm_data import hf_id_to_local_model_dir
from transformers import AutoConfig, AutoModelForCausalLM
# Import custom_ops to register torch.ops.auto_deploy.torch_moe
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
from tensorrt_llm._torch.auto_deploy.models.patches.minimax_m2 import minimax_m2_moe
def _load_minimax_m2_moe_layer(model_name_or_path):
"""
Loads the MoE layer from MiniMax-M2 model with a minimal configuration.
We create a small model to keep tests fast while still exercising the
MoE routing and computation logic.
Parameters:
model_name_or_path (str): Path or name of the pretrained model.
Returns:
module: The MiniMaxM2SparseMoeBlock layer.
"""
try:
# Load only the model configuration
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
# Configure minimal model for fast testing
config.num_hidden_layers = 1
config.use_cache = False
config.hidden_size = 16 # Small hidden size
config.intermediate_size = 8 # For MLP within experts
config.mlp_intermediate_size = 32
config.num_local_experts = 4 # Small number of experts
config.num_experts_per_tok = 2 # Top-k experts
config.num_attention_heads = 2
config.num_key_value_heads = 2
config.router_jitter_noise = 0.0 # Disable jitter for deterministic tests
# Build the model architecture (no weights loaded)
# Note: Importing minimax_m2 module auto-patches from_config, so the
# instance's forward is already patched. But the CLASS method is still
# the original HF implementation, which we use as reference.
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
model.eval()
# Access the MoE layer
layer_name = "model.layers.0.block_sparse_moe"
module = dict(model.named_modules()).get(layer_name)
if module is None:
print(f"Layer '{layer_name}' not found in the model.")
else:
print(f"Successfully extracted layer '{layer_name}'.")
return module
except Exception as e:
print(f"Error extracting layer: {e}")
return None
@pytest.mark.parametrize(
"model_name",
[
pytest.param(
hf_id_to_local_model_dir("MiniMaxAI/MiniMax-M2"),
),
],
)
def test_minimax_m2_moe_patch(model_name):
"""
Test that the patched MiniMaxM2SparseMoeBlock forward produces the same
output as the original HuggingFace implementation.
The patch rewrites the forward to use torch.ops.auto_deploy.torch_moe
for torch.export compatibility while maintaining numerical equivalence.
Since importing minimax_m2.py auto-patches module instances, we use the
CLASS method (type(module).forward) as the original HF reference.
"""
# Set seed for reproducibility
torch.manual_seed(42)
# Get MoE module (instance is already patched by import side-effect)
module = _load_minimax_m2_moe_layer(model_name)
assert module is not None, "Failed to load MiniMax-M2 MoE layer"
# Convert module to bfloat16 to match input dtype
module = module.to(torch.bfloat16)
# Create test input - same input will be used for both original and patched
# hidden_size=16 matches the config in _load_minimax_m2_moe_layer
hidden_size = 16
inputs = torch.randn(2, 6, hidden_size, dtype=torch.bfloat16)
# The CLASS method is still the original HuggingFace implementation
# (the auto-patch only patches instance methods, not the class)
original_class_forward = type(module).forward
# Generate reference output using original HF class method
# Uses: same module weights, same input tensor
with torch.no_grad():
ref_output, ref_router_logits = original_class_forward(module, inputs)
# The instance forward is already patched by the import side-effect,
# but let's be explicit and apply our patch function directly
module.forward = types.MethodType(minimax_m2_moe, module)
# Generate test output using patched implementation
# Uses: same module weights, same input tensor
with torch.no_grad():
test_output, test_router_logits = module(inputs)
# Verify outputs match
# Router logits should be identical (same computation path)
torch.testing.assert_close(
ref_router_logits,
test_router_logits,
atol=1e-5,
rtol=1e-5,
msg="Router logits mismatch between original and patched MoE",
)
# Final hidden states should be very close
# (small tolerance for different computation order in torch_moe)
torch.testing.assert_close(
ref_output,
test_output,
atol=1e-3,
rtol=1e-3,
msg="Output mismatch between original and patched MoE",
)