mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 10:11:47 +08:00
[None][fix] AutoDeploy: Fix the nvfp4 fused_moe (#10727)
Signed-off-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com>
This commit is contained in:
parent
0cfd08745c
commit
b6acd96616
@ -130,7 +130,7 @@ transforms:
|
||||
backend: trtllm
|
||||
fuse_nvfp4_moe:
|
||||
stage: post_load_fusion
|
||||
enabled: false
|
||||
enabled: true
|
||||
fuse_allreduce_residual_rmsnorm:
|
||||
stage: post_load_fusion
|
||||
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
|
||||
|
||||
@ -271,13 +271,9 @@ def trtllm_quant_nvfp4_moe_fused(
|
||||
# Convert the inter_size from number of uint8 elements to number of FP4 elements.
|
||||
inter_size *= FP4_PER_UINT8
|
||||
|
||||
# Validate shapes and padding requirements as defined by the cutlass kernel.
|
||||
# Validate block scale tensors are 3D (padding requirements handled below)
|
||||
assert fc1_weight_blockscale_fp8.ndim == 3, "fc1_weight_blockscale_fp8 must be 3D"
|
||||
assert fc2_weight_blockscale_fp8.ndim == 3, "fc2_weight_blockscale_fp8 must be 3D"
|
||||
assert fc1_weight_blockscale_fp8.size(1) % TRTLLM_NVFP4_ROW_SIZE == 0
|
||||
assert fc2_weight_blockscale_fp8.size(1) % TRTLLM_NVFP4_ROW_SIZE == 0
|
||||
assert fc1_weight_blockscale_fp8.size(2) % TRTLLM_NVFP4_COLUMN_SIZE == 0
|
||||
assert fc2_weight_blockscale_fp8.size(2) % TRTLLM_NVFP4_COLUMN_SIZE == 0
|
||||
|
||||
_validate_mlp_style_and_act_fn(is_gated_mlp, act_fn)
|
||||
act_fn = ActivationType.Swiglu if act_fn == ActivationType.Silu else act_fn
|
||||
@ -292,7 +288,7 @@ def trtllm_quant_nvfp4_moe_fused(
|
||||
input_blockscale = None
|
||||
output_dtype = x.dtype
|
||||
|
||||
# Pad inter_size to be divisible by 128
|
||||
# Pad inter_size to be divisible by TRTLLM_NVFP4_ROW_SIZE
|
||||
inter_size_padded = math.ceil(inter_size / TRTLLM_NVFP4_ROW_SIZE) * TRTLLM_NVFP4_ROW_SIZE
|
||||
fc1_inter_size_padded = (
|
||||
math.ceil(fc1_inter_size / TRTLLM_NVFP4_ROW_SIZE) * TRTLLM_NVFP4_ROW_SIZE
|
||||
@ -305,18 +301,30 @@ def trtllm_quant_nvfp4_moe_fused(
|
||||
not is_gated_mlp and inter_size_padded != inter_size
|
||||
)
|
||||
hidden_size_needs_padding = hidden_size % TRTLLM_NVFP4_COLUMN_SIZE != 0
|
||||
|
||||
hidden_blocks_padded = hidden_size_padded // NVFP4_BLOCK_SIZE
|
||||
inter_blocks_padded = inter_size_padded // NVFP4_BLOCK_SIZE
|
||||
|
||||
if inter_size_needs_padding or hidden_size_needs_padding:
|
||||
assert False, "See https://github.com/NVIDIA/TensorRT-LLM/issues/10331"
|
||||
# fc1_expert_weights_fp4: [E, I, H] or [E, 2*I, H]
|
||||
# Pad fc1_expert_weights_fp4: [E, I, H/2] or [E, 2*I, H/2]
|
||||
fc1_padded = fc1_expert_weights_fp4.new_zeros(
|
||||
fc1_expert_weights_fp4.size(0),
|
||||
fc1_inter_size_padded,
|
||||
hidden_size_padded // FP4_PER_UINT8,
|
||||
)
|
||||
fc1_padded[:, :fc1_inter_size, :] = fc1_expert_weights_fp4
|
||||
fc1_padded[:, :fc1_inter_size, : hidden_size // FP4_PER_UINT8] = fc1_expert_weights_fp4
|
||||
fc1_expert_weights_fp4 = fc1_padded
|
||||
|
||||
# fc2_expert_weights_fp4: [E, H, I]
|
||||
# Block scales may already be padded, so check actual size
|
||||
fc1_bs_size1 = fc1_weight_blockscale_fp8.size(1)
|
||||
fc1_bs_size2 = fc1_weight_blockscale_fp8.size(2)
|
||||
if fc1_bs_size1 < fc1_inter_size_padded or fc1_bs_size2 < hidden_blocks_padded:
|
||||
fc1_blockscale_fp8_padded = fc1_weight_blockscale_fp8.new_zeros(
|
||||
n_experts, fc1_inter_size_padded, hidden_blocks_padded
|
||||
)
|
||||
fc1_blockscale_fp8_padded[:, :fc1_bs_size1, :fc1_bs_size2] = fc1_weight_blockscale_fp8
|
||||
fc1_weight_blockscale_fp8 = fc1_blockscale_fp8_padded
|
||||
|
||||
fc2_padded = fc2_expert_weights_fp4.new_zeros(
|
||||
n_experts, hidden_size_padded, inter_size_padded // FP4_PER_UINT8
|
||||
)
|
||||
@ -325,16 +333,20 @@ def trtllm_quant_nvfp4_moe_fused(
|
||||
f"inter_size {inter_size} must be divisible by {NVFP4_BLOCK_SIZE}"
|
||||
)
|
||||
|
||||
fc2_padded[:, :, : inter_size // FP4_PER_UINT8] = fc2_expert_weights_fp4
|
||||
fc2_padded[:, :hidden_size, : inter_size // FP4_PER_UINT8] = fc2_expert_weights_fp4
|
||||
fc2_expert_weights_fp4 = fc2_padded
|
||||
|
||||
fc2_blockscale_fp8_padded = fc2_weight_blockscale_fp8.new_zeros(
|
||||
n_experts, hidden_size_padded, inter_size_padded // NVFP4_BLOCK_SIZE
|
||||
)
|
||||
fc2_blockscale_fp8_padded[:, :, : inter_size // NVFP4_BLOCK_SIZE] = (
|
||||
fc2_weight_blockscale_fp8
|
||||
)
|
||||
fc2_weight_blockscale_fp8 = fc2_blockscale_fp8_padded
|
||||
# Pad fc2_weight_blockscale_fp8: [E, H, I/16]
|
||||
# Block scales may already be padded, so check actual size
|
||||
fc2_bs_size1 = fc2_weight_blockscale_fp8.size(1)
|
||||
fc2_bs_size2 = fc2_weight_blockscale_fp8.size(2)
|
||||
if fc2_bs_size1 < hidden_size_padded or fc2_bs_size2 < inter_blocks_padded:
|
||||
fc2_blockscale_fp8_padded = fc2_weight_blockscale_fp8.new_zeros(
|
||||
n_experts, hidden_size_padded, inter_blocks_padded
|
||||
)
|
||||
fc2_blockscale_fp8_padded[:, :fc2_bs_size1, :fc2_bs_size2] = fc2_weight_blockscale_fp8
|
||||
fc2_weight_blockscale_fp8 = fc2_blockscale_fp8_padded
|
||||
# else: block scales already have correct padded size, use as-is
|
||||
|
||||
# quant_scales is described by this code:
|
||||
# https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015
|
||||
|
||||
@ -83,7 +83,7 @@ class QuantConfigReaderRegistry:
|
||||
|
||||
@QuantConfigReaderRegistry.register("modelopt")
|
||||
class ModelOPTQuantConfigReader(QuantConfigReader):
|
||||
_ALWAYS_EXCLUDE = ("lm_head", "model.embed_tokens")
|
||||
_ALWAYS_EXCLUDE = ("lm_head", "model.embed_tokens", "*.mixer.gate*")
|
||||
|
||||
def read_config(self, config: Dict) -> Dict:
|
||||
producer = config.get("producer", {}).get("name")
|
||||
|
||||
@ -1680,9 +1680,9 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
|
||||
new_key_fc1_alpha = f"nvfp4_moe_w1_alpha_stacked_{fused_key_counter}"
|
||||
new_key_fc2_alpha = f"nvfp4_moe_w2_alpha_stacked_{fused_key_counter}"
|
||||
|
||||
weight_dtype = torch.float8_e4m3fn
|
||||
_register_parameter(gm, new_key_fc1_expert_weights, fc1_expert_weights.to(weight_dtype))
|
||||
_register_parameter(gm, new_key_fc2_expert_weights, fc2_expert_weights.to(weight_dtype))
|
||||
# FP4 weights are already packed as uint8, don't convert dtype
|
||||
_register_parameter(gm, new_key_fc1_expert_weights, fc1_expert_weights)
|
||||
_register_parameter(gm, new_key_fc2_expert_weights, fc2_expert_weights)
|
||||
_register_parameter(
|
||||
gm, new_key_fc1_weight_blockscale_fp8, fc1_weight_blockscale_fp8_stacked
|
||||
)
|
||||
@ -1747,11 +1747,12 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
|
||||
w2_input_scale_stacked = _stack(w2_input_scale, dim=0)
|
||||
w3_input_scale_stacked = _stack(w3_input_scale, dim=0, device=device, dtype=dtype)
|
||||
|
||||
w1_weight_blockscale_fp8_stacked = _stack(w1_weight_scale, dim=0).to(torch.float8_e4m3fn)
|
||||
w2_weight_blockscale_fp8_stacked = _stack(w2_weight_scale, dim=0).to(torch.float8_e4m3fn)
|
||||
# Use .view() not .to() to reinterpret bytes as float8, not value conversion
|
||||
w1_weight_blockscale_fp8_stacked = _stack(w1_weight_scale, dim=0).view(torch.float8_e4m3fn)
|
||||
w2_weight_blockscale_fp8_stacked = _stack(w2_weight_scale, dim=0).view(torch.float8_e4m3fn)
|
||||
w3_weight_blockscale_fp8_stacked = _stack(
|
||||
w3_weight_scale, dim=0, device=device, dtype=dtype
|
||||
).to(torch.float8_e4m3fn)
|
||||
).view(torch.float8_e4m3fn)
|
||||
|
||||
w1_alpha_stacked = _stack(w1_alpha, dim=0)
|
||||
w2_alpha_stacked = _stack(w2_alpha, dim=0)
|
||||
|
||||
@ -141,6 +141,7 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "nvidia/Nemotron-MOE"
|
||||
MODEL_PATH_BF16 = f"{llm_models_root()}/Nemotron-Nano-3-30B-A3.5B-dev-1024"
|
||||
MODEL_PATH_FP8 = f"{llm_models_root()}/Nemotron-Nano-3-30B-A3.5B-FP8-KVFP8-dev"
|
||||
MODEL_PATH_NVFP4 = f"{llm_models_root()}/Nemotron-3-Nano-30B-A3B-NVFP4"
|
||||
|
||||
def get_default_kwargs(self):
|
||||
return {
|
||||
@ -217,6 +218,22 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip(reason="NVFP4 model is not in the CI yet")
|
||||
def test_nvfp4(self):
|
||||
kwargs = self.get_default_kwargs()
|
||||
with AutoDeployLLM(model=self.MODEL_PATH_NVFP4,
|
||||
tokenizer=self.MODEL_PATH_NVFP4,
|
||||
**kwargs) as llm:
|
||||
# Manually set quant_config for NVFP4 model to get the accuracy threshold
|
||||
llm.args.quant_config.quant_algo = QuantAlgo.NVFP4
|
||||
llm.args.quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm, sampling_params=sampling_params)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
class TestNemotronSuperV3(LlmapiAccuracyTestHarness):
|
||||
"""Accuracy regression tests for Nemotron Super V3.
|
||||
|
||||
@ -792,3 +792,213 @@ def test_trtllm_fused_moe_nvfp4(
|
||||
print(f"{trtllm_output=}")
|
||||
# print(f"{diff.abs()>=2e-1=}")
|
||||
torch.testing.assert_close(ref_output, trtllm_output, rtol=2e-1, atol=2e-1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hidden_size, intermediate_size",
|
||||
[
|
||||
(128, 128), # No padding needed
|
||||
(128, 160), # Small padded case (inter_size needs padding)
|
||||
(2688, 1856), # Nemotron-Nano-3-30B-A3 dimensions (padding required)
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not fp4_compatible() or not trtllm_ops_available(),
|
||||
reason="Requires fp4 and trtllm support",
|
||||
)
|
||||
def test_stack_nvfp4_moe_weights_transform_relu2(hidden_size, intermediate_size):
|
||||
"""
|
||||
Test for _stack_nvfp4_moe_weights transform with non-gated MLP (Relu2).
|
||||
|
||||
Tests both:
|
||||
- 128x128: No padding needed
|
||||
- 2688x1856: Padding required for intermediate_size
|
||||
|
||||
Compares torch_quant_nvfp4_moe (before transform) vs
|
||||
trtllm_quant_nvfp4_moe_fused (after transform).
|
||||
"""
|
||||
import torch.fx as fx
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.transform.library.fused_moe import _stack_nvfp4_moe_weights
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
batch_size, num_experts, top_k = 1, 2, 2
|
||||
otype = torch.float16
|
||||
is_gated_mlp = False # Non-gated MLP (Relu2)
|
||||
act_fn = ActivationType.Relu2
|
||||
|
||||
# Generate test data
|
||||
x = gen_tensor((batch_size, hidden_size), otype) * 0.5
|
||||
router_logits = torch.randn(batch_size, num_experts, dtype=otype).cuda()
|
||||
routing_weights, selected_experts = compute_routing(router_logits, top_k)
|
||||
|
||||
# Quantize weights for each expert
|
||||
def quantize_weight(weight):
|
||||
"""Quantize weight and return FP4 values + 2D block scales for fused kernel.
|
||||
|
||||
Key insight: fp4_quantize returns block scales in an internal interleaved
|
||||
layout that works with nvfp4_gemm but NOT with the fused MoE kernel.
|
||||
|
||||
For the fused kernel, we must:
|
||||
1. Get FP4 values from fp4_quantize
|
||||
2. Compute block scales manually in row-major format
|
||||
3. Pad to kernel's expected dimensions
|
||||
4. Apply block_scale_interleave for swizzling
|
||||
"""
|
||||
m, n = weight.shape
|
||||
amax = weight.abs().max().to(torch.float32)
|
||||
gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / amax
|
||||
|
||||
# Get FP4 values only (ignore block scales - they're in wrong format)
|
||||
fp4, _ = torch.ops.trtllm.fp4_quantize(weight, gs, NVFP4_BLOCK_SIZE, False)
|
||||
|
||||
# Compute block scales manually in row-major format
|
||||
# Each block is NVFP4_BLOCK_SIZE (16) elements along the N dimension
|
||||
n_blocks = n // NVFP4_BLOCK_SIZE
|
||||
# Reshape to (M, n_blocks, 16) and compute max per block
|
||||
weight_blocks = weight.view(m, n_blocks, NVFP4_BLOCK_SIZE)
|
||||
block_maxes = weight_blocks.abs().amax(dim=2) # (M, n_blocks)
|
||||
|
||||
# Convert to FP8 E4M3 block scales
|
||||
# block_scale = block_max * gs / FLOAT8_E4M3_MAX
|
||||
block_scales_fp8 = (
|
||||
(block_maxes * gs / FLOAT8_E4M3_MAX).clamp(max=FLOAT8_E4M3_MAX).to(torch.float8_e4m3fn)
|
||||
)
|
||||
|
||||
# Pad to kernel's expected dimensions (multiples of 128 for rows, 8 for block cols)
|
||||
padded_m = math.ceil(m / 128) * 128
|
||||
padded_n_blocks = math.ceil(n_blocks / 8) * 8
|
||||
|
||||
# Create padded buffer (zeros are neutral for unused blocks)
|
||||
block_scales_padded = torch.zeros(
|
||||
(padded_m, padded_n_blocks), dtype=torch.uint8, device="cuda"
|
||||
)
|
||||
block_scales_padded[:m, :n_blocks] = block_scales_fp8.view(torch.uint8)
|
||||
|
||||
# Apply block_scale_interleave for fused kernel's swizzled format
|
||||
block_scales_swizzled = (
|
||||
torch.ops.trtllm.block_scale_interleave(block_scales_padded.cpu().contiguous())
|
||||
.reshape(padded_m, padded_n_blocks)
|
||||
.cuda()
|
||||
)
|
||||
|
||||
input_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
alpha = torch.tensor(
|
||||
1.0 / (input_scale.item() * gs.item()), device="cuda", dtype=torch.float32
|
||||
)
|
||||
return fp4, block_scales_swizzled, input_scale, alpha
|
||||
|
||||
# Create per-expert weights and scales
|
||||
w1_weight, w2_weight, w3_weight = [], [], []
|
||||
w1_input_scale, w2_input_scale, w3_input_scale = [], [], []
|
||||
w1_weight_scale, w2_weight_scale, w3_weight_scale = [], [], []
|
||||
w1_alpha, w2_alpha, w3_alpha = [], [], []
|
||||
|
||||
for _ in range(num_experts):
|
||||
# w1 (gate), w2 (down), w3 (up)
|
||||
w1 = gen_tensor((intermediate_size, hidden_size), otype, scale=0.1)
|
||||
w2 = gen_tensor((hidden_size, intermediate_size), otype, scale=0.1)
|
||||
|
||||
fp4, bs, iscale, alpha = quantize_weight(w1)
|
||||
w1_weight.append(fp4)
|
||||
w1_weight_scale.append(bs)
|
||||
w1_input_scale.append(iscale) # Keep as scalar, not [1]
|
||||
w1_alpha.append(alpha) # Keep as scalar, not [1]
|
||||
|
||||
fp4, bs, iscale, alpha = quantize_weight(w2)
|
||||
w2_weight.append(fp4)
|
||||
w2_weight_scale.append(bs)
|
||||
w2_input_scale.append(iscale) # Keep as scalar, not [1]
|
||||
w2_alpha.append(alpha) # Keep as scalar, not [1]
|
||||
|
||||
# Call torch_quant_nvfp4_moe directly (reference)
|
||||
ref_output = torch.ops.auto_deploy.torch_quant_nvfp4_moe(
|
||||
x,
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
w1_weight,
|
||||
w2_weight,
|
||||
w3_weight,
|
||||
w1_input_scale,
|
||||
w2_input_scale,
|
||||
w3_input_scale,
|
||||
w1_weight_scale,
|
||||
w2_weight_scale,
|
||||
w3_weight_scale,
|
||||
w1_alpha,
|
||||
w2_alpha,
|
||||
w3_alpha,
|
||||
is_gated_mlp=is_gated_mlp,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
# Now create a GraphModule and apply the transform
|
||||
class MoEModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
for i in range(num_experts):
|
||||
self.register_buffer(f"w1_{i}", w1_weight[i])
|
||||
self.register_buffer(f"w2_{i}", w2_weight[i])
|
||||
self.register_buffer(f"w1_scale_{i}", w1_weight_scale[i])
|
||||
self.register_buffer(f"w2_scale_{i}", w2_weight_scale[i])
|
||||
self.register_buffer(f"w1_iscale_{i}", w1_input_scale[i])
|
||||
self.register_buffer(f"w2_iscale_{i}", w2_input_scale[i])
|
||||
self.register_buffer(f"w1_alpha_{i}", w1_alpha[i])
|
||||
self.register_buffer(f"w2_alpha_{i}", w2_alpha[i])
|
||||
if is_gated_mlp:
|
||||
self.register_buffer(f"w3_{i}", w3_weight[i])
|
||||
self.register_buffer(f"w3_scale_{i}", w3_weight_scale[i])
|
||||
self.register_buffer(f"w3_iscale_{i}", w3_input_scale[i])
|
||||
self.register_buffer(f"w3_alpha_{i}", w3_alpha[i])
|
||||
|
||||
def forward(self, x, selected_experts, routing_weights):
|
||||
return torch.ops.auto_deploy.torch_quant_nvfp4_moe(
|
||||
x,
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
[getattr(self, f"w1_{i}") for i in range(num_experts)],
|
||||
[getattr(self, f"w2_{i}") for i in range(num_experts)],
|
||||
[getattr(self, f"w3_{i}") for i in range(num_experts)] if is_gated_mlp else [],
|
||||
[getattr(self, f"w1_iscale_{i}") for i in range(num_experts)],
|
||||
[getattr(self, f"w2_iscale_{i}") for i in range(num_experts)],
|
||||
[getattr(self, f"w3_iscale_{i}") for i in range(num_experts)]
|
||||
if is_gated_mlp
|
||||
else [],
|
||||
[getattr(self, f"w1_scale_{i}") for i in range(num_experts)],
|
||||
[getattr(self, f"w2_scale_{i}") for i in range(num_experts)],
|
||||
[getattr(self, f"w3_scale_{i}") for i in range(num_experts)]
|
||||
if is_gated_mlp
|
||||
else [],
|
||||
[getattr(self, f"w1_alpha_{i}") for i in range(num_experts)],
|
||||
[getattr(self, f"w2_alpha_{i}") for i in range(num_experts)],
|
||||
[getattr(self, f"w3_alpha_{i}") for i in range(num_experts)]
|
||||
if is_gated_mlp
|
||||
else [],
|
||||
is_gated_mlp=is_gated_mlp,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
module = MoEModule().cuda()
|
||||
gm = fx.symbolic_trace(module)
|
||||
|
||||
# Apply the transform
|
||||
num_transformed = _stack_nvfp4_moe_weights(gm)
|
||||
gm.recompile()
|
||||
|
||||
assert num_transformed == 1, f"Expected 1 transform, got {num_transformed}"
|
||||
|
||||
# Run the transformed graph
|
||||
transformed_output = gm(x, selected_experts, routing_weights)
|
||||
|
||||
# Get the registered parameters after transform
|
||||
fc1_act_scale = getattr(gm, "nvfp4_moe_w3_w1_input_scale_stacked_0", None)
|
||||
fc1_alpha = getattr(gm, "nvfp4_moe_w1_alpha_stacked_0", None)
|
||||
if fc1_act_scale is not None:
|
||||
print(f"fc1_act_scale (after transform): {fc1_act_scale}, shape: {fc1_act_scale.shape}")
|
||||
if fc1_alpha is not None:
|
||||
print(f"fc1_alpha (after transform): {fc1_alpha}, shape: {fc1_alpha.shape}")
|
||||
|
||||
# Should be close for FP4 quantization (gated MLP may have slightly larger diff due to alpha handling)
|
||||
tol = 1e-3
|
||||
torch.testing.assert_close(ref_output, transformed_output, rtol=tol, atol=tol)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user