mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[TRTLLM-9111][feat] provide the uniform test framework to test all MoE backends (#11128)
Signed-off-by: xxi <xxi@nvidia.com>
This commit is contained in:
parent
de6931bbfd
commit
02b80bfd58
@ -95,6 +95,35 @@ class ConfigurableMoE(MoE):
|
||||
- Communication: Auto-selected based on hardware (NVLINK > DeepEP > AllGather)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls,
|
||||
quant_algo,
|
||||
dtype_activation: torch.dtype = torch.bfloat16,
|
||||
gptoss_style: bool = False,
|
||||
):
|
||||
"""
|
||||
ConfigurableMoE is a wrapper class that delegates to specific backends.
|
||||
|
||||
To check capability, query the specific backend class directly:
|
||||
- CutlassFusedMoE.can_implement(quant_algo, dtype_activation, gptoss_style)
|
||||
- TRTLLMGenFusedMoE.can_implement(quant_algo, dtype_activation, gptoss_style)
|
||||
- etc.
|
||||
|
||||
Args:
|
||||
quant_algo: The quantization algorithm to check (None for unquantized)
|
||||
dtype_activation: The activation data type
|
||||
gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: Always returns (False, reason)
|
||||
"""
|
||||
del quant_algo, dtype_activation, gptoss_style # Unused - wrapper class
|
||||
return False, (
|
||||
"ConfigurableMoE is a wrapper class. "
|
||||
"Query the specific backend (CutlassFusedMoE, TRTLLMGenFusedMoE, etc.) directly."
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -4,7 +4,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tensorrt_llm._utils import is_sm_100f
|
||||
from tensorrt_llm._utils import get_sm_version, is_sm_100f
|
||||
from tensorrt_llm.models.modeling_utils import QuantAlgo
|
||||
|
||||
from ...autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
|
||||
OptimizationProfile, TunableRunner, TuningConfig)
|
||||
@ -312,6 +313,69 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
model_config (ModelConfig): Configuration object for the model.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls,
|
||||
quant_algo: Optional[QuantAlgo],
|
||||
dtype_activation: torch.dtype = torch.bfloat16,
|
||||
gptoss_style: bool = False,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Check if CuteDslFusedMoE can implement the given quantization algorithm.
|
||||
|
||||
CuteDslFusedMoE supports:
|
||||
- NVFP4: SM in {100, 103}
|
||||
|
||||
Does NOT support unquantized mode. Output dtype is hardcoded to bfloat16.
|
||||
Does NOT support gptoss_style (bias/swiglu with custom alpha/beta/limit).
|
||||
|
||||
Args:
|
||||
quant_algo: The quantization algorithm to check (None for unquantized)
|
||||
dtype_activation: The activation input data type. Only bfloat16 is supported
|
||||
because output dtype is hardcoded to bfloat16 (input/output dtype must match).
|
||||
gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled.
|
||||
CuteDslFusedMoE does NOT support gptoss_style.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (can_implement, skip_reason)
|
||||
"""
|
||||
from .interface import _warn_and_return
|
||||
|
||||
sm_version = get_sm_version()
|
||||
|
||||
# CuteDslFusedMoE requires at least SM90
|
||||
if sm_version < 90:
|
||||
return _warn_and_return(
|
||||
f"CuteDslFusedMoE requires SM >= 90, got SM{sm_version}")
|
||||
|
||||
# Check dtype_activation: output is hardcoded to bfloat16, so input must also be bfloat16
|
||||
# to maintain input/output dtype consistency
|
||||
if dtype_activation != torch.bfloat16:
|
||||
return _warn_and_return(
|
||||
f"CuteDslFusedMoE only supports bfloat16 activation (output is hardcoded to bfloat16), "
|
||||
f"got {dtype_activation}")
|
||||
|
||||
# CuteDslFusedMoE does NOT support unquantized mode
|
||||
if quant_algo is None:
|
||||
return _warn_and_return(
|
||||
"CuteDslFusedMoE does not support unquantized mode")
|
||||
|
||||
# CuteDslFusedMoE does NOT support gptoss_style
|
||||
if gptoss_style:
|
||||
return _warn_and_return(
|
||||
"CuteDslFusedMoE does not support gptoss_style (bias/swiglu with custom alpha/beta/limit)"
|
||||
)
|
||||
|
||||
# NVFP4 - SM in {100, 103}
|
||||
if quant_algo == QuantAlgo.NVFP4:
|
||||
if sm_version not in {100, 103}:
|
||||
return _warn_and_return(
|
||||
f"NVFP4 requires SM100 or SM103, got SM{sm_version}")
|
||||
return True, None
|
||||
|
||||
return _warn_and_return(
|
||||
f"CuteDslFusedMoE does not support quant_algo={quant_algo}")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -7,7 +7,9 @@ import torch
|
||||
|
||||
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe
|
||||
from tensorrt_llm._torch.distributed.moe_alltoall import MoeAlltoAll
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.models.modeling_utils import QuantAlgo
|
||||
from tensorrt_llm.tools.layer_wise_benchmarks import get_calibrator
|
||||
|
||||
from ...distributed import allgather
|
||||
@ -57,6 +59,151 @@ class CutlassFusedMoE(MoE):
|
||||
equals to: dynamic quant + routing(topK, etc.) [+ fp4_allgather] + scatter + gemm1 + swiglu + gemm2 + finalizeMoeRoute [no allreduce] + reducescatter
|
||||
"""
|
||||
|
||||
# Quantization algorithm support table for can_implement()
|
||||
# Format: quant_algo -> {sm_constraint, dtypes}
|
||||
# sm_constraint types:
|
||||
# - ("min", N): SM >= N
|
||||
# - ("exact", N): SM == N
|
||||
# - ("in", {N1, N2, ...}): SM in set
|
||||
_QUANT_SUPPORT_TABLE = {
|
||||
# Unquantized (FP16/BF16): SM >= 80
|
||||
None: {
|
||||
"sm_constraint": ("min", 80),
|
||||
"dtypes": {torch.float16, torch.bfloat16},
|
||||
},
|
||||
# FP8 per-tensor (QDQ): SM >= 89
|
||||
QuantAlgo.FP8: {
|
||||
"sm_constraint": ("min", 89),
|
||||
"dtypes": {torch.float16, torch.bfloat16, torch.float32},
|
||||
},
|
||||
# FP8_BLOCK_SCALES: SM == 90 only
|
||||
QuantAlgo.FP8_BLOCK_SCALES: {
|
||||
"sm_constraint": ("exact", 90),
|
||||
"dtypes": {torch.float16, torch.bfloat16, torch.float32},
|
||||
},
|
||||
# NVFP4: SM in {100, 103}
|
||||
QuantAlgo.NVFP4: {
|
||||
"sm_constraint": ("in", {100, 103}),
|
||||
"dtypes": {torch.float16, torch.bfloat16, torch.float8_e4m3fn},
|
||||
},
|
||||
# W4A8_AWQ: SM in {89, 90} only
|
||||
QuantAlgo.W4A8_AWQ: {
|
||||
"sm_constraint": ("in", {89, 90}),
|
||||
"dtypes": {torch.float16, torch.bfloat16},
|
||||
},
|
||||
# W8A16: SM >= 80
|
||||
QuantAlgo.W8A16: {
|
||||
"sm_constraint": ("min", 80),
|
||||
"dtypes": {torch.float16, torch.bfloat16},
|
||||
},
|
||||
# W4A16_MXFP4: SM == 90 only
|
||||
QuantAlgo.W4A16_MXFP4: {
|
||||
"sm_constraint": ("exact", 90),
|
||||
"dtypes": {torch.float16, torch.bfloat16},
|
||||
},
|
||||
# W4A8_MXFP4_FP8: SM in {100, 103}
|
||||
QuantAlgo.W4A8_MXFP4_FP8: {
|
||||
"sm_constraint": ("in", {100, 103}),
|
||||
"dtypes": {torch.float16, torch.bfloat16, torch.float32},
|
||||
},
|
||||
# W4A8_MXFP4_MXFP8: SM in {100, 103}
|
||||
QuantAlgo.W4A8_MXFP4_MXFP8: {
|
||||
"sm_constraint": ("in", {100, 103}),
|
||||
"dtypes": {torch.float16, torch.bfloat16},
|
||||
},
|
||||
}
|
||||
|
||||
# Quantization algorithms that support gptoss_style
|
||||
_GPTOSS_SUPPORTED_ALGOS = {QuantAlgo.W4A8_MXFP4_MXFP8}
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls,
|
||||
quant_algo: Optional[QuantAlgo],
|
||||
dtype_activation: torch.dtype = torch.bfloat16,
|
||||
gptoss_style: bool = False,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Check if CutlassFusedMoE can implement the given quantization algorithm.
|
||||
|
||||
CutlassFusedMoE supports:
|
||||
- Unquantized (FP16/BF16): SM >= 80
|
||||
- FP8 per-tensor (QDQ): SM >= 89
|
||||
- FP8_BLOCK_SCALES: SM == 90 only
|
||||
- NVFP4: SM in {100, 103}
|
||||
- W4A8_AWQ: SM in {89, 90} only
|
||||
- W8A16: SM >= 80
|
||||
- W4A16_MXFP4: SM == 90 only
|
||||
- W4A8_MXFP4_FP8: SM in {100, 103}
|
||||
- W4A8_MXFP4_MXFP8: SM in {100, 103}
|
||||
|
||||
Args:
|
||||
quant_algo: The quantization algorithm to check (None for unquantized)
|
||||
dtype_activation: The activation input data type (before quantization).
|
||||
Supported dtypes vary by quantization mode:
|
||||
- Unquantized: float16, bfloat16
|
||||
- FP8/FP8_BLOCK_SCALES/W4A8_MXFP4_FP8: float16, bfloat16, float32
|
||||
- NVFP4: float16, bfloat16, float8_e4m3fn
|
||||
- W4A16_MXFP4/W4A8_AWQ/W8A16/W4A8_MXFP4_MXFP8: float16, bfloat16
|
||||
gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled.
|
||||
CutlassFusedMoE only supports gptoss_style for W4A8_MXFP4_MXFP8 quantization.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (can_implement, skip_reason)
|
||||
"""
|
||||
from .interface import _warn_and_return
|
||||
|
||||
sm_version = get_sm_version()
|
||||
|
||||
# Check minimum SM version for Cutlass backend
|
||||
if sm_version < 80:
|
||||
return _warn_and_return(
|
||||
f"CutlassFusedMoE requires SM >= 80, got SM{sm_version}")
|
||||
|
||||
# Check gptoss_style support
|
||||
if gptoss_style and quant_algo not in cls._GPTOSS_SUPPORTED_ALGOS:
|
||||
return _warn_and_return(
|
||||
f"CutlassFusedMoE gptoss_style only supports W4A8_MXFP4_MXFP8 "
|
||||
f"(got quant_algo={quant_algo})")
|
||||
|
||||
# Check if quant_algo is supported
|
||||
if quant_algo not in cls._QUANT_SUPPORT_TABLE:
|
||||
return _warn_and_return(
|
||||
f"CutlassFusedMoE does not support quant_algo={quant_algo}")
|
||||
|
||||
support_info = cls._QUANT_SUPPORT_TABLE[quant_algo]
|
||||
|
||||
# Check SM version constraint
|
||||
constraint_type, constraint_value = support_info["sm_constraint"]
|
||||
algo_name = "unquantized" if quant_algo is None else quant_algo.name
|
||||
|
||||
if constraint_type == "min":
|
||||
if sm_version < constraint_value:
|
||||
return _warn_and_return(
|
||||
f"CutlassFusedMoE {algo_name} requires SM >= {constraint_value}, "
|
||||
f"got SM{sm_version}")
|
||||
elif constraint_type == "exact":
|
||||
if sm_version != constraint_value:
|
||||
return _warn_and_return(
|
||||
f"CutlassFusedMoE {algo_name} only supports SM{constraint_value}, "
|
||||
f"got SM{sm_version}")
|
||||
elif constraint_type == "in":
|
||||
if sm_version not in constraint_value:
|
||||
sm_list = "/".join(f"SM{v}" for v in sorted(constraint_value))
|
||||
return _warn_and_return(
|
||||
f"CutlassFusedMoE {algo_name} only supports {sm_list}, "
|
||||
f"got SM{sm_version}")
|
||||
|
||||
# Check dtype_activation
|
||||
supported_dtypes = support_info["dtypes"]
|
||||
if dtype_activation not in supported_dtypes:
|
||||
dtype_list = ", ".join(str(d) for d in supported_dtypes)
|
||||
return _warn_and_return(
|
||||
f"CutlassFusedMoE {algo_name} requires {dtype_list}, "
|
||||
f"got {dtype_activation}")
|
||||
|
||||
return True, None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -1,4 +1,19 @@
|
||||
from typing import Dict, List, Optional, Union
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@ -6,7 +21,8 @@ import triton.language as tl
|
||||
|
||||
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
|
||||
from tensorrt_llm import deep_gemm
|
||||
from tensorrt_llm._utils import nvtx_range
|
||||
from tensorrt_llm._utils import get_sm_version, nvtx_range
|
||||
from tensorrt_llm.models.modeling_utils import QuantAlgo
|
||||
|
||||
from ...distributed import allgather
|
||||
from ...memory_buffer_utils import get_memory_buffers
|
||||
@ -361,6 +377,67 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
|
||||
model_config (ModelConfig): Configuration object for the model.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls,
|
||||
quant_algo: Optional[QuantAlgo],
|
||||
dtype_activation: torch.dtype = torch.bfloat16,
|
||||
gptoss_style: bool = False,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Check if DeepGemmFusedMoE can implement the given quantization algorithm.
|
||||
|
||||
DeepGemmFusedMoE supports:
|
||||
- FP8_BLOCK_SCALES: SM in {100, 103}
|
||||
|
||||
Does NOT support unquantized mode. Output dtype is hardcoded to bfloat16.
|
||||
Does NOT support gptoss_style (bias/swiglu with custom alpha/beta/limit).
|
||||
|
||||
Args:
|
||||
quant_algo: The quantization algorithm to check (None for unquantized)
|
||||
dtype_activation: The activation input data type. Supported types are
|
||||
float32, bfloat16, and float16 (required by moe_permute_op kernel).
|
||||
Note: Output dtype is always bfloat16 regardless of input dtype.
|
||||
gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled.
|
||||
DeepGemmFusedMoE does NOT support gptoss_style.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (can_implement, skip_reason)
|
||||
"""
|
||||
from .interface import _warn_and_return
|
||||
|
||||
sm_version = get_sm_version()
|
||||
|
||||
if sm_version not in {100, 103}:
|
||||
return _warn_and_return(
|
||||
f"DeepGemmFusedMoE requires SM100 or SM103, got SM{sm_version}")
|
||||
|
||||
# Check dtype_activation: moe_permute_op only supports float32, bfloat16, float16
|
||||
if dtype_activation not in {
|
||||
torch.float32, torch.bfloat16, torch.float16
|
||||
}:
|
||||
return _warn_and_return(
|
||||
f"DeepGemmFusedMoE requires float32, bfloat16, or float16 activation, "
|
||||
f"got {dtype_activation}")
|
||||
|
||||
# DeepGemmFusedMoE does NOT support unquantized mode
|
||||
if quant_algo is None:
|
||||
return _warn_and_return(
|
||||
"DeepGemmFusedMoE does not support unquantized mode")
|
||||
|
||||
# DeepGemmFusedMoE does NOT support gptoss_style
|
||||
if gptoss_style:
|
||||
return _warn_and_return(
|
||||
"DeepGemmFusedMoE does not support gptoss_style (bias/swiglu with custom alpha/beta/limit)"
|
||||
)
|
||||
|
||||
# Only FP8_BLOCK_SCALES is supported
|
||||
if quant_algo == QuantAlgo.FP8_BLOCK_SCALES:
|
||||
return True, None
|
||||
|
||||
return _warn_and_return(
|
||||
f"DeepGemmFusedMoE does not support quant_algo={quant_algo}")
|
||||
|
||||
# To reuse pytorch memory segments allocated during graph capture.
|
||||
buffers = get_memory_buffers()
|
||||
|
||||
|
||||
@ -1,7 +1,22 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Dict, List, NamedTuple, Optional
|
||||
from typing import Dict, List, NamedTuple, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -1263,6 +1278,73 @@ class TritonMXFP4FusedMoEMethod(TritonUnquantizedFusedMoEMethod):
|
||||
|
||||
class TritonFusedMoE(MoE):
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls,
|
||||
quant_algo: Optional["QuantAlgo"],
|
||||
dtype_activation: torch.dtype = torch.bfloat16,
|
||||
gptoss_style: bool = False,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Check if TritonFusedMoE can implement the given quantization algorithm.
|
||||
|
||||
TritonFusedMoE supports (SM90 only, gptoss_style=True only):
|
||||
- Unquantized (BF16 only)
|
||||
- FP8 per-tensor (QDQ)
|
||||
- W4A8_MXFP4_FP8
|
||||
- W4A16_MXFP4
|
||||
|
||||
Args:
|
||||
quant_algo: The quantization algorithm to check (None for unquantized)
|
||||
dtype_activation: The activation data type. In unquantized mode, activation,
|
||||
weight, and output dtypes must all match (only bfloat16 supported).
|
||||
gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled.
|
||||
TritonFusedMoE ONLY supports gptoss_style=True.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (can_implement, skip_reason)
|
||||
"""
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.models.modeling_utils import QuantAlgo
|
||||
|
||||
from .interface import _warn_and_return
|
||||
|
||||
sm_version = get_sm_version()
|
||||
|
||||
# TritonFusedMoE only supports SM90
|
||||
if sm_version != 90:
|
||||
return _warn_and_return(
|
||||
f"TritonFusedMoE only supports SM90, got SM{sm_version}")
|
||||
|
||||
# TritonFusedMoE ONLY supports gptoss_style=True
|
||||
if not gptoss_style:
|
||||
return _warn_and_return(
|
||||
"TritonFusedMoE only supports gptoss_style=True")
|
||||
|
||||
# Unquantized mode - only bfloat16 is supported
|
||||
if quant_algo is None:
|
||||
if dtype_activation != torch.bfloat16:
|
||||
return _warn_and_return(
|
||||
f"TritonFusedMoE unquantized mode only supports bfloat16, got {dtype_activation}"
|
||||
)
|
||||
return True, None
|
||||
|
||||
# FP8 per-tensor (QDQ) and W4A8_MXFP4_FP8 - no dtype_activation restriction
|
||||
if quant_algo in {QuantAlgo.FP8, QuantAlgo.W4A8_MXFP4_FP8}:
|
||||
return True, None
|
||||
|
||||
# W4A16_MXFP4 - only bfloat16 and float16 are supported
|
||||
if quant_algo == QuantAlgo.W4A16_MXFP4:
|
||||
if dtype_activation not in {torch.bfloat16, torch.float16}:
|
||||
return _warn_and_return(
|
||||
f"TritonFusedMoE W4A16_MXFP4 only supports bfloat16 or float16, "
|
||||
f"got {dtype_activation}")
|
||||
return True, None
|
||||
|
||||
# Unsupported quantization algorithm
|
||||
return _warn_and_return(
|
||||
f"TritonFusedMoE does not support quant_algo={quant_algo}")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -1,7 +1,22 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -10,6 +25,7 @@ from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe
|
||||
from tensorrt_llm._torch.distributed.moe_alltoall import MoeAlltoAll
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.models.modeling_utils import QuantAlgo
|
||||
|
||||
from ...custom_ops.trtllm_gen_custom_ops import \
|
||||
fp4_block_scale_fake_output_without_finalize
|
||||
@ -61,6 +77,88 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
There should be at lease `num_experts` slots in the model engine. More than that is OK, in that case, some experts may have multiple replicas.
|
||||
"""
|
||||
|
||||
# Supported quantization algorithms for TRTLLMGenFusedMoE
|
||||
_SUPPORTED_QUANT_ALGOS = {
|
||||
QuantAlgo.NVFP4,
|
||||
QuantAlgo.FP8_BLOCK_SCALES,
|
||||
QuantAlgo.W4A8_NVFP4_FP8,
|
||||
QuantAlgo.W4A16_MXFP4,
|
||||
QuantAlgo.W4A8_MXFP4_FP8,
|
||||
QuantAlgo.W4A8_MXFP4_MXFP8,
|
||||
}
|
||||
|
||||
# Quantization algorithms that support gptoss_style
|
||||
_GPTOSS_SUPPORTED_ALGOS = {
|
||||
QuantAlgo.NVFP4,
|
||||
QuantAlgo.W4A16_MXFP4,
|
||||
QuantAlgo.W4A8_MXFP4_FP8,
|
||||
QuantAlgo.W4A8_MXFP4_MXFP8,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls,
|
||||
quant_algo: Optional[QuantAlgo],
|
||||
dtype_activation: torch.dtype = torch.bfloat16,
|
||||
gptoss_style: bool = False,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Check if TRTLLMGenFusedMoE can implement the given quantization algorithm.
|
||||
|
||||
TRTLLMGenFusedMoE only supports SM in {100, 103} and the following quantizations:
|
||||
- NVFP4
|
||||
- FP8_BLOCK_SCALES
|
||||
- W4A8_NVFP4_FP8
|
||||
- W4A16_MXFP4
|
||||
- W4A8_MXFP4_FP8
|
||||
- W4A8_MXFP4_MXFP8
|
||||
|
||||
Does NOT support unquantized mode. Output dtype is hardcoded to bfloat16.
|
||||
|
||||
Args:
|
||||
quant_algo: The quantization algorithm to check (None for unquantized)
|
||||
dtype_activation: The activation input data type. Only bfloat16 is supported.
|
||||
See: forward_impl() assert x.dtype == torch.bfloat16 (line 722).
|
||||
gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled.
|
||||
Only supported for nvfp4 and mxfp4 variants.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (can_implement, skip_reason)
|
||||
"""
|
||||
from .interface import _warn_and_return
|
||||
|
||||
sm_version = get_sm_version()
|
||||
|
||||
# TRTLLMGenFusedMoE requires SM in {100, 103}
|
||||
if sm_version not in {100, 103}:
|
||||
return _warn_and_return(
|
||||
f"TRTLLMGenFusedMoE requires SM100 or SM103, got SM{sm_version}"
|
||||
)
|
||||
|
||||
# Check dtype_activation: only bfloat16 is supported
|
||||
if dtype_activation != torch.bfloat16:
|
||||
return _warn_and_return(
|
||||
f"TRTLLMGenFusedMoE only supports bfloat16 activation, got {dtype_activation}"
|
||||
)
|
||||
|
||||
# TRTLLMGenFusedMoE does NOT support unquantized mode
|
||||
if quant_algo is None:
|
||||
return _warn_and_return(
|
||||
"TRTLLMGenFusedMoE does not support unquantized mode")
|
||||
|
||||
# Check if quant_algo is supported
|
||||
if quant_algo not in cls._SUPPORTED_QUANT_ALGOS:
|
||||
return _warn_and_return(
|
||||
f"TRTLLMGenFusedMoE does not support quant_algo={quant_algo}")
|
||||
|
||||
# Check gptoss_style support: only supported for nvfp4 and mxfp4 variants
|
||||
if gptoss_style and quant_algo not in cls._GPTOSS_SUPPORTED_ALGOS:
|
||||
return _warn_and_return(
|
||||
f"TRTLLMGenFusedMoE supports gptoss_style (bias/swiglu) only for nvfp4 and mxfp4 variants, "
|
||||
f"got quant_algo={quant_algo}")
|
||||
|
||||
return True, None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -7,7 +7,30 @@ from typing import Dict, List, Optional, Tuple, Union, final
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.models.modeling_utils import QuantAlgo
|
||||
|
||||
from ...distributed.ops import reducescatter
|
||||
|
||||
|
||||
def _warn_and_return(reason: str) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Log a warning and return (False, reason) for can_implement() checks.
|
||||
|
||||
This is a common utility function used by all MoE backend implementations
|
||||
to provide consistent logging and return values when a configuration
|
||||
is not supported.
|
||||
|
||||
Args:
|
||||
reason: The reason why the configuration is not supported.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: Always returns (False, reason)
|
||||
"""
|
||||
logger.warning(reason)
|
||||
return False, reason
|
||||
|
||||
|
||||
from ...model_config import ModelConfig
|
||||
from ...utils import (ActivationType, AuxStreamType, Fp4QuantizedTensor,
|
||||
get_model_extra_attrs, is_gated_activation,
|
||||
@ -129,6 +152,40 @@ class MoE(nn.Module):
|
||||
aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(
|
||||
cls,
|
||||
quant_algo: Optional[QuantAlgo],
|
||||
dtype_activation: torch.dtype = torch.bfloat16,
|
||||
gptoss_style: bool = False,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Check if this MoE backend can implement the given quantization algorithm.
|
||||
|
||||
NOTE: This is a TRANSITIONAL interface. In the future, this method will be moved
|
||||
to the MoEBackend interface as part of the backend abstraction layer. During this
|
||||
transition period, it remains in the MoE base class to maintain compatibility.
|
||||
|
||||
This method checks both:
|
||||
1. Whether the backend supports the specified quantization algorithm
|
||||
2. Whether the current platform (SM version) supports the backend and quantization
|
||||
|
||||
Each backend MUST override this method to provide accurate capability information.
|
||||
|
||||
Args:
|
||||
quant_algo: The quantization algorithm to check (None for unquantized)
|
||||
dtype_activation: The activation data type.
|
||||
gptoss_style: Whether gptoss_style (bias/swiglu with custom alpha/beta/limit) is enabled.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (can_implement, skip_reason)
|
||||
- can_implement: True if the backend can implement this configuration
|
||||
- skip_reason: None if can_implement is True, otherwise a string explaining why not
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{cls.__name__} must implement can_implement method")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -22,8 +22,6 @@ l0_dgx_b300:
|
||||
- unittest/_torch/modeling -k "modeling_mixtral"
|
||||
- unittest/_torch/modeling -k "modeling_gpt_oss"
|
||||
- unittest/_torch/multi_gpu_modeling -k "deepseek"
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEP]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=FLASHINFER-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=FLASHINFER-torch_compile=False]
|
||||
|
||||
@ -23,8 +23,6 @@ l0_gb300_multi_gpus:
|
||||
- unittest/_torch/modeling -k "modeling_mixtral"
|
||||
- unittest/_torch/modeling -k "modeling_gpt_oss"
|
||||
- unittest/_torch/multi_gpu_modeling -k "deepseek"
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEP]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
|
||||
|
||||
@ -1,119 +0,0 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TEMPORARY FILE - Will be removed after MoE refactor is complete.
|
||||
#
|
||||
# Background:
|
||||
# The `enable_configurable_moe` parameter is a temporary measure during the MoE
|
||||
# refactor. The old and new MoE flows will coexist for a period of time. To avoid
|
||||
# large-scale changes to the existing test lists, we handle the test ID cleanup
|
||||
# here. Once the refactor is complete and all tests use ConfigurableMoE by default,
|
||||
# this file will no longer be needed and should be deleted.
|
||||
#
|
||||
# Two-phase approach:
|
||||
# 1. pytest_sessionstart: Convert clean test names in CLI args back to original
|
||||
# format so pytest can find tests during collection.
|
||||
# 2. pytest_collection_modifyitems: Clean up the collected test IDs for display
|
||||
# and waive matching.
|
||||
import re
|
||||
|
||||
# Test functions that use enable_configurable_moe parameter and need ID conversion
|
||||
TESTS_WITH_CONFIGURABLE_MOE = [
|
||||
"test_fused_moe_nvfp4[",
|
||||
"test_fused_moe_mxfp4_mxfp8[",
|
||||
"test_fused_moe_w4a8_nvfp4_fp8[",
|
||||
"test_fused_moe_wfp4a16[",
|
||||
"test_fused_moe_fp8_blockwise_deepgemm[",
|
||||
]
|
||||
|
||||
|
||||
def _convert_clean_to_original_moe_test_id(test_id):
|
||||
"""Convert clean MoE test ID back to original format for pytest collection.
|
||||
|
||||
Example: "test_fused_moe.py::test_foo[TRTLLM-dtype0]" -> "test_fused_moe.py::test_foo[-TRTLLM-dtype0]"
|
||||
|
||||
This is needed because the `enable_configurable_moe` parameter uses empty string
|
||||
as ID when value is 0, resulting in test IDs like "test_foo[-TRTLLM-dtype0]".
|
||||
We clean these up in pytest_collection_modifyitems, but pytest filters tests
|
||||
during collection using the original IDs. So when user runs with clean test name,
|
||||
we need to convert it back to match the original.
|
||||
"""
|
||||
if "test_fused_moe.py" not in test_id:
|
||||
return test_id
|
||||
|
||||
# Match pattern like "test_name[params]" and add leading dash after "["
|
||||
# But only if params don't already start with "-" or "enable_configurable_moe"
|
||||
match = re.search(r"\[([^\]]+)\]", test_id)
|
||||
if match:
|
||||
params = match.group(1)
|
||||
# Skip if already has leading dash or starts with enable_configurable_moe
|
||||
if not params.startswith("-") and not params.startswith("enable_configurable_moe"):
|
||||
# Add leading dash to params
|
||||
new_params = "-" + params
|
||||
test_id = test_id.replace(f"[{params}]", f"[{new_params}]")
|
||||
|
||||
return test_id
|
||||
|
||||
|
||||
def pytest_sessionstart(session):
|
||||
"""Convert clean MoE test IDs in config.args to original format for collection.
|
||||
|
||||
This is needed because pytest filters tests during collection using original IDs.
|
||||
When user runs with clean test name, we convert it back to match the original.
|
||||
"""
|
||||
args = session.config.args
|
||||
for i, arg in enumerate(args):
|
||||
if "test_fused_moe.py" in arg and "[" in arg:
|
||||
# Only apply conversion to specific tests that use enable_configurable_moe
|
||||
should_convert = any(test_name in arg for test_name in TESTS_WITH_CONFIGURABLE_MOE)
|
||||
if should_convert:
|
||||
args[i] = _convert_clean_to_original_moe_test_id(arg)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(items):
|
||||
"""Clean up test IDs by removing leading/trailing dashes from parameter IDs.
|
||||
|
||||
This is needed because `enable_configurable_moe` parameter can be empty,
|
||||
resulting in ugly test IDs like "test_foo[-True]" or "test_foo[--abc]".
|
||||
We clean these up to "test_foo[True]" or "test_foo[abc]" so that:
|
||||
1. Test names in waive files and test lists remain unchanged
|
||||
2. Test reports look cleaner
|
||||
|
||||
This runs BEFORE the global conftest applies waives (due to hookwrapper).
|
||||
"""
|
||||
for item in items:
|
||||
if "test_fused_moe.py" in item.nodeid and "[" in item.nodeid:
|
||||
# Only apply cleanup to specific tests that use enable_configurable_moe
|
||||
should_cleanup = any(
|
||||
test_name in item.nodeid for test_name in TESTS_WITH_CONFIGURABLE_MOE
|
||||
)
|
||||
if should_cleanup:
|
||||
original_nodeid = item.nodeid
|
||||
original_name = item.name
|
||||
nodeid = item.nodeid
|
||||
name = item.name
|
||||
|
||||
# Clean up leading/trailing dashes in nodeid
|
||||
nodeid = nodeid.replace("[-", "[")
|
||||
nodeid = nodeid.replace("-]", "]")
|
||||
|
||||
# Clean up leading/trailing dashes in name
|
||||
name = name.replace("[-", "[")
|
||||
name = name.replace("-]", "]")
|
||||
|
||||
if nodeid != original_nodeid:
|
||||
item._nodeid = nodeid
|
||||
if name != original_name:
|
||||
item.name = name
|
||||
File diff suppressed because it is too large
Load Diff
927
tests/unittest/_torch/modules/moe/test_moe_backend.py
Normal file
927
tests/unittest/_torch/modules/moe/test_moe_backend.py
Normal file
@ -0,0 +1,927 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
MoE Backend Unit Tests
|
||||
|
||||
This module provides a unified test framework for testing different MoE backends
|
||||
through the backend-level interfaces (quantize_input + run_moe), rather than
|
||||
the high-level forward() interface.
|
||||
|
||||
Design Goals:
|
||||
1. Test backend interfaces directly: routing_method.apply -> quantize_input -> run_moe
|
||||
2. Cover all quantization + backend combinations
|
||||
3. Use can_implement() interface to determine test skip logic
|
||||
4. Support autotune and tactic capture testing
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable, List, Optional, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from _torch.modules.moe.quantize_utils import get_test_quant_params
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.modules.fused_moe import RenormalizeMoeRoutingMethod
|
||||
from tensorrt_llm._torch.modules.fused_moe.create_moe import create_moe_backend
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import CuteDslFusedMoE
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import CutlassFusedMoE
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_deepgemm import DeepGemmFusedMoE
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_trtllm_gen import TRTLLMGenFusedMoE
|
||||
from tensorrt_llm._torch.modules.fused_moe.interface import MoE
|
||||
from tensorrt_llm._utils import mpi_rank
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.modeling_utils import QuantAlgo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MoeBackendType(str, Enum):
|
||||
"""Enum for MoE backend types."""
|
||||
|
||||
CUTLASS = "CUTLASS"
|
||||
TRTLLM = "TRTLLM"
|
||||
CUTEDSL = "CUTEDSL"
|
||||
DEEPGEMM = "DEEPGEMM"
|
||||
|
||||
|
||||
def get_backend_class(backend_type: MoeBackendType) -> Type[MoE]:
|
||||
"""Get the MoE backend class for a given backend type."""
|
||||
backend_class_map = {
|
||||
MoeBackendType.CUTLASS: CutlassFusedMoE,
|
||||
MoeBackendType.TRTLLM: TRTLLMGenFusedMoE,
|
||||
MoeBackendType.CUTEDSL: CuteDslFusedMoE,
|
||||
MoeBackendType.DEEPGEMM: DeepGemmFusedMoE,
|
||||
}
|
||||
return backend_class_map[backend_type]
|
||||
|
||||
|
||||
def should_skip_TRTLLM(
|
||||
backend_type: MoeBackendType,
|
||||
quant_algo: Optional[QuantAlgo],
|
||||
model_config: "MoeModelConfig",
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Check TRTLLM Gen backend specific constraints.
|
||||
|
||||
The TRTLLM Gen MoE kernels have hardware-level constraints that must be satisfied.
|
||||
These constraints are enforced in C++ layer.
|
||||
|
||||
Constraints:
|
||||
1. num_experts must be divisible by 4 (routing kernel vectorization requirement)
|
||||
2. num_experts must be greater than top_k (routing logic requirement)
|
||||
|
||||
Args:
|
||||
backend_type: The MoE backend type
|
||||
quant_algo: The quantization algorithm
|
||||
model_config: The MoE model configuration
|
||||
|
||||
Returns:
|
||||
Skip reason string if test should be skipped, None otherwise
|
||||
"""
|
||||
if backend_type != MoeBackendType.TRTLLM:
|
||||
return None
|
||||
|
||||
if model_config is None:
|
||||
return None
|
||||
|
||||
# These quantization algorithms use TRTLLM Gen kernels with the constraints
|
||||
trtllm_gen_quant_algos = {
|
||||
QuantAlgo.NVFP4,
|
||||
QuantAlgo.FP8_BLOCK_SCALES,
|
||||
QuantAlgo.W4A8_NVFP4_FP8,
|
||||
QuantAlgo.W4A16_MXFP4,
|
||||
QuantAlgo.W4A8_MXFP4_MXFP8,
|
||||
}
|
||||
|
||||
if quant_algo not in trtllm_gen_quant_algos:
|
||||
return None
|
||||
|
||||
num_experts = model_config.num_experts
|
||||
top_k = model_config.top_k
|
||||
intermediate_size = model_config.intermediate_size
|
||||
|
||||
# Check: num_experts must be divisible by 4
|
||||
# Routing kernel uses vectorized operations that require this alignment
|
||||
if num_experts % 4 != 0:
|
||||
return (
|
||||
f"TRTLLMGenFusedMoE routing kernel requires num_experts divisible by 4 "
|
||||
f"(got num_experts={num_experts})"
|
||||
)
|
||||
|
||||
# Check: num_experts must be greater than top_k
|
||||
# Routing logic cannot handle the case where all experts are selected
|
||||
if num_experts <= top_k:
|
||||
return (
|
||||
f"TRTLLMGenFusedMoE requires num_experts > top_k "
|
||||
f"(got num_experts={num_experts}, top_k={top_k})"
|
||||
)
|
||||
|
||||
# -----------------Potential issues------------------
|
||||
# These are known issues that need investigation. Skipping to avoid test failures
|
||||
# and CUDA errors that can cascade to subsequent tests.
|
||||
|
||||
# Issue 1: W4A8_NVFP4_FP8 with top_k=1 causes CUDA illegal memory access
|
||||
# This triggers GPU state corruption that affects all subsequent tests.
|
||||
# Affected config: e8_k1_h512_i512
|
||||
if quant_algo == QuantAlgo.W4A8_NVFP4_FP8 and top_k == 1:
|
||||
return (
|
||||
"[Potential Bug] TRTLLMGenFusedMoE W4A8_NVFP4_FP8 with top_k=1 "
|
||||
"causes CUDA illegal memory access. Needs kernel investigation."
|
||||
)
|
||||
|
||||
# Issue 2: NVFP4 with large intermediate_size has known accuracy issues
|
||||
# Observed mismatch: 18%~25% vs expected <7.5% (per test_moe.py baseline)
|
||||
# Affected configs: e8_k2_h4096_i14336, e8_k2_h6144_i32768
|
||||
if quant_algo == QuantAlgo.NVFP4 and intermediate_size >= 14336:
|
||||
return (
|
||||
f"[Potential Bug] TRTLLMGenFusedMoE NVFP4 with large intermediate_size "
|
||||
f"has known accuracy issues (intermediate_size={intermediate_size} >= 14336). "
|
||||
f"Observed mismatch 18%~25% exceeds expected threshold."
|
||||
)
|
||||
|
||||
# Issue 3: W4A8_MXFP4_MXFP8 has accuracy issues on certain model configs
|
||||
# Observed mismatch: 14%~18% vs expected <15% (percent=0.85)
|
||||
# Affected configs: large intermediate_size or many experts
|
||||
# e8_k2_h4096_i14336, e64_k6_h2048_i1408, e60_k4_h2048_i1408,
|
||||
# e256_k8_h7168_i2048, e8_k2_h6144_i32768, e128_k4_h2880_i2880
|
||||
if quant_algo == QuantAlgo.W4A8_MXFP4_MXFP8:
|
||||
# Large intermediate_size (>= 14336) has precision issues
|
||||
if intermediate_size >= 14336:
|
||||
return (
|
||||
f"[Potential Bug] TRTLLMGenFusedMoE W4A8_MXFP4_MXFP8 with large "
|
||||
f"intermediate_size has accuracy issues (intermediate_size={intermediate_size} >= 14336). "
|
||||
f"Observed mismatch 14%~18% exceeds 15% threshold."
|
||||
)
|
||||
# Many experts (>= 60) with moderate intermediate_size has precision issues
|
||||
if num_experts >= 60 and intermediate_size >= 1408:
|
||||
return (
|
||||
f"[Potential Bug] TRTLLMGenFusedMoE W4A8_MXFP4_MXFP8 with many experts "
|
||||
f"has accuracy issues (num_experts={num_experts} >= 60, intermediate_size={intermediate_size}). "
|
||||
f"Observed mismatch 14%~18% exceeds 15% threshold."
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def should_skip_CUTEDSL(
|
||||
backend_type: MoeBackendType,
|
||||
quant_algo: Optional[QuantAlgo],
|
||||
model_config: "MoeModelConfig" = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Check CuteDSL backend specific constraints.
|
||||
|
||||
The CuteDSL MoE kernels have known accuracy issues with certain configurations.
|
||||
|
||||
Args:
|
||||
backend_type: The MoE backend type
|
||||
quant_algo: The quantization algorithm
|
||||
model_config: The MoE model configuration
|
||||
|
||||
Returns:
|
||||
Skip reason string if test should be skipped, None otherwise
|
||||
"""
|
||||
if backend_type != MoeBackendType.CUTEDSL:
|
||||
return None
|
||||
|
||||
if model_config is None:
|
||||
return None
|
||||
|
||||
intermediate_size = model_config.intermediate_size
|
||||
|
||||
# -----------------Potential issues------------------
|
||||
# NVFP4 with large intermediate_size has known accuracy issues (same as TRTLLM)
|
||||
# Observed mismatch: 8%~26% vs expected <2%
|
||||
# Affected configs: e8_k2_h4096_i14336, e8_k2_h6144_i32768
|
||||
if quant_algo == QuantAlgo.NVFP4 and intermediate_size >= 14336:
|
||||
return (
|
||||
f"[Potential Bug] CuteDslFusedMoE NVFP4 with large intermediate_size "
|
||||
f"has known accuracy issues (intermediate_size={intermediate_size} >= 14336). "
|
||||
f"Observed mismatch 8%~26% exceeds 2% threshold."
|
||||
)
|
||||
|
||||
# NVFP4 with prime num_experts (7, 13) causes CUDA_ERROR_ILLEGAL_ADDRESS
|
||||
# Root cause: Autotuner cache bucket mapping issue
|
||||
# - When tests run in batch, previous tests cache tactics to buckets
|
||||
# - Prime num_experts shapes map to same bucket as other configs
|
||||
# - The cached tactic (e.g., ((128, 256), (1, 2), False)) works for other configs
|
||||
# but causes illegal memory access for prime num_experts' actual shape
|
||||
# - Single test run passes because fallback tactic ((128, 128), (1, 1), False) is used
|
||||
# Affected configs: e7_k2_h256_i512, e13_k3_h256_i512
|
||||
num_experts = model_config.num_experts
|
||||
prime_experts_with_issues = {7, 13}
|
||||
if quant_algo == QuantAlgo.NVFP4 and num_experts in prime_experts_with_issues:
|
||||
return (
|
||||
f"[Potential Bug] CuteDslFusedMoE NVFP4 with prime num_experts={num_experts} "
|
||||
f"causes CUDA_ERROR_ILLEGAL_ADDRESS due to autotuner cache bucket mapping. "
|
||||
f"Cached tactic from other configs is incompatible with this shape."
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def should_skip_gptoss(
|
||||
backend_type: MoeBackendType,
|
||||
quant_algo: Optional[QuantAlgo],
|
||||
gptoss_style: bool,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Check if gptoss_style test should be skipped for this backend.
|
||||
|
||||
Only CUTLASS and TRTLLM backends support gptoss_style (SwiGlu with custom
|
||||
alpha/beta/limit parameters and bias).
|
||||
|
||||
Args:
|
||||
backend_type: The MoE backend type
|
||||
quant_algo: The quantization algorithm
|
||||
gptoss_style: Whether gptoss_style is enabled
|
||||
|
||||
Returns:
|
||||
Skip reason string if test should be skipped, None otherwise
|
||||
"""
|
||||
if not gptoss_style:
|
||||
return None
|
||||
|
||||
# Only CUTLASS and TRTLLM backends support gptoss_style
|
||||
supported_backends = {MoeBackendType.CUTLASS, MoeBackendType.TRTLLM}
|
||||
if backend_type not in supported_backends:
|
||||
return (
|
||||
f"gptoss_style is only supported by CUTLASS and TRTLLM backends "
|
||||
f"(got backend_type={backend_type.value})"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def supports_autotuner_capture(
|
||||
backend_type: MoeBackendType,
|
||||
quant_algo: Optional[QuantAlgo],
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if a backend+quant_algo combination supports AutoTuner capture/replay.
|
||||
|
||||
AutoTuner capture/replay requires AutoTuner.choose_one() to be called during
|
||||
run_moe execution.
|
||||
|
||||
Args:
|
||||
backend_type: The MoE backend type
|
||||
quant_algo: The quantization algorithm (None for unquantized)
|
||||
|
||||
Returns:
|
||||
True if autotuner capture/replay is supported, False otherwise
|
||||
"""
|
||||
# DEEPGEMM does not support autotuner capture
|
||||
# Evidence: fused_moe_deepgemm.py has no AutoTuner/choose_one references
|
||||
if backend_type == MoeBackendType.DEEPGEMM:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def create_test_backend(
|
||||
backend_type: MoeBackendType,
|
||||
routing_method: RenormalizeMoeRoutingMethod,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
dtype: torch.dtype,
|
||||
quant_config,
|
||||
mapping: Mapping,
|
||||
bias: bool = False,
|
||||
swiglu_alpha: Optional[torch.Tensor] = None,
|
||||
swiglu_beta: Optional[torch.Tensor] = None,
|
||||
swiglu_limit: Optional[torch.Tensor] = None,
|
||||
) -> MoE:
|
||||
"""Create a MoE backend for testing."""
|
||||
backend_cls = get_backend_class(backend_type)
|
||||
|
||||
pretrained_config = PretrainedConfig()
|
||||
pretrained_config.num_experts = num_experts
|
||||
pretrained_config.hidden_size = hidden_size
|
||||
pretrained_config.intermediate_size = intermediate_size
|
||||
pretrained_config.torch_dtype = dtype
|
||||
|
||||
model_config = ModelConfig(
|
||||
pretrained_config=pretrained_config,
|
||||
quant_config=quant_config,
|
||||
mapping=mapping,
|
||||
moe_backend=backend_type.value,
|
||||
)
|
||||
|
||||
return create_moe_backend(
|
||||
moe_cls=backend_cls,
|
||||
routing_method=routing_method,
|
||||
num_experts=num_experts,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
dtype=dtype,
|
||||
reduce_results=True,
|
||||
model_config=model_config,
|
||||
init_load_balancer=False,
|
||||
bias=bias,
|
||||
swiglu_alpha=swiglu_alpha,
|
||||
swiglu_beta=swiglu_beta,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
|
||||
|
||||
def run_backend_moe(
|
||||
backend: MoE,
|
||||
backend_type: MoeBackendType,
|
||||
x_quantized: torch.Tensor,
|
||||
x_sf: torch.Tensor,
|
||||
token_selected_experts: torch.Tensor,
|
||||
token_final_scales: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
router_logits: torch.Tensor = None,
|
||||
trtllm_use_router_logits: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Run MoE computation with backend-specific parameters.
|
||||
|
||||
Each backend has different requirements:
|
||||
- CUTLASS: output_dtype, token_final_scales=float32
|
||||
- TRTLLM: token_final_scales=bfloat16, optionally router_logits
|
||||
- CUTEDSL: token_final_scales=float32
|
||||
- DEEPGEMM: workspace, token_final_scales=float32
|
||||
|
||||
Args:
|
||||
trtllm_use_router_logits: If True, TRTLLM backend uses router_logits for routing.
|
||||
If False, uses token_selected_experts and token_final_scales.
|
||||
Note: When both are provided, TRTLLM only uses (topk_ids and topk_weights).
|
||||
"""
|
||||
# Common args for all backends (default: token_final_scales=float32)
|
||||
args = dict(
|
||||
x=x_quantized,
|
||||
token_selected_experts=token_selected_experts.to(torch.int32),
|
||||
token_final_scales=token_final_scales.to(torch.float32),
|
||||
x_sf=x_sf,
|
||||
)
|
||||
|
||||
# Backend-specific overrides
|
||||
if backend_type == MoeBackendType.CUTLASS:
|
||||
args["output_dtype"] = dtype
|
||||
elif backend_type == MoeBackendType.TRTLLM:
|
||||
args["token_final_scales"] = token_final_scales.to(torch.bfloat16)
|
||||
if trtllm_use_router_logits:
|
||||
# Use router_logits for routing (TRTLLM will compute topk internally)
|
||||
args["router_logits"] = router_logits
|
||||
args["token_selected_experts"] = None
|
||||
args["token_final_scales"] = None
|
||||
# else: use token_selected_experts and token_final_scales (already set)
|
||||
elif backend_type == MoeBackendType.DEEPGEMM:
|
||||
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
|
||||
|
||||
m_max = fp8_utils.align(x_quantized.shape[0], 128)
|
||||
args["workspace"] = backend.get_workspace(m_max, 128)
|
||||
|
||||
return backend.run_moe(**args)
|
||||
|
||||
|
||||
def replay_tactics_and_check(
|
||||
all_tactics,
|
||||
run_moe_fn: Callable[[], torch.Tensor],
|
||||
check_accuracy_fn: Callable[[torch.Tensor, torch.Tensor], None],
|
||||
ref_output: torch.Tensor,
|
||||
backend_type: MoeBackendType,
|
||||
quant_algo: Optional[QuantAlgo],
|
||||
fail_fast: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Replay all tactics and check accuracy.
|
||||
|
||||
Args:
|
||||
all_tactics: TacticsCapture object from AutoTuner.capture()
|
||||
run_moe_fn: Function to run MoE computation
|
||||
check_accuracy_fn: Function to check accuracy (output, ref_output) -> None
|
||||
ref_output: Reference output tensor
|
||||
backend_type: Backend type for error reporting
|
||||
quant_algo: Quantization algorithm for error reporting
|
||||
fail_fast: If True, fail on first error. If False, run all and report summary.
|
||||
"""
|
||||
tactics_list = list(all_tactics)
|
||||
passed_tactics = []
|
||||
failed_tactics = []
|
||||
logger.info(f"Replay tactics : {len(tactics_list)} and check accuracy")
|
||||
for idx, tactic in enumerate(tactics_list):
|
||||
with AutoTuner.get().replay(tactic), torch.inference_mode():
|
||||
output = run_moe_fn()
|
||||
try:
|
||||
check_accuracy_fn(output, ref_output)
|
||||
passed_tactics.append((idx, tactic))
|
||||
except Exception as e:
|
||||
if fail_fast:
|
||||
pytest.fail(
|
||||
f"Accuracy check failed for tactic[{idx}/{len(tactics_list)}]={tactic}, "
|
||||
f"backend={backend_type}, quant_algo={quant_algo}: {e}"
|
||||
)
|
||||
failed_tactics.append((idx, tactic, str(e)))
|
||||
|
||||
# Report results (only when fail_fast=False)
|
||||
total = len(tactics_list)
|
||||
num_passed = len(passed_tactics)
|
||||
num_failed = len(failed_tactics)
|
||||
if failed_tactics:
|
||||
fail_details = "\n".join(
|
||||
f" tactic[{idx}]={tactic}: {err}" for idx, tactic, err in failed_tactics
|
||||
)
|
||||
pytest.fail(
|
||||
f"backend={backend_type}, quant_algo={quant_algo}: "
|
||||
f"{num_passed}/{total} passed, {num_failed}/{total} failed\n"
|
||||
f"Failed tactics:\n{fail_details}"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Parameters
|
||||
# ============================================================================
|
||||
|
||||
# Quantization algorithms to test
|
||||
QUANT_ALGOS_TO_TEST = [
|
||||
None, # Unquantized
|
||||
QuantAlgo.FP8,
|
||||
QuantAlgo.NVFP4,
|
||||
QuantAlgo.FP8_BLOCK_SCALES,
|
||||
QuantAlgo.W4A8_NVFP4_FP8,
|
||||
QuantAlgo.W4A16_MXFP4,
|
||||
QuantAlgo.W4A8_MXFP4_MXFP8,
|
||||
QuantAlgo.W8A16,
|
||||
QuantAlgo.W4A8_AWQ,
|
||||
]
|
||||
|
||||
# Backend types to test
|
||||
BACKEND_TYPES_TO_TEST = [
|
||||
MoeBackendType.CUTLASS,
|
||||
MoeBackendType.TRTLLM,
|
||||
MoeBackendType.CUTEDSL,
|
||||
MoeBackendType.DEEPGEMM,
|
||||
]
|
||||
|
||||
# Data types to test
|
||||
DTYPES_TO_TEST = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Model MoE Configurations
|
||||
# ============================================================================
|
||||
@dataclass
|
||||
class MoeModelConfig:
|
||||
"""MoE model configuration: (num_experts, top_k, hidden_size, intermediate_size)."""
|
||||
|
||||
num_experts: int
|
||||
top_k: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"e{self.num_experts}_k{self.top_k}_h{self.hidden_size}_i{self.intermediate_size}"
|
||||
|
||||
|
||||
# Format: (num_experts, top_k, hidden_size, intermediate_size)
|
||||
MOE_MODEL_CONFIGS = [
|
||||
# === Real Model Configs ===
|
||||
MoeModelConfig(8, 2, 4096, 14336), # Mixtral-8x7B
|
||||
MoeModelConfig(64, 6, 2048, 1408), # DeepSeek-MoE-16B / DeepSeek-V2-Lite
|
||||
MoeModelConfig(60, 4, 2048, 1408), # Qwen1.5-MoE-A2.7B
|
||||
MoeModelConfig(256, 8, 7168, 2048), # DeepSeek-V3
|
||||
MoeModelConfig(8, 2, 6144, 32768), # Grok-1
|
||||
MoeModelConfig(128, 4, 2880, 2880), # GPT-OSS-120B
|
||||
# === Boundary Tests: num_experts / top_k ===
|
||||
MoeModelConfig(8, 1, 512, 512), # top_k=1, single expert activated
|
||||
MoeModelConfig(4, 4, 512, 512), # top_k=num_experts, all experts activated
|
||||
MoeModelConfig(7, 2, 256, 512), # prime num_experts
|
||||
MoeModelConfig(13, 3, 256, 512), # prime num_experts, odd top_k
|
||||
# === Boundary Tests: small sizes ===
|
||||
MoeModelConfig(4, 2, 64, 128), # very small hidden_size
|
||||
MoeModelConfig(4, 2, 128, 64), # intermediate < hidden
|
||||
]
|
||||
|
||||
# Sequence lengths to test
|
||||
SEQ_LENS_TO_TEST = [1, 8]
|
||||
|
||||
# SwiGLU parameters for gptoss_style testing
|
||||
SWIGLU_ALPHAS = [1, 0.1]
|
||||
SWIGLU_BETAS = [0, 1]
|
||||
SWIGLU_LIMITS = [float("inf"), 1]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fast Skip Check (for parametrize-level skip, avoids entering test function)
|
||||
# ============================================================================
|
||||
def get_quick_skip_reason(
|
||||
backend_type: MoeBackendType,
|
||||
quant_algo: Optional[QuantAlgo],
|
||||
dtype: torch.dtype,
|
||||
model_config: "MoeModelConfig",
|
||||
gptoss_style: bool,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Fast skip check that calls backend's can_implement() method.
|
||||
|
||||
This function calls the backend's can_implement() classmethod to check
|
||||
dtype/quant_algo/gptoss_style support, then uses should_skip_* functions
|
||||
for additional model_config specific checks.
|
||||
|
||||
Note: Logging is temporarily suppressed to avoid excessive warning output
|
||||
during test parameter generation.
|
||||
|
||||
Returns:
|
||||
Skip reason string if test should be skipped, None otherwise
|
||||
"""
|
||||
import logging as _logging
|
||||
|
||||
# Suppress logger warnings during parameter generation to avoid excessive output
|
||||
trtllm_logger = _logging.getLogger("tensorrt_llm")
|
||||
original_level = trtllm_logger.level
|
||||
trtllm_logger.setLevel(_logging.ERROR)
|
||||
|
||||
try:
|
||||
# ===== Call backend's can_implement for dtype/quant_algo/gptoss_style checks =====
|
||||
backend_cls = get_backend_class(backend_type)
|
||||
can_impl, skip_reason = backend_cls.can_implement(
|
||||
quant_algo, dtype_activation=dtype, gptoss_style=gptoss_style
|
||||
)
|
||||
if not can_impl:
|
||||
return skip_reason
|
||||
|
||||
# ===== Additional model_config specific checks =====
|
||||
|
||||
# TRTLLM: num_experts constraints and accuracy issues
|
||||
skip_reason = should_skip_TRTLLM(backend_type, quant_algo, model_config)
|
||||
if skip_reason:
|
||||
return skip_reason
|
||||
|
||||
# CUTEDSL: accuracy issues with specific configs
|
||||
skip_reason = should_skip_CUTEDSL(backend_type, quant_algo, model_config)
|
||||
if skip_reason:
|
||||
return skip_reason
|
||||
|
||||
# DEEPGEMM: float16 reference module constraint
|
||||
if backend_type == MoeBackendType.DEEPGEMM and dtype == torch.float16:
|
||||
return "DeepGemmFusedMoE reference module (FP8BlockScalesLinearMethod) requires bfloat16 input"
|
||||
|
||||
# 128-alignment requirement for quantization
|
||||
if quant_algo is not None:
|
||||
hidden_size = model_config.hidden_size
|
||||
intermediate_size = model_config.intermediate_size
|
||||
is_hidden_128_aligned = hidden_size % 128 == 0
|
||||
is_intermediate_128_aligned = intermediate_size % 128 == 0
|
||||
|
||||
if not is_hidden_128_aligned or not is_intermediate_128_aligned:
|
||||
# TRTLLM with MXFP4 variants automatically pads to 128 alignment
|
||||
is_mxfp4_variant = quant_algo in {QuantAlgo.W4A16_MXFP4, QuantAlgo.W4A8_MXFP4_MXFP8}
|
||||
is_trtllm_backend = backend_type == MoeBackendType.TRTLLM
|
||||
if not (is_trtllm_backend and is_mxfp4_variant):
|
||||
return (
|
||||
f"Non-128-aligned sizes (h={hidden_size}, i={intermediate_size}) "
|
||||
f"require TRTLLM backend with MXFP4 quantization"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
finally:
|
||||
# Restore logger level
|
||||
trtllm_logger.setLevel(original_level)
|
||||
|
||||
|
||||
def generate_test_params() -> List:
|
||||
"""
|
||||
Generate all test parameter combinations with skip marks for invalid combinations.
|
||||
|
||||
This function pre-computes skip decisions at collection time using static rules,
|
||||
avoiding the overhead of entering test functions and calling can_implement().
|
||||
This significantly speeds up test collection and skip execution.
|
||||
|
||||
Returns:
|
||||
List of pytest.param objects with appropriate skip marks
|
||||
"""
|
||||
params: List = []
|
||||
|
||||
# Generate all combinations
|
||||
swiglu_combos = list(itertools.product(SWIGLU_ALPHAS, SWIGLU_BETAS, SWIGLU_LIMITS))
|
||||
|
||||
for swiglu_alpha, swiglu_beta, swiglu_limit in swiglu_combos:
|
||||
for model_config in MOE_MODEL_CONFIGS:
|
||||
for seq_len in SEQ_LENS_TO_TEST:
|
||||
for dtype in DTYPES_TO_TEST:
|
||||
for backend_type in BACKEND_TYPES_TO_TEST:
|
||||
for quant_algo in QUANT_ALGOS_TO_TEST:
|
||||
# Determine gptoss_style
|
||||
gptoss_style = (
|
||||
swiglu_alpha != 1
|
||||
or swiglu_beta != 0
|
||||
or swiglu_limit != float("inf")
|
||||
)
|
||||
|
||||
# Generate test ID
|
||||
test_id = (
|
||||
f"alpha={swiglu_alpha}_beta={swiglu_beta}_limit={swiglu_limit}-"
|
||||
f"{model_config}-seq={seq_len}-dtype={dtype}-"
|
||||
f"backend={backend_type.value}-quant_algo={quant_algo}"
|
||||
)
|
||||
|
||||
# Check if should skip
|
||||
skip_reason = get_quick_skip_reason(
|
||||
backend_type, quant_algo, dtype, model_config, gptoss_style
|
||||
)
|
||||
|
||||
param_values = (
|
||||
dtype,
|
||||
backend_type,
|
||||
quant_algo,
|
||||
seq_len,
|
||||
model_config,
|
||||
swiglu_alpha,
|
||||
swiglu_beta,
|
||||
swiglu_limit,
|
||||
)
|
||||
|
||||
if skip_reason:
|
||||
params.append(
|
||||
pytest.param(
|
||||
*param_values,
|
||||
id=test_id,
|
||||
marks=pytest.mark.skip(reason=skip_reason),
|
||||
)
|
||||
)
|
||||
else:
|
||||
params.append(pytest.param(*param_values, id=test_id))
|
||||
|
||||
return params
|
||||
|
||||
|
||||
# Pre-generate test parameters at module load time
|
||||
TEST_PARAMS = generate_test_params()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Timing Fixtures
|
||||
# ============================================================================
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def module_timer(request):
|
||||
"""Fixture to measure and log total module execution time."""
|
||||
start = time.perf_counter()
|
||||
yield
|
||||
elapsed = time.perf_counter() - start
|
||||
logger.info(
|
||||
"[TIMING] Total %s: %.3fs (%.2f min)",
|
||||
request.module.__name__,
|
||||
elapsed,
|
||||
elapsed / 60,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Implementation
|
||||
# ============================================================================
|
||||
#
|
||||
# This file provides a UNIFIED TEST FRAMEWORK for testing all MoE backend
|
||||
# implementations through their backend-level interfaces.
|
||||
#
|
||||
# =============================================================================
|
||||
# Purpose & Scope
|
||||
# =============================================================================
|
||||
# - Test MoE backends via: routing_method.apply -> quantize_input -> run_moe
|
||||
# - Single GPU execution (no multi-GPU/distributed testing)
|
||||
# - Accuracy validation against reference implementations
|
||||
#
|
||||
# =============================================================================
|
||||
# Test Coverage Matrix
|
||||
# =============================================================================
|
||||
# 1. BACKENDS: CUTLASS, TRTLLM, CUTEDSL, DEEPGEMM
|
||||
#
|
||||
# 2. QUANTIZATION ALGORITHMS:
|
||||
# - Unquantized (None)
|
||||
# - FP8, FP8_BLOCK_SCALES
|
||||
# - NVFP4, W4A8_NVFP4_FP8
|
||||
# - W4A16_MXFP4, W4A8_MXFP4_MXFP8
|
||||
# - W8A16, W4A8_AWQ
|
||||
#
|
||||
# 3. ACTIVATION DTYPES: float16, bfloat16
|
||||
#
|
||||
# 4. AUTOTUNER TACTICS:
|
||||
# - Autotune phase: find optimal tactics via AutoTuner
|
||||
# - Capture phase: record all tactics used
|
||||
# - Replay phase: verify each tactic produces correct results
|
||||
#
|
||||
# 5. GPTOSS_STYLE (SwiGLU with custom parameters):
|
||||
# - swiglu_alpha: scaling factor (default=1)
|
||||
# - swiglu_beta: bias term (default=0)
|
||||
# - swiglu_limit: clipping limit (default=inf)
|
||||
# - Supported by: CUTLASS (W4A8_MXFP4_MXFP8), TRTLLM (W4A8_MXFP4_MXFP8)
|
||||
#
|
||||
# 6. MODEL CONFIGURATIONS:
|
||||
# - Real models: Mixtral, DeepSeek, Qwen, Grok, GPT-OSS
|
||||
# - Boundary cases: prime num_experts, small sizes, top_k=1, top_k=num_experts
|
||||
#
|
||||
# =============================================================================
|
||||
# Skip Logic
|
||||
# =============================================================================
|
||||
# Tests are automatically skipped for unsupported configurations using:
|
||||
# - backend.can_implement(): Check dtype/quant_algo/gptoss_style support
|
||||
# - should_skip_TRTLLM(): TRTLLM-specific constraints (num_experts % 4, etc.)
|
||||
# - should_skip_CUTEDSL(): CuteDSL-specific accuracy issues
|
||||
# - 128-alignment requirements for quantization
|
||||
#
|
||||
# =============================================================================
|
||||
@pytest.mark.skip(reason="Temporarily skipped due to the long time to run the test")
|
||||
@pytest.mark.parametrize(
|
||||
"dtype_activation,backend_type,quant_algo,seq_len,model_config,swiglu_alpha,swiglu_beta,swiglu_limit",
|
||||
TEST_PARAMS,
|
||||
)
|
||||
def test_moe_backend(
|
||||
dtype_activation: torch.dtype,
|
||||
backend_type: MoeBackendType,
|
||||
quant_algo: Optional[QuantAlgo],
|
||||
seq_len: int,
|
||||
model_config: MoeModelConfig,
|
||||
swiglu_alpha: float,
|
||||
swiglu_beta: float,
|
||||
swiglu_limit: float,
|
||||
):
|
||||
"""
|
||||
Test MoE backend with autotune to capture all tactics.
|
||||
|
||||
This test verifies:
|
||||
1. Autotune works correctly with the backend
|
||||
2. All tactics are captured properly
|
||||
3. Different sequence lengths use appropriate tactics
|
||||
4. gptoss_style (SwiGlu with custom parameters) works correctly
|
||||
"""
|
||||
# Determine gptoss_style based on swiglu parameters
|
||||
# gptoss_style is True when any swiglu parameter deviates from default
|
||||
# Default values: alpha=1, beta=0, limit=inf
|
||||
gptoss_style = swiglu_alpha != 1 or swiglu_beta != 0 or swiglu_limit != float("inf")
|
||||
|
||||
# Note: Skip logic is now handled at parametrize level via get_quick_skip_reason()
|
||||
# which calls backend's can_implement() and should_skip_* functions.
|
||||
# This avoids entering test function for invalid combinations, significantly
|
||||
# reducing test collection time (from ~17 min to ~5 sec for 3400+ skipped tests).
|
||||
|
||||
# Extract model parameters
|
||||
num_experts = model_config.num_experts
|
||||
top_k = model_config.top_k
|
||||
hidden_size = model_config.hidden_size
|
||||
intermediate_size = model_config.intermediate_size
|
||||
|
||||
# Create mapping
|
||||
mapping = Mapping()
|
||||
mapping.rank = mpi_rank()
|
||||
|
||||
with torch.device(f"cuda:{mapping.rank}"):
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
# Setup autotuner distributed state
|
||||
AutoTuner.get().setup_distributed_state(mapping)
|
||||
|
||||
# Create routing method
|
||||
routing_method = RenormalizeMoeRoutingMethod(top_k=top_k)
|
||||
|
||||
# Create test inputs
|
||||
x = torch.randn((seq_len, hidden_size), dtype=dtype_activation, device="cuda")
|
||||
router_logits = torch.randn((seq_len, num_experts), dtype=dtype_activation, device="cuda")
|
||||
|
||||
# Get quantization parameters
|
||||
# Pass backend_type to determine scale format (DEEPGEMM/TRTLLM need E8M0 scale)
|
||||
quantize_util_cls, quant_config, quant_kwargs = get_test_quant_params(
|
||||
quant_algo, x, backend_type
|
||||
)
|
||||
|
||||
# Create quantize utility with gptoss_style parameters
|
||||
quantize_util = quantize_util_cls(
|
||||
num_experts=num_experts,
|
||||
dtype=dtype_activation,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_size=hidden_size,
|
||||
quant_config=quant_config,
|
||||
bias=gptoss_style,
|
||||
gptoss_style=gptoss_style,
|
||||
swiglu_alpha=swiglu_alpha if gptoss_style else None,
|
||||
swiglu_beta=swiglu_beta if gptoss_style else None,
|
||||
swiglu_limit=swiglu_limit if gptoss_style else None,
|
||||
)
|
||||
|
||||
# Get swiglu tensors if gptoss_style is enabled
|
||||
swiglu_tensors = quantize_util.get_swiglu_tensors()
|
||||
|
||||
# Create backend first (needed for MXFP4_MXFP8 to get shapes)
|
||||
backend = create_test_backend(
|
||||
backend_type=backend_type,
|
||||
routing_method=routing_method,
|
||||
num_experts=num_experts,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
dtype=dtype_activation,
|
||||
quant_config=quant_config,
|
||||
mapping=mapping,
|
||||
bias=gptoss_style,
|
||||
swiglu_alpha=swiglu_tensors["swiglu_alpha"] if swiglu_tensors else None,
|
||||
swiglu_beta=swiglu_tensors["swiglu_beta"] if swiglu_tensors else None,
|
||||
swiglu_limit=swiglu_tensors["swiglu_limit"] if swiglu_tensors else None,
|
||||
)
|
||||
|
||||
# W4A8_MXFP4_MXFP8 requires different weights for backend and reference
|
||||
# due to different padding/alignment requirements
|
||||
ref_cls = quant_kwargs.pop("ref_cls", None)
|
||||
ref_module_kwargs = {}
|
||||
if quant_algo == QuantAlgo.W4A8_MXFP4_MXFP8:
|
||||
weights, ref_weights, ref_module_kwargs = quantize_util.prepare_weights_from_backend(
|
||||
backend, **quant_kwargs
|
||||
)
|
||||
else:
|
||||
weights = quantize_util.create_weights(**quant_kwargs)
|
||||
ref_weights = weights
|
||||
|
||||
backend.load_weights([weights])
|
||||
backend.post_load_weights()
|
||||
backend.cuda()
|
||||
|
||||
# Create reference
|
||||
if ref_cls is not None:
|
||||
ref_fused_moe = quantize_util.create_ref_module(
|
||||
routing_method, ref_cls=ref_cls, **ref_module_kwargs
|
||||
)
|
||||
else:
|
||||
ref_fused_moe = quantize_util.create_ref_module(routing_method, **ref_module_kwargs)
|
||||
ref_fused_moe.load_weights([ref_weights])
|
||||
ref_fused_moe.cuda()
|
||||
|
||||
# Clear autotuner cache before autotune phase
|
||||
AutoTuner.get().clear_cache()
|
||||
|
||||
# Get reference output first
|
||||
with torch.inference_mode():
|
||||
ref_output = ref_fused_moe.forward(x, router_logits)
|
||||
|
||||
# Helper to run MoE computation
|
||||
def run_moe():
|
||||
token_selected_experts, token_final_scales = routing_method.apply(router_logits)
|
||||
x_quantized, x_sf = backend.quantize_input(x, post_quant_comm=False)
|
||||
return run_backend_moe(
|
||||
backend,
|
||||
backend_type,
|
||||
x_quantized,
|
||||
x_sf,
|
||||
token_selected_experts,
|
||||
token_final_scales,
|
||||
dtype_activation,
|
||||
router_logits,
|
||||
)
|
||||
|
||||
# Configure AutoTuner for faster profiling (reduce warmup/repeat for unit tests)
|
||||
autotuner = AutoTuner.get()
|
||||
autotuner.warmup = 0 # default: 2
|
||||
autotuner.repeat = 1 # default: 10
|
||||
autotuner.stream_delay_micro_secs = 10 # default: 1000
|
||||
|
||||
# Autotune phase: tune kernels to find best tactics
|
||||
# Use cache_path to speed up subsequent runs by reusing tuning results
|
||||
with torch.inference_mode(), autotune(cache_path="/tmp/moe_autotuner_cache.json"):
|
||||
_ = run_moe()
|
||||
|
||||
# Check if this backend+quant_algo combination supports autotuner capture/replay
|
||||
if supports_autotuner_capture(backend_type, quant_algo):
|
||||
# Capture phase: record which tactics are used (requires actual execution)
|
||||
with AutoTuner.get().capture() as all_tactics, torch.inference_mode():
|
||||
_ = run_moe()
|
||||
|
||||
# Replay phase: test each tactic for correctness
|
||||
# Set fail_fast=True to stop on first failure, False to run all and report summary
|
||||
replay_tactics_and_check(
|
||||
all_tactics=all_tactics,
|
||||
run_moe_fn=run_moe,
|
||||
check_accuracy_fn=ref_fused_moe.check_accuracy,
|
||||
ref_output=ref_output,
|
||||
backend_type=backend_type,
|
||||
quant_algo=quant_algo,
|
||||
fail_fast=False, # Change to True to fail on first error
|
||||
)
|
||||
else:
|
||||
# For backends that don't support autotuner capture/replay,
|
||||
# just run a simple accuracy check
|
||||
with torch.inference_mode():
|
||||
output = run_moe()
|
||||
ref_fused_moe.check_accuracy(output, ref_output)
|
||||
@ -864,23 +864,13 @@ def test_fused_moe_fp8_blockwise_wide_ep(alltoall_method_type):
|
||||
[DefaultMoeRoutingMethod],
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
|
||||
ids=lambda x: ""
|
||||
if x == 0 else "enable_configurable_moe")
|
||||
def test_fused_moe_fp8_blockwise_deepgemm(dtype,
|
||||
num_experts,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
RoutingMethodCls,
|
||||
enable_configurable_moe,
|
||||
mocker,
|
||||
mapping=None):
|
||||
|
||||
mocker.patch.dict(os.environ, {
|
||||
"ENABLE_CONFIGURABLE_MOE":
|
||||
"1" if enable_configurable_moe == 1 else "0"
|
||||
})
|
||||
|
||||
SEQ_LEN = seq_len
|
||||
HIDDEN_SIZE = hidden_size
|
||||
INTERMEDIATE_SIZE = 256
|
||||
@ -1388,25 +1378,7 @@ def test_fused_moe_fp8_blockwise_cute_dsl_multi_gpu(ep_size, routing_method,
|
||||
@pytest.mark.parametrize(
|
||||
"finalize_fusion", [True, False],
|
||||
ids=["enable_finalize_fusion", "disable_finalize_fusion"])
|
||||
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
|
||||
ids=lambda x: ""
|
||||
if x == 0 else "enable_configurable_moe")
|
||||
def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion,
|
||||
enable_configurable_moe, mocker):
|
||||
|
||||
if enable_configurable_moe == 1 and moe_backend not in [
|
||||
"TRTLLM", "CUTLASS"
|
||||
]:
|
||||
pytest.skip(
|
||||
"ENABLE_CONFIGURABLE_MOE=1, only TRTLLM and CUTLASS backend are enabled"
|
||||
)
|
||||
|
||||
mocker.patch.dict(
|
||||
os.environ, {
|
||||
"ENABLE_CONFIGURABLE_MOE":
|
||||
"1" if enable_configurable_moe == 1
|
||||
and moe_backend in ["TRTLLM", "CUTLASS"] else "0"
|
||||
})
|
||||
def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion):
|
||||
|
||||
run_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion)
|
||||
|
||||
@ -1417,17 +1389,8 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion,
|
||||
@pytest.mark.parametrize("swiglu_beta", [0, 1], ids=lambda v: f"beta{v}")
|
||||
@pytest.mark.parametrize("swiglu_limit", [float("inf"), 1],
|
||||
ids=lambda v: f"limit{v}")
|
||||
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
|
||||
ids=lambda x: ""
|
||||
if x == 0 else "enable_configurable_moe")
|
||||
def test_fused_moe_nvfp4_gptoss_style(hidden_size, intermediate_size,
|
||||
swiglu_alpha, swiglu_beta, swiglu_limit,
|
||||
enable_configurable_moe, mocker):
|
||||
mocker.patch.dict(os.environ, {
|
||||
"ENABLE_CONFIGURABLE_MOE":
|
||||
"1" if enable_configurable_moe == 1 else "0"
|
||||
})
|
||||
|
||||
swiglu_alpha, swiglu_beta, swiglu_limit):
|
||||
run_fused_moe_nvfp4(dtype=torch.bfloat16,
|
||||
moe_backend="TRTLLM",
|
||||
finalize_fusion=False,
|
||||
@ -1686,15 +1649,7 @@ def run_fused_moe_nvfp4(dtype,
|
||||
@pytest.mark.parametrize(
|
||||
"moe_backend",
|
||||
[pytest.param("TRTLLM", marks=skip_blackwell_geforce), "CUTLASS"])
|
||||
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
|
||||
ids=lambda x: ""
|
||||
if x == 0 else "enable_configurable_moe")
|
||||
def test_fused_moe_w4a8_nvfp4_fp8(moe_backend, enable_configurable_moe, mocker):
|
||||
mocker.patch.dict(os.environ, {
|
||||
"ENABLE_CONFIGURABLE_MOE":
|
||||
"1" if enable_configurable_moe == 1 else "0"
|
||||
})
|
||||
|
||||
def test_fused_moe_w4a8_nvfp4_fp8(moe_backend):
|
||||
dtype = torch.bfloat16
|
||||
mapping = Mapping()
|
||||
mapping.rank = mpi_rank()
|
||||
@ -2109,20 +2064,7 @@ def test_fused_moe_w4afp8(dtype, weight_loading_mode):
|
||||
@pytest.mark.parametrize("hidden_unpadded", [64, 192, 256])
|
||||
@pytest.mark.parametrize("seq_len", [8, 128])
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
|
||||
ids=lambda x: ""
|
||||
if x == 0 else "enable_configurable_moe")
|
||||
def test_fused_moe_mxfp4_mxfp8(moe_backend, hidden_unpadded, seq_len, bias,
|
||||
enable_configurable_moe, mocker):
|
||||
|
||||
mocker.patch.dict(os.environ, {
|
||||
"ENABLE_CONFIGURABLE_MOE":
|
||||
"1" if enable_configurable_moe == 1 else "0"
|
||||
})
|
||||
|
||||
if moe_backend == "CUTLASS" and hidden_unpadded % 128 != 0:
|
||||
pytest.skip()
|
||||
|
||||
def test_fused_moe_mxfp4_mxfp8(moe_backend, hidden_unpadded, seq_len, bias):
|
||||
SCALING_VECTOR_SIZE = 32
|
||||
dtype = torch.bfloat16
|
||||
SEQ_LEN = seq_len
|
||||
@ -2379,17 +2321,7 @@ def test_fused_moe_mxfp4_mxfp8(moe_backend, hidden_unpadded, seq_len, bias,
|
||||
marks=[skip_pre_hopper, skip_blackwell, skip_blackwell_geforce]),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
|
||||
ids=lambda x: ""
|
||||
if x == 0 else "enable_configurable_moe")
|
||||
def test_fused_moe_wfp4a16(dtype, hidden_size, moe_backend,
|
||||
enable_configurable_moe, mocker):
|
||||
|
||||
mocker.patch.dict(os.environ, {
|
||||
"ENABLE_CONFIGURABLE_MOE":
|
||||
"1" if enable_configurable_moe == 1 else "0"
|
||||
})
|
||||
|
||||
def test_fused_moe_wfp4a16(dtype, hidden_size, moe_backend):
|
||||
mapping = Mapping()
|
||||
mapping.rank = mpi_rank()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user