mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[6/n] Migrate activation kernels, gptq, gguf, non cutlass w8a8 to libtorch stable ABI (continued) (#42663)
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com> Signed-off-by: Chris Leonard <chleonar@redhat.com> Co-authored-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com> Co-authored-by: Shengqi Chen <harry-chen@outlook.com>
This commit is contained in:
+41
-22
@@ -312,19 +312,14 @@ set(VLLM_EXT_SRC
|
||||
"csrc/attention/paged_attention_v2.cu"
|
||||
"csrc/attention/merge_attn_states.cu"
|
||||
"csrc/pos_encoding_kernels.cu"
|
||||
"csrc/activation_kernels.cu"
|
||||
"csrc/layernorm_kernels.cu"
|
||||
"csrc/fused_qknorm_rope_kernel.cu"
|
||||
"csrc/layernorm_quant_kernels.cu"
|
||||
"csrc/sampler.cu"
|
||||
"csrc/topk.cu"
|
||||
"csrc/cuda_view.cu"
|
||||
"csrc/quantization/gptq/q_gemm.cu"
|
||||
"csrc/quantization/w8a8/int8/scaled_quant.cu"
|
||||
"csrc/quantization/w8a8/fp8/common.cu"
|
||||
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
|
||||
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
|
||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||
"csrc/quantization/activation_kernels.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
@@ -628,33 +623,33 @@ define_extension_target(
|
||||
# Setting this variable sidesteps the issue by calling the driver directly.
|
||||
target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
|
||||
|
||||
# add OR VLLM_GPU_LANG STREQUAL "HIP" here once
|
||||
# https://github.com/vllm-project/vllm/issues/35163 is resolved
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
#
|
||||
# _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY)
|
||||
#
|
||||
set(VLLM_STABLE_EXT_SRC
|
||||
"csrc/libtorch_stable/torch_bindings.cpp"
|
||||
"csrc/cutlass_extensions/common.cpp"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu")
|
||||
"csrc/libtorch_stable/activation_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/fp8/common.cu"
|
||||
"csrc/libtorch_stable/quantization/gptq/q_gemm.cu"
|
||||
"csrc/libtorch_stable/quantization/gguf/gguf_kernel.cu")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/cutlass_extensions/common.cpp"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||
"csrc/libtorch_stable/permute_cols.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu"
|
||||
"csrc/libtorch_stable/quantization/awq/gemm_kernels.cu")
|
||||
endif()
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_STABLE_EXT_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
endif()
|
||||
|
||||
# DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper and later)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
@@ -1034,6 +1029,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Building hadacore")
|
||||
endif()
|
||||
|
||||
# if CUDA endif
|
||||
endif()
|
||||
|
||||
message(STATUS "Enabling C_stable extension.")
|
||||
define_extension_target(
|
||||
_C_stable_libtorch
|
||||
@@ -1053,13 +1051,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
target_compile_definitions(_C_stable_libtorch PRIVATE
|
||||
TORCH_TARGET_VERSION=0x020A000000000000ULL)
|
||||
|
||||
# Needed to use cuda APIs from C-shim
|
||||
target_compile_definitions(_C_stable_libtorch PRIVATE
|
||||
USE_CUDA)
|
||||
# Needed to use cuda/hip APIs from C-shim
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
target_compile_definitions(_C_stable_libtorch PRIVATE USE_CUDA)
|
||||
# Needed by CUTLASS kernels
|
||||
target_compile_definitions(_C_stable_libtorch PRIVATE
|
||||
CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
|
||||
elseif(VLLM_GPU_LANG STREQUAL "HIP")
|
||||
target_compile_definitions(_C_stable_libtorch PRIVATE USE_ROCM)
|
||||
endif()
|
||||
|
||||
# Needed by CUTLASS kernels
|
||||
target_compile_definitions(_C_stable_libtorch PRIVATE
|
||||
CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
|
||||
# On ROCm, _C_stable_libtorch calls raw HIP APIs (e.g. hipGetDevice in
|
||||
# get_device_prop()) which must resolve to the same libamdhip64.so that
|
||||
# PyTorch uses. When PyTorch bundles its own copy (pip/conda wheels),
|
||||
# the raw HIP calls would otherwise resolve to the system ROCm copy,
|
||||
# initializing a second HIP runtime that corrupts device state (wrong
|
||||
# device on DeviceGuard, core dumps on multi-GPU tests).
|
||||
#
|
||||
# If PyTorch doesn't bundle libamdhip64 (built from source against system
|
||||
# ROCm), there is only one copy in the process and no action is needed —
|
||||
# the HIP compiler already links the system libamdhip64 automatically.
|
||||
if(VLLM_GPU_LANG STREQUAL "HIP")
|
||||
find_library(_STABLE_TORCH_AMDHIP64 amdhip64
|
||||
PATHS "${TORCH_INSTALL_PREFIX}/lib" NO_DEFAULT_PATH)
|
||||
if(_STABLE_TORCH_AMDHIP64)
|
||||
message(STATUS "Found PyTorch-bundled libamdhip64 at ${_STABLE_TORCH_AMDHIP64}")
|
||||
target_link_libraries(_C_stable_libtorch PRIVATE ${_STABLE_TORCH_AMDHIP64})
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "attention_generic.cuh"
|
||||
#include "torch_utils.h"
|
||||
|
||||
#include <stdint.h>
|
||||
#ifdef ENABLE_FP8
|
||||
@@ -30,7 +31,7 @@ inline Fp8KVCacheDataType get_fp8_kv_cache_data_type(
|
||||
} else if (dtype_str == "fp8_e5m2") {
|
||||
return Fp8KVCacheDataType::kFp8E5M2;
|
||||
}
|
||||
TORCH_CHECK(false, "Unsupported fp8 kv cache data type: ", dtype_str);
|
||||
TORCH_UTILS_CHECK(false, "Unsupported fp8 kv cache data type: ", dtype_str);
|
||||
}
|
||||
|
||||
// fp8 vector types for quantization of kv cache
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#else
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch_utils.h"
|
||||
|
||||
// This header is shared between _C (unstable ABI, used by machete) and
|
||||
// _C_stable_libtorch (stable ABI, used by W4A8/sparse). TORCH_TARGET_VERSION
|
||||
// is defined only for the stable target, so we switch includes and types
|
||||
@@ -8,13 +10,9 @@
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/util/BFloat16.h>
|
||||
#include <torch/headeronly/util/Half.h>
|
||||
#include <torch/headeronly/util/shim_utils.h> // for STD_TORCH_CHECK
|
||||
using TorchTensor = torch::stable::Tensor;
|
||||
#define TORCH_UTILS_CHECK STD_TORCH_CHECK
|
||||
#else
|
||||
#include <torch/all.h>
|
||||
using TorchTensor = torch::Tensor;
|
||||
#define TORCH_UTILS_CHECK TORCH_CHECK
|
||||
#endif
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "cuda_vec_utils.cuh"
|
||||
#include "../cuda_compat.h"
|
||||
#include "../cuda_vec_utils.cuh"
|
||||
#include "dispatch_utils.h"
|
||||
#include "torch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
@@ -210,64 +210,68 @@ packed_gelu_tanh_kernel(const packed_t& val) {
|
||||
return; \
|
||||
} \
|
||||
dim3 grid(num_tokens); \
|
||||
int cc_major = at::cuda::getCurrentDeviceProperties()->major; \
|
||||
int cc_major = get_device_prop()->major; \
|
||||
int support_vec = \
|
||||
(CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \
|
||||
? vllm::VecTraits<true>::ARCH_MAX_VEC_SIZE \
|
||||
: vllm::VecTraits<false>::ARCH_MAX_VEC_SIZE; \
|
||||
int vec_size = support_vec / at::elementSize(dtype); \
|
||||
int vec_size = support_vec / input.element_size(); \
|
||||
const bool use_vec = (d % vec_size == 0); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
const torch::stable::accelerator::DeviceGuard device_guard( \
|
||||
input.get_device_index()); \
|
||||
const cudaStream_t stream = get_current_cuda_stream(); \
|
||||
if (use_vec) { \
|
||||
dim3 block(std::min(d / vec_size, 1024)); \
|
||||
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
|
||||
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
|
||||
vllm::act_and_mul_kernel< \
|
||||
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
|
||||
KERNEL<scalar_t>, \
|
||||
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
|
||||
ACT_FIRST, true, HAS_CLAMP, true><<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, LIMIT); \
|
||||
out.mutable_data_ptr<scalar_t>(), \
|
||||
input.const_data_ptr<scalar_t>(), d, LIMIT); \
|
||||
}); \
|
||||
} else { \
|
||||
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
|
||||
vllm::act_and_mul_kernel< \
|
||||
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
|
||||
KERNEL<scalar_t>, \
|
||||
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
|
||||
ACT_FIRST, true, HAS_CLAMP, false><<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, LIMIT); \
|
||||
out.mutable_data_ptr<scalar_t>(), \
|
||||
input.const_data_ptr<scalar_t>(), d, LIMIT); \
|
||||
}); \
|
||||
} \
|
||||
} else { \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
|
||||
vllm::act_and_mul_kernel< \
|
||||
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
|
||||
KERNEL<scalar_t>, \
|
||||
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
|
||||
ACT_FIRST, false, HAS_CLAMP><<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, LIMIT); \
|
||||
out.mutable_data_ptr<scalar_t>(), input.const_data_ptr<scalar_t>(), \
|
||||
d, LIMIT); \
|
||||
}); \
|
||||
}
|
||||
|
||||
void silu_and_mul(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
void silu_and_mul(torch::stable::Tensor& out, // [..., d]
|
||||
torch::stable::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel,
|
||||
true, false, 0.0f);
|
||||
}
|
||||
|
||||
void silu_and_mul_clamp(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., 2 * d]
|
||||
void silu_and_mul_clamp(torch::stable::Tensor& out, // [..., d]
|
||||
torch::stable::Tensor& input, // [..., 2 * d]
|
||||
double limit) {
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel,
|
||||
true, true, (float)limit);
|
||||
}
|
||||
|
||||
void mul_and_silu(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
void mul_and_silu(torch::stable::Tensor& out, // [..., d]
|
||||
torch::stable::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
// The difference between mul_and_silu and silu_and_mul is that mul_and_silu
|
||||
// applies the silu to the latter half of the input.
|
||||
@@ -275,15 +279,15 @@ void mul_and_silu(torch::Tensor& out, // [..., d]
|
||||
false, false, 0.0f);
|
||||
}
|
||||
|
||||
void gelu_and_mul(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
void gelu_and_mul(torch::stable::Tensor& out, // [..., d]
|
||||
torch::stable::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, vllm::packed_gelu_kernel,
|
||||
true, false, 0.0f);
|
||||
}
|
||||
|
||||
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
void gelu_tanh_and_mul(torch::stable::Tensor& out, // [..., d]
|
||||
torch::stable::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(
|
||||
vllm::gelu_tanh_kernel, vllm::packed_gelu_tanh_kernel, true, false, 0.0f);
|
||||
@@ -434,19 +438,20 @@ __global__ void swigluoai_and_mul_kernel(
|
||||
return; \
|
||||
} \
|
||||
dim3 grid(num_tokens); \
|
||||
int cc_major = at::cuda::getCurrentDeviceProperties()->major; \
|
||||
int cc_major = get_device_prop()->major; \
|
||||
int support_vec = \
|
||||
(CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \
|
||||
? vllm::VecTraits<true>::ARCH_MAX_VEC_SIZE \
|
||||
: vllm::VecTraits<false>::ARCH_MAX_VEC_SIZE; \
|
||||
int vec_size = support_vec / at::elementSize(dtype); \
|
||||
int vec_size = support_vec / input.element_size(); \
|
||||
const bool use_vec = (d % vec_size == 0); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
const torch::stable::accelerator::DeviceGuard device_guard( \
|
||||
input.get_device_index()); \
|
||||
const cudaStream_t stream = get_current_cuda_stream(); \
|
||||
if (use_vec) { \
|
||||
dim3 block(std::min(d / vec_size, 1024)); \
|
||||
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES( \
|
||||
dtype, "act_and_mul_kernel_with_param", [&] { \
|
||||
vllm::act_and_mul_kernel_with_param< \
|
||||
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
|
||||
@@ -454,11 +459,11 @@ __global__ void swigluoai_and_mul_kernel(
|
||||
PACKED_KERNEL< \
|
||||
typename vllm::PackedTypeConverter<scalar_t>::Type>, \
|
||||
true, true><<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, \
|
||||
PARAM); \
|
||||
out.mutable_data_ptr<scalar_t>(), \
|
||||
input.const_data_ptr<scalar_t>(), d, PARAM); \
|
||||
}); \
|
||||
} else { \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES( \
|
||||
dtype, "act_and_mul_kernel_with_param", [&] { \
|
||||
vllm::act_and_mul_kernel_with_param< \
|
||||
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
|
||||
@@ -466,45 +471,49 @@ __global__ void swigluoai_and_mul_kernel(
|
||||
PACKED_KERNEL< \
|
||||
typename vllm::PackedTypeConverter<scalar_t>::Type>, \
|
||||
true, false><<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, \
|
||||
PARAM); \
|
||||
out.mutable_data_ptr<scalar_t>(), \
|
||||
input.const_data_ptr<scalar_t>(), d, PARAM); \
|
||||
}); \
|
||||
} \
|
||||
} else { \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel_with_param", [&] { \
|
||||
vllm::act_and_mul_kernel_with_param< \
|
||||
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
|
||||
KERNEL<scalar_t>, \
|
||||
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
|
||||
false><<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, PARAM); \
|
||||
}); \
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES( \
|
||||
dtype, "act_and_mul_kernel_with_param", [&] { \
|
||||
vllm::act_and_mul_kernel_with_param< \
|
||||
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
|
||||
KERNEL<scalar_t>, \
|
||||
PACKED_KERNEL< \
|
||||
typename vllm::PackedTypeConverter<scalar_t>::Type>, \
|
||||
false><<<grid, block, 0, stream>>>( \
|
||||
out.mutable_data_ptr<scalar_t>(), \
|
||||
input.const_data_ptr<scalar_t>(), d, PARAM); \
|
||||
}); \
|
||||
}
|
||||
|
||||
#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \
|
||||
int d = input.size(-1) / 2; \
|
||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \
|
||||
vllm::swigluoai_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
|
||||
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), d, ALPHA, \
|
||||
LIMIT); \
|
||||
#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \
|
||||
int d = input.size(-1) / 2; \
|
||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const torch::stable::accelerator::DeviceGuard device_guard( \
|
||||
input.get_device_index()); \
|
||||
const cudaStream_t stream = get_current_cuda_stream(); \
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \
|
||||
vllm::swigluoai_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
|
||||
<<<grid, block, 0, stream>>>(out.mutable_data_ptr<scalar_t>(), \
|
||||
input.const_data_ptr<scalar_t>(), d, \
|
||||
ALPHA, LIMIT); \
|
||||
});
|
||||
|
||||
void fatrelu_and_mul(torch::Tensor& out, // [..., d],
|
||||
torch::Tensor& input, // [..., 2 * d]
|
||||
void fatrelu_and_mul(torch::stable::Tensor& out, // [..., d],
|
||||
torch::stable::Tensor& input, // [..., 2 * d]
|
||||
double threshold) {
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(
|
||||
vllm::fatrelu_kernel, vllm::packed_fatrelu_kernel, threshold);
|
||||
}
|
||||
void swigluoai_and_mul(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., 2 * d]
|
||||
void swigluoai_and_mul(torch::stable::Tensor& out, // [..., d]
|
||||
torch::stable::Tensor& input, // [..., 2 * d]
|
||||
double alpha, double limit) {
|
||||
LAUNCH_SIGLUOAI_AND_MUL(vllm::swigluoai_and_mul, alpha, limit);
|
||||
}
|
||||
@@ -559,45 +568,46 @@ __global__ void activation_kernel(
|
||||
} // namespace vllm
|
||||
|
||||
// Launch element-wise activation kernel.
|
||||
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
||||
auto dtype = input.scalar_type(); \
|
||||
int d = input.size(-1); \
|
||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||
if (num_tokens == 0) { \
|
||||
return; \
|
||||
} \
|
||||
dim3 grid(num_tokens); \
|
||||
int cc_major = at::cuda::getCurrentDeviceProperties()->major; \
|
||||
int support_vec = \
|
||||
(CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \
|
||||
? vllm::VecTraits<true>::ARCH_MAX_VEC_SIZE \
|
||||
: vllm::VecTraits<false>::ARCH_MAX_VEC_SIZE; \
|
||||
int vec_size = support_vec / at::elementSize(dtype); \
|
||||
const bool use_vec = (d % vec_size == 0); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
if (use_vec) { \
|
||||
dim3 block(std::min(d / vec_size, 1024)); \
|
||||
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
|
||||
VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
|
||||
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, true, true> \
|
||||
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), d); \
|
||||
}); \
|
||||
} else { \
|
||||
VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
|
||||
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, true, false> \
|
||||
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), d); \
|
||||
}); \
|
||||
} \
|
||||
} else { \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
|
||||
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, false> \
|
||||
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), d); \
|
||||
}); \
|
||||
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
||||
auto dtype = input.scalar_type(); \
|
||||
int d = input.size(-1); \
|
||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||
if (num_tokens == 0) { \
|
||||
return; \
|
||||
} \
|
||||
dim3 grid(num_tokens); \
|
||||
int cc_major = get_device_prop()->major; \
|
||||
int support_vec = \
|
||||
(CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \
|
||||
? vllm::VecTraits<true>::ARCH_MAX_VEC_SIZE \
|
||||
: vllm::VecTraits<false>::ARCH_MAX_VEC_SIZE; \
|
||||
int vec_size = support_vec / input.element_size(); \
|
||||
const bool use_vec = (d % vec_size == 0); \
|
||||
const torch::stable::accelerator::DeviceGuard device_guard( \
|
||||
input.get_device_index()); \
|
||||
const cudaStream_t stream = get_current_cuda_stream(); \
|
||||
if (use_vec) { \
|
||||
dim3 block(std::min(d / vec_size, 1024)); \
|
||||
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
|
||||
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, true, true> \
|
||||
<<<grid, block, 0, stream>>>(out.mutable_data_ptr<scalar_t>(), \
|
||||
input.const_data_ptr<scalar_t>(), d); \
|
||||
}); \
|
||||
} else { \
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
|
||||
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, true, false> \
|
||||
<<<grid, block, 0, stream>>>(out.mutable_data_ptr<scalar_t>(), \
|
||||
input.const_data_ptr<scalar_t>(), d); \
|
||||
}); \
|
||||
} \
|
||||
} else { \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
|
||||
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, false> \
|
||||
<<<grid, block, 0, stream>>>(out.mutable_data_ptr<scalar_t>(), \
|
||||
input.const_data_ptr<scalar_t>(), d); \
|
||||
}); \
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
@@ -625,20 +635,20 @@ __device__ __forceinline__ T gelu_quick_kernel(const T& x) {
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void gelu_new(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., d]
|
||||
void gelu_new(torch::stable::Tensor& out, // [..., d]
|
||||
torch::stable::Tensor& input) // [..., d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
||||
}
|
||||
|
||||
void gelu_fast(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., d]
|
||||
void gelu_fast(torch::stable::Tensor& out, // [..., d]
|
||||
torch::stable::Tensor& input) // [..., d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
||||
}
|
||||
|
||||
void gelu_quick(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., d]
|
||||
void gelu_quick(torch::stable::Tensor& out, // [..., d]
|
||||
torch::stable::Tensor& input) // [..., d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
|
||||
}
|
||||
@@ -166,3 +166,86 @@ torch::stable::Tensor awq_dequantize(torch::stable::Tensor _kernel,
|
||||
|
||||
torch::stable::Tensor hadacore_transform(torch::stable::Tensor& x,
|
||||
bool inplace);
|
||||
|
||||
// Activation kernels (shared CUDA/ROCm)
|
||||
void silu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input);
|
||||
void silu_and_mul_clamp(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor& input, double limit);
|
||||
void mul_and_silu(torch::stable::Tensor& out, torch::stable::Tensor& input);
|
||||
void gelu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input);
|
||||
void gelu_tanh_and_mul(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor& input);
|
||||
void fatrelu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input,
|
||||
double threshold);
|
||||
void swigluoai_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input,
|
||||
double alpha = 1.702, double limit = 7.0);
|
||||
void gelu_new(torch::stable::Tensor& out, torch::stable::Tensor& input);
|
||||
void gelu_fast(torch::stable::Tensor& out, torch::stable::Tensor& input);
|
||||
void gelu_quick(torch::stable::Tensor& out, torch::stable::Tensor& input);
|
||||
|
||||
// INT8 quantization kernels (shared CUDA/ROCm)
|
||||
void static_scaled_int8_quant(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& scale,
|
||||
std::optional<torch::stable::Tensor> const& azp);
|
||||
|
||||
void dynamic_scaled_int8_quant(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor& scales,
|
||||
std::optional<torch::stable::Tensor> const& azp);
|
||||
|
||||
// FP8 quantization kernels (shared CUDA/ROCm)
|
||||
void static_scaled_fp8_quant(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& scale,
|
||||
std::optional<torch::headeronly::IntHeaderOnlyArrayRef> group_shape =
|
||||
std::nullopt);
|
||||
|
||||
void dynamic_scaled_fp8_quant(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor& scale);
|
||||
|
||||
void dynamic_per_token_scaled_fp8_quant(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor& scale,
|
||||
std::optional<torch::stable::Tensor> const& scale_ub);
|
||||
|
||||
// GPTQ kernels (shared CUDA/ROCm)
|
||||
torch::stable::Tensor gptq_gemm(torch::stable::Tensor a,
|
||||
torch::stable::Tensor b_q_weight,
|
||||
torch::stable::Tensor b_gptq_qzeros,
|
||||
torch::stable::Tensor b_gptq_scales,
|
||||
torch::stable::Tensor b_g_idx, bool use_exllama,
|
||||
bool use_v2_format, int64_t bit);
|
||||
|
||||
void gptq_shuffle(torch::stable::Tensor q_weight, torch::stable::Tensor q_perm,
|
||||
int64_t bit);
|
||||
|
||||
// GGML kernels (shared CUDA/ROCm)
|
||||
torch::stable::Tensor ggml_dequantize(
|
||||
torch::stable::Tensor W, int64_t type, int64_t m, int64_t n,
|
||||
std::optional<torch::headeronly::ScalarType> const& dtype);
|
||||
|
||||
torch::stable::Tensor ggml_mul_mat_vec_a8(torch::stable::Tensor W,
|
||||
torch::stable::Tensor X, int64_t type,
|
||||
int64_t row);
|
||||
|
||||
torch::stable::Tensor ggml_mul_mat_a8(torch::stable::Tensor W,
|
||||
torch::stable::Tensor X, int64_t type,
|
||||
int64_t row);
|
||||
|
||||
torch::stable::Tensor ggml_moe_a8(torch::stable::Tensor X,
|
||||
torch::stable::Tensor W,
|
||||
torch::stable::Tensor sorted_token_ids,
|
||||
torch::stable::Tensor expert_ids,
|
||||
torch::stable::Tensor num_tokens_post_padded,
|
||||
int64_t type, int64_t row, int64_t top_k,
|
||||
int64_t tokens);
|
||||
|
||||
torch::stable::Tensor ggml_moe_a8_vec(torch::stable::Tensor X,
|
||||
torch::stable::Tensor W,
|
||||
torch::stable::Tensor topk_ids,
|
||||
int64_t top_k, int64_t type, int64_t row,
|
||||
int64_t tokens);
|
||||
|
||||
int64_t ggml_moe_get_block_size(int64_t type);
|
||||
|
||||
+180
-162
@@ -1,17 +1,20 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "../../../cuda_compat.h"
|
||||
#include "../../dispatch_utils.h"
|
||||
#include "../../torch_utils.h"
|
||||
|
||||
#include "../../cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
|
||||
#include "ggml-common.h"
|
||||
#include "vecdotq.cuh"
|
||||
#include "dequantize.cuh"
|
||||
#include "mmvq.cuh"
|
||||
#include "mmq.cuh"
|
||||
// NOTE: These headers are intentionally kept in csrc/quantization/gguf/ (not
|
||||
// moved to libtorch_stable) to avoid unnecessary reformatting that would break
|
||||
// git rename detection and pollute blame history.
|
||||
#include "../../../quantization/gguf/ggml-common.h"
|
||||
#include "../../../quantization/gguf/vecdotq.cuh"
|
||||
#include "../../../quantization/gguf/dequantize.cuh"
|
||||
#include "../../../quantization/gguf/mmvq.cuh"
|
||||
#include "../../../quantization/gguf/mmq.cuh"
|
||||
#include "moe.cuh"
|
||||
#include "moe_vec.cuh"
|
||||
|
||||
@@ -71,16 +74,17 @@ static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx,
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
|
||||
int64_t type, int64_t m, int64_t n,
|
||||
std::optional<at::ScalarType> const& dtype) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
|
||||
auto dtype_ = dtype.value_or(torch::kFloat16);
|
||||
auto options = torch::TensorOptions().dtype(dtype_).device(W.device());
|
||||
at::Tensor DW = torch::empty({m, n}, options);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
torch::stable::Tensor ggml_dequantize(
|
||||
torch::stable::Tensor W, // quant weight
|
||||
int64_t type, int64_t m, int64_t n,
|
||||
std::optional<torch::headeronly::ScalarType> const& dtype) {
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
W.get_device_index());
|
||||
auto dtype_ = dtype.value_or(torch::headeronly::ScalarType::Half);
|
||||
auto DW = torch::stable::empty({m, n}, dtype_, std::nullopt, W.device());
|
||||
cudaStream_t stream = get_current_cuda_stream();
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(DW.scalar_type(), "ggml_dequantize", [&] {
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES(DW.scalar_type(), "ggml_dequantize", [&] {
|
||||
auto to_cuda = ggml_get_to_cuda<scalar_t>(type);
|
||||
to_cuda((void*)W.data_ptr(), (scalar_t*)DW.data_ptr(), m * n, stream);
|
||||
});
|
||||
@@ -88,135 +92,142 @@ torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
|
||||
return DW;
|
||||
}
|
||||
|
||||
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
|
||||
torch::Tensor X, // input
|
||||
int64_t type, int64_t row) {
|
||||
torch::stable::Tensor ggml_mul_mat_vec_a8(
|
||||
torch::stable::Tensor W, // quant weight
|
||||
torch::stable::Tensor X, // input
|
||||
int64_t type, int64_t row) {
|
||||
int col = X.sizes()[1];
|
||||
int vecs = X.sizes()[0];
|
||||
const int padded = (col + 512 - 1) / 512 * 512;
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
|
||||
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
|
||||
at::Tensor Y = torch::empty({vecs, row}, options);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
|
||||
at::Tensor quant_X = torch::empty({vecs, padded / 32 * 9}, options);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_vec_a8", [&] {
|
||||
quantize_row_q8_1_cuda<scalar_t>(
|
||||
(scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), col, vecs, stream);
|
||||
switch (type) {
|
||||
case 2:
|
||||
mul_mat_vec_q4_0_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 3:
|
||||
mul_mat_vec_q4_1_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 6:
|
||||
mul_mat_vec_q5_0_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 7:
|
||||
mul_mat_vec_q5_1_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 8:
|
||||
mul_mat_vec_q8_0_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 10:
|
||||
mul_mat_vec_q2_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 11:
|
||||
mul_mat_vec_q3_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 12:
|
||||
mul_mat_vec_q4_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 13:
|
||||
mul_mat_vec_q5_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 14:
|
||||
mul_mat_vec_q6_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 16:
|
||||
mul_mat_vec_iq2_xxs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 17:
|
||||
mul_mat_vec_iq2_xs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 18:
|
||||
mul_mat_vec_iq3_xxs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 19:
|
||||
mul_mat_vec_iq1_s_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 20:
|
||||
mul_mat_vec_iq4_nl_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 21:
|
||||
mul_mat_vec_iq3_s_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 22:
|
||||
mul_mat_vec_iq2_s_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 23:
|
||||
mul_mat_vec_iq4_xs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 29:
|
||||
mul_mat_vec_iq1_m_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
}
|
||||
});
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
X.get_device_index());
|
||||
auto Y = torch::stable::empty({vecs, row}, X.scalar_type(), std::nullopt,
|
||||
W.device());
|
||||
cudaStream_t stream = get_current_cuda_stream();
|
||||
auto quant_X = torch::stable::empty({vecs, padded / 32 * 9},
|
||||
torch::headeronly::ScalarType::Int,
|
||||
std::nullopt, W.device());
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
|
||||
X.scalar_type(), "ggml_mul_mat_vec_a8", [&] {
|
||||
quantize_row_q8_1_cuda<scalar_t>((scalar_t*)X.data_ptr(),
|
||||
(void*)quant_X.data_ptr(), col, vecs,
|
||||
stream);
|
||||
switch (type) {
|
||||
case 2:
|
||||
mul_mat_vec_q4_0_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 3:
|
||||
mul_mat_vec_q4_1_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 6:
|
||||
mul_mat_vec_q5_0_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 7:
|
||||
mul_mat_vec_q5_1_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 8:
|
||||
mul_mat_vec_q8_0_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 10:
|
||||
mul_mat_vec_q2_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 11:
|
||||
mul_mat_vec_q3_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 12:
|
||||
mul_mat_vec_q4_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 13:
|
||||
mul_mat_vec_q5_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 14:
|
||||
mul_mat_vec_q6_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 16:
|
||||
mul_mat_vec_iq2_xxs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 17:
|
||||
mul_mat_vec_iq2_xs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 18:
|
||||
mul_mat_vec_iq3_xxs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 19:
|
||||
mul_mat_vec_iq1_s_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 20:
|
||||
mul_mat_vec_iq4_nl_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 21:
|
||||
mul_mat_vec_iq3_s_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 22:
|
||||
mul_mat_vec_iq2_s_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 23:
|
||||
mul_mat_vec_iq4_xs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
case 29:
|
||||
mul_mat_vec_iq1_m_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
|
||||
break;
|
||||
}
|
||||
});
|
||||
return Y;
|
||||
}
|
||||
|
||||
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
|
||||
torch::Tensor X, // input
|
||||
int64_t type, int64_t row) {
|
||||
torch::stable::Tensor ggml_mul_mat_a8(torch::stable::Tensor W, // quant weight
|
||||
torch::stable::Tensor X, // input
|
||||
int64_t type, int64_t row) {
|
||||
int col = X.sizes()[1];
|
||||
int padded = (col + 512 - 1) / 512 * 512;
|
||||
int batch = X.sizes()[0];
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
|
||||
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
|
||||
at::Tensor Y = torch::empty({batch, row}, options);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
|
||||
at::Tensor quant_X = torch::empty({batch, padded / 32 * 9}, options);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_a8", [&] {
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
X.get_device_index());
|
||||
auto Y = torch::stable::empty({batch, row}, X.scalar_type(), std::nullopt,
|
||||
W.device());
|
||||
cudaStream_t stream = get_current_cuda_stream();
|
||||
auto quant_X = torch::stable::empty({batch, padded / 32 * 9},
|
||||
torch::headeronly::ScalarType::Int,
|
||||
std::nullopt, W.device());
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_a8", [&] {
|
||||
quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
col, batch, stream);
|
||||
|
||||
@@ -276,21 +287,24 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
|
||||
return Y;
|
||||
}
|
||||
|
||||
torch::Tensor ggml_moe_a8(torch::Tensor X, // input
|
||||
torch::Tensor W, // expert weights
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_padded, int64_t type,
|
||||
int64_t row, int64_t top_k, int64_t tokens) {
|
||||
torch::stable::Tensor ggml_moe_a8(torch::stable::Tensor X, // input
|
||||
torch::stable::Tensor W, // expert weights
|
||||
torch::stable::Tensor sorted_token_ids,
|
||||
torch::stable::Tensor expert_ids,
|
||||
torch::stable::Tensor num_tokens_post_padded,
|
||||
int64_t type, int64_t row, int64_t top_k,
|
||||
int64_t tokens) {
|
||||
int col = X.sizes()[1];
|
||||
int padded = (col + 512 - 1) / 512 * 512;
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
|
||||
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
|
||||
at::Tensor Y = torch::empty({tokens * top_k, row}, options);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
|
||||
at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_a8", [&] {
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
X.get_device_index());
|
||||
auto Y = torch::stable::empty({tokens * top_k, row}, X.scalar_type(),
|
||||
std::nullopt, W.device());
|
||||
cudaStream_t stream = get_current_cuda_stream();
|
||||
auto quant_X = torch::stable::empty({tokens, padded / 32 * 9},
|
||||
torch::headeronly::ScalarType::Int,
|
||||
std::nullopt, W.device());
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_a8", [&] {
|
||||
quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
col, tokens, stream);
|
||||
switch (type) {
|
||||
@@ -379,19 +393,23 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, // input
|
||||
return Y;
|
||||
}
|
||||
|
||||
torch::Tensor ggml_moe_a8_vec(torch::Tensor X, // input
|
||||
torch::Tensor W, // expert weights
|
||||
torch::Tensor topk_ids, int64_t top_k,
|
||||
int64_t type, int64_t row, int64_t tokens) {
|
||||
torch::stable::Tensor ggml_moe_a8_vec(
|
||||
torch::stable::Tensor X, // input
|
||||
torch::stable::Tensor W, // expert weights
|
||||
torch::stable::Tensor topk_ids, int64_t top_k, int64_t type, int64_t row,
|
||||
int64_t tokens) {
|
||||
int col = X.sizes()[1];
|
||||
const int padded = (col + 512 - 1) / 512 * 512;
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
|
||||
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
|
||||
at::Tensor Y = torch::zeros({tokens * top_k, row}, options);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
|
||||
at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_vec_a8", [&] {
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
X.get_device_index());
|
||||
auto Y = torch::stable::empty({tokens * top_k, row}, X.scalar_type(),
|
||||
std::nullopt, W.device());
|
||||
torch::stable::fill_(Y, 0.0);
|
||||
cudaStream_t stream = get_current_cuda_stream();
|
||||
auto quant_X = torch::stable::empty({tokens, padded / 32 * 9},
|
||||
torch::headeronly::ScalarType::Int,
|
||||
std::nullopt, W.device());
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_vec_a8", [&] {
|
||||
quantize_row_q8_1_cuda<scalar_t>((scalar_t*)X.data_ptr(),
|
||||
(void*)quant_X.data_ptr(), col, tokens,
|
||||
stream);
|
||||
+30
-23
@@ -6,9 +6,8 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include "../../torch_utils.h"
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
@@ -735,7 +734,7 @@ void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight,
|
||||
fp_gemm_half_q_half_gptq_kernel kernel =
|
||||
pick_gemm_half_q_half_gptq_kernel(true, m_count, bit);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const cudaStream_t stream = get_current_cuda_stream();
|
||||
kernel<<<gridDim, blockDim, 0, stream>>>(
|
||||
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, c, size_m, size_n, size_k,
|
||||
groups, use_v2_format, b_q_perm);
|
||||
@@ -1164,7 +1163,7 @@ void reconstruct_exllama(const uint32_t* b_q_weight,
|
||||
reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel;
|
||||
}
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const cudaStream_t stream = get_current_cuda_stream();
|
||||
reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>(
|
||||
b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups,
|
||||
use_v2_format, out);
|
||||
@@ -1376,7 +1375,7 @@ void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight,
|
||||
kernel = gemm_half_q_half_alt_8bit_kernel;
|
||||
}
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const cudaStream_t stream = get_current_cuda_stream();
|
||||
kernel<<<gridDim, blockDim, 0, stream>>>(
|
||||
(const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx,
|
||||
size_m, size_k / 32 * bit, size_n, use_v2_format);
|
||||
@@ -1485,7 +1484,7 @@ void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros,
|
||||
gridDim.y = DIVIDE(height, 32);
|
||||
}
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const cudaStream_t stream = get_current_cuda_stream();
|
||||
kernel<<<gridDim, blockDim, 0, stream>>>(b_q_weight, b_gptq_scales,
|
||||
b_gptq_qzeros, b_g_idx, height,
|
||||
width, groups, use_v2_format, out);
|
||||
@@ -1794,7 +1793,7 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height,
|
||||
} else if (bit == 8) {
|
||||
kernel = make_sequential_8bit_kernel;
|
||||
}
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const cudaStream_t stream = get_current_cuda_stream();
|
||||
kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, new_qweight, q_perm,
|
||||
width);
|
||||
// Replace qweights
|
||||
@@ -1818,29 +1817,34 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height,
|
||||
} else if (bit == 8) {
|
||||
shuffle_kernel = shuffle_8bit_kernel;
|
||||
}
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const cudaStream_t stream = get_current_cuda_stream();
|
||||
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height, width);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
|
||||
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
torch::Tensor b_gptq_qzeros,
|
||||
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
|
||||
bool use_exllama, bool use_v2_format, int64_t bit) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||
at::Tensor c = torch::zeros({a.size(0), b_q_weight.size(1)}, options);
|
||||
at::Tensor temp_dq = torch::empty(
|
||||
{b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options);
|
||||
torch::stable::Tensor gptq_gemm(torch::stable::Tensor a,
|
||||
torch::stable::Tensor b_q_weight,
|
||||
torch::stable::Tensor b_gptq_qzeros,
|
||||
torch::stable::Tensor b_gptq_scales,
|
||||
torch::stable::Tensor b_g_idx, bool use_exllama,
|
||||
bool use_v2_format, int64_t bit) {
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
a.get_device_index());
|
||||
auto c = torch::stable::new_zeros(a, {a.size(0), b_q_weight.size(1)});
|
||||
auto temp_dq =
|
||||
torch::stable::empty({b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)},
|
||||
a.scalar_type(), std::nullopt, a.device());
|
||||
|
||||
vllm::gptq::gemm_half_q_half_cuda(
|
||||
at::cuda::getCurrentCUDABlasHandle(), (const half*)a.data_ptr(),
|
||||
get_current_cuda_blas_handle(), (const half*)a.data_ptr(),
|
||||
(const uint32_t*)b_q_weight.data_ptr(),
|
||||
(const uint32_t*)b_gptq_qzeros.data_ptr(),
|
||||
(const half*)b_gptq_scales.data_ptr(),
|
||||
b_g_idx.device().is_meta() ? NULL : (const int*)b_g_idx.data_ptr(),
|
||||
b_g_idx.device().type() == torch::stable::DeviceType::Meta
|
||||
? NULL
|
||||
: (const int*)b_g_idx.data_ptr(),
|
||||
(half*)c.data_ptr(), (half*)temp_dq.data_ptr(),
|
||||
c.size(0), // m
|
||||
c.size(1), // n
|
||||
@@ -1850,11 +1854,14 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
return c;
|
||||
}
|
||||
|
||||
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
|
||||
void gptq_shuffle(torch::stable::Tensor q_weight, torch::stable::Tensor q_perm,
|
||||
int64_t bit) {
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
q_weight.get_device_index());
|
||||
vllm::gptq::shuffle_exllama_weight(
|
||||
(uint32_t*)q_weight.data_ptr(),
|
||||
q_perm.device().is_meta() || q_perm.numel() == 0
|
||||
q_perm.device().type() == torch::stable::DeviceType::Meta ||
|
||||
q_perm.numel() == 0
|
||||
? NULL
|
||||
: (int*)q_perm.data_ptr(),
|
||||
q_weight.size(0) * 32 / bit, q_weight.size(1), bit);
|
||||
+107
-88
@@ -1,11 +1,9 @@
|
||||
#include "common.cuh"
|
||||
#include "dispatch_utils.h"
|
||||
#include "cub_helpers.h"
|
||||
#include "libtorch_stable/quantization/vectorization_utils.cuh"
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <tuple>
|
||||
|
||||
#include "../../../../quantization/w8a8/fp8/common.cuh"
|
||||
#include "../../../dispatch_utils.h"
|
||||
#include "../../../../cub_helpers.h"
|
||||
#include "../../vectorization_utils.cuh"
|
||||
#include "../../../torch_utils.h"
|
||||
#include <torch/csrc/stable/macros.h>
|
||||
namespace vllm {
|
||||
|
||||
// STRIDE_I_ZERO: true if scale_stride_i == 0 (per-tensor or per-channel)
|
||||
@@ -183,16 +181,16 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided(
|
||||
} // namespace vllm
|
||||
|
||||
void static_scaled_fp8_quant(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor const& scale, // various shapes
|
||||
std::optional<std::tuple<int64_t, int64_t>>
|
||||
opt_group_shape) // optional explicit (group_m, group_n)
|
||||
torch::stable::Tensor& out, // [..., d]
|
||||
torch::stable::Tensor const& input, // [..., d]
|
||||
torch::stable::Tensor const& scale, // various shapes
|
||||
std::optional<torch::headeronly::IntHeaderOnlyArrayRef>
|
||||
opt_group_shape) // optional explicit [group_m, group_n]
|
||||
{
|
||||
TORCH_CHECK(input.stride(-1) == 1,
|
||||
"last dimension of input must be contiguous");
|
||||
TORCH_CHECK(out.stride(-1) == 1,
|
||||
"last dimension of output must be contiguous");
|
||||
STD_TORCH_CHECK(input.stride(-1) == 1,
|
||||
"last dimension of input must be contiguous");
|
||||
STD_TORCH_CHECK(out.stride(-1) == 1,
|
||||
"last dimension of output must be contiguous");
|
||||
|
||||
const int hidden_size = input.size(-1); // N (columns)
|
||||
const int num_tokens = input.numel() / hidden_size; // M (rows)
|
||||
@@ -212,13 +210,18 @@ void static_scaled_fp8_quant(
|
||||
} else if (scale.dim() == 1) {
|
||||
// 1D scale: require explicit group_shape to disambiguate per-channel vs
|
||||
// per-token (avoids edge case where num_tokens == hidden_size)
|
||||
TORCH_CHECK(opt_group_shape.has_value(),
|
||||
"1D scale requires explicit group_shape to disambiguate "
|
||||
"per-channel vs per-token quantization. "
|
||||
"Use group_shape=(-1, 1) for per-channel or group_shape=(1, "
|
||||
"-1) for per-token.");
|
||||
STD_TORCH_CHECK(
|
||||
opt_group_shape.has_value(),
|
||||
"1D scale requires explicit group_shape to disambiguate "
|
||||
"per-channel vs per-token quantization. "
|
||||
"Use group_shape=(-1, 1) for per-channel or group_shape=(1, "
|
||||
"-1) for per-token.");
|
||||
STD_TORCH_CHECK(opt_group_shape->size() == 2,
|
||||
"group_shape must have exactly 2 elements, got ",
|
||||
opt_group_shape->size());
|
||||
|
||||
const auto& [opt_group_m, opt_group_n] = opt_group_shape.value();
|
||||
const auto opt_group_m = (*opt_group_shape)[0];
|
||||
const auto opt_group_n = (*opt_group_shape)[1];
|
||||
group_m = opt_group_m == -1 ? num_tokens : static_cast<int>(opt_group_m);
|
||||
group_n = opt_group_n == -1 ? hidden_size : static_cast<int>(opt_group_n);
|
||||
|
||||
@@ -228,11 +231,11 @@ void static_scaled_fp8_quant(
|
||||
const int64_t expected_scale_n = hidden_size / group_n;
|
||||
const int64_t expected_scale_numel = expected_scale_m * expected_scale_n;
|
||||
|
||||
TORCH_CHECK(scale_len == expected_scale_numel, "1D scale length (",
|
||||
scale_len, ") does not match expected size (",
|
||||
expected_scale_numel, ") for group_shape (", opt_group_m, ", ",
|
||||
opt_group_n, ") with input shape (", num_tokens, ", ",
|
||||
hidden_size, ")");
|
||||
STD_TORCH_CHECK(scale_len == expected_scale_numel, "1D scale length (",
|
||||
scale_len, ") does not match expected size (",
|
||||
expected_scale_numel, ") for group_shape (", opt_group_m,
|
||||
", ", opt_group_n, ") with input shape (", num_tokens, ", ",
|
||||
hidden_size, ")");
|
||||
|
||||
// For 1D scale, determine strides based on which dim is trivial
|
||||
// Scale indexing: scale[gi * scale_stride_i + gj * scale_stride_j]
|
||||
@@ -248,7 +251,7 @@ void static_scaled_fp8_quant(
|
||||
scale_stride_i = scale.stride(0);
|
||||
scale_stride_j = 0;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
false,
|
||||
"1D scale can only be used when one of the scale dimensions is 1. "
|
||||
"For 2D group scaling, use a 2D scale tensor.");
|
||||
@@ -259,10 +262,12 @@ void static_scaled_fp8_quant(
|
||||
const int64_t scale_size_0 = scale.size(0);
|
||||
const int64_t scale_size_1 = scale.size(1);
|
||||
|
||||
TORCH_CHECK(num_tokens % scale_size_0 == 0, "num_tokens (", num_tokens,
|
||||
") must be divisible by scale.size(0) (", scale_size_0, ")");
|
||||
TORCH_CHECK(hidden_size % scale_size_1 == 0, "hidden_size (", hidden_size,
|
||||
") must be divisible by scale.size(1) (", scale_size_1, ")");
|
||||
STD_TORCH_CHECK(num_tokens % scale_size_0 == 0, "num_tokens (", num_tokens,
|
||||
") must be divisible by scale.size(0) (", scale_size_0,
|
||||
")");
|
||||
STD_TORCH_CHECK(hidden_size % scale_size_1 == 0, "hidden_size (",
|
||||
hidden_size, ") must be divisible by scale.size(1) (",
|
||||
scale_size_1, ")");
|
||||
|
||||
// Infer from 2D scale shape
|
||||
int inferred_group_m = num_tokens / scale_size_0;
|
||||
@@ -270,16 +275,21 @@ void static_scaled_fp8_quant(
|
||||
|
||||
// Use explicit if provided, otherwise use inferred
|
||||
if (opt_group_shape.has_value()) {
|
||||
const auto& [opt_group_m, opt_group_n] = opt_group_shape.value();
|
||||
STD_TORCH_CHECK(opt_group_shape->size() == 2,
|
||||
"group_shape must have exactly 2 elements, got ",
|
||||
opt_group_shape->size());
|
||||
const auto opt_group_m = (*opt_group_shape)[0];
|
||||
const auto opt_group_n = (*opt_group_shape)[1];
|
||||
group_m = opt_group_m == -1 ? num_tokens : static_cast<int>(opt_group_m);
|
||||
group_n = opt_group_n == -1 ? hidden_size : static_cast<int>(opt_group_n);
|
||||
|
||||
// Validate explicit matches inferred
|
||||
TORCH_CHECK(group_m == inferred_group_m && group_n == inferred_group_n,
|
||||
"Explicit group_shape (", opt_group_m, ", ", opt_group_n,
|
||||
") does not match inferred group shape (", inferred_group_m,
|
||||
", ", inferred_group_n, ") from 2D scale tensor shape (",
|
||||
scale_size_0, ", ", scale_size_1, ")");
|
||||
STD_TORCH_CHECK(
|
||||
group_m == inferred_group_m && group_n == inferred_group_n,
|
||||
"Explicit group_shape (", opt_group_m, ", ", opt_group_n,
|
||||
") does not match inferred group shape (", inferred_group_m, ", ",
|
||||
inferred_group_n, ") from 2D scale tensor shape (", scale_size_0,
|
||||
", ", scale_size_1, ")");
|
||||
} else {
|
||||
group_m = inferred_group_m;
|
||||
group_n = inferred_group_n;
|
||||
@@ -288,8 +298,8 @@ void static_scaled_fp8_quant(
|
||||
scale_stride_i = scale.stride(0);
|
||||
scale_stride_j = scale.stride(1);
|
||||
} else {
|
||||
TORCH_CHECK(false, "scale must be 0D, 1D, or 2D tensor, but got ",
|
||||
scale.dim(), "D");
|
||||
STD_TORCH_CHECK(false, "scale must be 0D, 1D, or 2D tensor, but got ",
|
||||
scale.dim(), "D");
|
||||
}
|
||||
|
||||
const int block_size = 256;
|
||||
@@ -299,37 +309,39 @@ void static_scaled_fp8_quant(
|
||||
const int64_t in_row_stride = input.stride(-2);
|
||||
const int64_t out_row_stride = out.stride(-2);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
input.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream();
|
||||
|
||||
// Dispatch to template-specialized kernel based on stride pattern
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
VLLM_STABLE_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
VLLM_DISPATCH_BOOL(scale_stride_i == 0, S0_ZERO, [&] {
|
||||
VLLM_DISPATCH_BOOL(scale_stride_j == 0, S1_ZERO, [&] {
|
||||
VLLM_STABLE_DISPATCH_BOOL(scale_stride_i == 0, S0_ZERO, [&] {
|
||||
VLLM_STABLE_DISPATCH_BOOL(scale_stride_j == 0, S1_ZERO, [&] {
|
||||
vllm::scaled_fp8_quant_kernel_strided_group_shape<
|
||||
scalar_t, fp8_t, S0_ZERO, S1_ZERO>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), hidden_size, in_row_stride,
|
||||
out_row_stride, group_m, group_n, scale_stride_i,
|
||||
scale_stride_j);
|
||||
out.mutable_data_ptr<fp8_t>(),
|
||||
input.const_data_ptr<scalar_t>(),
|
||||
scale.const_data_ptr<float>(), hidden_size,
|
||||
in_row_stride, out_row_stride, group_m, group_n,
|
||||
scale_stride_i, scale_stride_j);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
void dynamic_scaled_fp8_quant(torch::stable::Tensor& out, // [..., d]
|
||||
torch::stable::Tensor const& input, // [..., d]
|
||||
torch::stable::Tensor& scale) // [1]
|
||||
{
|
||||
TORCH_CHECK(input.stride(-1) == 1,
|
||||
"last dimension of input must be contiguous");
|
||||
TORCH_CHECK(out.stride(-1) == 1,
|
||||
"last dimension of output must be contiguous");
|
||||
STD_TORCH_CHECK(input.stride(-1) == 1,
|
||||
"last dimension of input must be contiguous");
|
||||
STD_TORCH_CHECK(out.stride(-1) == 1,
|
||||
"last dimension of output must be contiguous");
|
||||
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = input.numel() / hidden_size;
|
||||
@@ -340,40 +352,43 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
const int64_t in_row_stride = input.stride(-2);
|
||||
const int64_t out_row_stride = out.stride(-2);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
input.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream();
|
||||
|
||||
// scale tensor should be initialised to <=0 before reduction
|
||||
AT_CUDA_CHECK(
|
||||
cudaMemsetAsync(scale.data_ptr<float>(), 0, sizeof(float), stream));
|
||||
STD_CUDA_CHECK(cudaMemsetAsync(scale.mutable_data_ptr<float>(), 0,
|
||||
sizeof(float), stream));
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
VLLM_STABLE_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::segmented_max_reduction_strided<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
scale.data_ptr<float>(), input.data_ptr<scalar_t>(),
|
||||
hidden_size, in_row_stride,
|
||||
static_cast<int64_t>(num_tokens));
|
||||
scale.mutable_data_ptr<float>(),
|
||||
input.const_data_ptr<scalar_t>(), hidden_size,
|
||||
in_row_stride, static_cast<int64_t>(num_tokens));
|
||||
|
||||
vllm::scaled_fp8_quant_kernel_strided_dynamic<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), hidden_size, in_row_stride,
|
||||
out_row_stride);
|
||||
<<<grid, block, 0, stream>>>(out.mutable_data_ptr<fp8_t>(),
|
||||
input.const_data_ptr<scalar_t>(),
|
||||
scale.const_data_ptr<float>(),
|
||||
hidden_size, in_row_stride,
|
||||
out_row_stride);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void dynamic_per_token_scaled_fp8_quant(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
|
||||
TORCH_CHECK(input.stride(-1) == 1,
|
||||
"last dimension of input must be contiguous");
|
||||
TORCH_CHECK(out.stride(-1) == 1,
|
||||
"last dimension of output must be contiguous");
|
||||
torch::stable::Tensor& out, // [..., d]
|
||||
torch::stable::Tensor const& input, // [..., d]
|
||||
torch::stable::Tensor& scales,
|
||||
std::optional<torch::stable::Tensor> const& scale_ub) {
|
||||
STD_TORCH_CHECK(input.stride(-1) == 1,
|
||||
"last dimension of input must be contiguous");
|
||||
STD_TORCH_CHECK(out.stride(-1) == 1,
|
||||
"last dimension of output must be contiguous");
|
||||
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = input.numel() / hidden_size;
|
||||
@@ -384,20 +399,24 @@ void dynamic_per_token_scaled_fp8_quant(
|
||||
const int64_t in_row_stride = input.stride(-2);
|
||||
const int64_t out_row_stride = out.stride(-2);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
input.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream();
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
VLLM_STABLE_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(),
|
||||
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided<
|
||||
scalar_t, fp8_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||
hidden_size, in_row_stride, out_row_stride);
|
||||
vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided<scalar_t,
|
||||
fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.mutable_data_ptr<fp8_t>(),
|
||||
scales.mutable_data_ptr<float>(),
|
||||
input.const_data_ptr<scalar_t>(),
|
||||
scale_ub.has_value() ? scale_ub->const_data_ptr<float>()
|
||||
: nullptr,
|
||||
hidden_size, in_row_stride, out_row_stride);
|
||||
});
|
||||
});
|
||||
}
|
||||
+46
-40
@@ -1,12 +1,11 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
#include "libtorch_stable/quantization/vectorization_utils.cuh"
|
||||
#include "cub_helpers.h"
|
||||
#include "../../../dispatch_utils.h"
|
||||
#include "../../../torch_utils.h"
|
||||
#include "../../vectorization_utils.cuh"
|
||||
#include "../../../../cub_helpers.h"
|
||||
|
||||
static inline __device__ int8_t float_to_int8_rn(float x) {
|
||||
#ifdef USE_ROCM
|
||||
@@ -263,66 +262,73 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor const& input, // [..., hidden_size]
|
||||
torch::Tensor const& scale,
|
||||
std::optional<torch::Tensor> const& azp) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(scale.numel() == 1);
|
||||
TORCH_CHECK(!azp || azp->numel() == 1);
|
||||
void static_scaled_int8_quant(
|
||||
torch::stable::Tensor& out, // [..., hidden_size]
|
||||
torch::stable::Tensor const& input, // [..., hidden_size]
|
||||
torch::stable::Tensor const& scale,
|
||||
std::optional<torch::stable::Tensor> const& azp) {
|
||||
STD_TORCH_CHECK(input.is_contiguous());
|
||||
STD_TORCH_CHECK(out.is_contiguous());
|
||||
STD_TORCH_CHECK(scale.numel() == 1);
|
||||
STD_TORCH_CHECK(!azp || azp->numel() == 1);
|
||||
|
||||
int const hidden_size = input.size(-1);
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(std::min(hidden_size, 256));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
input.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream();
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
|
||||
if (!azp) {
|
||||
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), hidden_size);
|
||||
<<<grid, block, 0, stream>>>(input.const_data_ptr<scalar_t>(),
|
||||
out.mutable_data_ptr<int8_t>(),
|
||||
scale.const_data_ptr<float>(),
|
||||
hidden_size);
|
||||
} else {
|
||||
vllm::static_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), azp->data_ptr<int32_t>(),
|
||||
hidden_size);
|
||||
input.const_data_ptr<scalar_t>(),
|
||||
out.mutable_data_ptr<int8_t>(), scale.const_data_ptr<float>(),
|
||||
azp->const_data_ptr<int32_t>(), hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void dynamic_scaled_int8_quant(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor const& input, // [..., hidden_size]
|
||||
torch::Tensor& scales, std::optional<torch::Tensor> const& azp) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(scales.is_contiguous());
|
||||
TORCH_CHECK(!azp || azp->is_contiguous());
|
||||
torch::stable::Tensor& out, // [..., hidden_size]
|
||||
torch::stable::Tensor const& input, // [..., hidden_size]
|
||||
torch::stable::Tensor& scales,
|
||||
std::optional<torch::stable::Tensor> const& azp) {
|
||||
STD_TORCH_CHECK(input.is_contiguous());
|
||||
STD_TORCH_CHECK(out.is_contiguous());
|
||||
STD_TORCH_CHECK(scales.is_contiguous());
|
||||
STD_TORCH_CHECK(!azp || azp->is_contiguous());
|
||||
|
||||
int const hidden_size = input.size(-1);
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(std::min(hidden_size, 256));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
input.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream();
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
|
||||
if (!azp) {
|
||||
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scales.data_ptr<float>(), hidden_size);
|
||||
<<<grid, block, 0, stream>>>(input.const_data_ptr<scalar_t>(),
|
||||
out.mutable_data_ptr<int8_t>(),
|
||||
scales.mutable_data_ptr<float>(),
|
||||
hidden_size);
|
||||
} else {
|
||||
vllm::dynamic_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scales.data_ptr<float>(), azp->data_ptr<int32_t>(),
|
||||
hidden_size);
|
||||
<<<grid, block, 0, stream>>>(input.const_data_ptr<scalar_t>(),
|
||||
out.mutable_data_ptr<int8_t>(),
|
||||
scales.mutable_data_ptr<float>(),
|
||||
azp->mutable_data_ptr<int32_t>(),
|
||||
hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -266,6 +266,109 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||
// Hadamard transforms
|
||||
// conditionally compiled so impl registration is in source file
|
||||
ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor");
|
||||
|
||||
// Activation ops
|
||||
// Activation function used in SwiGLU.
|
||||
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
|
||||
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
|
||||
|
||||
// SwiGLU activation with input clamping.
|
||||
ops.def(
|
||||
"silu_and_mul_with_clamp(Tensor! result, Tensor input, float limit) "
|
||||
"-> ()");
|
||||
|
||||
// Activation function used in GeGLU with `none` approximation.
|
||||
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
|
||||
|
||||
// Activation function used in GeGLU with `tanh` approximation.
|
||||
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
|
||||
|
||||
// FATReLU implementation.
|
||||
ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");
|
||||
|
||||
ops.def(
|
||||
"swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float "
|
||||
"limit=7.0) "
|
||||
"-> ()");
|
||||
|
||||
// GELU implementation used in GPT-2.
|
||||
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
|
||||
|
||||
// Approximate GELU implementation.
|
||||
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
|
||||
|
||||
// Quick GELU implementation.
|
||||
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
|
||||
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale,"
|
||||
"Tensor? azp) -> ()");
|
||||
|
||||
// Compute int8 quantized tensor and scaling factor
|
||||
ops.def(
|
||||
"dynamic_scaled_int8_quant(Tensor! result, Tensor input, Tensor! scale, "
|
||||
"Tensor!? azp) -> ()");
|
||||
|
||||
// Compute FP8 quantized tensor for given scaling factor.
|
||||
// Supports per-tensor, per-channel, per-token, and arbitrary 2D group
|
||||
// scaling. Optional group_m/group_n specify the group shape explicitly;
|
||||
// required for 1D scales to disambiguate per-channel vs per-token.
|
||||
ops.def(
|
||||
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale, "
|
||||
"int[]? group_shape=None) -> ()");
|
||||
|
||||
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
|
||||
ops.def(
|
||||
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
|
||||
"-> "
|
||||
"()");
|
||||
|
||||
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
|
||||
ops.def(
|
||||
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
|
||||
"Tensor! scale, Tensor? scale_ub) -> "
|
||||
"()");
|
||||
|
||||
// Quantized GEMM for GPTQ.
|
||||
// Note: even though the C++ inferred schema is correct for this op, it seems
|
||||
// to prevent the meta function registry.
|
||||
ops.def(
|
||||
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
|
||||
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, bool "
|
||||
"use_v2_format, int bit) "
|
||||
"-> Tensor");
|
||||
|
||||
// Post processing for GPTQ.
|
||||
ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
|
||||
|
||||
// Dequantization for GGML.
|
||||
ops.def(
|
||||
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? "
|
||||
"dtype) -> Tensor");
|
||||
|
||||
// mmvq kernel for GGML.
|
||||
ops.def(
|
||||
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) "
|
||||
"-> Tensor");
|
||||
|
||||
// mmq kernel for GGML.
|
||||
ops.def(
|
||||
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
|
||||
|
||||
// moe kernel for GGML.
|
||||
ops.def(
|
||||
"ggml_moe_a8(Tensor X, Tensor W, "
|
||||
"Tensor sorted_token_ids, Tensor expert_ids, Tensor "
|
||||
"num_tokens_post_padded, "
|
||||
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor");
|
||||
|
||||
ops.def(
|
||||
"ggml_moe_a8_vec(Tensor X, Tensor W, "
|
||||
"Tensor topk_ids, int top_k, "
|
||||
"int type, SymInt row, SymInt tokens) -> Tensor");
|
||||
|
||||
ops.def("ggml_moe_get_block_size(int type) -> int");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||
@@ -312,6 +415,39 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||
// AllSpark ops: conditionally compiled so impl registrations are in source
|
||||
// files (allspark_repack.cu and allspark_qgemm_w8a16.cu)
|
||||
#endif
|
||||
|
||||
// Activation kernels (shared CUDA/ROCm)
|
||||
ops.impl("silu_and_mul", TORCH_BOX(&silu_and_mul));
|
||||
ops.impl("mul_and_silu", TORCH_BOX(&mul_and_silu));
|
||||
ops.impl("gelu_and_mul", TORCH_BOX(&gelu_and_mul));
|
||||
ops.impl("gelu_tanh_and_mul", TORCH_BOX(&gelu_tanh_and_mul));
|
||||
ops.impl("fatrelu_and_mul", TORCH_BOX(&fatrelu_and_mul));
|
||||
ops.impl("swigluoai_and_mul", TORCH_BOX(&swigluoai_and_mul));
|
||||
ops.impl("gelu_new", TORCH_BOX(&gelu_new));
|
||||
ops.impl("gelu_fast", TORCH_BOX(&gelu_fast));
|
||||
ops.impl("gelu_quick", TORCH_BOX(&gelu_quick));
|
||||
ops.impl("silu_and_mul_with_clamp", TORCH_BOX(&silu_and_mul_clamp));
|
||||
|
||||
// INT8 quantization kernels
|
||||
ops.impl("static_scaled_int8_quant", TORCH_BOX(&static_scaled_int8_quant));
|
||||
ops.impl("dynamic_scaled_int8_quant", TORCH_BOX(&dynamic_scaled_int8_quant));
|
||||
|
||||
// FP8 quantization kernels
|
||||
ops.impl("static_scaled_fp8_quant", TORCH_BOX(&static_scaled_fp8_quant));
|
||||
ops.impl("dynamic_scaled_fp8_quant", TORCH_BOX(&dynamic_scaled_fp8_quant));
|
||||
ops.impl("dynamic_per_token_scaled_fp8_quant",
|
||||
TORCH_BOX(&dynamic_per_token_scaled_fp8_quant));
|
||||
|
||||
// GPTQ kernels
|
||||
ops.impl("gptq_gemm", TORCH_BOX(&gptq_gemm));
|
||||
ops.impl("gptq_shuffle", TORCH_BOX(&gptq_shuffle));
|
||||
|
||||
// GGML kernels
|
||||
ops.impl("ggml_dequantize", TORCH_BOX(&ggml_dequantize));
|
||||
ops.impl("ggml_mul_mat_vec_a8", TORCH_BOX(&ggml_mul_mat_vec_a8));
|
||||
ops.impl("ggml_mul_mat_a8", TORCH_BOX(&ggml_mul_mat_a8));
|
||||
ops.impl("ggml_moe_a8", TORCH_BOX(&ggml_moe_a8));
|
||||
ops.impl("ggml_moe_a8_vec", TORCH_BOX(&ggml_moe_a8_vec));
|
||||
}
|
||||
|
||||
// These capability-check functions take only primitive args (no tensors), so
|
||||
@@ -329,6 +465,9 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
|
||||
ops.impl("cutlass_scaled_mm_supports_fp4",
|
||||
TORCH_BOX(&cutlass_scaled_mm_supports_fp4));
|
||||
#endif
|
||||
|
||||
// GGML block size lookup (no tensor args)
|
||||
ops.impl("ggml_moe_get_block_size", TORCH_BOX(&ggml_moe_get_block_size));
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(_C_stable_libtorch)
|
||||
|
||||
@@ -6,8 +6,12 @@
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_runtime.h>
|
||||
#else
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <deque>
|
||||
#include <mutex>
|
||||
|
||||
-47
@@ -149,17 +149,10 @@ void persistent_masked_m_silu_mul_quant(
|
||||
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
|
||||
bool use_ue8m0);
|
||||
|
||||
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input,
|
||||
double threshold);
|
||||
void swigluoai_and_mul(torch::Tensor& out, torch::Tensor& input,
|
||||
double alpha = 1.702, double limit = 7.0);
|
||||
|
||||
void gelu_new(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_fast(torch::Tensor& out, torch::Tensor& input);
|
||||
@@ -174,28 +167,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
|
||||
|
||||
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
|
||||
int64_t n,
|
||||
std::optional<at::ScalarType> const& dtype);
|
||||
|
||||
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
|
||||
int64_t type, int64_t row);
|
||||
|
||||
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
|
||||
int64_t row);
|
||||
|
||||
torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_padded, int64_t type,
|
||||
int64_t row, int64_t top_k, int64_t tokens);
|
||||
|
||||
torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W,
|
||||
torch::Tensor topk_ids, int64_t top_k,
|
||||
int64_t type, int64_t row, int64_t tokens);
|
||||
|
||||
int64_t ggml_moe_get_block_size(int64_t type);
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor const& scale,
|
||||
std::optional<torch::Tensor> const& azp);
|
||||
@@ -204,24 +175,6 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor& scales,
|
||||
std::optional<torch::Tensor> const& azp);
|
||||
|
||||
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
torch::Tensor b_gptq_qzeros,
|
||||
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
|
||||
bool use_exllama, bool use_v2_format, int64_t bit);
|
||||
|
||||
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
|
||||
|
||||
void static_scaled_fp8_quant(
|
||||
torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale,
|
||||
std::optional<std::tuple<int64_t, int64_t>> group_shape = std::nullopt);
|
||||
|
||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor& scale);
|
||||
|
||||
void dynamic_per_token_scaled_fp8_quant(
|
||||
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
|
||||
std::optional<torch::Tensor> const& scale_ub);
|
||||
|
||||
void selective_scan_fwd(
|
||||
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
|
||||
const torch::Tensor& B, const torch::Tensor& C,
|
||||
|
||||
+16
-14
@@ -7,23 +7,23 @@
|
||||
*/
|
||||
|
||||
#include <cmath>
|
||||
#include <torch/types.h>
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <torch/headeronly/util/Float8_e4m3fn.h>
|
||||
#define MAYBE_HOST_DEVICE C10_HOST_DEVICE
|
||||
#else
|
||||
#include <ATen/hip/HIPContext.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include <torch/headeronly/util/Float8_e4m3fn.h>
|
||||
#include <torch/headeronly/util/Float8_e4m3fnuz.h>
|
||||
// ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr
|
||||
#define MAYBE_HOST_DEVICE
|
||||
#endif
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> ||
|
||||
std::is_same_v<T, c10::Float8_e4m3fnuz> ||
|
||||
std::is_same_v<T, int8_t>>>
|
||||
typename = std::enable_if_t<
|
||||
std::is_same_v<T, torch::headeronly::Float8_e4m3fn> ||
|
||||
std::is_same_v<T, torch::headeronly::Float8_e4m3fnuz> ||
|
||||
std::is_same_v<T, int8_t>>>
|
||||
struct quant_type_max {
|
||||
static constexpr T val() { return std::numeric_limits<T>::max(); }
|
||||
};
|
||||
@@ -31,9 +31,10 @@ struct quant_type_max {
|
||||
// Using the default max value from pytorch (240.0 0x7F) will cause accuracy
|
||||
// issues when running dynamic quantization. Here use 224.0 0x7E for rocm.
|
||||
template <>
|
||||
struct quant_type_max<c10::Float8_e4m3fnuz> {
|
||||
static constexpr c10::Float8_e4m3fnuz val() {
|
||||
return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits());
|
||||
struct quant_type_max<torch::headeronly::Float8_e4m3fnuz> {
|
||||
static constexpr torch::headeronly::Float8_e4m3fnuz val() {
|
||||
return torch::headeronly::Float8_e4m3fnuz(
|
||||
0x7E, torch::headeronly::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
};
|
||||
|
||||
@@ -42,9 +43,10 @@ MAYBE_HOST_DEVICE static constexpr T quant_type_max_v =
|
||||
quant_type_max<T>::val();
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> ||
|
||||
std::is_same_v<T, c10::Float8_e4m3fnuz> ||
|
||||
std::is_same_v<T, int8_t>>>
|
||||
typename = std::enable_if_t<
|
||||
std::is_same_v<T, torch::headeronly::Float8_e4m3fn> ||
|
||||
std::is_same_v<T, torch::headeronly::Float8_e4m3fnuz> ||
|
||||
std::is_same_v<T, int8_t>>>
|
||||
struct min_scaling_factor {
|
||||
C10_DEVICE C10_ALWAYS_INLINE static float val() {
|
||||
return 1.0f / (quant_type_max_v<T> * 512.0f);
|
||||
|
||||
@@ -5,6 +5,19 @@
|
||||
|
||||
#include <cmath>
|
||||
|
||||
// This header is shared between _C and _C_stable_libtorch targets.
|
||||
// torch_utils.h provides get_device_prop(). We need to pass USE_CUDA
|
||||
// to the .so to expose some of the shims used by torch_utils.h. For now
|
||||
// this is only done for _C_stable_libtorch and not for _C, so we use the
|
||||
// non stable at::cuda::getCurrentDeviceProperties for _C for now.
|
||||
#ifdef TORCH_TARGET_VERSION
|
||||
#include "../../../libtorch_stable/torch_utils.h"
|
||||
#else
|
||||
#ifdef USE_ROCM
|
||||
#include <ATen/hip/HIPContext.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include "nvidia/quant_utils.cuh"
|
||||
#else
|
||||
@@ -18,7 +31,11 @@ static bool is_fp8_ocp() {
|
||||
#ifndef USE_ROCM
|
||||
return true;
|
||||
#else
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
#ifdef TORCH_TARGET_VERSION
|
||||
auto* dprops = get_device_prop();
|
||||
#else
|
||||
auto* dprops = at::cuda::getCurrentDeviceProperties();
|
||||
#endif
|
||||
std::string device_arch = dprops->gcnArchName;
|
||||
size_t substring = device_arch.find("gfx94");
|
||||
return substring == std::string::npos;
|
||||
|
||||
+1
-128
@@ -77,17 +77,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor? output_scale=None) -> ()");
|
||||
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
|
||||
|
||||
// Activation ops
|
||||
// Activation function used in SwiGLU.
|
||||
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
|
||||
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
|
||||
|
||||
// SwiGLU activation with input clamping.
|
||||
ops.def(
|
||||
"silu_and_mul_with_clamp(Tensor! result, Tensor input, float limit) "
|
||||
"-> ()");
|
||||
ops.impl("silu_and_mul_with_clamp", torch::kCUDA, &silu_and_mul_clamp);
|
||||
|
||||
// Activation ops (quantized only — basic ops moved to _C_stable_libtorch)
|
||||
ops.def(
|
||||
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
|
||||
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
|
||||
@@ -104,39 +94,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
|
||||
&silu_and_mul_per_block_quant);
|
||||
|
||||
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
|
||||
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
|
||||
|
||||
// Activation function used in GeGLU with `none` approximation.
|
||||
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
|
||||
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
|
||||
|
||||
// Activation function used in GeGLU with `tanh` approximation.
|
||||
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
|
||||
ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
|
||||
|
||||
// FATReLU implementation.
|
||||
ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");
|
||||
ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul);
|
||||
|
||||
ops.def(
|
||||
"swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float "
|
||||
"limit=7.0) "
|
||||
"-> ()");
|
||||
ops.impl("swigluoai_and_mul", torch::kCUDA, &swigluoai_and_mul);
|
||||
|
||||
// GELU implementation used in GPT-2.
|
||||
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
|
||||
ops.impl("gelu_new", torch::kCUDA, &gelu_new);
|
||||
|
||||
// Approximate GELU implementation.
|
||||
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
|
||||
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
|
||||
|
||||
// Quick GELU implementation.
|
||||
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
|
||||
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
|
||||
|
||||
// Layernorm
|
||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||
ops.def(
|
||||
@@ -318,39 +275,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
#endif
|
||||
|
||||
// Dequantization for GGML.
|
||||
ops.def(
|
||||
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? "
|
||||
"dtype) -> Tensor");
|
||||
ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
|
||||
|
||||
// mmvq kernel for GGML.
|
||||
ops.def(
|
||||
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) "
|
||||
"-> Tensor");
|
||||
ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
|
||||
|
||||
// mmq kernel for GGML.
|
||||
ops.def(
|
||||
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
|
||||
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
|
||||
|
||||
// moe kernel for GGML.
|
||||
ops.def(
|
||||
"ggml_moe_a8(Tensor X, Tensor W, "
|
||||
"Tensor sorted_token_ids, Tensor expert_ids, Tensor "
|
||||
"num_tokens_post_padded, "
|
||||
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor");
|
||||
ops.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8);
|
||||
|
||||
ops.def(
|
||||
"ggml_moe_a8_vec(Tensor X, Tensor W, "
|
||||
"Tensor topk_ids, int top_k, "
|
||||
"int type, SymInt row, SymInt tokens) -> Tensor");
|
||||
ops.impl("ggml_moe_a8_vec", torch::kCUDA, &ggml_moe_a8_vec);
|
||||
|
||||
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Expert-specialization mxfp8 blockscaled grouped quantization (SM100+).
|
||||
ops.def(
|
||||
@@ -370,57 +294,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
#endif
|
||||
|
||||
// Quantized GEMM for GPTQ.
|
||||
// Note: even though the C++ inferred schema is correct for this op, it seems
|
||||
// to prevent the meta function registry.
|
||||
ops.def(
|
||||
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
|
||||
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, bool "
|
||||
"use_v2_format, int bit) "
|
||||
"-> Tensor");
|
||||
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
|
||||
|
||||
// Post processing for GPTQ.
|
||||
ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
|
||||
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
|
||||
|
||||
// Compute FP8 quantized tensor for given scaling factor.
|
||||
// Supports per-tensor, per-channel, per-token, and arbitrary 2D group
|
||||
// scaling. Optional group_m/group_n specify the group shape explicitly;
|
||||
// required for 1D scales to disambiguate per-channel vs per-token.
|
||||
ops.def(
|
||||
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale, "
|
||||
"(int, int)? group_shape=None) -> ()");
|
||||
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
|
||||
|
||||
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
|
||||
ops.def(
|
||||
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
|
||||
"-> "
|
||||
"()");
|
||||
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
|
||||
|
||||
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
|
||||
ops.def(
|
||||
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
|
||||
"Tensor! scale, Tensor? scale_ub) -> "
|
||||
"()");
|
||||
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
|
||||
&dynamic_per_token_scaled_fp8_quant);
|
||||
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale,"
|
||||
"Tensor? azp) -> ()");
|
||||
ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
|
||||
|
||||
// Compute int8 quantized tensor and scaling factor
|
||||
ops.def(
|
||||
"dynamic_scaled_int8_quant(Tensor! result, Tensor input, Tensor! scale, "
|
||||
"Tensor!? azp) -> ()");
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
|
||||
&dynamic_scaled_int8_quant);
|
||||
|
||||
// Mamba selective scan kernel
|
||||
ops.def(
|
||||
"selective_scan_fwd(Tensor! u, Tensor! delta,"
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
#pragma once
|
||||
|
||||
// Shared TORCH_UTILS_CHECK across both libtorch stable and unstable source
|
||||
// files. Keep this header free of CUTLASS/CUTE so attention/quant headers can
|
||||
// use it.
|
||||
//
|
||||
// If TORCH_TARGET_VERSION is defined, we are building _C_stable_libtorch.so so
|
||||
// use STD_TORCH_CHECK via header-only.
|
||||
// Otherwise, use TORCH_CHECK via torch/all.h.
|
||||
|
||||
#ifdef TORCH_TARGET_VERSION
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
#define TORCH_UTILS_CHECK STD_TORCH_CHECK
|
||||
#else
|
||||
#include <torch/all.h>
|
||||
#define TORCH_UTILS_CHECK TORCH_CHECK
|
||||
#endif
|
||||
@@ -1045,9 +1045,7 @@ if _is_cpu():
|
||||
|
||||
if _build_custom_ops():
|
||||
ext_modules.append(CMakeExtension(name="vllm._C"))
|
||||
# also _is_hip() once https://github.com/vllm-project/vllm/issues/35163 is
|
||||
# fixed
|
||||
if _is_cuda():
|
||||
if _is_cuda() or _is_hip():
|
||||
ext_modules.append(CMakeExtension(name="vllm._C_stable_libtorch"))
|
||||
|
||||
package_data = {
|
||||
|
||||
@@ -44,6 +44,11 @@ except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._C with %r", e)
|
||||
|
||||
# import custom ops, trigger op registration
|
||||
try:
|
||||
import vllm._C_stable_libtorch # noqa: F401
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._C_stable_libtorch with %r", e)
|
||||
|
||||
try:
|
||||
import vllm._rocm_C # noqa: F401
|
||||
except ImportError as e:
|
||||
|
||||
Reference in New Issue
Block a user