diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 948e2f8139..5957bd4409 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -134,9 +134,11 @@ transforms: stage: post_load_fusion expect_mem_change: true backend: trtllm + allow_different_input_scales: false fuse_nvfp4_moe: stage: post_load_fusion expect_mem_change: true + allow_different_input_scales: false fuse_allreduce_residual_rmsnorm: stage: post_load_fusion fuse_rmsnorm: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py index 713d7dba03..deb1668ebe 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py @@ -670,8 +670,8 @@ def triton_quant_fp8_moe( w1_weight: torch.Tensor, # [E, I, H] stacked FP8 weights w2_weight: torch.Tensor, # [E, H, I] stacked FP8 weights w3_weight: torch.Tensor, # unused for mlp style - w1_input_scale: torch.Tensor, # [E] stacked input scales - w2_input_scale: torch.Tensor, # [E] stacked input scales + w1_input_scale: torch.Tensor, # [1] max input scale (precomputed) + w2_input_scale: torch.Tensor, # [1] max input scale (precomputed) w3_input_scale: torch.Tensor, # unused w1_weight_scale: torch.Tensor, # [E] stacked weight scales w2_weight_scale: torch.Tensor, # [E] stacked weight scales @@ -698,10 +698,11 @@ def triton_quant_fp8_moe( topk_weights = routing_weights.to(torch.float32).contiguous() # Weights are already stacked [E, ...] - just ensure contiguous and extract scales + # Input scales are precomputed max values (consistent with trtllm backend) w1_q = w1_weight.contiguous() w2_q = w2_weight.contiguous() - a1_scale = w1_input_scale[0].to(torch.float32).reshape(1).contiguous() - a2_scale = w2_input_scale[0].to(torch.float32).reshape(1).contiguous() + a1_scale = w1_input_scale.to(torch.float32).reshape(1).contiguous() + a2_scale = w2_input_scale.to(torch.float32).reshape(1).contiguous() b1_scale = w1_weight_scale.to(torch.float32).contiguous() b2_scale = w2_weight_scale.to(torch.float32).contiguous() @@ -762,7 +763,7 @@ def triton_quant_fp8_moe( @triton_quant_fp8_moe.register_fake -def triton_quant_fp8_moe( +def triton_quant_fp8_moe_fake( x: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor, diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index 57f3392a26..4e2fbefa72 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -136,9 +136,9 @@ def trtllm_quant_fp8_moe_fused( routing_weights: Routing weights (B*S, TOP_K) fc1_expert_weights: FC1 weights [E, 2*I, H] for gated_mlp, [E, I, H] for mlp fc2_expert_weights: FC2 weights [E, H, I] - fc1_act_scale: FC1 activation scale [E] + fc1_act_scale: FC1 activation scalar (scalar) fc1_dequant_scale: FC1 dequant scale [E] - fc2_act_scale_reciprocal: FC2 activation scale reciprocal [E] + fc2_act_scale_reciprocal: FC2 activation scale reciprocal (scalar) fc2_dequant_scale: FC2 dequant scale [E] is_gated_mlp: True for gated_mlp, False for mlp act_fn: ActivationType.Silu for gated_mlp, ActivationType.Relu2 for mlp @@ -153,24 +153,26 @@ def trtllm_quant_fp8_moe_fused( # Store original shape and flatten to 2D x_shape = x.shape x2d = x.view(-1, x_shape[-1]) - # Quantize the input - x_q_fp8 = _quantize_fp8(x2d, fc1_act_scale[0]) - # Scales are stored in float32 - w1_input_scale = fc1_act_scale[0] + # Quantize the input using precomputed max scale + x_q_fp8 = _quantize_fp8(x2d, fc1_act_scale) # Prepare quant_scales for TensorRT-LLM (Cutlass) FP8 format: # [fc1_dequant_scale, fc2_act_scale_reciprocal, fc2_dequant_scale, gemm1_input_dequant_scale] - # For gated MLP: - # These are precomputed in `fused_moe` transform - # - fc1_dequant_scale: w1_weight_scale * w1_input_scale (combined for w1 and w3) - # - fc2_act_scale_reciprocal: 1 / w2_input_scale - # - fc1_dequant_scale: w2_weight_scale * w2_input_scale - # - fc1_act_scale: w1_input_scale + # These are precomputed in `fused_moe` transform: + # - fc1_dequant_scale: w1_weight_scale * max(w1_input_scale) [E] + # - fc2_act_scale_reciprocal: 1 / max(w2_input_scale) (scalar) + # - fc2_dequant_scale: w2_weight_scale * max(w2_input_scale) [E] + # - fc1_act_scale: max(w1_input_scale) (scalar) assert fc1_dequant_scale.ndim == 1, "fc1_dequant_scale must be 1D" assert fc2_dequant_scale.ndim == 1, "fc2_dequant_scale must be 1D" - quant_scales = [fc1_dequant_scale, fc2_act_scale_reciprocal, fc2_dequant_scale, w1_input_scale] + quant_scales = [ + fc1_dequant_scale, + fc2_act_scale_reciprocal, + fc2_dequant_scale, + fc1_act_scale, + ] # Ensure contiguous tensors selected_experts = selected_experts.int().contiguous() diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index 94dc64892c..711b7cedb0 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -17,6 +17,7 @@ from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code, get_attr_by_name from ...utils.cuda_mem_tracker import cuda_memory_tracker +from ...utils.logger import ad_logger from ...utils.module import get_submodule_of_param from ...utils.node_utils import bfs, extract_op_args, identify_regions_between_residuals, is_op from ..interface import ( @@ -1323,11 +1324,21 @@ def remove_original_experts(gm: GraphModule, weight_lists: List[List[Node]]) -> continue -def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", "triton"]) -> int: +def _stack_fp8_moe_weights( + gm: GraphModule, + backend: Literal["auto", "trtllm", "triton"], + allow_different_input_scales: bool = False, +) -> int: """ Stack per-expert FP8 weights and scales by materializing stacked tensors as parameters. This is fast because we directly stack the tensor values (not graph nodes). Similar to _insert_fused_moe_ops but for quantized MoE. + + Args: + gm: The GraphModule to transform. + backend: Backend to use ('auto', 'trtllm', or 'triton'). + allow_different_input_scales: If False (default), assert that all experts have identical + input scales and fail if not. If True, allow different scales (use max for quantization). """ def _register_parameter(gm: GraphModule, target, value): @@ -1387,23 +1398,26 @@ def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", " # For optimization reasons, we precompute a few additional arguments to the trtllm_quant_fp8_moe_fused op # to avoid computing them at runtime. - fc1_dequant = (w1_weight_scale_stacked * w1_input_scale_stacked[0]).squeeze() - fc2_act_scale_recip = (1.0 / w2_input_scale_stacked[0]).to(torch.float32) - fc2_dequant = (w2_weight_scale_stacked * w2_input_scale_stacked[0]).squeeze() + # We use max scale to handle different input scales per expert (if enabled). + fc1_act_scale = fc1_act_scale.max() + fc2_act_scale = w2_input_scale_stacked.max() + fc1_dequant = (w1_weight_scale_stacked * w1_input_scale_stacked.max()).squeeze() + fc2_act_scale_recip = (1.0 / fc2_act_scale).to(torch.float32) + fc2_dequant = (w2_weight_scale_stacked * fc2_act_scale).squeeze() + new_key_fc1_expert_weights = f"quant_moe_w3_w1_stacked_{fused_key_counter}" + new_key_fc2_expert_weights = f"quant_moe_w2_stacked_{fused_key_counter}" + new_key_fc1_act_scale = f"quant_moe_fc1_act_scale_{fused_key_counter}" new_key_fc1_dequant = f"quant_moe_fc1_dequant_stacked_{fused_key_counter}" new_key_fc2_act_scale_recip = f"quant_moe_fc2_act_scale_recip_stacked_{fused_key_counter}" new_key_fc2_dequant = f"quant_moe_fc2_dequant_stacked_{fused_key_counter}" - new_key_fc1_expert_weights = f"quant_moe_w3_w1_stacked_{fused_key_counter}" - new_key_fc2_expert_weights = f"quant_moe_w2_stacked_{fused_key_counter}" - new_key_fc1_act_scale = f"quant_moe_w3_w1_input_scale_stacked_{fused_key_counter}" - _register_parameter(gm, new_key_fc1_dequant, fc1_dequant) - _register_parameter(gm, new_key_fc2_act_scale_recip, fc2_act_scale_recip) - _register_parameter(gm, new_key_fc2_dequant, fc2_dequant) _register_parameter(gm, new_key_fc1_expert_weights, fc1_expert_weights) _register_parameter(gm, new_key_fc2_expert_weights, fc2_expert_weights) _register_parameter(gm, new_key_fc1_act_scale, fc1_act_scale) + _register_parameter(gm, new_key_fc1_dequant, fc1_dequant) + _register_parameter(gm, new_key_fc2_act_scale_recip, fc2_act_scale_recip) + _register_parameter(gm, new_key_fc2_dequant, fc2_dequant) with graph.inserting_before(node): args = ( @@ -1424,19 +1438,27 @@ def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", " new_key_w1 = f"quant_moe_w1_stacked_{fused_key_counter}" new_key_w2 = f"quant_moe_w2_stacked_{fused_key_counter}" new_key_w3 = f"quant_moe_w3_stacked_{fused_key_counter}" - new_key_w1_input_scale = f"quant_moe_w1_input_scale_stacked_{fused_key_counter}" - new_key_w2_input_scale = f"quant_moe_w2_input_scale_stacked_{fused_key_counter}" - new_key_w3_input_scale = f"quant_moe_w3_input_scale_stacked_{fused_key_counter}" new_key_w1_weight_scale = f"quant_moe_w1_weight_scale_stacked_{fused_key_counter}" new_key_w2_weight_scale = f"quant_moe_w2_weight_scale_stacked_{fused_key_counter}" new_key_w3_weight_scale = f"quant_moe_w3_weight_scale_stacked_{fused_key_counter}" + w1_input_scale = w1_input_scale_stacked.max().reshape(1) + w2_input_scale = w2_input_scale_stacked.max().reshape(1) + # w3_input_scale: use max of w3 scales if present, else use empty tensor + w3_input_scale = ( + w3_input_scale_stacked.max().reshape(1) + if w3_input_scale_stacked.numel() > 0 + else torch.empty(1, device=w1_input_scale.device, dtype=w1_input_scale.dtype) + ) + new_key_w1_input_scale = f"quant_moe_w1_input_scale_{fused_key_counter}" + new_key_w2_input_scale = f"quant_moe_w2_input_scale_{fused_key_counter}" + new_key_w3_input_scale = f"quant_moe_w3_input_scale_{fused_key_counter}" _register_parameter(gm, new_key_w1, w1_stacked) _register_parameter(gm, new_key_w2, w2_stacked) _register_parameter(gm, new_key_w3, w3_stacked) - _register_parameter(gm, new_key_w1_input_scale, w1_input_scale_stacked) - _register_parameter(gm, new_key_w2_input_scale, w2_input_scale_stacked) - _register_parameter(gm, new_key_w3_input_scale, w3_input_scale_stacked) + _register_parameter(gm, new_key_w1_input_scale, w1_input_scale) + _register_parameter(gm, new_key_w2_input_scale, w2_input_scale) + _register_parameter(gm, new_key_w3_input_scale, w3_input_scale) _register_parameter(gm, new_key_w1_weight_scale, w1_weight_scale_stacked) _register_parameter(gm, new_key_w2_weight_scale, w2_weight_scale_stacked) _register_parameter(gm, new_key_w3_weight_scale, w3_weight_scale_stacked) @@ -1509,12 +1531,31 @@ def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", " 0, device=w1_input_scale_stacked.device, dtype=w1_input_scale_stacked.dtype ) ) - assert torch.all(w1_input_scale_stacked[0] == w1_input_scale_stacked), ( - "All w1 scales should have the same value." - ) - assert torch.all(w2_input_scale_stacked[0] == w2_input_scale_stacked), ( - "All w2 scales should have the same value." - ) + # Check if input scales are identical across experts + w1_input_scales_identical = torch.all( + w1_input_scale_stacked[0] == w1_input_scale_stacked + ).item() + w2_input_scales_identical = torch.all( + w2_input_scale_stacked[0] == w2_input_scale_stacked + ).item() + + if not w1_input_scales_identical or not w2_input_scales_identical: + if not allow_different_input_scales: + # Fail with assertion + assert w1_input_scales_identical, ( + "All w1 input scales should have the same value. " + "Set allow_different_input_scales=True to allow different scales (uses max)." + ) + assert w2_input_scales_identical, ( + "All w2 input scales should have the same value. " + "Set allow_different_input_scales=True to allow different scales (uses max)." + ) + # Issue warning once and continue - max() will be used + ad_logger.warning_once( + "FP8 MoE: Input scales differ across experts. Using max(input_scale) for quantization. " + "This may impact accuracy if scales differ significantly.", + key="fp8_moe_different_input_scales", + ) w1_weight_scale_stacked = _stack(w1_weight_scale, dim=0).to(torch.float32) w2_weight_scale_stacked = _stack(w2_weight_scale, dim=0).to(torch.float32) @@ -1602,6 +1643,14 @@ class FuseFP8MoeConfig(TransformConfig): default="auto", description="Backend to use for FP8 MoE computation ('auto', 'trtllm' or 'triton'. default: 'auto').", ) + allow_different_input_scales: bool = Field( + default=False, + description=( + "If False (default), assert that all experts have identical input scales and fail if not. " + "If True, allow different per-expert input scales by using max(input_scale) for quantization. " + "This matches TRT-LLM manual backend behavior but may impact accuracy if scales differ significantly." + ), + ) @TransformRegistry.register("fuse_fp8_moe") @@ -1611,6 +1660,10 @@ class FuseFP8Moe(BaseTransform): This runs after weights are loaded, similar to FuseMoe for unquantized MoE. """ + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return FuseFP8MoeConfig + def _apply( self, gm: GraphModule, @@ -1619,7 +1672,11 @@ class FuseFP8Moe(BaseTransform): shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: with cuda_memory_tracker(): - fused_key_counter = _stack_fp8_moe_weights(gm, backend=self.config.backend) + fused_key_counter = _stack_fp8_moe_weights( + gm, + backend=self.config.backend, + allow_different_input_scales=self.config.allow_different_input_scales, + ) info = TransformInfo( skipped=(fused_key_counter == 0), @@ -1630,7 +1687,10 @@ class FuseFP8Moe(BaseTransform): return gm, info -def _stack_nvfp4_moe_weights(gm: GraphModule) -> int: +def _stack_nvfp4_moe_weights( + gm: GraphModule, + allow_different_input_scales: bool = False, +) -> int: def _register_parameter(gm: GraphModule, target, value): gm.register_parameter(target, torch.nn.Parameter(value, requires_grad=False)) @@ -1659,6 +1719,7 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int: "w3_weight", "w1_input_scale", "w2_input_scale", + "w3_input_scale", "w1_weight_scale", "w2_weight_scale", "w3_weight_scale", @@ -1680,20 +1741,79 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int: if is_gated_mlp: # For gated MLP, concatenate w1 and w3 as [w3, w1] fc1_expert_weights = torch.cat([w3_stacked, w1_stacked], dim=1).contiguous() - # Expect w3 input scale and alpha to be the same as w1 - fc1_act_scale = w1_input_scale_stacked - fc1_alpha_stacked = w1_alpha_stacked fc1_weight_blockscale_fp8_stacked = torch.cat( [w3_weight_blockscale_fp8_stacked, w1_weight_blockscale_fp8_stacked], dim=1 ).contiguous() + + # Check if all w1 and w3 input scales are identical across experts + all_scales_equal = ( + torch.all(w1_input_scale_stacked == w1_input_scale_stacked[0]) + and torch.all(w3_input_scale_stacked == w3_input_scale_stacked[0]) + and torch.all(w1_input_scale_stacked == w3_input_scale_stacked) + ) + + if all_scales_equal: + # All scales are identical, no need for min() or alpha recomputation + fc1_act_scale = w1_input_scale_stacked[0] + fc1_alpha_stacked = w1_alpha_stacked + else: + if not allow_different_input_scales: + assert False, ( + "FC1 input scales differ across experts (w1 and/or w3). " + "Set allow_different_input_scales=True to allow different scales (uses min)." + ) + # Issue warning once and continue - min() will be used + ad_logger.warning_once( + "NVFP4 MoE: Input scales differ across experts. Using min(input_scale) for " + "FC1 quantization and recomputing alpha. This may impact accuracy if scales " + "differ significantly.", + key="nvfp4_moe_different_input_scales", + ) + # Scales differ across experts - use global min scale and recompute alpha. + # Use min() because NVFP4 scales are in kernel format (2688/amax): + # smaller scale = larger amax = larger dynamic range. + fc1_act_scale = torch.minimum( + w1_input_scale_stacked.min(), w3_input_scale_stacked.min() + ) + # Recompute alpha using global input scale instead of per-expert input scale. + # Formula: new_alpha = old_alpha * per_expert_input_scale / global_input_scale + # This ensures alpha is consistent with the global fc1_act_scale used by the kernel. + fc1_alpha_stacked = w1_alpha_stacked * w1_input_scale_stacked / fc1_act_scale else: fc1_expert_weights = w1_stacked - fc1_act_scale = w1_input_scale_stacked - fc1_alpha_stacked = w1_alpha_stacked fc1_weight_blockscale_fp8_stacked = w1_weight_blockscale_fp8_stacked + # Check if all w1 input scales are identical across experts + all_scales_equal = torch.all(w1_input_scale_stacked == w1_input_scale_stacked[0]) + + if all_scales_equal: + # All scales are identical, no need for min() or alpha recomputation + fc1_act_scale = w1_input_scale_stacked[0] + fc1_alpha_stacked = w1_alpha_stacked + else: + if not allow_different_input_scales: + assert False, ( + "FC1 input scales differ across experts (w1). " + "Set allow_different_input_scales=True to allow different scales (uses min)." + ) + # Issue warning once and continue - min() will be used + ad_logger.warning_once( + "NVFP4 MoE: Input scales differ across experts. Using min(input_scale) for " + "FC1 quantization and recomputing alpha. This may impact accuracy if scales " + "differ significantly.", + key="nvfp4_moe_different_input_scales", + ) + # Scales differ across experts - use global min scale and recompute alpha + fc1_act_scale = w1_input_scale_stacked.min() + fc1_alpha_stacked = w1_alpha_stacked * w1_input_scale_stacked / fc1_act_scale + fc2_expert_weights = w2_stacked + # Keep fc2_act_scale per-expert (no global scale aggregation for fc2). + # The kernel supports per-expert scales for fc2, and intermediate activations + # naturally have different dynamic ranges per expert. fc2_act_scale = w2_input_scale_stacked + # No alpha recomputation needed since fc2 uses per-expert input scales. + fc2_alpha_stacked = w2_alpha_stacked fc2_weight_blockscale_fp8_stacked = w2_weight_blockscale_fp8_stacked new_key_fc1_expert_weights = f"nvfp4_moe_w3_w1_stacked_{fused_key_counter}" @@ -1764,7 +1884,7 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int: _register_parameter(gm, new_key_fc1_act_scale, fc1_act_scale) _register_parameter(gm, new_key_fc2_act_scale, fc2_act_scale) _register_parameter(gm, new_key_fc1_alpha, fc1_alpha_stacked) - _register_parameter(gm, new_key_fc2_alpha, w2_alpha_stacked) + _register_parameter(gm, new_key_fc2_alpha, fc2_alpha_stacked) with graph.inserting_before(node): args = ( @@ -1804,6 +1924,7 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int: w3_list, w1_input_scale, w2_input_scale, + w3_input_scale, w1_weight_scale, w2_weight_scale, w3_weight_scale, @@ -1822,6 +1943,12 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int: # Scales are buffers, not parameters w1_input_scale_stacked = _stack(w1_input_scale, dim=0) w2_input_scale_stacked = _stack(w2_input_scale, dim=0) + w3_input_scale_stacked = _stack( + w3_input_scale, + dim=0, + device=w1_input_scale_stacked.device, + dtype=w1_input_scale_stacked.dtype, + ) # Use .view() not .to() to reinterpret bytes as float8, not value conversion w1_weight_blockscale_fp8_stacked = _stack(w1_weight_scale, dim=0).view(torch.float8_e4m3fn) @@ -1856,6 +1983,21 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int: return fused_key_counter +class FuseNVFP4MoeConfig(TransformConfig): + """Configuration for NVFP4 MoE fusion transform.""" + + allow_different_input_scales: bool = Field( + default=False, + description=( + "If False (default), assert that all experts have identical input scales and fail if not. " + "If True, allow different per-expert input scales by using min(input_scale) for quantization. " + "Note: NVFP4 uses min() (not max like FP8) because scales are in kernel format (2688/amax): " + "smaller scale = larger amax = larger dynamic range. " + "This may impact accuracy if scales differ significantly." + ), + ) + + @TransformRegistry.register("fuse_nvfp4_moe") class FuseNVFP4Moe(BaseTransform): """ @@ -1863,6 +2005,10 @@ class FuseNVFP4Moe(BaseTransform): This runs after weights are loaded, similar to FuseFP8Moe. """ + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return FuseNVFP4MoeConfig + def _apply( self, gm: GraphModule, @@ -1871,7 +2017,10 @@ class FuseNVFP4Moe(BaseTransform): shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: with cuda_memory_tracker(): - fused_key_counter = _stack_nvfp4_moe_weights(gm) + fused_key_counter = _stack_nvfp4_moe_weights( + gm, + allow_different_input_scales=self.config.allow_different_input_scales, + ) info = TransformInfo( skipped=(fused_key_counter == 0), diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/moe/test_triton_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/moe/test_triton_moe.py index 490eb1d742..ab29e14209 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/moe/test_triton_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/moe/test_triton_moe.py @@ -299,15 +299,16 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit): # Input scales (tensor-wise, replicated per expert for interface compatibility) x_scale = x.abs().max().item() / FP8_MAX - w1_input_scale_tensor = torch.full((E,), x_scale, device=device, dtype=torch.float32) + # Input scales: precomputed max values with shape [1] (new API) + w1_input_scale_max = torch.tensor([x_scale], device=device, dtype=torch.float32) # Compute intermediate activation scale by simulating first GEMM + ReLU^2 # This ensures w2_input_scale matches the actual activation magnitude with torch.no_grad(): # Simulate the first GEMM: quantize input, do FP8 matmul, apply ReLU^2 - x_q = (x / w1_input_scale_tensor[0]).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) + x_q = (x / w1_input_scale_max.item()).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) # Dequantize and compute output for a sample - x_dq = x_q[:8].to(torch.float32) * w1_input_scale_tensor[0].item() + x_dq = x_q[:8].to(torch.float32) * w1_input_scale_max.item() w1_dq = w1_fp8_stacked[0].to(torch.float32) * w1_weight_scale[0].item() sample_out = torch.nn.functional.linear(x_dq.to(dtype), w1_dq.to(dtype)) sample_act = torch.square(torch.nn.functional.relu(sample_out)) @@ -315,11 +316,11 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit): # Ensure scale is not too small intermediate_scale = max(intermediate_scale, 1e-6) - w2_input_scale_tensor = torch.full((E,), intermediate_scale, device=device, dtype=torch.float32) + w2_input_scale_max = torch.tensor([intermediate_scale], device=device, dtype=torch.float32) # Convert scales to lists for torch_quant_fp8_moe reference - w1_input_scale_list = [w1_input_scale_tensor[0].clone() for _ in range(E)] - w2_input_scale_list = [w2_input_scale_tensor[0].clone() for _ in range(E)] + w1_input_scale_list = [w1_input_scale_max[0].clone() for _ in range(E)] + w2_input_scale_list = [w2_input_scale_max[0].clone() for _ in range(E)] w1_weight_scale_list = [w1_weight_scale[e].clone() for e in range(E)] w2_weight_scale_list = [w2_weight_scale[e].clone() for e in range(E)] @@ -327,7 +328,7 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit): w3_fp8_list = [torch.empty((1, 1), device=device, dtype=torch.float8_e4m3fn) for _ in range(E)] w3_fp8_stacked = torch.stack(w3_fp8_list).contiguous() w3_input_scale_list = [torch.ones((), device=device, dtype=torch.float32) for _ in range(E)] - w3_input_scale_tensor = torch.ones((E,), device=device, dtype=torch.float32) + w3_input_scale_max = torch.ones((1,), device=device, dtype=torch.float32) w3_weight_scale_list = [torch.ones((), device=device, dtype=torch.float32) for _ in range(E)] w3_weight_scale_tensor = torch.ones((E,), device=device, dtype=torch.float32) @@ -352,7 +353,7 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit): # Create equal routing weights routing_weights = torch.ones((M, top_k), device=device, dtype=torch.float32) / top_k - # Triton FP8 quantized MoE (uses stacked tensors) + # Triton FP8 quantized MoE (uses stacked tensors with precomputed max input scales) out_triton = torch.ops.auto_deploy.triton_quant_fp8_moe( x, selected_experts.to(torch.int32), @@ -360,9 +361,9 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit): w1_fp8_stacked, w2_fp8_stacked, w3_fp8_stacked, - w1_input_scale_tensor, - w2_input_scale_tensor, - w3_input_scale_tensor, + w1_input_scale_max, # [1] precomputed max + w2_input_scale_max, # [1] precomputed max + w3_input_scale_max, # [1] precomputed max (unused) w1_weight_scale, w2_weight_scale, w3_weight_scale_tensor, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/moe/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/moe/test_trtllm_moe.py index d85166d9c9..6b87c11501 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/moe/test_trtllm_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/moe/test_trtllm_moe.py @@ -414,7 +414,7 @@ def test_trtllm_fused_moe_fp8( routing_weights, fc1_expert_weights=w31_weight.contiguous() if is_gated_mlp else w1_weight.contiguous(), fc2_expert_weights=w2_weight.contiguous(), - fc1_act_scale=hidden_states_scale.unsqueeze(0), + fc1_act_scale=hidden_states_scale, fc1_dequant_scale=gemm1_dequant, fc2_act_scale_reciprocal=gemm2_act_quant, fc2_dequant_scale=gemm2_dequant, @@ -433,7 +433,7 @@ def test_trtllm_fused_moe_fp8( w1_weight, w2_weight, w3_weight, - hidden_states_scale.unsqueeze(0), + hidden_states_scale.reshape(1), w2_input_scale, w3_input_scale, w1_scales, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py index c94108168a..a8c18aee24 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py @@ -1,11 +1,14 @@ import pytest import torch +import torch.fx as fx import torch.nn as nn import torch.nn.functional as F from _graph_test_helpers import run_test_transformed_gm from _model_test_utils import MoEOpModel from _torch_test_utils import fp4_compatible, fp8_compatible, trtllm_ops_available +from utils.util import skip_pre_hopper +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op @@ -650,3 +653,424 @@ def test_nvfp4_moe_fusion(is_gated_mlp, hidden_size, intermediate_size): assert not torch.isnan(fused_output).any(), "Fused output contains NaN" assert not torch.isinf(fused_output).any(), "Fused output contains Inf" + + +class FP8MoEModuleForInputScaleTest(nn.Module): + """Module wrapping torch_quant_fp8_moe for testing FP8 MoE input scale handling.""" + + def __init__( + self, + num_experts, + w1_weight, + w2_weight, + w1_input_scale, + w2_input_scale, + w1_weight_scale, + w2_weight_scale, + is_gated_mlp, + act_fn, + ): + super().__init__() + self.num_experts = num_experts + self.is_gated_mlp = is_gated_mlp + self.act_fn = act_fn + + for i in range(num_experts): + self.register_buffer(f"w1_{i}", w1_weight[i]) + self.register_buffer(f"w2_{i}", w2_weight[i]) + self.register_buffer(f"w1_iscale_{i}", w1_input_scale[i]) + self.register_buffer(f"w2_iscale_{i}", w2_input_scale[i]) + self.register_buffer(f"w1_wscale_{i}", w1_weight_scale[i]) + self.register_buffer(f"w2_wscale_{i}", w2_weight_scale[i]) + + def forward(self, x, selected_experts, routing_weights): + return torch.ops.auto_deploy.torch_quant_fp8_moe( + x, + selected_experts, + routing_weights, + [getattr(self, f"w1_{i}") for i in range(self.num_experts)], + [getattr(self, f"w2_{i}") for i in range(self.num_experts)], + [], # w3 is empty for non-gated MLP + [getattr(self, f"w1_iscale_{i}") for i in range(self.num_experts)], + [getattr(self, f"w2_iscale_{i}") for i in range(self.num_experts)], + [], # w3 input scale is empty for non-gated MLP + [getattr(self, f"w1_wscale_{i}") for i in range(self.num_experts)], + [getattr(self, f"w2_wscale_{i}") for i in range(self.num_experts)], + [], # w3 weight scale is empty for non-gated MLP + is_gated_mlp=self.is_gated_mlp, + act_fn=self.act_fn, + ) + + +@skip_pre_hopper +@pytest.mark.parametrize("backend", ["trtllm", "triton"]) +@pytest.mark.parametrize("allow_different_input_scales", [False, True]) +@pytest.mark.parametrize("scales_identical", [True, False]) +@pytest.mark.skipif( + not fp8_compatible() or not trtllm_ops_available(), + reason="Requires fp8 and trtllm support", +) +def test_fp8_moe_different_input_scales(backend, allow_different_input_scales, scales_identical): + """ + Test FP8 MoE behavior with different/identical input scales via InferenceOptimizer. + + Tests the allow_different_input_scales config option for both trtllm and triton backends: + - When scales_identical=True: should always work + - When scales_identical=False and allow_different_input_scales=False: should fail with assertion + - When scales_identical=False and allow_different_input_scales=True: should work (uses max) + """ + from tensorrt_llm._torch.auto_deploy.transform.library.fused_moe import _stack_fp8_moe_weights + + torch.manual_seed(0) + + batch_size, num_experts, top_k = 4, 2, 2 + hidden_size, intermediate_size = 128, 128 + # Use non-gated MLP (Relu2) because triton backend only supports non-gated MLP + is_gated_mlp = False + act_fn = ActivationType.Relu2 + + # Generate test data + x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda") * 0.5 + + # Simple routing: distribute tokens across experts + selected_experts = torch.zeros((batch_size, top_k), dtype=torch.int64, device="cuda") + for i in range(batch_size): + selected_experts[i, 0] = i % num_experts + selected_experts[i, 1] = (i + 1) % num_experts + routing_weights = torch.ones((batch_size, top_k), device="cuda", dtype=torch.float32) / top_k + + # Create per-expert weights and scales (non-gated MLP, no w3) + w1_weight, w2_weight = [], [] + w1_input_scale, w2_input_scale = [], [] + w1_weight_scale, w2_weight_scale = [], [] + + for expert_id in range(num_experts): + # Random FP8 weights + w1_fp8 = torch.randn(intermediate_size, hidden_size, device="cuda").to(torch.float8_e4m3fn) + w2_fp8 = torch.randn(hidden_size, intermediate_size, device="cuda").to(torch.float8_e4m3fn) + w1_weight.append(w1_fp8) + w2_weight.append(w2_fp8) + + # Random weight scales (shape [1]) + w1_weight_scale.append(torch.tensor([0.1], dtype=torch.float32, device="cuda")) + w2_weight_scale.append(torch.tensor([0.1], dtype=torch.float32, device="cuda")) + + # Input scales: either identical or different per expert (shape [1]) + if scales_identical: + inp_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") + else: + # Different input scales per expert - big variance to test max() behavior + inp_scale = torch.tensor([0.5 + 0.5 * expert_id], dtype=torch.float32, device="cuda") + + w1_input_scale.append(inp_scale) + w2_input_scale.append(inp_scale) + + # Create a module with the FP8 MoE op + module = FP8MoEModuleForInputScaleTest( + num_experts, + w1_weight, + w2_weight, + w1_input_scale, + w2_input_scale, + w1_weight_scale, + w2_weight_scale, + is_gated_mlp, + act_fn, + ).cuda() + gm = fx.symbolic_trace(module) + + # Compute reference output from original graph before transformation + with torch.inference_mode(): + ref_output = gm(x, selected_experts, routing_weights) + + # Expected behavior: + # - scales_identical=True: always works + # - scales_identical=False, allow_different_input_scales=False: assertion error + # - scales_identical=False, allow_different_input_scales=True: works with max() + + if not scales_identical and not allow_different_input_scales: + # Should fail with assertion error + with pytest.raises(AssertionError, match="input scales should have the same value"): + _stack_fp8_moe_weights( + gm, backend=backend, allow_different_input_scales=allow_different_input_scales + ) + else: + # Should succeed + num_transformed = _stack_fp8_moe_weights( + gm, backend=backend, allow_different_input_scales=allow_different_input_scales + ) + gm.recompile() + + assert num_transformed == 1, f"Expected 1 transform, got {num_transformed}" + + # Verify that max() is used when scales differ + if not scales_identical: + expected_max_w1_input_scale = torch.stack(w1_input_scale).max() + # Attribute name differs between backends + if backend == "trtllm": + actual_w1_input = getattr(gm, "quant_moe_fc1_act_scale_0") + else: # triton + actual_w1_input = getattr(gm, "quant_moe_w1_input_scale_0").squeeze() + + assert torch.allclose(actual_w1_input, expected_max_w1_input_scale), ( + f"w1 input scale max mismatch. Got {actual_w1_input}, expected {expected_max_w1_input_scale}" + ) + + # Run the transformed graph and compare to reference output + with torch.inference_mode(): + output = gm(x, selected_experts, routing_weights) + assert output.shape == ref_output.shape, ( + f"Output shape mismatch: {output.shape} vs {ref_output.shape}" + ) + + assert torch.allclose(output, ref_output, rtol=0.05, atol=0.05), ( + f"Output mismatch. rtol=0.05, atol=0.05. Max diff: {(output - ref_output).abs().max()}" + ) + + +class NVFP4MoEModuleForInputScaleTest(nn.Module): + """Module wrapping torch_quant_nvfp4_moe for testing NVFP4 MoE input scale handling.""" + + def __init__( + self, + num_experts, + w1_weight, + w2_weight, + w3_weight, + w1_input_scale, + w2_input_scale, + w3_input_scale, + w1_weight_scale, + w2_weight_scale, + w3_weight_scale, + w1_alpha, + w2_alpha, + w3_alpha, + is_gated_mlp, + act_fn, + ): + super().__init__() + self.num_experts = num_experts + self.is_gated_mlp = is_gated_mlp + self.act_fn = act_fn + + for i in range(num_experts): + self.register_buffer(f"w1_{i}", w1_weight[i]) + self.register_buffer(f"w2_{i}", w2_weight[i]) + self.register_buffer(f"w1_iscale_{i}", w1_input_scale[i]) + self.register_buffer(f"w2_iscale_{i}", w2_input_scale[i]) + self.register_buffer(f"w1_wscale_{i}", w1_weight_scale[i]) + self.register_buffer(f"w2_wscale_{i}", w2_weight_scale[i]) + self.register_buffer(f"w1_alpha_{i}", w1_alpha[i]) + self.register_buffer(f"w2_alpha_{i}", w2_alpha[i]) + if is_gated_mlp: + self.register_buffer(f"w3_{i}", w3_weight[i]) + self.register_buffer(f"w3_iscale_{i}", w3_input_scale[i]) + self.register_buffer(f"w3_wscale_{i}", w3_weight_scale[i]) + self.register_buffer(f"w3_alpha_{i}", w3_alpha[i]) + + def forward(self, x, selected_experts, routing_weights): + return torch.ops.auto_deploy.torch_quant_nvfp4_moe( + x, + selected_experts, + routing_weights, + [getattr(self, f"w1_{i}") for i in range(self.num_experts)], + [getattr(self, f"w2_{i}") for i in range(self.num_experts)], + [getattr(self, f"w3_{i}") for i in range(self.num_experts)] + if self.is_gated_mlp + else [], + [getattr(self, f"w1_iscale_{i}") for i in range(self.num_experts)], + [getattr(self, f"w2_iscale_{i}") for i in range(self.num_experts)], + [getattr(self, f"w3_iscale_{i}") for i in range(self.num_experts)] + if self.is_gated_mlp + else [], + [getattr(self, f"w1_wscale_{i}") for i in range(self.num_experts)], + [getattr(self, f"w2_wscale_{i}") for i in range(self.num_experts)], + [getattr(self, f"w3_wscale_{i}") for i in range(self.num_experts)] + if self.is_gated_mlp + else [], + [getattr(self, f"w1_alpha_{i}") for i in range(self.num_experts)], + [getattr(self, f"w2_alpha_{i}") for i in range(self.num_experts)], + [getattr(self, f"w3_alpha_{i}") for i in range(self.num_experts)] + if self.is_gated_mlp + else [], + is_gated_mlp=self.is_gated_mlp, + act_fn=self.act_fn, + ) + + +@skip_pre_hopper +@pytest.mark.parametrize("allow_different_input_scales", [False, True]) +@pytest.mark.parametrize("scales_identical", [True, False]) +@pytest.mark.parametrize("is_gated_mlp", [False, True]) +@pytest.mark.skipif( + not trtllm_ops_available(), + reason="Requires trtllm ops", +) +def test_nvfp4_moe_different_input_scales( + allow_different_input_scales, scales_identical, is_gated_mlp +): + """ + Test NVFP4 MoE behavior with different/identical input scales via _stack_nvfp4_moe_weights. + + Tests the allow_different_input_scales config option for both gated and non-gated MLP: + - When scales_identical=True: should always work + - When scales_identical=False and allow_different_input_scales=False: should fail with assertion + - When scales_identical=False and allow_different_input_scales=True: should work (uses min) + + Note: NVFP4 uses min() (not max() like FP8) because scales are in kernel format (2688/amax): + smaller scale = larger amax = larger dynamic range. + + This test uses mock tensors to test the transform logic without running the actual NVFP4 kernel. + """ + from tensorrt_llm._torch.auto_deploy.transform.library.fused_moe import _stack_nvfp4_moe_weights + + torch.manual_seed(0) + + num_experts = 2 + hidden_size, intermediate_size = 128, 128 + act_fn = ActivationType.Silu if is_gated_mlp else ActivationType.Relu2 + + # NVFP4 constants + FP4_GLOBAL_SCALE_MAX = 448 * 6 # 2688 + NVFP4_BLOCK_SIZE = 16 + + # Create per-expert mock weights and scales + # We use mock tensors with correct shapes to test the transform logic + # without needing actual FP4 quantization (which requires SM>=100) + w1_weight, w2_weight, w3_weight = [], [], [] + w1_input_scale, w2_input_scale, w3_input_scale = [], [], [] + w1_weight_scale, w2_weight_scale, w3_weight_scale = [], [], [] + w1_alpha, w2_alpha, w3_alpha = [], [], [] + + for expert_id in range(num_experts): + # Mock FP4 weights (uint8 packed, half the size in last dim) + w1_fp4 = torch.randint( + 0, 255, (intermediate_size, hidden_size // 2), dtype=torch.uint8, device="cuda" + ) + w2_fp4 = torch.randint( + 0, 255, (hidden_size, intermediate_size // 2), dtype=torch.uint8, device="cuda" + ) + + # Mock block scales (2D): shape (m, n // NVFP4_BLOCK_SIZE) + # With 128x128 dims, no padding needed (already multiples of 128 and 8) + w1_block_scale = torch.randn( + intermediate_size, hidden_size // NVFP4_BLOCK_SIZE, dtype=torch.float32, device="cuda" + ).to(torch.float8_e4m3fn) + w2_block_scale = torch.randn( + hidden_size, intermediate_size // NVFP4_BLOCK_SIZE, dtype=torch.float32, device="cuda" + ).to(torch.float8_e4m3fn) + + w1_weight.append(w1_fp4) + w2_weight.append(w2_fp4) + w1_weight_scale.append(w1_block_scale) + w2_weight_scale.append(w2_block_scale) + + # Input scales: either identical or different per expert + # For NVFP4, scale = FP4_GLOBAL_SCALE_MAX / amax + if scales_identical: + # Same amax for all experts -> same input scale + inp_scale = torch.tensor(FP4_GLOBAL_SCALE_MAX / 1.0, dtype=torch.float32, device="cuda") + else: + # Different amax per expert -> different input scales + # Expert 0: amax=1.0, scale=2688/1.0=2688 + # Expert 1: amax=2.0, scale=2688/2.0=1344 + amax = 1.0 + expert_id + inp_scale = torch.tensor( + FP4_GLOBAL_SCALE_MAX / amax, dtype=torch.float32, device="cuda" + ) + + w1_input_scale.append(inp_scale) + w2_input_scale.append(inp_scale) + + # Mock weight_scale_2 (global scale for this expert's weights) + w1_scale_2 = torch.tensor(100.0, dtype=torch.float32, device="cuda") + w2_scale_2 = torch.tensor(100.0, dtype=torch.float32, device="cuda") + + # Alpha = 1 / (input_scale * weight_scale_2) + w1_alpha.append((1.0 / (inp_scale * w1_scale_2)).to(torch.float32)) + w2_alpha.append((1.0 / (inp_scale * w2_scale_2)).to(torch.float32)) + + # For gated MLP, create w3 weights/scales/alpha (same shape as w1) + if is_gated_mlp: + w3_fp4 = torch.randint( + 0, 255, (intermediate_size, hidden_size // 2), dtype=torch.uint8, device="cuda" + ) + w3_block_scale = torch.randn( + intermediate_size, + hidden_size // NVFP4_BLOCK_SIZE, + dtype=torch.float32, + device="cuda", + ).to(torch.float8_e4m3fn) + w3_weight.append(w3_fp4) + w3_weight_scale.append(w3_block_scale) + # w3 uses the same input scale as w1 (they process the same input) + w3_input_scale.append(inp_scale) + w3_scale_2 = torch.tensor(100.0, dtype=torch.float32, device="cuda") + w3_alpha.append((1.0 / (inp_scale * w3_scale_2)).to(torch.float32)) + + # Create a module with the NVFP4 MoE op + module = NVFP4MoEModuleForInputScaleTest( + num_experts, + w1_weight, + w2_weight, + w3_weight, + w1_input_scale, + w2_input_scale, + w3_input_scale, + w1_weight_scale, + w2_weight_scale, + w3_weight_scale, + w1_alpha, + w2_alpha, + w3_alpha, + is_gated_mlp, + act_fn, + ).cuda() + gm = fx.symbolic_trace(module) + + # Expected behavior: + # - scales_identical=True: always works + # - scales_identical=False, allow_different_input_scales=False: assertion error + # - scales_identical=False, allow_different_input_scales=True: works with min() + + if not scales_identical and not allow_different_input_scales: + # Should fail with assertion error + with pytest.raises(AssertionError, match="FC1 input scales differ"): + _stack_nvfp4_moe_weights(gm, allow_different_input_scales=allow_different_input_scales) + else: + # Should succeed + num_transformed = _stack_nvfp4_moe_weights( + gm, allow_different_input_scales=allow_different_input_scales + ) + gm.recompile() + + assert num_transformed == 1, f"Expected 1 transform, got {num_transformed}" + + # Verify that min() is used when scales differ + if not scales_identical: + # For gated MLP, global scale = min(w1_scales.min(), w3_scales.min()) + # For non-gated MLP, global scale = w1_scales.min() + if is_gated_mlp: + expected_min_input_scale = torch.minimum( + torch.stack(w1_input_scale).min(), + torch.stack(w3_input_scale).min(), + ) + else: + expected_min_input_scale = torch.stack(w1_input_scale).min() + + actual_input_scale = getattr(gm, "nvfp4_moe_w3_w1_input_scale_stacked_0") + + assert torch.allclose(actual_input_scale, expected_min_input_scale), ( + f"FC1 input scale min mismatch. Got {actual_input_scale}, expected {expected_min_input_scale}" + ) + + # Verify alpha was recomputed correctly + # new_alpha = old_alpha * per_expert_input_scale / global_input_scale + expected_alpha = ( + torch.stack(w1_alpha) * torch.stack(w1_input_scale) / expected_min_input_scale + ) + actual_alpha = getattr(gm, "nvfp4_moe_w1_alpha_stacked_0") + assert torch.allclose(actual_alpha, expected_alpha, rtol=1e-5, atol=1e-5), ( + f"Alpha recomputation mismatch. Got {actual_alpha}, expected {expected_alpha}" + )