[#10780][feat] AutoDeploy: Support per-expert scales in FP8 and NVFP4 MoE (#11322)

Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
Signed-off-by: Gal Hubara-Agam <96368689+galagam@users.noreply.github.com>
This commit is contained in:
Gal Hubara-Agam 2026-02-09 17:07:37 +02:00 committed by GitHub
parent 540fb0f29e
commit 2b60cc181c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 641 additions and 62 deletions

View File

@ -134,9 +134,11 @@ transforms:
stage: post_load_fusion
expect_mem_change: true
backend: trtllm
allow_different_input_scales: false
fuse_nvfp4_moe:
stage: post_load_fusion
expect_mem_change: true
allow_different_input_scales: false
fuse_allreduce_residual_rmsnorm:
stage: post_load_fusion
fuse_rmsnorm:

View File

@ -670,8 +670,8 @@ def triton_quant_fp8_moe(
w1_weight: torch.Tensor, # [E, I, H] stacked FP8 weights
w2_weight: torch.Tensor, # [E, H, I] stacked FP8 weights
w3_weight: torch.Tensor, # unused for mlp style
w1_input_scale: torch.Tensor, # [E] stacked input scales
w2_input_scale: torch.Tensor, # [E] stacked input scales
w1_input_scale: torch.Tensor, # [1] max input scale (precomputed)
w2_input_scale: torch.Tensor, # [1] max input scale (precomputed)
w3_input_scale: torch.Tensor, # unused
w1_weight_scale: torch.Tensor, # [E] stacked weight scales
w2_weight_scale: torch.Tensor, # [E] stacked weight scales
@ -698,10 +698,11 @@ def triton_quant_fp8_moe(
topk_weights = routing_weights.to(torch.float32).contiguous()
# Weights are already stacked [E, ...] - just ensure contiguous and extract scales
# Input scales are precomputed max values (consistent with trtllm backend)
w1_q = w1_weight.contiguous()
w2_q = w2_weight.contiguous()
a1_scale = w1_input_scale[0].to(torch.float32).reshape(1).contiguous()
a2_scale = w2_input_scale[0].to(torch.float32).reshape(1).contiguous()
a1_scale = w1_input_scale.to(torch.float32).reshape(1).contiguous()
a2_scale = w2_input_scale.to(torch.float32).reshape(1).contiguous()
b1_scale = w1_weight_scale.to(torch.float32).contiguous()
b2_scale = w2_weight_scale.to(torch.float32).contiguous()
@ -762,7 +763,7 @@ def triton_quant_fp8_moe(
@triton_quant_fp8_moe.register_fake
def triton_quant_fp8_moe(
def triton_quant_fp8_moe_fake(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,

View File

@ -136,9 +136,9 @@ def trtllm_quant_fp8_moe_fused(
routing_weights: Routing weights (B*S, TOP_K)
fc1_expert_weights: FC1 weights [E, 2*I, H] for gated_mlp, [E, I, H] for mlp
fc2_expert_weights: FC2 weights [E, H, I]
fc1_act_scale: FC1 activation scale [E]
fc1_act_scale: FC1 activation scalar (scalar)
fc1_dequant_scale: FC1 dequant scale [E]
fc2_act_scale_reciprocal: FC2 activation scale reciprocal [E]
fc2_act_scale_reciprocal: FC2 activation scale reciprocal (scalar)
fc2_dequant_scale: FC2 dequant scale [E]
is_gated_mlp: True for gated_mlp, False for mlp
act_fn: ActivationType.Silu for gated_mlp, ActivationType.Relu2 for mlp
@ -153,24 +153,26 @@ def trtllm_quant_fp8_moe_fused(
# Store original shape and flatten to 2D
x_shape = x.shape
x2d = x.view(-1, x_shape[-1])
# Quantize the input
x_q_fp8 = _quantize_fp8(x2d, fc1_act_scale[0])
# Scales are stored in float32
w1_input_scale = fc1_act_scale[0]
# Quantize the input using precomputed max scale
x_q_fp8 = _quantize_fp8(x2d, fc1_act_scale)
# Prepare quant_scales for TensorRT-LLM (Cutlass) FP8 format:
# [fc1_dequant_scale, fc2_act_scale_reciprocal, fc2_dequant_scale, gemm1_input_dequant_scale]
# For gated MLP:
# These are precomputed in `fused_moe` transform
# - fc1_dequant_scale: w1_weight_scale * w1_input_scale (combined for w1 and w3)
# - fc2_act_scale_reciprocal: 1 / w2_input_scale
# - fc1_dequant_scale: w2_weight_scale * w2_input_scale
# - fc1_act_scale: w1_input_scale
# These are precomputed in `fused_moe` transform:
# - fc1_dequant_scale: w1_weight_scale * max(w1_input_scale) [E]
# - fc2_act_scale_reciprocal: 1 / max(w2_input_scale) (scalar)
# - fc2_dequant_scale: w2_weight_scale * max(w2_input_scale) [E]
# - fc1_act_scale: max(w1_input_scale) (scalar)
assert fc1_dequant_scale.ndim == 1, "fc1_dequant_scale must be 1D"
assert fc2_dequant_scale.ndim == 1, "fc2_dequant_scale must be 1D"
quant_scales = [fc1_dequant_scale, fc2_act_scale_reciprocal, fc2_dequant_scale, w1_input_scale]
quant_scales = [
fc1_dequant_scale,
fc2_act_scale_reciprocal,
fc2_dequant_scale,
fc1_act_scale,
]
# Ensure contiguous tensors
selected_experts = selected_experts.int().contiguous()

View File

@ -17,6 +17,7 @@ from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code, get_attr_by_name
from ...utils.cuda_mem_tracker import cuda_memory_tracker
from ...utils.logger import ad_logger
from ...utils.module import get_submodule_of_param
from ...utils.node_utils import bfs, extract_op_args, identify_regions_between_residuals, is_op
from ..interface import (
@ -1323,11 +1324,21 @@ def remove_original_experts(gm: GraphModule, weight_lists: List[List[Node]]) ->
continue
def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", "triton"]) -> int:
def _stack_fp8_moe_weights(
gm: GraphModule,
backend: Literal["auto", "trtllm", "triton"],
allow_different_input_scales: bool = False,
) -> int:
"""
Stack per-expert FP8 weights and scales by materializing stacked tensors as parameters.
This is fast because we directly stack the tensor values (not graph nodes).
Similar to _insert_fused_moe_ops but for quantized MoE.
Args:
gm: The GraphModule to transform.
backend: Backend to use ('auto', 'trtllm', or 'triton').
allow_different_input_scales: If False (default), assert that all experts have identical
input scales and fail if not. If True, allow different scales (use max for quantization).
"""
def _register_parameter(gm: GraphModule, target, value):
@ -1387,23 +1398,26 @@ def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", "
# For optimization reasons, we precompute a few additional arguments to the trtllm_quant_fp8_moe_fused op
# to avoid computing them at runtime.
fc1_dequant = (w1_weight_scale_stacked * w1_input_scale_stacked[0]).squeeze()
fc2_act_scale_recip = (1.0 / w2_input_scale_stacked[0]).to(torch.float32)
fc2_dequant = (w2_weight_scale_stacked * w2_input_scale_stacked[0]).squeeze()
# We use max scale to handle different input scales per expert (if enabled).
fc1_act_scale = fc1_act_scale.max()
fc2_act_scale = w2_input_scale_stacked.max()
fc1_dequant = (w1_weight_scale_stacked * w1_input_scale_stacked.max()).squeeze()
fc2_act_scale_recip = (1.0 / fc2_act_scale).to(torch.float32)
fc2_dequant = (w2_weight_scale_stacked * fc2_act_scale).squeeze()
new_key_fc1_expert_weights = f"quant_moe_w3_w1_stacked_{fused_key_counter}"
new_key_fc2_expert_weights = f"quant_moe_w2_stacked_{fused_key_counter}"
new_key_fc1_act_scale = f"quant_moe_fc1_act_scale_{fused_key_counter}"
new_key_fc1_dequant = f"quant_moe_fc1_dequant_stacked_{fused_key_counter}"
new_key_fc2_act_scale_recip = f"quant_moe_fc2_act_scale_recip_stacked_{fused_key_counter}"
new_key_fc2_dequant = f"quant_moe_fc2_dequant_stacked_{fused_key_counter}"
new_key_fc1_expert_weights = f"quant_moe_w3_w1_stacked_{fused_key_counter}"
new_key_fc2_expert_weights = f"quant_moe_w2_stacked_{fused_key_counter}"
new_key_fc1_act_scale = f"quant_moe_w3_w1_input_scale_stacked_{fused_key_counter}"
_register_parameter(gm, new_key_fc1_dequant, fc1_dequant)
_register_parameter(gm, new_key_fc2_act_scale_recip, fc2_act_scale_recip)
_register_parameter(gm, new_key_fc2_dequant, fc2_dequant)
_register_parameter(gm, new_key_fc1_expert_weights, fc1_expert_weights)
_register_parameter(gm, new_key_fc2_expert_weights, fc2_expert_weights)
_register_parameter(gm, new_key_fc1_act_scale, fc1_act_scale)
_register_parameter(gm, new_key_fc1_dequant, fc1_dequant)
_register_parameter(gm, new_key_fc2_act_scale_recip, fc2_act_scale_recip)
_register_parameter(gm, new_key_fc2_dequant, fc2_dequant)
with graph.inserting_before(node):
args = (
@ -1424,19 +1438,27 @@ def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", "
new_key_w1 = f"quant_moe_w1_stacked_{fused_key_counter}"
new_key_w2 = f"quant_moe_w2_stacked_{fused_key_counter}"
new_key_w3 = f"quant_moe_w3_stacked_{fused_key_counter}"
new_key_w1_input_scale = f"quant_moe_w1_input_scale_stacked_{fused_key_counter}"
new_key_w2_input_scale = f"quant_moe_w2_input_scale_stacked_{fused_key_counter}"
new_key_w3_input_scale = f"quant_moe_w3_input_scale_stacked_{fused_key_counter}"
new_key_w1_weight_scale = f"quant_moe_w1_weight_scale_stacked_{fused_key_counter}"
new_key_w2_weight_scale = f"quant_moe_w2_weight_scale_stacked_{fused_key_counter}"
new_key_w3_weight_scale = f"quant_moe_w3_weight_scale_stacked_{fused_key_counter}"
w1_input_scale = w1_input_scale_stacked.max().reshape(1)
w2_input_scale = w2_input_scale_stacked.max().reshape(1)
# w3_input_scale: use max of w3 scales if present, else use empty tensor
w3_input_scale = (
w3_input_scale_stacked.max().reshape(1)
if w3_input_scale_stacked.numel() > 0
else torch.empty(1, device=w1_input_scale.device, dtype=w1_input_scale.dtype)
)
new_key_w1_input_scale = f"quant_moe_w1_input_scale_{fused_key_counter}"
new_key_w2_input_scale = f"quant_moe_w2_input_scale_{fused_key_counter}"
new_key_w3_input_scale = f"quant_moe_w3_input_scale_{fused_key_counter}"
_register_parameter(gm, new_key_w1, w1_stacked)
_register_parameter(gm, new_key_w2, w2_stacked)
_register_parameter(gm, new_key_w3, w3_stacked)
_register_parameter(gm, new_key_w1_input_scale, w1_input_scale_stacked)
_register_parameter(gm, new_key_w2_input_scale, w2_input_scale_stacked)
_register_parameter(gm, new_key_w3_input_scale, w3_input_scale_stacked)
_register_parameter(gm, new_key_w1_input_scale, w1_input_scale)
_register_parameter(gm, new_key_w2_input_scale, w2_input_scale)
_register_parameter(gm, new_key_w3_input_scale, w3_input_scale)
_register_parameter(gm, new_key_w1_weight_scale, w1_weight_scale_stacked)
_register_parameter(gm, new_key_w2_weight_scale, w2_weight_scale_stacked)
_register_parameter(gm, new_key_w3_weight_scale, w3_weight_scale_stacked)
@ -1509,12 +1531,31 @@ def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", "
0, device=w1_input_scale_stacked.device, dtype=w1_input_scale_stacked.dtype
)
)
assert torch.all(w1_input_scale_stacked[0] == w1_input_scale_stacked), (
"All w1 scales should have the same value."
)
assert torch.all(w2_input_scale_stacked[0] == w2_input_scale_stacked), (
"All w2 scales should have the same value."
)
# Check if input scales are identical across experts
w1_input_scales_identical = torch.all(
w1_input_scale_stacked[0] == w1_input_scale_stacked
).item()
w2_input_scales_identical = torch.all(
w2_input_scale_stacked[0] == w2_input_scale_stacked
).item()
if not w1_input_scales_identical or not w2_input_scales_identical:
if not allow_different_input_scales:
# Fail with assertion
assert w1_input_scales_identical, (
"All w1 input scales should have the same value. "
"Set allow_different_input_scales=True to allow different scales (uses max)."
)
assert w2_input_scales_identical, (
"All w2 input scales should have the same value. "
"Set allow_different_input_scales=True to allow different scales (uses max)."
)
# Issue warning once and continue - max() will be used
ad_logger.warning_once(
"FP8 MoE: Input scales differ across experts. Using max(input_scale) for quantization. "
"This may impact accuracy if scales differ significantly.",
key="fp8_moe_different_input_scales",
)
w1_weight_scale_stacked = _stack(w1_weight_scale, dim=0).to(torch.float32)
w2_weight_scale_stacked = _stack(w2_weight_scale, dim=0).to(torch.float32)
@ -1602,6 +1643,14 @@ class FuseFP8MoeConfig(TransformConfig):
default="auto",
description="Backend to use for FP8 MoE computation ('auto', 'trtllm' or 'triton'. default: 'auto').",
)
allow_different_input_scales: bool = Field(
default=False,
description=(
"If False (default), assert that all experts have identical input scales and fail if not. "
"If True, allow different per-expert input scales by using max(input_scale) for quantization. "
"This matches TRT-LLM manual backend behavior but may impact accuracy if scales differ significantly."
),
)
@TransformRegistry.register("fuse_fp8_moe")
@ -1611,6 +1660,10 @@ class FuseFP8Moe(BaseTransform):
This runs after weights are loaded, similar to FuseMoe for unquantized MoE.
"""
@classmethod
def get_config_class(cls) -> Type[TransformConfig]:
return FuseFP8MoeConfig
def _apply(
self,
gm: GraphModule,
@ -1619,7 +1672,11 @@ class FuseFP8Moe(BaseTransform):
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
with cuda_memory_tracker():
fused_key_counter = _stack_fp8_moe_weights(gm, backend=self.config.backend)
fused_key_counter = _stack_fp8_moe_weights(
gm,
backend=self.config.backend,
allow_different_input_scales=self.config.allow_different_input_scales,
)
info = TransformInfo(
skipped=(fused_key_counter == 0),
@ -1630,7 +1687,10 @@ class FuseFP8Moe(BaseTransform):
return gm, info
def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
def _stack_nvfp4_moe_weights(
gm: GraphModule,
allow_different_input_scales: bool = False,
) -> int:
def _register_parameter(gm: GraphModule, target, value):
gm.register_parameter(target, torch.nn.Parameter(value, requires_grad=False))
@ -1659,6 +1719,7 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
"w3_weight",
"w1_input_scale",
"w2_input_scale",
"w3_input_scale",
"w1_weight_scale",
"w2_weight_scale",
"w3_weight_scale",
@ -1680,20 +1741,79 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
if is_gated_mlp:
# For gated MLP, concatenate w1 and w3 as [w3, w1]
fc1_expert_weights = torch.cat([w3_stacked, w1_stacked], dim=1).contiguous()
# Expect w3 input scale and alpha to be the same as w1
fc1_act_scale = w1_input_scale_stacked
fc1_alpha_stacked = w1_alpha_stacked
fc1_weight_blockscale_fp8_stacked = torch.cat(
[w3_weight_blockscale_fp8_stacked, w1_weight_blockscale_fp8_stacked], dim=1
).contiguous()
# Check if all w1 and w3 input scales are identical across experts
all_scales_equal = (
torch.all(w1_input_scale_stacked == w1_input_scale_stacked[0])
and torch.all(w3_input_scale_stacked == w3_input_scale_stacked[0])
and torch.all(w1_input_scale_stacked == w3_input_scale_stacked)
)
if all_scales_equal:
# All scales are identical, no need for min() or alpha recomputation
fc1_act_scale = w1_input_scale_stacked[0]
fc1_alpha_stacked = w1_alpha_stacked
else:
if not allow_different_input_scales:
assert False, (
"FC1 input scales differ across experts (w1 and/or w3). "
"Set allow_different_input_scales=True to allow different scales (uses min)."
)
# Issue warning once and continue - min() will be used
ad_logger.warning_once(
"NVFP4 MoE: Input scales differ across experts. Using min(input_scale) for "
"FC1 quantization and recomputing alpha. This may impact accuracy if scales "
"differ significantly.",
key="nvfp4_moe_different_input_scales",
)
# Scales differ across experts - use global min scale and recompute alpha.
# Use min() because NVFP4 scales are in kernel format (2688/amax):
# smaller scale = larger amax = larger dynamic range.
fc1_act_scale = torch.minimum(
w1_input_scale_stacked.min(), w3_input_scale_stacked.min()
)
# Recompute alpha using global input scale instead of per-expert input scale.
# Formula: new_alpha = old_alpha * per_expert_input_scale / global_input_scale
# This ensures alpha is consistent with the global fc1_act_scale used by the kernel.
fc1_alpha_stacked = w1_alpha_stacked * w1_input_scale_stacked / fc1_act_scale
else:
fc1_expert_weights = w1_stacked
fc1_act_scale = w1_input_scale_stacked
fc1_alpha_stacked = w1_alpha_stacked
fc1_weight_blockscale_fp8_stacked = w1_weight_blockscale_fp8_stacked
# Check if all w1 input scales are identical across experts
all_scales_equal = torch.all(w1_input_scale_stacked == w1_input_scale_stacked[0])
if all_scales_equal:
# All scales are identical, no need for min() or alpha recomputation
fc1_act_scale = w1_input_scale_stacked[0]
fc1_alpha_stacked = w1_alpha_stacked
else:
if not allow_different_input_scales:
assert False, (
"FC1 input scales differ across experts (w1). "
"Set allow_different_input_scales=True to allow different scales (uses min)."
)
# Issue warning once and continue - min() will be used
ad_logger.warning_once(
"NVFP4 MoE: Input scales differ across experts. Using min(input_scale) for "
"FC1 quantization and recomputing alpha. This may impact accuracy if scales "
"differ significantly.",
key="nvfp4_moe_different_input_scales",
)
# Scales differ across experts - use global min scale and recompute alpha
fc1_act_scale = w1_input_scale_stacked.min()
fc1_alpha_stacked = w1_alpha_stacked * w1_input_scale_stacked / fc1_act_scale
fc2_expert_weights = w2_stacked
# Keep fc2_act_scale per-expert (no global scale aggregation for fc2).
# The kernel supports per-expert scales for fc2, and intermediate activations
# naturally have different dynamic ranges per expert.
fc2_act_scale = w2_input_scale_stacked
# No alpha recomputation needed since fc2 uses per-expert input scales.
fc2_alpha_stacked = w2_alpha_stacked
fc2_weight_blockscale_fp8_stacked = w2_weight_blockscale_fp8_stacked
new_key_fc1_expert_weights = f"nvfp4_moe_w3_w1_stacked_{fused_key_counter}"
@ -1764,7 +1884,7 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
_register_parameter(gm, new_key_fc1_act_scale, fc1_act_scale)
_register_parameter(gm, new_key_fc2_act_scale, fc2_act_scale)
_register_parameter(gm, new_key_fc1_alpha, fc1_alpha_stacked)
_register_parameter(gm, new_key_fc2_alpha, w2_alpha_stacked)
_register_parameter(gm, new_key_fc2_alpha, fc2_alpha_stacked)
with graph.inserting_before(node):
args = (
@ -1804,6 +1924,7 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
w3_list,
w1_input_scale,
w2_input_scale,
w3_input_scale,
w1_weight_scale,
w2_weight_scale,
w3_weight_scale,
@ -1822,6 +1943,12 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
# Scales are buffers, not parameters
w1_input_scale_stacked = _stack(w1_input_scale, dim=0)
w2_input_scale_stacked = _stack(w2_input_scale, dim=0)
w3_input_scale_stacked = _stack(
w3_input_scale,
dim=0,
device=w1_input_scale_stacked.device,
dtype=w1_input_scale_stacked.dtype,
)
# Use .view() not .to() to reinterpret bytes as float8, not value conversion
w1_weight_blockscale_fp8_stacked = _stack(w1_weight_scale, dim=0).view(torch.float8_e4m3fn)
@ -1856,6 +1983,21 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
return fused_key_counter
class FuseNVFP4MoeConfig(TransformConfig):
"""Configuration for NVFP4 MoE fusion transform."""
allow_different_input_scales: bool = Field(
default=False,
description=(
"If False (default), assert that all experts have identical input scales and fail if not. "
"If True, allow different per-expert input scales by using min(input_scale) for quantization. "
"Note: NVFP4 uses min() (not max like FP8) because scales are in kernel format (2688/amax): "
"smaller scale = larger amax = larger dynamic range. "
"This may impact accuracy if scales differ significantly."
),
)
@TransformRegistry.register("fuse_nvfp4_moe")
class FuseNVFP4Moe(BaseTransform):
"""
@ -1863,6 +2005,10 @@ class FuseNVFP4Moe(BaseTransform):
This runs after weights are loaded, similar to FuseFP8Moe.
"""
@classmethod
def get_config_class(cls) -> Type[TransformConfig]:
return FuseNVFP4MoeConfig
def _apply(
self,
gm: GraphModule,
@ -1871,7 +2017,10 @@ class FuseNVFP4Moe(BaseTransform):
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
with cuda_memory_tracker():
fused_key_counter = _stack_nvfp4_moe_weights(gm)
fused_key_counter = _stack_nvfp4_moe_weights(
gm,
allow_different_input_scales=self.config.allow_different_input_scales,
)
info = TransformInfo(
skipped=(fused_key_counter == 0),

View File

@ -299,15 +299,16 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit):
# Input scales (tensor-wise, replicated per expert for interface compatibility)
x_scale = x.abs().max().item() / FP8_MAX
w1_input_scale_tensor = torch.full((E,), x_scale, device=device, dtype=torch.float32)
# Input scales: precomputed max values with shape [1] (new API)
w1_input_scale_max = torch.tensor([x_scale], device=device, dtype=torch.float32)
# Compute intermediate activation scale by simulating first GEMM + ReLU^2
# This ensures w2_input_scale matches the actual activation magnitude
with torch.no_grad():
# Simulate the first GEMM: quantize input, do FP8 matmul, apply ReLU^2
x_q = (x / w1_input_scale_tensor[0]).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
x_q = (x / w1_input_scale_max.item()).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
# Dequantize and compute output for a sample
x_dq = x_q[:8].to(torch.float32) * w1_input_scale_tensor[0].item()
x_dq = x_q[:8].to(torch.float32) * w1_input_scale_max.item()
w1_dq = w1_fp8_stacked[0].to(torch.float32) * w1_weight_scale[0].item()
sample_out = torch.nn.functional.linear(x_dq.to(dtype), w1_dq.to(dtype))
sample_act = torch.square(torch.nn.functional.relu(sample_out))
@ -315,11 +316,11 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit):
# Ensure scale is not too small
intermediate_scale = max(intermediate_scale, 1e-6)
w2_input_scale_tensor = torch.full((E,), intermediate_scale, device=device, dtype=torch.float32)
w2_input_scale_max = torch.tensor([intermediate_scale], device=device, dtype=torch.float32)
# Convert scales to lists for torch_quant_fp8_moe reference
w1_input_scale_list = [w1_input_scale_tensor[0].clone() for _ in range(E)]
w2_input_scale_list = [w2_input_scale_tensor[0].clone() for _ in range(E)]
w1_input_scale_list = [w1_input_scale_max[0].clone() for _ in range(E)]
w2_input_scale_list = [w2_input_scale_max[0].clone() for _ in range(E)]
w1_weight_scale_list = [w1_weight_scale[e].clone() for e in range(E)]
w2_weight_scale_list = [w2_weight_scale[e].clone() for e in range(E)]
@ -327,7 +328,7 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit):
w3_fp8_list = [torch.empty((1, 1), device=device, dtype=torch.float8_e4m3fn) for _ in range(E)]
w3_fp8_stacked = torch.stack(w3_fp8_list).contiguous()
w3_input_scale_list = [torch.ones((), device=device, dtype=torch.float32) for _ in range(E)]
w3_input_scale_tensor = torch.ones((E,), device=device, dtype=torch.float32)
w3_input_scale_max = torch.ones((1,), device=device, dtype=torch.float32)
w3_weight_scale_list = [torch.ones((), device=device, dtype=torch.float32) for _ in range(E)]
w3_weight_scale_tensor = torch.ones((E,), device=device, dtype=torch.float32)
@ -352,7 +353,7 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit):
# Create equal routing weights
routing_weights = torch.ones((M, top_k), device=device, dtype=torch.float32) / top_k
# Triton FP8 quantized MoE (uses stacked tensors)
# Triton FP8 quantized MoE (uses stacked tensors with precomputed max input scales)
out_triton = torch.ops.auto_deploy.triton_quant_fp8_moe(
x,
selected_experts.to(torch.int32),
@ -360,9 +361,9 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit):
w1_fp8_stacked,
w2_fp8_stacked,
w3_fp8_stacked,
w1_input_scale_tensor,
w2_input_scale_tensor,
w3_input_scale_tensor,
w1_input_scale_max, # [1] precomputed max
w2_input_scale_max, # [1] precomputed max
w3_input_scale_max, # [1] precomputed max (unused)
w1_weight_scale,
w2_weight_scale,
w3_weight_scale_tensor,

View File

@ -414,7 +414,7 @@ def test_trtllm_fused_moe_fp8(
routing_weights,
fc1_expert_weights=w31_weight.contiguous() if is_gated_mlp else w1_weight.contiguous(),
fc2_expert_weights=w2_weight.contiguous(),
fc1_act_scale=hidden_states_scale.unsqueeze(0),
fc1_act_scale=hidden_states_scale,
fc1_dequant_scale=gemm1_dequant,
fc2_act_scale_reciprocal=gemm2_act_quant,
fc2_dequant_scale=gemm2_dequant,
@ -433,7 +433,7 @@ def test_trtllm_fused_moe_fp8(
w1_weight,
w2_weight,
w3_weight,
hidden_states_scale.unsqueeze(0),
hidden_states_scale.reshape(1),
w2_input_scale,
w3_input_scale,
w1_scales,

View File

@ -1,11 +1,14 @@
import pytest
import torch
import torch.fx as fx
import torch.nn as nn
import torch.nn.functional as F
from _graph_test_helpers import run_test_transformed_gm
from _model_test_utils import MoEOpModel
from _torch_test_utils import fp4_compatible, fp8_compatible, trtllm_ops_available
from utils.util import skip_pre_hopper
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
@ -650,3 +653,424 @@ def test_nvfp4_moe_fusion(is_gated_mlp, hidden_size, intermediate_size):
assert not torch.isnan(fused_output).any(), "Fused output contains NaN"
assert not torch.isinf(fused_output).any(), "Fused output contains Inf"
class FP8MoEModuleForInputScaleTest(nn.Module):
"""Module wrapping torch_quant_fp8_moe for testing FP8 MoE input scale handling."""
def __init__(
self,
num_experts,
w1_weight,
w2_weight,
w1_input_scale,
w2_input_scale,
w1_weight_scale,
w2_weight_scale,
is_gated_mlp,
act_fn,
):
super().__init__()
self.num_experts = num_experts
self.is_gated_mlp = is_gated_mlp
self.act_fn = act_fn
for i in range(num_experts):
self.register_buffer(f"w1_{i}", w1_weight[i])
self.register_buffer(f"w2_{i}", w2_weight[i])
self.register_buffer(f"w1_iscale_{i}", w1_input_scale[i])
self.register_buffer(f"w2_iscale_{i}", w2_input_scale[i])
self.register_buffer(f"w1_wscale_{i}", w1_weight_scale[i])
self.register_buffer(f"w2_wscale_{i}", w2_weight_scale[i])
def forward(self, x, selected_experts, routing_weights):
return torch.ops.auto_deploy.torch_quant_fp8_moe(
x,
selected_experts,
routing_weights,
[getattr(self, f"w1_{i}") for i in range(self.num_experts)],
[getattr(self, f"w2_{i}") for i in range(self.num_experts)],
[], # w3 is empty for non-gated MLP
[getattr(self, f"w1_iscale_{i}") for i in range(self.num_experts)],
[getattr(self, f"w2_iscale_{i}") for i in range(self.num_experts)],
[], # w3 input scale is empty for non-gated MLP
[getattr(self, f"w1_wscale_{i}") for i in range(self.num_experts)],
[getattr(self, f"w2_wscale_{i}") for i in range(self.num_experts)],
[], # w3 weight scale is empty for non-gated MLP
is_gated_mlp=self.is_gated_mlp,
act_fn=self.act_fn,
)
@skip_pre_hopper
@pytest.mark.parametrize("backend", ["trtllm", "triton"])
@pytest.mark.parametrize("allow_different_input_scales", [False, True])
@pytest.mark.parametrize("scales_identical", [True, False])
@pytest.mark.skipif(
not fp8_compatible() or not trtllm_ops_available(),
reason="Requires fp8 and trtllm support",
)
def test_fp8_moe_different_input_scales(backend, allow_different_input_scales, scales_identical):
"""
Test FP8 MoE behavior with different/identical input scales via InferenceOptimizer.
Tests the allow_different_input_scales config option for both trtllm and triton backends:
- When scales_identical=True: should always work
- When scales_identical=False and allow_different_input_scales=False: should fail with assertion
- When scales_identical=False and allow_different_input_scales=True: should work (uses max)
"""
from tensorrt_llm._torch.auto_deploy.transform.library.fused_moe import _stack_fp8_moe_weights
torch.manual_seed(0)
batch_size, num_experts, top_k = 4, 2, 2
hidden_size, intermediate_size = 128, 128
# Use non-gated MLP (Relu2) because triton backend only supports non-gated MLP
is_gated_mlp = False
act_fn = ActivationType.Relu2
# Generate test data
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda") * 0.5
# Simple routing: distribute tokens across experts
selected_experts = torch.zeros((batch_size, top_k), dtype=torch.int64, device="cuda")
for i in range(batch_size):
selected_experts[i, 0] = i % num_experts
selected_experts[i, 1] = (i + 1) % num_experts
routing_weights = torch.ones((batch_size, top_k), device="cuda", dtype=torch.float32) / top_k
# Create per-expert weights and scales (non-gated MLP, no w3)
w1_weight, w2_weight = [], []
w1_input_scale, w2_input_scale = [], []
w1_weight_scale, w2_weight_scale = [], []
for expert_id in range(num_experts):
# Random FP8 weights
w1_fp8 = torch.randn(intermediate_size, hidden_size, device="cuda").to(torch.float8_e4m3fn)
w2_fp8 = torch.randn(hidden_size, intermediate_size, device="cuda").to(torch.float8_e4m3fn)
w1_weight.append(w1_fp8)
w2_weight.append(w2_fp8)
# Random weight scales (shape [1])
w1_weight_scale.append(torch.tensor([0.1], dtype=torch.float32, device="cuda"))
w2_weight_scale.append(torch.tensor([0.1], dtype=torch.float32, device="cuda"))
# Input scales: either identical or different per expert (shape [1])
if scales_identical:
inp_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
else:
# Different input scales per expert - big variance to test max() behavior
inp_scale = torch.tensor([0.5 + 0.5 * expert_id], dtype=torch.float32, device="cuda")
w1_input_scale.append(inp_scale)
w2_input_scale.append(inp_scale)
# Create a module with the FP8 MoE op
module = FP8MoEModuleForInputScaleTest(
num_experts,
w1_weight,
w2_weight,
w1_input_scale,
w2_input_scale,
w1_weight_scale,
w2_weight_scale,
is_gated_mlp,
act_fn,
).cuda()
gm = fx.symbolic_trace(module)
# Compute reference output from original graph before transformation
with torch.inference_mode():
ref_output = gm(x, selected_experts, routing_weights)
# Expected behavior:
# - scales_identical=True: always works
# - scales_identical=False, allow_different_input_scales=False: assertion error
# - scales_identical=False, allow_different_input_scales=True: works with max()
if not scales_identical and not allow_different_input_scales:
# Should fail with assertion error
with pytest.raises(AssertionError, match="input scales should have the same value"):
_stack_fp8_moe_weights(
gm, backend=backend, allow_different_input_scales=allow_different_input_scales
)
else:
# Should succeed
num_transformed = _stack_fp8_moe_weights(
gm, backend=backend, allow_different_input_scales=allow_different_input_scales
)
gm.recompile()
assert num_transformed == 1, f"Expected 1 transform, got {num_transformed}"
# Verify that max() is used when scales differ
if not scales_identical:
expected_max_w1_input_scale = torch.stack(w1_input_scale).max()
# Attribute name differs between backends
if backend == "trtllm":
actual_w1_input = getattr(gm, "quant_moe_fc1_act_scale_0")
else: # triton
actual_w1_input = getattr(gm, "quant_moe_w1_input_scale_0").squeeze()
assert torch.allclose(actual_w1_input, expected_max_w1_input_scale), (
f"w1 input scale max mismatch. Got {actual_w1_input}, expected {expected_max_w1_input_scale}"
)
# Run the transformed graph and compare to reference output
with torch.inference_mode():
output = gm(x, selected_experts, routing_weights)
assert output.shape == ref_output.shape, (
f"Output shape mismatch: {output.shape} vs {ref_output.shape}"
)
assert torch.allclose(output, ref_output, rtol=0.05, atol=0.05), (
f"Output mismatch. rtol=0.05, atol=0.05. Max diff: {(output - ref_output).abs().max()}"
)
class NVFP4MoEModuleForInputScaleTest(nn.Module):
"""Module wrapping torch_quant_nvfp4_moe for testing NVFP4 MoE input scale handling."""
def __init__(
self,
num_experts,
w1_weight,
w2_weight,
w3_weight,
w1_input_scale,
w2_input_scale,
w3_input_scale,
w1_weight_scale,
w2_weight_scale,
w3_weight_scale,
w1_alpha,
w2_alpha,
w3_alpha,
is_gated_mlp,
act_fn,
):
super().__init__()
self.num_experts = num_experts
self.is_gated_mlp = is_gated_mlp
self.act_fn = act_fn
for i in range(num_experts):
self.register_buffer(f"w1_{i}", w1_weight[i])
self.register_buffer(f"w2_{i}", w2_weight[i])
self.register_buffer(f"w1_iscale_{i}", w1_input_scale[i])
self.register_buffer(f"w2_iscale_{i}", w2_input_scale[i])
self.register_buffer(f"w1_wscale_{i}", w1_weight_scale[i])
self.register_buffer(f"w2_wscale_{i}", w2_weight_scale[i])
self.register_buffer(f"w1_alpha_{i}", w1_alpha[i])
self.register_buffer(f"w2_alpha_{i}", w2_alpha[i])
if is_gated_mlp:
self.register_buffer(f"w3_{i}", w3_weight[i])
self.register_buffer(f"w3_iscale_{i}", w3_input_scale[i])
self.register_buffer(f"w3_wscale_{i}", w3_weight_scale[i])
self.register_buffer(f"w3_alpha_{i}", w3_alpha[i])
def forward(self, x, selected_experts, routing_weights):
return torch.ops.auto_deploy.torch_quant_nvfp4_moe(
x,
selected_experts,
routing_weights,
[getattr(self, f"w1_{i}") for i in range(self.num_experts)],
[getattr(self, f"w2_{i}") for i in range(self.num_experts)],
[getattr(self, f"w3_{i}") for i in range(self.num_experts)]
if self.is_gated_mlp
else [],
[getattr(self, f"w1_iscale_{i}") for i in range(self.num_experts)],
[getattr(self, f"w2_iscale_{i}") for i in range(self.num_experts)],
[getattr(self, f"w3_iscale_{i}") for i in range(self.num_experts)]
if self.is_gated_mlp
else [],
[getattr(self, f"w1_wscale_{i}") for i in range(self.num_experts)],
[getattr(self, f"w2_wscale_{i}") for i in range(self.num_experts)],
[getattr(self, f"w3_wscale_{i}") for i in range(self.num_experts)]
if self.is_gated_mlp
else [],
[getattr(self, f"w1_alpha_{i}") for i in range(self.num_experts)],
[getattr(self, f"w2_alpha_{i}") for i in range(self.num_experts)],
[getattr(self, f"w3_alpha_{i}") for i in range(self.num_experts)]
if self.is_gated_mlp
else [],
is_gated_mlp=self.is_gated_mlp,
act_fn=self.act_fn,
)
@skip_pre_hopper
@pytest.mark.parametrize("allow_different_input_scales", [False, True])
@pytest.mark.parametrize("scales_identical", [True, False])
@pytest.mark.parametrize("is_gated_mlp", [False, True])
@pytest.mark.skipif(
not trtllm_ops_available(),
reason="Requires trtllm ops",
)
def test_nvfp4_moe_different_input_scales(
allow_different_input_scales, scales_identical, is_gated_mlp
):
"""
Test NVFP4 MoE behavior with different/identical input scales via _stack_nvfp4_moe_weights.
Tests the allow_different_input_scales config option for both gated and non-gated MLP:
- When scales_identical=True: should always work
- When scales_identical=False and allow_different_input_scales=False: should fail with assertion
- When scales_identical=False and allow_different_input_scales=True: should work (uses min)
Note: NVFP4 uses min() (not max() like FP8) because scales are in kernel format (2688/amax):
smaller scale = larger amax = larger dynamic range.
This test uses mock tensors to test the transform logic without running the actual NVFP4 kernel.
"""
from tensorrt_llm._torch.auto_deploy.transform.library.fused_moe import _stack_nvfp4_moe_weights
torch.manual_seed(0)
num_experts = 2
hidden_size, intermediate_size = 128, 128
act_fn = ActivationType.Silu if is_gated_mlp else ActivationType.Relu2
# NVFP4 constants
FP4_GLOBAL_SCALE_MAX = 448 * 6 # 2688
NVFP4_BLOCK_SIZE = 16
# Create per-expert mock weights and scales
# We use mock tensors with correct shapes to test the transform logic
# without needing actual FP4 quantization (which requires SM>=100)
w1_weight, w2_weight, w3_weight = [], [], []
w1_input_scale, w2_input_scale, w3_input_scale = [], [], []
w1_weight_scale, w2_weight_scale, w3_weight_scale = [], [], []
w1_alpha, w2_alpha, w3_alpha = [], [], []
for expert_id in range(num_experts):
# Mock FP4 weights (uint8 packed, half the size in last dim)
w1_fp4 = torch.randint(
0, 255, (intermediate_size, hidden_size // 2), dtype=torch.uint8, device="cuda"
)
w2_fp4 = torch.randint(
0, 255, (hidden_size, intermediate_size // 2), dtype=torch.uint8, device="cuda"
)
# Mock block scales (2D): shape (m, n // NVFP4_BLOCK_SIZE)
# With 128x128 dims, no padding needed (already multiples of 128 and 8)
w1_block_scale = torch.randn(
intermediate_size, hidden_size // NVFP4_BLOCK_SIZE, dtype=torch.float32, device="cuda"
).to(torch.float8_e4m3fn)
w2_block_scale = torch.randn(
hidden_size, intermediate_size // NVFP4_BLOCK_SIZE, dtype=torch.float32, device="cuda"
).to(torch.float8_e4m3fn)
w1_weight.append(w1_fp4)
w2_weight.append(w2_fp4)
w1_weight_scale.append(w1_block_scale)
w2_weight_scale.append(w2_block_scale)
# Input scales: either identical or different per expert
# For NVFP4, scale = FP4_GLOBAL_SCALE_MAX / amax
if scales_identical:
# Same amax for all experts -> same input scale
inp_scale = torch.tensor(FP4_GLOBAL_SCALE_MAX / 1.0, dtype=torch.float32, device="cuda")
else:
# Different amax per expert -> different input scales
# Expert 0: amax=1.0, scale=2688/1.0=2688
# Expert 1: amax=2.0, scale=2688/2.0=1344
amax = 1.0 + expert_id
inp_scale = torch.tensor(
FP4_GLOBAL_SCALE_MAX / amax, dtype=torch.float32, device="cuda"
)
w1_input_scale.append(inp_scale)
w2_input_scale.append(inp_scale)
# Mock weight_scale_2 (global scale for this expert's weights)
w1_scale_2 = torch.tensor(100.0, dtype=torch.float32, device="cuda")
w2_scale_2 = torch.tensor(100.0, dtype=torch.float32, device="cuda")
# Alpha = 1 / (input_scale * weight_scale_2)
w1_alpha.append((1.0 / (inp_scale * w1_scale_2)).to(torch.float32))
w2_alpha.append((1.0 / (inp_scale * w2_scale_2)).to(torch.float32))
# For gated MLP, create w3 weights/scales/alpha (same shape as w1)
if is_gated_mlp:
w3_fp4 = torch.randint(
0, 255, (intermediate_size, hidden_size // 2), dtype=torch.uint8, device="cuda"
)
w3_block_scale = torch.randn(
intermediate_size,
hidden_size // NVFP4_BLOCK_SIZE,
dtype=torch.float32,
device="cuda",
).to(torch.float8_e4m3fn)
w3_weight.append(w3_fp4)
w3_weight_scale.append(w3_block_scale)
# w3 uses the same input scale as w1 (they process the same input)
w3_input_scale.append(inp_scale)
w3_scale_2 = torch.tensor(100.0, dtype=torch.float32, device="cuda")
w3_alpha.append((1.0 / (inp_scale * w3_scale_2)).to(torch.float32))
# Create a module with the NVFP4 MoE op
module = NVFP4MoEModuleForInputScaleTest(
num_experts,
w1_weight,
w2_weight,
w3_weight,
w1_input_scale,
w2_input_scale,
w3_input_scale,
w1_weight_scale,
w2_weight_scale,
w3_weight_scale,
w1_alpha,
w2_alpha,
w3_alpha,
is_gated_mlp,
act_fn,
).cuda()
gm = fx.symbolic_trace(module)
# Expected behavior:
# - scales_identical=True: always works
# - scales_identical=False, allow_different_input_scales=False: assertion error
# - scales_identical=False, allow_different_input_scales=True: works with min()
if not scales_identical and not allow_different_input_scales:
# Should fail with assertion error
with pytest.raises(AssertionError, match="FC1 input scales differ"):
_stack_nvfp4_moe_weights(gm, allow_different_input_scales=allow_different_input_scales)
else:
# Should succeed
num_transformed = _stack_nvfp4_moe_weights(
gm, allow_different_input_scales=allow_different_input_scales
)
gm.recompile()
assert num_transformed == 1, f"Expected 1 transform, got {num_transformed}"
# Verify that min() is used when scales differ
if not scales_identical:
# For gated MLP, global scale = min(w1_scales.min(), w3_scales.min())
# For non-gated MLP, global scale = w1_scales.min()
if is_gated_mlp:
expected_min_input_scale = torch.minimum(
torch.stack(w1_input_scale).min(),
torch.stack(w3_input_scale).min(),
)
else:
expected_min_input_scale = torch.stack(w1_input_scale).min()
actual_input_scale = getattr(gm, "nvfp4_moe_w3_w1_input_scale_stacked_0")
assert torch.allclose(actual_input_scale, expected_min_input_scale), (
f"FC1 input scale min mismatch. Got {actual_input_scale}, expected {expected_min_input_scale}"
)
# Verify alpha was recomputed correctly
# new_alpha = old_alpha * per_expert_input_scale / global_input_scale
expected_alpha = (
torch.stack(w1_alpha) * torch.stack(w1_input_scale) / expected_min_input_scale
)
actual_alpha = getattr(gm, "nvfp4_moe_w1_alpha_stacked_0")
assert torch.allclose(actual_alpha, expected_alpha, rtol=1e-5, atol=1e-5), (
f"Alpha recomputation mismatch. Got {actual_alpha}, expected {expected_alpha}"
)