fix: fix bugs when enable attention dp

Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
This commit is contained in:
Mingyang Jiang 2026-01-06 16:09:48 +08:00
parent 6f946b8e81
commit 49cddcb27a
2 changed files with 26 additions and 23 deletions

View File

@ -169,28 +169,31 @@ class MiniMaxM2Attention(Attention):
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
def apply_qk_norm(self, q, k):
# collect q and k from all gpus
from ..distributed import allgather
if self.qkv_proj.mapping.tp_size > 1:
# collect q and k from all gpus
from ..distributed import allgather
temp_q = allgather(q, self.qkv_proj.mapping)
temp_k = allgather(k, self.qkv_proj.mapping)
temp_q = self.q_norm(temp_q)
temp_k = self.k_norm(temp_k)
# temp_q, temp_k = maybe_execute_in_parallel(
# self.q_norm(temp_q),
# self.k_norm(temp_k),
# self.ln_events[0],
# self.ln_events[1],
# self.aux_stream,
# )
# split q and k to each gpus
# Fixme: tp_size may not be equal to the world size of current mapping
q = temp_q.reshape(-1, self.tp_size, self.q_size)[:, self.tp_rank, :].reshape(
-1, self.q_size
)
k = temp_k.reshape(-1, self.tp_size, self.kv_size)[:, self.tp_rank, :].reshape(
-1, self.kv_size
)
temp_q = allgather(q, self.qkv_proj.mapping)
temp_k = allgather(k, self.qkv_proj.mapping)
temp_q = self.q_norm(temp_q)
temp_k = self.k_norm(temp_k)
# temp_q, temp_k = maybe_execute_in_parallel(
# self.q_norm(temp_q),
# self.k_norm(temp_k),
# self.ln_events[0],
# self.ln_events[1],
# )
# split q and k to each gpus
# Fixme: tp_size may not be equal to the world size of current mapping
q = temp_q.reshape(-1, self.tp_size, self.q_size)[:, self.tp_rank, :].reshape(
-1, self.q_size
)
k = temp_k.reshape(-1, self.tp_size, self.kv_size)[:, self.tp_rank, :].reshape(
-1, self.kv_size
)
else:
q = self.q_norm(q)
k = self.k_norm(k)
return q, k

View File

@ -314,6 +314,6 @@ nvidia/Nemotron-3-Nano:
kv_cache_quant_algo: FP8
accuracy: 68.73
MiniMaxAI/MiniMax-M2:
- accuracy: 20
- accuracy: 85
- quant_algo: FP8_BLOCK_SCALES
accuracy: 20
accuracy: 85