mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Feature] TurboQuant: support hybrid models and uniform quantization (#39931)
Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: Jim Smith <jhsmith0@me.com> Co-authored-by: Jim Smith <jhsmith0@me.com> Co-authored-by: Sandermage <sandermage@users.noreply.github.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -182,22 +182,100 @@ class TestTurboQuantConfig:
|
||||
|
||||
# ---- Boundary skip layers ----
|
||||
|
||||
@staticmethod
|
||||
def _dense_model_config(num_layers):
|
||||
from types import SimpleNamespace
|
||||
|
||||
return SimpleNamespace(
|
||||
is_hybrid=False,
|
||||
hf_text_config=SimpleNamespace(num_hidden_layers=num_layers),
|
||||
)
|
||||
|
||||
def test_boundary_skip_layers_basic(self):
|
||||
layers = TurboQuantConfig.get_boundary_skip_layers(32)
|
||||
mc = self._dense_model_config(32)
|
||||
layers = TurboQuantConfig.get_boundary_skip_layers(mc)
|
||||
assert layers == ["0", "1", "30", "31"]
|
||||
|
||||
def test_boundary_skip_layers_zero(self):
|
||||
assert TurboQuantConfig.get_boundary_skip_layers(32, 0) == []
|
||||
mc = self._dense_model_config(32)
|
||||
assert TurboQuantConfig.get_boundary_skip_layers(mc, 0) == []
|
||||
|
||||
def test_boundary_skip_layers_small_model(self):
|
||||
layers = TurboQuantConfig.get_boundary_skip_layers(4)
|
||||
mc = self._dense_model_config(4)
|
||||
layers = TurboQuantConfig.get_boundary_skip_layers(mc)
|
||||
assert layers == ["0", "1", "2", "3"]
|
||||
|
||||
def test_boundary_skip_layers_cap_at_half(self):
|
||||
layers = TurboQuantConfig.get_boundary_skip_layers(8, 10)
|
||||
mc = self._dense_model_config(8)
|
||||
layers = TurboQuantConfig.get_boundary_skip_layers(mc, 10)
|
||||
assert len(layers) == 8
|
||||
|
||||
|
||||
class TestHybridAttentionIndices:
|
||||
"""Regression tests for boundary protection on hybrid models.
|
||||
|
||||
Hybrid models (attention + Mamba / linear-attention) identify KV-carrying
|
||||
layers via layer_types / layers_block_type / attn_type_list. The helper
|
||||
must return the *global* layer indices of the full-attention layers so
|
||||
that kv_cache_dtype_skip_layers matches what extract_layer_index(prefix)
|
||||
reports on the Attention layers at runtime.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _fake_model_config(text_cfg=None, hf_cfg=None):
|
||||
from types import SimpleNamespace
|
||||
|
||||
return SimpleNamespace(
|
||||
hf_text_config=text_cfg if text_cfg is not None else SimpleNamespace(),
|
||||
hf_config=hf_cfg if hf_cfg is not None else SimpleNamespace(),
|
||||
)
|
||||
|
||||
def test_layer_types_full_attention(self):
|
||||
from vllm.model_executor.layers.quantization.turboquant.config import (
|
||||
_get_full_attention_layer_indices,
|
||||
)
|
||||
|
||||
cfg = type("C", (), {})()
|
||||
cfg.layer_types = [
|
||||
"linear_attention",
|
||||
"linear_attention",
|
||||
"full_attention",
|
||||
"linear_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
]
|
||||
mc = self._fake_model_config(text_cfg=cfg)
|
||||
assert _get_full_attention_layer_indices(mc) == [2, 4, 5]
|
||||
|
||||
def test_layers_block_type_jamba(self):
|
||||
from vllm.model_executor.layers.quantization.turboquant.config import (
|
||||
_get_full_attention_layer_indices,
|
||||
)
|
||||
|
||||
cfg = type("C", (), {})()
|
||||
cfg.layers_block_type = ["mamba", "attention", "mamba", "attention"]
|
||||
mc = self._fake_model_config(text_cfg=cfg)
|
||||
assert _get_full_attention_layer_indices(mc) == [1, 3]
|
||||
|
||||
def test_attn_type_list_minimax(self):
|
||||
from vllm.model_executor.layers.quantization.turboquant.config import (
|
||||
_get_full_attention_layer_indices,
|
||||
)
|
||||
|
||||
hf = type("C", (), {})()
|
||||
hf.attn_type_list = [0, 1, 0, 1, 1]
|
||||
mc = self._fake_model_config(hf_cfg=hf)
|
||||
assert _get_full_attention_layer_indices(mc) == [1, 3, 4]
|
||||
|
||||
def test_no_hybrid_hints_returns_empty(self):
|
||||
from vllm.model_executor.layers.quantization.turboquant.config import (
|
||||
_get_full_attention_layer_indices,
|
||||
)
|
||||
|
||||
mc = self._fake_model_config()
|
||||
assert _get_full_attention_layer_indices(mc) == []
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Centroids tests (CPU-only)
|
||||
# ============================================================================
|
||||
|
||||
@@ -1699,29 +1699,15 @@ class EngineArgs:
|
||||
kv_offloading_backend=self.kv_offloading_backend,
|
||||
)
|
||||
|
||||
# TurboQuant: auto-skip first/last 2 layers (boundary protection).
|
||||
# These layers are most sensitive to quantization error.
|
||||
# Users can add extra layers via --kv-cache-dtype-skip-layers.
|
||||
if resolved_cache_dtype.startswith("turboquant_"):
|
||||
if model_config.is_hybrid:
|
||||
raise NotImplementedError(
|
||||
"TurboQuant KV cache is not supported for hybrid "
|
||||
"(attention + Mamba) models. Boundary layer protection "
|
||||
"requires uniform attention layers."
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.turboquant.config import (
|
||||
TurboQuantConfig,
|
||||
)
|
||||
|
||||
num_layers = model_config.hf_text_config.num_hidden_layers
|
||||
boundary = TurboQuantConfig.get_boundary_skip_layers(num_layers)
|
||||
boundary = TurboQuantConfig.get_boundary_skip_layers(model_config)
|
||||
existing = set(cache_config.kv_cache_dtype_skip_layers)
|
||||
merged = sorted(existing | set(boundary), key=lambda x: int(x))
|
||||
cache_config.kv_cache_dtype_skip_layers = merged
|
||||
logger.info(
|
||||
"TQ: skipping layers %s for boundary protection (num_layers=%d)",
|
||||
merged,
|
||||
num_layers,
|
||||
cache_config.kv_cache_dtype_skip_layers = sorted(
|
||||
existing | set(boundary), key=int
|
||||
)
|
||||
|
||||
ray_runtime_env = None
|
||||
|
||||
@@ -2,8 +2,17 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""TurboQuant configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Named TQ presets: each maps to frozen config parameters.
|
||||
# key_quant_bits: 8 = FP8 keys, 3-4 = MSE (Lloyd-Max) quantized keys.
|
||||
@@ -159,12 +168,34 @@ class TurboQuantConfig:
|
||||
return s + (s % 2) # round up to even
|
||||
|
||||
@staticmethod
|
||||
def get_boundary_skip_layers(num_layers: int, n: int = 2) -> list[str]:
|
||||
"""Get layer indices to skip TQ compression (boundary protection).
|
||||
def get_boundary_skip_layers(
|
||||
model_config: ModelConfig,
|
||||
n: int = 2,
|
||||
) -> list[str]:
|
||||
"""Layer indices to skip TQ compression (boundary protection).
|
||||
|
||||
Returns first N and last N layer indices as strings, suitable for
|
||||
kv_cache_dtype_skip_layers.
|
||||
For hybrid models (attention + Mamba/linear-attention), boundary
|
||||
protection is disabled — hybrids typically have only 8-12
|
||||
full-attention layers and a hard n=2 on each side would cover
|
||||
~40 % of them. The dense GSM8K baselines that motivate n=2
|
||||
don't apply to hybrids.
|
||||
|
||||
For dense models, skips first N and last N attention layers.
|
||||
Empirically required for aggressive presets (k3v4_nc, 3bit_nc)
|
||||
— without it GSM8K drops ~30 points on Qwen3-4B.
|
||||
"""
|
||||
if model_config.is_hybrid:
|
||||
attn_indices = _get_full_attention_layer_indices(model_config)
|
||||
if not attn_indices:
|
||||
raise NotImplementedError(
|
||||
"TurboQuant KV cache requires identifiable "
|
||||
"full-attention layers, but none were found in "
|
||||
"the hybrid model config."
|
||||
)
|
||||
logger.info("TQ hybrid: full-attention layers %s", attn_indices)
|
||||
return []
|
||||
|
||||
num_layers = model_config.hf_text_config.num_hidden_layers
|
||||
if n <= 0 or num_layers <= 0:
|
||||
return []
|
||||
n = min(n, num_layers // 2) # don't skip more than half
|
||||
@@ -175,7 +206,7 @@ class TurboQuantConfig:
|
||||
return [str(i) for i in indices]
|
||||
|
||||
@staticmethod
|
||||
def from_cache_dtype(cache_dtype: str, head_dim: int) -> "TurboQuantConfig":
|
||||
def from_cache_dtype(cache_dtype: str, head_dim: int) -> TurboQuantConfig:
|
||||
"""Create config from a named preset.
|
||||
|
||||
Valid presets: turboquant_k8v4, turboquant_4bit_nc, etc.
|
||||
@@ -193,3 +224,31 @@ class TurboQuantConfig:
|
||||
value_quant_bits=preset["value_quant_bits"],
|
||||
norm_correction=preset["norm_correction"],
|
||||
)
|
||||
|
||||
|
||||
def _get_full_attention_layer_indices(model_config: ModelConfig) -> list[int]:
|
||||
"""Global indices of full-attention layers in a hybrid model.
|
||||
|
||||
Covers the conventions used across vLLM: ``layer_types`` (Qwen3.5/Next),
|
||||
``layers_block_type`` (Jamba/Zamba2), ``attn_type_list`` (Minimax).
|
||||
"""
|
||||
text_cfg = model_config.hf_text_config
|
||||
hf_cfg = model_config.hf_config
|
||||
|
||||
layer_types = getattr(text_cfg, "layer_types", None)
|
||||
if layer_types is not None:
|
||||
return [
|
||||
i for i, t in enumerate(layer_types) if t in ("full_attention", "attention")
|
||||
]
|
||||
|
||||
layers_block_type = getattr(text_cfg, "layers_block_type", None)
|
||||
if layers_block_type is not None:
|
||||
return [
|
||||
i for i, t in enumerate(layers_block_type) if t in ("attention", "hybrid")
|
||||
]
|
||||
|
||||
attn_type_list = getattr(hf_cfg, "attn_type_list", None)
|
||||
if attn_type_list is not None:
|
||||
return [i for i, t in enumerate(attn_type_list) if t == 1]
|
||||
|
||||
return []
|
||||
|
||||
@@ -545,6 +545,42 @@ class Platform:
|
||||
dtype=kv_cache_dtype,
|
||||
kv_quant_mode=kv_quant_mode,
|
||||
).page_size_bytes
|
||||
elif cache_config.cache_dtype.startswith("turboquant_"):
|
||||
# TQ has a packed K|V layout; the standard FullAttentionSpec
|
||||
# formula over-sizes it and trips unify_kv_cache_spec_page_size
|
||||
# when all attention layers are TQ. With mixed skip+TQ the skip
|
||||
# layers still use the standard layout — take max so mamba
|
||||
# padding covers the largest actual page.
|
||||
from vllm.model_executor.layers.quantization.turboquant.config import (
|
||||
TurboQuantConfig,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import TQFullAttentionSpec
|
||||
|
||||
tq_cfg = TurboQuantConfig.from_cache_dtype(
|
||||
cache_config.cache_dtype, model_config.get_head_size()
|
||||
)
|
||||
tq_page = TQFullAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
head_size_v=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
kv_quant_mode=kv_quant_mode,
|
||||
tq_slot_size=tq_cfg.slot_size_aligned,
|
||||
).page_size_bytes
|
||||
if cache_config.kv_cache_dtype_skip_layers:
|
||||
skip_page = FullAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=model_config.dtype,
|
||||
).page_size_bytes
|
||||
# lcm, not max: skip_page is often not a multiple of
|
||||
# tq_page, so max would leave per-layer page sizes
|
||||
# un-unifiable downstream.
|
||||
attn_page_size_1_token = lcm(tq_page, skip_page)
|
||||
else:
|
||||
attn_page_size_1_token = tq_page
|
||||
else:
|
||||
attn_page_size_1_token = FullAttentionSpec(
|
||||
block_size=1,
|
||||
|
||||
Reference in New Issue
Block a user