mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[8/n] Migrate merge_attn_states, mamba, sampler to torch stable ABI (continued) (#43361)
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com> Signed-off-by: Chris Leonard <chleonar@redhat.com> Co-authored-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com> Co-authored-by: Shengqi Chen <harry-chen@outlook.com>
This commit is contained in:
+5
-5
@@ -305,14 +305,10 @@ endif()
|
||||
#
|
||||
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
||||
"csrc/cache_kernels.cu"
|
||||
"csrc/cache_kernels_fused.cu"
|
||||
"csrc/attention/paged_attention_v1.cu"
|
||||
"csrc/attention/paged_attention_v2.cu"
|
||||
"csrc/attention/merge_attn_states.cu"
|
||||
"csrc/sampler.cu"
|
||||
"csrc/topk.cu"
|
||||
"csrc/cuda_view.cu"
|
||||
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
|
||||
"csrc/quantization/activation_kernels.cu"
|
||||
@@ -633,7 +629,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
"csrc/libtorch_stable/fused_qknorm_rope_kernel.cu"
|
||||
"csrc/libtorch_stable/layernorm_kernels.cu"
|
||||
"csrc/libtorch_stable/layernorm_quant_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu")
|
||||
"csrc/libtorch_stable/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
|
||||
"csrc/libtorch_stable/attention/merge_attn_states.cu"
|
||||
"csrc/libtorch_stable/sampler.cu"
|
||||
"csrc/libtorch_stable/topk.cu"
|
||||
"csrc/libtorch_stable/mamba/selective_scan_fwd.cu")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC
|
||||
|
||||
+44
-37
@@ -1,14 +1,14 @@
|
||||
#include <optional>
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
#include "../quantization/w8a8/fp8/common.cuh"
|
||||
#include "../torch_utils.h"
|
||||
#include "../dispatch_utils.h"
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
#include "../../attention/attention_dtypes.h"
|
||||
#include "../../attention/attention_utils.cuh"
|
||||
#include "../../quantization/w8a8/fp8/common.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
@@ -198,14 +198,14 @@ __global__ void merge_attn_states_kernel(
|
||||
// template<typename scalar_t>.
|
||||
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
|
||||
{ \
|
||||
if (scalar_dtype == at::ScalarType::Float) { \
|
||||
if (scalar_dtype == torch::headeronly::ScalarType::Float) { \
|
||||
fn(float); \
|
||||
} else if (scalar_dtype == at::ScalarType::Half) { \
|
||||
} else if (scalar_dtype == torch::headeronly::ScalarType::Half) { \
|
||||
fn(uint16_t); \
|
||||
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
|
||||
} else if (scalar_dtype == torch::headeronly::ScalarType::BFloat16) { \
|
||||
fn(__nv_bfloat16); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
|
||||
STD_TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
|
||||
} \
|
||||
}
|
||||
|
||||
@@ -245,11 +245,14 @@ __global__ void merge_attn_states_kernel(
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
void merge_attn_states_launcher(
|
||||
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
|
||||
torch::stable::Tensor& output,
|
||||
std::optional<torch::stable::Tensor> output_lse,
|
||||
const torch::stable::Tensor& prefix_output,
|
||||
const torch::stable::Tensor& prefix_lse,
|
||||
const torch::stable::Tensor& suffix_output,
|
||||
const torch::stable::Tensor& suffix_lse,
|
||||
const std::optional<int64_t> prefill_tokens_with_context,
|
||||
const std::optional<torch::Tensor>& output_scale) {
|
||||
const std::optional<torch::stable::Tensor>& output_scale) {
|
||||
constexpr uint NUM_THREADS = 128;
|
||||
const uint num_tokens = output.size(0);
|
||||
const uint num_heads = output.size(1);
|
||||
@@ -258,23 +261,23 @@ void merge_attn_states_launcher(
|
||||
const uint output_head_stride = output.stride(1);
|
||||
// Thread mapping is based on input BF16 pack_size
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
TORCH_CHECK(head_size % pack_size == 0,
|
||||
STD_TORCH_CHECK(head_size % pack_size == 0,
|
||||
"headsize must be multiple of pack_size:", pack_size);
|
||||
|
||||
const uint prefix_num_tokens =
|
||||
prefill_tokens_with_context.has_value()
|
||||
? static_cast<uint>(prefill_tokens_with_context.value())
|
||||
: num_tokens;
|
||||
TORCH_CHECK(prefix_num_tokens <= num_tokens,
|
||||
STD_TORCH_CHECK(prefix_num_tokens <= num_tokens,
|
||||
"prefix_num_tokens must be <= num_tokens");
|
||||
|
||||
float* output_lse_ptr = nullptr;
|
||||
if (output_lse.has_value()) {
|
||||
output_lse_ptr = output_lse.value().data_ptr<float>();
|
||||
output_lse_ptr = output_lse.value().mutable_data_ptr<float>();
|
||||
}
|
||||
float* output_scale_ptr = nullptr;
|
||||
if (output_scale.has_value()) {
|
||||
output_scale_ptr = output_scale.value().data_ptr<float>();
|
||||
output_scale_ptr = output_scale.value().mutable_data_ptr<float>();
|
||||
}
|
||||
// Process one pack elements per thread. for float, the
|
||||
// pack_size is 4 for half/bf16, the pack_size is 8.
|
||||
@@ -284,14 +287,15 @@ void merge_attn_states_launcher(
|
||||
dim3 block(NUM_THREADS);
|
||||
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
|
||||
|
||||
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
prefix_output.get_device_index());
|
||||
auto stream = get_current_cuda_stream();
|
||||
|
||||
if (output_scale.has_value()) {
|
||||
// FP8 output path - dispatch on output FP8 type
|
||||
VLLM_DISPATCH_FP8_TYPES(output.scalar_type(), "merge_attn_states_fp8", [&] {
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, fp8_t, NUM_THREADS, true);
|
||||
});
|
||||
VLLM_STABLE_DISPATCH_FP8_TYPES(
|
||||
output.scalar_type(), "merge_attn_states_fp8",
|
||||
[&] { LAUNCH_MERGE_ATTN_STATES(scalar_t, fp8_t, NUM_THREADS, true); });
|
||||
} else {
|
||||
// Original BF16/FP16/FP32 output path
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, scalar_t, NUM_THREADS, false);
|
||||
@@ -305,26 +309,29 @@ void merge_attn_states_launcher(
|
||||
suffix_lse, prefill_tokens_with_context, output_scale); \
|
||||
}
|
||||
|
||||
void merge_attn_states(torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse,
|
||||
std::optional<int64_t> prefill_tokens_with_context,
|
||||
const std::optional<torch::Tensor>& output_scale) {
|
||||
void merge_attn_states(
|
||||
torch::stable::Tensor& output,
|
||||
std::optional<torch::stable::Tensor> output_lse,
|
||||
const torch::stable::Tensor& prefix_output,
|
||||
const torch::stable::Tensor& prefix_lse,
|
||||
const torch::stable::Tensor& suffix_output,
|
||||
const torch::stable::Tensor& suffix_lse,
|
||||
const std::optional<int64_t> prefill_tokens_with_context,
|
||||
const std::optional<torch::stable::Tensor>& output_scale) {
|
||||
if (output_scale.has_value()) {
|
||||
TORCH_CHECK(output.scalar_type() == at::ScalarType::Float8_e4m3fn ||
|
||||
output.scalar_type() == at::ScalarType::Float8_e4m3fnuz,
|
||||
STD_TORCH_CHECK(
|
||||
output.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn ||
|
||||
output.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fnuz,
|
||||
"output must be FP8 when output_scale is provided, got: ",
|
||||
output.scalar_type());
|
||||
} else {
|
||||
TORCH_CHECK(output.scalar_type() == prefix_output.scalar_type(),
|
||||
"output dtype (", output.scalar_type(),
|
||||
") must match prefix_output dtype (",
|
||||
STD_TORCH_CHECK(
|
||||
output.scalar_type() == prefix_output.scalar_type(), "output dtype (",
|
||||
output.scalar_type(), ") must match prefix_output dtype (",
|
||||
prefix_output.scalar_type(), ") when output_scale is not set");
|
||||
}
|
||||
// Always dispatch on prefix_output (input) dtype
|
||||
DISPATCH_BY_SCALAR_DTYPE(prefix_output.dtype(),
|
||||
DISPATCH_BY_SCALAR_DTYPE(prefix_output.scalar_type(),
|
||||
CALL_MERGE_ATTN_STATES_LAUNCHER);
|
||||
}
|
||||
@@ -12,6 +12,9 @@
|
||||
#include <hip/hip_bf16.h>
|
||||
#endif
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <torch/headeronly/util/Half.h>
|
||||
#include <torch/headeronly/util/BFloat16.h>
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SSMParamsBase {
|
||||
@@ -159,8 +162,8 @@ struct Converter{
|
||||
};
|
||||
|
||||
template<int N>
|
||||
struct Converter<at::Half, N>{
|
||||
static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
|
||||
struct Converter<torch::headeronly::Half, N>{
|
||||
static inline __device__ void to_float(const torch::headeronly::Half (&src)[N], float (&dst)[N]) {
|
||||
static_assert(N % 2 == 0);
|
||||
auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
|
||||
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
||||
@@ -171,8 +174,8 @@ struct Converter<at::Half, N>{
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
template<int N>
|
||||
struct Converter<at::BFloat16, N>{
|
||||
static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
|
||||
struct Converter<torch::headeronly::BFloat16, N>{
|
||||
static inline __device__ void to_float(const torch::headeronly::BFloat16 (&src)[N], float (&dst)[N]) {
|
||||
static_assert(N % 2 == 0);
|
||||
auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
|
||||
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
||||
+98
-107
@@ -1,18 +1,9 @@
|
||||
// clang-format off
|
||||
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "../torch_utils.h"
|
||||
#include <torch/csrc/stable/macros.h>
|
||||
#include "selective_scan.h"
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#ifdef USE_ROCM
|
||||
#include <c10/hip/HIPException.h> // For C10_HIP_CHECK and C10_HIP_KERNEL_LAUNCH_CHECK
|
||||
#else
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
@@ -416,15 +407,15 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
#ifdef USE_ROCM
|
||||
C10_HIP_CHECK(hipFuncSetAttribute(
|
||||
STD_CUDA_CHECK(hipFuncSetAttribute(
|
||||
reinterpret_cast<const void*>(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
STD_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
#endif
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
STD_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -462,46 +453,46 @@ void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||
#endif
|
||||
}
|
||||
|
||||
template void selective_scan_fwd_cuda<at::BFloat16, float, at::BFloat16>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
template void selective_scan_fwd_cuda<at::BFloat16, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
template void selective_scan_fwd_cuda<at::Half, float, at::Half>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
template void selective_scan_fwd_cuda<at::Half, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
template void selective_scan_fwd_cuda<torch::headeronly::BFloat16, float, torch::headeronly::BFloat16>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
template void selective_scan_fwd_cuda<torch::headeronly::BFloat16, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
template void selective_scan_fwd_cuda<torch::headeronly::Half, float, torch::headeronly::Half>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
template void selective_scan_fwd_cuda<torch::headeronly::Half, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
template void selective_scan_fwd_cuda<float, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
#define CHECK_SHAPE(x, ...) STD_TORCH_CHECK(x.sizes().equals(torch::headeronly::IntHeaderOnlyArrayRef({__VA_ARGS__})), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, STYPE, NAME, ...) \
|
||||
if (ITYPE == at::ScalarType::Half) { \
|
||||
using input_t = at::Half; \
|
||||
if (ITYPE == torch::headeronly::ScalarType::Half) { \
|
||||
using input_t = torch::headeronly::Half; \
|
||||
using weight_t = float; \
|
||||
if (STYPE == at::ScalarType::Half) { \
|
||||
using state_t = at::Half; \
|
||||
if (STYPE == torch::headeronly::ScalarType::Half) { \
|
||||
using state_t = torch::headeronly::Half; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (STYPE == at::ScalarType::Float) { \
|
||||
} else if (STYPE == torch::headeronly::ScalarType::Float) { \
|
||||
using state_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
|
||||
STD_TORCH_CHECK(false, #NAME " not implemented for state type '", STYPE, "'"); \
|
||||
} \
|
||||
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
||||
using input_t = at::BFloat16; \
|
||||
} else if (ITYPE == torch::headeronly::ScalarType::BFloat16) { \
|
||||
using input_t = torch::headeronly::BFloat16; \
|
||||
using weight_t = float; \
|
||||
if (STYPE == at::ScalarType::BFloat16) { \
|
||||
using state_t = at::BFloat16; \
|
||||
if (STYPE == torch::headeronly::ScalarType::BFloat16) { \
|
||||
using state_t = torch::headeronly::BFloat16; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (STYPE == at::ScalarType::Float) { \
|
||||
} else if (STYPE == torch::headeronly::ScalarType::Float) { \
|
||||
using state_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
|
||||
STD_TORCH_CHECK(false, #NAME " not implemented for state type '", STYPE, "'"); \
|
||||
} \
|
||||
} else if (ITYPE == at::ScalarType::Float) { \
|
||||
} else if (ITYPE == torch::headeronly::ScalarType::Float) { \
|
||||
using input_t = float; \
|
||||
using weight_t = float; \
|
||||
using state_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
||||
STD_TORCH_CHECK(false, #NAME " not implemented for input type '", ITYPE, "'"); \
|
||||
}
|
||||
|
||||
|
||||
@@ -518,30 +509,30 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
const bool is_variable_B,
|
||||
const bool is_variable_C,
|
||||
// device pointers
|
||||
const torch::Tensor u,
|
||||
const torch::Tensor delta,
|
||||
const torch::Tensor A,
|
||||
const torch::Tensor B,
|
||||
const torch::Tensor C,
|
||||
const torch::Tensor out,
|
||||
const torch::Tensor z,
|
||||
const torch::Tensor out_z,
|
||||
const std::optional<at::Tensor>& D,
|
||||
const std::optional<at::Tensor>& delta_bias,
|
||||
const torch::Tensor ssm_states,
|
||||
const torch::stable::Tensor u,
|
||||
const torch::stable::Tensor delta,
|
||||
const torch::stable::Tensor A,
|
||||
const torch::stable::Tensor B,
|
||||
const torch::stable::Tensor C,
|
||||
const torch::stable::Tensor out,
|
||||
const torch::stable::Tensor z,
|
||||
const torch::stable::Tensor out_z,
|
||||
const std::optional<torch::stable::Tensor>& D,
|
||||
const std::optional<torch::stable::Tensor>& delta_bias,
|
||||
const torch::stable::Tensor ssm_states,
|
||||
bool has_z,
|
||||
bool delta_softplus,
|
||||
const std::optional<at::Tensor>& query_start_loc,
|
||||
const std::optional<at::Tensor>& cache_indices,
|
||||
const std::optional<at::Tensor>& has_initial_state,
|
||||
const std::optional<torch::stable::Tensor>& query_start_loc,
|
||||
const std::optional<torch::stable::Tensor>& cache_indices,
|
||||
const std::optional<torch::stable::Tensor>& has_initial_state,
|
||||
bool varlen,
|
||||
int64_t null_block_id,
|
||||
int64_t block_size,
|
||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
||||
const std::optional<torch::Tensor> &initial_state_idx,
|
||||
const std::optional<torch::Tensor> &cu_chunk_seqlen,
|
||||
const std::optional<torch::Tensor> &last_chunk_indices) {
|
||||
const std::optional<torch::stable::Tensor> &block_idx_first_scheduled_token,
|
||||
const std::optional<torch::stable::Tensor> &block_idx_last_scheduled_token,
|
||||
const std::optional<torch::stable::Tensor> &initial_state_idx,
|
||||
const std::optional<torch::stable::Tensor> &cu_chunk_seqlen,
|
||||
const std::optional<torch::stable::Tensor> &last_chunk_indices) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
@@ -654,45 +645,45 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
}
|
||||
}
|
||||
|
||||
void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C,
|
||||
const std::optional<torch::Tensor> &D_,
|
||||
const std::optional<torch::Tensor> &z_,
|
||||
const std::optional<torch::Tensor> &delta_bias_,
|
||||
void selective_scan_fwd(const torch::stable::Tensor &u, const torch::stable::Tensor &delta,
|
||||
const torch::stable::Tensor &A, const torch::stable::Tensor &B, const torch::stable::Tensor &C,
|
||||
const std::optional<torch::stable::Tensor> &D_,
|
||||
const std::optional<torch::stable::Tensor> &z_,
|
||||
const std::optional<torch::stable::Tensor> &delta_bias_,
|
||||
bool delta_softplus,
|
||||
const std::optional<torch::Tensor> &query_start_loc,
|
||||
const std::optional<torch::Tensor> &cache_indices,
|
||||
const std::optional<torch::Tensor> &has_initial_state,
|
||||
const torch::Tensor &ssm_states,
|
||||
const std::optional<torch::stable::Tensor> &query_start_loc,
|
||||
const std::optional<torch::stable::Tensor> &cache_indices,
|
||||
const std::optional<torch::stable::Tensor> &has_initial_state,
|
||||
const torch::stable::Tensor &ssm_states,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t null_block_id,
|
||||
int64_t block_size,
|
||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
||||
const std::optional<torch::Tensor> &initial_state_idx,
|
||||
const std::optional<torch::Tensor> &cu_chunk_seqlen,
|
||||
const std::optional<torch::Tensor> &last_chunk_indices) {
|
||||
const std::optional<torch::stable::Tensor> &block_idx_first_scheduled_token,
|
||||
const std::optional<torch::stable::Tensor> &block_idx_last_scheduled_token,
|
||||
const std::optional<torch::stable::Tensor> &initial_state_idx,
|
||||
const std::optional<torch::stable::Tensor> &cu_chunk_seqlen,
|
||||
const std::optional<torch::stable::Tensor> &last_chunk_indices) {
|
||||
auto input_type = u.scalar_type();
|
||||
auto weight_type = A.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float);
|
||||
STD_TORCH_CHECK(input_type == torch::headeronly::ScalarType::Float || input_type == torch::headeronly::ScalarType::Half || input_type == torch::headeronly::ScalarType::BFloat16);
|
||||
STD_TORCH_CHECK(weight_type == torch::headeronly::ScalarType::Float);
|
||||
|
||||
const bool is_variable_B = B.dim() >= 3;
|
||||
const bool is_variable_C = C.dim() >= 3;
|
||||
|
||||
TORCH_CHECK(delta.scalar_type() == input_type);
|
||||
TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
|
||||
TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
|
||||
STD_TORCH_CHECK(delta.scalar_type() == input_type);
|
||||
STD_TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
|
||||
STD_TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
|
||||
|
||||
TORCH_CHECK(u.is_cuda());
|
||||
TORCH_CHECK(delta.is_cuda());
|
||||
TORCH_CHECK(A.is_cuda());
|
||||
TORCH_CHECK(B.is_cuda());
|
||||
TORCH_CHECK(C.is_cuda());
|
||||
STD_TORCH_CHECK(u.is_cuda());
|
||||
STD_TORCH_CHECK(delta.is_cuda());
|
||||
STD_TORCH_CHECK(A.is_cuda());
|
||||
STD_TORCH_CHECK(B.is_cuda());
|
||||
STD_TORCH_CHECK(C.is_cuda());
|
||||
|
||||
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
||||
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
||||
STD_TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
||||
STD_TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
||||
|
||||
const auto sizes = u.sizes();
|
||||
const bool varlen = query_start_loc.has_value();
|
||||
@@ -702,7 +693,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
const int dstate = A.size(1);
|
||||
const int n_groups = varlen ? B.size(0) : B.size(1);
|
||||
|
||||
TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
|
||||
STD_TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
|
||||
|
||||
if (varlen) {
|
||||
CHECK_SHAPE(u, dim, seqlen);
|
||||
@@ -712,78 +703,78 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
||||
}
|
||||
CHECK_SHAPE(A, dim, dstate);
|
||||
TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size")
|
||||
STD_TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size");
|
||||
if (varlen) {
|
||||
CHECK_SHAPE(B, n_groups, dstate, seqlen);
|
||||
} else {
|
||||
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen);
|
||||
}
|
||||
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
||||
STD_TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
||||
|
||||
TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size")
|
||||
STD_TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size");
|
||||
if (varlen) {
|
||||
CHECK_SHAPE(C, n_groups, dstate, seqlen);
|
||||
} else {
|
||||
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
|
||||
}
|
||||
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
||||
STD_TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
||||
|
||||
if (D_.has_value()) {
|
||||
auto D = D_.value();
|
||||
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
||||
TORCH_CHECK(D.is_cuda());
|
||||
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
||||
STD_TORCH_CHECK(D.scalar_type() == torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(D.is_cuda());
|
||||
STD_TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
||||
CHECK_SHAPE(D, dim);
|
||||
}
|
||||
|
||||
if (delta_bias_.has_value()) {
|
||||
auto delta_bias = delta_bias_.value();
|
||||
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
||||
TORCH_CHECK(delta_bias.is_cuda());
|
||||
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
||||
STD_TORCH_CHECK(delta_bias.scalar_type() == torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(delta_bias.is_cuda());
|
||||
STD_TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
||||
CHECK_SHAPE(delta_bias, dim);
|
||||
}
|
||||
|
||||
|
||||
if (has_initial_state.has_value()) {
|
||||
auto has_initial_state_ = has_initial_state.value();
|
||||
TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool);
|
||||
TORCH_CHECK(has_initial_state_.is_cuda());
|
||||
STD_TORCH_CHECK(has_initial_state_.scalar_type() == torch::headeronly::ScalarType::Bool);
|
||||
STD_TORCH_CHECK(has_initial_state_.is_cuda());
|
||||
CHECK_SHAPE(has_initial_state_, batch_size);
|
||||
}
|
||||
|
||||
|
||||
if (query_start_loc.has_value()) {
|
||||
auto query_start_loc_ = query_start_loc.value();
|
||||
TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(query_start_loc_.is_cuda());
|
||||
STD_TORCH_CHECK(query_start_loc_.scalar_type() == torch::headeronly::ScalarType::Int);
|
||||
STD_TORCH_CHECK(query_start_loc_.is_cuda());
|
||||
}
|
||||
|
||||
|
||||
if (cache_indices.has_value()) {
|
||||
auto cache_indices_ = cache_indices.value();
|
||||
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(cache_indices_.is_cuda());
|
||||
STD_TORCH_CHECK(cache_indices_.scalar_type() == torch::headeronly::ScalarType::Int);
|
||||
STD_TORCH_CHECK(cache_indices_.is_cuda());
|
||||
|
||||
// cache_indices can be either 1D (batch_size,) for non-APC mode
|
||||
// or 2D (batch_size, max_positions) for APC mode
|
||||
const bool is_apc_mode = block_idx_first_scheduled_token.has_value();
|
||||
if (is_apc_mode) {
|
||||
TORCH_CHECK(cache_indices_.dim() == 2, "cache_indices must be 2D for APC mode");
|
||||
TORCH_CHECK(cache_indices_.size(0) == batch_size, "cache_indices first dimension must match batch_size");
|
||||
STD_TORCH_CHECK(cache_indices_.dim() == 2, "cache_indices must be 2D for APC mode");
|
||||
STD_TORCH_CHECK(cache_indices_.size(0) == batch_size, "cache_indices first dimension must match batch_size");
|
||||
} else {
|
||||
CHECK_SHAPE(cache_indices_, batch_size);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
at::Tensor z, out_z;
|
||||
torch::stable::Tensor z, out_z;
|
||||
const bool has_z = z_.has_value();
|
||||
if (has_z) {
|
||||
z = z_.value();
|
||||
TORCH_CHECK(z.scalar_type() == input_type);
|
||||
TORCH_CHECK(z.is_cuda());
|
||||
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||
STD_TORCH_CHECK(z.scalar_type() == input_type);
|
||||
STD_TORCH_CHECK(z.is_cuda());
|
||||
STD_TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||
if (varlen){
|
||||
CHECK_SHAPE(z, dim, seqlen);
|
||||
} else {
|
||||
@@ -794,12 +785,12 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
}
|
||||
|
||||
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
||||
at::Tensor out = delta;
|
||||
torch::stable::Tensor out = delta;
|
||||
// ssm_states can now be either the same as input_type or float32
|
||||
auto state_type = ssm_states.scalar_type();
|
||||
TORCH_CHECK(state_type == input_type || state_type == at::ScalarType::Float);
|
||||
TORCH_CHECK(ssm_states.is_cuda());
|
||||
TORCH_CHECK(ssm_states.stride(-1) == 1);
|
||||
STD_TORCH_CHECK(state_type == input_type || state_type == torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(ssm_states.is_cuda());
|
||||
STD_TORCH_CHECK(ssm_states.stride(-1) == 1);
|
||||
|
||||
SSMParamsBase params;
|
||||
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C,
|
||||
@@ -823,8 +814,8 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
);
|
||||
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(u));
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(u.get_device_index());
|
||||
auto stream = get_current_cuda_stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), ssm_states.scalar_type(), "selective_scan_fwd", [&] {
|
||||
selective_scan_fwd_cuda<input_t, weight_t, state_t>(params, stream);
|
||||
});
|
||||
@@ -164,6 +164,17 @@ torch::stable::Tensor awq_dequantize(torch::stable::Tensor _kernel,
|
||||
|
||||
#endif
|
||||
|
||||
// Attention kernels (shared CUDA/ROCm)
|
||||
void merge_attn_states(
|
||||
torch::stable::Tensor& output,
|
||||
std::optional<torch::stable::Tensor> output_lse,
|
||||
const torch::stable::Tensor& prefix_output,
|
||||
const torch::stable::Tensor& prefix_lse,
|
||||
const torch::stable::Tensor& suffix_output,
|
||||
const torch::stable::Tensor& suffix_lse,
|
||||
const std::optional<int64_t> prefill_tokens_with_context,
|
||||
const std::optional<torch::stable::Tensor>& output_scale = std::nullopt);
|
||||
|
||||
torch::stable::Tensor hadacore_transform(torch::stable::Tensor& x,
|
||||
bool inplace);
|
||||
|
||||
@@ -220,6 +231,48 @@ void fused_qk_norm_rope(torch::stable::Tensor& qkv, int64_t num_heads_q,
|
||||
torch::stable::Tensor& position_ids,
|
||||
int64_t forced_token_heads_per_warp);
|
||||
|
||||
// Sampler kernels (shared CUDA/ROCm)
|
||||
void apply_repetition_penalties_(
|
||||
torch::stable::Tensor& logits, const torch::stable::Tensor& prompt_mask,
|
||||
const torch::stable::Tensor& output_mask,
|
||||
const torch::stable::Tensor& repetition_penalties);
|
||||
|
||||
void top_k_per_row_prefill(const torch::stable::Tensor& logits,
|
||||
const torch::stable::Tensor& rowStarts,
|
||||
const torch::stable::Tensor& rowEnds,
|
||||
torch::stable::Tensor& indices, int64_t numRows,
|
||||
int64_t stride0, int64_t stride1, int64_t topK);
|
||||
|
||||
void top_k_per_row_decode(const torch::stable::Tensor& logits, int64_t next_n,
|
||||
const torch::stable::Tensor& seqLens,
|
||||
torch::stable::Tensor& indices, int64_t numRows,
|
||||
int64_t stride0, int64_t stride1, int64_t topK);
|
||||
|
||||
void persistent_topk(const torch::stable::Tensor& logits,
|
||||
const torch::stable::Tensor& lengths,
|
||||
torch::stable::Tensor& output,
|
||||
torch::stable::Tensor& workspace, int64_t k,
|
||||
int64_t max_seq_len);
|
||||
|
||||
void selective_scan_fwd(
|
||||
const torch::stable::Tensor& u, const torch::stable::Tensor& delta,
|
||||
const torch::stable::Tensor& A, const torch::stable::Tensor& B,
|
||||
const torch::stable::Tensor& C,
|
||||
const std::optional<torch::stable::Tensor>& D_,
|
||||
const std::optional<torch::stable::Tensor>& z_,
|
||||
const std::optional<torch::stable::Tensor>& delta_bias_,
|
||||
bool delta_softplus,
|
||||
const std::optional<torch::stable::Tensor>& query_start_loc,
|
||||
const std::optional<torch::stable::Tensor>& cache_indices,
|
||||
const std::optional<torch::stable::Tensor>& has_initial_state,
|
||||
const torch::stable::Tensor& ssm_states, int64_t null_block_id,
|
||||
int64_t block_size,
|
||||
const std::optional<torch::stable::Tensor>& block_idx_first_scheduled_token,
|
||||
const std::optional<torch::stable::Tensor>& block_idx_last_scheduled_token,
|
||||
const std::optional<torch::stable::Tensor>& initial_state_idx,
|
||||
const std::optional<torch::stable::Tensor>& cu_chunk_seqlen,
|
||||
const std::optional<torch::stable::Tensor>& last_chunk_indices);
|
||||
|
||||
// Activation kernels (shared CUDA/ROCm)
|
||||
void silu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input);
|
||||
void silu_and_mul_clamp(torch::stable::Tensor& out,
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
#include "cuda_compat.h"
|
||||
#include "../cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "torch_utils.h"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/cub.cuh>
|
||||
@@ -618,14 +616,14 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
|
||||
} // namespace vllm
|
||||
|
||||
void apply_repetition_penalties_(
|
||||
torch::Tensor& logits, // [num_seqs, vocab_size], in-place
|
||||
const torch::Tensor& prompt_mask, // [num_seqs, vocab_size]
|
||||
const torch::Tensor& output_mask, // [num_seqs, vocab_size]
|
||||
const torch::Tensor& repetition_penalties) { // [num_seqs]
|
||||
TORCH_CHECK(logits.is_contiguous());
|
||||
TORCH_CHECK(prompt_mask.is_contiguous());
|
||||
TORCH_CHECK(output_mask.is_contiguous());
|
||||
TORCH_CHECK(repetition_penalties.is_contiguous());
|
||||
torch::stable::Tensor& logits, // [num_seqs, vocab_size], in-place
|
||||
const torch::stable::Tensor& prompt_mask, // [num_seqs, vocab_size]
|
||||
const torch::stable::Tensor& output_mask, // [num_seqs, vocab_size]
|
||||
const torch::stable::Tensor& repetition_penalties) { // [num_seqs]
|
||||
STD_TORCH_CHECK(logits.is_contiguous());
|
||||
STD_TORCH_CHECK(prompt_mask.is_contiguous());
|
||||
STD_TORCH_CHECK(output_mask.is_contiguous());
|
||||
STD_TORCH_CHECK(repetition_penalties.is_contiguous());
|
||||
|
||||
int vocab_size = logits.size(-1);
|
||||
int num_seqs = logits.size(0);
|
||||
@@ -635,7 +633,7 @@ void apply_repetition_penalties_(
|
||||
// Get number of SMs on the current device
|
||||
int sms = 0;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount,
|
||||
logits.get_device());
|
||||
logits.get_device_index());
|
||||
|
||||
// Compute tile_num and tile_size
|
||||
int tile_num =
|
||||
@@ -645,27 +643,29 @@ void apply_repetition_penalties_(
|
||||
// Each block handles one sequence and a tile of vocab
|
||||
dim3 grid(num_seqs, tile_num);
|
||||
dim3 block(std::min(tile_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(logits));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
logits.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream();
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
|
||||
logits.scalar_type(), "apply_repetition_penalties_kernel", [&] {
|
||||
vllm::apply_repetition_penalties_kernel<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
logits.data_ptr<scalar_t>(), prompt_mask.data_ptr<bool>(),
|
||||
output_mask.data_ptr<bool>(),
|
||||
repetition_penalties.data_ptr<scalar_t>(), num_seqs, vocab_size,
|
||||
tile_size);
|
||||
logits.mutable_data_ptr<scalar_t>(),
|
||||
prompt_mask.const_data_ptr<bool>(),
|
||||
output_mask.const_data_ptr<bool>(),
|
||||
repetition_penalties.const_data_ptr<scalar_t>(), num_seqs,
|
||||
vocab_size, tile_size);
|
||||
});
|
||||
}
|
||||
|
||||
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
||||
const torch::Tensor& seqLens, torch::Tensor& indices,
|
||||
int64_t numRows, int64_t stride0, int64_t stride1,
|
||||
int64_t topK) {
|
||||
void top_k_per_row_decode(const torch::stable::Tensor& logits, int64_t next_n,
|
||||
const torch::stable::Tensor& seqLens,
|
||||
torch::stable::Tensor& indices, int64_t numRows,
|
||||
int64_t stride0, int64_t stride1, int64_t topK) {
|
||||
constexpr int kSortingAlgorithmThreshold = 12288;
|
||||
constexpr int kSplitWorkThreshold = 200 * 1000;
|
||||
constexpr int kNumThreadsPerBlock = 512;
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const cudaStream_t stream = get_current_cuda_stream();
|
||||
const auto numColumns = logits.size(1);
|
||||
|
||||
// True if seqLens is 2D (B, next_n): each logit row has its own pre-computed
|
||||
@@ -677,72 +677,75 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
||||
// Use insertion sort
|
||||
vllm::topKPerRowDecode<kNumThreadsPerBlock, false>
|
||||
<<<numRows, kNumThreadsPerBlock, topK * sizeof(int32_t), stream>>>(
|
||||
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||
indices.data_ptr<int>(), static_cast<int>(stride0),
|
||||
logits.const_data_ptr<float>(), seqLens.const_data_ptr<int>(),
|
||||
indices.mutable_data_ptr<int>(), static_cast<int>(stride0),
|
||||
static_cast<int>(stride1), static_cast<int>(topK),
|
||||
static_cast<int>(next_n), seqLensIs2D);
|
||||
} else if (numColumns < kSplitWorkThreshold) {
|
||||
// From this threshold, use radix sort instead
|
||||
vllm::topKPerRowDecode<kNumThreadsPerBlock, true>
|
||||
<<<numRows, kNumThreadsPerBlock, topK * sizeof(int32_t), stream>>>(
|
||||
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||
indices.data_ptr<int>(), static_cast<int>(stride0),
|
||||
logits.const_data_ptr<float>(), seqLens.const_data_ptr<int>(),
|
||||
indices.mutable_data_ptr<int>(), static_cast<int>(stride0),
|
||||
static_cast<int>(stride1), static_cast<int>(topK),
|
||||
static_cast<int>(next_n), seqLensIs2D);
|
||||
} else {
|
||||
// Long sequences are run in two steps
|
||||
constexpr auto multipleBlocksPerRowConfig = 10;
|
||||
|
||||
const auto outIndicesAux =
|
||||
torch::empty({numRows, multipleBlocksPerRowConfig, topK},
|
||||
torch::dtype(torch::kInt32).device(logits.device()));
|
||||
const auto outLogitsAux =
|
||||
torch::empty({numRows, multipleBlocksPerRowConfig, topK},
|
||||
torch::dtype(torch::kFloat).device(logits.device()));
|
||||
const auto outIndicesAux = torch::stable::empty(
|
||||
{numRows, multipleBlocksPerRowConfig, topK},
|
||||
torch::headeronly::ScalarType::Int, std::nullopt, logits.device());
|
||||
const auto outLogitsAux = torch::stable::empty(
|
||||
{numRows, multipleBlocksPerRowConfig, topK},
|
||||
torch::headeronly::ScalarType::Float, std::nullopt, logits.device());
|
||||
|
||||
vllm::topKPerRowDecode<kNumThreadsPerBlock, true, true>
|
||||
<<<dim3(numRows, multipleBlocksPerRowConfig), kNumThreadsPerBlock,
|
||||
2 * topK * sizeof(int32_t), stream>>>(
|
||||
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||
outIndicesAux.data_ptr<int>(), static_cast<int>(stride0),
|
||||
logits.const_data_ptr<float>(), seqLens.const_data_ptr<int>(),
|
||||
outIndicesAux.mutable_data_ptr<int>(), static_cast<int>(stride0),
|
||||
static_cast<int>(stride1), static_cast<int>(topK),
|
||||
static_cast<int>(next_n), seqLensIs2D,
|
||||
outLogitsAux.data_ptr<float>());
|
||||
outLogitsAux.mutable_data_ptr<float>());
|
||||
|
||||
constexpr int kNumThreadsPerBlockMerge = 1024;
|
||||
vllm::topKPerRowDecode<kNumThreadsPerBlockMerge, true, false, true>
|
||||
<<<numRows, kNumThreadsPerBlockMerge, topK * sizeof(int32_t), stream>>>(
|
||||
outLogitsAux.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||
indices.data_ptr<int>(), multipleBlocksPerRowConfig * topK, 1,
|
||||
static_cast<int>(topK), static_cast<int>(next_n), seqLensIs2D,
|
||||
nullptr, multipleBlocksPerRowConfig, outIndicesAux.data_ptr<int>());
|
||||
outLogitsAux.const_data_ptr<float>(), seqLens.const_data_ptr<int>(),
|
||||
indices.mutable_data_ptr<int>(), multipleBlocksPerRowConfig * topK,
|
||||
1, static_cast<int>(topK), static_cast<int>(next_n), seqLensIs2D,
|
||||
nullptr, multipleBlocksPerRowConfig,
|
||||
outIndicesAux.const_data_ptr<int>());
|
||||
}
|
||||
}
|
||||
|
||||
void top_k_per_row_prefill(const torch::Tensor& logits,
|
||||
const torch::Tensor& rowStarts,
|
||||
const torch::Tensor& rowEnds, torch::Tensor& indices,
|
||||
int64_t numRows, int64_t stride0, int64_t stride1,
|
||||
int64_t topK) {
|
||||
void top_k_per_row_prefill(const torch::stable::Tensor& logits,
|
||||
const torch::stable::Tensor& rowStarts,
|
||||
const torch::stable::Tensor& rowEnds,
|
||||
torch::stable::Tensor& indices, int64_t numRows,
|
||||
int64_t stride0, int64_t stride1, int64_t topK) {
|
||||
constexpr int kSortingAlgorithmThreshold = 12288;
|
||||
constexpr int kNumThreadsPerBlock = 512;
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const cudaStream_t stream = get_current_cuda_stream();
|
||||
|
||||
int numInsertionBlocks =
|
||||
std::min(static_cast<int>(numRows), kSortingAlgorithmThreshold);
|
||||
vllm::topKPerRowPrefill<kNumThreadsPerBlock, false>
|
||||
<<<numInsertionBlocks, kNumThreadsPerBlock, topK * sizeof(int32_t),
|
||||
stream>>>(logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
|
||||
rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
|
||||
static_cast<int>(stride0), static_cast<int>(stride1),
|
||||
static_cast<int>(topK), 0);
|
||||
stream>>>(logits.const_data_ptr<float>(),
|
||||
rowStarts.const_data_ptr<int>(),
|
||||
rowEnds.const_data_ptr<int>(),
|
||||
indices.mutable_data_ptr<int>(), static_cast<int>(stride0),
|
||||
static_cast<int>(stride1), static_cast<int>(topK), 0);
|
||||
|
||||
if (numRows > kSortingAlgorithmThreshold) {
|
||||
int numRadixBlocks = numRows - kSortingAlgorithmThreshold;
|
||||
vllm::topKPerRowPrefill<kNumThreadsPerBlock, true>
|
||||
<<<numRadixBlocks, kNumThreadsPerBlock, topK * sizeof(int32_t),
|
||||
stream>>>(logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
|
||||
rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
|
||||
stream>>>(
|
||||
logits.const_data_ptr<float>(), rowStarts.const_data_ptr<int>(),
|
||||
rowEnds.const_data_ptr<int>(), indices.mutable_data_ptr<int>(),
|
||||
static_cast<int>(stride0), static_cast<int>(stride1),
|
||||
static_cast<int>(topK), kSortingAlgorithmThreshold);
|
||||
}
|
||||
@@ -1,49 +1,51 @@
|
||||
// Persistent TopK kernel for DeepSeek V3 sparse attention indexer.
|
||||
// See persistent_topk.cuh for kernel implementation.
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <algorithm>
|
||||
|
||||
#include "torch_utils.h"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include "persistent_topk.cuh"
|
||||
#include "../persistent_topk.cuh"
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
#ifndef USE_ROCM
|
||||
template <int TopK>
|
||||
void launch_persistent_topk(const torch::Tensor& logits,
|
||||
const torch::Tensor& lengths, torch::Tensor& output,
|
||||
torch::Tensor& workspace, int64_t max_seq_len) {
|
||||
void launch_persistent_topk(const torch::stable::Tensor& logits,
|
||||
const torch::stable::Tensor& lengths,
|
||||
torch::stable::Tensor& output,
|
||||
torch::stable::Tensor& workspace,
|
||||
int64_t max_seq_len) {
|
||||
namespace P = vllm::persistent;
|
||||
|
||||
const int64_t num_rows = logits.size(0);
|
||||
const int64_t stride = logits.stride(0);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const cudaStream_t stream = get_current_cuda_stream();
|
||||
|
||||
static int num_sms = 0;
|
||||
static int max_smem_per_block = 0;
|
||||
if (num_sms == 0) {
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device);
|
||||
cudaDeviceGetAttribute(&max_smem_per_block,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
const cudaDeviceProp* device_prop = get_device_prop();
|
||||
num_sms = device_prop->multiProcessorCount;
|
||||
max_smem_per_block = device_prop->sharedMemPerBlockOptin;
|
||||
}
|
||||
|
||||
if (num_rows > 32 && max_smem_per_block >= 128 * 1024) {
|
||||
cudaError_t status =
|
||||
vllm::FilteredTopKRaggedTransform<float, int32_t, TopK>(
|
||||
logits.data_ptr<float>(), output.data_ptr<int32_t>(),
|
||||
lengths.data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
|
||||
logits.const_data_ptr<float>(), output.mutable_data_ptr<int32_t>(),
|
||||
lengths.const_data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
|
||||
static_cast<uint32_t>(TopK), static_cast<uint32_t>(stride), stream);
|
||||
TORCH_CHECK(status == cudaSuccess,
|
||||
STD_TORCH_CHECK(status == cudaSuccess,
|
||||
"FilteredTopK failed: ", cudaGetErrorString(status));
|
||||
} else {
|
||||
TORCH_CHECK(workspace.is_cuda(), "workspace must be CUDA tensor");
|
||||
TORCH_CHECK(workspace.dtype() == torch::kUInt8, "workspace must be uint8");
|
||||
STD_TORCH_CHECK(workspace.is_cuda(), "workspace must be CUDA tensor");
|
||||
STD_TORCH_CHECK(
|
||||
workspace.scalar_type() == torch::headeronly::ScalarType::Byte,
|
||||
"workspace must be uint8");
|
||||
|
||||
int effective_max_smem;
|
||||
if (num_rows <= 4) {
|
||||
@@ -99,7 +101,7 @@ void launch_persistent_topk(const torch::Tensor& logits,
|
||||
&occupancy, P::persistent_topk_kernel<TopK, 1>, P::kThreadsPerBlock,
|
||||
smem_size);
|
||||
}
|
||||
TORCH_CHECK(occ_err == cudaSuccess,
|
||||
STD_TORCH_CHECK(occ_err == cudaSuccess,
|
||||
"persistent_topk occupancy query failed: ",
|
||||
cudaGetErrorString(occ_err));
|
||||
if (occupancy < 1) occupancy = 1;
|
||||
@@ -131,7 +133,8 @@ void launch_persistent_topk(const torch::Tensor& logits,
|
||||
// If the cooperative launch wouldn't fit, fall back to FilteredTopK
|
||||
// instead of deadlocking. Only relevant when needs_cooperative.
|
||||
if (needs_cooperative && total_ctas > hw_resident_cap) {
|
||||
TORCH_CHECK(max_smem_per_block >= 128 * 1024,
|
||||
STD_TORCH_CHECK(
|
||||
max_smem_per_block >= 128 * 1024,
|
||||
"persistent_topk would oversubscribe and the FilteredTopK "
|
||||
"fallback requires >=128KB smem per block (have ",
|
||||
max_smem_per_block, "). total_ctas=", total_ctas,
|
||||
@@ -140,17 +143,18 @@ void launch_persistent_topk(const torch::Tensor& logits,
|
||||
", smem=", smem_size, ").");
|
||||
cudaError_t status =
|
||||
vllm::FilteredTopKRaggedTransform<float, int32_t, TopK>(
|
||||
logits.data_ptr<float>(), output.data_ptr<int32_t>(),
|
||||
lengths.data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
|
||||
static_cast<uint32_t>(TopK), static_cast<uint32_t>(stride),
|
||||
stream);
|
||||
TORCH_CHECK(status == cudaSuccess,
|
||||
"FilteredTopK fallback failed: ", cudaGetErrorString(status));
|
||||
logits.const_data_ptr<float>(),
|
||||
output.mutable_data_ptr<int32_t>(),
|
||||
lengths.const_data_ptr<int32_t>(),
|
||||
static_cast<uint32_t>(num_rows), static_cast<uint32_t>(TopK),
|
||||
static_cast<uint32_t>(stride), stream);
|
||||
STD_TORCH_CHECK(status == cudaSuccess, "FilteredTopK fallback failed: ",
|
||||
cudaGetErrorString(status));
|
||||
return;
|
||||
}
|
||||
|
||||
size_t state_bytes = num_groups * sizeof(P::RadixRowState);
|
||||
TORCH_CHECK(workspace.size(0) >= static_cast<int64_t>(state_bytes),
|
||||
STD_TORCH_CHECK(workspace.size(0) >= static_cast<int64_t>(state_bytes),
|
||||
"workspace too small, need ", state_bytes, " bytes");
|
||||
|
||||
// Zero the per-group RadixRowState region before launch.
|
||||
@@ -179,22 +183,22 @@ void launch_persistent_topk(const torch::Tensor& logits,
|
||||
// first red_release. cudaMemsetAsync is stream-ordered: the zero
|
||||
// is globally visible before any CTA runs.
|
||||
{
|
||||
cudaError_t mz_err = cudaMemsetAsync(workspace.data_ptr<uint8_t>(), 0,
|
||||
state_bytes, stream);
|
||||
TORCH_CHECK(mz_err == cudaSuccess,
|
||||
cudaError_t mz_err = cudaMemsetAsync(
|
||||
workspace.mutable_data_ptr<uint8_t>(), 0, state_bytes, stream);
|
||||
STD_TORCH_CHECK(mz_err == cudaSuccess,
|
||||
"row_states memset failed: ", cudaGetErrorString(mz_err));
|
||||
}
|
||||
|
||||
P::PersistentTopKParams params;
|
||||
params.input = logits.data_ptr<float>();
|
||||
params.output = output.data_ptr<int32_t>();
|
||||
params.lengths = lengths.data_ptr<int32_t>();
|
||||
params.input = logits.const_data_ptr<float>();
|
||||
params.output = output.mutable_data_ptr<int32_t>();
|
||||
params.lengths = lengths.const_data_ptr<int32_t>();
|
||||
params.num_rows = static_cast<uint32_t>(num_rows);
|
||||
params.stride = static_cast<uint32_t>(stride);
|
||||
params.top_k = static_cast<uint32_t>(TopK);
|
||||
params.chunk_size = chunk_size;
|
||||
params.row_states =
|
||||
reinterpret_cast<P::RadixRowState*>(workspace.data_ptr<uint8_t>());
|
||||
params.row_states = reinterpret_cast<P::RadixRowState*>(
|
||||
workspace.mutable_data_ptr<uint8_t>());
|
||||
params.ctas_per_group = ctas_per_group;
|
||||
params.max_seq_len = static_cast<uint32_t>(max_seq_len);
|
||||
|
||||
@@ -203,7 +207,7 @@ void launch_persistent_topk(const torch::Tensor& logits,
|
||||
auto kernel = &P::persistent_topk_kernel<TOPK_VAL, VS>; \
|
||||
cudaError_t err = cudaFuncSetAttribute( \
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); \
|
||||
TORCH_CHECK(err == cudaSuccess, \
|
||||
STD_TORCH_CHECK(err == cudaSuccess, \
|
||||
"Failed to set smem: ", cudaGetErrorString(err)); \
|
||||
kernel<<<total_ctas, P::kThreadsPerBlock, smem_size, stream>>>(params); \
|
||||
} while (0)
|
||||
@@ -219,36 +223,41 @@ void launch_persistent_topk(const torch::Tensor& logits,
|
||||
}
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
TORCH_CHECK(err == cudaSuccess,
|
||||
STD_TORCH_CHECK(err == cudaSuccess,
|
||||
"persistent_topk failed: ", cudaGetErrorString(err));
|
||||
}
|
||||
#endif
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
|
||||
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
|
||||
void persistent_topk(const torch::stable::Tensor& logits,
|
||||
const torch::stable::Tensor& lengths,
|
||||
torch::stable::Tensor& output,
|
||||
torch::stable::Tensor& workspace, int64_t k,
|
||||
int64_t max_seq_len) {
|
||||
#ifndef USE_ROCM
|
||||
TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor");
|
||||
TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor");
|
||||
TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor");
|
||||
TORCH_CHECK(logits.dtype() == torch::kFloat32, "Only float32 supported");
|
||||
TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32");
|
||||
TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32");
|
||||
TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
|
||||
TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2,
|
||||
STD_TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor");
|
||||
STD_TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor");
|
||||
STD_TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor");
|
||||
STD_TORCH_CHECK(logits.scalar_type() == torch::headeronly::ScalarType::Float,
|
||||
"Only float32 supported");
|
||||
STD_TORCH_CHECK(lengths.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||
"lengths must be int32");
|
||||
STD_TORCH_CHECK(output.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||
"output must be int32");
|
||||
STD_TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
|
||||
STD_TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2,
|
||||
"lengths must be 1D or 2D");
|
||||
TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous");
|
||||
TORCH_CHECK(output.dim() == 2, "output must be 2D");
|
||||
STD_TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous");
|
||||
STD_TORCH_CHECK(output.dim() == 2, "output must be 2D");
|
||||
|
||||
const int64_t num_rows = logits.size(0);
|
||||
const int64_t stride = logits.stride(0);
|
||||
|
||||
TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch");
|
||||
TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
|
||||
STD_TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch");
|
||||
STD_TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
|
||||
"output size mismatch");
|
||||
TORCH_CHECK(k == 512 || k == 1024 || k == 2048,
|
||||
STD_TORCH_CHECK(
|
||||
k == 512 || k == 1024 || k == 2048,
|
||||
"persistent_topk supports k=512, k=1024, or k=2048, got k=", k);
|
||||
|
||||
if (k == 512) {
|
||||
@@ -262,6 +271,6 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
|
||||
max_seq_len);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "persistent_topk is not supported on ROCm");
|
||||
STD_TORCH_CHECK(false, "persistent_topk is not supported on ROCm");
|
||||
#endif
|
||||
}
|
||||
@@ -263,6 +263,20 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||
"CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor");
|
||||
#endif
|
||||
|
||||
// Merge attn states
|
||||
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
// can be used to combine partial attention results (in the split-KV case)
|
||||
ops.def(
|
||||
"merge_attn_states("
|
||||
" Tensor! output,"
|
||||
" Tensor!? output_lse,"
|
||||
" Tensor prefix_output,"
|
||||
" Tensor prefix_lse,"
|
||||
" Tensor suffix_output,"
|
||||
" Tensor suffix_lse,"
|
||||
" int!? prefill_tokens_with_context,"
|
||||
" Tensor? output_scale=None) -> ()");
|
||||
|
||||
// Hadamard transforms
|
||||
// conditionally compiled so impl registration is in source file
|
||||
ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor");
|
||||
@@ -319,6 +333,26 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||
"bool is_neox, Tensor position_ids, "
|
||||
"int forced_token_heads_per_warp=-1) -> ()");
|
||||
|
||||
// Apply repetition penalties to logits in-place.
|
||||
ops.def(
|
||||
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
|
||||
"Tensor output_mask, Tensor repetition_penalties) -> ()");
|
||||
|
||||
// Optimized top-k per row operations.
|
||||
ops.def(
|
||||
"top_k_per_row_prefill(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
|
||||
"Tensor! indices, int numRows, int stride0, "
|
||||
"int stride1, int topK) -> ()");
|
||||
|
||||
ops.def(
|
||||
"top_k_per_row_decode(Tensor logits, int next_n, "
|
||||
"Tensor seq_lens, Tensor! indices, "
|
||||
"int numRows, int stride0, int stride1, int topK) -> ()");
|
||||
|
||||
ops.def(
|
||||
"persistent_topk(Tensor logits, Tensor lengths, Tensor! output, "
|
||||
"Tensor workspace, int k, int max_seq_len) -> ()");
|
||||
|
||||
// Activation ops
|
||||
// Activation function used in SwiGLU.
|
||||
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
|
||||
@@ -422,6 +456,24 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||
"int type, SymInt row, SymInt tokens) -> Tensor");
|
||||
|
||||
ops.def("ggml_moe_get_block_size(int type) -> int");
|
||||
|
||||
// Mamba selective scan kernel
|
||||
ops.def(
|
||||
"selective_scan_fwd(Tensor! u, Tensor! delta,"
|
||||
"Tensor! A, Tensor! B, Tensor! C,"
|
||||
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
|
||||
"bool delta_softplus,"
|
||||
"Tensor? query_start_loc,"
|
||||
"Tensor? cache_indices,"
|
||||
"Tensor? has_initial_state,"
|
||||
"Tensor! ssm_states,"
|
||||
"int null_block_id,"
|
||||
"int block_size,"
|
||||
"Tensor? block_idx_first_scheduled_token,"
|
||||
"Tensor? block_idx_last_scheduled_token,"
|
||||
"Tensor? initial_state_idx,"
|
||||
"Tensor? cu_chunk_seqlen,"
|
||||
"Tensor? last_chunk_indices) -> ()");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||
@@ -469,6 +521,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||
// files (allspark_repack.cu and allspark_qgemm_w8a16.cu)
|
||||
#endif
|
||||
|
||||
ops.impl("merge_attn_states", TORCH_BOX(&merge_attn_states));
|
||||
|
||||
// Layernorm kernels (shared CUDA/ROCm)
|
||||
ops.impl("rms_norm", TORCH_BOX(&rms_norm));
|
||||
ops.impl("fused_add_rms_norm", TORCH_BOX(&fused_add_rms_norm));
|
||||
@@ -487,6 +541,13 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||
ops.impl("rotary_embedding", TORCH_BOX(&rotary_embedding));
|
||||
ops.impl("fused_qk_norm_rope", TORCH_BOX(&fused_qk_norm_rope));
|
||||
|
||||
// Sampler kernels (shared CUDA/ROCm)
|
||||
ops.impl("apply_repetition_penalties_",
|
||||
TORCH_BOX(&apply_repetition_penalties_));
|
||||
ops.impl("top_k_per_row_prefill", TORCH_BOX(&top_k_per_row_prefill));
|
||||
ops.impl("top_k_per_row_decode", TORCH_BOX(&top_k_per_row_decode));
|
||||
ops.impl("persistent_topk", TORCH_BOX(&persistent_topk));
|
||||
|
||||
// Activation kernels (shared CUDA/ROCm)
|
||||
ops.impl("silu_and_mul", TORCH_BOX(&silu_and_mul));
|
||||
ops.impl("mul_and_silu", TORCH_BOX(&mul_and_silu));
|
||||
@@ -519,6 +580,7 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||
ops.impl("ggml_mul_mat_a8", TORCH_BOX(&ggml_mul_mat_a8));
|
||||
ops.impl("ggml_moe_a8", TORCH_BOX(&ggml_moe_a8));
|
||||
ops.impl("ggml_moe_a8_vec", TORCH_BOX(&ggml_moe_a8_vec));
|
||||
ops.impl("selective_scan_fwd", TORCH_BOX(&selective_scan_fwd));
|
||||
}
|
||||
|
||||
// These capability-check functions take only primitive args (no tensors), so
|
||||
|
||||
-43
@@ -54,13 +54,6 @@ void paged_attention_v2(
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step);
|
||||
|
||||
void merge_attn_states(
|
||||
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
|
||||
const std::optional<int64_t> prefill_tokens_with_context,
|
||||
const std::optional<torch::Tensor>& output_scale = std::nullopt);
|
||||
|
||||
// rms_norm and fused_add_rms_norm declarations also exist in
|
||||
// csrc/libtorch_stable/ops.h (torch::stable ABI for CUDA). They remain here
|
||||
// because the CPU build still uses these torch::Tensor declarations.
|
||||
@@ -76,26 +69,6 @@ torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
|
||||
torch::Tensor const& cos_sin_cache, int64_t q_head_padded, double eps,
|
||||
int64_t cache_block_size);
|
||||
|
||||
void apply_repetition_penalties_(torch::Tensor& logits,
|
||||
const torch::Tensor& prompt_mask,
|
||||
const torch::Tensor& output_mask,
|
||||
const torch::Tensor& repetition_penalties);
|
||||
|
||||
void top_k_per_row_prefill(const torch::Tensor& logits,
|
||||
const torch::Tensor& rowStarts,
|
||||
const torch::Tensor& rowEnds, torch::Tensor& indices,
|
||||
int64_t numRows, int64_t stride0, int64_t stride1,
|
||||
int64_t topK);
|
||||
|
||||
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
||||
const torch::Tensor& seqLens, torch::Tensor& indices,
|
||||
int64_t numRows, int64_t stride0, int64_t stride1,
|
||||
int64_t topK);
|
||||
|
||||
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
|
||||
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
|
||||
int64_t max_seq_len);
|
||||
|
||||
void silu_and_mul_per_block_quant(torch::Tensor& out,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor& scales, int64_t group_size,
|
||||
@@ -150,22 +123,6 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor& scales,
|
||||
std::optional<torch::Tensor> const& azp);
|
||||
|
||||
void selective_scan_fwd(
|
||||
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
|
||||
const torch::Tensor& B, const torch::Tensor& C,
|
||||
const std::optional<torch::Tensor>& D_,
|
||||
const std::optional<torch::Tensor>& z_,
|
||||
const std::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
|
||||
const std::optional<torch::Tensor>& query_start_loc,
|
||||
const std::optional<torch::Tensor>& cache_indices,
|
||||
const std::optional<torch::Tensor>& has_initial_state,
|
||||
const torch::Tensor& ssm_states, int64_t null_block_id, int64_t block_size,
|
||||
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
|
||||
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
|
||||
const std::optional<torch::Tensor>& initial_state_idx,
|
||||
const std::optional<torch::Tensor>& cu_chunk_seqlen,
|
||||
const std::optional<torch::Tensor>& last_chunk_indices);
|
||||
|
||||
torch::Tensor dynamic_4bit_int_moe_cpu(
|
||||
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
|
||||
torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
|
||||
|
||||
@@ -128,7 +128,7 @@ struct RadixRowState {
|
||||
struct PersistentTopKParams {
|
||||
const float* __restrict__ input; // [num_rows, stride]
|
||||
int32_t* __restrict__ output; // [num_rows, top_k]
|
||||
int32_t* __restrict__ lengths; // [num_rows]
|
||||
const int32_t* __restrict__ lengths; // [num_rows]
|
||||
RadixRowState* row_states; // large path: per-group state
|
||||
uint32_t num_rows;
|
||||
uint32_t stride;
|
||||
@@ -1269,9 +1269,11 @@ constexpr int ComputeFilteredTopKVecSize(uint32_t max_len) {
|
||||
}
|
||||
|
||||
template <typename DType, typename IdType, uint32_t MAX_K = 2048>
|
||||
cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices,
|
||||
IdType* lengths, uint32_t num_rows,
|
||||
uint32_t top_k_val, uint32_t max_len,
|
||||
cudaError_t FilteredTopKRaggedTransform(const DType* input,
|
||||
IdType* output_indices,
|
||||
const IdType* lengths,
|
||||
uint32_t num_rows, uint32_t top_k_val,
|
||||
uint32_t max_len,
|
||||
cudaStream_t stream = 0) {
|
||||
constexpr size_t smem_size = FILTERED_TOPK_SMEM_DYNAMIC;
|
||||
constexpr int MAX_VEC = 16 / sizeof(DType);
|
||||
|
||||
@@ -62,21 +62,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
||||
|
||||
// Merge attn states
|
||||
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
// can be used to combine partial attention results (in the split-KV case)
|
||||
ops.def(
|
||||
"merge_attn_states("
|
||||
" Tensor! output,"
|
||||
" Tensor!? output_lse,"
|
||||
" Tensor prefix_output,"
|
||||
" Tensor prefix_lse,"
|
||||
" Tensor suffix_output,"
|
||||
" Tensor suffix_lse,"
|
||||
" int!? prefill_tokens_with_context,"
|
||||
" Tensor? output_scale=None) -> ()");
|
||||
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
|
||||
|
||||
// Activation ops (quantized only — basic ops moved to _C_stable_libtorch)
|
||||
ops.def(
|
||||
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
|
||||
@@ -105,31 +90,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA,
|
||||
&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert);
|
||||
|
||||
// Apply repetition penalties to logits in-place
|
||||
ops.def(
|
||||
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
|
||||
"Tensor output_mask, Tensor repetition_penalties) -> ()");
|
||||
ops.impl("apply_repetition_penalties_", torch::kCUDA,
|
||||
&apply_repetition_penalties_);
|
||||
|
||||
// Optimized top-k per row operation
|
||||
ops.def(
|
||||
"top_k_per_row_prefill(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
|
||||
"Tensor! indices, int numRows, int stride0, "
|
||||
"int stride1, int topK) -> ()");
|
||||
ops.impl("top_k_per_row_prefill", torch::kCUDA, &top_k_per_row_prefill);
|
||||
|
||||
ops.def(
|
||||
"top_k_per_row_decode(Tensor logits, int next_n, "
|
||||
"Tensor seq_lens, Tensor! indices, "
|
||||
"int numRows, int stride0, int stride1, int topK) -> ()");
|
||||
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
|
||||
|
||||
ops.def(
|
||||
"persistent_topk(Tensor logits, Tensor lengths, Tensor! output, "
|
||||
"Tensor workspace, int k, int max_seq_len) -> ()");
|
||||
ops.impl("persistent_topk", torch::kCUDA, &persistent_topk);
|
||||
|
||||
// Quantization ops
|
||||
#ifndef USE_ROCM
|
||||
|
||||
@@ -230,25 +190,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
#endif
|
||||
|
||||
// Mamba selective scan kernel
|
||||
ops.def(
|
||||
"selective_scan_fwd(Tensor! u, Tensor! delta,"
|
||||
"Tensor! A, Tensor! B, Tensor! C,"
|
||||
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
|
||||
"bool delta_softplus,"
|
||||
"Tensor? query_start_loc,"
|
||||
"Tensor? cache_indices,"
|
||||
"Tensor? has_initial_state,"
|
||||
"Tensor! ssm_states,"
|
||||
"int null_block_id,"
|
||||
"int block_size,"
|
||||
"Tensor? block_idx_first_scheduled_token,"
|
||||
"Tensor? block_idx_last_scheduled_token,"
|
||||
"Tensor? initial_state_idx,"
|
||||
"Tensor? cu_chunk_seqlen,"
|
||||
"Tensor? last_chunk_indices) -> ()");
|
||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
ops.def(
|
||||
"minimax_allreduce_rms("
|
||||
|
||||
@@ -141,6 +141,7 @@ bbc5b7ede = "bbc5b7ede"
|
||||
NOOPs = "NOOPs"
|
||||
nin_shortcut = "nin_shortcut"
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin = "cudaDevAttrMaxSharedMemoryPerBlockOptin"
|
||||
sharedMemPerBlockOptin = "sharedMemPerBlockOptin"
|
||||
|
||||
depthwise_seperable_out_channel = "depthwise_seperable_out_channel"
|
||||
pard_token = "pard_token"
|
||||
|
||||
Reference in New Issue
Block a user