mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
284e6f543d
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com> Signed-off-by: Chris Leonard <chleonar@redhat.com> Co-authored-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com> Co-authored-by: Shengqi Chen <harry-chen@outlook.com>
387 lines
15 KiB
C++
387 lines
15 KiB
C++
#include "cache.h"
|
|
#include "cuda_utils.h"
|
|
#include "ops.h"
|
|
#include "core/registration.h"
|
|
#include <torch/library.h>
|
|
#include <torch/version.h>
|
|
|
|
// Note on op signatures:
|
|
// The X_meta signatures are for the meta functions corresponding to op X.
|
|
// They must be kept in sync with the signature for X. Generally, only
|
|
// functions that return Tensors require a meta function.
|
|
//
|
|
// See the following links for detailed docs on op registration and function
|
|
// schemas.
|
|
// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
|
|
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
|
|
|
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
// vLLM custom ops
|
|
//
|
|
|
|
ops.def(
|
|
"persistent_masked_m_silu_mul_quant(Tensor input, Tensor counts, Tensor! "
|
|
"y_q, Tensor! y_s,"
|
|
"bool use_ue8m0) -> ()");
|
|
ops.impl("persistent_masked_m_silu_mul_quant", torch::kCUDA,
|
|
&persistent_masked_m_silu_mul_quant);
|
|
|
|
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
|
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
|
|
|
ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
|
|
ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU,
|
|
&get_cuda_view_from_cpu_tensor);
|
|
|
|
// Attention ops
|
|
// Compute the attention between an input query and the cached
|
|
// keys/values using PagedAttention.
|
|
ops.def(
|
|
"paged_attention_v1("
|
|
" Tensor! out, Tensor query, Tensor key_cache,"
|
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
|
" int max_seq_len, Tensor? alibi_slopes,"
|
|
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
|
" int tp_rank, int blocksparse_local_blocks,"
|
|
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
|
" int blocksparse_head_sliding_step) -> ()");
|
|
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
|
|
|
|
// PagedAttention V2.
|
|
ops.def(
|
|
"paged_attention_v2("
|
|
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
|
|
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
|
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
|
" int max_seq_len, Tensor? alibi_slopes,"
|
|
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
|
" int tp_rank, int blocksparse_local_blocks,"
|
|
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
|
" int blocksparse_head_sliding_step) -> ()");
|
|
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
|
|
|
// Activation ops (quantized only — basic ops moved to _C_stable_libtorch)
|
|
ops.def(
|
|
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
|
|
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
|
|
|
|
// Fused SiLU+Mul + per-block quantization
|
|
ops.def(
|
|
"silu_and_mul_per_block_quant("
|
|
"Tensor! out, "
|
|
"Tensor input, "
|
|
"Tensor! scales, "
|
|
"int group_size, "
|
|
"Tensor? scale_ub=None, "
|
|
"bool is_scale_transposed=False) -> ()");
|
|
ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
|
|
&silu_and_mul_per_block_quant);
|
|
|
|
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
|
|
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
|
|
// kernel launch.
|
|
ops.def(
|
|
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert("
|
|
"Tensor q_in, Tensor kv, Tensor! k_cache, "
|
|
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
|
|
"int q_head_padded, float eps, int cache_block_size) -> Tensor");
|
|
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA,
|
|
&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert);
|
|
|
|
// Quantization ops
|
|
#ifndef USE_ROCM
|
|
|
|
// Note about marlin kernel 'workspace' arguments:
|
|
// Technically these should be mutable since they are modified by the kernel.
|
|
// But since they are set back to zero once the kernel is finished we can
|
|
// hand wave and say that they have no net effect.
|
|
//
|
|
// The reason to mark 'workspace' as immutable is so that they don't interfere
|
|
// with using ScalarType arguments in the ops. If they are marked as mutable,
|
|
// pytorch throws an assert in
|
|
// 'torch._higher_order_ops._register_effectful_op' that prevents these
|
|
// kernels from being torch.compile'd.
|
|
// See the following document for more info on custom types and ops that use
|
|
// custom types:
|
|
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
|
|
|
|
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
|
|
ops.def(
|
|
"machete_supported_schedules("
|
|
" ScalarType a_type,"
|
|
" int b_type,"
|
|
" ScalarType? maybe_group_scales_type,"
|
|
" ScalarType? maybe_group_zeros_type,"
|
|
" ScalarType? maybe_channel_scales_type,"
|
|
" ScalarType? maybe_token_scales_type,"
|
|
" ScalarType? maybe_out_type"
|
|
") -> str[]");
|
|
ops.def(
|
|
"machete_mm("
|
|
" Tensor A,"
|
|
" Tensor B,"
|
|
" int b_type,"
|
|
" ScalarType? out_type,"
|
|
" Tensor? group_scales,"
|
|
" Tensor? group_zeros,"
|
|
" int? group_size,"
|
|
" Tensor? channel_scales,"
|
|
" Tensor? token_scales,"
|
|
" str? schedule"
|
|
") -> Tensor");
|
|
ops.def(
|
|
"machete_prepack_B("
|
|
" Tensor B,"
|
|
" ScalarType a_type,"
|
|
" int b_type,"
|
|
" ScalarType? group_scales_type"
|
|
") -> Tensor");
|
|
// conditionally compiled so impl registration is in source file
|
|
|
|
// Marlin Optimized Quantized GEMM (supports GPTQ, AWQ, FP8, NVFP4, MXFP4).
|
|
ops.def(
|
|
"marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
|
|
"Tensor? b_bias_or_none,Tensor b_scales, "
|
|
"Tensor? a_scales, Tensor? global_scale, Tensor? b_zeros_or_none, "
|
|
"Tensor? "
|
|
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_type_id, "
|
|
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
|
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
|
// conditionally compiled so impl registration is in source file
|
|
|
|
// gptq_marlin repack from GPTQ.
|
|
ops.def(
|
|
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
|
|
"SymInt size_k, SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor");
|
|
// conditionally compiled so impl registrations are in source file
|
|
|
|
// awq_marlin repack from AWQ.
|
|
ops.def(
|
|
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
|
|
"SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor");
|
|
// conditionally compiled so impl registrations are in source file
|
|
|
|
// preprocess W-int4A-fp8 weight for marlin kernel
|
|
ops.def(
|
|
"marlin_int4_fp8_preprocess(Tensor qweight, "
|
|
"Tensor? qzeros_or_none, bool inplace) -> Tensor");
|
|
// conditionally compiled so impl registrations are in source file
|
|
|
|
#endif
|
|
|
|
#ifndef USE_ROCM
|
|
// Expert-specialization mxfp8 blockscaled grouped quantization (SM100+).
|
|
ops.def(
|
|
"mxfp8_experts_quant("
|
|
" Tensor input, Tensor problem_sizes, Tensor expert_offsets,"
|
|
" Tensor blockscale_offsets, Tensor! quant_output, Tensor! scale_factor)"
|
|
" -> ()");
|
|
// conditionally compiled so impl registration is in source file
|
|
|
|
// Expert-specialization mxfp8 blockscaled grouped GEMM (SM100+).
|
|
ops.def(
|
|
"cutlass_mxfp8_grouped_mm("
|
|
" Tensor a, Tensor b, Tensor sfa, Tensor sfb, Tensor! out,"
|
|
" Tensor problem_sizes, Tensor expert_offsets, Tensor blockscale_offsets)"
|
|
" -> ()");
|
|
// conditionally compiled so impl registration is in source file
|
|
|
|
#endif
|
|
|
|
#ifndef USE_ROCM
|
|
ops.def(
|
|
"minimax_allreduce_rms("
|
|
"Tensor input,"
|
|
"Tensor norm_weight,"
|
|
"Tensor workspace,"
|
|
"int rank,"
|
|
"int nranks,"
|
|
"float eps) -> Tensor");
|
|
ops.impl("minimax_allreduce_rms", torch::kCUDA, &minimax_allreduce_rms);
|
|
ops.def(
|
|
"minimax_allreduce_rms_qk("
|
|
"Tensor qkv,"
|
|
"Tensor norm_weight_q,"
|
|
"Tensor norm_weight_k,"
|
|
"Tensor workspace,"
|
|
"int q_size,"
|
|
"int kv_size,"
|
|
"int rank,"
|
|
"int nranks,"
|
|
"float eps) -> (Tensor, Tensor)");
|
|
ops.impl("minimax_allreduce_rms_qk", torch::kCUDA, &minimax_allreduce_rms_qk);
|
|
|
|
// conditionally compiled so impl in source file
|
|
#endif
|
|
}
|
|
|
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
|
// Cache ops
|
|
// Swap in (out) the cache blocks from src to dst.
|
|
cache_ops.def(
|
|
"swap_blocks(Tensor src, Tensor! dst,"
|
|
" int block_size_in_bytes, Tensor block_mapping) -> ()");
|
|
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
|
|
|
|
// Batch swap: submit all block copies in a single driver call.
|
|
cache_ops.def(
|
|
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs,"
|
|
" Tensor sizes,"
|
|
" bool is_src_access_order_any=False) -> ()");
|
|
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
|
|
|
|
// Reshape the key and value tensors and cache them.
|
|
cache_ops.def(
|
|
"reshape_and_cache(Tensor key, Tensor value,"
|
|
" Tensor! key_cache, Tensor! value_cache,"
|
|
" Tensor slot_mapping,"
|
|
" str kv_cache_dtype,"
|
|
" Tensor k_scale, Tensor v_scale) -> ()");
|
|
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
|
|
|
|
// Reshape the key and value tensors and cache them.
|
|
cache_ops.def(
|
|
"reshape_and_cache_flash(Tensor key, Tensor value,"
|
|
" Tensor! key_cache,"
|
|
" Tensor! value_cache,"
|
|
" Tensor slot_mapping,"
|
|
" str kv_cache_dtype,"
|
|
" Tensor k_scale, Tensor v_scale) -> ()");
|
|
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
|
|
&reshape_and_cache_flash);
|
|
|
|
// Concat kv_c and k_pe and cache them.
|
|
cache_ops.def(
|
|
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
|
|
" Tensor! kv_cache,"
|
|
" Tensor slot_mapping,"
|
|
" str kv_cache_dtype,"
|
|
" Tensor scale) -> ()");
|
|
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
|
|
|
|
// Rotate Q and K, then write to kv cache for MLA
|
|
cache_ops.def(
|
|
"concat_and_cache_mla_rope_fused("
|
|
" Tensor positions,"
|
|
" Tensor! q_pe,"
|
|
" Tensor! k_pe,"
|
|
" Tensor kv_c,"
|
|
" Tensor cos_sin_cache,"
|
|
" bool is_neox,"
|
|
" Tensor slot_mapping,"
|
|
" Tensor! kv_cache,"
|
|
" str kv_cache_dtype,"
|
|
" Tensor kv_cache_scale) -> ()");
|
|
cache_ops.impl("concat_and_cache_mla_rope_fused", torch::kCUDA,
|
|
&concat_and_cache_mla_rope_fused);
|
|
|
|
// Convert the key and value cache to fp8 data type.
|
|
cache_ops.def(
|
|
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
|
|
"str kv_cache_dtype) -> ()");
|
|
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
|
|
|
|
// Gather cache blocks from src_cache to dst, dequantizing from
|
|
// src_cache's dtype to dst's dtype if necessary.
|
|
cache_ops.def(
|
|
"gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, "
|
|
" Tensor block_table, Tensor cu_seq_lens, "
|
|
" Tensor token_to_seq, "
|
|
" int num_tokens, "
|
|
" str kv_cache_dtype, "
|
|
" Tensor scale, Tensor? seq_starts) -> ()");
|
|
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,
|
|
&gather_and_maybe_dequant_cache);
|
|
|
|
cache_ops.def(
|
|
"cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
|
|
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
|
|
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
|
|
|
|
cache_ops.def(
|
|
"cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, "
|
|
"Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int "
|
|
"batch_size) -> ()");
|
|
cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA,
|
|
&cp_gather_and_upconvert_fp8_kv_cache);
|
|
|
|
cache_ops.def(
|
|
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
|
|
"slot_mapping, "
|
|
"int quant_block_size, str kv_cache_dtype) -> ()");
|
|
cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA,
|
|
&indexer_k_quant_and_cache);
|
|
|
|
cache_ops.def(
|
|
"concat_mla_q(Tensor ql_nope, Tensor q_pe, Tensor! q_out) -> ()");
|
|
cache_ops.impl("concat_mla_q", torch::kCUDA, &concat_mla_q);
|
|
|
|
cache_ops.def(
|
|
"cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! "
|
|
"dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()");
|
|
cache_ops.impl("cp_gather_indexer_k_quant_cache", torch::kCUDA,
|
|
&cp_gather_indexer_k_quant_cache);
|
|
}
|
|
|
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
|
// Cuda utils
|
|
|
|
// Gets the specified device attribute.
|
|
cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
|
|
cuda_utils.impl("get_device_attribute", &get_device_attribute);
|
|
|
|
// Gets the maximum shared memory per block device attribute.
|
|
cuda_utils.def(
|
|
"get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
|
|
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
|
|
&get_max_shared_memory_per_block_device_attribute);
|
|
}
|
|
|
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
|
// Custom all-reduce kernels
|
|
custom_ar.def(
|
|
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
|
|
"int rank, bool fully_connected) -> int");
|
|
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
|
custom_ar.def(
|
|
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
|
|
"int reg_buffer_sz_bytes) -> ()");
|
|
custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce);
|
|
|
|
custom_ar.def("dispose", &dispose);
|
|
custom_ar.def("meta_size", &meta_size);
|
|
|
|
custom_ar.def("register_buffer", ®ister_buffer);
|
|
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
|
custom_ar.def("register_graph_buffers", ®ister_graph_buffers);
|
|
|
|
custom_ar.def("allocate_shared_buffer_and_handle",
|
|
&allocate_shared_buffer_and_handle);
|
|
custom_ar.def("open_mem_handle(Tensor mem_handle) -> int", &open_mem_handle);
|
|
custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle);
|
|
|
|
custom_ar.def("free_shared_buffer", &free_shared_buffer);
|
|
#ifdef USE_ROCM
|
|
// Quick Reduce all-reduce kernels
|
|
custom_ar.def(
|
|
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
|
|
"cast_bf2half) -> ()");
|
|
custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
|
|
|
|
custom_ar.def("init_custom_qr", &init_custom_qr);
|
|
custom_ar.def("qr_destroy", &qr_destroy);
|
|
|
|
custom_ar.def("qr_get_handle", &qr_get_handle);
|
|
|
|
custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
|
|
custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
|
|
|
|
// Max input size in bytes
|
|
custom_ar.def("qr_max_size", &qr_max_size);
|
|
#endif
|
|
}
|
|
|
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|