mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
Merge branch 'main' into wentao-fix-NixlConnector-PD-+-Spec-Decode-acceptance-(2-GPUs)
This commit is contained in:
+15
-12
@@ -112,6 +112,8 @@ endif()
|
||||
#
|
||||
# spinloop extension (pure CXX; must stay above the non-CUDA device branch so
|
||||
# CPU builds define the target before the early return)
|
||||
# This extension requires SABI 3.11 since it relies on Py_buffer support. Loading
|
||||
# failure is handled gracefully on vLLM side for lower Python versions.
|
||||
#
|
||||
set(VLLM_SPINLOOP_EXT_SRC "csrc/spinloop.cpp")
|
||||
set(SPINLOOP_COMPILE_FLAGS "")
|
||||
@@ -309,14 +311,9 @@ set(VLLM_EXT_SRC
|
||||
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
|
||||
"csrc/quantization/activation_kernels.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/torch_bindings.cpp"
|
||||
"csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu")
|
||||
"csrc/torch_bindings.cpp")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/minimax_reduce_rms_kernel.cu")
|
||||
|
||||
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
||||
|
||||
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
|
||||
@@ -503,12 +500,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND ES_MXFP8_GROUPED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu"
|
||||
"csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu")
|
||||
"csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu"
|
||||
"csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${ES_MXFP8_GROUPED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_ES_MXFP8_GROUPED_MM_SM100=1")
|
||||
message(STATUS "Building ES MXFP8 grouped kernels for archs: ${ES_MXFP8_GROUPED_MM_ARCHS}")
|
||||
else()
|
||||
@@ -598,7 +595,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
|
||||
if (VLLM_GPU_LANG STREQUAL "HIP")
|
||||
# Add QuickReduce kernels
|
||||
# Add QuickReduce kernels (ROCm-only; not part of stable ABI migration).
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/custom_quickreduce.cu"
|
||||
)
|
||||
@@ -649,7 +646,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
"csrc/libtorch_stable/attention/paged_attention_v1.cu"
|
||||
"csrc/libtorch_stable/attention/paged_attention_v2.cu"
|
||||
"csrc/libtorch_stable/cache_kernels.cu"
|
||||
"csrc/libtorch_stable/cache_kernels_fused.cu")
|
||||
"csrc/libtorch_stable/cache_kernels.cu"
|
||||
"csrc/libtorch_stable/cache_kernels_fused.cu"
|
||||
"csrc/libtorch_stable/custom_all_reduce.cu"
|
||||
"csrc/libtorch_stable/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC
|
||||
@@ -659,7 +659,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||
"csrc/libtorch_stable/permute_cols.cu"
|
||||
"csrc/libtorch_stable/quantization/awq/gemm_kernels.cu")
|
||||
"csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu"
|
||||
"csrc/libtorch_stable/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/libtorch_stable/minimax_reduce_rms_kernel.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_STABLE_EXT_SRC}"
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/all.h>
|
||||
#include "torch_utils.h"
|
||||
|
||||
#include <torch/csrc/stable/macros.h>
|
||||
#include <torch/csrc/stable/accelerator.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/csrc/stable/device.h>
|
||||
|
||||
#include "custom_all_reduce.cuh"
|
||||
|
||||
@@ -11,7 +15,7 @@ using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
|
||||
torch::Tensor& rank_data, int64_t rank,
|
||||
torch::stable::Tensor& rank_data, int64_t rank,
|
||||
bool fully_connected) {
|
||||
int world_size = fake_ipc_ptrs.size();
|
||||
if (world_size > 8)
|
||||
@@ -25,9 +29,9 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(fake_ipc_ptrs[i]);
|
||||
}
|
||||
return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(),
|
||||
rank_data.numel(), rank, world_size,
|
||||
fully_connected);
|
||||
return (fptr_t) new vllm::CustomAllreduce(
|
||||
ipc_ptrs, rank_data.mutable_data_ptr(), rank_data.numel(), rank,
|
||||
world_size, fully_connected);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -46,10 +50,14 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
|
||||
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
||||
* 6. A[:, 1:, 1:]: Not OK
|
||||
*/
|
||||
bool _is_weak_contiguous(torch::Tensor& t) {
|
||||
return t.is_contiguous() ||
|
||||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
|
||||
t.numel() * t.element_size());
|
||||
bool _is_weak_contiguous(torch::stable::Tensor& t) {
|
||||
if (t.is_contiguous()) {
|
||||
return true;
|
||||
}
|
||||
int64_t storage_nbytes = 0;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_storage_size(t.get(), &storage_nbytes));
|
||||
return storage_nbytes - t.storage_offset() * t.element_size() ==
|
||||
static_cast<int64_t>(t.numel() * t.element_size());
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -59,42 +67,45 @@ bool _is_weak_contiguous(torch::Tensor& t) {
|
||||
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
|
||||
* copied into _reg_buffer.
|
||||
*/
|
||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
|
||||
void all_reduce(fptr_t _fa, torch::stable::Tensor& inp,
|
||||
torch::stable::Tensor& out, fptr_t _reg_buffer,
|
||||
int64_t reg_buffer_sz_bytes) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
inp.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream(inp.get_device_index());
|
||||
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK(_is_weak_contiguous(out));
|
||||
TORCH_CHECK(_is_weak_contiguous(inp));
|
||||
STD_TORCH_CHECK((inp.scalar_type()) == (out.scalar_type()));
|
||||
STD_TORCH_CHECK((inp.numel()) == (out.numel()));
|
||||
STD_TORCH_CHECK(_is_weak_contiguous(out));
|
||||
STD_TORCH_CHECK(_is_weak_contiguous(inp));
|
||||
auto input_size = inp.numel() * inp.element_size();
|
||||
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
|
||||
if (reg_buffer) {
|
||||
TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size,
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
STD_TORCH_CHECK((input_size) <= (reg_buffer_sz_bytes));
|
||||
STD_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.const_data_ptr(), input_size,
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
} else {
|
||||
reg_buffer = inp.data_ptr();
|
||||
reg_buffer = inp.mutable_data_ptr();
|
||||
}
|
||||
switch (out.scalar_type()) {
|
||||
case at::ScalarType::Float: {
|
||||
case torch::headeronly::ScalarType::Float: {
|
||||
fa->allreduce<float>(stream, reinterpret_cast<float*>(reg_buffer),
|
||||
reinterpret_cast<float*>(out.data_ptr()),
|
||||
reinterpret_cast<float*>(out.mutable_data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
case torch::headeronly::ScalarType::Half: {
|
||||
fa->allreduce<half>(stream, reinterpret_cast<half*>(reg_buffer),
|
||||
reinterpret_cast<half*>(out.data_ptr()), out.numel());
|
||||
reinterpret_cast<half*>(out.mutable_data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
case at::ScalarType::BFloat16: {
|
||||
case torch::headeronly::ScalarType::BFloat16: {
|
||||
fa->allreduce<nv_bfloat16>(
|
||||
stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
|
||||
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
|
||||
reinterpret_cast<nv_bfloat16*>(out.mutable_data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
@@ -112,7 +123,7 @@ int64_t meta_size() { return sizeof(vllm::Signal); }
|
||||
|
||||
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
|
||||
STD_TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
|
||||
void* ipc_ptrs[8];
|
||||
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
|
||||
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
|
||||
@@ -143,47 +154,49 @@ void register_graph_buffers(fptr_t _fa,
|
||||
fa->register_graph_buffers(bytes, offsets);
|
||||
}
|
||||
|
||||
std::tuple<fptr_t, torch::Tensor> allocate_shared_buffer_and_handle(
|
||||
std::tuple<fptr_t, torch::stable::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size) {
|
||||
auto device_index = c10::cuda::current_device();
|
||||
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
|
||||
int device_index;
|
||||
STD_CUDA_CHECK(cudaGetDevice(&device_index));
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(device_index);
|
||||
void* buffer;
|
||||
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
const cudaStream_t stream = get_current_cuda_stream(device_index);
|
||||
STD_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
|
||||
// Allocate buffer
|
||||
#if defined(USE_ROCM)
|
||||
// data buffers need to be "uncached" for signal on MI200
|
||||
AT_CUDA_CHECK(
|
||||
STD_CUDA_CHECK(
|
||||
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
|
||||
#else
|
||||
AT_CUDA_CHECK(cudaMalloc((void**)&buffer, size));
|
||||
STD_CUDA_CHECK(cudaMalloc((void**)&buffer, size));
|
||||
#endif
|
||||
AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
|
||||
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
STD_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
|
||||
STD_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
STD_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
|
||||
// Create IPC memhandle for the allocated buffer.
|
||||
// Will use it in open_mem_handle.
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||
auto handle =
|
||||
torch::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, options);
|
||||
AT_CUDA_CHECK(
|
||||
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data_ptr(), buffer));
|
||||
auto handle = torch::stable::empty(
|
||||
{static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))},
|
||||
torch::headeronly::ScalarType::Byte, std::nullopt,
|
||||
torch::stable::Device(torch::stable::DeviceType::CPU));
|
||||
STD_CUDA_CHECK(cudaIpcGetMemHandle(
|
||||
(cudaIpcMemHandle_t*)handle.mutable_data_ptr(), buffer));
|
||||
|
||||
return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
|
||||
}
|
||||
|
||||
fptr_t open_mem_handle(torch::Tensor& mem_handle) {
|
||||
fptr_t open_mem_handle(torch::stable::Tensor& mem_handle) {
|
||||
void* ipc_ptr;
|
||||
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
|
||||
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data_ptr()),
|
||||
STD_CUDA_CHECK(cudaIpcOpenMemHandle(
|
||||
(void**)&ipc_ptr,
|
||||
*((const cudaIpcMemHandle_t*)mem_handle.const_data_ptr()),
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
return reinterpret_cast<fptr_t>(ipc_ptr);
|
||||
}
|
||||
|
||||
void free_shared_buffer(fptr_t buffer) {
|
||||
AT_CUDA_CHECK(cudaFree(reinterpret_cast<void*>(buffer)));
|
||||
STD_CUDA_CHECK(cudaFree(reinterpret_cast<void*>(buffer)));
|
||||
}
|
||||
+70
-56
@@ -28,7 +28,20 @@
|
||||
* [bs*576, bs*576 + bs*8): UE8M0 scales, 7 real + 1 pad per token
|
||||
*/
|
||||
|
||||
#include "torch_utils.h"
|
||||
|
||||
#include <torch/csrc/stable/macros.h>
|
||||
#include <torch/csrc/stable/accelerator.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/csrc/stable/device.h>
|
||||
|
||||
#include <cmath>
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
#include "type_convert.cuh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_fp8.h>
|
||||
#else
|
||||
@@ -37,14 +50,6 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/cuda.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
#include "type_convert.cuh"
|
||||
|
||||
#ifndef FINAL_MASK
|
||||
#ifdef USE_ROCM
|
||||
#define FINAL_MASK 0xffffffffffffffffULL
|
||||
@@ -70,7 +75,7 @@ namespace deepseek_v4_fused_ops {
|
||||
|
||||
namespace {
|
||||
inline int getSMVersion() {
|
||||
auto* props = at::cuda::getCurrentDeviceProperties();
|
||||
auto* props = get_device_prop();
|
||||
return props->major * 10 + props->minor;
|
||||
}
|
||||
} // namespace
|
||||
@@ -564,7 +569,7 @@ static void launchFusedDeepseekV4Templated(
|
||||
// bf16 on pre-Ampere (sm_70/sm_75) because _typeConvert<BFloat16> is
|
||||
// unavailable there. Refuse the launch loudly instead of silently
|
||||
// skipping the work.
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
sm_version >= 80,
|
||||
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert requires sm_80+ "
|
||||
"(Ampere or newer); got sm_",
|
||||
@@ -635,7 +640,7 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
|
||||
DISPATCH(64)
|
||||
DISPATCH(128)
|
||||
default:
|
||||
TORCH_CHECK(false,
|
||||
STD_TORCH_CHECK(false,
|
||||
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert: "
|
||||
"unsupported num_heads_q_padded=",
|
||||
num_heads_q_padded,
|
||||
@@ -650,71 +655,80 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
// Torch op wrapper
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
|
||||
torch::Tensor const& q_in, // [N, num_heads_q, 512] bf16
|
||||
torch::Tensor const& kv, // [N, 512] bf16 (read-only)
|
||||
torch::Tensor& k_cache, // [num_blocks, block_bytes] uint8
|
||||
torch::Tensor const& slot_mapping, // [N] int64
|
||||
torch::Tensor const& position_ids, // [N] int64
|
||||
torch::Tensor const& cos_sin_cache, // [max_pos, rope_dim] bf16
|
||||
int64_t q_head_padded, // padded Q head count for output
|
||||
torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
|
||||
torch::stable::Tensor const& q_in, // [N, num_heads_q, 512] bf16
|
||||
torch::stable::Tensor const& kv, // [N, 512] bf16 (read-only)
|
||||
torch::stable::Tensor& k_cache, // [num_blocks, block_bytes] uint8
|
||||
torch::stable::Tensor const& slot_mapping, // [N] int64
|
||||
torch::stable::Tensor const& position_ids, // [N] int64
|
||||
torch::stable::Tensor const& cos_sin_cache, // [max_pos, rope_dim] bf16
|
||||
int64_t q_head_padded, // padded Q head count for output
|
||||
double eps, int64_t cache_block_size) {
|
||||
TORCH_CHECK(q_in.is_cuda() && q_in.is_contiguous(),
|
||||
"q_in must be contiguous CUDA");
|
||||
TORCH_CHECK(kv.is_cuda() && kv.is_contiguous(), "kv must be contiguous CUDA");
|
||||
TORCH_CHECK(k_cache.is_cuda(), "k_cache must be CUDA");
|
||||
TORCH_CHECK(slot_mapping.is_cuda() && slot_mapping.dtype() == torch::kInt64,
|
||||
"slot_mapping must be int64 CUDA");
|
||||
TORCH_CHECK(position_ids.is_cuda() && position_ids.dtype() == torch::kInt64,
|
||||
"position_ids must be int64 CUDA");
|
||||
TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA");
|
||||
TORCH_CHECK(q_in.dim() == 3 && q_in.size(2) == 512,
|
||||
"q_in shape [N, num_heads_q, 512]");
|
||||
TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
|
||||
TORCH_CHECK(q_in.dtype() == kv.dtype(), "q_in and kv dtype must match");
|
||||
TORCH_CHECK(q_head_padded >= q_in.size(1),
|
||||
"q_head_padded must be >= q_in.size(1) (num_heads_q)");
|
||||
TORCH_CHECK(k_cache.dtype() == torch::kUInt8, "k_cache must be uint8");
|
||||
TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
|
||||
"cos_sin_cache shape [max_pos, 64]");
|
||||
TORCH_CHECK(cos_sin_cache.dtype() == torch::kFloat32,
|
||||
"cos_sin_cache must be float32");
|
||||
STD_TORCH_CHECK(q_in.device().is_cuda() && q_in.is_contiguous(),
|
||||
"q_in must be contiguous CUDA");
|
||||
STD_TORCH_CHECK(kv.device().is_cuda() && kv.is_contiguous(),
|
||||
"kv must be contiguous CUDA");
|
||||
STD_TORCH_CHECK(k_cache.device().is_cuda(), "k_cache must be CUDA");
|
||||
STD_TORCH_CHECK(slot_mapping.device().is_cuda() &&
|
||||
slot_mapping.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Long,
|
||||
"slot_mapping must be int64 CUDA");
|
||||
STD_TORCH_CHECK(position_ids.device().is_cuda() &&
|
||||
position_ids.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Long,
|
||||
"position_ids must be int64 CUDA");
|
||||
STD_TORCH_CHECK(cos_sin_cache.device().is_cuda(), "cos_sin_cache must be CUDA");
|
||||
STD_TORCH_CHECK(q_in.dim() == 3 && q_in.size(2) == 512,
|
||||
"q_in shape [N, num_heads_q, 512]");
|
||||
STD_TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
|
||||
STD_TORCH_CHECK(q_in.scalar_type() == kv.scalar_type(),
|
||||
"q_in and kv dtype must match");
|
||||
STD_TORCH_CHECK(q_head_padded >= q_in.size(1),
|
||||
"q_head_padded must be >= q_in.size(1) (num_heads_q)");
|
||||
STD_TORCH_CHECK(k_cache.scalar_type() == torch::headeronly::ScalarType::Byte,
|
||||
"k_cache must be uint8");
|
||||
STD_TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
|
||||
"cos_sin_cache shape [max_pos, 64]");
|
||||
STD_TORCH_CHECK(cos_sin_cache.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float,
|
||||
"cos_sin_cache must be float32");
|
||||
|
||||
// With DP padding, slot_mapping can be shorter than q/kv/positions.
|
||||
// Q-norm+RoPE runs on all q.size(0) rows (downstream attention uses them);
|
||||
// KV quant+insert runs only on the first slot_mapping.size(0) rows.
|
||||
int const num_tokens_full = static_cast<int>(q_in.size(0));
|
||||
int const num_tokens_insert = static_cast<int>(slot_mapping.size(0));
|
||||
TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
|
||||
static_cast<int>(position_ids.size(0)) == num_tokens_full,
|
||||
"q/kv/position_ids row counts must match");
|
||||
TORCH_CHECK(num_tokens_insert <= num_tokens_full,
|
||||
"slot_mapping must not exceed q row count");
|
||||
STD_TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
|
||||
static_cast<int>(position_ids.size(0)) == num_tokens_full,
|
||||
"q/kv/position_ids row counts must match");
|
||||
STD_TORCH_CHECK(num_tokens_insert <= num_tokens_full,
|
||||
"slot_mapping must not exceed q row count");
|
||||
int const num_heads_q = static_cast<int>(q_in.size(1));
|
||||
int const num_heads_q_padded = static_cast<int>(q_head_padded);
|
||||
int const cache_block_size_i = static_cast<int>(cache_block_size);
|
||||
int const kv_block_stride = static_cast<int>(k_cache.stride(0));
|
||||
|
||||
at::cuda::OptionalCUDAGuard device_guard(device_of(q_in));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
q_in.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream(q_in.get_device_index());
|
||||
|
||||
// Allocate the padded q output. The kernel writes every element (live
|
||||
// region gets RMSNorm+RoPE; pad region gets zeros), so `empty` is safe.
|
||||
torch::Tensor q_out = torch::empty(
|
||||
{q_in.size(0), q_head_padded, q_in.size(2)}, q_in.options());
|
||||
auto q_out = torch::stable::new_empty(
|
||||
q_in, {q_in.size(0), q_head_padded, q_in.size(2)}, q_in.scalar_type());
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(
|
||||
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||
q_in.scalar_type(), "fused_deepseek_v4_qnorm_rope_kv_insert", [&] {
|
||||
using qkv_scalar_t = scalar_t;
|
||||
vllm::deepseek_v4_fused_ops::
|
||||
launchFusedDeepseekV4QNormRopeKVRopeQuantInsert<qkv_scalar_t>(
|
||||
reinterpret_cast<qkv_scalar_t const*>(q_in.data_ptr()),
|
||||
reinterpret_cast<qkv_scalar_t*>(q_out.data_ptr()),
|
||||
reinterpret_cast<qkv_scalar_t const*>(kv.data_ptr()),
|
||||
reinterpret_cast<uint8_t*>(k_cache.data_ptr()),
|
||||
reinterpret_cast<int64_t const*>(slot_mapping.data_ptr()),
|
||||
reinterpret_cast<int64_t const*>(position_ids.data_ptr()),
|
||||
cos_sin_cache.data_ptr<float>(), static_cast<float>(eps),
|
||||
reinterpret_cast<qkv_scalar_t const*>(q_in.const_data_ptr()),
|
||||
reinterpret_cast<qkv_scalar_t*>(q_out.mutable_data_ptr()),
|
||||
reinterpret_cast<qkv_scalar_t const*>(kv.const_data_ptr()),
|
||||
reinterpret_cast<uint8_t*>(k_cache.mutable_data_ptr()),
|
||||
slot_mapping.const_data_ptr<int64_t>(),
|
||||
position_ids.const_data_ptr<int64_t>(),
|
||||
cos_sin_cache.const_data_ptr<float>(), static_cast<float>(eps),
|
||||
num_tokens_full, num_tokens_insert, num_heads_q,
|
||||
num_heads_q_padded, cache_block_size_i, kv_block_stride,
|
||||
stream);
|
||||
+71
-58
@@ -15,16 +15,19 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "torch_utils.h"
|
||||
|
||||
#include <torch/csrc/stable/macros.h>
|
||||
#include <torch/csrc/stable/accelerator.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/csrc/stable/device.h>
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "cuda_utils.h"
|
||||
#include "core/registration.h"
|
||||
#include "minimax_reduce_rms_kernel.h"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -611,7 +614,7 @@ int get_sm_count() {
|
||||
static int sm_count = 0;
|
||||
if (sm_count == 0) {
|
||||
int device_id;
|
||||
CUDA_CHECK(cudaGetDevice(&device_id));
|
||||
STD_CUDA_CHECK(cudaGetDevice(&device_id));
|
||||
cudaDeviceProp device_prop;
|
||||
cudaGetDeviceProperties(&device_prop, device_id);
|
||||
sm_count = device_prop.multiProcessorCount;
|
||||
@@ -621,13 +624,13 @@ int get_sm_count() {
|
||||
|
||||
inline int getSMVersion(bool queryRealSmArch = false) {
|
||||
int device{-1};
|
||||
CUDA_CHECK(cudaGetDevice(&device));
|
||||
STD_CUDA_CHECK(cudaGetDevice(&device));
|
||||
int sm_major = 0;
|
||||
int sm_minor = 0;
|
||||
CUDA_CHECK(cudaDeviceGetAttribute(&sm_major,
|
||||
cudaDevAttrComputeCapabilityMajor, device));
|
||||
CUDA_CHECK(cudaDeviceGetAttribute(&sm_minor,
|
||||
cudaDevAttrComputeCapabilityMinor, device));
|
||||
STD_CUDA_CHECK(cudaDeviceGetAttribute(
|
||||
&sm_major, cudaDevAttrComputeCapabilityMajor, device));
|
||||
STD_CUDA_CHECK(cudaDeviceGetAttribute(
|
||||
&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
|
||||
int sm = sm_major * 10 + sm_minor;
|
||||
if (sm == 121 && !queryRealSmArch) {
|
||||
return 120;
|
||||
@@ -639,7 +642,7 @@ template <typename KernelFunc>
|
||||
int get_max_active_blocks(KernelFunc kernel, int block_size,
|
||||
int dynamic_smem = 0) {
|
||||
int max_active = 0;
|
||||
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
STD_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active, kernel, block_size, dynamic_smem));
|
||||
return std::max(max_active, 1);
|
||||
}
|
||||
@@ -678,27 +681,27 @@ void minimax_reduce_rms_kernel_launcher(MiniMaxReduceRMSParams const& params) {
|
||||
cfg.attrs = attribute;
|
||||
cfg.numAttrs = SM >= 90 ? 2 : 0;
|
||||
|
||||
CUDA_CHECK(cudaLaunchKernelEx(
|
||||
STD_CUDA_CHECK(cudaLaunchKernelEx(
|
||||
&cfg, minimax_reduce_rms_kernel_lamport<DType, NRanks>, params));
|
||||
}
|
||||
|
||||
template <typename DType, int NRanks, int OriginQDim, int OriginKDim>
|
||||
void minimax_reduce_rms_kernel_launcher_float4(
|
||||
MiniMaxReduceRMSParams const& params) {
|
||||
TORCH_CHECK(params.size_q % params.hidden_dim == 0);
|
||||
TORCH_CHECK(params.hidden_dim % kElemsPerAccess<DType> == 0);
|
||||
STD_TORCH_CHECK(params.size_q % params.hidden_dim == 0);
|
||||
STD_TORCH_CHECK(params.hidden_dim % kElemsPerAccess<DType> == 0);
|
||||
if (params.stride_q > 0) {
|
||||
TORCH_CHECK(params.stride_q % kElemsPerAccess<DType> == 0);
|
||||
STD_TORCH_CHECK(params.stride_q % kElemsPerAccess<DType> == 0);
|
||||
}
|
||||
TORCH_CHECK(params.allreduce_in_k != nullptr,
|
||||
"float4 QK kernel requires K input");
|
||||
TORCH_CHECK(params.hidden_dim >= params.hidden_dim_k);
|
||||
TORCH_CHECK(params.size_k % params.hidden_dim_k == 0);
|
||||
TORCH_CHECK(params.hidden_dim_k % kElemsPerAccess<DType> == 0);
|
||||
TORCH_CHECK(params.size_q / params.hidden_dim ==
|
||||
params.size_k / params.hidden_dim_k);
|
||||
STD_TORCH_CHECK(params.allreduce_in_k != nullptr,
|
||||
"float4 QK kernel requires K input");
|
||||
STD_TORCH_CHECK(params.hidden_dim >= params.hidden_dim_k);
|
||||
STD_TORCH_CHECK(params.size_k % params.hidden_dim_k == 0);
|
||||
STD_TORCH_CHECK(params.hidden_dim_k % kElemsPerAccess<DType> == 0);
|
||||
STD_TORCH_CHECK(params.size_q / params.hidden_dim ==
|
||||
params.size_k / params.hidden_dim_k);
|
||||
if (params.stride_k > 0) {
|
||||
TORCH_CHECK(params.stride_k % kElemsPerAccess<DType> == 0);
|
||||
STD_TORCH_CHECK(params.stride_k % kElemsPerAccess<DType> == 0);
|
||||
}
|
||||
|
||||
int token_num = params.size_q / params.hidden_dim;
|
||||
@@ -746,7 +749,7 @@ void minimax_reduce_rms_kernel_launcher_float4(
|
||||
cfg.attrs = attribute;
|
||||
cfg.numAttrs = SM >= 90 ? 2 : 0;
|
||||
|
||||
CUDA_CHECK(cudaLaunchKernelEx(&cfg, kfn, params));
|
||||
STD_CUDA_CHECK(cudaLaunchKernelEx(&cfg, kfn, params));
|
||||
}
|
||||
|
||||
template <int NRanks>
|
||||
@@ -759,21 +762,21 @@ void dispatch_dtype(MiniMaxReduceRMSParams const& params) {
|
||||
(params.hidden_dim * params.nranks == 6144) &&
|
||||
(params.hidden_dim_k * params.nranks == 1024);
|
||||
|
||||
if (params.dtype == at::ScalarType::Half) {
|
||||
if (params.dtype == torch::headeronly::ScalarType::Half) {
|
||||
if (use_float4) {
|
||||
minimax_reduce_rms_kernel_launcher_float4<half, NRanks, 6144, 1024>(
|
||||
params);
|
||||
} else {
|
||||
minimax_reduce_rms_kernel_launcher<half, NRanks>(params);
|
||||
}
|
||||
} else if (params.dtype == at::ScalarType::BFloat16) {
|
||||
} else if (params.dtype == torch::headeronly::ScalarType::BFloat16) {
|
||||
if (use_float4) {
|
||||
minimax_reduce_rms_kernel_launcher_float4<__nv_bfloat16, NRanks, 6144,
|
||||
1024>(params);
|
||||
} else {
|
||||
minimax_reduce_rms_kernel_launcher<__nv_bfloat16, NRanks>(params);
|
||||
}
|
||||
} else if (params.dtype == at::ScalarType::Float) {
|
||||
} else if (params.dtype == torch::headeronly::ScalarType::Float) {
|
||||
if (use_float4) {
|
||||
minimax_reduce_rms_kernel_launcher_float4<float, NRanks, 6144, 1024>(
|
||||
params);
|
||||
@@ -781,7 +784,7 @@ void dispatch_dtype(MiniMaxReduceRMSParams const& params) {
|
||||
minimax_reduce_rms_kernel_launcher<float, NRanks>(params);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type for minimax_reduce_rms_op");
|
||||
STD_TORCH_CHECK(false, "Unsupported data type for minimax_reduce_rms_op");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -795,16 +798,18 @@ void minimax_reduce_rms_op(MiniMaxReduceRMSParams const& params) {
|
||||
} else if (params.nranks == 16) {
|
||||
dispatch_dtype<16>(params);
|
||||
} else {
|
||||
TORCH_CHECK(false, "minimax_reduce_rms_op: unsupported ranks number!");
|
||||
STD_TORCH_CHECK(false, "minimax_reduce_rms_op: unsupported ranks number!");
|
||||
}
|
||||
}
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace vllm
|
||||
|
||||
torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
|
||||
torch::Tensor const& norm_weight,
|
||||
torch::Tensor workspace, int64_t const rank,
|
||||
int64_t const nranks, double const eps) {
|
||||
torch::stable::Tensor minimax_allreduce_rms(
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& norm_weight, torch::stable::Tensor workspace,
|
||||
int64_t const rank, int64_t const nranks, double const eps) {
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
input.get_device_index());
|
||||
auto allreduce_params = vllm::tensorrt_llm::MiniMaxReduceRMSParams();
|
||||
|
||||
allreduce_params.nranks = static_cast<int>(nranks);
|
||||
@@ -815,12 +820,12 @@ torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
|
||||
allreduce_params.stride_q = allreduce_params.hidden_dim;
|
||||
allreduce_params.workspace =
|
||||
reinterpret_cast<void**>(workspace.mutable_data_ptr());
|
||||
allreduce_params.allreduce_in = input.data_ptr();
|
||||
allreduce_params.rms_gamma = norm_weight.data_ptr();
|
||||
allreduce_params.allreduce_in = const_cast<void*>(input.const_data_ptr());
|
||||
allreduce_params.rms_gamma = const_cast<void*>(norm_weight.const_data_ptr());
|
||||
allreduce_params.rms_eps = static_cast<float>(eps);
|
||||
allreduce_params.stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
allreduce_params.stream = get_current_cuda_stream(input.get_device_index());
|
||||
|
||||
torch::Tensor rms_norm_out = torch::empty_like(input);
|
||||
torch::stable::Tensor rms_norm_out = torch::stable::empty_like(input);
|
||||
allreduce_params.rms_norm_out = rms_norm_out.mutable_data_ptr();
|
||||
|
||||
vllm::tensorrt_llm::minimax_reduce_rms_op(allreduce_params);
|
||||
@@ -828,26 +833,33 @@ torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
|
||||
return rms_norm_out;
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> minimax_allreduce_rms_qk(
|
||||
torch::Tensor qkv, torch::Tensor const& norm_weight_q,
|
||||
torch::Tensor const& norm_weight_k, torch::Tensor workspace,
|
||||
int64_t const q_size, int64_t const kv_size, int64_t const rank,
|
||||
int64_t const nranks, double const eps) {
|
||||
TORCH_CHECK(qkv.dim() == 2, "minimax_allreduce_rms_qk: qkv must be 2D");
|
||||
TORCH_CHECK(qkv.is_contiguous(),
|
||||
"minimax_allreduce_rms_qk: qkv must be contiguous");
|
||||
std::tuple<torch::stable::Tensor, torch::stable::Tensor>
|
||||
minimax_allreduce_rms_qk(torch::stable::Tensor qkv,
|
||||
torch::stable::Tensor const& norm_weight_q,
|
||||
torch::stable::Tensor const& norm_weight_k,
|
||||
torch::stable::Tensor workspace, int64_t const q_size,
|
||||
int64_t const kv_size, int64_t const rank,
|
||||
int64_t const nranks, double const eps) {
|
||||
STD_TORCH_CHECK(qkv.dim() == 2, "minimax_allreduce_rms_qk: qkv must be 2D");
|
||||
STD_TORCH_CHECK(qkv.is_contiguous(),
|
||||
"minimax_allreduce_rms_qk: qkv must be contiguous");
|
||||
int64_t qkv_dim = qkv.size(-1);
|
||||
TORCH_CHECK(qkv_dim == q_size + 2 * kv_size,
|
||||
"minimax_allreduce_rms_qk: qkv last dim must equal "
|
||||
"q_size + 2 * kv_size");
|
||||
TORCH_CHECK(rank < nranks,
|
||||
"minimax_allreduce_rms_qk: rank must be less than nranks");
|
||||
STD_TORCH_CHECK(qkv_dim == q_size + 2 * kv_size,
|
||||
"minimax_allreduce_rms_qk: qkv last dim must equal "
|
||||
"q_size + 2 * kv_size");
|
||||
STD_TORCH_CHECK(rank < nranks,
|
||||
"minimax_allreduce_rms_qk: rank must be less than nranks");
|
||||
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
qkv.get_device_index());
|
||||
|
||||
int64_t num_tokens = qkv.size(0);
|
||||
int elem_bytes = qkv.element_size();
|
||||
|
||||
torch::Tensor q_out = torch::empty({num_tokens, q_size}, qkv.options());
|
||||
torch::Tensor k_out = torch::empty({num_tokens, kv_size}, qkv.options());
|
||||
torch::stable::Tensor q_out =
|
||||
torch::stable::new_empty(qkv, {num_tokens, q_size}, qkv.scalar_type());
|
||||
torch::stable::Tensor k_out =
|
||||
torch::stable::new_empty(qkv, {num_tokens, kv_size}, qkv.scalar_type());
|
||||
|
||||
auto params = vllm::tensorrt_llm::MiniMaxReduceRMSParams();
|
||||
params.nranks = static_cast<int>(nranks);
|
||||
@@ -863,13 +875,14 @@ std::tuple<torch::Tensor, torch::Tensor> minimax_allreduce_rms_qk(
|
||||
params.stride_k_out = 0; // k_out is contiguous; kernel uses hidden_dim_k
|
||||
params.workspace = reinterpret_cast<void**>(workspace.mutable_data_ptr());
|
||||
|
||||
uint8_t* base = static_cast<uint8_t*>(qkv.data_ptr());
|
||||
uint8_t* base =
|
||||
const_cast<uint8_t*>(static_cast<const uint8_t*>(qkv.const_data_ptr()));
|
||||
params.allreduce_in = base;
|
||||
params.allreduce_in_k = base + q_size * elem_bytes;
|
||||
params.rms_gamma = norm_weight_q.data_ptr();
|
||||
params.rms_gamma_k = norm_weight_k.data_ptr();
|
||||
params.rms_gamma = const_cast<void*>(norm_weight_q.const_data_ptr());
|
||||
params.rms_gamma_k = const_cast<void*>(norm_weight_k.const_data_ptr());
|
||||
params.rms_eps = static_cast<float>(eps);
|
||||
params.stream = at::cuda::getCurrentCUDAStream(qkv.get_device());
|
||||
params.stream = get_current_cuda_stream(qkv.get_device_index());
|
||||
|
||||
params.rms_norm_out = q_out.mutable_data_ptr();
|
||||
params.rms_norm_out_k = k_out.mutable_data_ptr();
|
||||
@@ -0,0 +1,69 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
// Adapted from SGLang:
|
||||
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu
|
||||
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_mxfp8_grouped_mm_launcher.cuh"
|
||||
|
||||
void cutlass_mxfp8_grouped_mm(const torch::stable::Tensor& a,
|
||||
const torch::stable::Tensor& b,
|
||||
const torch::stable::Tensor& sfa,
|
||||
const torch::stable::Tensor& sfb,
|
||||
torch::stable::Tensor& d,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& blockscale_offsets) {
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
STD_TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
STD_TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||
"problem_sizes must have shape (num_experts, 3)");
|
||||
STD_TORCH_CHECK(
|
||||
problem_sizes.size(0) == expert_offsets.size(0),
|
||||
"Number of experts in problem_sizes must match expert_offsets");
|
||||
STD_TORCH_CHECK(
|
||||
problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||
"problem_sizes must be int32");
|
||||
STD_TORCH_CHECK(
|
||||
expert_offsets.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||
"expert_offsets must be int32");
|
||||
STD_TORCH_CHECK(
|
||||
blockscale_offsets.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||
"blockscale_offsets must be int32");
|
||||
STD_TORCH_CHECK(a.dim() == 2,
|
||||
"a must be a 2D tensor of shape (num_tokens, k)");
|
||||
STD_TORCH_CHECK(b.dim() == 3,
|
||||
"b must be a 3D tensor of shape (num_experts, k, n)");
|
||||
STD_TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0,
|
||||
"k should align 128");
|
||||
STD_TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128");
|
||||
STD_TORCH_CHECK(a.stride(1) == 1, "a must be row major");
|
||||
STD_TORCH_CHECK(b.stride(1) == 1, "b must be column major");
|
||||
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
a.get_device_index());
|
||||
auto stream = get_current_cuda_stream(a.get_device_index());
|
||||
if (d.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
|
||||
cutlass::bfloat16_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
|
||||
blockscale_offsets, stream);
|
||||
} else if (d.scalar_type() == torch::headeronly::ScalarType::Half) {
|
||||
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
|
||||
cutlass::half_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
|
||||
blockscale_offsets, stream);
|
||||
} else {
|
||||
STD_TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
|
||||
}
|
||||
#else
|
||||
STD_TORCH_CHECK(false,
|
||||
"No implemented cutlass_mxfp8_grouped_mm for "
|
||||
"current device");
|
||||
#endif
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
|
||||
m.impl("cutlass_mxfp8_grouped_mm", TORCH_BOX(&cutlass_mxfp8_grouped_mm));
|
||||
}
|
||||
+71
-52
@@ -4,9 +4,9 @@
|
||||
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_launcher.cuh
|
||||
|
||||
#pragma once
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
@@ -15,18 +15,22 @@
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass_mxfp8_grouped_mm_functor.cuh"
|
||||
#include "cutlass_mxfp8_grouped_mm_traits.cuh"
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
namespace expert_specialization {
|
||||
|
||||
template <typename GemmTraits>
|
||||
void cutlass_mxfp8_grouped_mm_pre_compute(
|
||||
torch::Tensor& a_ptrs, torch::Tensor& b_ptrs, torch::Tensor& sfa_ptrs,
|
||||
torch::Tensor& sfb_ptrs, torch::Tensor& d_ptrs, torch::Tensor& stride_a,
|
||||
torch::Tensor& stride_b, torch::Tensor& stride_d, torch::Tensor& layout_sfa,
|
||||
torch::Tensor& layout_sfb, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& sfa, const torch::Tensor& sfb, const torch::Tensor& d,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& blockscale_offsets, cudaStream_t stream) {
|
||||
torch::stable::Tensor& a_ptrs, torch::stable::Tensor& b_ptrs,
|
||||
torch::stable::Tensor& sfa_ptrs, torch::stable::Tensor& sfb_ptrs,
|
||||
torch::stable::Tensor& d_ptrs, torch::stable::Tensor& stride_a,
|
||||
torch::stable::Tensor& stride_b, torch::stable::Tensor& stride_d,
|
||||
torch::stable::Tensor& layout_sfa, torch::stable::Tensor& layout_sfb,
|
||||
const torch::stable::Tensor& a, const torch::stable::Tensor& b,
|
||||
const torch::stable::Tensor& sfa, const torch::stable::Tensor& sfb,
|
||||
const torch::stable::Tensor& d, const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& blockscale_offsets, cudaStream_t stream) {
|
||||
using OffsetFunctor = CutlassMxfp8GroupedMmOffsetFunctor<GemmTraits>;
|
||||
using ElementA = typename OffsetFunctor::ElementA;
|
||||
using ElementB = typename OffsetFunctor::ElementB;
|
||||
@@ -42,10 +46,10 @@ void cutlass_mxfp8_grouped_mm_pre_compute(
|
||||
using StrideB = typename StrideFunctor::StrideB;
|
||||
using StrideD = typename StrideFunctor::StrideD;
|
||||
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
TORCH_CHECK(num_experts <= 1024,
|
||||
"Number of experts cannot exceed 1024, the maximum number of "
|
||||
"threads per block.");
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
STD_TORCH_CHECK(num_experts <= 1024,
|
||||
"Number of experts cannot exceed 1024, the maximum number of "
|
||||
"threads per block.");
|
||||
|
||||
OffsetFunctor offset_functor(
|
||||
reinterpret_cast<int*>(expert_offsets.data_ptr()),
|
||||
@@ -72,13 +76,18 @@ void cutlass_mxfp8_grouped_mm_pre_compute(
|
||||
}
|
||||
|
||||
template <typename GemmTraits>
|
||||
void cutlass_mxfp8_grouped_mm(
|
||||
const torch::Tensor& a_ptrs, const torch::Tensor& b_ptrs,
|
||||
const torch::Tensor& sfa_ptrs, const torch::Tensor& sfb_ptrs,
|
||||
const torch::Tensor& d_ptrs, const torch::Tensor& stride_a,
|
||||
const torch::Tensor& stride_b, const torch::Tensor& stride_d,
|
||||
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& problem_sizes, cudaStream_t stream) {
|
||||
void cutlass_mxfp8_grouped_mm(const torch::stable::Tensor& a_ptrs,
|
||||
const torch::stable::Tensor& b_ptrs,
|
||||
const torch::stable::Tensor& sfa_ptrs,
|
||||
const torch::stable::Tensor& sfb_ptrs,
|
||||
const torch::stable::Tensor& d_ptrs,
|
||||
const torch::stable::Tensor& stride_a,
|
||||
const torch::stable::Tensor& stride_b,
|
||||
const torch::stable::Tensor& stride_d,
|
||||
const torch::stable::Tensor& layout_sfa,
|
||||
const torch::stable::Tensor& layout_sfb,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
cudaStream_t stream) {
|
||||
using Gemm = typename GemmTraits::Gemm;
|
||||
using ElementA = typename Gemm::ElementA;
|
||||
using ElementB = typename Gemm::ElementB;
|
||||
@@ -93,13 +102,12 @@ void cutlass_mxfp8_grouped_mm(
|
||||
typename GemmTraits::ProblemShape::UnderlyingProblemShape;
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = c10::cuda::current_device();
|
||||
hw_info.sm_count =
|
||||
at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||
hw_info.device_id = d_ptrs.get_device_index();
|
||||
hw_info.sm_count = get_device_prop()->multiProcessorCount;
|
||||
hw_info.cluster_shape = GemmTraits::MMAConfig::preferred_cluster;
|
||||
hw_info.cluster_shape_fallback = GemmTraits::MMAConfig::fallback_cluster;
|
||||
|
||||
int num_experts = (int)problem_sizes.size(0);
|
||||
int num_experts = static_cast<int>(problem_sizes.size(0));
|
||||
|
||||
UnderlyingProblemShape* underlying_problem_shape =
|
||||
reinterpret_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||
@@ -127,44 +135,55 @@ void cutlass_mxfp8_grouped_mm(
|
||||
Gemm gemm;
|
||||
|
||||
auto can_implement_status = gemm.can_implement(arguments);
|
||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
||||
"Failed to implement GEMM");
|
||||
STD_TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
||||
"Failed to implement GEMM");
|
||||
|
||||
torch::TensorOptions options_uint8 =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(d_ptrs.device());
|
||||
size_t workspace_size = gemm.get_workspace_size(arguments);
|
||||
torch::Tensor workspace = torch::empty(workspace_size, options_uint8);
|
||||
torch::stable::Tensor workspace = torch::stable::empty(
|
||||
{static_cast<int64_t>(workspace_size)},
|
||||
torch::headeronly::ScalarType::Byte, std::nullopt, d_ptrs.device());
|
||||
|
||||
auto status = gemm.initialize(arguments, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
|
||||
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
|
||||
"Failed to initialize GEMM");
|
||||
|
||||
status = gemm.run(stream, nullptr, true); // Enable PDL
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void cutlass_mxfp8_grouped_mm_dispatch_out_dtype(
|
||||
const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& sfa,
|
||||
const torch::Tensor& sfb, torch::Tensor& d,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& blockscale_offsets, cudaStream_t stream) {
|
||||
int num_experts = (int)problem_sizes.size(0);
|
||||
torch::TensorOptions options_int64 =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
torch::TensorOptions options_int32 =
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(a.device());
|
||||
const torch::stable::Tensor& a, const torch::stable::Tensor& b,
|
||||
const torch::stable::Tensor& sfa, const torch::stable::Tensor& sfb,
|
||||
torch::stable::Tensor& d, const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& blockscale_offsets, cudaStream_t stream) {
|
||||
int num_experts = static_cast<int>(problem_sizes.size(0));
|
||||
auto device = a.device();
|
||||
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor sfa_ptrs = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor sfb_ptrs = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor d_ptrs = torch::empty(num_experts, options_int64);
|
||||
torch::stable::Tensor a_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor b_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor sfa_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor sfb_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor d_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
|
||||
torch::Tensor stride_a = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor stride_b = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor stride_d = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int32);
|
||||
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int32);
|
||||
torch::stable::Tensor stride_a = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor stride_b = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor stride_d = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor layout_sfa =
|
||||
torch::stable::empty({num_experts, 5}, torch::headeronly::ScalarType::Int,
|
||||
std::nullopt, device);
|
||||
torch::stable::Tensor layout_sfb =
|
||||
torch::stable::empty({num_experts, 5}, torch::headeronly::ScalarType::Int,
|
||||
std::nullopt, device);
|
||||
|
||||
using GemmTraits = CutlassMxfp8GroupedMmGemmTraits<MMA1SMConfig, OutType>;
|
||||
cutlass_mxfp8_grouped_mm_pre_compute<GemmTraits>(
|
||||
@@ -176,4 +195,4 @@ void cutlass_mxfp8_grouped_mm_dispatch_out_dtype(
|
||||
layout_sfa, layout_sfb, problem_sizes, stream);
|
||||
}
|
||||
|
||||
} // namespace expert_specialization
|
||||
} // namespace expert_specialization
|
||||
@@ -0,0 +1,66 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
// Adapted from SGLang:
|
||||
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu
|
||||
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "mxfp8_experts_quant.cuh"
|
||||
|
||||
void mxfp8_experts_quant(const torch::stable::Tensor& input,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& blockscale_offsets,
|
||||
torch::stable::Tensor& quant_output,
|
||||
torch::stable::Tensor& scale_factor) {
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
STD_TORCH_CHECK(input.dim() == 2, "input must be 2D tensor");
|
||||
STD_TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128");
|
||||
STD_TORCH_CHECK(input.stride(1) == 1, "input must be row major");
|
||||
STD_TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
STD_TORCH_CHECK(
|
||||
problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||
"problem_sizes must be int32");
|
||||
STD_TORCH_CHECK(
|
||||
expert_offsets.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||
"expert_offsets must be int32");
|
||||
STD_TORCH_CHECK(
|
||||
blockscale_offsets.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||
"blockscale_offsets must be int32");
|
||||
|
||||
auto groups = problem_sizes.size(0);
|
||||
STD_TORCH_CHECK(
|
||||
expert_offsets.dim() == 1 && expert_offsets.size(0) == groups,
|
||||
"expert_offsets must be 1D and have size equal to the number of groups");
|
||||
STD_TORCH_CHECK(
|
||||
blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups,
|
||||
"blockscale_offsets must be 1D and have size equal to the number of "
|
||||
"groups");
|
||||
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
input.get_device_index());
|
||||
if (input.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
expert_specialization::launch_mxfp8_experts_quant<__nv_bfloat16>(
|
||||
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
|
||||
scale_factor);
|
||||
} else if (input.scalar_type() == torch::headeronly::ScalarType::Half) {
|
||||
expert_specialization::launch_mxfp8_experts_quant<__half>(
|
||||
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
|
||||
scale_factor);
|
||||
} else {
|
||||
STD_TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
|
||||
}
|
||||
#else
|
||||
STD_TORCH_CHECK(false,
|
||||
"No implemented mxfp8_experts_quant for "
|
||||
"current device");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Registered here (not torch_bindings.cpp) because ENABLE_ES_MXFP8_GROUPED_MM
|
||||
// is applied only under COMPILE_LANGUAGE:CUDA.
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
|
||||
m.impl("mxfp8_experts_quant", TORCH_BOX(&mxfp8_experts_quant));
|
||||
}
|
||||
+16
-14
@@ -4,16 +4,19 @@
|
||||
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cuh
|
||||
|
||||
#pragma once
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/macros.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
|
||||
#include <cuda/ptx>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
namespace expert_specialization {
|
||||
|
||||
@@ -356,12 +359,12 @@ __global__ void mxfp8_experts_quant_kernel(
|
||||
}
|
||||
|
||||
template <typename T_IN>
|
||||
void launch_mxfp8_experts_quant(const torch::Tensor& input,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& blockscale_offsets,
|
||||
torch::Tensor& quant_output,
|
||||
torch::Tensor& scale_factor) {
|
||||
void launch_mxfp8_experts_quant(const torch::stable::Tensor& input,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& blockscale_offsets,
|
||||
torch::stable::Tensor& quant_output,
|
||||
torch::stable::Tensor& scale_factor) {
|
||||
ThrLayout thr_layout{};
|
||||
ValLayout val_layout{};
|
||||
SfR2SThrLayout r2s_thr_layout{};
|
||||
@@ -386,19 +389,18 @@ void launch_mxfp8_experts_quant(const torch::Tensor& input,
|
||||
CopyAtomR2S{}, r2s_thr_layout, r2s_val_layout); // Tiler_MN: (16, 4)
|
||||
|
||||
int max_active_blocks_per_sm = -1;
|
||||
AT_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
STD_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks_per_sm,
|
||||
mxfp8_experts_quant_kernel<T_IN, decltype(tiled_copy_g2r),
|
||||
decltype(tiled_copy_r2g),
|
||||
decltype(tiled_copy_r2s)>,
|
||||
THREAD_BLOCK_SIZE, 0));
|
||||
|
||||
dim3 grid(at::cuda::getCurrentDeviceProperties()->multiProcessorCount *
|
||||
max_active_blocks_per_sm,
|
||||
dim3 grid(get_device_prop()->multiProcessorCount * max_active_blocks_per_sm,
|
||||
1, 1);
|
||||
dim3 block(THREAD_BLOCK_SIZE, 1, 1);
|
||||
int num_experts = (int)problem_sizes.size(0);
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
int num_experts = static_cast<int>(problem_sizes.size(0));
|
||||
auto stream = get_current_cuda_stream(input.get_device_index());
|
||||
mxfp8_experts_quant_kernel<T_IN, decltype(tiled_copy_g2r),
|
||||
decltype(tiled_copy_r2g), decltype(tiled_copy_r2s)>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
@@ -231,6 +231,27 @@ void fused_qk_norm_rope(torch::stable::Tensor& qkv, int64_t num_heads_q,
|
||||
torch::stable::Tensor& position_ids,
|
||||
int64_t forced_token_heads_per_warp);
|
||||
|
||||
torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
|
||||
torch::stable::Tensor const& q_in, torch::stable::Tensor const& kv,
|
||||
torch::stable::Tensor& k_cache, torch::stable::Tensor const& slot_mapping,
|
||||
torch::stable::Tensor const& position_ids,
|
||||
torch::stable::Tensor const& cos_sin_cache, int64_t q_head_padded,
|
||||
double eps, int64_t cache_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::stable::Tensor minimax_allreduce_rms(
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& norm_weight, torch::stable::Tensor workspace,
|
||||
int64_t const rank, int64_t const nranks, double const eps);
|
||||
std::tuple<torch::stable::Tensor, torch::stable::Tensor>
|
||||
minimax_allreduce_rms_qk(torch::stable::Tensor qkv,
|
||||
torch::stable::Tensor const& norm_weight_q,
|
||||
torch::stable::Tensor const& norm_weight_k,
|
||||
torch::stable::Tensor workspace, int64_t const q_size,
|
||||
int64_t const kv_size, int64_t const rank,
|
||||
int64_t const nranks, double const eps);
|
||||
#endif
|
||||
|
||||
// Sampler kernels (shared CUDA/ROCm)
|
||||
void apply_repetition_penalties_(
|
||||
torch::stable::Tensor& logits, const torch::stable::Tensor& prompt_mask,
|
||||
@@ -273,6 +294,26 @@ void selective_scan_fwd(
|
||||
const std::optional<torch::stable::Tensor>& cu_chunk_seqlen,
|
||||
const std::optional<torch::stable::Tensor>& last_chunk_indices);
|
||||
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||
torch::stable::Tensor& rank_data, int64_t rank,
|
||||
bool fully_connected);
|
||||
void all_reduce(fptr_t _fa, torch::stable::Tensor& inp,
|
||||
torch::stable::Tensor& out, fptr_t reg_buffer,
|
||||
int64_t reg_buffer_sz_bytes);
|
||||
void dispose(fptr_t _fa);
|
||||
int64_t meta_size();
|
||||
void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
|
||||
get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||
void register_graph_buffers(fptr_t _fa,
|
||||
const std::vector<std::vector<int64_t>>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets);
|
||||
std::tuple<int64_t, torch::stable::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size);
|
||||
int64_t open_mem_handle(torch::stable::Tensor& mem_handle);
|
||||
void free_shared_buffer(int64_t buffer);
|
||||
|
||||
// Activation kernels (shared CUDA/ROCm)
|
||||
void silu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input);
|
||||
void silu_and_mul_clamp(torch::stable::Tensor& out,
|
||||
|
||||
@@ -337,6 +337,24 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||
"bool is_neox, Tensor position_ids, "
|
||||
"int forced_token_heads_per_warp=-1) -> ()");
|
||||
|
||||
ops.def(
|
||||
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert("
|
||||
"Tensor q_in, Tensor kv, Tensor! k_cache, "
|
||||
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
|
||||
"int q_head_padded, float eps, int cache_block_size) -> Tensor");
|
||||
|
||||
#ifndef USE_ROCM
|
||||
ops.def(
|
||||
"minimax_allreduce_rms("
|
||||
"Tensor input, Tensor norm_weight, Tensor workspace, "
|
||||
"int rank, int nranks, float eps) -> Tensor");
|
||||
ops.def(
|
||||
"minimax_allreduce_rms_qk("
|
||||
"Tensor qkv, Tensor norm_weight_q, Tensor norm_weight_k, "
|
||||
"Tensor workspace, int q_size, int kv_size, int rank, int nranks, "
|
||||
"float eps) -> (Tensor, Tensor)");
|
||||
#endif
|
||||
|
||||
// Apply repetition penalties to logits in-place.
|
||||
ops.def(
|
||||
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
|
||||
@@ -571,6 +589,12 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||
// Positional encoding kernels (shared CUDA/ROCm)
|
||||
ops.impl("rotary_embedding", TORCH_BOX(&rotary_embedding));
|
||||
ops.impl("fused_qk_norm_rope", TORCH_BOX(&fused_qk_norm_rope));
|
||||
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert",
|
||||
TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert));
|
||||
#ifndef USE_ROCM
|
||||
ops.impl("minimax_allreduce_rms", TORCH_BOX(&minimax_allreduce_rms));
|
||||
ops.impl("minimax_allreduce_rms_qk", TORCH_BOX(&minimax_allreduce_rms_qk));
|
||||
#endif
|
||||
|
||||
// Sampler kernels (shared CUDA/ROCm)
|
||||
ops.impl("apply_repetition_penalties_",
|
||||
@@ -725,6 +749,45 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C_cache_ops, ops) {
|
||||
"dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(_C_custom_ar, custom_ar) {
|
||||
custom_ar.def(
|
||||
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
|
||||
"int rank, bool fully_connected) -> int");
|
||||
custom_ar.def(
|
||||
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
|
||||
"int reg_buffer_sz_bytes) -> ()");
|
||||
custom_ar.def("dispose(int fa) -> ()");
|
||||
custom_ar.def("meta_size() -> int");
|
||||
custom_ar.def("register_buffer(int fa, int[] ipc_tensors) -> ()");
|
||||
custom_ar.def("get_graph_buffer_ipc_meta(int fa) -> (int[], int[])");
|
||||
custom_ar.def(
|
||||
"register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()");
|
||||
custom_ar.def("allocate_shared_buffer_and_handle(int size) -> (int, Tensor)");
|
||||
custom_ar.def("open_mem_handle(Tensor mem_handle) -> int");
|
||||
custom_ar.def("free_shared_buffer(int ptr) -> ()");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CUDA, custom_ar) {
|
||||
custom_ar.impl("init_custom_ar", TORCH_BOX(&init_custom_ar));
|
||||
custom_ar.impl("all_reduce", TORCH_BOX(&all_reduce));
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CPU, custom_ar) {
|
||||
custom_ar.impl("open_mem_handle", TORCH_BOX(&open_mem_handle));
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CompositeExplicitAutograd, custom_ar) {
|
||||
custom_ar.impl("dispose", TORCH_BOX(&dispose));
|
||||
custom_ar.impl("meta_size", TORCH_BOX(&meta_size));
|
||||
custom_ar.impl("register_buffer", TORCH_BOX(®ister_buffer));
|
||||
custom_ar.impl("get_graph_buffer_ipc_meta",
|
||||
TORCH_BOX(&get_graph_buffer_ipc_meta));
|
||||
custom_ar.impl("register_graph_buffers", TORCH_BOX(®ister_graph_buffers));
|
||||
custom_ar.impl("allocate_shared_buffer_and_handle",
|
||||
TORCH_BOX(&allocate_shared_buffer_and_handle));
|
||||
custom_ar.impl("free_shared_buffer", TORCH_BOX(&free_shared_buffer));
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C_cache_ops, CPU, ops) {
|
||||
ops.impl("swap_blocks_batch", TORCH_BOX(&swap_blocks_batch));
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <torch/types.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
namespace vllm {
|
||||
namespace tensorrt_llm {
|
||||
@@ -51,7 +51,7 @@ static constexpr int kElemsPerAccess = ElemsPerAccess<DType>::value;
|
||||
struct MiniMaxReduceRMSParams {
|
||||
int nranks{};
|
||||
int rank{};
|
||||
at::ScalarType dtype{at::ScalarType::Undefined};
|
||||
torch::headeronly::ScalarType dtype{torch::headeronly::ScalarType::Undefined};
|
||||
int size_q{};
|
||||
int hidden_dim{};
|
||||
int size_k{};
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
// Adapted from SGLang:
|
||||
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass_mxfp8_grouped_mm_launcher.cuh"
|
||||
|
||||
void cutlass_mxfp8_grouped_mm(const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& sfa,
|
||||
const torch::Tensor& sfb, torch::Tensor& d,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& blockscale_offsets) {
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||
"problem_sizes must have shape (num_experts, 3)");
|
||||
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
|
||||
"Number of experts in problem_sizes must match expert_offsets");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(expert_offsets.dtype() == torch::kInt32,
|
||||
"expert_offsets must be int32");
|
||||
TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32,
|
||||
"blockscale_offsets must be int32");
|
||||
TORCH_CHECK(a.dim() == 2, "a must be a 2D tensor of shape (num_tokens, k)");
|
||||
TORCH_CHECK(b.dim() == 3,
|
||||
"b must be a 3D tensor of shape (num_experts, k, n)");
|
||||
TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0,
|
||||
"k should align 128");
|
||||
TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128");
|
||||
TORCH_CHECK(a.strides()[1] == 1, "a must be row major");
|
||||
TORCH_CHECK(b.strides()[1] == 1, "b must be column major");
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
if (d.dtype() == torch::kBFloat16) {
|
||||
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
|
||||
cutlass::bfloat16_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
|
||||
blockscale_offsets, stream);
|
||||
} else if (d.dtype() == torch::kFloat16) {
|
||||
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
|
||||
cutlass::half_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
|
||||
blockscale_offsets, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false,
|
||||
"No implemented cutlass_mxfp8_grouped_mm for "
|
||||
"current device");
|
||||
#endif
|
||||
}
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("cutlass_mxfp8_grouped_mm", cutlass_mxfp8_grouped_mm);
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
// Adapted from SGLang:
|
||||
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "mxfp8_experts_quant.cuh"
|
||||
|
||||
void mxfp8_experts_quant(const torch::Tensor& input,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& blockscale_offsets,
|
||||
torch::Tensor& quant_output,
|
||||
torch::Tensor& scale_factor) {
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
TORCH_CHECK(input.dim() == 2, "input must be 2D tensor");
|
||||
TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128");
|
||||
TORCH_CHECK(input.strides()[1] == 1, "input must be row major");
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(expert_offsets.dtype() == torch::kInt32,
|
||||
"expert_offsets must be int32");
|
||||
TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32,
|
||||
"blockscale_offsets must be int32");
|
||||
|
||||
auto groups = problem_sizes.size(0);
|
||||
TORCH_CHECK(
|
||||
expert_offsets.dim() == 1 && expert_offsets.size(0) == groups,
|
||||
"expert_offsets must be 1D and have size equal to the number of groups");
|
||||
TORCH_CHECK(
|
||||
blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups,
|
||||
"blockscale_offsets must be 1D and have size equal to the number of "
|
||||
"groups");
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
if (input.dtype() == torch::kBFloat16) {
|
||||
expert_specialization::launch_mxfp8_experts_quant<__nv_bfloat16>(
|
||||
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
|
||||
scale_factor);
|
||||
} else if (input.dtype() == torch::kFloat16) {
|
||||
expert_specialization::launch_mxfp8_experts_quant<__half>(
|
||||
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
|
||||
scale_factor);
|
||||
} else {
|
||||
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false,
|
||||
"No implemented mxfp8_experts_quant for "
|
||||
"current device");
|
||||
#endif
|
||||
}
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("mxfp8_experts_quant", mxfp8_experts_quant);
|
||||
}
|
||||
-36
@@ -40,12 +40,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||
torch::Tensor& weight, double epsilon);
|
||||
|
||||
torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
|
||||
torch::Tensor const& q_in, torch::Tensor const& kv, torch::Tensor& k_cache,
|
||||
torch::Tensor const& slot_mapping, torch::Tensor const& position_ids,
|
||||
torch::Tensor const& cos_sin_cache, int64_t q_head_padded, double eps,
|
||||
int64_t cache_block_size);
|
||||
|
||||
void silu_and_mul_per_block_quant(torch::Tensor& out,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor& scales, int64_t group_size,
|
||||
@@ -107,24 +101,6 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
|
||||
int64_t activation_kind);
|
||||
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||
torch::Tensor& rank_data, int64_t rank,
|
||||
bool fully_connected);
|
||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
|
||||
void dispose(fptr_t _fa);
|
||||
int64_t meta_size();
|
||||
void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
|
||||
get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||
void register_graph_buffers(fptr_t _fa,
|
||||
const std::vector<std::vector<int64_t>>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets);
|
||||
std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size);
|
||||
int64_t open_mem_handle(torch::Tensor& mem_handle);
|
||||
void free_shared_buffer(int64_t buffer);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
fptr_t init_custom_qr(int64_t rank, int64_t world_size,
|
||||
std::optional<int64_t> qr_max_size = std::nullopt);
|
||||
@@ -135,15 +111,3 @@ void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
int64_t quant_level, bool cast_bf2half = false);
|
||||
int64_t qr_max_size();
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
|
||||
torch::Tensor const& norm_weight,
|
||||
torch::Tensor workspace, int64_t const rank,
|
||||
int64_t const nranks, double const eps);
|
||||
std::tuple<torch::Tensor, torch::Tensor> minimax_allreduce_rms_qk(
|
||||
torch::Tensor qkv, torch::Tensor const& norm_weight_q,
|
||||
torch::Tensor const& norm_weight_k, torch::Tensor workspace,
|
||||
int64_t const q_size, int64_t const kv_size, int64_t const rank,
|
||||
int64_t const nranks, double const eps);
|
||||
#endif
|
||||
|
||||
+20
-78
@@ -55,14 +55,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
|
||||
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
|
||||
// kernel launch.
|
||||
ops.def(
|
||||
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert("
|
||||
"Tensor q_in, Tensor kv, Tensor! k_cache, "
|
||||
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
|
||||
"int q_head_padded, float eps, int cache_block_size) -> Tensor");
|
||||
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA,
|
||||
&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert);
|
||||
// kernel launch. Registered in _C_stable_libtorch.
|
||||
|
||||
// Quantization ops
|
||||
#ifndef USE_ROCM
|
||||
@@ -163,34 +156,27 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
ops.def(
|
||||
"minimax_allreduce_rms("
|
||||
"Tensor input,"
|
||||
"Tensor norm_weight,"
|
||||
"Tensor workspace,"
|
||||
"int rank,"
|
||||
"int nranks,"
|
||||
"float eps) -> Tensor");
|
||||
ops.impl("minimax_allreduce_rms", torch::kCUDA, &minimax_allreduce_rms);
|
||||
ops.def(
|
||||
"minimax_allreduce_rms_qk("
|
||||
"Tensor qkv,"
|
||||
"Tensor norm_weight_q,"
|
||||
"Tensor norm_weight_k,"
|
||||
"Tensor workspace,"
|
||||
"int q_size,"
|
||||
"int kv_size,"
|
||||
"int rank,"
|
||||
"int nranks,"
|
||||
"float eps) -> (Tensor, Tensor)");
|
||||
ops.impl("minimax_allreduce_rms_qk", torch::kCUDA, &minimax_allreduce_rms_qk);
|
||||
|
||||
// conditionally compiled so impl in source file
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
TORCH_LIBRARY_FRAGMENT(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
||||
// Quick Reduce all-reduce kernels (ROCm-only; stays on legacy _C).
|
||||
custom_ar.def(
|
||||
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
|
||||
"cast_bf2half) -> ()");
|
||||
custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
|
||||
|
||||
custom_ar.def("init_custom_qr", &init_custom_qr);
|
||||
custom_ar.def("qr_destroy", &qr_destroy);
|
||||
custom_ar.def("qr_get_handle", &qr_get_handle);
|
||||
|
||||
custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
|
||||
custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
|
||||
|
||||
custom_ar.def("qr_max_size", &qr_max_size);
|
||||
}
|
||||
#endif
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||
// Cuda utils
|
||||
|
||||
@@ -205,48 +191,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||
&get_max_shared_memory_per_block_device_attribute);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
||||
// Custom all-reduce kernels
|
||||
custom_ar.def(
|
||||
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
|
||||
"int rank, bool fully_connected) -> int");
|
||||
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
||||
custom_ar.def(
|
||||
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
|
||||
"int reg_buffer_sz_bytes) -> ()");
|
||||
custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce);
|
||||
|
||||
custom_ar.def("dispose", &dispose);
|
||||
custom_ar.def("meta_size", &meta_size);
|
||||
|
||||
custom_ar.def("register_buffer", ®ister_buffer);
|
||||
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
||||
custom_ar.def("register_graph_buffers", ®ister_graph_buffers);
|
||||
|
||||
custom_ar.def("allocate_shared_buffer_and_handle",
|
||||
&allocate_shared_buffer_and_handle);
|
||||
custom_ar.def("open_mem_handle(Tensor mem_handle) -> int", &open_mem_handle);
|
||||
custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle);
|
||||
|
||||
custom_ar.def("free_shared_buffer", &free_shared_buffer);
|
||||
#ifdef USE_ROCM
|
||||
// Quick Reduce all-reduce kernels
|
||||
custom_ar.def(
|
||||
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
|
||||
"cast_bf2half) -> ()");
|
||||
custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
|
||||
|
||||
custom_ar.def("init_custom_qr", &init_custom_qr);
|
||||
custom_ar.def("qr_destroy", &qr_destroy);
|
||||
|
||||
custom_ar.def("qr_get_handle", &qr_get_handle);
|
||||
|
||||
custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
|
||||
custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
|
||||
|
||||
// Max input size in bytes
|
||||
custom_ar.def("qr_max_size", &qr_max_size);
|
||||
#endif
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
||||
@@ -50,7 +50,7 @@ struct _typeConvert<float> {
|
||||
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||
// CUDA < 12.0 runs into issues with packed type conversion
|
||||
template <>
|
||||
struct _typeConvert<c10::Half> {
|
||||
struct _typeConvert<torch::headeronly::Half> {
|
||||
static constexpr bool exists = true;
|
||||
using hip_type = __half;
|
||||
using packed_hip_type = __half2;
|
||||
@@ -73,7 +73,7 @@ struct _typeConvert<c10::Half> {
|
||||
// CUDA_ARCH < 800 does not have BF16 support
|
||||
// ROCm 7.0+ supports bfloat16
|
||||
template <>
|
||||
struct _typeConvert<c10::BFloat16> {
|
||||
struct _typeConvert<torch::headeronly::BFloat16> {
|
||||
static constexpr bool exists = true;
|
||||
using hip_type = __nv_bfloat16;
|
||||
using packed_hip_type = __nv_bfloat162;
|
||||
|
||||
@@ -7,6 +7,9 @@ from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.renderers.hf import (
|
||||
_consolidate_system_messages,
|
||||
_convert_developer_to_system,
|
||||
_detect_developer_role_support,
|
||||
_get_hf_base_chat_template_params,
|
||||
_try_extract_ast,
|
||||
resolve_chat_template,
|
||||
@@ -592,3 +595,322 @@ def test_get_gen_prompt(
|
||||
f"The generated prompt does not match the expected output for "
|
||||
f"model {model} and template {template}"
|
||||
)
|
||||
|
||||
|
||||
class TestConvertDeveloperToSystem:
|
||||
def test_converts_role(self):
|
||||
conversation = [
|
||||
{"role": "developer", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
result = _convert_developer_to_system(conversation)
|
||||
assert result[0]["role"] == "system"
|
||||
assert result[0]["content"] == "You are helpful."
|
||||
assert result[1]["role"] == "user"
|
||||
|
||||
def test_removes_tools_key(self):
|
||||
conversation = [
|
||||
{
|
||||
"role": "developer",
|
||||
"content": "Instructions",
|
||||
"tools": [{"type": "function"}],
|
||||
},
|
||||
]
|
||||
result = _convert_developer_to_system(conversation)
|
||||
assert "tools" not in result[0]
|
||||
|
||||
def test_no_developer_messages_unchanged(self):
|
||||
conversation = [
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
result = _convert_developer_to_system(conversation)
|
||||
assert result[0]["role"] == "system"
|
||||
assert result[1]["role"] == "user"
|
||||
|
||||
def test_does_not_mutate_original(self):
|
||||
original = {
|
||||
"role": "developer",
|
||||
"content": "Instructions",
|
||||
"tools": [{"type": "function"}],
|
||||
}
|
||||
conversation = [original]
|
||||
_convert_developer_to_system(conversation)
|
||||
assert original["role"] == "developer"
|
||||
assert "tools" in original
|
||||
|
||||
|
||||
# --- Developer role detection and conversion tests ---
|
||||
|
||||
CHATML_TEMPLATE = (
|
||||
"{% for message in messages %}"
|
||||
"{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}"
|
||||
"{% if (loop.last and add_generation_prompt) or not loop.last %}"
|
||||
"{{ '<|im_end|>' + '\\n'}}"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}"
|
||||
"{{ '<|im_start|>assistant\\n' }}"
|
||||
"{% endif %}"
|
||||
)
|
||||
|
||||
TEMPLATE_WITH_DEVELOPER = (
|
||||
"{% for message in messages %}"
|
||||
"{% if message['role'] == 'developer' %}"
|
||||
"{{'<|im_start|>developer\\n' + message['content'] + '<|im_end|>\\n'}}"
|
||||
"{% elif message['role'] == 'system' %}"
|
||||
"{{'<|im_start|>system\\n' + message['content'] + '<|im_end|>\\n'}}"
|
||||
"{% elif message['role'] == 'user' %}"
|
||||
"{{'<|im_start|>user\\n' + message['content'] + '<|im_end|>\\n'}}"
|
||||
"{% elif message['role'] == 'assistant' %}"
|
||||
"{{'<|im_start|>assistant\\n' + message['content'] + '<|im_end|>\\n'}}"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}"
|
||||
"{{ '<|im_start|>assistant\\n' }}"
|
||||
"{% endif %}"
|
||||
)
|
||||
|
||||
STRICT_ROLE_TEMPLATE = (
|
||||
"{% for message in messages %}"
|
||||
"{% if message['role'] == 'system' %}"
|
||||
"{{'<|im_start|>system\\n' + message['content'] + '<|im_end|>\\n'}}"
|
||||
"{% elif message['role'] == 'user' %}"
|
||||
"{{'<|im_start|>user\\n' + message['content'] + '<|im_end|>\\n'}}"
|
||||
"{% elif message['role'] == 'assistant' %}"
|
||||
"{{'<|im_start|>assistant\\n' + message['content'] + '<|im_end|>\\n'}}"
|
||||
"{% else %}"
|
||||
"{{ raise_exception('Unexpected message role: ' + message['role']) }}"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}"
|
||||
"{{ '<|im_start|>assistant\\n' }}"
|
||||
"{% endif %}"
|
||||
)
|
||||
|
||||
|
||||
class TestDetectDeveloperRoleSupport:
|
||||
def test_absent_in_chatml(self):
|
||||
assert _detect_developer_role_support(CHATML_TEMPLATE) is False
|
||||
|
||||
def test_present_double_quotes(self):
|
||||
assert _detect_developer_role_support(TEMPLATE_WITH_DEVELOPER) is True
|
||||
|
||||
def test_present_single_quotes(self):
|
||||
template = TEMPLATE_WITH_DEVELOPER.replace('"developer"', "'developer'")
|
||||
assert _detect_developer_role_support(template) is True
|
||||
|
||||
def test_absent_in_strict_template(self):
|
||||
assert _detect_developer_role_support(STRICT_ROLE_TEMPLATE) is False
|
||||
|
||||
|
||||
class TestSafeApplyChatTemplateDeveloperRole:
|
||||
@pytest.fixture
|
||||
def model_config(self):
|
||||
return ModelConfig(
|
||||
"facebook/opt-125m",
|
||||
tokenizer="facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer(self):
|
||||
return get_tokenizer("facebook/opt-125m")
|
||||
|
||||
def test_developer_converted_to_system_for_chatml(self, model_config, tokenizer):
|
||||
conversation = [
|
||||
{"role": "developer", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
result = safe_apply_chat_template(
|
||||
model_config,
|
||||
tokenizer,
|
||||
conversation,
|
||||
chat_template=CHATML_TEMPLATE,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
assert "<|im_start|>system" in result
|
||||
assert "You are a helpful assistant." in result
|
||||
assert "<|im_start|>developer" not in result
|
||||
|
||||
def test_developer_preserved_when_template_supports_it(
|
||||
self, model_config, tokenizer
|
||||
):
|
||||
conversation = [
|
||||
{"role": "developer", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
result = safe_apply_chat_template(
|
||||
model_config,
|
||||
tokenizer,
|
||||
conversation,
|
||||
chat_template=TEMPLATE_WITH_DEVELOPER,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
assert "<|im_start|>developer" in result
|
||||
assert "You are a helpful assistant." in result
|
||||
|
||||
def test_developer_does_not_crash_strict_template(self, model_config, tokenizer):
|
||||
conversation = [
|
||||
{"role": "developer", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
result = safe_apply_chat_template(
|
||||
model_config,
|
||||
tokenizer,
|
||||
conversation,
|
||||
chat_template=STRICT_ROLE_TEMPLATE,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
assert "<|im_start|>system" in result
|
||||
assert "You are a helpful assistant." in result
|
||||
|
||||
def test_no_developer_messages_no_overhead(self, model_config, tokenizer):
|
||||
conversation = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
result = safe_apply_chat_template(
|
||||
model_config,
|
||||
tokenizer,
|
||||
conversation,
|
||||
chat_template=CHATML_TEMPLATE,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
assert "<|im_start|>system" in result
|
||||
assert "You are helpful." in result
|
||||
|
||||
def test_developer_at_non_first_position_consolidated(
|
||||
self, model_config, tokenizer
|
||||
):
|
||||
conversation = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "developer", "content": "Be concise."},
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
]
|
||||
result = safe_apply_chat_template(
|
||||
model_config,
|
||||
tokenizer,
|
||||
conversation,
|
||||
chat_template=SYSTEM_FIRST_TEMPLATE,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
assert "<|im_start|>system" in result
|
||||
assert "You are helpful." in result
|
||||
assert "Be concise." in result
|
||||
assert "What is 2+2?" in result
|
||||
|
||||
def test_developer_only_no_prior_system(self, model_config, tokenizer):
|
||||
conversation = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "developer", "content": "Be concise."},
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
]
|
||||
result = safe_apply_chat_template(
|
||||
model_config,
|
||||
tokenizer,
|
||||
conversation,
|
||||
chat_template=SYSTEM_FIRST_TEMPLATE,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
assert "<|im_start|>system" in result
|
||||
assert "Be concise." in result
|
||||
|
||||
|
||||
SYSTEM_FIRST_TEMPLATE = (
|
||||
"{% for message in messages %}"
|
||||
"{% if message['role'] == 'system' %}"
|
||||
"{% if not loop.first %}"
|
||||
"{{ raise_exception('System message must be at the beginning.') }}"
|
||||
"{% endif %}"
|
||||
"{{'<|im_start|>system\\n' + message['content'] + '<|im_end|>\\n'}}"
|
||||
"{% elif message['role'] == 'user' %}"
|
||||
"{{'<|im_start|>user\\n' + message['content'] + '<|im_end|>\\n'}}"
|
||||
"{% elif message['role'] == 'assistant' %}"
|
||||
"{{'<|im_start|>assistant\\n' + message['content'] + '<|im_end|>\\n'}}"
|
||||
"{% else %}"
|
||||
"{{ raise_exception('Unexpected message role: ' + message['role']) }}"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}"
|
||||
"{{ '<|im_start|>assistant\\n' }}"
|
||||
"{% endif %}"
|
||||
)
|
||||
|
||||
|
||||
class TestConsolidateSystemMessages:
|
||||
def test_no_system_messages_unchanged(self):
|
||||
conversation = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
]
|
||||
result = _consolidate_system_messages(conversation)
|
||||
assert result == conversation
|
||||
|
||||
def test_single_system_at_start_unchanged(self):
|
||||
conversation = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
result = _consolidate_system_messages(conversation)
|
||||
assert result == conversation
|
||||
|
||||
def test_system_at_non_first_position_moved(self):
|
||||
conversation = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
]
|
||||
result = _consolidate_system_messages(conversation)
|
||||
assert result[0]["role"] == "system"
|
||||
assert result[0]["content"] == "You are helpful."
|
||||
assert result[1]["role"] == "user"
|
||||
assert result[1]["content"] == "Hello"
|
||||
|
||||
def test_multiple_system_messages_merged(self):
|
||||
conversation = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "system", "content": "Be concise."},
|
||||
]
|
||||
result = _consolidate_system_messages(conversation)
|
||||
assert len(result) == 2
|
||||
assert result[0]["role"] == "system"
|
||||
assert result[0]["content"] == "You are helpful.\n\nBe concise."
|
||||
assert result[1]["role"] == "user"
|
||||
|
||||
def test_list_content_handled(self):
|
||||
conversation = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "Rule 1."},
|
||||
{"type": "text", "text": "Rule 2."},
|
||||
],
|
||||
},
|
||||
]
|
||||
result = _consolidate_system_messages(conversation)
|
||||
assert result[0]["role"] == "system"
|
||||
assert result[0]["content"] == "Rule 1.\nRule 2."
|
||||
assert result[1]["role"] == "user"
|
||||
|
||||
def test_does_not_mutate_original(self):
|
||||
conversation = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
]
|
||||
original_len = len(conversation)
|
||||
_consolidate_system_messages(conversation)
|
||||
assert len(conversation) == original_len
|
||||
assert conversation[0]["role"] == "user"
|
||||
assert conversation[1]["role"] == "system"
|
||||
|
||||
@@ -24,6 +24,7 @@ from vllm.utils.hashing import sha256
|
||||
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
|
||||
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.core.single_type_kv_cache_manager import register_all_kvcache_specs
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
@@ -160,6 +161,7 @@ def create_scheduler(
|
||||
],
|
||||
)
|
||||
cache_config.num_gpu_blocks = num_blocks
|
||||
register_all_kvcache_specs(vllm_config)
|
||||
scheduler_cls = AsyncScheduler if async_scheduling else Scheduler
|
||||
return scheduler_cls(
|
||||
vllm_config=vllm_config,
|
||||
|
||||
@@ -21,19 +21,20 @@ dp_ep_configs=(
|
||||
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1)
|
||||
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1)
|
||||
)
|
||||
# We assume HMA enabled by default.
|
||||
hybrid_ssm_configs=(
|
||||
"VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
|
||||
"VLLM_SSM_CONV_STATE_LAYOUT=DS GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
|
||||
# TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models.
|
||||
"VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
|
||||
"VLLM_SSM_CONV_STATE_LAYOUT=DS PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
|
||||
# GDN (Qwen3.5)
|
||||
"VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=Qwen/Qwen3.5-0.8B"
|
||||
"VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=Qwen/Qwen3.5-0.8B VLLM_SERVE_EXTRA_ARGS=--no-async-scheduling"
|
||||
"VLLM_SSM_CONV_STATE_LAYOUT=DS GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=Qwen/Qwen3.5-0.8B"
|
||||
"VLLM_SSM_CONV_STATE_LAYOUT=DS PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=Qwen/Qwen3.5-0.8B VLLM_SERVE_EXTRA_ARGS=--no-async-scheduling"
|
||||
)
|
||||
sw_attn_configs=(
|
||||
# NOTE: gemma3 does not work with FlashInfer
|
||||
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192" # SW model
|
||||
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
|
||||
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
|
||||
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
|
||||
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
|
||||
)
|
||||
|
||||
# Select config array based on DP_EP env var
|
||||
@@ -50,14 +51,6 @@ else
|
||||
configs=("${tp_configs[@]}")
|
||||
fi
|
||||
|
||||
if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then
|
||||
# Append ENABLE_HMA_FLAG=1 to each config in the selected array
|
||||
echo "ENABLE_HMA_FLAG is set, appending ENABLE_HMA_FLAG=1 to each config"
|
||||
for i in "${!configs[@]}"; do
|
||||
configs[$i]="ENABLE_HMA_FLAG=1 ${configs[$i]}"
|
||||
done
|
||||
fi
|
||||
|
||||
run_tests() {
|
||||
local label=$1
|
||||
local extra_args=$2
|
||||
|
||||
@@ -11,7 +11,7 @@ SCRIPT="v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh"
|
||||
eagle3_config="SD_METHOD=eagle3 MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct SD_MODEL=RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3 NUM_SPEC_TOKENS=3"
|
||||
|
||||
# MTP: Qwen3.5-0.8B-Base with hybrid SSM flags.
|
||||
mtp_config="SD_METHOD=mtp MODEL_NAME=Qwen/Qwen3.5-0.8B-Base SD_MODEL=Qwen/Qwen3.5-0.8B-Base NUM_SPEC_TOKENS=1 BLOCK_SIZE=32 MAX_MODEL_LEN=4096 VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 KV_BUFFER_DEVICES=cuda"
|
||||
mtp_config="SD_METHOD=mtp MODEL_NAME=Qwen/Qwen3.5-0.8B-Base SD_MODEL=Qwen/Qwen3.5-0.8B-Base NUM_SPEC_TOKENS=1 BLOCK_SIZE=32 MAX_MODEL_LEN=4096 VLLM_SSM_CONV_STATE_LAYOUT=DS KV_BUFFER_DEVICES=cuda"
|
||||
|
||||
configs=(
|
||||
"$eagle3_config"
|
||||
|
||||
@@ -5,11 +5,6 @@ set -xe
|
||||
KV_BUFFER_DEVICE="cuda" # Default to cuda
|
||||
ATTENTION_BACKEND="" # Default to empty (use vllm default)
|
||||
CROSS_LAYERS_BLOCKS="False"
|
||||
ENABLE_HMA_VAR="" # Default to empty (HMA disabled by default for kv connector)
|
||||
# Check for ENABLE_HMA_FLAG environment variable
|
||||
if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then
|
||||
ENABLE_HMA_VAR="--no-disable-hybrid-kv-cache-manager"
|
||||
fi
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
@@ -37,9 +32,6 @@ echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
|
||||
if [[ -n "$ATTENTION_BACKEND" ]]; then
|
||||
echo "Using attention backend: $ATTENTION_BACKEND"
|
||||
fi
|
||||
if [[ -n "$ENABLE_HMA_VAR" ]]; then
|
||||
echo "HMA (Hybrid KV Cache Manager) enabled"
|
||||
fi
|
||||
if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
|
||||
echo "vLLM serve extra args: $VLLM_SERVE_EXTRA_ARGS"
|
||||
fi
|
||||
@@ -180,10 +172,6 @@ run_tests_for_model() {
|
||||
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
|
||||
fi
|
||||
|
||||
# Add HMA flag if specified
|
||||
if [[ -n "$ENABLE_HMA_VAR" ]]; then
|
||||
BASE_CMD="${BASE_CMD} $ENABLE_HMA_VAR"
|
||||
fi
|
||||
|
||||
FULL_CMD="$BASE_CMD"
|
||||
eval "$FULL_CMD &"
|
||||
@@ -232,10 +220,6 @@ run_tests_for_model() {
|
||||
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
|
||||
fi
|
||||
|
||||
# Add HMA flag if specified
|
||||
if [[ -n "$ENABLE_HMA_VAR" ]]; then
|
||||
BASE_CMD="${BASE_CMD} $ENABLE_HMA_VAR"
|
||||
fi
|
||||
|
||||
# DP-EP attention mode
|
||||
if [[ -z "$DP_EP" ]]; then
|
||||
|
||||
@@ -27,7 +27,6 @@
|
||||
# ROCM_AITER_UNIFIED_ATTN
|
||||
# NVIDIA options: FLASH_ATTN, FLASHINFER
|
||||
# VLLM_SSM_CONV_STATE_LAYOUT - SSM conv state layout (e.g. "DS" required for Mamba models)
|
||||
# ENABLE_HMA_FLAG - set to 1 to enable hybrid KV cache manager
|
||||
# VLLM_SERVE_EXTRA_ARGS - comma-separated extra args for vllm serve
|
||||
set -ex
|
||||
|
||||
@@ -85,13 +84,7 @@ if [[ -z "${ATTENTION_BACKEND:-}" ]]; then
|
||||
fi
|
||||
echo "Using attention backend: ${ATTENTION_BACKEND}"
|
||||
|
||||
# ── HMA & extra serve args ────────────────────────────────────────────
|
||||
|
||||
ENABLE_HMA_VAR=""
|
||||
if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then
|
||||
ENABLE_HMA_VAR="--no-disable-hybrid-kv-cache-manager"
|
||||
echo "HMA (Hybrid KV Cache Manager) enabled"
|
||||
fi
|
||||
# ── Extra serve args ─────────────────────────────────────────────────
|
||||
|
||||
EXTRA_SERVE_ARGS=()
|
||||
if [[ -n "${VLLM_SERVE_EXTRA_ARGS:-}" ]]; then
|
||||
@@ -258,7 +251,6 @@ run_test_for_device() {
|
||||
--kv-transfer-config "$kv_config" \
|
||||
--speculative-config "$PREFILL_SPEC_CONFIG" \
|
||||
--attention-backend $ATTENTION_BACKEND \
|
||||
${ENABLE_HMA_VAR} \
|
||||
${EXTRA_SERVE_ARGS[@]+"${EXTRA_SERVE_ARGS[@]}"} &
|
||||
local SERVER_PID=$!
|
||||
|
||||
@@ -298,7 +290,6 @@ run_test_for_device() {
|
||||
--kv-transfer-config "$kv_config" \
|
||||
--speculative-config "$DECODE_SPEC_CONFIG" \
|
||||
--attention-backend $ATTENTION_BACKEND \
|
||||
${ENABLE_HMA_VAR} \
|
||||
${EXTRA_SERVE_ARGS[@]+"${EXTRA_SERVE_ARGS[@]}"} &
|
||||
local SERVER_PID=$!
|
||||
|
||||
|
||||
@@ -386,8 +386,6 @@ def test_fewer_blocks_with_hma(monkeypatch, model_name, sw_size):
|
||||
"kv_transfer_config": kv_transfer_config,
|
||||
"max_model_len": 2048,
|
||||
"max_num_seqs": 1,
|
||||
# NOTE: Make sure HMA is enabled
|
||||
"disable_hybrid_kv_cache_manager": False,
|
||||
"max_num_batched_tokens": 2048,
|
||||
"enable_prefix_caching": False,
|
||||
"block_size": block_size,
|
||||
|
||||
@@ -30,6 +30,9 @@ from vllm.v1.core.sched.output import (
|
||||
NewRequestData,
|
||||
SchedulerOutput,
|
||||
)
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
register_all_kvcache_specs,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
@@ -68,6 +71,9 @@ def _make_kv_cache_config(
|
||||
"""Build a KVCacheConfig with non-empty kv_cache_tensors."""
|
||||
groups = []
|
||||
tensors = []
|
||||
register_all_kvcache_specs(
|
||||
vllm_config=None
|
||||
) # Ensure specs are registered for tests
|
||||
for g in range(num_groups):
|
||||
layer_names = [f"layer_{g}"]
|
||||
groups.append(
|
||||
|
||||
@@ -0,0 +1,322 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
DeviceConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
ChunkedLocalAttentionManager,
|
||||
CrossAttentionManager,
|
||||
FullAttentionManager,
|
||||
MambaManager,
|
||||
SingleTypeKVCacheManager,
|
||||
SinkFullAttentionManager,
|
||||
SlidingWindowManager,
|
||||
register_all_kvcache_specs,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
ChunkedLocalAttentionSpec,
|
||||
CrossAttentionSpec,
|
||||
FullAttentionSpec,
|
||||
HiddenStateCacheSpec,
|
||||
KVCacheSpec,
|
||||
MambaSpec,
|
||||
MLAAttentionSpec,
|
||||
SinkFullAttentionSpec,
|
||||
SlidingWindowMLASpec,
|
||||
SlidingWindowSpec,
|
||||
TQFullAttentionSpec,
|
||||
UniformTypeKVCacheSpecs,
|
||||
)
|
||||
from vllm.v1.kv_cache_spec_registry import (
|
||||
_REGISTRY_KVCACHESPEC_LIST,
|
||||
KVCacheSpecRegistry,
|
||||
register_kv_cache_spec,
|
||||
)
|
||||
|
||||
|
||||
def make_vllm_config() -> VllmConfig:
|
||||
return VllmConfig(
|
||||
cache_config=CacheConfig(
|
||||
block_size=64,
|
||||
cache_dtype="bfloat16",
|
||||
),
|
||||
device_config=DeviceConfig(device="cpu"),
|
||||
)
|
||||
|
||||
|
||||
vllm_config = make_vllm_config()
|
||||
register_all_kvcache_specs(vllm_config)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def restore_kv_cache_spec_registry():
|
||||
registry = _REGISTRY_KVCACHESPEC_LIST.copy()
|
||||
yield
|
||||
_REGISTRY_KVCACHESPEC_LIST.clear()
|
||||
_REGISTRY_KVCACHESPEC_LIST.update(registry)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _TrulyUnregisteredSpec(KVCacheSpec):
|
||||
"""
|
||||
A spec that inherits directly from KVCacheSpec with no registered
|
||||
ancestor in the MRO. Used to test that the registry correctly raises
|
||||
when no entry can be found.
|
||||
"""
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
return self.block_size * 128
|
||||
|
||||
def max_memory_usage_bytes(self, _) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
TQFullAttentionSpec: FullAttentionManager,
|
||||
MLAAttentionSpec: FullAttentionManager,
|
||||
HiddenStateCacheSpec: FullAttentionManager,
|
||||
SlidingWindowSpec: SlidingWindowManager,
|
||||
SlidingWindowMLASpec: SlidingWindowManager,
|
||||
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
|
||||
MambaSpec: MambaManager,
|
||||
CrossAttentionSpec: CrossAttentionManager,
|
||||
SinkFullAttentionSpec: SinkFullAttentionManager,
|
||||
}
|
||||
|
||||
spec_uniform_base_map: dict[type[KVCacheSpec], type[KVCacheSpec]] = {
|
||||
FullAttentionSpec: FullAttentionSpec,
|
||||
TQFullAttentionSpec: FullAttentionSpec,
|
||||
MLAAttentionSpec: FullAttentionSpec,
|
||||
HiddenStateCacheSpec: FullAttentionSpec,
|
||||
SlidingWindowSpec: SlidingWindowSpec,
|
||||
SlidingWindowMLASpec: SlidingWindowMLASpec,
|
||||
ChunkedLocalAttentionSpec: ChunkedLocalAttentionSpec,
|
||||
MambaSpec: MambaSpec,
|
||||
CrossAttentionSpec: CrossAttentionSpec,
|
||||
SinkFullAttentionSpec: FullAttentionSpec,
|
||||
}
|
||||
|
||||
spec_args_map: dict[type[KVCacheSpec], dict[str, Any]] = {
|
||||
FullAttentionSpec: dict(
|
||||
block_size=64, num_kv_heads=8, head_size=128, dtype=torch.bfloat16
|
||||
),
|
||||
TQFullAttentionSpec: dict(
|
||||
block_size=64,
|
||||
num_kv_heads=8,
|
||||
head_size=128,
|
||||
dtype=torch.bfloat16,
|
||||
tq_slot_size=256,
|
||||
),
|
||||
MLAAttentionSpec: dict(
|
||||
block_size=64, num_kv_heads=1, head_size=128, dtype=torch.bfloat16
|
||||
),
|
||||
HiddenStateCacheSpec: dict(
|
||||
block_size=64, num_kv_heads=1, head_size=128, dtype=torch.bfloat16
|
||||
),
|
||||
SlidingWindowSpec: dict(
|
||||
block_size=64,
|
||||
num_kv_heads=8,
|
||||
head_size=128,
|
||||
dtype=torch.bfloat16,
|
||||
sliding_window=128,
|
||||
),
|
||||
SlidingWindowMLASpec: dict(
|
||||
block_size=64,
|
||||
num_kv_heads=1,
|
||||
head_size=128,
|
||||
dtype=torch.bfloat16,
|
||||
sliding_window=128,
|
||||
),
|
||||
ChunkedLocalAttentionSpec: dict(
|
||||
block_size=64,
|
||||
num_kv_heads=8,
|
||||
head_size=128,
|
||||
dtype=torch.bfloat16,
|
||||
attention_chunk_size=4,
|
||||
),
|
||||
MambaSpec: dict(
|
||||
block_size=64,
|
||||
shapes=((2, 512), (3, 32, 32)),
|
||||
dtypes=(torch.float32, torch.float32),
|
||||
mamba_cache_mode="align",
|
||||
num_speculative_blocks=2,
|
||||
),
|
||||
CrossAttentionSpec: dict(
|
||||
block_size=64, num_kv_heads=8, head_size=128, dtype=torch.bfloat16
|
||||
),
|
||||
SinkFullAttentionSpec: dict(
|
||||
block_size=64, num_kv_heads=8, head_size=128, dtype=torch.bfloat16, sink_len=16
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def make_spec(spec_cls: type[KVCacheSpec]) -> KVCacheSpec:
|
||||
return spec_cls(**spec_args_map[spec_cls])
|
||||
|
||||
|
||||
def are_uniform_specs(*specs: KVCacheSpec) -> bool:
|
||||
return UniformTypeKVCacheSpecs.is_uniform_type(
|
||||
{f"layer_{i}": spec for i, spec in enumerate(specs)}
|
||||
)
|
||||
|
||||
|
||||
class TestKVCacheSpecRegistry:
|
||||
"""Test the core registry functionality."""
|
||||
|
||||
def test_builtin_kvcache_specs_registered(self):
|
||||
assert set(spec_manager_map) <= set(_REGISTRY_KVCACHESPEC_LIST)
|
||||
for spec_cls, manager in spec_manager_map.items():
|
||||
spec = make_spec(spec_cls)
|
||||
assert KVCacheSpecRegistry.get_manager_class(spec) is manager
|
||||
assert (
|
||||
KVCacheSpecRegistry.get_uniform_type_base_spec(spec)
|
||||
is spec_uniform_base_map[spec_cls]
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("spec_cls", list(spec_manager_map))
|
||||
def test_custom_spec_register(self, spec_cls):
|
||||
"""A decorated custom spec resolves to the declared manager."""
|
||||
manager = spec_manager_map[spec_cls]
|
||||
uniform_base_spec = spec_uniform_base_map[spec_cls]
|
||||
|
||||
@register_kv_cache_spec(
|
||||
manager_class=manager,
|
||||
uniform_type_base_spec=uniform_base_spec,
|
||||
)
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class _CustomSpec(spec_cls): # type: ignore[valid-type,misc]
|
||||
custom_param: int = 16
|
||||
|
||||
spec = _CustomSpec(**spec_args_map[spec_cls], custom_param=100)
|
||||
|
||||
assert KVCacheSpecRegistry.get_manager_class(spec) is manager
|
||||
assert KVCacheSpecRegistry.get_uniform_type_base_spec(spec) is uniform_base_spec
|
||||
|
||||
def test_custom_spec_register_requires_manager(self):
|
||||
"""Invalid register decorator arguments fail early."""
|
||||
|
||||
with pytest.raises(AssertionError, match="manager_class is required"):
|
||||
|
||||
@register_kv_cache_spec(
|
||||
uniform_type_base_spec=FullAttentionSpec,
|
||||
)
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class _CustomFullSpecWithoutManager(FullAttentionSpec):
|
||||
custom_param: int = 16
|
||||
|
||||
def test_unregistered_spec_no_registered_parent_raises(self):
|
||||
"""
|
||||
A spec whose entire MRO contains no registered class resolves to None.
|
||||
Runtime callers should use check_kv_cache_spec_registry to fail early.
|
||||
Subclasses of registered specs intentionally do not fail — they inherit
|
||||
their parent's manager via MRO walking.
|
||||
"""
|
||||
spec = _TrulyUnregisteredSpec(block_size=16)
|
||||
|
||||
assert KVCacheSpecRegistry.get_manager_class(spec) is None
|
||||
assert KVCacheSpecRegistry.get_uniform_type_base_spec(spec) is None
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Unsupported KV cache spec type for layer layer_0"
|
||||
):
|
||||
KVCacheSpecRegistry.check_kv_cache_spec_registry({"layer_0": spec})
|
||||
|
||||
with pytest.raises(AssertionError, match="Unsupported KV cache spec type"):
|
||||
UniformTypeKVCacheSpecs.is_uniform_type({"layer_0": spec})
|
||||
|
||||
def test_unregistered_subclass_inherits_parent_manager(self):
|
||||
"""
|
||||
An unregistered subclass of a registered spec resolves via MRO
|
||||
to its parent's manager — this is intentional registry behaviour.
|
||||
"""
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class _ImplicitlyInheritedSpec(FullAttentionSpec):
|
||||
pass
|
||||
|
||||
spec = _ImplicitlyInheritedSpec(
|
||||
block_size=16, num_kv_heads=8, head_size=128, dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# MRO walk finds FullAttentionSpec → FullAttentionManager
|
||||
assert KVCacheSpecRegistry.get_manager_class(spec) is FullAttentionManager
|
||||
|
||||
@pytest.mark.parametrize("spec_cls", list(spec_manager_map))
|
||||
def test_builtin_specs_are_uniform_with_same_spec_type(self, spec_cls):
|
||||
spec = make_spec(spec_cls)
|
||||
assert are_uniform_specs(spec, replace(spec))
|
||||
|
||||
def test_full_attention_family_specs_are_uniform(self):
|
||||
specs = [
|
||||
make_spec(FullAttentionSpec),
|
||||
make_spec(TQFullAttentionSpec),
|
||||
make_spec(MLAAttentionSpec),
|
||||
make_spec(HiddenStateCacheSpec),
|
||||
make_spec(SinkFullAttentionSpec),
|
||||
]
|
||||
|
||||
assert are_uniform_specs(*specs)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("spec_cls", "field", "value"),
|
||||
[
|
||||
(SlidingWindowSpec, "sliding_window", 256),
|
||||
(SlidingWindowMLASpec, "sliding_window", 256),
|
||||
(ChunkedLocalAttentionSpec, "attention_chunk_size", 8),
|
||||
(MambaSpec, "num_speculative_blocks", 4),
|
||||
],
|
||||
)
|
||||
def test_specs_with_type_specific_uniform_fields(self, spec_cls, field, value):
|
||||
spec = make_spec(spec_cls)
|
||||
changed_spec = replace(spec, **{field: value})
|
||||
|
||||
assert not are_uniform_specs(spec, changed_spec)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("left_cls", "right_cls"),
|
||||
[
|
||||
(FullAttentionSpec, CrossAttentionSpec),
|
||||
(FullAttentionSpec, SlidingWindowSpec),
|
||||
(FullAttentionSpec, ChunkedLocalAttentionSpec),
|
||||
(FullAttentionSpec, MambaSpec),
|
||||
(SlidingWindowMLASpec, SlidingWindowSpec),
|
||||
(ChunkedLocalAttentionSpec, SlidingWindowSpec),
|
||||
(MambaSpec, CrossAttentionSpec),
|
||||
],
|
||||
)
|
||||
def test_different_uniform_groups_are_not_uniform(self, left_cls, right_cls):
|
||||
assert not are_uniform_specs(make_spec(left_cls), make_spec(right_cls))
|
||||
|
||||
def test_different_block_sizes_are_not_uniform(self):
|
||||
spec = make_spec(FullAttentionSpec)
|
||||
|
||||
assert not are_uniform_specs(spec, replace(spec, block_size=32))
|
||||
|
||||
def test_registered_custom_spec_uses_base_uniform_rule(self):
|
||||
@register_kv_cache_spec(
|
||||
manager_class=FullAttentionManager,
|
||||
uniform_type_base_spec=FullAttentionSpec,
|
||||
)
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class _CustomFullSpec(FullAttentionSpec):
|
||||
custom_param: int = 16
|
||||
|
||||
custom_spec = _CustomFullSpec(
|
||||
block_size=64,
|
||||
num_kv_heads=8,
|
||||
head_size=128,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
assert are_uniform_specs(custom_spec, make_spec(FullAttentionSpec))
|
||||
@@ -53,7 +53,7 @@ class SchedulerConfig:
|
||||
In real usage, this should be set in `EngineArgs.create_engine_config`.
|
||||
"""
|
||||
|
||||
max_num_scheduled_tokens: int | None = None
|
||||
max_num_scheduled_tokens: int | None = Field(default=None, ge=0)
|
||||
"""Maximum number of tokens that the scheduler may issue in a single iteration.
|
||||
|
||||
This is usually equal to max_num_batched_tokens, but can be smaller in cases
|
||||
|
||||
@@ -38,9 +38,19 @@ from vllm.utils.network_utils import (
|
||||
is_valid_ipv6_address,
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_SPINLOOP_EXT:
|
||||
from vllm.spinloop import spinloop
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
SPINLOOP_EXT_ENABLED = False
|
||||
if envs.VLLM_USE_SPINLOOP_EXT:
|
||||
try:
|
||||
from vllm.spinloop import spinloop
|
||||
|
||||
SPINLOOP_EXT_ENABLED = True
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"spinloop extension could not be loaded, disabling VLLM_USE_SPINLOOP_EXT!"
|
||||
)
|
||||
SPINLOOP_TIMEOUT_SECONDS = 0.1
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -82,9 +92,6 @@ def to_bytes_big(value: int, size: int) -> bytes:
|
||||
return value.to_bytes(size, byteorder="big")
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
LONG_WAIT_TIME_LOG_MSG = (
|
||||
"No available shared memory broadcast block found "
|
||||
"in %d seconds. This typically happens "
|
||||
@@ -552,7 +559,7 @@ class MessageQueue:
|
||||
written_flag = metadata_buffer[0]
|
||||
return not (written_flag and read_count != self.buffer.n_reader)
|
||||
|
||||
if envs.VLLM_USE_SPINLOOP_EXT and not check():
|
||||
if SPINLOOP_EXT_ENABLED and not check():
|
||||
spinloop(metadata_buffer, check, timeout=SPINLOOP_TIMEOUT_SECONDS)
|
||||
|
||||
if not check():
|
||||
@@ -673,7 +680,7 @@ class MessageQueue:
|
||||
written_flag = metadata_buffer[0]
|
||||
return not (not written_flag or read_flag)
|
||||
|
||||
if envs.VLLM_USE_SPINLOOP_EXT and not check():
|
||||
if SPINLOOP_EXT_ENABLED and not check():
|
||||
spinloop(
|
||||
metadata_buffer[0 : self.local_reader_rank + 1],
|
||||
check,
|
||||
|
||||
@@ -13,7 +13,6 @@ from vllm.v1.core.kv_cache_utils import (
|
||||
)
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
SingleTypeKVCacheManager,
|
||||
spec_manager_map,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
@@ -21,6 +20,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
KVCacheSpec,
|
||||
UniformTypeKVCacheSpecs,
|
||||
)
|
||||
from vllm.v1.kv_cache_spec_registry import KVCacheSpecRegistry
|
||||
|
||||
# Dummy placeholder hash for store_mask's template computation.
|
||||
_DUMMY_BLOCK_HASH = BlockHash(b"\x00" * 32)
|
||||
@@ -89,7 +89,10 @@ class MooncakeStoreCoordinator:
|
||||
] = []
|
||||
for i, g in enumerate(self.kv_cache_groups):
|
||||
spec = _unwrap_spec(g.kv_cache_spec)
|
||||
manager_cls = spec_manager_map[type(spec)]
|
||||
manager_cls = KVCacheSpecRegistry.get_manager_class(spec)
|
||||
assert manager_cls is not None, (
|
||||
f"No manager registered for KVCacheSpec {spec}"
|
||||
)
|
||||
for existing_spec, group_ids, existing_cls in attention_groups:
|
||||
if existing_spec == spec:
|
||||
assert manager_cls is existing_cls
|
||||
|
||||
@@ -194,10 +194,16 @@ class OpenAIServingModels:
|
||||
lora_request.lora_name, lora_request.lora_path
|
||||
)
|
||||
) in str(e):
|
||||
raise LoRAAdapterNotFoundError(
|
||||
lora_request.lora_name, lora_request.lora_path
|
||||
) from e
|
||||
raise
|
||||
return create_error_response(
|
||||
LoRAAdapterNotFoundError(
|
||||
lora_request.lora_name, lora_request.lora_path
|
||||
)
|
||||
)
|
||||
return create_error_response(
|
||||
message=str(e),
|
||||
err_type="InternalServerError",
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
self.lora_requests[lora_name] = lora_request
|
||||
logger.info(
|
||||
|
||||
@@ -698,6 +698,13 @@ class Platform:
|
||||
mamba_padding_pct,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def register_custom_kv_cache_specs(cls, vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
Register custom KVCacheSpec class on current platform.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def verify_model_arch(cls, model_arch: str) -> None:
|
||||
"""
|
||||
|
||||
+69
-1
@@ -450,6 +450,66 @@ def _detect_content_format(
|
||||
return "openai"
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _detect_developer_role_support(chat_template: str) -> bool:
|
||||
return '"developer"' in chat_template or "'developer'" in chat_template
|
||||
|
||||
|
||||
def _convert_developer_to_system(
|
||||
conversation: list[ConversationMessage],
|
||||
) -> list[ConversationMessage]:
|
||||
converted: list[ConversationMessage] = []
|
||||
for msg in conversation:
|
||||
if msg["role"] == "developer":
|
||||
new_msg = dict(msg)
|
||||
new_msg["role"] = "system"
|
||||
new_msg.pop("tools", None)
|
||||
converted.append(new_msg) # type: ignore[arg-type]
|
||||
else:
|
||||
converted.append(msg)
|
||||
return converted
|
||||
|
||||
|
||||
def _consolidate_system_messages(
|
||||
conversation: list[ConversationMessage],
|
||||
) -> list[ConversationMessage]:
|
||||
"""Merge all system messages into one at position 0.
|
||||
|
||||
Some chat templates (e.g. Qwen 3.6) require the system message to be the
|
||||
very first message. After developer-to-system conversion, system messages
|
||||
may appear at non-first positions; this merges them into a single message.
|
||||
"""
|
||||
system_contents: list[str] = []
|
||||
non_system: list[ConversationMessage] = []
|
||||
needs_consolidation = False
|
||||
for i, msg in enumerate(conversation):
|
||||
if msg["role"] == "system":
|
||||
if i > 0 or system_contents:
|
||||
needs_consolidation = True
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for part in content:
|
||||
if isinstance(part, dict) and "text" in part:
|
||||
parts.append(part["text"])
|
||||
elif isinstance(part, str):
|
||||
parts.append(part)
|
||||
content = "\n".join(parts)
|
||||
if content:
|
||||
system_contents.append(content)
|
||||
else:
|
||||
non_system.append(msg)
|
||||
|
||||
if not needs_consolidation:
|
||||
return conversation
|
||||
|
||||
merged: ConversationMessage = {
|
||||
"role": "system",
|
||||
"content": "\n\n".join(system_contents),
|
||||
}
|
||||
return [merged, *non_system]
|
||||
|
||||
|
||||
def _resolve_chat_template_content_format(
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
@@ -653,7 +713,15 @@ def safe_apply_chat_template(
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."
|
||||
)
|
||||
|
||||
if any(
|
||||
msg["role"] == "developer" for msg in conversation
|
||||
) and not _detect_developer_role_support(chat_template):
|
||||
conversation = _convert_developer_to_system(conversation)
|
||||
conversation = _consolidate_system_messages(conversation)
|
||||
logger.info_once(
|
||||
"Chat template does not support the 'developer' message role. "
|
||||
"Converting developer messages to 'system' role.",
|
||||
)
|
||||
resolved_kwargs = resolve_chat_template_kwargs(
|
||||
tokenizer=tokenizer,
|
||||
chat_template=chat_template,
|
||||
|
||||
@@ -11,7 +11,6 @@ The fix streams string values incrementally as they arrive, providing a true
|
||||
streaming experience for long content.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
@@ -42,6 +41,7 @@ from vllm.tool_parsers.utils import (
|
||||
extract_types_from_schema,
|
||||
find_tool_properties,
|
||||
partial_tag_overlap,
|
||||
safe_literal_eval,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -110,7 +110,7 @@ class Glm4MoeModelToolParser(ToolParser):
|
||||
pass
|
||||
|
||||
try:
|
||||
return ast.literal_eval(value)
|
||||
return safe_literal_eval(value)
|
||||
except (ValueError, SyntaxError):
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
@@ -27,6 +26,7 @@ from vllm.tool_parsers.abstract_tool_parser import (
|
||||
Tool,
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.tool_parsers.utils import safe_literal_eval
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -183,13 +183,13 @@ class HYV3ToolParser(ToolParser):
|
||||
|
||||
@staticmethod
|
||||
def _deserialize(value: str) -> Any:
|
||||
"""Deserialize a string value using json.loads then ast.literal_eval."""
|
||||
"""Deserialize a string value using json.loads then safe_literal_eval."""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return ast.literal_eval(value)
|
||||
return safe_literal_eval(value)
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
@@ -28,7 +27,7 @@ from vllm.tool_parsers.abstract_tool_parser import (
|
||||
Tool,
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.tool_parsers.utils import partial_tag_overlap
|
||||
from vllm.tool_parsers.utils import partial_tag_overlap, safe_literal_eval
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -116,7 +115,7 @@ def _parse_arguments(json_value: str) -> tuple[Any, bool]:
|
||||
try:
|
||||
parsed_value = json.loads(json_value)
|
||||
except json.JSONDecodeError:
|
||||
parsed_value = ast.literal_eval(json_value)
|
||||
parsed_value = safe_literal_eval(json_value)
|
||||
return parsed_value, True
|
||||
except Exception:
|
||||
return json_value, False
|
||||
|
||||
@@ -11,7 +11,6 @@ The fix streams string values incrementally as they arrive, providing a true
|
||||
streaming experience for long content.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
@@ -41,6 +40,7 @@ from vllm.tool_parsers.abstract_tool_parser import (
|
||||
Tool,
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.tool_parsers.utils import safe_literal_eval
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -106,7 +106,7 @@ class PoolsideV1ToolParser(ToolParser):
|
||||
pass
|
||||
|
||||
try:
|
||||
return ast.literal_eval(value)
|
||||
return safe_literal_eval(value)
|
||||
except (ValueError, SyntaxError):
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
@@ -26,7 +25,7 @@ from vllm.tool_parsers.abstract_tool_parser import (
|
||||
Tool,
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.tool_parsers.utils import find_tool_properties
|
||||
from vllm.tool_parsers.utils import find_tool_properties, safe_literal_eval
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -824,7 +823,7 @@ class StreamingXMLToolCallParser:
|
||||
try:
|
||||
parsed_value = json.loads(raw_for_parse)
|
||||
except json.JSONDecodeError:
|
||||
parsed_value = ast.literal_eval(raw_for_parse)
|
||||
parsed_value = safe_literal_eval(raw_for_parse)
|
||||
output_arguments = json.dumps(parsed_value, ensure_ascii=False)
|
||||
except Exception:
|
||||
# Fallback: output as string as-is
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
@@ -23,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
|
||||
from vllm.tool_parsers.utils import safe_literal_eval
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -1016,7 +1016,7 @@ class StreamingXMLToolCallParser:
|
||||
raw_for_parse = raw_text + "\n"
|
||||
else:
|
||||
raw_for_parse = raw_text
|
||||
parsed_value = ast.literal_eval(raw_for_parse)
|
||||
parsed_value = safe_literal_eval(raw_for_parse)
|
||||
output_arguments = json.dumps(parsed_value, ensure_ascii=False)
|
||||
except Exception:
|
||||
# Fallback: output as string as-is
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import ast
|
||||
import json
|
||||
import warnings
|
||||
from json import JSONDecodeError, JSONDecoder
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
@@ -31,6 +32,12 @@ Tool: TypeAlias = ChatCompletionToolsParam | ResponsesTool
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def safe_literal_eval(text: str):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", SyntaxWarning)
|
||||
return ast.literal_eval(text)
|
||||
|
||||
|
||||
def partial_tag_overlap(text: str, tag: str) -> int:
|
||||
"""Length of the longest prefix of *tag* that matches a suffix of *text*.
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
SlidingWindowSpec,
|
||||
UniformTypeKVCacheSpecs,
|
||||
)
|
||||
from vllm.v1.kv_cache_spec_registry import KVCacheSpecRegistry
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.utils import tensor_data
|
||||
|
||||
@@ -1991,6 +1992,9 @@ def get_kv_cache_configs(
|
||||
"across workers. This is not supported yet."
|
||||
)
|
||||
|
||||
# Check if the KV cache specs are registered correctly.
|
||||
# This is to prevent that some layers are initialized with unregistered specs.
|
||||
KVCacheSpecRegistry.check_kv_cache_spec_registry(merged_kv_cache_specs)
|
||||
# Get global KV cache groups. This also handles spec unification for
|
||||
# hybrid models when disable_hybrid_kv_cache_manager is enabled.
|
||||
# After this call, merged_kv_cache_specs may be modified in-place.
|
||||
|
||||
@@ -104,7 +104,7 @@ class Scheduler(SchedulerInterface):
|
||||
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
|
||||
self.max_num_scheduled_tokens = (
|
||||
self.scheduler_config.max_num_scheduled_tokens
|
||||
if self.scheduler_config.max_num_scheduled_tokens
|
||||
if self.scheduler_config.max_num_scheduled_tokens is not None
|
||||
else self.scheduler_config.max_num_batched_tokens
|
||||
)
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
@@ -25,6 +25,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
SlidingWindowSpec,
|
||||
TQFullAttentionSpec,
|
||||
)
|
||||
from vllm.v1.kv_cache_spec_registry import KVCacheSpecRegistry
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
@@ -1247,27 +1248,30 @@ class SinkFullAttentionManager(FullAttentionManager):
|
||||
self.sink_blocks = self.block_pool.free_block_queue.popleft_n(num_sink_block)
|
||||
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
TQFullAttentionSpec: FullAttentionManager,
|
||||
MLAAttentionSpec: FullAttentionManager,
|
||||
HiddenStateCacheSpec: FullAttentionManager,
|
||||
SlidingWindowSpec: SlidingWindowManager,
|
||||
SlidingWindowMLASpec: SlidingWindowManager,
|
||||
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
|
||||
MambaSpec: MambaManager,
|
||||
CrossAttentionSpec: CrossAttentionManager,
|
||||
SinkFullAttentionSpec: SinkFullAttentionManager,
|
||||
}
|
||||
|
||||
|
||||
def get_manager_for_kv_cache_spec(
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
max_num_batched_tokens: int,
|
||||
max_model_len: int,
|
||||
**kwargs,
|
||||
) -> SingleTypeKVCacheManager:
|
||||
manager_class = spec_manager_map[type(kv_cache_spec)]
|
||||
"""
|
||||
Get the appropriate manager for a given KVCacheSpec.
|
||||
|
||||
Uses the KVCacheSpecRegistry to look up the manager class, supporting
|
||||
both built-in and custom specs registered via @register_kv_cache_spec
|
||||
and KVCacheSpecRegistry.register.
|
||||
|
||||
Args:
|
||||
kv_cache_spec: The KVCacheSpec instance
|
||||
max_num_batched_tokens: The maximum number of tokens in a batch
|
||||
max_model_len: The maximum context length the model could serve
|
||||
Returns:
|
||||
An instance of the appropriate SingleTypeKVCacheManager subclass
|
||||
"""
|
||||
manager_class = KVCacheSpecRegistry.get_manager_class(kv_cache_spec)
|
||||
assert manager_class is not None, (
|
||||
f"No manager registered for KVCacheSpec {type(kv_cache_spec)}"
|
||||
)
|
||||
# SlidingWindow / ChunkedLocalAttention managers recycle blocks across
|
||||
# chunks; the runtime admission cap must match the recycling-aware bound
|
||||
# the startup pool sizer uses (single source of truth: the spec method).
|
||||
@@ -1280,3 +1284,64 @@ def get_manager_for_kv_cache_spec(
|
||||
)
|
||||
manager = manager_class(kv_cache_spec, **kwargs)
|
||||
return manager
|
||||
|
||||
|
||||
def register_all_kvcache_specs(vllm_config):
|
||||
"""Built-in spec registration"""
|
||||
KVCacheSpecRegistry.register(
|
||||
FullAttentionSpec,
|
||||
FullAttentionManager,
|
||||
uniform_type_base_spec=FullAttentionSpec,
|
||||
)
|
||||
|
||||
KVCacheSpecRegistry.register(
|
||||
SlidingWindowSpec,
|
||||
SlidingWindowManager,
|
||||
uniform_type_base_spec=SlidingWindowSpec,
|
||||
)
|
||||
KVCacheSpecRegistry.register(
|
||||
SlidingWindowMLASpec,
|
||||
SlidingWindowManager,
|
||||
uniform_type_base_spec=SlidingWindowMLASpec,
|
||||
)
|
||||
|
||||
KVCacheSpecRegistry.register(
|
||||
MambaSpec, MambaManager, uniform_type_base_spec=MambaSpec
|
||||
)
|
||||
KVCacheSpecRegistry.register(
|
||||
ChunkedLocalAttentionSpec,
|
||||
ChunkedLocalAttentionManager,
|
||||
uniform_type_base_spec=ChunkedLocalAttentionSpec,
|
||||
)
|
||||
KVCacheSpecRegistry.register(
|
||||
CrossAttentionSpec,
|
||||
CrossAttentionManager,
|
||||
uniform_type_base_spec=CrossAttentionSpec,
|
||||
)
|
||||
|
||||
# FullAttentionSpec subclasses — grouped with FullAttentionSpec
|
||||
KVCacheSpecRegistry.register(
|
||||
TQFullAttentionSpec,
|
||||
FullAttentionManager,
|
||||
uniform_type_base_spec=FullAttentionSpec,
|
||||
)
|
||||
KVCacheSpecRegistry.register(
|
||||
MLAAttentionSpec, FullAttentionManager, uniform_type_base_spec=FullAttentionSpec
|
||||
)
|
||||
# NOTE(Mengqing): HiddenStateCacheSpec won't take part in
|
||||
# grouping, thus the uniform_type_base_spec is just a
|
||||
# placeholder.
|
||||
KVCacheSpecRegistry.register(
|
||||
HiddenStateCacheSpec,
|
||||
FullAttentionManager,
|
||||
uniform_type_base_spec=FullAttentionSpec,
|
||||
)
|
||||
KVCacheSpecRegistry.register(
|
||||
SinkFullAttentionSpec,
|
||||
SinkFullAttentionManager,
|
||||
uniform_type_base_spec=FullAttentionSpec,
|
||||
)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
current_platform.register_custom_kv_cache_specs(vllm_config)
|
||||
|
||||
@@ -52,6 +52,7 @@ from vllm.v1.core.kv_cache_utils import (
|
||||
)
|
||||
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.single_type_kv_cache_manager import register_all_kvcache_specs
|
||||
from vllm.v1.engine import (
|
||||
EEP_NOTIFICATION_CALL_ID,
|
||||
EEPNotificationType,
|
||||
@@ -235,6 +236,9 @@ class EngineCore:
|
||||
def _initialize_kv_caches(self, vllm_config: VllmConfig) -> KVCacheConfig:
|
||||
start = time.time()
|
||||
|
||||
# register all kvcache specs in enginecore process.
|
||||
register_all_kvcache_specs(vllm_config)
|
||||
|
||||
# Get all kv cache needed by the model
|
||||
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
from vllm.utils.torch_utils import get_dtype_size, nvfp4_kv_cache_full_dim
|
||||
from vllm.v1.attention.backends.registry import MambaAttentionBackendEnum
|
||||
from vllm.v1.kv_cache_spec_registry import KVCacheSpecRegistry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
@@ -139,6 +140,21 @@ class KVCacheSpec:
|
||||
)
|
||||
return copy.deepcopy(specs[0])
|
||||
|
||||
def is_uniform_with_collection(
|
||||
self, kv_cache_specs: dict[str, KVCacheSpec]
|
||||
) -> bool:
|
||||
"""
|
||||
Whether this KVCacheSpec is uniform with all specs of all layers.
|
||||
"""
|
||||
uniform_type_base_spec = KVCacheSpecRegistry.get_uniform_type_base_spec(self)
|
||||
assert uniform_type_base_spec is not None, (
|
||||
f"Unsupported KV cache spec type: {type(self)}. "
|
||||
"Please register it using @register_kv_cache_spec decorator."
|
||||
)
|
||||
return all(
|
||||
isinstance(spec, uniform_type_base_spec) for spec in kv_cache_specs.values()
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class AttentionSpec(KVCacheSpec):
|
||||
@@ -430,6 +446,15 @@ class ChunkedLocalAttentionSpec(AttentionSpec):
|
||||
)
|
||||
return max_blocks * self.page_size_bytes
|
||||
|
||||
def is_uniform_with_collection(
|
||||
self, kv_cache_specs: dict[str, KVCacheSpec]
|
||||
) -> bool:
|
||||
return all(
|
||||
isinstance(spec, ChunkedLocalAttentionSpec)
|
||||
and spec.attention_chunk_size == self.attention_chunk_size
|
||||
for spec in kv_cache_specs.values()
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class SlidingWindowSpec(AttentionSpec):
|
||||
@@ -493,6 +518,15 @@ class SlidingWindowSpec(AttentionSpec):
|
||||
)
|
||||
return max_blocks * self.page_size_bytes
|
||||
|
||||
def is_uniform_with_collection(
|
||||
self, kv_cache_specs: dict[str, KVCacheSpec]
|
||||
) -> bool:
|
||||
return all(
|
||||
isinstance(spec, SlidingWindowSpec)
|
||||
and spec.sliding_window == self.sliding_window
|
||||
for spec in kv_cache_specs.values()
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class SlidingWindowMLASpec(SlidingWindowSpec):
|
||||
@@ -558,6 +592,15 @@ class SlidingWindowMLASpec(SlidingWindowSpec):
|
||||
model_version=model_version_set.pop(),
|
||||
)
|
||||
|
||||
def is_uniform_with_collection(
|
||||
self, kv_cache_specs: dict[str, KVCacheSpec]
|
||||
) -> bool:
|
||||
return all(
|
||||
isinstance(spec, SlidingWindowMLASpec)
|
||||
and spec.sliding_window == self.sliding_window
|
||||
for spec in kv_cache_specs.values()
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MambaSpec(KVCacheSpec):
|
||||
@@ -590,6 +633,15 @@ class MambaSpec(KVCacheSpec):
|
||||
else:
|
||||
return self.page_size_bytes * (1 + self.num_speculative_blocks)
|
||||
|
||||
def is_uniform_with_collection(
|
||||
self, kv_cache_specs: dict[str, KVCacheSpec]
|
||||
) -> bool:
|
||||
return all(
|
||||
isinstance(spec, MambaSpec)
|
||||
and spec.num_speculative_blocks == self.num_speculative_blocks
|
||||
for spec in kv_cache_specs.values()
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EncoderOnlyAttentionSpec(AttentionSpec):
|
||||
@@ -689,53 +741,16 @@ class UniformTypeKVCacheSpecs(KVCacheSpec):
|
||||
def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool:
|
||||
"""
|
||||
Whether all layers have the same type of KV cache spec.
|
||||
|
||||
Uses the registry to determine grouping base classes, so custom specs
|
||||
that inherit from FullAttentionSpec are treated as full attention.
|
||||
"""
|
||||
block_sizes = set(spec.block_size for spec in kv_cache_specs.values())
|
||||
if len(block_sizes) > 1:
|
||||
# Different block sizes, not uniform.
|
||||
return False
|
||||
one_spec = next(iter(kv_cache_specs.values()))
|
||||
# NOTE: Check subclasses before parent classes since isinstance()
|
||||
# returns True for subclasses.
|
||||
if isinstance(one_spec, SlidingWindowMLASpec):
|
||||
# SlidingWindowMLASpec is uniform if all specs are SlidingWindowMLASpec
|
||||
# with the same sliding_window size.
|
||||
return all(
|
||||
isinstance(spec, SlidingWindowMLASpec)
|
||||
and spec.sliding_window == one_spec.sliding_window
|
||||
for spec in kv_cache_specs.values()
|
||||
)
|
||||
elif isinstance(one_spec, FullAttentionSpec):
|
||||
return all(
|
||||
isinstance(spec, FullAttentionSpec) for spec in kv_cache_specs.values()
|
||||
)
|
||||
elif isinstance(one_spec, CrossAttentionSpec):
|
||||
return all(
|
||||
isinstance(spec, CrossAttentionSpec) for spec in kv_cache_specs.values()
|
||||
)
|
||||
elif isinstance(one_spec, SlidingWindowSpec):
|
||||
return all(
|
||||
isinstance(spec, SlidingWindowSpec)
|
||||
and spec.sliding_window == one_spec.sliding_window
|
||||
for spec in kv_cache_specs.values()
|
||||
)
|
||||
elif isinstance(one_spec, ChunkedLocalAttentionSpec):
|
||||
return all(
|
||||
isinstance(spec, ChunkedLocalAttentionSpec)
|
||||
and spec.attention_chunk_size == one_spec.attention_chunk_size
|
||||
for spec in kv_cache_specs.values()
|
||||
)
|
||||
elif isinstance(one_spec, MambaSpec):
|
||||
return all(
|
||||
isinstance(spec, MambaSpec)
|
||||
and spec.num_speculative_blocks == one_spec.num_speculative_blocks
|
||||
for spec in kv_cache_specs.values()
|
||||
)
|
||||
else:
|
||||
# NOTE(Chen): Please add new branches for new KV cache spec types.
|
||||
raise NotImplementedError(
|
||||
f"Unsupported KV cache spec type: {type(one_spec)}"
|
||||
)
|
||||
first_spec = next(iter(kv_cache_specs.values()))
|
||||
return first_spec.is_uniform_with_collection(kv_cache_specs)
|
||||
|
||||
@classmethod
|
||||
def from_specs(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> Self | None:
|
||||
|
||||
@@ -0,0 +1,209 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Registry for KVCacheSpec types and their associated managers.
|
||||
|
||||
This module provides a pluggable architecture for registering custom KVCacheSpec
|
||||
subclasses without modifying vLLM core code. Out-of-tree platforms can define
|
||||
custom specs and managers by using the @register_kv_cache_spec decorator.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.single_type_kv_cache_manager import SingleTypeKVCacheManager
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KVCacheSpecMetadata:
|
||||
"""Metadata for a registered KVCacheSpec."""
|
||||
|
||||
kvcache_spec_cls: type["KVCacheSpec"]
|
||||
manager_class: type["SingleTypeKVCacheManager"]
|
||||
# The base spec class for grouping compatibility checks.
|
||||
# KVCacheSpecs with the same uniform_type_base_spec will be
|
||||
# grouped into one kvcache group
|
||||
uniform_type_base_spec: type["KVCacheSpec"]
|
||||
|
||||
|
||||
_REGISTRY_KVCACHESPEC_LIST: dict[type["KVCacheSpec"], KVCacheSpecMetadata] = {}
|
||||
|
||||
|
||||
class KVCacheSpecRegistry:
|
||||
"""Global registry for KVCacheSpec types and their associated managers."""
|
||||
|
||||
@classmethod
|
||||
def _ensure_registered(cls, vllm_config=None) -> None:
|
||||
"""
|
||||
Run full KVCacheSpec registration if the registration is not done.
|
||||
"""
|
||||
if _REGISTRY_KVCACHESPEC_LIST:
|
||||
return
|
||||
|
||||
if vllm_config is None:
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
|
||||
vllm_config = get_current_vllm_config_or_none()
|
||||
|
||||
# lazy import to avoid circular dependency
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
register_all_kvcache_specs,
|
||||
)
|
||||
|
||||
register_all_kvcache_specs(vllm_config)
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
cls,
|
||||
kvcache_spec_cls: type["KVCacheSpec"],
|
||||
manager_class: type["SingleTypeKVCacheManager"] | None = None,
|
||||
uniform_type_base_spec: type["KVCacheSpec"] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Register a KVCacheSpec class with its manager and base spec.
|
||||
|
||||
Args:
|
||||
kvcache_spec_cls: The KVCacheSpec subclass to register
|
||||
manager_class: The SingleTypeKVCacheManager to use for this spec
|
||||
uniform_type_base_spec: The base spec class for grouping compatibility.
|
||||
instead of being grouped to different kvcache group, `kvcache_spec_cls`
|
||||
and `uniform_type_base_spec` will be trated as uniform type.
|
||||
If None, defaults to kvcache_spec_cls itself (for built-in base specs).
|
||||
"""
|
||||
assert manager_class is not None, "manager_class is required"
|
||||
if uniform_type_base_spec is None:
|
||||
uniform_type_base_spec = kvcache_spec_cls
|
||||
assert issubclass(kvcache_spec_cls, uniform_type_base_spec), (
|
||||
f"{kvcache_spec_cls.__name__} must inherit from its declared "
|
||||
f"uniform_type_base_spec {uniform_type_base_spec.__name__}."
|
||||
)
|
||||
|
||||
if kvcache_spec_cls in _REGISTRY_KVCACHESPEC_LIST:
|
||||
registered_spec = _REGISTRY_KVCACHESPEC_LIST[kvcache_spec_cls]
|
||||
is_same_registration = (
|
||||
manager_class == registered_spec.manager_class
|
||||
and uniform_type_base_spec == registered_spec.uniform_type_base_spec
|
||||
)
|
||||
assert is_same_registration, (
|
||||
f"Conflicting registration for KVCacheSpec "
|
||||
f": {kvcache_spec_cls.__name__}"
|
||||
)
|
||||
|
||||
_REGISTRY_KVCACHESPEC_LIST[kvcache_spec_cls] = KVCacheSpecMetadata(
|
||||
kvcache_spec_cls=kvcache_spec_cls,
|
||||
manager_class=manager_class,
|
||||
uniform_type_base_spec=uniform_type_base_spec,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_manager_class(
|
||||
cls, kvcache_spec: "KVCacheSpec"
|
||||
) -> type["SingleTypeKVCacheManager"] | None:
|
||||
"""
|
||||
Get the single type kvcache manager class for a given kvcache spec instance.
|
||||
|
||||
Args:
|
||||
kvcache_spec: A KVCacheSpec instance
|
||||
|
||||
Returns:
|
||||
The SingleTypeKVCacheManager class to use for this kvcache_spec
|
||||
"""
|
||||
cls._ensure_registered()
|
||||
kvcache_spec_cls = type(kvcache_spec)
|
||||
|
||||
# Walk up the MRO to find a registered base class
|
||||
for base in kvcache_spec_cls.__mro__:
|
||||
if base in _REGISTRY_KVCACHESPEC_LIST:
|
||||
return _REGISTRY_KVCACHESPEC_LIST[base].manager_class
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_uniform_type_base_spec(
|
||||
cls, kvcache_spec: "KVCacheSpec"
|
||||
) -> type["KVCacheSpec"] | None:
|
||||
"""
|
||||
Get the base kvcache spec class for grouping compatibility checks.
|
||||
KVCacheSpecs with uniform_type_base_spec will be trated as one group.
|
||||
|
||||
Args:
|
||||
kvcache_spec: A KVCacheSpec instance
|
||||
|
||||
Returns:
|
||||
The base KVCacheSpec class for checking uniform type kvcache specs
|
||||
"""
|
||||
cls._ensure_registered()
|
||||
kvcache_spec_cls = type(kvcache_spec)
|
||||
|
||||
# Walk up the MRO to find a registered base spec
|
||||
for base in kvcache_spec_cls.__mro__:
|
||||
if base in _REGISTRY_KVCACHESPEC_LIST:
|
||||
return _REGISTRY_KVCACHESPEC_LIST[base].uniform_type_base_spec
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def check_kv_cache_spec_registry(
|
||||
cls, kv_cache_spec: dict[str, "KVCacheSpec"]
|
||||
) -> None:
|
||||
"""
|
||||
Check if the KVCacheSpecs of each layer are registered as expected.
|
||||
"""
|
||||
cls._ensure_registered()
|
||||
for layer_name, spec in kv_cache_spec.items():
|
||||
# use raise instead of assert to make it effective in production environment
|
||||
if cls.get_uniform_type_base_spec(spec) is None:
|
||||
raise ValueError(
|
||||
f"Unsupported KV cache spec type for layer {layer_name}: "
|
||||
f"{type(spec)}. Please register it using "
|
||||
f"@register_kv_cache_spec decorator."
|
||||
)
|
||||
if cls.get_manager_class(spec) is None:
|
||||
raise ValueError(
|
||||
f"No manager found for KV cache spec type for layer "
|
||||
f"{layer_name}: {type(spec)}. Please register it using "
|
||||
f"@register_kv_cache_spec decorator."
|
||||
)
|
||||
|
||||
|
||||
def register_kv_cache_spec(
|
||||
manager_class: type["SingleTypeKVCacheManager"] | None = None,
|
||||
uniform_type_base_spec: type["KVCacheSpec"] | None = None,
|
||||
):
|
||||
"""
|
||||
Decorator to register a custom KVCacheSpec class.
|
||||
|
||||
Args:
|
||||
manager_class: The SingleTypeKVCacheManager to use for this spec.
|
||||
Required for all registered specs.
|
||||
uniform_type_base_spec: The base spec class for uniform type kv cache specs
|
||||
compatibility. If None, the spec is treated as a new base
|
||||
type.
|
||||
|
||||
Examples:
|
||||
- Register a new specs:
|
||||
@register_kv_cache_spec(
|
||||
manager_class=FullAttentionManager,
|
||||
uniform_type_base_spec=FullAttentionSpec
|
||||
)
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class CustomFullAttentionSpec(FullAttentionSpec):
|
||||
pass
|
||||
"""
|
||||
|
||||
def decorator(kvcache_spec_cls: type["KVCacheSpec"]) -> type["KVCacheSpec"]:
|
||||
KVCacheSpecRegistry.register(
|
||||
kvcache_spec_cls=kvcache_spec_cls,
|
||||
manager_class=manager_class,
|
||||
uniform_type_base_spec=uniform_type_base_spec,
|
||||
)
|
||||
return kvcache_spec_cls
|
||||
|
||||
return decorator
|
||||
@@ -19,6 +19,69 @@ if HAS_TRITON:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
_FLASHINFER_MIN_VERSION = "0.2.3"
|
||||
|
||||
|
||||
def flashinfer_sampler_supported() -> bool:
|
||||
"""Decide whether FlashInfer's top-p/top-k sampler can be used.
|
||||
|
||||
Returns False (with appropriate logging) when ``VLLM_USE_FLASHINFER_SAMPLER``
|
||||
is 0, when the platform isn't CUDA, when the GPU's compute capability is
|
||||
unsupported, or when the installed flashinfer is missing or too old. Raises
|
||||
``RuntimeError`` if the user explicitly opted in via the env var but
|
||||
FlashInfer is unavailable.
|
||||
|
||||
Note: callers must additionally ensure ``logprobs_mode`` doesn't require
|
||||
post-top-k/top-p logits/logprobs for any request whose logprobs will be
|
||||
returned in this step, since FlashInfer doesn't expose those.
|
||||
"""
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
if not envs.VLLM_USE_FLASHINFER_SAMPLER:
|
||||
logger.info_once(
|
||||
"FlashInfer top-p/top-k sampling disabled via "
|
||||
"VLLM_USE_FLASHINFER_SAMPLER=0."
|
||||
)
|
||||
return False
|
||||
from vllm.v1.attention.backends.flashinfer import FlashInferBackend
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
assert capability is not None
|
||||
unsupported_reason: str | None = None
|
||||
if not FlashInferBackend.supports_compute_capability(capability):
|
||||
unsupported_reason = (
|
||||
f"unsupported compute capability {capability.as_version_str()}"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
import flashinfer
|
||||
|
||||
if version.parse(flashinfer.__version__) < version.parse(
|
||||
_FLASHINFER_MIN_VERSION
|
||||
):
|
||||
unsupported_reason = (
|
||||
f"flashinfer {flashinfer.__version__} is too old "
|
||||
f"(>={_FLASHINFER_MIN_VERSION} required)"
|
||||
)
|
||||
except ImportError:
|
||||
unsupported_reason = "flashinfer is not installed"
|
||||
|
||||
if unsupported_reason is None:
|
||||
logger.info_once("Using FlashInfer for top-p & top-k sampling.", scope="global")
|
||||
return True
|
||||
if envs.is_set("VLLM_USE_FLASHINFER_SAMPLER"):
|
||||
raise RuntimeError(
|
||||
f"FlashInfer top-p/top-k sampling unavailable: {unsupported_reason}. "
|
||||
"Unset VLLM_USE_FLASHINFER_SAMPLER=1."
|
||||
)
|
||||
logger.warning_once(
|
||||
"FlashInfer top-p/top-k sampling unavailable: %s; falling back. "
|
||||
"Set VLLM_USE_FLASHINFER_SAMPLER=0 to silence.",
|
||||
unsupported_reason,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
class TopKTopPSampler(nn.Module):
|
||||
"""
|
||||
Module that performs optional top-k and top-p filtering followed by
|
||||
@@ -30,49 +93,16 @@ class TopKTopPSampler(nn.Module):
|
||||
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None:
|
||||
super().__init__()
|
||||
self.logprobs_mode = logprobs_mode
|
||||
# flashinfer optimization does not apply if intermediate
|
||||
# logprobs/logits after top_k/top_p need to be returned
|
||||
if (
|
||||
logprobs_mode not in ("processed_logits", "processed_logprobs")
|
||||
and current_platform.is_cuda()
|
||||
):
|
||||
if envs.VLLM_USE_FLASHINFER_SAMPLER:
|
||||
from vllm.v1.attention.backends.flashinfer import FlashInferBackend
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
assert capability is not None
|
||||
if FlashInferBackend.supports_compute_capability(capability):
|
||||
logger.info_once(
|
||||
"Using FlashInfer for top-p & top-k sampling.",
|
||||
scope="global",
|
||||
)
|
||||
self.forward = self.forward_cuda
|
||||
elif envs.is_set("VLLM_USE_FLASHINFER_SAMPLER"):
|
||||
# User explicitly opted in but the GPU can't run FlashInfer.
|
||||
capability_str = capability.as_version_str()
|
||||
raise RuntimeError(
|
||||
"FlashInfer does not support compute capability "
|
||||
f"{capability_str}, unset VLLM_USE_FLASHINFER_SAMPLER=1."
|
||||
)
|
||||
else:
|
||||
# Default-on path; hardware can't run FlashInfer →
|
||||
# quietly fall back to the PyTorch-native sampler
|
||||
# instead of failing server startup.
|
||||
logger.warning_once(
|
||||
"FlashInfer top-p/top-k sampling not supported on "
|
||||
"compute capability %s; falling back to PyTorch-native "
|
||||
"sampler. Set VLLM_USE_FLASHINFER_SAMPLER=0 to silence.",
|
||||
capability.as_version_str(),
|
||||
)
|
||||
self.forward = self.forward_native
|
||||
else:
|
||||
# User explicitly set VLLM_USE_FLASHINFER_SAMPLER=0.
|
||||
logger.info_once(
|
||||
"FlashInfer top-p/top-k sampling disabled via "
|
||||
"VLLM_USE_FLASHINFER_SAMPLER=0; using PyTorch-native sampler."
|
||||
)
|
||||
self.forward = self.forward_native
|
||||
|
||||
if current_platform.is_cuda():
|
||||
# FlashInfer doesn't expose post-top-k/top-p logits/logprobs,
|
||||
# so it can't be used when the configured mode requires them.
|
||||
can_use_flashinfer = (
|
||||
logprobs_mode not in ("processed_logits", "processed_logprobs")
|
||||
and flashinfer_sampler_supported()
|
||||
)
|
||||
self.forward = (
|
||||
self.forward_cuda if can_use_flashinfer else self.forward_native
|
||||
)
|
||||
elif current_platform.is_cpu():
|
||||
arch = current_platform.get_cpu_architecture()
|
||||
# Fall back to native implementation for POWERPC and RISCV.
|
||||
@@ -417,7 +447,7 @@ def flashinfer_sample(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
generators: dict[int, torch.Generator],
|
||||
generators: dict[int, torch.Generator] = {}, # noqa
|
||||
) -> torch.Tensor:
|
||||
"""Sample from the logits using FlashInfer.
|
||||
|
||||
@@ -431,11 +461,6 @@ def flashinfer_sample(
|
||||
"""
|
||||
import flashinfer
|
||||
|
||||
if version.parse(flashinfer.__version__) < version.parse("0.2.3"):
|
||||
raise ImportError(
|
||||
"FlashInfer version >= 0.2.3 required for top-k and top-p sampling. "
|
||||
)
|
||||
|
||||
assert not (k is None and p is None)
|
||||
if k is None:
|
||||
# Top-p only.
|
||||
|
||||
@@ -7,6 +7,11 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.config.model import LogprobsMode
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import (
|
||||
apply_top_k_top_p,
|
||||
flashinfer_sample,
|
||||
flashinfer_sampler_supported,
|
||||
)
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
|
||||
from vllm.v1.worker.gpu.sample.bad_words import BadWordsState
|
||||
@@ -45,6 +50,7 @@ class Sampler:
|
||||
self.bad_words_state = BadWordsState(req_states)
|
||||
self.logprob_token_ids_state = LogprobTokenIdsState(max_num_reqs, device)
|
||||
self.num_speculative_tokens = num_speculative_tokens
|
||||
self.use_flashinfer = flashinfer_sampler_supported()
|
||||
|
||||
def add_request(
|
||||
self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
|
||||
@@ -77,6 +83,13 @@ class Sampler:
|
||||
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
|
||||
# that num_nans is computed before applying penalties and temperature.
|
||||
num_nans = get_num_nans(logits) if self.compute_nans else None
|
||||
|
||||
max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
|
||||
max_per_req_token_ids = self.logprob_token_ids_state.max_num_token_ids(
|
||||
idx_mapping_np
|
||||
)
|
||||
return_logprobs = max_num_logprobs != NO_LOGPROBS or max_per_req_token_ids > 0
|
||||
|
||||
sampled, processed_logits = self.sample(
|
||||
logits,
|
||||
expanded_idx_mapping,
|
||||
@@ -84,13 +97,10 @@ class Sampler:
|
||||
pos,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
return_logprobs=return_logprobs,
|
||||
)
|
||||
|
||||
max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
|
||||
max_per_req_token_ids = self.logprob_token_ids_state.max_num_token_ids(
|
||||
idx_mapping_np
|
||||
)
|
||||
if max_num_logprobs != NO_LOGPROBS or max_per_req_token_ids > 0:
|
||||
if return_logprobs:
|
||||
if self.logprobs_mode == "processed_logprobs":
|
||||
logits = processed_logits
|
||||
expanded_logits = logits.shape[0] != idx_mapping_np.shape[0]
|
||||
@@ -128,6 +138,7 @@ class Sampler:
|
||||
pos: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
skip_top_k_top_p: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# Copy logits to a new FP32 tensor.
|
||||
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
|
||||
@@ -163,6 +174,9 @@ class Sampler:
|
||||
# Apply min_p in place.
|
||||
self.sampling_states.apply_min_p(logits, expanded_idx_mapping, idx_mapping_np)
|
||||
|
||||
if skip_top_k_top_p:
|
||||
return logits
|
||||
|
||||
# Apply top_k and/or top_p. This might or might not return a new tensor.
|
||||
return self.sampling_states.apply_top_k_top_p(
|
||||
logits, expanded_idx_mapping, idx_mapping_np
|
||||
@@ -176,6 +190,7 @@ class Sampler:
|
||||
pos: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
return_logprobs: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
processed_logits = self.apply_sampling_params(
|
||||
logits,
|
||||
@@ -184,16 +199,33 @@ class Sampler:
|
||||
pos,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
skip_top_k_top_p=True,
|
||||
)
|
||||
top_k, top_p = self.sampling_states.get_top_k_top_p(
|
||||
expanded_idx_mapping, idx_mapping_np
|
||||
)
|
||||
use_flashinfer = self.use_flashinfer and not (
|
||||
# Don't use FI sampler if no requests use top_k/top_p, if there are
|
||||
# any greedy requests or per-request seeds, or if post-processed
|
||||
# logprobs need to be returned for any requests.
|
||||
(top_k is None and top_p is None)
|
||||
or (return_logprobs and self.logprobs_mode == "processed_logprobs")
|
||||
or self.sampling_states.any_greedy(idx_mapping_np)
|
||||
or self.sampling_states.any_explicit_seed(idx_mapping_np)
|
||||
)
|
||||
|
||||
# Sample the next token.
|
||||
sampled = gumbel_sample(
|
||||
processed_logits,
|
||||
expanded_idx_mapping,
|
||||
self.sampling_states.temperature.gpu,
|
||||
self.sampling_states.seeds.gpu,
|
||||
pos,
|
||||
apply_temperature=False,
|
||||
use_fp64=self.use_fp64_gumbel,
|
||||
)
|
||||
if use_flashinfer:
|
||||
sampled = flashinfer_sample(processed_logits, top_k, top_p).to(torch.int64)
|
||||
else:
|
||||
processed_logits = apply_top_k_top_p(processed_logits, top_k, top_p)
|
||||
sampled = gumbel_sample(
|
||||
processed_logits,
|
||||
expanded_idx_mapping,
|
||||
self.sampling_states.temperature.gpu,
|
||||
self.sampling_states.seeds.gpu,
|
||||
pos,
|
||||
apply_temperature=False,
|
||||
use_fp64=self.use_fp64_gumbel,
|
||||
)
|
||||
return sampled, processed_logits
|
||||
|
||||
@@ -24,6 +24,9 @@ class SamplingStates:
|
||||
self.top_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
|
||||
self.min_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
|
||||
self.seeds = UvaBackedTensor(max_num_reqs, dtype=torch.int64)
|
||||
# Tracks whether `seed` was set explicitly by the user, so callers
|
||||
# can fall back from RNG paths that don't honor per-request seeds.
|
||||
self.seeds_set = np.zeros(max_num_reqs, dtype=bool)
|
||||
|
||||
# Initialize top_k and top_p manually because 0 is an invalid value for them.
|
||||
self.top_k.np.fill(self.vocab_size)
|
||||
@@ -45,6 +48,7 @@ class SamplingStates:
|
||||
self.min_p.np[req_idx] = sampling_params.min_p
|
||||
|
||||
seed = sampling_params.seed
|
||||
self.seeds_set[req_idx] = seed is not None
|
||||
if seed is None:
|
||||
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
|
||||
self.seeds.np[req_idx] = seed
|
||||
@@ -85,20 +89,31 @@ class SamplingStates:
|
||||
return
|
||||
apply_min_p(logits, expanded_idx_mapping, self.min_p.gpu)
|
||||
|
||||
def get_top_k_top_p(
|
||||
self, expanded_idx_mapping: torch.Tensor, idx_mapping_np: np.ndarray
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
|
||||
do_top_p = np.any(self.top_p.np[idx_mapping_np] != 1.0)
|
||||
top_k = self.top_k.gpu[expanded_idx_mapping] if do_top_k else None
|
||||
top_p = self.top_p.gpu[expanded_idx_mapping] if do_top_p else None
|
||||
return top_k, top_p
|
||||
|
||||
def apply_top_k_top_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
) -> torch.Tensor:
|
||||
do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
|
||||
do_top_p = np.any(self.top_p.np[idx_mapping_np] != 1.0)
|
||||
if not (do_top_k or do_top_p):
|
||||
top_k, top_p = self.get_top_k_top_p(expanded_idx_mapping, idx_mapping_np)
|
||||
if top_k is None and top_p is None:
|
||||
return logits
|
||||
|
||||
top_k = self.top_k.gpu[expanded_idx_mapping] if do_top_k else None
|
||||
top_p = self.top_p.gpu[expanded_idx_mapping] if do_top_p else None
|
||||
return apply_top_k_top_p(logits, top_k, top_p)
|
||||
|
||||
def any_greedy(self, idx_mapping_np: np.ndarray) -> bool:
|
||||
return bool(np.any(self.temperature.np[idx_mapping_np] == 0.0))
|
||||
|
||||
def any_explicit_seed(self, idx_mapping_np: np.ndarray) -> bool:
|
||||
return bool(np.any(self.seeds_set[idx_mapping_np]))
|
||||
|
||||
def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int:
|
||||
return int(np.max(self.num_logprobs[idx_mapping_np]))
|
||||
|
||||
@@ -152,6 +152,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
SlidingWindowSpec,
|
||||
UniformTypeKVCacheSpecs,
|
||||
)
|
||||
from vllm.v1.kv_cache_spec_registry import KVCacheSpecRegistry
|
||||
from vllm.v1.outputs import (
|
||||
EMPTY_MODEL_RUNNER_OUTPUT,
|
||||
AsyncModelRunnerOutput,
|
||||
@@ -6235,6 +6236,7 @@ class GPUModelRunner(
|
||||
)
|
||||
|
||||
kv_cache_spec = self.get_kv_cache_spec()
|
||||
KVCacheSpecRegistry.check_kv_cache_spec_registry(kv_cache_spec)
|
||||
kv_cache_groups = get_kv_cache_groups(self.vllm_config, kv_cache_spec)
|
||||
min_blocks = self.compilation_config.max_cudagraph_capture_size or 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user