diff --git a/CMakeLists.txt b/CMakeLists.txt index 9e8ccfa8718..0652a5f066e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -112,6 +112,8 @@ endif() # # spinloop extension (pure CXX; must stay above the non-CUDA device branch so # CPU builds define the target before the early return) +# This extension requires SABI 3.11 since it relies on Py_buffer support. Loading +# failure is handled gracefully on vLLM side for lower Python versions. # set(VLLM_SPINLOOP_EXT_SRC "csrc/spinloop.cpp") set(SPINLOOP_COMPILE_FLAGS "") @@ -309,14 +311,9 @@ set(VLLM_EXT_SRC "csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu" "csrc/quantization/activation_kernels.cu" "csrc/cuda_utils_kernels.cu" - "csrc/custom_all_reduce.cu" - "csrc/torch_bindings.cpp" - "csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu") + "csrc/torch_bindings.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") - list(APPEND VLLM_EXT_SRC - "csrc/minimax_reduce_rms_kernel.cu") - SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. @@ -503,12 +500,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND ES_MXFP8_GROUPED_MM_ARCHS) set(SRCS - "csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu" - "csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu") + "csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu" + "csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${ES_MXFP8_GROUPED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") list(APPEND VLLM_GPU_FLAGS "-DENABLE_ES_MXFP8_GROUPED_MM_SM100=1") message(STATUS "Building ES MXFP8 grouped kernels for archs: ${ES_MXFP8_GROUPED_MM_ARCHS}") else() @@ -598,7 +595,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() if (VLLM_GPU_LANG STREQUAL "HIP") - # Add QuickReduce kernels + # Add QuickReduce kernels (ROCm-only; not part of stable ABI migration). list(APPEND VLLM_EXT_SRC "csrc/custom_quickreduce.cu" ) @@ -649,7 +646,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") "csrc/libtorch_stable/attention/paged_attention_v1.cu" "csrc/libtorch_stable/attention/paged_attention_v2.cu" "csrc/libtorch_stable/cache_kernels.cu" - "csrc/libtorch_stable/cache_kernels_fused.cu") + "csrc/libtorch_stable/cache_kernels.cu" + "csrc/libtorch_stable/cache_kernels_fused.cu" + "csrc/libtorch_stable/custom_all_reduce.cu" + "csrc/libtorch_stable/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_STABLE_EXT_SRC @@ -659,7 +659,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") "csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu" "csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/libtorch_stable/permute_cols.cu" - "csrc/libtorch_stable/quantization/awq/gemm_kernels.cu") + "csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu" + "csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu" + "csrc/libtorch_stable/quantization/awq/gemm_kernels.cu" + "csrc/libtorch_stable/minimax_reduce_rms_kernel.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_STABLE_EXT_SRC}" diff --git a/csrc/custom_all_reduce.cu b/csrc/libtorch_stable/custom_all_reduce.cu similarity index 58% rename from csrc/custom_all_reduce.cu rename to csrc/libtorch_stable/custom_all_reduce.cu index a38d6fa24a2..0f7f759949a 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/libtorch_stable/custom_all_reduce.cu @@ -1,7 +1,11 @@ -#include -#include -#include -#include +#include "torch_utils.h" + +#include +#include +#include +#include +#include +#include #include "custom_all_reduce.cuh" @@ -11,7 +15,7 @@ using fptr_t = int64_t; static_assert(sizeof(void*) == sizeof(fptr_t)); fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, - torch::Tensor& rank_data, int64_t rank, + torch::stable::Tensor& rank_data, int64_t rank, bool fully_connected) { int world_size = fake_ipc_ptrs.size(); if (world_size > 8) @@ -25,9 +29,9 @@ fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, for (int i = 0; i < world_size; i++) { ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); } - return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(), - rank_data.numel(), rank, world_size, - fully_connected); + return (fptr_t) new vllm::CustomAllreduce( + ipc_ptrs, rank_data.mutable_data_ptr(), rank_data.numel(), rank, + world_size, fully_connected); } /** @@ -46,10 +50,14 @@ fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, * 5. A[None].expand(2, -1, -1, -1): Not OK * 6. A[:, 1:, 1:]: Not OK */ -bool _is_weak_contiguous(torch::Tensor& t) { - return t.is_contiguous() || - (t.storage().nbytes() - t.storage_offset() * t.element_size() == - t.numel() * t.element_size()); +bool _is_weak_contiguous(torch::stable::Tensor& t) { + if (t.is_contiguous()) { + return true; + } + int64_t storage_nbytes = 0; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_storage_size(t.get(), &storage_nbytes)); + return storage_nbytes - t.storage_offset() * t.element_size() == + static_cast(t.numel() * t.element_size()); } /** @@ -59,42 +67,45 @@ bool _is_weak_contiguous(torch::Tensor& t) { * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first * copied into _reg_buffer. */ -void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, - fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) { +void all_reduce(fptr_t _fa, torch::stable::Tensor& inp, + torch::stable::Tensor& out, fptr_t _reg_buffer, + int64_t reg_buffer_sz_bytes) { auto fa = reinterpret_cast(_fa); - const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); - auto stream = c10::cuda::getCurrentCUDAStream().stream(); + const torch::stable::accelerator::DeviceGuard device_guard( + inp.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(inp.get_device_index()); - TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); - TORCH_CHECK_EQ(inp.numel(), out.numel()); - TORCH_CHECK(_is_weak_contiguous(out)); - TORCH_CHECK(_is_weak_contiguous(inp)); + STD_TORCH_CHECK((inp.scalar_type()) == (out.scalar_type())); + STD_TORCH_CHECK((inp.numel()) == (out.numel())); + STD_TORCH_CHECK(_is_weak_contiguous(out)); + STD_TORCH_CHECK(_is_weak_contiguous(inp)); auto input_size = inp.numel() * inp.element_size(); auto reg_buffer = reinterpret_cast(_reg_buffer); if (reg_buffer) { - TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes); - AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size, - cudaMemcpyDeviceToDevice, stream)); + STD_TORCH_CHECK((input_size) <= (reg_buffer_sz_bytes)); + STD_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.const_data_ptr(), input_size, + cudaMemcpyDeviceToDevice, stream)); } else { - reg_buffer = inp.data_ptr(); + reg_buffer = inp.mutable_data_ptr(); } switch (out.scalar_type()) { - case at::ScalarType::Float: { + case torch::headeronly::ScalarType::Float: { fa->allreduce(stream, reinterpret_cast(reg_buffer), - reinterpret_cast(out.data_ptr()), + reinterpret_cast(out.mutable_data_ptr()), out.numel()); break; } - case at::ScalarType::Half: { + case torch::headeronly::ScalarType::Half: { fa->allreduce(stream, reinterpret_cast(reg_buffer), - reinterpret_cast(out.data_ptr()), out.numel()); + reinterpret_cast(out.mutable_data_ptr()), + out.numel()); break; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) - case at::ScalarType::BFloat16: { + case torch::headeronly::ScalarType::BFloat16: { fa->allreduce( stream, reinterpret_cast(reg_buffer), - reinterpret_cast(out.data_ptr()), out.numel()); + reinterpret_cast(out.mutable_data_ptr()), out.numel()); break; } #endif @@ -112,7 +123,7 @@ int64_t meta_size() { return sizeof(vllm::Signal); } void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs) { auto fa = reinterpret_cast(_fa); - TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_); + STD_TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_); void* ipc_ptrs[8]; for (int i = 0; i < fake_ipc_ptrs.size(); i++) { ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); @@ -143,47 +154,49 @@ void register_graph_buffers(fptr_t _fa, fa->register_graph_buffers(bytes, offsets); } -std::tuple allocate_shared_buffer_and_handle( +std::tuple allocate_shared_buffer_and_handle( int64_t size) { - auto device_index = c10::cuda::current_device(); - at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index)); + int device_index; + STD_CUDA_CHECK(cudaGetDevice(&device_index)); + const torch::stable::accelerator::DeviceGuard device_guard(device_index); void* buffer; cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; - auto stream = c10::cuda::getCurrentCUDAStream().stream(); - AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); + const cudaStream_t stream = get_current_cuda_stream(device_index); + STD_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); // Allocate buffer #if defined(USE_ROCM) // data buffers need to be "uncached" for signal on MI200 - AT_CUDA_CHECK( + STD_CUDA_CHECK( hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached)); #else - AT_CUDA_CHECK(cudaMalloc((void**)&buffer, size)); + STD_CUDA_CHECK(cudaMalloc((void**)&buffer, size)); #endif - AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream)); - AT_CUDA_CHECK(cudaStreamSynchronize(stream)); - AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); + STD_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream)); + STD_CUDA_CHECK(cudaStreamSynchronize(stream)); + STD_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); // Create IPC memhandle for the allocated buffer. // Will use it in open_mem_handle. - auto options = - torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); - auto handle = - torch::empty({static_cast(sizeof(cudaIpcMemHandle_t))}, options); - AT_CUDA_CHECK( - cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data_ptr(), buffer)); + auto handle = torch::stable::empty( + {static_cast(sizeof(cudaIpcMemHandle_t))}, + torch::headeronly::ScalarType::Byte, std::nullopt, + torch::stable::Device(torch::stable::DeviceType::CPU)); + STD_CUDA_CHECK(cudaIpcGetMemHandle( + (cudaIpcMemHandle_t*)handle.mutable_data_ptr(), buffer)); return std::make_tuple(reinterpret_cast(buffer), handle); } -fptr_t open_mem_handle(torch::Tensor& mem_handle) { +fptr_t open_mem_handle(torch::stable::Tensor& mem_handle) { void* ipc_ptr; - AT_CUDA_CHECK(cudaIpcOpenMemHandle( - (void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data_ptr()), + STD_CUDA_CHECK(cudaIpcOpenMemHandle( + (void**)&ipc_ptr, + *((const cudaIpcMemHandle_t*)mem_handle.const_data_ptr()), cudaIpcMemLazyEnablePeerAccess)); return reinterpret_cast(ipc_ptr); } void free_shared_buffer(fptr_t buffer) { - AT_CUDA_CHECK(cudaFree(reinterpret_cast(buffer))); + STD_CUDA_CHECK(cudaFree(reinterpret_cast(buffer))); } diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/libtorch_stable/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu similarity index 89% rename from csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu rename to csrc/libtorch_stable/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu index e4d432cac97..a5f3f03de00 100644 --- a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +++ b/csrc/libtorch_stable/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -28,7 +28,20 @@ * [bs*576, bs*576 + bs*8): UE8M0 scales, 7 real + 1 pad per token */ +#include "torch_utils.h" + +#include +#include +#include +#include +#include +#include + #include +#include "cuda_compat.h" +#include "dispatch_utils.h" +#include "type_convert.cuh" + #ifndef USE_ROCM #include #else @@ -37,14 +50,6 @@ #include #include -#include -#include -#include - -#include "cuda_compat.h" -#include "dispatch_utils.h" -#include "type_convert.cuh" - #ifndef FINAL_MASK #ifdef USE_ROCM #define FINAL_MASK 0xffffffffffffffffULL @@ -70,7 +75,7 @@ namespace deepseek_v4_fused_ops { namespace { inline int getSMVersion() { - auto* props = at::cuda::getCurrentDeviceProperties(); + auto* props = get_device_prop(); return props->major * 10 + props->minor; } } // namespace @@ -564,7 +569,7 @@ static void launchFusedDeepseekV4Templated( // bf16 on pre-Ampere (sm_70/sm_75) because _typeConvert is // unavailable there. Refuse the launch loudly instead of silently // skipping the work. - TORCH_CHECK( + STD_TORCH_CHECK( sm_version >= 80, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert requires sm_80+ " "(Ampere or newer); got sm_", @@ -635,7 +640,7 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( DISPATCH(64) DISPATCH(128) default: - TORCH_CHECK(false, + STD_TORCH_CHECK(false, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert: " "unsupported num_heads_q_padded=", num_heads_q_padded, @@ -650,71 +655,80 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( // ──────────────────────────────────────────────────────────────────────────── // Torch op wrapper // ──────────────────────────────────────────────────────────────────────────── -torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( - torch::Tensor const& q_in, // [N, num_heads_q, 512] bf16 - torch::Tensor const& kv, // [N, 512] bf16 (read-only) - torch::Tensor& k_cache, // [num_blocks, block_bytes] uint8 - torch::Tensor const& slot_mapping, // [N] int64 - torch::Tensor const& position_ids, // [N] int64 - torch::Tensor const& cos_sin_cache, // [max_pos, rope_dim] bf16 - int64_t q_head_padded, // padded Q head count for output +torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( + torch::stable::Tensor const& q_in, // [N, num_heads_q, 512] bf16 + torch::stable::Tensor const& kv, // [N, 512] bf16 (read-only) + torch::stable::Tensor& k_cache, // [num_blocks, block_bytes] uint8 + torch::stable::Tensor const& slot_mapping, // [N] int64 + torch::stable::Tensor const& position_ids, // [N] int64 + torch::stable::Tensor const& cos_sin_cache, // [max_pos, rope_dim] bf16 + int64_t q_head_padded, // padded Q head count for output double eps, int64_t cache_block_size) { - TORCH_CHECK(q_in.is_cuda() && q_in.is_contiguous(), - "q_in must be contiguous CUDA"); - TORCH_CHECK(kv.is_cuda() && kv.is_contiguous(), "kv must be contiguous CUDA"); - TORCH_CHECK(k_cache.is_cuda(), "k_cache must be CUDA"); - TORCH_CHECK(slot_mapping.is_cuda() && slot_mapping.dtype() == torch::kInt64, - "slot_mapping must be int64 CUDA"); - TORCH_CHECK(position_ids.is_cuda() && position_ids.dtype() == torch::kInt64, - "position_ids must be int64 CUDA"); - TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA"); - TORCH_CHECK(q_in.dim() == 3 && q_in.size(2) == 512, - "q_in shape [N, num_heads_q, 512]"); - TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]"); - TORCH_CHECK(q_in.dtype() == kv.dtype(), "q_in and kv dtype must match"); - TORCH_CHECK(q_head_padded >= q_in.size(1), - "q_head_padded must be >= q_in.size(1) (num_heads_q)"); - TORCH_CHECK(k_cache.dtype() == torch::kUInt8, "k_cache must be uint8"); - TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64, - "cos_sin_cache shape [max_pos, 64]"); - TORCH_CHECK(cos_sin_cache.dtype() == torch::kFloat32, - "cos_sin_cache must be float32"); + STD_TORCH_CHECK(q_in.device().is_cuda() && q_in.is_contiguous(), + "q_in must be contiguous CUDA"); + STD_TORCH_CHECK(kv.device().is_cuda() && kv.is_contiguous(), + "kv must be contiguous CUDA"); + STD_TORCH_CHECK(k_cache.device().is_cuda(), "k_cache must be CUDA"); + STD_TORCH_CHECK(slot_mapping.device().is_cuda() && + slot_mapping.scalar_type() == + torch::headeronly::ScalarType::Long, + "slot_mapping must be int64 CUDA"); + STD_TORCH_CHECK(position_ids.device().is_cuda() && + position_ids.scalar_type() == + torch::headeronly::ScalarType::Long, + "position_ids must be int64 CUDA"); + STD_TORCH_CHECK(cos_sin_cache.device().is_cuda(), "cos_sin_cache must be CUDA"); + STD_TORCH_CHECK(q_in.dim() == 3 && q_in.size(2) == 512, + "q_in shape [N, num_heads_q, 512]"); + STD_TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]"); + STD_TORCH_CHECK(q_in.scalar_type() == kv.scalar_type(), + "q_in and kv dtype must match"); + STD_TORCH_CHECK(q_head_padded >= q_in.size(1), + "q_head_padded must be >= q_in.size(1) (num_heads_q)"); + STD_TORCH_CHECK(k_cache.scalar_type() == torch::headeronly::ScalarType::Byte, + "k_cache must be uint8"); + STD_TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64, + "cos_sin_cache shape [max_pos, 64]"); + STD_TORCH_CHECK(cos_sin_cache.scalar_type() == + torch::headeronly::ScalarType::Float, + "cos_sin_cache must be float32"); // With DP padding, slot_mapping can be shorter than q/kv/positions. // Q-norm+RoPE runs on all q.size(0) rows (downstream attention uses them); // KV quant+insert runs only on the first slot_mapping.size(0) rows. int const num_tokens_full = static_cast(q_in.size(0)); int const num_tokens_insert = static_cast(slot_mapping.size(0)); - TORCH_CHECK(static_cast(kv.size(0)) == num_tokens_full && - static_cast(position_ids.size(0)) == num_tokens_full, - "q/kv/position_ids row counts must match"); - TORCH_CHECK(num_tokens_insert <= num_tokens_full, - "slot_mapping must not exceed q row count"); + STD_TORCH_CHECK(static_cast(kv.size(0)) == num_tokens_full && + static_cast(position_ids.size(0)) == num_tokens_full, + "q/kv/position_ids row counts must match"); + STD_TORCH_CHECK(num_tokens_insert <= num_tokens_full, + "slot_mapping must not exceed q row count"); int const num_heads_q = static_cast(q_in.size(1)); int const num_heads_q_padded = static_cast(q_head_padded); int const cache_block_size_i = static_cast(cache_block_size); int const kv_block_stride = static_cast(k_cache.stride(0)); - at::cuda::OptionalCUDAGuard device_guard(device_of(q_in)); - auto stream = at::cuda::getCurrentCUDAStream(); + const torch::stable::accelerator::DeviceGuard device_guard( + q_in.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(q_in.get_device_index()); // Allocate the padded q output. The kernel writes every element (live // region gets RMSNorm+RoPE; pad region gets zeros), so `empty` is safe. - torch::Tensor q_out = torch::empty( - {q_in.size(0), q_head_padded, q_in.size(2)}, q_in.options()); + auto q_out = torch::stable::new_empty( + q_in, {q_in.size(0), q_head_padded, q_in.size(2)}, q_in.scalar_type()); - VLLM_DISPATCH_HALF_TYPES( + VLLM_STABLE_DISPATCH_HALF_TYPES( q_in.scalar_type(), "fused_deepseek_v4_qnorm_rope_kv_insert", [&] { using qkv_scalar_t = scalar_t; vllm::deepseek_v4_fused_ops:: launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( - reinterpret_cast(q_in.data_ptr()), - reinterpret_cast(q_out.data_ptr()), - reinterpret_cast(kv.data_ptr()), - reinterpret_cast(k_cache.data_ptr()), - reinterpret_cast(slot_mapping.data_ptr()), - reinterpret_cast(position_ids.data_ptr()), - cos_sin_cache.data_ptr(), static_cast(eps), + reinterpret_cast(q_in.const_data_ptr()), + reinterpret_cast(q_out.mutable_data_ptr()), + reinterpret_cast(kv.const_data_ptr()), + reinterpret_cast(k_cache.mutable_data_ptr()), + slot_mapping.const_data_ptr(), + position_ids.const_data_ptr(), + cos_sin_cache.const_data_ptr(), static_cast(eps), num_tokens_full, num_tokens_insert, num_heads_q, num_heads_q_padded, cache_block_size_i, kv_block_stride, stream); diff --git a/csrc/minimax_reduce_rms_kernel.cu b/csrc/libtorch_stable/minimax_reduce_rms_kernel.cu similarity index 87% rename from csrc/minimax_reduce_rms_kernel.cu rename to csrc/libtorch_stable/minimax_reduce_rms_kernel.cu index 6245b02d6e9..d9af0f5efe0 100644 --- a/csrc/minimax_reduce_rms_kernel.cu +++ b/csrc/libtorch_stable/minimax_reduce_rms_kernel.cu @@ -15,16 +15,19 @@ * limitations under the License. */ +#include "torch_utils.h" + +#include +#include +#include +#include +#include +#include + #include #include -#include -#include -#include - #include "cuda_compat.h" -#include "cuda_utils.h" -#include "core/registration.h" #include "minimax_reduce_rms_kernel.h" #include @@ -611,7 +614,7 @@ int get_sm_count() { static int sm_count = 0; if (sm_count == 0) { int device_id; - CUDA_CHECK(cudaGetDevice(&device_id)); + STD_CUDA_CHECK(cudaGetDevice(&device_id)); cudaDeviceProp device_prop; cudaGetDeviceProperties(&device_prop, device_id); sm_count = device_prop.multiProcessorCount; @@ -621,13 +624,13 @@ int get_sm_count() { inline int getSMVersion(bool queryRealSmArch = false) { int device{-1}; - CUDA_CHECK(cudaGetDevice(&device)); + STD_CUDA_CHECK(cudaGetDevice(&device)); int sm_major = 0; int sm_minor = 0; - CUDA_CHECK(cudaDeviceGetAttribute(&sm_major, - cudaDevAttrComputeCapabilityMajor, device)); - CUDA_CHECK(cudaDeviceGetAttribute(&sm_minor, - cudaDevAttrComputeCapabilityMinor, device)); + STD_CUDA_CHECK(cudaDeviceGetAttribute( + &sm_major, cudaDevAttrComputeCapabilityMajor, device)); + STD_CUDA_CHECK(cudaDeviceGetAttribute( + &sm_minor, cudaDevAttrComputeCapabilityMinor, device)); int sm = sm_major * 10 + sm_minor; if (sm == 121 && !queryRealSmArch) { return 120; @@ -639,7 +642,7 @@ template int get_max_active_blocks(KernelFunc kernel, int block_size, int dynamic_smem = 0) { int max_active = 0; - CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + STD_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active, kernel, block_size, dynamic_smem)); return std::max(max_active, 1); } @@ -678,27 +681,27 @@ void minimax_reduce_rms_kernel_launcher(MiniMaxReduceRMSParams const& params) { cfg.attrs = attribute; cfg.numAttrs = SM >= 90 ? 2 : 0; - CUDA_CHECK(cudaLaunchKernelEx( + STD_CUDA_CHECK(cudaLaunchKernelEx( &cfg, minimax_reduce_rms_kernel_lamport, params)); } template void minimax_reduce_rms_kernel_launcher_float4( MiniMaxReduceRMSParams const& params) { - TORCH_CHECK(params.size_q % params.hidden_dim == 0); - TORCH_CHECK(params.hidden_dim % kElemsPerAccess == 0); + STD_TORCH_CHECK(params.size_q % params.hidden_dim == 0); + STD_TORCH_CHECK(params.hidden_dim % kElemsPerAccess == 0); if (params.stride_q > 0) { - TORCH_CHECK(params.stride_q % kElemsPerAccess == 0); + STD_TORCH_CHECK(params.stride_q % kElemsPerAccess == 0); } - TORCH_CHECK(params.allreduce_in_k != nullptr, - "float4 QK kernel requires K input"); - TORCH_CHECK(params.hidden_dim >= params.hidden_dim_k); - TORCH_CHECK(params.size_k % params.hidden_dim_k == 0); - TORCH_CHECK(params.hidden_dim_k % kElemsPerAccess == 0); - TORCH_CHECK(params.size_q / params.hidden_dim == - params.size_k / params.hidden_dim_k); + STD_TORCH_CHECK(params.allreduce_in_k != nullptr, + "float4 QK kernel requires K input"); + STD_TORCH_CHECK(params.hidden_dim >= params.hidden_dim_k); + STD_TORCH_CHECK(params.size_k % params.hidden_dim_k == 0); + STD_TORCH_CHECK(params.hidden_dim_k % kElemsPerAccess == 0); + STD_TORCH_CHECK(params.size_q / params.hidden_dim == + params.size_k / params.hidden_dim_k); if (params.stride_k > 0) { - TORCH_CHECK(params.stride_k % kElemsPerAccess == 0); + STD_TORCH_CHECK(params.stride_k % kElemsPerAccess == 0); } int token_num = params.size_q / params.hidden_dim; @@ -746,7 +749,7 @@ void minimax_reduce_rms_kernel_launcher_float4( cfg.attrs = attribute; cfg.numAttrs = SM >= 90 ? 2 : 0; - CUDA_CHECK(cudaLaunchKernelEx(&cfg, kfn, params)); + STD_CUDA_CHECK(cudaLaunchKernelEx(&cfg, kfn, params)); } template @@ -759,21 +762,21 @@ void dispatch_dtype(MiniMaxReduceRMSParams const& params) { (params.hidden_dim * params.nranks == 6144) && (params.hidden_dim_k * params.nranks == 1024); - if (params.dtype == at::ScalarType::Half) { + if (params.dtype == torch::headeronly::ScalarType::Half) { if (use_float4) { minimax_reduce_rms_kernel_launcher_float4( params); } else { minimax_reduce_rms_kernel_launcher(params); } - } else if (params.dtype == at::ScalarType::BFloat16) { + } else if (params.dtype == torch::headeronly::ScalarType::BFloat16) { if (use_float4) { minimax_reduce_rms_kernel_launcher_float4<__nv_bfloat16, NRanks, 6144, 1024>(params); } else { minimax_reduce_rms_kernel_launcher<__nv_bfloat16, NRanks>(params); } - } else if (params.dtype == at::ScalarType::Float) { + } else if (params.dtype == torch::headeronly::ScalarType::Float) { if (use_float4) { minimax_reduce_rms_kernel_launcher_float4( params); @@ -781,7 +784,7 @@ void dispatch_dtype(MiniMaxReduceRMSParams const& params) { minimax_reduce_rms_kernel_launcher(params); } } else { - TORCH_CHECK(false, "Unsupported data type for minimax_reduce_rms_op"); + STD_TORCH_CHECK(false, "Unsupported data type for minimax_reduce_rms_op"); } } @@ -795,16 +798,18 @@ void minimax_reduce_rms_op(MiniMaxReduceRMSParams const& params) { } else if (params.nranks == 16) { dispatch_dtype<16>(params); } else { - TORCH_CHECK(false, "minimax_reduce_rms_op: unsupported ranks number!"); + STD_TORCH_CHECK(false, "minimax_reduce_rms_op: unsupported ranks number!"); } } } // namespace tensorrt_llm } // namespace vllm -torch::Tensor minimax_allreduce_rms(torch::Tensor const& input, - torch::Tensor const& norm_weight, - torch::Tensor workspace, int64_t const rank, - int64_t const nranks, double const eps) { +torch::stable::Tensor minimax_allreduce_rms( + torch::stable::Tensor const& input, + torch::stable::Tensor const& norm_weight, torch::stable::Tensor workspace, + int64_t const rank, int64_t const nranks, double const eps) { + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); auto allreduce_params = vllm::tensorrt_llm::MiniMaxReduceRMSParams(); allreduce_params.nranks = static_cast(nranks); @@ -815,12 +820,12 @@ torch::Tensor minimax_allreduce_rms(torch::Tensor const& input, allreduce_params.stride_q = allreduce_params.hidden_dim; allreduce_params.workspace = reinterpret_cast(workspace.mutable_data_ptr()); - allreduce_params.allreduce_in = input.data_ptr(); - allreduce_params.rms_gamma = norm_weight.data_ptr(); + allreduce_params.allreduce_in = const_cast(input.const_data_ptr()); + allreduce_params.rms_gamma = const_cast(norm_weight.const_data_ptr()); allreduce_params.rms_eps = static_cast(eps); - allreduce_params.stream = at::cuda::getCurrentCUDAStream(input.get_device()); + allreduce_params.stream = get_current_cuda_stream(input.get_device_index()); - torch::Tensor rms_norm_out = torch::empty_like(input); + torch::stable::Tensor rms_norm_out = torch::stable::empty_like(input); allreduce_params.rms_norm_out = rms_norm_out.mutable_data_ptr(); vllm::tensorrt_llm::minimax_reduce_rms_op(allreduce_params); @@ -828,26 +833,33 @@ torch::Tensor minimax_allreduce_rms(torch::Tensor const& input, return rms_norm_out; } -std::tuple minimax_allreduce_rms_qk( - torch::Tensor qkv, torch::Tensor const& norm_weight_q, - torch::Tensor const& norm_weight_k, torch::Tensor workspace, - int64_t const q_size, int64_t const kv_size, int64_t const rank, - int64_t const nranks, double const eps) { - TORCH_CHECK(qkv.dim() == 2, "minimax_allreduce_rms_qk: qkv must be 2D"); - TORCH_CHECK(qkv.is_contiguous(), - "minimax_allreduce_rms_qk: qkv must be contiguous"); +std::tuple +minimax_allreduce_rms_qk(torch::stable::Tensor qkv, + torch::stable::Tensor const& norm_weight_q, + torch::stable::Tensor const& norm_weight_k, + torch::stable::Tensor workspace, int64_t const q_size, + int64_t const kv_size, int64_t const rank, + int64_t const nranks, double const eps) { + STD_TORCH_CHECK(qkv.dim() == 2, "minimax_allreduce_rms_qk: qkv must be 2D"); + STD_TORCH_CHECK(qkv.is_contiguous(), + "minimax_allreduce_rms_qk: qkv must be contiguous"); int64_t qkv_dim = qkv.size(-1); - TORCH_CHECK(qkv_dim == q_size + 2 * kv_size, - "minimax_allreduce_rms_qk: qkv last dim must equal " - "q_size + 2 * kv_size"); - TORCH_CHECK(rank < nranks, - "minimax_allreduce_rms_qk: rank must be less than nranks"); + STD_TORCH_CHECK(qkv_dim == q_size + 2 * kv_size, + "minimax_allreduce_rms_qk: qkv last dim must equal " + "q_size + 2 * kv_size"); + STD_TORCH_CHECK(rank < nranks, + "minimax_allreduce_rms_qk: rank must be less than nranks"); + + const torch::stable::accelerator::DeviceGuard device_guard( + qkv.get_device_index()); int64_t num_tokens = qkv.size(0); int elem_bytes = qkv.element_size(); - torch::Tensor q_out = torch::empty({num_tokens, q_size}, qkv.options()); - torch::Tensor k_out = torch::empty({num_tokens, kv_size}, qkv.options()); + torch::stable::Tensor q_out = + torch::stable::new_empty(qkv, {num_tokens, q_size}, qkv.scalar_type()); + torch::stable::Tensor k_out = + torch::stable::new_empty(qkv, {num_tokens, kv_size}, qkv.scalar_type()); auto params = vllm::tensorrt_llm::MiniMaxReduceRMSParams(); params.nranks = static_cast(nranks); @@ -863,13 +875,14 @@ std::tuple minimax_allreduce_rms_qk( params.stride_k_out = 0; // k_out is contiguous; kernel uses hidden_dim_k params.workspace = reinterpret_cast(workspace.mutable_data_ptr()); - uint8_t* base = static_cast(qkv.data_ptr()); + uint8_t* base = + const_cast(static_cast(qkv.const_data_ptr())); params.allreduce_in = base; params.allreduce_in_k = base + q_size * elem_bytes; - params.rms_gamma = norm_weight_q.data_ptr(); - params.rms_gamma_k = norm_weight_k.data_ptr(); + params.rms_gamma = const_cast(norm_weight_q.const_data_ptr()); + params.rms_gamma_k = const_cast(norm_weight_k.const_data_ptr()); params.rms_eps = static_cast(eps); - params.stream = at::cuda::getCurrentCUDAStream(qkv.get_device()); + params.stream = get_current_cuda_stream(qkv.get_device_index()); params.rms_norm_out = q_out.mutable_data_ptr(); params.rms_norm_out_k = k_out.mutable_data_ptr(); diff --git a/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu b/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu new file mode 100644 index 00000000000..fda9bc020da --- /dev/null +++ b/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project +// Adapted from SGLang: +// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu + +#include +#include +#include "libtorch_stable/torch_utils.h" + +#include "cutlass_mxfp8_grouped_mm_launcher.cuh" + +void cutlass_mxfp8_grouped_mm(const torch::stable::Tensor& a, + const torch::stable::Tensor& b, + const torch::stable::Tensor& sfa, + const torch::stable::Tensor& sfb, + torch::stable::Tensor& d, + const torch::stable::Tensor& problem_sizes, + const torch::stable::Tensor& expert_offsets, + const torch::stable::Tensor& blockscale_offsets) { +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + STD_TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); + STD_TORCH_CHECK(problem_sizes.size(1) == 3, + "problem_sizes must have shape (num_experts, 3)"); + STD_TORCH_CHECK( + problem_sizes.size(0) == expert_offsets.size(0), + "Number of experts in problem_sizes must match expert_offsets"); + STD_TORCH_CHECK( + problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int, + "problem_sizes must be int32"); + STD_TORCH_CHECK( + expert_offsets.scalar_type() == torch::headeronly::ScalarType::Int, + "expert_offsets must be int32"); + STD_TORCH_CHECK( + blockscale_offsets.scalar_type() == torch::headeronly::ScalarType::Int, + "blockscale_offsets must be int32"); + STD_TORCH_CHECK(a.dim() == 2, + "a must be a 2D tensor of shape (num_tokens, k)"); + STD_TORCH_CHECK(b.dim() == 3, + "b must be a 3D tensor of shape (num_experts, k, n)"); + STD_TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0, + "k should align 128"); + STD_TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128"); + STD_TORCH_CHECK(a.stride(1) == 1, "a must be row major"); + STD_TORCH_CHECK(b.stride(1) == 1, "b must be column major"); + + const torch::stable::accelerator::DeviceGuard device_guard( + a.get_device_index()); + auto stream = get_current_cuda_stream(a.get_device_index()); + if (d.scalar_type() == torch::headeronly::ScalarType::BFloat16) { + expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype< + cutlass::bfloat16_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets, + blockscale_offsets, stream); + } else if (d.scalar_type() == torch::headeronly::ScalarType::Half) { + expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype< + cutlass::half_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets, + blockscale_offsets, stream); + } else { + STD_TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16"); + } +#else + STD_TORCH_CHECK(false, + "No implemented cutlass_mxfp8_grouped_mm for " + "current device"); +#endif +} + +STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { + m.impl("cutlass_mxfp8_grouped_mm", TORCH_BOX(&cutlass_mxfp8_grouped_mm)); +} diff --git a/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh b/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh similarity index 100% rename from csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh rename to csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh diff --git a/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh b/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh similarity index 54% rename from csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh rename to csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh index 2c46e1fa725..82d6543b288 100644 --- a/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh +++ b/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh @@ -4,9 +4,9 @@ // https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_launcher.cuh #pragma once -#include -#include -#include + +#include +#include #include #include @@ -15,18 +15,22 @@ #include "cute/tensor.hpp" #include "cutlass_mxfp8_grouped_mm_functor.cuh" #include "cutlass_mxfp8_grouped_mm_traits.cuh" +#include "libtorch_stable/torch_utils.h" namespace expert_specialization { template void cutlass_mxfp8_grouped_mm_pre_compute( - torch::Tensor& a_ptrs, torch::Tensor& b_ptrs, torch::Tensor& sfa_ptrs, - torch::Tensor& sfb_ptrs, torch::Tensor& d_ptrs, torch::Tensor& stride_a, - torch::Tensor& stride_b, torch::Tensor& stride_d, torch::Tensor& layout_sfa, - torch::Tensor& layout_sfb, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& sfa, const torch::Tensor& sfb, const torch::Tensor& d, - const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets, - const torch::Tensor& blockscale_offsets, cudaStream_t stream) { + torch::stable::Tensor& a_ptrs, torch::stable::Tensor& b_ptrs, + torch::stable::Tensor& sfa_ptrs, torch::stable::Tensor& sfb_ptrs, + torch::stable::Tensor& d_ptrs, torch::stable::Tensor& stride_a, + torch::stable::Tensor& stride_b, torch::stable::Tensor& stride_d, + torch::stable::Tensor& layout_sfa, torch::stable::Tensor& layout_sfb, + const torch::stable::Tensor& a, const torch::stable::Tensor& b, + const torch::stable::Tensor& sfa, const torch::stable::Tensor& sfb, + const torch::stable::Tensor& d, const torch::stable::Tensor& problem_sizes, + const torch::stable::Tensor& expert_offsets, + const torch::stable::Tensor& blockscale_offsets, cudaStream_t stream) { using OffsetFunctor = CutlassMxfp8GroupedMmOffsetFunctor; using ElementA = typename OffsetFunctor::ElementA; using ElementB = typename OffsetFunctor::ElementB; @@ -42,10 +46,10 @@ void cutlass_mxfp8_grouped_mm_pre_compute( using StrideB = typename StrideFunctor::StrideB; using StrideD = typename StrideFunctor::StrideD; - int num_experts = (int)expert_offsets.size(0); - TORCH_CHECK(num_experts <= 1024, - "Number of experts cannot exceed 1024, the maximum number of " - "threads per block."); + int num_experts = static_cast(expert_offsets.size(0)); + STD_TORCH_CHECK(num_experts <= 1024, + "Number of experts cannot exceed 1024, the maximum number of " + "threads per block."); OffsetFunctor offset_functor( reinterpret_cast(expert_offsets.data_ptr()), @@ -72,13 +76,18 @@ void cutlass_mxfp8_grouped_mm_pre_compute( } template -void cutlass_mxfp8_grouped_mm( - const torch::Tensor& a_ptrs, const torch::Tensor& b_ptrs, - const torch::Tensor& sfa_ptrs, const torch::Tensor& sfb_ptrs, - const torch::Tensor& d_ptrs, const torch::Tensor& stride_a, - const torch::Tensor& stride_b, const torch::Tensor& stride_d, - const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb, - const torch::Tensor& problem_sizes, cudaStream_t stream) { +void cutlass_mxfp8_grouped_mm(const torch::stable::Tensor& a_ptrs, + const torch::stable::Tensor& b_ptrs, + const torch::stable::Tensor& sfa_ptrs, + const torch::stable::Tensor& sfb_ptrs, + const torch::stable::Tensor& d_ptrs, + const torch::stable::Tensor& stride_a, + const torch::stable::Tensor& stride_b, + const torch::stable::Tensor& stride_d, + const torch::stable::Tensor& layout_sfa, + const torch::stable::Tensor& layout_sfb, + const torch::stable::Tensor& problem_sizes, + cudaStream_t stream) { using Gemm = typename GemmTraits::Gemm; using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; @@ -93,13 +102,12 @@ void cutlass_mxfp8_grouped_mm( typename GemmTraits::ProblemShape::UnderlyingProblemShape; cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = c10::cuda::current_device(); - hw_info.sm_count = - at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + hw_info.device_id = d_ptrs.get_device_index(); + hw_info.sm_count = get_device_prop()->multiProcessorCount; hw_info.cluster_shape = GemmTraits::MMAConfig::preferred_cluster; hw_info.cluster_shape_fallback = GemmTraits::MMAConfig::fallback_cluster; - int num_experts = (int)problem_sizes.size(0); + int num_experts = static_cast(problem_sizes.size(0)); UnderlyingProblemShape* underlying_problem_shape = reinterpret_cast(problem_sizes.data_ptr()); @@ -127,44 +135,55 @@ void cutlass_mxfp8_grouped_mm( Gemm gemm; auto can_implement_status = gemm.can_implement(arguments); - TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, - "Failed to implement GEMM"); + STD_TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, + "Failed to implement GEMM"); - torch::TensorOptions options_uint8 = - torch::TensorOptions().dtype(torch::kUInt8).device(d_ptrs.device()); size_t workspace_size = gemm.get_workspace_size(arguments); - torch::Tensor workspace = torch::empty(workspace_size, options_uint8); + torch::stable::Tensor workspace = torch::stable::empty( + {static_cast(workspace_size)}, + torch::headeronly::ScalarType::Byte, std::nullopt, d_ptrs.device()); auto status = gemm.initialize(arguments, workspace.data_ptr(), stream); - TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); + STD_TORCH_CHECK(status == cutlass::Status::kSuccess, + "Failed to initialize GEMM"); status = gemm.run(stream, nullptr, true); // Enable PDL - TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); + STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); } template void cutlass_mxfp8_grouped_mm_dispatch_out_dtype( - const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& sfa, - const torch::Tensor& sfb, torch::Tensor& d, - const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets, - const torch::Tensor& blockscale_offsets, cudaStream_t stream) { - int num_experts = (int)problem_sizes.size(0); - torch::TensorOptions options_int64 = - torch::TensorOptions().dtype(torch::kInt64).device(a.device()); - torch::TensorOptions options_int32 = - torch::TensorOptions().dtype(torch::kInt32).device(a.device()); + const torch::stable::Tensor& a, const torch::stable::Tensor& b, + const torch::stable::Tensor& sfa, const torch::stable::Tensor& sfb, + torch::stable::Tensor& d, const torch::stable::Tensor& problem_sizes, + const torch::stable::Tensor& expert_offsets, + const torch::stable::Tensor& blockscale_offsets, cudaStream_t stream) { + int num_experts = static_cast(problem_sizes.size(0)); + auto device = a.device(); - torch::Tensor a_ptrs = torch::empty(num_experts, options_int64); - torch::Tensor b_ptrs = torch::empty(num_experts, options_int64); - torch::Tensor sfa_ptrs = torch::empty(num_experts, options_int64); - torch::Tensor sfb_ptrs = torch::empty(num_experts, options_int64); - torch::Tensor d_ptrs = torch::empty(num_experts, options_int64); + torch::stable::Tensor a_ptrs = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); + torch::stable::Tensor b_ptrs = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); + torch::stable::Tensor sfa_ptrs = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); + torch::stable::Tensor sfb_ptrs = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); + torch::stable::Tensor d_ptrs = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); - torch::Tensor stride_a = torch::empty(num_experts, options_int64); - torch::Tensor stride_b = torch::empty(num_experts, options_int64); - torch::Tensor stride_d = torch::empty(num_experts, options_int64); - torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int32); - torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int32); + torch::stable::Tensor stride_a = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); + torch::stable::Tensor stride_b = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); + torch::stable::Tensor stride_d = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); + torch::stable::Tensor layout_sfa = + torch::stable::empty({num_experts, 5}, torch::headeronly::ScalarType::Int, + std::nullopt, device); + torch::stable::Tensor layout_sfb = + torch::stable::empty({num_experts, 5}, torch::headeronly::ScalarType::Int, + std::nullopt, device); using GemmTraits = CutlassMxfp8GroupedMmGemmTraits; cutlass_mxfp8_grouped_mm_pre_compute( @@ -176,4 +195,4 @@ void cutlass_mxfp8_grouped_mm_dispatch_out_dtype( layout_sfa, layout_sfb, problem_sizes, stream); } -} // namespace expert_specialization \ No newline at end of file +} // namespace expert_specialization diff --git a/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh b/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh similarity index 100% rename from csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh rename to csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh diff --git a/csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cu b/csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cu new file mode 100644 index 00000000000..e075721c2a3 --- /dev/null +++ b/csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cu @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project +// Adapted from SGLang: +// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu + +#include +#include +#include "libtorch_stable/torch_utils.h" + +#include "mxfp8_experts_quant.cuh" + +void mxfp8_experts_quant(const torch::stable::Tensor& input, + const torch::stable::Tensor& problem_sizes, + const torch::stable::Tensor& expert_offsets, + const torch::stable::Tensor& blockscale_offsets, + torch::stable::Tensor& quant_output, + torch::stable::Tensor& scale_factor) { +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + STD_TORCH_CHECK(input.dim() == 2, "input must be 2D tensor"); + STD_TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128"); + STD_TORCH_CHECK(input.stride(1) == 1, "input must be row major"); + STD_TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); + STD_TORCH_CHECK( + problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int, + "problem_sizes must be int32"); + STD_TORCH_CHECK( + expert_offsets.scalar_type() == torch::headeronly::ScalarType::Int, + "expert_offsets must be int32"); + STD_TORCH_CHECK( + blockscale_offsets.scalar_type() == torch::headeronly::ScalarType::Int, + "blockscale_offsets must be int32"); + + auto groups = problem_sizes.size(0); + STD_TORCH_CHECK( + expert_offsets.dim() == 1 && expert_offsets.size(0) == groups, + "expert_offsets must be 1D and have size equal to the number of groups"); + STD_TORCH_CHECK( + blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups, + "blockscale_offsets must be 1D and have size equal to the number of " + "groups"); + + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + if (input.scalar_type() == torch::headeronly::ScalarType::BFloat16) { + expert_specialization::launch_mxfp8_experts_quant<__nv_bfloat16>( + input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, + scale_factor); + } else if (input.scalar_type() == torch::headeronly::ScalarType::Half) { + expert_specialization::launch_mxfp8_experts_quant<__half>( + input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, + scale_factor); + } else { + STD_TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16"); + } +#else + STD_TORCH_CHECK(false, + "No implemented mxfp8_experts_quant for " + "current device"); +#endif +} + +// Registered here (not torch_bindings.cpp) because ENABLE_ES_MXFP8_GROUPED_MM +// is applied only under COMPILE_LANGUAGE:CUDA. +STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { + m.impl("mxfp8_experts_quant", TORCH_BOX(&mxfp8_experts_quant)); +} diff --git a/csrc/moe/mxfp8_moe/mxfp8_experts_quant.cuh b/csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cuh similarity index 95% rename from csrc/moe/mxfp8_moe/mxfp8_experts_quant.cuh rename to csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cuh index 9a85852080f..a57e00e76c3 100644 --- a/csrc/moe/mxfp8_moe/mxfp8_experts_quant.cuh +++ b/csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cuh @@ -4,16 +4,19 @@ // https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cuh #pragma once -#include -#include #include #include #include -#include + +#include +#include +#include +#include #include #include "cute/tensor.hpp" +#include "libtorch_stable/torch_utils.h" namespace expert_specialization { @@ -356,12 +359,12 @@ __global__ void mxfp8_experts_quant_kernel( } template -void launch_mxfp8_experts_quant(const torch::Tensor& input, - const torch::Tensor& problem_sizes, - const torch::Tensor& expert_offsets, - const torch::Tensor& blockscale_offsets, - torch::Tensor& quant_output, - torch::Tensor& scale_factor) { +void launch_mxfp8_experts_quant(const torch::stable::Tensor& input, + const torch::stable::Tensor& problem_sizes, + const torch::stable::Tensor& expert_offsets, + const torch::stable::Tensor& blockscale_offsets, + torch::stable::Tensor& quant_output, + torch::stable::Tensor& scale_factor) { ThrLayout thr_layout{}; ValLayout val_layout{}; SfR2SThrLayout r2s_thr_layout{}; @@ -386,19 +389,18 @@ void launch_mxfp8_experts_quant(const torch::Tensor& input, CopyAtomR2S{}, r2s_thr_layout, r2s_val_layout); // Tiler_MN: (16, 4) int max_active_blocks_per_sm = -1; - AT_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + STD_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_per_sm, mxfp8_experts_quant_kernel, THREAD_BLOCK_SIZE, 0)); - dim3 grid(at::cuda::getCurrentDeviceProperties()->multiProcessorCount * - max_active_blocks_per_sm, + dim3 grid(get_device_prop()->multiProcessorCount * max_active_blocks_per_sm, 1, 1); dim3 block(THREAD_BLOCK_SIZE, 1, 1); - int num_experts = (int)problem_sizes.size(0); - auto stream = at::cuda::getCurrentCUDAStream(); + int num_experts = static_cast(problem_sizes.size(0)); + auto stream = get_current_cuda_stream(input.get_device_index()); mxfp8_experts_quant_kernel <<>>( diff --git a/csrc/libtorch_stable/ops.h b/csrc/libtorch_stable/ops.h index 0363ec7cdfc..dd27a6968d0 100644 --- a/csrc/libtorch_stable/ops.h +++ b/csrc/libtorch_stable/ops.h @@ -231,6 +231,27 @@ void fused_qk_norm_rope(torch::stable::Tensor& qkv, int64_t num_heads_q, torch::stable::Tensor& position_ids, int64_t forced_token_heads_per_warp); +torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( + torch::stable::Tensor const& q_in, torch::stable::Tensor const& kv, + torch::stable::Tensor& k_cache, torch::stable::Tensor const& slot_mapping, + torch::stable::Tensor const& position_ids, + torch::stable::Tensor const& cos_sin_cache, int64_t q_head_padded, + double eps, int64_t cache_block_size); + +#ifndef USE_ROCM +torch::stable::Tensor minimax_allreduce_rms( + torch::stable::Tensor const& input, + torch::stable::Tensor const& norm_weight, torch::stable::Tensor workspace, + int64_t const rank, int64_t const nranks, double const eps); +std::tuple +minimax_allreduce_rms_qk(torch::stable::Tensor qkv, + torch::stable::Tensor const& norm_weight_q, + torch::stable::Tensor const& norm_weight_k, + torch::stable::Tensor workspace, int64_t const q_size, + int64_t const kv_size, int64_t const rank, + int64_t const nranks, double const eps); +#endif + // Sampler kernels (shared CUDA/ROCm) void apply_repetition_penalties_( torch::stable::Tensor& logits, const torch::stable::Tensor& prompt_mask, @@ -273,6 +294,26 @@ void selective_scan_fwd( const std::optional& cu_chunk_seqlen, const std::optional& last_chunk_indices); +using fptr_t = int64_t; +fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, + torch::stable::Tensor& rank_data, int64_t rank, + bool fully_connected); +void all_reduce(fptr_t _fa, torch::stable::Tensor& inp, + torch::stable::Tensor& out, fptr_t reg_buffer, + int64_t reg_buffer_sz_bytes); +void dispose(fptr_t _fa); +int64_t meta_size(); +void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs); +std::tuple, std::vector> +get_graph_buffer_ipc_meta(fptr_t _fa); +void register_graph_buffers(fptr_t _fa, + const std::vector>& handles, + const std::vector>& offsets); +std::tuple allocate_shared_buffer_and_handle( + int64_t size); +int64_t open_mem_handle(torch::stable::Tensor& mem_handle); +void free_shared_buffer(int64_t buffer); + // Activation kernels (shared CUDA/ROCm) void silu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input); void silu_and_mul_clamp(torch::stable::Tensor& out, diff --git a/csrc/libtorch_stable/torch_bindings.cpp b/csrc/libtorch_stable/torch_bindings.cpp index 98cd31df13b..e9a62a8666c 100644 --- a/csrc/libtorch_stable/torch_bindings.cpp +++ b/csrc/libtorch_stable/torch_bindings.cpp @@ -337,6 +337,24 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) { "bool is_neox, Tensor position_ids, " "int forced_token_heads_per_warp=-1) -> ()"); + ops.def( + "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(" + "Tensor q_in, Tensor kv, Tensor! k_cache, " + "Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, " + "int q_head_padded, float eps, int cache_block_size) -> Tensor"); + +#ifndef USE_ROCM + ops.def( + "minimax_allreduce_rms(" + "Tensor input, Tensor norm_weight, Tensor workspace, " + "int rank, int nranks, float eps) -> Tensor"); + ops.def( + "minimax_allreduce_rms_qk(" + "Tensor qkv, Tensor norm_weight_q, Tensor norm_weight_k, " + "Tensor workspace, int q_size, int kv_size, int rank, int nranks, " + "float eps) -> (Tensor, Tensor)"); +#endif + // Apply repetition penalties to logits in-place. ops.def( "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " @@ -571,6 +589,12 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) { // Positional encoding kernels (shared CUDA/ROCm) ops.impl("rotary_embedding", TORCH_BOX(&rotary_embedding)); ops.impl("fused_qk_norm_rope", TORCH_BOX(&fused_qk_norm_rope)); + ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", + TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert)); +#ifndef USE_ROCM + ops.impl("minimax_allreduce_rms", TORCH_BOX(&minimax_allreduce_rms)); + ops.impl("minimax_allreduce_rms_qk", TORCH_BOX(&minimax_allreduce_rms_qk)); +#endif // Sampler kernels (shared CUDA/ROCm) ops.impl("apply_repetition_penalties_", @@ -725,6 +749,45 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C_cache_ops, ops) { "dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()"); } +STABLE_TORCH_LIBRARY_FRAGMENT(_C_custom_ar, custom_ar) { + custom_ar.def( + "init_custom_ar(int[] ipc_tensors, Tensor rank_data, " + "int rank, bool fully_connected) -> int"); + custom_ar.def( + "all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, " + "int reg_buffer_sz_bytes) -> ()"); + custom_ar.def("dispose(int fa) -> ()"); + custom_ar.def("meta_size() -> int"); + custom_ar.def("register_buffer(int fa, int[] ipc_tensors) -> ()"); + custom_ar.def("get_graph_buffer_ipc_meta(int fa) -> (int[], int[])"); + custom_ar.def( + "register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()"); + custom_ar.def("allocate_shared_buffer_and_handle(int size) -> (int, Tensor)"); + custom_ar.def("open_mem_handle(Tensor mem_handle) -> int"); + custom_ar.def("free_shared_buffer(int ptr) -> ()"); +} + +STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CUDA, custom_ar) { + custom_ar.impl("init_custom_ar", TORCH_BOX(&init_custom_ar)); + custom_ar.impl("all_reduce", TORCH_BOX(&all_reduce)); +} + +STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CPU, custom_ar) { + custom_ar.impl("open_mem_handle", TORCH_BOX(&open_mem_handle)); +} + +STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CompositeExplicitAutograd, custom_ar) { + custom_ar.impl("dispose", TORCH_BOX(&dispose)); + custom_ar.impl("meta_size", TORCH_BOX(&meta_size)); + custom_ar.impl("register_buffer", TORCH_BOX(®ister_buffer)); + custom_ar.impl("get_graph_buffer_ipc_meta", + TORCH_BOX(&get_graph_buffer_ipc_meta)); + custom_ar.impl("register_graph_buffers", TORCH_BOX(®ister_graph_buffers)); + custom_ar.impl("allocate_shared_buffer_and_handle", + TORCH_BOX(&allocate_shared_buffer_and_handle)); + custom_ar.impl("free_shared_buffer", TORCH_BOX(&free_shared_buffer)); +} + STABLE_TORCH_LIBRARY_IMPL(_C_cache_ops, CPU, ops) { ops.impl("swap_blocks_batch", TORCH_BOX(&swap_blocks_batch)); } diff --git a/csrc/minimax_reduce_rms_kernel.h b/csrc/minimax_reduce_rms_kernel.h index e8c2d012247..c3d2dd5c599 100644 --- a/csrc/minimax_reduce_rms_kernel.h +++ b/csrc/minimax_reduce_rms_kernel.h @@ -19,7 +19,7 @@ #include #include -#include +#include namespace vllm { namespace tensorrt_llm { @@ -51,7 +51,7 @@ static constexpr int kElemsPerAccess = ElemsPerAccess::value; struct MiniMaxReduceRMSParams { int nranks{}; int rank{}; - at::ScalarType dtype{at::ScalarType::Undefined}; + torch::headeronly::ScalarType dtype{torch::headeronly::ScalarType::Undefined}; int size_q{}; int hidden_dim{}; int size_k{}; diff --git a/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu b/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu deleted file mode 100644 index f507f9299b0..00000000000 --- a/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu +++ /dev/null @@ -1,60 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright contributors to the vLLM project -// Adapted from SGLang: -// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu - -#include - -#include "cutlass_mxfp8_grouped_mm_launcher.cuh" - -void cutlass_mxfp8_grouped_mm(const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& sfa, - const torch::Tensor& sfb, torch::Tensor& d, - const torch::Tensor& problem_sizes, - const torch::Tensor& expert_offsets, - const torch::Tensor& blockscale_offsets) { -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) - TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); - TORCH_CHECK(problem_sizes.size(1) == 3, - "problem_sizes must have shape (num_experts, 3)"); - TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0), - "Number of experts in problem_sizes must match expert_offsets"); - TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, - "problem_sizes must be int32"); - TORCH_CHECK(expert_offsets.dtype() == torch::kInt32, - "expert_offsets must be int32"); - TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32, - "blockscale_offsets must be int32"); - TORCH_CHECK(a.dim() == 2, "a must be a 2D tensor of shape (num_tokens, k)"); - TORCH_CHECK(b.dim() == 3, - "b must be a 3D tensor of shape (num_experts, k, n)"); - TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0, - "k should align 128"); - TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128"); - TORCH_CHECK(a.strides()[1] == 1, "a must be row major"); - TORCH_CHECK(b.strides()[1] == 1, "b must be column major"); - - auto stream = at::cuda::getCurrentCUDAStream(); - if (d.dtype() == torch::kBFloat16) { - expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype< - cutlass::bfloat16_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets, - blockscale_offsets, stream); - } else if (d.dtype() == torch::kFloat16) { - expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype< - cutlass::half_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets, - blockscale_offsets, stream); - } else { - TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16"); - } -#else - TORCH_CHECK(false, - "No implemented cutlass_mxfp8_grouped_mm for " - "current device"); -#endif -} - -#include "core/registration.h" - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("cutlass_mxfp8_grouped_mm", cutlass_mxfp8_grouped_mm); -} \ No newline at end of file diff --git a/csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu b/csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu deleted file mode 100644 index 2a93ab94d5c..00000000000 --- a/csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu +++ /dev/null @@ -1,60 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright contributors to the vLLM project -// Adapted from SGLang: -// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu - -#include - -#include "mxfp8_experts_quant.cuh" - -void mxfp8_experts_quant(const torch::Tensor& input, - const torch::Tensor& problem_sizes, - const torch::Tensor& expert_offsets, - const torch::Tensor& blockscale_offsets, - torch::Tensor& quant_output, - torch::Tensor& scale_factor) { -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) - TORCH_CHECK(input.dim() == 2, "input must be 2D tensor"); - TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128"); - TORCH_CHECK(input.strides()[1] == 1, "input must be row major"); - TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); - TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, - "problem_sizes must be int32"); - TORCH_CHECK(expert_offsets.dtype() == torch::kInt32, - "expert_offsets must be int32"); - TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32, - "blockscale_offsets must be int32"); - - auto groups = problem_sizes.size(0); - TORCH_CHECK( - expert_offsets.dim() == 1 && expert_offsets.size(0) == groups, - "expert_offsets must be 1D and have size equal to the number of groups"); - TORCH_CHECK( - blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups, - "blockscale_offsets must be 1D and have size equal to the number of " - "groups"); - - auto stream = at::cuda::getCurrentCUDAStream(); - if (input.dtype() == torch::kBFloat16) { - expert_specialization::launch_mxfp8_experts_quant<__nv_bfloat16>( - input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, - scale_factor); - } else if (input.dtype() == torch::kFloat16) { - expert_specialization::launch_mxfp8_experts_quant<__half>( - input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, - scale_factor); - } else { - TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16"); - } -#else - TORCH_CHECK(false, - "No implemented mxfp8_experts_quant for " - "current device"); -#endif -} - -#include "core/registration.h" - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("mxfp8_experts_quant", mxfp8_experts_quant); -} \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index f458f79d6f4..ed2fca26b0d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -40,12 +40,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); -torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( - torch::Tensor const& q_in, torch::Tensor const& kv, torch::Tensor& k_cache, - torch::Tensor const& slot_mapping, torch::Tensor const& position_ids, - torch::Tensor const& cos_sin_cache, int64_t q_head_padded, double eps, - int64_t cache_block_size); - void silu_and_mul_per_block_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales, int64_t group_size, @@ -107,24 +101,6 @@ torch::Tensor dynamic_4bit_int_moe_cpu( int64_t activation_kind); using fptr_t = int64_t; -fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, - torch::Tensor& rank_data, int64_t rank, - bool fully_connected); -void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, - fptr_t reg_buffer, int64_t reg_buffer_sz_bytes); -void dispose(fptr_t _fa); -int64_t meta_size(); -void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs); -std::tuple, std::vector> -get_graph_buffer_ipc_meta(fptr_t _fa); -void register_graph_buffers(fptr_t _fa, - const std::vector>& handles, - const std::vector>& offsets); -std::tuple allocate_shared_buffer_and_handle( - int64_t size); -int64_t open_mem_handle(torch::Tensor& mem_handle); -void free_shared_buffer(int64_t buffer); - #ifdef USE_ROCM fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional qr_max_size = std::nullopt); @@ -135,15 +111,3 @@ void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half = false); int64_t qr_max_size(); #endif - -#ifndef USE_ROCM -torch::Tensor minimax_allreduce_rms(torch::Tensor const& input, - torch::Tensor const& norm_weight, - torch::Tensor workspace, int64_t const rank, - int64_t const nranks, double const eps); -std::tuple minimax_allreduce_rms_qk( - torch::Tensor qkv, torch::Tensor const& norm_weight_q, - torch::Tensor const& norm_weight_k, torch::Tensor workspace, - int64_t const q_size, int64_t const kv_size, int64_t const rank, - int64_t const nranks, double const eps); -#endif diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 01869474e0f..c078222bca0 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -55,14 +55,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and // GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one - // kernel launch. - ops.def( - "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(" - "Tensor q_in, Tensor kv, Tensor! k_cache, " - "Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, " - "int q_head_padded, float eps, int cache_block_size) -> Tensor"); - ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA, - &fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert); + // kernel launch. Registered in _C_stable_libtorch. // Quantization ops #ifndef USE_ROCM @@ -163,34 +156,27 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // conditionally compiled so impl registration is in source file #endif - -#ifndef USE_ROCM - ops.def( - "minimax_allreduce_rms(" - "Tensor input," - "Tensor norm_weight," - "Tensor workspace," - "int rank," - "int nranks," - "float eps) -> Tensor"); - ops.impl("minimax_allreduce_rms", torch::kCUDA, &minimax_allreduce_rms); - ops.def( - "minimax_allreduce_rms_qk(" - "Tensor qkv," - "Tensor norm_weight_q," - "Tensor norm_weight_k," - "Tensor workspace," - "int q_size," - "int kv_size," - "int rank," - "int nranks," - "float eps) -> (Tensor, Tensor)"); - ops.impl("minimax_allreduce_rms_qk", torch::kCUDA, &minimax_allreduce_rms_qk); - - // conditionally compiled so impl in source file -#endif } +#ifdef USE_ROCM +TORCH_LIBRARY_FRAGMENT(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { + // Quick Reduce all-reduce kernels (ROCm-only; stays on legacy _C). + custom_ar.def( + "qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool " + "cast_bf2half) -> ()"); + custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce); + + custom_ar.def("init_custom_qr", &init_custom_qr); + custom_ar.def("qr_destroy", &qr_destroy); + custom_ar.def("qr_get_handle", &qr_get_handle); + + custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()"); + custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles); + + custom_ar.def("qr_max_size", &qr_max_size); +} +#endif + TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { // Cuda utils @@ -205,48 +191,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { &get_max_shared_memory_per_block_device_attribute); } -TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { - // Custom all-reduce kernels - custom_ar.def( - "init_custom_ar(int[] ipc_tensors, Tensor rank_data, " - "int rank, bool fully_connected) -> int"); - custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); - custom_ar.def( - "all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, " - "int reg_buffer_sz_bytes) -> ()"); - custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce); - - custom_ar.def("dispose", &dispose); - custom_ar.def("meta_size", &meta_size); - - custom_ar.def("register_buffer", ®ister_buffer); - custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); - custom_ar.def("register_graph_buffers", ®ister_graph_buffers); - - custom_ar.def("allocate_shared_buffer_and_handle", - &allocate_shared_buffer_and_handle); - custom_ar.def("open_mem_handle(Tensor mem_handle) -> int", &open_mem_handle); - custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle); - - custom_ar.def("free_shared_buffer", &free_shared_buffer); -#ifdef USE_ROCM - // Quick Reduce all-reduce kernels - custom_ar.def( - "qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool " - "cast_bf2half) -> ()"); - custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce); - - custom_ar.def("init_custom_qr", &init_custom_qr); - custom_ar.def("qr_destroy", &qr_destroy); - - custom_ar.def("qr_get_handle", &qr_get_handle); - - custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()"); - custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles); - - // Max input size in bytes - custom_ar.def("qr_max_size", &qr_max_size); -#endif -} - REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/type_convert.cuh b/csrc/type_convert.cuh index 9d939bb828f..8093c4bc871 100644 --- a/csrc/type_convert.cuh +++ b/csrc/type_convert.cuh @@ -50,7 +50,7 @@ struct _typeConvert { #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) // CUDA < 12.0 runs into issues with packed type conversion template <> -struct _typeConvert { +struct _typeConvert { static constexpr bool exists = true; using hip_type = __half; using packed_hip_type = __half2; @@ -73,7 +73,7 @@ struct _typeConvert { // CUDA_ARCH < 800 does not have BF16 support // ROCm 7.0+ supports bfloat16 template <> -struct _typeConvert { +struct _typeConvert { static constexpr bool exists = true; using hip_type = __nv_bfloat16; using packed_hip_type = __nv_bfloat162; diff --git a/tests/renderers/test_hf.py b/tests/renderers/test_hf.py index c2eb6556394..0545457eb7a 100644 --- a/tests/renderers/test_hf.py +++ b/tests/renderers/test_hf.py @@ -7,6 +7,9 @@ from vllm.config import ModelConfig from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.renderers.hf import ( + _consolidate_system_messages, + _convert_developer_to_system, + _detect_developer_role_support, _get_hf_base_chat_template_params, _try_extract_ast, resolve_chat_template, @@ -592,3 +595,322 @@ def test_get_gen_prompt( f"The generated prompt does not match the expected output for " f"model {model} and template {template}" ) + + +class TestConvertDeveloperToSystem: + def test_converts_role(self): + conversation = [ + {"role": "developer", "content": "You are helpful."}, + {"role": "user", "content": "Hello"}, + ] + result = _convert_developer_to_system(conversation) + assert result[0]["role"] == "system" + assert result[0]["content"] == "You are helpful." + assert result[1]["role"] == "user" + + def test_removes_tools_key(self): + conversation = [ + { + "role": "developer", + "content": "Instructions", + "tools": [{"type": "function"}], + }, + ] + result = _convert_developer_to_system(conversation) + assert "tools" not in result[0] + + def test_no_developer_messages_unchanged(self): + conversation = [ + {"role": "system", "content": "System prompt"}, + {"role": "user", "content": "Hello"}, + ] + result = _convert_developer_to_system(conversation) + assert result[0]["role"] == "system" + assert result[1]["role"] == "user" + + def test_does_not_mutate_original(self): + original = { + "role": "developer", + "content": "Instructions", + "tools": [{"type": "function"}], + } + conversation = [original] + _convert_developer_to_system(conversation) + assert original["role"] == "developer" + assert "tools" in original + + +# --- Developer role detection and conversion tests --- + +CHATML_TEMPLATE = ( + "{% for message in messages %}" + "{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}" + "{% if (loop.last and add_generation_prompt) or not loop.last %}" + "{{ '<|im_end|>' + '\\n'}}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}" + "{{ '<|im_start|>assistant\\n' }}" + "{% endif %}" +) + +TEMPLATE_WITH_DEVELOPER = ( + "{% for message in messages %}" + "{% if message['role'] == 'developer' %}" + "{{'<|im_start|>developer\\n' + message['content'] + '<|im_end|>\\n'}}" + "{% elif message['role'] == 'system' %}" + "{{'<|im_start|>system\\n' + message['content'] + '<|im_end|>\\n'}}" + "{% elif message['role'] == 'user' %}" + "{{'<|im_start|>user\\n' + message['content'] + '<|im_end|>\\n'}}" + "{% elif message['role'] == 'assistant' %}" + "{{'<|im_start|>assistant\\n' + message['content'] + '<|im_end|>\\n'}}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}" + "{{ '<|im_start|>assistant\\n' }}" + "{% endif %}" +) + +STRICT_ROLE_TEMPLATE = ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{'<|im_start|>system\\n' + message['content'] + '<|im_end|>\\n'}}" + "{% elif message['role'] == 'user' %}" + "{{'<|im_start|>user\\n' + message['content'] + '<|im_end|>\\n'}}" + "{% elif message['role'] == 'assistant' %}" + "{{'<|im_start|>assistant\\n' + message['content'] + '<|im_end|>\\n'}}" + "{% else %}" + "{{ raise_exception('Unexpected message role: ' + message['role']) }}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}" + "{{ '<|im_start|>assistant\\n' }}" + "{% endif %}" +) + + +class TestDetectDeveloperRoleSupport: + def test_absent_in_chatml(self): + assert _detect_developer_role_support(CHATML_TEMPLATE) is False + + def test_present_double_quotes(self): + assert _detect_developer_role_support(TEMPLATE_WITH_DEVELOPER) is True + + def test_present_single_quotes(self): + template = TEMPLATE_WITH_DEVELOPER.replace('"developer"', "'developer'") + assert _detect_developer_role_support(template) is True + + def test_absent_in_strict_template(self): + assert _detect_developer_role_support(STRICT_ROLE_TEMPLATE) is False + + +class TestSafeApplyChatTemplateDeveloperRole: + @pytest.fixture + def model_config(self): + return ModelConfig( + "facebook/opt-125m", + tokenizer="facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + ) + + @pytest.fixture + def tokenizer(self): + return get_tokenizer("facebook/opt-125m") + + def test_developer_converted_to_system_for_chatml(self, model_config, tokenizer): + conversation = [ + {"role": "developer", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + ] + result = safe_apply_chat_template( + model_config, + tokenizer, + conversation, + chat_template=CHATML_TEMPLATE, + tokenize=False, + add_generation_prompt=True, + ) + assert "<|im_start|>system" in result + assert "You are a helpful assistant." in result + assert "<|im_start|>developer" not in result + + def test_developer_preserved_when_template_supports_it( + self, model_config, tokenizer + ): + conversation = [ + {"role": "developer", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + ] + result = safe_apply_chat_template( + model_config, + tokenizer, + conversation, + chat_template=TEMPLATE_WITH_DEVELOPER, + tokenize=False, + add_generation_prompt=True, + ) + assert "<|im_start|>developer" in result + assert "You are a helpful assistant." in result + + def test_developer_does_not_crash_strict_template(self, model_config, tokenizer): + conversation = [ + {"role": "developer", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + ] + result = safe_apply_chat_template( + model_config, + tokenizer, + conversation, + chat_template=STRICT_ROLE_TEMPLATE, + tokenize=False, + add_generation_prompt=True, + ) + assert "<|im_start|>system" in result + assert "You are a helpful assistant." in result + + def test_no_developer_messages_no_overhead(self, model_config, tokenizer): + conversation = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello"}, + ] + result = safe_apply_chat_template( + model_config, + tokenizer, + conversation, + chat_template=CHATML_TEMPLATE, + tokenize=False, + add_generation_prompt=True, + ) + assert "<|im_start|>system" in result + assert "You are helpful." in result + + def test_developer_at_non_first_position_consolidated( + self, model_config, tokenizer + ): + conversation = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "developer", "content": "Be concise."}, + {"role": "user", "content": "What is 2+2?"}, + ] + result = safe_apply_chat_template( + model_config, + tokenizer, + conversation, + chat_template=SYSTEM_FIRST_TEMPLATE, + tokenize=False, + add_generation_prompt=True, + ) + assert "<|im_start|>system" in result + assert "You are helpful." in result + assert "Be concise." in result + assert "What is 2+2?" in result + + def test_developer_only_no_prior_system(self, model_config, tokenizer): + conversation = [ + {"role": "user", "content": "Hello"}, + {"role": "developer", "content": "Be concise."}, + {"role": "user", "content": "What is 2+2?"}, + ] + result = safe_apply_chat_template( + model_config, + tokenizer, + conversation, + chat_template=SYSTEM_FIRST_TEMPLATE, + tokenize=False, + add_generation_prompt=True, + ) + assert "<|im_start|>system" in result + assert "Be concise." in result + + +SYSTEM_FIRST_TEMPLATE = ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{% if not loop.first %}" + "{{ raise_exception('System message must be at the beginning.') }}" + "{% endif %}" + "{{'<|im_start|>system\\n' + message['content'] + '<|im_end|>\\n'}}" + "{% elif message['role'] == 'user' %}" + "{{'<|im_start|>user\\n' + message['content'] + '<|im_end|>\\n'}}" + "{% elif message['role'] == 'assistant' %}" + "{{'<|im_start|>assistant\\n' + message['content'] + '<|im_end|>\\n'}}" + "{% else %}" + "{{ raise_exception('Unexpected message role: ' + message['role']) }}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}" + "{{ '<|im_start|>assistant\\n' }}" + "{% endif %}" +) + + +class TestConsolidateSystemMessages: + def test_no_system_messages_unchanged(self): + conversation = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + ] + result = _consolidate_system_messages(conversation) + assert result == conversation + + def test_single_system_at_start_unchanged(self): + conversation = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello"}, + ] + result = _consolidate_system_messages(conversation) + assert result == conversation + + def test_system_at_non_first_position_moved(self): + conversation = [ + {"role": "user", "content": "Hello"}, + {"role": "system", "content": "You are helpful."}, + ] + result = _consolidate_system_messages(conversation) + assert result[0]["role"] == "system" + assert result[0]["content"] == "You are helpful." + assert result[1]["role"] == "user" + assert result[1]["content"] == "Hello" + + def test_multiple_system_messages_merged(self): + conversation = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello"}, + {"role": "system", "content": "Be concise."}, + ] + result = _consolidate_system_messages(conversation) + assert len(result) == 2 + assert result[0]["role"] == "system" + assert result[0]["content"] == "You are helpful.\n\nBe concise." + assert result[1]["role"] == "user" + + def test_list_content_handled(self): + conversation = [ + {"role": "user", "content": "Hello"}, + { + "role": "system", + "content": [ + {"type": "text", "text": "Rule 1."}, + {"type": "text", "text": "Rule 2."}, + ], + }, + ] + result = _consolidate_system_messages(conversation) + assert result[0]["role"] == "system" + assert result[0]["content"] == "Rule 1.\nRule 2." + assert result[1]["role"] == "user" + + def test_does_not_mutate_original(self): + conversation = [ + {"role": "user", "content": "Hello"}, + {"role": "system", "content": "You are helpful."}, + ] + original_len = len(conversation) + _consolidate_system_messages(conversation) + assert len(conversation) == original_len + assert conversation[0]["role"] == "user" + assert conversation[1]["role"] == "system" diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 2d9834d2e3a..7213a669c53 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -24,6 +24,7 @@ from vllm.utils.hashing import sha256 from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.core.single_type_kv_cache_manager import register_all_kvcache_specs from vllm.v1.kv_cache_interface import ( FullAttentionSpec, KVCacheConfig, @@ -160,6 +161,7 @@ def create_scheduler( ], ) cache_config.num_gpu_blocks = num_blocks + register_all_kvcache_specs(vllm_config) scheduler_cls = AsyncScheduler if async_scheduling else Scheduler return scheduler_cls( vllm_config=vllm_config, diff --git a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh index 040632249d3..d0a56304f2a 100755 --- a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh @@ -21,19 +21,20 @@ dp_ep_configs=( "DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1) "DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1) ) +# We assume HMA enabled by default. hybrid_ssm_configs=( - "VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code" + "VLLM_SSM_CONV_STATE_LAYOUT=DS GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code" # TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models. - "VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling" + "VLLM_SSM_CONV_STATE_LAYOUT=DS PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling" # GDN (Qwen3.5) - "VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=Qwen/Qwen3.5-0.8B" - "VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=Qwen/Qwen3.5-0.8B VLLM_SERVE_EXTRA_ARGS=--no-async-scheduling" + "VLLM_SSM_CONV_STATE_LAYOUT=DS GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=Qwen/Qwen3.5-0.8B" + "VLLM_SSM_CONV_STATE_LAYOUT=DS PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=Qwen/Qwen3.5-0.8B VLLM_SERVE_EXTRA_ARGS=--no-async-scheduling" ) sw_attn_configs=( # NOTE: gemma3 does not work with FlashInfer "GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192" # SW model - "ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192" - "ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192" + "GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192" + "GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192" ) # Select config array based on DP_EP env var @@ -50,14 +51,6 @@ else configs=("${tp_configs[@]}") fi -if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then - # Append ENABLE_HMA_FLAG=1 to each config in the selected array - echo "ENABLE_HMA_FLAG is set, appending ENABLE_HMA_FLAG=1 to each config" - for i in "${!configs[@]}"; do - configs[$i]="ENABLE_HMA_FLAG=1 ${configs[$i]}" - done -fi - run_tests() { local label=$1 local extra_args=$2 diff --git a/tests/v1/kv_connector/nixl_integration/config_sweep_spec_decode_test.sh b/tests/v1/kv_connector/nixl_integration/config_sweep_spec_decode_test.sh index 313efc3968d..f55bd308a0a 100755 --- a/tests/v1/kv_connector/nixl_integration/config_sweep_spec_decode_test.sh +++ b/tests/v1/kv_connector/nixl_integration/config_sweep_spec_decode_test.sh @@ -11,7 +11,7 @@ SCRIPT="v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh" eagle3_config="SD_METHOD=eagle3 MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct SD_MODEL=RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3 NUM_SPEC_TOKENS=3" # MTP: Qwen3.5-0.8B-Base with hybrid SSM flags. -mtp_config="SD_METHOD=mtp MODEL_NAME=Qwen/Qwen3.5-0.8B-Base SD_MODEL=Qwen/Qwen3.5-0.8B-Base NUM_SPEC_TOKENS=1 BLOCK_SIZE=32 MAX_MODEL_LEN=4096 VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 KV_BUFFER_DEVICES=cuda" +mtp_config="SD_METHOD=mtp MODEL_NAME=Qwen/Qwen3.5-0.8B-Base SD_MODEL=Qwen/Qwen3.5-0.8B-Base NUM_SPEC_TOKENS=1 BLOCK_SIZE=32 MAX_MODEL_LEN=4096 VLLM_SSM_CONV_STATE_LAYOUT=DS KV_BUFFER_DEVICES=cuda" configs=( "$eagle3_config" diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index fc446a0e765..bde246c9b66 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -5,11 +5,6 @@ set -xe KV_BUFFER_DEVICE="cuda" # Default to cuda ATTENTION_BACKEND="" # Default to empty (use vllm default) CROSS_LAYERS_BLOCKS="False" -ENABLE_HMA_VAR="" # Default to empty (HMA disabled by default for kv connector) -# Check for ENABLE_HMA_FLAG environment variable -if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then - ENABLE_HMA_VAR="--no-disable-hybrid-kv-cache-manager" -fi while [[ $# -gt 0 ]]; do case $1 in @@ -37,9 +32,6 @@ echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE" if [[ -n "$ATTENTION_BACKEND" ]]; then echo "Using attention backend: $ATTENTION_BACKEND" fi -if [[ -n "$ENABLE_HMA_VAR" ]]; then - echo "HMA (Hybrid KV Cache Manager) enabled" -fi if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then echo "vLLM serve extra args: $VLLM_SERVE_EXTRA_ARGS" fi @@ -180,10 +172,6 @@ run_tests_for_model() { BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND" fi - # Add HMA flag if specified - if [[ -n "$ENABLE_HMA_VAR" ]]; then - BASE_CMD="${BASE_CMD} $ENABLE_HMA_VAR" - fi FULL_CMD="$BASE_CMD" eval "$FULL_CMD &" @@ -232,10 +220,6 @@ run_tests_for_model() { BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND" fi - # Add HMA flag if specified - if [[ -n "$ENABLE_HMA_VAR" ]]; then - BASE_CMD="${BASE_CMD} $ENABLE_HMA_VAR" - fi # DP-EP attention mode if [[ -z "$DP_EP" ]]; then diff --git a/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh b/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh index 2c5622a2f0e..bc90680a533 100755 --- a/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh +++ b/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh @@ -27,7 +27,6 @@ # ROCM_AITER_UNIFIED_ATTN # NVIDIA options: FLASH_ATTN, FLASHINFER # VLLM_SSM_CONV_STATE_LAYOUT - SSM conv state layout (e.g. "DS" required for Mamba models) -# ENABLE_HMA_FLAG - set to 1 to enable hybrid KV cache manager # VLLM_SERVE_EXTRA_ARGS - comma-separated extra args for vllm serve set -ex @@ -85,13 +84,7 @@ if [[ -z "${ATTENTION_BACKEND:-}" ]]; then fi echo "Using attention backend: ${ATTENTION_BACKEND}" -# ── HMA & extra serve args ──────────────────────────────────────────── - -ENABLE_HMA_VAR="" -if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then - ENABLE_HMA_VAR="--no-disable-hybrid-kv-cache-manager" - echo "HMA (Hybrid KV Cache Manager) enabled" -fi +# ── Extra serve args ───────────────────────────────────────────────── EXTRA_SERVE_ARGS=() if [[ -n "${VLLM_SERVE_EXTRA_ARGS:-}" ]]; then @@ -258,7 +251,6 @@ run_test_for_device() { --kv-transfer-config "$kv_config" \ --speculative-config "$PREFILL_SPEC_CONFIG" \ --attention-backend $ATTENTION_BACKEND \ - ${ENABLE_HMA_VAR} \ ${EXTRA_SERVE_ARGS[@]+"${EXTRA_SERVE_ARGS[@]}"} & local SERVER_PID=$! @@ -298,7 +290,6 @@ run_test_for_device() { --kv-transfer-config "$kv_config" \ --speculative-config "$DECODE_SPEC_CONFIG" \ --attention-backend $ATTENTION_BACKEND \ - ${ENABLE_HMA_VAR} \ ${EXTRA_SERVE_ARGS[@]+"${EXTRA_SERVE_ARGS[@]}"} & local SERVER_PID=$! diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 6d4e6565e37..8d54353f82a 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -386,8 +386,6 @@ def test_fewer_blocks_with_hma(monkeypatch, model_name, sw_size): "kv_transfer_config": kv_transfer_config, "max_model_len": 2048, "max_num_seqs": 1, - # NOTE: Make sure HMA is enabled - "disable_hybrid_kv_cache_manager": False, "max_num_batched_tokens": 2048, "enable_prefix_caching": False, "block_size": block_size, diff --git a/tests/v1/simple_kv_offload/test_scheduler.py b/tests/v1/simple_kv_offload/test_scheduler.py index 970e16e5279..e59905f504a 100644 --- a/tests/v1/simple_kv_offload/test_scheduler.py +++ b/tests/v1/simple_kv_offload/test_scheduler.py @@ -30,6 +30,9 @@ from vllm.v1.core.sched.output import ( NewRequestData, SchedulerOutput, ) +from vllm.v1.core.single_type_kv_cache_manager import ( + register_all_kvcache_specs, +) from vllm.v1.kv_cache_interface import ( FullAttentionSpec, KVCacheConfig, @@ -68,6 +71,9 @@ def _make_kv_cache_config( """Build a KVCacheConfig with non-empty kv_cache_tensors.""" groups = [] tensors = [] + register_all_kvcache_specs( + vllm_config=None + ) # Ensure specs are registered for tests for g in range(num_groups): layer_names = [f"layer_{g}"] groups.append( diff --git a/tests/v1/test_kv_cache_spec_registry.py b/tests/v1/test_kv_cache_spec_registry.py new file mode 100644 index 00000000000..c84c3ef1c8a --- /dev/null +++ b/tests/v1/test_kv_cache_spec_registry.py @@ -0,0 +1,322 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass, replace +from typing import Any + +import pytest +import torch + +from vllm.config import ( + CacheConfig, + DeviceConfig, + VllmConfig, +) +from vllm.v1.core.single_type_kv_cache_manager import ( + ChunkedLocalAttentionManager, + CrossAttentionManager, + FullAttentionManager, + MambaManager, + SingleTypeKVCacheManager, + SinkFullAttentionManager, + SlidingWindowManager, + register_all_kvcache_specs, +) +from vllm.v1.kv_cache_interface import ( + ChunkedLocalAttentionSpec, + CrossAttentionSpec, + FullAttentionSpec, + HiddenStateCacheSpec, + KVCacheSpec, + MambaSpec, + MLAAttentionSpec, + SinkFullAttentionSpec, + SlidingWindowMLASpec, + SlidingWindowSpec, + TQFullAttentionSpec, + UniformTypeKVCacheSpecs, +) +from vllm.v1.kv_cache_spec_registry import ( + _REGISTRY_KVCACHESPEC_LIST, + KVCacheSpecRegistry, + register_kv_cache_spec, +) + + +def make_vllm_config() -> VllmConfig: + return VllmConfig( + cache_config=CacheConfig( + block_size=64, + cache_dtype="bfloat16", + ), + device_config=DeviceConfig(device="cpu"), + ) + + +vllm_config = make_vllm_config() +register_all_kvcache_specs(vllm_config) + + +@pytest.fixture(autouse=True) +def restore_kv_cache_spec_registry(): + registry = _REGISTRY_KVCACHESPEC_LIST.copy() + yield + _REGISTRY_KVCACHESPEC_LIST.clear() + _REGISTRY_KVCACHESPEC_LIST.update(registry) + + +@dataclass(frozen=True) +class _TrulyUnregisteredSpec(KVCacheSpec): + """ + A spec that inherits directly from KVCacheSpec with no registered + ancestor in the MRO. Used to test that the registry correctly raises + when no entry can be found. + """ + + @property + def page_size_bytes(self) -> int: + return self.block_size * 128 + + def max_memory_usage_bytes(self, _) -> int: + return 0 + + +spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { + FullAttentionSpec: FullAttentionManager, + TQFullAttentionSpec: FullAttentionManager, + MLAAttentionSpec: FullAttentionManager, + HiddenStateCacheSpec: FullAttentionManager, + SlidingWindowSpec: SlidingWindowManager, + SlidingWindowMLASpec: SlidingWindowManager, + ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, + MambaSpec: MambaManager, + CrossAttentionSpec: CrossAttentionManager, + SinkFullAttentionSpec: SinkFullAttentionManager, +} + +spec_uniform_base_map: dict[type[KVCacheSpec], type[KVCacheSpec]] = { + FullAttentionSpec: FullAttentionSpec, + TQFullAttentionSpec: FullAttentionSpec, + MLAAttentionSpec: FullAttentionSpec, + HiddenStateCacheSpec: FullAttentionSpec, + SlidingWindowSpec: SlidingWindowSpec, + SlidingWindowMLASpec: SlidingWindowMLASpec, + ChunkedLocalAttentionSpec: ChunkedLocalAttentionSpec, + MambaSpec: MambaSpec, + CrossAttentionSpec: CrossAttentionSpec, + SinkFullAttentionSpec: FullAttentionSpec, +} + +spec_args_map: dict[type[KVCacheSpec], dict[str, Any]] = { + FullAttentionSpec: dict( + block_size=64, num_kv_heads=8, head_size=128, dtype=torch.bfloat16 + ), + TQFullAttentionSpec: dict( + block_size=64, + num_kv_heads=8, + head_size=128, + dtype=torch.bfloat16, + tq_slot_size=256, + ), + MLAAttentionSpec: dict( + block_size=64, num_kv_heads=1, head_size=128, dtype=torch.bfloat16 + ), + HiddenStateCacheSpec: dict( + block_size=64, num_kv_heads=1, head_size=128, dtype=torch.bfloat16 + ), + SlidingWindowSpec: dict( + block_size=64, + num_kv_heads=8, + head_size=128, + dtype=torch.bfloat16, + sliding_window=128, + ), + SlidingWindowMLASpec: dict( + block_size=64, + num_kv_heads=1, + head_size=128, + dtype=torch.bfloat16, + sliding_window=128, + ), + ChunkedLocalAttentionSpec: dict( + block_size=64, + num_kv_heads=8, + head_size=128, + dtype=torch.bfloat16, + attention_chunk_size=4, + ), + MambaSpec: dict( + block_size=64, + shapes=((2, 512), (3, 32, 32)), + dtypes=(torch.float32, torch.float32), + mamba_cache_mode="align", + num_speculative_blocks=2, + ), + CrossAttentionSpec: dict( + block_size=64, num_kv_heads=8, head_size=128, dtype=torch.bfloat16 + ), + SinkFullAttentionSpec: dict( + block_size=64, num_kv_heads=8, head_size=128, dtype=torch.bfloat16, sink_len=16 + ), +} + + +def make_spec(spec_cls: type[KVCacheSpec]) -> KVCacheSpec: + return spec_cls(**spec_args_map[spec_cls]) + + +def are_uniform_specs(*specs: KVCacheSpec) -> bool: + return UniformTypeKVCacheSpecs.is_uniform_type( + {f"layer_{i}": spec for i, spec in enumerate(specs)} + ) + + +class TestKVCacheSpecRegistry: + """Test the core registry functionality.""" + + def test_builtin_kvcache_specs_registered(self): + assert set(spec_manager_map) <= set(_REGISTRY_KVCACHESPEC_LIST) + for spec_cls, manager in spec_manager_map.items(): + spec = make_spec(spec_cls) + assert KVCacheSpecRegistry.get_manager_class(spec) is manager + assert ( + KVCacheSpecRegistry.get_uniform_type_base_spec(spec) + is spec_uniform_base_map[spec_cls] + ) + + @pytest.mark.parametrize("spec_cls", list(spec_manager_map)) + def test_custom_spec_register(self, spec_cls): + """A decorated custom spec resolves to the declared manager.""" + manager = spec_manager_map[spec_cls] + uniform_base_spec = spec_uniform_base_map[spec_cls] + + @register_kv_cache_spec( + manager_class=manager, + uniform_type_base_spec=uniform_base_spec, + ) + @dataclass(frozen=True, kw_only=True) + class _CustomSpec(spec_cls): # type: ignore[valid-type,misc] + custom_param: int = 16 + + spec = _CustomSpec(**spec_args_map[spec_cls], custom_param=100) + + assert KVCacheSpecRegistry.get_manager_class(spec) is manager + assert KVCacheSpecRegistry.get_uniform_type_base_spec(spec) is uniform_base_spec + + def test_custom_spec_register_requires_manager(self): + """Invalid register decorator arguments fail early.""" + + with pytest.raises(AssertionError, match="manager_class is required"): + + @register_kv_cache_spec( + uniform_type_base_spec=FullAttentionSpec, + ) + @dataclass(frozen=True, kw_only=True) + class _CustomFullSpecWithoutManager(FullAttentionSpec): + custom_param: int = 16 + + def test_unregistered_spec_no_registered_parent_raises(self): + """ + A spec whose entire MRO contains no registered class resolves to None. + Runtime callers should use check_kv_cache_spec_registry to fail early. + Subclasses of registered specs intentionally do not fail — they inherit + their parent's manager via MRO walking. + """ + spec = _TrulyUnregisteredSpec(block_size=16) + + assert KVCacheSpecRegistry.get_manager_class(spec) is None + assert KVCacheSpecRegistry.get_uniform_type_base_spec(spec) is None + + with pytest.raises( + ValueError, match="Unsupported KV cache spec type for layer layer_0" + ): + KVCacheSpecRegistry.check_kv_cache_spec_registry({"layer_0": spec}) + + with pytest.raises(AssertionError, match="Unsupported KV cache spec type"): + UniformTypeKVCacheSpecs.is_uniform_type({"layer_0": spec}) + + def test_unregistered_subclass_inherits_parent_manager(self): + """ + An unregistered subclass of a registered spec resolves via MRO + to its parent's manager — this is intentional registry behaviour. + """ + + @dataclass(frozen=True, kw_only=True) + class _ImplicitlyInheritedSpec(FullAttentionSpec): + pass + + spec = _ImplicitlyInheritedSpec( + block_size=16, num_kv_heads=8, head_size=128, dtype=torch.bfloat16 + ) + + # MRO walk finds FullAttentionSpec → FullAttentionManager + assert KVCacheSpecRegistry.get_manager_class(spec) is FullAttentionManager + + @pytest.mark.parametrize("spec_cls", list(spec_manager_map)) + def test_builtin_specs_are_uniform_with_same_spec_type(self, spec_cls): + spec = make_spec(spec_cls) + assert are_uniform_specs(spec, replace(spec)) + + def test_full_attention_family_specs_are_uniform(self): + specs = [ + make_spec(FullAttentionSpec), + make_spec(TQFullAttentionSpec), + make_spec(MLAAttentionSpec), + make_spec(HiddenStateCacheSpec), + make_spec(SinkFullAttentionSpec), + ] + + assert are_uniform_specs(*specs) + + @pytest.mark.parametrize( + ("spec_cls", "field", "value"), + [ + (SlidingWindowSpec, "sliding_window", 256), + (SlidingWindowMLASpec, "sliding_window", 256), + (ChunkedLocalAttentionSpec, "attention_chunk_size", 8), + (MambaSpec, "num_speculative_blocks", 4), + ], + ) + def test_specs_with_type_specific_uniform_fields(self, spec_cls, field, value): + spec = make_spec(spec_cls) + changed_spec = replace(spec, **{field: value}) + + assert not are_uniform_specs(spec, changed_spec) + + @pytest.mark.parametrize( + ("left_cls", "right_cls"), + [ + (FullAttentionSpec, CrossAttentionSpec), + (FullAttentionSpec, SlidingWindowSpec), + (FullAttentionSpec, ChunkedLocalAttentionSpec), + (FullAttentionSpec, MambaSpec), + (SlidingWindowMLASpec, SlidingWindowSpec), + (ChunkedLocalAttentionSpec, SlidingWindowSpec), + (MambaSpec, CrossAttentionSpec), + ], + ) + def test_different_uniform_groups_are_not_uniform(self, left_cls, right_cls): + assert not are_uniform_specs(make_spec(left_cls), make_spec(right_cls)) + + def test_different_block_sizes_are_not_uniform(self): + spec = make_spec(FullAttentionSpec) + + assert not are_uniform_specs(spec, replace(spec, block_size=32)) + + def test_registered_custom_spec_uses_base_uniform_rule(self): + @register_kv_cache_spec( + manager_class=FullAttentionManager, + uniform_type_base_spec=FullAttentionSpec, + ) + @dataclass(frozen=True, kw_only=True) + class _CustomFullSpec(FullAttentionSpec): + custom_param: int = 16 + + custom_spec = _CustomFullSpec( + block_size=64, + num_kv_heads=8, + head_size=128, + dtype=torch.bfloat16, + ) + + assert are_uniform_specs(custom_spec, make_spec(FullAttentionSpec)) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index fb6951ea7dd..7900c948480 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -53,7 +53,7 @@ class SchedulerConfig: In real usage, this should be set in `EngineArgs.create_engine_config`. """ - max_num_scheduled_tokens: int | None = None + max_num_scheduled_tokens: int | None = Field(default=None, ge=0) """Maximum number of tokens that the scheduler may issue in a single iteration. This is usually equal to max_num_batched_tokens, but can be smaller in cases diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index dc7e6d151a4..9482568461c 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -38,9 +38,19 @@ from vllm.utils.network_utils import ( is_valid_ipv6_address, ) -if envs.VLLM_USE_SPINLOOP_EXT: - from vllm.spinloop import spinloop +logger = init_logger(__name__) + +SPINLOOP_EXT_ENABLED = False +if envs.VLLM_USE_SPINLOOP_EXT: + try: + from vllm.spinloop import spinloop + + SPINLOOP_EXT_ENABLED = True + except ImportError: + logger.warning( + "spinloop extension could not be loaded, disabling VLLM_USE_SPINLOOP_EXT!" + ) SPINLOOP_TIMEOUT_SECONDS = 0.1 if TYPE_CHECKING: @@ -82,9 +92,6 @@ def to_bytes_big(value: int, size: int) -> bytes: return value.to_bytes(size, byteorder="big") -logger = init_logger(__name__) - - LONG_WAIT_TIME_LOG_MSG = ( "No available shared memory broadcast block found " "in %d seconds. This typically happens " @@ -552,7 +559,7 @@ class MessageQueue: written_flag = metadata_buffer[0] return not (written_flag and read_count != self.buffer.n_reader) - if envs.VLLM_USE_SPINLOOP_EXT and not check(): + if SPINLOOP_EXT_ENABLED and not check(): spinloop(metadata_buffer, check, timeout=SPINLOOP_TIMEOUT_SECONDS) if not check(): @@ -673,7 +680,7 @@ class MessageQueue: written_flag = metadata_buffer[0] return not (not written_flag or read_flag) - if envs.VLLM_USE_SPINLOOP_EXT and not check(): + if SPINLOOP_EXT_ENABLED and not check(): spinloop( metadata_buffer[0 : self.local_reader_rank + 1], check, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/coordinator.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/coordinator.py index a17f0b5f5ff..ad528140966 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/coordinator.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/coordinator.py @@ -13,7 +13,6 @@ from vllm.v1.core.kv_cache_utils import ( ) from vllm.v1.core.single_type_kv_cache_manager import ( SingleTypeKVCacheManager, - spec_manager_map, ) from vllm.v1.kv_cache_interface import ( FullAttentionSpec, @@ -21,6 +20,7 @@ from vllm.v1.kv_cache_interface import ( KVCacheSpec, UniformTypeKVCacheSpecs, ) +from vllm.v1.kv_cache_spec_registry import KVCacheSpecRegistry # Dummy placeholder hash for store_mask's template computation. _DUMMY_BLOCK_HASH = BlockHash(b"\x00" * 32) @@ -89,7 +89,10 @@ class MooncakeStoreCoordinator: ] = [] for i, g in enumerate(self.kv_cache_groups): spec = _unwrap_spec(g.kv_cache_spec) - manager_cls = spec_manager_map[type(spec)] + manager_cls = KVCacheSpecRegistry.get_manager_class(spec) + assert manager_cls is not None, ( + f"No manager registered for KVCacheSpec {spec}" + ) for existing_spec, group_ids, existing_cls in attention_groups: if existing_spec == spec: assert manager_cls is existing_cls diff --git a/vllm/entrypoints/openai/models/serving.py b/vllm/entrypoints/openai/models/serving.py index 347752c912c..504d30f69d2 100644 --- a/vllm/entrypoints/openai/models/serving.py +++ b/vllm/entrypoints/openai/models/serving.py @@ -194,10 +194,16 @@ class OpenAIServingModels: lora_request.lora_name, lora_request.lora_path ) ) in str(e): - raise LoRAAdapterNotFoundError( - lora_request.lora_name, lora_request.lora_path - ) from e - raise + return create_error_response( + LoRAAdapterNotFoundError( + lora_request.lora_name, lora_request.lora_path + ) + ) + return create_error_response( + message=str(e), + err_type="InternalServerError", + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + ) self.lora_requests[lora_name] = lora_request logger.info( diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 546361229e1..b357c5798bf 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -698,6 +698,13 @@ class Platform: mamba_padding_pct, ) + @classmethod + def register_custom_kv_cache_specs(cls, vllm_config: "VllmConfig") -> None: + """ + Register custom KVCacheSpec class on current platform. + """ + pass + @classmethod def verify_model_arch(cls, model_arch: str) -> None: """ diff --git a/vllm/renderers/hf.py b/vllm/renderers/hf.py index e796607722a..e57d0586aa0 100644 --- a/vllm/renderers/hf.py +++ b/vllm/renderers/hf.py @@ -450,6 +450,66 @@ def _detect_content_format( return "openai" +@lru_cache(maxsize=32) +def _detect_developer_role_support(chat_template: str) -> bool: + return '"developer"' in chat_template or "'developer'" in chat_template + + +def _convert_developer_to_system( + conversation: list[ConversationMessage], +) -> list[ConversationMessage]: + converted: list[ConversationMessage] = [] + for msg in conversation: + if msg["role"] == "developer": + new_msg = dict(msg) + new_msg["role"] = "system" + new_msg.pop("tools", None) + converted.append(new_msg) # type: ignore[arg-type] + else: + converted.append(msg) + return converted + + +def _consolidate_system_messages( + conversation: list[ConversationMessage], +) -> list[ConversationMessage]: + """Merge all system messages into one at position 0. + + Some chat templates (e.g. Qwen 3.6) require the system message to be the + very first message. After developer-to-system conversion, system messages + may appear at non-first positions; this merges them into a single message. + """ + system_contents: list[str] = [] + non_system: list[ConversationMessage] = [] + needs_consolidation = False + for i, msg in enumerate(conversation): + if msg["role"] == "system": + if i > 0 or system_contents: + needs_consolidation = True + content = msg.get("content", "") + if isinstance(content, list): + parts = [] + for part in content: + if isinstance(part, dict) and "text" in part: + parts.append(part["text"]) + elif isinstance(part, str): + parts.append(part) + content = "\n".join(parts) + if content: + system_contents.append(content) + else: + non_system.append(msg) + + if not needs_consolidation: + return conversation + + merged: ConversationMessage = { + "role": "system", + "content": "\n\n".join(system_contents), + } + return [merged, *non_system] + + def _resolve_chat_template_content_format( chat_template: str | None, tools: list[dict[str, Any]] | None, @@ -653,7 +713,15 @@ def safe_apply_chat_template( "allowed, so you must provide a chat template if the tokenizer " "does not define one." ) - + if any( + msg["role"] == "developer" for msg in conversation + ) and not _detect_developer_role_support(chat_template): + conversation = _convert_developer_to_system(conversation) + conversation = _consolidate_system_messages(conversation) + logger.info_once( + "Chat template does not support the 'developer' message role. " + "Converting developer messages to 'system' role.", + ) resolved_kwargs = resolve_chat_template_kwargs( tokenizer=tokenizer, chat_template=chat_template, diff --git a/vllm/tool_parsers/glm4_moe_tool_parser.py b/vllm/tool_parsers/glm4_moe_tool_parser.py index 1779896e5b6..213a774535b 100644 --- a/vllm/tool_parsers/glm4_moe_tool_parser.py +++ b/vllm/tool_parsers/glm4_moe_tool_parser.py @@ -11,7 +11,6 @@ The fix streams string values incrementally as they arrive, providing a true streaming experience for long content. """ -import ast import json from collections.abc import Sequence from typing import Any @@ -42,6 +41,7 @@ from vllm.tool_parsers.utils import ( extract_types_from_schema, find_tool_properties, partial_tag_overlap, + safe_literal_eval, ) logger = init_logger(__name__) @@ -110,7 +110,7 @@ class Glm4MoeModelToolParser(ToolParser): pass try: - return ast.literal_eval(value) + return safe_literal_eval(value) except (ValueError, SyntaxError): pass diff --git a/vllm/tool_parsers/hy_v3_tool_parser.py b/vllm/tool_parsers/hy_v3_tool_parser.py index 496deb4f2d5..619be5e9cc2 100644 --- a/vllm/tool_parsers/hy_v3_tool_parser.py +++ b/vllm/tool_parsers/hy_v3_tool_parser.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import ast import json from collections.abc import Sequence from typing import Any @@ -27,6 +26,7 @@ from vllm.tool_parsers.abstract_tool_parser import ( Tool, ToolParser, ) +from vllm.tool_parsers.utils import safe_literal_eval logger = init_logger(__name__) @@ -183,13 +183,13 @@ class HYV3ToolParser(ToolParser): @staticmethod def _deserialize(value: str) -> Any: - """Deserialize a string value using json.loads then ast.literal_eval.""" + """Deserialize a string value using json.loads then safe_literal_eval.""" try: return json.loads(value) except Exception: pass try: - return ast.literal_eval(value) + return safe_literal_eval(value) except Exception: pass return value diff --git a/vllm/tool_parsers/minicpm5xml_tool_parser.py b/vllm/tool_parsers/minicpm5xml_tool_parser.py index fed9677411b..a5b5252415c 100644 --- a/vllm/tool_parsers/minicpm5xml_tool_parser.py +++ b/vllm/tool_parsers/minicpm5xml_tool_parser.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import ast import json from collections.abc import Sequence from typing import Any @@ -28,7 +27,7 @@ from vllm.tool_parsers.abstract_tool_parser import ( Tool, ToolParser, ) -from vllm.tool_parsers.utils import partial_tag_overlap +from vllm.tool_parsers.utils import partial_tag_overlap, safe_literal_eval from vllm.utils import random_uuid logger = init_logger(__name__) @@ -116,7 +115,7 @@ def _parse_arguments(json_value: str) -> tuple[Any, bool]: try: parsed_value = json.loads(json_value) except json.JSONDecodeError: - parsed_value = ast.literal_eval(json_value) + parsed_value = safe_literal_eval(json_value) return parsed_value, True except Exception: return json_value, False diff --git a/vllm/tool_parsers/poolside_v1_tool_parser.py b/vllm/tool_parsers/poolside_v1_tool_parser.py index f14b4736291..e515e1ce637 100644 --- a/vllm/tool_parsers/poolside_v1_tool_parser.py +++ b/vllm/tool_parsers/poolside_v1_tool_parser.py @@ -11,7 +11,6 @@ The fix streams string values incrementally as they arrive, providing a true streaming experience for long content. """ -import ast import json from collections.abc import Sequence from typing import Any @@ -41,6 +40,7 @@ from vllm.tool_parsers.abstract_tool_parser import ( Tool, ToolParser, ) +from vllm.tool_parsers.utils import safe_literal_eval logger = init_logger(__name__) @@ -106,7 +106,7 @@ class PoolsideV1ToolParser(ToolParser): pass try: - return ast.literal_eval(value) + return safe_literal_eval(value) except (ValueError, SyntaxError): pass diff --git a/vllm/tool_parsers/qwen3xml_tool_parser.py b/vllm/tool_parsers/qwen3xml_tool_parser.py index d5b87ea074e..e5d2b896e00 100644 --- a/vllm/tool_parsers/qwen3xml_tool_parser.py +++ b/vllm/tool_parsers/qwen3xml_tool_parser.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import ast import json from collections.abc import Sequence from typing import Any @@ -26,7 +25,7 @@ from vllm.tool_parsers.abstract_tool_parser import ( Tool, ToolParser, ) -from vllm.tool_parsers.utils import find_tool_properties +from vllm.tool_parsers.utils import find_tool_properties, safe_literal_eval logger = init_logger(__name__) @@ -824,7 +823,7 @@ class StreamingXMLToolCallParser: try: parsed_value = json.loads(raw_for_parse) except json.JSONDecodeError: - parsed_value = ast.literal_eval(raw_for_parse) + parsed_value = safe_literal_eval(raw_for_parse) output_arguments = json.dumps(parsed_value, ensure_ascii=False) except Exception: # Fallback: output as string as-is diff --git a/vllm/tool_parsers/step3p5_tool_parser.py b/vllm/tool_parsers/step3p5_tool_parser.py index b46f899ce2c..8a48e2686e5 100644 --- a/vllm/tool_parsers/step3p5_tool_parser.py +++ b/vllm/tool_parsers/step3p5_tool_parser.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import ast import json from collections.abc import Sequence from typing import Any @@ -23,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import ( from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser +from vllm.tool_parsers.utils import safe_literal_eval logger = init_logger(__name__) @@ -1016,7 +1016,7 @@ class StreamingXMLToolCallParser: raw_for_parse = raw_text + "\n" else: raw_for_parse = raw_text - parsed_value = ast.literal_eval(raw_for_parse) + parsed_value = safe_literal_eval(raw_for_parse) output_arguments = json.dumps(parsed_value, ensure_ascii=False) except Exception: # Fallback: output as string as-is diff --git a/vllm/tool_parsers/utils.py b/vllm/tool_parsers/utils.py index 1c7830b320f..6ee107433c5 100644 --- a/vllm/tool_parsers/utils.py +++ b/vllm/tool_parsers/utils.py @@ -3,6 +3,7 @@ import ast import json +import warnings from json import JSONDecodeError, JSONDecoder from typing import Any, TypeAlias @@ -31,6 +32,12 @@ Tool: TypeAlias = ChatCompletionToolsParam | ResponsesTool logger = init_logger(__name__) +def safe_literal_eval(text: str): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", SyntaxWarning) + return ast.literal_eval(text) + + def partial_tag_overlap(text: str, tag: str) -> int: """Length of the longest prefix of *tag* that matches a suffix of *text*. diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 7f3a5e4fdf3..cfa79f077a1 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -33,6 +33,7 @@ from vllm.v1.kv_cache_interface import ( SlidingWindowSpec, UniformTypeKVCacheSpecs, ) +from vllm.v1.kv_cache_spec_registry import KVCacheSpecRegistry from vllm.v1.request import Request from vllm.v1.utils import tensor_data @@ -1991,6 +1992,9 @@ def get_kv_cache_configs( "across workers. This is not supported yet." ) + # Check if the KV cache specs are registered correctly. + # This is to prevent that some layers are initialized with unregistered specs. + KVCacheSpecRegistry.check_kv_cache_spec_registry(merged_kv_cache_specs) # Get global KV cache groups. This also handles spec unification for # hybrid models when disable_hybrid_kv_cache_manager is enabled. # After this call, merged_kv_cache_specs may be modified in-place. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 40d7fc8b336..22e2858f908 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -104,7 +104,7 @@ class Scheduler(SchedulerInterface): self.max_num_running_reqs = self.scheduler_config.max_num_seqs self.max_num_scheduled_tokens = ( self.scheduler_config.max_num_scheduled_tokens - if self.scheduler_config.max_num_scheduled_tokens + if self.scheduler_config.max_num_scheduled_tokens is not None else self.scheduler_config.max_num_batched_tokens ) self.max_model_len = vllm_config.model_config.max_model_len diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 7b16d9c6f05..281b79639db 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -25,6 +25,7 @@ from vllm.v1.kv_cache_interface import ( SlidingWindowSpec, TQFullAttentionSpec, ) +from vllm.v1.kv_cache_spec_registry import KVCacheSpecRegistry from vllm.v1.request import Request @@ -1247,27 +1248,30 @@ class SinkFullAttentionManager(FullAttentionManager): self.sink_blocks = self.block_pool.free_block_queue.popleft_n(num_sink_block) -spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { - FullAttentionSpec: FullAttentionManager, - TQFullAttentionSpec: FullAttentionManager, - MLAAttentionSpec: FullAttentionManager, - HiddenStateCacheSpec: FullAttentionManager, - SlidingWindowSpec: SlidingWindowManager, - SlidingWindowMLASpec: SlidingWindowManager, - ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, - MambaSpec: MambaManager, - CrossAttentionSpec: CrossAttentionManager, - SinkFullAttentionSpec: SinkFullAttentionManager, -} - - def get_manager_for_kv_cache_spec( kv_cache_spec: KVCacheSpec, max_num_batched_tokens: int, max_model_len: int, **kwargs, ) -> SingleTypeKVCacheManager: - manager_class = spec_manager_map[type(kv_cache_spec)] + """ + Get the appropriate manager for a given KVCacheSpec. + + Uses the KVCacheSpecRegistry to look up the manager class, supporting + both built-in and custom specs registered via @register_kv_cache_spec + and KVCacheSpecRegistry.register. + + Args: + kv_cache_spec: The KVCacheSpec instance + max_num_batched_tokens: The maximum number of tokens in a batch + max_model_len: The maximum context length the model could serve + Returns: + An instance of the appropriate SingleTypeKVCacheManager subclass + """ + manager_class = KVCacheSpecRegistry.get_manager_class(kv_cache_spec) + assert manager_class is not None, ( + f"No manager registered for KVCacheSpec {type(kv_cache_spec)}" + ) # SlidingWindow / ChunkedLocalAttention managers recycle blocks across # chunks; the runtime admission cap must match the recycling-aware bound # the startup pool sizer uses (single source of truth: the spec method). @@ -1280,3 +1284,64 @@ def get_manager_for_kv_cache_spec( ) manager = manager_class(kv_cache_spec, **kwargs) return manager + + +def register_all_kvcache_specs(vllm_config): + """Built-in spec registration""" + KVCacheSpecRegistry.register( + FullAttentionSpec, + FullAttentionManager, + uniform_type_base_spec=FullAttentionSpec, + ) + + KVCacheSpecRegistry.register( + SlidingWindowSpec, + SlidingWindowManager, + uniform_type_base_spec=SlidingWindowSpec, + ) + KVCacheSpecRegistry.register( + SlidingWindowMLASpec, + SlidingWindowManager, + uniform_type_base_spec=SlidingWindowMLASpec, + ) + + KVCacheSpecRegistry.register( + MambaSpec, MambaManager, uniform_type_base_spec=MambaSpec + ) + KVCacheSpecRegistry.register( + ChunkedLocalAttentionSpec, + ChunkedLocalAttentionManager, + uniform_type_base_spec=ChunkedLocalAttentionSpec, + ) + KVCacheSpecRegistry.register( + CrossAttentionSpec, + CrossAttentionManager, + uniform_type_base_spec=CrossAttentionSpec, + ) + + # FullAttentionSpec subclasses — grouped with FullAttentionSpec + KVCacheSpecRegistry.register( + TQFullAttentionSpec, + FullAttentionManager, + uniform_type_base_spec=FullAttentionSpec, + ) + KVCacheSpecRegistry.register( + MLAAttentionSpec, FullAttentionManager, uniform_type_base_spec=FullAttentionSpec + ) + # NOTE(Mengqing): HiddenStateCacheSpec won't take part in + # grouping, thus the uniform_type_base_spec is just a + # placeholder. + KVCacheSpecRegistry.register( + HiddenStateCacheSpec, + FullAttentionManager, + uniform_type_base_spec=FullAttentionSpec, + ) + KVCacheSpecRegistry.register( + SinkFullAttentionSpec, + SinkFullAttentionManager, + uniform_type_base_spec=FullAttentionSpec, + ) + + from vllm.platforms import current_platform + + current_platform.register_custom_kv_cache_specs(vllm_config) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 50fd98e1fdf..b12aa9d0505 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -52,6 +52,7 @@ from vllm.v1.core.kv_cache_utils import ( ) from vllm.v1.core.sched.interface import PauseState, SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.core.single_type_kv_cache_manager import register_all_kvcache_specs from vllm.v1.engine import ( EEP_NOTIFICATION_CALL_ID, EEPNotificationType, @@ -235,6 +236,9 @@ class EngineCore: def _initialize_kv_caches(self, vllm_config: VllmConfig) -> KVCacheConfig: start = time.time() + # register all kvcache specs in enginecore process. + register_all_kvcache_specs(vllm_config) + # Get all kv cache needed by the model kv_cache_specs = self.model_executor.get_kv_cache_specs() diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 31ee89bc72a..3bbfba1a0fe 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -17,6 +17,7 @@ from vllm.logger import init_logger from vllm.utils.math_utils import cdiv, round_up from vllm.utils.torch_utils import get_dtype_size, nvfp4_kv_cache_full_dim from vllm.v1.attention.backends.registry import MambaAttentionBackendEnum +from vllm.v1.kv_cache_spec_registry import KVCacheSpecRegistry if TYPE_CHECKING: from vllm.config import VllmConfig @@ -139,6 +140,21 @@ class KVCacheSpec: ) return copy.deepcopy(specs[0]) + def is_uniform_with_collection( + self, kv_cache_specs: dict[str, KVCacheSpec] + ) -> bool: + """ + Whether this KVCacheSpec is uniform with all specs of all layers. + """ + uniform_type_base_spec = KVCacheSpecRegistry.get_uniform_type_base_spec(self) + assert uniform_type_base_spec is not None, ( + f"Unsupported KV cache spec type: {type(self)}. " + "Please register it using @register_kv_cache_spec decorator." + ) + return all( + isinstance(spec, uniform_type_base_spec) for spec in kv_cache_specs.values() + ) + @dataclass(frozen=True, kw_only=True) class AttentionSpec(KVCacheSpec): @@ -430,6 +446,15 @@ class ChunkedLocalAttentionSpec(AttentionSpec): ) return max_blocks * self.page_size_bytes + def is_uniform_with_collection( + self, kv_cache_specs: dict[str, KVCacheSpec] + ) -> bool: + return all( + isinstance(spec, ChunkedLocalAttentionSpec) + and spec.attention_chunk_size == self.attention_chunk_size + for spec in kv_cache_specs.values() + ) + @dataclass(frozen=True, kw_only=True) class SlidingWindowSpec(AttentionSpec): @@ -493,6 +518,15 @@ class SlidingWindowSpec(AttentionSpec): ) return max_blocks * self.page_size_bytes + def is_uniform_with_collection( + self, kv_cache_specs: dict[str, KVCacheSpec] + ) -> bool: + return all( + isinstance(spec, SlidingWindowSpec) + and spec.sliding_window == self.sliding_window + for spec in kv_cache_specs.values() + ) + @dataclass(frozen=True, kw_only=True) class SlidingWindowMLASpec(SlidingWindowSpec): @@ -558,6 +592,15 @@ class SlidingWindowMLASpec(SlidingWindowSpec): model_version=model_version_set.pop(), ) + def is_uniform_with_collection( + self, kv_cache_specs: dict[str, KVCacheSpec] + ) -> bool: + return all( + isinstance(spec, SlidingWindowMLASpec) + and spec.sliding_window == self.sliding_window + for spec in kv_cache_specs.values() + ) + @dataclass(frozen=True) class MambaSpec(KVCacheSpec): @@ -590,6 +633,15 @@ class MambaSpec(KVCacheSpec): else: return self.page_size_bytes * (1 + self.num_speculative_blocks) + def is_uniform_with_collection( + self, kv_cache_specs: dict[str, KVCacheSpec] + ) -> bool: + return all( + isinstance(spec, MambaSpec) + and spec.num_speculative_blocks == self.num_speculative_blocks + for spec in kv_cache_specs.values() + ) + @dataclass(frozen=True) class EncoderOnlyAttentionSpec(AttentionSpec): @@ -689,53 +741,16 @@ class UniformTypeKVCacheSpecs(KVCacheSpec): def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool: """ Whether all layers have the same type of KV cache spec. + + Uses the registry to determine grouping base classes, so custom specs + that inherit from FullAttentionSpec are treated as full attention. """ block_sizes = set(spec.block_size for spec in kv_cache_specs.values()) if len(block_sizes) > 1: # Different block sizes, not uniform. return False - one_spec = next(iter(kv_cache_specs.values())) - # NOTE: Check subclasses before parent classes since isinstance() - # returns True for subclasses. - if isinstance(one_spec, SlidingWindowMLASpec): - # SlidingWindowMLASpec is uniform if all specs are SlidingWindowMLASpec - # with the same sliding_window size. - return all( - isinstance(spec, SlidingWindowMLASpec) - and spec.sliding_window == one_spec.sliding_window - for spec in kv_cache_specs.values() - ) - elif isinstance(one_spec, FullAttentionSpec): - return all( - isinstance(spec, FullAttentionSpec) for spec in kv_cache_specs.values() - ) - elif isinstance(one_spec, CrossAttentionSpec): - return all( - isinstance(spec, CrossAttentionSpec) for spec in kv_cache_specs.values() - ) - elif isinstance(one_spec, SlidingWindowSpec): - return all( - isinstance(spec, SlidingWindowSpec) - and spec.sliding_window == one_spec.sliding_window - for spec in kv_cache_specs.values() - ) - elif isinstance(one_spec, ChunkedLocalAttentionSpec): - return all( - isinstance(spec, ChunkedLocalAttentionSpec) - and spec.attention_chunk_size == one_spec.attention_chunk_size - for spec in kv_cache_specs.values() - ) - elif isinstance(one_spec, MambaSpec): - return all( - isinstance(spec, MambaSpec) - and spec.num_speculative_blocks == one_spec.num_speculative_blocks - for spec in kv_cache_specs.values() - ) - else: - # NOTE(Chen): Please add new branches for new KV cache spec types. - raise NotImplementedError( - f"Unsupported KV cache spec type: {type(one_spec)}" - ) + first_spec = next(iter(kv_cache_specs.values())) + return first_spec.is_uniform_with_collection(kv_cache_specs) @classmethod def from_specs(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> Self | None: diff --git a/vllm/v1/kv_cache_spec_registry.py b/vllm/v1/kv_cache_spec_registry.py new file mode 100644 index 00000000000..816a6862dae --- /dev/null +++ b/vllm/v1/kv_cache_spec_registry.py @@ -0,0 +1,209 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Registry for KVCacheSpec types and their associated managers. + +This module provides a pluggable architecture for registering custom KVCacheSpec +subclasses without modifying vLLM core code. Out-of-tree platforms can define +custom specs and managers by using the @register_kv_cache_spec decorator. +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.v1.core.single_type_kv_cache_manager import SingleTypeKVCacheManager + from vllm.v1.kv_cache_interface import KVCacheSpec + + +@dataclass(frozen=True) +class KVCacheSpecMetadata: + """Metadata for a registered KVCacheSpec.""" + + kvcache_spec_cls: type["KVCacheSpec"] + manager_class: type["SingleTypeKVCacheManager"] + # The base spec class for grouping compatibility checks. + # KVCacheSpecs with the same uniform_type_base_spec will be + # grouped into one kvcache group + uniform_type_base_spec: type["KVCacheSpec"] + + +_REGISTRY_KVCACHESPEC_LIST: dict[type["KVCacheSpec"], KVCacheSpecMetadata] = {} + + +class KVCacheSpecRegistry: + """Global registry for KVCacheSpec types and their associated managers.""" + + @classmethod + def _ensure_registered(cls, vllm_config=None) -> None: + """ + Run full KVCacheSpec registration if the registration is not done. + """ + if _REGISTRY_KVCACHESPEC_LIST: + return + + if vllm_config is None: + from vllm.config import get_current_vllm_config_or_none + + vllm_config = get_current_vllm_config_or_none() + + # lazy import to avoid circular dependency + from vllm.v1.core.single_type_kv_cache_manager import ( + register_all_kvcache_specs, + ) + + register_all_kvcache_specs(vllm_config) + + @classmethod + def register( + cls, + kvcache_spec_cls: type["KVCacheSpec"], + manager_class: type["SingleTypeKVCacheManager"] | None = None, + uniform_type_base_spec: type["KVCacheSpec"] | None = None, + ) -> None: + """ + Register a KVCacheSpec class with its manager and base spec. + + Args: + kvcache_spec_cls: The KVCacheSpec subclass to register + manager_class: The SingleTypeKVCacheManager to use for this spec + uniform_type_base_spec: The base spec class for grouping compatibility. + instead of being grouped to different kvcache group, `kvcache_spec_cls` + and `uniform_type_base_spec` will be trated as uniform type. + If None, defaults to kvcache_spec_cls itself (for built-in base specs). + """ + assert manager_class is not None, "manager_class is required" + if uniform_type_base_spec is None: + uniform_type_base_spec = kvcache_spec_cls + assert issubclass(kvcache_spec_cls, uniform_type_base_spec), ( + f"{kvcache_spec_cls.__name__} must inherit from its declared " + f"uniform_type_base_spec {uniform_type_base_spec.__name__}." + ) + + if kvcache_spec_cls in _REGISTRY_KVCACHESPEC_LIST: + registered_spec = _REGISTRY_KVCACHESPEC_LIST[kvcache_spec_cls] + is_same_registration = ( + manager_class == registered_spec.manager_class + and uniform_type_base_spec == registered_spec.uniform_type_base_spec + ) + assert is_same_registration, ( + f"Conflicting registration for KVCacheSpec " + f": {kvcache_spec_cls.__name__}" + ) + + _REGISTRY_KVCACHESPEC_LIST[kvcache_spec_cls] = KVCacheSpecMetadata( + kvcache_spec_cls=kvcache_spec_cls, + manager_class=manager_class, + uniform_type_base_spec=uniform_type_base_spec, + ) + + @classmethod + def get_manager_class( + cls, kvcache_spec: "KVCacheSpec" + ) -> type["SingleTypeKVCacheManager"] | None: + """ + Get the single type kvcache manager class for a given kvcache spec instance. + + Args: + kvcache_spec: A KVCacheSpec instance + + Returns: + The SingleTypeKVCacheManager class to use for this kvcache_spec + """ + cls._ensure_registered() + kvcache_spec_cls = type(kvcache_spec) + + # Walk up the MRO to find a registered base class + for base in kvcache_spec_cls.__mro__: + if base in _REGISTRY_KVCACHESPEC_LIST: + return _REGISTRY_KVCACHESPEC_LIST[base].manager_class + + return None + + @classmethod + def get_uniform_type_base_spec( + cls, kvcache_spec: "KVCacheSpec" + ) -> type["KVCacheSpec"] | None: + """ + Get the base kvcache spec class for grouping compatibility checks. + KVCacheSpecs with uniform_type_base_spec will be trated as one group. + + Args: + kvcache_spec: A KVCacheSpec instance + + Returns: + The base KVCacheSpec class for checking uniform type kvcache specs + """ + cls._ensure_registered() + kvcache_spec_cls = type(kvcache_spec) + + # Walk up the MRO to find a registered base spec + for base in kvcache_spec_cls.__mro__: + if base in _REGISTRY_KVCACHESPEC_LIST: + return _REGISTRY_KVCACHESPEC_LIST[base].uniform_type_base_spec + + return None + + @classmethod + def check_kv_cache_spec_registry( + cls, kv_cache_spec: dict[str, "KVCacheSpec"] + ) -> None: + """ + Check if the KVCacheSpecs of each layer are registered as expected. + """ + cls._ensure_registered() + for layer_name, spec in kv_cache_spec.items(): + # use raise instead of assert to make it effective in production environment + if cls.get_uniform_type_base_spec(spec) is None: + raise ValueError( + f"Unsupported KV cache spec type for layer {layer_name}: " + f"{type(spec)}. Please register it using " + f"@register_kv_cache_spec decorator." + ) + if cls.get_manager_class(spec) is None: + raise ValueError( + f"No manager found for KV cache spec type for layer " + f"{layer_name}: {type(spec)}. Please register it using " + f"@register_kv_cache_spec decorator." + ) + + +def register_kv_cache_spec( + manager_class: type["SingleTypeKVCacheManager"] | None = None, + uniform_type_base_spec: type["KVCacheSpec"] | None = None, +): + """ + Decorator to register a custom KVCacheSpec class. + + Args: + manager_class: The SingleTypeKVCacheManager to use for this spec. + Required for all registered specs. + uniform_type_base_spec: The base spec class for uniform type kv cache specs + compatibility. If None, the spec is treated as a new base + type. + + Examples: + - Register a new specs: + @register_kv_cache_spec( + manager_class=FullAttentionManager, + uniform_type_base_spec=FullAttentionSpec + ) + @dataclass(frozen=True, kw_only=True) + class CustomFullAttentionSpec(FullAttentionSpec): + pass + """ + + def decorator(kvcache_spec_cls: type["KVCacheSpec"]) -> type["KVCacheSpec"]: + KVCacheSpecRegistry.register( + kvcache_spec_cls=kvcache_spec_cls, + manager_class=manager_class, + uniform_type_base_spec=uniform_type_base_spec, + ) + return kvcache_spec_cls + + return decorator diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index f98b12a379d..66806ab8a9b 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -19,6 +19,69 @@ if HAS_TRITON: logger = init_logger(__name__) +_FLASHINFER_MIN_VERSION = "0.2.3" + + +def flashinfer_sampler_supported() -> bool: + """Decide whether FlashInfer's top-p/top-k sampler can be used. + + Returns False (with appropriate logging) when ``VLLM_USE_FLASHINFER_SAMPLER`` + is 0, when the platform isn't CUDA, when the GPU's compute capability is + unsupported, or when the installed flashinfer is missing or too old. Raises + ``RuntimeError`` if the user explicitly opted in via the env var but + FlashInfer is unavailable. + + Note: callers must additionally ensure ``logprobs_mode`` doesn't require + post-top-k/top-p logits/logprobs for any request whose logprobs will be + returned in this step, since FlashInfer doesn't expose those. + """ + if not current_platform.is_cuda(): + return False + if not envs.VLLM_USE_FLASHINFER_SAMPLER: + logger.info_once( + "FlashInfer top-p/top-k sampling disabled via " + "VLLM_USE_FLASHINFER_SAMPLER=0." + ) + return False + from vllm.v1.attention.backends.flashinfer import FlashInferBackend + + capability = current_platform.get_device_capability() + assert capability is not None + unsupported_reason: str | None = None + if not FlashInferBackend.supports_compute_capability(capability): + unsupported_reason = ( + f"unsupported compute capability {capability.as_version_str()}" + ) + else: + try: + import flashinfer + + if version.parse(flashinfer.__version__) < version.parse( + _FLASHINFER_MIN_VERSION + ): + unsupported_reason = ( + f"flashinfer {flashinfer.__version__} is too old " + f"(>={_FLASHINFER_MIN_VERSION} required)" + ) + except ImportError: + unsupported_reason = "flashinfer is not installed" + + if unsupported_reason is None: + logger.info_once("Using FlashInfer for top-p & top-k sampling.", scope="global") + return True + if envs.is_set("VLLM_USE_FLASHINFER_SAMPLER"): + raise RuntimeError( + f"FlashInfer top-p/top-k sampling unavailable: {unsupported_reason}. " + "Unset VLLM_USE_FLASHINFER_SAMPLER=1." + ) + logger.warning_once( + "FlashInfer top-p/top-k sampling unavailable: %s; falling back. " + "Set VLLM_USE_FLASHINFER_SAMPLER=0 to silence.", + unsupported_reason, + ) + return False + + class TopKTopPSampler(nn.Module): """ Module that performs optional top-k and top-p filtering followed by @@ -30,49 +93,16 @@ class TopKTopPSampler(nn.Module): def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: super().__init__() self.logprobs_mode = logprobs_mode - # flashinfer optimization does not apply if intermediate - # logprobs/logits after top_k/top_p need to be returned - if ( - logprobs_mode not in ("processed_logits", "processed_logprobs") - and current_platform.is_cuda() - ): - if envs.VLLM_USE_FLASHINFER_SAMPLER: - from vllm.v1.attention.backends.flashinfer import FlashInferBackend - - capability = current_platform.get_device_capability() - assert capability is not None - if FlashInferBackend.supports_compute_capability(capability): - logger.info_once( - "Using FlashInfer for top-p & top-k sampling.", - scope="global", - ) - self.forward = self.forward_cuda - elif envs.is_set("VLLM_USE_FLASHINFER_SAMPLER"): - # User explicitly opted in but the GPU can't run FlashInfer. - capability_str = capability.as_version_str() - raise RuntimeError( - "FlashInfer does not support compute capability " - f"{capability_str}, unset VLLM_USE_FLASHINFER_SAMPLER=1." - ) - else: - # Default-on path; hardware can't run FlashInfer → - # quietly fall back to the PyTorch-native sampler - # instead of failing server startup. - logger.warning_once( - "FlashInfer top-p/top-k sampling not supported on " - "compute capability %s; falling back to PyTorch-native " - "sampler. Set VLLM_USE_FLASHINFER_SAMPLER=0 to silence.", - capability.as_version_str(), - ) - self.forward = self.forward_native - else: - # User explicitly set VLLM_USE_FLASHINFER_SAMPLER=0. - logger.info_once( - "FlashInfer top-p/top-k sampling disabled via " - "VLLM_USE_FLASHINFER_SAMPLER=0; using PyTorch-native sampler." - ) - self.forward = self.forward_native - + if current_platform.is_cuda(): + # FlashInfer doesn't expose post-top-k/top-p logits/logprobs, + # so it can't be used when the configured mode requires them. + can_use_flashinfer = ( + logprobs_mode not in ("processed_logits", "processed_logprobs") + and flashinfer_sampler_supported() + ) + self.forward = ( + self.forward_cuda if can_use_flashinfer else self.forward_native + ) elif current_platform.is_cpu(): arch = current_platform.get_cpu_architecture() # Fall back to native implementation for POWERPC and RISCV. @@ -417,7 +447,7 @@ def flashinfer_sample( logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None, - generators: dict[int, torch.Generator], + generators: dict[int, torch.Generator] = {}, # noqa ) -> torch.Tensor: """Sample from the logits using FlashInfer. @@ -431,11 +461,6 @@ def flashinfer_sample( """ import flashinfer - if version.parse(flashinfer.__version__) < version.parse("0.2.3"): - raise ImportError( - "FlashInfer version >= 0.2.3 required for top-k and top-p sampling. " - ) - assert not (k is None and p is None) if k is None: # Top-p only. diff --git a/vllm/v1/worker/gpu/sample/sampler.py b/vllm/v1/worker/gpu/sample/sampler.py index 8bf884fd9b3..6b545aef3a2 100644 --- a/vllm/v1/worker/gpu/sample/sampler.py +++ b/vllm/v1/worker/gpu/sample/sampler.py @@ -7,6 +7,11 @@ import torch import vllm.envs as envs from vllm.config.model import LogprobsMode from vllm.sampling_params import SamplingParams +from vllm.v1.sample.ops.topk_topp_sampler import ( + apply_top_k_top_p, + flashinfer_sample, + flashinfer_sampler_supported, +) from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.metrics.logits import get_num_nans from vllm.v1.worker.gpu.sample.bad_words import BadWordsState @@ -45,6 +50,7 @@ class Sampler: self.bad_words_state = BadWordsState(req_states) self.logprob_token_ids_state = LogprobTokenIdsState(max_num_reqs, device) self.num_speculative_tokens = num_speculative_tokens + self.use_flashinfer = flashinfer_sampler_supported() def add_request( self, req_idx: int, prompt_len: int, sampling_params: SamplingParams @@ -77,6 +83,13 @@ class Sampler: # NOTE(woosuk): We intentionally compute num_nans before sampling to make clear # that num_nans is computed before applying penalties and temperature. num_nans = get_num_nans(logits) if self.compute_nans else None + + max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np) + max_per_req_token_ids = self.logprob_token_ids_state.max_num_token_ids( + idx_mapping_np + ) + return_logprobs = max_num_logprobs != NO_LOGPROBS or max_per_req_token_ids > 0 + sampled, processed_logits = self.sample( logits, expanded_idx_mapping, @@ -84,13 +97,10 @@ class Sampler: pos, input_ids, expanded_local_pos, + return_logprobs=return_logprobs, ) - max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np) - max_per_req_token_ids = self.logprob_token_ids_state.max_num_token_ids( - idx_mapping_np - ) - if max_num_logprobs != NO_LOGPROBS or max_per_req_token_ids > 0: + if return_logprobs: if self.logprobs_mode == "processed_logprobs": logits = processed_logits expanded_logits = logits.shape[0] != idx_mapping_np.shape[0] @@ -128,6 +138,7 @@ class Sampler: pos: torch.Tensor, input_ids: torch.Tensor, expanded_local_pos: torch.Tensor, + skip_top_k_top_p: bool = False, ) -> torch.Tensor: # Copy logits to a new FP32 tensor. logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits) @@ -163,6 +174,9 @@ class Sampler: # Apply min_p in place. self.sampling_states.apply_min_p(logits, expanded_idx_mapping, idx_mapping_np) + if skip_top_k_top_p: + return logits + # Apply top_k and/or top_p. This might or might not return a new tensor. return self.sampling_states.apply_top_k_top_p( logits, expanded_idx_mapping, idx_mapping_np @@ -176,6 +190,7 @@ class Sampler: pos: torch.Tensor, input_ids: torch.Tensor, expanded_local_pos: torch.Tensor, + return_logprobs: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: processed_logits = self.apply_sampling_params( logits, @@ -184,16 +199,33 @@ class Sampler: pos, input_ids, expanded_local_pos, + skip_top_k_top_p=True, + ) + top_k, top_p = self.sampling_states.get_top_k_top_p( + expanded_idx_mapping, idx_mapping_np + ) + use_flashinfer = self.use_flashinfer and not ( + # Don't use FI sampler if no requests use top_k/top_p, if there are + # any greedy requests or per-request seeds, or if post-processed + # logprobs need to be returned for any requests. + (top_k is None and top_p is None) + or (return_logprobs and self.logprobs_mode == "processed_logprobs") + or self.sampling_states.any_greedy(idx_mapping_np) + or self.sampling_states.any_explicit_seed(idx_mapping_np) ) # Sample the next token. - sampled = gumbel_sample( - processed_logits, - expanded_idx_mapping, - self.sampling_states.temperature.gpu, - self.sampling_states.seeds.gpu, - pos, - apply_temperature=False, - use_fp64=self.use_fp64_gumbel, - ) + if use_flashinfer: + sampled = flashinfer_sample(processed_logits, top_k, top_p).to(torch.int64) + else: + processed_logits = apply_top_k_top_p(processed_logits, top_k, top_p) + sampled = gumbel_sample( + processed_logits, + expanded_idx_mapping, + self.sampling_states.temperature.gpu, + self.sampling_states.seeds.gpu, + pos, + apply_temperature=False, + use_fp64=self.use_fp64_gumbel, + ) return sampled, processed_logits diff --git a/vllm/v1/worker/gpu/sample/states.py b/vllm/v1/worker/gpu/sample/states.py index f247acba07c..bf2f1ce78fe 100644 --- a/vllm/v1/worker/gpu/sample/states.py +++ b/vllm/v1/worker/gpu/sample/states.py @@ -24,6 +24,9 @@ class SamplingStates: self.top_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32) self.min_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32) self.seeds = UvaBackedTensor(max_num_reqs, dtype=torch.int64) + # Tracks whether `seed` was set explicitly by the user, so callers + # can fall back from RNG paths that don't honor per-request seeds. + self.seeds_set = np.zeros(max_num_reqs, dtype=bool) # Initialize top_k and top_p manually because 0 is an invalid value for them. self.top_k.np.fill(self.vocab_size) @@ -45,6 +48,7 @@ class SamplingStates: self.min_p.np[req_idx] = sampling_params.min_p seed = sampling_params.seed + self.seeds_set[req_idx] = seed is not None if seed is None: seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX) self.seeds.np[req_idx] = seed @@ -85,20 +89,31 @@ class SamplingStates: return apply_min_p(logits, expanded_idx_mapping, self.min_p.gpu) + def get_top_k_top_p( + self, expanded_idx_mapping: torch.Tensor, idx_mapping_np: np.ndarray + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size) + do_top_p = np.any(self.top_p.np[idx_mapping_np] != 1.0) + top_k = self.top_k.gpu[expanded_idx_mapping] if do_top_k else None + top_p = self.top_p.gpu[expanded_idx_mapping] if do_top_p else None + return top_k, top_p + def apply_top_k_top_p( self, logits: torch.Tensor, expanded_idx_mapping: torch.Tensor, idx_mapping_np: np.ndarray, ) -> torch.Tensor: - do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size) - do_top_p = np.any(self.top_p.np[idx_mapping_np] != 1.0) - if not (do_top_k or do_top_p): + top_k, top_p = self.get_top_k_top_p(expanded_idx_mapping, idx_mapping_np) + if top_k is None and top_p is None: return logits - - top_k = self.top_k.gpu[expanded_idx_mapping] if do_top_k else None - top_p = self.top_p.gpu[expanded_idx_mapping] if do_top_p else None return apply_top_k_top_p(logits, top_k, top_p) + def any_greedy(self, idx_mapping_np: np.ndarray) -> bool: + return bool(np.any(self.temperature.np[idx_mapping_np] == 0.0)) + + def any_explicit_seed(self, idx_mapping_np: np.ndarray) -> bool: + return bool(np.any(self.seeds_set[idx_mapping_np])) + def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int: return int(np.max(self.num_logprobs[idx_mapping_np])) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index dbbeafcbd40..5265c3a43a2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -152,6 +152,7 @@ from vllm.v1.kv_cache_interface import ( SlidingWindowSpec, UniformTypeKVCacheSpecs, ) +from vllm.v1.kv_cache_spec_registry import KVCacheSpecRegistry from vllm.v1.outputs import ( EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, @@ -6235,6 +6236,7 @@ class GPUModelRunner( ) kv_cache_spec = self.get_kv_cache_spec() + KVCacheSpecRegistry.check_kv_cache_spec_registry(kv_cache_spec) kv_cache_groups = get_kv_cache_groups(self.vllm_config, kv_cache_spec) min_blocks = self.compilation_config.max_cudagraph_capture_size or 1