[None][feat] MiniMax M2 support (#10532)

Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
This commit is contained in:
jmydurant 2026-01-14 17:38:58 +08:00 committed by GitHub
parent 052c36ddd2
commit e7882d5c74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 417 additions and 3 deletions

View File

@ -78,8 +78,10 @@ enum class RoutingMethodType : int64_t
Llama4 = 3,
// RenormalizeNaive: Softmax -> TopK -> Renormalize
RenormalizeNaive = 4,
// MiniMaxM2: Sigmoid -> RoutingBiasAdd -> TopK -> Renormalize(without bias)
MiniMax2 = 5,
// Unspecified
Unspecified = 5,
Unspecified = 6,
};
inline int32_t maybeGetMinTokenCount(int32_t numPaddedTokens, int32_t hiddenSize, int32_t dtypeSizeBits)
@ -98,6 +100,7 @@ inline std::string serializeMoeRoutingMethodType(RoutingMethodType routingMethod
case RoutingMethodType::DeepSeekV3: return "DeepSeekV3";
case RoutingMethodType::Llama4: return "Llama4";
case RoutingMethodType::RenormalizeNaive: return "RenormalizeNaive";
case RoutingMethodType::MiniMax2: return "MiniMax2";
default: TLLM_CHECK_WITH_INFO(false, "Invalid routing method"); return "";
};
}

View File

@ -15,6 +15,7 @@ from .modeling_hunyuan_moe import HunYuanMoEV1ForCausalLM
from .modeling_hyperclovax import HCXVisionForCausalLM
from .modeling_llama import LlamaForCausalLM
from .modeling_llava_next import LlavaNextModel
from .modeling_minimaxm2 import MiniMaxM2ForCausalLM
from .modeling_mistral import Mistral3VLM, MistralForCausalLM
from .modeling_mixtral import MixtralForCausalLM
from .modeling_nemotron import NemotronForCausalLM
@ -80,6 +81,7 @@ __all__ = [
"SeedOssForCausalLM",
"Glm4MoeForCausalLM",
"Qwen3VLModel",
"MiniMaxM2ForCausalLM",
]
if transformers.__version__ >= "4.45.1":

View File

@ -0,0 +1,314 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional
import torch
from torch import nn
from transformers import PretrainedConfig
from tensorrt_llm.functional import PositionEmbeddingType
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..models.modeling_utils import ModelConfig
from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import MiniMaxM2MoeRoutingMethod, create_moe
from ..modules.linear import Linear
from ..modules.rms_norm import RMSNorm
from ..utils import AuxStreamType
from .modeling_utils import DecoderModel, DecoderModelForCausalLM, register_auto_model
# MiniMax M2/M2.1 requires the implementation of the following two additional components:
# 1. MoE routing method: Currently, TRT-LLM does not support
# the following routing method: sigmoid -> add bias -> topk -> renorm.
# 2. QK layer normalization needs to be performed across the head_num * head_size dimension,
# which conflicts with the current TP-mode attention logic.
# For the better performance, we suggest to enable attention DP when using MiniMax M2/M2.1 model.
class MiniMaxM2MoE(nn.Module):
def __init__(
self,
model_config: ModelConfig[PretrainedConfig],
aux_stream: torch.cuda.Stream,
layer_idx: Optional[int] = None,
):
super().__init__()
config = model_config.pretrained_config
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
self.enable_attention_dp = model_config.mapping.enable_attention_dp
# moe gate (linear layer) only runs in half/full precision for now
self.gate = Linear(
self.hidden_dim, self.num_experts, bias=False, dtype=torch.float32, quant_config=None
)
self.e_score_correction_bias = nn.Parameter(
torch.empty((self.num_experts), dtype=torch.float32), requires_grad=False
)
reduce_results = True
self.experts = create_moe(
routing_method=MiniMaxM2MoeRoutingMethod(
top_k=self.top_k,
num_experts=self.num_experts,
callable_e_score_correction_bias=lambda: self.e_score_correction_bias,
),
num_experts=self.num_experts,
aux_stream_dict={AuxStreamType.MoeChunkingOverlap: aux_stream},
reduce_results=reduce_results,
model_config=model_config,
layer_idx=layer_idx,
)
def load_weights(self, weights: List[Dict]):
assert len(weights) == 1
self.e_score_correction_bias.copy_(
weights[0]["e_score_correction_bias"][:].to(self.e_score_correction_bias.dtype)
)
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
hidden_states_f32 = hidden_states.to(torch.float32)
router_logits = self.gate(hidden_states_f32)
final_hidden_states = self.experts(
hidden_states,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=False,
)
return final_hidden_states
# It's a little bit tricky to implement special qk norm
# because rms dim is hidden_size * num_heads, not hidden_size, after qkv linear,
# the result size is hidden_size * num_heads / tp_size.
# Actually, we have two strategies to implement qk norm attention:
# 1. the first linear layer is not col parallel, then we can use the normal rms layer norm. each attention use full qkv
# 2. we use col parallel linear layer, then we use allgather to gather qkv from all gpus,
# then we use rms norm on q and k. Finally, we split qkv to each gpus and continue.
# for better performance, we choose the second strategy here.
# Most adaptions are from QKNormRoPEAttention.
class MiniMaxM2Attention(Attention):
def __init__(
self,
*,
model_config: ModelConfig[PretrainedConfig],
layer_idx: Optional[int] = None,
):
config = model_config.pretrained_config
self.pretrained_config = config
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
bias=False,
pos_embd_params=PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gpt_neox,
rope=RopeParams.from_config(config),
),
rope_fusion=True,
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
)
self.q_norm = RMSNorm(
hidden_size=self.q_size * self.tp_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
)
self.k_norm = RMSNorm(
hidden_size=self.kv_size * self.tp_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
)
def apply_qk_norm(self, q, k):
if self.qkv_proj.mapping.tp_size > 1:
# collect q and k from all gpus
from ..distributed import allgather
temp_q = allgather(q, self.qkv_proj.mapping)
temp_k = allgather(k, self.qkv_proj.mapping)
temp_q = self.q_norm(temp_q)
temp_k = self.k_norm(temp_k)
q = temp_q.reshape(-1, self.tp_size, self.q_size)[:, self.tp_rank, :].reshape(
-1, self.q_size
)
k = temp_k.reshape(-1, self.tp_size, self.kv_size)[:, self.tp_rank, :].reshape(
-1, self.kv_size
)
else:
q = self.q_norm(q)
k = self.k_norm(k)
return q, k
def apply_rope(
self,
q: torch.Tensor,
k: Optional[torch.Tensor],
v: Optional[torch.Tensor],
position_ids: torch.Tensor,
):
"""
The apply_rope method is called in the forward method of the Attention class.
The apply_rope method is overridden in this class to apply QK norm and RoPE to the input tensor.
"""
# Apply QK norm before RoPE.
q, k, v = self.split_qkv(q, k, v)
q, k = self.apply_qk_norm(q, k)
return super().apply_rope(q, k, v, position_ids)
class MiniMaxM2DecoderLayer(DecoderLayer):
def __init__(
self,
model_config: ModelConfig[PretrainedConfig],
layer_idx: int,
aux_stream: torch.cuda.Stream,
):
super().__init__()
config = model_config.pretrained_config
self.hidden_size = config.hidden_size
self.self_attn = MiniMaxM2Attention(model_config=model_config, layer_idx=layer_idx)
self.block_sparse_moe = MiniMaxM2MoE(
model_config=model_config, aux_stream=aux_stream, layer_idx=layer_idx
)
self.input_layernorm = RMSNorm(
hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype
)
self.post_attention_layernorm = RMSNorm(
hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype
)
self.mapping = model_config.mapping
self.layer_idx = layer_idx
def forward(
self,
position_ids: torch.IntTensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
**kwargs,
) -> torch.Tensor:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
**kwargs,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.block_sparse_moe(hidden_states, attn_metadata)
return hidden_states, residual
class MiniMaxM2Model(DecoderModel):
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
super().__init__(model_config)
# add this for kv cache initialization (if we use bf16 for kv cache)
quant_config = model_config.quant_config
if quant_config is None or (
(not quant_config.quant_mode.has_fp8_kv_cache())
and (not quant_config.quant_mode.has_fp4_kv_cache())
):
model_config.pretrained_config.torch_dtype = torch.bfloat16
config = model_config.pretrained_config
self.vocab_size = config.vocab_size
self.aux_stream = torch.cuda.Stream()
self.embed_tokens = Embedding(
config.vocab_size,
config.hidden_size,
dtype=config.torch_dtype,
enable_torch_compile_for_embedding=model_config.enable_torch_compile_for_embedding,
)
self.layers = nn.ModuleList(
[
MiniMaxM2DecoderLayer(model_config, layer_idx, self.aux_stream)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(
hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype
)
def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: Optional[torch.IntTensor] = None,
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
residual = None
for decoder_layer in self.layers:
hidden_states, residual = decoder_layer(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@register_auto_model("MiniMaxM2ForCausalLM")
class MiniMaxM2ForCausalLM(DecoderModelForCausalLM[MiniMaxM2Model, PretrainedConfig]):
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
super().__init__(
MiniMaxM2Model(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size,
)

View File

@ -12,7 +12,8 @@ from .quantization import FusedMoEQuantScalesFP8
from .routing import (BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod,
DefaultMoeRoutingMethod,
Llama4RenormalizeMoeRoutingMethod,
LoadBalancedMoeRoutingMethod, RenormalizeMoeRoutingMethod,
LoadBalancedMoeRoutingMethod, MiniMaxM2MoeRoutingMethod,
RenormalizeMoeRoutingMethod,
RenormalizeNaiveMoeRoutingMethod, RoutingMethodType,
SparseMixerMoeRoutingMethod, StaticMoeRoutingMethod,
create_renormalize_expert_load_balanced_logits)
@ -33,6 +34,7 @@ __all__ = [
"MoE",
"MoeLoadBalancer",
"MoEWeightLoadingMode",
"MiniMaxM2MoeRoutingMethod",
"RenormalizeMoeRoutingMethod",
"RenormalizeNaiveMoeRoutingMethod",
"RoutingMethodType",

View File

@ -155,8 +155,10 @@ class RoutingMethodType(IntEnum):
Llama4 = 3,
# Qwen3: Softmax -> TopK -> Renormalize
RenormalizeNaive = 4,
# MiniMaxM2: Sigmoid -> RoutingBiasAdd -> TopK -> Renormalize(without bias)
MiniMax2 = 5,
# Unspecified
Unspecified = 5,
Unspecified = 6,
class BaseMoeRoutingMethod(nn.Module):
@ -379,6 +381,57 @@ class DeepSeekV3MoeRoutingMethod(BaseMoeRoutingMethod):
return RoutingMethodType.DeepSeekV3
class MiniMaxM2MoeRoutingMethod(BaseMoeRoutingMethod):
def __init__(
self,
top_k: int,
num_experts: int,
callable_e_score_correction_bias: Callable[[], torch.Tensor],
output_dtype: torch.dtype = torch.float32,
):
super().__init__()
self.top_k = top_k
self.num_experts = num_experts
assert callable(callable_e_score_correction_bias)
self.callable_e_score_correction_bias = callable_e_score_correction_bias
self.output_dtype = output_dtype
@staticmethod
@torch.compile(options={"max-autotune": True})
def get_scores(logits, e_score_correction_bias):
scores = F.sigmoid(logits)
scores_with_bias = scores + e_score_correction_bias
if enable_llm_debug():
has_nan = torch.isnan(scores_with_bias).any()
if has_nan:
warnings.warn(
"Detected NAN in the tensor scores_with_bias. Please check if it matches the expectation."
)
return scores, scores_with_bias
def apply(self,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
scores, scores_with_bias = self.get_scores(router_logits,
self.e_score_correction_bias)
_, topk_idx = torch.topk(scores_with_bias,
k=self.top_k,
dim=-1,
sorted=False)
top_k_weights = scores.gather(1, topk_idx)
top_k_weights /= (top_k_weights.sum(dim=-1, keepdim=True) + 1e-20)
return topk_idx.to(torch.int32), top_k_weights.to(self.output_dtype)
@property
def e_score_correction_bias(self) -> torch.Tensor:
return self.callable_e_score_correction_bias()
@property
def routing_method_type(self):
return RoutingMethodType.MiniMax2
class RenormalizeMoeRoutingMethod(BaseMoeRoutingMethod):
def __init__(
@ -587,6 +640,8 @@ ROUTING_METHOD_TYPE_TO_CLASS: Dict[RoutingMethodType,
RenormalizeNaiveMoeRoutingMethod,
RoutingMethodType.Unspecified:
BaseMoeRoutingMethod,
RoutingMethodType.MiniMax2:
MiniMaxM2MoeRoutingMethod,
}

View File

@ -313,3 +313,7 @@ nvidia/Nemotron-3-Nano:
- quant_algo: FP8
kv_cache_quant_algo: FP8
accuracy: 68.73
MiniMaxAI/MiniMax-M2:
- accuracy: 85
- quant_algo: FP8_BLOCK_SCALES
accuracy: 85

View File

@ -5350,3 +5350,35 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm,
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
@skip_pre_hopper
class TestMiniMaxM2(LlmapiAccuracyTestHarness):
MODEL_NAME = "MiniMaxAI/MiniMax-M2"
MODEL_PATH = f"{llm_models_root()}/MiniMax-M2"
@parametrize_with_ids("tp_size,ep_size", [(4, 4)])
@pytest.mark.skip_less_device(4)
@parametrize_with_ids("attention_dp,cuda_graph,overlap_scheduler",
[(False, True, True), (True, True, True)])
def test_4gpus(self, tp_size, ep_size, attention_dp, cuda_graph,
overlap_scheduler):
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6)
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
moe_config=MoeConfig(
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"))
with LLM(self.MODEL_PATH,
tensor_parallel_size=tp_size,
pipeline_parallel_size=1,
moe_expert_parallel_size=ep_size,
kv_cache_config=kv_cache_config,
max_seq_len=4096,
**pytorch_config,
enable_attention_dp=attention_dp) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

View File

@ -31,6 +31,8 @@ l0_dgx_b200:
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_adp_lmtp_tp4]
- accuracy/test_llm_api_pytorch.py::TestMiniMaxM2::test_4gpus[attention_dp=False-cuda_graph=True-overlap_scheduler=True-tp_size=4-ep_size=4] TIMEOUT (60)
# ------------- AutoDeploy tests ---------------
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16
- condition: