[None][chore] Add test configurable moe module (#10575)

Signed-off-by: leslie-fang25 <leslief@nvidia.com>
This commit is contained in:
Leslie Fang 2026-01-14 07:25:57 +08:00 committed by GitHub
parent ccdfa43a6e
commit bc119f5644
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 555 additions and 0 deletions

View File

@ -0,0 +1,424 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.util import check_accuracy
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.modules.fused_moe import BaseMoeRoutingMethod
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
def get_test_quant_params(quant_algo, x):
"""
Create quantization configuration and corresponding kwargs for testing.
"""
quantize_util_cls = None
quant_config = None
quant_kwargs = {}
if quant_algo is None:
quantize_util_cls = BaseQuantizeUtil
elif quant_algo == QuantAlgo.FP8:
quantize_util_cls = FP8QuantizeUtil
quant_config = QuantConfig(quant_algo=QuantAlgo.FP8)
_, x_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(x)
x_scale = x_scale.float().squeeze()
quant_kwargs["x_scale"] = x_scale
elif quant_algo == QuantAlgo.NVFP4:
quantize_util_cls = NVFP4QuantizeUtil
quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4)
x_sf_global = (448 * 6) / x.abs().max().float()
scaling_vector_size = 16
quant_kwargs["scaling_vector_size"] = scaling_vector_size
quant_kwargs["x_sf_global"] = x_sf_global
else:
assert False, "unsupported quant_algo"
return quantize_util_cls, quant_config, quant_kwargs
class RefGatedMLPFusedMoE(nn.Module):
"""
RefGatedMLPFusedMoE serves as a reference implementation with Gated MLPs designed for correctness testing.
It utilizes derived classes to provide extensible support for various quantization algorithms.
"""
def __init__(
self,
num_experts: int,
routing_method: BaseMoeRoutingMethod,
hidden_size: int,
intermediate_size: int,
dtype: Optional[torch.dtype] = None,
model_config: Optional[ModelConfig] = None,
bias=False,
):
super().__init__()
self.num_experts = num_experts
self.routing_method = routing_method
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.bias = bias
self.dtype = dtype
if model_config is None:
model_config = ModelConfig()
self.quant_config = model_config.quant_config
self.experts = nn.ModuleList(
[
GatedMLP(
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
bias=bias,
dtype=self.dtype,
config=model_config,
use_cute_dsl_blockscaling_mm=False,
activation=F.silu,
)
for _ in range(self.num_experts)
]
)
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> torch.Tensor:
assert hidden_states.shape[-1] == self.hidden_size
hidden_states = hidden_states.view(-1, self.hidden_size)
selected_experts, routing_weights = self.routing_method.apply(router_logits)
final_hidden_states = torch.zeros(
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
)
for expert_id in range(self.num_experts):
if not torch.any(selected_experts == expert_id):
continue
batch_idx, nth_expert = torch.where(selected_experts == expert_id)
expert_inputs = hidden_states[batch_idx]
output = self.experts[expert_id](expert_inputs)
final_hidden_states[batch_idx] += (
routing_weights[batch_idx, nth_expert, None] * output.float()
)
final_hidden_states = final_hidden_states.reshape(hidden_states.shape)
return final_hidden_states
def load_weights(self, weights: List[Dict]):
assert len(weights) == 1
weights = weights[0]
for expert in range(self.num_experts):
gate_up_proj_weights = [{}, {}]
down_proj_weights = [{}]
gate_up_proj_weights[0]["weight"] = weights[f"{expert}.w1.weight"]
gate_up_proj_weights[1]["weight"] = weights[f"{expert}.w3.weight"]
down_proj_weights[0]["weight"] = weights[f"{expert}.w2.weight"]
if self.bias:
gate_up_proj_weights[0]["bias"] = weights[f"{expert}.w1.bias"]
gate_up_proj_weights[1]["bias"] = weights[f"{expert}.w3.bias"]
down_proj_weights[0]["bias"] = weights[f"{expert}.w2.bias"]
self.experts[expert].gate_up_proj.load_weights(gate_up_proj_weights)
self.experts[expert].down_proj.load_weights(down_proj_weights)
def check_accuracy(self, output, ref_output):
# Here we use same rtol and atol as test_fused_moe
check_accuracy(output, ref_output, rtol=2e-1, atol=2e-1, percent=0.984)
class FP8RefGatedMLPFusedMoE(RefGatedMLPFusedMoE):
"""
A derived class of RefGatedMLPFusedMoE serves as a reference implementation of FP8 quantization
for correctness testing.
"""
def load_weights(self, weights: List[Dict]):
assert len(weights) == 1
weights = weights[0]
for expert in range(self.num_experts):
gate_up_proj_weights = [{}, {}]
down_proj_weights = [{}]
gate_up_proj_weights[0]["weight"] = weights[f"{expert}.w1.weight"]
gate_up_proj_weights[1]["weight"] = weights[f"{expert}.w3.weight"]
down_proj_weights[0]["weight"] = weights[f"{expert}.w2.weight"]
if self.bias:
gate_up_proj_weights[0]["bias"] = weights[f"{expert}.w1.bias"]
gate_up_proj_weights[1]["bias"] = weights[f"{expert}.w3.bias"]
down_proj_weights[0]["bias"] = weights[f"{expert}.w2.bias"]
assert self.quant_config and self.quant_config.quant_algo == QuantAlgo.FP8
gate_up_proj_weights[0]["weight_scale"] = weights[f"{expert}.w1.weight_scale"]
gate_up_proj_weights[1]["weight_scale"] = weights[f"{expert}.w3.weight_scale"]
down_proj_weights[0]["weight_scale"] = weights[f"{expert}.w2.weight_scale"]
gate_up_proj_weights[0]["input_scale"] = weights[f"{expert}.w1.input_scale"]
gate_up_proj_weights[1]["input_scale"] = weights[f"{expert}.w3.input_scale"]
down_proj_weights[0]["input_scale"] = weights[f"{expert}.w2.input_scale"]
self.experts[expert].gate_up_proj.load_weights(gate_up_proj_weights)
self.experts[expert].down_proj.load_weights(down_proj_weights)
def check_accuracy(self, output, ref_output):
# Here we use same rtol and atol as test_fused_moe
check_accuracy(output, ref_output, rtol=4e-2, atol=1e-1, percent=0.99)
class NVFP4RefGatedMLPFusedMoE(RefGatedMLPFusedMoE):
"""
A derived class of RefGatedMLPFusedMoE serves as a reference implementation of NVFP4 quantization
for correctness testing.
"""
def load_weights(self, weights: List[Dict]):
assert len(weights) == 1
weights = weights[0]
for expert in range(self.num_experts):
gate_up_proj_weights = [{}, {}]
down_proj_weights = [{}]
gate_up_proj_weights[0]["weight"] = weights[f"{expert}.w1.weight"]
gate_up_proj_weights[1]["weight"] = weights[f"{expert}.w3.weight"]
down_proj_weights[0]["weight"] = weights[f"{expert}.w2.weight"]
if self.bias:
gate_up_proj_weights[0]["bias"] = weights[f"{expert}.w1.bias"]
gate_up_proj_weights[1]["bias"] = weights[f"{expert}.w3.bias"]
down_proj_weights[0]["bias"] = weights[f"{expert}.w2.bias"]
assert self.quant_config and self.quant_config.quant_algo == QuantAlgo.NVFP4, (
"expect quant_algo to be NVFP4 in load weights"
)
gate_up_proj_weights[0]["weight_scale"] = weights[f"{expert}.w1.weight_scale"]
gate_up_proj_weights[1]["weight_scale"] = weights[f"{expert}.w3.weight_scale"]
down_proj_weights[0]["weight_scale"] = weights[f"{expert}.w2.weight_scale"]
gate_up_proj_weights[0]["input_scale"] = weights[f"{expert}.w1.input_scale"]
gate_up_proj_weights[1]["input_scale"] = weights[f"{expert}.w3.input_scale"]
down_proj_weights[0]["input_scale"] = weights[f"{expert}.w2.input_scale"]
gate_up_proj_weights[0]["weight_scale_2"] = weights[f"{expert}.w1.weight_scale_2"]
gate_up_proj_weights[1]["weight_scale_2"] = weights[f"{expert}.w3.weight_scale_2"]
down_proj_weights[0]["weight_scale_2"] = weights[f"{expert}.w2.weight_scale_2"]
self.experts[expert].gate_up_proj.load_weights(gate_up_proj_weights)
self.experts[expert].down_proj.load_weights(down_proj_weights)
def check_accuracy(self, output, ref_output):
# Here we use same rtol and atol as test_fused_moe
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.15)
class BaseQuantizeUtil(ABC):
"""
BaseQuantizeUtil serves as a base class for MoE correctess testing which provides interface
to create quantized weights and reference modules. It can be extended for different quantization algorithms.
"""
def __init__(
self,
num_experts: int,
dtype: torch.dtype,
intermediate_size: int,
hidden_size: int,
quant_config: QuantConfig,
):
self.num_experts = num_experts
self.dtype = dtype
self.intermediate_size = intermediate_size
self.hidden_size = hidden_size
self.quant_config = quant_config
def create_weights(self, **quant_kwargs) -> Dict[str, torch.Tensor]:
"""
Create quantized weights for MoE experts.
"""
assert self.quant_config is None, "quant_config should be None for BaseQuantizeUtil"
weights = {}
for expert_id in range(self.num_experts):
w1_weight = torch.randn(
(self.intermediate_size, self.hidden_size), dtype=self.dtype, device="cuda"
)
w2_weight = torch.randn(
(self.hidden_size, self.intermediate_size), dtype=self.dtype, device="cuda"
)
w3_weight = torch.randn(
(self.intermediate_size, self.hidden_size), dtype=self.dtype, device="cuda"
)
weights[f"{expert_id}.w1.weight"] = w1_weight
weights[f"{expert_id}.w2.weight"] = w2_weight
weights[f"{expert_id}.w3.weight"] = w3_weight
return weights
def create_ref_module(self, routing_method, ref_cls=RefGatedMLPFusedMoE) -> torch.nn.Module:
"""
Create a reference module for correctness testing.
"""
ref_fused_moe = ref_cls(
num_experts=self.num_experts,
routing_method=routing_method,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
dtype=self.dtype,
model_config=ModelConfig(quant_config=self.quant_config),
)
return ref_fused_moe
class FP8QuantizeUtil(BaseQuantizeUtil):
"""
FP8QuantizeUtil inherits from BaseQuantizeUtil to support correctness testing for FP8 quantized MoE modules.
"""
def create_weights(self, **quant_kwargs) -> Dict[str, torch.Tensor]:
"""
Create quantized weights for MoE experts.
"""
assert self.quant_config is not None and self.quant_config.quant_algo == QuantAlgo.FP8, (
"expect quant_algo to be fp8"
)
weights = {}
for expert_id in range(self.num_experts):
w1_weight = torch.randn(
(self.intermediate_size, self.hidden_size), dtype=self.dtype, device="cuda"
)
w2_weight = torch.randn(
(self.hidden_size, self.intermediate_size), dtype=self.dtype, device="cuda"
)
w3_weight = torch.randn(
(self.intermediate_size, self.hidden_size), dtype=self.dtype, device="cuda"
)
w1_weight_fp8, w1_weight_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(
w1_weight
)
w1_weight_fp8 = w1_weight_fp8.view(torch.float8_e4m3fn).cuda()
w2_weight_fp8, w2_weight_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(
w2_weight
)
w2_weight_fp8 = w2_weight_fp8.view(torch.float8_e4m3fn).cuda()
w3_weight_fp8, w3_weight_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(
w3_weight
)
w3_weight_fp8 = w3_weight_fp8.view(torch.float8_e4m3fn).cuda()
assert "x_scale" in quant_kwargs, "x_scale is required for FP8 quant"
x_scale = quant_kwargs["x_scale"]
w1_input_scale = x_scale.cuda()
w2_input_scale = x_scale.cuda()
w3_input_scale = x_scale.cuda()
weights[f"{expert_id}.w1.weight"] = w1_weight_fp8
weights[f"{expert_id}.w2.weight"] = w2_weight_fp8
weights[f"{expert_id}.w3.weight"] = w3_weight_fp8
weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale.float()
weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale.float()
weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale.float()
weights[f"{expert_id}.w1.input_scale"] = w1_input_scale
weights[f"{expert_id}.w2.input_scale"] = w2_input_scale
weights[f"{expert_id}.w3.input_scale"] = w3_input_scale
return weights
def create_ref_module(self, routing_method, ref_cls=FP8RefGatedMLPFusedMoE) -> torch.nn.Module:
"""
Create a reference module for correctness testing.
"""
return super().create_ref_module(routing_method, ref_cls)
class NVFP4QuantizeUtil(BaseQuantizeUtil):
"""
NVFP4QuantizeUtil inherits from BaseQuantizeUtil to support correctness testing for NVFP4 quantized MoE modules.
"""
def create_weights(self, **quant_kwargs) -> Dict[str, torch.Tensor]:
"""
Create quantized weights for MoE experts.
"""
assert self.quant_config is not None and self.quant_config.quant_algo == QuantAlgo.NVFP4, (
"expect quant_algo to be NVFP4"
)
weights = {}
for expert_id in range(self.num_experts):
w1_weight = (
torch.randn(
(self.intermediate_size, self.hidden_size), dtype=self.dtype, device="cuda"
)
* 0.05
)
w2_weight = (
torch.randn(
(self.hidden_size, self.intermediate_size), dtype=self.dtype, device="cuda"
)
* 0.05
)
w3_weight = (
torch.randn(
(self.intermediate_size, self.hidden_size), dtype=self.dtype, device="cuda"
)
* 0.05
)
assert "scaling_vector_size" in quant_kwargs, (
"scaling_vector_size is required for NVFP4 quant"
)
assert "x_sf_global" in quant_kwargs, "x_sf_global is required for NVFP4 quant"
scaling_vector_size = quant_kwargs["scaling_vector_size"]
x_sf_global = quant_kwargs["x_sf_global"]
w1_sf_global = (448 * 6) / w1_weight.abs().max().float()
w2_sf_global = (448 * 6) / w2_weight.abs().max().float()
w3_sf_global = (448 * 6) / w3_weight.abs().max().float()
w3_w1_global = min(
w1_sf_global, w3_sf_global
) # w3 global and w1 global must be the same
# start to quantize
w1_weight_nvfp4, w1_sf_block_unswizzled = torch.ops.trtllm.fp4_quantize(
w1_weight, w3_w1_global, scaling_vector_size, False, False
)
w1_sf_block_unswizzled = w1_sf_block_unswizzled.view(self.intermediate_size, -1)
w2_weight_nvfp4, w2_sf_block_unswizzled = torch.ops.trtllm.fp4_quantize(
w2_weight, w2_sf_global, scaling_vector_size, False, False
)
w2_sf_block_unswizzled = w2_sf_block_unswizzled.view(self.hidden_size, -1)
w3_weight_nvfp4, w3_sf_block_unswizzled = torch.ops.trtllm.fp4_quantize(
w3_weight, w3_w1_global, scaling_vector_size, False, False
)
w3_sf_block_unswizzled = w3_sf_block_unswizzled.view(self.intermediate_size, -1)
w1_input_scale = x_sf_global.cuda()
w2_input_scale = x_sf_global.cuda()
w3_input_scale = x_sf_global.cuda()
weights[f"{expert_id}.w1.weight"] = w1_weight_nvfp4
weights[f"{expert_id}.w2.weight"] = w2_weight_nvfp4
weights[f"{expert_id}.w3.weight"] = w3_weight_nvfp4
weights[f"{expert_id}.w1.weight_scale"] = w1_sf_block_unswizzled.view(
torch.float8_e4m3fn
).cuda()
weights[f"{expert_id}.w2.weight_scale"] = w2_sf_block_unswizzled.view(
torch.float8_e4m3fn
).cuda()
weights[f"{expert_id}.w3.weight_scale"] = w3_sf_block_unswizzled.view(
torch.float8_e4m3fn
).cuda()
weights[f"{expert_id}.w1.input_scale"] = 1.0 / w1_input_scale
weights[f"{expert_id}.w2.input_scale"] = 1.0 / w2_input_scale
weights[f"{expert_id}.w3.input_scale"] = 1.0 / w3_input_scale
weights[f"{expert_id}.w1.weight_scale_2"] = 1.0 / w3_w1_global
weights[f"{expert_id}.w2.weight_scale_2"] = 1.0 / w2_sf_global
weights[f"{expert_id}.w3.weight_scale_2"] = 1.0 / w3_w1_global
return weights
def create_ref_module(
self, routing_method, ref_cls=NVFP4RefGatedMLPFusedMoE
) -> torch.nn.Module:
"""
Create a reference module for correctness testing.
"""
return super().create_ref_module(routing_method, ref_cls)

View File

@ -0,0 +1,131 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
from _torch.modules.moe.quantize_utils import get_test_quant_params
from transformers.configuration_utils import PretrainedConfig
from utils.util import getSMVersion
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.modules.fused_moe import RenormalizeMoeRoutingMethod, create_moe
from tensorrt_llm._utils import mpi_rank
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantAlgo
@pytest.mark.parametrize(
"quant_algo",
[
None,
QuantAlgo.FP8,
QuantAlgo.NVFP4,
],
ids=lambda val: f"quant_algo={val}",
)
@pytest.mark.parametrize(
"moe_backend",
[
"CUTLASS",
"TRTLLM",
],
ids=lambda val: f"moe_backend={val}",
)
@pytest.mark.parametrize(
"dtype",
[
torch.float16,
torch.bfloat16,
],
ids=lambda val: f"dtype={val}",
)
def test_moe(dtype, moe_backend, quant_algo, mocker):
# Enable configurable moe by default
mocker.patch.dict(os.environ, {"ENABLE_CONFIGURABLE_MOE": "1"})
if moe_backend == "TRTLLM":
if dtype == torch.float16 and quant_algo == QuantAlgo.NVFP4:
pytest.skip("TRTLLM NVFP4 MoE backend does not support float16 yet")
if quant_algo == QuantAlgo.NVFP4 and getSMVersion() < 100:
pytest.skip("This test is not supported in pre-Blackwell architecture")
# Hardcode some parameters for testing
# activation and weight related
seq_len = 4
top_k = 2
num_experts = 8
hidden_size = 512
intermediate_size = 512
# Other parameters
finalize_fusion = True
# Create mapping for current rank
mapping = Mapping()
mapping.rank = mpi_rank()
with torch.device(f"cuda:{mapping.rank}"):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# Create route method
routing_method = RenormalizeMoeRoutingMethod(top_k=top_k)
# Create activation and weight
x = torch.randn((seq_len, hidden_size), dtype=dtype, device="cuda")
router_logits = torch.randn((seq_len, num_experts), dtype=dtype, device="cuda")
quantize_util_cls, quant_config, quant_kwargs = get_test_quant_params(quant_algo, x)
quantize_util = quantize_util_cls(
num_experts=num_experts,
dtype=dtype,
intermediate_size=intermediate_size,
hidden_size=hidden_size,
quant_config=quant_config,
)
weights = quantize_util.create_weights(**quant_kwargs)
# Create pretrained config
pretrained_config = PretrainedConfig()
pretrained_config.num_experts = num_experts
pretrained_config.hidden_size = hidden_size
pretrained_config.intermediate_size = intermediate_size
pretrained_config.torch_dtype = dtype
# Create fused MoE module
fused_moe = create_moe(
routing_method=routing_method,
reduce_results=True,
model_config=ModelConfig(
pretrained_config=pretrained_config,
quant_config=quant_config,
moe_backend=moe_backend,
moe_disable_finalize_fusion=not finalize_fusion,
),
)
fused_moe.load_weights([weights])
fused_moe.post_load_weights()
fused_moe.cuda()
ref_fused_moe = quantize_util.create_ref_module(routing_method)
ref_fused_moe.load_weights([weights])
ref_fused_moe.cuda()
# Evaluate the outputs
with torch.inference_mode():
ref_output = ref_fused_moe.forward(x, router_logits)
output = fused_moe.forward(x, router_logits)
ref_fused_moe.check_accuracy(output, ref_output)