chore: modify by code review

Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
This commit is contained in:
Mingyang Jiang 2026-01-12 16:26:29 +08:00
parent 1e6e0582b3
commit 5704619137
2 changed files with 4 additions and 19 deletions

View File

@ -71,6 +71,7 @@ class MiniMaxM2MoE(nn.Module):
num_experts=self.num_experts,
callable_e_score_correction_bias=lambda: self.e_score_correction_bias,
),
num_experts=self.num_experts,
aux_stream_dict={AuxStreamType.MoeChunkingOverlap: aux_stream},
reduce_results=reduce_results,
model_config=model_config,
@ -115,24 +116,13 @@ class MiniMaxM2Attention(Attention):
self,
*,
model_config: ModelConfig[PretrainedConfig],
skip_rope: bool = False,
fuse_qk_norm_rope: bool = False,
layer_idx: Optional[int] = None,
is_qk_norm: bool = True,
):
config = model_config.pretrained_config
self.pretrained_config = config
self.fuse_qk_norm_rope = fuse_qk_norm_rope
self.skip_rope = skip_rope
# If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP, and self.rotary_emb
# will be skipped in the overridden apply_rope.
rope_fusion = not self.fuse_qk_norm_rope and not skip_rope
self.is_qk_norm = is_qk_norm
assert not (fuse_qk_norm_rope and skip_rope), (
"Fusing qk norm and skipping rope is not supported"
)
super().__init__(
hidden_size=config.hidden_size,
@ -144,7 +134,7 @@ class MiniMaxM2Attention(Attention):
type=PositionEmbeddingType.rope_gpt_neox,
rope=RopeParams.from_config(config),
),
rope_fusion=rope_fusion,
rope_fusion=True,
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
@ -160,8 +150,6 @@ class MiniMaxM2Attention(Attention):
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
)
self.aux_stream = torch.cuda.Stream()
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
def apply_qk_norm(self, q, k):
if self.qkv_proj.mapping.tp_size > 1:
@ -201,10 +189,7 @@ class MiniMaxM2Attention(Attention):
)
q, k, v = self.split_qkv(q, k, v)
q, k = self.apply_qk_norm(q, k)
if not self.skip_rope:
return super().apply_rope(q, k, v, position_ids)
else:
return q, k, v
return super().apply_rope(q, k, v, position_ids)
class MiniMaxM2DecoderLayer(DecoderLayer):

View File

@ -72,7 +72,7 @@ l0_gb200_multi_gpus:
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-cutlass]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[no_cuda_graph_overlap-cutlass]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-trtllm]
- accuracy/test_llm_api_pytorch.py::TestMiniMaxM2::test_4gpus[attention_dp=False-cuda_graph=True-overlap_scheduler=True-tp_size=4-ep_size=4]
- accuracy/test_llm_api_pytorch.py::TestMiniMaxM2::test_4gpus[attention_dp=False-cuda_graph=True-overlap_scheduler=True-tp_size=4-ep_size=4] TIMEOUT (90)
- accuracy/test_llm_api_pytorch_multimodal.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] TIMEOUT (90)
- condition:
ranges: