[None] [test] Add MNNVL AlltoAll tests to pre-merge (#8601)

This commit is contained in:
Kaiyu Xie 2025-10-27 21:39:44 +08:00 committed by GitHub
parent 0019d99e6d
commit c9b08790c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 41 additions and 20 deletions

View File

@ -57,6 +57,7 @@ from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod,
MoEWeightLoadingMode, create_moe)
from ..modules.fused_moe.fused_moe_wide_ep import WideEPMoE
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig
from ..modules.multi_stream_utils import maybe_execute_in_parallel
@ -864,6 +865,9 @@ class Deepseekv3MoE(nn.Module):
output_dtype=hidden_states.dtype,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding,
**({
"alltoall_result_do_sum": False
} if isinstance(self.experts, WideEPMoE) else {}),
)
return routed_output

View File

@ -454,14 +454,15 @@ class WideEPMoE(MoE):
return False
def forward_chunk(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
use_all_to_all: bool,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
repeating_info: Tuple = (True, True),
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
use_all_to_all: bool,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
repeating_info: Tuple = (True, True),
alltoall_result_do_sum: bool = True,
) -> torch.Tensor:
all_rank_max_num_tokens = max(all_rank_num_tokens)
if isinstance(x, Fp4QuantizedTensor):
@ -476,7 +477,7 @@ class WideEPMoE(MoE):
self.layer_load_balancer.start_wait_gpu_stage()
if not use_all_to_all or self.alltoall_method_type != AlltoallMethodType.MNNVL:
pass
alltoall_result_do_sum = True
weight_dtype = self.w3_w1_weight.dtype
@ -743,7 +744,8 @@ class WideEPMoE(MoE):
if self.enable_dummy_allreduce:
self.dummy_allreduce()
final_hidden_states = self.alltoall_combine(
final_hidden_states, alltoall_info, token_count)
final_hidden_states, alltoall_info, token_count,
alltoall_result_do_sum)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
final_hidden_states = self.unpad_tensors(
padded, final_hidden_states)
@ -788,6 +790,7 @@ class WideEPMoE(MoE):
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
alltoall_result_do_sum: bool = True,
**kwargs,
) -> torch.Tensor:
assert all_rank_num_tokens is not None
@ -815,7 +818,8 @@ class WideEPMoE(MoE):
output_dtype,
all_rank_num_tokens=all_rank_num_tokens_padded,
use_dp_padding=use_dp_padding,
repeating_info=(is_first_call, is_last_call))
repeating_info=(is_first_call, is_last_call),
alltoall_result_do_sum=alltoall_result_do_sum)
outputs = self.reducescatter_or_allreduce(
outputs,
use_all_to_all,
@ -873,7 +877,8 @@ class WideEPMoE(MoE):
all_rank_num_tokens=all_rank_num_tokens_list[
idx_chunk],
use_dp_padding=use_dp_padding,
repeating_info=(is_first_call, is_last_call))
repeating_info=(is_first_call, is_last_call),
alltoall_result_do_sum=alltoall_result_do_sum)
if idx_chunk > 0:
outputs_list[-1] = self.reducescatter_or_allreduce(
outputs_list[-1],
@ -889,7 +894,8 @@ class WideEPMoE(MoE):
all_rank_num_tokens=all_rank_num_tokens_list[
idx_chunk],
use_dp_padding=use_dp_padding,
repeating_info=(is_first_call, is_last_call))
repeating_info=(is_first_call, is_last_call),
alltoall_result_do_sum=alltoall_result_do_sum)
with torch.cuda.stream(self.aux_stream):
outputs_list[-1] = self.reducescatter_or_allreduce(
outputs_list[-1],
@ -903,7 +909,8 @@ class WideEPMoE(MoE):
router_logits,
use_all_to_all,
all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk],
repeating_info=(is_first_call, is_last_call))
repeating_info=(is_first_call, is_last_call),
alltoall_result_do_sum=alltoall_result_do_sum)
outputs_list.append(outputs)
if not use_all_to_all:
@ -959,7 +966,8 @@ class WideEPMoE(MoE):
return x, x_sf, token_selected_slots, token_final_scales
def alltoall_combine(self, final_hidden_states: torch.Tensor,
alltoall_info: MoEAlltoallInfo, token_count: int):
alltoall_info: MoEAlltoallInfo, token_count: int,
alltoall_result_do_sum: bool):
top_k = self.routing_method.experts_per_token
if isinstance(final_hidden_states, list):
final_hidden_states = final_hidden_states[0]
@ -972,7 +980,7 @@ class WideEPMoE(MoE):
top_k=top_k,
token_count=token_count,
use_low_precision_combine=self.use_low_precision_combine,
do_reduce=False)
do_reduce=alltoall_result_do_sum)
return final_hidden_states

View File

@ -242,6 +242,7 @@ class MoE(nn.Module):
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
**kwargs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
if self.register_to_config and is_torch_compiling():
hidden_states = x.fp4_tensor if isinstance(
@ -274,6 +275,7 @@ class MoE(nn.Module):
output_dtype=output_dtype,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding,
**kwargs,
)
@property

View File

@ -17,6 +17,7 @@ l0_dgx_b200:
tests:
- unittest/_torch/multi_gpu_modeling -k "deepseek"
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[MNNVL]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]

View File

@ -116,6 +116,7 @@ l0_dgx_h100:
- unittest/_torch/multi_gpu_modeling/test_deepseek.py::test_deepseek_streaming[tp4-bf16-trtllm-deepseekv3_lite]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEP]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEPLowLatency]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[MNNVL]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype0]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype1]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.W4A8_CUSTOM-dtype0]

View File

@ -214,11 +214,14 @@ def test_fused_moe_alltoall(alltoall_method_type):
weights = {}
for expert_id in range(NUM_EXPERTS):
w1_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=dtype)
dtype=dtype,
device="cuda")
w2_weight = torch.empty((HIDDEN_SIZE, INTERMEDIATE_SIZE),
dtype=dtype)
dtype=dtype,
device="cuda")
w3_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=dtype)
dtype=dtype,
device="cuda")
torch.nn.init.xavier_uniform_(w1_weight)
torch.nn.init.xavier_uniform_(w2_weight)
torch.nn.init.xavier_uniform_(w3_weight)
@ -294,7 +297,6 @@ def test_fused_moe_alltoall(alltoall_method_type):
assert r is None
@pytest.mark.skip(reason="https://nvbugs/5467531")
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="needs 4 GPUs to run this test")
@pytest.mark.parametrize("alltoall_method_type", [
@ -304,6 +306,9 @@ def test_fused_moe_alltoall(alltoall_method_type):
ids=lambda s: s.name)
def test_fused_moe_alltoall_fp4(alltoall_method_type):
if alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
pytest.skip("Skipped due to https://nvbugs/5467531")
world_size = 4
dtype = torch.bfloat16
HIDDEN_SIZE = 2560