mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
Add nvfp4 kv cache support (#40177)
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
This commit is contained in:
@@ -169,7 +169,7 @@ Priority is **1 = highest** (tried first).
|
||||
| ------- | ------- | ------ | --------- | ----------- | ---------- | ---- | --------- | --- | --------------- | ------------ |
|
||||
| `CPU_ATTN` | | fp16, bf16, fp32 | `auto`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256, 512 | ❌ | ❌ | ❌ | All | N/A |
|
||||
| `FLASHINFER` | Native† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | ✅ | Decoder | 7.x-9.x |
|
||||
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
|
||||
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2`, `nvfp4` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
|
||||
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
|
||||
| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x |
|
||||
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ✅ | ❌ | ✅ | All | ≥10.0 |
|
||||
|
||||
@@ -5,12 +5,17 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quantization.nvfp4_utils import (
|
||||
dequant_nvfp4_kv_cache,
|
||||
dequantize_nvfp4_to_dtype,
|
||||
get_nvfp4_global_scale,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.utils.torch_utils import (
|
||||
nvfp4_kv_cache_full_dim,
|
||||
nvfp4_kv_cache_split_views,
|
||||
set_random_seed,
|
||||
)
|
||||
|
||||
if not current_platform.is_device_capability_family(100):
|
||||
pytest.skip(
|
||||
@@ -33,6 +38,117 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||
|
||||
|
||||
def build_paged_kv_metadata(
|
||||
seq_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
block_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Build paged-KV indptr/indices/last_page_lens from seq_lens + block_tables."""
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(len(seq_lens)):
|
||||
sl = int(seq_lens[i])
|
||||
assert sl > 0
|
||||
nb = (sl + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :nb].tolist())
|
||||
kv_indptr.append(kv_indptr[-1] + nb)
|
||||
kv_last_page_lens.append(sl % block_size or block_size)
|
||||
return (
|
||||
torch.tensor(kv_indptr, dtype=torch.int32),
|
||||
torch.tensor(kv_indices, dtype=torch.int32),
|
||||
torch.tensor(kv_last_page_lens, dtype=torch.int32),
|
||||
)
|
||||
|
||||
|
||||
def make_nvfp4_kv_cache(
|
||||
kv_bf16_hnd: torch.Tensor, block_size: int, head_size: int
|
||||
) -> tuple:
|
||||
"""Quantize bf16 KV cache to nvfp4 via reshape_and_cache_flash.
|
||||
|
||||
Returns (k_data, v_data), (k_scales, v_scales), kv_scale, ref_kv_bf16.
|
||||
"""
|
||||
num_blocks, _, num_kv_heads, _, _ = kv_bf16_hnd.shape
|
||||
kv_scale_val = (kv_bf16_hnd.abs().amax() / 448.0).item()
|
||||
kv_scale_tensor = torch.tensor(
|
||||
kv_scale_val, dtype=torch.float32, device=kv_bf16_hnd.device
|
||||
)
|
||||
|
||||
# Allocate in HND physical order, permute to NHD logical order.
|
||||
# hnd_order swaps dims 2↔3; it is its own inverse.
|
||||
full_dim = nvfp4_kv_cache_full_dim(head_size)
|
||||
hnd_order = (0, 1, 3, 2, 4)
|
||||
kv_cache = torch.zeros(
|
||||
(num_blocks, 2, num_kv_heads, block_size, full_dim),
|
||||
dtype=torch.uint8,
|
||||
device=kv_bf16_hnd.device,
|
||||
).permute(*hnd_order)
|
||||
|
||||
# Flatten NHD [N, T, H, D] → token tensors [N*T, H, D] for the kernel.
|
||||
num_tokens = num_blocks * block_size
|
||||
k_tokens = (
|
||||
kv_bf16_hnd[:, 0]
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(num_tokens, num_kv_heads, head_size)
|
||||
)
|
||||
v_tokens = (
|
||||
kv_bf16_hnd[:, 1]
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(num_tokens, num_kv_heads, head_size)
|
||||
)
|
||||
slot_mapping = torch.arange(num_tokens, dtype=torch.long, device=kv_bf16_hnd.device)
|
||||
|
||||
# reshape_and_cache_flash: kernel receives kv_cache[:, 0] and [:, 1]
|
||||
# (full K/V buffers containing both data and scale).
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
k_tokens,
|
||||
v_tokens,
|
||||
kv_cache[:, 0],
|
||||
kv_cache[:, 1],
|
||||
slot_mapping,
|
||||
"nvfp4",
|
||||
kv_scale_tensor,
|
||||
kv_scale_tensor,
|
||||
)
|
||||
|
||||
# Split in HND order for trtllm kernel (expects HND numTokensPerPage).
|
||||
kv_cache_hnd = kv_cache.permute(*hnd_order)
|
||||
(k_data, v_data), (k_scales, v_scales) = nvfp4_kv_cache_split_views(kv_cache_hnd)
|
||||
|
||||
# Dequantize for the FA2 reference baseline.
|
||||
ref_k = dequant_nvfp4_kv_cache(
|
||||
k_data, k_scales, kv_scale_val, head_size, block_size
|
||||
).to(torch.bfloat16)
|
||||
ref_v = dequant_nvfp4_kv_cache(
|
||||
v_data, v_scales, kv_scale_val, head_size, block_size
|
||||
).to(torch.bfloat16)
|
||||
ref_kv_bf16 = torch.stack([ref_k, ref_v], dim=1) # [N, 2, H, T, D]
|
||||
|
||||
return (k_data, v_data), (k_scales, v_scales), kv_scale_val, ref_kv_bf16
|
||||
|
||||
|
||||
def make_quantized_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
kv_quant_dtype: torch.dtype,
|
||||
block_size: int,
|
||||
head_size: int,
|
||||
) -> tuple:
|
||||
"""Quantize kv_cache based on dtype. Returns (kv_cache, kv_cache_sf,
|
||||
kv_scale, ref_kv_cache, is_nvfp4_kv)."""
|
||||
is_nvfp4_kv = kv_quant_dtype == FP4_DTYPE
|
||||
if is_nvfp4_kv:
|
||||
data, scales, kv_scale, ref = make_nvfp4_kv_cache(
|
||||
kv_cache, block_size, head_size
|
||||
)
|
||||
return data, scales, kv_scale, ref, True
|
||||
elif kv_quant_dtype == FP8_DTYPE:
|
||||
kv_fp8, kv_scale = to_float8(kv_cache)
|
||||
ref = kv_fp8.to(kv_cache.dtype) * kv_scale
|
||||
return kv_fp8, None, kv_scale, ref, False
|
||||
else:
|
||||
return kv_cache, None, 1.0, kv_cache, False
|
||||
|
||||
|
||||
DTYPE = [torch.bfloat16]
|
||||
QUANT_DTYPES = [
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
@@ -41,6 +157,7 @@ QUANT_DTYPES = [
|
||||
(FP8_DTYPE, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
(FP8_DTYPE, FP4_DTYPE, FP8_DTYPE), # nvfp4 KV cache
|
||||
]
|
||||
BATCH_SIZE = [4, 12]
|
||||
MAX_SEQ_LENS = [(1024, 4096)]
|
||||
@@ -127,35 +244,19 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
|
||||
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, kv_scale = to_float8(kv_cache)
|
||||
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
||||
else:
|
||||
kv_scale = 1.0
|
||||
ref_kv_cache = kv_cache
|
||||
kv_cache, kv_cache_sf, kv_scale, ref_kv_cache, is_nvfp4_kv = (
|
||||
make_quantized_kv_cache(kv_cache, kv_quant_dtype, block_size, head_size)
|
||||
)
|
||||
|
||||
k_scale = v_scale = kv_scale
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_len = seq_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
kv_indptr, kv_indices, kv_last_page_lens = build_paged_kv_metadata(
|
||||
seq_lens, block_tables, block_size
|
||||
)
|
||||
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
|
||||
|
||||
# Baseline Decode
|
||||
@@ -225,6 +326,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
sinks=sinks,
|
||||
o_sf_scale=o_sf_scale_float,
|
||||
out=output_trtllm,
|
||||
kv_cache_sf=kv_cache_sf,
|
||||
)
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
output_trtllm = output_trtllm.to(dtype) * o_scale
|
||||
@@ -237,7 +339,9 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
)
|
||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
|
||||
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||
if is_nvfp4_kv:
|
||||
rtol, atol = 1.0, 1.0 # nvfp4 has higher quantization error
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||
rtol, atol = 7e-2, 9e-2
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||
rtol, atol = 3e-2, 4e-2
|
||||
@@ -287,7 +391,12 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
if q_quant_dtype != kv_quant_dtype:
|
||||
# FP8 Q + nvfp4 KV is the required combination for the nvfp4 KV path.
|
||||
# All other mixed Q/KV dtype combinations are unsupported.
|
||||
is_nvfp4_kv = kv_quant_dtype == FP4_DTYPE
|
||||
if q_quant_dtype != kv_quant_dtype and not (
|
||||
q_quant_dtype == FP8_DTYPE and is_nvfp4_kv
|
||||
):
|
||||
pytest.skip("Skipped mixed QKV dtypes for prefill")
|
||||
|
||||
max_q_len, max_kv_len = max_seq_lens
|
||||
@@ -329,35 +438,19 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
|
||||
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, kv_scale = to_float8(kv_cache)
|
||||
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
||||
else:
|
||||
kv_scale = 1.0
|
||||
ref_kv_cache = kv_cache
|
||||
kv_cache, kv_cache_sf, kv_scale, ref_kv_cache, is_nvfp4_kv = (
|
||||
make_quantized_kv_cache(kv_cache, kv_quant_dtype, block_size, head_size)
|
||||
)
|
||||
|
||||
k_scale = v_scale = kv_scale
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_len = seq_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
kv_indptr, kv_indices, kv_last_page_lens = build_paged_kv_metadata(
|
||||
seq_lens, block_tables, block_size
|
||||
)
|
||||
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
|
||||
|
||||
# Baseline Prefill
|
||||
@@ -431,6 +524,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
sinks=sinks,
|
||||
o_sf_scale=o_sf_scale_float,
|
||||
out=output_trtllm,
|
||||
kv_cache_sf=kv_cache_sf,
|
||||
)
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
output_trtllm = output_trtllm.to(dtype) * o_scale
|
||||
@@ -443,7 +537,9 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
)
|
||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
|
||||
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||
if is_nvfp4_kv:
|
||||
rtol, atol = 1.0, 1.5 # nvfp4 has higher quantization error
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||
rtol, atol = 3e-1, 4e-1
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||
rtol, atol = 4e-2, 6e-2
|
||||
|
||||
@@ -17,13 +17,13 @@ from tests.v1.attention.utils import (
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.utils.torch_utils import nvfp4_kv_cache_full_dim, set_random_seed
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PerLayerParameters,
|
||||
get_kv_cache_layout,
|
||||
set_kv_cache_layout,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVQuantMode
|
||||
|
||||
if not current_platform.is_device_capability_family(100):
|
||||
pytest.skip(
|
||||
@@ -53,6 +53,7 @@ class MockAttentionLayer:
|
||||
|
||||
|
||||
MODEL = "Qwen/Qwen2.5-0.5B"
|
||||
MODEL_NVFP4 = "Qwen/Qwen3-4B" # nvfp4 needs head_dim >= 128 (or 80)
|
||||
BLOCK_SIZE = 16
|
||||
NUM_GPU_BLOCKS = 8192
|
||||
DEVICE_TYPE = current_platform.device_type
|
||||
@@ -169,19 +170,129 @@ def _create_hnd_kv_cache(
|
||||
return kv_cache
|
||||
|
||||
|
||||
def _run_trtllm_integration(batch_spec):
|
||||
def _create_nvfp4_hnd_kv_cache(
|
||||
k_contexts,
|
||||
v_contexts,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype,
|
||||
device,
|
||||
num_blocks,
|
||||
common_attn_metadata,
|
||||
kv_scale_val,
|
||||
):
|
||||
"""Create an nvfp4 KV cache by quantizing bf16 context via
|
||||
reshape_and_cache_flash, using the same block-table layout as
|
||||
_create_hnd_kv_cache.
|
||||
|
||||
The returned tensor is dtype ``uint8`` with shape
|
||||
``(num_blocks, 2, block_size, num_kv_heads, full_dim)`` in logical
|
||||
(NHD) order, but physically permuted to HND layout via stride order
|
||||
``(0, 1, 3, 2, 4)`` (i.e. ``num_kv_heads`` before ``block_size``).
|
||||
|
||||
The last dimension ``full_dim = head_size // 2 + head_size // 16``
|
||||
packs two regions contiguously:
|
||||
- **FP4 data** (``head_size // 2`` bytes): pairs of E2M1 values,
|
||||
two per byte.
|
||||
- **FP8 block scales** (``head_size // 16`` bytes): one E4M3
|
||||
scale per 16-element block.
|
||||
|
||||
Dimension 1 indexes K (``[:, 0]``) and V (``[:, 1]``).
|
||||
|
||||
Args:
|
||||
k_contexts: List of key context tensors, one per sequence.
|
||||
v_contexts: List of value context tensors, one per sequence.
|
||||
block_size: Number of tokens per cache block.
|
||||
num_kv_heads: Number of key/value heads.
|
||||
head_size: Head dimension (must be divisible by 16).
|
||||
dtype: Source data type for the bf16 intermediate cache.
|
||||
device: Target device.
|
||||
num_blocks: Total number of blocks to allocate.
|
||||
common_attn_metadata: Metadata containing block tables and
|
||||
sequence lengths.
|
||||
kv_scale_val: Scalar float used as both k_scale and v_scale
|
||||
during quantization.
|
||||
|
||||
Returns:
|
||||
``torch.Tensor``: The nvfp4 kv_cache tensor (uint8, HND-strided).
|
||||
"""
|
||||
# First create a bf16 HND cache so block tables are populated.
|
||||
bf16_cache = _create_hnd_kv_cache(
|
||||
k_contexts,
|
||||
v_contexts,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype,
|
||||
device,
|
||||
num_blocks,
|
||||
common_attn_metadata,
|
||||
)
|
||||
|
||||
# Allocate nvfp4 cache: same shape but with full_dim (data + scale).
|
||||
full_dim = nvfp4_kv_cache_full_dim(head_size)
|
||||
hnd_order = (0, 1, 3, 2, 4)
|
||||
nvfp4_cache = torch.zeros(
|
||||
(num_blocks, 2, num_kv_heads, block_size, full_dim),
|
||||
dtype=torch.uint8,
|
||||
device=device,
|
||||
).permute(*hnd_order)
|
||||
|
||||
# Flatten bf16 context into tokens and quantize via reshape_and_cache_flash.
|
||||
# bf16_cache is (num_blocks, 2, block_size, num_kv_heads, head_size) logical
|
||||
# with HND physical strides.
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
seq_lens = common_attn_metadata.seq_lens.cpu()
|
||||
query_lens = (
|
||||
common_attn_metadata.query_start_loc_cpu[1:]
|
||||
- common_attn_metadata.query_start_loc_cpu[:-1]
|
||||
)
|
||||
kv_scale_t = torch.tensor(kv_scale_val, dtype=torch.float32, device=device)
|
||||
|
||||
for i in range(len(k_contexts)):
|
||||
ctx_len = int(seq_lens[i]) - int(query_lens[i])
|
||||
if ctx_len == 0:
|
||||
continue
|
||||
# Gather context tokens from the bf16 cache using block table.
|
||||
n_ctx_blocks = (ctx_len + block_size - 1) // block_size
|
||||
blocks = block_table[i, :n_ctx_blocks]
|
||||
# bf16_cache[:, kv_idx] is (num_blocks, block_size, num_kv_heads, head_size)
|
||||
k_ctx = bf16_cache[blocks, 0].reshape(-1, num_kv_heads, head_size)[:ctx_len]
|
||||
v_ctx = bf16_cache[blocks, 1].reshape(-1, num_kv_heads, head_size)[:ctx_len]
|
||||
# Build slot mapping for these context tokens.
|
||||
token_offsets = torch.arange(ctx_len, device=device)
|
||||
block_indices = token_offsets // block_size
|
||||
intra_offsets = token_offsets % block_size
|
||||
slots = block_table[i, block_indices] * block_size + intra_offsets
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
k_ctx,
|
||||
v_ctx,
|
||||
nvfp4_cache[:, 0],
|
||||
nvfp4_cache[:, 1],
|
||||
slots,
|
||||
"nvfp4",
|
||||
kv_scale_t,
|
||||
kv_scale_t,
|
||||
)
|
||||
|
||||
return nvfp4_cache
|
||||
|
||||
|
||||
def _run_trtllm_integration(batch_spec, kv_cache_dtype="auto", model_name=MODEL):
|
||||
"""Run TRTLLM attention through the full FlashInfer pipeline
|
||||
and compare against an SDPA reference."""
|
||||
set_random_seed(42)
|
||||
device = torch.device(f"{DEVICE_TYPE}:0")
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
model_name=MODEL,
|
||||
model_name=model_name,
|
||||
max_model_len=max(batch_spec.seq_lens),
|
||||
block_size=BLOCK_SIZE,
|
||||
num_gpu_blocks=NUM_GPU_BLOCKS,
|
||||
)
|
||||
vllm_config.attention_config.use_trtllm_attention = True
|
||||
vllm_config.cache_config.cache_dtype = kv_cache_dtype
|
||||
|
||||
num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
@@ -248,28 +359,51 @@ def _run_trtllm_integration(batch_spec):
|
||||
common_attn_metadata = create_common_attn_metadata(batch_spec, BLOCK_SIZE, device)
|
||||
|
||||
# 2. Create HND KV cache
|
||||
kv_cache = _create_hnd_kv_cache(
|
||||
k_contexts,
|
||||
v_contexts,
|
||||
BLOCK_SIZE,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype,
|
||||
device,
|
||||
NUM_GPU_BLOCKS,
|
||||
common_attn_metadata,
|
||||
)
|
||||
is_nvfp4 = kv_cache_dtype == "nvfp4"
|
||||
if is_nvfp4:
|
||||
# Compute a global scale from the context data.
|
||||
all_ctx = torch.cat(k_contexts + v_contexts, dim=0)
|
||||
kv_scale_val = (all_ctx.abs().amax() / 448.0).item()
|
||||
kv_cache = _create_nvfp4_hnd_kv_cache(
|
||||
k_contexts,
|
||||
v_contexts,
|
||||
BLOCK_SIZE,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype,
|
||||
device,
|
||||
NUM_GPU_BLOCKS,
|
||||
common_attn_metadata,
|
||||
kv_scale_val,
|
||||
)
|
||||
else:
|
||||
kv_scale_val = 1.0
|
||||
kv_cache = _create_hnd_kv_cache(
|
||||
k_contexts,
|
||||
v_contexts,
|
||||
BLOCK_SIZE,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype,
|
||||
device,
|
||||
NUM_GPU_BLOCKS,
|
||||
common_attn_metadata,
|
||||
)
|
||||
|
||||
# 3. Run through FlashInfer with TRTLLM enabled
|
||||
set_kv_cache_layout("HND")
|
||||
get_kv_cache_layout.cache_clear()
|
||||
|
||||
try:
|
||||
is_nvfp4 = kv_cache_dtype == "nvfp4"
|
||||
kv_quant_mode = KVQuantMode.NVFP4 if is_nvfp4 else KVQuantMode.NONE
|
||||
spec_dtype = torch.uint8 if is_nvfp4 else dtype
|
||||
kv_cache_spec = FullAttentionSpec(
|
||||
block_size=BLOCK_SIZE,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
dtype=spec_dtype,
|
||||
kv_quant_mode=kv_quant_mode,
|
||||
)
|
||||
layer_names = ["test_layer_0"]
|
||||
|
||||
@@ -312,10 +446,20 @@ def _run_trtllm_integration(batch_spec):
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype="auto",
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
)
|
||||
|
||||
mock_layer = MockAttentionLayer(device)
|
||||
if is_nvfp4:
|
||||
# For nvfp4, k_scale/v_scale are the global quantization
|
||||
# scales (amax/448) used by reshape_and_cache_flash.
|
||||
kv_scale_t = torch.tensor(
|
||||
kv_scale_val, dtype=torch.float32, device=device
|
||||
)
|
||||
mock_layer._k_scale = kv_scale_t
|
||||
mock_layer._v_scale = kv_scale_t
|
||||
mock_layer._k_scale_float = kv_scale_val
|
||||
mock_layer._v_scale_float = kv_scale_val
|
||||
output = torch.empty_like(query_vllm)
|
||||
|
||||
impl.do_kv_cache_update(
|
||||
@@ -326,6 +470,23 @@ def _run_trtllm_integration(batch_spec):
|
||||
attn_metadata.slot_mapping,
|
||||
)
|
||||
|
||||
# nvfp4 trtllm kernel requires FP8 queries. In the real
|
||||
# pipeline the attention layer handles this; here we
|
||||
# quantize manually.
|
||||
if is_nvfp4:
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
q_amax = query_vllm.abs().amax().clamp(min=1e-12)
|
||||
q_s = (finfo.max / q_amax * 0.1).item()
|
||||
query_vllm = (
|
||||
(query_vllm * q_s)
|
||||
.clamp(finfo.min, finfo.max)
|
||||
.to(torch.float8_e4m3fn)
|
||||
)
|
||||
mock_layer._q_scale = torch.tensor(
|
||||
1.0 / q_s, dtype=torch.float32, device=device
|
||||
)
|
||||
mock_layer._q_scale_float = 1.0 / q_s
|
||||
|
||||
output = impl.forward(
|
||||
mock_layer,
|
||||
query_vllm,
|
||||
@@ -337,12 +498,11 @@ def _run_trtllm_integration(batch_spec):
|
||||
)
|
||||
|
||||
# 4. Compare against SDPA reference
|
||||
torch.testing.assert_close(
|
||||
output,
|
||||
sdpa_output,
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
if is_nvfp4:
|
||||
atol, rtol = 1.0, 1.0 # nvfp4 has higher quantization error
|
||||
else:
|
||||
atol, rtol = 1e-2, 1e-2
|
||||
torch.testing.assert_close(output, sdpa_output, atol=atol, rtol=rtol)
|
||||
|
||||
finally:
|
||||
set_kv_cache_layout(None)
|
||||
@@ -359,3 +519,18 @@ def test_trtllm_gen_full_attention_integration(batch_spec_name: str):
|
||||
MetadataBuilder.build() -> FlashInferImpl.forward() pipeline,
|
||||
with real TRTLLM kernels on Blackwell."""
|
||||
_run_trtllm_integration(BATCH_SPECS[batch_spec_name])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_spec_name",
|
||||
list(BATCH_SPECS.keys()),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_trtllm_gen_nvfp4_kv_integration(batch_spec_name: str):
|
||||
"""Test TRTLLM attention with nvfp4 KV cache through the full
|
||||
FlashInfer MetadataBuilder.build() -> FlashInferImpl.forward() pipeline."""
|
||||
_run_trtllm_integration(
|
||||
BATCH_SPECS[batch_spec_name],
|
||||
kv_cache_dtype="nvfp4",
|
||||
model_name=MODEL_NVFP4,
|
||||
)
|
||||
|
||||
@@ -788,6 +788,11 @@ def parse_flashinfer_trtllm_features() -> dict[str, dict[str, Any]]:
|
||||
if not trtllm_compute_cap:
|
||||
return {}
|
||||
|
||||
# KV cache dtypes that only work with a dedicated kernel (e.g. nvfp4
|
||||
# requires the SM100 NVFP4 MHA kernel) and should not appear in the
|
||||
# generic attention-backend feature matrix.
|
||||
kernel_only_kv_dtypes = ["nvfp4"]
|
||||
|
||||
return {
|
||||
"native": {
|
||||
# Native FlashInfer: everything except SM100
|
||||
@@ -798,6 +803,7 @@ def parse_flashinfer_trtllm_features() -> dict[str, dict[str, Any]]:
|
||||
"compute_capability": trtllm_compute_cap,
|
||||
"supports_sink": True,
|
||||
},
|
||||
"exclude_kv_dtypes": kernel_only_kv_dtypes,
|
||||
}
|
||||
|
||||
|
||||
@@ -963,6 +969,15 @@ def _expand_flashinfer_variants(
|
||||
native["supports_sink"] = fi_features["native"]["supports_sink"]
|
||||
native["compute_capability"] = f"{min_cc}.x-9.x"
|
||||
|
||||
# Remove KV dtypes only supported by SM100 kernels (e.g. nvfp4)
|
||||
exclude = fi_features.get("exclude_kv_dtypes", [])
|
||||
if exclude:
|
||||
native["kv_cache_dtypes"] = ", ".join(
|
||||
d
|
||||
for d in (d.strip() for d in native["kv_cache_dtypes"].split(","))
|
||||
if d not in exclude
|
||||
)
|
||||
|
||||
# Create TRTLLM entry
|
||||
trtllm = backend.copy()
|
||||
trtllm["version"] = "TRTLLM†"
|
||||
|
||||
@@ -1884,6 +1884,18 @@ class VllmConfig:
|
||||
"in the middle of a mm input"
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_nvfp4_kv_cache_with_mla(self) -> "VllmConfig":
|
||||
if self.model_config is None:
|
||||
return self
|
||||
if self.cache_config.cache_dtype == "nvfp4" and self.model_config.use_mla:
|
||||
raise ValueError(
|
||||
"nvfp4 KV cache is not supported with MLA (Multi-head Latent "
|
||||
"Attention) backends. Please use a different --kv-cache-dtype "
|
||||
"(e.g., 'fp8' or 'auto') for MLA models such as DeepSeek."
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_mamba_block_size(self) -> "VllmConfig":
|
||||
if self.model_config is None:
|
||||
|
||||
@@ -114,12 +114,12 @@ QUANT_ALGOS = [
|
||||
# MIXED_PRECISION,
|
||||
"MIXED_PRECISION",
|
||||
]
|
||||
KV_CACHE_QUANT_ALGOS = ["FP8"]
|
||||
KV_CACHE_QUANT_ALGOS = ["FP8", "NVFP4"]
|
||||
|
||||
|
||||
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
||||
class ModelOptKVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||
Supports loading kv-cache scaling factors from FP8 or NVFP4 checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: "ModelOptQuantConfigBase"):
|
||||
@@ -995,7 +995,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
|
||||
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
|
||||
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
|
||||
ModelOptFp8Config.KVCacheMethodCls = ModelOptKVCacheMethod
|
||||
|
||||
|
||||
class ModelOptNvFp4Config(ModelOptQuantConfigBase):
|
||||
@@ -1488,7 +1488,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
|
||||
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
|
||||
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
|
||||
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
|
||||
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptKVCacheMethod
|
||||
|
||||
|
||||
class ModelOptMxFp8Config(ModelOptQuantConfigBase):
|
||||
@@ -2018,7 +2018,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase):
|
||||
# Register the method classes for ModelOptMxFp8Config
|
||||
ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod
|
||||
ModelOptMxFp8Config.FusedMoEMethodCls = ModelOptMxFp8FusedMoE
|
||||
ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
|
||||
ModelOptMxFp8Config.KVCacheMethodCls = ModelOptKVCacheMethod
|
||||
|
||||
|
||||
class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase):
|
||||
@@ -2166,7 +2166,7 @@ class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase):
|
||||
# KV-cache quantization
|
||||
if isinstance(layer, Attention):
|
||||
if self.kv_cache_quant_method:
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
return ModelOptKVCacheMethod(self)
|
||||
return None
|
||||
|
||||
# Excluded layers
|
||||
|
||||
@@ -332,6 +332,7 @@ class FlashInferBackend(AttentionBackend):
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
"fp8_e5m2",
|
||||
"nvfp4",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -388,13 +389,15 @@ class FlashInferBackend(AttentionBackend):
|
||||
return stride_order
|
||||
|
||||
@staticmethod
|
||||
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
|
||||
def get_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
|
||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||
return torch.float8_e4m3fn
|
||||
elif kv_cache_dtype == "fp8_e5m2":
|
||||
return torch.float8_e5m2
|
||||
elif kv_cache_dtype == "nvfp4":
|
||||
return torch.uint8
|
||||
else:
|
||||
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
|
||||
raise ValueError(f"Unrecognized dtype: {kv_cache_dtype}")
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
@@ -622,9 +625,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# For NVFP4, kv_cache_dtype stays as the string "nvfp4"
|
||||
# which is passed to FlashInferImpl
|
||||
self.kv_cache_dtype = self.cache_dtype
|
||||
raise NotImplementedError("nvfp4 KV cache is not yet supported")
|
||||
else:
|
||||
self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
self.kv_cache_dtype = FlashInferBackend.get_dtype_for_flashinfer(
|
||||
self.cache_dtype
|
||||
)
|
||||
else:
|
||||
@@ -645,7 +647,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
):
|
||||
if self.is_kvcache_nvfp4:
|
||||
# NVFP4 KV cache uses FP8 quantized queries
|
||||
self.q_data_type = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
self.q_data_type = FlashInferBackend.get_dtype_for_flashinfer(
|
||||
"fp8_e4m3"
|
||||
)
|
||||
else:
|
||||
@@ -765,8 +767,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
dcp_a2a=self.dcp_a2a,
|
||||
)
|
||||
else:
|
||||
# NVFP4 KV cache requires the trtllm-gen backend inside
|
||||
# the wrapper; fa2/fa3 do not support nvfp4.
|
||||
backend = "trtllm-gen" if self.is_kvcache_nvfp4 else "auto"
|
||||
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||
self._get_workspace_buffer(), get_kv_cache_layout()
|
||||
self._get_workspace_buffer(),
|
||||
get_kv_cache_layout(),
|
||||
backend=backend,
|
||||
)
|
||||
assert self._prefill_wrapper is not None
|
||||
return self._prefill_wrapper
|
||||
@@ -786,6 +793,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
paged_kv_indptr = None
|
||||
paged_kv_indices = None
|
||||
paged_kv_last_page_len = None
|
||||
# NVFP4 KV cache requires the trtllm-gen backend inside
|
||||
# the wrapper; fa2/fa3 do not support nvfp4.
|
||||
backend = "trtllm-gen" if self.is_kvcache_nvfp4 else "auto"
|
||||
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self._get_workspace_buffer(),
|
||||
get_kv_cache_layout(),
|
||||
@@ -797,6 +807,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# at least as good as cuda cores for all attention ops in latest
|
||||
# gpus.
|
||||
use_tensor_cores=True,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
# save the decode wrapper
|
||||
@@ -1148,6 +1159,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
prefill_wrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
)
|
||||
# NVFP4 trtllm kernel only supports FP8 output;
|
||||
# use FP8 o_data_type so the wrapper matches the
|
||||
# FP8 output buffer allocated in forward().
|
||||
o_dtype = (
|
||||
FP8_DTYPE if self.is_kvcache_nvfp4 else self.model_config.dtype
|
||||
)
|
||||
prefill_wrapper.plan(
|
||||
qo_indptr=qo_indptr_prefill_cpu,
|
||||
paged_kv_indptr=paged_kv_indptr_prefill_cpu,
|
||||
@@ -1163,7 +1180,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
o_data_type=self.model_config.dtype,
|
||||
o_data_type=o_dtype,
|
||||
fixed_split_size=self.prefill_fixed_split_size,
|
||||
disable_split_kv=self.disable_split_kv,
|
||||
)
|
||||
@@ -1197,6 +1214,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# Use the persistent buffer with padding length,
|
||||
# instead of the same address but chunked version
|
||||
# in atten_metadata when using cudagraph.
|
||||
# NVFP4 trtllm kernel only supports FP8 output;
|
||||
# use FP8 o_data_type so the wrapper matches the
|
||||
# FP8 output buffer allocated in forward().
|
||||
o_dtype = (
|
||||
FP8_DTYPE if self.is_kvcache_nvfp4 else self.model_config.dtype
|
||||
)
|
||||
fast_plan_decode(
|
||||
decode_wrapper,
|
||||
indptr_cpu=self.paged_kv_indptr.cpu[: num_input_tokens + 1],
|
||||
@@ -1215,7 +1238,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
o_data_type=self.model_config.dtype,
|
||||
o_data_type=o_dtype,
|
||||
fixed_split_size=self.decode_fixed_split_size,
|
||||
disable_split_kv=self.disable_split_kv,
|
||||
)
|
||||
@@ -1300,6 +1323,17 @@ class FlashInferImpl(AttentionImpl):
|
||||
self.bmm2_scale: float | None = None
|
||||
self.o_sf_scale: float | None = None
|
||||
|
||||
# Pre-allocated FP8 output buffer for NVFP4 without fused output quant.
|
||||
if self.is_kvcache_nvfp4 and vllm_config is not None:
|
||||
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
self._nvfp4_fp8_out = torch.empty(
|
||||
(max_num_tokens, num_heads, head_size),
|
||||
dtype=FP8_DTYPE,
|
||||
device="cuda",
|
||||
)
|
||||
else:
|
||||
self._nvfp4_fp8_out = None
|
||||
|
||||
dcp_a2a = (
|
||||
vllm_config is not None
|
||||
and vllm_config.parallel_config.decode_context_parallel_size > 1
|
||||
@@ -1420,7 +1454,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
if self.kv_sharing_target_layer_name is None and is_quantized_kv_cache(
|
||||
self.kv_cache_dtype
|
||||
):
|
||||
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
torch_dtype = FlashInferBackend.get_dtype_for_flashinfer(
|
||||
self.kv_cache_dtype
|
||||
)
|
||||
kv_cache = kv_cache.view(torch_dtype)
|
||||
@@ -1500,13 +1534,37 @@ class FlashInferImpl(AttentionImpl):
|
||||
)
|
||||
assert prefill_wrapper._sm_scale == self.scale
|
||||
assert prefill_wrapper._causal
|
||||
|
||||
if self.is_kvcache_nvfp4:
|
||||
kv_cache_permute = nvfp4_kv_data
|
||||
kv_cache_sf = (
|
||||
nvfp4_kv_block_scales if self.is_kvcache_nvfp4 else None
|
||||
)
|
||||
|
||||
# NVFP4 trtllm kernel only supports FP8 output.
|
||||
# Use a pre-allocated FP8 buffer and dequantize
|
||||
# afterwards.
|
||||
needs_fp8_out_prefill = (
|
||||
self.is_kvcache_nvfp4 and output.dtype != FP8_DTYPE
|
||||
)
|
||||
if needs_fp8_out_prefill:
|
||||
out_prefill = self._nvfp4_fp8_out[:num_prefill_tokens]
|
||||
else:
|
||||
out_prefill = output[num_decode_tokens:]
|
||||
|
||||
prefill_wrapper.run(
|
||||
prefill_query,
|
||||
kv_cache_permute,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
out=output[num_decode_tokens:],
|
||||
out=out_prefill,
|
||||
kv_cache_sf=kv_cache_sf,
|
||||
)
|
||||
|
||||
if needs_fp8_out_prefill:
|
||||
output[
|
||||
num_decode_tokens : num_decode_tokens + num_prefill_tokens
|
||||
].copy_(out_prefill.to(output.dtype))
|
||||
else:
|
||||
assert isinstance(attn_metadata.prefill, TRTLLMPrefill)
|
||||
# prefill_query may be non-contiguous or have degenerate strides
|
||||
@@ -1537,6 +1595,12 @@ class FlashInferImpl(AttentionImpl):
|
||||
assert self.o_sf_scale is None
|
||||
out = output[num_decode_tokens:]
|
||||
|
||||
# NVFP4 trtllm kernel only supports FP8 output.
|
||||
# Use a pre-allocated FP8 buffer and dequantize afterwards.
|
||||
needs_fp8_out = self.is_kvcache_nvfp4 and output.dtype != FP8_DTYPE
|
||||
if needs_fp8_out:
|
||||
out = self._nvfp4_fp8_out[:num_prefill_tokens]
|
||||
|
||||
prefill_kv_block_scales = None
|
||||
if self.is_kvcache_nvfp4:
|
||||
# NVFP4 trtllm-gen kernel requires FP8 query.
|
||||
@@ -1547,7 +1611,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
)
|
||||
mock_kv_cache = nvfp4_kv_data
|
||||
mock_block_table = block_tables_prefill
|
||||
prefill_kv_block_scales = nvfp4_kv_block_scales # noqa: F841
|
||||
prefill_kv_block_scales = nvfp4_kv_block_scales
|
||||
elif (
|
||||
attn_metadata.q_data_type != FP8_DTYPE
|
||||
and self.kv_cache_dtype.startswith("fp8")
|
||||
@@ -1598,8 +1662,14 @@ class FlashInferImpl(AttentionImpl):
|
||||
sinks=self.sinks,
|
||||
o_sf_scale=self.o_sf_scale,
|
||||
out=out,
|
||||
kv_cache_sf=prefill_kv_block_scales,
|
||||
)
|
||||
|
||||
if needs_fp8_out:
|
||||
output[
|
||||
num_decode_tokens : num_decode_tokens + num_prefill_tokens
|
||||
].copy_(out[:num_prefill_tokens].to(output.dtype))
|
||||
|
||||
if num_decode_tokens > 0:
|
||||
decode_query = query[:num_decode_tokens]
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
@@ -1612,6 +1682,18 @@ class FlashInferImpl(AttentionImpl):
|
||||
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
|
||||
assert decode_wrapper._sm_scale == self.scale
|
||||
|
||||
if self.is_kvcache_nvfp4:
|
||||
kv_cache_permute = nvfp4_kv_data
|
||||
kv_cache_sf = nvfp4_kv_block_scales if self.is_kvcache_nvfp4 else None
|
||||
|
||||
# NVFP4 kernel only supports FP8 output.
|
||||
# Use a pre-allocated FP8 buffer and dequantize afterwards.
|
||||
needs_fp8_out = self.is_kvcache_nvfp4 and output.dtype != FP8_DTYPE
|
||||
if needs_fp8_out:
|
||||
out_decode = self._nvfp4_fp8_out[:num_decode_tokens]
|
||||
else:
|
||||
out_decode = output[:num_decode_tokens]
|
||||
|
||||
if use_dcp:
|
||||
decode_query = get_dcp_group().all_gather(
|
||||
decode_query.contiguous(), dim=-2
|
||||
@@ -1630,6 +1712,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
out=output_tmp,
|
||||
lse=lse,
|
||||
return_lse=True,
|
||||
kv_cache_sf=kv_cache_sf,
|
||||
)
|
||||
output[:num_decode_tokens] = self.dcp_combine(
|
||||
output_tmp,
|
||||
@@ -1642,8 +1725,12 @@ class FlashInferImpl(AttentionImpl):
|
||||
kv_cache_permute,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
out=output[:num_decode_tokens],
|
||||
out=out_decode,
|
||||
kv_cache_sf=kv_cache_sf,
|
||||
)
|
||||
|
||||
if needs_fp8_out:
|
||||
output[:num_decode_tokens].copy_(out_decode.to(output.dtype))
|
||||
else:
|
||||
# decode_query may be non-contiguous or have degenerate strides
|
||||
assert isinstance(attn_metadata.decode, TRTLLMDecode)
|
||||
@@ -1686,6 +1773,12 @@ class FlashInferImpl(AttentionImpl):
|
||||
assert self.o_sf_scale is None
|
||||
out = output[:num_decode_tokens]
|
||||
|
||||
# NVFP4 trtllm kernel only supports FP8 output.
|
||||
# Use a pre-allocated FP8 buffer and dequantize afterwards.
|
||||
needs_fp8_out = self.is_kvcache_nvfp4 and output.dtype != FP8_DTYPE
|
||||
if needs_fp8_out:
|
||||
out = self._nvfp4_fp8_out[:num_decode_tokens]
|
||||
|
||||
if num_decode_tokens % attn_metadata.num_decodes != 0:
|
||||
# This gets triggered when the dummy_run forces
|
||||
# attention to be initialized with q_len = 0
|
||||
@@ -1695,9 +1788,9 @@ class FlashInferImpl(AttentionImpl):
|
||||
|
||||
trtllm_batch_decode_with_kv_cache(
|
||||
query=decode_query,
|
||||
kv_cache=nvfp4_kv_data
|
||||
if self.is_kvcache_nvfp4
|
||||
else kv_cache_permute,
|
||||
kv_cache=(
|
||||
nvfp4_kv_data if self.is_kvcache_nvfp4 else kv_cache_permute
|
||||
),
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables_decode,
|
||||
seq_lens=seq_lens_decode,
|
||||
@@ -1709,7 +1802,13 @@ class FlashInferImpl(AttentionImpl):
|
||||
o_sf_scale=self.o_sf_scale,
|
||||
out=out,
|
||||
q_len_per_req=q_len_per_req,
|
||||
kv_cache_sf=(
|
||||
nvfp4_kv_block_scales if self.is_kvcache_nvfp4 else None
|
||||
),
|
||||
)
|
||||
|
||||
if needs_fp8_out:
|
||||
output[:num_decode_tokens].copy_(out.to(output.dtype))
|
||||
return output_padded
|
||||
|
||||
def do_kv_cache_update(
|
||||
|
||||
@@ -151,6 +151,16 @@ class AttentionSpec(KVCacheSpec):
|
||||
|
||||
@property
|
||||
def real_page_size_bytes(self) -> int:
|
||||
if self.kv_quant_mode.is_nvfp4:
|
||||
# Packed layout: fp4 data + fp8 block scales per head.
|
||||
full_dim = nvfp4_kv_cache_full_dim(self.head_size)
|
||||
return (
|
||||
2
|
||||
* self.block_size
|
||||
* self.num_kv_heads
|
||||
* full_dim
|
||||
* get_dtype_size(self.dtype)
|
||||
)
|
||||
return (
|
||||
2
|
||||
* self.block_size
|
||||
|
||||
Reference in New Issue
Block a user