mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[PERF]MiniMax-M2 gate kernel (#38445)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: qianlihuang <91178480+qianlihuang@users.noreply.github.com> Co-authored-by: Yiliu Dong <91178480+qianlihuang@users.noreply.github.com>
This commit is contained in:
+24
-10
@@ -683,6 +683,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
|
||||
# FP32 router GEMM (H=3072, E=256, M<=32). Requires SM90+ and CUDA >= 12.0.
|
||||
cuda_archs_sm90plus(FP32_ROUTER_GEMM_ARCHS "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND FP32_ROUTER_GEMM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/libtorch_stable/fp32_router_gemm_entry.cu"
|
||||
"csrc/libtorch_stable/fp32_router_gemm.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${FP32_ROUTER_GEMM_ARCHS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
message(STATUS "Building fp32_router_gemm for archs: ${FP32_ROUTER_GEMM_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building fp32_router_gemm as no compatible archs found "
|
||||
"(requires SM90+ and CUDA >= 12.0).")
|
||||
endif()
|
||||
|
||||
# Only build AllSpark kernels if we are building for at least some compatible archs.
|
||||
cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}")
|
||||
if (ALLSPARK_ARCHS)
|
||||
@@ -1240,24 +1256,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
" in CUDA target architectures")
|
||||
endif()
|
||||
|
||||
# DeepSeek V3 router GEMM kernel - requires SM90+
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(DSV3_ROUTER_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(DSV3_ROUTER_GEMM_ARCHS "9.0a;10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND DSV3_ROUTER_GEMM_ARCHS)
|
||||
# DeepSeek V3 router GEMM kernel requires SM90+ and CUDA >= 12.0.
|
||||
# (fp32_router_gemm has been migrated to _C_stable_libtorch above.)
|
||||
cuda_archs_sm90plus(SM90PLUS_ROUTER_GEMM_ARCHS "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SM90PLUS_ROUTER_GEMM_ARCHS)
|
||||
set(DSV3_ROUTER_GEMM_SRC
|
||||
"csrc/moe/dsv3_router_gemm_entry.cu"
|
||||
"csrc/moe/dsv3_router_gemm_float_out.cu"
|
||||
"csrc/moe/dsv3_router_gemm_bf16_out.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${DSV3_ROUTER_GEMM_SRC}"
|
||||
CUDA_ARCHS "${DSV3_ROUTER_GEMM_ARCHS}")
|
||||
CUDA_ARCHS "${SM90PLUS_ROUTER_GEMM_ARCHS}")
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${DSV3_ROUTER_GEMM_SRC}")
|
||||
message(STATUS "Building DSV3 router GEMM kernel for archs: ${DSV3_ROUTER_GEMM_ARCHS}")
|
||||
|
||||
message(STATUS "Building DSV3 router GEMM kernels for archs: ${SM90PLUS_ROUTER_GEMM_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building DSV3 router GEMM kernel as no compatible archs found"
|
||||
message(STATUS "Not building DSV3 router GEMM kernels as no compatible archs found"
|
||||
" (requires SM90+ and CUDA >= 12.0)")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -0,0 +1,154 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
# Dimensions supported by the DSV3 specialized kernel
|
||||
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
|
||||
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
|
||||
|
||||
# Dimensions supported by the gpt-oss specialized kernel
|
||||
GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128]
|
||||
GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880]
|
||||
|
||||
# Dimensions supported by the fp32 specialized kernel (MiniMax-M2)
|
||||
FP32_SUPPORTED_NUM_EXPERTS = [256]
|
||||
FP32_SUPPORTED_HIDDEN_SIZES = [3072]
|
||||
FP32_MAX_TOKENS = 32
|
||||
|
||||
|
||||
def get_batch_size_range(max_batch_size):
|
||||
return [2**x for x in range(14) if 2**x <= max_batch_size]
|
||||
|
||||
|
||||
def get_model_params(config):
|
||||
if config.architectures[0] in (
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV32ForCausalLM",
|
||||
):
|
||||
num_experts = config.n_routed_experts
|
||||
hidden_size = config.hidden_size
|
||||
elif config.architectures[0] in ("GptOssForCausalLM",) or config.architectures[
|
||||
0
|
||||
] in ("MiniMaxM2ForCausalLM",):
|
||||
num_experts = config.num_local_experts
|
||||
hidden_size = config.hidden_size
|
||||
else:
|
||||
raise ValueError(f"Unsupported architecture: {config.architectures}")
|
||||
return num_experts, hidden_size
|
||||
|
||||
|
||||
def get_benchmark(model, max_batch_size, trust_remote_code):
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=get_batch_size_range(max_batch_size),
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"torch",
|
||||
"vllm",
|
||||
],
|
||||
line_names=["PyTorch", "vLLM"],
|
||||
styles=([("blue", "-"), ("red", "-")]),
|
||||
ylabel="TFLOPs",
|
||||
plot_name=f"{model} router gemm throughput",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider):
|
||||
config = get_config(model=model, trust_remote_code=trust_remote_code)
|
||||
num_experts, hidden_size = get_model_params(config)
|
||||
|
||||
is_hopper_or_blackwell = current_platform.is_device_capability(
|
||||
90
|
||||
) or current_platform.is_device_capability_family(100)
|
||||
allow_dsv3_router_gemm = (
|
||||
is_hopper_or_blackwell
|
||||
and num_experts in DSV3_SUPPORTED_NUM_EXPERTS
|
||||
and hidden_size in DSV3_SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
allow_gpt_oss_router_gemm = (
|
||||
is_hopper_or_blackwell
|
||||
and num_experts in GPT_OSS_SUPPORTED_NUM_EXPERTS
|
||||
and hidden_size in GPT_OSS_SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
is_fp32_router_model = (
|
||||
is_hopper_or_blackwell
|
||||
and num_experts in FP32_SUPPORTED_NUM_EXPERTS
|
||||
and hidden_size in FP32_SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
allow_fp32_router_gemm = is_fp32_router_model and batch_size <= FP32_MAX_TOKENS
|
||||
|
||||
# Weight dtype: fp32 kernel requires fp32 weights; others use bf16.
|
||||
weight_dtype = torch.float32 if is_fp32_router_model else torch.bfloat16
|
||||
mat_a = torch.randn(
|
||||
(batch_size, hidden_size), dtype=torch.bfloat16, device="cuda"
|
||||
).contiguous()
|
||||
mat_b = torch.randn(
|
||||
(num_experts, hidden_size), dtype=weight_dtype, device="cuda"
|
||||
).contiguous()
|
||||
bias = torch.randn(
|
||||
num_experts, dtype=torch.bfloat16, device="cuda"
|
||||
).contiguous()
|
||||
|
||||
has_bias = allow_gpt_oss_router_gemm
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch":
|
||||
|
||||
def runner():
|
||||
if allow_fp32_router_gemm:
|
||||
F.linear(mat_a.float(), mat_b)
|
||||
elif has_bias:
|
||||
F.linear(mat_a, mat_b, bias)
|
||||
else:
|
||||
F.linear(mat_a, mat_b)
|
||||
elif provider == "vllm":
|
||||
|
||||
def runner():
|
||||
if allow_dsv3_router_gemm:
|
||||
ops.dsv3_router_gemm(mat_a, mat_b, torch.bfloat16)
|
||||
elif allow_fp32_router_gemm:
|
||||
ops.fp32_router_gemm(mat_a, mat_b)
|
||||
elif allow_gpt_oss_router_gemm:
|
||||
ops.gpt_oss_router_gemm(mat_a, mat_b, bias)
|
||||
elif is_fp32_router_model:
|
||||
# batch_size > FP32_MAX_TOKENS: fall back to F.linear
|
||||
F.linear(mat_a.float(), mat_b)
|
||||
else:
|
||||
F.linear(mat_a, mat_b)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
runner, quantiles=quantiles
|
||||
)
|
||||
|
||||
def tflops(t_ms):
|
||||
flops = 2 * batch_size * hidden_size * num_experts
|
||||
return flops / (t_ms * 1e-3) / 1e12
|
||||
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("--model", type=str, default="openai/gpt-oss-20b")
|
||||
parser.add_argument("--max-batch-size", default=16, type=int)
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get the benchmark function
|
||||
benchmark = get_benchmark(args.model, args.max_batch_size, args.trust_remote_code)
|
||||
# Run performance benchmark
|
||||
benchmark.run(print_data=True)
|
||||
@@ -476,6 +476,16 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
||||
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
|
||||
function(cuda_archs_sm90plus OUT_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(_archs "9.0a;10.0f;11.0f" "${TGT_CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(_archs "9.0a;10.0a;10.1a;10.3a" "${TGT_CUDA_ARCHS}")
|
||||
endif()
|
||||
set(${OUT_CUDA_ARCHS} ${_archs} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
#
|
||||
# Override the GPU architectures detected by cmake/torch and filter them by
|
||||
# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
|
||||
|
||||
@@ -0,0 +1,223 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
//
|
||||
// Router GEMM: activation(T) x weight(fp32) -> fp32, H=3072, E=256, M<=32.
|
||||
// Supports bf16 or fp32 activation; weight is always fp32.
|
||||
// Adapted from dsv3_router_gemm_float_out.cu.
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Load helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Load VPT fp32 values from the weight matrix (always fp32).
|
||||
// VPT=4 when activation is fp32 (one float4 load)
|
||||
// VPT=8 when activation is bf16 (two float4 loads)
|
||||
template <int VPT>
|
||||
__device__ __forceinline__ void load_weight(float const* ptr, float* dst);
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void load_weight<4>(float const* ptr, float* dst) {
|
||||
float4 v = *reinterpret_cast<float4 const*>(ptr);
|
||||
dst[0] = v.x;
|
||||
dst[1] = v.y;
|
||||
dst[2] = v.z;
|
||||
dst[3] = v.w;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void load_weight<8>(float const* ptr, float* dst) {
|
||||
float4 v0 = *reinterpret_cast<float4 const*>(ptr);
|
||||
float4 v1 = *reinterpret_cast<float4 const*>(ptr + 4);
|
||||
dst[0] = v0.x;
|
||||
dst[1] = v0.y;
|
||||
dst[2] = v0.z;
|
||||
dst[3] = v0.w;
|
||||
dst[4] = v1.x;
|
||||
dst[5] = v1.y;
|
||||
dst[6] = v1.z;
|
||||
dst[7] = v1.w;
|
||||
}
|
||||
|
||||
// Load VPT activation values and convert to fp32.
|
||||
template <typename T, int VPT>
|
||||
__device__ __forceinline__ void load_activation(T const* ptr, float* dst);
|
||||
|
||||
// fp32 activation: one float4 load, no conversion needed.
|
||||
template <>
|
||||
__device__ __forceinline__ void load_activation<float, 4>(float const* ptr,
|
||||
float* dst) {
|
||||
float4 v = *reinterpret_cast<float4 const*>(ptr);
|
||||
dst[0] = v.x;
|
||||
dst[1] = v.y;
|
||||
dst[2] = v.z;
|
||||
dst[3] = v.w;
|
||||
}
|
||||
|
||||
// bf16 activation: one uint4 load (8 × bf16) + element-wise conversion.
|
||||
template <>
|
||||
__device__ __forceinline__ void load_activation<__nv_bfloat16, 8>(
|
||||
__nv_bfloat16 const* ptr, float* dst) {
|
||||
uint4 v = *reinterpret_cast<uint4 const*>(ptr);
|
||||
__nv_bfloat16 const* bf16_ptr = reinterpret_cast<__nv_bfloat16 const*>(&v);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) dst[i] = __bfloat162float(bf16_ptr[i]);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Kernel
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// InputT : type of activation (float or __nv_bfloat16)
|
||||
// Weight is always fp32; output is always fp32.
|
||||
// VPT = 16 / sizeof(InputT): 4 for fp32, 8 for bf16
|
||||
template <typename InputT, int kBlockSize, int kNumTokens, int kNumExperts,
|
||||
int kHiddenDim>
|
||||
__global__ __launch_bounds__(128, 1) void fp32_router_gemm_kernel(
|
||||
float* out, InputT const* mat_a, float const* mat_b) {
|
||||
constexpr int VPT = 16 / sizeof(InputT);
|
||||
constexpr int k_elems_per_k_iteration = VPT * kBlockSize;
|
||||
constexpr int k_iterations = kHiddenDim / k_elems_per_k_iteration;
|
||||
constexpr int kWarpSize = 32;
|
||||
constexpr int kNumWarps = kBlockSize / kWarpSize;
|
||||
|
||||
int const n_idx = blockIdx.x;
|
||||
int const tid = threadIdx.x;
|
||||
int const warpId = tid / kWarpSize;
|
||||
int const laneId = tid % kWarpSize;
|
||||
|
||||
float acc[kNumTokens] = {};
|
||||
__shared__ float sm_reduction[kNumTokens][kNumWarps];
|
||||
|
||||
float const* b_col = mat_b + n_idx * kHiddenDim;
|
||||
|
||||
int k_bases[k_iterations];
|
||||
#pragma unroll
|
||||
for (int ki = 0; ki < k_iterations; ki++) {
|
||||
k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT;
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
|
||||
for (int ki = 0; ki < k_iterations; ki++) {
|
||||
int const k_base = k_bases[ki];
|
||||
|
||||
float b_float[VPT];
|
||||
load_weight<VPT>(b_col + k_base, b_float);
|
||||
|
||||
#pragma unroll
|
||||
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
|
||||
float a_float[VPT];
|
||||
load_activation<InputT, VPT>(mat_a + m_idx * kHiddenDim + k_base,
|
||||
a_float);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < VPT; k++) {
|
||||
acc[m_idx] += a_float[k] * b_float[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Warp-level butterfly reduction
|
||||
#pragma unroll
|
||||
for (int m = 0; m < kNumTokens; m++) {
|
||||
float sum = acc[m];
|
||||
sum += __shfl_xor_sync(0xffffffff, sum, 16);
|
||||
sum += __shfl_xor_sync(0xffffffff, sum, 8);
|
||||
sum += __shfl_xor_sync(0xffffffff, sum, 4);
|
||||
sum += __shfl_xor_sync(0xffffffff, sum, 2);
|
||||
sum += __shfl_xor_sync(0xffffffff, sum, 1);
|
||||
if (laneId == 0) sm_reduction[m][warpId] = sum;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (tid == 0) {
|
||||
#pragma unroll
|
||||
for (int m = 0; m < kNumTokens; m++) {
|
||||
float final_sum = 0.0f;
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kNumWarps; w++) final_sum += sm_reduction[m][w];
|
||||
out[m * kNumExperts + n_idx] = final_sum;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Launcher
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
template <typename InputT, int kNumTokens, int kNumExperts, int kHiddenDim>
|
||||
void invokeFp32RouterGemm(float* output, InputT const* mat_a,
|
||||
float const* mat_b, cudaStream_t stream) {
|
||||
constexpr int kBlockSize = 128;
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = kNumExperts;
|
||||
config.blockDim = kBlockSize;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = 1;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config,
|
||||
fp32_router_gemm_kernel<InputT, kBlockSize, kNumTokens,
|
||||
kNumExperts, kHiddenDim>,
|
||||
output, mat_a, mat_b);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Explicit instantiations: M=1..32, E=256, H=3072, for both input types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#define INSTANTIATE(T, M) \
|
||||
template void invokeFp32RouterGemm<T, M, 256, 3072>( \
|
||||
float*, T const*, float const*, cudaStream_t);
|
||||
|
||||
#define INSTANTIATE_ALL(T) \
|
||||
INSTANTIATE(T, 1) \
|
||||
INSTANTIATE(T, 2) \
|
||||
INSTANTIATE(T, 3) \
|
||||
INSTANTIATE(T, 4) \
|
||||
INSTANTIATE(T, 5) \
|
||||
INSTANTIATE(T, 6) \
|
||||
INSTANTIATE(T, 7) \
|
||||
INSTANTIATE(T, 8) \
|
||||
INSTANTIATE(T, 9) \
|
||||
INSTANTIATE(T, 10) \
|
||||
INSTANTIATE(T, 11) \
|
||||
INSTANTIATE(T, 12) \
|
||||
INSTANTIATE(T, 13) \
|
||||
INSTANTIATE(T, 14) \
|
||||
INSTANTIATE(T, 15) \
|
||||
INSTANTIATE(T, 16) \
|
||||
INSTANTIATE(T, 17) \
|
||||
INSTANTIATE(T, 18) \
|
||||
INSTANTIATE(T, 19) \
|
||||
INSTANTIATE(T, 20) \
|
||||
INSTANTIATE(T, 21) \
|
||||
INSTANTIATE(T, 22) \
|
||||
INSTANTIATE(T, 23) \
|
||||
INSTANTIATE(T, 24) \
|
||||
INSTANTIATE(T, 25) \
|
||||
INSTANTIATE(T, 26) \
|
||||
INSTANTIATE(T, 27) \
|
||||
INSTANTIATE(T, 28) \
|
||||
INSTANTIATE(T, 29) \
|
||||
INSTANTIATE(T, 30) \
|
||||
INSTANTIATE(T, 31) \
|
||||
INSTANTIATE(T, 32)
|
||||
|
||||
INSTANTIATE_ALL(float)
|
||||
INSTANTIATE_ALL(__nv_bfloat16)
|
||||
|
||||
#undef INSTANTIATE_ALL
|
||||
#undef INSTANTIATE
|
||||
@@ -0,0 +1,127 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
#include "core/registration.h"
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
namespace {
|
||||
|
||||
inline int getSMVersion() {
|
||||
auto* props = get_device_prop();
|
||||
return props->major * 10 + props->minor;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
static constexpr int FP32_NUM_EXPERTS = 256;
|
||||
static constexpr int FP32_HIDDEN_DIM = 3072;
|
||||
static constexpr int FP32_MAX_TOKENS = 32;
|
||||
|
||||
// Forward declarations — 4 template params must match fp32_router_gemm.cu
|
||||
template <typename InputT, int kNumTokens, int kNumExperts, int kHiddenDim>
|
||||
void invokeFp32RouterGemm(float* output, InputT const* mat_a,
|
||||
float const* mat_b, cudaStream_t stream);
|
||||
|
||||
// LoopUnroller templated on InputT
|
||||
template <typename InputT, int kBegin, int kEnd>
|
||||
struct Fp32LoopUnroller {
|
||||
static void unroll(int num_tokens, float* output, InputT const* mat_a,
|
||||
float const* mat_b, cudaStream_t stream) {
|
||||
if (num_tokens == kBegin) {
|
||||
invokeFp32RouterGemm<InputT, kBegin, FP32_NUM_EXPERTS, FP32_HIDDEN_DIM>(
|
||||
output, mat_a, mat_b, stream);
|
||||
} else {
|
||||
Fp32LoopUnroller<InputT, kBegin + 1, kEnd>::unroll(num_tokens, output,
|
||||
mat_a, mat_b, stream);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename InputT, int kEnd>
|
||||
struct Fp32LoopUnroller<InputT, kEnd, kEnd> {
|
||||
static void unroll(int num_tokens, float* output, InputT const* mat_a,
|
||||
float const* mat_b, cudaStream_t stream) {
|
||||
if (num_tokens == kEnd) {
|
||||
invokeFp32RouterGemm<InputT, kEnd, FP32_NUM_EXPERTS, FP32_HIDDEN_DIM>(
|
||||
output, mat_a, mat_b, stream);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"fp32_router_gemm: num_tokens must be in [1, 32]");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void fp32_router_gemm(
|
||||
torch::stable::Tensor& output, // [num_tokens, num_experts]
|
||||
torch::stable::Tensor const& mat_a, // [num_tokens, hidden_dim]
|
||||
torch::stable::Tensor const& mat_b // [num_experts, hidden_dim]
|
||||
) {
|
||||
STD_TORCH_CHECK(output.dim() == 2 && mat_a.dim() == 2 && mat_b.dim() == 2);
|
||||
STD_TORCH_CHECK(output.is_cuda() && mat_a.is_cuda() && mat_b.is_cuda(),
|
||||
"fp32_router_gemm: all tensors must be CUDA tensors");
|
||||
STD_TORCH_CHECK(output.get_device_index() == mat_a.get_device_index() &&
|
||||
output.get_device_index() == mat_b.get_device_index(),
|
||||
"fp32_router_gemm: all tensors must be on the same device");
|
||||
STD_TORCH_CHECK(
|
||||
output.is_contiguous() && mat_a.is_contiguous() && mat_b.is_contiguous(),
|
||||
"fp32_router_gemm: all tensors must be contiguous");
|
||||
|
||||
const int num_tokens = mat_a.size(0);
|
||||
const int num_experts = mat_b.size(0);
|
||||
const int hidden_dim = mat_a.size(1);
|
||||
|
||||
STD_TORCH_CHECK(output.size(0) == num_tokens && output.size(1) == num_experts,
|
||||
"fp32_router_gemm: output must have shape [num_tokens, "
|
||||
"num_experts]");
|
||||
STD_TORCH_CHECK(
|
||||
mat_a.size(1) == mat_b.size(1),
|
||||
"fp32_router_gemm: mat_a and mat_b must have the same hidden_dim");
|
||||
STD_TORCH_CHECK(hidden_dim == FP32_HIDDEN_DIM,
|
||||
"fp32_router_gemm: expected hidden_dim=3072");
|
||||
STD_TORCH_CHECK(num_experts == FP32_NUM_EXPERTS,
|
||||
"fp32_router_gemm: expected num_experts=256");
|
||||
STD_TORCH_CHECK(num_tokens <= FP32_MAX_TOKENS,
|
||||
"fp32_router_gemm: num_tokens must be in [0, 32]");
|
||||
STD_TORCH_CHECK(
|
||||
mat_a.scalar_type() == torch::headeronly::ScalarType::Float ||
|
||||
mat_a.scalar_type() == torch::headeronly::ScalarType::BFloat16,
|
||||
"fp32_router_gemm: mat_a must be float32 or bfloat16");
|
||||
STD_TORCH_CHECK(mat_b.scalar_type() == torch::headeronly::ScalarType::Float,
|
||||
"fp32_router_gemm: mat_b (weight) must be float32");
|
||||
STD_TORCH_CHECK(output.scalar_type() == torch::headeronly::ScalarType::Float,
|
||||
"fp32_router_gemm: output must be float32");
|
||||
|
||||
if (num_tokens == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
STD_TORCH_CHECK(getSMVersion() >= 90, "fp32_router_gemm: requires SM90+");
|
||||
|
||||
auto stream = get_current_cuda_stream(mat_a.get_device_index());
|
||||
float* out_ptr = reinterpret_cast<float*>(output.mutable_data_ptr());
|
||||
float const* mat_b_ptr = reinterpret_cast<float const*>(mat_b.data_ptr());
|
||||
|
||||
if (mat_a.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
auto const* mat_a_ptr =
|
||||
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr());
|
||||
Fp32LoopUnroller<__nv_bfloat16, 1, FP32_MAX_TOKENS>::unroll(
|
||||
num_tokens, out_ptr, mat_a_ptr, mat_b_ptr, stream);
|
||||
} else {
|
||||
auto const* mat_a_ptr = reinterpret_cast<float const*>(mat_a.data_ptr());
|
||||
Fp32LoopUnroller<float, 1, FP32_MAX_TOKENS>::unroll(
|
||||
num_tokens, out_ptr, mat_a_ptr, mat_b_ptr, stream);
|
||||
}
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
|
||||
m.impl("fp32_router_gemm", TORCH_BOX(&fp32_router_gemm));
|
||||
}
|
||||
@@ -247,6 +247,10 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||
ops.def(
|
||||
"dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
|
||||
|
||||
// BF16/FP32 x FP32 -> FP32 router GEMM for H=3072, E=256, M<=32 (SM90+).
|
||||
// conditionally compiled so impl registration is in source file
|
||||
ops.def("fp32_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
|
||||
|
||||
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
||||
ops.def(
|
||||
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for fp32_router_gemm kernel: activation×weight→fp32, H=3072, E=256.
|
||||
|
||||
Correctness baseline: torch.matmul in float64.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm._custom_ops import fp32_router_gemm
|
||||
|
||||
NUM_EXPERTS = 256
|
||||
HIDDEN_DIM = 3072
|
||||
# Absolute tolerance for fp32 kernel vs float64 reference
|
||||
ATOL_FP32 = 2e-4
|
||||
ATOL_BF16 = 2e-2 # bf16 activation has lower precision
|
||||
|
||||
|
||||
def _requires_sm90():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if major * 10 + minor < 90:
|
||||
pytest.skip(f"fp32_router_gemm requires SM90+, got SM{major}{minor}")
|
||||
|
||||
|
||||
def _ref(mat_a: torch.Tensor, mat_b: torch.Tensor) -> torch.Tensor:
|
||||
"""Reference: F.linear in float32 on GPU."""
|
||||
return torch.nn.functional.linear(mat_a.float(), mat_b.float())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 2, 4, 8, 16, 32])
|
||||
def test_fp32_activation(num_tokens: int):
|
||||
"""fp32 activation → fp32 output should match reference closely."""
|
||||
_requires_sm90()
|
||||
torch.manual_seed(42)
|
||||
device = torch.device("cuda")
|
||||
mat_a = torch.randn(num_tokens, HIDDEN_DIM, dtype=torch.float32, device=device)
|
||||
mat_b = torch.randn(NUM_EXPERTS, HIDDEN_DIM, dtype=torch.float32, device=device)
|
||||
|
||||
out = fp32_router_gemm(mat_a, mat_b)
|
||||
ref = _ref(mat_a, mat_b)
|
||||
|
||||
assert out.shape == (num_tokens, NUM_EXPERTS)
|
||||
assert out.dtype == torch.float32
|
||||
torch.testing.assert_close(out, ref, atol=ATOL_FP32, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 2, 4, 8, 16, 32])
|
||||
def test_bf16_activation(num_tokens: int):
|
||||
"""bf16 activation → fp32 output should match reference within bf16 error."""
|
||||
_requires_sm90()
|
||||
torch.manual_seed(42)
|
||||
device = torch.device("cuda")
|
||||
mat_a_bf16 = torch.randn(
|
||||
num_tokens, HIDDEN_DIM, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
mat_b = torch.randn(NUM_EXPERTS, HIDDEN_DIM, dtype=torch.float32, device=device)
|
||||
|
||||
out = fp32_router_gemm(mat_a_bf16, mat_b)
|
||||
ref = _ref(mat_a_bf16, mat_b).to(device)
|
||||
|
||||
assert out.shape == (num_tokens, NUM_EXPERTS)
|
||||
assert out.dtype == torch.float32
|
||||
torch.testing.assert_close(out, ref, atol=ATOL_BF16, rtol=0)
|
||||
|
||||
|
||||
def test_output_shape_and_dtype():
|
||||
"""Basic shape and dtype checks."""
|
||||
_requires_sm90()
|
||||
device = torch.device("cuda")
|
||||
mat_a = torch.randn(4, HIDDEN_DIM, dtype=torch.float32, device=device)
|
||||
mat_b = torch.randn(NUM_EXPERTS, HIDDEN_DIM, dtype=torch.float32, device=device)
|
||||
out = fp32_router_gemm(mat_a, mat_b)
|
||||
assert out.shape == (4, NUM_EXPERTS)
|
||||
assert out.dtype == torch.float32
|
||||
assert out.device.type == "cuda"
|
||||
@@ -2412,6 +2412,31 @@ def dsv3_router_gemm(
|
||||
return output
|
||||
|
||||
|
||||
def fp32_router_gemm(
|
||||
hidden_states: torch.Tensor,
|
||||
router_weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
router_weight.shape[0],
|
||||
device=hidden_states.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
torch.ops._C.fp32_router_gemm(output, hidden_states, router_weight)
|
||||
return output
|
||||
|
||||
|
||||
if hasattr(torch.ops, "_C") and hasattr(torch.ops._C, "fp32_router_gemm"):
|
||||
|
||||
@register_fake("_C::fp32_router_gemm")
|
||||
def fp32_router_gemm_fake(
|
||||
output: torch.Tensor,
|
||||
mat_a: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
def topk_softmax(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
|
||||
@@ -3,18 +3,22 @@
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.model_executor.custom_op import PluggableLayer
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
@PluggableLayer.register("gate_linear")
|
||||
class GateLinear(ReplicatedLinear):
|
||||
"""MoE gate linear layer with three-tier GEMM dispatch:
|
||||
"""MoE gate linear layer with multi-tier GEMM dispatch:
|
||||
|
||||
1. DSV3 specialized kernel (SM90+, batch<=16, supported dims)
|
||||
2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
|
||||
3. F.linear via ReplicatedLinear (ultimate fallback)
|
||||
1. DSV3 specialized kernel (SM90+, fp32 out, M<=16, H=7168, E=256/384)
|
||||
2. fp32 specialized kernel (SM90+, bf16/fp32 in, fp32 out,
|
||||
M<=32, H=3072, E=256)
|
||||
3. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 weight + fp32 out_dtype)
|
||||
4. F.linear via ReplicatedLinear (ultimate fallback)
|
||||
|
||||
The ``out_dtype`` attribute is mutable and can be set after init
|
||||
(e.g. when the required dtype depends on the expert quantization
|
||||
@@ -25,6 +29,11 @@ class GateLinear(ReplicatedLinear):
|
||||
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
|
||||
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
|
||||
|
||||
# Dimensions supported by the fp32 specialized kernel
|
||||
FP32_SUPPORTED_NUM_EXPERTS = [256]
|
||||
FP32_SUPPORTED_HIDDEN_SIZES = [3072]
|
||||
FP32_MAX_TOKENS = 32
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
@@ -43,7 +52,7 @@ class GateLinear(ReplicatedLinear):
|
||||
)
|
||||
|
||||
# If fp32 compute is required and no specialized kernel is available,
|
||||
# store weights in fp32 so Tier 3 computes in fp32 natively.
|
||||
# store weights in fp32 so the fallback linear path computes in fp32.
|
||||
if force_fp32_compute and not can_use_specialized_kernels:
|
||||
params_dtype = torch.float32
|
||||
|
||||
@@ -65,6 +74,16 @@ class GateLinear(ReplicatedLinear):
|
||||
and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
|
||||
# fp32 specialized kernel eligibility (SM90+, exact dims, fp32 weight)
|
||||
self.allow_fp32_router_gemm = (
|
||||
not bias
|
||||
and self.weight.dtype == torch.float32
|
||||
and current_platform.is_cuda()
|
||||
and is_hopper_or_blackwell
|
||||
and output_size in self.FP32_SUPPORTED_NUM_EXPERTS
|
||||
and input_size in self.FP32_SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
|
||||
# cuBLAS bf16→fp32 eligibility
|
||||
self.allow_cublas_router_gemm = (
|
||||
self.allow_specialized_router_gemm
|
||||
@@ -92,8 +111,6 @@ class GateLinear(ReplicatedLinear):
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
# Tier 1: DSV3 specialized kernel
|
||||
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
|
||||
output = ops.dsv3_router_gemm(
|
||||
@@ -103,15 +120,56 @@ class GateLinear(ReplicatedLinear):
|
||||
)
|
||||
return output, None
|
||||
|
||||
# Tier 2: cuBLAS bf16→fp32
|
||||
# Tier 2: fp32 specialized kernel (H=3072, E=256, M<=32)
|
||||
# Dispatch is wrapped in a custom op so that torch.compile/CUDA-graph
|
||||
# capture does not freeze the runtime num_tokens branch.
|
||||
if self.allow_fp32_router_gemm and x.dtype in (
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
):
|
||||
output = torch.ops.vllm.fp32_router_gemm_dispatch(x, self.weight)
|
||||
return output, None
|
||||
|
||||
# Tier 3: cuBLAS bf16→fp32
|
||||
if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16:
|
||||
output = torch.mm(x, self.weight.T, out_dtype=torch.float32)
|
||||
return output, None
|
||||
|
||||
# Tier 3: F.linear (ReplicatedLinear)
|
||||
# Tier 4: F.linear (ReplicatedLinear)
|
||||
if self.out_dtype is not None and x.dtype != self.weight.dtype:
|
||||
x = x.to(self.weight.dtype)
|
||||
output, output_bias = super().forward(x)
|
||||
if self.out_dtype is not None and output.dtype != self.out_dtype:
|
||||
output = output.to(self.out_dtype)
|
||||
return output, output_bias
|
||||
|
||||
|
||||
_FP32_ROUTER_GEMM_MAX_TOKENS = GateLinear.FP32_MAX_TOKENS
|
||||
|
||||
|
||||
def fp32_router_gemm_dispatch_impl(
|
||||
x: torch.Tensor, weight: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Dynamically run fp32 specialized gemm if num_tokens <= FP32_MAX_TOKENS,
|
||||
otherwise fall back to F.linear.
|
||||
This must be wrapped in a custom op because our torch.compile integration
|
||||
does not support runtime dispatching on num_tokens.
|
||||
"""
|
||||
if x.shape[0] <= _FP32_ROUTER_GEMM_MAX_TOKENS:
|
||||
return ops.fp32_router_gemm(x, weight)
|
||||
else:
|
||||
return torch.nn.functional.linear(x.float(), weight)
|
||||
|
||||
|
||||
def fp32_router_gemm_dispatch_fake(
|
||||
x: torch.Tensor, weight: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return x.new_empty((x.shape[0], weight.shape[0]), dtype=torch.float32)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="fp32_router_gemm_dispatch",
|
||||
op_func=fp32_router_gemm_dispatch_impl,
|
||||
fake_impl=fp32_router_gemm_dispatch_fake,
|
||||
)
|
||||
|
||||
@@ -43,10 +43,10 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
fused_moe_make_expert_params_mapping,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@@ -113,12 +113,12 @@ class MiniMaxM2MoE(nn.Module):
|
||||
router_logits_dtype=torch.float32,
|
||||
)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
self.gate = GateLinear(
|
||||
config.hidden_size,
|
||||
config.num_local_experts,
|
||||
bias=False,
|
||||
params_dtype=torch.float32,
|
||||
quant_config=None,
|
||||
out_dtype=torch.float32,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
@@ -132,7 +132,7 @@ class MiniMaxM2MoE(nn.Module):
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states.to(torch.float32))
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user