diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h index 987b953ee3..31eba5bb8d 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h @@ -78,8 +78,10 @@ enum class RoutingMethodType : int64_t Llama4 = 3, // RenormalizeNaive: Softmax -> TopK -> Renormalize RenormalizeNaive = 4, + // MiniMaxM2: Sigmoid -> RoutingBiasAdd -> TopK -> Renormalize(without bias) + MiniMax2 = 5, // Unspecified - Unspecified = 5, + Unspecified = 6, }; inline int32_t maybeGetMinTokenCount(int32_t numPaddedTokens, int32_t hiddenSize, int32_t dtypeSizeBits) @@ -98,6 +100,7 @@ inline std::string serializeMoeRoutingMethodType(RoutingMethodType routingMethod case RoutingMethodType::DeepSeekV3: return "DeepSeekV3"; case RoutingMethodType::Llama4: return "Llama4"; case RoutingMethodType::RenormalizeNaive: return "RenormalizeNaive"; + case RoutingMethodType::MiniMax2: return "MiniMax2"; default: TLLM_CHECK_WITH_INFO(false, "Invalid routing method"); return ""; }; } diff --git a/tensorrt_llm/_torch/models/__init__.py b/tensorrt_llm/_torch/models/__init__.py index 7b8718c10b..c56bf86faf 100644 --- a/tensorrt_llm/_torch/models/__init__.py +++ b/tensorrt_llm/_torch/models/__init__.py @@ -15,6 +15,7 @@ from .modeling_hunyuan_moe import HunYuanMoEV1ForCausalLM from .modeling_hyperclovax import HCXVisionForCausalLM from .modeling_llama import LlamaForCausalLM from .modeling_llava_next import LlavaNextModel +from .modeling_minimaxm2 import MiniMaxM2ForCausalLM from .modeling_mistral import Mistral3VLM, MistralForCausalLM from .modeling_mixtral import MixtralForCausalLM from .modeling_nemotron import NemotronForCausalLM @@ -80,6 +81,7 @@ __all__ = [ "SeedOssForCausalLM", "Glm4MoeForCausalLM", "Qwen3VLModel", + "MiniMaxM2ForCausalLM", ] if transformers.__version__ >= "4.45.1": diff --git a/tensorrt_llm/_torch/models/modeling_minimaxm2.py b/tensorrt_llm/_torch/models/modeling_minimaxm2.py new file mode 100644 index 0000000000..73cd480ee7 --- /dev/null +++ b/tensorrt_llm/_torch/models/modeling_minimaxm2.py @@ -0,0 +1,314 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional + +import torch +from torch import nn +from transformers import PretrainedConfig + +from tensorrt_llm.functional import PositionEmbeddingType + +from ..attention_backend import AttentionMetadata +from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams +from ..models.modeling_utils import ModelConfig +from ..modules.attention import Attention +from ..modules.decoder_layer import DecoderLayer +from ..modules.embedding import Embedding +from ..modules.fused_moe import MiniMaxM2MoeRoutingMethod, create_moe +from ..modules.linear import Linear +from ..modules.rms_norm import RMSNorm +from ..utils import AuxStreamType +from .modeling_utils import DecoderModel, DecoderModelForCausalLM, register_auto_model + + +# MiniMax M2/M2.1 requires the implementation of the following two additional components: +# 1. MoE routing method: Currently, TRT-LLM does not support +# the following routing method: sigmoid -> add bias -> topk -> renorm. +# 2. QK layer normalization needs to be performed across the head_num * head_size dimension, +# which conflicts with the current TP-mode attention logic. +# For the better performance, we suggest to enable attention DP when using MiniMax M2/M2.1 model. +class MiniMaxM2MoE(nn.Module): + def __init__( + self, + model_config: ModelConfig[PretrainedConfig], + aux_stream: torch.cuda.Stream, + layer_idx: Optional[int] = None, + ): + super().__init__() + config = model_config.pretrained_config + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + self.enable_attention_dp = model_config.mapping.enable_attention_dp + + # moe gate (linear layer) only runs in half/full precision for now + self.gate = Linear( + self.hidden_dim, self.num_experts, bias=False, dtype=torch.float32, quant_config=None + ) + + self.e_score_correction_bias = nn.Parameter( + torch.empty((self.num_experts), dtype=torch.float32), requires_grad=False + ) + + reduce_results = True + self.experts = create_moe( + routing_method=MiniMaxM2MoeRoutingMethod( + top_k=self.top_k, + num_experts=self.num_experts, + callable_e_score_correction_bias=lambda: self.e_score_correction_bias, + ), + num_experts=self.num_experts, + aux_stream_dict={AuxStreamType.MoeChunkingOverlap: aux_stream}, + reduce_results=reduce_results, + model_config=model_config, + layer_idx=layer_idx, + ) + + def load_weights(self, weights: List[Dict]): + assert len(weights) == 1 + + self.e_score_correction_bias.copy_( + weights[0]["e_score_correction_bias"][:].to(self.e_score_correction_bias.dtype) + ) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + all_rank_num_tokens = attn_metadata.all_rank_num_tokens + hidden_states_f32 = hidden_states.to(torch.float32) + router_logits = self.gate(hidden_states_f32) + final_hidden_states = self.experts( + hidden_states, + router_logits, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=False, + ) + return final_hidden_states + + +# It's a little bit tricky to implement special qk norm +# because rms dim is hidden_size * num_heads, not hidden_size, after qkv linear, +# the result size is hidden_size * num_heads / tp_size. +# Actually, we have two strategies to implement qk norm attention: +# 1. the first linear layer is not col parallel, then we can use the normal rms layer norm. each attention use full qkv +# 2. we use col parallel linear layer, then we use allgather to gather qkv from all gpus, +# then we use rms norm on q and k. Finally, we split qkv to each gpus and continue. +# for better performance, we choose the second strategy here. +# Most adaptions are from QKNormRoPEAttention. +class MiniMaxM2Attention(Attention): + def __init__( + self, + *, + model_config: ModelConfig[PretrainedConfig], + layer_idx: Optional[int] = None, + ): + config = model_config.pretrained_config + self.pretrained_config = config + + super().__init__( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + max_position_embeddings=config.max_position_embeddings, + bias=False, + pos_embd_params=PositionalEmbeddingParams( + type=PositionEmbeddingType.rope_gpt_neox, + rope=RopeParams.from_config(config), + ), + rope_fusion=True, + layer_idx=layer_idx, + dtype=config.torch_dtype, + config=model_config, + ) + + self.q_norm = RMSNorm( + hidden_size=self.q_size * self.tp_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype, + ) + self.k_norm = RMSNorm( + hidden_size=self.kv_size * self.tp_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype, + ) + + def apply_qk_norm(self, q, k): + if self.qkv_proj.mapping.tp_size > 1: + # collect q and k from all gpus + from ..distributed import allgather + + temp_q = allgather(q, self.qkv_proj.mapping) + temp_k = allgather(k, self.qkv_proj.mapping) + temp_q = self.q_norm(temp_q) + temp_k = self.k_norm(temp_k) + q = temp_q.reshape(-1, self.tp_size, self.q_size)[:, self.tp_rank, :].reshape( + -1, self.q_size + ) + k = temp_k.reshape(-1, self.tp_size, self.kv_size)[:, self.tp_rank, :].reshape( + -1, self.kv_size + ) + else: + q = self.q_norm(q) + k = self.k_norm(k) + + return q, k + + def apply_rope( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + position_ids: torch.Tensor, + ): + """ + The apply_rope method is called in the forward method of the Attention class. + The apply_rope method is overridden in this class to apply QK norm and RoPE to the input tensor. + """ + # Apply QK norm before RoPE. + q, k, v = self.split_qkv(q, k, v) + q, k = self.apply_qk_norm(q, k) + return super().apply_rope(q, k, v, position_ids) + + +class MiniMaxM2DecoderLayer(DecoderLayer): + def __init__( + self, + model_config: ModelConfig[PretrainedConfig], + layer_idx: int, + aux_stream: torch.cuda.Stream, + ): + super().__init__() + config = model_config.pretrained_config + self.hidden_size = config.hidden_size + + self.self_attn = MiniMaxM2Attention(model_config=model_config, layer_idx=layer_idx) + + self.block_sparse_moe = MiniMaxM2MoE( + model_config=model_config, aux_stream=aux_stream, layer_idx=layer_idx + ) + + self.input_layernorm = RMSNorm( + hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype + ) + + self.post_attention_layernorm = RMSNorm( + hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype + ) + self.mapping = model_config.mapping + self.layer_idx = layer_idx + + def forward( + self, + position_ids: torch.IntTensor, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + **kwargs, + ) -> torch.Tensor: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + **kwargs, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.block_sparse_moe(hidden_states, attn_metadata) + return hidden_states, residual + + +class MiniMaxM2Model(DecoderModel): + def __init__(self, model_config: ModelConfig[PretrainedConfig]): + super().__init__(model_config) + # add this for kv cache initialization (if we use bf16 for kv cache) + quant_config = model_config.quant_config + if quant_config is None or ( + (not quant_config.quant_mode.has_fp8_kv_cache()) + and (not quant_config.quant_mode.has_fp4_kv_cache()) + ): + model_config.pretrained_config.torch_dtype = torch.bfloat16 + config = model_config.pretrained_config + self.vocab_size = config.vocab_size + self.aux_stream = torch.cuda.Stream() + + self.embed_tokens = Embedding( + config.vocab_size, + config.hidden_size, + dtype=config.torch_dtype, + enable_torch_compile_for_embedding=model_config.enable_torch_compile_for_embedding, + ) + + self.layers = nn.ModuleList( + [ + MiniMaxM2DecoderLayer(model_config, layer_idx, self.aux_stream) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm( + hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype + ) + + def forward( + self, + attn_metadata: AttentionMetadata, + input_ids: Optional[torch.IntTensor] = None, + position_ids: Optional[torch.IntTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> torch.Tensor: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + residual = None + for decoder_layer in self.layers: + hidden_states, residual = decoder_layer( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + residual=residual, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +@register_auto_model("MiniMaxM2ForCausalLM") +class MiniMaxM2ForCausalLM(DecoderModelForCausalLM[MiniMaxM2Model, PretrainedConfig]): + def __init__(self, model_config: ModelConfig[PretrainedConfig]): + super().__init__( + MiniMaxM2Model(model_config), + config=model_config, + hidden_size=model_config.pretrained_config.hidden_size, + vocab_size=model_config.pretrained_config.vocab_size, + ) diff --git a/tensorrt_llm/_torch/modules/fused_moe/__init__.py b/tensorrt_llm/_torch/modules/fused_moe/__init__.py index 053ecaa25f..51d6ba5f94 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/__init__.py +++ b/tensorrt_llm/_torch/modules/fused_moe/__init__.py @@ -12,7 +12,8 @@ from .quantization import FusedMoEQuantScalesFP8 from .routing import (BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod, DefaultMoeRoutingMethod, Llama4RenormalizeMoeRoutingMethod, - LoadBalancedMoeRoutingMethod, RenormalizeMoeRoutingMethod, + LoadBalancedMoeRoutingMethod, MiniMaxM2MoeRoutingMethod, + RenormalizeMoeRoutingMethod, RenormalizeNaiveMoeRoutingMethod, RoutingMethodType, SparseMixerMoeRoutingMethod, StaticMoeRoutingMethod, create_renormalize_expert_load_balanced_logits) @@ -33,6 +34,7 @@ __all__ = [ "MoE", "MoeLoadBalancer", "MoEWeightLoadingMode", + "MiniMaxM2MoeRoutingMethod", "RenormalizeMoeRoutingMethod", "RenormalizeNaiveMoeRoutingMethod", "RoutingMethodType", diff --git a/tensorrt_llm/_torch/modules/fused_moe/routing.py b/tensorrt_llm/_torch/modules/fused_moe/routing.py index 3927e9bd6b..be21b5716a 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/routing.py +++ b/tensorrt_llm/_torch/modules/fused_moe/routing.py @@ -155,8 +155,10 @@ class RoutingMethodType(IntEnum): Llama4 = 3, # Qwen3: Softmax -> TopK -> Renormalize RenormalizeNaive = 4, + # MiniMaxM2: Sigmoid -> RoutingBiasAdd -> TopK -> Renormalize(without bias) + MiniMax2 = 5, # Unspecified - Unspecified = 5, + Unspecified = 6, class BaseMoeRoutingMethod(nn.Module): @@ -379,6 +381,57 @@ class DeepSeekV3MoeRoutingMethod(BaseMoeRoutingMethod): return RoutingMethodType.DeepSeekV3 +class MiniMaxM2MoeRoutingMethod(BaseMoeRoutingMethod): + + def __init__( + self, + top_k: int, + num_experts: int, + callable_e_score_correction_bias: Callable[[], torch.Tensor], + output_dtype: torch.dtype = torch.float32, + ): + super().__init__() + self.top_k = top_k + self.num_experts = num_experts + assert callable(callable_e_score_correction_bias) + self.callable_e_score_correction_bias = callable_e_score_correction_bias + self.output_dtype = output_dtype + + @staticmethod + @torch.compile(options={"max-autotune": True}) + def get_scores(logits, e_score_correction_bias): + scores = F.sigmoid(logits) + scores_with_bias = scores + e_score_correction_bias + if enable_llm_debug(): + has_nan = torch.isnan(scores_with_bias).any() + if has_nan: + warnings.warn( + "Detected NAN in the tensor scores_with_bias. Please check if it matches the expectation." + ) + + return scores, scores_with_bias + + def apply(self, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + scores, scores_with_bias = self.get_scores(router_logits, + self.e_score_correction_bias) + _, topk_idx = torch.topk(scores_with_bias, + k=self.top_k, + dim=-1, + sorted=False) + top_k_weights = scores.gather(1, topk_idx) + top_k_weights /= (top_k_weights.sum(dim=-1, keepdim=True) + 1e-20) + return topk_idx.to(torch.int32), top_k_weights.to(self.output_dtype) + + @property + def e_score_correction_bias(self) -> torch.Tensor: + return self.callable_e_score_correction_bias() + + @property + def routing_method_type(self): + return RoutingMethodType.MiniMax2 + + class RenormalizeMoeRoutingMethod(BaseMoeRoutingMethod): def __init__( @@ -587,6 +640,8 @@ ROUTING_METHOD_TYPE_TO_CLASS: Dict[RoutingMethodType, RenormalizeNaiveMoeRoutingMethod, RoutingMethodType.Unspecified: BaseMoeRoutingMethod, + RoutingMethodType.MiniMax2: + MiniMaxM2MoeRoutingMethod, } diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index 76cf65bcb3..96a9ef6b94 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -313,3 +313,7 @@ nvidia/Nemotron-3-Nano: - quant_algo: FP8 kv_cache_quant_algo: FP8 accuracy: 68.73 +MiniMaxAI/MiniMax-M2: + - accuracy: 85 + - quant_algo: FP8_BLOCK_SCALES + accuracy: 85 diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 0ca0842b00..c3d6c87435 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -5350,3 +5350,35 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness): task = GSM8K(self.MODEL_NAME) task.evaluate(llm, extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) + + +@skip_pre_hopper +class TestMiniMaxM2(LlmapiAccuracyTestHarness): + MODEL_NAME = "MiniMaxAI/MiniMax-M2" + MODEL_PATH = f"{llm_models_root()}/MiniMax-M2" + + @parametrize_with_ids("tp_size,ep_size", [(4, 4)]) + @pytest.mark.skip_less_device(4) + @parametrize_with_ids("attention_dp,cuda_graph,overlap_scheduler", + [(False, True, True), (True, True, True)]) + def test_4gpus(self, tp_size, ep_size, attention_dp, cuda_graph, + overlap_scheduler): + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6) + + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None, + moe_config=MoeConfig( + backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS")) + + with LLM(self.MODEL_PATH, + tensor_parallel_size=tp_size, + pipeline_parallel_size=1, + moe_expert_parallel_size=ep_size, + kv_cache_config=kv_cache_config, + max_seq_len=4096, + **pytorch_config, + enable_attention_dp=attention_dp) as llm: + assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index d5ea5d9e17..84af0aae2b 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -31,6 +31,8 @@ l0_dgx_b200: - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8] - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_adp_lmtp_tp4] + - accuracy/test_llm_api_pytorch.py::TestMiniMaxM2::test_4gpus[attention_dp=False-cuda_graph=True-overlap_scheduler=True-tp_size=4-ep_size=4] TIMEOUT (60) + # ------------- AutoDeploy tests --------------- - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16 - condition: