From 06ee2d8433831f69d5de3a6d9fa3d7d042dd394f Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 4 Jun 2026 10:40:33 -0400 Subject: [PATCH] [Quant] Support compressed-tensors WNA8O8Int linears and WNInt embeddings (#44340) Signed-off-by: mgoin --- requirements/common.txt | 2 +- requirements/cuda.txt | 2 +- requirements/test/rocm.txt | 2 +- .../quantization/test_quantized_embedding.py | 67 +++++ .../model_executor/kernels/linear/__init__.py | 4 + .../kernels/linear/mixed_precision/humming.py | 61 +++++ .../compressed_tensors/compressed_tensors.py | 103 ++++++- .../compressed_tensors_embedding.py | 170 ++++++++++++ .../compressed_tensors/schemes/__init__.py | 2 + .../schemes/compressed_tensors_wNa8o8.py | 257 ++++++++++++++++++ .../quantization/utils/humming_utils.py | 66 ++++- vllm/model_executor/models/gemma4.py | 6 +- vllm/model_executor/models/gemma4_mm.py | 24 +- vllm/model_executor/models/gemma4_unified.py | 5 +- 14 files changed, 744 insertions(+), 27 deletions(-) create mode 100644 tests/kernels/quantization/test_quantized_embedding.py create mode 100644 vllm/model_executor/kernels/linear/mixed_precision/humming.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_embedding.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa8o8.py diff --git a/requirements/common.txt b/requirements/common.txt index d37ef1f1fed..8141dc8ea6b 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -38,7 +38,7 @@ pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 setuptools>=77.0.3,<81.0.0; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 einops # Required for Qwen2-VL. -compressed-tensors == 0.15.0.1 # required for compressed-tensors +compressed-tensors == 0.17.0 # required for compressed-tensors depyf==0.20.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py watchfiles # required for http server to monitor the updates of TLS files diff --git a/requirements/cuda.txt b/requirements/cuda.txt index 479a949fd0d..618f8ae0a37 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -28,4 +28,4 @@ quack-kernels>=0.3.3 tokenspeed-mla==0.1.2 # Humming kernels for quantization gemm -humming-kernels[cu13]==0.1.2 +humming-kernels[cu13]==0.1.4 diff --git a/requirements/test/rocm.txt b/requirements/test/rocm.txt index 5a8e70946ad..e0232d8b6d3 100644 --- a/requirements/test/rocm.txt +++ b/requirements/test/rocm.txt @@ -143,7 +143,7 @@ colorful==0.5.8 # via ray colorlog==6.10.1 # via optuna -compressed-tensors==0.15.0.1 +compressed-tensors==0.17.0 # via # -c requirements/common.txt # -r requirements/test/../common.txt diff --git a/tests/kernels/quantization/test_quantized_embedding.py b/tests/kernels/quantization/test_quantized_embedding.py new file mode 100644 index 00000000000..0e4af0a0c1a --- /dev/null +++ b/tests/kernels/quantization/test_quantized_embedding.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the Triton dequant-gather kernel used by +``CompressedTensorsEmbeddingWNA16Int`` (quantized embedding lookup).""" + +import pytest +import torch +from compressed_tensors.compressors.pack_quantized.helpers import unpack_from_int32 + +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_embedding import ( # noqa: E501 + _dequant_gather_triton, +) +from vllm.platforms import current_platform + + +def _dequant_gather_torch( + ids: torch.Tensor, + weight_packed: torch.Tensor, + weight_scale: torch.Tensor, + hidden: int, + num_bits: int, +) -> torch.Tensor: + """Reference: gather packed rows by id, unpack int32-packed INT, dequant.""" + n = ids.shape[0] + int8 = unpack_from_int32(weight_packed[ids], num_bits, torch.Size([n, hidden])) + scale_rows = weight_scale[ids] + w = int8.to(scale_rows.dtype) + if scale_rows.shape[1] == 1: + return w * scale_rows + ng = scale_rows.shape[1] + return (w.view(n, ng, hidden // ng) * scale_rows.unsqueeze(-1)).view(n, hidden) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="Triton dequant kernel requires CUDA" +) +@pytest.mark.parametrize("num_bits", [2, 4, 8]) +@pytest.mark.parametrize("group_size", [0, 256]) # 0 -> channel +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("num_ids", [1, 17, 4096]) +def test_dequant_gather(num_bits, group_size, dtype, num_ids): + torch.manual_seed(0) + device = "cuda" + vocab, hidden = 1000, 2048 + pack_factor = 32 // num_bits + + # Random full-range int32 packed weights (covers the sign bit -> exercises the + # arithmetic-shift + mask unpack path). + weight_packed = torch.randint( + -(2**31), + 2**31, + (vocab, hidden // pack_factor), + dtype=torch.int32, + device=device, + ) + + num_groups = 1 if group_size == 0 else hidden // group_size + weight_scale = torch.rand(vocab, num_groups, dtype=dtype, device=device) + 0.01 + + ids = torch.randint(0, vocab, (num_ids,), dtype=torch.long, device=device) + + out = _dequant_gather_triton(ids, weight_packed, weight_scale, hidden, num_bits) + ref = _dequant_gather_torch(ids, weight_packed, weight_scale, hidden, num_bits) + + assert out.shape == (num_ids, hidden) + assert out.dtype == dtype + torch.testing.assert_close(out, ref) diff --git a/vllm/model_executor/kernels/linear/__init__.py b/vllm/model_executor/kernels/linear/__init__.py index 32cf8a51565..cd2c9eb01be 100644 --- a/vllm/model_executor/kernels/linear/__init__.py +++ b/vllm/model_executor/kernels/linear/__init__.py @@ -45,6 +45,9 @@ from vllm.model_executor.kernels.linear.mixed_precision.dynamic_4bit import ( from vllm.model_executor.kernels.linear.mixed_precision.exllama import ( ExllamaLinearKernel, ) +from vllm.model_executor.kernels.linear.mixed_precision.humming import ( + HummingLinearKernel, +) from vllm.model_executor.kernels.linear.mixed_precision.machete import ( MacheteLinearKernel, ) @@ -345,6 +348,7 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = { MacheteLinearKernel, AllSparkLinearKernel, MarlinLinearKernel, + HummingLinearKernel, ConchLinearKernel, ExllamaLinearKernel, TritonW4A16LinearKernel, diff --git a/vllm/model_executor/kernels/linear/mixed_precision/humming.py b/vllm/model_executor/kernels/linear/mixed_precision/humming.py new file mode 100644 index 00000000000..cb02d661294 --- /dev/null +++ b/vllm/model_executor/kernels/linear/mixed_precision/humming.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Humming GEMM as a mixed-precision WNA16Int linear kernel.""" + +import torch + +from vllm.platforms import current_platform +from vllm.utils.import_utils import _has_module + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class HummingLinearKernel(MPLinearKernel): + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]: + if not current_platform.is_cuda(): + return False, "Humming is only supported on CUDA" + if not _has_module("humming"): + return False, "Humming is not installed" + if c.has_g_idx: + return False, "Humming does not support act-order (g_idx)" + if c.zero_points: + return False, "Humming linear kernel only supports symmetric weights" + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + from vllm.model_executor.layers.quantization.utils.humming_utils import ( + convert_linear_layer_to_humming_standard, + prepare_humming_layer, + ) + + name_map = {"weight": self.w_q_name, "weight_scale": self.w_s_name} + group_size = self.config.group_size + quant_config = { + "quant_method": "humming", + "dtype": "int" + str(self.config.weight_type.size_bits), + "group_size": 0 if group_size == -1 else group_size, + } + + convert_linear_layer_to_humming_standard(layer=layer, name_map=name_map) + prepare_humming_layer(layer, quant_config) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + from humming.layer import HummingMethod + + flatten_inputs = x.view(-1, x.size(-1)) + output = HummingMethod.forward_layer( + layer=layer, + inputs=flatten_inputs, + compute_config=layer.compute_config, + ) + return output.view(*x.shape[:-1], output.size(-1)) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 9715c74aea2..b59e12e8e1b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -30,6 +30,9 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_embedding import ( # noqa: E501 + CompressedTensorsEmbeddingWNA16Int, +) from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 CompressedTensorsMoEMethod, ) @@ -44,6 +47,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsW8A8Int8, CompressedTensorsW8A8Mxfp8, CompressedTensorsW8A16Fp8, + CompressedTensorsWNA8O8Int, CompressedTensorsWNA16, ) from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501 @@ -56,7 +60,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( should_ignore_layer, ) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.platforms import current_platform if TYPE_CHECKING: @@ -177,6 +184,24 @@ class CompressedTensorsConfig(QuantizationConfig): layer.scheme = quant_scheme return CompressedTensorsLinearMethod(self) + # ParallelLMHead subclasses VocabParallelEmbedding but is handled above as + # a linear; only true embedding lookups land here. + if isinstance(layer, VocabParallelEmbedding): + scheme_dict = self.get_scheme_dict(layer, layer_name=prefix) + weight_quant = scheme_dict.get("weights") if scheme_dict else None + if weight_quant is None: + return None # unquantized embedding + if not ( + isinstance(weight_quant, QuantizationArgs) + and self._is_wNa16_group_channel(weight_quant, None) + and weight_quant.type == QuantizationType.INT + ): + raise ValueError( + "compressed-tensors embeddings only support weight-only INT " + f"group/channel (WNA16) quantization, got: {weight_quant}" + ) + return CompressedTensorsEmbeddingWNA16Int(weight_quant) + if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) if isinstance(layer, RoutedExperts): @@ -324,6 +349,15 @@ class CompressedTensorsConfig(QuantizationConfig): quant_config.get("input_activations") ) ) + + # Static output-activation quant is applied as a float fake-quant + # on the layer output; capture it when present. + target_scheme_map[target]["output_activations"] = None + output_activations = quant_config.get("output_activations") + if output_activations: + target_scheme_map[target]["output_activations"] = ( + QuantizationArgs.model_validate(output_activations) + ) return target_scheme_map @classmethod @@ -604,10 +638,56 @@ class CompressedTensorsConfig(QuantizationConfig): return is_channel_group and input_quant_none and is_static + @staticmethod + def _is_wNa8o8_int( + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs | None, + output_quant: QuantizationArgs | None, + format: str | None, + ) -> bool: + """Weight N-bit INT (pack-quantized for sub-byte, int-quantized for 8-bit) + with static per-tensor INT8 input/output activation quant, applied as a float + fake-quant around a weight-only matmul.""" + is_int_pack_format = format in ( + CompressionFormat.pack_quantized.value, + CompressionFormat.int_quantized.value, + ) + is_channel_group = weight_quant.strategy in ( + QuantizationStrategy.CHANNEL.value, + QuantizationStrategy.GROUP.value, + ) + is_static_int = ( + weight_quant.type == QuantizationType.INT and not weight_quant.dynamic + ) + is_intN_weight = is_static_int and is_channel_group and is_int_pack_format + is_static_int8_in = ( + input_quant is not None + and input_quant.type == QuantizationType.INT + and input_quant.strategy == QuantizationStrategy.TENSOR.value + and input_quant.num_bits == 8 + and not input_quant.dynamic + ) + is_static_int8_out = ( + output_quant is not None + and output_quant.type == QuantizationType.INT + and output_quant.strategy == QuantizationStrategy.TENSOR.value + and output_quant.num_bits == 8 + and not output_quant.dynamic + ) + # Static int8-activation layers, plus sub-byte weight-only layers (e.g. + # 2-bit lm_head) that marlin-backed WNA16 cannot serve. Standard 4/8-bit + # weight-only (no activations) falls through to WNA16. + is_subbyte_weight_only = weight_quant.num_bits not in WNA16_SUPPORTED_BITS + needs_wNa8o8 = is_intN_weight and ( + (is_static_int8_in and is_static_int8_out) or is_subbyte_weight_only + ) + return needs_wNa8o8 + def _get_scheme_from_parts( self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs, + output_quant: QuantizationArgs | None = None, format: str | None = None, layer_name: str | None = None, ) -> "CompressedTensorsScheme": @@ -641,6 +721,19 @@ class CompressedTensorsConfig(QuantizationConfig): actorder=weight_quant.actorder, ) + # Must come before the WNA16 check; standard 4/8-bit weight-only (no + # output-activation scale) still falls through to WNA16. + if self._is_wNa8o8_int(weight_quant, input_quant, output_quant, format): + return CompressedTensorsWNA8O8Int( + num_bits=weight_quant.num_bits, + strategy=weight_quant.strategy, + group_size=weight_quant.group_size, + has_input_act=input_quant is not None, + has_output_act=output_quant is not None, + layer_name=layer_name, + quant_format=format, + ) + if ( self._is_wNa16_group_channel(weight_quant, input_quant) and (format == CompressionFormat.pack_quantized.value) @@ -708,7 +801,10 @@ class CompressedTensorsConfig(QuantizationConfig): input_symmetric=input_quant.symmetric, ) - raise NotImplementedError("No compressed-tensors compatible scheme was found.") + raise NotImplementedError( + f"No compressed-tensors compatible scheme was found for {layer_name=}, " + f"{weight_quant=}, {input_quant=}, {output_quant=}, {format=}" + ) def get_scheme( self, layer: torch.nn.Module, layer_name: str | None = None @@ -731,10 +827,12 @@ class CompressedTensorsConfig(QuantizationConfig): weight_quant = None input_quant = None + output_quant = None format = None if scheme_dict: weight_quant = scheme_dict.get("weights") input_quant = scheme_dict.get("input_activations") + output_quant = scheme_dict.get("output_activations") format = scheme_dict.get("format") if weight_quant is None: @@ -746,6 +844,7 @@ class CompressedTensorsConfig(QuantizationConfig): scheme = self._get_scheme_from_parts( # type: ignore weight_quant=weight_quant, input_quant=input_quant, + output_quant=output_quant, format=format, layer_name=layer_name, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_embedding.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_embedding.py new file mode 100644 index 00000000000..23d25261301 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_embedding.py @@ -0,0 +1,170 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Quantized embedding method for compressed-tensors. + +Adds dequant-on-lookup support for a pack-quantized ``VocabParallelEmbedding`` +(2-8 bit INT, channel- or group-quantized). Only the gathered token rows are +unpacked and dequantized, so the packed weight is never densified. +""" + +import torch +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy + +from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase +from vllm.model_executor.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter, +) +from vllm.triton_utils import tl, triton + +__all__ = ["CompressedTensorsEmbeddingWNA16Int"] + + +@triton.jit +def _dequant_gather_kernel( + ids_ptr, + packed_ptr, + scale_ptr, + out_ptr, + hidden, + packed_cols, + num_groups, + NUM_BITS: tl.constexpr, + PACK_FACTOR: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK: tl.constexpr, +): + """Gather embedding rows by token id, unpack int32-packed INT weights, and + dequantize to ``out`` dtype in one pass (no int8 intermediate).""" + row = tl.program_id(0) + col = tl.program_id(1) * BLOCK + tl.arange(0, BLOCK) + col_mask = col < hidden + tid = tl.load(ids_ptr + row).to(tl.int64) + + packed_idx = col // PACK_FACTOR + shift = (col % PACK_FACTOR) * NUM_BITS + packed = tl.load( + packed_ptr + tid * packed_cols + packed_idx, mask=col_mask, other=0 + ) + q = ((packed >> shift) & ((1 << NUM_BITS) - 1)) - (1 << (NUM_BITS - 1)) + + if GROUP_SIZE == 0: # channel: one scale per row + scale = tl.load(scale_ptr + tid) + else: # group: one scale per (row, group) + grp = col // GROUP_SIZE + scale = tl.load(scale_ptr + tid * num_groups + grp, mask=col_mask, other=0.0) + + out = q.to(tl.float32) * scale.to(tl.float32) + tl.store( + out_ptr + row * hidden + col, out.to(out_ptr.dtype.element_ty), mask=col_mask + ) + + +def _dequant_gather_triton( + ids: torch.Tensor, + weight_packed: torch.Tensor, + weight_scale: torch.Tensor, + hidden: int, + num_bits: int, +) -> torch.Tensor: + n = ids.numel() + out = torch.empty(n, hidden, dtype=weight_scale.dtype, device=weight_packed.device) + num_groups = weight_scale.shape[1] + group_size = 0 if num_groups == 1 else hidden // num_groups + block = min(triton.next_power_of_2(hidden), 1024) + grid = (n, triton.cdiv(hidden, block)) + _dequant_gather_kernel[grid]( + ids, + weight_packed, + weight_scale, + out, + hidden, + weight_packed.shape[1], + num_groups, + NUM_BITS=num_bits, + PACK_FACTOR=32 // num_bits, + GROUP_SIZE=group_size, + BLOCK=block, + ) + return out + + +class CompressedTensorsEmbeddingWNA16Int(QuantizeMethodBase): + def __init__(self, weight_quant: QuantizationArgs): + self.num_bits = weight_quant.num_bits + self.pack_factor = 32 // self.num_bits + self.strategy = weight_quant.strategy + self.group_size = weight_quant.group_size + self.is_group = ( + self.strategy == QuantizationStrategy.GROUP.value + and self.group_size is not None + ) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs["weight_loader"] + # Embedding weight is [num_embeddings(vocab), embedding_dim(hidden)]; + # vocab is the output (partitioned) dim, hidden is the input dim. + vocab_pp = sum(output_partition_sizes) + hidden = input_size_per_partition + layer.hidden_size = hidden + + weight_packed = PackedvLLMParameter( + input_dim=1, + output_dim=0, + packed_dim=1, + packed_factor=self.pack_factor, + weight_loader=weight_loader, + data=torch.empty(vocab_pp, hidden // self.pack_factor, dtype=torch.int32), + ) + + if self.is_group: + assert hidden % self.group_size == 0 + weight_scale = GroupQuantScaleParameter( + output_dim=0, + input_dim=1, + weight_loader=weight_loader, + data=torch.empty( + vocab_pp, hidden // self.group_size, dtype=params_dtype + ), + ) + else: + weight_scale = ChannelQuantScaleParameter( + output_dim=0, + weight_loader=weight_loader, + data=torch.empty(vocab_pp, 1, dtype=params_dtype), + ) + + weight_shape = BasevLLMParameter( + data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader + ) + + layer.register_parameter("weight_packed", weight_packed) + layer.register_parameter("weight_scale", weight_scale) + layer.register_parameter("weight_shape", weight_shape) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + pass + + def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: + ids = input_.reshape(-1).contiguous() + hidden = layer.hidden_size + deq = _dequant_gather_triton( + ids, layer.weight_packed, layer.weight_scale, hidden, self.num_bits + ) + return deq.reshape(*input_.shape, hidden) + + def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError( + "CompressedTensorsEmbeddingWNA16Int supports embedding lookup only" + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 4ba2fe080f1..d81db4a052f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -10,11 +10,13 @@ from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 from .compressed_tensors_w8a8_mxfp8 import CompressedTensorsW8A8Mxfp8 from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 +from .compressed_tensors_wNa8o8 import CompressedTensorsWNA8O8Int from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16 __all__ = [ "CompressedTensorsScheme", "CompressedTensorsWNA16", + "CompressedTensorsWNA8O8Int", "CompressedTensorsW8A16Fp8", "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8", diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa8o8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa8o8.py new file mode 100644 index 00000000000..52d9cfeb05b --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa8o8.py @@ -0,0 +1,257 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Weight N-bit INT scheme with static INT8 input/output activation quant. + +Handles compressed-tensors INT weight checkpoints that carry static per-tensor +INT8 ``input_activations`` and/or ``output_activations``. The activation quant is +reproduced as a float fake-quant on the layer input and output, around a +weight-only matmul, rather than a fused int8 GEMM. +""" + +from collections.abc import Callable + +import torch +from compressed_tensors.compressors.pack_quantized.helpers import pack_to_int32 + +from vllm.model_executor.kernels.linear import ( + MPLinearLayerConfig, + choose_mp_linear_kernel, +) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_repeat_scales_on_all_ranks, +) +from vllm.model_executor.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + ModelWeightParameter, + PackedvLLMParameter, +) +from vllm.scalar_type import scalar_types + +__all__ = ["CompressedTensorsWNA8O8Int", "fake_quant_static_int8"] + +WNA8O8_SUPPORTED_TYPES_MAP = { + 2: scalar_types.uint2b2, + 4: scalar_types.uint4b8, + 8: scalar_types.uint8b128, +} + + +def fake_quant_static_int8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Static per-tensor symmetric INT8 quantize-dequantize, in x's dtype.""" + scale = scale.to(x.dtype) + q = torch.clamp(torch.round(x / scale), -128.0, 127.0) + return q * scale + + +class CompressedTensorsWNA8O8Int(CompressedTensorsScheme): + def __init__( + self, + num_bits: int, + strategy: str, + group_size: int | None = None, + has_input_act: bool = False, + has_output_act: bool = False, + layer_name: str | None = None, + quant_format: str = "pack-quantized", + ): + self.num_bits = num_bits + self.pack_factor = 32 // num_bits + self.strategy = strategy + self.group_size = -1 if group_size is None else group_size + self.has_input_act = has_input_act + self.has_output_act = has_output_act + self.layer_name = layer_name + # "pack-quantized" (sub-byte, int32-packed) or "int-quantized" (8-bit int8). + self.quant_format = quant_format + self.is_int_quantized = quant_format == "int-quantized" + if num_bits not in WNA8O8_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits} for WNA8O8Int; " + f"supported = {sorted(WNA8O8_SUPPORTED_TYPES_MAP)}" + ) + self.quant_type = WNA8O8_SUPPORTED_TYPES_MAP[num_bits] + self._input_scale: torch.Tensor | None = None + self._output_scale: torch.Tensor | None = None + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + def create_weights( + self, + layer: torch.nn.Module, + output_size: int, + input_size: int, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): + output_size_per_partition = sum(output_partition_sizes) + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + # Set for kernels' weight prep; also covers ParallelLMHead, which does + # not set these in __init__. + layer.output_partition_sizes = output_partition_sizes + layer.params_dtype = params_dtype + if not hasattr(layer, "has_bias"): + layer.has_bias = False + + mp_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=( + input_size_per_partition, + output_size_per_partition, + ), + weight_type=self.quant_type, + act_type=params_dtype, # activation quant applied externally (SRQ) + group_size=self.group_size, + zero_points=False, + has_g_idx=False, + ) + self.kernel = choose_mp_linear_kernel(mp_config)( + mp_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + ) + + self._register_weight( + layer, input_size, input_size_per_partition, params_dtype, weight_loader + ) + + def _register_weight( + self, layer, input_size, input_size_per_partition, params_dtype, weight_loader + ): + out = layer.output_size_per_partition + if self.is_int_quantized: + # Plain int8 weight; packed to the canonical int32 layout after load. + layer.register_parameter( + "weight", + ModelWeightParameter( + data=torch.empty(out, input_size_per_partition, dtype=torch.int8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ), + ) + else: + layer.register_parameter( + "weight_packed", + PackedvLLMParameter( + input_dim=1, + output_dim=0, + packed_dim=1, + packed_factor=self.pack_factor, + weight_loader=weight_loader, + data=torch.empty( + out, + input_size_per_partition // self.pack_factor, + dtype=torch.int32, + ), + ), + ) + layer.register_parameter( + "weight_shape", + BasevLLMParameter( + data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader + ), + ) + + # Scale: per-output-channel, or per group along the input dim under TP. + group_size = self.group_size if self.group_size != -1 else input_size + partitioned = not marlin_repeat_scales_on_all_ranks( + False, self.group_size, input_size != input_size_per_partition + ) + scales = (input_size_per_partition if partitioned else input_size) // group_size + scale_data = torch.empty(out, scales, dtype=params_dtype) + if partitioned: + assert input_size_per_partition % group_size == 0 + weight_scale = GroupQuantScaleParameter( + data=scale_data, output_dim=0, input_dim=1, weight_loader=weight_loader + ) + else: + weight_scale = ChannelQuantScaleParameter( + data=scale_data, output_dim=0, weight_loader=weight_loader + ) + layer.register_parameter("weight_scale", weight_scale) + + for name, present in ( + ("input_scale", self.has_input_act), + ("output_scale", self.has_output_act), + ): + if present: + layer.register_parameter( + name, + BasevLLMParameter( + data=torch.empty(1, dtype=torch.float32), + weight_loader=weight_loader, + ), + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Lift the static activation scales off the layer (applied externally) so + # the kernel only sees weight tensors. Drop uncalibrated (zero) scales. + self._input_scale = self._take_act_scale(layer, "input_scale") + self._output_scale = self._take_act_scale(layer, "output_scale") + self.has_input_act = self._input_scale is not None + self.has_output_act = self._output_scale is not None + + if self.is_int_quantized: + self._pack_int_quantized_weight(layer) + + self.kernel.process_weights_after_loading(layer) + + def _pack_int_quantized_weight(self, layer: torch.nn.Module) -> None: + """Normalize an int-quantized (plain int8) weight to the canonical + ``weight_packed`` int32 + ``weight_shape`` layout the MP kernels expect.""" + weight = layer.weight + out_features, in_features = weight.shape + packed = pack_to_int32(weight.data.contiguous(), self.num_bits) + delattr(layer, "weight") + + def _noop_loader(*_, **__): + return None + + layer.register_parameter( + "weight_packed", + PackedvLLMParameter( + data=packed.contiguous(), + input_dim=1, + output_dim=0, + packed_dim=1, + packed_factor=self.pack_factor, + weight_loader=_noop_loader, + ), + ) + layer.register_parameter( + "weight_shape", + BasevLLMParameter( + data=torch.tensor([out_features, in_features], dtype=torch.int64), + weight_loader=_noop_loader, + ), + ) + + @staticmethod + def _take_act_scale(layer, name: str) -> torch.Tensor | None: + param = getattr(layer, name, None) + if param is None: + return None + scale = param.data.clone() + delattr(layer, name) + return None if float(scale.reshape(-1)[0]) == 0.0 else scale + + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None + ) -> torch.Tensor: + if self.has_input_act: + x = fake_quant_static_int8(x, self._input_scale) + out = self.kernel.apply_weights(layer, x, bias) + if self.has_output_act: + out = fake_quant_static_int8(out, self._output_scale) + return out diff --git a/vllm/model_executor/layers/quantization/utils/humming_utils.py b/vllm/model_executor/layers/quantization/utils/humming_utils.py index 3c01977f3b5..d9e02542c6d 100644 --- a/vllm/model_executor/layers/quantization/utils/humming_utils.py +++ b/vllm/model_executor/layers/quantization/utils/humming_utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json from typing import Any import regex as re @@ -42,16 +43,57 @@ def humming_is_layer_skipped(config: dict[str, Any], prefix: str): return False +def convert_linear_layer_to_humming_standard( + layer: LinearBase, name_map: dict[str, str] +): + """Rename/reshape a linear layer's quantized params (the canonical MPLinear + layout: ``weight_packed`` int32 + ``weight_scale``) into the parameter names + and layout humming's weight schema expects (``weight`` / ``weight_scale``).""" + for name, checkpoint_name in name_map.items(): + tensor = getattr(layer, checkpoint_name) + delattr(layer, checkpoint_name) + + if name == "weight": + input_dim = getattr(tensor, "input_dim", 1) + output_dim = getattr(tensor, "output_dim", 0) + + if input_dim == 0 and output_dim == 1: + tensor = tensor.transpose(1, 0).contiguous() + else: + assert output_dim == 0 and input_dim == 1 + + tensor = tensor.view(tensor.size(0), -1).view(torch.int32) + elif name in ["weight_scale", "zero_point"]: + if getattr(tensor, "output_dim", 0) == 1: + tensor = tensor.transpose(0, 1).contiguous() + if tensor.ndim == 1: + tensor = tensor.unsqueeze(1) + + tensor = tensor.view(torch.int32) if name == "zero_point" else tensor + + if isinstance(tensor, torch.nn.Parameter): + param = tensor + else: + param = torch.nn.Parameter(tensor, requires_grad=False) + + setattr(layer, name, param) + + def prepare_humming_layer(layer: LinearBase, quant_config: dict): weight_schema = BaseWeightSchema.from_config(quant_config) input_schema = HummingInputSchema() - shape_k_stacks = [layer.input_size_per_partition] + # ReplicatedLinear has no TP partitioning and so does not set + # input_size_per_partition; for it that is just input_size. + input_size_per_partition = getattr( + layer, "input_size_per_partition", layer.input_size + ) + shape_k_stacks = [input_size_per_partition] shape_n_stacks = layer.output_partition_sizes # Step 1: convert weight to humming standard format weight_schema, tensors = weight_schema.convert_humming( - tensors=layer.named_parameters(), + tensors=dict(layer.named_parameters()), shape_n_stacks=shape_n_stacks, shape_k_stacks=shape_k_stacks, param_dtype=layer.params_dtype, @@ -63,23 +105,37 @@ def prepare_humming_layer(layer: LinearBase, quant_config: dict): delattr(layer, name) for name, tensor in tensors.items(): + if isinstance(tensor, torch.nn.Parameter): + tensor = tensor.data param = torch.nn.Parameter(tensor, requires_grad=False) setattr(layer, name, param) # Step 2: transform weight (humming standard format) for forwarding HummingMethod.prepare_layer_meta( layer=layer, - shape_n=layer.output_partition_sizes_sum, - shape_k=layer.input_size_per_partition, + shape_n=sum(layer.output_partition_sizes), + shape_k=input_size_per_partition, weight_schema=weight_schema, input_schema=input_schema, pad_n_to_multiple=256, pad_k_to_multiple=128, has_bias=layer.has_bias, - torch_dtype=layer.param_dtype, + torch_dtype=layer.params_dtype, ) HummingMethod.transform_humming_layer(layer) + if not hasattr(layer, "locks"): + device = layer.weight.device + locks = torch.zeros(1024, dtype=torch.int32, device=device) + layer.register_buffer("locks", locks) + + compute_config = { + "use_batch_invariant": envs.VLLM_BATCH_INVARIANT, + "use_f16_accum": envs.VLLM_HUMMING_USE_F16_ACCUM, + "gemm_type": "dense", + } + + layer.compute_config = json.dumps(compute_config) def prepare_humming_moe_layer(layer: RoutedExperts, quant_config: dict): diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py index 2355f61ac51..5d0e3efe2e1 100644 --- a/vllm/model_executor/models/gemma4.py +++ b/vllm/model_executor/models/gemma4.py @@ -1057,7 +1057,7 @@ class Gemma4Model(nn.Module, EagleModelMixin): "normalizer", torch.tensor( config.hidden_size**0.5, - dtype=self.embed_tokens.weight.dtype, + dtype=vllm_config.model_config.dtype, ), persistent=False, ) @@ -1111,7 +1111,7 @@ class Gemma4Model(nn.Module, EagleModelMixin): ) self.hidden_states = torch.zeros( (max_num_tokens, config.hidden_size), - dtype=self.embed_tokens.weight.dtype, + dtype=vllm_config.model_config.dtype, device=device, ) if ( @@ -1124,7 +1124,7 @@ class Gemma4Model(nn.Module, EagleModelMixin): config.num_hidden_layers, self.hidden_size_per_layer_input, ), - dtype=self.embed_tokens.weight.dtype, + dtype=vllm_config.model_config.dtype, device=device, ) else: diff --git a/vllm/model_executor/models/gemma4_mm.py b/vllm/model_executor/models/gemma4_mm.py index f21dde96af5..2f7def54151 100644 --- a/vllm/model_executor/models/gemma4_mm.py +++ b/vllm/model_executor/models/gemma4_mm.py @@ -1011,14 +1011,17 @@ class Gemma4ForConditionalGeneration( self.config = config self.quant_config = quant_config self.multimodal_config = multimodal_config + self.model_dtype = vllm_config.model_config.dtype # Only quantize towers when the quant method supports their # dimensions. BNB/torchao handle arbitrary sizes; other methods # (Marlin, FP8, …) require dimensions divisible by 64, which # the vision tower (intermediate_size=4304) does not satisfy. + # TODO(mgoin): remove this by fixing kernel padding. if quant_config and quant_config.get_name() in [ "bitsandbytes", "torchao", + "compressed-tensors", ]: tower_quant = quant_config else: @@ -1081,12 +1084,13 @@ class Gemma4ForConditionalGeneration( # Some variants have hidden_size_per_layer_input=None (no PLE). ple_dim = config.text_config.hidden_size_per_layer_input if ple_dim is not None and ple_dim > 0: + embed = self.language_model.model.embed_tokens self.per_layer_embeddings = torch.zeros( vllm_config.scheduler_config.max_num_batched_tokens, config.text_config.num_hidden_layers, ple_dim, - device=self.language_model.model.embed_tokens.weight.device, - dtype=self.language_model.model.embed_tokens.weight.dtype, + device=next(embed.parameters()).device, + dtype=vllm_config.model_config.dtype, ) else: self.per_layer_embeddings = None @@ -1246,7 +1250,6 @@ class Gemma4ForConditionalGeneration( vt = self.vision_tower vision_cfg = self.config.vision_config pooling_k2 = vision_cfg.pooling_kernel_size**2 - target_dtype = self.language_model.model.embed_tokens.weight.dtype # Concurrent requests with different image resolutions may # arrive as a list of per-image tensors, while same-resolution @@ -1291,7 +1294,7 @@ class Gemma4ForConditionalGeneration( pv_tensor, pp_tensor, pad_tensor, - ).to(target_dtype) + ).to(self.model_dtype) encoder_outputs = vt.encoder( inputs_embeds=inputs_embeds, attention_mask=~pad_tensor, @@ -1328,12 +1331,8 @@ class Gemma4ForConditionalGeneration( all_valid_states[orig_idx] = valid_states valid_lens[orig_idx] = valid_states.shape[0] - # Use embed_tokens dtype as compute dtype; embedding_projection.weight - # may be uint8 under BnB 4-bit, which would corrupt the cast. - target_dtype = self.language_model.model.embed_tokens.weight.dtype - # Project all images in a single batched call. - flat_valid_states = torch.cat(all_valid_states, dim=0).to(target_dtype) + flat_valid_states = torch.cat(all_valid_states, dim=0).to(self.model_dtype) flat_proj_embs = self.embed_vision( inputs_embeds=flat_valid_states.unsqueeze(0) ).squeeze(0) @@ -1373,7 +1372,6 @@ class Gemma4ForConditionalGeneration( vt = self.vision_tower vision_cfg = self.config.vision_config pooling_k2 = vision_cfg.pooling_kernel_size**2 - target_dtype = self.language_model.model.embed_tokens.weight.dtype if isinstance(frame_counts, torch.Tensor): fc_list = frame_counts.tolist() @@ -1405,7 +1403,7 @@ class Gemma4ForConditionalGeneration( pv_chunk, pp_chunk, pad_chunk, - ).to(target_dtype) + ).to(self.model_dtype) encoder_outputs = vt.encoder( inputs_embeds=inputs_embeds, attention_mask=~pad_chunk, @@ -1440,7 +1438,9 @@ class Gemma4ForConditionalGeneration( frame_valid_lens.append(valid_states.shape[0]) # Project all frames in a single batched call. - flat_valid_states = torch.cat(all_frame_valid_states, dim=0).to(target_dtype) + flat_valid_states = torch.cat(all_frame_valid_states, dim=0).to( + self.model_dtype + ) flat_proj_embs = self.embed_vision( inputs_embeds=flat_valid_states.unsqueeze(0) ).squeeze(0) diff --git a/vllm/model_executor/models/gemma4_unified.py b/vllm/model_executor/models/gemma4_unified.py index e5f3784ffe2..66fc914dc75 100644 --- a/vllm/model_executor/models/gemma4_unified.py +++ b/vllm/model_executor/models/gemma4_unified.py @@ -307,12 +307,13 @@ class Gemma4UnifiedForConditionalGeneration(Gemma4ForConditionalGeneration): None, ) if ple_dim is not None and ple_dim > 0: + embed = self.language_model.model.embed_tokens self.per_layer_embeddings = torch.zeros( vllm_config.scheduler_config.max_num_batched_tokens, config.text_config.num_hidden_layers, ple_dim, - device=self.language_model.model.embed_tokens.weight.device, - dtype=self.language_model.model.embed_tokens.weight.dtype, + device=next(embed.parameters()).device, + dtype=vllm_config.model_config.dtype, ) else: self.per_layer_embeddings = None