mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Attention Backend] TurboQuant: 2-bit KV cache compression with 4x capacity (#38479)
Signed-off-by: vibhavagarwal5 <vibhavagarwal5@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Xinyu Chen <xinyu1.chen@intel.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -91,6 +91,16 @@ steps:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=evals/gsm8k/configs/moe-refactor-dp-ep/config-b200.txt
|
||||
|
||||
|
||||
- label: LM Eval TurboQuant KV Cache
|
||||
timeout_in_minutes: 75
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/layers/quantization/turboquant/
|
||||
- vllm/v1/attention/backends/turboquant_attn.py
|
||||
- vllm/v1/attention/ops/triton_turboquant_decode.py
|
||||
- vllm/v1/attention/ops/triton_turboquant_store.py
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=evals/gsm8k/configs/models-turboquant.txt
|
||||
|
||||
- label: GPQA Eval (GPT-OSS) (H100)
|
||||
timeout_in_minutes: 120
|
||||
device: h100
|
||||
|
||||
@@ -178,6 +178,7 @@ Priority is **1 = highest** (tried first).
|
||||
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ❌ | ✅ | ❌ | Decoder, Encoder, Encoder Only | N/A |
|
||||
| `TREE_ATTN` | | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
|
||||
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2`, `int8_per_token_head`, `fp8_per_token_head` | %16 | Any | ✅ | ✅ | ❌ | All | Any |
|
||||
| `TURBOQUANT` | | fp16, bf16 | `turboquant_k8v4`, `turboquant_4bit_nc`, `turboquant_k3v4_nc`, `turboquant_3bit_nc` | 16, 32, 64, 128 | Any | ❌ | ❌ | ❌ | Decoder | Any |
|
||||
|
||||
> **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`.
|
||||
>
|
||||
|
||||
@@ -170,6 +170,9 @@ eles = "eles"
|
||||
datas = "datas"
|
||||
ser = "ser"
|
||||
ure = "ure"
|
||||
# Walsh-Hadamard Transform
|
||||
wht = "wht"
|
||||
WHT = "WHT"
|
||||
|
||||
[tool.uv]
|
||||
no-build-isolation-package = ["torch"]
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
model_name: "Qwen/Qwen3-4B"
|
||||
accuracy_threshold: 0.78
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
server_args: "--kv-cache-dtype turboquant_k3v4_nc --enforce-eager --max-model-len 4096"
|
||||
@@ -0,0 +1,5 @@
|
||||
model_name: "Qwen/Qwen3-4B"
|
||||
accuracy_threshold: 0.80
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
server_args: "--kv-cache-dtype turboquant_k8v4 --enforce-eager --max-model-len 4096"
|
||||
@@ -0,0 +1,5 @@
|
||||
model_name: "Qwen/Qwen3-4B"
|
||||
accuracy_threshold: 0.75
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
server_args: "--kv-cache-dtype turboquant_3bit_nc --enforce-eager --max-model-len 4096"
|
||||
@@ -0,0 +1,5 @@
|
||||
model_name: "Qwen/Qwen3-4B"
|
||||
accuracy_threshold: 0.80
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
server_args: "--kv-cache-dtype turboquant_4bit_nc --enforce-eager --max-model-len 4096"
|
||||
@@ -0,0 +1,4 @@
|
||||
Qwen3-4B-TQ-k8v4.yaml
|
||||
Qwen3-4B-TQ-t4nc.yaml
|
||||
Qwen3-4B-TQ-k3v4nc.yaml
|
||||
Qwen3-4B-TQ-t3nc.yaml
|
||||
@@ -0,0 +1,570 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for TurboQuant KV-cache quantization.
|
||||
|
||||
Run: .venv/bin/python -m pytest tests/quantization/test_turboquant.py -v
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.turboquant.centroids import (
|
||||
get_centroids,
|
||||
solve_lloyd_max,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.turboquant.config import (
|
||||
TQ_PRESETS,
|
||||
TurboQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.turboquant.quantizer import (
|
||||
generate_wht_signs,
|
||||
)
|
||||
from vllm.utils.math_utils import next_power_of_2
|
||||
|
||||
# ============================================================================
|
||||
# Helpers
|
||||
# ============================================================================
|
||||
|
||||
ALL_PRESETS = list(TQ_PRESETS.keys())
|
||||
|
||||
|
||||
def _assert_strictly_sorted(seq, name="sequence"):
|
||||
for i in range(len(seq) - 1):
|
||||
assert seq[i] < seq[i + 1], f"{name} not sorted at index {i}"
|
||||
|
||||
|
||||
def _is_power_of_2(n: int) -> bool:
|
||||
return n > 0 and next_power_of_2(n) == n
|
||||
|
||||
|
||||
# Expected concrete values for each preset at head_dim=128.
|
||||
# fmt: off
|
||||
PRESET_EXPECTED = {
|
||||
"turboquant_k8v4": dict(
|
||||
key_fp8=True, key_quant_bits=8,
|
||||
key_mse_bits=0, value_quant_bits=4,
|
||||
mse_bits=4, n_centroids=16, centroid_bits=4,
|
||||
norm_correction=False,
|
||||
key_packed_size=128, value_packed_size=68,
|
||||
slot_size=196, slot_size_aligned=196,
|
||||
),
|
||||
"turboquant_4bit_nc": dict(
|
||||
key_fp8=False, key_quant_bits=4,
|
||||
key_mse_bits=4, value_quant_bits=4,
|
||||
mse_bits=4, n_centroids=16, centroid_bits=4,
|
||||
norm_correction=True,
|
||||
key_packed_size=66, value_packed_size=68,
|
||||
slot_size=134, slot_size_aligned=134,
|
||||
),
|
||||
"turboquant_k3v4_nc": dict(
|
||||
key_fp8=False, key_quant_bits=3,
|
||||
key_mse_bits=3, value_quant_bits=4,
|
||||
mse_bits=3, n_centroids=8, centroid_bits=3,
|
||||
norm_correction=True,
|
||||
key_packed_size=50, value_packed_size=68,
|
||||
slot_size=118, slot_size_aligned=118,
|
||||
),
|
||||
"turboquant_3bit_nc": dict(
|
||||
key_fp8=False, key_quant_bits=3,
|
||||
key_mse_bits=3, value_quant_bits=3,
|
||||
mse_bits=3, n_centroids=8, centroid_bits=3,
|
||||
norm_correction=True,
|
||||
key_packed_size=50, value_packed_size=52,
|
||||
slot_size=102, slot_size_aligned=102,
|
||||
),
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Config tests (CPU-only, no dependencies beyond config.py)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestTurboQuantConfig:
|
||||
@pytest.mark.parametrize("preset", ALL_PRESETS)
|
||||
def test_preset_parses(self, preset):
|
||||
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
|
||||
assert isinstance(cfg, TurboQuantConfig)
|
||||
|
||||
def test_invalid_preset_raises(self):
|
||||
with pytest.raises(ValueError, match="Unknown TurboQuant"):
|
||||
TurboQuantConfig.from_cache_dtype("turboquant_invalid", head_dim=128)
|
||||
|
||||
# ---- Per-preset concrete value checks (table-driven) ----
|
||||
|
||||
@pytest.mark.parametrize("preset", ALL_PRESETS)
|
||||
def test_key_mode(self, preset):
|
||||
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
|
||||
exp = PRESET_EXPECTED[preset]
|
||||
assert cfg.key_fp8 is exp["key_fp8"]
|
||||
assert cfg.key_quant_bits == exp["key_quant_bits"]
|
||||
assert cfg.key_mse_bits == exp["key_mse_bits"]
|
||||
|
||||
@pytest.mark.parametrize("preset", ALL_PRESETS)
|
||||
def test_value_mode(self, preset):
|
||||
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
|
||||
exp = PRESET_EXPECTED[preset]
|
||||
assert cfg.value_quant_bits == exp["value_quant_bits"]
|
||||
|
||||
@pytest.mark.parametrize("preset", ALL_PRESETS)
|
||||
def test_bits_and_centroids(self, preset):
|
||||
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
|
||||
exp = PRESET_EXPECTED[preset]
|
||||
assert cfg.mse_bits == exp["mse_bits"]
|
||||
assert cfg.n_centroids == exp["n_centroids"]
|
||||
assert cfg.centroid_bits == exp["centroid_bits"]
|
||||
|
||||
@pytest.mark.parametrize("preset", ALL_PRESETS)
|
||||
def test_norm_correction(self, preset):
|
||||
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
|
||||
assert cfg.norm_correction is PRESET_EXPECTED[preset]["norm_correction"]
|
||||
|
||||
@pytest.mark.parametrize("preset", ALL_PRESETS)
|
||||
def test_packed_sizes(self, preset):
|
||||
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
|
||||
exp = PRESET_EXPECTED[preset]
|
||||
assert cfg.key_packed_size == exp["key_packed_size"]
|
||||
assert cfg.value_packed_size == exp["value_packed_size"]
|
||||
assert cfg.slot_size == exp["slot_size"]
|
||||
assert cfg.slot_size_aligned == exp["slot_size_aligned"]
|
||||
|
||||
# ---- Cross-preset structural invariants ----
|
||||
|
||||
@pytest.mark.parametrize("preset", ALL_PRESETS)
|
||||
def test_slot_equals_key_plus_value(self, preset):
|
||||
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
|
||||
assert cfg.slot_size == cfg.key_packed_size + cfg.value_packed_size
|
||||
|
||||
@pytest.mark.parametrize("preset", ALL_PRESETS)
|
||||
def test_padded_slot_is_even(self, preset):
|
||||
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
|
||||
assert cfg.slot_size_aligned >= cfg.slot_size
|
||||
assert cfg.slot_size_aligned % 2 == 0, (
|
||||
f"slot_size_aligned={cfg.slot_size_aligned} is not even"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("preset", ALL_PRESETS)
|
||||
def test_key_value_packed_sizes_positive(self, preset):
|
||||
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
|
||||
assert cfg.key_packed_size > 0
|
||||
assert cfg.value_packed_size > 0
|
||||
|
||||
@pytest.mark.parametrize("preset", ALL_PRESETS)
|
||||
def test_n_centroids_is_2_to_mse_bits(self, preset):
|
||||
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
|
||||
assert cfg.n_centroids == 2**cfg.mse_bits
|
||||
|
||||
@pytest.mark.parametrize("preset", ALL_PRESETS)
|
||||
def test_centroid_bits_always_positive(self, preset):
|
||||
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
|
||||
assert cfg.centroid_bits > 0
|
||||
|
||||
@pytest.mark.parametrize("preset", ALL_PRESETS)
|
||||
def test_mse_key_or_fp8_exclusive(self, preset):
|
||||
"""Each preset is either FP8 keys or MSE keys, never both."""
|
||||
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
|
||||
if cfg.key_fp8:
|
||||
assert cfg.key_mse_bits == 0
|
||||
assert cfg.key_quant_bits == 8
|
||||
else:
|
||||
assert cfg.key_mse_bits > 0
|
||||
assert cfg.key_quant_bits in (3, 4)
|
||||
|
||||
@pytest.mark.parametrize("preset", ALL_PRESETS)
|
||||
@pytest.mark.parametrize("head_dim", [64, 96, 128, 256])
|
||||
def test_all_presets_all_head_dims(self, preset, head_dim):
|
||||
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=head_dim)
|
||||
assert cfg.head_dim == head_dim
|
||||
assert cfg.slot_size == cfg.key_packed_size + cfg.value_packed_size
|
||||
assert cfg.slot_size_aligned >= cfg.slot_size
|
||||
assert cfg.slot_size_aligned % 2 == 0
|
||||
|
||||
# ---- Boundary skip layers ----
|
||||
|
||||
def test_boundary_skip_layers_basic(self):
|
||||
layers = TurboQuantConfig.get_boundary_skip_layers(32)
|
||||
assert layers == ["0", "1", "30", "31"]
|
||||
|
||||
def test_boundary_skip_layers_zero(self):
|
||||
assert TurboQuantConfig.get_boundary_skip_layers(32, 0) == []
|
||||
|
||||
def test_boundary_skip_layers_small_model(self):
|
||||
layers = TurboQuantConfig.get_boundary_skip_layers(4)
|
||||
assert layers == ["0", "1", "2", "3"]
|
||||
|
||||
def test_boundary_skip_layers_cap_at_half(self):
|
||||
layers = TurboQuantConfig.get_boundary_skip_layers(8, 10)
|
||||
assert len(layers) == 8
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Centroids tests (CPU-only)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestCentroids:
|
||||
@pytest.mark.parametrize("bits,expected_n", [(2, 4), (3, 8), (4, 16)])
|
||||
def test_centroids_shape(self, bits, expected_n):
|
||||
c = get_centroids(128, bits)
|
||||
assert c.shape == (expected_n,)
|
||||
|
||||
@pytest.mark.parametrize("bits", [2, 3, 4])
|
||||
def test_centroids_sorted(self, bits):
|
||||
_assert_strictly_sorted(get_centroids(128, bits), "centroids")
|
||||
|
||||
def test_centroids_cached(self):
|
||||
c1 = get_centroids(128, 3)
|
||||
c2 = get_centroids(128, 3)
|
||||
assert c1 is c2, "get_centroids should return cached object"
|
||||
|
||||
def test_centroids_different_dims_not_identical(self):
|
||||
c64 = get_centroids(64, 3)
|
||||
c128 = get_centroids(128, 3)
|
||||
assert not torch.equal(c64, c128)
|
||||
|
||||
@pytest.mark.parametrize("bits", [2, 3, 4])
|
||||
def test_centroids_symmetric_around_zero(self, bits):
|
||||
"""N(0, 1/d) is symmetric, so centroids should be ~symmetric."""
|
||||
c = get_centroids(128, bits)
|
||||
assert abs(c.mean().item()) < 0.01, "Centroids not centered near 0"
|
||||
assert abs(c[0].item() + c[-1].item()) < 0.01
|
||||
|
||||
@pytest.mark.parametrize("bits", [2, 3, 4])
|
||||
def test_centroids_within_4sigma(self, bits):
|
||||
"""All centroids should be within ~4 sigma of N(0, 1/d)."""
|
||||
sigma = math.sqrt(1.0 / 128)
|
||||
c = get_centroids(128, bits)
|
||||
for i, val in enumerate(c):
|
||||
assert abs(val.item()) < 4 * sigma, (
|
||||
f"Centroid {i}={val:.6f} outside 4*sigma={4 * sigma:.6f}"
|
||||
)
|
||||
|
||||
|
||||
class TestLloydMax:
|
||||
@pytest.mark.parametrize("bits,expected_n", [(2, 4), (3, 8), (4, 16)])
|
||||
def test_solve_shapes(self, bits, expected_n):
|
||||
centroids, boundaries = solve_lloyd_max(128, bits)
|
||||
assert centroids.shape == (expected_n,)
|
||||
assert boundaries.shape == (expected_n - 1,)
|
||||
|
||||
@pytest.mark.parametrize("bits", [2, 3, 4])
|
||||
def test_centroids_sorted(self, bits):
|
||||
centroids, _ = solve_lloyd_max(128, bits)
|
||||
_assert_strictly_sorted(centroids, "centroids")
|
||||
|
||||
@pytest.mark.parametrize("bits", [2, 3, 4])
|
||||
def test_boundaries_sorted(self, bits):
|
||||
_, boundaries = solve_lloyd_max(128, bits)
|
||||
_assert_strictly_sorted(boundaries, "boundaries")
|
||||
|
||||
@pytest.mark.parametrize("bits", [2, 3, 4])
|
||||
def test_boundaries_between_centroids(self, bits):
|
||||
"""Each boundary must lie between its adjacent centroids."""
|
||||
centroids, boundaries = solve_lloyd_max(128, bits)
|
||||
for i in range(len(boundaries)):
|
||||
assert centroids[i] < boundaries[i] < centroids[i + 1], (
|
||||
f"Boundary {i}={boundaries[i]:.6f} not between "
|
||||
f"c[{i}]={centroids[i]:.6f} and c[{i + 1}]={centroids[i + 1]:.6f}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("bits", [2, 3, 4])
|
||||
def test_boundaries_are_midpoints(self, bits):
|
||||
"""Lloyd-Max boundaries are midpoints of adjacent centroids."""
|
||||
centroids, boundaries = solve_lloyd_max(128, bits)
|
||||
for i in range(len(boundaries)):
|
||||
expected = (centroids[i] + centroids[i + 1]) / 2.0
|
||||
assert abs(boundaries[i].item() - expected.item()) < 1e-6
|
||||
|
||||
def test_solve_deterministic(self):
|
||||
c1, b1 = solve_lloyd_max(128, 3)
|
||||
c2, b2 = solve_lloyd_max(128, 3)
|
||||
assert torch.equal(c1, c2)
|
||||
assert torch.equal(b1, b2)
|
||||
|
||||
def test_solve_dtype_float32(self):
|
||||
centroids, boundaries = solve_lloyd_max(128, 3)
|
||||
assert centroids.dtype == torch.float32
|
||||
assert boundaries.dtype == torch.float32
|
||||
|
||||
@pytest.mark.parametrize("bits", [3, 4])
|
||||
def test_centroids_match_scipy_reference(self, bits):
|
||||
"""Verify _trapz(n=200) centroids match scipy.integrate.quad reference.
|
||||
|
||||
This ensures our scipy-free trapezoid integration doesn't silently
|
||||
drift from the published Lloyd-Max quality.
|
||||
"""
|
||||
pytest.importorskip("scipy")
|
||||
from scipy.integrate import quad
|
||||
|
||||
d = 128
|
||||
sigma2 = 1.0 / d
|
||||
sigma = math.sqrt(sigma2)
|
||||
|
||||
def pdf(x):
|
||||
return (1.0 / math.sqrt(2 * math.pi * sigma2)) * math.exp(
|
||||
-x * x / (2 * sigma2)
|
||||
)
|
||||
|
||||
n_levels = 2**bits
|
||||
lo, hi = -3.5 * sigma, 3.5 * sigma
|
||||
ref_centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)]
|
||||
for _ in range(200):
|
||||
boundaries = [
|
||||
(ref_centroids[i] + ref_centroids[i + 1]) / 2.0
|
||||
for i in range(n_levels - 1)
|
||||
]
|
||||
edges = [lo * 3] + boundaries + [hi * 3]
|
||||
new_centroids = []
|
||||
for i in range(n_levels):
|
||||
a, b = edges[i], edges[i + 1]
|
||||
num, _ = quad(lambda x: x * pdf(x), a, b)
|
||||
den, _ = quad(pdf, a, b)
|
||||
new_centroids.append(num / den if den > 1e-15 else ref_centroids[i])
|
||||
if (
|
||||
max(abs(new_centroids[i] - ref_centroids[i]) for i in range(n_levels))
|
||||
< 1e-10
|
||||
):
|
||||
break
|
||||
ref_centroids = new_centroids
|
||||
|
||||
# Compare our _trapz centroids against scipy reference
|
||||
our_centroids, _ = solve_lloyd_max(d, bits)
|
||||
ref_t = torch.tensor(ref_centroids, dtype=torch.float32)
|
||||
max_err = (our_centroids - ref_t).abs().max().item()
|
||||
# _trapz(n=200) has ~O(h^2) error vs adaptive quad; 1e-3 is tight
|
||||
# enough to catch regression while allowing trapezoid approximation.
|
||||
assert max_err < 1e-3, (
|
||||
f"d={d}, bits={bits}: max centroid error vs scipy = {max_err:.2e}"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Rotation matrix tests (GPU required)
|
||||
# ============================================================================
|
||||
|
||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
|
||||
|
||||
def generate_rotation_matrix(d: int, seed: int, device: str = "cpu") -> torch.Tensor:
|
||||
"""Haar-distributed random orthogonal matrix via QR (test/benchmark only)."""
|
||||
gen = torch.Generator(device="cpu")
|
||||
gen.manual_seed(seed)
|
||||
G = torch.randn(d, d, generator=gen, device="cpu", dtype=torch.float32)
|
||||
Q, R = torch.linalg.qr(G)
|
||||
diag_sign = torch.sign(torch.diag(R))
|
||||
diag_sign[diag_sign == 0] = 1.0
|
||||
Q = Q * diag_sign.unsqueeze(0)
|
||||
return Q.to(device)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA not available")
|
||||
class TestRotationMatrix:
|
||||
"""Tests for the QR-based rotation (standalone benchmarks only)."""
|
||||
|
||||
@pytest.mark.parametrize("dim", [64, 96, 128, 256])
|
||||
def test_rotation_matrix_shape_and_orthogonal(self, dim):
|
||||
Pi = generate_rotation_matrix(dim, seed=42, device="cuda")
|
||||
assert Pi.shape == (dim, dim)
|
||||
eye = Pi @ Pi.T
|
||||
assert torch.allclose(eye, torch.eye(dim, device="cuda"), atol=1e-5), (
|
||||
f"Pi not orthogonal for dim={dim}"
|
||||
)
|
||||
|
||||
def test_rotation_matrix_deterministic(self):
|
||||
Pi1 = generate_rotation_matrix(128, seed=42)
|
||||
Pi2 = generate_rotation_matrix(128, seed=42)
|
||||
assert torch.equal(Pi1, Pi2)
|
||||
|
||||
def test_rotation_matrix_different_seeds(self):
|
||||
Pi1 = generate_rotation_matrix(128, seed=42)
|
||||
Pi2 = generate_rotation_matrix(128, seed=99)
|
||||
assert not torch.equal(Pi1, Pi2)
|
||||
|
||||
def test_rotation_matrix_det_is_pm1(self):
|
||||
"""Orthogonal matrix determinant must be +1 or -1."""
|
||||
Pi = generate_rotation_matrix(128, seed=42, device="cuda")
|
||||
det = torch.linalg.det(Pi)
|
||||
assert abs(abs(det.item()) - 1.0) < 1e-4
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# WHT rotation tests (serving path: generate_wht_signs + _build_hadamard)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _build_hadamard(d: int, device: str = "cpu") -> torch.Tensor:
|
||||
"""Reproduce the serving-path Hadamard construction."""
|
||||
H = torch.tensor([[1.0]])
|
||||
while H.shape[0] < d:
|
||||
H = torch.cat([torch.cat([H, H], 1), torch.cat([H, -H], 1)], 0)
|
||||
return (H / math.sqrt(d)).to(torch.device(device))
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA not available")
|
||||
class TestWHTRotation:
|
||||
"""Tests for the WHT rotation actually used in serving."""
|
||||
|
||||
@pytest.mark.parametrize("dim", [64, 128, 256])
|
||||
def test_wht_orthonormal(self, dim):
|
||||
"""signs * H must be orthonormal: (signs*H) @ (signs*H)^T = I."""
|
||||
signs = generate_wht_signs(dim, seed=42, device="cuda")
|
||||
H = _build_hadamard(dim, "cuda")
|
||||
PiT = (signs.unsqueeze(1) * H).contiguous()
|
||||
eye = PiT @ PiT.T
|
||||
assert torch.allclose(eye, torch.eye(dim, device="cuda"), atol=1e-5), (
|
||||
f"WHT rotation not orthonormal for dim={dim}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("dim", [64, 128, 256])
|
||||
def test_wht_self_inverse(self, dim):
|
||||
"""PiT should be self-inverse: PiT @ PiT = I (up to sign flip)."""
|
||||
signs = generate_wht_signs(dim, seed=42, device="cuda")
|
||||
H = _build_hadamard(dim, "cuda")
|
||||
PiT = (signs.unsqueeze(1) * H).contiguous()
|
||||
Pi = PiT.T.contiguous()
|
||||
# Pi @ PiT should be identity (rotation then inverse)
|
||||
result = Pi @ PiT
|
||||
assert torch.allclose(result, torch.eye(dim, device="cuda"), atol=1e-5), (
|
||||
f"WHT rotation not self-inverse for dim={dim}"
|
||||
)
|
||||
|
||||
def test_wht_signs_deterministic(self):
|
||||
"""Same seed must produce identical signs."""
|
||||
s1 = generate_wht_signs(128, seed=42)
|
||||
s2 = generate_wht_signs(128, seed=42)
|
||||
assert torch.equal(s1, s2)
|
||||
|
||||
def test_wht_signs_different_seeds(self):
|
||||
"""Different seeds must produce different signs."""
|
||||
s1 = generate_wht_signs(128, seed=42)
|
||||
s2 = generate_wht_signs(128, seed=99)
|
||||
assert not torch.equal(s1, s2)
|
||||
|
||||
def test_wht_signs_are_pm1(self):
|
||||
"""All sign values must be exactly +1 or -1."""
|
||||
signs = generate_wht_signs(128, seed=42)
|
||||
assert torch.all(signs.abs() == 1.0)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Store → Decode round-trip test (GPU + Triton required)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA not available")
|
||||
class TestStoreDecodeRoundTrip:
|
||||
"""End-to-end: store KV into TQ cache, decode, compare vs fp16 ref."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"preset",
|
||||
["turboquant_k8v4", "turboquant_4bit_nc"],
|
||||
)
|
||||
def test_single_token_roundtrip(self, preset):
|
||||
"""Store 1 token, decode with query=key, check attention output.
|
||||
|
||||
For a single token with query=key, attention output should equal
|
||||
the value (softmax over single key = 1.0). Quantization error
|
||||
means we check cosine similarity rather than exact equality.
|
||||
"""
|
||||
from vllm.model_executor.layers.quantization.turboquant.centroids import (
|
||||
solve_lloyd_max,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_turboquant_decode import (
|
||||
triton_turboquant_decode_attention,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_turboquant_store import (
|
||||
triton_turboquant_store,
|
||||
)
|
||||
|
||||
cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
|
||||
D = 128
|
||||
Hk = 4 # num_kv_heads
|
||||
Hq = 4 # num_q_heads (no GQA for simplicity)
|
||||
B = 1 # single token
|
||||
block_size = 16
|
||||
num_blocks = 1
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Generate rotation
|
||||
signs = generate_wht_signs(D, seed=42, device=device)
|
||||
H = _build_hadamard(D, "cuda")
|
||||
PiT = (signs.unsqueeze(1) * H).contiguous().float()
|
||||
Pi = PiT.T.contiguous()
|
||||
|
||||
# Generate centroids
|
||||
centroids, _ = solve_lloyd_max(D, cfg.centroid_bits)
|
||||
centroids = centroids.float().to(device)
|
||||
c_sorted, _ = centroids.sort()
|
||||
midpoints = ((c_sorted[:-1] + c_sorted[1:]) / 2).to(device)
|
||||
|
||||
# Random K, V
|
||||
torch.manual_seed(123)
|
||||
key = torch.randn(B, Hk, D, device=device, dtype=torch.float16)
|
||||
value = torch.randn(B, Hk, D, device=device, dtype=torch.float16)
|
||||
|
||||
# Allocate KV cache
|
||||
padded_slot = cfg.slot_size_aligned
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
block_size,
|
||||
Hk,
|
||||
padded_slot,
|
||||
device=device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
slot_mapping = torch.tensor([0], device=device, dtype=torch.int32)
|
||||
|
||||
# Store
|
||||
triton_turboquant_store(
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
PiT,
|
||||
midpoints,
|
||||
mse_bits=cfg.key_mse_bits,
|
||||
key_packed_size=cfg.key_packed_size,
|
||||
value_quant_bits=cfg.effective_value_quant_bits,
|
||||
key_fp8=cfg.key_fp8,
|
||||
)
|
||||
|
||||
# Decode: use key as query so attention = softmax([1]) * V = V
|
||||
query = key.expand(B, Hq, D).contiguous().to(torch.float16)
|
||||
block_table = torch.tensor([[0]], device=device, dtype=torch.int32)
|
||||
seq_lens = torch.tensor([1], device=device, dtype=torch.int32)
|
||||
|
||||
output = triton_turboquant_decode_attention(
|
||||
query=query,
|
||||
kv_cache=kv_cache,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
Pi=Pi,
|
||||
centroids=centroids,
|
||||
scale=1.0 / math.sqrt(D),
|
||||
mse_bits=cfg.key_mse_bits,
|
||||
key_packed_size=cfg.key_packed_size,
|
||||
value_quant_bits=cfg.effective_value_quant_bits,
|
||||
key_fp8=cfg.key_fp8,
|
||||
norm_correction=cfg.norm_correction,
|
||||
PiT=PiT,
|
||||
max_num_kv_splits=4,
|
||||
)
|
||||
|
||||
# With single KV, output should approximate the stored value.
|
||||
# Check per-head cosine similarity > threshold.
|
||||
out_fp32 = output.float()
|
||||
val_fp32 = value.expand(B, Hq, D).float()
|
||||
for h in range(Hq):
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
out_fp32[0, h].unsqueeze(0),
|
||||
val_fp32[0, h].unsqueeze(0),
|
||||
).item()
|
||||
# FP8 keys should be very accurate; MSE keys have more error
|
||||
threshold = 0.95 if cfg.key_fp8 else 0.85
|
||||
assert cos_sim > threshold, (
|
||||
f"Preset {preset} head {h}: cosine_sim={cos_sim:.4f} < {threshold}"
|
||||
)
|
||||
@@ -27,6 +27,11 @@ class AttentionConfig:
|
||||
flash_attn_max_num_splits_for_cuda_graph: int = 32
|
||||
"""Flash Attention max number splits for cuda graph decode."""
|
||||
|
||||
tq_max_kv_splits_for_cuda_graph: int = 32
|
||||
"""TurboQuant max NUM_KV_SPLITS for cuda graph decode.
|
||||
Fixes the split count so grid dimensions are constant across captures,
|
||||
and buffers can be pre-allocated to avoid inflating the memory estimate."""
|
||||
|
||||
use_cudnn_prefill: bool = False
|
||||
"""Whether to use cudnn prefill."""
|
||||
|
||||
|
||||
@@ -24,6 +24,10 @@ CacheDType = Literal[
|
||||
"fp8_e5m2",
|
||||
"fp8_inc",
|
||||
"fp8_ds_mla",
|
||||
"turboquant_k8v4",
|
||||
"turboquant_4bit_nc",
|
||||
"turboquant_k3v4_nc",
|
||||
"turboquant_3bit_nc",
|
||||
"int8_per_token_head",
|
||||
"fp8_per_token_head",
|
||||
]
|
||||
|
||||
@@ -1642,6 +1642,31 @@ class EngineArgs:
|
||||
kv_offloading_backend=self.kv_offloading_backend,
|
||||
)
|
||||
|
||||
# TurboQuant: auto-skip first/last 2 layers (boundary protection).
|
||||
# These layers are most sensitive to quantization error.
|
||||
# Users can add extra layers via --kv-cache-dtype-skip-layers.
|
||||
if resolved_cache_dtype.startswith("turboquant_"):
|
||||
if model_config.is_hybrid:
|
||||
raise NotImplementedError(
|
||||
"TurboQuant KV cache is not supported for hybrid "
|
||||
"(attention + Mamba) models. Boundary layer protection "
|
||||
"requires uniform attention layers."
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.turboquant.config import (
|
||||
TurboQuantConfig,
|
||||
)
|
||||
|
||||
num_layers = model_config.hf_text_config.num_hidden_layers
|
||||
boundary = TurboQuantConfig.get_boundary_skip_layers(num_layers)
|
||||
existing = set(cache_config.kv_cache_dtype_skip_layers)
|
||||
merged = sorted(existing | set(boundary), key=lambda x: int(x))
|
||||
cache_config.kv_cache_dtype_skip_layers = merged
|
||||
logger.info(
|
||||
"TQ: skipping layers %s for boundary protection (num_layers=%d)",
|
||||
merged,
|
||||
num_layers,
|
||||
)
|
||||
|
||||
ray_runtime_env = None
|
||||
if is_ray_initialized():
|
||||
# Ray Serve LLM calls `create_engine_config` in the context
|
||||
@@ -1948,6 +1973,19 @@ class EngineArgs:
|
||||
self.attention_backend
|
||||
)
|
||||
|
||||
# TurboQuant requires FlashAttention 2 — FA3 boundary layers assert
|
||||
# FlashAttentionImpl which fails with TurboQuantAttentionImpl.
|
||||
if resolved_cache_dtype.startswith("turboquant_") and (
|
||||
attention_config.flash_attn_version is None
|
||||
or attention_config.flash_attn_version >= 3
|
||||
):
|
||||
logger.warning(
|
||||
"TurboQuant is not yet compatible with FlashAttention >= 3. "
|
||||
"Overriding flash_attn_version to 2. To silence this "
|
||||
"warning, pass --attention-config.flash_attn_version=2"
|
||||
)
|
||||
attention_config.flash_attn_version = 2
|
||||
|
||||
# Mamba config overrides
|
||||
mamba_config = copy.deepcopy(self.mamba_config)
|
||||
# Convert string to enum if needed (CLI parsing returns a string)
|
||||
|
||||
@@ -379,6 +379,10 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
# Initialize KV cache quantization attributes
|
||||
_init_kv_cache_quant(self, quant_config, prefix)
|
||||
|
||||
# Initialize TurboQuant buffers (Pi, S, centroids) if tq cache dtype
|
||||
if kv_cache_dtype.startswith("turboquant_"):
|
||||
self._init_turboquant_buffers(kv_cache_dtype, head_size, prefix)
|
||||
|
||||
# for attn backends supporting query quantization
|
||||
self.query_quant = None
|
||||
if (
|
||||
@@ -397,6 +401,67 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
else GroupShape.PER_TENSOR,
|
||||
)
|
||||
|
||||
def _init_turboquant_buffers(
|
||||
self, cache_dtype: str, head_size: int, prefix: str
|
||||
) -> None:
|
||||
"""Initialize TurboQuant rotation/projection matrices and centroids."""
|
||||
from vllm.model_executor.layers.quantization.turboquant.centroids import (
|
||||
get_centroids,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.turboquant.config import (
|
||||
TurboQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.turboquant.quantizer import (
|
||||
generate_wht_signs,
|
||||
)
|
||||
|
||||
tq_config = TurboQuantConfig.from_cache_dtype(cache_dtype, head_size)
|
||||
|
||||
# Each layer needs a unique rotation matrix so quantization errors
|
||||
# don't correlate across layers. Stride must exceed max head_dim to
|
||||
# ensure non-overlapping RNG streams between adjacent layers.
|
||||
_TQ_LAYER_SEED_STRIDE = 1337
|
||||
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
seed = tq_config.seed + layer_idx * _TQ_LAYER_SEED_STRIDE
|
||||
|
||||
self.register_buffer(
|
||||
"_tq_signs",
|
||||
generate_wht_signs(head_size, seed=seed),
|
||||
)
|
||||
self.register_buffer(
|
||||
"_tq_centroids",
|
||||
get_centroids(head_size, tq_config.centroid_bits),
|
||||
)
|
||||
self._tq_config = tq_config
|
||||
|
||||
# Pre-allocate decode intermediate buffers so model.to(device) moves
|
||||
# them to GPU *before* the memory profiler runs. Without this the
|
||||
# profiler gives all free memory to KV cache blocks and the first
|
||||
# decode OOMs when these buffers are lazily allocated.
|
||||
_vllm_cfg = get_current_vllm_config()
|
||||
B = _vllm_cfg.scheduler_config.max_num_seqs
|
||||
Hq = self.num_heads
|
||||
S = _vllm_cfg.attention_config.tq_max_kv_splits_for_cuda_graph
|
||||
D = head_size
|
||||
self.register_buffer(
|
||||
"_tq_mid_o_buf",
|
||||
torch.empty(B, Hq, S, D + 1, dtype=torch.float32),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"_tq_output_buf",
|
||||
torch.empty(B, Hq, D, dtype=torch.float32),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"_tq_lse_buf",
|
||||
torch.empty(B, Hq, dtype=torch.float32),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@@ -544,6 +609,23 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
kv_quant_mode=quant_mode,
|
||||
sliding_window=self.sliding_window,
|
||||
)
|
||||
elif self.kv_cache_dtype.startswith("turboquant_"):
|
||||
from vllm.model_executor.layers.quantization.turboquant.config import (
|
||||
TurboQuantConfig,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import TQFullAttentionSpec
|
||||
|
||||
tq_config = TurboQuantConfig.from_cache_dtype(
|
||||
self.kv_cache_dtype, self.head_size
|
||||
)
|
||||
return TQFullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
head_size_v=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
tq_slot_size=tq_config.slot_size_aligned,
|
||||
)
|
||||
else:
|
||||
return FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""TurboQuant: Near-optimal KV-cache quantization for vLLM.
|
||||
|
||||
PolarQuant compression: random rotation + per-coordinate Lloyd-Max
|
||||
scalar quantization for keys, uniform quantization for values.
|
||||
|
||||
Reference: "TurboQuant: Online Vector Quantization with Near-optimal
|
||||
Distortion Rate" (ICLR 2026), Zandieh et al.
|
||||
"""
|
||||
|
||||
from vllm.model_executor.layers.quantization.turboquant.config import TurboQuantConfig
|
||||
|
||||
__all__ = ["TurboQuantConfig"]
|
||||
@@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Lloyd-Max optimal scalar quantizer for TurboQuant.
|
||||
|
||||
After rotating a d-dimensional unit vector by a random orthogonal matrix,
|
||||
each coordinate approximately follows N(0, 1/d) for d >= 64.
|
||||
We solve the Lloyd-Max conditions to find optimal centroids.
|
||||
|
||||
Based on: turboquant-pytorch/lloyd_max.py (Zandieh et al.)
|
||||
"""
|
||||
|
||||
import math
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _gaussian_pdf(x: float, sigma2: float) -> float:
|
||||
return (1.0 / math.sqrt(2 * math.pi * sigma2)) * math.exp(-x * x / (2 * sigma2))
|
||||
|
||||
|
||||
def _trapz(f, a: float, b: float, n: int = 200) -> float:
|
||||
"""Trapezoidal numerical integration (replaces scipy.integrate.quad)."""
|
||||
h = (b - a) / n
|
||||
result = 0.5 * (f(a) + f(b))
|
||||
for i in range(1, n):
|
||||
result += f(a + i * h)
|
||||
return result * h
|
||||
|
||||
|
||||
def solve_lloyd_max(
|
||||
d: int,
|
||||
bits: int,
|
||||
max_iter: int = 200,
|
||||
tol: float = 1e-10,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Solve Lloyd-Max optimal quantizer for N(0, 1/d) distribution.
|
||||
|
||||
Args:
|
||||
d: Vector dimension (determines variance = 1/d).
|
||||
bits: Number of quantization bits.
|
||||
max_iter: Maximum Lloyd-Max iterations.
|
||||
tol: Convergence tolerance.
|
||||
|
||||
Returns:
|
||||
centroids: Sorted tensor of 2^bits optimal centroids.
|
||||
boundaries: Sorted tensor of 2^bits - 1 decision boundaries.
|
||||
"""
|
||||
n_levels = 2**bits
|
||||
sigma2 = 1.0 / d
|
||||
sigma = math.sqrt(sigma2)
|
||||
|
||||
def pdf(x):
|
||||
return _gaussian_pdf(x, sigma2)
|
||||
|
||||
lo, hi = -3.5 * sigma, 3.5 * sigma
|
||||
centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)]
|
||||
|
||||
for _ in range(max_iter):
|
||||
boundaries = [
|
||||
(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)
|
||||
]
|
||||
edges = [lo * 3] + boundaries + [hi * 3]
|
||||
new_centroids = []
|
||||
for i in range(n_levels):
|
||||
a, b = edges[i], edges[i + 1]
|
||||
num = _trapz(lambda x: x * pdf(x), a, b)
|
||||
den = _trapz(pdf, a, b)
|
||||
new_centroids.append(num / den if den > 1e-15 else centroids[i])
|
||||
|
||||
if max(abs(new_centroids[i] - centroids[i]) for i in range(n_levels)) < tol:
|
||||
break
|
||||
centroids = new_centroids
|
||||
|
||||
boundaries = [(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)]
|
||||
return (
|
||||
torch.tensor(centroids, dtype=torch.float32),
|
||||
torch.tensor(boundaries, dtype=torch.float32),
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def get_centroids(d: int, bits: int) -> torch.Tensor:
|
||||
"""Get precomputed Lloyd-Max centroids (cached)."""
|
||||
centroids, _ = solve_lloyd_max(d, bits)
|
||||
return centroids
|
||||
@@ -0,0 +1,185 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""TurboQuant configuration."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Named TQ presets: each maps to frozen config parameters.
|
||||
# key_quant_bits: 8 = FP8 keys, 3-4 = MSE (Lloyd-Max) quantized keys.
|
||||
# value_quant_bits: 3-4 = uniform quantized values.
|
||||
TQ_PRESETS: dict[str, dict] = {
|
||||
"turboquant_k8v4": {
|
||||
"key_quant_bits": 8,
|
||||
"value_quant_bits": 4,
|
||||
"norm_correction": False,
|
||||
},
|
||||
"turboquant_4bit_nc": {
|
||||
"key_quant_bits": 4,
|
||||
"value_quant_bits": 4,
|
||||
"norm_correction": True,
|
||||
},
|
||||
"turboquant_k3v4_nc": {
|
||||
"key_quant_bits": 3,
|
||||
"value_quant_bits": 4,
|
||||
"norm_correction": True,
|
||||
},
|
||||
"turboquant_3bit_nc": {
|
||||
"key_quant_bits": 3,
|
||||
"value_quant_bits": 3,
|
||||
"norm_correction": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TurboQuantConfig:
|
||||
"""Configuration for TurboQuant KV-cache quantization.
|
||||
|
||||
Uses PolarQuant (WHT rotation + Lloyd-Max scalar quantization) for keys
|
||||
and uniform quantization for values. QJL is intentionally omitted —
|
||||
community consensus (5+ independent groups) found it hurts attention
|
||||
quality by amplifying variance through softmax.
|
||||
|
||||
Named presets (use via --kv-cache-dtype):
|
||||
turboquant_k8v4: FP8 keys + 4-bit values, 2.6x, +1.17% PPL
|
||||
turboquant_4bit_nc: 4-bit MSE keys + 4-bit values + NC, 3.8x, +2.71%
|
||||
turboquant_k3v4_nc: 3-bit MSE keys + 4-bit values + NC, ~3.5x, +10.63%
|
||||
turboquant_3bit_nc: 3-bit MSE keys + 3-bit values + NC, 4.9x, +20.59%
|
||||
|
||||
Args:
|
||||
head_dim: Attention head dimension (e.g. 64, 96, 128).
|
||||
key_quant_bits: Bits for key quantization. 8 = FP8 keys (no
|
||||
rotation/MSE). 3-4 = Lloyd-Max MSE quantized keys.
|
||||
value_quant_bits: Bits per value dimension for uniform quantization.
|
||||
3 = 8 levels, 4 = 16 levels (default).
|
||||
seed: Base seed for deterministic random matrix generation.
|
||||
Actual seed per layer = seed + layer_idx * 1337.
|
||||
norm_correction: Re-normalize centroid vectors to unit norm before
|
||||
inverse rotation during dequant. Fixes quantization-induced norm
|
||||
distortion, improving PPL by ~0.8% at 4-bit.
|
||||
"""
|
||||
|
||||
head_dim: int = 128
|
||||
key_quant_bits: int = 3 # 3-4 = MSE keys, 8 = FP8 keys
|
||||
value_quant_bits: int = 4 # 3-4 = uniform quantized values
|
||||
seed: int = 42
|
||||
norm_correction: bool = False
|
||||
|
||||
@property
|
||||
def key_fp8(self) -> bool:
|
||||
"""Whether keys are stored as FP8 — no rotation/quantization needed."""
|
||||
return self.key_quant_bits == 8
|
||||
|
||||
@property
|
||||
def mse_bits(self) -> int:
|
||||
"""MSE quantizer bit-width (determines centroid count: 2^mse_bits).
|
||||
|
||||
For MSE key modes, equals key_quant_bits.
|
||||
For FP8 key mode, falls back to value_quant_bits (centroids are still
|
||||
needed for continuation-prefill dequant and decode kernel params).
|
||||
"""
|
||||
if self.key_fp8:
|
||||
return self.value_quant_bits
|
||||
return self.key_quant_bits
|
||||
|
||||
@property
|
||||
def key_mse_bits(self) -> int:
|
||||
"""MSE bits actually used for key quantization (0 if FP8 keys)."""
|
||||
if self.key_fp8:
|
||||
return 0
|
||||
return self.key_quant_bits
|
||||
|
||||
@property
|
||||
def centroid_bits(self) -> int:
|
||||
"""Bits for centroid generation — always non-zero."""
|
||||
return self.mse_bits
|
||||
|
||||
@property
|
||||
def n_centroids(self) -> int:
|
||||
return 2**self.mse_bits
|
||||
|
||||
@property
|
||||
def key_packed_size(self) -> int:
|
||||
"""Packed bytes for a single KEY vector.
|
||||
|
||||
FP8 mode (key_quant_bits=8):
|
||||
head_dim bytes (1 byte per element, no overhead).
|
||||
|
||||
TQ mode:
|
||||
- MSE indices: ceil(head_dim * key_mse_bits / 8) bytes
|
||||
- vec_norm: 2 bytes (float16)
|
||||
"""
|
||||
if self.key_fp8:
|
||||
return self.head_dim # 1 byte per element
|
||||
mse_bytes = math.ceil(self.head_dim * self.key_mse_bits / 8)
|
||||
norm_bytes = 2 # vec_norm fp16
|
||||
return mse_bytes + norm_bytes
|
||||
|
||||
@property
|
||||
def effective_value_quant_bits(self) -> int:
|
||||
"""Actual bits used for value storage."""
|
||||
return self.value_quant_bits
|
||||
|
||||
@property
|
||||
def value_packed_size(self) -> int:
|
||||
"""Packed bytes for a single VALUE vector.
|
||||
|
||||
Uniform quantization: ceil(head_dim * bits / 8) + 4 bytes (scale + zero fp16).
|
||||
"""
|
||||
data_bytes = math.ceil(self.head_dim * self.value_quant_bits / 8)
|
||||
return data_bytes + 4 # +2 scale(fp16) +2 zero(fp16)
|
||||
|
||||
@property
|
||||
def slot_size(self) -> int:
|
||||
"""Total packed bytes per head per position (key + value combined).
|
||||
|
||||
Layout: [key_packed | value_packed]
|
||||
"""
|
||||
return self.key_packed_size + self.value_packed_size
|
||||
|
||||
@property
|
||||
def slot_size_aligned(self) -> int:
|
||||
"""Slot size rounded up to next even number.
|
||||
|
||||
Even-number is required so effective_head_size = slot_size_aligned // 2
|
||||
is integral.
|
||||
"""
|
||||
s = self.slot_size
|
||||
return s + (s % 2) # round up to even
|
||||
|
||||
@staticmethod
|
||||
def get_boundary_skip_layers(num_layers: int, n: int = 2) -> list[str]:
|
||||
"""Get layer indices to skip TQ compression (boundary protection).
|
||||
|
||||
Returns first N and last N layer indices as strings, suitable for
|
||||
kv_cache_dtype_skip_layers.
|
||||
"""
|
||||
if n <= 0 or num_layers <= 0:
|
||||
return []
|
||||
n = min(n, num_layers // 2) # don't skip more than half
|
||||
first = list(range(n))
|
||||
last = list(range(num_layers - n, num_layers))
|
||||
# Deduplicate (if num_layers <= 2*n)
|
||||
indices = sorted(set(first + last))
|
||||
return [str(i) for i in indices]
|
||||
|
||||
@staticmethod
|
||||
def from_cache_dtype(cache_dtype: str, head_dim: int) -> "TurboQuantConfig":
|
||||
"""Create config from a named preset.
|
||||
|
||||
Valid presets: turboquant_k8v4, turboquant_4bit_nc, etc.
|
||||
"""
|
||||
if cache_dtype not in TQ_PRESETS:
|
||||
valid = ", ".join(TQ_PRESETS.keys())
|
||||
raise ValueError(
|
||||
f"Unknown TurboQuant cache dtype: {cache_dtype!r}. "
|
||||
f"Valid presets: {valid}"
|
||||
)
|
||||
preset = TQ_PRESETS[cache_dtype]
|
||||
return TurboQuantConfig(
|
||||
head_dim=head_dim,
|
||||
key_quant_bits=preset["key_quant_bits"],
|
||||
value_quant_bits=preset["value_quant_bits"],
|
||||
norm_correction=preset["norm_correction"],
|
||||
)
|
||||
@@ -0,0 +1,24 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""TurboQuant quantizer utilities.
|
||||
|
||||
Serving path uses generate_wht_signs() for WHT rotation sign buffers.
|
||||
Triton kernels handle all quantization, packing, and dequantization on GPU.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
_CPU = torch.device("cpu")
|
||||
|
||||
|
||||
def generate_wht_signs(d: int, seed: int, device: torch.device = _CPU) -> torch.Tensor:
|
||||
"""Generate deterministic random ±1 signs for WHT rotation.
|
||||
|
||||
Used with Walsh-Hadamard Transform for per-layer rotation randomization.
|
||||
Same seed derivation as QR (per-layer via seed + layer_idx * stride).
|
||||
"""
|
||||
gen = torch.Generator(device="cpu")
|
||||
gen.manual_seed(seed)
|
||||
bits = torch.randint(0, 2, (d,), generator=gen, device="cpu")
|
||||
signs = bits.float() * 2 - 1
|
||||
return signs.to(device)
|
||||
@@ -255,6 +255,11 @@ class CudaPlatformBase(Platform):
|
||||
valid_backends_priorities = []
|
||||
invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {}
|
||||
|
||||
# TurboQuant KV cache: route directly to TQ backend
|
||||
kv_cache_dtype = attn_selector_config.kv_cache_dtype
|
||||
if kv_cache_dtype is not None and kv_cache_dtype.startswith("turboquant_"):
|
||||
return [(AttentionBackendEnum.TURBOQUANT, 0)], {}
|
||||
|
||||
backend_priorities = _get_backend_priorities(
|
||||
attn_selector_config.use_mla,
|
||||
device_capability,
|
||||
|
||||
@@ -61,6 +61,12 @@ class XPUPlatform(Platform):
|
||||
"only NHD layout is supported by XPU attention kernels."
|
||||
)
|
||||
|
||||
# TurboQuant KV cache: route directly to TQ backend
|
||||
kv_cache_dtype = attn_selector_config.kv_cache_dtype
|
||||
if kv_cache_dtype is not None and kv_cache_dtype.startswith("turboquant_"):
|
||||
logger.info_once("Using TurboQuant attention backend.")
|
||||
return AttentionBackendEnum.TURBOQUANT.get_path()
|
||||
|
||||
dtype = attn_selector_config.dtype
|
||||
if attn_selector_config.use_sparse:
|
||||
logger.info_once("Using XPU MLA Sparse backend.")
|
||||
|
||||
@@ -42,6 +42,10 @@ STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"fp8_per_token_head": torch.uint8,
|
||||
"fp8_inc": torch.float8_e4m3fn,
|
||||
"fp8_ds_mla": torch.uint8,
|
||||
"turboquant_k8v4": torch.uint8,
|
||||
"turboquant_4bit_nc": torch.uint8,
|
||||
"turboquant_k3v4_nc": torch.uint8,
|
||||
"turboquant_3bit_nc": torch.uint8,
|
||||
}
|
||||
|
||||
TORCH_DTYPE_TO_NUMPY_DTYPE = {
|
||||
|
||||
@@ -82,6 +82,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
"RocmAiterUnifiedAttentionBackend"
|
||||
)
|
||||
CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
|
||||
TURBOQUANT = "vllm.v1.attention.backends.turboquant_attn.TurboQuantAttentionBackend"
|
||||
# Placeholder for third-party/custom backends - must be registered before use
|
||||
# set to None to avoid alias with other backend, whose value is an empty string
|
||||
CUSTOM = None
|
||||
|
||||
@@ -0,0 +1,812 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""TurboQuant attention backend for vLLM.
|
||||
|
||||
Prefill: Standard scaled dot-product attention on uncompressed K/V,
|
||||
then quantize K and store K+V into combined cache slot.
|
||||
Decode: Compute TQ attention scores from compressed cache,
|
||||
unpack FP16 values, softmax + weighted sum.
|
||||
|
||||
Cache layout (no leading 2 dimension):
|
||||
(num_blocks, block_size, num_kv_heads, slot_size)
|
||||
where slot_size = key_packed_size + value_fp16_size
|
||||
|
||||
Per-head per-position slot layout:
|
||||
[key_packed (kps bytes) | value_fp16 (D*2 bytes)]
|
||||
For turboquant_k3v4_nc head_dim=256: [100 bytes key | 512 bytes value] = 612
|
||||
"""
|
||||
|
||||
import functools
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.fa_utils import (
|
||||
is_flash_attn_varlen_func_available,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import split_decodes_and_prefills
|
||||
from vllm.v1.attention.ops.triton_turboquant_decode import (
|
||||
_tq_full_dequant_kv,
|
||||
_use_fp8_e4b15,
|
||||
triton_turboquant_decode_attention,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_turboquant_store import triton_turboquant_store
|
||||
|
||||
_HAS_FLASH_ATTN = is_flash_attn_varlen_func_available()
|
||||
if _HAS_FLASH_ATTN:
|
||||
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func
|
||||
|
||||
# Continuation prefill: for small continuation chunks (q_len ≤ threshold),
|
||||
# use the TQ decode kernel directly instead of full-dequant + flash_attn.
|
||||
# do_kv_cache_update already stored all tokens to TQ cache, so the decode
|
||||
# kernel can read them efficiently. This avoids O(cached_len) dequant work
|
||||
# per continuation, eliminating the O(N²/chunk_size) collapse at long context.
|
||||
_CONTINUATION_DECODE_THRESHOLD = 128
|
||||
|
||||
|
||||
def _build_hadamard(d: int, device_str: str) -> torch.Tensor:
|
||||
"""Orthonormal Hadamard matrix (Sylvester construction), cached per (d, device).
|
||||
|
||||
Precomputed D×D matrix enables matmul-based WHT — single cuBLAS GEMM
|
||||
instead of log2(D) butterfly kernel launches. 64KB for D=128.
|
||||
"""
|
||||
# Normalize device string so "cuda" and "cuda:0" hit the same cache entry.
|
||||
return _build_hadamard_cached(d, str(torch.device(device_str)))
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _build_hadamard_cached(d: int, device_str: str) -> torch.Tensor:
|
||||
H = torch.tensor([[1.0]])
|
||||
while H.shape[0] < d:
|
||||
H = torch.cat([torch.cat([H, H], 1), torch.cat([H, -H], 1)], 0)
|
||||
return (H / math.sqrt(d)).to(torch.device(device_str))
|
||||
|
||||
|
||||
class TurboQuantAttentionBackend(AttentionBackend):
|
||||
"""Attention backend using TurboQuant KV-cache compression."""
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"turboquant_k8v4",
|
||||
"turboquant_4bit_nc",
|
||||
"turboquant_k3v4_nc",
|
||||
"turboquant_3bit_nc",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TURBOQUANT"
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [16, 32, 64, 128]
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
return attn_type == AttentionType.DECODER
|
||||
|
||||
@classmethod
|
||||
def supports_per_head_quant_scales(cls) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TurboQuantAttentionImpl"]:
|
||||
return TurboQuantAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["TurboQuantMetadataBuilder"]:
|
||||
return TurboQuantMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "turboquant_4bit_nc",
|
||||
) -> tuple[int, ...]:
|
||||
"""Combined K+V cache shape — no leading 2 dimension.
|
||||
|
||||
Standard attention backends use (2, num_blocks, block_size, num_kv_heads,
|
||||
head_dim) with a leading 2 to separate K and V. TurboQuant packs K+V
|
||||
into a single interleaved slot per head per position, so the cache is:
|
||||
|
||||
(num_blocks, block_size, num_kv_heads, slot_size_aligned)
|
||||
|
||||
Each slot = [key_packed | value_packed | padding].
|
||||
This is safe because TQ has its own get_kv_cache_shape override and
|
||||
never shares cache tensors with other backends. Layers that fall back
|
||||
to native dtype via kv_cache_dtype_skip_layers get their own
|
||||
standard-shaped cache allocation.
|
||||
|
||||
head_size is the model's real head_dim. slot_size_aligned is computed
|
||||
from the TQ config to ensure correct cache allocation for all head dims.
|
||||
"""
|
||||
from vllm.model_executor.layers.quantization.turboquant.config import (
|
||||
TurboQuantConfig,
|
||||
)
|
||||
|
||||
tq_config = TurboQuantConfig.from_cache_dtype(cache_dtype_str, head_size)
|
||||
return (num_blocks, block_size, num_kv_heads, tq_config.slot_size_aligned)
|
||||
|
||||
@classmethod
|
||||
def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
|
||||
if kv_cache_dtype is None:
|
||||
return False
|
||||
return kv_cache_dtype.startswith("turboquant_")
|
||||
|
||||
@classmethod
|
||||
def supports_head_size(cls, head_size: int) -> bool:
|
||||
# head_size from spec is effective_head_size (padded_slot//2),
|
||||
# not the model's actual head_dim. Accept any positive value.
|
||||
return head_size > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class TurboQuantMetadata(AttentionMetadata):
|
||||
"""Metadata for TurboQuant attention."""
|
||||
|
||||
seq_lens: torch.Tensor # (num_reqs,) — total context length per request
|
||||
slot_mapping: torch.Tensor # (num_tokens,) — cache slot for each token
|
||||
block_table: torch.Tensor # (num_reqs, max_num_blocks)
|
||||
query_start_loc: torch.Tensor # (num_reqs + 1,) — cu_seqlens for queries
|
||||
num_actual_tokens: int = 0 # actual tokens (excluding padding)
|
||||
max_query_len: int = 0 # longest query in batch
|
||||
max_seq_len: int = 0 # longest context in batch
|
||||
is_prefill: bool = False
|
||||
num_decodes: int = 0 # number of decode requests (first in batch)
|
||||
num_decode_tokens: int = 0 # tokens from decode requests
|
||||
|
||||
|
||||
class TurboQuantMetadataBuilder(AttentionMetadataBuilder[TurboQuantMetadata]):
|
||||
"""Builds TurboQuantMetadata from scheduler output."""
|
||||
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(self, kv_cache_spec, layer_names, vllm_config, device):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode=False)
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> TurboQuantMetadata:
|
||||
attn_metadata = self.build(0, common_attn_metadata)
|
||||
# Set seq_lens to 1 so CUDA graph capture is fast
|
||||
# (real seq_lens are filled at replay time).
|
||||
attn_metadata.seq_lens.fill_(1)
|
||||
return attn_metadata
|
||||
|
||||
def build(self, common_prefix_len, common_attn_metadata, fast_build=False):
|
||||
"""Build TurboQuantMetadata from common attention metadata."""
|
||||
cam = common_attn_metadata
|
||||
|
||||
# With reorder_batch_threshold=1, the model runner guarantees
|
||||
# decodes come first in the batch. split_decodes_and_prefills
|
||||
# finds the boundary (operates on CPU tensors — no GPU sync).
|
||||
assert self.reorder_batch_threshold is not None
|
||||
num_decodes, num_prefills, num_decode_tokens, _ = split_decodes_and_prefills(
|
||||
cam, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
|
||||
return TurboQuantMetadata(
|
||||
seq_lens=cam.seq_lens,
|
||||
slot_mapping=cam.slot_mapping,
|
||||
block_table=cam.block_table_tensor,
|
||||
query_start_loc=cam.query_start_loc,
|
||||
num_actual_tokens=cam.num_actual_tokens,
|
||||
max_query_len=cam.max_query_len,
|
||||
max_seq_len=cam.max_seq_len,
|
||||
is_prefill=(cam.max_query_len > 1),
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
)
|
||||
|
||||
|
||||
class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
|
||||
"""TurboQuant attention implementation.
|
||||
|
||||
Vectorized PyTorch: batch quantize/store, vectorized bit-unpack
|
||||
decode with einsum scores and value gather.
|
||||
"""
|
||||
|
||||
supports_quant_query_input: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int | None = None,
|
||||
alibi_slopes: list[float] | None = None,
|
||||
sliding_window: int | None = None,
|
||||
kv_cache_dtype: str = "auto",
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
|
||||
self.num_kv_groups = num_heads // self.num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
from vllm.model_executor.layers.quantization.turboquant.config import (
|
||||
TurboQuantConfig,
|
||||
)
|
||||
|
||||
self.tq_config = TurboQuantConfig.from_cache_dtype(kv_cache_dtype, head_size)
|
||||
|
||||
# Pre-compute kernel constants from config (avoid repeated arithmetic)
|
||||
cfg = self.tq_config
|
||||
self._mse_bytes = (
|
||||
math.ceil(head_size * cfg.key_mse_bits / 8)
|
||||
if not cfg.key_fp8
|
||||
else head_size
|
||||
)
|
||||
self._val_data_bytes = math.ceil(head_size * cfg.effective_value_quant_bits / 8)
|
||||
self._n_centroids = cfg.n_centroids if not cfg.key_fp8 else 1
|
||||
|
||||
# Fixed NUM_KV_SPLITS (grid dims must be constant for cudagraph,
|
||||
# and benchmarks show no regression vs dynamic in eager mode).
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.max_num_kv_splits = (
|
||||
vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph
|
||||
)
|
||||
|
||||
def _ensure_on_device(self, layer, device):
|
||||
"""One-time derivation of TQ buffers (rotation matrices, midpoints).
|
||||
|
||||
Registered buffers (_tq_signs, _tq_centroids) are already on the
|
||||
correct device via register_buffer + model.to(device).
|
||||
"""
|
||||
if not hasattr(layer, "_tq_cached"):
|
||||
D = layer._tq_signs.shape[0]
|
||||
signs = layer._tq_signs.to(device=device, dtype=torch.float32)
|
||||
|
||||
# WHT rotation: orthonormal + self-inverse, enabling future
|
||||
# in-kernel butterfly fusion and trivial inverse for continuation.
|
||||
H = _build_hadamard(D, str(device))
|
||||
layer._tq_PiT = (signs.unsqueeze(1) * H).contiguous()
|
||||
layer._tq_Pi = layer._tq_PiT.T.contiguous()
|
||||
|
||||
c = layer._tq_centroids.to(device=device, dtype=torch.float32)
|
||||
# Precompute midpoints for threshold-based quantization
|
||||
c_sorted, _ = c.sort()
|
||||
layer._tq_midpoints = (c_sorted[:-1] + c_sorted[1:]) / 2
|
||||
# Decode buffers (_tq_mid_o_buf, _tq_output_buf, _tq_lse_buf)
|
||||
# are pre-allocated via register_buffer in Attention.__init__
|
||||
# and moved to GPU by model.to(device) — no allocation needed
|
||||
# here. The memory profiler sees them before KV cache sizing.
|
||||
layer._tq_cached = True
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
"""Store compressed K/V into the combined TQ cache.
|
||||
|
||||
Called as a separate custom op (unified_kv_cache_update) BEFORE
|
||||
the attention forward, matching FlashAttention's split pattern.
|
||||
slot_mapping is already sliced to num_actual_tokens by the caller.
|
||||
"""
|
||||
N = slot_mapping.shape[0]
|
||||
if N <= 0:
|
||||
return
|
||||
|
||||
device = key.device
|
||||
self._ensure_on_device(layer, device)
|
||||
|
||||
k = key[:N].view(N, self.num_kv_heads, self.head_size)
|
||||
v = value[:N].view(N, self.num_kv_heads, self.head_size)
|
||||
self._store_kv(k, v, kv_cache, slot_mapping, layer)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: "TurboQuantMetadata",
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = query.shape[0]
|
||||
|
||||
if output is None:
|
||||
output = torch.zeros(
|
||||
num_tokens,
|
||||
self.num_heads * self.head_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
return output.fill_(0)
|
||||
|
||||
# Slice to actual tokens
|
||||
N = attn_metadata.num_actual_tokens
|
||||
if N <= 0:
|
||||
return output.fill_(0)
|
||||
|
||||
q = query[:N].view(N, self.num_heads, self.head_size)
|
||||
|
||||
# Get TQ buffers, ensure on device (one-time migration).
|
||||
# Use Any-typed alias for dynamic _tq_* attrs set by _ensure_on_device.
|
||||
tq_layer: Any = layer
|
||||
device = q.device
|
||||
self._ensure_on_device(tq_layer, device)
|
||||
Pi = tq_layer._tq_Pi
|
||||
PiT = tq_layer._tq_PiT
|
||||
centroids = tq_layer._tq_centroids
|
||||
|
||||
# Compute attention (KV cache was already updated by do_kv_cache_update)
|
||||
# With reorder_batch_threshold=1, decodes come first in the batch.
|
||||
# num_decodes/num_decode_tokens from metadata give the split point.
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
if not attn_metadata.is_prefill:
|
||||
# Pure decode batch — fast path
|
||||
attn_out = self._decode_attention(
|
||||
q, kv_cache, attn_metadata, Pi, centroids, PiT, layer
|
||||
)
|
||||
elif num_decodes == 0:
|
||||
# Pure prefill batch
|
||||
k = key[:N].view(N, self.num_kv_heads, self.head_size)
|
||||
v = value[:N].view(N, self.num_kv_heads, self.head_size)
|
||||
attn_out = self._prefill_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
Pi,
|
||||
centroids,
|
||||
PiT,
|
||||
layer=layer,
|
||||
)
|
||||
else:
|
||||
# Mixed batch: decodes first (guaranteed by reorder_batch).
|
||||
attn_out = torch.zeros(
|
||||
N, self.num_heads, self.head_size, device=device, dtype=q.dtype
|
||||
)
|
||||
|
||||
# --- Decode portion (first num_decodes requests) ---
|
||||
# Use full-batch max_seq_len as safe upper bound (no GPU sync).
|
||||
decode_meta = TurboQuantMetadata(
|
||||
seq_lens=attn_metadata.seq_lens[:num_decodes],
|
||||
slot_mapping=attn_metadata.slot_mapping[:num_decode_tokens],
|
||||
block_table=attn_metadata.block_table[:num_decodes],
|
||||
query_start_loc=attn_metadata.query_start_loc[: num_decodes + 1],
|
||||
num_actual_tokens=num_decode_tokens,
|
||||
max_query_len=1,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
is_prefill=False,
|
||||
)
|
||||
attn_out[:num_decode_tokens] = self._decode_attention(
|
||||
q[:num_decode_tokens], kv_cache, decode_meta, Pi, centroids, PiT, layer
|
||||
)
|
||||
|
||||
# --- Prefill portion (remaining requests) ---
|
||||
# CRITICAL: use prefill-specific max_seq_len so flash_attn's
|
||||
# fast path (max_query_len == max_seq_len) triggers for
|
||||
# first-chunk prefills. Using full-batch max_seq_len breaks
|
||||
# this because decode requests inflate max_seq_len.
|
||||
prefill_seq_lens = attn_metadata.seq_lens[num_decodes:]
|
||||
# Use CPU-side max to avoid GPU→CPU sync from .item()
|
||||
prefill_max_seq = max(attn_metadata.seq_lens[num_decodes:].tolist())
|
||||
prefill_qsl = (
|
||||
attn_metadata.query_start_loc[num_decodes:] - num_decode_tokens
|
||||
)
|
||||
prefill_meta = TurboQuantMetadata(
|
||||
seq_lens=prefill_seq_lens,
|
||||
slot_mapping=attn_metadata.slot_mapping[num_decode_tokens:N],
|
||||
block_table=attn_metadata.block_table[num_decodes:],
|
||||
query_start_loc=prefill_qsl,
|
||||
num_actual_tokens=N - num_decode_tokens,
|
||||
max_query_len=attn_metadata.max_query_len,
|
||||
max_seq_len=prefill_max_seq,
|
||||
is_prefill=True,
|
||||
)
|
||||
k = key[:N].view(N, self.num_kv_heads, self.head_size)
|
||||
v = value[:N].view(N, self.num_kv_heads, self.head_size)
|
||||
attn_out[num_decode_tokens:] = self._prefill_attention(
|
||||
q[num_decode_tokens:],
|
||||
k[num_decode_tokens:],
|
||||
v[num_decode_tokens:],
|
||||
kv_cache,
|
||||
prefill_meta,
|
||||
Pi,
|
||||
centroids,
|
||||
PiT,
|
||||
layer=layer,
|
||||
)
|
||||
|
||||
# Write into output buffer: attn_out is (N, Hq, D)
|
||||
# output may be 2D (N, Hq*D) or 3D (N, Hq, D)
|
||||
if output.ndim == 3:
|
||||
output[:N] = attn_out.to(output.dtype)
|
||||
else:
|
||||
output[:N] = attn_out.reshape(N, -1).to(output.dtype)
|
||||
return output
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Store K/V into combined cache (vectorized) #
|
||||
# ------------------------------------------------------------------ #
|
||||
def _store_kv(
|
||||
self,
|
||||
key: torch.Tensor, # (N, Hk, D)
|
||||
value: torch.Tensor, # (N, Hk, D)
|
||||
kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size)
|
||||
slot_mapping: torch.Tensor,
|
||||
layer: Any,
|
||||
):
|
||||
"""Quantize + store via fused Triton kernel."""
|
||||
triton_turboquant_store(
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
layer._tq_PiT,
|
||||
layer._tq_midpoints,
|
||||
mse_bits=self.tq_config.key_mse_bits,
|
||||
key_packed_size=self.tq_config.key_packed_size,
|
||||
value_quant_bits=self.tq_config.effective_value_quant_bits,
|
||||
key_fp8=self.tq_config.key_fp8,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Prefill: SDPA on raw Q/K/V with causal mask #
|
||||
# ------------------------------------------------------------------ #
|
||||
def _prefill_attention(
|
||||
self,
|
||||
query: torch.Tensor, # (N, Hq, D)
|
||||
key: torch.Tensor, # (N, Hk, D)
|
||||
value: torch.Tensor, # (N, Hk, D)
|
||||
kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size)
|
||||
attn_metadata: TurboQuantMetadata,
|
||||
Pi: torch.Tensor,
|
||||
centroids: torch.Tensor,
|
||||
PiT: torch.Tensor | None = None,
|
||||
layer: Any = None,
|
||||
) -> torch.Tensor:
|
||||
N, Hq, D = query.shape
|
||||
|
||||
# Fast path: use flash_attn for first-chunk prefills (all K/V in batch).
|
||||
# max_query_len == max_seq_len means no request has prior cached KV.
|
||||
# Both are Python ints — no GPU sync.
|
||||
if _HAS_FLASH_ATTN and attn_metadata.max_query_len == attn_metadata.max_seq_len:
|
||||
output = torch.empty(N, Hq, D, device=query.device, dtype=query.dtype)
|
||||
flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=attn_metadata.query_start_loc,
|
||||
cu_seqlens_k=attn_metadata.query_start_loc,
|
||||
max_seqlen_q=attn_metadata.max_query_len,
|
||||
max_seqlen_k=attn_metadata.max_query_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
out=output,
|
||||
)
|
||||
return output
|
||||
|
||||
# Continuation or no flash_attn: per-request attention.
|
||||
# For continuation chunks (seq_len > q_len), we must attend to
|
||||
# previously cached K/V from the TQ cache, not just the current
|
||||
# chunk's raw K/V.
|
||||
Hk = key.shape[1]
|
||||
use_gqa = Hk < Hq
|
||||
query_start_loc = attn_metadata.query_start_loc
|
||||
num_reqs = query_start_loc.shape[0] - 1
|
||||
|
||||
output = torch.zeros(N, Hq, D, device=query.device, dtype=query.dtype)
|
||||
|
||||
# Convert to Python lists once (single CPU-GPU sync) instead of
|
||||
# per-request .item() calls that each force a sync.
|
||||
qsl = query_start_loc.tolist()
|
||||
seq_lens_list = attn_metadata.seq_lens.tolist()
|
||||
|
||||
# Pre-allocate cu_seqlens for single-request flash_attn calls
|
||||
# to avoid per-request host→device tensor creation.
|
||||
_cu_2 = torch.zeros(2, device=query.device, dtype=torch.int32)
|
||||
|
||||
for i in range(num_reqs):
|
||||
q_start = qsl[i]
|
||||
q_end = qsl[i + 1]
|
||||
q_len = q_end - q_start
|
||||
if q_len <= 0:
|
||||
continue
|
||||
|
||||
seq_len = seq_lens_list[i]
|
||||
q_seq = query[q_start:q_end] # (q_len, Hq, D)
|
||||
k_seq = key[q_start:q_end] # (q_len, Hk, D)
|
||||
v_seq = value[q_start:q_end] # (q_len, Hk, D)
|
||||
|
||||
if q_len == seq_len:
|
||||
# First-chunk prefill: all K/V are in the current batch.
|
||||
if _HAS_FLASH_ATTN:
|
||||
out = torch.empty_like(q_seq)
|
||||
_cu_2[1] = q_len
|
||||
cu = _cu_2
|
||||
flash_attn_varlen_func(
|
||||
q=q_seq,
|
||||
k=k_seq,
|
||||
v=v_seq,
|
||||
cu_seqlens_q=cu,
|
||||
cu_seqlens_k=cu,
|
||||
max_seqlen_q=q_len,
|
||||
max_seqlen_k=q_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
out=out,
|
||||
)
|
||||
else:
|
||||
q_t = q_seq.transpose(0, 1).contiguous()
|
||||
k_t = k_seq.transpose(0, 1).contiguous()
|
||||
v_t = v_seq.transpose(0, 1).contiguous()
|
||||
out = F.scaled_dot_product_attention(
|
||||
q_t,
|
||||
k_t,
|
||||
v_t,
|
||||
is_causal=True,
|
||||
scale=self.scale,
|
||||
enable_gqa=use_gqa,
|
||||
).transpose(0, 1)
|
||||
output[q_start:q_end] = out.to(query.dtype)
|
||||
else:
|
||||
# Continuation chunk: tokens already stored to TQ cache
|
||||
# by do_kv_cache_update. Use decode kernel directly to
|
||||
# avoid O(cached_len) full-dequant per continuation.
|
||||
# For large continuations, fall back to _continuation_prefill.
|
||||
cached_len = seq_len - q_len
|
||||
if q_len <= _CONTINUATION_DECODE_THRESHOLD:
|
||||
# Fast path: treat each query as a decode request
|
||||
# with incremental seq_lens for causal masking.
|
||||
synth_seq_lens = torch.arange(
|
||||
cached_len + 1,
|
||||
seq_len + 1,
|
||||
device=query.device,
|
||||
dtype=attn_metadata.seq_lens.dtype,
|
||||
)
|
||||
synth_bt = attn_metadata.block_table[i : i + 1].expand(q_len, -1)
|
||||
out = triton_turboquant_decode_attention(
|
||||
query=q_seq,
|
||||
kv_cache=kv_cache,
|
||||
block_table=synth_bt,
|
||||
seq_lens=synth_seq_lens,
|
||||
Pi=Pi,
|
||||
centroids=centroids,
|
||||
scale=self.scale,
|
||||
mse_bits=self.tq_config.key_mse_bits,
|
||||
key_packed_size=self.tq_config.key_packed_size,
|
||||
value_quant_bits=(self.tq_config.effective_value_quant_bits),
|
||||
key_fp8=self.tq_config.key_fp8,
|
||||
norm_correction=self.tq_config.norm_correction,
|
||||
PiT=PiT,
|
||||
)
|
||||
else:
|
||||
# Large continuation: dequant cached K/V and use
|
||||
# flash_attn for better throughput.
|
||||
out = self._continuation_prefill(
|
||||
layer,
|
||||
q_seq,
|
||||
k_seq,
|
||||
v_seq,
|
||||
kv_cache,
|
||||
attn_metadata.block_table[i : i + 1],
|
||||
cached_len,
|
||||
seq_len,
|
||||
Pi,
|
||||
centroids,
|
||||
)
|
||||
output[q_start:q_end] = out.to(query.dtype)
|
||||
|
||||
return output
|
||||
|
||||
def _continuation_prefill(
|
||||
self,
|
||||
layer: Any,
|
||||
query: torch.Tensor, # (q_len, Hq, D)
|
||||
key_chunk: torch.Tensor, # (q_len, Hk, D)
|
||||
val_chunk: torch.Tensor, # (q_len, Hk, D)
|
||||
kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size)
|
||||
block_table: torch.Tensor, # (1, max_num_blocks)
|
||||
cached_len: int,
|
||||
seq_len: int,
|
||||
Pi: torch.Tensor,
|
||||
centroids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Handle continuation chunk by dequanting cached K/V from TQ cache.
|
||||
|
||||
Dequants previously cached K/V, concatenates with the current
|
||||
chunk's raw K/V, then runs flash_attn with causal masking.
|
||||
"""
|
||||
q_len, Hq, D = query.shape
|
||||
Hk = key_chunk.shape[1]
|
||||
device = query.device
|
||||
block_size = kv_cache.shape[1]
|
||||
BLOCK_D = triton.next_power_of_2(D)
|
||||
|
||||
mse_bytes = self._mse_bytes
|
||||
val_data_bytes = self._val_data_bytes
|
||||
|
||||
# Dequant cached K/V from TQ cache
|
||||
# Allocate slightly over to align to block_size for the grid.
|
||||
# Reuse cached buffers to avoid per-call allocation (~16MB at 8K).
|
||||
alloc_len = math.ceil(cached_len / block_size) * block_size
|
||||
buf_shape = (1, Hk, alloc_len, D)
|
||||
k_buf = getattr(layer, "_tq_k_dequant_buf", None)
|
||||
if k_buf is None or k_buf.shape[2] < alloc_len:
|
||||
k_buf = torch.empty(buf_shape, dtype=torch.float16, device=device)
|
||||
v_buf = torch.empty(buf_shape, dtype=torch.float16, device=device)
|
||||
layer._tq_k_dequant_buf = k_buf
|
||||
layer._tq_v_dequant_buf = v_buf
|
||||
else:
|
||||
v_buf = layer._tq_v_dequant_buf
|
||||
k_cached = k_buf[:, :, :alloc_len, :].zero_()
|
||||
v_cached = v_buf[:, :, :alloc_len, :].zero_()
|
||||
|
||||
grid = (alloc_len, 1 * Hk)
|
||||
_tq_full_dequant_kv[grid](
|
||||
kv_cache,
|
||||
block_table,
|
||||
centroids,
|
||||
k_cached,
|
||||
v_cached,
|
||||
k_cached.stride(0),
|
||||
k_cached.stride(1),
|
||||
k_cached.stride(2),
|
||||
v_cached.stride(0),
|
||||
v_cached.stride(1),
|
||||
v_cached.stride(2),
|
||||
kv_cache.stride(0),
|
||||
kv_cache.stride(1),
|
||||
kv_cache.stride(2),
|
||||
block_table.stride(0),
|
||||
HEAD_DIM=D,
|
||||
BLOCK_SIZE=block_size,
|
||||
NUM_KV_HEADS=Hk,
|
||||
MSE_BYTES=mse_bytes,
|
||||
KPS=self.tq_config.key_packed_size,
|
||||
VQB=self.tq_config.effective_value_quant_bits,
|
||||
VAL_DATA_BYTES=val_data_bytes,
|
||||
MSE_BITS=self.tq_config.key_mse_bits,
|
||||
KEY_FP8=1 if self.tq_config.key_fp8 else 0,
|
||||
BLOCK_D=BLOCK_D,
|
||||
NORM_CORRECTION=1 if self.tq_config.norm_correction else 0,
|
||||
FP8_E4B15=_use_fp8_e4b15(device.index or 0),
|
||||
num_warps=4,
|
||||
)
|
||||
|
||||
# Inverse-rotate MSE keys back to original space
|
||||
if not self.tq_config.key_fp8:
|
||||
k_flat = k_cached[0, :, :cached_len, :].reshape(-1, D).float()
|
||||
k_flat = k_flat @ Pi
|
||||
k_cached_trim = (
|
||||
k_flat.to(torch.float16).reshape(Hk, cached_len, D).transpose(0, 1)
|
||||
) # (cached_len, Hk, D)
|
||||
else:
|
||||
k_cached_trim = (
|
||||
k_cached[0, :, :cached_len, :].transpose(0, 1).contiguous()
|
||||
) # (cached_len, Hk, D)
|
||||
|
||||
v_cached_trim = (
|
||||
v_cached[0, :, :cached_len, :].transpose(0, 1).contiguous()
|
||||
) # (cached_len, Hk, D)
|
||||
|
||||
# Concatenate cached + current chunk K/V (match query dtype)
|
||||
qdtype = query.dtype
|
||||
k_full = torch.cat([k_cached_trim.to(qdtype), key_chunk], dim=0)
|
||||
v_full = torch.cat([v_cached_trim.to(qdtype), val_chunk], dim=0)
|
||||
|
||||
# Attention: q_len queries attending to seq_len K/V with causal mask
|
||||
if _HAS_FLASH_ATTN:
|
||||
output = torch.empty(q_len, Hq, D, device=device, dtype=query.dtype)
|
||||
cu_seqlens_q = torch.tensor([0, q_len], device=device, dtype=torch.int32)
|
||||
cu_seqlens_k = torch.tensor([0, seq_len], device=device, dtype=torch.int32)
|
||||
flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=k_full,
|
||||
v=v_full,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=q_len,
|
||||
max_seqlen_k=seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
out=output,
|
||||
)
|
||||
return output
|
||||
else:
|
||||
# SDPA fallback: expand KV for GQA, build causal mask
|
||||
q_t = query.transpose(0, 1).unsqueeze(0) # (1, Hq, q_len, D)
|
||||
k_t = k_full.transpose(0, 1).unsqueeze(0) # (1, Hk, seq_len, D)
|
||||
v_t = v_full.transpose(0, 1).unsqueeze(0) # (1, Hk, seq_len, D)
|
||||
# Build causal mask: query position p can attend to K position j
|
||||
# where j <= cached_len + p (p is 0-indexed within chunk)
|
||||
q_pos = torch.arange(q_len, device=device).unsqueeze(1) + cached_len
|
||||
k_pos = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
mask = k_pos <= q_pos # (q_len, seq_len)
|
||||
out = F.scaled_dot_product_attention(
|
||||
q_t,
|
||||
k_t,
|
||||
v_t,
|
||||
attn_mask=mask,
|
||||
scale=self.scale,
|
||||
enable_gqa=(Hk < Hq),
|
||||
) # (1, Hq, q_len, D)
|
||||
return out[0].transpose(0, 1) # (q_len, Hq, D)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Decode: Triton TQ decode attention #
|
||||
# ------------------------------------------------------------------ #
|
||||
def _decode_attention(
|
||||
self,
|
||||
query: torch.Tensor, # (B, Hq, D)
|
||||
kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size)
|
||||
attn_metadata: TurboQuantMetadata,
|
||||
Pi: torch.Tensor,
|
||||
centroids: torch.Tensor,
|
||||
PiT: torch.Tensor | None = None,
|
||||
layer: torch.nn.Module | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Grab cached decode buffers from the layer (lazily allocated).
|
||||
mid_o_buf = output_buf = lse_buf = None
|
||||
if layer is not None:
|
||||
mid_o_buf = getattr(layer, "_tq_mid_o_buf", None)
|
||||
output_buf = getattr(layer, "_tq_output_buf", None)
|
||||
lse_buf = getattr(layer, "_tq_lse_buf", None)
|
||||
|
||||
result = triton_turboquant_decode_attention(
|
||||
query=query,
|
||||
kv_cache=kv_cache,
|
||||
block_table=attn_metadata.block_table,
|
||||
seq_lens=attn_metadata.seq_lens,
|
||||
Pi=Pi,
|
||||
centroids=centroids,
|
||||
scale=self.scale,
|
||||
mse_bits=self.tq_config.key_mse_bits,
|
||||
key_packed_size=self.tq_config.key_packed_size,
|
||||
value_quant_bits=self.tq_config.effective_value_quant_bits,
|
||||
key_fp8=self.tq_config.key_fp8,
|
||||
norm_correction=self.tq_config.norm_correction,
|
||||
PiT=PiT,
|
||||
mid_o_buf=mid_o_buf,
|
||||
output_buf=output_buf,
|
||||
lse_buf=lse_buf,
|
||||
buf_holder=layer,
|
||||
max_num_kv_splits=self.max_num_kv_splits,
|
||||
)
|
||||
return result
|
||||
@@ -0,0 +1,617 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Triton fused TurboQuant decode attention.
|
||||
|
||||
Decode path: Triton stage1 (split-KV tiled attention scoring + value
|
||||
accumulation) + stage2 (log-sum-exp reduction across splits).
|
||||
|
||||
Supports FP8 (E4M3) keys, 3-bit and 4-bit uniform quantized values.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.ops.triton_decode_attention import (
|
||||
_fwd_kernel_stage2,
|
||||
)
|
||||
|
||||
_FP8_E4B15: dict[int, int] = {}
|
||||
|
||||
|
||||
def _use_fp8_e4b15(device: int = 0) -> int:
|
||||
"""Return 1 if device needs fp8e4b15 (Ampere/Ada, SM < 8.9), else 0."""
|
||||
if device not in _FP8_E4B15:
|
||||
cap = torch.cuda.get_device_capability(device)
|
||||
_FP8_E4B15[device] = 1 if cap < (8, 9) else 0
|
||||
return _FP8_E4B15[device]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stage 1: Fused TQ score + value accumulation (BLOCK_KV tiled)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _tq_decode_stage1(
|
||||
# Precomputed query projection
|
||||
Q_rot_ptr, # [B, Hq, D] float32
|
||||
# Compressed KV cache (combined K+V)
|
||||
KV_cache_ptr, # [num_blocks, block_size, Hk, padded_slot] uint8
|
||||
# Block table and sequence info
|
||||
Block_table_ptr, # [B, max_num_blocks] int32
|
||||
Seq_lens_ptr, # [B] int32
|
||||
# TQ parameters
|
||||
Centroids_ptr, # [n_centroids] float32
|
||||
# Output (intermediate for stage2)
|
||||
Mid_o_ptr, # [B, Hq, NUM_KV_SPLITS, D+1] float32
|
||||
# Strides
|
||||
stride_qb,
|
||||
stride_qh, # Q strides: [B, Hq, D]
|
||||
stride_cache_block,
|
||||
stride_cache_pos,
|
||||
stride_cache_head, # KV cache
|
||||
stride_bt_b, # block_table stride per batch
|
||||
stride_mid_b,
|
||||
stride_mid_h,
|
||||
stride_mid_s, # mid_o strides
|
||||
# Constexpr dims
|
||||
NUM_KV_HEADS: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr, # KV cache block_size (pages)
|
||||
NUM_KV_SPLITS: tl.constexpr,
|
||||
KV_GROUP_SIZE: tl.constexpr, # Hq // Hk
|
||||
# TQ layout constants
|
||||
MSE_BITS: tl.constexpr, # 3 or 4
|
||||
MSE_BYTES: tl.constexpr, # ceil(D * mse_bits / 8)
|
||||
KPS: tl.constexpr, # key_packed_size
|
||||
VQB: tl.constexpr, # value_quant_bits (4 or 8=FP8)
|
||||
VAL_DATA_BYTES: tl.constexpr, # ceil(D * vqb / 8) or D for FP8
|
||||
# Score constants
|
||||
ATTN_SCALE: tl.constexpr, # 1/sqrt(D)
|
||||
# Block tile sizes
|
||||
BLOCK_D: tl.constexpr, # next_power_of_2(HEAD_DIM)
|
||||
BLOCK_KV: tl.constexpr, # tokens per tile (16)
|
||||
KEY_FP8: tl.constexpr, # 1 if K is stored as FP8
|
||||
NORM_CORRECTION: tl.constexpr = 0, # 1 = re-normalize centroids
|
||||
FP8_E4B15: tl.constexpr = 0, # 1 = use e4b15 (Ampere/Ada), 0 = e4nv (Hopper+)
|
||||
):
|
||||
bid = tl.program_id(0) # batch index
|
||||
hid = tl.program_id(1) # q_head index
|
||||
sid = tl.program_id(2) # kv_split index
|
||||
|
||||
kv_head = hid // KV_GROUP_SIZE
|
||||
|
||||
# Sequence length for this batch
|
||||
seq_len = tl.load(Seq_lens_ptr + bid)
|
||||
|
||||
# KV split range
|
||||
split_len = tl.cdiv(seq_len, NUM_KV_SPLITS)
|
||||
split_start = split_len * sid
|
||||
split_end = tl.minimum(split_start + split_len, seq_len)
|
||||
|
||||
if split_start >= split_end:
|
||||
return
|
||||
|
||||
# Dimension offsets
|
||||
d_offs = tl.arange(0, BLOCK_D)
|
||||
d_mask = d_offs < HEAD_DIM
|
||||
kv_range = tl.arange(0, BLOCK_KV)
|
||||
|
||||
# Load query vector: q_rot — [BLOCK_D] float32
|
||||
q_base = bid * stride_qb + hid * stride_qh
|
||||
q_rot = tl.load(Q_rot_ptr + q_base + d_offs, mask=d_mask, other=0.0).to(tl.float32)
|
||||
|
||||
# Precompute byte/bit index vectors for MSE gather loads
|
||||
if not KEY_FP8:
|
||||
mse_bit_off = d_offs * MSE_BITS
|
||||
mse_byte_idx = mse_bit_off // 8
|
||||
mse_bit_shift = mse_bit_off % 8
|
||||
mse_mask = (1 << MSE_BITS) - 1
|
||||
|
||||
# Precompute value bit/byte index vectors (loop-invariant)
|
||||
if VQB == 3:
|
||||
val_bit_off = d_offs * 3
|
||||
val_byte_idx = val_bit_off // 8
|
||||
val_bit_shift = val_bit_off % 8
|
||||
|
||||
# Online softmax accumulators
|
||||
m_prev = -float("inf")
|
||||
l_prev = 0.0
|
||||
acc = tl.zeros([BLOCK_D], dtype=tl.float32)
|
||||
|
||||
bt_base = bid * stride_bt_b
|
||||
|
||||
# ================================================================
|
||||
# TILED LOOP: process BLOCK_KV tokens per iteration
|
||||
# ================================================================
|
||||
for start_n in range(split_start, split_end, BLOCK_KV):
|
||||
kv_offs = start_n + kv_range
|
||||
kv_mask = kv_offs < split_end
|
||||
|
||||
page_idx = kv_offs // BLOCK_SIZE
|
||||
page_off = kv_offs % BLOCK_SIZE
|
||||
block_nums = tl.load(
|
||||
Block_table_ptr + bt_base + page_idx,
|
||||
mask=kv_mask,
|
||||
other=0,
|
||||
)
|
||||
|
||||
slot_bases = (
|
||||
block_nums * stride_cache_block
|
||||
+ page_off * stride_cache_pos
|
||||
+ kv_head * stride_cache_head
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# COMPUTE ATTENTION SCORES: [BLOCK_KV]
|
||||
# ============================================================
|
||||
if KEY_FP8:
|
||||
k_addrs = slot_bases[:, None] + d_offs[None, :]
|
||||
k_raw = tl.load(
|
||||
KV_cache_ptr + k_addrs,
|
||||
mask=kv_mask[:, None] & d_mask[None, :],
|
||||
other=0,
|
||||
)
|
||||
if FP8_E4B15:
|
||||
k_float = k_raw.to(tl.float8e4b15, bitcast=True).to(tl.float32)
|
||||
else:
|
||||
k_float = k_raw.to(tl.float8e4nv, bitcast=True).to(tl.float32)
|
||||
scores = (
|
||||
tl.sum(
|
||||
tl.where(d_mask[None, :], q_rot[None, :] * k_float, 0.0),
|
||||
axis=1,
|
||||
)
|
||||
* ATTN_SCALE
|
||||
)
|
||||
scores = tl.where(kv_mask, scores, -float("inf"))
|
||||
else:
|
||||
# MSE unpack + norms
|
||||
mse_addrs0 = slot_bases[:, None] + mse_byte_idx[None, :]
|
||||
mse_raw0 = tl.load(
|
||||
KV_cache_ptr + mse_addrs0,
|
||||
mask=kv_mask[:, None] & d_mask[None, :],
|
||||
other=0,
|
||||
).to(tl.int32)
|
||||
mse_raw1 = tl.load(
|
||||
KV_cache_ptr + mse_addrs0 + 1,
|
||||
mask=kv_mask[:, None] & d_mask[None, :],
|
||||
other=0,
|
||||
).to(tl.int32)
|
||||
raw16 = mse_raw0 | (mse_raw1 << 8)
|
||||
mse_idx = (raw16 >> mse_bit_shift[None, :]) & mse_mask
|
||||
|
||||
# Centroid gather + dot product
|
||||
c_vals = tl.load(
|
||||
Centroids_ptr + mse_idx,
|
||||
mask=kv_mask[:, None] & d_mask[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
# Norm correction: re-normalize centroid vector to unit norm
|
||||
if NORM_CORRECTION:
|
||||
c_norm_sq = tl.sum(
|
||||
tl.where(d_mask[None, :], c_vals * c_vals, 0.0),
|
||||
axis=1,
|
||||
)
|
||||
c_inv_norm = 1.0 / tl.sqrt(c_norm_sq + 1e-16)
|
||||
c_vals = c_vals * c_inv_norm[:, None]
|
||||
|
||||
term1 = tl.sum(
|
||||
tl.where(d_mask[None, :], q_rot[None, :] * c_vals, 0.0),
|
||||
axis=1,
|
||||
)
|
||||
|
||||
# Load norms (fp16 -> fp32): norms are at MSE_BYTES offset
|
||||
norm_bases = slot_bases + MSE_BYTES
|
||||
n_lo = tl.load(KV_cache_ptr + norm_bases, mask=kv_mask, other=0).to(
|
||||
tl.uint16
|
||||
)
|
||||
n_hi = tl.load(KV_cache_ptr + norm_bases + 1, mask=kv_mask, other=0).to(
|
||||
tl.uint16
|
||||
)
|
||||
vec_norms = (n_lo | (n_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
|
||||
|
||||
scores = vec_norms * term1 * ATTN_SCALE
|
||||
scores = tl.where(kv_mask, scores, -float("inf"))
|
||||
|
||||
# ============================================================
|
||||
# ONLINE SOFTMAX UPDATE (block-level)
|
||||
# ============================================================
|
||||
n_e_max = tl.maximum(tl.max(scores, 0), m_prev)
|
||||
re_scale = tl.exp(m_prev - n_e_max)
|
||||
p = tl.exp(scores - n_e_max)
|
||||
|
||||
# ============================================================
|
||||
# VALUE LOAD + DEQUANTIZE: [BLOCK_KV, BLOCK_D]
|
||||
# ============================================================
|
||||
val_bases = slot_bases + KPS
|
||||
|
||||
if VQB == 3:
|
||||
val_addrs0 = val_bases[:, None] + val_byte_idx[None, :]
|
||||
val_raw0 = tl.load(
|
||||
KV_cache_ptr + val_addrs0,
|
||||
mask=kv_mask[:, None] & d_mask[None, :],
|
||||
other=0,
|
||||
).to(tl.int32)
|
||||
val_raw1 = tl.load(
|
||||
KV_cache_ptr + val_addrs0 + 1,
|
||||
mask=kv_mask[:, None] & d_mask[None, :],
|
||||
other=0,
|
||||
).to(tl.int32)
|
||||
raw16 = val_raw0 | (val_raw1 << 8)
|
||||
v_idx = ((raw16 >> val_bit_shift[None, :]) & 0x7).to(tl.float32)
|
||||
|
||||
sc_bases = val_bases + VAL_DATA_BYTES
|
||||
sc_lo = tl.load(KV_cache_ptr + sc_bases, mask=kv_mask, other=0).to(
|
||||
tl.uint16
|
||||
)
|
||||
sc_hi = tl.load(KV_cache_ptr + sc_bases + 1, mask=kv_mask, other=0).to(
|
||||
tl.uint16
|
||||
)
|
||||
v_scales = (
|
||||
(sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
|
||||
)
|
||||
zr_lo = tl.load(KV_cache_ptr + sc_bases + 2, mask=kv_mask, other=0).to(
|
||||
tl.uint16
|
||||
)
|
||||
zr_hi = tl.load(KV_cache_ptr + sc_bases + 3, mask=kv_mask, other=0).to(
|
||||
tl.uint16
|
||||
)
|
||||
v_zeros = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
|
||||
values = v_idx * v_scales[:, None] + v_zeros[:, None]
|
||||
else: # VQB == 4
|
||||
vb_idx = d_offs // 2
|
||||
vb_shift = (d_offs % 2) * 4
|
||||
val_addrs = val_bases[:, None] + vb_idx[None, :]
|
||||
val_raw = tl.load(
|
||||
KV_cache_ptr + val_addrs,
|
||||
mask=kv_mask[:, None] & d_mask[None, :],
|
||||
other=0,
|
||||
).to(tl.int32)
|
||||
v_idx = ((val_raw >> vb_shift[None, :]) & 0xF).to(tl.float32)
|
||||
|
||||
sc_bases = val_bases + VAL_DATA_BYTES
|
||||
sc_lo = tl.load(KV_cache_ptr + sc_bases, mask=kv_mask, other=0).to(
|
||||
tl.uint16
|
||||
)
|
||||
sc_hi = tl.load(KV_cache_ptr + sc_bases + 1, mask=kv_mask, other=0).to(
|
||||
tl.uint16
|
||||
)
|
||||
v_scales = (
|
||||
(sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
|
||||
)
|
||||
zr_lo = tl.load(KV_cache_ptr + sc_bases + 2, mask=kv_mask, other=0).to(
|
||||
tl.uint16
|
||||
)
|
||||
zr_hi = tl.load(KV_cache_ptr + sc_bases + 3, mask=kv_mask, other=0).to(
|
||||
tl.uint16
|
||||
)
|
||||
v_zeros = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
|
||||
values = v_idx * v_scales[:, None] + v_zeros[:, None]
|
||||
|
||||
# ============================================================
|
||||
# WEIGHTED VALUE ACCUMULATION
|
||||
# ============================================================
|
||||
acc = acc * re_scale + tl.sum(p[:, None] * values, 0)
|
||||
l_prev = l_prev * re_scale + tl.sum(p, 0)
|
||||
m_prev = n_e_max
|
||||
|
||||
# Store partial result
|
||||
out_base = bid * stride_mid_b + hid * stride_mid_h + sid * stride_mid_s
|
||||
safe_l = tl.where(l_prev > 0.0, l_prev, 1.0)
|
||||
tl.store(Mid_o_ptr + out_base + d_offs, acc / safe_l, mask=d_mask)
|
||||
lse = m_prev + tl.log(safe_l)
|
||||
tl.store(Mid_o_ptr + out_base + HEAD_DIM, lse)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pre-dequant kernel: Bulk dequant K (MSE+norms) and V to fp16
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _tq_full_dequant_kv(
|
||||
KV_cache_ptr,
|
||||
Block_table_ptr,
|
||||
Centroids_ptr,
|
||||
K_out_ptr, # [B, Hk, max_seq, D] float16
|
||||
V_out_ptr, # [B, Hk, max_seq, D] float16
|
||||
stride_ko_b,
|
||||
stride_ko_h,
|
||||
stride_ko_s,
|
||||
stride_vo_b,
|
||||
stride_vo_h,
|
||||
stride_vo_s,
|
||||
stride_cache_block,
|
||||
stride_cache_pos,
|
||||
stride_cache_head,
|
||||
stride_bt_b,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
NUM_KV_HEADS: tl.constexpr,
|
||||
MSE_BYTES: tl.constexpr,
|
||||
KPS: tl.constexpr,
|
||||
VQB: tl.constexpr,
|
||||
VAL_DATA_BYTES: tl.constexpr,
|
||||
MSE_BITS: tl.constexpr,
|
||||
KEY_FP8: tl.constexpr,
|
||||
BLOCK_D: tl.constexpr,
|
||||
NORM_CORRECTION: tl.constexpr = 0,
|
||||
FP8_E4B15: tl.constexpr = 0, # 1 = use e4b15 (Ampere/Ada), 0 = e4nv (Hopper+)
|
||||
):
|
||||
"""Full dequant: reconstruct K (MSE centroids * norm or FP8) and V to fp16."""
|
||||
pos = tl.program_id(0)
|
||||
bh = tl.program_id(1)
|
||||
bid = bh // NUM_KV_HEADS
|
||||
hid = bh % NUM_KV_HEADS
|
||||
|
||||
page_idx = pos // BLOCK_SIZE
|
||||
page_off = pos % BLOCK_SIZE
|
||||
block_num = tl.load(Block_table_ptr + bid * stride_bt_b + page_idx)
|
||||
slot_base = (
|
||||
block_num * stride_cache_block
|
||||
+ page_off * stride_cache_pos
|
||||
+ hid * stride_cache_head
|
||||
)
|
||||
|
||||
d_offs = tl.arange(0, BLOCK_D)
|
||||
d_mask = d_offs < HEAD_DIM
|
||||
|
||||
# === K dequant ===
|
||||
ko_base = bid * stride_ko_b + hid * stride_ko_h + pos * stride_ko_s
|
||||
if KEY_FP8:
|
||||
k_raw = tl.load(KV_cache_ptr + slot_base + d_offs, mask=d_mask, other=0)
|
||||
if FP8_E4B15:
|
||||
k_recon = k_raw.to(tl.float8e4b15, bitcast=True).to(tl.float32)
|
||||
else:
|
||||
k_recon = k_raw.to(tl.float8e4nv, bitcast=True).to(tl.float32)
|
||||
tl.store(K_out_ptr + ko_base + d_offs, k_recon.to(tl.float16), mask=d_mask)
|
||||
else:
|
||||
# MSE unpack (3-bit or 4-bit) + norms
|
||||
mse_bit_off = d_offs * MSE_BITS
|
||||
mse_byte_idx = mse_bit_off // 8
|
||||
mse_bit_shift = mse_bit_off % 8
|
||||
mse_umask = (1 << MSE_BITS) - 1
|
||||
|
||||
mse_raw0 = tl.load(
|
||||
KV_cache_ptr + slot_base + mse_byte_idx, mask=d_mask, other=0
|
||||
).to(tl.int32)
|
||||
mse_raw1 = tl.load(
|
||||
KV_cache_ptr + slot_base + mse_byte_idx + 1, mask=d_mask, other=0
|
||||
).to(tl.int32)
|
||||
raw16_key = mse_raw0 | (mse_raw1 << 8)
|
||||
mse_idx = (raw16_key >> mse_bit_shift) & mse_umask
|
||||
|
||||
k_mse = tl.load(Centroids_ptr + mse_idx, mask=d_mask, other=0.0)
|
||||
|
||||
# Norm correction: re-normalize centroid vector to unit norm
|
||||
if NORM_CORRECTION:
|
||||
c_norm_sq = tl.sum(tl.where(d_mask, k_mse * k_mse, 0.0), axis=0)
|
||||
c_inv_norm = 1.0 / tl.sqrt(c_norm_sq + 1e-16)
|
||||
k_mse = k_mse * c_inv_norm
|
||||
|
||||
# Norms at MSE_BYTES offset (no QJL bytes)
|
||||
norm_base = slot_base + MSE_BYTES
|
||||
n_lo = tl.load(KV_cache_ptr + norm_base).to(tl.uint16)
|
||||
n_hi = tl.load(KV_cache_ptr + norm_base + 1).to(tl.uint16)
|
||||
vec_norm = (n_lo | (n_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
|
||||
|
||||
k_recon = vec_norm * k_mse
|
||||
tl.store(K_out_ptr + ko_base + d_offs, k_recon.to(tl.float16), mask=d_mask)
|
||||
|
||||
# === V dequant ===
|
||||
val_base = slot_base + KPS
|
||||
if VQB == 4:
|
||||
vb_idx = d_offs // 2
|
||||
vb_shift = (d_offs % 2) * 4
|
||||
val_raw = tl.load(KV_cache_ptr + val_base + vb_idx, mask=d_mask, other=0).to(
|
||||
tl.int32
|
||||
)
|
||||
v_idx = ((val_raw >> vb_shift) & 0xF).to(tl.float32)
|
||||
|
||||
sc_base = val_base + VAL_DATA_BYTES
|
||||
sc_lo = tl.load(KV_cache_ptr + sc_base).to(tl.uint16)
|
||||
sc_hi = tl.load(KV_cache_ptr + sc_base + 1).to(tl.uint16)
|
||||
v_scale = (sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
|
||||
zr_lo = tl.load(KV_cache_ptr + sc_base + 2).to(tl.uint16)
|
||||
zr_hi = tl.load(KV_cache_ptr + sc_base + 3).to(tl.uint16)
|
||||
v_zero = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
|
||||
v_vals = v_idx * v_scale + v_zero
|
||||
elif VQB == 3:
|
||||
# 3-bit value unpack: 8 values per 3 bytes
|
||||
val_bit_off = d_offs * 3
|
||||
val_byte_idx = val_bit_off // 8
|
||||
val_bit_shift = val_bit_off % 8
|
||||
val_raw0 = tl.load(
|
||||
KV_cache_ptr + val_base + val_byte_idx, mask=d_mask, other=0
|
||||
).to(tl.int32)
|
||||
val_raw1 = tl.load(
|
||||
KV_cache_ptr + val_base + val_byte_idx + 1, mask=d_mask, other=0
|
||||
).to(tl.int32)
|
||||
raw16_val = val_raw0 | (val_raw1 << 8)
|
||||
v_idx = ((raw16_val >> val_bit_shift) & 0x7).to(tl.float32)
|
||||
|
||||
sc_base = val_base + VAL_DATA_BYTES
|
||||
sc_lo = tl.load(KV_cache_ptr + sc_base).to(tl.uint16)
|
||||
sc_hi = tl.load(KV_cache_ptr + sc_base + 1).to(tl.uint16)
|
||||
v_scale = (sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
|
||||
zr_lo = tl.load(KV_cache_ptr + sc_base + 2).to(tl.uint16)
|
||||
zr_hi = tl.load(KV_cache_ptr + sc_base + 3).to(tl.uint16)
|
||||
v_zero = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
|
||||
v_vals = v_idx * v_scale + v_zero
|
||||
else:
|
||||
v_vals = tl.zeros([BLOCK_D], dtype=tl.float32)
|
||||
|
||||
vo_base = bid * stride_vo_b + hid * stride_vo_h + pos * stride_vo_s
|
||||
tl.store(V_out_ptr + vo_base + d_offs, v_vals.to(tl.float16), mask=d_mask)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stage 2: Reuse from triton_decode_attention.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Launcher — cached constants + fused GEMM
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_layout_cache: dict = {}
|
||||
|
||||
|
||||
def _get_layout(D, mse_bits, value_quant_bits, key_packed_size):
|
||||
"""Get cached layout constants."""
|
||||
key = (D, mse_bits, value_quant_bits, key_packed_size)
|
||||
cfg = _layout_cache.get(key)
|
||||
if cfg is None:
|
||||
val_data_bytes = math.ceil(D * value_quant_bits / 8)
|
||||
cfg = {
|
||||
"mse_bytes": math.ceil(D * mse_bits / 8),
|
||||
"val_data_bytes": val_data_bytes,
|
||||
"mse_bits": mse_bits,
|
||||
"n_centroids": 2**mse_bits,
|
||||
"BLOCK_D": triton.next_power_of_2(D),
|
||||
}
|
||||
_layout_cache[key] = cfg
|
||||
return cfg
|
||||
|
||||
|
||||
def triton_turboquant_decode_attention(
|
||||
query: torch.Tensor, # [B, Hq, D] — original query
|
||||
kv_cache: torch.Tensor, # [num_blocks, block_size, Hk, padded_slot] uint8
|
||||
block_table: torch.Tensor, # [B, max_num_blocks] int32
|
||||
seq_lens: torch.Tensor, # [B] int32
|
||||
Pi: torch.Tensor, # [D, D] float32
|
||||
centroids: torch.Tensor, # [n_centroids] float32
|
||||
scale: float,
|
||||
mse_bits: int,
|
||||
key_packed_size: int,
|
||||
value_quant_bits: int,
|
||||
key_fp8: bool = False,
|
||||
norm_correction: bool = False,
|
||||
PiT: torch.Tensor | None = None, # [D, D] pre-computed Pi.T contiguous
|
||||
# Pre-allocated buffers (optional, avoids per-call allocation)
|
||||
mid_o_buf: torch.Tensor | None = None,
|
||||
output_buf: torch.Tensor | None = None,
|
||||
lse_buf: torch.Tensor | None = None,
|
||||
buf_holder: Any = None,
|
||||
max_num_kv_splits: int = 32, # fixed split count (must be constant for cudagraph)
|
||||
) -> torch.Tensor:
|
||||
"""Launch fused TQ decode attention (Triton stage1 + stage2).
|
||||
|
||||
Returns: output tensor [B, Hq, D] in query's dtype.
|
||||
"""
|
||||
B, Hq, D = query.shape
|
||||
Hk = kv_cache.shape[2]
|
||||
block_size = kv_cache.shape[1]
|
||||
kv_group_size = Hq // Hk
|
||||
device = query.device
|
||||
|
||||
cfg = _get_layout(D, mse_bits, value_quant_bits, key_packed_size)
|
||||
|
||||
# Compute q_rot = q @ Pi.T (rotated query for MSE key scoring)
|
||||
# FP8 path: pass query directly (float16); kernel casts inline.
|
||||
# MSE path: still needs external GEMM (cuBLAS), so q_rot is float32.
|
||||
if key_fp8:
|
||||
q_rot = query.contiguous()
|
||||
else:
|
||||
q_float = query.float()
|
||||
if PiT is None:
|
||||
PiT = Pi.T.contiguous()
|
||||
q_rot = (q_float @ PiT).contiguous()
|
||||
|
||||
NUM_KV_SPLITS = max_num_kv_splits
|
||||
|
||||
if (
|
||||
mid_o_buf is not None
|
||||
and mid_o_buf.shape[0] >= B
|
||||
and mid_o_buf.shape[2] >= NUM_KV_SPLITS
|
||||
):
|
||||
mid_o = mid_o_buf[:B, :Hq, :NUM_KV_SPLITS, :]
|
||||
else:
|
||||
mid_o = torch.empty(
|
||||
B,
|
||||
Hq,
|
||||
NUM_KV_SPLITS,
|
||||
D + 1,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
if buf_holder is not None:
|
||||
buf_holder._tq_mid_o_buf = mid_o
|
||||
|
||||
# Stage 1: split-KV tiled attention scoring + value accumulation
|
||||
fp8_e4b15 = _use_fp8_e4b15(device.index or 0)
|
||||
BLOCK_KV = 4
|
||||
grid = (B, Hq, NUM_KV_SPLITS)
|
||||
_tq_decode_stage1[grid](
|
||||
q_rot,
|
||||
kv_cache,
|
||||
block_table,
|
||||
seq_lens,
|
||||
centroids,
|
||||
mid_o,
|
||||
q_rot.stride(0),
|
||||
q_rot.stride(1),
|
||||
kv_cache.stride(0),
|
||||
kv_cache.stride(1),
|
||||
kv_cache.stride(2),
|
||||
block_table.stride(0),
|
||||
mid_o.stride(0),
|
||||
mid_o.stride(1),
|
||||
mid_o.stride(2),
|
||||
NUM_KV_HEADS=Hk,
|
||||
HEAD_DIM=D,
|
||||
BLOCK_SIZE=block_size,
|
||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||
KV_GROUP_SIZE=kv_group_size,
|
||||
MSE_BITS=mse_bits,
|
||||
MSE_BYTES=cfg["mse_bytes"],
|
||||
KPS=key_packed_size,
|
||||
VQB=value_quant_bits,
|
||||
VAL_DATA_BYTES=cfg["val_data_bytes"],
|
||||
ATTN_SCALE=scale,
|
||||
BLOCK_D=cfg["BLOCK_D"],
|
||||
BLOCK_KV=BLOCK_KV,
|
||||
KEY_FP8=1 if key_fp8 else 0,
|
||||
NORM_CORRECTION=1 if norm_correction else 0,
|
||||
FP8_E4B15=fp8_e4b15,
|
||||
num_warps=1,
|
||||
num_stages=1,
|
||||
)
|
||||
|
||||
# Stage 2: Reduce across KV splits
|
||||
if output_buf is not None and output_buf.shape[0] >= B:
|
||||
output = output_buf[:B, :Hq, :D]
|
||||
else:
|
||||
output = torch.empty(B, Hq, D, dtype=torch.float32, device=device)
|
||||
if buf_holder is not None:
|
||||
buf_holder._tq_output_buf = output
|
||||
if lse_buf is not None and lse_buf.shape[0] >= B:
|
||||
lse = lse_buf[:B, :Hq]
|
||||
else:
|
||||
lse = torch.empty(B, Hq, dtype=torch.float32, device=device)
|
||||
if buf_holder is not None:
|
||||
buf_holder._tq_lse_buf = lse
|
||||
|
||||
grid2 = (B, Hq)
|
||||
_fwd_kernel_stage2[grid2](
|
||||
mid_o,
|
||||
output,
|
||||
lse,
|
||||
seq_lens,
|
||||
mid_o.stride(0),
|
||||
mid_o.stride(1),
|
||||
mid_o.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
lse.stride(0),
|
||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||
BLOCK_DV=cfg["BLOCK_D"],
|
||||
Lv=D,
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
)
|
||||
|
||||
return output.to(query.dtype)
|
||||
@@ -0,0 +1,441 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Fused Triton kernels for TurboQuant KV store.
|
||||
|
||||
Two kernels:
|
||||
1. _tq_fused_store_fp8: FP8 key scatter + value uniform quantization.
|
||||
2. _tq_fused_store_mse: Fused binary-search bucketize + MSE index
|
||||
packing + value quantization.
|
||||
|
||||
The launcher `triton_turboquant_store` selects the appropriate kernel.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.ops.triton_turboquant_decode import _use_fp8_e4b15
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Shared: value uniform quantization + pack + scale/zero store
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _store_quantized_value(
|
||||
Value_ptr,
|
||||
KV_cache_ptr,
|
||||
base, # pid * D offset into Value_ptr
|
||||
slot_base, # byte offset into KV_cache_ptr for this slot+head
|
||||
d_offs, # tl.arange(0, BLOCK_D)
|
||||
d_mask, # d_offs < D
|
||||
D: tl.constexpr,
|
||||
KPS: tl.constexpr,
|
||||
VQB: tl.constexpr,
|
||||
VAL_DATA_BYTES: tl.constexpr,
|
||||
BLOCK_D: tl.constexpr,
|
||||
BLOCK_VAL: tl.constexpr,
|
||||
BLOCK_GRP: tl.constexpr,
|
||||
):
|
||||
"""Uniform quantization of values to VQB bits, pack, and store with scale/zero."""
|
||||
val_cache_offset = KPS
|
||||
|
||||
if VQB == 3:
|
||||
val_vec = tl.load(Value_ptr + base + d_offs, mask=d_mask, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
val_min = tl.min(tl.where(d_mask, val_vec, float("inf")), axis=0)
|
||||
val_max = tl.max(tl.where(d_mask, val_vec, -float("inf")), axis=0)
|
||||
v_scale = (val_max - val_min) / 7.0
|
||||
v_scale = tl.where(v_scale > 1e-8, v_scale, 1e-8)
|
||||
|
||||
q_vals = tl.minimum(
|
||||
tl.maximum(((val_vec - val_min) / v_scale + 0.5).to(tl.int32), 0), 7
|
||||
)
|
||||
|
||||
grp_offs = tl.arange(0, BLOCK_GRP)
|
||||
grp_mask = grp_offs < (D // 8)
|
||||
q_grp = tl.reshape(q_vals, [BLOCK_GRP, 8])
|
||||
shifts_3bit = tl.arange(0, 8) * 3
|
||||
packed_24 = tl.sum(q_grp << shifts_3bit[None, :], axis=1)
|
||||
b0 = (packed_24 & 0xFF).to(tl.uint8)
|
||||
b1 = ((packed_24 >> 8) & 0xFF).to(tl.uint8)
|
||||
b2 = ((packed_24 >> 16) & 0xFF).to(tl.uint8)
|
||||
tl.store(
|
||||
KV_cache_ptr + slot_base + val_cache_offset + grp_offs * 3,
|
||||
b0,
|
||||
mask=grp_mask,
|
||||
)
|
||||
tl.store(
|
||||
KV_cache_ptr + slot_base + val_cache_offset + grp_offs * 3 + 1,
|
||||
b1,
|
||||
mask=grp_mask,
|
||||
)
|
||||
tl.store(
|
||||
KV_cache_ptr + slot_base + val_cache_offset + grp_offs * 3 + 2,
|
||||
b2,
|
||||
mask=grp_mask,
|
||||
)
|
||||
|
||||
sc_offset = val_cache_offset + VAL_DATA_BYTES
|
||||
sc_f16 = v_scale.to(tl.float16)
|
||||
sc_u16 = sc_f16.to(tl.uint16, bitcast=True)
|
||||
tl.store(KV_cache_ptr + slot_base + sc_offset, (sc_u16 & 0xFF).to(tl.uint8))
|
||||
tl.store(
|
||||
KV_cache_ptr + slot_base + sc_offset + 1,
|
||||
((sc_u16 >> 8) & 0xFF).to(tl.uint8),
|
||||
)
|
||||
zr_f16 = val_min.to(tl.float16)
|
||||
zr_u16 = zr_f16.to(tl.uint16, bitcast=True)
|
||||
tl.store(KV_cache_ptr + slot_base + sc_offset + 2, (zr_u16 & 0xFF).to(tl.uint8))
|
||||
tl.store(
|
||||
KV_cache_ptr + slot_base + sc_offset + 3,
|
||||
((zr_u16 >> 8) & 0xFF).to(tl.uint8),
|
||||
)
|
||||
|
||||
else: # VQB == 4
|
||||
val_vec = tl.load(Value_ptr + base + d_offs, mask=d_mask, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
val_min = tl.min(tl.where(d_mask, val_vec, float("inf")), axis=0)
|
||||
val_max = tl.max(tl.where(d_mask, val_vec, -float("inf")), axis=0)
|
||||
v_scale = (val_max - val_min) / 15.0
|
||||
v_scale = tl.where(v_scale > 1e-8, v_scale, 1e-8)
|
||||
|
||||
# Quantize all D elements from register (no re-load)
|
||||
q_all = tl.minimum(
|
||||
tl.maximum(((val_vec - val_min) / v_scale + 0.5).to(tl.int32), 0), 15
|
||||
)
|
||||
# Reshape to pairs and pack two 4-bit values per byte
|
||||
q_pairs = tl.reshape(q_all, [BLOCK_D // 2, 2])
|
||||
shifts_4 = tl.arange(0, 2) * 4
|
||||
packed_val = tl.sum((q_pairs & 0xF) << shifts_4[None, :], axis=1).to(tl.uint8)
|
||||
val_offs = tl.arange(0, BLOCK_D // 2)
|
||||
val_mask = val_offs < VAL_DATA_BYTES
|
||||
tl.store(
|
||||
KV_cache_ptr + slot_base + val_cache_offset + val_offs,
|
||||
packed_val,
|
||||
mask=val_mask,
|
||||
)
|
||||
|
||||
sc_offset = val_cache_offset + VAL_DATA_BYTES
|
||||
sc_f16 = v_scale.to(tl.float16)
|
||||
sc_u16 = sc_f16.to(tl.uint16, bitcast=True)
|
||||
tl.store(KV_cache_ptr + slot_base + sc_offset, (sc_u16 & 0xFF).to(tl.uint8))
|
||||
tl.store(
|
||||
KV_cache_ptr + slot_base + sc_offset + 1,
|
||||
((sc_u16 >> 8) & 0xFF).to(tl.uint8),
|
||||
)
|
||||
zr_f16 = val_min.to(tl.float16)
|
||||
zr_u16 = zr_f16.to(tl.uint16, bitcast=True)
|
||||
tl.store(KV_cache_ptr + slot_base + sc_offset + 2, (zr_u16 & 0xFF).to(tl.uint8))
|
||||
tl.store(
|
||||
KV_cache_ptr + slot_base + sc_offset + 3,
|
||||
((zr_u16 >> 8) & 0xFF).to(tl.uint8),
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# FP8 key store + value uniform quantization
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _tq_fused_store_fp8(
|
||||
Key_ptr, # [NH, D] float16/bfloat16 — raw keys
|
||||
Value_ptr, # [NH, D] float16/bfloat16 — raw values
|
||||
KV_cache_ptr, # [total_bytes] uint8 (flattened view)
|
||||
Slot_mapping_ptr, # [N] int32 — per-token slot indices
|
||||
# Cache strides (for computing byte offsets)
|
||||
stride_cache_block: tl.constexpr,
|
||||
stride_cache_pos: tl.constexpr,
|
||||
stride_cache_head: tl.constexpr,
|
||||
# Dimensions
|
||||
D: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_D: tl.constexpr,
|
||||
# TQ layout
|
||||
KPS: tl.constexpr,
|
||||
# Value quantization
|
||||
VQB: tl.constexpr,
|
||||
VAL_DATA_BYTES: tl.constexpr,
|
||||
# Packing block sizes
|
||||
BLOCK_VAL: tl.constexpr,
|
||||
BLOCK_GRP: tl.constexpr = 16,
|
||||
FP8_E4B15: tl.constexpr = 0, # 1 = e4b15 (Ampere/Ada), 0 = e4nv (Hopper+)
|
||||
):
|
||||
"""FP8 key cast+scatter + value uniform quantization."""
|
||||
pid = tl.program_id(0)
|
||||
token_idx = pid // H
|
||||
head_idx = pid % H
|
||||
|
||||
slot = tl.load(Slot_mapping_ptr + token_idx)
|
||||
if slot < 0:
|
||||
return
|
||||
blk = slot // BLOCK_SIZE
|
||||
off = slot % BLOCK_SIZE
|
||||
slot_base = (
|
||||
blk * stride_cache_block + off * stride_cache_pos + head_idx * stride_cache_head
|
||||
)
|
||||
|
||||
base = pid * D
|
||||
|
||||
# ── FP8 KEY: cast to FP8 in-kernel and store ─────────────────
|
||||
d_offs = tl.arange(0, BLOCK_D)
|
||||
d_mask = d_offs < D
|
||||
k_vals = tl.load(Key_ptr + base + d_offs, mask=d_mask, other=0.0)
|
||||
k_fp8 = k_vals.to(tl.float8e4b15) if FP8_E4B15 else k_vals.to(tl.float8e4nv)
|
||||
k_bytes = k_fp8.to(tl.uint8, bitcast=True)
|
||||
tl.store(KV_cache_ptr + slot_base + d_offs, k_bytes, mask=d_mask)
|
||||
|
||||
# ── VALUE QUANTIZE + PACK ───────────────────────────────────────
|
||||
_store_quantized_value(
|
||||
Value_ptr,
|
||||
KV_cache_ptr,
|
||||
base,
|
||||
slot_base,
|
||||
d_offs,
|
||||
d_mask,
|
||||
D=D,
|
||||
KPS=KPS,
|
||||
VQB=VQB,
|
||||
VAL_DATA_BYTES=VAL_DATA_BYTES,
|
||||
BLOCK_D=BLOCK_D,
|
||||
BLOCK_VAL=BLOCK_VAL,
|
||||
BLOCK_GRP=BLOCK_GRP,
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Fused MSE store: bucketize + MSE index pack + norm store + value pack
|
||||
# (eliminates 4 PyTorch kernel launches per layer vs pack-only kernel)
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _tq_fused_store_mse(
|
||||
# Post-rotation inputs
|
||||
Y_ptr, # [NH, D] float32 — rotated normalized keys (x_hat @ PiT)
|
||||
Norms_ptr, # [NH] float32 — key vector norms (||k||)
|
||||
Value_ptr, # [NH, D] float32 — raw values
|
||||
# Quantization tables
|
||||
Midpoints_ptr, # [n_centroids-1] float32
|
||||
# Cache and indexing
|
||||
KV_cache_ptr, # [total_bytes] uint8 (flattened view)
|
||||
Slot_mapping_ptr, # [N] int32 — per-token slot indices
|
||||
# Cache strides
|
||||
stride_cache_block: tl.constexpr,
|
||||
stride_cache_pos: tl.constexpr,
|
||||
stride_cache_head: tl.constexpr,
|
||||
# Dimensions
|
||||
D: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_D: tl.constexpr,
|
||||
# TQ layout
|
||||
MSE_BYTES: tl.constexpr,
|
||||
KPS: tl.constexpr,
|
||||
# Value quantization
|
||||
VQB: tl.constexpr,
|
||||
VAL_DATA_BYTES: tl.constexpr,
|
||||
# Packing block sizes
|
||||
BLOCK_VAL: tl.constexpr,
|
||||
# MSE params
|
||||
MSE_BITS: tl.constexpr,
|
||||
N_CENTROIDS: tl.constexpr,
|
||||
BLOCK_GRP: tl.constexpr = 16,
|
||||
):
|
||||
"""Fused MSE quantize + pack + store.
|
||||
|
||||
Performs binary-search bucketize, MSE index packing, norm storage,
|
||||
and value quantization in one kernel.
|
||||
"""
|
||||
pid = tl.program_id(0)
|
||||
token_idx = pid // H
|
||||
head_idx = pid % H
|
||||
|
||||
slot = tl.load(Slot_mapping_ptr + token_idx)
|
||||
if slot < 0:
|
||||
return
|
||||
blk = slot // BLOCK_SIZE
|
||||
off = slot % BLOCK_SIZE
|
||||
slot_base = (
|
||||
blk * stride_cache_block + off * stride_cache_pos + head_idx * stride_cache_head
|
||||
)
|
||||
|
||||
base = pid * D
|
||||
d_offs = tl.arange(0, BLOCK_D)
|
||||
d_mask = d_offs < D
|
||||
|
||||
# ── 1. BINARY SEARCH BUCKETIZE ───────────────────────────────────
|
||||
# Midpoints are sorted (N_CENTROIDS-1 values); binary search finds
|
||||
# insertion point in MSE_BITS iterations vs N_CENTROIDS-1 for linear.
|
||||
y_vec = tl.load(Y_ptr + base + d_offs, mask=d_mask, other=0.0)
|
||||
lo = tl.zeros([BLOCK_D], dtype=tl.int32)
|
||||
hi = tl.full([BLOCK_D], N_CENTROIDS - 1, dtype=tl.int32)
|
||||
for _ in range(MSE_BITS):
|
||||
mid = (lo + hi) >> 1
|
||||
# Clamp to valid midpoint index [0, N_CENTROIDS-2] for load safety;
|
||||
# the search result (lo) is still correct since converged lanes
|
||||
# don't change.
|
||||
safe_mid = tl.minimum(mid, N_CENTROIDS - 2)
|
||||
mid_val = tl.load(Midpoints_ptr + safe_mid, mask=d_mask, other=0.0)
|
||||
lo = tl.where(y_vec >= mid_val, mid + 1, lo)
|
||||
hi = tl.where(y_vec >= mid_val, hi, mid)
|
||||
idx = tl.minimum(lo, N_CENTROIDS - 1)
|
||||
|
||||
# ── 2. PACK MSE INDICES from register idx ─────────────────────────
|
||||
if MSE_BITS == 4:
|
||||
idx_pairs = tl.reshape(idx, [BLOCK_D // 2, 2])
|
||||
shifts_4 = tl.arange(0, 2) * 4
|
||||
packed = tl.sum((idx_pairs & 0xF) << shifts_4[None, :], axis=1).to(tl.uint8)
|
||||
mse_offs = tl.arange(0, BLOCK_D // 2)
|
||||
mse_mask = mse_offs < MSE_BYTES
|
||||
tl.store(KV_cache_ptr + slot_base + mse_offs, packed, mask=mse_mask)
|
||||
|
||||
elif MSE_BITS == 3:
|
||||
grp_offs = tl.arange(0, BLOCK_GRP)
|
||||
grp_mask = grp_offs < (D // 8)
|
||||
idx_grp = tl.reshape(idx, [BLOCK_GRP, 8])
|
||||
shifts_3 = tl.arange(0, 8) * 3
|
||||
packed_24 = tl.sum((idx_grp & 0x7) << shifts_3[None, :], axis=1)
|
||||
b0 = (packed_24 & 0xFF).to(tl.uint8)
|
||||
b1 = ((packed_24 >> 8) & 0xFF).to(tl.uint8)
|
||||
b2 = ((packed_24 >> 16) & 0xFF).to(tl.uint8)
|
||||
tl.store(KV_cache_ptr + slot_base + grp_offs * 3, b0, mask=grp_mask)
|
||||
tl.store(KV_cache_ptr + slot_base + grp_offs * 3 + 1, b1, mask=grp_mask)
|
||||
tl.store(KV_cache_ptr + slot_base + grp_offs * 3 + 2, b2, mask=grp_mask)
|
||||
|
||||
# ── 3. STORE vec_norm (fp16, 2 bytes) ─────────────────────────────
|
||||
norm_offset = MSE_BYTES
|
||||
|
||||
vn_f16 = tl.load(Norms_ptr + pid).to(tl.float16)
|
||||
vn_u16 = vn_f16.to(tl.uint16, bitcast=True)
|
||||
tl.store(KV_cache_ptr + slot_base + norm_offset, (vn_u16 & 0xFF).to(tl.uint8))
|
||||
tl.store(
|
||||
KV_cache_ptr + slot_base + norm_offset + 1, ((vn_u16 >> 8) & 0xFF).to(tl.uint8)
|
||||
)
|
||||
|
||||
# ── 4. VALUE QUANTIZE + PACK ──────────────────────────────────────
|
||||
_store_quantized_value(
|
||||
Value_ptr,
|
||||
KV_cache_ptr,
|
||||
base,
|
||||
slot_base,
|
||||
d_offs,
|
||||
d_mask,
|
||||
D=D,
|
||||
KPS=KPS,
|
||||
VQB=VQB,
|
||||
VAL_DATA_BYTES=VAL_DATA_BYTES,
|
||||
BLOCK_D=BLOCK_D,
|
||||
BLOCK_VAL=BLOCK_VAL,
|
||||
BLOCK_GRP=BLOCK_GRP,
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Launcher
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def triton_turboquant_store(
|
||||
key: torch.Tensor, # [N, H, D] — raw keys (post-RoPE)
|
||||
value: torch.Tensor, # [N, H, D] — raw values
|
||||
kv_cache: torch.Tensor, # [num_blocks, block_size, Hk, padded_slot] uint8
|
||||
slot_mapping: torch.Tensor, # [N] int32
|
||||
PiT: torch.Tensor, # [D, D] float32
|
||||
midpoints: torch.Tensor, # [n_centroids-1] float32
|
||||
mse_bits: int,
|
||||
key_packed_size: int,
|
||||
value_quant_bits: int,
|
||||
key_fp8: bool = False,
|
||||
):
|
||||
"""Launch TQ store kernel (FP8 or MSE path)."""
|
||||
N, H, D = key.shape
|
||||
NH = N * H
|
||||
block_size = kv_cache.shape[1]
|
||||
BLOCK_D = triton.next_power_of_2(D)
|
||||
mse_bytes = math.ceil(D * mse_bits / 8)
|
||||
n_centroids = 2**mse_bits
|
||||
|
||||
val_data_bytes = math.ceil(D * value_quant_bits / 8)
|
||||
|
||||
BLOCK_VAL = triton.next_power_of_2(val_data_bytes)
|
||||
|
||||
# Cache strides (element_size=1 for uint8, so stride in bytes = stride())
|
||||
stride_block = kv_cache.stride(0)
|
||||
stride_pos = kv_cache.stride(1)
|
||||
stride_head = kv_cache.stride(2)
|
||||
|
||||
block_grp = triton.next_power_of_2(D // 8) if D >= 8 else 1
|
||||
|
||||
# ── FP8 PATH: in-kernel FP8 cast + scatter via fp8 kernel ──
|
||||
if key_fp8:
|
||||
k_flat = key.reshape(NH, D).contiguous()
|
||||
v_flat = value.reshape(NH, D).contiguous()
|
||||
|
||||
fp8_e4b15 = _use_fp8_e4b15(key.device.index or 0)
|
||||
|
||||
grid = (NH,)
|
||||
_tq_fused_store_fp8[grid](
|
||||
k_flat,
|
||||
v_flat,
|
||||
kv_cache.view(-1),
|
||||
slot_mapping,
|
||||
stride_cache_block=stride_block,
|
||||
stride_cache_pos=stride_pos,
|
||||
stride_cache_head=stride_head,
|
||||
D=D,
|
||||
H=H,
|
||||
BLOCK_SIZE=block_size,
|
||||
BLOCK_D=BLOCK_D,
|
||||
KPS=key_packed_size,
|
||||
VQB=value_quant_bits,
|
||||
VAL_DATA_BYTES=val_data_bytes,
|
||||
BLOCK_VAL=BLOCK_VAL,
|
||||
BLOCK_GRP=block_grp,
|
||||
FP8_E4B15=fp8_e4b15,
|
||||
num_warps=4,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
# ── MSE PATH: external GEMM + fused bucketize/pack kernel ──
|
||||
# Normalize + rotation GEMM externally (cuBLAS is faster than in-kernel)
|
||||
k_flat = key.float().reshape(NH, D)
|
||||
norms = k_flat.norm(dim=1, keepdim=True)
|
||||
x_hat = k_flat / (norms + 1e-8)
|
||||
y = x_hat @ PiT
|
||||
|
||||
v_flat = value.float().reshape(NH, D)
|
||||
|
||||
# Fused kernel: bucketize + MSE index pack + norm store + value pack
|
||||
grid = (NH,)
|
||||
_tq_fused_store_mse[grid](
|
||||
y,
|
||||
norms.squeeze(1),
|
||||
v_flat,
|
||||
midpoints,
|
||||
kv_cache.view(-1),
|
||||
slot_mapping,
|
||||
stride_cache_block=stride_block,
|
||||
stride_cache_pos=stride_pos,
|
||||
stride_cache_head=stride_head,
|
||||
D=D,
|
||||
H=H,
|
||||
BLOCK_SIZE=block_size,
|
||||
BLOCK_D=BLOCK_D,
|
||||
MSE_BYTES=mse_bytes,
|
||||
KPS=key_packed_size,
|
||||
VQB=value_quant_bits,
|
||||
VAL_DATA_BYTES=val_data_bytes,
|
||||
BLOCK_VAL=BLOCK_VAL,
|
||||
MSE_BITS=mse_bits,
|
||||
N_CENTROIDS=n_centroids,
|
||||
BLOCK_GRP=block_grp,
|
||||
num_warps=4,
|
||||
num_stages=1,
|
||||
)
|
||||
@@ -21,6 +21,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
MLAAttentionSpec,
|
||||
SinkFullAttentionSpec,
|
||||
SlidingWindowSpec,
|
||||
TQFullAttentionSpec,
|
||||
)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
@@ -209,7 +210,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks)
|
||||
)
|
||||
req_blocks.extend(allocated_blocks)
|
||||
if type(self.kv_cache_spec) is FullAttentionSpec:
|
||||
if type(self.kv_cache_spec) in (FullAttentionSpec, TQFullAttentionSpec):
|
||||
self.new_block_ids.extend(b.block_id for b in allocated_blocks)
|
||||
|
||||
def allocate_new_blocks(
|
||||
@@ -237,7 +238,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
else:
|
||||
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
|
||||
req_blocks.extend(new_blocks)
|
||||
if type(self.kv_cache_spec) is FullAttentionSpec:
|
||||
if type(self.kv_cache_spec) in (FullAttentionSpec, TQFullAttentionSpec):
|
||||
self.new_block_ids.extend(b.block_id for b in new_blocks)
|
||||
return new_blocks
|
||||
|
||||
@@ -1114,6 +1115,7 @@ class SinkFullAttentionManager(FullAttentionManager):
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
TQFullAttentionSpec: FullAttentionManager,
|
||||
MLAAttentionSpec: FullAttentionManager,
|
||||
SlidingWindowSpec: SlidingWindowManager,
|
||||
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
|
||||
|
||||
@@ -245,6 +245,32 @@ class FullAttentionSpec(AttentionSpec):
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TQFullAttentionSpec(FullAttentionSpec):
|
||||
"""FullAttentionSpec with TQ-aware page size.
|
||||
|
||||
Python equivalent of the C++ TQ4FullAttentionSpec. Overrides
|
||||
real_page_size_bytes to use TQ slot bytes instead of the raw
|
||||
head_size * dtype formula.
|
||||
"""
|
||||
|
||||
tq_slot_size: int = 0
|
||||
|
||||
@property
|
||||
def real_page_size_bytes(self) -> int:
|
||||
if self.tq_slot_size > 0:
|
||||
return self.block_size * self.num_kv_heads * self.tq_slot_size
|
||||
return super().real_page_size_bytes
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
merged = super().merge(specs)
|
||||
assert all(s.tq_slot_size == specs[0].tq_slot_size for s in specs), (
|
||||
"All TQ layers in the same KV cache group must use the same tq_slot_size."
|
||||
)
|
||||
return replace(merged, tq_slot_size=specs[0].tq_slot_size)
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MLAAttentionSpec(FullAttentionSpec):
|
||||
# TODO(Lucas/Chen): less hacky way to do this
|
||||
|
||||
@@ -120,7 +120,7 @@ class KVBlockZeroer:
|
||||
|
||||
for group in attn_groups_iter:
|
||||
spec = group.kv_cache_spec
|
||||
if type(spec) is not FullAttentionSpec:
|
||||
if not isinstance(spec, FullAttentionSpec):
|
||||
continue
|
||||
if group.kv_cache_group_id >= len(kernel_block_sizes):
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user