mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
4d51588e23
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai> Signed-off-by: Woosuk Kwon <woosuk@inferact.ai> Signed-off-by: qizixi <zixi@inferact.ai> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Yongye Zhu <yongye@inferact.ai> Co-authored-by: Simon Mo <simon@inferact.ai> Co-authored-by: Bugen Zhao <i@bugenzhao.com> Co-authored-by: Giancarlo Delfin <gdelfin@inferact.ai> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roy Wang <yasong.wang@inferact.ai> Co-authored-by: Woosuk Kwon <woosuk@inferact.ai> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Zhewen Li <jerven.vllm@gmail.com> Co-authored-by: Zijing Liu <liuzijing2014@gmail.com> Co-authored-by: khluu <khluu000@gmail.com> Co-authored-by: qizixi <zixi@inferact.ai> Co-authored-by: Zhewen Li <zhewenli@inferact.ai>
312 lines
12 KiB
Python
312 lines
12 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
Round-trip tests for compressor → FP8 quant + KV cache insert → gather + dequant.
|
|
|
|
Two paths tested:
|
|
A) DeepseekV4 Attention: head_dim=512 (448 FP8 nope + 64 bf16 rope), quant_block=64
|
|
B) Indexer: head_dim=128 (all FP8), quant_block=128
|
|
|
|
These serve as golden references for validating the future fused
|
|
compressor+quant+cache kernel.
|
|
"""
|
|
|
|
import math
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm.v1.attention.ops.deepseek_v4_ops import (
|
|
dequantize_and_gather_k_cache,
|
|
quantize_and_insert_k_cache,
|
|
)
|
|
|
|
|
|
def _ue8m0_reference(x: torch.Tensor, block_size: int, fp8_max: float):
|
|
"""PyTorch reference for UE8M0 FP8 quantization (per-block, power-of-2 scale).
|
|
|
|
Returns (x_fp8, scales) where x_fp8 is float8_e4m3fn and scales are float32.
|
|
"""
|
|
assert x.dim() == 1
|
|
n = x.numel()
|
|
n_blocks = math.ceil(n / block_size)
|
|
x_fp8 = torch.zeros(n, dtype=torch.float8_e4m3fn, device=x.device)
|
|
scales = torch.zeros(n_blocks, dtype=torch.float32, device=x.device)
|
|
|
|
for i in range(n_blocks):
|
|
start = i * block_size
|
|
end = min(start + block_size, n)
|
|
block = x[start:end].float()
|
|
amax = block.abs().max().clamp(min=1e-4)
|
|
raw_scale = amax / fp8_max
|
|
exponent = math.ceil(math.log2(raw_scale.item()))
|
|
scale = 2.0**exponent
|
|
scales[i] = scale
|
|
quantized = (block / scale).clamp(-fp8_max, fp8_max)
|
|
x_fp8[start:end] = quantized.to(torch.float8_e4m3fn)
|
|
|
|
return x_fp8, scales
|
|
|
|
|
|
# ── Test A: DeepseekV4 Attention path ──────────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 17])
|
|
@pytest.mark.parametrize("block_size", [16, 64])
|
|
def test_deepseek_v4_attention_quant_cache_roundtrip(num_tokens: int, block_size: int):
|
|
"""compressed_kv → quantize_and_insert_k_cache → dequantize_and_gather_k_cache
|
|
→ compare against original."""
|
|
|
|
HEAD_DIM = 512
|
|
NOPE_DIM = 448
|
|
HEAD_BYTES = 584 # 448 fp8 + 128 bf16 + 8 uint8 scale
|
|
FP8_MAX = 448.0
|
|
QUANT_BLOCK = 64
|
|
|
|
num_blocks = (num_tokens + block_size - 1) // block_size + 1
|
|
device = "cuda"
|
|
|
|
# Random compressed_kv (simulates compressor output)
|
|
compressed_kv = torch.randn(
|
|
num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device
|
|
)
|
|
|
|
# ── Quant + insert ──────────────────────────────────────────────────
|
|
k_cache = torch.zeros(
|
|
num_blocks, block_size, HEAD_BYTES, dtype=torch.uint8, device=device
|
|
)
|
|
k_cache_2d = k_cache.view(num_blocks, -1)
|
|
|
|
# Sequential slot mapping: token i → slot i
|
|
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
|
|
|
|
quantize_and_insert_k_cache(
|
|
compressed_kv, k_cache_2d, slot_mapping, block_size=block_size
|
|
)
|
|
|
|
# ── Gather + dequant ────────────────────────────────────────────────
|
|
num_reqs = 1
|
|
max_blocks_per_seq = num_blocks
|
|
out = torch.zeros(
|
|
num_reqs, num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device
|
|
)
|
|
seq_lens = torch.tensor([num_tokens], dtype=torch.int32, device=device)
|
|
# block_table: request 0 uses physical blocks 0, 1, ...
|
|
block_table = torch.arange(
|
|
max_blocks_per_seq, dtype=torch.int32, device=device
|
|
).unsqueeze(0)
|
|
|
|
dequantize_and_gather_k_cache(
|
|
out, k_cache, seq_lens, None, block_table, block_size, offset=0
|
|
)
|
|
|
|
recovered = out[0, :num_tokens]
|
|
|
|
# ── NoPE portion (first 448): FP8 quantized, expect UE8M0 error ──
|
|
nope_orig = compressed_kv[:, :NOPE_DIM].float()
|
|
nope_recv = recovered[:, :NOPE_DIM].float()
|
|
nope_diff = (nope_recv - nope_orig).abs()
|
|
|
|
# Per-token check: FP8 e4m3 (3-bit mantissa) worst-case error is
|
|
# half-ULP at the largest representable value. At y ≈ 448 (max),
|
|
# ULP = 2^(8-3) = 32, so error ≤ 16 * scale.
|
|
for t in range(num_tokens):
|
|
_, scales = _ue8m0_reference(
|
|
compressed_kv[t, :NOPE_DIM].float(), QUANT_BLOCK, FP8_MAX
|
|
)
|
|
max_allowed = 16.0 * scales.max().item()
|
|
token_diff = nope_diff[t].max().item()
|
|
assert token_diff <= max_allowed, (
|
|
f"Token {t} nope diff {token_diff} exceeds max_allowed "
|
|
f"{max_allowed} (scale={scales.max().item()})"
|
|
)
|
|
|
|
# ── RoPE portion (last 64): stored as bf16, should be exact ─────
|
|
rope_diff = (recovered[:, NOPE_DIM:] - compressed_kv[:, NOPE_DIM:]).abs()
|
|
assert rope_diff.max().item() == 0.0, (
|
|
f"RoPE portion should be exact but got max diff {rope_diff.max().item()}"
|
|
)
|
|
|
|
|
|
# ── Test B: Indexer path ────────────────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 17])
|
|
@pytest.mark.parametrize("block_size", [16, 64])
|
|
def test_indexer_quant_cache_roundtrip(num_tokens: int, block_size: int):
|
|
"""k → indexer_k_quant_and_cache → cp_gather_indexer_k_quant_cache
|
|
→ manual dequant → compare against original."""
|
|
|
|
HEAD_DIM = 128
|
|
QUANT_BLOCK_SIZE = 128
|
|
# cache_stride = head_dim + (head_dim * 4 / quant_block_size) = 128 + 4 = 132
|
|
CACHE_STRIDE = HEAD_DIM + HEAD_DIM * 4 // QUANT_BLOCK_SIZE
|
|
|
|
num_blocks = (num_tokens + block_size - 1) // block_size + 1
|
|
device = "cuda"
|
|
|
|
# Random K (simulates compressor output for indexer)
|
|
k = torch.randn(num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device)
|
|
|
|
# ── Quant + insert ──────────────────────────────────────────────────
|
|
kv_cache = torch.zeros(
|
|
num_blocks, block_size, CACHE_STRIDE, dtype=torch.uint8, device=device
|
|
)
|
|
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
|
|
|
|
ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping, QUANT_BLOCK_SIZE, "ue8m0")
|
|
|
|
# ── Gather ──────────────────────────────────────────────────────────
|
|
max_blocks_per_seq = num_blocks
|
|
block_table = torch.arange(
|
|
max_blocks_per_seq, dtype=torch.int32, device=device
|
|
).unsqueeze(0)
|
|
cu_seq_lens = torch.tensor([0, num_tokens], dtype=torch.int32, device=device)
|
|
|
|
# dst_k: [total_seq_len, head_dim] as uint8 (raw FP8 bytes)
|
|
dst_k = torch.zeros(num_tokens, HEAD_DIM, dtype=torch.uint8, device=device)
|
|
# dst_scale: [total_seq_len, head_dim/quant_block*4] as uint8 (raw float32 bytes)
|
|
num_scale_bytes = HEAD_DIM * 4 // QUANT_BLOCK_SIZE # 4
|
|
dst_scale = torch.zeros(
|
|
num_tokens, num_scale_bytes, dtype=torch.uint8, device=device
|
|
)
|
|
|
|
ops.cp_gather_indexer_k_quant_cache(
|
|
kv_cache, dst_k, dst_scale, block_table, cu_seq_lens
|
|
)
|
|
|
|
# ── Manual dequant ──────────────────────────────────────────────────
|
|
k_fp8 = dst_k.view(torch.float8_e4m3fn).float() # [num_tokens, 128]
|
|
scale = dst_scale.view(torch.float32) # [num_tokens, 1]
|
|
k_recovered = k_fp8 * scale # [num_tokens, 128]
|
|
|
|
# ── Compare ─────────────────────────────────────────────────────────
|
|
diff = (k_recovered - k.float()).abs()
|
|
k_abs = k.float().abs()
|
|
|
|
for t in range(num_tokens):
|
|
amax = k_abs[t].max().clamp(min=1e-4).item()
|
|
# UE8M0: scale = 2^ceil(log2(amax / 448))
|
|
exponent = math.ceil(math.log2(amax / 448.0))
|
|
ue8m0_scale = 2.0**exponent
|
|
# FP8 e4m3 (3-bit mantissa): worst-case error = 16 * scale
|
|
max_allowed = 16.0 * ue8m0_scale
|
|
token_diff = diff[t].max().item()
|
|
assert token_diff <= max_allowed, (
|
|
f"Token {t} diff {token_diff} exceeds max_allowed "
|
|
f"{max_allowed} (scale={ue8m0_scale})"
|
|
)
|
|
|
|
|
|
def test_indexer_gather_accepts_upper_bound_output():
|
|
"""Gather only exact cu_seq_lens even when dst is over-allocated."""
|
|
|
|
head_dim = 128
|
|
quant_block_size = 128
|
|
cache_stride = head_dim + head_dim * 4 // quant_block_size
|
|
valid_tokens = 9
|
|
upper_bound_tokens = 13
|
|
block_size = 16
|
|
num_blocks = 2
|
|
sentinel = 123
|
|
device = "cuda"
|
|
|
|
k = torch.randn(valid_tokens, head_dim, dtype=torch.bfloat16, device=device)
|
|
kv_cache = torch.zeros(
|
|
num_blocks, block_size, cache_stride, dtype=torch.uint8, device=device
|
|
)
|
|
slot_mapping = torch.arange(valid_tokens, dtype=torch.int64, device=device)
|
|
ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping, quant_block_size, "ue8m0")
|
|
|
|
block_table = torch.arange(num_blocks, dtype=torch.int32, device=device).unsqueeze(
|
|
0
|
|
)
|
|
cu_seq_lens = torch.tensor([0, valid_tokens], dtype=torch.int32, device=device)
|
|
dst_k = torch.full(
|
|
(upper_bound_tokens, head_dim), sentinel, dtype=torch.uint8, device=device
|
|
)
|
|
num_scale_bytes = head_dim * 4 // quant_block_size
|
|
dst_scale = torch.full(
|
|
(upper_bound_tokens, num_scale_bytes),
|
|
sentinel,
|
|
dtype=torch.uint8,
|
|
device=device,
|
|
)
|
|
|
|
ops.cp_gather_indexer_k_quant_cache(
|
|
kv_cache, dst_k, dst_scale, block_table, cu_seq_lens
|
|
)
|
|
torch.accelerator.synchronize()
|
|
|
|
k_recovered = dst_k[:valid_tokens].view(torch.float8_e4m3fn).float() * dst_scale[
|
|
:valid_tokens
|
|
].view(torch.float32)
|
|
diff = (k_recovered - k.float()).abs()
|
|
max_allowed = (16.0 * dst_scale[:valid_tokens].view(torch.float32).max()).item()
|
|
assert diff.max().item() <= max_allowed
|
|
assert torch.all(dst_k[valid_tokens:] == sentinel)
|
|
assert torch.all(dst_scale[valid_tokens:] == sentinel)
|
|
|
|
|
|
# ── Test C: DeepseekV4 attention with values at different magnitudes ───────────
|
|
|
|
|
|
def test_deepseek_v4_quant_magnitude_range():
|
|
"""Test that quantization handles a range of magnitudes correctly."""
|
|
|
|
HEAD_DIM = 512
|
|
NOPE_DIM = 448
|
|
HEAD_BYTES = 584
|
|
block_size = 16
|
|
num_tokens = 4
|
|
num_blocks = 2
|
|
device = "cuda"
|
|
|
|
# Create inputs with varying magnitudes: small, medium, large
|
|
compressed_kv = torch.zeros(
|
|
num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device
|
|
)
|
|
compressed_kv[0] = 0.001 # very small
|
|
compressed_kv[1] = 1.0 # unit scale
|
|
compressed_kv[2] = 100.0 # large
|
|
compressed_kv[3] = torch.randn(HEAD_DIM, dtype=torch.bfloat16, device=device)
|
|
|
|
k_cache = torch.zeros(
|
|
num_blocks, block_size, HEAD_BYTES, dtype=torch.uint8, device=device
|
|
)
|
|
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
|
|
|
|
quantize_and_insert_k_cache(
|
|
compressed_kv, k_cache.view(num_blocks, -1), slot_mapping, block_size
|
|
)
|
|
|
|
out = torch.zeros(1, num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device)
|
|
seq_lens = torch.tensor([num_tokens], dtype=torch.int32, device=device)
|
|
block_table = torch.arange(num_blocks, dtype=torch.int32, device=device).unsqueeze(
|
|
0
|
|
)
|
|
|
|
dequantize_and_gather_k_cache(
|
|
out, k_cache, seq_lens, None, block_table, block_size, offset=0
|
|
)
|
|
|
|
recovered = out[0, :num_tokens]
|
|
|
|
# RoPE portion must be exact
|
|
rope_diff = (recovered[:, NOPE_DIM:] - compressed_kv[:, NOPE_DIM:]).abs().max()
|
|
assert rope_diff.item() == 0.0, f"RoPE diff {rope_diff.item()}"
|
|
|
|
# NoPE: relative error should be reasonable
|
|
for t in range(num_tokens):
|
|
orig = compressed_kv[t, :NOPE_DIM].float()
|
|
recv = recovered[t, :NOPE_DIM].float()
|
|
abs_diff = (recv - orig).abs().max().item()
|
|
magnitude = orig.abs().max().item()
|
|
if magnitude > 0.01:
|
|
rel_err = abs_diff / magnitude
|
|
assert rel_err < 0.15, (
|
|
f"Token {t}: rel_err={rel_err:.4f}, abs_diff={abs_diff:.6f}, "
|
|
f"magnitude={magnitude:.4f}"
|
|
)
|