From cc4511997a21c13c4cfd6c55f842c3e914a79ad6 Mon Sep 17 00:00:00 2001 From: Yanchao Lu Date: Mon, 16 Feb 2026 21:23:12 +0800 Subject: [PATCH] [None][revert] - Revert "[TRTLLM-9108][feat] refactor MoE unit tests: add unified ConfigurableMoE test framework" (#11532) --- .../communication/communication_factory.py | 2 +- .../fused_moe/communication/deep_ep.py | 17 +- .../communication/deep_ep_low_latency.py | 40 +- .../modules/fused_moe/configurable_moe.py | 10 +- .../modules/fused_moe/fused_moe_cute_dsl.py | 14 +- .../modules/fused_moe/fused_moe_cutlass.py | 14 +- .../modules/fused_moe/fused_moe_deepgemm.py | 14 +- .../modules/fused_moe/fused_moe_triton.py | 14 +- .../modules/fused_moe/fused_moe_trtllm_gen.py | 12 +- .../_torch/modules/fused_moe/interface.py | 19 +- .../_torch/modules/fused_moe/quantization.py | 56 - .../_torch/modules/moe/moe_test_utils.py | 727 --------- .../_torch/modules/moe/quantize_utils.py | 148 +- .../_torch/modules/moe/test_moe_backend.py | 538 ++++++- .../_torch/modules/moe/test_moe_module.py | 1308 +++-------------- 15 files changed, 759 insertions(+), 2174 deletions(-) delete mode 100644 tests/unittest/_torch/modules/moe/moe_test_utils.py diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py b/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py index 651306f765..c7f5b22a4a 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py @@ -154,7 +154,7 @@ class CommunicationFactory: logger.debug(f"NVLinkTwoSided not available: {e}") # Try DeepEP (if enabled and weight dtype is bfloat16) - if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "1") == "1" and act_dtype == torch.bfloat16: + if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "0") == "1" and act_dtype == torch.bfloat16: try: strategy = DeepEP( mapping, diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.py b/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.py index 4f50b9d12d..a8d71a1d6b 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.py @@ -81,6 +81,8 @@ class DeepEP(Communication): """ Check if DeepEP is supported on the current platform """ + if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "0") != "1": + return False return deep_ep_installed @staticmethod @@ -143,13 +145,6 @@ class DeepEP(Communication): """ all_rank_max_num_tokens = max(all_rank_num_tokens) - # DeepEP C++ kernel requires topk_weights (token_final_scales) to be float32, - # but downstream backends (e.g. TRTLLM) may require the original dtype. - # Convert to float32 for dispatch, then restore afterward. - original_scales_dtype = token_final_scales.dtype if token_final_scales is not None else None - if token_final_scales is not None and token_final_scales.dtype != torch.float32: - token_final_scales = token_final_scales.to(torch.float32) - if not self.supports_post_quant_dispatch(): # Pre-quant dispatch (unquantized data) ( @@ -220,14 +215,6 @@ class DeepEP(Communication): "padded": padded, } - # Restore token_final_scales to original dtype for downstream consumers - if ( - token_final_scales is not None - and original_scales_dtype is not None - and token_final_scales.dtype != original_scales_dtype - ): - token_final_scales = token_final_scales.to(original_scales_dtype) - return hidden_states, hidden_states_sf, token_selected_slots, token_final_scales def combine( diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep_low_latency.py b/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep_low_latency.py index 8fdcf7c330..656f4957fe 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep_low_latency.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep_low_latency.py @@ -37,15 +37,6 @@ class DeepEPLowLatency(Communication): DeepEP Low Latency strategy supporting both pre-quant and post-quant """ - SUPPORTED_HIDDEN_SIZES = {2048, 2560, 3584, 4096, 5120, 6144, 7168} - """set[int]: Hidden sizes supported by the low-latency DeepEP kernel (SWITCH_HIDDEN in launch.cuh).""" - - SUPPORTED_HIDDEN_SIZES_EXTENSION = {4096, 6144, 7168} - """set[int]: Hidden sizes supported by extension kernels (nvfp4 post-quant/low-precision combine). - - Sourced from SWITCH_HIDDEN_FOR_EXTENSION_KERNELS in extension_kernels.cu. - """ - def __init__( self, mapping: Mapping, @@ -60,13 +51,6 @@ class DeepEPLowLatency(Communication): ): super().__init__(mapping) - # Validate hidden_size against kernel constraints - if hidden_size not in self.SUPPORTED_HIDDEN_SIZES: - raise RuntimeError( - f"DeepEPLowLatency does not support hidden_size={hidden_size}. " - f"Supported hidden sizes: {sorted(self.SUPPORTED_HIDDEN_SIZES)}" - ) - # Store needed parameters self.num_slots = num_slots self.hidden_size = hidden_size @@ -102,6 +86,8 @@ class DeepEPLowLatency(Communication): """ Check if DeepEP Low Latency is supported on the current platform """ + if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "0") != "1": + return False if not deep_ep_installed: return False return True @@ -109,35 +95,15 @@ class DeepEPLowLatency(Communication): def supports_post_quant_dispatch(self) -> bool: """ DeepEP Low Latency supports post-quant for: fp8_qdq, nvfp4, w4afp8 - - Note: nvfp4 post-quant dispatch uses extension kernels which require - hidden_size in SUPPORTED_HIDDEN_SIZES_EXTENSION. - - Note: fp8_qdq and w4afp8 post-quant dispatch views fp8 (1 byte) as - bf16 (2 bytes) via .view(torch.bfloat16), halving the hidden dimension. - The halved dimension must be in SUPPORTED_HIDDEN_SIZES for the dispatch - kernel (SWITCH_HIDDEN in internode_ll.cu) to work. """ if not self.enable_postquant_alltoall: return False - if self._has_nvfp4(): - # nvfp4 dispatch uses extension kernels with stricter hidden_size requirement - return self.hidden_size in self.SUPPORTED_HIDDEN_SIZES_EXTENSION - if self._has_fp8_qdq() or self._has_w4afp8(): - # fp8/w4afp8 post-quant dispatch views fp8 (1 byte) as bf16 (2 bytes), - # halving the hidden dimension. The kernel must support the halved size. - return (self.hidden_size // 2) in self.SUPPORTED_HIDDEN_SIZES - return False + return self._has_nvfp4() or self._has_fp8_qdq() or self._has_w4afp8() def supports_low_precision_combine(self) -> bool: """ DeepEP Low Latency supports low-precision combine for: fp8_qdq, nvfp4, w4afp8 - - Note: low-precision combine uses extension kernels which require - hidden_size in SUPPORTED_HIDDEN_SIZES_EXTENSION. """ - if self.hidden_size not in self.SUPPORTED_HIDDEN_SIZES_EXTENSION: - return False return self._has_nvfp4() or self._has_fp8_qdq() or self._has_w4afp8() def is_workload_feasible(self, all_rank_num_tokens: List[int], num_chunks: int) -> bool: diff --git a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py index d1b7024c88..b99de2086d 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py @@ -100,25 +100,25 @@ class ConfigurableMoE(MoE): cls, quant_algo, dtype_activation: torch.dtype = torch.bfloat16, - swiglu_gptoss_style: bool = False, + gptoss_style: bool = False, ): """ ConfigurableMoE is a wrapper class that delegates to specific backends. To check capability, query the specific backend class directly: - - CutlassFusedMoE.can_implement(quant_algo, dtype_activation, swiglu_gptoss_style) - - TRTLLMGenFusedMoE.can_implement(quant_algo, dtype_activation, swiglu_gptoss_style) + - CutlassFusedMoE.can_implement(quant_algo, dtype_activation, gptoss_style) + - TRTLLMGenFusedMoE.can_implement(quant_algo, dtype_activation, gptoss_style) - etc. Args: quant_algo: The quantization algorithm to check (None for unquantized) dtype_activation: The activation data type - swiglu_gptoss_style: Whether swiglu_gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled + gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled Returns: Tuple[bool, Optional[str]]: Always returns (False, reason) """ - del quant_algo, dtype_activation, swiglu_gptoss_style # Unused - wrapper class + del quant_algo, dtype_activation, gptoss_style # Unused - wrapper class return False, ( "ConfigurableMoE is a wrapper class. " "Query the specific backend (CutlassFusedMoE, TRTLLMGenFusedMoE, etc.) directly." diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py index b15abc4ccc..ca3e6c1a20 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py @@ -318,7 +318,7 @@ class CuteDslFusedMoE(CutlassFusedMoE): cls, quant_algo: Optional[QuantAlgo], dtype_activation: torch.dtype = torch.bfloat16, - swiglu_gptoss_style: bool = False, + gptoss_style: bool = False, ) -> Tuple[bool, Optional[str]]: """ Check if CuteDslFusedMoE can implement the given quantization algorithm. @@ -327,14 +327,14 @@ class CuteDslFusedMoE(CutlassFusedMoE): - NVFP4: SM in {100, 103} Does NOT support unquantized mode. Output dtype is hardcoded to bfloat16. - Does NOT support swiglu_gptoss_style (bias/swiglu with custom alpha/beta/limit). + Does NOT support gptoss_style (bias/swiglu with custom alpha/beta/limit). Args: quant_algo: The quantization algorithm to check (None for unquantized) dtype_activation: The activation input data type. Only bfloat16 is supported because output dtype is hardcoded to bfloat16 (input/output dtype must match). - swiglu_gptoss_style: Whether swiglu_gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled. - CuteDslFusedMoE does NOT support swiglu_gptoss_style. + gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled. + CuteDslFusedMoE does NOT support gptoss_style. Returns: Tuple[bool, Optional[str]]: (can_implement, skip_reason) @@ -360,10 +360,10 @@ class CuteDslFusedMoE(CutlassFusedMoE): return _warn_and_return( "CuteDslFusedMoE does not support unquantized mode") - # CuteDslFusedMoE does NOT support swiglu_gptoss_style - if swiglu_gptoss_style: + # CuteDslFusedMoE does NOT support gptoss_style + if gptoss_style: return _warn_and_return( - "CuteDslFusedMoE does not support swiglu_gptoss_style (bias/swiglu with custom alpha/beta/limit)" + "CuteDslFusedMoE does not support gptoss_style (bias/swiglu with custom alpha/beta/limit)" ) # NVFP4 - SM in {100, 103} diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index 83aae9a06a..ff23a103bb 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -113,15 +113,15 @@ class CutlassFusedMoE(MoE): }, } + # Quantization algorithms that support gptoss_style _GPTOSS_SUPPORTED_ALGOS = {QuantAlgo.W4A8_MXFP4_MXFP8} - """set[QuantAlgo]: Quantization algorithms that support swiglu_gptoss_style.""" @classmethod def can_implement( cls, quant_algo: Optional[QuantAlgo], dtype_activation: torch.dtype = torch.bfloat16, - swiglu_gptoss_style: bool = False, + gptoss_style: bool = False, ) -> Tuple[bool, Optional[str]]: """ Check if CutlassFusedMoE can implement the given quantization algorithm. @@ -145,8 +145,8 @@ class CutlassFusedMoE(MoE): - FP8/FP8_BLOCK_SCALES/W4A8_MXFP4_FP8: float16, bfloat16, float32 - NVFP4: float16, bfloat16, float8_e4m3fn - W4A16_MXFP4/W4A8_AWQ/W8A16/W4A8_MXFP4_MXFP8: float16, bfloat16 - swiglu_gptoss_style: Whether swiglu_gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled. - CutlassFusedMoE only supports swiglu_gptoss_style for W4A8_MXFP4_MXFP8 quantization. + gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled. + CutlassFusedMoE only supports gptoss_style for W4A8_MXFP4_MXFP8 quantization. Returns: Tuple[bool, Optional[str]]: (can_implement, skip_reason) @@ -160,10 +160,10 @@ class CutlassFusedMoE(MoE): return _warn_and_return( f"CutlassFusedMoE requires SM >= 80, got SM{sm_version}") - # Check swiglu_gptoss_style support - if swiglu_gptoss_style and quant_algo not in cls._GPTOSS_SUPPORTED_ALGOS: + # Check gptoss_style support + if gptoss_style and quant_algo not in cls._GPTOSS_SUPPORTED_ALGOS: return _warn_and_return( - f"CutlassFusedMoE swiglu_gptoss_style only supports W4A8_MXFP4_MXFP8 " + f"CutlassFusedMoE gptoss_style only supports W4A8_MXFP4_MXFP8 " f"(got quant_algo={quant_algo})") # Check if quant_algo is supported diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index f1e88cf743..671df5285e 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -382,7 +382,7 @@ class DeepGemmFusedMoE(CutlassFusedMoE): cls, quant_algo: Optional[QuantAlgo], dtype_activation: torch.dtype = torch.bfloat16, - swiglu_gptoss_style: bool = False, + gptoss_style: bool = False, ) -> Tuple[bool, Optional[str]]: """ Check if DeepGemmFusedMoE can implement the given quantization algorithm. @@ -391,15 +391,15 @@ class DeepGemmFusedMoE(CutlassFusedMoE): - FP8_BLOCK_SCALES: SM in {100, 103} Does NOT support unquantized mode. Output dtype is hardcoded to bfloat16. - Does NOT support swiglu_gptoss_style (bias/swiglu with custom alpha/beta/limit). + Does NOT support gptoss_style (bias/swiglu with custom alpha/beta/limit). Args: quant_algo: The quantization algorithm to check (None for unquantized) dtype_activation: The activation input data type. Supported types are float32, bfloat16, and float16 (required by moe_permute_op kernel). Note: Output dtype is always bfloat16 regardless of input dtype. - swiglu_gptoss_style: Whether swiglu_gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled. - DeepGemmFusedMoE does NOT support swiglu_gptoss_style. + gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled. + DeepGemmFusedMoE does NOT support gptoss_style. Returns: Tuple[bool, Optional[str]]: (can_implement, skip_reason) @@ -425,10 +425,10 @@ class DeepGemmFusedMoE(CutlassFusedMoE): return _warn_and_return( "DeepGemmFusedMoE does not support unquantized mode") - # DeepGemmFusedMoE does NOT support swiglu_gptoss_style - if swiglu_gptoss_style: + # DeepGemmFusedMoE does NOT support gptoss_style + if gptoss_style: return _warn_and_return( - "DeepGemmFusedMoE does not support swiglu_gptoss_style (bias/swiglu with custom alpha/beta/limit)" + "DeepGemmFusedMoE does not support gptoss_style (bias/swiglu with custom alpha/beta/limit)" ) # Only FP8_BLOCK_SCALES is supported diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py index fbb9491353..68ec51c18a 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py @@ -1283,12 +1283,12 @@ class TritonFusedMoE(MoE): cls, quant_algo: Optional["QuantAlgo"], dtype_activation: torch.dtype = torch.bfloat16, - swiglu_gptoss_style: bool = False, + gptoss_style: bool = False, ) -> Tuple[bool, Optional[str]]: """ Check if TritonFusedMoE can implement the given quantization algorithm. - TritonFusedMoE supports (SM90 only, swiglu_gptoss_style=True only): + TritonFusedMoE supports (SM90 only, gptoss_style=True only): - Unquantized (BF16 only) - FP8 per-tensor (QDQ) - W4A8_MXFP4_FP8 @@ -1298,8 +1298,8 @@ class TritonFusedMoE(MoE): quant_algo: The quantization algorithm to check (None for unquantized) dtype_activation: The activation data type. In unquantized mode, activation, weight, and output dtypes must all match (only bfloat16 supported). - swiglu_gptoss_style: Whether swiglu_gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled. - TritonFusedMoE ONLY supports swiglu_gptoss_style=True. + gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled. + TritonFusedMoE ONLY supports gptoss_style=True. Returns: Tuple[bool, Optional[str]]: (can_implement, skip_reason) @@ -1316,10 +1316,10 @@ class TritonFusedMoE(MoE): return _warn_and_return( f"TritonFusedMoE only supports SM90, got SM{sm_version}") - # TritonFusedMoE ONLY supports swiglu_gptoss_style=True - if not swiglu_gptoss_style: + # TritonFusedMoE ONLY supports gptoss_style=True + if not gptoss_style: return _warn_and_return( - "TritonFusedMoE only supports swiglu_gptoss_style=True") + "TritonFusedMoE only supports gptoss_style=True") # Unquantized mode - only bfloat16 is supported if quant_algo is None: diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index 8026a7799b..ea259cc162 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -87,7 +87,7 @@ class TRTLLMGenFusedMoE(MoE): QuantAlgo.W4A8_MXFP4_MXFP8, } - # Quantization algorithms that support swiglu_gptoss_style + # Quantization algorithms that support gptoss_style _GPTOSS_SUPPORTED_ALGOS = { QuantAlgo.NVFP4, QuantAlgo.W4A16_MXFP4, @@ -100,7 +100,7 @@ class TRTLLMGenFusedMoE(MoE): cls, quant_algo: Optional[QuantAlgo], dtype_activation: torch.dtype = torch.bfloat16, - swiglu_gptoss_style: bool = False, + gptoss_style: bool = False, ) -> Tuple[bool, Optional[str]]: """ Check if TRTLLMGenFusedMoE can implement the given quantization algorithm. @@ -119,7 +119,7 @@ class TRTLLMGenFusedMoE(MoE): quant_algo: The quantization algorithm to check (None for unquantized) dtype_activation: The activation input data type. Only bfloat16 is supported. See: forward_impl() assert x.dtype == torch.bfloat16 (line 722). - swiglu_gptoss_style: Whether swiglu_gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled. + gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled. Only supported for nvfp4 and mxfp4 variants. Returns: @@ -151,10 +151,10 @@ class TRTLLMGenFusedMoE(MoE): return _warn_and_return( f"TRTLLMGenFusedMoE does not support quant_algo={quant_algo}") - # Check swiglu_gptoss_style support: only supported for nvfp4 and mxfp4 variants - if swiglu_gptoss_style and quant_algo not in cls._GPTOSS_SUPPORTED_ALGOS: + # Check gptoss_style support: only supported for nvfp4 and mxfp4 variants + if gptoss_style and quant_algo not in cls._GPTOSS_SUPPORTED_ALGOS: return _warn_and_return( - f"TRTLLMGenFusedMoE supports swiglu_gptoss_style (bias/swiglu) only for nvfp4 and mxfp4 variants, " + f"TRTLLMGenFusedMoE supports gptoss_style (bias/swiglu) only for nvfp4 and mxfp4 variants, " f"got quant_algo={quant_algo}") return True, None diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index f5e8e1e6f5..7138cc9cfe 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -1,18 +1,3 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import os import weakref from abc import abstractmethod @@ -173,7 +158,7 @@ class MoE(nn.Module): cls, quant_algo: Optional[QuantAlgo], dtype_activation: torch.dtype = torch.bfloat16, - swiglu_gptoss_style: bool = False, + gptoss_style: bool = False, ) -> Tuple[bool, Optional[str]]: """ Check if this MoE backend can implement the given quantization algorithm. @@ -191,7 +176,7 @@ class MoE(nn.Module): Args: quant_algo: The quantization algorithm to check (None for unquantized) dtype_activation: The activation data type. - swiglu_gptoss_style: Whether swiglu_gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled. + gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled. Returns: Tuple[bool, Optional[str]]: (can_implement, skip_reason) diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 8ed02847dc..1f9b324c66 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -1,22 +1,6 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import inspect import math from abc import ABC, abstractmethod -from enum import Enum, auto from typing import Dict, List, NamedTuple, Optional, Tuple, Union import torch @@ -190,38 +174,11 @@ def interleave_linear_and_gate(x: torch.Tensor, return x -class EplbSupportStatus(Enum): - """EPLB support status for FusedMoEMethod classes.""" - SUPPORTED = auto() - NOT_SUPPORTED = auto() - NOT_VERIFIED = auto() - - class FusedMoEMethodBase(ABC): """ Base class for all fused MoE methods. """ weight_alignment: int = 1 - """Required byte alignment for MoE weight tensors.""" - - eplb_support_status: EplbSupportStatus = EplbSupportStatus.NOT_SUPPORTED - """Online EPLB support status for this quantization method. - - Defaults to NOT_SUPPORTED for safety so that new subclasses do not - silently claim EPLB compatibility. Subclasses that have been verified - to work with online EPLB should override this to SUPPORTED; those that - have not yet been tested may set it to NOT_VERIFIED. - """ - - @classmethod - def supports_online_eplb(cls) -> bool: - """ - Check if this FusedMoEMethod supports online EPLB. - - Returns: - True if online EPLB is supported, False otherwise. - """ - return cls.eplb_support_status == EplbSupportStatus.SUPPORTED @classmethod def need_load_shared_weights(cls, module): @@ -612,7 +569,6 @@ class FusedMoEMethodBase(ABC): class UnquantizedFusedMoEMethod(FusedMoEMethodBase): - eplb_support_status = EplbSupportStatus.SUPPORTED def create_weights(self, module: torch.nn.Module): weight_dtype = module.dtype @@ -717,7 +673,6 @@ def requantize_expert_w3_w1_weight_fp8_qdq(module: torch.nn.Module, class FP8QDQFusedMoEMethod(FusedMoEMethodBase): - eplb_support_status = EplbSupportStatus.NOT_SUPPORTED def create_weights(self, module: torch.nn.Module): weight_dtype = torch.float8_e4m3fn @@ -894,7 +849,6 @@ class FP8QDQFusedMoEMethod(FusedMoEMethodBase): class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase): - eplb_support_status = EplbSupportStatus.NOT_VERIFIED def create_weights(self, module: torch.nn.Module): weight_dtype = torch.float8_e4m3fn @@ -1154,7 +1108,6 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm( class INT8WoqPerChannelFusedMoEMethod(FusedMoEMethodBase): - eplb_support_status = EplbSupportStatus.NOT_SUPPORTED def create_weights(self, module: torch.nn.Module): module.sm_version = get_sm_version() @@ -1288,7 +1241,6 @@ class INT8WoqPerChannelFusedMoEMethod(FusedMoEMethodBase): class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase): - eplb_support_status = EplbSupportStatus.NOT_SUPPORTED def create_weights(self, module: torch.nn.Module): module.sm_version = get_sm_version() @@ -1722,7 +1674,6 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase): class WFP4A16FusedMoEMethod(FusedMoEMethodBase): - eplb_support_status = EplbSupportStatus.NOT_SUPPORTED group_size = 32 @@ -1932,7 +1883,6 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase): """ Base class for NVFP4 fused MoE methods for all backends. """ - eplb_support_status = EplbSupportStatus.SUPPORTED def get_weights_shapes(self, module: torch.nn.Module, weight_vec_size: int, block_scales_vec_size: int): @@ -3224,7 +3174,6 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4TRTLLMGenFusedMoEBaseMethod): class W4A8NVFP4FP8TRTLLMGenFusedMoEMethod(NVFP4TRTLLMGenFusedMoEBaseMethod): - eplb_support_status = EplbSupportStatus.NOT_VERIFIED def create_weights(self, module: torch.nn.Module): weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4 @@ -3283,7 +3232,6 @@ def _get_weight_alignment(weight_alignment, scaling_vector_size, tp_size, class MXFP4WeightFusedMoEMethod(FusedMoEMethodBase): - eplb_support_status = EplbSupportStatus.SUPPORTED def create_weights(self, module: torch.nn.Module, @@ -3431,7 +3379,6 @@ class MXFP4WeightFusedMoEMethod(FusedMoEMethodBase): class MXFP4WeightCutlassFusedMoEMethod(MXFP4WeightFusedMoEMethod): - eplb_support_status = EplbSupportStatus.NOT_VERIFIED weight_dtype = FUSED_MOE_MXFP4_WEIGHT_DTYPE block_scales_dtype = FUSED_MOE_MXFP4_WEIGHT_BLOCK_SCALE_DTYPE # Cutlass MoE backend requires weight elements to be 128 aligned. @@ -3604,7 +3551,6 @@ class W4A16MXFP4CutlassFusedMoEMethod(MXFP4WeightCutlassFusedMoEMethod): class W4A8MXFP4MXFP8CutlassFusedMoEMethod(MXFP4WeightCutlassFusedMoEMethod): - eplb_support_status = EplbSupportStatus.NOT_VERIFIED def create_weights(self, module: torch.nn.Module): fake_input_scale = nn.Parameter(torch.empty( @@ -3635,7 +3581,6 @@ class W4A8MXFP4MXFP8CutlassFusedMoEMethod(MXFP4WeightCutlassFusedMoEMethod): class W4A8MXFP4FP8CutlassFusedMoEMethod(MXFP4WeightCutlassFusedMoEMethod): - eplb_support_status = EplbSupportStatus.NOT_SUPPORTED def create_weights(self, module: torch.nn.Module): fc31_input_scale = nn.Parameter(torch.tensor(1., dtype=torch.float32), @@ -4042,7 +3987,6 @@ class W4A16MXFP4TRTLLMGenFusedMoEMethod(MXFP4WeightTRTLLMGenFusedMoEMethod): class W4A8MXFP4FP8TRTLLMGenFusedMoEMethod(MXFP4WeightTRTLLMGenFusedMoEMethod): - eplb_support_status = EplbSupportStatus.NOT_SUPPORTED def create_weights(self, module: torch.nn.Module): fc31_input_dequant = nn.Parameter(torch.empty( diff --git a/tests/unittest/_torch/modules/moe/moe_test_utils.py b/tests/unittest/_torch/modules/moe/moe_test_utils.py deleted file mode 100644 index b3987bafae..0000000000 --- a/tests/unittest/_torch/modules/moe/moe_test_utils.py +++ /dev/null @@ -1,727 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Shared utilities for MoE test files (test_moe_backend.py and test_moe_module.py). - -This module contains common code extracted from both test files: -- MoeBackendType enum and get_backend_class() -- MoeModelConfig dataclass -- Skip logic functions (should_skip_trtllm, should_skip_cutedsl, should_skip_routing_method, etc.) -- get_quick_skip_reason() - unified version supporting both backend and module tests -- supports_autotuner_capture() -- replay_tactics_and_check() -- module_timer fixture -- create_test_param() helper -- Common test parameter constants -""" - -import logging -import time -from dataclasses import dataclass -from enum import Enum -from itertools import product -from typing import Callable, Optional, Type - -import pytest -import torch - -from tensorrt_llm._torch.autotuner import AutoTuner -from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import CuteDslFusedMoE -from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import CutlassFusedMoE -from tensorrt_llm._torch.modules.fused_moe.fused_moe_deepgemm import DeepGemmFusedMoE -from tensorrt_llm._torch.modules.fused_moe.fused_moe_trtllm_gen import TRTLLMGenFusedMoE -from tensorrt_llm._torch.modules.fused_moe.interface import MoE -from tensorrt_llm.models.modeling_utils import QuantAlgo - -G_LOGGER = logging.getLogger(__name__) - - -# ============================================================================ -# MoE Backend Types -# ============================================================================ -class MoeBackendType(str, Enum): - """Enum for MoE backend types.""" - - CUTLASS = "CUTLASS" - TRTLLM = "TRTLLM" - CUTEDSL = "CUTEDSL" - DEEPGEMM = "DEEPGEMM" - - -def get_backend_class(backend_type: MoeBackendType) -> Type[MoE]: - """Get the MoE backend class for a given backend type.""" - backend_class_map = { - MoeBackendType.CUTLASS: CutlassFusedMoE, - MoeBackendType.TRTLLM: TRTLLMGenFusedMoE, - MoeBackendType.CUTEDSL: CuteDslFusedMoE, - MoeBackendType.DEEPGEMM: DeepGemmFusedMoE, - } - return backend_class_map[backend_type] - - -# ============================================================================ -# Model Configuration -# ============================================================================ -@dataclass -class MoeModelConfig: - """MoE model configuration: (num_experts, top_k, hidden_size, intermediate_size).""" - - num_experts: int - top_k: int - hidden_size: int - intermediate_size: int - - def __str__(self) -> str: - return f"e{self.num_experts}_k{self.top_k}_h{self.hidden_size}_i{self.intermediate_size}" - - -# ============================================================================ -# Skip Logic Functions -# ============================================================================ -def should_skip_trtllm( - backend_type: MoeBackendType, - quant_algo: Optional[QuantAlgo], - model_config: "MoeModelConfig", - routing_method_cls=None, - swiglu_gptoss_style: bool = False, - comm_method: Optional[str] = None, -) -> Optional[str]: - """ - Check TRTLLM Gen backend specific constraints. - - The TRTLLM Gen MoE kernels have hardware-level constraints that must be satisfied. - These constraints are enforced in C++ layer. - - Args: - backend_type: The MoE backend type - quant_algo: The quantization algorithm - model_config: The MoE model configuration - routing_method_cls: Optional routing method class for compatibility checks - (used by test_moe_module.py) - swiglu_gptoss_style: Whether using swiglu gptoss style - comm_method: Optional communication method (e.g. "DEEPEP", "DEEPEPLOWLATENCY") - for multi-GPU EP mode checks - - Returns: - Skip reason string if test should be skipped, None otherwise - """ - if backend_type != MoeBackendType.TRTLLM: - return None - - # Routing method compatibility check (used by test_moe_module.py) - # TRTLLMGen C++ routing kernel (runner.cu) only implements: - # - DeepSeekV3 (requires float32 routing_logits) - # - Llama4 (requires top_k=1) - # - Renormalize - # - RenormalizeNaive - # See: cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu:77-212 - if routing_method_cls is not None: - from tensorrt_llm._torch.modules.fused_moe import ( - DeepSeekV3MoeRoutingMethod, - DefaultMoeRoutingMethod, - Llama4RenormalizeMoeRoutingMethod, - MiniMaxM2MoeRoutingMethod, - ) - - # Routing methods NOT implemented in C++ kernel - trtllm_unimplemented_routing = ( - DefaultMoeRoutingMethod, # runner.cu:210 - "Unimplemented routing method" - MiniMaxM2MoeRoutingMethod, # runner.cu:210 - "Unimplemented routing method" - ) - if routing_method_cls in trtllm_unimplemented_routing: - routing_name = routing_method_cls.__name__ - return ( - f"TRTLLMGen C++ routing kernel does not implement {routing_name}. See runner.cu:210" - ) - - # Llama4 routing only supports top_k=1 - # See: runner.cu:113 - TLLM_CHECK_WITH_INFO(topK == 1, ...) - if routing_method_cls == Llama4RenormalizeMoeRoutingMethod: - if model_config is not None and model_config.top_k != 1: - return ( - f"TRTLLMGen Llama4 routing only supports top_k=1 " - f"(got top_k={model_config.top_k}). See runner.cu:113" - ) - - # DeepSeekV3 routing requires num_experts >= 22 - # See: RoutingDeepSeek.cu:32,664 - MaxSupportedTopExperts = 22 - if routing_method_cls == DeepSeekV3MoeRoutingMethod: - if model_config is not None and model_config.num_experts < 22: - return ( - f"TRTLLMGen DeepSeekV3 routing requires num_experts >= 22 " - f"(got num_experts={model_config.num_experts}). See RoutingDeepSeek.cu:664" - ) - - # DeepSeekV3 routing kernel only supports topk_group <= 4. - # topk_group is computed from num_experts in _create_routing_method: - # n_group = max(1, num_experts // 2) - # topk_group = min(n_group, max(1, n_group // 2)) - if model_config is not None: - n_group = max(1, model_config.num_experts // 2) - topk_group = min(n_group, max(1, n_group // 2)) - if topk_group > 4: - return ( - f"TRTLLMGen DeepSeekV3 routing kernel only supports " - f"topk_group <= 4 (got topk_group={topk_group} from " - f"num_experts={model_config.num_experts})" - ) - - if model_config is None: - return None - - # These quantization algorithms use TRTLLM Gen kernels with the constraints - trtllm_gen_quant_algos = { - QuantAlgo.NVFP4, - QuantAlgo.FP8_BLOCK_SCALES, - QuantAlgo.W4A8_NVFP4_FP8, - QuantAlgo.W4A16_MXFP4, - QuantAlgo.W4A8_MXFP4_MXFP8, - } - - if quant_algo not in trtllm_gen_quant_algos: - return None - - num_experts = model_config.num_experts - top_k = model_config.top_k - intermediate_size = model_config.intermediate_size - - # Check: num_experts must be divisible by 4 - # Routing kernel uses vectorized operations that require this alignment - if num_experts % 4 != 0: - return ( - f"TRTLLMGenFusedMoE routing kernel requires num_experts divisible by 4 " - f"(got num_experts={num_experts})" - ) - - # Check: num_experts must be greater than top_k - # Routing logic cannot handle the case where all experts are selected - if num_experts <= top_k: - return ( - f"TRTLLMGenFusedMoE requires num_experts > top_k " - f"(got num_experts={num_experts}, top_k={top_k})" - ) - # W4A8_MXFP4_MXFP8 with non-128-aligned hidden_size or intermediate_size - # causes block_scale_interleave_reverse to fail with - # "rows of Interleaved block scales should be multiple of 128". - if quant_algo == QuantAlgo.W4A8_MXFP4_MXFP8: - hidden_size = model_config.hidden_size - if hidden_size % 128 != 0 or intermediate_size % 128 != 0: - return ( - f"TRTLLMGenFusedMoE W4A8_MXFP4_MXFP8 with non-128-aligned " - f"sizes (h={hidden_size}, i={intermediate_size}) causes " - f"block_scale_interleave_reverse rows must be multiple of 128." - ) - - # -----------------Potential issues------------------ - # These are known issues that need investigation. Skipping to avoid test failures - # and CUDA errors that can cascade to subsequent tests. - - # Issue: W4A8_NVFP4_FP8 with top_k=1 causes CUDA illegal memory access - if quant_algo == QuantAlgo.W4A8_NVFP4_FP8 and top_k == 1: - return ( - "[Potential Bug] TRTLLMGenFusedMoE W4A8_NVFP4_FP8 with top_k=1 " - "causes CUDA illegal memory access." - ) - - # Issue: NVFP4 with large intermediate_size has known accuracy issues - if quant_algo == QuantAlgo.NVFP4 and intermediate_size >= 14336: - return ( - f"[Potential Bug] TRTLLMGenFusedMoE NVFP4 with large intermediate_size " - f"has known accuracy issues (intermediate_size={intermediate_size} >= 14336)." - ) - - # Issue: W4A8_MXFP4_MXFP8 has accuracy issues on certain model configs - if quant_algo == QuantAlgo.W4A8_MXFP4_MXFP8: - if intermediate_size >= 14336: - return ( - f"[Potential Bug] TRTLLMGenFusedMoE W4A8_MXFP4_MXFP8 with large " - f"intermediate_size has accuracy issues (intermediate_size={intermediate_size} >= 14336)." - ) - if num_experts >= 60 and intermediate_size >= 1408: - return ( - f"[Potential Bug] TRTLLMGenFusedMoE W4A8_MXFP4_MXFP8 with many experts " - f"has accuracy issues (num_experts={num_experts} >= 60)." - ) - # Issue: W4A8_MXFP4_MXFP8 with swiglu_gptoss_style and top_k=1 has accuracy - # issues on TRTLLM backend. Observed mismatch ~20-22% exceeds the 20% threshold. - # CUTLASS backend with the same configuration passes. - if swiglu_gptoss_style and top_k == 1: - return ( - f"[Potential Bug] TRTLLMGenFusedMoE W4A8_MXFP4_MXFP8 with " - f"swiglu_gptoss_style and top_k={top_k} has accuracy issues " - f"(mismatch ~20-22%). CUTLASS backend with the same config passes." - ) - - # Issue: Certain TRTLLM kernel runners crash with CUDA errors in multi-GPU - # DeepEP mode. the crash is specific to EP with DeepEP. - # Verified on 4 GPUs with DEP + DEEPEP + TRTLLM (e60_k4_h2048_i1408): - # - FP8_BLOCK_SCALES: CRASH (fp8_block_scale_moe_runner -> CUDA_ERROR_INVALID_HANDLE) - # - W4A16_MXFP4: CRASH (bf16_mxe2m1_block_scale_moe_runner -> illegal memory access) - # - W4A8_MXFP4_MXFP8: likely crash (same mxe2m1 kernel family as W4A16_MXFP4) - if comm_method in ("DEEPEP", "DEEPEPLOWLATENCY"): - deepep_crash_quant_algos = { - QuantAlgo.FP8_BLOCK_SCALES, - QuantAlgo.W4A16_MXFP4, - QuantAlgo.W4A8_MXFP4_MXFP8, - } - if quant_algo in deepep_crash_quant_algos: - return ( - f"[Potential Bug] TRTLLMGenFusedMoE {quant_algo} crashes with " - f"CUDA error in multi-GPU DeepEP mode (comm={comm_method}). " - f"Single-GPU tests pass; issue is in the kernel runner under EP." - ) - - return None - - -def should_skip_cutedsl( - backend_type: MoeBackendType, - quant_algo: Optional[QuantAlgo], - model_config: "MoeModelConfig" = None, - comm_method: Optional[str] = None, - routing_method_cls=None, -) -> Optional[str]: - """ - Check CuteDSL backend specific constraints. - - Returns: - Skip reason string if test should be skipped, None otherwise - """ - if backend_type != MoeBackendType.CUTEDSL: - return None - - # DeepEPLowLatency _modify_output_to_adapt_fused_moe converts dispatch output - # to a format where token_selected_slots has shape [num_local_experts, tokens_per_expert] - # instead of [num_tokens, top_k]. CuteDSL moe_sort asserts - # token_selected_experts.size(1) == top_k, which fails with this format. - if comm_method == "DEEPEPLOWLATENCY": - return ( - "[Potential Bug] CuteDslFusedMoE is incompatible with DeepEPLowLatency: " - "DeepEPLowLatency _modify_output_to_adapt_fused_moe reshapes " - "token_selected_slots to [num_local_experts, tokens_per_expert] " - "(effectively top_k=1), but CuteDSL moe_sort requires " - "token_selected_experts.size(1) == top_k." - ) - - if model_config is None: - return None - - intermediate_size = model_config.intermediate_size - num_experts = model_config.num_experts - - # NVFP4 with large intermediate_size has known accuracy issues - if quant_algo == QuantAlgo.NVFP4 and intermediate_size >= 14336: - return ( - f"[Potential Bug] CuteDslFusedMoE NVFP4 with large intermediate_size " - f"has known accuracy issues (intermediate_size={intermediate_size} >= 14336)." - ) - - # NVFP4 with prime num_experts causes CUDA_ERROR_ILLEGAL_ADDRESS - prime_experts_with_issues = {7, 13} - if quant_algo == QuantAlgo.NVFP4 and num_experts in prime_experts_with_issues: - return ( - f"[Potential Bug] CuteDslFusedMoE NVFP4 with prime num_experts={num_experts} " - f"causes CUDA_ERROR_ILLEGAL_ADDRESS due to autotuner cache bucket mapping." - ) - - # NVFP4 with Llama4Renormalize routing has significant accuracy issues on bfloat16. - # Observed mismatch up to 34.6% (threshold 2% at rtol=0.01, percent=0.98). - if routing_method_cls is not None: - from tensorrt_llm._torch.modules.fused_moe import Llama4RenormalizeMoeRoutingMethod - - if ( - quant_algo == QuantAlgo.NVFP4 - and routing_method_cls == Llama4RenormalizeMoeRoutingMethod - ): - return ( - "[Potential Bug] CuteDslFusedMoE NVFP4 with Llama4Renormalize " - "routing has significant accuracy issues (mismatch up to 34.6%%)." - ) - - return None - - -def should_skip_deepgemm( - backend_type: MoeBackendType, - comm_method: Optional[str] = None, - quant_algo: Optional[QuantAlgo] = None, - model_config: "MoeModelConfig" = None, -) -> Optional[str]: - """ - Check DeepGemm backend specific constraints. - - Returns: - Skip reason string if test should be skipped, None otherwise - """ - if backend_type != MoeBackendType.DEEPGEMM: - return None - - # DeepGemm workspace allocation in set_strides (fused_moe_deepgemm.py) uses a - # storage size that is 4x too small when combined with DeepEPLowLatency dispatch. - # The workspace is allocated based on assumptions that do not account for the - # DeepEPLowLatency output format ([num_local_experts, ep_size * max_tokens, hidden_size]). - if comm_method == "DEEPEPLOWLATENCY": - return ( - "[Potential Bug] DeepGemmFusedMoE workspace allocation is incompatible " - "with DeepEPLowLatency: set_strides requires storage of " - "[num_local_experts * tokens * hidden_size] bytes but the allocated " - "workspace is ~4x too small, causing setStorage out of bounds." - ) - - # Issue: DEEPGEMM + FP8_BLOCK_SCALES crashes with CUDA illegal memory access - # on large expert counts (e.g. e384_k8_h7168_i2048) during post_load_weights(). - # The crash occurs in get_col_major_tma_aligned_packed_tensor (fp8_utils.py) - # when resmoothing FP8 E8M0 scales on SM100f (Blackwell). - # Small configs (e.g. e60_k4_h2048_i1408) pass fine. - if quant_algo == QuantAlgo.FP8_BLOCK_SCALES and model_config is not None: - if model_config.num_experts > 128: - return ( - f"[Potential Bug] DeepGemmFusedMoE FP8_BLOCK_SCALES crashes with " - f"CUDA illegal memory access on large expert count " - f"(num_experts={model_config.num_experts}). The crash occurs in " - f"get_col_major_tma_aligned_packed_tensor during " - f"post_load_weights() FP8 E8M0 scale resmoothing on SM100f." - ) - - return None - - -def should_skip_multi_gpu( - parallel_mode: str, - model_config: "MoeModelConfig", - world_size: int = 4, -) -> Optional[str]: - """ - Check if a multi-GPU test should be skipped due to EP partitioning constraints. - - In EP modes (DEP, TEP), num_experts must be divisible by ep_size (= world_size) - when EPLB (Expert Load Balancing) is not enabled. Otherwise the assertion - `num_experts % ep_size == 0` in interface.py _init_load_balancer will fail. - - Args: - parallel_mode: Parallelism strategy ("DEP", "TEP", "DTP", "TTP") - model_config: MoE model configuration containing num_experts - world_size: Total number of GPUs (default: 4) - - Returns: - Skip reason string if test should be skipped, None otherwise - """ - # Only EP modes have ep_size = world_size; TP modes have ep_size = 1 - if parallel_mode not in ("DEP", "TEP"): - return None - - ep_size = world_size - num_experts = model_config.num_experts - if num_experts % ep_size != 0: - return ( - f"num_experts={num_experts} is not divisible by ep_size={ep_size} " - f"in {parallel_mode} mode. Requires EPLB to handle non-uniform " - f"expert partitioning (tested separately in test_ConfigurableMoE_multi_gpu_eplb)." - ) - - return None - - -def should_skip_routing_method( - routing_method_cls, - model_config: "MoeModelConfig", -) -> Optional[str]: - """ - Check routing method specific constraints that are independent of backend. - - Args: - routing_method_cls: The routing method class - model_config: The MoE model configuration - - Returns: - Skip reason string if test should be skipped, None otherwise - """ - if routing_method_cls is None or model_config is None: - return None - - from tensorrt_llm._torch.modules.fused_moe import DeepSeekV3MoeRoutingMethod - - # DeepSeekV3 routing: num_experts must be divisible by n_group for the - # view operation in noaux_tc (routing.py:298). n_group = max(1, num_experts // 2), - # so odd num_experts (e.g. 7, 13) fail because num_experts % n_group != 0. - if routing_method_cls == DeepSeekV3MoeRoutingMethod: - num_experts = model_config.num_experts - experts_per_group = 2 - n_group = max(1, num_experts // experts_per_group) - if n_group > 1 and num_experts % n_group != 0: - return ( - f"DeepSeekV3 routing requires num_experts divisible by n_group " - f"(num_experts={num_experts}, n_group={n_group}). " - f"noaux_tc view([n_group, num_experts // n_group]) fails." - ) - - return None - - -def supports_autotuner_capture( - backend_type: MoeBackendType, - _quant_algo: Optional[QuantAlgo], -) -> bool: - """ - Determine if a backend+quant_algo combination supports AutoTuner capture/replay. - - Args: - backend_type: The MoE backend type - _quant_algo: The quantization algorithm (None for unquantized). - Reserved for future per-algorithm gating; currently unused. - - Returns: - True if autotuner capture/replay is supported, False otherwise - """ - # DEEPGEMM does not support autotuner capture - if backend_type == MoeBackendType.DEEPGEMM: - return False - - return True - - -def get_quick_skip_reason( - backend_type: MoeBackendType, - quant_algo: Optional[QuantAlgo], - dtype: torch.dtype, - model_config: "MoeModelConfig", - routing_method_cls=None, - swiglu_gptoss_style: bool = False, -) -> Optional[str]: - """ - Fast skip check that calls backend's can_implement() method. - - Unified version supporting both backend-level and module-level tests: - - routing_method_cls: Used by test_moe_module.py for routing method compatibility checks - - swiglu_gptoss_style: Used by test_moe_backend.py for SwiGLU parameter checks - - Returns: - Skip reason string if test should be skipped, None otherwise - """ - import logging as _logging - - # Suppress logger warnings during parameter generation - trtllm_logger = _logging.getLogger("tensorrt_llm") - original_level = trtllm_logger.level - trtllm_logger.setLevel(_logging.ERROR) - - try: - # Call backend's can_implement for dtype/quant_algo checks - backend_cls = get_backend_class(backend_type) - can_impl_kwargs = {"dtype_activation": dtype} - if swiglu_gptoss_style: - can_impl_kwargs["swiglu_gptoss_style"] = swiglu_gptoss_style - can_impl, skip_reason = backend_cls.can_implement(quant_algo, **can_impl_kwargs) - if not can_impl: - return skip_reason - - # Chain skip checks: routing method, then per-backend constraints - skip_checks = [ - lambda: should_skip_routing_method(routing_method_cls, model_config), - lambda: should_skip_trtllm( - backend_type, quant_algo, model_config, routing_method_cls, swiglu_gptoss_style - ), - lambda: should_skip_cutedsl( - backend_type, quant_algo, model_config, routing_method_cls=routing_method_cls - ), - lambda: should_skip_deepgemm( - backend_type, quant_algo=quant_algo, model_config=model_config - ), - ] - for check in skip_checks: - skip_reason = check() - if skip_reason: - return skip_reason - - # DEEPGEMM: float16 reference module constraint - if backend_type == MoeBackendType.DEEPGEMM and dtype == torch.float16: - return "DeepGemmFusedMoE reference module requires bfloat16 input" - - # 128-alignment requirement for quantization - if quant_algo is not None: - hidden_size = model_config.hidden_size - intermediate_size = model_config.intermediate_size - is_hidden_128_aligned = hidden_size % 128 == 0 - is_intermediate_128_aligned = intermediate_size % 128 == 0 - - if not is_hidden_128_aligned or not is_intermediate_128_aligned: - # TRTLLM with MXFP4 variants automatically pads to 128 alignment - is_mxfp4_variant = quant_algo in {QuantAlgo.W4A16_MXFP4, QuantAlgo.W4A8_MXFP4_MXFP8} - is_trtllm_backend = backend_type == MoeBackendType.TRTLLM - if not (is_trtllm_backend and is_mxfp4_variant): - return ( - f"Non-128-aligned sizes (h={hidden_size}, i={intermediate_size}) " - f"require TRTLLM backend with MXFP4 quantization" - ) - - return None - - finally: - trtllm_logger.setLevel(original_level) - - -# ============================================================================ -# Autotuner Tactic Replay -# ============================================================================ -def replay_tactics_and_check( - all_tactics, - run_moe_fn: Callable[[], torch.Tensor], - check_accuracy_fn: Callable[[torch.Tensor, torch.Tensor], None], - ref_output: torch.Tensor, - backend_type: MoeBackendType, - quant_algo: Optional[QuantAlgo], - fail_fast: bool = False, -) -> None: - """ - Replay all tactics and check accuracy. - - Args: - all_tactics: TacticsCapture object from AutoTuner.capture() - run_moe_fn: Function to run MoE computation - check_accuracy_fn: Function to check accuracy (output, ref_output) -> None - ref_output: Reference output tensor - backend_type: Backend type for error reporting - quant_algo: Quantization algorithm for error reporting - fail_fast: If True, fail on first error. If False, run all and report summary. - """ - tactics_list = list(all_tactics) - passed_tactics = [] - failed_tactics = [] - G_LOGGER.info(f"Replay tactics : {len(tactics_list)} and check accuracy") - for idx, tactic in enumerate(tactics_list): - with AutoTuner.get().replay(tactic), torch.inference_mode(): - output = run_moe_fn() - try: - check_accuracy_fn(output, ref_output) - passed_tactics.append((idx, tactic)) - except Exception as e: - if fail_fast: - pytest.fail( - f"Accuracy check failed for tactic[{idx}/{len(tactics_list)}]={tactic}, " - f"backend={backend_type}, quant_algo={quant_algo}: {e}" - ) - failed_tactics.append((idx, tactic, str(e))) - - # Report results (only when fail_fast=False) - total = len(tactics_list) - num_passed = len(passed_tactics) - num_failed = len(failed_tactics) - if failed_tactics: - fail_details = "\n".join( - f" tactic[{idx}]={tactic}: {err}" for idx, tactic, err in failed_tactics - ) - pytest.fail( - f"backend={backend_type}, quant_algo={quant_algo}: " - f"{num_passed}/{total} passed, {num_failed}/{total} failed\n" - f"Failed tactics:\n{fail_details}" - ) - - -# ============================================================================ -# Test Parameter Helpers -# ============================================================================ -def create_test_param(param_values, test_id, skip_reason=None): - """Create a pytest.param with optional skip mark.""" - if skip_reason: - return pytest.param(*param_values, id=test_id, marks=pytest.mark.skip(reason=skip_reason)) - return pytest.param(*param_values, id=test_id) - - -# ============================================================================ -# Timing Fixture -# ============================================================================ -@pytest.fixture(scope="module", autouse=True) -def module_timer(request): - """Fixture to measure and log total module execution time.""" - start = time.perf_counter() - yield - elapsed = time.perf_counter() - start - G_LOGGER.info( - "[TIMING] Total %s: %.3fs (%.2f min)", - request.module.__name__, - elapsed, - elapsed / 60, - ) - - -# ============================================================================ -# Base Test Config Iterator -# ============================================================================ -def iter_base_test_configs( - swiglu_combos, model_configs, seq_lens, dtypes, backend_types, quant_algos, routing_methods=None -): - """ - Iterate over base test configurations using itertools.product. - - This is shared by test_moe_backend.py and test_moe_module.py. - When routing_methods is None, defaults to [RenormalizeMoeRoutingMethod]. - - Args: - swiglu_combos: List of (swiglu_alpha, swiglu_beta, swiglu_limit) tuples - model_configs: List of MoeModelConfig - seq_lens: List of sequence lengths - dtypes: List of data types - backend_types: List of backend types - quant_algos: List of quantization algorithms - routing_methods: List of routing method classes (default: [RenormalizeMoeRoutingMethod]) - - Yields: - Tuple of (swiglu_alpha, swiglu_beta, swiglu_limit, model_config, seq_len, - dtype, backend_type, quant_algo, routing_method_cls, skip_reason, base_test_id) - """ - if routing_methods is None: - from tensorrt_llm._torch.modules.fused_moe import RenormalizeMoeRoutingMethod - - routing_methods = [RenormalizeMoeRoutingMethod] - - for ( - swiglu_alpha, - swiglu_beta, - swiglu_limit, - ), model_config, seq_len, dtype, backend_type, quant_algo, routing_method_cls in product( - swiglu_combos, model_configs, seq_lens, dtypes, backend_types, quant_algos, routing_methods - ): - swiglu_gptoss_style = swiglu_alpha != 1 or swiglu_beta != 0 or swiglu_limit != float("inf") - skip_reason = get_quick_skip_reason( - backend_type, - quant_algo, - dtype, - model_config, - routing_method_cls, - swiglu_gptoss_style=swiglu_gptoss_style, - ) - routing_name = routing_method_cls.__name__.replace("MoeRoutingMethod", "") - swiglu_id = ( - f"alpha={swiglu_alpha}_beta={swiglu_beta}_limit={swiglu_limit}-" - if swiglu_gptoss_style - else "" - ) - base_test_id = ( - f"{swiglu_id}{model_config}-seq={seq_len}-dtype={dtype}-" - f"backend={backend_type.value}-quant={quant_algo}-routing={routing_name}" - ) - yield ( - swiglu_alpha, - swiglu_beta, - swiglu_limit, - model_config, - seq_len, - dtype, - backend_type, - quant_algo, - routing_method_cls, - skip_reason, - base_test_id, - ) diff --git a/tests/unittest/_torch/modules/moe/quantize_utils.py b/tests/unittest/_torch/modules/moe/quantize_utils.py index 24652a1c06..57fbb4f832 100644 --- a/tests/unittest/_torch/modules/moe/quantize_utils.py +++ b/tests/unittest/_torch/modules/moe/quantize_utils.py @@ -217,7 +217,7 @@ class RefGatedMLPFusedMoE(nn.Module): model_config = ModelConfig() self.quant_config = model_config.quant_config - # Custom swiglu activation for swiglu_gptoss_style + # Custom swiglu activation for gptoss_style def custom_swiglu(x): gate, value = x.chunk(2, dim=-1) if swiglu_limit is not None and swiglu_limit != float("inf"): @@ -314,7 +314,7 @@ class BaseQuantizeUtil(ABC): """ BaseQuantizeUtil serves as a base class for MoE correctess testing which provides interface to create quantized weights and reference modules. It can be extended for different quantization algorithms. - Supports swiglu_gptoss_style with custom swiglu parameters. + Supports gptoss_style with custom swiglu parameters. """ def __init__( @@ -325,11 +325,10 @@ class BaseQuantizeUtil(ABC): hidden_size: int, quant_config: QuantConfig, bias: bool = False, - swiglu_gptoss_style: bool = False, + gptoss_style: bool = False, swiglu_alpha: Optional[float] = None, swiglu_beta: Optional[float] = None, swiglu_limit: Optional[float] = None, - num_local_experts: Optional[int] = None, ): self.num_experts = num_experts self.dtype = dtype @@ -337,48 +336,38 @@ class BaseQuantizeUtil(ABC): self.hidden_size = hidden_size self.quant_config = quant_config self.bias = bias - self._swiglu_gptoss_style = swiglu_gptoss_style + self._gptoss_style = gptoss_style self.swiglu_alpha = swiglu_alpha self.swiglu_beta = swiglu_beta self.swiglu_limit = swiglu_limit - # In EP mode, swiglu tensors must be sized per local experts - # (see modeling_gpt_oss.py: num_slots // moe_ep_size) - self.num_local_experts = num_local_experts if num_local_experts is not None else num_experts - # Pre-create swiglu tensors if swiglu_gptoss_style is enabled - if self._swiglu_gptoss_style: - self.swiglu_alpha = 1.0 if self.swiglu_alpha is None else self.swiglu_alpha - self.swiglu_beta = 0.0 if self.swiglu_beta is None else self.swiglu_beta - self.swiglu_limit = float("inf") if self.swiglu_limit is None else self.swiglu_limit + # Pre-create swiglu tensors if gptoss_style is enabled + if self._gptoss_style: self._swiglu_tensors = self._create_swiglu_tensors() else: self._swiglu_tensors = None @property - def swiglu_gptoss_style(self) -> bool: - """Check if swiglu_gptoss_style is enabled.""" - return self._swiglu_gptoss_style + def gptoss_style(self) -> bool: + """Check if gptoss_style is enabled.""" + return self._gptoss_style def _create_swiglu_tensors(self) -> Dict[str, torch.Tensor]: """ Internal method to create swiglu tensors for MoE backend. - Uses num_local_experts (= num_experts // ep_size in EP mode) to match - the kernel expectation. See modeling_gpt_oss.py for reference: - swiglu_alpha is created with size (num_slots // moe_ep_size). - Returns: Dict with 'swiglu_alpha', 'swiglu_beta', 'swiglu_limit' tensors. """ return { "swiglu_alpha": torch.full( - (self.num_local_experts,), self.swiglu_alpha, device="cuda", dtype=torch.float + (self.num_experts,), self.swiglu_alpha, device="cuda", dtype=torch.float ), "swiglu_beta": torch.full( - (self.num_local_experts,), self.swiglu_beta, device="cuda", dtype=torch.float + (self.num_experts,), self.swiglu_beta, device="cuda", dtype=torch.float ), "swiglu_limit": torch.full( - (self.num_local_experts,), self.swiglu_limit, device="cuda", dtype=torch.float + (self.num_experts,), self.swiglu_limit, device="cuda", dtype=torch.float ), } @@ -387,7 +376,7 @@ class BaseQuantizeUtil(ABC): Get pre-created swiglu tensors. Returns: - Dict with swiglu tensors if swiglu_gptoss_style is enabled, None otherwise. + Dict with swiglu tensors if gptoss_style is enabled, None otherwise. """ return self._swiglu_tensors @@ -439,12 +428,11 @@ class FP8RefGatedMLPFusedMoE(RefGatedMLPFusedMoE): expected_quant_algo = QuantAlgo.FP8 def check_accuracy(self, output, ref_output): - # Relaxed percent from 0.97 to 0.95 to account for FP8 quantization error accumulation - # in large intermediate dimensions and multi-expert routing computations, - # especially with Llama4Renormalize sigmoid-based routing. + # Relaxed percent from 0.99 to 0.97 to account for FP8 quantization error accumulation + # in large intermediate dimensions and multi-expert routing computations. # Theoretical basis: FP8 (E4M3) has ~12.5% unit error, accumulated error grows as sqrt(K) - # where K is GEMM reduction dimension. Max observed mismatch is ~4.8% < 5%. - check_accuracy(output, ref_output, rtol=4e-2, atol=1e-1, percent=0.95) + # where K is GEMM reduction dimension. Max observed mismatch is ~2.1% < 3%. + check_accuracy(output, ref_output, rtol=4e-2, atol=1e-1, percent=0.97) class FP8QuantizeUtil(BaseQuantizeUtil): @@ -459,6 +447,7 @@ class FP8QuantizeUtil(BaseQuantizeUtil): assert self.quant_config is not None and self.quant_config.quant_algo == QuantAlgo.FP8, ( "expect quant_algo to be fp8" ) + bias = quant_kwargs.get("bias", False) weights = {} for expert_id in range(self.num_experts): w1_weight = torch.randn( @@ -500,7 +489,7 @@ class FP8QuantizeUtil(BaseQuantizeUtil): weights[f"{expert_id}.w2.input_scale"] = w2_input_scale weights[f"{expert_id}.w3.input_scale"] = w3_input_scale - if self.bias: + if bias: weights[f"{expert_id}.w1.bias"] = torch.randn( (self.intermediate_size,), dtype=self.dtype, device="cuda" ) @@ -525,25 +514,22 @@ class NVFP4RefGatedMLPFusedMoE(RefGatedMLPFusedMoE): scale_keys = ["weight_scale", "input_scale", "weight_scale_2"] expected_quant_algo = QuantAlgo.NVFP4 - def __init__(self, *args, swiglu_gptoss_style: bool = False, **kwargs): + def __init__(self, *args, gptoss_style: bool = False, **kwargs): super().__init__(*args, **kwargs) - self.swiglu_gptoss_style = swiglu_gptoss_style + self.gptoss_style = gptoss_style def check_accuracy(self, output, ref_output): - if self.swiglu_gptoss_style: - # swiglu_gptoss_style uses relaxed tolerance + if self.gptoss_style: + # gptoss_style uses relaxed tolerance check_accuracy(output, ref_output, rtol=0.1, atol=0.1, percent=0.95) else: - # Relaxed percent from 0.98 to 0.97 to account for NVFP4 quantization - # error accumulation with certain routing methods (e.g. Llama4Renormalize). - # Max observed mismatch in non-skipped cases is ~2.7% < 3%. - check_accuracy(output, ref_output, rtol=1e-2, atol=0.15, percent=0.97) + check_accuracy(output, ref_output, rtol=1e-2, atol=0.15, percent=0.98) class NVFP4QuantizeUtil(BaseQuantizeUtil): """ NVFP4QuantizeUtil inherits from BaseQuantizeUtil to support correctness testing for NVFP4 quantized MoE modules. - Supports swiglu_gptoss_style with custom swiglu parameters (inherited from BaseQuantizeUtil). + Supports gptoss_style with custom swiglu parameters (inherited from BaseQuantizeUtil). """ def create_weights(self, **quant_kwargs) -> Dict[str, torch.Tensor]: @@ -643,7 +629,7 @@ class NVFP4QuantizeUtil(BaseQuantizeUtil): self, routing_method, ref_cls=NVFP4RefGatedMLPFusedMoE ) -> torch.nn.Module: """ - Create a reference module for correctness testing with swiglu_gptoss_style support. + Create a reference module for correctness testing with gptoss_style support. """ ref_fused_moe = ref_cls( num_experts=self.num_experts, @@ -653,7 +639,7 @@ class NVFP4QuantizeUtil(BaseQuantizeUtil): dtype=self.dtype, model_config=ModelConfig(quant_config=self.quant_config), bias=self.bias, - swiglu_gptoss_style=self.swiglu_gptoss_style, + gptoss_style=self.gptoss_style, swiglu_alpha=self.swiglu_alpha, swiglu_beta=self.swiglu_beta, swiglu_limit=self.swiglu_limit, @@ -681,11 +667,6 @@ class FP8BlockScalesQuantizeUtil(BaseQuantizeUtil): for FP8 block-wise quantized MoE modules. """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # Will be set by create_weights() if ref_cls is provided in quant_kwargs - self._ref_cls = None - def create_weights(self, **quant_kwargs) -> Dict[str, torch.Tensor]: """ Create quantized weights for MoE experts using FP8 block-wise quantization. @@ -694,18 +675,12 @@ class FP8BlockScalesQuantizeUtil(BaseQuantizeUtil): use_e8m0_scale: If True, use per_block_cast_to_fp8_e8m0 which produces E8M0 format scales required by DEEPGEMM and TRTLLM backends. If False, use per_block_cast_to_fp8 with regular float scales. - ref_cls: Optional reference module class to use for accuracy testing. - If provided, will be stored and used by create_ref_module(). """ assert ( self.quant_config is not None and self.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES ), "expect quant_algo to be FP8_BLOCK_SCALES" - # Store ref_cls if provided for later use in create_ref_module() - if "ref_cls" in quant_kwargs: - self._ref_cls = quant_kwargs.pop("ref_cls") - # Select quantization function based on scale format requirement use_e8m0_scale = quant_kwargs.get("use_e8m0_scale", False) quant_fn = per_block_cast_to_fp8_e8m0 if use_e8m0_scale else per_block_cast_to_fp8 @@ -737,19 +712,12 @@ class FP8BlockScalesQuantizeUtil(BaseQuantizeUtil): weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale return weights - def create_ref_module(self, routing_method, ref_cls=None) -> torch.nn.Module: + def create_ref_module( + self, routing_method, ref_cls=FP8BlockScalesRefGatedMLPFusedMoE + ) -> torch.nn.Module: """ Create a reference module for correctness testing. - - Uses ref_cls in the following priority: - 1. Explicitly passed ref_cls argument - 2. ref_cls stored from create_weights() call (via quant_kwargs) - 3. Default FP8BlockScalesRefGatedMLPFusedMoE """ - if ref_cls is None: - ref_cls = ( - self._ref_cls if self._ref_cls is not None else FP8BlockScalesRefGatedMLPFusedMoE - ) return super().create_ref_module(routing_method, ref_cls) def create_input(self, seq_len: int) -> torch.Tensor: @@ -815,19 +783,14 @@ class DeepGemmFP8BlockScalesRefFusedMoE(FP8BlockScalesRefGatedMLPFusedMoE): self.w2_weights_stacked = torch.stack(w2_list, dim=0) self.w2_scales_stacked = torch.stack(w2_scale_list, dim=0) - def cuda(self, device=None): - """Move all weights to CUDA. - - Args: - device: Optional device specification (e.g., 'cuda:0', 0, or torch.device). - If None, uses the current CUDA device. - """ - super().cuda(device) + def cuda(self): + """Move all weights to CUDA.""" + super().cuda() if self.w3_w1_weights is not None: - self.w3_w1_weights = self.w3_w1_weights.cuda(device) - self.w3_w1_scales = self.w3_w1_scales.cuda(device) - self.w2_weights_stacked = self.w2_weights_stacked.cuda(device) - self.w2_scales_stacked = self.w2_scales_stacked.cuda(device) + self.w3_w1_weights = self.w3_w1_weights.cuda() + self.w3_w1_scales = self.w3_w1_scales.cuda() + self.w2_weights_stacked = self.w2_weights_stacked.cuda() + self.w2_scales_stacked = self.w2_scales_stacked.cuda() return self def _swiglu(self, x): @@ -955,20 +918,6 @@ class DeepGemmFP8BlockScalesRefFusedMoE(FP8BlockScalesRefGatedMLPFusedMoE): return output - def check_accuracy(self, output, ref_output): - """ - Check accuracy with relaxed tolerance for DEEPGEMM FP8 block scale kernel. - - DEEPGEMM with FP8 block scaling has specific numerical behavior due to: - - E8M0 scale format quantization - - Manual grouped GEMM computation pattern - - Different routing methods (especially MiniMaxM2 with sigmoid + manual normalization) - - Relaxed from rtol=0.01 to rtol=0.02 to accommodate these numerical differences - while still catching significant errors. - """ - torch.testing.assert_close(output, ref_output, rtol=2e-2, atol=1.5) - class DeepGemmFP8BlockScalesQuantizeUtil(BaseQuantizeUtil): """ @@ -1175,7 +1124,7 @@ class MXFP4MXFP8RefGatedMLPFusedMoE(RefGatedMLPFusedMoE): model_config: Optional[ModelConfig] = None, bias=False, hidden_size_unpadded: Optional[int] = None, - swiglu_gptoss_style: bool = False, + gptoss_style: bool = False, **kwargs, ): super().__init__( @@ -1193,7 +1142,7 @@ class MXFP4MXFP8RefGatedMLPFusedMoE(RefGatedMLPFusedMoE): self.hidden_size_unpadded = ( hidden_size_unpadded if hidden_size_unpadded is not None else hidden_size ) - self.swiglu_gptoss_style = swiglu_gptoss_style + self.gptoss_style = gptoss_style def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> torch.Tensor: # Pad input if hidden_size_unpadded < hidden_size @@ -1209,14 +1158,8 @@ class MXFP4MXFP8RefGatedMLPFusedMoE(RefGatedMLPFusedMoE): return output def check_accuracy(self, output, ref_output): - if self.swiglu_gptoss_style: + if self.gptoss_style: check_accuracy(output, ref_output, rtol=0.1, atol=0.2, percent=0.8) - elif self.hidden_size >= 4096: - # Relax tolerance for large hidden_size (e.g., DeepSeek-V3 h=7168). - # MXFP4 (4-bit) weights + MXFP8 (8-bit) activations accumulate more - # quantization error in large GEMM reduction dimensions: error ~ sqrt(K). - # Observed mismatch: ~17-19% for h=7168 vs <15% for h=512. - check_accuracy(output, ref_output, rtol=0.15, atol=0.3, percent=0.85) else: check_accuracy(output, ref_output, rtol=0.10, atol=0.2, percent=0.85) @@ -1301,6 +1244,7 @@ class MXFP4MXFP8QuantizeUtil(BaseQuantizeUtil): intermediate_size_unpadded = quant_kwargs.get( "intermediate_size_unpadded", self.intermediate_size ) + bias = quant_kwargs.get("bias", False) pad_zero_or_val = quant_kwargs.get("pad_zero_or_val", True) weight_alignment = quant_kwargs.get("weight_alignment", 128) input_hidden_alignment = quant_kwargs.get("input_hidden_alignment", 512) @@ -1312,7 +1256,7 @@ class MXFP4MXFP8QuantizeUtil(BaseQuantizeUtil): weights = {} for expert_id in range(self.num_experts): - if self.bias: + if bias: w1_bias = torch.randn((intermediate_size_unpadded,), dtype=self.dtype).cuda() * 0.1 w2_bias = torch.randn((hidden_size_unpadded,), dtype=self.dtype).cuda() * 0.1 w3_bias = torch.randn((intermediate_size_unpadded,), dtype=self.dtype).cuda() * 0.1 @@ -1490,7 +1434,7 @@ class MXFP4MXFP8QuantizeUtil(BaseQuantizeUtil): model_config=ModelConfig(quant_config=self.quant_config), bias=self.bias, hidden_size_unpadded=hs_unpadded, - swiglu_gptoss_style=self.swiglu_gptoss_style, + gptoss_style=self.gptoss_style, swiglu_alpha=self.swiglu_alpha, swiglu_beta=self.swiglu_beta, swiglu_limit=self.swiglu_limit, @@ -1603,7 +1547,7 @@ class WFP4A16QuantizeUtil(BaseQuantizeUtil): """ WFP4A16QuantizeUtil inherits from BaseQuantizeUtil to support correctness testing for W4A16_MXFP4 quantized MoE modules. - Supports swiglu_gptoss_style with custom swiglu parameters (inherited from BaseQuantizeUtil). + Supports gptoss_style with custom swiglu parameters (inherited from BaseQuantizeUtil). """ def create_weights(self, **quant_kwargs) -> Dict[str, torch.Tensor]: @@ -1676,7 +1620,7 @@ class WFP4A16QuantizeUtil(BaseQuantizeUtil): weights[f"{expert_id}.w2.weight_scale_inv"] = w2_scale weights[f"{expert_id}.w3.weight_scale_inv"] = w3_scale - # Bias for swiglu_gptoss_style + # Bias for gptoss_style if self.bias: weights[f"{expert_id}.w1.bias"] = torch.randn( self.intermediate_size, device="cuda", dtype=torch.float @@ -1693,7 +1637,7 @@ class WFP4A16QuantizeUtil(BaseQuantizeUtil): self, routing_method, ref_cls=WFP4A16RefGatedMLPFusedMoE ) -> torch.nn.Module: """ - Create a reference module for correctness testing with swiglu_gptoss_style support. + Create a reference module for correctness testing with gptoss_style support. """ return super().create_ref_module(routing_method, ref_cls) diff --git a/tests/unittest/_torch/modules/moe/test_moe_backend.py b/tests/unittest/_torch/modules/moe/test_moe_backend.py index b31c696ab0..f80f26bd47 100644 --- a/tests/unittest/_torch/modules/moe/test_moe_backend.py +++ b/tests/unittest/_torch/modules/moe/test_moe_backend.py @@ -28,20 +28,13 @@ Design Goals: import itertools import logging -from typing import List, Optional +import time +from dataclasses import dataclass +from enum import Enum +from typing import Callable, List, Optional, Type import pytest import torch -from _torch.modules.moe.moe_test_utils import ( - MoeBackendType, - MoeModelConfig, - create_test_param, - get_backend_class, - iter_base_test_configs, - module_timer, # noqa: F401 - imported for pytest fixture registration - replay_tactics_and_check, - supports_autotuner_capture, -) from _torch.modules.moe.quantize_utils import get_test_quant_params from transformers.configuration_utils import PretrainedConfig @@ -49,6 +42,10 @@ from tensorrt_llm._torch.autotuner import AutoTuner, autotune from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.modules.fused_moe import RenormalizeMoeRoutingMethod from tensorrt_llm._torch.modules.fused_moe.create_moe import create_moe_backend +from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import CuteDslFusedMoE +from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import CutlassFusedMoE +from tensorrt_llm._torch.modules.fused_moe.fused_moe_deepgemm import DeepGemmFusedMoE +from tensorrt_llm._torch.modules.fused_moe.fused_moe_trtllm_gen import TRTLLMGenFusedMoE from tensorrt_llm._torch.modules.fused_moe.interface import MoE from tensorrt_llm._utils import mpi_rank from tensorrt_llm.mapping import Mapping @@ -57,39 +54,249 @@ from tensorrt_llm.models.modeling_utils import QuantAlgo logger = logging.getLogger(__name__) +class MoeBackendType(str, Enum): + """Enum for MoE backend types.""" + + CUTLASS = "CUTLASS" + TRTLLM = "TRTLLM" + CUTEDSL = "CUTEDSL" + DEEPGEMM = "DEEPGEMM" + + +def get_backend_class(backend_type: MoeBackendType) -> Type[MoE]: + """Get the MoE backend class for a given backend type.""" + backend_class_map = { + MoeBackendType.CUTLASS: CutlassFusedMoE, + MoeBackendType.TRTLLM: TRTLLMGenFusedMoE, + MoeBackendType.CUTEDSL: CuteDslFusedMoE, + MoeBackendType.DEEPGEMM: DeepGemmFusedMoE, + } + return backend_class_map[backend_type] + + +def should_skip_TRTLLM( + backend_type: MoeBackendType, + quant_algo: Optional[QuantAlgo], + model_config: "MoeModelConfig", +) -> Optional[str]: + """ + Check TRTLLM Gen backend specific constraints. + + The TRTLLM Gen MoE kernels have hardware-level constraints that must be satisfied. + These constraints are enforced in C++ layer. + + Constraints: + 1. num_experts must be divisible by 4 (routing kernel vectorization requirement) + 2. num_experts must be greater than top_k (routing logic requirement) + + Args: + backend_type: The MoE backend type + quant_algo: The quantization algorithm + model_config: The MoE model configuration + + Returns: + Skip reason string if test should be skipped, None otherwise + """ + if backend_type != MoeBackendType.TRTLLM: + return None + + if model_config is None: + return None + + # These quantization algorithms use TRTLLM Gen kernels with the constraints + trtllm_gen_quant_algos = { + QuantAlgo.NVFP4, + QuantAlgo.FP8_BLOCK_SCALES, + QuantAlgo.W4A8_NVFP4_FP8, + QuantAlgo.W4A16_MXFP4, + QuantAlgo.W4A8_MXFP4_MXFP8, + } + + if quant_algo not in trtllm_gen_quant_algos: + return None + + num_experts = model_config.num_experts + top_k = model_config.top_k + intermediate_size = model_config.intermediate_size + + # Check: num_experts must be divisible by 4 + # Routing kernel uses vectorized operations that require this alignment + if num_experts % 4 != 0: + return ( + f"TRTLLMGenFusedMoE routing kernel requires num_experts divisible by 4 " + f"(got num_experts={num_experts})" + ) + + # Check: num_experts must be greater than top_k + # Routing logic cannot handle the case where all experts are selected + if num_experts <= top_k: + return ( + f"TRTLLMGenFusedMoE requires num_experts > top_k " + f"(got num_experts={num_experts}, top_k={top_k})" + ) + + # -----------------Potential issues------------------ + # These are known issues that need investigation. Skipping to avoid test failures + # and CUDA errors that can cascade to subsequent tests. + + # Issue 1: W4A8_NVFP4_FP8 with top_k=1 causes CUDA illegal memory access + # This triggers GPU state corruption that affects all subsequent tests. + # Affected config: e8_k1_h512_i512 + if quant_algo == QuantAlgo.W4A8_NVFP4_FP8 and top_k == 1: + return ( + "[Potential Bug] TRTLLMGenFusedMoE W4A8_NVFP4_FP8 with top_k=1 " + "causes CUDA illegal memory access. Needs kernel investigation." + ) + + # Issue 2: NVFP4 with large intermediate_size has known accuracy issues + # Observed mismatch: 18%~25% vs expected <7.5% (per test_moe.py baseline) + # Affected configs: e8_k2_h4096_i14336, e8_k2_h6144_i32768 + if quant_algo == QuantAlgo.NVFP4 and intermediate_size >= 14336: + return ( + f"[Potential Bug] TRTLLMGenFusedMoE NVFP4 with large intermediate_size " + f"has known accuracy issues (intermediate_size={intermediate_size} >= 14336). " + f"Observed mismatch 18%~25% exceeds expected threshold." + ) + + # Issue 3: W4A8_MXFP4_MXFP8 has accuracy issues on certain model configs + # Observed mismatch: 14%~18% vs expected <15% (percent=0.85) + # Affected configs: large intermediate_size or many experts + # e8_k2_h4096_i14336, e64_k6_h2048_i1408, e60_k4_h2048_i1408, + # e256_k8_h7168_i2048, e8_k2_h6144_i32768, e128_k4_h2880_i2880 + if quant_algo == QuantAlgo.W4A8_MXFP4_MXFP8: + # Large intermediate_size (>= 14336) has precision issues + if intermediate_size >= 14336: + return ( + f"[Potential Bug] TRTLLMGenFusedMoE W4A8_MXFP4_MXFP8 with large " + f"intermediate_size has accuracy issues (intermediate_size={intermediate_size} >= 14336). " + f"Observed mismatch 14%~18% exceeds 15% threshold." + ) + # Many experts (>= 60) with moderate intermediate_size has precision issues + if num_experts >= 60 and intermediate_size >= 1408: + return ( + f"[Potential Bug] TRTLLMGenFusedMoE W4A8_MXFP4_MXFP8 with many experts " + f"has accuracy issues (num_experts={num_experts} >= 60, intermediate_size={intermediate_size}). " + f"Observed mismatch 14%~18% exceeds 15% threshold." + ) + + return None + + +def should_skip_CUTEDSL( + backend_type: MoeBackendType, + quant_algo: Optional[QuantAlgo], + model_config: "MoeModelConfig" = None, +) -> Optional[str]: + """ + Check CuteDSL backend specific constraints. + + The CuteDSL MoE kernels have known accuracy issues with certain configurations. + + Args: + backend_type: The MoE backend type + quant_algo: The quantization algorithm + model_config: The MoE model configuration + + Returns: + Skip reason string if test should be skipped, None otherwise + """ + if backend_type != MoeBackendType.CUTEDSL: + return None + + if model_config is None: + return None + + intermediate_size = model_config.intermediate_size + + # -----------------Potential issues------------------ + # NVFP4 with large intermediate_size has known accuracy issues (same as TRTLLM) + # Observed mismatch: 8%~26% vs expected <2% + # Affected configs: e8_k2_h4096_i14336, e8_k2_h6144_i32768 + if quant_algo == QuantAlgo.NVFP4 and intermediate_size >= 14336: + return ( + f"[Potential Bug] CuteDslFusedMoE NVFP4 with large intermediate_size " + f"has known accuracy issues (intermediate_size={intermediate_size} >= 14336). " + f"Observed mismatch 8%~26% exceeds 2% threshold." + ) + + # NVFP4 with prime num_experts (7, 13) causes CUDA_ERROR_ILLEGAL_ADDRESS + # Root cause: Autotuner cache bucket mapping issue + # - When tests run in batch, previous tests cache tactics to buckets + # - Prime num_experts shapes map to same bucket as other configs + # - The cached tactic (e.g., ((128, 256), (1, 2), False)) works for other configs + # but causes illegal memory access for prime num_experts' actual shape + # - Single test run passes because fallback tactic ((128, 128), (1, 1), False) is used + # Affected configs: e7_k2_h256_i512, e13_k3_h256_i512 + num_experts = model_config.num_experts + prime_experts_with_issues = {7, 13} + if quant_algo == QuantAlgo.NVFP4 and num_experts in prime_experts_with_issues: + return ( + f"[Potential Bug] CuteDslFusedMoE NVFP4 with prime num_experts={num_experts} " + f"causes CUDA_ERROR_ILLEGAL_ADDRESS due to autotuner cache bucket mapping. " + f"Cached tactic from other configs is incompatible with this shape." + ) + + return None + + def should_skip_gptoss( backend_type: MoeBackendType, quant_algo: Optional[QuantAlgo], - swiglu_gptoss_style: bool, + gptoss_style: bool, ) -> Optional[str]: """ - Check if swiglu_gptoss_style test should be skipped for this backend. + Check if gptoss_style test should be skipped for this backend. - Only CUTLASS and TRTLLM backends support swiglu_gptoss_style (SwiGlu with custom + Only CUTLASS and TRTLLM backends support gptoss_style (SwiGlu with custom alpha/beta/limit parameters and bias). Args: backend_type: The MoE backend type quant_algo: The quantization algorithm - swiglu_gptoss_style: Whether swiglu_gptoss_style is enabled + gptoss_style: Whether gptoss_style is enabled Returns: Skip reason string if test should be skipped, None otherwise """ - if not swiglu_gptoss_style: + if not gptoss_style: return None - # Only CUTLASS and TRTLLM backends support swiglu_gptoss_style + # Only CUTLASS and TRTLLM backends support gptoss_style supported_backends = {MoeBackendType.CUTLASS, MoeBackendType.TRTLLM} if backend_type not in supported_backends: return ( - f"swiglu_gptoss_style is only supported by CUTLASS and TRTLLM backends " + f"gptoss_style is only supported by CUTLASS and TRTLLM backends " f"(got backend_type={backend_type.value})" ) return None +def supports_autotuner_capture( + backend_type: MoeBackendType, + quant_algo: Optional[QuantAlgo], +) -> bool: + """ + Determine if a backend+quant_algo combination supports AutoTuner capture/replay. + + AutoTuner capture/replay requires AutoTuner.choose_one() to be called during + run_moe execution. + + Args: + backend_type: The MoE backend type + quant_algo: The quantization algorithm (None for unquantized) + + Returns: + True if autotuner capture/replay is supported, False otherwise + """ + # DEEPGEMM does not support autotuner capture + # Evidence: fused_moe_deepgemm.py has no AutoTuner/choose_one references + if backend_type == MoeBackendType.DEEPGEMM: + return False + + return True + + def create_test_backend( backend_type: MoeBackendType, routing_method: RenormalizeMoeRoutingMethod, @@ -190,6 +397,60 @@ def run_backend_moe( return backend.run_moe(**args) +def replay_tactics_and_check( + all_tactics, + run_moe_fn: Callable[[], torch.Tensor], + check_accuracy_fn: Callable[[torch.Tensor, torch.Tensor], None], + ref_output: torch.Tensor, + backend_type: MoeBackendType, + quant_algo: Optional[QuantAlgo], + fail_fast: bool = False, +) -> None: + """ + Replay all tactics and check accuracy. + + Args: + all_tactics: TacticsCapture object from AutoTuner.capture() + run_moe_fn: Function to run MoE computation + check_accuracy_fn: Function to check accuracy (output, ref_output) -> None + ref_output: Reference output tensor + backend_type: Backend type for error reporting + quant_algo: Quantization algorithm for error reporting + fail_fast: If True, fail on first error. If False, run all and report summary. + """ + tactics_list = list(all_tactics) + passed_tactics = [] + failed_tactics = [] + logger.info(f"Replay tactics : {len(tactics_list)} and check accuracy") + for idx, tactic in enumerate(tactics_list): + with AutoTuner.get().replay(tactic), torch.inference_mode(): + output = run_moe_fn() + try: + check_accuracy_fn(output, ref_output) + passed_tactics.append((idx, tactic)) + except Exception as e: + if fail_fast: + pytest.fail( + f"Accuracy check failed for tactic[{idx}/{len(tactics_list)}]={tactic}, " + f"backend={backend_type}, quant_algo={quant_algo}: {e}" + ) + failed_tactics.append((idx, tactic, str(e))) + + # Report results (only when fail_fast=False) + total = len(tactics_list) + num_passed = len(passed_tactics) + num_failed = len(failed_tactics) + if failed_tactics: + fail_details = "\n".join( + f" tactic[{idx}]={tactic}: {err}" for idx, tactic, err in failed_tactics + ) + pytest.fail( + f"backend={backend_type}, quant_algo={quant_algo}: " + f"{num_passed}/{total} passed, {num_failed}/{total} failed\n" + f"Failed tactics:\n{fail_details}" + ) + + # ============================================================================ # Test Parameters # ============================================================================ @@ -221,6 +482,23 @@ DTYPES_TO_TEST = [ torch.bfloat16, ] + +# ============================================================================ +# Model MoE Configurations +# ============================================================================ +@dataclass +class MoeModelConfig: + """MoE model configuration: (num_experts, top_k, hidden_size, intermediate_size).""" + + num_experts: int + top_k: int + hidden_size: int + intermediate_size: int + + def __str__(self) -> str: + return f"e{self.num_experts}_k{self.top_k}_h{self.hidden_size}_i{self.intermediate_size}" + + # Format: (num_experts, top_k, hidden_size, intermediate_size) MOE_MODEL_CONFIGS = [ # === Real Model Configs === @@ -243,10 +521,89 @@ MOE_MODEL_CONFIGS = [ # Sequence lengths to test SEQ_LENS_TO_TEST = [1, 8] -# SwiGLU parameters for swiglu_gptoss_style testing -SWIGLU_ALPHAS = [1, 1.702] # default, GPT-OSS (modeling_gpt_oss.py) -SWIGLU_BETAS = [0, 1.0] # default, GPT-OSS -SWIGLU_LIMITS = [float("inf"), 7.0] # default, GPT-OSS +# SwiGLU parameters for gptoss_style testing +SWIGLU_ALPHAS = [1, 0.1] +SWIGLU_BETAS = [0, 1] +SWIGLU_LIMITS = [float("inf"), 1] + + +# ============================================================================ +# Fast Skip Check (for parametrize-level skip, avoids entering test function) +# ============================================================================ +def get_quick_skip_reason( + backend_type: MoeBackendType, + quant_algo: Optional[QuantAlgo], + dtype: torch.dtype, + model_config: "MoeModelConfig", + gptoss_style: bool, +) -> Optional[str]: + """ + Fast skip check that calls backend's can_implement() method. + + This function calls the backend's can_implement() classmethod to check + dtype/quant_algo/gptoss_style support, then uses should_skip_* functions + for additional model_config specific checks. + + Note: Logging is temporarily suppressed to avoid excessive warning output + during test parameter generation. + + Returns: + Skip reason string if test should be skipped, None otherwise + """ + import logging as _logging + + # Suppress logger warnings during parameter generation to avoid excessive output + trtllm_logger = _logging.getLogger("tensorrt_llm") + original_level = trtllm_logger.level + trtllm_logger.setLevel(_logging.ERROR) + + try: + # ===== Call backend's can_implement for dtype/quant_algo/gptoss_style checks ===== + backend_cls = get_backend_class(backend_type) + can_impl, skip_reason = backend_cls.can_implement( + quant_algo, dtype_activation=dtype, gptoss_style=gptoss_style + ) + if not can_impl: + return skip_reason + + # ===== Additional model_config specific checks ===== + + # TRTLLM: num_experts constraints and accuracy issues + skip_reason = should_skip_TRTLLM(backend_type, quant_algo, model_config) + if skip_reason: + return skip_reason + + # CUTEDSL: accuracy issues with specific configs + skip_reason = should_skip_CUTEDSL(backend_type, quant_algo, model_config) + if skip_reason: + return skip_reason + + # DEEPGEMM: float16 reference module constraint + if backend_type == MoeBackendType.DEEPGEMM and dtype == torch.float16: + return "DeepGemmFusedMoE reference module (FP8BlockScalesLinearMethod) requires bfloat16 input" + + # 128-alignment requirement for quantization + if quant_algo is not None: + hidden_size = model_config.hidden_size + intermediate_size = model_config.intermediate_size + is_hidden_128_aligned = hidden_size % 128 == 0 + is_intermediate_128_aligned = intermediate_size % 128 == 0 + + if not is_hidden_128_aligned or not is_intermediate_128_aligned: + # TRTLLM with MXFP4 variants automatically pads to 128 alignment + is_mxfp4_variant = quant_algo in {QuantAlgo.W4A16_MXFP4, QuantAlgo.W4A8_MXFP4_MXFP8} + is_trtllm_backend = backend_type == MoeBackendType.TRTLLM + if not (is_trtllm_backend and is_mxfp4_variant): + return ( + f"Non-128-aligned sizes (h={hidden_size}, i={intermediate_size}) " + f"require TRTLLM backend with MXFP4 quantization" + ) + + return None + + finally: + # Restore logger level + trtllm_logger.setLevel(original_level) def generate_test_params() -> List: @@ -260,41 +617,57 @@ def generate_test_params() -> List: Returns: List of pytest.param objects with appropriate skip marks """ + params: List = [] + + # Generate all combinations swiglu_combos = list(itertools.product(SWIGLU_ALPHAS, SWIGLU_BETAS, SWIGLU_LIMITS)) - params: List = [] - for ( - swiglu_alpha, - swiglu_beta, - swiglu_limit, - model_config, - seq_len, - dtype, - backend_type, - quant_algo, - routing_method_cls, - skip_reason, - test_id, - ) in iter_base_test_configs( - swiglu_combos, - MOE_MODEL_CONFIGS, - SEQ_LENS_TO_TEST, - DTYPES_TO_TEST, - BACKEND_TYPES_TO_TEST, - QUANT_ALGOS_TO_TEST, - ): - param_values = ( - dtype, - backend_type, - quant_algo, - seq_len, - model_config, - routing_method_cls, - swiglu_alpha, - swiglu_beta, - swiglu_limit, - ) - params.append(create_test_param(param_values, test_id, skip_reason)) + for swiglu_alpha, swiglu_beta, swiglu_limit in swiglu_combos: + for model_config in MOE_MODEL_CONFIGS: + for seq_len in SEQ_LENS_TO_TEST: + for dtype in DTYPES_TO_TEST: + for backend_type in BACKEND_TYPES_TO_TEST: + for quant_algo in QUANT_ALGOS_TO_TEST: + # Determine gptoss_style + gptoss_style = ( + swiglu_alpha != 1 + or swiglu_beta != 0 + or swiglu_limit != float("inf") + ) + + # Generate test ID + test_id = ( + f"alpha={swiglu_alpha}_beta={swiglu_beta}_limit={swiglu_limit}-" + f"{model_config}-seq={seq_len}-dtype={dtype}-" + f"backend={backend_type.value}-quant_algo={quant_algo}" + ) + + # Check if should skip + skip_reason = get_quick_skip_reason( + backend_type, quant_algo, dtype, model_config, gptoss_style + ) + + param_values = ( + dtype, + backend_type, + quant_algo, + seq_len, + model_config, + swiglu_alpha, + swiglu_beta, + swiglu_limit, + ) + + if skip_reason: + params.append( + pytest.param( + *param_values, + id=test_id, + marks=pytest.mark.skip(reason=skip_reason), + ) + ) + else: + params.append(pytest.param(*param_values, id=test_id)) return params @@ -303,6 +676,23 @@ def generate_test_params() -> List: TEST_PARAMS = generate_test_params() +# ============================================================================ +# Timing Fixtures +# ============================================================================ +@pytest.fixture(scope="module", autouse=True) +def module_timer(request): + """Fixture to measure and log total module execution time.""" + start = time.perf_counter() + yield + elapsed = time.perf_counter() - start + logger.info( + "[TIMING] Total %s: %.3fs (%.2f min)", + request.module.__name__, + elapsed, + elapsed / 60, + ) + + # ============================================================================ # Test Implementation # ============================================================================ @@ -350,16 +740,15 @@ TEST_PARAMS = generate_test_params() # Skip Logic # ============================================================================= # Tests are automatically skipped for unsupported configurations using: -# - backend.can_implement(): Check dtype/quant_algo/swiglu_gptoss_style support -# - should_skip_trtllm(): TRTLLM-specific constraints (num_experts % 4, etc.) -# - should_skip_cutedsl(): CuteDSL-specific accuracy issues +# - backend.can_implement(): Check dtype/quant_algo/gptoss_style support +# - should_skip_TRTLLM(): TRTLLM-specific constraints (num_experts % 4, etc.) +# - should_skip_CUTEDSL(): CuteDSL-specific accuracy issues # - 128-alignment requirements for quantization # # ============================================================================= @pytest.mark.skip(reason="Temporarily skipped due to the long time to run the test") @pytest.mark.parametrize( - "dtype_activation,backend_type,quant_algo,seq_len,model_config," - "routing_method_cls,swiglu_alpha,swiglu_beta,swiglu_limit", + "dtype_activation,backend_type,quant_algo,seq_len,model_config,swiglu_alpha,swiglu_beta,swiglu_limit", TEST_PARAMS, ) def test_moe_backend( @@ -368,7 +757,6 @@ def test_moe_backend( quant_algo: Optional[QuantAlgo], seq_len: int, model_config: MoeModelConfig, - routing_method_cls, swiglu_alpha: float, swiglu_beta: float, swiglu_limit: float, @@ -380,12 +768,12 @@ def test_moe_backend( 1. Autotune works correctly with the backend 2. All tactics are captured properly 3. Different sequence lengths use appropriate tactics - 4. swiglu_gptoss_style (SwiGlu with custom parameters) works correctly + 4. gptoss_style (SwiGlu with custom parameters) works correctly """ - # Determine swiglu_gptoss_style based on swiglu parameters - # swiglu_gptoss_style is True when any swiglu parameter deviates from default + # Determine gptoss_style based on swiglu parameters + # gptoss_style is True when any swiglu parameter deviates from default # Default values: alpha=1, beta=0, limit=inf - swiglu_gptoss_style = swiglu_alpha != 1 or swiglu_beta != 0 or swiglu_limit != float("inf") + gptoss_style = swiglu_alpha != 1 or swiglu_beta != 0 or swiglu_limit != float("inf") # Note: Skip logic is now handled at parametrize level via get_quick_skip_reason() # which calls backend's can_implement() and should_skip_* functions. @@ -409,8 +797,8 @@ def test_moe_backend( # Setup autotuner distributed state AutoTuner.get().setup_distributed_state(mapping) - # Create routing method from parametrized class - routing_method = routing_method_cls(top_k=top_k) + # Create routing method + routing_method = RenormalizeMoeRoutingMethod(top_k=top_k) # Create test inputs x = torch.randn((seq_len, hidden_size), dtype=dtype_activation, device="cuda") @@ -422,21 +810,21 @@ def test_moe_backend( quant_algo, x, backend_type ) - # Create quantize utility with swiglu_gptoss_style parameters + # Create quantize utility with gptoss_style parameters quantize_util = quantize_util_cls( num_experts=num_experts, dtype=dtype_activation, intermediate_size=intermediate_size, hidden_size=hidden_size, quant_config=quant_config, - bias=swiglu_gptoss_style, - swiglu_gptoss_style=swiglu_gptoss_style, - swiglu_alpha=swiglu_alpha if swiglu_gptoss_style else None, - swiglu_beta=swiglu_beta if swiglu_gptoss_style else None, - swiglu_limit=swiglu_limit if swiglu_gptoss_style else None, + bias=gptoss_style, + gptoss_style=gptoss_style, + swiglu_alpha=swiglu_alpha if gptoss_style else None, + swiglu_beta=swiglu_beta if gptoss_style else None, + swiglu_limit=swiglu_limit if gptoss_style else None, ) - # Get swiglu tensors if swiglu_gptoss_style is enabled + # Get swiglu tensors if gptoss_style is enabled swiglu_tensors = quantize_util.get_swiglu_tensors() # Create backend first (needed for MXFP4_MXFP8 to get shapes) @@ -449,7 +837,7 @@ def test_moe_backend( dtype=dtype_activation, quant_config=quant_config, mapping=mapping, - bias=swiglu_gptoss_style, + bias=gptoss_style, swiglu_alpha=swiglu_tensors["swiglu_alpha"] if swiglu_tensors else None, swiglu_beta=swiglu_tensors["swiglu_beta"] if swiglu_tensors else None, swiglu_limit=swiglu_tensors["swiglu_limit"] if swiglu_tensors else None, diff --git a/tests/unittest/_torch/modules/moe/test_moe_module.py b/tests/unittest/_torch/modules/moe/test_moe_module.py index bc4060b5a6..becc8df849 100644 --- a/tests/unittest/_torch/modules/moe/test_moe_module.py +++ b/tests/unittest/_torch/modules/moe/test_moe_module.py @@ -12,89 +12,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -MoE Module Unit Tests - -This module provides a unified test framework for testing MoE modules through the -high-level create_moe() + forward() interface, rather than the backend-level interfaces. - -Design Goals: -1. Test MoE module via: create_moe -> load_weights -> forward -2. Cover key quantization + backend combinations -3. Support EPLB (Expert Load Balancing) testing -4. Support autotune and tactic capture testing -""" - import copy -import logging import os import pickle import sys from contextlib import nullcontext -from itertools import product -from typing import List, Optional import cloudpickle import pytest import torch -from _torch.modules.moe.moe_test_utils import ( - MoeBackendType, - MoeModelConfig, - create_test_param, - get_quick_skip_reason, - iter_base_test_configs, - module_timer, # noqa: F401 - imported for pytest fixture registration - replay_tactics_and_check, - should_skip_cutedsl, - should_skip_deepgemm, - should_skip_multi_gpu, - should_skip_trtllm, - supports_autotuner_capture, -) from _torch.modules.moe.quantize_utils import get_test_quant_params from mpi4py import MPI from mpi4py.futures import MPIPoolExecutor from transformers.configuration_utils import PretrainedConfig +from utils.util import getSMVersion import tensorrt_llm.bindings.internal.runtime as _tbr -from tensorrt_llm._torch.autotuner import AutoTuner, autotune from tensorrt_llm._torch.model_config import ModelConfig -from tensorrt_llm._torch.modules.fused_moe import ( - DeepSeekV3MoeRoutingMethod, - DefaultMoeRoutingMethod, - Llama4RenormalizeMoeRoutingMethod, - MiniMaxM2MoeRoutingMethod, - RenormalizeMoeRoutingMethod, - RenormalizeNaiveMoeRoutingMethod, - create_moe, -) +from tensorrt_llm._torch.modules.fused_moe import RenormalizeMoeRoutingMethod, create_moe from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import ( MoeLoadBalancer, MoeLoadBalancerIterContext, ) -from tensorrt_llm._torch.modules.fused_moe.quantization import ( - DeepSeekFP8BlockScalesFusedMoEMethod, - FP8QDQFusedMoEMethod, - INT8WoqPerChannelFusedMoEMethod, - NVFP4CutlassFusedMoEMethod, - NVFP4TRTLLMGenFusedMoEMethod, - UnquantizedFusedMoEMethod, - W4A8MXFP4FP8CutlassFusedMoEMethod, - W4A8MXFP4FP8TRTLLMGenFusedMoEMethod, - W4A8MXFP4MXFP8CutlassFusedMoEMethod, - W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod, - W4A8NVFP4FP8TRTLLMGenFusedMoEMethod, - W4A16MXFP4TRTLLMGenFusedMoEMethod, - WFP4A16FusedMoEMethod, - WInt4AFP8FusedMoEMethod, -) from tensorrt_llm._utils import mpi_rank from tensorrt_llm.llmapi.llm_args import MoeLoadBalancerConfig from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantAlgo -logger = logging.getLogger(__name__) - cloudpickle.register_pickle_by_value(sys.modules[__name__]) MPI.pickle.__init__( cloudpickle.dumps, @@ -103,250 +47,9 @@ MPI.pickle.__init__( ) -def _create_mapping_for_parallel_mode(world_size, parallel_mode): - """Create Mapping for different parallelism strategies. - - Args: - world_size: Total number of GPUs - parallel_mode: One of "DEP", "TEP", "DTP", "TTP" - - DEP: Attention uses DP, MoE uses EP - - TEP: Attention uses TP, MoE uses EP - - DTP: Attention uses DP, MoE uses TP - - TTP: Attention uses TP, MoE uses TP - - Returns: - Mapping object configured for the specified parallel mode - """ - configs = { - "DEP": { # Attention DP, MoE EP - "moe_ep_size": world_size, - "moe_tp_size": 1, - "enable_attention_dp": True, - }, - "TEP": { # Attention TP, MoE EP - "moe_ep_size": world_size, - "moe_tp_size": 1, - "enable_attention_dp": False, - }, - "DTP": { # Attention DP, MoE TP - "moe_ep_size": 1, - "moe_tp_size": world_size, - "enable_attention_dp": True, - }, - "TTP": { # Attention TP, MoE TP - "moe_ep_size": 1, - "moe_tp_size": world_size, - "enable_attention_dp": False, - }, - } - if parallel_mode not in configs: - raise ValueError( - f"Unknown parallel_mode: {parallel_mode}. Must be one of {list(configs.keys())}" - ) - - cfg = configs[parallel_mode] - return Mapping( - world_size=world_size, - tp_size=world_size, - moe_ep_size=cfg["moe_ep_size"], - moe_tp_size=cfg["moe_tp_size"], - enable_attention_dp=cfg["enable_attention_dp"], - ) - - -def _create_moe_load_balancer(model_cfg, enable_eplb): - """Create MoeLoadBalancer if EPLB is enabled, otherwise return nullcontext.""" - if not enable_eplb: - return nullcontext() - - ep_rank = model_cfg.mapping.moe_ep_rank - ep_size = model_cfg.mapping.moe_ep_size - model_cfg.moe_load_balancer.setup(ep_rank=ep_rank, ep_size=ep_size) - return MoeLoadBalancer( - ep_rank=ep_rank, - ep_size=ep_size, - layer_updates_per_iter=model_cfg.moe_load_balancer.layer_updates_per_iter, - ) - - -def _setup_autotuner_for_test(mapping): - """Configure AutoTuner for faster unit test profiling.""" - AutoTuner.get().setup_distributed_state(mapping) - AutoTuner.get().clear_cache() - autotuner = AutoTuner.get() - autotuner.warmup = 0 # default: 2 - autotuner.repeat = 1 # default: 10 - autotuner.stream_delay_micro_secs = 10 # default: 1000 - - -def _create_model_config( - num_experts, - hidden_size, - intermediate_size, - dtype, - mapping, - quant_config, - moe_backend, - enable_eplb=False, - num_slots=-1, - layer_updates_per_iter=-1, -): - """Create PretrainedConfig and ModelConfig for MoE testing.""" - pretrained_config = PretrainedConfig() - pretrained_config.num_experts = num_experts - pretrained_config.hidden_size = hidden_size - pretrained_config.intermediate_size = intermediate_size - pretrained_config.torch_dtype = dtype - - moe_load_balancer_config = ( - MoeLoadBalancerConfig( - num_slots=num_slots, - layer_updates_per_iter=layer_updates_per_iter, - ) - if enable_eplb - else None - ) - - return ModelConfig( - pretrained_config=pretrained_config, - mapping=mapping, - quant_config=quant_config, - moe_backend=moe_backend, - moe_disable_finalize_fusion=False, - moe_load_balancer=moe_load_balancer_config, - ) - - -def _run_autotune_test( - run_forward_fn, ref_fused_moe, ref_output, backend_type, quant_algo, run_all_tactics=False -): - """Run autotune phase and tactic replay test. - - Args: - run_forward_fn: Forward function to run - ref_fused_moe: Reference MoE module for accuracy check - ref_output: Reference output for comparison - backend_type: MoE backend type - quant_algo: Quantization algorithm - run_all_tactics: If False, skip full tactic replay and only run simple accuracy check - """ - # Autotune phase - with torch.inference_mode(), autotune(cache_path="/tmp/moe_module_autotuner_cache.json"): - _ = run_forward_fn() - - # Check if we should run full tactic replay - if not run_all_tactics or not supports_autotuner_capture(backend_type, quant_algo): - # Simple accuracy check for unsupported backends or when run_all_tactics is False - with torch.inference_mode(): - output = run_forward_fn() - ref_fused_moe.check_accuracy(output, ref_output) - return - - # Capture phase: record which tactics are used - with AutoTuner.get().capture() as all_tactics, torch.inference_mode(): - _ = run_forward_fn() - - # Replay phase: test each tactic for correctness - replay_tactics_and_check( - all_tactics=all_tactics, - run_moe_fn=run_forward_fn, - check_accuracy_fn=ref_fused_moe.check_accuracy, - ref_output=ref_output, - backend_type=backend_type, - quant_algo=quant_algo, - fail_fast=False, - ) - - -def _run_eplb_test( - run_forward_fn, ref_fused_moe, ref_output, moe_load_balancer, initial_expert_ids -): - """Run EPLB multi-iteration test. - - Args: - run_forward_fn: Forward function to run - ref_fused_moe: Reference MoE module for accuracy check - ref_output: Reference output for comparison - moe_load_balancer: MoeLoadBalancer instance - initial_expert_ids: Expert IDs recorded immediately after MoE initialization (before any forward) - """ - assert isinstance(moe_load_balancer, MoeLoadBalancer), ( - "Moe load balancer should be created when eplb is enabled" - ) - assert initial_expert_ids is not None, ( - "initial_expert_ids should be recorded before any forward pass" - ) - - extra_steps = 1 - for _ in range(extra_steps): - output = run_forward_fn() - ref_fused_moe.check_accuracy(output, ref_output) - - current_expert_ids = copy.deepcopy( - moe_load_balancer.single_layer_load_balancers[0].get_old_rank_expert_ids() - ) - - # EPLB should have updated expert_ids from initial state - assert initial_expert_ids != current_expert_ids, ( - f"Expert ids after eplb update should be different from the initial loaded ones. " - f"Initial: {initial_expert_ids}, Current: {current_expert_ids}" - ) - - -def _create_routing_method(routing_method_cls, top_k, num_experts, dtype): - """ - Create a routing method instance with appropriate parameters for each routing method type. - - Args: - routing_method_cls: The routing method class to instantiate - top_k: Number of experts to select per token - num_experts: Total number of experts - dtype: Data type for tensors - - Returns: - An instance of the routing method - """ - # Routing methods with force_enable_pytorch_op support - if routing_method_cls in (RenormalizeMoeRoutingMethod, DefaultMoeRoutingMethod): - return routing_method_cls(top_k=top_k, force_enable_pytorch_op=True) - - # Simple routing methods (only top_k) - if routing_method_cls in (RenormalizeNaiveMoeRoutingMethod, Llama4RenormalizeMoeRoutingMethod): - return routing_method_cls(top_k=top_k) - - # DeepSeekV3 routing method requires special parameters - if routing_method_cls == DeepSeekV3MoeRoutingMethod: - # DeepSeek-V3 routing: groups experts, selects top groups, then selects top_k from those - # The routing logic does topk(k=2) within each group, so each group must have >= 2 experts - # Calculate n_group such that each group has at least 2 experts - experts_per_group = 2 - n_group = max(1, num_experts // experts_per_group) - # topk_group should be <= n_group and reasonable for the selection - topk_group = min(n_group, max(1, n_group // 2)) - routed_scaling_factor = 1.0 - # Create e_score_correction_bias as a zero tensor (no bias correction in test) - e_score_correction_bias = torch.zeros(num_experts, dtype=dtype, device="cuda") - return routing_method_cls( - top_k=top_k, - n_group=n_group, - topk_group=topk_group, - routed_scaling_factor=routed_scaling_factor, - callable_e_score_correction_bias=lambda: e_score_correction_bias, - is_fused=False, # Use PyTorch implementation for testing - ) - - # MiniMaxM2 routing method requires special parameters - if routing_method_cls == MiniMaxM2MoeRoutingMethod: - # Create e_score_correction_bias as a zero tensor (no bias correction in test) - e_score_correction_bias = torch.zeros(num_experts, dtype=dtype, device="cuda") - return routing_method_cls( - top_k=top_k, - num_experts=num_experts, - callable_e_score_correction_bias=lambda: e_score_correction_bias, - ) - - # Fallback: try with just top_k - return routing_method_cls(top_k=top_k) +def _skip_helper(quant_algo): + if quant_algo == QuantAlgo.NVFP4 and getSMVersion() < 100: + pytest.skip("This test is not supported in pre-Blackwell architecture") def _test_moe_worker( @@ -357,202 +60,119 @@ def _test_moe_worker( enable_eplb=False, layer_updates_per_iter=-1, num_slots=-1, - model_config: Optional[MoeModelConfig] = None, - seq_len: int = 4, - enable_autotune: bool = False, - routing_method_cls=RenormalizeMoeRoutingMethod, - dtype_routing_logits=None, - swiglu_alpha: float = 1, - swiglu_beta: float = 0, - swiglu_limit: float = float("inf"), ): - """ - Test MoE module worker function. + # Hardcode some parameters for testing + # activation and weight related + seq_len = 4 + top_k = 2 + num_experts = 8 + hidden_size = 512 + intermediate_size = 512 - This test verifies: - 1. MoE module forward pass produces correct results - 2. EPLB (Expert Load Balancing) works correctly when enabled - 3. Autotune works correctly with the module when enabled - 4. All tactics are captured and replayed properly when autotune is enabled + # Other parameters + finalize_fusion = True - Args: - routing_method_cls: Routing method class to use (default: RenormalizeMoeRoutingMethod) - dtype_routing_logits: Data type for routing logits (default: same as dtype). - DeepSeekV3 routing requires torch.float32. - swiglu_alpha: SwiGLU alpha parameter (default=1, non-gptoss) - swiglu_beta: SwiGLU beta parameter (default=0, non-gptoss) - swiglu_limit: SwiGLU limit parameter (default=inf, non-gptoss) - """ - import traceback - - try: - _test_moe_worker_impl( - moe_backend=moe_backend, - dtype=dtype, - quant_algo=quant_algo, - mapping=mapping, - enable_eplb=enable_eplb, - layer_updates_per_iter=layer_updates_per_iter, - num_slots=num_slots, - model_config=model_config, - seq_len=seq_len, - enable_autotune=enable_autotune, - routing_method_cls=routing_method_cls, - dtype_routing_logits=dtype_routing_logits, - swiglu_alpha=swiglu_alpha, - swiglu_beta=swiglu_beta, - swiglu_limit=swiglu_limit, - ) - except Exception: - traceback.print_exc() - raise - - -def _test_moe_worker_impl( - moe_backend, - dtype, - quant_algo, - mapping=None, - enable_eplb=False, - layer_updates_per_iter=-1, - num_slots=-1, - model_config: Optional[MoeModelConfig] = None, - seq_len: int = 4, - enable_autotune: bool = False, - routing_method_cls=RenormalizeMoeRoutingMethod, - dtype_routing_logits=None, - swiglu_alpha: float = 1, - swiglu_beta: float = 0, - swiglu_limit: float = float("inf"), -): - """Actual implementation of _test_moe_worker.""" - # Default routing logits dtype to model dtype if not specified - if dtype_routing_logits is None: - dtype_routing_logits = dtype - # Parse model config - if model_config is not None: - num_experts = model_config.num_experts - top_k = model_config.top_k - hidden_size = model_config.hidden_size - intermediate_size = model_config.intermediate_size - else: - num_experts, top_k, hidden_size, intermediate_size = 8, 2, 512, 512 - - # Setup mapping mapping = mapping or Mapping() mapping.rank = mpi_rank() + all_rank_num_tokens = [seq_len] * mapping.world_size + torch.cuda.set_device(mapping.rank) with torch.device(f"cuda:{mapping.rank}"): torch.manual_seed(0) torch.cuda.manual_seed(0) - # Create routing method and input tensors - routing_method = _create_routing_method( - routing_method_cls, top_k=top_k, num_experts=num_experts, dtype=dtype - ) + # Create route method + routing_method = RenormalizeMoeRoutingMethod(top_k=top_k, force_enable_pytorch_op=True) + + # Create activation and weight x = torch.randn((seq_len, hidden_size), dtype=dtype, device="cuda") if enable_eplb: - # Same router_logits for all tokens to force the eplb update weights - router_logits = torch.randn( - (1, num_experts), dtype=dtype_routing_logits, device="cuda" - ).repeat(seq_len, 1) - else: - router_logits = torch.randn( - (seq_len, num_experts), dtype=dtype_routing_logits, device="cuda" + # Here we create same router_logits for all tokens to force the eplb update weights + router_logits = torch.randn((1, num_experts), dtype=dtype, device="cuda").repeat( + seq_len, 1 ) + else: + router_logits = torch.randn((seq_len, num_experts), dtype=dtype, device="cuda") - # Determine swiglu_gptoss_style - swiglu_gptoss_style = swiglu_alpha != 1 or swiglu_beta != 0 or swiglu_limit != float("inf") - - # In EP mode, swiglu tensors must be sized per local experts - # (C++ kernels check: swiglu_alpha.size(0) == num_experts_on_rank) - num_local_experts = num_experts // mapping.moe_ep_size - - # Setup quantization - backend_type = MoeBackendType(moe_backend) - quantize_util_cls, quant_config, quant_kwargs = get_test_quant_params( - quant_algo, x, backend_type - ) + quantize_util_cls, quant_config, quant_kwargs = get_test_quant_params(quant_algo, x) quantize_util = quantize_util_cls( num_experts=num_experts, dtype=dtype, intermediate_size=intermediate_size, hidden_size=hidden_size, quant_config=quant_config, - bias=swiglu_gptoss_style, - swiglu_gptoss_style=swiglu_gptoss_style, - swiglu_alpha=swiglu_alpha if swiglu_gptoss_style else None, - swiglu_beta=swiglu_beta if swiglu_gptoss_style else None, - swiglu_limit=swiglu_limit if swiglu_gptoss_style else None, - num_local_experts=num_local_experts, ) weights = quantize_util.create_weights(**quant_kwargs) - # For EPLB, keep weights on CPU if enable_eplb: + # Keep the tensor on CPU for eplb for key in weights: if isinstance(weights[key], torch.Tensor): weights[key] = weights[key].to("cpu") + + # Deepcopy the CPU weight since when eplb turns on, fused moe may advise_tensor_pageout in post load weight. ref_weights = copy.deepcopy(weights) if enable_eplb else weights - # Create configs - model_cfg = _create_model_config( - num_experts=num_experts, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - dtype=dtype, + # Create pretrained config + pretrained_config = PretrainedConfig() + pretrained_config.num_experts = num_experts + pretrained_config.hidden_size = hidden_size + pretrained_config.intermediate_size = intermediate_size + pretrained_config.torch_dtype = dtype + + if enable_eplb: + moe_load_balancer_config = MoeLoadBalancerConfig( + num_slots=num_slots, + layer_updates_per_iter=layer_updates_per_iter, + ) + else: + moe_load_balancer_config = None + + model_config = ModelConfig( + pretrained_config=pretrained_config, mapping=mapping, quant_config=quant_config, moe_backend=moe_backend, - enable_eplb=enable_eplb, - num_slots=num_slots, - layer_updates_per_iter=layer_updates_per_iter, + moe_disable_finalize_fusion=not finalize_fusion, + moe_load_balancer=moe_load_balancer_config, ) - # Create MoE load balancer - moe_load_balancer = _create_moe_load_balancer(model_cfg, enable_eplb) - - # Get swiglu tensors if swiglu_gptoss_style is enabled - swiglu_tensors = quantize_util.get_swiglu_tensors() + moe_load_balancer = nullcontext() + if enable_eplb: + # A simple implementation of maybe_create_moe_load_balancer for unit test. + ep_rank = model_config.mapping.moe_ep_rank + ep_size = model_config.mapping.moe_ep_size + model_config.moe_load_balancer.setup(ep_rank=ep_rank, ep_size=ep_size) + moe_load_balancer = MoeLoadBalancer( + ep_rank=ep_rank, + ep_size=ep_size, + layer_updates_per_iter=model_config.moe_load_balancer.layer_updates_per_iter, + ) with moe_load_balancer: - # Create and setup fused MoE module + # Create fused MoE module fused_moe = create_moe( - routing_method=routing_method, - reduce_results=True, - model_config=model_cfg, - bias=swiglu_gptoss_style, - swiglu_alpha=swiglu_tensors["swiglu_alpha"] if swiglu_tensors else None, - swiglu_beta=swiglu_tensors["swiglu_beta"] if swiglu_tensors else None, - swiglu_limit=swiglu_tensors["swiglu_limit"] if swiglu_tensors else None, + routing_method=routing_method, reduce_results=True, model_config=model_config ) + fused_moe.load_weights([weights]) fused_moe.post_load_weights() fused_moe.cuda(f"cuda:{mapping.rank}") - # Record initial expert_ids before any forward pass (for EPLB test) - initial_expert_ids = None if isinstance(moe_load_balancer, MoeLoadBalancer): moe_load_balancer.register_weight_slots_after_to_cuda() moe_load_balancer.finalize_model() - moe_load_balancer.set_iter_info(enable_statistic=True, enable_update_weights=True) - # Record initial expert_ids immediately after initialization - # Use deepcopy to avoid reference issues if the list is modified in-place - initial_expert_ids = copy.deepcopy( - moe_load_balancer.single_layer_load_balancers[0].get_old_rank_expert_ids() - ) - logger.info(f"[EPLB Debug] Initial expert_ids (after init): {initial_expert_ids}") - # Create reference module ref_fused_moe = quantize_util.create_ref_module(routing_method) ref_fused_moe.load_weights([ref_weights]) ref_fused_moe.cuda(f"cuda:{mapping.rank}") - # Define forward function - def run_forward(): + # Evaluate the outputs + def _run_forward(x, router_logits, skip_ref=False): with torch.inference_mode(): + ref_output = None if skip_ref else ref_fused_moe.forward(x, router_logits) if isinstance(moe_load_balancer, MoeLoadBalancer): with MoeLoadBalancerIterContext(moe_load_balancer): output = fused_moe.forward( @@ -563,24 +183,72 @@ def _test_moe_worker_impl( x, router_logits, all_rank_num_tokens=all_rank_num_tokens ) torch.cuda.synchronize() - return output + return ref_output, output - # Get reference output - with torch.inference_mode(): - ref_output = ref_fused_moe.forward(x, router_logits) + load_expert_ids = None + if isinstance(moe_load_balancer, MoeLoadBalancer): + moe_load_balancer.set_iter_info(enable_statistic=True, enable_update_weights=True) + load_expert_ids = moe_load_balancer.single_layer_load_balancers[ + 0 + ].get_old_rank_expert_ids() - # Run tests - if enable_autotune: - _setup_autotuner_for_test(mapping) - _run_autotune_test(run_forward, ref_fused_moe, ref_output, backend_type, quant_algo) - else: - output = run_forward() - ref_fused_moe.check_accuracy(output, ref_output) + ref_output, output = _run_forward(x, router_logits) + ref_fused_moe.check_accuracy(output, ref_output) if enable_eplb: - _run_eplb_test( - run_forward, ref_fused_moe, ref_output, moe_load_balancer, initial_expert_ids + # Multi iter run for eplb + assert isinstance(moe_load_balancer, MoeLoadBalancer), ( + "Moe load balancer should be created when eplb is enabled" ) + extra_steps = 3 + for _ in range(extra_steps): + _, output = _run_forward(x, router_logits, skip_ref=True) + ref_fused_moe.check_accuracy(output, ref_output) + assert moe_load_balancer.iter_id == extra_steps + 1, ( + "Iter id should be equal to extra steps + 1 after multiple iterations" + ) + + current_expert_ids = moe_load_balancer.single_layer_load_balancers[ + 0 + ].get_old_rank_expert_ids() + assert load_expert_ids != current_expert_ids, ( + "Expert ids after eplb update should be different from the initial loaded ones" + ) + + +@pytest.mark.parametrize( + "quant_algo", + [ + None, + QuantAlgo.FP8, + QuantAlgo.NVFP4, + ], + ids=lambda val: f"quant_algo={val}", +) +@pytest.mark.parametrize( + "moe_backend", + [ + "CUTLASS", + "TRTLLM", + ], + ids=lambda val: f"moe_backend={val}", +) +@pytest.mark.parametrize( + "dtype", + [ + torch.float16, + torch.bfloat16, + ], + ids=lambda val: f"dtype={val}", +) +def test_moe(dtype, moe_backend, quant_algo): + # Enable configurable moe by default + if moe_backend == "TRTLLM": + if dtype == torch.float16 and quant_algo == QuantAlgo.NVFP4: + pytest.skip("TRTLLM NVFP4 MoE backend does not support float16 yet") + _skip_helper(quant_algo) + + _test_moe_worker(moe_backend=moe_backend, dtype=dtype, quant_algo=quant_algo) def _test_moe_multi_gpu( @@ -588,43 +256,12 @@ def _test_moe_multi_gpu( moe_backend, quant_algo, dtype, + ep_size, world_size, - parallel_mode="DEP", enable_eplb=False, layer_updates_per_iter=-1, num_slots=-1, - model_config: Optional[MoeModelConfig] = None, - seq_len: int = 4, - enable_autotune: bool = False, - routing_method_cls=RenormalizeMoeRoutingMethod, - dtype_routing_logits=None, - swiglu_alpha: float = 1, - swiglu_beta: float = 0, - swiglu_limit: float = float("inf"), ): - """ - Test MoE module with multi-GPU support. - - Args: - comm_method_type: Communication method type - moe_backend: Backend type string - quant_algo: Quantization algorithm - dtype: Activation data type - world_size: Total world size - parallel_mode: Parallelism strategy ("DEP", "TEP", "DTP", "TTP") - enable_eplb: Enable Expert Load Balancing - layer_updates_per_iter: EPLB layer updates per iteration - num_slots: EPLB number of slots - model_config: MoE model configuration - seq_len: Sequence length for test input - enable_autotune: Enable autotune and tactic capture/replay testing - routing_method_cls: Routing method class to use - dtype_routing_logits: Data type for routing logits (default: same as dtype) - swiglu_alpha: SwiGLU alpha parameter (default=1, non-gptoss) - swiglu_beta: SwiGLU beta parameter (default=0, non-gptoss) - swiglu_limit: SwiGLU limit parameter (default=inf, non-gptoss) - """ - def init_worker(custom_paths, comm_method_type): # Update the sys.path to align with main process for submodule import for custom_path in custom_paths: @@ -634,8 +271,6 @@ def _test_moe_multi_gpu( # Set comm method os.environ["TRTLLM_FORCE_COMM_METHOD"] = comm_method_type - mapping = _create_mapping_for_parallel_mode(world_size, parallel_mode) - with MPIPoolExecutor( initializer=init_worker, initargs=(sys.path, comm_method_type), max_workers=world_size ) as executor: @@ -647,18 +282,16 @@ def _test_moe_multi_gpu( moe_backend, dtype, quant_algo, - mapping, + Mapping( + world_size=world_size, + tp_size=world_size, + moe_ep_size=ep_size, + moe_tp_size=world_size // ep_size, + enable_attention_dp=True, + ), enable_eplb, layer_updates_per_iter, num_slots, - model_config, - seq_len, - enable_autotune, - routing_method_cls, - dtype_routing_logits, - swiglu_alpha, - swiglu_beta, - swiglu_limit, ) ] * world_size @@ -668,639 +301,104 @@ def _test_moe_multi_gpu( assert r is None -# ============================================================================ -# Test Parameters Configuration -# ============================================================================ - -# Quantization algorithms to test -QUANT_ALGOS = [ - None, # Unquantized - QuantAlgo.FP8, - QuantAlgo.NVFP4, - QuantAlgo.FP8_BLOCK_SCALES, - QuantAlgo.W4A8_NVFP4_FP8, - QuantAlgo.W4A16_MXFP4, - QuantAlgo.W4A8_MXFP4_MXFP8, - QuantAlgo.W8A16, - QuantAlgo.W4A8_AWQ, -] - -# Backend types to test -BACKEND_TYPES = [ - MoeBackendType.CUTLASS, - MoeBackendType.TRTLLM, - MoeBackendType.CUTEDSL, - MoeBackendType.DEEPGEMM, -] - -# Data types to test -DTYPES = [ - torch.float16, - torch.bfloat16, -] - -# Model configurations for testing -# (num_experts, top_k, hidden_size, intermediate_size) -# -# Default runs the full local config matrix (TRTLLM_TEST_MOE_CI=0). -# Set TRTLLM_TEST_MOE_CI=1 in CI to run only the smaller subset for speed. -CI_MOE_MODEL_CONFIGS = [ - MoeModelConfig(60, 4, 2048, 1408), # Qwen1.5-MoE-A2.7B - MoeModelConfig(32, 8, 7168, 2048), # DeepSeek-V3 (reduced from 256 experts to accelerate test) - MoeModelConfig(128, 4, 2880, 2880), # GPT-OSS-120B - MoeModelConfig(8, 1, 512, 512), # boundary: top_k=1, single expert activated -] - -LOCAL_MOE_MODEL_CONFIGS = CI_MOE_MODEL_CONFIGS + [ - MoeModelConfig(64, 6, 2048, 1408), # DeepSeek-MoE-16B / DeepSeek-V2-Lite - MoeModelConfig(384, 8, 7168, 2048), # Kimi-K2 - # === Boundary Tests: num_experts / top_k === - MoeModelConfig(4, 4, 512, 512), # top_k=num_experts, all experts activated - MoeModelConfig(7, 2, 256, 512), # prime num_experts - MoeModelConfig(13, 3, 256, 512), # prime num_experts, odd top_k - # === Boundary Tests: small sizes === - MoeModelConfig(4, 2, 64, 128), # very small hidden_size - MoeModelConfig(4, 2, 128, 64), # intermediate < hidden -] - -MOE_MODEL_CONFIGS = ( - CI_MOE_MODEL_CONFIGS - if os.environ.get("TRTLLM_TEST_MOE_CI", "0") == "1" - else LOCAL_MOE_MODEL_CONFIGS -) - -# Sequence lengths to test -SEQ_LENS = [1, 8] - -# Routing methods to test -ROUTING_METHODS = [ - RenormalizeMoeRoutingMethod, # TopK -> Softmax (Mixtral, etc.) - DefaultMoeRoutingMethod, # Softmax -> TopK - RenormalizeNaiveMoeRoutingMethod, # Softmax -> TopK -> Renormalize (Qwen3) - Llama4RenormalizeMoeRoutingMethod, # Top1 -> Sigmoid (Llama4) - DeepSeekV3MoeRoutingMethod, # Sigmoid -> BiasAdd -> Group TopK (DeepSeek-V3) - MiniMaxM2MoeRoutingMethod, # Sigmoid -> BiasAdd -> TopK -> Renormalize (MiniMax-M2) -] - - -MULTI_GPU_ROUTING_METHODS = [ - RenormalizeMoeRoutingMethod, # TopK -> Softmax (Mixtral, etc.) - DeepSeekV3MoeRoutingMethod, # Sigmoid -> BiasAdd -> Group TopK (DeepSeek-V3) -] - - -# ============================================================================ -# Multi-GPU Test Configuration -# ============================================================================ -# Parallel modes to test -PARALLEL_MODES = [ - "DEP", # Attention DP, MoE EP - "TEP", # Attention TP, MoE EP - "DTP", # Attention DP, MoE TP - "TTP", # Attention TP, MoE TP -] - -# Communication methods to test -COMM_METHODS = [ - "NVLINK_ONE_SIDED", - "NVLINK_TWO_SIDED", - "DEEPEP", - "DEEPEPLOWLATENCY", -] - -# SwiGLU parameters for swiglu_gptoss_style testing -SWIGLU_ALPHAS = [1, 1.702] # default, GPT-OSS (modeling_gpt_oss.py) -SWIGLU_BETAS = [0, 1.0] # default, GPT-OSS -SWIGLU_LIMITS = [float("inf"), 7.0] # default, GPT-OSS - -# Single-GPU: full product of all SwiGLU combos -SWIGLU_COMBOS = list(product(SWIGLU_ALPHAS, SWIGLU_BETAS, SWIGLU_LIMITS)) - -# Multi-GPU: only non-gptoss (default) and one gptoss combo -MULTI_GPU_SWIGLU_COMBOS = [ - (1, 0, float("inf")), # non-gptoss (default SwiGLU) - (1.702, 1.0, 7.0), # gptoss style (GPT-OSS real values) -] - - -def _get_comm_method_skip_reason( - comm_method: str, - model_config: "MoeModelConfig", -) -> Optional[str]: - """ - Check if a communication method is compatible with the given model config. - - Returns a skip reason string if incompatible, None otherwise. - """ - from tensorrt_llm._torch.modules.fused_moe.communication.deep_ep_low_latency import ( - DeepEPLowLatency, - ) - - if comm_method == "DEEPEPLOWLATENCY": - if model_config.hidden_size not in DeepEPLowLatency.SUPPORTED_HIDDEN_SIZES: - return ( - f"DeepEPLowLatency does not support hidden_size={model_config.hidden_size}, " - f"requires one of {sorted(DeepEPLowLatency.SUPPORTED_HIDDEN_SIZES)}" - ) - return None - - -def generate_multi_gpu_test_params( - parallel_modes, - comm_methods, - swiglu_combos, - model_configs, - seq_lens, - dtypes, - backend_types, - quant_algos, - routing_methods, -) -> List: - """ - Generate test parameter combinations for multi-GPU tests. - - Args: - parallel_modes: List of parallel modes - comm_methods: List of communication methods - swiglu_combos: List of (swiglu_alpha, swiglu_beta, swiglu_limit) tuples - model_configs: List of MoeModelConfig - seq_lens: List of sequence lengths - dtypes: List of data types - backend_types: List of backend types - quant_algos: List of quantization algorithms - routing_methods: List of routing method classes - - Returns: - List of pytest.param objects with appropriate skip marks - """ - params: List = [] - for parallel_mode, comm_method in product(parallel_modes, comm_methods): - for ( - swiglu_alpha, - swiglu_beta, - swiglu_limit, - model_config, - seq_len, - dtype, - backend_type, - quant_algo, - routing_method_cls, - skip_reason, - base_test_id, - ) in iter_base_test_configs( - swiglu_combos, - model_configs, - seq_lens, - dtypes, - backend_types, - quant_algos, - routing_methods, - ): - # Check multi-GPU specific skip conditions - if not skip_reason: - skip_reason = _get_comm_method_skip_reason(comm_method, model_config) - if not skip_reason: - skip_reason = should_skip_trtllm( - backend_type, quant_algo, model_config, comm_method=comm_method - ) - if not skip_reason: - skip_reason = should_skip_cutedsl( - backend_type, quant_algo, model_config, comm_method - ) - if not skip_reason: - skip_reason = should_skip_deepgemm( - backend_type, comm_method, quant_algo=quant_algo, model_config=model_config - ) - if not skip_reason: - skip_reason = should_skip_multi_gpu(parallel_mode, model_config, world_size=4) - - test_id = f"parallel={parallel_mode}-comm={comm_method}-{base_test_id}" - param_values = ( - parallel_mode, - comm_method, - dtype, - backend_type.value, - quant_algo, - seq_len, - model_config, - routing_method_cls, - swiglu_alpha, - swiglu_beta, - swiglu_limit, - ) - params.append(create_test_param(param_values, test_id, skip_reason)) - - return params - - -def generate_base_test_params( - swiglu_combos, model_configs, seq_lens, dtypes, backend_types, quant_algos, routing_methods -) -> List: - """ - Generate test parameter combinations for base tests. - - Args: - swiglu_combos: List of (swiglu_alpha, swiglu_beta, swiglu_limit) tuples - model_configs: List of MoeModelConfig - seq_lens: List of sequence lengths - dtypes: List of data types - backend_types: List of backend types - quant_algos: List of quantization algorithms - routing_methods: List of routing method classes - - Returns: - List of pytest.param objects with appropriate skip marks - """ - params: List = [] - for ( - swiglu_alpha, - swiglu_beta, - swiglu_limit, - model_config, - seq_len, - dtype, - backend_type, - quant_algo, - routing_method_cls, - skip_reason, - base_test_id, - ) in iter_base_test_configs( - swiglu_combos, model_configs, seq_lens, dtypes, backend_types, quant_algos, routing_methods - ): - param_values = ( - dtype, - backend_type.value, - quant_algo, - seq_len, - model_config, - routing_method_cls, - swiglu_alpha, - swiglu_beta, - swiglu_limit, - ) - params.append(create_test_param(param_values, base_test_id, skip_reason)) - - return params - - -# ============================================================================ -# MoE Single GPU Tests -# ============================================================================ -# Pre-generate test parameters at module load time -BASE_TEST_PARAMS = generate_base_test_params( - swiglu_combos=SWIGLU_COMBOS, - model_configs=MOE_MODEL_CONFIGS, - seq_lens=SEQ_LENS, - dtypes=DTYPES, - backend_types=BACKEND_TYPES, - quant_algos=QUANT_ALGOS, - routing_methods=ROUTING_METHODS, -) - - -@pytest.mark.skip(reason="Temporarily skipped due to the long time to run the test") -@pytest.mark.parametrize( - "dtype,moe_backend,quant_algo,seq_len,model_config,routing_method_cls," - "swiglu_alpha,swiglu_beta,swiglu_limit", - BASE_TEST_PARAMS, -) -def test_ConfigurableMoE_single_gpu( - dtype: torch.dtype, - moe_backend: str, - quant_algo: Optional[QuantAlgo], - seq_len: int, - model_config: MoeModelConfig, - routing_method_cls, - swiglu_alpha: float, - swiglu_beta: float, - swiglu_limit: float, -): - """ - Single-GPU test for ConfigurableMoE module. - - This test verifies: - 1. MoE create_moe -> load_weights -> forward produces correct results - 2. Various backend + quantization combinations work correctly - 3. Autotune captures and replays all tactics properly - 4. swiglu_gptoss_style (SwiGLU with custom parameters) works correctly - """ - # DeepSeekV3 routing requires float32 routing_logits for TRTLLM backend - # See: cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp:70-72 - dtype_routing_logits = None - if ( - moe_backend == MoeBackendType.TRTLLM.value - and routing_method_cls == DeepSeekV3MoeRoutingMethod - ): - dtype_routing_logits = torch.float32 - - _test_moe_worker( - moe_backend=moe_backend, - dtype=dtype, - quant_algo=quant_algo, - model_config=model_config, - seq_len=seq_len, - enable_autotune=True, - routing_method_cls=routing_method_cls, - dtype_routing_logits=dtype_routing_logits, - swiglu_alpha=swiglu_alpha, - swiglu_beta=swiglu_beta, - swiglu_limit=swiglu_limit, - ) - - -# ============================================================================ -# MoE Multi-GPU Tests -# ============================================================================ -# Pre-generate multi-GPU test parameters at module load time -MULTI_GPU_TEST_PARAMS = generate_multi_gpu_test_params( - parallel_modes=PARALLEL_MODES, - comm_methods=COMM_METHODS, - swiglu_combos=MULTI_GPU_SWIGLU_COMBOS, - model_configs=MOE_MODEL_CONFIGS, - seq_lens=SEQ_LENS, - dtypes=DTYPES, - backend_types=BACKEND_TYPES, - quant_algos=QUANT_ALGOS, - routing_methods=MULTI_GPU_ROUTING_METHODS, -) - - -@pytest.mark.skip(reason="Temporarily skipped due to the long time to run the test") @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="needs 4 GPUs to run this test") @pytest.mark.parametrize( - "parallel_mode,comm_method_type,dtype,moe_backend,quant_algo,seq_len,model_config," - "routing_method_cls,swiglu_alpha,swiglu_beta,swiglu_limit", - MULTI_GPU_TEST_PARAMS, + "quant_algo", + [ + None, + QuantAlgo.NVFP4, + ], + ids=lambda val: f"quant_algo={val}", ) -def test_ConfigurableMoE_multi_gpu( - parallel_mode, - comm_method_type, - dtype, - moe_backend, - quant_algo, - seq_len, - model_config, - routing_method_cls, - swiglu_alpha, - swiglu_beta, - swiglu_limit, -): - # DeepSeekV3 routing requires float32 routing_logits for TRTLLM backend - # See: cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp:70-72 - dtype_routing_logits = None - if ( - moe_backend == MoeBackendType.TRTLLM.value - and routing_method_cls == DeepSeekV3MoeRoutingMethod - ): - dtype_routing_logits = torch.float32 +@pytest.mark.parametrize( + "moe_backend", + [ + "CUTLASS", + "TRTLLM", + ], + ids=lambda val: f"moe_backend={val}", +) +@pytest.mark.parametrize( + "comm_method_type", + [ + "NVLINK_ONE_SIDED", + "NVLINK_TWO_SIDED", + ], + ids=lambda val: f"comm_method_type={val}", +) +def test_moe_multi_gpu(comm_method_type, moe_backend, quant_algo): + _skip_helper(quant_algo) + dtype = torch.bfloat16 + ep_size = 4 world_size = 4 _test_moe_multi_gpu( comm_method_type, moe_backend, quant_algo, dtype=dtype, + ep_size=ep_size, world_size=world_size, - parallel_mode=parallel_mode, - model_config=model_config, - seq_len=seq_len, - routing_method_cls=routing_method_cls, - dtype_routing_logits=dtype_routing_logits, - swiglu_alpha=swiglu_alpha, - swiglu_beta=swiglu_beta, - swiglu_limit=swiglu_limit, ) -# ============================================================================ -# MoE Multi-GPU EPLB Tests -# ============================================================================ -# EPLB-specific configuration -EPLB_PARALLEL_MODES = ["DEP"] # EPLB only works with DEP mode (use_dp=True) -EPLB_COMM_METHODS = [ - "NVLINK_ONE_SIDED", - "NVLINK_TWO_SIDED", -] # Communication methods for EPLB -EPLB_ROUTING_METHODS = [RenormalizeMoeRoutingMethod] # Common routing methods -EPLB_MODEL_CONFIGS = [MoeModelConfig(8, 2, 512, 512)] # Model configs for EPLB -EPLB_NUM_SLOTS_LIST = [16] # Must be > num_experts (8) to be effective - - -def _get_fused_moe_method_class(quant_algo, backend_type): - """ - Get the FusedMoEMethod class based on quant_algo and backend_type. - - This mirrors the logic in each backend's _get_quant_method() method. - - Returns: - FusedMoEMethod class or None if not found - """ - backend_str = backend_type.value if hasattr(backend_type, "value") else str(backend_type) - - if quant_algo is None: - # Unquantized - only CUTLASS supports it - if backend_str == "CUTLASS": - return UnquantizedFusedMoEMethod - return None - - # CUTLASS backend - # Mapping based on CutlassFusedMoE._get_quant_method() logic - if backend_str == "CUTLASS": - method_map = { - QuantAlgo.FP8: FP8QDQFusedMoEMethod, - QuantAlgo.FP8_BLOCK_SCALES: DeepSeekFP8BlockScalesFusedMoEMethod, - QuantAlgo.NVFP4: NVFP4CutlassFusedMoEMethod, - # W4A8_AWQ uses is_int4_weight_only_per_group() -> WInt4AFP8FusedMoEMethod - QuantAlgo.W4A8_AWQ: WInt4AFP8FusedMoEMethod, - QuantAlgo.W8A16: INT8WoqPerChannelFusedMoEMethod, - QuantAlgo.W4A16_MXFP4: WFP4A16FusedMoEMethod, - QuantAlgo.W4A8_MXFP4_FP8: W4A8MXFP4FP8CutlassFusedMoEMethod, - QuantAlgo.W4A8_MXFP4_MXFP8: W4A8MXFP4MXFP8CutlassFusedMoEMethod, - # Note: W4A8_NVFP4_FP8 is NOT supported by CUTLASS backend - } - return method_map.get(quant_algo) - - # TRTLLM backend - if backend_str == "TRTLLM": - method_map = { - QuantAlgo.FP8_BLOCK_SCALES: DeepSeekFP8BlockScalesFusedMoEMethod, - QuantAlgo.NVFP4: NVFP4TRTLLMGenFusedMoEMethod, - QuantAlgo.W4A16_MXFP4: W4A16MXFP4TRTLLMGenFusedMoEMethod, - QuantAlgo.W4A8_NVFP4_FP8: W4A8NVFP4FP8TRTLLMGenFusedMoEMethod, - QuantAlgo.W4A8_MXFP4_FP8: W4A8MXFP4FP8TRTLLMGenFusedMoEMethod, - QuantAlgo.W4A8_MXFP4_MXFP8: W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod, - } - return method_map.get(quant_algo) - - # CUTEDSL backend uses same methods as CUTLASS for quantization - if backend_str == "CUTEDSL": - method_map = { - QuantAlgo.NVFP4: NVFP4CutlassFusedMoEMethod, - } - return method_map.get(quant_algo) - - # DEEPGEMM backend - if backend_str == "DEEPGEMM": - method_map = { - QuantAlgo.FP8_BLOCK_SCALES: DeepSeekFP8BlockScalesFusedMoEMethod, - } - return method_map.get(quant_algo) - - return None - - -def _should_skip_EPLB(quant_algo, backend_type, num_slots, num_experts): - """ - Check if EPLB test should be skipped based on quant_algo, backend_type, and slot configuration. - - Returns: - str or None: Skip reason if should skip, None otherwise - """ - # Check num_slots > num_experts requirement - if num_slots <= num_experts: - return f"EPLB requires num_slots ({num_slots}) > num_experts ({num_experts})" - - # Get the FusedMoEMethod class for this quant_algo + backend combination - method_class = _get_fused_moe_method_class(quant_algo, backend_type) - - if method_class is None: - # Cannot determine the method class, skip the test - return ( - f"Cannot determine FusedMoEMethod for quant_algo={quant_algo}, backend={backend_type}" - ) - - # Query the method class directly for EPLB support - if not method_class.supports_online_eplb(): - return f"EPLB not supported for {method_class.__name__} (supports_online_eplb=False)" - - return None - - -def generate_eplb_test_params( - parallel_modes, - comm_methods, - model_configs, - num_slots_list, - dtypes, - backend_types, - quant_algos, - routing_methods, -) -> List: - """ - Generate test parameter combinations for EPLB tests. - - EPLB requires num_slots > num_experts to be effective. - - Args: - parallel_modes: List of parallel modes (only EP modes: DEP, TEP) - comm_methods: List of communication methods - model_configs: List of MoeModelConfig - num_slots_list: List of EPLB slots (must be > num_experts) - dtypes: List of data types - backend_types: List of backend types - quant_algos: List of quantization algorithms - routing_methods: List of routing method classes - - Returns: - List of pytest.param objects with appropriate skip marks - """ - params: List = [] - - for ( - parallel_mode, - comm_method, - model_config, - num_slots, - dtype, - backend_type, - quant_algo, - routing_method_cls, - ) in product( - parallel_modes, - comm_methods, - model_configs, - num_slots_list, - dtypes, - backend_types, - quant_algos, - routing_methods, - ): - # Get skip reason using existing logic - skip_reason = get_quick_skip_reason( - backend_type, quant_algo, dtype, model_config, routing_method_cls - ) - - # Check EPLB-specific skip conditions - if not skip_reason: - skip_reason = _should_skip_EPLB( - quant_algo, backend_type, num_slots, model_config.num_experts - ) - - routing_name = routing_method_cls.__name__.replace("MoeRoutingMethod", "") - test_id = ( - f"parallel={parallel_mode}-comm={comm_method}-{model_config}-slots={num_slots}-" - f"dtype={dtype}-backend={backend_type.value}-quant={quant_algo}-routing={routing_name}" - ) - - param_values = ( - parallel_mode, - comm_method, - dtype, - backend_type.value, - quant_algo, - model_config, - num_slots, - routing_method_cls, - ) - params.append(create_test_param(param_values, test_id, skip_reason)) - - return params - - -# Pre-generate EPLB test parameters at module load time -EPLB_TEST_PARAMS = generate_eplb_test_params( - parallel_modes=EPLB_PARALLEL_MODES, - comm_methods=EPLB_COMM_METHODS, - model_configs=EPLB_MODEL_CONFIGS, - num_slots_list=EPLB_NUM_SLOTS_LIST, - dtypes=DTYPES, - backend_types=BACKEND_TYPES, - quant_algos=QUANT_ALGOS, - routing_methods=EPLB_ROUTING_METHODS, -) - - -@pytest.mark.skip(reason="Temporarily skipped due to the long time to run the test") @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="needs 4 GPUs to run this test") @pytest.mark.skipif( not _tbr.is_host_accessible_device_memory_supported(), reason="needs support of host accessible device memory", ) @pytest.mark.parametrize( - "parallel_mode,comm_method_type,dtype,moe_backend,quant_algo,model_config,num_slots,routing_method_cls", - EPLB_TEST_PARAMS, + "quant_algo", + [ + None, + QuantAlgo.NVFP4, + ], + ids=lambda val: f"quant_algo={val}", ) -def test_ConfigurableMoE_multi_gpu_eplb( - parallel_mode, - comm_method_type, - dtype, - moe_backend, - quant_algo, - model_config, - num_slots, - routing_method_cls, +@pytest.mark.parametrize( + "moe_backend", + [ + "CUTLASS", + ], + ids=lambda val: f"moe_backend={val}", +) +@pytest.mark.parametrize( + "comm_method_type", + [ + "NVLINK_ONE_SIDED", + ], + ids=lambda val: f"comm_method_type={val}", +) +@pytest.mark.parametrize( + "num_slots", + [ + 16, + ], + ids=lambda val: f"num_slots={val}", +) +@pytest.mark.parametrize( + "layer_updates_per_iter", + [ + 1, + ], + ids=lambda val: f"layer_updates_per_iter={val}", +) +def test_moe_multi_gpu_eplb( + layer_updates_per_iter, num_slots, comm_method_type, moe_backend, quant_algo ): + _skip_helper(quant_algo) + + dtype = torch.bfloat16 + ep_size = 4 world_size = 4 _test_moe_multi_gpu( comm_method_type, moe_backend, quant_algo, dtype=dtype, + ep_size=ep_size, world_size=world_size, - parallel_mode=parallel_mode, enable_eplb=True, - layer_updates_per_iter=1, + layer_updates_per_iter=layer_updates_per_iter, num_slots=num_slots, - model_config=model_config, - routing_method_cls=routing_method_cls, )