Add nvfp4 kv cache support (#40177)

Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
This commit is contained in:
sychen52
2026-04-30 21:55:16 -07:00
committed by GitHub
parent 941fb50835
commit 947138b6c2
8 changed files with 503 additions and 96 deletions
@@ -788,6 +788,11 @@ def parse_flashinfer_trtllm_features() -> dict[str, dict[str, Any]]:
if not trtllm_compute_cap:
return {}
# KV cache dtypes that only work with a dedicated kernel (e.g. nvfp4
# requires the SM100 NVFP4 MHA kernel) and should not appear in the
# generic attention-backend feature matrix.
kernel_only_kv_dtypes = ["nvfp4"]
return {
"native": {
# Native FlashInfer: everything except SM100
@@ -798,6 +803,7 @@ def parse_flashinfer_trtllm_features() -> dict[str, dict[str, Any]]:
"compute_capability": trtllm_compute_cap,
"supports_sink": True,
},
"exclude_kv_dtypes": kernel_only_kv_dtypes,
}
@@ -963,6 +969,15 @@ def _expand_flashinfer_variants(
native["supports_sink"] = fi_features["native"]["supports_sink"]
native["compute_capability"] = f"{min_cc}.x-9.x"
# Remove KV dtypes only supported by SM100 kernels (e.g. nvfp4)
exclude = fi_features.get("exclude_kv_dtypes", [])
if exclude:
native["kv_cache_dtypes"] = ", ".join(
d
for d in (d.strip() for d in native["kv_cache_dtypes"].split(","))
if d not in exclude
)
# Create TRTLLM entry
trtllm = backend.copy()
trtllm["version"] = "TRTLLM†"