mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[8/n] Migrate merge_attn_states, mamba, sampler to torch stable ABI (continued) (#43361)
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com> Signed-off-by: Chris Leonard <chleonar@redhat.com> Co-authored-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com> Co-authored-by: Shengqi Chen <harry-chen@outlook.com>
This commit is contained in:
+5
-5
@@ -305,14 +305,10 @@ endif()
|
|||||||
#
|
#
|
||||||
|
|
||||||
set(VLLM_EXT_SRC
|
set(VLLM_EXT_SRC
|
||||||
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
|
||||||
"csrc/cache_kernels.cu"
|
"csrc/cache_kernels.cu"
|
||||||
"csrc/cache_kernels_fused.cu"
|
"csrc/cache_kernels_fused.cu"
|
||||||
"csrc/attention/paged_attention_v1.cu"
|
"csrc/attention/paged_attention_v1.cu"
|
||||||
"csrc/attention/paged_attention_v2.cu"
|
"csrc/attention/paged_attention_v2.cu"
|
||||||
"csrc/attention/merge_attn_states.cu"
|
|
||||||
"csrc/sampler.cu"
|
|
||||||
"csrc/topk.cu"
|
|
||||||
"csrc/cuda_view.cu"
|
"csrc/cuda_view.cu"
|
||||||
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
|
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
|
||||||
"csrc/quantization/activation_kernels.cu"
|
"csrc/quantization/activation_kernels.cu"
|
||||||
@@ -633,7 +629,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
|||||||
"csrc/libtorch_stable/fused_qknorm_rope_kernel.cu"
|
"csrc/libtorch_stable/fused_qknorm_rope_kernel.cu"
|
||||||
"csrc/libtorch_stable/layernorm_kernels.cu"
|
"csrc/libtorch_stable/layernorm_kernels.cu"
|
||||||
"csrc/libtorch_stable/layernorm_quant_kernels.cu"
|
"csrc/libtorch_stable/layernorm_quant_kernels.cu"
|
||||||
"csrc/libtorch_stable/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu")
|
"csrc/libtorch_stable/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
|
||||||
|
"csrc/libtorch_stable/attention/merge_attn_states.cu"
|
||||||
|
"csrc/libtorch_stable/sampler.cu"
|
||||||
|
"csrc/libtorch_stable/topk.cu"
|
||||||
|
"csrc/libtorch_stable/mamba/selective_scan_fwd.cu")
|
||||||
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
list(APPEND VLLM_STABLE_EXT_SRC
|
list(APPEND VLLM_STABLE_EXT_SRC
|
||||||
|
|||||||
+56
-49
@@ -1,14 +1,14 @@
|
|||||||
#include <optional>
|
#include <optional>
|
||||||
#include <torch/all.h>
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
#include "attention_dtypes.h"
|
#include "../torch_utils.h"
|
||||||
#include "attention_utils.cuh"
|
|
||||||
#include "../quantization/w8a8/fp8/common.cuh"
|
|
||||||
#include "../dispatch_utils.h"
|
#include "../dispatch_utils.h"
|
||||||
|
#include <torch/headeronly/core/ScalarType.h>
|
||||||
|
|
||||||
|
#include "../../attention/attention_dtypes.h"
|
||||||
|
#include "../../attention/attention_utils.cuh"
|
||||||
|
#include "../../quantization/w8a8/fp8/common.cuh"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
@@ -196,17 +196,17 @@ __global__ void merge_attn_states_kernel(
|
|||||||
// The following macro is used to dispatch the conversion function based on
|
// The following macro is used to dispatch the conversion function based on
|
||||||
// the output data type. The FN is a macro that calls a function with
|
// the output data type. The FN is a macro that calls a function with
|
||||||
// template<typename scalar_t>.
|
// template<typename scalar_t>.
|
||||||
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
|
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
|
||||||
{ \
|
{ \
|
||||||
if (scalar_dtype == at::ScalarType::Float) { \
|
if (scalar_dtype == torch::headeronly::ScalarType::Float) { \
|
||||||
fn(float); \
|
fn(float); \
|
||||||
} else if (scalar_dtype == at::ScalarType::Half) { \
|
} else if (scalar_dtype == torch::headeronly::ScalarType::Half) { \
|
||||||
fn(uint16_t); \
|
fn(uint16_t); \
|
||||||
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
|
} else if (scalar_dtype == torch::headeronly::ScalarType::BFloat16) { \
|
||||||
fn(__nv_bfloat16); \
|
fn(__nv_bfloat16); \
|
||||||
} else { \
|
} else { \
|
||||||
TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
|
STD_TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, output_t, NUM_THREADS, \
|
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, output_t, NUM_THREADS, \
|
||||||
@@ -245,11 +245,14 @@ __global__ void merge_attn_states_kernel(
|
|||||||
*/
|
*/
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void merge_attn_states_launcher(
|
void merge_attn_states_launcher(
|
||||||
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
|
torch::stable::Tensor& output,
|
||||||
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
|
std::optional<torch::stable::Tensor> output_lse,
|
||||||
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
|
const torch::stable::Tensor& prefix_output,
|
||||||
|
const torch::stable::Tensor& prefix_lse,
|
||||||
|
const torch::stable::Tensor& suffix_output,
|
||||||
|
const torch::stable::Tensor& suffix_lse,
|
||||||
const std::optional<int64_t> prefill_tokens_with_context,
|
const std::optional<int64_t> prefill_tokens_with_context,
|
||||||
const std::optional<torch::Tensor>& output_scale) {
|
const std::optional<torch::stable::Tensor>& output_scale) {
|
||||||
constexpr uint NUM_THREADS = 128;
|
constexpr uint NUM_THREADS = 128;
|
||||||
const uint num_tokens = output.size(0);
|
const uint num_tokens = output.size(0);
|
||||||
const uint num_heads = output.size(1);
|
const uint num_heads = output.size(1);
|
||||||
@@ -258,23 +261,23 @@ void merge_attn_states_launcher(
|
|||||||
const uint output_head_stride = output.stride(1);
|
const uint output_head_stride = output.stride(1);
|
||||||
// Thread mapping is based on input BF16 pack_size
|
// Thread mapping is based on input BF16 pack_size
|
||||||
const uint pack_size = 16 / sizeof(scalar_t);
|
const uint pack_size = 16 / sizeof(scalar_t);
|
||||||
TORCH_CHECK(head_size % pack_size == 0,
|
STD_TORCH_CHECK(head_size % pack_size == 0,
|
||||||
"headsize must be multiple of pack_size:", pack_size);
|
"headsize must be multiple of pack_size:", pack_size);
|
||||||
|
|
||||||
const uint prefix_num_tokens =
|
const uint prefix_num_tokens =
|
||||||
prefill_tokens_with_context.has_value()
|
prefill_tokens_with_context.has_value()
|
||||||
? static_cast<uint>(prefill_tokens_with_context.value())
|
? static_cast<uint>(prefill_tokens_with_context.value())
|
||||||
: num_tokens;
|
: num_tokens;
|
||||||
TORCH_CHECK(prefix_num_tokens <= num_tokens,
|
STD_TORCH_CHECK(prefix_num_tokens <= num_tokens,
|
||||||
"prefix_num_tokens must be <= num_tokens");
|
"prefix_num_tokens must be <= num_tokens");
|
||||||
|
|
||||||
float* output_lse_ptr = nullptr;
|
float* output_lse_ptr = nullptr;
|
||||||
if (output_lse.has_value()) {
|
if (output_lse.has_value()) {
|
||||||
output_lse_ptr = output_lse.value().data_ptr<float>();
|
output_lse_ptr = output_lse.value().mutable_data_ptr<float>();
|
||||||
}
|
}
|
||||||
float* output_scale_ptr = nullptr;
|
float* output_scale_ptr = nullptr;
|
||||||
if (output_scale.has_value()) {
|
if (output_scale.has_value()) {
|
||||||
output_scale_ptr = output_scale.value().data_ptr<float>();
|
output_scale_ptr = output_scale.value().mutable_data_ptr<float>();
|
||||||
}
|
}
|
||||||
// Process one pack elements per thread. for float, the
|
// Process one pack elements per thread. for float, the
|
||||||
// pack_size is 4 for half/bf16, the pack_size is 8.
|
// pack_size is 4 for half/bf16, the pack_size is 8.
|
||||||
@@ -284,14 +287,15 @@ void merge_attn_states_launcher(
|
|||||||
dim3 block(NUM_THREADS);
|
dim3 block(NUM_THREADS);
|
||||||
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
|
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
|
||||||
|
|
||||||
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
|
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||||
auto stream = at::cuda::getCurrentCUDAStream();
|
prefix_output.get_device_index());
|
||||||
|
auto stream = get_current_cuda_stream();
|
||||||
|
|
||||||
if (output_scale.has_value()) {
|
if (output_scale.has_value()) {
|
||||||
// FP8 output path - dispatch on output FP8 type
|
// FP8 output path - dispatch on output FP8 type
|
||||||
VLLM_DISPATCH_FP8_TYPES(output.scalar_type(), "merge_attn_states_fp8", [&] {
|
VLLM_STABLE_DISPATCH_FP8_TYPES(
|
||||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, fp8_t, NUM_THREADS, true);
|
output.scalar_type(), "merge_attn_states_fp8",
|
||||||
});
|
[&] { LAUNCH_MERGE_ATTN_STATES(scalar_t, fp8_t, NUM_THREADS, true); });
|
||||||
} else {
|
} else {
|
||||||
// Original BF16/FP16/FP32 output path
|
// Original BF16/FP16/FP32 output path
|
||||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, scalar_t, NUM_THREADS, false);
|
LAUNCH_MERGE_ATTN_STATES(scalar_t, scalar_t, NUM_THREADS, false);
|
||||||
@@ -305,26 +309,29 @@ void merge_attn_states_launcher(
|
|||||||
suffix_lse, prefill_tokens_with_context, output_scale); \
|
suffix_lse, prefill_tokens_with_context, output_scale); \
|
||||||
}
|
}
|
||||||
|
|
||||||
void merge_attn_states(torch::Tensor& output,
|
void merge_attn_states(
|
||||||
std::optional<torch::Tensor> output_lse,
|
torch::stable::Tensor& output,
|
||||||
const torch::Tensor& prefix_output,
|
std::optional<torch::stable::Tensor> output_lse,
|
||||||
const torch::Tensor& prefix_lse,
|
const torch::stable::Tensor& prefix_output,
|
||||||
const torch::Tensor& suffix_output,
|
const torch::stable::Tensor& prefix_lse,
|
||||||
const torch::Tensor& suffix_lse,
|
const torch::stable::Tensor& suffix_output,
|
||||||
std::optional<int64_t> prefill_tokens_with_context,
|
const torch::stable::Tensor& suffix_lse,
|
||||||
const std::optional<torch::Tensor>& output_scale) {
|
const std::optional<int64_t> prefill_tokens_with_context,
|
||||||
|
const std::optional<torch::stable::Tensor>& output_scale) {
|
||||||
if (output_scale.has_value()) {
|
if (output_scale.has_value()) {
|
||||||
TORCH_CHECK(output.scalar_type() == at::ScalarType::Float8_e4m3fn ||
|
STD_TORCH_CHECK(
|
||||||
output.scalar_type() == at::ScalarType::Float8_e4m3fnuz,
|
output.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn ||
|
||||||
"output must be FP8 when output_scale is provided, got: ",
|
output.scalar_type() ==
|
||||||
output.scalar_type());
|
torch::headeronly::ScalarType::Float8_e4m3fnuz,
|
||||||
|
"output must be FP8 when output_scale is provided, got: ",
|
||||||
|
output.scalar_type());
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(output.scalar_type() == prefix_output.scalar_type(),
|
STD_TORCH_CHECK(
|
||||||
"output dtype (", output.scalar_type(),
|
output.scalar_type() == prefix_output.scalar_type(), "output dtype (",
|
||||||
") must match prefix_output dtype (",
|
output.scalar_type(), ") must match prefix_output dtype (",
|
||||||
prefix_output.scalar_type(), ") when output_scale is not set");
|
prefix_output.scalar_type(), ") when output_scale is not set");
|
||||||
}
|
}
|
||||||
// Always dispatch on prefix_output (input) dtype
|
// Always dispatch on prefix_output (input) dtype
|
||||||
DISPATCH_BY_SCALAR_DTYPE(prefix_output.dtype(),
|
DISPATCH_BY_SCALAR_DTYPE(prefix_output.scalar_type(),
|
||||||
CALL_MERGE_ATTN_STATES_LAUNCHER);
|
CALL_MERGE_ATTN_STATES_LAUNCHER);
|
||||||
}
|
}
|
||||||
@@ -12,6 +12,9 @@
|
|||||||
#include <hip/hip_bf16.h>
|
#include <hip/hip_bf16.h>
|
||||||
#endif
|
#endif
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
#include <torch/headeronly/util/Half.h>
|
||||||
|
#include <torch/headeronly/util/BFloat16.h>
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
struct SSMParamsBase {
|
struct SSMParamsBase {
|
||||||
@@ -159,8 +162,8 @@ struct Converter{
|
|||||||
};
|
};
|
||||||
|
|
||||||
template<int N>
|
template<int N>
|
||||||
struct Converter<at::Half, N>{
|
struct Converter<torch::headeronly::Half, N>{
|
||||||
static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
|
static inline __device__ void to_float(const torch::headeronly::Half (&src)[N], float (&dst)[N]) {
|
||||||
static_assert(N % 2 == 0);
|
static_assert(N % 2 == 0);
|
||||||
auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
|
auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
|
||||||
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
||||||
@@ -171,8 +174,8 @@ struct Converter<at::Half, N>{
|
|||||||
|
|
||||||
#if __CUDA_ARCH__ >= 800
|
#if __CUDA_ARCH__ >= 800
|
||||||
template<int N>
|
template<int N>
|
||||||
struct Converter<at::BFloat16, N>{
|
struct Converter<torch::headeronly::BFloat16, N>{
|
||||||
static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
|
static inline __device__ void to_float(const torch::headeronly::BFloat16 (&src)[N], float (&dst)[N]) {
|
||||||
static_assert(N % 2 == 0);
|
static_assert(N % 2 == 0);
|
||||||
auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
|
auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
|
||||||
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
||||||
+102
-111
@@ -1,18 +1,9 @@
|
|||||||
// clang-format off
|
// clang-format off
|
||||||
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh
|
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh
|
||||||
#include <torch/all.h>
|
#include "../torch_utils.h"
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <torch/csrc/stable/macros.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
#include "selective_scan.h"
|
#include "selective_scan.h"
|
||||||
|
|
||||||
#include <c10/util/BFloat16.h>
|
|
||||||
#include <c10/util/Half.h>
|
|
||||||
#ifdef USE_ROCM
|
|
||||||
#include <c10/hip/HIPException.h> // For C10_HIP_CHECK and C10_HIP_KERNEL_LAUNCH_CHECK
|
|
||||||
#else
|
|
||||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#include <cub/block/block_load.cuh>
|
#include <cub/block/block_load.cuh>
|
||||||
#include <cub/block/block_store.cuh>
|
#include <cub/block/block_store.cuh>
|
||||||
@@ -416,15 +407,15 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
|||||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||||
if (kSmemSize >= 48 * 1024) {
|
if (kSmemSize >= 48 * 1024) {
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
C10_HIP_CHECK(hipFuncSetAttribute(
|
STD_CUDA_CHECK(hipFuncSetAttribute(
|
||||||
reinterpret_cast<const void*>(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
reinterpret_cast<const void*>(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||||
#else
|
#else
|
||||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
STD_CUDA_CHECK(cudaFuncSetAttribute(
|
||||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
STD_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -462,46 +453,46 @@ void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template void selective_scan_fwd_cuda<at::BFloat16, float, at::BFloat16>(SSMParamsBase ¶ms, cudaStream_t stream);
|
template void selective_scan_fwd_cuda<torch::headeronly::BFloat16, float, torch::headeronly::BFloat16>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||||
template void selective_scan_fwd_cuda<at::BFloat16, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
template void selective_scan_fwd_cuda<torch::headeronly::BFloat16, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||||
template void selective_scan_fwd_cuda<at::Half, float, at::Half>(SSMParamsBase ¶ms, cudaStream_t stream);
|
template void selective_scan_fwd_cuda<torch::headeronly::Half, float, torch::headeronly::Half>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||||
template void selective_scan_fwd_cuda<at::Half, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
template void selective_scan_fwd_cuda<torch::headeronly::Half, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||||
template void selective_scan_fwd_cuda<float, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
template void selective_scan_fwd_cuda<float, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||||
|
|
||||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
#define CHECK_SHAPE(x, ...) STD_TORCH_CHECK(x.sizes().equals(torch::headeronly::IntHeaderOnlyArrayRef({__VA_ARGS__})), #x " must have shape (" #__VA_ARGS__ ")")
|
||||||
|
|
||||||
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, STYPE, NAME, ...) \
|
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, STYPE, NAME, ...) \
|
||||||
if (ITYPE == at::ScalarType::Half) { \
|
if (ITYPE == torch::headeronly::ScalarType::Half) { \
|
||||||
using input_t = at::Half; \
|
using input_t = torch::headeronly::Half; \
|
||||||
using weight_t = float; \
|
using weight_t = float; \
|
||||||
if (STYPE == at::ScalarType::Half) { \
|
if (STYPE == torch::headeronly::ScalarType::Half) { \
|
||||||
using state_t = at::Half; \
|
using state_t = torch::headeronly::Half; \
|
||||||
__VA_ARGS__(); \
|
__VA_ARGS__(); \
|
||||||
} else if (STYPE == at::ScalarType::Float) { \
|
} else if (STYPE == torch::headeronly::ScalarType::Float) { \
|
||||||
using state_t = float; \
|
using state_t = float; \
|
||||||
__VA_ARGS__(); \
|
__VA_ARGS__(); \
|
||||||
} else { \
|
} else { \
|
||||||
AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
|
STD_TORCH_CHECK(false, #NAME " not implemented for state type '", STYPE, "'"); \
|
||||||
} \
|
} \
|
||||||
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
} else if (ITYPE == torch::headeronly::ScalarType::BFloat16) { \
|
||||||
using input_t = at::BFloat16; \
|
using input_t = torch::headeronly::BFloat16; \
|
||||||
using weight_t = float; \
|
using weight_t = float; \
|
||||||
if (STYPE == at::ScalarType::BFloat16) { \
|
if (STYPE == torch::headeronly::ScalarType::BFloat16) { \
|
||||||
using state_t = at::BFloat16; \
|
using state_t = torch::headeronly::BFloat16; \
|
||||||
__VA_ARGS__(); \
|
__VA_ARGS__(); \
|
||||||
} else if (STYPE == at::ScalarType::Float) { \
|
} else if (STYPE == torch::headeronly::ScalarType::Float) { \
|
||||||
using state_t = float; \
|
using state_t = float; \
|
||||||
__VA_ARGS__(); \
|
__VA_ARGS__(); \
|
||||||
} else { \
|
} else { \
|
||||||
AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
|
STD_TORCH_CHECK(false, #NAME " not implemented for state type '", STYPE, "'"); \
|
||||||
} \
|
} \
|
||||||
} else if (ITYPE == at::ScalarType::Float) { \
|
} else if (ITYPE == torch::headeronly::ScalarType::Float) { \
|
||||||
using input_t = float; \
|
using input_t = float; \
|
||||||
using weight_t = float; \
|
using weight_t = float; \
|
||||||
using state_t = float; \
|
using state_t = float; \
|
||||||
__VA_ARGS__(); \
|
__VA_ARGS__(); \
|
||||||
} else { \
|
} else { \
|
||||||
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
STD_TORCH_CHECK(false, #NAME " not implemented for input type '", ITYPE, "'"); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -518,30 +509,30 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
|||||||
const bool is_variable_B,
|
const bool is_variable_B,
|
||||||
const bool is_variable_C,
|
const bool is_variable_C,
|
||||||
// device pointers
|
// device pointers
|
||||||
const torch::Tensor u,
|
const torch::stable::Tensor u,
|
||||||
const torch::Tensor delta,
|
const torch::stable::Tensor delta,
|
||||||
const torch::Tensor A,
|
const torch::stable::Tensor A,
|
||||||
const torch::Tensor B,
|
const torch::stable::Tensor B,
|
||||||
const torch::Tensor C,
|
const torch::stable::Tensor C,
|
||||||
const torch::Tensor out,
|
const torch::stable::Tensor out,
|
||||||
const torch::Tensor z,
|
const torch::stable::Tensor z,
|
||||||
const torch::Tensor out_z,
|
const torch::stable::Tensor out_z,
|
||||||
const std::optional<at::Tensor>& D,
|
const std::optional<torch::stable::Tensor>& D,
|
||||||
const std::optional<at::Tensor>& delta_bias,
|
const std::optional<torch::stable::Tensor>& delta_bias,
|
||||||
const torch::Tensor ssm_states,
|
const torch::stable::Tensor ssm_states,
|
||||||
bool has_z,
|
bool has_z,
|
||||||
bool delta_softplus,
|
bool delta_softplus,
|
||||||
const std::optional<at::Tensor>& query_start_loc,
|
const std::optional<torch::stable::Tensor>& query_start_loc,
|
||||||
const std::optional<at::Tensor>& cache_indices,
|
const std::optional<torch::stable::Tensor>& cache_indices,
|
||||||
const std::optional<at::Tensor>& has_initial_state,
|
const std::optional<torch::stable::Tensor>& has_initial_state,
|
||||||
bool varlen,
|
bool varlen,
|
||||||
int64_t null_block_id,
|
int64_t null_block_id,
|
||||||
int64_t block_size,
|
int64_t block_size,
|
||||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
const std::optional<torch::stable::Tensor> &block_idx_first_scheduled_token,
|
||||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
const std::optional<torch::stable::Tensor> &block_idx_last_scheduled_token,
|
||||||
const std::optional<torch::Tensor> &initial_state_idx,
|
const std::optional<torch::stable::Tensor> &initial_state_idx,
|
||||||
const std::optional<torch::Tensor> &cu_chunk_seqlen,
|
const std::optional<torch::stable::Tensor> &cu_chunk_seqlen,
|
||||||
const std::optional<torch::Tensor> &last_chunk_indices) {
|
const std::optional<torch::stable::Tensor> &last_chunk_indices) {
|
||||||
|
|
||||||
// Reset the parameters
|
// Reset the parameters
|
||||||
memset(¶ms, 0, sizeof(params));
|
memset(¶ms, 0, sizeof(params));
|
||||||
@@ -654,45 +645,45 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
void selective_scan_fwd(const torch::stable::Tensor &u, const torch::stable::Tensor &delta,
|
||||||
const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C,
|
const torch::stable::Tensor &A, const torch::stable::Tensor &B, const torch::stable::Tensor &C,
|
||||||
const std::optional<torch::Tensor> &D_,
|
const std::optional<torch::stable::Tensor> &D_,
|
||||||
const std::optional<torch::Tensor> &z_,
|
const std::optional<torch::stable::Tensor> &z_,
|
||||||
const std::optional<torch::Tensor> &delta_bias_,
|
const std::optional<torch::stable::Tensor> &delta_bias_,
|
||||||
bool delta_softplus,
|
bool delta_softplus,
|
||||||
const std::optional<torch::Tensor> &query_start_loc,
|
const std::optional<torch::stable::Tensor> &query_start_loc,
|
||||||
const std::optional<torch::Tensor> &cache_indices,
|
const std::optional<torch::stable::Tensor> &cache_indices,
|
||||||
const std::optional<torch::Tensor> &has_initial_state,
|
const std::optional<torch::stable::Tensor> &has_initial_state,
|
||||||
const torch::Tensor &ssm_states,
|
const torch::stable::Tensor &ssm_states,
|
||||||
// used to identify padding entries if cache_indices provided
|
// used to identify padding entries if cache_indices provided
|
||||||
// in case of padding, the kernel will return early
|
// in case of padding, the kernel will return early
|
||||||
int64_t null_block_id,
|
int64_t null_block_id,
|
||||||
int64_t block_size,
|
int64_t block_size,
|
||||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
const std::optional<torch::stable::Tensor> &block_idx_first_scheduled_token,
|
||||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
const std::optional<torch::stable::Tensor> &block_idx_last_scheduled_token,
|
||||||
const std::optional<torch::Tensor> &initial_state_idx,
|
const std::optional<torch::stable::Tensor> &initial_state_idx,
|
||||||
const std::optional<torch::Tensor> &cu_chunk_seqlen,
|
const std::optional<torch::stable::Tensor> &cu_chunk_seqlen,
|
||||||
const std::optional<torch::Tensor> &last_chunk_indices) {
|
const std::optional<torch::stable::Tensor> &last_chunk_indices) {
|
||||||
auto input_type = u.scalar_type();
|
auto input_type = u.scalar_type();
|
||||||
auto weight_type = A.scalar_type();
|
auto weight_type = A.scalar_type();
|
||||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
STD_TORCH_CHECK(input_type == torch::headeronly::ScalarType::Float || input_type == torch::headeronly::ScalarType::Half || input_type == torch::headeronly::ScalarType::BFloat16);
|
||||||
TORCH_CHECK(weight_type == at::ScalarType::Float);
|
STD_TORCH_CHECK(weight_type == torch::headeronly::ScalarType::Float);
|
||||||
|
|
||||||
const bool is_variable_B = B.dim() >= 3;
|
const bool is_variable_B = B.dim() >= 3;
|
||||||
const bool is_variable_C = C.dim() >= 3;
|
const bool is_variable_C = C.dim() >= 3;
|
||||||
|
|
||||||
TORCH_CHECK(delta.scalar_type() == input_type);
|
STD_TORCH_CHECK(delta.scalar_type() == input_type);
|
||||||
TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
|
STD_TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
|
||||||
TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
|
STD_TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
|
||||||
|
|
||||||
TORCH_CHECK(u.is_cuda());
|
STD_TORCH_CHECK(u.is_cuda());
|
||||||
TORCH_CHECK(delta.is_cuda());
|
STD_TORCH_CHECK(delta.is_cuda());
|
||||||
TORCH_CHECK(A.is_cuda());
|
STD_TORCH_CHECK(A.is_cuda());
|
||||||
TORCH_CHECK(B.is_cuda());
|
STD_TORCH_CHECK(B.is_cuda());
|
||||||
TORCH_CHECK(C.is_cuda());
|
STD_TORCH_CHECK(C.is_cuda());
|
||||||
|
|
||||||
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
STD_TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
||||||
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
STD_TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
||||||
|
|
||||||
const auto sizes = u.sizes();
|
const auto sizes = u.sizes();
|
||||||
const bool varlen = query_start_loc.has_value();
|
const bool varlen = query_start_loc.has_value();
|
||||||
@@ -702,7 +693,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
|||||||
const int dstate = A.size(1);
|
const int dstate = A.size(1);
|
||||||
const int n_groups = varlen ? B.size(0) : B.size(1);
|
const int n_groups = varlen ? B.size(0) : B.size(1);
|
||||||
|
|
||||||
TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
|
STD_TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
|
||||||
|
|
||||||
if (varlen) {
|
if (varlen) {
|
||||||
CHECK_SHAPE(u, dim, seqlen);
|
CHECK_SHAPE(u, dim, seqlen);
|
||||||
@@ -712,94 +703,94 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
|||||||
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
||||||
}
|
}
|
||||||
CHECK_SHAPE(A, dim, dstate);
|
CHECK_SHAPE(A, dim, dstate);
|
||||||
TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size")
|
STD_TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size");
|
||||||
if (varlen) {
|
if (varlen) {
|
||||||
CHECK_SHAPE(B, n_groups, dstate, seqlen);
|
CHECK_SHAPE(B, n_groups, dstate, seqlen);
|
||||||
} else {
|
} else {
|
||||||
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen);
|
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen);
|
||||||
}
|
}
|
||||||
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
STD_TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
||||||
|
|
||||||
TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size")
|
STD_TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size");
|
||||||
if (varlen) {
|
if (varlen) {
|
||||||
CHECK_SHAPE(C, n_groups, dstate, seqlen);
|
CHECK_SHAPE(C, n_groups, dstate, seqlen);
|
||||||
} else {
|
} else {
|
||||||
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
|
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
|
||||||
}
|
}
|
||||||
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
STD_TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
||||||
|
|
||||||
if (D_.has_value()) {
|
if (D_.has_value()) {
|
||||||
auto D = D_.value();
|
auto D = D_.value();
|
||||||
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
STD_TORCH_CHECK(D.scalar_type() == torch::headeronly::ScalarType::Float);
|
||||||
TORCH_CHECK(D.is_cuda());
|
STD_TORCH_CHECK(D.is_cuda());
|
||||||
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
STD_TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
||||||
CHECK_SHAPE(D, dim);
|
CHECK_SHAPE(D, dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (delta_bias_.has_value()) {
|
if (delta_bias_.has_value()) {
|
||||||
auto delta_bias = delta_bias_.value();
|
auto delta_bias = delta_bias_.value();
|
||||||
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
STD_TORCH_CHECK(delta_bias.scalar_type() == torch::headeronly::ScalarType::Float);
|
||||||
TORCH_CHECK(delta_bias.is_cuda());
|
STD_TORCH_CHECK(delta_bias.is_cuda());
|
||||||
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
STD_TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
||||||
CHECK_SHAPE(delta_bias, dim);
|
CHECK_SHAPE(delta_bias, dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if (has_initial_state.has_value()) {
|
if (has_initial_state.has_value()) {
|
||||||
auto has_initial_state_ = has_initial_state.value();
|
auto has_initial_state_ = has_initial_state.value();
|
||||||
TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool);
|
STD_TORCH_CHECK(has_initial_state_.scalar_type() == torch::headeronly::ScalarType::Bool);
|
||||||
TORCH_CHECK(has_initial_state_.is_cuda());
|
STD_TORCH_CHECK(has_initial_state_.is_cuda());
|
||||||
CHECK_SHAPE(has_initial_state_, batch_size);
|
CHECK_SHAPE(has_initial_state_, batch_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if (query_start_loc.has_value()) {
|
if (query_start_loc.has_value()) {
|
||||||
auto query_start_loc_ = query_start_loc.value();
|
auto query_start_loc_ = query_start_loc.value();
|
||||||
TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
|
STD_TORCH_CHECK(query_start_loc_.scalar_type() == torch::headeronly::ScalarType::Int);
|
||||||
TORCH_CHECK(query_start_loc_.is_cuda());
|
STD_TORCH_CHECK(query_start_loc_.is_cuda());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if (cache_indices.has_value()) {
|
if (cache_indices.has_value()) {
|
||||||
auto cache_indices_ = cache_indices.value();
|
auto cache_indices_ = cache_indices.value();
|
||||||
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
|
STD_TORCH_CHECK(cache_indices_.scalar_type() == torch::headeronly::ScalarType::Int);
|
||||||
TORCH_CHECK(cache_indices_.is_cuda());
|
STD_TORCH_CHECK(cache_indices_.is_cuda());
|
||||||
|
|
||||||
// cache_indices can be either 1D (batch_size,) for non-APC mode
|
// cache_indices can be either 1D (batch_size,) for non-APC mode
|
||||||
// or 2D (batch_size, max_positions) for APC mode
|
// or 2D (batch_size, max_positions) for APC mode
|
||||||
const bool is_apc_mode = block_idx_first_scheduled_token.has_value();
|
const bool is_apc_mode = block_idx_first_scheduled_token.has_value();
|
||||||
if (is_apc_mode) {
|
if (is_apc_mode) {
|
||||||
TORCH_CHECK(cache_indices_.dim() == 2, "cache_indices must be 2D for APC mode");
|
STD_TORCH_CHECK(cache_indices_.dim() == 2, "cache_indices must be 2D for APC mode");
|
||||||
TORCH_CHECK(cache_indices_.size(0) == batch_size, "cache_indices first dimension must match batch_size");
|
STD_TORCH_CHECK(cache_indices_.size(0) == batch_size, "cache_indices first dimension must match batch_size");
|
||||||
} else {
|
} else {
|
||||||
CHECK_SHAPE(cache_indices_, batch_size);
|
CHECK_SHAPE(cache_indices_, batch_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
at::Tensor z, out_z;
|
|
||||||
|
torch::stable::Tensor z, out_z;
|
||||||
const bool has_z = z_.has_value();
|
const bool has_z = z_.has_value();
|
||||||
if (has_z) {
|
if (has_z) {
|
||||||
z = z_.value();
|
z = z_.value();
|
||||||
TORCH_CHECK(z.scalar_type() == input_type);
|
STD_TORCH_CHECK(z.scalar_type() == input_type);
|
||||||
TORCH_CHECK(z.is_cuda());
|
STD_TORCH_CHECK(z.is_cuda());
|
||||||
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
STD_TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||||
if (varlen){
|
if (varlen){
|
||||||
CHECK_SHAPE(z, dim, seqlen);
|
CHECK_SHAPE(z, dim, seqlen);
|
||||||
} else {
|
} else {
|
||||||
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
||||||
}
|
}
|
||||||
|
|
||||||
out_z = z;
|
out_z = z;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
||||||
at::Tensor out = delta;
|
torch::stable::Tensor out = delta;
|
||||||
// ssm_states can now be either the same as input_type or float32
|
// ssm_states can now be either the same as input_type or float32
|
||||||
auto state_type = ssm_states.scalar_type();
|
auto state_type = ssm_states.scalar_type();
|
||||||
TORCH_CHECK(state_type == input_type || state_type == at::ScalarType::Float);
|
STD_TORCH_CHECK(state_type == input_type || state_type == torch::headeronly::ScalarType::Float);
|
||||||
TORCH_CHECK(ssm_states.is_cuda());
|
STD_TORCH_CHECK(ssm_states.is_cuda());
|
||||||
TORCH_CHECK(ssm_states.stride(-1) == 1);
|
STD_TORCH_CHECK(ssm_states.stride(-1) == 1);
|
||||||
|
|
||||||
SSMParamsBase params;
|
SSMParamsBase params;
|
||||||
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C,
|
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C,
|
||||||
@@ -823,8 +814,8 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
|||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(u));
|
const torch::stable::accelerator::DeviceGuard device_guard(u.get_device_index());
|
||||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
auto stream = get_current_cuda_stream();
|
||||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), ssm_states.scalar_type(), "selective_scan_fwd", [&] {
|
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), ssm_states.scalar_type(), "selective_scan_fwd", [&] {
|
||||||
selective_scan_fwd_cuda<input_t, weight_t, state_t>(params, stream);
|
selective_scan_fwd_cuda<input_t, weight_t, state_t>(params, stream);
|
||||||
});
|
});
|
||||||
@@ -164,6 +164,17 @@ torch::stable::Tensor awq_dequantize(torch::stable::Tensor _kernel,
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Attention kernels (shared CUDA/ROCm)
|
||||||
|
void merge_attn_states(
|
||||||
|
torch::stable::Tensor& output,
|
||||||
|
std::optional<torch::stable::Tensor> output_lse,
|
||||||
|
const torch::stable::Tensor& prefix_output,
|
||||||
|
const torch::stable::Tensor& prefix_lse,
|
||||||
|
const torch::stable::Tensor& suffix_output,
|
||||||
|
const torch::stable::Tensor& suffix_lse,
|
||||||
|
const std::optional<int64_t> prefill_tokens_with_context,
|
||||||
|
const std::optional<torch::stable::Tensor>& output_scale = std::nullopt);
|
||||||
|
|
||||||
torch::stable::Tensor hadacore_transform(torch::stable::Tensor& x,
|
torch::stable::Tensor hadacore_transform(torch::stable::Tensor& x,
|
||||||
bool inplace);
|
bool inplace);
|
||||||
|
|
||||||
@@ -220,6 +231,48 @@ void fused_qk_norm_rope(torch::stable::Tensor& qkv, int64_t num_heads_q,
|
|||||||
torch::stable::Tensor& position_ids,
|
torch::stable::Tensor& position_ids,
|
||||||
int64_t forced_token_heads_per_warp);
|
int64_t forced_token_heads_per_warp);
|
||||||
|
|
||||||
|
// Sampler kernels (shared CUDA/ROCm)
|
||||||
|
void apply_repetition_penalties_(
|
||||||
|
torch::stable::Tensor& logits, const torch::stable::Tensor& prompt_mask,
|
||||||
|
const torch::stable::Tensor& output_mask,
|
||||||
|
const torch::stable::Tensor& repetition_penalties);
|
||||||
|
|
||||||
|
void top_k_per_row_prefill(const torch::stable::Tensor& logits,
|
||||||
|
const torch::stable::Tensor& rowStarts,
|
||||||
|
const torch::stable::Tensor& rowEnds,
|
||||||
|
torch::stable::Tensor& indices, int64_t numRows,
|
||||||
|
int64_t stride0, int64_t stride1, int64_t topK);
|
||||||
|
|
||||||
|
void top_k_per_row_decode(const torch::stable::Tensor& logits, int64_t next_n,
|
||||||
|
const torch::stable::Tensor& seqLens,
|
||||||
|
torch::stable::Tensor& indices, int64_t numRows,
|
||||||
|
int64_t stride0, int64_t stride1, int64_t topK);
|
||||||
|
|
||||||
|
void persistent_topk(const torch::stable::Tensor& logits,
|
||||||
|
const torch::stable::Tensor& lengths,
|
||||||
|
torch::stable::Tensor& output,
|
||||||
|
torch::stable::Tensor& workspace, int64_t k,
|
||||||
|
int64_t max_seq_len);
|
||||||
|
|
||||||
|
void selective_scan_fwd(
|
||||||
|
const torch::stable::Tensor& u, const torch::stable::Tensor& delta,
|
||||||
|
const torch::stable::Tensor& A, const torch::stable::Tensor& B,
|
||||||
|
const torch::stable::Tensor& C,
|
||||||
|
const std::optional<torch::stable::Tensor>& D_,
|
||||||
|
const std::optional<torch::stable::Tensor>& z_,
|
||||||
|
const std::optional<torch::stable::Tensor>& delta_bias_,
|
||||||
|
bool delta_softplus,
|
||||||
|
const std::optional<torch::stable::Tensor>& query_start_loc,
|
||||||
|
const std::optional<torch::stable::Tensor>& cache_indices,
|
||||||
|
const std::optional<torch::stable::Tensor>& has_initial_state,
|
||||||
|
const torch::stable::Tensor& ssm_states, int64_t null_block_id,
|
||||||
|
int64_t block_size,
|
||||||
|
const std::optional<torch::stable::Tensor>& block_idx_first_scheduled_token,
|
||||||
|
const std::optional<torch::stable::Tensor>& block_idx_last_scheduled_token,
|
||||||
|
const std::optional<torch::stable::Tensor>& initial_state_idx,
|
||||||
|
const std::optional<torch::stable::Tensor>& cu_chunk_seqlen,
|
||||||
|
const std::optional<torch::stable::Tensor>& last_chunk_indices);
|
||||||
|
|
||||||
// Activation kernels (shared CUDA/ROCm)
|
// Activation kernels (shared CUDA/ROCm)
|
||||||
void silu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input);
|
void silu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input);
|
||||||
void silu_and_mul_clamp(torch::stable::Tensor& out,
|
void silu_and_mul_clamp(torch::stable::Tensor& out,
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
#include "cuda_compat.h"
|
#include "../cuda_compat.h"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
|
#include "torch_utils.h"
|
||||||
#include <torch/cuda.h>
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#include <cub/cub.cuh>
|
#include <cub/cub.cuh>
|
||||||
@@ -618,14 +616,14 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
|
|||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void apply_repetition_penalties_(
|
void apply_repetition_penalties_(
|
||||||
torch::Tensor& logits, // [num_seqs, vocab_size], in-place
|
torch::stable::Tensor& logits, // [num_seqs, vocab_size], in-place
|
||||||
const torch::Tensor& prompt_mask, // [num_seqs, vocab_size]
|
const torch::stable::Tensor& prompt_mask, // [num_seqs, vocab_size]
|
||||||
const torch::Tensor& output_mask, // [num_seqs, vocab_size]
|
const torch::stable::Tensor& output_mask, // [num_seqs, vocab_size]
|
||||||
const torch::Tensor& repetition_penalties) { // [num_seqs]
|
const torch::stable::Tensor& repetition_penalties) { // [num_seqs]
|
||||||
TORCH_CHECK(logits.is_contiguous());
|
STD_TORCH_CHECK(logits.is_contiguous());
|
||||||
TORCH_CHECK(prompt_mask.is_contiguous());
|
STD_TORCH_CHECK(prompt_mask.is_contiguous());
|
||||||
TORCH_CHECK(output_mask.is_contiguous());
|
STD_TORCH_CHECK(output_mask.is_contiguous());
|
||||||
TORCH_CHECK(repetition_penalties.is_contiguous());
|
STD_TORCH_CHECK(repetition_penalties.is_contiguous());
|
||||||
|
|
||||||
int vocab_size = logits.size(-1);
|
int vocab_size = logits.size(-1);
|
||||||
int num_seqs = logits.size(0);
|
int num_seqs = logits.size(0);
|
||||||
@@ -635,7 +633,7 @@ void apply_repetition_penalties_(
|
|||||||
// Get number of SMs on the current device
|
// Get number of SMs on the current device
|
||||||
int sms = 0;
|
int sms = 0;
|
||||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount,
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount,
|
||||||
logits.get_device());
|
logits.get_device_index());
|
||||||
|
|
||||||
// Compute tile_num and tile_size
|
// Compute tile_num and tile_size
|
||||||
int tile_num =
|
int tile_num =
|
||||||
@@ -645,27 +643,29 @@ void apply_repetition_penalties_(
|
|||||||
// Each block handles one sequence and a tile of vocab
|
// Each block handles one sequence and a tile of vocab
|
||||||
dim3 grid(num_seqs, tile_num);
|
dim3 grid(num_seqs, tile_num);
|
||||||
dim3 block(std::min(tile_size, 1024));
|
dim3 block(std::min(tile_size, 1024));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(logits));
|
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
logits.get_device_index());
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
const cudaStream_t stream = get_current_cuda_stream();
|
||||||
|
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
|
||||||
logits.scalar_type(), "apply_repetition_penalties_kernel", [&] {
|
logits.scalar_type(), "apply_repetition_penalties_kernel", [&] {
|
||||||
vllm::apply_repetition_penalties_kernel<scalar_t>
|
vllm::apply_repetition_penalties_kernel<scalar_t>
|
||||||
<<<grid, block, 0, stream>>>(
|
<<<grid, block, 0, stream>>>(
|
||||||
logits.data_ptr<scalar_t>(), prompt_mask.data_ptr<bool>(),
|
logits.mutable_data_ptr<scalar_t>(),
|
||||||
output_mask.data_ptr<bool>(),
|
prompt_mask.const_data_ptr<bool>(),
|
||||||
repetition_penalties.data_ptr<scalar_t>(), num_seqs, vocab_size,
|
output_mask.const_data_ptr<bool>(),
|
||||||
tile_size);
|
repetition_penalties.const_data_ptr<scalar_t>(), num_seqs,
|
||||||
|
vocab_size, tile_size);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
void top_k_per_row_decode(const torch::stable::Tensor& logits, int64_t next_n,
|
||||||
const torch::Tensor& seqLens, torch::Tensor& indices,
|
const torch::stable::Tensor& seqLens,
|
||||||
int64_t numRows, int64_t stride0, int64_t stride1,
|
torch::stable::Tensor& indices, int64_t numRows,
|
||||||
int64_t topK) {
|
int64_t stride0, int64_t stride1, int64_t topK) {
|
||||||
constexpr int kSortingAlgorithmThreshold = 12288;
|
constexpr int kSortingAlgorithmThreshold = 12288;
|
||||||
constexpr int kSplitWorkThreshold = 200 * 1000;
|
constexpr int kSplitWorkThreshold = 200 * 1000;
|
||||||
constexpr int kNumThreadsPerBlock = 512;
|
constexpr int kNumThreadsPerBlock = 512;
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = get_current_cuda_stream();
|
||||||
const auto numColumns = logits.size(1);
|
const auto numColumns = logits.size(1);
|
||||||
|
|
||||||
// True if seqLens is 2D (B, next_n): each logit row has its own pre-computed
|
// True if seqLens is 2D (B, next_n): each logit row has its own pre-computed
|
||||||
@@ -677,73 +677,76 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
|||||||
// Use insertion sort
|
// Use insertion sort
|
||||||
vllm::topKPerRowDecode<kNumThreadsPerBlock, false>
|
vllm::topKPerRowDecode<kNumThreadsPerBlock, false>
|
||||||
<<<numRows, kNumThreadsPerBlock, topK * sizeof(int32_t), stream>>>(
|
<<<numRows, kNumThreadsPerBlock, topK * sizeof(int32_t), stream>>>(
|
||||||
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
|
logits.const_data_ptr<float>(), seqLens.const_data_ptr<int>(),
|
||||||
indices.data_ptr<int>(), static_cast<int>(stride0),
|
indices.mutable_data_ptr<int>(), static_cast<int>(stride0),
|
||||||
static_cast<int>(stride1), static_cast<int>(topK),
|
static_cast<int>(stride1), static_cast<int>(topK),
|
||||||
static_cast<int>(next_n), seqLensIs2D);
|
static_cast<int>(next_n), seqLensIs2D);
|
||||||
} else if (numColumns < kSplitWorkThreshold) {
|
} else if (numColumns < kSplitWorkThreshold) {
|
||||||
// From this threshold, use radix sort instead
|
// From this threshold, use radix sort instead
|
||||||
vllm::topKPerRowDecode<kNumThreadsPerBlock, true>
|
vllm::topKPerRowDecode<kNumThreadsPerBlock, true>
|
||||||
<<<numRows, kNumThreadsPerBlock, topK * sizeof(int32_t), stream>>>(
|
<<<numRows, kNumThreadsPerBlock, topK * sizeof(int32_t), stream>>>(
|
||||||
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
|
logits.const_data_ptr<float>(), seqLens.const_data_ptr<int>(),
|
||||||
indices.data_ptr<int>(), static_cast<int>(stride0),
|
indices.mutable_data_ptr<int>(), static_cast<int>(stride0),
|
||||||
static_cast<int>(stride1), static_cast<int>(topK),
|
static_cast<int>(stride1), static_cast<int>(topK),
|
||||||
static_cast<int>(next_n), seqLensIs2D);
|
static_cast<int>(next_n), seqLensIs2D);
|
||||||
} else {
|
} else {
|
||||||
// Long sequences are run in two steps
|
// Long sequences are run in two steps
|
||||||
constexpr auto multipleBlocksPerRowConfig = 10;
|
constexpr auto multipleBlocksPerRowConfig = 10;
|
||||||
|
|
||||||
const auto outIndicesAux =
|
const auto outIndicesAux = torch::stable::empty(
|
||||||
torch::empty({numRows, multipleBlocksPerRowConfig, topK},
|
{numRows, multipleBlocksPerRowConfig, topK},
|
||||||
torch::dtype(torch::kInt32).device(logits.device()));
|
torch::headeronly::ScalarType::Int, std::nullopt, logits.device());
|
||||||
const auto outLogitsAux =
|
const auto outLogitsAux = torch::stable::empty(
|
||||||
torch::empty({numRows, multipleBlocksPerRowConfig, topK},
|
{numRows, multipleBlocksPerRowConfig, topK},
|
||||||
torch::dtype(torch::kFloat).device(logits.device()));
|
torch::headeronly::ScalarType::Float, std::nullopt, logits.device());
|
||||||
|
|
||||||
vllm::topKPerRowDecode<kNumThreadsPerBlock, true, true>
|
vllm::topKPerRowDecode<kNumThreadsPerBlock, true, true>
|
||||||
<<<dim3(numRows, multipleBlocksPerRowConfig), kNumThreadsPerBlock,
|
<<<dim3(numRows, multipleBlocksPerRowConfig), kNumThreadsPerBlock,
|
||||||
2 * topK * sizeof(int32_t), stream>>>(
|
2 * topK * sizeof(int32_t), stream>>>(
|
||||||
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
|
logits.const_data_ptr<float>(), seqLens.const_data_ptr<int>(),
|
||||||
outIndicesAux.data_ptr<int>(), static_cast<int>(stride0),
|
outIndicesAux.mutable_data_ptr<int>(), static_cast<int>(stride0),
|
||||||
static_cast<int>(stride1), static_cast<int>(topK),
|
static_cast<int>(stride1), static_cast<int>(topK),
|
||||||
static_cast<int>(next_n), seqLensIs2D,
|
static_cast<int>(next_n), seqLensIs2D,
|
||||||
outLogitsAux.data_ptr<float>());
|
outLogitsAux.mutable_data_ptr<float>());
|
||||||
|
|
||||||
constexpr int kNumThreadsPerBlockMerge = 1024;
|
constexpr int kNumThreadsPerBlockMerge = 1024;
|
||||||
vllm::topKPerRowDecode<kNumThreadsPerBlockMerge, true, false, true>
|
vllm::topKPerRowDecode<kNumThreadsPerBlockMerge, true, false, true>
|
||||||
<<<numRows, kNumThreadsPerBlockMerge, topK * sizeof(int32_t), stream>>>(
|
<<<numRows, kNumThreadsPerBlockMerge, topK * sizeof(int32_t), stream>>>(
|
||||||
outLogitsAux.data_ptr<float>(), seqLens.data_ptr<int>(),
|
outLogitsAux.const_data_ptr<float>(), seqLens.const_data_ptr<int>(),
|
||||||
indices.data_ptr<int>(), multipleBlocksPerRowConfig * topK, 1,
|
indices.mutable_data_ptr<int>(), multipleBlocksPerRowConfig * topK,
|
||||||
static_cast<int>(topK), static_cast<int>(next_n), seqLensIs2D,
|
1, static_cast<int>(topK), static_cast<int>(next_n), seqLensIs2D,
|
||||||
nullptr, multipleBlocksPerRowConfig, outIndicesAux.data_ptr<int>());
|
nullptr, multipleBlocksPerRowConfig,
|
||||||
|
outIndicesAux.const_data_ptr<int>());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void top_k_per_row_prefill(const torch::Tensor& logits,
|
void top_k_per_row_prefill(const torch::stable::Tensor& logits,
|
||||||
const torch::Tensor& rowStarts,
|
const torch::stable::Tensor& rowStarts,
|
||||||
const torch::Tensor& rowEnds, torch::Tensor& indices,
|
const torch::stable::Tensor& rowEnds,
|
||||||
int64_t numRows, int64_t stride0, int64_t stride1,
|
torch::stable::Tensor& indices, int64_t numRows,
|
||||||
int64_t topK) {
|
int64_t stride0, int64_t stride1, int64_t topK) {
|
||||||
constexpr int kSortingAlgorithmThreshold = 12288;
|
constexpr int kSortingAlgorithmThreshold = 12288;
|
||||||
constexpr int kNumThreadsPerBlock = 512;
|
constexpr int kNumThreadsPerBlock = 512;
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = get_current_cuda_stream();
|
||||||
|
|
||||||
int numInsertionBlocks =
|
int numInsertionBlocks =
|
||||||
std::min(static_cast<int>(numRows), kSortingAlgorithmThreshold);
|
std::min(static_cast<int>(numRows), kSortingAlgorithmThreshold);
|
||||||
vllm::topKPerRowPrefill<kNumThreadsPerBlock, false>
|
vllm::topKPerRowPrefill<kNumThreadsPerBlock, false>
|
||||||
<<<numInsertionBlocks, kNumThreadsPerBlock, topK * sizeof(int32_t),
|
<<<numInsertionBlocks, kNumThreadsPerBlock, topK * sizeof(int32_t),
|
||||||
stream>>>(logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
|
stream>>>(logits.const_data_ptr<float>(),
|
||||||
rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
|
rowStarts.const_data_ptr<int>(),
|
||||||
static_cast<int>(stride0), static_cast<int>(stride1),
|
rowEnds.const_data_ptr<int>(),
|
||||||
static_cast<int>(topK), 0);
|
indices.mutable_data_ptr<int>(), static_cast<int>(stride0),
|
||||||
|
static_cast<int>(stride1), static_cast<int>(topK), 0);
|
||||||
|
|
||||||
if (numRows > kSortingAlgorithmThreshold) {
|
if (numRows > kSortingAlgorithmThreshold) {
|
||||||
int numRadixBlocks = numRows - kSortingAlgorithmThreshold;
|
int numRadixBlocks = numRows - kSortingAlgorithmThreshold;
|
||||||
vllm::topKPerRowPrefill<kNumThreadsPerBlock, true>
|
vllm::topKPerRowPrefill<kNumThreadsPerBlock, true>
|
||||||
<<<numRadixBlocks, kNumThreadsPerBlock, topK * sizeof(int32_t),
|
<<<numRadixBlocks, kNumThreadsPerBlock, topK * sizeof(int32_t),
|
||||||
stream>>>(logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
|
stream>>>(
|
||||||
rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
|
logits.const_data_ptr<float>(), rowStarts.const_data_ptr<int>(),
|
||||||
static_cast<int>(stride0), static_cast<int>(stride1),
|
rowEnds.const_data_ptr<int>(), indices.mutable_data_ptr<int>(),
|
||||||
static_cast<int>(topK), kSortingAlgorithmThreshold);
|
static_cast<int>(stride0), static_cast<int>(stride1),
|
||||||
|
static_cast<int>(topK), kSortingAlgorithmThreshold);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,49 +1,51 @@
|
|||||||
// Persistent TopK kernel for DeepSeek V3 sparse attention indexer.
|
// Persistent TopK kernel for DeepSeek V3 sparse attention indexer.
|
||||||
// See persistent_topk.cuh for kernel implementation.
|
// See persistent_topk.cuh for kernel implementation.
|
||||||
|
|
||||||
#include <torch/all.h>
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "torch_utils.h"
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#include "persistent_topk.cuh"
|
#include "../persistent_topk.cuh"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
template <int TopK>
|
template <int TopK>
|
||||||
void launch_persistent_topk(const torch::Tensor& logits,
|
void launch_persistent_topk(const torch::stable::Tensor& logits,
|
||||||
const torch::Tensor& lengths, torch::Tensor& output,
|
const torch::stable::Tensor& lengths,
|
||||||
torch::Tensor& workspace, int64_t max_seq_len) {
|
torch::stable::Tensor& output,
|
||||||
|
torch::stable::Tensor& workspace,
|
||||||
|
int64_t max_seq_len) {
|
||||||
namespace P = vllm::persistent;
|
namespace P = vllm::persistent;
|
||||||
|
|
||||||
const int64_t num_rows = logits.size(0);
|
const int64_t num_rows = logits.size(0);
|
||||||
const int64_t stride = logits.stride(0);
|
const int64_t stride = logits.stride(0);
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = get_current_cuda_stream();
|
||||||
|
|
||||||
static int num_sms = 0;
|
static int num_sms = 0;
|
||||||
static int max_smem_per_block = 0;
|
static int max_smem_per_block = 0;
|
||||||
if (num_sms == 0) {
|
if (num_sms == 0) {
|
||||||
int device;
|
const cudaDeviceProp* device_prop = get_device_prop();
|
||||||
cudaGetDevice(&device);
|
num_sms = device_prop->multiProcessorCount;
|
||||||
cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device);
|
max_smem_per_block = device_prop->sharedMemPerBlockOptin;
|
||||||
cudaDeviceGetAttribute(&max_smem_per_block,
|
|
||||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (num_rows > 32 && max_smem_per_block >= 128 * 1024) {
|
if (num_rows > 32 && max_smem_per_block >= 128 * 1024) {
|
||||||
cudaError_t status =
|
cudaError_t status =
|
||||||
vllm::FilteredTopKRaggedTransform<float, int32_t, TopK>(
|
vllm::FilteredTopKRaggedTransform<float, int32_t, TopK>(
|
||||||
logits.data_ptr<float>(), output.data_ptr<int32_t>(),
|
logits.const_data_ptr<float>(), output.mutable_data_ptr<int32_t>(),
|
||||||
lengths.data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
|
lengths.const_data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
|
||||||
static_cast<uint32_t>(TopK), static_cast<uint32_t>(stride), stream);
|
static_cast<uint32_t>(TopK), static_cast<uint32_t>(stride), stream);
|
||||||
TORCH_CHECK(status == cudaSuccess,
|
STD_TORCH_CHECK(status == cudaSuccess,
|
||||||
"FilteredTopK failed: ", cudaGetErrorString(status));
|
"FilteredTopK failed: ", cudaGetErrorString(status));
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(workspace.is_cuda(), "workspace must be CUDA tensor");
|
STD_TORCH_CHECK(workspace.is_cuda(), "workspace must be CUDA tensor");
|
||||||
TORCH_CHECK(workspace.dtype() == torch::kUInt8, "workspace must be uint8");
|
STD_TORCH_CHECK(
|
||||||
|
workspace.scalar_type() == torch::headeronly::ScalarType::Byte,
|
||||||
|
"workspace must be uint8");
|
||||||
|
|
||||||
int effective_max_smem;
|
int effective_max_smem;
|
||||||
if (num_rows <= 4) {
|
if (num_rows <= 4) {
|
||||||
@@ -99,9 +101,9 @@ void launch_persistent_topk(const torch::Tensor& logits,
|
|||||||
&occupancy, P::persistent_topk_kernel<TopK, 1>, P::kThreadsPerBlock,
|
&occupancy, P::persistent_topk_kernel<TopK, 1>, P::kThreadsPerBlock,
|
||||||
smem_size);
|
smem_size);
|
||||||
}
|
}
|
||||||
TORCH_CHECK(occ_err == cudaSuccess,
|
STD_TORCH_CHECK(occ_err == cudaSuccess,
|
||||||
"persistent_topk occupancy query failed: ",
|
"persistent_topk occupancy query failed: ",
|
||||||
cudaGetErrorString(occ_err));
|
cudaGetErrorString(occ_err));
|
||||||
if (occupancy < 1) occupancy = 1;
|
if (occupancy < 1) occupancy = 1;
|
||||||
|
|
||||||
// The cooperative spin-wait barrier only runs when at least one row hits
|
// The cooperative spin-wait barrier only runs when at least one row hits
|
||||||
@@ -131,27 +133,29 @@ void launch_persistent_topk(const torch::Tensor& logits,
|
|||||||
// If the cooperative launch wouldn't fit, fall back to FilteredTopK
|
// If the cooperative launch wouldn't fit, fall back to FilteredTopK
|
||||||
// instead of deadlocking. Only relevant when needs_cooperative.
|
// instead of deadlocking. Only relevant when needs_cooperative.
|
||||||
if (needs_cooperative && total_ctas > hw_resident_cap) {
|
if (needs_cooperative && total_ctas > hw_resident_cap) {
|
||||||
TORCH_CHECK(max_smem_per_block >= 128 * 1024,
|
STD_TORCH_CHECK(
|
||||||
"persistent_topk would oversubscribe and the FilteredTopK "
|
max_smem_per_block >= 128 * 1024,
|
||||||
"fallback requires >=128KB smem per block (have ",
|
"persistent_topk would oversubscribe and the FilteredTopK "
|
||||||
max_smem_per_block, "). total_ctas=", total_ctas,
|
"fallback requires >=128KB smem per block (have ",
|
||||||
" > num_sms*occupancy=", hw_resident_cap, " (TopK=", TopK,
|
max_smem_per_block, "). total_ctas=", total_ctas,
|
||||||
", vec_size=", vec_size, ", ctas_per_group=", ctas_per_group,
|
" > num_sms*occupancy=", hw_resident_cap, " (TopK=", TopK,
|
||||||
", smem=", smem_size, ").");
|
", vec_size=", vec_size, ", ctas_per_group=", ctas_per_group,
|
||||||
|
", smem=", smem_size, ").");
|
||||||
cudaError_t status =
|
cudaError_t status =
|
||||||
vllm::FilteredTopKRaggedTransform<float, int32_t, TopK>(
|
vllm::FilteredTopKRaggedTransform<float, int32_t, TopK>(
|
||||||
logits.data_ptr<float>(), output.data_ptr<int32_t>(),
|
logits.const_data_ptr<float>(),
|
||||||
lengths.data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
|
output.mutable_data_ptr<int32_t>(),
|
||||||
static_cast<uint32_t>(TopK), static_cast<uint32_t>(stride),
|
lengths.const_data_ptr<int32_t>(),
|
||||||
stream);
|
static_cast<uint32_t>(num_rows), static_cast<uint32_t>(TopK),
|
||||||
TORCH_CHECK(status == cudaSuccess,
|
static_cast<uint32_t>(stride), stream);
|
||||||
"FilteredTopK fallback failed: ", cudaGetErrorString(status));
|
STD_TORCH_CHECK(status == cudaSuccess, "FilteredTopK fallback failed: ",
|
||||||
|
cudaGetErrorString(status));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t state_bytes = num_groups * sizeof(P::RadixRowState);
|
size_t state_bytes = num_groups * sizeof(P::RadixRowState);
|
||||||
TORCH_CHECK(workspace.size(0) >= static_cast<int64_t>(state_bytes),
|
STD_TORCH_CHECK(workspace.size(0) >= static_cast<int64_t>(state_bytes),
|
||||||
"workspace too small, need ", state_bytes, " bytes");
|
"workspace too small, need ", state_bytes, " bytes");
|
||||||
|
|
||||||
// Zero the per-group RadixRowState region before launch.
|
// Zero the per-group RadixRowState region before launch.
|
||||||
//
|
//
|
||||||
@@ -179,22 +183,22 @@ void launch_persistent_topk(const torch::Tensor& logits,
|
|||||||
// first red_release. cudaMemsetAsync is stream-ordered: the zero
|
// first red_release. cudaMemsetAsync is stream-ordered: the zero
|
||||||
// is globally visible before any CTA runs.
|
// is globally visible before any CTA runs.
|
||||||
{
|
{
|
||||||
cudaError_t mz_err = cudaMemsetAsync(workspace.data_ptr<uint8_t>(), 0,
|
cudaError_t mz_err = cudaMemsetAsync(
|
||||||
state_bytes, stream);
|
workspace.mutable_data_ptr<uint8_t>(), 0, state_bytes, stream);
|
||||||
TORCH_CHECK(mz_err == cudaSuccess,
|
STD_TORCH_CHECK(mz_err == cudaSuccess,
|
||||||
"row_states memset failed: ", cudaGetErrorString(mz_err));
|
"row_states memset failed: ", cudaGetErrorString(mz_err));
|
||||||
}
|
}
|
||||||
|
|
||||||
P::PersistentTopKParams params;
|
P::PersistentTopKParams params;
|
||||||
params.input = logits.data_ptr<float>();
|
params.input = logits.const_data_ptr<float>();
|
||||||
params.output = output.data_ptr<int32_t>();
|
params.output = output.mutable_data_ptr<int32_t>();
|
||||||
params.lengths = lengths.data_ptr<int32_t>();
|
params.lengths = lengths.const_data_ptr<int32_t>();
|
||||||
params.num_rows = static_cast<uint32_t>(num_rows);
|
params.num_rows = static_cast<uint32_t>(num_rows);
|
||||||
params.stride = static_cast<uint32_t>(stride);
|
params.stride = static_cast<uint32_t>(stride);
|
||||||
params.top_k = static_cast<uint32_t>(TopK);
|
params.top_k = static_cast<uint32_t>(TopK);
|
||||||
params.chunk_size = chunk_size;
|
params.chunk_size = chunk_size;
|
||||||
params.row_states =
|
params.row_states = reinterpret_cast<P::RadixRowState*>(
|
||||||
reinterpret_cast<P::RadixRowState*>(workspace.data_ptr<uint8_t>());
|
workspace.mutable_data_ptr<uint8_t>());
|
||||||
params.ctas_per_group = ctas_per_group;
|
params.ctas_per_group = ctas_per_group;
|
||||||
params.max_seq_len = static_cast<uint32_t>(max_seq_len);
|
params.max_seq_len = static_cast<uint32_t>(max_seq_len);
|
||||||
|
|
||||||
@@ -203,8 +207,8 @@ void launch_persistent_topk(const torch::Tensor& logits,
|
|||||||
auto kernel = &P::persistent_topk_kernel<TOPK_VAL, VS>; \
|
auto kernel = &P::persistent_topk_kernel<TOPK_VAL, VS>; \
|
||||||
cudaError_t err = cudaFuncSetAttribute( \
|
cudaError_t err = cudaFuncSetAttribute( \
|
||||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); \
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); \
|
||||||
TORCH_CHECK(err == cudaSuccess, \
|
STD_TORCH_CHECK(err == cudaSuccess, \
|
||||||
"Failed to set smem: ", cudaGetErrorString(err)); \
|
"Failed to set smem: ", cudaGetErrorString(err)); \
|
||||||
kernel<<<total_ctas, P::kThreadsPerBlock, smem_size, stream>>>(params); \
|
kernel<<<total_ctas, P::kThreadsPerBlock, smem_size, stream>>>(params); \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
@@ -219,37 +223,42 @@ void launch_persistent_topk(const torch::Tensor& logits,
|
|||||||
}
|
}
|
||||||
|
|
||||||
cudaError_t err = cudaGetLastError();
|
cudaError_t err = cudaGetLastError();
|
||||||
TORCH_CHECK(err == cudaSuccess,
|
STD_TORCH_CHECK(err == cudaSuccess,
|
||||||
"persistent_topk failed: ", cudaGetErrorString(err));
|
"persistent_topk failed: ", cudaGetErrorString(err));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
|
void persistent_topk(const torch::stable::Tensor& logits,
|
||||||
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
|
const torch::stable::Tensor& lengths,
|
||||||
|
torch::stable::Tensor& output,
|
||||||
|
torch::stable::Tensor& workspace, int64_t k,
|
||||||
int64_t max_seq_len) {
|
int64_t max_seq_len) {
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor");
|
STD_TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor");
|
||||||
TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor");
|
STD_TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor");
|
||||||
TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor");
|
STD_TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor");
|
||||||
TORCH_CHECK(logits.dtype() == torch::kFloat32, "Only float32 supported");
|
STD_TORCH_CHECK(logits.scalar_type() == torch::headeronly::ScalarType::Float,
|
||||||
TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32");
|
"Only float32 supported");
|
||||||
TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32");
|
STD_TORCH_CHECK(lengths.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||||
TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
|
"lengths must be int32");
|
||||||
TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2,
|
STD_TORCH_CHECK(output.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||||
"lengths must be 1D or 2D");
|
"output must be int32");
|
||||||
TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous");
|
STD_TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
|
||||||
TORCH_CHECK(output.dim() == 2, "output must be 2D");
|
STD_TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2,
|
||||||
|
"lengths must be 1D or 2D");
|
||||||
|
STD_TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous");
|
||||||
|
STD_TORCH_CHECK(output.dim() == 2, "output must be 2D");
|
||||||
|
|
||||||
const int64_t num_rows = logits.size(0);
|
const int64_t num_rows = logits.size(0);
|
||||||
const int64_t stride = logits.stride(0);
|
|
||||||
|
|
||||||
TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch");
|
STD_TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch");
|
||||||
TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
|
STD_TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
|
||||||
"output size mismatch");
|
"output size mismatch");
|
||||||
TORCH_CHECK(k == 512 || k == 1024 || k == 2048,
|
STD_TORCH_CHECK(
|
||||||
"persistent_topk supports k=512, k=1024, or k=2048, got k=", k);
|
k == 512 || k == 1024 || k == 2048,
|
||||||
|
"persistent_topk supports k=512, k=1024, or k=2048, got k=", k);
|
||||||
|
|
||||||
if (k == 512) {
|
if (k == 512) {
|
||||||
launch_persistent_topk<512>(logits, lengths, output, workspace,
|
launch_persistent_topk<512>(logits, lengths, output, workspace,
|
||||||
@@ -262,6 +271,6 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
|
|||||||
max_seq_len);
|
max_seq_len);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
TORCH_CHECK(false, "persistent_topk is not supported on ROCm");
|
STD_TORCH_CHECK(false, "persistent_topk is not supported on ROCm");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
@@ -263,6 +263,20 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
|||||||
"CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor");
|
"CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Merge attn states
|
||||||
|
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||||
|
// can be used to combine partial attention results (in the split-KV case)
|
||||||
|
ops.def(
|
||||||
|
"merge_attn_states("
|
||||||
|
" Tensor! output,"
|
||||||
|
" Tensor!? output_lse,"
|
||||||
|
" Tensor prefix_output,"
|
||||||
|
" Tensor prefix_lse,"
|
||||||
|
" Tensor suffix_output,"
|
||||||
|
" Tensor suffix_lse,"
|
||||||
|
" int!? prefill_tokens_with_context,"
|
||||||
|
" Tensor? output_scale=None) -> ()");
|
||||||
|
|
||||||
// Hadamard transforms
|
// Hadamard transforms
|
||||||
// conditionally compiled so impl registration is in source file
|
// conditionally compiled so impl registration is in source file
|
||||||
ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor");
|
ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor");
|
||||||
@@ -319,6 +333,26 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
|||||||
"bool is_neox, Tensor position_ids, "
|
"bool is_neox, Tensor position_ids, "
|
||||||
"int forced_token_heads_per_warp=-1) -> ()");
|
"int forced_token_heads_per_warp=-1) -> ()");
|
||||||
|
|
||||||
|
// Apply repetition penalties to logits in-place.
|
||||||
|
ops.def(
|
||||||
|
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
|
||||||
|
"Tensor output_mask, Tensor repetition_penalties) -> ()");
|
||||||
|
|
||||||
|
// Optimized top-k per row operations.
|
||||||
|
ops.def(
|
||||||
|
"top_k_per_row_prefill(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
|
||||||
|
"Tensor! indices, int numRows, int stride0, "
|
||||||
|
"int stride1, int topK) -> ()");
|
||||||
|
|
||||||
|
ops.def(
|
||||||
|
"top_k_per_row_decode(Tensor logits, int next_n, "
|
||||||
|
"Tensor seq_lens, Tensor! indices, "
|
||||||
|
"int numRows, int stride0, int stride1, int topK) -> ()");
|
||||||
|
|
||||||
|
ops.def(
|
||||||
|
"persistent_topk(Tensor logits, Tensor lengths, Tensor! output, "
|
||||||
|
"Tensor workspace, int k, int max_seq_len) -> ()");
|
||||||
|
|
||||||
// Activation ops
|
// Activation ops
|
||||||
// Activation function used in SwiGLU.
|
// Activation function used in SwiGLU.
|
||||||
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
|
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
|
||||||
@@ -422,6 +456,24 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
|||||||
"int type, SymInt row, SymInt tokens) -> Tensor");
|
"int type, SymInt row, SymInt tokens) -> Tensor");
|
||||||
|
|
||||||
ops.def("ggml_moe_get_block_size(int type) -> int");
|
ops.def("ggml_moe_get_block_size(int type) -> int");
|
||||||
|
|
||||||
|
// Mamba selective scan kernel
|
||||||
|
ops.def(
|
||||||
|
"selective_scan_fwd(Tensor! u, Tensor! delta,"
|
||||||
|
"Tensor! A, Tensor! B, Tensor! C,"
|
||||||
|
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
|
||||||
|
"bool delta_softplus,"
|
||||||
|
"Tensor? query_start_loc,"
|
||||||
|
"Tensor? cache_indices,"
|
||||||
|
"Tensor? has_initial_state,"
|
||||||
|
"Tensor! ssm_states,"
|
||||||
|
"int null_block_id,"
|
||||||
|
"int block_size,"
|
||||||
|
"Tensor? block_idx_first_scheduled_token,"
|
||||||
|
"Tensor? block_idx_last_scheduled_token,"
|
||||||
|
"Tensor? initial_state_idx,"
|
||||||
|
"Tensor? cu_chunk_seqlen,"
|
||||||
|
"Tensor? last_chunk_indices) -> ()");
|
||||||
}
|
}
|
||||||
|
|
||||||
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||||
@@ -469,6 +521,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
|||||||
// files (allspark_repack.cu and allspark_qgemm_w8a16.cu)
|
// files (allspark_repack.cu and allspark_qgemm_w8a16.cu)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
ops.impl("merge_attn_states", TORCH_BOX(&merge_attn_states));
|
||||||
|
|
||||||
// Layernorm kernels (shared CUDA/ROCm)
|
// Layernorm kernels (shared CUDA/ROCm)
|
||||||
ops.impl("rms_norm", TORCH_BOX(&rms_norm));
|
ops.impl("rms_norm", TORCH_BOX(&rms_norm));
|
||||||
ops.impl("fused_add_rms_norm", TORCH_BOX(&fused_add_rms_norm));
|
ops.impl("fused_add_rms_norm", TORCH_BOX(&fused_add_rms_norm));
|
||||||
@@ -487,6 +541,13 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
|||||||
ops.impl("rotary_embedding", TORCH_BOX(&rotary_embedding));
|
ops.impl("rotary_embedding", TORCH_BOX(&rotary_embedding));
|
||||||
ops.impl("fused_qk_norm_rope", TORCH_BOX(&fused_qk_norm_rope));
|
ops.impl("fused_qk_norm_rope", TORCH_BOX(&fused_qk_norm_rope));
|
||||||
|
|
||||||
|
// Sampler kernels (shared CUDA/ROCm)
|
||||||
|
ops.impl("apply_repetition_penalties_",
|
||||||
|
TORCH_BOX(&apply_repetition_penalties_));
|
||||||
|
ops.impl("top_k_per_row_prefill", TORCH_BOX(&top_k_per_row_prefill));
|
||||||
|
ops.impl("top_k_per_row_decode", TORCH_BOX(&top_k_per_row_decode));
|
||||||
|
ops.impl("persistent_topk", TORCH_BOX(&persistent_topk));
|
||||||
|
|
||||||
// Activation kernels (shared CUDA/ROCm)
|
// Activation kernels (shared CUDA/ROCm)
|
||||||
ops.impl("silu_and_mul", TORCH_BOX(&silu_and_mul));
|
ops.impl("silu_and_mul", TORCH_BOX(&silu_and_mul));
|
||||||
ops.impl("mul_and_silu", TORCH_BOX(&mul_and_silu));
|
ops.impl("mul_and_silu", TORCH_BOX(&mul_and_silu));
|
||||||
@@ -519,6 +580,7 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
|||||||
ops.impl("ggml_mul_mat_a8", TORCH_BOX(&ggml_mul_mat_a8));
|
ops.impl("ggml_mul_mat_a8", TORCH_BOX(&ggml_mul_mat_a8));
|
||||||
ops.impl("ggml_moe_a8", TORCH_BOX(&ggml_moe_a8));
|
ops.impl("ggml_moe_a8", TORCH_BOX(&ggml_moe_a8));
|
||||||
ops.impl("ggml_moe_a8_vec", TORCH_BOX(&ggml_moe_a8_vec));
|
ops.impl("ggml_moe_a8_vec", TORCH_BOX(&ggml_moe_a8_vec));
|
||||||
|
ops.impl("selective_scan_fwd", TORCH_BOX(&selective_scan_fwd));
|
||||||
}
|
}
|
||||||
|
|
||||||
// These capability-check functions take only primitive args (no tensors), so
|
// These capability-check functions take only primitive args (no tensors), so
|
||||||
|
|||||||
-43
@@ -54,13 +54,6 @@ void paged_attention_v2(
|
|||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
const int64_t blocksparse_head_sliding_step);
|
const int64_t blocksparse_head_sliding_step);
|
||||||
|
|
||||||
void merge_attn_states(
|
|
||||||
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
|
|
||||||
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
|
|
||||||
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
|
|
||||||
const std::optional<int64_t> prefill_tokens_with_context,
|
|
||||||
const std::optional<torch::Tensor>& output_scale = std::nullopt);
|
|
||||||
|
|
||||||
// rms_norm and fused_add_rms_norm declarations also exist in
|
// rms_norm and fused_add_rms_norm declarations also exist in
|
||||||
// csrc/libtorch_stable/ops.h (torch::stable ABI for CUDA). They remain here
|
// csrc/libtorch_stable/ops.h (torch::stable ABI for CUDA). They remain here
|
||||||
// because the CPU build still uses these torch::Tensor declarations.
|
// because the CPU build still uses these torch::Tensor declarations.
|
||||||
@@ -76,26 +69,6 @@ torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
|
|||||||
torch::Tensor const& cos_sin_cache, int64_t q_head_padded, double eps,
|
torch::Tensor const& cos_sin_cache, int64_t q_head_padded, double eps,
|
||||||
int64_t cache_block_size);
|
int64_t cache_block_size);
|
||||||
|
|
||||||
void apply_repetition_penalties_(torch::Tensor& logits,
|
|
||||||
const torch::Tensor& prompt_mask,
|
|
||||||
const torch::Tensor& output_mask,
|
|
||||||
const torch::Tensor& repetition_penalties);
|
|
||||||
|
|
||||||
void top_k_per_row_prefill(const torch::Tensor& logits,
|
|
||||||
const torch::Tensor& rowStarts,
|
|
||||||
const torch::Tensor& rowEnds, torch::Tensor& indices,
|
|
||||||
int64_t numRows, int64_t stride0, int64_t stride1,
|
|
||||||
int64_t topK);
|
|
||||||
|
|
||||||
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
|
||||||
const torch::Tensor& seqLens, torch::Tensor& indices,
|
|
||||||
int64_t numRows, int64_t stride0, int64_t stride1,
|
|
||||||
int64_t topK);
|
|
||||||
|
|
||||||
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
|
|
||||||
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
|
|
||||||
int64_t max_seq_len);
|
|
||||||
|
|
||||||
void silu_and_mul_per_block_quant(torch::Tensor& out,
|
void silu_and_mul_per_block_quant(torch::Tensor& out,
|
||||||
torch::Tensor const& input,
|
torch::Tensor const& input,
|
||||||
torch::Tensor& scales, int64_t group_size,
|
torch::Tensor& scales, int64_t group_size,
|
||||||
@@ -150,22 +123,6 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
|||||||
torch::Tensor& scales,
|
torch::Tensor& scales,
|
||||||
std::optional<torch::Tensor> const& azp);
|
std::optional<torch::Tensor> const& azp);
|
||||||
|
|
||||||
void selective_scan_fwd(
|
|
||||||
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
|
|
||||||
const torch::Tensor& B, const torch::Tensor& C,
|
|
||||||
const std::optional<torch::Tensor>& D_,
|
|
||||||
const std::optional<torch::Tensor>& z_,
|
|
||||||
const std::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
|
|
||||||
const std::optional<torch::Tensor>& query_start_loc,
|
|
||||||
const std::optional<torch::Tensor>& cache_indices,
|
|
||||||
const std::optional<torch::Tensor>& has_initial_state,
|
|
||||||
const torch::Tensor& ssm_states, int64_t null_block_id, int64_t block_size,
|
|
||||||
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
|
|
||||||
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
|
|
||||||
const std::optional<torch::Tensor>& initial_state_idx,
|
|
||||||
const std::optional<torch::Tensor>& cu_chunk_seqlen,
|
|
||||||
const std::optional<torch::Tensor>& last_chunk_indices);
|
|
||||||
|
|
||||||
torch::Tensor dynamic_4bit_int_moe_cpu(
|
torch::Tensor dynamic_4bit_int_moe_cpu(
|
||||||
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
|
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
|
||||||
torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
|
torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
|
||||||
|
|||||||
@@ -126,10 +126,10 @@ struct RadixRowState {
|
|||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
||||||
struct PersistentTopKParams {
|
struct PersistentTopKParams {
|
||||||
const float* __restrict__ input; // [num_rows, stride]
|
const float* __restrict__ input; // [num_rows, stride]
|
||||||
int32_t* __restrict__ output; // [num_rows, top_k]
|
int32_t* __restrict__ output; // [num_rows, top_k]
|
||||||
int32_t* __restrict__ lengths; // [num_rows]
|
const int32_t* __restrict__ lengths; // [num_rows]
|
||||||
RadixRowState* row_states; // large path: per-group state
|
RadixRowState* row_states; // large path: per-group state
|
||||||
uint32_t num_rows;
|
uint32_t num_rows;
|
||||||
uint32_t stride;
|
uint32_t stride;
|
||||||
uint32_t top_k; // actual k value for output stride
|
uint32_t top_k; // actual k value for output stride
|
||||||
@@ -1269,9 +1269,11 @@ constexpr int ComputeFilteredTopKVecSize(uint32_t max_len) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename DType, typename IdType, uint32_t MAX_K = 2048>
|
template <typename DType, typename IdType, uint32_t MAX_K = 2048>
|
||||||
cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices,
|
cudaError_t FilteredTopKRaggedTransform(const DType* input,
|
||||||
IdType* lengths, uint32_t num_rows,
|
IdType* output_indices,
|
||||||
uint32_t top_k_val, uint32_t max_len,
|
const IdType* lengths,
|
||||||
|
uint32_t num_rows, uint32_t top_k_val,
|
||||||
|
uint32_t max_len,
|
||||||
cudaStream_t stream = 0) {
|
cudaStream_t stream = 0) {
|
||||||
constexpr size_t smem_size = FILTERED_TOPK_SMEM_DYNAMIC;
|
constexpr size_t smem_size = FILTERED_TOPK_SMEM_DYNAMIC;
|
||||||
constexpr int MAX_VEC = 16 / sizeof(DType);
|
constexpr int MAX_VEC = 16 / sizeof(DType);
|
||||||
|
|||||||
@@ -62,21 +62,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" int blocksparse_head_sliding_step) -> ()");
|
" int blocksparse_head_sliding_step) -> ()");
|
||||||
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
||||||
|
|
||||||
// Merge attn states
|
|
||||||
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
|
||||||
// can be used to combine partial attention results (in the split-KV case)
|
|
||||||
ops.def(
|
|
||||||
"merge_attn_states("
|
|
||||||
" Tensor! output,"
|
|
||||||
" Tensor!? output_lse,"
|
|
||||||
" Tensor prefix_output,"
|
|
||||||
" Tensor prefix_lse,"
|
|
||||||
" Tensor suffix_output,"
|
|
||||||
" Tensor suffix_lse,"
|
|
||||||
" int!? prefill_tokens_with_context,"
|
|
||||||
" Tensor? output_scale=None) -> ()");
|
|
||||||
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
|
|
||||||
|
|
||||||
// Activation ops (quantized only — basic ops moved to _C_stable_libtorch)
|
// Activation ops (quantized only — basic ops moved to _C_stable_libtorch)
|
||||||
ops.def(
|
ops.def(
|
||||||
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
|
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
|
||||||
@@ -105,31 +90,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA,
|
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA,
|
||||||
&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert);
|
&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert);
|
||||||
|
|
||||||
// Apply repetition penalties to logits in-place
|
|
||||||
ops.def(
|
|
||||||
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
|
|
||||||
"Tensor output_mask, Tensor repetition_penalties) -> ()");
|
|
||||||
ops.impl("apply_repetition_penalties_", torch::kCUDA,
|
|
||||||
&apply_repetition_penalties_);
|
|
||||||
|
|
||||||
// Optimized top-k per row operation
|
|
||||||
ops.def(
|
|
||||||
"top_k_per_row_prefill(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
|
|
||||||
"Tensor! indices, int numRows, int stride0, "
|
|
||||||
"int stride1, int topK) -> ()");
|
|
||||||
ops.impl("top_k_per_row_prefill", torch::kCUDA, &top_k_per_row_prefill);
|
|
||||||
|
|
||||||
ops.def(
|
|
||||||
"top_k_per_row_decode(Tensor logits, int next_n, "
|
|
||||||
"Tensor seq_lens, Tensor! indices, "
|
|
||||||
"int numRows, int stride0, int stride1, int topK) -> ()");
|
|
||||||
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
|
|
||||||
|
|
||||||
ops.def(
|
|
||||||
"persistent_topk(Tensor logits, Tensor lengths, Tensor! output, "
|
|
||||||
"Tensor workspace, int k, int max_seq_len) -> ()");
|
|
||||||
ops.impl("persistent_topk", torch::kCUDA, &persistent_topk);
|
|
||||||
|
|
||||||
// Quantization ops
|
// Quantization ops
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
|
|
||||||
@@ -230,25 +190,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Mamba selective scan kernel
|
|
||||||
ops.def(
|
|
||||||
"selective_scan_fwd(Tensor! u, Tensor! delta,"
|
|
||||||
"Tensor! A, Tensor! B, Tensor! C,"
|
|
||||||
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
|
|
||||||
"bool delta_softplus,"
|
|
||||||
"Tensor? query_start_loc,"
|
|
||||||
"Tensor? cache_indices,"
|
|
||||||
"Tensor? has_initial_state,"
|
|
||||||
"Tensor! ssm_states,"
|
|
||||||
"int null_block_id,"
|
|
||||||
"int block_size,"
|
|
||||||
"Tensor? block_idx_first_scheduled_token,"
|
|
||||||
"Tensor? block_idx_last_scheduled_token,"
|
|
||||||
"Tensor? initial_state_idx,"
|
|
||||||
"Tensor? cu_chunk_seqlen,"
|
|
||||||
"Tensor? last_chunk_indices) -> ()");
|
|
||||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
ops.def(
|
ops.def(
|
||||||
"minimax_allreduce_rms("
|
"minimax_allreduce_rms("
|
||||||
|
|||||||
@@ -141,6 +141,7 @@ bbc5b7ede = "bbc5b7ede"
|
|||||||
NOOPs = "NOOPs"
|
NOOPs = "NOOPs"
|
||||||
nin_shortcut = "nin_shortcut"
|
nin_shortcut = "nin_shortcut"
|
||||||
cudaDevAttrMaxSharedMemoryPerBlockOptin = "cudaDevAttrMaxSharedMemoryPerBlockOptin"
|
cudaDevAttrMaxSharedMemoryPerBlockOptin = "cudaDevAttrMaxSharedMemoryPerBlockOptin"
|
||||||
|
sharedMemPerBlockOptin = "sharedMemPerBlockOptin"
|
||||||
|
|
||||||
depthwise_seperable_out_channel = "depthwise_seperable_out_channel"
|
depthwise_seperable_out_channel = "depthwise_seperable_out_channel"
|
||||||
pard_token = "pard_token"
|
pard_token = "pard_token"
|
||||||
|
|||||||
Reference in New Issue
Block a user