Merge branch 'main' into wentao-fix-NixlConnector-PD-+-Spec-Decode-acceptance-(2-GPUs)

This commit is contained in:
yewentao256
2026-06-03 18:14:26 +00:00
50 changed files with 1855 additions and 681 deletions
+15 -12
View File
@@ -112,6 +112,8 @@ endif()
#
# spinloop extension (pure CXX; must stay above the non-CUDA device branch so
# CPU builds define the target before the early return)
# This extension requires SABI 3.11 since it relies on Py_buffer support. Loading
# failure is handled gracefully on vLLM side for lower Python versions.
#
set(VLLM_SPINLOOP_EXT_SRC "csrc/spinloop.cpp")
set(SPINLOOP_COMPILE_FLAGS "")
@@ -309,14 +311,9 @@ set(VLLM_EXT_SRC
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
"csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/custom_all_reduce.cu"
"csrc/torch_bindings.cpp"
"csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu")
"csrc/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC
"csrc/minimax_reduce_rms_kernel.cu")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
@@ -503,12 +500,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND ES_MXFP8_GROUPED_MM_ARCHS)
set(SRCS
"csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu"
"csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu")
"csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu"
"csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${ES_MXFP8_GROUPED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_ES_MXFP8_GROUPED_MM_SM100=1")
message(STATUS "Building ES MXFP8 grouped kernels for archs: ${ES_MXFP8_GROUPED_MM_ARCHS}")
else()
@@ -598,7 +595,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
if (VLLM_GPU_LANG STREQUAL "HIP")
# Add QuickReduce kernels
# Add QuickReduce kernels (ROCm-only; not part of stable ABI migration).
list(APPEND VLLM_EXT_SRC
"csrc/custom_quickreduce.cu"
)
@@ -649,7 +646,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
"csrc/libtorch_stable/attention/paged_attention_v1.cu"
"csrc/libtorch_stable/attention/paged_attention_v2.cu"
"csrc/libtorch_stable/cache_kernels.cu"
"csrc/libtorch_stable/cache_kernels_fused.cu")
"csrc/libtorch_stable/cache_kernels.cu"
"csrc/libtorch_stable/cache_kernels_fused.cu"
"csrc/libtorch_stable/custom_all_reduce.cu"
"csrc/libtorch_stable/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC
@@ -659,7 +659,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/libtorch_stable/permute_cols.cu"
"csrc/libtorch_stable/quantization/awq/gemm_kernels.cu")
"csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu"
"csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu"
"csrc/libtorch_stable/quantization/awq/gemm_kernels.cu"
"csrc/libtorch_stable/minimax_reduce_rms_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${VLLM_STABLE_EXT_SRC}"
@@ -1,7 +1,11 @@
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#include "torch_utils.h"
#include <torch/csrc/stable/macros.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/csrc/stable/device.h>
#include "custom_all_reduce.cuh"
@@ -11,7 +15,7 @@ using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
torch::Tensor& rank_data, int64_t rank,
torch::stable::Tensor& rank_data, int64_t rank,
bool fully_connected) {
int world_size = fake_ipc_ptrs.size();
if (world_size > 8)
@@ -25,9 +29,9 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
for (int i = 0; i < world_size; i++) {
ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(fake_ipc_ptrs[i]);
}
return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(),
rank_data.numel(), rank, world_size,
fully_connected);
return (fptr_t) new vllm::CustomAllreduce(
ipc_ptrs, rank_data.mutable_data_ptr(), rank_data.numel(), rank,
world_size, fully_connected);
}
/**
@@ -46,10 +50,14 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
* 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK
*/
bool _is_weak_contiguous(torch::Tensor& t) {
return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
t.numel() * t.element_size());
bool _is_weak_contiguous(torch::stable::Tensor& t) {
if (t.is_contiguous()) {
return true;
}
int64_t storage_nbytes = 0;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_storage_size(t.get(), &storage_nbytes));
return storage_nbytes - t.storage_offset() * t.element_size() ==
static_cast<int64_t>(t.numel() * t.element_size());
}
/**
@@ -59,42 +67,45 @@ bool _is_weak_contiguous(torch::Tensor& t) {
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer.
*/
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
void all_reduce(fptr_t _fa, torch::stable::Tensor& inp,
torch::stable::Tensor& out, fptr_t _reg_buffer,
int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
const torch::stable::accelerator::DeviceGuard device_guard(
inp.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(inp.get_device_index());
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(_is_weak_contiguous(out));
TORCH_CHECK(_is_weak_contiguous(inp));
STD_TORCH_CHECK((inp.scalar_type()) == (out.scalar_type()));
STD_TORCH_CHECK((inp.numel()) == (out.numel()));
STD_TORCH_CHECK(_is_weak_contiguous(out));
STD_TORCH_CHECK(_is_weak_contiguous(inp));
auto input_size = inp.numel() * inp.element_size();
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
if (reg_buffer) {
TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size,
cudaMemcpyDeviceToDevice, stream));
STD_TORCH_CHECK((input_size) <= (reg_buffer_sz_bytes));
STD_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.const_data_ptr(), input_size,
cudaMemcpyDeviceToDevice, stream));
} else {
reg_buffer = inp.data_ptr();
reg_buffer = inp.mutable_data_ptr();
}
switch (out.scalar_type()) {
case at::ScalarType::Float: {
case torch::headeronly::ScalarType::Float: {
fa->allreduce<float>(stream, reinterpret_cast<float*>(reg_buffer),
reinterpret_cast<float*>(out.data_ptr()),
reinterpret_cast<float*>(out.mutable_data_ptr()),
out.numel());
break;
}
case at::ScalarType::Half: {
case torch::headeronly::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half*>(reg_buffer),
reinterpret_cast<half*>(out.data_ptr()), out.numel());
reinterpret_cast<half*>(out.mutable_data_ptr()),
out.numel());
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
case torch::headeronly::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
reinterpret_cast<nv_bfloat16*>(out.mutable_data_ptr()), out.numel());
break;
}
#endif
@@ -112,7 +123,7 @@ int64_t meta_size() { return sizeof(vllm::Signal); }
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
STD_TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
void* ipc_ptrs[8];
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
@@ -143,47 +154,49 @@ void register_graph_buffers(fptr_t _fa,
fa->register_graph_buffers(bytes, offsets);
}
std::tuple<fptr_t, torch::Tensor> allocate_shared_buffer_and_handle(
std::tuple<fptr_t, torch::stable::Tensor> allocate_shared_buffer_and_handle(
int64_t size) {
auto device_index = c10::cuda::current_device();
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
int device_index;
STD_CUDA_CHECK(cudaGetDevice(&device_index));
const torch::stable::accelerator::DeviceGuard device_guard(device_index);
void* buffer;
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
auto stream = c10::cuda::getCurrentCUDAStream().stream();
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
const cudaStream_t stream = get_current_cuda_stream(device_index);
STD_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
// Allocate buffer
#if defined(USE_ROCM)
// data buffers need to be "uncached" for signal on MI200
AT_CUDA_CHECK(
STD_CUDA_CHECK(
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
#else
AT_CUDA_CHECK(cudaMalloc((void**)&buffer, size));
STD_CUDA_CHECK(cudaMalloc((void**)&buffer, size));
#endif
AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
STD_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
STD_CUDA_CHECK(cudaStreamSynchronize(stream));
STD_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
// Create IPC memhandle for the allocated buffer.
// Will use it in open_mem_handle.
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto handle =
torch::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, options);
AT_CUDA_CHECK(
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data_ptr(), buffer));
auto handle = torch::stable::empty(
{static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))},
torch::headeronly::ScalarType::Byte, std::nullopt,
torch::stable::Device(torch::stable::DeviceType::CPU));
STD_CUDA_CHECK(cudaIpcGetMemHandle(
(cudaIpcMemHandle_t*)handle.mutable_data_ptr(), buffer));
return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
}
fptr_t open_mem_handle(torch::Tensor& mem_handle) {
fptr_t open_mem_handle(torch::stable::Tensor& mem_handle) {
void* ipc_ptr;
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data_ptr()),
STD_CUDA_CHECK(cudaIpcOpenMemHandle(
(void**)&ipc_ptr,
*((const cudaIpcMemHandle_t*)mem_handle.const_data_ptr()),
cudaIpcMemLazyEnablePeerAccess));
return reinterpret_cast<fptr_t>(ipc_ptr);
}
void free_shared_buffer(fptr_t buffer) {
AT_CUDA_CHECK(cudaFree(reinterpret_cast<void*>(buffer)));
STD_CUDA_CHECK(cudaFree(reinterpret_cast<void*>(buffer)));
}
@@ -28,7 +28,20 @@
* [bs*576, bs*576 + bs*8): UE8M0 scales, 7 real + 1 pad per token
*/
#include "torch_utils.h"
#include <torch/csrc/stable/macros.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/csrc/stable/device.h>
#include <cmath>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "type_convert.cuh"
#ifndef USE_ROCM
#include <cuda_fp8.h>
#else
@@ -37,14 +50,6 @@
#include <cuda_runtime.h>
#include <type_traits>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/cuda.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "type_convert.cuh"
#ifndef FINAL_MASK
#ifdef USE_ROCM
#define FINAL_MASK 0xffffffffffffffffULL
@@ -70,7 +75,7 @@ namespace deepseek_v4_fused_ops {
namespace {
inline int getSMVersion() {
auto* props = at::cuda::getCurrentDeviceProperties();
auto* props = get_device_prop();
return props->major * 10 + props->minor;
}
} // namespace
@@ -564,7 +569,7 @@ static void launchFusedDeepseekV4Templated(
// bf16 on pre-Ampere (sm_70/sm_75) because _typeConvert<BFloat16> is
// unavailable there. Refuse the launch loudly instead of silently
// skipping the work.
TORCH_CHECK(
STD_TORCH_CHECK(
sm_version >= 80,
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert requires sm_80+ "
"(Ampere or newer); got sm_",
@@ -635,7 +640,7 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
DISPATCH(64)
DISPATCH(128)
default:
TORCH_CHECK(false,
STD_TORCH_CHECK(false,
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert: "
"unsupported num_heads_q_padded=",
num_heads_q_padded,
@@ -650,71 +655,80 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
// ────────────────────────────────────────────────────────────────────────────
// Torch op wrapper
// ────────────────────────────────────────────────────────────────────────────
torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::Tensor const& q_in, // [N, num_heads_q, 512] bf16
torch::Tensor const& kv, // [N, 512] bf16 (read-only)
torch::Tensor& k_cache, // [num_blocks, block_bytes] uint8
torch::Tensor const& slot_mapping, // [N] int64
torch::Tensor const& position_ids, // [N] int64
torch::Tensor const& cos_sin_cache, // [max_pos, rope_dim] bf16
int64_t q_head_padded, // padded Q head count for output
torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::stable::Tensor const& q_in, // [N, num_heads_q, 512] bf16
torch::stable::Tensor const& kv, // [N, 512] bf16 (read-only)
torch::stable::Tensor& k_cache, // [num_blocks, block_bytes] uint8
torch::stable::Tensor const& slot_mapping, // [N] int64
torch::stable::Tensor const& position_ids, // [N] int64
torch::stable::Tensor const& cos_sin_cache, // [max_pos, rope_dim] bf16
int64_t q_head_padded, // padded Q head count for output
double eps, int64_t cache_block_size) {
TORCH_CHECK(q_in.is_cuda() && q_in.is_contiguous(),
"q_in must be contiguous CUDA");
TORCH_CHECK(kv.is_cuda() && kv.is_contiguous(), "kv must be contiguous CUDA");
TORCH_CHECK(k_cache.is_cuda(), "k_cache must be CUDA");
TORCH_CHECK(slot_mapping.is_cuda() && slot_mapping.dtype() == torch::kInt64,
"slot_mapping must be int64 CUDA");
TORCH_CHECK(position_ids.is_cuda() && position_ids.dtype() == torch::kInt64,
"position_ids must be int64 CUDA");
TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA");
TORCH_CHECK(q_in.dim() == 3 && q_in.size(2) == 512,
"q_in shape [N, num_heads_q, 512]");
TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
TORCH_CHECK(q_in.dtype() == kv.dtype(), "q_in and kv dtype must match");
TORCH_CHECK(q_head_padded >= q_in.size(1),
"q_head_padded must be >= q_in.size(1) (num_heads_q)");
TORCH_CHECK(k_cache.dtype() == torch::kUInt8, "k_cache must be uint8");
TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
"cos_sin_cache shape [max_pos, 64]");
TORCH_CHECK(cos_sin_cache.dtype() == torch::kFloat32,
"cos_sin_cache must be float32");
STD_TORCH_CHECK(q_in.device().is_cuda() && q_in.is_contiguous(),
"q_in must be contiguous CUDA");
STD_TORCH_CHECK(kv.device().is_cuda() && kv.is_contiguous(),
"kv must be contiguous CUDA");
STD_TORCH_CHECK(k_cache.device().is_cuda(), "k_cache must be CUDA");
STD_TORCH_CHECK(slot_mapping.device().is_cuda() &&
slot_mapping.scalar_type() ==
torch::headeronly::ScalarType::Long,
"slot_mapping must be int64 CUDA");
STD_TORCH_CHECK(position_ids.device().is_cuda() &&
position_ids.scalar_type() ==
torch::headeronly::ScalarType::Long,
"position_ids must be int64 CUDA");
STD_TORCH_CHECK(cos_sin_cache.device().is_cuda(), "cos_sin_cache must be CUDA");
STD_TORCH_CHECK(q_in.dim() == 3 && q_in.size(2) == 512,
"q_in shape [N, num_heads_q, 512]");
STD_TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
STD_TORCH_CHECK(q_in.scalar_type() == kv.scalar_type(),
"q_in and kv dtype must match");
STD_TORCH_CHECK(q_head_padded >= q_in.size(1),
"q_head_padded must be >= q_in.size(1) (num_heads_q)");
STD_TORCH_CHECK(k_cache.scalar_type() == torch::headeronly::ScalarType::Byte,
"k_cache must be uint8");
STD_TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
"cos_sin_cache shape [max_pos, 64]");
STD_TORCH_CHECK(cos_sin_cache.scalar_type() ==
torch::headeronly::ScalarType::Float,
"cos_sin_cache must be float32");
// With DP padding, slot_mapping can be shorter than q/kv/positions.
// Q-norm+RoPE runs on all q.size(0) rows (downstream attention uses them);
// KV quant+insert runs only on the first slot_mapping.size(0) rows.
int const num_tokens_full = static_cast<int>(q_in.size(0));
int const num_tokens_insert = static_cast<int>(slot_mapping.size(0));
TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
static_cast<int>(position_ids.size(0)) == num_tokens_full,
"q/kv/position_ids row counts must match");
TORCH_CHECK(num_tokens_insert <= num_tokens_full,
"slot_mapping must not exceed q row count");
STD_TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
static_cast<int>(position_ids.size(0)) == num_tokens_full,
"q/kv/position_ids row counts must match");
STD_TORCH_CHECK(num_tokens_insert <= num_tokens_full,
"slot_mapping must not exceed q row count");
int const num_heads_q = static_cast<int>(q_in.size(1));
int const num_heads_q_padded = static_cast<int>(q_head_padded);
int const cache_block_size_i = static_cast<int>(cache_block_size);
int const kv_block_stride = static_cast<int>(k_cache.stride(0));
at::cuda::OptionalCUDAGuard device_guard(device_of(q_in));
auto stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
q_in.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(q_in.get_device_index());
// Allocate the padded q output. The kernel writes every element (live
// region gets RMSNorm+RoPE; pad region gets zeros), so `empty` is safe.
torch::Tensor q_out = torch::empty(
{q_in.size(0), q_head_padded, q_in.size(2)}, q_in.options());
auto q_out = torch::stable::new_empty(
q_in, {q_in.size(0), q_head_padded, q_in.size(2)}, q_in.scalar_type());
VLLM_DISPATCH_HALF_TYPES(
VLLM_STABLE_DISPATCH_HALF_TYPES(
q_in.scalar_type(), "fused_deepseek_v4_qnorm_rope_kv_insert", [&] {
using qkv_scalar_t = scalar_t;
vllm::deepseek_v4_fused_ops::
launchFusedDeepseekV4QNormRopeKVRopeQuantInsert<qkv_scalar_t>(
reinterpret_cast<qkv_scalar_t const*>(q_in.data_ptr()),
reinterpret_cast<qkv_scalar_t*>(q_out.data_ptr()),
reinterpret_cast<qkv_scalar_t const*>(kv.data_ptr()),
reinterpret_cast<uint8_t*>(k_cache.data_ptr()),
reinterpret_cast<int64_t const*>(slot_mapping.data_ptr()),
reinterpret_cast<int64_t const*>(position_ids.data_ptr()),
cos_sin_cache.data_ptr<float>(), static_cast<float>(eps),
reinterpret_cast<qkv_scalar_t const*>(q_in.const_data_ptr()),
reinterpret_cast<qkv_scalar_t*>(q_out.mutable_data_ptr()),
reinterpret_cast<qkv_scalar_t const*>(kv.const_data_ptr()),
reinterpret_cast<uint8_t*>(k_cache.mutable_data_ptr()),
slot_mapping.const_data_ptr<int64_t>(),
position_ids.const_data_ptr<int64_t>(),
cos_sin_cache.const_data_ptr<float>(), static_cast<float>(eps),
num_tokens_full, num_tokens_insert, num_heads_q,
num_heads_q_padded, cache_block_size_i, kv_block_stride,
stream);
@@ -15,16 +15,19 @@
* limitations under the License.
*/
#include "torch_utils.h"
#include <torch/csrc/stable/macros.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/csrc/stable/device.h>
#include <cooperative_groups.h>
#include <cuda_runtime.h>
#include <torch/cuda.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "cuda_utils.h"
#include "core/registration.h"
#include "minimax_reduce_rms_kernel.h"
#include <algorithm>
@@ -611,7 +614,7 @@ int get_sm_count() {
static int sm_count = 0;
if (sm_count == 0) {
int device_id;
CUDA_CHECK(cudaGetDevice(&device_id));
STD_CUDA_CHECK(cudaGetDevice(&device_id));
cudaDeviceProp device_prop;
cudaGetDeviceProperties(&device_prop, device_id);
sm_count = device_prop.multiProcessorCount;
@@ -621,13 +624,13 @@ int get_sm_count() {
inline int getSMVersion(bool queryRealSmArch = false) {
int device{-1};
CUDA_CHECK(cudaGetDevice(&device));
STD_CUDA_CHECK(cudaGetDevice(&device));
int sm_major = 0;
int sm_minor = 0;
CUDA_CHECK(cudaDeviceGetAttribute(&sm_major,
cudaDevAttrComputeCapabilityMajor, device));
CUDA_CHECK(cudaDeviceGetAttribute(&sm_minor,
cudaDevAttrComputeCapabilityMinor, device));
STD_CUDA_CHECK(cudaDeviceGetAttribute(
&sm_major, cudaDevAttrComputeCapabilityMajor, device));
STD_CUDA_CHECK(cudaDeviceGetAttribute(
&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
int sm = sm_major * 10 + sm_minor;
if (sm == 121 && !queryRealSmArch) {
return 120;
@@ -639,7 +642,7 @@ template <typename KernelFunc>
int get_max_active_blocks(KernelFunc kernel, int block_size,
int dynamic_smem = 0) {
int max_active = 0;
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
STD_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active, kernel, block_size, dynamic_smem));
return std::max(max_active, 1);
}
@@ -678,27 +681,27 @@ void minimax_reduce_rms_kernel_launcher(MiniMaxReduceRMSParams const& params) {
cfg.attrs = attribute;
cfg.numAttrs = SM >= 90 ? 2 : 0;
CUDA_CHECK(cudaLaunchKernelEx(
STD_CUDA_CHECK(cudaLaunchKernelEx(
&cfg, minimax_reduce_rms_kernel_lamport<DType, NRanks>, params));
}
template <typename DType, int NRanks, int OriginQDim, int OriginKDim>
void minimax_reduce_rms_kernel_launcher_float4(
MiniMaxReduceRMSParams const& params) {
TORCH_CHECK(params.size_q % params.hidden_dim == 0);
TORCH_CHECK(params.hidden_dim % kElemsPerAccess<DType> == 0);
STD_TORCH_CHECK(params.size_q % params.hidden_dim == 0);
STD_TORCH_CHECK(params.hidden_dim % kElemsPerAccess<DType> == 0);
if (params.stride_q > 0) {
TORCH_CHECK(params.stride_q % kElemsPerAccess<DType> == 0);
STD_TORCH_CHECK(params.stride_q % kElemsPerAccess<DType> == 0);
}
TORCH_CHECK(params.allreduce_in_k != nullptr,
"float4 QK kernel requires K input");
TORCH_CHECK(params.hidden_dim >= params.hidden_dim_k);
TORCH_CHECK(params.size_k % params.hidden_dim_k == 0);
TORCH_CHECK(params.hidden_dim_k % kElemsPerAccess<DType> == 0);
TORCH_CHECK(params.size_q / params.hidden_dim ==
params.size_k / params.hidden_dim_k);
STD_TORCH_CHECK(params.allreduce_in_k != nullptr,
"float4 QK kernel requires K input");
STD_TORCH_CHECK(params.hidden_dim >= params.hidden_dim_k);
STD_TORCH_CHECK(params.size_k % params.hidden_dim_k == 0);
STD_TORCH_CHECK(params.hidden_dim_k % kElemsPerAccess<DType> == 0);
STD_TORCH_CHECK(params.size_q / params.hidden_dim ==
params.size_k / params.hidden_dim_k);
if (params.stride_k > 0) {
TORCH_CHECK(params.stride_k % kElemsPerAccess<DType> == 0);
STD_TORCH_CHECK(params.stride_k % kElemsPerAccess<DType> == 0);
}
int token_num = params.size_q / params.hidden_dim;
@@ -746,7 +749,7 @@ void minimax_reduce_rms_kernel_launcher_float4(
cfg.attrs = attribute;
cfg.numAttrs = SM >= 90 ? 2 : 0;
CUDA_CHECK(cudaLaunchKernelEx(&cfg, kfn, params));
STD_CUDA_CHECK(cudaLaunchKernelEx(&cfg, kfn, params));
}
template <int NRanks>
@@ -759,21 +762,21 @@ void dispatch_dtype(MiniMaxReduceRMSParams const& params) {
(params.hidden_dim * params.nranks == 6144) &&
(params.hidden_dim_k * params.nranks == 1024);
if (params.dtype == at::ScalarType::Half) {
if (params.dtype == torch::headeronly::ScalarType::Half) {
if (use_float4) {
minimax_reduce_rms_kernel_launcher_float4<half, NRanks, 6144, 1024>(
params);
} else {
minimax_reduce_rms_kernel_launcher<half, NRanks>(params);
}
} else if (params.dtype == at::ScalarType::BFloat16) {
} else if (params.dtype == torch::headeronly::ScalarType::BFloat16) {
if (use_float4) {
minimax_reduce_rms_kernel_launcher_float4<__nv_bfloat16, NRanks, 6144,
1024>(params);
} else {
minimax_reduce_rms_kernel_launcher<__nv_bfloat16, NRanks>(params);
}
} else if (params.dtype == at::ScalarType::Float) {
} else if (params.dtype == torch::headeronly::ScalarType::Float) {
if (use_float4) {
minimax_reduce_rms_kernel_launcher_float4<float, NRanks, 6144, 1024>(
params);
@@ -781,7 +784,7 @@ void dispatch_dtype(MiniMaxReduceRMSParams const& params) {
minimax_reduce_rms_kernel_launcher<float, NRanks>(params);
}
} else {
TORCH_CHECK(false, "Unsupported data type for minimax_reduce_rms_op");
STD_TORCH_CHECK(false, "Unsupported data type for minimax_reduce_rms_op");
}
}
@@ -795,16 +798,18 @@ void minimax_reduce_rms_op(MiniMaxReduceRMSParams const& params) {
} else if (params.nranks == 16) {
dispatch_dtype<16>(params);
} else {
TORCH_CHECK(false, "minimax_reduce_rms_op: unsupported ranks number!");
STD_TORCH_CHECK(false, "minimax_reduce_rms_op: unsupported ranks number!");
}
}
} // namespace tensorrt_llm
} // namespace vllm
torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
torch::Tensor const& norm_weight,
torch::Tensor workspace, int64_t const rank,
int64_t const nranks, double const eps) {
torch::stable::Tensor minimax_allreduce_rms(
torch::stable::Tensor const& input,
torch::stable::Tensor const& norm_weight, torch::stable::Tensor workspace,
int64_t const rank, int64_t const nranks, double const eps) {
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
auto allreduce_params = vllm::tensorrt_llm::MiniMaxReduceRMSParams();
allreduce_params.nranks = static_cast<int>(nranks);
@@ -815,12 +820,12 @@ torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
allreduce_params.stride_q = allreduce_params.hidden_dim;
allreduce_params.workspace =
reinterpret_cast<void**>(workspace.mutable_data_ptr());
allreduce_params.allreduce_in = input.data_ptr();
allreduce_params.rms_gamma = norm_weight.data_ptr();
allreduce_params.allreduce_in = const_cast<void*>(input.const_data_ptr());
allreduce_params.rms_gamma = const_cast<void*>(norm_weight.const_data_ptr());
allreduce_params.rms_eps = static_cast<float>(eps);
allreduce_params.stream = at::cuda::getCurrentCUDAStream(input.get_device());
allreduce_params.stream = get_current_cuda_stream(input.get_device_index());
torch::Tensor rms_norm_out = torch::empty_like(input);
torch::stable::Tensor rms_norm_out = torch::stable::empty_like(input);
allreduce_params.rms_norm_out = rms_norm_out.mutable_data_ptr();
vllm::tensorrt_llm::minimax_reduce_rms_op(allreduce_params);
@@ -828,26 +833,33 @@ torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
return rms_norm_out;
}
std::tuple<torch::Tensor, torch::Tensor> minimax_allreduce_rms_qk(
torch::Tensor qkv, torch::Tensor const& norm_weight_q,
torch::Tensor const& norm_weight_k, torch::Tensor workspace,
int64_t const q_size, int64_t const kv_size, int64_t const rank,
int64_t const nranks, double const eps) {
TORCH_CHECK(qkv.dim() == 2, "minimax_allreduce_rms_qk: qkv must be 2D");
TORCH_CHECK(qkv.is_contiguous(),
"minimax_allreduce_rms_qk: qkv must be contiguous");
std::tuple<torch::stable::Tensor, torch::stable::Tensor>
minimax_allreduce_rms_qk(torch::stable::Tensor qkv,
torch::stable::Tensor const& norm_weight_q,
torch::stable::Tensor const& norm_weight_k,
torch::stable::Tensor workspace, int64_t const q_size,
int64_t const kv_size, int64_t const rank,
int64_t const nranks, double const eps) {
STD_TORCH_CHECK(qkv.dim() == 2, "minimax_allreduce_rms_qk: qkv must be 2D");
STD_TORCH_CHECK(qkv.is_contiguous(),
"minimax_allreduce_rms_qk: qkv must be contiguous");
int64_t qkv_dim = qkv.size(-1);
TORCH_CHECK(qkv_dim == q_size + 2 * kv_size,
"minimax_allreduce_rms_qk: qkv last dim must equal "
"q_size + 2 * kv_size");
TORCH_CHECK(rank < nranks,
"minimax_allreduce_rms_qk: rank must be less than nranks");
STD_TORCH_CHECK(qkv_dim == q_size + 2 * kv_size,
"minimax_allreduce_rms_qk: qkv last dim must equal "
"q_size + 2 * kv_size");
STD_TORCH_CHECK(rank < nranks,
"minimax_allreduce_rms_qk: rank must be less than nranks");
const torch::stable::accelerator::DeviceGuard device_guard(
qkv.get_device_index());
int64_t num_tokens = qkv.size(0);
int elem_bytes = qkv.element_size();
torch::Tensor q_out = torch::empty({num_tokens, q_size}, qkv.options());
torch::Tensor k_out = torch::empty({num_tokens, kv_size}, qkv.options());
torch::stable::Tensor q_out =
torch::stable::new_empty(qkv, {num_tokens, q_size}, qkv.scalar_type());
torch::stable::Tensor k_out =
torch::stable::new_empty(qkv, {num_tokens, kv_size}, qkv.scalar_type());
auto params = vllm::tensorrt_llm::MiniMaxReduceRMSParams();
params.nranks = static_cast<int>(nranks);
@@ -863,13 +875,14 @@ std::tuple<torch::Tensor, torch::Tensor> minimax_allreduce_rms_qk(
params.stride_k_out = 0; // k_out is contiguous; kernel uses hidden_dim_k
params.workspace = reinterpret_cast<void**>(workspace.mutable_data_ptr());
uint8_t* base = static_cast<uint8_t*>(qkv.data_ptr());
uint8_t* base =
const_cast<uint8_t*>(static_cast<const uint8_t*>(qkv.const_data_ptr()));
params.allreduce_in = base;
params.allreduce_in_k = base + q_size * elem_bytes;
params.rms_gamma = norm_weight_q.data_ptr();
params.rms_gamma_k = norm_weight_k.data_ptr();
params.rms_gamma = const_cast<void*>(norm_weight_q.const_data_ptr());
params.rms_gamma_k = const_cast<void*>(norm_weight_k.const_data_ptr());
params.rms_eps = static_cast<float>(eps);
params.stream = at::cuda::getCurrentCUDAStream(qkv.get_device());
params.stream = get_current_cuda_stream(qkv.get_device_index());
params.rms_norm_out = q_out.mutable_data_ptr();
params.rms_norm_out_k = k_out.mutable_data_ptr();
@@ -0,0 +1,69 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_mxfp8_grouped_mm_launcher.cuh"
void cutlass_mxfp8_grouped_mm(const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
const torch::stable::Tensor& sfa,
const torch::stable::Tensor& sfb,
torch::stable::Tensor& d,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& blockscale_offsets) {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
STD_TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
STD_TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have shape (num_experts, 3)");
STD_TORCH_CHECK(
problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
STD_TORCH_CHECK(
problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int,
"problem_sizes must be int32");
STD_TORCH_CHECK(
expert_offsets.scalar_type() == torch::headeronly::ScalarType::Int,
"expert_offsets must be int32");
STD_TORCH_CHECK(
blockscale_offsets.scalar_type() == torch::headeronly::ScalarType::Int,
"blockscale_offsets must be int32");
STD_TORCH_CHECK(a.dim() == 2,
"a must be a 2D tensor of shape (num_tokens, k)");
STD_TORCH_CHECK(b.dim() == 3,
"b must be a 3D tensor of shape (num_experts, k, n)");
STD_TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0,
"k should align 128");
STD_TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128");
STD_TORCH_CHECK(a.stride(1) == 1, "a must be row major");
STD_TORCH_CHECK(b.stride(1) == 1, "b must be column major");
const torch::stable::accelerator::DeviceGuard device_guard(
a.get_device_index());
auto stream = get_current_cuda_stream(a.get_device_index());
if (d.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
cutlass::bfloat16_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
blockscale_offsets, stream);
} else if (d.scalar_type() == torch::headeronly::ScalarType::Half) {
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
cutlass::half_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
blockscale_offsets, stream);
} else {
STD_TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
}
#else
STD_TORCH_CHECK(false,
"No implemented cutlass_mxfp8_grouped_mm for "
"current device");
#endif
}
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("cutlass_mxfp8_grouped_mm", TORCH_BOX(&cutlass_mxfp8_grouped_mm));
}
@@ -4,9 +4,9 @@
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_launcher.cuh
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/Exception.h>
#include <cassert>
#include <iostream>
@@ -15,18 +15,22 @@
#include "cute/tensor.hpp"
#include "cutlass_mxfp8_grouped_mm_functor.cuh"
#include "cutlass_mxfp8_grouped_mm_traits.cuh"
#include "libtorch_stable/torch_utils.h"
namespace expert_specialization {
template <typename GemmTraits>
void cutlass_mxfp8_grouped_mm_pre_compute(
torch::Tensor& a_ptrs, torch::Tensor& b_ptrs, torch::Tensor& sfa_ptrs,
torch::Tensor& sfb_ptrs, torch::Tensor& d_ptrs, torch::Tensor& stride_a,
torch::Tensor& stride_b, torch::Tensor& stride_d, torch::Tensor& layout_sfa,
torch::Tensor& layout_sfb, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& sfa, const torch::Tensor& sfb, const torch::Tensor& d,
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets, cudaStream_t stream) {
torch::stable::Tensor& a_ptrs, torch::stable::Tensor& b_ptrs,
torch::stable::Tensor& sfa_ptrs, torch::stable::Tensor& sfb_ptrs,
torch::stable::Tensor& d_ptrs, torch::stable::Tensor& stride_a,
torch::stable::Tensor& stride_b, torch::stable::Tensor& stride_d,
torch::stable::Tensor& layout_sfa, torch::stable::Tensor& layout_sfb,
const torch::stable::Tensor& a, const torch::stable::Tensor& b,
const torch::stable::Tensor& sfa, const torch::stable::Tensor& sfb,
const torch::stable::Tensor& d, const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& blockscale_offsets, cudaStream_t stream) {
using OffsetFunctor = CutlassMxfp8GroupedMmOffsetFunctor<GemmTraits>;
using ElementA = typename OffsetFunctor::ElementA;
using ElementB = typename OffsetFunctor::ElementB;
@@ -42,10 +46,10 @@ void cutlass_mxfp8_grouped_mm_pre_compute(
using StrideB = typename StrideFunctor::StrideB;
using StrideD = typename StrideFunctor::StrideD;
int num_experts = (int)expert_offsets.size(0);
TORCH_CHECK(num_experts <= 1024,
"Number of experts cannot exceed 1024, the maximum number of "
"threads per block.");
int num_experts = static_cast<int>(expert_offsets.size(0));
STD_TORCH_CHECK(num_experts <= 1024,
"Number of experts cannot exceed 1024, the maximum number of "
"threads per block.");
OffsetFunctor offset_functor(
reinterpret_cast<int*>(expert_offsets.data_ptr()),
@@ -72,13 +76,18 @@ void cutlass_mxfp8_grouped_mm_pre_compute(
}
template <typename GemmTraits>
void cutlass_mxfp8_grouped_mm(
const torch::Tensor& a_ptrs, const torch::Tensor& b_ptrs,
const torch::Tensor& sfa_ptrs, const torch::Tensor& sfb_ptrs,
const torch::Tensor& d_ptrs, const torch::Tensor& stride_a,
const torch::Tensor& stride_b, const torch::Tensor& stride_d,
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes, cudaStream_t stream) {
void cutlass_mxfp8_grouped_mm(const torch::stable::Tensor& a_ptrs,
const torch::stable::Tensor& b_ptrs,
const torch::stable::Tensor& sfa_ptrs,
const torch::stable::Tensor& sfb_ptrs,
const torch::stable::Tensor& d_ptrs,
const torch::stable::Tensor& stride_a,
const torch::stable::Tensor& stride_b,
const torch::stable::Tensor& stride_d,
const torch::stable::Tensor& layout_sfa,
const torch::stable::Tensor& layout_sfb,
const torch::stable::Tensor& problem_sizes,
cudaStream_t stream) {
using Gemm = typename GemmTraits::Gemm;
using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
@@ -93,13 +102,12 @@ void cutlass_mxfp8_grouped_mm(
typename GemmTraits::ProblemShape::UnderlyingProblemShape;
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = c10::cuda::current_device();
hw_info.sm_count =
at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
hw_info.device_id = d_ptrs.get_device_index();
hw_info.sm_count = get_device_prop()->multiProcessorCount;
hw_info.cluster_shape = GemmTraits::MMAConfig::preferred_cluster;
hw_info.cluster_shape_fallback = GemmTraits::MMAConfig::fallback_cluster;
int num_experts = (int)problem_sizes.size(0);
int num_experts = static_cast<int>(problem_sizes.size(0));
UnderlyingProblemShape* underlying_problem_shape =
reinterpret_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
@@ -127,44 +135,55 @@ void cutlass_mxfp8_grouped_mm(
Gemm gemm;
auto can_implement_status = gemm.can_implement(arguments);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM");
STD_TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM");
torch::TensorOptions options_uint8 =
torch::TensorOptions().dtype(torch::kUInt8).device(d_ptrs.device());
size_t workspace_size = gemm.get_workspace_size(arguments);
torch::Tensor workspace = torch::empty(workspace_size, options_uint8);
torch::stable::Tensor workspace = torch::stable::empty(
{static_cast<int64_t>(workspace_size)},
torch::headeronly::ScalarType::Byte, std::nullopt, d_ptrs.device());
auto status = gemm.initialize(arguments, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM");
status = gemm.run(stream, nullptr, true); // Enable PDL
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
template <typename OutType>
void cutlass_mxfp8_grouped_mm_dispatch_out_dtype(
const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& sfa,
const torch::Tensor& sfb, torch::Tensor& d,
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets, cudaStream_t stream) {
int num_experts = (int)problem_sizes.size(0);
torch::TensorOptions options_int64 =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::TensorOptions options_int32 =
torch::TensorOptions().dtype(torch::kInt32).device(a.device());
const torch::stable::Tensor& a, const torch::stable::Tensor& b,
const torch::stable::Tensor& sfa, const torch::stable::Tensor& sfb,
torch::stable::Tensor& d, const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& blockscale_offsets, cudaStream_t stream) {
int num_experts = static_cast<int>(problem_sizes.size(0));
auto device = a.device();
torch::Tensor a_ptrs = torch::empty(num_experts, options_int64);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int64);
torch::Tensor sfa_ptrs = torch::empty(num_experts, options_int64);
torch::Tensor sfb_ptrs = torch::empty(num_experts, options_int64);
torch::Tensor d_ptrs = torch::empty(num_experts, options_int64);
torch::stable::Tensor a_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor sfa_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor sfb_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor d_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::Tensor stride_a = torch::empty(num_experts, options_int64);
torch::Tensor stride_b = torch::empty(num_experts, options_int64);
torch::Tensor stride_d = torch::empty(num_experts, options_int64);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int32);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int32);
torch::stable::Tensor stride_a = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor stride_b = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor stride_d = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor layout_sfa =
torch::stable::empty({num_experts, 5}, torch::headeronly::ScalarType::Int,
std::nullopt, device);
torch::stable::Tensor layout_sfb =
torch::stable::empty({num_experts, 5}, torch::headeronly::ScalarType::Int,
std::nullopt, device);
using GemmTraits = CutlassMxfp8GroupedMmGemmTraits<MMA1SMConfig, OutType>;
cutlass_mxfp8_grouped_mm_pre_compute<GemmTraits>(
@@ -176,4 +195,4 @@ void cutlass_mxfp8_grouped_mm_dispatch_out_dtype(
layout_sfa, layout_sfb, problem_sizes, stream);
}
} // namespace expert_specialization
} // namespace expert_specialization
@@ -0,0 +1,66 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "mxfp8_experts_quant.cuh"
void mxfp8_experts_quant(const torch::stable::Tensor& input,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& blockscale_offsets,
torch::stable::Tensor& quant_output,
torch::stable::Tensor& scale_factor) {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
STD_TORCH_CHECK(input.dim() == 2, "input must be 2D tensor");
STD_TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128");
STD_TORCH_CHECK(input.stride(1) == 1, "input must be row major");
STD_TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
STD_TORCH_CHECK(
problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int,
"problem_sizes must be int32");
STD_TORCH_CHECK(
expert_offsets.scalar_type() == torch::headeronly::ScalarType::Int,
"expert_offsets must be int32");
STD_TORCH_CHECK(
blockscale_offsets.scalar_type() == torch::headeronly::ScalarType::Int,
"blockscale_offsets must be int32");
auto groups = problem_sizes.size(0);
STD_TORCH_CHECK(
expert_offsets.dim() == 1 && expert_offsets.size(0) == groups,
"expert_offsets must be 1D and have size equal to the number of groups");
STD_TORCH_CHECK(
blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups,
"blockscale_offsets must be 1D and have size equal to the number of "
"groups");
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
if (input.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
expert_specialization::launch_mxfp8_experts_quant<__nv_bfloat16>(
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
scale_factor);
} else if (input.scalar_type() == torch::headeronly::ScalarType::Half) {
expert_specialization::launch_mxfp8_experts_quant<__half>(
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
scale_factor);
} else {
STD_TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
}
#else
STD_TORCH_CHECK(false,
"No implemented mxfp8_experts_quant for "
"current device");
#endif
}
// Registered here (not torch_bindings.cpp) because ENABLE_ES_MXFP8_GROUPED_MM
// is applied only under COMPILE_LANGUAGE:CUDA.
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("mxfp8_experts_quant", TORCH_BOX(&mxfp8_experts_quant));
}
@@ -4,16 +4,19 @@
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cuh
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/all.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/macros.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/Exception.h>
#include <cuda/ptx>
#include "cute/tensor.hpp"
#include "libtorch_stable/torch_utils.h"
namespace expert_specialization {
@@ -356,12 +359,12 @@ __global__ void mxfp8_experts_quant_kernel(
}
template <typename T_IN>
void launch_mxfp8_experts_quant(const torch::Tensor& input,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets,
torch::Tensor& quant_output,
torch::Tensor& scale_factor) {
void launch_mxfp8_experts_quant(const torch::stable::Tensor& input,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& blockscale_offsets,
torch::stable::Tensor& quant_output,
torch::stable::Tensor& scale_factor) {
ThrLayout thr_layout{};
ValLayout val_layout{};
SfR2SThrLayout r2s_thr_layout{};
@@ -386,19 +389,18 @@ void launch_mxfp8_experts_quant(const torch::Tensor& input,
CopyAtomR2S{}, r2s_thr_layout, r2s_val_layout); // Tiler_MN: (16, 4)
int max_active_blocks_per_sm = -1;
AT_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
STD_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_per_sm,
mxfp8_experts_quant_kernel<T_IN, decltype(tiled_copy_g2r),
decltype(tiled_copy_r2g),
decltype(tiled_copy_r2s)>,
THREAD_BLOCK_SIZE, 0));
dim3 grid(at::cuda::getCurrentDeviceProperties()->multiProcessorCount *
max_active_blocks_per_sm,
dim3 grid(get_device_prop()->multiProcessorCount * max_active_blocks_per_sm,
1, 1);
dim3 block(THREAD_BLOCK_SIZE, 1, 1);
int num_experts = (int)problem_sizes.size(0);
auto stream = at::cuda::getCurrentCUDAStream();
int num_experts = static_cast<int>(problem_sizes.size(0));
auto stream = get_current_cuda_stream(input.get_device_index());
mxfp8_experts_quant_kernel<T_IN, decltype(tiled_copy_g2r),
decltype(tiled_copy_r2g), decltype(tiled_copy_r2s)>
<<<grid, block, 0, stream>>>(
+41
View File
@@ -231,6 +231,27 @@ void fused_qk_norm_rope(torch::stable::Tensor& qkv, int64_t num_heads_q,
torch::stable::Tensor& position_ids,
int64_t forced_token_heads_per_warp);
torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::stable::Tensor const& q_in, torch::stable::Tensor const& kv,
torch::stable::Tensor& k_cache, torch::stable::Tensor const& slot_mapping,
torch::stable::Tensor const& position_ids,
torch::stable::Tensor const& cos_sin_cache, int64_t q_head_padded,
double eps, int64_t cache_block_size);
#ifndef USE_ROCM
torch::stable::Tensor minimax_allreduce_rms(
torch::stable::Tensor const& input,
torch::stable::Tensor const& norm_weight, torch::stable::Tensor workspace,
int64_t const rank, int64_t const nranks, double const eps);
std::tuple<torch::stable::Tensor, torch::stable::Tensor>
minimax_allreduce_rms_qk(torch::stable::Tensor qkv,
torch::stable::Tensor const& norm_weight_q,
torch::stable::Tensor const& norm_weight_k,
torch::stable::Tensor workspace, int64_t const q_size,
int64_t const kv_size, int64_t const rank,
int64_t const nranks, double const eps);
#endif
// Sampler kernels (shared CUDA/ROCm)
void apply_repetition_penalties_(
torch::stable::Tensor& logits, const torch::stable::Tensor& prompt_mask,
@@ -273,6 +294,26 @@ void selective_scan_fwd(
const std::optional<torch::stable::Tensor>& cu_chunk_seqlen,
const std::optional<torch::stable::Tensor>& last_chunk_indices);
using fptr_t = int64_t;
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
torch::stable::Tensor& rank_data, int64_t rank,
bool fully_connected);
void all_reduce(fptr_t _fa, torch::stable::Tensor& inp,
torch::stable::Tensor& out, fptr_t reg_buffer,
int64_t reg_buffer_sz_bytes);
void dispose(fptr_t _fa);
int64_t meta_size();
void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa,
const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets);
std::tuple<int64_t, torch::stable::Tensor> allocate_shared_buffer_and_handle(
int64_t size);
int64_t open_mem_handle(torch::stable::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);
// Activation kernels (shared CUDA/ROCm)
void silu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input);
void silu_and_mul_clamp(torch::stable::Tensor& out,
+63
View File
@@ -337,6 +337,24 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"bool is_neox, Tensor position_ids, "
"int forced_token_heads_per_warp=-1) -> ()");
ops.def(
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert("
"Tensor q_in, Tensor kv, Tensor! k_cache, "
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
"int q_head_padded, float eps, int cache_block_size) -> Tensor");
#ifndef USE_ROCM
ops.def(
"minimax_allreduce_rms("
"Tensor input, Tensor norm_weight, Tensor workspace, "
"int rank, int nranks, float eps) -> Tensor");
ops.def(
"minimax_allreduce_rms_qk("
"Tensor qkv, Tensor norm_weight_q, Tensor norm_weight_k, "
"Tensor workspace, int q_size, int kv_size, int rank, int nranks, "
"float eps) -> (Tensor, Tensor)");
#endif
// Apply repetition penalties to logits in-place.
ops.def(
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
@@ -571,6 +589,12 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
// Positional encoding kernels (shared CUDA/ROCm)
ops.impl("rotary_embedding", TORCH_BOX(&rotary_embedding));
ops.impl("fused_qk_norm_rope", TORCH_BOX(&fused_qk_norm_rope));
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert",
TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert));
#ifndef USE_ROCM
ops.impl("minimax_allreduce_rms", TORCH_BOX(&minimax_allreduce_rms));
ops.impl("minimax_allreduce_rms_qk", TORCH_BOX(&minimax_allreduce_rms_qk));
#endif
// Sampler kernels (shared CUDA/ROCm)
ops.impl("apply_repetition_penalties_",
@@ -725,6 +749,45 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C_cache_ops, ops) {
"dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()");
}
STABLE_TORCH_LIBRARY_FRAGMENT(_C_custom_ar, custom_ar) {
custom_ar.def(
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
"int rank, bool fully_connected) -> int");
custom_ar.def(
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()");
custom_ar.def("dispose(int fa) -> ()");
custom_ar.def("meta_size() -> int");
custom_ar.def("register_buffer(int fa, int[] ipc_tensors) -> ()");
custom_ar.def("get_graph_buffer_ipc_meta(int fa) -> (int[], int[])");
custom_ar.def(
"register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()");
custom_ar.def("allocate_shared_buffer_and_handle(int size) -> (int, Tensor)");
custom_ar.def("open_mem_handle(Tensor mem_handle) -> int");
custom_ar.def("free_shared_buffer(int ptr) -> ()");
}
STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CUDA, custom_ar) {
custom_ar.impl("init_custom_ar", TORCH_BOX(&init_custom_ar));
custom_ar.impl("all_reduce", TORCH_BOX(&all_reduce));
}
STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CPU, custom_ar) {
custom_ar.impl("open_mem_handle", TORCH_BOX(&open_mem_handle));
}
STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CompositeExplicitAutograd, custom_ar) {
custom_ar.impl("dispose", TORCH_BOX(&dispose));
custom_ar.impl("meta_size", TORCH_BOX(&meta_size));
custom_ar.impl("register_buffer", TORCH_BOX(&register_buffer));
custom_ar.impl("get_graph_buffer_ipc_meta",
TORCH_BOX(&get_graph_buffer_ipc_meta));
custom_ar.impl("register_graph_buffers", TORCH_BOX(&register_graph_buffers));
custom_ar.impl("allocate_shared_buffer_and_handle",
TORCH_BOX(&allocate_shared_buffer_and_handle));
custom_ar.impl("free_shared_buffer", TORCH_BOX(&free_shared_buffer));
}
STABLE_TORCH_LIBRARY_IMPL(_C_cache_ops, CPU, ops) {
ops.impl("swap_blocks_batch", TORCH_BOX(&swap_blocks_batch));
}
+2 -2
View File
@@ -19,7 +19,7 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/types.h>
#include <torch/headeronly/core/ScalarType.h>
namespace vllm {
namespace tensorrt_llm {
@@ -51,7 +51,7 @@ static constexpr int kElemsPerAccess = ElemsPerAccess<DType>::value;
struct MiniMaxReduceRMSParams {
int nranks{};
int rank{};
at::ScalarType dtype{at::ScalarType::Undefined};
torch::headeronly::ScalarType dtype{torch::headeronly::ScalarType::Undefined};
int size_q{};
int hidden_dim{};
int size_k{};
@@ -1,60 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu
#include <torch/all.h>
#include "cutlass_mxfp8_grouped_mm_launcher.cuh"
void cutlass_mxfp8_grouped_mm(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& sfa,
const torch::Tensor& sfb, torch::Tensor& d,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets) {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have shape (num_experts, 3)");
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
"problem_sizes must be int32");
TORCH_CHECK(expert_offsets.dtype() == torch::kInt32,
"expert_offsets must be int32");
TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32,
"blockscale_offsets must be int32");
TORCH_CHECK(a.dim() == 2, "a must be a 2D tensor of shape (num_tokens, k)");
TORCH_CHECK(b.dim() == 3,
"b must be a 3D tensor of shape (num_experts, k, n)");
TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0,
"k should align 128");
TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128");
TORCH_CHECK(a.strides()[1] == 1, "a must be row major");
TORCH_CHECK(b.strides()[1] == 1, "b must be column major");
auto stream = at::cuda::getCurrentCUDAStream();
if (d.dtype() == torch::kBFloat16) {
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
cutlass::bfloat16_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
blockscale_offsets, stream);
} else if (d.dtype() == torch::kFloat16) {
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
cutlass::half_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
blockscale_offsets, stream);
} else {
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
}
#else
TORCH_CHECK(false,
"No implemented cutlass_mxfp8_grouped_mm for "
"current device");
#endif
}
#include "core/registration.h"
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_mxfp8_grouped_mm", cutlass_mxfp8_grouped_mm);
}
-60
View File
@@ -1,60 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu
#include <torch/all.h>
#include "mxfp8_experts_quant.cuh"
void mxfp8_experts_quant(const torch::Tensor& input,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets,
torch::Tensor& quant_output,
torch::Tensor& scale_factor) {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
TORCH_CHECK(input.dim() == 2, "input must be 2D tensor");
TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128");
TORCH_CHECK(input.strides()[1] == 1, "input must be row major");
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
"problem_sizes must be int32");
TORCH_CHECK(expert_offsets.dtype() == torch::kInt32,
"expert_offsets must be int32");
TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32,
"blockscale_offsets must be int32");
auto groups = problem_sizes.size(0);
TORCH_CHECK(
expert_offsets.dim() == 1 && expert_offsets.size(0) == groups,
"expert_offsets must be 1D and have size equal to the number of groups");
TORCH_CHECK(
blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups,
"blockscale_offsets must be 1D and have size equal to the number of "
"groups");
auto stream = at::cuda::getCurrentCUDAStream();
if (input.dtype() == torch::kBFloat16) {
expert_specialization::launch_mxfp8_experts_quant<__nv_bfloat16>(
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
scale_factor);
} else if (input.dtype() == torch::kFloat16) {
expert_specialization::launch_mxfp8_experts_quant<__half>(
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
scale_factor);
} else {
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
}
#else
TORCH_CHECK(false,
"No implemented mxfp8_experts_quant for "
"current device");
#endif
}
#include "core/registration.h"
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("mxfp8_experts_quant", mxfp8_experts_quant);
}
-36
View File
@@ -40,12 +40,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);
torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::Tensor const& q_in, torch::Tensor const& kv, torch::Tensor& k_cache,
torch::Tensor const& slot_mapping, torch::Tensor const& position_ids,
torch::Tensor const& cos_sin_cache, int64_t q_head_padded, double eps,
int64_t cache_block_size);
void silu_and_mul_per_block_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scales, int64_t group_size,
@@ -107,24 +101,6 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
int64_t activation_kind);
using fptr_t = int64_t;
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
torch::Tensor& rank_data, int64_t rank,
bool fully_connected);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
void dispose(fptr_t _fa);
int64_t meta_size();
void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa,
const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets);
std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
int64_t size);
int64_t open_mem_handle(torch::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);
#ifdef USE_ROCM
fptr_t init_custom_qr(int64_t rank, int64_t world_size,
std::optional<int64_t> qr_max_size = std::nullopt);
@@ -135,15 +111,3 @@ void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int64_t quant_level, bool cast_bf2half = false);
int64_t qr_max_size();
#endif
#ifndef USE_ROCM
torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
torch::Tensor const& norm_weight,
torch::Tensor workspace, int64_t const rank,
int64_t const nranks, double const eps);
std::tuple<torch::Tensor, torch::Tensor> minimax_allreduce_rms_qk(
torch::Tensor qkv, torch::Tensor const& norm_weight_q,
torch::Tensor const& norm_weight_k, torch::Tensor workspace,
int64_t const q_size, int64_t const kv_size, int64_t const rank,
int64_t const nranks, double const eps);
#endif
+20 -78
View File
@@ -55,14 +55,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
// kernel launch.
ops.def(
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert("
"Tensor q_in, Tensor kv, Tensor! k_cache, "
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
"int q_head_padded, float eps, int cache_block_size) -> Tensor");
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA,
&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert);
// kernel launch. Registered in _C_stable_libtorch.
// Quantization ops
#ifndef USE_ROCM
@@ -163,34 +156,27 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// conditionally compiled so impl registration is in source file
#endif
#ifndef USE_ROCM
ops.def(
"minimax_allreduce_rms("
"Tensor input,"
"Tensor norm_weight,"
"Tensor workspace,"
"int rank,"
"int nranks,"
"float eps) -> Tensor");
ops.impl("minimax_allreduce_rms", torch::kCUDA, &minimax_allreduce_rms);
ops.def(
"minimax_allreduce_rms_qk("
"Tensor qkv,"
"Tensor norm_weight_q,"
"Tensor norm_weight_k,"
"Tensor workspace,"
"int q_size,"
"int kv_size,"
"int rank,"
"int nranks,"
"float eps) -> (Tensor, Tensor)");
ops.impl("minimax_allreduce_rms_qk", torch::kCUDA, &minimax_allreduce_rms_qk);
// conditionally compiled so impl in source file
#endif
}
#ifdef USE_ROCM
TORCH_LIBRARY_FRAGMENT(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
// Quick Reduce all-reduce kernels (ROCm-only; stays on legacy _C).
custom_ar.def(
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
"cast_bf2half) -> ()");
custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
custom_ar.def("init_custom_qr", &init_custom_qr);
custom_ar.def("qr_destroy", &qr_destroy);
custom_ar.def("qr_get_handle", &qr_get_handle);
custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
custom_ar.def("qr_max_size", &qr_max_size);
}
#endif
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
// Cuda utils
@@ -205,48 +191,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
&get_max_shared_memory_per_block_device_attribute);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
// Custom all-reduce kernels
custom_ar.def(
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
"int rank, bool fully_connected) -> int");
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
custom_ar.def(
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()");
custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce);
custom_ar.def("dispose", &dispose);
custom_ar.def("meta_size", &meta_size);
custom_ar.def("register_buffer", &register_buffer);
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
custom_ar.def("register_graph_buffers", &register_graph_buffers);
custom_ar.def("allocate_shared_buffer_and_handle",
&allocate_shared_buffer_and_handle);
custom_ar.def("open_mem_handle(Tensor mem_handle) -> int", &open_mem_handle);
custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle);
custom_ar.def("free_shared_buffer", &free_shared_buffer);
#ifdef USE_ROCM
// Quick Reduce all-reduce kernels
custom_ar.def(
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
"cast_bf2half) -> ()");
custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
custom_ar.def("init_custom_qr", &init_custom_qr);
custom_ar.def("qr_destroy", &qr_destroy);
custom_ar.def("qr_get_handle", &qr_get_handle);
custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
// Max input size in bytes
custom_ar.def("qr_max_size", &qr_max_size);
#endif
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
+2 -2
View File
@@ -50,7 +50,7 @@ struct _typeConvert<float> {
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
template <>
struct _typeConvert<c10::Half> {
struct _typeConvert<torch::headeronly::Half> {
static constexpr bool exists = true;
using hip_type = __half;
using packed_hip_type = __half2;
@@ -73,7 +73,7 @@ struct _typeConvert<c10::Half> {
// CUDA_ARCH < 800 does not have BF16 support
// ROCm 7.0+ supports bfloat16
template <>
struct _typeConvert<c10::BFloat16> {
struct _typeConvert<torch::headeronly::BFloat16> {
static constexpr bool exists = true;
using hip_type = __nv_bfloat16;
using packed_hip_type = __nv_bfloat162;
+322
View File
@@ -7,6 +7,9 @@ from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.renderers.hf import (
_consolidate_system_messages,
_convert_developer_to_system,
_detect_developer_role_support,
_get_hf_base_chat_template_params,
_try_extract_ast,
resolve_chat_template,
@@ -592,3 +595,322 @@ def test_get_gen_prompt(
f"The generated prompt does not match the expected output for "
f"model {model} and template {template}"
)
class TestConvertDeveloperToSystem:
def test_converts_role(self):
conversation = [
{"role": "developer", "content": "You are helpful."},
{"role": "user", "content": "Hello"},
]
result = _convert_developer_to_system(conversation)
assert result[0]["role"] == "system"
assert result[0]["content"] == "You are helpful."
assert result[1]["role"] == "user"
def test_removes_tools_key(self):
conversation = [
{
"role": "developer",
"content": "Instructions",
"tools": [{"type": "function"}],
},
]
result = _convert_developer_to_system(conversation)
assert "tools" not in result[0]
def test_no_developer_messages_unchanged(self):
conversation = [
{"role": "system", "content": "System prompt"},
{"role": "user", "content": "Hello"},
]
result = _convert_developer_to_system(conversation)
assert result[0]["role"] == "system"
assert result[1]["role"] == "user"
def test_does_not_mutate_original(self):
original = {
"role": "developer",
"content": "Instructions",
"tools": [{"type": "function"}],
}
conversation = [original]
_convert_developer_to_system(conversation)
assert original["role"] == "developer"
assert "tools" in original
# --- Developer role detection and conversion tests ---
CHATML_TEMPLATE = (
"{% for message in messages %}"
"{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}"
"{% if (loop.last and add_generation_prompt) or not loop.last %}"
"{{ '<|im_end|>' + '\\n'}}"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}"
"{{ '<|im_start|>assistant\\n' }}"
"{% endif %}"
)
TEMPLATE_WITH_DEVELOPER = (
"{% for message in messages %}"
"{% if message['role'] == 'developer' %}"
"{{'<|im_start|>developer\\n' + message['content'] + '<|im_end|>\\n'}}"
"{% elif message['role'] == 'system' %}"
"{{'<|im_start|>system\\n' + message['content'] + '<|im_end|>\\n'}}"
"{% elif message['role'] == 'user' %}"
"{{'<|im_start|>user\\n' + message['content'] + '<|im_end|>\\n'}}"
"{% elif message['role'] == 'assistant' %}"
"{{'<|im_start|>assistant\\n' + message['content'] + '<|im_end|>\\n'}}"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|im_start|>assistant\\n' }}"
"{% endif %}"
)
STRICT_ROLE_TEMPLATE = (
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
"{{'<|im_start|>system\\n' + message['content'] + '<|im_end|>\\n'}}"
"{% elif message['role'] == 'user' %}"
"{{'<|im_start|>user\\n' + message['content'] + '<|im_end|>\\n'}}"
"{% elif message['role'] == 'assistant' %}"
"{{'<|im_start|>assistant\\n' + message['content'] + '<|im_end|>\\n'}}"
"{% else %}"
"{{ raise_exception('Unexpected message role: ' + message['role']) }}"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|im_start|>assistant\\n' }}"
"{% endif %}"
)
class TestDetectDeveloperRoleSupport:
def test_absent_in_chatml(self):
assert _detect_developer_role_support(CHATML_TEMPLATE) is False
def test_present_double_quotes(self):
assert _detect_developer_role_support(TEMPLATE_WITH_DEVELOPER) is True
def test_present_single_quotes(self):
template = TEMPLATE_WITH_DEVELOPER.replace('"developer"', "'developer'")
assert _detect_developer_role_support(template) is True
def test_absent_in_strict_template(self):
assert _detect_developer_role_support(STRICT_ROLE_TEMPLATE) is False
class TestSafeApplyChatTemplateDeveloperRole:
@pytest.fixture
def model_config(self):
return ModelConfig(
"facebook/opt-125m",
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
)
@pytest.fixture
def tokenizer(self):
return get_tokenizer("facebook/opt-125m")
def test_developer_converted_to_system_for_chatml(self, model_config, tokenizer):
conversation = [
{"role": "developer", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
]
result = safe_apply_chat_template(
model_config,
tokenizer,
conversation,
chat_template=CHATML_TEMPLATE,
tokenize=False,
add_generation_prompt=True,
)
assert "<|im_start|>system" in result
assert "You are a helpful assistant." in result
assert "<|im_start|>developer" not in result
def test_developer_preserved_when_template_supports_it(
self, model_config, tokenizer
):
conversation = [
{"role": "developer", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
]
result = safe_apply_chat_template(
model_config,
tokenizer,
conversation,
chat_template=TEMPLATE_WITH_DEVELOPER,
tokenize=False,
add_generation_prompt=True,
)
assert "<|im_start|>developer" in result
assert "You are a helpful assistant." in result
def test_developer_does_not_crash_strict_template(self, model_config, tokenizer):
conversation = [
{"role": "developer", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
]
result = safe_apply_chat_template(
model_config,
tokenizer,
conversation,
chat_template=STRICT_ROLE_TEMPLATE,
tokenize=False,
add_generation_prompt=True,
)
assert "<|im_start|>system" in result
assert "You are a helpful assistant." in result
def test_no_developer_messages_no_overhead(self, model_config, tokenizer):
conversation = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello"},
]
result = safe_apply_chat_template(
model_config,
tokenizer,
conversation,
chat_template=CHATML_TEMPLATE,
tokenize=False,
add_generation_prompt=True,
)
assert "<|im_start|>system" in result
assert "You are helpful." in result
def test_developer_at_non_first_position_consolidated(
self, model_config, tokenizer
):
conversation = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "developer", "content": "Be concise."},
{"role": "user", "content": "What is 2+2?"},
]
result = safe_apply_chat_template(
model_config,
tokenizer,
conversation,
chat_template=SYSTEM_FIRST_TEMPLATE,
tokenize=False,
add_generation_prompt=True,
)
assert "<|im_start|>system" in result
assert "You are helpful." in result
assert "Be concise." in result
assert "What is 2+2?" in result
def test_developer_only_no_prior_system(self, model_config, tokenizer):
conversation = [
{"role": "user", "content": "Hello"},
{"role": "developer", "content": "Be concise."},
{"role": "user", "content": "What is 2+2?"},
]
result = safe_apply_chat_template(
model_config,
tokenizer,
conversation,
chat_template=SYSTEM_FIRST_TEMPLATE,
tokenize=False,
add_generation_prompt=True,
)
assert "<|im_start|>system" in result
assert "Be concise." in result
SYSTEM_FIRST_TEMPLATE = (
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
"{% if not loop.first %}"
"{{ raise_exception('System message must be at the beginning.') }}"
"{% endif %}"
"{{'<|im_start|>system\\n' + message['content'] + '<|im_end|>\\n'}}"
"{% elif message['role'] == 'user' %}"
"{{'<|im_start|>user\\n' + message['content'] + '<|im_end|>\\n'}}"
"{% elif message['role'] == 'assistant' %}"
"{{'<|im_start|>assistant\\n' + message['content'] + '<|im_end|>\\n'}}"
"{% else %}"
"{{ raise_exception('Unexpected message role: ' + message['role']) }}"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|im_start|>assistant\\n' }}"
"{% endif %}"
)
class TestConsolidateSystemMessages:
def test_no_system_messages_unchanged(self):
conversation = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
]
result = _consolidate_system_messages(conversation)
assert result == conversation
def test_single_system_at_start_unchanged(self):
conversation = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello"},
]
result = _consolidate_system_messages(conversation)
assert result == conversation
def test_system_at_non_first_position_moved(self):
conversation = [
{"role": "user", "content": "Hello"},
{"role": "system", "content": "You are helpful."},
]
result = _consolidate_system_messages(conversation)
assert result[0]["role"] == "system"
assert result[0]["content"] == "You are helpful."
assert result[1]["role"] == "user"
assert result[1]["content"] == "Hello"
def test_multiple_system_messages_merged(self):
conversation = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello"},
{"role": "system", "content": "Be concise."},
]
result = _consolidate_system_messages(conversation)
assert len(result) == 2
assert result[0]["role"] == "system"
assert result[0]["content"] == "You are helpful.\n\nBe concise."
assert result[1]["role"] == "user"
def test_list_content_handled(self):
conversation = [
{"role": "user", "content": "Hello"},
{
"role": "system",
"content": [
{"type": "text", "text": "Rule 1."},
{"type": "text", "text": "Rule 2."},
],
},
]
result = _consolidate_system_messages(conversation)
assert result[0]["role"] == "system"
assert result[0]["content"] == "Rule 1.\nRule 2."
assert result[1]["role"] == "user"
def test_does_not_mutate_original(self):
conversation = [
{"role": "user", "content": "Hello"},
{"role": "system", "content": "You are helpful."},
]
original_len = len(conversation)
_consolidate_system_messages(conversation)
assert len(conversation) == original_len
assert conversation[0]["role"] == "user"
assert conversation[1]["role"] == "system"
+2
View File
@@ -24,6 +24,7 @@ from vllm.utils.hashing import sha256
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.core.single_type_kv_cache_manager import register_all_kvcache_specs
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
@@ -160,6 +161,7 @@ def create_scheduler(
],
)
cache_config.num_gpu_blocks = num_blocks
register_all_kvcache_specs(vllm_config)
scheduler_cls = AsyncScheduler if async_scheduling else Scheduler
return scheduler_cls(
vllm_config=vllm_config,
@@ -21,19 +21,20 @@ dp_ep_configs=(
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1)
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1)
)
# We assume HMA enabled by default.
hybrid_ssm_configs=(
"VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
"VLLM_SSM_CONV_STATE_LAYOUT=DS GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
# TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models.
"VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
"VLLM_SSM_CONV_STATE_LAYOUT=DS PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
# GDN (Qwen3.5)
"VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=Qwen/Qwen3.5-0.8B"
"VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=Qwen/Qwen3.5-0.8B VLLM_SERVE_EXTRA_ARGS=--no-async-scheduling"
"VLLM_SSM_CONV_STATE_LAYOUT=DS GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=Qwen/Qwen3.5-0.8B"
"VLLM_SSM_CONV_STATE_LAYOUT=DS PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=Qwen/Qwen3.5-0.8B VLLM_SERVE_EXTRA_ARGS=--no-async-scheduling"
)
sw_attn_configs=(
# NOTE: gemma3 does not work with FlashInfer
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192" # SW model
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
)
# Select config array based on DP_EP env var
@@ -50,14 +51,6 @@ else
configs=("${tp_configs[@]}")
fi
if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then
# Append ENABLE_HMA_FLAG=1 to each config in the selected array
echo "ENABLE_HMA_FLAG is set, appending ENABLE_HMA_FLAG=1 to each config"
for i in "${!configs[@]}"; do
configs[$i]="ENABLE_HMA_FLAG=1 ${configs[$i]}"
done
fi
run_tests() {
local label=$1
local extra_args=$2
@@ -11,7 +11,7 @@ SCRIPT="v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh"
eagle3_config="SD_METHOD=eagle3 MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct SD_MODEL=RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3 NUM_SPEC_TOKENS=3"
# MTP: Qwen3.5-0.8B-Base with hybrid SSM flags.
mtp_config="SD_METHOD=mtp MODEL_NAME=Qwen/Qwen3.5-0.8B-Base SD_MODEL=Qwen/Qwen3.5-0.8B-Base NUM_SPEC_TOKENS=1 BLOCK_SIZE=32 MAX_MODEL_LEN=4096 VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 KV_BUFFER_DEVICES=cuda"
mtp_config="SD_METHOD=mtp MODEL_NAME=Qwen/Qwen3.5-0.8B-Base SD_MODEL=Qwen/Qwen3.5-0.8B-Base NUM_SPEC_TOKENS=1 BLOCK_SIZE=32 MAX_MODEL_LEN=4096 VLLM_SSM_CONV_STATE_LAYOUT=DS KV_BUFFER_DEVICES=cuda"
configs=(
"$eagle3_config"
@@ -5,11 +5,6 @@ set -xe
KV_BUFFER_DEVICE="cuda" # Default to cuda
ATTENTION_BACKEND="" # Default to empty (use vllm default)
CROSS_LAYERS_BLOCKS="False"
ENABLE_HMA_VAR="" # Default to empty (HMA disabled by default for kv connector)
# Check for ENABLE_HMA_FLAG environment variable
if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then
ENABLE_HMA_VAR="--no-disable-hybrid-kv-cache-manager"
fi
while [[ $# -gt 0 ]]; do
case $1 in
@@ -37,9 +32,6 @@ echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
if [[ -n "$ATTENTION_BACKEND" ]]; then
echo "Using attention backend: $ATTENTION_BACKEND"
fi
if [[ -n "$ENABLE_HMA_VAR" ]]; then
echo "HMA (Hybrid KV Cache Manager) enabled"
fi
if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
echo "vLLM serve extra args: $VLLM_SERVE_EXTRA_ARGS"
fi
@@ -180,10 +172,6 @@ run_tests_for_model() {
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
fi
# Add HMA flag if specified
if [[ -n "$ENABLE_HMA_VAR" ]]; then
BASE_CMD="${BASE_CMD} $ENABLE_HMA_VAR"
fi
FULL_CMD="$BASE_CMD"
eval "$FULL_CMD &"
@@ -232,10 +220,6 @@ run_tests_for_model() {
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
fi
# Add HMA flag if specified
if [[ -n "$ENABLE_HMA_VAR" ]]; then
BASE_CMD="${BASE_CMD} $ENABLE_HMA_VAR"
fi
# DP-EP attention mode
if [[ -z "$DP_EP" ]]; then
@@ -27,7 +27,6 @@
# ROCM_AITER_UNIFIED_ATTN
# NVIDIA options: FLASH_ATTN, FLASHINFER
# VLLM_SSM_CONV_STATE_LAYOUT - SSM conv state layout (e.g. "DS" required for Mamba models)
# ENABLE_HMA_FLAG - set to 1 to enable hybrid KV cache manager
# VLLM_SERVE_EXTRA_ARGS - comma-separated extra args for vllm serve
set -ex
@@ -85,13 +84,7 @@ if [[ -z "${ATTENTION_BACKEND:-}" ]]; then
fi
echo "Using attention backend: ${ATTENTION_BACKEND}"
# ── HMA & extra serve args ────────────────────────────────────────────
ENABLE_HMA_VAR=""
if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then
ENABLE_HMA_VAR="--no-disable-hybrid-kv-cache-manager"
echo "HMA (Hybrid KV Cache Manager) enabled"
fi
# ── Extra serve args ─────────────────────────────────────────────────
EXTRA_SERVE_ARGS=()
if [[ -n "${VLLM_SERVE_EXTRA_ARGS:-}" ]]; then
@@ -258,7 +251,6 @@ run_test_for_device() {
--kv-transfer-config "$kv_config" \
--speculative-config "$PREFILL_SPEC_CONFIG" \
--attention-backend $ATTENTION_BACKEND \
${ENABLE_HMA_VAR} \
${EXTRA_SERVE_ARGS[@]+"${EXTRA_SERVE_ARGS[@]}"} &
local SERVER_PID=$!
@@ -298,7 +290,6 @@ run_test_for_device() {
--kv-transfer-config "$kv_config" \
--speculative-config "$DECODE_SPEC_CONFIG" \
--attention-backend $ATTENTION_BACKEND \
${ENABLE_HMA_VAR} \
${EXTRA_SERVE_ARGS[@]+"${EXTRA_SERVE_ARGS[@]}"} &
local SERVER_PID=$!
@@ -386,8 +386,6 @@ def test_fewer_blocks_with_hma(monkeypatch, model_name, sw_size):
"kv_transfer_config": kv_transfer_config,
"max_model_len": 2048,
"max_num_seqs": 1,
# NOTE: Make sure HMA is enabled
"disable_hybrid_kv_cache_manager": False,
"max_num_batched_tokens": 2048,
"enable_prefix_caching": False,
"block_size": block_size,
@@ -30,6 +30,9 @@ from vllm.v1.core.sched.output import (
NewRequestData,
SchedulerOutput,
)
from vllm.v1.core.single_type_kv_cache_manager import (
register_all_kvcache_specs,
)
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
@@ -68,6 +71,9 @@ def _make_kv_cache_config(
"""Build a KVCacheConfig with non-empty kv_cache_tensors."""
groups = []
tensors = []
register_all_kvcache_specs(
vllm_config=None
) # Ensure specs are registered for tests
for g in range(num_groups):
layer_names = [f"layer_{g}"]
groups.append(
+322
View File
@@ -0,0 +1,322 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, replace
from typing import Any
import pytest
import torch
from vllm.config import (
CacheConfig,
DeviceConfig,
VllmConfig,
)
from vllm.v1.core.single_type_kv_cache_manager import (
ChunkedLocalAttentionManager,
CrossAttentionManager,
FullAttentionManager,
MambaManager,
SingleTypeKVCacheManager,
SinkFullAttentionManager,
SlidingWindowManager,
register_all_kvcache_specs,
)
from vllm.v1.kv_cache_interface import (
ChunkedLocalAttentionSpec,
CrossAttentionSpec,
FullAttentionSpec,
HiddenStateCacheSpec,
KVCacheSpec,
MambaSpec,
MLAAttentionSpec,
SinkFullAttentionSpec,
SlidingWindowMLASpec,
SlidingWindowSpec,
TQFullAttentionSpec,
UniformTypeKVCacheSpecs,
)
from vllm.v1.kv_cache_spec_registry import (
_REGISTRY_KVCACHESPEC_LIST,
KVCacheSpecRegistry,
register_kv_cache_spec,
)
def make_vllm_config() -> VllmConfig:
return VllmConfig(
cache_config=CacheConfig(
block_size=64,
cache_dtype="bfloat16",
),
device_config=DeviceConfig(device="cpu"),
)
vllm_config = make_vllm_config()
register_all_kvcache_specs(vllm_config)
@pytest.fixture(autouse=True)
def restore_kv_cache_spec_registry():
registry = _REGISTRY_KVCACHESPEC_LIST.copy()
yield
_REGISTRY_KVCACHESPEC_LIST.clear()
_REGISTRY_KVCACHESPEC_LIST.update(registry)
@dataclass(frozen=True)
class _TrulyUnregisteredSpec(KVCacheSpec):
"""
A spec that inherits directly from KVCacheSpec with no registered
ancestor in the MRO. Used to test that the registry correctly raises
when no entry can be found.
"""
@property
def page_size_bytes(self) -> int:
return self.block_size * 128
def max_memory_usage_bytes(self, _) -> int:
return 0
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
TQFullAttentionSpec: FullAttentionManager,
MLAAttentionSpec: FullAttentionManager,
HiddenStateCacheSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
SlidingWindowMLASpec: SlidingWindowManager,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
MambaSpec: MambaManager,
CrossAttentionSpec: CrossAttentionManager,
SinkFullAttentionSpec: SinkFullAttentionManager,
}
spec_uniform_base_map: dict[type[KVCacheSpec], type[KVCacheSpec]] = {
FullAttentionSpec: FullAttentionSpec,
TQFullAttentionSpec: FullAttentionSpec,
MLAAttentionSpec: FullAttentionSpec,
HiddenStateCacheSpec: FullAttentionSpec,
SlidingWindowSpec: SlidingWindowSpec,
SlidingWindowMLASpec: SlidingWindowMLASpec,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionSpec,
MambaSpec: MambaSpec,
CrossAttentionSpec: CrossAttentionSpec,
SinkFullAttentionSpec: FullAttentionSpec,
}
spec_args_map: dict[type[KVCacheSpec], dict[str, Any]] = {
FullAttentionSpec: dict(
block_size=64, num_kv_heads=8, head_size=128, dtype=torch.bfloat16
),
TQFullAttentionSpec: dict(
block_size=64,
num_kv_heads=8,
head_size=128,
dtype=torch.bfloat16,
tq_slot_size=256,
),
MLAAttentionSpec: dict(
block_size=64, num_kv_heads=1, head_size=128, dtype=torch.bfloat16
),
HiddenStateCacheSpec: dict(
block_size=64, num_kv_heads=1, head_size=128, dtype=torch.bfloat16
),
SlidingWindowSpec: dict(
block_size=64,
num_kv_heads=8,
head_size=128,
dtype=torch.bfloat16,
sliding_window=128,
),
SlidingWindowMLASpec: dict(
block_size=64,
num_kv_heads=1,
head_size=128,
dtype=torch.bfloat16,
sliding_window=128,
),
ChunkedLocalAttentionSpec: dict(
block_size=64,
num_kv_heads=8,
head_size=128,
dtype=torch.bfloat16,
attention_chunk_size=4,
),
MambaSpec: dict(
block_size=64,
shapes=((2, 512), (3, 32, 32)),
dtypes=(torch.float32, torch.float32),
mamba_cache_mode="align",
num_speculative_blocks=2,
),
CrossAttentionSpec: dict(
block_size=64, num_kv_heads=8, head_size=128, dtype=torch.bfloat16
),
SinkFullAttentionSpec: dict(
block_size=64, num_kv_heads=8, head_size=128, dtype=torch.bfloat16, sink_len=16
),
}
def make_spec(spec_cls: type[KVCacheSpec]) -> KVCacheSpec:
return spec_cls(**spec_args_map[spec_cls])
def are_uniform_specs(*specs: KVCacheSpec) -> bool:
return UniformTypeKVCacheSpecs.is_uniform_type(
{f"layer_{i}": spec for i, spec in enumerate(specs)}
)
class TestKVCacheSpecRegistry:
"""Test the core registry functionality."""
def test_builtin_kvcache_specs_registered(self):
assert set(spec_manager_map) <= set(_REGISTRY_KVCACHESPEC_LIST)
for spec_cls, manager in spec_manager_map.items():
spec = make_spec(spec_cls)
assert KVCacheSpecRegistry.get_manager_class(spec) is manager
assert (
KVCacheSpecRegistry.get_uniform_type_base_spec(spec)
is spec_uniform_base_map[spec_cls]
)
@pytest.mark.parametrize("spec_cls", list(spec_manager_map))
def test_custom_spec_register(self, spec_cls):
"""A decorated custom spec resolves to the declared manager."""
manager = spec_manager_map[spec_cls]
uniform_base_spec = spec_uniform_base_map[spec_cls]
@register_kv_cache_spec(
manager_class=manager,
uniform_type_base_spec=uniform_base_spec,
)
@dataclass(frozen=True, kw_only=True)
class _CustomSpec(spec_cls): # type: ignore[valid-type,misc]
custom_param: int = 16
spec = _CustomSpec(**spec_args_map[spec_cls], custom_param=100)
assert KVCacheSpecRegistry.get_manager_class(spec) is manager
assert KVCacheSpecRegistry.get_uniform_type_base_spec(spec) is uniform_base_spec
def test_custom_spec_register_requires_manager(self):
"""Invalid register decorator arguments fail early."""
with pytest.raises(AssertionError, match="manager_class is required"):
@register_kv_cache_spec(
uniform_type_base_spec=FullAttentionSpec,
)
@dataclass(frozen=True, kw_only=True)
class _CustomFullSpecWithoutManager(FullAttentionSpec):
custom_param: int = 16
def test_unregistered_spec_no_registered_parent_raises(self):
"""
A spec whose entire MRO contains no registered class resolves to None.
Runtime callers should use check_kv_cache_spec_registry to fail early.
Subclasses of registered specs intentionally do not fail they inherit
their parent's manager via MRO walking.
"""
spec = _TrulyUnregisteredSpec(block_size=16)
assert KVCacheSpecRegistry.get_manager_class(spec) is None
assert KVCacheSpecRegistry.get_uniform_type_base_spec(spec) is None
with pytest.raises(
ValueError, match="Unsupported KV cache spec type for layer layer_0"
):
KVCacheSpecRegistry.check_kv_cache_spec_registry({"layer_0": spec})
with pytest.raises(AssertionError, match="Unsupported KV cache spec type"):
UniformTypeKVCacheSpecs.is_uniform_type({"layer_0": spec})
def test_unregistered_subclass_inherits_parent_manager(self):
"""
An unregistered subclass of a registered spec resolves via MRO
to its parent's manager — this is intentional registry behaviour.
"""
@dataclass(frozen=True, kw_only=True)
class _ImplicitlyInheritedSpec(FullAttentionSpec):
pass
spec = _ImplicitlyInheritedSpec(
block_size=16, num_kv_heads=8, head_size=128, dtype=torch.bfloat16
)
# MRO walk finds FullAttentionSpec → FullAttentionManager
assert KVCacheSpecRegistry.get_manager_class(spec) is FullAttentionManager
@pytest.mark.parametrize("spec_cls", list(spec_manager_map))
def test_builtin_specs_are_uniform_with_same_spec_type(self, spec_cls):
spec = make_spec(spec_cls)
assert are_uniform_specs(spec, replace(spec))
def test_full_attention_family_specs_are_uniform(self):
specs = [
make_spec(FullAttentionSpec),
make_spec(TQFullAttentionSpec),
make_spec(MLAAttentionSpec),
make_spec(HiddenStateCacheSpec),
make_spec(SinkFullAttentionSpec),
]
assert are_uniform_specs(*specs)
@pytest.mark.parametrize(
("spec_cls", "field", "value"),
[
(SlidingWindowSpec, "sliding_window", 256),
(SlidingWindowMLASpec, "sliding_window", 256),
(ChunkedLocalAttentionSpec, "attention_chunk_size", 8),
(MambaSpec, "num_speculative_blocks", 4),
],
)
def test_specs_with_type_specific_uniform_fields(self, spec_cls, field, value):
spec = make_spec(spec_cls)
changed_spec = replace(spec, **{field: value})
assert not are_uniform_specs(spec, changed_spec)
@pytest.mark.parametrize(
("left_cls", "right_cls"),
[
(FullAttentionSpec, CrossAttentionSpec),
(FullAttentionSpec, SlidingWindowSpec),
(FullAttentionSpec, ChunkedLocalAttentionSpec),
(FullAttentionSpec, MambaSpec),
(SlidingWindowMLASpec, SlidingWindowSpec),
(ChunkedLocalAttentionSpec, SlidingWindowSpec),
(MambaSpec, CrossAttentionSpec),
],
)
def test_different_uniform_groups_are_not_uniform(self, left_cls, right_cls):
assert not are_uniform_specs(make_spec(left_cls), make_spec(right_cls))
def test_different_block_sizes_are_not_uniform(self):
spec = make_spec(FullAttentionSpec)
assert not are_uniform_specs(spec, replace(spec, block_size=32))
def test_registered_custom_spec_uses_base_uniform_rule(self):
@register_kv_cache_spec(
manager_class=FullAttentionManager,
uniform_type_base_spec=FullAttentionSpec,
)
@dataclass(frozen=True, kw_only=True)
class _CustomFullSpec(FullAttentionSpec):
custom_param: int = 16
custom_spec = _CustomFullSpec(
block_size=64,
num_kv_heads=8,
head_size=128,
dtype=torch.bfloat16,
)
assert are_uniform_specs(custom_spec, make_spec(FullAttentionSpec))
+1 -1
View File
@@ -53,7 +53,7 @@ class SchedulerConfig:
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""
max_num_scheduled_tokens: int | None = None
max_num_scheduled_tokens: int | None = Field(default=None, ge=0)
"""Maximum number of tokens that the scheduler may issue in a single iteration.
This is usually equal to max_num_batched_tokens, but can be smaller in cases
@@ -38,9 +38,19 @@ from vllm.utils.network_utils import (
is_valid_ipv6_address,
)
if envs.VLLM_USE_SPINLOOP_EXT:
from vllm.spinloop import spinloop
logger = init_logger(__name__)
SPINLOOP_EXT_ENABLED = False
if envs.VLLM_USE_SPINLOOP_EXT:
try:
from vllm.spinloop import spinloop
SPINLOOP_EXT_ENABLED = True
except ImportError:
logger.warning(
"spinloop extension could not be loaded, disabling VLLM_USE_SPINLOOP_EXT!"
)
SPINLOOP_TIMEOUT_SECONDS = 0.1
if TYPE_CHECKING:
@@ -82,9 +92,6 @@ def to_bytes_big(value: int, size: int) -> bytes:
return value.to_bytes(size, byteorder="big")
logger = init_logger(__name__)
LONG_WAIT_TIME_LOG_MSG = (
"No available shared memory broadcast block found "
"in %d seconds. This typically happens "
@@ -552,7 +559,7 @@ class MessageQueue:
written_flag = metadata_buffer[0]
return not (written_flag and read_count != self.buffer.n_reader)
if envs.VLLM_USE_SPINLOOP_EXT and not check():
if SPINLOOP_EXT_ENABLED and not check():
spinloop(metadata_buffer, check, timeout=SPINLOOP_TIMEOUT_SECONDS)
if not check():
@@ -673,7 +680,7 @@ class MessageQueue:
written_flag = metadata_buffer[0]
return not (not written_flag or read_flag)
if envs.VLLM_USE_SPINLOOP_EXT and not check():
if SPINLOOP_EXT_ENABLED and not check():
spinloop(
metadata_buffer[0 : self.local_reader_rank + 1],
check,
@@ -13,7 +13,6 @@ from vllm.v1.core.kv_cache_utils import (
)
from vllm.v1.core.single_type_kv_cache_manager import (
SingleTypeKVCacheManager,
spec_manager_map,
)
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
@@ -21,6 +20,7 @@ from vllm.v1.kv_cache_interface import (
KVCacheSpec,
UniformTypeKVCacheSpecs,
)
from vllm.v1.kv_cache_spec_registry import KVCacheSpecRegistry
# Dummy placeholder hash for store_mask's template computation.
_DUMMY_BLOCK_HASH = BlockHash(b"\x00" * 32)
@@ -89,7 +89,10 @@ class MooncakeStoreCoordinator:
] = []
for i, g in enumerate(self.kv_cache_groups):
spec = _unwrap_spec(g.kv_cache_spec)
manager_cls = spec_manager_map[type(spec)]
manager_cls = KVCacheSpecRegistry.get_manager_class(spec)
assert manager_cls is not None, (
f"No manager registered for KVCacheSpec {spec}"
)
for existing_spec, group_ids, existing_cls in attention_groups:
if existing_spec == spec:
assert manager_cls is existing_cls
+10 -4
View File
@@ -194,10 +194,16 @@ class OpenAIServingModels:
lora_request.lora_name, lora_request.lora_path
)
) in str(e):
raise LoRAAdapterNotFoundError(
lora_request.lora_name, lora_request.lora_path
) from e
raise
return create_error_response(
LoRAAdapterNotFoundError(
lora_request.lora_name, lora_request.lora_path
)
)
return create_error_response(
message=str(e),
err_type="InternalServerError",
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)
self.lora_requests[lora_name] = lora_request
logger.info(
+7
View File
@@ -698,6 +698,13 @@ class Platform:
mamba_padding_pct,
)
@classmethod
def register_custom_kv_cache_specs(cls, vllm_config: "VllmConfig") -> None:
"""
Register custom KVCacheSpec class on current platform.
"""
pass
@classmethod
def verify_model_arch(cls, model_arch: str) -> None:
"""
+69 -1
View File
@@ -450,6 +450,66 @@ def _detect_content_format(
return "openai"
@lru_cache(maxsize=32)
def _detect_developer_role_support(chat_template: str) -> bool:
return '"developer"' in chat_template or "'developer'" in chat_template
def _convert_developer_to_system(
conversation: list[ConversationMessage],
) -> list[ConversationMessage]:
converted: list[ConversationMessage] = []
for msg in conversation:
if msg["role"] == "developer":
new_msg = dict(msg)
new_msg["role"] = "system"
new_msg.pop("tools", None)
converted.append(new_msg) # type: ignore[arg-type]
else:
converted.append(msg)
return converted
def _consolidate_system_messages(
conversation: list[ConversationMessage],
) -> list[ConversationMessage]:
"""Merge all system messages into one at position 0.
Some chat templates (e.g. Qwen 3.6) require the system message to be the
very first message. After developer-to-system conversion, system messages
may appear at non-first positions; this merges them into a single message.
"""
system_contents: list[str] = []
non_system: list[ConversationMessage] = []
needs_consolidation = False
for i, msg in enumerate(conversation):
if msg["role"] == "system":
if i > 0 or system_contents:
needs_consolidation = True
content = msg.get("content", "")
if isinstance(content, list):
parts = []
for part in content:
if isinstance(part, dict) and "text" in part:
parts.append(part["text"])
elif isinstance(part, str):
parts.append(part)
content = "\n".join(parts)
if content:
system_contents.append(content)
else:
non_system.append(msg)
if not needs_consolidation:
return conversation
merged: ConversationMessage = {
"role": "system",
"content": "\n\n".join(system_contents),
}
return [merged, *non_system]
def _resolve_chat_template_content_format(
chat_template: str | None,
tools: list[dict[str, Any]] | None,
@@ -653,7 +713,15 @@ def safe_apply_chat_template(
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
)
if any(
msg["role"] == "developer" for msg in conversation
) and not _detect_developer_role_support(chat_template):
conversation = _convert_developer_to_system(conversation)
conversation = _consolidate_system_messages(conversation)
logger.info_once(
"Chat template does not support the 'developer' message role. "
"Converting developer messages to 'system' role.",
)
resolved_kwargs = resolve_chat_template_kwargs(
tokenizer=tokenizer,
chat_template=chat_template,
+2 -2
View File
@@ -11,7 +11,6 @@ The fix streams string values incrementally as they arrive, providing a true
streaming experience for long content.
"""
import ast
import json
from collections.abc import Sequence
from typing import Any
@@ -42,6 +41,7 @@ from vllm.tool_parsers.utils import (
extract_types_from_schema,
find_tool_properties,
partial_tag_overlap,
safe_literal_eval,
)
logger = init_logger(__name__)
@@ -110,7 +110,7 @@ class Glm4MoeModelToolParser(ToolParser):
pass
try:
return ast.literal_eval(value)
return safe_literal_eval(value)
except (ValueError, SyntaxError):
pass
+3 -3
View File
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
from collections.abc import Sequence
from typing import Any
@@ -27,6 +26,7 @@ from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.tool_parsers.utils import safe_literal_eval
logger = init_logger(__name__)
@@ -183,13 +183,13 @@ class HYV3ToolParser(ToolParser):
@staticmethod
def _deserialize(value: str) -> Any:
"""Deserialize a string value using json.loads then ast.literal_eval."""
"""Deserialize a string value using json.loads then safe_literal_eval."""
try:
return json.loads(value)
except Exception:
pass
try:
return ast.literal_eval(value)
return safe_literal_eval(value)
except Exception:
pass
return value
+2 -3
View File
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
from collections.abc import Sequence
from typing import Any
@@ -28,7 +27,7 @@ from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.tool_parsers.utils import partial_tag_overlap
from vllm.tool_parsers.utils import partial_tag_overlap, safe_literal_eval
from vllm.utils import random_uuid
logger = init_logger(__name__)
@@ -116,7 +115,7 @@ def _parse_arguments(json_value: str) -> tuple[Any, bool]:
try:
parsed_value = json.loads(json_value)
except json.JSONDecodeError:
parsed_value = ast.literal_eval(json_value)
parsed_value = safe_literal_eval(json_value)
return parsed_value, True
except Exception:
return json_value, False
+2 -2
View File
@@ -11,7 +11,6 @@ The fix streams string values incrementally as they arrive, providing a true
streaming experience for long content.
"""
import ast
import json
from collections.abc import Sequence
from typing import Any
@@ -41,6 +40,7 @@ from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.tool_parsers.utils import safe_literal_eval
logger = init_logger(__name__)
@@ -106,7 +106,7 @@ class PoolsideV1ToolParser(ToolParser):
pass
try:
return ast.literal_eval(value)
return safe_literal_eval(value)
except (ValueError, SyntaxError):
pass
+2 -3
View File
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
from collections.abc import Sequence
from typing import Any
@@ -26,7 +25,7 @@ from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.tool_parsers.utils import find_tool_properties
from vllm.tool_parsers.utils import find_tool_properties, safe_literal_eval
logger = init_logger(__name__)
@@ -824,7 +823,7 @@ class StreamingXMLToolCallParser:
try:
parsed_value = json.loads(raw_for_parse)
except json.JSONDecodeError:
parsed_value = ast.literal_eval(raw_for_parse)
parsed_value = safe_literal_eval(raw_for_parse)
output_arguments = json.dumps(parsed_value, ensure_ascii=False)
except Exception:
# Fallback: output as string as-is
+2 -2
View File
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
from collections.abc import Sequence
from typing import Any
@@ -23,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
from vllm.tool_parsers.utils import safe_literal_eval
logger = init_logger(__name__)
@@ -1016,7 +1016,7 @@ class StreamingXMLToolCallParser:
raw_for_parse = raw_text + "\n"
else:
raw_for_parse = raw_text
parsed_value = ast.literal_eval(raw_for_parse)
parsed_value = safe_literal_eval(raw_for_parse)
output_arguments = json.dumps(parsed_value, ensure_ascii=False)
except Exception:
# Fallback: output as string as-is
+7
View File
@@ -3,6 +3,7 @@
import ast
import json
import warnings
from json import JSONDecodeError, JSONDecoder
from typing import Any, TypeAlias
@@ -31,6 +32,12 @@ Tool: TypeAlias = ChatCompletionToolsParam | ResponsesTool
logger = init_logger(__name__)
def safe_literal_eval(text: str):
with warnings.catch_warnings():
warnings.simplefilter("ignore", SyntaxWarning)
return ast.literal_eval(text)
def partial_tag_overlap(text: str, tag: str) -> int:
"""Length of the longest prefix of *tag* that matches a suffix of *text*.
+4
View File
@@ -33,6 +33,7 @@ from vllm.v1.kv_cache_interface import (
SlidingWindowSpec,
UniformTypeKVCacheSpecs,
)
from vllm.v1.kv_cache_spec_registry import KVCacheSpecRegistry
from vllm.v1.request import Request
from vllm.v1.utils import tensor_data
@@ -1991,6 +1992,9 @@ def get_kv_cache_configs(
"across workers. This is not supported yet."
)
# Check if the KV cache specs are registered correctly.
# This is to prevent that some layers are initialized with unregistered specs.
KVCacheSpecRegistry.check_kv_cache_spec_registry(merged_kv_cache_specs)
# Get global KV cache groups. This also handles spec unification for
# hybrid models when disable_hybrid_kv_cache_manager is enabled.
# After this call, merged_kv_cache_specs may be modified in-place.
+1 -1
View File
@@ -104,7 +104,7 @@ class Scheduler(SchedulerInterface):
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
self.max_num_scheduled_tokens = (
self.scheduler_config.max_num_scheduled_tokens
if self.scheduler_config.max_num_scheduled_tokens
if self.scheduler_config.max_num_scheduled_tokens is not None
else self.scheduler_config.max_num_batched_tokens
)
self.max_model_len = vllm_config.model_config.max_model_len
+80 -15
View File
@@ -25,6 +25,7 @@ from vllm.v1.kv_cache_interface import (
SlidingWindowSpec,
TQFullAttentionSpec,
)
from vllm.v1.kv_cache_spec_registry import KVCacheSpecRegistry
from vllm.v1.request import Request
@@ -1247,27 +1248,30 @@ class SinkFullAttentionManager(FullAttentionManager):
self.sink_blocks = self.block_pool.free_block_queue.popleft_n(num_sink_block)
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
TQFullAttentionSpec: FullAttentionManager,
MLAAttentionSpec: FullAttentionManager,
HiddenStateCacheSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
SlidingWindowMLASpec: SlidingWindowManager,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
MambaSpec: MambaManager,
CrossAttentionSpec: CrossAttentionManager,
SinkFullAttentionSpec: SinkFullAttentionManager,
}
def get_manager_for_kv_cache_spec(
kv_cache_spec: KVCacheSpec,
max_num_batched_tokens: int,
max_model_len: int,
**kwargs,
) -> SingleTypeKVCacheManager:
manager_class = spec_manager_map[type(kv_cache_spec)]
"""
Get the appropriate manager for a given KVCacheSpec.
Uses the KVCacheSpecRegistry to look up the manager class, supporting
both built-in and custom specs registered via @register_kv_cache_spec
and KVCacheSpecRegistry.register.
Args:
kv_cache_spec: The KVCacheSpec instance
max_num_batched_tokens: The maximum number of tokens in a batch
max_model_len: The maximum context length the model could serve
Returns:
An instance of the appropriate SingleTypeKVCacheManager subclass
"""
manager_class = KVCacheSpecRegistry.get_manager_class(kv_cache_spec)
assert manager_class is not None, (
f"No manager registered for KVCacheSpec {type(kv_cache_spec)}"
)
# SlidingWindow / ChunkedLocalAttention managers recycle blocks across
# chunks; the runtime admission cap must match the recycling-aware bound
# the startup pool sizer uses (single source of truth: the spec method).
@@ -1280,3 +1284,64 @@ def get_manager_for_kv_cache_spec(
)
manager = manager_class(kv_cache_spec, **kwargs)
return manager
def register_all_kvcache_specs(vllm_config):
"""Built-in spec registration"""
KVCacheSpecRegistry.register(
FullAttentionSpec,
FullAttentionManager,
uniform_type_base_spec=FullAttentionSpec,
)
KVCacheSpecRegistry.register(
SlidingWindowSpec,
SlidingWindowManager,
uniform_type_base_spec=SlidingWindowSpec,
)
KVCacheSpecRegistry.register(
SlidingWindowMLASpec,
SlidingWindowManager,
uniform_type_base_spec=SlidingWindowMLASpec,
)
KVCacheSpecRegistry.register(
MambaSpec, MambaManager, uniform_type_base_spec=MambaSpec
)
KVCacheSpecRegistry.register(
ChunkedLocalAttentionSpec,
ChunkedLocalAttentionManager,
uniform_type_base_spec=ChunkedLocalAttentionSpec,
)
KVCacheSpecRegistry.register(
CrossAttentionSpec,
CrossAttentionManager,
uniform_type_base_spec=CrossAttentionSpec,
)
# FullAttentionSpec subclasses — grouped with FullAttentionSpec
KVCacheSpecRegistry.register(
TQFullAttentionSpec,
FullAttentionManager,
uniform_type_base_spec=FullAttentionSpec,
)
KVCacheSpecRegistry.register(
MLAAttentionSpec, FullAttentionManager, uniform_type_base_spec=FullAttentionSpec
)
# NOTE(Mengqing): HiddenStateCacheSpec won't take part in
# grouping, thus the uniform_type_base_spec is just a
# placeholder.
KVCacheSpecRegistry.register(
HiddenStateCacheSpec,
FullAttentionManager,
uniform_type_base_spec=FullAttentionSpec,
)
KVCacheSpecRegistry.register(
SinkFullAttentionSpec,
SinkFullAttentionManager,
uniform_type_base_spec=FullAttentionSpec,
)
from vllm.platforms import current_platform
current_platform.register_custom_kv_cache_specs(vllm_config)
+4
View File
@@ -52,6 +52,7 @@ from vllm.v1.core.kv_cache_utils import (
)
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.single_type_kv_cache_manager import register_all_kvcache_specs
from vllm.v1.engine import (
EEP_NOTIFICATION_CALL_ID,
EEPNotificationType,
@@ -235,6 +236,9 @@ class EngineCore:
def _initialize_kv_caches(self, vllm_config: VllmConfig) -> KVCacheConfig:
start = time.time()
# register all kvcache specs in enginecore process.
register_all_kvcache_specs(vllm_config)
# Get all kv cache needed by the model
kv_cache_specs = self.model_executor.get_kv_cache_specs()
+57 -42
View File
@@ -17,6 +17,7 @@ from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import get_dtype_size, nvfp4_kv_cache_full_dim
from vllm.v1.attention.backends.registry import MambaAttentionBackendEnum
from vllm.v1.kv_cache_spec_registry import KVCacheSpecRegistry
if TYPE_CHECKING:
from vllm.config import VllmConfig
@@ -139,6 +140,21 @@ class KVCacheSpec:
)
return copy.deepcopy(specs[0])
def is_uniform_with_collection(
self, kv_cache_specs: dict[str, KVCacheSpec]
) -> bool:
"""
Whether this KVCacheSpec is uniform with all specs of all layers.
"""
uniform_type_base_spec = KVCacheSpecRegistry.get_uniform_type_base_spec(self)
assert uniform_type_base_spec is not None, (
f"Unsupported KV cache spec type: {type(self)}. "
"Please register it using @register_kv_cache_spec decorator."
)
return all(
isinstance(spec, uniform_type_base_spec) for spec in kv_cache_specs.values()
)
@dataclass(frozen=True, kw_only=True)
class AttentionSpec(KVCacheSpec):
@@ -430,6 +446,15 @@ class ChunkedLocalAttentionSpec(AttentionSpec):
)
return max_blocks * self.page_size_bytes
def is_uniform_with_collection(
self, kv_cache_specs: dict[str, KVCacheSpec]
) -> bool:
return all(
isinstance(spec, ChunkedLocalAttentionSpec)
and spec.attention_chunk_size == self.attention_chunk_size
for spec in kv_cache_specs.values()
)
@dataclass(frozen=True, kw_only=True)
class SlidingWindowSpec(AttentionSpec):
@@ -493,6 +518,15 @@ class SlidingWindowSpec(AttentionSpec):
)
return max_blocks * self.page_size_bytes
def is_uniform_with_collection(
self, kv_cache_specs: dict[str, KVCacheSpec]
) -> bool:
return all(
isinstance(spec, SlidingWindowSpec)
and spec.sliding_window == self.sliding_window
for spec in kv_cache_specs.values()
)
@dataclass(frozen=True, kw_only=True)
class SlidingWindowMLASpec(SlidingWindowSpec):
@@ -558,6 +592,15 @@ class SlidingWindowMLASpec(SlidingWindowSpec):
model_version=model_version_set.pop(),
)
def is_uniform_with_collection(
self, kv_cache_specs: dict[str, KVCacheSpec]
) -> bool:
return all(
isinstance(spec, SlidingWindowMLASpec)
and spec.sliding_window == self.sliding_window
for spec in kv_cache_specs.values()
)
@dataclass(frozen=True)
class MambaSpec(KVCacheSpec):
@@ -590,6 +633,15 @@ class MambaSpec(KVCacheSpec):
else:
return self.page_size_bytes * (1 + self.num_speculative_blocks)
def is_uniform_with_collection(
self, kv_cache_specs: dict[str, KVCacheSpec]
) -> bool:
return all(
isinstance(spec, MambaSpec)
and spec.num_speculative_blocks == self.num_speculative_blocks
for spec in kv_cache_specs.values()
)
@dataclass(frozen=True)
class EncoderOnlyAttentionSpec(AttentionSpec):
@@ -689,53 +741,16 @@ class UniformTypeKVCacheSpecs(KVCacheSpec):
def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool:
"""
Whether all layers have the same type of KV cache spec.
Uses the registry to determine grouping base classes, so custom specs
that inherit from FullAttentionSpec are treated as full attention.
"""
block_sizes = set(spec.block_size for spec in kv_cache_specs.values())
if len(block_sizes) > 1:
# Different block sizes, not uniform.
return False
one_spec = next(iter(kv_cache_specs.values()))
# NOTE: Check subclasses before parent classes since isinstance()
# returns True for subclasses.
if isinstance(one_spec, SlidingWindowMLASpec):
# SlidingWindowMLASpec is uniform if all specs are SlidingWindowMLASpec
# with the same sliding_window size.
return all(
isinstance(spec, SlidingWindowMLASpec)
and spec.sliding_window == one_spec.sliding_window
for spec in kv_cache_specs.values()
)
elif isinstance(one_spec, FullAttentionSpec):
return all(
isinstance(spec, FullAttentionSpec) for spec in kv_cache_specs.values()
)
elif isinstance(one_spec, CrossAttentionSpec):
return all(
isinstance(spec, CrossAttentionSpec) for spec in kv_cache_specs.values()
)
elif isinstance(one_spec, SlidingWindowSpec):
return all(
isinstance(spec, SlidingWindowSpec)
and spec.sliding_window == one_spec.sliding_window
for spec in kv_cache_specs.values()
)
elif isinstance(one_spec, ChunkedLocalAttentionSpec):
return all(
isinstance(spec, ChunkedLocalAttentionSpec)
and spec.attention_chunk_size == one_spec.attention_chunk_size
for spec in kv_cache_specs.values()
)
elif isinstance(one_spec, MambaSpec):
return all(
isinstance(spec, MambaSpec)
and spec.num_speculative_blocks == one_spec.num_speculative_blocks
for spec in kv_cache_specs.values()
)
else:
# NOTE(Chen): Please add new branches for new KV cache spec types.
raise NotImplementedError(
f"Unsupported KV cache spec type: {type(one_spec)}"
)
first_spec = next(iter(kv_cache_specs.values()))
return first_spec.is_uniform_with_collection(kv_cache_specs)
@classmethod
def from_specs(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> Self | None:
+209
View File
@@ -0,0 +1,209 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Registry for KVCacheSpec types and their associated managers.
This module provides a pluggable architecture for registering custom KVCacheSpec
subclasses without modifying vLLM core code. Out-of-tree platforms can define
custom specs and managers by using the @register_kv_cache_spec decorator.
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING
from vllm.logger import init_logger
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.v1.core.single_type_kv_cache_manager import SingleTypeKVCacheManager
from vllm.v1.kv_cache_interface import KVCacheSpec
@dataclass(frozen=True)
class KVCacheSpecMetadata:
"""Metadata for a registered KVCacheSpec."""
kvcache_spec_cls: type["KVCacheSpec"]
manager_class: type["SingleTypeKVCacheManager"]
# The base spec class for grouping compatibility checks.
# KVCacheSpecs with the same uniform_type_base_spec will be
# grouped into one kvcache group
uniform_type_base_spec: type["KVCacheSpec"]
_REGISTRY_KVCACHESPEC_LIST: dict[type["KVCacheSpec"], KVCacheSpecMetadata] = {}
class KVCacheSpecRegistry:
"""Global registry for KVCacheSpec types and their associated managers."""
@classmethod
def _ensure_registered(cls, vllm_config=None) -> None:
"""
Run full KVCacheSpec registration if the registration is not done.
"""
if _REGISTRY_KVCACHESPEC_LIST:
return
if vllm_config is None:
from vllm.config import get_current_vllm_config_or_none
vllm_config = get_current_vllm_config_or_none()
# lazy import to avoid circular dependency
from vllm.v1.core.single_type_kv_cache_manager import (
register_all_kvcache_specs,
)
register_all_kvcache_specs(vllm_config)
@classmethod
def register(
cls,
kvcache_spec_cls: type["KVCacheSpec"],
manager_class: type["SingleTypeKVCacheManager"] | None = None,
uniform_type_base_spec: type["KVCacheSpec"] | None = None,
) -> None:
"""
Register a KVCacheSpec class with its manager and base spec.
Args:
kvcache_spec_cls: The KVCacheSpec subclass to register
manager_class: The SingleTypeKVCacheManager to use for this spec
uniform_type_base_spec: The base spec class for grouping compatibility.
instead of being grouped to different kvcache group, `kvcache_spec_cls`
and `uniform_type_base_spec` will be trated as uniform type.
If None, defaults to kvcache_spec_cls itself (for built-in base specs).
"""
assert manager_class is not None, "manager_class is required"
if uniform_type_base_spec is None:
uniform_type_base_spec = kvcache_spec_cls
assert issubclass(kvcache_spec_cls, uniform_type_base_spec), (
f"{kvcache_spec_cls.__name__} must inherit from its declared "
f"uniform_type_base_spec {uniform_type_base_spec.__name__}."
)
if kvcache_spec_cls in _REGISTRY_KVCACHESPEC_LIST:
registered_spec = _REGISTRY_KVCACHESPEC_LIST[kvcache_spec_cls]
is_same_registration = (
manager_class == registered_spec.manager_class
and uniform_type_base_spec == registered_spec.uniform_type_base_spec
)
assert is_same_registration, (
f"Conflicting registration for KVCacheSpec "
f": {kvcache_spec_cls.__name__}"
)
_REGISTRY_KVCACHESPEC_LIST[kvcache_spec_cls] = KVCacheSpecMetadata(
kvcache_spec_cls=kvcache_spec_cls,
manager_class=manager_class,
uniform_type_base_spec=uniform_type_base_spec,
)
@classmethod
def get_manager_class(
cls, kvcache_spec: "KVCacheSpec"
) -> type["SingleTypeKVCacheManager"] | None:
"""
Get the single type kvcache manager class for a given kvcache spec instance.
Args:
kvcache_spec: A KVCacheSpec instance
Returns:
The SingleTypeKVCacheManager class to use for this kvcache_spec
"""
cls._ensure_registered()
kvcache_spec_cls = type(kvcache_spec)
# Walk up the MRO to find a registered base class
for base in kvcache_spec_cls.__mro__:
if base in _REGISTRY_KVCACHESPEC_LIST:
return _REGISTRY_KVCACHESPEC_LIST[base].manager_class
return None
@classmethod
def get_uniform_type_base_spec(
cls, kvcache_spec: "KVCacheSpec"
) -> type["KVCacheSpec"] | None:
"""
Get the base kvcache spec class for grouping compatibility checks.
KVCacheSpecs with uniform_type_base_spec will be trated as one group.
Args:
kvcache_spec: A KVCacheSpec instance
Returns:
The base KVCacheSpec class for checking uniform type kvcache specs
"""
cls._ensure_registered()
kvcache_spec_cls = type(kvcache_spec)
# Walk up the MRO to find a registered base spec
for base in kvcache_spec_cls.__mro__:
if base in _REGISTRY_KVCACHESPEC_LIST:
return _REGISTRY_KVCACHESPEC_LIST[base].uniform_type_base_spec
return None
@classmethod
def check_kv_cache_spec_registry(
cls, kv_cache_spec: dict[str, "KVCacheSpec"]
) -> None:
"""
Check if the KVCacheSpecs of each layer are registered as expected.
"""
cls._ensure_registered()
for layer_name, spec in kv_cache_spec.items():
# use raise instead of assert to make it effective in production environment
if cls.get_uniform_type_base_spec(spec) is None:
raise ValueError(
f"Unsupported KV cache spec type for layer {layer_name}: "
f"{type(spec)}. Please register it using "
f"@register_kv_cache_spec decorator."
)
if cls.get_manager_class(spec) is None:
raise ValueError(
f"No manager found for KV cache spec type for layer "
f"{layer_name}: {type(spec)}. Please register it using "
f"@register_kv_cache_spec decorator."
)
def register_kv_cache_spec(
manager_class: type["SingleTypeKVCacheManager"] | None = None,
uniform_type_base_spec: type["KVCacheSpec"] | None = None,
):
"""
Decorator to register a custom KVCacheSpec class.
Args:
manager_class: The SingleTypeKVCacheManager to use for this spec.
Required for all registered specs.
uniform_type_base_spec: The base spec class for uniform type kv cache specs
compatibility. If None, the spec is treated as a new base
type.
Examples:
- Register a new specs:
@register_kv_cache_spec(
manager_class=FullAttentionManager,
uniform_type_base_spec=FullAttentionSpec
)
@dataclass(frozen=True, kw_only=True)
class CustomFullAttentionSpec(FullAttentionSpec):
pass
"""
def decorator(kvcache_spec_cls: type["KVCacheSpec"]) -> type["KVCacheSpec"]:
KVCacheSpecRegistry.register(
kvcache_spec_cls=kvcache_spec_cls,
manager_class=manager_class,
uniform_type_base_spec=uniform_type_base_spec,
)
return kvcache_spec_cls
return decorator
+74 -49
View File
@@ -19,6 +19,69 @@ if HAS_TRITON:
logger = init_logger(__name__)
_FLASHINFER_MIN_VERSION = "0.2.3"
def flashinfer_sampler_supported() -> bool:
"""Decide whether FlashInfer's top-p/top-k sampler can be used.
Returns False (with appropriate logging) when ``VLLM_USE_FLASHINFER_SAMPLER``
is 0, when the platform isn't CUDA, when the GPU's compute capability is
unsupported, or when the installed flashinfer is missing or too old. Raises
``RuntimeError`` if the user explicitly opted in via the env var but
FlashInfer is unavailable.
Note: callers must additionally ensure ``logprobs_mode`` doesn't require
post-top-k/top-p logits/logprobs for any request whose logprobs will be
returned in this step, since FlashInfer doesn't expose those.
"""
if not current_platform.is_cuda():
return False
if not envs.VLLM_USE_FLASHINFER_SAMPLER:
logger.info_once(
"FlashInfer top-p/top-k sampling disabled via "
"VLLM_USE_FLASHINFER_SAMPLER=0."
)
return False
from vllm.v1.attention.backends.flashinfer import FlashInferBackend
capability = current_platform.get_device_capability()
assert capability is not None
unsupported_reason: str | None = None
if not FlashInferBackend.supports_compute_capability(capability):
unsupported_reason = (
f"unsupported compute capability {capability.as_version_str()}"
)
else:
try:
import flashinfer
if version.parse(flashinfer.__version__) < version.parse(
_FLASHINFER_MIN_VERSION
):
unsupported_reason = (
f"flashinfer {flashinfer.__version__} is too old "
f"(>={_FLASHINFER_MIN_VERSION} required)"
)
except ImportError:
unsupported_reason = "flashinfer is not installed"
if unsupported_reason is None:
logger.info_once("Using FlashInfer for top-p & top-k sampling.", scope="global")
return True
if envs.is_set("VLLM_USE_FLASHINFER_SAMPLER"):
raise RuntimeError(
f"FlashInfer top-p/top-k sampling unavailable: {unsupported_reason}. "
"Unset VLLM_USE_FLASHINFER_SAMPLER=1."
)
logger.warning_once(
"FlashInfer top-p/top-k sampling unavailable: %s; falling back. "
"Set VLLM_USE_FLASHINFER_SAMPLER=0 to silence.",
unsupported_reason,
)
return False
class TopKTopPSampler(nn.Module):
"""
Module that performs optional top-k and top-p filtering followed by
@@ -30,49 +93,16 @@ class TopKTopPSampler(nn.Module):
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None:
super().__init__()
self.logprobs_mode = logprobs_mode
# flashinfer optimization does not apply if intermediate
# logprobs/logits after top_k/top_p need to be returned
if (
logprobs_mode not in ("processed_logits", "processed_logprobs")
and current_platform.is_cuda()
):
if envs.VLLM_USE_FLASHINFER_SAMPLER:
from vllm.v1.attention.backends.flashinfer import FlashInferBackend
capability = current_platform.get_device_capability()
assert capability is not None
if FlashInferBackend.supports_compute_capability(capability):
logger.info_once(
"Using FlashInfer for top-p & top-k sampling.",
scope="global",
)
self.forward = self.forward_cuda
elif envs.is_set("VLLM_USE_FLASHINFER_SAMPLER"):
# User explicitly opted in but the GPU can't run FlashInfer.
capability_str = capability.as_version_str()
raise RuntimeError(
"FlashInfer does not support compute capability "
f"{capability_str}, unset VLLM_USE_FLASHINFER_SAMPLER=1."
)
else:
# Default-on path; hardware can't run FlashInfer →
# quietly fall back to the PyTorch-native sampler
# instead of failing server startup.
logger.warning_once(
"FlashInfer top-p/top-k sampling not supported on "
"compute capability %s; falling back to PyTorch-native "
"sampler. Set VLLM_USE_FLASHINFER_SAMPLER=0 to silence.",
capability.as_version_str(),
)
self.forward = self.forward_native
else:
# User explicitly set VLLM_USE_FLASHINFER_SAMPLER=0.
logger.info_once(
"FlashInfer top-p/top-k sampling disabled via "
"VLLM_USE_FLASHINFER_SAMPLER=0; using PyTorch-native sampler."
)
self.forward = self.forward_native
if current_platform.is_cuda():
# FlashInfer doesn't expose post-top-k/top-p logits/logprobs,
# so it can't be used when the configured mode requires them.
can_use_flashinfer = (
logprobs_mode not in ("processed_logits", "processed_logprobs")
and flashinfer_sampler_supported()
)
self.forward = (
self.forward_cuda if can_use_flashinfer else self.forward_native
)
elif current_platform.is_cpu():
arch = current_platform.get_cpu_architecture()
# Fall back to native implementation for POWERPC and RISCV.
@@ -417,7 +447,7 @@ def flashinfer_sample(
logits: torch.Tensor,
k: torch.Tensor | None,
p: torch.Tensor | None,
generators: dict[int, torch.Generator],
generators: dict[int, torch.Generator] = {}, # noqa
) -> torch.Tensor:
"""Sample from the logits using FlashInfer.
@@ -431,11 +461,6 @@ def flashinfer_sample(
"""
import flashinfer
if version.parse(flashinfer.__version__) < version.parse("0.2.3"):
raise ImportError(
"FlashInfer version >= 0.2.3 required for top-k and top-p sampling. "
)
assert not (k is None and p is None)
if k is None:
# Top-p only.
+46 -14
View File
@@ -7,6 +7,11 @@ import torch
import vllm.envs as envs
from vllm.config.model import LogprobsMode
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.ops.topk_topp_sampler import (
apply_top_k_top_p,
flashinfer_sample,
flashinfer_sampler_supported,
)
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.bad_words import BadWordsState
@@ -45,6 +50,7 @@ class Sampler:
self.bad_words_state = BadWordsState(req_states)
self.logprob_token_ids_state = LogprobTokenIdsState(max_num_reqs, device)
self.num_speculative_tokens = num_speculative_tokens
self.use_flashinfer = flashinfer_sampler_supported()
def add_request(
self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
@@ -77,6 +83,13 @@ class Sampler:
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# that num_nans is computed before applying penalties and temperature.
num_nans = get_num_nans(logits) if self.compute_nans else None
max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
max_per_req_token_ids = self.logprob_token_ids_state.max_num_token_ids(
idx_mapping_np
)
return_logprobs = max_num_logprobs != NO_LOGPROBS or max_per_req_token_ids > 0
sampled, processed_logits = self.sample(
logits,
expanded_idx_mapping,
@@ -84,13 +97,10 @@ class Sampler:
pos,
input_ids,
expanded_local_pos,
return_logprobs=return_logprobs,
)
max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
max_per_req_token_ids = self.logprob_token_ids_state.max_num_token_ids(
idx_mapping_np
)
if max_num_logprobs != NO_LOGPROBS or max_per_req_token_ids > 0:
if return_logprobs:
if self.logprobs_mode == "processed_logprobs":
logits = processed_logits
expanded_logits = logits.shape[0] != idx_mapping_np.shape[0]
@@ -128,6 +138,7 @@ class Sampler:
pos: torch.Tensor,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
skip_top_k_top_p: bool = False,
) -> torch.Tensor:
# Copy logits to a new FP32 tensor.
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
@@ -163,6 +174,9 @@ class Sampler:
# Apply min_p in place.
self.sampling_states.apply_min_p(logits, expanded_idx_mapping, idx_mapping_np)
if skip_top_k_top_p:
return logits
# Apply top_k and/or top_p. This might or might not return a new tensor.
return self.sampling_states.apply_top_k_top_p(
logits, expanded_idx_mapping, idx_mapping_np
@@ -176,6 +190,7 @@ class Sampler:
pos: torch.Tensor,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
return_logprobs: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
processed_logits = self.apply_sampling_params(
logits,
@@ -184,16 +199,33 @@ class Sampler:
pos,
input_ids,
expanded_local_pos,
skip_top_k_top_p=True,
)
top_k, top_p = self.sampling_states.get_top_k_top_p(
expanded_idx_mapping, idx_mapping_np
)
use_flashinfer = self.use_flashinfer and not (
# Don't use FI sampler if no requests use top_k/top_p, if there are
# any greedy requests or per-request seeds, or if post-processed
# logprobs need to be returned for any requests.
(top_k is None and top_p is None)
or (return_logprobs and self.logprobs_mode == "processed_logprobs")
or self.sampling_states.any_greedy(idx_mapping_np)
or self.sampling_states.any_explicit_seed(idx_mapping_np)
)
# Sample the next token.
sampled = gumbel_sample(
processed_logits,
expanded_idx_mapping,
self.sampling_states.temperature.gpu,
self.sampling_states.seeds.gpu,
pos,
apply_temperature=False,
use_fp64=self.use_fp64_gumbel,
)
if use_flashinfer:
sampled = flashinfer_sample(processed_logits, top_k, top_p).to(torch.int64)
else:
processed_logits = apply_top_k_top_p(processed_logits, top_k, top_p)
sampled = gumbel_sample(
processed_logits,
expanded_idx_mapping,
self.sampling_states.temperature.gpu,
self.sampling_states.seeds.gpu,
pos,
apply_temperature=False,
use_fp64=self.use_fp64_gumbel,
)
return sampled, processed_logits
+21 -6
View File
@@ -24,6 +24,9 @@ class SamplingStates:
self.top_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.min_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.seeds = UvaBackedTensor(max_num_reqs, dtype=torch.int64)
# Tracks whether `seed` was set explicitly by the user, so callers
# can fall back from RNG paths that don't honor per-request seeds.
self.seeds_set = np.zeros(max_num_reqs, dtype=bool)
# Initialize top_k and top_p manually because 0 is an invalid value for them.
self.top_k.np.fill(self.vocab_size)
@@ -45,6 +48,7 @@ class SamplingStates:
self.min_p.np[req_idx] = sampling_params.min_p
seed = sampling_params.seed
self.seeds_set[req_idx] = seed is not None
if seed is None:
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
self.seeds.np[req_idx] = seed
@@ -85,20 +89,31 @@ class SamplingStates:
return
apply_min_p(logits, expanded_idx_mapping, self.min_p.gpu)
def get_top_k_top_p(
self, expanded_idx_mapping: torch.Tensor, idx_mapping_np: np.ndarray
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
do_top_p = np.any(self.top_p.np[idx_mapping_np] != 1.0)
top_k = self.top_k.gpu[expanded_idx_mapping] if do_top_k else None
top_p = self.top_p.gpu[expanded_idx_mapping] if do_top_p else None
return top_k, top_p
def apply_top_k_top_p(
self,
logits: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> torch.Tensor:
do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
do_top_p = np.any(self.top_p.np[idx_mapping_np] != 1.0)
if not (do_top_k or do_top_p):
top_k, top_p = self.get_top_k_top_p(expanded_idx_mapping, idx_mapping_np)
if top_k is None and top_p is None:
return logits
top_k = self.top_k.gpu[expanded_idx_mapping] if do_top_k else None
top_p = self.top_p.gpu[expanded_idx_mapping] if do_top_p else None
return apply_top_k_top_p(logits, top_k, top_p)
def any_greedy(self, idx_mapping_np: np.ndarray) -> bool:
return bool(np.any(self.temperature.np[idx_mapping_np] == 0.0))
def any_explicit_seed(self, idx_mapping_np: np.ndarray) -> bool:
return bool(np.any(self.seeds_set[idx_mapping_np]))
def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int:
return int(np.max(self.num_logprobs[idx_mapping_np]))
+2
View File
@@ -152,6 +152,7 @@ from vllm.v1.kv_cache_interface import (
SlidingWindowSpec,
UniformTypeKVCacheSpecs,
)
from vllm.v1.kv_cache_spec_registry import KVCacheSpecRegistry
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
AsyncModelRunnerOutput,
@@ -6235,6 +6236,7 @@ class GPUModelRunner(
)
kv_cache_spec = self.get_kv_cache_spec()
KVCacheSpecRegistry.check_kv_cache_spec_registry(kv_cache_spec)
kv_cache_groups = get_kv_cache_groups(self.vllm_config, kv_cache_spec)
min_blocks = self.compilation_config.max_cudagraph_capture_size or 1