mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Quark] Support loading Quark NVFP4 checkpoints in vLLM (#35859)
Signed-off-by: Felix Marty <Felix.Marty@amd.com> Signed-off-by: fxmarty-amd <felmarty@amd.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
@@ -240,8 +240,13 @@ WIKITEXT_ACCURACY_CONFIGS = [
|
||||
not QUARK_MXFP4_AVAILABLE,
|
||||
reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
|
||||
)
|
||||
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
@pytest.mark.parametrize(
|
||||
"config",
|
||||
[pytest.param(val, id=f"config:{val}") for val in WIKITEXT_ACCURACY_CONFIGS],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"tp_size", [pytest.param(val, id=f"tp_size:{val}") for val in [1, 2]]
|
||||
)
|
||||
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
|
||||
device_count = torch.accelerator.device_count()
|
||||
if device_count < tp_size:
|
||||
@@ -268,6 +273,53 @@ def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not QUARK_MXFP4_AVAILABLE,
|
||||
reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
|
||||
)
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
def test_nvfp4_wikitext_correctness(tp_size: int):
|
||||
device_count = torch.accelerator.device_count()
|
||||
if device_count < tp_size:
|
||||
pytest.skip(f"This test requires >={tp_size} gpus, got only {device_count}")
|
||||
|
||||
# NOTE: expected_value from nvidia/Qwen3-30B-A3B-NVFP4
|
||||
expected_value = 11.2391
|
||||
|
||||
model_name = "amd-quark/Qwen3-30B-A3B-nvfp4-quark"
|
||||
task = "wikitext"
|
||||
|
||||
rtol = 0.25
|
||||
|
||||
config = AccuracyTestConfig(
|
||||
model_name=model_name,
|
||||
excepted_value=expected_value,
|
||||
)
|
||||
|
||||
model_args = config.get_model_args(
|
||||
tp_size=tp_size,
|
||||
kwargs={
|
||||
"cudagraph_capture_sizes": [16],
|
||||
},
|
||||
)
|
||||
model_args.pop("add_bos_token")
|
||||
|
||||
# Smaller cudagraph_capture_sizes to speed up the test.
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
model_args=model_args,
|
||||
tasks=task,
|
||||
batch_size=64,
|
||||
)
|
||||
|
||||
EXPECTED_VALUE = config.excepted_value
|
||||
measured_value = results["results"][task]["word_perplexity,none"]
|
||||
assert (
|
||||
measured_value < EXPECTED_VALUE + rtol
|
||||
and measured_value > EXPECTED_VALUE - rtol
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS)
|
||||
@pytest.mark.skipif(
|
||||
not QUARK_MXFP4_AVAILABLE,
|
||||
|
||||
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E
|
||||
QuarkMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import (
|
||||
QuarkNVFP4,
|
||||
QuarkOCP_MX,
|
||||
QuarkScheme,
|
||||
QuarkW4A8_MXFP4_FP8,
|
||||
@@ -395,6 +396,54 @@ class QuarkConfig(QuantizationConfig):
|
||||
and is_weight_symmetric
|
||||
)
|
||||
|
||||
def _is_nvfp4(
|
||||
self,
|
||||
weight_quant: dict[str, Any] | list[dict[str, Any]] | None,
|
||||
input_quant: dict[str, Any] | list[dict[str, Any]] | None,
|
||||
) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
# Confirm both weight_quant and input_quant are lists with 2 elements
|
||||
if not isinstance(weight_quant, list) or len(weight_quant) != 2:
|
||||
return False
|
||||
if not isinstance(input_quant, list) or len(input_quant) != 2:
|
||||
return False
|
||||
|
||||
# First element should be fp4 with per_group quantization
|
||||
is_fp4_per_group_weight = (
|
||||
weight_quant[0].get("dtype") == "fp4"
|
||||
and weight_quant[0].get("qscheme") == "per_group"
|
||||
and weight_quant[0].get("group_size") == 16
|
||||
and not weight_quant[0].get("is_dynamic")
|
||||
)
|
||||
is_fp4_per_group_input = (
|
||||
input_quant[0].get("dtype") == "fp4"
|
||||
and input_quant[0].get("qscheme") == "per_group"
|
||||
and input_quant[0].get("group_size") == 16
|
||||
and input_quant[0].get("is_dynamic")
|
||||
)
|
||||
|
||||
# Second element should be fp8_e4m3 with per_tensor quantization
|
||||
is_fp8_per_tensor_weight = (
|
||||
weight_quant[1].get("dtype") == "fp8_e4m3"
|
||||
and weight_quant[1].get("qscheme") == "per_tensor"
|
||||
and not weight_quant[1].get("is_dynamic")
|
||||
)
|
||||
is_fp8_per_tensor_input = (
|
||||
input_quant[1].get("dtype") == "fp8_e4m3"
|
||||
and input_quant[1].get("qscheme") == "per_tensor"
|
||||
and not input_quant[1].get("is_dynamic")
|
||||
)
|
||||
|
||||
return (
|
||||
is_fp4_per_group_weight # type: ignore[return-value]
|
||||
and is_fp4_per_group_input
|
||||
and is_fp8_per_tensor_weight
|
||||
and is_fp8_per_tensor_input
|
||||
)
|
||||
|
||||
def _is_w_ocp_mx_a_x(
|
||||
self, weight_quant: dict[str, Any] | None, input_quant: dict[str, Any] | None
|
||||
) -> bool:
|
||||
@@ -543,7 +592,9 @@ class QuarkConfig(QuantizationConfig):
|
||||
weight_config = cast(dict[str, Any], config.get("weight"))
|
||||
input_config = cast(dict[str, Any], config.get("input_tensors"))
|
||||
|
||||
if self._is_fp8_w8a8(weight_config, input_config):
|
||||
if self._is_nvfp4(weight_config, input_config):
|
||||
return QuarkNVFP4()
|
||||
elif self._is_fp8_w8a8(weight_config, input_config):
|
||||
is_fp8_w8a8_supported = self._check_scheme_supported(
|
||||
QuarkW8A8Fp8.get_min_capability(), error=False
|
||||
)
|
||||
|
||||
@@ -39,6 +39,12 @@ from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
|
||||
mxfp4_round_up_hidden_size_and_intermediate_size,
|
||||
select_mxfp4_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
|
||||
convert_to_nvfp4_moe_kernel_format,
|
||||
make_nvfp4_moe_kernel,
|
||||
make_nvfp4_moe_quant_config,
|
||||
select_nvfp4_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_fp8_moe_layer_for_marlin,
|
||||
)
|
||||
@@ -49,6 +55,8 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Dynamic,
|
||||
kNvfp4Static,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d,
|
||||
@@ -63,7 +71,9 @@ logger = init_logger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"QuarkMoEMethod",
|
||||
"QuarkW8A8Fp8MoEMethod",
|
||||
"QuarkOCP_MX_MoEMethod",
|
||||
"QuarkNvfp4MoEMethod",
|
||||
]
|
||||
|
||||
|
||||
@@ -92,6 +102,10 @@ class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
if quant_config._is_fp8_w4a8(weight_config, input_config):
|
||||
return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
|
||||
elif quant_config._is_nvfp4(weight_config, input_config):
|
||||
return QuarkNvfp4MoEMethod(
|
||||
weight_config, input_config, module.moe_config, quant_config
|
||||
)
|
||||
elif quant_config._is_fp8_w8a8(weight_config, input_config):
|
||||
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
|
||||
elif quant_config._is_w_ocp_mx_a_x(weight_config, input_config):
|
||||
@@ -1501,3 +1515,228 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
class QuarkNvfp4MoEMethod(QuarkMoEMethod):
|
||||
def __init__(
|
||||
self,
|
||||
weight_config: dict[str, Any],
|
||||
input_config: dict[str, Any],
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.weight_quant = weight_config
|
||||
self.input_quant = input_config
|
||||
self.quant_config = quant_config
|
||||
self.group_size = 16
|
||||
|
||||
# Select experts implementation.
|
||||
self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend(
|
||||
config=self.moe,
|
||||
weight_key=kNvfp4Static,
|
||||
activation_key=kNvfp4Dynamic,
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
layer.num_experts = num_experts
|
||||
layer.params_dtype = params_dtype
|
||||
layer.quant_config = self.quant_config
|
||||
weight_dtype = torch.uint8
|
||||
weight_scale_dtype = torch.float8_e4m3fn
|
||||
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
|
||||
|
||||
# GEMM 1 - w13 weight
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // 2,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
# GEMM 2 - w2 weight
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
intermediate_size_per_partition // 2,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# Weight scales (per-group FP8 scales)
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
hidden_size // self.group_size,
|
||||
dtype=weight_scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
|
||||
)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=weight_scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# Global weight scales (per-tensor FP32 scales)
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||
)
|
||||
|
||||
w13_weight_scale_2 = torch.nn.Parameter(
|
||||
torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
|
||||
set_weight_attrs(w13_weight_scale_2, extra_weight_attrs)
|
||||
|
||||
w2_weight_scale_2 = torch.nn.Parameter(
|
||||
torch.empty(num_experts, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
|
||||
set_weight_attrs(w2_weight_scale_2, extra_weight_attrs)
|
||||
|
||||
# Input global scales (per-tensor FP32 scales)
|
||||
w13_input_scale_2 = torch.nn.Parameter(
|
||||
torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_input_scale_2", w13_input_scale_2)
|
||||
set_weight_attrs(w13_input_scale_2, extra_weight_attrs)
|
||||
|
||||
w2_input_scale_2 = torch.nn.Parameter(
|
||||
torch.empty(num_experts, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_input_scale_2", w2_input_scale_2)
|
||||
set_weight_attrs(w2_input_scale_2, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: FusedMoE) -> None:
|
||||
"""
|
||||
Convert NVFP4 MoE weights into kernel format and setup the kernel.
|
||||
"""
|
||||
|
||||
if not torch.allclose(
|
||||
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
|
||||
):
|
||||
raise ValueError("Different global scales for w1 and w3 is not supported.")
|
||||
|
||||
# Use a single gscale for w13
|
||||
w13_weight_scale_2 = torch.maximum(
|
||||
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
|
||||
).contiguous()
|
||||
|
||||
w2_weight_scale_2 = layer.w2_weight_scale_2
|
||||
|
||||
(
|
||||
w13,
|
||||
w13_scale,
|
||||
w13_scale_2,
|
||||
a13_scale,
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_scale_2,
|
||||
a2_scale,
|
||||
) = convert_to_nvfp4_moe_kernel_format(
|
||||
nvfp4_backend=self.nvfp4_backend,
|
||||
layer=layer,
|
||||
w13=layer.w13_weight,
|
||||
w13_scale=layer.w13_weight_scale,
|
||||
w13_scale_2=w13_weight_scale_2,
|
||||
a13_scale=layer.w13_input_scale_2,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w2_scale_2=w2_weight_scale_2,
|
||||
a2_scale=layer.w2_input_scale_2,
|
||||
is_act_and_mul=self.moe.is_act_and_mul,
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w13_weight", w13)
|
||||
replace_parameter(layer, "w13_weight_scale", w13_scale)
|
||||
replace_parameter(layer, "w13_weight_scale_2", w13_scale_2)
|
||||
replace_parameter(layer, "w13_input_scale_2", a13_scale)
|
||||
|
||||
replace_parameter(layer, "w2_weight", w2)
|
||||
replace_parameter(layer, "w2_weight_scale", w2_scale)
|
||||
replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
|
||||
replace_parameter(layer, "w2_input_scale_2", a2_scale)
|
||||
|
||||
# Setup modular kernel.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_mk = make_nvfp4_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
experts_cls=self.experts_cls,
|
||||
shared_experts=layer.shared_experts,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return make_nvfp4_moe_quant_config(
|
||||
backend=self.nvfp4_backend,
|
||||
w13_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w13_scale_2=layer.w13_weight_scale_2,
|
||||
w2_scale_2=layer.w2_weight_scale_2,
|
||||
a13_scale=layer.w13_input_scale_2,
|
||||
a2_scale=layer.w2_input_scale_2,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: Any | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk.apply(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .quark_nvfp4 import QuarkNVFP4
|
||||
from .quark_ocp_mx import QuarkOCP_MX
|
||||
from .quark_scheme import QuarkScheme
|
||||
from .quark_w4a8_mxfp4_fp8 import QuarkW4A8_MXFP4_FP8
|
||||
@@ -13,4 +14,5 @@ __all__ = [
|
||||
"QuarkW8A8Int8",
|
||||
"QuarkOCP_MX",
|
||||
"QuarkW4A8_MXFP4_FP8",
|
||||
"QuarkNVFP4",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,154 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import init_nvfp4_linear_kernel
|
||||
from vllm.model_executor.kernels.linear.nvfp4.emulation import (
|
||||
EmulationNvFp4LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes.quark_scheme import (
|
||||
QuarkScheme,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
|
||||
__all__ = ["QuarkNVFP4"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class QuarkNVFP4(QuarkScheme):
|
||||
"""
|
||||
Quark NVFP4 quantization scheme.
|
||||
|
||||
Supports loading NVFP4 checkpoints with the following structure:
|
||||
- weight: uint8, shape [out_features, in_features // 2] (packed FP4)
|
||||
- weight_scale: float8_e4m3fn, shape [out_features, in_features // group_size]
|
||||
- weight_scale_2: bfloat16/float32, scalar (global weight scale)
|
||||
- input_scale_2: bfloat16/float32, scalar (global input scale)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
self.kernel = init_nvfp4_linear_kernel()
|
||||
self.group_size = 16
|
||||
|
||||
if not isinstance(self.kernel, EmulationNvFp4LinearKernel):
|
||||
logger.warning_once(
|
||||
"Only EmulationNvFp4LinearKernel NVFP4 dense implementation is "
|
||||
"tested with QuarkNVFP4, got kernel=%s. Correctness is not validated.",
|
||||
type(self.kernel).__name__,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# FP4 requires Turing (75) or newer
|
||||
return 75
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
if input_size_per_partition % self.group_size != 0:
|
||||
raise ValueError(
|
||||
f"Input size per partition ({input_size_per_partition}) must be "
|
||||
f"divisible by group size ({self.group_size})"
|
||||
)
|
||||
|
||||
# Weight: FP4 packed as uint8 (2 FP4 values per uint8)
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# Per-group weight scale (FP8 E4M3)
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // self.group_size,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# Global weight scale (scalar, per partition)
|
||||
weight_scale_2 = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale_2", weight_scale_2)
|
||||
|
||||
# Global input scale (scalar, per partition)
|
||||
input_scale_2 = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("input_scale_2", input_scale_2)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
input_global_scale = layer.input_scale_2.max().to(torch.float32)
|
||||
layer.input_global_scale = Parameter(input_global_scale, requires_grad=False)
|
||||
del layer.input_scale_2
|
||||
|
||||
weight_global_scale = layer.weight_scale_2.to(torch.float32)
|
||||
|
||||
if torch.unique(weight_global_scale).numel() != 1:
|
||||
logger.warning_once(
|
||||
"In NVFP4 linear, the global scale for weight are different"
|
||||
" for parallel layers (e.g. q_proj, k_proj, v_proj). This"
|
||||
" will likely result in reduced accuracy. Please verify the"
|
||||
" model accuracy. Consider using a checkpoint with a shared"
|
||||
" global NVFP4 scale for fused layers."
|
||||
)
|
||||
|
||||
weight_global_scale = weight_global_scale.max()
|
||||
|
||||
layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False)
|
||||
del layer.weight_scale_2
|
||||
|
||||
layer.alpha = Parameter(
|
||||
layer.input_global_scale * layer.weight_global_scale, requires_grad=False
|
||||
)
|
||||
layer.input_global_scale_inv = Parameter(
|
||||
(1.0 / layer.input_global_scale).to(torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
# Convert layer to NVFP4 linear kernel format
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer=layer, x=x, bias=bias)
|
||||
@@ -17,7 +17,8 @@ def deep_compare(dict1: Any, dict2: Any) -> bool:
|
||||
return False
|
||||
return all(deep_compare(dict1[k], dict2[k]) for k in dict1)
|
||||
elif isinstance(dict1, list):
|
||||
return set(dict1) == set(dict2)
|
||||
# `dict1` may be a list of dict.
|
||||
return all(deep_compare(dict1[i], dict2[i]) for i in range(len(dict1)))
|
||||
else:
|
||||
return dict1 == dict2
|
||||
|
||||
|
||||
Reference in New Issue
Block a user