[8/n] Migrate merge_attn_states, mamba, sampler to torch stable ABI (continued) (#43361)

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Chris Leonard <chleonar@redhat.com>
Co-authored-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Co-authored-by: Shengqi Chen <harry-chen@outlook.com>
This commit is contained in:
Chris Leonard
2026-05-27 12:35:24 -04:00
committed by GitHub
parent 05c50c721e
commit 284e6f543d
13 changed files with 432 additions and 403 deletions
+5 -5
View File
@@ -305,14 +305,10 @@ endif()
#
set(VLLM_EXT_SRC
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
"csrc/cache_kernels.cu"
"csrc/cache_kernels_fused.cu"
"csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu"
"csrc/attention/merge_attn_states.cu"
"csrc/sampler.cu"
"csrc/topk.cu"
"csrc/cuda_view.cu"
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
"csrc/quantization/activation_kernels.cu"
@@ -633,7 +629,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
"csrc/libtorch_stable/fused_qknorm_rope_kernel.cu"
"csrc/libtorch_stable/layernorm_kernels.cu"
"csrc/libtorch_stable/layernorm_quant_kernels.cu"
"csrc/libtorch_stable/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu")
"csrc/libtorch_stable/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/libtorch_stable/attention/merge_attn_states.cu"
"csrc/libtorch_stable/sampler.cu"
"csrc/libtorch_stable/topk.cu"
"csrc/libtorch_stable/mamba/selective_scan_fwd.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC
@@ -1,14 +1,14 @@
#include <optional>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include <limits>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "../quantization/w8a8/fp8/common.cuh"
#include "../torch_utils.h"
#include "../dispatch_utils.h"
#include <torch/headeronly/core/ScalarType.h>
#include "../../attention/attention_dtypes.h"
#include "../../attention/attention_utils.cuh"
#include "../../quantization/w8a8/fp8/common.cuh"
namespace vllm {
@@ -196,17 +196,17 @@ __global__ void merge_attn_states_kernel(
// The following macro is used to dispatch the conversion function based on
// the output data type. The FN is a macro that calls a function with
// template<typename scalar_t>.
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
{ \
if (scalar_dtype == at::ScalarType::Float) { \
fn(float); \
} else if (scalar_dtype == at::ScalarType::Half) { \
fn(uint16_t); \
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
fn(__nv_bfloat16); \
} else { \
TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
} \
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
{ \
if (scalar_dtype == torch::headeronly::ScalarType::Float) { \
fn(float); \
} else if (scalar_dtype == torch::headeronly::ScalarType::Half) { \
fn(uint16_t); \
} else if (scalar_dtype == torch::headeronly::ScalarType::BFloat16) { \
fn(__nv_bfloat16); \
} else { \
STD_TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
} \
}
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, output_t, NUM_THREADS, \
@@ -245,11 +245,14 @@ __global__ void merge_attn_states_kernel(
*/
template <typename scalar_t>
void merge_attn_states_launcher(
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
torch::stable::Tensor& output,
std::optional<torch::stable::Tensor> output_lse,
const torch::stable::Tensor& prefix_output,
const torch::stable::Tensor& prefix_lse,
const torch::stable::Tensor& suffix_output,
const torch::stable::Tensor& suffix_lse,
const std::optional<int64_t> prefill_tokens_with_context,
const std::optional<torch::Tensor>& output_scale) {
const std::optional<torch::stable::Tensor>& output_scale) {
constexpr uint NUM_THREADS = 128;
const uint num_tokens = output.size(0);
const uint num_heads = output.size(1);
@@ -258,23 +261,23 @@ void merge_attn_states_launcher(
const uint output_head_stride = output.stride(1);
// Thread mapping is based on input BF16 pack_size
const uint pack_size = 16 / sizeof(scalar_t);
TORCH_CHECK(head_size % pack_size == 0,
"headsize must be multiple of pack_size:", pack_size);
STD_TORCH_CHECK(head_size % pack_size == 0,
"headsize must be multiple of pack_size:", pack_size);
const uint prefix_num_tokens =
prefill_tokens_with_context.has_value()
? static_cast<uint>(prefill_tokens_with_context.value())
: num_tokens;
TORCH_CHECK(prefix_num_tokens <= num_tokens,
"prefix_num_tokens must be <= num_tokens");
STD_TORCH_CHECK(prefix_num_tokens <= num_tokens,
"prefix_num_tokens must be <= num_tokens");
float* output_lse_ptr = nullptr;
if (output_lse.has_value()) {
output_lse_ptr = output_lse.value().data_ptr<float>();
output_lse_ptr = output_lse.value().mutable_data_ptr<float>();
}
float* output_scale_ptr = nullptr;
if (output_scale.has_value()) {
output_scale_ptr = output_scale.value().data_ptr<float>();
output_scale_ptr = output_scale.value().mutable_data_ptr<float>();
}
// Process one pack elements per thread. for float, the
// pack_size is 4 for half/bf16, the pack_size is 8.
@@ -284,14 +287,15 @@ void merge_attn_states_launcher(
dim3 block(NUM_THREADS);
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
auto stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
prefix_output.get_device_index());
auto stream = get_current_cuda_stream();
if (output_scale.has_value()) {
// FP8 output path - dispatch on output FP8 type
VLLM_DISPATCH_FP8_TYPES(output.scalar_type(), "merge_attn_states_fp8", [&] {
LAUNCH_MERGE_ATTN_STATES(scalar_t, fp8_t, NUM_THREADS, true);
});
VLLM_STABLE_DISPATCH_FP8_TYPES(
output.scalar_type(), "merge_attn_states_fp8",
[&] { LAUNCH_MERGE_ATTN_STATES(scalar_t, fp8_t, NUM_THREADS, true); });
} else {
// Original BF16/FP16/FP32 output path
LAUNCH_MERGE_ATTN_STATES(scalar_t, scalar_t, NUM_THREADS, false);
@@ -305,26 +309,29 @@ void merge_attn_states_launcher(
suffix_lse, prefill_tokens_with_context, output_scale); \
}
void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse,
std::optional<int64_t> prefill_tokens_with_context,
const std::optional<torch::Tensor>& output_scale) {
void merge_attn_states(
torch::stable::Tensor& output,
std::optional<torch::stable::Tensor> output_lse,
const torch::stable::Tensor& prefix_output,
const torch::stable::Tensor& prefix_lse,
const torch::stable::Tensor& suffix_output,
const torch::stable::Tensor& suffix_lse,
const std::optional<int64_t> prefill_tokens_with_context,
const std::optional<torch::stable::Tensor>& output_scale) {
if (output_scale.has_value()) {
TORCH_CHECK(output.scalar_type() == at::ScalarType::Float8_e4m3fn ||
output.scalar_type() == at::ScalarType::Float8_e4m3fnuz,
"output must be FP8 when output_scale is provided, got: ",
output.scalar_type());
STD_TORCH_CHECK(
output.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn ||
output.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fnuz,
"output must be FP8 when output_scale is provided, got: ",
output.scalar_type());
} else {
TORCH_CHECK(output.scalar_type() == prefix_output.scalar_type(),
"output dtype (", output.scalar_type(),
") must match prefix_output dtype (",
prefix_output.scalar_type(), ") when output_scale is not set");
STD_TORCH_CHECK(
output.scalar_type() == prefix_output.scalar_type(), "output dtype (",
output.scalar_type(), ") must match prefix_output dtype (",
prefix_output.scalar_type(), ") when output_scale is not set");
}
// Always dispatch on prefix_output (input) dtype
DISPATCH_BY_SCALAR_DTYPE(prefix_output.dtype(),
DISPATCH_BY_SCALAR_DTYPE(prefix_output.scalar_type(),
CALL_MERGE_ATTN_STATES_LAUNCHER);
}
@@ -12,6 +12,9 @@
#include <hip/hip_bf16.h>
#endif
#include <cuda_fp16.h>
#include <torch/headeronly/util/Half.h>
#include <torch/headeronly/util/BFloat16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SSMParamsBase {
@@ -159,8 +162,8 @@ struct Converter{
};
template<int N>
struct Converter<at::Half, N>{
static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
struct Converter<torch::headeronly::Half, N>{
static inline __device__ void to_float(const torch::headeronly::Half (&src)[N], float (&dst)[N]) {
static_assert(N % 2 == 0);
auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
@@ -171,8 +174,8 @@ struct Converter<at::Half, N>{
#if __CUDA_ARCH__ >= 800
template<int N>
struct Converter<at::BFloat16, N>{
static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
struct Converter<torch::headeronly::BFloat16, N>{
static inline __device__ void to_float(const torch::headeronly::BFloat16 (&src)[N], float (&dst)[N]) {
static_assert(N % 2 == 0);
auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
@@ -1,18 +1,9 @@
// clang-format off
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../torch_utils.h"
#include <torch/csrc/stable/macros.h>
#include "selective_scan.h"
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#ifdef USE_ROCM
#include <c10/hip/HIPException.h> // For C10_HIP_CHECK and C10_HIP_KERNEL_LAUNCH_CHECK
#else
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#endif
#ifndef USE_ROCM
#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>
@@ -416,15 +407,15 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
if (kSmemSize >= 48 * 1024) {
#ifdef USE_ROCM
C10_HIP_CHECK(hipFuncSetAttribute(
STD_CUDA_CHECK(hipFuncSetAttribute(
reinterpret_cast<const void*>(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
#else
C10_CUDA_CHECK(cudaFuncSetAttribute(
STD_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
#endif
}
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
STD_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
@@ -462,46 +453,46 @@ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
#endif
}
template void selective_scan_fwd_cuda<at::BFloat16, float, at::BFloat16>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<at::BFloat16, float, float>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<at::Half, float, at::Half>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<at::Half, float, float>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<torch::headeronly::BFloat16, float, torch::headeronly::BFloat16>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<torch::headeronly::BFloat16, float, float>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<torch::headeronly::Half, float, torch::headeronly::Half>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<torch::headeronly::Half, float, float>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<float, float, float>(SSMParamsBase &params, cudaStream_t stream);
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_SHAPE(x, ...) STD_TORCH_CHECK(x.sizes().equals(torch::headeronly::IntHeaderOnlyArrayRef({__VA_ARGS__})), #x " must have shape (" #__VA_ARGS__ ")")
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, STYPE, NAME, ...) \
if (ITYPE == at::ScalarType::Half) { \
using input_t = at::Half; \
if (ITYPE == torch::headeronly::ScalarType::Half) { \
using input_t = torch::headeronly::Half; \
using weight_t = float; \
if (STYPE == at::ScalarType::Half) { \
using state_t = at::Half; \
if (STYPE == torch::headeronly::ScalarType::Half) { \
using state_t = torch::headeronly::Half; \
__VA_ARGS__(); \
} else if (STYPE == at::ScalarType::Float) { \
} else if (STYPE == torch::headeronly::ScalarType::Float) { \
using state_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
STD_TORCH_CHECK(false, #NAME " not implemented for state type '", STYPE, "'"); \
} \
} else if (ITYPE == at::ScalarType::BFloat16) { \
using input_t = at::BFloat16; \
} else if (ITYPE == torch::headeronly::ScalarType::BFloat16) { \
using input_t = torch::headeronly::BFloat16; \
using weight_t = float; \
if (STYPE == at::ScalarType::BFloat16) { \
using state_t = at::BFloat16; \
if (STYPE == torch::headeronly::ScalarType::BFloat16) { \
using state_t = torch::headeronly::BFloat16; \
__VA_ARGS__(); \
} else if (STYPE == at::ScalarType::Float) { \
} else if (STYPE == torch::headeronly::ScalarType::Float) { \
using state_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
STD_TORCH_CHECK(false, #NAME " not implemented for state type '", STYPE, "'"); \
} \
} else if (ITYPE == at::ScalarType::Float) { \
} else if (ITYPE == torch::headeronly::ScalarType::Float) { \
using input_t = float; \
using weight_t = float; \
using state_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
STD_TORCH_CHECK(false, #NAME " not implemented for input type '", ITYPE, "'"); \
}
@@ -518,30 +509,30 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const bool is_variable_B,
const bool is_variable_C,
// device pointers
const torch::Tensor u,
const torch::Tensor delta,
const torch::Tensor A,
const torch::Tensor B,
const torch::Tensor C,
const torch::Tensor out,
const torch::Tensor z,
const torch::Tensor out_z,
const std::optional<at::Tensor>& D,
const std::optional<at::Tensor>& delta_bias,
const torch::Tensor ssm_states,
const torch::stable::Tensor u,
const torch::stable::Tensor delta,
const torch::stable::Tensor A,
const torch::stable::Tensor B,
const torch::stable::Tensor C,
const torch::stable::Tensor out,
const torch::stable::Tensor z,
const torch::stable::Tensor out_z,
const std::optional<torch::stable::Tensor>& D,
const std::optional<torch::stable::Tensor>& delta_bias,
const torch::stable::Tensor ssm_states,
bool has_z,
bool delta_softplus,
const std::optional<at::Tensor>& query_start_loc,
const std::optional<at::Tensor>& cache_indices,
const std::optional<at::Tensor>& has_initial_state,
const std::optional<torch::stable::Tensor>& query_start_loc,
const std::optional<torch::stable::Tensor>& cache_indices,
const std::optional<torch::stable::Tensor>& has_initial_state,
bool varlen,
int64_t null_block_id,
int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
const std::optional<torch::Tensor> &initial_state_idx,
const std::optional<torch::Tensor> &cu_chunk_seqlen,
const std::optional<torch::Tensor> &last_chunk_indices) {
const std::optional<torch::stable::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::stable::Tensor> &block_idx_last_scheduled_token,
const std::optional<torch::stable::Tensor> &initial_state_idx,
const std::optional<torch::stable::Tensor> &cu_chunk_seqlen,
const std::optional<torch::stable::Tensor> &last_chunk_indices) {
// Reset the parameters
memset(&params, 0, sizeof(params));
@@ -654,45 +645,45 @@ void set_ssm_params_fwd(SSMParamsBase &params,
}
}
void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C,
const std::optional<torch::Tensor> &D_,
const std::optional<torch::Tensor> &z_,
const std::optional<torch::Tensor> &delta_bias_,
void selective_scan_fwd(const torch::stable::Tensor &u, const torch::stable::Tensor &delta,
const torch::stable::Tensor &A, const torch::stable::Tensor &B, const torch::stable::Tensor &C,
const std::optional<torch::stable::Tensor> &D_,
const std::optional<torch::stable::Tensor> &z_,
const std::optional<torch::stable::Tensor> &delta_bias_,
bool delta_softplus,
const std::optional<torch::Tensor> &query_start_loc,
const std::optional<torch::Tensor> &cache_indices,
const std::optional<torch::Tensor> &has_initial_state,
const torch::Tensor &ssm_states,
const std::optional<torch::stable::Tensor> &query_start_loc,
const std::optional<torch::stable::Tensor> &cache_indices,
const std::optional<torch::stable::Tensor> &has_initial_state,
const torch::stable::Tensor &ssm_states,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t null_block_id,
int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
const std::optional<torch::Tensor> &initial_state_idx,
const std::optional<torch::Tensor> &cu_chunk_seqlen,
const std::optional<torch::Tensor> &last_chunk_indices) {
const std::optional<torch::stable::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::stable::Tensor> &block_idx_last_scheduled_token,
const std::optional<torch::stable::Tensor> &initial_state_idx,
const std::optional<torch::stable::Tensor> &cu_chunk_seqlen,
const std::optional<torch::stable::Tensor> &last_chunk_indices) {
auto input_type = u.scalar_type();
auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
TORCH_CHECK(weight_type == at::ScalarType::Float);
STD_TORCH_CHECK(input_type == torch::headeronly::ScalarType::Float || input_type == torch::headeronly::ScalarType::Half || input_type == torch::headeronly::ScalarType::BFloat16);
STD_TORCH_CHECK(weight_type == torch::headeronly::ScalarType::Float);
const bool is_variable_B = B.dim() >= 3;
const bool is_variable_C = C.dim() >= 3;
TORCH_CHECK(delta.scalar_type() == input_type);
TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
STD_TORCH_CHECK(delta.scalar_type() == input_type);
STD_TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
STD_TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
TORCH_CHECK(u.is_cuda());
TORCH_CHECK(delta.is_cuda());
TORCH_CHECK(A.is_cuda());
TORCH_CHECK(B.is_cuda());
TORCH_CHECK(C.is_cuda());
STD_TORCH_CHECK(u.is_cuda());
STD_TORCH_CHECK(delta.is_cuda());
STD_TORCH_CHECK(A.is_cuda());
STD_TORCH_CHECK(B.is_cuda());
STD_TORCH_CHECK(C.is_cuda());
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
STD_TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
STD_TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
const auto sizes = u.sizes();
const bool varlen = query_start_loc.has_value();
@@ -702,7 +693,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const int dstate = A.size(1);
const int n_groups = varlen ? B.size(0) : B.size(1);
TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
STD_TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
if (varlen) {
CHECK_SHAPE(u, dim, seqlen);
@@ -712,94 +703,94 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
CHECK_SHAPE(delta, batch_size, dim, seqlen);
}
CHECK_SHAPE(A, dim, dstate);
TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size")
STD_TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size");
if (varlen) {
CHECK_SHAPE(B, n_groups, dstate, seqlen);
} else {
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen);
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen);
}
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
STD_TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size")
STD_TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size");
if (varlen) {
CHECK_SHAPE(C, n_groups, dstate, seqlen);
} else {
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
}
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
STD_TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
if (D_.has_value()) {
auto D = D_.value();
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
TORCH_CHECK(D.is_cuda());
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
STD_TORCH_CHECK(D.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(D.is_cuda());
STD_TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
CHECK_SHAPE(D, dim);
}
if (delta_bias_.has_value()) {
auto delta_bias = delta_bias_.value();
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
TORCH_CHECK(delta_bias.is_cuda());
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
STD_TORCH_CHECK(delta_bias.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(delta_bias.is_cuda());
STD_TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
CHECK_SHAPE(delta_bias, dim);
}
if (has_initial_state.has_value()) {
auto has_initial_state_ = has_initial_state.value();
TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool);
TORCH_CHECK(has_initial_state_.is_cuda());
STD_TORCH_CHECK(has_initial_state_.scalar_type() == torch::headeronly::ScalarType::Bool);
STD_TORCH_CHECK(has_initial_state_.is_cuda());
CHECK_SHAPE(has_initial_state_, batch_size);
}
if (query_start_loc.has_value()) {
auto query_start_loc_ = query_start_loc.value();
TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(query_start_loc_.is_cuda());
STD_TORCH_CHECK(query_start_loc_.scalar_type() == torch::headeronly::ScalarType::Int);
STD_TORCH_CHECK(query_start_loc_.is_cuda());
}
if (cache_indices.has_value()) {
auto cache_indices_ = cache_indices.value();
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(cache_indices_.is_cuda());
STD_TORCH_CHECK(cache_indices_.scalar_type() == torch::headeronly::ScalarType::Int);
STD_TORCH_CHECK(cache_indices_.is_cuda());
// cache_indices can be either 1D (batch_size,) for non-APC mode
// or 2D (batch_size, max_positions) for APC mode
const bool is_apc_mode = block_idx_first_scheduled_token.has_value();
if (is_apc_mode) {
TORCH_CHECK(cache_indices_.dim() == 2, "cache_indices must be 2D for APC mode");
TORCH_CHECK(cache_indices_.size(0) == batch_size, "cache_indices first dimension must match batch_size");
STD_TORCH_CHECK(cache_indices_.dim() == 2, "cache_indices must be 2D for APC mode");
STD_TORCH_CHECK(cache_indices_.size(0) == batch_size, "cache_indices first dimension must match batch_size");
} else {
CHECK_SHAPE(cache_indices_, batch_size);
}
}
at::Tensor z, out_z;
torch::stable::Tensor z, out_z;
const bool has_z = z_.has_value();
if (has_z) {
z = z_.value();
TORCH_CHECK(z.scalar_type() == input_type);
TORCH_CHECK(z.is_cuda());
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
STD_TORCH_CHECK(z.scalar_type() == input_type);
STD_TORCH_CHECK(z.is_cuda());
STD_TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
if (varlen){
CHECK_SHAPE(z, dim, seqlen);
} else {
CHECK_SHAPE(z, batch_size, dim, seqlen);
}
out_z = z;
}
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
at::Tensor out = delta;
torch::stable::Tensor out = delta;
// ssm_states can now be either the same as input_type or float32
auto state_type = ssm_states.scalar_type();
TORCH_CHECK(state_type == input_type || state_type == at::ScalarType::Float);
TORCH_CHECK(ssm_states.is_cuda());
TORCH_CHECK(ssm_states.stride(-1) == 1);
STD_TORCH_CHECK(state_type == input_type || state_type == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(ssm_states.is_cuda());
STD_TORCH_CHECK(ssm_states.stride(-1) == 1);
SSMParamsBase params;
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C,
@@ -823,8 +814,8 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
);
const at::cuda::OptionalCUDAGuard device_guard(device_of(u));
auto stream = at::cuda::getCurrentCUDAStream().stream();
const torch::stable::accelerator::DeviceGuard device_guard(u.get_device_index());
auto stream = get_current_cuda_stream();
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), ssm_states.scalar_type(), "selective_scan_fwd", [&] {
selective_scan_fwd_cuda<input_t, weight_t, state_t>(params, stream);
});
+53
View File
@@ -164,6 +164,17 @@ torch::stable::Tensor awq_dequantize(torch::stable::Tensor _kernel,
#endif
// Attention kernels (shared CUDA/ROCm)
void merge_attn_states(
torch::stable::Tensor& output,
std::optional<torch::stable::Tensor> output_lse,
const torch::stable::Tensor& prefix_output,
const torch::stable::Tensor& prefix_lse,
const torch::stable::Tensor& suffix_output,
const torch::stable::Tensor& suffix_lse,
const std::optional<int64_t> prefill_tokens_with_context,
const std::optional<torch::stable::Tensor>& output_scale = std::nullopt);
torch::stable::Tensor hadacore_transform(torch::stable::Tensor& x,
bool inplace);
@@ -220,6 +231,48 @@ void fused_qk_norm_rope(torch::stable::Tensor& qkv, int64_t num_heads_q,
torch::stable::Tensor& position_ids,
int64_t forced_token_heads_per_warp);
// Sampler kernels (shared CUDA/ROCm)
void apply_repetition_penalties_(
torch::stable::Tensor& logits, const torch::stable::Tensor& prompt_mask,
const torch::stable::Tensor& output_mask,
const torch::stable::Tensor& repetition_penalties);
void top_k_per_row_prefill(const torch::stable::Tensor& logits,
const torch::stable::Tensor& rowStarts,
const torch::stable::Tensor& rowEnds,
torch::stable::Tensor& indices, int64_t numRows,
int64_t stride0, int64_t stride1, int64_t topK);
void top_k_per_row_decode(const torch::stable::Tensor& logits, int64_t next_n,
const torch::stable::Tensor& seqLens,
torch::stable::Tensor& indices, int64_t numRows,
int64_t stride0, int64_t stride1, int64_t topK);
void persistent_topk(const torch::stable::Tensor& logits,
const torch::stable::Tensor& lengths,
torch::stable::Tensor& output,
torch::stable::Tensor& workspace, int64_t k,
int64_t max_seq_len);
void selective_scan_fwd(
const torch::stable::Tensor& u, const torch::stable::Tensor& delta,
const torch::stable::Tensor& A, const torch::stable::Tensor& B,
const torch::stable::Tensor& C,
const std::optional<torch::stable::Tensor>& D_,
const std::optional<torch::stable::Tensor>& z_,
const std::optional<torch::stable::Tensor>& delta_bias_,
bool delta_softplus,
const std::optional<torch::stable::Tensor>& query_start_loc,
const std::optional<torch::stable::Tensor>& cache_indices,
const std::optional<torch::stable::Tensor>& has_initial_state,
const torch::stable::Tensor& ssm_states, int64_t null_block_id,
int64_t block_size,
const std::optional<torch::stable::Tensor>& block_idx_first_scheduled_token,
const std::optional<torch::stable::Tensor>& block_idx_last_scheduled_token,
const std::optional<torch::stable::Tensor>& initial_state_idx,
const std::optional<torch::stable::Tensor>& cu_chunk_seqlen,
const std::optional<torch::stable::Tensor>& last_chunk_indices);
// Activation kernels (shared CUDA/ROCm)
void silu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input);
void silu_and_mul_clamp(torch::stable::Tensor& out,
@@ -1,8 +1,6 @@
#include "cuda_compat.h"
#include "../cuda_compat.h"
#include "dispatch_utils.h"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include "torch_utils.h"
#ifndef USE_ROCM
#include <cub/cub.cuh>
@@ -618,14 +616,14 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
} // namespace vllm
void apply_repetition_penalties_(
torch::Tensor& logits, // [num_seqs, vocab_size], in-place
const torch::Tensor& prompt_mask, // [num_seqs, vocab_size]
const torch::Tensor& output_mask, // [num_seqs, vocab_size]
const torch::Tensor& repetition_penalties) { // [num_seqs]
TORCH_CHECK(logits.is_contiguous());
TORCH_CHECK(prompt_mask.is_contiguous());
TORCH_CHECK(output_mask.is_contiguous());
TORCH_CHECK(repetition_penalties.is_contiguous());
torch::stable::Tensor& logits, // [num_seqs, vocab_size], in-place
const torch::stable::Tensor& prompt_mask, // [num_seqs, vocab_size]
const torch::stable::Tensor& output_mask, // [num_seqs, vocab_size]
const torch::stable::Tensor& repetition_penalties) { // [num_seqs]
STD_TORCH_CHECK(logits.is_contiguous());
STD_TORCH_CHECK(prompt_mask.is_contiguous());
STD_TORCH_CHECK(output_mask.is_contiguous());
STD_TORCH_CHECK(repetition_penalties.is_contiguous());
int vocab_size = logits.size(-1);
int num_seqs = logits.size(0);
@@ -635,7 +633,7 @@ void apply_repetition_penalties_(
// Get number of SMs on the current device
int sms = 0;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount,
logits.get_device());
logits.get_device_index());
// Compute tile_num and tile_size
int tile_num =
@@ -645,27 +643,29 @@ void apply_repetition_penalties_(
// Each block handles one sequence and a tile of vocab
dim3 grid(num_seqs, tile_num);
dim3 block(std::min(tile_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(logits));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
const torch::stable::accelerator::DeviceGuard device_guard(
logits.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
logits.scalar_type(), "apply_repetition_penalties_kernel", [&] {
vllm::apply_repetition_penalties_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
logits.data_ptr<scalar_t>(), prompt_mask.data_ptr<bool>(),
output_mask.data_ptr<bool>(),
repetition_penalties.data_ptr<scalar_t>(), num_seqs, vocab_size,
tile_size);
logits.mutable_data_ptr<scalar_t>(),
prompt_mask.const_data_ptr<bool>(),
output_mask.const_data_ptr<bool>(),
repetition_penalties.const_data_ptr<scalar_t>(), num_seqs,
vocab_size, tile_size);
});
}
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
const torch::Tensor& seqLens, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1,
int64_t topK) {
void top_k_per_row_decode(const torch::stable::Tensor& logits, int64_t next_n,
const torch::stable::Tensor& seqLens,
torch::stable::Tensor& indices, int64_t numRows,
int64_t stride0, int64_t stride1, int64_t topK) {
constexpr int kSortingAlgorithmThreshold = 12288;
constexpr int kSplitWorkThreshold = 200 * 1000;
constexpr int kNumThreadsPerBlock = 512;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const cudaStream_t stream = get_current_cuda_stream();
const auto numColumns = logits.size(1);
// True if seqLens is 2D (B, next_n): each logit row has its own pre-computed
@@ -677,73 +677,76 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
// Use insertion sort
vllm::topKPerRowDecode<kNumThreadsPerBlock, false>
<<<numRows, kNumThreadsPerBlock, topK * sizeof(int32_t), stream>>>(
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
indices.data_ptr<int>(), static_cast<int>(stride0),
logits.const_data_ptr<float>(), seqLens.const_data_ptr<int>(),
indices.mutable_data_ptr<int>(), static_cast<int>(stride0),
static_cast<int>(stride1), static_cast<int>(topK),
static_cast<int>(next_n), seqLensIs2D);
} else if (numColumns < kSplitWorkThreshold) {
// From this threshold, use radix sort instead
vllm::topKPerRowDecode<kNumThreadsPerBlock, true>
<<<numRows, kNumThreadsPerBlock, topK * sizeof(int32_t), stream>>>(
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
indices.data_ptr<int>(), static_cast<int>(stride0),
logits.const_data_ptr<float>(), seqLens.const_data_ptr<int>(),
indices.mutable_data_ptr<int>(), static_cast<int>(stride0),
static_cast<int>(stride1), static_cast<int>(topK),
static_cast<int>(next_n), seqLensIs2D);
} else {
// Long sequences are run in two steps
constexpr auto multipleBlocksPerRowConfig = 10;
const auto outIndicesAux =
torch::empty({numRows, multipleBlocksPerRowConfig, topK},
torch::dtype(torch::kInt32).device(logits.device()));
const auto outLogitsAux =
torch::empty({numRows, multipleBlocksPerRowConfig, topK},
torch::dtype(torch::kFloat).device(logits.device()));
const auto outIndicesAux = torch::stable::empty(
{numRows, multipleBlocksPerRowConfig, topK},
torch::headeronly::ScalarType::Int, std::nullopt, logits.device());
const auto outLogitsAux = torch::stable::empty(
{numRows, multipleBlocksPerRowConfig, topK},
torch::headeronly::ScalarType::Float, std::nullopt, logits.device());
vllm::topKPerRowDecode<kNumThreadsPerBlock, true, true>
<<<dim3(numRows, multipleBlocksPerRowConfig), kNumThreadsPerBlock,
2 * topK * sizeof(int32_t), stream>>>(
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
outIndicesAux.data_ptr<int>(), static_cast<int>(stride0),
logits.const_data_ptr<float>(), seqLens.const_data_ptr<int>(),
outIndicesAux.mutable_data_ptr<int>(), static_cast<int>(stride0),
static_cast<int>(stride1), static_cast<int>(topK),
static_cast<int>(next_n), seqLensIs2D,
outLogitsAux.data_ptr<float>());
outLogitsAux.mutable_data_ptr<float>());
constexpr int kNumThreadsPerBlockMerge = 1024;
vllm::topKPerRowDecode<kNumThreadsPerBlockMerge, true, false, true>
<<<numRows, kNumThreadsPerBlockMerge, topK * sizeof(int32_t), stream>>>(
outLogitsAux.data_ptr<float>(), seqLens.data_ptr<int>(),
indices.data_ptr<int>(), multipleBlocksPerRowConfig * topK, 1,
static_cast<int>(topK), static_cast<int>(next_n), seqLensIs2D,
nullptr, multipleBlocksPerRowConfig, outIndicesAux.data_ptr<int>());
outLogitsAux.const_data_ptr<float>(), seqLens.const_data_ptr<int>(),
indices.mutable_data_ptr<int>(), multipleBlocksPerRowConfig * topK,
1, static_cast<int>(topK), static_cast<int>(next_n), seqLensIs2D,
nullptr, multipleBlocksPerRowConfig,
outIndicesAux.const_data_ptr<int>());
}
}
void top_k_per_row_prefill(const torch::Tensor& logits,
const torch::Tensor& rowStarts,
const torch::Tensor& rowEnds, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1,
int64_t topK) {
void top_k_per_row_prefill(const torch::stable::Tensor& logits,
const torch::stable::Tensor& rowStarts,
const torch::stable::Tensor& rowEnds,
torch::stable::Tensor& indices, int64_t numRows,
int64_t stride0, int64_t stride1, int64_t topK) {
constexpr int kSortingAlgorithmThreshold = 12288;
constexpr int kNumThreadsPerBlock = 512;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const cudaStream_t stream = get_current_cuda_stream();
int numInsertionBlocks =
std::min(static_cast<int>(numRows), kSortingAlgorithmThreshold);
vllm::topKPerRowPrefill<kNumThreadsPerBlock, false>
<<<numInsertionBlocks, kNumThreadsPerBlock, topK * sizeof(int32_t),
stream>>>(logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
static_cast<int>(stride0), static_cast<int>(stride1),
static_cast<int>(topK), 0);
stream>>>(logits.const_data_ptr<float>(),
rowStarts.const_data_ptr<int>(),
rowEnds.const_data_ptr<int>(),
indices.mutable_data_ptr<int>(), static_cast<int>(stride0),
static_cast<int>(stride1), static_cast<int>(topK), 0);
if (numRows > kSortingAlgorithmThreshold) {
int numRadixBlocks = numRows - kSortingAlgorithmThreshold;
vllm::topKPerRowPrefill<kNumThreadsPerBlock, true>
<<<numRadixBlocks, kNumThreadsPerBlock, topK * sizeof(int32_t),
stream>>>(logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
static_cast<int>(stride0), static_cast<int>(stride1),
static_cast<int>(topK), kSortingAlgorithmThreshold);
stream>>>(
logits.const_data_ptr<float>(), rowStarts.const_data_ptr<int>(),
rowEnds.const_data_ptr<int>(), indices.mutable_data_ptr<int>(),
static_cast<int>(stride0), static_cast<int>(stride1),
static_cast<int>(topK), kSortingAlgorithmThreshold);
}
}
+78 -69
View File
@@ -1,49 +1,51 @@
// Persistent TopK kernel for DeepSeek V3 sparse attention indexer.
// See persistent_topk.cuh for kernel implementation.
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <algorithm>
#include "torch_utils.h"
#ifndef USE_ROCM
#include "persistent_topk.cuh"
#include "../persistent_topk.cuh"
#endif
namespace {
#ifndef USE_ROCM
template <int TopK>
void launch_persistent_topk(const torch::Tensor& logits,
const torch::Tensor& lengths, torch::Tensor& output,
torch::Tensor& workspace, int64_t max_seq_len) {
void launch_persistent_topk(const torch::stable::Tensor& logits,
const torch::stable::Tensor& lengths,
torch::stable::Tensor& output,
torch::stable::Tensor& workspace,
int64_t max_seq_len) {
namespace P = vllm::persistent;
const int64_t num_rows = logits.size(0);
const int64_t stride = logits.stride(0);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const cudaStream_t stream = get_current_cuda_stream();
static int num_sms = 0;
static int max_smem_per_block = 0;
if (num_sms == 0) {
int device;
cudaGetDevice(&device);
cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device);
cudaDeviceGetAttribute(&max_smem_per_block,
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
const cudaDeviceProp* device_prop = get_device_prop();
num_sms = device_prop->multiProcessorCount;
max_smem_per_block = device_prop->sharedMemPerBlockOptin;
}
if (num_rows > 32 && max_smem_per_block >= 128 * 1024) {
cudaError_t status =
vllm::FilteredTopKRaggedTransform<float, int32_t, TopK>(
logits.data_ptr<float>(), output.data_ptr<int32_t>(),
lengths.data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
logits.const_data_ptr<float>(), output.mutable_data_ptr<int32_t>(),
lengths.const_data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
static_cast<uint32_t>(TopK), static_cast<uint32_t>(stride), stream);
TORCH_CHECK(status == cudaSuccess,
"FilteredTopK failed: ", cudaGetErrorString(status));
STD_TORCH_CHECK(status == cudaSuccess,
"FilteredTopK failed: ", cudaGetErrorString(status));
} else {
TORCH_CHECK(workspace.is_cuda(), "workspace must be CUDA tensor");
TORCH_CHECK(workspace.dtype() == torch::kUInt8, "workspace must be uint8");
STD_TORCH_CHECK(workspace.is_cuda(), "workspace must be CUDA tensor");
STD_TORCH_CHECK(
workspace.scalar_type() == torch::headeronly::ScalarType::Byte,
"workspace must be uint8");
int effective_max_smem;
if (num_rows <= 4) {
@@ -99,9 +101,9 @@ void launch_persistent_topk(const torch::Tensor& logits,
&occupancy, P::persistent_topk_kernel<TopK, 1>, P::kThreadsPerBlock,
smem_size);
}
TORCH_CHECK(occ_err == cudaSuccess,
"persistent_topk occupancy query failed: ",
cudaGetErrorString(occ_err));
STD_TORCH_CHECK(occ_err == cudaSuccess,
"persistent_topk occupancy query failed: ",
cudaGetErrorString(occ_err));
if (occupancy < 1) occupancy = 1;
// The cooperative spin-wait barrier only runs when at least one row hits
@@ -131,27 +133,29 @@ void launch_persistent_topk(const torch::Tensor& logits,
// If the cooperative launch wouldn't fit, fall back to FilteredTopK
// instead of deadlocking. Only relevant when needs_cooperative.
if (needs_cooperative && total_ctas > hw_resident_cap) {
TORCH_CHECK(max_smem_per_block >= 128 * 1024,
"persistent_topk would oversubscribe and the FilteredTopK "
"fallback requires >=128KB smem per block (have ",
max_smem_per_block, "). total_ctas=", total_ctas,
" > num_sms*occupancy=", hw_resident_cap, " (TopK=", TopK,
", vec_size=", vec_size, ", ctas_per_group=", ctas_per_group,
", smem=", smem_size, ").");
STD_TORCH_CHECK(
max_smem_per_block >= 128 * 1024,
"persistent_topk would oversubscribe and the FilteredTopK "
"fallback requires >=128KB smem per block (have ",
max_smem_per_block, "). total_ctas=", total_ctas,
" > num_sms*occupancy=", hw_resident_cap, " (TopK=", TopK,
", vec_size=", vec_size, ", ctas_per_group=", ctas_per_group,
", smem=", smem_size, ").");
cudaError_t status =
vllm::FilteredTopKRaggedTransform<float, int32_t, TopK>(
logits.data_ptr<float>(), output.data_ptr<int32_t>(),
lengths.data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
static_cast<uint32_t>(TopK), static_cast<uint32_t>(stride),
stream);
TORCH_CHECK(status == cudaSuccess,
"FilteredTopK fallback failed: ", cudaGetErrorString(status));
logits.const_data_ptr<float>(),
output.mutable_data_ptr<int32_t>(),
lengths.const_data_ptr<int32_t>(),
static_cast<uint32_t>(num_rows), static_cast<uint32_t>(TopK),
static_cast<uint32_t>(stride), stream);
STD_TORCH_CHECK(status == cudaSuccess, "FilteredTopK fallback failed: ",
cudaGetErrorString(status));
return;
}
size_t state_bytes = num_groups * sizeof(P::RadixRowState);
TORCH_CHECK(workspace.size(0) >= static_cast<int64_t>(state_bytes),
"workspace too small, need ", state_bytes, " bytes");
STD_TORCH_CHECK(workspace.size(0) >= static_cast<int64_t>(state_bytes),
"workspace too small, need ", state_bytes, " bytes");
// Zero the per-group RadixRowState region before launch.
//
@@ -179,22 +183,22 @@ void launch_persistent_topk(const torch::Tensor& logits,
// first red_release. cudaMemsetAsync is stream-ordered: the zero
// is globally visible before any CTA runs.
{
cudaError_t mz_err = cudaMemsetAsync(workspace.data_ptr<uint8_t>(), 0,
state_bytes, stream);
TORCH_CHECK(mz_err == cudaSuccess,
"row_states memset failed: ", cudaGetErrorString(mz_err));
cudaError_t mz_err = cudaMemsetAsync(
workspace.mutable_data_ptr<uint8_t>(), 0, state_bytes, stream);
STD_TORCH_CHECK(mz_err == cudaSuccess,
"row_states memset failed: ", cudaGetErrorString(mz_err));
}
P::PersistentTopKParams params;
params.input = logits.data_ptr<float>();
params.output = output.data_ptr<int32_t>();
params.lengths = lengths.data_ptr<int32_t>();
params.input = logits.const_data_ptr<float>();
params.output = output.mutable_data_ptr<int32_t>();
params.lengths = lengths.const_data_ptr<int32_t>();
params.num_rows = static_cast<uint32_t>(num_rows);
params.stride = static_cast<uint32_t>(stride);
params.top_k = static_cast<uint32_t>(TopK);
params.chunk_size = chunk_size;
params.row_states =
reinterpret_cast<P::RadixRowState*>(workspace.data_ptr<uint8_t>());
params.row_states = reinterpret_cast<P::RadixRowState*>(
workspace.mutable_data_ptr<uint8_t>());
params.ctas_per_group = ctas_per_group;
params.max_seq_len = static_cast<uint32_t>(max_seq_len);
@@ -203,8 +207,8 @@ void launch_persistent_topk(const torch::Tensor& logits,
auto kernel = &P::persistent_topk_kernel<TOPK_VAL, VS>; \
cudaError_t err = cudaFuncSetAttribute( \
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); \
TORCH_CHECK(err == cudaSuccess, \
"Failed to set smem: ", cudaGetErrorString(err)); \
STD_TORCH_CHECK(err == cudaSuccess, \
"Failed to set smem: ", cudaGetErrorString(err)); \
kernel<<<total_ctas, P::kThreadsPerBlock, smem_size, stream>>>(params); \
} while (0)
@@ -219,37 +223,42 @@ void launch_persistent_topk(const torch::Tensor& logits,
}
cudaError_t err = cudaGetLastError();
TORCH_CHECK(err == cudaSuccess,
"persistent_topk failed: ", cudaGetErrorString(err));
STD_TORCH_CHECK(err == cudaSuccess,
"persistent_topk failed: ", cudaGetErrorString(err));
}
#endif
} // anonymous namespace
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
void persistent_topk(const torch::stable::Tensor& logits,
const torch::stable::Tensor& lengths,
torch::stable::Tensor& output,
torch::stable::Tensor& workspace, int64_t k,
int64_t max_seq_len) {
#ifndef USE_ROCM
TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor");
TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor");
TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor");
TORCH_CHECK(logits.dtype() == torch::kFloat32, "Only float32 supported");
TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32");
TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32");
TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2,
"lengths must be 1D or 2D");
TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous");
TORCH_CHECK(output.dim() == 2, "output must be 2D");
STD_TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor");
STD_TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor");
STD_TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor");
STD_TORCH_CHECK(logits.scalar_type() == torch::headeronly::ScalarType::Float,
"Only float32 supported");
STD_TORCH_CHECK(lengths.scalar_type() == torch::headeronly::ScalarType::Int,
"lengths must be int32");
STD_TORCH_CHECK(output.scalar_type() == torch::headeronly::ScalarType::Int,
"output must be int32");
STD_TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
STD_TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2,
"lengths must be 1D or 2D");
STD_TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous");
STD_TORCH_CHECK(output.dim() == 2, "output must be 2D");
const int64_t num_rows = logits.size(0);
const int64_t stride = logits.stride(0);
TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch");
TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
"output size mismatch");
TORCH_CHECK(k == 512 || k == 1024 || k == 2048,
"persistent_topk supports k=512, k=1024, or k=2048, got k=", k);
STD_TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch");
STD_TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
"output size mismatch");
STD_TORCH_CHECK(
k == 512 || k == 1024 || k == 2048,
"persistent_topk supports k=512, k=1024, or k=2048, got k=", k);
if (k == 512) {
launch_persistent_topk<512>(logits, lengths, output, workspace,
@@ -262,6 +271,6 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
max_seq_len);
}
#else
TORCH_CHECK(false, "persistent_topk is not supported on ROCm");
STD_TORCH_CHECK(false, "persistent_topk is not supported on ROCm");
#endif
}
+62
View File
@@ -263,6 +263,20 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor");
#endif
// Merge attn states
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
ops.def(
"merge_attn_states("
" Tensor! output,"
" Tensor!? output_lse,"
" Tensor prefix_output,"
" Tensor prefix_lse,"
" Tensor suffix_output,"
" Tensor suffix_lse,"
" int!? prefill_tokens_with_context,"
" Tensor? output_scale=None) -> ()");
// Hadamard transforms
// conditionally compiled so impl registration is in source file
ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor");
@@ -319,6 +333,26 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"bool is_neox, Tensor position_ids, "
"int forced_token_heads_per_warp=-1) -> ()");
// Apply repetition penalties to logits in-place.
ops.def(
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
"Tensor output_mask, Tensor repetition_penalties) -> ()");
// Optimized top-k per row operations.
ops.def(
"top_k_per_row_prefill(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
"Tensor! indices, int numRows, int stride0, "
"int stride1, int topK) -> ()");
ops.def(
"top_k_per_row_decode(Tensor logits, int next_n, "
"Tensor seq_lens, Tensor! indices, "
"int numRows, int stride0, int stride1, int topK) -> ()");
ops.def(
"persistent_topk(Tensor logits, Tensor lengths, Tensor! output, "
"Tensor workspace, int k, int max_seq_len) -> ()");
// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
@@ -422,6 +456,24 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"int type, SymInt row, SymInt tokens) -> Tensor");
ops.def("ggml_moe_get_block_size(int type) -> int");
// Mamba selective scan kernel
ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"Tensor! ssm_states,"
"int null_block_id,"
"int block_size,"
"Tensor? block_idx_first_scheduled_token,"
"Tensor? block_idx_last_scheduled_token,"
"Tensor? initial_state_idx,"
"Tensor? cu_chunk_seqlen,"
"Tensor? last_chunk_indices) -> ()");
}
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
@@ -469,6 +521,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
// files (allspark_repack.cu and allspark_qgemm_w8a16.cu)
#endif
ops.impl("merge_attn_states", TORCH_BOX(&merge_attn_states));
// Layernorm kernels (shared CUDA/ROCm)
ops.impl("rms_norm", TORCH_BOX(&rms_norm));
ops.impl("fused_add_rms_norm", TORCH_BOX(&fused_add_rms_norm));
@@ -487,6 +541,13 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
ops.impl("rotary_embedding", TORCH_BOX(&rotary_embedding));
ops.impl("fused_qk_norm_rope", TORCH_BOX(&fused_qk_norm_rope));
// Sampler kernels (shared CUDA/ROCm)
ops.impl("apply_repetition_penalties_",
TORCH_BOX(&apply_repetition_penalties_));
ops.impl("top_k_per_row_prefill", TORCH_BOX(&top_k_per_row_prefill));
ops.impl("top_k_per_row_decode", TORCH_BOX(&top_k_per_row_decode));
ops.impl("persistent_topk", TORCH_BOX(&persistent_topk));
// Activation kernels (shared CUDA/ROCm)
ops.impl("silu_and_mul", TORCH_BOX(&silu_and_mul));
ops.impl("mul_and_silu", TORCH_BOX(&mul_and_silu));
@@ -519,6 +580,7 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
ops.impl("ggml_mul_mat_a8", TORCH_BOX(&ggml_mul_mat_a8));
ops.impl("ggml_moe_a8", TORCH_BOX(&ggml_moe_a8));
ops.impl("ggml_moe_a8_vec", TORCH_BOX(&ggml_moe_a8_vec));
ops.impl("selective_scan_fwd", TORCH_BOX(&selective_scan_fwd));
}
// These capability-check functions take only primitive args (no tensors), so
-43
View File
@@ -54,13 +54,6 @@ void paged_attention_v2(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void merge_attn_states(
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
const std::optional<int64_t> prefill_tokens_with_context,
const std::optional<torch::Tensor>& output_scale = std::nullopt);
// rms_norm and fused_add_rms_norm declarations also exist in
// csrc/libtorch_stable/ops.h (torch::stable ABI for CUDA). They remain here
// because the CPU build still uses these torch::Tensor declarations.
@@ -76,26 +69,6 @@ torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::Tensor const& cos_sin_cache, int64_t q_head_padded, double eps,
int64_t cache_block_size);
void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& prompt_mask,
const torch::Tensor& output_mask,
const torch::Tensor& repetition_penalties);
void top_k_per_row_prefill(const torch::Tensor& logits,
const torch::Tensor& rowStarts,
const torch::Tensor& rowEnds, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1,
int64_t topK);
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
const torch::Tensor& seqLens, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1,
int64_t topK);
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
int64_t max_seq_len);
void silu_and_mul_per_block_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scales, int64_t group_size,
@@ -150,22 +123,6 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scales,
std::optional<torch::Tensor> const& azp);
void selective_scan_fwd(
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
const torch::Tensor& B, const torch::Tensor& C,
const std::optional<torch::Tensor>& D_,
const std::optional<torch::Tensor>& z_,
const std::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
const std::optional<torch::Tensor>& query_start_loc,
const std::optional<torch::Tensor>& cache_indices,
const std::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states, int64_t null_block_id, int64_t block_size,
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
const std::optional<torch::Tensor>& initial_state_idx,
const std::optional<torch::Tensor>& cu_chunk_seqlen,
const std::optional<torch::Tensor>& last_chunk_indices);
torch::Tensor dynamic_4bit_int_moe_cpu(
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
+9 -7
View File
@@ -126,10 +126,10 @@ struct RadixRowState {
// ============================================================================
struct PersistentTopKParams {
const float* __restrict__ input; // [num_rows, stride]
int32_t* __restrict__ output; // [num_rows, top_k]
int32_t* __restrict__ lengths; // [num_rows]
RadixRowState* row_states; // large path: per-group state
const float* __restrict__ input; // [num_rows, stride]
int32_t* __restrict__ output; // [num_rows, top_k]
const int32_t* __restrict__ lengths; // [num_rows]
RadixRowState* row_states; // large path: per-group state
uint32_t num_rows;
uint32_t stride;
uint32_t top_k; // actual k value for output stride
@@ -1269,9 +1269,11 @@ constexpr int ComputeFilteredTopKVecSize(uint32_t max_len) {
}
template <typename DType, typename IdType, uint32_t MAX_K = 2048>
cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices,
IdType* lengths, uint32_t num_rows,
uint32_t top_k_val, uint32_t max_len,
cudaError_t FilteredTopKRaggedTransform(const DType* input,
IdType* output_indices,
const IdType* lengths,
uint32_t num_rows, uint32_t top_k_val,
uint32_t max_len,
cudaStream_t stream = 0) {
constexpr size_t smem_size = FILTERED_TOPK_SMEM_DYNAMIC;
constexpr int MAX_VEC = 16 / sizeof(DType);
-59
View File
@@ -62,21 +62,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
// Merge attn states
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
ops.def(
"merge_attn_states("
" Tensor! output,"
" Tensor!? output_lse,"
" Tensor prefix_output,"
" Tensor prefix_lse,"
" Tensor suffix_output,"
" Tensor suffix_lse,"
" int!? prefill_tokens_with_context,"
" Tensor? output_scale=None) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
// Activation ops (quantized only — basic ops moved to _C_stable_libtorch)
ops.def(
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
@@ -105,31 +90,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA,
&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert);
// Apply repetition penalties to logits in-place
ops.def(
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
"Tensor output_mask, Tensor repetition_penalties) -> ()");
ops.impl("apply_repetition_penalties_", torch::kCUDA,
&apply_repetition_penalties_);
// Optimized top-k per row operation
ops.def(
"top_k_per_row_prefill(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
"Tensor! indices, int numRows, int stride0, "
"int stride1, int topK) -> ()");
ops.impl("top_k_per_row_prefill", torch::kCUDA, &top_k_per_row_prefill);
ops.def(
"top_k_per_row_decode(Tensor logits, int next_n, "
"Tensor seq_lens, Tensor! indices, "
"int numRows, int stride0, int stride1, int topK) -> ()");
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
ops.def(
"persistent_topk(Tensor logits, Tensor lengths, Tensor! output, "
"Tensor workspace, int k, int max_seq_len) -> ()");
ops.impl("persistent_topk", torch::kCUDA, &persistent_topk);
// Quantization ops
#ifndef USE_ROCM
@@ -230,25 +190,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#endif
// Mamba selective scan kernel
ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"Tensor! ssm_states,"
"int null_block_id,"
"int block_size,"
"Tensor? block_idx_first_scheduled_token,"
"Tensor? block_idx_last_scheduled_token,"
"Tensor? initial_state_idx,"
"Tensor? cu_chunk_seqlen,"
"Tensor? last_chunk_indices) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
#ifndef USE_ROCM
ops.def(
"minimax_allreduce_rms("
+1
View File
@@ -141,6 +141,7 @@ bbc5b7ede = "bbc5b7ede"
NOOPs = "NOOPs"
nin_shortcut = "nin_shortcut"
cudaDevAttrMaxSharedMemoryPerBlockOptin = "cudaDevAttrMaxSharedMemoryPerBlockOptin"
sharedMemPerBlockOptin = "sharedMemPerBlockOptin"
depthwise_seperable_out_channel = "depthwise_seperable_out_channel"
pard_token = "pard_token"