[None][chore] NVFP4 MoE - Move weights transformation to fusion phase… (#10803)

Signed-off-by: Tal Cherckez <tcherckez@nvl72070-T11.cm.cluster>
Signed-off-by: Tal Cherckez <tcherckez@nvl72039-T03.cm.cluster>
Signed-off-by: Tal Cherckez <tcherckez@nvl72098-T11.cm.cluster>
Signed-off-by: tcherckez-nvidia <127761168+tcherckez-nvidia@users.noreply.github.com>
Co-authored-by: Tal Cherckez <tcherckez@nvl72070-T11.cm.cluster>
Co-authored-by: Tal Cherckez <tcherckez@nvl72039-T03.cm.cluster>
Co-authored-by: Tal Cherckez <tcherckez@nvl72098-T11.cm.cluster>
This commit is contained in:
tcherckez-nvidia 2026-01-22 13:08:05 +02:00 committed by GitHub
parent 0243abee22
commit 128d4ac5be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 313 additions and 97 deletions

View File

@ -13,16 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from tensorrt_llm._torch.auto_deploy.custom_ops.quant import (
TRTLLM_NVFP4_COLUMN_SIZE,
TRTLLM_NVFP4_ROW_SIZE,
TRTLLM_NVFP4_SCALING_VECTOR_SIZE,
)
from tensorrt_llm._torch.auto_deploy.custom_ops.quant import TRTLLM_NVFP4_SCALING_VECTOR_SIZE
from tensorrt_llm._torch.utils import ActivationType
@ -262,14 +255,6 @@ def trtllm_quant_nvfp4_moe_fused(
mlp_style: "gated_mlp" or "mlp"
act_fn: "silu" for gated_mlp, "relu2" for mlp
"""
NVFP4_BLOCK_SIZE = TRTLLM_NVFP4_SCALING_VECTOR_SIZE
FP4_PER_UINT8 = 2
_, fc1_inter_size, _ = fc1_expert_weights_fp4.shape
n_experts, hidden_size, inter_size = fc2_expert_weights_fp4.shape
# Convert the inter_size from number of uint8 elements to number of FP4 elements.
inter_size *= FP4_PER_UINT8
# Validate block scale tensors are 3D (padding requirements handled below)
assert fc1_weight_blockscale_fp8.ndim == 3, "fc1_weight_blockscale_fp8 must be 3D"
@ -280,7 +265,7 @@ def trtllm_quant_nvfp4_moe_fused(
if x.dtype in (torch.float16, torch.bfloat16):
x_q_fp4, input_blockscale = torch.ops.trtllm.fp4_quantize(
x, fc1_act_global_scale, NVFP4_BLOCK_SIZE
x, fc1_act_global_scale, TRTLLM_NVFP4_SCALING_VECTOR_SIZE
)
output_dtype = x.dtype
else:
@ -288,66 +273,6 @@ def trtllm_quant_nvfp4_moe_fused(
input_blockscale = None
output_dtype = x.dtype
# Pad inter_size to be divisible by TRTLLM_NVFP4_ROW_SIZE
inter_size_padded = math.ceil(inter_size / TRTLLM_NVFP4_ROW_SIZE) * TRTLLM_NVFP4_ROW_SIZE
fc1_inter_size_padded = (
math.ceil(fc1_inter_size / TRTLLM_NVFP4_ROW_SIZE) * TRTLLM_NVFP4_ROW_SIZE
)
hidden_size_padded = (
math.ceil(hidden_size / TRTLLM_NVFP4_COLUMN_SIZE) * TRTLLM_NVFP4_COLUMN_SIZE
)
inter_size_needs_padding = (is_gated_mlp and fc1_inter_size_padded != fc1_inter_size) or (
not is_gated_mlp and inter_size_padded != inter_size
)
hidden_size_needs_padding = hidden_size % TRTLLM_NVFP4_COLUMN_SIZE != 0
hidden_blocks_padded = hidden_size_padded // NVFP4_BLOCK_SIZE
inter_blocks_padded = inter_size_padded // NVFP4_BLOCK_SIZE
if inter_size_needs_padding or hidden_size_needs_padding:
# Pad fc1_expert_weights_fp4: [E, I, H/2] or [E, 2*I, H/2]
fc1_padded = fc1_expert_weights_fp4.new_zeros(
fc1_expert_weights_fp4.size(0),
fc1_inter_size_padded,
hidden_size_padded // FP4_PER_UINT8,
)
fc1_padded[:, :fc1_inter_size, : hidden_size // FP4_PER_UINT8] = fc1_expert_weights_fp4
fc1_expert_weights_fp4 = fc1_padded
# Block scales may already be padded, so check actual size
fc1_bs_size1 = fc1_weight_blockscale_fp8.size(1)
fc1_bs_size2 = fc1_weight_blockscale_fp8.size(2)
if fc1_bs_size1 < fc1_inter_size_padded or fc1_bs_size2 < hidden_blocks_padded:
fc1_blockscale_fp8_padded = fc1_weight_blockscale_fp8.new_zeros(
n_experts, fc1_inter_size_padded, hidden_blocks_padded
)
fc1_blockscale_fp8_padded[:, :fc1_bs_size1, :fc1_bs_size2] = fc1_weight_blockscale_fp8
fc1_weight_blockscale_fp8 = fc1_blockscale_fp8_padded
fc2_padded = fc2_expert_weights_fp4.new_zeros(
n_experts, hidden_size_padded, inter_size_padded // FP4_PER_UINT8
)
assert inter_size % NVFP4_BLOCK_SIZE == 0, (
f"inter_size {inter_size} must be divisible by {NVFP4_BLOCK_SIZE}"
)
fc2_padded[:, :hidden_size, : inter_size // FP4_PER_UINT8] = fc2_expert_weights_fp4
fc2_expert_weights_fp4 = fc2_padded
# Pad fc2_weight_blockscale_fp8: [E, H, I/16]
# Block scales may already be padded, so check actual size
fc2_bs_size1 = fc2_weight_blockscale_fp8.size(1)
fc2_bs_size2 = fc2_weight_blockscale_fp8.size(2)
if fc2_bs_size1 < hidden_size_padded or fc2_bs_size2 < inter_blocks_padded:
fc2_blockscale_fp8_padded = fc2_weight_blockscale_fp8.new_zeros(
n_experts, hidden_size_padded, inter_blocks_padded
)
fc2_blockscale_fp8_padded[:, :fc2_bs_size1, :fc2_bs_size2] = fc2_weight_blockscale_fp8
fc2_weight_blockscale_fp8 = fc2_blockscale_fp8_padded
# else: block scales already have correct padded size, use as-is
# quant_scales is described by this code:
# https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015
quant_scales = [

View File

@ -14,6 +14,7 @@ TRTLLM_FP4_OP_AVAILABLE = True
TRTLLM_NVFP4_SCALING_VECTOR_SIZE = 16
TRTLLM_NVFP4_ROW_SIZE = 128
TRTLLM_NVFP4_COLUMN_SIZE = 4
TRTLLM_NVFP4_PACKING_FACTOR = 2
@torch.library.custom_op("auto_deploy::torch_quant_fn", mutates_args=())

View File

@ -1,3 +1,4 @@
import math
from collections import defaultdict
from functools import partial
from typing import Dict, List, Literal, Optional, Tuple, Type
@ -8,6 +9,7 @@ from torch.fx import GraphModule, Node
from tensorrt_llm._torch.utils import ActivationType
from ...custom_ops.quant import TRTLLM_NVFP4_PACKING_FACTOR, TRTLLM_NVFP4_SCALING_VECTOR_SIZE
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code
@ -1627,14 +1629,13 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
"w3_weight",
"w1_input_scale",
"w2_input_scale",
"w3_input_scale",
"w1_weight_scale",
"w2_weight_scale",
"w3_weight_scale",
"w1_alpha",
"w2_alpha",
"w3_alpha",
"is_gated_mlp",
"act_fn",
)
def _stack(param_list, dim=0, device=None, dtype=None):
@ -1648,13 +1649,10 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
def _prepare_args_cutlass_format_nvfp4():
if is_gated_mlp:
# For gated MLP, concatenate w1 and w3 as [w3, w1]
fc1_expert_weights = torch.cat(
[w3_stacked, w1_stacked], dim=1
).contiguous() # [E, 2*I, H]
fc1_act_scale = torch.cat(
[w3_input_scale_stacked, w1_input_scale_stacked], dim=1
).contiguous()
fc1_alpha_stacked = torch.cat([w3_alpha_stacked, w1_alpha_stacked], dim=1).contiguous()
fc1_expert_weights = torch.cat([w3_stacked, w1_stacked], dim=1).contiguous()
# Expect w3 input scale and alpha to be the same as w1
fc1_act_scale = w1_input_scale_stacked
fc1_alpha_stacked = w1_alpha_stacked
fc1_weight_blockscale_fp8_stacked = torch.cat(
[w3_weight_blockscale_fp8_stacked, w1_weight_blockscale_fp8_stacked], dim=1
).contiguous()
@ -1666,6 +1664,7 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
fc2_expert_weights = w2_stacked
fc2_act_scale = w2_input_scale_stacked
fc2_weight_blockscale_fp8_stacked = w2_weight_blockscale_fp8_stacked
new_key_fc1_expert_weights = f"nvfp4_moe_w3_w1_stacked_{fused_key_counter}"
new_key_fc2_expert_weights = f"nvfp4_moe_w2_stacked_{fused_key_counter}"
@ -1681,13 +1680,57 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
new_key_fc1_alpha = f"nvfp4_moe_w1_alpha_stacked_{fused_key_counter}"
new_key_fc2_alpha = f"nvfp4_moe_w2_alpha_stacked_{fused_key_counter}"
# Pad fc1_expert_weights to match the already padded scales
fc1_pad_size = fc1_weight_blockscale_fp8_stacked.shape[1] - fc1_expert_weights.shape[1]
if fc1_pad_size > 0:
fc1_expert_weights = torch.nn.functional.pad(
fc1_expert_weights, (0, 0, 0, fc1_pad_size), mode="constant", value=0
)
# Need to update fc2 scales and weights to match the padded size of fc1,
# as they share the same intermediate dimension.
target_intermediate = fc1_weight_blockscale_fp8_stacked.shape[1]
TRTLLM_NVFP4_SCALING_VECTOR_NUM_ELEMENTS = TRTLLM_NVFP4_SCALING_VECTOR_SIZE
TRTLLM_NVFP4_SCALING_BYTES_SIZE = (
TRTLLM_NVFP4_SCALING_VECTOR_NUM_ELEMENTS // TRTLLM_NVFP4_PACKING_FACTOR
)
target_n_blocks = target_intermediate // TRTLLM_NVFP4_SCALING_VECTOR_NUM_ELEMENTS
padded_target_n_blocks = (
math.ceil(target_n_blocks / TRTLLM_NVFP4_SCALING_BYTES_SIZE)
* TRTLLM_NVFP4_SCALING_BYTES_SIZE
)
fc2_blocks_pad = padded_target_n_blocks - fc2_weight_blockscale_fp8_stacked.shape[2]
if fc2_blocks_pad > 0:
# unswizzle fc2 scales
fc2_blockscale_shape = list(fc2_weight_blockscale_fp8_stacked.shape)
fc2_blockscale_shape[2] = padded_target_n_blocks
fc2_weight_blockscale_fp8_stacked = torch.ops.trtllm.block_scale_interleave_reverse(
fc2_weight_blockscale_fp8_stacked.view(torch.uint8)
)
fc2_weight_blockscale_fp8_stacked = torch.nn.functional.pad(
fc2_weight_blockscale_fp8_stacked, (0, fc2_blocks_pad), mode="constant", value=0
)
fc2_weight_blockscale_fp8_stacked = (
torch.ops.trtllm.block_scale_interleave(fc2_weight_blockscale_fp8_stacked)
.view(torch.float8_e4m3fn)
.reshape(fc2_blockscale_shape)
)
fc2_expert_weights = torch.nn.functional.pad(
fc2_expert_weights,
(0, fc1_pad_size // TRTLLM_NVFP4_PACKING_FACTOR, 0, 0),
mode="constant",
value=0,
).view(torch.uint8)
# FP4 weights are already packed as uint8, don't convert dtype
_register_parameter(gm, new_key_fc1_expert_weights, fc1_expert_weights)
_register_parameter(gm, new_key_fc2_expert_weights, fc2_expert_weights)
_register_parameter(
gm, new_key_fc1_weight_blockscale_fp8, fc1_weight_blockscale_fp8_stacked
)
_register_parameter(gm, new_key_fc2_weight_blockscale_fp8, w2_weight_blockscale_fp8_stacked)
_register_parameter(
gm, new_key_fc2_weight_blockscale_fp8, fc2_weight_blockscale_fp8_stacked
)
_register_parameter(gm, new_key_fc1_act_scale, fc1_act_scale)
_register_parameter(gm, new_key_fc2_act_scale, fc2_act_scale)
_register_parameter(gm, new_key_fc1_alpha, fc1_alpha_stacked)
@ -1707,7 +1750,11 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
graph.get_attr(new_key_fc1_alpha),
graph.get_attr(new_key_fc2_alpha),
)
return args
kwargs = {
"is_gated_mlp": is_gated_mlp,
"act_fn": act_fn,
}
return args, kwargs
fused_key_counter = 0
graph = gm.graph
@ -1727,14 +1774,13 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
w3_list,
w1_input_scale,
w2_input_scale,
w3_input_scale,
w1_weight_scale,
w2_weight_scale,
w3_weight_scale,
w1_alpha,
w2_alpha,
w3_alpha,
is_gated_mlp,
act_fn,
) = _extract_op_args(node)
# Stack the actual tensor values (fast, like in quantize_moe.py)
@ -1746,7 +1792,6 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
# Scales are buffers, not parameters
w1_input_scale_stacked = _stack(w1_input_scale, dim=0)
w2_input_scale_stacked = _stack(w2_input_scale, dim=0)
w3_input_scale_stacked = _stack(w3_input_scale, dim=0, device=device, dtype=dtype)
# Use .view() not .to() to reinterpret bytes as float8, not value conversion
w1_weight_blockscale_fp8_stacked = _stack(w1_weight_scale, dim=0).view(torch.float8_e4m3fn)
@ -1757,9 +1802,8 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
w1_alpha_stacked = _stack(w1_alpha, dim=0)
w2_alpha_stacked = _stack(w2_alpha, dim=0)
w3_alpha_stacked = _stack(w3_alpha, dim=0, device=device, dtype=dtype)
args = _prepare_args_cutlass_format_nvfp4()
args, kwargs = _prepare_args_cutlass_format_nvfp4()
fused_key_counter += 1
@ -1768,7 +1812,7 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
new_node = graph.call_function(
replacement_op,
args,
kwargs=node.kwargs,
kwargs=kwargs,
)
node.replace_all_uses_with(new_node)

View File

@ -180,6 +180,9 @@ nvidia/Nemotron-MOE:
- quant_algo: FP8
kv_cache_quant_algo: FP8
accuracy: 86.884
- quant_algo: NVFP4
kv_cache_quant_algo: FP8
accuracy: 63.268
nvidia/Llama-3.1-Nemotron-Nano-8B-v1:
- accuracy: 37.15
- quant_algo: FP8

View File

@ -284,6 +284,9 @@ nvidia/Nemotron-MOE:
- quant_algo: FP8
kv_cache_quant_algo: FP8
accuracy: 73.879
- quant_algo: NVFP4
kv_cache_quant_algo: FP8
accuracy: 70.879
microsoft/Phi-4-mini-instruct:
- accuracy: 68.98
- quant_algo: FP8

View File

@ -14,6 +14,7 @@
# limitations under the License.
import pytest
from defs.conftest import skip_pre_blackwell
from test_common.llm_data import hf_id_to_local_model_dir, llm_models_root
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
@ -142,7 +143,7 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
MODEL_NAME = "nvidia/Nemotron-MOE"
MODEL_PATH_BF16 = f"{llm_models_root()}/Nemotron-Nano-3-30B-A3.5B-dev-1024"
MODEL_PATH_FP8 = f"{llm_models_root()}/Nemotron-Nano-3-30B-A3.5B-FP8-KVFP8-dev"
MODEL_PATH_NVFP4 = f"{llm_models_root()}/Nemotron-3-Nano-30B-A3B-NVFP4"
MODEL_PATH_NVFP4 = f"{llm_models_root()}/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4"
def get_default_kwargs(self):
return {
@ -220,11 +221,13 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.skip(reason="NVFP4 model is not in the CI yet")
def test_nvfp4(self):
@skip_pre_blackwell
@pytest.mark.parametrize("world_size", [1, 2, 4])
def test_nvfp4(self, world_size):
kwargs = self.get_default_kwargs()
with AutoDeployLLM(model=self.MODEL_PATH_NVFP4,
tokenizer=self.MODEL_PATH_NVFP4,
world_size=world_size,
**kwargs) as llm:
# Manually set quant_config for NVFP4 model to get the accuracy threshold
llm.args.quant_config.quant_algo = QuantAlgo.NVFP4

View File

@ -218,3 +218,6 @@ l0_dgx_b200:
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_bf16
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_fp8[4]
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_fp8[8]
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_nvfp4[1]
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_nvfp4[2]
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_nvfp4[4]

View File

@ -10,6 +10,7 @@ from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp4_global_scale
from tensorrt_llm._torch.utils import ActivationType
class BlockSparseTop2MLP(nn.Module):
@ -416,3 +417,236 @@ def test_fuse_moe_cleanup():
assert mem_after <= mem_before, (
f"CUDA memory increased after fusion: before={mem_before} after={mem_after}"
)
class MoEOpModelNVFP4(nn.Module):
"""MoE model using torch_quant_nvfp4_moe op for testing fusion to trtllm_quant_nvfp4_moe_fused.
This model creates weights with 3D block scales that are compatible with
the trtllm fused MoE kernel.
"""
def __init__(
self,
hidden_size=512, # Already aligned to all requirements (16, 128, etc.)
intermediate_size=256, # Already aligned - no padding needed
num_experts=3,
top_k=2,
dtype=torch.bfloat16,
is_gated_mlp=True,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_experts = num_experts
self.top_k = top_k
self.dtype = dtype
self.is_gated_mlp = is_gated_mlp
# Constants for NVFP4 layout
NVFP4_BLOCK_SIZE = 16
NVFP4_PACK_FACTOR = 2
FLOAT8_E4M3_MAX = 448.0
FLOAT4_E2M1_MAX = 6.0
self.gate = nn.Linear(hidden_size, num_experts, dtype=dtype)
# Create sample input for scale computation
sample_input = torch.randn(2, hidden_size, dtype=dtype, device="cuda") * 0.01
inp_scale = fp4_global_scale(sample_input)
# Per-expert quantized weights and scales
self.w1_weight = nn.ParameterList()
self.w2_weight = nn.ParameterList()
self.w3_weight = nn.ParameterList() if is_gated_mlp else None
for i in range(num_experts):
w1_fp32 = torch.randn(intermediate_size, hidden_size, device="cuda", dtype=dtype) * 0.01
w2_fp32 = torch.randn(hidden_size, intermediate_size, device="cuda", dtype=dtype) * 0.01
# Compute global scales
w1_amax = torch.abs(w1_fp32).max().to(torch.float32)
w2_amax = torch.abs(w2_fp32).max().to(torch.float32)
scale_1 = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
scale_2 = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
# Quantize weights (non-swizzled layout)
w1_fp4, w1_bs = torch.ops.trtllm.fp4_quantize(w1_fp32, scale_1, NVFP4_BLOCK_SIZE, False)
w2_fp4, w2_bs = torch.ops.trtllm.fp4_quantize(w2_fp32, scale_2, NVFP4_BLOCK_SIZE, False)
# fp4_quantize pads block scales but not weights - infer padded dims from block scale size
_, w1_k_packed = w1_fp4.shape
_, w2_k_packed = w2_fp4.shape
w1_k_padded = w1_k_packed * NVFP4_PACK_FACTOR # Convert from uint8 to FP4 element count
w2_k_padded = w2_k_packed * NVFP4_PACK_FACTOR
# Calculate padded N dimension from block scale tensor size
w1_n_padded = w1_bs.numel() // (w1_k_padded // NVFP4_BLOCK_SIZE)
w2_n_padded = w2_bs.numel() // (w2_k_padded // NVFP4_BLOCK_SIZE)
# Reshape block scales to 3D format [N_padded, K/block]
w1_bs_3d = w1_bs.view(w1_n_padded, w1_k_padded // NVFP4_BLOCK_SIZE)
w2_bs_3d = w2_bs.view(w2_n_padded, w2_k_padded // NVFP4_BLOCK_SIZE)
self.w1_weight.append(nn.Parameter(w1_fp4, requires_grad=False))
self.w2_weight.append(nn.Parameter(w2_fp4, requires_grad=False))
self.register_buffer(f"w1_input_scale_{i}", inp_scale)
self.register_buffer(f"w2_input_scale_{i}", inp_scale)
self.register_buffer(f"w1_weight_scale_{i}", w1_bs_3d.contiguous())
self.register_buffer(f"w2_weight_scale_{i}", w2_bs_3d.contiguous())
self.register_buffer(f"w1_alpha_{i}", 1.0 / (inp_scale * scale_1))
self.register_buffer(f"w2_alpha_{i}", 1.0 / (inp_scale * scale_2))
if is_gated_mlp:
w3_fp32 = (
torch.randn(intermediate_size, hidden_size, device="cuda", dtype=dtype) * 0.01
)
w3_amax = torch.abs(w3_fp32).max().to(torch.float32)
scale_3 = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w3_amax
w3_fp4, w3_bs = torch.ops.trtllm.fp4_quantize(
w3_fp32, scale_3, NVFP4_BLOCK_SIZE, False
)
# Infer padded dimensions for w3
_, w3_k_packed = w3_fp4.shape
w3_k_padded = w3_k_packed * NVFP4_PACK_FACTOR
w3_n_padded = w3_bs.numel() // (w3_k_padded // NVFP4_BLOCK_SIZE)
w3_bs_3d = w3_bs.view(w3_n_padded, w3_k_padded // NVFP4_BLOCK_SIZE)
self.w3_weight.append(nn.Parameter(w3_fp4, requires_grad=False))
self.register_buffer(f"w3_input_scale_{i}", inp_scale)
self.register_buffer(f"w3_weight_scale_{i}", w3_bs_3d.contiguous())
self.register_buffer(f"w3_alpha_{i}", 1.0 / (inp_scale * scale_3))
def forward(self, x: torch.Tensor) -> torch.Tensor:
router_logits = self.gate(x)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(x.dtype)
w1_list = list(self.w1_weight)
w2_list = list(self.w2_weight)
w3_list = list(self.w3_weight) if self.is_gated_mlp else []
w1_input_scale = [getattr(self, f"w1_input_scale_{i}") for i in range(self.num_experts)]
w2_input_scale = [getattr(self, f"w2_input_scale_{i}") for i in range(self.num_experts)]
w3_input_scale = (
[getattr(self, f"w3_input_scale_{i}") for i in range(self.num_experts)]
if self.is_gated_mlp
else []
)
w1_weight_scale = [getattr(self, f"w1_weight_scale_{i}") for i in range(self.num_experts)]
w2_weight_scale = [getattr(self, f"w2_weight_scale_{i}") for i in range(self.num_experts)]
w3_weight_scale = (
[getattr(self, f"w3_weight_scale_{i}") for i in range(self.num_experts)]
if self.is_gated_mlp
else []
)
w1_alpha = [getattr(self, f"w1_alpha_{i}") for i in range(self.num_experts)]
w2_alpha = [getattr(self, f"w2_alpha_{i}") for i in range(self.num_experts)]
w3_alpha = (
[getattr(self, f"w3_alpha_{i}") for i in range(self.num_experts)]
if self.is_gated_mlp
else []
)
out = torch.ops.auto_deploy.torch_quant_nvfp4_moe(
x,
selected_experts,
routing_weights,
w1_list,
w2_list,
w3_list,
w1_input_scale,
w2_input_scale,
w3_input_scale,
w1_weight_scale,
w2_weight_scale,
w3_weight_scale,
w1_alpha,
w2_alpha,
w3_alpha,
is_gated_mlp=self.is_gated_mlp,
act_fn=ActivationType.Silu if self.is_gated_mlp else ActivationType.Relu2,
)
return out
def get_input(self, device, dtype=torch.bfloat16):
return torch.randn(2, self.hidden_size, device=device, dtype=dtype) * 0.01
@pytest.mark.skipif(
not fp4_compatible() or not trtllm_ops_available(),
reason="Requires FP4 + TRTLLM support",
)
@pytest.mark.parametrize("is_gated_mlp", [True, False], ids=["gated_mlp", "mlp"])
@pytest.mark.parametrize(
"hidden_size,intermediate_size",
[
(512, 256), # Standard aligned dimensions
(1024, 512), # Larger aligned dimensions
(768, 384), # Common transformer dimensions (divisible by 16)
(512, 128), # Smaller intermediate
(256, 256), # Equal dimensions
],
ids=["512x256", "1024x512", "768x384", "512x128", "256x256"],
)
def test_nvfp4_moe_fusion(is_gated_mlp, hidden_size, intermediate_size):
"""Test that torch_quant_nvfp4_moe fuses to trtllm_quant_nvfp4_moe_fused.
Note: This test uses swizzled block scales that are compatible with the fused trtllm kernel.
The non-fused op (torch_quant_nvfp4_moe) uses a different internal path that expects
non-swizzled scales, so we don't compare outputs between non-fused and fused.
Instead, we verify the fusion transformation works correctly and produces valid output.
Tests both gated MLP (with w3) and non-gated MLP (without w3) variants with various
hidden_size and intermediate_size configurations.
"""
device = "cuda"
dtype = torch.bfloat16
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
model = MoEOpModelNVFP4(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
dtype=dtype,
is_gated_mlp=is_gated_mlp,
).to(device=device)
x = model.get_input(device=device, dtype=dtype)
# Export to GraphModule
gm = torch_export_to_gm(model, args=(x,), clone=True)
# Verify non-fused op is present before fusion
has_nonfused = any(
is_op(n, torch.ops.auto_deploy.torch_quant_nvfp4_moe) for n in gm.graph.nodes
)
assert has_nonfused, "Expected torch_quant_nvfp4_moe op before fusion"
# Apply NVFP4 MoE fusion
gm_transformed = InferenceOptimizer(
None,
{
"fuse_nvfp4_moe": {
"stage": "post_load_fusion",
},
},
)(None, gm)
# Verify fused op is present after fusion
has_fused = any(
is_op(n, torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused)
for n in gm_transformed.graph.nodes
)
assert has_fused, "Expected trtllm_quant_nvfp4_moe_fused op after fusion"
# Run fused graph to verify it produces valid output (not NaN/Inf)
with torch.inference_mode():
fused_output = gm_transformed(x)
assert not torch.isnan(fused_output).any(), "Fused output contains NaN"
assert not torch.isinf(fused_output).any(), "Fused output contains Inf"