[9/n] Migrate attention and cache kernels to torch stable ABI (continued) (#43717)

Signed-off-by: Chris Leonard <chleonar@redhat.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Co-authored-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Co-authored-by: Shengqi Chen <harry-chen@outlook.com>
This commit is contained in:
Chris Leonard
2026-05-29 00:44:45 -04:00
committed by GitHub
parent 710f077617
commit 22a58640b4
23 changed files with 909 additions and 720 deletions
+9 -11
View File
@@ -305,10 +305,6 @@ endif()
#
set(VLLM_EXT_SRC
"csrc/cache_kernels.cu"
"csrc/cache_kernels_fused.cu"
"csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu"
"csrc/cuda_view.cu"
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
"csrc/quantization/activation_kernels.cu"
@@ -647,7 +643,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
"csrc/libtorch_stable/attention/merge_attn_states.cu"
"csrc/libtorch_stable/sampler.cu"
"csrc/libtorch_stable/topk.cu"
"csrc/libtorch_stable/mamba/selective_scan_fwd.cu")
"csrc/libtorch_stable/mamba/selective_scan_fwd.cu"
"csrc/libtorch_stable/attention/paged_attention_v1.cu"
"csrc/libtorch_stable/attention/paged_attention_v2.cu"
"csrc/libtorch_stable/cache_kernels.cu"
"csrc/libtorch_stable/cache_kernels_fused.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC
@@ -924,13 +924,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
# nvfp4_kv_cache_kernels uses non-stable torch API and is called directly
# from cache_kernels.cu, so it belongs in _C rather than _C_stable.
set(NVFP4_KV_SRC "csrc/nvfp4_kv_cache_kernels.cu")
set(NVFP4_KV_SRC "csrc/libtorch_stable/nvfp4_kv_cache_kernels.cu")
set_gencode_flags_for_srcs(
SRCS "${NVFP4_KV_SRC}"
CUDA_ARCHS "${FP4_ARCHS}")
target_sources(_C PRIVATE ${NVFP4_KV_SRC})
list(APPEND VLLM_STABLE_EXT_SRC "${NVFP4_KV_SRC}")
target_compile_definitions(_C PRIVATE ENABLE_NVFP4_SM120=1)
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
@@ -960,11 +958,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
set(NVFP4_KV_SRC "csrc/nvfp4_kv_cache_kernels.cu")
set(NVFP4_KV_SRC "csrc/libtorch_stable/nvfp4_kv_cache_kernels.cu")
set_gencode_flags_for_srcs(
SRCS "${NVFP4_KV_SRC}"
CUDA_ARCHS "${FP4_ARCHS}")
target_sources(_C PRIVATE ${NVFP4_KV_SRC})
list(APPEND VLLM_STABLE_EXT_SRC "${NVFP4_KV_SRC}")
target_compile_definitions(_C PRIVATE ENABLE_NVFP4_SM100=1)
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
+1 -1
View File
@@ -4,7 +4,7 @@
#include <cmath>
#include "../cuda_compat.h"
#include "../cuda_vec_utils.cuh"
#include "cuda_vec_utils.cuh"
#include "dispatch_utils.h"
#include "torch_utils.h"
@@ -17,21 +17,18 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include "attention_dtypes.h"
#include "../../attention/attention_dtypes.h"
#include "attention_utils.cuh"
#include "../cuda_compat.h"
#include "../../cuda_compat.h"
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include "../quantization/w8a8/fp8/amd/quant_utils.cuh"
#include "../../quantization/w8a8/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16;
#else
#include "../quantization/w8a8/fp8/nvidia/quant_utils.cuh"
#include "../../quantization/w8a8/fp8/nvidia/quant_utils.cuh"
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
@@ -18,8 +18,8 @@
*/
#pragma once
#include "../cuda_compat.h"
#include "attention_dtypes.h"
#include "../../cuda_compat.h"
#include "../../attention/attention_dtypes.h"
#include <float.h>
#include <type_traits>
@@ -7,7 +7,7 @@
#include <torch/headeronly/core/ScalarType.h>
#include "../../attention/attention_dtypes.h"
#include "../../attention/attention_utils.cuh"
#include "attention_utils.cuh"
#include "../../quantization/w8a8/fp8/common.cuh"
namespace vllm {
@@ -16,8 +16,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../torch_utils.h"
#include "attention_kernels.cuh"
#include "../cuda_compat.h"
#include "../../cuda_compat.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
@@ -44,13 +45,15 @@ template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
int NUM_THREADS = 128>
void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
torch::stable::Tensor& out, torch::stable::Tensor& query,
torch::stable::Tensor& key_cache, torch::stable::Tensor& value_cache,
int num_kv_heads, float scale, torch::stable::Tensor& block_tables,
torch::stable::Tensor& seq_lens, int max_seq_len,
const std::optional<torch::stable::Tensor>& alibi_slopes,
torch::stable::Tensor& k_scale, torch::stable::Tensor& v_scale,
const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@@ -69,8 +72,8 @@ void paged_attention_v1_launcher(
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
int* block_tables_ptr = block_tables.mutable_data_ptr<int>();
int* seq_lens_ptr = seq_lens.mutable_data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
@@ -85,8 +88,9 @@ void paged_attention_v1_launcher(
dim3 grid(num_heads, num_seqs, 1);
dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
query.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
@@ -119,7 +123,7 @@ void paged_attention_v1_launcher(
LAUNCH_PAGED_ATTENTION_V1(256);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
STD_TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
@@ -141,43 +145,43 @@ void paged_attention_v1_launcher(
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
STD_TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v1(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
torch::stable::Tensor& out, // [num_seqs, num_heads, head_size]
torch::stable::Tensor& query, // [num_seqs, num_heads, head_size]
torch::stable::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
torch::stable::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
torch::stable::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::stable::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const std::optional<torch::stable::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::stable::Tensor& k_scale,
torch::stable::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
DISPATCH_BY_KV_CACHE_DTYPE(query.scalar_type(), kv_cache_dtype,
CALL_V1_LAUNCHER_BLOCK_SIZE)
}
@@ -16,8 +16,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../torch_utils.h"
#include "attention_kernels.cuh"
#include "../cuda_compat.h"
#include "../../cuda_compat.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
@@ -44,14 +45,16 @@ template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
int NUM_THREADS = 128, int PARTITION_SIZE = 512>
void paged_attention_v2_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
torch::stable::Tensor& out, torch::stable::Tensor& exp_sums,
torch::stable::Tensor& max_logits, torch::stable::Tensor& tmp_out,
torch::stable::Tensor& query, torch::stable::Tensor& key_cache,
torch::stable::Tensor& value_cache, int num_kv_heads, float scale,
torch::stable::Tensor& block_tables, torch::stable::Tensor& seq_lens,
int max_seq_len, const std::optional<torch::stable::Tensor>& alibi_slopes,
torch::stable::Tensor& k_scale, torch::stable::Tensor& v_scale,
const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@@ -73,8 +76,8 @@ void paged_attention_v2_launcher(
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
int* block_tables_ptr = block_tables.mutable_data_ptr<int>();
int* seq_lens_ptr = seq_lens.mutable_data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
@@ -91,8 +94,9 @@ void paged_attention_v2_launcher(
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
query.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
@@ -125,7 +129,7 @@ void paged_attention_v2_launcher(
LAUNCH_PAGED_ATTENTION_V2(256);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
STD_TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
@@ -148,46 +152,48 @@ void paged_attention_v2_launcher(
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
STD_TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v2(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor&
torch::stable::Tensor& out, // [num_seqs, num_heads, head_size]
torch::stable::Tensor&
exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::stable::Tensor&
max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::stable::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
torch::stable::Tensor& query, // [num_seqs, num_heads, head_size]
torch::stable::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
torch::stable::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
torch::stable::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::stable::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const std::optional<torch::stable::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::stable::Tensor& k_scale,
torch::stable::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
DISPATCH_BY_KV_CACHE_DTYPE(query.scalar_type(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
}
File diff suppressed because it is too large Load Diff
@@ -1,15 +1,13 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "torch_utils.h"
#include "dispatch_utils.h"
#include "quantization/w8a8/fp8/common.cuh"
#include "../cuda_compat.h"
#include "../quantization/w8a8/fp8/common.cuh"
#ifdef USE_ROCM
#include "quantization/w8a8/fp8/amd/quant_utils.cuh"
#include "../quantization/w8a8/fp8/amd/quant_utils.cuh"
#else
#include "quantization/w8a8/fp8/nvidia/quant_utils.cuh"
#include "../quantization/w8a8/fp8/nvidia/quant_utils.cuh"
#endif
#ifdef USE_ROCM
@@ -164,43 +162,52 @@ __global__ void concat_and_cache_mla_rope_fused_kernel(
} // namespace vllm
#define CALL_CONCAT_AND_CACHE_MLA_ROPE_FUSED(RAW_KV_T, CACHE_T, KV_DTYPE) \
do { \
VLLM_DISPATCH_FLOATING_TYPES(q_pe.scalar_type(), "qk_scalar_type", [&] { \
using qk_t = scalar_t; \
VLLM_DISPATCH_FLOATING_TYPES( \
rope_cos_sin_cache.scalar_type(), "rope_cos_sin_cache_scalar_type", \
[&] { \
using cos_sin_t = scalar_t; \
if (rope_is_neox) { \
vllm::concat_and_cache_mla_rope_fused_kernel< \
qk_t, cos_sin_t, true, RAW_KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
positions.data_ptr<int64_t>(), q_pe.data_ptr<qk_t>(), \
k_pe.data_ptr<qk_t>(), kv_c.data_ptr<qk_t>(), \
rope_cos_sin_cache.data_ptr<cos_sin_t>(), rot_dim, \
q_pe_stride_token, q_pe_stride_head, k_pe_stride, \
kv_c_stride, num_q_heads, \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, \
entry_stride, kv_lora_rank, block_size, \
kv_cache_quant_scale.data_ptr<float>()); \
} else { \
vllm::concat_and_cache_mla_rope_fused_kernel< \
qk_t, cos_sin_t, false, RAW_KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
positions.data_ptr<int64_t>(), q_pe.data_ptr<qk_t>(), \
k_pe.data_ptr<qk_t>(), kv_c.data_ptr<qk_t>(), \
rope_cos_sin_cache.data_ptr<cos_sin_t>(), rot_dim, \
q_pe_stride_token, q_pe_stride_head, k_pe_stride, \
kv_c_stride, num_q_heads, \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, \
entry_stride, kv_lora_rank, block_size, \
kv_cache_quant_scale.data_ptr<float>()); \
} \
}); \
}); \
#define CALL_CONCAT_AND_CACHE_MLA_ROPE_FUSED(RAW_KV_T, CACHE_T, KV_DTYPE) \
do { \
VLLM_STABLE_DISPATCH_FLOATING_TYPES( \
q_pe.scalar_type(), "qk_scalar_type", [&] { \
using qk_t = scalar_t; \
VLLM_STABLE_DISPATCH_FLOATING_TYPES( \
rope_cos_sin_cache.scalar_type(), \
"rope_cos_sin_cache_scalar_type", [&] { \
using cos_sin_t = scalar_t; \
if (rope_is_neox) { \
vllm::concat_and_cache_mla_rope_fused_kernel< \
qk_t, cos_sin_t, true, RAW_KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
positions.const_data_ptr<int64_t>(), \
q_pe.mutable_data_ptr<qk_t>(), \
k_pe.mutable_data_ptr<qk_t>(), \
kv_c.const_data_ptr<qk_t>(), \
rope_cos_sin_cache.const_data_ptr<cos_sin_t>(), \
rot_dim, q_pe_stride_token, q_pe_stride_head, \
k_pe_stride, kv_c_stride, num_q_heads, \
reinterpret_cast<CACHE_T*>( \
kv_cache.mutable_data_ptr()), \
slot_mapping.const_data_ptr<int64_t>(), \
block_stride, entry_stride, kv_lora_rank, \
block_size, \
kv_cache_quant_scale.const_data_ptr<float>()); \
} else { \
vllm::concat_and_cache_mla_rope_fused_kernel< \
qk_t, cos_sin_t, false, RAW_KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
positions.const_data_ptr<int64_t>(), \
q_pe.mutable_data_ptr<qk_t>(), \
k_pe.mutable_data_ptr<qk_t>(), \
kv_c.const_data_ptr<qk_t>(), \
rope_cos_sin_cache.const_data_ptr<cos_sin_t>(), \
rot_dim, q_pe_stride_token, q_pe_stride_head, \
k_pe_stride, kv_c_stride, num_q_heads, \
reinterpret_cast<CACHE_T*>( \
kv_cache.mutable_data_ptr()), \
slot_mapping.const_data_ptr<int64_t>(), \
block_stride, entry_stride, kv_lora_rank, \
block_size, \
kv_cache_quant_scale.const_data_ptr<float>()); \
} \
}); \
}); \
} while (false)
// Executes RoPE on q_pe and k_pe, then writes k_pe and kv_c in the kv cache.
@@ -208,64 +215,69 @@ __global__ void concat_and_cache_mla_rope_fused_kernel(
// Replaces DeepseekScalingRotaryEmbedding.self.rotary_emb and
// concat_and_cache_mla.
void concat_and_cache_mla_rope_fused(
torch::Tensor& positions, // [num_tokens]
torch::Tensor& q_pe, // [num_tokens, num_q_heads, rot_dim]
torch::Tensor& k_pe, // [num_tokens, rot_dim]
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& rope_cos_sin_cache, // [max_position, rot_dim]
torch::stable::Tensor& positions, // [num_tokens]
torch::stable::Tensor& q_pe, // [num_tokens, num_q_heads, rot_dim]
torch::stable::Tensor& k_pe, // [num_tokens, rot_dim]
torch::stable::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::stable::Tensor& rope_cos_sin_cache, // [max_position, rot_dim]
bool rope_is_neox,
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
torch::Tensor&
torch::stable::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
torch::stable::Tensor&
kv_cache, // [num_blocks, block_size, (kv_lora_rank + rot_dim)]
const std::string& kv_cache_dtype, torch::Tensor& kv_cache_quant_scale) {
const std::string& kv_cache_dtype,
torch::stable::Tensor& kv_cache_quant_scale) {
// NOTE(woosuk): In vLLM V1, query/key/position.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
// both include padding.
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
// since key includes padding for CUDA graphs, while slot_mapping does not.
// In this case, slot_mapping.size(0) represents the actual number of tokens
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0)
// because both include padding.
// In vLLM V1, however, key.size(0) can be larger than
// slot_mapping.size(0) since key includes padding for CUDA graphs,
// while slot_mapping does not. In this case,
// slot_mapping.size(0) represents the actual number of tokens
// before padding.
// For compatibility with both cases, we use slot_mapping.size(0) as the
// number of tokens.
int num_tokens = slot_mapping.size(0);
int num_padded_tokens = q_pe.size(0);
TORCH_CHECK_GE(num_padded_tokens, num_tokens);
// For compatibility with both cases, we use slot_mapping.size(0) as
// the number of tokens.
const int64_t num_tokens = slot_mapping.size(0);
const int64_t num_padded_tokens = q_pe.size(0);
STD_TORCH_CHECK(num_padded_tokens >= num_tokens);
const int num_q_heads = q_pe.size(1);
const int rot_dim = q_pe.size(2);
const int kv_lora_rank = kv_c.size(1);
TORCH_CHECK_EQ(positions.size(0), num_padded_tokens);
TORCH_CHECK_EQ(positions.dim(), 1);
TORCH_CHECK_EQ(positions.scalar_type(), c10::ScalarType::Long);
STD_TORCH_CHECK(positions.size(0) == num_padded_tokens);
STD_TORCH_CHECK(positions.dim() == 1);
STD_TORCH_CHECK(positions.scalar_type() ==
torch::headeronly::ScalarType::Long);
TORCH_CHECK_EQ(q_pe.dim(), 3);
TORCH_CHECK_EQ(q_pe.size(0), num_padded_tokens);
TORCH_CHECK_EQ(q_pe.size(1), num_q_heads);
TORCH_CHECK_EQ(q_pe.size(2), rot_dim);
STD_TORCH_CHECK(q_pe.dim() == 3);
STD_TORCH_CHECK(q_pe.size(0) == num_padded_tokens);
STD_TORCH_CHECK(q_pe.size(1) == num_q_heads);
STD_TORCH_CHECK(q_pe.size(2) == rot_dim);
TORCH_CHECK_EQ(k_pe.dim(), 2);
TORCH_CHECK_EQ(k_pe.size(0), num_padded_tokens);
TORCH_CHECK_EQ(k_pe.size(1), rot_dim);
TORCH_CHECK_EQ(k_pe.scalar_type(), q_pe.scalar_type());
STD_TORCH_CHECK(k_pe.dim() == 2);
STD_TORCH_CHECK(k_pe.size(0) == num_padded_tokens);
STD_TORCH_CHECK(k_pe.size(1) == rot_dim);
STD_TORCH_CHECK(k_pe.scalar_type() == q_pe.scalar_type());
TORCH_CHECK_EQ(kv_c.dim(), 2);
TORCH_CHECK_EQ(kv_c.size(0), num_padded_tokens);
TORCH_CHECK_EQ(kv_c.size(1), kv_lora_rank);
TORCH_CHECK_EQ(kv_c.scalar_type(), q_pe.scalar_type());
TORCH_CHECK_EQ(kv_c.dtype(), q_pe.dtype());
STD_TORCH_CHECK(kv_c.dim() == 2);
STD_TORCH_CHECK(kv_c.size(0) == num_padded_tokens);
STD_TORCH_CHECK(kv_c.size(1) == kv_lora_rank);
STD_TORCH_CHECK(kv_c.scalar_type() == q_pe.scalar_type());
TORCH_CHECK_EQ(rope_cos_sin_cache.size(1), rot_dim);
STD_TORCH_CHECK(rope_cos_sin_cache.size(1) == rot_dim);
STD_TORCH_CHECK(rope_cos_sin_cache.scalar_type() == q_pe.scalar_type());
TORCH_CHECK_EQ(slot_mapping.size(0), num_tokens);
TORCH_CHECK_EQ(slot_mapping.scalar_type(), c10::ScalarType::Long);
STD_TORCH_CHECK(slot_mapping.size(0) == num_tokens);
STD_TORCH_CHECK(slot_mapping.scalar_type() ==
torch::headeronly::ScalarType::Long);
TORCH_CHECK_EQ(kv_cache.size(2), kv_lora_rank + rot_dim);
TORCH_CHECK_EQ(kv_cache.dim(), 3);
STD_TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + rot_dim);
STD_TORCH_CHECK(kv_cache.dim() == 3);
TORCH_CHECK_EQ(kv_cache_quant_scale.numel(), 1);
TORCH_CHECK_EQ(kv_cache_quant_scale.scalar_type(), c10::ScalarType::Float);
STD_TORCH_CHECK(kv_cache_quant_scale.numel() == 1);
STD_TORCH_CHECK(kv_cache_quant_scale.scalar_type() ==
torch::headeronly::ScalarType::Float);
int64_t q_pe_stride_token = q_pe.stride(0);
int64_t q_pe_stride_head = q_pe.stride(1);
@@ -286,9 +298,10 @@ void concat_and_cache_mla_rope_fused(
dim3 grid(num_tokens, 1, 1);
dim3 block(thread_block_size, 1, 1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(positions));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
positions.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.scalar_type(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA_ROPE_FUSED);
}
@@ -17,11 +17,8 @@
#define NVFP4_ENABLE_ELTS16 1
#include "libtorch_stable/quantization/fp4/nvfp4_utils.cuh"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "dispatch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "libtorch_stable/torch_utils.h"
namespace vllm {
@@ -184,12 +181,13 @@ __global__ void reshape_and_cache_nvfp4_kernel(
// Receives key_cache/value_cache as kv_cache[:, 0] and kv_cache[:, 1].
// Each KV side contains both data and scale:
// page = [K_data | K_scale | V_data | V_scale]
void reshape_and_cache_nvfp4_dispatch(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
torch::Tensor& k_scale,
torch::Tensor& v_scale) {
void reshape_and_cache_nvfp4_dispatch(torch::stable::Tensor& key,
torch::stable::Tensor& value,
torch::stable::Tensor& key_cache,
torch::stable::Tensor& value_cache,
torch::stable::Tensor& slot_mapping,
torch::stable::Tensor& k_scale,
torch::stable::Tensor& v_scale) {
int num_tokens = slot_mapping.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
@@ -200,17 +198,18 @@ void reshape_and_cache_nvfp4_dispatch(torch::Tensor& key, torch::Tensor& value,
// key_cache is kv_cache[:, 0] with shape
// [num_blocks, block_size, num_heads, full_dim] in logical order.
// Strides encode the physical layout (HND or NHD).
TORCH_CHECK(key_cache.dim() == 4, "key_cache must be 4D");
TORCH_CHECK(key_cache.size(3) == full_dim,
"key_cache last dim must be data_dim + scale_dim, got ",
key_cache.size(3), " expected ", full_dim);
STD_TORCH_CHECK(key_cache.dim() == 4, "key_cache must be 4D");
STD_TORCH_CHECK(key_cache.size(3) == full_dim,
"key_cache last dim must be data_dim + scale_dim, got ",
key_cache.size(3), " expected ", full_dim);
int block_size = key_cache.size(1);
TORCH_CHECK(head_size % 16 == 0,
"head_size must be divisible by 16 for NVFP4 KV cache");
TORCH_CHECK(block_size % 4 == 0,
"block_size must be divisible by 4 for NVFP4 KV cache swizzle");
STD_TORCH_CHECK(head_size % 16 == 0,
"head_size must be divisible by 16 for NVFP4 KV cache");
STD_TORCH_CHECK(block_size % 4 == 0,
"block_size must be divisible by 4 for NVFP4 KV cache "
"swizzle");
// Detect physical layout from strides (based on full_dim).
// HND: head stride > block_offset stride.
@@ -230,8 +229,9 @@ void reshape_and_cache_nvfp4_dispatch(torch::Tensor& key, torch::Tensor& value,
// Scale follows data within each KV side.
int64_t data_per_kv = (int64_t)num_heads * block_size * data_dim;
uint8_t* key_scale_ptr = key_cache.data_ptr<uint8_t>() + data_per_kv;
uint8_t* value_scale_ptr = value_cache.data_ptr<uint8_t>() + data_per_kv;
uint8_t* key_scale_ptr = key_cache.mutable_data_ptr<uint8_t>() + data_per_kv;
uint8_t* value_scale_ptr =
value_cache.mutable_data_ptr<uint8_t>() + data_per_kv;
// Scale strides: same page stride, inner strides from layout.
int64_t scale_block_stride = data_block_stride;
@@ -244,8 +244,8 @@ void reshape_and_cache_nvfp4_dispatch(torch::Tensor& key, torch::Tensor& value,
scale_block_offset_stride = (int64_t)num_heads * scale_dim;
}
const float* k_scale_ptr = k_scale.data_ptr<float>();
const float* v_scale_ptr = v_scale.data_ptr<float>();
const float* k_scale_ptr = k_scale.const_data_ptr<float>();
const float* v_scale_ptr = v_scale.const_data_ptr<float>();
int groups_per_head = head_size / CVT_FP4_SF_VEC_SIZE;
int total_groups = num_heads * groups_per_head;
@@ -256,20 +256,22 @@ void reshape_and_cache_nvfp4_dispatch(torch::Tensor& key, torch::Tensor& value,
dim3 grid(num_tokens);
dim3 block(num_threads);
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
key.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
AT_DISPATCH_REDUCED_FLOATING_TYPES(
VLLM_STABLE_DISPATCH_HALF_TYPES(
key.scalar_type(), "reshape_and_cache_nvfp4", [&] {
vllm::reshape_and_cache_nvfp4_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key_cache.data_ptr<uint8_t>(), value_cache.data_ptr<uint8_t>(),
key_scale_ptr, value_scale_ptr,
slot_mapping.data_ptr<int64_t>(), k_scale_ptr, v_scale_ptr,
key.stride(0), value.stride(0), num_heads, head_size,
block_size, data_block_stride, data_head_stride,
data_block_offset_stride, scale_block_stride, scale_head_stride,
scale_block_offset_stride);
key.const_data_ptr<scalar_t>(),
value.const_data_ptr<scalar_t>(),
key_cache.mutable_data_ptr<uint8_t>(),
value_cache.mutable_data_ptr<uint8_t>(), key_scale_ptr,
value_scale_ptr, slot_mapping.const_data_ptr<int64_t>(),
k_scale_ptr, v_scale_ptr, key.stride(0), value.stride(0),
num_heads, head_size, block_size, data_block_stride,
data_head_stride, data_block_offset_stride, scale_block_stride,
scale_head_stride, scale_block_offset_stride);
});
}
+129
View File
@@ -355,3 +355,132 @@ torch::stable::Tensor ggml_moe_a8_vec(torch::stable::Tensor X,
int64_t tokens);
int64_t ggml_moe_get_block_size(int64_t type);
void paged_attention_v1(
torch::stable::Tensor& out, torch::stable::Tensor& query,
torch::stable::Tensor& key_cache, torch::stable::Tensor& value_cache,
int64_t num_kv_heads, double scale, torch::stable::Tensor& block_tables,
torch::stable::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len,
const std::optional<torch::stable::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::stable::Tensor& k_scale,
torch::stable::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2(
torch::stable::Tensor& out, torch::stable::Tensor& exp_sums,
torch::stable::Tensor& max_logits, torch::stable::Tensor& tmp_out,
torch::stable::Tensor& query, torch::stable::Tensor& key_cache,
torch::stable::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::stable::Tensor& block_tables, torch::stable::Tensor& seq_lens,
int64_t block_size, int64_t max_seq_len,
const std::optional<torch::stable::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::stable::Tensor& k_scale,
torch::stable::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
// Cache ops (shared CUDA/ROCm)
void swap_blocks(torch::stable::Tensor& src, torch::stable::Tensor& dst,
int64_t block_size_in_bytes,
const torch::stable::Tensor& block_mapping);
// Batch swap: submit all block copies in a single driver call.
void swap_blocks_batch(const torch::stable::Tensor& src_ptrs,
const torch::stable::Tensor& dst_ptrs,
const torch::stable::Tensor& sizes,
bool is_src_access_order_any);
void reshape_and_cache(torch::stable::Tensor& key, torch::stable::Tensor& value,
torch::stable::Tensor& key_cache,
torch::stable::Tensor& value_cache,
torch::stable::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::stable::Tensor& k_scale,
torch::stable::Tensor& v_scale);
void reshape_and_cache_flash(
torch::stable::Tensor& key, torch::stable::Tensor& value,
torch::stable::Tensor& key_cache, torch::stable::Tensor& value_cache,
torch::stable::Tensor& slot_mapping, const std::string& kv_cache_dtype,
torch::stable::Tensor& k_scale, torch::stable::Tensor& v_scale);
void concat_and_cache_mla(torch::stable::Tensor& kv_c,
torch::stable::Tensor& k_pe,
torch::stable::Tensor& kv_cache,
torch::stable::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::stable::Tensor& scale);
// NOTE: k_pe and kv_c order is flipped compared to concat_and_cache_mla
void concat_and_cache_mla_rope_fused(
torch::stable::Tensor& positions, torch::stable::Tensor& q_pe,
torch::stable::Tensor& k_pe, torch::stable::Tensor& kv_c,
torch::stable::Tensor& rope_cos_sin_cache, bool rope_is_neox,
torch::stable::Tensor& slot_mapping, torch::stable::Tensor& kv_cache,
const std::string& kv_cache_dtype,
torch::stable::Tensor& kv_cache_quant_scale);
// Just for unittest
void convert_fp8(torch::stable::Tensor& dst_cache,
torch::stable::Tensor& src_cache, const double scale,
const std::string& kv_cache_dtype);
void gather_and_maybe_dequant_cache(
torch::stable::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
// ENTRIES...]
torch::stable::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::stable::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::stable::Tensor const& cu_seq_lens, // [BATCH+1]
torch::stable::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS]
int64_t num_tokens, const std::string& kv_cache_dtype,
torch::stable::Tensor const& scale,
std::optional<torch::stable::Tensor> seq_starts = std::nullopt);
// TODO(hc): cp_gather_cache need support scaled kvcahe in the future.
void cp_gather_cache(
torch::stable::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
// ENTRIES...]
torch::stable::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::stable::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::stable::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size,
std::optional<torch::stable::Tensor> seq_starts = std::nullopt);
// Gather and upconvert FP8 KV cache to BF16 workspace
void cp_gather_and_upconvert_fp8_kv_cache(
torch::stable::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
// 656]
torch::stable::Tensor const& dst, // [TOT_TOKENS, 576]
torch::stable::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::stable::Tensor const& seq_lens, // [BATCH]
torch::stable::Tensor const& workspace_starts, // [BATCH]
int64_t batch_size);
// Indexer K quantization and cache function
void indexer_k_quant_and_cache(
torch::stable::Tensor& k, // [num_tokens, head_dim]
torch::stable::Tensor& kv_cache, // [num_blocks, block_size,
// cache_stride]
torch::stable::Tensor& slot_mapping, // [num_tokens]
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt);
// Concatenate query nope and rope for MLA/DSA attention
void concat_mla_q(
torch::stable::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim]
torch::stable::Tensor& q_pe, // [num_tokens, num_heads, rope_dim]
torch::stable::Tensor& q_out); // [num_tokens, num_heads, nope_dim +
// rope_dim]
// Extract function to gather quantized K cache
void cp_gather_indexer_k_quant_cache(
const torch::stable::Tensor& kv_cache, // [num_blocks, block_size,
// cache_stride]
torch::stable::Tensor& dst_k, // [num_tokens, head_dim]
torch::stable::Tensor& dst_scale, // [num_tokens, head_dim /
// quant_block_size * 4]
const torch::stable::Tensor& block_table, // [batch_size, num_blocks]
const torch::stable::Tensor& cu_seq_lens); // [batch_size + 1]
@@ -17,7 +17,7 @@
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include "../../cuda_vec_utils.cuh"
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
@@ -27,7 +27,7 @@
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include "../../cuda_vec_utils.cuh"
#include "cuda_utils.h"
#include "nvfp4_utils.cuh"
@@ -17,7 +17,7 @@
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include "../../cuda_vec_utils.cuh"
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
@@ -23,7 +23,7 @@
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include "../../cuda_vec_utils.cuh"
#include "cuda_utils.h"
#include "launch_bounds_utils.h"
@@ -20,7 +20,7 @@
#include <cuda_fp8.h>
#include <utility>
#include "cuda_vec_utils.cuh"
#include "../../cuda_vec_utils.cuh"
#if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \
CUDA_VERSION >= 12090
+141
View File
@@ -474,6 +474,33 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"Tensor? initial_state_idx,"
"Tensor? cu_chunk_seqlen,"
"Tensor? last_chunk_indices) -> ()");
// Attention ops
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
ops.def(
"paged_attention_v1("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
// PagedAttention V2.
ops.def(
"paged_attention_v2("
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
}
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
@@ -581,6 +608,9 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
ops.impl("ggml_moe_a8", TORCH_BOX(&ggml_moe_a8));
ops.impl("ggml_moe_a8_vec", TORCH_BOX(&ggml_moe_a8_vec));
ops.impl("selective_scan_fwd", TORCH_BOX(&selective_scan_fwd));
ops.impl("paged_attention_v1", TORCH_BOX(&paged_attention_v1));
ops.impl("paged_attention_v2", TORCH_BOX(&paged_attention_v2));
}
// These capability-check functions take only primitive args (no tensors), so
@@ -603,4 +633,115 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
ops.impl("ggml_moe_get_block_size", TORCH_BOX(&ggml_moe_get_block_size));
}
// Cache ops
STABLE_TORCH_LIBRARY_FRAGMENT(_C_cache_ops, ops) {
// Swap in (out) the cache blocks from src to dst.
ops.def(
"swap_blocks(Tensor src, Tensor! dst,"
" int block_size_in_bytes, Tensor block_mapping) -> ()");
// Batch swap: submit all block copies in a single driver call.
ops.def(
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs,"
" Tensor sizes,"
" bool is_src_access_order_any=False) -> ()");
// Reshape the key and value tensors and cache them.
ops.def(
"reshape_and_cache(Tensor key, Tensor value,"
" Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()");
// Reshape the key and value tensors and cache them.
ops.def(
"reshape_and_cache_flash(Tensor key, Tensor value,"
" Tensor! key_cache,"
" Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()");
// Concat kv_c and k_pe and cache them.
ops.def(
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()");
// Rotate Q and K, then write to kv cache for MLA
ops.def(
"concat_and_cache_mla_rope_fused("
" Tensor positions,"
" Tensor! q_pe,"
" Tensor! k_pe,"
" Tensor kv_c,"
" Tensor cos_sin_cache,"
" bool is_neox,"
" Tensor slot_mapping,"
" Tensor! kv_cache,"
" str kv_cache_dtype,"
" Tensor kv_cache_scale) -> ()");
// Convert the key and value cache to fp8 data type.
ops.def(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
"str kv_cache_dtype) -> ()");
// Gather cache blocks from src_cache to dst, dequantizing from
// src_cache's dtype to dst's dtype if necessary.
ops.def(
"gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, "
" Tensor block_table, Tensor cu_seq_lens, "
" Tensor token_to_seq, "
" int num_tokens, "
" str kv_cache_dtype, "
" Tensor scale, Tensor? seq_starts) -> ()");
ops.def(
"cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
ops.def(
"cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, "
"Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int "
"batch_size) -> ()");
ops.def(
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
"slot_mapping, "
"int quant_block_size, str kv_cache_dtype) -> ()");
ops.def("concat_mla_q(Tensor ql_nope, Tensor q_pe, Tensor! q_out) -> ()");
ops.def(
"cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! "
"dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()");
}
STABLE_TORCH_LIBRARY_IMPL(_C_cache_ops, CPU, ops) {
ops.impl("swap_blocks_batch", TORCH_BOX(&swap_blocks_batch));
}
STABLE_TORCH_LIBRARY_IMPL(_C_cache_ops, CUDA, ops) {
ops.impl("swap_blocks", TORCH_BOX(&swap_blocks));
ops.impl("reshape_and_cache", TORCH_BOX(&reshape_and_cache));
ops.impl("reshape_and_cache_flash", TORCH_BOX(&reshape_and_cache_flash));
ops.impl("concat_and_cache_mla", TORCH_BOX(&concat_and_cache_mla));
ops.impl("concat_and_cache_mla_rope_fused",
TORCH_BOX(&concat_and_cache_mla_rope_fused));
ops.impl("convert_fp8", TORCH_BOX(&convert_fp8));
ops.impl("gather_and_maybe_dequant_cache",
TORCH_BOX(&gather_and_maybe_dequant_cache));
ops.impl("cp_gather_cache", TORCH_BOX(&cp_gather_cache));
ops.impl("cp_gather_and_upconvert_fp8_kv_cache",
TORCH_BOX(&cp_gather_and_upconvert_fp8_kv_cache));
ops.impl("indexer_k_quant_and_cache", TORCH_BOX(&indexer_k_quant_and_cache));
ops.impl("concat_mla_q", TORCH_BOX(&concat_mla_q));
ops.impl("cp_gather_indexer_k_quant_cache",
TORCH_BOX(&cp_gather_indexer_k_quant_cache));
}
REGISTER_EXTENSION(_C_stable_libtorch)
-23
View File
@@ -31,29 +31,6 @@ torch::Tensor weak_ref_tensor(torch::Tensor& tensor) {
return new_tensor;
}
void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
// rms_norm and fused_add_rms_norm declarations also exist in
// csrc/libtorch_stable/ops.h (torch::stable ABI for CUDA). They remain here
// because the CPU build still uses these torch::Tensor declarations.
+12 -9
View File
@@ -6,6 +6,7 @@
#include <hip/hip_bfloat16.h>
#include "../../../../attention/attention_dtypes.h"
#include <torch/headeronly/core/ScalarType.h>
namespace vllm {
#ifdef USE_ROCM
@@ -642,27 +643,29 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
vllm::Fp8KVCacheDataType KV_CACHE_DTYPE = \
vllm::get_fp8_kv_cache_data_type(KV_DTYPE); \
if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
if (SRC_DTYPE == torch::headeronly::ScalarType::Float) { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
} else if (SRC_DTYPE == torch::headeronly::ScalarType::Half) { \
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
} else if (SRC_DTYPE == torch::headeronly::ScalarType::BFloat16) { \
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
STD_TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kFp8E4M3) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
if (SRC_DTYPE == torch::headeronly::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
} else if (SRC_DTYPE == torch::headeronly::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
} else if (SRC_DTYPE == torch::headeronly::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
STD_TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
STD_TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
}
} // namespace fp8
@@ -1,6 +1,7 @@
#pragma once
#include "../../../../attention/attention_dtypes.h"
#include <torch/headeronly/core/ScalarType.h>
#include <assert.h>
#include <float.h>
#include <stdint.h>
@@ -546,37 +547,40 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
vllm::Fp8KVCacheDataType KV_CACHE_DTYPE = \
vllm::get_fp8_kv_cache_data_type(KV_DTYPE); \
if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
if (SRC_DTYPE == torch::headeronly::ScalarType::Float) { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
} else if (SRC_DTYPE == torch::headeronly::ScalarType::Half) { \
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
} else if (SRC_DTYPE == torch::headeronly::ScalarType::BFloat16) { \
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
STD_TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kFp8E4M3) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
if (SRC_DTYPE == torch::headeronly::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
} else if (SRC_DTYPE == torch::headeronly::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
} else if (SRC_DTYPE == torch::headeronly::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
STD_TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kFp8E5M2) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
if (SRC_DTYPE == torch::headeronly::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
} else if (SRC_DTYPE == torch::headeronly::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
} else if (SRC_DTYPE == torch::headeronly::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
STD_TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
STD_TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
}
} // namespace fp8
+4 -138
View File
@@ -1,4 +1,7 @@
#include "cache.h"
// Provides torch::Tensor for ops.h (previously included transitively via
// cache.h, which is no longer included here after cache ops moved to
// _C_stable_libtorch).
#include <torch/all.h>
#include "cuda_utils.h"
#include "ops.h"
#include "core/registration.h"
@@ -33,35 +36,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU,
&get_cuda_view_from_cpu_tensor);
// Attention ops
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
ops.def(
"paged_attention_v1("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
// PagedAttention V2.
ops.def(
"paged_attention_v2("
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
// Activation ops (quantized only — basic ops moved to _C_stable_libtorch)
ops.def(
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
@@ -217,114 +191,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#endif
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Cache ops
// Swap in (out) the cache blocks from src to dst.
cache_ops.def(
"swap_blocks(Tensor src, Tensor! dst,"
" int block_size_in_bytes, Tensor block_mapping) -> ()");
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
// Batch swap: submit all block copies in a single driver call.
cache_ops.def(
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs,"
" Tensor sizes,"
" bool is_src_access_order_any=False) -> ()");
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache(Tensor key, Tensor value,"
" Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache_flash(Tensor key, Tensor value,"
" Tensor! key_cache,"
" Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
&reshape_and_cache_flash);
// Concat kv_c and k_pe and cache them.
cache_ops.def(
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()");
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
// Rotate Q and K, then write to kv cache for MLA
cache_ops.def(
"concat_and_cache_mla_rope_fused("
" Tensor positions,"
" Tensor! q_pe,"
" Tensor! k_pe,"
" Tensor kv_c,"
" Tensor cos_sin_cache,"
" bool is_neox,"
" Tensor slot_mapping,"
" Tensor! kv_cache,"
" str kv_cache_dtype,"
" Tensor kv_cache_scale) -> ()");
cache_ops.impl("concat_and_cache_mla_rope_fused", torch::kCUDA,
&concat_and_cache_mla_rope_fused);
// Convert the key and value cache to fp8 data type.
cache_ops.def(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
"str kv_cache_dtype) -> ()");
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
// Gather cache blocks from src_cache to dst, dequantizing from
// src_cache's dtype to dst's dtype if necessary.
cache_ops.def(
"gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, "
" Tensor block_table, Tensor cu_seq_lens, "
" Tensor token_to_seq, "
" int num_tokens, "
" str kv_cache_dtype, "
" Tensor scale, Tensor? seq_starts) -> ()");
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,
&gather_and_maybe_dequant_cache);
cache_ops.def(
"cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
cache_ops.def(
"cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, "
"Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int "
"batch_size) -> ()");
cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA,
&cp_gather_and_upconvert_fp8_kv_cache);
cache_ops.def(
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
"slot_mapping, "
"int quant_block_size, str kv_cache_dtype) -> ()");
cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA,
&indexer_k_quant_and_cache);
cache_ops.def(
"concat_mla_q(Tensor ql_nope, Tensor q_pe, Tensor! q_out) -> ()");
cache_ops.impl("concat_mla_q", torch::kCUDA, &concat_mla_q);
cache_ops.def(
"cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! "
"dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()");
cache_ops.impl("cp_gather_indexer_k_quant_cache", torch::kCUDA,
&cp_gather_indexer_k_quant_cache);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
// Cuda utils