mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix: fix bugs when enable attention dp
Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
This commit is contained in:
parent
6f946b8e81
commit
49cddcb27a
@ -169,28 +169,31 @@ class MiniMaxM2Attention(Attention):
|
||||
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
|
||||
|
||||
def apply_qk_norm(self, q, k):
|
||||
# collect q and k from all gpus
|
||||
from ..distributed import allgather
|
||||
if self.qkv_proj.mapping.tp_size > 1:
|
||||
# collect q and k from all gpus
|
||||
from ..distributed import allgather
|
||||
|
||||
temp_q = allgather(q, self.qkv_proj.mapping)
|
||||
temp_k = allgather(k, self.qkv_proj.mapping)
|
||||
temp_q = self.q_norm(temp_q)
|
||||
temp_k = self.k_norm(temp_k)
|
||||
# temp_q, temp_k = maybe_execute_in_parallel(
|
||||
# self.q_norm(temp_q),
|
||||
# self.k_norm(temp_k),
|
||||
# self.ln_events[0],
|
||||
# self.ln_events[1],
|
||||
# self.aux_stream,
|
||||
# )
|
||||
# split q and k to each gpus
|
||||
# Fixme: tp_size may not be equal to the world size of current mapping
|
||||
q = temp_q.reshape(-1, self.tp_size, self.q_size)[:, self.tp_rank, :].reshape(
|
||||
-1, self.q_size
|
||||
)
|
||||
k = temp_k.reshape(-1, self.tp_size, self.kv_size)[:, self.tp_rank, :].reshape(
|
||||
-1, self.kv_size
|
||||
)
|
||||
temp_q = allgather(q, self.qkv_proj.mapping)
|
||||
temp_k = allgather(k, self.qkv_proj.mapping)
|
||||
temp_q = self.q_norm(temp_q)
|
||||
temp_k = self.k_norm(temp_k)
|
||||
# temp_q, temp_k = maybe_execute_in_parallel(
|
||||
# self.q_norm(temp_q),
|
||||
# self.k_norm(temp_k),
|
||||
# self.ln_events[0],
|
||||
# self.ln_events[1],
|
||||
# )
|
||||
# split q and k to each gpus
|
||||
# Fixme: tp_size may not be equal to the world size of current mapping
|
||||
q = temp_q.reshape(-1, self.tp_size, self.q_size)[:, self.tp_rank, :].reshape(
|
||||
-1, self.q_size
|
||||
)
|
||||
k = temp_k.reshape(-1, self.tp_size, self.kv_size)[:, self.tp_rank, :].reshape(
|
||||
-1, self.kv_size
|
||||
)
|
||||
else:
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
return q, k
|
||||
|
||||
|
||||
@ -314,6 +314,6 @@ nvidia/Nemotron-3-Nano:
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 68.73
|
||||
MiniMaxAI/MiniMax-M2:
|
||||
- accuracy: 20
|
||||
- accuracy: 85
|
||||
- quant_algo: FP8_BLOCK_SCALES
|
||||
accuracy: 20
|
||||
accuracy: 85
|
||||
|
||||
Loading…
Reference in New Issue
Block a user