diff --git a/tensorrt_llm/_torch/models/modeling_minimaxm2.py b/tensorrt_llm/_torch/models/modeling_minimaxm2.py index 88885e5bfd..4715ce888d 100644 --- a/tensorrt_llm/_torch/models/modeling_minimaxm2.py +++ b/tensorrt_llm/_torch/models/modeling_minimaxm2.py @@ -169,28 +169,31 @@ class MiniMaxM2Attention(Attention): self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] def apply_qk_norm(self, q, k): - # collect q and k from all gpus - from ..distributed import allgather + if self.qkv_proj.mapping.tp_size > 1: + # collect q and k from all gpus + from ..distributed import allgather - temp_q = allgather(q, self.qkv_proj.mapping) - temp_k = allgather(k, self.qkv_proj.mapping) - temp_q = self.q_norm(temp_q) - temp_k = self.k_norm(temp_k) - # temp_q, temp_k = maybe_execute_in_parallel( - # self.q_norm(temp_q), - # self.k_norm(temp_k), - # self.ln_events[0], - # self.ln_events[1], - # self.aux_stream, - # ) - # split q and k to each gpus - # Fixme: tp_size may not be equal to the world size of current mapping - q = temp_q.reshape(-1, self.tp_size, self.q_size)[:, self.tp_rank, :].reshape( - -1, self.q_size - ) - k = temp_k.reshape(-1, self.tp_size, self.kv_size)[:, self.tp_rank, :].reshape( - -1, self.kv_size - ) + temp_q = allgather(q, self.qkv_proj.mapping) + temp_k = allgather(k, self.qkv_proj.mapping) + temp_q = self.q_norm(temp_q) + temp_k = self.k_norm(temp_k) + # temp_q, temp_k = maybe_execute_in_parallel( + # self.q_norm(temp_q), + # self.k_norm(temp_k), + # self.ln_events[0], + # self.ln_events[1], + # ) + # split q and k to each gpus + # Fixme: tp_size may not be equal to the world size of current mapping + q = temp_q.reshape(-1, self.tp_size, self.q_size)[:, self.tp_rank, :].reshape( + -1, self.q_size + ) + k = temp_k.reshape(-1, self.tp_size, self.kv_size)[:, self.tp_rank, :].reshape( + -1, self.kv_size + ) + else: + q = self.q_norm(q) + k = self.k_norm(k) return q, k diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index 22be00114b..96a9ef6b94 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -314,6 +314,6 @@ nvidia/Nemotron-3-Nano: kv_cache_quant_algo: FP8 accuracy: 68.73 MiniMaxAI/MiniMax-M2: - - accuracy: 20 + - accuracy: 85 - quant_algo: FP8_BLOCK_SCALES - accuracy: 20 + accuracy: 85