From 128d4ac5bea203419e26924a3cb2e3935b205605 Mon Sep 17 00:00:00 2001 From: tcherckez-nvidia <127761168+tcherckez-nvidia@users.noreply.github.com> Date: Thu, 22 Jan 2026 13:08:05 +0200 Subject: [PATCH] =?UTF-8?q?[None][chore]=20NVFP4=20MoE=20-=20Move=20weight?= =?UTF-8?q?s=20transformation=20to=20fusion=20phase=E2=80=A6=20(#10803)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Tal Cherckez Signed-off-by: Tal Cherckez Signed-off-by: Tal Cherckez Signed-off-by: tcherckez-nvidia <127761168+tcherckez-nvidia@users.noreply.github.com> Co-authored-by: Tal Cherckez Co-authored-by: Tal Cherckez Co-authored-by: Tal Cherckez --- .../custom_ops/fused_moe/trtllm_moe.py | 79 +----- .../_torch/auto_deploy/custom_ops/quant.py | 1 + .../transform/library/fused_moe.py | 78 ++++-- .../defs/accuracy/references/gsm8k.yaml | 3 + .../defs/accuracy/references/mmlu.yaml | 3 + .../defs/accuracy/test_llm_api_autodeploy.py | 9 +- .../test_lists/test-db/l0_dgx_b200.yml | 3 + .../library/test_moe_fusion.py | 234 ++++++++++++++++++ 8 files changed, 313 insertions(+), 97 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index bcd903d26e..7a7a53d960 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -13,16 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import math - import torch -from tensorrt_llm._torch.auto_deploy.custom_ops.quant import ( - TRTLLM_NVFP4_COLUMN_SIZE, - TRTLLM_NVFP4_ROW_SIZE, - TRTLLM_NVFP4_SCALING_VECTOR_SIZE, -) +from tensorrt_llm._torch.auto_deploy.custom_ops.quant import TRTLLM_NVFP4_SCALING_VECTOR_SIZE from tensorrt_llm._torch.utils import ActivationType @@ -262,14 +255,6 @@ def trtllm_quant_nvfp4_moe_fused( mlp_style: "gated_mlp" or "mlp" act_fn: "silu" for gated_mlp, "relu2" for mlp """ - NVFP4_BLOCK_SIZE = TRTLLM_NVFP4_SCALING_VECTOR_SIZE - FP4_PER_UINT8 = 2 - - _, fc1_inter_size, _ = fc1_expert_weights_fp4.shape - n_experts, hidden_size, inter_size = fc2_expert_weights_fp4.shape - - # Convert the inter_size from number of uint8 elements to number of FP4 elements. - inter_size *= FP4_PER_UINT8 # Validate block scale tensors are 3D (padding requirements handled below) assert fc1_weight_blockscale_fp8.ndim == 3, "fc1_weight_blockscale_fp8 must be 3D" @@ -280,7 +265,7 @@ def trtllm_quant_nvfp4_moe_fused( if x.dtype in (torch.float16, torch.bfloat16): x_q_fp4, input_blockscale = torch.ops.trtllm.fp4_quantize( - x, fc1_act_global_scale, NVFP4_BLOCK_SIZE + x, fc1_act_global_scale, TRTLLM_NVFP4_SCALING_VECTOR_SIZE ) output_dtype = x.dtype else: @@ -288,66 +273,6 @@ def trtllm_quant_nvfp4_moe_fused( input_blockscale = None output_dtype = x.dtype - # Pad inter_size to be divisible by TRTLLM_NVFP4_ROW_SIZE - inter_size_padded = math.ceil(inter_size / TRTLLM_NVFP4_ROW_SIZE) * TRTLLM_NVFP4_ROW_SIZE - fc1_inter_size_padded = ( - math.ceil(fc1_inter_size / TRTLLM_NVFP4_ROW_SIZE) * TRTLLM_NVFP4_ROW_SIZE - ) - hidden_size_padded = ( - math.ceil(hidden_size / TRTLLM_NVFP4_COLUMN_SIZE) * TRTLLM_NVFP4_COLUMN_SIZE - ) - - inter_size_needs_padding = (is_gated_mlp and fc1_inter_size_padded != fc1_inter_size) or ( - not is_gated_mlp and inter_size_padded != inter_size - ) - hidden_size_needs_padding = hidden_size % TRTLLM_NVFP4_COLUMN_SIZE != 0 - - hidden_blocks_padded = hidden_size_padded // NVFP4_BLOCK_SIZE - inter_blocks_padded = inter_size_padded // NVFP4_BLOCK_SIZE - - if inter_size_needs_padding or hidden_size_needs_padding: - # Pad fc1_expert_weights_fp4: [E, I, H/2] or [E, 2*I, H/2] - fc1_padded = fc1_expert_weights_fp4.new_zeros( - fc1_expert_weights_fp4.size(0), - fc1_inter_size_padded, - hidden_size_padded // FP4_PER_UINT8, - ) - fc1_padded[:, :fc1_inter_size, : hidden_size // FP4_PER_UINT8] = fc1_expert_weights_fp4 - fc1_expert_weights_fp4 = fc1_padded - - # Block scales may already be padded, so check actual size - fc1_bs_size1 = fc1_weight_blockscale_fp8.size(1) - fc1_bs_size2 = fc1_weight_blockscale_fp8.size(2) - if fc1_bs_size1 < fc1_inter_size_padded or fc1_bs_size2 < hidden_blocks_padded: - fc1_blockscale_fp8_padded = fc1_weight_blockscale_fp8.new_zeros( - n_experts, fc1_inter_size_padded, hidden_blocks_padded - ) - fc1_blockscale_fp8_padded[:, :fc1_bs_size1, :fc1_bs_size2] = fc1_weight_blockscale_fp8 - fc1_weight_blockscale_fp8 = fc1_blockscale_fp8_padded - - fc2_padded = fc2_expert_weights_fp4.new_zeros( - n_experts, hidden_size_padded, inter_size_padded // FP4_PER_UINT8 - ) - - assert inter_size % NVFP4_BLOCK_SIZE == 0, ( - f"inter_size {inter_size} must be divisible by {NVFP4_BLOCK_SIZE}" - ) - - fc2_padded[:, :hidden_size, : inter_size // FP4_PER_UINT8] = fc2_expert_weights_fp4 - fc2_expert_weights_fp4 = fc2_padded - - # Pad fc2_weight_blockscale_fp8: [E, H, I/16] - # Block scales may already be padded, so check actual size - fc2_bs_size1 = fc2_weight_blockscale_fp8.size(1) - fc2_bs_size2 = fc2_weight_blockscale_fp8.size(2) - if fc2_bs_size1 < hidden_size_padded or fc2_bs_size2 < inter_blocks_padded: - fc2_blockscale_fp8_padded = fc2_weight_blockscale_fp8.new_zeros( - n_experts, hidden_size_padded, inter_blocks_padded - ) - fc2_blockscale_fp8_padded[:, :fc2_bs_size1, :fc2_bs_size2] = fc2_weight_blockscale_fp8 - fc2_weight_blockscale_fp8 = fc2_blockscale_fp8_padded - # else: block scales already have correct padded size, use as-is - # quant_scales is described by this code: # https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015 quant_scales = [ diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py index 5c5bcf6e3c..cfb4049923 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py @@ -14,6 +14,7 @@ TRTLLM_FP4_OP_AVAILABLE = True TRTLLM_NVFP4_SCALING_VECTOR_SIZE = 16 TRTLLM_NVFP4_ROW_SIZE = 128 TRTLLM_NVFP4_COLUMN_SIZE = 4 +TRTLLM_NVFP4_PACKING_FACTOR = 2 @torch.library.custom_op("auto_deploy::torch_quant_fn", mutates_args=()) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index c7d949f20e..8c32dac0d2 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -1,3 +1,4 @@ +import math from collections import defaultdict from functools import partial from typing import Dict, List, Literal, Optional, Tuple, Type @@ -8,6 +9,7 @@ from torch.fx import GraphModule, Node from tensorrt_llm._torch.utils import ActivationType +from ...custom_ops.quant import TRTLLM_NVFP4_PACKING_FACTOR, TRTLLM_NVFP4_SCALING_VECTOR_SIZE from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code @@ -1627,14 +1629,13 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int: "w3_weight", "w1_input_scale", "w2_input_scale", - "w3_input_scale", "w1_weight_scale", "w2_weight_scale", "w3_weight_scale", "w1_alpha", "w2_alpha", - "w3_alpha", "is_gated_mlp", + "act_fn", ) def _stack(param_list, dim=0, device=None, dtype=None): @@ -1648,13 +1649,10 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int: def _prepare_args_cutlass_format_nvfp4(): if is_gated_mlp: # For gated MLP, concatenate w1 and w3 as [w3, w1] - fc1_expert_weights = torch.cat( - [w3_stacked, w1_stacked], dim=1 - ).contiguous() # [E, 2*I, H] - fc1_act_scale = torch.cat( - [w3_input_scale_stacked, w1_input_scale_stacked], dim=1 - ).contiguous() - fc1_alpha_stacked = torch.cat([w3_alpha_stacked, w1_alpha_stacked], dim=1).contiguous() + fc1_expert_weights = torch.cat([w3_stacked, w1_stacked], dim=1).contiguous() + # Expect w3 input scale and alpha to be the same as w1 + fc1_act_scale = w1_input_scale_stacked + fc1_alpha_stacked = w1_alpha_stacked fc1_weight_blockscale_fp8_stacked = torch.cat( [w3_weight_blockscale_fp8_stacked, w1_weight_blockscale_fp8_stacked], dim=1 ).contiguous() @@ -1666,6 +1664,7 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int: fc2_expert_weights = w2_stacked fc2_act_scale = w2_input_scale_stacked + fc2_weight_blockscale_fp8_stacked = w2_weight_blockscale_fp8_stacked new_key_fc1_expert_weights = f"nvfp4_moe_w3_w1_stacked_{fused_key_counter}" new_key_fc2_expert_weights = f"nvfp4_moe_w2_stacked_{fused_key_counter}" @@ -1681,13 +1680,57 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int: new_key_fc1_alpha = f"nvfp4_moe_w1_alpha_stacked_{fused_key_counter}" new_key_fc2_alpha = f"nvfp4_moe_w2_alpha_stacked_{fused_key_counter}" + # Pad fc1_expert_weights to match the already padded scales + fc1_pad_size = fc1_weight_blockscale_fp8_stacked.shape[1] - fc1_expert_weights.shape[1] + if fc1_pad_size > 0: + fc1_expert_weights = torch.nn.functional.pad( + fc1_expert_weights, (0, 0, 0, fc1_pad_size), mode="constant", value=0 + ) + # Need to update fc2 scales and weights to match the padded size of fc1, + # as they share the same intermediate dimension. + target_intermediate = fc1_weight_blockscale_fp8_stacked.shape[1] + TRTLLM_NVFP4_SCALING_VECTOR_NUM_ELEMENTS = TRTLLM_NVFP4_SCALING_VECTOR_SIZE + TRTLLM_NVFP4_SCALING_BYTES_SIZE = ( + TRTLLM_NVFP4_SCALING_VECTOR_NUM_ELEMENTS // TRTLLM_NVFP4_PACKING_FACTOR + ) + target_n_blocks = target_intermediate // TRTLLM_NVFP4_SCALING_VECTOR_NUM_ELEMENTS + padded_target_n_blocks = ( + math.ceil(target_n_blocks / TRTLLM_NVFP4_SCALING_BYTES_SIZE) + * TRTLLM_NVFP4_SCALING_BYTES_SIZE + ) + fc2_blocks_pad = padded_target_n_blocks - fc2_weight_blockscale_fp8_stacked.shape[2] + + if fc2_blocks_pad > 0: + # unswizzle fc2 scales + fc2_blockscale_shape = list(fc2_weight_blockscale_fp8_stacked.shape) + fc2_blockscale_shape[2] = padded_target_n_blocks + fc2_weight_blockscale_fp8_stacked = torch.ops.trtllm.block_scale_interleave_reverse( + fc2_weight_blockscale_fp8_stacked.view(torch.uint8) + ) + fc2_weight_blockscale_fp8_stacked = torch.nn.functional.pad( + fc2_weight_blockscale_fp8_stacked, (0, fc2_blocks_pad), mode="constant", value=0 + ) + fc2_weight_blockscale_fp8_stacked = ( + torch.ops.trtllm.block_scale_interleave(fc2_weight_blockscale_fp8_stacked) + .view(torch.float8_e4m3fn) + .reshape(fc2_blockscale_shape) + ) + fc2_expert_weights = torch.nn.functional.pad( + fc2_expert_weights, + (0, fc1_pad_size // TRTLLM_NVFP4_PACKING_FACTOR, 0, 0), + mode="constant", + value=0, + ).view(torch.uint8) + # FP4 weights are already packed as uint8, don't convert dtype _register_parameter(gm, new_key_fc1_expert_weights, fc1_expert_weights) _register_parameter(gm, new_key_fc2_expert_weights, fc2_expert_weights) _register_parameter( gm, new_key_fc1_weight_blockscale_fp8, fc1_weight_blockscale_fp8_stacked ) - _register_parameter(gm, new_key_fc2_weight_blockscale_fp8, w2_weight_blockscale_fp8_stacked) + _register_parameter( + gm, new_key_fc2_weight_blockscale_fp8, fc2_weight_blockscale_fp8_stacked + ) _register_parameter(gm, new_key_fc1_act_scale, fc1_act_scale) _register_parameter(gm, new_key_fc2_act_scale, fc2_act_scale) _register_parameter(gm, new_key_fc1_alpha, fc1_alpha_stacked) @@ -1707,7 +1750,11 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int: graph.get_attr(new_key_fc1_alpha), graph.get_attr(new_key_fc2_alpha), ) - return args + kwargs = { + "is_gated_mlp": is_gated_mlp, + "act_fn": act_fn, + } + return args, kwargs fused_key_counter = 0 graph = gm.graph @@ -1727,14 +1774,13 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int: w3_list, w1_input_scale, w2_input_scale, - w3_input_scale, w1_weight_scale, w2_weight_scale, w3_weight_scale, w1_alpha, w2_alpha, - w3_alpha, is_gated_mlp, + act_fn, ) = _extract_op_args(node) # Stack the actual tensor values (fast, like in quantize_moe.py) @@ -1746,7 +1792,6 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int: # Scales are buffers, not parameters w1_input_scale_stacked = _stack(w1_input_scale, dim=0) w2_input_scale_stacked = _stack(w2_input_scale, dim=0) - w3_input_scale_stacked = _stack(w3_input_scale, dim=0, device=device, dtype=dtype) # Use .view() not .to() to reinterpret bytes as float8, not value conversion w1_weight_blockscale_fp8_stacked = _stack(w1_weight_scale, dim=0).view(torch.float8_e4m3fn) @@ -1757,9 +1802,8 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int: w1_alpha_stacked = _stack(w1_alpha, dim=0) w2_alpha_stacked = _stack(w2_alpha, dim=0) - w3_alpha_stacked = _stack(w3_alpha, dim=0, device=device, dtype=dtype) - args = _prepare_args_cutlass_format_nvfp4() + args, kwargs = _prepare_args_cutlass_format_nvfp4() fused_key_counter += 1 @@ -1768,7 +1812,7 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int: new_node = graph.call_function( replacement_op, args, - kwargs=node.kwargs, + kwargs=kwargs, ) node.replace_all_uses_with(new_node) diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index aea77ba411..94a64dea1a 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -180,6 +180,9 @@ nvidia/Nemotron-MOE: - quant_algo: FP8 kv_cache_quant_algo: FP8 accuracy: 86.884 + - quant_algo: NVFP4 + kv_cache_quant_algo: FP8 + accuracy: 63.268 nvidia/Llama-3.1-Nemotron-Nano-8B-v1: - accuracy: 37.15 - quant_algo: FP8 diff --git a/tests/integration/defs/accuracy/references/mmlu.yaml b/tests/integration/defs/accuracy/references/mmlu.yaml index 8e6a1e9b1f..3646793942 100644 --- a/tests/integration/defs/accuracy/references/mmlu.yaml +++ b/tests/integration/defs/accuracy/references/mmlu.yaml @@ -284,6 +284,9 @@ nvidia/Nemotron-MOE: - quant_algo: FP8 kv_cache_quant_algo: FP8 accuracy: 73.879 + - quant_algo: NVFP4 + kv_cache_quant_algo: FP8 + accuracy: 70.879 microsoft/Phi-4-mini-instruct: - accuracy: 68.98 - quant_algo: FP8 diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index 92730768fe..4a3e277cb9 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -14,6 +14,7 @@ # limitations under the License. import pytest +from defs.conftest import skip_pre_blackwell from test_common.llm_data import hf_id_to_local_model_dir, llm_models_root from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM @@ -142,7 +143,7 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness): MODEL_NAME = "nvidia/Nemotron-MOE" MODEL_PATH_BF16 = f"{llm_models_root()}/Nemotron-Nano-3-30B-A3.5B-dev-1024" MODEL_PATH_FP8 = f"{llm_models_root()}/Nemotron-Nano-3-30B-A3.5B-FP8-KVFP8-dev" - MODEL_PATH_NVFP4 = f"{llm_models_root()}/Nemotron-3-Nano-30B-A3B-NVFP4" + MODEL_PATH_NVFP4 = f"{llm_models_root()}/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4" def get_default_kwargs(self): return { @@ -220,11 +221,13 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) - @pytest.mark.skip(reason="NVFP4 model is not in the CI yet") - def test_nvfp4(self): + @skip_pre_blackwell + @pytest.mark.parametrize("world_size", [1, 2, 4]) + def test_nvfp4(self, world_size): kwargs = self.get_default_kwargs() with AutoDeployLLM(model=self.MODEL_PATH_NVFP4, tokenizer=self.MODEL_PATH_NVFP4, + world_size=world_size, **kwargs) as llm: # Manually set quant_config for NVFP4 model to get the accuracy threshold llm.args.quant_config.quant_algo = QuantAlgo.NVFP4 diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index 75833f4f10..4f7a1200cd 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -218,3 +218,6 @@ l0_dgx_b200: - accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_bf16 - accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_fp8[4] - accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_fp8[8] + - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_nvfp4[1] + - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_nvfp4[2] + - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_nvfp4[4] diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py index 3b55e24f98..c94108168a 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py @@ -10,6 +10,7 @@ from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp4_global_scale +from tensorrt_llm._torch.utils import ActivationType class BlockSparseTop2MLP(nn.Module): @@ -416,3 +417,236 @@ def test_fuse_moe_cleanup(): assert mem_after <= mem_before, ( f"CUDA memory increased after fusion: before={mem_before} after={mem_after}" ) + + +class MoEOpModelNVFP4(nn.Module): + """MoE model using torch_quant_nvfp4_moe op for testing fusion to trtllm_quant_nvfp4_moe_fused. + + This model creates weights with 3D block scales that are compatible with + the trtllm fused MoE kernel. + """ + + def __init__( + self, + hidden_size=512, # Already aligned to all requirements (16, 128, etc.) + intermediate_size=256, # Already aligned - no padding needed + num_experts=3, + top_k=2, + dtype=torch.bfloat16, + is_gated_mlp=True, + ): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_experts = num_experts + self.top_k = top_k + self.dtype = dtype + self.is_gated_mlp = is_gated_mlp + + # Constants for NVFP4 layout + NVFP4_BLOCK_SIZE = 16 + NVFP4_PACK_FACTOR = 2 + FLOAT8_E4M3_MAX = 448.0 + FLOAT4_E2M1_MAX = 6.0 + + self.gate = nn.Linear(hidden_size, num_experts, dtype=dtype) + + # Create sample input for scale computation + sample_input = torch.randn(2, hidden_size, dtype=dtype, device="cuda") * 0.01 + inp_scale = fp4_global_scale(sample_input) + + # Per-expert quantized weights and scales + self.w1_weight = nn.ParameterList() + self.w2_weight = nn.ParameterList() + self.w3_weight = nn.ParameterList() if is_gated_mlp else None + + for i in range(num_experts): + w1_fp32 = torch.randn(intermediate_size, hidden_size, device="cuda", dtype=dtype) * 0.01 + w2_fp32 = torch.randn(hidden_size, intermediate_size, device="cuda", dtype=dtype) * 0.01 + + # Compute global scales + w1_amax = torch.abs(w1_fp32).max().to(torch.float32) + w2_amax = torch.abs(w2_fp32).max().to(torch.float32) + + scale_1 = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax + scale_2 = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax + + # Quantize weights (non-swizzled layout) + w1_fp4, w1_bs = torch.ops.trtllm.fp4_quantize(w1_fp32, scale_1, NVFP4_BLOCK_SIZE, False) + w2_fp4, w2_bs = torch.ops.trtllm.fp4_quantize(w2_fp32, scale_2, NVFP4_BLOCK_SIZE, False) + + # fp4_quantize pads block scales but not weights - infer padded dims from block scale size + _, w1_k_packed = w1_fp4.shape + _, w2_k_packed = w2_fp4.shape + w1_k_padded = w1_k_packed * NVFP4_PACK_FACTOR # Convert from uint8 to FP4 element count + w2_k_padded = w2_k_packed * NVFP4_PACK_FACTOR + + # Calculate padded N dimension from block scale tensor size + w1_n_padded = w1_bs.numel() // (w1_k_padded // NVFP4_BLOCK_SIZE) + w2_n_padded = w2_bs.numel() // (w2_k_padded // NVFP4_BLOCK_SIZE) + + # Reshape block scales to 3D format [N_padded, K/block] + w1_bs_3d = w1_bs.view(w1_n_padded, w1_k_padded // NVFP4_BLOCK_SIZE) + w2_bs_3d = w2_bs.view(w2_n_padded, w2_k_padded // NVFP4_BLOCK_SIZE) + + self.w1_weight.append(nn.Parameter(w1_fp4, requires_grad=False)) + self.w2_weight.append(nn.Parameter(w2_fp4, requires_grad=False)) + + self.register_buffer(f"w1_input_scale_{i}", inp_scale) + self.register_buffer(f"w2_input_scale_{i}", inp_scale) + self.register_buffer(f"w1_weight_scale_{i}", w1_bs_3d.contiguous()) + self.register_buffer(f"w2_weight_scale_{i}", w2_bs_3d.contiguous()) + self.register_buffer(f"w1_alpha_{i}", 1.0 / (inp_scale * scale_1)) + self.register_buffer(f"w2_alpha_{i}", 1.0 / (inp_scale * scale_2)) + + if is_gated_mlp: + w3_fp32 = ( + torch.randn(intermediate_size, hidden_size, device="cuda", dtype=dtype) * 0.01 + ) + w3_amax = torch.abs(w3_fp32).max().to(torch.float32) + scale_3 = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w3_amax + w3_fp4, w3_bs = torch.ops.trtllm.fp4_quantize( + w3_fp32, scale_3, NVFP4_BLOCK_SIZE, False + ) + + # Infer padded dimensions for w3 + _, w3_k_packed = w3_fp4.shape + w3_k_padded = w3_k_packed * NVFP4_PACK_FACTOR + w3_n_padded = w3_bs.numel() // (w3_k_padded // NVFP4_BLOCK_SIZE) + w3_bs_3d = w3_bs.view(w3_n_padded, w3_k_padded // NVFP4_BLOCK_SIZE) + + self.w3_weight.append(nn.Parameter(w3_fp4, requires_grad=False)) + self.register_buffer(f"w3_input_scale_{i}", inp_scale) + self.register_buffer(f"w3_weight_scale_{i}", w3_bs_3d.contiguous()) + self.register_buffer(f"w3_alpha_{i}", 1.0 / (inp_scale * scale_3)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + router_logits = self.gate(x) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(x.dtype) + + w1_list = list(self.w1_weight) + w2_list = list(self.w2_weight) + w3_list = list(self.w3_weight) if self.is_gated_mlp else [] + + w1_input_scale = [getattr(self, f"w1_input_scale_{i}") for i in range(self.num_experts)] + w2_input_scale = [getattr(self, f"w2_input_scale_{i}") for i in range(self.num_experts)] + w3_input_scale = ( + [getattr(self, f"w3_input_scale_{i}") for i in range(self.num_experts)] + if self.is_gated_mlp + else [] + ) + w1_weight_scale = [getattr(self, f"w1_weight_scale_{i}") for i in range(self.num_experts)] + w2_weight_scale = [getattr(self, f"w2_weight_scale_{i}") for i in range(self.num_experts)] + w3_weight_scale = ( + [getattr(self, f"w3_weight_scale_{i}") for i in range(self.num_experts)] + if self.is_gated_mlp + else [] + ) + w1_alpha = [getattr(self, f"w1_alpha_{i}") for i in range(self.num_experts)] + w2_alpha = [getattr(self, f"w2_alpha_{i}") for i in range(self.num_experts)] + w3_alpha = ( + [getattr(self, f"w3_alpha_{i}") for i in range(self.num_experts)] + if self.is_gated_mlp + else [] + ) + + out = torch.ops.auto_deploy.torch_quant_nvfp4_moe( + x, + selected_experts, + routing_weights, + w1_list, + w2_list, + w3_list, + w1_input_scale, + w2_input_scale, + w3_input_scale, + w1_weight_scale, + w2_weight_scale, + w3_weight_scale, + w1_alpha, + w2_alpha, + w3_alpha, + is_gated_mlp=self.is_gated_mlp, + act_fn=ActivationType.Silu if self.is_gated_mlp else ActivationType.Relu2, + ) + return out + + def get_input(self, device, dtype=torch.bfloat16): + return torch.randn(2, self.hidden_size, device=device, dtype=dtype) * 0.01 + + +@pytest.mark.skipif( + not fp4_compatible() or not trtllm_ops_available(), + reason="Requires FP4 + TRTLLM support", +) +@pytest.mark.parametrize("is_gated_mlp", [True, False], ids=["gated_mlp", "mlp"]) +@pytest.mark.parametrize( + "hidden_size,intermediate_size", + [ + (512, 256), # Standard aligned dimensions + (1024, 512), # Larger aligned dimensions + (768, 384), # Common transformer dimensions (divisible by 16) + (512, 128), # Smaller intermediate + (256, 256), # Equal dimensions + ], + ids=["512x256", "1024x512", "768x384", "512x128", "256x256"], +) +def test_nvfp4_moe_fusion(is_gated_mlp, hidden_size, intermediate_size): + """Test that torch_quant_nvfp4_moe fuses to trtllm_quant_nvfp4_moe_fused. + + Note: This test uses swizzled block scales that are compatible with the fused trtllm kernel. + The non-fused op (torch_quant_nvfp4_moe) uses a different internal path that expects + non-swizzled scales, so we don't compare outputs between non-fused and fused. + Instead, we verify the fusion transformation works correctly and produces valid output. + + Tests both gated MLP (with w3) and non-gated MLP (without w3) variants with various + hidden_size and intermediate_size configurations. + """ + device = "cuda" + dtype = torch.bfloat16 + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) + + model = MoEOpModelNVFP4( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + dtype=dtype, + is_gated_mlp=is_gated_mlp, + ).to(device=device) + x = model.get_input(device=device, dtype=dtype) + + # Export to GraphModule + gm = torch_export_to_gm(model, args=(x,), clone=True) + + # Verify non-fused op is present before fusion + has_nonfused = any( + is_op(n, torch.ops.auto_deploy.torch_quant_nvfp4_moe) for n in gm.graph.nodes + ) + assert has_nonfused, "Expected torch_quant_nvfp4_moe op before fusion" + + # Apply NVFP4 MoE fusion + gm_transformed = InferenceOptimizer( + None, + { + "fuse_nvfp4_moe": { + "stage": "post_load_fusion", + }, + }, + )(None, gm) + + # Verify fused op is present after fusion + has_fused = any( + is_op(n, torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused) + for n in gm_transformed.graph.nodes + ) + assert has_fused, "Expected trtllm_quant_nvfp4_moe_fused op after fusion" + + # Run fused graph to verify it produces valid output (not NaN/Inf) + with torch.inference_mode(): + fused_output = gm_transformed(x) + + assert not torch.isnan(fused_output).any(), "Fused output contains NaN" + assert not torch.isinf(fused_output).any(), "Fused output contains Inf"