mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Quant] Support compressed-tensors WNA8O8Int linears and WNInt embeddings (#44340)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -38,7 +38,7 @@ pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
setuptools>=77.0.3,<81.0.0; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
||||
einops # Required for Qwen2-VL.
|
||||
compressed-tensors == 0.15.0.1 # required for compressed-tensors
|
||||
compressed-tensors == 0.17.0 # required for compressed-tensors
|
||||
depyf==0.20.0 # required for profiling and debugging with compilation config
|
||||
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
|
||||
watchfiles # required for http server to monitor the updates of TLS files
|
||||
|
||||
@@ -28,4 +28,4 @@ quack-kernels>=0.3.3
|
||||
tokenspeed-mla==0.1.2
|
||||
|
||||
# Humming kernels for quantization gemm
|
||||
humming-kernels[cu13]==0.1.2
|
||||
humming-kernels[cu13]==0.1.4
|
||||
|
||||
@@ -143,7 +143,7 @@ colorful==0.5.8
|
||||
# via ray
|
||||
colorlog==6.10.1
|
||||
# via optuna
|
||||
compressed-tensors==0.15.0.1
|
||||
compressed-tensors==0.17.0
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# -r requirements/test/../common.txt
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the Triton dequant-gather kernel used by
|
||||
``CompressedTensorsEmbeddingWNA16Int`` (quantized embedding lookup)."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from compressed_tensors.compressors.pack_quantized.helpers import unpack_from_int32
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_embedding import ( # noqa: E501
|
||||
_dequant_gather_triton,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def _dequant_gather_torch(
|
||||
ids: torch.Tensor,
|
||||
weight_packed: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
hidden: int,
|
||||
num_bits: int,
|
||||
) -> torch.Tensor:
|
||||
"""Reference: gather packed rows by id, unpack int32-packed INT, dequant."""
|
||||
n = ids.shape[0]
|
||||
int8 = unpack_from_int32(weight_packed[ids], num_bits, torch.Size([n, hidden]))
|
||||
scale_rows = weight_scale[ids]
|
||||
w = int8.to(scale_rows.dtype)
|
||||
if scale_rows.shape[1] == 1:
|
||||
return w * scale_rows
|
||||
ng = scale_rows.shape[1]
|
||||
return (w.view(n, ng, hidden // ng) * scale_rows.unsqueeze(-1)).view(n, hidden)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="Triton dequant kernel requires CUDA"
|
||||
)
|
||||
@pytest.mark.parametrize("num_bits", [2, 4, 8])
|
||||
@pytest.mark.parametrize("group_size", [0, 256]) # 0 -> channel
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("num_ids", [1, 17, 4096])
|
||||
def test_dequant_gather(num_bits, group_size, dtype, num_ids):
|
||||
torch.manual_seed(0)
|
||||
device = "cuda"
|
||||
vocab, hidden = 1000, 2048
|
||||
pack_factor = 32 // num_bits
|
||||
|
||||
# Random full-range int32 packed weights (covers the sign bit -> exercises the
|
||||
# arithmetic-shift + mask unpack path).
|
||||
weight_packed = torch.randint(
|
||||
-(2**31),
|
||||
2**31,
|
||||
(vocab, hidden // pack_factor),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
num_groups = 1 if group_size == 0 else hidden // group_size
|
||||
weight_scale = torch.rand(vocab, num_groups, dtype=dtype, device=device) + 0.01
|
||||
|
||||
ids = torch.randint(0, vocab, (num_ids,), dtype=torch.long, device=device)
|
||||
|
||||
out = _dequant_gather_triton(ids, weight_packed, weight_scale, hidden, num_bits)
|
||||
ref = _dequant_gather_torch(ids, weight_packed, weight_scale, hidden, num_bits)
|
||||
|
||||
assert out.shape == (num_ids, hidden)
|
||||
assert out.dtype == dtype
|
||||
torch.testing.assert_close(out, ref)
|
||||
@@ -45,6 +45,9 @@ from vllm.model_executor.kernels.linear.mixed_precision.dynamic_4bit import (
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.exllama import (
|
||||
ExllamaLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.humming import (
|
||||
HummingLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.machete import (
|
||||
MacheteLinearKernel,
|
||||
)
|
||||
@@ -345,6 +348,7 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
|
||||
MacheteLinearKernel,
|
||||
AllSparkLinearKernel,
|
||||
MarlinLinearKernel,
|
||||
HummingLinearKernel,
|
||||
ConchLinearKernel,
|
||||
ExllamaLinearKernel,
|
||||
TritonW4A16LinearKernel,
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Humming GEMM as a mixed-precision WNA16Int linear kernel."""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import _has_module
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class HummingLinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "Humming is only supported on CUDA"
|
||||
if not _has_module("humming"):
|
||||
return False, "Humming is not installed"
|
||||
if c.has_g_idx:
|
||||
return False, "Humming does not support act-order (g_idx)"
|
||||
if c.zero_points:
|
||||
return False, "Humming linear kernel only supports symmetric weights"
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
from vllm.model_executor.layers.quantization.utils.humming_utils import (
|
||||
convert_linear_layer_to_humming_standard,
|
||||
prepare_humming_layer,
|
||||
)
|
||||
|
||||
name_map = {"weight": self.w_q_name, "weight_scale": self.w_s_name}
|
||||
group_size = self.config.group_size
|
||||
quant_config = {
|
||||
"quant_method": "humming",
|
||||
"dtype": "int" + str(self.config.weight_type.size_bits),
|
||||
"group_size": 0 if group_size == -1 else group_size,
|
||||
}
|
||||
|
||||
convert_linear_layer_to_humming_standard(layer=layer, name_map=name_map)
|
||||
prepare_humming_layer(layer, quant_config)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from humming.layer import HummingMethod
|
||||
|
||||
flatten_inputs = x.view(-1, x.size(-1))
|
||||
output = HummingMethod.forward_layer(
|
||||
layer=layer,
|
||||
inputs=flatten_inputs,
|
||||
compute_config=layer.compute_config,
|
||||
)
|
||||
return output.view(*x.shape[:-1], output.size(-1))
|
||||
@@ -30,6 +30,9 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_embedding import ( # noqa: E501
|
||||
CompressedTensorsEmbeddingWNA16Int,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
|
||||
CompressedTensorsMoEMethod,
|
||||
)
|
||||
@@ -44,6 +47,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A8Mxfp8,
|
||||
CompressedTensorsW8A16Fp8,
|
||||
CompressedTensorsWNA8O8Int,
|
||||
CompressedTensorsWNA16,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
|
||||
@@ -56,7 +60,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
should_ignore_layer,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -177,6 +184,24 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
layer.scheme = quant_scheme
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
|
||||
# ParallelLMHead subclasses VocabParallelEmbedding but is handled above as
|
||||
# a linear; only true embedding lookups land here.
|
||||
if isinstance(layer, VocabParallelEmbedding):
|
||||
scheme_dict = self.get_scheme_dict(layer, layer_name=prefix)
|
||||
weight_quant = scheme_dict.get("weights") if scheme_dict else None
|
||||
if weight_quant is None:
|
||||
return None # unquantized embedding
|
||||
if not (
|
||||
isinstance(weight_quant, QuantizationArgs)
|
||||
and self._is_wNa16_group_channel(weight_quant, None)
|
||||
and weight_quant.type == QuantizationType.INT
|
||||
):
|
||||
raise ValueError(
|
||||
"compressed-tensors embeddings only support weight-only INT "
|
||||
f"group/channel (WNA16) quantization, got: {weight_quant}"
|
||||
)
|
||||
return CompressedTensorsEmbeddingWNA16Int(weight_quant)
|
||||
|
||||
if isinstance(layer, Attention):
|
||||
return CompressedTensorsKVCacheMethod(self)
|
||||
if isinstance(layer, RoutedExperts):
|
||||
@@ -324,6 +349,15 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
quant_config.get("input_activations")
|
||||
)
|
||||
)
|
||||
|
||||
# Static output-activation quant is applied as a float fake-quant
|
||||
# on the layer output; capture it when present.
|
||||
target_scheme_map[target]["output_activations"] = None
|
||||
output_activations = quant_config.get("output_activations")
|
||||
if output_activations:
|
||||
target_scheme_map[target]["output_activations"] = (
|
||||
QuantizationArgs.model_validate(output_activations)
|
||||
)
|
||||
return target_scheme_map
|
||||
|
||||
@classmethod
|
||||
@@ -604,10 +638,56 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
return is_channel_group and input_quant_none and is_static
|
||||
|
||||
@staticmethod
|
||||
def _is_wNa8o8_int(
|
||||
weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs | None,
|
||||
output_quant: QuantizationArgs | None,
|
||||
format: str | None,
|
||||
) -> bool:
|
||||
"""Weight N-bit INT (pack-quantized for sub-byte, int-quantized for 8-bit)
|
||||
with static per-tensor INT8 input/output activation quant, applied as a float
|
||||
fake-quant around a weight-only matmul."""
|
||||
is_int_pack_format = format in (
|
||||
CompressionFormat.pack_quantized.value,
|
||||
CompressionFormat.int_quantized.value,
|
||||
)
|
||||
is_channel_group = weight_quant.strategy in (
|
||||
QuantizationStrategy.CHANNEL.value,
|
||||
QuantizationStrategy.GROUP.value,
|
||||
)
|
||||
is_static_int = (
|
||||
weight_quant.type == QuantizationType.INT and not weight_quant.dynamic
|
||||
)
|
||||
is_intN_weight = is_static_int and is_channel_group and is_int_pack_format
|
||||
is_static_int8_in = (
|
||||
input_quant is not None
|
||||
and input_quant.type == QuantizationType.INT
|
||||
and input_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
and input_quant.num_bits == 8
|
||||
and not input_quant.dynamic
|
||||
)
|
||||
is_static_int8_out = (
|
||||
output_quant is not None
|
||||
and output_quant.type == QuantizationType.INT
|
||||
and output_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
and output_quant.num_bits == 8
|
||||
and not output_quant.dynamic
|
||||
)
|
||||
# Static int8-activation layers, plus sub-byte weight-only layers (e.g.
|
||||
# 2-bit lm_head) that marlin-backed WNA16 cannot serve. Standard 4/8-bit
|
||||
# weight-only (no activations) falls through to WNA16.
|
||||
is_subbyte_weight_only = weight_quant.num_bits not in WNA16_SUPPORTED_BITS
|
||||
needs_wNa8o8 = is_intN_weight and (
|
||||
(is_static_int8_in and is_static_int8_out) or is_subbyte_weight_only
|
||||
)
|
||||
return needs_wNa8o8
|
||||
|
||||
def _get_scheme_from_parts(
|
||||
self,
|
||||
weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs,
|
||||
output_quant: QuantizationArgs | None = None,
|
||||
format: str | None = None,
|
||||
layer_name: str | None = None,
|
||||
) -> "CompressedTensorsScheme":
|
||||
@@ -641,6 +721,19 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
actorder=weight_quant.actorder,
|
||||
)
|
||||
|
||||
# Must come before the WNA16 check; standard 4/8-bit weight-only (no
|
||||
# output-activation scale) still falls through to WNA16.
|
||||
if self._is_wNa8o8_int(weight_quant, input_quant, output_quant, format):
|
||||
return CompressedTensorsWNA8O8Int(
|
||||
num_bits=weight_quant.num_bits,
|
||||
strategy=weight_quant.strategy,
|
||||
group_size=weight_quant.group_size,
|
||||
has_input_act=input_quant is not None,
|
||||
has_output_act=output_quant is not None,
|
||||
layer_name=layer_name,
|
||||
quant_format=format,
|
||||
)
|
||||
|
||||
if (
|
||||
self._is_wNa16_group_channel(weight_quant, input_quant)
|
||||
and (format == CompressionFormat.pack_quantized.value)
|
||||
@@ -708,7 +801,10 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
input_symmetric=input_quant.symmetric,
|
||||
)
|
||||
|
||||
raise NotImplementedError("No compressed-tensors compatible scheme was found.")
|
||||
raise NotImplementedError(
|
||||
f"No compressed-tensors compatible scheme was found for {layer_name=}, "
|
||||
f"{weight_quant=}, {input_quant=}, {output_quant=}, {format=}"
|
||||
)
|
||||
|
||||
def get_scheme(
|
||||
self, layer: torch.nn.Module, layer_name: str | None = None
|
||||
@@ -731,10 +827,12 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
weight_quant = None
|
||||
input_quant = None
|
||||
output_quant = None
|
||||
format = None
|
||||
if scheme_dict:
|
||||
weight_quant = scheme_dict.get("weights")
|
||||
input_quant = scheme_dict.get("input_activations")
|
||||
output_quant = scheme_dict.get("output_activations")
|
||||
format = scheme_dict.get("format")
|
||||
|
||||
if weight_quant is None:
|
||||
@@ -746,6 +844,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
scheme = self._get_scheme_from_parts( # type: ignore
|
||||
weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
output_quant=output_quant,
|
||||
format=format,
|
||||
layer_name=layer_name,
|
||||
)
|
||||
|
||||
+170
@@ -0,0 +1,170 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Quantized embedding method for compressed-tensors.
|
||||
|
||||
Adds dequant-on-lookup support for a pack-quantized ``VocabParallelEmbedding``
|
||||
(2-8 bit INT, channel- or group-quantized). Only the gathered token rows are
|
||||
unpacked and dequantized, so the packed weight is never densified.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
|
||||
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
__all__ = ["CompressedTensorsEmbeddingWNA16Int"]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dequant_gather_kernel(
|
||||
ids_ptr,
|
||||
packed_ptr,
|
||||
scale_ptr,
|
||||
out_ptr,
|
||||
hidden,
|
||||
packed_cols,
|
||||
num_groups,
|
||||
NUM_BITS: tl.constexpr,
|
||||
PACK_FACTOR: tl.constexpr,
|
||||
GROUP_SIZE: tl.constexpr,
|
||||
BLOCK: tl.constexpr,
|
||||
):
|
||||
"""Gather embedding rows by token id, unpack int32-packed INT weights, and
|
||||
dequantize to ``out`` dtype in one pass (no int8 intermediate)."""
|
||||
row = tl.program_id(0)
|
||||
col = tl.program_id(1) * BLOCK + tl.arange(0, BLOCK)
|
||||
col_mask = col < hidden
|
||||
tid = tl.load(ids_ptr + row).to(tl.int64)
|
||||
|
||||
packed_idx = col // PACK_FACTOR
|
||||
shift = (col % PACK_FACTOR) * NUM_BITS
|
||||
packed = tl.load(
|
||||
packed_ptr + tid * packed_cols + packed_idx, mask=col_mask, other=0
|
||||
)
|
||||
q = ((packed >> shift) & ((1 << NUM_BITS) - 1)) - (1 << (NUM_BITS - 1))
|
||||
|
||||
if GROUP_SIZE == 0: # channel: one scale per row
|
||||
scale = tl.load(scale_ptr + tid)
|
||||
else: # group: one scale per (row, group)
|
||||
grp = col // GROUP_SIZE
|
||||
scale = tl.load(scale_ptr + tid * num_groups + grp, mask=col_mask, other=0.0)
|
||||
|
||||
out = q.to(tl.float32) * scale.to(tl.float32)
|
||||
tl.store(
|
||||
out_ptr + row * hidden + col, out.to(out_ptr.dtype.element_ty), mask=col_mask
|
||||
)
|
||||
|
||||
|
||||
def _dequant_gather_triton(
|
||||
ids: torch.Tensor,
|
||||
weight_packed: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
hidden: int,
|
||||
num_bits: int,
|
||||
) -> torch.Tensor:
|
||||
n = ids.numel()
|
||||
out = torch.empty(n, hidden, dtype=weight_scale.dtype, device=weight_packed.device)
|
||||
num_groups = weight_scale.shape[1]
|
||||
group_size = 0 if num_groups == 1 else hidden // num_groups
|
||||
block = min(triton.next_power_of_2(hidden), 1024)
|
||||
grid = (n, triton.cdiv(hidden, block))
|
||||
_dequant_gather_kernel[grid](
|
||||
ids,
|
||||
weight_packed,
|
||||
weight_scale,
|
||||
out,
|
||||
hidden,
|
||||
weight_packed.shape[1],
|
||||
num_groups,
|
||||
NUM_BITS=num_bits,
|
||||
PACK_FACTOR=32 // num_bits,
|
||||
GROUP_SIZE=group_size,
|
||||
BLOCK=block,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class CompressedTensorsEmbeddingWNA16Int(QuantizeMethodBase):
|
||||
def __init__(self, weight_quant: QuantizationArgs):
|
||||
self.num_bits = weight_quant.num_bits
|
||||
self.pack_factor = 32 // self.num_bits
|
||||
self.strategy = weight_quant.strategy
|
||||
self.group_size = weight_quant.group_size
|
||||
self.is_group = (
|
||||
self.strategy == QuantizationStrategy.GROUP.value
|
||||
and self.group_size is not None
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
weight_loader = extra_weight_attrs["weight_loader"]
|
||||
# Embedding weight is [num_embeddings(vocab), embedding_dim(hidden)];
|
||||
# vocab is the output (partitioned) dim, hidden is the input dim.
|
||||
vocab_pp = sum(output_partition_sizes)
|
||||
hidden = input_size_per_partition
|
||||
layer.hidden_size = hidden
|
||||
|
||||
weight_packed = PackedvLLMParameter(
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=self.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
data=torch.empty(vocab_pp, hidden // self.pack_factor, dtype=torch.int32),
|
||||
)
|
||||
|
||||
if self.is_group:
|
||||
assert hidden % self.group_size == 0
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
output_dim=0,
|
||||
input_dim=1,
|
||||
weight_loader=weight_loader,
|
||||
data=torch.empty(
|
||||
vocab_pp, hidden // self.group_size, dtype=params_dtype
|
||||
),
|
||||
)
|
||||
else:
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
data=torch.empty(vocab_pp, 1, dtype=params_dtype),
|
||||
)
|
||||
|
||||
weight_shape = BasevLLMParameter(
|
||||
data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_packed", weight_packed)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
|
||||
ids = input_.reshape(-1).contiguous()
|
||||
hidden = layer.hidden_size
|
||||
deq = _dequant_gather_triton(
|
||||
ids, layer.weight_packed, layer.weight_scale, hidden, self.num_bits
|
||||
)
|
||||
return deq.reshape(*input_.shape, hidden)
|
||||
|
||||
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError(
|
||||
"CompressedTensorsEmbeddingWNA16Int supports embedding lookup only"
|
||||
)
|
||||
@@ -10,11 +10,13 @@ from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
||||
from .compressed_tensors_w8a8_mxfp8 import CompressedTensorsW8A8Mxfp8
|
||||
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
||||
from .compressed_tensors_wNa8o8 import CompressedTensorsWNA8O8Int
|
||||
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16
|
||||
|
||||
__all__ = [
|
||||
"CompressedTensorsScheme",
|
||||
"CompressedTensorsWNA16",
|
||||
"CompressedTensorsWNA8O8Int",
|
||||
"CompressedTensorsW8A16Fp8",
|
||||
"CompressedTensorsW8A8Int8",
|
||||
"CompressedTensorsW8A8Fp8",
|
||||
|
||||
+257
@@ -0,0 +1,257 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Weight N-bit INT scheme with static INT8 input/output activation quant.
|
||||
|
||||
Handles compressed-tensors INT weight checkpoints that carry static per-tensor
|
||||
INT8 ``input_activations`` and/or ``output_activations``. The activation quant is
|
||||
reproduced as a float fake-quant on the layer input and output, around a
|
||||
weight-only matmul, rather than a fused int8 GEMM.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.compressors.pack_quantized.helpers import pack_to_int32
|
||||
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
MPLinearLayerConfig,
|
||||
choose_mp_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_repeat_scales_on_all_ranks,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PackedvLLMParameter,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
__all__ = ["CompressedTensorsWNA8O8Int", "fake_quant_static_int8"]
|
||||
|
||||
WNA8O8_SUPPORTED_TYPES_MAP = {
|
||||
2: scalar_types.uint2b2,
|
||||
4: scalar_types.uint4b8,
|
||||
8: scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
|
||||
def fake_quant_static_int8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
"""Static per-tensor symmetric INT8 quantize-dequantize, in x's dtype."""
|
||||
scale = scale.to(x.dtype)
|
||||
q = torch.clamp(torch.round(x / scale), -128.0, 127.0)
|
||||
return q * scale
|
||||
|
||||
|
||||
class CompressedTensorsWNA8O8Int(CompressedTensorsScheme):
|
||||
def __init__(
|
||||
self,
|
||||
num_bits: int,
|
||||
strategy: str,
|
||||
group_size: int | None = None,
|
||||
has_input_act: bool = False,
|
||||
has_output_act: bool = False,
|
||||
layer_name: str | None = None,
|
||||
quant_format: str = "pack-quantized",
|
||||
):
|
||||
self.num_bits = num_bits
|
||||
self.pack_factor = 32 // num_bits
|
||||
self.strategy = strategy
|
||||
self.group_size = -1 if group_size is None else group_size
|
||||
self.has_input_act = has_input_act
|
||||
self.has_output_act = has_output_act
|
||||
self.layer_name = layer_name
|
||||
# "pack-quantized" (sub-byte, int32-packed) or "int-quantized" (8-bit int8).
|
||||
self.quant_format = quant_format
|
||||
self.is_int_quantized = quant_format == "int-quantized"
|
||||
if num_bits not in WNA8O8_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits} for WNA8O8Int; "
|
||||
f"supported = {sorted(WNA8O8_SUPPORTED_TYPES_MAP)}"
|
||||
)
|
||||
self.quant_type = WNA8O8_SUPPORTED_TYPES_MAP[num_bits]
|
||||
self._input_scale: torch.Tensor | None = None
|
||||
self._output_scale: torch.Tensor | None = None
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_size: int,
|
||||
input_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
# Set for kernels' weight prep; also covers ParallelLMHead, which does
|
||||
# not set these in __init__.
|
||||
layer.output_partition_sizes = output_partition_sizes
|
||||
layer.params_dtype = params_dtype
|
||||
if not hasattr(layer, "has_bias"):
|
||||
layer.has_bias = False
|
||||
|
||||
mp_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
),
|
||||
weight_type=self.quant_type,
|
||||
act_type=params_dtype, # activation quant applied externally (SRQ)
|
||||
group_size=self.group_size,
|
||||
zero_points=False,
|
||||
has_g_idx=False,
|
||||
)
|
||||
self.kernel = choose_mp_linear_kernel(mp_config)(
|
||||
mp_config,
|
||||
w_q_param_name="weight_packed",
|
||||
w_s_param_name="weight_scale",
|
||||
)
|
||||
|
||||
self._register_weight(
|
||||
layer, input_size, input_size_per_partition, params_dtype, weight_loader
|
||||
)
|
||||
|
||||
def _register_weight(
|
||||
self, layer, input_size, input_size_per_partition, params_dtype, weight_loader
|
||||
):
|
||||
out = layer.output_size_per_partition
|
||||
if self.is_int_quantized:
|
||||
# Plain int8 weight; packed to the canonical int32 layout after load.
|
||||
layer.register_parameter(
|
||||
"weight",
|
||||
ModelWeightParameter(
|
||||
data=torch.empty(out, input_size_per_partition, dtype=torch.int8),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
),
|
||||
)
|
||||
else:
|
||||
layer.register_parameter(
|
||||
"weight_packed",
|
||||
PackedvLLMParameter(
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=self.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
data=torch.empty(
|
||||
out,
|
||||
input_size_per_partition // self.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
),
|
||||
)
|
||||
layer.register_parameter(
|
||||
"weight_shape",
|
||||
BasevLLMParameter(
|
||||
data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader
|
||||
),
|
||||
)
|
||||
|
||||
# Scale: per-output-channel, or per group along the input dim under TP.
|
||||
group_size = self.group_size if self.group_size != -1 else input_size
|
||||
partitioned = not marlin_repeat_scales_on_all_ranks(
|
||||
False, self.group_size, input_size != input_size_per_partition
|
||||
)
|
||||
scales = (input_size_per_partition if partitioned else input_size) // group_size
|
||||
scale_data = torch.empty(out, scales, dtype=params_dtype)
|
||||
if partitioned:
|
||||
assert input_size_per_partition % group_size == 0
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=scale_data, output_dim=0, input_dim=1, weight_loader=weight_loader
|
||||
)
|
||||
else:
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=scale_data, output_dim=0, weight_loader=weight_loader
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
for name, present in (
|
||||
("input_scale", self.has_input_act),
|
||||
("output_scale", self.has_output_act),
|
||||
):
|
||||
if present:
|
||||
layer.register_parameter(
|
||||
name,
|
||||
BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
),
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Lift the static activation scales off the layer (applied externally) so
|
||||
# the kernel only sees weight tensors. Drop uncalibrated (zero) scales.
|
||||
self._input_scale = self._take_act_scale(layer, "input_scale")
|
||||
self._output_scale = self._take_act_scale(layer, "output_scale")
|
||||
self.has_input_act = self._input_scale is not None
|
||||
self.has_output_act = self._output_scale is not None
|
||||
|
||||
if self.is_int_quantized:
|
||||
self._pack_int_quantized_weight(layer)
|
||||
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def _pack_int_quantized_weight(self, layer: torch.nn.Module) -> None:
|
||||
"""Normalize an int-quantized (plain int8) weight to the canonical
|
||||
``weight_packed`` int32 + ``weight_shape`` layout the MP kernels expect."""
|
||||
weight = layer.weight
|
||||
out_features, in_features = weight.shape
|
||||
packed = pack_to_int32(weight.data.contiguous(), self.num_bits)
|
||||
delattr(layer, "weight")
|
||||
|
||||
def _noop_loader(*_, **__):
|
||||
return None
|
||||
|
||||
layer.register_parameter(
|
||||
"weight_packed",
|
||||
PackedvLLMParameter(
|
||||
data=packed.contiguous(),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=self.pack_factor,
|
||||
weight_loader=_noop_loader,
|
||||
),
|
||||
)
|
||||
layer.register_parameter(
|
||||
"weight_shape",
|
||||
BasevLLMParameter(
|
||||
data=torch.tensor([out_features, in_features], dtype=torch.int64),
|
||||
weight_loader=_noop_loader,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _take_act_scale(layer, name: str) -> torch.Tensor | None:
|
||||
param = getattr(layer, name, None)
|
||||
if param is None:
|
||||
return None
|
||||
scale = param.data.clone()
|
||||
delattr(layer, name)
|
||||
return None if float(scale.reshape(-1)[0]) == 0.0 else scale
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
if self.has_input_act:
|
||||
x = fake_quant_static_int8(x, self._input_scale)
|
||||
out = self.kernel.apply_weights(layer, x, bias)
|
||||
if self.has_output_act:
|
||||
out = fake_quant_static_int8(out, self._output_scale)
|
||||
return out
|
||||
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
@@ -42,16 +43,57 @@ def humming_is_layer_skipped(config: dict[str, Any], prefix: str):
|
||||
return False
|
||||
|
||||
|
||||
def convert_linear_layer_to_humming_standard(
|
||||
layer: LinearBase, name_map: dict[str, str]
|
||||
):
|
||||
"""Rename/reshape a linear layer's quantized params (the canonical MPLinear
|
||||
layout: ``weight_packed`` int32 + ``weight_scale``) into the parameter names
|
||||
and layout humming's weight schema expects (``weight`` / ``weight_scale``)."""
|
||||
for name, checkpoint_name in name_map.items():
|
||||
tensor = getattr(layer, checkpoint_name)
|
||||
delattr(layer, checkpoint_name)
|
||||
|
||||
if name == "weight":
|
||||
input_dim = getattr(tensor, "input_dim", 1)
|
||||
output_dim = getattr(tensor, "output_dim", 0)
|
||||
|
||||
if input_dim == 0 and output_dim == 1:
|
||||
tensor = tensor.transpose(1, 0).contiguous()
|
||||
else:
|
||||
assert output_dim == 0 and input_dim == 1
|
||||
|
||||
tensor = tensor.view(tensor.size(0), -1).view(torch.int32)
|
||||
elif name in ["weight_scale", "zero_point"]:
|
||||
if getattr(tensor, "output_dim", 0) == 1:
|
||||
tensor = tensor.transpose(0, 1).contiguous()
|
||||
if tensor.ndim == 1:
|
||||
tensor = tensor.unsqueeze(1)
|
||||
|
||||
tensor = tensor.view(torch.int32) if name == "zero_point" else tensor
|
||||
|
||||
if isinstance(tensor, torch.nn.Parameter):
|
||||
param = tensor
|
||||
else:
|
||||
param = torch.nn.Parameter(tensor, requires_grad=False)
|
||||
|
||||
setattr(layer, name, param)
|
||||
|
||||
|
||||
def prepare_humming_layer(layer: LinearBase, quant_config: dict):
|
||||
weight_schema = BaseWeightSchema.from_config(quant_config)
|
||||
input_schema = HummingInputSchema()
|
||||
|
||||
shape_k_stacks = [layer.input_size_per_partition]
|
||||
# ReplicatedLinear has no TP partitioning and so does not set
|
||||
# input_size_per_partition; for it that is just input_size.
|
||||
input_size_per_partition = getattr(
|
||||
layer, "input_size_per_partition", layer.input_size
|
||||
)
|
||||
shape_k_stacks = [input_size_per_partition]
|
||||
shape_n_stacks = layer.output_partition_sizes
|
||||
|
||||
# Step 1: convert weight to humming standard format
|
||||
weight_schema, tensors = weight_schema.convert_humming(
|
||||
tensors=layer.named_parameters(),
|
||||
tensors=dict(layer.named_parameters()),
|
||||
shape_n_stacks=shape_n_stacks,
|
||||
shape_k_stacks=shape_k_stacks,
|
||||
param_dtype=layer.params_dtype,
|
||||
@@ -63,23 +105,37 @@ def prepare_humming_layer(layer: LinearBase, quant_config: dict):
|
||||
delattr(layer, name)
|
||||
|
||||
for name, tensor in tensors.items():
|
||||
if isinstance(tensor, torch.nn.Parameter):
|
||||
tensor = tensor.data
|
||||
param = torch.nn.Parameter(tensor, requires_grad=False)
|
||||
setattr(layer, name, param)
|
||||
|
||||
# Step 2: transform weight (humming standard format) for forwarding
|
||||
HummingMethod.prepare_layer_meta(
|
||||
layer=layer,
|
||||
shape_n=layer.output_partition_sizes_sum,
|
||||
shape_k=layer.input_size_per_partition,
|
||||
shape_n=sum(layer.output_partition_sizes),
|
||||
shape_k=input_size_per_partition,
|
||||
weight_schema=weight_schema,
|
||||
input_schema=input_schema,
|
||||
pad_n_to_multiple=256,
|
||||
pad_k_to_multiple=128,
|
||||
has_bias=layer.has_bias,
|
||||
torch_dtype=layer.param_dtype,
|
||||
torch_dtype=layer.params_dtype,
|
||||
)
|
||||
|
||||
HummingMethod.transform_humming_layer(layer)
|
||||
if not hasattr(layer, "locks"):
|
||||
device = layer.weight.device
|
||||
locks = torch.zeros(1024, dtype=torch.int32, device=device)
|
||||
layer.register_buffer("locks", locks)
|
||||
|
||||
compute_config = {
|
||||
"use_batch_invariant": envs.VLLM_BATCH_INVARIANT,
|
||||
"use_f16_accum": envs.VLLM_HUMMING_USE_F16_ACCUM,
|
||||
"gemm_type": "dense",
|
||||
}
|
||||
|
||||
layer.compute_config = json.dumps(compute_config)
|
||||
|
||||
|
||||
def prepare_humming_moe_layer(layer: RoutedExperts, quant_config: dict):
|
||||
|
||||
@@ -1057,7 +1057,7 @@ class Gemma4Model(nn.Module, EagleModelMixin):
|
||||
"normalizer",
|
||||
torch.tensor(
|
||||
config.hidden_size**0.5,
|
||||
dtype=self.embed_tokens.weight.dtype,
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
),
|
||||
persistent=False,
|
||||
)
|
||||
@@ -1111,7 +1111,7 @@ class Gemma4Model(nn.Module, EagleModelMixin):
|
||||
)
|
||||
self.hidden_states = torch.zeros(
|
||||
(max_num_tokens, config.hidden_size),
|
||||
dtype=self.embed_tokens.weight.dtype,
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
if (
|
||||
@@ -1124,7 +1124,7 @@ class Gemma4Model(nn.Module, EagleModelMixin):
|
||||
config.num_hidden_layers,
|
||||
self.hidden_size_per_layer_input,
|
||||
),
|
||||
dtype=self.embed_tokens.weight.dtype,
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -1011,14 +1011,17 @@ class Gemma4ForConditionalGeneration(
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.model_dtype = vllm_config.model_config.dtype
|
||||
|
||||
# Only quantize towers when the quant method supports their
|
||||
# dimensions. BNB/torchao handle arbitrary sizes; other methods
|
||||
# (Marlin, FP8, …) require dimensions divisible by 64, which
|
||||
# the vision tower (intermediate_size=4304) does not satisfy.
|
||||
# TODO(mgoin): remove this by fixing kernel padding.
|
||||
if quant_config and quant_config.get_name() in [
|
||||
"bitsandbytes",
|
||||
"torchao",
|
||||
"compressed-tensors",
|
||||
]:
|
||||
tower_quant = quant_config
|
||||
else:
|
||||
@@ -1081,12 +1084,13 @@ class Gemma4ForConditionalGeneration(
|
||||
# Some variants have hidden_size_per_layer_input=None (no PLE).
|
||||
ple_dim = config.text_config.hidden_size_per_layer_input
|
||||
if ple_dim is not None and ple_dim > 0:
|
||||
embed = self.language_model.model.embed_tokens
|
||||
self.per_layer_embeddings = torch.zeros(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
config.text_config.num_hidden_layers,
|
||||
ple_dim,
|
||||
device=self.language_model.model.embed_tokens.weight.device,
|
||||
dtype=self.language_model.model.embed_tokens.weight.dtype,
|
||||
device=next(embed.parameters()).device,
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
)
|
||||
else:
|
||||
self.per_layer_embeddings = None
|
||||
@@ -1246,7 +1250,6 @@ class Gemma4ForConditionalGeneration(
|
||||
vt = self.vision_tower
|
||||
vision_cfg = self.config.vision_config
|
||||
pooling_k2 = vision_cfg.pooling_kernel_size**2
|
||||
target_dtype = self.language_model.model.embed_tokens.weight.dtype
|
||||
|
||||
# Concurrent requests with different image resolutions may
|
||||
# arrive as a list of per-image tensors, while same-resolution
|
||||
@@ -1291,7 +1294,7 @@ class Gemma4ForConditionalGeneration(
|
||||
pv_tensor,
|
||||
pp_tensor,
|
||||
pad_tensor,
|
||||
).to(target_dtype)
|
||||
).to(self.model_dtype)
|
||||
encoder_outputs = vt.encoder(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=~pad_tensor,
|
||||
@@ -1328,12 +1331,8 @@ class Gemma4ForConditionalGeneration(
|
||||
all_valid_states[orig_idx] = valid_states
|
||||
valid_lens[orig_idx] = valid_states.shape[0]
|
||||
|
||||
# Use embed_tokens dtype as compute dtype; embedding_projection.weight
|
||||
# may be uint8 under BnB 4-bit, which would corrupt the cast.
|
||||
target_dtype = self.language_model.model.embed_tokens.weight.dtype
|
||||
|
||||
# Project all images in a single batched call.
|
||||
flat_valid_states = torch.cat(all_valid_states, dim=0).to(target_dtype)
|
||||
flat_valid_states = torch.cat(all_valid_states, dim=0).to(self.model_dtype)
|
||||
flat_proj_embs = self.embed_vision(
|
||||
inputs_embeds=flat_valid_states.unsqueeze(0)
|
||||
).squeeze(0)
|
||||
@@ -1373,7 +1372,6 @@ class Gemma4ForConditionalGeneration(
|
||||
vt = self.vision_tower
|
||||
vision_cfg = self.config.vision_config
|
||||
pooling_k2 = vision_cfg.pooling_kernel_size**2
|
||||
target_dtype = self.language_model.model.embed_tokens.weight.dtype
|
||||
|
||||
if isinstance(frame_counts, torch.Tensor):
|
||||
fc_list = frame_counts.tolist()
|
||||
@@ -1405,7 +1403,7 @@ class Gemma4ForConditionalGeneration(
|
||||
pv_chunk,
|
||||
pp_chunk,
|
||||
pad_chunk,
|
||||
).to(target_dtype)
|
||||
).to(self.model_dtype)
|
||||
encoder_outputs = vt.encoder(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=~pad_chunk,
|
||||
@@ -1440,7 +1438,9 @@ class Gemma4ForConditionalGeneration(
|
||||
frame_valid_lens.append(valid_states.shape[0])
|
||||
|
||||
# Project all frames in a single batched call.
|
||||
flat_valid_states = torch.cat(all_frame_valid_states, dim=0).to(target_dtype)
|
||||
flat_valid_states = torch.cat(all_frame_valid_states, dim=0).to(
|
||||
self.model_dtype
|
||||
)
|
||||
flat_proj_embs = self.embed_vision(
|
||||
inputs_embeds=flat_valid_states.unsqueeze(0)
|
||||
).squeeze(0)
|
||||
|
||||
@@ -307,12 +307,13 @@ class Gemma4UnifiedForConditionalGeneration(Gemma4ForConditionalGeneration):
|
||||
None,
|
||||
)
|
||||
if ple_dim is not None and ple_dim > 0:
|
||||
embed = self.language_model.model.embed_tokens
|
||||
self.per_layer_embeddings = torch.zeros(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
config.text_config.num_hidden_layers,
|
||||
ple_dim,
|
||||
device=self.language_model.model.embed_tokens.weight.device,
|
||||
dtype=self.language_model.model.embed_tokens.weight.dtype,
|
||||
device=next(embed.parameters()).device,
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
)
|
||||
else:
|
||||
self.per_layer_embeddings = None
|
||||
|
||||
Reference in New Issue
Block a user