[PERF]MiniMax-M2 gate kernel (#38445)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: qianlihuang <91178480+qianlihuang@users.noreply.github.com>
Co-authored-by: Yiliu Dong <91178480+qianlihuang@users.noreply.github.com>
This commit is contained in:
Jee Jee Li
2026-05-30 09:28:34 +08:00
committed by GitHub
parent 187457a952
commit 559d6710bf
10 changed files with 716 additions and 23 deletions
+24 -10
View File
@@ -683,6 +683,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
"in CUDA target architectures.")
endif()
# FP32 router GEMM (H=3072, E=256, M<=32). Requires SM90+ and CUDA >= 12.0.
cuda_archs_sm90plus(FP32_ROUTER_GEMM_ARCHS "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND FP32_ROUTER_GEMM_ARCHS)
set(SRCS
"csrc/libtorch_stable/fp32_router_gemm_entry.cu"
"csrc/libtorch_stable/fp32_router_gemm.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP32_ROUTER_GEMM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
message(STATUS "Building fp32_router_gemm for archs: ${FP32_ROUTER_GEMM_ARCHS}")
else()
message(STATUS "Not building fp32_router_gemm as no compatible archs found "
"(requires SM90+ and CUDA >= 12.0).")
endif()
# Only build AllSpark kernels if we are building for at least some compatible archs.
cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}")
if (ALLSPARK_ARCHS)
@@ -1240,24 +1256,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures")
endif()
# DeepSeek V3 router GEMM kernel - requires SM90+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(DSV3_ROUTER_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(DSV3_ROUTER_GEMM_ARCHS "9.0a;10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND DSV3_ROUTER_GEMM_ARCHS)
# DeepSeek V3 router GEMM kernel requires SM90+ and CUDA >= 12.0.
# (fp32_router_gemm has been migrated to _C_stable_libtorch above.)
cuda_archs_sm90plus(SM90PLUS_ROUTER_GEMM_ARCHS "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SM90PLUS_ROUTER_GEMM_ARCHS)
set(DSV3_ROUTER_GEMM_SRC
"csrc/moe/dsv3_router_gemm_entry.cu"
"csrc/moe/dsv3_router_gemm_float_out.cu"
"csrc/moe/dsv3_router_gemm_bf16_out.cu")
set_gencode_flags_for_srcs(
SRCS "${DSV3_ROUTER_GEMM_SRC}"
CUDA_ARCHS "${DSV3_ROUTER_GEMM_ARCHS}")
CUDA_ARCHS "${SM90PLUS_ROUTER_GEMM_ARCHS}")
list(APPEND VLLM_MOE_EXT_SRC "${DSV3_ROUTER_GEMM_SRC}")
message(STATUS "Building DSV3 router GEMM kernel for archs: ${DSV3_ROUTER_GEMM_ARCHS}")
message(STATUS "Building DSV3 router GEMM kernels for archs: ${SM90PLUS_ROUTER_GEMM_ARCHS}")
else()
message(STATUS "Not building DSV3 router GEMM kernel as no compatible archs found"
message(STATUS "Not building DSV3 router GEMM kernels as no compatible archs found"
" (requires SM90+ and CUDA >= 12.0)")
endif()
endif()
+154
View File
@@ -0,0 +1,154 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser
# Dimensions supported by the DSV3 specialized kernel
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
# Dimensions supported by the gpt-oss specialized kernel
GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128]
GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880]
# Dimensions supported by the fp32 specialized kernel (MiniMax-M2)
FP32_SUPPORTED_NUM_EXPERTS = [256]
FP32_SUPPORTED_HIDDEN_SIZES = [3072]
FP32_MAX_TOKENS = 32
def get_batch_size_range(max_batch_size):
return [2**x for x in range(14) if 2**x <= max_batch_size]
def get_model_params(config):
if config.architectures[0] in (
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM",
):
num_experts = config.n_routed_experts
hidden_size = config.hidden_size
elif config.architectures[0] in ("GptOssForCausalLM",) or config.architectures[
0
] in ("MiniMaxM2ForCausalLM",):
num_experts = config.num_local_experts
hidden_size = config.hidden_size
else:
raise ValueError(f"Unsupported architecture: {config.architectures}")
return num_experts, hidden_size
def get_benchmark(model, max_batch_size, trust_remote_code):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=get_batch_size_range(max_batch_size),
x_log=False,
line_arg="provider",
line_vals=[
"torch",
"vllm",
],
line_names=["PyTorch", "vLLM"],
styles=([("blue", "-"), ("red", "-")]),
ylabel="TFLOPs",
plot_name=f"{model} router gemm throughput",
args={},
)
)
def benchmark(batch_size, provider):
config = get_config(model=model, trust_remote_code=trust_remote_code)
num_experts, hidden_size = get_model_params(config)
is_hopper_or_blackwell = current_platform.is_device_capability(
90
) or current_platform.is_device_capability_family(100)
allow_dsv3_router_gemm = (
is_hopper_or_blackwell
and num_experts in DSV3_SUPPORTED_NUM_EXPERTS
and hidden_size in DSV3_SUPPORTED_HIDDEN_SIZES
)
allow_gpt_oss_router_gemm = (
is_hopper_or_blackwell
and num_experts in GPT_OSS_SUPPORTED_NUM_EXPERTS
and hidden_size in GPT_OSS_SUPPORTED_HIDDEN_SIZES
)
is_fp32_router_model = (
is_hopper_or_blackwell
and num_experts in FP32_SUPPORTED_NUM_EXPERTS
and hidden_size in FP32_SUPPORTED_HIDDEN_SIZES
)
allow_fp32_router_gemm = is_fp32_router_model and batch_size <= FP32_MAX_TOKENS
# Weight dtype: fp32 kernel requires fp32 weights; others use bf16.
weight_dtype = torch.float32 if is_fp32_router_model else torch.bfloat16
mat_a = torch.randn(
(batch_size, hidden_size), dtype=torch.bfloat16, device="cuda"
).contiguous()
mat_b = torch.randn(
(num_experts, hidden_size), dtype=weight_dtype, device="cuda"
).contiguous()
bias = torch.randn(
num_experts, dtype=torch.bfloat16, device="cuda"
).contiguous()
has_bias = allow_gpt_oss_router_gemm
quantiles = [0.5, 0.2, 0.8]
if provider == "torch":
def runner():
if allow_fp32_router_gemm:
F.linear(mat_a.float(), mat_b)
elif has_bias:
F.linear(mat_a, mat_b, bias)
else:
F.linear(mat_a, mat_b)
elif provider == "vllm":
def runner():
if allow_dsv3_router_gemm:
ops.dsv3_router_gemm(mat_a, mat_b, torch.bfloat16)
elif allow_fp32_router_gemm:
ops.fp32_router_gemm(mat_a, mat_b)
elif allow_gpt_oss_router_gemm:
ops.gpt_oss_router_gemm(mat_a, mat_b, bias)
elif is_fp32_router_model:
# batch_size > FP32_MAX_TOKENS: fall back to F.linear
F.linear(mat_a.float(), mat_b)
else:
F.linear(mat_a, mat_b)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
runner, quantiles=quantiles
)
def tflops(t_ms):
flops = 2 * batch_size * hidden_size * num_experts
return flops / (t_ms * 1e-3) / 1e12
return tflops(ms), tflops(max_ms), tflops(min_ms)
return benchmark
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--model", type=str, default="openai/gpt-oss-20b")
parser.add_argument("--max-batch-size", default=16, type=int)
parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args()
# Get the benchmark function
benchmark = get_benchmark(args.model, args.max_batch_size, args.trust_remote_code)
# Run performance benchmark
benchmark.run(print_data=True)
+10
View File
@@ -476,6 +476,16 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
endfunction()
function(cuda_archs_sm90plus OUT_CUDA_ARCHS TGT_CUDA_ARCHS)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(_archs "9.0a;10.0f;11.0f" "${TGT_CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(_archs "9.0a;10.0a;10.1a;10.3a" "${TGT_CUDA_ARCHS}")
endif()
set(${OUT_CUDA_ARCHS} ${_archs} PARENT_SCOPE)
endfunction()
#
# Override the GPU architectures detected by cmake/torch and filter them by
# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
+223
View File
@@ -0,0 +1,223 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
//
// Router GEMM: activation(T) x weight(fp32) -> fp32, H=3072, E=256, M<=32.
// Supports bf16 or fp32 activation; weight is always fp32.
// Adapted from dsv3_router_gemm_float_out.cu.
#include <cuda_bf16.h>
#include <cuda_runtime.h>
// ---------------------------------------------------------------------------
// Load helpers
// ---------------------------------------------------------------------------
// Load VPT fp32 values from the weight matrix (always fp32).
// VPT=4 when activation is fp32 (one float4 load)
// VPT=8 when activation is bf16 (two float4 loads)
template <int VPT>
__device__ __forceinline__ void load_weight(float const* ptr, float* dst);
template <>
__device__ __forceinline__ void load_weight<4>(float const* ptr, float* dst) {
float4 v = *reinterpret_cast<float4 const*>(ptr);
dst[0] = v.x;
dst[1] = v.y;
dst[2] = v.z;
dst[3] = v.w;
}
template <>
__device__ __forceinline__ void load_weight<8>(float const* ptr, float* dst) {
float4 v0 = *reinterpret_cast<float4 const*>(ptr);
float4 v1 = *reinterpret_cast<float4 const*>(ptr + 4);
dst[0] = v0.x;
dst[1] = v0.y;
dst[2] = v0.z;
dst[3] = v0.w;
dst[4] = v1.x;
dst[5] = v1.y;
dst[6] = v1.z;
dst[7] = v1.w;
}
// Load VPT activation values and convert to fp32.
template <typename T, int VPT>
__device__ __forceinline__ void load_activation(T const* ptr, float* dst);
// fp32 activation: one float4 load, no conversion needed.
template <>
__device__ __forceinline__ void load_activation<float, 4>(float const* ptr,
float* dst) {
float4 v = *reinterpret_cast<float4 const*>(ptr);
dst[0] = v.x;
dst[1] = v.y;
dst[2] = v.z;
dst[3] = v.w;
}
// bf16 activation: one uint4 load (8 × bf16) + element-wise conversion.
template <>
__device__ __forceinline__ void load_activation<__nv_bfloat16, 8>(
__nv_bfloat16 const* ptr, float* dst) {
uint4 v = *reinterpret_cast<uint4 const*>(ptr);
__nv_bfloat16 const* bf16_ptr = reinterpret_cast<__nv_bfloat16 const*>(&v);
#pragma unroll
for (int i = 0; i < 8; i++) dst[i] = __bfloat162float(bf16_ptr[i]);
}
// ---------------------------------------------------------------------------
// Kernel
// ---------------------------------------------------------------------------
// InputT : type of activation (float or __nv_bfloat16)
// Weight is always fp32; output is always fp32.
// VPT = 16 / sizeof(InputT): 4 for fp32, 8 for bf16
template <typename InputT, int kBlockSize, int kNumTokens, int kNumExperts,
int kHiddenDim>
__global__ __launch_bounds__(128, 1) void fp32_router_gemm_kernel(
float* out, InputT const* mat_a, float const* mat_b) {
constexpr int VPT = 16 / sizeof(InputT);
constexpr int k_elems_per_k_iteration = VPT * kBlockSize;
constexpr int k_iterations = kHiddenDim / k_elems_per_k_iteration;
constexpr int kWarpSize = 32;
constexpr int kNumWarps = kBlockSize / kWarpSize;
int const n_idx = blockIdx.x;
int const tid = threadIdx.x;
int const warpId = tid / kWarpSize;
int const laneId = tid % kWarpSize;
float acc[kNumTokens] = {};
__shared__ float sm_reduction[kNumTokens][kNumWarps];
float const* b_col = mat_b + n_idx * kHiddenDim;
int k_bases[k_iterations];
#pragma unroll
for (int ki = 0; ki < k_iterations; ki++) {
k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT;
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("griddepcontrol.wait;");
#endif
for (int ki = 0; ki < k_iterations; ki++) {
int const k_base = k_bases[ki];
float b_float[VPT];
load_weight<VPT>(b_col + k_base, b_float);
#pragma unroll
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
float a_float[VPT];
load_activation<InputT, VPT>(mat_a + m_idx * kHiddenDim + k_base,
a_float);
#pragma unroll
for (int k = 0; k < VPT; k++) {
acc[m_idx] += a_float[k] * b_float[k];
}
}
}
// Warp-level butterfly reduction
#pragma unroll
for (int m = 0; m < kNumTokens; m++) {
float sum = acc[m];
sum += __shfl_xor_sync(0xffffffff, sum, 16);
sum += __shfl_xor_sync(0xffffffff, sum, 8);
sum += __shfl_xor_sync(0xffffffff, sum, 4);
sum += __shfl_xor_sync(0xffffffff, sum, 2);
sum += __shfl_xor_sync(0xffffffff, sum, 1);
if (laneId == 0) sm_reduction[m][warpId] = sum;
}
__syncthreads();
if (tid == 0) {
#pragma unroll
for (int m = 0; m < kNumTokens; m++) {
float final_sum = 0.0f;
#pragma unroll
for (int w = 0; w < kNumWarps; w++) final_sum += sm_reduction[m][w];
out[m * kNumExperts + n_idx] = final_sum;
}
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
// ---------------------------------------------------------------------------
// Launcher
// ---------------------------------------------------------------------------
template <typename InputT, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeFp32RouterGemm(float* output, InputT const* mat_a,
float const* mat_b, cudaStream_t stream) {
constexpr int kBlockSize = 128;
cudaLaunchConfig_t config;
config.gridDim = kNumExperts;
config.blockDim = kBlockSize;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = 1;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config,
fp32_router_gemm_kernel<InputT, kBlockSize, kNumTokens,
kNumExperts, kHiddenDim>,
output, mat_a, mat_b);
}
// ---------------------------------------------------------------------------
// Explicit instantiations: M=1..32, E=256, H=3072, for both input types
// ---------------------------------------------------------------------------
#define INSTANTIATE(T, M) \
template void invokeFp32RouterGemm<T, M, 256, 3072>( \
float*, T const*, float const*, cudaStream_t);
#define INSTANTIATE_ALL(T) \
INSTANTIATE(T, 1) \
INSTANTIATE(T, 2) \
INSTANTIATE(T, 3) \
INSTANTIATE(T, 4) \
INSTANTIATE(T, 5) \
INSTANTIATE(T, 6) \
INSTANTIATE(T, 7) \
INSTANTIATE(T, 8) \
INSTANTIATE(T, 9) \
INSTANTIATE(T, 10) \
INSTANTIATE(T, 11) \
INSTANTIATE(T, 12) \
INSTANTIATE(T, 13) \
INSTANTIATE(T, 14) \
INSTANTIATE(T, 15) \
INSTANTIATE(T, 16) \
INSTANTIATE(T, 17) \
INSTANTIATE(T, 18) \
INSTANTIATE(T, 19) \
INSTANTIATE(T, 20) \
INSTANTIATE(T, 21) \
INSTANTIATE(T, 22) \
INSTANTIATE(T, 23) \
INSTANTIATE(T, 24) \
INSTANTIATE(T, 25) \
INSTANTIATE(T, 26) \
INSTANTIATE(T, 27) \
INSTANTIATE(T, 28) \
INSTANTIATE(T, 29) \
INSTANTIATE(T, 30) \
INSTANTIATE(T, 31) \
INSTANTIATE(T, 32)
INSTANTIATE_ALL(float)
INSTANTIATE_ALL(__nv_bfloat16)
#undef INSTANTIATE_ALL
#undef INSTANTIATE
@@ -0,0 +1,127 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include "core/registration.h"
#include "libtorch_stable/torch_utils.h"
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <stdexcept>
namespace {
inline int getSMVersion() {
auto* props = get_device_prop();
return props->major * 10 + props->minor;
}
} // namespace
static constexpr int FP32_NUM_EXPERTS = 256;
static constexpr int FP32_HIDDEN_DIM = 3072;
static constexpr int FP32_MAX_TOKENS = 32;
// Forward declarations — 4 template params must match fp32_router_gemm.cu
template <typename InputT, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeFp32RouterGemm(float* output, InputT const* mat_a,
float const* mat_b, cudaStream_t stream);
// LoopUnroller templated on InputT
template <typename InputT, int kBegin, int kEnd>
struct Fp32LoopUnroller {
static void unroll(int num_tokens, float* output, InputT const* mat_a,
float const* mat_b, cudaStream_t stream) {
if (num_tokens == kBegin) {
invokeFp32RouterGemm<InputT, kBegin, FP32_NUM_EXPERTS, FP32_HIDDEN_DIM>(
output, mat_a, mat_b, stream);
} else {
Fp32LoopUnroller<InputT, kBegin + 1, kEnd>::unroll(num_tokens, output,
mat_a, mat_b, stream);
}
}
};
template <typename InputT, int kEnd>
struct Fp32LoopUnroller<InputT, kEnd, kEnd> {
static void unroll(int num_tokens, float* output, InputT const* mat_a,
float const* mat_b, cudaStream_t stream) {
if (num_tokens == kEnd) {
invokeFp32RouterGemm<InputT, kEnd, FP32_NUM_EXPERTS, FP32_HIDDEN_DIM>(
output, mat_a, mat_b, stream);
} else {
throw std::invalid_argument(
"fp32_router_gemm: num_tokens must be in [1, 32]");
}
}
};
void fp32_router_gemm(
torch::stable::Tensor& output, // [num_tokens, num_experts]
torch::stable::Tensor const& mat_a, // [num_tokens, hidden_dim]
torch::stable::Tensor const& mat_b // [num_experts, hidden_dim]
) {
STD_TORCH_CHECK(output.dim() == 2 && mat_a.dim() == 2 && mat_b.dim() == 2);
STD_TORCH_CHECK(output.is_cuda() && mat_a.is_cuda() && mat_b.is_cuda(),
"fp32_router_gemm: all tensors must be CUDA tensors");
STD_TORCH_CHECK(output.get_device_index() == mat_a.get_device_index() &&
output.get_device_index() == mat_b.get_device_index(),
"fp32_router_gemm: all tensors must be on the same device");
STD_TORCH_CHECK(
output.is_contiguous() && mat_a.is_contiguous() && mat_b.is_contiguous(),
"fp32_router_gemm: all tensors must be contiguous");
const int num_tokens = mat_a.size(0);
const int num_experts = mat_b.size(0);
const int hidden_dim = mat_a.size(1);
STD_TORCH_CHECK(output.size(0) == num_tokens && output.size(1) == num_experts,
"fp32_router_gemm: output must have shape [num_tokens, "
"num_experts]");
STD_TORCH_CHECK(
mat_a.size(1) == mat_b.size(1),
"fp32_router_gemm: mat_a and mat_b must have the same hidden_dim");
STD_TORCH_CHECK(hidden_dim == FP32_HIDDEN_DIM,
"fp32_router_gemm: expected hidden_dim=3072");
STD_TORCH_CHECK(num_experts == FP32_NUM_EXPERTS,
"fp32_router_gemm: expected num_experts=256");
STD_TORCH_CHECK(num_tokens <= FP32_MAX_TOKENS,
"fp32_router_gemm: num_tokens must be in [0, 32]");
STD_TORCH_CHECK(
mat_a.scalar_type() == torch::headeronly::ScalarType::Float ||
mat_a.scalar_type() == torch::headeronly::ScalarType::BFloat16,
"fp32_router_gemm: mat_a must be float32 or bfloat16");
STD_TORCH_CHECK(mat_b.scalar_type() == torch::headeronly::ScalarType::Float,
"fp32_router_gemm: mat_b (weight) must be float32");
STD_TORCH_CHECK(output.scalar_type() == torch::headeronly::ScalarType::Float,
"fp32_router_gemm: output must be float32");
if (num_tokens == 0) {
return;
}
STD_TORCH_CHECK(getSMVersion() >= 90, "fp32_router_gemm: requires SM90+");
auto stream = get_current_cuda_stream(mat_a.get_device_index());
float* out_ptr = reinterpret_cast<float*>(output.mutable_data_ptr());
float const* mat_b_ptr = reinterpret_cast<float const*>(mat_b.data_ptr());
if (mat_a.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
auto const* mat_a_ptr =
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr());
Fp32LoopUnroller<__nv_bfloat16, 1, FP32_MAX_TOKENS>::unroll(
num_tokens, out_ptr, mat_a_ptr, mat_b_ptr, stream);
} else {
auto const* mat_a_ptr = reinterpret_cast<float const*>(mat_a.data_ptr());
Fp32LoopUnroller<float, 1, FP32_MAX_TOKENS>::unroll(
num_tokens, out_ptr, mat_a_ptr, mat_b_ptr, stream);
}
}
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("fp32_router_gemm", TORCH_BOX(&fp32_router_gemm));
}
+4
View File
@@ -247,6 +247,10 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
ops.def(
"dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
// BF16/FP32 x FP32 -> FP32 router GEMM for H=3072, E=256, M<=32 (SM90+).
// conditionally compiled so impl registration is in source file
ops.def("fp32_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
ops.def(
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
+78
View File
@@ -0,0 +1,78 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for fp32_router_gemm kernel: activation×weight→fp32, H=3072, E=256.
Correctness baseline: torch.matmul in float64.
"""
import pytest
import torch
from vllm._custom_ops import fp32_router_gemm
NUM_EXPERTS = 256
HIDDEN_DIM = 3072
# Absolute tolerance for fp32 kernel vs float64 reference
ATOL_FP32 = 2e-4
ATOL_BF16 = 2e-2 # bf16 activation has lower precision
def _requires_sm90():
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
major, minor = torch.cuda.get_device_capability()
if major * 10 + minor < 90:
pytest.skip(f"fp32_router_gemm requires SM90+, got SM{major}{minor}")
def _ref(mat_a: torch.Tensor, mat_b: torch.Tensor) -> torch.Tensor:
"""Reference: F.linear in float32 on GPU."""
return torch.nn.functional.linear(mat_a.float(), mat_b.float())
@pytest.mark.parametrize("num_tokens", [1, 2, 4, 8, 16, 32])
def test_fp32_activation(num_tokens: int):
"""fp32 activation → fp32 output should match reference closely."""
_requires_sm90()
torch.manual_seed(42)
device = torch.device("cuda")
mat_a = torch.randn(num_tokens, HIDDEN_DIM, dtype=torch.float32, device=device)
mat_b = torch.randn(NUM_EXPERTS, HIDDEN_DIM, dtype=torch.float32, device=device)
out = fp32_router_gemm(mat_a, mat_b)
ref = _ref(mat_a, mat_b)
assert out.shape == (num_tokens, NUM_EXPERTS)
assert out.dtype == torch.float32
torch.testing.assert_close(out, ref, atol=ATOL_FP32, rtol=0)
@pytest.mark.parametrize("num_tokens", [1, 2, 4, 8, 16, 32])
def test_bf16_activation(num_tokens: int):
"""bf16 activation → fp32 output should match reference within bf16 error."""
_requires_sm90()
torch.manual_seed(42)
device = torch.device("cuda")
mat_a_bf16 = torch.randn(
num_tokens, HIDDEN_DIM, dtype=torch.bfloat16, device=device
)
mat_b = torch.randn(NUM_EXPERTS, HIDDEN_DIM, dtype=torch.float32, device=device)
out = fp32_router_gemm(mat_a_bf16, mat_b)
ref = _ref(mat_a_bf16, mat_b).to(device)
assert out.shape == (num_tokens, NUM_EXPERTS)
assert out.dtype == torch.float32
torch.testing.assert_close(out, ref, atol=ATOL_BF16, rtol=0)
def test_output_shape_and_dtype():
"""Basic shape and dtype checks."""
_requires_sm90()
device = torch.device("cuda")
mat_a = torch.randn(4, HIDDEN_DIM, dtype=torch.float32, device=device)
mat_b = torch.randn(NUM_EXPERTS, HIDDEN_DIM, dtype=torch.float32, device=device)
out = fp32_router_gemm(mat_a, mat_b)
assert out.shape == (4, NUM_EXPERTS)
assert out.dtype == torch.float32
assert out.device.type == "cuda"
+25
View File
@@ -2412,6 +2412,31 @@ def dsv3_router_gemm(
return output
def fp32_router_gemm(
hidden_states: torch.Tensor,
router_weight: torch.Tensor,
) -> torch.Tensor:
output = torch.empty(
hidden_states.shape[0],
router_weight.shape[0],
device=hidden_states.device,
dtype=torch.float32,
)
torch.ops._C.fp32_router_gemm(output, hidden_states, router_weight)
return output
if hasattr(torch.ops, "_C") and hasattr(torch.ops._C, "fp32_router_gemm"):
@register_fake("_C::fp32_router_gemm")
def fp32_router_gemm_fake(
output: torch.Tensor,
mat_a: torch.Tensor,
mat_b: torch.Tensor,
) -> None:
return
def topk_softmax(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
@@ -3,18 +3,22 @@
import torch
from torch.nn.parameter import Parameter
import vllm._custom_ops as ops
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
@PluggableLayer.register("gate_linear")
class GateLinear(ReplicatedLinear):
"""MoE gate linear layer with three-tier GEMM dispatch:
"""MoE gate linear layer with multi-tier GEMM dispatch:
1. DSV3 specialized kernel (SM90+, batch<=16, supported dims)
2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
3. F.linear via ReplicatedLinear (ultimate fallback)
1. DSV3 specialized kernel (SM90+, fp32 out, M<=16, H=7168, E=256/384)
2. fp32 specialized kernel (SM90+, bf16/fp32 in, fp32 out,
M<=32, H=3072, E=256)
3. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 weight + fp32 out_dtype)
4. F.linear via ReplicatedLinear (ultimate fallback)
The ``out_dtype`` attribute is mutable and can be set after init
(e.g. when the required dtype depends on the expert quantization
@@ -25,6 +29,11 @@ class GateLinear(ReplicatedLinear):
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
# Dimensions supported by the fp32 specialized kernel
FP32_SUPPORTED_NUM_EXPERTS = [256]
FP32_SUPPORTED_HIDDEN_SIZES = [3072]
FP32_MAX_TOKENS = 32
def __init__(
self,
input_size: int,
@@ -43,7 +52,7 @@ class GateLinear(ReplicatedLinear):
)
# If fp32 compute is required and no specialized kernel is available,
# store weights in fp32 so Tier 3 computes in fp32 natively.
# store weights in fp32 so the fallback linear path computes in fp32.
if force_fp32_compute and not can_use_specialized_kernels:
params_dtype = torch.float32
@@ -65,6 +74,16 @@ class GateLinear(ReplicatedLinear):
and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES
)
# fp32 specialized kernel eligibility (SM90+, exact dims, fp32 weight)
self.allow_fp32_router_gemm = (
not bias
and self.weight.dtype == torch.float32
and current_platform.is_cuda()
and is_hopper_or_blackwell
and output_size in self.FP32_SUPPORTED_NUM_EXPERTS
and input_size in self.FP32_SUPPORTED_HIDDEN_SIZES
)
# cuBLAS bf16→fp32 eligibility
self.allow_cublas_router_gemm = (
self.allow_specialized_router_gemm
@@ -92,8 +111,6 @@ class GateLinear(ReplicatedLinear):
def forward(
self, x: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
import vllm._custom_ops as ops
# Tier 1: DSV3 specialized kernel
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
output = ops.dsv3_router_gemm(
@@ -103,15 +120,56 @@ class GateLinear(ReplicatedLinear):
)
return output, None
# Tier 2: cuBLAS bf16→fp32
# Tier 2: fp32 specialized kernel (H=3072, E=256, M<=32)
# Dispatch is wrapped in a custom op so that torch.compile/CUDA-graph
# capture does not freeze the runtime num_tokens branch.
if self.allow_fp32_router_gemm and x.dtype in (
torch.float32,
torch.bfloat16,
):
output = torch.ops.vllm.fp32_router_gemm_dispatch(x, self.weight)
return output, None
# Tier 3: cuBLAS bf16→fp32
if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16:
output = torch.mm(x, self.weight.T, out_dtype=torch.float32)
return output, None
# Tier 3: F.linear (ReplicatedLinear)
# Tier 4: F.linear (ReplicatedLinear)
if self.out_dtype is not None and x.dtype != self.weight.dtype:
x = x.to(self.weight.dtype)
output, output_bias = super().forward(x)
if self.out_dtype is not None and output.dtype != self.out_dtype:
output = output.to(self.out_dtype)
return output, output_bias
_FP32_ROUTER_GEMM_MAX_TOKENS = GateLinear.FP32_MAX_TOKENS
def fp32_router_gemm_dispatch_impl(
x: torch.Tensor, weight: torch.Tensor
) -> torch.Tensor:
"""
Dynamically run fp32 specialized gemm if num_tokens <= FP32_MAX_TOKENS,
otherwise fall back to F.linear.
This must be wrapped in a custom op because our torch.compile integration
does not support runtime dispatching on num_tokens.
"""
if x.shape[0] <= _FP32_ROUTER_GEMM_MAX_TOKENS:
return ops.fp32_router_gemm(x, weight)
else:
return torch.nn.functional.linear(x.float(), weight)
def fp32_router_gemm_dispatch_fake(
x: torch.Tensor, weight: torch.Tensor
) -> torch.Tensor:
return x.new_empty((x.shape[0], weight.shape[0]), dtype=torch.float32)
direct_register_custom_op(
op_name="fp32_router_gemm_dispatch",
op_func=fp32_router_gemm_dispatch_impl,
fake_impl=fp32_router_gemm_dispatch_fake,
)
+4 -4
View File
@@ -43,10 +43,10 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE,
fused_moe_make_expert_params_mapping,
)
from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -113,12 +113,12 @@ class MiniMaxM2MoE(nn.Module):
router_logits_dtype=torch.float32,
)
self.gate = ReplicatedLinear(
self.gate = GateLinear(
config.hidden_size,
config.num_local_experts,
bias=False,
params_dtype=torch.float32,
quant_config=None,
out_dtype=torch.float32,
prefix=f"{prefix}.gate",
)
@@ -132,7 +132,7 @@ class MiniMaxM2MoE(nn.Module):
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states.to(torch.float32))
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)