mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
Refactor CT NVFP4 linear to use a single class (#42443)
This commit is contained in:
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
||||
CompressedTensorsW4A4Fp4,
|
||||
CompressedTensorsW4A4Mxfp4,
|
||||
CompressedTensorsW4A8Fp8,
|
||||
CompressedTensorsW4A16Fp4,
|
||||
CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A8Mxfp8,
|
||||
@@ -37,9 +36,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
find_matched_target,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
cutlass_fp4_supported,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
|
||||
@@ -376,13 +372,12 @@ def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner):
|
||||
@pytest.mark.parametrize(
|
||||
"args",
|
||||
[
|
||||
# TODO: Enable once model is available again
|
||||
# ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", CompressedTensorsW4A16Fp4),
|
||||
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", CompressedTensorsW4A4Fp4),
|
||||
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", True),
|
||||
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", False),
|
||||
],
|
||||
)
|
||||
def test_compressed_tensors_nvfp4(vllm_runner, args):
|
||||
model, scheme = args
|
||||
model, use_a16 = args
|
||||
with vllm_runner(model, enforce_eager=True) as llm:
|
||||
|
||||
def check_model(model):
|
||||
@@ -390,15 +385,8 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||
if (
|
||||
isinstance(qkv_proj.scheme, scheme)
|
||||
or isinstance(qkv_proj.scheme, CompressedTensorsW4A16Fp4)
|
||||
and not cutlass_fp4_supported()
|
||||
):
|
||||
assert True
|
||||
else:
|
||||
raise AssertionError("FP4 Scheme Mismatch")
|
||||
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A4Fp4)
|
||||
assert qkv_proj.scheme.use_a16 == use_a16
|
||||
assert qkv_proj.scheme.group_size == 16
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
@@ -842,7 +842,7 @@ _NVFP4_BACKEND_TO_KERNEL: dict[str, type[NvFp4LinearKernel]] = {
|
||||
}
|
||||
|
||||
|
||||
def init_nvfp4_linear_kernel() -> NvFp4LinearKernel:
|
||||
def init_nvfp4_linear_kernel(use_a16: bool = False) -> NvFp4LinearKernel:
|
||||
"""Select and instantiate the best NVFP4 linear kernel for the
|
||||
current platform."""
|
||||
config = NvFp4LinearLayerConfig()
|
||||
@@ -885,7 +885,9 @@ def init_nvfp4_linear_kernel() -> NvFp4LinearKernel:
|
||||
elif linear_backend == "auto":
|
||||
# Deprecated env-var overrides — only honoured when --linear-backend
|
||||
# is "auto". Deprecation warnings are emitted from vllm/envs.py.
|
||||
if envs.VLLM_USE_FBGEMM:
|
||||
if use_a16: # force a16 if running weight-only quantization
|
||||
force_kernel = MarlinNvFp4LinearKernel
|
||||
elif envs.VLLM_USE_FBGEMM:
|
||||
force_kernel = FbgemmNvFp4LinearKernel
|
||||
elif envs.VLLM_USE_NVFP4_CT_EMULATIONS:
|
||||
force_kernel = EmulationNvFp4LinearKernel
|
||||
|
||||
@@ -40,7 +40,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsW4A4Mxfp4,
|
||||
CompressedTensorsW4A8Fp8,
|
||||
CompressedTensorsW4A8Int,
|
||||
CompressedTensorsW4A16Fp4,
|
||||
CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A8Mxfp8,
|
||||
@@ -616,8 +615,16 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
format = format if format is not None else self.quant_format
|
||||
|
||||
# Detect If Mixed Precision
|
||||
if self._is_nvfp4_format(weight_quant) and input_quant is None:
|
||||
return CompressedTensorsW4A16Fp4()
|
||||
if self._is_nvfp4_format(weight_quant):
|
||||
if input_quant is None:
|
||||
return CompressedTensorsW4A4Fp4(use_a16=True)
|
||||
|
||||
if not self._is_nvfp4_format(input_quant):
|
||||
raise ValueError(
|
||||
"For NVFP4 weights, input quantization must also be NVFP4 format, ",
|
||||
"None for NVFP4A16",
|
||||
)
|
||||
return CompressedTensorsW4A4Fp4()
|
||||
|
||||
if self._is_mxfp4(weight_quant):
|
||||
return CompressedTensorsW4A4Mxfp4()
|
||||
@@ -650,11 +657,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
act_quant_format = is_activation_quantization_format(format)
|
||||
if act_quant_format:
|
||||
if self._is_nvfp4_format(weight_quant) and self._is_nvfp4_format(
|
||||
input_quant
|
||||
):
|
||||
return CompressedTensorsW4A4Fp4()
|
||||
|
||||
if self._is_fp8_w8a8(weight_quant, input_quant):
|
||||
is_fp8_w8a8_supported = self._check_scheme_supported(
|
||||
CompressedTensorsW8A8Fp8.get_min_capability(), error=False
|
||||
|
||||
@@ -6,7 +6,6 @@ from .compressed_tensors_w4a4_mxfp4 import CompressedTensorsW4A4Mxfp4
|
||||
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
|
||||
from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8
|
||||
from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
|
||||
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
|
||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
||||
from .compressed_tensors_w8a8_mxfp8 import CompressedTensorsW8A8Mxfp8
|
||||
@@ -20,7 +19,6 @@ __all__ = [
|
||||
"CompressedTensorsW8A8Int8",
|
||||
"CompressedTensorsW8A8Fp8",
|
||||
"WNA16_SUPPORTED_BITS",
|
||||
"CompressedTensorsW4A16Fp4",
|
||||
"CompressedTensorsW4A4Mxfp4",
|
||||
"CompressedTensorsW4A4Fp4",
|
||||
"CompressedTensorsW4A8Int",
|
||||
|
||||
-109
@@ -1,109 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
apply_fp4_marlin_linear,
|
||||
prepare_fp4_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
|
||||
__all__ = ["CompressedTensorsW4A16Fp4"]
|
||||
|
||||
|
||||
class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
||||
def __init__(self):
|
||||
self.group_size = 16
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# don't restrict as emulations
|
||||
return 75
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
# Weight
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
|
||||
# Global Weight Scale
|
||||
weight_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_global_scale", weight_global_scale)
|
||||
|
||||
# Per Group Weight Scale
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // self.group_size,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Process parameters for marlin repacking
|
||||
|
||||
# Rename weight_packed to weight that marlin expects
|
||||
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
|
||||
del layer.weight_packed
|
||||
# ct stores the inverse of what is expected by the marlin kernel
|
||||
layer.weight_global_scale = Parameter(
|
||||
1.0 / layer.weight_global_scale.max().to(torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_fp4_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_global_scale=layer.weight_global_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
||||
+36
-24
@@ -23,8 +23,9 @@ __all__ = ["CompressedTensorsW4A4Fp4"]
|
||||
|
||||
|
||||
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
def __init__(self):
|
||||
self.kernel = init_nvfp4_linear_kernel()
|
||||
def __init__(self, use_a16: bool = False):
|
||||
self.use_a16 = use_a16
|
||||
self.kernel = init_nvfp4_linear_kernel(use_a16=use_a16)
|
||||
self.group_size = 16
|
||||
|
||||
@classmethod
|
||||
@@ -79,46 +80,57 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
input_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("input_global_scale", input_global_scale)
|
||||
if not self.use_a16:
|
||||
input_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("input_global_scale", input_global_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Rename CT checkpoint names to standardized names
|
||||
layer.weight = layer.weight_packed
|
||||
del layer.weight_packed
|
||||
|
||||
if (
|
||||
torch.unique(layer.input_global_scale).numel() != 1
|
||||
or torch.unique(layer.weight_global_scale).numel() != 1
|
||||
):
|
||||
# Check for mismatched weight global scales
|
||||
if torch.unique(layer.weight_global_scale).numel() != 1:
|
||||
logger.warning_once(
|
||||
"In NVFP4 linear, the global scale for input or weight are different"
|
||||
"In NVFP4 linear, the weight global scale is different"
|
||||
" for parallel layers (e.g. q_proj, k_proj, v_proj). This "
|
||||
" will likely result in reduced accuracy. Please verify the model"
|
||||
" accuracy. Consider using a checkpoint with a shared global NVFP4"
|
||||
" scale for fused layers."
|
||||
)
|
||||
|
||||
# Process global scales (CT stores as divisors, i.e. 1/scale)
|
||||
input_global_scale_inv = layer.input_global_scale.max().to(torch.float32)
|
||||
layer.input_global_scale = Parameter(
|
||||
(1.0 / input_global_scale_inv).to(torch.float32), requires_grad=False
|
||||
)
|
||||
# Process weight global scale (CT stores as divisors, i.e. 1/scale)
|
||||
weight_global_scale = layer.weight_global_scale.max().to(torch.float32)
|
||||
layer.weight_global_scale = Parameter(
|
||||
1.0 / weight_global_scale, requires_grad=False
|
||||
)
|
||||
|
||||
# Pre-compute alpha and inverse for runtime quantization
|
||||
layer.input_global_scale_inv = Parameter(
|
||||
input_global_scale_inv, requires_grad=False
|
||||
)
|
||||
layer.alpha = Parameter(
|
||||
layer.input_global_scale * layer.weight_global_scale, requires_grad=False
|
||||
)
|
||||
if not self.use_a16:
|
||||
if torch.unique(layer.input_global_scale).numel() != 1:
|
||||
logger.warning_once(
|
||||
"In NVFP4 linear, the input global scale is different"
|
||||
" for parallel layers (e.g. q_proj, k_proj, v_proj). This "
|
||||
" will likely result in reduced accuracy. Please verify the model"
|
||||
" accuracy. Consider using a checkpoint with a shared global NVFP4"
|
||||
" scale for fused layers."
|
||||
)
|
||||
# Process input global scale and pre-compute alpha for W4A4 mode
|
||||
input_global_scale_inv = layer.input_global_scale.max().to(torch.float32)
|
||||
layer.input_global_scale = Parameter(
|
||||
(1.0 / input_global_scale_inv).to(torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
# Pre-compute alpha and inverse for runtime quantization
|
||||
layer.input_global_scale_inv = Parameter(
|
||||
input_global_scale_inv, requires_grad=False
|
||||
)
|
||||
layer.alpha = Parameter(
|
||||
layer.input_global_scale * layer.weight_global_scale,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# Convert layer to NVFP4 linear kernel format
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
Reference in New Issue
Block a user