mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
545 lines
27 KiB
C++
545 lines
27 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/stable/library.h>
|
|
#include <torch/csrc/stable/tensor.h>
|
|
|
|
void per_token_group_quant_fp8(const torch::stable::Tensor& input,
|
|
torch::stable::Tensor& output_q,
|
|
torch::stable::Tensor& output_s,
|
|
int64_t group_size, double eps, double fp8_min,
|
|
double fp8_max, bool scale_ue8m0,
|
|
bool dummy_is_scale_transposed,
|
|
bool dummy_is_tma_aligned);
|
|
|
|
// Fused activation quantisation + DeepGEMM-compatible UE8M0-packed scales.
|
|
void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
|
|
torch::stable::Tensor& output_q,
|
|
torch::stable::Tensor& output_s_packed,
|
|
int64_t group_size, double eps,
|
|
double min_8bit, double max_8bit);
|
|
|
|
void per_token_group_quant_int8(const torch::stable::Tensor& input,
|
|
torch::stable::Tensor& output_q,
|
|
torch::stable::Tensor& output_s,
|
|
int64_t group_size, double eps, double int8_min,
|
|
double int8_max);
|
|
|
|
#ifndef USE_ROCM
|
|
torch::stable::Tensor permute_cols(torch::stable::Tensor const& A,
|
|
torch::stable::Tensor const& perm);
|
|
|
|
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
|
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
|
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
|
|
|
|
void cutlass_scaled_mm(torch::stable::Tensor& out,
|
|
torch::stable::Tensor const& a,
|
|
torch::stable::Tensor const& b,
|
|
torch::stable::Tensor const& a_scales,
|
|
torch::stable::Tensor const& b_scales,
|
|
std::optional<torch::stable::Tensor> const& bias);
|
|
|
|
void cutlass_moe_mm(torch::stable::Tensor& out_tensors,
|
|
torch::stable::Tensor const& a_tensors,
|
|
torch::stable::Tensor const& b_tensors,
|
|
torch::stable::Tensor const& a_scales,
|
|
torch::stable::Tensor const& b_scales,
|
|
torch::stable::Tensor const& expert_offsets,
|
|
torch::stable::Tensor const& problem_sizes,
|
|
torch::stable::Tensor const& a_strides,
|
|
torch::stable::Tensor const& b_strides,
|
|
torch::stable::Tensor const& c_strides, bool per_act_token,
|
|
bool per_out_ch);
|
|
|
|
void cutlass_scaled_mm_azp(torch::stable::Tensor& out,
|
|
torch::stable::Tensor const& a,
|
|
torch::stable::Tensor const& b,
|
|
torch::stable::Tensor const& a_scales,
|
|
torch::stable::Tensor const& b_scales,
|
|
torch::stable::Tensor const& azp_adj,
|
|
std::optional<torch::stable::Tensor> const& azp,
|
|
std::optional<torch::stable::Tensor> const& bias);
|
|
|
|
void get_cutlass_moe_mm_data(
|
|
const torch::stable::Tensor& topk_ids,
|
|
torch::stable::Tensor& expert_offsets,
|
|
torch::stable::Tensor& problem_sizes1,
|
|
torch::stable::Tensor& problem_sizes2,
|
|
torch::stable::Tensor& input_permutation,
|
|
torch::stable::Tensor& output_permutation, const int64_t num_experts,
|
|
const int64_t n, const int64_t k,
|
|
const std::optional<torch::stable::Tensor>& blockscale_offsets,
|
|
const bool is_gated);
|
|
|
|
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
|
const torch::stable::Tensor& expert_first_token_offset,
|
|
torch::stable::Tensor& problem_sizes1,
|
|
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
|
|
const bool swap_ab);
|
|
|
|
void get_cutlass_batched_moe_mm_data(
|
|
torch::stable::Tensor& expert_offsets,
|
|
torch::stable::Tensor& problem_sizes1,
|
|
torch::stable::Tensor& problem_sizes2,
|
|
const torch::stable::Tensor& expert_num_tokens,
|
|
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
|
const int64_t k);
|
|
|
|
// FP4/NVFP4 ops
|
|
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
|
|
|
|
void cutlass_scaled_fp4_mm(torch::stable::Tensor& D,
|
|
torch::stable::Tensor const& A,
|
|
torch::stable::Tensor const& B,
|
|
torch::stable::Tensor const& A_sf,
|
|
torch::stable::Tensor const& B_sf,
|
|
torch::stable::Tensor const& alpha);
|
|
|
|
void cutlass_fp4_group_mm(torch::stable::Tensor& output,
|
|
const torch::stable::Tensor& a,
|
|
const torch::stable::Tensor& b,
|
|
const torch::stable::Tensor& a_blockscale,
|
|
const torch::stable::Tensor& b_blockscales,
|
|
const torch::stable::Tensor& alphas,
|
|
const torch::stable::Tensor& problem_sizes,
|
|
const torch::stable::Tensor& expert_offsets,
|
|
const torch::stable::Tensor& sf_offsets);
|
|
|
|
std::tuple<torch::stable::Tensor, torch::stable::Tensor> scaled_fp4_quant_func(
|
|
torch::stable::Tensor const& input,
|
|
torch::stable::Tensor const& input_scale, bool is_sf_swizzled_layout);
|
|
|
|
void scaled_fp4_quant_out(torch::stable::Tensor const& input,
|
|
torch::stable::Tensor const& input_scale,
|
|
bool is_sf_swizzled_layout,
|
|
torch::stable::Tensor& output,
|
|
torch::stable::Tensor& output_scale);
|
|
|
|
void scaled_fp4_experts_quant(
|
|
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
|
torch::stable::Tensor const& input,
|
|
torch::stable::Tensor const& input_global_scale,
|
|
torch::stable::Tensor const& input_offset_by_experts,
|
|
torch::stable::Tensor const& output_scale_offset_by_experts);
|
|
|
|
void silu_and_mul_scaled_fp4_experts_quant(
|
|
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
|
torch::stable::Tensor const& input,
|
|
torch::stable::Tensor const& input_global_scale,
|
|
torch::stable::Tensor const& input_offset_by_experts,
|
|
torch::stable::Tensor const& output_scale_offset_by_experts);
|
|
|
|
void silu_and_mul_nvfp4_quant(torch::stable::Tensor& out,
|
|
torch::stable::Tensor& output_block_scale,
|
|
torch::stable::Tensor& input,
|
|
torch::stable::Tensor& input_global_scale);
|
|
|
|
void cutlass_mxfp4_group_mm(torch::stable::Tensor& output,
|
|
const torch::stable::Tensor& a,
|
|
const torch::stable::Tensor& b,
|
|
const torch::stable::Tensor& a_blockscale,
|
|
const torch::stable::Tensor& b_blockscales,
|
|
const torch::stable::Tensor& problem_sizes,
|
|
const torch::stable::Tensor& expert_offsets,
|
|
const torch::stable::Tensor& sf_offsets);
|
|
|
|
// AWQ ops
|
|
torch::stable::Tensor awq_gemm(torch::stable::Tensor _in_feats,
|
|
torch::stable::Tensor _kernel,
|
|
torch::stable::Tensor _scaling_factors,
|
|
torch::stable::Tensor _zeros,
|
|
int64_t split_k_iters);
|
|
|
|
torch::stable::Tensor awq_dequantize(torch::stable::Tensor _kernel,
|
|
torch::stable::Tensor _scaling_factors,
|
|
torch::stable::Tensor _zeros,
|
|
int64_t split_k_iters, int64_t thx,
|
|
int64_t thy);
|
|
|
|
// DSV3 fused A GEMM: conditionally compiled so declaration and impl
|
|
// registration are in the source file (dsv3_fused_a_gemm.cu)
|
|
|
|
// AllSpark ops: declarations are in the source files
|
|
// (allspark_repack.cu and allspark_qgemm_w8a16.cu)
|
|
|
|
#endif
|
|
|
|
// Attention kernels (shared CUDA/ROCm)
|
|
void merge_attn_states(
|
|
torch::stable::Tensor& output,
|
|
std::optional<torch::stable::Tensor> output_lse,
|
|
const torch::stable::Tensor& prefix_output,
|
|
const torch::stable::Tensor& prefix_lse,
|
|
const torch::stable::Tensor& suffix_output,
|
|
const torch::stable::Tensor& suffix_lse,
|
|
const std::optional<int64_t> prefill_tokens_with_context,
|
|
const std::optional<torch::stable::Tensor>& output_scale = std::nullopt);
|
|
|
|
torch::stable::Tensor hadacore_transform(torch::stable::Tensor& x,
|
|
bool inplace);
|
|
|
|
// Layernorm kernels (shared CUDA/ROCm)
|
|
void rms_norm(torch::stable::Tensor& out, torch::stable::Tensor& input,
|
|
torch::stable::Tensor& weight, double epsilon);
|
|
|
|
void fused_add_rms_norm(torch::stable::Tensor& input,
|
|
torch::stable::Tensor& residual,
|
|
torch::stable::Tensor& weight, double epsilon);
|
|
|
|
// Layernorm-quant kernels (shared CUDA/ROCm)
|
|
void rms_norm_static_fp8_quant(torch::stable::Tensor& out,
|
|
torch::stable::Tensor& input,
|
|
torch::stable::Tensor& weight,
|
|
torch::stable::Tensor& scale, double epsilon);
|
|
|
|
void fused_add_rms_norm_static_fp8_quant(torch::stable::Tensor& out,
|
|
torch::stable::Tensor& input,
|
|
torch::stable::Tensor& residual,
|
|
torch::stable::Tensor& weight,
|
|
torch::stable::Tensor& scale,
|
|
double epsilon);
|
|
|
|
// Fused layernorm + dynamic per-token quant kernels (shared CUDA/ROCm)
|
|
void rms_norm_dynamic_per_token_quant(
|
|
torch::stable::Tensor& out, torch::stable::Tensor const& input,
|
|
torch::stable::Tensor const& weight, torch::stable::Tensor& scales,
|
|
double const var_epsilon, std::optional<torch::stable::Tensor> scale_ub,
|
|
std::optional<torch::stable::Tensor> residual);
|
|
|
|
void rms_norm_per_block_quant(torch::stable::Tensor& out,
|
|
torch::stable::Tensor const& input,
|
|
torch::stable::Tensor const& weight,
|
|
torch::stable::Tensor& scales,
|
|
double const var_epsilon,
|
|
std::optional<torch::stable::Tensor> scale_ub,
|
|
std::optional<torch::stable::Tensor> residual,
|
|
int64_t group_size, bool is_scale_transposed);
|
|
|
|
// Positional encoding kernels (shared CUDA/ROCm)
|
|
void rotary_embedding(torch::stable::Tensor& positions,
|
|
torch::stable::Tensor& query,
|
|
std::optional<torch::stable::Tensor> key,
|
|
int64_t head_size, torch::stable::Tensor& cos_sin_cache,
|
|
bool is_neox, int64_t rope_dim_offset, bool inverse);
|
|
|
|
void fused_qk_norm_rope(torch::stable::Tensor& qkv, int64_t num_heads_q,
|
|
int64_t num_heads_k, int64_t num_heads_v,
|
|
int64_t head_dim, double eps,
|
|
torch::stable::Tensor& q_weight,
|
|
torch::stable::Tensor& k_weight,
|
|
torch::stable::Tensor& cos_sin_cache, bool is_neox,
|
|
torch::stable::Tensor& position_ids,
|
|
int64_t forced_token_heads_per_warp);
|
|
|
|
torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
|
|
torch::stable::Tensor const& q_in, torch::stable::Tensor const& kv,
|
|
torch::stable::Tensor& k_cache, torch::stable::Tensor const& slot_mapping,
|
|
torch::stable::Tensor const& position_ids,
|
|
torch::stable::Tensor const& cos_sin_cache, int64_t q_head_padded,
|
|
double eps, int64_t cache_block_size);
|
|
|
|
void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert(
|
|
torch::stable::Tensor& q, torch::stable::Tensor const& kv,
|
|
torch::stable::Tensor& k_cache, torch::stable::Tensor const& slot_mapping,
|
|
torch::stable::Tensor const& position_ids,
|
|
torch::stable::Tensor const& cos_sin_cache, double eps,
|
|
int64_t cache_block_size);
|
|
|
|
void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert(
|
|
torch::stable::Tensor const& q, torch::stable::Tensor const& kv,
|
|
torch::stable::Tensor& q_fp8, torch::stable::Tensor& k_cache,
|
|
torch::stable::Tensor const& slot_mapping,
|
|
torch::stable::Tensor const& position_ids,
|
|
torch::stable::Tensor const& cos_sin_cache,
|
|
torch::stable::Tensor const& fp8_scale,
|
|
torch::stable::Tensor const& q_fp8_scale_inv, double eps,
|
|
int64_t cache_block_size);
|
|
|
|
#ifndef USE_ROCM
|
|
torch::stable::Tensor minimax_allreduce_rms(
|
|
torch::stable::Tensor const& input,
|
|
torch::stable::Tensor const& norm_weight, torch::stable::Tensor workspace,
|
|
int64_t const rank, int64_t const nranks, double const eps);
|
|
std::tuple<torch::stable::Tensor, torch::stable::Tensor>
|
|
minimax_allreduce_rms_qk(torch::stable::Tensor qkv,
|
|
torch::stable::Tensor const& norm_weight_q,
|
|
torch::stable::Tensor const& norm_weight_k,
|
|
torch::stable::Tensor workspace, int64_t const q_size,
|
|
int64_t const kv_size, int64_t const rank,
|
|
int64_t const nranks, double const eps);
|
|
#endif
|
|
|
|
// Sampler kernels (shared CUDA/ROCm)
|
|
void apply_repetition_penalties_(
|
|
torch::stable::Tensor& logits, const torch::stable::Tensor& prompt_mask,
|
|
const torch::stable::Tensor& output_mask,
|
|
const torch::stable::Tensor& repetition_penalties);
|
|
|
|
void top_k_per_row_prefill(const torch::stable::Tensor& logits,
|
|
const torch::stable::Tensor& rowStarts,
|
|
const torch::stable::Tensor& rowEnds,
|
|
torch::stable::Tensor& indices, int64_t numRows,
|
|
int64_t stride0, int64_t stride1, int64_t topK);
|
|
|
|
void top_k_per_row_decode(const torch::stable::Tensor& logits, int64_t next_n,
|
|
const torch::stable::Tensor& seqLens,
|
|
torch::stable::Tensor& indices, int64_t numRows,
|
|
int64_t stride0, int64_t stride1, int64_t topK);
|
|
|
|
void persistent_topk(const torch::stable::Tensor& logits,
|
|
const torch::stable::Tensor& lengths,
|
|
torch::stable::Tensor& output,
|
|
torch::stable::Tensor& workspace, int64_t k,
|
|
int64_t max_seq_len);
|
|
|
|
void selective_scan_fwd(
|
|
const torch::stable::Tensor& u, const torch::stable::Tensor& delta,
|
|
const torch::stable::Tensor& A, const torch::stable::Tensor& B,
|
|
const torch::stable::Tensor& C,
|
|
const std::optional<torch::stable::Tensor>& D_,
|
|
const std::optional<torch::stable::Tensor>& z_,
|
|
const std::optional<torch::stable::Tensor>& delta_bias_,
|
|
bool delta_softplus,
|
|
const std::optional<torch::stable::Tensor>& query_start_loc,
|
|
const std::optional<torch::stable::Tensor>& cache_indices,
|
|
const std::optional<torch::stable::Tensor>& has_initial_state,
|
|
const torch::stable::Tensor& ssm_states, int64_t null_block_id,
|
|
int64_t block_size,
|
|
const std::optional<torch::stable::Tensor>& block_idx_first_scheduled_token,
|
|
const std::optional<torch::stable::Tensor>& block_idx_last_scheduled_token,
|
|
const std::optional<torch::stable::Tensor>& initial_state_idx,
|
|
const std::optional<torch::stable::Tensor>& cu_chunk_seqlen,
|
|
const std::optional<torch::stable::Tensor>& last_chunk_indices);
|
|
|
|
using fptr_t = int64_t;
|
|
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
|
torch::stable::Tensor& rank_data, int64_t rank,
|
|
bool fully_connected);
|
|
void all_reduce(fptr_t _fa, torch::stable::Tensor& inp,
|
|
torch::stable::Tensor& out, fptr_t reg_buffer,
|
|
int64_t reg_buffer_sz_bytes);
|
|
void dispose(fptr_t _fa);
|
|
int64_t meta_size();
|
|
void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
|
|
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
|
|
get_graph_buffer_ipc_meta(fptr_t _fa);
|
|
void register_graph_buffers(fptr_t _fa,
|
|
const std::vector<std::vector<int64_t>>& handles,
|
|
const std::vector<std::vector<int64_t>>& offsets);
|
|
std::tuple<int64_t, torch::stable::Tensor> allocate_shared_buffer_and_handle(
|
|
int64_t size);
|
|
int64_t open_mem_handle(torch::stable::Tensor& mem_handle);
|
|
void free_shared_buffer(int64_t buffer);
|
|
|
|
// Activation kernels (shared CUDA/ROCm)
|
|
void silu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input);
|
|
void silu_and_mul_clamp(torch::stable::Tensor& out,
|
|
torch::stable::Tensor& input, double limit);
|
|
void mul_and_silu(torch::stable::Tensor& out, torch::stable::Tensor& input);
|
|
void gelu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input);
|
|
void gelu_tanh_and_mul(torch::stable::Tensor& out,
|
|
torch::stable::Tensor& input);
|
|
void fatrelu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input,
|
|
double threshold);
|
|
void swigluoai_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input,
|
|
double alpha = 1.702, double limit = 7.0);
|
|
void gelu_new(torch::stable::Tensor& out, torch::stable::Tensor& input);
|
|
void gelu_fast(torch::stable::Tensor& out, torch::stable::Tensor& input);
|
|
void gelu_quick(torch::stable::Tensor& out, torch::stable::Tensor& input);
|
|
|
|
// INT8 quantization kernels (shared CUDA/ROCm)
|
|
void static_scaled_int8_quant(torch::stable::Tensor& out,
|
|
torch::stable::Tensor const& input,
|
|
torch::stable::Tensor const& scale,
|
|
std::optional<torch::stable::Tensor> const& azp);
|
|
|
|
void dynamic_scaled_int8_quant(torch::stable::Tensor& out,
|
|
torch::stable::Tensor const& input,
|
|
torch::stable::Tensor& scales,
|
|
std::optional<torch::stable::Tensor> const& azp);
|
|
|
|
// FP8 quantization kernels (shared CUDA/ROCm)
|
|
void static_scaled_fp8_quant(
|
|
torch::stable::Tensor& out, torch::stable::Tensor const& input,
|
|
torch::stable::Tensor const& scale,
|
|
std::optional<torch::headeronly::IntHeaderOnlyArrayRef> group_shape =
|
|
std::nullopt);
|
|
|
|
void dynamic_scaled_fp8_quant(torch::stable::Tensor& out,
|
|
torch::stable::Tensor const& input,
|
|
torch::stable::Tensor& scale);
|
|
|
|
void dynamic_per_token_scaled_fp8_quant(
|
|
torch::stable::Tensor& out, torch::stable::Tensor const& input,
|
|
torch::stable::Tensor& scale,
|
|
std::optional<torch::stable::Tensor> const& scale_ub);
|
|
|
|
// GPTQ kernels (shared CUDA/ROCm)
|
|
torch::stable::Tensor gptq_gemm(torch::stable::Tensor a,
|
|
torch::stable::Tensor b_q_weight,
|
|
torch::stable::Tensor b_gptq_qzeros,
|
|
torch::stable::Tensor b_gptq_scales,
|
|
torch::stable::Tensor b_g_idx, bool use_exllama,
|
|
bool use_v2_format, int64_t bit);
|
|
|
|
void gptq_shuffle(torch::stable::Tensor q_weight, torch::stable::Tensor q_perm,
|
|
int64_t bit);
|
|
|
|
// GGML kernels (shared CUDA/ROCm)
|
|
torch::stable::Tensor ggml_dequantize(
|
|
torch::stable::Tensor W, int64_t type, int64_t m, int64_t n,
|
|
std::optional<torch::headeronly::ScalarType> const& dtype);
|
|
|
|
torch::stable::Tensor ggml_mul_mat_vec_a8(torch::stable::Tensor W,
|
|
torch::stable::Tensor X, int64_t type,
|
|
int64_t row);
|
|
|
|
torch::stable::Tensor ggml_mul_mat_a8(torch::stable::Tensor W,
|
|
torch::stable::Tensor X, int64_t type,
|
|
int64_t row);
|
|
|
|
torch::stable::Tensor ggml_moe_a8(torch::stable::Tensor X,
|
|
torch::stable::Tensor W,
|
|
torch::stable::Tensor sorted_token_ids,
|
|
torch::stable::Tensor expert_ids,
|
|
torch::stable::Tensor num_tokens_post_padded,
|
|
int64_t type, int64_t row, int64_t top_k,
|
|
int64_t tokens);
|
|
|
|
torch::stable::Tensor ggml_moe_a8_vec(torch::stable::Tensor X,
|
|
torch::stable::Tensor W,
|
|
torch::stable::Tensor topk_ids,
|
|
int64_t top_k, int64_t type, int64_t row,
|
|
int64_t tokens);
|
|
|
|
int64_t ggml_moe_get_block_size(int64_t type);
|
|
|
|
void paged_attention_v1(
|
|
torch::stable::Tensor& out, torch::stable::Tensor& query,
|
|
torch::stable::Tensor& key_cache, torch::stable::Tensor& value_cache,
|
|
int64_t num_kv_heads, double scale, torch::stable::Tensor& block_tables,
|
|
torch::stable::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len,
|
|
const std::optional<torch::stable::Tensor>& alibi_slopes,
|
|
const std::string& kv_cache_dtype, torch::stable::Tensor& k_scale,
|
|
torch::stable::Tensor& v_scale, const int64_t tp_rank,
|
|
const int64_t blocksparse_local_blocks,
|
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
|
const int64_t blocksparse_head_sliding_step);
|
|
|
|
void paged_attention_v2(
|
|
torch::stable::Tensor& out, torch::stable::Tensor& exp_sums,
|
|
torch::stable::Tensor& max_logits, torch::stable::Tensor& tmp_out,
|
|
torch::stable::Tensor& query, torch::stable::Tensor& key_cache,
|
|
torch::stable::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
|
torch::stable::Tensor& block_tables, torch::stable::Tensor& seq_lens,
|
|
int64_t block_size, int64_t max_seq_len,
|
|
const std::optional<torch::stable::Tensor>& alibi_slopes,
|
|
const std::string& kv_cache_dtype, torch::stable::Tensor& k_scale,
|
|
torch::stable::Tensor& v_scale, const int64_t tp_rank,
|
|
const int64_t blocksparse_local_blocks,
|
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
|
const int64_t blocksparse_head_sliding_step);
|
|
|
|
// Cache ops (shared CUDA/ROCm)
|
|
void swap_blocks(torch::stable::Tensor& src, torch::stable::Tensor& dst,
|
|
int64_t block_size_in_bytes,
|
|
const torch::stable::Tensor& block_mapping);
|
|
|
|
// Batch swap: submit all block copies in a single driver call.
|
|
void swap_blocks_batch(const torch::stable::Tensor& src_ptrs,
|
|
const torch::stable::Tensor& dst_ptrs,
|
|
const torch::stable::Tensor& sizes,
|
|
bool is_src_access_order_any);
|
|
|
|
void reshape_and_cache(torch::stable::Tensor& key, torch::stable::Tensor& value,
|
|
torch::stable::Tensor& key_cache,
|
|
torch::stable::Tensor& value_cache,
|
|
torch::stable::Tensor& slot_mapping,
|
|
const std::string& kv_cache_dtype,
|
|
torch::stable::Tensor& k_scale,
|
|
torch::stable::Tensor& v_scale);
|
|
|
|
void reshape_and_cache_flash(
|
|
torch::stable::Tensor& key, torch::stable::Tensor& value,
|
|
torch::stable::Tensor& key_cache, torch::stable::Tensor& value_cache,
|
|
torch::stable::Tensor& slot_mapping, const std::string& kv_cache_dtype,
|
|
torch::stable::Tensor& k_scale, torch::stable::Tensor& v_scale);
|
|
|
|
void concat_and_cache_mla(torch::stable::Tensor& kv_c,
|
|
torch::stable::Tensor& k_pe,
|
|
torch::stable::Tensor& kv_cache,
|
|
torch::stable::Tensor& slot_mapping,
|
|
const std::string& kv_cache_dtype,
|
|
torch::stable::Tensor& scale);
|
|
|
|
// NOTE: k_pe and kv_c order is flipped compared to concat_and_cache_mla
|
|
void concat_and_cache_mla_rope_fused(
|
|
torch::stable::Tensor& positions, torch::stable::Tensor& q_pe,
|
|
torch::stable::Tensor& k_pe, torch::stable::Tensor& kv_c,
|
|
torch::stable::Tensor& rope_cos_sin_cache, bool rope_is_neox,
|
|
torch::stable::Tensor& slot_mapping, torch::stable::Tensor& kv_cache,
|
|
const std::string& kv_cache_dtype,
|
|
torch::stable::Tensor& kv_cache_quant_scale);
|
|
|
|
// Just for unittest
|
|
void convert_fp8(torch::stable::Tensor& dst_cache,
|
|
torch::stable::Tensor& src_cache, const double scale,
|
|
const std::string& kv_cache_dtype);
|
|
|
|
void gather_and_maybe_dequant_cache(
|
|
torch::stable::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
|
|
// ENTRIES...]
|
|
torch::stable::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
|
torch::stable::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
|
torch::stable::Tensor const& cu_seq_lens, // [BATCH+1]
|
|
torch::stable::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS]
|
|
int64_t num_tokens, const std::string& kv_cache_dtype,
|
|
torch::stable::Tensor const& scale,
|
|
std::optional<torch::stable::Tensor> seq_starts = std::nullopt);
|
|
|
|
// TODO(hc): cp_gather_cache need support scaled kvcahe in the future.
|
|
void cp_gather_cache(
|
|
torch::stable::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
|
|
// ENTRIES...]
|
|
torch::stable::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
|
torch::stable::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
|
torch::stable::Tensor const& cu_seq_lens, // [BATCH+1]
|
|
int64_t batch_size,
|
|
std::optional<torch::stable::Tensor> seq_starts = std::nullopt);
|
|
|
|
// Gather and upconvert FP8 KV cache to BF16 workspace
|
|
void cp_gather_and_upconvert_fp8_kv_cache(
|
|
torch::stable::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
|
|
// 656]
|
|
torch::stable::Tensor const& dst, // [TOT_TOKENS, 576]
|
|
torch::stable::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
|
torch::stable::Tensor const& seq_lens, // [BATCH]
|
|
torch::stable::Tensor const& workspace_starts, // [BATCH]
|
|
int64_t batch_size);
|
|
|
|
// Indexer K quantization and cache function
|
|
void indexer_k_quant_and_cache(
|
|
torch::stable::Tensor& k, // [num_tokens, head_dim]
|
|
torch::stable::Tensor& kv_cache, // [num_blocks, block_size,
|
|
// cache_stride]
|
|
torch::stable::Tensor& slot_mapping, // [num_tokens]
|
|
int64_t quant_block_size, // quantization block size
|
|
const std::string& scale_fmt);
|
|
|
|
// Concatenate query nope and rope for MLA/DSA attention
|
|
void concat_mla_q(
|
|
torch::stable::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim]
|
|
torch::stable::Tensor& q_pe, // [num_tokens, num_heads, rope_dim]
|
|
torch::stable::Tensor& q_out); // [num_tokens, num_heads, nope_dim +
|
|
// rope_dim]
|
|
|
|
// Extract function to gather quantized K cache
|
|
void cp_gather_indexer_k_quant_cache(
|
|
const torch::stable::Tensor& kv_cache, // [num_blocks, block_size,
|
|
// cache_stride]
|
|
torch::stable::Tensor& dst_k, // [num_tokens, head_dim]
|
|
torch::stable::Tensor& dst_scale, // [num_tokens, head_dim /
|
|
// quant_block_size * 4]
|
|
const torch::stable::Tensor& block_table, // [batch_size, num_blocks]
|
|
const torch::stable::Tensor& cu_seq_lens); // [batch_size + 1]
|