Add nvfp4 kv cache support (#40177)

Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
This commit is contained in:
sychen52
2026-04-30 21:55:16 -07:00
committed by GitHub
parent 941fb50835
commit 947138b6c2
8 changed files with 503 additions and 96 deletions
+1 -1
View File
@@ -169,7 +169,7 @@ Priority is **1 = highest** (tried first).
| ------- | ------- | ------ | --------- | ----------- | ---------- | ---- | --------- | --- | --------------- | ------------ |
| `CPU_ATTN` | | fp16, bf16, fp32 | `auto`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256, 512 | ❌ | ❌ | ❌ | All | N/A |
| `FLASHINFER` | Native† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | ✅ | Decoder | 7.x-9.x |
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2`, `nvfp4` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x |
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ✅ | ❌ | ✅ | All | ≥10.0 |
@@ -5,12 +5,17 @@ import pytest
import torch
from tests.kernels.quantization.nvfp4_utils import (
dequant_nvfp4_kv_cache,
dequantize_nvfp4_to_dtype,
get_nvfp4_global_scale,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import set_random_seed
from vllm.utils.torch_utils import (
nvfp4_kv_cache_full_dim,
nvfp4_kv_cache_split_views,
set_random_seed,
)
if not current_platform.is_device_capability_family(100):
pytest.skip(
@@ -33,6 +38,117 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
return x_scl_sat.to(dtype), scale.float().reciprocal()
def build_paged_kv_metadata(
seq_lens: torch.Tensor,
block_tables: torch.Tensor,
block_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Build paged-KV indptr/indices/last_page_lens from seq_lens + block_tables."""
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(len(seq_lens)):
sl = int(seq_lens[i])
assert sl > 0
nb = (sl + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :nb].tolist())
kv_indptr.append(kv_indptr[-1] + nb)
kv_last_page_lens.append(sl % block_size or block_size)
return (
torch.tensor(kv_indptr, dtype=torch.int32),
torch.tensor(kv_indices, dtype=torch.int32),
torch.tensor(kv_last_page_lens, dtype=torch.int32),
)
def make_nvfp4_kv_cache(
kv_bf16_hnd: torch.Tensor, block_size: int, head_size: int
) -> tuple:
"""Quantize bf16 KV cache to nvfp4 via reshape_and_cache_flash.
Returns (k_data, v_data), (k_scales, v_scales), kv_scale, ref_kv_bf16.
"""
num_blocks, _, num_kv_heads, _, _ = kv_bf16_hnd.shape
kv_scale_val = (kv_bf16_hnd.abs().amax() / 448.0).item()
kv_scale_tensor = torch.tensor(
kv_scale_val, dtype=torch.float32, device=kv_bf16_hnd.device
)
# Allocate in HND physical order, permute to NHD logical order.
# hnd_order swaps dims 2↔3; it is its own inverse.
full_dim = nvfp4_kv_cache_full_dim(head_size)
hnd_order = (0, 1, 3, 2, 4)
kv_cache = torch.zeros(
(num_blocks, 2, num_kv_heads, block_size, full_dim),
dtype=torch.uint8,
device=kv_bf16_hnd.device,
).permute(*hnd_order)
# Flatten NHD [N, T, H, D] → token tensors [N*T, H, D] for the kernel.
num_tokens = num_blocks * block_size
k_tokens = (
kv_bf16_hnd[:, 0]
.permute(0, 2, 1, 3)
.reshape(num_tokens, num_kv_heads, head_size)
)
v_tokens = (
kv_bf16_hnd[:, 1]
.permute(0, 2, 1, 3)
.reshape(num_tokens, num_kv_heads, head_size)
)
slot_mapping = torch.arange(num_tokens, dtype=torch.long, device=kv_bf16_hnd.device)
# reshape_and_cache_flash: kernel receives kv_cache[:, 0] and [:, 1]
# (full K/V buffers containing both data and scale).
torch.ops._C_cache_ops.reshape_and_cache_flash(
k_tokens,
v_tokens,
kv_cache[:, 0],
kv_cache[:, 1],
slot_mapping,
"nvfp4",
kv_scale_tensor,
kv_scale_tensor,
)
# Split in HND order for trtllm kernel (expects HND numTokensPerPage).
kv_cache_hnd = kv_cache.permute(*hnd_order)
(k_data, v_data), (k_scales, v_scales) = nvfp4_kv_cache_split_views(kv_cache_hnd)
# Dequantize for the FA2 reference baseline.
ref_k = dequant_nvfp4_kv_cache(
k_data, k_scales, kv_scale_val, head_size, block_size
).to(torch.bfloat16)
ref_v = dequant_nvfp4_kv_cache(
v_data, v_scales, kv_scale_val, head_size, block_size
).to(torch.bfloat16)
ref_kv_bf16 = torch.stack([ref_k, ref_v], dim=1) # [N, 2, H, T, D]
return (k_data, v_data), (k_scales, v_scales), kv_scale_val, ref_kv_bf16
def make_quantized_kv_cache(
kv_cache: torch.Tensor,
kv_quant_dtype: torch.dtype,
block_size: int,
head_size: int,
) -> tuple:
"""Quantize kv_cache based on dtype. Returns (kv_cache, kv_cache_sf,
kv_scale, ref_kv_cache, is_nvfp4_kv)."""
is_nvfp4_kv = kv_quant_dtype == FP4_DTYPE
if is_nvfp4_kv:
data, scales, kv_scale, ref = make_nvfp4_kv_cache(
kv_cache, block_size, head_size
)
return data, scales, kv_scale, ref, True
elif kv_quant_dtype == FP8_DTYPE:
kv_fp8, kv_scale = to_float8(kv_cache)
ref = kv_fp8.to(kv_cache.dtype) * kv_scale
return kv_fp8, None, kv_scale, ref, False
else:
return kv_cache, None, 1.0, kv_cache, False
DTYPE = [torch.bfloat16]
QUANT_DTYPES = [
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
@@ -41,6 +157,7 @@ QUANT_DTYPES = [
(FP8_DTYPE, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
(FP8_DTYPE, FP4_DTYPE, FP8_DTYPE), # nvfp4 KV cache
]
BATCH_SIZE = [4, 12]
MAX_SEQ_LENS = [(1024, 4096)]
@@ -127,35 +244,19 @@ def test_flashinfer_trtllm_decode_with_baseline(
max_seq_len = torch.max(seq_lens).item()
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
if kv_quant_dtype == FP8_DTYPE:
kv_cache, kv_scale = to_float8(kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
else:
kv_scale = 1.0
ref_kv_cache = kv_cache
kv_cache, kv_cache_sf, kv_scale, ref_kv_cache, is_nvfp4_kv = (
make_quantized_kv_cache(kv_cache, kv_quant_dtype, block_size, head_size)
)
k_scale = v_scale = kv_scale
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
)
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(batch_size):
seq_len = seq_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
kv_indptr, kv_indices, kv_last_page_lens = build_paged_kv_metadata(
seq_lens, block_tables, block_size
)
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
# Baseline Decode
@@ -225,6 +326,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
sinks=sinks,
o_sf_scale=o_sf_scale_float,
out=output_trtllm,
kv_cache_sf=kv_cache_sf,
)
if o_quant_dtype == FP8_DTYPE:
output_trtllm = output_trtllm.to(dtype) * o_scale
@@ -237,7 +339,9 @@ def test_flashinfer_trtllm_decode_with_baseline(
)
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
if is_nvfp4_kv:
rtol, atol = 1.0, 1.0 # nvfp4 has higher quantization error
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 7e-2, 9e-2
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 3e-2, 4e-2
@@ -287,7 +391,12 @@ def test_flashinfer_trtllm_prefill_with_baseline(
kv_quant_dtype = kv_quant_dtype or dtype
o_quant_dtype = o_quant_dtype or dtype
if q_quant_dtype != kv_quant_dtype:
# FP8 Q + nvfp4 KV is the required combination for the nvfp4 KV path.
# All other mixed Q/KV dtype combinations are unsupported.
is_nvfp4_kv = kv_quant_dtype == FP4_DTYPE
if q_quant_dtype != kv_quant_dtype and not (
q_quant_dtype == FP8_DTYPE and is_nvfp4_kv
):
pytest.skip("Skipped mixed QKV dtypes for prefill")
max_q_len, max_kv_len = max_seq_lens
@@ -329,35 +438,19 @@ def test_flashinfer_trtllm_prefill_with_baseline(
max_seq_len = torch.max(seq_lens).item()
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
if kv_quant_dtype == FP8_DTYPE:
kv_cache, kv_scale = to_float8(kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
else:
kv_scale = 1.0
ref_kv_cache = kv_cache
kv_cache, kv_cache_sf, kv_scale, ref_kv_cache, is_nvfp4_kv = (
make_quantized_kv_cache(kv_cache, kv_quant_dtype, block_size, head_size)
)
k_scale = v_scale = kv_scale
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
)
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(batch_size):
seq_len = seq_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
kv_indptr, kv_indices, kv_last_page_lens = build_paged_kv_metadata(
seq_lens, block_tables, block_size
)
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
# Baseline Prefill
@@ -431,6 +524,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
sinks=sinks,
o_sf_scale=o_sf_scale_float,
out=output_trtllm,
kv_cache_sf=kv_cache_sf,
)
if o_quant_dtype == FP8_DTYPE:
output_trtllm = output_trtllm.to(dtype) * o_scale
@@ -443,7 +537,9 @@ def test_flashinfer_trtllm_prefill_with_baseline(
)
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
if is_nvfp4_kv:
rtol, atol = 1.0, 1.5 # nvfp4 has higher quantization error
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 3e-1, 4e-1
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 4e-2, 6e-2
@@ -17,13 +17,13 @@ from tests.v1.attention.utils import (
from vllm.config import set_current_vllm_config
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import set_random_seed
from vllm.utils.torch_utils import nvfp4_kv_cache_full_dim, set_random_seed
from vllm.v1.attention.backends.utils import (
PerLayerParameters,
get_kv_cache_layout,
set_kv_cache_layout,
)
from vllm.v1.kv_cache_interface import FullAttentionSpec
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVQuantMode
if not current_platform.is_device_capability_family(100):
pytest.skip(
@@ -53,6 +53,7 @@ class MockAttentionLayer:
MODEL = "Qwen/Qwen2.5-0.5B"
MODEL_NVFP4 = "Qwen/Qwen3-4B" # nvfp4 needs head_dim >= 128 (or 80)
BLOCK_SIZE = 16
NUM_GPU_BLOCKS = 8192
DEVICE_TYPE = current_platform.device_type
@@ -169,19 +170,129 @@ def _create_hnd_kv_cache(
return kv_cache
def _run_trtllm_integration(batch_spec):
def _create_nvfp4_hnd_kv_cache(
k_contexts,
v_contexts,
block_size,
num_kv_heads,
head_size,
dtype,
device,
num_blocks,
common_attn_metadata,
kv_scale_val,
):
"""Create an nvfp4 KV cache by quantizing bf16 context via
reshape_and_cache_flash, using the same block-table layout as
_create_hnd_kv_cache.
The returned tensor is dtype ``uint8`` with shape
``(num_blocks, 2, block_size, num_kv_heads, full_dim)`` in logical
(NHD) order, but physically permuted to HND layout via stride order
``(0, 1, 3, 2, 4)`` (i.e. ``num_kv_heads`` before ``block_size``).
The last dimension ``full_dim = head_size // 2 + head_size // 16``
packs two regions contiguously:
- **FP4 data** (``head_size // 2`` bytes): pairs of E2M1 values,
two per byte.
- **FP8 block scales** (``head_size // 16`` bytes): one E4M3
scale per 16-element block.
Dimension 1 indexes K (``[:, 0]``) and V (``[:, 1]``).
Args:
k_contexts: List of key context tensors, one per sequence.
v_contexts: List of value context tensors, one per sequence.
block_size: Number of tokens per cache block.
num_kv_heads: Number of key/value heads.
head_size: Head dimension (must be divisible by 16).
dtype: Source data type for the bf16 intermediate cache.
device: Target device.
num_blocks: Total number of blocks to allocate.
common_attn_metadata: Metadata containing block tables and
sequence lengths.
kv_scale_val: Scalar float used as both k_scale and v_scale
during quantization.
Returns:
``torch.Tensor``: The nvfp4 kv_cache tensor (uint8, HND-strided).
"""
# First create a bf16 HND cache so block tables are populated.
bf16_cache = _create_hnd_kv_cache(
k_contexts,
v_contexts,
block_size,
num_kv_heads,
head_size,
dtype,
device,
num_blocks,
common_attn_metadata,
)
# Allocate nvfp4 cache: same shape but with full_dim (data + scale).
full_dim = nvfp4_kv_cache_full_dim(head_size)
hnd_order = (0, 1, 3, 2, 4)
nvfp4_cache = torch.zeros(
(num_blocks, 2, num_kv_heads, block_size, full_dim),
dtype=torch.uint8,
device=device,
).permute(*hnd_order)
# Flatten bf16 context into tokens and quantize via reshape_and_cache_flash.
# bf16_cache is (num_blocks, 2, block_size, num_kv_heads, head_size) logical
# with HND physical strides.
block_table = common_attn_metadata.block_table_tensor
seq_lens = common_attn_metadata.seq_lens.cpu()
query_lens = (
common_attn_metadata.query_start_loc_cpu[1:]
- common_attn_metadata.query_start_loc_cpu[:-1]
)
kv_scale_t = torch.tensor(kv_scale_val, dtype=torch.float32, device=device)
for i in range(len(k_contexts)):
ctx_len = int(seq_lens[i]) - int(query_lens[i])
if ctx_len == 0:
continue
# Gather context tokens from the bf16 cache using block table.
n_ctx_blocks = (ctx_len + block_size - 1) // block_size
blocks = block_table[i, :n_ctx_blocks]
# bf16_cache[:, kv_idx] is (num_blocks, block_size, num_kv_heads, head_size)
k_ctx = bf16_cache[blocks, 0].reshape(-1, num_kv_heads, head_size)[:ctx_len]
v_ctx = bf16_cache[blocks, 1].reshape(-1, num_kv_heads, head_size)[:ctx_len]
# Build slot mapping for these context tokens.
token_offsets = torch.arange(ctx_len, device=device)
block_indices = token_offsets // block_size
intra_offsets = token_offsets % block_size
slots = block_table[i, block_indices] * block_size + intra_offsets
torch.ops._C_cache_ops.reshape_and_cache_flash(
k_ctx,
v_ctx,
nvfp4_cache[:, 0],
nvfp4_cache[:, 1],
slots,
"nvfp4",
kv_scale_t,
kv_scale_t,
)
return nvfp4_cache
def _run_trtllm_integration(batch_spec, kv_cache_dtype="auto", model_name=MODEL):
"""Run TRTLLM attention through the full FlashInfer pipeline
and compare against an SDPA reference."""
set_random_seed(42)
device = torch.device(f"{DEVICE_TYPE}:0")
vllm_config = create_vllm_config(
model_name=MODEL,
model_name=model_name,
max_model_len=max(batch_spec.seq_lens),
block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
)
vllm_config.attention_config.use_trtllm_attention = True
vllm_config.cache_config.cache_dtype = kv_cache_dtype
num_q_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config
@@ -248,28 +359,51 @@ def _run_trtllm_integration(batch_spec):
common_attn_metadata = create_common_attn_metadata(batch_spec, BLOCK_SIZE, device)
# 2. Create HND KV cache
kv_cache = _create_hnd_kv_cache(
k_contexts,
v_contexts,
BLOCK_SIZE,
num_kv_heads,
head_size,
dtype,
device,
NUM_GPU_BLOCKS,
common_attn_metadata,
)
is_nvfp4 = kv_cache_dtype == "nvfp4"
if is_nvfp4:
# Compute a global scale from the context data.
all_ctx = torch.cat(k_contexts + v_contexts, dim=0)
kv_scale_val = (all_ctx.abs().amax() / 448.0).item()
kv_cache = _create_nvfp4_hnd_kv_cache(
k_contexts,
v_contexts,
BLOCK_SIZE,
num_kv_heads,
head_size,
dtype,
device,
NUM_GPU_BLOCKS,
common_attn_metadata,
kv_scale_val,
)
else:
kv_scale_val = 1.0
kv_cache = _create_hnd_kv_cache(
k_contexts,
v_contexts,
BLOCK_SIZE,
num_kv_heads,
head_size,
dtype,
device,
NUM_GPU_BLOCKS,
common_attn_metadata,
)
# 3. Run through FlashInfer with TRTLLM enabled
set_kv_cache_layout("HND")
get_kv_cache_layout.cache_clear()
try:
is_nvfp4 = kv_cache_dtype == "nvfp4"
kv_quant_mode = KVQuantMode.NVFP4 if is_nvfp4 else KVQuantMode.NONE
spec_dtype = torch.uint8 if is_nvfp4 else dtype
kv_cache_spec = FullAttentionSpec(
block_size=BLOCK_SIZE,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
dtype=spec_dtype,
kv_quant_mode=kv_quant_mode,
)
layer_names = ["test_layer_0"]
@@ -312,10 +446,20 @@ def _run_trtllm_integration(batch_spec):
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="auto",
kv_cache_dtype=kv_cache_dtype,
)
mock_layer = MockAttentionLayer(device)
if is_nvfp4:
# For nvfp4, k_scale/v_scale are the global quantization
# scales (amax/448) used by reshape_and_cache_flash.
kv_scale_t = torch.tensor(
kv_scale_val, dtype=torch.float32, device=device
)
mock_layer._k_scale = kv_scale_t
mock_layer._v_scale = kv_scale_t
mock_layer._k_scale_float = kv_scale_val
mock_layer._v_scale_float = kv_scale_val
output = torch.empty_like(query_vllm)
impl.do_kv_cache_update(
@@ -326,6 +470,23 @@ def _run_trtllm_integration(batch_spec):
attn_metadata.slot_mapping,
)
# nvfp4 trtllm kernel requires FP8 queries. In the real
# pipeline the attention layer handles this; here we
# quantize manually.
if is_nvfp4:
finfo = torch.finfo(torch.float8_e4m3fn)
q_amax = query_vllm.abs().amax().clamp(min=1e-12)
q_s = (finfo.max / q_amax * 0.1).item()
query_vllm = (
(query_vllm * q_s)
.clamp(finfo.min, finfo.max)
.to(torch.float8_e4m3fn)
)
mock_layer._q_scale = torch.tensor(
1.0 / q_s, dtype=torch.float32, device=device
)
mock_layer._q_scale_float = 1.0 / q_s
output = impl.forward(
mock_layer,
query_vllm,
@@ -337,12 +498,11 @@ def _run_trtllm_integration(batch_spec):
)
# 4. Compare against SDPA reference
torch.testing.assert_close(
output,
sdpa_output,
atol=1e-2,
rtol=1e-2,
)
if is_nvfp4:
atol, rtol = 1.0, 1.0 # nvfp4 has higher quantization error
else:
atol, rtol = 1e-2, 1e-2
torch.testing.assert_close(output, sdpa_output, atol=atol, rtol=rtol)
finally:
set_kv_cache_layout(None)
@@ -359,3 +519,18 @@ def test_trtllm_gen_full_attention_integration(batch_spec_name: str):
MetadataBuilder.build() -> FlashInferImpl.forward() pipeline,
with real TRTLLM kernels on Blackwell."""
_run_trtllm_integration(BATCH_SPECS[batch_spec_name])
@pytest.mark.parametrize(
"batch_spec_name",
list(BATCH_SPECS.keys()),
)
@torch.inference_mode()
def test_trtllm_gen_nvfp4_kv_integration(batch_spec_name: str):
"""Test TRTLLM attention with nvfp4 KV cache through the full
FlashInfer MetadataBuilder.build() -> FlashInferImpl.forward() pipeline."""
_run_trtllm_integration(
BATCH_SPECS[batch_spec_name],
kv_cache_dtype="nvfp4",
model_name=MODEL_NVFP4,
)
@@ -788,6 +788,11 @@ def parse_flashinfer_trtllm_features() -> dict[str, dict[str, Any]]:
if not trtllm_compute_cap:
return {}
# KV cache dtypes that only work with a dedicated kernel (e.g. nvfp4
# requires the SM100 NVFP4 MHA kernel) and should not appear in the
# generic attention-backend feature matrix.
kernel_only_kv_dtypes = ["nvfp4"]
return {
"native": {
# Native FlashInfer: everything except SM100
@@ -798,6 +803,7 @@ def parse_flashinfer_trtllm_features() -> dict[str, dict[str, Any]]:
"compute_capability": trtllm_compute_cap,
"supports_sink": True,
},
"exclude_kv_dtypes": kernel_only_kv_dtypes,
}
@@ -963,6 +969,15 @@ def _expand_flashinfer_variants(
native["supports_sink"] = fi_features["native"]["supports_sink"]
native["compute_capability"] = f"{min_cc}.x-9.x"
# Remove KV dtypes only supported by SM100 kernels (e.g. nvfp4)
exclude = fi_features.get("exclude_kv_dtypes", [])
if exclude:
native["kv_cache_dtypes"] = ", ".join(
d
for d in (d.strip() for d in native["kv_cache_dtypes"].split(","))
if d not in exclude
)
# Create TRTLLM entry
trtllm = backend.copy()
trtllm["version"] = "TRTLLM†"
+12
View File
@@ -1884,6 +1884,18 @@ class VllmConfig:
"in the middle of a mm input"
)
@model_validator(mode="after")
def validate_nvfp4_kv_cache_with_mla(self) -> "VllmConfig":
if self.model_config is None:
return self
if self.cache_config.cache_dtype == "nvfp4" and self.model_config.use_mla:
raise ValueError(
"nvfp4 KV cache is not supported with MLA (Multi-head Latent "
"Attention) backends. Please use a different --kv-cache-dtype "
"(e.g., 'fp8' or 'auto') for MLA models such as DeepSeek."
)
return self
@model_validator(mode="after")
def validate_mamba_block_size(self) -> "VllmConfig":
if self.model_config is None:
@@ -114,12 +114,12 @@ QUANT_ALGOS = [
# MIXED_PRECISION,
"MIXED_PRECISION",
]
KV_CACHE_QUANT_ALGOS = ["FP8"]
KV_CACHE_QUANT_ALGOS = ["FP8", "NVFP4"]
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
class ModelOptKVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
Supports loading kv-cache scaling factors from FP8 or NVFP4 checkpoints.
"""
def __init__(self, quant_config: "ModelOptQuantConfigBase"):
@@ -995,7 +995,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
ModelOptFp8Config.KVCacheMethodCls = ModelOptKVCacheMethod
class ModelOptNvFp4Config(ModelOptQuantConfigBase):
@@ -1488,7 +1488,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptKVCacheMethod
class ModelOptMxFp8Config(ModelOptQuantConfigBase):
@@ -2018,7 +2018,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase):
# Register the method classes for ModelOptMxFp8Config
ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod
ModelOptMxFp8Config.FusedMoEMethodCls = ModelOptMxFp8FusedMoE
ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
ModelOptMxFp8Config.KVCacheMethodCls = ModelOptKVCacheMethod
class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase):
@@ -2166,7 +2166,7 @@ class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase):
# KV-cache quantization
if isinstance(layer, Attention):
if self.kv_cache_quant_method:
return ModelOptFp8KVCacheMethod(self)
return ModelOptKVCacheMethod(self)
return None
# Excluded layers
+114 -15
View File
@@ -332,6 +332,7 @@ class FlashInferBackend(AttentionBackend):
"fp8",
"fp8_e4m3",
"fp8_e5m2",
"nvfp4",
]
@staticmethod
@@ -388,13 +389,15 @@ class FlashInferBackend(AttentionBackend):
return stride_order
@staticmethod
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
def get_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
return torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":
return torch.float8_e5m2
elif kv_cache_dtype == "nvfp4":
return torch.uint8
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
raise ValueError(f"Unrecognized dtype: {kv_cache_dtype}")
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
@@ -622,9 +625,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# For NVFP4, kv_cache_dtype stays as the string "nvfp4"
# which is passed to FlashInferImpl
self.kv_cache_dtype = self.cache_dtype
raise NotImplementedError("nvfp4 KV cache is not yet supported")
else:
self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.kv_cache_dtype = FlashInferBackend.get_dtype_for_flashinfer(
self.cache_dtype
)
else:
@@ -645,7 +647,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
):
if self.is_kvcache_nvfp4:
# NVFP4 KV cache uses FP8 quantized queries
self.q_data_type = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.q_data_type = FlashInferBackend.get_dtype_for_flashinfer(
"fp8_e4m3"
)
else:
@@ -765,8 +767,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dcp_a2a=self.dcp_a2a,
)
else:
# NVFP4 KV cache requires the trtllm-gen backend inside
# the wrapper; fa2/fa3 do not support nvfp4.
backend = "trtllm-gen" if self.is_kvcache_nvfp4 else "auto"
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), get_kv_cache_layout()
self._get_workspace_buffer(),
get_kv_cache_layout(),
backend=backend,
)
assert self._prefill_wrapper is not None
return self._prefill_wrapper
@@ -786,6 +793,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indptr = None
paged_kv_indices = None
paged_kv_last_page_len = None
# NVFP4 KV cache requires the trtllm-gen backend inside
# the wrapper; fa2/fa3 do not support nvfp4.
backend = "trtllm-gen" if self.is_kvcache_nvfp4 else "auto"
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
get_kv_cache_layout(),
@@ -797,6 +807,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# at least as good as cuda cores for all attention ops in latest
# gpus.
use_tensor_cores=True,
backend=backend,
)
# save the decode wrapper
@@ -1148,6 +1159,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
prefill_wrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
# NVFP4 trtllm kernel only supports FP8 output;
# use FP8 o_data_type so the wrapper matches the
# FP8 output buffer allocated in forward().
o_dtype = (
FP8_DTYPE if self.is_kvcache_nvfp4 else self.model_config.dtype
)
prefill_wrapper.plan(
qo_indptr=qo_indptr_prefill_cpu,
paged_kv_indptr=paged_kv_indptr_prefill_cpu,
@@ -1163,7 +1180,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
o_data_type=self.model_config.dtype,
o_data_type=o_dtype,
fixed_split_size=self.prefill_fixed_split_size,
disable_split_kv=self.disable_split_kv,
)
@@ -1197,6 +1214,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Use the persistent buffer with padding length,
# instead of the same address but chunked version
# in atten_metadata when using cudagraph.
# NVFP4 trtllm kernel only supports FP8 output;
# use FP8 o_data_type so the wrapper matches the
# FP8 output buffer allocated in forward().
o_dtype = (
FP8_DTYPE if self.is_kvcache_nvfp4 else self.model_config.dtype
)
fast_plan_decode(
decode_wrapper,
indptr_cpu=self.paged_kv_indptr.cpu[: num_input_tokens + 1],
@@ -1215,7 +1238,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
o_data_type=self.model_config.dtype,
o_data_type=o_dtype,
fixed_split_size=self.decode_fixed_split_size,
disable_split_kv=self.disable_split_kv,
)
@@ -1300,6 +1323,17 @@ class FlashInferImpl(AttentionImpl):
self.bmm2_scale: float | None = None
self.o_sf_scale: float | None = None
# Pre-allocated FP8 output buffer for NVFP4 without fused output quant.
if self.is_kvcache_nvfp4 and vllm_config is not None:
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self._nvfp4_fp8_out = torch.empty(
(max_num_tokens, num_heads, head_size),
dtype=FP8_DTYPE,
device="cuda",
)
else:
self._nvfp4_fp8_out = None
dcp_a2a = (
vllm_config is not None
and vllm_config.parallel_config.decode_context_parallel_size > 1
@@ -1420,7 +1454,7 @@ class FlashInferImpl(AttentionImpl):
if self.kv_sharing_target_layer_name is None and is_quantized_kv_cache(
self.kv_cache_dtype
):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
torch_dtype = FlashInferBackend.get_dtype_for_flashinfer(
self.kv_cache_dtype
)
kv_cache = kv_cache.view(torch_dtype)
@@ -1500,13 +1534,37 @@ class FlashInferImpl(AttentionImpl):
)
assert prefill_wrapper._sm_scale == self.scale
assert prefill_wrapper._causal
if self.is_kvcache_nvfp4:
kv_cache_permute = nvfp4_kv_data
kv_cache_sf = (
nvfp4_kv_block_scales if self.is_kvcache_nvfp4 else None
)
# NVFP4 trtllm kernel only supports FP8 output.
# Use a pre-allocated FP8 buffer and dequantize
# afterwards.
needs_fp8_out_prefill = (
self.is_kvcache_nvfp4 and output.dtype != FP8_DTYPE
)
if needs_fp8_out_prefill:
out_prefill = self._nvfp4_fp8_out[:num_prefill_tokens]
else:
out_prefill = output[num_decode_tokens:]
prefill_wrapper.run(
prefill_query,
kv_cache_permute,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[num_decode_tokens:],
out=out_prefill,
kv_cache_sf=kv_cache_sf,
)
if needs_fp8_out_prefill:
output[
num_decode_tokens : num_decode_tokens + num_prefill_tokens
].copy_(out_prefill.to(output.dtype))
else:
assert isinstance(attn_metadata.prefill, TRTLLMPrefill)
# prefill_query may be non-contiguous or have degenerate strides
@@ -1537,6 +1595,12 @@ class FlashInferImpl(AttentionImpl):
assert self.o_sf_scale is None
out = output[num_decode_tokens:]
# NVFP4 trtllm kernel only supports FP8 output.
# Use a pre-allocated FP8 buffer and dequantize afterwards.
needs_fp8_out = self.is_kvcache_nvfp4 and output.dtype != FP8_DTYPE
if needs_fp8_out:
out = self._nvfp4_fp8_out[:num_prefill_tokens]
prefill_kv_block_scales = None
if self.is_kvcache_nvfp4:
# NVFP4 trtllm-gen kernel requires FP8 query.
@@ -1547,7 +1611,7 @@ class FlashInferImpl(AttentionImpl):
)
mock_kv_cache = nvfp4_kv_data
mock_block_table = block_tables_prefill
prefill_kv_block_scales = nvfp4_kv_block_scales # noqa: F841
prefill_kv_block_scales = nvfp4_kv_block_scales
elif (
attn_metadata.q_data_type != FP8_DTYPE
and self.kv_cache_dtype.startswith("fp8")
@@ -1598,8 +1662,14 @@ class FlashInferImpl(AttentionImpl):
sinks=self.sinks,
o_sf_scale=self.o_sf_scale,
out=out,
kv_cache_sf=prefill_kv_block_scales,
)
if needs_fp8_out:
output[
num_decode_tokens : num_decode_tokens + num_prefill_tokens
].copy_(out[:num_prefill_tokens].to(output.dtype))
if num_decode_tokens > 0:
decode_query = query[:num_decode_tokens]
assert decode_query.shape[0] == num_decode_tokens
@@ -1612,6 +1682,18 @@ class FlashInferImpl(AttentionImpl):
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
assert decode_wrapper._sm_scale == self.scale
if self.is_kvcache_nvfp4:
kv_cache_permute = nvfp4_kv_data
kv_cache_sf = nvfp4_kv_block_scales if self.is_kvcache_nvfp4 else None
# NVFP4 kernel only supports FP8 output.
# Use a pre-allocated FP8 buffer and dequantize afterwards.
needs_fp8_out = self.is_kvcache_nvfp4 and output.dtype != FP8_DTYPE
if needs_fp8_out:
out_decode = self._nvfp4_fp8_out[:num_decode_tokens]
else:
out_decode = output[:num_decode_tokens]
if use_dcp:
decode_query = get_dcp_group().all_gather(
decode_query.contiguous(), dim=-2
@@ -1630,6 +1712,7 @@ class FlashInferImpl(AttentionImpl):
out=output_tmp,
lse=lse,
return_lse=True,
kv_cache_sf=kv_cache_sf,
)
output[:num_decode_tokens] = self.dcp_combine(
output_tmp,
@@ -1642,8 +1725,12 @@ class FlashInferImpl(AttentionImpl):
kv_cache_permute,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[:num_decode_tokens],
out=out_decode,
kv_cache_sf=kv_cache_sf,
)
if needs_fp8_out:
output[:num_decode_tokens].copy_(out_decode.to(output.dtype))
else:
# decode_query may be non-contiguous or have degenerate strides
assert isinstance(attn_metadata.decode, TRTLLMDecode)
@@ -1686,6 +1773,12 @@ class FlashInferImpl(AttentionImpl):
assert self.o_sf_scale is None
out = output[:num_decode_tokens]
# NVFP4 trtllm kernel only supports FP8 output.
# Use a pre-allocated FP8 buffer and dequantize afterwards.
needs_fp8_out = self.is_kvcache_nvfp4 and output.dtype != FP8_DTYPE
if needs_fp8_out:
out = self._nvfp4_fp8_out[:num_decode_tokens]
if num_decode_tokens % attn_metadata.num_decodes != 0:
# This gets triggered when the dummy_run forces
# attention to be initialized with q_len = 0
@@ -1695,9 +1788,9 @@ class FlashInferImpl(AttentionImpl):
trtllm_batch_decode_with_kv_cache(
query=decode_query,
kv_cache=nvfp4_kv_data
if self.is_kvcache_nvfp4
else kv_cache_permute,
kv_cache=(
nvfp4_kv_data if self.is_kvcache_nvfp4 else kv_cache_permute
),
workspace_buffer=workspace_buffer,
block_tables=block_tables_decode,
seq_lens=seq_lens_decode,
@@ -1709,7 +1802,13 @@ class FlashInferImpl(AttentionImpl):
o_sf_scale=self.o_sf_scale,
out=out,
q_len_per_req=q_len_per_req,
kv_cache_sf=(
nvfp4_kv_block_scales if self.is_kvcache_nvfp4 else None
),
)
if needs_fp8_out:
output[:num_decode_tokens].copy_(out.to(output.dtype))
return output_padded
def do_kv_cache_update(
+10
View File
@@ -151,6 +151,16 @@ class AttentionSpec(KVCacheSpec):
@property
def real_page_size_bytes(self) -> int:
if self.kv_quant_mode.is_nvfp4:
# Packed layout: fp4 data + fp8 block scales per head.
full_dim = nvfp4_kv_cache_full_dim(self.head_size)
return (
2
* self.block_size
* self.num_kv_heads
* full_dim
* get_dtype_size(self.dtype)
)
return (
2
* self.block_size