[Feat] DeepSeek V4 Rebased (#40860)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: qizixi <zixi@inferact.ai>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Yongye Zhu <yongye@inferact.ai>
Co-authored-by: Simon Mo <simon@inferact.ai>
Co-authored-by: Bugen Zhao <i@bugenzhao.com>
Co-authored-by: Giancarlo Delfin <gdelfin@inferact.ai>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roy Wang <yasong.wang@inferact.ai>
Co-authored-by: Woosuk Kwon <woosuk@inferact.ai>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Zhewen Li <jerven.vllm@gmail.com>
Co-authored-by: Zijing Liu <liuzijing2014@gmail.com>
Co-authored-by: khluu <khluu000@gmail.com>
Co-authored-by: qizixi <zixi@inferact.ai>
Co-authored-by: Zhewen Li <zhewenli@inferact.ai>
This commit is contained in:
Yifan Qiao
2026-04-26 18:31:08 -07:00
committed by GitHub
parent 32e45636e3
commit 4d51588e23
150 changed files with 16312 additions and 716 deletions
+5 -2
View File
@@ -310,7 +310,9 @@ set(VLLM_EXT_SRC
"csrc/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC "csrc/minimax_reduce_rms_kernel.cu")
list(APPEND VLLM_EXT_SRC
"csrc/minimax_reduce_rms_kernel.cu"
"csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
@@ -1051,7 +1053,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC
"csrc/moe/moe_wna16.cu"
"csrc/moe/grouped_topk_kernels.cu"
"csrc/moe/router_gemm.cu")
"csrc/moe/router_gemm.cu"
"csrc/moe/topk_softplus_sqrt_kernels.cu")
endif()
if(VLLM_GPU_LANG STREQUAL "CUDA")
+6 -1
View File
@@ -20,7 +20,7 @@ else()
FetchContent_Declare(
deepgemm
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM.git
GIT_TAG 477618cd51baffca09c4b0b87e97c03fe827ef03
GIT_TAG 891d57b4db1071624b5c8fa0d1e51cb317fa709f
GIT_SUBMODULES "third-party/cutlass" "third-party/fmt"
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
@@ -120,6 +120,11 @@ if(DEEPGEMM_ARCHS)
COMPONENT _deep_gemm_C
FILES_MATCHING PATTERN "*.py")
install(DIRECTORY "${deepgemm_SOURCE_DIR}/deep_gemm/mega/"
DESTINATION vllm/third_party/deep_gemm/mega
COMPONENT _deep_gemm_C
FILES_MATCHING PATTERN "*.py")
# Generate envs.py (normally generated by DeepGEMM's setup.py build step)
file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/deep_gemm_envs.py"
"# Pre-installed environment variables\npersistent_envs = dict()\n")
+1 -1
View File
@@ -19,7 +19,7 @@ else()
FetchContent_Declare(
flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
GIT_TAG 692917b1cda61b93ac9ee2d846ec54e75afe87b1
GIT_TAG a6ec2ba7bd0a7dff98b3f4d3e6b52b159c48d78b
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
+6 -1
View File
@@ -178,7 +178,12 @@ void rotary_embedding_gptj_impl(
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox) {
torch::Tensor& cos_sin_cache, bool is_neox,
int64_t rope_dim_offset, bool inverse) {
TORCH_CHECK(rope_dim_offset == 0,
"rope_dim_offset != 0 is not supported on CPU");
TORCH_CHECK(!inverse, "inverse rotary embedding is not supported on CPU");
int num_tokens = positions.numel();
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
+2 -1
View File
@@ -263,7 +263,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()");
" Tensor cos_sin_cache, bool is_neox, int "
"rope_dim_offset=0, bool inverse=False) -> ()");
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
// Quantization
@@ -0,0 +1,477 @@
/*
* SPDX-License-Identifier: Apache-2.0
* SPDX-FileCopyrightText: Copyright contributors to the vLLM project
*
* Horizontally-fused DeepseekV4-MLA kernel:
* - Q side: per-head RMSNorm (no weight) + GPT-J RoPE on last ROPE_DIM
* - KV side: GPT-J RoPE on last ROPE_DIM + UE8M0 FP8 quant on NoPE + paged
* cache insert
*
* Structured after `applyMLARopeAndAssignQKVKernelGeneration` in
* TensorRT-LLM's mlaKernels.cu: one kernel, one grid, with head-slot
* dispatch choosing Q vs KV work per warp. The per-warp RMSNorm/RoPE
* skeleton is adapted from vllm-deepseek_v4's existing
* `fusedQKNormRopeKernel` (csrc/fused_qknorm_rope_kernel.cu).
*
* Assumptions (hard-coded for DeepseekV4 attention):
* HEAD_DIM = 512
* ROPE_DIM = 64 (RoPE applied to dims [NOPE_DIM, HEAD_DIM))
* NOPE_DIM = 448
* QUANT_BLOCK = 64 (UE8M0 FP8 quant block)
* FP8_MAX = 448.0f
* is_neox=false (GPT-J interleaved pairs)
* cos_sin_cache layout [max_pos, rope_dim] = cos || sin (cos first, sin
* second along last dim; each half is rope_dim/2 = 32 values)
*
* Cache layout per paged-cache block (block_size tokens):
* [0, bs*576): token data, 448 fp8 + 128 bf16 each
* [bs*576, bs*576 + bs*8): UE8M0 scales, 7 real + 1 pad per token
*/
#include <cmath>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <type_traits>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/cuda.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "type_convert.cuh"
#ifndef FINAL_MASK
#define FINAL_MASK 0xffffffffu
#endif
namespace vllm {
namespace deepseek_v4_fused_ops {
namespace {
inline int getSMVersion() {
auto* props = at::cuda::getCurrentDeviceProperties();
return props->major * 10 + props->minor;
}
} // namespace
// ────────────────────────────────────────────────────────────────────────────
// Constants
// ────────────────────────────────────────────────────────────────────────────
constexpr int kHeadDim = 512;
constexpr int kRopeDim = 64;
constexpr int kNopeDim = kHeadDim - kRopeDim; // 448
constexpr int kQuantBlock = 64;
constexpr int kNumQuantBlocks = kNopeDim / kQuantBlock; // 7
constexpr int kScaleBytesPerToken = kNumQuantBlocks + 1; // 8 (7 real + 1 pad)
constexpr int kTokenDataBytes = kNopeDim + kRopeDim * 2; // 448 + 128 = 576
constexpr float kFp8Max = 448.0f;
// Per-warp layout: 32 lanes × 16 elems/lane = 512 elems = HEAD_DIM.
constexpr int kNumLanes = 32;
constexpr int kElemsPerLane = kHeadDim / kNumLanes; // 16
// ────────────────────────────────────────────────────────────────────────────
// Small inline helpers
// ────────────────────────────────────────────────────────────────────────────
__device__ __forceinline__ float warp4MaxAbs(float val) {
// Reduce absolute max across 4 consecutive lanes (lane id & 3 group).
float peer = __shfl_xor_sync(FINAL_MASK, val, 1);
val = fmaxf(val, peer);
peer = __shfl_xor_sync(FINAL_MASK, val, 2);
val = fmaxf(val, peer);
return val;
}
template <typename T>
__device__ __forceinline__ float warpSum(float val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
}
return val;
}
// ────────────────────────────────────────────────────────────────────────────
// Kernel
// ────────────────────────────────────────────────────────────────────────────
//
// Grid: 1D, gridDim.x = ceil(num_tokens_full * (num_heads_q + 1) /
// warps_per_block) Block: blockDim.x = 256 threads (8 warps per block) Each
// warp handles one (token, head_slot) pair. head_slot < num_heads_q →
// Q branch (RMSNorm + RoPE, in place) head_slot == num_heads_q → KV
// branch (RoPE + UE8M0 quant + insert)
//
// With DP padding, q/kv/position_ids can have more rows than slot_mapping.
// The Q branch covers all `num_tokens_full` rows (downstream attention uses
// them). The KV branch only inserts the first `num_tokens_insert` tokens
// (= slot_mapping length) into the paged cache.
//
template <typename scalar_t_in>
__global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel(
scalar_t_in* __restrict__ q_inout, // [N, H, 512] bf16, in place
scalar_t_in const* __restrict__ kv_in, // [N, 512] bf16
uint8_t* __restrict__ k_cache, // [num_blocks, block_stride]
int64_t const* __restrict__ slot_mapping, // [num_tokens_insert] i64
int64_t const* __restrict__ position_ids, // [N] i64
float const* __restrict__ cos_sin_cache, // [max_pos, 64] fp32
float const eps,
int const num_tokens_full, // = q.size(0) = kv.size(0)
int const num_tokens_insert, // = slot_mapping.size(0), ≤ num_tokens_full
int const num_heads_q, // H
int const cache_block_size, // tokens per paged-cache block
int const kv_block_stride) { // bytes per paged-cache block
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
// BF16 _typeConvert specialization is unavailable on pre-Ampere. The
// DeepseekV4 kernel only runs with bf16 inputs in practice, so compile a
// no-op stub for sm_70/sm_75 to keep multi-arch builds happy.
if constexpr (std::is_same_v<scalar_t_in, c10::BFloat16>) {
return;
} else {
#endif
using Converter = vllm::_typeConvert<scalar_t_in>;
int const warpsPerBlock = blockDim.x / 32;
int const warpId = threadIdx.x / 32;
int const laneId = threadIdx.x % 32;
int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId;
int const total_slots_per_token = num_heads_q + 1;
int const tokenIdx = globalWarpIdx / total_slots_per_token;
int const slotIdx = globalWarpIdx % total_slots_per_token;
if (tokenIdx >= num_tokens_full) return;
bool const isKV = (slotIdx == num_heads_q);
// KV branch: skip DP-padded tokens (no slot reserved for them).
if (isKV && tokenIdx >= num_tokens_insert) return;
// PDL: wait for predecessor kernel (upstream q/kv producer) to signal
// before touching any global memory. No-op when PDL is not enabled on
// the launch. The CUDA runtime wrapper emits the griddepcontrol.wait
// PTX with the required memory clobber internally.
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaGridDependencySynchronize();
#endif
// Dim range this lane owns within the 512-wide head.
int const dim_base = laneId * kElemsPerLane; // in [0, 512) step 16
// ── Load 16 bf16 → 16 fp32 registers (one 16-byte + one 16-byte LDG) ────
float elements[kElemsPerLane];
float sumOfSquares = 0.0f;
scalar_t_in const* src_ptr;
if (isKV) {
src_ptr = kv_in + static_cast<int64_t>(tokenIdx) * kHeadDim + dim_base;
} else {
int64_t const q_row_offset =
(static_cast<int64_t>(tokenIdx) * num_heads_q + slotIdx) * kHeadDim +
dim_base;
src_ptr = q_inout + q_row_offset;
}
// Two 16-byte loads per thread (8 bf16 each). Use uint4 as the vector
// type and bitcast to scalar_t_in packed pairs for conversion.
uint4 v0 = *reinterpret_cast<uint4 const*>(src_ptr);
uint4 v1 = *reinterpret_cast<uint4 const*>(src_ptr + 8);
{
typename Converter::packed_hip_type const* p0 =
reinterpret_cast<typename Converter::packed_hip_type const*>(&v0);
typename Converter::packed_hip_type const* p1 =
reinterpret_cast<typename Converter::packed_hip_type const*>(&v1);
// Each packed_hip_type holds 2 bf16 → 4 packed = 8 elems per uint4.
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 f2 = Converter::convert(p0[i]);
elements[2 * i] = f2.x;
elements[2 * i + 1] = f2.y;
}
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 f2 = Converter::convert(p1[i]);
elements[8 + 2 * i] = f2.x;
elements[8 + 2 * i + 1] = f2.y;
}
}
// ── Q branch: RMSNorm with no weight (has_weight=False) ─────────────────
// Variance + rsqrt + multiply all in fp32, no intermediate bf16 round.
// The downstream bf16 round only happens at the final store.
if (!isKV) {
#pragma unroll
for (int i = 0; i < kElemsPerLane; i++) {
sumOfSquares += elements[i] * elements[i];
}
sumOfSquares = warpSum<float>(sumOfSquares);
float const rms_rcp =
rsqrtf(sumOfSquares / static_cast<float>(kHeadDim) + eps);
#pragma unroll
for (int i = 0; i < kElemsPerLane; i++) {
elements[i] = elements[i] * rms_rcp;
}
}
// ── GPT-J RoPE on dims [NOPE_DIM, HEAD_DIM) ─────────────────────────────
// All math in fp32. cos_sin_cache is loaded as fp32 (its native storage).
bool const is_rope_lane = dim_base >= kNopeDim;
if (is_rope_lane) {
int64_t const pos = position_ids[tokenIdx];
constexpr int kHalfRope = kRopeDim / 2; // 32
float const* cos_ptr = cos_sin_cache + pos * kRopeDim;
float const* sin_ptr = cos_ptr + kHalfRope;
int const rope_local_base = dim_base - kNopeDim; // in [0, 64) step 16
#pragma unroll
for (int p = 0; p < kElemsPerLane / 2; p++) {
int const pair_dim = rope_local_base + 2 * p;
int const half_idx = pair_dim / 2;
float const cos_v = VLLM_LDG(cos_ptr + half_idx);
float const sin_v = VLLM_LDG(sin_ptr + half_idx);
float const x_even = elements[2 * p];
float const x_odd = elements[2 * p + 1];
elements[2 * p] = x_even * cos_v - x_odd * sin_v;
elements[2 * p + 1] = x_even * sin_v + x_odd * cos_v;
}
}
// ═══════════════════════════════════════════════════════════════════════
// Q branch: cast to bf16 and store back in place.
// ═══════════════════════════════════════════════════════════════════════
if (!isKV) {
uint4 out0, out1;
typename Converter::packed_hip_type* po0 =
reinterpret_cast<typename Converter::packed_hip_type*>(&out0);
typename Converter::packed_hip_type* po1 =
reinterpret_cast<typename Converter::packed_hip_type*>(&out1);
#pragma unroll
for (int i = 0; i < 4; i++) {
po0[i] = Converter::convert(
make_float2(elements[2 * i], elements[2 * i + 1]));
}
#pragma unroll
for (int i = 0; i < 4; i++) {
po1[i] = Converter::convert(
make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1]));
}
scalar_t_in* dst =
q_inout +
(static_cast<int64_t>(tokenIdx) * num_heads_q + slotIdx) * kHeadDim +
dim_base;
*reinterpret_cast<uint4*>(dst) = out0;
*reinterpret_cast<uint4*>(dst + 8) = out1;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
return;
}
// ═══════════════════════════════════════════════════════════════════════
// KV branch.
// ═══════════════════════════════════════════════════════════════════════
int64_t const slot_id = slot_mapping[tokenIdx];
if (slot_id < 0) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
return;
}
int64_t const block_idx = slot_id / cache_block_size;
int64_t const pos_in_block = slot_id % cache_block_size;
uint8_t* block_base =
k_cache + block_idx * static_cast<int64_t>(kv_block_stride);
uint8_t* token_fp8_ptr = block_base + pos_in_block * kTokenDataBytes;
uint8_t* token_bf16_ptr = token_fp8_ptr + kNopeDim;
uint8_t* token_scale_ptr =
block_base + static_cast<int64_t>(cache_block_size) * kTokenDataBytes +
pos_in_block * kScaleBytesPerToken;
// Round K to bf16 first, matching the unfused reference path where K is
// materialized as bf16 before K quantization. absmax, clamp, and FP8
// quant below all run on these bf16-rounded values.
#pragma unroll
for (int i = 0; i < kElemsPerLane; i++) {
elements[i] = Converter::convert(Converter::convert(elements[i]));
}
// Per-quant-block absmax must be computed by ALL 32 lanes (warp-collective
// shuffle requires full participation). RoPE lanes contribute garbage,
// but their values are gated out below via `!is_rope_lane`.
float local_absmax = 0.0f;
#pragma unroll
for (int i = 0; i < kElemsPerLane; i++) {
local_absmax = fmaxf(local_absmax, fabsf(elements[i]));
}
float const absmax = fmaxf(warp4MaxAbs(local_absmax), 1e-4f);
float const exponent = ceilf(log2f(absmax / kFp8Max));
float const inv_scale = exp2f(-exponent);
if (!is_rope_lane) {
// ── NoPE lane: UE8M0 FP8 quant ───────────────────────────────────────
uint8_t out_bytes[kElemsPerLane];
#pragma unroll
for (int i = 0; i < kElemsPerLane; i++) {
float scaled = elements[i] * inv_scale;
scaled = fminf(fmaxf(scaled, -kFp8Max), kFp8Max);
__nv_fp8_storage_t s =
__nv_cvt_float_to_fp8(scaled, __NV_SATFINITE, __NV_E4M3);
out_bytes[i] = static_cast<uint8_t>(s);
}
// One 16-byte STG per lane.
*reinterpret_cast<uint4*>(token_fp8_ptr + dim_base) =
*reinterpret_cast<uint4 const*>(out_bytes);
// Lane (4k) of each 4-lane group writes the scale byte for block k<7.
if ((laneId & 3) == 0) {
int const q_block_idx = laneId >> 2; // 0..6 for NoPE lanes
float encoded = fmaxf(fminf(exponent + 127.0f, 255.0f), 0.0f);
token_scale_ptr[q_block_idx] = static_cast<uint8_t>(encoded);
}
// Lane 0 also writes the padding byte at index 7.
if (laneId == 0) {
token_scale_ptr[kNumQuantBlocks] = 0; // pad
}
} else {
// ── RoPE lane: cast back to bf16 and store to cache bf16 tail ────────
uint4 out0, out1;
typename Converter::packed_hip_type* po0 =
reinterpret_cast<typename Converter::packed_hip_type*>(&out0);
typename Converter::packed_hip_type* po1 =
reinterpret_cast<typename Converter::packed_hip_type*>(&out1);
#pragma unroll
for (int i = 0; i < 4; i++) {
po0[i] = Converter::convert(
make_float2(elements[2 * i], elements[2 * i + 1]));
}
#pragma unroll
for (int i = 0; i < 4; i++) {
po1[i] = Converter::convert(
make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1]));
}
int const rope_local_base = dim_base - kNopeDim; // in [0, 64)
scalar_t_in* bf16_dst =
reinterpret_cast<scalar_t_in*>(token_bf16_ptr) + rope_local_base;
*reinterpret_cast<uint4*>(bf16_dst) = out0;
*reinterpret_cast<uint4*>(bf16_dst + 8) = out1;
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
}
#endif
}
// ────────────────────────────────────────────────────────────────────────────
// Launch wrapper
// ────────────────────────────────────────────────────────────────────────────
template <typename scalar_t_in>
void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
scalar_t_in* q_inout, scalar_t_in const* kv_in, uint8_t* k_cache,
int64_t const* slot_mapping, int64_t const* position_ids,
float const* cos_sin_cache, float const eps, int const num_tokens_full,
int const num_tokens_insert, int const num_heads_q,
int const cache_block_size, int const kv_block_stride,
cudaStream_t stream) {
constexpr int kBlockSize = 256;
constexpr int kWarpsPerBlock = kBlockSize / 32;
int64_t const total_warps =
static_cast<int64_t>(num_tokens_full) * (num_heads_q + 1);
int const grid =
static_cast<int>((total_warps + kWarpsPerBlock - 1) / kWarpsPerBlock);
// PDL: enable programmatic stream serialization whenever the hardware
// supports it (SM90+). On pre-Hopper GPUs the attribute is unavailable,
// so leave numAttrs = 0 and launch as a regular kernel.
static int const sm_version = getSMVersion();
// Host-side guard: the device kernel body is compiled as a no-op for
// bf16 on pre-Ampere (sm_70/sm_75) because _typeConvert<BFloat16> is
// unavailable there. Refuse the launch loudly instead of silently
// skipping the work.
TORCH_CHECK(
sm_version >= 80,
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert requires sm_80+ "
"(Ampere or newer); got sm_",
sm_version);
cudaLaunchConfig_t config;
config.gridDim = dim3(grid);
config.blockDim = dim3(kBlockSize);
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = 1;
config.attrs = attrs;
config.numAttrs = (sm_version >= 90) ? 1 : 0;
cudaLaunchKernelEx(
&config, fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel<scalar_t_in>,
q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, eps,
num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size,
kv_block_stride);
}
} // namespace deepseek_v4_fused_ops
} // namespace vllm
// ────────────────────────────────────────────────────────────────────────────
// Torch op wrapper
// ────────────────────────────────────────────────────────────────────────────
void fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::Tensor& q, // [N, H, 512] bf16, in place
torch::Tensor const& kv, // [N, 512] bf16 (read-only)
torch::Tensor& k_cache, // [num_blocks, block_bytes] uint8
torch::Tensor const& slot_mapping, // [N] int64
torch::Tensor const& position_ids, // [N] int64
torch::Tensor const& cos_sin_cache, // [max_pos, rope_dim] bf16
double eps, int64_t cache_block_size) {
TORCH_CHECK(q.is_cuda() && q.is_contiguous(), "q must be contiguous CUDA");
TORCH_CHECK(kv.is_cuda() && kv.is_contiguous(), "kv must be contiguous CUDA");
TORCH_CHECK(k_cache.is_cuda(), "k_cache must be CUDA");
TORCH_CHECK(slot_mapping.is_cuda() && slot_mapping.dtype() == torch::kInt64,
"slot_mapping must be int64 CUDA");
TORCH_CHECK(position_ids.is_cuda() && position_ids.dtype() == torch::kInt64,
"position_ids must be int64 CUDA");
TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA");
TORCH_CHECK(q.dim() == 3 && q.size(2) == 512, "q shape [N, H, 512]");
TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
TORCH_CHECK(q.dtype() == kv.dtype(), "q and kv dtype must match");
TORCH_CHECK(k_cache.dtype() == torch::kUInt8, "k_cache must be uint8");
TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
"cos_sin_cache shape [max_pos, 64]");
TORCH_CHECK(cos_sin_cache.dtype() == torch::kFloat32,
"cos_sin_cache must be float32");
// With DP padding, slot_mapping can be shorter than q/kv/positions.
// Q-norm+RoPE runs on all q.size(0) rows (downstream attention uses them);
// KV quant+insert runs only on the first slot_mapping.size(0) rows.
int const num_tokens_full = static_cast<int>(q.size(0));
int const num_tokens_insert = static_cast<int>(slot_mapping.size(0));
TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
static_cast<int>(position_ids.size(0)) == num_tokens_full,
"q/kv/position_ids row counts must match");
TORCH_CHECK(num_tokens_insert <= num_tokens_full,
"slot_mapping must not exceed q row count");
int const num_heads_q = static_cast<int>(q.size(1));
int const cache_block_size_i = static_cast<int>(cache_block_size);
int const kv_block_stride = static_cast<int>(k_cache.stride(0));
at::cuda::OptionalCUDAGuard device_guard(device_of(q));
auto stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_HALF_TYPES(
q.scalar_type(), "fused_deepseek_v4_qnorm_rope_kv_insert", [&] {
using qkv_scalar_t = scalar_t;
vllm::deepseek_v4_fused_ops::
launchFusedDeepseekV4QNormRopeKVRopeQuantInsert<qkv_scalar_t>(
reinterpret_cast<qkv_scalar_t*>(q.data_ptr()),
reinterpret_cast<qkv_scalar_t const*>(kv.data_ptr()),
reinterpret_cast<uint8_t*>(k_cache.data_ptr()),
reinterpret_cast<int64_t const*>(slot_mapping.data_ptr()),
reinterpret_cast<int64_t const*>(position_ids.data_ptr()),
cos_sin_cache.data_ptr<float>(), static_cast<float>(eps),
num_tokens_full, num_tokens_insert, num_heads_q,
cache_block_size_i, kv_block_stride, stream);
});
}
+15 -7
View File
@@ -77,7 +77,8 @@ __global__ void rms_norm_kernel(
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
float x = static_cast<float>(src1.val[j]);
dst.val[j] = ((scalar_t)(x * s_variance)) * src2.val[j];
float w = static_cast<float>(src2.val[j]);
dst.val[j] = static_cast<scalar_t>(x * s_variance * w);
}
v_out[i] = dst;
}
@@ -134,10 +135,17 @@ fused_add_rms_norm_kernel(
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
int64_t strided_id = blockIdx.x * vec_input_stride + idx;
_f16Vec<scalar_t, width> temp = residual_v[id];
temp *= s_variance;
temp *= weight_v[idx];
input_v[strided_id] = temp;
_f16Vec<scalar_t, width> res = residual_v[id];
_f16Vec<scalar_t, width> w = weight_v[idx];
_f16Vec<scalar_t, width> out;
using Converter = _typeConvert<scalar_t>;
#pragma unroll
for (int j = 0; j < width; ++j) {
float x = Converter::convert(res.data[j]);
float wf = Converter::convert(w.data[j]);
out.data[j] = Converter::convert(x * s_variance * wf);
}
input_v[strided_id] = out;
}
}
@@ -174,8 +182,8 @@ fused_add_rms_norm_kernel(
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx];
input[blockIdx.x * input_stride + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
float w = (float)weight[idx];
input[blockIdx.x * input_stride + idx] = (scalar_t)(x * s_variance * w);
}
}
+28 -10
View File
@@ -65,9 +65,16 @@ __global__ void rms_norm_static_fp8_quant_kernel(
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
float x = static_cast<float>(src1.val[j]);
float const out_norm = ((scalar_t)(x * s_variance)) * src2.val[j];
float w = static_cast<float>(src2.val[j]);
// Round normalized result through scalar_t to match the precision of the
// unfused composite (rms_norm writes scalar_t, then
// static_scaled_fp8_quant re-loads it as float before FP8 conversion).
// Without this round, the fused path is strictly more accurate and
// disagrees with the composite at exact E4M3 quantization tie boundaries.
scalar_t out_norm = static_cast<scalar_t>(x * s_variance * w);
out[blockIdx.x * hidden_size + idx * VEC_SIZE + j] =
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
scaled_fp8_conversion<true, fp8_type>(static_cast<float>(out_norm),
scale_inv);
}
}
}
@@ -127,13 +134,21 @@ fused_add_rms_norm_static_fp8_quant_kernel(
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = residual_v[id];
temp *= s_variance;
temp *= weight_v[idx];
_f16Vec<scalar_t, width> res = residual_v[id];
_f16Vec<scalar_t, width> w = weight_v[idx];
using Converter = _typeConvert<scalar_t>;
using HipT = typename Converter::hip_type;
#pragma unroll
for (int i = 0; i < width; ++i) {
out[id * width + i] =
scaled_fp8_conversion<true, fp8_type>(float(temp.data[i]), scale_inv);
float x = Converter::convert(res.data[i]);
float wf = Converter::convert(w.data[i]);
// See note in rms_norm_static_fp8_quant_kernel: round through scalar_t
// to match the unfused composite path at FP8 boundaries. We use the
// backend's hip_type for the intermediate since c10::Half/BFloat16 has
// ambiguous conversions on CUDA and no implicit conversion on ROCm.
HipT out_norm_h = Converter::convert(x * s_variance * wf);
out[id * width + i] = scaled_fp8_conversion<true, fp8_type>(
Converter::convert(out_norm_h), scale_inv);
}
}
}
@@ -176,9 +191,12 @@ fused_add_rms_norm_static_fp8_quant_kernel(
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
float w = (float)weight[idx];
// See note in rms_norm_static_fp8_quant_kernel: round through scalar_t
// to match the unfused composite path at FP8 boundaries.
scalar_t out_norm = static_cast<scalar_t>(x * s_variance * w);
out[blockIdx.x * hidden_size + idx] = scaled_fp8_conversion<true, fp8_type>(
static_cast<float>(out_norm), scale_inv);
}
}
+9
View File
@@ -12,6 +12,15 @@ void topk_sigmoid(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& gating_output, bool renormalize,
std::optional<torch::Tensor> bias);
void topk_softplus_sqrt(torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output, bool renormalize,
double routed_scaling_factor,
const c10::optional<torch::Tensor>& correction_bias,
const c10::optional<torch::Tensor>& input_ids,
const c10::optional<torch::Tensor>& tid2eid);
void moe_sum(torch::Tensor& input, torch::Tensor& output);
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
+715
View File
@@ -0,0 +1,715 @@
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
* Copyright (c) 2024, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <type_traits>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../cuda_compat.h"
#include "../cub_helpers.h"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
typedef __hip_bfloat16 __nv_bfloat16;
typedef __hip_bfloat162 __nv_bfloat162;
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
namespace vllm {
namespace moe {
/// Aligned array type
template <typename T,
/// Number of elements in the array
int N,
/// Alignment requirement in bytes
int Alignment = sizeof(T) * N>
struct alignas(Alignment) AlignedArray {
T data[N];
};
template <typename T>
__device__ __forceinline__ float toFloat(T value) {
if constexpr (std::is_same_v<T, float>) {
return value;
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
return __bfloat162float(value);
} else if constexpr (std::is_same_v<T, __half>) {
return __half2float(value);
}
}
#define FINAL_MASK 0xffffffff
template <typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
return val;
}
// ====================== TopK softplus_sqrt things
// ===============================
/*
A Top-K gating softplus_sqrt written to exploit when the number of experts in
the MoE layers are a small power of 2. This allows us to cleanly share the
rows among the threads in a single warp and eliminate communication between
warps (so no need to use shared mem).
It fuses the sigmoid, max and argmax into a single kernel.
Limitations:
1) This implementation is optimized for when the number of experts is a small
power of 2. Additionally it also supports when number of experts is multiple
of 64 which is still faster than the computing sigmoid and topK separately
(only tested on CUDA yet). 2) This implementation assumes k is small, but will
work for any k.
*/
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG,
int WARP_SIZE_PARAM, bool USE_HASH, typename IndType,
typename InputType = float>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
void topkGatingSoftplusSqrt(
const InputType* input, const bool* finished, float* output,
const int num_rows, IndType* indices, int* source_rows, const int k,
const int start_expert, const int end_expert, const bool renormalize,
double routed_scaling_factor, const float* correction_bias,
const IndType* input_ids, const IndType* tid2eid) {
static_assert(std::is_same_v<InputType, float> ||
std::is_same_v<InputType, __nv_bfloat16> ||
std::is_same_v<InputType, __half>,
"InputType must be float, __nv_bfloat16, or __half");
// We begin by enforcing compile time assertions and setting up compile time
// constants.
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG),
"BYTES_PER_LDG must be power of 2");
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
// Number of bytes each thread pulls in per load
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType);
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
if constexpr (std::is_same_v<InputType, __nv_bfloat16> ||
std::is_same_v<InputType, __half>) {
static_assert(ELTS_PER_LDG == 1 || ELTS_PER_LDG % 2 == 0,
"ELTS_PER_LDG must be 1 or even for 16-bit conversion");
}
// Restrictions based on previous section.
static_assert(
VPT % ELTS_PER_LDG == 0,
"The elements per thread must be a multiple of the elements per ldg");
static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0,
"The threads per row must cleanly divide the threads per warp");
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW),
"THREADS_PER_ROW must be power of 2");
static_assert(THREADS_PER_ROW <= WARP_SIZE_PARAM,
"THREADS_PER_ROW can be at most warp size");
// We have NUM_EXPERTS elements per row. We specialize for small #experts
static constexpr int ELTS_PER_WARP = WARP_SIZE_PARAM * VPT;
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
// Restrictions for previous section.
static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0,
"The elts per row must cleanly divide the total elt per warp");
// ===================== From this point, we finally start computing run-time
// variables. ========================
// Compute CTA and warp rows. We pack multiple rows into a single warp, and a
// block contains WARPS_PER_CTA warps. This, each block processes a chunk of
// rows. We start by computing the start row for each block.
const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
// Now, using the base row per thread block, we compute the base row per warp.
const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
// The threads in a warp are split into sub-groups that will work on a row.
// We compute row offset for each thread sub-group
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
const int thread_row = warp_base_row + thread_row_in_warp;
// Threads with indices out of bounds should early exit here.
if (thread_row >= num_rows) {
return;
}
const bool row_is_active = finished ? !finished[thread_row] : true;
// We finally start setting up the read pointers for each thread. First, each
// thread jumps to the start of the row it will read.
const InputType* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
// Now, we compute the group each thread belong to in order to determine the
// first column to start loads.
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
const InputType* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
// Finally, we pull in the data from global mem
float row_chunk[VPT];
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
// NOTE(zhuhaoran): dispatch different input types loading, BF16/FP16 convert
// to float
if constexpr (std::is_same_v<InputType, float>) {
using VecType = AlignedArray<float, ELTS_PER_LDG>;
VecType* row_chunk_vec_ptr = reinterpret_cast<VecType*>(&row_chunk);
const VecType* vec_thread_read_ptr =
reinterpret_cast<const VecType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
}
} else if constexpr (std::is_same_v<InputType, __nv_bfloat16>) {
if constexpr (ELTS_PER_LDG >= 2) {
using VecType = AlignedArray<__nv_bfloat16, ELTS_PER_LDG>;
float2* row_chunk_f2 = reinterpret_cast<float2*>(row_chunk);
const VecType* vec_thread_read_ptr =
reinterpret_cast<const VecType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW];
int base_idx_f2 = ii * ELTS_PER_LDG / 2;
#pragma unroll
for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) {
row_chunk_f2[base_idx_f2 + jj] = __bfloat1622float2(
*reinterpret_cast<const __nv_bfloat162*>(vec.data + jj * 2));
}
}
} else { // ELTS_PER_LDG == 1
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
const __nv_bfloat16* scalar_ptr =
thread_read_ptr + ii * THREADS_PER_ROW;
row_chunk[ii] = __bfloat162float(*scalar_ptr);
}
}
} else if constexpr (std::is_same_v<InputType, __half>) {
if constexpr (ELTS_PER_LDG >= 2) {
using VecType = AlignedArray<__half, ELTS_PER_LDG>;
float2* row_chunk_f2 = reinterpret_cast<float2*>(row_chunk);
const VecType* vec_thread_read_ptr =
reinterpret_cast<const VecType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW];
int base_idx_f2 = ii * ELTS_PER_LDG / 2;
#pragma unroll
for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) {
row_chunk_f2[base_idx_f2 + jj] = __half22float2(
*reinterpret_cast<const __half2*>(vec.data + jj * 2));
}
}
} else { // ELTS_PER_LDG == 1
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
const __half* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW;
row_chunk[ii] = __half2float(*scalar_ptr);
}
}
}
constexpr float threshold = 20.0f;
constexpr float beta = 1.0f;
// Hash MoE path: indices are predetermined from lookup table
if constexpr (USE_HASH) {
const IndType token_id = input_ids[thread_row];
const IndType* expert_indices_for_token = tid2eid + token_id * k;
#pragma unroll
for (int ii = 0; ii < VPT; ++ii) {
float val = row_chunk[ii];
float val_b = val * beta;
val = (val_b > threshold) ? val : (__logf(1.0f + __expf(val_b))) / beta;
row_chunk[ii] = sqrtf(val);
}
float selected_sum = 0.f;
#pragma unroll
for (int k_idx = 0; k_idx < k; ++k_idx) {
const int expert = expert_indices_for_token[k_idx];
const int idx = k * thread_row + k_idx;
for (int ii = 0; ii < VPT; ++ii) {
const int group_id = ii / ELTS_PER_LDG;
const int local_id = ii % ELTS_PER_LDG;
const int expert_idx = first_elt_read_by_thread +
group_id * THREADS_PER_ROW * ELTS_PER_LDG +
local_id;
if (expert == expert_idx) {
indices[idx] = expert;
selected_sum += row_chunk[ii];
break;
}
}
}
// Compute per-thread scale (using warp reduction when renormalizing).
if (renormalize) {
selected_sum = warpReduceSum(selected_sum);
}
float scale = static_cast<float>(routed_scaling_factor);
if (renormalize) {
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
scale /= denom;
}
#pragma unroll
for (int k_idx = 0; k_idx < k; ++k_idx) {
const int expert = expert_indices_for_token[k_idx];
const int idx = k * thread_row + k_idx;
for (int ii = 0; ii < VPT; ++ii) {
const int group_id = ii / ELTS_PER_LDG;
const int local_id = ii % ELTS_PER_LDG;
const int expert_idx = first_elt_read_by_thread +
group_id * THREADS_PER_ROW * ELTS_PER_LDG +
local_id;
if (expert == expert_idx) {
output[idx] = row_chunk[ii] * scale;
break;
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
return;
}
#pragma unroll
for (int ii = 0; ii < VPT; ++ii) {
float val = row_chunk[ii];
float val_b = val * beta;
// Compute softplus: log(1 + exp(val)) with numerical stability
// When val > threshold, softplus(x) ≈ x to avoid exp overflow
val = (val_b > threshold) ? val : (__logf(1.0f + __expf(val_b))) / beta;
val = sqrtf(val);
if (correction_bias) {
const int group_id = ii / ELTS_PER_LDG;
const int local_id = ii % ELTS_PER_LDG;
const int expert_idx = first_elt_read_by_thread +
group_id * THREADS_PER_ROW * ELTS_PER_LDG +
local_id;
val = val + correction_bias[expert_idx];
}
row_chunk[ii] = val;
}
// Original TopK path: find top-k experts by score
// Now, sigmoid_res contains the sigmoid of the row chunk. Now, I want to find
// the topk elements in each row, along with the max index.
int start_col = first_elt_read_by_thread;
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
float selected_sum = 0.f;
for (int k_idx = 0; k_idx < k; ++k_idx) {
// First, each thread does the local argmax
float max_val = row_chunk[0];
int expert = start_col;
#pragma unroll
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD;
++ldg, col += COLS_PER_GROUP_LDG) {
#pragma unroll
for (int ii = 0; ii < ELTS_PER_LDG; ++ii) {
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
// No check on the experts here since columns with the smallest index
// are processed first and only updated if > (not >=)
if (val > max_val) {
max_val = val;
expert = col + ii;
}
}
}
// Now, we perform the argmax reduce. We use the butterfly pattern so threads
// reach consensus about the max. This will be useful for K > 1 so that the
// threads can agree on "who" had the max value. That thread can then blank out
// their max with -inf and the warp can run more iterations...
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
float other_max =
VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
int other_expert =
VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
// We want lower indices to "win" in every thread so we break ties this
// way
if (other_max > max_val ||
(other_max == max_val && other_expert < expert)) {
max_val = other_max;
expert = other_expert;
}
}
// Write the max for this k iteration to global memory.
if (thread_group_idx == 0) {
// Add a guard to ignore experts not included by this node
const bool node_uses_expert =
expert >= start_expert && expert < end_expert;
const bool should_process_row = row_is_active && node_uses_expert;
// The lead thread from each sub-group will write out the final results to
// global memory. (This will be a single) thread per row of the
// input/output matrices.
const int idx = k * thread_row + k_idx;
if (correction_bias != nullptr) {
max_val -= correction_bias[expert];
}
output[idx] = max_val;
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
source_rows[idx] = k_idx * num_rows + thread_row;
if (renormalize) {
selected_sum += max_val;
}
}
// Finally, we clear the value in the thread with the current max if there
// is another iteration to run.
if (k_idx + 1 < k) {
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
const int thread_to_clear_in_group =
(expert / ELTS_PER_LDG) % THREADS_PER_ROW;
// Only the thread in the group which produced the max will reset the
// "winning" value to -inf.
if (thread_group_idx == thread_to_clear_in_group) {
const int offset_for_expert = expert % ELTS_PER_LDG;
// Safe to set to any negative value since row_chunk values must be
// between 0 and 1.
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] =
-10000.f;
}
}
}
// Apply renormalization and routed scaling factor to final weights.
if (thread_group_idx == 0) {
float scale = static_cast<float>(routed_scaling_factor);
if (renormalize) {
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
scale /= denom;
}
for (int k_idx = 0; k_idx < k; ++k_idx) {
const int idx = k * thread_row + k_idx;
output[idx] = output[idx] * scale;
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
namespace detail {
// Constructs some constants needed to partition the work across threads at
// compile time.
template <int EXPERTS, int BYTES_PER_LDG, int WARP_SIZE_PARAM,
typename InputType>
struct TopkConstants {
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 ||
EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0,
"");
static constexpr int VECs_PER_THREAD =
MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM));
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
static const int ROWS_PER_WARP = WARP_SIZE_PARAM / THREADS_PER_ROW;
};
} // namespace detail
#define DISPATCH_HASH(use_hash, USE_HASH, ...) \
if (use_hash) { \
const bool USE_HASH = true; \
static_assert(USE_HASH == true, "USE_HASH must be compile-time constant"); \
__VA_ARGS__ \
} else { \
const bool USE_HASH = false; \
static_assert(USE_HASH == false, \
"USE_HASH must be compile-time constant"); \
__VA_ARGS__ \
}
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM,
int MAX_BYTES_PER_LDG, typename IndType, typename InputType>
void topkGatingSoftplusSqrtLauncherHelper(
const InputType* input, const bool* finished, float* output,
IndType* indices, int* source_row, const int num_rows, const int k,
const int start_expert, const int end_expert, const bool renormalize,
double routed_scaling_factor, const float* correction_bias,
const bool use_hash, const IndType* input_ids, const IndType* tid2eid,
cudaStream_t stream) {
static constexpr int BYTES_PER_LDG =
MIN(MAX_BYTES_PER_LDG, sizeof(InputType) * EXPERTS);
using Constants =
detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM, InputType>;
static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
DISPATCH_HASH(use_hash, USE_HASH, {
auto* kernel =
&topkGatingSoftplusSqrt<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG,
WARP_SIZE_PARAM, USE_HASH, IndType, InputType>;
#ifndef USE_ROCM
cudaLaunchConfig_t config = {};
config.gridDim = num_blocks;
config.blockDim = block_dim;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = 1;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel, input, finished, output, num_rows,
indices, source_row, k, start_expert, end_expert,
renormalize, routed_scaling_factor, correction_bias,
input_ids, tid2eid);
#else
kernel<<<num_blocks, block_dim, 0, stream>>>(
input, finished, output, num_rows, indices, source_row, k, start_expert,
end_expert, renormalize, routed_scaling_factor, correction_bias,
input_ids, tid2eid);
#endif
})
}
#ifndef USE_ROCM
#define LAUNCH_SOFTPLUS_SQRT(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
static_assert(WARP_SIZE == 32, \
"Unsupported warp size. Only 32 is supported for CUDA"); \
topkGatingSoftplusSqrtLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, \
MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \
routed_scaling_factor, correction_bias, use_hash, input_ids, tid2eid, \
stream);
#else
#define LAUNCH_SOFTPLUS_SQRT(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
if (WARP_SIZE == 64) { \
topkGatingSoftplusSqrtLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64, \
MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \
routed_scaling_factor, correction_bias, use_hash, input_ids, \
tid2eid, stream); \
} else if (WARP_SIZE == 32) { \
topkGatingSoftplusSqrtLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32, \
MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \
routed_scaling_factor, correction_bias, use_hash, input_ids, \
tid2eid, stream); \
} else { \
assert(false && \
"Unsupported warp size. Only 32 and 64 are supported for ROCm"); \
}
#endif
template <typename IndType, typename InputType>
void topkGatingSoftplusSqrtKernelLauncher(
const InputType* gating_output, float* topk_weights, IndType* topk_indices,
int* token_expert_indices, const int num_tokens, const int num_experts,
const int topk, const bool renormalize, double routed_scaling_factor,
const float* correction_bias, const bool use_hash, const IndType* input_ids,
const IndType* tid2eid, cudaStream_t stream) {
static constexpr int WARPS_PER_TB = 4;
static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16;
#ifndef USE_ROCM
// for bfloat16 dtype, we need 4 bytes loading to make sure num_experts
// elements can be loaded by a warp
static constexpr int BYTES_PER_LDG_MULTIPLE_64 =
(std::is_same_v<InputType, __nv_bfloat16> ||
std::is_same_v<InputType, __half>)
? 4
: 8;
#endif
switch (num_experts) {
case 1:
LAUNCH_SOFTPLUS_SQRT(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 2:
LAUNCH_SOFTPLUS_SQRT(2, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 4:
LAUNCH_SOFTPLUS_SQRT(4, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 8:
LAUNCH_SOFTPLUS_SQRT(8, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 16:
LAUNCH_SOFTPLUS_SQRT(16, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 32:
LAUNCH_SOFTPLUS_SQRT(32, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 64:
LAUNCH_SOFTPLUS_SQRT(64, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 128:
LAUNCH_SOFTPLUS_SQRT(128, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 256:
LAUNCH_SOFTPLUS_SQRT(256, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 512:
LAUNCH_SOFTPLUS_SQRT(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
// (CUDA only) support multiples of 64 when num_experts is not power of 2.
// ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of
// num_experts, alternatively we can test 4 bytes loading and enable it in
// future.
#ifndef USE_ROCM
case 192:
LAUNCH_SOFTPLUS_SQRT(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break;
case 320:
LAUNCH_SOFTPLUS_SQRT(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break;
case 384:
LAUNCH_SOFTPLUS_SQRT(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break;
case 448:
LAUNCH_SOFTPLUS_SQRT(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break;
case 576:
LAUNCH_SOFTPLUS_SQRT(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break;
#endif
default: {
TORCH_CHECK(false, "Unsupported expert number: ", num_experts);
}
}
}
} // namespace moe
} // namespace vllm
template <typename ComputeType>
void dispatch_topk_softplus_sqrt_launch(
const ComputeType* gating_output, torch::Tensor& topk_weights,
torch::Tensor& topk_indices, torch::Tensor& token_expert_indices,
int num_tokens, int num_experts, int topk, bool renormalize,
double routed_scaling_factor,
const c10::optional<torch::Tensor>& correction_bias,
const c10::optional<torch::Tensor>& input_ids,
const c10::optional<torch::Tensor>& tid2eid, cudaStream_t stream) {
const float* bias_ptr = nullptr;
if (correction_bias.has_value()) {
bias_ptr = correction_bias.value().data_ptr<float>();
}
bool use_hash = false;
if (tid2eid.has_value()) {
TORCH_CHECK(input_ids.has_value(), "input_ids is required for hash MoE");
use_hash = true;
}
if (topk_indices.scalar_type() == at::ScalarType::Int) {
const int* input_ids_ptr = nullptr;
const int* tid2eid_ptr = nullptr;
if (tid2eid.has_value()) {
input_ids_ptr = input_ids.value().data_ptr<int>();
tid2eid_ptr = tid2eid.value().data_ptr<int>();
}
vllm::moe::topkGatingSoftplusSqrtKernelLauncher<int, ComputeType>(
gating_output, topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(), token_expert_indices.data_ptr<int>(),
num_tokens, num_experts, topk, renormalize, routed_scaling_factor,
bias_ptr, use_hash, input_ids_ptr, tid2eid_ptr, stream);
} else if (topk_indices.scalar_type() == at::ScalarType::UInt32) {
const uint32_t* input_ids_ptr = nullptr;
const uint32_t* tid2eid_ptr = nullptr;
if (tid2eid.has_value()) {
input_ids_ptr = input_ids.value().data_ptr<uint32_t>();
tid2eid_ptr = tid2eid.value().data_ptr<uint32_t>();
}
vllm::moe::topkGatingSoftplusSqrtKernelLauncher<uint32_t, ComputeType>(
gating_output, topk_weights.data_ptr<float>(),
topk_indices.data_ptr<uint32_t>(), token_expert_indices.data_ptr<int>(),
num_tokens, num_experts, topk, renormalize, routed_scaling_factor,
bias_ptr, use_hash, input_ids_ptr, tid2eid_ptr, stream);
} else {
TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long);
const int64_t* input_ids_ptr = nullptr;
const int64_t* tid2eid_ptr = nullptr;
if (tid2eid.has_value()) {
input_ids_ptr = input_ids.value().data_ptr<int64_t>();
tid2eid_ptr = tid2eid.value().data_ptr<int64_t>();
}
vllm::moe::topkGatingSoftplusSqrtKernelLauncher<int64_t, ComputeType>(
gating_output, topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int64_t>(), token_expert_indices.data_ptr<int>(),
num_tokens, num_experts, topk, renormalize, routed_scaling_factor,
bias_ptr, use_hash, input_ids_ptr, tid2eid_ptr, stream);
}
}
void topk_softplus_sqrt(
torch::Tensor& topk_weights, // [num_tokens, topk]
torch::Tensor& topk_indices, // [num_tokens, topk]
torch::Tensor& token_expert_indices, // [num_tokens, topk]
torch::Tensor& gating_output, // [num_tokens, num_experts]
bool renormalize, double routed_scaling_factor,
const c10::optional<torch::Tensor>& correction_bias,
const c10::optional<torch::Tensor>& input_ids,
const c10::optional<torch::Tensor>& tid2eid) {
const int num_experts = gating_output.size(-1);
const auto num_tokens = gating_output.numel() / num_experts;
const int topk = topk_weights.size(-1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (gating_output.scalar_type() == at::ScalarType::Float) {
dispatch_topk_softplus_sqrt_launch<float>(
gating_output.data_ptr<float>(), topk_weights, topk_indices,
token_expert_indices, num_tokens, num_experts, topk, renormalize,
routed_scaling_factor, correction_bias, input_ids, tid2eid, stream);
} else if (gating_output.scalar_type() == at::ScalarType::Half) {
dispatch_topk_softplus_sqrt_launch<__half>(
reinterpret_cast<const __half*>(gating_output.data_ptr<at::Half>()),
topk_weights, topk_indices, token_expert_indices, num_tokens,
num_experts, topk, renormalize, routed_scaling_factor, correction_bias,
input_ids, tid2eid, stream);
} else if (gating_output.scalar_type() == at::ScalarType::BFloat16) {
dispatch_topk_softplus_sqrt_launch<__nv_bfloat16>(
reinterpret_cast<const __nv_bfloat16*>(
gating_output.data_ptr<at::BFloat16>()),
topk_weights, topk_indices, token_expert_indices, num_tokens,
num_experts, topk, renormalize, routed_scaling_factor, correction_bias,
input_ids, tid2eid, stream);
} else {
TORCH_CHECK(false, "Unsupported gating_output data type: ",
gating_output.scalar_type());
}
}
+8
View File
@@ -16,6 +16,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"bias) -> ()");
m.impl("topk_sigmoid", torch::kCUDA, &topk_sigmoid);
#ifndef USE_ROCM
m.def(
"topk_softplus_sqrt(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output, bool renormalize, float "
"routed_scaling_factor, Tensor? "
"bias, Tensor? input_ids, Tensor? tid2eid) -> ()");
m.impl("topk_softplus_sqrt", torch::kCUDA, &topk_softplus_sqrt);
#endif
// Calculate the result of moe by summing up the partial results
// from all selected experts.
m.def("moe_sum(Tensor input, Tensor! output) -> ()");
+7 -1
View File
@@ -100,6 +100,11 @@ void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q,
bool is_neox, torch::Tensor& position_ids,
int64_t forced_token_heads_per_warp);
void fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::Tensor& q, torch::Tensor const& kv, torch::Tensor& k_cache,
torch::Tensor const& slot_mapping, torch::Tensor const& position_ids,
torch::Tensor const& cos_sin_cache, double eps, int64_t cache_block_size);
void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& prompt_mask,
const torch::Tensor& output_mask,
@@ -153,7 +158,8 @@ void silu_and_mul_per_block_quant(torch::Tensor& out,
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
torch::Tensor& cos_sin_cache, bool is_neox,
int64_t rope_dim_offset, bool inverse);
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
+17 -16
View File
@@ -18,7 +18,6 @@ namespace persistent {
// Constants
// ============================================================================
constexpr int TopK = 2048;
constexpr int kThreadsPerBlock = 1024;
constexpr int RADIX = 256;
@@ -128,11 +127,12 @@ struct RadixRowState {
struct PersistentTopKParams {
const float* __restrict__ input; // [num_rows, stride]
int32_t* __restrict__ output; // [num_rows, TopK]
int32_t* __restrict__ output; // [num_rows, top_k]
int32_t* __restrict__ lengths; // [num_rows]
RadixRowState* row_states; // large path: per-group state
uint32_t num_rows;
uint32_t stride;
uint32_t top_k; // actual k value for output stride
uint32_t chunk_size; // large path: elements per CTA
uint32_t ctas_per_group; // 1=medium, >1=large
uint32_t max_seq_len; // max seq_len across all rows (for early CTA exit)
@@ -154,6 +154,7 @@ __device__ __forceinline__ uint32_t decode_bin(float x) {
return key >> 5;
}
template <int TopK>
__device__ __noinline__ void histogram_2048_topk(
const float* __restrict__ logits, int32_t* __restrict__ output_indices,
int32_t seq_len) {
@@ -418,6 +419,7 @@ __device__ __noinline__ void histogram_2048_topk(
// by: DarkSharpness
// which at the same time is an optimized topk kernel copied from tilelang
// kernel
template <int TopK>
__device__ __noinline__ void histogram_256_topk(
const float* __restrict__ logits, int* __restrict__ output_indices,
int logits_offset, int seq_len) {
@@ -649,7 +651,7 @@ __device__ __forceinline__ void wait_ge(int* ptr, int target_val,
// Adapted from https://github.com/flashinfer-ai/flashinfer/pull/2215
// ============================================================================
template <uint32_t VEC_SIZE>
template <int TopK, uint32_t VEC_SIZE>
__device__ void radix_topk(const float* __restrict__ row_input,
int32_t* __restrict__ row_output, uint32_t seq_len,
uint32_t my_chunk_start, uint32_t chunk_size,
@@ -857,7 +859,7 @@ __device__ void radix_topk(const float* __restrict__ row_input,
// see filtered_topk.cuh)
// ============================================================================
template <uint32_t VEC_SIZE = 1>
template <int TopK = 2048, uint32_t VEC_SIZE = 1>
__global__ void __launch_bounds__(kThreadsPerBlock, 2)
persistent_topk_kernel(PersistentTopKParams params) {
const uint32_t tx = threadIdx.x;
@@ -915,7 +917,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 2)
if (row_idx >= params.num_rows) break;
const uint32_t seq_len = params.lengths[row_idx];
int32_t* row_output = params.output + row_idx * TopK;
int32_t* row_output = params.output + row_idx * params.top_k;
const float* row_input = params.input + row_idx * params.stride;
if (seq_len <= RADIX_THRESHOLD) {
@@ -927,19 +929,19 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 2)
row_output[i] = (i < seq_len) ? static_cast<int32_t>(i) : -1;
}
} else if (seq_len <= static_cast<uint32_t>(HIST2048_THRESHOLD)) {
histogram_2048_topk(row_input, row_output, seq_len);
histogram_2048_topk<TopK>(row_input, row_output, seq_len);
} else {
histogram_256_topk(row_input, row_output, 0, seq_len);
histogram_256_topk<TopK>(row_input, row_output, 0, seq_len);
}
}
continue;
}
const uint32_t my_chunk_start = cta_in_group * chunk_size;
radix_topk<VEC_SIZE>(row_input, row_output, seq_len, my_chunk_start,
chunk_size, local_histogram, suffix_sum,
shared_scalars, shared_ordered, state, cta_in_group,
ctas_per_group, barrier_phase, iter, tx);
radix_topk<TopK, VEC_SIZE>(
row_input, row_output, seq_len, my_chunk_start, chunk_size,
local_histogram, suffix_sum, shared_scalars, shared_ordered, state,
cta_in_group, ctas_per_group, barrier_phase, iter, tx);
}
}
@@ -1011,7 +1013,6 @@ struct FilteredTopKTraits<float> {
}
};
constexpr uint32_t FILTERED_TOPK_MAX_K = 2048;
constexpr uint32_t FILTERED_TOPK_BLOCK_THREADS = 1024;
constexpr uint32_t FILTERED_TOPK_SMEM_INPUT_SIZE =
16 * 1024; // 16K indices per buffer
@@ -1025,7 +1026,7 @@ constexpr size_t FILTERED_TOPK_SMEM_DYNAMIC =
* \tparam IdType Index type (int32_t)
* \tparam VEC_SIZE Vector size for input loads (1, 2, 4, or 8)
*/
template <typename DType, typename IdType, int VEC_SIZE>
template <typename DType, typename IdType, int VEC_SIZE, uint32_t MAX_K = 2048>
__global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
FilteredTopKUnifiedKernel(const DType* __restrict__ input,
IdType* __restrict__ output,
@@ -1059,7 +1060,7 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
alignas(128) __shared__ int s_counter;
alignas(128) __shared__ int s_threshold_bin_id;
alignas(128) __shared__ int s_num_input[2];
alignas(128) __shared__ int s_indices[FILTERED_TOPK_MAX_K];
alignas(128) __shared__ int s_indices[MAX_K];
auto& s_histogram = s_histogram_buf[0];
@@ -1280,7 +1281,7 @@ constexpr int ComputeFilteredTopKVecSize(uint32_t max_len) {
return static_cast<int>(g);
}
template <typename DType, typename IdType>
template <typename DType, typename IdType, uint32_t MAX_K = 2048>
cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices,
IdType* lengths, uint32_t num_rows,
uint32_t top_k_val, uint32_t max_len,
@@ -1297,7 +1298,7 @@ cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices,
#define DISPATCH_VEC_SIZE(VS) \
if (vec_size == VS) { \
auto kernel = FilteredTopKUnifiedKernel<DType, IdType, VS>; \
auto kernel = FilteredTopKUnifiedKernel<DType, IdType, VS, MAX_K>; \
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( \
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); \
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, grid, block, args, \
+39 -33
View File
@@ -9,28 +9,29 @@ namespace vllm {
template <typename scalar_t, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding(
scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) {
scalar_t* __restrict__ arr, const float* __restrict__ cos_ptr,
const float* __restrict__ sin_ptr, int rot_offset, int embed_dim,
const bool inverse) {
int x_index, y_index;
scalar_t cos, sin;
float cos_f, sin_f;
if (IS_NEOX) {
// GPT-NeoX style rotary embedding.
x_index = rot_offset;
y_index = embed_dim + rot_offset;
cos = VLLM_LDG(cos_ptr + x_index);
sin = VLLM_LDG(sin_ptr + x_index);
cos_f = VLLM_LDG(cos_ptr + x_index);
sin_f = VLLM_LDG(sin_ptr + x_index);
} else {
// GPT-J style rotary embedding.
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = VLLM_LDG(cos_ptr + x_index / 2);
sin = VLLM_LDG(sin_ptr + x_index / 2);
cos_f = VLLM_LDG(cos_ptr + x_index / 2);
sin_f = VLLM_LDG(sin_ptr + x_index / 2);
}
const scalar_t x = arr[x_index];
const scalar_t y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
if (inverse) {
sin_f = -sin_f;
}
const float x_f = static_cast<float>(arr[x_index]);
const float y_f = static_cast<float>(arr[y_index]);
arr[x_index] = static_cast<scalar_t>(x_f * cos_f - y_f * sin_f);
arr[y_index] = static_cast<scalar_t>(y_f * cos_f + x_f * sin_f);
}
template <typename scalar_t, bool IS_NEOX>
@@ -42,22 +43,23 @@ inline __device__ void apply_rotary_embedding(
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* cache_ptr, const int head_size, const int num_heads,
const float* cache_ptr, const int head_size, const int num_heads,
const int num_kv_heads, const int rot_dim, const int token_idx,
const int64_t query_stride, const int64_t key_stride,
const int64_t head_stride) {
const int64_t head_stride, const int64_t rope_dim_offset,
const bool inverse) {
const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim;
const float* cos_ptr = cache_ptr;
const float* sin_ptr = cache_ptr + embed_dim;
const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head =
token_idx * query_stride + head_idx * head_stride;
token_idx * query_stride + head_idx * head_stride + rope_dim_offset;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim, inverse);
}
if (key != nullptr) {
@@ -65,10 +67,10 @@ inline __device__ void apply_rotary_embedding(
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head =
token_idx * key_stride + head_idx * head_stride;
token_idx * key_stride + head_idx * head_stride + rope_dim_offset;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim, inverse);
}
}
}
@@ -84,19 +86,18 @@ __global__ void rotary_embedding_kernel(
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2]
const float* __restrict__ cos_sin_cache, // [max_position, rot_dim] fp32
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int64_t head_stride, const int num_heads, const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int head_size, const int64_t rope_dim_offset, const bool inverse) {
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const float* cache_ptr = cos_sin_cache + pos * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride, head_stride);
token_idx, query_stride, key_stride, head_stride, rope_dim_offset,
inverse);
}
} // namespace vllm
@@ -115,7 +116,7 @@ void rotary_embedding(
// [num_tokens, num_heads, head_size]
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
bool is_neox, int64_t rope_dim_offset, bool inverse) {
// num_tokens = batch_size * seq_len
int64_t num_tokens = positions.numel();
int positions_ndim = positions.dim();
@@ -154,6 +155,8 @@ void rotary_embedding(
int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
TORCH_CHECK((rot_dim + rope_dim_offset) <= head_size);
// Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size
@@ -165,20 +168,23 @@ void rotary_embedding(
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto cache_f32 = cos_sin_cache.to(torch::kFloat32);
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
if (is_neox) {
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride,
head_stride, num_heads, num_kv_heads, head_size);
cache_f32.data_ptr<float>(), rot_dim, query_stride, key_stride,
head_stride, num_heads, num_kv_heads, head_size, rope_dim_offset,
inverse);
} else {
vllm::rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
key_stride, head_stride, num_heads, num_kv_heads, head_size);
cache_f32.data_ptr<float>(), rot_dim, query_stride, key_stride,
head_stride, num_heads, num_kv_heads, head_size, rope_dim_offset,
inverse);
}
});
}
+7 -1
View File
@@ -258,7 +258,13 @@ __device__ bool processHistogramStep(
auto processBins = [&](float logit, int idx) {
if (isPartialMatch<patternShift>(logit, logitPattern)) {
uint32_t binIdx = extractBinIdx<step>(logit);
if (binIdx < thresholdBinIdx) {
// Only write elements with binIdx < thresholdBinIdx when:
// 1. This is step 0 and the threshold bin is small enough (no step 1)
// 2. This is step >= 1 (where pattern matching filters correctly)
// This prevents duplicates when step 0 and step 1 both run.
bool shouldWriteDirectly =
(step == 0 && smemFinalBinSize[0] <= kNumFinalItems) || (step >= 1);
if (binIdx < thresholdBinIdx && shouldWriteDirectly) {
// The element is part of the top-k selection
int dstIdx = atomicAdd(&smemFoundTopKValues[0], 1);
+59 -35
View File
@@ -10,33 +10,17 @@
#include "persistent_topk.cuh"
#endif
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
int64_t max_seq_len) {
namespace {
#ifndef USE_ROCM
TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor");
TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor");
TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor");
TORCH_CHECK(logits.dtype() == torch::kFloat32, "Only float32 supported");
TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32");
TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32");
TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2,
"lengths must be 1D or 2D");
TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous");
TORCH_CHECK(output.dim() == 2, "output must be 2D");
template <int TopK>
void launch_persistent_topk(const torch::Tensor& logits,
const torch::Tensor& lengths, torch::Tensor& output,
torch::Tensor& workspace, int64_t max_seq_len) {
namespace P = vllm::persistent;
const int64_t num_rows = logits.size(0);
const int64_t stride = logits.size(1);
TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch");
TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
"output size mismatch");
namespace P = vllm::persistent;
TORCH_CHECK(k == P::TopK, "k must be 2048");
TORCH_CHECK(k <= stride, "k out of range");
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
static int num_sms = 0;
@@ -50,18 +34,17 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
}
if (num_rows > 32 && max_smem_per_block >= 128 * 1024) {
cudaError_t status = vllm::FilteredTopKRaggedTransform<float, int32_t>(
logits.data_ptr<float>(), output.data_ptr<int32_t>(),
lengths.data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
static_cast<uint32_t>(k), static_cast<uint32_t>(stride), stream);
cudaError_t status =
vllm::FilteredTopKRaggedTransform<float, int32_t, TopK>(
logits.data_ptr<float>(), output.data_ptr<int32_t>(),
lengths.data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
static_cast<uint32_t>(TopK), static_cast<uint32_t>(stride), stream);
TORCH_CHECK(status == cudaSuccess,
"FilteredTopK failed: ", cudaGetErrorString(status));
} else {
TORCH_CHECK(workspace.is_cuda(), "workspace must be CUDA tensor");
TORCH_CHECK(workspace.dtype() == torch::kUInt8, "workspace must be uint8");
// Smem cap: smaller smem → more CTAs/group → more per-row parallelism for
// large path. Empirically tuned.
int effective_max_smem;
if (num_rows <= 4) {
effective_max_smem =
@@ -101,7 +84,7 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
int occupancy = 1;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, P::persistent_topk_kernel<4>, P::kThreadsPerBlock,
&occupancy, P::persistent_topk_kernel<TopK, 4>, P::kThreadsPerBlock,
smem_size);
if (occupancy < 1) occupancy = 1;
@@ -121,15 +104,16 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
params.lengths = lengths.data_ptr<int32_t>();
params.num_rows = static_cast<uint32_t>(num_rows);
params.stride = static_cast<uint32_t>(stride);
params.top_k = static_cast<uint32_t>(TopK);
params.chunk_size = chunk_size;
params.row_states =
reinterpret_cast<P::RadixRowState*>(workspace.data_ptr<uint8_t>());
params.ctas_per_group = ctas_per_group;
params.max_seq_len = static_cast<uint32_t>(max_seq_len);
#define LAUNCH_PERSISTENT(VS) \
#define LAUNCH_PERSISTENT(TOPK_VAL, VS) \
do { \
auto kernel = &P::persistent_topk_kernel<VS>; \
auto kernel = &P::persistent_topk_kernel<TOPK_VAL, VS>; \
cudaError_t err = cudaFuncSetAttribute( \
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); \
TORCH_CHECK(err == cudaSuccess, \
@@ -138,11 +122,11 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
} while (0)
if (vec_size == 4) {
LAUNCH_PERSISTENT(4);
LAUNCH_PERSISTENT(TopK, 4);
} else if (vec_size == 2) {
LAUNCH_PERSISTENT(2);
LAUNCH_PERSISTENT(TopK, 2);
} else {
LAUNCH_PERSISTENT(1);
LAUNCH_PERSISTENT(TopK, 1);
}
#undef LAUNCH_PERSISTENT
}
@@ -150,6 +134,46 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
cudaError_t err = cudaGetLastError();
TORCH_CHECK(err == cudaSuccess,
"persistent_topk failed: ", cudaGetErrorString(err));
}
#endif
} // anonymous namespace
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
int64_t max_seq_len) {
#ifndef USE_ROCM
TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor");
TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor");
TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor");
TORCH_CHECK(logits.dtype() == torch::kFloat32, "Only float32 supported");
TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32");
TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32");
TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2,
"lengths must be 1D or 2D");
TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous");
TORCH_CHECK(output.dim() == 2, "output must be 2D");
const int64_t num_rows = logits.size(0);
const int64_t stride = logits.size(1);
TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch");
TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
"output size mismatch");
TORCH_CHECK(k == 512 || k == 1024 || k == 2048,
"persistent_topk supports k=512, k=1024, or k=2048, got k=", k);
if (k == 512) {
launch_persistent_topk<512>(logits, lengths, output, workspace,
max_seq_len);
} else if (k == 1024) {
launch_persistent_topk<1024>(logits, lengths, output, workspace,
max_seq_len);
} else {
launch_persistent_topk<2048>(logits, lengths, output, workspace,
max_seq_len);
}
#else
TORCH_CHECK(false, "persistent_topk is not supported on ROCm");
#endif
+15 -1
View File
@@ -177,6 +177,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int forced_token_heads_per_warp=-1) -> ()");
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);
#ifndef USE_ROCM
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
// kernel launch.
ops.def(
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert("
"Tensor! q, Tensor kv, Tensor! k_cache, "
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
"float eps, int cache_block_size) -> ()");
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA,
&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert);
#endif
// Apply repetition penalties to logits in-place
ops.def(
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
@@ -240,7 +253,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()");
" Tensor cos_sin_cache, bool is_neox, int "
"rope_dim_offset=0, bool inverse=False) -> ()");
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
// Quantization ops
+1 -1
View File
@@ -213,7 +213,7 @@ configuration.
| `FLASHINFER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x |
| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x |
| `FLASHMLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x |
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 512, 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
| `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
+4 -3
View File
@@ -384,6 +384,7 @@ th {
| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | ✅︎ | ✅︎ |
| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | ✅︎ | ✅︎ |
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`, etc. | ✅︎ | ✅︎ |
| `DeepseekV4ForCausalLM` | DeepSeek-V4 | `deepseek-ai/DeepSeek-V4-Flash`, `deepseek-ai/DeepSeek-V4-Pro`, etc. | | |
| `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ |
| `DotsOCRForCausalLM` | dots_ocr | `rednote-hilab/dots.ocr` | ✅︎ | ✅︎ |
| `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ |
@@ -643,10 +644,10 @@ Some models are supported only via the [Transformers modeling backend](#transfor
!!! note
`Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its
MobileNet-v5 vision backbone.
Performance is not yet fully optimized mainly due to:
- Both audio and vision MM encoders use `transformers.AutoModel` implementation.
- Both audio and vision MM encoders use `transformers.AutoModel` implementation.
- There's no PLE caching or out-of-memory swapping support, as described in [Google's blog](https://developers.googleblog.com/en/introducing-gemma-3n/). These features might be too model-specific for vLLM, and swapping in particular may be better suited for constrained setups.
!!! note
+2
View File
@@ -11,6 +11,8 @@ torchvision==0.26.0 # Required for phi3v processor. See https://github.com/pytor
# FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.6.8.post1
flashinfer-cubin==0.6.8.post1
apache-tvm-ffi==0.1.9
tilelang==0.1.9
# Cap nvidia-cudnn-frontend (transitive dep of flashinfer) due to
# breaking changes in 1.19.0
nvidia-cudnn-frontend>=1.13.0,<1.19.0
+5
View File
@@ -116,6 +116,11 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
model_kwargs["attention_config"] = {"backend": attn_backend.backend.name}
model_kwargs["tensor_parallel_size"] = tp_size
# Cap warmup memory: tests use small max_model_len (1024) but the
# engine default max_num_batched_tokens is 16384. Warming up large
# models (e.g. Llama-4-Scout-FP8) at 16384 tokens may trigger OOM.
model_kwargs.setdefault("max_num_batched_tokens", 8192)
# Sparse MLA models (DSv3.2) hit an over-strict inductor assertion in
# decompose_auto_functionalized when +rotary_embedding is forced into
# the compile graph. Disable qk_norm+rope fusion (which auto-enables
@@ -9,8 +9,8 @@ from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
_ceil_to_ue8m0,
calc_diff,
fp8_mqa_logits,
fp8_paged_mqa_logits,
fp8_fp4_mqa_logits,
fp8_fp4_paged_mqa_logits,
get_num_sms,
get_paged_mqa_logits_metadata,
)
@@ -127,8 +127,8 @@ def test_deepgemm_fp8_mqa_logits(clean_logits: bool):
q_fp8 = q.to(torch.float8_e4m3fn)
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False)
logits = fp8_mqa_logits(
q_fp8, kv_fp8, weights, ks, ke, clean_logits=clean_logits
logits = fp8_fp4_mqa_logits(
(q_fp8, None), kv_fp8, weights, ks, ke, clean_logits=clean_logits
)
ref_logits = _ref_fp8_mqa_logits(
@@ -150,7 +150,7 @@ def test_deepgemm_fp8_mqa_logits(clean_logits: bool):
assert diff < 1e-3, f"{diff=}"
def _ref_fp8_paged_mqa_logits(
def _ref_fp8_fp4_paged_mqa_logits(
q: torch.Tensor,
kv_cache: torch.Tensor,
weights: torch.Tensor,
@@ -205,8 +205,10 @@ def _ref_fp8_paged_mqa_logits(
@pytest.mark.skipif(
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
)
@pytest.mark.parametrize("clean_logits", [True, False])
def test_deepgemm_fp8_paged_mqa_logits(clean_logits: bool):
def test_deepgemm_fp8_fp4_paged_mqa_logits():
# NOTE: clean_logits=True is incompatible with the 2D context_lens
# required by csrc/apis/attention.hpp; only the False path is exercised.
clean_logits = False
torch.manual_seed(0)
random.seed(0)
@@ -258,21 +260,29 @@ def test_deepgemm_fp8_paged_mqa_logits(clean_logits: bool):
q_fp8 = q.to(torch.float8_e4m3fn)
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
# deep_gemm paged MQA logits requires 2D context_lens of
# shape (B, next_n) (csrc/apis/attention.hpp:332-335);
# see indexer.py:607-608. For each batch/next_n token, the
# effective context length is context_lens[b] - next_n + j + 1.
next_n_arange = torch.arange(next_n, device="cuda", dtype=torch.int32)
context_lens_2d = (
context_lens.unsqueeze(-1) - next_n + 1 + next_n_arange
).contiguous()
schedule_metadata = get_paged_mqa_logits_metadata(
context_lens, blocksize, get_num_sms()
context_lens_2d, blocksize, get_num_sms()
)
logits = fp8_paged_mqa_logits(
q_fp8,
logits = fp8_fp4_paged_mqa_logits(
(q_fp8, None),
kv_cache_fp8,
weights,
context_lens,
context_lens_2d,
block_tables,
schedule_metadata,
max_model_len,
clean_logits=clean_logits,
)
ref_logits = _ref_fp8_paged_mqa_logits(
ref_logits = _ref_fp8_fp4_paged_mqa_logits(
q,
kv_cache,
weights,
@@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Correctness + large-token-count launch tests for fused_q_kv_rmsnorm.
Before the grid-dim fix the kernel used grid ``(2, num_tokens)``, which hit
CUDA's 65535 grid-y cap for ``num_tokens >= 65536`` and failed with
``Triton Error [CUDA]: invalid argument`` at every large chunked-prefill
profile run. These tests pin the new grid layout.
"""
from __future__ import annotations
import pytest
import torch
from vllm.platforms import current_platform
from vllm.v1.attention.ops.deepseek_v4_ops import fused_q_kv_rmsnorm
pytestmark = pytest.mark.skipif(
not current_platform.is_cuda_alike(),
reason="fused_q_kv_rmsnorm requires a CUDA/ROCm device",
)
def _ref_rmsnorm(x: torch.Tensor, w: torch.Tensor, eps: float) -> torch.Tensor:
x_f32 = x.to(torch.float32)
variance = x_f32.pow(2).mean(dim=-1, keepdim=True)
y = x_f32 * torch.rsqrt(variance + eps) * w.to(torch.float32)
return y.to(x.dtype)
@pytest.mark.parametrize("num_tokens", [1, 17, 1024, 8192])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_fused_q_kv_rmsnorm_correctness(num_tokens: int, dtype: torch.dtype):
torch.manual_seed(0)
device = "cuda"
q_size, kv_size = 192, 576
qr = torch.randn(num_tokens, q_size, dtype=dtype, device=device)
kv = torch.randn(num_tokens, kv_size, dtype=dtype, device=device)
qw = torch.randn(q_size, dtype=dtype, device=device)
kvw = torch.randn(kv_size, dtype=dtype, device=device)
eps = 1e-6
qr_out, kv_out = fused_q_kv_rmsnorm(qr, kv, qw, kvw, eps)
qr_ref = _ref_rmsnorm(qr, qw, eps)
kv_ref = _ref_rmsnorm(kv, kvw, eps)
tol = dict(rtol=1e-2, atol=1e-2)
torch.testing.assert_close(qr_out, qr_ref, **tol)
torch.testing.assert_close(kv_out, kv_ref, **tol)
@pytest.mark.parametrize("num_tokens", [65535, 65536, 131072])
def test_fused_q_kv_rmsnorm_launches_past_grid_y_cap(num_tokens: int):
"""Regression guard: grid used to be (2, num_tokens), hitting CUDA's
65535 grid-y cap at num_tokens >= 65536. The new grid (num_tokens, 2)
lifts that bound to 2**31-1."""
device = "cuda"
dtype = torch.bfloat16
q_size, kv_size = 192, 576
qr = torch.randn(num_tokens, q_size, dtype=dtype, device=device)
kv = torch.randn(num_tokens, kv_size, dtype=dtype, device=device)
qw = torch.randn(q_size, dtype=dtype, device=device)
kvw = torch.randn(kv_size, dtype=dtype, device=device)
qr_out, kv_out = fused_q_kv_rmsnorm(qr, kv, qw, kvw, 1e-6)
# spot-check a couple of rows against the torch reference
for row in (0, num_tokens // 2, num_tokens - 1):
torch.testing.assert_close(
qr_out[row],
_ref_rmsnorm(qr[row : row + 1], qw, 1e-6)[0],
rtol=1e-2,
atol=1e-2,
)
torch.testing.assert_close(
kv_out[row],
_ref_rmsnorm(kv[row : row + 1], kvw, 1e-6)[0],
rtol=1e-2,
atol=1e-2,
)
+195 -1
View File
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit-test DeepGEMM FP8 kernels (no DeepEP).
Unit-test DeepGEMM FP8 and FP4 kernels (no DeepEP).
Compare DeepGEMM path against the Triton fallback inside vLLM's fused_experts.
"""
@@ -21,6 +21,8 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
FusedMoEQuantDesc,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
@@ -204,3 +206,195 @@ def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_i
f"DeepGEMM path was not executed during the test. "
f"Call counter: {call_counter['cnt']}"
)
# ---------------------------------------------------------------------------
# FP4 weight tests (DeepGEMM m_grouped_fp8_fp4_gemm_nt_contiguous)
# ---------------------------------------------------------------------------
def make_mxfp4_weights(
e: int,
n: int,
k: int,
):
"""
Generate (w1, w2) expert weights in MXFP4 packed format with float32 scales,
plus BF16 reference weights for validation.
w1 shape: (E, 2N, K//2) uint8 — packed FP4
w2 shape: (E, K, N//2) uint8 — packed FP4
w1_s shape: (E, 2N, K//32) float32 — per-row block-32 scales
w2_s shape: (E, K, N//32) float32 — per-row block-32 scales
w1_bf16: (E, 2N, K) — original BF16 for reference
w2_bf16: (E, K, N) — original BF16 for reference
"""
from deep_gemm.utils.math import per_token_cast_to_fp4
dtype = torch.bfloat16
gran_k = 32 # MXFP4 block size
# bf16 reference weights — scale by 1/sqrt(dim) for numerical stability
w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) * (k**-0.5)
w2_bf16 = torch.randn(e, k, n, device="cuda", dtype=dtype) * (n**-0.5)
# Quantize per-expert to FP4
w1 = torch.empty(e, 2 * n, k // 2, device="cuda", dtype=torch.uint8)
w2 = torch.empty(e, k, n // 2, device="cuda", dtype=torch.uint8)
w1_s = torch.empty(
e, 2 * n, math.ceil(k / gran_k), device="cuda", dtype=torch.float32
)
w2_s = torch.empty(e, k, math.ceil(n / gran_k), device="cuda", dtype=torch.float32)
for i in range(e):
w1[i], w1_s[i] = per_token_cast_to_fp4(
w1_bf16[i].float(), use_ue8m0=True, gran_k=gran_k
)
w2[i], w2_s[i] = per_token_cast_to_fp4(
w2_bf16[i].float(), use_ue8m0=True, gran_k=gran_k
)
return w1, w2, w1_s, w2_s, w1_bf16, w2_bf16
def _bf16_moe_reference(x, w1, w2, topk_weights, topk_ids):
"""BF16 token-loop MoE reference for correctness testing."""
import torch.nn.functional as F
num_tokens, hidden_size = x.shape
intermediate = w1.shape[1] // 2
top_k = topk_ids.shape[1]
output = torch.zeros(num_tokens, hidden_size, dtype=torch.float32, device=x.device)
for t in range(num_tokens):
for kk in range(top_k):
e = topk_ids[t, kk].item()
w = topk_weights[t, kk].item()
fc1 = x[t : t + 1].float() @ w1[e].float().T
linear = fc1[:, :intermediate]
gate = fc1[:, intermediate:]
act = F.silu(gate) * linear
fc2 = act @ w2[e].float().T
output[t] += w * fc2[0]
return output.to(torch.bfloat16)
def run_single_fp4_case(m, n, k, topk, num_experts):
"""
Run one (M,N,K) configuration with FP4 weights on DeepGEMM and assert
DeepGEMM FP4 == BF16 reference within tolerance.
"""
tokens_bf16 = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) * (k**-0.5)
# FP4 expert weight tensors + BF16 originals for reference
w1, w2, w1_s, w2_s, w1_bf16, w2_bf16 = make_mxfp4_weights(num_experts, n, k)
router_logits = torch.randn(m, num_experts, device="cuda", dtype=torch.float32)
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.platforms import current_platform
_fp8_dtype = current_platform.fp8_dtype()
_block_shape = GroupShape(128, 128)
quant_config = FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(_fp8_dtype, _block_shape, None, None, None, None),
_a2=FusedMoEQuantDesc(_fp8_dtype, _block_shape, None, None, None, None),
_w1=FusedMoEQuantDesc("mxfp4", None, w1_s, None, None, None),
_w2=FusedMoEQuantDesc("mxfp4", None, w2_s, None, None, None),
)
moe_config = make_dummy_moe_config()
from vllm.model_executor.layers.fused_moe.experts.deep_gemm_moe import (
DeepGemmFP4Experts,
)
deep_gemm_fp4_experts = mk.FusedMoEKernel(
prepare_finalize=maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
fused_experts=DeepGemmFP4Experts(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=False,
)
# DeepGEMM FP4 path
out_deepgemm_fp4 = deep_gemm_fp4_experts.apply(
hidden_states=tokens_bf16,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
global_num_experts=num_experts,
activation=MoEActivation.SILU,
apply_router_weight_on_input=False,
expert_map=None,
)
# BF16 reference using the same original weights
out_ref = _bf16_moe_reference(tokens_bf16, w1_bf16, w2_bf16, topk_weights, topk_ids)
# FP4 vs BF16 reference: quantization error from FP4 weights + FP8 activations
diff = calc_diff(out_deepgemm_fp4, out_ref)
assert diff < 0.05, f"FP4 diff exceeded 5%: {diff}"
# DeepSeek V4 dims: H=4096, I=2048, so N=2*I=4096, K=H=4096.
# FP4 quantization with block_k=32 needs large K for good accuracy.
FP4_MNKs = [
(128, 4096, 4096), # DeepSeek V4 shape
(256, 2048, 2048), # Half-size variant
]
FP4_TOPKS = [2]
FP4_NUM_EXPERTS = [8]
@pytest.mark.parametrize(("m", "n", "k"), FP4_MNKs)
@pytest.mark.parametrize("topk", FP4_TOPKS)
@pytest.mark.parametrize("num_experts", FP4_NUM_EXPERTS)
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
def test_deepgemm_fp4_vs_triton(
m, n, k, topk, num_experts, monkeypatch, workspace_init
):
pytest.importorskip("deep_gemm.utils.math")
with monkeypatch.context() as mp:
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
_DeepGemmFP4Experts = importlib.import_module(
"vllm.model_executor.layers.fused_moe.experts.deep_gemm_moe"
).DeepGemmFP4Experts
call_counter = {"cnt": 0}
orig_fn = _DeepGemmFP4Experts.apply
def _spy_apply(*args, **kwargs):
call_counter["cnt"] += 1
return orig_fn(*args, **kwargs)
monkeypatch.setattr(_DeepGemmFP4Experts, "apply", _spy_apply)
if topk > num_experts:
pytest.skip(f"topk={topk} > num_experts={num_experts}")
run_single_fp4_case(
m=m,
n=n,
k=k,
topk=topk,
num_experts=num_experts,
)
# ensure that the DeepGEMM FP4 path was indeed taken.
assert call_counter["cnt"] == 1, (
f"DeepGEMM FP4 path was not executed during the test. "
f"Call counter: {call_counter['cnt']}"
)
@@ -0,0 +1,186 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
import torch.nn.functional as F
from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType,
get_routing_method_type,
)
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
fused_topk_bias,
)
from vllm.platforms import current_platform
def _torch_topk_softplus_sqrt(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
routed_scaling_factor: float,
e_score_correction_bias: torch.Tensor | None = None,
input_ids: torch.Tensor | None = None,
hash_indices_table: torch.Tensor | None = None,
):
scores = F.softplus(gating_output.float()).sqrt()
original_scores = scores
if e_score_correction_bias is not None:
scores_for_choice = scores + e_score_correction_bias.unsqueeze(0)
else:
scores_for_choice = scores
if hash_indices_table is not None:
assert input_ids is not None
topk_ids = hash_indices_table[input_ids.long()]
else:
topk_ids = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=True)[1]
topk_weights = original_scores.gather(1, topk_ids.long())
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if routed_scaling_factor != 1.0:
topk_weights = topk_weights * routed_scaling_factor
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def test_sqrtsoftplus_bias_uses_deepseek_v4_routing_method():
assert (
get_routing_method_type(
scoring_func="sqrtsoftplus",
top_k=8,
renormalize=True,
num_expert_group=None,
has_e_score_bias=True,
)
== RoutingMethodType.DeepseekV4
)
assert (
get_routing_method_type(
scoring_func="sqrtsoftplus",
top_k=8,
renormalize=False,
num_expert_group=None,
has_e_score_bias=True,
)
== RoutingMethodType.Unspecified
)
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
@pytest.mark.parametrize("num_tokens", [1, 33, 128])
@pytest.mark.parametrize("hidden_size", [1024, 2048])
@pytest.mark.parametrize("num_experts", [128, 256, 384, 512])
@pytest.mark.parametrize("topk", [6, 8, 16])
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 1.5])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
def test_fused_topk_softplus_sqrt(
num_tokens: int,
hidden_size: int,
num_experts: int,
topk: int,
renormalize: bool,
routed_scaling_factor: float,
dtype: torch.dtype,
):
torch.manual_seed(0)
hidden_states = torch.randn((num_tokens, hidden_size), dtype=dtype, device="cuda")
gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda")
e_score_correction_bias = torch.randn(
(num_experts,), dtype=torch.float32, device="cuda"
)
topk_weights_ref, topk_ids_ref = _torch_topk_softplus_sqrt(
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
)
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=gating_output,
scoring_func="sqrtsoftplus",
e_score_correction_bias=e_score_correction_bias,
topk=topk,
renormalize=renormalize,
routed_scaling_factor=routed_scaling_factor,
)
# Different kernels may return the topk experts in different orders when
# scores tie; sort by expert id before comparing.
sorted_ref_ids, idx_ref = topk_ids_ref.sort(dim=-1)
sorted_ids, idx_ops = topk_ids.sort(dim=-1)
torch.testing.assert_close(sorted_ref_ids, sorted_ids, atol=0, rtol=0)
sorted_w_ref = topk_weights_ref.gather(1, idx_ref)
sorted_w = topk_weights.gather(1, idx_ops)
torch.testing.assert_close(sorted_w_ref, sorted_w, atol=2e-2, rtol=1e-2)
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
@pytest.mark.parametrize("num_tokens", [1, 33, 128])
@pytest.mark.parametrize("hidden_size", [1024, 2048])
@pytest.mark.parametrize("num_experts", [256, 384, 512])
@pytest.mark.parametrize("topk", [6, 8, 16])
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
def test_fused_topk_softplus_sqrt_hash(
num_tokens: int,
hidden_size: int,
num_experts: int,
topk: int,
renormalize: bool,
routed_scaling_factor: float,
dtype: torch.dtype,
):
torch.manual_seed(0)
vocab_size = 1024
hidden_states = torch.randn((num_tokens, hidden_size), dtype=dtype, device="cuda")
gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda")
# Per-token fixed expert selection: for each vocab id pick `topk` distinct
# experts.
hash_indices_table = torch.stack(
[torch.randperm(num_experts)[:topk] for _ in range(vocab_size)]
).to(device="cuda", dtype=torch.int32)
input_ids = torch.randint(
0, vocab_size, (num_tokens,), dtype=torch.int32, device="cuda"
)
topk_weights_ref, topk_ids_ref = _torch_topk_softplus_sqrt(
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
routed_scaling_factor=routed_scaling_factor,
input_ids=input_ids,
hash_indices_table=hash_indices_table,
)
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=gating_output,
scoring_func="sqrtsoftplus",
e_score_correction_bias=None,
topk=topk,
renormalize=renormalize,
input_tokens=input_ids,
hash_indices_table=hash_indices_table,
routed_scaling_factor=routed_scaling_factor,
)
sorted_ref_ids, idx_ref = topk_ids_ref.sort(dim=-1)
sorted_ids, idx_ops = topk_ids.sort(dim=-1)
torch.testing.assert_close(sorted_ref_ids, sorted_ids, atol=0, rtol=0)
sorted_w_ref = topk_weights_ref.gather(1, idx_ref)
sorted_w = topk_weights.gather(1, idx_ops)
torch.testing.assert_close(sorted_w_ref, sorted_w, atol=2e-2, rtol=1e-2)
+311
View File
@@ -0,0 +1,311 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Round-trip tests for compressor → FP8 quant + KV cache insert → gather + dequant.
Two paths tested:
A) DeepseekV4 Attention: head_dim=512 (448 FP8 nope + 64 bf16 rope), quant_block=64
B) Indexer: head_dim=128 (all FP8), quant_block=128
These serve as golden references for validating the future fused
compressor+quant+cache kernel.
"""
import math
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.v1.attention.ops.deepseek_v4_ops import (
dequantize_and_gather_k_cache,
quantize_and_insert_k_cache,
)
def _ue8m0_reference(x: torch.Tensor, block_size: int, fp8_max: float):
"""PyTorch reference for UE8M0 FP8 quantization (per-block, power-of-2 scale).
Returns (x_fp8, scales) where x_fp8 is float8_e4m3fn and scales are float32.
"""
assert x.dim() == 1
n = x.numel()
n_blocks = math.ceil(n / block_size)
x_fp8 = torch.zeros(n, dtype=torch.float8_e4m3fn, device=x.device)
scales = torch.zeros(n_blocks, dtype=torch.float32, device=x.device)
for i in range(n_blocks):
start = i * block_size
end = min(start + block_size, n)
block = x[start:end].float()
amax = block.abs().max().clamp(min=1e-4)
raw_scale = amax / fp8_max
exponent = math.ceil(math.log2(raw_scale.item()))
scale = 2.0**exponent
scales[i] = scale
quantized = (block / scale).clamp(-fp8_max, fp8_max)
x_fp8[start:end] = quantized.to(torch.float8_e4m3fn)
return x_fp8, scales
# ── Test A: DeepseekV4 Attention path ──────────────────────────────────────────────
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 17])
@pytest.mark.parametrize("block_size", [16, 64])
def test_deepseek_v4_attention_quant_cache_roundtrip(num_tokens: int, block_size: int):
"""compressed_kv → quantize_and_insert_k_cache → dequantize_and_gather_k_cache
→ compare against original."""
HEAD_DIM = 512
NOPE_DIM = 448
HEAD_BYTES = 584 # 448 fp8 + 128 bf16 + 8 uint8 scale
FP8_MAX = 448.0
QUANT_BLOCK = 64
num_blocks = (num_tokens + block_size - 1) // block_size + 1
device = "cuda"
# Random compressed_kv (simulates compressor output)
compressed_kv = torch.randn(
num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device
)
# ── Quant + insert ──────────────────────────────────────────────────
k_cache = torch.zeros(
num_blocks, block_size, HEAD_BYTES, dtype=torch.uint8, device=device
)
k_cache_2d = k_cache.view(num_blocks, -1)
# Sequential slot mapping: token i → slot i
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
quantize_and_insert_k_cache(
compressed_kv, k_cache_2d, slot_mapping, block_size=block_size
)
# ── Gather + dequant ────────────────────────────────────────────────
num_reqs = 1
max_blocks_per_seq = num_blocks
out = torch.zeros(
num_reqs, num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device
)
seq_lens = torch.tensor([num_tokens], dtype=torch.int32, device=device)
# block_table: request 0 uses physical blocks 0, 1, ...
block_table = torch.arange(
max_blocks_per_seq, dtype=torch.int32, device=device
).unsqueeze(0)
dequantize_and_gather_k_cache(
out, k_cache, seq_lens, None, block_table, block_size, offset=0
)
recovered = out[0, :num_tokens]
# ── NoPE portion (first 448): FP8 quantized, expect UE8M0 error ──
nope_orig = compressed_kv[:, :NOPE_DIM].float()
nope_recv = recovered[:, :NOPE_DIM].float()
nope_diff = (nope_recv - nope_orig).abs()
# Per-token check: FP8 e4m3 (3-bit mantissa) worst-case error is
# half-ULP at the largest representable value. At y ≈ 448 (max),
# ULP = 2^(8-3) = 32, so error ≤ 16 * scale.
for t in range(num_tokens):
_, scales = _ue8m0_reference(
compressed_kv[t, :NOPE_DIM].float(), QUANT_BLOCK, FP8_MAX
)
max_allowed = 16.0 * scales.max().item()
token_diff = nope_diff[t].max().item()
assert token_diff <= max_allowed, (
f"Token {t} nope diff {token_diff} exceeds max_allowed "
f"{max_allowed} (scale={scales.max().item()})"
)
# ── RoPE portion (last 64): stored as bf16, should be exact ─────
rope_diff = (recovered[:, NOPE_DIM:] - compressed_kv[:, NOPE_DIM:]).abs()
assert rope_diff.max().item() == 0.0, (
f"RoPE portion should be exact but got max diff {rope_diff.max().item()}"
)
# ── Test B: Indexer path ────────────────────────────────────────────────────
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 17])
@pytest.mark.parametrize("block_size", [16, 64])
def test_indexer_quant_cache_roundtrip(num_tokens: int, block_size: int):
"""k → indexer_k_quant_and_cache → cp_gather_indexer_k_quant_cache
→ manual dequant → compare against original."""
HEAD_DIM = 128
QUANT_BLOCK_SIZE = 128
# cache_stride = head_dim + (head_dim * 4 / quant_block_size) = 128 + 4 = 132
CACHE_STRIDE = HEAD_DIM + HEAD_DIM * 4 // QUANT_BLOCK_SIZE
num_blocks = (num_tokens + block_size - 1) // block_size + 1
device = "cuda"
# Random K (simulates compressor output for indexer)
k = torch.randn(num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device)
# ── Quant + insert ──────────────────────────────────────────────────
kv_cache = torch.zeros(
num_blocks, block_size, CACHE_STRIDE, dtype=torch.uint8, device=device
)
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping, QUANT_BLOCK_SIZE, "ue8m0")
# ── Gather ──────────────────────────────────────────────────────────
max_blocks_per_seq = num_blocks
block_table = torch.arange(
max_blocks_per_seq, dtype=torch.int32, device=device
).unsqueeze(0)
cu_seq_lens = torch.tensor([0, num_tokens], dtype=torch.int32, device=device)
# dst_k: [total_seq_len, head_dim] as uint8 (raw FP8 bytes)
dst_k = torch.zeros(num_tokens, HEAD_DIM, dtype=torch.uint8, device=device)
# dst_scale: [total_seq_len, head_dim/quant_block*4] as uint8 (raw float32 bytes)
num_scale_bytes = HEAD_DIM * 4 // QUANT_BLOCK_SIZE # 4
dst_scale = torch.zeros(
num_tokens, num_scale_bytes, dtype=torch.uint8, device=device
)
ops.cp_gather_indexer_k_quant_cache(
kv_cache, dst_k, dst_scale, block_table, cu_seq_lens
)
# ── Manual dequant ──────────────────────────────────────────────────
k_fp8 = dst_k.view(torch.float8_e4m3fn).float() # [num_tokens, 128]
scale = dst_scale.view(torch.float32) # [num_tokens, 1]
k_recovered = k_fp8 * scale # [num_tokens, 128]
# ── Compare ─────────────────────────────────────────────────────────
diff = (k_recovered - k.float()).abs()
k_abs = k.float().abs()
for t in range(num_tokens):
amax = k_abs[t].max().clamp(min=1e-4).item()
# UE8M0: scale = 2^ceil(log2(amax / 448))
exponent = math.ceil(math.log2(amax / 448.0))
ue8m0_scale = 2.0**exponent
# FP8 e4m3 (3-bit mantissa): worst-case error = 16 * scale
max_allowed = 16.0 * ue8m0_scale
token_diff = diff[t].max().item()
assert token_diff <= max_allowed, (
f"Token {t} diff {token_diff} exceeds max_allowed "
f"{max_allowed} (scale={ue8m0_scale})"
)
def test_indexer_gather_accepts_upper_bound_output():
"""Gather only exact cu_seq_lens even when dst is over-allocated."""
head_dim = 128
quant_block_size = 128
cache_stride = head_dim + head_dim * 4 // quant_block_size
valid_tokens = 9
upper_bound_tokens = 13
block_size = 16
num_blocks = 2
sentinel = 123
device = "cuda"
k = torch.randn(valid_tokens, head_dim, dtype=torch.bfloat16, device=device)
kv_cache = torch.zeros(
num_blocks, block_size, cache_stride, dtype=torch.uint8, device=device
)
slot_mapping = torch.arange(valid_tokens, dtype=torch.int64, device=device)
ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping, quant_block_size, "ue8m0")
block_table = torch.arange(num_blocks, dtype=torch.int32, device=device).unsqueeze(
0
)
cu_seq_lens = torch.tensor([0, valid_tokens], dtype=torch.int32, device=device)
dst_k = torch.full(
(upper_bound_tokens, head_dim), sentinel, dtype=torch.uint8, device=device
)
num_scale_bytes = head_dim * 4 // quant_block_size
dst_scale = torch.full(
(upper_bound_tokens, num_scale_bytes),
sentinel,
dtype=torch.uint8,
device=device,
)
ops.cp_gather_indexer_k_quant_cache(
kv_cache, dst_k, dst_scale, block_table, cu_seq_lens
)
torch.accelerator.synchronize()
k_recovered = dst_k[:valid_tokens].view(torch.float8_e4m3fn).float() * dst_scale[
:valid_tokens
].view(torch.float32)
diff = (k_recovered - k.float()).abs()
max_allowed = (16.0 * dst_scale[:valid_tokens].view(torch.float32).max()).item()
assert diff.max().item() <= max_allowed
assert torch.all(dst_k[valid_tokens:] == sentinel)
assert torch.all(dst_scale[valid_tokens:] == sentinel)
# ── Test C: DeepseekV4 attention with values at different magnitudes ───────────
def test_deepseek_v4_quant_magnitude_range():
"""Test that quantization handles a range of magnitudes correctly."""
HEAD_DIM = 512
NOPE_DIM = 448
HEAD_BYTES = 584
block_size = 16
num_tokens = 4
num_blocks = 2
device = "cuda"
# Create inputs with varying magnitudes: small, medium, large
compressed_kv = torch.zeros(
num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device
)
compressed_kv[0] = 0.001 # very small
compressed_kv[1] = 1.0 # unit scale
compressed_kv[2] = 100.0 # large
compressed_kv[3] = torch.randn(HEAD_DIM, dtype=torch.bfloat16, device=device)
k_cache = torch.zeros(
num_blocks, block_size, HEAD_BYTES, dtype=torch.uint8, device=device
)
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
quantize_and_insert_k_cache(
compressed_kv, k_cache.view(num_blocks, -1), slot_mapping, block_size
)
out = torch.zeros(1, num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device)
seq_lens = torch.tensor([num_tokens], dtype=torch.int32, device=device)
block_table = torch.arange(num_blocks, dtype=torch.int32, device=device).unsqueeze(
0
)
dequantize_and_gather_k_cache(
out, k_cache, seq_lens, None, block_table, block_size, offset=0
)
recovered = out[0, :num_tokens]
# RoPE portion must be exact
rope_diff = (recovered[:, NOPE_DIM:] - compressed_kv[:, NOPE_DIM:]).abs().max()
assert rope_diff.item() == 0.0, f"RoPE diff {rope_diff.item()}"
# NoPE: relative error should be reasonable
for t in range(num_tokens):
orig = compressed_kv[t, :NOPE_DIM].float()
recv = recovered[t, :NOPE_DIM].float()
abs_diff = (recv - orig).abs().max().item()
magnitude = orig.abs().max().item()
if magnitude > 0.01:
rel_err = abs_diff / magnitude
assert rel_err < 0.15, (
f"Token {t}: rel_err={rel_err:.4f}, abs_diff={abs_diff:.6f}, "
f"magnitude={magnitude:.4f}"
)
@@ -0,0 +1,359 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Standalone unit test for the horizontally-fused DeepseekV4-MLA kernel:
fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert
- Q side: per-head RMSNorm (no weight) + GPT-J RoPE on last 64 dims
- KV side: GPT-J RoPE on last 64 + UE8M0 FP8 quant + paged cache insert
We compare against:
- PyTorch reference for RMSNorm + GPT-J RoPE on Q
- Existing Triton `quantize_and_insert_k_cache` + round-trip via
`dequantize_and_gather_k_cache` for KV
The kernel is imported via
`torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert`.
"""
import pytest
import torch
from vllm.v1.attention.ops.deepseek_v4_ops import (
dequantize_and_gather_k_cache,
quantize_and_insert_k_cache,
)
# ── Constants matching the kernel ────────────────────────────────────────────
HEAD_DIM = 512
ROPE_DIM = 64
NOPE_DIM = HEAD_DIM - ROPE_DIM # 448
QUANT_BLOCK = 64
FP8_MAX = 448.0
HEAD_BYTES = NOPE_DIM + ROPE_DIM * 2 + 8 # 448 + 128 + 8 = 584
# ── PyTorch reference implementations ────────────────────────────────────────
def make_cos_sin_cache(max_pos: int, rope_dim: int, dtype, device):
"""Build a cos||sin cache matching DeepseekV4ScalingRotaryEmbedding layout.
cos_sin_cache[pos, :rope_dim/2] = cos(theta), [rope_dim/2:] = sin(theta).
"""
base = 10000.0
inv_freq = 1.0 / (
base
** (torch.arange(0, rope_dim, 2, dtype=torch.float32, device=device) / rope_dim)
)
t = torch.arange(max_pos, dtype=torch.float32, device=device)
freqs = torch.einsum("i,j -> ij", t, inv_freq) # [max_pos, rope_dim/2]
cache = torch.cat((freqs.cos(), freqs.sin()), dim=-1) # [max_pos, rope_dim]
return cache.to(dtype)
def apply_rope_gptj_last_k(
x: torch.Tensor, positions: torch.Tensor, cos_sin_cache: torch.Tensor
) -> torch.Tensor:
"""GPT-J-style (interleaved-pair) RoPE on the LAST rope_dim elements.
x: [..., head_dim] float32
positions: [num_tokens] int64 (positions[i] corresponds to x[i, ...])
cos_sin_cache: [max_pos, rope_dim] float (cos|sin layout)
Returns rotated x (same shape/dtype).
"""
rope_dim = cos_sin_cache.shape[-1]
half = rope_dim // 2
head_dim = x.shape[-1]
nope_dim = head_dim - rope_dim
# Gather cos/sin for each token position: [num_tokens, rope_dim]
cs = cos_sin_cache[positions].to(torch.float32) # [N, rope_dim]
cos = cs[..., :half] # [N, half]
sin = cs[..., half:] # [N, half]
# Reshape leading dims so we can broadcast: x shape [..., head_dim].
# Bring token dim to front; assume x is [num_tokens, ..., head_dim].
# We rely on positions being per-token and all other dims sharing the same pos.
rope = x[..., nope_dim:].float() # [..., rope_dim]
# Make rope pairs: reshape last dim to [half, 2]
shape = rope.shape
rope = rope.reshape(*shape[:-1], half, 2)
even = rope[..., 0] # [..., half]
odd = rope[..., 1]
# Broadcast cos/sin over any heads dim in between. cos/sin are [N, half].
# Add singleton dims for intermediate axes.
for _ in range(rope.ndim - 3):
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
new_even = even * cos - odd * sin
new_odd = even * sin + odd * cos
rope_rotated = torch.stack((new_even, new_odd), dim=-1).reshape(shape)
out = x.clone().float()
out[..., nope_dim:] = rope_rotated
return out.to(x.dtype)
def rmsnorm_no_weight(x: torch.Tensor, eps: float) -> torch.Tensor:
"""RMSNorm with no learnable weight, matching
`RMSNorm(head_dim, has_weight=False)`."""
orig_dtype = x.dtype
xf = x.float()
variance = xf.pow(2).mean(dim=-1, keepdim=True)
return (xf * torch.rsqrt(variance + eps)).to(orig_dtype)
# ── Dispatch to the CUDA op (skip test cleanly if it isn't built in) ─────────
def _op_available() -> bool:
return hasattr(torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert")
pytestmark = pytest.mark.skipif(
not torch.cuda.is_available() or not _op_available(),
reason="CUDA not available or fused DeepseekV4 op not built in",
)
def _call_fused(q, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs):
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
q, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs
)
# ── Test 1: Q path numerical parity ──────────────────────────────────────────
@pytest.mark.parametrize("num_tokens", [1, 4, 17, 64])
@pytest.mark.parametrize("n_heads", [8, 64])
def test_q_path_matches_reference(num_tokens: int, n_heads: int):
torch.manual_seed(0)
device = "cuda"
dtype = torch.bfloat16
eps = 1e-6
max_pos = 4096
q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device)
positions = torch.arange(num_tokens, dtype=torch.int64, device=device)
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)
# Reference: RMSNorm (no weight) per head, then GPT-J RoPE on last 64.
q_ref = rmsnorm_no_weight(q, eps)
q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache)
# Fused call with dummy KV tensors (KV branch will write slot_mapping=-1 → noop).
num_blocks = 2
bs = 16
kv = torch.zeros(num_tokens, HEAD_DIM, dtype=dtype, device=device)
k_cache = torch.zeros(
num_blocks, bs, HEAD_BYTES, dtype=torch.uint8, device=device
).view(num_blocks, -1)
slot_mapping = torch.full((num_tokens,), -1, dtype=torch.int64, device=device)
q_fused = q.clone()
_call_fused(q_fused, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs)
torch.testing.assert_close(q_fused, q_ref, rtol=1e-2, atol=1e-2)
# ── Test 2: KV path round-trip byte/value parity ─────────────────────────────
def _ue8m0_per_block_scales(kv_roped_nope_f32: torch.Tensor, qblock: int):
"""Return per-token per-block max scale (used to bound FP8 error)."""
n_tok, nope = kv_roped_nope_f32.shape
n_blocks = nope // qblock
blocks = kv_roped_nope_f32.view(n_tok, n_blocks, qblock)
absmax = blocks.abs().amax(dim=-1).clamp(min=1e-4)
raw = absmax / FP8_MAX
exponent = torch.ceil(torch.log2(raw))
return torch.pow(2.0, exponent) # [n_tok, n_blocks]
@pytest.mark.parametrize("num_tokens", [1, 4, 17, 64])
@pytest.mark.parametrize("block_size", [16, 64])
def test_kv_path_matches_reference(num_tokens: int, block_size: int):
torch.manual_seed(1)
device = "cuda"
dtype = torch.bfloat16
eps = 1e-6
max_pos = 4096
kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device)
positions = torch.arange(num_tokens, dtype=torch.int64, device=device)
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)
num_blocks = (num_tokens + block_size - 1) // block_size + 1
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
# ── Reference path: RoPE on kv, then existing Triton quant+insert ──────
kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache)
k_cache_ref = torch.zeros(
num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device
)
quantize_and_insert_k_cache(
kv_ref, k_cache_ref, slot_mapping, block_size=block_size
)
# ── Fused path (dummy q, single head) ──────────────────────────────────
k_cache_fused = torch.zeros_like(k_cache_ref)
q_dummy = torch.zeros(num_tokens, 1, HEAD_DIM, dtype=dtype, device=device)
_call_fused(
q_dummy,
kv,
k_cache_fused,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
)
# ── Round-trip compare via dequant+gather ──────────────────────────────
def _dequant(k_cache_2d):
num_reqs = 1
max_blocks = num_blocks
out = torch.zeros(
num_reqs, num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device
)
seq_lens = torch.tensor([num_tokens], dtype=torch.int32, device=device)
block_table = torch.arange(
max_blocks, dtype=torch.int32, device=device
).unsqueeze(0)
# gather_lens arg is None (use seq_lens)
k_cache_3d = k_cache_2d.view(num_blocks, block_size, HEAD_BYTES)
dequantize_and_gather_k_cache(
out, k_cache_3d, seq_lens, None, block_table, block_size, offset=0
)
return out[0, :num_tokens]
recovered_ref = _dequant(k_cache_ref)
recovered_fused = _dequant(k_cache_fused)
# NoPE: per-block UE8M0 FP8 error bound (half-ULP at max = 16 * scale).
scales = _ue8m0_per_block_scales(kv_ref[:, :NOPE_DIM].float(), QUANT_BLOCK)
for t in range(num_tokens):
max_allowed = 16.0 * scales[t].max().item()
diff_ref = (
(recovered_ref[t, :NOPE_DIM] - kv_ref[t, :NOPE_DIM]).abs().max().item()
)
diff_fused = (
(recovered_fused[t, :NOPE_DIM] - kv_ref[t, :NOPE_DIM]).abs().max().item()
)
assert diff_ref <= max_allowed, (
f"ref NoPE token {t} diff {diff_ref} > {max_allowed}"
)
assert diff_fused <= max_allowed, (
f"fused NoPE token {t} diff {diff_fused} > {max_allowed}"
)
# RoPE region: bf16 stored exactly → zero diff.
rope_diff = (recovered_fused[:, NOPE_DIM:] - kv_ref[:, NOPE_DIM:]).abs().max()
assert rope_diff.item() == 0.0, f"RoPE portion not exact: {rope_diff.item()}"
# Exact byte equality of the two cache buffers — strong parity.
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)
# ── Test 2b: DP padding (slot_mapping shorter than q/kv) ─────────────────────
@pytest.mark.parametrize("num_tokens", [4, 17])
@pytest.mark.parametrize("pad", [1, 5])
@pytest.mark.parametrize("block_size", [16, 64])
def test_kv_path_with_dp_padding(num_tokens: int, pad: int, block_size: int):
"""slot_mapping.size(0) < q.size(0): the kernel must skip padded
tokens in the KV branch while still running Q-norm+RoPE on all rows."""
torch.manual_seed(3)
device = "cuda"
dtype = torch.bfloat16
eps = 1e-6
max_pos = 4096
total = num_tokens + pad
kv = torch.randn(total, HEAD_DIM, dtype=dtype, device=device)
positions = torch.arange(total, dtype=torch.int64, device=device)
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)
num_blocks = (num_tokens + block_size - 1) // block_size + 1
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
# Reference: only the first num_tokens kv rows get inserted.
kv_ref = apply_rope_gptj_last_k(
kv[:num_tokens], positions[:num_tokens], cos_sin_cache
)
k_cache_ref = torch.zeros(
num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device
)
quantize_and_insert_k_cache(
kv_ref, k_cache_ref, slot_mapping, block_size=block_size
)
# Fused: pass full-sized q/kv/positions, shorter slot_mapping.
q_dummy = torch.zeros(total, 1, HEAD_DIM, dtype=dtype, device=device)
k_cache_fused = torch.zeros_like(k_cache_ref)
_call_fused(
q_dummy,
kv,
k_cache_fused,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
)
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)
# ── Test 3: combined single-call Q + KV parity ───────────────────────────────
@pytest.mark.parametrize("num_tokens", [1, 4, 17])
@pytest.mark.parametrize("n_heads", [8, 64])
@pytest.mark.parametrize("block_size", [16, 64])
def test_combined_q_and_kv(num_tokens: int, n_heads: int, block_size: int):
torch.manual_seed(2)
device = "cuda"
dtype = torch.bfloat16
eps = 1e-6
max_pos = 4096
q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device)
kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device)
positions = torch.arange(num_tokens, dtype=torch.int64, device=device)
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)
num_blocks = (num_tokens + block_size - 1) // block_size + 1
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
# Reference.
q_ref = rmsnorm_no_weight(q, eps)
q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache)
kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache)
k_cache_ref = torch.zeros(
num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device
)
quantize_and_insert_k_cache(
kv_ref, k_cache_ref, slot_mapping, block_size=block_size
)
# Fused single call.
q_fused = q.clone()
k_cache_fused = torch.zeros_like(k_cache_ref)
_call_fused(
q_fused,
kv,
k_cache_fused,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
)
torch.testing.assert_close(q_fused, q_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)
@@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit test for fused_indexer_q_rope_quant.
Compares the fused Triton kernel against the unfused reference flow used by
the DeepseekV4 indexer in model_tracking:
q_rot = ops.rotary_embedding(positions, q, None, head_dim, cos_sin_cache,
is_neox_style=False,
rope_dim_offset=head_dim - rope_dim)
q_fp8, q_scale = per_token_group_quant_fp8(q_rot, head_dim, use_ue8m0=True)
weights_out = weights * q_scale * softmax_scale * head_scale
Expects bit-exact equality on both q_fp8 and weights_out.
"""
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import (
fused_indexer_q_rope_quant,
)
HEAD_DIM = 128
ROPE_DIM = 64
N_HEAD = 64
MAX_POS = 4096
def _reference(
positions: torch.Tensor,
q: torch.Tensor,
cos_sin_cache: torch.Tensor,
weights: torch.Tensor,
softmax_scale: float,
head_scale: float,
) -> tuple[torch.Tensor, torch.Tensor]:
q_rot = q.clone()
ops.rotary_embedding(
positions,
q_rot,
None,
HEAD_DIM,
cos_sin_cache,
False, # is_neox_style=False → GPT-J interleaved
HEAD_DIM - ROPE_DIM, # rope_dim_offset → rotate the tail
False,
)
q_fp8, q_scale = per_token_group_quant_fp8(
q_rot.view(-1, HEAD_DIM).contiguous(),
HEAD_DIM,
use_ue8m0=True,
)
q_fp8 = q_fp8.view(-1, N_HEAD, HEAD_DIM)
q_scale = q_scale.view(-1, N_HEAD)
weights_out = weights.to(torch.float32) * q_scale * softmax_scale * head_scale
return q_fp8, weights_out
@pytest.mark.parametrize("num_tokens", [1, 7, 32, 257])
@pytest.mark.parametrize("cache_dtype", [torch.float32, torch.bfloat16])
@torch.inference_mode()
def test_fused_indexer_q_rope_quant_matches_unfused(num_tokens, cache_dtype):
device = "cuda"
torch.manual_seed(0)
q = torch.randn(num_tokens, N_HEAD, HEAD_DIM, dtype=torch.bfloat16, device=device)
positions = torch.randint(
0, MAX_POS, (num_tokens,), dtype=torch.int64, device=device
)
cos_sin_cache = torch.randn(MAX_POS, ROPE_DIM, dtype=cache_dtype, device=device)
weights = torch.randn(num_tokens, N_HEAD, dtype=torch.bfloat16, device=device)
softmax_scale = HEAD_DIM**-0.5
head_scale = N_HEAD**-0.5
q_fp8_ref, weights_ref = _reference(
positions, q, cos_sin_cache, weights, softmax_scale, head_scale
)
q_fp8_fused, weights_fused = fused_indexer_q_rope_quant(
positions, q.clone(), cos_sin_cache, weights, softmax_scale, head_scale
)
# fp8 tensors aren't directly comparable via torch.equal — reinterpret as int8.
ref_bits = q_fp8_ref.view(torch.int8)
fused_bits = q_fp8_fused.view(torch.int8)
assert torch.equal(ref_bits, fused_bits), (
f"q_fp8 mismatch: "
f"{(ref_bits != fused_bits).sum().item()} / {ref_bits.numel()} bytes differ"
)
assert torch.equal(weights_ref, weights_fused), (
f"weights mismatch: max abs diff "
f"{(weights_ref - weights_fused).abs().max().item()}"
)
@@ -0,0 +1,908 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for the fused inverse RoPE + block-scaled FP8 quantization kernel.
Tests compare the fused kernel against a reference implementation built from
the existing separate operations (inverse RoPE via rotate_neox + FP8 quant
via per_token_group_quant_fp8).
The reference faithfully reproduces the exact flow in deepseek_v4_attention.py:295-310:
1. Apply inverse RoPE (NeoX style, last rope_dim=64 dims of each head)
2. Reshape [T, H, head_dim] -> [T, G, D]
3. Transpose+flatten to [G*T, D], quantize, reshape back
4. Return o_fp8 and o_scale with strides (D, T*D, 1) and (S, T*S, 1)
(non-contiguous [T, G, ...] view backed by contiguous [G, T, ...] memory)
Usage:
pytest tests/kernels/test_fused_inv_rope_fp8_quant.py -v
"""
import pytest
import torch
from vllm.v1.attention.ops.deepseek_v4_ops import fused_inv_rope_fp8_quant
# -- Default dimensions matching DeepSeek V3/V4 --------------------------
HEAD_DIM = 512
NOPE_DIM = 448
ROPE_DIM = 64
QUANT_GROUP_SIZE = 128
FP8_MAX = 448.0 # torch.finfo(torch.float8_e4m3fn).max
FP8_DTYPE = torch.float8_e4m3fn
EPS = 1e-10
# =========================================================================
# Helpers
# =========================================================================
def assert_dequant_close(
fp8_a: torch.Tensor,
scale_a: torch.Tensor,
fp8_b: torch.Tensor,
scale_b: torch.Tensor,
msg: str = "",
):
"""Compare two FP8-quantized tensors via their dequantized values.
Uses cosine-similarity-based diff (same as deep_gemm calc_diff).
Both fused and reference paths rotate in fp32 using an fp32
cos_sin_cache, so differences are only fp32 ordering ULPs that can
occasionally shift FP8 values at quantization boundaries.
"""
S = scale_a.shape[-1]
shape = fp8_a.shape
dq_a = fp8_a.float() * scale_a.unsqueeze(-1).expand(
*shape[:-1], S, QUANT_GROUP_SIZE
).reshape(shape)
dq_b = fp8_b.float() * scale_b.unsqueeze(-1).expand(
*shape[:-1], S, QUANT_GROUP_SIZE
).reshape(shape)
# Cosine diff: 1 - cos_sim (0 = identical, higher = worse)
dq_a_flat = dq_a.flatten().float()
dq_b_flat = dq_b.flatten().float()
cos_sim = torch.nn.functional.cosine_similarity(
dq_a_flat.unsqueeze(0), dq_b_flat.unsqueeze(0)
).item()
diff = 1.0 - cos_sim
assert diff < 1e-4, f"Dequant diff too large: {diff:.8f} (expected < 1e-4). {msg}"
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
"""GPT-J style rotation: interleaved pairs, negate-swap.
Matches vllm/model_executor/layers/rotary_embedding/common.py:23-27.
DeepseekV4 uses is_neox_style=False, so this is the correct rotation.
"""
x1 = x[..., ::2]
x2 = x[..., 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)
def make_cos_sin_cache(
max_pos: int,
rope_dim: int = ROPE_DIM,
dtype: torch.dtype = torch.float32,
device: str = "cuda",
) -> torch.Tensor:
"""Create a synthetic cos_sin_cache matching the layout used by
DeepseekV4ScalingRotaryEmbedding._compute_cos_sin_cache.
Shape: [max_pos, rope_dim] where first half is cos, second half is sin.
The fused kernel requires fp32; callers can override dtype if passing
the cache into the bf16-only paths.
"""
half = rope_dim // 2
# Use random but bounded frequencies so cos/sin are well-behaved
inv_freq = 1.0 / (
10000.0 ** (torch.arange(0, half, device=device, dtype=torch.float32) / half)
)
t = torch.arange(max_pos, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq) # [max_pos, half]
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1) # [max_pos, rope_dim]
return cache.to(dtype)
def reference_inv_rope(
o: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
nope_dim: int = NOPE_DIM,
rope_dim: int = ROPE_DIM,
) -> torch.Tensor:
"""Apply inverse RoPE to the last rope_dim dimensions of each head.
Matches the GPT-J inverse rotation in pos_encoding_kernels.cu, which
promotes the cache to fp32 and performs the rotation in fp32. The
result is cast back to the input dtype.
Args:
o: [T, H, head_dim] bf16
positions: [T] int64
cos_sin_cache: [max_pos, rope_dim] fp32
Returns:
o with inverse RoPE applied on the rope portion (bf16).
"""
assert cos_sin_cache.dtype == torch.float32
cos_sin = cos_sin_cache[positions] # [T, rope_dim] fp32
half = rope_dim // 2
cos = cos_sin[:, :half]
sin = cos_sin[:, half:]
# GPT-J style: repeat_interleave (not repeat) to match interleaved pairs
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(1)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(1)
sin = -sin # inverse
o_pass = o[..., :nope_dim]
o_rot_f32 = o[..., nope_dim:].float()
o_rot_f32 = o_rot_f32 * cos + rotate_gptj(o_rot_f32) * sin
o_rot = o_rot_f32.to(o.dtype)
return torch.cat([o_pass, o_rot], dim=-1)
def _ref_ue8m0_quant_block(x_f32: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Per-block UE8M0 FP8 quantization in pure float32.
Matches the Triton kernel logic exactly:
absmax -> 2^ceil(log2(absmax / fp8_max)) -> clamp(x / scale) -> fp8
Args:
x_f32: [..., quant_group_size] float32 — one or more 128-element blocks.
Returns:
x_fp8: same shape, float8_e4m3fn
scales: [...] float32, one scale per block
"""
assert x_f32.shape[-1] == QUANT_GROUP_SIZE
absmax = x_f32.abs().amax(dim=-1, keepdim=True).clamp(min=EPS)
scale_raw = absmax * (1.0 / FP8_MAX)
scale = torch.exp2(torch.ceil(torch.log2(scale_raw)))
x_scaled = (x_f32 / scale).clamp(-FP8_MAX, FP8_MAX)
x_fp8 = x_scaled.to(FP8_DTYPE)
return x_fp8, scale.squeeze(-1)
def reference_inv_rope_fp8_quant(
o: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
n_groups: int,
heads_per_group: int,
nope_dim: int = NOPE_DIM,
rope_dim: int = ROPE_DIM,
quant_group_size: int = QUANT_GROUP_SIZE,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Full reference: inverse RoPE in fp32 + UE8M0 FP8 quant in fp32.
Mimics the Triton kernel's precision path exactly:
Load bf16 -> cast to fp32 -> apply inverse RoPE with fp32 cos/sin ->
UE8M0 quant in fp32 -> write fp8 + scale
Returns:
o_fp8: [T, G, D] FP8 with strides (D, T*D, 1)
o_scale: [T, G, S] FP32 with strides (S, T*S, 1)
"""
assert cos_sin_cache.dtype == torch.float32
T, _H, head_dim = o.shape
d = heads_per_group * head_dim
S = d // quant_group_size
half_rope = rope_dim // 2
chunks_per_head = head_dim // quant_group_size
# Reshape [T, H, head_dim] -> [T, G, heads_per_group, head_dim]
o_4d = o.view(T, n_groups, heads_per_group, head_dim)
# Lookup cos/sin directly in fp32
cos_sin = cos_sin_cache[positions] # [T, rope_dim] fp32
cos = cos_sin[:, :half_rope] # [T, half_rope] fp32
sin = cos_sin[:, half_rope:] # [T, half_rope] fp32
# Allocate outputs in [G, T, ...] contiguous layout
fp8_buf = torch.empty(n_groups, T, d, dtype=FP8_DTYPE, device=o.device)
scale_buf = torch.empty(n_groups, T, S, dtype=torch.float32, device=o.device)
# Process each quant block, matching the Triton kernel's per-program logic
for g in range(n_groups):
for qb in range(S):
head_in_group = qb // chunks_per_head
chunk_in_head = qb % chunks_per_head
offset = chunk_in_head * quant_group_size
# Load 128 bf16 elements and promote to fp32 for rotation+quant
block = o_4d[:, g, head_in_group, offset : offset + quant_group_size]
x = block.float()
# Apply inverse RoPE in fp32 if this is the last chunk
# GPT-J style: interleaved pairs (even=x, odd=y)
if chunk_in_head == chunks_per_head - 1:
rope_start = nope_dim % quant_group_size # 64
rope_region = x[:, rope_start:].clone()
x_vals = rope_region[:, ::2]
y_vals = rope_region[:, 1::2]
x_new = x_vals * cos + y_vals * sin
y_new = y_vals * cos - x_vals * sin
x = x.clone()
x[:, rope_start::2] = x_new
x[:, rope_start + 1 :: 2] = y_new
# UE8M0 quant in fp32
x_fp8, scale = _ref_ue8m0_quant_block(x)
# Write to [G, T, D] contiguous memory
fp8_buf[g, :, qb * quant_group_size : (qb + 1) * quant_group_size] = x_fp8
scale_buf[g, :, qb] = scale
# Return transposed views
return fp8_buf.transpose(0, 1), scale_buf.transpose(0, 1)
# =========================================================================
# Tests
# =========================================================================
@pytest.mark.parametrize("num_tokens", [1, 7, 32, 128])
@pytest.mark.parametrize(
"num_heads,n_groups",
[(64, 8), (32, 4), (128, 8)],
ids=["H64_G8", "H32_G4", "H128_G8"],
)
@pytest.mark.parametrize("seed", [0, 42])
@torch.inference_mode()
def test_correctness(num_tokens, num_heads, n_groups, seed):
"""Compare fused kernel against reference for FP8 values and scales."""
torch.manual_seed(seed)
heads_per_group = num_heads // n_groups
max_pos = 4096
device = "cuda"
# Create inputs
o = torch.randn(
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
)
positions = torch.randint(
0, max_pos, (num_tokens,), device=device, dtype=torch.long
)
cos_sin_cache = make_cos_sin_cache(
max_pos, ROPE_DIM, dtype=torch.float32, device=device
)
# Reference
ref_fp8, ref_scale = reference_inv_rope_fp8_quant(
o.clone(),
positions,
cos_sin_cache,
n_groups,
heads_per_group,
)
# Fused kernel
fused_fp8, fused_scale = fused_inv_rope_fp8_quant(
o.clone(),
positions,
cos_sin_cache,
n_groups,
heads_per_group,
)
# Check shapes
d = heads_per_group * HEAD_DIM
S = d // QUANT_GROUP_SIZE
assert ref_fp8.shape == (num_tokens, n_groups, d)
assert fused_fp8.shape == (num_tokens, n_groups, d)
assert ref_scale.shape == (num_tokens, n_groups, S)
assert fused_scale.shape == (num_tokens, n_groups, S)
# Scales: exact match (both use identical UE8M0 algorithm)
# Scales may differ by one UE8M0 step (factor of 2) if fp32 rotation
# ordering shifts absmax across a power-of-2 boundary. Check ratio is
# close to 1.
scale_ratio = fused_scale / ref_scale.clamp(min=1e-30)
assert scale_ratio.max() <= 2.0 and scale_ratio.min() >= 0.5, (
f"Scale ratio out of [0.5, 2]: min={scale_ratio.min():.4f} "
f"max={scale_ratio.max():.4f}"
)
# Compare via dequant (Triton vs PyTorch fp32 may differ by ULPs)
assert_dequant_close(ref_fp8, ref_scale, fused_fp8, fused_scale)
@pytest.mark.parametrize("num_tokens", [1, 7, 32, 128])
@pytest.mark.parametrize(
"num_heads,n_groups",
[(64, 8), (128, 8)],
ids=["H64_G8", "H128_G8"],
)
@torch.inference_mode()
def test_output_strides(num_tokens, num_heads, n_groups):
"""Verify fused output layout:
- FP8: logical [T, G, D] backed by contiguous [G, T, D].
- Scale: MN-major TMA-aligned (column-major: T-stride=1).
"""
heads_per_group = num_heads // n_groups
max_pos = 4096
device = "cuda"
o = torch.randn(
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
)
positions = torch.randint(
0, max_pos, (num_tokens,), device=device, dtype=torch.long
)
cos_sin_cache = make_cos_sin_cache(max_pos, device=device)
fused_fp8, fused_scale = fused_inv_rope_fp8_quant(
o.clone(),
positions,
cos_sin_cache,
n_groups,
heads_per_group,
)
# FP8: logical [T, G, D] backed by [G, T, D] row-major
d = heads_per_group * HEAD_DIM
expected_fp8_stride = (d, num_tokens * d, 1)
assert fused_fp8.stride() == expected_fp8_stride, (
f"FP8 stride mismatch: got {fused_fp8.stride()}, expected {expected_fp8_stride}"
)
# Scale: MN-major TMA-aligned layout. After fp8_einsum permutes
# [T,G,S] -> [G,T,S], T-dim should have stride 1.
# Our output is [T,G,S] = transpose of [G,T,S].
# So fused_scale.permute(1,0,2) should have T-stride=1.
perm = fused_scale.permute(1, 0, 2) # [G, T, S]
assert perm.stride(1) == 1 or num_tokens == 1, (
f"Scale T-stride (after permute to [G,T,S]) should be 1, got {perm.stride(1)}"
)
@pytest.mark.parametrize("num_tokens", [1, 7, 32, 128])
@torch.inference_mode()
def test_per_group_contiguity(num_tokens):
"""FP8 per-group slices must be contiguous. Scale per-group slices
are column-major (T-stride=1) — not row-major contiguous, which is
correct for TMA loads."""
num_heads, n_groups = 64, 8
heads_per_group = num_heads // n_groups
max_pos = 4096
device = "cuda"
o = torch.randn(
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
)
positions = torch.randint(
0, max_pos, (num_tokens,), device=device, dtype=torch.long
)
cos_sin_cache = make_cos_sin_cache(max_pos, device=device)
fused_fp8, fused_scale = fused_inv_rope_fp8_quant(
o.clone(),
positions,
cos_sin_cache,
n_groups,
heads_per_group,
)
for g in range(n_groups):
fp8_slice = fused_fp8[:, g, :]
assert fp8_slice.is_contiguous(), (
f"o_fp8[:, {g}, :] is not contiguous: "
f"shape={list(fp8_slice.shape)}, stride={list(fp8_slice.stride())}"
)
@torch.inference_mode()
def test_scales_are_power_of_two():
"""Verify all scales are exact powers of 2 (UE8M0 property)."""
num_tokens, num_heads, n_groups = 32, 64, 8
heads_per_group = num_heads // n_groups
max_pos = 4096
device = "cuda"
o = torch.randn(
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
)
positions = torch.randint(
0, max_pos, (num_tokens,), device=device, dtype=torch.long
)
cos_sin_cache = make_cos_sin_cache(max_pos, device=device)
_, fused_scale = fused_inv_rope_fp8_quant(
o.clone(),
positions,
cos_sin_cache,
n_groups,
heads_per_group,
)
# log2 of a power-of-two is an exact integer
log2_scales = torch.log2(fused_scale)
residual = (log2_scales - log2_scales.round()).abs()
assert residual.max() < 1e-5, (
f"Not all scales are powers of 2: max log2 residual = {residual.max().item()}"
)
@torch.inference_mode()
def test_nope_dims_unchanged():
"""Nope dimensions (first 448 per head) should only be quantized,
not rotated. Verify by dequantizing and comparing against
quantize-only reference (no RoPE)."""
num_tokens, num_heads, n_groups = 16, 64, 8
heads_per_group = num_heads // n_groups
max_pos = 4096
device = "cuda"
torch.manual_seed(0)
o = torch.randn(
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
)
positions = torch.randint(
0, max_pos, (num_tokens,), device=device, dtype=torch.long
)
cos_sin_cache = make_cos_sin_cache(max_pos, device=device)
# Fused kernel result
fused_fp8, fused_scale = fused_inv_rope_fp8_quant(
o.clone(),
positions,
cos_sin_cache,
n_groups,
heads_per_group,
)
# Reference: quantize without RoPE (identity rotation)
# Create a zero-sin cache so RoPE is identity
zero_cache = torch.zeros_like(cos_sin_cache)
half = ROPE_DIM // 2
zero_cache[:, :half] = 1.0 # cos = 1
# sin = 0 (already zero)
norope_fp8, norope_scale = fused_inv_rope_fp8_quant(
o.clone(),
positions,
zero_cache,
n_groups,
heads_per_group,
)
# Extract nope quant blocks only (first 3 of every 4 blocks per head)
chunks_per_head = HEAD_DIM // QUANT_GROUP_SIZE # 4
for h in range(heads_per_group):
for c in range(chunks_per_head - 1): # skip last chunk (has rope)
qb = h * chunks_per_head + c
start = qb * QUANT_GROUP_SIZE
end = start + QUANT_GROUP_SIZE
fused_nope = fused_fp8[:, :, start:end].view(torch.uint8)
norope_nope = norope_fp8[:, :, start:end].view(torch.uint8)
assert torch.equal(fused_nope, norope_nope), (
f"Nope block (head={h}, chunk={c}) differs between "
f"fused and no-rope reference"
)
fused_s = fused_scale[:, :, qb]
norope_s = norope_scale[:, :, qb]
assert torch.equal(fused_s, norope_s), (
f"Nope scale (head={h}, chunk={c}) differs"
)
@torch.inference_mode()
def test_single_token():
"""Edge case: single token."""
num_tokens, num_heads, n_groups = 1, 64, 8
heads_per_group = num_heads // n_groups
max_pos = 4096
device = "cuda"
o = torch.randn(
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
)
positions = torch.tensor([42], device=device, dtype=torch.long)
cos_sin_cache = make_cos_sin_cache(max_pos, device=device)
ref_fp8, ref_scale = reference_inv_rope_fp8_quant(
o.clone(),
positions,
cos_sin_cache,
n_groups,
heads_per_group,
)
fused_fp8, fused_scale = fused_inv_rope_fp8_quant(
o.clone(),
positions,
cos_sin_cache,
n_groups,
heads_per_group,
)
assert_dequant_close(ref_fp8, ref_scale, fused_fp8, fused_scale)
@torch.inference_mode()
def test_zero_positions():
"""Edge case: all positions are 0."""
num_tokens, num_heads, n_groups = 16, 64, 8
heads_per_group = num_heads // n_groups
max_pos = 4096
device = "cuda"
o = torch.randn(
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
)
positions = torch.zeros(num_tokens, device=device, dtype=torch.long)
cos_sin_cache = make_cos_sin_cache(max_pos, device=device)
ref_fp8, ref_scale = reference_inv_rope_fp8_quant(
o.clone(),
positions,
cos_sin_cache,
n_groups,
heads_per_group,
)
fused_fp8, fused_scale = fused_inv_rope_fp8_quant(
o.clone(),
positions,
cos_sin_cache,
n_groups,
heads_per_group,
)
assert_dequant_close(ref_fp8, ref_scale, fused_fp8, fused_scale)
@torch.inference_mode()
def test_large_values():
"""Edge case: values near FP8 saturation to test clamping."""
num_tokens, num_heads, n_groups = 8, 64, 8
heads_per_group = num_heads // n_groups
max_pos = 4096
device = "cuda"
# Create inputs with large values that will saturate FP8
o = torch.randn(
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
)
o = o * 1000.0 # scale up to force saturation
positions = torch.randint(
0, max_pos, (num_tokens,), device=device, dtype=torch.long
)
cos_sin_cache = make_cos_sin_cache(max_pos, device=device)
ref_fp8, ref_scale = reference_inv_rope_fp8_quant(
o.clone(),
positions,
cos_sin_cache,
n_groups,
heads_per_group,
)
fused_fp8, fused_scale = fused_inv_rope_fp8_quant(
o.clone(),
positions,
cos_sin_cache,
n_groups,
heads_per_group,
)
assert_dequant_close(ref_fp8, ref_scale, fused_fp8, fused_scale)
@torch.inference_mode()
def test_dequant_numerical_accuracy():
"""Verify dequantized values are close to the original (after inv RoPE)."""
num_tokens, num_heads, n_groups = 32, 64, 8
heads_per_group = num_heads // n_groups
max_pos = 4096
device = "cuda"
torch.manual_seed(0)
o = torch.randn(
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
)
positions = torch.randint(
0, max_pos, (num_tokens,), device=device, dtype=torch.long
)
cos_sin_cache = make_cos_sin_cache(max_pos, device=device)
# Get the post-inv-RoPE values (ground truth before quantization)
o_after_rope = reference_inv_rope(o.clone(), positions, cos_sin_cache)
d = heads_per_group * HEAD_DIM
o_after_rope = o_after_rope.view(num_tokens, n_groups, d)
# Get fused quantized output
fused_fp8, fused_scale = fused_inv_rope_fp8_quant(
o.clone(),
positions,
cos_sin_cache,
n_groups,
heads_per_group,
)
# Dequantize: broadcast scale [T, G, S] to [T, G, D] via repeat
S = d // QUANT_GROUP_SIZE
scale_expanded = (
fused_scale.unsqueeze(-1)
.expand(num_tokens, n_groups, S, QUANT_GROUP_SIZE)
.reshape(num_tokens, n_groups, d)
)
dequant = fused_fp8.float() * scale_expanded
# Check relative error.
# FP8 e4m3 with UE8M0 (power-of-two scales that round UP) quantizes more
# coarsely than optimal scaling. Both paths rotate in fp32, so the bulk
# of the error comes from UE8M0 quantization itself (~10-12% typical).
o_gt = o_after_rope.transpose(0, 1).contiguous().transpose(0, 1)
dequant_contig = dequant.transpose(0, 1).contiguous().transpose(0, 1)
abs_err = (dequant_contig.float() - o_gt.float()).abs()
rel_err = abs_err / (o_gt.float().abs().clamp(min=1e-6))
mean_rel_err = rel_err.mean().item()
assert mean_rel_err < 0.15, (
f"Mean relative error too high: {mean_rel_err:.4f} (expected < 0.15)"
)
def _unfused_inv_rope_fp8_quant(
o: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
n_groups: int,
heads_per_group: int,
nope_dim: int = NOPE_DIM,
rope_dim: int = ROPE_DIM,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Unfused path matching deepseek_v4_attention.py:295-310.
Uses the production CUDA RoPE kernel + per_token_group_quant_fp8.
"""
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
head_dim = o.shape[-1]
rope_dim_offset = head_dim - rope_dim
# Step 1: In-place CUDA RoPE (same as production)
ops.rotary_embedding(
positions,
o,
None,
head_dim,
cos_sin_cache,
False, # is_neox=False for DeepseekV4 (GPT-J style)
rope_dim_offset=rope_dim_offset,
inverse=True,
)
# Step 2: Reshape + quant + reshape (same as production)
T = o.shape[0]
d = heads_per_group * head_dim
o = o.view(T, n_groups, -1)
o_flat = o.transpose(0, 1).contiguous().reshape(-1, d)
o_fp8, o_scale = per_token_group_quant_fp8(
o_flat,
group_size=QUANT_GROUP_SIZE,
use_ue8m0=True,
)
o_fp8 = o_fp8.view(n_groups, T, d).transpose(0, 1)
o_scale = o_scale.view(n_groups, T, -1).transpose(0, 1)
return o_fp8, o_scale
# =========================================================================
# End-to-end test including fp8_einsum
# =========================================================================
@pytest.mark.parametrize("num_tokens", [1, 7, 32, 128, 1024])
@pytest.mark.parametrize(
"num_heads,n_groups",
[(64, 8)],
ids=["H64_G8"],
)
@torch.inference_mode()
def test_einsum_end_to_end(num_tokens, num_heads, n_groups):
"""End-to-end: fused inv_rope+quant → fp8_einsum must match
unfused CUDA_rope+quant → fp8_einsum bitwise.
This catches stride/layout bugs that only manifest when the einsum
kernel actually consumes the quantized activations.
"""
from deep_gemm.utils.math import ceil_div
from vllm.utils.deep_gemm import (
fp8_einsum,
per_block_cast_to_fp8,
transform_sf_into_required_layout,
)
heads_per_group = num_heads // n_groups
d = heads_per_group * HEAD_DIM
o_lora_rank = 1024
max_pos = 4096
device = "cuda"
torch.manual_seed(0)
o = torch.randn(
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
)
positions = torch.randint(
0, max_pos, (num_tokens,), device=device, dtype=torch.long
)
cos_sin_cache = make_cos_sin_cache(max_pos, device=device)
# -- Weight quantization (shared between both paths) --
w = torch.randn(n_groups, o_lora_rank, d, device=device, dtype=torch.bfloat16)
w_fp8 = torch.empty_like(w, dtype=torch.float8_e4m3fn)
w_scale = torch.empty(
n_groups,
ceil_div(o_lora_rank, 128),
ceil_div(d, 128),
device=device,
dtype=torch.float32,
)
for g in range(n_groups):
w_fp8[g], w_scale[g] = per_block_cast_to_fp8(w[g], use_ue8m0=True)
recipe = (1, 1, 128)
w_scale_t = transform_sf_into_required_layout(
sf=w_scale,
mn=o_lora_rank,
k=d,
recipe=(1, 128, 128),
num_groups=n_groups,
is_sfa=False,
)
# -- UNFUSED path --
ref_fp8, ref_scale = _unfused_inv_rope_fp8_quant(
o.clone(),
positions,
cos_sin_cache,
n_groups,
heads_per_group,
)
z_ref = torch.empty(
num_tokens, n_groups, o_lora_rank, device=device, dtype=torch.bfloat16
)
fp8_einsum(
"bhr,hdr->bhd", (ref_fp8, ref_scale), (w_fp8, w_scale_t), z_ref, recipe=recipe
)
# -- FUSED path --
fused_fp8, fused_scale = fused_inv_rope_fp8_quant(
o.clone(),
positions,
cos_sin_cache,
n_groups,
heads_per_group,
)
z_fused = torch.empty(
num_tokens, n_groups, o_lora_rank, device=device, dtype=torch.bfloat16
)
fp8_einsum(
"bhr,hdr->bhd",
(fused_fp8, fused_scale),
(w_fp8, w_scale_t),
z_fused,
recipe=recipe,
)
# -- Checks --
# Einsum output: Triton and CUDA both rotate in fp32 now, so diffs
# come from fp32 ordering and UE8M0 boundary shifts only.
# Use relative diff (same metric as test_fp8_einsum.py).
from deep_gemm.testing import calc_diff
z_diff = calc_diff(z_fused, z_ref)
assert z_diff < 0.01, (
f"Einsum output diff too large: {z_diff:.6f} (expected < 0.01)"
)
@pytest.mark.parametrize("num_tokens", [1, 32, 256])
@torch.inference_mode()
def test_with_real_deepseek_v4_rope(num_tokens, default_vllm_config):
"""Test with real DeepseekV4ScalingRotaryEmbedding (GPT-J style,
mscale=0, YaRN scaling) matching the production config."""
num_heads = 64
n_groups = 8
heads_per_group = num_heads // n_groups
device = "cuda"
torch.manual_seed(0)
# Build YaRN-scaled cos_sin_cache matching real DeepSeek V3/V4 config
# (mscale=0 → mscale=1.0, so no magnitude scaling)
from vllm.model_executor.layers.rotary_embedding.common import (
yarn_find_correction_range,
yarn_linear_ramp_mask,
)
scaling_factor = 16
base = 10000.0
max_pos = 65536
beta_fast, beta_slow = 32, 1
pos_freqs = base ** (
torch.arange(0, ROPE_DIM, 2, dtype=torch.float32, device=device) / ROPE_DIM
)
inv_freq_extra = 1.0 / pos_freqs
inv_freq_interp = 1.0 / (scaling_factor * pos_freqs)
low, high = yarn_find_correction_range(
beta_fast, beta_slow, ROPE_DIM, base, max_pos
)
mask = 1 - yarn_linear_ramp_mask(low, high, ROPE_DIM // 2, dtype=torch.float32).to(
device
)
inv_freq = inv_freq_interp * (1 - mask) + inv_freq_extra * mask
t = torch.arange(max_pos * scaling_factor, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
# mscale=0 → yarn_get_mscale returns 1.0
cos_sin_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) # fp32
o = torch.randn(
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
)
positions = torch.randint(0, 4096, (num_tokens,), device=device, dtype=torch.long)
# UNFUSED: CUDA RoPE with is_neox=False (GPT-J)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
o_unfused = o.clone()
ops.rotary_embedding(
positions,
o_unfused,
None,
HEAD_DIM,
cos_sin_cache,
False, # is_neox=False (GPT-J style)
rope_dim_offset=NOPE_DIM,
inverse=True,
)
d = heads_per_group * HEAD_DIM
T = num_tokens
o_unfused = o_unfused.view(T, n_groups, d)
o_flat = o_unfused.transpose(0, 1).contiguous().reshape(-1, d)
ref_fp8, ref_scale = per_token_group_quant_fp8(
o_flat,
group_size=QUANT_GROUP_SIZE,
use_ue8m0=True,
)
ref_fp8 = ref_fp8.view(n_groups, T, d).transpose(0, 1)
ref_scale = ref_scale.view(n_groups, T, -1).transpose(0, 1)
# FUSED: use the real YaRN-scaled cos_sin_cache
fused_fp8, fused_scale = fused_inv_rope_fp8_quant(
o.clone(),
positions,
cos_sin_cache,
n_groups,
heads_per_group,
)
# Scales must match exactly (same UE8M0 algorithm)
# Compare via dequant (Triton bf16 rotation may differ from CUDA by 1 ULP)
assert_dequant_close(
ref_fp8, ref_scale, fused_fp8, fused_scale, msg="Real DeepSeek V4 rope"
)
+5 -10
View File
@@ -718,7 +718,6 @@ def test_persistent_topk_stress() -> None:
pytest.param(
{
"seq_lens": [2000, 6000, 30000, 80000],
"top_k": 2048,
"data_type": "random",
},
id="mixed_all_paths",
@@ -727,7 +726,6 @@ def test_persistent_topk_stress() -> None:
pytest.param(
{
"seq_lens": [2048, 4096, 8192, 16000],
"top_k": 2048,
"data_type": "random",
},
id="all_decode_medium",
@@ -736,7 +734,6 @@ def test_persistent_topk_stress() -> None:
pytest.param(
{
"seq_lens": [70000, 100000, 163840],
"top_k": 2048,
"data_type": "random",
},
id="all_large",
@@ -745,7 +742,6 @@ def test_persistent_topk_stress() -> None:
pytest.param(
{
"seq_lens": [32767, 32768, 32769, 32772],
"top_k": 2048,
"data_type": "random",
},
id="large_threshold_boundary",
@@ -754,7 +750,6 @@ def test_persistent_topk_stress() -> None:
pytest.param(
{
"seq_lens": [5000],
"top_k": 2048,
"data_type": "random",
},
id="single_row_medium",
@@ -772,15 +767,15 @@ def test_persistent_topk_stress() -> None:
pytest.param(
{
"seq_lens": [100, 2048, 10000, 80000],
"top_k": 2048,
"data_type": "random",
},
id="trivial_medium_large_mix",
),
],
)
@pytest.mark.parametrize("top_k", [512, 2048])
@torch.inference_mode()
def test_persistent_topk(test_config: dict) -> None:
def test_persistent_topk(test_config: dict, top_k: int) -> None:
"""
Tests specific to the persistent_topk kernel:
- Mixed medium/large rows in the same batch (dynamic per-row dispatch)
@@ -790,14 +785,15 @@ def test_persistent_topk(test_config: dict) -> None:
run_large_context_topk_test(
batch_size=len(test_config["seq_lens"]),
seq_lens=test_config["seq_lens"],
top_k=test_config["top_k"],
top_k=top_k,
data_type=test_config.get("data_type", "random"),
)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@pytest.mark.parametrize("top_k", [512, 2048])
@torch.inference_mode()
def test_persistent_topk_padded_stride() -> None:
def test_persistent_topk_padded_stride(top_k: int) -> None:
"""
Test persistent_topk with padded logits (large stride, small seq_len)
to simulate the e2e CUDAGraph scenario where fp8_paged_mqa_logits
@@ -806,7 +802,6 @@ def test_persistent_topk_padded_stride() -> None:
set_random_seed(42)
torch.set_default_device("cuda:0")
top_k = 2048
batch_size = 4
padded_stride = 163840 # DeepSeek-V3.2 max_model_len
actual_seq_lens = [3000, 5000, 8000, 12000]
@@ -41,7 +41,9 @@ class DummyRouter(BaseRouter):
def routing_method_type(self) -> RoutingMethodType:
return RoutingMethodType.FUSED_TOPK
def _compute_routing(self, hidden_states, router_logits, indices_type):
def _compute_routing(
self, hidden_states, router_logits, indices_type, *, input_ids=None
):
topk_ids = torch.tensor([[1, 2], [3, 4]], dtype=torch.int64)
topk_weights = torch.ones_like(topk_ids, dtype=torch.float32)
return topk_weights, topk_ids
+9
View File
@@ -260,6 +260,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True,
),
"DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"),
"DeepseekV4ForCausalLM": _HfExamplesInfo(
"deepseek-ai/DeepSeek-V4-Flash", is_available_online=False
),
"Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT"),
"Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT"),
"ExaoneForCausalLM": _HfExamplesInfo(
@@ -1482,6 +1485,12 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
speculative_model="luccafong/deepseek_mtp_draft_random",
trust_remote_code=True,
),
"DeepSeekV4MTPModel": _HfExamplesInfo(
"deepseek-ai/DeepSeek-V4-Flash",
speculative_model="deepseek-ai/DeepSeek-V4-Flash",
trust_remote_code=True,
is_available_online=False,
),
"ErnieMTPModel": _HfExamplesInfo(
"baidu/ERNIE-4.5-21B-A3B-PT",
trust_remote_code=True,
+184
View File
@@ -0,0 +1,184 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from types import SimpleNamespace
import pytest
import torch
from vllm.model_executor.models.deepseek_v4 import (
DeepseekV4MegaMoEExperts,
_stage_deepseek_v4_mega_moe_inputs,
make_deepseek_v4_expert_params_mapping,
)
from vllm.platforms import current_platform
pytestmark = pytest.mark.skipif(
not current_platform.is_cuda(),
reason="DeepSeek V4 MegaMoE requires CUDA",
)
def test_deepseek_v4_mega_moe_expert_mapping():
mapping = make_deepseek_v4_expert_params_mapping(2)
assert mapping == [
("experts.w13_", "experts.0.w1.", 0, "w1"),
("experts.w2_", "experts.0.w2.", 0, "w2"),
("experts.w13_", "experts.0.w3.", 0, "w3"),
("experts.w13_", "experts.1.w1.", 1, "w1"),
("experts.w2_", "experts.1.w2.", 1, "w2"),
("experts.w13_", "experts.1.w3.", 1, "w3"),
]
def test_deepseek_v4_mega_moe_ue8m0_uint8_to_float():
raw = torch.tensor([0, 126, 127, 128], dtype=torch.uint8)
decoded = DeepseekV4MegaMoEExperts._ue8m0_uint8_to_float(raw)
assert torch.equal(decoded.view(torch.int32), raw.to(torch.int32) << 23)
assert decoded[0].item() == 0.0
assert decoded[1].item() == 0.5
assert decoded[2].item() == 1.0
assert decoded[3].item() == 2.0
def test_deepseek_v4_mega_moe_weight_loader_uses_ep_expert_ownership():
vllm_config = SimpleNamespace(
scheduler_config=SimpleNamespace(max_num_batched_tokens=4)
)
experts = DeepseekV4MegaMoEExperts(
vllm_config,
num_experts=4,
num_local_experts=2,
experts_start_idx=2,
top_k=2,
hidden_size=128,
intermediate_size=128,
)
nonlocal_weight = torch.ones(128, 64, dtype=torch.uint8)
assert (
experts.weight_loader(
experts.w13_weight,
nonlocal_weight,
"experts.w13_weight",
shard_id="w1",
expert_id=1,
return_success=True,
)
is False
)
w1 = torch.full((128, 64), 3, dtype=torch.uint8)
w3 = torch.full((128, 64), 7, dtype=torch.uint8)
w2 = torch.full((128, 64), 11, dtype=torch.uint8)
assert experts.weight_loader(
experts.w13_weight,
w1,
"experts.w13_weight",
shard_id="w1",
expert_id=2,
return_success=True,
)
assert experts.weight_loader(
experts.w13_weight,
w3,
"experts.w13_weight",
shard_id="w3",
expert_id=2,
return_success=True,
)
assert experts.weight_loader(
experts.w2_weight,
w2,
"experts.w2_weight",
shard_id="w2",
expert_id=2,
return_success=True,
)
assert torch.equal(experts.w13_weight[0, :128], w1)
assert torch.equal(experts.w13_weight[0, 128:], w3)
assert torch.equal(experts.w2_weight[0], w2)
assert torch.count_nonzero(experts.w13_weight[1]) == 0
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="DeepSeek V4 MegaMoE fused input staging requires CUDA.",
)
def test_deepseek_v4_mega_moe_fused_input_staging_is_bitwise_exact():
from vllm.third_party.deep_gemm.utils import per_token_cast_to_fp8
device = torch.device("cuda")
num_tokens = 7
hidden_size = 256
top_k = 8
generator = torch.Generator(device=device)
generator.manual_seed(0)
hidden_states = (
torch.randn(
num_tokens,
hidden_size,
device=device,
dtype=torch.float32,
generator=generator,
)
* 17.0
).to(torch.bfloat16)
hidden_states[0, :32] = 0
hidden_states[1, 32:64] = 1.0e-6
hidden_states[2, 64:96] = -1.0e-6
topk_ids = torch.randint(
0,
256,
(num_tokens, top_k),
device=device,
dtype=torch.int32,
generator=generator,
)
topk_weights = torch.randn(
num_tokens,
top_k,
device=device,
dtype=torch.float32,
generator=generator,
)
ref_x, ref_x_sf = per_token_cast_to_fp8(
hidden_states,
use_ue8m0=True,
gran_k=32,
use_packed_ue8m0=True,
)
ref_topk_idx = topk_ids.to(torch.int64)
ref_topk_weights = topk_weights.clone()
fused_x = torch.empty_like(ref_x)
fused_x_sf = torch.empty_like(ref_x_sf)
fused_topk_idx = torch.empty_like(ref_topk_idx)
fused_topk_weights = torch.empty_like(ref_topk_weights)
_stage_deepseek_v4_mega_moe_inputs(
hidden_states,
topk_weights,
topk_ids,
fused_x,
fused_x_sf,
fused_topk_idx,
fused_topk_weights,
)
torch.accelerator.synchronize()
assert torch.equal(fused_x.view(torch.uint8), ref_x.view(torch.uint8))
assert torch.equal(fused_x_sf, ref_x_sf)
assert torch.equal(fused_topk_idx, ref_topk_idx)
assert torch.equal(
fused_topk_weights.view(torch.uint8),
ref_topk_weights.view(torch.uint8),
)
@@ -6,6 +6,7 @@ from transformers import AutoTokenizer
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.reasoning import ReasoningParserManager
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from vllm.reasoning.deepseek_v3_reasoning_parser import DeepSeekV3ReasoningParser
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
@@ -33,6 +34,12 @@ def test_parser_selection(tokenizer, thinking, expected_parser_type):
assert isinstance(parser._parser, expected_parser_type)
def test_deepseek_v4_reasoning_parser_alias():
parser_cls = ReasoningParserManager.get_reasoning_parser("deepseek_v4")
assert parser_cls is DeepSeekV3ReasoningParser
def test_identity_reasoning_parser_basic(tokenizer):
parser = IdentityReasoningParser(tokenizer)
@@ -0,0 +1,81 @@
{
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather for a specific location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "Temperature unit"
}
},
"required": ["location"]
}
}
},
{
"type": "function",
"function": {
"name": "search",
"description": "Search the web for information",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query"
},
"num_results": {
"type": "integer",
"description": "Number of results to return"
}
},
"required": ["query"]
}
}
}
],
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What's the weather in Beijing?"
},
{
"role": "assistant",
"reasoning": "The user wants to know the weather in Beijing. I should use the get_weather tool.",
"tool_calls": [
{
"id": "call_001",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"location\": \"Beijing\", \"unit\": \"celsius\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_001",
"content": "{\"temperature\": 22, \"condition\": \"sunny\", \"humidity\": 45}"
},
{
"role": "assistant",
"reasoning": "Got the weather data. Let me format a nice response.",
"content": "The weather in Beijing is currently sunny with a temperature of 22°C and 45% humidity."
}
]
}
@@ -0,0 +1,24 @@
[
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Hello"
},
{
"role": "assistant",
"reasoning": "The user said hello, I should greet back.",
"content": "Hi there! How can I help you?"
},
{
"role": "user",
"content": "What is the capital of France?"
},
{
"role": "assistant",
"reasoning": "The user asks about the capital of France. It is Paris.",
"content": "The capital of France is Paris."
}
]
@@ -0,0 +1,159 @@
[
{
"role": "system",
"content": "该助手为DeepSeek,由深度求索公司创造。"
},
{
"role": "latest_reminder",
"content": "2026-02-21,星期六,广州,App,中文"
},
{
"role": "developer",
"content": "小柴胡冲剂和布洛芬能一起吃吗?\n\nCITATION FORMAT: 【{cursor_id}†L{start_line_id}(-L{end_line_id})?】",
"tools": [
{
"type": "function",
"function": {
"name": "search",
"description": "Web search. Split multiple queries with '||'.",
"parameters": {
"type": "object",
"properties": {
"queries": {
"type": "string",
"description": "query1||query2"
}
},
"required": [
"queries"
],
"additionalProperties": false,
"$schema": "http://json-schema.org/draft-07/schema#"
}
}
},
{
"type": "function",
"function": {
"name": "open",
"description": "Batch open IDs (format 【{id}†...】) or URLs.",
"parameters": {
"type": "object",
"properties": {
"open_list": {
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {
"description": "ID or URL",
"anyOf": [
{
"type": "integer"
},
{
"type": "string"
}
],
"default": -1
},
"cursor": {
"type": "integer",
"description": "",
"default": -1
},
"loc": {
"type": "integer",
"description": "Start line",
"default": -1
},
"num_lines": {
"type": "integer",
"description": "",
"default": -1
},
"view_source": {
"type": "boolean",
"description": "",
"default": false
}
},
"additionalProperties": false
},
"description": ""
}
},
"required": [
"open_list"
],
"additionalProperties": false,
"$schema": "http://json-schema.org/draft-07/schema#"
}
}
},
{
"type": "function",
"function": {
"name": "find",
"description": "Find exact text pattern in pages.",
"parameters": {
"type": "object",
"properties": {
"find_list": {
"type": "array",
"items": {
"type": "object",
"properties": {
"pattern": {
"type": "string",
"description": ""
},
"cursor": {
"type": "integer",
"description": "",
"default": -1
}
},
"required": [
"pattern"
],
"additionalProperties": false
},
"description": ""
}
},
"required": [
"find_list"
],
"additionalProperties": false,
"$schema": "http://json-schema.org/draft-07/schema#"
}
}
}
]
},
{
"role": "assistant",
"content": "",
"reasoning": "用户想知道小柴胡冲剂和布洛芬能否一起服用。",
"tool_calls": [
{
"type": "function",
"function": {
"name": "search",
"arguments": "{\"queries\": \"小柴胡冲剂 布洛芬 相互作用 一起吃\"}"
}
}
]
},
{
"role": "tool",
"content": "[0]"
},
{
"role": "assistant",
"content": "请及时就医。",
"reasoning": "现在开始组织回答。",
"tool_calls": []
}
]
@@ -0,0 +1,28 @@
[
{
"role": "system",
"content": "该助手为DeepSeek-V3,由深度求索公司创造。\n今天是2025年10月17日,星期五。"
},
{
"role": "latest_reminder",
"content": "2024-11-15,上海市,App,中文"
},
{
"role": "user",
"content": "热海大滚锅是世界著名温泉吗"
},
{
"role": "assistant",
"content": "关于热海大滚锅是否是世界著名温泉,可以这样总结:它在**中国乃至全球的地热奇观中占有重要地位**,但“世界著名”的称号更侧重于它作为**独特的地质现象和旅游景点**,而非普遍的温泉疗养体验。\n\n为了让你快速了解,我整理了一个简要的表格:\n\n| 维度 | 热海大滚锅的地位与特点 |\n| :--- | :--- |\n| **地质奇观** | **世界罕见**的高温地热系统,泉眼多、水温高、形态多样。 |\n| **国内声誉** | **中国三大地热区之一**,**国家5A级旅游景区**,自古闻名(徐霞客曾游历并记载)。 |\n| **国际知名度** | 在一些旅游平台被提及为“世界六大温泉”之一,但此说法流传不广,其国际声誉更多建立在地质独特性上。 |\n| **核心体验** | **观赏地热奇观**(如97℃沸腾的“大滚锅”)、**体验温泉煮鸡蛋**。 |\n\n### 💡 游玩攻略与温馨提示\n\n如果你计划前往热海大滚锅,这里有一些实用信息供你参考:\n\n- **门票与开放时间**:\n - **门票**:景区门票约为**50元/人**。如果选择包含温泉沐浴的套餐,价格会更高,例如约**288元**。\n - **开放时间**:景区一般**08:00-18:00**开放,但具体时间可能变动,建议提前核实。\n\n- **特色体验**:\n - **温泉煮鸡蛋**:这几乎是必试项目。可以在景区门口购买用草绳串起的生鸡蛋(约5-8元/串),然后到“大滚锅”旁的指定区域蒸煮,几分钟便可熟食,趣味十足。\n - **金汤足浴**:可以直接用从“大滚锅”流出的温泉水泡脚,缓解旅途疲劳。\n\n- **注意事项**:\n - **安全第一**:“大滚锅”水温极高,务必遵守游览规则,在指定区域内观赏,切勿随意触碰泉水。\n - **规划行程**:建议为热海景区预留**3-4小时**的游览时间。景区内步道不走回头路,出入口有观光车接送。\n\n希望这些信息能帮助你更好地了解热海大滚锅。如果你对腾冲的其他景点或者行程规划有更多疑问,我很乐意提供进一步的信息。",
"mask": 1
},
{
"role": "user",
"content": "世界著名温泉有哪些",
"task": "action"
},
{
"role": "assistant",
"content": "Search"
}
]
@@ -0,0 +1,36 @@
<begin▁of▁sentence>
## Tools
You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<DSMLtool_calls>" block like the following:
<DSMLtool_calls>
<DSMLinvoke name="$TOOL_NAME">
<DSMLparameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</DSMLparameter>
...
</DSMLinvoke>
<DSMLinvoke name="$TOOL_NAME2">
...
</DSMLinvoke>
</DSMLtool_calls>
String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
Otherwise, output directly after </think> with tool calls or final response.
### Available Tool Schemas
{"name": "get_weather", "description": "Get the weather for a specific location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city name"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "Temperature unit"}}, "required": ["location"]}}
{"name": "search", "description": "Search the web for information", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Search query"}, "num_results": {"type": "integer", "description": "Number of results to return"}}, "required": ["query"]}}
You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
You are a helpful assistant.<User>What's the weather in Beijing?<Assistant><think>The user wants to know the weather in Beijing. I should use the get_weather tool.</think>
<DSMLtool_calls>
<DSMLinvoke name="get_weather">
<DSMLparameter name="location" string="true">Beijing</DSMLparameter>
<DSMLparameter name="unit" string="true">celsius</DSMLparameter>
</DSMLinvoke>
</DSMLtool_calls><end▁of▁sentence><User><tool_result>{"temperature": 22, "condition": "sunny", "humidity": 45}</tool_result><Assistant><think>Got the weather data. Let me format a nice response.</think>The weather in Beijing is currently sunny with a temperature of 22°C and 45% humidity.<end▁of▁sentence>
@@ -0,0 +1 @@
<begin▁of▁sentence>You are a helpful assistant.<User>Hello<Assistant></think>Hi there! How can I help you?<end▁of▁sentence><User>What is the capital of France?<Assistant><think>The user asks about the capital of France. It is Paris.</think>The capital of France is Paris.<end▁of▁sentence>
@@ -0,0 +1,38 @@
<begin▁of▁sentence>该助手为DeepSeek,由深度求索公司创造。<latest_reminder>2026-02-21,星期六,广州,App,中文<User|>小柴胡冲剂和布洛芬能一起吃吗?
CITATION FORMAT: 【{cursor_id}†L{start_line_id}(-L{end_line_id})?】
## Tools
You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<DSMLtool_calls>" block like the following:
<DSMLtool_calls>
<DSMLinvoke name="$TOOL_NAME">
<DSMLparameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</DSMLparameter>
...
</DSMLinvoke>
<DSMLinvoke name="$TOOL_NAME2">
...
</DSMLinvoke>
</DSMLtool_calls>
String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
Otherwise, output directly after </think> with tool calls or final response.
### Available Tool Schemas
{"name": "search", "description": "Web search. Split multiple queries with '||'.", "parameters": {"type": "object", "properties": {"queries": {"type": "string", "description": "query1||query2"}}, "required": ["queries"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
{"name": "open", "description": "Batch open IDs (format 【{id}†...】) or URLs.", "parameters": {"type": "object", "properties": {"open_list": {"type": "array", "items": {"type": "object", "properties": {"id": {"description": "ID or URL", "anyOf": [{"type": "integer"}, {"type": "string"}], "default": -1}, "cursor": {"type": "integer", "description": "", "default": -1}, "loc": {"type": "integer", "description": "Start line", "default": -1}, "num_lines": {"type": "integer", "description": "", "default": -1}, "view_source": {"type": "boolean", "description": "", "default": false}}, "additionalProperties": false}, "description": ""}}, "required": ["open_list"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
{"name": "find", "description": "Find exact text pattern in pages.", "parameters": {"type": "object", "properties": {"find_list": {"type": "array", "items": {"type": "object", "properties": {"pattern": {"type": "string", "description": ""}, "cursor": {"type": "integer", "description": "", "default": -1}}, "required": ["pattern"], "additionalProperties": false}, "description": ""}}, "required": ["find_list"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
<Assistant><think>用户想知道小柴胡冲剂和布洛芬能否一起服用。</think>
<DSMLtool_calls>
<DSMLinvoke name="search">
<DSMLparameter name="queries" string="true">小柴胡冲剂 布洛芬 相互作用 一起吃</DSMLparameter>
</DSMLinvoke>
</DSMLtool_calls><end▁of▁sentence><User><tool_result>[0]</tool_result><Assistant><think>现在开始组织回答。</think>请及时就医。<end▁of▁sentence>
@@ -0,0 +1,29 @@
<begin▁of▁sentence>该助手为DeepSeek-V3,由深度求索公司创造。
今天是2025年10月17日,星期五。<latest_reminder>2024-11-15,上海市,App,中文<User|>热海大滚锅是世界著名温泉吗<Assistant></think>关于热海大滚锅是否是世界著名温泉,可以这样总结:它在**中国乃至全球的地热奇观中占有重要地位**,但“世界著名”的称号更侧重于它作为**独特的地质现象和旅游景点**,而非普遍的温泉疗养体验。
为了让你快速了解,我整理了一个简要的表格:
| 维度 | 热海大滚锅的地位与特点 |
| :--- | :--- |
| **地质奇观** | **世界罕见**的高温地热系统,泉眼多、水温高、形态多样。 |
| **国内声誉** | **中国三大地热区之一**,**国家5A级旅游景区**,自古闻名(徐霞客曾游历并记载)。 |
| **国际知名度** | 在一些旅游平台被提及为“世界六大温泉”之一,但此说法流传不广,其国际声誉更多建立在地质独特性上。 |
| **核心体验** | **观赏地热奇观**(如97℃沸腾的“大滚锅”)、**体验温泉煮鸡蛋**。 |
### 💡 游玩攻略与温馨提示
如果你计划前往热海大滚锅,这里有一些实用信息供你参考:
- **门票与开放时间**
- **门票**:景区门票约为**50元/人**。如果选择包含温泉沐浴的套餐,价格会更高,例如约**288元**。
- **开放时间**:景区一般**08:00-18:00**开放,但具体时间可能变动,建议提前核实。
- **特色体验**
- **温泉煮鸡蛋**:这几乎是必试项目。可以在景区门口购买用草绳串起的生鸡蛋(约5-8元/串),然后到“大滚锅”旁的指定区域蒸煮,几分钟便可熟食,趣味十足。
- **金汤足浴**:可以直接用从“大滚锅”流出的温泉水泡脚,缓解旅途疲劳。
- **注意事项**
- **安全第一**:“大滚锅”水温极高,务必遵守游览规则,在指定区域内观赏,切勿随意触碰泉水。
- **规划行程**:建议为热海景区预留**3-4小时**的游览时间。景区内步道不走回头路,出入口有观光车接送。
希望这些信息能帮助你更好地了解热海大滚锅。如果你对腾冲的其他景点或者行程规划有更多疑问,我很乐意提供进一步的信息。<|end▁of▁sentence><User>世界著名温泉有哪些<Assistant></think><action>Search<end▁of▁sentence>
+224
View File
@@ -0,0 +1,224 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from pathlib import Path
from types import SimpleNamespace
import pytest
from vllm.entrypoints.chat_utils import parse_chat_messages
from vllm.renderers.registry import RENDERER_REGISTRY
from vllm.tokenizers.deepseek_v4 import get_deepseek_v4_tokenizer
from vllm.tokenizers.registry import TokenizerRegistry
FIXTURES_DIR = Path(__file__).parent / "fixtures" / "deepseek_v4"
class FakeHfTokenizer:
vocab_size = 100
def get_added_vocab(self) -> dict[str, int]:
return {"</think>": 100}
def encode(
self,
text: str,
add_special_tokens: bool = False,
**kwargs,
) -> list[int]:
self.last_encode = (text, add_special_tokens, kwargs)
return [len(text)]
def _tokenizer():
return get_deepseek_v4_tokenizer(FakeHfTokenizer())
def _model_config():
return SimpleNamespace(
multimodal_config=None,
allowed_local_media_path="",
allowed_media_domains=None,
)
def _load_reference_case(case_id: int):
data = json.loads((FIXTURES_DIR / f"test_input_{case_id}.json").read_text())
if isinstance(data, dict):
return data["messages"], data.get("tools")
return data, None
def _render_reference_case(case_id: int, **kwargs):
messages, tools = _load_reference_case(case_id)
conversation, _, _ = parse_chat_messages(
messages,
_model_config(),
content_format="string",
)
return _tokenizer().apply_chat_template(
conversation=conversation,
messages=messages,
tools=tools,
tokenize=False,
**kwargs,
)
def test_deepseek_v4_tokenizer_registered():
assert TokenizerRegistry.load_tokenizer_cls("deepseek_v4").__name__ == (
"DeepseekV4Tokenizer"
)
assert RENDERER_REGISTRY.load_renderer_cls("deepseek_v4").__name__ == (
"DeepseekV4Renderer"
)
def test_deepseek_v4_defaults_to_chat_mode():
prompt = _tokenizer().apply_chat_template(
[{"role": "user", "content": "Hello"}],
tokenize=False,
)
assert prompt == ("<begin▁of▁sentence><User>Hello<Assistant></think>")
@pytest.mark.parametrize("kwargs", [{"thinking": True}, {"enable_thinking": True}])
def test_deepseek_v4_enables_thinking_with_compatible_kwargs(kwargs):
prompt = _tokenizer().apply_chat_template(
[{"role": "user", "content": "Hello"}],
tokenize=False,
**kwargs,
)
assert prompt == ("<begin▁of▁sentence><User>Hello<Assistant><think>")
def test_deepseek_v4_uses_v4_tool_prompt_from_request_tools():
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather for a city",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
},
}
]
prompt = _tokenizer().apply_chat_template(
[{"role": "user", "content": "Weather?"}],
tools=tools,
tokenize=False,
)
assert "## Tools" in prompt
assert "<DSMLtool_calls>" in prompt
assert "</DSMLtool_calls>" in prompt
assert "function_calls" not in prompt
assert '"name": "get_weather"' in prompt
assert prompt.endswith("<User>Weather?<Assistant></think>")
def test_deepseek_v4_renders_parsed_history_tool_arguments():
messages = [
{"role": "user", "content": "List the repo"},
{
"role": "assistant",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "str_replace_editor",
"arguments": '{"command": "view", "path": "/testbed"}',
},
}
],
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": "file list",
},
]
tools = [
{
"type": "function",
"function": {
"name": "str_replace_editor",
"description": "Edit files",
"parameters": {
"type": "object",
"properties": {
"command": {"type": "string"},
"path": {"type": "string"},
},
"required": ["command", "path"],
},
},
}
]
conversation, _, _ = parse_chat_messages(
messages,
_model_config(),
content_format="string",
)
prompt = _tokenizer().apply_chat_template(
conversation=conversation,
messages=messages,
tools=tools,
tokenize=False,
)
assert '<DSMLparameter name="command" string="true">view' in prompt
assert '<DSMLparameter name="path" string="true">/testbed' in prompt
assert 'parameter name="arguments"' not in prompt
@pytest.mark.parametrize("reasoning_effort", ["none", "low", "medium", "high"])
def test_deepseek_v4_accepts_openai_reasoning_effort_values(reasoning_effort):
prompt = _tokenizer().apply_chat_template(
[{"role": "user", "content": "Hello"}],
tokenize=False,
enable_thinking=True,
reasoning_effort=reasoning_effort,
)
assert prompt.endswith("<Assistant><think>")
assert "Reasoning Effort: Absolute maximum" not in prompt
def test_deepseek_v4_preserves_reference_max_reasoning_effort():
prompt = _tokenizer().apply_chat_template(
[{"role": "user", "content": "Hello"}],
tokenize=False,
enable_thinking=True,
reasoning_effort="max",
)
assert prompt.startswith(
"<begin▁of▁sentence>Reasoning Effort: Absolute maximum"
)
@pytest.mark.parametrize(
("case_id", "kwargs"),
[
(1, {"thinking": True}),
(2, {"thinking": True}),
(3, {"thinking": True}),
(4, {}),
],
)
def test_deepseek_v4_matches_reference_golden_fixtures(case_id, kwargs):
prompt = _render_reference_case(case_id, **kwargs)
expected = (FIXTURES_DIR / f"test_output_{case_id}.txt").read_text()
assert prompt == expected
@@ -0,0 +1,123 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for DeepSeekV4ToolParser."""
import json
from unittest.mock import MagicMock
from vllm.tool_parsers import ToolParserManager
from vllm.tool_parsers.deepseekv4_tool_parser import DeepSeekV4ToolParser
MOCK_TOKENIZER = MagicMock()
MOCK_TOKENIZER.get_vocab.return_value = {}
TC_START = "<DSMLtool_calls>"
TC_END = "</DSMLtool_calls>"
INV_START = '<DSMLinvoke name="'
INV_END = "</DSMLinvoke>"
PARAM_START = '<DSMLparameter name="'
PARAM_END = "</DSMLparameter>"
def make_parser(tools=None) -> DeepSeekV4ToolParser:
return DeepSeekV4ToolParser(MOCK_TOKENIZER, tools=tools)
def make_request(tools=None) -> MagicMock:
req = MagicMock()
req.tools = tools
return req
def build_tool_call(func_name: str, params: dict[str, str]) -> str:
param_strs = "".join(
f'{PARAM_START}{k}" string="true">{v}{PARAM_END}\n' for k, v in params.items()
)
return f'{TC_START}\n{INV_START}{func_name}">\n{param_strs}{INV_END}\n{TC_END}'
def stream(parser: DeepSeekV4ToolParser, full_text: str, chunk_size: int = 7):
deltas = []
previous_text = ""
for start in range(0, len(full_text), chunk_size):
delta_text = full_text[start : start + chunk_size]
current_text = previous_text + delta_text
delta = parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[1],
request=make_request(),
)
previous_text = current_text
if delta is not None:
deltas.append(delta)
return deltas
def reconstruct_args(deltas, tool_index: int = 0) -> str:
fragments = []
for delta in deltas:
if delta.tool_calls:
for tool_call in delta.tool_calls:
if (
tool_call.index == tool_index
and tool_call.function
and tool_call.function.arguments
):
fragments.append(tool_call.function.arguments)
return "".join(fragments)
def test_registered():
assert ToolParserManager.get_tool_parser("deepseek_v4") is DeepSeekV4ToolParser
def test_extract_tool_calls():
parser = make_parser()
model_output = "Let me check. " + build_tool_call(
"get_weather", {"location": "Beijing", "unit": "celsius"}
)
result = parser.extract_tool_calls(model_output, make_request())
assert result.tools_called
assert result.content == "Let me check. "
assert len(result.tool_calls) == 1
tool_call = result.tool_calls[0]
assert tool_call.function.name == "get_weather"
assert json.loads(tool_call.function.arguments) == {
"location": "Beijing",
"unit": "celsius",
}
def test_function_calls_block_is_not_accepted():
parser = make_parser()
model_output = build_tool_call("search", {"query": "vllm"}).replace(
"tool_calls", "function_calls"
)
result = parser.extract_tool_calls(model_output, make_request())
assert not result.tools_called
assert result.content == model_output
def test_streaming_extracts_complete_invokes():
parser = make_parser()
full_text = build_tool_call("search", {"query": "deepseek v4"})
deltas = stream(parser, full_text, chunk_size=5)
names = [
tool_call.function.name
for delta in deltas
if delta.tool_calls
for tool_call in delta.tool_calls
]
assert names == ["search"]
assert json.loads(reconstruct_args(deltas)) == {"query": "deepseek v4"}
@@ -0,0 +1,92 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.v1.attention.utils import create_vllm_config
from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadataBuilder
from vllm.v1.kv_cache_interface import MLAAttentionSpec
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
def test_indexer_builder_deepseek_v4_compressed_slot_mapping_uses_storage_block_size():
"""Regression test: DeepseekV4 compression path must compute slot_mapping from
compressed positions, not reuse the uncompressed common metadata mapping.
"""
device = torch.device("cuda")
# storage_block_size = block_size // compress_ratio = 256 // 4 = 64
kv_cache_spec = MLAAttentionSpec(
block_size=256,
num_kv_heads=1,
head_size=128,
dtype=torch.bfloat16,
compress_ratio=4,
)
vllm_config = create_vllm_config(max_model_len=1024)
builder = DeepseekV32IndexerMetadataBuilder(
kv_cache_spec=kv_cache_spec,
layer_names=["dummy"],
vllm_config=vllm_config,
device=device,
)
# Construct a single request where:
# - num_computed = 240 (=> compressed_pos_start = 60)
# - query_len = 40 (=> num_groups = 10)
# => compressed positions are 60..69 which cross the storage block boundary at 64.
query_start_loc = torch.tensor([0, 40], dtype=torch.int32, device=device)
query_start_loc_cpu = query_start_loc.cpu()
seq_lens = torch.tensor([280], dtype=torch.int32, device=device) # 240 + 40
# Two blocks: compressed positions 0..63 map to block 5, 64..127 map to block 7.
block_table_tensor = torch.tensor([[5, 7]], dtype=torch.int32, device=device)
# Dummy uncompressed slot mapping (length == uncompressed num_actual_tokens).
slot_mapping = torch.full((40,), -123, dtype=torch.int64, device=device)
common = CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu_upper_bound=seq_lens.cpu(),
num_reqs=1,
num_actual_tokens=40,
max_query_len=40,
max_seq_len=280,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
causal=True,
)
md = builder.build(common_prefix_len=0, common_attn_metadata=common)
# The compressed slot_mapping retains the original uncompressed size (40).
# Only every compress_ratio-th position gets a valid slot; the rest are -1.
assert md.slot_mapping.numel() == 40
valid_slots = md.slot_mapping[md.slot_mapping >= 0]
assert valid_slots.numel() == 10 # 40 tokens / compress_ratio 4
storage_bs = kv_cache_spec.storage_block_size # 64
# Compressed positions 60..63 land in block 5, positions 64..69 in block 7.
expected = torch.tensor(
[
5 * storage_bs + 60,
5 * storage_bs + 61,
5 * storage_bs + 62,
5 * storage_bs + 63,
]
+ [
7 * storage_bs + 0,
7 * storage_bs + 1,
7 * storage_bs + 2,
7 * storage_bs + 3,
7 * storage_bs + 4,
7 * storage_bs + 5,
],
dtype=torch.int64,
device=device,
)
torch.testing.assert_close(valid_slots, expected)
+3 -2
View File
@@ -1855,10 +1855,11 @@ def test_generate_scheduler_kv_cache_config():
def new_mla_spec(cache_dtype_str=None):
# head_size = kv_lora_rank(512) + qk_rope_head_dim(64) = 576
return MLAAttentionSpec(
block_size=16,
num_kv_heads=16,
head_size=64,
num_kv_heads=1,
head_size=576,
dtype=torch.float32,
cache_dtype_str=cache_dtype_str,
)
+19 -20
View File
@@ -557,19 +557,19 @@ def test_prefill_hybrid_model_eagle():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(req1.block_hashes) == num_full_blocks
assert computed_blocks.get_block_ids() == (
[1, 2, 3, 4],
[0, 9, 10, 11],
[0, 16, 17, 18],
[1, 2, 3, 4, 5],
[0, 0, 10, 11, 12],
[0, 0, 17, 18, 19],
)
assert num_computed_tokens == 4 * block_size
assert num_computed_tokens == 5 * block_size
num_new_tokens = len(all_token_ids) - num_computed_tokens
blocks = manager.allocate_slots(
req1, num_new_tokens, num_computed_tokens, computed_blocks
)
assert blocks is not None and blocks.get_block_ids() == (
[22, 23, 24],
[25, 26, 27],
[28, 29, 30],
[22, 23],
[24, 25],
[26, 27],
)
for block_per_group in computed_blocks.blocks:
for block in block_per_group:
@@ -591,7 +591,7 @@ def test_prefill_hybrid_model_eagle():
make_block_hash_with_group_id(block_hashes[0], 1),
make_block_hash_with_group_id(block_hashes[0], 2),
],
4,
5,
)
# Evict the first block of full attention, makes total cache miss.
@@ -605,7 +605,7 @@ def test_prefill_hybrid_model_eagle():
0,
)
# Evict the last block of all layers, reduces the hit length to 3.
# Evict the last block of all layers, reduces the hit length to 4.
_test_partial_request_hit(
manager,
block_size,
@@ -617,10 +617,10 @@ def test_prefill_hybrid_model_eagle():
make_block_hash_with_group_id(block_hashes[-1], 1),
make_block_hash_with_group_id(block_hashes[-1], 2),
],
3,
4,
)
# Evict the last block of full attention, reduces the hit length to 3.
# Evict the last block of full attention, reduces the hit length to 4.
_test_partial_request_hit(
manager,
block_size,
@@ -628,7 +628,7 @@ def test_prefill_hybrid_model_eagle():
"5",
all_token_ids,
[make_block_hash_with_group_id(block_hashes[-1], 0)],
3,
4,
)
# Since the last block of full attention is dropped for eagle, evict
@@ -655,12 +655,11 @@ def test_prefill_hybrid_model_eagle():
3,
)
# Evict different set of blocks for full attention and sliding window makes
# total cache miss.
# The cache hit length of full attention is 4 * block_size.
# The cache hit length of sliding window is 3 * block_size.
# Then it is cache miss as the two type of layers
# have different hit length.
# Evict different set of blocks for full attention and sliding window.
# Full loses its last block so it drops to 4 full blocks after the eagle
# pop; SWA lost block 0 (outside the sliding window of the final hit),
# which is not required for the K+1 anchor at position 4. Coordinated
# single-drop aligns both groups at hit=4.
_test_partial_request_hit(
manager,
block_size,
@@ -672,7 +671,7 @@ def test_prefill_hybrid_model_eagle():
make_block_hash_with_group_id(block_hashes[0], 1),
make_block_hash_with_group_id(block_hashes[0], 2),
],
0,
4,
)
@@ -893,7 +892,7 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]):
# - 2 groups: 1 full + 1 other
_EAGLE_HYBRID_MODEL_TEST_CASES = [
# 2 groups: 1 full + 1 other
pytest.param(["full", "sliding_window"], 2, id="2g-full+sw"),
pytest.param(["full", "sliding_window"], 3, id="2g-full+sw"),
]
+2
View File
@@ -1892,6 +1892,7 @@ def create_scheduler_with_priority(
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
block_size=block_size,
hash_block_size=block_size,
)
@@ -4008,6 +4009,7 @@ def _create_encoder_decoder_scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
block_size=block_size,
hash_block_size=block_size,
structured_output_manager=StructuredOutputManager(vllm_config),
)
@@ -91,8 +91,10 @@ def test_basic_interface():
assert request_id in kv_connector_metadata.reqs_to_recv["my-engine-id"]
req_meta = kv_connector_metadata.reqs_to_recv["my-engine-id"][request_id]
# local_block_ids is list[list[int]] (per-group); flatten for comparison.
all_block_ids = [bid for group in req_meta.local_block_ids for bid in group]
for block_id, block in zip(
req_meta.local_block_ids,
all_block_ids,
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
request_id
],
@@ -228,15 +230,15 @@ def test_scheduler_request_finished():
# Case: Capped length (Successful prefill, need to send to decoder)
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
delay_free, _ = scheduler_connector.request_finished(request, block_ids=[10, 11])
delay_free, _ = scheduler_connector.request_finished(request, block_ids=([10, 11],))
assert delay_free is True
assert "id-1" in scheduler_connector._reqs_need_send
assert scheduler_connector._reqs_need_send["id-1"][1] == [10, 11]
assert scheduler_connector._reqs_need_send["id-1"][1] == [[10, 11]]
# Case: Aborted (No need to transfer, free blocks immediately)
scheduler_connector._reqs_need_send.clear()
request.status = RequestStatus.FINISHED_ABORTED
delay_free, _ = scheduler_connector.request_finished(request, block_ids=[12])
delay_free, _ = scheduler_connector.request_finished(request, block_ids=([12],))
assert delay_free is False
assert len(scheduler_connector._reqs_need_send) == 0
assert "id-1" in scheduler_connector._reqs_not_processed
@@ -334,7 +336,7 @@ async def test_kv_producer(monkeypatch):
send_meta = SendBlockMeta(
p_req_id="p-req-1",
transfer_id=transfer_id,
local_block_ids=[10, 11],
local_block_ids=[[10, 11]],
ready=asyncio.Event(),
)
prefill_worker.reqs_need_send[transfer_id] = send_meta
@@ -346,7 +348,7 @@ async def test_kv_producer(monkeypatch):
remote_port=54321,
remote_tp_size=1,
remote_tp_rank=0,
req_blocks={"d-req-1": (transfer_id, [20, 21])},
req_blocks={"d-req-1": (transfer_id, [[20, 21]])},
kv_caches_base_addr=[0x2000],
block_lens=[block_len],
)
@@ -389,7 +391,7 @@ async def test_kv_producer(monkeypatch):
prefill_worker.reqs_need_send[transfer_id] = send_meta
send_meta.sent = 0
send_meta.ready.set()
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20])
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [[20]])
# Worker processes the consumer's request
await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta)
# Verify transfer parameters are correct: 11 to 20
@@ -407,7 +409,7 @@ async def test_kv_producer(monkeypatch):
prefill_worker.reqs_need_send[transfer_id] = send_meta
send_meta.sent = 0
send_meta.ready.set()
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20, 21, 22])
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [[20, 21, 22]])
# Worker processes the consumer's request
await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta)
# This should not be called because error.
@@ -424,7 +426,7 @@ async def test_kv_producer(monkeypatch):
prefill_worker.reqs_need_send[transfer_id] = send_meta
send_meta.sent = 0
send_meta.ready.clear()
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20, 21])
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [[20, 21]])
# Worker processes the consumer's request
await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta)
# This should not be called because timeout.
@@ -443,7 +445,7 @@ async def test_kv_producer(monkeypatch):
prefill_worker.reqs_need_send[transfer_id] = send_meta
send_meta.sent = 0
send_meta.ready.set()
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20, 21])
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [[20, 21]])
# Worker processes the consumer's request
await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta)
mock_send_blocks.assert_called_once()
@@ -481,7 +483,7 @@ async def test_kv_consumuer(monkeypatch):
"d-req-1": PullReqMeta(
d_req_id="d-req-1",
transfer_id="xfer-req-1",
local_block_ids=[100, 101],
local_block_ids=[[100, 101]],
remote_engine_id="p-engine",
remote_bootstrap_addr="http://bootstrap:33333",
pull_tasks_count=1,
@@ -514,7 +516,7 @@ async def test_kv_consumuer(monkeypatch):
assert sent_meta.remote_hostname == "127.0.0.1"
assert sent_meta.remote_port == 54321
assert sent_meta.req_blocks["d-req-1"] == ("xfer-req-1", [100, 101])
assert sent_meta.req_blocks["d-req-1"] == ("xfer-req-1", [[100, 101]])
# Verify internal state is updated correctly.
assert "d-req-1" in decode_worker.finished_recving_reqs
@@ -538,7 +540,7 @@ async def test_worker_get_finished_timeout(monkeypatch):
prefill_worker.reqs_need_send["tx-expired"] = SendBlockMeta(
p_req_id="p-req-expired",
transfer_id="tx-expired",
local_block_ids=[1, 2],
local_block_ids=[[1, 2]],
ready=MagicMock(),
expire_time=time.perf_counter() - 100,
)
@@ -547,7 +549,7 @@ async def test_worker_get_finished_timeout(monkeypatch):
prefill_worker.reqs_need_send["tx-active"] = SendBlockMeta(
p_req_id="p-req-active",
transfer_id="tx-active",
local_block_ids=[3, 4],
local_block_ids=[[3, 4]],
ready=MagicMock(),
expire_time=time.perf_counter() + 100,
)
@@ -703,7 +705,7 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
prefill_worker.sender_loop = asyncio.get_event_loop()
transfer_id = "xfer-hetero-1"
local_block_ids = [10, 11]
local_block_ids = [[10, 11]]
send_meta = SendBlockMeta(
p_req_id="p-req-h1",
transfer_id=transfer_id,
@@ -720,9 +722,9 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
mock_socket.send_multipart = AsyncMock()
identity = b"consumer-hetero"
# Assign different remote block IDs per D rank
# Assign different remote block IDs per D rank (nested per-group)
d_rank_remote_blocks = {
rank: [20 + i * 10, 21 + i * 10] for i, rank in enumerate(target_d_ranks)
rank: [[20 + i * 10, 21 + i * 10]] for i, rank in enumerate(target_d_ranks)
}
with patch.object(
@@ -757,11 +759,15 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
dst_ptrs = call_args[2]
lengths = call_args[3]
# Flatten nested per-group block IDs for assertions
flat_local = [b for g in local_block_ids for b in g]
flat_remote = [b for g in remote_block_ids for b in g]
# Heterogeneous TP: blocks cannot be coalesced because
# local and remote block_lens differ
assert len(src_ptrs) == len(local_block_ids)
assert len(dst_ptrs) == len(local_block_ids)
assert len(lengths) == len(local_block_ids)
assert len(src_ptrs) == len(flat_local)
assert len(dst_ptrs) == len(flat_local)
assert len(lengths) == len(flat_local)
# Compute expected offsets based on TP ratio
if d_tp_size <= P_TP_SIZE:
@@ -775,9 +781,7 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
expected_dst_off = 0
expected_xfer_len = remote_block_len
for idx, (lblk, rblk) in enumerate(
zip(local_block_ids, remote_block_ids)
):
for idx, (lblk, rblk) in enumerate(zip(flat_local, flat_remote)):
assert src_ptrs[idx] == (
0x1000 + lblk * local_block_len + expected_src_off
)
@@ -0,0 +1,410 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for MooncakeConnector HMA (Hybrid Memory Architecture) support.
Covers sliding-window clipping, multi-group metadata shape, multi-group
send trimming, and group-count invariant checking in _build_transfer_params.
"""
import asyncio
from unittest.mock import patch
import pytest
from vllm.config import set_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector import (
KVConnectorRole,
MooncakeConnector,
MooncakeConnectorMetadata,
MooncakeConnectorScheduler,
MooncakeXferMetadata,
SendBlockMeta,
TransferRegion,
)
from .test_mooncake_connector import FakeMooncakeWrapper, patch_worker_dependencies
from .utils import create_request, create_vllm_config, make_kv_cache_config
# ---------------------------------------------------------------------------
# test_sw_sizes: blocks_per_sw computed from KVCacheConfig
# ---------------------------------------------------------------------------
@pytest.mark.cpu_test
@pytest.mark.parametrize(
"swa_enabled,expected_blocks_per_sw",
[
# SWA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128+1)
(True, [0, 128 + 1]),
# SWA disabled: only FullAttentionSpec (0)
(False, [0]),
],
)
def test_sw_sizes(swa_enabled, expected_blocks_per_sw):
"""blocks_per_sw is correctly computed based on SWA enabled/disabled."""
block_size = 16
vllm_config = create_vllm_config(
kv_connector="MooncakeConnector",
kv_role="kv_both",
block_size=block_size,
)
# Override so HMA detection works
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False
kv_cache_config = make_kv_cache_config(
block_size=block_size, swa_enabled=swa_enabled, sw_size=2048
)
scheduler = MooncakeConnectorScheduler(
vllm_config=vllm_config,
engine_id="test-engine",
kv_cache_config=kv_cache_config,
)
assert scheduler.blocks_per_sw == expected_blocks_per_sw
# ---------------------------------------------------------------------------
# test_is_hma_required: derived from kv_cache_config groups
# ---------------------------------------------------------------------------
@pytest.mark.cpu_test
@pytest.mark.parametrize(
"swa_enabled,disable_hma,expected_is_hma",
[
(True, False, True), # SWA group present, HMA enabled
(True, True, False), # SWA group present, but HMA disabled
(False, False, False), # FA only, HMA not needed
],
)
def test_is_hma_required(swa_enabled, disable_hma, expected_is_hma):
"""_is_hma_required is correctly derived from kv_cache_config."""
block_size = 16
vllm_config = create_vllm_config(
kv_connector="MooncakeConnector",
kv_role="kv_both",
block_size=block_size,
)
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = disable_hma
kv_cache_config = make_kv_cache_config(
block_size=block_size, swa_enabled=swa_enabled
)
scheduler = MooncakeConnectorScheduler(
vllm_config=vllm_config,
engine_id="test-engine",
kv_cache_config=kv_cache_config,
)
assert scheduler._is_hma_required is expected_is_hma
# ---------------------------------------------------------------------------
# test_get_sw_clipped_blocks: sliding-window clipping logic
# ---------------------------------------------------------------------------
@pytest.mark.cpu_test
def test_get_sw_clipped_blocks():
"""get_sw_clipped_blocks clips SWA group but keeps FA group intact."""
block_size = 16
vllm_config = create_vllm_config(
kv_connector="MooncakeConnector",
kv_role="kv_both",
block_size=block_size,
)
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False
# SW=128 tokens → 128/16 = 8 blocks + 1 = 9 blocks_per_sw
kv_cache_config = make_kv_cache_config(
block_size=block_size, swa_enabled=True, sw_size=128
)
scheduler = MooncakeConnectorScheduler(
vllm_config=vllm_config,
engine_id="test-engine",
kv_cache_config=kv_cache_config,
)
assert scheduler.blocks_per_sw == [0, 9]
# FA group: 20 blocks, SW group: 20 blocks (exceeds window)
fa_blocks = list(range(20))
sw_blocks = list(range(100, 120))
block_ids = (fa_blocks, sw_blocks)
clipped = scheduler.get_sw_clipped_blocks(block_ids)
# FA: untouched (blocks_per_sw[0] = 0)
assert clipped[0] == fa_blocks
# SW: clipped to last 9 blocks
assert clipped[1] == sw_blocks[-9:]
assert len(clipped[1]) == 9
@pytest.mark.cpu_test
def test_get_sw_clipped_blocks_noop_no_hma():
"""get_sw_clipped_blocks is a no-op when HMA is not required."""
block_size = 16
vllm_config = create_vllm_config(
kv_connector="MooncakeConnector",
kv_role="kv_both",
block_size=block_size,
)
# FA only → _is_hma_required = False
kv_cache_config = make_kv_cache_config(block_size=block_size, swa_enabled=False)
scheduler = MooncakeConnectorScheduler(
vllm_config=vllm_config,
engine_id="test-engine",
kv_cache_config=kv_cache_config,
)
assert scheduler._is_hma_required is False
block_ids = ([1, 2, 3],)
clipped = scheduler.get_sw_clipped_blocks(block_ids)
assert clipped == [[1, 2, 3]]
# ---------------------------------------------------------------------------
# test_metadata_hma_block_ids: MooncakeConnectorMetadata stores per-group IDs
# ---------------------------------------------------------------------------
@pytest.mark.cpu_test
def test_metadata_hma_block_ids():
"""MooncakeConnectorMetadata.add_new_req stores per-group block IDs."""
metadata = MooncakeConnectorMetadata()
# FA group: 6 blocks, SW group: 3 blocks (clipped)
fa_blocks = [0, 1, 2, 3, 4, 5]
sw_blocks = [10, 11, 12]
# Test recv path
metadata.add_new_req(
request_id="recv-req",
local_block_ids=[fa_blocks, sw_blocks],
kv_transfer_params={
"transfer_id": "recv-req",
"remote_engine_id": "remote-engine",
"remote_bootstrap_addr": "http://bootstrap:33333",
},
load_remote_cache=True,
)
assert "recv-req" in metadata.reqs_to_recv["remote-engine"]
req_meta = metadata.reqs_to_recv["remote-engine"]["recv-req"]
assert len(req_meta.local_block_ids) == 2
assert req_meta.local_block_ids[0] == fa_blocks
assert req_meta.local_block_ids[1] == sw_blocks
# Test send path
metadata.add_new_req(
request_id="send-req",
local_block_ids=[fa_blocks, sw_blocks],
kv_transfer_params={
"transfer_id": "send-req",
},
load_remote_cache=False,
)
assert "send-req" in metadata.reqs_to_send
transfer_id, stored_blocks = metadata.reqs_to_send["send-req"]
assert transfer_id == "send-req"
assert len(stored_blocks) == 2
assert stored_blocks[0] == fa_blocks
assert stored_blocks[1] == sw_blocks
# ---------------------------------------------------------------------------
# test_build_transfer_params_multi_group_trimming
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake"
".mooncake_connector.TransferEngine",
FakeMooncakeWrapper,
)
async def test_build_transfer_params_multi_group_trimming(monkeypatch):
"""_build_transfer_params trims per-group blocks when local > remote."""
monkeypatch.setenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "5")
vllm_config = create_vllm_config(
kv_connector="MooncakeConnector", kv_role="kv_producer"
)
with set_current_vllm_config(vllm_config), patch_worker_dependencies():
connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
worker = connector.connector_worker
block_len = 4096
# Call _build_transfer_params directly (avoids send_kv_to_decode
# async event loop complexity).
transfer_id = "xfer-hma-trim"
send_meta = SendBlockMeta(
p_req_id="p-trim",
transfer_id=transfer_id,
# FA: 4 blocks, SW: 3 blocks (producer has more)
local_block_ids=[[10, 11, 12, 13], [20, 21, 22]],
ready=asyncio.Event(),
)
xfer_meta = MooncakeXferMetadata(
remote_hostname="consumer-host",
remote_port=54321,
remote_tp_size=1,
remote_tp_rank=0,
req_blocks={
"d-trim": (
transfer_id,
# FA: 2 blocks, SW: 2 blocks (consumer needs fewer)
[[30, 31], [40, 41]],
)
},
kv_caches_base_addr=[0x2000],
block_lens=[block_len],
)
local_regions = [
TransferRegion(
base_addr=0x1000, block_len=block_len, kv_block_len=block_len
),
]
remote_regions = [
TransferRegion(
base_addr=0x2000, block_len=block_len, kv_block_len=block_len
),
]
ready_reqs = [("d-trim", send_meta)]
(
src_ptrs,
dst_ptrs,
lengths,
err_reqs,
err_msg,
) = await worker._build_transfer_params(
ready_reqs, xfer_meta, local_regions, remote_regions
)
# No errors
assert err_reqs == []
assert err_msg is None
# After trimming: FA [10..13] → last 2 → [12,13]; SW [20..22] → last 2 → [21,22]
# Flattened: [12,13,21,22] = 4 blocks → coalesced into some transfers
assert len(src_ptrs) > 0
assert len(dst_ptrs) == len(src_ptrs)
assert len(lengths) == len(src_ptrs)
worker.shutdown()
# ---------------------------------------------------------------------------
# test_build_transfer_params_group_count_mismatch
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake"
".mooncake_connector.TransferEngine",
FakeMooncakeWrapper,
)
async def test_build_transfer_params_group_count_mismatch(monkeypatch):
"""_build_transfer_params reports an error when group counts differ."""
monkeypatch.setenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "5")
vllm_config = create_vllm_config(
kv_connector="MooncakeConnector", kv_role="kv_producer"
)
with set_current_vllm_config(vllm_config), patch_worker_dependencies():
connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
worker = connector.connector_worker
block_len = 4096
transfer_id = "xfer-mismatch"
send_meta = SendBlockMeta(
p_req_id="p-mismatch",
transfer_id=transfer_id,
# Producer has 2 groups
local_block_ids=[[10, 11], [20, 21]],
ready=asyncio.Event(),
)
# Consumer has only 1 group — group count mismatch
xfer_meta = MooncakeXferMetadata(
remote_hostname="consumer-host",
remote_port=54321,
remote_tp_size=1,
remote_tp_rank=0,
req_blocks={
"d-mismatch": (transfer_id, [[30, 31]]),
},
kv_caches_base_addr=[0x2000],
block_lens=[block_len],
)
local_regions = [
TransferRegion(
base_addr=0x1000, block_len=block_len, kv_block_len=block_len
),
]
remote_regions = [
TransferRegion(
base_addr=0x2000, block_len=block_len, kv_block_len=block_len
),
]
ready_reqs = [("d-mismatch", send_meta)]
(
src_ptrs,
dst_ptrs,
lengths,
err_reqs,
err_msg,
) = await worker._build_transfer_params(
ready_reqs, xfer_meta, local_regions, remote_regions
)
# Mismatched req is reported via err_reqs/err_msg with no transfers built.
assert err_reqs == ["d-mismatch"]
assert err_msg == "KV group count mismatch"
assert src_ptrs == []
assert dst_ptrs == []
assert lengths == []
worker.shutdown()
# ---------------------------------------------------------------------------
# test_request_finished_with_hma_groups
# ---------------------------------------------------------------------------
@pytest.mark.cpu_test
def test_request_finished_with_hma_groups():
"""request_finished correctly handles per-group block_ids."""
block_size = 16
vllm_config = create_vllm_config(
kv_connector="MooncakeConnector",
kv_role="kv_producer",
block_size=block_size,
)
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False
kv_cache_config = make_kv_cache_config(
block_size=block_size, swa_enabled=True, sw_size=128
)
scheduler = MooncakeConnectorScheduler(
vllm_config=vllm_config,
engine_id="test-engine",
kv_cache_config=kv_cache_config,
)
request = create_request(request_id=1, do_remote_decode=True)
request.kv_transfer_params["transfer_id"] = request.request_id
from vllm.v1.request import RequestStatus
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
# 2 groups: FA with 10 blocks, SW with 20 blocks (will be clipped)
fa_blocks = list(range(10))
sw_blocks = list(range(100, 120))
block_ids = (fa_blocks, sw_blocks)
delay_free, _ = scheduler.request_finished(request, block_ids)
assert delay_free is True
assert request.request_id in scheduler._reqs_need_send
_, stored_blocks = scheduler._reqs_need_send[request.request_id]
# FA: untouched
assert stored_blocks[0] == fa_blocks
# SW: clipped to last 9 blocks (sw_size=128, block_size=16 → 8+1=9)
assert stored_blocks[1] == sw_blocks[-9:]
@@ -76,6 +76,7 @@ def create_scheduler() -> Scheduler:
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
block_size=16,
hash_block_size=16,
)
+1 -1
View File
@@ -7,7 +7,7 @@ set -e
# Default values
# Keep DEEPGEMM_GIT_REF in sync with cmake/external_projects/deepgemm.cmake
DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git"
DEEPGEMM_GIT_REF="477618cd51baffca09c4b0b87e97c03fe827ef03"
DEEPGEMM_GIT_REF="891d57b4db1071624b5c8fa0d1e51cb317fa709f"
WHEEL_DIR=""
# Parse command line arguments
+41 -3
View File
@@ -404,10 +404,24 @@ def rotary_embedding(
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
rope_dim_offset: int = 0,
inverse: bool = False,
) -> None:
torch.ops._C.rotary_embedding(
positions, query, key, head_size, cos_sin_cache, is_neox
)
if rope_dim_offset == 0 and not inverse:
torch.ops._C.rotary_embedding(
positions, query, key, head_size, cos_sin_cache, is_neox
)
else:
torch.ops._C.rotary_embedding(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox,
rope_dim_offset,
inverse,
)
# layer norm ops
@@ -2503,6 +2517,30 @@ def topk_sigmoid(
)
def topk_hash_softplus_sqrt(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
input_tokens: torch.Tensor | None = None,
hash_indices_table: torch.Tensor | None = None,
) -> None:
torch.ops._moe_C.topk_softplus_sqrt(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
routed_scaling_factor,
e_score_correction_bias,
input_tokens,
hash_indices_table,
)
def grouped_topk(
scores: torch.Tensor,
num_expert_group: int,
+3
View File
@@ -51,6 +51,9 @@ class AttentionConfig:
use_prefill_query_quantization: bool = False
"""If set, quantize query for attention in prefill."""
use_fp4_indexer_cache: bool = False
"""If set, use fp4 indexer cache for dsv32 family model (not support yet)"""
def compute_hash(self) -> str:
"""
Provide a hash that uniquely identifies all the configs
+14
View File
@@ -51,6 +51,18 @@ class CacheConfig:
"""Whether block_size was explicitly provided. Derived automatically."""
user_specified_mamba_block_size: bool = field(default=False, init=False)
"""Whether mamba_block_size was explicitly provided. Derived automatically."""
hash_block_size: SkipValidation[int] | None = None # type: ignore
"""Block size (in tokens) used for computing Request's block_hashes.
This can be set to a finer granularity than the physical KV cache block
sizes (e.g. 8) as long as every KV cache group's `block_size` is divisible
by it. This enables prefix-caching keys to be computed at the finest common
granularity and then merged for larger physical block sizes.
This config is not static default. If left unspecified, vLLM will choose a
default based on the resolved KV cache groups (typically the smallest KV
cache block size when there are multiple groups).
"""
gpu_memory_utilization: float = Field(default=0.92, gt=0, le=1)
"""The fraction of GPU memory to be used for the model executor, which can
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
@@ -182,6 +194,8 @@ class CacheConfig:
"num_gpu_blocks_override",
"enable_prefix_caching",
"prefix_caching_hash_algo",
# Prefix-caching implementation detail (doesn't affect compiled graph).
"hash_block_size",
"mamba_page_size_padded",
"user_specified_block_size",
"user_specified_mamba_block_size",
+1
View File
@@ -749,6 +749,7 @@ class CompilationConfig:
"vllm::kda_attention",
"vllm::sparse_attn_indexer",
"vllm::rocm_aiter_sparse_attn_indexer",
"vllm::deepseek_v4_attention",
]
def compute_hash(self) -> str:
+6 -4
View File
@@ -50,7 +50,7 @@ class IrOpPriorityConfig:
name: {
provider: IrOp.registry[name].impls[provider].uuid() for provider in p
}
for name, p in asdict(self).items()
for name, p in asdict(self).items() # type: ignore[call-overload]
}
return hash_factors(factors)
@@ -77,7 +77,7 @@ class IrOpPriorityConfig:
current_platform.import_ir_kernels()
with contextlib.ExitStack() as stack:
for field in fields(self):
for field in fields(self): # type: ignore[arg-type]
op_priority = getattr(self, field.name)
assert op_priority is not None, (
f"IR op priority for {field.name} must be set"
@@ -98,7 +98,7 @@ class IrOpPriorityConfig:
A helper to create an IrOpPriorityConfig where fields not specified in kwargs
use the given default list.
"""
for field in fields(cls):
for field in fields(cls): # type: ignore[arg-type]
if field.name not in kwargs:
kwargs[field.name] = list(default)
@@ -109,6 +109,7 @@ MoEBackend = Literal[
"auto",
"triton",
"deep_gemm",
"deep_gemm_mega_moe",
"cutlass",
"flashinfer_trtllm",
"flashinfer_cutlass",
@@ -136,8 +137,9 @@ class KernelConfig:
"""Backend for MoE expert computation kernels. Available options:
- "auto": Automatically select the best backend based on model and hardware
- "triton": Use Triton-based fused MoE kernels
- "triton": Use Triton-based fused MoE kernels
- "deep_gemm": Use DeepGEMM kernels (FP8 block-quantized only)
- "deep_gemm_mega_moe": Use DeepGEMM mega MoE kernels
- "cutlass": Use vLLM CUTLASS kernels
- "flashinfer_trtllm": Use FlashInfer with TRTLLM-GEN kernels
- "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels
+5 -1
View File
@@ -83,7 +83,7 @@ logger = init_logger(__name__)
RunnerOption = Literal["auto", RunnerType]
ConvertType = Literal["none", "embed", "classify"]
ConvertOption = Literal["auto", ConvertType]
TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"]
TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32", "deepseek_v4"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal[
"raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs"
@@ -134,6 +134,7 @@ class ModelConfig:
- "slow" will always use the slow tokenizer.
- "mistral" will always use the tokenizer from `mistral_common`.
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.
- "deepseek_v4" will always use the tokenizer from `deepseek_v4`.
- "qwen_vl" will always use the tokenizer from `qwen_vl`.
- Other custom values can be supported via plugins."""
trust_remote_code: bool = False
@@ -565,6 +566,8 @@ class ModelConfig:
self.tokenizer_mode = "qwen_vl"
elif arch == "DeepseekV32ForCausalLM":
self.tokenizer_mode = "deepseek_v32"
elif arch == "DeepseekV4ForCausalLM":
self.tokenizer_mode = "deepseek_v4"
if self.tokenizer_mode != "auto":
logger.info(
@@ -952,6 +955,7 @@ class ModelConfig:
# imports during override detection (e.g., MXFP4 imports Triton)
"mxfp4",
"gpt_oss_mxfp4",
"deepseek_v4_fp8",
"cpu_awq",
"humming",
"gguf",
+11 -1
View File
@@ -287,13 +287,23 @@ class SpeculativeConfig:
@staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
initial_architecture = hf_config.architectures[0]
if hf_config.model_type in ("deepseek_v3", "deepseek_v32", "glm_moe_dsa"):
if hf_config.model_type in (
"deepseek_v3",
"deepseek_v32",
"glm_moe_dsa",
):
hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]}
)
if hf_config.model_type == "deepseek_v4":
hf_config.model_type = "deepseek_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{"n_predict": n_predict, "architectures": ["DeepSeekV4MTPModel"]}
)
if hf_config.model_type in ("pangu_ultra_moe"):
hf_config.model_type = "pangu_ultra_moe_mtp"
if hf_config.model_type == "pangu_ultra_moe_mtp":
@@ -29,6 +29,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
SupportsHMA,
)
from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_utils import (
MooncakeBootstrapServer,
@@ -43,10 +44,12 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import FullAttentionSpec, SlidingWindowSpec
from vllm.v1.request import RequestStatus
from vllm.v1.worker.utils import select_common_block_size
@@ -252,7 +255,7 @@ class MooncakeXferMetadata(
remote_port: int
remote_tp_size: int
remote_tp_rank: int
req_blocks: dict[ReqId, tuple[TransferId, list[int]]]
req_blocks: dict[ReqId, tuple[TransferId, list[list[int]]]]
kv_caches_base_addr: list[int]
block_lens: list[int]
@@ -280,7 +283,7 @@ class MooncakeXferResponse(
class PullReqMeta:
d_req_id: ReqId
transfer_id: TransferId
local_block_ids: list[int]
local_block_ids: list[list[int]]
remote_engine_id: EngineId
remote_bootstrap_addr: str
# Set expire time to avoid infinitely sending requests.
@@ -293,7 +296,7 @@ class PullReqMeta:
class SendBlockMeta:
p_req_id: ReqId
transfer_id: TransferId
local_block_ids: list[int]
local_block_ids: list[list[int]]
ready: asyncio.Event
expire_time: float = float("inf")
need_send: int = 0
@@ -306,13 +309,13 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
# Use (engine_id, dp_rank) to group reqs with same dp.
# See comments in MooncakeBootstrapServer.
self.reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
self.reqs_to_send: dict[ReqId, tuple[TransferId, list[int]]] = {}
self.reqs_to_send: dict[ReqId, tuple[TransferId, list[list[int]]]] = {}
self.reqs_not_processed: set[TransferId] = set()
def add_new_req(
self,
request_id: ReqId,
local_block_ids: list[int],
local_block_ids: list[list[int]],
kv_transfer_params: dict[str, Any],
load_remote_cache: bool = True,
):
@@ -330,7 +333,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
self.reqs_to_send[request_id] = (transfer_id, local_block_ids)
class MooncakeConnector(KVConnectorBase_V1):
class MooncakeConnector(KVConnectorBase_V1, SupportsHMA):
def __init__(
self,
vllm_config: VllmConfig,
@@ -344,13 +347,18 @@ class MooncakeConnector(KVConnectorBase_V1):
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
if role == KVConnectorRole.SCHEDULER:
assert kv_cache_config is not None, (
"kv_cache_config is required for SCHEDULER role"
)
self.connector_scheduler: MooncakeConnectorScheduler | None = (
MooncakeConnectorScheduler(vllm_config, self.engine_id)
MooncakeConnectorScheduler(vllm_config, self.engine_id, kv_cache_config)
)
self.connector_worker: MooncakeConnectorWorker | None = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id)
self.connector_worker = MooncakeConnectorWorker(
vllm_config, self.engine_id, kv_cache_config
)
@classmethod
def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
@@ -401,6 +409,14 @@ class MooncakeConnector(KVConnectorBase_V1):
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, (block_ids,))
def request_finished_all_groups(
self,
request: "Request",
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
@@ -445,8 +461,14 @@ class MooncakeConnector(KVConnectorBase_V1):
class MooncakeConnectorScheduler:
"""Implementation of Scheduler side methods"""
def __init__(self, vllm_config: VllmConfig, engine_id: str):
def __init__(
self,
vllm_config: VllmConfig,
engine_id: str,
kv_cache_config: "KVCacheConfig",
):
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
assert vllm_config.kv_transfer_config
self.is_kv_producer: bool = (
@@ -457,15 +479,49 @@ class MooncakeConnectorScheduler:
)
logger.info("Initializing Mooncake Transfer Engine Scheduler %s", engine_id)
self._is_hma_required = (
not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
and any(
not isinstance(g.kv_cache_spec, FullAttentionSpec)
for g in kv_cache_config.kv_cache_groups
)
)
# Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_send: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_recv: dict[ReqId, tuple[Request, list[list[int]]]] = {}
self._reqs_need_send: dict[ReqId, tuple[Request, list[list[int]]]] = {}
# Reqs to remove from processed set because they're not to send after
# remote prefill or aborted.
self._reqs_not_processed: set[TransferId] = set()
# Compute sliding window block counts per KV cache group.
sw_sizes_tokens: list[tuple[int, int]] = [
(g.kv_cache_spec.sliding_window, g.kv_cache_spec.block_size)
if isinstance(g.kv_cache_spec, SlidingWindowSpec)
else (0, self.block_size)
for g in kv_cache_config.kv_cache_groups
]
# cdiv(n_tokens, block_size) gives blocks/window; add 1 to
# conservatively account for boundary overlap.
self.blocks_per_sw = [
cdiv(n_tokens, block_size) + 1 if n_tokens else 0
for n_tokens, block_size in sw_sizes_tokens
]
def get_sw_clipped_blocks(
self,
block_ids: tuple[list[int], ...] | list[list[int]],
) -> list[list[int]]:
"""Clip per-group block IDs to sliding window size."""
if len(block_ids) == 0 or not self._is_hma_required:
return list(block_ids)
return [
blocks[-self.blocks_per_sw[i] :] if self.blocks_per_sw[i] > 0 else blocks
for i, blocks in enumerate(block_ids)
]
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]:
@@ -530,9 +586,12 @@ class MooncakeConnectorScheduler:
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
local_block_ids = (
blocks.get_unhashed_block_ids() if num_external_tokens > 0 else []
unhashed_block_ids = (
blocks.get_unhashed_block_ids_all_groups()
if num_external_tokens > 0
else ()
)
local_block_ids = self.get_sw_clipped_blocks(unhashed_block_ids)
# Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = (request, local_block_ids)
else:
@@ -587,7 +646,7 @@ class MooncakeConnectorScheduler:
def request_finished(
self,
request: "Request",
block_ids: list[int],
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
"""
Once a request is finished, determine whether request blocks
@@ -630,10 +689,13 @@ class MooncakeConnectorScheduler:
# TODO: check whether block_ids actually ever be 0. If not we could
# remove the conditional below
delay_free_blocks = len(block_ids) > 0
delay_free_blocks = any(len(group) > 0 for group in block_ids)
if delay_free_blocks:
self._reqs_need_send[request.request_id] = (request, block_ids)
self._reqs_need_send[request.request_id] = (
request,
self.get_sw_clipped_blocks(block_ids),
)
return delay_free_blocks, None
@@ -641,7 +703,12 @@ class MooncakeConnectorScheduler:
class MooncakeConnectorWorker:
"""Implementation of Worker side methods"""
def __init__(self, vllm_config: VllmConfig, engine_id: str):
def __init__(
self,
vllm_config: VllmConfig,
engine_id: str,
kv_cache_config: "KVCacheConfig | None" = None,
):
if TransferEngine is None:
logger.error("Mooncake is not available")
raise RuntimeError("Mooncake is not available")
@@ -752,6 +819,7 @@ class MooncakeConnectorWorker:
self.block_size = vllm_config.cache_config.block_size
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.kv_cache_config = kv_cache_config
self.use_mla = self.model_config.use_mla
self._sync_block_size_with_kernel()
@@ -1103,27 +1171,61 @@ class MooncakeConnectorWorker:
remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}"
for d_req_id, send_meta in ready_reqs:
_, remote_block_ids = agent_meta.req_blocks[d_req_id]
num_remote_blocks = len(remote_block_ids)
if num_remote_blocks == 0:
_, remote_block_ids_per_group = agent_meta.req_blocks[d_req_id]
if not remote_block_ids_per_group or all(
len(g) == 0 for g in remote_block_ids_per_group
):
continue
local_block_ids = send_meta.local_block_ids
# Partial prefix cache hit: just read uncomputed blocks.
num_local_blocks = len(local_block_ids)
if num_local_blocks < num_remote_blocks:
# Per-group partial hit trimming, then flatten.
# With HMA, groups share the same KV tensor but use different
# block ranges. We trim and concatenate so the coalescer and
# address math see one flat block list — same as non-HMA, but
# now including blocks from every group.
local_block_ids: list[int] = []
remote_block_ids: list[int] = []
has_block_error = False
if len(send_meta.local_block_ids) != len(remote_block_ids_per_group):
logger.error(
"req %s: local blocks(%d) less than remote blocks(%d)!",
"req %s: KV group count mismatch: local=%d, remote=%d",
d_req_id,
num_local_blocks,
num_remote_blocks,
len(send_meta.local_block_ids),
len(remote_block_ids_per_group),
)
err_reqs.append(d_req_id)
if err_msg is None:
err_msg = "KV group count mismatch"
continue
for local_group, remote_group in zip(
send_meta.local_block_ids, remote_block_ids_per_group
):
n_local = len(local_group)
n_remote = len(remote_group)
if n_local < n_remote:
logger.error(
"req %s: local blocks(%d) < remote blocks(%d) "
"in a KV cache group",
d_req_id,
n_local,
n_remote,
)
has_block_error = True
break
if n_local > n_remote:
# Partial prefix cache hit: just read uncomputed blocks.
local_group = local_group[-n_remote:]
local_block_ids.extend(local_group)
remote_block_ids.extend(remote_group)
if has_block_error:
err_reqs.append(d_req_id)
if err_msg is None:
err_msg = "P num blocks less than D"
continue
if num_local_blocks > num_remote_blocks:
local_block_ids = local_block_ids[-num_remote_blocks:]
if not local_block_ids:
continue
# Group by indices
group_local_block_ids, group_remote_block_ids = group_concurrent_contiguous(
@@ -1215,7 +1317,7 @@ class MooncakeConnectorWorker:
logger.debug(
"Sending kv_caches for request %s (%d blocks) to %s",
d_req_id,
num_remote_blocks,
len(local_block_ids),
remote_session,
)
@@ -1273,23 +1375,24 @@ class MooncakeConnectorWorker:
continue
seen_base_addresses.append(base_addr)
curr_tensor_size_bytes = cache.nbytes
if tensor_size_bytes is None:
tensor_size_bytes = curr_tensor_size_bytes
tensor_size_bytes = cache.nbytes
self.num_blocks = cache.shape[0]
assert cache.shape[0] == self.num_blocks, (
"All kv cache tensors must have the same number of blocks"
)
assert curr_tensor_size_bytes % self.num_blocks == 0, (
"Mooncake expects each kv cache tensor size to be "
"divisible by the number of blocks."
)
self.block_len_per_layer.append(
curr_tensor_size_bytes // self.num_blocks
)
# Use stride-based block length so RDMA reaches the last
# block's padding (e.g. DeepseekV4 MLA alignment). stride(0)
# reflects the actual byte distance between consecutive
# blocks in GPU memory, which matches or exceeds the
# shape-based size.
block_len = cache.stride(0) * cache.element_size()
self.block_len_per_layer.append(block_len)
kv_data_ptrs.append(base_addr)
kv_data_lens.append(curr_tensor_size_bytes)
kv_data_lens.append(self.num_blocks * block_len)
self.kv_caches_base_addr = seen_base_addresses
self.seen_base_addresses = seen_base_addresses
+9
View File
@@ -299,6 +299,9 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
tools: list[ChatCompletionFunctionToolParam] | None
"""The tools for developer role."""
task: str | None
"""Model-specific task marker. Currently passed through for DeepSeek V4."""
ChatCompletionMessageParam: TypeAlias = (
OpenAIChatCompletionMessageParam
@@ -333,6 +336,9 @@ class ConversationMessage(TypedDict, total=False):
tools: list[ChatCompletionFunctionToolParam] | None
"""The tools for developer role."""
task: str | None
"""Model-specific task marker. Currently passed through for DeepSeek V4."""
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]
@@ -1566,6 +1572,9 @@ def _parse_chat_message_content(
if "name" in message and isinstance(message["name"], str):
result_msg["name"] = message["name"]
if "task" in message and isinstance(message["task"], str):
result_msg["task"] = message["task"]
if role == "developer":
result_msg["tools"] = message.get("tools", None)
return result
@@ -100,6 +100,8 @@ class DeepGemmFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
else params.weight_scale,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=self.use_deep_gemm_e8m0,
is_bmm=getattr(layer, "is_bmm", False),
bmm_batch_size=getattr(layer, "bmm_batch_size", 0),
)
replace_parameter(layer, params.WEIGHT, dg_weight)
replace_parameter(layer, scale_attr, dg_weight_scale)
@@ -1422,6 +1422,20 @@ class MLADims:
def get_mla_dims(model_config: ModelConfig) -> MLADims:
hf_text_config = model_config.hf_text_config
# Check if this is a DeepseekV4 config (uses unified head_dim + rope_head_dim)
if hasattr(hf_text_config, "compress_ratios"):
# DeepseekV4 style config: unified head_dim with rope_head_dim
head_dim = hf_text_config.head_dim
rope_head_dim = hf_text_config.qk_rope_head_dim
return MLADims(
q_lora_rank=hf_text_config.q_lora_rank,
kv_lora_rank=head_dim,
qk_nope_head_dim=head_dim - rope_head_dim,
qk_rope_head_dim=rope_head_dim,
v_head_dim=head_dim,
)
# DeepseekV2/V3 style config
return MLADims(
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
kv_lora_rank=hf_text_config.kv_lora_rank,
@@ -2191,6 +2205,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: ColumnParallelLinear,
# DSV3.2 MLA Specific Arguments
indexer: object | None = None,
q_pad_num_heads: int | None = None,
) -> None:
@@ -2213,6 +2228,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.indexer = indexer
self.q_pad_num_heads = q_pad_num_heads
self.supports_quant_query_input = True
self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
# Use flashinfer's optimized concat_mla_k kernel when available.
# The kernel is optimized for DeepSeek V3 dimensions:
@@ -0,0 +1,438 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, ClassVar, cast
import torch
from torch import nn
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
)
from vllm.model_executor.layers.utils import cublas_gemm_bf16_bf16_fp32
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.ops.deepseek_v4_ops.fused_compress_quant_cache import (
_fused_kv_compress_norm_rope_insert_indexer_attn,
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn,
_fused_kv_compress_norm_rope_insert_sparse_attn,
)
from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import (
MXFP4_BLOCK_SIZE,
)
from vllm.v1.kv_cache_interface import (
KVCacheSpec,
MLAAttentionSpec,
SlidingWindowMLASpec,
)
class CompressorBackend(AttentionBackend):
def __init__(self):
super().__init__()
@staticmethod
def get_name() -> str:
return "CompressorBackend"
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(1)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [512, 1024]
@staticmethod
def get_builder_cls() -> type["CompressorMetadataBuilder"]:
return CompressorMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
assert num_kv_heads == 1
return (num_blocks, block_size, head_size)
@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
if include_num_layers_dimension:
return (0, 1, 2, 3)
return (0, 1, 2)
@dataclass
class CompressorMetadata:
block_table: torch.Tensor
slot_mapping: torch.Tensor
block_size: int
token_to_req_indices: torch.Tensor | None = None # [num_tokens]
class CompressorMetadataBuilder(AttentionMetadataBuilder):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert isinstance(self.kv_cache_spec, SlidingWindowMLASpec | MLAAttentionSpec)
mla_spec = cast(SlidingWindowMLASpec | MLAAttentionSpec, self.kv_cache_spec)
self.block_size = mla_spec.block_size
self.token_to_req_indices = torch.zeros(
self.vllm_config.scheduler_config.max_num_batched_tokens,
dtype=torch.int32,
device=self.device,
)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> CompressorMetadata:
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_reqs = common_attn_metadata.num_reqs
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
x = torch.repeat_interleave(torch.arange(num_reqs), query_lens).pin_memory()
token_to_req_indices = self.token_to_req_indices[: x.shape[0]]
token_to_req_indices.copy_(x, non_blocking=True)
return CompressorMetadata(
block_table=common_attn_metadata.block_table_tensor.clamp_(min=0),
slot_mapping=common_attn_metadata.slot_mapping,
block_size=self.block_size,
token_to_req_indices=token_to_req_indices,
)
class CompressorStateCache(torch.nn.Module, AttentionLayerBase):
def __init__(
self,
state_dim: int,
dtype: torch.dtype,
compress_ratio: int,
prefix: str,
):
super().__init__()
self.state_dim = state_dim
self.dtype = dtype
self.prefix = prefix
self.kv_cache = torch.tensor([])
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
assert self.dtype == torch.float32
assert compress_ratio in [4, 128]
coff = 1 + (compress_ratio == 4)
self.sliding_window = coff * compress_ratio
# Block size is constrained by tensor sharing between compressor states
# and KV blocks. Since compressor states share the same physical tensor
# as KV blocks, they must use the same page size.
# The KV block shape [256//4, head_dim] = [64, 584] determines:
# - C4 compressor block shape [4, 2*512*2*4] -> block_size = 4
# - C128 compressor block shape [8, 512*2*4] -> block_size = 8
# TODO(yifan): make block size automatically determined and configurable.
if compress_ratio == 4:
self.block_size = 4
elif compress_ratio == 128:
self.block_size = 8
else:
raise ValueError(f"Invalid compress ratio: {compress_ratio}")
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
return SlidingWindowMLASpec( # only has one vector instead of K + V
block_size=self.block_size,
num_kv_heads=1,
head_size=self.state_dim,
dtype=self.dtype,
sliding_window=self.sliding_window,
alignment=576, # NOTE: FlashMLA requires 576B alignment
)
def forward(self): ...
def get_attn_backend(self) -> type[AttentionBackend]:
return CompressorBackend
class DeepseekCompressor(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
compress_ratio: int,
hidden_size: int,
head_dim: int,
rotate: bool = False,
prefix: str = "",
k_cache_prefix="",
use_fp4_cache: bool = False,
):
super().__init__()
self.compress_ratio = compress_ratio
self.hidden_size = hidden_size
self.head_dim = head_dim
self.rotate = rotate
self.prefix = prefix
self.k_cache_prefix = k_cache_prefix
self.use_fp4_cache = use_fp4_cache
config = vllm_config.model_config.hf_config
self.rope_head_dim = config.qk_rope_head_dim
self.nope_head_dim = self.head_dim - self.rope_head_dim
self.rms_norm_eps = config.rms_norm_eps
self.device = current_platform.device_type
self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self.max_model_len = vllm_config.model_config.max_model_len
self.overlap = compress_ratio == 4
self.coff = 1 + self.overlap
state_dtype = torch.float32
self.ape = nn.Parameter(
torch.empty(
(compress_ratio, self.coff * self.head_dim),
dtype=state_dtype,
device=self.device,
),
requires_grad=False,
)
self.fused_wkv_wgate = MergedColumnParallelLinear(
self.hidden_size,
[self.coff * self.head_dim, self.coff * self.head_dim],
bias=False,
return_bias=False,
quant_config=None,
disable_tp=True,
prefix=f"{prefix}.fused_wkv_wgate",
)
self.norm = RMSNorm(self.head_dim, self.rms_norm_eps)
self.state_cache = CompressorStateCache(
state_dim=2 * self.coff * self.head_dim, # kv_state + score_state
dtype=state_dtype,
compress_ratio=compress_ratio,
prefix=f"{prefix}.state_cache",
)
# Save reference to static_forward_context for forward-time KV cache lookup.
# get_current_vllm_config() is only available during __init__, not forward.
self._static_forward_context = (
vllm_config.compilation_config.static_forward_context
)
if self.head_dim == 512:
assert not use_fp4_cache, (
"MXFP4 cache is only supported for indexer (head=128)"
)
self._fused_kernel = _fused_kv_compress_norm_rope_insert_sparse_attn
self._quant_block = 64
self._token_stride = self.nope_head_dim + self.rope_head_dim * 2
self._scale_dim = self.nope_head_dim // 64 + 1 # 7 real + 1 pad
self._num_warps = 4
elif self.head_dim == 128:
if use_fp4_cache:
self._fused_kernel = (
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn
)
self._quant_block = MXFP4_BLOCK_SIZE
self._token_stride = self.head_dim // 2
self._scale_dim = self.head_dim // MXFP4_BLOCK_SIZE
else:
self._fused_kernel = _fused_kv_compress_norm_rope_insert_indexer_attn
self._quant_block = 128
self._token_stride = self.head_dim
self._scale_dim = 4 # single float32 scale
self._num_warps = 1
else:
raise ValueError(
f"Unsupported head_dim for fused quant+cache: {self.head_dim}"
)
def forward(
self,
# [num_tokens, hidden_size]
x: torch.Tensor,
# [num_tokens]
positions: torch.Tensor,
rotary_emb,
) -> None:
num_tokens, _ = x.shape
# bf16 weights/activations but fp32 output for numerical stability of
# the downstream compressor math.
kv_score = cublas_gemm_bf16_bf16_fp32(x, self.fused_wkv_wgate.weight)
# Each of shape [num_tokens, coff * self.head_dim]
# input bf16, output are fp32
kv, score = kv_score.split(
[self.coff * self.head_dim, self.coff * self.head_dim], dim=-1
)
# Get the metadata and handle dummy profiling run.
attn_metadata = get_forward_context().attn_metadata
if not isinstance(attn_metadata, dict):
return
state_metadata = cast(
CompressorMetadata, attn_metadata[self.state_cache.prefix]
)
token_to_req_indices = state_metadata.token_to_req_indices
slot_mapping = state_metadata.slot_mapping
num_actual = slot_mapping.shape[0]
block_table = state_metadata.block_table
block_size = state_metadata.block_size
# [num_blocks, block_size, kv_dim+score_dim], where kv_dim == score_dim
state_cache = self.state_cache.kv_cache
# kv_state stored in first half, score_state stored in second half
state_width = state_cache.shape[-1] // 2
# Store the KV and score (with fused APE addition) in the state.
# NOTE: PDL is disabled — both this kernel and _fused_kernel below
# depend on preceding kernel outputs (kv/score from the cublas GEMM;
# state_cache from this kernel) but neither emits/waits on PDL grid
# dependency primitives, so launch_pdl=True caused a read-after-write
# race and non-deterministic output.
_save_partial_states_kernel[(num_actual,)](
kv,
kv.stride(0),
score,
score.stride(0),
self.ape,
self.ape.stride(0),
positions,
state_cache,
state_cache.stride(0),
state_cache.stride(1),
slot_mapping,
block_size,
HEAD_SIZE=kv.shape[-1],
TRITON_BLOCK_SIZE=triton.next_power_of_2(kv.shape[-1]),
STATE_WIDTH=state_width,
COMPRESS_RATIO=self.compress_ratio,
launch_pdl=False,
)
# Fused: compress → RMSNorm → RoPE → FP8 quant → KV cache write.
# RoPE requirements (kernel applies forward GPT-J style rotation):
# - is_neox_style=False (interleaved pairs, NOT split-half)
# - cos_sin_cache layout: [max_pos, rope_head_dim] with first half cos,
# second half sin (per-pair, length rope_head_dim // 2 each)
# - applied to LAST rope_head_dim elements of head_dim
# - position used: (positions // compress_ratio) * compress_ratio
cos_sin_cache = rotary_emb.cos_sin_cache
k_cache_metadata = cast(Any, attn_metadata[self.k_cache_prefix])
kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache
self._fused_kernel[(num_actual,)](
# state cache
state_cache,
state_cache.stride(0),
state_cache.stride(1),
# metadata
token_to_req_indices,
positions,
slot_mapping,
block_table,
block_table.stride(0),
block_size,
# RMSNorm
self.norm.weight,
self.rms_norm_eps,
# RoPE
cos_sin_cache,
cos_sin_cache.stride(0),
# KV cache
kv_cache,
k_cache_metadata.slot_mapping,
kv_cache.shape[1], # paged KV cache block size (tokens per block)
# constexprs
HEAD_SIZE=self.head_dim,
TRITON_BLOCK_SIZE=triton.next_power_of_2(self.head_dim),
STATE_WIDTH=state_width,
COMPRESS_RATIO=self.compress_ratio,
OVERLAP=self.overlap,
ROPE_HEAD_DIM=self.rope_head_dim,
FP8_MAX=448.0,
QUANT_BLOCK=self._quant_block,
TOKEN_STRIDE=self._token_stride,
SCALE_DIM=self._scale_dim,
KV_BLOCK_STRIDE=kv_cache.stride(0),
num_warps=self._num_warps,
launch_pdl=False,
)
@triton.jit
def _save_partial_states_kernel(
kv_ptr,
kv_stride,
score_ptr,
score_stride,
ape_ptr,
ape_stride,
positions_ptr,
state_cache_ptr,
state_cache_stride0,
state_cache_stride1,
slot_mapping_ptr,
block_size,
HEAD_SIZE: tl.constexpr,
TRITON_BLOCK_SIZE: tl.constexpr,
# state_cache last dim packs [kv_state, score_state], each STATE_WIDTH wide.
STATE_WIDTH: tl.constexpr,
COMPRESS_RATIO: tl.constexpr,
):
token_idx = tl.program_id(0)
slot_id = tl.load(slot_mapping_ptr + token_idx)
# Skip padded / invalid tokens (slot_id == -1 is the PAD sentinel used
# by vLLM). During CUDA graph replay the batch may contain padding
# tokens whose slot_mapping is -1; writing to kv_state[-1] would be an
# illegal memory access.
if slot_id < 0:
return
block_idx = slot_id // block_size
pos_in_block = slot_id % block_size
base_ptr = (
state_cache_ptr
+ block_idx * state_cache_stride0
+ pos_in_block * state_cache_stride1
)
block = tl.arange(0, TRITON_BLOCK_SIZE)
mask = block < HEAD_SIZE
kv = tl.load(kv_ptr + token_idx * kv_stride + block, mask=mask)
tl.store(base_ptr + block, kv, mask=mask)
# Fused: score += ape[position % compress_ratio]
position = tl.load(positions_ptr + token_idx)
ape_row = position % COMPRESS_RATIO
ape = tl.load(ape_ptr + ape_row * ape_stride + block, mask=mask)
score = tl.load(score_ptr + token_idx * score_stride + block, mask=mask)
tl.store(
base_ptr + STATE_WIDTH + block,
score + ape,
mask=mask,
)
File diff suppressed because it is too large Load Diff
+45 -1
View File
@@ -117,8 +117,10 @@ class RoutingMethodType(IntEnum):
Custom = (6,)
# Simulated
Simulated = (7,)
# Deepseek V4 -> sqrtsoftplus + Bias + Normalize
DeepseekV4 = (8,)
# Unspecified
Unspecified = 8.0
Unspecified = 9.0
def get_routing_method_type(
@@ -128,6 +130,14 @@ def get_routing_method_type(
num_expert_group: int | None,
has_e_score_bias: bool,
) -> RoutingMethodType:
if scoring_func == "sqrtsoftplus":
# DeepSeek V4 uses sqrtsoftplus routing with optional routing bias
# and top-k renormalization.
if renormalize:
return RoutingMethodType.DeepseekV4
else:
return RoutingMethodType.Unspecified
if has_e_score_bias:
if (num_expert_group or 0) > 0 and scoring_func == "sigmoid":
return RoutingMethodType.DeepSeekV3
@@ -230,6 +240,13 @@ class FusedMoEQuantConfig:
_w2: FusedMoEQuantDesc
is_nvfp4_scale_swizzled: bool = True
# MXFP4-specific TRTLLM parameters for SwiGLU activation clamping.
# These correspond to gemm1_alpha, gemm1_beta, gemm1_clamp_limit
# in TrtLlmMxfp4ExpertsBase.
gemm1_alpha: float | None = None
gemm1_beta: float | None = None
gemm1_clamp_limit: float | None = None
def __post_init__(self):
assert not self.per_act_token_quant or self.block_shape is None, (
"illegal quantization"
@@ -477,6 +494,9 @@ class FusedMoEQuantConfig:
w2_zp: torch.Tensor | None = None,
weight_dtype: torch.dtype | str | None = None,
is_nvfp4_scale_swizzled: bool = True,
gemm1_alpha: float | None = None,
gemm1_beta: float | None = None,
gemm1_clamp_limit: float | None = None,
) -> "FusedMoEQuantConfig":
"""
General builder function for a FusedMoEQuantConfig.
@@ -507,6 +527,9 @@ class FusedMoEQuantConfig:
- w1_zp: Optional w1 zero points for int4/int8 quantization.
- w2_zp: Optional w2 zero points for int4/int8 quantization.
- is_nvfp4_scale_swizzled: Whether to swizzle the nvfp4 scale swizzling.
- gemm1_alpha: Optional MXFP4 TRTLLM SwiGLU alpha parameter.
- gemm1_beta: Optional MXFP4 TRTLLM SwiGLU beta parameter.
- gemm1_clamp_limit: Optional MXFP4 TRTLLM SwiGLU clamp limit.
"""
assert not isinstance(quant_dtype, str) or quant_dtype in {
"nvfp4",
@@ -540,6 +563,9 @@ class FusedMoEQuantConfig:
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
),
is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled,
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=gemm1_clamp_limit,
)
assert quant_config.per_act_token_quant == per_act_token_quant
assert quant_config.per_out_ch_quant == per_out_ch_quant
@@ -650,6 +676,9 @@ def mxfp4_w4a16_moe_quant_config(
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
gemm1_alpha: float | None = None,
gemm1_beta: float | None = None,
gemm1_clamp_limit: float | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for unquantized activations and mxfp4 weights.
@@ -659,6 +688,9 @@ def mxfp4_w4a16_moe_quant_config(
_a2=FusedMoEQuantDesc(),
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=gemm1_clamp_limit,
)
@@ -670,6 +702,9 @@ def mxfp4_mxfp8_moe_quant_config(
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
block_shape: list[int] | None = None,
gemm1_alpha: float | None = None,
gemm1_beta: float | None = None,
gemm1_clamp_limit: float | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for mxfp4 activations and mxfp4 weights.
@@ -679,6 +714,9 @@ def mxfp4_mxfp8_moe_quant_config(
_a2=FusedMoEQuantDesc("mxfp8"),
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=gemm1_clamp_limit,
)
@@ -712,6 +750,9 @@ def ocp_mx_moe_quant_config(
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
block_shape: list[int] | None = None,
gemm1_alpha: float | None = None,
gemm1_beta: float | None = None,
gemm1_clamp_limit: float | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for mxfp4 activations and mxfp4 weights.
@@ -729,6 +770,9 @@ def ocp_mx_moe_quant_config(
per_act_token_quant=False,
per_out_ch_quant=False,
block_shape=block_shape,
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=gemm1_clamp_limit,
)
@@ -25,15 +25,20 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8_packed_for_deepgemm,
silu_mul_per_token_group_quant_fp8_colmajor,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
silu_mul_quant_fp8_packed_triton as fused_silu_mul_fp8_quant_packed,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8Static128BlockSym,
kMxfp4Static,
)
from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
get_mk_alignment_for_contiguous_layout,
is_deep_gemm_supported,
m_grouped_fp8_fp4_gemm_nt_contiguous,
m_grouped_fp8_gemm_nt_contiguous,
)
from vllm.utils.import_utils import has_deep_gemm
@@ -197,8 +202,14 @@ class DeepGemmExperts(mk.FusedMoEExpertsModular):
M_sum, N = input.size()
activation_out_dim = self.adjust_N_for_activation(N, activation)
# 1. DeepGemm UE8M0: use packed per-token-group quant
# 1. DeepGemm UE8M0: fused SiLU+mul+clamp+quant+pack
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
if activation == MoEActivation.SILU:
return fused_silu_mul_fp8_quant_packed(
input=input,
output_q=output,
group_size=block_k,
)
act_out = torch.empty(
(M_sum, activation_out_dim), dtype=input.dtype, device=input.device
)
@@ -312,3 +323,225 @@ class DeepGemmExperts(mk.FusedMoEExpertsModular):
expert_map=expert_map,
output=output,
)
class DeepGemmFP4Experts(mk.FusedMoEExpertsModular):
"""DeepGemm-based fused MoE expert implementation for FP4 weights.
Uses m_grouped_fp8_fp4_gemm_nt_contiguous with FP8 activations and
MXFP4 (FP4 E2M1 packed as uint8) weights. Requires SM100+ (Blackwell).
"""
# FP8 activation block size (hardcoded since mxfp4_w4a8 quant config
# does not set a block_shape on the activation descriptor).
_ACT_BLOCK_K = 128
# FP4 weight block size
_WEIGHT_BLOCK_K = 32
def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
super().__init__(moe_config=moe_config, quant_config=quant_config)
assert quant_config.weight_quant_dtype == "mxfp4"
assert not quant_config.per_act_token_quant
assert not quant_config.per_out_ch_quant
self.gemm1_clamp_limit = quant_config.gemm1_clamp_limit
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
from vllm.platforms import current_platform
return (
is_deep_gemm_supported()
and current_platform.is_device_capability_family(100)
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kMxfp4Static, kFp8Dynamic128Sym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.SWIGLUSTEP]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return not (
moe_parallel_config.use_fi_nvl_two_sided_kernels
or moe_parallel_config.use_fi_nvl_one_sided_kernels
)
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
block_m = get_mk_alignment_for_contiguous_layout()[0]
M_sum = compute_aligned_M(
M, topk, local_num_experts, block_m, expert_tokens_meta
)
assert M_sum % block_m == 0
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M_sum, max(activation_out_dim, K))
workspace2 = (M_sum, max(N, K))
output = (M, K)
return (workspace1, workspace2, output)
def _act_mul_quant(
self, input: torch.Tensor, output: torch.Tensor, activation: MoEActivation
) -> tuple[torch.Tensor, torch.Tensor]:
block_k = self._ACT_BLOCK_K
scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
M_sum, N = input.size()
activation_out_dim = self.adjust_N_for_activation(N, activation)
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
assert activation == MoEActivation.SILU
return fused_silu_mul_fp8_quant_packed(
input=input,
output_q=output,
group_size=block_k,
clamp_limit=self.gemm1_clamp_limit,
)
if activation == MoEActivation.SILU:
use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
return silu_mul_per_token_group_quant_fp8_colmajor(
input=input,
output=output,
use_ue8m0=use_ue8m0,
)
act_out = torch.empty(
(M_sum, activation_out_dim), dtype=input.dtype, device=input.device
)
self.activation(activation, act_out, input)
return per_token_group_quant_fp8(
act_out, block_k, column_major_scales=True, out_q=output
)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert a1q_scale is not None
assert a2_scale is None
assert self.w1_scale is not None
assert self.w2_scale is not None
a1q = hidden_states
_, N, _ = w1.size()
# K comes from activations (full hidden dim), not from w1 which is
# packed FP4 (E, N, K//2).
K = a1q.size(1)
local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = local_num_experts
M_sum = compute_aligned_M(
M=topk_ids.size(0),
num_topk=topk_ids.size(1),
local_num_experts=local_num_experts,
alignment=get_mk_alignment_for_contiguous_layout()[0],
expert_tokens_meta=expert_tokens_meta,
)
a1q_perm = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, K)
)
a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
aq=a1q,
aq_scale=a1q_scale,
topk_ids=topk_ids,
local_num_experts=local_num_experts,
expert_map=expert_map,
expert_tokens_meta=expert_tokens_meta,
aq_out=a1q_perm,
)
assert a1q.size(0) == M_sum
# FC1: FP8 activations x FP4 weights
# DeepGEMM 2.4.2 requires FP4-packed weights as int8 (kPackedFP4).
mm1_out = _resize_cache(workspace2, (M_sum, N))
m_grouped_fp8_fp4_gemm_nt_contiguous(
(a1q, a1q_scale),
(w1.view(torch.int8), self.w1_scale),
mm1_out,
expert_ids,
recipe_a=(1, self._ACT_BLOCK_K),
recipe_b=(1, self._WEIGHT_BLOCK_K),
)
# SwiGLU activation + FP8 requant
activation_out_dim = self.adjust_N_for_activation(N, activation)
quant_out = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, activation_out_dim)
)
a2q, a2q_scale = self._act_mul_quant(
input=mm1_out.view(-1, N), output=quant_out, activation=activation
)
# FC2: FP8 activations x FP4 weights
mm2_out = _resize_cache(workspace2, (M_sum, K))
m_grouped_fp8_fp4_gemm_nt_contiguous(
(a2q, a2q_scale),
(w2.view(torch.int8), self.w2_scale),
mm2_out,
expert_ids,
recipe_a=(1, self._ACT_BLOCK_K),
recipe_b=(1, self._WEIGHT_BLOCK_K),
)
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
deepgemm_unpermute_and_reduce(
a=mm2_out,
topk_ids=topk_ids,
topk_weights=topk_weights,
inv_perm=inv_perm,
expert_map=expert_map,
output=output,
)
@@ -28,8 +28,162 @@ from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.import_utils import has_triton_kernels
from ..utils import swiglu_limit_func
logger = init_logger(__name__)
def _patch_make_bitmatrix_metadata() -> None:
"""Monkey-patch make_bitmatrix_metadata to support non-power-of-2 top_k.
triton's tl.arange requires a power-of-2 range. The original kernel
computes BLOCK_SIZE = BLOCK_PER_TOK * TOKS_PER_ROW (= 32 * top_k). For
DeepSeek-V4 with top_k=6 this gives 192, which is not a power of 2 and
causes a compile error at the first forward pass.
Fix: define a drop-in replacement kernel that accepts an extra constexpr
BLOCK_SIZE_PADDED (next power of 2 >= BLOCK_SIZE) and uses it for the
tl.arange call while keeping the actual BLOCK_SIZE as the stride between
thread-blocks so that all flat indices into NonzeroIndx stay correct.
Elements beyond BLOCK_SIZE are masked out (col_indx = 0xffff) and ignored.
This function is called once at module load time and patches the function
inside the triton_kernels tensor module so that SparseMatrix.__post_init__
picks up the fixed version transparently.
"""
import torch
import triton
import triton.language as tl
try:
from vllm.third_party.triton_kernels.tensor_details import (
bitmatrix as _bm,
)
from vllm.third_party.triton_kernels.tensor_details.bitmatrix import (
BitmatrixMetadata,
_keyed_add,
cdiv,
)
from vllm.third_party.triton_kernels.tensor_details.bitmatrix_details.sum_bitmatrix_rows import ( # noqa: E501
sum_bitmatrix_rows,
)
except ImportError:
return
@triton.jit
def _stage2_pow2(
ColSortedIndx,
RowSortedIndx,
NonzeroIndx,
n_tokens,
ColPartialSum,
stride_pm,
stride_pn,
ColOffs,
TOKS_PER_ROW: tl.constexpr,
BLOCK_PER_TOK: tl.constexpr,
BLOCK_SIZE_PADDED: tl.constexpr,
):
# Actual number of elements per block (may not be a power of 2).
BLOCK_SIZE: tl.constexpr = BLOCK_PER_TOK * TOKS_PER_ROW
tl.static_assert(BLOCK_SIZE_PADDED <= 32768)
if isinstance(n_tokens, tl.tensor) and n_tokens.dtype.is_ptr():
n_tokens = tl.load(n_tokens)
nonzero_indx_size = n_tokens * TOKS_PER_ROW
pid_m = tl.program_id(0)
# Use BLOCK_SIZE_PADDED (a power of 2) for tl.arange, but stride by
# the actual BLOCK_SIZE so flat positions in NonzeroIndx are correct.
# Elements with offs_local >= BLOCK_SIZE have offs_global beyond the
# valid range, get col_indx = 0xffff, and are filtered by the mask
# below without producing any output.
offs_local = tl.arange(0, BLOCK_SIZE_PADDED)
offs_global = pid_m * BLOCK_SIZE + offs_local
mask = offs_global < nonzero_indx_size
col_indx = tl.load(NonzeroIndx + offs_global, mask=mask, other=-1).to(tl.uint32)
kv_pairs = ((col_indx << 16) | offs_local).to(tl.uint32)
kv_pairs = tl.sort(kv_pairs, 0)
col_indx = kv_pairs >> 16
offs_global = pid_m * BLOCK_SIZE + (kv_pairs & 0xFFFF)
mask = col_indx != 0xFFFF
x = kv_pairs & 0xFFFF0000 | 0x00000001
cols_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add)
exclusive_run_lengths = (cols_and_inclusive_run_lengths - 1) & 0xFFFF
row_sorted_indx = tl.load(
ColPartialSum + pid_m * stride_pm + col_indx * stride_pn, mask=mask
)
row_sorted_indx += tl.load(ColOffs + col_indx, mask=mask)
row_sorted_indx += exclusive_run_lengths
tl.store(RowSortedIndx + offs_global, row_sorted_indx, mask=mask)
tl.store(ColSortedIndx + row_sorted_indx, offs_global, mask=mask)
def _make_bitmatrix_metadata_pow2_safe(nonzero_indx, bitmatrix):
assert nonzero_indx.ndim == 2
PARTIAL_BLOCK_M = 32
col_sum, col_partial_sum = sum_bitmatrix_rows(
bitmatrix, partials_block_size=PARTIAL_BLOCK_M
)
device = bitmatrix.device
n_indx = nonzero_indx.numel()
n_cols = bitmatrix.shape[1]
col_offs = torch.empty(n_cols, dtype=torch.int32, device=device)
combined_indx = torch.empty(n_indx * 2, dtype=torch.int32, device=device)
col_sorted_indx = combined_indx[:n_indx]
row_sorted_indx = combined_indx[n_indx:]
MEMSET_BLOCK = 1024
memset_grid = (cdiv(n_indx * 2, MEMSET_BLOCK) + n_cols + 1,)
_bm._bitmatrix_metadata_compute_stage1[memset_grid](
combined_indx,
n_indx * 2,
-1,
MEMSET_BLOCK,
col_sum,
col_offs,
col_sum.shape[0],
col_partial_sum,
col_partial_sum.shape[0],
col_partial_sum.stride(0),
col_partial_sum.stride(1),
BLOCK_M=512,
BLOCK_N=512,
)
toks_per_row = nonzero_indx.shape[-1]
block_size = PARTIAL_BLOCK_M * toks_per_row
# Next power of 2 >= block_size (required by tl.arange).
block_size_padded = 1 << (max(block_size, 1) - 1).bit_length()
compute_grid = (cdiv(bitmatrix.shape_max[0], PARTIAL_BLOCK_M),)
_stage2_pow2[compute_grid](
col_sorted_indx,
row_sorted_indx,
nonzero_indx,
bitmatrix.shape[0],
col_partial_sum,
col_partial_sum.stride(0),
col_partial_sum.stride(1),
col_offs,
TOKS_PER_ROW=toks_per_row,
BLOCK_PER_TOK=PARTIAL_BLOCK_M,
BLOCK_SIZE_PADDED=block_size_padded,
)
return BitmatrixMetadata(
col_sum=col_sum,
col_sorted_indx=col_sorted_indx,
row_sorted_indx=row_sorted_indx,
)
# The most reliable patch point: SparseMatrix.__post_init__ looks up
# make_bitmatrix_metadata via its own __globals__ dict (the tensor.py
# module dict). Patching through __globals__ works regardless of how
# sys.modules maps "triton_kernels.tensor" vs
# "vllm.third_party.triton_kernels.tensor".
from triton_kernels.tensor import SparseMatrix as _SparseMatrix
_SparseMatrix.__post_init__.__globals__["make_bitmatrix_metadata"] = (
_make_bitmatrix_metadata_pow2_safe
)
# Also patch the bitmatrix module itself in case it is imported directly.
_bm.make_bitmatrix_metadata = _make_bitmatrix_metadata_pow2_safe
use_legacy_triton_kernels = False
if has_triton_kernels():
@@ -59,6 +213,8 @@ if has_triton_kernels():
use_legacy_triton_kernels = True
else:
raise
if not use_legacy_triton_kernels:
_patch_make_bitmatrix_metadata()
except (AttributeError, ImportError) as e:
logger.error(
"Failed to import Triton kernels. Please make sure your triton "
@@ -497,6 +653,8 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
return False
# (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (which map to 9.4/9.5).
if not has_triton_kernels():
return False
return (9, 0) <= (cap.major, cap.minor) < (11, 0)
@staticmethod
@@ -698,6 +856,37 @@ class UnfusedOAITritonExperts(LoRAExpertsMixin, BaseOAITritonExperts):
def moe_sum(self, input: torch.Tensor, output: torch.Tensor):
ops.moe_sum(input, output)
def activation(
self,
activation: MoEActivation,
output: torch.Tensor,
input: torch.Tensor,
) -> None:
quant_config = self.quant_config or FUSED_MOE_UNQUANTIZED_CONFIG
if activation == MoEActivation.SWIGLUOAI:
alpha = (
quant_config.gemm1_alpha
if quant_config.gemm1_alpha is not None
else 1.702
)
limit = (
quant_config.gemm1_clamp_limit
if quant_config.gemm1_clamp_limit is not None
else 7.0
)
torch.ops._C.swigluoai_and_mul(output, input, alpha, limit)
elif (
activation == MoEActivation.SILU
and quant_config.gemm1_clamp_limit is not None
):
swiglu_limit_func(
output,
input,
quant_config.gemm1_clamp_limit,
)
else:
super().activation(activation, output, input)
def apply(
self,
output: torch.Tensor,
@@ -812,9 +1001,9 @@ class UnfusedOAITritonExperts(LoRAExpertsMixin, BaseOAITritonExperts):
act_input,
)
# matmul_ogs grouped reduction fuse sum across multiple experts:
# matmul_ogs grouped reduction fuses sum across multiple experts:
# y[dst_indx // n_expts_act, :] += x
# Need to set n_expts_act to 1 to unfuse moe_sum
# Set n_expts_act to 1 to unfuse the sum so we can do it manually via moe_sum.
routing_data.n_expts_act = 1
matmul_ogs(
@@ -878,6 +1067,8 @@ class OAITritonMxfp4ExpertsMonolithic(mk.FusedMoEExpertsMonolithic):
return False
# (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (which map to 9.4/9.5).
if not has_triton_kernels():
return False
return (9, 0) <= (cap.major, cap.minor) < (11, 0)
@staticmethod
@@ -14,6 +14,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import trtllm_moe_pack_topk_ids_weights
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kMxfp4Static,
@@ -32,10 +33,8 @@ class TrtLlmMxfp4ExpertsBase:
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
**kwargs,
):
# NOTE: FusedMoEExperts.__init__ is called by the concrete subclass
# (Monolithic/Modular) via MRO, not here, to avoid mypy issues with
# multiple inheritance. This matches the NvFP4 expert pattern.
self.moe_config = moe_config
self.quant_config = quant_config
@@ -48,23 +47,34 @@ class TrtLlmMxfp4ExpertsBase:
self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
# MXFP4-specific TRTLLM parameters
# MXFP4-specific TRTLLM parameters from quant_config
device = torch.accelerator.current_device_index()
self.gemm1_alpha = torch.tensor(
[1.702] * self.local_num_experts,
dtype=torch.float32,
device=device,
)
self.gemm1_beta = torch.tensor(
[1.0] * self.local_num_experts,
dtype=torch.float32,
device=device,
)
self.gemm1_clamp_limit = torch.tensor(
[7.0] * self.local_num_experts,
dtype=torch.float32,
device=device,
)
if quant_config.gemm1_alpha is not None:
self.gemm1_alpha = torch.tensor(
[quant_config.gemm1_alpha] * self.local_num_experts,
dtype=torch.float32,
device=device,
)
else:
self.gemm1_alpha = None
if quant_config.gemm1_beta is not None:
self.gemm1_beta = torch.tensor(
[quant_config.gemm1_beta] * self.local_num_experts,
dtype=torch.float32,
device=device,
)
else:
self.gemm1_beta = None
if quant_config.gemm1_clamp_limit is not None:
self.gemm1_clamp_limit = torch.tensor(
[quant_config.gemm1_clamp_limit] * self.local_num_experts,
dtype=torch.float32,
device=device,
)
else:
self.gemm1_clamp_limit = None
from vllm.config import get_current_vllm_config
@@ -97,7 +107,7 @@ class TrtLlmMxfp4ExpertsBase:
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation == MoEActivation.SWIGLUOAI
return activation in (MoEActivation.SWIGLUOAI, MoEActivation.SILU)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
@@ -190,36 +200,41 @@ class TrtLlmMxfp4ExpertsMonolithic(
output = torch.empty_like(hidden_states)
return trtllm_fp4_block_scale_moe(
routing_logits=router_logits.to(torch.bfloat16),
routing_bias=None,
hidden_states=x_quant,
hidden_states_scale=x_scale,
gemm1_weights=w1,
gemm1_weights_scale=self.w1_scale,
gemm1_bias=self.w1_bias,
gemm1_alpha=self.gemm1_alpha,
gemm1_beta=self.gemm1_beta,
gemm1_clamp_limit=self.gemm1_clamp_limit,
gemm2_weights=w2,
gemm2_weights_scale=self.w2_scale,
gemm2_bias=self.w2_bias,
output1_scale_scalar=None,
output1_scale_gate_scalar=None,
output2_scale_scalar=None,
num_experts=global_num_experts,
top_k=self.topk,
n_group=None,
topk_group=None,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=None,
routing_method_type=self.routing_method_type,
do_finalize=True,
tune_max_num_tokens=max(self.max_capture_size, 1),
output=output,
)[0]
from vllm.utils.flashinfer import _is_fi_autotuning, autotune
with autotune(_is_fi_autotuning):
trtllm_fp4_block_scale_moe(
routing_logits=router_logits.to(torch.bfloat16),
routing_bias=None,
hidden_states=x_quant,
hidden_states_scale=x_scale,
gemm1_weights=w1,
gemm1_weights_scale=self.w1_scale,
gemm1_bias=self.w1_bias,
gemm1_alpha=self.gemm1_alpha,
gemm1_beta=self.gemm1_beta,
gemm1_clamp_limit=self.gemm1_clamp_limit,
gemm2_weights=w2,
gemm2_weights_scale=self.w2_scale,
gemm2_bias=self.w2_bias,
output1_scale_scalar=None,
output1_scale_gate_scalar=None,
output2_scale_scalar=None,
num_experts=global_num_experts,
top_k=self.topk,
n_group=None,
topk_group=None,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=None,
routing_method_type=self.routing_method_type,
do_finalize=True,
tune_max_num_tokens=max(self.max_capture_size, 1),
output=output,
)
return output
class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModular):
@@ -239,6 +254,16 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
) -> bool:
return True
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
# Modular kernel handles only the expert computation;
# routing is done externally, so accept any routing method.
return True
def supports_expert_map(self) -> bool:
return True
@@ -282,7 +307,7 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
):
topk = topk_ids.size(-1)
local_num_experts = w1.size(0)
intermediate_size = w2.size(1)
intermediate_size = self.intermediate_size_per_partition
local_expert_offset = self.moe_config.ep_rank * local_num_experts
# Handle input quantization
@@ -302,9 +327,8 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
x_quant = hidden_states
x_scale = None
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16
).view(torch.int16)
# Pack topk ids and weights into format expected by the kernel.
packed_tensor = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights)
assert self.w1_scale is not None
assert self.w2_scale is not None
@@ -333,7 +357,10 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
"local_expert_offset": local_expert_offset,
"local_num_experts": local_num_experts,
"routed_scaling_factor": None,
"routing_method_type": self.routing_method_type,
# Modular kernel receives pre-routed tokens, so routing
# is already done. Use Renormalize as a safe default that
# the TRTLLM C++ kernel supports.
"routing_method_type": RoutingMethodType.Renormalize,
"do_finalize": True,
"output": output,
"tune_max_num_tokens": max(self.max_capture_size, 1),
@@ -341,12 +368,9 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
from flashinfer import trtllm_fp4_block_scale_routed_moe
from vllm.utils.flashinfer import autotune
from vllm.utils.flashinfer import _is_fi_autotuning, autotune
with autotune(False):
# Enable autotune when,
# https://github.com/flashinfer-ai/flashinfer/issues/2023 is
# resolved.
with autotune(_is_fi_autotuning):
trtllm_fp4_block_scale_routed_moe(**kwargs)
return output
@@ -50,6 +50,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
from .utils import swiglu_limit_func
def _fused_marlin_moe(
hidden_states: torch.Tensor,
@@ -88,6 +90,7 @@ def _fused_marlin_moe(
output: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None,
is_k_full: bool = True,
clamp_limit: float | None = None,
) -> torch.Tensor:
assert hidden_states.ndim == 2
M, K = hidden_states.size()
@@ -155,11 +158,18 @@ def _fused_marlin_moe(
use_fp32_reduce=True,
is_zp_float=False,
)
activation_func(
activation,
intermediate_cache2,
intermediate_cache1.view(-1, w13_num_shards * N),
)
if clamp_limit is not None and activation == MoEActivation.SILU:
swiglu_limit_func(
intermediate_cache2,
intermediate_cache1.view(-1, w13_num_shards * N),
clamp_limit,
)
else:
activation_func(
activation,
intermediate_cache2,
intermediate_cache1.view(-1, w13_num_shards * N),
)
if output is None:
output = intermediate_cache3
@@ -247,6 +257,7 @@ def fused_marlin_moe(
output: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None,
inplace: bool = False,
clamp_limit: float | None = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -363,6 +374,7 @@ def fused_marlin_moe(
output=None,
input_dtype=input_dtype,
is_k_full=is_k_full,
clamp_limit=clamp_limit,
).view(-1, topk, K)
if output is None:
@@ -557,6 +569,7 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
self.w2_g_idx_sort_indices = w2_g_idx_sort_indices
self.is_k_full = is_k_full
self.input_dtype = get_marlin_input_dtype()
self.gemm1_clamp_limit = quant_config.gemm1_clamp_limit
super().__init__(
moe_config=moe_config,
@@ -850,6 +863,7 @@ class MarlinExperts(LoRAExpertsMixin, MarlinExpertsBase):
sort_indices2=self.w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
input_dtype=self.input_dtype,
clamp_limit=self.gemm1_clamp_limit,
)
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
@@ -169,5 +169,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError
@@ -268,6 +268,7 @@ class FusedMoE(PluggableLayer):
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
swiglu_limit: float | None = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
@@ -285,6 +286,7 @@ class FusedMoE(PluggableLayer):
routed_output_transform: torch.nn.Module | None = None,
apply_routed_scale_to_output: bool = False,
zero_expert_type: str | None = None,
hash_indices_table: torch.Tensor | None = None,
):
super().__init__()
@@ -294,6 +296,7 @@ class FusedMoE(PluggableLayer):
vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config
self.swiglu_limit = swiglu_limit
# FIXME (varun): We should have a better way of inferring the activation
# datatype. This works for now as the tensor datatype entering the MoE
@@ -455,6 +458,7 @@ class FusedMoE(PluggableLayer):
self.e_score_correction_bias = e_score_correction_bias
# TODO(bnell): end attributes
self.hash_indices_table = hash_indices_table
self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = MoEActivation.from_str(activation)
@@ -479,6 +483,7 @@ class FusedMoE(PluggableLayer):
indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
zero_expert_type=zero_expert_type,
num_logical_experts=self.logical_num_experts,
hash_indices_table=self.hash_indices_table,
)
self.routing_method_type: RoutingMethodType = self.router.routing_method_type
@@ -1541,10 +1546,12 @@ class FusedMoE(PluggableLayer):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
return self.runner.forward(
hidden_states,
router_logits,
input_ids,
)
@property
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import (
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
FusedMoEQuantDesc,
mxfp4_mxfp8_moe_quant_config,
mxfp4_w4a16_moe_quant_config,
ocp_mx_moe_quant_config,
@@ -24,6 +25,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kMxfp4Static,
kMxfp8Dynamic,
)
@@ -46,6 +48,8 @@ if has_triton_kernels():
class Mxfp4MoeBackend(Enum):
NONE = "None"
# DeepGEMM FP8xFP4 backend (SM100+)
DEEPGEMM_MXFP4 = "DEEPGEMM_MXFP4"
# FlashInfer TRTLLM backends
FLASHINFER_TRTLLM_MXFP4_MXFP8 = "FLASHINFER_TRTLLM_MXFP4_MXFP8"
FLASHINFER_TRTLLM_MXFP4_BF16 = "FLASHINFER_TRTLLM_MXFP4_BF16"
@@ -81,7 +85,14 @@ TRITON_BACKENDS = (
def backend_to_kernel_cls(
backend: Mxfp4MoeBackend,
) -> list[type[mk.FusedMoEExperts]]:
if backend in (
if backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
from vllm.model_executor.layers.fused_moe.experts.deep_gemm_moe import (
DeepGemmFP4Experts,
)
return [DeepGemmFP4Experts]
elif backend in (
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
):
@@ -159,11 +170,13 @@ def backend_to_kernel_cls(
def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend:
"""Map user's moe_backend string to Mxfp4MoeBackend."""
mapping: dict[str, Mxfp4MoeBackend] = {
"deep_gemm": Mxfp4MoeBackend.DEEPGEMM_MXFP4,
"flashinfer_trtllm": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
"flashinfer_trtllm_afp8": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
"flashinfer_cutlass": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
"flashinfer_cutlass_afp8": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
"triton": Mxfp4MoeBackend.TRITON,
"triton_unfused": Mxfp4MoeBackend.TRITON_UNFUSED,
"marlin": Mxfp4MoeBackend.MARLIN,
"aiter": Mxfp4MoeBackend.AITER,
"xpu": Mxfp4MoeBackend.XPU,
@@ -177,7 +190,7 @@ def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend:
)
def _get_priority_backends() -> list[Mxfp4MoeBackend]:
def _get_priority_backends_for_gpt_oss() -> list[Mxfp4MoeBackend]:
"""
Get available backends in priority order based on platform and config.
Only includes BF16 backends. MXFP8 backends are selected via env vars.
@@ -187,7 +200,9 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]:
Mxfp4MoeBackend.AITER,
Mxfp4MoeBackend.TRITON,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.TRITON_UNFUSED,
# TRITON_UNFUSED has bug with MTP support
# TODO re-enable after kernel is fixed
# TRITON_UNFUSED
Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN,
Mxfp4MoeBackend.XPU,
@@ -196,8 +211,28 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]:
return _AVAILABLE_BACKENDS
def _get_priority_backends() -> list[Mxfp4MoeBackend]:
"""
Get available backends in priority order. SM100+ prefers DeepGEMM FP4 /
TRTLLM MXFP8; SM90 falls through to Triton_unfused or Marlin (the
backend-level ``is_supported_config`` check filters by device capability).
"""
_AVAILABLE_BACKENDS = [
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
Mxfp4MoeBackend.DEEPGEMM_MXFP4,
# TRITON_UNFUSED has bug with MTP support
# TODO re-enable after kernel is fixed
# TRITON_UNFUSED
Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN,
]
return _AVAILABLE_BACKENDS
def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None:
"""Map backend to its activation key (MXFP8 or None for BF16)."""
"""Map backend to its activation key (FP8, MXFP8, or None for BF16)."""
if backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
return kFp8Dynamic128Sym
if backend in (
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
@@ -290,7 +325,7 @@ def select_gpt_oss_mxfp4_moe_backend(
)
# Select kernels in order of backend.
AVAILABLE_BACKENDS = _get_priority_backends()
AVAILABLE_BACKENDS = _get_priority_backends_for_gpt_oss()
# Handle explicit FlashInfer MXFP4 BF16 configuration.
if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
@@ -387,11 +422,95 @@ def select_gpt_oss_mxfp4_moe_backend(
return Mxfp4MoeBackend.NONE, None
def select_mxfp4_moe_backend(
config: FusedMoEConfig,
) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts] | None]:
"""
Select the MXFP4 MoE backend with MXFP8 activation as top priority.
Falls back through BF16 and other backends.
"""
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if config.moe_parallel_config.use_batched_activation_format
else mk.FusedMoEActivationFormat.Standard
)
def _make_log_backend(backend: Mxfp4MoeBackend):
return f"Using '{backend.value}' Mxfp4 MoE backend."
def _make_log_unsupported(backend: Mxfp4MoeBackend, reason: str | None) -> str:
if reason:
return (
f"Mxfp4 MoE backend '{backend.value}' does not support the "
f"deployment configuration since {reason}."
)
return (
f"Mxfp4 MoE backend '{backend.value}' does not support the "
"deployment configuration."
)
def _return_or_raise(
backend: Mxfp4MoeBackend,
config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts]]:
reason: str | None = None
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason))
# Honor explicit moe_backend (e.g. "marlin", "triton_unfused") before
# falling back to the auto priority list.
runner_backend = config.moe_backend
if runner_backend != "auto":
requested_backend = map_mxfp4_backend(runner_backend)
if (
activation_format == mk.FusedMoEActivationFormat.BatchedExperts
and requested_backend == Mxfp4MoeBackend.MARLIN
):
requested_backend = Mxfp4MoeBackend.BATCHED_MARLIN
return _return_or_raise(
requested_backend,
config,
kMxfp4Static,
_backend_activation_key(requested_backend),
activation_format,
)
# Iterate priority backends: TRTLLM MXFP8, then Triton.
for backend in _get_priority_backends():
activation_key = _backend_activation_key(backend)
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls, config, kMxfp4Static, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
raise NotImplementedError(
"No MXFP4 MoE backend supports the deployment configuration."
)
def mxfp4_round_up_hidden_size_and_intermediate_size(
backend: Mxfp4MoeBackend, hidden_size: int, intermediate_size: int
) -> tuple[int, int]:
"""Round up hidden_size and intermediate_size based on backend requirements."""
if backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN):
if backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
# DeepGEMM requires M/N/K alignment
intermediate_size = round_up(intermediate_size, 128)
hidden_size = round_up(hidden_size, 128)
elif backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN):
intermediate_size = round_up(intermediate_size, 128)
if current_platform.is_xpu():
hidden_size = round_up(hidden_size, 128)
@@ -434,6 +553,20 @@ def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
]:
"""Convert loaded weights into backend-specific kernel format."""
if mxfp4_backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_upcast_e8m0_to_fp32,
)
return (
w13_weight.data,
w2_weight.data,
_upcast_e8m0_to_fp32(w13_weight_scale.data),
_upcast_e8m0_to_fp32(w2_weight_scale.data),
w13_bias,
w2_bias,
)
num_experts = w13_weight.shape[0]
intermediate_size = w13_weight.shape[1] // 2
hidden_size = w13_weight.shape[2] * 2
@@ -738,9 +871,10 @@ def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
elif mxfp4_backend in TRITON_BACKENDS:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
assert w13_bias is not None and w2_bias is not None
w13_bias = w13_bias.to(torch.float32)
w2_bias = w2_bias.to(torch.float32)
if w13_bias is not None:
w13_bias = w13_bias.to(torch.float32)
if w2_bias is not None:
w2_bias = w2_bias.to(torch.float32)
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
w13_weight,
@@ -797,15 +931,271 @@ def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
)
def convert_weight_to_mxfp4_moe_kernel_format(
mxfp4_backend: Mxfp4MoeBackend,
layer: torch.nn.Module,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
w13_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w13_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
_cache_permute_indices: dict[torch.Size, torch.Tensor] | None = None,
) -> tuple[
torch.Tensor,
torch.Tensor,
Union[torch.Tensor, "PrecisionConfig"],
Union[torch.Tensor, "PrecisionConfig"],
torch.Tensor | None,
torch.Tensor | None,
]:
"""Convert loaded weights into backend-specific kernel format.
Supports DeepGEMM, TRTLLM MXFP8, Triton and Marlin backends.
"""
if mxfp4_backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_upcast_e8m0_to_fp32,
)
# Weights stay as uint8 packed FP4 — no layout change needed.
# Convert E8M0 uint8 scales to float32.
return (
w13_weight.data,
w2_weight.data,
_upcast_e8m0_to_fp32(w13_weight_scale.data),
_upcast_e8m0_to_fp32(w2_weight_scale.data),
w13_bias,
w2_bias,
)
if mxfp4_backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN):
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_mxfp4_layer_for_marlin,
)
return prepare_moe_mxfp4_layer_for_marlin(
layer,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
w13_bias,
w2_bias,
)
num_experts = w13_weight.shape[0]
intermediate_size = w13_weight.shape[1] // 2
hidden_size = w13_weight.shape[2] * 2
sf_block_size = 32 # mxfp4 block size
if mxfp4_backend in TRTLLM_BACKENDS:
assert _cache_permute_indices is not None
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
w13_weight = w13_weight.data
w2_weight = w2_weight.data
w13_weight_scale = w13_weight_scale.data
w2_weight_scale = w2_weight_scale.data
if w13_bias is not None:
w13_bias = w13_bias.data.to(torch.float32)
if w2_bias is not None:
w2_bias = w2_bias.data.to(torch.float32)
# Swap w1/w3 and interleave to match TRTLLM SwiGLU convention.
# Standard loading gives contiguous [w1/gate, w3/up].
# TRTLLM kernel expects interleaved [w3_0, w1_0, w3_1, w1_1, ...].
w1_weight = w13_weight[:, :intermediate_size, :]
w3_weight = w13_weight[:, intermediate_size:, :]
w13_weight = torch.stack([w3_weight, w1_weight], dim=2).reshape(
w13_weight.shape
)
w1_scale = w13_weight_scale[:, :intermediate_size, :]
w3_scale = w13_weight_scale[:, intermediate_size:, :]
w13_weight_scale = torch.stack([w3_scale, w1_scale], dim=2).reshape(
w13_weight_scale.shape
)
if w13_bias is not None:
b1 = w13_bias[:, :intermediate_size]
b3 = w13_bias[:, intermediate_size:]
w13_bias = torch.stack([b3, b1], dim=2).reshape(w13_bias.shape)
# Shuffle weights and scaling factors for transposed mma output.
# Permute indices depend only on shape (cached by torch.Size),
# so compute once and apply to all experts via batched indexing.
epilogue_tile_m = 128
# w13 weight permute
w13_perm = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w13_weight[0].view(torch.uint8),
epilogue_tile_m,
).to(w13_weight.device)
w13_weight = w13_weight.view(torch.uint8)[:, w13_perm].contiguous()
# w13 scale permute + interleave
w13_sf_perm = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w13_weight_scale[0].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
).to(w13_weight_scale.device)
w13_s = w13_weight_scale.view(torch.uint8)[:, w13_sf_perm].contiguous()
E, N_s, K_s = w13_s.shape
w13_weight_scale = (
nvfp4_block_scale_interleave(w13_s.reshape(E * N_s, K_s))
.reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size)
.view(torch.float8_e4m3fn)
)
# w2 weight permute
w2_perm = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w2_weight[0].view(torch.uint8),
epilogue_tile_m,
).to(w2_weight.device)
w2_weight = w2_weight.view(torch.uint8)[:, w2_perm].contiguous()
# w2 scale permute + interleave
w2_sf_perm = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w2_weight_scale[0].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
).to(w2_weight_scale.device)
w2_s = w2_weight_scale.view(torch.uint8)[:, w2_sf_perm].contiguous()
E2, N2_s, K2_s = w2_s.shape
w2_weight_scale = (
nvfp4_block_scale_interleave(w2_s.reshape(E2 * N2_s, K2_s))
.reshape(num_experts, hidden_size, intermediate_size // sf_block_size)
.view(torch.float8_e4m3fn)
)
# w13 bias permute
if w13_bias is not None:
w13_b_perm = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w13_bias[0].reshape(-1, 1),
epilogue_tile_m,
).to(w13_bias.device)
w13_bias = w13_bias.reshape(num_experts, -1, 1)[:, w13_b_perm].reshape(
num_experts, -1
)
# w2 bias permute
if w2_bias is not None:
w2_b_perm = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w2_bias[0].reshape(-1, 1),
epilogue_tile_m,
).to(w2_bias.device)
w2_bias = w2_bias.reshape(num_experts, -1, 1)[:, w2_b_perm].reshape(
num_experts, -1
)
return (
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
w13_bias,
w2_bias,
)
elif mxfp4_backend in TRITON_BACKENDS:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
if mxfp4_backend == Mxfp4MoeBackend.TRITON:
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
shape = w.shape
n = shape[-1]
first = w[..., : n // 2]
second = w[..., n // 2 :]
stacked = torch.stack((first, second), dim=-1)
return stacked.reshape(shape)
w13_weight = shuffle_weight(w13_weight)
w13_weight_scale = shuffle_weight(w13_weight_scale)
if w13_bias is not None:
w13_bias = shuffle_weight(w13_bias.to(torch.float32))
else:
if w13_bias is not None:
w13_bias = w13_bias.to(torch.float32)
if w2_bias is not None:
w2_bias = w2_bias.to(torch.float32)
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
w13_weight,
w13_weight_scale,
)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
w2_weight,
w2_weight_scale,
)
w13_precision_config = PrecisionConfig(
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
)
w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
)
del layer.w13_weight
del layer.w2_weight
return (
w13_weight,
w2_weight,
w13_precision_config,
w2_precision_config,
w13_bias,
w2_bias,
)
else:
raise ValueError(
f"Unsupported mxfp4_backend for Mxfp4MoEMethod: {mxfp4_backend}. "
f"Expected TRTLLM or Triton backend."
)
def make_mxfp4_moe_quant_config(
mxfp4_backend: Mxfp4MoeBackend,
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
gemm1_alpha: float | None = None,
gemm1_beta: float | None = None,
swiglu_limit: float | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> FusedMoEQuantConfig | None:
"""Create a FusedMoEQuantConfig for the given MXFP4 backend."""
if mxfp4_backend in (
if mxfp4_backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
# DeepGEMM FP4 uses FP8 per-token-group activation quantization
# with block 128, matching the FP8 DeepGEMM path.
_fp8_dtype = current_platform.fp8_dtype()
_block_shape = GroupShape(128, 128)
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(_fp8_dtype, _block_shape, None, None, None, None),
_a2=FusedMoEQuantDesc(_fp8_dtype, _block_shape, None, None, None, None),
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=swiglu_limit,
)
elif mxfp4_backend in (
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
):
@@ -814,6 +1204,9 @@ def make_mxfp4_moe_quant_config(
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=swiglu_limit,
)
elif mxfp4_backend in (
Mxfp4MoeBackend.MARLIN,
@@ -829,6 +1222,9 @@ def make_mxfp4_moe_quant_config(
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=swiglu_limit,
)
else:
return ocp_mx_moe_quant_config(
@@ -837,6 +1233,9 @@ def make_mxfp4_moe_quant_config(
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=swiglu_limit,
)
@@ -228,6 +228,8 @@ class BaseRouter(FusedMoERouter):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute the actual routing logic.
@@ -249,6 +251,8 @@ class BaseRouter(FusedMoERouter):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Route the input hidden states to the top-k experts based on the
@@ -278,7 +282,7 @@ class BaseRouter(FusedMoERouter):
# Step 3: Compute routing (delegated to subclass)
topk_weights, topk_ids = self._compute_routing(
hidden_states, router_logits, indices_type
hidden_states, router_logits, indices_type, input_ids=input_ids
)
# Capture logical ids before EPLB mapping.
@@ -46,6 +46,8 @@ class CustomRoutingRouter(BaseRouter):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using the custom routing function."""
topk_weights, topk_ids = self.custom_routing_function(
@@ -31,6 +31,8 @@ class FusedMoERouter(ABC):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Route the input hidden states to the top-k experts based on the
@@ -4,6 +4,7 @@ import functools
from collections.abc import Callable
import torch
import torch.nn.functional as F
import vllm._custom_ops as ops
import vllm.envs as envs
@@ -56,6 +57,32 @@ def vllm_topk_sigmoid(
return topk_weights, topk_indices
def vllm_topk_softplus_sqrt(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
e_score_correction_bias: torch.Tensor | None = None,
input_tokens: torch.Tensor | None = None,
hash_indices_table: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, ...]:
ops.topk_hash_softplus_sqrt(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
routed_scaling_factor,
e_score_correction_bias,
input_tokens,
hash_indices_table,
)
return topk_weights, topk_indices
@functools.lru_cache(maxsize=8)
def _aiter_get_num_expert_group(num_experts: int) -> int:
_AITER_MAX_EXPERTS_PER_GROUP = 32
@@ -72,11 +99,14 @@ def _aiter_get_num_expert_group(num_experts: int) -> int:
def fused_topk_bias(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
scoring_func: str,
e_score_correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
scoring_func: str = "softmax",
indices_type: torch.dtype | None = None,
input_tokens: torch.Tensor | None = None,
hash_indices_table: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
):
if not rocm_aiter_ops.is_fused_moe_enabled():
assert hidden_states.size(0) == gating_output.size(0), (
@@ -107,6 +137,8 @@ def fused_topk_bias(
renormalize,
e_score_correction_bias,
)
if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor
return topk_weights, topk_ids
elif scoring_func == "sigmoid":
topk_weights, topk_ids = vllm_topk_sigmoid(
@@ -117,9 +149,24 @@ def fused_topk_bias(
renormalize,
e_score_correction_bias,
)
if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor
return topk_weights, topk_ids
elif scoring_func == "sqrtsoftplus":
return vllm_topk_softplus_sqrt(
topk_weights,
topk_ids,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
input_tokens,
hash_indices_table,
routed_scaling_factor,
)
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
elif rocm_aiter_ops.is_fused_moe_enabled() and scoring_func == "sigmoid":
M = hidden_states.size(0)
num_experts = gating_output.shape[-1]
@@ -143,6 +190,8 @@ def fused_topk_bias(
topk_group=num_expert_group,
need_renorm=renormalize,
)
if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor
return topk_weights, topk_ids
n_routed_experts = gating_output.shape[-1]
@@ -150,20 +199,31 @@ def fused_topk_bias(
scores = gating_output.softmax(dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
elif scoring_func == "sqrtsoftplus":
scores = F.softplus(gating_output).sqrt()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
scores_for_choice = scores.view(
-1, n_routed_experts
) + e_score_correction_bias.unsqueeze(0)
if e_score_correction_bias is not None:
scores_for_choice = scores.view(
-1, n_routed_experts
) + e_score_correction_bias.unsqueeze(0)
else:
scores_for_choice = scores.view(-1, n_routed_experts)
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = envs.VLLM_BATCH_INVARIANT
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
if hash_indices_table is not None:
topk_indices = hash_indices_table[input_tokens]
else:
use_sorted = envs.VLLM_BATCH_INVARIANT
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[
1
]
topk_weights = scores.gather(1, topk_indices)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_indices.to(
topk_weights = topk_weights.to(torch.float32)
if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor
return topk_weights, topk_indices.to(
torch.int32 if indices_type is None else indices_type
)
@@ -176,12 +236,14 @@ class FusedTopKBiasRouter(BaseRouter):
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
e_score_correction_bias: torch.Tensor,
scoring_func: str,
e_score_correction_bias: torch.Tensor | None = None,
renormalize: bool = True,
routed_scaling_factor: float = 1.0,
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
*,
scoring_func: str = "sigmoid",
hash_indices_table: torch.Tensor | None = None,
):
super().__init__(
top_k=top_k,
@@ -194,6 +256,8 @@ class FusedTopKBiasRouter(BaseRouter):
self.renormalize = renormalize
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
self.scoring_func = scoring_func
self._hash_indices_table = hash_indices_table
@property
def routing_method_type(self) -> RoutingMethodType:
@@ -210,19 +274,23 @@ class FusedTopKBiasRouter(BaseRouter):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using fused top-k with bias."""
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
e_score_correction_bias=self.e_score_correction_bias.data,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias.data
if self.e_score_correction_bias is not None
else None,
topk=self.top_k,
renormalize=self.renormalize,
scoring_func=self.scoring_func,
indices_type=indices_type,
input_tokens=input_ids,
hash_indices_table=self._hash_indices_table,
routed_scaling_factor=self.routed_scaling_factor,
)
if self.routed_scaling_factor != 1.0:
topk_weights *= self.routed_scaling_factor
return topk_weights, topk_ids
@@ -151,6 +151,8 @@ class FusedTopKRouter(BaseRouter):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using standard fused top-k."""
topk_weights, topk_ids, token_expert_indices = fused_topk(
@@ -292,6 +292,8 @@ class GroupedTopKRouter(BaseRouter):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using grouped top-k."""
@@ -308,6 +310,7 @@ class GroupedTopKRouter(BaseRouter):
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias.data,
topk=self.top_k,
renormalize=self.renormalize,
@@ -55,6 +55,7 @@ def create_fused_moe_router(
# zero expert parameters
zero_expert_type: str | None = None,
num_logical_experts: int | None = None,
hash_indices_table: torch.Tensor | None = None,
) -> FusedMoERouter:
"""
Factory function to create the appropriate FusedMoERouter subclass based on
@@ -99,6 +100,9 @@ def create_fused_moe_router(
num_logical_experts: Number of real (non-zero) experts. Required when
zero_expert_type is not None.
Hash Indices Table:
Used to map input_ids to experts, need for Deepseek V4
Returns:
An instance of the appropriate FusedMoERouter subclass
"""
@@ -179,17 +183,20 @@ def create_fused_moe_router(
indices_type_getter=indices_type_getter,
)
if e_score_correction_bias is not None:
assert scoring_func in ["sigmoid", "softmax", "sqrtsoftplus"]
if e_score_correction_bias is not None or hash_indices_table is not None:
return FusedTopKBiasRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
e_score_correction_bias=e_score_correction_bias,
scoring_func=scoring_func,
renormalize=renormalize,
routed_scaling_factor=routed_scaling_factor,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
scoring_func=scoring_func,
hash_indices_table=hash_indices_table,
)
return FusedTopKRouter(
@@ -334,6 +334,8 @@ class RoutingSimulatorRouter(BaseRouter):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Use routing simulator to compute routing."""
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
@@ -72,6 +72,8 @@ class ZeroExpertRouter(BaseRouter):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing with full bias, compute zero expert output,
mask zero expert IDs."""
@@ -91,6 +91,7 @@ def _moe_forward(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
input_ids: torch.Tensor | None,
layer_name: _layer_name_type,
) -> torch.Tensor:
layer = get_layer_from_name(_resolve_layer_name(layer_name))
@@ -99,6 +100,7 @@ def _moe_forward(
hidden_states,
router_logits,
shared_experts_input,
input_ids,
)
@@ -106,6 +108,7 @@ def _moe_forward_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
input_ids: torch.Tensor | None,
layer_name: _layer_name_type,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
@@ -115,6 +118,7 @@ def _moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
input_ids: torch.Tensor | None,
layer_name: _layer_name_type,
) -> tuple[torch.Tensor, torch.Tensor]:
layer = get_layer_from_name(_resolve_layer_name(layer_name))
@@ -123,6 +127,7 @@ def _moe_forward_shared(
hidden_states,
router_logits,
shared_experts_input,
input_ids,
)
@@ -130,6 +135,7 @@ def _moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
input_ids: torch.Tensor | None,
layer_name: _layer_name_type,
) -> tuple[torch.Tensor, torch.Tensor]:
# Output shapes:
@@ -433,6 +439,7 @@ class MoERunner(MoERunnerInterface):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor]:
"""Run expert routing and the fused MoE kernel via the quant method.
@@ -449,11 +456,13 @@ class MoERunner(MoERunnerInterface):
layer=layer,
x=hidden_states,
router_logits=router_logits,
input_ids=input_ids,
)
else:
topk_weights, topk_ids = self.router.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
input_ids=input_ids,
)
# Passing shared_experts_input in case SharedExpertsOrder is
@@ -523,6 +532,7 @@ class MoERunner(MoERunnerInterface):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
"""Invoke the fused moe layer.
@@ -565,6 +575,7 @@ class MoERunner(MoERunnerInterface):
hidden_states,
router_logits,
shared_experts_input,
input_ids,
self._encode_layer_name(),
)
@@ -672,6 +683,7 @@ class MoERunner(MoERunnerInterface):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Entry point called by the custom op to run the MoE computation.
@@ -712,6 +724,7 @@ class MoERunner(MoERunnerInterface):
hidden_states=hidden_states,
router_logits=router_logits,
shared_experts_input=shared_experts_input,
input_ids=input_ids,
)
return self._maybe_combine(
@@ -26,6 +26,7 @@ class MoERunnerInterface(ABC):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError
@@ -309,6 +309,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.is_monolithic
if self.unquantized_backend == UnquantizedMoeBackend.CPU:
@@ -4,6 +4,7 @@ import functools
from math import prod
import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
@@ -384,3 +385,20 @@ def trtllm_moe_pack_topk_ids_weights(
return (topk_ids.to(torch.int32) << 16) | topk_weights.to(torch.bfloat16).view(
torch.int16
)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def swiglu_limit_func(
output: torch.Tensor,
input: torch.Tensor, # first half is gate, second half is up
swiglu_limit: float = 0.0,
) -> None:
d = input.shape[1] // 2
gate = input[:, :d]
up = input[:, d:]
if swiglu_limit > 0:
gate = torch.clamp(gate, max=swiglu_limit)
up = torch.clamp(up, min=-swiglu_limit, max=swiglu_limit)
output.copy_(F.silu(gate) * up)
+450
View File
@@ -0,0 +1,450 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from functools import cache
from typing import TYPE_CHECKING
import torch
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_tilelang
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import direct_register_custom_op
# tilelang is only available on CUDA platforms
if TYPE_CHECKING or current_platform.is_cuda_alike():
if not has_tilelang():
raise ImportError(
"tilelang is required for mhc but is not installed. Install it with "
"`pip install tilelang`."
)
import tilelang
import tilelang.language as T
else:
tilelang = None # type: ignore[assignment]
T = None # type: ignore[assignment]
@cache
def compute_num_split(block_k: int, k: int | None, grid_size: int) -> int:
device_props = torch.cuda.get_device_properties(0)
n_sms = device_props.multi_processor_count
split_k = n_sms // grid_size
if k is not None:
# avoid split_k for small k
num_block_k = cdiv(k, block_k)
split_k = min(split_k, num_block_k // 4)
split_k = max(split_k, 1)
return split_k
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
},
)
def mhc_pre_big_fuse_tilelang(
gemm_out_mul,
gemm_out_sqrsum,
hc_scale,
hc_base,
residual,
post_mix,
comb_mix,
layer_input,
hidden_size: int,
rms_eps: float,
hc_pre_eps: float,
hc_sinkhorn_eps: float,
hc_post_mult_value: float,
sinkhorn_repeat: int,
n_splits: int = 16,
hc_mult: int = 4,
):
"""Deeply fused kernels, everything other than gemm & sqrsum in mHC pre block."""
num_tokens = T.dynamic("num_tokens")
hc_mult3 = hc_mult * (2 + hc_mult)
hidden_block = math.gcd(512, hidden_size)
gemm_out_mul: T.Tensor[[n_splits, num_tokens, hc_mult3], T.float32] # type: ignore[no-redef, valid-type]
gemm_out_sqrsum: T.Tensor[[n_splits, num_tokens], T.float32] # type: ignore[no-redef, valid-type]
hc_scale: T.Tensor[[3], T.float32] # type: ignore[no-redef, valid-type]
hc_base: T.Tensor[[hc_mult3], T.float32] # type: ignore[no-redef, valid-type]
residual: T.Tensor[[num_tokens, hc_mult, hidden_size], T.bfloat16] # type: ignore[no-redef, valid-type]
# outputs
post_mix: T.Tensor[[num_tokens, hc_mult], T.float32] # type: ignore[no-redef, valid-type]
comb_mix: T.Tensor[[num_tokens, hc_mult * hc_mult], T.float32] # type: ignore[no-redef, valid-type]
layer_input: T.Tensor[[num_tokens, hidden_size], T.bfloat16] # type: ignore[no-redef, valid-type]
with T.Kernel(num_tokens, threads=96) as i:
T.pdl_sync()
##################################################################
# _pre_norm_fn_fwd_norm
rms = T.alloc_fragment(1, T.float32)
mixes = T.alloc_fragment(hc_mult3, T.float32)
T.clear(mixes)
rms[0] = 0
for i_split in T.serial(n_splits):
rms[0] += gemm_out_sqrsum[i_split, i]
rms[0] = T.rsqrt(rms[0] / (hc_mult * hidden_size) + rms_eps)
for j in T.Parallel(hc_mult3):
mixes[j] = 0
for i_split in T.serial(n_splits):
mixes[j] += gemm_out_mul[i_split, i, j]
mixes[j] *= rms[0]
mixes_shared = T.alloc_shared(hc_mult3, T.float32)
T.copy(mixes, mixes_shared)
if T.get_thread_binding() < 32:
##################################################################
# _pre_split_mixes_fwd (post & comb)
cm = T.alloc_fragment((hc_mult, hc_mult), T.float32)
for j in T.Parallel(hc_mult):
post_mix[i, j] = (
T.sigmoid(
mixes_shared[j + hc_mult] * hc_scale[1] + hc_base[j + hc_mult]
)
* hc_post_mult_value
)
for j, k in T.Parallel(hc_mult, hc_mult):
cm[j, k] = (
mixes_shared[j * hc_mult + k + hc_mult * 2] * hc_scale[2]
+ hc_base[j * hc_mult + k + hc_mult * 2]
)
##################################################################
# _sinkhorn_fwd
row_sum = T.alloc_fragment(hc_mult, T.float32)
col_sum = T.alloc_fragment(hc_mult, T.float32)
# comb = comb.softmax(-1) + eps
row_max = T.alloc_fragment(hc_mult, T.float32)
T.reduce_max(cm, row_max, dim=1)
for j, k in T.Parallel(hc_mult, hc_mult):
cm[j, k] = T.exp(cm[j, k] - row_max[j])
T.reduce_sum(cm, row_sum, dim=1)
for j, k in T.Parallel(hc_mult, hc_mult):
cm[j, k] = cm[j, k] / row_sum[j] + hc_sinkhorn_eps
# comb = comb / (comb.sum(-2) + eps)
T.reduce_sum(cm, col_sum, dim=0)
for j, k in T.Parallel(hc_mult, hc_mult):
cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps)
for _ in T.serial(sinkhorn_repeat - 1):
# comb = comb / (comb.sum(-1) + eps)
T.reduce_sum(cm, row_sum, dim=1)
for j, k in T.Parallel(hc_mult, hc_mult):
cm[j, k] = cm[j, k] / (row_sum[j] + hc_sinkhorn_eps)
# comb = comb / (comb.sum(-2) + eps)
T.reduce_sum(cm, col_sum, dim=0)
for j, k in T.Parallel(hc_mult, hc_mult):
cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps)
# save comb_mix to global memory
for j, k in T.Parallel(hc_mult, hc_mult):
comb_mix[i, j * hc_mult + k] = cm[j, k]
else:
##################################################################
# _pre_split_mixes_fwd (pre)
pre_mix_shared = T.alloc_shared(hc_mult, T.float32)
for j in T.Parallel(hc_mult):
pre_mix_shared[j] = (
T.sigmoid(
mixes_shared[j] * hc_scale[0] + hc_base[j],
)
+ hc_pre_eps
)
###################################################################
# _pre_apply_mix_fwd
for i0_h in T.Pipelined(hidden_size // hidden_block, num_stages=2):
xs = T.alloc_shared((hc_mult, hidden_block), T.float32)
xl = T.alloc_fragment((hc_mult, hidden_block), T.float32)
T.copy(residual[i, 0, i0_h * hidden_block], xs)
T.copy(xs, xl)
ol = T.alloc_fragment(hidden_block, T.float32)
T.clear(ol)
for i_hc in T.serial(hc_mult):
pre = pre_mix_shared[i_hc]
for i1_h in T.Parallel(hidden_block):
ol[i1_h] += pre * xl[i_hc, i1_h]
T.copy(ol, layer_input[i, i0_h * hidden_block])
T.pdl_trigger()
def mhc_pre(
residual: torch.Tensor,
fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
rms_eps: float,
hc_pre_eps: float,
hc_sinkhorn_eps: float,
hc_post_mult_value: float,
sinkhorn_repeat: int,
n_splits: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward pass for mHC pre block.
Args:
residual: shape (..., hc_mult, hidden_size), dtype torch.bfloat16
fn: shape (hc_mult3, hc_mult * hidden_size), dtype torch.float32
hc_scale: shape (3,), dtype torch.float32
hc_base: shape (hc_mult3,), dtype torch.float32
rms_eps: RMS normalization epsilon
hc_pre_eps: pre-mix epsilon
hc_sinkhorn_eps: sinkhorn epsilon
hc_post_mult_value: post-mix multiplier value
sinkhorn_repeat: number of sinkhorn iterations
n_splits: split-k factor;
Returns:
post_mix: shape (..., hc_mult), dtype torch.float32
comb_mix: shape (..., hc_mult, hc_mult), dtype torch.float32
layer_input: shape (..., hidden_size), dtype torch.bfloat16
"""
# Validate shapes
assert residual.dtype == torch.bfloat16
assert fn.dtype == torch.float32
assert hc_scale.dtype == torch.float32
assert hc_base.dtype == torch.float32
hc_mult = residual.shape[-2]
hidden_size = residual.shape[-1]
hc_mult2 = hc_mult * hc_mult
hc_mult3 = hc_mult * 2 + hc_mult2
hc_hidden_size = hc_mult * hidden_size
assert fn.shape[0] == hc_mult3
assert fn.shape[1] == hc_hidden_size
assert hc_scale.shape == (3,)
assert hc_base.shape == (hc_mult3,)
outer_shape = residual.shape[:-2]
residual_flat = residual.view(-1, hc_mult, hidden_size)
num_tokens = residual_flat.shape[0]
fn_flat = fn
# these number are from deepgemm kernel impl
block_k = 64
block_m = 64
n_splits = compute_num_split(block_k, hc_hidden_size, cdiv(num_tokens, block_m))
post_mix = torch.empty(
num_tokens,
hc_mult,
dtype=torch.float32,
device=residual.device,
)
comb_mix = torch.empty(
num_tokens,
hc_mult2,
dtype=torch.float32,
device=residual.device,
)
layer_input = torch.empty(
num_tokens,
hidden_size,
dtype=torch.bfloat16,
device=residual.device,
)
gemm_out_mul = torch.empty(
n_splits,
num_tokens,
hc_mult3,
dtype=torch.float32,
device=residual.device,
)
gemm_out_sqrsum = torch.empty(
n_splits,
num_tokens,
dtype=torch.float32,
device=residual.device,
)
from vllm.utils.deep_gemm import tf32_hc_prenorm_gemm
tf32_hc_prenorm_gemm(
residual_flat.view(num_tokens, hc_mult * hidden_size),
fn_flat,
gemm_out_mul,
gemm_out_sqrsum,
n_splits,
)
mhc_pre_big_fuse_tilelang(
gemm_out_mul,
gemm_out_sqrsum,
hc_scale,
hc_base,
residual_flat,
post_mix,
comb_mix,
layer_input,
hidden_size,
rms_eps,
hc_pre_eps,
hc_sinkhorn_eps,
hc_post_mult_value,
sinkhorn_repeat,
n_splits,
hc_mult,
)
post_mix = post_mix.view(*outer_shape, hc_mult, 1)
comb_mix = comb_mix.view(*outer_shape, hc_mult, hc_mult)
layer_input = layer_input.view(*outer_shape, hidden_size)
return post_mix, comb_mix, layer_input
def _mhc_pre_fake(
residual: torch.Tensor,
fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
rms_eps: float,
hc_pre_eps: float,
hc_sinkhorn_eps: float,
hc_post_mult_value: float,
sinkhorn_repeat: int,
n_splits: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hc_mult = residual.shape[-2]
hidden_size = residual.shape[-1]
outer_shape = residual.shape[:-2]
# Create empty tensors with correct shapes for meta device / shape inference
post_mix = torch.empty(
*outer_shape,
hc_mult,
1,
dtype=torch.float32,
device=residual.device,
)
comb_mix = torch.empty(
*outer_shape,
hc_mult,
hc_mult,
dtype=torch.float32,
device=residual.device,
)
layer_input = torch.empty(
*outer_shape,
hidden_size,
dtype=torch.bfloat16,
device=residual.device,
)
return post_mix, comb_mix, layer_input
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
},
)
def mhc_post_tilelang(
a,
b,
c,
d,
x,
hc: int,
hidden: int,
n_thr: int = 128,
h_blk: int = 1024,
) -> tilelang.JITKernel:
# rename for shorter code
n = T.dynamic("num_tokens")
h = hidden
h_blk = math.gcd(hidden, h_blk)
a: T.Tensor((n, hc, hc), T.float32) # type: ignore[no-redef, valid-type]
b: T.Tensor((n, hc, h), T.bfloat16) # type: ignore[no-redef, valid-type]
c: T.Tensor((n, hc), T.float32) # type: ignore[no-redef, valid-type]
d: T.Tensor((n, h), T.bfloat16) # type: ignore[no-redef, valid-type]
x: T.Tensor((n, hc, h), T.bfloat16) # type: ignore[no-redef, valid-type]
with T.Kernel(n, threads=n_thr) as i_n:
x_shared = T.alloc_shared((hc, h_blk), T.bfloat16)
b_shared = T.alloc_shared((hc, h_blk), T.bfloat16)
d_shared = T.alloc_shared(h_blk, T.bfloat16)
x_local = T.alloc_fragment((hc, h_blk), T.float32)
b_local = T.alloc_fragment((hc, h_blk), T.float32)
d_local = T.alloc_fragment(h_blk, T.float32)
a_local = T.alloc_fragment((hc, hc), T.float32)
c_local = T.alloc_fragment(hc, T.float32)
T.pdl_sync()
T.copy(a[i_n, 0, 0], a_local)
T.copy(c[i_n, 0], c_local)
for i0_h in T.Pipelined(T.ceildiv(h, h_blk), num_stages=2):
T.copy(b[i_n, 0, i0_h * h_blk], b_shared)
T.copy(d[i_n, i0_h * h_blk], d_shared)
T.copy(b_shared, b_local)
T.copy(d_shared, d_local)
for i_hco, i1_h in T.Parallel(hc, h_blk):
x_local[i_hco, i1_h] = c_local[i_hco] * d_local[i1_h]
for i_hci in T.serial(hc):
x_local[i_hco, i1_h] += a_local[i_hci, i_hco] * b_local[i_hci, i1_h]
T.copy(x_local, x_shared)
T.copy(x_shared, x[i_n, 0, i0_h * h_blk])
T.pdl_trigger()
def mhc_post(
x: torch.Tensor,
residual: torch.Tensor,
post_layer_mix: torch.Tensor,
comb_res_mix: torch.Tensor,
) -> torch.Tensor:
out = torch.empty_like(residual)
mhc_post_tilelang(
comb_res_mix,
residual,
post_layer_mix.squeeze(-1),
x,
out,
residual.shape[-2],
residual.shape[-1],
)
return out
def _mhc_post_fake(
x: torch.Tensor,
residual: torch.Tensor,
post_layer_mix: torch.Tensor,
comb_res_mix: torch.Tensor,
) -> torch.Tensor:
return torch.empty_like(residual)
direct_register_custom_op(
op_name="mhc_pre",
op_func=mhc_pre,
mutates_args=[],
fake_impl=_mhc_pre_fake,
)
direct_register_custom_op(
op_name="mhc_post",
op_func=mhc_post,
mutates_args=[],
fake_impl=_mhc_post_fake,
)
@@ -32,6 +32,7 @@ QuantizationMethods = Literal[
"inc",
"mxfp4",
"gpt_oss_mxfp4",
"deepseek_v4_fp8",
"cpu_awq",
"online",
# Below are values of the OnlineQuantScheme enum, specified as strings to
@@ -112,6 +113,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
# lazy import to avoid triggering `torch.compile` too early
from vllm.config.quantization import OnlineQuantScheme
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from vllm.model_executor.models.deepseek_v4 import DeepseekV4FP8Config
from .awq import AWQConfig
from .awq_marlin import AWQMarlinConfig
@@ -163,6 +165,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"inc": INCConfig,
"mxfp4": Mxfp4Config,
"gpt_oss_mxfp4": GptOssMxfp4Config,
"deepseek_v4_fp8": DeepseekV4FP8Config,
"cpu_awq": CPUAWQConfig,
"humming": HummingConfig,
"online": OnlineQuantizationConfig,
@@ -265,6 +265,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.is_monolithic
assert self.moe_kernel is not None
@@ -305,6 +305,7 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
assert not layer.enable_eplb, "EPLB not supported for W4A8-int MoE yet."
assert layer.activation in (
@@ -367,6 +367,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
@@ -168,6 +168,7 @@ class CompressedTensorsW8A8Mxfp8MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
@@ -517,6 +517,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.kernel_backend == "Flashinfer"
return flashinfer_trtllm_mxint4_moe(
@@ -269,6 +269,7 @@ class Fp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.is_scale_e8m0 = getattr(quant_config, "is_scale_e8m0", False)
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.out_dtype = torch.get_default_dtype()
self.input_dtype = get_current_vllm_config().model_config.dtype
@@ -362,6 +363,7 @@ class Fp8LinearMethod(LinearMethodBase):
input_size_per_partition,
self.weight_block_size,
weight_loader,
scale_dtype=(torch.float8_e8m0fnu if self.is_scale_e8m0 else None),
)
# The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale)
@@ -866,6 +868,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.is_monolithic
assert self.moe_kernel is not None
@@ -950,6 +950,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.is_monolithic
assert self.moe_kernel is not None
@@ -1442,6 +1443,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.is_monolithic
assert self.moe_kernel is not None
@@ -1920,6 +1922,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from flashinfer.fused_moe.core import (
ActivationType,
@@ -20,10 +20,12 @@ from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
TRITON_BACKENDS,
Mxfp4MoeBackend,
convert_gpt_oss_weight_to_mxfp4_moe_kernel_format,
convert_weight_to_mxfp4_moe_kernel_format,
make_mxfp4_moe_kernel,
make_mxfp4_moe_quant_config,
mxfp4_round_up_hidden_size_and_intermediate_size,
select_gpt_oss_mxfp4_moe_backend,
select_mxfp4_moe_backend,
)
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -217,6 +219,7 @@ class GptOssMxfp4MoEMethod(FusedMoEMethodBase):
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w13_weight_scale.quant_method = "block"
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
@@ -242,6 +245,7 @@ class GptOssMxfp4MoEMethod(FusedMoEMethodBase):
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w2_weight_scale.quant_method = "block"
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(
@@ -397,6 +401,9 @@ class GptOssMxfp4MoEMethod(FusedMoEMethodBase):
w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
gemm1_alpha=1.702,
gemm1_beta=1.0,
swiglu_limit=7.0,
)
def select_gemm_impl(
@@ -437,6 +444,332 @@ class GptOssMxfp4MoEMethod(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
router_logits=router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
class Mxfp4MoEMethod(FusedMoEMethodBase):
"""MXFP4 MoE quantization method."""
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.weight_dtype = "mxfp4"
self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe)
self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
)
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
self.moe_kernel: mk.FusedMoEKernel | None = None
# Used for triton kernel precision configs
self.w13_precision_config = None
self.w2_precision_config = None
@property
def skip_forward_padding(self) -> bool:
# SM100_FI_MXFP4_MXFP8_TRTLLM supports padding with mxfp8 quant
# so can skip the padding in the forward before applying the moe method
return self.mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8
def maybe_roundup_sizes(
self,
hidden_size: int,
intermediate_size_per_partition: int,
act_dtype: torch.dtype,
moe_parallel_config: FusedMoEParallelConfig,
) -> tuple[int, int]:
hidden_size, intermediate_size_per_partition = super().maybe_roundup_sizes(
hidden_size=hidden_size,
intermediate_size_per_partition=intermediate_size_per_partition,
act_dtype=act_dtype,
moe_parallel_config=moe_parallel_config,
)
return mxfp4_round_up_hidden_size_and_intermediate_size(
self.mxfp4_backend, hidden_size, intermediate_size_per_partition
)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
self.num_experts = num_experts
weight_dtype = torch.uint8
scale_dtype = torch.uint8
mxfp4_block = 32
layer.params_dtype = params_dtype
layer.num_experts = num_experts
self.intermediate_size = intermediate_size_per_partition
self.hidden_size = hidden_size
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // 2,
dtype=weight_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // mxfp4_block,
dtype=scale_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w13_weight_scale.quant_method = "block"
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition // 2,
dtype=weight_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition // mxfp4_block,
dtype=scale_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w2_weight_scale.quant_method = "block"
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.bfloat16,
),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
w2_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
dtype=torch.bfloat16,
),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
def _setup_kernel(
self,
layer: FusedMoE,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
w13_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> None:
num_experts = self.num_experts
intermediate_size = self.intermediate_size
hidden_size = self.hidden_size
sf_block_size = 32
# Shape assertions
assert (
w13.dim() == 3
and w13.shape[0] == num_experts
and w13.shape[1] == intermediate_size * 2
and w13.shape[2] == hidden_size // 2
)
assert (
w13_scale.dim() == 3
and w13_scale.shape[0] == num_experts
and w13_scale.shape[1] == intermediate_size * 2
and w13_scale.shape[2] == hidden_size // sf_block_size
)
assert (
w2.dim() == 3
and w2.shape[0] == num_experts
and w2.shape[1] == hidden_size
and w2.shape[2] == intermediate_size // 2
)
assert (
w2_scale.dim() == 3
and w2_scale.shape[1] == hidden_size
and w2_scale.shape[2] == intermediate_size // sf_block_size
)
if w13_bias is not None:
assert (
w13_bias.dim() == 2
and w13_bias.shape[0] == num_experts
and w13_bias.shape[1] == intermediate_size * 2
)
if w2_bias is not None:
assert (
w2_bias.dim() == 2
and w2_bias.shape[0] == num_experts
and w2_bias.shape[1] == hidden_size
)
# Convert weights to kernel format
w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = (
convert_weight_to_mxfp4_moe_kernel_format(
mxfp4_backend=self.mxfp4_backend,
layer=layer,
w13_weight=w13,
w2_weight=w2,
w13_weight_scale=w13_scale,
w2_weight_scale=w2_scale,
w13_bias=w13_bias,
w2_bias=w2_bias,
_cache_permute_indices=self._cache_permute_indices,
)
)
# For TRITON backends, weights are wrapped tensors from triton_kernels
# that don't support .detach(). Manually assign parameters.
if self.mxfp4_backend not in TRITON_BACKENDS:
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, "w13_weight_scale", w13_scale)
replace_parameter(layer, "w2_weight_scale", w2_scale)
else:
layer.w13_weight = w13
layer.w2_weight = w2
self.w13_precision_config = w13_scale
self.w2_precision_config = w2_scale
if w13_bias is not None and w2_bias is not None:
replace_parameter(layer, "w13_bias", w13_bias)
replace_parameter(layer, "w2_bias", w2_bias)
# Build quant config
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
# Build kernel (modular or monolithic)
if self.moe_quant_config is not None and self.experts_cls is not None:
self.moe_kernel = make_mxfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
mxfp4_backend=self.mxfp4_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def process_weights_after_loading(self, layer):
w13 = layer.w13_weight
w2 = layer.w2_weight
w13_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
w13_bias = getattr(layer, "w13_bias", None)
w2_bias = getattr(layer, "w2_bias", None)
if self.mxfp4_backend == Mxfp4MoeBackend.NONE:
return
self._setup_kernel(layer, w13, w2, w13_scale, w2_scale, w13_bias, w2_bias)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
w1_bias = getattr(layer, "w13_bias", None)
w2_bias = getattr(layer, "w2_bias", None)
swiglu_limit = getattr(layer, "swiglu_limit", None)
if self.mxfp4_backend in TRITON_BACKENDS:
assert self.w13_precision_config is not None
assert self.w2_precision_config is not None
w1_scale = self.w13_precision_config
w2_scale = self.w2_precision_config
return make_mxfp4_moe_quant_config(
mxfp4_backend=self.mxfp4_backend,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
swiglu_limit=swiglu_limit,
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEExpertsModular:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel "
"initialization logic. This function should not be called."
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor:
assert not self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map,
shared_experts_input=shared_experts_input,
)
def apply_monolithic(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.is_monolithic
assert self.moe_kernel is not None
@@ -130,6 +130,7 @@ class OnlineMoEMethodBase(FusedMoEMethodBase):
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert self.moe_kernel is not None
@@ -1457,6 +1457,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.is_monolithic
assert self.moe_kernel is not None
@@ -149,6 +149,148 @@ def _per_token_group_quant_fp8(
tl.store(y_s_ptr, y_s)
@triton.jit
def _silu_mul_quant_fp8_packed_kernel(
input_ptr,
output_q_ptr,
output_scale_ptr,
M,
input_stride_m,
output_q_stride_m,
output_scale_stride_k,
clamp_limit,
N: tl.constexpr,
NUM_GROUPS: tl.constexpr,
fp8_min: tl.constexpr,
fp8_max: tl.constexpr,
GROUP_SIZE: tl.constexpr,
BLOCK_M: tl.constexpr,
HAS_CLAMP: tl.constexpr,
):
N_2: tl.constexpr = N // 2
pid_pack = tl.program_id(0)
pid_m = tl.program_id(1)
m_offset = pid_m * BLOCK_M
if m_offset >= M:
return
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, GROUP_SIZE)
row_mask = (m_offset + offs_m) < M
base_row_offset = (m_offset + offs_m[:, None]) * input_stride_m
base_out_offset = (m_offset + offs_m[:, None]) * output_q_stride_m
packed_scale = tl.zeros((BLOCK_M,), dtype=tl.int32)
for pack_idx in tl.static_range(4):
group_id = pid_pack * 4 + pack_idx
if group_id < NUM_GROUPS:
n_offset = group_id * GROUP_SIZE
act_ptrs = input_ptr + base_row_offset + n_offset + offs_n[None, :]
act_in = tl.load(act_ptrs, mask=row_mask[:, None], other=0.0)
mul_ptrs = act_ptrs + N_2
mul_in = tl.load(mul_ptrs, mask=row_mask[:, None], other=0.0)
act_f32 = act_in.to(tl.float32)
mul_f32 = mul_in.to(tl.float32)
if HAS_CLAMP:
act_f32 = tl.minimum(act_f32, clamp_limit)
mul_f32 = tl.clamp(mul_f32, -clamp_limit, clamp_limit)
y = (act_f32 / (1.0 + tl.exp(-act_f32))) * mul_f32
# Round through bf16 to match unfused precision path
y = y.to(tl.bfloat16).to(tl.float32)
absmax = tl.max(tl.abs(y), axis=1)
scale_raw = tl.maximum(absmax / fp8_max, 1e-10)
exponent = tl.ceil(tl.log2(scale_raw))
scale = tl.math.exp2(exponent)
y_q = tl.clamp(y / scale[:, None], fp8_min, fp8_max)
out_q_ptrs = output_q_ptr + base_out_offset + n_offset + offs_n[None, :]
tl.store(
out_q_ptrs,
y_q.to(output_q_ptr.dtype.element_ty),
mask=row_mask[:, None],
)
exponent_biased = tl.clamp(exponent + 127.0, 0.0, 255.0).to(tl.int32)
packed_scale = packed_scale | (exponent_biased << (pack_idx * 8))
scale_ptrs = output_scale_ptr + pid_pack * output_scale_stride_k + m_offset + offs_m
tl.store(scale_ptrs, packed_scale, mask=row_mask)
def silu_mul_quant_fp8_packed_triton(
input: torch.Tensor,
group_size: int = 128,
output_q: torch.Tensor | None = None,
clamp_limit: float | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert input.dim() == 2
assert input.is_contiguous()
M, N = input.shape
N_2 = N // 2
assert N_2 % group_size == 0
fp8_dtype = torch.float8_e4m3fn
finfo = torch.finfo(fp8_dtype)
fp8_min, fp8_max = finfo.min, finfo.max
num_groups_per_row = N_2 // group_size
num_packed_groups = (num_groups_per_row + 3) // 4
tma_aligned_M = ((M + 3) // 4) * 4
if output_q is None:
output_q = torch.empty((M, N_2), dtype=fp8_dtype, device=input.device)
output_scale_packed = torch.zeros(
(num_packed_groups, tma_aligned_M),
dtype=torch.int32,
device=input.device,
).T[:M, :]
BLOCK_M = 8
grid = (num_packed_groups, (M + BLOCK_M - 1) // BLOCK_M)
num_warps = max(4, group_size // 32)
num_stages = 2
has_clamp = clamp_limit is not None
_silu_mul_quant_fp8_packed_kernel[grid](
input,
output_q,
output_scale_packed,
M,
input.stride(0),
output_q.stride(0),
output_scale_packed.stride(1),
clamp_limit if has_clamp else 0.0,
N=N,
NUM_GROUPS=num_groups_per_row,
fp8_min=fp8_min,
fp8_max=fp8_max,
GROUP_SIZE=group_size,
BLOCK_M=BLOCK_M,
HAS_CLAMP=has_clamp,
num_warps=num_warps,
num_stages=num_stages,
)
return output_q, output_scale_packed
@triton.jit
def _silu_mul_per_token_group_quant_fp8_colmajor(
y_ptr, # [M, N]
@@ -823,19 +965,65 @@ def requant_weight_ue8m0_inplace(
s_old.copy_(s_requant)
def _upcast_e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor:
"""Upcast E8M0 (exponent-only) scale to float32.
E8M0 stores only the 8-bit biased exponent (bias=127). To convert
to float32 we place those 8 bits into the exponent field of an
IEEE-754 float32 (bits 23-30) with sign=0 and mantissa=0.
"""
exp_bits = scale.view(torch.uint8).to(torch.int32)
fp32_bits = exp_bits << 23
return fp32_bits.view(torch.float32)
def deepgemm_post_process_fp8_weight_block(
wq: torch.Tensor, ws: torch.Tensor, quant_block_shape: tuple[int], use_e8m0: bool
wq: torch.Tensor,
ws: torch.Tensor,
quant_block_shape: tuple[int, ...],
use_e8m0: bool,
is_bmm: bool = False,
bmm_batch_size: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
assert wq.dtype == torch.float8_e4m3fn, (
"Expected quantized tensor dtype "
f"to be torch.float8_e4m3fn, got {wq.dtype} instead."
)
assert ws.dtype == torch.float32, (
f"Expected tensor scales dtype to be torch.float32, got {ws.dtype} instead"
)
if use_e8m0:
requant_weight_ue8m0_inplace(wq, ws, block_size=quant_block_shape)
if ws.dtype == torch.float8_e8m0fnu:
# Scales already in E8M0 from checkpoint — upcast to fp32
# and skip requantization (weights already have power-of-two scales).
ws = _upcast_e8m0_to_fp32(ws)
else:
assert ws.dtype == torch.float32, (
f"Expected tensor scales dtype to be torch.float32 or "
f"torch.float8_e8m0fnu, got {ws.dtype} instead"
)
if use_e8m0:
requant_weight_ue8m0_inplace(wq, ws, block_size=quant_block_shape)
if is_bmm:
# Reshape 2D weight/scale to 3D for grouped BMM (einsum):
# wq: (g*r, d) -> (g, r, d)
# ws: (g*r/128, d/128) -> (g, r/128, d/128)
g = bmm_batch_size
assert wq.ndim == 2 and ws.ndim == 2
d = wq.size(1)
r = wq.size(0) // g
wq = wq.view(g, r, d)
ws = ws.view(g, r // quant_block_shape[0], d // quant_block_shape[1])
# Pre-transform scale with recipe=(1, 128, 128) to broadcast + pack
# into TMA-aligned UE8M0 (INT32) layout. At runtime fp8_einsum uses
# recipe=(1, 1, 128) which sees INT dtype and skips re-transform.
dg_ws = transform_sf_into_required_layout(
sf=ws,
mn=r,
k=d,
recipe=(1, quant_block_shape[0], quant_block_shape[1]),
num_groups=g,
is_sfa=False,
)
return wq, dg_ws
original_ndim = wq.ndim
if wq.ndim == 2:
@@ -984,11 +1172,13 @@ def create_fp8_scale_parameter(
input_size_per_partition: int,
block_size: list[int] | None,
weight_loader: Callable | None,
scale_dtype: torch.dtype | None = None,
) -> torch.nn.Parameter:
"""Create scale parameter based on quantization strategy."""
dtype = scale_dtype if scale_dtype is not None else torch.float32
if parameter_type == ChannelQuantScaleParameter:
scale = parameter_type(
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
data=torch.empty((sum(output_partition_sizes), 1), dtype=dtype),
output_dim=0,
weight_loader=weight_loader,
)
@@ -1000,7 +1190,7 @@ def create_fp8_scale_parameter(
data=torch.empty(
(output_size_per_partition + block_n - 1) // block_n,
(input_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32,
dtype=dtype,
),
input_dim=1,
output_dim=0,
@@ -1008,13 +1198,14 @@ def create_fp8_scale_parameter(
)
elif parameter_type == PerTensorScaleParameter:
scale = parameter_type(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
data=torch.empty(len(output_partition_sizes), dtype=dtype),
weight_loader=weight_loader,
)
else:
raise ValueError(f"Unknown parameter type: {parameter_type}")
scale[:] = torch.finfo(torch.float32).min
if dtype == torch.float32:
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "weight_scale"})
return scale
@@ -7,7 +7,10 @@ from typing import Any
import torch
from .base import RotaryEmbedding
from .deepseek_scaling_rope import DeepseekScalingRotaryEmbedding
from .deepseek_scaling_rope import (
DeepseekScalingRotaryEmbedding,
DeepseekV4ScalingRotaryEmbedding,
)
from .dual_chunk_rope import DualChunkRotaryEmbedding
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
from .dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding
@@ -60,11 +63,13 @@ def get_rope(
rope_parameters = rope_parameters or {}
base = rope_parameters.get("rope_theta", 10000)
scaling_type = rope_parameters.get("rope_type", "default")
partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
if partial_rotary_factor <= 0.0 or partial_rotary_factor > 1.0:
raise ValueError(f"{partial_rotary_factor=} must be between 0.0 and 1.0")
rotary_dim = int(head_size * partial_rotary_factor)
if rotary_dim := rope_parameters.get("rope_dim", None):
pass
else:
partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
if partial_rotary_factor <= 0.0 or partial_rotary_factor > 1.0:
raise ValueError(f"{partial_rotary_factor=} must be between 0.0 and 1.0")
rotary_dim = int(head_size * partial_rotary_factor)
key = (
head_size,
@@ -289,7 +294,11 @@ def get_rope(
"mscale_all_dim",
)
}
rotary_emb = DeepseekScalingRotaryEmbedding(
if rope_parameters.get("is_deepseek_v4", False):
cls = DeepseekV4ScalingRotaryEmbedding
else:
cls = DeepseekScalingRotaryEmbedding
rotary_emb = cls(
head_size,
rotary_dim,
original_max_position,

Some files were not shown because too many files have changed in this diff Show More