[Quant] Support compressed-tensors WNA8O8Int linears and WNInt embeddings (#44340)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-06-04 10:40:33 -04:00
committed by GitHub
parent b5235fca2e
commit 06ee2d8433
14 changed files with 744 additions and 27 deletions
+1 -1
View File
@@ -38,7 +38,7 @@ pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
setuptools>=77.0.3,<81.0.0; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
einops # Required for Qwen2-VL.
compressed-tensors == 0.15.0.1 # required for compressed-tensors
compressed-tensors == 0.17.0 # required for compressed-tensors
depyf==0.20.0 # required for profiling and debugging with compilation config
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
watchfiles # required for http server to monitor the updates of TLS files
+1 -1
View File
@@ -28,4 +28,4 @@ quack-kernels>=0.3.3
tokenspeed-mla==0.1.2
# Humming kernels for quantization gemm
humming-kernels[cu13]==0.1.2
humming-kernels[cu13]==0.1.4
+1 -1
View File
@@ -143,7 +143,7 @@ colorful==0.5.8
# via ray
colorlog==6.10.1
# via optuna
compressed-tensors==0.15.0.1
compressed-tensors==0.17.0
# via
# -c requirements/common.txt
# -r requirements/test/../common.txt
@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the Triton dequant-gather kernel used by
``CompressedTensorsEmbeddingWNA16Int`` (quantized embedding lookup)."""
import pytest
import torch
from compressed_tensors.compressors.pack_quantized.helpers import unpack_from_int32
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_embedding import ( # noqa: E501
_dequant_gather_triton,
)
from vllm.platforms import current_platform
def _dequant_gather_torch(
ids: torch.Tensor,
weight_packed: torch.Tensor,
weight_scale: torch.Tensor,
hidden: int,
num_bits: int,
) -> torch.Tensor:
"""Reference: gather packed rows by id, unpack int32-packed INT, dequant."""
n = ids.shape[0]
int8 = unpack_from_int32(weight_packed[ids], num_bits, torch.Size([n, hidden]))
scale_rows = weight_scale[ids]
w = int8.to(scale_rows.dtype)
if scale_rows.shape[1] == 1:
return w * scale_rows
ng = scale_rows.shape[1]
return (w.view(n, ng, hidden // ng) * scale_rows.unsqueeze(-1)).view(n, hidden)
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="Triton dequant kernel requires CUDA"
)
@pytest.mark.parametrize("num_bits", [2, 4, 8])
@pytest.mark.parametrize("group_size", [0, 256]) # 0 -> channel
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("num_ids", [1, 17, 4096])
def test_dequant_gather(num_bits, group_size, dtype, num_ids):
torch.manual_seed(0)
device = "cuda"
vocab, hidden = 1000, 2048
pack_factor = 32 // num_bits
# Random full-range int32 packed weights (covers the sign bit -> exercises the
# arithmetic-shift + mask unpack path).
weight_packed = torch.randint(
-(2**31),
2**31,
(vocab, hidden // pack_factor),
dtype=torch.int32,
device=device,
)
num_groups = 1 if group_size == 0 else hidden // group_size
weight_scale = torch.rand(vocab, num_groups, dtype=dtype, device=device) + 0.01
ids = torch.randint(0, vocab, (num_ids,), dtype=torch.long, device=device)
out = _dequant_gather_triton(ids, weight_packed, weight_scale, hidden, num_bits)
ref = _dequant_gather_torch(ids, weight_packed, weight_scale, hidden, num_bits)
assert out.shape == (num_ids, hidden)
assert out.dtype == dtype
torch.testing.assert_close(out, ref)
@@ -45,6 +45,9 @@ from vllm.model_executor.kernels.linear.mixed_precision.dynamic_4bit import (
from vllm.model_executor.kernels.linear.mixed_precision.exllama import (
ExllamaLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.humming import (
HummingLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.machete import (
MacheteLinearKernel,
)
@@ -345,6 +348,7 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
MacheteLinearKernel,
AllSparkLinearKernel,
MarlinLinearKernel,
HummingLinearKernel,
ConchLinearKernel,
ExllamaLinearKernel,
TritonW4A16LinearKernel,
@@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Humming GEMM as a mixed-precision WNA16Int linear kernel."""
import torch
from vllm.platforms import current_platform
from vllm.utils.import_utils import _has_module
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class HummingLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 75
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "Humming is only supported on CUDA"
if not _has_module("humming"):
return False, "Humming is not installed"
if c.has_g_idx:
return False, "Humming does not support act-order (g_idx)"
if c.zero_points:
return False, "Humming linear kernel only supports symmetric weights"
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
from vllm.model_executor.layers.quantization.utils.humming_utils import (
convert_linear_layer_to_humming_standard,
prepare_humming_layer,
)
name_map = {"weight": self.w_q_name, "weight_scale": self.w_s_name}
group_size = self.config.group_size
quant_config = {
"quant_method": "humming",
"dtype": "int" + str(self.config.weight_type.size_bits),
"group_size": 0 if group_size == -1 else group_size,
}
convert_linear_layer_to_humming_standard(layer=layer, name_map=name_map)
prepare_humming_layer(layer, quant_config)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
from humming.layer import HummingMethod
flatten_inputs = x.view(-1, x.size(-1))
output = HummingMethod.forward_layer(
layer=layer,
inputs=flatten_inputs,
compute_config=layer.compute_config,
)
return output.view(*x.shape[:-1], output.size(-1))
@@ -30,6 +30,9 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_embedding import ( # noqa: E501
CompressedTensorsEmbeddingWNA16Int,
)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
CompressedTensorsMoEMethod,
)
@@ -44,6 +47,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Int8,
CompressedTensorsW8A8Mxfp8,
CompressedTensorsW8A16Fp8,
CompressedTensorsWNA8O8Int,
CompressedTensorsWNA16,
)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
@@ -56,7 +60,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
should_ignore_layer,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.platforms import current_platform
if TYPE_CHECKING:
@@ -177,6 +184,24 @@ class CompressedTensorsConfig(QuantizationConfig):
layer.scheme = quant_scheme
return CompressedTensorsLinearMethod(self)
# ParallelLMHead subclasses VocabParallelEmbedding but is handled above as
# a linear; only true embedding lookups land here.
if isinstance(layer, VocabParallelEmbedding):
scheme_dict = self.get_scheme_dict(layer, layer_name=prefix)
weight_quant = scheme_dict.get("weights") if scheme_dict else None
if weight_quant is None:
return None # unquantized embedding
if not (
isinstance(weight_quant, QuantizationArgs)
and self._is_wNa16_group_channel(weight_quant, None)
and weight_quant.type == QuantizationType.INT
):
raise ValueError(
"compressed-tensors embeddings only support weight-only INT "
f"group/channel (WNA16) quantization, got: {weight_quant}"
)
return CompressedTensorsEmbeddingWNA16Int(weight_quant)
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, RoutedExperts):
@@ -324,6 +349,15 @@ class CompressedTensorsConfig(QuantizationConfig):
quant_config.get("input_activations")
)
)
# Static output-activation quant is applied as a float fake-quant
# on the layer output; capture it when present.
target_scheme_map[target]["output_activations"] = None
output_activations = quant_config.get("output_activations")
if output_activations:
target_scheme_map[target]["output_activations"] = (
QuantizationArgs.model_validate(output_activations)
)
return target_scheme_map
@classmethod
@@ -604,10 +638,56 @@ class CompressedTensorsConfig(QuantizationConfig):
return is_channel_group and input_quant_none and is_static
@staticmethod
def _is_wNa8o8_int(
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs | None,
output_quant: QuantizationArgs | None,
format: str | None,
) -> bool:
"""Weight N-bit INT (pack-quantized for sub-byte, int-quantized for 8-bit)
with static per-tensor INT8 input/output activation quant, applied as a float
fake-quant around a weight-only matmul."""
is_int_pack_format = format in (
CompressionFormat.pack_quantized.value,
CompressionFormat.int_quantized.value,
)
is_channel_group = weight_quant.strategy in (
QuantizationStrategy.CHANNEL.value,
QuantizationStrategy.GROUP.value,
)
is_static_int = (
weight_quant.type == QuantizationType.INT and not weight_quant.dynamic
)
is_intN_weight = is_static_int and is_channel_group and is_int_pack_format
is_static_int8_in = (
input_quant is not None
and input_quant.type == QuantizationType.INT
and input_quant.strategy == QuantizationStrategy.TENSOR.value
and input_quant.num_bits == 8
and not input_quant.dynamic
)
is_static_int8_out = (
output_quant is not None
and output_quant.type == QuantizationType.INT
and output_quant.strategy == QuantizationStrategy.TENSOR.value
and output_quant.num_bits == 8
and not output_quant.dynamic
)
# Static int8-activation layers, plus sub-byte weight-only layers (e.g.
# 2-bit lm_head) that marlin-backed WNA16 cannot serve. Standard 4/8-bit
# weight-only (no activations) falls through to WNA16.
is_subbyte_weight_only = weight_quant.num_bits not in WNA16_SUPPORTED_BITS
needs_wNa8o8 = is_intN_weight and (
(is_static_int8_in and is_static_int8_out) or is_subbyte_weight_only
)
return needs_wNa8o8
def _get_scheme_from_parts(
self,
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
output_quant: QuantizationArgs | None = None,
format: str | None = None,
layer_name: str | None = None,
) -> "CompressedTensorsScheme":
@@ -641,6 +721,19 @@ class CompressedTensorsConfig(QuantizationConfig):
actorder=weight_quant.actorder,
)
# Must come before the WNA16 check; standard 4/8-bit weight-only (no
# output-activation scale) still falls through to WNA16.
if self._is_wNa8o8_int(weight_quant, input_quant, output_quant, format):
return CompressedTensorsWNA8O8Int(
num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
group_size=weight_quant.group_size,
has_input_act=input_quant is not None,
has_output_act=output_quant is not None,
layer_name=layer_name,
quant_format=format,
)
if (
self._is_wNa16_group_channel(weight_quant, input_quant)
and (format == CompressionFormat.pack_quantized.value)
@@ -708,7 +801,10 @@ class CompressedTensorsConfig(QuantizationConfig):
input_symmetric=input_quant.symmetric,
)
raise NotImplementedError("No compressed-tensors compatible scheme was found.")
raise NotImplementedError(
f"No compressed-tensors compatible scheme was found for {layer_name=}, "
f"{weight_quant=}, {input_quant=}, {output_quant=}, {format=}"
)
def get_scheme(
self, layer: torch.nn.Module, layer_name: str | None = None
@@ -731,10 +827,12 @@ class CompressedTensorsConfig(QuantizationConfig):
weight_quant = None
input_quant = None
output_quant = None
format = None
if scheme_dict:
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")
output_quant = scheme_dict.get("output_activations")
format = scheme_dict.get("format")
if weight_quant is None:
@@ -746,6 +844,7 @@ class CompressedTensorsConfig(QuantizationConfig):
scheme = self._get_scheme_from_parts( # type: ignore
weight_quant=weight_quant,
input_quant=input_quant,
output_quant=output_quant,
format=format,
layer_name=layer_name,
)
@@ -0,0 +1,170 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Quantized embedding method for compressed-tensors.
Adds dequant-on-lookup support for a pack-quantized ``VocabParallelEmbedding``
(2-8 bit INT, channel- or group-quantized). Only the gathered token rows are
unpacked and dequantized, so the packed weight is never densified.
"""
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
from vllm.model_executor.parameter import (
BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter,
)
from vllm.triton_utils import tl, triton
__all__ = ["CompressedTensorsEmbeddingWNA16Int"]
@triton.jit
def _dequant_gather_kernel(
ids_ptr,
packed_ptr,
scale_ptr,
out_ptr,
hidden,
packed_cols,
num_groups,
NUM_BITS: tl.constexpr,
PACK_FACTOR: tl.constexpr,
GROUP_SIZE: tl.constexpr,
BLOCK: tl.constexpr,
):
"""Gather embedding rows by token id, unpack int32-packed INT weights, and
dequantize to ``out`` dtype in one pass (no int8 intermediate)."""
row = tl.program_id(0)
col = tl.program_id(1) * BLOCK + tl.arange(0, BLOCK)
col_mask = col < hidden
tid = tl.load(ids_ptr + row).to(tl.int64)
packed_idx = col // PACK_FACTOR
shift = (col % PACK_FACTOR) * NUM_BITS
packed = tl.load(
packed_ptr + tid * packed_cols + packed_idx, mask=col_mask, other=0
)
q = ((packed >> shift) & ((1 << NUM_BITS) - 1)) - (1 << (NUM_BITS - 1))
if GROUP_SIZE == 0: # channel: one scale per row
scale = tl.load(scale_ptr + tid)
else: # group: one scale per (row, group)
grp = col // GROUP_SIZE
scale = tl.load(scale_ptr + tid * num_groups + grp, mask=col_mask, other=0.0)
out = q.to(tl.float32) * scale.to(tl.float32)
tl.store(
out_ptr + row * hidden + col, out.to(out_ptr.dtype.element_ty), mask=col_mask
)
def _dequant_gather_triton(
ids: torch.Tensor,
weight_packed: torch.Tensor,
weight_scale: torch.Tensor,
hidden: int,
num_bits: int,
) -> torch.Tensor:
n = ids.numel()
out = torch.empty(n, hidden, dtype=weight_scale.dtype, device=weight_packed.device)
num_groups = weight_scale.shape[1]
group_size = 0 if num_groups == 1 else hidden // num_groups
block = min(triton.next_power_of_2(hidden), 1024)
grid = (n, triton.cdiv(hidden, block))
_dequant_gather_kernel[grid](
ids,
weight_packed,
weight_scale,
out,
hidden,
weight_packed.shape[1],
num_groups,
NUM_BITS=num_bits,
PACK_FACTOR=32 // num_bits,
GROUP_SIZE=group_size,
BLOCK=block,
)
return out
class CompressedTensorsEmbeddingWNA16Int(QuantizeMethodBase):
def __init__(self, weight_quant: QuantizationArgs):
self.num_bits = weight_quant.num_bits
self.pack_factor = 32 // self.num_bits
self.strategy = weight_quant.strategy
self.group_size = weight_quant.group_size
self.is_group = (
self.strategy == QuantizationStrategy.GROUP.value
and self.group_size is not None
)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight_loader = extra_weight_attrs["weight_loader"]
# Embedding weight is [num_embeddings(vocab), embedding_dim(hidden)];
# vocab is the output (partitioned) dim, hidden is the input dim.
vocab_pp = sum(output_partition_sizes)
hidden = input_size_per_partition
layer.hidden_size = hidden
weight_packed = PackedvLLMParameter(
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=self.pack_factor,
weight_loader=weight_loader,
data=torch.empty(vocab_pp, hidden // self.pack_factor, dtype=torch.int32),
)
if self.is_group:
assert hidden % self.group_size == 0
weight_scale = GroupQuantScaleParameter(
output_dim=0,
input_dim=1,
weight_loader=weight_loader,
data=torch.empty(
vocab_pp, hidden // self.group_size, dtype=params_dtype
),
)
else:
weight_scale = ChannelQuantScaleParameter(
output_dim=0,
weight_loader=weight_loader,
data=torch.empty(vocab_pp, 1, dtype=params_dtype),
)
weight_shape = BasevLLMParameter(
data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader
)
layer.register_parameter("weight_packed", weight_packed)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
ids = input_.reshape(-1).contiguous()
hidden = layer.hidden_size
deq = _dequant_gather_triton(
ids, layer.weight_packed, layer.weight_scale, hidden, self.num_bits
)
return deq.reshape(*input_.shape, hidden)
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError(
"CompressedTensorsEmbeddingWNA16Int supports embedding lookup only"
)
@@ -10,11 +10,13 @@ from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
from .compressed_tensors_w8a8_mxfp8 import CompressedTensorsW8A8Mxfp8
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
from .compressed_tensors_wNa8o8 import CompressedTensorsWNA8O8Int
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16
__all__ = [
"CompressedTensorsScheme",
"CompressedTensorsWNA16",
"CompressedTensorsWNA8O8Int",
"CompressedTensorsW8A16Fp8",
"CompressedTensorsW8A8Int8",
"CompressedTensorsW8A8Fp8",
@@ -0,0 +1,257 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Weight N-bit INT scheme with static INT8 input/output activation quant.
Handles compressed-tensors INT weight checkpoints that carry static per-tensor
INT8 ``input_activations`` and/or ``output_activations``. The activation quant is
reproduced as a float fake-quant on the layer input and output, around a
weight-only matmul, rather than a fused int8 GEMM.
"""
from collections.abc import Callable
import torch
from compressed_tensors.compressors.pack_quantized.helpers import pack_to_int32
from vllm.model_executor.kernels.linear import (
MPLinearLayerConfig,
choose_mp_linear_kernel,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks,
)
from vllm.model_executor.parameter import (
BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
ModelWeightParameter,
PackedvLLMParameter,
)
from vllm.scalar_type import scalar_types
__all__ = ["CompressedTensorsWNA8O8Int", "fake_quant_static_int8"]
WNA8O8_SUPPORTED_TYPES_MAP = {
2: scalar_types.uint2b2,
4: scalar_types.uint4b8,
8: scalar_types.uint8b128,
}
def fake_quant_static_int8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""Static per-tensor symmetric INT8 quantize-dequantize, in x's dtype."""
scale = scale.to(x.dtype)
q = torch.clamp(torch.round(x / scale), -128.0, 127.0)
return q * scale
class CompressedTensorsWNA8O8Int(CompressedTensorsScheme):
def __init__(
self,
num_bits: int,
strategy: str,
group_size: int | None = None,
has_input_act: bool = False,
has_output_act: bool = False,
layer_name: str | None = None,
quant_format: str = "pack-quantized",
):
self.num_bits = num_bits
self.pack_factor = 32 // num_bits
self.strategy = strategy
self.group_size = -1 if group_size is None else group_size
self.has_input_act = has_input_act
self.has_output_act = has_output_act
self.layer_name = layer_name
# "pack-quantized" (sub-byte, int32-packed) or "int-quantized" (8-bit int8).
self.quant_format = quant_format
self.is_int_quantized = quant_format == "int-quantized"
if num_bits not in WNA8O8_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits} for WNA8O8Int; "
f"supported = {sorted(WNA8O8_SUPPORTED_TYPES_MAP)}"
)
self.quant_type = WNA8O8_SUPPORTED_TYPES_MAP[num_bits]
self._input_scale: torch.Tensor | None = None
self._output_scale: torch.Tensor | None = None
@classmethod
def get_min_capability(cls) -> int:
return 70
def create_weights(
self,
layer: torch.nn.Module,
output_size: int,
input_size: int,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
# Set for kernels' weight prep; also covers ParallelLMHead, which does
# not set these in __init__.
layer.output_partition_sizes = output_partition_sizes
layer.params_dtype = params_dtype
if not hasattr(layer, "has_bias"):
layer.has_bias = False
mp_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=(
input_size_per_partition,
output_size_per_partition,
),
weight_type=self.quant_type,
act_type=params_dtype, # activation quant applied externally (SRQ)
group_size=self.group_size,
zero_points=False,
has_g_idx=False,
)
self.kernel = choose_mp_linear_kernel(mp_config)(
mp_config,
w_q_param_name="weight_packed",
w_s_param_name="weight_scale",
)
self._register_weight(
layer, input_size, input_size_per_partition, params_dtype, weight_loader
)
def _register_weight(
self, layer, input_size, input_size_per_partition, params_dtype, weight_loader
):
out = layer.output_size_per_partition
if self.is_int_quantized:
# Plain int8 weight; packed to the canonical int32 layout after load.
layer.register_parameter(
"weight",
ModelWeightParameter(
data=torch.empty(out, input_size_per_partition, dtype=torch.int8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
),
)
else:
layer.register_parameter(
"weight_packed",
PackedvLLMParameter(
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=self.pack_factor,
weight_loader=weight_loader,
data=torch.empty(
out,
input_size_per_partition // self.pack_factor,
dtype=torch.int32,
),
),
)
layer.register_parameter(
"weight_shape",
BasevLLMParameter(
data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader
),
)
# Scale: per-output-channel, or per group along the input dim under TP.
group_size = self.group_size if self.group_size != -1 else input_size
partitioned = not marlin_repeat_scales_on_all_ranks(
False, self.group_size, input_size != input_size_per_partition
)
scales = (input_size_per_partition if partitioned else input_size) // group_size
scale_data = torch.empty(out, scales, dtype=params_dtype)
if partitioned:
assert input_size_per_partition % group_size == 0
weight_scale = GroupQuantScaleParameter(
data=scale_data, output_dim=0, input_dim=1, weight_loader=weight_loader
)
else:
weight_scale = ChannelQuantScaleParameter(
data=scale_data, output_dim=0, weight_loader=weight_loader
)
layer.register_parameter("weight_scale", weight_scale)
for name, present in (
("input_scale", self.has_input_act),
("output_scale", self.has_output_act),
):
if present:
layer.register_parameter(
name,
BasevLLMParameter(
data=torch.empty(1, dtype=torch.float32),
weight_loader=weight_loader,
),
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Lift the static activation scales off the layer (applied externally) so
# the kernel only sees weight tensors. Drop uncalibrated (zero) scales.
self._input_scale = self._take_act_scale(layer, "input_scale")
self._output_scale = self._take_act_scale(layer, "output_scale")
self.has_input_act = self._input_scale is not None
self.has_output_act = self._output_scale is not None
if self.is_int_quantized:
self._pack_int_quantized_weight(layer)
self.kernel.process_weights_after_loading(layer)
def _pack_int_quantized_weight(self, layer: torch.nn.Module) -> None:
"""Normalize an int-quantized (plain int8) weight to the canonical
``weight_packed`` int32 + ``weight_shape`` layout the MP kernels expect."""
weight = layer.weight
out_features, in_features = weight.shape
packed = pack_to_int32(weight.data.contiguous(), self.num_bits)
delattr(layer, "weight")
def _noop_loader(*_, **__):
return None
layer.register_parameter(
"weight_packed",
PackedvLLMParameter(
data=packed.contiguous(),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=self.pack_factor,
weight_loader=_noop_loader,
),
)
layer.register_parameter(
"weight_shape",
BasevLLMParameter(
data=torch.tensor([out_features, in_features], dtype=torch.int64),
weight_loader=_noop_loader,
),
)
@staticmethod
def _take_act_scale(layer, name: str) -> torch.Tensor | None:
param = getattr(layer, name, None)
if param is None:
return None
scale = param.data.clone()
delattr(layer, name)
return None if float(scale.reshape(-1)[0]) == 0.0 else scale
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
) -> torch.Tensor:
if self.has_input_act:
x = fake_quant_static_int8(x, self._input_scale)
out = self.kernel.apply_weights(layer, x, bias)
if self.has_output_act:
out = fake_quant_static_int8(out, self._output_scale)
return out
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from typing import Any
import regex as re
@@ -42,16 +43,57 @@ def humming_is_layer_skipped(config: dict[str, Any], prefix: str):
return False
def convert_linear_layer_to_humming_standard(
layer: LinearBase, name_map: dict[str, str]
):
"""Rename/reshape a linear layer's quantized params (the canonical MPLinear
layout: ``weight_packed`` int32 + ``weight_scale``) into the parameter names
and layout humming's weight schema expects (``weight`` / ``weight_scale``)."""
for name, checkpoint_name in name_map.items():
tensor = getattr(layer, checkpoint_name)
delattr(layer, checkpoint_name)
if name == "weight":
input_dim = getattr(tensor, "input_dim", 1)
output_dim = getattr(tensor, "output_dim", 0)
if input_dim == 0 and output_dim == 1:
tensor = tensor.transpose(1, 0).contiguous()
else:
assert output_dim == 0 and input_dim == 1
tensor = tensor.view(tensor.size(0), -1).view(torch.int32)
elif name in ["weight_scale", "zero_point"]:
if getattr(tensor, "output_dim", 0) == 1:
tensor = tensor.transpose(0, 1).contiguous()
if tensor.ndim == 1:
tensor = tensor.unsqueeze(1)
tensor = tensor.view(torch.int32) if name == "zero_point" else tensor
if isinstance(tensor, torch.nn.Parameter):
param = tensor
else:
param = torch.nn.Parameter(tensor, requires_grad=False)
setattr(layer, name, param)
def prepare_humming_layer(layer: LinearBase, quant_config: dict):
weight_schema = BaseWeightSchema.from_config(quant_config)
input_schema = HummingInputSchema()
shape_k_stacks = [layer.input_size_per_partition]
# ReplicatedLinear has no TP partitioning and so does not set
# input_size_per_partition; for it that is just input_size.
input_size_per_partition = getattr(
layer, "input_size_per_partition", layer.input_size
)
shape_k_stacks = [input_size_per_partition]
shape_n_stacks = layer.output_partition_sizes
# Step 1: convert weight to humming standard format
weight_schema, tensors = weight_schema.convert_humming(
tensors=layer.named_parameters(),
tensors=dict(layer.named_parameters()),
shape_n_stacks=shape_n_stacks,
shape_k_stacks=shape_k_stacks,
param_dtype=layer.params_dtype,
@@ -63,23 +105,37 @@ def prepare_humming_layer(layer: LinearBase, quant_config: dict):
delattr(layer, name)
for name, tensor in tensors.items():
if isinstance(tensor, torch.nn.Parameter):
tensor = tensor.data
param = torch.nn.Parameter(tensor, requires_grad=False)
setattr(layer, name, param)
# Step 2: transform weight (humming standard format) for forwarding
HummingMethod.prepare_layer_meta(
layer=layer,
shape_n=layer.output_partition_sizes_sum,
shape_k=layer.input_size_per_partition,
shape_n=sum(layer.output_partition_sizes),
shape_k=input_size_per_partition,
weight_schema=weight_schema,
input_schema=input_schema,
pad_n_to_multiple=256,
pad_k_to_multiple=128,
has_bias=layer.has_bias,
torch_dtype=layer.param_dtype,
torch_dtype=layer.params_dtype,
)
HummingMethod.transform_humming_layer(layer)
if not hasattr(layer, "locks"):
device = layer.weight.device
locks = torch.zeros(1024, dtype=torch.int32, device=device)
layer.register_buffer("locks", locks)
compute_config = {
"use_batch_invariant": envs.VLLM_BATCH_INVARIANT,
"use_f16_accum": envs.VLLM_HUMMING_USE_F16_ACCUM,
"gemm_type": "dense",
}
layer.compute_config = json.dumps(compute_config)
def prepare_humming_moe_layer(layer: RoutedExperts, quant_config: dict):
+3 -3
View File
@@ -1057,7 +1057,7 @@ class Gemma4Model(nn.Module, EagleModelMixin):
"normalizer",
torch.tensor(
config.hidden_size**0.5,
dtype=self.embed_tokens.weight.dtype,
dtype=vllm_config.model_config.dtype,
),
persistent=False,
)
@@ -1111,7 +1111,7 @@ class Gemma4Model(nn.Module, EagleModelMixin):
)
self.hidden_states = torch.zeros(
(max_num_tokens, config.hidden_size),
dtype=self.embed_tokens.weight.dtype,
dtype=vllm_config.model_config.dtype,
device=device,
)
if (
@@ -1124,7 +1124,7 @@ class Gemma4Model(nn.Module, EagleModelMixin):
config.num_hidden_layers,
self.hidden_size_per_layer_input,
),
dtype=self.embed_tokens.weight.dtype,
dtype=vllm_config.model_config.dtype,
device=device,
)
else:
+12 -12
View File
@@ -1011,14 +1011,17 @@ class Gemma4ForConditionalGeneration(
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.model_dtype = vllm_config.model_config.dtype
# Only quantize towers when the quant method supports their
# dimensions. BNB/torchao handle arbitrary sizes; other methods
# (Marlin, FP8, …) require dimensions divisible by 64, which
# the vision tower (intermediate_size=4304) does not satisfy.
# TODO(mgoin): remove this by fixing kernel padding.
if quant_config and quant_config.get_name() in [
"bitsandbytes",
"torchao",
"compressed-tensors",
]:
tower_quant = quant_config
else:
@@ -1081,12 +1084,13 @@ class Gemma4ForConditionalGeneration(
# Some variants have hidden_size_per_layer_input=None (no PLE).
ple_dim = config.text_config.hidden_size_per_layer_input
if ple_dim is not None and ple_dim > 0:
embed = self.language_model.model.embed_tokens
self.per_layer_embeddings = torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.num_hidden_layers,
ple_dim,
device=self.language_model.model.embed_tokens.weight.device,
dtype=self.language_model.model.embed_tokens.weight.dtype,
device=next(embed.parameters()).device,
dtype=vllm_config.model_config.dtype,
)
else:
self.per_layer_embeddings = None
@@ -1246,7 +1250,6 @@ class Gemma4ForConditionalGeneration(
vt = self.vision_tower
vision_cfg = self.config.vision_config
pooling_k2 = vision_cfg.pooling_kernel_size**2
target_dtype = self.language_model.model.embed_tokens.weight.dtype
# Concurrent requests with different image resolutions may
# arrive as a list of per-image tensors, while same-resolution
@@ -1291,7 +1294,7 @@ class Gemma4ForConditionalGeneration(
pv_tensor,
pp_tensor,
pad_tensor,
).to(target_dtype)
).to(self.model_dtype)
encoder_outputs = vt.encoder(
inputs_embeds=inputs_embeds,
attention_mask=~pad_tensor,
@@ -1328,12 +1331,8 @@ class Gemma4ForConditionalGeneration(
all_valid_states[orig_idx] = valid_states
valid_lens[orig_idx] = valid_states.shape[0]
# Use embed_tokens dtype as compute dtype; embedding_projection.weight
# may be uint8 under BnB 4-bit, which would corrupt the cast.
target_dtype = self.language_model.model.embed_tokens.weight.dtype
# Project all images in a single batched call.
flat_valid_states = torch.cat(all_valid_states, dim=0).to(target_dtype)
flat_valid_states = torch.cat(all_valid_states, dim=0).to(self.model_dtype)
flat_proj_embs = self.embed_vision(
inputs_embeds=flat_valid_states.unsqueeze(0)
).squeeze(0)
@@ -1373,7 +1372,6 @@ class Gemma4ForConditionalGeneration(
vt = self.vision_tower
vision_cfg = self.config.vision_config
pooling_k2 = vision_cfg.pooling_kernel_size**2
target_dtype = self.language_model.model.embed_tokens.weight.dtype
if isinstance(frame_counts, torch.Tensor):
fc_list = frame_counts.tolist()
@@ -1405,7 +1403,7 @@ class Gemma4ForConditionalGeneration(
pv_chunk,
pp_chunk,
pad_chunk,
).to(target_dtype)
).to(self.model_dtype)
encoder_outputs = vt.encoder(
inputs_embeds=inputs_embeds,
attention_mask=~pad_chunk,
@@ -1440,7 +1438,9 @@ class Gemma4ForConditionalGeneration(
frame_valid_lens.append(valid_states.shape[0])
# Project all frames in a single batched call.
flat_valid_states = torch.cat(all_frame_valid_states, dim=0).to(target_dtype)
flat_valid_states = torch.cat(all_frame_valid_states, dim=0).to(
self.model_dtype
)
flat_proj_embs = self.embed_vision(
inputs_embeds=flat_valid_states.unsqueeze(0)
).squeeze(0)
+3 -2
View File
@@ -307,12 +307,13 @@ class Gemma4UnifiedForConditionalGeneration(Gemma4ForConditionalGeneration):
None,
)
if ple_dim is not None and ple_dim > 0:
embed = self.language_model.model.embed_tokens
self.per_layer_embeddings = torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.num_hidden_layers,
ple_dim,
device=self.language_model.model.embed_tokens.weight.device,
dtype=self.language_model.model.embed_tokens.weight.dtype,
device=next(embed.parameters()).device,
dtype=vllm_config.model_config.dtype,
)
else:
self.per_layer_embeddings = None