[10/n] Migrate cuda_view and silu_and_mul_per_block_quant kernels to torch stale ABI. (#44334)

This commit is contained in:
Chris Leonard
2026-06-04 23:14:43 -04:00
committed by GitHub
parent 063ce98fb7
commit 56aff0dd15
25 changed files with 186 additions and 163 deletions
+7 -11
View File
@@ -307,10 +307,7 @@ endif()
#
set(VLLM_EXT_SRC
"csrc/cuda_view.cu"
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
"csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA")
@@ -346,9 +343,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
FetchContent_MakeAvailable(cutlass)
list(APPEND VLLM_EXT_SRC
"csrc/cutlass_extensions/common.cpp")
set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}"
CUDA_ARCHS "${CUDA_ARCHS}")
@@ -627,6 +621,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
#
set(VLLM_STABLE_EXT_SRC
"csrc/libtorch_stable/torch_bindings.cpp"
"csrc/libtorch_stable/cuda_view.cu"
"csrc/libtorch_stable/activation_kernels.cu"
"csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu"
"csrc/libtorch_stable/quantization/w8a8/fp8/common.cu"
@@ -639,6 +634,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
"csrc/libtorch_stable/layernorm_kernels.cu"
"csrc/libtorch_stable/layernorm_quant_kernels.cu"
"csrc/libtorch_stable/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/libtorch_stable/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
"csrc/libtorch_stable/attention/merge_attn_states.cu"
"csrc/libtorch_stable/sampler.cu"
"csrc/libtorch_stable/topk.cu"
@@ -653,8 +649,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC
"csrc/cuda_utils_kernels.cu"
"csrc/cutlass_extensions/common.cpp"
"csrc/libtorch_stable/cuda_utils_kernels.cu"
"csrc/libtorch_stable/cutlass_extensions/common.cpp"
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu"
@@ -1076,11 +1072,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
WITH_SOABI)
# Set TORCH_TARGET_VERSION for stable ABI compatibility.
# This ensures we only use C-shim APIs available in PyTorch 2.10.
# This ensures we only use C-shim APIs available in PyTorch 2.11.
# _C_stable_libtorch is abi compatible with PyTorch >= TORCH_TARGET_VERSION
# which is currently set to 2.10.
# which is currently set to 2.11.
target_compile_definitions(_C_stable_libtorch PRIVATE
TORCH_TARGET_VERSION=0x020A000000000000ULL)
TORCH_TARGET_VERSION=0x020B000000000000ULL)
# Needed to use cuda/hip APIs from C-shim
if(VLLM_GPU_LANG STREQUAL "CUDA")
-59
View File
@@ -1,59 +0,0 @@
#include <torch/all.h>
#include <torch/cuda.h>
#include <cuda_runtime.h>
// This function assumes that `cpu_tensor` is a CPU tensor,
// and that UVA (Unified Virtual Addressing) is enabled.
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
// handle empty tensor
if (cpu_tensor.numel() == 0) {
return torch::empty(cpu_tensor.sizes(),
cpu_tensor.options().device(torch::kCUDA));
}
if (cpu_tensor.is_pinned()) {
// If CPU tensor is pinned, directly get the device pointer.
void* host_ptr = const_cast<void*>(cpu_tensor.data_ptr());
void* device_ptr = nullptr;
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
TORCH_CHECK(err == cudaSuccess,
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
return torch::from_blob(
device_ptr, cpu_tensor.sizes(), cpu_tensor.strides(),
[base = cpu_tensor](void*) {}, // keep cpu tensor alive
cpu_tensor.options().device(torch::kCUDA));
}
// If CPU tensor is not pinned, allocate a new pinned memory buffer.
torch::Tensor contiguous_cpu = cpu_tensor.contiguous();
size_t nbytes = contiguous_cpu.nbytes();
void* host_ptr = nullptr;
cudaError_t err = cudaHostAlloc(&host_ptr, nbytes, cudaHostAllocMapped);
if (err != cudaSuccess) {
AT_ERROR("cudaHostAlloc failed: ", cudaGetErrorString(err));
}
err = cudaMemcpy(host_ptr, contiguous_cpu.data_ptr(), nbytes,
cudaMemcpyDefault);
if (err != cudaSuccess) {
cudaFreeHost(host_ptr);
AT_ERROR("cudaMemcpy failed: ", cudaGetErrorString(err));
}
void* device_ptr = nullptr;
err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
if (err != cudaSuccess) {
cudaFreeHost(host_ptr);
AT_ERROR("cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
}
auto deleter = [host_ptr](void*) { cudaFreeHost(host_ptr); };
return torch::from_blob(device_ptr, contiguous_cpu.sizes(),
contiguous_cpu.strides(), deleter,
contiguous_cpu.options().device(torch::kCUDA));
}
+76
View File
@@ -0,0 +1,76 @@
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/csrc/stable/device.h>
#include <torch/csrc/stable/c/shim.h>
#include <torch/headeronly/version.h>
#include <cuda_runtime.h>
#include <array>
#include <optional>
// This function assumes that `cpu_tensor` is a CPU tensor,
// and that UVA (Unified Virtual Addressing) is enabled.
torch::stable::Tensor get_cuda_view_from_cpu_tensor(
torch::stable::Tensor& cpu_tensor) {
STD_TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
const auto dtype = cpu_tensor.scalar_type();
const auto layout = cpu_tensor.layout();
const torch::stable::Device cuda_dev(torch::headeronly::DeviceType::CUDA);
// handle empty tensor
if (cpu_tensor.numel() == 0) {
return torch::stable::empty(cpu_tensor.sizes(), dtype, layout, cuda_dev);
}
std::array<StableIValue, 2> is_pinned_stack{
torch::stable::detail::from(cpu_tensor),
torch::stable::detail::from(std::nullopt)};
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::is_pinned", "", is_pinned_stack.data(), TORCH_ABI_VERSION));
if (torch::stable::detail::to<bool>(is_pinned_stack[0])) {
// If CPU tensor is pinned, directly get the device pointer.
void* host_ptr = const_cast<void*>(cpu_tensor.mutable_data_ptr());
void* device_ptr = nullptr;
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
STD_TORCH_CHECK(err == cudaSuccess, "cudaHostGetDevicePointer failed: ",
cudaGetErrorString(err));
return torch::stable::from_blob(
device_ptr, cpu_tensor.sizes(), cpu_tensor.strides(), cuda_dev, dtype,
[base = cpu_tensor](void*) {}); // keep cpu tensor alive
}
// If CPU tensor is not pinned, allocate a new pinned memory buffer.
torch::stable::Tensor contiguous_cpu = torch::stable::contiguous(cpu_tensor);
size_t nbytes = contiguous_cpu.numel() * contiguous_cpu.element_size();
void* host_ptr = nullptr;
cudaError_t err = cudaHostAlloc(&host_ptr, nbytes, cudaHostAllocMapped);
if (err != cudaSuccess) {
STD_TORCH_CHECK(false, "cudaHostAlloc failed: ", cudaGetErrorString(err));
}
err = cudaMemcpy(host_ptr, contiguous_cpu.const_data_ptr(), nbytes,
cudaMemcpyDefault);
if (err != cudaSuccess) {
cudaFreeHost(host_ptr);
STD_TORCH_CHECK(false, "cudaMemcpy failed: ", cudaGetErrorString(err));
}
void* device_ptr = nullptr;
err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
if (err != cudaSuccess) {
cudaFreeHost(host_ptr);
STD_TORCH_CHECK(
false, "cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
}
auto deleter = [host_ptr](void*) { cudaFreeHost(host_ptr); };
return torch::stable::from_blob(device_ptr, contiguous_cpu.sizes(),
contiguous_cpu.strides(), cuda_dev,
contiguous_cpu.scalar_type(), deleter);
}
@@ -1,4 +1,4 @@
#include "cutlass_extensions/common.hpp"
#include "common.hpp"
int32_t get_sm_version_num() {
int32_t major_capability, minor_capability;
+11
View File
@@ -164,6 +164,10 @@ torch::stable::Tensor awq_dequantize(torch::stable::Tensor _kernel,
#endif
// CPU tensor -> CUDA UVA view (shared CUDA/ROCm)
torch::stable::Tensor get_cuda_view_from_cpu_tensor(
torch::stable::Tensor& cpu_tensor);
// Attention kernels (shared CUDA/ROCm)
void merge_attn_states(
torch::stable::Tensor& output,
@@ -215,6 +219,13 @@ void rms_norm_per_block_quant(torch::stable::Tensor& out,
std::optional<torch::stable::Tensor> residual,
int64_t group_size, bool is_scale_transposed);
void silu_and_mul_per_block_quant(torch::stable::Tensor& out,
torch::stable::Tensor const& input,
torch::stable::Tensor& scales,
int64_t group_size,
std::optional<torch::stable::Tensor> scale_ub,
bool is_scale_transposed);
// Positional encoding kernels (shared CUDA/ROCm)
void rotary_embedding(torch::stable::Tensor& positions,
torch::stable::Tensor& query,
@@ -18,7 +18,7 @@
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/torch_utils.hpp"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "get_group_starts.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
@@ -21,7 +21,7 @@
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/mixed_dtype_utils.hpp"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include <cuda_runtime.h>
@@ -12,7 +12,7 @@
#include <cutlass/arch/arch.h>
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
@@ -20,7 +20,7 @@
#include <cutlass/arch/arch.h>
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
@@ -18,7 +18,7 @@
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "nvfp4_utils.cuh"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
@@ -18,7 +18,7 @@
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
@@ -18,7 +18,7 @@
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "cutlass/cutlass.h"
@@ -18,7 +18,7 @@
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "cutlass/cutlass.h"
@@ -1,11 +1,10 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../../torch_utils.h"
#include "../../dispatch_utils.h"
#include "libtorch_stable/quantization/fused_kernels/quant_conversions.cuh"
#include "quant_conversions.cuh"
namespace vllm {
@@ -105,64 +104,70 @@ __global__ void silu_and_mul_per_block_quant_kernel(
} // namespace vllm
void silu_and_mul_per_block_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scales, int64_t group_size,
std::optional<torch::Tensor> scale_ub,
void silu_and_mul_per_block_quant(torch::stable::Tensor& out,
torch::stable::Tensor const& input,
torch::stable::Tensor& scales,
int64_t group_size,
std::optional<torch::stable::Tensor> scale_ub,
bool is_scale_transposed) {
static c10::ScalarType kFp8Type = is_fp8_ocp()
? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz;
static torch::headeronly::ScalarType kFp8Type =
is_fp8_ocp() ? torch::headeronly::ScalarType::Float8_e4m3fn
: torch::headeronly::ScalarType::Float8_e4m3fnuz;
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
TORCH_CHECK(
input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16,
STD_TORCH_CHECK(out.scalar_type() == kFp8Type ||
out.scalar_type() == torch::headeronly::ScalarType::Char);
STD_TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
STD_TORCH_CHECK(
input.scalar_type() == torch::headeronly::ScalarType::Half ||
input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
"Input must be FP16 or BF16");
TORCH_CHECK(scales.dtype() == torch::kFloat32, "Scales must be FP32");
TORCH_CHECK(group_size == 128 || group_size == 64,
"Unsupported group size: ", group_size);
STD_TORCH_CHECK(scales.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(group_size == 128 || group_size == 64,
"Unsupported group size: ", group_size);
if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type);
STD_TORCH_CHECK(out.scalar_type() == kFp8Type);
}
int32_t hidden_size = out.size(-1);
auto num_tokens = input.size(0);
int32_t num_groups = hidden_size / group_size;
TORCH_CHECK(input.size(-1) == hidden_size * 2,
"input last dim must be 2x output hidden_size");
TORCH_CHECK(hidden_size % group_size == 0,
"hidden_size must be divisible by group_size");
STD_TORCH_CHECK(input.size(-1) == hidden_size * 2,
"input last dim must be 2x output hidden_size");
STD_TORCH_CHECK(hidden_size % group_size == 0,
"hidden_size must be divisible by group_size");
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
dim3 grid(num_tokens, num_groups);
dim3 block(group_size);
VLLM_DISPATCH_FLOATING_TYPES(
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "silu_and_mul_per_block_quant", [&] {
using scalar_in_t = scalar_t;
VLLM_DISPATCH_QUANT_TYPES(
VLLM_STABLE_DISPATCH_QUANT_TYPES(
out.scalar_type(), "silu_and_mul_per_block_quant", [&] {
using scalar_out_t = scalar_t;
VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] {
VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] {
vllm::silu_and_mul_per_block_quant_kernel<
scalar_in_t, scalar_out_t, transpose_scale, gs>
<<<grid, block, 0, stream>>>(
out.data_ptr<scalar_out_t>(),
scales.data_ptr<float>(),
input.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>()
: nullptr,
hidden_size);
});
VLLM_STABLE_DISPATCH_GROUP_SIZE(group_size, gs, [&] {
VLLM_STABLE_DISPATCH_BOOL(
is_scale_transposed, transpose_scale, [&] {
vllm::silu_and_mul_per_block_quant_kernel<
scalar_in_t, scalar_out_t, transpose_scale, gs>
<<<grid, block, 0, stream>>>(
out.mutable_data_ptr<scalar_out_t>(),
scales.mutable_data_ptr<float>(),
input.const_data_ptr<scalar_in_t>(),
scale_ub.has_value()
? scale_ub->const_data_ptr<float>()
: nullptr,
hidden_size);
});
});
});
});
}
}
@@ -20,7 +20,7 @@
#include "cutlass/util/packed_stride.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
// clang-format on
namespace vllm::c3x {
@@ -15,7 +15,7 @@
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
// clang-format on
/*
@@ -1,7 +1,7 @@
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include "cuda_utils.h"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
void dispatch_scaled_mm(torch::stable::Tensor& c,
@@ -8,7 +8,7 @@
#include <torch/csrc/stable/ops.h>
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "get_group_starts.cuh"
using namespace cute;
@@ -23,7 +23,7 @@
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
// clang-format on
using namespace cute;
@@ -4,7 +4,7 @@
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
void cutlass_scaled_mm_sm75(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
+33
View File
@@ -1,4 +1,5 @@
#include "ops.h"
#include "cuda_utils.h"
#include "core/registration.h"
#include <torch/csrc/stable/library.h>
@@ -27,6 +28,8 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
"()");
ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
#ifndef USE_ROCM
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
#endif
@@ -321,6 +324,16 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"Tensor? scale_ub, Tensor!? residual, int group_size, "
"bool is_scale_transposed) -> ()");
// Fused SiLU+Mul + per-block quantization
ops.def(
"silu_and_mul_per_block_quant("
"Tensor! out, "
"Tensor input, "
"Tensor! scales, "
"int group_size, "
"Tensor? scale_ub=None, "
"bool is_scale_transposed=False) -> ()");
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
@@ -599,6 +612,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
ops.impl("rms_norm_dynamic_per_token_quant",
TORCH_BOX(&rms_norm_dynamic_per_token_quant));
ops.impl("rms_norm_per_block_quant", TORCH_BOX(&rms_norm_per_block_quant));
ops.impl("silu_and_mul_per_block_quant",
TORCH_BOX(&silu_and_mul_per_block_quant));
// Positional encoding kernels (shared CUDA/ROCm)
ops.impl("rotary_embedding", TORCH_BOX(&rotary_embedding));
@@ -661,6 +676,11 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
ops.impl("paged_attention_v2", TORCH_BOX(&paged_attention_v2));
}
STABLE_TORCH_LIBRARY_IMPL(_C, CPU, ops) {
ops.impl("get_cuda_view_from_cpu_tensor",
TORCH_BOX(&get_cuda_view_from_cpu_tensor));
}
// These capability-check functions take only primitive args (no tensors), so
// there is no device to dispatch on. CompositeExplicitAutograd makes them
// available for all backends. This is the stable ABI equivalent of calling
@@ -681,6 +701,19 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
ops.impl("ggml_moe_get_block_size", TORCH_BOX(&ggml_moe_get_block_size));
}
STABLE_TORCH_LIBRARY_FRAGMENT(_C_cuda_utils, cuda_utils) {
cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
cuda_utils.def(
"get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
}
STABLE_TORCH_LIBRARY_IMPL(_C_cuda_utils, CompositeExplicitAutograd,
cuda_utils) {
cuda_utils.impl("get_device_attribute", TORCH_BOX(&get_device_attribute));
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
TORCH_BOX(&get_max_shared_memory_per_block_device_attribute));
}
// Cache ops
STABLE_TORCH_LIBRARY_FRAGMENT(_C_cache_ops, ops) {
// Swap in (out) the cache blocks from src to dst.
-8
View File
@@ -40,12 +40,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);
void silu_and_mul_per_block_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scales, int64_t group_size,
std::optional<torch::Tensor> scale_ub,
bool is_scale_transposed);
// rotary_embedding also exist in csrc/libtorch_stable/ops.h (torch::stable
// ABI for CUDA). It remains here because the CPU build still uses these
// torch::Tensor declarations.
@@ -84,8 +78,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table, double scale);
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale,
std::optional<torch::Tensor> const& azp);
-31
View File
@@ -2,7 +2,6 @@
// cache.h, which is no longer included here after cache ops moved to
// _C_stable_libtorch).
#include <torch/all.h>
#include "cuda_utils.h"
#include "ops.h"
#include "core/registration.h"
#include <torch/library.h>
@@ -32,27 +31,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU,
&get_cuda_view_from_cpu_tensor);
// Activation ops (quantized only — basic ops moved to _C_stable_libtorch)
ops.def(
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
// Fused SiLU+Mul + per-block quantization
ops.def(
"silu_and_mul_per_block_quant("
"Tensor! out, "
"Tensor input, "
"Tensor! scales, "
"int group_size, "
"Tensor? scale_ub=None, "
"bool is_scale_transposed=False) -> ()");
ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
&silu_and_mul_per_block_quant);
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
// kernel launch. Registered in _C_stable_libtorch (incl. the FlashInfer V4
@@ -178,18 +161,4 @@ TORCH_LIBRARY_FRAGMENT(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
}
#endif
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
// Cuda utils
// Gets the specified device attribute.
cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
cuda_utils.impl("get_device_attribute", &get_device_attribute);
// Gets the maximum shared memory per block device attribute.
cuda_utils.def(
"get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
&get_max_shared_memory_per_block_device_attribute);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)