[TRTLLM-9111][feat] provide the uniform test framework to test all MoE backends (#11128)

Signed-off-by: xxi <xxi@nvidia.com>
This commit is contained in:
xxi 2026-02-04 15:57:56 +08:00 committed by GitHub
parent de6931bbfd
commit 02b80bfd58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 2196 additions and 255 deletions

View File

@ -95,6 +95,35 @@ class ConfigurableMoE(MoE):
- Communication: Auto-selected based on hardware (NVLINK > DeepEP > AllGather)
"""
@classmethod
def can_implement(
cls,
quant_algo,
dtype_activation: torch.dtype = torch.bfloat16,
gptoss_style: bool = False,
):
"""
ConfigurableMoE is a wrapper class that delegates to specific backends.
To check capability, query the specific backend class directly:
- CutlassFusedMoE.can_implement(quant_algo, dtype_activation, gptoss_style)
- TRTLLMGenFusedMoE.can_implement(quant_algo, dtype_activation, gptoss_style)
- etc.
Args:
quant_algo: The quantization algorithm to check (None for unquantized)
dtype_activation: The activation data type
gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled
Returns:
Tuple[bool, Optional[str]]: Always returns (False, reason)
"""
del quant_algo, dtype_activation, gptoss_style # Unused - wrapper class
return False, (
"ConfigurableMoE is a wrapper class. "
"Query the specific backend (CutlassFusedMoE, TRTLLMGenFusedMoE, etc.) directly."
)
def __init__(
self,
*,

View File

@ -4,7 +4,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from tensorrt_llm._utils import is_sm_100f
from tensorrt_llm._utils import get_sm_version, is_sm_100f
from tensorrt_llm.models.modeling_utils import QuantAlgo
from ...autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
OptimizationProfile, TunableRunner, TuningConfig)
@ -312,6 +313,69 @@ class CuteDslFusedMoE(CutlassFusedMoE):
model_config (ModelConfig): Configuration object for the model.
"""
@classmethod
def can_implement(
cls,
quant_algo: Optional[QuantAlgo],
dtype_activation: torch.dtype = torch.bfloat16,
gptoss_style: bool = False,
) -> Tuple[bool, Optional[str]]:
"""
Check if CuteDslFusedMoE can implement the given quantization algorithm.
CuteDslFusedMoE supports:
- NVFP4: SM in {100, 103}
Does NOT support unquantized mode. Output dtype is hardcoded to bfloat16.
Does NOT support gptoss_style (bias/swiglu with custom alpha/beta/limit).
Args:
quant_algo: The quantization algorithm to check (None for unquantized)
dtype_activation: The activation input data type. Only bfloat16 is supported
because output dtype is hardcoded to bfloat16 (input/output dtype must match).
gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled.
CuteDslFusedMoE does NOT support gptoss_style.
Returns:
Tuple[bool, Optional[str]]: (can_implement, skip_reason)
"""
from .interface import _warn_and_return
sm_version = get_sm_version()
# CuteDslFusedMoE requires at least SM90
if sm_version < 90:
return _warn_and_return(
f"CuteDslFusedMoE requires SM >= 90, got SM{sm_version}")
# Check dtype_activation: output is hardcoded to bfloat16, so input must also be bfloat16
# to maintain input/output dtype consistency
if dtype_activation != torch.bfloat16:
return _warn_and_return(
f"CuteDslFusedMoE only supports bfloat16 activation (output is hardcoded to bfloat16), "
f"got {dtype_activation}")
# CuteDslFusedMoE does NOT support unquantized mode
if quant_algo is None:
return _warn_and_return(
"CuteDslFusedMoE does not support unquantized mode")
# CuteDslFusedMoE does NOT support gptoss_style
if gptoss_style:
return _warn_and_return(
"CuteDslFusedMoE does not support gptoss_style (bias/swiglu with custom alpha/beta/limit)"
)
# NVFP4 - SM in {100, 103}
if quant_algo == QuantAlgo.NVFP4:
if sm_version not in {100, 103}:
return _warn_and_return(
f"NVFP4 requires SM100 or SM103, got SM{sm_version}")
return True, None
return _warn_and_return(
f"CuteDslFusedMoE does not support quant_algo={quant_algo}")
def __init__(
self,
*,

View File

@ -7,7 +7,9 @@ import torch
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe
from tensorrt_llm._torch.distributed.moe_alltoall import MoeAlltoAll
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.logger import logger
from tensorrt_llm.models.modeling_utils import QuantAlgo
from tensorrt_llm.tools.layer_wise_benchmarks import get_calibrator
from ...distributed import allgather
@ -57,6 +59,151 @@ class CutlassFusedMoE(MoE):
equals to: dynamic quant + routing(topK, etc.) [+ fp4_allgather] + scatter + gemm1 + swiglu + gemm2 + finalizeMoeRoute [no allreduce] + reducescatter
"""
# Quantization algorithm support table for can_implement()
# Format: quant_algo -> {sm_constraint, dtypes}
# sm_constraint types:
# - ("min", N): SM >= N
# - ("exact", N): SM == N
# - ("in", {N1, N2, ...}): SM in set
_QUANT_SUPPORT_TABLE = {
# Unquantized (FP16/BF16): SM >= 80
None: {
"sm_constraint": ("min", 80),
"dtypes": {torch.float16, torch.bfloat16},
},
# FP8 per-tensor (QDQ): SM >= 89
QuantAlgo.FP8: {
"sm_constraint": ("min", 89),
"dtypes": {torch.float16, torch.bfloat16, torch.float32},
},
# FP8_BLOCK_SCALES: SM == 90 only
QuantAlgo.FP8_BLOCK_SCALES: {
"sm_constraint": ("exact", 90),
"dtypes": {torch.float16, torch.bfloat16, torch.float32},
},
# NVFP4: SM in {100, 103}
QuantAlgo.NVFP4: {
"sm_constraint": ("in", {100, 103}),
"dtypes": {torch.float16, torch.bfloat16, torch.float8_e4m3fn},
},
# W4A8_AWQ: SM in {89, 90} only
QuantAlgo.W4A8_AWQ: {
"sm_constraint": ("in", {89, 90}),
"dtypes": {torch.float16, torch.bfloat16},
},
# W8A16: SM >= 80
QuantAlgo.W8A16: {
"sm_constraint": ("min", 80),
"dtypes": {torch.float16, torch.bfloat16},
},
# W4A16_MXFP4: SM == 90 only
QuantAlgo.W4A16_MXFP4: {
"sm_constraint": ("exact", 90),
"dtypes": {torch.float16, torch.bfloat16},
},
# W4A8_MXFP4_FP8: SM in {100, 103}
QuantAlgo.W4A8_MXFP4_FP8: {
"sm_constraint": ("in", {100, 103}),
"dtypes": {torch.float16, torch.bfloat16, torch.float32},
},
# W4A8_MXFP4_MXFP8: SM in {100, 103}
QuantAlgo.W4A8_MXFP4_MXFP8: {
"sm_constraint": ("in", {100, 103}),
"dtypes": {torch.float16, torch.bfloat16},
},
}
# Quantization algorithms that support gptoss_style
_GPTOSS_SUPPORTED_ALGOS = {QuantAlgo.W4A8_MXFP4_MXFP8}
@classmethod
def can_implement(
cls,
quant_algo: Optional[QuantAlgo],
dtype_activation: torch.dtype = torch.bfloat16,
gptoss_style: bool = False,
) -> Tuple[bool, Optional[str]]:
"""
Check if CutlassFusedMoE can implement the given quantization algorithm.
CutlassFusedMoE supports:
- Unquantized (FP16/BF16): SM >= 80
- FP8 per-tensor (QDQ): SM >= 89
- FP8_BLOCK_SCALES: SM == 90 only
- NVFP4: SM in {100, 103}
- W4A8_AWQ: SM in {89, 90} only
- W8A16: SM >= 80
- W4A16_MXFP4: SM == 90 only
- W4A8_MXFP4_FP8: SM in {100, 103}
- W4A8_MXFP4_MXFP8: SM in {100, 103}
Args:
quant_algo: The quantization algorithm to check (None for unquantized)
dtype_activation: The activation input data type (before quantization).
Supported dtypes vary by quantization mode:
- Unquantized: float16, bfloat16
- FP8/FP8_BLOCK_SCALES/W4A8_MXFP4_FP8: float16, bfloat16, float32
- NVFP4: float16, bfloat16, float8_e4m3fn
- W4A16_MXFP4/W4A8_AWQ/W8A16/W4A8_MXFP4_MXFP8: float16, bfloat16
gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled.
CutlassFusedMoE only supports gptoss_style for W4A8_MXFP4_MXFP8 quantization.
Returns:
Tuple[bool, Optional[str]]: (can_implement, skip_reason)
"""
from .interface import _warn_and_return
sm_version = get_sm_version()
# Check minimum SM version for Cutlass backend
if sm_version < 80:
return _warn_and_return(
f"CutlassFusedMoE requires SM >= 80, got SM{sm_version}")
# Check gptoss_style support
if gptoss_style and quant_algo not in cls._GPTOSS_SUPPORTED_ALGOS:
return _warn_and_return(
f"CutlassFusedMoE gptoss_style only supports W4A8_MXFP4_MXFP8 "
f"(got quant_algo={quant_algo})")
# Check if quant_algo is supported
if quant_algo not in cls._QUANT_SUPPORT_TABLE:
return _warn_and_return(
f"CutlassFusedMoE does not support quant_algo={quant_algo}")
support_info = cls._QUANT_SUPPORT_TABLE[quant_algo]
# Check SM version constraint
constraint_type, constraint_value = support_info["sm_constraint"]
algo_name = "unquantized" if quant_algo is None else quant_algo.name
if constraint_type == "min":
if sm_version < constraint_value:
return _warn_and_return(
f"CutlassFusedMoE {algo_name} requires SM >= {constraint_value}, "
f"got SM{sm_version}")
elif constraint_type == "exact":
if sm_version != constraint_value:
return _warn_and_return(
f"CutlassFusedMoE {algo_name} only supports SM{constraint_value}, "
f"got SM{sm_version}")
elif constraint_type == "in":
if sm_version not in constraint_value:
sm_list = "/".join(f"SM{v}" for v in sorted(constraint_value))
return _warn_and_return(
f"CutlassFusedMoE {algo_name} only supports {sm_list}, "
f"got SM{sm_version}")
# Check dtype_activation
supported_dtypes = support_info["dtypes"]
if dtype_activation not in supported_dtypes:
dtype_list = ", ".join(str(d) for d in supported_dtypes)
return _warn_and_return(
f"CutlassFusedMoE {algo_name} requires {dtype_list}, "
f"got {dtype_activation}")
return True, None
def __init__(
self,
*,

View File

@ -1,4 +1,19 @@
from typing import Dict, List, Optional, Union
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Tuple, Union
import torch
import triton
@ -6,7 +21,8 @@ import triton.language as tl
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
from tensorrt_llm import deep_gemm
from tensorrt_llm._utils import nvtx_range
from tensorrt_llm._utils import get_sm_version, nvtx_range
from tensorrt_llm.models.modeling_utils import QuantAlgo
from ...distributed import allgather
from ...memory_buffer_utils import get_memory_buffers
@ -361,6 +377,67 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
model_config (ModelConfig): Configuration object for the model.
"""
@classmethod
def can_implement(
cls,
quant_algo: Optional[QuantAlgo],
dtype_activation: torch.dtype = torch.bfloat16,
gptoss_style: bool = False,
) -> Tuple[bool, Optional[str]]:
"""
Check if DeepGemmFusedMoE can implement the given quantization algorithm.
DeepGemmFusedMoE supports:
- FP8_BLOCK_SCALES: SM in {100, 103}
Does NOT support unquantized mode. Output dtype is hardcoded to bfloat16.
Does NOT support gptoss_style (bias/swiglu with custom alpha/beta/limit).
Args:
quant_algo: The quantization algorithm to check (None for unquantized)
dtype_activation: The activation input data type. Supported types are
float32, bfloat16, and float16 (required by moe_permute_op kernel).
Note: Output dtype is always bfloat16 regardless of input dtype.
gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled.
DeepGemmFusedMoE does NOT support gptoss_style.
Returns:
Tuple[bool, Optional[str]]: (can_implement, skip_reason)
"""
from .interface import _warn_and_return
sm_version = get_sm_version()
if sm_version not in {100, 103}:
return _warn_and_return(
f"DeepGemmFusedMoE requires SM100 or SM103, got SM{sm_version}")
# Check dtype_activation: moe_permute_op only supports float32, bfloat16, float16
if dtype_activation not in {
torch.float32, torch.bfloat16, torch.float16
}:
return _warn_and_return(
f"DeepGemmFusedMoE requires float32, bfloat16, or float16 activation, "
f"got {dtype_activation}")
# DeepGemmFusedMoE does NOT support unquantized mode
if quant_algo is None:
return _warn_and_return(
"DeepGemmFusedMoE does not support unquantized mode")
# DeepGemmFusedMoE does NOT support gptoss_style
if gptoss_style:
return _warn_and_return(
"DeepGemmFusedMoE does not support gptoss_style (bias/swiglu with custom alpha/beta/limit)"
)
# Only FP8_BLOCK_SCALES is supported
if quant_algo == QuantAlgo.FP8_BLOCK_SCALES:
return True, None
return _warn_and_return(
f"DeepGemmFusedMoE does not support quant_algo={quant_algo}")
# To reuse pytorch memory segments allocated during graph capture.
buffers = get_memory_buffers()

View File

@ -1,7 +1,22 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
from typing import Dict, List, NamedTuple, Optional
from typing import Dict, List, NamedTuple, Optional, Tuple
import torch
import torch.nn as nn
@ -1263,6 +1278,73 @@ class TritonMXFP4FusedMoEMethod(TritonUnquantizedFusedMoEMethod):
class TritonFusedMoE(MoE):
@classmethod
def can_implement(
cls,
quant_algo: Optional["QuantAlgo"],
dtype_activation: torch.dtype = torch.bfloat16,
gptoss_style: bool = False,
) -> Tuple[bool, Optional[str]]:
"""
Check if TritonFusedMoE can implement the given quantization algorithm.
TritonFusedMoE supports (SM90 only, gptoss_style=True only):
- Unquantized (BF16 only)
- FP8 per-tensor (QDQ)
- W4A8_MXFP4_FP8
- W4A16_MXFP4
Args:
quant_algo: The quantization algorithm to check (None for unquantized)
dtype_activation: The activation data type. In unquantized mode, activation,
weight, and output dtypes must all match (only bfloat16 supported).
gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled.
TritonFusedMoE ONLY supports gptoss_style=True.
Returns:
Tuple[bool, Optional[str]]: (can_implement, skip_reason)
"""
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.models.modeling_utils import QuantAlgo
from .interface import _warn_and_return
sm_version = get_sm_version()
# TritonFusedMoE only supports SM90
if sm_version != 90:
return _warn_and_return(
f"TritonFusedMoE only supports SM90, got SM{sm_version}")
# TritonFusedMoE ONLY supports gptoss_style=True
if not gptoss_style:
return _warn_and_return(
"TritonFusedMoE only supports gptoss_style=True")
# Unquantized mode - only bfloat16 is supported
if quant_algo is None:
if dtype_activation != torch.bfloat16:
return _warn_and_return(
f"TritonFusedMoE unquantized mode only supports bfloat16, got {dtype_activation}"
)
return True, None
# FP8 per-tensor (QDQ) and W4A8_MXFP4_FP8 - no dtype_activation restriction
if quant_algo in {QuantAlgo.FP8, QuantAlgo.W4A8_MXFP4_FP8}:
return True, None
# W4A16_MXFP4 - only bfloat16 and float16 are supported
if quant_algo == QuantAlgo.W4A16_MXFP4:
if dtype_activation not in {torch.bfloat16, torch.float16}:
return _warn_and_return(
f"TritonFusedMoE W4A16_MXFP4 only supports bfloat16 or float16, "
f"got {dtype_activation}")
return True, None
# Unsupported quantization algorithm
return _warn_and_return(
f"TritonFusedMoE does not support quant_algo={quant_algo}")
def __init__(
self,
*,

View File

@ -1,7 +1,22 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from functools import cached_property
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch import nn
@ -10,6 +25,7 @@ from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe
from tensorrt_llm._torch.distributed.moe_alltoall import MoeAlltoAll
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.logger import logger
from tensorrt_llm.models.modeling_utils import QuantAlgo
from ...custom_ops.trtllm_gen_custom_ops import \
fp4_block_scale_fake_output_without_finalize
@ -61,6 +77,88 @@ class TRTLLMGenFusedMoE(MoE):
There should be at lease `num_experts` slots in the model engine. More than that is OK, in that case, some experts may have multiple replicas.
"""
# Supported quantization algorithms for TRTLLMGenFusedMoE
_SUPPORTED_QUANT_ALGOS = {
QuantAlgo.NVFP4,
QuantAlgo.FP8_BLOCK_SCALES,
QuantAlgo.W4A8_NVFP4_FP8,
QuantAlgo.W4A16_MXFP4,
QuantAlgo.W4A8_MXFP4_FP8,
QuantAlgo.W4A8_MXFP4_MXFP8,
}
# Quantization algorithms that support gptoss_style
_GPTOSS_SUPPORTED_ALGOS = {
QuantAlgo.NVFP4,
QuantAlgo.W4A16_MXFP4,
QuantAlgo.W4A8_MXFP4_FP8,
QuantAlgo.W4A8_MXFP4_MXFP8,
}
@classmethod
def can_implement(
cls,
quant_algo: Optional[QuantAlgo],
dtype_activation: torch.dtype = torch.bfloat16,
gptoss_style: bool = False,
) -> Tuple[bool, Optional[str]]:
"""
Check if TRTLLMGenFusedMoE can implement the given quantization algorithm.
TRTLLMGenFusedMoE only supports SM in {100, 103} and the following quantizations:
- NVFP4
- FP8_BLOCK_SCALES
- W4A8_NVFP4_FP8
- W4A16_MXFP4
- W4A8_MXFP4_FP8
- W4A8_MXFP4_MXFP8
Does NOT support unquantized mode. Output dtype is hardcoded to bfloat16.
Args:
quant_algo: The quantization algorithm to check (None for unquantized)
dtype_activation: The activation input data type. Only bfloat16 is supported.
See: forward_impl() assert x.dtype == torch.bfloat16 (line 722).
gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled.
Only supported for nvfp4 and mxfp4 variants.
Returns:
Tuple[bool, Optional[str]]: (can_implement, skip_reason)
"""
from .interface import _warn_and_return
sm_version = get_sm_version()
# TRTLLMGenFusedMoE requires SM in {100, 103}
if sm_version not in {100, 103}:
return _warn_and_return(
f"TRTLLMGenFusedMoE requires SM100 or SM103, got SM{sm_version}"
)
# Check dtype_activation: only bfloat16 is supported
if dtype_activation != torch.bfloat16:
return _warn_and_return(
f"TRTLLMGenFusedMoE only supports bfloat16 activation, got {dtype_activation}"
)
# TRTLLMGenFusedMoE does NOT support unquantized mode
if quant_algo is None:
return _warn_and_return(
"TRTLLMGenFusedMoE does not support unquantized mode")
# Check if quant_algo is supported
if quant_algo not in cls._SUPPORTED_QUANT_ALGOS:
return _warn_and_return(
f"TRTLLMGenFusedMoE does not support quant_algo={quant_algo}")
# Check gptoss_style support: only supported for nvfp4 and mxfp4 variants
if gptoss_style and quant_algo not in cls._GPTOSS_SUPPORTED_ALGOS:
return _warn_and_return(
f"TRTLLMGenFusedMoE supports gptoss_style (bias/swiglu) only for nvfp4 and mxfp4 variants, "
f"got quant_algo={quant_algo}")
return True, None
def __init__(
self,
*,

View File

@ -7,7 +7,30 @@ from typing import Dict, List, Optional, Tuple, Union, final
import torch
from torch import nn
from tensorrt_llm.logger import logger
from tensorrt_llm.models.modeling_utils import QuantAlgo
from ...distributed.ops import reducescatter
def _warn_and_return(reason: str) -> Tuple[bool, Optional[str]]:
"""
Log a warning and return (False, reason) for can_implement() checks.
This is a common utility function used by all MoE backend implementations
to provide consistent logging and return values when a configuration
is not supported.
Args:
reason: The reason why the configuration is not supported.
Returns:
Tuple[bool, Optional[str]]: Always returns (False, reason)
"""
logger.warning(reason)
return False, reason
from ...model_config import ModelConfig
from ...utils import (ActivationType, AuxStreamType, Fp4QuantizedTensor,
get_model_extra_attrs, is_gated_activation,
@ -129,6 +152,40 @@ class MoE(nn.Module):
aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping.
"""
@classmethod
@abstractmethod
def can_implement(
cls,
quant_algo: Optional[QuantAlgo],
dtype_activation: torch.dtype = torch.bfloat16,
gptoss_style: bool = False,
) -> Tuple[bool, Optional[str]]:
"""
Check if this MoE backend can implement the given quantization algorithm.
NOTE: This is a TRANSITIONAL interface. In the future, this method will be moved
to the MoEBackend interface as part of the backend abstraction layer. During this
transition period, it remains in the MoE base class to maintain compatibility.
This method checks both:
1. Whether the backend supports the specified quantization algorithm
2. Whether the current platform (SM version) supports the backend and quantization
Each backend MUST override this method to provide accurate capability information.
Args:
quant_algo: The quantization algorithm to check (None for unquantized)
dtype_activation: The activation data type.
gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled.
Returns:
Tuple[bool, Optional[str]]: (can_implement, skip_reason)
- can_implement: True if the backend can implement this configuration
- skip_reason: None if can_implement is True, otherwise a string explaining why not
"""
raise NotImplementedError(
f"{cls.__name__} must implement can_implement method")
def __init__(
self,
*,

View File

@ -22,8 +22,6 @@ l0_dgx_b300:
- unittest/_torch/modeling -k "modeling_mixtral"
- unittest/_torch/modeling -k "modeling_gpt_oss"
- unittest/_torch/multi_gpu_modeling -k "deepseek"
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEP]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=FLASHINFER-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=FLASHINFER-torch_compile=False]

View File

@ -23,8 +23,6 @@ l0_gb300_multi_gpus:
- unittest/_torch/modeling -k "modeling_mixtral"
- unittest/_torch/modeling -k "modeling_gpt_oss"
- unittest/_torch/multi_gpu_modeling -k "deepseek"
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEP]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]

View File

@ -1,119 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TEMPORARY FILE - Will be removed after MoE refactor is complete.
#
# Background:
# The `enable_configurable_moe` parameter is a temporary measure during the MoE
# refactor. The old and new MoE flows will coexist for a period of time. To avoid
# large-scale changes to the existing test lists, we handle the test ID cleanup
# here. Once the refactor is complete and all tests use ConfigurableMoE by default,
# this file will no longer be needed and should be deleted.
#
# Two-phase approach:
# 1. pytest_sessionstart: Convert clean test names in CLI args back to original
# format so pytest can find tests during collection.
# 2. pytest_collection_modifyitems: Clean up the collected test IDs for display
# and waive matching.
import re
# Test functions that use enable_configurable_moe parameter and need ID conversion
TESTS_WITH_CONFIGURABLE_MOE = [
"test_fused_moe_nvfp4[",
"test_fused_moe_mxfp4_mxfp8[",
"test_fused_moe_w4a8_nvfp4_fp8[",
"test_fused_moe_wfp4a16[",
"test_fused_moe_fp8_blockwise_deepgemm[",
]
def _convert_clean_to_original_moe_test_id(test_id):
"""Convert clean MoE test ID back to original format for pytest collection.
Example: "test_fused_moe.py::test_foo[TRTLLM-dtype0]" -> "test_fused_moe.py::test_foo[-TRTLLM-dtype0]"
This is needed because the `enable_configurable_moe` parameter uses empty string
as ID when value is 0, resulting in test IDs like "test_foo[-TRTLLM-dtype0]".
We clean these up in pytest_collection_modifyitems, but pytest filters tests
during collection using the original IDs. So when user runs with clean test name,
we need to convert it back to match the original.
"""
if "test_fused_moe.py" not in test_id:
return test_id
# Match pattern like "test_name[params]" and add leading dash after "["
# But only if params don't already start with "-" or "enable_configurable_moe"
match = re.search(r"\[([^\]]+)\]", test_id)
if match:
params = match.group(1)
# Skip if already has leading dash or starts with enable_configurable_moe
if not params.startswith("-") and not params.startswith("enable_configurable_moe"):
# Add leading dash to params
new_params = "-" + params
test_id = test_id.replace(f"[{params}]", f"[{new_params}]")
return test_id
def pytest_sessionstart(session):
"""Convert clean MoE test IDs in config.args to original format for collection.
This is needed because pytest filters tests during collection using original IDs.
When user runs with clean test name, we convert it back to match the original.
"""
args = session.config.args
for i, arg in enumerate(args):
if "test_fused_moe.py" in arg and "[" in arg:
# Only apply conversion to specific tests that use enable_configurable_moe
should_convert = any(test_name in arg for test_name in TESTS_WITH_CONFIGURABLE_MOE)
if should_convert:
args[i] = _convert_clean_to_original_moe_test_id(arg)
def pytest_collection_modifyitems(items):
"""Clean up test IDs by removing leading/trailing dashes from parameter IDs.
This is needed because `enable_configurable_moe` parameter can be empty,
resulting in ugly test IDs like "test_foo[-True]" or "test_foo[--abc]".
We clean these up to "test_foo[True]" or "test_foo[abc]" so that:
1. Test names in waive files and test lists remain unchanged
2. Test reports look cleaner
This runs BEFORE the global conftest applies waives (due to hookwrapper).
"""
for item in items:
if "test_fused_moe.py" in item.nodeid and "[" in item.nodeid:
# Only apply cleanup to specific tests that use enable_configurable_moe
should_cleanup = any(
test_name in item.nodeid for test_name in TESTS_WITH_CONFIGURABLE_MOE
)
if should_cleanup:
original_nodeid = item.nodeid
original_name = item.name
nodeid = item.nodeid
name = item.name
# Clean up leading/trailing dashes in nodeid
nodeid = nodeid.replace("[-", "[")
nodeid = nodeid.replace("-]", "]")
# Clean up leading/trailing dashes in name
name = name.replace("[-", "[")
name = name.replace("-]", "]")
if nodeid != original_nodeid:
item._nodeid = nodeid
if name != original_name:
item.name = name

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,927 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MoE Backend Unit Tests
This module provides a unified test framework for testing different MoE backends
through the backend-level interfaces (quantize_input + run_moe), rather than
the high-level forward() interface.
Design Goals:
1. Test backend interfaces directly: routing_method.apply -> quantize_input -> run_moe
2. Cover all quantization + backend combinations
3. Use can_implement() interface to determine test skip logic
4. Support autotune and tactic capture testing
"""
import itertools
import logging
import time
from dataclasses import dataclass
from enum import Enum
from typing import Callable, List, Optional, Type
import pytest
import torch
from _torch.modules.moe.quantize_utils import get_test_quant_params
from transformers.configuration_utils import PretrainedConfig
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.modules.fused_moe import RenormalizeMoeRoutingMethod
from tensorrt_llm._torch.modules.fused_moe.create_moe import create_moe_backend
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import CuteDslFusedMoE
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import CutlassFusedMoE
from tensorrt_llm._torch.modules.fused_moe.fused_moe_deepgemm import DeepGemmFusedMoE
from tensorrt_llm._torch.modules.fused_moe.fused_moe_trtllm_gen import TRTLLMGenFusedMoE
from tensorrt_llm._torch.modules.fused_moe.interface import MoE
from tensorrt_llm._utils import mpi_rank
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantAlgo
logger = logging.getLogger(__name__)
class MoeBackendType(str, Enum):
"""Enum for MoE backend types."""
CUTLASS = "CUTLASS"
TRTLLM = "TRTLLM"
CUTEDSL = "CUTEDSL"
DEEPGEMM = "DEEPGEMM"
def get_backend_class(backend_type: MoeBackendType) -> Type[MoE]:
"""Get the MoE backend class for a given backend type."""
backend_class_map = {
MoeBackendType.CUTLASS: CutlassFusedMoE,
MoeBackendType.TRTLLM: TRTLLMGenFusedMoE,
MoeBackendType.CUTEDSL: CuteDslFusedMoE,
MoeBackendType.DEEPGEMM: DeepGemmFusedMoE,
}
return backend_class_map[backend_type]
def should_skip_TRTLLM(
backend_type: MoeBackendType,
quant_algo: Optional[QuantAlgo],
model_config: "MoeModelConfig",
) -> Optional[str]:
"""
Check TRTLLM Gen backend specific constraints.
The TRTLLM Gen MoE kernels have hardware-level constraints that must be satisfied.
These constraints are enforced in C++ layer.
Constraints:
1. num_experts must be divisible by 4 (routing kernel vectorization requirement)
2. num_experts must be greater than top_k (routing logic requirement)
Args:
backend_type: The MoE backend type
quant_algo: The quantization algorithm
model_config: The MoE model configuration
Returns:
Skip reason string if test should be skipped, None otherwise
"""
if backend_type != MoeBackendType.TRTLLM:
return None
if model_config is None:
return None
# These quantization algorithms use TRTLLM Gen kernels with the constraints
trtllm_gen_quant_algos = {
QuantAlgo.NVFP4,
QuantAlgo.FP8_BLOCK_SCALES,
QuantAlgo.W4A8_NVFP4_FP8,
QuantAlgo.W4A16_MXFP4,
QuantAlgo.W4A8_MXFP4_MXFP8,
}
if quant_algo not in trtllm_gen_quant_algos:
return None
num_experts = model_config.num_experts
top_k = model_config.top_k
intermediate_size = model_config.intermediate_size
# Check: num_experts must be divisible by 4
# Routing kernel uses vectorized operations that require this alignment
if num_experts % 4 != 0:
return (
f"TRTLLMGenFusedMoE routing kernel requires num_experts divisible by 4 "
f"(got num_experts={num_experts})"
)
# Check: num_experts must be greater than top_k
# Routing logic cannot handle the case where all experts are selected
if num_experts <= top_k:
return (
f"TRTLLMGenFusedMoE requires num_experts > top_k "
f"(got num_experts={num_experts}, top_k={top_k})"
)
# -----------------Potential issues------------------
# These are known issues that need investigation. Skipping to avoid test failures
# and CUDA errors that can cascade to subsequent tests.
# Issue 1: W4A8_NVFP4_FP8 with top_k=1 causes CUDA illegal memory access
# This triggers GPU state corruption that affects all subsequent tests.
# Affected config: e8_k1_h512_i512
if quant_algo == QuantAlgo.W4A8_NVFP4_FP8 and top_k == 1:
return (
"[Potential Bug] TRTLLMGenFusedMoE W4A8_NVFP4_FP8 with top_k=1 "
"causes CUDA illegal memory access. Needs kernel investigation."
)
# Issue 2: NVFP4 with large intermediate_size has known accuracy issues
# Observed mismatch: 18%~25% vs expected <7.5% (per test_moe.py baseline)
# Affected configs: e8_k2_h4096_i14336, e8_k2_h6144_i32768
if quant_algo == QuantAlgo.NVFP4 and intermediate_size >= 14336:
return (
f"[Potential Bug] TRTLLMGenFusedMoE NVFP4 with large intermediate_size "
f"has known accuracy issues (intermediate_size={intermediate_size} >= 14336). "
f"Observed mismatch 18%~25% exceeds expected threshold."
)
# Issue 3: W4A8_MXFP4_MXFP8 has accuracy issues on certain model configs
# Observed mismatch: 14%~18% vs expected <15% (percent=0.85)
# Affected configs: large intermediate_size or many experts
# e8_k2_h4096_i14336, e64_k6_h2048_i1408, e60_k4_h2048_i1408,
# e256_k8_h7168_i2048, e8_k2_h6144_i32768, e128_k4_h2880_i2880
if quant_algo == QuantAlgo.W4A8_MXFP4_MXFP8:
# Large intermediate_size (>= 14336) has precision issues
if intermediate_size >= 14336:
return (
f"[Potential Bug] TRTLLMGenFusedMoE W4A8_MXFP4_MXFP8 with large "
f"intermediate_size has accuracy issues (intermediate_size={intermediate_size} >= 14336). "
f"Observed mismatch 14%~18% exceeds 15% threshold."
)
# Many experts (>= 60) with moderate intermediate_size has precision issues
if num_experts >= 60 and intermediate_size >= 1408:
return (
f"[Potential Bug] TRTLLMGenFusedMoE W4A8_MXFP4_MXFP8 with many experts "
f"has accuracy issues (num_experts={num_experts} >= 60, intermediate_size={intermediate_size}). "
f"Observed mismatch 14%~18% exceeds 15% threshold."
)
return None
def should_skip_CUTEDSL(
backend_type: MoeBackendType,
quant_algo: Optional[QuantAlgo],
model_config: "MoeModelConfig" = None,
) -> Optional[str]:
"""
Check CuteDSL backend specific constraints.
The CuteDSL MoE kernels have known accuracy issues with certain configurations.
Args:
backend_type: The MoE backend type
quant_algo: The quantization algorithm
model_config: The MoE model configuration
Returns:
Skip reason string if test should be skipped, None otherwise
"""
if backend_type != MoeBackendType.CUTEDSL:
return None
if model_config is None:
return None
intermediate_size = model_config.intermediate_size
# -----------------Potential issues------------------
# NVFP4 with large intermediate_size has known accuracy issues (same as TRTLLM)
# Observed mismatch: 8%~26% vs expected <2%
# Affected configs: e8_k2_h4096_i14336, e8_k2_h6144_i32768
if quant_algo == QuantAlgo.NVFP4 and intermediate_size >= 14336:
return (
f"[Potential Bug] CuteDslFusedMoE NVFP4 with large intermediate_size "
f"has known accuracy issues (intermediate_size={intermediate_size} >= 14336). "
f"Observed mismatch 8%~26% exceeds 2% threshold."
)
# NVFP4 with prime num_experts (7, 13) causes CUDA_ERROR_ILLEGAL_ADDRESS
# Root cause: Autotuner cache bucket mapping issue
# - When tests run in batch, previous tests cache tactics to buckets
# - Prime num_experts shapes map to same bucket as other configs
# - The cached tactic (e.g., ((128, 256), (1, 2), False)) works for other configs
# but causes illegal memory access for prime num_experts' actual shape
# - Single test run passes because fallback tactic ((128, 128), (1, 1), False) is used
# Affected configs: e7_k2_h256_i512, e13_k3_h256_i512
num_experts = model_config.num_experts
prime_experts_with_issues = {7, 13}
if quant_algo == QuantAlgo.NVFP4 and num_experts in prime_experts_with_issues:
return (
f"[Potential Bug] CuteDslFusedMoE NVFP4 with prime num_experts={num_experts} "
f"causes CUDA_ERROR_ILLEGAL_ADDRESS due to autotuner cache bucket mapping. "
f"Cached tactic from other configs is incompatible with this shape."
)
return None
def should_skip_gptoss(
backend_type: MoeBackendType,
quant_algo: Optional[QuantAlgo],
gptoss_style: bool,
) -> Optional[str]:
"""
Check if gptoss_style test should be skipped for this backend.
Only CUTLASS and TRTLLM backends support gptoss_style (SwiGlu with custom
alpha/beta/limit parameters and bias).
Args:
backend_type: The MoE backend type
quant_algo: The quantization algorithm
gptoss_style: Whether gptoss_style is enabled
Returns:
Skip reason string if test should be skipped, None otherwise
"""
if not gptoss_style:
return None
# Only CUTLASS and TRTLLM backends support gptoss_style
supported_backends = {MoeBackendType.CUTLASS, MoeBackendType.TRTLLM}
if backend_type not in supported_backends:
return (
f"gptoss_style is only supported by CUTLASS and TRTLLM backends "
f"(got backend_type={backend_type.value})"
)
return None
def supports_autotuner_capture(
backend_type: MoeBackendType,
quant_algo: Optional[QuantAlgo],
) -> bool:
"""
Determine if a backend+quant_algo combination supports AutoTuner capture/replay.
AutoTuner capture/replay requires AutoTuner.choose_one() to be called during
run_moe execution.
Args:
backend_type: The MoE backend type
quant_algo: The quantization algorithm (None for unquantized)
Returns:
True if autotuner capture/replay is supported, False otherwise
"""
# DEEPGEMM does not support autotuner capture
# Evidence: fused_moe_deepgemm.py has no AutoTuner/choose_one references
if backend_type == MoeBackendType.DEEPGEMM:
return False
return True
def create_test_backend(
backend_type: MoeBackendType,
routing_method: RenormalizeMoeRoutingMethod,
num_experts: int,
hidden_size: int,
intermediate_size: int,
dtype: torch.dtype,
quant_config,
mapping: Mapping,
bias: bool = False,
swiglu_alpha: Optional[torch.Tensor] = None,
swiglu_beta: Optional[torch.Tensor] = None,
swiglu_limit: Optional[torch.Tensor] = None,
) -> MoE:
"""Create a MoE backend for testing."""
backend_cls = get_backend_class(backend_type)
pretrained_config = PretrainedConfig()
pretrained_config.num_experts = num_experts
pretrained_config.hidden_size = hidden_size
pretrained_config.intermediate_size = intermediate_size
pretrained_config.torch_dtype = dtype
model_config = ModelConfig(
pretrained_config=pretrained_config,
quant_config=quant_config,
mapping=mapping,
moe_backend=backend_type.value,
)
return create_moe_backend(
moe_cls=backend_cls,
routing_method=routing_method,
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
dtype=dtype,
reduce_results=True,
model_config=model_config,
init_load_balancer=False,
bias=bias,
swiglu_alpha=swiglu_alpha,
swiglu_beta=swiglu_beta,
swiglu_limit=swiglu_limit,
)
def run_backend_moe(
backend: MoE,
backend_type: MoeBackendType,
x_quantized: torch.Tensor,
x_sf: torch.Tensor,
token_selected_experts: torch.Tensor,
token_final_scales: torch.Tensor,
dtype: torch.dtype,
router_logits: torch.Tensor = None,
trtllm_use_router_logits: bool = True,
) -> torch.Tensor:
"""
Run MoE computation with backend-specific parameters.
Each backend has different requirements:
- CUTLASS: output_dtype, token_final_scales=float32
- TRTLLM: token_final_scales=bfloat16, optionally router_logits
- CUTEDSL: token_final_scales=float32
- DEEPGEMM: workspace, token_final_scales=float32
Args:
trtllm_use_router_logits: If True, TRTLLM backend uses router_logits for routing.
If False, uses token_selected_experts and token_final_scales.
Note: When both are provided, TRTLLM only uses (topk_ids and topk_weights).
"""
# Common args for all backends (default: token_final_scales=float32)
args = dict(
x=x_quantized,
token_selected_experts=token_selected_experts.to(torch.int32),
token_final_scales=token_final_scales.to(torch.float32),
x_sf=x_sf,
)
# Backend-specific overrides
if backend_type == MoeBackendType.CUTLASS:
args["output_dtype"] = dtype
elif backend_type == MoeBackendType.TRTLLM:
args["token_final_scales"] = token_final_scales.to(torch.bfloat16)
if trtllm_use_router_logits:
# Use router_logits for routing (TRTLLM will compute topk internally)
args["router_logits"] = router_logits
args["token_selected_experts"] = None
args["token_final_scales"] = None
# else: use token_selected_experts and token_final_scales (already set)
elif backend_type == MoeBackendType.DEEPGEMM:
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
m_max = fp8_utils.align(x_quantized.shape[0], 128)
args["workspace"] = backend.get_workspace(m_max, 128)
return backend.run_moe(**args)
def replay_tactics_and_check(
all_tactics,
run_moe_fn: Callable[[], torch.Tensor],
check_accuracy_fn: Callable[[torch.Tensor, torch.Tensor], None],
ref_output: torch.Tensor,
backend_type: MoeBackendType,
quant_algo: Optional[QuantAlgo],
fail_fast: bool = False,
) -> None:
"""
Replay all tactics and check accuracy.
Args:
all_tactics: TacticsCapture object from AutoTuner.capture()
run_moe_fn: Function to run MoE computation
check_accuracy_fn: Function to check accuracy (output, ref_output) -> None
ref_output: Reference output tensor
backend_type: Backend type for error reporting
quant_algo: Quantization algorithm for error reporting
fail_fast: If True, fail on first error. If False, run all and report summary.
"""
tactics_list = list(all_tactics)
passed_tactics = []
failed_tactics = []
logger.info(f"Replay tactics : {len(tactics_list)} and check accuracy")
for idx, tactic in enumerate(tactics_list):
with AutoTuner.get().replay(tactic), torch.inference_mode():
output = run_moe_fn()
try:
check_accuracy_fn(output, ref_output)
passed_tactics.append((idx, tactic))
except Exception as e:
if fail_fast:
pytest.fail(
f"Accuracy check failed for tactic[{idx}/{len(tactics_list)}]={tactic}, "
f"backend={backend_type}, quant_algo={quant_algo}: {e}"
)
failed_tactics.append((idx, tactic, str(e)))
# Report results (only when fail_fast=False)
total = len(tactics_list)
num_passed = len(passed_tactics)
num_failed = len(failed_tactics)
if failed_tactics:
fail_details = "\n".join(
f" tactic[{idx}]={tactic}: {err}" for idx, tactic, err in failed_tactics
)
pytest.fail(
f"backend={backend_type}, quant_algo={quant_algo}: "
f"{num_passed}/{total} passed, {num_failed}/{total} failed\n"
f"Failed tactics:\n{fail_details}"
)
# ============================================================================
# Test Parameters
# ============================================================================
# Quantization algorithms to test
QUANT_ALGOS_TO_TEST = [
None, # Unquantized
QuantAlgo.FP8,
QuantAlgo.NVFP4,
QuantAlgo.FP8_BLOCK_SCALES,
QuantAlgo.W4A8_NVFP4_FP8,
QuantAlgo.W4A16_MXFP4,
QuantAlgo.W4A8_MXFP4_MXFP8,
QuantAlgo.W8A16,
QuantAlgo.W4A8_AWQ,
]
# Backend types to test
BACKEND_TYPES_TO_TEST = [
MoeBackendType.CUTLASS,
MoeBackendType.TRTLLM,
MoeBackendType.CUTEDSL,
MoeBackendType.DEEPGEMM,
]
# Data types to test
DTYPES_TO_TEST = [
torch.float16,
torch.bfloat16,
]
# ============================================================================
# Model MoE Configurations
# ============================================================================
@dataclass
class MoeModelConfig:
"""MoE model configuration: (num_experts, top_k, hidden_size, intermediate_size)."""
num_experts: int
top_k: int
hidden_size: int
intermediate_size: int
def __str__(self) -> str:
return f"e{self.num_experts}_k{self.top_k}_h{self.hidden_size}_i{self.intermediate_size}"
# Format: (num_experts, top_k, hidden_size, intermediate_size)
MOE_MODEL_CONFIGS = [
# === Real Model Configs ===
MoeModelConfig(8, 2, 4096, 14336), # Mixtral-8x7B
MoeModelConfig(64, 6, 2048, 1408), # DeepSeek-MoE-16B / DeepSeek-V2-Lite
MoeModelConfig(60, 4, 2048, 1408), # Qwen1.5-MoE-A2.7B
MoeModelConfig(256, 8, 7168, 2048), # DeepSeek-V3
MoeModelConfig(8, 2, 6144, 32768), # Grok-1
MoeModelConfig(128, 4, 2880, 2880), # GPT-OSS-120B
# === Boundary Tests: num_experts / top_k ===
MoeModelConfig(8, 1, 512, 512), # top_k=1, single expert activated
MoeModelConfig(4, 4, 512, 512), # top_k=num_experts, all experts activated
MoeModelConfig(7, 2, 256, 512), # prime num_experts
MoeModelConfig(13, 3, 256, 512), # prime num_experts, odd top_k
# === Boundary Tests: small sizes ===
MoeModelConfig(4, 2, 64, 128), # very small hidden_size
MoeModelConfig(4, 2, 128, 64), # intermediate < hidden
]
# Sequence lengths to test
SEQ_LENS_TO_TEST = [1, 8]
# SwiGLU parameters for gptoss_style testing
SWIGLU_ALPHAS = [1, 0.1]
SWIGLU_BETAS = [0, 1]
SWIGLU_LIMITS = [float("inf"), 1]
# ============================================================================
# Fast Skip Check (for parametrize-level skip, avoids entering test function)
# ============================================================================
def get_quick_skip_reason(
backend_type: MoeBackendType,
quant_algo: Optional[QuantAlgo],
dtype: torch.dtype,
model_config: "MoeModelConfig",
gptoss_style: bool,
) -> Optional[str]:
"""
Fast skip check that calls backend's can_implement() method.
This function calls the backend's can_implement() classmethod to check
dtype/quant_algo/gptoss_style support, then uses should_skip_* functions
for additional model_config specific checks.
Note: Logging is temporarily suppressed to avoid excessive warning output
during test parameter generation.
Returns:
Skip reason string if test should be skipped, None otherwise
"""
import logging as _logging
# Suppress logger warnings during parameter generation to avoid excessive output
trtllm_logger = _logging.getLogger("tensorrt_llm")
original_level = trtllm_logger.level
trtllm_logger.setLevel(_logging.ERROR)
try:
# ===== Call backend's can_implement for dtype/quant_algo/gptoss_style checks =====
backend_cls = get_backend_class(backend_type)
can_impl, skip_reason = backend_cls.can_implement(
quant_algo, dtype_activation=dtype, gptoss_style=gptoss_style
)
if not can_impl:
return skip_reason
# ===== Additional model_config specific checks =====
# TRTLLM: num_experts constraints and accuracy issues
skip_reason = should_skip_TRTLLM(backend_type, quant_algo, model_config)
if skip_reason:
return skip_reason
# CUTEDSL: accuracy issues with specific configs
skip_reason = should_skip_CUTEDSL(backend_type, quant_algo, model_config)
if skip_reason:
return skip_reason
# DEEPGEMM: float16 reference module constraint
if backend_type == MoeBackendType.DEEPGEMM and dtype == torch.float16:
return "DeepGemmFusedMoE reference module (FP8BlockScalesLinearMethod) requires bfloat16 input"
# 128-alignment requirement for quantization
if quant_algo is not None:
hidden_size = model_config.hidden_size
intermediate_size = model_config.intermediate_size
is_hidden_128_aligned = hidden_size % 128 == 0
is_intermediate_128_aligned = intermediate_size % 128 == 0
if not is_hidden_128_aligned or not is_intermediate_128_aligned:
# TRTLLM with MXFP4 variants automatically pads to 128 alignment
is_mxfp4_variant = quant_algo in {QuantAlgo.W4A16_MXFP4, QuantAlgo.W4A8_MXFP4_MXFP8}
is_trtllm_backend = backend_type == MoeBackendType.TRTLLM
if not (is_trtllm_backend and is_mxfp4_variant):
return (
f"Non-128-aligned sizes (h={hidden_size}, i={intermediate_size}) "
f"require TRTLLM backend with MXFP4 quantization"
)
return None
finally:
# Restore logger level
trtllm_logger.setLevel(original_level)
def generate_test_params() -> List:
"""
Generate all test parameter combinations with skip marks for invalid combinations.
This function pre-computes skip decisions at collection time using static rules,
avoiding the overhead of entering test functions and calling can_implement().
This significantly speeds up test collection and skip execution.
Returns:
List of pytest.param objects with appropriate skip marks
"""
params: List = []
# Generate all combinations
swiglu_combos = list(itertools.product(SWIGLU_ALPHAS, SWIGLU_BETAS, SWIGLU_LIMITS))
for swiglu_alpha, swiglu_beta, swiglu_limit in swiglu_combos:
for model_config in MOE_MODEL_CONFIGS:
for seq_len in SEQ_LENS_TO_TEST:
for dtype in DTYPES_TO_TEST:
for backend_type in BACKEND_TYPES_TO_TEST:
for quant_algo in QUANT_ALGOS_TO_TEST:
# Determine gptoss_style
gptoss_style = (
swiglu_alpha != 1
or swiglu_beta != 0
or swiglu_limit != float("inf")
)
# Generate test ID
test_id = (
f"alpha={swiglu_alpha}_beta={swiglu_beta}_limit={swiglu_limit}-"
f"{model_config}-seq={seq_len}-dtype={dtype}-"
f"backend={backend_type.value}-quant_algo={quant_algo}"
)
# Check if should skip
skip_reason = get_quick_skip_reason(
backend_type, quant_algo, dtype, model_config, gptoss_style
)
param_values = (
dtype,
backend_type,
quant_algo,
seq_len,
model_config,
swiglu_alpha,
swiglu_beta,
swiglu_limit,
)
if skip_reason:
params.append(
pytest.param(
*param_values,
id=test_id,
marks=pytest.mark.skip(reason=skip_reason),
)
)
else:
params.append(pytest.param(*param_values, id=test_id))
return params
# Pre-generate test parameters at module load time
TEST_PARAMS = generate_test_params()
# ============================================================================
# Timing Fixtures
# ============================================================================
@pytest.fixture(scope="module", autouse=True)
def module_timer(request):
"""Fixture to measure and log total module execution time."""
start = time.perf_counter()
yield
elapsed = time.perf_counter() - start
logger.info(
"[TIMING] Total %s: %.3fs (%.2f min)",
request.module.__name__,
elapsed,
elapsed / 60,
)
# ============================================================================
# Test Implementation
# ============================================================================
#
# This file provides a UNIFIED TEST FRAMEWORK for testing all MoE backend
# implementations through their backend-level interfaces.
#
# =============================================================================
# Purpose & Scope
# =============================================================================
# - Test MoE backends via: routing_method.apply -> quantize_input -> run_moe
# - Single GPU execution (no multi-GPU/distributed testing)
# - Accuracy validation against reference implementations
#
# =============================================================================
# Test Coverage Matrix
# =============================================================================
# 1. BACKENDS: CUTLASS, TRTLLM, CUTEDSL, DEEPGEMM
#
# 2. QUANTIZATION ALGORITHMS:
# - Unquantized (None)
# - FP8, FP8_BLOCK_SCALES
# - NVFP4, W4A8_NVFP4_FP8
# - W4A16_MXFP4, W4A8_MXFP4_MXFP8
# - W8A16, W4A8_AWQ
#
# 3. ACTIVATION DTYPES: float16, bfloat16
#
# 4. AUTOTUNER TACTICS:
# - Autotune phase: find optimal tactics via AutoTuner
# - Capture phase: record all tactics used
# - Replay phase: verify each tactic produces correct results
#
# 5. GPTOSS_STYLE (SwiGLU with custom parameters):
# - swiglu_alpha: scaling factor (default=1)
# - swiglu_beta: bias term (default=0)
# - swiglu_limit: clipping limit (default=inf)
# - Supported by: CUTLASS (W4A8_MXFP4_MXFP8), TRTLLM (W4A8_MXFP4_MXFP8)
#
# 6. MODEL CONFIGURATIONS:
# - Real models: Mixtral, DeepSeek, Qwen, Grok, GPT-OSS
# - Boundary cases: prime num_experts, small sizes, top_k=1, top_k=num_experts
#
# =============================================================================
# Skip Logic
# =============================================================================
# Tests are automatically skipped for unsupported configurations using:
# - backend.can_implement(): Check dtype/quant_algo/gptoss_style support
# - should_skip_TRTLLM(): TRTLLM-specific constraints (num_experts % 4, etc.)
# - should_skip_CUTEDSL(): CuteDSL-specific accuracy issues
# - 128-alignment requirements for quantization
#
# =============================================================================
@pytest.mark.skip(reason="Temporarily skipped due to the long time to run the test")
@pytest.mark.parametrize(
"dtype_activation,backend_type,quant_algo,seq_len,model_config,swiglu_alpha,swiglu_beta,swiglu_limit",
TEST_PARAMS,
)
def test_moe_backend(
dtype_activation: torch.dtype,
backend_type: MoeBackendType,
quant_algo: Optional[QuantAlgo],
seq_len: int,
model_config: MoeModelConfig,
swiglu_alpha: float,
swiglu_beta: float,
swiglu_limit: float,
):
"""
Test MoE backend with autotune to capture all tactics.
This test verifies:
1. Autotune works correctly with the backend
2. All tactics are captured properly
3. Different sequence lengths use appropriate tactics
4. gptoss_style (SwiGlu with custom parameters) works correctly
"""
# Determine gptoss_style based on swiglu parameters
# gptoss_style is True when any swiglu parameter deviates from default
# Default values: alpha=1, beta=0, limit=inf
gptoss_style = swiglu_alpha != 1 or swiglu_beta != 0 or swiglu_limit != float("inf")
# Note: Skip logic is now handled at parametrize level via get_quick_skip_reason()
# which calls backend's can_implement() and should_skip_* functions.
# This avoids entering test function for invalid combinations, significantly
# reducing test collection time (from ~17 min to ~5 sec for 3400+ skipped tests).
# Extract model parameters
num_experts = model_config.num_experts
top_k = model_config.top_k
hidden_size = model_config.hidden_size
intermediate_size = model_config.intermediate_size
# Create mapping
mapping = Mapping()
mapping.rank = mpi_rank()
with torch.device(f"cuda:{mapping.rank}"):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# Setup autotuner distributed state
AutoTuner.get().setup_distributed_state(mapping)
# Create routing method
routing_method = RenormalizeMoeRoutingMethod(top_k=top_k)
# Create test inputs
x = torch.randn((seq_len, hidden_size), dtype=dtype_activation, device="cuda")
router_logits = torch.randn((seq_len, num_experts), dtype=dtype_activation, device="cuda")
# Get quantization parameters
# Pass backend_type to determine scale format (DEEPGEMM/TRTLLM need E8M0 scale)
quantize_util_cls, quant_config, quant_kwargs = get_test_quant_params(
quant_algo, x, backend_type
)
# Create quantize utility with gptoss_style parameters
quantize_util = quantize_util_cls(
num_experts=num_experts,
dtype=dtype_activation,
intermediate_size=intermediate_size,
hidden_size=hidden_size,
quant_config=quant_config,
bias=gptoss_style,
gptoss_style=gptoss_style,
swiglu_alpha=swiglu_alpha if gptoss_style else None,
swiglu_beta=swiglu_beta if gptoss_style else None,
swiglu_limit=swiglu_limit if gptoss_style else None,
)
# Get swiglu tensors if gptoss_style is enabled
swiglu_tensors = quantize_util.get_swiglu_tensors()
# Create backend first (needed for MXFP4_MXFP8 to get shapes)
backend = create_test_backend(
backend_type=backend_type,
routing_method=routing_method,
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
dtype=dtype_activation,
quant_config=quant_config,
mapping=mapping,
bias=gptoss_style,
swiglu_alpha=swiglu_tensors["swiglu_alpha"] if swiglu_tensors else None,
swiglu_beta=swiglu_tensors["swiglu_beta"] if swiglu_tensors else None,
swiglu_limit=swiglu_tensors["swiglu_limit"] if swiglu_tensors else None,
)
# W4A8_MXFP4_MXFP8 requires different weights for backend and reference
# due to different padding/alignment requirements
ref_cls = quant_kwargs.pop("ref_cls", None)
ref_module_kwargs = {}
if quant_algo == QuantAlgo.W4A8_MXFP4_MXFP8:
weights, ref_weights, ref_module_kwargs = quantize_util.prepare_weights_from_backend(
backend, **quant_kwargs
)
else:
weights = quantize_util.create_weights(**quant_kwargs)
ref_weights = weights
backend.load_weights([weights])
backend.post_load_weights()
backend.cuda()
# Create reference
if ref_cls is not None:
ref_fused_moe = quantize_util.create_ref_module(
routing_method, ref_cls=ref_cls, **ref_module_kwargs
)
else:
ref_fused_moe = quantize_util.create_ref_module(routing_method, **ref_module_kwargs)
ref_fused_moe.load_weights([ref_weights])
ref_fused_moe.cuda()
# Clear autotuner cache before autotune phase
AutoTuner.get().clear_cache()
# Get reference output first
with torch.inference_mode():
ref_output = ref_fused_moe.forward(x, router_logits)
# Helper to run MoE computation
def run_moe():
token_selected_experts, token_final_scales = routing_method.apply(router_logits)
x_quantized, x_sf = backend.quantize_input(x, post_quant_comm=False)
return run_backend_moe(
backend,
backend_type,
x_quantized,
x_sf,
token_selected_experts,
token_final_scales,
dtype_activation,
router_logits,
)
# Configure AutoTuner for faster profiling (reduce warmup/repeat for unit tests)
autotuner = AutoTuner.get()
autotuner.warmup = 0 # default: 2
autotuner.repeat = 1 # default: 10
autotuner.stream_delay_micro_secs = 10 # default: 1000
# Autotune phase: tune kernels to find best tactics
# Use cache_path to speed up subsequent runs by reusing tuning results
with torch.inference_mode(), autotune(cache_path="/tmp/moe_autotuner_cache.json"):
_ = run_moe()
# Check if this backend+quant_algo combination supports autotuner capture/replay
if supports_autotuner_capture(backend_type, quant_algo):
# Capture phase: record which tactics are used (requires actual execution)
with AutoTuner.get().capture() as all_tactics, torch.inference_mode():
_ = run_moe()
# Replay phase: test each tactic for correctness
# Set fail_fast=True to stop on first failure, False to run all and report summary
replay_tactics_and_check(
all_tactics=all_tactics,
run_moe_fn=run_moe,
check_accuracy_fn=ref_fused_moe.check_accuracy,
ref_output=ref_output,
backend_type=backend_type,
quant_algo=quant_algo,
fail_fast=False, # Change to True to fail on first error
)
else:
# For backends that don't support autotuner capture/replay,
# just run a simple accuracy check
with torch.inference_mode():
output = run_moe()
ref_fused_moe.check_accuracy(output, ref_output)

View File

@ -864,23 +864,13 @@ def test_fused_moe_fp8_blockwise_wide_ep(alltoall_method_type):
[DefaultMoeRoutingMethod],
),
)
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
ids=lambda x: ""
if x == 0 else "enable_configurable_moe")
def test_fused_moe_fp8_blockwise_deepgemm(dtype,
num_experts,
seq_len,
hidden_size,
RoutingMethodCls,
enable_configurable_moe,
mocker,
mapping=None):
mocker.patch.dict(os.environ, {
"ENABLE_CONFIGURABLE_MOE":
"1" if enable_configurable_moe == 1 else "0"
})
SEQ_LEN = seq_len
HIDDEN_SIZE = hidden_size
INTERMEDIATE_SIZE = 256
@ -1388,25 +1378,7 @@ def test_fused_moe_fp8_blockwise_cute_dsl_multi_gpu(ep_size, routing_method,
@pytest.mark.parametrize(
"finalize_fusion", [True, False],
ids=["enable_finalize_fusion", "disable_finalize_fusion"])
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
ids=lambda x: ""
if x == 0 else "enable_configurable_moe")
def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion,
enable_configurable_moe, mocker):
if enable_configurable_moe == 1 and moe_backend not in [
"TRTLLM", "CUTLASS"
]:
pytest.skip(
"ENABLE_CONFIGURABLE_MOE=1, only TRTLLM and CUTLASS backend are enabled"
)
mocker.patch.dict(
os.environ, {
"ENABLE_CONFIGURABLE_MOE":
"1" if enable_configurable_moe == 1
and moe_backend in ["TRTLLM", "CUTLASS"] else "0"
})
def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion):
run_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion)
@ -1417,17 +1389,8 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion,
@pytest.mark.parametrize("swiglu_beta", [0, 1], ids=lambda v: f"beta{v}")
@pytest.mark.parametrize("swiglu_limit", [float("inf"), 1],
ids=lambda v: f"limit{v}")
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
ids=lambda x: ""
if x == 0 else "enable_configurable_moe")
def test_fused_moe_nvfp4_gptoss_style(hidden_size, intermediate_size,
swiglu_alpha, swiglu_beta, swiglu_limit,
enable_configurable_moe, mocker):
mocker.patch.dict(os.environ, {
"ENABLE_CONFIGURABLE_MOE":
"1" if enable_configurable_moe == 1 else "0"
})
swiglu_alpha, swiglu_beta, swiglu_limit):
run_fused_moe_nvfp4(dtype=torch.bfloat16,
moe_backend="TRTLLM",
finalize_fusion=False,
@ -1686,15 +1649,7 @@ def run_fused_moe_nvfp4(dtype,
@pytest.mark.parametrize(
"moe_backend",
[pytest.param("TRTLLM", marks=skip_blackwell_geforce), "CUTLASS"])
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
ids=lambda x: ""
if x == 0 else "enable_configurable_moe")
def test_fused_moe_w4a8_nvfp4_fp8(moe_backend, enable_configurable_moe, mocker):
mocker.patch.dict(os.environ, {
"ENABLE_CONFIGURABLE_MOE":
"1" if enable_configurable_moe == 1 else "0"
})
def test_fused_moe_w4a8_nvfp4_fp8(moe_backend):
dtype = torch.bfloat16
mapping = Mapping()
mapping.rank = mpi_rank()
@ -2109,20 +2064,7 @@ def test_fused_moe_w4afp8(dtype, weight_loading_mode):
@pytest.mark.parametrize("hidden_unpadded", [64, 192, 256])
@pytest.mark.parametrize("seq_len", [8, 128])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
ids=lambda x: ""
if x == 0 else "enable_configurable_moe")
def test_fused_moe_mxfp4_mxfp8(moe_backend, hidden_unpadded, seq_len, bias,
enable_configurable_moe, mocker):
mocker.patch.dict(os.environ, {
"ENABLE_CONFIGURABLE_MOE":
"1" if enable_configurable_moe == 1 else "0"
})
if moe_backend == "CUTLASS" and hidden_unpadded % 128 != 0:
pytest.skip()
def test_fused_moe_mxfp4_mxfp8(moe_backend, hidden_unpadded, seq_len, bias):
SCALING_VECTOR_SIZE = 32
dtype = torch.bfloat16
SEQ_LEN = seq_len
@ -2379,17 +2321,7 @@ def test_fused_moe_mxfp4_mxfp8(moe_backend, hidden_unpadded, seq_len, bias,
marks=[skip_pre_hopper, skip_blackwell, skip_blackwell_geforce]),
],
)
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
ids=lambda x: ""
if x == 0 else "enable_configurable_moe")
def test_fused_moe_wfp4a16(dtype, hidden_size, moe_backend,
enable_configurable_moe, mocker):
mocker.patch.dict(os.environ, {
"ENABLE_CONFIGURABLE_MOE":
"1" if enable_configurable_moe == 1 else "0"
})
def test_fused_moe_wfp4a16(dtype, hidden_size, moe_backend):
mapping = Mapping()
mapping.rank = mpi_rank()