mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix: alltoall padding for chunked MoE (#4157)
fix alltoall padding for chunked MoE Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>
This commit is contained in:
parent
9477661f4c
commit
7147efb2e8
@ -457,14 +457,14 @@ class Deepseekv3MoE(nn.Module):
|
||||
def compute_routed_output(self, hidden_states, hidden_states_fp4,
|
||||
all_rank_num_tokens, cutlass_min_latency_mode):
|
||||
# max-throughput
|
||||
if self.use_dp and self.mapping.tp_size > 1 and not self.enable_alltoall:
|
||||
if self.use_dp and self.mapping.tp_size > 1:
|
||||
max_num_token = max(all_rank_num_tokens)
|
||||
hidden_states = torch.nn.functional.pad(
|
||||
hidden_states,
|
||||
(0, 0, 0, max_num_token - hidden_states.shape[0]))
|
||||
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
|
||||
# to reduce allreduce BW
|
||||
if disable_fp4_allgather():
|
||||
if disable_fp4_allgather() and not self.enable_alltoall:
|
||||
hidden_states = allgather(hidden_states,
|
||||
self.mapping,
|
||||
gather_dim=0)
|
||||
|
||||
@ -903,14 +903,13 @@ class FusedMoE(nn.Module):
|
||||
if self.use_dp and self.enable_alltoall:
|
||||
all_rank_chunk_size_list = []
|
||||
for single_rank_num_tokens in all_rank_num_tokens:
|
||||
single_rank_num_chunks = (single_rank_num_tokens +
|
||||
max_chunk_size -
|
||||
1) // max_chunk_size
|
||||
assert single_rank_num_chunks == num_chunks,\
|
||||
"num_chunks should be the same for attention dp and ep"
|
||||
all_rank_chunk_size_list.append(
|
||||
split_chunk(single_rank_num_tokens,
|
||||
single_rank_num_chunks))
|
||||
single_rank_num_chunks = num_chunks
|
||||
single_rank_chunk_size_list = split_chunk(
|
||||
single_rank_num_tokens, single_rank_num_chunks)
|
||||
single_rank_chunk_size_list = [
|
||||
1 if x == 0 else x for x in single_rank_chunk_size_list
|
||||
]
|
||||
all_rank_chunk_size_list.append(single_rank_chunk_size_list)
|
||||
|
||||
for chunk_id in range(num_chunks):
|
||||
chunk_all_rank_num_tokens = [
|
||||
|
||||
Loading…
Reference in New Issue
Block a user