Files
Jee Jee Li 559d6710bf [PERF]MiniMax-M2 gate kernel (#38445)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: qianlihuang <91178480+qianlihuang@users.noreply.github.com>
Co-authored-by: Yiliu Dong <91178480+qianlihuang@users.noreply.github.com>
2026-05-29 18:28:34 -07:00

224 lines
7.2 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
//
// Router GEMM: activation(T) x weight(fp32) -> fp32, H=3072, E=256, M<=32.
// Supports bf16 or fp32 activation; weight is always fp32.
// Adapted from dsv3_router_gemm_float_out.cu.
#include <cuda_bf16.h>
#include <cuda_runtime.h>
// ---------------------------------------------------------------------------
// Load helpers
// ---------------------------------------------------------------------------
// Load VPT fp32 values from the weight matrix (always fp32).
// VPT=4 when activation is fp32 (one float4 load)
// VPT=8 when activation is bf16 (two float4 loads)
template <int VPT>
__device__ __forceinline__ void load_weight(float const* ptr, float* dst);
template <>
__device__ __forceinline__ void load_weight<4>(float const* ptr, float* dst) {
float4 v = *reinterpret_cast<float4 const*>(ptr);
dst[0] = v.x;
dst[1] = v.y;
dst[2] = v.z;
dst[3] = v.w;
}
template <>
__device__ __forceinline__ void load_weight<8>(float const* ptr, float* dst) {
float4 v0 = *reinterpret_cast<float4 const*>(ptr);
float4 v1 = *reinterpret_cast<float4 const*>(ptr + 4);
dst[0] = v0.x;
dst[1] = v0.y;
dst[2] = v0.z;
dst[3] = v0.w;
dst[4] = v1.x;
dst[5] = v1.y;
dst[6] = v1.z;
dst[7] = v1.w;
}
// Load VPT activation values and convert to fp32.
template <typename T, int VPT>
__device__ __forceinline__ void load_activation(T const* ptr, float* dst);
// fp32 activation: one float4 load, no conversion needed.
template <>
__device__ __forceinline__ void load_activation<float, 4>(float const* ptr,
float* dst) {
float4 v = *reinterpret_cast<float4 const*>(ptr);
dst[0] = v.x;
dst[1] = v.y;
dst[2] = v.z;
dst[3] = v.w;
}
// bf16 activation: one uint4 load (8 × bf16) + element-wise conversion.
template <>
__device__ __forceinline__ void load_activation<__nv_bfloat16, 8>(
__nv_bfloat16 const* ptr, float* dst) {
uint4 v = *reinterpret_cast<uint4 const*>(ptr);
__nv_bfloat16 const* bf16_ptr = reinterpret_cast<__nv_bfloat16 const*>(&v);
#pragma unroll
for (int i = 0; i < 8; i++) dst[i] = __bfloat162float(bf16_ptr[i]);
}
// ---------------------------------------------------------------------------
// Kernel
// ---------------------------------------------------------------------------
// InputT : type of activation (float or __nv_bfloat16)
// Weight is always fp32; output is always fp32.
// VPT = 16 / sizeof(InputT): 4 for fp32, 8 for bf16
template <typename InputT, int kBlockSize, int kNumTokens, int kNumExperts,
int kHiddenDim>
__global__ __launch_bounds__(128, 1) void fp32_router_gemm_kernel(
float* out, InputT const* mat_a, float const* mat_b) {
constexpr int VPT = 16 / sizeof(InputT);
constexpr int k_elems_per_k_iteration = VPT * kBlockSize;
constexpr int k_iterations = kHiddenDim / k_elems_per_k_iteration;
constexpr int kWarpSize = 32;
constexpr int kNumWarps = kBlockSize / kWarpSize;
int const n_idx = blockIdx.x;
int const tid = threadIdx.x;
int const warpId = tid / kWarpSize;
int const laneId = tid % kWarpSize;
float acc[kNumTokens] = {};
__shared__ float sm_reduction[kNumTokens][kNumWarps];
float const* b_col = mat_b + n_idx * kHiddenDim;
int k_bases[k_iterations];
#pragma unroll
for (int ki = 0; ki < k_iterations; ki++) {
k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT;
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("griddepcontrol.wait;");
#endif
for (int ki = 0; ki < k_iterations; ki++) {
int const k_base = k_bases[ki];
float b_float[VPT];
load_weight<VPT>(b_col + k_base, b_float);
#pragma unroll
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
float a_float[VPT];
load_activation<InputT, VPT>(mat_a + m_idx * kHiddenDim + k_base,
a_float);
#pragma unroll
for (int k = 0; k < VPT; k++) {
acc[m_idx] += a_float[k] * b_float[k];
}
}
}
// Warp-level butterfly reduction
#pragma unroll
for (int m = 0; m < kNumTokens; m++) {
float sum = acc[m];
sum += __shfl_xor_sync(0xffffffff, sum, 16);
sum += __shfl_xor_sync(0xffffffff, sum, 8);
sum += __shfl_xor_sync(0xffffffff, sum, 4);
sum += __shfl_xor_sync(0xffffffff, sum, 2);
sum += __shfl_xor_sync(0xffffffff, sum, 1);
if (laneId == 0) sm_reduction[m][warpId] = sum;
}
__syncthreads();
if (tid == 0) {
#pragma unroll
for (int m = 0; m < kNumTokens; m++) {
float final_sum = 0.0f;
#pragma unroll
for (int w = 0; w < kNumWarps; w++) final_sum += sm_reduction[m][w];
out[m * kNumExperts + n_idx] = final_sum;
}
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
// ---------------------------------------------------------------------------
// Launcher
// ---------------------------------------------------------------------------
template <typename InputT, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeFp32RouterGemm(float* output, InputT const* mat_a,
float const* mat_b, cudaStream_t stream) {
constexpr int kBlockSize = 128;
cudaLaunchConfig_t config;
config.gridDim = kNumExperts;
config.blockDim = kBlockSize;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = 1;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config,
fp32_router_gemm_kernel<InputT, kBlockSize, kNumTokens,
kNumExperts, kHiddenDim>,
output, mat_a, mat_b);
}
// ---------------------------------------------------------------------------
// Explicit instantiations: M=1..32, E=256, H=3072, for both input types
// ---------------------------------------------------------------------------
#define INSTANTIATE(T, M) \
template void invokeFp32RouterGemm<T, M, 256, 3072>( \
float*, T const*, float const*, cudaStream_t);
#define INSTANTIATE_ALL(T) \
INSTANTIATE(T, 1) \
INSTANTIATE(T, 2) \
INSTANTIATE(T, 3) \
INSTANTIATE(T, 4) \
INSTANTIATE(T, 5) \
INSTANTIATE(T, 6) \
INSTANTIATE(T, 7) \
INSTANTIATE(T, 8) \
INSTANTIATE(T, 9) \
INSTANTIATE(T, 10) \
INSTANTIATE(T, 11) \
INSTANTIATE(T, 12) \
INSTANTIATE(T, 13) \
INSTANTIATE(T, 14) \
INSTANTIATE(T, 15) \
INSTANTIATE(T, 16) \
INSTANTIATE(T, 17) \
INSTANTIATE(T, 18) \
INSTANTIATE(T, 19) \
INSTANTIATE(T, 20) \
INSTANTIATE(T, 21) \
INSTANTIATE(T, 22) \
INSTANTIATE(T, 23) \
INSTANTIATE(T, 24) \
INSTANTIATE(T, 25) \
INSTANTIATE(T, 26) \
INSTANTIATE(T, 27) \
INSTANTIATE(T, 28) \
INSTANTIATE(T, 29) \
INSTANTIATE(T, 30) \
INSTANTIATE(T, 31) \
INSTANTIATE(T, 32)
INSTANTIATE_ALL(float)
INSTANTIATE_ALL(__nv_bfloat16)
#undef INSTANTIATE_ALL
#undef INSTANTIATE