Refactor CT NVFP4 linear to use a single class (#42443)

This commit is contained in:
Dipika Sikka
2026-06-04 08:25:08 -04:00
committed by GitHub
parent 4b87b3e845
commit e68988a248
6 changed files with 55 additions and 162 deletions
+5 -17
View File
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsW4A4Fp4,
CompressedTensorsW4A4Mxfp4,
CompressedTensorsW4A8Fp8,
CompressedTensorsW4A16Fp4,
CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8,
CompressedTensorsW8A8Mxfp8,
@@ -37,9 +36,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target,
)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
cutlass_fp4_supported,
)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.platforms import current_platform
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
@@ -376,13 +372,12 @@ def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner):
@pytest.mark.parametrize(
"args",
[
# TODO: Enable once model is available again
# ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", CompressedTensorsW4A16Fp4),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", CompressedTensorsW4A4Fp4),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", True),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", False),
],
)
def test_compressed_tensors_nvfp4(vllm_runner, args):
model, scheme = args
model, use_a16 = args
with vllm_runner(model, enforce_eager=True) as llm:
def check_model(model):
@@ -390,15 +385,8 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
if (
isinstance(qkv_proj.scheme, scheme)
or isinstance(qkv_proj.scheme, CompressedTensorsW4A16Fp4)
and not cutlass_fp4_supported()
):
assert True
else:
raise AssertionError("FP4 Scheme Mismatch")
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A4Fp4)
assert qkv_proj.scheme.use_a16 == use_a16
assert qkv_proj.scheme.group_size == 16
llm.apply_model(check_model)
@@ -842,7 +842,7 @@ _NVFP4_BACKEND_TO_KERNEL: dict[str, type[NvFp4LinearKernel]] = {
}
def init_nvfp4_linear_kernel() -> NvFp4LinearKernel:
def init_nvfp4_linear_kernel(use_a16: bool = False) -> NvFp4LinearKernel:
"""Select and instantiate the best NVFP4 linear kernel for the
current platform."""
config = NvFp4LinearLayerConfig()
@@ -885,7 +885,9 @@ def init_nvfp4_linear_kernel() -> NvFp4LinearKernel:
elif linear_backend == "auto":
# Deprecated env-var overrides — only honoured when --linear-backend
# is "auto". Deprecation warnings are emitted from vllm/envs.py.
if envs.VLLM_USE_FBGEMM:
if use_a16: # force a16 if running weight-only quantization
force_kernel = MarlinNvFp4LinearKernel
elif envs.VLLM_USE_FBGEMM:
force_kernel = FbgemmNvFp4LinearKernel
elif envs.VLLM_USE_NVFP4_CT_EMULATIONS:
force_kernel = EmulationNvFp4LinearKernel
@@ -40,7 +40,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW4A4Mxfp4,
CompressedTensorsW4A8Fp8,
CompressedTensorsW4A8Int,
CompressedTensorsW4A16Fp4,
CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8,
CompressedTensorsW8A8Mxfp8,
@@ -616,8 +615,16 @@ class CompressedTensorsConfig(QuantizationConfig):
format = format if format is not None else self.quant_format
# Detect If Mixed Precision
if self._is_nvfp4_format(weight_quant) and input_quant is None:
return CompressedTensorsW4A16Fp4()
if self._is_nvfp4_format(weight_quant):
if input_quant is None:
return CompressedTensorsW4A4Fp4(use_a16=True)
if not self._is_nvfp4_format(input_quant):
raise ValueError(
"For NVFP4 weights, input quantization must also be NVFP4 format, ",
"None for NVFP4A16",
)
return CompressedTensorsW4A4Fp4()
if self._is_mxfp4(weight_quant):
return CompressedTensorsW4A4Mxfp4()
@@ -650,11 +657,6 @@ class CompressedTensorsConfig(QuantizationConfig):
act_quant_format = is_activation_quantization_format(format)
if act_quant_format:
if self._is_nvfp4_format(weight_quant) and self._is_nvfp4_format(
input_quant
):
return CompressedTensorsW4A4Fp4()
if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(), error=False
@@ -6,7 +6,6 @@ from .compressed_tensors_w4a4_mxfp4 import CompressedTensorsW4A4Mxfp4
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8
from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
from .compressed_tensors_w8a8_mxfp8 import CompressedTensorsW8A8Mxfp8
@@ -20,7 +19,6 @@ __all__ = [
"CompressedTensorsW8A8Int8",
"CompressedTensorsW8A8Fp8",
"WNA16_SUPPORTED_BITS",
"CompressedTensorsW4A16Fp4",
"CompressedTensorsW4A4Mxfp4",
"CompressedTensorsW4A4Fp4",
"CompressedTensorsW4A8Int",
@@ -1,109 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear,
prepare_fp4_layer_for_marlin,
)
from vllm.model_executor.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
__all__ = ["CompressedTensorsW4A16Fp4"]
class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
def __init__(self):
self.group_size = 16
@classmethod
def get_min_capability(cls) -> int:
# don't restrict as emulations
return 75
def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
# Weight
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_packed", weight)
# Global Weight Scale
weight_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_global_scale", weight_global_scale)
# Per Group Weight Scale
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // self.group_size,
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Process parameters for marlin repacking
# Rename weight_packed to weight that marlin expects
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
del layer.weight_packed
# ct stores the inverse of what is expected by the marlin kernel
layer.weight_global_scale = Parameter(
1.0 / layer.weight_global_scale.max().to(torch.float32), requires_grad=False
)
prepare_fp4_layer_for_marlin(layer)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return apply_fp4_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_global_scale=layer.weight_global_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
@@ -23,8 +23,9 @@ __all__ = ["CompressedTensorsW4A4Fp4"]
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
def __init__(self):
self.kernel = init_nvfp4_linear_kernel()
def __init__(self, use_a16: bool = False):
self.use_a16 = use_a16
self.kernel = init_nvfp4_linear_kernel(use_a16=use_a16)
self.group_size = 16
@classmethod
@@ -79,46 +80,57 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
layer.register_parameter("weight_scale", weight_scale)
input_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("input_global_scale", input_global_scale)
if not self.use_a16:
input_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("input_global_scale", input_global_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Rename CT checkpoint names to standardized names
layer.weight = layer.weight_packed
del layer.weight_packed
if (
torch.unique(layer.input_global_scale).numel() != 1
or torch.unique(layer.weight_global_scale).numel() != 1
):
# Check for mismatched weight global scales
if torch.unique(layer.weight_global_scale).numel() != 1:
logger.warning_once(
"In NVFP4 linear, the global scale for input or weight are different"
"In NVFP4 linear, the weight global scale is different"
" for parallel layers (e.g. q_proj, k_proj, v_proj). This "
" will likely result in reduced accuracy. Please verify the model"
" accuracy. Consider using a checkpoint with a shared global NVFP4"
" scale for fused layers."
)
# Process global scales (CT stores as divisors, i.e. 1/scale)
input_global_scale_inv = layer.input_global_scale.max().to(torch.float32)
layer.input_global_scale = Parameter(
(1.0 / input_global_scale_inv).to(torch.float32), requires_grad=False
)
# Process weight global scale (CT stores as divisors, i.e. 1/scale)
weight_global_scale = layer.weight_global_scale.max().to(torch.float32)
layer.weight_global_scale = Parameter(
1.0 / weight_global_scale, requires_grad=False
)
# Pre-compute alpha and inverse for runtime quantization
layer.input_global_scale_inv = Parameter(
input_global_scale_inv, requires_grad=False
)
layer.alpha = Parameter(
layer.input_global_scale * layer.weight_global_scale, requires_grad=False
)
if not self.use_a16:
if torch.unique(layer.input_global_scale).numel() != 1:
logger.warning_once(
"In NVFP4 linear, the input global scale is different"
" for parallel layers (e.g. q_proj, k_proj, v_proj). This "
" will likely result in reduced accuracy. Please verify the model"
" accuracy. Consider using a checkpoint with a shared global NVFP4"
" scale for fused layers."
)
# Process input global scale and pre-compute alpha for W4A4 mode
input_global_scale_inv = layer.input_global_scale.max().to(torch.float32)
layer.input_global_scale = Parameter(
(1.0 / input_global_scale_inv).to(torch.float32), requires_grad=False
)
# Pre-compute alpha and inverse for runtime quantization
layer.input_global_scale_inv = Parameter(
input_global_scale_inv, requires_grad=False
)
layer.alpha = Parameter(
layer.input_global_scale * layer.weight_global_scale,
requires_grad=False,
)
# Convert layer to NVFP4 linear kernel format
self.kernel.process_weights_after_loading(layer)