[Kernel] Batch invariant NVFP4 linear using cutlass (#39912)

Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Jakub Zakrzewski
2026-05-23 15:41:12 +02:00
committed by GitHub
parent 3f3e862681
commit 5bb8d2767a
5 changed files with 211 additions and 50 deletions
+1
View File
@@ -367,6 +367,7 @@ steps:
- VLLM_TEST_MODEL=deepseek-ai/DeepSeek-V2-Lite-Chat pytest -v -s v1/determinism/test_batch_invariance.py::test_v1_generation_is_deterministic_across_batch_sizes_with_needle[TRITON_MLA]
- VLLM_TEST_MODEL=Qwen/Qwen3-30B-A3B-Thinking-2507-FP8 pytest -v -s v1/determinism/test_batch_invariance.py::test_v1_generation_is_deterministic_across_batch_sizes_with_needle[FLASH_ATTN]
- pytest -v -s v1/determinism/test_nvfp4_batch_invariant.py
- pytest -v -s v1/determinism/test_nvfp4_batch_invariant_scaled_mm.py
- label: Acceptance Length Test (Large Models) # optional
device: h200_35gb
@@ -22,6 +22,8 @@
#include "cutlass/cutlass.h"
#include <type_traits>
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
@@ -30,15 +32,21 @@
#include "cutlass/util/packed_stride.hpp"
#include "core/math.hpp"
#include "core/batch_invariant.hpp"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// Configuration for M in (256, inf)
// Configuration for M in (256, inf), also reused for batch-invariant mode
// to keep a fixed large-M tiling across all batch sizes.
// Do not change the tile K or tile scheduler here unless you are also
// updating the batch-invariant behavior; if batch-invariant mode needs a
// different schedule, add a dedicated batch-invariant config/path instead.
struct sm100_fp4_config_default {
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using TileShape = Shape<_256, _256, _256>;
using ClusterShape = Shape<_2, _1, _1>;
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
@@ -48,6 +56,7 @@ struct sm100_fp4_config_default {
struct sm100_fp4_config_M256 {
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileScheduler = void;
using TileShape = Shape<_256, _128, _256>;
using ClusterShape = Shape<_2, _1, _1>;
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
@@ -57,6 +66,7 @@ struct sm100_fp4_config_M256 {
struct sm100_fp4_config_M16 {
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileScheduler = void;
using TileShape = Shape<_128, _128, _256>;
using ClusterShape = Shape<_1, _1, _1>;
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
@@ -98,7 +108,7 @@ struct Fp4GemmSm100 {
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
ElementAccumulator, ElementC, LayoutCTag, AlignmentC, ElementD,
LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;
typename Config::EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
@@ -107,10 +117,13 @@ struct Fp4GemmSm100 {
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
typename Config::KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
using TileScheduler = typename Config::TileScheduler;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>,
CollectiveMainloop,
CollectiveEpilogue, TileScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{}));
@@ -205,6 +218,17 @@ void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int64_t m,
int64_t n, int64_t k, cudaStream_t stream) {
if (vllm::vllm_is_batch_invariant()) {
using BiGemm = Fp4GemmSm100<sm100_fp4_config_default, OutType>;
static_assert(
cute::is_same_v<typename BiGemm::TileScheduler,
cutlass::gemm::PersistentScheduler>,
"batch_invariant requires a persistent tile scheduler; stream-K or "
"split-K would break numerical invariance");
runGemm<BiGemm>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
return;
}
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 16) {
@@ -22,6 +22,8 @@
#include "cutlass/cutlass.h"
#include <type_traits>
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
@@ -30,6 +32,7 @@
#include "cutlass/util/packed_stride.hpp"
#include "core/math.hpp"
#include "core/batch_invariant.hpp"
using namespace cute;
@@ -49,12 +52,22 @@ constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
struct sm120_fp4_config_M256 {
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileScheduler = void;
using ClusterShape = Shape<_1, _1, _1>;
using MmaTileShape = Shape<_128, _128, _128>;
using PerSmTileShape_MNK = Shape<_128, _128, _128>;
};
struct sm120_fp4_config_default {
// Also used for batch-invariant mode.
// Do not change the tile K or tile scheduler here unless you are also
// updating the batch-invariant behavior; if batch-invariant mode needs a
// different schedule, add a dedicated batch-invariant config/path instead.
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using ClusterShape = Shape<_1, _1, _1>;
using MmaTileShape = Shape<_256, _128, _128>;
using PerSmTileShape_MNK = Shape<_256, _128, _128>;
@@ -91,7 +104,7 @@ struct Fp4GemmSm120 {
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
ElementAccumulator, ElementC, LayoutCTag, AlignmentC, ElementD,
LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;
typename Config::EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
@@ -100,10 +113,13 @@ struct Fp4GemmSm120 {
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
typename Config::KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
using TileScheduler = typename Config::TileScheduler;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>,
CollectiveMainloop,
CollectiveEpilogue, TileScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};
@@ -180,39 +196,41 @@ void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
}
void cutlass_fp4_bf16_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int m,
int n, int k, cudaStream_t stream) {
namespace {
// Dispatch function to select appropriate config based on M (file-local;
// internal linkage avoids clashing with SM100's cutlass_fp4_gemm_dispatch in
// nvfp4_scaled_mm_kernels.cu).
template <typename OutType>
void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int m, int n,
int k, cudaStream_t stream) {
if (vllm::vllm_is_batch_invariant()) {
using BiGemm = Fp4GemmSm120<sm120_fp4_config_default, OutType>;
static_assert(
cute::is_same_v<typename BiGemm::TileScheduler,
cutlass::gemm::PersistentScheduler>,
"batch_invariant requires a persistent tile scheduler; stream-K or "
"split-K would break numerical invariance");
runGemm<typename BiGemm::Gemm>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
return;
}
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 256) {
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::bfloat16_t>::Gemm>(
runGemm<typename Fp4GemmSm120<sm120_fp4_config_M256, OutType>::Gemm>(
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else {
runGemm<Fp4GemmSm120<sm120_fp4_config_default, cutlass::bfloat16_t>::Gemm>(
runGemm<typename Fp4GemmSm120<sm120_fp4_config_default, OutType>::Gemm>(
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
}
}
void cutlass_fp4_f16_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int m,
int n, int k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 256) {
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::half_t>::Gemm>(
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else {
runGemm<Fp4GemmSm120<sm120_fp4_config_default, cutlass::half_t>::Gemm>(
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
}
}
} // namespace
void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
@@ -275,11 +293,11 @@ void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
if (out_dtype == torch::headeronly::ScalarType::BFloat16) {
return cutlass_fp4_bf16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
stream);
return cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else if (out_dtype == torch::headeronly::ScalarType::Half) {
return cutlass_fp4_f16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
stream);
return cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf,
alpha, m, n, k, stream);
} else {
STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (",
out_dtype, ")");
@@ -0,0 +1,101 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""NVFP4 CUTLASS GEMM tests that require ``VLLM_BATCH_INVARIANT=1``.
Must run in a **fresh** pytest process:
pytest tests/v1/determinism/test_nvfp4_batch_invariant_scaled_mm.py -v
Do not share a session with ``tests/kernels/quantization/test_nvfp4_scaled_mm.py``:
the native code caches whether batch invariance is enabled on the first GEMM, and
if ``VLLM_BATCH_INVARIANT`` was not set at that moment, it stays disabled for the
rest of the process.
"""
import os
import pytest
import torch
from tests.kernels.quantization.nvfp4_utils import get_nvfp4_global_scale
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
if not current_platform.has_device_capability(100):
pytest.skip(
reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True,
)
DTYPES = [torch.float16, torch.bfloat16]
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
SHAPES.extend(PAD_SHAPES)
CONSISTENCY_SHAPES = [
(256, 128, 4096),
(512, 256, 4096),
(256, 256, 2048),
(241, 160, 2048),
(401, 352, 1984),
(333, 320, 1008),
(287, 96, 4096),
]
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shape", CONSISTENCY_SHAPES)
@torch.inference_mode()
def test_nvfp4_gemm_batch_invariance(
dtype: torch.dtype,
shape: tuple[int, int, int],
) -> None:
"""Batch invariance: each row of a full-``M`` GEMM matches its ``M=1`` counterpart.
For row ``i``, compares ``cutlass_scaled_fp4_mm`` run once over all ``M``
rows against a separate call with ``A`` sliced to ``a_dtype[i : i+1]``.
Catches kernels whose reduction or scheduling depends on ``M`` or adjacent
rows.
"""
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
set_random_seed(seed)
m, n, packed_k = shape
k = packed_k * 2 # real K (FP4 elements)
a_dtype = torch.randn((m, k), dtype=dtype, device="cuda")
b_dtype = torch.randn((n, k), dtype=dtype, device="cuda")
a_global_scale = get_nvfp4_global_scale(a_dtype)
b_global_scale = get_nvfp4_global_scale(b_dtype)
alpha = 1.0 / (a_global_scale * b_global_scale)
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
a_fp4_full, a_sf_full = ops.scaled_fp4_quant(a_dtype, a_global_scale)
out_full = ops.cutlass_scaled_fp4_mm(
a_fp4_full,
b_fp4,
a_sf_full,
b_scale_interleaved,
alpha,
dtype,
)
for i in range(m):
a_row = a_dtype[i : i + 1]
a_fp4_row, a_sf_row = ops.scaled_fp4_quant(a_row, a_global_scale)
out_row = ops.cutlass_scaled_fp4_mm(
a_fp4_row,
b_fp4,
a_sf_row,
b_scale_interleaved,
alpha,
dtype,
)
assert torch.equal(out_full[i], out_row[0]), (
f"VLLM_BATCH_INVARIANT: row {i} differs between M={m} and M=1: "
f"max_abs_diff={(out_full[i] - out_row[0]).abs().max().item()}"
)
+29 -12
View File
@@ -833,24 +833,41 @@ def init_nvfp4_linear_kernel() -> NvFp4LinearKernel:
current platform."""
config = NvFp4LinearLayerConfig()
# VLLM_BATCH_INVARIANT unconditionally forces emulation for deterministic
# execution. It overrides both --linear-backend and the deprecated env
# vars below.
# VLLM_BATCH_INVARIANT forces deterministic execution. Prefer the
# batch-invariant CUTLASS implementation when available, otherwise fall
# back to emulation. It overrides both --linear-backend and the deprecated
# env vars below.
force_kernel: type[NvFp4LinearKernel] | None = None
linear_backend = _get_linear_backend()
if envs.VLLM_BATCH_INVARIANT:
if linear_backend not in ("auto", "emulation"):
logger.warning_once(
"VLLM_BATCH_INVARIANT overrides --linear-backend=%s; using "
"the emulation backend for deterministic execution.",
linear_backend,
)
bi_supported, reason = CutlassNvFp4LinearKernel.is_supported()
if bi_supported:
if linear_backend not in ("auto", "cutlass"):
logger.warning_once(
"VLLM_BATCH_INVARIANT overrides --linear-backend=%s; "
"using the CUTLASS backend for deterministic execution.",
linear_backend,
)
else:
logger.info_once(
"VLLM_BATCH_INVARIANT forces NVFP4 linear to use the "
"CUTLASS backend for deterministic execution."
)
force_kernel = CutlassNvFp4LinearKernel
else:
if linear_backend not in ("auto", "emulation"):
logger.warning_once(
"VLLM_BATCH_INVARIANT overrides --linear-backend=%s; "
"using the emulation backend for deterministic execution.",
linear_backend,
)
logger.info_once(
"VLLM_BATCH_INVARIANT forces NVFP4 linear to use the "
"emulation backend for deterministic execution."
"VLLM_BATCH_INVARIANT is set but the batch-invariant NVFP4 "
"kernel is not supported on this platform; falling back to "
"emulation for deterministic execution. Reason: %s",
reason,
)
force_kernel = EmulationNvFp4LinearKernel
force_kernel = EmulationNvFp4LinearKernel
elif linear_backend == "auto":
# Deprecated env-var overrides — only honoured when --linear-backend
# is "auto". Deprecation warnings are emitted from vllm/envs.py.