mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Kernel] Batch invariant NVFP4 linear using cutlass (#39912)
Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
@@ -367,6 +367,7 @@ steps:
|
||||
- VLLM_TEST_MODEL=deepseek-ai/DeepSeek-V2-Lite-Chat pytest -v -s v1/determinism/test_batch_invariance.py::test_v1_generation_is_deterministic_across_batch_sizes_with_needle[TRITON_MLA]
|
||||
- VLLM_TEST_MODEL=Qwen/Qwen3-30B-A3B-Thinking-2507-FP8 pytest -v -s v1/determinism/test_batch_invariance.py::test_v1_generation_is_deterministic_across_batch_sizes_with_needle[FLASH_ATTN]
|
||||
- pytest -v -s v1/determinism/test_nvfp4_batch_invariant.py
|
||||
- pytest -v -s v1/determinism/test_nvfp4_batch_invariant_scaled_mm.py
|
||||
|
||||
- label: Acceptance Length Test (Large Models) # optional
|
||||
device: h200_35gb
|
||||
|
||||
@@ -22,6 +22,8 @@
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
@@ -30,15 +32,21 @@
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
#include "core/batch_invariant.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
// Configuration for M in (256, inf)
|
||||
// Configuration for M in (256, inf), also reused for batch-invariant mode
|
||||
// to keep a fixed large-M tiling across all batch sizes.
|
||||
// Do not change the tile K or tile scheduler here unless you are also
|
||||
// updating the batch-invariant behavior; if batch-invariant mode needs a
|
||||
// different schedule, add a dedicated batch-invariant config/path instead.
|
||||
struct sm100_fp4_config_default {
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
using TileShape = Shape<_256, _256, _256>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
||||
@@ -48,6 +56,7 @@ struct sm100_fp4_config_default {
|
||||
struct sm100_fp4_config_M256 {
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileScheduler = void;
|
||||
using TileShape = Shape<_256, _128, _256>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
|
||||
@@ -57,6 +66,7 @@ struct sm100_fp4_config_M256 {
|
||||
struct sm100_fp4_config_M16 {
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileScheduler = void;
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
|
||||
@@ -98,7 +108,7 @@ struct Fp4GemmSm100 {
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||
ElementAccumulator, ElementC, LayoutCTag, AlignmentC, ElementD,
|
||||
LayoutDTag, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;
|
||||
typename Config::EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
@@ -107,10 +117,13 @@ struct Fp4GemmSm100 {
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
|
||||
typename Config::KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
using TileScheduler = typename Config::TileScheduler;
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue, TileScheduler>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{}));
|
||||
@@ -205,6 +218,17 @@ void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha, int64_t m,
|
||||
int64_t n, int64_t k, cudaStream_t stream) {
|
||||
if (vllm::vllm_is_batch_invariant()) {
|
||||
using BiGemm = Fp4GemmSm100<sm100_fp4_config_default, OutType>;
|
||||
static_assert(
|
||||
cute::is_same_v<typename BiGemm::TileScheduler,
|
||||
cutlass::gemm::PersistentScheduler>,
|
||||
"batch_invariant requires a persistent tile scheduler; stream-K or "
|
||||
"split-K would break numerical invariance");
|
||||
runGemm<BiGemm>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
||||
|
||||
if (mp2 <= 16) {
|
||||
|
||||
@@ -22,6 +22,8 @@
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
@@ -30,6 +32,7 @@
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
#include "core/batch_invariant.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@@ -49,12 +52,22 @@ constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
|
||||
|
||||
struct sm120_fp4_config_M256 {
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileScheduler = void;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using MmaTileShape = Shape<_128, _128, _128>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _128, _128>;
|
||||
};
|
||||
|
||||
struct sm120_fp4_config_default {
|
||||
// Also used for batch-invariant mode.
|
||||
// Do not change the tile K or tile scheduler here unless you are also
|
||||
// updating the batch-invariant behavior; if batch-invariant mode needs a
|
||||
// different schedule, add a dedicated batch-invariant config/path instead.
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using MmaTileShape = Shape<_256, _128, _128>;
|
||||
using PerSmTileShape_MNK = Shape<_256, _128, _128>;
|
||||
@@ -91,7 +104,7 @@ struct Fp4GemmSm120 {
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||
ElementAccumulator, ElementC, LayoutCTag, AlignmentC, ElementD,
|
||||
LayoutDTag, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;
|
||||
typename Config::EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
@@ -100,10 +113,13 @@ struct Fp4GemmSm120 {
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
|
||||
typename Config::KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
using TileScheduler = typename Config::TileScheduler;
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue, TileScheduler>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
};
|
||||
@@ -180,39 +196,41 @@ void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
|
||||
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
|
||||
void cutlass_fp4_bf16_gemm_dispatch(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha, int m,
|
||||
int n, int k, cudaStream_t stream) {
|
||||
namespace {
|
||||
|
||||
// Dispatch function to select appropriate config based on M (file-local;
|
||||
// internal linkage avoids clashing with SM100's cutlass_fp4_gemm_dispatch in
|
||||
// nvfp4_scaled_mm_kernels.cu).
|
||||
template <typename OutType>
|
||||
void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha, int m, int n,
|
||||
int k, cudaStream_t stream) {
|
||||
if (vllm::vllm_is_batch_invariant()) {
|
||||
using BiGemm = Fp4GemmSm120<sm120_fp4_config_default, OutType>;
|
||||
static_assert(
|
||||
cute::is_same_v<typename BiGemm::TileScheduler,
|
||||
cutlass::gemm::PersistentScheduler>,
|
||||
"batch_invariant requires a persistent tile scheduler; stream-K or "
|
||||
"split-K would break numerical invariance");
|
||||
runGemm<typename BiGemm::Gemm>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
||||
if (mp2 <= 256) {
|
||||
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::bfloat16_t>::Gemm>(
|
||||
runGemm<typename Fp4GemmSm120<sm120_fp4_config_M256, OutType>::Gemm>(
|
||||
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else {
|
||||
runGemm<Fp4GemmSm120<sm120_fp4_config_default, cutlass::bfloat16_t>::Gemm>(
|
||||
runGemm<typename Fp4GemmSm120<sm120_fp4_config_default, OutType>::Gemm>(
|
||||
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_fp4_f16_gemm_dispatch(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha, int m,
|
||||
int n, int k, cudaStream_t stream) {
|
||||
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
||||
if (mp2 <= 256) {
|
||||
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::half_t>::Gemm>(
|
||||
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else {
|
||||
runGemm<Fp4GemmSm120<sm120_fp4_config_default, cutlass::half_t>::Gemm>(
|
||||
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
@@ -275,11 +293,11 @@ void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
|
||||
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
|
||||
|
||||
if (out_dtype == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_fp4_bf16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
|
||||
stream);
|
||||
return cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(
|
||||
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else if (out_dtype == torch::headeronly::ScalarType::Half) {
|
||||
return cutlass_fp4_f16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
|
||||
stream);
|
||||
return cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf,
|
||||
alpha, m, n, k, stream);
|
||||
} else {
|
||||
STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (",
|
||||
out_dtype, ")");
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""NVFP4 CUTLASS GEMM tests that require ``VLLM_BATCH_INVARIANT=1``.
|
||||
|
||||
Must run in a **fresh** pytest process:
|
||||
|
||||
pytest tests/v1/determinism/test_nvfp4_batch_invariant_scaled_mm.py -v
|
||||
|
||||
Do not share a session with ``tests/kernels/quantization/test_nvfp4_scaled_mm.py``:
|
||||
the native code caches whether batch invariance is enabled on the first GEMM, and
|
||||
if ``VLLM_BATCH_INVARIANT`` was not set at that moment, it stays disabled for the
|
||||
rest of the process.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quantization.nvfp4_utils import get_nvfp4_global_scale
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(
|
||||
reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
|
||||
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
|
||||
SHAPES.extend(PAD_SHAPES)
|
||||
|
||||
|
||||
CONSISTENCY_SHAPES = [
|
||||
(256, 128, 4096),
|
||||
(512, 256, 4096),
|
||||
(256, 256, 2048),
|
||||
(241, 160, 2048),
|
||||
(401, 352, 1984),
|
||||
(333, 320, 1008),
|
||||
(287, 96, 4096),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("shape", CONSISTENCY_SHAPES)
|
||||
@torch.inference_mode()
|
||||
def test_nvfp4_gemm_batch_invariance(
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, int, int],
|
||||
) -> None:
|
||||
"""Batch invariance: each row of a full-``M`` GEMM matches its ``M=1`` counterpart.
|
||||
|
||||
For row ``i``, compares ``cutlass_scaled_fp4_mm`` run once over all ``M``
|
||||
rows against a separate call with ``A`` sliced to ``a_dtype[i : i+1]``.
|
||||
Catches kernels whose reduction or scheduling depends on ``M`` or adjacent
|
||||
rows.
|
||||
"""
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
set_random_seed(seed)
|
||||
m, n, packed_k = shape
|
||||
k = packed_k * 2 # real K (FP4 elements)
|
||||
|
||||
a_dtype = torch.randn((m, k), dtype=dtype, device="cuda")
|
||||
b_dtype = torch.randn((n, k), dtype=dtype, device="cuda")
|
||||
|
||||
a_global_scale = get_nvfp4_global_scale(a_dtype)
|
||||
b_global_scale = get_nvfp4_global_scale(b_dtype)
|
||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||
|
||||
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
|
||||
|
||||
a_fp4_full, a_sf_full = ops.scaled_fp4_quant(a_dtype, a_global_scale)
|
||||
out_full = ops.cutlass_scaled_fp4_mm(
|
||||
a_fp4_full,
|
||||
b_fp4,
|
||||
a_sf_full,
|
||||
b_scale_interleaved,
|
||||
alpha,
|
||||
dtype,
|
||||
)
|
||||
|
||||
for i in range(m):
|
||||
a_row = a_dtype[i : i + 1]
|
||||
a_fp4_row, a_sf_row = ops.scaled_fp4_quant(a_row, a_global_scale)
|
||||
out_row = ops.cutlass_scaled_fp4_mm(
|
||||
a_fp4_row,
|
||||
b_fp4,
|
||||
a_sf_row,
|
||||
b_scale_interleaved,
|
||||
alpha,
|
||||
dtype,
|
||||
)
|
||||
|
||||
assert torch.equal(out_full[i], out_row[0]), (
|
||||
f"VLLM_BATCH_INVARIANT: row {i} differs between M={m} and M=1: "
|
||||
f"max_abs_diff={(out_full[i] - out_row[0]).abs().max().item()}"
|
||||
)
|
||||
@@ -833,24 +833,41 @@ def init_nvfp4_linear_kernel() -> NvFp4LinearKernel:
|
||||
current platform."""
|
||||
config = NvFp4LinearLayerConfig()
|
||||
|
||||
# VLLM_BATCH_INVARIANT unconditionally forces emulation for deterministic
|
||||
# execution. It overrides both --linear-backend and the deprecated env
|
||||
# vars below.
|
||||
# VLLM_BATCH_INVARIANT forces deterministic execution. Prefer the
|
||||
# batch-invariant CUTLASS implementation when available, otherwise fall
|
||||
# back to emulation. It overrides both --linear-backend and the deprecated
|
||||
# env vars below.
|
||||
force_kernel: type[NvFp4LinearKernel] | None = None
|
||||
linear_backend = _get_linear_backend()
|
||||
if envs.VLLM_BATCH_INVARIANT:
|
||||
if linear_backend not in ("auto", "emulation"):
|
||||
logger.warning_once(
|
||||
"VLLM_BATCH_INVARIANT overrides --linear-backend=%s; using "
|
||||
"the emulation backend for deterministic execution.",
|
||||
linear_backend,
|
||||
)
|
||||
bi_supported, reason = CutlassNvFp4LinearKernel.is_supported()
|
||||
if bi_supported:
|
||||
if linear_backend not in ("auto", "cutlass"):
|
||||
logger.warning_once(
|
||||
"VLLM_BATCH_INVARIANT overrides --linear-backend=%s; "
|
||||
"using the CUTLASS backend for deterministic execution.",
|
||||
linear_backend,
|
||||
)
|
||||
else:
|
||||
logger.info_once(
|
||||
"VLLM_BATCH_INVARIANT forces NVFP4 linear to use the "
|
||||
"CUTLASS backend for deterministic execution."
|
||||
)
|
||||
force_kernel = CutlassNvFp4LinearKernel
|
||||
else:
|
||||
if linear_backend not in ("auto", "emulation"):
|
||||
logger.warning_once(
|
||||
"VLLM_BATCH_INVARIANT overrides --linear-backend=%s; "
|
||||
"using the emulation backend for deterministic execution.",
|
||||
linear_backend,
|
||||
)
|
||||
logger.info_once(
|
||||
"VLLM_BATCH_INVARIANT forces NVFP4 linear to use the "
|
||||
"emulation backend for deterministic execution."
|
||||
"VLLM_BATCH_INVARIANT is set but the batch-invariant NVFP4 "
|
||||
"kernel is not supported on this platform; falling back to "
|
||||
"emulation for deterministic execution. Reason: %s",
|
||||
reason,
|
||||
)
|
||||
force_kernel = EmulationNvFp4LinearKernel
|
||||
force_kernel = EmulationNvFp4LinearKernel
|
||||
elif linear_backend == "auto":
|
||||
# Deprecated env-var overrides — only honoured when --linear-backend
|
||||
# is "auto". Deprecation warnings are emitted from vllm/envs.py.
|
||||
|
||||
Reference in New Issue
Block a user