mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Feat] DeepSeek V4 Rebased (#40860)
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai> Signed-off-by: Woosuk Kwon <woosuk@inferact.ai> Signed-off-by: qizixi <zixi@inferact.ai> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Yongye Zhu <yongye@inferact.ai> Co-authored-by: Simon Mo <simon@inferact.ai> Co-authored-by: Bugen Zhao <i@bugenzhao.com> Co-authored-by: Giancarlo Delfin <gdelfin@inferact.ai> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roy Wang <yasong.wang@inferact.ai> Co-authored-by: Woosuk Kwon <woosuk@inferact.ai> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Zhewen Li <jerven.vllm@gmail.com> Co-authored-by: Zijing Liu <liuzijing2014@gmail.com> Co-authored-by: khluu <khluu000@gmail.com> Co-authored-by: qizixi <zixi@inferact.ai> Co-authored-by: Zhewen Li <zhewenli@inferact.ai>
This commit is contained in:
+5
-2
@@ -310,7 +310,9 @@ set(VLLM_EXT_SRC
|
||||
"csrc/torch_bindings.cpp")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_EXT_SRC "csrc/minimax_reduce_rms_kernel.cu")
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/minimax_reduce_rms_kernel.cu"
|
||||
"csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu")
|
||||
|
||||
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
||||
|
||||
@@ -1051,7 +1053,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/moe_wna16.cu"
|
||||
"csrc/moe/grouped_topk_kernels.cu"
|
||||
"csrc/moe/router_gemm.cu")
|
||||
"csrc/moe/router_gemm.cu"
|
||||
"csrc/moe/topk_softplus_sqrt_kernels.cu")
|
||||
endif()
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
@@ -20,7 +20,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
deepgemm
|
||||
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM.git
|
||||
GIT_TAG 477618cd51baffca09c4b0b87e97c03fe827ef03
|
||||
GIT_TAG 891d57b4db1071624b5c8fa0d1e51cb317fa709f
|
||||
GIT_SUBMODULES "third-party/cutlass" "third-party/fmt"
|
||||
GIT_PROGRESS TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
@@ -120,6 +120,11 @@ if(DEEPGEMM_ARCHS)
|
||||
COMPONENT _deep_gemm_C
|
||||
FILES_MATCHING PATTERN "*.py")
|
||||
|
||||
install(DIRECTORY "${deepgemm_SOURCE_DIR}/deep_gemm/mega/"
|
||||
DESTINATION vllm/third_party/deep_gemm/mega
|
||||
COMPONENT _deep_gemm_C
|
||||
FILES_MATCHING PATTERN "*.py")
|
||||
|
||||
# Generate envs.py (normally generated by DeepGEMM's setup.py build step)
|
||||
file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/deep_gemm_envs.py"
|
||||
"# Pre-installed environment variables\npersistent_envs = dict()\n")
|
||||
|
||||
@@ -19,7 +19,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
flashmla
|
||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
|
||||
GIT_TAG 692917b1cda61b93ac9ee2d846ec54e75afe87b1
|
||||
GIT_TAG a6ec2ba7bd0a7dff98b3f4d3e6b52b159c48d78b
|
||||
GIT_PROGRESS TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
|
||||
@@ -178,7 +178,12 @@ void rotary_embedding_gptj_impl(
|
||||
|
||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
std::optional<torch::Tensor> key, int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox) {
|
||||
torch::Tensor& cos_sin_cache, bool is_neox,
|
||||
int64_t rope_dim_offset, bool inverse) {
|
||||
TORCH_CHECK(rope_dim_offset == 0,
|
||||
"rope_dim_offset != 0 is not supported on CPU");
|
||||
TORCH_CHECK(!inverse, "inverse rotary embedding is not supported on CPU");
|
||||
|
||||
int num_tokens = positions.numel();
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(-1) / head_size;
|
||||
|
||||
@@ -263,7 +263,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def(
|
||||
"rotary_embedding(Tensor positions, Tensor! query,"
|
||||
" Tensor!? key, int head_size,"
|
||||
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
||||
" Tensor cos_sin_cache, bool is_neox, int "
|
||||
"rope_dim_offset=0, bool inverse=False) -> ()");
|
||||
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
|
||||
|
||||
// Quantization
|
||||
|
||||
@@ -0,0 +1,477 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
* SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
*
|
||||
* Horizontally-fused DeepseekV4-MLA kernel:
|
||||
* - Q side: per-head RMSNorm (no weight) + GPT-J RoPE on last ROPE_DIM
|
||||
* - KV side: GPT-J RoPE on last ROPE_DIM + UE8M0 FP8 quant on NoPE + paged
|
||||
* cache insert
|
||||
*
|
||||
* Structured after `applyMLARopeAndAssignQKVKernelGeneration` in
|
||||
* TensorRT-LLM's mlaKernels.cu: one kernel, one grid, with head-slot
|
||||
* dispatch choosing Q vs KV work per warp. The per-warp RMSNorm/RoPE
|
||||
* skeleton is adapted from vllm-deepseek_v4's existing
|
||||
* `fusedQKNormRopeKernel` (csrc/fused_qknorm_rope_kernel.cu).
|
||||
*
|
||||
* Assumptions (hard-coded for DeepseekV4 attention):
|
||||
* HEAD_DIM = 512
|
||||
* ROPE_DIM = 64 (RoPE applied to dims [NOPE_DIM, HEAD_DIM))
|
||||
* NOPE_DIM = 448
|
||||
* QUANT_BLOCK = 64 (UE8M0 FP8 quant block)
|
||||
* FP8_MAX = 448.0f
|
||||
* is_neox=false (GPT-J interleaved pairs)
|
||||
* cos_sin_cache layout [max_pos, rope_dim] = cos || sin (cos first, sin
|
||||
* second along last dim; each half is rope_dim/2 = 32 values)
|
||||
*
|
||||
* Cache layout per paged-cache block (block_size tokens):
|
||||
* [0, bs*576): token data, 448 fp8 + 128 bf16 each
|
||||
* [bs*576, bs*576 + bs*8): UE8M0 scales, 7 real + 1 pad per token
|
||||
*/
|
||||
|
||||
#include <cmath>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/cuda.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
#include "type_convert.cuh"
|
||||
|
||||
#ifndef FINAL_MASK
|
||||
#define FINAL_MASK 0xffffffffu
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
namespace deepseek_v4_fused_ops {
|
||||
|
||||
namespace {
|
||||
inline int getSMVersion() {
|
||||
auto* props = at::cuda::getCurrentDeviceProperties();
|
||||
return props->major * 10 + props->minor;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
// Constants
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
constexpr int kHeadDim = 512;
|
||||
constexpr int kRopeDim = 64;
|
||||
constexpr int kNopeDim = kHeadDim - kRopeDim; // 448
|
||||
constexpr int kQuantBlock = 64;
|
||||
constexpr int kNumQuantBlocks = kNopeDim / kQuantBlock; // 7
|
||||
constexpr int kScaleBytesPerToken = kNumQuantBlocks + 1; // 8 (7 real + 1 pad)
|
||||
constexpr int kTokenDataBytes = kNopeDim + kRopeDim * 2; // 448 + 128 = 576
|
||||
constexpr float kFp8Max = 448.0f;
|
||||
|
||||
// Per-warp layout: 32 lanes × 16 elems/lane = 512 elems = HEAD_DIM.
|
||||
constexpr int kNumLanes = 32;
|
||||
constexpr int kElemsPerLane = kHeadDim / kNumLanes; // 16
|
||||
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
// Small inline helpers
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
__device__ __forceinline__ float warp4MaxAbs(float val) {
|
||||
// Reduce absolute max across 4 consecutive lanes (lane id & 3 group).
|
||||
float peer = __shfl_xor_sync(FINAL_MASK, val, 1);
|
||||
val = fmaxf(val, peer);
|
||||
peer = __shfl_xor_sync(FINAL_MASK, val, 2);
|
||||
val = fmaxf(val, peer);
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ float warpSum(float val) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
// Kernel
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
//
|
||||
// Grid: 1D, gridDim.x = ceil(num_tokens_full * (num_heads_q + 1) /
|
||||
// warps_per_block) Block: blockDim.x = 256 threads (8 warps per block) Each
|
||||
// warp handles one (token, head_slot) pair. head_slot < num_heads_q →
|
||||
// Q branch (RMSNorm + RoPE, in place) head_slot == num_heads_q → KV
|
||||
// branch (RoPE + UE8M0 quant + insert)
|
||||
//
|
||||
// With DP padding, q/kv/position_ids can have more rows than slot_mapping.
|
||||
// The Q branch covers all `num_tokens_full` rows (downstream attention uses
|
||||
// them). The KV branch only inserts the first `num_tokens_insert` tokens
|
||||
// (= slot_mapping length) into the paged cache.
|
||||
//
|
||||
template <typename scalar_t_in>
|
||||
__global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel(
|
||||
scalar_t_in* __restrict__ q_inout, // [N, H, 512] bf16, in place
|
||||
scalar_t_in const* __restrict__ kv_in, // [N, 512] bf16
|
||||
uint8_t* __restrict__ k_cache, // [num_blocks, block_stride]
|
||||
int64_t const* __restrict__ slot_mapping, // [num_tokens_insert] i64
|
||||
int64_t const* __restrict__ position_ids, // [N] i64
|
||||
float const* __restrict__ cos_sin_cache, // [max_pos, 64] fp32
|
||||
float const eps,
|
||||
int const num_tokens_full, // = q.size(0) = kv.size(0)
|
||||
int const num_tokens_insert, // = slot_mapping.size(0), ≤ num_tokens_full
|
||||
int const num_heads_q, // H
|
||||
int const cache_block_size, // tokens per paged-cache block
|
||||
int const kv_block_stride) { // bytes per paged-cache block
|
||||
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
|
||||
// BF16 _typeConvert specialization is unavailable on pre-Ampere. The
|
||||
// DeepseekV4 kernel only runs with bf16 inputs in practice, so compile a
|
||||
// no-op stub for sm_70/sm_75 to keep multi-arch builds happy.
|
||||
if constexpr (std::is_same_v<scalar_t_in, c10::BFloat16>) {
|
||||
return;
|
||||
} else {
|
||||
#endif
|
||||
using Converter = vllm::_typeConvert<scalar_t_in>;
|
||||
|
||||
int const warpsPerBlock = blockDim.x / 32;
|
||||
int const warpId = threadIdx.x / 32;
|
||||
int const laneId = threadIdx.x % 32;
|
||||
int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId;
|
||||
|
||||
int const total_slots_per_token = num_heads_q + 1;
|
||||
int const tokenIdx = globalWarpIdx / total_slots_per_token;
|
||||
int const slotIdx = globalWarpIdx % total_slots_per_token;
|
||||
if (tokenIdx >= num_tokens_full) return;
|
||||
|
||||
bool const isKV = (slotIdx == num_heads_q);
|
||||
// KV branch: skip DP-padded tokens (no slot reserved for them).
|
||||
if (isKV && tokenIdx >= num_tokens_insert) return;
|
||||
|
||||
// PDL: wait for predecessor kernel (upstream q/kv producer) to signal
|
||||
// before touching any global memory. No-op when PDL is not enabled on
|
||||
// the launch. The CUDA runtime wrapper emits the griddepcontrol.wait
|
||||
// PTX with the required memory clobber internally.
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
|
||||
// Dim range this lane owns within the 512-wide head.
|
||||
int const dim_base = laneId * kElemsPerLane; // in [0, 512) step 16
|
||||
|
||||
// ── Load 16 bf16 → 16 fp32 registers (one 16-byte + one 16-byte LDG) ────
|
||||
float elements[kElemsPerLane];
|
||||
float sumOfSquares = 0.0f;
|
||||
|
||||
scalar_t_in const* src_ptr;
|
||||
if (isKV) {
|
||||
src_ptr = kv_in + static_cast<int64_t>(tokenIdx) * kHeadDim + dim_base;
|
||||
} else {
|
||||
int64_t const q_row_offset =
|
||||
(static_cast<int64_t>(tokenIdx) * num_heads_q + slotIdx) * kHeadDim +
|
||||
dim_base;
|
||||
src_ptr = q_inout + q_row_offset;
|
||||
}
|
||||
|
||||
// Two 16-byte loads per thread (8 bf16 each). Use uint4 as the vector
|
||||
// type and bitcast to scalar_t_in packed pairs for conversion.
|
||||
uint4 v0 = *reinterpret_cast<uint4 const*>(src_ptr);
|
||||
uint4 v1 = *reinterpret_cast<uint4 const*>(src_ptr + 8);
|
||||
|
||||
{
|
||||
typename Converter::packed_hip_type const* p0 =
|
||||
reinterpret_cast<typename Converter::packed_hip_type const*>(&v0);
|
||||
typename Converter::packed_hip_type const* p1 =
|
||||
reinterpret_cast<typename Converter::packed_hip_type const*>(&v1);
|
||||
// Each packed_hip_type holds 2 bf16 → 4 packed = 8 elems per uint4.
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
float2 f2 = Converter::convert(p0[i]);
|
||||
elements[2 * i] = f2.x;
|
||||
elements[2 * i + 1] = f2.y;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
float2 f2 = Converter::convert(p1[i]);
|
||||
elements[8 + 2 * i] = f2.x;
|
||||
elements[8 + 2 * i + 1] = f2.y;
|
||||
}
|
||||
}
|
||||
|
||||
// ── Q branch: RMSNorm with no weight (has_weight=False) ─────────────────
|
||||
// Variance + rsqrt + multiply all in fp32, no intermediate bf16 round.
|
||||
// The downstream bf16 round only happens at the final store.
|
||||
if (!isKV) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kElemsPerLane; i++) {
|
||||
sumOfSquares += elements[i] * elements[i];
|
||||
}
|
||||
sumOfSquares = warpSum<float>(sumOfSquares);
|
||||
float const rms_rcp =
|
||||
rsqrtf(sumOfSquares / static_cast<float>(kHeadDim) + eps);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kElemsPerLane; i++) {
|
||||
elements[i] = elements[i] * rms_rcp;
|
||||
}
|
||||
}
|
||||
|
||||
// ── GPT-J RoPE on dims [NOPE_DIM, HEAD_DIM) ─────────────────────────────
|
||||
// All math in fp32. cos_sin_cache is loaded as fp32 (its native storage).
|
||||
bool const is_rope_lane = dim_base >= kNopeDim;
|
||||
if (is_rope_lane) {
|
||||
int64_t const pos = position_ids[tokenIdx];
|
||||
constexpr int kHalfRope = kRopeDim / 2; // 32
|
||||
float const* cos_ptr = cos_sin_cache + pos * kRopeDim;
|
||||
float const* sin_ptr = cos_ptr + kHalfRope;
|
||||
|
||||
int const rope_local_base = dim_base - kNopeDim; // in [0, 64) step 16
|
||||
#pragma unroll
|
||||
for (int p = 0; p < kElemsPerLane / 2; p++) {
|
||||
int const pair_dim = rope_local_base + 2 * p;
|
||||
int const half_idx = pair_dim / 2;
|
||||
float const cos_v = VLLM_LDG(cos_ptr + half_idx);
|
||||
float const sin_v = VLLM_LDG(sin_ptr + half_idx);
|
||||
float const x_even = elements[2 * p];
|
||||
float const x_odd = elements[2 * p + 1];
|
||||
elements[2 * p] = x_even * cos_v - x_odd * sin_v;
|
||||
elements[2 * p + 1] = x_even * sin_v + x_odd * cos_v;
|
||||
}
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// Q branch: cast to bf16 and store back in place.
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
if (!isKV) {
|
||||
uint4 out0, out1;
|
||||
typename Converter::packed_hip_type* po0 =
|
||||
reinterpret_cast<typename Converter::packed_hip_type*>(&out0);
|
||||
typename Converter::packed_hip_type* po1 =
|
||||
reinterpret_cast<typename Converter::packed_hip_type*>(&out1);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
po0[i] = Converter::convert(
|
||||
make_float2(elements[2 * i], elements[2 * i + 1]));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
po1[i] = Converter::convert(
|
||||
make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1]));
|
||||
}
|
||||
scalar_t_in* dst =
|
||||
q_inout +
|
||||
(static_cast<int64_t>(tokenIdx) * num_heads_q + slotIdx) * kHeadDim +
|
||||
dim_base;
|
||||
*reinterpret_cast<uint4*>(dst) = out0;
|
||||
*reinterpret_cast<uint4*>(dst + 8) = out1;
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// KV branch.
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
int64_t const slot_id = slot_mapping[tokenIdx];
|
||||
if (slot_id < 0) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t const block_idx = slot_id / cache_block_size;
|
||||
int64_t const pos_in_block = slot_id % cache_block_size;
|
||||
uint8_t* block_base =
|
||||
k_cache + block_idx * static_cast<int64_t>(kv_block_stride);
|
||||
uint8_t* token_fp8_ptr = block_base + pos_in_block * kTokenDataBytes;
|
||||
uint8_t* token_bf16_ptr = token_fp8_ptr + kNopeDim;
|
||||
uint8_t* token_scale_ptr =
|
||||
block_base + static_cast<int64_t>(cache_block_size) * kTokenDataBytes +
|
||||
pos_in_block * kScaleBytesPerToken;
|
||||
|
||||
// Round K to bf16 first, matching the unfused reference path where K is
|
||||
// materialized as bf16 before K quantization. absmax, clamp, and FP8
|
||||
// quant below all run on these bf16-rounded values.
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kElemsPerLane; i++) {
|
||||
elements[i] = Converter::convert(Converter::convert(elements[i]));
|
||||
}
|
||||
|
||||
// Per-quant-block absmax must be computed by ALL 32 lanes (warp-collective
|
||||
// shuffle requires full participation). RoPE lanes contribute garbage,
|
||||
// but their values are gated out below via `!is_rope_lane`.
|
||||
float local_absmax = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kElemsPerLane; i++) {
|
||||
local_absmax = fmaxf(local_absmax, fabsf(elements[i]));
|
||||
}
|
||||
float const absmax = fmaxf(warp4MaxAbs(local_absmax), 1e-4f);
|
||||
float const exponent = ceilf(log2f(absmax / kFp8Max));
|
||||
float const inv_scale = exp2f(-exponent);
|
||||
|
||||
if (!is_rope_lane) {
|
||||
// ── NoPE lane: UE8M0 FP8 quant ───────────────────────────────────────
|
||||
uint8_t out_bytes[kElemsPerLane];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kElemsPerLane; i++) {
|
||||
float scaled = elements[i] * inv_scale;
|
||||
scaled = fminf(fmaxf(scaled, -kFp8Max), kFp8Max);
|
||||
__nv_fp8_storage_t s =
|
||||
__nv_cvt_float_to_fp8(scaled, __NV_SATFINITE, __NV_E4M3);
|
||||
out_bytes[i] = static_cast<uint8_t>(s);
|
||||
}
|
||||
// One 16-byte STG per lane.
|
||||
*reinterpret_cast<uint4*>(token_fp8_ptr + dim_base) =
|
||||
*reinterpret_cast<uint4 const*>(out_bytes);
|
||||
|
||||
// Lane (4k) of each 4-lane group writes the scale byte for block k<7.
|
||||
if ((laneId & 3) == 0) {
|
||||
int const q_block_idx = laneId >> 2; // 0..6 for NoPE lanes
|
||||
float encoded = fmaxf(fminf(exponent + 127.0f, 255.0f), 0.0f);
|
||||
token_scale_ptr[q_block_idx] = static_cast<uint8_t>(encoded);
|
||||
}
|
||||
// Lane 0 also writes the padding byte at index 7.
|
||||
if (laneId == 0) {
|
||||
token_scale_ptr[kNumQuantBlocks] = 0; // pad
|
||||
}
|
||||
} else {
|
||||
// ── RoPE lane: cast back to bf16 and store to cache bf16 tail ────────
|
||||
uint4 out0, out1;
|
||||
typename Converter::packed_hip_type* po0 =
|
||||
reinterpret_cast<typename Converter::packed_hip_type*>(&out0);
|
||||
typename Converter::packed_hip_type* po1 =
|
||||
reinterpret_cast<typename Converter::packed_hip_type*>(&out1);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
po0[i] = Converter::convert(
|
||||
make_float2(elements[2 * i], elements[2 * i + 1]));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
po1[i] = Converter::convert(
|
||||
make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1]));
|
||||
}
|
||||
int const rope_local_base = dim_base - kNopeDim; // in [0, 64)
|
||||
scalar_t_in* bf16_dst =
|
||||
reinterpret_cast<scalar_t_in*>(token_bf16_ptr) + rope_local_base;
|
||||
*reinterpret_cast<uint4*>(bf16_dst) = out0;
|
||||
*reinterpret_cast<uint4*>(bf16_dst + 8) = out1;
|
||||
}
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
// Launch wrapper
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
template <typename scalar_t_in>
|
||||
void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
|
||||
scalar_t_in* q_inout, scalar_t_in const* kv_in, uint8_t* k_cache,
|
||||
int64_t const* slot_mapping, int64_t const* position_ids,
|
||||
float const* cos_sin_cache, float const eps, int const num_tokens_full,
|
||||
int const num_tokens_insert, int const num_heads_q,
|
||||
int const cache_block_size, int const kv_block_stride,
|
||||
cudaStream_t stream) {
|
||||
constexpr int kBlockSize = 256;
|
||||
constexpr int kWarpsPerBlock = kBlockSize / 32;
|
||||
int64_t const total_warps =
|
||||
static_cast<int64_t>(num_tokens_full) * (num_heads_q + 1);
|
||||
int const grid =
|
||||
static_cast<int>((total_warps + kWarpsPerBlock - 1) / kWarpsPerBlock);
|
||||
|
||||
// PDL: enable programmatic stream serialization whenever the hardware
|
||||
// supports it (SM90+). On pre-Hopper GPUs the attribute is unavailable,
|
||||
// so leave numAttrs = 0 and launch as a regular kernel.
|
||||
static int const sm_version = getSMVersion();
|
||||
// Host-side guard: the device kernel body is compiled as a no-op for
|
||||
// bf16 on pre-Ampere (sm_70/sm_75) because _typeConvert<BFloat16> is
|
||||
// unavailable there. Refuse the launch loudly instead of silently
|
||||
// skipping the work.
|
||||
TORCH_CHECK(
|
||||
sm_version >= 80,
|
||||
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert requires sm_80+ "
|
||||
"(Ampere or newer); got sm_",
|
||||
sm_version);
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = dim3(grid);
|
||||
config.blockDim = dim3(kBlockSize);
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = 1;
|
||||
config.attrs = attrs;
|
||||
config.numAttrs = (sm_version >= 90) ? 1 : 0;
|
||||
|
||||
cudaLaunchKernelEx(
|
||||
&config, fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel<scalar_t_in>,
|
||||
q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, eps,
|
||||
num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size,
|
||||
kv_block_stride);
|
||||
}
|
||||
|
||||
} // namespace deepseek_v4_fused_ops
|
||||
} // namespace vllm
|
||||
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
// Torch op wrapper
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
void fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
|
||||
torch::Tensor& q, // [N, H, 512] bf16, in place
|
||||
torch::Tensor const& kv, // [N, 512] bf16 (read-only)
|
||||
torch::Tensor& k_cache, // [num_blocks, block_bytes] uint8
|
||||
torch::Tensor const& slot_mapping, // [N] int64
|
||||
torch::Tensor const& position_ids, // [N] int64
|
||||
torch::Tensor const& cos_sin_cache, // [max_pos, rope_dim] bf16
|
||||
double eps, int64_t cache_block_size) {
|
||||
TORCH_CHECK(q.is_cuda() && q.is_contiguous(), "q must be contiguous CUDA");
|
||||
TORCH_CHECK(kv.is_cuda() && kv.is_contiguous(), "kv must be contiguous CUDA");
|
||||
TORCH_CHECK(k_cache.is_cuda(), "k_cache must be CUDA");
|
||||
TORCH_CHECK(slot_mapping.is_cuda() && slot_mapping.dtype() == torch::kInt64,
|
||||
"slot_mapping must be int64 CUDA");
|
||||
TORCH_CHECK(position_ids.is_cuda() && position_ids.dtype() == torch::kInt64,
|
||||
"position_ids must be int64 CUDA");
|
||||
TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA");
|
||||
TORCH_CHECK(q.dim() == 3 && q.size(2) == 512, "q shape [N, H, 512]");
|
||||
TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
|
||||
TORCH_CHECK(q.dtype() == kv.dtype(), "q and kv dtype must match");
|
||||
TORCH_CHECK(k_cache.dtype() == torch::kUInt8, "k_cache must be uint8");
|
||||
TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
|
||||
"cos_sin_cache shape [max_pos, 64]");
|
||||
TORCH_CHECK(cos_sin_cache.dtype() == torch::kFloat32,
|
||||
"cos_sin_cache must be float32");
|
||||
|
||||
// With DP padding, slot_mapping can be shorter than q/kv/positions.
|
||||
// Q-norm+RoPE runs on all q.size(0) rows (downstream attention uses them);
|
||||
// KV quant+insert runs only on the first slot_mapping.size(0) rows.
|
||||
int const num_tokens_full = static_cast<int>(q.size(0));
|
||||
int const num_tokens_insert = static_cast<int>(slot_mapping.size(0));
|
||||
TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
|
||||
static_cast<int>(position_ids.size(0)) == num_tokens_full,
|
||||
"q/kv/position_ids row counts must match");
|
||||
TORCH_CHECK(num_tokens_insert <= num_tokens_full,
|
||||
"slot_mapping must not exceed q row count");
|
||||
int const num_heads_q = static_cast<int>(q.size(1));
|
||||
int const cache_block_size_i = static_cast<int>(cache_block_size);
|
||||
int const kv_block_stride = static_cast<int>(k_cache.stride(0));
|
||||
|
||||
at::cuda::OptionalCUDAGuard device_guard(device_of(q));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(
|
||||
q.scalar_type(), "fused_deepseek_v4_qnorm_rope_kv_insert", [&] {
|
||||
using qkv_scalar_t = scalar_t;
|
||||
vllm::deepseek_v4_fused_ops::
|
||||
launchFusedDeepseekV4QNormRopeKVRopeQuantInsert<qkv_scalar_t>(
|
||||
reinterpret_cast<qkv_scalar_t*>(q.data_ptr()),
|
||||
reinterpret_cast<qkv_scalar_t const*>(kv.data_ptr()),
|
||||
reinterpret_cast<uint8_t*>(k_cache.data_ptr()),
|
||||
reinterpret_cast<int64_t const*>(slot_mapping.data_ptr()),
|
||||
reinterpret_cast<int64_t const*>(position_ids.data_ptr()),
|
||||
cos_sin_cache.data_ptr<float>(), static_cast<float>(eps),
|
||||
num_tokens_full, num_tokens_insert, num_heads_q,
|
||||
cache_block_size_i, kv_block_stride, stream);
|
||||
});
|
||||
}
|
||||
@@ -77,7 +77,8 @@ __global__ void rms_norm_kernel(
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; j++) {
|
||||
float x = static_cast<float>(src1.val[j]);
|
||||
dst.val[j] = ((scalar_t)(x * s_variance)) * src2.val[j];
|
||||
float w = static_cast<float>(src2.val[j]);
|
||||
dst.val[j] = static_cast<scalar_t>(x * s_variance * w);
|
||||
}
|
||||
v_out[i] = dst;
|
||||
}
|
||||
@@ -134,10 +135,17 @@ fused_add_rms_norm_kernel(
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
int64_t strided_id = blockIdx.x * vec_input_stride + idx;
|
||||
_f16Vec<scalar_t, width> temp = residual_v[id];
|
||||
temp *= s_variance;
|
||||
temp *= weight_v[idx];
|
||||
input_v[strided_id] = temp;
|
||||
_f16Vec<scalar_t, width> res = residual_v[id];
|
||||
_f16Vec<scalar_t, width> w = weight_v[idx];
|
||||
_f16Vec<scalar_t, width> out;
|
||||
using Converter = _typeConvert<scalar_t>;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < width; ++j) {
|
||||
float x = Converter::convert(res.data[j]);
|
||||
float wf = Converter::convert(w.data[j]);
|
||||
out.data[j] = Converter::convert(x * s_variance * wf);
|
||||
}
|
||||
input_v[strided_id] = out;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,8 +182,8 @@ fused_add_rms_norm_kernel(
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
||||
input[blockIdx.x * input_stride + idx] =
|
||||
((scalar_t)(x * s_variance)) * weight[idx];
|
||||
float w = (float)weight[idx];
|
||||
input[blockIdx.x * input_stride + idx] = (scalar_t)(x * s_variance * w);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -65,9 +65,16 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; j++) {
|
||||
float x = static_cast<float>(src1.val[j]);
|
||||
float const out_norm = ((scalar_t)(x * s_variance)) * src2.val[j];
|
||||
float w = static_cast<float>(src2.val[j]);
|
||||
// Round normalized result through scalar_t to match the precision of the
|
||||
// unfused composite (rms_norm writes scalar_t, then
|
||||
// static_scaled_fp8_quant re-loads it as float before FP8 conversion).
|
||||
// Without this round, the fused path is strictly more accurate and
|
||||
// disagrees with the composite at exact E4M3 quantization tie boundaries.
|
||||
scalar_t out_norm = static_cast<scalar_t>(x * s_variance * w);
|
||||
out[blockIdx.x * hidden_size + idx * VEC_SIZE + j] =
|
||||
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
|
||||
scaled_fp8_conversion<true, fp8_type>(static_cast<float>(out_norm),
|
||||
scale_inv);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -127,13 +134,21 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
_f16Vec<scalar_t, width> temp = residual_v[id];
|
||||
temp *= s_variance;
|
||||
temp *= weight_v[idx];
|
||||
_f16Vec<scalar_t, width> res = residual_v[id];
|
||||
_f16Vec<scalar_t, width> w = weight_v[idx];
|
||||
using Converter = _typeConvert<scalar_t>;
|
||||
using HipT = typename Converter::hip_type;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i) {
|
||||
out[id * width + i] =
|
||||
scaled_fp8_conversion<true, fp8_type>(float(temp.data[i]), scale_inv);
|
||||
float x = Converter::convert(res.data[i]);
|
||||
float wf = Converter::convert(w.data[i]);
|
||||
// See note in rms_norm_static_fp8_quant_kernel: round through scalar_t
|
||||
// to match the unfused composite path at FP8 boundaries. We use the
|
||||
// backend's hip_type for the intermediate since c10::Half/BFloat16 has
|
||||
// ambiguous conversions on CUDA and no implicit conversion on ROCm.
|
||||
HipT out_norm_h = Converter::convert(x * s_variance * wf);
|
||||
out[id * width + i] = scaled_fp8_conversion<true, fp8_type>(
|
||||
Converter::convert(out_norm_h), scale_inv);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -176,9 +191,12 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
||||
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
|
||||
float w = (float)weight[idx];
|
||||
// See note in rms_norm_static_fp8_quant_kernel: round through scalar_t
|
||||
// to match the unfused composite path at FP8 boundaries.
|
||||
scalar_t out_norm = static_cast<scalar_t>(x * s_variance * w);
|
||||
out[blockIdx.x * hidden_size + idx] = scaled_fp8_conversion<true, fp8_type>(
|
||||
static_cast<float>(out_norm), scale_inv);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,15 @@ void topk_sigmoid(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
|
||||
torch::Tensor& gating_output, bool renormalize,
|
||||
std::optional<torch::Tensor> bias);
|
||||
|
||||
void topk_softplus_sqrt(torch::Tensor& topk_weights,
|
||||
torch::Tensor& topk_indices,
|
||||
torch::Tensor& token_expert_indices,
|
||||
torch::Tensor& gating_output, bool renormalize,
|
||||
double routed_scaling_factor,
|
||||
const c10::optional<torch::Tensor>& correction_bias,
|
||||
const c10::optional<torch::Tensor>& input_ids,
|
||||
const c10::optional<torch::Tensor>& tid2eid);
|
||||
|
||||
void moe_sum(torch::Tensor& input, torch::Tensor& output);
|
||||
|
||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
|
||||
@@ -0,0 +1,715 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
|
||||
* Copyright (c) 2024, The vLLM team.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <type_traits>
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "../cuda_compat.h"
|
||||
#include "../cub_helpers.h"
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
typedef __hip_bfloat162 __nv_bfloat162;
|
||||
#endif
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
|
||||
namespace vllm {
|
||||
namespace moe {
|
||||
|
||||
/// Aligned array type
|
||||
template <typename T,
|
||||
/// Number of elements in the array
|
||||
int N,
|
||||
/// Alignment requirement in bytes
|
||||
int Alignment = sizeof(T) * N>
|
||||
struct alignas(Alignment) AlignedArray {
|
||||
T data[N];
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ float toFloat(T value) {
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
return value;
|
||||
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
|
||||
return __bfloat162float(value);
|
||||
} else if constexpr (std::is_same_v<T, __half>) {
|
||||
return __half2float(value);
|
||||
}
|
||||
}
|
||||
|
||||
#define FINAL_MASK 0xffffffff
|
||||
template <typename T>
|
||||
__inline__ __device__ T warpReduceSum(T val) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
|
||||
return val;
|
||||
}
|
||||
|
||||
// ====================== TopK softplus_sqrt things
|
||||
// ===============================
|
||||
|
||||
/*
|
||||
A Top-K gating softplus_sqrt written to exploit when the number of experts in
|
||||
the MoE layers are a small power of 2. This allows us to cleanly share the
|
||||
rows among the threads in a single warp and eliminate communication between
|
||||
warps (so no need to use shared mem).
|
||||
|
||||
It fuses the sigmoid, max and argmax into a single kernel.
|
||||
|
||||
Limitations:
|
||||
1) This implementation is optimized for when the number of experts is a small
|
||||
power of 2. Additionally it also supports when number of experts is multiple
|
||||
of 64 which is still faster than the computing sigmoid and topK separately
|
||||
(only tested on CUDA yet). 2) This implementation assumes k is small, but will
|
||||
work for any k.
|
||||
*/
|
||||
|
||||
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG,
|
||||
int WARP_SIZE_PARAM, bool USE_HASH, typename IndType,
|
||||
typename InputType = float>
|
||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
|
||||
void topkGatingSoftplusSqrt(
|
||||
const InputType* input, const bool* finished, float* output,
|
||||
const int num_rows, IndType* indices, int* source_rows, const int k,
|
||||
const int start_expert, const int end_expert, const bool renormalize,
|
||||
double routed_scaling_factor, const float* correction_bias,
|
||||
const IndType* input_ids, const IndType* tid2eid) {
|
||||
static_assert(std::is_same_v<InputType, float> ||
|
||||
std::is_same_v<InputType, __nv_bfloat16> ||
|
||||
std::is_same_v<InputType, __half>,
|
||||
"InputType must be float, __nv_bfloat16, or __half");
|
||||
|
||||
// We begin by enforcing compile time assertions and setting up compile time
|
||||
// constants.
|
||||
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG),
|
||||
"BYTES_PER_LDG must be power of 2");
|
||||
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
|
||||
|
||||
// Number of bytes each thread pulls in per load
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType);
|
||||
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
|
||||
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
|
||||
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
|
||||
|
||||
if constexpr (std::is_same_v<InputType, __nv_bfloat16> ||
|
||||
std::is_same_v<InputType, __half>) {
|
||||
static_assert(ELTS_PER_LDG == 1 || ELTS_PER_LDG % 2 == 0,
|
||||
"ELTS_PER_LDG must be 1 or even for 16-bit conversion");
|
||||
}
|
||||
|
||||
// Restrictions based on previous section.
|
||||
static_assert(
|
||||
VPT % ELTS_PER_LDG == 0,
|
||||
"The elements per thread must be a multiple of the elements per ldg");
|
||||
static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0,
|
||||
"The threads per row must cleanly divide the threads per warp");
|
||||
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW),
|
||||
"THREADS_PER_ROW must be power of 2");
|
||||
static_assert(THREADS_PER_ROW <= WARP_SIZE_PARAM,
|
||||
"THREADS_PER_ROW can be at most warp size");
|
||||
|
||||
// We have NUM_EXPERTS elements per row. We specialize for small #experts
|
||||
static constexpr int ELTS_PER_WARP = WARP_SIZE_PARAM * VPT;
|
||||
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
|
||||
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
|
||||
|
||||
// Restrictions for previous section.
|
||||
static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0,
|
||||
"The elts per row must cleanly divide the total elt per warp");
|
||||
|
||||
// ===================== From this point, we finally start computing run-time
|
||||
// variables. ========================
|
||||
|
||||
// Compute CTA and warp rows. We pack multiple rows into a single warp, and a
|
||||
// block contains WARPS_PER_CTA warps. This, each block processes a chunk of
|
||||
// rows. We start by computing the start row for each block.
|
||||
const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
|
||||
|
||||
// Now, using the base row per thread block, we compute the base row per warp.
|
||||
const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
|
||||
|
||||
// The threads in a warp are split into sub-groups that will work on a row.
|
||||
// We compute row offset for each thread sub-group
|
||||
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
|
||||
const int thread_row = warp_base_row + thread_row_in_warp;
|
||||
|
||||
// Threads with indices out of bounds should early exit here.
|
||||
if (thread_row >= num_rows) {
|
||||
return;
|
||||
}
|
||||
const bool row_is_active = finished ? !finished[thread_row] : true;
|
||||
|
||||
// We finally start setting up the read pointers for each thread. First, each
|
||||
// thread jumps to the start of the row it will read.
|
||||
const InputType* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
|
||||
|
||||
// Now, we compute the group each thread belong to in order to determine the
|
||||
// first column to start loads.
|
||||
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
|
||||
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
|
||||
const InputType* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
||||
|
||||
// Finally, we pull in the data from global mem
|
||||
float row_chunk[VPT];
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
|
||||
// NOTE(zhuhaoran): dispatch different input types loading, BF16/FP16 convert
|
||||
// to float
|
||||
if constexpr (std::is_same_v<InputType, float>) {
|
||||
using VecType = AlignedArray<float, ELTS_PER_LDG>;
|
||||
VecType* row_chunk_vec_ptr = reinterpret_cast<VecType*>(&row_chunk);
|
||||
const VecType* vec_thread_read_ptr =
|
||||
reinterpret_cast<const VecType*>(thread_read_ptr);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
|
||||
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
||||
}
|
||||
} else if constexpr (std::is_same_v<InputType, __nv_bfloat16>) {
|
||||
if constexpr (ELTS_PER_LDG >= 2) {
|
||||
using VecType = AlignedArray<__nv_bfloat16, ELTS_PER_LDG>;
|
||||
float2* row_chunk_f2 = reinterpret_cast<float2*>(row_chunk);
|
||||
const VecType* vec_thread_read_ptr =
|
||||
reinterpret_cast<const VecType*>(thread_read_ptr);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
|
||||
VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
||||
int base_idx_f2 = ii * ELTS_PER_LDG / 2;
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) {
|
||||
row_chunk_f2[base_idx_f2 + jj] = __bfloat1622float2(
|
||||
*reinterpret_cast<const __nv_bfloat162*>(vec.data + jj * 2));
|
||||
}
|
||||
}
|
||||
} else { // ELTS_PER_LDG == 1
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
|
||||
const __nv_bfloat16* scalar_ptr =
|
||||
thread_read_ptr + ii * THREADS_PER_ROW;
|
||||
row_chunk[ii] = __bfloat162float(*scalar_ptr);
|
||||
}
|
||||
}
|
||||
} else if constexpr (std::is_same_v<InputType, __half>) {
|
||||
if constexpr (ELTS_PER_LDG >= 2) {
|
||||
using VecType = AlignedArray<__half, ELTS_PER_LDG>;
|
||||
float2* row_chunk_f2 = reinterpret_cast<float2*>(row_chunk);
|
||||
const VecType* vec_thread_read_ptr =
|
||||
reinterpret_cast<const VecType*>(thread_read_ptr);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
|
||||
VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
||||
int base_idx_f2 = ii * ELTS_PER_LDG / 2;
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) {
|
||||
row_chunk_f2[base_idx_f2 + jj] = __half22float2(
|
||||
*reinterpret_cast<const __half2*>(vec.data + jj * 2));
|
||||
}
|
||||
}
|
||||
} else { // ELTS_PER_LDG == 1
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
|
||||
const __half* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW;
|
||||
row_chunk[ii] = __half2float(*scalar_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
constexpr float threshold = 20.0f;
|
||||
constexpr float beta = 1.0f;
|
||||
|
||||
// Hash MoE path: indices are predetermined from lookup table
|
||||
if constexpr (USE_HASH) {
|
||||
const IndType token_id = input_ids[thread_row];
|
||||
const IndType* expert_indices_for_token = tid2eid + token_id * k;
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii) {
|
||||
float val = row_chunk[ii];
|
||||
float val_b = val * beta;
|
||||
val = (val_b > threshold) ? val : (__logf(1.0f + __expf(val_b))) / beta;
|
||||
row_chunk[ii] = sqrtf(val);
|
||||
}
|
||||
float selected_sum = 0.f;
|
||||
#pragma unroll
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
const int expert = expert_indices_for_token[k_idx];
|
||||
const int idx = k * thread_row + k_idx;
|
||||
for (int ii = 0; ii < VPT; ++ii) {
|
||||
const int group_id = ii / ELTS_PER_LDG;
|
||||
const int local_id = ii % ELTS_PER_LDG;
|
||||
const int expert_idx = first_elt_read_by_thread +
|
||||
group_id * THREADS_PER_ROW * ELTS_PER_LDG +
|
||||
local_id;
|
||||
if (expert == expert_idx) {
|
||||
indices[idx] = expert;
|
||||
selected_sum += row_chunk[ii];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Compute per-thread scale (using warp reduction when renormalizing).
|
||||
if (renormalize) {
|
||||
selected_sum = warpReduceSum(selected_sum);
|
||||
}
|
||||
float scale = static_cast<float>(routed_scaling_factor);
|
||||
if (renormalize) {
|
||||
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
|
||||
scale /= denom;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
const int expert = expert_indices_for_token[k_idx];
|
||||
const int idx = k * thread_row + k_idx;
|
||||
for (int ii = 0; ii < VPT; ++ii) {
|
||||
const int group_id = ii / ELTS_PER_LDG;
|
||||
const int local_id = ii % ELTS_PER_LDG;
|
||||
const int expert_idx = first_elt_read_by_thread +
|
||||
group_id * THREADS_PER_ROW * ELTS_PER_LDG +
|
||||
local_id;
|
||||
if (expert == expert_idx) {
|
||||
output[idx] = row_chunk[ii] * scale;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii) {
|
||||
float val = row_chunk[ii];
|
||||
float val_b = val * beta;
|
||||
// Compute softplus: log(1 + exp(val)) with numerical stability
|
||||
// When val > threshold, softplus(x) ≈ x to avoid exp overflow
|
||||
val = (val_b > threshold) ? val : (__logf(1.0f + __expf(val_b))) / beta;
|
||||
val = sqrtf(val);
|
||||
if (correction_bias) {
|
||||
const int group_id = ii / ELTS_PER_LDG;
|
||||
const int local_id = ii % ELTS_PER_LDG;
|
||||
const int expert_idx = first_elt_read_by_thread +
|
||||
group_id * THREADS_PER_ROW * ELTS_PER_LDG +
|
||||
local_id;
|
||||
val = val + correction_bias[expert_idx];
|
||||
}
|
||||
row_chunk[ii] = val;
|
||||
}
|
||||
|
||||
// Original TopK path: find top-k experts by score
|
||||
// Now, sigmoid_res contains the sigmoid of the row chunk. Now, I want to find
|
||||
// the topk elements in each row, along with the max index.
|
||||
int start_col = first_elt_read_by_thread;
|
||||
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
|
||||
|
||||
float selected_sum = 0.f;
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
// First, each thread does the local argmax
|
||||
float max_val = row_chunk[0];
|
||||
int expert = start_col;
|
||||
#pragma unroll
|
||||
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD;
|
||||
++ldg, col += COLS_PER_GROUP_LDG) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ELTS_PER_LDG; ++ii) {
|
||||
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
|
||||
|
||||
// No check on the experts here since columns with the smallest index
|
||||
// are processed first and only updated if > (not >=)
|
||||
if (val > max_val) {
|
||||
max_val = val;
|
||||
expert = col + ii;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now, we perform the argmax reduce. We use the butterfly pattern so threads
|
||||
// reach consensus about the max. This will be useful for K > 1 so that the
|
||||
// threads can agree on "who" had the max value. That thread can then blank out
|
||||
// their max with -inf and the warp can run more iterations...
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
|
||||
float other_max =
|
||||
VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
|
||||
int other_expert =
|
||||
VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
|
||||
|
||||
// We want lower indices to "win" in every thread so we break ties this
|
||||
// way
|
||||
if (other_max > max_val ||
|
||||
(other_max == max_val && other_expert < expert)) {
|
||||
max_val = other_max;
|
||||
expert = other_expert;
|
||||
}
|
||||
}
|
||||
|
||||
// Write the max for this k iteration to global memory.
|
||||
if (thread_group_idx == 0) {
|
||||
// Add a guard to ignore experts not included by this node
|
||||
const bool node_uses_expert =
|
||||
expert >= start_expert && expert < end_expert;
|
||||
const bool should_process_row = row_is_active && node_uses_expert;
|
||||
|
||||
// The lead thread from each sub-group will write out the final results to
|
||||
// global memory. (This will be a single) thread per row of the
|
||||
// input/output matrices.
|
||||
const int idx = k * thread_row + k_idx;
|
||||
if (correction_bias != nullptr) {
|
||||
max_val -= correction_bias[expert];
|
||||
}
|
||||
output[idx] = max_val;
|
||||
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
|
||||
source_rows[idx] = k_idx * num_rows + thread_row;
|
||||
if (renormalize) {
|
||||
selected_sum += max_val;
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, we clear the value in the thread with the current max if there
|
||||
// is another iteration to run.
|
||||
if (k_idx + 1 < k) {
|
||||
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
|
||||
const int thread_to_clear_in_group =
|
||||
(expert / ELTS_PER_LDG) % THREADS_PER_ROW;
|
||||
|
||||
// Only the thread in the group which produced the max will reset the
|
||||
// "winning" value to -inf.
|
||||
if (thread_group_idx == thread_to_clear_in_group) {
|
||||
const int offset_for_expert = expert % ELTS_PER_LDG;
|
||||
// Safe to set to any negative value since row_chunk values must be
|
||||
// between 0 and 1.
|
||||
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] =
|
||||
-10000.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply renormalization and routed scaling factor to final weights.
|
||||
if (thread_group_idx == 0) {
|
||||
float scale = static_cast<float>(routed_scaling_factor);
|
||||
if (renormalize) {
|
||||
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
|
||||
scale /= denom;
|
||||
}
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
const int idx = k * thread_row + k_idx;
|
||||
output[idx] = output[idx] * scale;
|
||||
}
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
// Constructs some constants needed to partition the work across threads at
|
||||
// compile time.
|
||||
template <int EXPERTS, int BYTES_PER_LDG, int WARP_SIZE_PARAM,
|
||||
typename InputType>
|
||||
struct TopkConstants {
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType);
|
||||
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 ||
|
||||
EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0,
|
||||
"");
|
||||
static constexpr int VECs_PER_THREAD =
|
||||
MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM));
|
||||
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
||||
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
||||
static const int ROWS_PER_WARP = WARP_SIZE_PARAM / THREADS_PER_ROW;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
#define DISPATCH_HASH(use_hash, USE_HASH, ...) \
|
||||
if (use_hash) { \
|
||||
const bool USE_HASH = true; \
|
||||
static_assert(USE_HASH == true, "USE_HASH must be compile-time constant"); \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
const bool USE_HASH = false; \
|
||||
static_assert(USE_HASH == false, \
|
||||
"USE_HASH must be compile-time constant"); \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM,
|
||||
int MAX_BYTES_PER_LDG, typename IndType, typename InputType>
|
||||
void topkGatingSoftplusSqrtLauncherHelper(
|
||||
const InputType* input, const bool* finished, float* output,
|
||||
IndType* indices, int* source_row, const int num_rows, const int k,
|
||||
const int start_expert, const int end_expert, const bool renormalize,
|
||||
double routed_scaling_factor, const float* correction_bias,
|
||||
const bool use_hash, const IndType* input_ids, const IndType* tid2eid,
|
||||
cudaStream_t stream) {
|
||||
static constexpr int BYTES_PER_LDG =
|
||||
MIN(MAX_BYTES_PER_LDG, sizeof(InputType) * EXPERTS);
|
||||
using Constants =
|
||||
detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM, InputType>;
|
||||
static constexpr int VPT = Constants::VPT;
|
||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
||||
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
|
||||
DISPATCH_HASH(use_hash, USE_HASH, {
|
||||
auto* kernel =
|
||||
&topkGatingSoftplusSqrt<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG,
|
||||
WARP_SIZE_PARAM, USE_HASH, IndType, InputType>;
|
||||
#ifndef USE_ROCM
|
||||
cudaLaunchConfig_t config = {};
|
||||
config.gridDim = num_blocks;
|
||||
config.blockDim = block_dim;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = 1;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernel, input, finished, output, num_rows,
|
||||
indices, source_row, k, start_expert, end_expert,
|
||||
renormalize, routed_scaling_factor, correction_bias,
|
||||
input_ids, tid2eid);
|
||||
#else
|
||||
kernel<<<num_blocks, block_dim, 0, stream>>>(
|
||||
input, finished, output, num_rows, indices, source_row, k, start_expert,
|
||||
end_expert, renormalize, routed_scaling_factor, correction_bias,
|
||||
input_ids, tid2eid);
|
||||
#endif
|
||||
})
|
||||
}
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define LAUNCH_SOFTPLUS_SQRT(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
|
||||
static_assert(WARP_SIZE == 32, \
|
||||
"Unsupported warp size. Only 32 is supported for CUDA"); \
|
||||
topkGatingSoftplusSqrtLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, \
|
||||
MAX_BYTES>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \
|
||||
routed_scaling_factor, correction_bias, use_hash, input_ids, tid2eid, \
|
||||
stream);
|
||||
#else
|
||||
#define LAUNCH_SOFTPLUS_SQRT(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
|
||||
if (WARP_SIZE == 64) { \
|
||||
topkGatingSoftplusSqrtLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64, \
|
||||
MAX_BYTES>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \
|
||||
routed_scaling_factor, correction_bias, use_hash, input_ids, \
|
||||
tid2eid, stream); \
|
||||
} else if (WARP_SIZE == 32) { \
|
||||
topkGatingSoftplusSqrtLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32, \
|
||||
MAX_BYTES>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \
|
||||
routed_scaling_factor, correction_bias, use_hash, input_ids, \
|
||||
tid2eid, stream); \
|
||||
} else { \
|
||||
assert(false && \
|
||||
"Unsupported warp size. Only 32 and 64 are supported for ROCm"); \
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename IndType, typename InputType>
|
||||
void topkGatingSoftplusSqrtKernelLauncher(
|
||||
const InputType* gating_output, float* topk_weights, IndType* topk_indices,
|
||||
int* token_expert_indices, const int num_tokens, const int num_experts,
|
||||
const int topk, const bool renormalize, double routed_scaling_factor,
|
||||
const float* correction_bias, const bool use_hash, const IndType* input_ids,
|
||||
const IndType* tid2eid, cudaStream_t stream) {
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16;
|
||||
#ifndef USE_ROCM
|
||||
// for bfloat16 dtype, we need 4 bytes loading to make sure num_experts
|
||||
// elements can be loaded by a warp
|
||||
static constexpr int BYTES_PER_LDG_MULTIPLE_64 =
|
||||
(std::is_same_v<InputType, __nv_bfloat16> ||
|
||||
std::is_same_v<InputType, __half>)
|
||||
? 4
|
||||
: 8;
|
||||
#endif
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_SOFTPLUS_SQRT(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_SOFTPLUS_SQRT(2, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_SOFTPLUS_SQRT(4, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_SOFTPLUS_SQRT(8, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 16:
|
||||
LAUNCH_SOFTPLUS_SQRT(16, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 32:
|
||||
LAUNCH_SOFTPLUS_SQRT(32, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_SOFTPLUS_SQRT(64, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_SOFTPLUS_SQRT(128, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_SOFTPLUS_SQRT(256, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 512:
|
||||
LAUNCH_SOFTPLUS_SQRT(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
// (CUDA only) support multiples of 64 when num_experts is not power of 2.
|
||||
// ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of
|
||||
// num_experts, alternatively we can test 4 bytes loading and enable it in
|
||||
// future.
|
||||
#ifndef USE_ROCM
|
||||
case 192:
|
||||
LAUNCH_SOFTPLUS_SQRT(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 320:
|
||||
LAUNCH_SOFTPLUS_SQRT(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 384:
|
||||
LAUNCH_SOFTPLUS_SQRT(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 448:
|
||||
LAUNCH_SOFTPLUS_SQRT(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 576:
|
||||
LAUNCH_SOFTPLUS_SQRT(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
#endif
|
||||
default: {
|
||||
TORCH_CHECK(false, "Unsupported expert number: ", num_experts);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace moe
|
||||
} // namespace vllm
|
||||
|
||||
template <typename ComputeType>
|
||||
void dispatch_topk_softplus_sqrt_launch(
|
||||
const ComputeType* gating_output, torch::Tensor& topk_weights,
|
||||
torch::Tensor& topk_indices, torch::Tensor& token_expert_indices,
|
||||
int num_tokens, int num_experts, int topk, bool renormalize,
|
||||
double routed_scaling_factor,
|
||||
const c10::optional<torch::Tensor>& correction_bias,
|
||||
const c10::optional<torch::Tensor>& input_ids,
|
||||
const c10::optional<torch::Tensor>& tid2eid, cudaStream_t stream) {
|
||||
const float* bias_ptr = nullptr;
|
||||
if (correction_bias.has_value()) {
|
||||
bias_ptr = correction_bias.value().data_ptr<float>();
|
||||
}
|
||||
bool use_hash = false;
|
||||
if (tid2eid.has_value()) {
|
||||
TORCH_CHECK(input_ids.has_value(), "input_ids is required for hash MoE");
|
||||
use_hash = true;
|
||||
}
|
||||
if (topk_indices.scalar_type() == at::ScalarType::Int) {
|
||||
const int* input_ids_ptr = nullptr;
|
||||
const int* tid2eid_ptr = nullptr;
|
||||
if (tid2eid.has_value()) {
|
||||
input_ids_ptr = input_ids.value().data_ptr<int>();
|
||||
tid2eid_ptr = tid2eid.value().data_ptr<int>();
|
||||
}
|
||||
|
||||
vllm::moe::topkGatingSoftplusSqrtKernelLauncher<int, ComputeType>(
|
||||
gating_output, topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<int>(), token_expert_indices.data_ptr<int>(),
|
||||
num_tokens, num_experts, topk, renormalize, routed_scaling_factor,
|
||||
bias_ptr, use_hash, input_ids_ptr, tid2eid_ptr, stream);
|
||||
} else if (topk_indices.scalar_type() == at::ScalarType::UInt32) {
|
||||
const uint32_t* input_ids_ptr = nullptr;
|
||||
const uint32_t* tid2eid_ptr = nullptr;
|
||||
if (tid2eid.has_value()) {
|
||||
input_ids_ptr = input_ids.value().data_ptr<uint32_t>();
|
||||
tid2eid_ptr = tid2eid.value().data_ptr<uint32_t>();
|
||||
}
|
||||
vllm::moe::topkGatingSoftplusSqrtKernelLauncher<uint32_t, ComputeType>(
|
||||
gating_output, topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<uint32_t>(), token_expert_indices.data_ptr<int>(),
|
||||
num_tokens, num_experts, topk, renormalize, routed_scaling_factor,
|
||||
bias_ptr, use_hash, input_ids_ptr, tid2eid_ptr, stream);
|
||||
} else {
|
||||
TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long);
|
||||
|
||||
const int64_t* input_ids_ptr = nullptr;
|
||||
const int64_t* tid2eid_ptr = nullptr;
|
||||
if (tid2eid.has_value()) {
|
||||
input_ids_ptr = input_ids.value().data_ptr<int64_t>();
|
||||
tid2eid_ptr = tid2eid.value().data_ptr<int64_t>();
|
||||
}
|
||||
|
||||
vllm::moe::topkGatingSoftplusSqrtKernelLauncher<int64_t, ComputeType>(
|
||||
gating_output, topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<int64_t>(), token_expert_indices.data_ptr<int>(),
|
||||
num_tokens, num_experts, topk, renormalize, routed_scaling_factor,
|
||||
bias_ptr, use_hash, input_ids_ptr, tid2eid_ptr, stream);
|
||||
}
|
||||
}
|
||||
|
||||
void topk_softplus_sqrt(
|
||||
torch::Tensor& topk_weights, // [num_tokens, topk]
|
||||
torch::Tensor& topk_indices, // [num_tokens, topk]
|
||||
torch::Tensor& token_expert_indices, // [num_tokens, topk]
|
||||
torch::Tensor& gating_output, // [num_tokens, num_experts]
|
||||
bool renormalize, double routed_scaling_factor,
|
||||
const c10::optional<torch::Tensor>& correction_bias,
|
||||
const c10::optional<torch::Tensor>& input_ids,
|
||||
const c10::optional<torch::Tensor>& tid2eid) {
|
||||
const int num_experts = gating_output.size(-1);
|
||||
const auto num_tokens = gating_output.numel() / num_experts;
|
||||
const int topk = topk_weights.size(-1);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (gating_output.scalar_type() == at::ScalarType::Float) {
|
||||
dispatch_topk_softplus_sqrt_launch<float>(
|
||||
gating_output.data_ptr<float>(), topk_weights, topk_indices,
|
||||
token_expert_indices, num_tokens, num_experts, topk, renormalize,
|
||||
routed_scaling_factor, correction_bias, input_ids, tid2eid, stream);
|
||||
} else if (gating_output.scalar_type() == at::ScalarType::Half) {
|
||||
dispatch_topk_softplus_sqrt_launch<__half>(
|
||||
reinterpret_cast<const __half*>(gating_output.data_ptr<at::Half>()),
|
||||
topk_weights, topk_indices, token_expert_indices, num_tokens,
|
||||
num_experts, topk, renormalize, routed_scaling_factor, correction_bias,
|
||||
input_ids, tid2eid, stream);
|
||||
} else if (gating_output.scalar_type() == at::ScalarType::BFloat16) {
|
||||
dispatch_topk_softplus_sqrt_launch<__nv_bfloat16>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(
|
||||
gating_output.data_ptr<at::BFloat16>()),
|
||||
topk_weights, topk_indices, token_expert_indices, num_tokens,
|
||||
num_experts, topk, renormalize, routed_scaling_factor, correction_bias,
|
||||
input_ids, tid2eid, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported gating_output data type: ",
|
||||
gating_output.scalar_type());
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"bias) -> ()");
|
||||
m.impl("topk_sigmoid", torch::kCUDA, &topk_sigmoid);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
m.def(
|
||||
"topk_softplus_sqrt(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
|
||||
"token_expert_indices, Tensor gating_output, bool renormalize, float "
|
||||
"routed_scaling_factor, Tensor? "
|
||||
"bias, Tensor? input_ids, Tensor? tid2eid) -> ()");
|
||||
m.impl("topk_softplus_sqrt", torch::kCUDA, &topk_softplus_sqrt);
|
||||
#endif
|
||||
// Calculate the result of moe by summing up the partial results
|
||||
// from all selected experts.
|
||||
m.def("moe_sum(Tensor input, Tensor! output) -> ()");
|
||||
|
||||
+7
-1
@@ -100,6 +100,11 @@ void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q,
|
||||
bool is_neox, torch::Tensor& position_ids,
|
||||
int64_t forced_token_heads_per_warp);
|
||||
|
||||
void fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
|
||||
torch::Tensor& q, torch::Tensor const& kv, torch::Tensor& k_cache,
|
||||
torch::Tensor const& slot_mapping, torch::Tensor const& position_ids,
|
||||
torch::Tensor const& cos_sin_cache, double eps, int64_t cache_block_size);
|
||||
|
||||
void apply_repetition_penalties_(torch::Tensor& logits,
|
||||
const torch::Tensor& prompt_mask,
|
||||
const torch::Tensor& output_mask,
|
||||
@@ -153,7 +158,8 @@ void silu_and_mul_per_block_quant(torch::Tensor& out,
|
||||
|
||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
std::optional<torch::Tensor> key, int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||
torch::Tensor& cos_sin_cache, bool is_neox,
|
||||
int64_t rope_dim_offset, bool inverse);
|
||||
|
||||
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
|
||||
+17
-16
@@ -18,7 +18,6 @@ namespace persistent {
|
||||
// Constants
|
||||
// ============================================================================
|
||||
|
||||
constexpr int TopK = 2048;
|
||||
constexpr int kThreadsPerBlock = 1024;
|
||||
constexpr int RADIX = 256;
|
||||
|
||||
@@ -128,11 +127,12 @@ struct RadixRowState {
|
||||
|
||||
struct PersistentTopKParams {
|
||||
const float* __restrict__ input; // [num_rows, stride]
|
||||
int32_t* __restrict__ output; // [num_rows, TopK]
|
||||
int32_t* __restrict__ output; // [num_rows, top_k]
|
||||
int32_t* __restrict__ lengths; // [num_rows]
|
||||
RadixRowState* row_states; // large path: per-group state
|
||||
uint32_t num_rows;
|
||||
uint32_t stride;
|
||||
uint32_t top_k; // actual k value for output stride
|
||||
uint32_t chunk_size; // large path: elements per CTA
|
||||
uint32_t ctas_per_group; // 1=medium, >1=large
|
||||
uint32_t max_seq_len; // max seq_len across all rows (for early CTA exit)
|
||||
@@ -154,6 +154,7 @@ __device__ __forceinline__ uint32_t decode_bin(float x) {
|
||||
return key >> 5;
|
||||
}
|
||||
|
||||
template <int TopK>
|
||||
__device__ __noinline__ void histogram_2048_topk(
|
||||
const float* __restrict__ logits, int32_t* __restrict__ output_indices,
|
||||
int32_t seq_len) {
|
||||
@@ -418,6 +419,7 @@ __device__ __noinline__ void histogram_2048_topk(
|
||||
// by: DarkSharpness
|
||||
// which at the same time is an optimized topk kernel copied from tilelang
|
||||
// kernel
|
||||
template <int TopK>
|
||||
__device__ __noinline__ void histogram_256_topk(
|
||||
const float* __restrict__ logits, int* __restrict__ output_indices,
|
||||
int logits_offset, int seq_len) {
|
||||
@@ -649,7 +651,7 @@ __device__ __forceinline__ void wait_ge(int* ptr, int target_val,
|
||||
// Adapted from https://github.com/flashinfer-ai/flashinfer/pull/2215
|
||||
// ============================================================================
|
||||
|
||||
template <uint32_t VEC_SIZE>
|
||||
template <int TopK, uint32_t VEC_SIZE>
|
||||
__device__ void radix_topk(const float* __restrict__ row_input,
|
||||
int32_t* __restrict__ row_output, uint32_t seq_len,
|
||||
uint32_t my_chunk_start, uint32_t chunk_size,
|
||||
@@ -857,7 +859,7 @@ __device__ void radix_topk(const float* __restrict__ row_input,
|
||||
// see filtered_topk.cuh)
|
||||
// ============================================================================
|
||||
|
||||
template <uint32_t VEC_SIZE = 1>
|
||||
template <int TopK = 2048, uint32_t VEC_SIZE = 1>
|
||||
__global__ void __launch_bounds__(kThreadsPerBlock, 2)
|
||||
persistent_topk_kernel(PersistentTopKParams params) {
|
||||
const uint32_t tx = threadIdx.x;
|
||||
@@ -915,7 +917,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 2)
|
||||
if (row_idx >= params.num_rows) break;
|
||||
|
||||
const uint32_t seq_len = params.lengths[row_idx];
|
||||
int32_t* row_output = params.output + row_idx * TopK;
|
||||
int32_t* row_output = params.output + row_idx * params.top_k;
|
||||
const float* row_input = params.input + row_idx * params.stride;
|
||||
|
||||
if (seq_len <= RADIX_THRESHOLD) {
|
||||
@@ -927,19 +929,19 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 2)
|
||||
row_output[i] = (i < seq_len) ? static_cast<int32_t>(i) : -1;
|
||||
}
|
||||
} else if (seq_len <= static_cast<uint32_t>(HIST2048_THRESHOLD)) {
|
||||
histogram_2048_topk(row_input, row_output, seq_len);
|
||||
histogram_2048_topk<TopK>(row_input, row_output, seq_len);
|
||||
} else {
|
||||
histogram_256_topk(row_input, row_output, 0, seq_len);
|
||||
histogram_256_topk<TopK>(row_input, row_output, 0, seq_len);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint32_t my_chunk_start = cta_in_group * chunk_size;
|
||||
radix_topk<VEC_SIZE>(row_input, row_output, seq_len, my_chunk_start,
|
||||
chunk_size, local_histogram, suffix_sum,
|
||||
shared_scalars, shared_ordered, state, cta_in_group,
|
||||
ctas_per_group, barrier_phase, iter, tx);
|
||||
radix_topk<TopK, VEC_SIZE>(
|
||||
row_input, row_output, seq_len, my_chunk_start, chunk_size,
|
||||
local_histogram, suffix_sum, shared_scalars, shared_ordered, state,
|
||||
cta_in_group, ctas_per_group, barrier_phase, iter, tx);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1011,7 +1013,6 @@ struct FilteredTopKTraits<float> {
|
||||
}
|
||||
};
|
||||
|
||||
constexpr uint32_t FILTERED_TOPK_MAX_K = 2048;
|
||||
constexpr uint32_t FILTERED_TOPK_BLOCK_THREADS = 1024;
|
||||
constexpr uint32_t FILTERED_TOPK_SMEM_INPUT_SIZE =
|
||||
16 * 1024; // 16K indices per buffer
|
||||
@@ -1025,7 +1026,7 @@ constexpr size_t FILTERED_TOPK_SMEM_DYNAMIC =
|
||||
* \tparam IdType Index type (int32_t)
|
||||
* \tparam VEC_SIZE Vector size for input loads (1, 2, 4, or 8)
|
||||
*/
|
||||
template <typename DType, typename IdType, int VEC_SIZE>
|
||||
template <typename DType, typename IdType, int VEC_SIZE, uint32_t MAX_K = 2048>
|
||||
__global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
|
||||
FilteredTopKUnifiedKernel(const DType* __restrict__ input,
|
||||
IdType* __restrict__ output,
|
||||
@@ -1059,7 +1060,7 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
|
||||
alignas(128) __shared__ int s_counter;
|
||||
alignas(128) __shared__ int s_threshold_bin_id;
|
||||
alignas(128) __shared__ int s_num_input[2];
|
||||
alignas(128) __shared__ int s_indices[FILTERED_TOPK_MAX_K];
|
||||
alignas(128) __shared__ int s_indices[MAX_K];
|
||||
|
||||
auto& s_histogram = s_histogram_buf[0];
|
||||
|
||||
@@ -1280,7 +1281,7 @@ constexpr int ComputeFilteredTopKVecSize(uint32_t max_len) {
|
||||
return static_cast<int>(g);
|
||||
}
|
||||
|
||||
template <typename DType, typename IdType>
|
||||
template <typename DType, typename IdType, uint32_t MAX_K = 2048>
|
||||
cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices,
|
||||
IdType* lengths, uint32_t num_rows,
|
||||
uint32_t top_k_val, uint32_t max_len,
|
||||
@@ -1297,7 +1298,7 @@ cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices,
|
||||
|
||||
#define DISPATCH_VEC_SIZE(VS) \
|
||||
if (vec_size == VS) { \
|
||||
auto kernel = FilteredTopKUnifiedKernel<DType, IdType, VS>; \
|
||||
auto kernel = FilteredTopKUnifiedKernel<DType, IdType, VS, MAX_K>; \
|
||||
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( \
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); \
|
||||
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, grid, block, args, \
|
||||
|
||||
@@ -9,28 +9,29 @@ namespace vllm {
|
||||
|
||||
template <typename scalar_t, bool IS_NEOX>
|
||||
inline __device__ void apply_token_rotary_embedding(
|
||||
scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
|
||||
const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) {
|
||||
scalar_t* __restrict__ arr, const float* __restrict__ cos_ptr,
|
||||
const float* __restrict__ sin_ptr, int rot_offset, int embed_dim,
|
||||
const bool inverse) {
|
||||
int x_index, y_index;
|
||||
scalar_t cos, sin;
|
||||
float cos_f, sin_f;
|
||||
if (IS_NEOX) {
|
||||
// GPT-NeoX style rotary embedding.
|
||||
x_index = rot_offset;
|
||||
y_index = embed_dim + rot_offset;
|
||||
cos = VLLM_LDG(cos_ptr + x_index);
|
||||
sin = VLLM_LDG(sin_ptr + x_index);
|
||||
cos_f = VLLM_LDG(cos_ptr + x_index);
|
||||
sin_f = VLLM_LDG(sin_ptr + x_index);
|
||||
} else {
|
||||
// GPT-J style rotary embedding.
|
||||
x_index = 2 * rot_offset;
|
||||
y_index = 2 * rot_offset + 1;
|
||||
cos = VLLM_LDG(cos_ptr + x_index / 2);
|
||||
sin = VLLM_LDG(sin_ptr + x_index / 2);
|
||||
cos_f = VLLM_LDG(cos_ptr + x_index / 2);
|
||||
sin_f = VLLM_LDG(sin_ptr + x_index / 2);
|
||||
}
|
||||
|
||||
const scalar_t x = arr[x_index];
|
||||
const scalar_t y = arr[y_index];
|
||||
arr[x_index] = x * cos - y * sin;
|
||||
arr[y_index] = y * cos + x * sin;
|
||||
if (inverse) {
|
||||
sin_f = -sin_f;
|
||||
}
|
||||
const float x_f = static_cast<float>(arr[x_index]);
|
||||
const float y_f = static_cast<float>(arr[y_index]);
|
||||
arr[x_index] = static_cast<scalar_t>(x_f * cos_f - y_f * sin_f);
|
||||
arr[y_index] = static_cast<scalar_t>(y_f * cos_f + x_f * sin_f);
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool IS_NEOX>
|
||||
@@ -42,22 +43,23 @@ inline __device__ void apply_rotary_embedding(
|
||||
// [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* cache_ptr, const int head_size, const int num_heads,
|
||||
const float* cache_ptr, const int head_size, const int num_heads,
|
||||
const int num_kv_heads, const int rot_dim, const int token_idx,
|
||||
const int64_t query_stride, const int64_t key_stride,
|
||||
const int64_t head_stride) {
|
||||
const int64_t head_stride, const int64_t rope_dim_offset,
|
||||
const bool inverse) {
|
||||
const int embed_dim = rot_dim / 2;
|
||||
const scalar_t* cos_ptr = cache_ptr;
|
||||
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
||||
const float* cos_ptr = cache_ptr;
|
||||
const float* sin_ptr = cache_ptr + embed_dim;
|
||||
|
||||
const int nq = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head =
|
||||
token_idx * query_stride + head_idx * head_stride;
|
||||
token_idx * query_stride + head_idx * head_stride + rope_dim_offset;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
||||
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim, inverse);
|
||||
}
|
||||
|
||||
if (key != nullptr) {
|
||||
@@ -65,10 +67,10 @@ inline __device__ void apply_rotary_embedding(
|
||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head =
|
||||
token_idx * key_stride + head_idx * head_stride;
|
||||
token_idx * key_stride + head_idx * head_stride + rope_dim_offset;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
||||
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim, inverse);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -84,19 +86,18 @@ __global__ void rotary_embedding_kernel(
|
||||
// [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||
// 2]
|
||||
const float* __restrict__ cos_sin_cache, // [max_position, rot_dim] fp32
|
||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||
const int64_t head_stride, const int num_heads, const int num_kv_heads,
|
||||
const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int head_size, const int64_t rope_dim_offset, const bool inverse) {
|
||||
const int token_idx = blockIdx.x;
|
||||
int64_t pos = positions[token_idx];
|
||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
const float* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
||||
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
||||
token_idx, query_stride, key_stride, head_stride);
|
||||
token_idx, query_stride, key_stride, head_stride, rope_dim_offset,
|
||||
inverse);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -115,7 +116,7 @@ void rotary_embedding(
|
||||
// [num_tokens, num_heads, head_size]
|
||||
int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox) {
|
||||
bool is_neox, int64_t rope_dim_offset, bool inverse) {
|
||||
// num_tokens = batch_size * seq_len
|
||||
int64_t num_tokens = positions.numel();
|
||||
int positions_ndim = positions.dim();
|
||||
@@ -154,6 +155,8 @@ void rotary_embedding(
|
||||
int seq_dim_idx = positions_ndim - 1;
|
||||
int64_t query_stride = query.stride(seq_dim_idx);
|
||||
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
|
||||
|
||||
TORCH_CHECK((rot_dim + rope_dim_offset) <= head_size);
|
||||
// Determine head stride: for [*, heads, head_size] use stride of last dim;
|
||||
// for flat [*, heads*head_size], heads blocks are contiguous of size
|
||||
// head_size
|
||||
@@ -165,20 +168,23 @@ void rotary_embedding(
|
||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
auto cache_f32 = cos_sin_cache.to(torch::kFloat32);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
|
||||
if (is_neox) {
|
||||
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride,
|
||||
head_stride, num_heads, num_kv_heads, head_size);
|
||||
cache_f32.data_ptr<float>(), rot_dim, query_stride, key_stride,
|
||||
head_stride, num_heads, num_kv_heads, head_size, rope_dim_offset,
|
||||
inverse);
|
||||
} else {
|
||||
vllm::rotary_embedding_kernel<scalar_t, false>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
|
||||
key_stride, head_stride, num_heads, num_kv_heads, head_size);
|
||||
cache_f32.data_ptr<float>(), rot_dim, query_stride, key_stride,
|
||||
head_stride, num_heads, num_kv_heads, head_size, rope_dim_offset,
|
||||
inverse);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
+7
-1
@@ -258,7 +258,13 @@ __device__ bool processHistogramStep(
|
||||
auto processBins = [&](float logit, int idx) {
|
||||
if (isPartialMatch<patternShift>(logit, logitPattern)) {
|
||||
uint32_t binIdx = extractBinIdx<step>(logit);
|
||||
if (binIdx < thresholdBinIdx) {
|
||||
// Only write elements with binIdx < thresholdBinIdx when:
|
||||
// 1. This is step 0 and the threshold bin is small enough (no step 1)
|
||||
// 2. This is step >= 1 (where pattern matching filters correctly)
|
||||
// This prevents duplicates when step 0 and step 1 both run.
|
||||
bool shouldWriteDirectly =
|
||||
(step == 0 && smemFinalBinSize[0] <= kNumFinalItems) || (step >= 1);
|
||||
if (binIdx < thresholdBinIdx && shouldWriteDirectly) {
|
||||
// The element is part of the top-k selection
|
||||
int dstIdx = atomicAdd(&smemFoundTopKValues[0], 1);
|
||||
|
||||
|
||||
+59
-35
@@ -10,33 +10,17 @@
|
||||
#include "persistent_topk.cuh"
|
||||
#endif
|
||||
|
||||
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
|
||||
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
|
||||
int64_t max_seq_len) {
|
||||
namespace {
|
||||
|
||||
#ifndef USE_ROCM
|
||||
TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor");
|
||||
TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor");
|
||||
TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor");
|
||||
TORCH_CHECK(logits.dtype() == torch::kFloat32, "Only float32 supported");
|
||||
TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32");
|
||||
TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32");
|
||||
TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
|
||||
TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2,
|
||||
"lengths must be 1D or 2D");
|
||||
TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous");
|
||||
TORCH_CHECK(output.dim() == 2, "output must be 2D");
|
||||
template <int TopK>
|
||||
void launch_persistent_topk(const torch::Tensor& logits,
|
||||
const torch::Tensor& lengths, torch::Tensor& output,
|
||||
torch::Tensor& workspace, int64_t max_seq_len) {
|
||||
namespace P = vllm::persistent;
|
||||
|
||||
const int64_t num_rows = logits.size(0);
|
||||
const int64_t stride = logits.size(1);
|
||||
|
||||
TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch");
|
||||
TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
|
||||
"output size mismatch");
|
||||
namespace P = vllm::persistent;
|
||||
|
||||
TORCH_CHECK(k == P::TopK, "k must be 2048");
|
||||
TORCH_CHECK(k <= stride, "k out of range");
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
static int num_sms = 0;
|
||||
@@ -50,18 +34,17 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
|
||||
}
|
||||
|
||||
if (num_rows > 32 && max_smem_per_block >= 128 * 1024) {
|
||||
cudaError_t status = vllm::FilteredTopKRaggedTransform<float, int32_t>(
|
||||
logits.data_ptr<float>(), output.data_ptr<int32_t>(),
|
||||
lengths.data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
|
||||
static_cast<uint32_t>(k), static_cast<uint32_t>(stride), stream);
|
||||
cudaError_t status =
|
||||
vllm::FilteredTopKRaggedTransform<float, int32_t, TopK>(
|
||||
logits.data_ptr<float>(), output.data_ptr<int32_t>(),
|
||||
lengths.data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
|
||||
static_cast<uint32_t>(TopK), static_cast<uint32_t>(stride), stream);
|
||||
TORCH_CHECK(status == cudaSuccess,
|
||||
"FilteredTopK failed: ", cudaGetErrorString(status));
|
||||
} else {
|
||||
TORCH_CHECK(workspace.is_cuda(), "workspace must be CUDA tensor");
|
||||
TORCH_CHECK(workspace.dtype() == torch::kUInt8, "workspace must be uint8");
|
||||
|
||||
// Smem cap: smaller smem → more CTAs/group → more per-row parallelism for
|
||||
// large path. Empirically tuned.
|
||||
int effective_max_smem;
|
||||
if (num_rows <= 4) {
|
||||
effective_max_smem =
|
||||
@@ -101,7 +84,7 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
|
||||
|
||||
int occupancy = 1;
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&occupancy, P::persistent_topk_kernel<4>, P::kThreadsPerBlock,
|
||||
&occupancy, P::persistent_topk_kernel<TopK, 4>, P::kThreadsPerBlock,
|
||||
smem_size);
|
||||
if (occupancy < 1) occupancy = 1;
|
||||
|
||||
@@ -121,15 +104,16 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
|
||||
params.lengths = lengths.data_ptr<int32_t>();
|
||||
params.num_rows = static_cast<uint32_t>(num_rows);
|
||||
params.stride = static_cast<uint32_t>(stride);
|
||||
params.top_k = static_cast<uint32_t>(TopK);
|
||||
params.chunk_size = chunk_size;
|
||||
params.row_states =
|
||||
reinterpret_cast<P::RadixRowState*>(workspace.data_ptr<uint8_t>());
|
||||
params.ctas_per_group = ctas_per_group;
|
||||
params.max_seq_len = static_cast<uint32_t>(max_seq_len);
|
||||
|
||||
#define LAUNCH_PERSISTENT(VS) \
|
||||
#define LAUNCH_PERSISTENT(TOPK_VAL, VS) \
|
||||
do { \
|
||||
auto kernel = &P::persistent_topk_kernel<VS>; \
|
||||
auto kernel = &P::persistent_topk_kernel<TOPK_VAL, VS>; \
|
||||
cudaError_t err = cudaFuncSetAttribute( \
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); \
|
||||
TORCH_CHECK(err == cudaSuccess, \
|
||||
@@ -138,11 +122,11 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
|
||||
} while (0)
|
||||
|
||||
if (vec_size == 4) {
|
||||
LAUNCH_PERSISTENT(4);
|
||||
LAUNCH_PERSISTENT(TopK, 4);
|
||||
} else if (vec_size == 2) {
|
||||
LAUNCH_PERSISTENT(2);
|
||||
LAUNCH_PERSISTENT(TopK, 2);
|
||||
} else {
|
||||
LAUNCH_PERSISTENT(1);
|
||||
LAUNCH_PERSISTENT(TopK, 1);
|
||||
}
|
||||
#undef LAUNCH_PERSISTENT
|
||||
}
|
||||
@@ -150,6 +134,46 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
|
||||
cudaError_t err = cudaGetLastError();
|
||||
TORCH_CHECK(err == cudaSuccess,
|
||||
"persistent_topk failed: ", cudaGetErrorString(err));
|
||||
}
|
||||
#endif
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
|
||||
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
|
||||
int64_t max_seq_len) {
|
||||
#ifndef USE_ROCM
|
||||
TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor");
|
||||
TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor");
|
||||
TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor");
|
||||
TORCH_CHECK(logits.dtype() == torch::kFloat32, "Only float32 supported");
|
||||
TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32");
|
||||
TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32");
|
||||
TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
|
||||
TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2,
|
||||
"lengths must be 1D or 2D");
|
||||
TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous");
|
||||
TORCH_CHECK(output.dim() == 2, "output must be 2D");
|
||||
|
||||
const int64_t num_rows = logits.size(0);
|
||||
const int64_t stride = logits.size(1);
|
||||
|
||||
TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch");
|
||||
TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
|
||||
"output size mismatch");
|
||||
TORCH_CHECK(k == 512 || k == 1024 || k == 2048,
|
||||
"persistent_topk supports k=512, k=1024, or k=2048, got k=", k);
|
||||
|
||||
if (k == 512) {
|
||||
launch_persistent_topk<512>(logits, lengths, output, workspace,
|
||||
max_seq_len);
|
||||
} else if (k == 1024) {
|
||||
launch_persistent_topk<1024>(logits, lengths, output, workspace,
|
||||
max_seq_len);
|
||||
} else {
|
||||
launch_persistent_topk<2048>(logits, lengths, output, workspace,
|
||||
max_seq_len);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "persistent_topk is not supported on ROCm");
|
||||
#endif
|
||||
|
||||
+15
-1
@@ -177,6 +177,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"int forced_token_heads_per_warp=-1) -> ()");
|
||||
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
|
||||
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
|
||||
// kernel launch.
|
||||
ops.def(
|
||||
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert("
|
||||
"Tensor! q, Tensor kv, Tensor! k_cache, "
|
||||
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
|
||||
"float eps, int cache_block_size) -> ()");
|
||||
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA,
|
||||
&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert);
|
||||
#endif
|
||||
|
||||
// Apply repetition penalties to logits in-place
|
||||
ops.def(
|
||||
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
|
||||
@@ -240,7 +253,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def(
|
||||
"rotary_embedding(Tensor positions, Tensor! query,"
|
||||
" Tensor!? key, int head_size,"
|
||||
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
||||
" Tensor cos_sin_cache, bool is_neox, int "
|
||||
"rope_dim_offset=0, bool inverse=False) -> ()");
|
||||
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
|
||||
|
||||
// Quantization ops
|
||||
|
||||
@@ -213,7 +213,7 @@ configuration.
|
||||
| `FLASHINFER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x |
|
||||
| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x |
|
||||
| `FLASHMLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x |
|
||||
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
|
||||
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 512, 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
|
||||
| `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
|
||||
| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
||||
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
|
||||
|
||||
@@ -384,6 +384,7 @@ th {
|
||||
| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | ✅︎ | ✅︎ |
|
||||
| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | ✅︎ | ✅︎ |
|
||||
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`, etc. | ✅︎ | ✅︎ |
|
||||
| `DeepseekV4ForCausalLM` | DeepSeek-V4 | `deepseek-ai/DeepSeek-V4-Flash`, `deepseek-ai/DeepSeek-V4-Pro`, etc. | | |
|
||||
| `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ |
|
||||
| `DotsOCRForCausalLM` | dots_ocr | `rednote-hilab/dots.ocr` | ✅︎ | ✅︎ |
|
||||
| `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ |
|
||||
@@ -643,10 +644,10 @@ Some models are supported only via the [Transformers modeling backend](#transfor
|
||||
!!! note
|
||||
`Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its
|
||||
MobileNet-v5 vision backbone.
|
||||
|
||||
|
||||
Performance is not yet fully optimized mainly due to:
|
||||
|
||||
- Both audio and vision MM encoders use `transformers.AutoModel` implementation.
|
||||
|
||||
- Both audio and vision MM encoders use `transformers.AutoModel` implementation.
|
||||
- There's no PLE caching or out-of-memory swapping support, as described in [Google's blog](https://developers.googleblog.com/en/introducing-gemma-3n/). These features might be too model-specific for vLLM, and swapping in particular may be better suited for constrained setups.
|
||||
|
||||
!!! note
|
||||
|
||||
@@ -11,6 +11,8 @@ torchvision==0.26.0 # Required for phi3v processor. See https://github.com/pytor
|
||||
# FlashInfer should be updated together with the Dockerfile
|
||||
flashinfer-python==0.6.8.post1
|
||||
flashinfer-cubin==0.6.8.post1
|
||||
apache-tvm-ffi==0.1.9
|
||||
tilelang==0.1.9
|
||||
# Cap nvidia-cudnn-frontend (transitive dep of flashinfer) due to
|
||||
# breaking changes in 1.19.0
|
||||
nvidia-cudnn-frontend>=1.13.0,<1.19.0
|
||||
|
||||
@@ -116,6 +116,11 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
model_kwargs["attention_config"] = {"backend": attn_backend.backend.name}
|
||||
model_kwargs["tensor_parallel_size"] = tp_size
|
||||
|
||||
# Cap warmup memory: tests use small max_model_len (1024) but the
|
||||
# engine default max_num_batched_tokens is 16384. Warming up large
|
||||
# models (e.g. Llama-4-Scout-FP8) at 16384 tokens may trigger OOM.
|
||||
model_kwargs.setdefault("max_num_batched_tokens", 8192)
|
||||
|
||||
# Sparse MLA models (DSv3.2) hit an over-strict inductor assertion in
|
||||
# decompose_auto_functionalized when +rotary_embedding is forced into
|
||||
# the compile graph. Disable qk_norm+rope fusion (which auto-enables
|
||||
|
||||
@@ -9,8 +9,8 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
_ceil_to_ue8m0,
|
||||
calc_diff,
|
||||
fp8_mqa_logits,
|
||||
fp8_paged_mqa_logits,
|
||||
fp8_fp4_mqa_logits,
|
||||
fp8_fp4_paged_mqa_logits,
|
||||
get_num_sms,
|
||||
get_paged_mqa_logits_metadata,
|
||||
)
|
||||
@@ -127,8 +127,8 @@ def test_deepgemm_fp8_mqa_logits(clean_logits: bool):
|
||||
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False)
|
||||
logits = fp8_mqa_logits(
|
||||
q_fp8, kv_fp8, weights, ks, ke, clean_logits=clean_logits
|
||||
logits = fp8_fp4_mqa_logits(
|
||||
(q_fp8, None), kv_fp8, weights, ks, ke, clean_logits=clean_logits
|
||||
)
|
||||
|
||||
ref_logits = _ref_fp8_mqa_logits(
|
||||
@@ -150,7 +150,7 @@ def test_deepgemm_fp8_mqa_logits(clean_logits: bool):
|
||||
assert diff < 1e-3, f"{diff=}"
|
||||
|
||||
|
||||
def _ref_fp8_paged_mqa_logits(
|
||||
def _ref_fp8_fp4_paged_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
@@ -205,8 +205,10 @@ def _ref_fp8_paged_mqa_logits(
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
|
||||
)
|
||||
@pytest.mark.parametrize("clean_logits", [True, False])
|
||||
def test_deepgemm_fp8_paged_mqa_logits(clean_logits: bool):
|
||||
def test_deepgemm_fp8_fp4_paged_mqa_logits():
|
||||
# NOTE: clean_logits=True is incompatible with the 2D context_lens
|
||||
# required by csrc/apis/attention.hpp; only the False path is exercised.
|
||||
clean_logits = False
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
@@ -258,21 +260,29 @@ def test_deepgemm_fp8_paged_mqa_logits(clean_logits: bool):
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
|
||||
|
||||
# deep_gemm paged MQA logits requires 2D context_lens of
|
||||
# shape (B, next_n) (csrc/apis/attention.hpp:332-335);
|
||||
# see indexer.py:607-608. For each batch/next_n token, the
|
||||
# effective context length is context_lens[b] - next_n + j + 1.
|
||||
next_n_arange = torch.arange(next_n, device="cuda", dtype=torch.int32)
|
||||
context_lens_2d = (
|
||||
context_lens.unsqueeze(-1) - next_n + 1 + next_n_arange
|
||||
).contiguous()
|
||||
schedule_metadata = get_paged_mqa_logits_metadata(
|
||||
context_lens, blocksize, get_num_sms()
|
||||
context_lens_2d, blocksize, get_num_sms()
|
||||
)
|
||||
logits = fp8_paged_mqa_logits(
|
||||
q_fp8,
|
||||
logits = fp8_fp4_paged_mqa_logits(
|
||||
(q_fp8, None),
|
||||
kv_cache_fp8,
|
||||
weights,
|
||||
context_lens,
|
||||
context_lens_2d,
|
||||
block_tables,
|
||||
schedule_metadata,
|
||||
max_model_len,
|
||||
clean_logits=clean_logits,
|
||||
)
|
||||
|
||||
ref_logits = _ref_fp8_paged_mqa_logits(
|
||||
ref_logits = _ref_fp8_fp4_paged_mqa_logits(
|
||||
q,
|
||||
kv_cache,
|
||||
weights,
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Correctness + large-token-count launch tests for fused_q_kv_rmsnorm.
|
||||
|
||||
Before the grid-dim fix the kernel used grid ``(2, num_tokens)``, which hit
|
||||
CUDA's 65535 grid-y cap for ``num_tokens >= 65536`` and failed with
|
||||
``Triton Error [CUDA]: invalid argument`` at every large chunked-prefill
|
||||
profile run. These tests pin the new grid layout.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.ops.deepseek_v4_ops import fused_q_kv_rmsnorm
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="fused_q_kv_rmsnorm requires a CUDA/ROCm device",
|
||||
)
|
||||
|
||||
|
||||
def _ref_rmsnorm(x: torch.Tensor, w: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
x_f32 = x.to(torch.float32)
|
||||
variance = x_f32.pow(2).mean(dim=-1, keepdim=True)
|
||||
y = x_f32 * torch.rsqrt(variance + eps) * w.to(torch.float32)
|
||||
return y.to(x.dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 17, 1024, 8192])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
def test_fused_q_kv_rmsnorm_correctness(num_tokens: int, dtype: torch.dtype):
|
||||
torch.manual_seed(0)
|
||||
device = "cuda"
|
||||
q_size, kv_size = 192, 576
|
||||
qr = torch.randn(num_tokens, q_size, dtype=dtype, device=device)
|
||||
kv = torch.randn(num_tokens, kv_size, dtype=dtype, device=device)
|
||||
qw = torch.randn(q_size, dtype=dtype, device=device)
|
||||
kvw = torch.randn(kv_size, dtype=dtype, device=device)
|
||||
eps = 1e-6
|
||||
|
||||
qr_out, kv_out = fused_q_kv_rmsnorm(qr, kv, qw, kvw, eps)
|
||||
|
||||
qr_ref = _ref_rmsnorm(qr, qw, eps)
|
||||
kv_ref = _ref_rmsnorm(kv, kvw, eps)
|
||||
|
||||
tol = dict(rtol=1e-2, atol=1e-2)
|
||||
torch.testing.assert_close(qr_out, qr_ref, **tol)
|
||||
torch.testing.assert_close(kv_out, kv_ref, **tol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [65535, 65536, 131072])
|
||||
def test_fused_q_kv_rmsnorm_launches_past_grid_y_cap(num_tokens: int):
|
||||
"""Regression guard: grid used to be (2, num_tokens), hitting CUDA's
|
||||
65535 grid-y cap at num_tokens >= 65536. The new grid (num_tokens, 2)
|
||||
lifts that bound to 2**31-1."""
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
q_size, kv_size = 192, 576
|
||||
qr = torch.randn(num_tokens, q_size, dtype=dtype, device=device)
|
||||
kv = torch.randn(num_tokens, kv_size, dtype=dtype, device=device)
|
||||
qw = torch.randn(q_size, dtype=dtype, device=device)
|
||||
kvw = torch.randn(kv_size, dtype=dtype, device=device)
|
||||
|
||||
qr_out, kv_out = fused_q_kv_rmsnorm(qr, kv, qw, kvw, 1e-6)
|
||||
# spot-check a couple of rows against the torch reference
|
||||
for row in (0, num_tokens // 2, num_tokens - 1):
|
||||
torch.testing.assert_close(
|
||||
qr_out[row],
|
||||
_ref_rmsnorm(qr[row : row + 1], qw, 1e-6)[0],
|
||||
rtol=1e-2,
|
||||
atol=1e-2,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
kv_out[row],
|
||||
_ref_rmsnorm(kv[row : row + 1], kvw, 1e-6)[0],
|
||||
rtol=1e-2,
|
||||
atol=1e-2,
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Unit-test DeepGEMM FP8 kernels (no DeepEP).
|
||||
Unit-test DeepGEMM FP8 and FP4 kernels (no DeepEP).
|
||||
Compare DeepGEMM path against the Triton fallback inside vLLM's fused_experts.
|
||||
"""
|
||||
|
||||
@@ -21,6 +21,8 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
FusedMoEQuantDesc,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
@@ -204,3 +206,195 @@ def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_i
|
||||
f"DeepGEMM path was not executed during the test. "
|
||||
f"Call counter: {call_counter['cnt']}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FP4 weight tests (DeepGEMM m_grouped_fp8_fp4_gemm_nt_contiguous)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_mxfp4_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
):
|
||||
"""
|
||||
Generate (w1, w2) expert weights in MXFP4 packed format with float32 scales,
|
||||
plus BF16 reference weights for validation.
|
||||
|
||||
w1 shape: (E, 2N, K//2) uint8 — packed FP4
|
||||
w2 shape: (E, K, N//2) uint8 — packed FP4
|
||||
w1_s shape: (E, 2N, K//32) float32 — per-row block-32 scales
|
||||
w2_s shape: (E, K, N//32) float32 — per-row block-32 scales
|
||||
w1_bf16: (E, 2N, K) — original BF16 for reference
|
||||
w2_bf16: (E, K, N) — original BF16 for reference
|
||||
"""
|
||||
from deep_gemm.utils.math import per_token_cast_to_fp4
|
||||
|
||||
dtype = torch.bfloat16
|
||||
gran_k = 32 # MXFP4 block size
|
||||
|
||||
# bf16 reference weights — scale by 1/sqrt(dim) for numerical stability
|
||||
w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) * (k**-0.5)
|
||||
w2_bf16 = torch.randn(e, k, n, device="cuda", dtype=dtype) * (n**-0.5)
|
||||
|
||||
# Quantize per-expert to FP4
|
||||
w1 = torch.empty(e, 2 * n, k // 2, device="cuda", dtype=torch.uint8)
|
||||
w2 = torch.empty(e, k, n // 2, device="cuda", dtype=torch.uint8)
|
||||
w1_s = torch.empty(
|
||||
e, 2 * n, math.ceil(k / gran_k), device="cuda", dtype=torch.float32
|
||||
)
|
||||
w2_s = torch.empty(e, k, math.ceil(n / gran_k), device="cuda", dtype=torch.float32)
|
||||
|
||||
for i in range(e):
|
||||
w1[i], w1_s[i] = per_token_cast_to_fp4(
|
||||
w1_bf16[i].float(), use_ue8m0=True, gran_k=gran_k
|
||||
)
|
||||
w2[i], w2_s[i] = per_token_cast_to_fp4(
|
||||
w2_bf16[i].float(), use_ue8m0=True, gran_k=gran_k
|
||||
)
|
||||
|
||||
return w1, w2, w1_s, w2_s, w1_bf16, w2_bf16
|
||||
|
||||
|
||||
def _bf16_moe_reference(x, w1, w2, topk_weights, topk_ids):
|
||||
"""BF16 token-loop MoE reference for correctness testing."""
|
||||
import torch.nn.functional as F
|
||||
|
||||
num_tokens, hidden_size = x.shape
|
||||
intermediate = w1.shape[1] // 2
|
||||
top_k = topk_ids.shape[1]
|
||||
|
||||
output = torch.zeros(num_tokens, hidden_size, dtype=torch.float32, device=x.device)
|
||||
for t in range(num_tokens):
|
||||
for kk in range(top_k):
|
||||
e = topk_ids[t, kk].item()
|
||||
w = topk_weights[t, kk].item()
|
||||
fc1 = x[t : t + 1].float() @ w1[e].float().T
|
||||
linear = fc1[:, :intermediate]
|
||||
gate = fc1[:, intermediate:]
|
||||
act = F.silu(gate) * linear
|
||||
fc2 = act @ w2[e].float().T
|
||||
output[t] += w * fc2[0]
|
||||
return output.to(torch.bfloat16)
|
||||
|
||||
|
||||
def run_single_fp4_case(m, n, k, topk, num_experts):
|
||||
"""
|
||||
Run one (M,N,K) configuration with FP4 weights on DeepGEMM and assert
|
||||
DeepGEMM FP4 == BF16 reference within tolerance.
|
||||
"""
|
||||
tokens_bf16 = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) * (k**-0.5)
|
||||
|
||||
# FP4 expert weight tensors + BF16 originals for reference
|
||||
w1, w2, w1_s, w2_s, w1_bf16, w2_bf16 = make_mxfp4_weights(num_experts, n, k)
|
||||
|
||||
router_logits = torch.randn(m, num_experts, device="cuda", dtype=torch.float32)
|
||||
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
|
||||
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
_fp8_dtype = current_platform.fp8_dtype()
|
||||
_block_shape = GroupShape(128, 128)
|
||||
quant_config = FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(_fp8_dtype, _block_shape, None, None, None, None),
|
||||
_a2=FusedMoEQuantDesc(_fp8_dtype, _block_shape, None, None, None, None),
|
||||
_w1=FusedMoEQuantDesc("mxfp4", None, w1_s, None, None, None),
|
||||
_w2=FusedMoEQuantDesc("mxfp4", None, w2_s, None, None, None),
|
||||
)
|
||||
moe_config = make_dummy_moe_config()
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.experts.deep_gemm_moe import (
|
||||
DeepGemmFP4Experts,
|
||||
)
|
||||
|
||||
deep_gemm_fp4_experts = mk.FusedMoEKernel(
|
||||
prepare_finalize=maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=quant_config,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=False,
|
||||
),
|
||||
fused_experts=DeepGemmFP4Experts(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
# DeepGEMM FP4 path
|
||||
out_deepgemm_fp4 = deep_gemm_fp4_experts.apply(
|
||||
hidden_states=tokens_bf16,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
global_num_experts=num_experts,
|
||||
activation=MoEActivation.SILU,
|
||||
apply_router_weight_on_input=False,
|
||||
expert_map=None,
|
||||
)
|
||||
|
||||
# BF16 reference using the same original weights
|
||||
out_ref = _bf16_moe_reference(tokens_bf16, w1_bf16, w2_bf16, topk_weights, topk_ids)
|
||||
|
||||
# FP4 vs BF16 reference: quantization error from FP4 weights + FP8 activations
|
||||
diff = calc_diff(out_deepgemm_fp4, out_ref)
|
||||
assert diff < 0.05, f"FP4 diff exceeded 5%: {diff}"
|
||||
|
||||
|
||||
# DeepSeek V4 dims: H=4096, I=2048, so N=2*I=4096, K=H=4096.
|
||||
# FP4 quantization with block_k=32 needs large K for good accuracy.
|
||||
FP4_MNKs = [
|
||||
(128, 4096, 4096), # DeepSeek V4 shape
|
||||
(256, 2048, 2048), # Half-size variant
|
||||
]
|
||||
|
||||
FP4_TOPKS = [2]
|
||||
FP4_NUM_EXPERTS = [8]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("m", "n", "k"), FP4_MNKs)
|
||||
@pytest.mark.parametrize("topk", FP4_TOPKS)
|
||||
@pytest.mark.parametrize("num_experts", FP4_NUM_EXPERTS)
|
||||
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
|
||||
def test_deepgemm_fp4_vs_triton(
|
||||
m, n, k, topk, num_experts, monkeypatch, workspace_init
|
||||
):
|
||||
pytest.importorskip("deep_gemm.utils.math")
|
||||
with monkeypatch.context() as mp:
|
||||
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||
|
||||
_DeepGemmFP4Experts = importlib.import_module(
|
||||
"vllm.model_executor.layers.fused_moe.experts.deep_gemm_moe"
|
||||
).DeepGemmFP4Experts
|
||||
|
||||
call_counter = {"cnt": 0}
|
||||
|
||||
orig_fn = _DeepGemmFP4Experts.apply
|
||||
|
||||
def _spy_apply(*args, **kwargs):
|
||||
call_counter["cnt"] += 1
|
||||
return orig_fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(_DeepGemmFP4Experts, "apply", _spy_apply)
|
||||
if topk > num_experts:
|
||||
pytest.skip(f"topk={topk} > num_experts={num_experts}")
|
||||
|
||||
run_single_fp4_case(
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
topk=topk,
|
||||
num_experts=num_experts,
|
||||
)
|
||||
|
||||
# ensure that the DeepGEMM FP4 path was indeed taken.
|
||||
assert call_counter["cnt"] == 1, (
|
||||
f"DeepGEMM FP4 path was not executed during the test. "
|
||||
f"Call counter: {call_counter['cnt']}"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
RoutingMethodType,
|
||||
get_routing_method_type,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
|
||||
fused_topk_bias,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def _torch_topk_softplus_sqrt(
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
routed_scaling_factor: float,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
hash_indices_table: torch.Tensor | None = None,
|
||||
):
|
||||
scores = F.softplus(gating_output.float()).sqrt()
|
||||
original_scores = scores
|
||||
if e_score_correction_bias is not None:
|
||||
scores_for_choice = scores + e_score_correction_bias.unsqueeze(0)
|
||||
else:
|
||||
scores_for_choice = scores
|
||||
|
||||
if hash_indices_table is not None:
|
||||
assert input_ids is not None
|
||||
topk_ids = hash_indices_table[input_ids.long()]
|
||||
else:
|
||||
topk_ids = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=True)[1]
|
||||
|
||||
topk_weights = original_scores.gather(1, topk_ids.long())
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
if routed_scaling_factor != 1.0:
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
|
||||
|
||||
def test_sqrtsoftplus_bias_uses_deepseek_v4_routing_method():
|
||||
assert (
|
||||
get_routing_method_type(
|
||||
scoring_func="sqrtsoftplus",
|
||||
top_k=8,
|
||||
renormalize=True,
|
||||
num_expert_group=None,
|
||||
has_e_score_bias=True,
|
||||
)
|
||||
== RoutingMethodType.DeepseekV4
|
||||
)
|
||||
assert (
|
||||
get_routing_method_type(
|
||||
scoring_func="sqrtsoftplus",
|
||||
top_k=8,
|
||||
renormalize=False,
|
||||
num_expert_group=None,
|
||||
has_e_score_bias=True,
|
||||
)
|
||||
== RoutingMethodType.Unspecified
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
|
||||
)
|
||||
@pytest.mark.parametrize("num_tokens", [1, 33, 128])
|
||||
@pytest.mark.parametrize("hidden_size", [1024, 2048])
|
||||
@pytest.mark.parametrize("num_experts", [128, 256, 384, 512])
|
||||
@pytest.mark.parametrize("topk", [6, 8, 16])
|
||||
@pytest.mark.parametrize("renormalize", [True, False])
|
||||
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 1.5])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
|
||||
def test_fused_topk_softplus_sqrt(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
routed_scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
hidden_states = torch.randn((num_tokens, hidden_size), dtype=dtype, device="cuda")
|
||||
gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda")
|
||||
e_score_correction_bias = torch.randn(
|
||||
(num_experts,), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
|
||||
topk_weights_ref, topk_ids_ref = _torch_topk_softplus_sqrt(
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids = fused_topk_bias(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
scoring_func="sqrtsoftplus",
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
# Different kernels may return the topk experts in different orders when
|
||||
# scores tie; sort by expert id before comparing.
|
||||
sorted_ref_ids, idx_ref = topk_ids_ref.sort(dim=-1)
|
||||
sorted_ids, idx_ops = topk_ids.sort(dim=-1)
|
||||
torch.testing.assert_close(sorted_ref_ids, sorted_ids, atol=0, rtol=0)
|
||||
|
||||
sorted_w_ref = topk_weights_ref.gather(1, idx_ref)
|
||||
sorted_w = topk_weights.gather(1, idx_ops)
|
||||
torch.testing.assert_close(sorted_w_ref, sorted_w, atol=2e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
|
||||
)
|
||||
@pytest.mark.parametrize("num_tokens", [1, 33, 128])
|
||||
@pytest.mark.parametrize("hidden_size", [1024, 2048])
|
||||
@pytest.mark.parametrize("num_experts", [256, 384, 512])
|
||||
@pytest.mark.parametrize("topk", [6, 8, 16])
|
||||
@pytest.mark.parametrize("renormalize", [True, False])
|
||||
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
|
||||
def test_fused_topk_softplus_sqrt_hash(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
routed_scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
vocab_size = 1024
|
||||
hidden_states = torch.randn((num_tokens, hidden_size), dtype=dtype, device="cuda")
|
||||
gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda")
|
||||
# Per-token fixed expert selection: for each vocab id pick `topk` distinct
|
||||
# experts.
|
||||
hash_indices_table = torch.stack(
|
||||
[torch.randperm(num_experts)[:topk] for _ in range(vocab_size)]
|
||||
).to(device="cuda", dtype=torch.int32)
|
||||
input_ids = torch.randint(
|
||||
0, vocab_size, (num_tokens,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
topk_weights_ref, topk_ids_ref = _torch_topk_softplus_sqrt(
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
input_ids=input_ids,
|
||||
hash_indices_table=hash_indices_table,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids = fused_topk_bias(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
scoring_func="sqrtsoftplus",
|
||||
e_score_correction_bias=None,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
input_tokens=input_ids,
|
||||
hash_indices_table=hash_indices_table,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
sorted_ref_ids, idx_ref = topk_ids_ref.sort(dim=-1)
|
||||
sorted_ids, idx_ops = topk_ids.sort(dim=-1)
|
||||
torch.testing.assert_close(sorted_ref_ids, sorted_ids, atol=0, rtol=0)
|
||||
|
||||
sorted_w_ref = topk_weights_ref.gather(1, idx_ref)
|
||||
sorted_w = topk_weights.gather(1, idx_ops)
|
||||
torch.testing.assert_close(sorted_w_ref, sorted_w, atol=2e-2, rtol=1e-2)
|
||||
@@ -0,0 +1,311 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Round-trip tests for compressor → FP8 quant + KV cache insert → gather + dequant.
|
||||
|
||||
Two paths tested:
|
||||
A) DeepseekV4 Attention: head_dim=512 (448 FP8 nope + 64 bf16 rope), quant_block=64
|
||||
B) Indexer: head_dim=128 (all FP8), quant_block=128
|
||||
|
||||
These serve as golden references for validating the future fused
|
||||
compressor+quant+cache kernel.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.v1.attention.ops.deepseek_v4_ops import (
|
||||
dequantize_and_gather_k_cache,
|
||||
quantize_and_insert_k_cache,
|
||||
)
|
||||
|
||||
|
||||
def _ue8m0_reference(x: torch.Tensor, block_size: int, fp8_max: float):
|
||||
"""PyTorch reference for UE8M0 FP8 quantization (per-block, power-of-2 scale).
|
||||
|
||||
Returns (x_fp8, scales) where x_fp8 is float8_e4m3fn and scales are float32.
|
||||
"""
|
||||
assert x.dim() == 1
|
||||
n = x.numel()
|
||||
n_blocks = math.ceil(n / block_size)
|
||||
x_fp8 = torch.zeros(n, dtype=torch.float8_e4m3fn, device=x.device)
|
||||
scales = torch.zeros(n_blocks, dtype=torch.float32, device=x.device)
|
||||
|
||||
for i in range(n_blocks):
|
||||
start = i * block_size
|
||||
end = min(start + block_size, n)
|
||||
block = x[start:end].float()
|
||||
amax = block.abs().max().clamp(min=1e-4)
|
||||
raw_scale = amax / fp8_max
|
||||
exponent = math.ceil(math.log2(raw_scale.item()))
|
||||
scale = 2.0**exponent
|
||||
scales[i] = scale
|
||||
quantized = (block / scale).clamp(-fp8_max, fp8_max)
|
||||
x_fp8[start:end] = quantized.to(torch.float8_e4m3fn)
|
||||
|
||||
return x_fp8, scales
|
||||
|
||||
|
||||
# ── Test A: DeepseekV4 Attention path ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 17])
|
||||
@pytest.mark.parametrize("block_size", [16, 64])
|
||||
def test_deepseek_v4_attention_quant_cache_roundtrip(num_tokens: int, block_size: int):
|
||||
"""compressed_kv → quantize_and_insert_k_cache → dequantize_and_gather_k_cache
|
||||
→ compare against original."""
|
||||
|
||||
HEAD_DIM = 512
|
||||
NOPE_DIM = 448
|
||||
HEAD_BYTES = 584 # 448 fp8 + 128 bf16 + 8 uint8 scale
|
||||
FP8_MAX = 448.0
|
||||
QUANT_BLOCK = 64
|
||||
|
||||
num_blocks = (num_tokens + block_size - 1) // block_size + 1
|
||||
device = "cuda"
|
||||
|
||||
# Random compressed_kv (simulates compressor output)
|
||||
compressed_kv = torch.randn(
|
||||
num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
|
||||
# ── Quant + insert ──────────────────────────────────────────────────
|
||||
k_cache = torch.zeros(
|
||||
num_blocks, block_size, HEAD_BYTES, dtype=torch.uint8, device=device
|
||||
)
|
||||
k_cache_2d = k_cache.view(num_blocks, -1)
|
||||
|
||||
# Sequential slot mapping: token i → slot i
|
||||
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
|
||||
|
||||
quantize_and_insert_k_cache(
|
||||
compressed_kv, k_cache_2d, slot_mapping, block_size=block_size
|
||||
)
|
||||
|
||||
# ── Gather + dequant ────────────────────────────────────────────────
|
||||
num_reqs = 1
|
||||
max_blocks_per_seq = num_blocks
|
||||
out = torch.zeros(
|
||||
num_reqs, num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
seq_lens = torch.tensor([num_tokens], dtype=torch.int32, device=device)
|
||||
# block_table: request 0 uses physical blocks 0, 1, ...
|
||||
block_table = torch.arange(
|
||||
max_blocks_per_seq, dtype=torch.int32, device=device
|
||||
).unsqueeze(0)
|
||||
|
||||
dequantize_and_gather_k_cache(
|
||||
out, k_cache, seq_lens, None, block_table, block_size, offset=0
|
||||
)
|
||||
|
||||
recovered = out[0, :num_tokens]
|
||||
|
||||
# ── NoPE portion (first 448): FP8 quantized, expect UE8M0 error ──
|
||||
nope_orig = compressed_kv[:, :NOPE_DIM].float()
|
||||
nope_recv = recovered[:, :NOPE_DIM].float()
|
||||
nope_diff = (nope_recv - nope_orig).abs()
|
||||
|
||||
# Per-token check: FP8 e4m3 (3-bit mantissa) worst-case error is
|
||||
# half-ULP at the largest representable value. At y ≈ 448 (max),
|
||||
# ULP = 2^(8-3) = 32, so error ≤ 16 * scale.
|
||||
for t in range(num_tokens):
|
||||
_, scales = _ue8m0_reference(
|
||||
compressed_kv[t, :NOPE_DIM].float(), QUANT_BLOCK, FP8_MAX
|
||||
)
|
||||
max_allowed = 16.0 * scales.max().item()
|
||||
token_diff = nope_diff[t].max().item()
|
||||
assert token_diff <= max_allowed, (
|
||||
f"Token {t} nope diff {token_diff} exceeds max_allowed "
|
||||
f"{max_allowed} (scale={scales.max().item()})"
|
||||
)
|
||||
|
||||
# ── RoPE portion (last 64): stored as bf16, should be exact ─────
|
||||
rope_diff = (recovered[:, NOPE_DIM:] - compressed_kv[:, NOPE_DIM:]).abs()
|
||||
assert rope_diff.max().item() == 0.0, (
|
||||
f"RoPE portion should be exact but got max diff {rope_diff.max().item()}"
|
||||
)
|
||||
|
||||
|
||||
# ── Test B: Indexer path ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 17])
|
||||
@pytest.mark.parametrize("block_size", [16, 64])
|
||||
def test_indexer_quant_cache_roundtrip(num_tokens: int, block_size: int):
|
||||
"""k → indexer_k_quant_and_cache → cp_gather_indexer_k_quant_cache
|
||||
→ manual dequant → compare against original."""
|
||||
|
||||
HEAD_DIM = 128
|
||||
QUANT_BLOCK_SIZE = 128
|
||||
# cache_stride = head_dim + (head_dim * 4 / quant_block_size) = 128 + 4 = 132
|
||||
CACHE_STRIDE = HEAD_DIM + HEAD_DIM * 4 // QUANT_BLOCK_SIZE
|
||||
|
||||
num_blocks = (num_tokens + block_size - 1) // block_size + 1
|
||||
device = "cuda"
|
||||
|
||||
# Random K (simulates compressor output for indexer)
|
||||
k = torch.randn(num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device)
|
||||
|
||||
# ── Quant + insert ──────────────────────────────────────────────────
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks, block_size, CACHE_STRIDE, dtype=torch.uint8, device=device
|
||||
)
|
||||
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
|
||||
|
||||
ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping, QUANT_BLOCK_SIZE, "ue8m0")
|
||||
|
||||
# ── Gather ──────────────────────────────────────────────────────────
|
||||
max_blocks_per_seq = num_blocks
|
||||
block_table = torch.arange(
|
||||
max_blocks_per_seq, dtype=torch.int32, device=device
|
||||
).unsqueeze(0)
|
||||
cu_seq_lens = torch.tensor([0, num_tokens], dtype=torch.int32, device=device)
|
||||
|
||||
# dst_k: [total_seq_len, head_dim] as uint8 (raw FP8 bytes)
|
||||
dst_k = torch.zeros(num_tokens, HEAD_DIM, dtype=torch.uint8, device=device)
|
||||
# dst_scale: [total_seq_len, head_dim/quant_block*4] as uint8 (raw float32 bytes)
|
||||
num_scale_bytes = HEAD_DIM * 4 // QUANT_BLOCK_SIZE # 4
|
||||
dst_scale = torch.zeros(
|
||||
num_tokens, num_scale_bytes, dtype=torch.uint8, device=device
|
||||
)
|
||||
|
||||
ops.cp_gather_indexer_k_quant_cache(
|
||||
kv_cache, dst_k, dst_scale, block_table, cu_seq_lens
|
||||
)
|
||||
|
||||
# ── Manual dequant ──────────────────────────────────────────────────
|
||||
k_fp8 = dst_k.view(torch.float8_e4m3fn).float() # [num_tokens, 128]
|
||||
scale = dst_scale.view(torch.float32) # [num_tokens, 1]
|
||||
k_recovered = k_fp8 * scale # [num_tokens, 128]
|
||||
|
||||
# ── Compare ─────────────────────────────────────────────────────────
|
||||
diff = (k_recovered - k.float()).abs()
|
||||
k_abs = k.float().abs()
|
||||
|
||||
for t in range(num_tokens):
|
||||
amax = k_abs[t].max().clamp(min=1e-4).item()
|
||||
# UE8M0: scale = 2^ceil(log2(amax / 448))
|
||||
exponent = math.ceil(math.log2(amax / 448.0))
|
||||
ue8m0_scale = 2.0**exponent
|
||||
# FP8 e4m3 (3-bit mantissa): worst-case error = 16 * scale
|
||||
max_allowed = 16.0 * ue8m0_scale
|
||||
token_diff = diff[t].max().item()
|
||||
assert token_diff <= max_allowed, (
|
||||
f"Token {t} diff {token_diff} exceeds max_allowed "
|
||||
f"{max_allowed} (scale={ue8m0_scale})"
|
||||
)
|
||||
|
||||
|
||||
def test_indexer_gather_accepts_upper_bound_output():
|
||||
"""Gather only exact cu_seq_lens even when dst is over-allocated."""
|
||||
|
||||
head_dim = 128
|
||||
quant_block_size = 128
|
||||
cache_stride = head_dim + head_dim * 4 // quant_block_size
|
||||
valid_tokens = 9
|
||||
upper_bound_tokens = 13
|
||||
block_size = 16
|
||||
num_blocks = 2
|
||||
sentinel = 123
|
||||
device = "cuda"
|
||||
|
||||
k = torch.randn(valid_tokens, head_dim, dtype=torch.bfloat16, device=device)
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks, block_size, cache_stride, dtype=torch.uint8, device=device
|
||||
)
|
||||
slot_mapping = torch.arange(valid_tokens, dtype=torch.int64, device=device)
|
||||
ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping, quant_block_size, "ue8m0")
|
||||
|
||||
block_table = torch.arange(num_blocks, dtype=torch.int32, device=device).unsqueeze(
|
||||
0
|
||||
)
|
||||
cu_seq_lens = torch.tensor([0, valid_tokens], dtype=torch.int32, device=device)
|
||||
dst_k = torch.full(
|
||||
(upper_bound_tokens, head_dim), sentinel, dtype=torch.uint8, device=device
|
||||
)
|
||||
num_scale_bytes = head_dim * 4 // quant_block_size
|
||||
dst_scale = torch.full(
|
||||
(upper_bound_tokens, num_scale_bytes),
|
||||
sentinel,
|
||||
dtype=torch.uint8,
|
||||
device=device,
|
||||
)
|
||||
|
||||
ops.cp_gather_indexer_k_quant_cache(
|
||||
kv_cache, dst_k, dst_scale, block_table, cu_seq_lens
|
||||
)
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
k_recovered = dst_k[:valid_tokens].view(torch.float8_e4m3fn).float() * dst_scale[
|
||||
:valid_tokens
|
||||
].view(torch.float32)
|
||||
diff = (k_recovered - k.float()).abs()
|
||||
max_allowed = (16.0 * dst_scale[:valid_tokens].view(torch.float32).max()).item()
|
||||
assert diff.max().item() <= max_allowed
|
||||
assert torch.all(dst_k[valid_tokens:] == sentinel)
|
||||
assert torch.all(dst_scale[valid_tokens:] == sentinel)
|
||||
|
||||
|
||||
# ── Test C: DeepseekV4 attention with values at different magnitudes ───────────
|
||||
|
||||
|
||||
def test_deepseek_v4_quant_magnitude_range():
|
||||
"""Test that quantization handles a range of magnitudes correctly."""
|
||||
|
||||
HEAD_DIM = 512
|
||||
NOPE_DIM = 448
|
||||
HEAD_BYTES = 584
|
||||
block_size = 16
|
||||
num_tokens = 4
|
||||
num_blocks = 2
|
||||
device = "cuda"
|
||||
|
||||
# Create inputs with varying magnitudes: small, medium, large
|
||||
compressed_kv = torch.zeros(
|
||||
num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
compressed_kv[0] = 0.001 # very small
|
||||
compressed_kv[1] = 1.0 # unit scale
|
||||
compressed_kv[2] = 100.0 # large
|
||||
compressed_kv[3] = torch.randn(HEAD_DIM, dtype=torch.bfloat16, device=device)
|
||||
|
||||
k_cache = torch.zeros(
|
||||
num_blocks, block_size, HEAD_BYTES, dtype=torch.uint8, device=device
|
||||
)
|
||||
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
|
||||
|
||||
quantize_and_insert_k_cache(
|
||||
compressed_kv, k_cache.view(num_blocks, -1), slot_mapping, block_size
|
||||
)
|
||||
|
||||
out = torch.zeros(1, num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device)
|
||||
seq_lens = torch.tensor([num_tokens], dtype=torch.int32, device=device)
|
||||
block_table = torch.arange(num_blocks, dtype=torch.int32, device=device).unsqueeze(
|
||||
0
|
||||
)
|
||||
|
||||
dequantize_and_gather_k_cache(
|
||||
out, k_cache, seq_lens, None, block_table, block_size, offset=0
|
||||
)
|
||||
|
||||
recovered = out[0, :num_tokens]
|
||||
|
||||
# RoPE portion must be exact
|
||||
rope_diff = (recovered[:, NOPE_DIM:] - compressed_kv[:, NOPE_DIM:]).abs().max()
|
||||
assert rope_diff.item() == 0.0, f"RoPE diff {rope_diff.item()}"
|
||||
|
||||
# NoPE: relative error should be reasonable
|
||||
for t in range(num_tokens):
|
||||
orig = compressed_kv[t, :NOPE_DIM].float()
|
||||
recv = recovered[t, :NOPE_DIM].float()
|
||||
abs_diff = (recv - orig).abs().max().item()
|
||||
magnitude = orig.abs().max().item()
|
||||
if magnitude > 0.01:
|
||||
rel_err = abs_diff / magnitude
|
||||
assert rel_err < 0.15, (
|
||||
f"Token {t}: rel_err={rel_err:.4f}, abs_diff={abs_diff:.6f}, "
|
||||
f"magnitude={magnitude:.4f}"
|
||||
)
|
||||
@@ -0,0 +1,359 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Standalone unit test for the horizontally-fused DeepseekV4-MLA kernel:
|
||||
|
||||
fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert
|
||||
- Q side: per-head RMSNorm (no weight) + GPT-J RoPE on last 64 dims
|
||||
- KV side: GPT-J RoPE on last 64 + UE8M0 FP8 quant + paged cache insert
|
||||
|
||||
We compare against:
|
||||
- PyTorch reference for RMSNorm + GPT-J RoPE on Q
|
||||
- Existing Triton `quantize_and_insert_k_cache` + round-trip via
|
||||
`dequantize_and_gather_k_cache` for KV
|
||||
|
||||
The kernel is imported via
|
||||
`torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.v1.attention.ops.deepseek_v4_ops import (
|
||||
dequantize_and_gather_k_cache,
|
||||
quantize_and_insert_k_cache,
|
||||
)
|
||||
|
||||
# ── Constants matching the kernel ────────────────────────────────────────────
|
||||
HEAD_DIM = 512
|
||||
ROPE_DIM = 64
|
||||
NOPE_DIM = HEAD_DIM - ROPE_DIM # 448
|
||||
QUANT_BLOCK = 64
|
||||
FP8_MAX = 448.0
|
||||
HEAD_BYTES = NOPE_DIM + ROPE_DIM * 2 + 8 # 448 + 128 + 8 = 584
|
||||
|
||||
|
||||
# ── PyTorch reference implementations ────────────────────────────────────────
|
||||
|
||||
|
||||
def make_cos_sin_cache(max_pos: int, rope_dim: int, dtype, device):
|
||||
"""Build a cos||sin cache matching DeepseekV4ScalingRotaryEmbedding layout.
|
||||
cos_sin_cache[pos, :rope_dim/2] = cos(theta), [rope_dim/2:] = sin(theta).
|
||||
"""
|
||||
base = 10000.0
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (torch.arange(0, rope_dim, 2, dtype=torch.float32, device=device) / rope_dim)
|
||||
)
|
||||
t = torch.arange(max_pos, dtype=torch.float32, device=device)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq) # [max_pos, rope_dim/2]
|
||||
cache = torch.cat((freqs.cos(), freqs.sin()), dim=-1) # [max_pos, rope_dim]
|
||||
return cache.to(dtype)
|
||||
|
||||
|
||||
def apply_rope_gptj_last_k(
|
||||
x: torch.Tensor, positions: torch.Tensor, cos_sin_cache: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""GPT-J-style (interleaved-pair) RoPE on the LAST rope_dim elements.
|
||||
|
||||
x: [..., head_dim] float32
|
||||
positions: [num_tokens] int64 (positions[i] corresponds to x[i, ...])
|
||||
cos_sin_cache: [max_pos, rope_dim] float (cos|sin layout)
|
||||
|
||||
Returns rotated x (same shape/dtype).
|
||||
"""
|
||||
rope_dim = cos_sin_cache.shape[-1]
|
||||
half = rope_dim // 2
|
||||
head_dim = x.shape[-1]
|
||||
nope_dim = head_dim - rope_dim
|
||||
|
||||
# Gather cos/sin for each token position: [num_tokens, rope_dim]
|
||||
cs = cos_sin_cache[positions].to(torch.float32) # [N, rope_dim]
|
||||
cos = cs[..., :half] # [N, half]
|
||||
sin = cs[..., half:] # [N, half]
|
||||
|
||||
# Reshape leading dims so we can broadcast: x shape [..., head_dim].
|
||||
# Bring token dim to front; assume x is [num_tokens, ..., head_dim].
|
||||
# We rely on positions being per-token and all other dims sharing the same pos.
|
||||
rope = x[..., nope_dim:].float() # [..., rope_dim]
|
||||
# Make rope pairs: reshape last dim to [half, 2]
|
||||
shape = rope.shape
|
||||
rope = rope.reshape(*shape[:-1], half, 2)
|
||||
even = rope[..., 0] # [..., half]
|
||||
odd = rope[..., 1]
|
||||
|
||||
# Broadcast cos/sin over any heads dim in between. cos/sin are [N, half].
|
||||
# Add singleton dims for intermediate axes.
|
||||
for _ in range(rope.ndim - 3):
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
|
||||
new_even = even * cos - odd * sin
|
||||
new_odd = even * sin + odd * cos
|
||||
rope_rotated = torch.stack((new_even, new_odd), dim=-1).reshape(shape)
|
||||
|
||||
out = x.clone().float()
|
||||
out[..., nope_dim:] = rope_rotated
|
||||
return out.to(x.dtype)
|
||||
|
||||
|
||||
def rmsnorm_no_weight(x: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
"""RMSNorm with no learnable weight, matching
|
||||
`RMSNorm(head_dim, has_weight=False)`."""
|
||||
orig_dtype = x.dtype
|
||||
xf = x.float()
|
||||
variance = xf.pow(2).mean(dim=-1, keepdim=True)
|
||||
return (xf * torch.rsqrt(variance + eps)).to(orig_dtype)
|
||||
|
||||
|
||||
# ── Dispatch to the CUDA op (skip test cleanly if it isn't built in) ─────────
|
||||
|
||||
|
||||
def _op_available() -> bool:
|
||||
return hasattr(torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert")
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not torch.cuda.is_available() or not _op_available(),
|
||||
reason="CUDA not available or fused DeepseekV4 op not built in",
|
||||
)
|
||||
|
||||
|
||||
def _call_fused(q, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs):
|
||||
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
|
||||
q, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs
|
||||
)
|
||||
|
||||
|
||||
# ── Test 1: Q path numerical parity ──────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 4, 17, 64])
|
||||
@pytest.mark.parametrize("n_heads", [8, 64])
|
||||
def test_q_path_matches_reference(num_tokens: int, n_heads: int):
|
||||
torch.manual_seed(0)
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
eps = 1e-6
|
||||
max_pos = 4096
|
||||
|
||||
q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device)
|
||||
positions = torch.arange(num_tokens, dtype=torch.int64, device=device)
|
||||
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)
|
||||
|
||||
# Reference: RMSNorm (no weight) per head, then GPT-J RoPE on last 64.
|
||||
q_ref = rmsnorm_no_weight(q, eps)
|
||||
q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache)
|
||||
|
||||
# Fused call with dummy KV tensors (KV branch will write slot_mapping=-1 → noop).
|
||||
num_blocks = 2
|
||||
bs = 16
|
||||
kv = torch.zeros(num_tokens, HEAD_DIM, dtype=dtype, device=device)
|
||||
k_cache = torch.zeros(
|
||||
num_blocks, bs, HEAD_BYTES, dtype=torch.uint8, device=device
|
||||
).view(num_blocks, -1)
|
||||
slot_mapping = torch.full((num_tokens,), -1, dtype=torch.int64, device=device)
|
||||
q_fused = q.clone()
|
||||
_call_fused(q_fused, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs)
|
||||
|
||||
torch.testing.assert_close(q_fused, q_ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
# ── Test 2: KV path round-trip byte/value parity ─────────────────────────────
|
||||
|
||||
|
||||
def _ue8m0_per_block_scales(kv_roped_nope_f32: torch.Tensor, qblock: int):
|
||||
"""Return per-token per-block max scale (used to bound FP8 error)."""
|
||||
n_tok, nope = kv_roped_nope_f32.shape
|
||||
n_blocks = nope // qblock
|
||||
blocks = kv_roped_nope_f32.view(n_tok, n_blocks, qblock)
|
||||
absmax = blocks.abs().amax(dim=-1).clamp(min=1e-4)
|
||||
raw = absmax / FP8_MAX
|
||||
exponent = torch.ceil(torch.log2(raw))
|
||||
return torch.pow(2.0, exponent) # [n_tok, n_blocks]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 4, 17, 64])
|
||||
@pytest.mark.parametrize("block_size", [16, 64])
|
||||
def test_kv_path_matches_reference(num_tokens: int, block_size: int):
|
||||
torch.manual_seed(1)
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
eps = 1e-6
|
||||
max_pos = 4096
|
||||
|
||||
kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device)
|
||||
positions = torch.arange(num_tokens, dtype=torch.int64, device=device)
|
||||
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)
|
||||
|
||||
num_blocks = (num_tokens + block_size - 1) // block_size + 1
|
||||
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
|
||||
|
||||
# ── Reference path: RoPE on kv, then existing Triton quant+insert ──────
|
||||
kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache)
|
||||
k_cache_ref = torch.zeros(
|
||||
num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device
|
||||
)
|
||||
quantize_and_insert_k_cache(
|
||||
kv_ref, k_cache_ref, slot_mapping, block_size=block_size
|
||||
)
|
||||
|
||||
# ── Fused path (dummy q, single head) ──────────────────────────────────
|
||||
k_cache_fused = torch.zeros_like(k_cache_ref)
|
||||
q_dummy = torch.zeros(num_tokens, 1, HEAD_DIM, dtype=dtype, device=device)
|
||||
_call_fused(
|
||||
q_dummy,
|
||||
kv,
|
||||
k_cache_fused,
|
||||
slot_mapping,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
block_size,
|
||||
)
|
||||
|
||||
# ── Round-trip compare via dequant+gather ──────────────────────────────
|
||||
def _dequant(k_cache_2d):
|
||||
num_reqs = 1
|
||||
max_blocks = num_blocks
|
||||
out = torch.zeros(
|
||||
num_reqs, num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
seq_lens = torch.tensor([num_tokens], dtype=torch.int32, device=device)
|
||||
block_table = torch.arange(
|
||||
max_blocks, dtype=torch.int32, device=device
|
||||
).unsqueeze(0)
|
||||
# gather_lens arg is None (use seq_lens)
|
||||
k_cache_3d = k_cache_2d.view(num_blocks, block_size, HEAD_BYTES)
|
||||
dequantize_and_gather_k_cache(
|
||||
out, k_cache_3d, seq_lens, None, block_table, block_size, offset=0
|
||||
)
|
||||
return out[0, :num_tokens]
|
||||
|
||||
recovered_ref = _dequant(k_cache_ref)
|
||||
recovered_fused = _dequant(k_cache_fused)
|
||||
|
||||
# NoPE: per-block UE8M0 FP8 error bound (half-ULP at max = 16 * scale).
|
||||
scales = _ue8m0_per_block_scales(kv_ref[:, :NOPE_DIM].float(), QUANT_BLOCK)
|
||||
for t in range(num_tokens):
|
||||
max_allowed = 16.0 * scales[t].max().item()
|
||||
diff_ref = (
|
||||
(recovered_ref[t, :NOPE_DIM] - kv_ref[t, :NOPE_DIM]).abs().max().item()
|
||||
)
|
||||
diff_fused = (
|
||||
(recovered_fused[t, :NOPE_DIM] - kv_ref[t, :NOPE_DIM]).abs().max().item()
|
||||
)
|
||||
assert diff_ref <= max_allowed, (
|
||||
f"ref NoPE token {t} diff {diff_ref} > {max_allowed}"
|
||||
)
|
||||
assert diff_fused <= max_allowed, (
|
||||
f"fused NoPE token {t} diff {diff_fused} > {max_allowed}"
|
||||
)
|
||||
|
||||
# RoPE region: bf16 stored exactly → zero diff.
|
||||
rope_diff = (recovered_fused[:, NOPE_DIM:] - kv_ref[:, NOPE_DIM:]).abs().max()
|
||||
assert rope_diff.item() == 0.0, f"RoPE portion not exact: {rope_diff.item()}"
|
||||
|
||||
# Exact byte equality of the two cache buffers — strong parity.
|
||||
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)
|
||||
|
||||
|
||||
# ── Test 2b: DP padding (slot_mapping shorter than q/kv) ─────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [4, 17])
|
||||
@pytest.mark.parametrize("pad", [1, 5])
|
||||
@pytest.mark.parametrize("block_size", [16, 64])
|
||||
def test_kv_path_with_dp_padding(num_tokens: int, pad: int, block_size: int):
|
||||
"""slot_mapping.size(0) < q.size(0): the kernel must skip padded
|
||||
tokens in the KV branch while still running Q-norm+RoPE on all rows."""
|
||||
torch.manual_seed(3)
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
eps = 1e-6
|
||||
max_pos = 4096
|
||||
total = num_tokens + pad
|
||||
|
||||
kv = torch.randn(total, HEAD_DIM, dtype=dtype, device=device)
|
||||
positions = torch.arange(total, dtype=torch.int64, device=device)
|
||||
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)
|
||||
|
||||
num_blocks = (num_tokens + block_size - 1) // block_size + 1
|
||||
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
|
||||
|
||||
# Reference: only the first num_tokens kv rows get inserted.
|
||||
kv_ref = apply_rope_gptj_last_k(
|
||||
kv[:num_tokens], positions[:num_tokens], cos_sin_cache
|
||||
)
|
||||
k_cache_ref = torch.zeros(
|
||||
num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device
|
||||
)
|
||||
quantize_and_insert_k_cache(
|
||||
kv_ref, k_cache_ref, slot_mapping, block_size=block_size
|
||||
)
|
||||
|
||||
# Fused: pass full-sized q/kv/positions, shorter slot_mapping.
|
||||
q_dummy = torch.zeros(total, 1, HEAD_DIM, dtype=dtype, device=device)
|
||||
k_cache_fused = torch.zeros_like(k_cache_ref)
|
||||
_call_fused(
|
||||
q_dummy,
|
||||
kv,
|
||||
k_cache_fused,
|
||||
slot_mapping,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
block_size,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)
|
||||
|
||||
|
||||
# ── Test 3: combined single-call Q + KV parity ───────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 4, 17])
|
||||
@pytest.mark.parametrize("n_heads", [8, 64])
|
||||
@pytest.mark.parametrize("block_size", [16, 64])
|
||||
def test_combined_q_and_kv(num_tokens: int, n_heads: int, block_size: int):
|
||||
torch.manual_seed(2)
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
eps = 1e-6
|
||||
max_pos = 4096
|
||||
|
||||
q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device)
|
||||
kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device)
|
||||
positions = torch.arange(num_tokens, dtype=torch.int64, device=device)
|
||||
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)
|
||||
|
||||
num_blocks = (num_tokens + block_size - 1) // block_size + 1
|
||||
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
|
||||
|
||||
# Reference.
|
||||
q_ref = rmsnorm_no_weight(q, eps)
|
||||
q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache)
|
||||
kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache)
|
||||
k_cache_ref = torch.zeros(
|
||||
num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device
|
||||
)
|
||||
quantize_and_insert_k_cache(
|
||||
kv_ref, k_cache_ref, slot_mapping, block_size=block_size
|
||||
)
|
||||
|
||||
# Fused single call.
|
||||
q_fused = q.clone()
|
||||
k_cache_fused = torch.zeros_like(k_cache_ref)
|
||||
_call_fused(
|
||||
q_fused,
|
||||
kv,
|
||||
k_cache_fused,
|
||||
slot_mapping,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
block_size,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(q_fused, q_ref, rtol=1e-2, atol=1e-2)
|
||||
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)
|
||||
@@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit test for fused_indexer_q_rope_quant.
|
||||
|
||||
Compares the fused Triton kernel against the unfused reference flow used by
|
||||
the DeepseekV4 indexer in model_tracking:
|
||||
q_rot = ops.rotary_embedding(positions, q, None, head_dim, cos_sin_cache,
|
||||
is_neox_style=False,
|
||||
rope_dim_offset=head_dim - rope_dim)
|
||||
q_fp8, q_scale = per_token_group_quant_fp8(q_rot, head_dim, use_ue8m0=True)
|
||||
weights_out = weights * q_scale * softmax_scale * head_scale
|
||||
|
||||
Expects bit-exact equality on both q_fp8 and weights_out.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import (
|
||||
fused_indexer_q_rope_quant,
|
||||
)
|
||||
|
||||
HEAD_DIM = 128
|
||||
ROPE_DIM = 64
|
||||
N_HEAD = 64
|
||||
MAX_POS = 4096
|
||||
|
||||
|
||||
def _reference(
|
||||
positions: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
head_scale: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
q_rot = q.clone()
|
||||
ops.rotary_embedding(
|
||||
positions,
|
||||
q_rot,
|
||||
None,
|
||||
HEAD_DIM,
|
||||
cos_sin_cache,
|
||||
False, # is_neox_style=False → GPT-J interleaved
|
||||
HEAD_DIM - ROPE_DIM, # rope_dim_offset → rotate the tail
|
||||
False,
|
||||
)
|
||||
q_fp8, q_scale = per_token_group_quant_fp8(
|
||||
q_rot.view(-1, HEAD_DIM).contiguous(),
|
||||
HEAD_DIM,
|
||||
use_ue8m0=True,
|
||||
)
|
||||
q_fp8 = q_fp8.view(-1, N_HEAD, HEAD_DIM)
|
||||
q_scale = q_scale.view(-1, N_HEAD)
|
||||
|
||||
weights_out = weights.to(torch.float32) * q_scale * softmax_scale * head_scale
|
||||
return q_fp8, weights_out
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 7, 32, 257])
|
||||
@pytest.mark.parametrize("cache_dtype", [torch.float32, torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
def test_fused_indexer_q_rope_quant_matches_unfused(num_tokens, cache_dtype):
|
||||
device = "cuda"
|
||||
torch.manual_seed(0)
|
||||
|
||||
q = torch.randn(num_tokens, N_HEAD, HEAD_DIM, dtype=torch.bfloat16, device=device)
|
||||
positions = torch.randint(
|
||||
0, MAX_POS, (num_tokens,), dtype=torch.int64, device=device
|
||||
)
|
||||
cos_sin_cache = torch.randn(MAX_POS, ROPE_DIM, dtype=cache_dtype, device=device)
|
||||
weights = torch.randn(num_tokens, N_HEAD, dtype=torch.bfloat16, device=device)
|
||||
softmax_scale = HEAD_DIM**-0.5
|
||||
head_scale = N_HEAD**-0.5
|
||||
|
||||
q_fp8_ref, weights_ref = _reference(
|
||||
positions, q, cos_sin_cache, weights, softmax_scale, head_scale
|
||||
)
|
||||
q_fp8_fused, weights_fused = fused_indexer_q_rope_quant(
|
||||
positions, q.clone(), cos_sin_cache, weights, softmax_scale, head_scale
|
||||
)
|
||||
|
||||
# fp8 tensors aren't directly comparable via torch.equal — reinterpret as int8.
|
||||
ref_bits = q_fp8_ref.view(torch.int8)
|
||||
fused_bits = q_fp8_fused.view(torch.int8)
|
||||
assert torch.equal(ref_bits, fused_bits), (
|
||||
f"q_fp8 mismatch: "
|
||||
f"{(ref_bits != fused_bits).sum().item()} / {ref_bits.numel()} bytes differ"
|
||||
)
|
||||
|
||||
assert torch.equal(weights_ref, weights_fused), (
|
||||
f"weights mismatch: max abs diff "
|
||||
f"{(weights_ref - weights_fused).abs().max().item()}"
|
||||
)
|
||||
@@ -0,0 +1,908 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Unit tests for the fused inverse RoPE + block-scaled FP8 quantization kernel.
|
||||
|
||||
Tests compare the fused kernel against a reference implementation built from
|
||||
the existing separate operations (inverse RoPE via rotate_neox + FP8 quant
|
||||
via per_token_group_quant_fp8).
|
||||
|
||||
The reference faithfully reproduces the exact flow in deepseek_v4_attention.py:295-310:
|
||||
1. Apply inverse RoPE (NeoX style, last rope_dim=64 dims of each head)
|
||||
2. Reshape [T, H, head_dim] -> [T, G, D]
|
||||
3. Transpose+flatten to [G*T, D], quantize, reshape back
|
||||
4. Return o_fp8 and o_scale with strides (D, T*D, 1) and (S, T*S, 1)
|
||||
(non-contiguous [T, G, ...] view backed by contiguous [G, T, ...] memory)
|
||||
|
||||
Usage:
|
||||
pytest tests/kernels/test_fused_inv_rope_fp8_quant.py -v
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.v1.attention.ops.deepseek_v4_ops import fused_inv_rope_fp8_quant
|
||||
|
||||
# -- Default dimensions matching DeepSeek V3/V4 --------------------------
|
||||
HEAD_DIM = 512
|
||||
NOPE_DIM = 448
|
||||
ROPE_DIM = 64
|
||||
QUANT_GROUP_SIZE = 128
|
||||
FP8_MAX = 448.0 # torch.finfo(torch.float8_e4m3fn).max
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
EPS = 1e-10
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Helpers
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def assert_dequant_close(
|
||||
fp8_a: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
fp8_b: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
msg: str = "",
|
||||
):
|
||||
"""Compare two FP8-quantized tensors via their dequantized values.
|
||||
|
||||
Uses cosine-similarity-based diff (same as deep_gemm calc_diff).
|
||||
Both fused and reference paths rotate in fp32 using an fp32
|
||||
cos_sin_cache, so differences are only fp32 ordering ULPs that can
|
||||
occasionally shift FP8 values at quantization boundaries.
|
||||
"""
|
||||
S = scale_a.shape[-1]
|
||||
shape = fp8_a.shape
|
||||
|
||||
dq_a = fp8_a.float() * scale_a.unsqueeze(-1).expand(
|
||||
*shape[:-1], S, QUANT_GROUP_SIZE
|
||||
).reshape(shape)
|
||||
dq_b = fp8_b.float() * scale_b.unsqueeze(-1).expand(
|
||||
*shape[:-1], S, QUANT_GROUP_SIZE
|
||||
).reshape(shape)
|
||||
|
||||
# Cosine diff: 1 - cos_sim (0 = identical, higher = worse)
|
||||
dq_a_flat = dq_a.flatten().float()
|
||||
dq_b_flat = dq_b.flatten().float()
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
dq_a_flat.unsqueeze(0), dq_b_flat.unsqueeze(0)
|
||||
).item()
|
||||
diff = 1.0 - cos_sim
|
||||
|
||||
assert diff < 1e-4, f"Dequant diff too large: {diff:.8f} (expected < 1e-4). {msg}"
|
||||
|
||||
|
||||
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
||||
"""GPT-J style rotation: interleaved pairs, negate-swap.
|
||||
|
||||
Matches vllm/model_executor/layers/rotary_embedding/common.py:23-27.
|
||||
DeepseekV4 uses is_neox_style=False, so this is the correct rotation.
|
||||
"""
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return x.flatten(-2)
|
||||
|
||||
|
||||
def make_cos_sin_cache(
|
||||
max_pos: int,
|
||||
rope_dim: int = ROPE_DIM,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: str = "cuda",
|
||||
) -> torch.Tensor:
|
||||
"""Create a synthetic cos_sin_cache matching the layout used by
|
||||
DeepseekV4ScalingRotaryEmbedding._compute_cos_sin_cache.
|
||||
|
||||
Shape: [max_pos, rope_dim] where first half is cos, second half is sin.
|
||||
The fused kernel requires fp32; callers can override dtype if passing
|
||||
the cache into the bf16-only paths.
|
||||
"""
|
||||
half = rope_dim // 2
|
||||
# Use random but bounded frequencies so cos/sin are well-behaved
|
||||
inv_freq = 1.0 / (
|
||||
10000.0 ** (torch.arange(0, half, device=device, dtype=torch.float32) / half)
|
||||
)
|
||||
t = torch.arange(max_pos, device=device, dtype=torch.float32)
|
||||
freqs = torch.outer(t, inv_freq) # [max_pos, half]
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1) # [max_pos, rope_dim]
|
||||
return cache.to(dtype)
|
||||
|
||||
|
||||
def reference_inv_rope(
|
||||
o: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
nope_dim: int = NOPE_DIM,
|
||||
rope_dim: int = ROPE_DIM,
|
||||
) -> torch.Tensor:
|
||||
"""Apply inverse RoPE to the last rope_dim dimensions of each head.
|
||||
|
||||
Matches the GPT-J inverse rotation in pos_encoding_kernels.cu, which
|
||||
promotes the cache to fp32 and performs the rotation in fp32. The
|
||||
result is cast back to the input dtype.
|
||||
|
||||
Args:
|
||||
o: [T, H, head_dim] bf16
|
||||
positions: [T] int64
|
||||
cos_sin_cache: [max_pos, rope_dim] fp32
|
||||
|
||||
Returns:
|
||||
o with inverse RoPE applied on the rope portion (bf16).
|
||||
"""
|
||||
assert cos_sin_cache.dtype == torch.float32
|
||||
cos_sin = cos_sin_cache[positions] # [T, rope_dim] fp32
|
||||
half = rope_dim // 2
|
||||
cos = cos_sin[:, :half]
|
||||
sin = cos_sin[:, half:]
|
||||
|
||||
# GPT-J style: repeat_interleave (not repeat) to match interleaved pairs
|
||||
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(1)
|
||||
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(1)
|
||||
sin = -sin # inverse
|
||||
|
||||
o_pass = o[..., :nope_dim]
|
||||
o_rot_f32 = o[..., nope_dim:].float()
|
||||
o_rot_f32 = o_rot_f32 * cos + rotate_gptj(o_rot_f32) * sin
|
||||
o_rot = o_rot_f32.to(o.dtype)
|
||||
|
||||
return torch.cat([o_pass, o_rot], dim=-1)
|
||||
|
||||
|
||||
def _ref_ue8m0_quant_block(x_f32: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Per-block UE8M0 FP8 quantization in pure float32.
|
||||
|
||||
Matches the Triton kernel logic exactly:
|
||||
absmax -> 2^ceil(log2(absmax / fp8_max)) -> clamp(x / scale) -> fp8
|
||||
|
||||
Args:
|
||||
x_f32: [..., quant_group_size] float32 — one or more 128-element blocks.
|
||||
|
||||
Returns:
|
||||
x_fp8: same shape, float8_e4m3fn
|
||||
scales: [...] float32, one scale per block
|
||||
"""
|
||||
assert x_f32.shape[-1] == QUANT_GROUP_SIZE
|
||||
absmax = x_f32.abs().amax(dim=-1, keepdim=True).clamp(min=EPS)
|
||||
scale_raw = absmax * (1.0 / FP8_MAX)
|
||||
scale = torch.exp2(torch.ceil(torch.log2(scale_raw)))
|
||||
x_scaled = (x_f32 / scale).clamp(-FP8_MAX, FP8_MAX)
|
||||
x_fp8 = x_scaled.to(FP8_DTYPE)
|
||||
return x_fp8, scale.squeeze(-1)
|
||||
|
||||
|
||||
def reference_inv_rope_fp8_quant(
|
||||
o: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
n_groups: int,
|
||||
heads_per_group: int,
|
||||
nope_dim: int = NOPE_DIM,
|
||||
rope_dim: int = ROPE_DIM,
|
||||
quant_group_size: int = QUANT_GROUP_SIZE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Full reference: inverse RoPE in fp32 + UE8M0 FP8 quant in fp32.
|
||||
|
||||
Mimics the Triton kernel's precision path exactly:
|
||||
Load bf16 -> cast to fp32 -> apply inverse RoPE with fp32 cos/sin ->
|
||||
UE8M0 quant in fp32 -> write fp8 + scale
|
||||
|
||||
Returns:
|
||||
o_fp8: [T, G, D] FP8 with strides (D, T*D, 1)
|
||||
o_scale: [T, G, S] FP32 with strides (S, T*S, 1)
|
||||
"""
|
||||
assert cos_sin_cache.dtype == torch.float32
|
||||
T, _H, head_dim = o.shape
|
||||
d = heads_per_group * head_dim
|
||||
S = d // quant_group_size
|
||||
half_rope = rope_dim // 2
|
||||
chunks_per_head = head_dim // quant_group_size
|
||||
|
||||
# Reshape [T, H, head_dim] -> [T, G, heads_per_group, head_dim]
|
||||
o_4d = o.view(T, n_groups, heads_per_group, head_dim)
|
||||
|
||||
# Lookup cos/sin directly in fp32
|
||||
cos_sin = cos_sin_cache[positions] # [T, rope_dim] fp32
|
||||
cos = cos_sin[:, :half_rope] # [T, half_rope] fp32
|
||||
sin = cos_sin[:, half_rope:] # [T, half_rope] fp32
|
||||
|
||||
# Allocate outputs in [G, T, ...] contiguous layout
|
||||
fp8_buf = torch.empty(n_groups, T, d, dtype=FP8_DTYPE, device=o.device)
|
||||
scale_buf = torch.empty(n_groups, T, S, dtype=torch.float32, device=o.device)
|
||||
|
||||
# Process each quant block, matching the Triton kernel's per-program logic
|
||||
for g in range(n_groups):
|
||||
for qb in range(S):
|
||||
head_in_group = qb // chunks_per_head
|
||||
chunk_in_head = qb % chunks_per_head
|
||||
offset = chunk_in_head * quant_group_size
|
||||
|
||||
# Load 128 bf16 elements and promote to fp32 for rotation+quant
|
||||
block = o_4d[:, g, head_in_group, offset : offset + quant_group_size]
|
||||
x = block.float()
|
||||
|
||||
# Apply inverse RoPE in fp32 if this is the last chunk
|
||||
# GPT-J style: interleaved pairs (even=x, odd=y)
|
||||
if chunk_in_head == chunks_per_head - 1:
|
||||
rope_start = nope_dim % quant_group_size # 64
|
||||
rope_region = x[:, rope_start:].clone()
|
||||
x_vals = rope_region[:, ::2]
|
||||
y_vals = rope_region[:, 1::2]
|
||||
x_new = x_vals * cos + y_vals * sin
|
||||
y_new = y_vals * cos - x_vals * sin
|
||||
x = x.clone()
|
||||
x[:, rope_start::2] = x_new
|
||||
x[:, rope_start + 1 :: 2] = y_new
|
||||
|
||||
# UE8M0 quant in fp32
|
||||
x_fp8, scale = _ref_ue8m0_quant_block(x)
|
||||
|
||||
# Write to [G, T, D] contiguous memory
|
||||
fp8_buf[g, :, qb * quant_group_size : (qb + 1) * quant_group_size] = x_fp8
|
||||
scale_buf[g, :, qb] = scale
|
||||
|
||||
# Return transposed views
|
||||
return fp8_buf.transpose(0, 1), scale_buf.transpose(0, 1)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 7, 32, 128])
|
||||
@pytest.mark.parametrize(
|
||||
"num_heads,n_groups",
|
||||
[(64, 8), (32, 4), (128, 8)],
|
||||
ids=["H64_G8", "H32_G4", "H128_G8"],
|
||||
)
|
||||
@pytest.mark.parametrize("seed", [0, 42])
|
||||
@torch.inference_mode()
|
||||
def test_correctness(num_tokens, num_heads, n_groups, seed):
|
||||
"""Compare fused kernel against reference for FP8 values and scales."""
|
||||
torch.manual_seed(seed)
|
||||
|
||||
heads_per_group = num_heads // n_groups
|
||||
max_pos = 4096
|
||||
device = "cuda"
|
||||
|
||||
# Create inputs
|
||||
o = torch.randn(
|
||||
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
positions = torch.randint(
|
||||
0, max_pos, (num_tokens,), device=device, dtype=torch.long
|
||||
)
|
||||
cos_sin_cache = make_cos_sin_cache(
|
||||
max_pos, ROPE_DIM, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
# Reference
|
||||
ref_fp8, ref_scale = reference_inv_rope_fp8_quant(
|
||||
o.clone(),
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
n_groups,
|
||||
heads_per_group,
|
||||
)
|
||||
|
||||
# Fused kernel
|
||||
fused_fp8, fused_scale = fused_inv_rope_fp8_quant(
|
||||
o.clone(),
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
n_groups,
|
||||
heads_per_group,
|
||||
)
|
||||
|
||||
# Check shapes
|
||||
d = heads_per_group * HEAD_DIM
|
||||
S = d // QUANT_GROUP_SIZE
|
||||
assert ref_fp8.shape == (num_tokens, n_groups, d)
|
||||
assert fused_fp8.shape == (num_tokens, n_groups, d)
|
||||
assert ref_scale.shape == (num_tokens, n_groups, S)
|
||||
assert fused_scale.shape == (num_tokens, n_groups, S)
|
||||
|
||||
# Scales: exact match (both use identical UE8M0 algorithm)
|
||||
# Scales may differ by one UE8M0 step (factor of 2) if fp32 rotation
|
||||
# ordering shifts absmax across a power-of-2 boundary. Check ratio is
|
||||
# close to 1.
|
||||
scale_ratio = fused_scale / ref_scale.clamp(min=1e-30)
|
||||
assert scale_ratio.max() <= 2.0 and scale_ratio.min() >= 0.5, (
|
||||
f"Scale ratio out of [0.5, 2]: min={scale_ratio.min():.4f} "
|
||||
f"max={scale_ratio.max():.4f}"
|
||||
)
|
||||
|
||||
# Compare via dequant (Triton vs PyTorch fp32 may differ by ULPs)
|
||||
assert_dequant_close(ref_fp8, ref_scale, fused_fp8, fused_scale)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 7, 32, 128])
|
||||
@pytest.mark.parametrize(
|
||||
"num_heads,n_groups",
|
||||
[(64, 8), (128, 8)],
|
||||
ids=["H64_G8", "H128_G8"],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_output_strides(num_tokens, num_heads, n_groups):
|
||||
"""Verify fused output layout:
|
||||
- FP8: logical [T, G, D] backed by contiguous [G, T, D].
|
||||
- Scale: MN-major TMA-aligned (column-major: T-stride=1).
|
||||
"""
|
||||
|
||||
heads_per_group = num_heads // n_groups
|
||||
max_pos = 4096
|
||||
device = "cuda"
|
||||
|
||||
o = torch.randn(
|
||||
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
positions = torch.randint(
|
||||
0, max_pos, (num_tokens,), device=device, dtype=torch.long
|
||||
)
|
||||
cos_sin_cache = make_cos_sin_cache(max_pos, device=device)
|
||||
|
||||
fused_fp8, fused_scale = fused_inv_rope_fp8_quant(
|
||||
o.clone(),
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
n_groups,
|
||||
heads_per_group,
|
||||
)
|
||||
|
||||
# FP8: logical [T, G, D] backed by [G, T, D] row-major
|
||||
d = heads_per_group * HEAD_DIM
|
||||
expected_fp8_stride = (d, num_tokens * d, 1)
|
||||
assert fused_fp8.stride() == expected_fp8_stride, (
|
||||
f"FP8 stride mismatch: got {fused_fp8.stride()}, expected {expected_fp8_stride}"
|
||||
)
|
||||
|
||||
# Scale: MN-major TMA-aligned layout. After fp8_einsum permutes
|
||||
# [T,G,S] -> [G,T,S], T-dim should have stride 1.
|
||||
# Our output is [T,G,S] = transpose of [G,T,S].
|
||||
# So fused_scale.permute(1,0,2) should have T-stride=1.
|
||||
perm = fused_scale.permute(1, 0, 2) # [G, T, S]
|
||||
assert perm.stride(1) == 1 or num_tokens == 1, (
|
||||
f"Scale T-stride (after permute to [G,T,S]) should be 1, got {perm.stride(1)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 7, 32, 128])
|
||||
@torch.inference_mode()
|
||||
def test_per_group_contiguity(num_tokens):
|
||||
"""FP8 per-group slices must be contiguous. Scale per-group slices
|
||||
are column-major (T-stride=1) — not row-major contiguous, which is
|
||||
correct for TMA loads."""
|
||||
num_heads, n_groups = 64, 8
|
||||
heads_per_group = num_heads // n_groups
|
||||
max_pos = 4096
|
||||
device = "cuda"
|
||||
|
||||
o = torch.randn(
|
||||
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
positions = torch.randint(
|
||||
0, max_pos, (num_tokens,), device=device, dtype=torch.long
|
||||
)
|
||||
cos_sin_cache = make_cos_sin_cache(max_pos, device=device)
|
||||
|
||||
fused_fp8, fused_scale = fused_inv_rope_fp8_quant(
|
||||
o.clone(),
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
n_groups,
|
||||
heads_per_group,
|
||||
)
|
||||
|
||||
for g in range(n_groups):
|
||||
fp8_slice = fused_fp8[:, g, :]
|
||||
assert fp8_slice.is_contiguous(), (
|
||||
f"o_fp8[:, {g}, :] is not contiguous: "
|
||||
f"shape={list(fp8_slice.shape)}, stride={list(fp8_slice.stride())}"
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_scales_are_power_of_two():
|
||||
"""Verify all scales are exact powers of 2 (UE8M0 property)."""
|
||||
num_tokens, num_heads, n_groups = 32, 64, 8
|
||||
heads_per_group = num_heads // n_groups
|
||||
max_pos = 4096
|
||||
device = "cuda"
|
||||
|
||||
o = torch.randn(
|
||||
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
positions = torch.randint(
|
||||
0, max_pos, (num_tokens,), device=device, dtype=torch.long
|
||||
)
|
||||
cos_sin_cache = make_cos_sin_cache(max_pos, device=device)
|
||||
|
||||
_, fused_scale = fused_inv_rope_fp8_quant(
|
||||
o.clone(),
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
n_groups,
|
||||
heads_per_group,
|
||||
)
|
||||
|
||||
# log2 of a power-of-two is an exact integer
|
||||
log2_scales = torch.log2(fused_scale)
|
||||
residual = (log2_scales - log2_scales.round()).abs()
|
||||
assert residual.max() < 1e-5, (
|
||||
f"Not all scales are powers of 2: max log2 residual = {residual.max().item()}"
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_nope_dims_unchanged():
|
||||
"""Nope dimensions (first 448 per head) should only be quantized,
|
||||
not rotated. Verify by dequantizing and comparing against
|
||||
quantize-only reference (no RoPE)."""
|
||||
num_tokens, num_heads, n_groups = 16, 64, 8
|
||||
heads_per_group = num_heads // n_groups
|
||||
max_pos = 4096
|
||||
device = "cuda"
|
||||
torch.manual_seed(0)
|
||||
|
||||
o = torch.randn(
|
||||
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
positions = torch.randint(
|
||||
0, max_pos, (num_tokens,), device=device, dtype=torch.long
|
||||
)
|
||||
cos_sin_cache = make_cos_sin_cache(max_pos, device=device)
|
||||
|
||||
# Fused kernel result
|
||||
fused_fp8, fused_scale = fused_inv_rope_fp8_quant(
|
||||
o.clone(),
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
n_groups,
|
||||
heads_per_group,
|
||||
)
|
||||
|
||||
# Reference: quantize without RoPE (identity rotation)
|
||||
# Create a zero-sin cache so RoPE is identity
|
||||
zero_cache = torch.zeros_like(cos_sin_cache)
|
||||
half = ROPE_DIM // 2
|
||||
zero_cache[:, :half] = 1.0 # cos = 1
|
||||
# sin = 0 (already zero)
|
||||
|
||||
norope_fp8, norope_scale = fused_inv_rope_fp8_quant(
|
||||
o.clone(),
|
||||
positions,
|
||||
zero_cache,
|
||||
n_groups,
|
||||
heads_per_group,
|
||||
)
|
||||
|
||||
# Extract nope quant blocks only (first 3 of every 4 blocks per head)
|
||||
chunks_per_head = HEAD_DIM // QUANT_GROUP_SIZE # 4
|
||||
|
||||
for h in range(heads_per_group):
|
||||
for c in range(chunks_per_head - 1): # skip last chunk (has rope)
|
||||
qb = h * chunks_per_head + c
|
||||
start = qb * QUANT_GROUP_SIZE
|
||||
end = start + QUANT_GROUP_SIZE
|
||||
|
||||
fused_nope = fused_fp8[:, :, start:end].view(torch.uint8)
|
||||
norope_nope = norope_fp8[:, :, start:end].view(torch.uint8)
|
||||
assert torch.equal(fused_nope, norope_nope), (
|
||||
f"Nope block (head={h}, chunk={c}) differs between "
|
||||
f"fused and no-rope reference"
|
||||
)
|
||||
|
||||
fused_s = fused_scale[:, :, qb]
|
||||
norope_s = norope_scale[:, :, qb]
|
||||
assert torch.equal(fused_s, norope_s), (
|
||||
f"Nope scale (head={h}, chunk={c}) differs"
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_single_token():
|
||||
"""Edge case: single token."""
|
||||
num_tokens, num_heads, n_groups = 1, 64, 8
|
||||
heads_per_group = num_heads // n_groups
|
||||
max_pos = 4096
|
||||
device = "cuda"
|
||||
|
||||
o = torch.randn(
|
||||
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
positions = torch.tensor([42], device=device, dtype=torch.long)
|
||||
cos_sin_cache = make_cos_sin_cache(max_pos, device=device)
|
||||
|
||||
ref_fp8, ref_scale = reference_inv_rope_fp8_quant(
|
||||
o.clone(),
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
n_groups,
|
||||
heads_per_group,
|
||||
)
|
||||
fused_fp8, fused_scale = fused_inv_rope_fp8_quant(
|
||||
o.clone(),
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
n_groups,
|
||||
heads_per_group,
|
||||
)
|
||||
|
||||
assert_dequant_close(ref_fp8, ref_scale, fused_fp8, fused_scale)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_zero_positions():
|
||||
"""Edge case: all positions are 0."""
|
||||
num_tokens, num_heads, n_groups = 16, 64, 8
|
||||
heads_per_group = num_heads // n_groups
|
||||
max_pos = 4096
|
||||
device = "cuda"
|
||||
|
||||
o = torch.randn(
|
||||
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
positions = torch.zeros(num_tokens, device=device, dtype=torch.long)
|
||||
cos_sin_cache = make_cos_sin_cache(max_pos, device=device)
|
||||
|
||||
ref_fp8, ref_scale = reference_inv_rope_fp8_quant(
|
||||
o.clone(),
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
n_groups,
|
||||
heads_per_group,
|
||||
)
|
||||
fused_fp8, fused_scale = fused_inv_rope_fp8_quant(
|
||||
o.clone(),
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
n_groups,
|
||||
heads_per_group,
|
||||
)
|
||||
|
||||
assert_dequant_close(ref_fp8, ref_scale, fused_fp8, fused_scale)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_large_values():
|
||||
"""Edge case: values near FP8 saturation to test clamping."""
|
||||
num_tokens, num_heads, n_groups = 8, 64, 8
|
||||
heads_per_group = num_heads // n_groups
|
||||
max_pos = 4096
|
||||
device = "cuda"
|
||||
|
||||
# Create inputs with large values that will saturate FP8
|
||||
o = torch.randn(
|
||||
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
o = o * 1000.0 # scale up to force saturation
|
||||
positions = torch.randint(
|
||||
0, max_pos, (num_tokens,), device=device, dtype=torch.long
|
||||
)
|
||||
cos_sin_cache = make_cos_sin_cache(max_pos, device=device)
|
||||
|
||||
ref_fp8, ref_scale = reference_inv_rope_fp8_quant(
|
||||
o.clone(),
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
n_groups,
|
||||
heads_per_group,
|
||||
)
|
||||
fused_fp8, fused_scale = fused_inv_rope_fp8_quant(
|
||||
o.clone(),
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
n_groups,
|
||||
heads_per_group,
|
||||
)
|
||||
|
||||
assert_dequant_close(ref_fp8, ref_scale, fused_fp8, fused_scale)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_dequant_numerical_accuracy():
|
||||
"""Verify dequantized values are close to the original (after inv RoPE)."""
|
||||
num_tokens, num_heads, n_groups = 32, 64, 8
|
||||
heads_per_group = num_heads // n_groups
|
||||
max_pos = 4096
|
||||
device = "cuda"
|
||||
torch.manual_seed(0)
|
||||
|
||||
o = torch.randn(
|
||||
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
positions = torch.randint(
|
||||
0, max_pos, (num_tokens,), device=device, dtype=torch.long
|
||||
)
|
||||
cos_sin_cache = make_cos_sin_cache(max_pos, device=device)
|
||||
|
||||
# Get the post-inv-RoPE values (ground truth before quantization)
|
||||
o_after_rope = reference_inv_rope(o.clone(), positions, cos_sin_cache)
|
||||
d = heads_per_group * HEAD_DIM
|
||||
o_after_rope = o_after_rope.view(num_tokens, n_groups, d)
|
||||
|
||||
# Get fused quantized output
|
||||
fused_fp8, fused_scale = fused_inv_rope_fp8_quant(
|
||||
o.clone(),
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
n_groups,
|
||||
heads_per_group,
|
||||
)
|
||||
|
||||
# Dequantize: broadcast scale [T, G, S] to [T, G, D] via repeat
|
||||
S = d // QUANT_GROUP_SIZE
|
||||
scale_expanded = (
|
||||
fused_scale.unsqueeze(-1)
|
||||
.expand(num_tokens, n_groups, S, QUANT_GROUP_SIZE)
|
||||
.reshape(num_tokens, n_groups, d)
|
||||
)
|
||||
dequant = fused_fp8.float() * scale_expanded
|
||||
|
||||
# Check relative error.
|
||||
# FP8 e4m3 with UE8M0 (power-of-two scales that round UP) quantizes more
|
||||
# coarsely than optimal scaling. Both paths rotate in fp32, so the bulk
|
||||
# of the error comes from UE8M0 quantization itself (~10-12% typical).
|
||||
o_gt = o_after_rope.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
dequant_contig = dequant.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
|
||||
abs_err = (dequant_contig.float() - o_gt.float()).abs()
|
||||
rel_err = abs_err / (o_gt.float().abs().clamp(min=1e-6))
|
||||
mean_rel_err = rel_err.mean().item()
|
||||
|
||||
assert mean_rel_err < 0.15, (
|
||||
f"Mean relative error too high: {mean_rel_err:.4f} (expected < 0.15)"
|
||||
)
|
||||
|
||||
|
||||
def _unfused_inv_rope_fp8_quant(
|
||||
o: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
n_groups: int,
|
||||
heads_per_group: int,
|
||||
nope_dim: int = NOPE_DIM,
|
||||
rope_dim: int = ROPE_DIM,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Unfused path matching deepseek_v4_attention.py:295-310.
|
||||
|
||||
Uses the production CUDA RoPE kernel + per_token_group_quant_fp8.
|
||||
"""
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
|
||||
head_dim = o.shape[-1]
|
||||
rope_dim_offset = head_dim - rope_dim
|
||||
|
||||
# Step 1: In-place CUDA RoPE (same as production)
|
||||
ops.rotary_embedding(
|
||||
positions,
|
||||
o,
|
||||
None,
|
||||
head_dim,
|
||||
cos_sin_cache,
|
||||
False, # is_neox=False for DeepseekV4 (GPT-J style)
|
||||
rope_dim_offset=rope_dim_offset,
|
||||
inverse=True,
|
||||
)
|
||||
|
||||
# Step 2: Reshape + quant + reshape (same as production)
|
||||
T = o.shape[0]
|
||||
d = heads_per_group * head_dim
|
||||
o = o.view(T, n_groups, -1)
|
||||
o_flat = o.transpose(0, 1).contiguous().reshape(-1, d)
|
||||
o_fp8, o_scale = per_token_group_quant_fp8(
|
||||
o_flat,
|
||||
group_size=QUANT_GROUP_SIZE,
|
||||
use_ue8m0=True,
|
||||
)
|
||||
o_fp8 = o_fp8.view(n_groups, T, d).transpose(0, 1)
|
||||
o_scale = o_scale.view(n_groups, T, -1).transpose(0, 1)
|
||||
return o_fp8, o_scale
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# End-to-end test including fp8_einsum
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 7, 32, 128, 1024])
|
||||
@pytest.mark.parametrize(
|
||||
"num_heads,n_groups",
|
||||
[(64, 8)],
|
||||
ids=["H64_G8"],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_einsum_end_to_end(num_tokens, num_heads, n_groups):
|
||||
"""End-to-end: fused inv_rope+quant → fp8_einsum must match
|
||||
unfused CUDA_rope+quant → fp8_einsum bitwise.
|
||||
|
||||
This catches stride/layout bugs that only manifest when the einsum
|
||||
kernel actually consumes the quantized activations.
|
||||
"""
|
||||
from deep_gemm.utils.math import ceil_div
|
||||
|
||||
from vllm.utils.deep_gemm import (
|
||||
fp8_einsum,
|
||||
per_block_cast_to_fp8,
|
||||
transform_sf_into_required_layout,
|
||||
)
|
||||
|
||||
heads_per_group = num_heads // n_groups
|
||||
d = heads_per_group * HEAD_DIM
|
||||
o_lora_rank = 1024
|
||||
max_pos = 4096
|
||||
device = "cuda"
|
||||
torch.manual_seed(0)
|
||||
|
||||
o = torch.randn(
|
||||
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
positions = torch.randint(
|
||||
0, max_pos, (num_tokens,), device=device, dtype=torch.long
|
||||
)
|
||||
cos_sin_cache = make_cos_sin_cache(max_pos, device=device)
|
||||
|
||||
# -- Weight quantization (shared between both paths) --
|
||||
w = torch.randn(n_groups, o_lora_rank, d, device=device, dtype=torch.bfloat16)
|
||||
w_fp8 = torch.empty_like(w, dtype=torch.float8_e4m3fn)
|
||||
w_scale = torch.empty(
|
||||
n_groups,
|
||||
ceil_div(o_lora_rank, 128),
|
||||
ceil_div(d, 128),
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
for g in range(n_groups):
|
||||
w_fp8[g], w_scale[g] = per_block_cast_to_fp8(w[g], use_ue8m0=True)
|
||||
|
||||
recipe = (1, 1, 128)
|
||||
w_scale_t = transform_sf_into_required_layout(
|
||||
sf=w_scale,
|
||||
mn=o_lora_rank,
|
||||
k=d,
|
||||
recipe=(1, 128, 128),
|
||||
num_groups=n_groups,
|
||||
is_sfa=False,
|
||||
)
|
||||
|
||||
# -- UNFUSED path --
|
||||
ref_fp8, ref_scale = _unfused_inv_rope_fp8_quant(
|
||||
o.clone(),
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
n_groups,
|
||||
heads_per_group,
|
||||
)
|
||||
z_ref = torch.empty(
|
||||
num_tokens, n_groups, o_lora_rank, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
fp8_einsum(
|
||||
"bhr,hdr->bhd", (ref_fp8, ref_scale), (w_fp8, w_scale_t), z_ref, recipe=recipe
|
||||
)
|
||||
|
||||
# -- FUSED path --
|
||||
fused_fp8, fused_scale = fused_inv_rope_fp8_quant(
|
||||
o.clone(),
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
n_groups,
|
||||
heads_per_group,
|
||||
)
|
||||
z_fused = torch.empty(
|
||||
num_tokens, n_groups, o_lora_rank, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
fp8_einsum(
|
||||
"bhr,hdr->bhd",
|
||||
(fused_fp8, fused_scale),
|
||||
(w_fp8, w_scale_t),
|
||||
z_fused,
|
||||
recipe=recipe,
|
||||
)
|
||||
|
||||
# -- Checks --
|
||||
# Einsum output: Triton and CUDA both rotate in fp32 now, so diffs
|
||||
# come from fp32 ordering and UE8M0 boundary shifts only.
|
||||
# Use relative diff (same metric as test_fp8_einsum.py).
|
||||
from deep_gemm.testing import calc_diff
|
||||
|
||||
z_diff = calc_diff(z_fused, z_ref)
|
||||
assert z_diff < 0.01, (
|
||||
f"Einsum output diff too large: {z_diff:.6f} (expected < 0.01)"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 32, 256])
|
||||
@torch.inference_mode()
|
||||
def test_with_real_deepseek_v4_rope(num_tokens, default_vllm_config):
|
||||
"""Test with real DeepseekV4ScalingRotaryEmbedding (GPT-J style,
|
||||
mscale=0, YaRN scaling) matching the production config."""
|
||||
|
||||
num_heads = 64
|
||||
n_groups = 8
|
||||
heads_per_group = num_heads // n_groups
|
||||
device = "cuda"
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Build YaRN-scaled cos_sin_cache matching real DeepSeek V3/V4 config
|
||||
# (mscale=0 → mscale=1.0, so no magnitude scaling)
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
yarn_find_correction_range,
|
||||
yarn_linear_ramp_mask,
|
||||
)
|
||||
|
||||
scaling_factor = 16
|
||||
base = 10000.0
|
||||
max_pos = 65536
|
||||
beta_fast, beta_slow = 32, 1
|
||||
|
||||
pos_freqs = base ** (
|
||||
torch.arange(0, ROPE_DIM, 2, dtype=torch.float32, device=device) / ROPE_DIM
|
||||
)
|
||||
inv_freq_extra = 1.0 / pos_freqs
|
||||
inv_freq_interp = 1.0 / (scaling_factor * pos_freqs)
|
||||
low, high = yarn_find_correction_range(
|
||||
beta_fast, beta_slow, ROPE_DIM, base, max_pos
|
||||
)
|
||||
mask = 1 - yarn_linear_ramp_mask(low, high, ROPE_DIM // 2, dtype=torch.float32).to(
|
||||
device
|
||||
)
|
||||
inv_freq = inv_freq_interp * (1 - mask) + inv_freq_extra * mask
|
||||
t = torch.arange(max_pos * scaling_factor, device=device, dtype=torch.float32)
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
# mscale=0 → yarn_get_mscale returns 1.0
|
||||
cos_sin_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) # fp32
|
||||
|
||||
o = torch.randn(
|
||||
num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
positions = torch.randint(0, 4096, (num_tokens,), device=device, dtype=torch.long)
|
||||
|
||||
# UNFUSED: CUDA RoPE with is_neox=False (GPT-J)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
|
||||
o_unfused = o.clone()
|
||||
ops.rotary_embedding(
|
||||
positions,
|
||||
o_unfused,
|
||||
None,
|
||||
HEAD_DIM,
|
||||
cos_sin_cache,
|
||||
False, # is_neox=False (GPT-J style)
|
||||
rope_dim_offset=NOPE_DIM,
|
||||
inverse=True,
|
||||
)
|
||||
d = heads_per_group * HEAD_DIM
|
||||
T = num_tokens
|
||||
o_unfused = o_unfused.view(T, n_groups, d)
|
||||
o_flat = o_unfused.transpose(0, 1).contiguous().reshape(-1, d)
|
||||
ref_fp8, ref_scale = per_token_group_quant_fp8(
|
||||
o_flat,
|
||||
group_size=QUANT_GROUP_SIZE,
|
||||
use_ue8m0=True,
|
||||
)
|
||||
ref_fp8 = ref_fp8.view(n_groups, T, d).transpose(0, 1)
|
||||
ref_scale = ref_scale.view(n_groups, T, -1).transpose(0, 1)
|
||||
|
||||
# FUSED: use the real YaRN-scaled cos_sin_cache
|
||||
fused_fp8, fused_scale = fused_inv_rope_fp8_quant(
|
||||
o.clone(),
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
n_groups,
|
||||
heads_per_group,
|
||||
)
|
||||
|
||||
# Scales must match exactly (same UE8M0 algorithm)
|
||||
# Compare via dequant (Triton bf16 rotation may differ from CUDA by 1 ULP)
|
||||
assert_dequant_close(
|
||||
ref_fp8, ref_scale, fused_fp8, fused_scale, msg="Real DeepSeek V4 rope"
|
||||
)
|
||||
@@ -718,7 +718,6 @@ def test_persistent_topk_stress() -> None:
|
||||
pytest.param(
|
||||
{
|
||||
"seq_lens": [2000, 6000, 30000, 80000],
|
||||
"top_k": 2048,
|
||||
"data_type": "random",
|
||||
},
|
||||
id="mixed_all_paths",
|
||||
@@ -727,7 +726,6 @@ def test_persistent_topk_stress() -> None:
|
||||
pytest.param(
|
||||
{
|
||||
"seq_lens": [2048, 4096, 8192, 16000],
|
||||
"top_k": 2048,
|
||||
"data_type": "random",
|
||||
},
|
||||
id="all_decode_medium",
|
||||
@@ -736,7 +734,6 @@ def test_persistent_topk_stress() -> None:
|
||||
pytest.param(
|
||||
{
|
||||
"seq_lens": [70000, 100000, 163840],
|
||||
"top_k": 2048,
|
||||
"data_type": "random",
|
||||
},
|
||||
id="all_large",
|
||||
@@ -745,7 +742,6 @@ def test_persistent_topk_stress() -> None:
|
||||
pytest.param(
|
||||
{
|
||||
"seq_lens": [32767, 32768, 32769, 32772],
|
||||
"top_k": 2048,
|
||||
"data_type": "random",
|
||||
},
|
||||
id="large_threshold_boundary",
|
||||
@@ -754,7 +750,6 @@ def test_persistent_topk_stress() -> None:
|
||||
pytest.param(
|
||||
{
|
||||
"seq_lens": [5000],
|
||||
"top_k": 2048,
|
||||
"data_type": "random",
|
||||
},
|
||||
id="single_row_medium",
|
||||
@@ -772,15 +767,15 @@ def test_persistent_topk_stress() -> None:
|
||||
pytest.param(
|
||||
{
|
||||
"seq_lens": [100, 2048, 10000, 80000],
|
||||
"top_k": 2048,
|
||||
"data_type": "random",
|
||||
},
|
||||
id="trivial_medium_large_mix",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("top_k", [512, 2048])
|
||||
@torch.inference_mode()
|
||||
def test_persistent_topk(test_config: dict) -> None:
|
||||
def test_persistent_topk(test_config: dict, top_k: int) -> None:
|
||||
"""
|
||||
Tests specific to the persistent_topk kernel:
|
||||
- Mixed medium/large rows in the same batch (dynamic per-row dispatch)
|
||||
@@ -790,14 +785,15 @@ def test_persistent_topk(test_config: dict) -> None:
|
||||
run_large_context_topk_test(
|
||||
batch_size=len(test_config["seq_lens"]),
|
||||
seq_lens=test_config["seq_lens"],
|
||||
top_k=test_config["top_k"],
|
||||
top_k=top_k,
|
||||
data_type=test_config.get("data_type", "random"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||
@pytest.mark.parametrize("top_k", [512, 2048])
|
||||
@torch.inference_mode()
|
||||
def test_persistent_topk_padded_stride() -> None:
|
||||
def test_persistent_topk_padded_stride(top_k: int) -> None:
|
||||
"""
|
||||
Test persistent_topk with padded logits (large stride, small seq_len)
|
||||
to simulate the e2e CUDAGraph scenario where fp8_paged_mqa_logits
|
||||
@@ -806,7 +802,6 @@ def test_persistent_topk_padded_stride() -> None:
|
||||
set_random_seed(42)
|
||||
torch.set_default_device("cuda:0")
|
||||
|
||||
top_k = 2048
|
||||
batch_size = 4
|
||||
padded_stride = 163840 # DeepSeek-V3.2 max_model_len
|
||||
actual_seq_lens = [3000, 5000, 8000, 12000]
|
||||
|
||||
@@ -41,7 +41,9 @@ class DummyRouter(BaseRouter):
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
return RoutingMethodType.FUSED_TOPK
|
||||
|
||||
def _compute_routing(self, hidden_states, router_logits, indices_type):
|
||||
def _compute_routing(
|
||||
self, hidden_states, router_logits, indices_type, *, input_ids=None
|
||||
):
|
||||
topk_ids = torch.tensor([[1, 2], [3, 4]], dtype=torch.int64)
|
||||
topk_weights = torch.ones_like(topk_ids, dtype=torch.float32)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
@@ -260,6 +260,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True,
|
||||
),
|
||||
"DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"),
|
||||
"DeepseekV4ForCausalLM": _HfExamplesInfo(
|
||||
"deepseek-ai/DeepSeek-V4-Flash", is_available_online=False
|
||||
),
|
||||
"Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT"),
|
||||
"Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT"),
|
||||
"ExaoneForCausalLM": _HfExamplesInfo(
|
||||
@@ -1482,6 +1485,12 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
speculative_model="luccafong/deepseek_mtp_draft_random",
|
||||
trust_remote_code=True,
|
||||
),
|
||||
"DeepSeekV4MTPModel": _HfExamplesInfo(
|
||||
"deepseek-ai/DeepSeek-V4-Flash",
|
||||
speculative_model="deepseek-ai/DeepSeek-V4-Flash",
|
||||
trust_remote_code=True,
|
||||
is_available_online=False,
|
||||
),
|
||||
"ErnieMTPModel": _HfExamplesInfo(
|
||||
"baidu/ERNIE-4.5-21B-A3B-PT",
|
||||
trust_remote_code=True,
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.deepseek_v4 import (
|
||||
DeepseekV4MegaMoEExperts,
|
||||
_stage_deepseek_v4_mega_moe_inputs,
|
||||
make_deepseek_v4_expert_params_mapping,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="DeepSeek V4 MegaMoE requires CUDA",
|
||||
)
|
||||
|
||||
|
||||
def test_deepseek_v4_mega_moe_expert_mapping():
|
||||
mapping = make_deepseek_v4_expert_params_mapping(2)
|
||||
|
||||
assert mapping == [
|
||||
("experts.w13_", "experts.0.w1.", 0, "w1"),
|
||||
("experts.w2_", "experts.0.w2.", 0, "w2"),
|
||||
("experts.w13_", "experts.0.w3.", 0, "w3"),
|
||||
("experts.w13_", "experts.1.w1.", 1, "w1"),
|
||||
("experts.w2_", "experts.1.w2.", 1, "w2"),
|
||||
("experts.w13_", "experts.1.w3.", 1, "w3"),
|
||||
]
|
||||
|
||||
|
||||
def test_deepseek_v4_mega_moe_ue8m0_uint8_to_float():
|
||||
raw = torch.tensor([0, 126, 127, 128], dtype=torch.uint8)
|
||||
|
||||
decoded = DeepseekV4MegaMoEExperts._ue8m0_uint8_to_float(raw)
|
||||
|
||||
assert torch.equal(decoded.view(torch.int32), raw.to(torch.int32) << 23)
|
||||
assert decoded[0].item() == 0.0
|
||||
assert decoded[1].item() == 0.5
|
||||
assert decoded[2].item() == 1.0
|
||||
assert decoded[3].item() == 2.0
|
||||
|
||||
|
||||
def test_deepseek_v4_mega_moe_weight_loader_uses_ep_expert_ownership():
|
||||
vllm_config = SimpleNamespace(
|
||||
scheduler_config=SimpleNamespace(max_num_batched_tokens=4)
|
||||
)
|
||||
experts = DeepseekV4MegaMoEExperts(
|
||||
vllm_config,
|
||||
num_experts=4,
|
||||
num_local_experts=2,
|
||||
experts_start_idx=2,
|
||||
top_k=2,
|
||||
hidden_size=128,
|
||||
intermediate_size=128,
|
||||
)
|
||||
|
||||
nonlocal_weight = torch.ones(128, 64, dtype=torch.uint8)
|
||||
assert (
|
||||
experts.weight_loader(
|
||||
experts.w13_weight,
|
||||
nonlocal_weight,
|
||||
"experts.w13_weight",
|
||||
shard_id="w1",
|
||||
expert_id=1,
|
||||
return_success=True,
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
w1 = torch.full((128, 64), 3, dtype=torch.uint8)
|
||||
w3 = torch.full((128, 64), 7, dtype=torch.uint8)
|
||||
w2 = torch.full((128, 64), 11, dtype=torch.uint8)
|
||||
|
||||
assert experts.weight_loader(
|
||||
experts.w13_weight,
|
||||
w1,
|
||||
"experts.w13_weight",
|
||||
shard_id="w1",
|
||||
expert_id=2,
|
||||
return_success=True,
|
||||
)
|
||||
assert experts.weight_loader(
|
||||
experts.w13_weight,
|
||||
w3,
|
||||
"experts.w13_weight",
|
||||
shard_id="w3",
|
||||
expert_id=2,
|
||||
return_success=True,
|
||||
)
|
||||
assert experts.weight_loader(
|
||||
experts.w2_weight,
|
||||
w2,
|
||||
"experts.w2_weight",
|
||||
shard_id="w2",
|
||||
expert_id=2,
|
||||
return_success=True,
|
||||
)
|
||||
|
||||
assert torch.equal(experts.w13_weight[0, :128], w1)
|
||||
assert torch.equal(experts.w13_weight[0, 128:], w3)
|
||||
assert torch.equal(experts.w2_weight[0], w2)
|
||||
assert torch.count_nonzero(experts.w13_weight[1]) == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="DeepSeek V4 MegaMoE fused input staging requires CUDA.",
|
||||
)
|
||||
def test_deepseek_v4_mega_moe_fused_input_staging_is_bitwise_exact():
|
||||
from vllm.third_party.deep_gemm.utils import per_token_cast_to_fp8
|
||||
|
||||
device = torch.device("cuda")
|
||||
num_tokens = 7
|
||||
hidden_size = 256
|
||||
top_k = 8
|
||||
|
||||
generator = torch.Generator(device=device)
|
||||
generator.manual_seed(0)
|
||||
hidden_states = (
|
||||
torch.randn(
|
||||
num_tokens,
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
generator=generator,
|
||||
)
|
||||
* 17.0
|
||||
).to(torch.bfloat16)
|
||||
hidden_states[0, :32] = 0
|
||||
hidden_states[1, 32:64] = 1.0e-6
|
||||
hidden_states[2, 64:96] = -1.0e-6
|
||||
|
||||
topk_ids = torch.randint(
|
||||
0,
|
||||
256,
|
||||
(num_tokens, top_k),
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
generator=generator,
|
||||
)
|
||||
topk_weights = torch.randn(
|
||||
num_tokens,
|
||||
top_k,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
ref_x, ref_x_sf = per_token_cast_to_fp8(
|
||||
hidden_states,
|
||||
use_ue8m0=True,
|
||||
gran_k=32,
|
||||
use_packed_ue8m0=True,
|
||||
)
|
||||
ref_topk_idx = topk_ids.to(torch.int64)
|
||||
ref_topk_weights = topk_weights.clone()
|
||||
|
||||
fused_x = torch.empty_like(ref_x)
|
||||
fused_x_sf = torch.empty_like(ref_x_sf)
|
||||
fused_topk_idx = torch.empty_like(ref_topk_idx)
|
||||
fused_topk_weights = torch.empty_like(ref_topk_weights)
|
||||
|
||||
_stage_deepseek_v4_mega_moe_inputs(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
fused_x,
|
||||
fused_x_sf,
|
||||
fused_topk_idx,
|
||||
fused_topk_weights,
|
||||
)
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
assert torch.equal(fused_x.view(torch.uint8), ref_x.view(torch.uint8))
|
||||
assert torch.equal(fused_x_sf, ref_x_sf)
|
||||
assert torch.equal(fused_topk_idx, ref_topk_idx)
|
||||
assert torch.equal(
|
||||
fused_topk_weights.view(torch.uint8),
|
||||
ref_topk_weights.view(torch.uint8),
|
||||
)
|
||||
@@ -6,6 +6,7 @@ from transformers import AutoTokenizer
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
||||
from vllm.reasoning.deepseek_v3_reasoning_parser import DeepSeekV3ReasoningParser
|
||||
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
|
||||
@@ -33,6 +34,12 @@ def test_parser_selection(tokenizer, thinking, expected_parser_type):
|
||||
assert isinstance(parser._parser, expected_parser_type)
|
||||
|
||||
|
||||
def test_deepseek_v4_reasoning_parser_alias():
|
||||
parser_cls = ReasoningParserManager.get_reasoning_parser("deepseek_v4")
|
||||
|
||||
assert parser_cls is DeepSeekV3ReasoningParser
|
||||
|
||||
|
||||
def test_identity_reasoning_parser_basic(tokenizer):
|
||||
parser = IdentityReasoningParser(tokenizer)
|
||||
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
{
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather for a specific location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city name"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "Temperature unit"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"description": "Search the web for information",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query"
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather in Beijing?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning": "The user wants to know the weather in Beijing. I should use the get_weather tool.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_001",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"location\": \"Beijing\", \"unit\": \"celsius\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_001",
|
||||
"content": "{\"temperature\": 22, \"condition\": \"sunny\", \"humidity\": 45}"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning": "Got the weather data. Let me format a nice response.",
|
||||
"content": "The weather in Beijing is currently sunny with a temperature of 22°C and 45% humidity."
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning": "The user said hello, I should greet back.",
|
||||
"content": "Hi there! How can I help you?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the capital of France?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning": "The user asks about the capital of France. It is Paris.",
|
||||
"content": "The capital of France is Paris."
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,159 @@
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "该助手为DeepSeek,由深度求索公司创造。"
|
||||
},
|
||||
{
|
||||
"role": "latest_reminder",
|
||||
"content": "2026-02-21,星期六,广州,App,中文"
|
||||
},
|
||||
{
|
||||
"role": "developer",
|
||||
"content": "小柴胡冲剂和布洛芬能一起吃吗?\n\nCITATION FORMAT: 【{cursor_id}†L{start_line_id}(-L{end_line_id})?】",
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"description": "Web search. Split multiple queries with '||'.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"queries": {
|
||||
"type": "string",
|
||||
"description": "query1||query2"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"queries"
|
||||
],
|
||||
"additionalProperties": false,
|
||||
"$schema": "http://json-schema.org/draft-07/schema#"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "open",
|
||||
"description": "Batch open IDs (format 【{id}†...】) or URLs.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"open_list": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"description": "ID or URL",
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "integer"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"default": -1
|
||||
},
|
||||
"cursor": {
|
||||
"type": "integer",
|
||||
"description": "",
|
||||
"default": -1
|
||||
},
|
||||
"loc": {
|
||||
"type": "integer",
|
||||
"description": "Start line",
|
||||
"default": -1
|
||||
},
|
||||
"num_lines": {
|
||||
"type": "integer",
|
||||
"description": "",
|
||||
"default": -1
|
||||
},
|
||||
"view_source": {
|
||||
"type": "boolean",
|
||||
"description": "",
|
||||
"default": false
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
},
|
||||
"description": ""
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"open_list"
|
||||
],
|
||||
"additionalProperties": false,
|
||||
"$schema": "http://json-schema.org/draft-07/schema#"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "find",
|
||||
"description": "Find exact text pattern in pages.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"find_list": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": ""
|
||||
},
|
||||
"cursor": {
|
||||
"type": "integer",
|
||||
"description": "",
|
||||
"default": -1
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"pattern"
|
||||
],
|
||||
"additionalProperties": false
|
||||
},
|
||||
"description": ""
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"find_list"
|
||||
],
|
||||
"additionalProperties": false,
|
||||
"$schema": "http://json-schema.org/draft-07/schema#"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning": "用户想知道小柴胡冲剂和布洛芬能否一起服用。",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"arguments": "{\"queries\": \"小柴胡冲剂 布洛芬 相互作用 一起吃\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "[0]"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "请及时就医。",
|
||||
"reasoning": "现在开始组织回答。",
|
||||
"tool_calls": []
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,28 @@
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "该助手为DeepSeek-V3,由深度求索公司创造。\n今天是2025年10月17日,星期五。"
|
||||
},
|
||||
{
|
||||
"role": "latest_reminder",
|
||||
"content": "2024-11-15,上海市,App,中文"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "热海大滚锅是世界著名温泉吗"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "关于热海大滚锅是否是世界著名温泉,可以这样总结:它在**中国乃至全球的地热奇观中占有重要地位**,但“世界著名”的称号更侧重于它作为**独特的地质现象和旅游景点**,而非普遍的温泉疗养体验。\n\n为了让你快速了解,我整理了一个简要的表格:\n\n| 维度 | 热海大滚锅的地位与特点 |\n| :--- | :--- |\n| **地质奇观** | **世界罕见**的高温地热系统,泉眼多、水温高、形态多样。 |\n| **国内声誉** | **中国三大地热区之一**,**国家5A级旅游景区**,自古闻名(徐霞客曾游历并记载)。 |\n| **国际知名度** | 在一些旅游平台被提及为“世界六大温泉”之一,但此说法流传不广,其国际声誉更多建立在地质独特性上。 |\n| **核心体验** | **观赏地热奇观**(如97℃沸腾的“大滚锅”)、**体验温泉煮鸡蛋**。 |\n\n### 💡 游玩攻略与温馨提示\n\n如果你计划前往热海大滚锅,这里有一些实用信息供你参考:\n\n- **门票与开放时间**:\n - **门票**:景区门票约为**50元/人**。如果选择包含温泉沐浴的套餐,价格会更高,例如约**288元**。\n - **开放时间**:景区一般**08:00-18:00**开放,但具体时间可能变动,建议提前核实。\n\n- **特色体验**:\n - **温泉煮鸡蛋**:这几乎是必试项目。可以在景区门口购买用草绳串起的生鸡蛋(约5-8元/串),然后到“大滚锅”旁的指定区域蒸煮,几分钟便可熟食,趣味十足。\n - **金汤足浴**:可以直接用从“大滚锅”流出的温泉水泡脚,缓解旅途疲劳。\n\n- **注意事项**:\n - **安全第一**:“大滚锅”水温极高,务必遵守游览规则,在指定区域内观赏,切勿随意触碰泉水。\n - **规划行程**:建议为热海景区预留**3-4小时**的游览时间。景区内步道不走回头路,出入口有观光车接送。\n\n希望这些信息能帮助你更好地了解热海大滚锅。如果你对腾冲的其他景点或者行程规划有更多疑问,我很乐意提供进一步的信息。",
|
||||
"mask": 1
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "世界著名温泉有哪些",
|
||||
"task": "action"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Search"
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,36 @@
|
||||
<|begin▁of▁sentence|>
|
||||
|
||||
## Tools
|
||||
|
||||
You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following:
|
||||
|
||||
<|DSML|tool_calls>
|
||||
<|DSML|invoke name="$TOOL_NAME">
|
||||
<|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</|DSML|parameter>
|
||||
...
|
||||
</|DSML|invoke>
|
||||
<|DSML|invoke name="$TOOL_NAME2">
|
||||
...
|
||||
</|DSML|invoke>
|
||||
</|DSML|tool_calls>
|
||||
|
||||
String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
|
||||
|
||||
If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
|
||||
|
||||
Otherwise, output directly after </think> with tool calls or final response.
|
||||
|
||||
### Available Tool Schemas
|
||||
|
||||
{"name": "get_weather", "description": "Get the weather for a specific location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city name"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "Temperature unit"}}, "required": ["location"]}}
|
||||
{"name": "search", "description": "Search the web for information", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Search query"}, "num_results": {"type": "integer", "description": "Number of results to return"}}, "required": ["query"]}}
|
||||
|
||||
You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
|
||||
You are a helpful assistant.<|User|>What's the weather in Beijing?<|Assistant|><think>The user wants to know the weather in Beijing. I should use the get_weather tool.</think>
|
||||
|
||||
<|DSML|tool_calls>
|
||||
<|DSML|invoke name="get_weather">
|
||||
<|DSML|parameter name="location" string="true">Beijing</|DSML|parameter>
|
||||
<|DSML|parameter name="unit" string="true">celsius</|DSML|parameter>
|
||||
</|DSML|invoke>
|
||||
</|DSML|tool_calls><|end▁of▁sentence|><|User|><tool_result>{"temperature": 22, "condition": "sunny", "humidity": 45}</tool_result><|Assistant|><think>Got the weather data. Let me format a nice response.</think>The weather in Beijing is currently sunny with a temperature of 22°C and 45% humidity.<|end▁of▁sentence|>
|
||||
@@ -0,0 +1 @@
|
||||
<|begin▁of▁sentence|>You are a helpful assistant.<|User|>Hello<|Assistant|></think>Hi there! How can I help you?<|end▁of▁sentence|><|User|>What is the capital of France?<|Assistant|><think>The user asks about the capital of France. It is Paris.</think>The capital of France is Paris.<|end▁of▁sentence|>
|
||||
@@ -0,0 +1,38 @@
|
||||
<|begin▁of▁sentence|>该助手为DeepSeek,由深度求索公司创造。<|latest_reminder|>2026-02-21,星期六,广州,App,中文<|User|>小柴胡冲剂和布洛芬能一起吃吗?
|
||||
|
||||
CITATION FORMAT: 【{cursor_id}†L{start_line_id}(-L{end_line_id})?】
|
||||
|
||||
## Tools
|
||||
|
||||
You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following:
|
||||
|
||||
<|DSML|tool_calls>
|
||||
<|DSML|invoke name="$TOOL_NAME">
|
||||
<|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</|DSML|parameter>
|
||||
...
|
||||
</|DSML|invoke>
|
||||
<|DSML|invoke name="$TOOL_NAME2">
|
||||
...
|
||||
</|DSML|invoke>
|
||||
</|DSML|tool_calls>
|
||||
|
||||
String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
|
||||
|
||||
If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
|
||||
|
||||
Otherwise, output directly after </think> with tool calls or final response.
|
||||
|
||||
### Available Tool Schemas
|
||||
|
||||
{"name": "search", "description": "Web search. Split multiple queries with '||'.", "parameters": {"type": "object", "properties": {"queries": {"type": "string", "description": "query1||query2"}}, "required": ["queries"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
|
||||
{"name": "open", "description": "Batch open IDs (format 【{id}†...】) or URLs.", "parameters": {"type": "object", "properties": {"open_list": {"type": "array", "items": {"type": "object", "properties": {"id": {"description": "ID or URL", "anyOf": [{"type": "integer"}, {"type": "string"}], "default": -1}, "cursor": {"type": "integer", "description": "", "default": -1}, "loc": {"type": "integer", "description": "Start line", "default": -1}, "num_lines": {"type": "integer", "description": "", "default": -1}, "view_source": {"type": "boolean", "description": "", "default": false}}, "additionalProperties": false}, "description": ""}}, "required": ["open_list"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
|
||||
{"name": "find", "description": "Find exact text pattern in pages.", "parameters": {"type": "object", "properties": {"find_list": {"type": "array", "items": {"type": "object", "properties": {"pattern": {"type": "string", "description": ""}, "cursor": {"type": "integer", "description": "", "default": -1}}, "required": ["pattern"], "additionalProperties": false}, "description": ""}}, "required": ["find_list"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
|
||||
|
||||
You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
|
||||
<|Assistant|><think>用户想知道小柴胡冲剂和布洛芬能否一起服用。</think>
|
||||
|
||||
<|DSML|tool_calls>
|
||||
<|DSML|invoke name="search">
|
||||
<|DSML|parameter name="queries" string="true">小柴胡冲剂 布洛芬 相互作用 一起吃</|DSML|parameter>
|
||||
</|DSML|invoke>
|
||||
</|DSML|tool_calls><|end▁of▁sentence|><|User|><tool_result>[0]</tool_result><|Assistant|><think>现在开始组织回答。</think>请及时就医。<|end▁of▁sentence|>
|
||||
@@ -0,0 +1,29 @@
|
||||
<|begin▁of▁sentence|>该助手为DeepSeek-V3,由深度求索公司创造。
|
||||
今天是2025年10月17日,星期五。<|latest_reminder|>2024-11-15,上海市,App,中文<|User|>热海大滚锅是世界著名温泉吗<|Assistant|></think>关于热海大滚锅是否是世界著名温泉,可以这样总结:它在**中国乃至全球的地热奇观中占有重要地位**,但“世界著名”的称号更侧重于它作为**独特的地质现象和旅游景点**,而非普遍的温泉疗养体验。
|
||||
|
||||
为了让你快速了解,我整理了一个简要的表格:
|
||||
|
||||
| 维度 | 热海大滚锅的地位与特点 |
|
||||
| :--- | :--- |
|
||||
| **地质奇观** | **世界罕见**的高温地热系统,泉眼多、水温高、形态多样。 |
|
||||
| **国内声誉** | **中国三大地热区之一**,**国家5A级旅游景区**,自古闻名(徐霞客曾游历并记载)。 |
|
||||
| **国际知名度** | 在一些旅游平台被提及为“世界六大温泉”之一,但此说法流传不广,其国际声誉更多建立在地质独特性上。 |
|
||||
| **核心体验** | **观赏地热奇观**(如97℃沸腾的“大滚锅”)、**体验温泉煮鸡蛋**。 |
|
||||
|
||||
### 💡 游玩攻略与温馨提示
|
||||
|
||||
如果你计划前往热海大滚锅,这里有一些实用信息供你参考:
|
||||
|
||||
- **门票与开放时间**:
|
||||
- **门票**:景区门票约为**50元/人**。如果选择包含温泉沐浴的套餐,价格会更高,例如约**288元**。
|
||||
- **开放时间**:景区一般**08:00-18:00**开放,但具体时间可能变动,建议提前核实。
|
||||
|
||||
- **特色体验**:
|
||||
- **温泉煮鸡蛋**:这几乎是必试项目。可以在景区门口购买用草绳串起的生鸡蛋(约5-8元/串),然后到“大滚锅”旁的指定区域蒸煮,几分钟便可熟食,趣味十足。
|
||||
- **金汤足浴**:可以直接用从“大滚锅”流出的温泉水泡脚,缓解旅途疲劳。
|
||||
|
||||
- **注意事项**:
|
||||
- **安全第一**:“大滚锅”水温极高,务必遵守游览规则,在指定区域内观赏,切勿随意触碰泉水。
|
||||
- **规划行程**:建议为热海景区预留**3-4小时**的游览时间。景区内步道不走回头路,出入口有观光车接送。
|
||||
|
||||
希望这些信息能帮助你更好地了解热海大滚锅。如果你对腾冲的其他景点或者行程规划有更多疑问,我很乐意提供进一步的信息。<|end▁of▁sentence|><|User|>世界著名温泉有哪些<|Assistant|></think><|action|>Search<|end▁of▁sentence|>
|
||||
@@ -0,0 +1,224 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.chat_utils import parse_chat_messages
|
||||
from vllm.renderers.registry import RENDERER_REGISTRY
|
||||
from vllm.tokenizers.deepseek_v4 import get_deepseek_v4_tokenizer
|
||||
from vllm.tokenizers.registry import TokenizerRegistry
|
||||
|
||||
FIXTURES_DIR = Path(__file__).parent / "fixtures" / "deepseek_v4"
|
||||
|
||||
|
||||
class FakeHfTokenizer:
|
||||
vocab_size = 100
|
||||
|
||||
def get_added_vocab(self) -> dict[str, int]:
|
||||
return {"</think>": 100}
|
||||
|
||||
def encode(
|
||||
self,
|
||||
text: str,
|
||||
add_special_tokens: bool = False,
|
||||
**kwargs,
|
||||
) -> list[int]:
|
||||
self.last_encode = (text, add_special_tokens, kwargs)
|
||||
return [len(text)]
|
||||
|
||||
|
||||
def _tokenizer():
|
||||
return get_deepseek_v4_tokenizer(FakeHfTokenizer())
|
||||
|
||||
|
||||
def _model_config():
|
||||
return SimpleNamespace(
|
||||
multimodal_config=None,
|
||||
allowed_local_media_path="",
|
||||
allowed_media_domains=None,
|
||||
)
|
||||
|
||||
|
||||
def _load_reference_case(case_id: int):
|
||||
data = json.loads((FIXTURES_DIR / f"test_input_{case_id}.json").read_text())
|
||||
if isinstance(data, dict):
|
||||
return data["messages"], data.get("tools")
|
||||
return data, None
|
||||
|
||||
|
||||
def _render_reference_case(case_id: int, **kwargs):
|
||||
messages, tools = _load_reference_case(case_id)
|
||||
conversation, _, _ = parse_chat_messages(
|
||||
messages,
|
||||
_model_config(),
|
||||
content_format="string",
|
||||
)
|
||||
return _tokenizer().apply_chat_template(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def test_deepseek_v4_tokenizer_registered():
|
||||
assert TokenizerRegistry.load_tokenizer_cls("deepseek_v4").__name__ == (
|
||||
"DeepseekV4Tokenizer"
|
||||
)
|
||||
assert RENDERER_REGISTRY.load_renderer_cls("deepseek_v4").__name__ == (
|
||||
"DeepseekV4Renderer"
|
||||
)
|
||||
|
||||
|
||||
def test_deepseek_v4_defaults_to_chat_mode():
|
||||
prompt = _tokenizer().apply_chat_template(
|
||||
[{"role": "user", "content": "Hello"}],
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
assert prompt == ("<|begin▁of▁sentence|><|User|>Hello<|Assistant|></think>")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kwargs", [{"thinking": True}, {"enable_thinking": True}])
|
||||
def test_deepseek_v4_enables_thinking_with_compatible_kwargs(kwargs):
|
||||
prompt = _tokenizer().apply_chat_template(
|
||||
[{"role": "user", "content": "Hello"}],
|
||||
tokenize=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
assert prompt == ("<|begin▁of▁sentence|><|User|>Hello<|Assistant|><think>")
|
||||
|
||||
|
||||
def test_deepseek_v4_uses_v4_tool_prompt_from_request_tools():
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather for a city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
prompt = _tokenizer().apply_chat_template(
|
||||
[{"role": "user", "content": "Weather?"}],
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
assert "## Tools" in prompt
|
||||
assert "<|DSML|tool_calls>" in prompt
|
||||
assert "</|DSML|tool_calls>" in prompt
|
||||
assert "function_calls" not in prompt
|
||||
assert '"name": "get_weather"' in prompt
|
||||
assert prompt.endswith("<|User|>Weather?<|Assistant|></think>")
|
||||
|
||||
|
||||
def test_deepseek_v4_renders_parsed_history_tool_arguments():
|
||||
messages = [
|
||||
{"role": "user", "content": "List the repo"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "str_replace_editor",
|
||||
"arguments": '{"command": "view", "path": "/testbed"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": "file list",
|
||||
},
|
||||
]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "str_replace_editor",
|
||||
"description": "Edit files",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {"type": "string"},
|
||||
"path": {"type": "string"},
|
||||
},
|
||||
"required": ["command", "path"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
conversation, _, _ = parse_chat_messages(
|
||||
messages,
|
||||
_model_config(),
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt = _tokenizer().apply_chat_template(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
assert '<|DSML|parameter name="command" string="true">view' in prompt
|
||||
assert '<|DSML|parameter name="path" string="true">/testbed' in prompt
|
||||
assert 'parameter name="arguments"' not in prompt
|
||||
|
||||
|
||||
@pytest.mark.parametrize("reasoning_effort", ["none", "low", "medium", "high"])
|
||||
def test_deepseek_v4_accepts_openai_reasoning_effort_values(reasoning_effort):
|
||||
prompt = _tokenizer().apply_chat_template(
|
||||
[{"role": "user", "content": "Hello"}],
|
||||
tokenize=False,
|
||||
enable_thinking=True,
|
||||
reasoning_effort=reasoning_effort,
|
||||
)
|
||||
|
||||
assert prompt.endswith("<|Assistant|><think>")
|
||||
assert "Reasoning Effort: Absolute maximum" not in prompt
|
||||
|
||||
|
||||
def test_deepseek_v4_preserves_reference_max_reasoning_effort():
|
||||
prompt = _tokenizer().apply_chat_template(
|
||||
[{"role": "user", "content": "Hello"}],
|
||||
tokenize=False,
|
||||
enable_thinking=True,
|
||||
reasoning_effort="max",
|
||||
)
|
||||
|
||||
assert prompt.startswith(
|
||||
"<|begin▁of▁sentence|>Reasoning Effort: Absolute maximum"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("case_id", "kwargs"),
|
||||
[
|
||||
(1, {"thinking": True}),
|
||||
(2, {"thinking": True}),
|
||||
(3, {"thinking": True}),
|
||||
(4, {}),
|
||||
],
|
||||
)
|
||||
def test_deepseek_v4_matches_reference_golden_fixtures(case_id, kwargs):
|
||||
prompt = _render_reference_case(case_id, **kwargs)
|
||||
|
||||
expected = (FIXTURES_DIR / f"test_output_{case_id}.txt").read_text()
|
||||
assert prompt == expected
|
||||
@@ -0,0 +1,123 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""Unit tests for DeepSeekV4ToolParser."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from vllm.tool_parsers import ToolParserManager
|
||||
from vllm.tool_parsers.deepseekv4_tool_parser import DeepSeekV4ToolParser
|
||||
|
||||
MOCK_TOKENIZER = MagicMock()
|
||||
MOCK_TOKENIZER.get_vocab.return_value = {}
|
||||
|
||||
TC_START = "<|DSML|tool_calls>"
|
||||
TC_END = "</|DSML|tool_calls>"
|
||||
INV_START = '<|DSML|invoke name="'
|
||||
INV_END = "</|DSML|invoke>"
|
||||
PARAM_START = '<|DSML|parameter name="'
|
||||
PARAM_END = "</|DSML|parameter>"
|
||||
|
||||
|
||||
def make_parser(tools=None) -> DeepSeekV4ToolParser:
|
||||
return DeepSeekV4ToolParser(MOCK_TOKENIZER, tools=tools)
|
||||
|
||||
|
||||
def make_request(tools=None) -> MagicMock:
|
||||
req = MagicMock()
|
||||
req.tools = tools
|
||||
return req
|
||||
|
||||
|
||||
def build_tool_call(func_name: str, params: dict[str, str]) -> str:
|
||||
param_strs = "".join(
|
||||
f'{PARAM_START}{k}" string="true">{v}{PARAM_END}\n' for k, v in params.items()
|
||||
)
|
||||
return f'{TC_START}\n{INV_START}{func_name}">\n{param_strs}{INV_END}\n{TC_END}'
|
||||
|
||||
|
||||
def stream(parser: DeepSeekV4ToolParser, full_text: str, chunk_size: int = 7):
|
||||
deltas = []
|
||||
previous_text = ""
|
||||
for start in range(0, len(full_text), chunk_size):
|
||||
delta_text = full_text[start : start + chunk_size]
|
||||
current_text = previous_text + delta_text
|
||||
delta = parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[1],
|
||||
request=make_request(),
|
||||
)
|
||||
previous_text = current_text
|
||||
if delta is not None:
|
||||
deltas.append(delta)
|
||||
return deltas
|
||||
|
||||
|
||||
def reconstruct_args(deltas, tool_index: int = 0) -> str:
|
||||
fragments = []
|
||||
for delta in deltas:
|
||||
if delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
if (
|
||||
tool_call.index == tool_index
|
||||
and tool_call.function
|
||||
and tool_call.function.arguments
|
||||
):
|
||||
fragments.append(tool_call.function.arguments)
|
||||
return "".join(fragments)
|
||||
|
||||
|
||||
def test_registered():
|
||||
assert ToolParserManager.get_tool_parser("deepseek_v4") is DeepSeekV4ToolParser
|
||||
|
||||
|
||||
def test_extract_tool_calls():
|
||||
parser = make_parser()
|
||||
model_output = "Let me check. " + build_tool_call(
|
||||
"get_weather", {"location": "Beijing", "unit": "celsius"}
|
||||
)
|
||||
|
||||
result = parser.extract_tool_calls(model_output, make_request())
|
||||
|
||||
assert result.tools_called
|
||||
assert result.content == "Let me check. "
|
||||
assert len(result.tool_calls) == 1
|
||||
tool_call = result.tool_calls[0]
|
||||
assert tool_call.function.name == "get_weather"
|
||||
assert json.loads(tool_call.function.arguments) == {
|
||||
"location": "Beijing",
|
||||
"unit": "celsius",
|
||||
}
|
||||
|
||||
|
||||
def test_function_calls_block_is_not_accepted():
|
||||
parser = make_parser()
|
||||
model_output = build_tool_call("search", {"query": "vllm"}).replace(
|
||||
"tool_calls", "function_calls"
|
||||
)
|
||||
|
||||
result = parser.extract_tool_calls(model_output, make_request())
|
||||
|
||||
assert not result.tools_called
|
||||
assert result.content == model_output
|
||||
|
||||
|
||||
def test_streaming_extracts_complete_invokes():
|
||||
parser = make_parser()
|
||||
full_text = build_tool_call("search", {"query": "deepseek v4"})
|
||||
|
||||
deltas = stream(parser, full_text, chunk_size=5)
|
||||
|
||||
names = [
|
||||
tool_call.function.name
|
||||
for delta in deltas
|
||||
if delta.tool_calls
|
||||
for tool_call in delta.tool_calls
|
||||
]
|
||||
assert names == ["search"]
|
||||
assert json.loads(reconstruct_args(deltas)) == {"query": "deepseek v4"}
|
||||
@@ -0,0 +1,92 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.utils import create_vllm_config
|
||||
from vllm.v1.attention.backend import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
|
||||
def test_indexer_builder_deepseek_v4_compressed_slot_mapping_uses_storage_block_size():
|
||||
"""Regression test: DeepseekV4 compression path must compute slot_mapping from
|
||||
compressed positions, not reuse the uncompressed common metadata mapping.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
|
||||
# storage_block_size = block_size // compress_ratio = 256 // 4 = 64
|
||||
kv_cache_spec = MLAAttentionSpec(
|
||||
block_size=256,
|
||||
num_kv_heads=1,
|
||||
head_size=128,
|
||||
dtype=torch.bfloat16,
|
||||
compress_ratio=4,
|
||||
)
|
||||
vllm_config = create_vllm_config(max_model_len=1024)
|
||||
builder = DeepseekV32IndexerMetadataBuilder(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
layer_names=["dummy"],
|
||||
vllm_config=vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Construct a single request where:
|
||||
# - num_computed = 240 (=> compressed_pos_start = 60)
|
||||
# - query_len = 40 (=> num_groups = 10)
|
||||
# => compressed positions are 60..69 which cross the storage block boundary at 64.
|
||||
query_start_loc = torch.tensor([0, 40], dtype=torch.int32, device=device)
|
||||
query_start_loc_cpu = query_start_loc.cpu()
|
||||
seq_lens = torch.tensor([280], dtype=torch.int32, device=device) # 240 + 40
|
||||
|
||||
# Two blocks: compressed positions 0..63 map to block 5, 64..127 map to block 7.
|
||||
block_table_tensor = torch.tensor([[5, 7]], dtype=torch.int32, device=device)
|
||||
|
||||
# Dummy uncompressed slot mapping (length == uncompressed num_actual_tokens).
|
||||
slot_mapping = torch.full((40,), -123, dtype=torch.int64, device=device)
|
||||
|
||||
common = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu_upper_bound=seq_lens.cpu(),
|
||||
num_reqs=1,
|
||||
num_actual_tokens=40,
|
||||
max_query_len=40,
|
||||
max_seq_len=280,
|
||||
block_table_tensor=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
md = builder.build(common_prefix_len=0, common_attn_metadata=common)
|
||||
|
||||
# The compressed slot_mapping retains the original uncompressed size (40).
|
||||
# Only every compress_ratio-th position gets a valid slot; the rest are -1.
|
||||
assert md.slot_mapping.numel() == 40
|
||||
valid_slots = md.slot_mapping[md.slot_mapping >= 0]
|
||||
assert valid_slots.numel() == 10 # 40 tokens / compress_ratio 4
|
||||
|
||||
storage_bs = kv_cache_spec.storage_block_size # 64
|
||||
# Compressed positions 60..63 land in block 5, positions 64..69 in block 7.
|
||||
expected = torch.tensor(
|
||||
[
|
||||
5 * storage_bs + 60,
|
||||
5 * storage_bs + 61,
|
||||
5 * storage_bs + 62,
|
||||
5 * storage_bs + 63,
|
||||
]
|
||||
+ [
|
||||
7 * storage_bs + 0,
|
||||
7 * storage_bs + 1,
|
||||
7 * storage_bs + 2,
|
||||
7 * storage_bs + 3,
|
||||
7 * storage_bs + 4,
|
||||
7 * storage_bs + 5,
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
torch.testing.assert_close(valid_slots, expected)
|
||||
@@ -1855,10 +1855,11 @@ def test_generate_scheduler_kv_cache_config():
|
||||
|
||||
|
||||
def new_mla_spec(cache_dtype_str=None):
|
||||
# head_size = kv_lora_rank(512) + qk_rope_head_dim(64) = 576
|
||||
return MLAAttentionSpec(
|
||||
block_size=16,
|
||||
num_kv_heads=16,
|
||||
head_size=64,
|
||||
num_kv_heads=1,
|
||||
head_size=576,
|
||||
dtype=torch.float32,
|
||||
cache_dtype_str=cache_dtype_str,
|
||||
)
|
||||
|
||||
@@ -557,19 +557,19 @@ def test_prefill_hybrid_model_eagle():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(req1.block_hashes) == num_full_blocks
|
||||
assert computed_blocks.get_block_ids() == (
|
||||
[1, 2, 3, 4],
|
||||
[0, 9, 10, 11],
|
||||
[0, 16, 17, 18],
|
||||
[1, 2, 3, 4, 5],
|
||||
[0, 0, 10, 11, 12],
|
||||
[0, 0, 17, 18, 19],
|
||||
)
|
||||
assert num_computed_tokens == 4 * block_size
|
||||
assert num_computed_tokens == 5 * block_size
|
||||
num_new_tokens = len(all_token_ids) - num_computed_tokens
|
||||
blocks = manager.allocate_slots(
|
||||
req1, num_new_tokens, num_computed_tokens, computed_blocks
|
||||
)
|
||||
assert blocks is not None and blocks.get_block_ids() == (
|
||||
[22, 23, 24],
|
||||
[25, 26, 27],
|
||||
[28, 29, 30],
|
||||
[22, 23],
|
||||
[24, 25],
|
||||
[26, 27],
|
||||
)
|
||||
for block_per_group in computed_blocks.blocks:
|
||||
for block in block_per_group:
|
||||
@@ -591,7 +591,7 @@ def test_prefill_hybrid_model_eagle():
|
||||
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||
make_block_hash_with_group_id(block_hashes[0], 2),
|
||||
],
|
||||
4,
|
||||
5,
|
||||
)
|
||||
|
||||
# Evict the first block of full attention, makes total cache miss.
|
||||
@@ -605,7 +605,7 @@ def test_prefill_hybrid_model_eagle():
|
||||
0,
|
||||
)
|
||||
|
||||
# Evict the last block of all layers, reduces the hit length to 3.
|
||||
# Evict the last block of all layers, reduces the hit length to 4.
|
||||
_test_partial_request_hit(
|
||||
manager,
|
||||
block_size,
|
||||
@@ -617,10 +617,10 @@ def test_prefill_hybrid_model_eagle():
|
||||
make_block_hash_with_group_id(block_hashes[-1], 1),
|
||||
make_block_hash_with_group_id(block_hashes[-1], 2),
|
||||
],
|
||||
3,
|
||||
4,
|
||||
)
|
||||
|
||||
# Evict the last block of full attention, reduces the hit length to 3.
|
||||
# Evict the last block of full attention, reduces the hit length to 4.
|
||||
_test_partial_request_hit(
|
||||
manager,
|
||||
block_size,
|
||||
@@ -628,7 +628,7 @@ def test_prefill_hybrid_model_eagle():
|
||||
"5",
|
||||
all_token_ids,
|
||||
[make_block_hash_with_group_id(block_hashes[-1], 0)],
|
||||
3,
|
||||
4,
|
||||
)
|
||||
|
||||
# Since the last block of full attention is dropped for eagle, evict
|
||||
@@ -655,12 +655,11 @@ def test_prefill_hybrid_model_eagle():
|
||||
3,
|
||||
)
|
||||
|
||||
# Evict different set of blocks for full attention and sliding window makes
|
||||
# total cache miss.
|
||||
# The cache hit length of full attention is 4 * block_size.
|
||||
# The cache hit length of sliding window is 3 * block_size.
|
||||
# Then it is cache miss as the two type of layers
|
||||
# have different hit length.
|
||||
# Evict different set of blocks for full attention and sliding window.
|
||||
# Full loses its last block so it drops to 4 full blocks after the eagle
|
||||
# pop; SWA lost block 0 (outside the sliding window of the final hit),
|
||||
# which is not required for the K+1 anchor at position 4. Coordinated
|
||||
# single-drop aligns both groups at hit=4.
|
||||
_test_partial_request_hit(
|
||||
manager,
|
||||
block_size,
|
||||
@@ -672,7 +671,7 @@ def test_prefill_hybrid_model_eagle():
|
||||
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||
make_block_hash_with_group_id(block_hashes[0], 2),
|
||||
],
|
||||
0,
|
||||
4,
|
||||
)
|
||||
|
||||
|
||||
@@ -893,7 +892,7 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]):
|
||||
# - 2 groups: 1 full + 1 other
|
||||
_EAGLE_HYBRID_MODEL_TEST_CASES = [
|
||||
# 2 groups: 1 full + 1 other
|
||||
pytest.param(["full", "sliding_window"], 2, id="2g-full+sw"),
|
||||
pytest.param(["full", "sliding_window"], 3, id="2g-full+sw"),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1892,6 +1892,7 @@ def create_scheduler_with_priority(
|
||||
log_stats=True,
|
||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||
block_size=block_size,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
|
||||
@@ -4008,6 +4009,7 @@ def _create_encoder_decoder_scheduler(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
block_size=block_size,
|
||||
hash_block_size=block_size,
|
||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||
)
|
||||
|
||||
|
||||
@@ -91,8 +91,10 @@ def test_basic_interface():
|
||||
assert request_id in kv_connector_metadata.reqs_to_recv["my-engine-id"]
|
||||
req_meta = kv_connector_metadata.reqs_to_recv["my-engine-id"][request_id]
|
||||
|
||||
# local_block_ids is list[list[int]] (per-group); flatten for comparison.
|
||||
all_block_ids = [bid for group in req_meta.local_block_ids for bid in group]
|
||||
for block_id, block in zip(
|
||||
req_meta.local_block_ids,
|
||||
all_block_ids,
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
request_id
|
||||
],
|
||||
@@ -228,15 +230,15 @@ def test_scheduler_request_finished():
|
||||
|
||||
# Case: Capped length (Successful prefill, need to send to decoder)
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
delay_free, _ = scheduler_connector.request_finished(request, block_ids=[10, 11])
|
||||
delay_free, _ = scheduler_connector.request_finished(request, block_ids=([10, 11],))
|
||||
assert delay_free is True
|
||||
assert "id-1" in scheduler_connector._reqs_need_send
|
||||
assert scheduler_connector._reqs_need_send["id-1"][1] == [10, 11]
|
||||
assert scheduler_connector._reqs_need_send["id-1"][1] == [[10, 11]]
|
||||
|
||||
# Case: Aborted (No need to transfer, free blocks immediately)
|
||||
scheduler_connector._reqs_need_send.clear()
|
||||
request.status = RequestStatus.FINISHED_ABORTED
|
||||
delay_free, _ = scheduler_connector.request_finished(request, block_ids=[12])
|
||||
delay_free, _ = scheduler_connector.request_finished(request, block_ids=([12],))
|
||||
assert delay_free is False
|
||||
assert len(scheduler_connector._reqs_need_send) == 0
|
||||
assert "id-1" in scheduler_connector._reqs_not_processed
|
||||
@@ -334,7 +336,7 @@ async def test_kv_producer(monkeypatch):
|
||||
send_meta = SendBlockMeta(
|
||||
p_req_id="p-req-1",
|
||||
transfer_id=transfer_id,
|
||||
local_block_ids=[10, 11],
|
||||
local_block_ids=[[10, 11]],
|
||||
ready=asyncio.Event(),
|
||||
)
|
||||
prefill_worker.reqs_need_send[transfer_id] = send_meta
|
||||
@@ -346,7 +348,7 @@ async def test_kv_producer(monkeypatch):
|
||||
remote_port=54321,
|
||||
remote_tp_size=1,
|
||||
remote_tp_rank=0,
|
||||
req_blocks={"d-req-1": (transfer_id, [20, 21])},
|
||||
req_blocks={"d-req-1": (transfer_id, [[20, 21]])},
|
||||
kv_caches_base_addr=[0x2000],
|
||||
block_lens=[block_len],
|
||||
)
|
||||
@@ -389,7 +391,7 @@ async def test_kv_producer(monkeypatch):
|
||||
prefill_worker.reqs_need_send[transfer_id] = send_meta
|
||||
send_meta.sent = 0
|
||||
send_meta.ready.set()
|
||||
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20])
|
||||
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [[20]])
|
||||
# Worker processes the consumer's request
|
||||
await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta)
|
||||
# Verify transfer parameters are correct: 11 to 20
|
||||
@@ -407,7 +409,7 @@ async def test_kv_producer(monkeypatch):
|
||||
prefill_worker.reqs_need_send[transfer_id] = send_meta
|
||||
send_meta.sent = 0
|
||||
send_meta.ready.set()
|
||||
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20, 21, 22])
|
||||
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [[20, 21, 22]])
|
||||
# Worker processes the consumer's request
|
||||
await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta)
|
||||
# This should not be called because error.
|
||||
@@ -424,7 +426,7 @@ async def test_kv_producer(monkeypatch):
|
||||
prefill_worker.reqs_need_send[transfer_id] = send_meta
|
||||
send_meta.sent = 0
|
||||
send_meta.ready.clear()
|
||||
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20, 21])
|
||||
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [[20, 21]])
|
||||
# Worker processes the consumer's request
|
||||
await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta)
|
||||
# This should not be called because timeout.
|
||||
@@ -443,7 +445,7 @@ async def test_kv_producer(monkeypatch):
|
||||
prefill_worker.reqs_need_send[transfer_id] = send_meta
|
||||
send_meta.sent = 0
|
||||
send_meta.ready.set()
|
||||
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20, 21])
|
||||
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [[20, 21]])
|
||||
# Worker processes the consumer's request
|
||||
await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta)
|
||||
mock_send_blocks.assert_called_once()
|
||||
@@ -481,7 +483,7 @@ async def test_kv_consumuer(monkeypatch):
|
||||
"d-req-1": PullReqMeta(
|
||||
d_req_id="d-req-1",
|
||||
transfer_id="xfer-req-1",
|
||||
local_block_ids=[100, 101],
|
||||
local_block_ids=[[100, 101]],
|
||||
remote_engine_id="p-engine",
|
||||
remote_bootstrap_addr="http://bootstrap:33333",
|
||||
pull_tasks_count=1,
|
||||
@@ -514,7 +516,7 @@ async def test_kv_consumuer(monkeypatch):
|
||||
|
||||
assert sent_meta.remote_hostname == "127.0.0.1"
|
||||
assert sent_meta.remote_port == 54321
|
||||
assert sent_meta.req_blocks["d-req-1"] == ("xfer-req-1", [100, 101])
|
||||
assert sent_meta.req_blocks["d-req-1"] == ("xfer-req-1", [[100, 101]])
|
||||
|
||||
# Verify internal state is updated correctly.
|
||||
assert "d-req-1" in decode_worker.finished_recving_reqs
|
||||
@@ -538,7 +540,7 @@ async def test_worker_get_finished_timeout(monkeypatch):
|
||||
prefill_worker.reqs_need_send["tx-expired"] = SendBlockMeta(
|
||||
p_req_id="p-req-expired",
|
||||
transfer_id="tx-expired",
|
||||
local_block_ids=[1, 2],
|
||||
local_block_ids=[[1, 2]],
|
||||
ready=MagicMock(),
|
||||
expire_time=time.perf_counter() - 100,
|
||||
)
|
||||
@@ -547,7 +549,7 @@ async def test_worker_get_finished_timeout(monkeypatch):
|
||||
prefill_worker.reqs_need_send["tx-active"] = SendBlockMeta(
|
||||
p_req_id="p-req-active",
|
||||
transfer_id="tx-active",
|
||||
local_block_ids=[3, 4],
|
||||
local_block_ids=[[3, 4]],
|
||||
ready=MagicMock(),
|
||||
expire_time=time.perf_counter() + 100,
|
||||
)
|
||||
@@ -703,7 +705,7 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
|
||||
prefill_worker.sender_loop = asyncio.get_event_loop()
|
||||
|
||||
transfer_id = "xfer-hetero-1"
|
||||
local_block_ids = [10, 11]
|
||||
local_block_ids = [[10, 11]]
|
||||
send_meta = SendBlockMeta(
|
||||
p_req_id="p-req-h1",
|
||||
transfer_id=transfer_id,
|
||||
@@ -720,9 +722,9 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
|
||||
mock_socket.send_multipart = AsyncMock()
|
||||
identity = b"consumer-hetero"
|
||||
|
||||
# Assign different remote block IDs per D rank
|
||||
# Assign different remote block IDs per D rank (nested per-group)
|
||||
d_rank_remote_blocks = {
|
||||
rank: [20 + i * 10, 21 + i * 10] for i, rank in enumerate(target_d_ranks)
|
||||
rank: [[20 + i * 10, 21 + i * 10]] for i, rank in enumerate(target_d_ranks)
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
@@ -757,11 +759,15 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
|
||||
dst_ptrs = call_args[2]
|
||||
lengths = call_args[3]
|
||||
|
||||
# Flatten nested per-group block IDs for assertions
|
||||
flat_local = [b for g in local_block_ids for b in g]
|
||||
flat_remote = [b for g in remote_block_ids for b in g]
|
||||
|
||||
# Heterogeneous TP: blocks cannot be coalesced because
|
||||
# local and remote block_lens differ
|
||||
assert len(src_ptrs) == len(local_block_ids)
|
||||
assert len(dst_ptrs) == len(local_block_ids)
|
||||
assert len(lengths) == len(local_block_ids)
|
||||
assert len(src_ptrs) == len(flat_local)
|
||||
assert len(dst_ptrs) == len(flat_local)
|
||||
assert len(lengths) == len(flat_local)
|
||||
|
||||
# Compute expected offsets based on TP ratio
|
||||
if d_tp_size <= P_TP_SIZE:
|
||||
@@ -775,9 +781,7 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
|
||||
expected_dst_off = 0
|
||||
expected_xfer_len = remote_block_len
|
||||
|
||||
for idx, (lblk, rblk) in enumerate(
|
||||
zip(local_block_ids, remote_block_ids)
|
||||
):
|
||||
for idx, (lblk, rblk) in enumerate(zip(flat_local, flat_remote)):
|
||||
assert src_ptrs[idx] == (
|
||||
0x1000 + lblk * local_block_len + expected_src_off
|
||||
)
|
||||
|
||||
@@ -0,0 +1,410 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for MooncakeConnector HMA (Hybrid Memory Architecture) support.
|
||||
|
||||
Covers sliding-window clipping, multi-group metadata shape, multi-group
|
||||
send trimming, and group-count invariant checking in _build_transfer_params.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector import (
|
||||
KVConnectorRole,
|
||||
MooncakeConnector,
|
||||
MooncakeConnectorMetadata,
|
||||
MooncakeConnectorScheduler,
|
||||
MooncakeXferMetadata,
|
||||
SendBlockMeta,
|
||||
TransferRegion,
|
||||
)
|
||||
|
||||
from .test_mooncake_connector import FakeMooncakeWrapper, patch_worker_dependencies
|
||||
from .utils import create_request, create_vllm_config, make_kv_cache_config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# test_sw_sizes: blocks_per_sw computed from KVCacheConfig
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.cpu_test
|
||||
@pytest.mark.parametrize(
|
||||
"swa_enabled,expected_blocks_per_sw",
|
||||
[
|
||||
# SWA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128+1)
|
||||
(True, [0, 128 + 1]),
|
||||
# SWA disabled: only FullAttentionSpec (0)
|
||||
(False, [0]),
|
||||
],
|
||||
)
|
||||
def test_sw_sizes(swa_enabled, expected_blocks_per_sw):
|
||||
"""blocks_per_sw is correctly computed based on SWA enabled/disabled."""
|
||||
block_size = 16
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector="MooncakeConnector",
|
||||
kv_role="kv_both",
|
||||
block_size=block_size,
|
||||
)
|
||||
# Override so HMA detection works
|
||||
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False
|
||||
kv_cache_config = make_kv_cache_config(
|
||||
block_size=block_size, swa_enabled=swa_enabled, sw_size=2048
|
||||
)
|
||||
|
||||
scheduler = MooncakeConnectorScheduler(
|
||||
vllm_config=vllm_config,
|
||||
engine_id="test-engine",
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
assert scheduler.blocks_per_sw == expected_blocks_per_sw
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# test_is_hma_required: derived from kv_cache_config groups
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.cpu_test
|
||||
@pytest.mark.parametrize(
|
||||
"swa_enabled,disable_hma,expected_is_hma",
|
||||
[
|
||||
(True, False, True), # SWA group present, HMA enabled
|
||||
(True, True, False), # SWA group present, but HMA disabled
|
||||
(False, False, False), # FA only, HMA not needed
|
||||
],
|
||||
)
|
||||
def test_is_hma_required(swa_enabled, disable_hma, expected_is_hma):
|
||||
"""_is_hma_required is correctly derived from kv_cache_config."""
|
||||
block_size = 16
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector="MooncakeConnector",
|
||||
kv_role="kv_both",
|
||||
block_size=block_size,
|
||||
)
|
||||
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = disable_hma
|
||||
kv_cache_config = make_kv_cache_config(
|
||||
block_size=block_size, swa_enabled=swa_enabled
|
||||
)
|
||||
|
||||
scheduler = MooncakeConnectorScheduler(
|
||||
vllm_config=vllm_config,
|
||||
engine_id="test-engine",
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
assert scheduler._is_hma_required is expected_is_hma
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# test_get_sw_clipped_blocks: sliding-window clipping logic
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.cpu_test
|
||||
def test_get_sw_clipped_blocks():
|
||||
"""get_sw_clipped_blocks clips SWA group but keeps FA group intact."""
|
||||
block_size = 16
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector="MooncakeConnector",
|
||||
kv_role="kv_both",
|
||||
block_size=block_size,
|
||||
)
|
||||
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False
|
||||
# SW=128 tokens → 128/16 = 8 blocks + 1 = 9 blocks_per_sw
|
||||
kv_cache_config = make_kv_cache_config(
|
||||
block_size=block_size, swa_enabled=True, sw_size=128
|
||||
)
|
||||
|
||||
scheduler = MooncakeConnectorScheduler(
|
||||
vllm_config=vllm_config,
|
||||
engine_id="test-engine",
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
assert scheduler.blocks_per_sw == [0, 9]
|
||||
|
||||
# FA group: 20 blocks, SW group: 20 blocks (exceeds window)
|
||||
fa_blocks = list(range(20))
|
||||
sw_blocks = list(range(100, 120))
|
||||
block_ids = (fa_blocks, sw_blocks)
|
||||
|
||||
clipped = scheduler.get_sw_clipped_blocks(block_ids)
|
||||
|
||||
# FA: untouched (blocks_per_sw[0] = 0)
|
||||
assert clipped[0] == fa_blocks
|
||||
# SW: clipped to last 9 blocks
|
||||
assert clipped[1] == sw_blocks[-9:]
|
||||
assert len(clipped[1]) == 9
|
||||
|
||||
|
||||
@pytest.mark.cpu_test
|
||||
def test_get_sw_clipped_blocks_noop_no_hma():
|
||||
"""get_sw_clipped_blocks is a no-op when HMA is not required."""
|
||||
block_size = 16
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector="MooncakeConnector",
|
||||
kv_role="kv_both",
|
||||
block_size=block_size,
|
||||
)
|
||||
# FA only → _is_hma_required = False
|
||||
kv_cache_config = make_kv_cache_config(block_size=block_size, swa_enabled=False)
|
||||
|
||||
scheduler = MooncakeConnectorScheduler(
|
||||
vllm_config=vllm_config,
|
||||
engine_id="test-engine",
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
assert scheduler._is_hma_required is False
|
||||
|
||||
block_ids = ([1, 2, 3],)
|
||||
clipped = scheduler.get_sw_clipped_blocks(block_ids)
|
||||
assert clipped == [[1, 2, 3]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# test_metadata_hma_block_ids: MooncakeConnectorMetadata stores per-group IDs
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.cpu_test
|
||||
def test_metadata_hma_block_ids():
|
||||
"""MooncakeConnectorMetadata.add_new_req stores per-group block IDs."""
|
||||
metadata = MooncakeConnectorMetadata()
|
||||
|
||||
# FA group: 6 blocks, SW group: 3 blocks (clipped)
|
||||
fa_blocks = [0, 1, 2, 3, 4, 5]
|
||||
sw_blocks = [10, 11, 12]
|
||||
|
||||
# Test recv path
|
||||
metadata.add_new_req(
|
||||
request_id="recv-req",
|
||||
local_block_ids=[fa_blocks, sw_blocks],
|
||||
kv_transfer_params={
|
||||
"transfer_id": "recv-req",
|
||||
"remote_engine_id": "remote-engine",
|
||||
"remote_bootstrap_addr": "http://bootstrap:33333",
|
||||
},
|
||||
load_remote_cache=True,
|
||||
)
|
||||
|
||||
assert "recv-req" in metadata.reqs_to_recv["remote-engine"]
|
||||
req_meta = metadata.reqs_to_recv["remote-engine"]["recv-req"]
|
||||
assert len(req_meta.local_block_ids) == 2
|
||||
assert req_meta.local_block_ids[0] == fa_blocks
|
||||
assert req_meta.local_block_ids[1] == sw_blocks
|
||||
|
||||
# Test send path
|
||||
metadata.add_new_req(
|
||||
request_id="send-req",
|
||||
local_block_ids=[fa_blocks, sw_blocks],
|
||||
kv_transfer_params={
|
||||
"transfer_id": "send-req",
|
||||
},
|
||||
load_remote_cache=False,
|
||||
)
|
||||
|
||||
assert "send-req" in metadata.reqs_to_send
|
||||
transfer_id, stored_blocks = metadata.reqs_to_send["send-req"]
|
||||
assert transfer_id == "send-req"
|
||||
assert len(stored_blocks) == 2
|
||||
assert stored_blocks[0] == fa_blocks
|
||||
assert stored_blocks[1] == sw_blocks
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# test_build_transfer_params_multi_group_trimming
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake"
|
||||
".mooncake_connector.TransferEngine",
|
||||
FakeMooncakeWrapper,
|
||||
)
|
||||
async def test_build_transfer_params_multi_group_trimming(monkeypatch):
|
||||
"""_build_transfer_params trims per-group blocks when local > remote."""
|
||||
|
||||
monkeypatch.setenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "5")
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector="MooncakeConnector", kv_role="kv_producer"
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config), patch_worker_dependencies():
|
||||
connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
worker = connector.connector_worker
|
||||
|
||||
block_len = 4096
|
||||
# Call _build_transfer_params directly (avoids send_kv_to_decode
|
||||
# async event loop complexity).
|
||||
transfer_id = "xfer-hma-trim"
|
||||
send_meta = SendBlockMeta(
|
||||
p_req_id="p-trim",
|
||||
transfer_id=transfer_id,
|
||||
# FA: 4 blocks, SW: 3 blocks (producer has more)
|
||||
local_block_ids=[[10, 11, 12, 13], [20, 21, 22]],
|
||||
ready=asyncio.Event(),
|
||||
)
|
||||
|
||||
xfer_meta = MooncakeXferMetadata(
|
||||
remote_hostname="consumer-host",
|
||||
remote_port=54321,
|
||||
remote_tp_size=1,
|
||||
remote_tp_rank=0,
|
||||
req_blocks={
|
||||
"d-trim": (
|
||||
transfer_id,
|
||||
# FA: 2 blocks, SW: 2 blocks (consumer needs fewer)
|
||||
[[30, 31], [40, 41]],
|
||||
)
|
||||
},
|
||||
kv_caches_base_addr=[0x2000],
|
||||
block_lens=[block_len],
|
||||
)
|
||||
|
||||
local_regions = [
|
||||
TransferRegion(
|
||||
base_addr=0x1000, block_len=block_len, kv_block_len=block_len
|
||||
),
|
||||
]
|
||||
remote_regions = [
|
||||
TransferRegion(
|
||||
base_addr=0x2000, block_len=block_len, kv_block_len=block_len
|
||||
),
|
||||
]
|
||||
|
||||
ready_reqs = [("d-trim", send_meta)]
|
||||
(
|
||||
src_ptrs,
|
||||
dst_ptrs,
|
||||
lengths,
|
||||
err_reqs,
|
||||
err_msg,
|
||||
) = await worker._build_transfer_params(
|
||||
ready_reqs, xfer_meta, local_regions, remote_regions
|
||||
)
|
||||
|
||||
# No errors
|
||||
assert err_reqs == []
|
||||
assert err_msg is None
|
||||
# After trimming: FA [10..13] → last 2 → [12,13]; SW [20..22] → last 2 → [21,22]
|
||||
# Flattened: [12,13,21,22] = 4 blocks → coalesced into some transfers
|
||||
assert len(src_ptrs) > 0
|
||||
assert len(dst_ptrs) == len(src_ptrs)
|
||||
assert len(lengths) == len(src_ptrs)
|
||||
|
||||
worker.shutdown()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# test_build_transfer_params_group_count_mismatch
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake"
|
||||
".mooncake_connector.TransferEngine",
|
||||
FakeMooncakeWrapper,
|
||||
)
|
||||
async def test_build_transfer_params_group_count_mismatch(monkeypatch):
|
||||
"""_build_transfer_params reports an error when group counts differ."""
|
||||
|
||||
monkeypatch.setenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "5")
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector="MooncakeConnector", kv_role="kv_producer"
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config), patch_worker_dependencies():
|
||||
connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
worker = connector.connector_worker
|
||||
|
||||
block_len = 4096
|
||||
transfer_id = "xfer-mismatch"
|
||||
send_meta = SendBlockMeta(
|
||||
p_req_id="p-mismatch",
|
||||
transfer_id=transfer_id,
|
||||
# Producer has 2 groups
|
||||
local_block_ids=[[10, 11], [20, 21]],
|
||||
ready=asyncio.Event(),
|
||||
)
|
||||
|
||||
# Consumer has only 1 group — group count mismatch
|
||||
xfer_meta = MooncakeXferMetadata(
|
||||
remote_hostname="consumer-host",
|
||||
remote_port=54321,
|
||||
remote_tp_size=1,
|
||||
remote_tp_rank=0,
|
||||
req_blocks={
|
||||
"d-mismatch": (transfer_id, [[30, 31]]),
|
||||
},
|
||||
kv_caches_base_addr=[0x2000],
|
||||
block_lens=[block_len],
|
||||
)
|
||||
|
||||
local_regions = [
|
||||
TransferRegion(
|
||||
base_addr=0x1000, block_len=block_len, kv_block_len=block_len
|
||||
),
|
||||
]
|
||||
remote_regions = [
|
||||
TransferRegion(
|
||||
base_addr=0x2000, block_len=block_len, kv_block_len=block_len
|
||||
),
|
||||
]
|
||||
|
||||
ready_reqs = [("d-mismatch", send_meta)]
|
||||
(
|
||||
src_ptrs,
|
||||
dst_ptrs,
|
||||
lengths,
|
||||
err_reqs,
|
||||
err_msg,
|
||||
) = await worker._build_transfer_params(
|
||||
ready_reqs, xfer_meta, local_regions, remote_regions
|
||||
)
|
||||
|
||||
# Mismatched req is reported via err_reqs/err_msg with no transfers built.
|
||||
assert err_reqs == ["d-mismatch"]
|
||||
assert err_msg == "KV group count mismatch"
|
||||
assert src_ptrs == []
|
||||
assert dst_ptrs == []
|
||||
assert lengths == []
|
||||
|
||||
worker.shutdown()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# test_request_finished_with_hma_groups
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.cpu_test
|
||||
def test_request_finished_with_hma_groups():
|
||||
"""request_finished correctly handles per-group block_ids."""
|
||||
block_size = 16
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector="MooncakeConnector",
|
||||
kv_role="kv_producer",
|
||||
block_size=block_size,
|
||||
)
|
||||
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False
|
||||
kv_cache_config = make_kv_cache_config(
|
||||
block_size=block_size, swa_enabled=True, sw_size=128
|
||||
)
|
||||
|
||||
scheduler = MooncakeConnectorScheduler(
|
||||
vllm_config=vllm_config,
|
||||
engine_id="test-engine",
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
request = create_request(request_id=1, do_remote_decode=True)
|
||||
request.kv_transfer_params["transfer_id"] = request.request_id
|
||||
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
|
||||
# 2 groups: FA with 10 blocks, SW with 20 blocks (will be clipped)
|
||||
fa_blocks = list(range(10))
|
||||
sw_blocks = list(range(100, 120))
|
||||
block_ids = (fa_blocks, sw_blocks)
|
||||
|
||||
delay_free, _ = scheduler.request_finished(request, block_ids)
|
||||
assert delay_free is True
|
||||
assert request.request_id in scheduler._reqs_need_send
|
||||
|
||||
_, stored_blocks = scheduler._reqs_need_send[request.request_id]
|
||||
# FA: untouched
|
||||
assert stored_blocks[0] == fa_blocks
|
||||
# SW: clipped to last 9 blocks (sw_size=128, block_size=16 → 8+1=9)
|
||||
assert stored_blocks[1] == sw_blocks[-9:]
|
||||
@@ -76,6 +76,7 @@ def create_scheduler() -> Scheduler:
|
||||
log_stats=True,
|
||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||
block_size=16,
|
||||
hash_block_size=16,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ set -e
|
||||
# Default values
|
||||
# Keep DEEPGEMM_GIT_REF in sync with cmake/external_projects/deepgemm.cmake
|
||||
DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git"
|
||||
DEEPGEMM_GIT_REF="477618cd51baffca09c4b0b87e97c03fe827ef03"
|
||||
DEEPGEMM_GIT_REF="891d57b4db1071624b5c8fa0d1e51cb317fa709f"
|
||||
WHEEL_DIR=""
|
||||
|
||||
# Parse command line arguments
|
||||
|
||||
+41
-3
@@ -404,10 +404,24 @@ def rotary_embedding(
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
rope_dim_offset: int = 0,
|
||||
inverse: bool = False,
|
||||
) -> None:
|
||||
torch.ops._C.rotary_embedding(
|
||||
positions, query, key, head_size, cos_sin_cache, is_neox
|
||||
)
|
||||
if rope_dim_offset == 0 and not inverse:
|
||||
torch.ops._C.rotary_embedding(
|
||||
positions, query, key, head_size, cos_sin_cache, is_neox
|
||||
)
|
||||
else:
|
||||
torch.ops._C.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
is_neox,
|
||||
rope_dim_offset,
|
||||
inverse,
|
||||
)
|
||||
|
||||
|
||||
# layer norm ops
|
||||
@@ -2503,6 +2517,30 @@ def topk_sigmoid(
|
||||
)
|
||||
|
||||
|
||||
def topk_hash_softplus_sqrt(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool = False,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
input_tokens: torch.Tensor | None = None,
|
||||
hash_indices_table: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
torch.ops._moe_C.topk_softplus_sqrt(
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
input_tokens,
|
||||
hash_indices_table,
|
||||
)
|
||||
|
||||
|
||||
def grouped_topk(
|
||||
scores: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
|
||||
@@ -51,6 +51,9 @@ class AttentionConfig:
|
||||
use_prefill_query_quantization: bool = False
|
||||
"""If set, quantize query for attention in prefill."""
|
||||
|
||||
use_fp4_indexer_cache: bool = False
|
||||
"""If set, use fp4 indexer cache for dsv32 family model (not support yet)"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
|
||||
@@ -51,6 +51,18 @@ class CacheConfig:
|
||||
"""Whether block_size was explicitly provided. Derived automatically."""
|
||||
user_specified_mamba_block_size: bool = field(default=False, init=False)
|
||||
"""Whether mamba_block_size was explicitly provided. Derived automatically."""
|
||||
hash_block_size: SkipValidation[int] | None = None # type: ignore
|
||||
"""Block size (in tokens) used for computing Request's block_hashes.
|
||||
|
||||
This can be set to a finer granularity than the physical KV cache block
|
||||
sizes (e.g. 8) as long as every KV cache group's `block_size` is divisible
|
||||
by it. This enables prefix-caching keys to be computed at the finest common
|
||||
granularity and then merged for larger physical block sizes.
|
||||
|
||||
This config is not static default. If left unspecified, vLLM will choose a
|
||||
default based on the resolved KV cache groups (typically the smallest KV
|
||||
cache block size when there are multiple groups).
|
||||
"""
|
||||
gpu_memory_utilization: float = Field(default=0.92, gt=0, le=1)
|
||||
"""The fraction of GPU memory to be used for the model executor, which can
|
||||
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
|
||||
@@ -182,6 +194,8 @@ class CacheConfig:
|
||||
"num_gpu_blocks_override",
|
||||
"enable_prefix_caching",
|
||||
"prefix_caching_hash_algo",
|
||||
# Prefix-caching implementation detail (doesn't affect compiled graph).
|
||||
"hash_block_size",
|
||||
"mamba_page_size_padded",
|
||||
"user_specified_block_size",
|
||||
"user_specified_mamba_block_size",
|
||||
|
||||
@@ -749,6 +749,7 @@ class CompilationConfig:
|
||||
"vllm::kda_attention",
|
||||
"vllm::sparse_attn_indexer",
|
||||
"vllm::rocm_aiter_sparse_attn_indexer",
|
||||
"vllm::deepseek_v4_attention",
|
||||
]
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
|
||||
@@ -50,7 +50,7 @@ class IrOpPriorityConfig:
|
||||
name: {
|
||||
provider: IrOp.registry[name].impls[provider].uuid() for provider in p
|
||||
}
|
||||
for name, p in asdict(self).items()
|
||||
for name, p in asdict(self).items() # type: ignore[call-overload]
|
||||
}
|
||||
|
||||
return hash_factors(factors)
|
||||
@@ -77,7 +77,7 @@ class IrOpPriorityConfig:
|
||||
current_platform.import_ir_kernels()
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
for field in fields(self):
|
||||
for field in fields(self): # type: ignore[arg-type]
|
||||
op_priority = getattr(self, field.name)
|
||||
assert op_priority is not None, (
|
||||
f"IR op priority for {field.name} must be set"
|
||||
@@ -98,7 +98,7 @@ class IrOpPriorityConfig:
|
||||
A helper to create an IrOpPriorityConfig where fields not specified in kwargs
|
||||
use the given default list.
|
||||
"""
|
||||
for field in fields(cls):
|
||||
for field in fields(cls): # type: ignore[arg-type]
|
||||
if field.name not in kwargs:
|
||||
kwargs[field.name] = list(default)
|
||||
|
||||
@@ -109,6 +109,7 @@ MoEBackend = Literal[
|
||||
"auto",
|
||||
"triton",
|
||||
"deep_gemm",
|
||||
"deep_gemm_mega_moe",
|
||||
"cutlass",
|
||||
"flashinfer_trtllm",
|
||||
"flashinfer_cutlass",
|
||||
@@ -136,8 +137,9 @@ class KernelConfig:
|
||||
"""Backend for MoE expert computation kernels. Available options:
|
||||
|
||||
- "auto": Automatically select the best backend based on model and hardware
|
||||
- "triton": Use Triton-based fused MoE kernels
|
||||
- "triton": Use Triton-based fused MoE kernels
|
||||
- "deep_gemm": Use DeepGEMM kernels (FP8 block-quantized only)
|
||||
- "deep_gemm_mega_moe": Use DeepGEMM mega MoE kernels
|
||||
- "cutlass": Use vLLM CUTLASS kernels
|
||||
- "flashinfer_trtllm": Use FlashInfer with TRTLLM-GEN kernels
|
||||
- "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels
|
||||
|
||||
@@ -83,7 +83,7 @@ logger = init_logger(__name__)
|
||||
RunnerOption = Literal["auto", RunnerType]
|
||||
ConvertType = Literal["none", "embed", "classify"]
|
||||
ConvertOption = Literal["auto", ConvertType]
|
||||
TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"]
|
||||
TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32", "deepseek_v4"]
|
||||
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
||||
LogprobsMode = Literal[
|
||||
"raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs"
|
||||
@@ -134,6 +134,7 @@ class ModelConfig:
|
||||
- "slow" will always use the slow tokenizer.
|
||||
- "mistral" will always use the tokenizer from `mistral_common`.
|
||||
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.
|
||||
- "deepseek_v4" will always use the tokenizer from `deepseek_v4`.
|
||||
- "qwen_vl" will always use the tokenizer from `qwen_vl`.
|
||||
- Other custom values can be supported via plugins."""
|
||||
trust_remote_code: bool = False
|
||||
@@ -565,6 +566,8 @@ class ModelConfig:
|
||||
self.tokenizer_mode = "qwen_vl"
|
||||
elif arch == "DeepseekV32ForCausalLM":
|
||||
self.tokenizer_mode = "deepseek_v32"
|
||||
elif arch == "DeepseekV4ForCausalLM":
|
||||
self.tokenizer_mode = "deepseek_v4"
|
||||
|
||||
if self.tokenizer_mode != "auto":
|
||||
logger.info(
|
||||
@@ -952,6 +955,7 @@ class ModelConfig:
|
||||
# imports during override detection (e.g., MXFP4 imports Triton)
|
||||
"mxfp4",
|
||||
"gpt_oss_mxfp4",
|
||||
"deepseek_v4_fp8",
|
||||
"cpu_awq",
|
||||
"humming",
|
||||
"gguf",
|
||||
|
||||
@@ -287,13 +287,23 @@ class SpeculativeConfig:
|
||||
@staticmethod
|
||||
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||
initial_architecture = hf_config.architectures[0]
|
||||
if hf_config.model_type in ("deepseek_v3", "deepseek_v32", "glm_moe_dsa"):
|
||||
if hf_config.model_type in (
|
||||
"deepseek_v3",
|
||||
"deepseek_v32",
|
||||
"glm_moe_dsa",
|
||||
):
|
||||
hf_config.model_type = "deepseek_mtp"
|
||||
if hf_config.model_type == "deepseek_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]}
|
||||
)
|
||||
if hf_config.model_type == "deepseek_v4":
|
||||
hf_config.model_type = "deepseek_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["DeepSeekV4MTPModel"]}
|
||||
)
|
||||
if hf_config.model_type in ("pangu_ultra_moe"):
|
||||
hf_config.model_type = "pangu_ultra_moe_mtp"
|
||||
if hf_config.model_type == "pangu_ultra_moe_mtp":
|
||||
|
||||
@@ -29,6 +29,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
SupportsHMA,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_utils import (
|
||||
MooncakeBootstrapServer,
|
||||
@@ -43,10 +44,12 @@ from vllm.distributed.parallel_state import (
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, SlidingWindowSpec
|
||||
from vllm.v1.request import RequestStatus
|
||||
from vllm.v1.worker.utils import select_common_block_size
|
||||
|
||||
@@ -252,7 +255,7 @@ class MooncakeXferMetadata(
|
||||
remote_port: int
|
||||
remote_tp_size: int
|
||||
remote_tp_rank: int
|
||||
req_blocks: dict[ReqId, tuple[TransferId, list[int]]]
|
||||
req_blocks: dict[ReqId, tuple[TransferId, list[list[int]]]]
|
||||
kv_caches_base_addr: list[int]
|
||||
block_lens: list[int]
|
||||
|
||||
@@ -280,7 +283,7 @@ class MooncakeXferResponse(
|
||||
class PullReqMeta:
|
||||
d_req_id: ReqId
|
||||
transfer_id: TransferId
|
||||
local_block_ids: list[int]
|
||||
local_block_ids: list[list[int]]
|
||||
remote_engine_id: EngineId
|
||||
remote_bootstrap_addr: str
|
||||
# Set expire time to avoid infinitely sending requests.
|
||||
@@ -293,7 +296,7 @@ class PullReqMeta:
|
||||
class SendBlockMeta:
|
||||
p_req_id: ReqId
|
||||
transfer_id: TransferId
|
||||
local_block_ids: list[int]
|
||||
local_block_ids: list[list[int]]
|
||||
ready: asyncio.Event
|
||||
expire_time: float = float("inf")
|
||||
need_send: int = 0
|
||||
@@ -306,13 +309,13 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
# Use (engine_id, dp_rank) to group reqs with same dp.
|
||||
# See comments in MooncakeBootstrapServer.
|
||||
self.reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
|
||||
self.reqs_to_send: dict[ReqId, tuple[TransferId, list[int]]] = {}
|
||||
self.reqs_to_send: dict[ReqId, tuple[TransferId, list[list[int]]]] = {}
|
||||
self.reqs_not_processed: set[TransferId] = set()
|
||||
|
||||
def add_new_req(
|
||||
self,
|
||||
request_id: ReqId,
|
||||
local_block_ids: list[int],
|
||||
local_block_ids: list[list[int]],
|
||||
kv_transfer_params: dict[str, Any],
|
||||
load_remote_cache: bool = True,
|
||||
):
|
||||
@@ -330,7 +333,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
self.reqs_to_send[request_id] = (transfer_id, local_block_ids)
|
||||
|
||||
|
||||
class MooncakeConnector(KVConnectorBase_V1):
|
||||
class MooncakeConnector(KVConnectorBase_V1, SupportsHMA):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
@@ -344,13 +347,18 @@ class MooncakeConnector(KVConnectorBase_V1):
|
||||
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
assert kv_cache_config is not None, (
|
||||
"kv_cache_config is required for SCHEDULER role"
|
||||
)
|
||||
self.connector_scheduler: MooncakeConnectorScheduler | None = (
|
||||
MooncakeConnectorScheduler(vllm_config, self.engine_id)
|
||||
MooncakeConnectorScheduler(vllm_config, self.engine_id, kv_cache_config)
|
||||
)
|
||||
self.connector_worker: MooncakeConnectorWorker | None = None
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id)
|
||||
self.connector_worker = MooncakeConnectorWorker(
|
||||
vllm_config, self.engine_id, kv_cache_config
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
|
||||
@@ -401,6 +409,14 @@ class MooncakeConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, (block_ids,))
|
||||
|
||||
def request_finished_all_groups(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: tuple[list[int], ...],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
@@ -445,8 +461,14 @@ class MooncakeConnector(KVConnectorBase_V1):
|
||||
class MooncakeConnectorScheduler:
|
||||
"""Implementation of Scheduler side methods"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
engine_id: str,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
assert vllm_config.kv_transfer_config
|
||||
self.is_kv_producer: bool = (
|
||||
@@ -457,15 +479,49 @@ class MooncakeConnectorScheduler:
|
||||
)
|
||||
logger.info("Initializing Mooncake Transfer Engine Scheduler %s", engine_id)
|
||||
|
||||
self._is_hma_required = (
|
||||
not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
|
||||
and any(
|
||||
not isinstance(g.kv_cache_spec, FullAttentionSpec)
|
||||
for g in kv_cache_config.kv_cache_groups
|
||||
)
|
||||
)
|
||||
|
||||
# Requests that need to start recv/send.
|
||||
# New requests are added by update_state_after_alloc in
|
||||
# the scheduler. Used to make metadata passed to Worker.
|
||||
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||
self._reqs_need_send: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||
self._reqs_need_recv: dict[ReqId, tuple[Request, list[list[int]]]] = {}
|
||||
self._reqs_need_send: dict[ReqId, tuple[Request, list[list[int]]]] = {}
|
||||
# Reqs to remove from processed set because they're not to send after
|
||||
# remote prefill or aborted.
|
||||
self._reqs_not_processed: set[TransferId] = set()
|
||||
|
||||
# Compute sliding window block counts per KV cache group.
|
||||
sw_sizes_tokens: list[tuple[int, int]] = [
|
||||
(g.kv_cache_spec.sliding_window, g.kv_cache_spec.block_size)
|
||||
if isinstance(g.kv_cache_spec, SlidingWindowSpec)
|
||||
else (0, self.block_size)
|
||||
for g in kv_cache_config.kv_cache_groups
|
||||
]
|
||||
# cdiv(n_tokens, block_size) gives blocks/window; add 1 to
|
||||
# conservatively account for boundary overlap.
|
||||
self.blocks_per_sw = [
|
||||
cdiv(n_tokens, block_size) + 1 if n_tokens else 0
|
||||
for n_tokens, block_size in sw_sizes_tokens
|
||||
]
|
||||
|
||||
def get_sw_clipped_blocks(
|
||||
self,
|
||||
block_ids: tuple[list[int], ...] | list[list[int]],
|
||||
) -> list[list[int]]:
|
||||
"""Clip per-group block IDs to sliding window size."""
|
||||
if len(block_ids) == 0 or not self._is_hma_required:
|
||||
return list(block_ids)
|
||||
return [
|
||||
blocks[-self.blocks_per_sw[i] :] if self.blocks_per_sw[i] > 0 else blocks
|
||||
for i, blocks in enumerate(block_ids)
|
||||
]
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int, bool]:
|
||||
@@ -530,9 +586,12 @@ class MooncakeConnectorScheduler:
|
||||
# If remote_blocks and num_external_tokens = 0, we have
|
||||
# a full prefix cache hit on the D worker. We need to call
|
||||
# send_notif in _read_blocks to free the memory on the P.
|
||||
local_block_ids = (
|
||||
blocks.get_unhashed_block_ids() if num_external_tokens > 0 else []
|
||||
unhashed_block_ids = (
|
||||
blocks.get_unhashed_block_ids_all_groups()
|
||||
if num_external_tokens > 0
|
||||
else ()
|
||||
)
|
||||
local_block_ids = self.get_sw_clipped_blocks(unhashed_block_ids)
|
||||
# Get unhashed blocks to pull from remote.
|
||||
self._reqs_need_recv[request.request_id] = (request, local_block_ids)
|
||||
else:
|
||||
@@ -587,7 +646,7 @@ class MooncakeConnectorScheduler:
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
block_ids: tuple[list[int], ...],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Once a request is finished, determine whether request blocks
|
||||
@@ -630,10 +689,13 @@ class MooncakeConnectorScheduler:
|
||||
|
||||
# TODO: check whether block_ids actually ever be 0. If not we could
|
||||
# remove the conditional below
|
||||
delay_free_blocks = len(block_ids) > 0
|
||||
delay_free_blocks = any(len(group) > 0 for group in block_ids)
|
||||
|
||||
if delay_free_blocks:
|
||||
self._reqs_need_send[request.request_id] = (request, block_ids)
|
||||
self._reqs_need_send[request.request_id] = (
|
||||
request,
|
||||
self.get_sw_clipped_blocks(block_ids),
|
||||
)
|
||||
|
||||
return delay_free_blocks, None
|
||||
|
||||
@@ -641,7 +703,12 @@ class MooncakeConnectorScheduler:
|
||||
class MooncakeConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
engine_id: str,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
if TransferEngine is None:
|
||||
logger.error("Mooncake is not available")
|
||||
raise RuntimeError("Mooncake is not available")
|
||||
@@ -752,6 +819,7 @@ class MooncakeConnectorWorker:
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.use_mla = self.model_config.use_mla
|
||||
self._sync_block_size_with_kernel()
|
||||
|
||||
@@ -1103,27 +1171,61 @@ class MooncakeConnectorWorker:
|
||||
remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}"
|
||||
|
||||
for d_req_id, send_meta in ready_reqs:
|
||||
_, remote_block_ids = agent_meta.req_blocks[d_req_id]
|
||||
num_remote_blocks = len(remote_block_ids)
|
||||
if num_remote_blocks == 0:
|
||||
_, remote_block_ids_per_group = agent_meta.req_blocks[d_req_id]
|
||||
|
||||
if not remote_block_ids_per_group or all(
|
||||
len(g) == 0 for g in remote_block_ids_per_group
|
||||
):
|
||||
continue
|
||||
|
||||
local_block_ids = send_meta.local_block_ids
|
||||
# Partial prefix cache hit: just read uncomputed blocks.
|
||||
num_local_blocks = len(local_block_ids)
|
||||
if num_local_blocks < num_remote_blocks:
|
||||
# Per-group partial hit trimming, then flatten.
|
||||
# With HMA, groups share the same KV tensor but use different
|
||||
# block ranges. We trim and concatenate so the coalescer and
|
||||
# address math see one flat block list — same as non-HMA, but
|
||||
# now including blocks from every group.
|
||||
local_block_ids: list[int] = []
|
||||
remote_block_ids: list[int] = []
|
||||
has_block_error = False
|
||||
if len(send_meta.local_block_ids) != len(remote_block_ids_per_group):
|
||||
logger.error(
|
||||
"req %s: local blocks(%d) less than remote blocks(%d)!",
|
||||
"req %s: KV group count mismatch: local=%d, remote=%d",
|
||||
d_req_id,
|
||||
num_local_blocks,
|
||||
num_remote_blocks,
|
||||
len(send_meta.local_block_ids),
|
||||
len(remote_block_ids_per_group),
|
||||
)
|
||||
err_reqs.append(d_req_id)
|
||||
if err_msg is None:
|
||||
err_msg = "KV group count mismatch"
|
||||
continue
|
||||
for local_group, remote_group in zip(
|
||||
send_meta.local_block_ids, remote_block_ids_per_group
|
||||
):
|
||||
n_local = len(local_group)
|
||||
n_remote = len(remote_group)
|
||||
if n_local < n_remote:
|
||||
logger.error(
|
||||
"req %s: local blocks(%d) < remote blocks(%d) "
|
||||
"in a KV cache group",
|
||||
d_req_id,
|
||||
n_local,
|
||||
n_remote,
|
||||
)
|
||||
has_block_error = True
|
||||
break
|
||||
if n_local > n_remote:
|
||||
# Partial prefix cache hit: just read uncomputed blocks.
|
||||
local_group = local_group[-n_remote:]
|
||||
local_block_ids.extend(local_group)
|
||||
remote_block_ids.extend(remote_group)
|
||||
|
||||
if has_block_error:
|
||||
err_reqs.append(d_req_id)
|
||||
if err_msg is None:
|
||||
err_msg = "P num blocks less than D"
|
||||
continue
|
||||
if num_local_blocks > num_remote_blocks:
|
||||
local_block_ids = local_block_ids[-num_remote_blocks:]
|
||||
|
||||
if not local_block_ids:
|
||||
continue
|
||||
|
||||
# Group by indices
|
||||
group_local_block_ids, group_remote_block_ids = group_concurrent_contiguous(
|
||||
@@ -1215,7 +1317,7 @@ class MooncakeConnectorWorker:
|
||||
logger.debug(
|
||||
"Sending kv_caches for request %s (%d blocks) to %s",
|
||||
d_req_id,
|
||||
num_remote_blocks,
|
||||
len(local_block_ids),
|
||||
remote_session,
|
||||
)
|
||||
|
||||
@@ -1273,23 +1375,24 @@ class MooncakeConnectorWorker:
|
||||
continue
|
||||
|
||||
seen_base_addresses.append(base_addr)
|
||||
curr_tensor_size_bytes = cache.nbytes
|
||||
|
||||
if tensor_size_bytes is None:
|
||||
tensor_size_bytes = curr_tensor_size_bytes
|
||||
tensor_size_bytes = cache.nbytes
|
||||
self.num_blocks = cache.shape[0]
|
||||
assert cache.shape[0] == self.num_blocks, (
|
||||
"All kv cache tensors must have the same number of blocks"
|
||||
)
|
||||
assert curr_tensor_size_bytes % self.num_blocks == 0, (
|
||||
"Mooncake expects each kv cache tensor size to be "
|
||||
"divisible by the number of blocks."
|
||||
)
|
||||
self.block_len_per_layer.append(
|
||||
curr_tensor_size_bytes // self.num_blocks
|
||||
)
|
||||
|
||||
# Use stride-based block length so RDMA reaches the last
|
||||
# block's padding (e.g. DeepseekV4 MLA alignment). stride(0)
|
||||
# reflects the actual byte distance between consecutive
|
||||
# blocks in GPU memory, which matches or exceeds the
|
||||
# shape-based size.
|
||||
block_len = cache.stride(0) * cache.element_size()
|
||||
|
||||
self.block_len_per_layer.append(block_len)
|
||||
kv_data_ptrs.append(base_addr)
|
||||
kv_data_lens.append(curr_tensor_size_bytes)
|
||||
kv_data_lens.append(self.num_blocks * block_len)
|
||||
|
||||
self.kv_caches_base_addr = seen_base_addresses
|
||||
self.seen_base_addresses = seen_base_addresses
|
||||
|
||||
@@ -299,6 +299,9 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||
tools: list[ChatCompletionFunctionToolParam] | None
|
||||
"""The tools for developer role."""
|
||||
|
||||
task: str | None
|
||||
"""Model-specific task marker. Currently passed through for DeepSeek V4."""
|
||||
|
||||
|
||||
ChatCompletionMessageParam: TypeAlias = (
|
||||
OpenAIChatCompletionMessageParam
|
||||
@@ -333,6 +336,9 @@ class ConversationMessage(TypedDict, total=False):
|
||||
tools: list[ChatCompletionFunctionToolParam] | None
|
||||
"""The tools for developer role."""
|
||||
|
||||
task: str | None
|
||||
"""Model-specific task marker. Currently passed through for DeepSeek V4."""
|
||||
|
||||
|
||||
# Passed in by user
|
||||
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]
|
||||
@@ -1566,6 +1572,9 @@ def _parse_chat_message_content(
|
||||
if "name" in message and isinstance(message["name"], str):
|
||||
result_msg["name"] = message["name"]
|
||||
|
||||
if "task" in message and isinstance(message["task"], str):
|
||||
result_msg["task"] = message["task"]
|
||||
|
||||
if role == "developer":
|
||||
result_msg["tools"] = message.get("tools", None)
|
||||
return result
|
||||
|
||||
@@ -100,6 +100,8 @@ class DeepGemmFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
|
||||
else params.weight_scale,
|
||||
quant_block_shape=tuple(layer.weight_block_size),
|
||||
use_e8m0=self.use_deep_gemm_e8m0,
|
||||
is_bmm=getattr(layer, "is_bmm", False),
|
||||
bmm_batch_size=getattr(layer, "bmm_batch_size", 0),
|
||||
)
|
||||
replace_parameter(layer, params.WEIGHT, dg_weight)
|
||||
replace_parameter(layer, scale_attr, dg_weight_scale)
|
||||
|
||||
@@ -1422,6 +1422,20 @@ class MLADims:
|
||||
def get_mla_dims(model_config: ModelConfig) -> MLADims:
|
||||
hf_text_config = model_config.hf_text_config
|
||||
|
||||
# Check if this is a DeepseekV4 config (uses unified head_dim + rope_head_dim)
|
||||
if hasattr(hf_text_config, "compress_ratios"):
|
||||
# DeepseekV4 style config: unified head_dim with rope_head_dim
|
||||
head_dim = hf_text_config.head_dim
|
||||
rope_head_dim = hf_text_config.qk_rope_head_dim
|
||||
return MLADims(
|
||||
q_lora_rank=hf_text_config.q_lora_rank,
|
||||
kv_lora_rank=head_dim,
|
||||
qk_nope_head_dim=head_dim - rope_head_dim,
|
||||
qk_rope_head_dim=rope_head_dim,
|
||||
v_head_dim=head_dim,
|
||||
)
|
||||
|
||||
# DeepseekV2/V3 style config
|
||||
return MLADims(
|
||||
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
|
||||
kv_lora_rank=hf_text_config.kv_lora_rank,
|
||||
@@ -2191,6 +2205,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
# DSV3.2 MLA Specific Arguments
|
||||
indexer: object | None = None,
|
||||
q_pad_num_heads: int | None = None,
|
||||
) -> None:
|
||||
@@ -2213,6 +2228,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
self.indexer = indexer
|
||||
self.q_pad_num_heads = q_pad_num_heads
|
||||
self.supports_quant_query_input = True
|
||||
self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
|
||||
|
||||
# Use flashinfer's optimized concat_mla_k kernel when available.
|
||||
# The kernel is optimized for DeepSeek V3 dimensions:
|
||||
|
||||
@@ -0,0 +1,438 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, cast
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.utils import cublas_gemm_bf16_bf16_fp32
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.ops.deepseek_v4_ops.fused_compress_quant_cache import (
|
||||
_fused_kv_compress_norm_rope_insert_indexer_attn,
|
||||
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn,
|
||||
_fused_kv_compress_norm_rope_insert_sparse_attn,
|
||||
)
|
||||
from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import (
|
||||
MXFP4_BLOCK_SIZE,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
KVCacheSpec,
|
||||
MLAAttentionSpec,
|
||||
SlidingWindowMLASpec,
|
||||
)
|
||||
|
||||
|
||||
class CompressorBackend(AttentionBackend):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CompressorBackend"
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(1)]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [512, 1024]
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["CompressorMetadataBuilder"]:
|
||||
return CompressorMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
assert num_kv_heads == 1
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order(
|
||||
include_num_layers_dimension: bool = False,
|
||||
) -> tuple[int, ...]:
|
||||
if include_num_layers_dimension:
|
||||
return (0, 1, 2, 3)
|
||||
return (0, 1, 2)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressorMetadata:
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
block_size: int
|
||||
|
||||
token_to_req_indices: torch.Tensor | None = None # [num_tokens]
|
||||
|
||||
|
||||
class CompressorMetadataBuilder(AttentionMetadataBuilder):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert isinstance(self.kv_cache_spec, SlidingWindowMLASpec | MLAAttentionSpec)
|
||||
mla_spec = cast(SlidingWindowMLASpec | MLAAttentionSpec, self.kv_cache_spec)
|
||||
self.block_size = mla_spec.block_size
|
||||
|
||||
self.token_to_req_indices = torch.zeros(
|
||||
self.vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> CompressorMetadata:
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
x = torch.repeat_interleave(torch.arange(num_reqs), query_lens).pin_memory()
|
||||
token_to_req_indices = self.token_to_req_indices[: x.shape[0]]
|
||||
token_to_req_indices.copy_(x, non_blocking=True)
|
||||
return CompressorMetadata(
|
||||
block_table=common_attn_metadata.block_table_tensor.clamp_(min=0),
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
block_size=self.block_size,
|
||||
token_to_req_indices=token_to_req_indices,
|
||||
)
|
||||
|
||||
|
||||
class CompressorStateCache(torch.nn.Module, AttentionLayerBase):
|
||||
def __init__(
|
||||
self,
|
||||
state_dim: int,
|
||||
dtype: torch.dtype,
|
||||
compress_ratio: int,
|
||||
prefix: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.state_dim = state_dim
|
||||
self.dtype = dtype
|
||||
self.prefix = prefix
|
||||
self.kv_cache = torch.tensor([])
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
assert self.dtype == torch.float32
|
||||
assert compress_ratio in [4, 128]
|
||||
coff = 1 + (compress_ratio == 4)
|
||||
self.sliding_window = coff * compress_ratio
|
||||
# Block size is constrained by tensor sharing between compressor states
|
||||
# and KV blocks. Since compressor states share the same physical tensor
|
||||
# as KV blocks, they must use the same page size.
|
||||
# The KV block shape [256//4, head_dim] = [64, 584] determines:
|
||||
# - C4 compressor block shape [4, 2*512*2*4] -> block_size = 4
|
||||
# - C128 compressor block shape [8, 512*2*4] -> block_size = 8
|
||||
# TODO(yifan): make block size automatically determined and configurable.
|
||||
if compress_ratio == 4:
|
||||
self.block_size = 4
|
||||
elif compress_ratio == 128:
|
||||
self.block_size = 8
|
||||
else:
|
||||
raise ValueError(f"Invalid compress ratio: {compress_ratio}")
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
return SlidingWindowMLASpec( # only has one vector instead of K + V
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.state_dim,
|
||||
dtype=self.dtype,
|
||||
sliding_window=self.sliding_window,
|
||||
alignment=576, # NOTE: FlashMLA requires 576B alignment
|
||||
)
|
||||
|
||||
def forward(self): ...
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return CompressorBackend
|
||||
|
||||
|
||||
class DeepseekCompressor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
compress_ratio: int,
|
||||
hidden_size: int,
|
||||
head_dim: int,
|
||||
rotate: bool = False,
|
||||
prefix: str = "",
|
||||
k_cache_prefix="",
|
||||
use_fp4_cache: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.compress_ratio = compress_ratio
|
||||
self.hidden_size = hidden_size
|
||||
self.head_dim = head_dim
|
||||
self.rotate = rotate
|
||||
self.prefix = prefix
|
||||
self.k_cache_prefix = k_cache_prefix
|
||||
self.use_fp4_cache = use_fp4_cache
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.rope_head_dim = config.qk_rope_head_dim
|
||||
self.nope_head_dim = self.head_dim - self.rope_head_dim
|
||||
self.rms_norm_eps = config.rms_norm_eps
|
||||
self.device = current_platform.device_type
|
||||
self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
self.overlap = compress_ratio == 4
|
||||
self.coff = 1 + self.overlap
|
||||
|
||||
state_dtype = torch.float32
|
||||
self.ape = nn.Parameter(
|
||||
torch.empty(
|
||||
(compress_ratio, self.coff * self.head_dim),
|
||||
dtype=state_dtype,
|
||||
device=self.device,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
self.fused_wkv_wgate = MergedColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
[self.coff * self.head_dim, self.coff * self.head_dim],
|
||||
bias=False,
|
||||
return_bias=False,
|
||||
quant_config=None,
|
||||
disable_tp=True,
|
||||
prefix=f"{prefix}.fused_wkv_wgate",
|
||||
)
|
||||
self.norm = RMSNorm(self.head_dim, self.rms_norm_eps)
|
||||
|
||||
self.state_cache = CompressorStateCache(
|
||||
state_dim=2 * self.coff * self.head_dim, # kv_state + score_state
|
||||
dtype=state_dtype,
|
||||
compress_ratio=compress_ratio,
|
||||
prefix=f"{prefix}.state_cache",
|
||||
)
|
||||
|
||||
# Save reference to static_forward_context for forward-time KV cache lookup.
|
||||
# get_current_vllm_config() is only available during __init__, not forward.
|
||||
self._static_forward_context = (
|
||||
vllm_config.compilation_config.static_forward_context
|
||||
)
|
||||
|
||||
if self.head_dim == 512:
|
||||
assert not use_fp4_cache, (
|
||||
"MXFP4 cache is only supported for indexer (head=128)"
|
||||
)
|
||||
self._fused_kernel = _fused_kv_compress_norm_rope_insert_sparse_attn
|
||||
self._quant_block = 64
|
||||
self._token_stride = self.nope_head_dim + self.rope_head_dim * 2
|
||||
self._scale_dim = self.nope_head_dim // 64 + 1 # 7 real + 1 pad
|
||||
self._num_warps = 4
|
||||
elif self.head_dim == 128:
|
||||
if use_fp4_cache:
|
||||
self._fused_kernel = (
|
||||
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn
|
||||
)
|
||||
self._quant_block = MXFP4_BLOCK_SIZE
|
||||
self._token_stride = self.head_dim // 2
|
||||
self._scale_dim = self.head_dim // MXFP4_BLOCK_SIZE
|
||||
else:
|
||||
self._fused_kernel = _fused_kv_compress_norm_rope_insert_indexer_attn
|
||||
self._quant_block = 128
|
||||
self._token_stride = self.head_dim
|
||||
self._scale_dim = 4 # single float32 scale
|
||||
self._num_warps = 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported head_dim for fused quant+cache: {self.head_dim}"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
# [num_tokens, hidden_size]
|
||||
x: torch.Tensor,
|
||||
# [num_tokens]
|
||||
positions: torch.Tensor,
|
||||
rotary_emb,
|
||||
) -> None:
|
||||
num_tokens, _ = x.shape
|
||||
# bf16 weights/activations but fp32 output for numerical stability of
|
||||
# the downstream compressor math.
|
||||
kv_score = cublas_gemm_bf16_bf16_fp32(x, self.fused_wkv_wgate.weight)
|
||||
# Each of shape [num_tokens, coff * self.head_dim]
|
||||
# input bf16, output are fp32
|
||||
kv, score = kv_score.split(
|
||||
[self.coff * self.head_dim, self.coff * self.head_dim], dim=-1
|
||||
)
|
||||
|
||||
# Get the metadata and handle dummy profiling run.
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if not isinstance(attn_metadata, dict):
|
||||
return
|
||||
|
||||
state_metadata = cast(
|
||||
CompressorMetadata, attn_metadata[self.state_cache.prefix]
|
||||
)
|
||||
token_to_req_indices = state_metadata.token_to_req_indices
|
||||
slot_mapping = state_metadata.slot_mapping
|
||||
num_actual = slot_mapping.shape[0]
|
||||
block_table = state_metadata.block_table
|
||||
block_size = state_metadata.block_size
|
||||
|
||||
# [num_blocks, block_size, kv_dim+score_dim], where kv_dim == score_dim
|
||||
state_cache = self.state_cache.kv_cache
|
||||
# kv_state stored in first half, score_state stored in second half
|
||||
state_width = state_cache.shape[-1] // 2
|
||||
|
||||
# Store the KV and score (with fused APE addition) in the state.
|
||||
# NOTE: PDL is disabled — both this kernel and _fused_kernel below
|
||||
# depend on preceding kernel outputs (kv/score from the cublas GEMM;
|
||||
# state_cache from this kernel) but neither emits/waits on PDL grid
|
||||
# dependency primitives, so launch_pdl=True caused a read-after-write
|
||||
# race and non-deterministic output.
|
||||
_save_partial_states_kernel[(num_actual,)](
|
||||
kv,
|
||||
kv.stride(0),
|
||||
score,
|
||||
score.stride(0),
|
||||
self.ape,
|
||||
self.ape.stride(0),
|
||||
positions,
|
||||
state_cache,
|
||||
state_cache.stride(0),
|
||||
state_cache.stride(1),
|
||||
slot_mapping,
|
||||
block_size,
|
||||
HEAD_SIZE=kv.shape[-1],
|
||||
TRITON_BLOCK_SIZE=triton.next_power_of_2(kv.shape[-1]),
|
||||
STATE_WIDTH=state_width,
|
||||
COMPRESS_RATIO=self.compress_ratio,
|
||||
launch_pdl=False,
|
||||
)
|
||||
|
||||
# Fused: compress → RMSNorm → RoPE → FP8 quant → KV cache write.
|
||||
# RoPE requirements (kernel applies forward GPT-J style rotation):
|
||||
# - is_neox_style=False (interleaved pairs, NOT split-half)
|
||||
# - cos_sin_cache layout: [max_pos, rope_head_dim] with first half cos,
|
||||
# second half sin (per-pair, length rope_head_dim // 2 each)
|
||||
# - applied to LAST rope_head_dim elements of head_dim
|
||||
# - position used: (positions // compress_ratio) * compress_ratio
|
||||
cos_sin_cache = rotary_emb.cos_sin_cache
|
||||
k_cache_metadata = cast(Any, attn_metadata[self.k_cache_prefix])
|
||||
kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache
|
||||
|
||||
self._fused_kernel[(num_actual,)](
|
||||
# state cache
|
||||
state_cache,
|
||||
state_cache.stride(0),
|
||||
state_cache.stride(1),
|
||||
# metadata
|
||||
token_to_req_indices,
|
||||
positions,
|
||||
slot_mapping,
|
||||
block_table,
|
||||
block_table.stride(0),
|
||||
block_size,
|
||||
# RMSNorm
|
||||
self.norm.weight,
|
||||
self.rms_norm_eps,
|
||||
# RoPE
|
||||
cos_sin_cache,
|
||||
cos_sin_cache.stride(0),
|
||||
# KV cache
|
||||
kv_cache,
|
||||
k_cache_metadata.slot_mapping,
|
||||
kv_cache.shape[1], # paged KV cache block size (tokens per block)
|
||||
# constexprs
|
||||
HEAD_SIZE=self.head_dim,
|
||||
TRITON_BLOCK_SIZE=triton.next_power_of_2(self.head_dim),
|
||||
STATE_WIDTH=state_width,
|
||||
COMPRESS_RATIO=self.compress_ratio,
|
||||
OVERLAP=self.overlap,
|
||||
ROPE_HEAD_DIM=self.rope_head_dim,
|
||||
FP8_MAX=448.0,
|
||||
QUANT_BLOCK=self._quant_block,
|
||||
TOKEN_STRIDE=self._token_stride,
|
||||
SCALE_DIM=self._scale_dim,
|
||||
KV_BLOCK_STRIDE=kv_cache.stride(0),
|
||||
num_warps=self._num_warps,
|
||||
launch_pdl=False,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _save_partial_states_kernel(
|
||||
kv_ptr,
|
||||
kv_stride,
|
||||
score_ptr,
|
||||
score_stride,
|
||||
ape_ptr,
|
||||
ape_stride,
|
||||
positions_ptr,
|
||||
state_cache_ptr,
|
||||
state_cache_stride0,
|
||||
state_cache_stride1,
|
||||
slot_mapping_ptr,
|
||||
block_size,
|
||||
HEAD_SIZE: tl.constexpr,
|
||||
TRITON_BLOCK_SIZE: tl.constexpr,
|
||||
# state_cache last dim packs [kv_state, score_state], each STATE_WIDTH wide.
|
||||
STATE_WIDTH: tl.constexpr,
|
||||
COMPRESS_RATIO: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
slot_id = tl.load(slot_mapping_ptr + token_idx)
|
||||
|
||||
# Skip padded / invalid tokens (slot_id == -1 is the PAD sentinel used
|
||||
# by vLLM). During CUDA graph replay the batch may contain padding
|
||||
# tokens whose slot_mapping is -1; writing to kv_state[-1] would be an
|
||||
# illegal memory access.
|
||||
if slot_id < 0:
|
||||
return
|
||||
|
||||
block_idx = slot_id // block_size
|
||||
pos_in_block = slot_id % block_size
|
||||
base_ptr = (
|
||||
state_cache_ptr
|
||||
+ block_idx * state_cache_stride0
|
||||
+ pos_in_block * state_cache_stride1
|
||||
)
|
||||
|
||||
block = tl.arange(0, TRITON_BLOCK_SIZE)
|
||||
mask = block < HEAD_SIZE
|
||||
|
||||
kv = tl.load(kv_ptr + token_idx * kv_stride + block, mask=mask)
|
||||
tl.store(base_ptr + block, kv, mask=mask)
|
||||
|
||||
# Fused: score += ape[position % compress_ratio]
|
||||
position = tl.load(positions_ptr + token_idx)
|
||||
ape_row = position % COMPRESS_RATIO
|
||||
ape = tl.load(ape_ptr + ape_row * ape_stride + block, mask=mask)
|
||||
score = tl.load(score_ptr + token_idx * score_stride + block, mask=mask)
|
||||
tl.store(
|
||||
base_ptr + STATE_WIDTH + block,
|
||||
score + ape,
|
||||
mask=mask,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -117,8 +117,10 @@ class RoutingMethodType(IntEnum):
|
||||
Custom = (6,)
|
||||
# Simulated
|
||||
Simulated = (7,)
|
||||
# Deepseek V4 -> sqrtsoftplus + Bias + Normalize
|
||||
DeepseekV4 = (8,)
|
||||
# Unspecified
|
||||
Unspecified = 8.0
|
||||
Unspecified = 9.0
|
||||
|
||||
|
||||
def get_routing_method_type(
|
||||
@@ -128,6 +130,14 @@ def get_routing_method_type(
|
||||
num_expert_group: int | None,
|
||||
has_e_score_bias: bool,
|
||||
) -> RoutingMethodType:
|
||||
if scoring_func == "sqrtsoftplus":
|
||||
# DeepSeek V4 uses sqrtsoftplus routing with optional routing bias
|
||||
# and top-k renormalization.
|
||||
if renormalize:
|
||||
return RoutingMethodType.DeepseekV4
|
||||
else:
|
||||
return RoutingMethodType.Unspecified
|
||||
|
||||
if has_e_score_bias:
|
||||
if (num_expert_group or 0) > 0 and scoring_func == "sigmoid":
|
||||
return RoutingMethodType.DeepSeekV3
|
||||
@@ -230,6 +240,13 @@ class FusedMoEQuantConfig:
|
||||
_w2: FusedMoEQuantDesc
|
||||
is_nvfp4_scale_swizzled: bool = True
|
||||
|
||||
# MXFP4-specific TRTLLM parameters for SwiGLU activation clamping.
|
||||
# These correspond to gemm1_alpha, gemm1_beta, gemm1_clamp_limit
|
||||
# in TrtLlmMxfp4ExpertsBase.
|
||||
gemm1_alpha: float | None = None
|
||||
gemm1_beta: float | None = None
|
||||
gemm1_clamp_limit: float | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert not self.per_act_token_quant or self.block_shape is None, (
|
||||
"illegal quantization"
|
||||
@@ -477,6 +494,9 @@ class FusedMoEQuantConfig:
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
weight_dtype: torch.dtype | str | None = None,
|
||||
is_nvfp4_scale_swizzled: bool = True,
|
||||
gemm1_alpha: float | None = None,
|
||||
gemm1_beta: float | None = None,
|
||||
gemm1_clamp_limit: float | None = None,
|
||||
) -> "FusedMoEQuantConfig":
|
||||
"""
|
||||
General builder function for a FusedMoEQuantConfig.
|
||||
@@ -507,6 +527,9 @@ class FusedMoEQuantConfig:
|
||||
- w1_zp: Optional w1 zero points for int4/int8 quantization.
|
||||
- w2_zp: Optional w2 zero points for int4/int8 quantization.
|
||||
- is_nvfp4_scale_swizzled: Whether to swizzle the nvfp4 scale swizzling.
|
||||
- gemm1_alpha: Optional MXFP4 TRTLLM SwiGLU alpha parameter.
|
||||
- gemm1_beta: Optional MXFP4 TRTLLM SwiGLU beta parameter.
|
||||
- gemm1_clamp_limit: Optional MXFP4 TRTLLM SwiGLU clamp limit.
|
||||
"""
|
||||
assert not isinstance(quant_dtype, str) or quant_dtype in {
|
||||
"nvfp4",
|
||||
@@ -540,6 +563,9 @@ class FusedMoEQuantConfig:
|
||||
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
|
||||
),
|
||||
is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled,
|
||||
gemm1_alpha=gemm1_alpha,
|
||||
gemm1_beta=gemm1_beta,
|
||||
gemm1_clamp_limit=gemm1_clamp_limit,
|
||||
)
|
||||
assert quant_config.per_act_token_quant == per_act_token_quant
|
||||
assert quant_config.per_out_ch_quant == per_out_ch_quant
|
||||
@@ -650,6 +676,9 @@ def mxfp4_w4a16_moe_quant_config(
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
gemm1_alpha: float | None = None,
|
||||
gemm1_beta: float | None = None,
|
||||
gemm1_clamp_limit: float | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for unquantized activations and mxfp4 weights.
|
||||
@@ -659,6 +688,9 @@ def mxfp4_w4a16_moe_quant_config(
|
||||
_a2=FusedMoEQuantDesc(),
|
||||
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
|
||||
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
|
||||
gemm1_alpha=gemm1_alpha,
|
||||
gemm1_beta=gemm1_beta,
|
||||
gemm1_clamp_limit=gemm1_clamp_limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -670,6 +702,9 @@ def mxfp4_mxfp8_moe_quant_config(
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
gemm1_alpha: float | None = None,
|
||||
gemm1_beta: float | None = None,
|
||||
gemm1_clamp_limit: float | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for mxfp4 activations and mxfp4 weights.
|
||||
@@ -679,6 +714,9 @@ def mxfp4_mxfp8_moe_quant_config(
|
||||
_a2=FusedMoEQuantDesc("mxfp8"),
|
||||
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
|
||||
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
|
||||
gemm1_alpha=gemm1_alpha,
|
||||
gemm1_beta=gemm1_beta,
|
||||
gemm1_clamp_limit=gemm1_clamp_limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -712,6 +750,9 @@ def ocp_mx_moe_quant_config(
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
gemm1_alpha: float | None = None,
|
||||
gemm1_beta: float | None = None,
|
||||
gemm1_clamp_limit: float | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for mxfp4 activations and mxfp4 weights.
|
||||
@@ -729,6 +770,9 @@ def ocp_mx_moe_quant_config(
|
||||
per_act_token_quant=False,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=block_shape,
|
||||
gemm1_alpha=gemm1_alpha,
|
||||
gemm1_beta=gemm1_beta,
|
||||
gemm1_clamp_limit=gemm1_clamp_limit,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -25,15 +25,20 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8_packed_for_deepgemm,
|
||||
silu_mul_per_token_group_quant_fp8_colmajor,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
silu_mul_quant_fp8_packed_triton as fused_silu_mul_fp8_quant_packed,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8Static128BlockSym,
|
||||
kMxfp4Static,
|
||||
)
|
||||
from vllm.utils.deep_gemm import (
|
||||
DeepGemmQuantScaleFMT,
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
is_deep_gemm_supported,
|
||||
m_grouped_fp8_fp4_gemm_nt_contiguous,
|
||||
m_grouped_fp8_gemm_nt_contiguous,
|
||||
)
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
@@ -197,8 +202,14 @@ class DeepGemmExperts(mk.FusedMoEExpertsModular):
|
||||
M_sum, N = input.size()
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
|
||||
# 1. DeepGemm UE8M0: use packed per-token-group quant
|
||||
# 1. DeepGemm UE8M0: fused SiLU+mul+clamp+quant+pack
|
||||
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
|
||||
if activation == MoEActivation.SILU:
|
||||
return fused_silu_mul_fp8_quant_packed(
|
||||
input=input,
|
||||
output_q=output,
|
||||
group_size=block_k,
|
||||
)
|
||||
act_out = torch.empty(
|
||||
(M_sum, activation_out_dim), dtype=input.dtype, device=input.device
|
||||
)
|
||||
@@ -312,3 +323,225 @@ class DeepGemmExperts(mk.FusedMoEExpertsModular):
|
||||
expert_map=expert_map,
|
||||
output=output,
|
||||
)
|
||||
|
||||
|
||||
class DeepGemmFP4Experts(mk.FusedMoEExpertsModular):
|
||||
"""DeepGemm-based fused MoE expert implementation for FP4 weights.
|
||||
|
||||
Uses m_grouped_fp8_fp4_gemm_nt_contiguous with FP8 activations and
|
||||
MXFP4 (FP4 E2M1 packed as uint8) weights. Requires SM100+ (Blackwell).
|
||||
"""
|
||||
|
||||
# FP8 activation block size (hardcoded since mxfp4_w4a8 quant config
|
||||
# does not set a block_shape on the activation descriptor).
|
||||
_ACT_BLOCK_K = 128
|
||||
# FP4 weight block size
|
||||
_WEIGHT_BLOCK_K = 32
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
|
||||
super().__init__(moe_config=moe_config, quant_config=quant_config)
|
||||
assert quant_config.weight_quant_dtype == "mxfp4"
|
||||
assert not quant_config.per_act_token_quant
|
||||
assert not quant_config.per_out_ch_quant
|
||||
|
||||
self.gemm1_clamp_limit = quant_config.gemm1_clamp_limit
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return (
|
||||
is_deep_gemm_supported()
|
||||
and current_platform.is_device_capability_family(100)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
SUPPORTED_W_A = [
|
||||
(kMxfp4Static, kFp8Dynamic128Sym),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [MoEActivation.SILU, MoEActivation.SWIGLUSTEP]
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
return not (
|
||||
moe_parallel_config.use_fi_nvl_two_sided_kernels
|
||||
or moe_parallel_config.use_fi_nvl_one_sided_kernels
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
block_m = get_mk_alignment_for_contiguous_layout()[0]
|
||||
M_sum = compute_aligned_M(
|
||||
M, topk, local_num_experts, block_m, expert_tokens_meta
|
||||
)
|
||||
assert M_sum % block_m == 0
|
||||
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
workspace1 = (M_sum, max(activation_out_dim, K))
|
||||
workspace2 = (M_sum, max(N, K))
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def _act_mul_quant(
|
||||
self, input: torch.Tensor, output: torch.Tensor, activation: MoEActivation
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
block_k = self._ACT_BLOCK_K
|
||||
scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
|
||||
|
||||
M_sum, N = input.size()
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
|
||||
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
|
||||
assert activation == MoEActivation.SILU
|
||||
return fused_silu_mul_fp8_quant_packed(
|
||||
input=input,
|
||||
output_q=output,
|
||||
group_size=block_k,
|
||||
clamp_limit=self.gemm1_clamp_limit,
|
||||
)
|
||||
|
||||
if activation == MoEActivation.SILU:
|
||||
use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
|
||||
return silu_mul_per_token_group_quant_fp8_colmajor(
|
||||
input=input,
|
||||
output=output,
|
||||
use_ue8m0=use_ue8m0,
|
||||
)
|
||||
|
||||
act_out = torch.empty(
|
||||
(M_sum, activation_out_dim), dtype=input.dtype, device=input.device
|
||||
)
|
||||
self.activation(activation, act_out, input)
|
||||
return per_token_group_quant_fp8(
|
||||
act_out, block_k, column_major_scales=True, out_q=output
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert a1q_scale is not None
|
||||
assert a2_scale is None
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
|
||||
a1q = hidden_states
|
||||
_, N, _ = w1.size()
|
||||
# K comes from activations (full hidden dim), not from w1 which is
|
||||
# packed FP4 (E, N, K//2).
|
||||
K = a1q.size(1)
|
||||
|
||||
local_num_experts = w1.size(0)
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
M_sum = compute_aligned_M(
|
||||
M=topk_ids.size(0),
|
||||
num_topk=topk_ids.size(1),
|
||||
local_num_experts=local_num_experts,
|
||||
alignment=get_mk_alignment_for_contiguous_layout()[0],
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
)
|
||||
|
||||
a1q_perm = _resize_cache(
|
||||
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, K)
|
||||
)
|
||||
a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
|
||||
aq=a1q,
|
||||
aq_scale=a1q_scale,
|
||||
topk_ids=topk_ids,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
aq_out=a1q_perm,
|
||||
)
|
||||
assert a1q.size(0) == M_sum
|
||||
|
||||
# FC1: FP8 activations x FP4 weights
|
||||
# DeepGEMM 2.4.2 requires FP4-packed weights as int8 (kPackedFP4).
|
||||
mm1_out = _resize_cache(workspace2, (M_sum, N))
|
||||
m_grouped_fp8_fp4_gemm_nt_contiguous(
|
||||
(a1q, a1q_scale),
|
||||
(w1.view(torch.int8), self.w1_scale),
|
||||
mm1_out,
|
||||
expert_ids,
|
||||
recipe_a=(1, self._ACT_BLOCK_K),
|
||||
recipe_b=(1, self._WEIGHT_BLOCK_K),
|
||||
)
|
||||
|
||||
# SwiGLU activation + FP8 requant
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
quant_out = _resize_cache(
|
||||
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, activation_out_dim)
|
||||
)
|
||||
a2q, a2q_scale = self._act_mul_quant(
|
||||
input=mm1_out.view(-1, N), output=quant_out, activation=activation
|
||||
)
|
||||
|
||||
# FC2: FP8 activations x FP4 weights
|
||||
mm2_out = _resize_cache(workspace2, (M_sum, K))
|
||||
m_grouped_fp8_fp4_gemm_nt_contiguous(
|
||||
(a2q, a2q_scale),
|
||||
(w2.view(torch.int8), self.w2_scale),
|
||||
mm2_out,
|
||||
expert_ids,
|
||||
recipe_a=(1, self._ACT_BLOCK_K),
|
||||
recipe_b=(1, self._WEIGHT_BLOCK_K),
|
||||
)
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk_weights = torch.ones_like(topk_weights)
|
||||
|
||||
deepgemm_unpermute_and_reduce(
|
||||
a=mm2_out,
|
||||
topk_ids=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
inv_perm=inv_perm,
|
||||
expert_map=expert_map,
|
||||
output=output,
|
||||
)
|
||||
|
||||
@@ -28,8 +28,162 @@ from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
|
||||
from ..utils import swiglu_limit_func
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _patch_make_bitmatrix_metadata() -> None:
|
||||
"""Monkey-patch make_bitmatrix_metadata to support non-power-of-2 top_k.
|
||||
|
||||
triton's tl.arange requires a power-of-2 range. The original kernel
|
||||
computes BLOCK_SIZE = BLOCK_PER_TOK * TOKS_PER_ROW (= 32 * top_k). For
|
||||
DeepSeek-V4 with top_k=6 this gives 192, which is not a power of 2 and
|
||||
causes a compile error at the first forward pass.
|
||||
|
||||
Fix: define a drop-in replacement kernel that accepts an extra constexpr
|
||||
BLOCK_SIZE_PADDED (next power of 2 >= BLOCK_SIZE) and uses it for the
|
||||
tl.arange call while keeping the actual BLOCK_SIZE as the stride between
|
||||
thread-blocks so that all flat indices into NonzeroIndx stay correct.
|
||||
Elements beyond BLOCK_SIZE are masked out (col_indx = 0xffff) and ignored.
|
||||
|
||||
This function is called once at module load time and patches the function
|
||||
inside the triton_kernels tensor module so that SparseMatrix.__post_init__
|
||||
picks up the fixed version transparently.
|
||||
"""
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
try:
|
||||
from vllm.third_party.triton_kernels.tensor_details import (
|
||||
bitmatrix as _bm,
|
||||
)
|
||||
from vllm.third_party.triton_kernels.tensor_details.bitmatrix import (
|
||||
BitmatrixMetadata,
|
||||
_keyed_add,
|
||||
cdiv,
|
||||
)
|
||||
from vllm.third_party.triton_kernels.tensor_details.bitmatrix_details.sum_bitmatrix_rows import ( # noqa: E501
|
||||
sum_bitmatrix_rows,
|
||||
)
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
@triton.jit
|
||||
def _stage2_pow2(
|
||||
ColSortedIndx,
|
||||
RowSortedIndx,
|
||||
NonzeroIndx,
|
||||
n_tokens,
|
||||
ColPartialSum,
|
||||
stride_pm,
|
||||
stride_pn,
|
||||
ColOffs,
|
||||
TOKS_PER_ROW: tl.constexpr,
|
||||
BLOCK_PER_TOK: tl.constexpr,
|
||||
BLOCK_SIZE_PADDED: tl.constexpr,
|
||||
):
|
||||
# Actual number of elements per block (may not be a power of 2).
|
||||
BLOCK_SIZE: tl.constexpr = BLOCK_PER_TOK * TOKS_PER_ROW
|
||||
tl.static_assert(BLOCK_SIZE_PADDED <= 32768)
|
||||
if isinstance(n_tokens, tl.tensor) and n_tokens.dtype.is_ptr():
|
||||
n_tokens = tl.load(n_tokens)
|
||||
nonzero_indx_size = n_tokens * TOKS_PER_ROW
|
||||
pid_m = tl.program_id(0)
|
||||
# Use BLOCK_SIZE_PADDED (a power of 2) for tl.arange, but stride by
|
||||
# the actual BLOCK_SIZE so flat positions in NonzeroIndx are correct.
|
||||
# Elements with offs_local >= BLOCK_SIZE have offs_global beyond the
|
||||
# valid range, get col_indx = 0xffff, and are filtered by the mask
|
||||
# below without producing any output.
|
||||
offs_local = tl.arange(0, BLOCK_SIZE_PADDED)
|
||||
offs_global = pid_m * BLOCK_SIZE + offs_local
|
||||
mask = offs_global < nonzero_indx_size
|
||||
col_indx = tl.load(NonzeroIndx + offs_global, mask=mask, other=-1).to(tl.uint32)
|
||||
kv_pairs = ((col_indx << 16) | offs_local).to(tl.uint32)
|
||||
kv_pairs = tl.sort(kv_pairs, 0)
|
||||
col_indx = kv_pairs >> 16
|
||||
offs_global = pid_m * BLOCK_SIZE + (kv_pairs & 0xFFFF)
|
||||
mask = col_indx != 0xFFFF
|
||||
x = kv_pairs & 0xFFFF0000 | 0x00000001
|
||||
cols_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add)
|
||||
exclusive_run_lengths = (cols_and_inclusive_run_lengths - 1) & 0xFFFF
|
||||
row_sorted_indx = tl.load(
|
||||
ColPartialSum + pid_m * stride_pm + col_indx * stride_pn, mask=mask
|
||||
)
|
||||
row_sorted_indx += tl.load(ColOffs + col_indx, mask=mask)
|
||||
row_sorted_indx += exclusive_run_lengths
|
||||
tl.store(RowSortedIndx + offs_global, row_sorted_indx, mask=mask)
|
||||
tl.store(ColSortedIndx + row_sorted_indx, offs_global, mask=mask)
|
||||
|
||||
def _make_bitmatrix_metadata_pow2_safe(nonzero_indx, bitmatrix):
|
||||
assert nonzero_indx.ndim == 2
|
||||
PARTIAL_BLOCK_M = 32
|
||||
col_sum, col_partial_sum = sum_bitmatrix_rows(
|
||||
bitmatrix, partials_block_size=PARTIAL_BLOCK_M
|
||||
)
|
||||
device = bitmatrix.device
|
||||
n_indx = nonzero_indx.numel()
|
||||
n_cols = bitmatrix.shape[1]
|
||||
col_offs = torch.empty(n_cols, dtype=torch.int32, device=device)
|
||||
combined_indx = torch.empty(n_indx * 2, dtype=torch.int32, device=device)
|
||||
col_sorted_indx = combined_indx[:n_indx]
|
||||
row_sorted_indx = combined_indx[n_indx:]
|
||||
MEMSET_BLOCK = 1024
|
||||
memset_grid = (cdiv(n_indx * 2, MEMSET_BLOCK) + n_cols + 1,)
|
||||
_bm._bitmatrix_metadata_compute_stage1[memset_grid](
|
||||
combined_indx,
|
||||
n_indx * 2,
|
||||
-1,
|
||||
MEMSET_BLOCK,
|
||||
col_sum,
|
||||
col_offs,
|
||||
col_sum.shape[0],
|
||||
col_partial_sum,
|
||||
col_partial_sum.shape[0],
|
||||
col_partial_sum.stride(0),
|
||||
col_partial_sum.stride(1),
|
||||
BLOCK_M=512,
|
||||
BLOCK_N=512,
|
||||
)
|
||||
toks_per_row = nonzero_indx.shape[-1]
|
||||
block_size = PARTIAL_BLOCK_M * toks_per_row
|
||||
# Next power of 2 >= block_size (required by tl.arange).
|
||||
block_size_padded = 1 << (max(block_size, 1) - 1).bit_length()
|
||||
compute_grid = (cdiv(bitmatrix.shape_max[0], PARTIAL_BLOCK_M),)
|
||||
_stage2_pow2[compute_grid](
|
||||
col_sorted_indx,
|
||||
row_sorted_indx,
|
||||
nonzero_indx,
|
||||
bitmatrix.shape[0],
|
||||
col_partial_sum,
|
||||
col_partial_sum.stride(0),
|
||||
col_partial_sum.stride(1),
|
||||
col_offs,
|
||||
TOKS_PER_ROW=toks_per_row,
|
||||
BLOCK_PER_TOK=PARTIAL_BLOCK_M,
|
||||
BLOCK_SIZE_PADDED=block_size_padded,
|
||||
)
|
||||
return BitmatrixMetadata(
|
||||
col_sum=col_sum,
|
||||
col_sorted_indx=col_sorted_indx,
|
||||
row_sorted_indx=row_sorted_indx,
|
||||
)
|
||||
|
||||
# The most reliable patch point: SparseMatrix.__post_init__ looks up
|
||||
# make_bitmatrix_metadata via its own __globals__ dict (the tensor.py
|
||||
# module dict). Patching through __globals__ works regardless of how
|
||||
# sys.modules maps "triton_kernels.tensor" vs
|
||||
# "vllm.third_party.triton_kernels.tensor".
|
||||
from triton_kernels.tensor import SparseMatrix as _SparseMatrix
|
||||
|
||||
_SparseMatrix.__post_init__.__globals__["make_bitmatrix_metadata"] = (
|
||||
_make_bitmatrix_metadata_pow2_safe
|
||||
)
|
||||
# Also patch the bitmatrix module itself in case it is imported directly.
|
||||
_bm.make_bitmatrix_metadata = _make_bitmatrix_metadata_pow2_safe
|
||||
|
||||
|
||||
use_legacy_triton_kernels = False
|
||||
|
||||
if has_triton_kernels():
|
||||
@@ -59,6 +213,8 @@ if has_triton_kernels():
|
||||
use_legacy_triton_kernels = True
|
||||
else:
|
||||
raise
|
||||
if not use_legacy_triton_kernels:
|
||||
_patch_make_bitmatrix_metadata()
|
||||
except (AttributeError, ImportError) as e:
|
||||
logger.error(
|
||||
"Failed to import Triton kernels. Please make sure your triton "
|
||||
@@ -497,6 +653,8 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
|
||||
return False
|
||||
# (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
|
||||
# and ROCm gfx942/gfx950 (which map to 9.4/9.5).
|
||||
if not has_triton_kernels():
|
||||
return False
|
||||
return (9, 0) <= (cap.major, cap.minor) < (11, 0)
|
||||
|
||||
@staticmethod
|
||||
@@ -698,6 +856,37 @@ class UnfusedOAITritonExperts(LoRAExpertsMixin, BaseOAITritonExperts):
|
||||
def moe_sum(self, input: torch.Tensor, output: torch.Tensor):
|
||||
ops.moe_sum(input, output)
|
||||
|
||||
def activation(
|
||||
self,
|
||||
activation: MoEActivation,
|
||||
output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
) -> None:
|
||||
quant_config = self.quant_config or FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
if activation == MoEActivation.SWIGLUOAI:
|
||||
alpha = (
|
||||
quant_config.gemm1_alpha
|
||||
if quant_config.gemm1_alpha is not None
|
||||
else 1.702
|
||||
)
|
||||
limit = (
|
||||
quant_config.gemm1_clamp_limit
|
||||
if quant_config.gemm1_clamp_limit is not None
|
||||
else 7.0
|
||||
)
|
||||
torch.ops._C.swigluoai_and_mul(output, input, alpha, limit)
|
||||
elif (
|
||||
activation == MoEActivation.SILU
|
||||
and quant_config.gemm1_clamp_limit is not None
|
||||
):
|
||||
swiglu_limit_func(
|
||||
output,
|
||||
input,
|
||||
quant_config.gemm1_clamp_limit,
|
||||
)
|
||||
else:
|
||||
super().activation(activation, output, input)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
@@ -812,9 +1001,9 @@ class UnfusedOAITritonExperts(LoRAExpertsMixin, BaseOAITritonExperts):
|
||||
act_input,
|
||||
)
|
||||
|
||||
# matmul_ogs grouped reduction fuse sum across multiple experts:
|
||||
# matmul_ogs grouped reduction fuses sum across multiple experts:
|
||||
# y[dst_indx // n_expts_act, :] += x
|
||||
# Need to set n_expts_act to 1 to unfuse moe_sum
|
||||
# Set n_expts_act to 1 to unfuse the sum so we can do it manually via moe_sum.
|
||||
routing_data.n_expts_act = 1
|
||||
|
||||
matmul_ogs(
|
||||
@@ -878,6 +1067,8 @@ class OAITritonMxfp4ExpertsMonolithic(mk.FusedMoEExpertsMonolithic):
|
||||
return False
|
||||
# (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
|
||||
# and ROCm gfx942/gfx950 (which map to 9.4/9.5).
|
||||
if not has_triton_kernels():
|
||||
return False
|
||||
return (9, 0) <= (cap.major, cap.minor) < (11, 0)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -14,6 +14,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import trtllm_moe_pack_topk_ids_weights
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kMxfp4Static,
|
||||
@@ -32,10 +33,8 @@ class TrtLlmMxfp4ExpertsBase:
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
**kwargs,
|
||||
):
|
||||
# NOTE: FusedMoEExperts.__init__ is called by the concrete subclass
|
||||
# (Monolithic/Modular) via MRO, not here, to avoid mypy issues with
|
||||
# multiple inheritance. This matches the NvFP4 expert pattern.
|
||||
self.moe_config = moe_config
|
||||
self.quant_config = quant_config
|
||||
|
||||
@@ -48,23 +47,34 @@ class TrtLlmMxfp4ExpertsBase:
|
||||
self.local_num_experts = moe_config.num_local_experts
|
||||
self.ep_rank = moe_config.moe_parallel_config.ep_rank
|
||||
|
||||
# MXFP4-specific TRTLLM parameters
|
||||
# MXFP4-specific TRTLLM parameters from quant_config
|
||||
device = torch.accelerator.current_device_index()
|
||||
self.gemm1_alpha = torch.tensor(
|
||||
[1.702] * self.local_num_experts,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
self.gemm1_beta = torch.tensor(
|
||||
[1.0] * self.local_num_experts,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
self.gemm1_clamp_limit = torch.tensor(
|
||||
[7.0] * self.local_num_experts,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
if quant_config.gemm1_alpha is not None:
|
||||
self.gemm1_alpha = torch.tensor(
|
||||
[quant_config.gemm1_alpha] * self.local_num_experts,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.gemm1_alpha = None
|
||||
|
||||
if quant_config.gemm1_beta is not None:
|
||||
self.gemm1_beta = torch.tensor(
|
||||
[quant_config.gemm1_beta] * self.local_num_experts,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.gemm1_beta = None
|
||||
|
||||
if quant_config.gemm1_clamp_limit is not None:
|
||||
self.gemm1_clamp_limit = torch.tensor(
|
||||
[quant_config.gemm1_clamp_limit] * self.local_num_experts,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.gemm1_clamp_limit = None
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
@@ -97,7 +107,7 @@ class TrtLlmMxfp4ExpertsBase:
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation == MoEActivation.SWIGLUOAI
|
||||
return activation in (MoEActivation.SWIGLUOAI, MoEActivation.SILU)
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
@@ -190,36 +200,41 @@ class TrtLlmMxfp4ExpertsMonolithic(
|
||||
|
||||
output = torch.empty_like(hidden_states)
|
||||
|
||||
return trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits.to(torch.bfloat16),
|
||||
routing_bias=None,
|
||||
hidden_states=x_quant,
|
||||
hidden_states_scale=x_scale,
|
||||
gemm1_weights=w1,
|
||||
gemm1_weights_scale=self.w1_scale,
|
||||
gemm1_bias=self.w1_bias,
|
||||
gemm1_alpha=self.gemm1_alpha,
|
||||
gemm1_beta=self.gemm1_beta,
|
||||
gemm1_clamp_limit=self.gemm1_clamp_limit,
|
||||
gemm2_weights=w2,
|
||||
gemm2_weights_scale=self.w2_scale,
|
||||
gemm2_bias=self.w2_bias,
|
||||
output1_scale_scalar=None,
|
||||
output1_scale_gate_scalar=None,
|
||||
output2_scale_scalar=None,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=None,
|
||||
topk_group=None,
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=self.routing_method_type,
|
||||
do_finalize=True,
|
||||
tune_max_num_tokens=max(self.max_capture_size, 1),
|
||||
output=output,
|
||||
)[0]
|
||||
from vllm.utils.flashinfer import _is_fi_autotuning, autotune
|
||||
|
||||
with autotune(_is_fi_autotuning):
|
||||
trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits.to(torch.bfloat16),
|
||||
routing_bias=None,
|
||||
hidden_states=x_quant,
|
||||
hidden_states_scale=x_scale,
|
||||
gemm1_weights=w1,
|
||||
gemm1_weights_scale=self.w1_scale,
|
||||
gemm1_bias=self.w1_bias,
|
||||
gemm1_alpha=self.gemm1_alpha,
|
||||
gemm1_beta=self.gemm1_beta,
|
||||
gemm1_clamp_limit=self.gemm1_clamp_limit,
|
||||
gemm2_weights=w2,
|
||||
gemm2_weights_scale=self.w2_scale,
|
||||
gemm2_bias=self.w2_bias,
|
||||
output1_scale_scalar=None,
|
||||
output1_scale_gate_scalar=None,
|
||||
output2_scale_scalar=None,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=None,
|
||||
topk_group=None,
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=self.routing_method_type,
|
||||
do_finalize=True,
|
||||
tune_max_num_tokens=max(self.max_capture_size, 1),
|
||||
output=output,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
@@ -239,6 +254,16 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
# Modular kernel handles only the expert computation;
|
||||
# routing is done externally, so accept any routing method.
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -282,7 +307,7 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
|
||||
):
|
||||
topk = topk_ids.size(-1)
|
||||
local_num_experts = w1.size(0)
|
||||
intermediate_size = w2.size(1)
|
||||
intermediate_size = self.intermediate_size_per_partition
|
||||
local_expert_offset = self.moe_config.ep_rank * local_num_experts
|
||||
|
||||
# Handle input quantization
|
||||
@@ -302,9 +327,8 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
|
||||
x_quant = hidden_states
|
||||
x_scale = None
|
||||
|
||||
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
|
||||
torch.bfloat16
|
||||
).view(torch.int16)
|
||||
# Pack topk ids and weights into format expected by the kernel.
|
||||
packed_tensor = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights)
|
||||
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
@@ -333,7 +357,10 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
|
||||
"local_expert_offset": local_expert_offset,
|
||||
"local_num_experts": local_num_experts,
|
||||
"routed_scaling_factor": None,
|
||||
"routing_method_type": self.routing_method_type,
|
||||
# Modular kernel receives pre-routed tokens, so routing
|
||||
# is already done. Use Renormalize as a safe default that
|
||||
# the TRTLLM C++ kernel supports.
|
||||
"routing_method_type": RoutingMethodType.Renormalize,
|
||||
"do_finalize": True,
|
||||
"output": output,
|
||||
"tune_max_num_tokens": max(self.max_capture_size, 1),
|
||||
@@ -341,12 +368,9 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
|
||||
|
||||
from flashinfer import trtllm_fp4_block_scale_routed_moe
|
||||
|
||||
from vllm.utils.flashinfer import autotune
|
||||
from vllm.utils.flashinfer import _is_fi_autotuning, autotune
|
||||
|
||||
with autotune(False):
|
||||
# Enable autotune when,
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2023 is
|
||||
# resolved.
|
||||
with autotune(_is_fi_autotuning):
|
||||
trtllm_fp4_block_scale_routed_moe(**kwargs)
|
||||
|
||||
return output
|
||||
|
||||
@@ -50,6 +50,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
from .utils import swiglu_limit_func
|
||||
|
||||
|
||||
def _fused_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -88,6 +90,7 @@ def _fused_marlin_moe(
|
||||
output: torch.Tensor | None = None,
|
||||
input_dtype: torch.dtype | None = None,
|
||||
is_k_full: bool = True,
|
||||
clamp_limit: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert hidden_states.ndim == 2
|
||||
M, K = hidden_states.size()
|
||||
@@ -155,11 +158,18 @@ def _fused_marlin_moe(
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
)
|
||||
activation_func(
|
||||
activation,
|
||||
intermediate_cache2,
|
||||
intermediate_cache1.view(-1, w13_num_shards * N),
|
||||
)
|
||||
if clamp_limit is not None and activation == MoEActivation.SILU:
|
||||
swiglu_limit_func(
|
||||
intermediate_cache2,
|
||||
intermediate_cache1.view(-1, w13_num_shards * N),
|
||||
clamp_limit,
|
||||
)
|
||||
else:
|
||||
activation_func(
|
||||
activation,
|
||||
intermediate_cache2,
|
||||
intermediate_cache1.view(-1, w13_num_shards * N),
|
||||
)
|
||||
|
||||
if output is None:
|
||||
output = intermediate_cache3
|
||||
@@ -247,6 +257,7 @@ def fused_marlin_moe(
|
||||
output: torch.Tensor | None = None,
|
||||
input_dtype: torch.dtype | None = None,
|
||||
inplace: bool = False,
|
||||
clamp_limit: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
@@ -363,6 +374,7 @@ def fused_marlin_moe(
|
||||
output=None,
|
||||
input_dtype=input_dtype,
|
||||
is_k_full=is_k_full,
|
||||
clamp_limit=clamp_limit,
|
||||
).view(-1, topk, K)
|
||||
|
||||
if output is None:
|
||||
@@ -557,6 +569,7 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
|
||||
self.w2_g_idx_sort_indices = w2_g_idx_sort_indices
|
||||
self.is_k_full = is_k_full
|
||||
self.input_dtype = get_marlin_input_dtype()
|
||||
self.gemm1_clamp_limit = quant_config.gemm1_clamp_limit
|
||||
|
||||
super().__init__(
|
||||
moe_config=moe_config,
|
||||
@@ -850,6 +863,7 @@ class MarlinExperts(LoRAExpertsMixin, MarlinExpertsBase):
|
||||
sort_indices2=self.w2_g_idx_sort_indices,
|
||||
is_k_full=self.is_k_full,
|
||||
input_dtype=self.input_dtype,
|
||||
clamp_limit=self.gemm1_clamp_limit,
|
||||
)
|
||||
|
||||
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
|
||||
|
||||
@@ -169,5 +169,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -268,6 +268,7 @@ class FusedMoE(PluggableLayer):
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
swiglu_limit: float | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
@@ -285,6 +286,7 @@ class FusedMoE(PluggableLayer):
|
||||
routed_output_transform: torch.nn.Module | None = None,
|
||||
apply_routed_scale_to_output: bool = False,
|
||||
zero_expert_type: str | None = None,
|
||||
hash_indices_table: torch.Tensor | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -294,6 +296,7 @@ class FusedMoE(PluggableLayer):
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.vllm_config = vllm_config
|
||||
self.swiglu_limit = swiglu_limit
|
||||
|
||||
# FIXME (varun): We should have a better way of inferring the activation
|
||||
# datatype. This works for now as the tensor datatype entering the MoE
|
||||
@@ -455,6 +458,7 @@ class FusedMoE(PluggableLayer):
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
# TODO(bnell): end attributes
|
||||
|
||||
self.hash_indices_table = hash_indices_table
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
self.activation = MoEActivation.from_str(activation)
|
||||
|
||||
@@ -479,6 +483,7 @@ class FusedMoE(PluggableLayer):
|
||||
indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
|
||||
zero_expert_type=zero_expert_type,
|
||||
num_logical_experts=self.logical_num_experts,
|
||||
hash_indices_table=self.hash_indices_table,
|
||||
)
|
||||
self.routing_method_type: RoutingMethodType = self.router.routing_method_type
|
||||
|
||||
@@ -1541,10 +1546,12 @@ class FusedMoE(PluggableLayer):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.runner.forward(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
input_ids,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
FusedMoEQuantDesc,
|
||||
mxfp4_mxfp8_moe_quant_config,
|
||||
mxfp4_w4a16_moe_quant_config,
|
||||
ocp_mx_moe_quant_config,
|
||||
@@ -24,6 +25,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8Dynamic128Sym,
|
||||
kMxfp4Static,
|
||||
kMxfp8Dynamic,
|
||||
)
|
||||
@@ -46,6 +48,8 @@ if has_triton_kernels():
|
||||
|
||||
class Mxfp4MoeBackend(Enum):
|
||||
NONE = "None"
|
||||
# DeepGEMM FP8xFP4 backend (SM100+)
|
||||
DEEPGEMM_MXFP4 = "DEEPGEMM_MXFP4"
|
||||
# FlashInfer TRTLLM backends
|
||||
FLASHINFER_TRTLLM_MXFP4_MXFP8 = "FLASHINFER_TRTLLM_MXFP4_MXFP8"
|
||||
FLASHINFER_TRTLLM_MXFP4_BF16 = "FLASHINFER_TRTLLM_MXFP4_BF16"
|
||||
@@ -81,7 +85,14 @@ TRITON_BACKENDS = (
|
||||
def backend_to_kernel_cls(
|
||||
backend: Mxfp4MoeBackend,
|
||||
) -> list[type[mk.FusedMoEExperts]]:
|
||||
if backend in (
|
||||
if backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
|
||||
from vllm.model_executor.layers.fused_moe.experts.deep_gemm_moe import (
|
||||
DeepGemmFP4Experts,
|
||||
)
|
||||
|
||||
return [DeepGemmFP4Experts]
|
||||
|
||||
elif backend in (
|
||||
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
|
||||
):
|
||||
@@ -159,11 +170,13 @@ def backend_to_kernel_cls(
|
||||
def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend:
|
||||
"""Map user's moe_backend string to Mxfp4MoeBackend."""
|
||||
mapping: dict[str, Mxfp4MoeBackend] = {
|
||||
"deep_gemm": Mxfp4MoeBackend.DEEPGEMM_MXFP4,
|
||||
"flashinfer_trtllm": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
|
||||
"flashinfer_trtllm_afp8": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
|
||||
"flashinfer_cutlass": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
|
||||
"flashinfer_cutlass_afp8": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
|
||||
"triton": Mxfp4MoeBackend.TRITON,
|
||||
"triton_unfused": Mxfp4MoeBackend.TRITON_UNFUSED,
|
||||
"marlin": Mxfp4MoeBackend.MARLIN,
|
||||
"aiter": Mxfp4MoeBackend.AITER,
|
||||
"xpu": Mxfp4MoeBackend.XPU,
|
||||
@@ -177,7 +190,7 @@ def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend:
|
||||
)
|
||||
|
||||
|
||||
def _get_priority_backends() -> list[Mxfp4MoeBackend]:
|
||||
def _get_priority_backends_for_gpt_oss() -> list[Mxfp4MoeBackend]:
|
||||
"""
|
||||
Get available backends in priority order based on platform and config.
|
||||
Only includes BF16 backends. MXFP8 backends are selected via env vars.
|
||||
@@ -187,7 +200,9 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]:
|
||||
Mxfp4MoeBackend.AITER,
|
||||
Mxfp4MoeBackend.TRITON,
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.TRITON_UNFUSED,
|
||||
# TRITON_UNFUSED has bug with MTP support
|
||||
# TODO re-enable after kernel is fixed
|
||||
# TRITON_UNFUSED
|
||||
Mxfp4MoeBackend.MARLIN,
|
||||
Mxfp4MoeBackend.BATCHED_MARLIN,
|
||||
Mxfp4MoeBackend.XPU,
|
||||
@@ -196,8 +211,28 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]:
|
||||
return _AVAILABLE_BACKENDS
|
||||
|
||||
|
||||
def _get_priority_backends() -> list[Mxfp4MoeBackend]:
|
||||
"""
|
||||
Get available backends in priority order. SM100+ prefers DeepGEMM FP4 /
|
||||
TRTLLM MXFP8; SM90 falls through to Triton_unfused or Marlin (the
|
||||
backend-level ``is_supported_config`` check filters by device capability).
|
||||
"""
|
||||
_AVAILABLE_BACKENDS = [
|
||||
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
|
||||
Mxfp4MoeBackend.DEEPGEMM_MXFP4,
|
||||
# TRITON_UNFUSED has bug with MTP support
|
||||
# TODO re-enable after kernel is fixed
|
||||
# TRITON_UNFUSED
|
||||
Mxfp4MoeBackend.MARLIN,
|
||||
Mxfp4MoeBackend.BATCHED_MARLIN,
|
||||
]
|
||||
return _AVAILABLE_BACKENDS
|
||||
|
||||
|
||||
def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None:
|
||||
"""Map backend to its activation key (MXFP8 or None for BF16)."""
|
||||
"""Map backend to its activation key (FP8, MXFP8, or None for BF16)."""
|
||||
if backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
|
||||
return kFp8Dynamic128Sym
|
||||
if backend in (
|
||||
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
|
||||
@@ -290,7 +325,7 @@ def select_gpt_oss_mxfp4_moe_backend(
|
||||
)
|
||||
|
||||
# Select kernels in order of backend.
|
||||
AVAILABLE_BACKENDS = _get_priority_backends()
|
||||
AVAILABLE_BACKENDS = _get_priority_backends_for_gpt_oss()
|
||||
|
||||
# Handle explicit FlashInfer MXFP4 BF16 configuration.
|
||||
if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
|
||||
@@ -387,11 +422,95 @@ def select_gpt_oss_mxfp4_moe_backend(
|
||||
return Mxfp4MoeBackend.NONE, None
|
||||
|
||||
|
||||
def select_mxfp4_moe_backend(
|
||||
config: FusedMoEConfig,
|
||||
) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts] | None]:
|
||||
"""
|
||||
Select the MXFP4 MoE backend with MXFP8 activation as top priority.
|
||||
Falls back through BF16 and other backends.
|
||||
"""
|
||||
activation_format = (
|
||||
mk.FusedMoEActivationFormat.BatchedExperts
|
||||
if config.moe_parallel_config.use_batched_activation_format
|
||||
else mk.FusedMoEActivationFormat.Standard
|
||||
)
|
||||
|
||||
def _make_log_backend(backend: Mxfp4MoeBackend):
|
||||
return f"Using '{backend.value}' Mxfp4 MoE backend."
|
||||
|
||||
def _make_log_unsupported(backend: Mxfp4MoeBackend, reason: str | None) -> str:
|
||||
if reason:
|
||||
return (
|
||||
f"Mxfp4 MoE backend '{backend.value}' does not support the "
|
||||
f"deployment configuration since {reason}."
|
||||
)
|
||||
return (
|
||||
f"Mxfp4 MoE backend '{backend.value}' does not support the "
|
||||
"deployment configuration."
|
||||
)
|
||||
|
||||
def _return_or_raise(
|
||||
backend: Mxfp4MoeBackend,
|
||||
config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts]]:
|
||||
reason: str | None = None
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
raise ValueError(_make_log_unsupported(backend, reason))
|
||||
|
||||
# Honor explicit moe_backend (e.g. "marlin", "triton_unfused") before
|
||||
# falling back to the auto priority list.
|
||||
runner_backend = config.moe_backend
|
||||
if runner_backend != "auto":
|
||||
requested_backend = map_mxfp4_backend(runner_backend)
|
||||
if (
|
||||
activation_format == mk.FusedMoEActivationFormat.BatchedExperts
|
||||
and requested_backend == Mxfp4MoeBackend.MARLIN
|
||||
):
|
||||
requested_backend = Mxfp4MoeBackend.BATCHED_MARLIN
|
||||
return _return_or_raise(
|
||||
requested_backend,
|
||||
config,
|
||||
kMxfp4Static,
|
||||
_backend_activation_key(requested_backend),
|
||||
activation_format,
|
||||
)
|
||||
|
||||
# Iterate priority backends: TRTLLM MXFP8, then Triton.
|
||||
for backend in _get_priority_backends():
|
||||
activation_key = _backend_activation_key(backend)
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls, config, kMxfp4Static, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
else:
|
||||
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
|
||||
|
||||
raise NotImplementedError(
|
||||
"No MXFP4 MoE backend supports the deployment configuration."
|
||||
)
|
||||
|
||||
|
||||
def mxfp4_round_up_hidden_size_and_intermediate_size(
|
||||
backend: Mxfp4MoeBackend, hidden_size: int, intermediate_size: int
|
||||
) -> tuple[int, int]:
|
||||
"""Round up hidden_size and intermediate_size based on backend requirements."""
|
||||
if backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN):
|
||||
if backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
|
||||
# DeepGEMM requires M/N/K alignment
|
||||
intermediate_size = round_up(intermediate_size, 128)
|
||||
hidden_size = round_up(hidden_size, 128)
|
||||
elif backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN):
|
||||
intermediate_size = round_up(intermediate_size, 128)
|
||||
if current_platform.is_xpu():
|
||||
hidden_size = round_up(hidden_size, 128)
|
||||
@@ -434,6 +553,20 @@ def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
|
||||
]:
|
||||
"""Convert loaded weights into backend-specific kernel format."""
|
||||
|
||||
if mxfp4_backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
_upcast_e8m0_to_fp32,
|
||||
)
|
||||
|
||||
return (
|
||||
w13_weight.data,
|
||||
w2_weight.data,
|
||||
_upcast_e8m0_to_fp32(w13_weight_scale.data),
|
||||
_upcast_e8m0_to_fp32(w2_weight_scale.data),
|
||||
w13_bias,
|
||||
w2_bias,
|
||||
)
|
||||
|
||||
num_experts = w13_weight.shape[0]
|
||||
intermediate_size = w13_weight.shape[1] // 2
|
||||
hidden_size = w13_weight.shape[2] * 2
|
||||
@@ -738,9 +871,10 @@ def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
|
||||
elif mxfp4_backend in TRITON_BACKENDS:
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
assert w13_bias is not None and w2_bias is not None
|
||||
w13_bias = w13_bias.to(torch.float32)
|
||||
w2_bias = w2_bias.to(torch.float32)
|
||||
if w13_bias is not None:
|
||||
w13_bias = w13_bias.to(torch.float32)
|
||||
if w2_bias is not None:
|
||||
w2_bias = w2_bias.to(torch.float32)
|
||||
|
||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||
w13_weight,
|
||||
@@ -797,15 +931,271 @@ def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
|
||||
)
|
||||
|
||||
|
||||
def convert_weight_to_mxfp4_moe_kernel_format(
|
||||
mxfp4_backend: Mxfp4MoeBackend,
|
||||
layer: torch.nn.Module,
|
||||
w13_weight: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w13_weight_scale: torch.Tensor,
|
||||
w2_weight_scale: torch.Tensor,
|
||||
w13_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
_cache_permute_indices: dict[torch.Size, torch.Tensor] | None = None,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
Union[torch.Tensor, "PrecisionConfig"],
|
||||
Union[torch.Tensor, "PrecisionConfig"],
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
]:
|
||||
"""Convert loaded weights into backend-specific kernel format.
|
||||
|
||||
Supports DeepGEMM, TRTLLM MXFP8, Triton and Marlin backends.
|
||||
"""
|
||||
|
||||
if mxfp4_backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
_upcast_e8m0_to_fp32,
|
||||
)
|
||||
|
||||
# Weights stay as uint8 packed FP4 — no layout change needed.
|
||||
# Convert E8M0 uint8 scales to float32.
|
||||
return (
|
||||
w13_weight.data,
|
||||
w2_weight.data,
|
||||
_upcast_e8m0_to_fp32(w13_weight_scale.data),
|
||||
_upcast_e8m0_to_fp32(w2_weight_scale.data),
|
||||
w13_bias,
|
||||
w2_bias,
|
||||
)
|
||||
|
||||
if mxfp4_backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN):
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
prepare_moe_mxfp4_layer_for_marlin,
|
||||
)
|
||||
|
||||
return prepare_moe_mxfp4_layer_for_marlin(
|
||||
layer,
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
w13_weight_scale,
|
||||
w2_weight_scale,
|
||||
w13_bias,
|
||||
w2_bias,
|
||||
)
|
||||
|
||||
num_experts = w13_weight.shape[0]
|
||||
intermediate_size = w13_weight.shape[1] // 2
|
||||
hidden_size = w13_weight.shape[2] * 2
|
||||
|
||||
sf_block_size = 32 # mxfp4 block size
|
||||
|
||||
if mxfp4_backend in TRTLLM_BACKENDS:
|
||||
assert _cache_permute_indices is not None
|
||||
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
|
||||
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
|
||||
|
||||
w13_weight = w13_weight.data
|
||||
w2_weight = w2_weight.data
|
||||
w13_weight_scale = w13_weight_scale.data
|
||||
w2_weight_scale = w2_weight_scale.data
|
||||
if w13_bias is not None:
|
||||
w13_bias = w13_bias.data.to(torch.float32)
|
||||
if w2_bias is not None:
|
||||
w2_bias = w2_bias.data.to(torch.float32)
|
||||
|
||||
# Swap w1/w3 and interleave to match TRTLLM SwiGLU convention.
|
||||
# Standard loading gives contiguous [w1/gate, w3/up].
|
||||
# TRTLLM kernel expects interleaved [w3_0, w1_0, w3_1, w1_1, ...].
|
||||
w1_weight = w13_weight[:, :intermediate_size, :]
|
||||
w3_weight = w13_weight[:, intermediate_size:, :]
|
||||
w13_weight = torch.stack([w3_weight, w1_weight], dim=2).reshape(
|
||||
w13_weight.shape
|
||||
)
|
||||
|
||||
w1_scale = w13_weight_scale[:, :intermediate_size, :]
|
||||
w3_scale = w13_weight_scale[:, intermediate_size:, :]
|
||||
w13_weight_scale = torch.stack([w3_scale, w1_scale], dim=2).reshape(
|
||||
w13_weight_scale.shape
|
||||
)
|
||||
|
||||
if w13_bias is not None:
|
||||
b1 = w13_bias[:, :intermediate_size]
|
||||
b3 = w13_bias[:, intermediate_size:]
|
||||
w13_bias = torch.stack([b3, b1], dim=2).reshape(w13_bias.shape)
|
||||
|
||||
# Shuffle weights and scaling factors for transposed mma output.
|
||||
# Permute indices depend only on shape (cached by torch.Size),
|
||||
# so compute once and apply to all experts via batched indexing.
|
||||
epilogue_tile_m = 128
|
||||
|
||||
# w13 weight permute
|
||||
w13_perm = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w13_weight[0].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
).to(w13_weight.device)
|
||||
w13_weight = w13_weight.view(torch.uint8)[:, w13_perm].contiguous()
|
||||
|
||||
# w13 scale permute + interleave
|
||||
w13_sf_perm = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w13_weight_scale[0].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
).to(w13_weight_scale.device)
|
||||
w13_s = w13_weight_scale.view(torch.uint8)[:, w13_sf_perm].contiguous()
|
||||
E, N_s, K_s = w13_s.shape
|
||||
w13_weight_scale = (
|
||||
nvfp4_block_scale_interleave(w13_s.reshape(E * N_s, K_s))
|
||||
.reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
|
||||
# w2 weight permute
|
||||
w2_perm = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w2_weight[0].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
).to(w2_weight.device)
|
||||
w2_weight = w2_weight.view(torch.uint8)[:, w2_perm].contiguous()
|
||||
|
||||
# w2 scale permute + interleave
|
||||
w2_sf_perm = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w2_weight_scale[0].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
).to(w2_weight_scale.device)
|
||||
w2_s = w2_weight_scale.view(torch.uint8)[:, w2_sf_perm].contiguous()
|
||||
E2, N2_s, K2_s = w2_s.shape
|
||||
w2_weight_scale = (
|
||||
nvfp4_block_scale_interleave(w2_s.reshape(E2 * N2_s, K2_s))
|
||||
.reshape(num_experts, hidden_size, intermediate_size // sf_block_size)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
|
||||
# w13 bias permute
|
||||
if w13_bias is not None:
|
||||
w13_b_perm = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w13_bias[0].reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
).to(w13_bias.device)
|
||||
w13_bias = w13_bias.reshape(num_experts, -1, 1)[:, w13_b_perm].reshape(
|
||||
num_experts, -1
|
||||
)
|
||||
|
||||
# w2 bias permute
|
||||
if w2_bias is not None:
|
||||
w2_b_perm = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w2_bias[0].reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
).to(w2_bias.device)
|
||||
w2_bias = w2_bias.reshape(num_experts, -1, 1)[:, w2_b_perm].reshape(
|
||||
num_experts, -1
|
||||
)
|
||||
|
||||
return (
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
w13_weight_scale,
|
||||
w2_weight_scale,
|
||||
w13_bias,
|
||||
w2_bias,
|
||||
)
|
||||
|
||||
elif mxfp4_backend in TRITON_BACKENDS:
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
if mxfp4_backend == Mxfp4MoeBackend.TRITON:
|
||||
|
||||
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
|
||||
shape = w.shape
|
||||
n = shape[-1]
|
||||
first = w[..., : n // 2]
|
||||
second = w[..., n // 2 :]
|
||||
stacked = torch.stack((first, second), dim=-1)
|
||||
return stacked.reshape(shape)
|
||||
|
||||
w13_weight = shuffle_weight(w13_weight)
|
||||
w13_weight_scale = shuffle_weight(w13_weight_scale)
|
||||
|
||||
if w13_bias is not None:
|
||||
w13_bias = shuffle_weight(w13_bias.to(torch.float32))
|
||||
else:
|
||||
if w13_bias is not None:
|
||||
w13_bias = w13_bias.to(torch.float32)
|
||||
|
||||
if w2_bias is not None:
|
||||
w2_bias = w2_bias.to(torch.float32)
|
||||
|
||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||
w13_weight,
|
||||
w13_weight_scale,
|
||||
)
|
||||
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
|
||||
w2_weight,
|
||||
w2_weight_scale,
|
||||
)
|
||||
|
||||
w13_precision_config = PrecisionConfig(
|
||||
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
|
||||
)
|
||||
w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
|
||||
)
|
||||
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
|
||||
return (
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
w13_precision_config,
|
||||
w2_precision_config,
|
||||
w13_bias,
|
||||
w2_bias,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported mxfp4_backend for Mxfp4MoEMethod: {mxfp4_backend}. "
|
||||
f"Expected TRTLLM or Triton backend."
|
||||
)
|
||||
|
||||
|
||||
def make_mxfp4_moe_quant_config(
|
||||
mxfp4_backend: Mxfp4MoeBackend,
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
gemm1_alpha: float | None = None,
|
||||
gemm1_beta: float | None = None,
|
||||
swiglu_limit: float | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
"""Create a FusedMoEQuantConfig for the given MXFP4 backend."""
|
||||
if mxfp4_backend in (
|
||||
if mxfp4_backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4:
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
)
|
||||
|
||||
# DeepGEMM FP4 uses FP8 per-token-group activation quantization
|
||||
# with block 128, matching the FP8 DeepGEMM path.
|
||||
_fp8_dtype = current_platform.fp8_dtype()
|
||||
_block_shape = GroupShape(128, 128)
|
||||
return FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(_fp8_dtype, _block_shape, None, None, None, None),
|
||||
_a2=FusedMoEQuantDesc(_fp8_dtype, _block_shape, None, None, None, None),
|
||||
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
|
||||
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
|
||||
gemm1_alpha=gemm1_alpha,
|
||||
gemm1_beta=gemm1_beta,
|
||||
gemm1_clamp_limit=swiglu_limit,
|
||||
)
|
||||
elif mxfp4_backend in (
|
||||
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
|
||||
):
|
||||
@@ -814,6 +1204,9 @@ def make_mxfp4_moe_quant_config(
|
||||
w2_bias=w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
gemm1_alpha=gemm1_alpha,
|
||||
gemm1_beta=gemm1_beta,
|
||||
gemm1_clamp_limit=swiglu_limit,
|
||||
)
|
||||
elif mxfp4_backend in (
|
||||
Mxfp4MoeBackend.MARLIN,
|
||||
@@ -829,6 +1222,9 @@ def make_mxfp4_moe_quant_config(
|
||||
w2_bias=w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
gemm1_alpha=gemm1_alpha,
|
||||
gemm1_beta=gemm1_beta,
|
||||
gemm1_clamp_limit=swiglu_limit,
|
||||
)
|
||||
else:
|
||||
return ocp_mx_moe_quant_config(
|
||||
@@ -837,6 +1233,9 @@ def make_mxfp4_moe_quant_config(
|
||||
w2_bias=w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
gemm1_alpha=gemm1_alpha,
|
||||
gemm1_beta=gemm1_beta,
|
||||
gemm1_clamp_limit=swiglu_limit,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -228,6 +228,8 @@ class BaseRouter(FusedMoERouter):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
indices_type: torch.dtype | None,
|
||||
*,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute the actual routing logic.
|
||||
@@ -249,6 +251,8 @@ class BaseRouter(FusedMoERouter):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
*,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Route the input hidden states to the top-k experts based on the
|
||||
@@ -278,7 +282,7 @@ class BaseRouter(FusedMoERouter):
|
||||
|
||||
# Step 3: Compute routing (delegated to subclass)
|
||||
topk_weights, topk_ids = self._compute_routing(
|
||||
hidden_states, router_logits, indices_type
|
||||
hidden_states, router_logits, indices_type, input_ids=input_ids
|
||||
)
|
||||
|
||||
# Capture logical ids before EPLB mapping.
|
||||
|
||||
@@ -46,6 +46,8 @@ class CustomRoutingRouter(BaseRouter):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
indices_type: torch.dtype | None,
|
||||
*,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute routing using the custom routing function."""
|
||||
topk_weights, topk_ids = self.custom_routing_function(
|
||||
|
||||
@@ -31,6 +31,8 @@ class FusedMoERouter(ABC):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
*,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Route the input hidden states to the top-k experts based on the
|
||||
|
||||
@@ -4,6 +4,7 @@ import functools
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
import vllm.envs as envs
|
||||
@@ -56,6 +57,32 @@ def vllm_topk_sigmoid(
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def vllm_topk_softplus_sqrt(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool = False,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
input_tokens: torch.Tensor | None = None,
|
||||
hash_indices_table: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
ops.topk_hash_softplus_sqrt(
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
input_tokens,
|
||||
hash_indices_table,
|
||||
)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=8)
|
||||
def _aiter_get_num_expert_group(num_experts: int) -> int:
|
||||
_AITER_MAX_EXPERTS_PER_GROUP = 32
|
||||
@@ -72,11 +99,14 @@ def _aiter_get_num_expert_group(num_experts: int) -> int:
|
||||
def fused_topk_bias(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
scoring_func: str,
|
||||
e_score_correction_bias: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
scoring_func: str = "softmax",
|
||||
indices_type: torch.dtype | None = None,
|
||||
input_tokens: torch.Tensor | None = None,
|
||||
hash_indices_table: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
):
|
||||
if not rocm_aiter_ops.is_fused_moe_enabled():
|
||||
assert hidden_states.size(0) == gating_output.size(0), (
|
||||
@@ -107,6 +137,8 @@ def fused_topk_bias(
|
||||
renormalize,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
if routed_scaling_factor != 1.0:
|
||||
topk_weights *= routed_scaling_factor
|
||||
return topk_weights, topk_ids
|
||||
elif scoring_func == "sigmoid":
|
||||
topk_weights, topk_ids = vllm_topk_sigmoid(
|
||||
@@ -117,9 +149,24 @@ def fused_topk_bias(
|
||||
renormalize,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
if routed_scaling_factor != 1.0:
|
||||
topk_weights *= routed_scaling_factor
|
||||
return topk_weights, topk_ids
|
||||
elif scoring_func == "sqrtsoftplus":
|
||||
return vllm_topk_softplus_sqrt(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
e_score_correction_bias,
|
||||
input_tokens,
|
||||
hash_indices_table,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
elif rocm_aiter_ops.is_fused_moe_enabled() and scoring_func == "sigmoid":
|
||||
M = hidden_states.size(0)
|
||||
num_experts = gating_output.shape[-1]
|
||||
@@ -143,6 +190,8 @@ def fused_topk_bias(
|
||||
topk_group=num_expert_group,
|
||||
need_renorm=renormalize,
|
||||
)
|
||||
if routed_scaling_factor != 1.0:
|
||||
topk_weights *= routed_scaling_factor
|
||||
return topk_weights, topk_ids
|
||||
|
||||
n_routed_experts = gating_output.shape[-1]
|
||||
@@ -150,20 +199,31 @@ def fused_topk_bias(
|
||||
scores = gating_output.softmax(dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
scores = gating_output.sigmoid()
|
||||
elif scoring_func == "sqrtsoftplus":
|
||||
scores = F.softplus(gating_output).sqrt()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
scores_for_choice = scores.view(
|
||||
-1, n_routed_experts
|
||||
) + e_score_correction_bias.unsqueeze(0)
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
scores_for_choice = scores.view(
|
||||
-1, n_routed_experts
|
||||
) + e_score_correction_bias.unsqueeze(0)
|
||||
else:
|
||||
scores_for_choice = scores.view(-1, n_routed_experts)
|
||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||
use_sorted = envs.VLLM_BATCH_INVARIANT
|
||||
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||
if hash_indices_table is not None:
|
||||
topk_indices = hash_indices_table[input_tokens]
|
||||
else:
|
||||
use_sorted = envs.VLLM_BATCH_INVARIANT
|
||||
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[
|
||||
1
|
||||
]
|
||||
topk_weights = scores.gather(1, topk_indices)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights.to(torch.float32), topk_indices.to(
|
||||
topk_weights = topk_weights.to(torch.float32)
|
||||
if routed_scaling_factor != 1.0:
|
||||
topk_weights *= routed_scaling_factor
|
||||
return topk_weights, topk_indices.to(
|
||||
torch.int32 if indices_type is None else indices_type
|
||||
)
|
||||
|
||||
@@ -176,12 +236,14 @@ class FusedTopKBiasRouter(BaseRouter):
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
eplb_state: EplbLayerState,
|
||||
e_score_correction_bias: torch.Tensor,
|
||||
scoring_func: str,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
renormalize: bool = True,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
enable_eplb: bool = False,
|
||||
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
|
||||
*,
|
||||
scoring_func: str = "sigmoid",
|
||||
hash_indices_table: torch.Tensor | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
top_k=top_k,
|
||||
@@ -194,6 +256,8 @@ class FusedTopKBiasRouter(BaseRouter):
|
||||
self.renormalize = renormalize
|
||||
self.scoring_func = scoring_func
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.scoring_func = scoring_func
|
||||
self._hash_indices_table = hash_indices_table
|
||||
|
||||
@property
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
@@ -210,19 +274,23 @@ class FusedTopKBiasRouter(BaseRouter):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
indices_type: torch.dtype | None,
|
||||
*,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute routing using fused top-k with bias."""
|
||||
topk_weights, topk_ids = fused_topk_bias(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
e_score_correction_bias=self.e_score_correction_bias.data,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias.data
|
||||
if self.e_score_correction_bias is not None
|
||||
else None,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
scoring_func=self.scoring_func,
|
||||
indices_type=indices_type,
|
||||
input_tokens=input_ids,
|
||||
hash_indices_table=self._hash_indices_table,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
|
||||
if self.routed_scaling_factor != 1.0:
|
||||
topk_weights *= self.routed_scaling_factor
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
@@ -151,6 +151,8 @@ class FusedTopKRouter(BaseRouter):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
indices_type: torch.dtype | None,
|
||||
*,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute routing using standard fused top-k."""
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
|
||||
@@ -292,6 +292,8 @@ class GroupedTopKRouter(BaseRouter):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
indices_type: torch.dtype | None,
|
||||
*,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute routing using grouped top-k."""
|
||||
|
||||
@@ -308,6 +310,7 @@ class GroupedTopKRouter(BaseRouter):
|
||||
topk_weights, topk_ids = fused_topk_bias(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias.data,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
|
||||
@@ -55,6 +55,7 @@ def create_fused_moe_router(
|
||||
# zero expert parameters
|
||||
zero_expert_type: str | None = None,
|
||||
num_logical_experts: int | None = None,
|
||||
hash_indices_table: torch.Tensor | None = None,
|
||||
) -> FusedMoERouter:
|
||||
"""
|
||||
Factory function to create the appropriate FusedMoERouter subclass based on
|
||||
@@ -99,6 +100,9 @@ def create_fused_moe_router(
|
||||
num_logical_experts: Number of real (non-zero) experts. Required when
|
||||
zero_expert_type is not None.
|
||||
|
||||
Hash Indices Table:
|
||||
Used to map input_ids to experts, need for Deepseek V4
|
||||
|
||||
Returns:
|
||||
An instance of the appropriate FusedMoERouter subclass
|
||||
"""
|
||||
@@ -179,17 +183,20 @@ def create_fused_moe_router(
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
assert scoring_func in ["sigmoid", "softmax", "sqrtsoftplus"]
|
||||
|
||||
if e_score_correction_bias is not None or hash_indices_table is not None:
|
||||
return FusedTopKBiasRouter(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
scoring_func=scoring_func,
|
||||
renormalize=renormalize,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
scoring_func=scoring_func,
|
||||
hash_indices_table=hash_indices_table,
|
||||
)
|
||||
|
||||
return FusedTopKRouter(
|
||||
|
||||
@@ -334,6 +334,8 @@ class RoutingSimulatorRouter(BaseRouter):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
indices_type: torch.dtype | None,
|
||||
*,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Use routing simulator to compute routing."""
|
||||
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
|
||||
|
||||
@@ -72,6 +72,8 @@ class ZeroExpertRouter(BaseRouter):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
indices_type: torch.dtype | None,
|
||||
*,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute routing with full bias, compute zero expert output,
|
||||
mask zero expert IDs."""
|
||||
|
||||
@@ -91,6 +91,7 @@ def _moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
input_ids: torch.Tensor | None,
|
||||
layer_name: _layer_name_type,
|
||||
) -> torch.Tensor:
|
||||
layer = get_layer_from_name(_resolve_layer_name(layer_name))
|
||||
@@ -99,6 +100,7 @@ def _moe_forward(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
shared_experts_input,
|
||||
input_ids,
|
||||
)
|
||||
|
||||
|
||||
@@ -106,6 +108,7 @@ def _moe_forward_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
input_ids: torch.Tensor | None,
|
||||
layer_name: _layer_name_type,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
@@ -115,6 +118,7 @@ def _moe_forward_shared(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
input_ids: torch.Tensor | None,
|
||||
layer_name: _layer_name_type,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
layer = get_layer_from_name(_resolve_layer_name(layer_name))
|
||||
@@ -123,6 +127,7 @@ def _moe_forward_shared(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
shared_experts_input,
|
||||
input_ids,
|
||||
)
|
||||
|
||||
|
||||
@@ -130,6 +135,7 @@ def _moe_forward_shared_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
input_ids: torch.Tensor | None,
|
||||
layer_name: _layer_name_type,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Output shapes:
|
||||
@@ -433,6 +439,7 @@ class MoERunner(MoERunnerInterface):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor]:
|
||||
"""Run expert routing and the fused MoE kernel via the quant method.
|
||||
|
||||
@@ -449,11 +456,13 @@ class MoERunner(MoERunnerInterface):
|
||||
layer=layer,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
input_ids=input_ids,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = self.router.select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
input_ids=input_ids,
|
||||
)
|
||||
|
||||
# Passing shared_experts_input in case SharedExpertsOrder is
|
||||
@@ -523,6 +532,7 @@ class MoERunner(MoERunnerInterface):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Invoke the fused moe layer.
|
||||
|
||||
@@ -565,6 +575,7 @@ class MoERunner(MoERunnerInterface):
|
||||
hidden_states,
|
||||
router_logits,
|
||||
shared_experts_input,
|
||||
input_ids,
|
||||
self._encode_layer_name(),
|
||||
)
|
||||
|
||||
@@ -672,6 +683,7 @@ class MoERunner(MoERunnerInterface):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Entry point called by the custom op to run the MoE computation.
|
||||
|
||||
@@ -712,6 +724,7 @@ class MoERunner(MoERunnerInterface):
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
shared_experts_input=shared_experts_input,
|
||||
input_ids=input_ids,
|
||||
)
|
||||
|
||||
return self._maybe_combine(
|
||||
|
||||
@@ -26,6 +26,7 @@ class MoERunnerInterface(ABC):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -309,6 +309,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.is_monolithic
|
||||
if self.unquantized_backend == UnquantizedMoeBackend.CPU:
|
||||
|
||||
@@ -4,6 +4,7 @@ import functools
|
||||
from math import prod
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
@@ -384,3 +385,20 @@ def trtllm_moe_pack_topk_ids_weights(
|
||||
return (topk_ids.to(torch.int32) << 16) | topk_weights.to(torch.bfloat16).view(
|
||||
torch.int16
|
||||
)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def swiglu_limit_func(
|
||||
output: torch.Tensor,
|
||||
input: torch.Tensor, # first half is gate, second half is up
|
||||
swiglu_limit: float = 0.0,
|
||||
) -> None:
|
||||
d = input.shape[1] // 2
|
||||
gate = input[:, :d]
|
||||
up = input[:, d:]
|
||||
|
||||
if swiglu_limit > 0:
|
||||
gate = torch.clamp(gate, max=swiglu_limit)
|
||||
up = torch.clamp(up, min=-swiglu_limit, max=swiglu_limit)
|
||||
|
||||
output.copy_(F.silu(gate) * up)
|
||||
|
||||
@@ -0,0 +1,450 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_tilelang
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
# tilelang is only available on CUDA platforms
|
||||
if TYPE_CHECKING or current_platform.is_cuda_alike():
|
||||
if not has_tilelang():
|
||||
raise ImportError(
|
||||
"tilelang is required for mhc but is not installed. Install it with "
|
||||
"`pip install tilelang`."
|
||||
)
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
else:
|
||||
tilelang = None # type: ignore[assignment]
|
||||
T = None # type: ignore[assignment]
|
||||
|
||||
|
||||
@cache
|
||||
def compute_num_split(block_k: int, k: int | None, grid_size: int) -> int:
|
||||
device_props = torch.cuda.get_device_properties(0)
|
||||
n_sms = device_props.multi_processor_count
|
||||
split_k = n_sms // grid_size
|
||||
if k is not None:
|
||||
# avoid split_k for small k
|
||||
num_block_k = cdiv(k, block_k)
|
||||
split_k = min(split_k, num_block_k // 4)
|
||||
split_k = max(split_k, 1)
|
||||
return split_k
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
pass_configs={
|
||||
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
||||
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
||||
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
|
||||
},
|
||||
)
|
||||
def mhc_pre_big_fuse_tilelang(
|
||||
gemm_out_mul,
|
||||
gemm_out_sqrsum,
|
||||
hc_scale,
|
||||
hc_base,
|
||||
residual,
|
||||
post_mix,
|
||||
comb_mix,
|
||||
layer_input,
|
||||
hidden_size: int,
|
||||
rms_eps: float,
|
||||
hc_pre_eps: float,
|
||||
hc_sinkhorn_eps: float,
|
||||
hc_post_mult_value: float,
|
||||
sinkhorn_repeat: int,
|
||||
n_splits: int = 16,
|
||||
hc_mult: int = 4,
|
||||
):
|
||||
"""Deeply fused kernels, everything other than gemm & sqrsum in mHC pre block."""
|
||||
num_tokens = T.dynamic("num_tokens")
|
||||
hc_mult3 = hc_mult * (2 + hc_mult)
|
||||
hidden_block = math.gcd(512, hidden_size)
|
||||
|
||||
gemm_out_mul: T.Tensor[[n_splits, num_tokens, hc_mult3], T.float32] # type: ignore[no-redef, valid-type]
|
||||
gemm_out_sqrsum: T.Tensor[[n_splits, num_tokens], T.float32] # type: ignore[no-redef, valid-type]
|
||||
hc_scale: T.Tensor[[3], T.float32] # type: ignore[no-redef, valid-type]
|
||||
hc_base: T.Tensor[[hc_mult3], T.float32] # type: ignore[no-redef, valid-type]
|
||||
residual: T.Tensor[[num_tokens, hc_mult, hidden_size], T.bfloat16] # type: ignore[no-redef, valid-type]
|
||||
# outputs
|
||||
post_mix: T.Tensor[[num_tokens, hc_mult], T.float32] # type: ignore[no-redef, valid-type]
|
||||
comb_mix: T.Tensor[[num_tokens, hc_mult * hc_mult], T.float32] # type: ignore[no-redef, valid-type]
|
||||
layer_input: T.Tensor[[num_tokens, hidden_size], T.bfloat16] # type: ignore[no-redef, valid-type]
|
||||
|
||||
with T.Kernel(num_tokens, threads=96) as i:
|
||||
T.pdl_sync()
|
||||
##################################################################
|
||||
# _pre_norm_fn_fwd_norm
|
||||
rms = T.alloc_fragment(1, T.float32)
|
||||
mixes = T.alloc_fragment(hc_mult3, T.float32)
|
||||
T.clear(mixes)
|
||||
rms[0] = 0
|
||||
for i_split in T.serial(n_splits):
|
||||
rms[0] += gemm_out_sqrsum[i_split, i]
|
||||
rms[0] = T.rsqrt(rms[0] / (hc_mult * hidden_size) + rms_eps)
|
||||
for j in T.Parallel(hc_mult3):
|
||||
mixes[j] = 0
|
||||
for i_split in T.serial(n_splits):
|
||||
mixes[j] += gemm_out_mul[i_split, i, j]
|
||||
mixes[j] *= rms[0]
|
||||
mixes_shared = T.alloc_shared(hc_mult3, T.float32)
|
||||
T.copy(mixes, mixes_shared)
|
||||
|
||||
if T.get_thread_binding() < 32:
|
||||
##################################################################
|
||||
# _pre_split_mixes_fwd (post & comb)
|
||||
cm = T.alloc_fragment((hc_mult, hc_mult), T.float32)
|
||||
for j in T.Parallel(hc_mult):
|
||||
post_mix[i, j] = (
|
||||
T.sigmoid(
|
||||
mixes_shared[j + hc_mult] * hc_scale[1] + hc_base[j + hc_mult]
|
||||
)
|
||||
* hc_post_mult_value
|
||||
)
|
||||
for j, k in T.Parallel(hc_mult, hc_mult):
|
||||
cm[j, k] = (
|
||||
mixes_shared[j * hc_mult + k + hc_mult * 2] * hc_scale[2]
|
||||
+ hc_base[j * hc_mult + k + hc_mult * 2]
|
||||
)
|
||||
|
||||
##################################################################
|
||||
# _sinkhorn_fwd
|
||||
row_sum = T.alloc_fragment(hc_mult, T.float32)
|
||||
col_sum = T.alloc_fragment(hc_mult, T.float32)
|
||||
|
||||
# comb = comb.softmax(-1) + eps
|
||||
row_max = T.alloc_fragment(hc_mult, T.float32)
|
||||
T.reduce_max(cm, row_max, dim=1)
|
||||
for j, k in T.Parallel(hc_mult, hc_mult):
|
||||
cm[j, k] = T.exp(cm[j, k] - row_max[j])
|
||||
T.reduce_sum(cm, row_sum, dim=1)
|
||||
for j, k in T.Parallel(hc_mult, hc_mult):
|
||||
cm[j, k] = cm[j, k] / row_sum[j] + hc_sinkhorn_eps
|
||||
|
||||
# comb = comb / (comb.sum(-2) + eps)
|
||||
T.reduce_sum(cm, col_sum, dim=0)
|
||||
for j, k in T.Parallel(hc_mult, hc_mult):
|
||||
cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps)
|
||||
|
||||
for _ in T.serial(sinkhorn_repeat - 1):
|
||||
# comb = comb / (comb.sum(-1) + eps)
|
||||
T.reduce_sum(cm, row_sum, dim=1)
|
||||
for j, k in T.Parallel(hc_mult, hc_mult):
|
||||
cm[j, k] = cm[j, k] / (row_sum[j] + hc_sinkhorn_eps)
|
||||
|
||||
# comb = comb / (comb.sum(-2) + eps)
|
||||
T.reduce_sum(cm, col_sum, dim=0)
|
||||
for j, k in T.Parallel(hc_mult, hc_mult):
|
||||
cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps)
|
||||
|
||||
# save comb_mix to global memory
|
||||
for j, k in T.Parallel(hc_mult, hc_mult):
|
||||
comb_mix[i, j * hc_mult + k] = cm[j, k]
|
||||
else:
|
||||
##################################################################
|
||||
# _pre_split_mixes_fwd (pre)
|
||||
pre_mix_shared = T.alloc_shared(hc_mult, T.float32)
|
||||
for j in T.Parallel(hc_mult):
|
||||
pre_mix_shared[j] = (
|
||||
T.sigmoid(
|
||||
mixes_shared[j] * hc_scale[0] + hc_base[j],
|
||||
)
|
||||
+ hc_pre_eps
|
||||
)
|
||||
###################################################################
|
||||
# _pre_apply_mix_fwd
|
||||
for i0_h in T.Pipelined(hidden_size // hidden_block, num_stages=2):
|
||||
xs = T.alloc_shared((hc_mult, hidden_block), T.float32)
|
||||
xl = T.alloc_fragment((hc_mult, hidden_block), T.float32)
|
||||
T.copy(residual[i, 0, i0_h * hidden_block], xs)
|
||||
T.copy(xs, xl)
|
||||
|
||||
ol = T.alloc_fragment(hidden_block, T.float32)
|
||||
T.clear(ol)
|
||||
|
||||
for i_hc in T.serial(hc_mult):
|
||||
pre = pre_mix_shared[i_hc]
|
||||
for i1_h in T.Parallel(hidden_block):
|
||||
ol[i1_h] += pre * xl[i_hc, i1_h]
|
||||
|
||||
T.copy(ol, layer_input[i, i0_h * hidden_block])
|
||||
T.pdl_trigger()
|
||||
|
||||
|
||||
def mhc_pre(
|
||||
residual: torch.Tensor,
|
||||
fn: torch.Tensor,
|
||||
hc_scale: torch.Tensor,
|
||||
hc_base: torch.Tensor,
|
||||
rms_eps: float,
|
||||
hc_pre_eps: float,
|
||||
hc_sinkhorn_eps: float,
|
||||
hc_post_mult_value: float,
|
||||
sinkhorn_repeat: int,
|
||||
n_splits: int = 1,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass for mHC pre block.
|
||||
|
||||
Args:
|
||||
residual: shape (..., hc_mult, hidden_size), dtype torch.bfloat16
|
||||
fn: shape (hc_mult3, hc_mult * hidden_size), dtype torch.float32
|
||||
hc_scale: shape (3,), dtype torch.float32
|
||||
hc_base: shape (hc_mult3,), dtype torch.float32
|
||||
rms_eps: RMS normalization epsilon
|
||||
hc_pre_eps: pre-mix epsilon
|
||||
hc_sinkhorn_eps: sinkhorn epsilon
|
||||
hc_post_mult_value: post-mix multiplier value
|
||||
sinkhorn_repeat: number of sinkhorn iterations
|
||||
n_splits: split-k factor;
|
||||
|
||||
Returns:
|
||||
post_mix: shape (..., hc_mult), dtype torch.float32
|
||||
comb_mix: shape (..., hc_mult, hc_mult), dtype torch.float32
|
||||
layer_input: shape (..., hidden_size), dtype torch.bfloat16
|
||||
"""
|
||||
|
||||
# Validate shapes
|
||||
assert residual.dtype == torch.bfloat16
|
||||
assert fn.dtype == torch.float32
|
||||
assert hc_scale.dtype == torch.float32
|
||||
assert hc_base.dtype == torch.float32
|
||||
|
||||
hc_mult = residual.shape[-2]
|
||||
hidden_size = residual.shape[-1]
|
||||
hc_mult2 = hc_mult * hc_mult
|
||||
hc_mult3 = hc_mult * 2 + hc_mult2
|
||||
|
||||
hc_hidden_size = hc_mult * hidden_size
|
||||
assert fn.shape[0] == hc_mult3
|
||||
assert fn.shape[1] == hc_hidden_size
|
||||
assert hc_scale.shape == (3,)
|
||||
assert hc_base.shape == (hc_mult3,)
|
||||
|
||||
outer_shape = residual.shape[:-2]
|
||||
|
||||
residual_flat = residual.view(-1, hc_mult, hidden_size)
|
||||
num_tokens = residual_flat.shape[0]
|
||||
fn_flat = fn
|
||||
|
||||
# these number are from deepgemm kernel impl
|
||||
block_k = 64
|
||||
block_m = 64
|
||||
n_splits = compute_num_split(block_k, hc_hidden_size, cdiv(num_tokens, block_m))
|
||||
|
||||
post_mix = torch.empty(
|
||||
num_tokens,
|
||||
hc_mult,
|
||||
dtype=torch.float32,
|
||||
device=residual.device,
|
||||
)
|
||||
comb_mix = torch.empty(
|
||||
num_tokens,
|
||||
hc_mult2,
|
||||
dtype=torch.float32,
|
||||
device=residual.device,
|
||||
)
|
||||
layer_input = torch.empty(
|
||||
num_tokens,
|
||||
hidden_size,
|
||||
dtype=torch.bfloat16,
|
||||
device=residual.device,
|
||||
)
|
||||
|
||||
gemm_out_mul = torch.empty(
|
||||
n_splits,
|
||||
num_tokens,
|
||||
hc_mult3,
|
||||
dtype=torch.float32,
|
||||
device=residual.device,
|
||||
)
|
||||
gemm_out_sqrsum = torch.empty(
|
||||
n_splits,
|
||||
num_tokens,
|
||||
dtype=torch.float32,
|
||||
device=residual.device,
|
||||
)
|
||||
|
||||
from vllm.utils.deep_gemm import tf32_hc_prenorm_gemm
|
||||
|
||||
tf32_hc_prenorm_gemm(
|
||||
residual_flat.view(num_tokens, hc_mult * hidden_size),
|
||||
fn_flat,
|
||||
gemm_out_mul,
|
||||
gemm_out_sqrsum,
|
||||
n_splits,
|
||||
)
|
||||
|
||||
mhc_pre_big_fuse_tilelang(
|
||||
gemm_out_mul,
|
||||
gemm_out_sqrsum,
|
||||
hc_scale,
|
||||
hc_base,
|
||||
residual_flat,
|
||||
post_mix,
|
||||
comb_mix,
|
||||
layer_input,
|
||||
hidden_size,
|
||||
rms_eps,
|
||||
hc_pre_eps,
|
||||
hc_sinkhorn_eps,
|
||||
hc_post_mult_value,
|
||||
sinkhorn_repeat,
|
||||
n_splits,
|
||||
hc_mult,
|
||||
)
|
||||
|
||||
post_mix = post_mix.view(*outer_shape, hc_mult, 1)
|
||||
comb_mix = comb_mix.view(*outer_shape, hc_mult, hc_mult)
|
||||
layer_input = layer_input.view(*outer_shape, hidden_size)
|
||||
|
||||
return post_mix, comb_mix, layer_input
|
||||
|
||||
|
||||
def _mhc_pre_fake(
|
||||
residual: torch.Tensor,
|
||||
fn: torch.Tensor,
|
||||
hc_scale: torch.Tensor,
|
||||
hc_base: torch.Tensor,
|
||||
rms_eps: float,
|
||||
hc_pre_eps: float,
|
||||
hc_sinkhorn_eps: float,
|
||||
hc_post_mult_value: float,
|
||||
sinkhorn_repeat: int,
|
||||
n_splits: int = 1,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hc_mult = residual.shape[-2]
|
||||
hidden_size = residual.shape[-1]
|
||||
outer_shape = residual.shape[:-2]
|
||||
|
||||
# Create empty tensors with correct shapes for meta device / shape inference
|
||||
post_mix = torch.empty(
|
||||
*outer_shape,
|
||||
hc_mult,
|
||||
1,
|
||||
dtype=torch.float32,
|
||||
device=residual.device,
|
||||
)
|
||||
comb_mix = torch.empty(
|
||||
*outer_shape,
|
||||
hc_mult,
|
||||
hc_mult,
|
||||
dtype=torch.float32,
|
||||
device=residual.device,
|
||||
)
|
||||
layer_input = torch.empty(
|
||||
*outer_shape,
|
||||
hidden_size,
|
||||
dtype=torch.bfloat16,
|
||||
device=residual.device,
|
||||
)
|
||||
|
||||
return post_mix, comb_mix, layer_input
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
pass_configs={
|
||||
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
||||
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
||||
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
|
||||
},
|
||||
)
|
||||
def mhc_post_tilelang(
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
d,
|
||||
x,
|
||||
hc: int,
|
||||
hidden: int,
|
||||
n_thr: int = 128,
|
||||
h_blk: int = 1024,
|
||||
) -> tilelang.JITKernel:
|
||||
# rename for shorter code
|
||||
n = T.dynamic("num_tokens")
|
||||
h = hidden
|
||||
|
||||
h_blk = math.gcd(hidden, h_blk)
|
||||
a: T.Tensor((n, hc, hc), T.float32) # type: ignore[no-redef, valid-type]
|
||||
b: T.Tensor((n, hc, h), T.bfloat16) # type: ignore[no-redef, valid-type]
|
||||
c: T.Tensor((n, hc), T.float32) # type: ignore[no-redef, valid-type]
|
||||
d: T.Tensor((n, h), T.bfloat16) # type: ignore[no-redef, valid-type]
|
||||
x: T.Tensor((n, hc, h), T.bfloat16) # type: ignore[no-redef, valid-type]
|
||||
with T.Kernel(n, threads=n_thr) as i_n:
|
||||
x_shared = T.alloc_shared((hc, h_blk), T.bfloat16)
|
||||
b_shared = T.alloc_shared((hc, h_blk), T.bfloat16)
|
||||
d_shared = T.alloc_shared(h_blk, T.bfloat16)
|
||||
|
||||
x_local = T.alloc_fragment((hc, h_blk), T.float32)
|
||||
b_local = T.alloc_fragment((hc, h_blk), T.float32)
|
||||
d_local = T.alloc_fragment(h_blk, T.float32)
|
||||
|
||||
a_local = T.alloc_fragment((hc, hc), T.float32)
|
||||
c_local = T.alloc_fragment(hc, T.float32)
|
||||
T.pdl_sync()
|
||||
T.copy(a[i_n, 0, 0], a_local)
|
||||
T.copy(c[i_n, 0], c_local)
|
||||
|
||||
for i0_h in T.Pipelined(T.ceildiv(h, h_blk), num_stages=2):
|
||||
T.copy(b[i_n, 0, i0_h * h_blk], b_shared)
|
||||
T.copy(d[i_n, i0_h * h_blk], d_shared)
|
||||
|
||||
T.copy(b_shared, b_local)
|
||||
T.copy(d_shared, d_local)
|
||||
for i_hco, i1_h in T.Parallel(hc, h_blk):
|
||||
x_local[i_hco, i1_h] = c_local[i_hco] * d_local[i1_h]
|
||||
for i_hci in T.serial(hc):
|
||||
x_local[i_hco, i1_h] += a_local[i_hci, i_hco] * b_local[i_hci, i1_h]
|
||||
T.copy(x_local, x_shared)
|
||||
|
||||
T.copy(x_shared, x[i_n, 0, i0_h * h_blk])
|
||||
T.pdl_trigger()
|
||||
|
||||
|
||||
def mhc_post(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
post_layer_mix: torch.Tensor,
|
||||
comb_res_mix: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
out = torch.empty_like(residual)
|
||||
mhc_post_tilelang(
|
||||
comb_res_mix,
|
||||
residual,
|
||||
post_layer_mix.squeeze(-1),
|
||||
x,
|
||||
out,
|
||||
residual.shape[-2],
|
||||
residual.shape[-1],
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _mhc_post_fake(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
post_layer_mix: torch.Tensor,
|
||||
comb_res_mix: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(residual)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="mhc_pre",
|
||||
op_func=mhc_pre,
|
||||
mutates_args=[],
|
||||
fake_impl=_mhc_pre_fake,
|
||||
)
|
||||
direct_register_custom_op(
|
||||
op_name="mhc_post",
|
||||
op_func=mhc_post,
|
||||
mutates_args=[],
|
||||
fake_impl=_mhc_post_fake,
|
||||
)
|
||||
@@ -32,6 +32,7 @@ QuantizationMethods = Literal[
|
||||
"inc",
|
||||
"mxfp4",
|
||||
"gpt_oss_mxfp4",
|
||||
"deepseek_v4_fp8",
|
||||
"cpu_awq",
|
||||
"online",
|
||||
# Below are values of the OnlineQuantScheme enum, specified as strings to
|
||||
@@ -112,6 +113,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
# lazy import to avoid triggering `torch.compile` too early
|
||||
from vllm.config.quantization import OnlineQuantScheme
|
||||
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
|
||||
from vllm.model_executor.models.deepseek_v4 import DeepseekV4FP8Config
|
||||
|
||||
from .awq import AWQConfig
|
||||
from .awq_marlin import AWQMarlinConfig
|
||||
@@ -163,6 +165,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"inc": INCConfig,
|
||||
"mxfp4": Mxfp4Config,
|
||||
"gpt_oss_mxfp4": GptOssMxfp4Config,
|
||||
"deepseek_v4_fp8": DeepseekV4FP8Config,
|
||||
"cpu_awq": CPUAWQConfig,
|
||||
"humming": HummingConfig,
|
||||
"online": OnlineQuantizationConfig,
|
||||
|
||||
+1
@@ -265,6 +265,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.is_monolithic
|
||||
assert self.moe_kernel is not None
|
||||
|
||||
+1
@@ -305,6 +305,7 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert not layer.enable_eplb, "EPLB not supported for W4A8-int MoE yet."
|
||||
assert layer.activation in (
|
||||
|
||||
+1
@@ -367,6 +367,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply_monolithic(
|
||||
|
||||
+1
@@ -168,6 +168,7 @@ class CompressedTensorsW8A8Mxfp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply_monolithic(
|
||||
|
||||
+1
@@ -517,6 +517,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.kernel_backend == "Flashinfer"
|
||||
return flashinfer_trtllm_mxint4_moe(
|
||||
|
||||
@@ -269,6 +269,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.is_scale_e8m0 = getattr(quant_config, "is_scale_e8m0", False)
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.input_dtype = get_current_vllm_config().model_config.dtype
|
||||
@@ -362,6 +363,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
input_size_per_partition,
|
||||
self.weight_block_size,
|
||||
weight_loader,
|
||||
scale_dtype=(torch.float8_e8m0fnu if self.is_scale_e8m0 else None),
|
||||
)
|
||||
# The weight_scale_inv name is intentional for deepseekv3
|
||||
layer.register_parameter("weight_scale_inv", scale)
|
||||
@@ -866,6 +868,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.is_monolithic
|
||||
assert self.moe_kernel is not None
|
||||
|
||||
@@ -950,6 +950,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.is_monolithic
|
||||
assert self.moe_kernel is not None
|
||||
@@ -1442,6 +1443,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.is_monolithic
|
||||
assert self.moe_kernel is not None
|
||||
@@ -1920,6 +1922,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from flashinfer.fused_moe.core import (
|
||||
ActivationType,
|
||||
|
||||
@@ -20,10 +20,12 @@ from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
|
||||
TRITON_BACKENDS,
|
||||
Mxfp4MoeBackend,
|
||||
convert_gpt_oss_weight_to_mxfp4_moe_kernel_format,
|
||||
convert_weight_to_mxfp4_moe_kernel_format,
|
||||
make_mxfp4_moe_kernel,
|
||||
make_mxfp4_moe_quant_config,
|
||||
mxfp4_round_up_hidden_size_and_intermediate_size,
|
||||
select_gpt_oss_mxfp4_moe_backend,
|
||||
select_mxfp4_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@@ -217,6 +219,7 @@ class GptOssMxfp4MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
w13_weight_scale.quant_method = "block"
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
@@ -242,6 +245,7 @@ class GptOssMxfp4MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
w2_weight_scale.quant_method = "block"
|
||||
|
||||
if self.moe.has_bias:
|
||||
w13_bias = torch.nn.Parameter(
|
||||
@@ -397,6 +401,9 @@ class GptOssMxfp4MoEMethod(FusedMoEMethodBase):
|
||||
w2_scale=w2_scale,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
gemm1_alpha=1.702,
|
||||
gemm1_beta=1.0,
|
||||
swiglu_limit=7.0,
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
@@ -437,6 +444,332 @@ class GptOssMxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.is_monolithic
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply_monolithic(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
router_logits=router_logits,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
"""MXFP4 MoE quantization method."""
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.weight_dtype = "mxfp4"
|
||||
self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe)
|
||||
|
||||
self.max_capture_size = (
|
||||
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
|
||||
)
|
||||
|
||||
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
self.moe_kernel: mk.FusedMoEKernel | None = None
|
||||
|
||||
# Used for triton kernel precision configs
|
||||
self.w13_precision_config = None
|
||||
self.w2_precision_config = None
|
||||
|
||||
@property
|
||||
def skip_forward_padding(self) -> bool:
|
||||
# SM100_FI_MXFP4_MXFP8_TRTLLM supports padding with mxfp8 quant
|
||||
# so can skip the padding in the forward before applying the moe method
|
||||
return self.mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8
|
||||
|
||||
def maybe_roundup_sizes(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
act_dtype: torch.dtype,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
) -> tuple[int, int]:
|
||||
hidden_size, intermediate_size_per_partition = super().maybe_roundup_sizes(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size_per_partition=intermediate_size_per_partition,
|
||||
act_dtype=act_dtype,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
)
|
||||
return mxfp4_round_up_hidden_size_and_intermediate_size(
|
||||
self.mxfp4_backend, hidden_size, intermediate_size_per_partition
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
self.num_experts = num_experts
|
||||
weight_dtype = torch.uint8
|
||||
scale_dtype = torch.uint8
|
||||
mxfp4_block = 32
|
||||
|
||||
layer.params_dtype = params_dtype
|
||||
layer.num_experts = num_experts
|
||||
self.intermediate_size = intermediate_size_per_partition
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // 2,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
w13_weight_scale.quant_method = "block"
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // 2,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
w2_weight_scale.quant_method = "block"
|
||||
|
||||
if self.moe.has_bias:
|
||||
w13_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_bias", w13_bias)
|
||||
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||
|
||||
w2_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_bias", w2_bias)
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
|
||||
def _setup_kernel(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w13_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
num_experts = self.num_experts
|
||||
intermediate_size = self.intermediate_size
|
||||
hidden_size = self.hidden_size
|
||||
sf_block_size = 32
|
||||
|
||||
# Shape assertions
|
||||
assert (
|
||||
w13.dim() == 3
|
||||
and w13.shape[0] == num_experts
|
||||
and w13.shape[1] == intermediate_size * 2
|
||||
and w13.shape[2] == hidden_size // 2
|
||||
)
|
||||
assert (
|
||||
w13_scale.dim() == 3
|
||||
and w13_scale.shape[0] == num_experts
|
||||
and w13_scale.shape[1] == intermediate_size * 2
|
||||
and w13_scale.shape[2] == hidden_size // sf_block_size
|
||||
)
|
||||
assert (
|
||||
w2.dim() == 3
|
||||
and w2.shape[0] == num_experts
|
||||
and w2.shape[1] == hidden_size
|
||||
and w2.shape[2] == intermediate_size // 2
|
||||
)
|
||||
assert (
|
||||
w2_scale.dim() == 3
|
||||
and w2_scale.shape[1] == hidden_size
|
||||
and w2_scale.shape[2] == intermediate_size // sf_block_size
|
||||
)
|
||||
if w13_bias is not None:
|
||||
assert (
|
||||
w13_bias.dim() == 2
|
||||
and w13_bias.shape[0] == num_experts
|
||||
and w13_bias.shape[1] == intermediate_size * 2
|
||||
)
|
||||
if w2_bias is not None:
|
||||
assert (
|
||||
w2_bias.dim() == 2
|
||||
and w2_bias.shape[0] == num_experts
|
||||
and w2_bias.shape[1] == hidden_size
|
||||
)
|
||||
|
||||
# Convert weights to kernel format
|
||||
w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = (
|
||||
convert_weight_to_mxfp4_moe_kernel_format(
|
||||
mxfp4_backend=self.mxfp4_backend,
|
||||
layer=layer,
|
||||
w13_weight=w13,
|
||||
w2_weight=w2,
|
||||
w13_weight_scale=w13_scale,
|
||||
w2_weight_scale=w2_scale,
|
||||
w13_bias=w13_bias,
|
||||
w2_bias=w2_bias,
|
||||
_cache_permute_indices=self._cache_permute_indices,
|
||||
)
|
||||
)
|
||||
|
||||
# For TRITON backends, weights are wrapped tensors from triton_kernels
|
||||
# that don't support .detach(). Manually assign parameters.
|
||||
if self.mxfp4_backend not in TRITON_BACKENDS:
|
||||
replace_parameter(layer, "w13_weight", w13)
|
||||
replace_parameter(layer, "w2_weight", w2)
|
||||
replace_parameter(layer, "w13_weight_scale", w13_scale)
|
||||
replace_parameter(layer, "w2_weight_scale", w2_scale)
|
||||
else:
|
||||
layer.w13_weight = w13
|
||||
layer.w2_weight = w2
|
||||
self.w13_precision_config = w13_scale
|
||||
self.w2_precision_config = w2_scale
|
||||
|
||||
if w13_bias is not None and w2_bias is not None:
|
||||
replace_parameter(layer, "w13_bias", w13_bias)
|
||||
replace_parameter(layer, "w2_bias", w2_bias)
|
||||
|
||||
# Build quant config
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
|
||||
# Build kernel (modular or monolithic)
|
||||
if self.moe_quant_config is not None and self.experts_cls is not None:
|
||||
self.moe_kernel = make_mxfp4_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
mxfp4_backend=self.mxfp4_backend,
|
||||
experts_cls=self.experts_cls,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
shared_experts=layer.shared_experts,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
w13 = layer.w13_weight
|
||||
w2 = layer.w2_weight
|
||||
w13_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
w13_bias = getattr(layer, "w13_bias", None)
|
||||
w2_bias = getattr(layer, "w2_bias", None)
|
||||
|
||||
if self.mxfp4_backend == Mxfp4MoeBackend.NONE:
|
||||
return
|
||||
|
||||
self._setup_kernel(layer, w13, w2, w13_scale, w2_scale, w13_bias, w2_bias)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
w1_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
w1_bias = getattr(layer, "w13_bias", None)
|
||||
w2_bias = getattr(layer, "w2_bias", None)
|
||||
swiglu_limit = getattr(layer, "swiglu_limit", None)
|
||||
|
||||
if self.mxfp4_backend in TRITON_BACKENDS:
|
||||
assert self.w13_precision_config is not None
|
||||
assert self.w2_precision_config is not None
|
||||
w1_scale = self.w13_precision_config
|
||||
w2_scale = self.w2_precision_config
|
||||
|
||||
return make_mxfp4_moe_quant_config(
|
||||
mxfp4_backend=self.mxfp4_backend,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel "
|
||||
"initialization logic. This function should not be called."
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
assert not self.is_monolithic
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
expert_map=layer.expert_map,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.is_monolithic
|
||||
assert self.moe_kernel is not None
|
||||
|
||||
@@ -130,6 +130,7 @@ class OnlineMoEMethodBase(FusedMoEMethodBase):
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.is_monolithic
|
||||
assert self.moe_kernel is not None
|
||||
|
||||
@@ -1457,6 +1457,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.is_monolithic
|
||||
assert self.moe_kernel is not None
|
||||
|
||||
@@ -149,6 +149,148 @@ def _per_token_group_quant_fp8(
|
||||
tl.store(y_s_ptr, y_s)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _silu_mul_quant_fp8_packed_kernel(
|
||||
input_ptr,
|
||||
output_q_ptr,
|
||||
output_scale_ptr,
|
||||
M,
|
||||
input_stride_m,
|
||||
output_q_stride_m,
|
||||
output_scale_stride_k,
|
||||
clamp_limit,
|
||||
N: tl.constexpr,
|
||||
NUM_GROUPS: tl.constexpr,
|
||||
fp8_min: tl.constexpr,
|
||||
fp8_max: tl.constexpr,
|
||||
GROUP_SIZE: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
HAS_CLAMP: tl.constexpr,
|
||||
):
|
||||
N_2: tl.constexpr = N // 2
|
||||
|
||||
pid_pack = tl.program_id(0)
|
||||
pid_m = tl.program_id(1)
|
||||
m_offset = pid_m * BLOCK_M
|
||||
|
||||
if m_offset >= M:
|
||||
return
|
||||
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, GROUP_SIZE)
|
||||
row_mask = (m_offset + offs_m) < M
|
||||
|
||||
base_row_offset = (m_offset + offs_m[:, None]) * input_stride_m
|
||||
base_out_offset = (m_offset + offs_m[:, None]) * output_q_stride_m
|
||||
|
||||
packed_scale = tl.zeros((BLOCK_M,), dtype=tl.int32)
|
||||
|
||||
for pack_idx in tl.static_range(4):
|
||||
group_id = pid_pack * 4 + pack_idx
|
||||
|
||||
if group_id < NUM_GROUPS:
|
||||
n_offset = group_id * GROUP_SIZE
|
||||
|
||||
act_ptrs = input_ptr + base_row_offset + n_offset + offs_n[None, :]
|
||||
act_in = tl.load(act_ptrs, mask=row_mask[:, None], other=0.0)
|
||||
|
||||
mul_ptrs = act_ptrs + N_2
|
||||
mul_in = tl.load(mul_ptrs, mask=row_mask[:, None], other=0.0)
|
||||
|
||||
act_f32 = act_in.to(tl.float32)
|
||||
mul_f32 = mul_in.to(tl.float32)
|
||||
|
||||
if HAS_CLAMP:
|
||||
act_f32 = tl.minimum(act_f32, clamp_limit)
|
||||
mul_f32 = tl.clamp(mul_f32, -clamp_limit, clamp_limit)
|
||||
|
||||
y = (act_f32 / (1.0 + tl.exp(-act_f32))) * mul_f32
|
||||
# Round through bf16 to match unfused precision path
|
||||
y = y.to(tl.bfloat16).to(tl.float32)
|
||||
|
||||
absmax = tl.max(tl.abs(y), axis=1)
|
||||
|
||||
scale_raw = tl.maximum(absmax / fp8_max, 1e-10)
|
||||
exponent = tl.ceil(tl.log2(scale_raw))
|
||||
scale = tl.math.exp2(exponent)
|
||||
|
||||
y_q = tl.clamp(y / scale[:, None], fp8_min, fp8_max)
|
||||
|
||||
out_q_ptrs = output_q_ptr + base_out_offset + n_offset + offs_n[None, :]
|
||||
tl.store(
|
||||
out_q_ptrs,
|
||||
y_q.to(output_q_ptr.dtype.element_ty),
|
||||
mask=row_mask[:, None],
|
||||
)
|
||||
|
||||
exponent_biased = tl.clamp(exponent + 127.0, 0.0, 255.0).to(tl.int32)
|
||||
packed_scale = packed_scale | (exponent_biased << (pack_idx * 8))
|
||||
|
||||
scale_ptrs = output_scale_ptr + pid_pack * output_scale_stride_k + m_offset + offs_m
|
||||
tl.store(scale_ptrs, packed_scale, mask=row_mask)
|
||||
|
||||
|
||||
def silu_mul_quant_fp8_packed_triton(
|
||||
input: torch.Tensor,
|
||||
group_size: int = 128,
|
||||
output_q: torch.Tensor | None = None,
|
||||
clamp_limit: float | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert input.dim() == 2
|
||||
assert input.is_contiguous()
|
||||
|
||||
M, N = input.shape
|
||||
N_2 = N // 2
|
||||
|
||||
assert N_2 % group_size == 0
|
||||
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
finfo = torch.finfo(fp8_dtype)
|
||||
fp8_min, fp8_max = finfo.min, finfo.max
|
||||
|
||||
num_groups_per_row = N_2 // group_size
|
||||
num_packed_groups = (num_groups_per_row + 3) // 4
|
||||
tma_aligned_M = ((M + 3) // 4) * 4
|
||||
|
||||
if output_q is None:
|
||||
output_q = torch.empty((M, N_2), dtype=fp8_dtype, device=input.device)
|
||||
|
||||
output_scale_packed = torch.zeros(
|
||||
(num_packed_groups, tma_aligned_M),
|
||||
dtype=torch.int32,
|
||||
device=input.device,
|
||||
).T[:M, :]
|
||||
|
||||
BLOCK_M = 8
|
||||
grid = (num_packed_groups, (M + BLOCK_M - 1) // BLOCK_M)
|
||||
|
||||
num_warps = max(4, group_size // 32)
|
||||
num_stages = 2
|
||||
|
||||
has_clamp = clamp_limit is not None
|
||||
_silu_mul_quant_fp8_packed_kernel[grid](
|
||||
input,
|
||||
output_q,
|
||||
output_scale_packed,
|
||||
M,
|
||||
input.stride(0),
|
||||
output_q.stride(0),
|
||||
output_scale_packed.stride(1),
|
||||
clamp_limit if has_clamp else 0.0,
|
||||
N=N,
|
||||
NUM_GROUPS=num_groups_per_row,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
GROUP_SIZE=group_size,
|
||||
BLOCK_M=BLOCK_M,
|
||||
HAS_CLAMP=has_clamp,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
|
||||
return output_q, output_scale_packed
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _silu_mul_per_token_group_quant_fp8_colmajor(
|
||||
y_ptr, # [M, N]
|
||||
@@ -823,19 +965,65 @@ def requant_weight_ue8m0_inplace(
|
||||
s_old.copy_(s_requant)
|
||||
|
||||
|
||||
def _upcast_e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor:
|
||||
"""Upcast E8M0 (exponent-only) scale to float32.
|
||||
|
||||
E8M0 stores only the 8-bit biased exponent (bias=127). To convert
|
||||
to float32 we place those 8 bits into the exponent field of an
|
||||
IEEE-754 float32 (bits 23-30) with sign=0 and mantissa=0.
|
||||
"""
|
||||
exp_bits = scale.view(torch.uint8).to(torch.int32)
|
||||
fp32_bits = exp_bits << 23
|
||||
return fp32_bits.view(torch.float32)
|
||||
|
||||
|
||||
def deepgemm_post_process_fp8_weight_block(
|
||||
wq: torch.Tensor, ws: torch.Tensor, quant_block_shape: tuple[int], use_e8m0: bool
|
||||
wq: torch.Tensor,
|
||||
ws: torch.Tensor,
|
||||
quant_block_shape: tuple[int, ...],
|
||||
use_e8m0: bool,
|
||||
is_bmm: bool = False,
|
||||
bmm_batch_size: int = 0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert wq.dtype == torch.float8_e4m3fn, (
|
||||
"Expected quantized tensor dtype "
|
||||
f"to be torch.float8_e4m3fn, got {wq.dtype} instead."
|
||||
)
|
||||
assert ws.dtype == torch.float32, (
|
||||
f"Expected tensor scales dtype to be torch.float32, got {ws.dtype} instead"
|
||||
)
|
||||
|
||||
if use_e8m0:
|
||||
requant_weight_ue8m0_inplace(wq, ws, block_size=quant_block_shape)
|
||||
if ws.dtype == torch.float8_e8m0fnu:
|
||||
# Scales already in E8M0 from checkpoint — upcast to fp32
|
||||
# and skip requantization (weights already have power-of-two scales).
|
||||
ws = _upcast_e8m0_to_fp32(ws)
|
||||
else:
|
||||
assert ws.dtype == torch.float32, (
|
||||
f"Expected tensor scales dtype to be torch.float32 or "
|
||||
f"torch.float8_e8m0fnu, got {ws.dtype} instead"
|
||||
)
|
||||
if use_e8m0:
|
||||
requant_weight_ue8m0_inplace(wq, ws, block_size=quant_block_shape)
|
||||
|
||||
if is_bmm:
|
||||
# Reshape 2D weight/scale to 3D for grouped BMM (einsum):
|
||||
# wq: (g*r, d) -> (g, r, d)
|
||||
# ws: (g*r/128, d/128) -> (g, r/128, d/128)
|
||||
g = bmm_batch_size
|
||||
assert wq.ndim == 2 and ws.ndim == 2
|
||||
d = wq.size(1)
|
||||
r = wq.size(0) // g
|
||||
wq = wq.view(g, r, d)
|
||||
ws = ws.view(g, r // quant_block_shape[0], d // quant_block_shape[1])
|
||||
# Pre-transform scale with recipe=(1, 128, 128) to broadcast + pack
|
||||
# into TMA-aligned UE8M0 (INT32) layout. At runtime fp8_einsum uses
|
||||
# recipe=(1, 1, 128) which sees INT dtype and skips re-transform.
|
||||
dg_ws = transform_sf_into_required_layout(
|
||||
sf=ws,
|
||||
mn=r,
|
||||
k=d,
|
||||
recipe=(1, quant_block_shape[0], quant_block_shape[1]),
|
||||
num_groups=g,
|
||||
is_sfa=False,
|
||||
)
|
||||
return wq, dg_ws
|
||||
|
||||
original_ndim = wq.ndim
|
||||
if wq.ndim == 2:
|
||||
@@ -984,11 +1172,13 @@ def create_fp8_scale_parameter(
|
||||
input_size_per_partition: int,
|
||||
block_size: list[int] | None,
|
||||
weight_loader: Callable | None,
|
||||
scale_dtype: torch.dtype | None = None,
|
||||
) -> torch.nn.Parameter:
|
||||
"""Create scale parameter based on quantization strategy."""
|
||||
dtype = scale_dtype if scale_dtype is not None else torch.float32
|
||||
if parameter_type == ChannelQuantScaleParameter:
|
||||
scale = parameter_type(
|
||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=dtype),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
@@ -1000,7 +1190,7 @@ def create_fp8_scale_parameter(
|
||||
data=torch.empty(
|
||||
(output_size_per_partition + block_n - 1) // block_n,
|
||||
(input_size_per_partition + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
dtype=dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
@@ -1008,13 +1198,14 @@ def create_fp8_scale_parameter(
|
||||
)
|
||||
elif parameter_type == PerTensorScaleParameter:
|
||||
scale = parameter_type(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
data=torch.empty(len(output_partition_sizes), dtype=dtype),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown parameter type: {parameter_type}")
|
||||
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
if dtype == torch.float32:
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
return scale
|
||||
|
||||
|
||||
@@ -7,7 +7,10 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from .base import RotaryEmbedding
|
||||
from .deepseek_scaling_rope import DeepseekScalingRotaryEmbedding
|
||||
from .deepseek_scaling_rope import (
|
||||
DeepseekScalingRotaryEmbedding,
|
||||
DeepseekV4ScalingRotaryEmbedding,
|
||||
)
|
||||
from .dual_chunk_rope import DualChunkRotaryEmbedding
|
||||
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
|
||||
from .dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding
|
||||
@@ -60,11 +63,13 @@ def get_rope(
|
||||
rope_parameters = rope_parameters or {}
|
||||
base = rope_parameters.get("rope_theta", 10000)
|
||||
scaling_type = rope_parameters.get("rope_type", "default")
|
||||
partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
|
||||
|
||||
if partial_rotary_factor <= 0.0 or partial_rotary_factor > 1.0:
|
||||
raise ValueError(f"{partial_rotary_factor=} must be between 0.0 and 1.0")
|
||||
rotary_dim = int(head_size * partial_rotary_factor)
|
||||
if rotary_dim := rope_parameters.get("rope_dim", None):
|
||||
pass
|
||||
else:
|
||||
partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
|
||||
if partial_rotary_factor <= 0.0 or partial_rotary_factor > 1.0:
|
||||
raise ValueError(f"{partial_rotary_factor=} must be between 0.0 and 1.0")
|
||||
rotary_dim = int(head_size * partial_rotary_factor)
|
||||
|
||||
key = (
|
||||
head_size,
|
||||
@@ -289,7 +294,11 @@ def get_rope(
|
||||
"mscale_all_dim",
|
||||
)
|
||||
}
|
||||
rotary_emb = DeepseekScalingRotaryEmbedding(
|
||||
if rope_parameters.get("is_deepseek_v4", False):
|
||||
cls = DeepseekV4ScalingRotaryEmbedding
|
||||
else:
|
||||
cls = DeepseekScalingRotaryEmbedding
|
||||
rotary_emb = cls(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
original_max_position,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user