From 07aeaf9d4df870a76d5a0dc19d6a7e74b4be5d3b Mon Sep 17 00:00:00 2001 From: Chris Leonard Date: Wed, 20 May 2026 03:18:12 -0400 Subject: [PATCH] [6/n] Migrate activation kernels, gptq, gguf, non cutlass w8a8 to libtorch stable ABI (continued) (#42663) Signed-off-by: Mikayla Gawarecki Signed-off-by: Chris Leonard Co-authored-by: Mikayla Gawarecki Co-authored-by: Shengqi Chen --- CMakeLists.txt | 63 ++-- csrc/attention/dtype_fp8.cuh | 3 +- csrc/cuda_vec_utils.cuh | 2 + csrc/cutlass_extensions/torch_utils.hpp | 6 +- .../activation_kernels.cu | 220 +++++------ csrc/libtorch_stable/ops.h | 83 +++++ .../quantization/gguf/gguf_kernel.cu | 342 +++++++++--------- .../quantization/gguf/moe.cuh | 0 .../quantization/gguf/moe_vec.cuh | 0 .../quantization/gptq/compat.cuh | 0 .../quantization/gptq/matrix_view.cuh | 0 .../quantization/gptq/q_gemm.cu | 53 +-- .../quantization/gptq/qdq_2.cuh | 0 .../quantization/gptq/qdq_3.cuh | 0 .../quantization/gptq/qdq_4.cuh | 0 .../quantization/gptq/qdq_8.cuh | 0 .../quantization/gptq/qdq_util.cuh | 0 .../quantization/w8a8/fp8/common.cu | 195 +++++----- .../quantization/w8a8/int8/scaled_quant.cu | 86 +++-- csrc/libtorch_stable/torch_bindings.cpp | 139 +++++++ csrc/libtorch_stable/torch_utils.h | 6 +- csrc/ops.h | 47 --- csrc/quantization/utils.cuh | 30 +- csrc/quantization/w8a8/fp8/common.cuh | 19 +- csrc/torch_bindings.cpp | 129 +------ csrc/torch_utils.h | 17 + setup.py | 4 +- vllm/platforms/rocm.py | 5 + 28 files changed, 810 insertions(+), 639 deletions(-) rename csrc/{ => libtorch_stable}/activation_kernels.cu (79%) rename csrc/{ => libtorch_stable}/quantization/gguf/gguf_kernel.cu (61%) rename csrc/{ => libtorch_stable}/quantization/gguf/moe.cuh (100%) rename csrc/{ => libtorch_stable}/quantization/gguf/moe_vec.cuh (100%) rename csrc/{ => libtorch_stable}/quantization/gptq/compat.cuh (100%) rename csrc/{ => libtorch_stable}/quantization/gptq/matrix_view.cuh (100%) rename csrc/{ => libtorch_stable}/quantization/gptq/q_gemm.cu (97%) rename csrc/{ => libtorch_stable}/quantization/gptq/qdq_2.cuh (100%) rename csrc/{ => libtorch_stable}/quantization/gptq/qdq_3.cuh (100%) rename csrc/{ => libtorch_stable}/quantization/gptq/qdq_4.cuh (100%) rename csrc/{ => libtorch_stable}/quantization/gptq/qdq_8.cuh (100%) rename csrc/{ => libtorch_stable}/quantization/gptq/qdq_util.cuh (100%) rename csrc/{ => libtorch_stable}/quantization/w8a8/fp8/common.cu (66%) rename csrc/{ => libtorch_stable}/quantization/w8a8/int8/scaled_quant.cu (79%) create mode 100644 csrc/torch_utils.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 35171b473d3..d5039470d41 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -312,19 +312,14 @@ set(VLLM_EXT_SRC "csrc/attention/paged_attention_v2.cu" "csrc/attention/merge_attn_states.cu" "csrc/pos_encoding_kernels.cu" - "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" "csrc/fused_qknorm_rope_kernel.cu" "csrc/layernorm_quant_kernels.cu" "csrc/sampler.cu" "csrc/topk.cu" "csrc/cuda_view.cu" - "csrc/quantization/gptq/q_gemm.cu" - "csrc/quantization/w8a8/int8/scaled_quant.cu" - "csrc/quantization/w8a8/fp8/common.cu" "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu" - "csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/activation_kernels.cu" "csrc/cuda_utils_kernels.cu" "csrc/custom_all_reduce.cu" @@ -628,33 +623,33 @@ define_extension_target( # Setting this variable sidesteps the issue by calling the driver directly. target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) -# add OR VLLM_GPU_LANG STREQUAL "HIP" here once -# https://github.com/vllm-project/vllm/issues/35163 is resolved -if(VLLM_GPU_LANG STREQUAL "CUDA") +if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") # # _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY) # set(VLLM_STABLE_EXT_SRC "csrc/libtorch_stable/torch_bindings.cpp" - "csrc/cutlass_extensions/common.cpp" - "csrc/cuda_utils_kernels.cu" - "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu" - "csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu" - "csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu") + "csrc/libtorch_stable/activation_kernels.cu" + "csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu" + "csrc/libtorch_stable/quantization/w8a8/fp8/common.cu" + "csrc/libtorch_stable/quantization/gptq/q_gemm.cu" + "csrc/libtorch_stable/quantization/gguf/gguf_kernel.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_STABLE_EXT_SRC + "csrc/cuda_utils_kernels.cu" + "csrc/cutlass_extensions/common.cpp" + "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu" + "csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu" + "csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/libtorch_stable/permute_cols.cu" "csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu" "csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu" "csrc/libtorch_stable/quantization/awq/gemm_kernels.cu") - endif() - if(VLLM_GPU_LANG STREQUAL "CUDA") set_gencode_flags_for_srcs( SRCS "${VLLM_STABLE_EXT_SRC}" CUDA_ARCHS "${CUDA_ARCHS}") - endif() # DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper and later) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) @@ -1034,6 +1029,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Building hadacore") endif() + # if CUDA endif + endif() + message(STATUS "Enabling C_stable extension.") define_extension_target( _C_stable_libtorch @@ -1053,13 +1051,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") target_compile_definitions(_C_stable_libtorch PRIVATE TORCH_TARGET_VERSION=0x020A000000000000ULL) - # Needed to use cuda APIs from C-shim - target_compile_definitions(_C_stable_libtorch PRIVATE - USE_CUDA) + # Needed to use cuda/hip APIs from C-shim + if(VLLM_GPU_LANG STREQUAL "CUDA") + target_compile_definitions(_C_stable_libtorch PRIVATE USE_CUDA) + # Needed by CUTLASS kernels + target_compile_definitions(_C_stable_libtorch PRIVATE + CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) + elseif(VLLM_GPU_LANG STREQUAL "HIP") + target_compile_definitions(_C_stable_libtorch PRIVATE USE_ROCM) + endif() - # Needed by CUTLASS kernels - target_compile_definitions(_C_stable_libtorch PRIVATE - CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) + # On ROCm, _C_stable_libtorch calls raw HIP APIs (e.g. hipGetDevice in + # get_device_prop()) which must resolve to the same libamdhip64.so that + # PyTorch uses. When PyTorch bundles its own copy (pip/conda wheels), + # the raw HIP calls would otherwise resolve to the system ROCm copy, + # initializing a second HIP runtime that corrupts device state (wrong + # device on DeviceGuard, core dumps on multi-GPU tests). + # + # If PyTorch doesn't bundle libamdhip64 (built from source against system + # ROCm), there is only one copy in the process and no action is needed — + # the HIP compiler already links the system libamdhip64 automatically. + if(VLLM_GPU_LANG STREQUAL "HIP") + find_library(_STABLE_TORCH_AMDHIP64 amdhip64 + PATHS "${TORCH_INSTALL_PREFIX}/lib" NO_DEFAULT_PATH) + if(_STABLE_TORCH_AMDHIP64) + message(STATUS "Found PyTorch-bundled libamdhip64 at ${_STABLE_TORCH_AMDHIP64}") + target_link_libraries(_C_stable_libtorch PRIVATE ${_STABLE_TORCH_AMDHIP64}) + endif() + endif() endif() # diff --git a/csrc/attention/dtype_fp8.cuh b/csrc/attention/dtype_fp8.cuh index 1afec3c3997..3d56859d8fc 100644 --- a/csrc/attention/dtype_fp8.cuh +++ b/csrc/attention/dtype_fp8.cuh @@ -1,6 +1,7 @@ #pragma once #include "attention_generic.cuh" +#include "torch_utils.h" #include #ifdef ENABLE_FP8 @@ -30,7 +31,7 @@ inline Fp8KVCacheDataType get_fp8_kv_cache_data_type( } else if (dtype_str == "fp8_e5m2") { return Fp8KVCacheDataType::kFp8E5M2; } - TORCH_CHECK(false, "Unsupported fp8 kv cache data type: ", dtype_str); + TORCH_UTILS_CHECK(false, "Unsupported fp8 kv cache data type: ", dtype_str); } // fp8 vector types for quantization of kv cache diff --git a/csrc/cuda_vec_utils.cuh b/csrc/cuda_vec_utils.cuh index 91e181c5856..efbb09994d2 100644 --- a/csrc/cuda_vec_utils.cuh +++ b/csrc/cuda_vec_utils.cuh @@ -9,6 +9,8 @@ #ifdef USE_ROCM #include + #include + #include #else #include #include diff --git a/csrc/cutlass_extensions/torch_utils.hpp b/csrc/cutlass_extensions/torch_utils.hpp index 45f25ea3d39..5f973033fd7 100644 --- a/csrc/cutlass_extensions/torch_utils.hpp +++ b/csrc/cutlass_extensions/torch_utils.hpp @@ -1,5 +1,7 @@ #pragma once +#include "torch_utils.h" + // This header is shared between _C (unstable ABI, used by machete) and // _C_stable_libtorch (stable ABI, used by W4A8/sparse). TORCH_TARGET_VERSION // is defined only for the stable target, so we switch includes and types @@ -8,13 +10,9 @@ #include #include #include - #include // for STD_TORCH_CHECK using TorchTensor = torch::stable::Tensor; - #define TORCH_UTILS_CHECK STD_TORCH_CHECK #else - #include using TorchTensor = torch::Tensor; - #define TORCH_UTILS_CHECK TORCH_CHECK #endif #include "cute/layout.hpp" diff --git a/csrc/activation_kernels.cu b/csrc/libtorch_stable/activation_kernels.cu similarity index 79% rename from csrc/activation_kernels.cu rename to csrc/libtorch_stable/activation_kernels.cu index 303433392c3..28fdce5c305 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/libtorch_stable/activation_kernels.cu @@ -1,12 +1,12 @@ -#include -#include -#include +#include +#include #include -#include "cuda_compat.h" -#include "cuda_vec_utils.cuh" +#include "../cuda_compat.h" +#include "../cuda_vec_utils.cuh" #include "dispatch_utils.h" +#include "torch_utils.h" namespace vllm { @@ -210,64 +210,68 @@ packed_gelu_tanh_kernel(const packed_t& val) { return; \ } \ dim3 grid(num_tokens); \ - int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ + int cc_major = get_device_prop()->major; \ int support_vec = \ (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ - int vec_size = support_vec / at::elementSize(dtype); \ + int vec_size = support_vec / input.element_size(); \ const bool use_vec = (d % vec_size == 0); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + const torch::stable::accelerator::DeviceGuard device_guard( \ + input.get_device_index()); \ + const cudaStream_t stream = get_current_cuda_stream(); \ if (use_vec) { \ dim3 block(std::min(d / vec_size, 1024)); \ if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ vllm::act_and_mul_kernel< \ scalar_t, typename vllm::PackedTypeConverter::Type, \ KERNEL, \ PACKED_KERNEL::Type>, \ ACT_FIRST, true, HAS_CLAMP, true><<>>( \ - out.data_ptr(), input.data_ptr(), d, LIMIT); \ + out.mutable_data_ptr(), \ + input.const_data_ptr(), d, LIMIT); \ }); \ } else { \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ vllm::act_and_mul_kernel< \ scalar_t, typename vllm::PackedTypeConverter::Type, \ KERNEL, \ PACKED_KERNEL::Type>, \ ACT_FIRST, true, HAS_CLAMP, false><<>>( \ - out.data_ptr(), input.data_ptr(), d, LIMIT); \ + out.mutable_data_ptr(), \ + input.const_data_ptr(), d, LIMIT); \ }); \ } \ } else { \ dim3 block(std::min(d, 1024)); \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ vllm::act_and_mul_kernel< \ scalar_t, typename vllm::PackedTypeConverter::Type, \ KERNEL, \ PACKED_KERNEL::Type>, \ ACT_FIRST, false, HAS_CLAMP><<>>( \ - out.data_ptr(), input.data_ptr(), d, LIMIT); \ + out.mutable_data_ptr(), input.const_data_ptr(), \ + d, LIMIT); \ }); \ } -void silu_and_mul(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void silu_and_mul(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel, true, false, 0.0f); } -void silu_and_mul_clamp(torch::Tensor& out, // [..., d] - torch::Tensor& input, // [..., 2 * d] +void silu_and_mul_clamp(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input, // [..., 2 * d] double limit) { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel, true, true, (float)limit); } -void mul_and_silu(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void mul_and_silu(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., 2 * d] { // The difference between mul_and_silu and silu_and_mul is that mul_and_silu // applies the silu to the latter half of the input. @@ -275,15 +279,15 @@ void mul_and_silu(torch::Tensor& out, // [..., d] false, false, 0.0f); } -void gelu_and_mul(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void gelu_and_mul(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, vllm::packed_gelu_kernel, true, false, 0.0f); } -void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void gelu_tanh_and_mul(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL( vllm::gelu_tanh_kernel, vllm::packed_gelu_tanh_kernel, true, false, 0.0f); @@ -434,19 +438,20 @@ __global__ void swigluoai_and_mul_kernel( return; \ } \ dim3 grid(num_tokens); \ - int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ + int cc_major = get_device_prop()->major; \ int support_vec = \ (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ - int vec_size = support_vec / at::elementSize(dtype); \ + int vec_size = support_vec / input.element_size(); \ const bool use_vec = (d % vec_size == 0); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + const torch::stable::accelerator::DeviceGuard device_guard( \ + input.get_device_index()); \ + const cudaStream_t stream = get_current_cuda_stream(); \ if (use_vec) { \ dim3 block(std::min(d / vec_size, 1024)); \ if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ - VLLM_DISPATCH_FLOATING_TYPES( \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES( \ dtype, "act_and_mul_kernel_with_param", [&] { \ vllm::act_and_mul_kernel_with_param< \ scalar_t, typename vllm::PackedTypeConverter::Type, \ @@ -454,11 +459,11 @@ __global__ void swigluoai_and_mul_kernel( PACKED_KERNEL< \ typename vllm::PackedTypeConverter::Type>, \ true, true><<>>( \ - out.data_ptr(), input.data_ptr(), d, \ - PARAM); \ + out.mutable_data_ptr(), \ + input.const_data_ptr(), d, PARAM); \ }); \ } else { \ - VLLM_DISPATCH_FLOATING_TYPES( \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES( \ dtype, "act_and_mul_kernel_with_param", [&] { \ vllm::act_and_mul_kernel_with_param< \ scalar_t, typename vllm::PackedTypeConverter::Type, \ @@ -466,45 +471,49 @@ __global__ void swigluoai_and_mul_kernel( PACKED_KERNEL< \ typename vllm::PackedTypeConverter::Type>, \ true, false><<>>( \ - out.data_ptr(), input.data_ptr(), d, \ - PARAM); \ + out.mutable_data_ptr(), \ + input.const_data_ptr(), d, PARAM); \ }); \ } \ } else { \ dim3 block(std::min(d, 1024)); \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel_with_param", [&] { \ - vllm::act_and_mul_kernel_with_param< \ - scalar_t, typename vllm::PackedTypeConverter::Type, \ - KERNEL, \ - PACKED_KERNEL::Type>, \ - false><<>>( \ - out.data_ptr(), input.data_ptr(), d, PARAM); \ - }); \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES( \ + dtype, "act_and_mul_kernel_with_param", [&] { \ + vllm::act_and_mul_kernel_with_param< \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ + KERNEL, \ + PACKED_KERNEL< \ + typename vllm::PackedTypeConverter::Type>, \ + false><<>>( \ + out.mutable_data_ptr(), \ + input.const_data_ptr(), d, PARAM); \ + }); \ } -#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \ - vllm::swigluoai_and_mul_kernel> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d, ALPHA, \ - LIMIT); \ +#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const torch::stable::accelerator::DeviceGuard device_guard( \ + input.get_device_index()); \ + const cudaStream_t stream = get_current_cuda_stream(); \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \ + vllm::swigluoai_and_mul_kernel> \ + <<>>(out.mutable_data_ptr(), \ + input.const_data_ptr(), d, \ + ALPHA, LIMIT); \ }); -void fatrelu_and_mul(torch::Tensor& out, // [..., d], - torch::Tensor& input, // [..., 2 * d] +void fatrelu_and_mul(torch::stable::Tensor& out, // [..., d], + torch::stable::Tensor& input, // [..., 2 * d] double threshold) { LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM( vllm::fatrelu_kernel, vllm::packed_fatrelu_kernel, threshold); } -void swigluoai_and_mul(torch::Tensor& out, // [..., d] - torch::Tensor& input, // [..., 2 * d] +void swigluoai_and_mul(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input, // [..., 2 * d] double alpha, double limit) { LAUNCH_SIGLUOAI_AND_MUL(vllm::swigluoai_and_mul, alpha, limit); } @@ -559,45 +568,46 @@ __global__ void activation_kernel( } // namespace vllm // Launch element-wise activation kernel. -#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ - auto dtype = input.scalar_type(); \ - int d = input.size(-1); \ - int64_t num_tokens = input.numel() / input.size(-1); \ - if (num_tokens == 0) { \ - return; \ - } \ - dim3 grid(num_tokens); \ - int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ - int support_vec = \ - (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ - ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ - : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ - int vec_size = support_vec / at::elementSize(dtype); \ - const bool use_vec = (d % vec_size == 0); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - if (use_vec) { \ - dim3 block(std::min(d / vec_size, 1024)); \ - if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ - vllm::activation_kernel, true, true> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d); \ - }); \ - } else { \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ - vllm::activation_kernel, true, false> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d); \ - }); \ - } \ - } else { \ - dim3 block(std::min(d, 1024)); \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ - vllm::activation_kernel, false> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d); \ - }); \ +#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ + auto dtype = input.scalar_type(); \ + int d = input.size(-1); \ + int64_t num_tokens = input.numel() / input.size(-1); \ + if (num_tokens == 0) { \ + return; \ + } \ + dim3 grid(num_tokens); \ + int cc_major = get_device_prop()->major; \ + int support_vec = \ + (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ + ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ + : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ + int vec_size = support_vec / input.element_size(); \ + const bool use_vec = (d % vec_size == 0); \ + const torch::stable::accelerator::DeviceGuard device_guard( \ + input.get_device_index()); \ + const cudaStream_t stream = get_current_cuda_stream(); \ + if (use_vec) { \ + dim3 block(std::min(d / vec_size, 1024)); \ + if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ + vllm::activation_kernel, true, true> \ + <<>>(out.mutable_data_ptr(), \ + input.const_data_ptr(), d); \ + }); \ + } else { \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ + vllm::activation_kernel, true, false> \ + <<>>(out.mutable_data_ptr(), \ + input.const_data_ptr(), d); \ + }); \ + } \ + } else { \ + dim3 block(std::min(d, 1024)); \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ + vllm::activation_kernel, false> \ + <<>>(out.mutable_data_ptr(), \ + input.const_data_ptr(), d); \ + }); \ } namespace vllm { @@ -625,20 +635,20 @@ __device__ __forceinline__ T gelu_quick_kernel(const T& x) { } // namespace vllm -void gelu_new(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_new(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); } -void gelu_fast(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_fast(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); } -void gelu_quick(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_quick(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel); } diff --git a/csrc/libtorch_stable/ops.h b/csrc/libtorch_stable/ops.h index cdae5fff60f..5ebcb2034f5 100644 --- a/csrc/libtorch_stable/ops.h +++ b/csrc/libtorch_stable/ops.h @@ -166,3 +166,86 @@ torch::stable::Tensor awq_dequantize(torch::stable::Tensor _kernel, torch::stable::Tensor hadacore_transform(torch::stable::Tensor& x, bool inplace); + +// Activation kernels (shared CUDA/ROCm) +void silu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input); +void silu_and_mul_clamp(torch::stable::Tensor& out, + torch::stable::Tensor& input, double limit); +void mul_and_silu(torch::stable::Tensor& out, torch::stable::Tensor& input); +void gelu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input); +void gelu_tanh_and_mul(torch::stable::Tensor& out, + torch::stable::Tensor& input); +void fatrelu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input, + double threshold); +void swigluoai_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input, + double alpha = 1.702, double limit = 7.0); +void gelu_new(torch::stable::Tensor& out, torch::stable::Tensor& input); +void gelu_fast(torch::stable::Tensor& out, torch::stable::Tensor& input); +void gelu_quick(torch::stable::Tensor& out, torch::stable::Tensor& input); + +// INT8 quantization kernels (shared CUDA/ROCm) +void static_scaled_int8_quant(torch::stable::Tensor& out, + torch::stable::Tensor const& input, + torch::stable::Tensor const& scale, + std::optional const& azp); + +void dynamic_scaled_int8_quant(torch::stable::Tensor& out, + torch::stable::Tensor const& input, + torch::stable::Tensor& scales, + std::optional const& azp); + +// FP8 quantization kernels (shared CUDA/ROCm) +void static_scaled_fp8_quant( + torch::stable::Tensor& out, torch::stable::Tensor const& input, + torch::stable::Tensor const& scale, + std::optional group_shape = + std::nullopt); + +void dynamic_scaled_fp8_quant(torch::stable::Tensor& out, + torch::stable::Tensor const& input, + torch::stable::Tensor& scale); + +void dynamic_per_token_scaled_fp8_quant( + torch::stable::Tensor& out, torch::stable::Tensor const& input, + torch::stable::Tensor& scale, + std::optional const& scale_ub); + +// GPTQ kernels (shared CUDA/ROCm) +torch::stable::Tensor gptq_gemm(torch::stable::Tensor a, + torch::stable::Tensor b_q_weight, + torch::stable::Tensor b_gptq_qzeros, + torch::stable::Tensor b_gptq_scales, + torch::stable::Tensor b_g_idx, bool use_exllama, + bool use_v2_format, int64_t bit); + +void gptq_shuffle(torch::stable::Tensor q_weight, torch::stable::Tensor q_perm, + int64_t bit); + +// GGML kernels (shared CUDA/ROCm) +torch::stable::Tensor ggml_dequantize( + torch::stable::Tensor W, int64_t type, int64_t m, int64_t n, + std::optional const& dtype); + +torch::stable::Tensor ggml_mul_mat_vec_a8(torch::stable::Tensor W, + torch::stable::Tensor X, int64_t type, + int64_t row); + +torch::stable::Tensor ggml_mul_mat_a8(torch::stable::Tensor W, + torch::stable::Tensor X, int64_t type, + int64_t row); + +torch::stable::Tensor ggml_moe_a8(torch::stable::Tensor X, + torch::stable::Tensor W, + torch::stable::Tensor sorted_token_ids, + torch::stable::Tensor expert_ids, + torch::stable::Tensor num_tokens_post_padded, + int64_t type, int64_t row, int64_t top_k, + int64_t tokens); + +torch::stable::Tensor ggml_moe_a8_vec(torch::stable::Tensor X, + torch::stable::Tensor W, + torch::stable::Tensor topk_ids, + int64_t top_k, int64_t type, int64_t row, + int64_t tokens); + +int64_t ggml_moe_get_block_size(int64_t type); diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/libtorch_stable/quantization/gguf/gguf_kernel.cu similarity index 61% rename from csrc/quantization/gguf/gguf_kernel.cu rename to csrc/libtorch_stable/quantization/gguf/gguf_kernel.cu index 76fe73e9504..0fdfcafab8c 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/libtorch_stable/quantization/gguf/gguf_kernel.cu @@ -1,17 +1,20 @@ #include #include -#include -#include +#include "../../../cuda_compat.h" +#include "../../dispatch_utils.h" +#include "../../torch_utils.h" -#include "../../cuda_compat.h" -#include "dispatch_utils.h" +#include -#include "ggml-common.h" -#include "vecdotq.cuh" -#include "dequantize.cuh" -#include "mmvq.cuh" -#include "mmq.cuh" +// NOTE: These headers are intentionally kept in csrc/quantization/gguf/ (not +// moved to libtorch_stable) to avoid unnecessary reformatting that would break +// git rename detection and pollute blame history. +#include "../../../quantization/gguf/ggml-common.h" +#include "../../../quantization/gguf/vecdotq.cuh" +#include "../../../quantization/gguf/dequantize.cuh" +#include "../../../quantization/gguf/mmvq.cuh" +#include "../../../quantization/gguf/mmq.cuh" #include "moe.cuh" #include "moe_vec.cuh" @@ -71,16 +74,17 @@ static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx, } } -torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight - int64_t type, int64_t m, int64_t n, - std::optional const& dtype) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(W)); - auto dtype_ = dtype.value_or(torch::kFloat16); - auto options = torch::TensorOptions().dtype(dtype_).device(W.device()); - at::Tensor DW = torch::empty({m, n}, options); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); +torch::stable::Tensor ggml_dequantize( + torch::stable::Tensor W, // quant weight + int64_t type, int64_t m, int64_t n, + std::optional const& dtype) { + const torch::stable::accelerator::DeviceGuard device_guard( + W.get_device_index()); + auto dtype_ = dtype.value_or(torch::headeronly::ScalarType::Half); + auto DW = torch::stable::empty({m, n}, dtype_, std::nullopt, W.device()); + cudaStream_t stream = get_current_cuda_stream(); - VLLM_DISPATCH_FLOATING_TYPES(DW.scalar_type(), "ggml_dequantize", [&] { + VLLM_STABLE_DISPATCH_FLOATING_TYPES(DW.scalar_type(), "ggml_dequantize", [&] { auto to_cuda = ggml_get_to_cuda(type); to_cuda((void*)W.data_ptr(), (scalar_t*)DW.data_ptr(), m * n, stream); }); @@ -88,135 +92,142 @@ torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight return DW; } -torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight - torch::Tensor X, // input - int64_t type, int64_t row) { +torch::stable::Tensor ggml_mul_mat_vec_a8( + torch::stable::Tensor W, // quant weight + torch::stable::Tensor X, // input + int64_t type, int64_t row) { int col = X.sizes()[1]; int vecs = X.sizes()[0]; const int padded = (col + 512 - 1) / 512 * 512; - const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); - auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device()); - at::Tensor Y = torch::empty({vecs, row}, options); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); - at::Tensor quant_X = torch::empty({vecs, padded / 32 * 9}, options); - VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_vec_a8", [&] { - quantize_row_q8_1_cuda( - (scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), col, vecs, stream); - switch (type) { - case 2: - mul_mat_vec_q4_0_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 3: - mul_mat_vec_q4_1_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 6: - mul_mat_vec_q5_0_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 7: - mul_mat_vec_q5_1_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 8: - mul_mat_vec_q8_0_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 10: - mul_mat_vec_q2_K_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 11: - mul_mat_vec_q3_K_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 12: - mul_mat_vec_q4_K_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 13: - mul_mat_vec_q5_K_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 14: - mul_mat_vec_q6_K_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 16: - mul_mat_vec_iq2_xxs_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 17: - mul_mat_vec_iq2_xs_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 18: - mul_mat_vec_iq3_xxs_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 19: - mul_mat_vec_iq1_s_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 20: - mul_mat_vec_iq4_nl_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 21: - mul_mat_vec_iq3_s_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 22: - mul_mat_vec_iq2_s_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 23: - mul_mat_vec_iq4_xs_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - case 29: - mul_mat_vec_iq1_m_q8_1_cuda( - (void*)W.data_ptr(), (void*)quant_X.data_ptr(), - (scalar_t*)Y.data_ptr(), col, row, vecs, stream); - break; - } - }); + const torch::stable::accelerator::DeviceGuard device_guard( + X.get_device_index()); + auto Y = torch::stable::empty({vecs, row}, X.scalar_type(), std::nullopt, + W.device()); + cudaStream_t stream = get_current_cuda_stream(); + auto quant_X = torch::stable::empty({vecs, padded / 32 * 9}, + torch::headeronly::ScalarType::Int, + std::nullopt, W.device()); + VLLM_STABLE_DISPATCH_FLOATING_TYPES( + X.scalar_type(), "ggml_mul_mat_vec_a8", [&] { + quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), + (void*)quant_X.data_ptr(), col, vecs, + stream); + switch (type) { + case 2: + mul_mat_vec_q4_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 3: + mul_mat_vec_q4_1_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 6: + mul_mat_vec_q5_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 7: + mul_mat_vec_q5_1_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 8: + mul_mat_vec_q8_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 10: + mul_mat_vec_q2_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 11: + mul_mat_vec_q3_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 12: + mul_mat_vec_q4_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 13: + mul_mat_vec_q5_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 14: + mul_mat_vec_q6_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 16: + mul_mat_vec_iq2_xxs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 17: + mul_mat_vec_iq2_xs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 18: + mul_mat_vec_iq3_xxs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 19: + mul_mat_vec_iq1_s_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 20: + mul_mat_vec_iq4_nl_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 21: + mul_mat_vec_iq3_s_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 22: + mul_mat_vec_iq2_s_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 23: + mul_mat_vec_iq4_xs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + case 29: + mul_mat_vec_iq1_m_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), col, row, vecs, stream); + break; + } + }); return Y; } -torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight - torch::Tensor X, // input - int64_t type, int64_t row) { +torch::stable::Tensor ggml_mul_mat_a8(torch::stable::Tensor W, // quant weight + torch::stable::Tensor X, // input + int64_t type, int64_t row) { int col = X.sizes()[1]; int padded = (col + 512 - 1) / 512 * 512; int batch = X.sizes()[0]; - const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); - auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device()); - at::Tensor Y = torch::empty({batch, row}, options); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); - at::Tensor quant_X = torch::empty({batch, padded / 32 * 9}, options); - VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_a8", [&] { + const torch::stable::accelerator::DeviceGuard device_guard( + X.get_device_index()); + auto Y = torch::stable::empty({batch, row}, X.scalar_type(), std::nullopt, + W.device()); + cudaStream_t stream = get_current_cuda_stream(); + auto quant_X = torch::stable::empty({batch, padded / 32 * 9}, + torch::headeronly::ScalarType::Int, + std::nullopt, W.device()); + VLLM_STABLE_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_a8", [&] { quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), col, batch, stream); @@ -276,21 +287,24 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight return Y; } -torch::Tensor ggml_moe_a8(torch::Tensor X, // input - torch::Tensor W, // expert weights - torch::Tensor sorted_token_ids, - torch::Tensor expert_ids, - torch::Tensor num_tokens_post_padded, int64_t type, - int64_t row, int64_t top_k, int64_t tokens) { +torch::stable::Tensor ggml_moe_a8(torch::stable::Tensor X, // input + torch::stable::Tensor W, // expert weights + torch::stable::Tensor sorted_token_ids, + torch::stable::Tensor expert_ids, + torch::stable::Tensor num_tokens_post_padded, + int64_t type, int64_t row, int64_t top_k, + int64_t tokens) { int col = X.sizes()[1]; int padded = (col + 512 - 1) / 512 * 512; - const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); - auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device()); - at::Tensor Y = torch::empty({tokens * top_k, row}, options); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); - at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options); - VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_a8", [&] { + const torch::stable::accelerator::DeviceGuard device_guard( + X.get_device_index()); + auto Y = torch::stable::empty({tokens * top_k, row}, X.scalar_type(), + std::nullopt, W.device()); + cudaStream_t stream = get_current_cuda_stream(); + auto quant_X = torch::stable::empty({tokens, padded / 32 * 9}, + torch::headeronly::ScalarType::Int, + std::nullopt, W.device()); + VLLM_STABLE_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_a8", [&] { quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), col, tokens, stream); switch (type) { @@ -379,19 +393,23 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, // input return Y; } -torch::Tensor ggml_moe_a8_vec(torch::Tensor X, // input - torch::Tensor W, // expert weights - torch::Tensor topk_ids, int64_t top_k, - int64_t type, int64_t row, int64_t tokens) { +torch::stable::Tensor ggml_moe_a8_vec( + torch::stable::Tensor X, // input + torch::stable::Tensor W, // expert weights + torch::stable::Tensor topk_ids, int64_t top_k, int64_t type, int64_t row, + int64_t tokens) { int col = X.sizes()[1]; const int padded = (col + 512 - 1) / 512 * 512; - const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); - auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device()); - at::Tensor Y = torch::zeros({tokens * top_k, row}, options); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); - at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options); - VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_vec_a8", [&] { + const torch::stable::accelerator::DeviceGuard device_guard( + X.get_device_index()); + auto Y = torch::stable::empty({tokens * top_k, row}, X.scalar_type(), + std::nullopt, W.device()); + torch::stable::fill_(Y, 0.0); + cudaStream_t stream = get_current_cuda_stream(); + auto quant_X = torch::stable::empty({tokens, padded / 32 * 9}, + torch::headeronly::ScalarType::Int, + std::nullopt, W.device()); + VLLM_STABLE_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_vec_a8", [&] { quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), col, tokens, stream); diff --git a/csrc/quantization/gguf/moe.cuh b/csrc/libtorch_stable/quantization/gguf/moe.cuh similarity index 100% rename from csrc/quantization/gguf/moe.cuh rename to csrc/libtorch_stable/quantization/gguf/moe.cuh diff --git a/csrc/quantization/gguf/moe_vec.cuh b/csrc/libtorch_stable/quantization/gguf/moe_vec.cuh similarity index 100% rename from csrc/quantization/gguf/moe_vec.cuh rename to csrc/libtorch_stable/quantization/gguf/moe_vec.cuh diff --git a/csrc/quantization/gptq/compat.cuh b/csrc/libtorch_stable/quantization/gptq/compat.cuh similarity index 100% rename from csrc/quantization/gptq/compat.cuh rename to csrc/libtorch_stable/quantization/gptq/compat.cuh diff --git a/csrc/quantization/gptq/matrix_view.cuh b/csrc/libtorch_stable/quantization/gptq/matrix_view.cuh similarity index 100% rename from csrc/quantization/gptq/matrix_view.cuh rename to csrc/libtorch_stable/quantization/gptq/matrix_view.cuh diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/libtorch_stable/quantization/gptq/q_gemm.cu similarity index 97% rename from csrc/quantization/gptq/q_gemm.cu rename to csrc/libtorch_stable/quantization/gptq/q_gemm.cu index 8a29ad5ab2d..e3f79c5a6b8 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/libtorch_stable/quantization/gptq/q_gemm.cu @@ -6,9 +6,8 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa #include #include -#include -#include -#include +#include "../../torch_utils.h" +#include #include #include @@ -735,7 +734,7 @@ void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight, fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const cudaStream_t stream = get_current_cuda_stream(); kernel<<>>( a, b_q_weight, b_gptq_qzeros, b_gptq_scales, c, size_m, size_n, size_k, groups, use_v2_format, b_q_perm); @@ -1164,7 +1163,7 @@ void reconstruct_exllama(const uint32_t* b_q_weight, reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel; } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const cudaStream_t stream = get_current_cuda_stream(); reconstruct_exllama_kernel<<>>( b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups, use_v2_format, out); @@ -1376,7 +1375,7 @@ void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight, kernel = gemm_half_q_half_alt_8bit_kernel; } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const cudaStream_t stream = get_current_cuda_stream(); kernel<<>>( (const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx, size_m, size_k / 32 * bit, size_n, use_v2_format); @@ -1485,7 +1484,7 @@ void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, gridDim.y = DIVIDE(height, 32); } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const cudaStream_t stream = get_current_cuda_stream(); kernel<<>>(b_q_weight, b_gptq_scales, b_gptq_qzeros, b_g_idx, height, width, groups, use_v2_format, out); @@ -1794,7 +1793,7 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height, } else if (bit == 8) { kernel = make_sequential_8bit_kernel; } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const cudaStream_t stream = get_current_cuda_stream(); kernel<<>>(q_weight, new_qweight, q_perm, width); // Replace qweights @@ -1818,29 +1817,34 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height, } else if (bit == 8) { shuffle_kernel = shuffle_8bit_kernel; } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const cudaStream_t stream = get_current_cuda_stream(); shuffle_kernel<<>>(q_weight, height, width); } } // namespace gptq } // namespace vllm -torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama, bool use_v2_format, int64_t bit) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - at::Tensor c = torch::zeros({a.size(0), b_q_weight.size(1)}, options); - at::Tensor temp_dq = torch::empty( - {b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options); +torch::stable::Tensor gptq_gemm(torch::stable::Tensor a, + torch::stable::Tensor b_q_weight, + torch::stable::Tensor b_gptq_qzeros, + torch::stable::Tensor b_gptq_scales, + torch::stable::Tensor b_g_idx, bool use_exllama, + bool use_v2_format, int64_t bit) { + const torch::stable::accelerator::DeviceGuard device_guard( + a.get_device_index()); + auto c = torch::stable::new_zeros(a, {a.size(0), b_q_weight.size(1)}); + auto temp_dq = + torch::stable::empty({b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, + a.scalar_type(), std::nullopt, a.device()); vllm::gptq::gemm_half_q_half_cuda( - at::cuda::getCurrentCUDABlasHandle(), (const half*)a.data_ptr(), + get_current_cuda_blas_handle(), (const half*)a.data_ptr(), (const uint32_t*)b_q_weight.data_ptr(), (const uint32_t*)b_gptq_qzeros.data_ptr(), (const half*)b_gptq_scales.data_ptr(), - b_g_idx.device().is_meta() ? NULL : (const int*)b_g_idx.data_ptr(), + b_g_idx.device().type() == torch::stable::DeviceType::Meta + ? NULL + : (const int*)b_g_idx.data_ptr(), (half*)c.data_ptr(), (half*)temp_dq.data_ptr(), c.size(0), // m c.size(1), // n @@ -1850,11 +1854,14 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, return c; } -void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); +void gptq_shuffle(torch::stable::Tensor q_weight, torch::stable::Tensor q_perm, + int64_t bit) { + const torch::stable::accelerator::DeviceGuard device_guard( + q_weight.get_device_index()); vllm::gptq::shuffle_exllama_weight( (uint32_t*)q_weight.data_ptr(), - q_perm.device().is_meta() || q_perm.numel() == 0 + q_perm.device().type() == torch::stable::DeviceType::Meta || + q_perm.numel() == 0 ? NULL : (int*)q_perm.data_ptr(), q_weight.size(0) * 32 / bit, q_weight.size(1), bit); diff --git a/csrc/quantization/gptq/qdq_2.cuh b/csrc/libtorch_stable/quantization/gptq/qdq_2.cuh similarity index 100% rename from csrc/quantization/gptq/qdq_2.cuh rename to csrc/libtorch_stable/quantization/gptq/qdq_2.cuh diff --git a/csrc/quantization/gptq/qdq_3.cuh b/csrc/libtorch_stable/quantization/gptq/qdq_3.cuh similarity index 100% rename from csrc/quantization/gptq/qdq_3.cuh rename to csrc/libtorch_stable/quantization/gptq/qdq_3.cuh diff --git a/csrc/quantization/gptq/qdq_4.cuh b/csrc/libtorch_stable/quantization/gptq/qdq_4.cuh similarity index 100% rename from csrc/quantization/gptq/qdq_4.cuh rename to csrc/libtorch_stable/quantization/gptq/qdq_4.cuh diff --git a/csrc/quantization/gptq/qdq_8.cuh b/csrc/libtorch_stable/quantization/gptq/qdq_8.cuh similarity index 100% rename from csrc/quantization/gptq/qdq_8.cuh rename to csrc/libtorch_stable/quantization/gptq/qdq_8.cuh diff --git a/csrc/quantization/gptq/qdq_util.cuh b/csrc/libtorch_stable/quantization/gptq/qdq_util.cuh similarity index 100% rename from csrc/quantization/gptq/qdq_util.cuh rename to csrc/libtorch_stable/quantization/gptq/qdq_util.cuh diff --git a/csrc/quantization/w8a8/fp8/common.cu b/csrc/libtorch_stable/quantization/w8a8/fp8/common.cu similarity index 66% rename from csrc/quantization/w8a8/fp8/common.cu rename to csrc/libtorch_stable/quantization/w8a8/fp8/common.cu index 52e159d6501..d02fc2296e6 100644 --- a/csrc/quantization/w8a8/fp8/common.cu +++ b/csrc/libtorch_stable/quantization/w8a8/fp8/common.cu @@ -1,11 +1,9 @@ -#include "common.cuh" -#include "dispatch_utils.h" -#include "cub_helpers.h" -#include "libtorch_stable/quantization/vectorization_utils.cuh" -#include -#include -#include - +#include "../../../../quantization/w8a8/fp8/common.cuh" +#include "../../../dispatch_utils.h" +#include "../../../../cub_helpers.h" +#include "../../vectorization_utils.cuh" +#include "../../../torch_utils.h" +#include namespace vllm { // STRIDE_I_ZERO: true if scale_stride_i == 0 (per-tensor or per-channel) @@ -183,16 +181,16 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided( } // namespace vllm void static_scaled_fp8_quant( - torch::Tensor& out, // [..., d] - torch::Tensor const& input, // [..., d] - torch::Tensor const& scale, // various shapes - std::optional> - opt_group_shape) // optional explicit (group_m, group_n) + torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor const& input, // [..., d] + torch::stable::Tensor const& scale, // various shapes + std::optional + opt_group_shape) // optional explicit [group_m, group_n] { - TORCH_CHECK(input.stride(-1) == 1, - "last dimension of input must be contiguous"); - TORCH_CHECK(out.stride(-1) == 1, - "last dimension of output must be contiguous"); + STD_TORCH_CHECK(input.stride(-1) == 1, + "last dimension of input must be contiguous"); + STD_TORCH_CHECK(out.stride(-1) == 1, + "last dimension of output must be contiguous"); const int hidden_size = input.size(-1); // N (columns) const int num_tokens = input.numel() / hidden_size; // M (rows) @@ -212,13 +210,18 @@ void static_scaled_fp8_quant( } else if (scale.dim() == 1) { // 1D scale: require explicit group_shape to disambiguate per-channel vs // per-token (avoids edge case where num_tokens == hidden_size) - TORCH_CHECK(opt_group_shape.has_value(), - "1D scale requires explicit group_shape to disambiguate " - "per-channel vs per-token quantization. " - "Use group_shape=(-1, 1) for per-channel or group_shape=(1, " - "-1) for per-token."); + STD_TORCH_CHECK( + opt_group_shape.has_value(), + "1D scale requires explicit group_shape to disambiguate " + "per-channel vs per-token quantization. " + "Use group_shape=(-1, 1) for per-channel or group_shape=(1, " + "-1) for per-token."); + STD_TORCH_CHECK(opt_group_shape->size() == 2, + "group_shape must have exactly 2 elements, got ", + opt_group_shape->size()); - const auto& [opt_group_m, opt_group_n] = opt_group_shape.value(); + const auto opt_group_m = (*opt_group_shape)[0]; + const auto opt_group_n = (*opt_group_shape)[1]; group_m = opt_group_m == -1 ? num_tokens : static_cast(opt_group_m); group_n = opt_group_n == -1 ? hidden_size : static_cast(opt_group_n); @@ -228,11 +231,11 @@ void static_scaled_fp8_quant( const int64_t expected_scale_n = hidden_size / group_n; const int64_t expected_scale_numel = expected_scale_m * expected_scale_n; - TORCH_CHECK(scale_len == expected_scale_numel, "1D scale length (", - scale_len, ") does not match expected size (", - expected_scale_numel, ") for group_shape (", opt_group_m, ", ", - opt_group_n, ") with input shape (", num_tokens, ", ", - hidden_size, ")"); + STD_TORCH_CHECK(scale_len == expected_scale_numel, "1D scale length (", + scale_len, ") does not match expected size (", + expected_scale_numel, ") for group_shape (", opt_group_m, + ", ", opt_group_n, ") with input shape (", num_tokens, ", ", + hidden_size, ")"); // For 1D scale, determine strides based on which dim is trivial // Scale indexing: scale[gi * scale_stride_i + gj * scale_stride_j] @@ -248,7 +251,7 @@ void static_scaled_fp8_quant( scale_stride_i = scale.stride(0); scale_stride_j = 0; } else { - TORCH_CHECK( + STD_TORCH_CHECK( false, "1D scale can only be used when one of the scale dimensions is 1. " "For 2D group scaling, use a 2D scale tensor."); @@ -259,10 +262,12 @@ void static_scaled_fp8_quant( const int64_t scale_size_0 = scale.size(0); const int64_t scale_size_1 = scale.size(1); - TORCH_CHECK(num_tokens % scale_size_0 == 0, "num_tokens (", num_tokens, - ") must be divisible by scale.size(0) (", scale_size_0, ")"); - TORCH_CHECK(hidden_size % scale_size_1 == 0, "hidden_size (", hidden_size, - ") must be divisible by scale.size(1) (", scale_size_1, ")"); + STD_TORCH_CHECK(num_tokens % scale_size_0 == 0, "num_tokens (", num_tokens, + ") must be divisible by scale.size(0) (", scale_size_0, + ")"); + STD_TORCH_CHECK(hidden_size % scale_size_1 == 0, "hidden_size (", + hidden_size, ") must be divisible by scale.size(1) (", + scale_size_1, ")"); // Infer from 2D scale shape int inferred_group_m = num_tokens / scale_size_0; @@ -270,16 +275,21 @@ void static_scaled_fp8_quant( // Use explicit if provided, otherwise use inferred if (opt_group_shape.has_value()) { - const auto& [opt_group_m, opt_group_n] = opt_group_shape.value(); + STD_TORCH_CHECK(opt_group_shape->size() == 2, + "group_shape must have exactly 2 elements, got ", + opt_group_shape->size()); + const auto opt_group_m = (*opt_group_shape)[0]; + const auto opt_group_n = (*opt_group_shape)[1]; group_m = opt_group_m == -1 ? num_tokens : static_cast(opt_group_m); group_n = opt_group_n == -1 ? hidden_size : static_cast(opt_group_n); // Validate explicit matches inferred - TORCH_CHECK(group_m == inferred_group_m && group_n == inferred_group_n, - "Explicit group_shape (", opt_group_m, ", ", opt_group_n, - ") does not match inferred group shape (", inferred_group_m, - ", ", inferred_group_n, ") from 2D scale tensor shape (", - scale_size_0, ", ", scale_size_1, ")"); + STD_TORCH_CHECK( + group_m == inferred_group_m && group_n == inferred_group_n, + "Explicit group_shape (", opt_group_m, ", ", opt_group_n, + ") does not match inferred group shape (", inferred_group_m, ", ", + inferred_group_n, ") from 2D scale tensor shape (", scale_size_0, + ", ", scale_size_1, ")"); } else { group_m = inferred_group_m; group_n = inferred_group_n; @@ -288,8 +298,8 @@ void static_scaled_fp8_quant( scale_stride_i = scale.stride(0); scale_stride_j = scale.stride(1); } else { - TORCH_CHECK(false, "scale must be 0D, 1D, or 2D tensor, but got ", - scale.dim(), "D"); + STD_TORCH_CHECK(false, "scale must be 0D, 1D, or 2D tensor, but got ", + scale.dim(), "D"); } const int block_size = 256; @@ -299,37 +309,39 @@ void static_scaled_fp8_quant( const int64_t in_row_stride = input.stride(-2); const int64_t out_row_stride = out.stride(-2); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(); // Dispatch to template-specialized kernel based on stride pattern - VLLM_DISPATCH_FLOATING_TYPES( + VLLM_STABLE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { - VLLM_DISPATCH_FP8_TYPES( + VLLM_STABLE_DISPATCH_FP8_TYPES( out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { - VLLM_DISPATCH_BOOL(scale_stride_i == 0, S0_ZERO, [&] { - VLLM_DISPATCH_BOOL(scale_stride_j == 0, S1_ZERO, [&] { + VLLM_STABLE_DISPATCH_BOOL(scale_stride_i == 0, S0_ZERO, [&] { + VLLM_STABLE_DISPATCH_BOOL(scale_stride_j == 0, S1_ZERO, [&] { vllm::scaled_fp8_quant_kernel_strided_group_shape< scalar_t, fp8_t, S0_ZERO, S1_ZERO> <<>>( - out.data_ptr(), input.data_ptr(), - scale.data_ptr(), hidden_size, in_row_stride, - out_row_stride, group_m, group_n, scale_stride_i, - scale_stride_j); + out.mutable_data_ptr(), + input.const_data_ptr(), + scale.const_data_ptr(), hidden_size, + in_row_stride, out_row_stride, group_m, group_n, + scale_stride_i, scale_stride_j); }); }); }); }); } -void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] - torch::Tensor const& input, // [..., d] - torch::Tensor& scale) // [1] +void dynamic_scaled_fp8_quant(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor const& input, // [..., d] + torch::stable::Tensor& scale) // [1] { - TORCH_CHECK(input.stride(-1) == 1, - "last dimension of input must be contiguous"); - TORCH_CHECK(out.stride(-1) == 1, - "last dimension of output must be contiguous"); + STD_TORCH_CHECK(input.stride(-1) == 1, + "last dimension of input must be contiguous"); + STD_TORCH_CHECK(out.stride(-1) == 1, + "last dimension of output must be contiguous"); const int hidden_size = input.size(-1); const int num_tokens = input.numel() / hidden_size; @@ -340,40 +352,43 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] const int64_t in_row_stride = input.stride(-2); const int64_t out_row_stride = out.stride(-2); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(); // scale tensor should be initialised to <=0 before reduction - AT_CUDA_CHECK( - cudaMemsetAsync(scale.data_ptr(), 0, sizeof(float), stream)); + STD_CUDA_CHECK(cudaMemsetAsync(scale.mutable_data_ptr(), 0, + sizeof(float), stream)); - VLLM_DISPATCH_FLOATING_TYPES( + VLLM_STABLE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { - VLLM_DISPATCH_FP8_TYPES( + VLLM_STABLE_DISPATCH_FP8_TYPES( out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { vllm::segmented_max_reduction_strided <<>>( - scale.data_ptr(), input.data_ptr(), - hidden_size, in_row_stride, - static_cast(num_tokens)); + scale.mutable_data_ptr(), + input.const_data_ptr(), hidden_size, + in_row_stride, static_cast(num_tokens)); vllm::scaled_fp8_quant_kernel_strided_dynamic - <<>>( - out.data_ptr(), input.data_ptr(), - scale.data_ptr(), hidden_size, in_row_stride, - out_row_stride); + <<>>(out.mutable_data_ptr(), + input.const_data_ptr(), + scale.const_data_ptr(), + hidden_size, in_row_stride, + out_row_stride); }); }); } void dynamic_per_token_scaled_fp8_quant( - torch::Tensor& out, // [..., d] - torch::Tensor const& input, // [..., d] - torch::Tensor& scales, std::optional const& scale_ub) { - TORCH_CHECK(input.stride(-1) == 1, - "last dimension of input must be contiguous"); - TORCH_CHECK(out.stride(-1) == 1, - "last dimension of output must be contiguous"); + torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor const& input, // [..., d] + torch::stable::Tensor& scales, + std::optional const& scale_ub) { + STD_TORCH_CHECK(input.stride(-1) == 1, + "last dimension of input must be contiguous"); + STD_TORCH_CHECK(out.stride(-1) == 1, + "last dimension of output must be contiguous"); const int hidden_size = input.size(-1); const int num_tokens = input.numel() / hidden_size; @@ -384,20 +399,24 @@ void dynamic_per_token_scaled_fp8_quant( const int64_t in_row_stride = input.stride(-2); const int64_t out_row_stride = out.stride(-2); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(); + VLLM_STABLE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] { - VLLM_DISPATCH_FP8_TYPES( + VLLM_STABLE_DISPATCH_FP8_TYPES( out.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] { - vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided< - scalar_t, fp8_t><<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - hidden_size, in_row_stride, out_row_stride); + vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided + <<>>( + out.mutable_data_ptr(), + scales.mutable_data_ptr(), + input.const_data_ptr(), + scale_ub.has_value() ? scale_ub->const_data_ptr() + : nullptr, + hidden_size, in_row_stride, out_row_stride); }); }); } diff --git a/csrc/quantization/w8a8/int8/scaled_quant.cu b/csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu similarity index 79% rename from csrc/quantization/w8a8/int8/scaled_quant.cu rename to csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu index ae1395a363c..ede7913a355 100644 --- a/csrc/quantization/w8a8/int8/scaled_quant.cu +++ b/csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu @@ -1,12 +1,11 @@ -#include -#include -#include +#include #include -#include "dispatch_utils.h" -#include "libtorch_stable/quantization/vectorization_utils.cuh" -#include "cub_helpers.h" +#include "../../../dispatch_utils.h" +#include "../../../torch_utils.h" +#include "../../vectorization_utils.cuh" +#include "../../../../cub_helpers.h" static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM @@ -263,66 +262,73 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( } // namespace vllm -void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] - torch::Tensor const& input, // [..., hidden_size] - torch::Tensor const& scale, - std::optional const& azp) { - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(scale.numel() == 1); - TORCH_CHECK(!azp || azp->numel() == 1); +void static_scaled_int8_quant( + torch::stable::Tensor& out, // [..., hidden_size] + torch::stable::Tensor const& input, // [..., hidden_size] + torch::stable::Tensor const& scale, + std::optional const& azp) { + STD_TORCH_CHECK(input.is_contiguous()); + STD_TORCH_CHECK(out.is_contiguous()); + STD_TORCH_CHECK(scale.numel() == 1); + STD_TORCH_CHECK(!azp || azp->numel() == 1); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); dim3 const block(std::min(hidden_size, 256)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(); + VLLM_STABLE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { if (!azp) { vllm::static_scaled_int8_quant_kernel - <<>>( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), hidden_size); + <<>>(input.const_data_ptr(), + out.mutable_data_ptr(), + scale.const_data_ptr(), + hidden_size); } else { vllm::static_scaled_int8_azp_quant_kernel <<>>( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), azp->data_ptr(), - hidden_size); + input.const_data_ptr(), + out.mutable_data_ptr(), scale.const_data_ptr(), + azp->const_data_ptr(), hidden_size); } }); } void dynamic_scaled_int8_quant( - torch::Tensor& out, // [..., hidden_size] - torch::Tensor const& input, // [..., hidden_size] - torch::Tensor& scales, std::optional const& azp) { - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(scales.is_contiguous()); - TORCH_CHECK(!azp || azp->is_contiguous()); + torch::stable::Tensor& out, // [..., hidden_size] + torch::stable::Tensor const& input, // [..., hidden_size] + torch::stable::Tensor& scales, + std::optional const& azp) { + STD_TORCH_CHECK(input.is_contiguous()); + STD_TORCH_CHECK(out.is_contiguous()); + STD_TORCH_CHECK(scales.is_contiguous()); + STD_TORCH_CHECK(!azp || azp->is_contiguous()); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); dim3 const block(std::min(hidden_size, 256)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(); + VLLM_STABLE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { if (!azp) { vllm::dynamic_scaled_int8_quant_kernel - <<>>( - input.data_ptr(), out.data_ptr(), - scales.data_ptr(), hidden_size); + <<>>(input.const_data_ptr(), + out.mutable_data_ptr(), + scales.mutable_data_ptr(), + hidden_size); } else { vllm::dynamic_scaled_int8_azp_quant_kernel - <<>>( - input.data_ptr(), out.data_ptr(), - scales.data_ptr(), azp->data_ptr(), - hidden_size); + <<>>(input.const_data_ptr(), + out.mutable_data_ptr(), + scales.mutable_data_ptr(), + azp->mutable_data_ptr(), + hidden_size); } }); } \ No newline at end of file diff --git a/csrc/libtorch_stable/torch_bindings.cpp b/csrc/libtorch_stable/torch_bindings.cpp index 0bbccd4222f..ee0af3da560 100644 --- a/csrc/libtorch_stable/torch_bindings.cpp +++ b/csrc/libtorch_stable/torch_bindings.cpp @@ -266,6 +266,109 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) { // Hadamard transforms // conditionally compiled so impl registration is in source file ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor"); + + // Activation ops + // Activation function used in SwiGLU. + ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()"); + ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); + + // SwiGLU activation with input clamping. + ops.def( + "silu_and_mul_with_clamp(Tensor! result, Tensor input, float limit) " + "-> ()"); + + // Activation function used in GeGLU with `none` approximation. + ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); + + // Activation function used in GeGLU with `tanh` approximation. + ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); + + // FATReLU implementation. + ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()"); + + ops.def( + "swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float " + "limit=7.0) " + "-> ()"); + + // GELU implementation used in GPT-2. + ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); + + // Approximate GELU implementation. + ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); + + // Quick GELU implementation. + ops.def("gelu_quick(Tensor! out, Tensor input) -> ()"); + + // Compute int8 quantized tensor for given scaling factor. + ops.def( + "static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale," + "Tensor? azp) -> ()"); + + // Compute int8 quantized tensor and scaling factor + ops.def( + "dynamic_scaled_int8_quant(Tensor! result, Tensor input, Tensor! scale, " + "Tensor!? azp) -> ()"); + + // Compute FP8 quantized tensor for given scaling factor. + // Supports per-tensor, per-channel, per-token, and arbitrary 2D group + // scaling. Optional group_m/group_n specify the group shape explicitly; + // required for 1D scales to disambiguate per-channel vs per-token. + ops.def( + "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale, " + "int[]? group_shape=None) -> ()"); + + // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor. + ops.def( + "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) " + "-> " + "()"); + + // Compute dynamic-per-token FP8 quantized tensor and scaling factor. + ops.def( + "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, " + "Tensor! scale, Tensor? scale_ub) -> " + "()"); + + // Quantized GEMM for GPTQ. + // Note: even though the C++ inferred schema is correct for this op, it seems + // to prevent the meta function registry. + ops.def( + "gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, " + "Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, bool " + "use_v2_format, int bit) " + "-> Tensor"); + + // Post processing for GPTQ. + ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"); + + // Dequantization for GGML. + ops.def( + "ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? " + "dtype) -> Tensor"); + + // mmvq kernel for GGML. + ops.def( + "ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) " + "-> Tensor"); + + // mmq kernel for GGML. + ops.def( + "ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor"); + + // moe kernel for GGML. + ops.def( + "ggml_moe_a8(Tensor X, Tensor W, " + "Tensor sorted_token_ids, Tensor expert_ids, Tensor " + "num_tokens_post_padded, " + "int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor"); + + ops.def( + "ggml_moe_a8_vec(Tensor X, Tensor W, " + "Tensor topk_ids, int top_k, " + "int type, SymInt row, SymInt tokens) -> Tensor"); + + ops.def("ggml_moe_get_block_size(int type) -> int"); } STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) { @@ -312,6 +415,39 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) { // AllSpark ops: conditionally compiled so impl registrations are in source // files (allspark_repack.cu and allspark_qgemm_w8a16.cu) #endif + + // Activation kernels (shared CUDA/ROCm) + ops.impl("silu_and_mul", TORCH_BOX(&silu_and_mul)); + ops.impl("mul_and_silu", TORCH_BOX(&mul_and_silu)); + ops.impl("gelu_and_mul", TORCH_BOX(&gelu_and_mul)); + ops.impl("gelu_tanh_and_mul", TORCH_BOX(&gelu_tanh_and_mul)); + ops.impl("fatrelu_and_mul", TORCH_BOX(&fatrelu_and_mul)); + ops.impl("swigluoai_and_mul", TORCH_BOX(&swigluoai_and_mul)); + ops.impl("gelu_new", TORCH_BOX(&gelu_new)); + ops.impl("gelu_fast", TORCH_BOX(&gelu_fast)); + ops.impl("gelu_quick", TORCH_BOX(&gelu_quick)); + ops.impl("silu_and_mul_with_clamp", TORCH_BOX(&silu_and_mul_clamp)); + + // INT8 quantization kernels + ops.impl("static_scaled_int8_quant", TORCH_BOX(&static_scaled_int8_quant)); + ops.impl("dynamic_scaled_int8_quant", TORCH_BOX(&dynamic_scaled_int8_quant)); + + // FP8 quantization kernels + ops.impl("static_scaled_fp8_quant", TORCH_BOX(&static_scaled_fp8_quant)); + ops.impl("dynamic_scaled_fp8_quant", TORCH_BOX(&dynamic_scaled_fp8_quant)); + ops.impl("dynamic_per_token_scaled_fp8_quant", + TORCH_BOX(&dynamic_per_token_scaled_fp8_quant)); + + // GPTQ kernels + ops.impl("gptq_gemm", TORCH_BOX(&gptq_gemm)); + ops.impl("gptq_shuffle", TORCH_BOX(&gptq_shuffle)); + + // GGML kernels + ops.impl("ggml_dequantize", TORCH_BOX(&ggml_dequantize)); + ops.impl("ggml_mul_mat_vec_a8", TORCH_BOX(&ggml_mul_mat_vec_a8)); + ops.impl("ggml_mul_mat_a8", TORCH_BOX(&ggml_mul_mat_a8)); + ops.impl("ggml_moe_a8", TORCH_BOX(&ggml_moe_a8)); + ops.impl("ggml_moe_a8_vec", TORCH_BOX(&ggml_moe_a8_vec)); } // These capability-check functions take only primitive args (no tensors), so @@ -329,6 +465,9 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) { ops.impl("cutlass_scaled_mm_supports_fp4", TORCH_BOX(&cutlass_scaled_mm_supports_fp4)); #endif + + // GGML block size lookup (no tensor args) + ops.impl("ggml_moe_get_block_size", TORCH_BOX(&ggml_moe_get_block_size)); } REGISTER_EXTENSION(_C_stable_libtorch) diff --git a/csrc/libtorch_stable/torch_utils.h b/csrc/libtorch_stable/torch_utils.h index db2ff557c41..1adbb4d4986 100644 --- a/csrc/libtorch_stable/torch_utils.h +++ b/csrc/libtorch_stable/torch_utils.h @@ -6,8 +6,12 @@ #include #include +#ifndef USE_ROCM + #include +#else + #include +#endif #include -#include #include #include diff --git a/csrc/ops.h b/csrc/ops.h index c0689534bea..d3db38aebfb 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -149,17 +149,10 @@ void persistent_masked_m_silu_mul_quant( at::Tensor& y_s, // (E, T, H//group_size) [OUT] bool use_ue8m0); -void mul_and_silu(torch::Tensor& out, torch::Tensor& input); - void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); -void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input, - double threshold); -void swigluoai_and_mul(torch::Tensor& out, torch::Tensor& input, - double alpha = 1.702, double limit = 7.0); - void gelu_new(torch::Tensor& out, torch::Tensor& input); void gelu_fast(torch::Tensor& out, torch::Tensor& input); @@ -174,28 +167,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor); -torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, - int64_t n, - std::optional const& dtype); - -torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, - int64_t type, int64_t row); - -torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, - int64_t row); - -torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W, - torch::Tensor sorted_token_ids, - torch::Tensor expert_ids, - torch::Tensor num_tokens_post_padded, int64_t type, - int64_t row, int64_t top_k, int64_t tokens); - -torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W, - torch::Tensor topk_ids, int64_t top_k, - int64_t type, int64_t row, int64_t tokens); - -int64_t ggml_moe_get_block_size(int64_t type); - void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale, std::optional const& azp); @@ -204,24 +175,6 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales, std::optional const& azp); -torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama, bool use_v2_format, int64_t bit); - -void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); - -void static_scaled_fp8_quant( - torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale, - std::optional> group_shape = std::nullopt); - -void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor& scale); - -void dynamic_per_token_scaled_fp8_quant( - torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, - std::optional const& scale_ub); - void selective_scan_fwd( const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& C, diff --git a/csrc/quantization/utils.cuh b/csrc/quantization/utils.cuh index 73055a15287..6bb9b9fc563 100644 --- a/csrc/quantization/utils.cuh +++ b/csrc/quantization/utils.cuh @@ -7,23 +7,23 @@ */ #include -#include +#include #ifndef USE_ROCM - #include + #include #define MAYBE_HOST_DEVICE C10_HOST_DEVICE #else - #include - #include - #include + #include + #include // ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr #define MAYBE_HOST_DEVICE #endif template || - std::is_same_v || - std::is_same_v>> + typename = std::enable_if_t< + std::is_same_v || + std::is_same_v || + std::is_same_v>> struct quant_type_max { static constexpr T val() { return std::numeric_limits::max(); } }; @@ -31,9 +31,10 @@ struct quant_type_max { // Using the default max value from pytorch (240.0 0x7F) will cause accuracy // issues when running dynamic quantization. Here use 224.0 0x7E for rocm. template <> -struct quant_type_max { - static constexpr c10::Float8_e4m3fnuz val() { - return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits()); +struct quant_type_max { + static constexpr torch::headeronly::Float8_e4m3fnuz val() { + return torch::headeronly::Float8_e4m3fnuz( + 0x7E, torch::headeronly::Float8_e4m3fnuz::from_bits()); } }; @@ -42,9 +43,10 @@ MAYBE_HOST_DEVICE static constexpr T quant_type_max_v = quant_type_max::val(); template || - std::is_same_v || - std::is_same_v>> + typename = std::enable_if_t< + std::is_same_v || + std::is_same_v || + std::is_same_v>> struct min_scaling_factor { C10_DEVICE C10_ALWAYS_INLINE static float val() { return 1.0f / (quant_type_max_v * 512.0f); diff --git a/csrc/quantization/w8a8/fp8/common.cuh b/csrc/quantization/w8a8/fp8/common.cuh index 7576f717950..087f5099165 100644 --- a/csrc/quantization/w8a8/fp8/common.cuh +++ b/csrc/quantization/w8a8/fp8/common.cuh @@ -5,6 +5,19 @@ #include +// This header is shared between _C and _C_stable_libtorch targets. +// torch_utils.h provides get_device_prop(). We need to pass USE_CUDA +// to the .so to expose some of the shims used by torch_utils.h. For now +// this is only done for _C_stable_libtorch and not for _C, so we use the +// non stable at::cuda::getCurrentDeviceProperties for _C for now. +#ifdef TORCH_TARGET_VERSION + #include "../../../libtorch_stable/torch_utils.h" +#else + #ifdef USE_ROCM + #include + #endif +#endif + #ifndef USE_ROCM #include "nvidia/quant_utils.cuh" #else @@ -18,7 +31,11 @@ static bool is_fp8_ocp() { #ifndef USE_ROCM return true; #else - auto dprops = at::cuda::getCurrentDeviceProperties(); + #ifdef TORCH_TARGET_VERSION + auto* dprops = get_device_prop(); + #else + auto* dprops = at::cuda::getCurrentDeviceProperties(); + #endif std::string device_arch = dprops->gcnArchName; size_t substring = device_arch.find("gfx94"); return substring == std::string::npos; diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index dde4b3028b2..b88e2bb4e68 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -77,17 +77,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor? output_scale=None) -> ()"); ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); - // Activation ops - // Activation function used in SwiGLU. - ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()"); - ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); - - // SwiGLU activation with input clamping. - ops.def( - "silu_and_mul_with_clamp(Tensor! result, Tensor input, float limit) " - "-> ()"); - ops.impl("silu_and_mul_with_clamp", torch::kCUDA, &silu_and_mul_clamp); - + // Activation ops (quantized only — basic ops moved to _C_stable_libtorch) ops.def( "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); @@ -104,39 +94,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("silu_and_mul_per_block_quant", torch::kCUDA, &silu_and_mul_per_block_quant); - ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); - ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu); - - // Activation function used in GeGLU with `none` approximation. - ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); - ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); - - // Activation function used in GeGLU with `tanh` approximation. - ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); - ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); - - // FATReLU implementation. - ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()"); - ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul); - - ops.def( - "swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float " - "limit=7.0) " - "-> ()"); - ops.impl("swigluoai_and_mul", torch::kCUDA, &swigluoai_and_mul); - - // GELU implementation used in GPT-2. - ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); - ops.impl("gelu_new", torch::kCUDA, &gelu_new); - - // Approximate GELU implementation. - ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); - ops.impl("gelu_fast", torch::kCUDA, &gelu_fast); - - // Quick GELU implementation. - ops.def("gelu_quick(Tensor! out, Tensor input) -> ()"); - ops.impl("gelu_quick", torch::kCUDA, &gelu_quick); - // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( @@ -318,39 +275,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { #endif - // Dequantization for GGML. - ops.def( - "ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? " - "dtype) -> Tensor"); - ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize); - - // mmvq kernel for GGML. - ops.def( - "ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) " - "-> Tensor"); - ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8); - - // mmq kernel for GGML. - ops.def( - "ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor"); - ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8); - - // moe kernel for GGML. - ops.def( - "ggml_moe_a8(Tensor X, Tensor W, " - "Tensor sorted_token_ids, Tensor expert_ids, Tensor " - "num_tokens_post_padded, " - "int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor"); - ops.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8); - - ops.def( - "ggml_moe_a8_vec(Tensor X, Tensor W, " - "Tensor topk_ids, int top_k, " - "int type, SymInt row, SymInt tokens) -> Tensor"); - ops.impl("ggml_moe_a8_vec", torch::kCUDA, &ggml_moe_a8_vec); - - ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size); - #ifndef USE_ROCM // Expert-specialization mxfp8 blockscaled grouped quantization (SM100+). ops.def( @@ -370,57 +294,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { #endif - // Quantized GEMM for GPTQ. - // Note: even though the C++ inferred schema is correct for this op, it seems - // to prevent the meta function registry. - ops.def( - "gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, " - "Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, bool " - "use_v2_format, int bit) " - "-> Tensor"); - ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm); - - // Post processing for GPTQ. - ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"); - ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); - - // Compute FP8 quantized tensor for given scaling factor. - // Supports per-tensor, per-channel, per-token, and arbitrary 2D group - // scaling. Optional group_m/group_n specify the group shape explicitly; - // required for 1D scales to disambiguate per-channel vs per-token. - ops.def( - "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale, " - "(int, int)? group_shape=None) -> ()"); - ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); - - // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor. - ops.def( - "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) " - "-> " - "()"); - ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); - - // Compute dynamic-per-token FP8 quantized tensor and scaling factor. - ops.def( - "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, " - "Tensor! scale, Tensor? scale_ub) -> " - "()"); - ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, - &dynamic_per_token_scaled_fp8_quant); - - // Compute int8 quantized tensor for given scaling factor. - ops.def( - "static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale," - "Tensor? azp) -> ()"); - ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant); - - // Compute int8 quantized tensor and scaling factor - ops.def( - "dynamic_scaled_int8_quant(Tensor! result, Tensor input, Tensor! scale, " - "Tensor!? azp) -> ()"); - ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, - &dynamic_scaled_int8_quant); - // Mamba selective scan kernel ops.def( "selective_scan_fwd(Tensor! u, Tensor! delta," diff --git a/csrc/torch_utils.h b/csrc/torch_utils.h new file mode 100644 index 00000000000..898b9e113e1 --- /dev/null +++ b/csrc/torch_utils.h @@ -0,0 +1,17 @@ +#pragma once + +// Shared TORCH_UTILS_CHECK across both libtorch stable and unstable source +// files. Keep this header free of CUTLASS/CUTE so attention/quant headers can +// use it. +// +// If TORCH_TARGET_VERSION is defined, we are building _C_stable_libtorch.so so +// use STD_TORCH_CHECK via header-only. +// Otherwise, use TORCH_CHECK via torch/all.h. + +#ifdef TORCH_TARGET_VERSION + #include + #define TORCH_UTILS_CHECK STD_TORCH_CHECK +#else + #include + #define TORCH_UTILS_CHECK TORCH_CHECK +#endif diff --git a/setup.py b/setup.py index d8b97e33e5c..d07a5e0dad5 100644 --- a/setup.py +++ b/setup.py @@ -1045,9 +1045,7 @@ if _is_cpu(): if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm._C")) - # also _is_hip() once https://github.com/vllm-project/vllm/issues/35163 is - # fixed - if _is_cuda(): + if _is_cuda() or _is_hip(): ext_modules.append(CMakeExtension(name="vllm._C_stable_libtorch")) package_data = { diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 6c6b2c3399a..114d236f131 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -44,6 +44,11 @@ except ImportError as e: logger.warning("Failed to import from vllm._C with %r", e) # import custom ops, trigger op registration +try: + import vllm._C_stable_libtorch # noqa: F401 +except ImportError as e: + logger.warning("Failed to import from vllm._C_stable_libtorch with %r", e) + try: import vllm._rocm_C # noqa: F401 except ImportError as e: