diff --git a/tests/quantization/test_quark.py b/tests/quantization/test_quark.py index 9eca6cda083..fe474d7e0cc 100644 --- a/tests/quantization/test_quark.py +++ b/tests/quantization/test_quark.py @@ -240,8 +240,13 @@ WIKITEXT_ACCURACY_CONFIGS = [ not QUARK_MXFP4_AVAILABLE, reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available", ) -@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS) -@pytest.mark.parametrize("tp_size", [1, 2]) +@pytest.mark.parametrize( + "config", + [pytest.param(val, id=f"config:{val}") for val in WIKITEXT_ACCURACY_CONFIGS], +) +@pytest.mark.parametrize( + "tp_size", [pytest.param(val, id=f"tp_size:{val}") for val in [1, 2]] +) def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int): device_count = torch.accelerator.device_count() if device_count < tp_size: @@ -268,6 +273,53 @@ def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int): ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" +@pytest.mark.skipif( + not QUARK_MXFP4_AVAILABLE, + reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available", +) +@pytest.mark.parametrize("tp_size", [1, 2]) +def test_nvfp4_wikitext_correctness(tp_size: int): + device_count = torch.accelerator.device_count() + if device_count < tp_size: + pytest.skip(f"This test requires >={tp_size} gpus, got only {device_count}") + + # NOTE: expected_value from nvidia/Qwen3-30B-A3B-NVFP4 + expected_value = 11.2391 + + model_name = "amd-quark/Qwen3-30B-A3B-nvfp4-quark" + task = "wikitext" + + rtol = 0.25 + + config = AccuracyTestConfig( + model_name=model_name, + excepted_value=expected_value, + ) + + model_args = config.get_model_args( + tp_size=tp_size, + kwargs={ + "cudagraph_capture_sizes": [16], + }, + ) + model_args.pop("add_bos_token") + + # Smaller cudagraph_capture_sizes to speed up the test. + results = lm_eval.simple_evaluate( + model="vllm", + model_args=model_args, + tasks=task, + batch_size=64, + ) + + EXPECTED_VALUE = config.excepted_value + measured_value = results["results"][task]["word_perplexity,none"] + assert ( + measured_value < EXPECTED_VALUE + rtol + and measured_value > EXPECTED_VALUE - rtol + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + + @pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS) @pytest.mark.skipif( not QUARK_MXFP4_AVAILABLE, diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index c117c71df0a..d1f7a169ee7 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E QuarkMoEMethod, ) from vllm.model_executor.layers.quantization.quark.schemes import ( + QuarkNVFP4, QuarkOCP_MX, QuarkScheme, QuarkW4A8_MXFP4_FP8, @@ -395,6 +396,54 @@ class QuarkConfig(QuantizationConfig): and is_weight_symmetric ) + def _is_nvfp4( + self, + weight_quant: dict[str, Any] | list[dict[str, Any]] | None, + input_quant: dict[str, Any] | list[dict[str, Any]] | None, + ) -> bool: + # Confirm weights and input quantized. + if weight_quant is None or input_quant is None: + return False + + # Confirm both weight_quant and input_quant are lists with 2 elements + if not isinstance(weight_quant, list) or len(weight_quant) != 2: + return False + if not isinstance(input_quant, list) or len(input_quant) != 2: + return False + + # First element should be fp4 with per_group quantization + is_fp4_per_group_weight = ( + weight_quant[0].get("dtype") == "fp4" + and weight_quant[0].get("qscheme") == "per_group" + and weight_quant[0].get("group_size") == 16 + and not weight_quant[0].get("is_dynamic") + ) + is_fp4_per_group_input = ( + input_quant[0].get("dtype") == "fp4" + and input_quant[0].get("qscheme") == "per_group" + and input_quant[0].get("group_size") == 16 + and input_quant[0].get("is_dynamic") + ) + + # Second element should be fp8_e4m3 with per_tensor quantization + is_fp8_per_tensor_weight = ( + weight_quant[1].get("dtype") == "fp8_e4m3" + and weight_quant[1].get("qscheme") == "per_tensor" + and not weight_quant[1].get("is_dynamic") + ) + is_fp8_per_tensor_input = ( + input_quant[1].get("dtype") == "fp8_e4m3" + and input_quant[1].get("qscheme") == "per_tensor" + and not input_quant[1].get("is_dynamic") + ) + + return ( + is_fp4_per_group_weight # type: ignore[return-value] + and is_fp4_per_group_input + and is_fp8_per_tensor_weight + and is_fp8_per_tensor_input + ) + def _is_w_ocp_mx_a_x( self, weight_quant: dict[str, Any] | None, input_quant: dict[str, Any] | None ) -> bool: @@ -543,7 +592,9 @@ class QuarkConfig(QuantizationConfig): weight_config = cast(dict[str, Any], config.get("weight")) input_config = cast(dict[str, Any], config.get("input_tensors")) - if self._is_fp8_w8a8(weight_config, input_config): + if self._is_nvfp4(weight_config, input_config): + return QuarkNVFP4() + elif self._is_fp8_w8a8(weight_config, input_config): is_fp8_w8a8_supported = self._check_scheme_supported( QuarkW8A8Fp8.get_min_capability(), error=False ) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index ca881b4ff8b..212ed5e83dd 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -39,6 +39,12 @@ from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( mxfp4_round_up_hidden_size_and_intermediate_size, select_mxfp4_moe_backend, ) +from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( + convert_to_nvfp4_moe_kernel_format, + make_nvfp4_moe_kernel, + make_nvfp4_moe_quant_config, + select_nvfp4_moe_backend, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_fp8_moe_layer_for_marlin, ) @@ -49,6 +55,8 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, kFp8StaticTensorSym, + kNvfp4Dynamic, + kNvfp4Static, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, @@ -63,7 +71,9 @@ logger = init_logger(__name__) __all__ = [ "QuarkMoEMethod", + "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod", + "QuarkNvfp4MoEMethod", ] @@ -92,6 +102,10 @@ class QuarkMoEMethod(FusedMoEMethodBase): if quant_config._is_fp8_w4a8(weight_config, input_config): return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config) + elif quant_config._is_nvfp4(weight_config, input_config): + return QuarkNvfp4MoEMethod( + weight_config, input_config, module.moe_config, quant_config + ) elif quant_config._is_fp8_w8a8(weight_config, input_config): return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config) elif quant_config._is_w_ocp_mx_a_x(weight_config, input_config): @@ -1501,3 +1515,228 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, ) + + +class QuarkNvfp4MoEMethod(QuarkMoEMethod): + def __init__( + self, + weight_config: dict[str, Any], + input_config: dict[str, Any], + moe: FusedMoEConfig, + quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821 + ): + super().__init__(moe) + self.weight_quant = weight_config + self.input_quant = input_config + self.quant_config = quant_config + self.group_size = 16 + + # Select experts implementation. + self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend( + config=self.moe, + weight_key=kNvfp4Static, + activation_key=kNvfp4Dynamic, + ) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.num_experts = num_experts + layer.params_dtype = params_dtype + layer.quant_config = self.quant_config + weight_dtype = torch.uint8 + weight_scale_dtype = torch.float8_e4m3fn + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 + + # GEMM 1 - w13 weight + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + w13_num_shards * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // 2, + dtype=weight_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # GEMM 2 - w2 weight + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // 2, + dtype=weight_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Weight scales (per-group FP8 scales) + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + w13_num_shards * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # Global weight scales (per-tensor FP32 scales) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + + w13_weight_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, w13_num_shards, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) + set_weight_attrs(w13_weight_scale_2, extra_weight_attrs) + + w2_weight_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2) + set_weight_attrs(w2_weight_scale_2, extra_weight_attrs) + + # Input global scales (per-tensor FP32 scales) + w13_input_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, w13_num_shards, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_input_scale_2", w13_input_scale_2) + set_weight_attrs(w13_input_scale_2, extra_weight_attrs) + + w2_input_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_input_scale_2", w2_input_scale_2) + set_weight_attrs(w2_input_scale_2, extra_weight_attrs) + + def process_weights_after_loading(self, layer: FusedMoE) -> None: + """ + Convert NVFP4 MoE weights into kernel format and setup the kernel. + """ + + if not torch.allclose( + layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1] + ): + raise ValueError("Different global scales for w1 and w3 is not supported.") + + # Use a single gscale for w13 + w13_weight_scale_2 = torch.maximum( + layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1] + ).contiguous() + + w2_weight_scale_2 = layer.w2_weight_scale_2 + + ( + w13, + w13_scale, + w13_scale_2, + a13_scale, + w2, + w2_scale, + w2_scale_2, + a2_scale, + ) = convert_to_nvfp4_moe_kernel_format( + nvfp4_backend=self.nvfp4_backend, + layer=layer, + w13=layer.w13_weight, + w13_scale=layer.w13_weight_scale, + w13_scale_2=w13_weight_scale_2, + a13_scale=layer.w13_input_scale_2, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + w2_scale_2=w2_weight_scale_2, + a2_scale=layer.w2_input_scale_2, + is_act_and_mul=self.moe.is_act_and_mul, + ) + + replace_parameter(layer, "w13_weight", w13) + replace_parameter(layer, "w13_weight_scale", w13_scale) + replace_parameter(layer, "w13_weight_scale_2", w13_scale_2) + replace_parameter(layer, "w13_input_scale_2", a13_scale) + + replace_parameter(layer, "w2_weight", w2) + replace_parameter(layer, "w2_weight_scale", w2_scale) + replace_parameter(layer, "w2_weight_scale_2", w2_scale_2) + replace_parameter(layer, "w2_input_scale_2", a2_scale) + + # Setup modular kernel. + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + if self.moe_quant_config: + assert self.experts_cls is not None + self.moe_mk = make_nvfp4_moe_kernel( + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + experts_cls=self.experts_cls, + shared_experts=layer.shared_experts, + routing_tables=layer._maybe_init_expert_routing_tables(), + ) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + return make_nvfp4_moe_quant_config( + backend=self.nvfp4_backend, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w13_scale_2=layer.w13_weight_scale_2, + w2_scale_2=layer.w2_weight_scale_2, + a13_scale=layer.w13_input_scale_2, + a2_scale=layer.w2_input_scale_2, + ) + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_experts_input: Any | None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + assert self.moe_mk is not None + return self.moe_mk.apply( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + shared_experts_input=shared_experts_input, + ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/__init__.py b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py index a5e33a0442b..1ef5824fec5 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from .quark_nvfp4 import QuarkNVFP4 from .quark_ocp_mx import QuarkOCP_MX from .quark_scheme import QuarkScheme from .quark_w4a8_mxfp4_fp8 import QuarkW4A8_MXFP4_FP8 @@ -13,4 +14,5 @@ __all__ = [ "QuarkW8A8Int8", "QuarkOCP_MX", "QuarkW4A8_MXFP4_FP8", + "QuarkNVFP4", ] diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_nvfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_nvfp4.py new file mode 100644 index 00000000000..d8a339770fa --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_nvfp4.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable + +import torch +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.kernels.linear import init_nvfp4_linear_kernel +from vllm.model_executor.kernels.linear.nvfp4.emulation import ( + EmulationNvFp4LinearKernel, +) +from vllm.model_executor.layers.quantization.quark.schemes.quark_scheme import ( + QuarkScheme, +) +from vllm.model_executor.parameter import ( + GroupQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) + +__all__ = ["QuarkNVFP4"] + +logger = init_logger(__name__) + + +class QuarkNVFP4(QuarkScheme): + """ + Quark NVFP4 quantization scheme. + + Supports loading NVFP4 checkpoints with the following structure: + - weight: uint8, shape [out_features, in_features // 2] (packed FP4) + - weight_scale: float8_e4m3fn, shape [out_features, in_features // group_size] + - weight_scale_2: bfloat16/float32, scalar (global weight scale) + - input_scale_2: bfloat16/float32, scalar (global input scale) + """ + + def __init__( + self, + ): + self.kernel = init_nvfp4_linear_kernel() + self.group_size = 16 + + if not isinstance(self.kernel, EmulationNvFp4LinearKernel): + logger.warning_once( + "Only EmulationNvFp4LinearKernel NVFP4 dense implementation is " + "tested with QuarkNVFP4, got kernel=%s. Correctness is not validated.", + type(self.kernel).__name__, + ) + + @classmethod + def get_min_capability(cls) -> int: + # FP4 requires Turing (75) or newer + return 75 + + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + if input_size_per_partition % self.group_size != 0: + raise ValueError( + f"Input size per partition ({input_size_per_partition}) must be " + f"divisible by group size ({self.group_size})" + ) + + # Weight: FP4 packed as uint8 (2 FP4 values per uint8) + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # Per-group weight scale (FP8 E4M3) + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + # Global weight scale (scalar, per partition) + weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale_2", weight_scale_2) + + # Global input scale (scalar, per partition) + input_scale_2 = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("input_scale_2", input_scale_2) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + input_global_scale = layer.input_scale_2.max().to(torch.float32) + layer.input_global_scale = Parameter(input_global_scale, requires_grad=False) + del layer.input_scale_2 + + weight_global_scale = layer.weight_scale_2.to(torch.float32) + + if torch.unique(weight_global_scale).numel() != 1: + logger.warning_once( + "In NVFP4 linear, the global scale for weight are different" + " for parallel layers (e.g. q_proj, k_proj, v_proj). This" + " will likely result in reduced accuracy. Please verify the" + " model accuracy. Consider using a checkpoint with a shared" + " global NVFP4 scale for fused layers." + ) + + weight_global_scale = weight_global_scale.max() + + layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False) + del layer.weight_scale_2 + + layer.alpha = Parameter( + layer.input_global_scale * layer.weight_global_scale, requires_grad=False + ) + layer.input_global_scale_inv = Parameter( + (1.0 / layer.input_global_scale).to(torch.float32), requires_grad=False + ) + + # Convert layer to NVFP4 linear kernel format + self.kernel.process_weights_after_loading(layer) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return self.kernel.apply_weights(layer=layer, x=x, bias=bias) diff --git a/vllm/model_executor/layers/quantization/quark/utils.py b/vllm/model_executor/layers/quantization/quark/utils.py index 98ac1a4f355..ee55e5d39e7 100644 --- a/vllm/model_executor/layers/quantization/quark/utils.py +++ b/vllm/model_executor/layers/quantization/quark/utils.py @@ -17,7 +17,8 @@ def deep_compare(dict1: Any, dict2: Any) -> bool: return False return all(deep_compare(dict1[k], dict2[k]) for k in dict1) elif isinstance(dict1, list): - return set(dict1) == set(dict2) + # `dict1` may be a list of dict. + return all(deep_compare(dict1[i], dict2[i]) for i in range(len(dict1))) else: return dict1 == dict2