[Quark] Support loading Quark NVFP4 checkpoints in vLLM (#35859)

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: fxmarty-amd <felmarty@amd.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
fxmarty-amd
2026-05-13 20:17:36 +02:00
committed by GitHub
parent ab1ad0d7a9
commit 40330967ab
6 changed files with 503 additions and 4 deletions
+54 -2
View File
@@ -240,8 +240,13 @@ WIKITEXT_ACCURACY_CONFIGS = [
not QUARK_MXFP4_AVAILABLE,
reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize(
"config",
[pytest.param(val, id=f"config:{val}") for val in WIKITEXT_ACCURACY_CONFIGS],
)
@pytest.mark.parametrize(
"tp_size", [pytest.param(val, id=f"tp_size:{val}") for val in [1, 2]]
)
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
device_count = torch.accelerator.device_count()
if device_count < tp_size:
@@ -268,6 +273,53 @@ def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
@pytest.mark.skipif(
not QUARK_MXFP4_AVAILABLE,
reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
@pytest.mark.parametrize("tp_size", [1, 2])
def test_nvfp4_wikitext_correctness(tp_size: int):
device_count = torch.accelerator.device_count()
if device_count < tp_size:
pytest.skip(f"This test requires >={tp_size} gpus, got only {device_count}")
# NOTE: expected_value from nvidia/Qwen3-30B-A3B-NVFP4
expected_value = 11.2391
model_name = "amd-quark/Qwen3-30B-A3B-nvfp4-quark"
task = "wikitext"
rtol = 0.25
config = AccuracyTestConfig(
model_name=model_name,
excepted_value=expected_value,
)
model_args = config.get_model_args(
tp_size=tp_size,
kwargs={
"cudagraph_capture_sizes": [16],
},
)
model_args.pop("add_bos_token")
# Smaller cudagraph_capture_sizes to speed up the test.
results = lm_eval.simple_evaluate(
model="vllm",
model_args=model_args,
tasks=task,
batch_size=64,
)
EXPECTED_VALUE = config.excepted_value
measured_value = results["results"][task]["word_perplexity,none"]
assert (
measured_value < EXPECTED_VALUE + rtol
and measured_value > EXPECTED_VALUE - rtol
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
@pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS)
@pytest.mark.skipif(
not QUARK_MXFP4_AVAILABLE,
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E
QuarkMoEMethod,
)
from vllm.model_executor.layers.quantization.quark.schemes import (
QuarkNVFP4,
QuarkOCP_MX,
QuarkScheme,
QuarkW4A8_MXFP4_FP8,
@@ -395,6 +396,54 @@ class QuarkConfig(QuantizationConfig):
and is_weight_symmetric
)
def _is_nvfp4(
self,
weight_quant: dict[str, Any] | list[dict[str, Any]] | None,
input_quant: dict[str, Any] | list[dict[str, Any]] | None,
) -> bool:
# Confirm weights and input quantized.
if weight_quant is None or input_quant is None:
return False
# Confirm both weight_quant and input_quant are lists with 2 elements
if not isinstance(weight_quant, list) or len(weight_quant) != 2:
return False
if not isinstance(input_quant, list) or len(input_quant) != 2:
return False
# First element should be fp4 with per_group quantization
is_fp4_per_group_weight = (
weight_quant[0].get("dtype") == "fp4"
and weight_quant[0].get("qscheme") == "per_group"
and weight_quant[0].get("group_size") == 16
and not weight_quant[0].get("is_dynamic")
)
is_fp4_per_group_input = (
input_quant[0].get("dtype") == "fp4"
and input_quant[0].get("qscheme") == "per_group"
and input_quant[0].get("group_size") == 16
and input_quant[0].get("is_dynamic")
)
# Second element should be fp8_e4m3 with per_tensor quantization
is_fp8_per_tensor_weight = (
weight_quant[1].get("dtype") == "fp8_e4m3"
and weight_quant[1].get("qscheme") == "per_tensor"
and not weight_quant[1].get("is_dynamic")
)
is_fp8_per_tensor_input = (
input_quant[1].get("dtype") == "fp8_e4m3"
and input_quant[1].get("qscheme") == "per_tensor"
and not input_quant[1].get("is_dynamic")
)
return (
is_fp4_per_group_weight # type: ignore[return-value]
and is_fp4_per_group_input
and is_fp8_per_tensor_weight
and is_fp8_per_tensor_input
)
def _is_w_ocp_mx_a_x(
self, weight_quant: dict[str, Any] | None, input_quant: dict[str, Any] | None
) -> bool:
@@ -543,7 +592,9 @@ class QuarkConfig(QuantizationConfig):
weight_config = cast(dict[str, Any], config.get("weight"))
input_config = cast(dict[str, Any], config.get("input_tensors"))
if self._is_fp8_w8a8(weight_config, input_config):
if self._is_nvfp4(weight_config, input_config):
return QuarkNVFP4()
elif self._is_fp8_w8a8(weight_config, input_config):
is_fp8_w8a8_supported = self._check_scheme_supported(
QuarkW8A8Fp8.get_min_capability(), error=False
)
@@ -39,6 +39,12 @@ from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
mxfp4_round_up_hidden_size_and_intermediate_size,
select_mxfp4_moe_backend,
)
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
convert_to_nvfp4_moe_kernel_format,
make_nvfp4_moe_kernel,
make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin,
)
@@ -49,6 +55,8 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8StaticTensorSym,
kNvfp4Dynamic,
kNvfp4Static,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d,
@@ -63,7 +71,9 @@ logger = init_logger(__name__)
__all__ = [
"QuarkMoEMethod",
"QuarkW8A8Fp8MoEMethod",
"QuarkOCP_MX_MoEMethod",
"QuarkNvfp4MoEMethod",
]
@@ -92,6 +102,10 @@ class QuarkMoEMethod(FusedMoEMethodBase):
if quant_config._is_fp8_w4a8(weight_config, input_config):
return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_nvfp4(weight_config, input_config):
return QuarkNvfp4MoEMethod(
weight_config, input_config, module.moe_config, quant_config
)
elif quant_config._is_fp8_w8a8(weight_config, input_config):
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_w_ocp_mx_a_x(weight_config, input_config):
@@ -1501,3 +1515,228 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
class QuarkNvfp4MoEMethod(QuarkMoEMethod):
def __init__(
self,
weight_config: dict[str, Any],
input_config: dict[str, Any],
moe: FusedMoEConfig,
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
):
super().__init__(moe)
self.weight_quant = weight_config
self.input_quant = input_config
self.quant_config = quant_config
self.group_size = 16
# Select experts implementation.
self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend(
config=self.moe,
weight_key=kNvfp4Static,
activation_key=kNvfp4Dynamic,
)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.num_experts = num_experts
layer.params_dtype = params_dtype
layer.quant_config = self.quant_config
weight_dtype = torch.uint8
weight_scale_dtype = torch.float8_e4m3fn
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
# GEMM 1 - w13 weight
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
w13_num_shards * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // 2,
dtype=weight_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# GEMM 2 - w2 weight
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // 2,
dtype=weight_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# Weight scales (per-group FP8 scales)
w13_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts,
w13_num_shards * intermediate_size_per_partition,
hidden_size // self.group_size,
dtype=weight_scale_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition // self.group_size,
dtype=weight_scale_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# Global weight scales (per-tensor FP32 scales)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
w13_weight_scale_2 = torch.nn.Parameter(
torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
set_weight_attrs(w13_weight_scale_2, extra_weight_attrs)
w2_weight_scale_2 = torch.nn.Parameter(
torch.empty(num_experts, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
set_weight_attrs(w2_weight_scale_2, extra_weight_attrs)
# Input global scales (per-tensor FP32 scales)
w13_input_scale_2 = torch.nn.Parameter(
torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_input_scale_2", w13_input_scale_2)
set_weight_attrs(w13_input_scale_2, extra_weight_attrs)
w2_input_scale_2 = torch.nn.Parameter(
torch.empty(num_experts, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_input_scale_2", w2_input_scale_2)
set_weight_attrs(w2_input_scale_2, extra_weight_attrs)
def process_weights_after_loading(self, layer: FusedMoE) -> None:
"""
Convert NVFP4 MoE weights into kernel format and setup the kernel.
"""
if not torch.allclose(
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
):
raise ValueError("Different global scales for w1 and w3 is not supported.")
# Use a single gscale for w13
w13_weight_scale_2 = torch.maximum(
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
).contiguous()
w2_weight_scale_2 = layer.w2_weight_scale_2
(
w13,
w13_scale,
w13_scale_2,
a13_scale,
w2,
w2_scale,
w2_scale_2,
a2_scale,
) = convert_to_nvfp4_moe_kernel_format(
nvfp4_backend=self.nvfp4_backend,
layer=layer,
w13=layer.w13_weight,
w13_scale=layer.w13_weight_scale,
w13_scale_2=w13_weight_scale_2,
a13_scale=layer.w13_input_scale_2,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
w2_scale_2=w2_weight_scale_2,
a2_scale=layer.w2_input_scale_2,
is_act_and_mul=self.moe.is_act_and_mul,
)
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w13_weight_scale", w13_scale)
replace_parameter(layer, "w13_weight_scale_2", w13_scale_2)
replace_parameter(layer, "w13_input_scale_2", a13_scale)
replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, "w2_weight_scale", w2_scale)
replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
replace_parameter(layer, "w2_input_scale_2", a2_scale)
# Setup modular kernel.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_mk = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return make_nvfp4_moe_quant_config(
backend=self.nvfp4_backend,
w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w13_scale_2=layer.w13_weight_scale_2,
w2_scale_2=layer.w2_weight_scale_2,
a13_scale=layer.w13_input_scale_2,
a2_scale=layer.w2_input_scale_2,
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: Any | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None
return self.moe_mk.apply(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .quark_nvfp4 import QuarkNVFP4
from .quark_ocp_mx import QuarkOCP_MX
from .quark_scheme import QuarkScheme
from .quark_w4a8_mxfp4_fp8 import QuarkW4A8_MXFP4_FP8
@@ -13,4 +14,5 @@ __all__ = [
"QuarkW8A8Int8",
"QuarkOCP_MX",
"QuarkW4A8_MXFP4_FP8",
"QuarkNVFP4",
]
@@ -0,0 +1,154 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import init_nvfp4_linear_kernel
from vllm.model_executor.kernels.linear.nvfp4.emulation import (
EmulationNvFp4LinearKernel,
)
from vllm.model_executor.layers.quantization.quark.schemes.quark_scheme import (
QuarkScheme,
)
from vllm.model_executor.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
__all__ = ["QuarkNVFP4"]
logger = init_logger(__name__)
class QuarkNVFP4(QuarkScheme):
"""
Quark NVFP4 quantization scheme.
Supports loading NVFP4 checkpoints with the following structure:
- weight: uint8, shape [out_features, in_features // 2] (packed FP4)
- weight_scale: float8_e4m3fn, shape [out_features, in_features // group_size]
- weight_scale_2: bfloat16/float32, scalar (global weight scale)
- input_scale_2: bfloat16/float32, scalar (global input scale)
"""
def __init__(
self,
):
self.kernel = init_nvfp4_linear_kernel()
self.group_size = 16
if not isinstance(self.kernel, EmulationNvFp4LinearKernel):
logger.warning_once(
"Only EmulationNvFp4LinearKernel NVFP4 dense implementation is "
"tested with QuarkNVFP4, got kernel=%s. Correctness is not validated.",
type(self.kernel).__name__,
)
@classmethod
def get_min_capability(cls) -> int:
# FP4 requires Turing (75) or newer
return 75
def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
if input_size_per_partition % self.group_size != 0:
raise ValueError(
f"Input size per partition ({input_size_per_partition}) must be "
f"divisible by group size ({self.group_size})"
)
# Weight: FP4 packed as uint8 (2 FP4 values per uint8)
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // 2,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# Per-group weight scale (FP8 E4M3)
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // self.group_size,
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
# Global weight scale (scalar, per partition)
weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale_2", weight_scale_2)
# Global input scale (scalar, per partition)
input_scale_2 = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("input_scale_2", input_scale_2)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
input_global_scale = layer.input_scale_2.max().to(torch.float32)
layer.input_global_scale = Parameter(input_global_scale, requires_grad=False)
del layer.input_scale_2
weight_global_scale = layer.weight_scale_2.to(torch.float32)
if torch.unique(weight_global_scale).numel() != 1:
logger.warning_once(
"In NVFP4 linear, the global scale for weight are different"
" for parallel layers (e.g. q_proj, k_proj, v_proj). This"
" will likely result in reduced accuracy. Please verify the"
" model accuracy. Consider using a checkpoint with a shared"
" global NVFP4 scale for fused layers."
)
weight_global_scale = weight_global_scale.max()
layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False)
del layer.weight_scale_2
layer.alpha = Parameter(
layer.input_global_scale * layer.weight_global_scale, requires_grad=False
)
layer.input_global_scale_inv = Parameter(
(1.0 / layer.input_global_scale).to(torch.float32), requires_grad=False
)
# Convert layer to NVFP4 linear kernel format
self.kernel.process_weights_after_loading(layer)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.kernel.apply_weights(layer=layer, x=x, bias=bias)
@@ -17,7 +17,8 @@ def deep_compare(dict1: Any, dict2: Any) -> bool:
return False
return all(deep_compare(dict1[k], dict2[k]) for k in dict1)
elif isinstance(dict1, list):
return set(dict1) == set(dict2)
# `dict1` may be a list of dict.
return all(deep_compare(dict1[i], dict2[i]) for i in range(len(dict1)))
else:
return dict1 == dict2