mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[10/n] Migrate cuda_view and silu_and_mul_per_block_quant kernels to torch stale ABI. (#44334)
This commit is contained in:
+7
-11
@@ -307,10 +307,7 @@ endif()
|
||||
#
|
||||
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cuda_view.cu"
|
||||
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
|
||||
"csrc/quantization/activation_kernels.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/torch_bindings.cpp")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
@@ -346,9 +343,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
FetchContent_MakeAvailable(cutlass)
|
||||
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/cutlass_extensions/common.cpp")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_EXT_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
@@ -627,6 +621,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
#
|
||||
set(VLLM_STABLE_EXT_SRC
|
||||
"csrc/libtorch_stable/torch_bindings.cpp"
|
||||
"csrc/libtorch_stable/cuda_view.cu"
|
||||
"csrc/libtorch_stable/activation_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/fp8/common.cu"
|
||||
@@ -639,6 +634,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
"csrc/libtorch_stable/layernorm_kernels.cu"
|
||||
"csrc/libtorch_stable/layernorm_quant_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
|
||||
"csrc/libtorch_stable/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
|
||||
"csrc/libtorch_stable/attention/merge_attn_states.cu"
|
||||
"csrc/libtorch_stable/sampler.cu"
|
||||
"csrc/libtorch_stable/topk.cu"
|
||||
@@ -653,8 +649,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/cutlass_extensions/common.cpp"
|
||||
"csrc/libtorch_stable/cuda_utils_kernels.cu"
|
||||
"csrc/libtorch_stable/cutlass_extensions/common.cpp"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||
@@ -1076,11 +1072,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
WITH_SOABI)
|
||||
|
||||
# Set TORCH_TARGET_VERSION for stable ABI compatibility.
|
||||
# This ensures we only use C-shim APIs available in PyTorch 2.10.
|
||||
# This ensures we only use C-shim APIs available in PyTorch 2.11.
|
||||
# _C_stable_libtorch is abi compatible with PyTorch >= TORCH_TARGET_VERSION
|
||||
# which is currently set to 2.10.
|
||||
# which is currently set to 2.11.
|
||||
target_compile_definitions(_C_stable_libtorch PRIVATE
|
||||
TORCH_TARGET_VERSION=0x020A000000000000ULL)
|
||||
TORCH_TARGET_VERSION=0x020B000000000000ULL)
|
||||
|
||||
# Needed to use cuda/hip APIs from C-shim
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
#include <torch/all.h>
|
||||
#include <torch/cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
// This function assumes that `cpu_tensor` is a CPU tensor,
|
||||
// and that UVA (Unified Virtual Addressing) is enabled.
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
|
||||
TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
|
||||
|
||||
// handle empty tensor
|
||||
if (cpu_tensor.numel() == 0) {
|
||||
return torch::empty(cpu_tensor.sizes(),
|
||||
cpu_tensor.options().device(torch::kCUDA));
|
||||
}
|
||||
|
||||
if (cpu_tensor.is_pinned()) {
|
||||
// If CPU tensor is pinned, directly get the device pointer.
|
||||
void* host_ptr = const_cast<void*>(cpu_tensor.data_ptr());
|
||||
void* device_ptr = nullptr;
|
||||
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
|
||||
TORCH_CHECK(err == cudaSuccess,
|
||||
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
|
||||
|
||||
return torch::from_blob(
|
||||
device_ptr, cpu_tensor.sizes(), cpu_tensor.strides(),
|
||||
[base = cpu_tensor](void*) {}, // keep cpu tensor alive
|
||||
cpu_tensor.options().device(torch::kCUDA));
|
||||
}
|
||||
|
||||
// If CPU tensor is not pinned, allocate a new pinned memory buffer.
|
||||
torch::Tensor contiguous_cpu = cpu_tensor.contiguous();
|
||||
size_t nbytes = contiguous_cpu.nbytes();
|
||||
|
||||
void* host_ptr = nullptr;
|
||||
cudaError_t err = cudaHostAlloc(&host_ptr, nbytes, cudaHostAllocMapped);
|
||||
if (err != cudaSuccess) {
|
||||
AT_ERROR("cudaHostAlloc failed: ", cudaGetErrorString(err));
|
||||
}
|
||||
|
||||
err = cudaMemcpy(host_ptr, contiguous_cpu.data_ptr(), nbytes,
|
||||
cudaMemcpyDefault);
|
||||
if (err != cudaSuccess) {
|
||||
cudaFreeHost(host_ptr);
|
||||
AT_ERROR("cudaMemcpy failed: ", cudaGetErrorString(err));
|
||||
}
|
||||
|
||||
void* device_ptr = nullptr;
|
||||
err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
|
||||
if (err != cudaSuccess) {
|
||||
cudaFreeHost(host_ptr);
|
||||
AT_ERROR("cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
|
||||
}
|
||||
|
||||
auto deleter = [host_ptr](void*) { cudaFreeHost(host_ptr); };
|
||||
|
||||
return torch::from_blob(device_ptr, contiguous_cpu.sizes(),
|
||||
contiguous_cpu.strides(), deleter,
|
||||
contiguous_cpu.options().device(torch::kCUDA));
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/csrc/stable/accelerator.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/csrc/stable/device.h>
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
#include <torch/headeronly/version.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <array>
|
||||
#include <optional>
|
||||
|
||||
// This function assumes that `cpu_tensor` is a CPU tensor,
|
||||
// and that UVA (Unified Virtual Addressing) is enabled.
|
||||
torch::stable::Tensor get_cuda_view_from_cpu_tensor(
|
||||
torch::stable::Tensor& cpu_tensor) {
|
||||
STD_TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
|
||||
|
||||
const auto dtype = cpu_tensor.scalar_type();
|
||||
const auto layout = cpu_tensor.layout();
|
||||
const torch::stable::Device cuda_dev(torch::headeronly::DeviceType::CUDA);
|
||||
|
||||
// handle empty tensor
|
||||
if (cpu_tensor.numel() == 0) {
|
||||
return torch::stable::empty(cpu_tensor.sizes(), dtype, layout, cuda_dev);
|
||||
}
|
||||
|
||||
std::array<StableIValue, 2> is_pinned_stack{
|
||||
torch::stable::detail::from(cpu_tensor),
|
||||
torch::stable::detail::from(std::nullopt)};
|
||||
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
|
||||
"aten::is_pinned", "", is_pinned_stack.data(), TORCH_ABI_VERSION));
|
||||
if (torch::stable::detail::to<bool>(is_pinned_stack[0])) {
|
||||
// If CPU tensor is pinned, directly get the device pointer.
|
||||
void* host_ptr = const_cast<void*>(cpu_tensor.mutable_data_ptr());
|
||||
void* device_ptr = nullptr;
|
||||
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
|
||||
STD_TORCH_CHECK(err == cudaSuccess, "cudaHostGetDevicePointer failed: ",
|
||||
cudaGetErrorString(err));
|
||||
|
||||
return torch::stable::from_blob(
|
||||
device_ptr, cpu_tensor.sizes(), cpu_tensor.strides(), cuda_dev, dtype,
|
||||
[base = cpu_tensor](void*) {}); // keep cpu tensor alive
|
||||
}
|
||||
|
||||
// If CPU tensor is not pinned, allocate a new pinned memory buffer.
|
||||
torch::stable::Tensor contiguous_cpu = torch::stable::contiguous(cpu_tensor);
|
||||
size_t nbytes = contiguous_cpu.numel() * contiguous_cpu.element_size();
|
||||
|
||||
void* host_ptr = nullptr;
|
||||
cudaError_t err = cudaHostAlloc(&host_ptr, nbytes, cudaHostAllocMapped);
|
||||
if (err != cudaSuccess) {
|
||||
STD_TORCH_CHECK(false, "cudaHostAlloc failed: ", cudaGetErrorString(err));
|
||||
}
|
||||
|
||||
err = cudaMemcpy(host_ptr, contiguous_cpu.const_data_ptr(), nbytes,
|
||||
cudaMemcpyDefault);
|
||||
if (err != cudaSuccess) {
|
||||
cudaFreeHost(host_ptr);
|
||||
STD_TORCH_CHECK(false, "cudaMemcpy failed: ", cudaGetErrorString(err));
|
||||
}
|
||||
|
||||
void* device_ptr = nullptr;
|
||||
err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
|
||||
if (err != cudaSuccess) {
|
||||
cudaFreeHost(host_ptr);
|
||||
STD_TORCH_CHECK(
|
||||
false, "cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
|
||||
}
|
||||
|
||||
auto deleter = [host_ptr](void*) { cudaFreeHost(host_ptr); };
|
||||
|
||||
return torch::stable::from_blob(device_ptr, contiguous_cpu.sizes(),
|
||||
contiguous_cpu.strides(), cuda_dev,
|
||||
contiguous_cpu.scalar_type(), deleter);
|
||||
}
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
int32_t get_sm_version_num() {
|
||||
int32_t major_capability, minor_capability;
|
||||
@@ -164,6 +164,10 @@ torch::stable::Tensor awq_dequantize(torch::stable::Tensor _kernel,
|
||||
|
||||
#endif
|
||||
|
||||
// CPU tensor -> CUDA UVA view (shared CUDA/ROCm)
|
||||
torch::stable::Tensor get_cuda_view_from_cpu_tensor(
|
||||
torch::stable::Tensor& cpu_tensor);
|
||||
|
||||
// Attention kernels (shared CUDA/ROCm)
|
||||
void merge_attn_states(
|
||||
torch::stable::Tensor& output,
|
||||
@@ -215,6 +219,13 @@ void rms_norm_per_block_quant(torch::stable::Tensor& out,
|
||||
std::optional<torch::stable::Tensor> residual,
|
||||
int64_t group_size, bool is_scale_transposed);
|
||||
|
||||
void silu_and_mul_per_block_quant(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor& scales,
|
||||
int64_t group_size,
|
||||
std::optional<torch::stable::Tensor> scale_ub,
|
||||
bool is_scale_transposed);
|
||||
|
||||
// Positional encoding kernels (shared CUDA/ROCm)
|
||||
void rotary_embedding(torch::stable::Tensor& positions,
|
||||
torch::stable::Tensor& query,
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
|
||||
#include "get_group_starts.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
|
||||
#include <cutlass/arch/arch.h>
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
|
||||
#include <cutlass/arch/arch.h>
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
#include "nvfp4_utils.cuh"
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
|
||||
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
|
||||
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
|
||||
+34
-29
@@ -1,11 +1,10 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "../../torch_utils.h"
|
||||
|
||||
#include "../../dispatch_utils.h"
|
||||
#include "libtorch_stable/quantization/fused_kernels/quant_conversions.cuh"
|
||||
#include "quant_conversions.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
@@ -105,60 +104,66 @@ __global__ void silu_and_mul_per_block_quant_kernel(
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void silu_and_mul_per_block_quant(torch::Tensor& out,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor& scales, int64_t group_size,
|
||||
std::optional<torch::Tensor> scale_ub,
|
||||
void silu_and_mul_per_block_quant(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor& scales,
|
||||
int64_t group_size,
|
||||
std::optional<torch::stable::Tensor> scale_ub,
|
||||
bool is_scale_transposed) {
|
||||
static c10::ScalarType kFp8Type = is_fp8_ocp()
|
||||
? c10::ScalarType::Float8_e4m3fn
|
||||
: c10::ScalarType::Float8_e4m3fnuz;
|
||||
static torch::headeronly::ScalarType kFp8Type =
|
||||
is_fp8_ocp() ? torch::headeronly::ScalarType::Float8_e4m3fn
|
||||
: torch::headeronly::ScalarType::Float8_e4m3fnuz;
|
||||
|
||||
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
|
||||
TORCH_CHECK(
|
||||
input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16,
|
||||
STD_TORCH_CHECK(out.scalar_type() == kFp8Type ||
|
||||
out.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
|
||||
STD_TORCH_CHECK(
|
||||
input.scalar_type() == torch::headeronly::ScalarType::Half ||
|
||||
input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
|
||||
"Input must be FP16 or BF16");
|
||||
TORCH_CHECK(scales.dtype() == torch::kFloat32, "Scales must be FP32");
|
||||
TORCH_CHECK(group_size == 128 || group_size == 64,
|
||||
STD_TORCH_CHECK(scales.scalar_type() == torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(group_size == 128 || group_size == 64,
|
||||
"Unsupported group size: ", group_size);
|
||||
|
||||
if (scale_ub.has_value()) {
|
||||
TORCH_CHECK(out.dtype() == kFp8Type);
|
||||
STD_TORCH_CHECK(out.scalar_type() == kFp8Type);
|
||||
}
|
||||
|
||||
int32_t hidden_size = out.size(-1);
|
||||
auto num_tokens = input.size(0);
|
||||
int32_t num_groups = hidden_size / group_size;
|
||||
|
||||
TORCH_CHECK(input.size(-1) == hidden_size * 2,
|
||||
STD_TORCH_CHECK(input.size(-1) == hidden_size * 2,
|
||||
"input last dim must be 2x output hidden_size");
|
||||
TORCH_CHECK(hidden_size % group_size == 0,
|
||||
STD_TORCH_CHECK(hidden_size % group_size == 0,
|
||||
"hidden_size must be divisible by group_size");
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
input.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
|
||||
|
||||
dim3 grid(num_tokens, num_groups);
|
||||
dim3 block(group_size);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "silu_and_mul_per_block_quant", [&] {
|
||||
using scalar_in_t = scalar_t;
|
||||
|
||||
VLLM_DISPATCH_QUANT_TYPES(
|
||||
VLLM_STABLE_DISPATCH_QUANT_TYPES(
|
||||
out.scalar_type(), "silu_and_mul_per_block_quant", [&] {
|
||||
using scalar_out_t = scalar_t;
|
||||
|
||||
VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] {
|
||||
VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] {
|
||||
VLLM_STABLE_DISPATCH_GROUP_SIZE(group_size, gs, [&] {
|
||||
VLLM_STABLE_DISPATCH_BOOL(
|
||||
is_scale_transposed, transpose_scale, [&] {
|
||||
vllm::silu_and_mul_per_block_quant_kernel<
|
||||
scalar_in_t, scalar_out_t, transpose_scale, gs>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_out_t>(),
|
||||
scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_in_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>()
|
||||
out.mutable_data_ptr<scalar_out_t>(),
|
||||
scales.mutable_data_ptr<float>(),
|
||||
input.const_data_ptr<scalar_in_t>(),
|
||||
scale_ub.has_value()
|
||||
? scale_ub->const_data_ptr<float>()
|
||||
: nullptr,
|
||||
hidden_size);
|
||||
});
|
||||
@@ -20,7 +20,7 @@
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
// clang-format on
|
||||
|
||||
namespace vllm::c3x {
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
// clang-format on
|
||||
|
||||
/*
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include "cuda_utils.h"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
|
||||
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
|
||||
void dispatch_scaled_mm(torch::stable::Tensor& c,
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
#include "get_group_starts.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
||||
|
||||
#include "core/math.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
|
||||
void cutlass_scaled_mm_sm75(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "ops.h"
|
||||
#include "cuda_utils.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
#include <torch/csrc/stable/library.h>
|
||||
@@ -27,6 +28,8 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
|
||||
"()");
|
||||
|
||||
ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
|
||||
|
||||
#ifndef USE_ROCM
|
||||
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
||||
#endif
|
||||
@@ -321,6 +324,16 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||
"Tensor? scale_ub, Tensor!? residual, int group_size, "
|
||||
"bool is_scale_transposed) -> ()");
|
||||
|
||||
// Fused SiLU+Mul + per-block quantization
|
||||
ops.def(
|
||||
"silu_and_mul_per_block_quant("
|
||||
"Tensor! out, "
|
||||
"Tensor input, "
|
||||
"Tensor! scales, "
|
||||
"int group_size, "
|
||||
"Tensor? scale_ub=None, "
|
||||
"bool is_scale_transposed=False) -> ()");
|
||||
|
||||
// Rotary embedding
|
||||
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
|
||||
ops.def(
|
||||
@@ -599,6 +612,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||
ops.impl("rms_norm_dynamic_per_token_quant",
|
||||
TORCH_BOX(&rms_norm_dynamic_per_token_quant));
|
||||
ops.impl("rms_norm_per_block_quant", TORCH_BOX(&rms_norm_per_block_quant));
|
||||
ops.impl("silu_and_mul_per_block_quant",
|
||||
TORCH_BOX(&silu_and_mul_per_block_quant));
|
||||
|
||||
// Positional encoding kernels (shared CUDA/ROCm)
|
||||
ops.impl("rotary_embedding", TORCH_BOX(&rotary_embedding));
|
||||
@@ -661,6 +676,11 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||
ops.impl("paged_attention_v2", TORCH_BOX(&paged_attention_v2));
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CPU, ops) {
|
||||
ops.impl("get_cuda_view_from_cpu_tensor",
|
||||
TORCH_BOX(&get_cuda_view_from_cpu_tensor));
|
||||
}
|
||||
|
||||
// These capability-check functions take only primitive args (no tensors), so
|
||||
// there is no device to dispatch on. CompositeExplicitAutograd makes them
|
||||
// available for all backends. This is the stable ABI equivalent of calling
|
||||
@@ -681,6 +701,19 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
|
||||
ops.impl("ggml_moe_get_block_size", TORCH_BOX(&ggml_moe_get_block_size));
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(_C_cuda_utils, cuda_utils) {
|
||||
cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
|
||||
cuda_utils.def(
|
||||
"get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C_cuda_utils, CompositeExplicitAutograd,
|
||||
cuda_utils) {
|
||||
cuda_utils.impl("get_device_attribute", TORCH_BOX(&get_device_attribute));
|
||||
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
|
||||
TORCH_BOX(&get_max_shared_memory_per_block_device_attribute));
|
||||
}
|
||||
|
||||
// Cache ops
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(_C_cache_ops, ops) {
|
||||
// Swap in (out) the cache blocks from src to dst.
|
||||
|
||||
@@ -40,12 +40,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||
torch::Tensor& weight, double epsilon);
|
||||
|
||||
void silu_and_mul_per_block_quant(torch::Tensor& out,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor& scales, int64_t group_size,
|
||||
std::optional<torch::Tensor> scale_ub,
|
||||
bool is_scale_transposed);
|
||||
|
||||
// rotary_embedding also exist in csrc/libtorch_stable/ops.h (torch::stable
|
||||
// ABI for CUDA). It remains here because the CPU build still uses these
|
||||
// torch::Tensor declarations.
|
||||
@@ -84,8 +78,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale);
|
||||
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor const& scale,
|
||||
std::optional<torch::Tensor> const& azp);
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
// cache.h, which is no longer included here after cache ops moved to
|
||||
// _C_stable_libtorch).
|
||||
#include <torch/all.h>
|
||||
#include "cuda_utils.h"
|
||||
#include "ops.h"
|
||||
#include "core/registration.h"
|
||||
#include <torch/library.h>
|
||||
@@ -32,27 +31,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
||||
|
||||
ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
|
||||
ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU,
|
||||
&get_cuda_view_from_cpu_tensor);
|
||||
|
||||
// Activation ops (quantized only — basic ops moved to _C_stable_libtorch)
|
||||
ops.def(
|
||||
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
|
||||
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
|
||||
|
||||
// Fused SiLU+Mul + per-block quantization
|
||||
ops.def(
|
||||
"silu_and_mul_per_block_quant("
|
||||
"Tensor! out, "
|
||||
"Tensor input, "
|
||||
"Tensor! scales, "
|
||||
"int group_size, "
|
||||
"Tensor? scale_ub=None, "
|
||||
"bool is_scale_transposed=False) -> ()");
|
||||
ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
|
||||
&silu_and_mul_per_block_quant);
|
||||
|
||||
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
|
||||
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
|
||||
// kernel launch. Registered in _C_stable_libtorch (incl. the FlashInfer V4
|
||||
@@ -178,18 +161,4 @@ TORCH_LIBRARY_FRAGMENT(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
||||
}
|
||||
#endif
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||
// Cuda utils
|
||||
|
||||
// Gets the specified device attribute.
|
||||
cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
|
||||
cuda_utils.impl("get_device_attribute", &get_device_attribute);
|
||||
|
||||
// Gets the maximum shared memory per block device attribute.
|
||||
cuda_utils.def(
|
||||
"get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
|
||||
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
|
||||
&get_max_shared_memory_per_block_device_attribute);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
||||
Reference in New Issue
Block a user