[9/n] Migrate attention and cache kernels to torch stable ABI (continued) (#43717)

Signed-off-by: Chris Leonard <chleonar@redhat.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Co-authored-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Co-authored-by: Shengqi Chen <harry-chen@outlook.com>
This commit is contained in:
Chris Leonard
2026-05-29 00:44:45 -04:00
committed by GitHub
parent 710f077617
commit 22a58640b4
23 changed files with 909 additions and 720 deletions
+9 -11
View File
@@ -305,10 +305,6 @@ endif()
#
set(VLLM_EXT_SRC
"csrc/cache_kernels.cu"
"csrc/cache_kernels_fused.cu"
"csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu"
"csrc/cuda_view.cu"
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
"csrc/quantization/activation_kernels.cu"
@@ -647,7 +643,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
"csrc/libtorch_stable/attention/merge_attn_states.cu"
"csrc/libtorch_stable/sampler.cu"
"csrc/libtorch_stable/topk.cu"
"csrc/libtorch_stable/mamba/selective_scan_fwd.cu")
"csrc/libtorch_stable/mamba/selective_scan_fwd.cu"
"csrc/libtorch_stable/attention/paged_attention_v1.cu"
"csrc/libtorch_stable/attention/paged_attention_v2.cu"
"csrc/libtorch_stable/cache_kernels.cu"
"csrc/libtorch_stable/cache_kernels_fused.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC
@@ -924,13 +924,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
# nvfp4_kv_cache_kernels uses non-stable torch API and is called directly
# from cache_kernels.cu, so it belongs in _C rather than _C_stable.
set(NVFP4_KV_SRC "csrc/nvfp4_kv_cache_kernels.cu")
set(NVFP4_KV_SRC "csrc/libtorch_stable/nvfp4_kv_cache_kernels.cu")
set_gencode_flags_for_srcs(
SRCS "${NVFP4_KV_SRC}"
CUDA_ARCHS "${FP4_ARCHS}")
target_sources(_C PRIVATE ${NVFP4_KV_SRC})
list(APPEND VLLM_STABLE_EXT_SRC "${NVFP4_KV_SRC}")
target_compile_definitions(_C PRIVATE ENABLE_NVFP4_SM120=1)
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
@@ -960,11 +958,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
set(NVFP4_KV_SRC "csrc/nvfp4_kv_cache_kernels.cu")
set(NVFP4_KV_SRC "csrc/libtorch_stable/nvfp4_kv_cache_kernels.cu")
set_gencode_flags_for_srcs(
SRCS "${NVFP4_KV_SRC}"
CUDA_ARCHS "${FP4_ARCHS}")
target_sources(_C PRIVATE ${NVFP4_KV_SRC})
list(APPEND VLLM_STABLE_EXT_SRC "${NVFP4_KV_SRC}")
target_compile_definitions(_C PRIVATE ENABLE_NVFP4_SM100=1)
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
+1 -1
View File
@@ -4,7 +4,7 @@
#include <cmath>
#include "../cuda_compat.h"
#include "../cuda_vec_utils.cuh"
#include "cuda_vec_utils.cuh"
#include "dispatch_utils.h"
#include "torch_utils.h"
@@ -17,21 +17,18 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include "attention_dtypes.h"
#include "../../attention/attention_dtypes.h"
#include "attention_utils.cuh"
#include "../cuda_compat.h"
#include "../../cuda_compat.h"
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include "../quantization/w8a8/fp8/amd/quant_utils.cuh"
#include "../../quantization/w8a8/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16;
#else
#include "../quantization/w8a8/fp8/nvidia/quant_utils.cuh"
#include "../../quantization/w8a8/fp8/nvidia/quant_utils.cuh"
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
@@ -18,8 +18,8 @@
*/
#pragma once
#include "../cuda_compat.h"
#include "attention_dtypes.h"
#include "../../cuda_compat.h"
#include "../../attention/attention_dtypes.h"
#include <float.h>
#include <type_traits>
@@ -7,7 +7,7 @@
#include <torch/headeronly/core/ScalarType.h>
#include "../../attention/attention_dtypes.h"
#include "../../attention/attention_utils.cuh"
#include "attention_utils.cuh"
#include "../../quantization/w8a8/fp8/common.cuh"
namespace vllm {
@@ -16,8 +16,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../torch_utils.h"
#include "attention_kernels.cuh"
#include "../cuda_compat.h"
#include "../../cuda_compat.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
@@ -44,13 +45,15 @@ template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
int NUM_THREADS = 128>
void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
torch::stable::Tensor& out, torch::stable::Tensor& query,
torch::stable::Tensor& key_cache, torch::stable::Tensor& value_cache,
int num_kv_heads, float scale, torch::stable::Tensor& block_tables,
torch::stable::Tensor& seq_lens, int max_seq_len,
const std::optional<torch::stable::Tensor>& alibi_slopes,
torch::stable::Tensor& k_scale, torch::stable::Tensor& v_scale,
const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@@ -69,8 +72,8 @@ void paged_attention_v1_launcher(
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
int* block_tables_ptr = block_tables.mutable_data_ptr<int>();
int* seq_lens_ptr = seq_lens.mutable_data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
@@ -85,8 +88,9 @@ void paged_attention_v1_launcher(
dim3 grid(num_heads, num_seqs, 1);
dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
query.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
@@ -119,7 +123,7 @@ void paged_attention_v1_launcher(
LAUNCH_PAGED_ATTENTION_V1(256);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
STD_TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
@@ -153,31 +157,31 @@ void paged_attention_v1_launcher(
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
STD_TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v1(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
torch::stable::Tensor& out, // [num_seqs, num_heads, head_size]
torch::stable::Tensor& query, // [num_seqs, num_heads, head_size]
torch::stable::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
torch::stable::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
torch::stable::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::stable::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const std::optional<torch::stable::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::stable::Tensor& k_scale,
torch::stable::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
DISPATCH_BY_KV_CACHE_DTYPE(query.scalar_type(), kv_cache_dtype,
CALL_V1_LAUNCHER_BLOCK_SIZE)
}
@@ -16,8 +16,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../torch_utils.h"
#include "attention_kernels.cuh"
#include "../cuda_compat.h"
#include "../../cuda_compat.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
@@ -44,14 +45,16 @@ template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
int NUM_THREADS = 128, int PARTITION_SIZE = 512>
void paged_attention_v2_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
torch::stable::Tensor& out, torch::stable::Tensor& exp_sums,
torch::stable::Tensor& max_logits, torch::stable::Tensor& tmp_out,
torch::stable::Tensor& query, torch::stable::Tensor& key_cache,
torch::stable::Tensor& value_cache, int num_kv_heads, float scale,
torch::stable::Tensor& block_tables, torch::stable::Tensor& seq_lens,
int max_seq_len, const std::optional<torch::stable::Tensor>& alibi_slopes,
torch::stable::Tensor& k_scale, torch::stable::Tensor& v_scale,
const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@@ -73,8 +76,8 @@ void paged_attention_v2_launcher(
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
int* block_tables_ptr = block_tables.mutable_data_ptr<int>();
int* seq_lens_ptr = seq_lens.mutable_data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
@@ -91,8 +94,9 @@ void paged_attention_v2_launcher(
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
query.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
@@ -125,7 +129,7 @@ void paged_attention_v2_launcher(
LAUNCH_PAGED_ATTENTION_V2(256);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
STD_TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
@@ -160,34 +164,36 @@ void paged_attention_v2_launcher(
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
STD_TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v2(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor&
torch::stable::Tensor& out, // [num_seqs, num_heads, head_size]
torch::stable::Tensor&
exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::stable::Tensor&
max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::stable::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
torch::stable::Tensor& query, // [num_seqs, num_heads, head_size]
torch::stable::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
torch::stable::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
torch::stable::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::stable::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const std::optional<torch::stable::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::stable::Tensor& k_scale,
torch::stable::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
DISPATCH_BY_KV_CACHE_DTYPE(query.scalar_type(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
}
@@ -1,20 +1,16 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAException.h>
#include <c10/util/Optional.h>
#include "cuda_utils.h"
#include "cuda_compat.h"
#include "torch_utils.h"
#include "dispatch_utils.h"
#include "libtorch_stable/quantization/vectorization_utils.cuh"
#include "../cuda_utils.h"
#include "../cuda_compat.h"
#include "quantization/vectorization_utils.cuh"
#include "concat_mla_q.cuh"
#ifdef USE_ROCM
#include "quantization/w8a8/fp8/amd/quant_utils.cuh"
#include "../quantization/w8a8/fp8/amd/quant_utils.cuh"
#else
#include "quantization/w8a8/fp8/nvidia/quant_utils.cuh"
#include "../quantization/w8a8/fp8/nvidia/quant_utils.cuh"
#endif
#include <algorithm>
@@ -34,14 +30,14 @@ constexpr float kFp8ScaleDivisor = 224.f;
constexpr float kFp8ScaleDivisor = 448.f;
#endif
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
void swap_blocks(torch::stable::Tensor& src, torch::stable::Tensor& dst,
int64_t block_size_in_bytes,
const torch::Tensor& block_mapping) {
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
const torch::stable::Tensor& block_mapping) {
torch::stable::Device src_device = src.device();
torch::stable::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type;
if (src_device.is_cuda() && dst_device.is_cuda()) {
TORCH_CHECK(src_device.index() == dst_device.index(),
STD_TORCH_CHECK(src_device.index() == dst_device.index(),
"src and dst must be on the same GPU");
memcpy_type = cudaMemcpyDeviceToDevice;
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
@@ -49,25 +45,30 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
memcpy_type = cudaMemcpyHostToDevice;
} else {
TORCH_CHECK(false, "Invalid device combination");
STD_TORCH_CHECK(false, "Invalid device combination");
}
// NOTE(youkaichao): keep in mind that `block_mapping` should be
// a cpu tensor, otherwise every `item` call will require a gpu-cpu
// synchronization.
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
STD_TORCH_CHECK(block_mapping.device().is_cpu(),
"block_mapping must be on CPU");
char* src_ptr = static_cast<char*>(src.data_ptr());
char* dst_ptr = static_cast<char*>(dst.data_ptr());
const at::cuda::OptionalCUDAGuard device_guard(
src_device.is_cuda() ? src_device : dst_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto guard_device = src_device.is_cuda() ? src_device : dst_device;
const torch::stable::accelerator::DeviceGuard device_guard(
guard_device.index());
const cudaStream_t stream = get_current_cuda_stream();
// NOTE(woosuk): This can be slow if the number of blocks is large.
const int64_t num_blocks = block_mapping.size(0);
const int64_t* bm_ptr = block_mapping.const_data_ptr<int64_t>();
const int64_t bm_stride0 = block_mapping.stride(0);
const int64_t bm_stride1 = block_mapping.stride(1);
for (size_t i = 0; i < num_blocks; i++) {
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
int64_t src_block_number = bm_ptr[i * bm_stride0];
int64_t dst_block_number = bm_ptr[i * bm_stride0 + bm_stride1];
int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes;
cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
@@ -75,20 +76,23 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
}
}
void swap_blocks_batch(const torch::Tensor& src_ptrs,
const torch::Tensor& dst_ptrs,
const torch::Tensor& sizes,
void swap_blocks_batch(const torch::stable::Tensor& src_ptrs,
const torch::stable::Tensor& dst_ptrs,
const torch::stable::Tensor& sizes,
bool is_src_access_order_any) {
TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU");
TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU");
TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU");
TORCH_CHECK(src_ptrs.dtype() == torch::kInt64, "src_ptrs must be int64");
TORCH_CHECK(dst_ptrs.dtype() == torch::kInt64, "dst_ptrs must be int64");
TORCH_CHECK(sizes.dtype() == torch::kInt64, "sizes must be int64");
STD_TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU");
STD_TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU");
STD_TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU");
STD_TORCH_CHECK(src_ptrs.scalar_type() == torch::headeronly::ScalarType::Long,
"src_ptrs must be int64");
STD_TORCH_CHECK(dst_ptrs.scalar_type() == torch::headeronly::ScalarType::Long,
"dst_ptrs must be int64");
STD_TORCH_CHECK(sizes.scalar_type() == torch::headeronly::ScalarType::Long,
"sizes must be int64");
const int64_t n = src_ptrs.size(0);
TORCH_CHECK(dst_ptrs.size(0) == n, "dst_ptrs length must match src_ptrs");
TORCH_CHECK(sizes.size(0) == n, "sizes length must match src_ptrs");
STD_TORCH_CHECK(dst_ptrs.size(0) == n, "dst_ptrs length must match src_ptrs");
STD_TORCH_CHECK(sizes.size(0) == n, "sizes length must match src_ptrs");
if (n == 0) return;
@@ -96,7 +100,7 @@ void swap_blocks_batch(const torch::Tensor& src_ptrs,
int64_t* dst_data = dst_ptrs.mutable_data_ptr<int64_t>();
int64_t* size_data = sizes.mutable_data_ptr<int64_t>();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const cudaStream_t stream = get_current_cuda_stream();
// Use cuMemcpyBatchAsync / hipMemcpyBatchAsync to submit all copies in a
// single driver call, amortizing per-copy submission overhead. int64_t
@@ -138,8 +142,9 @@ void swap_blocks_batch(const torch::Tensor& src_ptrs,
reinterpret_cast<size_t*>(size_data),
static_cast<size_t>(n), &attr, &attrs_idx, 1,
&fail_idx, static_cast<CUstream>(stream));
TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ",
fail_idx, " with error ", result);
STD_TORCH_CHECK(result == CUDA_SUCCESS,
"cuMemcpyBatchAsync failed at index ", fail_idx,
" with error ", result);
return;
}
#elif defined(USE_ROCM) && defined(HIP_VERSION) && HIP_VERSION >= 70100000
@@ -155,8 +160,9 @@ void swap_blocks_batch(const torch::Tensor& src_ptrs,
reinterpret_cast<void**>(dst_data), reinterpret_cast<void**>(src_data),
reinterpret_cast<size_t*>(size_data), static_cast<size_t>(n), &attr,
&attrs_idx, 0, &fail_idx, static_cast<hipStream_t>(stream));
TORCH_CHECK(result == hipSuccess, "hipMemcpyBatchAsync failed at index ",
fail_idx, " with error ", result);
STD_TORCH_CHECK(result == hipSuccess,
"hipMemcpyBatchAsync failed at index ", fail_idx,
" with error ", result);
return;
}
#endif
@@ -682,21 +688,21 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel(
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
slot_mapping.const_data_ptr<int64_t>(), key_stride, value_stride, \
num_heads, head_size, block_size, x, \
reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor&
torch::stable::Tensor& key, // [num_tokens, num_heads, head_size]
torch::stable::Tensor& value, // [num_tokens, num_heads, head_size]
torch::stable::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
torch::stable::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
torch::stable::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, torch::stable::Tensor& k_scale,
torch::stable::Tensor& v_scale) {
int num_tokens = slot_mapping.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
@@ -709,10 +715,11 @@ void reshape_and_cache(
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_div_x, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
key.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
DISPATCH_BY_KV_CACHE_DTYPE(key.scalar_type(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE);
}
@@ -726,22 +733,23 @@ void reshape_and_cache(
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, page_stride, \
slot_mapping.const_data_ptr<int64_t>(), block_stride, page_stride, \
head_stride, key_stride, value_stride, num_heads, head_size, \
block_size, reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()), \
kv_scale_stride);
void reshape_and_cache_flash(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor&
torch::stable::Tensor& key, // [num_tokens, num_heads, head_size]
torch::stable::Tensor& value, // [num_tokens, num_heads, head_size]
torch::stable::Tensor&
key_cache, // [num_blocks, block_size, num_heads, head_size]
torch::stable::Tensor&
value_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
torch::stable::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, // [1] or [num_heads]
torch::Tensor& v_scale) { // [1] or [num_heads]
torch::stable::Tensor& k_scale, // [1] or [num_heads]
torch::stable::Tensor& v_scale) { // [1] or [num_heads]
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
@@ -756,21 +764,24 @@ void reshape_and_cache_flash(
int num_heads = key.size(1);
int head_size = key.size(2);
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
key.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
if (kv_cache_dtype == "nvfp4") {
#if defined(ENABLE_NVFP4_SM100) || defined(ENABLE_NVFP4_SM120)
// NVFP4 dispatch is compiled separately for SM100+.
extern void reshape_and_cache_nvfp4_dispatch(
torch::Tensor & key, torch::Tensor & value, torch::Tensor & key_cache,
torch::Tensor & value_cache, torch::Tensor & slot_mapping,
torch::Tensor & k_scale, torch::Tensor & v_scale);
torch::stable::Tensor & key, torch::stable::Tensor & value,
torch::stable::Tensor & key_cache, torch::stable::Tensor & value_cache,
torch::stable::Tensor & slot_mapping, torch::stable::Tensor & k_scale,
torch::stable::Tensor & v_scale);
reshape_and_cache_nvfp4_dispatch(key, value, key_cache, value_cache,
slot_mapping, k_scale, v_scale);
return;
#else
TORCH_CHECK(false,
STD_TORCH_CHECK(
false,
"NVFP4 KV cache requires SM100+ (Blackwell). "
"Please rebuild vllm with a Blackwell-compatible CUDA target.");
#endif
@@ -784,18 +795,18 @@ void reshape_and_cache_flash(
int64_t block_stride = key_cache.stride(0);
int64_t page_stride = key_cache.stride(1);
int64_t head_stride = key_cache.stride(2);
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
STD_TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
TORCH_CHECK(k_scale.sizes() == v_scale.sizes(),
STD_TORCH_CHECK(k_scale.sizes().equals(v_scale.sizes()),
"k_scale and v_scale must have the same shape");
TORCH_CHECK(k_scale.numel() == 1 || k_scale.numel() == num_heads,
STD_TORCH_CHECK(k_scale.numel() == 1 || k_scale.numel() == num_heads,
"k_scale and v_scale must be of shape [1] or [num_heads]");
int kv_scale_stride = (k_scale.numel() > 1) ? 1 : 0;
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
DISPATCH_BY_KV_CACHE_DTYPE(key.scalar_type(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE_FLASH);
}
@@ -808,7 +819,7 @@ void reshape_and_cache_flash(
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
slot_mapping.const_data_ptr<int64_t>(), block_stride, entry_stride, \
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));
@@ -820,17 +831,17 @@ void reshape_and_cache_flash(
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
slot_mapping.const_data_ptr<int64_t>(), block_stride, entry_stride, \
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));
void concat_and_cache_mla(
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& scale) {
torch::stable::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::stable::Tensor& k_pe, // [num_tokens, pe_dim]
torch::stable::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank
// + pe_dim)]
torch::stable::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::stable::Tensor& scale) {
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
@@ -847,16 +858,17 @@ void concat_and_cache_mla(
int block_size = kv_cache.size(1);
if (kv_cache_dtype == "fp8_ds_mla") {
TORCH_CHECK(kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla");
TORCH_CHECK(pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla");
TORCH_CHECK(kv_cache.size(2) == 656 / kv_cache.itemsize(),
STD_TORCH_CHECK(kv_lora_rank == 512,
"kv_lora_rank must be 512 for fp8_ds_mla");
STD_TORCH_CHECK(pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla");
STD_TORCH_CHECK(kv_cache.size(2) == 656 / kv_cache.element_size(),
"kv_cache.size(2) must be 656 bytes for fp8_ds_mla");
TORCH_CHECK(kv_c.itemsize() == 2,
"kv_c.itemsize() must be 2 for fp8_ds_mla");
TORCH_CHECK(k_pe.itemsize() == 2,
"k_pe.itemsize() must be 2 for fp8_ds_mla");
STD_TORCH_CHECK(kv_c.element_size() == 2,
"kv_c.element_size() must be 2 for fp8_ds_mla");
STD_TORCH_CHECK(k_pe.element_size() == 2,
"k_pe.element_size() must be 2 for fp8_ds_mla");
} else {
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
STD_TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
}
int kv_c_stride = kv_c.stride(0);
@@ -864,8 +876,9 @@ void concat_and_cache_mla(
int block_stride = kv_cache.stride(0);
int entry_stride = kv_cache.stride(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
kv_c.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
if (kv_cache_dtype == "fp8_ds_mla") {
dim3 grid(num_tokens);
@@ -875,12 +888,12 @@ void concat_and_cache_mla(
// The RoPE part (last 64 elements) is handled by another 1 warp (32
// threads). So in total, we use 3 warps (96 threads) per block.
dim3 block(96);
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.scalar_type(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_DS_MLA);
} else {
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.scalar_type(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);
}
}
@@ -908,55 +921,62 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
reinterpret_cast<Tout*>(dst_cache.data_ptr()), scale, block_stride);
// Only for testing.
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype) {
torch::Device src_device = src_cache.device();
torch::Device dst_device = dst_cache.device();
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
TORCH_CHECK(src_device.index() == dst_device.index(),
void convert_fp8(torch::stable::Tensor& dst_cache,
torch::stable::Tensor& src_cache, const double scale,
const std::string& kv_cache_dtype) {
torch::stable::Device src_device = src_cache.device();
torch::stable::Device dst_device = dst_cache.device();
STD_TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
STD_TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
STD_TORCH_CHECK(src_device.index() == dst_device.index(),
"src and dst must be on the same GPU");
at::cuda::OptionalCUDAGuard device_guard(src_device);
torch::stable::accelerator::DeviceGuard device_guard(src_device.index());
int64_t num_blocks = src_cache.size(0);
int64_t block_stride = src_cache.stride(0);
dim3 grid(num_blocks);
dim3 block(std::min(block_stride, int64_t(512)));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const cudaStream_t stream = get_current_cuda_stream();
if (kv_cache_dtype == "auto") {
if (src_cache.dtype() == at::ScalarType::Float) {
if (src_cache.scalar_type() == torch::headeronly::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto);
} else if (src_cache.dtype() == at::ScalarType::Half) {
} else if (src_cache.scalar_type() == torch::headeronly::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
} else if (src_cache.scalar_type() ==
torch::headeronly::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
} else if (dst_cache.scalar_type() ==
torch::headeronly::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
} else if (dst_cache.scalar_type() == torch::headeronly::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
} else if (dst_cache.scalar_type() ==
torch::headeronly::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
}
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
if (src_cache.dtype() == at::ScalarType::Float) {
if (src_cache.scalar_type() == torch::headeronly::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (src_cache.dtype() == at::ScalarType::Half) {
} else if (src_cache.scalar_type() == torch::headeronly::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
} else if (src_cache.scalar_type() ==
torch::headeronly::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
} else if (dst_cache.scalar_type() ==
torch::headeronly::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
} else if (dst_cache.scalar_type() == torch::headeronly::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
} else if (dst_cache.scalar_type() ==
torch::headeronly::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3);
}
} else {
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
STD_TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
}
}
@@ -1053,8 +1073,9 @@ __global__ void gather_and_maybe_dequant_cache(
<<<grid, block, 0, stream>>>( \
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
reinterpret_cast<SCALAR_T*>(dst.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
token_to_seq.data_ptr<int32_t>(), num_tokens, block_size, \
block_table.const_data_ptr<int32_t>(), \
cu_seq_lens.const_data_ptr<int32_t>(), \
token_to_seq.const_data_ptr<int32_t>(), num_tokens, block_size, \
block_table_stride, cache_block_stride, cache_entry_stride, \
dst_entry_stride, reinterpret_cast<const float*>(scale.data_ptr()), \
seq_starts_ptr);
@@ -1072,41 +1093,46 @@ __global__ void gather_and_maybe_dequant_cache(
// - Optionally, seq_starts (if provided) offsets the starting block index by
// (seq_starts[bid] / page_size)
void gather_and_maybe_dequant_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS]
torch::stable::Tensor const&
src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::stable::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::stable::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::stable::Tensor const& cu_seq_lens, // [BATCH+1]
torch::stable::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS]
int64_t num_tokens, const std::string& kv_cache_dtype,
torch::Tensor const& scale,
std::optional<torch::Tensor> seq_starts = std::nullopt) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::stable::Tensor const& scale,
std::optional<torch::stable::Tensor> seq_starts = std::nullopt) {
torch::stable::accelerator::DeviceGuard device_guard(
src_cache.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
int32_t block_size = src_cache.size(1);
int32_t head_dim = dst.size(-1);
TORCH_CHECK(block_table.dtype() == torch::kInt32,
STD_TORCH_CHECK(
block_table.scalar_type() == torch::headeronly::ScalarType::Int,
"block_table must be int32");
TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32,
STD_TORCH_CHECK(
cu_seq_lens.scalar_type() == torch::headeronly::ScalarType::Int,
"cu_seq_lens must be int32");
if (seq_starts.has_value()) {
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
STD_TORCH_CHECK(
seq_starts.value().scalar_type() == torch::headeronly::ScalarType::Int,
"seq_starts must be int32");
}
TORCH_CHECK(
STD_TORCH_CHECK(
head_dim == 320 || head_dim == 576,
"gather_and_maybe_dequant_cache only support the head_dim to 320 or 576 "
"for better performance")
TORCH_CHECK(src_cache.device() == dst.device(),
STD_TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
TORCH_CHECK(src_cache.device() == block_table.device(),
STD_TORCH_CHECK(src_cache.device() == block_table.device(),
"src_cache and block_table must be on the same device");
TORCH_CHECK(src_cache.device() == cu_seq_lens.device(),
STD_TORCH_CHECK(src_cache.device() == cu_seq_lens.device(),
"src_cache and cu_seq_lens must be on the same device");
if (seq_starts.has_value()) {
TORCH_CHECK(src_cache.device() == seq_starts.value().device(),
STD_TORCH_CHECK(src_cache.device() == seq_starts.value().device(),
"src_cache and seq_starts must be on the same device");
}
@@ -1120,13 +1146,14 @@ void gather_and_maybe_dequant_cache(
dim3 block(thread_block_size);
const int32_t* seq_starts_ptr =
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
seq_starts.has_value() ? seq_starts.value().const_data_ptr<int32_t>()
: nullptr;
if (head_dim == 576) {
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype,
DISPATCH_BY_KV_CACHE_DTYPE(dst.scalar_type(), kv_cache_dtype,
CALL_GATHER_CACHE_576);
} else {
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype,
DISPATCH_BY_KV_CACHE_DTYPE(dst.scalar_type(), kv_cache_dtype,
CALL_GATHER_CACHE_320);
}
}
@@ -1271,9 +1298,10 @@ __global__ void cp_gather_cache(
vllm::cp_gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
block_size, entry_size, block_table_stride, cache_block_stride, \
cache_entry_stride, dst_entry_stride, seq_starts_ptr);
block_table.const_data_ptr<int32_t>(), \
cu_seq_lens.const_data_ptr<int32_t>(), block_size, entry_size, \
block_table_stride, cache_block_stride, cache_entry_stride, \
dst_entry_stride, seq_starts_ptr);
// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
@@ -1281,35 +1309,40 @@ __global__ void cp_gather_cache(
// - Optionally, seq_starts (if provided) offsets the starting slot index by
// seq_starts[bid]
void cp_gather_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
torch::stable::Tensor const&
src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::stable::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::stable::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::stable::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size,
std::optional<torch::Tensor> seq_starts = std::nullopt) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
std::optional<torch::stable::Tensor> seq_starts = std::nullopt) {
torch::stable::accelerator::DeviceGuard device_guard(
src_cache.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
int32_t block_size = src_cache.size(1);
int32_t entry_size = src_cache.flatten(2, -1).size(2);
int32_t entry_size = torch::stable::flatten(src_cache, 2, -1).size(2);
TORCH_CHECK(block_table.dtype() == torch::kInt32,
STD_TORCH_CHECK(
block_table.scalar_type() == torch::headeronly::ScalarType::Int,
"block_table must be int32");
TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32,
STD_TORCH_CHECK(
cu_seq_lens.scalar_type() == torch::headeronly::ScalarType::Int,
"cu_seq_lens must be int32");
if (seq_starts.has_value()) {
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
STD_TORCH_CHECK(
seq_starts.value().scalar_type() == torch::headeronly::ScalarType::Int,
"seq_starts must be int32");
}
TORCH_CHECK(src_cache.device() == dst.device(),
STD_TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
TORCH_CHECK(src_cache.device() == block_table.device(),
STD_TORCH_CHECK(src_cache.device() == block_table.device(),
"src_cache and block_table must be on the same device");
TORCH_CHECK(src_cache.device() == cu_seq_lens.device(),
STD_TORCH_CHECK(src_cache.device() == cu_seq_lens.device(),
"src_cache and cu_seq_lens must be on the same device");
if (seq_starts.has_value()) {
TORCH_CHECK(src_cache.device() == seq_starts.value().device(),
STD_TORCH_CHECK(src_cache.device() == seq_starts.value().device(),
"src_cache and seq_starts must be on the same device");
}
@@ -1323,12 +1356,13 @@ void cp_gather_cache(
dim3 grid(batch_size, num_splits);
dim3 block(1024);
TORCH_CHECK(src_cache.dtype() == dst.dtype(),
STD_TORCH_CHECK(src_cache.scalar_type() == dst.scalar_type(),
"src_cache and dst must have the same dtype");
const int dtype_bits = src_cache.element_size() * 8;
const int32_t* seq_starts_ptr =
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
seq_starts.has_value() ? seq_starts.value().const_data_ptr<int32_t>()
: nullptr;
if (dtype_bits == 32) {
CALL_CP_GATHER_CACHE(uint32_t);
@@ -1337,46 +1371,51 @@ void cp_gather_cache(
} else if (dtype_bits == 8) {
CALL_CP_GATHER_CACHE(uint8_t);
} else {
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
STD_TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
}
}
void cp_gather_and_upconvert_fp8_kv_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
torch::Tensor const& dst, // [TOT_TOKENS, 576]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& seq_lens, // [BATCH]
torch::Tensor const& workspace_starts, // [BATCH]
torch::stable::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
torch::stable::Tensor const& dst, // [TOT_TOKENS, 576]
torch::stable::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::stable::Tensor const& seq_lens, // [BATCH]
torch::stable::Tensor const& workspace_starts, // [BATCH]
int64_t batch_size) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::stable::accelerator::DeviceGuard device_guard(
src_cache.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
int32_t block_size = src_cache.size(1);
int32_t head_dim = dst.size(1);
TORCH_CHECK(block_table.dtype() == torch::kInt32,
STD_TORCH_CHECK(
block_table.scalar_type() == torch::headeronly::ScalarType::Int,
"block_table must be int32");
TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32");
TORCH_CHECK(workspace_starts.dtype() == torch::kInt32,
STD_TORCH_CHECK(seq_lens.scalar_type() == torch::headeronly::ScalarType::Int,
"seq_lens must be int32");
STD_TORCH_CHECK(
workspace_starts.scalar_type() == torch::headeronly::ScalarType::Int,
"workspace_starts must be int32");
TORCH_CHECK(src_cache.device() == dst.device(),
STD_TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
TORCH_CHECK(src_cache.device() == block_table.device(),
STD_TORCH_CHECK(src_cache.device() == block_table.device(),
"src_cache and block_table must be on the same device");
TORCH_CHECK(src_cache.device() == seq_lens.device(),
STD_TORCH_CHECK(src_cache.device() == seq_lens.device(),
"src_cache and seq_lens must be on the same device");
TORCH_CHECK(src_cache.device() == workspace_starts.device(),
STD_TORCH_CHECK(src_cache.device() == workspace_starts.device(),
"src_cache and workspace_starts must be on the same device");
auto dtype = src_cache.scalar_type();
TORCH_CHECK(
dtype == at::ScalarType::Byte || // uint8
dtype == at::ScalarType::Float8_e4m3fn || // fp8 e4m3
dtype == at::ScalarType::Float8_e5m2, // fp8 e5m2
STD_TORCH_CHECK(
dtype == torch::headeronly::ScalarType::Byte || // uint8
dtype == torch::headeronly::ScalarType::Float8_e4m3fn || // fp8 e4m3
dtype == torch::headeronly::ScalarType::Float8_e5m2, // fp8 e5m2
"src_cache must be uint8, float8_e4m3fn, or float8_e5m2, but got ",
src_cache.dtype());
TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16");
TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");
src_cache.scalar_type());
STD_TORCH_CHECK(dst.scalar_type() == torch::headeronly::ScalarType::BFloat16,
"dst must be bfloat16");
STD_TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");
int64_t block_table_stride = block_table.stride(0);
int64_t cache_block_stride = src_cache.stride(0);
@@ -1384,8 +1423,8 @@ void cp_gather_and_upconvert_fp8_kv_cache(
int64_t dst_entry_stride = dst.stride(0);
const uint8_t* src_ptr = nullptr;
if (dtype == at::ScalarType::Byte) {
src_ptr = src_cache.data_ptr<uint8_t>();
if (dtype == torch::headeronly::ScalarType::Byte) {
src_ptr = src_cache.const_data_ptr<uint8_t>();
} else {
// float8_e4m3fn or float8_e5m2
src_ptr = reinterpret_cast<const uint8_t*>(src_cache.data_ptr());
@@ -1399,7 +1438,8 @@ void cp_gather_and_upconvert_fp8_kv_cache(
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid_size, block_size_threads, 0,
stream>>>(
src_ptr, reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
block_table.data_ptr<int32_t>(), workspace_starts.data_ptr<int32_t>(),
block_table.const_data_ptr<int32_t>(),
workspace_starts.const_data_ptr<int32_t>(),
static_cast<int32_t>(batch_size), block_size, total_tokens,
block_table_stride, cache_block_stride, cache_entry_stride,
dst_entry_stride);
@@ -1411,13 +1451,13 @@ void cp_gather_and_upconvert_fp8_kv_cache(
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(k.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), head_dim, quant_block_size, \
slot_mapping.const_data_ptr<int64_t>(), head_dim, quant_block_size, \
cache_block_size, cache_stride, use_ue8m0);
void indexer_k_quant_and_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
torch::stable::Tensor& k, // [num_tokens, head_dim]
torch::stable::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::stable::Tensor& slot_mapping, // [num_tokens]
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt) {
int num_tokens = k.size(0);
@@ -1426,22 +1466,23 @@ void indexer_k_quant_and_cache(
int cache_stride = kv_cache.size(2);
bool use_ue8m0 = scale_fmt == "ue8m0";
TORCH_CHECK(k.device() == kv_cache.device(),
STD_TORCH_CHECK(k.device() == kv_cache.device(),
"k and kv_cache must be on the same device");
TORCH_CHECK(k.device() == slot_mapping.device(),
STD_TORCH_CHECK(k.device() == slot_mapping.device(),
"k and slot_mapping must be on the same device");
TORCH_CHECK(head_dim % quant_block_size == 0,
STD_TORCH_CHECK(head_dim % quant_block_size == 0,
"head_dim must be divisible by quant_block_size");
constexpr int vec_size = 4;
dim3 grid(num_tokens, (head_dim + quant_block_size * vec_size - 1) /
(quant_block_size * vec_size));
dim3 block(32, vec_size);
const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
k.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
static const std::string kv_cache_dtype = "fp8_e4m3";
DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), kv_cache_dtype,
DISPATCH_BY_KV_CACHE_DTYPE(k.scalar_type(), kv_cache_dtype,
CALL_INDEXER_K_QUANT_AND_CACHE);
}
@@ -1454,37 +1495,41 @@ void indexer_k_quant_and_cache(
reinterpret_cast<char*>(kv_cache.data_ptr()), \
reinterpret_cast<char*>(dst_k.data_ptr()), \
reinterpret_cast<char*>(dst_scale.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
batch_size, dst_k.stride(0), dst_k.size(1), kv_cache.stride(0), \
kv_cache.stride(1), kv_cache.size(1), block_table.size(1), \
num_tokens, quant_block_size);
block_table.const_data_ptr<int32_t>(), \
cu_seq_lens.const_data_ptr<int32_t>(), batch_size, dst_k.stride(0), \
dst_k.size(1), kv_cache.stride(0), kv_cache.stride(1), \
kv_cache.size(1), block_table.size(1), num_tokens, \
quant_block_size);
void cp_gather_indexer_k_quant_cache(
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& dst_k, // [num_tokens, head_dim]
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
const torch::Tensor& block_table, // [batch_size, num_blocks]
const torch::Tensor& cu_seq_lens // [batch_size + 1]
const torch::stable::Tensor&
kv_cache, // [num_blocks, block_size, cache_stride]
torch::stable::Tensor& dst_k, // [num_tokens, head_dim]
torch::stable::Tensor&
dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
const torch::stable::Tensor& block_table, // [batch_size, num_blocks]
const torch::stable::Tensor& cu_seq_lens // [batch_size + 1]
) {
int batch_size = block_table.size(0);
int num_tokens = dst_k.size(0);
int head_dim = dst_k.size(1);
int quant_block_size = head_dim * 4 / dst_scale.size(1);
TORCH_CHECK(kv_cache.device() == dst_k.device(),
STD_TORCH_CHECK(kv_cache.device() == dst_k.device(),
"kv_cache and dst_k must be on the same device");
TORCH_CHECK(kv_cache.device() == dst_scale.device(),
STD_TORCH_CHECK(kv_cache.device() == dst_scale.device(),
"kv_cache and dst_scale must be on the same device");
TORCH_CHECK(kv_cache.device() == block_table.device(),
STD_TORCH_CHECK(kv_cache.device() == block_table.device(),
"kv_cache and block_table must be on the same device");
TORCH_CHECK(kv_cache.device() == cu_seq_lens.device(),
STD_TORCH_CHECK(kv_cache.device() == cu_seq_lens.device(),
"kv_cache and cu_seq_lens must be on the same device");
TORCH_CHECK(head_dim % quant_block_size == 0,
STD_TORCH_CHECK(head_dim % quant_block_size == 0,
"head_dim must be divisible by quant_block_size");
constexpr int vec_size = 16;
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_cache));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
kv_cache.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
if (num_tokens < 32) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(1);
@@ -1503,9 +1548,10 @@ void cp_gather_indexer_k_quant_cache(
// Concatenate ql_nope and q_pe into a contiguous q_out tensor for MLA/DSA.
// Replaces torch.cat((ql_nope, q_pe), dim=-1).
void concat_mla_q(torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim]
torch::Tensor& q_pe, // [num_tokens, num_heads, rope_dim]
torch::Tensor& q_out // [num_tokens, num_heads, nope_dim +
void concat_mla_q(
torch::stable::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim]
torch::stable::Tensor& q_pe, // [num_tokens, num_heads, rope_dim]
torch::stable::Tensor& q_out // [num_tokens, num_heads, nope_dim +
// rope_dim]
) {
const int num_tokens = ql_nope.size(0);
@@ -1513,16 +1559,18 @@ void concat_mla_q(torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim]
const int nope_dim = ql_nope.size(2);
const int rope_dim = q_pe.size(2);
TORCH_CHECK(nope_dim % 512 == 0, "nope_dim must be a multiple of 512, got ",
nope_dim);
TORCH_CHECK(rope_dim == 64, "rope_dim must be 64, got ", rope_dim);
TORCH_CHECK(q_out.size(2) == nope_dim + rope_dim);
STD_TORCH_CHECK(nope_dim % 512 == 0,
"nope_dim must be a multiple of 512, got ", nope_dim);
STD_TORCH_CHECK(rope_dim == 64, "rope_dim must be 64, got ", rope_dim);
STD_TORCH_CHECK(q_out.size(2) == nope_dim + rope_dim);
TORCH_CHECK(ql_nope.stride(2) == 1, "ql_nope must have stride 1 in dim 2");
TORCH_CHECK(q_pe.stride(2) == 1, "q_pe must have stride 1 in dim 2");
TORCH_CHECK(q_out.stride(2) == 1, "q_out must have stride 1 in dim 2");
TORCH_CHECK(ql_nope.scalar_type() == at::ScalarType::Half ||
ql_nope.scalar_type() == at::ScalarType::BFloat16,
STD_TORCH_CHECK(ql_nope.stride(2) == 1,
"ql_nope must have stride 1 in dim 2");
STD_TORCH_CHECK(q_pe.stride(2) == 1, "q_pe must have stride 1 in dim 2");
STD_TORCH_CHECK(q_out.stride(2) == 1, "q_out must have stride 1 in dim 2");
STD_TORCH_CHECK(
ql_nope.scalar_type() == torch::headeronly::ScalarType::Half ||
ql_nope.scalar_type() == torch::headeronly::ScalarType::BFloat16,
"ql_nope must be float16 or bfloat16 dtype");
if (num_tokens == 0) return;
@@ -1532,13 +1580,14 @@ void concat_mla_q(torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim]
const int grid_size = (total_warps + warps_per_block - 1) / warps_per_block;
const int block_size = warps_per_block * 32;
const at::cuda::OptionalCUDAGuard device_guard(device_of(ql_nope));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
ql_nope.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
VLLM_DISPATCH_HALF_TYPES(ql_nope.scalar_type(), "concat_mla_q", [&] {
VLLM_STABLE_DISPATCH_HALF_TYPES(ql_nope.scalar_type(), "concat_mla_q", [&] {
vllm::ConcatMLAQKernel<scalar_t, 512><<<grid_size, block_size, 0, stream>>>(
q_out.data_ptr<scalar_t>(), ql_nope.data_ptr<scalar_t>(),
q_pe.data_ptr<scalar_t>(), num_tokens, num_heads, q_out.stride(0),
q_out.mutable_data_ptr<scalar_t>(), ql_nope.const_data_ptr<scalar_t>(),
q_pe.const_data_ptr<scalar_t>(), num_tokens, num_heads, q_out.stride(0),
q_out.stride(1), ql_nope.stride(0), ql_nope.stride(1), q_pe.stride(0),
q_pe.stride(1));
});
@@ -1,15 +1,13 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "torch_utils.h"
#include "dispatch_utils.h"
#include "quantization/w8a8/fp8/common.cuh"
#include "../cuda_compat.h"
#include "../quantization/w8a8/fp8/common.cuh"
#ifdef USE_ROCM
#include "quantization/w8a8/fp8/amd/quant_utils.cuh"
#include "../quantization/w8a8/fp8/amd/quant_utils.cuh"
#else
#include "quantization/w8a8/fp8/nvidia/quant_utils.cuh"
#include "../quantization/w8a8/fp8/nvidia/quant_utils.cuh"
#endif
#ifdef USE_ROCM
@@ -166,38 +164,47 @@ __global__ void concat_and_cache_mla_rope_fused_kernel(
#define CALL_CONCAT_AND_CACHE_MLA_ROPE_FUSED(RAW_KV_T, CACHE_T, KV_DTYPE) \
do { \
VLLM_DISPATCH_FLOATING_TYPES(q_pe.scalar_type(), "qk_scalar_type", [&] { \
VLLM_STABLE_DISPATCH_FLOATING_TYPES( \
q_pe.scalar_type(), "qk_scalar_type", [&] { \
using qk_t = scalar_t; \
VLLM_DISPATCH_FLOATING_TYPES( \
rope_cos_sin_cache.scalar_type(), "rope_cos_sin_cache_scalar_type", \
[&] { \
VLLM_STABLE_DISPATCH_FLOATING_TYPES( \
rope_cos_sin_cache.scalar_type(), \
"rope_cos_sin_cache_scalar_type", [&] { \
using cos_sin_t = scalar_t; \
if (rope_is_neox) { \
vllm::concat_and_cache_mla_rope_fused_kernel< \
qk_t, cos_sin_t, true, RAW_KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
positions.data_ptr<int64_t>(), q_pe.data_ptr<qk_t>(), \
k_pe.data_ptr<qk_t>(), kv_c.data_ptr<qk_t>(), \
rope_cos_sin_cache.data_ptr<cos_sin_t>(), rot_dim, \
q_pe_stride_token, q_pe_stride_head, k_pe_stride, \
kv_c_stride, num_q_heads, \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, \
entry_stride, kv_lora_rank, block_size, \
kv_cache_quant_scale.data_ptr<float>()); \
positions.const_data_ptr<int64_t>(), \
q_pe.mutable_data_ptr<qk_t>(), \
k_pe.mutable_data_ptr<qk_t>(), \
kv_c.const_data_ptr<qk_t>(), \
rope_cos_sin_cache.const_data_ptr<cos_sin_t>(), \
rot_dim, q_pe_stride_token, q_pe_stride_head, \
k_pe_stride, kv_c_stride, num_q_heads, \
reinterpret_cast<CACHE_T*>( \
kv_cache.mutable_data_ptr()), \
slot_mapping.const_data_ptr<int64_t>(), \
block_stride, entry_stride, kv_lora_rank, \
block_size, \
kv_cache_quant_scale.const_data_ptr<float>()); \
} else { \
vllm::concat_and_cache_mla_rope_fused_kernel< \
qk_t, cos_sin_t, false, RAW_KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
positions.data_ptr<int64_t>(), q_pe.data_ptr<qk_t>(), \
k_pe.data_ptr<qk_t>(), kv_c.data_ptr<qk_t>(), \
rope_cos_sin_cache.data_ptr<cos_sin_t>(), rot_dim, \
q_pe_stride_token, q_pe_stride_head, k_pe_stride, \
kv_c_stride, num_q_heads, \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, \
entry_stride, kv_lora_rank, block_size, \
kv_cache_quant_scale.data_ptr<float>()); \
positions.const_data_ptr<int64_t>(), \
q_pe.mutable_data_ptr<qk_t>(), \
k_pe.mutable_data_ptr<qk_t>(), \
kv_c.const_data_ptr<qk_t>(), \
rope_cos_sin_cache.const_data_ptr<cos_sin_t>(), \
rot_dim, q_pe_stride_token, q_pe_stride_head, \
k_pe_stride, kv_c_stride, num_q_heads, \
reinterpret_cast<CACHE_T*>( \
kv_cache.mutable_data_ptr()), \
slot_mapping.const_data_ptr<int64_t>(), \
block_stride, entry_stride, kv_lora_rank, \
block_size, \
kv_cache_quant_scale.const_data_ptr<float>()); \
} \
}); \
}); \
@@ -208,64 +215,69 @@ __global__ void concat_and_cache_mla_rope_fused_kernel(
// Replaces DeepseekScalingRotaryEmbedding.self.rotary_emb and
// concat_and_cache_mla.
void concat_and_cache_mla_rope_fused(
torch::Tensor& positions, // [num_tokens]
torch::Tensor& q_pe, // [num_tokens, num_q_heads, rot_dim]
torch::Tensor& k_pe, // [num_tokens, rot_dim]
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& rope_cos_sin_cache, // [max_position, rot_dim]
torch::stable::Tensor& positions, // [num_tokens]
torch::stable::Tensor& q_pe, // [num_tokens, num_q_heads, rot_dim]
torch::stable::Tensor& k_pe, // [num_tokens, rot_dim]
torch::stable::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::stable::Tensor& rope_cos_sin_cache, // [max_position, rot_dim]
bool rope_is_neox,
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
torch::Tensor&
torch::stable::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
torch::stable::Tensor&
kv_cache, // [num_blocks, block_size, (kv_lora_rank + rot_dim)]
const std::string& kv_cache_dtype, torch::Tensor& kv_cache_quant_scale) {
const std::string& kv_cache_dtype,
torch::stable::Tensor& kv_cache_quant_scale) {
// NOTE(woosuk): In vLLM V1, query/key/position.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
// both include padding.
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
// since key includes padding for CUDA graphs, while slot_mapping does not.
// In this case, slot_mapping.size(0) represents the actual number of tokens
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0)
// because both include padding.
// In vLLM V1, however, key.size(0) can be larger than
// slot_mapping.size(0) since key includes padding for CUDA graphs,
// while slot_mapping does not. In this case,
// slot_mapping.size(0) represents the actual number of tokens
// before padding.
// For compatibility with both cases, we use slot_mapping.size(0) as the
// number of tokens.
int num_tokens = slot_mapping.size(0);
int num_padded_tokens = q_pe.size(0);
TORCH_CHECK_GE(num_padded_tokens, num_tokens);
// For compatibility with both cases, we use slot_mapping.size(0) as
// the number of tokens.
const int64_t num_tokens = slot_mapping.size(0);
const int64_t num_padded_tokens = q_pe.size(0);
STD_TORCH_CHECK(num_padded_tokens >= num_tokens);
const int num_q_heads = q_pe.size(1);
const int rot_dim = q_pe.size(2);
const int kv_lora_rank = kv_c.size(1);
TORCH_CHECK_EQ(positions.size(0), num_padded_tokens);
TORCH_CHECK_EQ(positions.dim(), 1);
TORCH_CHECK_EQ(positions.scalar_type(), c10::ScalarType::Long);
STD_TORCH_CHECK(positions.size(0) == num_padded_tokens);
STD_TORCH_CHECK(positions.dim() == 1);
STD_TORCH_CHECK(positions.scalar_type() ==
torch::headeronly::ScalarType::Long);
TORCH_CHECK_EQ(q_pe.dim(), 3);
TORCH_CHECK_EQ(q_pe.size(0), num_padded_tokens);
TORCH_CHECK_EQ(q_pe.size(1), num_q_heads);
TORCH_CHECK_EQ(q_pe.size(2), rot_dim);
STD_TORCH_CHECK(q_pe.dim() == 3);
STD_TORCH_CHECK(q_pe.size(0) == num_padded_tokens);
STD_TORCH_CHECK(q_pe.size(1) == num_q_heads);
STD_TORCH_CHECK(q_pe.size(2) == rot_dim);
TORCH_CHECK_EQ(k_pe.dim(), 2);
TORCH_CHECK_EQ(k_pe.size(0), num_padded_tokens);
TORCH_CHECK_EQ(k_pe.size(1), rot_dim);
TORCH_CHECK_EQ(k_pe.scalar_type(), q_pe.scalar_type());
STD_TORCH_CHECK(k_pe.dim() == 2);
STD_TORCH_CHECK(k_pe.size(0) == num_padded_tokens);
STD_TORCH_CHECK(k_pe.size(1) == rot_dim);
STD_TORCH_CHECK(k_pe.scalar_type() == q_pe.scalar_type());
TORCH_CHECK_EQ(kv_c.dim(), 2);
TORCH_CHECK_EQ(kv_c.size(0), num_padded_tokens);
TORCH_CHECK_EQ(kv_c.size(1), kv_lora_rank);
TORCH_CHECK_EQ(kv_c.scalar_type(), q_pe.scalar_type());
TORCH_CHECK_EQ(kv_c.dtype(), q_pe.dtype());
STD_TORCH_CHECK(kv_c.dim() == 2);
STD_TORCH_CHECK(kv_c.size(0) == num_padded_tokens);
STD_TORCH_CHECK(kv_c.size(1) == kv_lora_rank);
STD_TORCH_CHECK(kv_c.scalar_type() == q_pe.scalar_type());
TORCH_CHECK_EQ(rope_cos_sin_cache.size(1), rot_dim);
STD_TORCH_CHECK(rope_cos_sin_cache.size(1) == rot_dim);
STD_TORCH_CHECK(rope_cos_sin_cache.scalar_type() == q_pe.scalar_type());
TORCH_CHECK_EQ(slot_mapping.size(0), num_tokens);
TORCH_CHECK_EQ(slot_mapping.scalar_type(), c10::ScalarType::Long);
STD_TORCH_CHECK(slot_mapping.size(0) == num_tokens);
STD_TORCH_CHECK(slot_mapping.scalar_type() ==
torch::headeronly::ScalarType::Long);
TORCH_CHECK_EQ(kv_cache.size(2), kv_lora_rank + rot_dim);
TORCH_CHECK_EQ(kv_cache.dim(), 3);
STD_TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + rot_dim);
STD_TORCH_CHECK(kv_cache.dim() == 3);
TORCH_CHECK_EQ(kv_cache_quant_scale.numel(), 1);
TORCH_CHECK_EQ(kv_cache_quant_scale.scalar_type(), c10::ScalarType::Float);
STD_TORCH_CHECK(kv_cache_quant_scale.numel() == 1);
STD_TORCH_CHECK(kv_cache_quant_scale.scalar_type() ==
torch::headeronly::ScalarType::Float);
int64_t q_pe_stride_token = q_pe.stride(0);
int64_t q_pe_stride_head = q_pe.stride(1);
@@ -286,9 +298,10 @@ void concat_and_cache_mla_rope_fused(
dim3 grid(num_tokens, 1, 1);
dim3 block(thread_block_size, 1, 1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(positions));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
positions.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.scalar_type(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA_ROPE_FUSED);
}
@@ -17,11 +17,8 @@
#define NVFP4_ENABLE_ELTS16 1
#include "libtorch_stable/quantization/fp4/nvfp4_utils.cuh"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "dispatch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "libtorch_stable/torch_utils.h"
namespace vllm {
@@ -184,12 +181,13 @@ __global__ void reshape_and_cache_nvfp4_kernel(
// Receives key_cache/value_cache as kv_cache[:, 0] and kv_cache[:, 1].
// Each KV side contains both data and scale:
// page = [K_data | K_scale | V_data | V_scale]
void reshape_and_cache_nvfp4_dispatch(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
torch::Tensor& k_scale,
torch::Tensor& v_scale) {
void reshape_and_cache_nvfp4_dispatch(torch::stable::Tensor& key,
torch::stable::Tensor& value,
torch::stable::Tensor& key_cache,
torch::stable::Tensor& value_cache,
torch::stable::Tensor& slot_mapping,
torch::stable::Tensor& k_scale,
torch::stable::Tensor& v_scale) {
int num_tokens = slot_mapping.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
@@ -200,17 +198,18 @@ void reshape_and_cache_nvfp4_dispatch(torch::Tensor& key, torch::Tensor& value,
// key_cache is kv_cache[:, 0] with shape
// [num_blocks, block_size, num_heads, full_dim] in logical order.
// Strides encode the physical layout (HND or NHD).
TORCH_CHECK(key_cache.dim() == 4, "key_cache must be 4D");
TORCH_CHECK(key_cache.size(3) == full_dim,
STD_TORCH_CHECK(key_cache.dim() == 4, "key_cache must be 4D");
STD_TORCH_CHECK(key_cache.size(3) == full_dim,
"key_cache last dim must be data_dim + scale_dim, got ",
key_cache.size(3), " expected ", full_dim);
int block_size = key_cache.size(1);
TORCH_CHECK(head_size % 16 == 0,
STD_TORCH_CHECK(head_size % 16 == 0,
"head_size must be divisible by 16 for NVFP4 KV cache");
TORCH_CHECK(block_size % 4 == 0,
"block_size must be divisible by 4 for NVFP4 KV cache swizzle");
STD_TORCH_CHECK(block_size % 4 == 0,
"block_size must be divisible by 4 for NVFP4 KV cache "
"swizzle");
// Detect physical layout from strides (based on full_dim).
// HND: head stride > block_offset stride.
@@ -230,8 +229,9 @@ void reshape_and_cache_nvfp4_dispatch(torch::Tensor& key, torch::Tensor& value,
// Scale follows data within each KV side.
int64_t data_per_kv = (int64_t)num_heads * block_size * data_dim;
uint8_t* key_scale_ptr = key_cache.data_ptr<uint8_t>() + data_per_kv;
uint8_t* value_scale_ptr = value_cache.data_ptr<uint8_t>() + data_per_kv;
uint8_t* key_scale_ptr = key_cache.mutable_data_ptr<uint8_t>() + data_per_kv;
uint8_t* value_scale_ptr =
value_cache.mutable_data_ptr<uint8_t>() + data_per_kv;
// Scale strides: same page stride, inner strides from layout.
int64_t scale_block_stride = data_block_stride;
@@ -244,8 +244,8 @@ void reshape_and_cache_nvfp4_dispatch(torch::Tensor& key, torch::Tensor& value,
scale_block_offset_stride = (int64_t)num_heads * scale_dim;
}
const float* k_scale_ptr = k_scale.data_ptr<float>();
const float* v_scale_ptr = v_scale.data_ptr<float>();
const float* k_scale_ptr = k_scale.const_data_ptr<float>();
const float* v_scale_ptr = v_scale.const_data_ptr<float>();
int groups_per_head = head_size / CVT_FP4_SF_VEC_SIZE;
int total_groups = num_heads * groups_per_head;
@@ -256,20 +256,22 @@ void reshape_and_cache_nvfp4_dispatch(torch::Tensor& key, torch::Tensor& value,
dim3 grid(num_tokens);
dim3 block(num_threads);
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
key.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
AT_DISPATCH_REDUCED_FLOATING_TYPES(
VLLM_STABLE_DISPATCH_HALF_TYPES(
key.scalar_type(), "reshape_and_cache_nvfp4", [&] {
vllm::reshape_and_cache_nvfp4_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key_cache.data_ptr<uint8_t>(), value_cache.data_ptr<uint8_t>(),
key_scale_ptr, value_scale_ptr,
slot_mapping.data_ptr<int64_t>(), k_scale_ptr, v_scale_ptr,
key.stride(0), value.stride(0), num_heads, head_size,
block_size, data_block_stride, data_head_stride,
data_block_offset_stride, scale_block_stride, scale_head_stride,
scale_block_offset_stride);
key.const_data_ptr<scalar_t>(),
value.const_data_ptr<scalar_t>(),
key_cache.mutable_data_ptr<uint8_t>(),
value_cache.mutable_data_ptr<uint8_t>(), key_scale_ptr,
value_scale_ptr, slot_mapping.const_data_ptr<int64_t>(),
k_scale_ptr, v_scale_ptr, key.stride(0), value.stride(0),
num_heads, head_size, block_size, data_block_stride,
data_head_stride, data_block_offset_stride, scale_block_stride,
scale_head_stride, scale_block_offset_stride);
});
}
+129
View File
@@ -355,3 +355,132 @@ torch::stable::Tensor ggml_moe_a8_vec(torch::stable::Tensor X,
int64_t tokens);
int64_t ggml_moe_get_block_size(int64_t type);
void paged_attention_v1(
torch::stable::Tensor& out, torch::stable::Tensor& query,
torch::stable::Tensor& key_cache, torch::stable::Tensor& value_cache,
int64_t num_kv_heads, double scale, torch::stable::Tensor& block_tables,
torch::stable::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len,
const std::optional<torch::stable::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::stable::Tensor& k_scale,
torch::stable::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2(
torch::stable::Tensor& out, torch::stable::Tensor& exp_sums,
torch::stable::Tensor& max_logits, torch::stable::Tensor& tmp_out,
torch::stable::Tensor& query, torch::stable::Tensor& key_cache,
torch::stable::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::stable::Tensor& block_tables, torch::stable::Tensor& seq_lens,
int64_t block_size, int64_t max_seq_len,
const std::optional<torch::stable::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::stable::Tensor& k_scale,
torch::stable::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
// Cache ops (shared CUDA/ROCm)
void swap_blocks(torch::stable::Tensor& src, torch::stable::Tensor& dst,
int64_t block_size_in_bytes,
const torch::stable::Tensor& block_mapping);
// Batch swap: submit all block copies in a single driver call.
void swap_blocks_batch(const torch::stable::Tensor& src_ptrs,
const torch::stable::Tensor& dst_ptrs,
const torch::stable::Tensor& sizes,
bool is_src_access_order_any);
void reshape_and_cache(torch::stable::Tensor& key, torch::stable::Tensor& value,
torch::stable::Tensor& key_cache,
torch::stable::Tensor& value_cache,
torch::stable::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::stable::Tensor& k_scale,
torch::stable::Tensor& v_scale);
void reshape_and_cache_flash(
torch::stable::Tensor& key, torch::stable::Tensor& value,
torch::stable::Tensor& key_cache, torch::stable::Tensor& value_cache,
torch::stable::Tensor& slot_mapping, const std::string& kv_cache_dtype,
torch::stable::Tensor& k_scale, torch::stable::Tensor& v_scale);
void concat_and_cache_mla(torch::stable::Tensor& kv_c,
torch::stable::Tensor& k_pe,
torch::stable::Tensor& kv_cache,
torch::stable::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::stable::Tensor& scale);
// NOTE: k_pe and kv_c order is flipped compared to concat_and_cache_mla
void concat_and_cache_mla_rope_fused(
torch::stable::Tensor& positions, torch::stable::Tensor& q_pe,
torch::stable::Tensor& k_pe, torch::stable::Tensor& kv_c,
torch::stable::Tensor& rope_cos_sin_cache, bool rope_is_neox,
torch::stable::Tensor& slot_mapping, torch::stable::Tensor& kv_cache,
const std::string& kv_cache_dtype,
torch::stable::Tensor& kv_cache_quant_scale);
// Just for unittest
void convert_fp8(torch::stable::Tensor& dst_cache,
torch::stable::Tensor& src_cache, const double scale,
const std::string& kv_cache_dtype);
void gather_and_maybe_dequant_cache(
torch::stable::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
// ENTRIES...]
torch::stable::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::stable::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::stable::Tensor const& cu_seq_lens, // [BATCH+1]
torch::stable::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS]
int64_t num_tokens, const std::string& kv_cache_dtype,
torch::stable::Tensor const& scale,
std::optional<torch::stable::Tensor> seq_starts = std::nullopt);
// TODO(hc): cp_gather_cache need support scaled kvcahe in the future.
void cp_gather_cache(
torch::stable::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
// ENTRIES...]
torch::stable::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::stable::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::stable::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size,
std::optional<torch::stable::Tensor> seq_starts = std::nullopt);
// Gather and upconvert FP8 KV cache to BF16 workspace
void cp_gather_and_upconvert_fp8_kv_cache(
torch::stable::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
// 656]
torch::stable::Tensor const& dst, // [TOT_TOKENS, 576]
torch::stable::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::stable::Tensor const& seq_lens, // [BATCH]
torch::stable::Tensor const& workspace_starts, // [BATCH]
int64_t batch_size);
// Indexer K quantization and cache function
void indexer_k_quant_and_cache(
torch::stable::Tensor& k, // [num_tokens, head_dim]
torch::stable::Tensor& kv_cache, // [num_blocks, block_size,
// cache_stride]
torch::stable::Tensor& slot_mapping, // [num_tokens]
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt);
// Concatenate query nope and rope for MLA/DSA attention
void concat_mla_q(
torch::stable::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim]
torch::stable::Tensor& q_pe, // [num_tokens, num_heads, rope_dim]
torch::stable::Tensor& q_out); // [num_tokens, num_heads, nope_dim +
// rope_dim]
// Extract function to gather quantized K cache
void cp_gather_indexer_k_quant_cache(
const torch::stable::Tensor& kv_cache, // [num_blocks, block_size,
// cache_stride]
torch::stable::Tensor& dst_k, // [num_tokens, head_dim]
torch::stable::Tensor& dst_scale, // [num_tokens, head_dim /
// quant_block_size * 4]
const torch::stable::Tensor& block_table, // [batch_size, num_blocks]
const torch::stable::Tensor& cu_seq_lens); // [batch_size + 1]
@@ -17,7 +17,7 @@
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include "../../cuda_vec_utils.cuh"
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
@@ -27,7 +27,7 @@
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include "../../cuda_vec_utils.cuh"
#include "cuda_utils.h"
#include "nvfp4_utils.cuh"
@@ -17,7 +17,7 @@
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include "../../cuda_vec_utils.cuh"
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
@@ -23,7 +23,7 @@
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include "../../cuda_vec_utils.cuh"
#include "cuda_utils.h"
#include "launch_bounds_utils.h"
@@ -20,7 +20,7 @@
#include <cuda_fp8.h>
#include <utility>
#include "cuda_vec_utils.cuh"
#include "../../cuda_vec_utils.cuh"
#if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \
CUDA_VERSION >= 12090
+141
View File
@@ -474,6 +474,33 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"Tensor? initial_state_idx,"
"Tensor? cu_chunk_seqlen,"
"Tensor? last_chunk_indices) -> ()");
// Attention ops
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
ops.def(
"paged_attention_v1("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
// PagedAttention V2.
ops.def(
"paged_attention_v2("
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
}
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
@@ -581,6 +608,9 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
ops.impl("ggml_moe_a8", TORCH_BOX(&ggml_moe_a8));
ops.impl("ggml_moe_a8_vec", TORCH_BOX(&ggml_moe_a8_vec));
ops.impl("selective_scan_fwd", TORCH_BOX(&selective_scan_fwd));
ops.impl("paged_attention_v1", TORCH_BOX(&paged_attention_v1));
ops.impl("paged_attention_v2", TORCH_BOX(&paged_attention_v2));
}
// These capability-check functions take only primitive args (no tensors), so
@@ -603,4 +633,115 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
ops.impl("ggml_moe_get_block_size", TORCH_BOX(&ggml_moe_get_block_size));
}
// Cache ops
STABLE_TORCH_LIBRARY_FRAGMENT(_C_cache_ops, ops) {
// Swap in (out) the cache blocks from src to dst.
ops.def(
"swap_blocks(Tensor src, Tensor! dst,"
" int block_size_in_bytes, Tensor block_mapping) -> ()");
// Batch swap: submit all block copies in a single driver call.
ops.def(
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs,"
" Tensor sizes,"
" bool is_src_access_order_any=False) -> ()");
// Reshape the key and value tensors and cache them.
ops.def(
"reshape_and_cache(Tensor key, Tensor value,"
" Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()");
// Reshape the key and value tensors and cache them.
ops.def(
"reshape_and_cache_flash(Tensor key, Tensor value,"
" Tensor! key_cache,"
" Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()");
// Concat kv_c and k_pe and cache them.
ops.def(
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()");
// Rotate Q and K, then write to kv cache for MLA
ops.def(
"concat_and_cache_mla_rope_fused("
" Tensor positions,"
" Tensor! q_pe,"
" Tensor! k_pe,"
" Tensor kv_c,"
" Tensor cos_sin_cache,"
" bool is_neox,"
" Tensor slot_mapping,"
" Tensor! kv_cache,"
" str kv_cache_dtype,"
" Tensor kv_cache_scale) -> ()");
// Convert the key and value cache to fp8 data type.
ops.def(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
"str kv_cache_dtype) -> ()");
// Gather cache blocks from src_cache to dst, dequantizing from
// src_cache's dtype to dst's dtype if necessary.
ops.def(
"gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, "
" Tensor block_table, Tensor cu_seq_lens, "
" Tensor token_to_seq, "
" int num_tokens, "
" str kv_cache_dtype, "
" Tensor scale, Tensor? seq_starts) -> ()");
ops.def(
"cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
ops.def(
"cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, "
"Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int "
"batch_size) -> ()");
ops.def(
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
"slot_mapping, "
"int quant_block_size, str kv_cache_dtype) -> ()");
ops.def("concat_mla_q(Tensor ql_nope, Tensor q_pe, Tensor! q_out) -> ()");
ops.def(
"cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! "
"dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()");
}
STABLE_TORCH_LIBRARY_IMPL(_C_cache_ops, CPU, ops) {
ops.impl("swap_blocks_batch", TORCH_BOX(&swap_blocks_batch));
}
STABLE_TORCH_LIBRARY_IMPL(_C_cache_ops, CUDA, ops) {
ops.impl("swap_blocks", TORCH_BOX(&swap_blocks));
ops.impl("reshape_and_cache", TORCH_BOX(&reshape_and_cache));
ops.impl("reshape_and_cache_flash", TORCH_BOX(&reshape_and_cache_flash));
ops.impl("concat_and_cache_mla", TORCH_BOX(&concat_and_cache_mla));
ops.impl("concat_and_cache_mla_rope_fused",
TORCH_BOX(&concat_and_cache_mla_rope_fused));
ops.impl("convert_fp8", TORCH_BOX(&convert_fp8));
ops.impl("gather_and_maybe_dequant_cache",
TORCH_BOX(&gather_and_maybe_dequant_cache));
ops.impl("cp_gather_cache", TORCH_BOX(&cp_gather_cache));
ops.impl("cp_gather_and_upconvert_fp8_kv_cache",
TORCH_BOX(&cp_gather_and_upconvert_fp8_kv_cache));
ops.impl("indexer_k_quant_and_cache", TORCH_BOX(&indexer_k_quant_and_cache));
ops.impl("concat_mla_q", TORCH_BOX(&concat_mla_q));
ops.impl("cp_gather_indexer_k_quant_cache",
TORCH_BOX(&cp_gather_indexer_k_quant_cache));
}
REGISTER_EXTENSION(_C_stable_libtorch)
-23
View File
@@ -31,29 +31,6 @@ torch::Tensor weak_ref_tensor(torch::Tensor& tensor) {
return new_tensor;
}
void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
// rms_norm and fused_add_rms_norm declarations also exist in
// csrc/libtorch_stable/ops.h (torch::stable ABI for CUDA). They remain here
// because the CPU build still uses these torch::Tensor declarations.
+12 -9
View File
@@ -6,6 +6,7 @@
#include <hip/hip_bfloat16.h>
#include "../../../../attention/attention_dtypes.h"
#include <torch/headeronly/core/ScalarType.h>
namespace vllm {
#ifdef USE_ROCM
@@ -642,27 +643,29 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
vllm::Fp8KVCacheDataType KV_CACHE_DTYPE = \
vllm::get_fp8_kv_cache_data_type(KV_DTYPE); \
if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
if (SRC_DTYPE == torch::headeronly::ScalarType::Float) { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
} else if (SRC_DTYPE == torch::headeronly::ScalarType::Half) { \
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
} else if (SRC_DTYPE == torch::headeronly::ScalarType::BFloat16) { \
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
STD_TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kFp8E4M3) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
if (SRC_DTYPE == torch::headeronly::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
} else if (SRC_DTYPE == torch::headeronly::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
} else if (SRC_DTYPE == torch::headeronly::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
STD_TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
STD_TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
}
} // namespace fp8
@@ -1,6 +1,7 @@
#pragma once
#include "../../../../attention/attention_dtypes.h"
#include <torch/headeronly/core/ScalarType.h>
#include <assert.h>
#include <float.h>
#include <stdint.h>
@@ -546,37 +547,40 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
vllm::Fp8KVCacheDataType KV_CACHE_DTYPE = \
vllm::get_fp8_kv_cache_data_type(KV_DTYPE); \
if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
if (SRC_DTYPE == torch::headeronly::ScalarType::Float) { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
} else if (SRC_DTYPE == torch::headeronly::ScalarType::Half) { \
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
} else if (SRC_DTYPE == torch::headeronly::ScalarType::BFloat16) { \
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
STD_TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kFp8E4M3) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
if (SRC_DTYPE == torch::headeronly::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
} else if (SRC_DTYPE == torch::headeronly::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
} else if (SRC_DTYPE == torch::headeronly::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
STD_TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kFp8E5M2) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
if (SRC_DTYPE == torch::headeronly::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
} else if (SRC_DTYPE == torch::headeronly::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
} else if (SRC_DTYPE == torch::headeronly::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
STD_TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
STD_TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
}
} // namespace fp8
+4 -138
View File
@@ -1,4 +1,7 @@
#include "cache.h"
// Provides torch::Tensor for ops.h (previously included transitively via
// cache.h, which is no longer included here after cache ops moved to
// _C_stable_libtorch).
#include <torch/all.h>
#include "cuda_utils.h"
#include "ops.h"
#include "core/registration.h"
@@ -33,35 +36,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU,
&get_cuda_view_from_cpu_tensor);
// Attention ops
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
ops.def(
"paged_attention_v1("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
// PagedAttention V2.
ops.def(
"paged_attention_v2("
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
// Activation ops (quantized only — basic ops moved to _C_stable_libtorch)
ops.def(
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
@@ -217,114 +191,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#endif
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Cache ops
// Swap in (out) the cache blocks from src to dst.
cache_ops.def(
"swap_blocks(Tensor src, Tensor! dst,"
" int block_size_in_bytes, Tensor block_mapping) -> ()");
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
// Batch swap: submit all block copies in a single driver call.
cache_ops.def(
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs,"
" Tensor sizes,"
" bool is_src_access_order_any=False) -> ()");
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache(Tensor key, Tensor value,"
" Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache_flash(Tensor key, Tensor value,"
" Tensor! key_cache,"
" Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
&reshape_and_cache_flash);
// Concat kv_c and k_pe and cache them.
cache_ops.def(
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()");
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
// Rotate Q and K, then write to kv cache for MLA
cache_ops.def(
"concat_and_cache_mla_rope_fused("
" Tensor positions,"
" Tensor! q_pe,"
" Tensor! k_pe,"
" Tensor kv_c,"
" Tensor cos_sin_cache,"
" bool is_neox,"
" Tensor slot_mapping,"
" Tensor! kv_cache,"
" str kv_cache_dtype,"
" Tensor kv_cache_scale) -> ()");
cache_ops.impl("concat_and_cache_mla_rope_fused", torch::kCUDA,
&concat_and_cache_mla_rope_fused);
// Convert the key and value cache to fp8 data type.
cache_ops.def(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
"str kv_cache_dtype) -> ()");
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
// Gather cache blocks from src_cache to dst, dequantizing from
// src_cache's dtype to dst's dtype if necessary.
cache_ops.def(
"gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, "
" Tensor block_table, Tensor cu_seq_lens, "
" Tensor token_to_seq, "
" int num_tokens, "
" str kv_cache_dtype, "
" Tensor scale, Tensor? seq_starts) -> ()");
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,
&gather_and_maybe_dequant_cache);
cache_ops.def(
"cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
cache_ops.def(
"cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, "
"Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int "
"batch_size) -> ()");
cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA,
&cp_gather_and_upconvert_fp8_kv_cache);
cache_ops.def(
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
"slot_mapping, "
"int quant_block_size, str kv_cache_dtype) -> ()");
cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA,
&indexer_k_quant_and_cache);
cache_ops.def(
"concat_mla_q(Tensor ql_nope, Tensor q_pe, Tensor! q_out) -> ()");
cache_ops.impl("concat_mla_q", torch::kCUDA, &concat_mla_q);
cache_ops.def(
"cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! "
"dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()");
cache_ops.impl("cp_gather_indexer_k_quant_cache", torch::kCUDA,
&cp_gather_indexer_k_quant_cache);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
// Cuda utils