mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None] [test] Add MNNVL AlltoAll tests to pre-merge (#8601)
This commit is contained in:
parent
0019d99e6d
commit
c9b08790c2
@ -57,6 +57,7 @@ from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod,
|
||||
MoEWeightLoadingMode, create_moe)
|
||||
from ..modules.fused_moe.fused_moe_wide_ep import WideEPMoE
|
||||
from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig
|
||||
from ..modules.multi_stream_utils import maybe_execute_in_parallel
|
||||
@ -864,6 +865,9 @@ class Deepseekv3MoE(nn.Module):
|
||||
output_dtype=hidden_states.dtype,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
use_dp_padding=use_dp_padding,
|
||||
**({
|
||||
"alltoall_result_do_sum": False
|
||||
} if isinstance(self.experts, WideEPMoE) else {}),
|
||||
)
|
||||
|
||||
return routed_output
|
||||
|
||||
@ -454,14 +454,15 @@ class WideEPMoE(MoE):
|
||||
return False
|
||||
|
||||
def forward_chunk(
|
||||
self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
router_logits: torch.Tensor,
|
||||
use_all_to_all: bool,
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
repeating_info: Tuple = (True, True),
|
||||
self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
router_logits: torch.Tensor,
|
||||
use_all_to_all: bool,
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
repeating_info: Tuple = (True, True),
|
||||
alltoall_result_do_sum: bool = True,
|
||||
) -> torch.Tensor:
|
||||
all_rank_max_num_tokens = max(all_rank_num_tokens)
|
||||
if isinstance(x, Fp4QuantizedTensor):
|
||||
@ -476,7 +477,7 @@ class WideEPMoE(MoE):
|
||||
self.layer_load_balancer.start_wait_gpu_stage()
|
||||
|
||||
if not use_all_to_all or self.alltoall_method_type != AlltoallMethodType.MNNVL:
|
||||
pass
|
||||
alltoall_result_do_sum = True
|
||||
|
||||
weight_dtype = self.w3_w1_weight.dtype
|
||||
|
||||
@ -743,7 +744,8 @@ class WideEPMoE(MoE):
|
||||
if self.enable_dummy_allreduce:
|
||||
self.dummy_allreduce()
|
||||
final_hidden_states = self.alltoall_combine(
|
||||
final_hidden_states, alltoall_info, token_count)
|
||||
final_hidden_states, alltoall_info, token_count,
|
||||
alltoall_result_do_sum)
|
||||
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
|
||||
final_hidden_states = self.unpad_tensors(
|
||||
padded, final_hidden_states)
|
||||
@ -788,6 +790,7 @@ class WideEPMoE(MoE):
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
alltoall_result_do_sum: bool = True,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert all_rank_num_tokens is not None
|
||||
@ -815,7 +818,8 @@ class WideEPMoE(MoE):
|
||||
output_dtype,
|
||||
all_rank_num_tokens=all_rank_num_tokens_padded,
|
||||
use_dp_padding=use_dp_padding,
|
||||
repeating_info=(is_first_call, is_last_call))
|
||||
repeating_info=(is_first_call, is_last_call),
|
||||
alltoall_result_do_sum=alltoall_result_do_sum)
|
||||
outputs = self.reducescatter_or_allreduce(
|
||||
outputs,
|
||||
use_all_to_all,
|
||||
@ -873,7 +877,8 @@ class WideEPMoE(MoE):
|
||||
all_rank_num_tokens=all_rank_num_tokens_list[
|
||||
idx_chunk],
|
||||
use_dp_padding=use_dp_padding,
|
||||
repeating_info=(is_first_call, is_last_call))
|
||||
repeating_info=(is_first_call, is_last_call),
|
||||
alltoall_result_do_sum=alltoall_result_do_sum)
|
||||
if idx_chunk > 0:
|
||||
outputs_list[-1] = self.reducescatter_or_allreduce(
|
||||
outputs_list[-1],
|
||||
@ -889,7 +894,8 @@ class WideEPMoE(MoE):
|
||||
all_rank_num_tokens=all_rank_num_tokens_list[
|
||||
idx_chunk],
|
||||
use_dp_padding=use_dp_padding,
|
||||
repeating_info=(is_first_call, is_last_call))
|
||||
repeating_info=(is_first_call, is_last_call),
|
||||
alltoall_result_do_sum=alltoall_result_do_sum)
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
outputs_list[-1] = self.reducescatter_or_allreduce(
|
||||
outputs_list[-1],
|
||||
@ -903,7 +909,8 @@ class WideEPMoE(MoE):
|
||||
router_logits,
|
||||
use_all_to_all,
|
||||
all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk],
|
||||
repeating_info=(is_first_call, is_last_call))
|
||||
repeating_info=(is_first_call, is_last_call),
|
||||
alltoall_result_do_sum=alltoall_result_do_sum)
|
||||
|
||||
outputs_list.append(outputs)
|
||||
if not use_all_to_all:
|
||||
@ -959,7 +966,8 @@ class WideEPMoE(MoE):
|
||||
return x, x_sf, token_selected_slots, token_final_scales
|
||||
|
||||
def alltoall_combine(self, final_hidden_states: torch.Tensor,
|
||||
alltoall_info: MoEAlltoallInfo, token_count: int):
|
||||
alltoall_info: MoEAlltoallInfo, token_count: int,
|
||||
alltoall_result_do_sum: bool):
|
||||
top_k = self.routing_method.experts_per_token
|
||||
if isinstance(final_hidden_states, list):
|
||||
final_hidden_states = final_hidden_states[0]
|
||||
@ -972,7 +980,7 @@ class WideEPMoE(MoE):
|
||||
top_k=top_k,
|
||||
token_count=token_count,
|
||||
use_low_precision_combine=self.use_low_precision_combine,
|
||||
do_reduce=False)
|
||||
do_reduce=alltoall_result_do_sum)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
@ -242,6 +242,7 @@ class MoE(nn.Module):
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
if self.register_to_config and is_torch_compiling():
|
||||
hidden_states = x.fp4_tensor if isinstance(
|
||||
@ -274,6 +275,7 @@ class MoE(nn.Module):
|
||||
output_dtype=output_dtype,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
use_dp_padding=use_dp_padding,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@ -17,6 +17,7 @@ l0_dgx_b200:
|
||||
tests:
|
||||
- unittest/_torch/multi_gpu_modeling -k "deepseek"
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[MNNVL]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
|
||||
|
||||
@ -116,6 +116,7 @@ l0_dgx_h100:
|
||||
- unittest/_torch/multi_gpu_modeling/test_deepseek.py::test_deepseek_streaming[tp4-bf16-trtllm-deepseekv3_lite]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEP]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEPLowLatency]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[MNNVL]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype0]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype1]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.W4A8_CUSTOM-dtype0]
|
||||
|
||||
@ -214,11 +214,14 @@ def test_fused_moe_alltoall(alltoall_method_type):
|
||||
weights = {}
|
||||
for expert_id in range(NUM_EXPERTS):
|
||||
w1_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
|
||||
dtype=dtype)
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
w2_weight = torch.empty((HIDDEN_SIZE, INTERMEDIATE_SIZE),
|
||||
dtype=dtype)
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
w3_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
|
||||
dtype=dtype)
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
torch.nn.init.xavier_uniform_(w1_weight)
|
||||
torch.nn.init.xavier_uniform_(w2_weight)
|
||||
torch.nn.init.xavier_uniform_(w3_weight)
|
||||
@ -294,7 +297,6 @@ def test_fused_moe_alltoall(alltoall_method_type):
|
||||
assert r is None
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5467531")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
reason="needs 4 GPUs to run this test")
|
||||
@pytest.mark.parametrize("alltoall_method_type", [
|
||||
@ -304,6 +306,9 @@ def test_fused_moe_alltoall(alltoall_method_type):
|
||||
ids=lambda s: s.name)
|
||||
def test_fused_moe_alltoall_fp4(alltoall_method_type):
|
||||
|
||||
if alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
|
||||
pytest.skip("Skipped due to https://nvbugs/5467531")
|
||||
|
||||
world_size = 4
|
||||
dtype = torch.bfloat16
|
||||
HIDDEN_SIZE = 2560
|
||||
|
||||
Loading…
Reference in New Issue
Block a user