[6/n] Migrate activation kernels, gptq, gguf, non cutlass w8a8 to libtorch stable ABI (continued) (#42663)

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Chris Leonard <chleonar@redhat.com>
Co-authored-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Co-authored-by: Shengqi Chen <harry-chen@outlook.com>
This commit is contained in:
Chris Leonard
2026-05-20 03:18:12 -04:00
committed by GitHub
parent 40651c0207
commit 07aeaf9d4d
28 changed files with 810 additions and 639 deletions
+41 -22
View File
@@ -312,19 +312,14 @@ set(VLLM_EXT_SRC
"csrc/attention/paged_attention_v2.cu"
"csrc/attention/merge_attn_states.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/fused_qknorm_rope_kernel.cu"
"csrc/layernorm_quant_kernels.cu"
"csrc/sampler.cu"
"csrc/topk.cu"
"csrc/cuda_view.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/w8a8/int8/scaled_quant.cu"
"csrc/quantization/w8a8/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/custom_all_reduce.cu"
@@ -628,33 +623,33 @@ define_extension_target(
# Setting this variable sidesteps the issue by calling the driver directly.
target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
# add OR VLLM_GPU_LANG STREQUAL "HIP" here once
# https://github.com/vllm-project/vllm/issues/35163 is resolved
if(VLLM_GPU_LANG STREQUAL "CUDA")
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
#
# _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY)
#
set(VLLM_STABLE_EXT_SRC
"csrc/libtorch_stable/torch_bindings.cpp"
"csrc/cutlass_extensions/common.cpp"
"csrc/cuda_utils_kernels.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu")
"csrc/libtorch_stable/activation_kernels.cu"
"csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu"
"csrc/libtorch_stable/quantization/w8a8/fp8/common.cu"
"csrc/libtorch_stable/quantization/gptq/q_gemm.cu"
"csrc/libtorch_stable/quantization/gguf/gguf_kernel.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC
"csrc/cuda_utils_kernels.cu"
"csrc/cutlass_extensions/common.cpp"
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/libtorch_stable/permute_cols.cu"
"csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu"
"csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu"
"csrc/libtorch_stable/quantization/awq/gemm_kernels.cu")
endif()
if(VLLM_GPU_LANG STREQUAL "CUDA")
set_gencode_flags_for_srcs(
SRCS "${VLLM_STABLE_EXT_SRC}"
CUDA_ARCHS "${CUDA_ARCHS}")
endif()
# DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper and later)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
@@ -1034,6 +1029,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Building hadacore")
endif()
# if CUDA endif
endif()
message(STATUS "Enabling C_stable extension.")
define_extension_target(
_C_stable_libtorch
@@ -1053,13 +1051,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
target_compile_definitions(_C_stable_libtorch PRIVATE
TORCH_TARGET_VERSION=0x020A000000000000ULL)
# Needed to use cuda APIs from C-shim
target_compile_definitions(_C_stable_libtorch PRIVATE
USE_CUDA)
# Needed to use cuda/hip APIs from C-shim
if(VLLM_GPU_LANG STREQUAL "CUDA")
target_compile_definitions(_C_stable_libtorch PRIVATE USE_CUDA)
# Needed by CUTLASS kernels
target_compile_definitions(_C_stable_libtorch PRIVATE
CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
elseif(VLLM_GPU_LANG STREQUAL "HIP")
target_compile_definitions(_C_stable_libtorch PRIVATE USE_ROCM)
endif()
# Needed by CUTLASS kernels
target_compile_definitions(_C_stable_libtorch PRIVATE
CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
# On ROCm, _C_stable_libtorch calls raw HIP APIs (e.g. hipGetDevice in
# get_device_prop()) which must resolve to the same libamdhip64.so that
# PyTorch uses. When PyTorch bundles its own copy (pip/conda wheels),
# the raw HIP calls would otherwise resolve to the system ROCm copy,
# initializing a second HIP runtime that corrupts device state (wrong
# device on DeviceGuard, core dumps on multi-GPU tests).
#
# If PyTorch doesn't bundle libamdhip64 (built from source against system
# ROCm), there is only one copy in the process and no action is needed —
# the HIP compiler already links the system libamdhip64 automatically.
if(VLLM_GPU_LANG STREQUAL "HIP")
find_library(_STABLE_TORCH_AMDHIP64 amdhip64
PATHS "${TORCH_INSTALL_PREFIX}/lib" NO_DEFAULT_PATH)
if(_STABLE_TORCH_AMDHIP64)
message(STATUS "Found PyTorch-bundled libamdhip64 at ${_STABLE_TORCH_AMDHIP64}")
target_link_libraries(_C_stable_libtorch PRIVATE ${_STABLE_TORCH_AMDHIP64})
endif()
endif()
endif()
#
+2 -1
View File
@@ -1,6 +1,7 @@
#pragma once
#include "attention_generic.cuh"
#include "torch_utils.h"
#include <stdint.h>
#ifdef ENABLE_FP8
@@ -30,7 +31,7 @@ inline Fp8KVCacheDataType get_fp8_kv_cache_data_type(
} else if (dtype_str == "fp8_e5m2") {
return Fp8KVCacheDataType::kFp8E5M2;
}
TORCH_CHECK(false, "Unsupported fp8 kv cache data type: ", dtype_str);
TORCH_UTILS_CHECK(false, "Unsupported fp8 kv cache data type: ", dtype_str);
}
// fp8 vector types for quantization of kv cache
+2
View File
@@ -9,6 +9,8 @@
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#else
#include <cuda_bf16.h>
#include <cuda_fp16.h>
+2 -4
View File
@@ -1,5 +1,7 @@
#pragma once
#include "torch_utils.h"
// This header is shared between _C (unstable ABI, used by machete) and
// _C_stable_libtorch (stable ABI, used by W4A8/sparse). TORCH_TARGET_VERSION
// is defined only for the stable target, so we switch includes and types
@@ -8,13 +10,9 @@
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Half.h>
#include <torch/headeronly/util/shim_utils.h> // for STD_TORCH_CHECK
using TorchTensor = torch::stable::Tensor;
#define TORCH_UTILS_CHECK STD_TORCH_CHECK
#else
#include <torch/all.h>
using TorchTensor = torch::Tensor;
#define TORCH_UTILS_CHECK TORCH_CHECK
#endif
#include "cute/layout.hpp"
@@ -1,12 +1,12 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <torch/csrc/stable/tensor.h>
#include <cmath>
#include "cuda_compat.h"
#include "cuda_vec_utils.cuh"
#include "../cuda_compat.h"
#include "../cuda_vec_utils.cuh"
#include "dispatch_utils.h"
#include "torch_utils.h"
namespace vllm {
@@ -210,64 +210,68 @@ packed_gelu_tanh_kernel(const packed_t& val) {
return; \
} \
dim3 grid(num_tokens); \
int cc_major = at::cuda::getCurrentDeviceProperties()->major; \
int cc_major = get_device_prop()->major; \
int support_vec = \
(CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \
? vllm::VecTraits<true>::ARCH_MAX_VEC_SIZE \
: vllm::VecTraits<false>::ARCH_MAX_VEC_SIZE; \
int vec_size = support_vec / at::elementSize(dtype); \
int vec_size = support_vec / input.element_size(); \
const bool use_vec = (d % vec_size == 0); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
const torch::stable::accelerator::DeviceGuard device_guard( \
input.get_device_index()); \
const cudaStream_t stream = get_current_cuda_stream(); \
if (use_vec) { \
dim3 block(std::min(d / vec_size, 1024)); \
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
ACT_FIRST, true, HAS_CLAMP, true><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, LIMIT); \
out.mutable_data_ptr<scalar_t>(), \
input.const_data_ptr<scalar_t>(), d, LIMIT); \
}); \
} else { \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
ACT_FIRST, true, HAS_CLAMP, false><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, LIMIT); \
out.mutable_data_ptr<scalar_t>(), \
input.const_data_ptr<scalar_t>(), d, LIMIT); \
}); \
} \
} else { \
dim3 block(std::min(d, 1024)); \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
ACT_FIRST, false, HAS_CLAMP><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, LIMIT); \
out.mutable_data_ptr<scalar_t>(), input.const_data_ptr<scalar_t>(), \
d, LIMIT); \
}); \
}
void silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
void silu_and_mul(torch::stable::Tensor& out, // [..., d]
torch::stable::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel,
true, false, 0.0f);
}
void silu_and_mul_clamp(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
void silu_and_mul_clamp(torch::stable::Tensor& out, // [..., d]
torch::stable::Tensor& input, // [..., 2 * d]
double limit) {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel,
true, true, (float)limit);
}
void mul_and_silu(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
void mul_and_silu(torch::stable::Tensor& out, // [..., d]
torch::stable::Tensor& input) // [..., 2 * d]
{
// The difference between mul_and_silu and silu_and_mul is that mul_and_silu
// applies the silu to the latter half of the input.
@@ -275,15 +279,15 @@ void mul_and_silu(torch::Tensor& out, // [..., d]
false, false, 0.0f);
}
void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
void gelu_and_mul(torch::stable::Tensor& out, // [..., d]
torch::stable::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, vllm::packed_gelu_kernel,
true, false, 0.0f);
}
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
void gelu_tanh_and_mul(torch::stable::Tensor& out, // [..., d]
torch::stable::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(
vllm::gelu_tanh_kernel, vllm::packed_gelu_tanh_kernel, true, false, 0.0f);
@@ -434,19 +438,20 @@ __global__ void swigluoai_and_mul_kernel(
return; \
} \
dim3 grid(num_tokens); \
int cc_major = at::cuda::getCurrentDeviceProperties()->major; \
int cc_major = get_device_prop()->major; \
int support_vec = \
(CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \
? vllm::VecTraits<true>::ARCH_MAX_VEC_SIZE \
: vllm::VecTraits<false>::ARCH_MAX_VEC_SIZE; \
int vec_size = support_vec / at::elementSize(dtype); \
int vec_size = support_vec / input.element_size(); \
const bool use_vec = (d % vec_size == 0); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
const torch::stable::accelerator::DeviceGuard device_guard( \
input.get_device_index()); \
const cudaStream_t stream = get_current_cuda_stream(); \
if (use_vec) { \
dim3 block(std::min(d / vec_size, 1024)); \
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
VLLM_DISPATCH_FLOATING_TYPES( \
VLLM_STABLE_DISPATCH_FLOATING_TYPES( \
dtype, "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
@@ -454,11 +459,11 @@ __global__ void swigluoai_and_mul_kernel(
PACKED_KERNEL< \
typename vllm::PackedTypeConverter<scalar_t>::Type>, \
true, true><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, \
PARAM); \
out.mutable_data_ptr<scalar_t>(), \
input.const_data_ptr<scalar_t>(), d, PARAM); \
}); \
} else { \
VLLM_DISPATCH_FLOATING_TYPES( \
VLLM_STABLE_DISPATCH_FLOATING_TYPES( \
dtype, "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
@@ -466,45 +471,49 @@ __global__ void swigluoai_and_mul_kernel(
PACKED_KERNEL< \
typename vllm::PackedTypeConverter<scalar_t>::Type>, \
true, false><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, \
PARAM); \
out.mutable_data_ptr<scalar_t>(), \
input.const_data_ptr<scalar_t>(), d, PARAM); \
}); \
} \
} else { \
dim3 block(std::min(d, 1024)); \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
false><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, PARAM); \
}); \
VLLM_STABLE_DISPATCH_FLOATING_TYPES( \
dtype, "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL< \
typename vllm::PackedTypeConverter<scalar_t>::Type>, \
false><<<grid, block, 0, stream>>>( \
out.mutable_data_ptr<scalar_t>(), \
input.const_data_ptr<scalar_t>(), d, PARAM); \
}); \
}
#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \
vllm::swigluoai_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d, ALPHA, \
LIMIT); \
#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const torch::stable::accelerator::DeviceGuard device_guard( \
input.get_device_index()); \
const cudaStream_t stream = get_current_cuda_stream(); \
VLLM_STABLE_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \
vllm::swigluoai_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.mutable_data_ptr<scalar_t>(), \
input.const_data_ptr<scalar_t>(), d, \
ALPHA, LIMIT); \
});
void fatrelu_and_mul(torch::Tensor& out, // [..., d],
torch::Tensor& input, // [..., 2 * d]
void fatrelu_and_mul(torch::stable::Tensor& out, // [..., d],
torch::stable::Tensor& input, // [..., 2 * d]
double threshold) {
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(
vllm::fatrelu_kernel, vllm::packed_fatrelu_kernel, threshold);
}
void swigluoai_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
void swigluoai_and_mul(torch::stable::Tensor& out, // [..., d]
torch::stable::Tensor& input, // [..., 2 * d]
double alpha, double limit) {
LAUNCH_SIGLUOAI_AND_MUL(vllm::swigluoai_and_mul, alpha, limit);
}
@@ -559,45 +568,46 @@ __global__ void activation_kernel(
} // namespace vllm
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
auto dtype = input.scalar_type(); \
int d = input.size(-1); \
int64_t num_tokens = input.numel() / input.size(-1); \
if (num_tokens == 0) { \
return; \
} \
dim3 grid(num_tokens); \
int cc_major = at::cuda::getCurrentDeviceProperties()->major; \
int support_vec = \
(CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \
? vllm::VecTraits<true>::ARCH_MAX_VEC_SIZE \
: vllm::VecTraits<false>::ARCH_MAX_VEC_SIZE; \
int vec_size = support_vec / at::elementSize(dtype); \
const bool use_vec = (d % vec_size == 0); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
if (use_vec) { \
dim3 block(std::min(d / vec_size, 1024)); \
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, true, true> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
}); \
} else { \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, true, false> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
}); \
} \
} else { \
dim3 block(std::min(d, 1024)); \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, false> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
}); \
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
auto dtype = input.scalar_type(); \
int d = input.size(-1); \
int64_t num_tokens = input.numel() / input.size(-1); \
if (num_tokens == 0) { \
return; \
} \
dim3 grid(num_tokens); \
int cc_major = get_device_prop()->major; \
int support_vec = \
(CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \
? vllm::VecTraits<true>::ARCH_MAX_VEC_SIZE \
: vllm::VecTraits<false>::ARCH_MAX_VEC_SIZE; \
int vec_size = support_vec / input.element_size(); \
const bool use_vec = (d % vec_size == 0); \
const torch::stable::accelerator::DeviceGuard device_guard( \
input.get_device_index()); \
const cudaStream_t stream = get_current_cuda_stream(); \
if (use_vec) { \
dim3 block(std::min(d / vec_size, 1024)); \
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, true, true> \
<<<grid, block, 0, stream>>>(out.mutable_data_ptr<scalar_t>(), \
input.const_data_ptr<scalar_t>(), d); \
}); \
} else { \
VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, true, false> \
<<<grid, block, 0, stream>>>(out.mutable_data_ptr<scalar_t>(), \
input.const_data_ptr<scalar_t>(), d); \
}); \
} \
} else { \
dim3 block(std::min(d, 1024)); \
VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, false> \
<<<grid, block, 0, stream>>>(out.mutable_data_ptr<scalar_t>(), \
input.const_data_ptr<scalar_t>(), d); \
}); \
}
namespace vllm {
@@ -625,20 +635,20 @@ __device__ __forceinline__ T gelu_quick_kernel(const T& x) {
} // namespace vllm
void gelu_new(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
void gelu_new(torch::stable::Tensor& out, // [..., d]
torch::stable::Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}
void gelu_fast(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
void gelu_fast(torch::stable::Tensor& out, // [..., d]
torch::stable::Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}
void gelu_quick(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
void gelu_quick(torch::stable::Tensor& out, // [..., d]
torch::stable::Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
}
+83
View File
@@ -166,3 +166,86 @@ torch::stable::Tensor awq_dequantize(torch::stable::Tensor _kernel,
torch::stable::Tensor hadacore_transform(torch::stable::Tensor& x,
bool inplace);
// Activation kernels (shared CUDA/ROCm)
void silu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input);
void silu_and_mul_clamp(torch::stable::Tensor& out,
torch::stable::Tensor& input, double limit);
void mul_and_silu(torch::stable::Tensor& out, torch::stable::Tensor& input);
void gelu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input);
void gelu_tanh_and_mul(torch::stable::Tensor& out,
torch::stable::Tensor& input);
void fatrelu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input,
double threshold);
void swigluoai_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input,
double alpha = 1.702, double limit = 7.0);
void gelu_new(torch::stable::Tensor& out, torch::stable::Tensor& input);
void gelu_fast(torch::stable::Tensor& out, torch::stable::Tensor& input);
void gelu_quick(torch::stable::Tensor& out, torch::stable::Tensor& input);
// INT8 quantization kernels (shared CUDA/ROCm)
void static_scaled_int8_quant(torch::stable::Tensor& out,
torch::stable::Tensor const& input,
torch::stable::Tensor const& scale,
std::optional<torch::stable::Tensor> const& azp);
void dynamic_scaled_int8_quant(torch::stable::Tensor& out,
torch::stable::Tensor const& input,
torch::stable::Tensor& scales,
std::optional<torch::stable::Tensor> const& azp);
// FP8 quantization kernels (shared CUDA/ROCm)
void static_scaled_fp8_quant(
torch::stable::Tensor& out, torch::stable::Tensor const& input,
torch::stable::Tensor const& scale,
std::optional<torch::headeronly::IntHeaderOnlyArrayRef> group_shape =
std::nullopt);
void dynamic_scaled_fp8_quant(torch::stable::Tensor& out,
torch::stable::Tensor const& input,
torch::stable::Tensor& scale);
void dynamic_per_token_scaled_fp8_quant(
torch::stable::Tensor& out, torch::stable::Tensor const& input,
torch::stable::Tensor& scale,
std::optional<torch::stable::Tensor> const& scale_ub);
// GPTQ kernels (shared CUDA/ROCm)
torch::stable::Tensor gptq_gemm(torch::stable::Tensor a,
torch::stable::Tensor b_q_weight,
torch::stable::Tensor b_gptq_qzeros,
torch::stable::Tensor b_gptq_scales,
torch::stable::Tensor b_g_idx, bool use_exllama,
bool use_v2_format, int64_t bit);
void gptq_shuffle(torch::stable::Tensor q_weight, torch::stable::Tensor q_perm,
int64_t bit);
// GGML kernels (shared CUDA/ROCm)
torch::stable::Tensor ggml_dequantize(
torch::stable::Tensor W, int64_t type, int64_t m, int64_t n,
std::optional<torch::headeronly::ScalarType> const& dtype);
torch::stable::Tensor ggml_mul_mat_vec_a8(torch::stable::Tensor W,
torch::stable::Tensor X, int64_t type,
int64_t row);
torch::stable::Tensor ggml_mul_mat_a8(torch::stable::Tensor W,
torch::stable::Tensor X, int64_t type,
int64_t row);
torch::stable::Tensor ggml_moe_a8(torch::stable::Tensor X,
torch::stable::Tensor W,
torch::stable::Tensor sorted_token_ids,
torch::stable::Tensor expert_ids,
torch::stable::Tensor num_tokens_post_padded,
int64_t type, int64_t row, int64_t top_k,
int64_t tokens);
torch::stable::Tensor ggml_moe_a8_vec(torch::stable::Tensor X,
torch::stable::Tensor W,
torch::stable::Tensor topk_ids,
int64_t top_k, int64_t type, int64_t row,
int64_t tokens);
int64_t ggml_moe_get_block_size(int64_t type);
@@ -1,17 +1,20 @@
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include "../../../cuda_compat.h"
#include "../../dispatch_utils.h"
#include "../../torch_utils.h"
#include "../../cuda_compat.h"
#include "dispatch_utils.h"
#include <torch/csrc/stable/ops.h>
#include "ggml-common.h"
#include "vecdotq.cuh"
#include "dequantize.cuh"
#include "mmvq.cuh"
#include "mmq.cuh"
// NOTE: These headers are intentionally kept in csrc/quantization/gguf/ (not
// moved to libtorch_stable) to avoid unnecessary reformatting that would break
// git rename detection and pollute blame history.
#include "../../../quantization/gguf/ggml-common.h"
#include "../../../quantization/gguf/vecdotq.cuh"
#include "../../../quantization/gguf/dequantize.cuh"
#include "../../../quantization/gguf/mmvq.cuh"
#include "../../../quantization/gguf/mmq.cuh"
#include "moe.cuh"
#include "moe_vec.cuh"
@@ -71,16 +74,17 @@ static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx,
}
}
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
int64_t type, int64_t m, int64_t n,
std::optional<at::ScalarType> const& dtype) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
auto dtype_ = dtype.value_or(torch::kFloat16);
auto options = torch::TensorOptions().dtype(dtype_).device(W.device());
at::Tensor DW = torch::empty({m, n}, options);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
torch::stable::Tensor ggml_dequantize(
torch::stable::Tensor W, // quant weight
int64_t type, int64_t m, int64_t n,
std::optional<torch::headeronly::ScalarType> const& dtype) {
const torch::stable::accelerator::DeviceGuard device_guard(
W.get_device_index());
auto dtype_ = dtype.value_or(torch::headeronly::ScalarType::Half);
auto DW = torch::stable::empty({m, n}, dtype_, std::nullopt, W.device());
cudaStream_t stream = get_current_cuda_stream();
VLLM_DISPATCH_FLOATING_TYPES(DW.scalar_type(), "ggml_dequantize", [&] {
VLLM_STABLE_DISPATCH_FLOATING_TYPES(DW.scalar_type(), "ggml_dequantize", [&] {
auto to_cuda = ggml_get_to_cuda<scalar_t>(type);
to_cuda((void*)W.data_ptr(), (scalar_t*)DW.data_ptr(), m * n, stream);
});
@@ -88,135 +92,142 @@ torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
return DW;
}
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
torch::Tensor X, // input
int64_t type, int64_t row) {
torch::stable::Tensor ggml_mul_mat_vec_a8(
torch::stable::Tensor W, // quant weight
torch::stable::Tensor X, // input
int64_t type, int64_t row) {
int col = X.sizes()[1];
int vecs = X.sizes()[0];
const int padded = (col + 512 - 1) / 512 * 512;
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
at::Tensor Y = torch::empty({vecs, row}, options);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
at::Tensor quant_X = torch::empty({vecs, padded / 32 * 9}, options);
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_vec_a8", [&] {
quantize_row_q8_1_cuda<scalar_t>(
(scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), col, vecs, stream);
switch (type) {
case 2:
mul_mat_vec_q4_0_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 3:
mul_mat_vec_q4_1_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 6:
mul_mat_vec_q5_0_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 7:
mul_mat_vec_q5_1_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 8:
mul_mat_vec_q8_0_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 10:
mul_mat_vec_q2_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 11:
mul_mat_vec_q3_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 12:
mul_mat_vec_q4_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 13:
mul_mat_vec_q5_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 14:
mul_mat_vec_q6_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 16:
mul_mat_vec_iq2_xxs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 17:
mul_mat_vec_iq2_xs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 18:
mul_mat_vec_iq3_xxs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 19:
mul_mat_vec_iq1_s_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 20:
mul_mat_vec_iq4_nl_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 21:
mul_mat_vec_iq3_s_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 22:
mul_mat_vec_iq2_s_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 23:
mul_mat_vec_iq4_xs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 29:
mul_mat_vec_iq1_m_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
}
});
const torch::stable::accelerator::DeviceGuard device_guard(
X.get_device_index());
auto Y = torch::stable::empty({vecs, row}, X.scalar_type(), std::nullopt,
W.device());
cudaStream_t stream = get_current_cuda_stream();
auto quant_X = torch::stable::empty({vecs, padded / 32 * 9},
torch::headeronly::ScalarType::Int,
std::nullopt, W.device());
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
X.scalar_type(), "ggml_mul_mat_vec_a8", [&] {
quantize_row_q8_1_cuda<scalar_t>((scalar_t*)X.data_ptr(),
(void*)quant_X.data_ptr(), col, vecs,
stream);
switch (type) {
case 2:
mul_mat_vec_q4_0_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 3:
mul_mat_vec_q4_1_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 6:
mul_mat_vec_q5_0_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 7:
mul_mat_vec_q5_1_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 8:
mul_mat_vec_q8_0_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 10:
mul_mat_vec_q2_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 11:
mul_mat_vec_q3_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 12:
mul_mat_vec_q4_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 13:
mul_mat_vec_q5_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 14:
mul_mat_vec_q6_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 16:
mul_mat_vec_iq2_xxs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 17:
mul_mat_vec_iq2_xs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 18:
mul_mat_vec_iq3_xxs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 19:
mul_mat_vec_iq1_s_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 20:
mul_mat_vec_iq4_nl_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 21:
mul_mat_vec_iq3_s_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 22:
mul_mat_vec_iq2_s_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 23:
mul_mat_vec_iq4_xs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 29:
mul_mat_vec_iq1_m_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
}
});
return Y;
}
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
torch::Tensor X, // input
int64_t type, int64_t row) {
torch::stable::Tensor ggml_mul_mat_a8(torch::stable::Tensor W, // quant weight
torch::stable::Tensor X, // input
int64_t type, int64_t row) {
int col = X.sizes()[1];
int padded = (col + 512 - 1) / 512 * 512;
int batch = X.sizes()[0];
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
at::Tensor Y = torch::empty({batch, row}, options);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
at::Tensor quant_X = torch::empty({batch, padded / 32 * 9}, options);
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_a8", [&] {
const torch::stable::accelerator::DeviceGuard device_guard(
X.get_device_index());
auto Y = torch::stable::empty({batch, row}, X.scalar_type(), std::nullopt,
W.device());
cudaStream_t stream = get_current_cuda_stream();
auto quant_X = torch::stable::empty({batch, padded / 32 * 9},
torch::headeronly::ScalarType::Int,
std::nullopt, W.device());
VLLM_STABLE_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_a8", [&] {
quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(),
col, batch, stream);
@@ -276,21 +287,24 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
return Y;
}
torch::Tensor ggml_moe_a8(torch::Tensor X, // input
torch::Tensor W, // expert weights
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_padded, int64_t type,
int64_t row, int64_t top_k, int64_t tokens) {
torch::stable::Tensor ggml_moe_a8(torch::stable::Tensor X, // input
torch::stable::Tensor W, // expert weights
torch::stable::Tensor sorted_token_ids,
torch::stable::Tensor expert_ids,
torch::stable::Tensor num_tokens_post_padded,
int64_t type, int64_t row, int64_t top_k,
int64_t tokens) {
int col = X.sizes()[1];
int padded = (col + 512 - 1) / 512 * 512;
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
at::Tensor Y = torch::empty({tokens * top_k, row}, options);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options);
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_a8", [&] {
const torch::stable::accelerator::DeviceGuard device_guard(
X.get_device_index());
auto Y = torch::stable::empty({tokens * top_k, row}, X.scalar_type(),
std::nullopt, W.device());
cudaStream_t stream = get_current_cuda_stream();
auto quant_X = torch::stable::empty({tokens, padded / 32 * 9},
torch::headeronly::ScalarType::Int,
std::nullopt, W.device());
VLLM_STABLE_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_a8", [&] {
quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(),
col, tokens, stream);
switch (type) {
@@ -379,19 +393,23 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, // input
return Y;
}
torch::Tensor ggml_moe_a8_vec(torch::Tensor X, // input
torch::Tensor W, // expert weights
torch::Tensor topk_ids, int64_t top_k,
int64_t type, int64_t row, int64_t tokens) {
torch::stable::Tensor ggml_moe_a8_vec(
torch::stable::Tensor X, // input
torch::stable::Tensor W, // expert weights
torch::stable::Tensor topk_ids, int64_t top_k, int64_t type, int64_t row,
int64_t tokens) {
int col = X.sizes()[1];
const int padded = (col + 512 - 1) / 512 * 512;
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
at::Tensor Y = torch::zeros({tokens * top_k, row}, options);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options);
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_vec_a8", [&] {
const torch::stable::accelerator::DeviceGuard device_guard(
X.get_device_index());
auto Y = torch::stable::empty({tokens * top_k, row}, X.scalar_type(),
std::nullopt, W.device());
torch::stable::fill_(Y, 0.0);
cudaStream_t stream = get_current_cuda_stream();
auto quant_X = torch::stable::empty({tokens, padded / 32 * 9},
torch::headeronly::ScalarType::Int,
std::nullopt, W.device());
VLLM_STABLE_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_vec_a8", [&] {
quantize_row_q8_1_cuda<scalar_t>((scalar_t*)X.data_ptr(),
(void*)quant_X.data_ptr(), col, tokens,
stream);
@@ -6,9 +6,8 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa
#include <cstdint>
#include <cstdio>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include "../../torch_utils.h"
#include <torch/csrc/stable/ops.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
@@ -735,7 +734,7 @@ void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight,
fp_gemm_half_q_half_gptq_kernel kernel =
pick_gemm_half_q_half_gptq_kernel(true, m_count, bit);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const cudaStream_t stream = get_current_cuda_stream();
kernel<<<gridDim, blockDim, 0, stream>>>(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, c, size_m, size_n, size_k,
groups, use_v2_format, b_q_perm);
@@ -1164,7 +1163,7 @@ void reconstruct_exllama(const uint32_t* b_q_weight,
reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const cudaStream_t stream = get_current_cuda_stream();
reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>(
b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups,
use_v2_format, out);
@@ -1376,7 +1375,7 @@ void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight,
kernel = gemm_half_q_half_alt_8bit_kernel;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const cudaStream_t stream = get_current_cuda_stream();
kernel<<<gridDim, blockDim, 0, stream>>>(
(const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx,
size_m, size_k / 32 * bit, size_n, use_v2_format);
@@ -1485,7 +1484,7 @@ void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros,
gridDim.y = DIVIDE(height, 32);
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const cudaStream_t stream = get_current_cuda_stream();
kernel<<<gridDim, blockDim, 0, stream>>>(b_q_weight, b_gptq_scales,
b_gptq_qzeros, b_g_idx, height,
width, groups, use_v2_format, out);
@@ -1794,7 +1793,7 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height,
} else if (bit == 8) {
kernel = make_sequential_8bit_kernel;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const cudaStream_t stream = get_current_cuda_stream();
kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, new_qweight, q_perm,
width);
// Replace qweights
@@ -1818,29 +1817,34 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height,
} else if (bit == 8) {
shuffle_kernel = shuffle_8bit_kernel;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const cudaStream_t stream = get_current_cuda_stream();
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height, width);
}
} // namespace gptq
} // namespace vllm
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, bool use_v2_format, int64_t bit) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
at::Tensor c = torch::zeros({a.size(0), b_q_weight.size(1)}, options);
at::Tensor temp_dq = torch::empty(
{b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options);
torch::stable::Tensor gptq_gemm(torch::stable::Tensor a,
torch::stable::Tensor b_q_weight,
torch::stable::Tensor b_gptq_qzeros,
torch::stable::Tensor b_gptq_scales,
torch::stable::Tensor b_g_idx, bool use_exllama,
bool use_v2_format, int64_t bit) {
const torch::stable::accelerator::DeviceGuard device_guard(
a.get_device_index());
auto c = torch::stable::new_zeros(a, {a.size(0), b_q_weight.size(1)});
auto temp_dq =
torch::stable::empty({b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)},
a.scalar_type(), std::nullopt, a.device());
vllm::gptq::gemm_half_q_half_cuda(
at::cuda::getCurrentCUDABlasHandle(), (const half*)a.data_ptr(),
get_current_cuda_blas_handle(), (const half*)a.data_ptr(),
(const uint32_t*)b_q_weight.data_ptr(),
(const uint32_t*)b_gptq_qzeros.data_ptr(),
(const half*)b_gptq_scales.data_ptr(),
b_g_idx.device().is_meta() ? NULL : (const int*)b_g_idx.data_ptr(),
b_g_idx.device().type() == torch::stable::DeviceType::Meta
? NULL
: (const int*)b_g_idx.data_ptr(),
(half*)c.data_ptr(), (half*)temp_dq.data_ptr(),
c.size(0), // m
c.size(1), // n
@@ -1850,11 +1854,14 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
return c;
}
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
void gptq_shuffle(torch::stable::Tensor q_weight, torch::stable::Tensor q_perm,
int64_t bit) {
const torch::stable::accelerator::DeviceGuard device_guard(
q_weight.get_device_index());
vllm::gptq::shuffle_exllama_weight(
(uint32_t*)q_weight.data_ptr(),
q_perm.device().is_meta() || q_perm.numel() == 0
q_perm.device().type() == torch::stable::DeviceType::Meta ||
q_perm.numel() == 0
? NULL
: (int*)q_perm.data_ptr(),
q_weight.size(0) * 32 / bit, q_weight.size(1), bit);
@@ -1,11 +1,9 @@
#include "common.cuh"
#include "dispatch_utils.h"
#include "cub_helpers.h"
#include "libtorch_stable/quantization/vectorization_utils.cuh"
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/Exceptions.h>
#include <tuple>
#include "../../../../quantization/w8a8/fp8/common.cuh"
#include "../../../dispatch_utils.h"
#include "../../../../cub_helpers.h"
#include "../../vectorization_utils.cuh"
#include "../../../torch_utils.h"
#include <torch/csrc/stable/macros.h>
namespace vllm {
// STRIDE_I_ZERO: true if scale_stride_i == 0 (per-tensor or per-channel)
@@ -183,16 +181,16 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided(
} // namespace vllm
void static_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor const& scale, // various shapes
std::optional<std::tuple<int64_t, int64_t>>
opt_group_shape) // optional explicit (group_m, group_n)
torch::stable::Tensor& out, // [..., d]
torch::stable::Tensor const& input, // [..., d]
torch::stable::Tensor const& scale, // various shapes
std::optional<torch::headeronly::IntHeaderOnlyArrayRef>
opt_group_shape) // optional explicit [group_m, group_n]
{
TORCH_CHECK(input.stride(-1) == 1,
"last dimension of input must be contiguous");
TORCH_CHECK(out.stride(-1) == 1,
"last dimension of output must be contiguous");
STD_TORCH_CHECK(input.stride(-1) == 1,
"last dimension of input must be contiguous");
STD_TORCH_CHECK(out.stride(-1) == 1,
"last dimension of output must be contiguous");
const int hidden_size = input.size(-1); // N (columns)
const int num_tokens = input.numel() / hidden_size; // M (rows)
@@ -212,13 +210,18 @@ void static_scaled_fp8_quant(
} else if (scale.dim() == 1) {
// 1D scale: require explicit group_shape to disambiguate per-channel vs
// per-token (avoids edge case where num_tokens == hidden_size)
TORCH_CHECK(opt_group_shape.has_value(),
"1D scale requires explicit group_shape to disambiguate "
"per-channel vs per-token quantization. "
"Use group_shape=(-1, 1) for per-channel or group_shape=(1, "
"-1) for per-token.");
STD_TORCH_CHECK(
opt_group_shape.has_value(),
"1D scale requires explicit group_shape to disambiguate "
"per-channel vs per-token quantization. "
"Use group_shape=(-1, 1) for per-channel or group_shape=(1, "
"-1) for per-token.");
STD_TORCH_CHECK(opt_group_shape->size() == 2,
"group_shape must have exactly 2 elements, got ",
opt_group_shape->size());
const auto& [opt_group_m, opt_group_n] = opt_group_shape.value();
const auto opt_group_m = (*opt_group_shape)[0];
const auto opt_group_n = (*opt_group_shape)[1];
group_m = opt_group_m == -1 ? num_tokens : static_cast<int>(opt_group_m);
group_n = opt_group_n == -1 ? hidden_size : static_cast<int>(opt_group_n);
@@ -228,11 +231,11 @@ void static_scaled_fp8_quant(
const int64_t expected_scale_n = hidden_size / group_n;
const int64_t expected_scale_numel = expected_scale_m * expected_scale_n;
TORCH_CHECK(scale_len == expected_scale_numel, "1D scale length (",
scale_len, ") does not match expected size (",
expected_scale_numel, ") for group_shape (", opt_group_m, ", ",
opt_group_n, ") with input shape (", num_tokens, ", ",
hidden_size, ")");
STD_TORCH_CHECK(scale_len == expected_scale_numel, "1D scale length (",
scale_len, ") does not match expected size (",
expected_scale_numel, ") for group_shape (", opt_group_m,
", ", opt_group_n, ") with input shape (", num_tokens, ", ",
hidden_size, ")");
// For 1D scale, determine strides based on which dim is trivial
// Scale indexing: scale[gi * scale_stride_i + gj * scale_stride_j]
@@ -248,7 +251,7 @@ void static_scaled_fp8_quant(
scale_stride_i = scale.stride(0);
scale_stride_j = 0;
} else {
TORCH_CHECK(
STD_TORCH_CHECK(
false,
"1D scale can only be used when one of the scale dimensions is 1. "
"For 2D group scaling, use a 2D scale tensor.");
@@ -259,10 +262,12 @@ void static_scaled_fp8_quant(
const int64_t scale_size_0 = scale.size(0);
const int64_t scale_size_1 = scale.size(1);
TORCH_CHECK(num_tokens % scale_size_0 == 0, "num_tokens (", num_tokens,
") must be divisible by scale.size(0) (", scale_size_0, ")");
TORCH_CHECK(hidden_size % scale_size_1 == 0, "hidden_size (", hidden_size,
") must be divisible by scale.size(1) (", scale_size_1, ")");
STD_TORCH_CHECK(num_tokens % scale_size_0 == 0, "num_tokens (", num_tokens,
") must be divisible by scale.size(0) (", scale_size_0,
")");
STD_TORCH_CHECK(hidden_size % scale_size_1 == 0, "hidden_size (",
hidden_size, ") must be divisible by scale.size(1) (",
scale_size_1, ")");
// Infer from 2D scale shape
int inferred_group_m = num_tokens / scale_size_0;
@@ -270,16 +275,21 @@ void static_scaled_fp8_quant(
// Use explicit if provided, otherwise use inferred
if (opt_group_shape.has_value()) {
const auto& [opt_group_m, opt_group_n] = opt_group_shape.value();
STD_TORCH_CHECK(opt_group_shape->size() == 2,
"group_shape must have exactly 2 elements, got ",
opt_group_shape->size());
const auto opt_group_m = (*opt_group_shape)[0];
const auto opt_group_n = (*opt_group_shape)[1];
group_m = opt_group_m == -1 ? num_tokens : static_cast<int>(opt_group_m);
group_n = opt_group_n == -1 ? hidden_size : static_cast<int>(opt_group_n);
// Validate explicit matches inferred
TORCH_CHECK(group_m == inferred_group_m && group_n == inferred_group_n,
"Explicit group_shape (", opt_group_m, ", ", opt_group_n,
") does not match inferred group shape (", inferred_group_m,
", ", inferred_group_n, ") from 2D scale tensor shape (",
scale_size_0, ", ", scale_size_1, ")");
STD_TORCH_CHECK(
group_m == inferred_group_m && group_n == inferred_group_n,
"Explicit group_shape (", opt_group_m, ", ", opt_group_n,
") does not match inferred group shape (", inferred_group_m, ", ",
inferred_group_n, ") from 2D scale tensor shape (", scale_size_0,
", ", scale_size_1, ")");
} else {
group_m = inferred_group_m;
group_n = inferred_group_n;
@@ -288,8 +298,8 @@ void static_scaled_fp8_quant(
scale_stride_i = scale.stride(0);
scale_stride_j = scale.stride(1);
} else {
TORCH_CHECK(false, "scale must be 0D, 1D, or 2D tensor, but got ",
scale.dim(), "D");
STD_TORCH_CHECK(false, "scale must be 0D, 1D, or 2D tensor, but got ",
scale.dim(), "D");
}
const int block_size = 256;
@@ -299,37 +309,39 @@ void static_scaled_fp8_quant(
const int64_t in_row_stride = input.stride(-2);
const int64_t out_row_stride = out.stride(-2);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
// Dispatch to template-specialized kernel based on stride pattern
VLLM_DISPATCH_FLOATING_TYPES(
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
VLLM_DISPATCH_FP8_TYPES(
VLLM_STABLE_DISPATCH_FP8_TYPES(
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
VLLM_DISPATCH_BOOL(scale_stride_i == 0, S0_ZERO, [&] {
VLLM_DISPATCH_BOOL(scale_stride_j == 0, S1_ZERO, [&] {
VLLM_STABLE_DISPATCH_BOOL(scale_stride_i == 0, S0_ZERO, [&] {
VLLM_STABLE_DISPATCH_BOOL(scale_stride_j == 0, S1_ZERO, [&] {
vllm::scaled_fp8_quant_kernel_strided_group_shape<
scalar_t, fp8_t, S0_ZERO, S1_ZERO>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), hidden_size, in_row_stride,
out_row_stride, group_m, group_n, scale_stride_i,
scale_stride_j);
out.mutable_data_ptr<fp8_t>(),
input.const_data_ptr<scalar_t>(),
scale.const_data_ptr<float>(), hidden_size,
in_row_stride, out_row_stride, group_m, group_n,
scale_stride_i, scale_stride_j);
});
});
});
});
}
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scale) // [1]
void dynamic_scaled_fp8_quant(torch::stable::Tensor& out, // [..., d]
torch::stable::Tensor const& input, // [..., d]
torch::stable::Tensor& scale) // [1]
{
TORCH_CHECK(input.stride(-1) == 1,
"last dimension of input must be contiguous");
TORCH_CHECK(out.stride(-1) == 1,
"last dimension of output must be contiguous");
STD_TORCH_CHECK(input.stride(-1) == 1,
"last dimension of input must be contiguous");
STD_TORCH_CHECK(out.stride(-1) == 1,
"last dimension of output must be contiguous");
const int hidden_size = input.size(-1);
const int num_tokens = input.numel() / hidden_size;
@@ -340,40 +352,43 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
const int64_t in_row_stride = input.stride(-2);
const int64_t out_row_stride = out.stride(-2);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
// scale tensor should be initialised to <=0 before reduction
AT_CUDA_CHECK(
cudaMemsetAsync(scale.data_ptr<float>(), 0, sizeof(float), stream));
STD_CUDA_CHECK(cudaMemsetAsync(scale.mutable_data_ptr<float>(), 0,
sizeof(float), stream));
VLLM_DISPATCH_FLOATING_TYPES(
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
VLLM_DISPATCH_FP8_TYPES(
VLLM_STABLE_DISPATCH_FP8_TYPES(
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
vllm::segmented_max_reduction_strided<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(
scale.data_ptr<float>(), input.data_ptr<scalar_t>(),
hidden_size, in_row_stride,
static_cast<int64_t>(num_tokens));
scale.mutable_data_ptr<float>(),
input.const_data_ptr<scalar_t>(), hidden_size,
in_row_stride, static_cast<int64_t>(num_tokens));
vllm::scaled_fp8_quant_kernel_strided_dynamic<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), hidden_size, in_row_stride,
out_row_stride);
<<<grid, block, 0, stream>>>(out.mutable_data_ptr<fp8_t>(),
input.const_data_ptr<scalar_t>(),
scale.const_data_ptr<float>(),
hidden_size, in_row_stride,
out_row_stride);
});
});
}
void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
TORCH_CHECK(input.stride(-1) == 1,
"last dimension of input must be contiguous");
TORCH_CHECK(out.stride(-1) == 1,
"last dimension of output must be contiguous");
torch::stable::Tensor& out, // [..., d]
torch::stable::Tensor const& input, // [..., d]
torch::stable::Tensor& scales,
std::optional<torch::stable::Tensor> const& scale_ub) {
STD_TORCH_CHECK(input.stride(-1) == 1,
"last dimension of input must be contiguous");
STD_TORCH_CHECK(out.stride(-1) == 1,
"last dimension of output must be contiguous");
const int hidden_size = input.size(-1);
const int num_tokens = input.numel() / hidden_size;
@@ -384,20 +399,24 @@ void dynamic_per_token_scaled_fp8_quant(
const int64_t in_row_stride = input.stride(-2);
const int64_t out_row_stride = out.stride(-2);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] {
VLLM_DISPATCH_FP8_TYPES(
VLLM_STABLE_DISPATCH_FP8_TYPES(
out.scalar_type(),
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided<
scalar_t, fp8_t><<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
input.data_ptr<scalar_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
hidden_size, in_row_stride, out_row_stride);
vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided<scalar_t,
fp8_t>
<<<grid, block, 0, stream>>>(
out.mutable_data_ptr<fp8_t>(),
scales.mutable_data_ptr<float>(),
input.const_data_ptr<scalar_t>(),
scale_ub.has_value() ? scale_ub->const_data_ptr<float>()
: nullptr,
hidden_size, in_row_stride, out_row_stride);
});
});
}
@@ -1,12 +1,11 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/csrc/stable/tensor.h>
#include <cmath>
#include "dispatch_utils.h"
#include "libtorch_stable/quantization/vectorization_utils.cuh"
#include "cub_helpers.h"
#include "../../../dispatch_utils.h"
#include "../../../torch_utils.h"
#include "../../vectorization_utils.cuh"
#include "../../../../cub_helpers.h"
static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM
@@ -263,66 +262,73 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
} // namespace vllm
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor const& scale,
std::optional<torch::Tensor> const& azp) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scale.numel() == 1);
TORCH_CHECK(!azp || azp->numel() == 1);
void static_scaled_int8_quant(
torch::stable::Tensor& out, // [..., hidden_size]
torch::stable::Tensor const& input, // [..., hidden_size]
torch::stable::Tensor const& scale,
std::optional<torch::stable::Tensor> const& azp) {
STD_TORCH_CHECK(input.is_contiguous());
STD_TORCH_CHECK(out.is_contiguous());
STD_TORCH_CHECK(scale.numel() == 1);
STD_TORCH_CHECK(!azp || azp->numel() == 1);
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 256));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
if (!azp) {
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scale.data_ptr<float>(), hidden_size);
<<<grid, block, 0, stream>>>(input.const_data_ptr<scalar_t>(),
out.mutable_data_ptr<int8_t>(),
scale.const_data_ptr<float>(),
hidden_size);
} else {
vllm::static_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scale.data_ptr<float>(), azp->data_ptr<int32_t>(),
hidden_size);
input.const_data_ptr<scalar_t>(),
out.mutable_data_ptr<int8_t>(), scale.const_data_ptr<float>(),
azp->const_data_ptr<int32_t>(), hidden_size);
}
});
}
void dynamic_scaled_int8_quant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor& scales, std::optional<torch::Tensor> const& azp) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scales.is_contiguous());
TORCH_CHECK(!azp || azp->is_contiguous());
torch::stable::Tensor& out, // [..., hidden_size]
torch::stable::Tensor const& input, // [..., hidden_size]
torch::stable::Tensor& scales,
std::optional<torch::stable::Tensor> const& azp) {
STD_TORCH_CHECK(input.is_contiguous());
STD_TORCH_CHECK(out.is_contiguous());
STD_TORCH_CHECK(scales.is_contiguous());
STD_TORCH_CHECK(!azp || azp->is_contiguous());
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 256));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream();
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
if (!azp) {
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scales.data_ptr<float>(), hidden_size);
<<<grid, block, 0, stream>>>(input.const_data_ptr<scalar_t>(),
out.mutable_data_ptr<int8_t>(),
scales.mutable_data_ptr<float>(),
hidden_size);
} else {
vllm::dynamic_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scales.data_ptr<float>(), azp->data_ptr<int32_t>(),
hidden_size);
<<<grid, block, 0, stream>>>(input.const_data_ptr<scalar_t>(),
out.mutable_data_ptr<int8_t>(),
scales.mutable_data_ptr<float>(),
azp->mutable_data_ptr<int32_t>(),
hidden_size);
}
});
}
+139
View File
@@ -266,6 +266,109 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
// Hadamard transforms
// conditionally compiled so impl registration is in source file
ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor");
// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
// SwiGLU activation with input clamping.
ops.def(
"silu_and_mul_with_clamp(Tensor! result, Tensor input, float limit) "
"-> ()");
// Activation function used in GeGLU with `none` approximation.
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
// Activation function used in GeGLU with `tanh` approximation.
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
// FATReLU implementation.
ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");
ops.def(
"swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float "
"limit=7.0) "
"-> ()");
// GELU implementation used in GPT-2.
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
// Approximate GELU implementation.
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
// Quick GELU implementation.
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale,"
"Tensor? azp) -> ()");
// Compute int8 quantized tensor and scaling factor
ops.def(
"dynamic_scaled_int8_quant(Tensor! result, Tensor input, Tensor! scale, "
"Tensor!? azp) -> ()");
// Compute FP8 quantized tensor for given scaling factor.
// Supports per-tensor, per-channel, per-token, and arbitrary 2D group
// scaling. Optional group_m/group_n specify the group shape explicitly;
// required for 1D scales to disambiguate per-channel vs per-token.
ops.def(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale, "
"int[]? group_shape=None) -> ()");
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
"-> "
"()");
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
"Tensor! scale, Tensor? scale_ub) -> "
"()");
// Quantized GEMM for GPTQ.
// Note: even though the C++ inferred schema is correct for this op, it seems
// to prevent the meta function registry.
ops.def(
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, bool "
"use_v2_format, int bit) "
"-> Tensor");
// Post processing for GPTQ.
ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
// Dequantization for GGML.
ops.def(
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? "
"dtype) -> Tensor");
// mmvq kernel for GGML.
ops.def(
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) "
"-> Tensor");
// mmq kernel for GGML.
ops.def(
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
// moe kernel for GGML.
ops.def(
"ggml_moe_a8(Tensor X, Tensor W, "
"Tensor sorted_token_ids, Tensor expert_ids, Tensor "
"num_tokens_post_padded, "
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor");
ops.def(
"ggml_moe_a8_vec(Tensor X, Tensor W, "
"Tensor topk_ids, int top_k, "
"int type, SymInt row, SymInt tokens) -> Tensor");
ops.def("ggml_moe_get_block_size(int type) -> int");
}
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
@@ -312,6 +415,39 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
// AllSpark ops: conditionally compiled so impl registrations are in source
// files (allspark_repack.cu and allspark_qgemm_w8a16.cu)
#endif
// Activation kernels (shared CUDA/ROCm)
ops.impl("silu_and_mul", TORCH_BOX(&silu_and_mul));
ops.impl("mul_and_silu", TORCH_BOX(&mul_and_silu));
ops.impl("gelu_and_mul", TORCH_BOX(&gelu_and_mul));
ops.impl("gelu_tanh_and_mul", TORCH_BOX(&gelu_tanh_and_mul));
ops.impl("fatrelu_and_mul", TORCH_BOX(&fatrelu_and_mul));
ops.impl("swigluoai_and_mul", TORCH_BOX(&swigluoai_and_mul));
ops.impl("gelu_new", TORCH_BOX(&gelu_new));
ops.impl("gelu_fast", TORCH_BOX(&gelu_fast));
ops.impl("gelu_quick", TORCH_BOX(&gelu_quick));
ops.impl("silu_and_mul_with_clamp", TORCH_BOX(&silu_and_mul_clamp));
// INT8 quantization kernels
ops.impl("static_scaled_int8_quant", TORCH_BOX(&static_scaled_int8_quant));
ops.impl("dynamic_scaled_int8_quant", TORCH_BOX(&dynamic_scaled_int8_quant));
// FP8 quantization kernels
ops.impl("static_scaled_fp8_quant", TORCH_BOX(&static_scaled_fp8_quant));
ops.impl("dynamic_scaled_fp8_quant", TORCH_BOX(&dynamic_scaled_fp8_quant));
ops.impl("dynamic_per_token_scaled_fp8_quant",
TORCH_BOX(&dynamic_per_token_scaled_fp8_quant));
// GPTQ kernels
ops.impl("gptq_gemm", TORCH_BOX(&gptq_gemm));
ops.impl("gptq_shuffle", TORCH_BOX(&gptq_shuffle));
// GGML kernels
ops.impl("ggml_dequantize", TORCH_BOX(&ggml_dequantize));
ops.impl("ggml_mul_mat_vec_a8", TORCH_BOX(&ggml_mul_mat_vec_a8));
ops.impl("ggml_mul_mat_a8", TORCH_BOX(&ggml_mul_mat_a8));
ops.impl("ggml_moe_a8", TORCH_BOX(&ggml_moe_a8));
ops.impl("ggml_moe_a8_vec", TORCH_BOX(&ggml_moe_a8_vec));
}
// These capability-check functions take only primitive args (no tensors), so
@@ -329,6 +465,9 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
ops.impl("cutlass_scaled_mm_supports_fp4",
TORCH_BOX(&cutlass_scaled_mm_supports_fp4));
#endif
// GGML block size lookup (no tensor args)
ops.impl("ggml_moe_get_block_size", TORCH_BOX(&ggml_moe_get_block_size));
}
REGISTER_EXTENSION(_C_stable_libtorch)
+5 -1
View File
@@ -6,8 +6,12 @@
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/shim_utils.h>
#ifndef USE_ROCM
#include <cuda_runtime.h>
#else
#include <hip/hip_runtime.h>
#endif
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <deque>
#include <mutex>
-47
View File
@@ -149,17 +149,10 @@ void persistent_masked_m_silu_mul_quant(
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
bool use_ue8m0);
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input,
double threshold);
void swigluoai_and_mul(torch::Tensor& out, torch::Tensor& input,
double alpha = 1.702, double limit = 7.0);
void gelu_new(torch::Tensor& out, torch::Tensor& input);
void gelu_fast(torch::Tensor& out, torch::Tensor& input);
@@ -174,28 +167,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
int64_t n,
std::optional<at::ScalarType> const& dtype);
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
int64_t type, int64_t row);
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
int64_t row);
torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_padded, int64_t type,
int64_t row, int64_t top_k, int64_t tokens);
torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W,
torch::Tensor topk_ids, int64_t top_k,
int64_t type, int64_t row, int64_t tokens);
int64_t ggml_moe_get_block_size(int64_t type);
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale,
std::optional<torch::Tensor> const& azp);
@@ -204,24 +175,6 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scales,
std::optional<torch::Tensor> const& azp);
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, bool use_v2_format, int64_t bit);
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
void static_scaled_fp8_quant(
torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale,
std::optional<std::tuple<int64_t, int64_t>> group_shape = std::nullopt);
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scale);
void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
std::optional<torch::Tensor> const& scale_ub);
void selective_scan_fwd(
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
const torch::Tensor& B, const torch::Tensor& C,
+16 -14
View File
@@ -7,23 +7,23 @@
*/
#include <cmath>
#include <torch/types.h>
#include <torch/headeronly/macros/Macros.h>
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
#include <torch/headeronly/util/Float8_e4m3fn.h>
#define MAYBE_HOST_DEVICE C10_HOST_DEVICE
#else
#include <ATen/hip/HIPContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <torch/headeronly/util/Float8_e4m3fn.h>
#include <torch/headeronly/util/Float8_e4m3fnuz.h>
// ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr
#define MAYBE_HOST_DEVICE
#endif
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> ||
std::is_same_v<T, c10::Float8_e4m3fnuz> ||
std::is_same_v<T, int8_t>>>
typename = std::enable_if_t<
std::is_same_v<T, torch::headeronly::Float8_e4m3fn> ||
std::is_same_v<T, torch::headeronly::Float8_e4m3fnuz> ||
std::is_same_v<T, int8_t>>>
struct quant_type_max {
static constexpr T val() { return std::numeric_limits<T>::max(); }
};
@@ -31,9 +31,10 @@ struct quant_type_max {
// Using the default max value from pytorch (240.0 0x7F) will cause accuracy
// issues when running dynamic quantization. Here use 224.0 0x7E for rocm.
template <>
struct quant_type_max<c10::Float8_e4m3fnuz> {
static constexpr c10::Float8_e4m3fnuz val() {
return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits());
struct quant_type_max<torch::headeronly::Float8_e4m3fnuz> {
static constexpr torch::headeronly::Float8_e4m3fnuz val() {
return torch::headeronly::Float8_e4m3fnuz(
0x7E, torch::headeronly::Float8_e4m3fnuz::from_bits());
}
};
@@ -42,9 +43,10 @@ MAYBE_HOST_DEVICE static constexpr T quant_type_max_v =
quant_type_max<T>::val();
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> ||
std::is_same_v<T, c10::Float8_e4m3fnuz> ||
std::is_same_v<T, int8_t>>>
typename = std::enable_if_t<
std::is_same_v<T, torch::headeronly::Float8_e4m3fn> ||
std::is_same_v<T, torch::headeronly::Float8_e4m3fnuz> ||
std::is_same_v<T, int8_t>>>
struct min_scaling_factor {
C10_DEVICE C10_ALWAYS_INLINE static float val() {
return 1.0f / (quant_type_max_v<T> * 512.0f);
+18 -1
View File
@@ -5,6 +5,19 @@
#include <cmath>
// This header is shared between _C and _C_stable_libtorch targets.
// torch_utils.h provides get_device_prop(). We need to pass USE_CUDA
// to the .so to expose some of the shims used by torch_utils.h. For now
// this is only done for _C_stable_libtorch and not for _C, so we use the
// non stable at::cuda::getCurrentDeviceProperties for _C for now.
#ifdef TORCH_TARGET_VERSION
#include "../../../libtorch_stable/torch_utils.h"
#else
#ifdef USE_ROCM
#include <ATen/hip/HIPContext.h>
#endif
#endif
#ifndef USE_ROCM
#include "nvidia/quant_utils.cuh"
#else
@@ -18,7 +31,11 @@ static bool is_fp8_ocp() {
#ifndef USE_ROCM
return true;
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
#ifdef TORCH_TARGET_VERSION
auto* dprops = get_device_prop();
#else
auto* dprops = at::cuda::getCurrentDeviceProperties();
#endif
std::string device_arch = dprops->gcnArchName;
size_t substring = device_arch.find("gfx94");
return substring == std::string::npos;
+1 -128
View File
@@ -77,17 +77,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor? output_scale=None) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
// SwiGLU activation with input clamping.
ops.def(
"silu_and_mul_with_clamp(Tensor! result, Tensor input, float limit) "
"-> ()");
ops.impl("silu_and_mul_with_clamp", torch::kCUDA, &silu_and_mul_clamp);
// Activation ops (quantized only — basic ops moved to _C_stable_libtorch)
ops.def(
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
@@ -104,39 +94,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
&silu_and_mul_per_block_quant);
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
// Activation function used in GeGLU with `none` approximation.
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
// Activation function used in GeGLU with `tanh` approximation.
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
// FATReLU implementation.
ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");
ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul);
ops.def(
"swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float "
"limit=7.0) "
"-> ()");
ops.impl("swigluoai_and_mul", torch::kCUDA, &swigluoai_and_mul);
// GELU implementation used in GPT-2.
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_new", torch::kCUDA, &gelu_new);
// Approximate GELU implementation.
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
// Quick GELU implementation.
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
@@ -318,39 +275,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#endif
// Dequantization for GGML.
ops.def(
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? "
"dtype) -> Tensor");
ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
// mmvq kernel for GGML.
ops.def(
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) "
"-> Tensor");
ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
// mmq kernel for GGML.
ops.def(
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
// moe kernel for GGML.
ops.def(
"ggml_moe_a8(Tensor X, Tensor W, "
"Tensor sorted_token_ids, Tensor expert_ids, Tensor "
"num_tokens_post_padded, "
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor");
ops.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8);
ops.def(
"ggml_moe_a8_vec(Tensor X, Tensor W, "
"Tensor topk_ids, int top_k, "
"int type, SymInt row, SymInt tokens) -> Tensor");
ops.impl("ggml_moe_a8_vec", torch::kCUDA, &ggml_moe_a8_vec);
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
#ifndef USE_ROCM
// Expert-specialization mxfp8 blockscaled grouped quantization (SM100+).
ops.def(
@@ -370,57 +294,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#endif
// Quantized GEMM for GPTQ.
// Note: even though the C++ inferred schema is correct for this op, it seems
// to prevent the meta function registry.
ops.def(
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, bool "
"use_v2_format, int bit) "
"-> Tensor");
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
// Post processing for GPTQ.
ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// Compute FP8 quantized tensor for given scaling factor.
// Supports per-tensor, per-channel, per-token, and arbitrary 2D group
// scaling. Optional group_m/group_n specify the group shape explicitly;
// required for 1D scales to disambiguate per-channel vs per-token.
ops.def(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale, "
"(int, int)? group_shape=None) -> ()");
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
"-> "
"()");
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
"Tensor! scale, Tensor? scale_ub) -> "
"()");
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&dynamic_per_token_scaled_fp8_quant);
// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale,"
"Tensor? azp) -> ()");
ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
// Compute int8 quantized tensor and scaling factor
ops.def(
"dynamic_scaled_int8_quant(Tensor! result, Tensor input, Tensor! scale, "
"Tensor!? azp) -> ()");
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
&dynamic_scaled_int8_quant);
// Mamba selective scan kernel
ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
+17
View File
@@ -0,0 +1,17 @@
#pragma once
// Shared TORCH_UTILS_CHECK across both libtorch stable and unstable source
// files. Keep this header free of CUTLASS/CUTE so attention/quant headers can
// use it.
//
// If TORCH_TARGET_VERSION is defined, we are building _C_stable_libtorch.so so
// use STD_TORCH_CHECK via header-only.
// Otherwise, use TORCH_CHECK via torch/all.h.
#ifdef TORCH_TARGET_VERSION
#include <torch/headeronly/util/Exception.h>
#define TORCH_UTILS_CHECK STD_TORCH_CHECK
#else
#include <torch/all.h>
#define TORCH_UTILS_CHECK TORCH_CHECK
#endif
+1 -3
View File
@@ -1045,9 +1045,7 @@ if _is_cpu():
if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C"))
# also _is_hip() once https://github.com/vllm-project/vllm/issues/35163 is
# fixed
if _is_cuda():
if _is_cuda() or _is_hip():
ext_modules.append(CMakeExtension(name="vllm._C_stable_libtorch"))
package_data = {
+5
View File
@@ -44,6 +44,11 @@ except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
# import custom ops, trigger op registration
try:
import vllm._C_stable_libtorch # noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._C_stable_libtorch with %r", e)
try:
import vllm._rocm_C # noqa: F401
except ImportError as e: