[Feature] TurboQuant: support hybrid models and uniform quantization (#39931)

Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: Jim Smith <jhsmith0@me.com>
Co-authored-by: Jim Smith <jhsmith0@me.com>
Co-authored-by: Sandermage <sandermage@users.noreply.github.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
JartX
2026-05-05 02:14:01 +02:00
committed by GitHub
parent 577b9623e6
commit 4f2af1a7c0
4 changed files with 185 additions and 26 deletions
+82 -4
View File
@@ -182,22 +182,100 @@ class TestTurboQuantConfig:
# ---- Boundary skip layers ----
@staticmethod
def _dense_model_config(num_layers):
from types import SimpleNamespace
return SimpleNamespace(
is_hybrid=False,
hf_text_config=SimpleNamespace(num_hidden_layers=num_layers),
)
def test_boundary_skip_layers_basic(self):
layers = TurboQuantConfig.get_boundary_skip_layers(32)
mc = self._dense_model_config(32)
layers = TurboQuantConfig.get_boundary_skip_layers(mc)
assert layers == ["0", "1", "30", "31"]
def test_boundary_skip_layers_zero(self):
assert TurboQuantConfig.get_boundary_skip_layers(32, 0) == []
mc = self._dense_model_config(32)
assert TurboQuantConfig.get_boundary_skip_layers(mc, 0) == []
def test_boundary_skip_layers_small_model(self):
layers = TurboQuantConfig.get_boundary_skip_layers(4)
mc = self._dense_model_config(4)
layers = TurboQuantConfig.get_boundary_skip_layers(mc)
assert layers == ["0", "1", "2", "3"]
def test_boundary_skip_layers_cap_at_half(self):
layers = TurboQuantConfig.get_boundary_skip_layers(8, 10)
mc = self._dense_model_config(8)
layers = TurboQuantConfig.get_boundary_skip_layers(mc, 10)
assert len(layers) == 8
class TestHybridAttentionIndices:
"""Regression tests for boundary protection on hybrid models.
Hybrid models (attention + Mamba / linear-attention) identify KV-carrying
layers via layer_types / layers_block_type / attn_type_list. The helper
must return the *global* layer indices of the full-attention layers so
that kv_cache_dtype_skip_layers matches what extract_layer_index(prefix)
reports on the Attention layers at runtime.
"""
@staticmethod
def _fake_model_config(text_cfg=None, hf_cfg=None):
from types import SimpleNamespace
return SimpleNamespace(
hf_text_config=text_cfg if text_cfg is not None else SimpleNamespace(),
hf_config=hf_cfg if hf_cfg is not None else SimpleNamespace(),
)
def test_layer_types_full_attention(self):
from vllm.model_executor.layers.quantization.turboquant.config import (
_get_full_attention_layer_indices,
)
cfg = type("C", (), {})()
cfg.layer_types = [
"linear_attention",
"linear_attention",
"full_attention",
"linear_attention",
"full_attention",
"full_attention",
]
mc = self._fake_model_config(text_cfg=cfg)
assert _get_full_attention_layer_indices(mc) == [2, 4, 5]
def test_layers_block_type_jamba(self):
from vllm.model_executor.layers.quantization.turboquant.config import (
_get_full_attention_layer_indices,
)
cfg = type("C", (), {})()
cfg.layers_block_type = ["mamba", "attention", "mamba", "attention"]
mc = self._fake_model_config(text_cfg=cfg)
assert _get_full_attention_layer_indices(mc) == [1, 3]
def test_attn_type_list_minimax(self):
from vllm.model_executor.layers.quantization.turboquant.config import (
_get_full_attention_layer_indices,
)
hf = type("C", (), {})()
hf.attn_type_list = [0, 1, 0, 1, 1]
mc = self._fake_model_config(hf_cfg=hf)
assert _get_full_attention_layer_indices(mc) == [1, 3, 4]
def test_no_hybrid_hints_returns_empty(self):
from vllm.model_executor.layers.quantization.turboquant.config import (
_get_full_attention_layer_indices,
)
mc = self._fake_model_config()
assert _get_full_attention_layer_indices(mc) == []
# ============================================================================
# Centroids tests (CPU-only)
# ============================================================================
+3 -17
View File
@@ -1699,29 +1699,15 @@ class EngineArgs:
kv_offloading_backend=self.kv_offloading_backend,
)
# TurboQuant: auto-skip first/last 2 layers (boundary protection).
# These layers are most sensitive to quantization error.
# Users can add extra layers via --kv-cache-dtype-skip-layers.
if resolved_cache_dtype.startswith("turboquant_"):
if model_config.is_hybrid:
raise NotImplementedError(
"TurboQuant KV cache is not supported for hybrid "
"(attention + Mamba) models. Boundary layer protection "
"requires uniform attention layers."
)
from vllm.model_executor.layers.quantization.turboquant.config import (
TurboQuantConfig,
)
num_layers = model_config.hf_text_config.num_hidden_layers
boundary = TurboQuantConfig.get_boundary_skip_layers(num_layers)
boundary = TurboQuantConfig.get_boundary_skip_layers(model_config)
existing = set(cache_config.kv_cache_dtype_skip_layers)
merged = sorted(existing | set(boundary), key=lambda x: int(x))
cache_config.kv_cache_dtype_skip_layers = merged
logger.info(
"TQ: skipping layers %s for boundary protection (num_layers=%d)",
merged,
num_layers,
cache_config.kv_cache_dtype_skip_layers = sorted(
existing | set(boundary), key=int
)
ray_runtime_env = None
@@ -2,8 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant configuration."""
from __future__ import annotations
import logging
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = logging.getLogger(__name__)
# Named TQ presets: each maps to frozen config parameters.
# key_quant_bits: 8 = FP8 keys, 3-4 = MSE (Lloyd-Max) quantized keys.
@@ -159,12 +168,34 @@ class TurboQuantConfig:
return s + (s % 2) # round up to even
@staticmethod
def get_boundary_skip_layers(num_layers: int, n: int = 2) -> list[str]:
"""Get layer indices to skip TQ compression (boundary protection).
def get_boundary_skip_layers(
model_config: ModelConfig,
n: int = 2,
) -> list[str]:
"""Layer indices to skip TQ compression (boundary protection).
Returns first N and last N layer indices as strings, suitable for
kv_cache_dtype_skip_layers.
For hybrid models (attention + Mamba/linear-attention), boundary
protection is disabled — hybrids typically have only 8-12
full-attention layers and a hard n=2 on each side would cover
~40 % of them. The dense GSM8K baselines that motivate n=2
don't apply to hybrids.
For dense models, skips first N and last N attention layers.
Empirically required for aggressive presets (k3v4_nc, 3bit_nc)
— without it GSM8K drops ~30 points on Qwen3-4B.
"""
if model_config.is_hybrid:
attn_indices = _get_full_attention_layer_indices(model_config)
if not attn_indices:
raise NotImplementedError(
"TurboQuant KV cache requires identifiable "
"full-attention layers, but none were found in "
"the hybrid model config."
)
logger.info("TQ hybrid: full-attention layers %s", attn_indices)
return []
num_layers = model_config.hf_text_config.num_hidden_layers
if n <= 0 or num_layers <= 0:
return []
n = min(n, num_layers // 2) # don't skip more than half
@@ -175,7 +206,7 @@ class TurboQuantConfig:
return [str(i) for i in indices]
@staticmethod
def from_cache_dtype(cache_dtype: str, head_dim: int) -> "TurboQuantConfig":
def from_cache_dtype(cache_dtype: str, head_dim: int) -> TurboQuantConfig:
"""Create config from a named preset.
Valid presets: turboquant_k8v4, turboquant_4bit_nc, etc.
@@ -193,3 +224,31 @@ class TurboQuantConfig:
value_quant_bits=preset["value_quant_bits"],
norm_correction=preset["norm_correction"],
)
def _get_full_attention_layer_indices(model_config: ModelConfig) -> list[int]:
"""Global indices of full-attention layers in a hybrid model.
Covers the conventions used across vLLM: ``layer_types`` (Qwen3.5/Next),
``layers_block_type`` (Jamba/Zamba2), ``attn_type_list`` (Minimax).
"""
text_cfg = model_config.hf_text_config
hf_cfg = model_config.hf_config
layer_types = getattr(text_cfg, "layer_types", None)
if layer_types is not None:
return [
i for i, t in enumerate(layer_types) if t in ("full_attention", "attention")
]
layers_block_type = getattr(text_cfg, "layers_block_type", None)
if layers_block_type is not None:
return [
i for i, t in enumerate(layers_block_type) if t in ("attention", "hybrid")
]
attn_type_list = getattr(hf_cfg, "attn_type_list", None)
if attn_type_list is not None:
return [i for i, t in enumerate(attn_type_list) if t == 1]
return []
+36
View File
@@ -545,6 +545,42 @@ class Platform:
dtype=kv_cache_dtype,
kv_quant_mode=kv_quant_mode,
).page_size_bytes
elif cache_config.cache_dtype.startswith("turboquant_"):
# TQ has a packed K|V layout; the standard FullAttentionSpec
# formula over-sizes it and trips unify_kv_cache_spec_page_size
# when all attention layers are TQ. With mixed skip+TQ the skip
# layers still use the standard layout — take max so mamba
# padding covers the largest actual page.
from vllm.model_executor.layers.quantization.turboquant.config import (
TurboQuantConfig,
)
from vllm.v1.kv_cache_interface import TQFullAttentionSpec
tq_cfg = TurboQuantConfig.from_cache_dtype(
cache_config.cache_dtype, model_config.get_head_size()
)
tq_page = TQFullAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
head_size_v=model_config.get_head_size(),
dtype=kv_cache_dtype,
kv_quant_mode=kv_quant_mode,
tq_slot_size=tq_cfg.slot_size_aligned,
).page_size_bytes
if cache_config.kv_cache_dtype_skip_layers:
skip_page = FullAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=model_config.dtype,
).page_size_bytes
# lcm, not max: skip_page is often not a multiple of
# tq_page, so max would leave per-layer page sizes
# un-unifiable downstream.
attn_page_size_1_token = lcm(tq_page, skip_page)
else:
attn_page_size_1_token = tq_page
else:
attn_page_size_1_token = FullAttentionSpec(
block_size=1,