From d9fba85396a1d9a368a2f2f621b49f5232e0d92a Mon Sep 17 00:00:00 2001 From: Wei-Ming Chen <17592131+meenchen@users.noreply.github.com> Date: Wed, 3 Dec 2025 09:47:13 -0800 Subject: [PATCH] [OMNIML-2932] [feat] nvfp4 awq support (#8698) Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- tensorrt_llm/_torch/modules/attention.py | 19 +- .../modules/fused_moe/fused_moe_cutlass.py | 8 + .../modules/fused_moe/fused_moe_trtllm_gen.py | 7 + .../_torch/modules/fused_moe/quantization.py | 82 ++++++++ tensorrt_llm/_torch/modules/linear.py | 59 ++++++ tensorrt_llm/quantization/mode.py | 4 + tensorrt_llm/quantization/quantize.py | 8 + .../_torch/modules/test_awq_quantization.py | 181 ++++++++++++++++++ 8 files changed, 365 insertions(+), 3 deletions(-) create mode 100644 tests/unittest/_torch/modules/test_awq_quantization.py diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 62f577e41f..ed23eb7aab 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -382,8 +382,11 @@ class Attention(nn.Module): out_dtype = q.dtype if self.attn_backend == "TRTLLM": - if self.has_quant_scale and (self.attn.has_fp8_kv_cache - or self.attn.has_fp4_kv_cache): + # Don't use FP8 output if o_proj has pre_quant_scale - keep BF16 for better precision + has_pre_quant_scale = getattr(self.o_proj, 'pre_quant_scale', + None) is not None + if self.has_quant_scale and not has_pre_quant_scale and ( + self.attn.has_fp8_kv_cache or self.attn.has_fp4_kv_cache): out_dtype = torch.float8_e4m3fn output = q.new_empty([num_tokens, hidden_size], dtype=out_dtype) return output @@ -414,8 +417,18 @@ class Attention(nn.Module): out_scale = None out_scale_sf = None - if self.has_quant_scale and not self.attn_output_gate: + has_awq_pre_quant_scale = hasattr( + self.o_proj, + 'pre_quant_scale') and self.o_proj.pre_quant_scale is not None + # Don't set out_scale if o_proj has pre_quant_scale - this prevents FP8/FP4 output + # and keeps attention output in BF16 for better precision when applying pre_quant_scale + if self.has_quant_scale and not self.attn_output_gate and not has_awq_pre_quant_scale: out_scale = self.o_proj.inv_input_scale + if has_awq_pre_quant_scale and enable_attn_nvfp4_output: + logger.warning_once( + "Disable attn nvfp4 output because o_proj has pre_quant_scale for AWQ.", + key="disable_attn_nvfp4_output_for_awq") + enable_attn_nvfp4_output = False if self.o_proj.has_nvfp4 and self.support_nvfp4_output and enable_attn_nvfp4_output and not self.attn_output_gate: out_scale_sf = self.o_proj.input_scale diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index 06119c1536..bb883ffc95 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -437,6 +437,14 @@ class CutlassFusedMoE(MoE): elif self.has_int8_woq_per_channel: use_int8_woq_per_channel = True elif self.has_nvfp4: + # Apply pre_quant_scale if it exists (for NVFP4_AWQ) + if hasattr( + self, + 'fc31_act_scale') and self.fc31_act_scale is not None: + assert not isinstance( + x, Fp4QuantizedTensor + ), "Fp4QuantizedTensor is not expected for AWQ quantization." + x = x * self.fc31_act_scale if run_post_quant_allgather or self.enable_alltoall: if isinstance(x, Fp4QuantizedTensor): assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication" diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index f0393d2c20..f2bcc4397a 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -316,6 +316,13 @@ class TRTLLMGenFusedMoE(MoE): x_row = x.shape[0] x, x_sf = x.fp4_tensor, x.scaling_factor else: + # Apply pre_quant_scale if it exists (for NVFP4_AWQ) + # fc31_act_scale shape: (1, hidden_size) + # x shape: (num_tokens, hidden_size) + if hasattr( + self, + 'fc31_act_scale') and self.fc31_act_scale is not None: + x = x * self.fc31_act_scale x_row = x.shape[0] x, x_sf = torch.ops.trtllm.fp4_quantize( x, self.fc31_input_scale, self.scaling_vector_size, False, diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 5e80d4840c..cb61b7867f 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -1716,6 +1716,10 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase): requires_grad=False) module.register_parameter("fc2_alpha", fc2_alpha) + # Optional per-channel act scale for NVFP4_AWQ (pre_quant_scale support) + # This will be initialized in load_quant_scales if pre_quant_scale exists + module.register_parameter("fc31_act_scale", None) + super().create_weights(module, weight_dtype, w3_w1_weight_shape, w2_weight_shape) @@ -1834,12 +1838,30 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase): dst_fc2_alpha[expert_idx]) def load_quant_scales(self, module: torch.nn.Module, weights: Dict): + # Check if pre_quant_scale exists in the checkpoint (for NVFP4_AWQ) + has_pre_quant_scale = False + if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + # Check if any expert has pre_quant_scale + has_pre_quant_scale = f"0.w1.pre_quant_scale" in weights + # Step1: Load input scales. tmp_fc31_input_scale = torch.empty(module.num_experts, dtype=torch.float32) tmp_fc2_input_scale = torch.empty(module.num_experts, dtype=torch.float32) + # If pre_quant_scale exists, we need a per-channel act scale for fc31 + # All experts share the same input, so pre_quant_scale should be identical across experts + if has_pre_quant_scale: + # Create fc31_act_scale parameter (for gate_up_proj / w3_w1) + # Shape: (1, hidden_size) - single vector for all experts (they share the same input) + fc31_act_scale = nn.Parameter(torch.empty(1, + module.hidden_size, + dtype=module.dtype, + device='cuda'), + requires_grad=False) + module.register_parameter("fc31_act_scale", fc31_act_scale) + for expert_id in range(module.num_experts): if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: w1_input_scale = weights[f"{expert_id}.w1.input_scale"] @@ -1866,6 +1888,66 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase): module.fc2_input_scale.data.copy_( tmp_fc2_input_scale.max().reciprocal()) + # Load pre_quant_scale if it exists (for NVFP4_AWQ) + if has_pre_quant_scale: + from ..linear import TensorParallelMode, load_weight_shard + + device = module.fc31_act_scale.device + # Load fc31 (w3/w1) pre_quant_scales + # All experts should have identical pre_quant_scale since they share the same input + all_w3_pre_quant_scales = [] + all_w1_pre_quant_scales = [] + for expert_id in module.initial_local_expert_ids: + w3_pre_quant_scale = load_weight_shard( + weights[f"{expert_id}.w3.pre_quant_scale"], + module.tp_size, + module.tp_rank, + TensorParallelMode.ROW, + device=device) + w1_pre_quant_scale = load_weight_shard( + weights[f"{expert_id}.w1.pre_quant_scale"], + module.tp_size, + module.tp_rank, + TensorParallelMode.ROW, + device=device) + all_w3_pre_quant_scales.append(w3_pre_quant_scale) + all_w1_pre_quant_scales.append(w1_pre_quant_scale) + + # Verify that all experts have identical pre_quant_scale + # (they should be the same since all experts share the same input) + w3_reference = all_w3_pre_quant_scales[0] + w1_reference = all_w1_pre_quant_scales[0] + + def check_consistency(scale, ref_scale, scale_name, expert_id): + if not torch.allclose(scale, ref_scale, rtol=1e-5, atol=1e-8): + max_diff = (scale - ref_scale).abs().max() + msg = ( + f"MoE pre_quant_scale: expert {expert_id} {scale_name} " + f"differs from expert {module.initial_local_expert_ids[0]}! Max diff: {max_diff:.6e}. " + f"All experts should have identical pre_quant_scale since they share the same input." + ) + logger.error(msg) + raise ValueError(msg) + + for i, (w3_scale, w1_scale) in enumerate( + zip(all_w3_pre_quant_scales[1:], + all_w1_pre_quant_scales[1:]), 1): + check_consistency(w3_scale, w3_reference, "w3.pre_quant_scale", + module.initial_local_expert_ids[i]) + check_consistency(w1_scale, w1_reference, "w1.pre_quant_scale", + module.initial_local_expert_ids[i]) + + # Take the maximum pre_quant_scale between w3 and w1 from the first expert + # (all experts should have the same values) + # Shape: (hidden_size,) + # Keep on CUDA device (w3_reference and w1_reference are already on CUDA) + fc31_pre_quant_scale = torch.max(w3_reference, w1_reference).to( + dtype=module.dtype, device='cuda') + + # Store as a single vector since all experts share the same pre_quant_scale + # This will be broadcasted to all tokens in the forward pass + module.fc31_act_scale.data.copy_(fc31_pre_quant_scale.unsqueeze(0)) + # Step2: Load weight block scales and alphas. self.load_all_fp4_weight_scales_and_alphas( module, weights, module.initial_local_expert_ids, diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index f56d6431b4..613665207c 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -898,6 +898,10 @@ class NVFP4LinearMethod(LinearMethodBase): module.inv_kv_scales = Parameter(torch.ones(3, dtype=torch.float32), requires_grad=False) + # NOTE: Not in all linear we have this tensor - pre_quant_scale is computed as an average and merged with the + # LayerNorm for QKV and Gate/Up projection layers when possible. we can see the tensor only for o_proj and down_proj + module.pre_quant_scale = None + if bias: module.bias = Parameter(torch.empty((out_features), dtype=dtype), requires_grad=False) @@ -907,10 +911,28 @@ class NVFP4LinearMethod(LinearMethodBase): def apply(self, module: Linear, input: torch.Tensor, bias: Optional[torch.Tensor]): if isinstance(input, Fp4QuantizedTensor): + # Input is already quantized - this should not happen if pre_quant_scale exists + # because we disable FP4 output for attention output when pre_quant_scale is present + if module.pre_quant_scale is not None: + raise RuntimeError( + "Received FP4 quantized input but pre_quant_scale exists. " + "This indicates FP4 output was not properly disabled for the previous layer." + ) act_fp4, act_sf = input.fp4_tensor, input.scaling_factor elif isinstance(input, tuple): + # Input is a tuple of (fp4_tensor, scaling_factor) + if module.pre_quant_scale is not None: + raise RuntimeError( + "Received FP4 quantized tuple input but pre_quant_scale exists. " + "This indicates FP4 output was not properly disabled for the previous layer." + ) act_fp4, act_sf = input else: + # Input is a regular tensor () - apply pre_quant_scale if it exists (for NVFP4_AWQ) + if module.pre_quant_scale is not None: + assert input.dtype == module.pre_quant_scale.dtype, "Input dtype and pre_quant_scale dtype must match" + input = input * module.pre_quant_scale + act_fp4, act_sf = torch.ops.trtllm.fp4_quantize( input, module.input_scale, module.scaling_vector_size, False) @@ -1003,6 +1025,24 @@ class NVFP4LinearMethod(LinearMethodBase): copy_weight(module.alpha, alpha) module.scalar_alpha = alpha.item() + # Load pre_quant_scale if it exists (for NVFP4_AWQ) + if "pre_quant_scale" in weights[0]: + device = module.weight.device + pre_quant_scale = load_weight_shard( + weights[0]["pre_quant_scale"], + module.tp_size, + module.tp_rank, + # pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around + TensorParallelMode.flip(module.tp_mode), + device, + ) + + module.pre_quant_scale = Parameter( + torch.ones((module.in_features, ), dtype=pre_quant_scale.dtype), + requires_grad=False).to(device=device) + + copy_weight(module.pre_quant_scale, pre_quant_scale) + def load_weights_fused_qkv_linear(self, module: Linear, weights: List[Dict]) -> None: q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( @@ -1059,6 +1099,25 @@ class NVFP4LinearMethod(LinearMethodBase): copy_weight(module.alpha, alpha) module.scalar_alpha = alpha.item() + # Load pre_quant_scale if it exists (for NVFP4_AWQ) + # NOTE: pre_quant_scale is the same for gate and up since modelopt checks which layer shared the same input + if "pre_quant_scale" in weights[0]: + device = module.weight.device + pre_quant_scale = load_weight_shard( + weights[0]["pre_quant_scale"], + module.tp_size, + module.tp_rank, + # pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around + TensorParallelMode.flip(module.tp_mode), + device, + ) + + module.pre_quant_scale = Parameter( + torch.ones((module.in_features, ), dtype=pre_quant_scale.dtype), + requires_grad=False).to(device=device) + + copy_weight(module.pre_quant_scale, pre_quant_scale) + def post_load_weights(self, module: Linear): super().post_load_weights(module) """ diff --git a/tensorrt_llm/quantization/mode.py b/tensorrt_llm/quantization/mode.py index 4615bc1376..6f035eb3d8 100644 --- a/tensorrt_llm/quantization/mode.py +++ b/tensorrt_llm/quantization/mode.py @@ -44,6 +44,7 @@ class QuantAlgo(StrEnum, metaclass=BaseEnumMeta): W4A8_MXFP4_FP8 = auto() W4A8_MXFP4_MXFP8 = auto() W4A16_MXFP4 = auto() + NVFP4_AWQ = auto() NO_QUANT = auto() @@ -410,6 +411,9 @@ class QuantMode(IntFlag): quant_mode = QuantMode.from_description(use_fp8_block_scales=True) elif quant_algo == QuantAlgo.NVFP4: quant_mode = QuantMode.from_description(use_nvfp4=True) + elif quant_algo == QuantAlgo.NVFP4_AWQ: + # NVFP4_AWQ uses the same QuantMode as NVFP4, distinction is at QuantAlgo level + quant_mode = QuantMode.from_description(use_nvfp4=True) elif quant_algo == QuantAlgo.W4A8_NVFP4_FP8: quant_mode = QuantMode.from_description(use_w4a8_nvfp4_fp8=True) elif quant_algo == QuantAlgo.W4A8_MXFP4_FP8: diff --git a/tensorrt_llm/quantization/quantize.py b/tensorrt_llm/quantization/quantize.py index b05d2a22e1..a04a223914 100644 --- a/tensorrt_llm/quantization/quantize.py +++ b/tensorrt_llm/quantization/quantize.py @@ -72,6 +72,14 @@ def quantize_layers( else: quant_mode = quant_config.quant_mode init_params["quant_mode"] = quant_mode + + # Auto-detect pre_quant_scale based on quant_algo + # For AWQ-based quantization methods that use pre_quant_scale + if quant_config.quant_algo in [ + QuantAlgo.W4A16_AWQ, QuantAlgo.NVFP4_AWQ, + QuantAlgo.W4A8_AWQ + ]: + init_params["pre_quant_scale"] = True if "bias" in init_params and not isinstance(module, MixtureOfExperts): init_params["bias"] = init_params["bias"] is not None diff --git a/tests/unittest/_torch/modules/test_awq_quantization.py b/tests/unittest/_torch/modules/test_awq_quantization.py new file mode 100644 index 0000000000..2e3ea3a6d1 --- /dev/null +++ b/tests/unittest/_torch/modules/test_awq_quantization.py @@ -0,0 +1,181 @@ +from unittest.mock import patch + +import pytest +import torch +from utils.util import skip_pre_blackwell + +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.modules.fused_moe import DefaultMoeRoutingMethod, create_moe +from tensorrt_llm._torch.modules.linear import Linear +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig + + +@skip_pre_blackwell # NVFP4 AWQ features require Blackwell (SM100) or later +@pytest.mark.parametrize("has_scale", [True, False]) +def test_linear_nvfp4_awq_pre_quant_scale(has_scale): + """ + Test that Linear (NVFP4 mode) applies pre_quant_scale to input before quantization. + + This tests the logic in NVFP4LinearMethod.apply (around line 824-827): + if module.pre_quant_scale is not None: + assert input.dtype == module.pre_quant_scale.dtype + input = input * module.pre_quant_scale + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + # Create a Linear module with NVFP4 quantization using actual initialization + mapping = Mapping(world_size=1, rank=0, tp_size=1) + quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4) + + in_features = 128 + out_features = 256 + + # Create actual Linear module + linear = Linear( + in_features=in_features, + out_features=out_features, + dtype=torch.bfloat16, + mapping=mapping, + quant_config=quant_config, + ).cuda() + + # Set pre_quant_scale based on test parameter (skip weight init as it's quantized) + if has_scale: + scale = torch.full((in_features,), 0.5, dtype=torch.bfloat16, device="cuda") + linear.pre_quant_scale = torch.nn.Parameter(scale, requires_grad=False) + + # Prepare input + x = torch.ones(2, in_features, dtype=torch.bfloat16, device="cuda") + + # Mock torch.ops.trtllm.fp4_quantize to capture the input after scaling + captured_input = None + + def mock_fp4_quantize(input_tensor, *args, **kwargs): + nonlocal captured_input + captured_input = input_tensor + # Return dummy quantized output + return ( + torch.zeros( + input_tensor.shape[0], input_tensor.shape[1] // 2, dtype=torch.uint8, device="cuda" + ), + torch.ones( + input_tensor.shape[0], + input_tensor.shape[1] // 16, + dtype=torch.float32, + device="cuda", + ), + ) + + # Also mock the GEMM to avoid execution errors + # just return a dummy output since we are capturing the input before input quantization + def mock_gemm(act_fp4, *args, **kwargs): + batch_size = act_fp4.shape[0] + return torch.zeros(batch_size, out_features, dtype=torch.bfloat16, device="cuda") + + with patch("torch.ops.trtllm.fp4_quantize", side_effect=mock_fp4_quantize, create=True): + with patch("torch.ops.trtllm.nvfp4_gemm", side_effect=mock_gemm, create=True): + linear(x) + + assert captured_input is not None, "fp4_quantize was not called" + + if has_scale: + # Should be scaled + expected = x * scale + assert torch.allclose(captured_input, expected, rtol=1e-5, atol=1e-5), ( + "Expected scaled input" + ) + else: + # Should be original + assert torch.equal(captured_input, x), "Expected original input" + + +@skip_pre_blackwell # TRTLLMGenFusedMoE requires Blackwell (SM100) or later +@pytest.mark.parametrize("has_scale", [True, False]) +def test_fused_moe_trtllm_gen_input_scaling(has_scale): + """ + Test that TRTLLMGenFusedMoE applies fc31_act_scale to input x if present. + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + # Setup + mapping = Mapping(world_size=1, rank=0, tp_size=1) + quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4) + model_config = ModelConfig(mapping=mapping, quant_config=quant_config, moe_backend="TRTLLM") + + num_experts = 8 + hidden_size = 128 + intermediate_size = 256 + top_k = 2 + seq_len = 4 + + routing_method = DefaultMoeRoutingMethod(top_k=top_k) + + torch.manual_seed(0) + torch.cuda.manual_seed(0) + + # Create actual MoE module + moe = create_moe( + num_experts=num_experts, + routing_method=routing_method, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + dtype=torch.bfloat16, + reduce_results=False, + model_config=model_config, + ).cuda() + + # Set fc31_act_scale directly (simulating AWQ pre_quant_scale) + if has_scale: + scale = torch.full((hidden_size,), 0.5, dtype=torch.bfloat16, device="cuda") + moe.fc31_act_scale = torch.nn.Parameter(scale, requires_grad=False) + + # Prepare input + x = torch.ones(seq_len, hidden_size, dtype=torch.bfloat16, device="cuda") + router_logits = torch.randn(seq_len, num_experts, dtype=torch.bfloat16, device="cuda") + + # Mock torch.ops.trtllm.fp4_quantize to capture the input after scaling + captured_input = None + + def mock_fp4_quantize(input_tensor, *args, **kwargs): + nonlocal captured_input + captured_input = input_tensor + # Return dummy quantized output + return ( + torch.zeros( + input_tensor.shape[0], input_tensor.shape[1] // 2, dtype=torch.uint8, device="cuda" + ), + torch.ones( + input_tensor.shape[0], + input_tensor.shape[1] // 16, + dtype=torch.float32, + device="cuda", + ), + ) + + # Also mock the MoE runner to avoid execution errors + # just return a dummy output since we are capturing the input before input quantization + def mock_moe_runner(*args, **kwargs): + return [torch.zeros(seq_len, hidden_size, dtype=torch.bfloat16, device="cuda")] + + with patch("torch.ops.trtllm.fp4_quantize", side_effect=mock_fp4_quantize, create=True): + with patch( + "torch.ops.trtllm.fp4_block_scale_moe_runner", side_effect=mock_moe_runner, create=True + ): + with torch.inference_mode(): + moe.forward(x, router_logits) + + assert captured_input is not None, "fp4_quantize was not called" + + if has_scale: + # Should be scaled by fc31_act_scale (which is loaded from pre_quant_scale) + # The scale is 0.5, so x_passed should be x * 0.5 + expected = x * 0.5 + assert torch.allclose(captured_input, expected, rtol=1e-5, atol=1e-5), ( + "Expected scaled input" + ) + else: + # Should be original + assert torch.equal(captured_input, x), "Expected original input"