mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Signed-off-by: Gal Hubara-Agam <96368689+galagam@users.noreply.github.com>
This commit is contained in:
parent
540fb0f29e
commit
2b60cc181c
@ -134,9 +134,11 @@ transforms:
|
||||
stage: post_load_fusion
|
||||
expect_mem_change: true
|
||||
backend: trtllm
|
||||
allow_different_input_scales: false
|
||||
fuse_nvfp4_moe:
|
||||
stage: post_load_fusion
|
||||
expect_mem_change: true
|
||||
allow_different_input_scales: false
|
||||
fuse_allreduce_residual_rmsnorm:
|
||||
stage: post_load_fusion
|
||||
fuse_rmsnorm:
|
||||
|
||||
@ -670,8 +670,8 @@ def triton_quant_fp8_moe(
|
||||
w1_weight: torch.Tensor, # [E, I, H] stacked FP8 weights
|
||||
w2_weight: torch.Tensor, # [E, H, I] stacked FP8 weights
|
||||
w3_weight: torch.Tensor, # unused for mlp style
|
||||
w1_input_scale: torch.Tensor, # [E] stacked input scales
|
||||
w2_input_scale: torch.Tensor, # [E] stacked input scales
|
||||
w1_input_scale: torch.Tensor, # [1] max input scale (precomputed)
|
||||
w2_input_scale: torch.Tensor, # [1] max input scale (precomputed)
|
||||
w3_input_scale: torch.Tensor, # unused
|
||||
w1_weight_scale: torch.Tensor, # [E] stacked weight scales
|
||||
w2_weight_scale: torch.Tensor, # [E] stacked weight scales
|
||||
@ -698,10 +698,11 @@ def triton_quant_fp8_moe(
|
||||
topk_weights = routing_weights.to(torch.float32).contiguous()
|
||||
|
||||
# Weights are already stacked [E, ...] - just ensure contiguous and extract scales
|
||||
# Input scales are precomputed max values (consistent with trtllm backend)
|
||||
w1_q = w1_weight.contiguous()
|
||||
w2_q = w2_weight.contiguous()
|
||||
a1_scale = w1_input_scale[0].to(torch.float32).reshape(1).contiguous()
|
||||
a2_scale = w2_input_scale[0].to(torch.float32).reshape(1).contiguous()
|
||||
a1_scale = w1_input_scale.to(torch.float32).reshape(1).contiguous()
|
||||
a2_scale = w2_input_scale.to(torch.float32).reshape(1).contiguous()
|
||||
b1_scale = w1_weight_scale.to(torch.float32).contiguous()
|
||||
b2_scale = w2_weight_scale.to(torch.float32).contiguous()
|
||||
|
||||
@ -762,7 +763,7 @@ def triton_quant_fp8_moe(
|
||||
|
||||
|
||||
@triton_quant_fp8_moe.register_fake
|
||||
def triton_quant_fp8_moe(
|
||||
def triton_quant_fp8_moe_fake(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
|
||||
@ -136,9 +136,9 @@ def trtllm_quant_fp8_moe_fused(
|
||||
routing_weights: Routing weights (B*S, TOP_K)
|
||||
fc1_expert_weights: FC1 weights [E, 2*I, H] for gated_mlp, [E, I, H] for mlp
|
||||
fc2_expert_weights: FC2 weights [E, H, I]
|
||||
fc1_act_scale: FC1 activation scale [E]
|
||||
fc1_act_scale: FC1 activation scalar (scalar)
|
||||
fc1_dequant_scale: FC1 dequant scale [E]
|
||||
fc2_act_scale_reciprocal: FC2 activation scale reciprocal [E]
|
||||
fc2_act_scale_reciprocal: FC2 activation scale reciprocal (scalar)
|
||||
fc2_dequant_scale: FC2 dequant scale [E]
|
||||
is_gated_mlp: True for gated_mlp, False for mlp
|
||||
act_fn: ActivationType.Silu for gated_mlp, ActivationType.Relu2 for mlp
|
||||
@ -153,24 +153,26 @@ def trtllm_quant_fp8_moe_fused(
|
||||
# Store original shape and flatten to 2D
|
||||
x_shape = x.shape
|
||||
x2d = x.view(-1, x_shape[-1])
|
||||
# Quantize the input
|
||||
x_q_fp8 = _quantize_fp8(x2d, fc1_act_scale[0])
|
||||
|
||||
# Scales are stored in float32
|
||||
w1_input_scale = fc1_act_scale[0]
|
||||
# Quantize the input using precomputed max scale
|
||||
x_q_fp8 = _quantize_fp8(x2d, fc1_act_scale)
|
||||
|
||||
# Prepare quant_scales for TensorRT-LLM (Cutlass) FP8 format:
|
||||
# [fc1_dequant_scale, fc2_act_scale_reciprocal, fc2_dequant_scale, gemm1_input_dequant_scale]
|
||||
# For gated MLP:
|
||||
# These are precomputed in `fused_moe` transform
|
||||
# - fc1_dequant_scale: w1_weight_scale * w1_input_scale (combined for w1 and w3)
|
||||
# - fc2_act_scale_reciprocal: 1 / w2_input_scale
|
||||
# - fc1_dequant_scale: w2_weight_scale * w2_input_scale
|
||||
# - fc1_act_scale: w1_input_scale
|
||||
# These are precomputed in `fused_moe` transform:
|
||||
# - fc1_dequant_scale: w1_weight_scale * max(w1_input_scale) [E]
|
||||
# - fc2_act_scale_reciprocal: 1 / max(w2_input_scale) (scalar)
|
||||
# - fc2_dequant_scale: w2_weight_scale * max(w2_input_scale) [E]
|
||||
# - fc1_act_scale: max(w1_input_scale) (scalar)
|
||||
|
||||
assert fc1_dequant_scale.ndim == 1, "fc1_dequant_scale must be 1D"
|
||||
assert fc2_dequant_scale.ndim == 1, "fc2_dequant_scale must be 1D"
|
||||
quant_scales = [fc1_dequant_scale, fc2_act_scale_reciprocal, fc2_dequant_scale, w1_input_scale]
|
||||
quant_scales = [
|
||||
fc1_dequant_scale,
|
||||
fc2_act_scale_reciprocal,
|
||||
fc2_dequant_scale,
|
||||
fc1_act_scale,
|
||||
]
|
||||
|
||||
# Ensure contiguous tensors
|
||||
selected_experts = selected_experts.int().contiguous()
|
||||
|
||||
@ -17,6 +17,7 @@ from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code, get_attr_by_name
|
||||
from ...utils.cuda_mem_tracker import cuda_memory_tracker
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.module import get_submodule_of_param
|
||||
from ...utils.node_utils import bfs, extract_op_args, identify_regions_between_residuals, is_op
|
||||
from ..interface import (
|
||||
@ -1323,11 +1324,21 @@ def remove_original_experts(gm: GraphModule, weight_lists: List[List[Node]]) ->
|
||||
continue
|
||||
|
||||
|
||||
def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", "triton"]) -> int:
|
||||
def _stack_fp8_moe_weights(
|
||||
gm: GraphModule,
|
||||
backend: Literal["auto", "trtllm", "triton"],
|
||||
allow_different_input_scales: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Stack per-expert FP8 weights and scales by materializing stacked tensors as parameters.
|
||||
This is fast because we directly stack the tensor values (not graph nodes).
|
||||
Similar to _insert_fused_moe_ops but for quantized MoE.
|
||||
|
||||
Args:
|
||||
gm: The GraphModule to transform.
|
||||
backend: Backend to use ('auto', 'trtllm', or 'triton').
|
||||
allow_different_input_scales: If False (default), assert that all experts have identical
|
||||
input scales and fail if not. If True, allow different scales (use max for quantization).
|
||||
"""
|
||||
|
||||
def _register_parameter(gm: GraphModule, target, value):
|
||||
@ -1387,23 +1398,26 @@ def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", "
|
||||
|
||||
# For optimization reasons, we precompute a few additional arguments to the trtllm_quant_fp8_moe_fused op
|
||||
# to avoid computing them at runtime.
|
||||
fc1_dequant = (w1_weight_scale_stacked * w1_input_scale_stacked[0]).squeeze()
|
||||
fc2_act_scale_recip = (1.0 / w2_input_scale_stacked[0]).to(torch.float32)
|
||||
fc2_dequant = (w2_weight_scale_stacked * w2_input_scale_stacked[0]).squeeze()
|
||||
# We use max scale to handle different input scales per expert (if enabled).
|
||||
fc1_act_scale = fc1_act_scale.max()
|
||||
fc2_act_scale = w2_input_scale_stacked.max()
|
||||
fc1_dequant = (w1_weight_scale_stacked * w1_input_scale_stacked.max()).squeeze()
|
||||
fc2_act_scale_recip = (1.0 / fc2_act_scale).to(torch.float32)
|
||||
fc2_dequant = (w2_weight_scale_stacked * fc2_act_scale).squeeze()
|
||||
|
||||
new_key_fc1_expert_weights = f"quant_moe_w3_w1_stacked_{fused_key_counter}"
|
||||
new_key_fc2_expert_weights = f"quant_moe_w2_stacked_{fused_key_counter}"
|
||||
new_key_fc1_act_scale = f"quant_moe_fc1_act_scale_{fused_key_counter}"
|
||||
new_key_fc1_dequant = f"quant_moe_fc1_dequant_stacked_{fused_key_counter}"
|
||||
new_key_fc2_act_scale_recip = f"quant_moe_fc2_act_scale_recip_stacked_{fused_key_counter}"
|
||||
new_key_fc2_dequant = f"quant_moe_fc2_dequant_stacked_{fused_key_counter}"
|
||||
new_key_fc1_expert_weights = f"quant_moe_w3_w1_stacked_{fused_key_counter}"
|
||||
new_key_fc2_expert_weights = f"quant_moe_w2_stacked_{fused_key_counter}"
|
||||
new_key_fc1_act_scale = f"quant_moe_w3_w1_input_scale_stacked_{fused_key_counter}"
|
||||
|
||||
_register_parameter(gm, new_key_fc1_dequant, fc1_dequant)
|
||||
_register_parameter(gm, new_key_fc2_act_scale_recip, fc2_act_scale_recip)
|
||||
_register_parameter(gm, new_key_fc2_dequant, fc2_dequant)
|
||||
_register_parameter(gm, new_key_fc1_expert_weights, fc1_expert_weights)
|
||||
_register_parameter(gm, new_key_fc2_expert_weights, fc2_expert_weights)
|
||||
_register_parameter(gm, new_key_fc1_act_scale, fc1_act_scale)
|
||||
_register_parameter(gm, new_key_fc1_dequant, fc1_dequant)
|
||||
_register_parameter(gm, new_key_fc2_act_scale_recip, fc2_act_scale_recip)
|
||||
_register_parameter(gm, new_key_fc2_dequant, fc2_dequant)
|
||||
|
||||
with graph.inserting_before(node):
|
||||
args = (
|
||||
@ -1424,19 +1438,27 @@ def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", "
|
||||
new_key_w1 = f"quant_moe_w1_stacked_{fused_key_counter}"
|
||||
new_key_w2 = f"quant_moe_w2_stacked_{fused_key_counter}"
|
||||
new_key_w3 = f"quant_moe_w3_stacked_{fused_key_counter}"
|
||||
new_key_w1_input_scale = f"quant_moe_w1_input_scale_stacked_{fused_key_counter}"
|
||||
new_key_w2_input_scale = f"quant_moe_w2_input_scale_stacked_{fused_key_counter}"
|
||||
new_key_w3_input_scale = f"quant_moe_w3_input_scale_stacked_{fused_key_counter}"
|
||||
new_key_w1_weight_scale = f"quant_moe_w1_weight_scale_stacked_{fused_key_counter}"
|
||||
new_key_w2_weight_scale = f"quant_moe_w2_weight_scale_stacked_{fused_key_counter}"
|
||||
new_key_w3_weight_scale = f"quant_moe_w3_weight_scale_stacked_{fused_key_counter}"
|
||||
w1_input_scale = w1_input_scale_stacked.max().reshape(1)
|
||||
w2_input_scale = w2_input_scale_stacked.max().reshape(1)
|
||||
# w3_input_scale: use max of w3 scales if present, else use empty tensor
|
||||
w3_input_scale = (
|
||||
w3_input_scale_stacked.max().reshape(1)
|
||||
if w3_input_scale_stacked.numel() > 0
|
||||
else torch.empty(1, device=w1_input_scale.device, dtype=w1_input_scale.dtype)
|
||||
)
|
||||
new_key_w1_input_scale = f"quant_moe_w1_input_scale_{fused_key_counter}"
|
||||
new_key_w2_input_scale = f"quant_moe_w2_input_scale_{fused_key_counter}"
|
||||
new_key_w3_input_scale = f"quant_moe_w3_input_scale_{fused_key_counter}"
|
||||
|
||||
_register_parameter(gm, new_key_w1, w1_stacked)
|
||||
_register_parameter(gm, new_key_w2, w2_stacked)
|
||||
_register_parameter(gm, new_key_w3, w3_stacked)
|
||||
_register_parameter(gm, new_key_w1_input_scale, w1_input_scale_stacked)
|
||||
_register_parameter(gm, new_key_w2_input_scale, w2_input_scale_stacked)
|
||||
_register_parameter(gm, new_key_w3_input_scale, w3_input_scale_stacked)
|
||||
_register_parameter(gm, new_key_w1_input_scale, w1_input_scale)
|
||||
_register_parameter(gm, new_key_w2_input_scale, w2_input_scale)
|
||||
_register_parameter(gm, new_key_w3_input_scale, w3_input_scale)
|
||||
_register_parameter(gm, new_key_w1_weight_scale, w1_weight_scale_stacked)
|
||||
_register_parameter(gm, new_key_w2_weight_scale, w2_weight_scale_stacked)
|
||||
_register_parameter(gm, new_key_w3_weight_scale, w3_weight_scale_stacked)
|
||||
@ -1509,12 +1531,31 @@ def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", "
|
||||
0, device=w1_input_scale_stacked.device, dtype=w1_input_scale_stacked.dtype
|
||||
)
|
||||
)
|
||||
assert torch.all(w1_input_scale_stacked[0] == w1_input_scale_stacked), (
|
||||
"All w1 scales should have the same value."
|
||||
)
|
||||
assert torch.all(w2_input_scale_stacked[0] == w2_input_scale_stacked), (
|
||||
"All w2 scales should have the same value."
|
||||
)
|
||||
# Check if input scales are identical across experts
|
||||
w1_input_scales_identical = torch.all(
|
||||
w1_input_scale_stacked[0] == w1_input_scale_stacked
|
||||
).item()
|
||||
w2_input_scales_identical = torch.all(
|
||||
w2_input_scale_stacked[0] == w2_input_scale_stacked
|
||||
).item()
|
||||
|
||||
if not w1_input_scales_identical or not w2_input_scales_identical:
|
||||
if not allow_different_input_scales:
|
||||
# Fail with assertion
|
||||
assert w1_input_scales_identical, (
|
||||
"All w1 input scales should have the same value. "
|
||||
"Set allow_different_input_scales=True to allow different scales (uses max)."
|
||||
)
|
||||
assert w2_input_scales_identical, (
|
||||
"All w2 input scales should have the same value. "
|
||||
"Set allow_different_input_scales=True to allow different scales (uses max)."
|
||||
)
|
||||
# Issue warning once and continue - max() will be used
|
||||
ad_logger.warning_once(
|
||||
"FP8 MoE: Input scales differ across experts. Using max(input_scale) for quantization. "
|
||||
"This may impact accuracy if scales differ significantly.",
|
||||
key="fp8_moe_different_input_scales",
|
||||
)
|
||||
|
||||
w1_weight_scale_stacked = _stack(w1_weight_scale, dim=0).to(torch.float32)
|
||||
w2_weight_scale_stacked = _stack(w2_weight_scale, dim=0).to(torch.float32)
|
||||
@ -1602,6 +1643,14 @@ class FuseFP8MoeConfig(TransformConfig):
|
||||
default="auto",
|
||||
description="Backend to use for FP8 MoE computation ('auto', 'trtllm' or 'triton'. default: 'auto').",
|
||||
)
|
||||
allow_different_input_scales: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If False (default), assert that all experts have identical input scales and fail if not. "
|
||||
"If True, allow different per-expert input scales by using max(input_scale) for quantization. "
|
||||
"This matches TRT-LLM manual backend behavior but may impact accuracy if scales differ significantly."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("fuse_fp8_moe")
|
||||
@ -1611,6 +1660,10 @@ class FuseFP8Moe(BaseTransform):
|
||||
This runs after weights are loaded, similar to FuseMoe for unquantized MoE.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return FuseFP8MoeConfig
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
@ -1619,7 +1672,11 @@ class FuseFP8Moe(BaseTransform):
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
with cuda_memory_tracker():
|
||||
fused_key_counter = _stack_fp8_moe_weights(gm, backend=self.config.backend)
|
||||
fused_key_counter = _stack_fp8_moe_weights(
|
||||
gm,
|
||||
backend=self.config.backend,
|
||||
allow_different_input_scales=self.config.allow_different_input_scales,
|
||||
)
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=(fused_key_counter == 0),
|
||||
@ -1630,7 +1687,10 @@ class FuseFP8Moe(BaseTransform):
|
||||
return gm, info
|
||||
|
||||
|
||||
def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
|
||||
def _stack_nvfp4_moe_weights(
|
||||
gm: GraphModule,
|
||||
allow_different_input_scales: bool = False,
|
||||
) -> int:
|
||||
def _register_parameter(gm: GraphModule, target, value):
|
||||
gm.register_parameter(target, torch.nn.Parameter(value, requires_grad=False))
|
||||
|
||||
@ -1659,6 +1719,7 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
|
||||
"w3_weight",
|
||||
"w1_input_scale",
|
||||
"w2_input_scale",
|
||||
"w3_input_scale",
|
||||
"w1_weight_scale",
|
||||
"w2_weight_scale",
|
||||
"w3_weight_scale",
|
||||
@ -1680,20 +1741,79 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
|
||||
if is_gated_mlp:
|
||||
# For gated MLP, concatenate w1 and w3 as [w3, w1]
|
||||
fc1_expert_weights = torch.cat([w3_stacked, w1_stacked], dim=1).contiguous()
|
||||
# Expect w3 input scale and alpha to be the same as w1
|
||||
fc1_act_scale = w1_input_scale_stacked
|
||||
fc1_alpha_stacked = w1_alpha_stacked
|
||||
fc1_weight_blockscale_fp8_stacked = torch.cat(
|
||||
[w3_weight_blockscale_fp8_stacked, w1_weight_blockscale_fp8_stacked], dim=1
|
||||
).contiguous()
|
||||
|
||||
# Check if all w1 and w3 input scales are identical across experts
|
||||
all_scales_equal = (
|
||||
torch.all(w1_input_scale_stacked == w1_input_scale_stacked[0])
|
||||
and torch.all(w3_input_scale_stacked == w3_input_scale_stacked[0])
|
||||
and torch.all(w1_input_scale_stacked == w3_input_scale_stacked)
|
||||
)
|
||||
|
||||
if all_scales_equal:
|
||||
# All scales are identical, no need for min() or alpha recomputation
|
||||
fc1_act_scale = w1_input_scale_stacked[0]
|
||||
fc1_alpha_stacked = w1_alpha_stacked
|
||||
else:
|
||||
if not allow_different_input_scales:
|
||||
assert False, (
|
||||
"FC1 input scales differ across experts (w1 and/or w3). "
|
||||
"Set allow_different_input_scales=True to allow different scales (uses min)."
|
||||
)
|
||||
# Issue warning once and continue - min() will be used
|
||||
ad_logger.warning_once(
|
||||
"NVFP4 MoE: Input scales differ across experts. Using min(input_scale) for "
|
||||
"FC1 quantization and recomputing alpha. This may impact accuracy if scales "
|
||||
"differ significantly.",
|
||||
key="nvfp4_moe_different_input_scales",
|
||||
)
|
||||
# Scales differ across experts - use global min scale and recompute alpha.
|
||||
# Use min() because NVFP4 scales are in kernel format (2688/amax):
|
||||
# smaller scale = larger amax = larger dynamic range.
|
||||
fc1_act_scale = torch.minimum(
|
||||
w1_input_scale_stacked.min(), w3_input_scale_stacked.min()
|
||||
)
|
||||
# Recompute alpha using global input scale instead of per-expert input scale.
|
||||
# Formula: new_alpha = old_alpha * per_expert_input_scale / global_input_scale
|
||||
# This ensures alpha is consistent with the global fc1_act_scale used by the kernel.
|
||||
fc1_alpha_stacked = w1_alpha_stacked * w1_input_scale_stacked / fc1_act_scale
|
||||
else:
|
||||
fc1_expert_weights = w1_stacked
|
||||
fc1_act_scale = w1_input_scale_stacked
|
||||
fc1_alpha_stacked = w1_alpha_stacked
|
||||
fc1_weight_blockscale_fp8_stacked = w1_weight_blockscale_fp8_stacked
|
||||
|
||||
# Check if all w1 input scales are identical across experts
|
||||
all_scales_equal = torch.all(w1_input_scale_stacked == w1_input_scale_stacked[0])
|
||||
|
||||
if all_scales_equal:
|
||||
# All scales are identical, no need for min() or alpha recomputation
|
||||
fc1_act_scale = w1_input_scale_stacked[0]
|
||||
fc1_alpha_stacked = w1_alpha_stacked
|
||||
else:
|
||||
if not allow_different_input_scales:
|
||||
assert False, (
|
||||
"FC1 input scales differ across experts (w1). "
|
||||
"Set allow_different_input_scales=True to allow different scales (uses min)."
|
||||
)
|
||||
# Issue warning once and continue - min() will be used
|
||||
ad_logger.warning_once(
|
||||
"NVFP4 MoE: Input scales differ across experts. Using min(input_scale) for "
|
||||
"FC1 quantization and recomputing alpha. This may impact accuracy if scales "
|
||||
"differ significantly.",
|
||||
key="nvfp4_moe_different_input_scales",
|
||||
)
|
||||
# Scales differ across experts - use global min scale and recompute alpha
|
||||
fc1_act_scale = w1_input_scale_stacked.min()
|
||||
fc1_alpha_stacked = w1_alpha_stacked * w1_input_scale_stacked / fc1_act_scale
|
||||
|
||||
fc2_expert_weights = w2_stacked
|
||||
# Keep fc2_act_scale per-expert (no global scale aggregation for fc2).
|
||||
# The kernel supports per-expert scales for fc2, and intermediate activations
|
||||
# naturally have different dynamic ranges per expert.
|
||||
fc2_act_scale = w2_input_scale_stacked
|
||||
# No alpha recomputation needed since fc2 uses per-expert input scales.
|
||||
fc2_alpha_stacked = w2_alpha_stacked
|
||||
fc2_weight_blockscale_fp8_stacked = w2_weight_blockscale_fp8_stacked
|
||||
|
||||
new_key_fc1_expert_weights = f"nvfp4_moe_w3_w1_stacked_{fused_key_counter}"
|
||||
@ -1764,7 +1884,7 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
|
||||
_register_parameter(gm, new_key_fc1_act_scale, fc1_act_scale)
|
||||
_register_parameter(gm, new_key_fc2_act_scale, fc2_act_scale)
|
||||
_register_parameter(gm, new_key_fc1_alpha, fc1_alpha_stacked)
|
||||
_register_parameter(gm, new_key_fc2_alpha, w2_alpha_stacked)
|
||||
_register_parameter(gm, new_key_fc2_alpha, fc2_alpha_stacked)
|
||||
|
||||
with graph.inserting_before(node):
|
||||
args = (
|
||||
@ -1804,6 +1924,7 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
|
||||
w3_list,
|
||||
w1_input_scale,
|
||||
w2_input_scale,
|
||||
w3_input_scale,
|
||||
w1_weight_scale,
|
||||
w2_weight_scale,
|
||||
w3_weight_scale,
|
||||
@ -1822,6 +1943,12 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
|
||||
# Scales are buffers, not parameters
|
||||
w1_input_scale_stacked = _stack(w1_input_scale, dim=0)
|
||||
w2_input_scale_stacked = _stack(w2_input_scale, dim=0)
|
||||
w3_input_scale_stacked = _stack(
|
||||
w3_input_scale,
|
||||
dim=0,
|
||||
device=w1_input_scale_stacked.device,
|
||||
dtype=w1_input_scale_stacked.dtype,
|
||||
)
|
||||
|
||||
# Use .view() not .to() to reinterpret bytes as float8, not value conversion
|
||||
w1_weight_blockscale_fp8_stacked = _stack(w1_weight_scale, dim=0).view(torch.float8_e4m3fn)
|
||||
@ -1856,6 +1983,21 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
|
||||
return fused_key_counter
|
||||
|
||||
|
||||
class FuseNVFP4MoeConfig(TransformConfig):
|
||||
"""Configuration for NVFP4 MoE fusion transform."""
|
||||
|
||||
allow_different_input_scales: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If False (default), assert that all experts have identical input scales and fail if not. "
|
||||
"If True, allow different per-expert input scales by using min(input_scale) for quantization. "
|
||||
"Note: NVFP4 uses min() (not max like FP8) because scales are in kernel format (2688/amax): "
|
||||
"smaller scale = larger amax = larger dynamic range. "
|
||||
"This may impact accuracy if scales differ significantly."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("fuse_nvfp4_moe")
|
||||
class FuseNVFP4Moe(BaseTransform):
|
||||
"""
|
||||
@ -1863,6 +2005,10 @@ class FuseNVFP4Moe(BaseTransform):
|
||||
This runs after weights are loaded, similar to FuseFP8Moe.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return FuseNVFP4MoeConfig
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
@ -1871,7 +2017,10 @@ class FuseNVFP4Moe(BaseTransform):
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
with cuda_memory_tracker():
|
||||
fused_key_counter = _stack_nvfp4_moe_weights(gm)
|
||||
fused_key_counter = _stack_nvfp4_moe_weights(
|
||||
gm,
|
||||
allow_different_input_scales=self.config.allow_different_input_scales,
|
||||
)
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=(fused_key_counter == 0),
|
||||
|
||||
@ -299,15 +299,16 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit):
|
||||
|
||||
# Input scales (tensor-wise, replicated per expert for interface compatibility)
|
||||
x_scale = x.abs().max().item() / FP8_MAX
|
||||
w1_input_scale_tensor = torch.full((E,), x_scale, device=device, dtype=torch.float32)
|
||||
# Input scales: precomputed max values with shape [1] (new API)
|
||||
w1_input_scale_max = torch.tensor([x_scale], device=device, dtype=torch.float32)
|
||||
|
||||
# Compute intermediate activation scale by simulating first GEMM + ReLU^2
|
||||
# This ensures w2_input_scale matches the actual activation magnitude
|
||||
with torch.no_grad():
|
||||
# Simulate the first GEMM: quantize input, do FP8 matmul, apply ReLU^2
|
||||
x_q = (x / w1_input_scale_tensor[0]).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
|
||||
x_q = (x / w1_input_scale_max.item()).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
|
||||
# Dequantize and compute output for a sample
|
||||
x_dq = x_q[:8].to(torch.float32) * w1_input_scale_tensor[0].item()
|
||||
x_dq = x_q[:8].to(torch.float32) * w1_input_scale_max.item()
|
||||
w1_dq = w1_fp8_stacked[0].to(torch.float32) * w1_weight_scale[0].item()
|
||||
sample_out = torch.nn.functional.linear(x_dq.to(dtype), w1_dq.to(dtype))
|
||||
sample_act = torch.square(torch.nn.functional.relu(sample_out))
|
||||
@ -315,11 +316,11 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit):
|
||||
# Ensure scale is not too small
|
||||
intermediate_scale = max(intermediate_scale, 1e-6)
|
||||
|
||||
w2_input_scale_tensor = torch.full((E,), intermediate_scale, device=device, dtype=torch.float32)
|
||||
w2_input_scale_max = torch.tensor([intermediate_scale], device=device, dtype=torch.float32)
|
||||
|
||||
# Convert scales to lists for torch_quant_fp8_moe reference
|
||||
w1_input_scale_list = [w1_input_scale_tensor[0].clone() for _ in range(E)]
|
||||
w2_input_scale_list = [w2_input_scale_tensor[0].clone() for _ in range(E)]
|
||||
w1_input_scale_list = [w1_input_scale_max[0].clone() for _ in range(E)]
|
||||
w2_input_scale_list = [w2_input_scale_max[0].clone() for _ in range(E)]
|
||||
w1_weight_scale_list = [w1_weight_scale[e].clone() for e in range(E)]
|
||||
w2_weight_scale_list = [w2_weight_scale[e].clone() for e in range(E)]
|
||||
|
||||
@ -327,7 +328,7 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit):
|
||||
w3_fp8_list = [torch.empty((1, 1), device=device, dtype=torch.float8_e4m3fn) for _ in range(E)]
|
||||
w3_fp8_stacked = torch.stack(w3_fp8_list).contiguous()
|
||||
w3_input_scale_list = [torch.ones((), device=device, dtype=torch.float32) for _ in range(E)]
|
||||
w3_input_scale_tensor = torch.ones((E,), device=device, dtype=torch.float32)
|
||||
w3_input_scale_max = torch.ones((1,), device=device, dtype=torch.float32)
|
||||
w3_weight_scale_list = [torch.ones((), device=device, dtype=torch.float32) for _ in range(E)]
|
||||
w3_weight_scale_tensor = torch.ones((E,), device=device, dtype=torch.float32)
|
||||
|
||||
@ -352,7 +353,7 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit):
|
||||
# Create equal routing weights
|
||||
routing_weights = torch.ones((M, top_k), device=device, dtype=torch.float32) / top_k
|
||||
|
||||
# Triton FP8 quantized MoE (uses stacked tensors)
|
||||
# Triton FP8 quantized MoE (uses stacked tensors with precomputed max input scales)
|
||||
out_triton = torch.ops.auto_deploy.triton_quant_fp8_moe(
|
||||
x,
|
||||
selected_experts.to(torch.int32),
|
||||
@ -360,9 +361,9 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit):
|
||||
w1_fp8_stacked,
|
||||
w2_fp8_stacked,
|
||||
w3_fp8_stacked,
|
||||
w1_input_scale_tensor,
|
||||
w2_input_scale_tensor,
|
||||
w3_input_scale_tensor,
|
||||
w1_input_scale_max, # [1] precomputed max
|
||||
w2_input_scale_max, # [1] precomputed max
|
||||
w3_input_scale_max, # [1] precomputed max (unused)
|
||||
w1_weight_scale,
|
||||
w2_weight_scale,
|
||||
w3_weight_scale_tensor,
|
||||
|
||||
@ -414,7 +414,7 @@ def test_trtllm_fused_moe_fp8(
|
||||
routing_weights,
|
||||
fc1_expert_weights=w31_weight.contiguous() if is_gated_mlp else w1_weight.contiguous(),
|
||||
fc2_expert_weights=w2_weight.contiguous(),
|
||||
fc1_act_scale=hidden_states_scale.unsqueeze(0),
|
||||
fc1_act_scale=hidden_states_scale,
|
||||
fc1_dequant_scale=gemm1_dequant,
|
||||
fc2_act_scale_reciprocal=gemm2_act_quant,
|
||||
fc2_dequant_scale=gemm2_dequant,
|
||||
@ -433,7 +433,7 @@ def test_trtllm_fused_moe_fp8(
|
||||
w1_weight,
|
||||
w2_weight,
|
||||
w3_weight,
|
||||
hidden_states_scale.unsqueeze(0),
|
||||
hidden_states_scale.reshape(1),
|
||||
w2_input_scale,
|
||||
w3_input_scale,
|
||||
w1_scales,
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from _graph_test_helpers import run_test_transformed_gm
|
||||
from _model_test_utils import MoEOpModel
|
||||
from _torch_test_utils import fp4_compatible, fp8_compatible, trtllm_ops_available
|
||||
from utils.util import skip_pre_hopper
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
@ -650,3 +653,424 @@ def test_nvfp4_moe_fusion(is_gated_mlp, hidden_size, intermediate_size):
|
||||
|
||||
assert not torch.isnan(fused_output).any(), "Fused output contains NaN"
|
||||
assert not torch.isinf(fused_output).any(), "Fused output contains Inf"
|
||||
|
||||
|
||||
class FP8MoEModuleForInputScaleTest(nn.Module):
|
||||
"""Module wrapping torch_quant_fp8_moe for testing FP8 MoE input scale handling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts,
|
||||
w1_weight,
|
||||
w2_weight,
|
||||
w1_input_scale,
|
||||
w2_input_scale,
|
||||
w1_weight_scale,
|
||||
w2_weight_scale,
|
||||
is_gated_mlp,
|
||||
act_fn,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.is_gated_mlp = is_gated_mlp
|
||||
self.act_fn = act_fn
|
||||
|
||||
for i in range(num_experts):
|
||||
self.register_buffer(f"w1_{i}", w1_weight[i])
|
||||
self.register_buffer(f"w2_{i}", w2_weight[i])
|
||||
self.register_buffer(f"w1_iscale_{i}", w1_input_scale[i])
|
||||
self.register_buffer(f"w2_iscale_{i}", w2_input_scale[i])
|
||||
self.register_buffer(f"w1_wscale_{i}", w1_weight_scale[i])
|
||||
self.register_buffer(f"w2_wscale_{i}", w2_weight_scale[i])
|
||||
|
||||
def forward(self, x, selected_experts, routing_weights):
|
||||
return torch.ops.auto_deploy.torch_quant_fp8_moe(
|
||||
x,
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
[getattr(self, f"w1_{i}") for i in range(self.num_experts)],
|
||||
[getattr(self, f"w2_{i}") for i in range(self.num_experts)],
|
||||
[], # w3 is empty for non-gated MLP
|
||||
[getattr(self, f"w1_iscale_{i}") for i in range(self.num_experts)],
|
||||
[getattr(self, f"w2_iscale_{i}") for i in range(self.num_experts)],
|
||||
[], # w3 input scale is empty for non-gated MLP
|
||||
[getattr(self, f"w1_wscale_{i}") for i in range(self.num_experts)],
|
||||
[getattr(self, f"w2_wscale_{i}") for i in range(self.num_experts)],
|
||||
[], # w3 weight scale is empty for non-gated MLP
|
||||
is_gated_mlp=self.is_gated_mlp,
|
||||
act_fn=self.act_fn,
|
||||
)
|
||||
|
||||
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.parametrize("backend", ["trtllm", "triton"])
|
||||
@pytest.mark.parametrize("allow_different_input_scales", [False, True])
|
||||
@pytest.mark.parametrize("scales_identical", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
not fp8_compatible() or not trtllm_ops_available(),
|
||||
reason="Requires fp8 and trtllm support",
|
||||
)
|
||||
def test_fp8_moe_different_input_scales(backend, allow_different_input_scales, scales_identical):
|
||||
"""
|
||||
Test FP8 MoE behavior with different/identical input scales via InferenceOptimizer.
|
||||
|
||||
Tests the allow_different_input_scales config option for both trtllm and triton backends:
|
||||
- When scales_identical=True: should always work
|
||||
- When scales_identical=False and allow_different_input_scales=False: should fail with assertion
|
||||
- When scales_identical=False and allow_different_input_scales=True: should work (uses max)
|
||||
"""
|
||||
from tensorrt_llm._torch.auto_deploy.transform.library.fused_moe import _stack_fp8_moe_weights
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
batch_size, num_experts, top_k = 4, 2, 2
|
||||
hidden_size, intermediate_size = 128, 128
|
||||
# Use non-gated MLP (Relu2) because triton backend only supports non-gated MLP
|
||||
is_gated_mlp = False
|
||||
act_fn = ActivationType.Relu2
|
||||
|
||||
# Generate test data
|
||||
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda") * 0.5
|
||||
|
||||
# Simple routing: distribute tokens across experts
|
||||
selected_experts = torch.zeros((batch_size, top_k), dtype=torch.int64, device="cuda")
|
||||
for i in range(batch_size):
|
||||
selected_experts[i, 0] = i % num_experts
|
||||
selected_experts[i, 1] = (i + 1) % num_experts
|
||||
routing_weights = torch.ones((batch_size, top_k), device="cuda", dtype=torch.float32) / top_k
|
||||
|
||||
# Create per-expert weights and scales (non-gated MLP, no w3)
|
||||
w1_weight, w2_weight = [], []
|
||||
w1_input_scale, w2_input_scale = [], []
|
||||
w1_weight_scale, w2_weight_scale = [], []
|
||||
|
||||
for expert_id in range(num_experts):
|
||||
# Random FP8 weights
|
||||
w1_fp8 = torch.randn(intermediate_size, hidden_size, device="cuda").to(torch.float8_e4m3fn)
|
||||
w2_fp8 = torch.randn(hidden_size, intermediate_size, device="cuda").to(torch.float8_e4m3fn)
|
||||
w1_weight.append(w1_fp8)
|
||||
w2_weight.append(w2_fp8)
|
||||
|
||||
# Random weight scales (shape [1])
|
||||
w1_weight_scale.append(torch.tensor([0.1], dtype=torch.float32, device="cuda"))
|
||||
w2_weight_scale.append(torch.tensor([0.1], dtype=torch.float32, device="cuda"))
|
||||
|
||||
# Input scales: either identical or different per expert (shape [1])
|
||||
if scales_identical:
|
||||
inp_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
|
||||
else:
|
||||
# Different input scales per expert - big variance to test max() behavior
|
||||
inp_scale = torch.tensor([0.5 + 0.5 * expert_id], dtype=torch.float32, device="cuda")
|
||||
|
||||
w1_input_scale.append(inp_scale)
|
||||
w2_input_scale.append(inp_scale)
|
||||
|
||||
# Create a module with the FP8 MoE op
|
||||
module = FP8MoEModuleForInputScaleTest(
|
||||
num_experts,
|
||||
w1_weight,
|
||||
w2_weight,
|
||||
w1_input_scale,
|
||||
w2_input_scale,
|
||||
w1_weight_scale,
|
||||
w2_weight_scale,
|
||||
is_gated_mlp,
|
||||
act_fn,
|
||||
).cuda()
|
||||
gm = fx.symbolic_trace(module)
|
||||
|
||||
# Compute reference output from original graph before transformation
|
||||
with torch.inference_mode():
|
||||
ref_output = gm(x, selected_experts, routing_weights)
|
||||
|
||||
# Expected behavior:
|
||||
# - scales_identical=True: always works
|
||||
# - scales_identical=False, allow_different_input_scales=False: assertion error
|
||||
# - scales_identical=False, allow_different_input_scales=True: works with max()
|
||||
|
||||
if not scales_identical and not allow_different_input_scales:
|
||||
# Should fail with assertion error
|
||||
with pytest.raises(AssertionError, match="input scales should have the same value"):
|
||||
_stack_fp8_moe_weights(
|
||||
gm, backend=backend, allow_different_input_scales=allow_different_input_scales
|
||||
)
|
||||
else:
|
||||
# Should succeed
|
||||
num_transformed = _stack_fp8_moe_weights(
|
||||
gm, backend=backend, allow_different_input_scales=allow_different_input_scales
|
||||
)
|
||||
gm.recompile()
|
||||
|
||||
assert num_transformed == 1, f"Expected 1 transform, got {num_transformed}"
|
||||
|
||||
# Verify that max() is used when scales differ
|
||||
if not scales_identical:
|
||||
expected_max_w1_input_scale = torch.stack(w1_input_scale).max()
|
||||
# Attribute name differs between backends
|
||||
if backend == "trtllm":
|
||||
actual_w1_input = getattr(gm, "quant_moe_fc1_act_scale_0")
|
||||
else: # triton
|
||||
actual_w1_input = getattr(gm, "quant_moe_w1_input_scale_0").squeeze()
|
||||
|
||||
assert torch.allclose(actual_w1_input, expected_max_w1_input_scale), (
|
||||
f"w1 input scale max mismatch. Got {actual_w1_input}, expected {expected_max_w1_input_scale}"
|
||||
)
|
||||
|
||||
# Run the transformed graph and compare to reference output
|
||||
with torch.inference_mode():
|
||||
output = gm(x, selected_experts, routing_weights)
|
||||
assert output.shape == ref_output.shape, (
|
||||
f"Output shape mismatch: {output.shape} vs {ref_output.shape}"
|
||||
)
|
||||
|
||||
assert torch.allclose(output, ref_output, rtol=0.05, atol=0.05), (
|
||||
f"Output mismatch. rtol=0.05, atol=0.05. Max diff: {(output - ref_output).abs().max()}"
|
||||
)
|
||||
|
||||
|
||||
class NVFP4MoEModuleForInputScaleTest(nn.Module):
|
||||
"""Module wrapping torch_quant_nvfp4_moe for testing NVFP4 MoE input scale handling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts,
|
||||
w1_weight,
|
||||
w2_weight,
|
||||
w3_weight,
|
||||
w1_input_scale,
|
||||
w2_input_scale,
|
||||
w3_input_scale,
|
||||
w1_weight_scale,
|
||||
w2_weight_scale,
|
||||
w3_weight_scale,
|
||||
w1_alpha,
|
||||
w2_alpha,
|
||||
w3_alpha,
|
||||
is_gated_mlp,
|
||||
act_fn,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.is_gated_mlp = is_gated_mlp
|
||||
self.act_fn = act_fn
|
||||
|
||||
for i in range(num_experts):
|
||||
self.register_buffer(f"w1_{i}", w1_weight[i])
|
||||
self.register_buffer(f"w2_{i}", w2_weight[i])
|
||||
self.register_buffer(f"w1_iscale_{i}", w1_input_scale[i])
|
||||
self.register_buffer(f"w2_iscale_{i}", w2_input_scale[i])
|
||||
self.register_buffer(f"w1_wscale_{i}", w1_weight_scale[i])
|
||||
self.register_buffer(f"w2_wscale_{i}", w2_weight_scale[i])
|
||||
self.register_buffer(f"w1_alpha_{i}", w1_alpha[i])
|
||||
self.register_buffer(f"w2_alpha_{i}", w2_alpha[i])
|
||||
if is_gated_mlp:
|
||||
self.register_buffer(f"w3_{i}", w3_weight[i])
|
||||
self.register_buffer(f"w3_iscale_{i}", w3_input_scale[i])
|
||||
self.register_buffer(f"w3_wscale_{i}", w3_weight_scale[i])
|
||||
self.register_buffer(f"w3_alpha_{i}", w3_alpha[i])
|
||||
|
||||
def forward(self, x, selected_experts, routing_weights):
|
||||
return torch.ops.auto_deploy.torch_quant_nvfp4_moe(
|
||||
x,
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
[getattr(self, f"w1_{i}") for i in range(self.num_experts)],
|
||||
[getattr(self, f"w2_{i}") for i in range(self.num_experts)],
|
||||
[getattr(self, f"w3_{i}") for i in range(self.num_experts)]
|
||||
if self.is_gated_mlp
|
||||
else [],
|
||||
[getattr(self, f"w1_iscale_{i}") for i in range(self.num_experts)],
|
||||
[getattr(self, f"w2_iscale_{i}") for i in range(self.num_experts)],
|
||||
[getattr(self, f"w3_iscale_{i}") for i in range(self.num_experts)]
|
||||
if self.is_gated_mlp
|
||||
else [],
|
||||
[getattr(self, f"w1_wscale_{i}") for i in range(self.num_experts)],
|
||||
[getattr(self, f"w2_wscale_{i}") for i in range(self.num_experts)],
|
||||
[getattr(self, f"w3_wscale_{i}") for i in range(self.num_experts)]
|
||||
if self.is_gated_mlp
|
||||
else [],
|
||||
[getattr(self, f"w1_alpha_{i}") for i in range(self.num_experts)],
|
||||
[getattr(self, f"w2_alpha_{i}") for i in range(self.num_experts)],
|
||||
[getattr(self, f"w3_alpha_{i}") for i in range(self.num_experts)]
|
||||
if self.is_gated_mlp
|
||||
else [],
|
||||
is_gated_mlp=self.is_gated_mlp,
|
||||
act_fn=self.act_fn,
|
||||
)
|
||||
|
||||
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.parametrize("allow_different_input_scales", [False, True])
|
||||
@pytest.mark.parametrize("scales_identical", [True, False])
|
||||
@pytest.mark.parametrize("is_gated_mlp", [False, True])
|
||||
@pytest.mark.skipif(
|
||||
not trtllm_ops_available(),
|
||||
reason="Requires trtllm ops",
|
||||
)
|
||||
def test_nvfp4_moe_different_input_scales(
|
||||
allow_different_input_scales, scales_identical, is_gated_mlp
|
||||
):
|
||||
"""
|
||||
Test NVFP4 MoE behavior with different/identical input scales via _stack_nvfp4_moe_weights.
|
||||
|
||||
Tests the allow_different_input_scales config option for both gated and non-gated MLP:
|
||||
- When scales_identical=True: should always work
|
||||
- When scales_identical=False and allow_different_input_scales=False: should fail with assertion
|
||||
- When scales_identical=False and allow_different_input_scales=True: should work (uses min)
|
||||
|
||||
Note: NVFP4 uses min() (not max() like FP8) because scales are in kernel format (2688/amax):
|
||||
smaller scale = larger amax = larger dynamic range.
|
||||
|
||||
This test uses mock tensors to test the transform logic without running the actual NVFP4 kernel.
|
||||
"""
|
||||
from tensorrt_llm._torch.auto_deploy.transform.library.fused_moe import _stack_nvfp4_moe_weights
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
num_experts = 2
|
||||
hidden_size, intermediate_size = 128, 128
|
||||
act_fn = ActivationType.Silu if is_gated_mlp else ActivationType.Relu2
|
||||
|
||||
# NVFP4 constants
|
||||
FP4_GLOBAL_SCALE_MAX = 448 * 6 # 2688
|
||||
NVFP4_BLOCK_SIZE = 16
|
||||
|
||||
# Create per-expert mock weights and scales
|
||||
# We use mock tensors with correct shapes to test the transform logic
|
||||
# without needing actual FP4 quantization (which requires SM>=100)
|
||||
w1_weight, w2_weight, w3_weight = [], [], []
|
||||
w1_input_scale, w2_input_scale, w3_input_scale = [], [], []
|
||||
w1_weight_scale, w2_weight_scale, w3_weight_scale = [], [], []
|
||||
w1_alpha, w2_alpha, w3_alpha = [], [], []
|
||||
|
||||
for expert_id in range(num_experts):
|
||||
# Mock FP4 weights (uint8 packed, half the size in last dim)
|
||||
w1_fp4 = torch.randint(
|
||||
0, 255, (intermediate_size, hidden_size // 2), dtype=torch.uint8, device="cuda"
|
||||
)
|
||||
w2_fp4 = torch.randint(
|
||||
0, 255, (hidden_size, intermediate_size // 2), dtype=torch.uint8, device="cuda"
|
||||
)
|
||||
|
||||
# Mock block scales (2D): shape (m, n // NVFP4_BLOCK_SIZE)
|
||||
# With 128x128 dims, no padding needed (already multiples of 128 and 8)
|
||||
w1_block_scale = torch.randn(
|
||||
intermediate_size, hidden_size // NVFP4_BLOCK_SIZE, dtype=torch.float32, device="cuda"
|
||||
).to(torch.float8_e4m3fn)
|
||||
w2_block_scale = torch.randn(
|
||||
hidden_size, intermediate_size // NVFP4_BLOCK_SIZE, dtype=torch.float32, device="cuda"
|
||||
).to(torch.float8_e4m3fn)
|
||||
|
||||
w1_weight.append(w1_fp4)
|
||||
w2_weight.append(w2_fp4)
|
||||
w1_weight_scale.append(w1_block_scale)
|
||||
w2_weight_scale.append(w2_block_scale)
|
||||
|
||||
# Input scales: either identical or different per expert
|
||||
# For NVFP4, scale = FP4_GLOBAL_SCALE_MAX / amax
|
||||
if scales_identical:
|
||||
# Same amax for all experts -> same input scale
|
||||
inp_scale = torch.tensor(FP4_GLOBAL_SCALE_MAX / 1.0, dtype=torch.float32, device="cuda")
|
||||
else:
|
||||
# Different amax per expert -> different input scales
|
||||
# Expert 0: amax=1.0, scale=2688/1.0=2688
|
||||
# Expert 1: amax=2.0, scale=2688/2.0=1344
|
||||
amax = 1.0 + expert_id
|
||||
inp_scale = torch.tensor(
|
||||
FP4_GLOBAL_SCALE_MAX / amax, dtype=torch.float32, device="cuda"
|
||||
)
|
||||
|
||||
w1_input_scale.append(inp_scale)
|
||||
w2_input_scale.append(inp_scale)
|
||||
|
||||
# Mock weight_scale_2 (global scale for this expert's weights)
|
||||
w1_scale_2 = torch.tensor(100.0, dtype=torch.float32, device="cuda")
|
||||
w2_scale_2 = torch.tensor(100.0, dtype=torch.float32, device="cuda")
|
||||
|
||||
# Alpha = 1 / (input_scale * weight_scale_2)
|
||||
w1_alpha.append((1.0 / (inp_scale * w1_scale_2)).to(torch.float32))
|
||||
w2_alpha.append((1.0 / (inp_scale * w2_scale_2)).to(torch.float32))
|
||||
|
||||
# For gated MLP, create w3 weights/scales/alpha (same shape as w1)
|
||||
if is_gated_mlp:
|
||||
w3_fp4 = torch.randint(
|
||||
0, 255, (intermediate_size, hidden_size // 2), dtype=torch.uint8, device="cuda"
|
||||
)
|
||||
w3_block_scale = torch.randn(
|
||||
intermediate_size,
|
||||
hidden_size // NVFP4_BLOCK_SIZE,
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
).to(torch.float8_e4m3fn)
|
||||
w3_weight.append(w3_fp4)
|
||||
w3_weight_scale.append(w3_block_scale)
|
||||
# w3 uses the same input scale as w1 (they process the same input)
|
||||
w3_input_scale.append(inp_scale)
|
||||
w3_scale_2 = torch.tensor(100.0, dtype=torch.float32, device="cuda")
|
||||
w3_alpha.append((1.0 / (inp_scale * w3_scale_2)).to(torch.float32))
|
||||
|
||||
# Create a module with the NVFP4 MoE op
|
||||
module = NVFP4MoEModuleForInputScaleTest(
|
||||
num_experts,
|
||||
w1_weight,
|
||||
w2_weight,
|
||||
w3_weight,
|
||||
w1_input_scale,
|
||||
w2_input_scale,
|
||||
w3_input_scale,
|
||||
w1_weight_scale,
|
||||
w2_weight_scale,
|
||||
w3_weight_scale,
|
||||
w1_alpha,
|
||||
w2_alpha,
|
||||
w3_alpha,
|
||||
is_gated_mlp,
|
||||
act_fn,
|
||||
).cuda()
|
||||
gm = fx.symbolic_trace(module)
|
||||
|
||||
# Expected behavior:
|
||||
# - scales_identical=True: always works
|
||||
# - scales_identical=False, allow_different_input_scales=False: assertion error
|
||||
# - scales_identical=False, allow_different_input_scales=True: works with min()
|
||||
|
||||
if not scales_identical and not allow_different_input_scales:
|
||||
# Should fail with assertion error
|
||||
with pytest.raises(AssertionError, match="FC1 input scales differ"):
|
||||
_stack_nvfp4_moe_weights(gm, allow_different_input_scales=allow_different_input_scales)
|
||||
else:
|
||||
# Should succeed
|
||||
num_transformed = _stack_nvfp4_moe_weights(
|
||||
gm, allow_different_input_scales=allow_different_input_scales
|
||||
)
|
||||
gm.recompile()
|
||||
|
||||
assert num_transformed == 1, f"Expected 1 transform, got {num_transformed}"
|
||||
|
||||
# Verify that min() is used when scales differ
|
||||
if not scales_identical:
|
||||
# For gated MLP, global scale = min(w1_scales.min(), w3_scales.min())
|
||||
# For non-gated MLP, global scale = w1_scales.min()
|
||||
if is_gated_mlp:
|
||||
expected_min_input_scale = torch.minimum(
|
||||
torch.stack(w1_input_scale).min(),
|
||||
torch.stack(w3_input_scale).min(),
|
||||
)
|
||||
else:
|
||||
expected_min_input_scale = torch.stack(w1_input_scale).min()
|
||||
|
||||
actual_input_scale = getattr(gm, "nvfp4_moe_w3_w1_input_scale_stacked_0")
|
||||
|
||||
assert torch.allclose(actual_input_scale, expected_min_input_scale), (
|
||||
f"FC1 input scale min mismatch. Got {actual_input_scale}, expected {expected_min_input_scale}"
|
||||
)
|
||||
|
||||
# Verify alpha was recomputed correctly
|
||||
# new_alpha = old_alpha * per_expert_input_scale / global_input_scale
|
||||
expected_alpha = (
|
||||
torch.stack(w1_alpha) * torch.stack(w1_input_scale) / expected_min_input_scale
|
||||
)
|
||||
actual_alpha = getattr(gm, "nvfp4_moe_w1_alpha_stacked_0")
|
||||
assert torch.allclose(actual_alpha, expected_alpha, rtol=1e-5, atol=1e-5), (
|
||||
f"Alpha recomputation mismatch. Got {actual_alpha}, expected {expected_alpha}"
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user