diff --git a/CMakeLists.txt b/CMakeLists.txt index fd6c7eeffd0..cf5754137f1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -369,7 +369,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_MakeAvailable(cutlass) list(APPEND VLLM_EXT_SRC - "csrc/quantization/awq/gemm_kernels.cu" "csrc/cutlass_extensions/common.cpp") set_gencode_flags_for_srcs( @@ -501,46 +500,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") " in CUDA target architectures") endif() - # Only build AllSpark kernels if we are building for at least some compatible archs. - cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}") - if (ALLSPARK_ARCHS) - set(ALLSPARK_SRCS - "csrc/quantization/gptq_allspark/allspark_repack.cu" - "csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu") - set_gencode_flags_for_srcs( - SRCS "${ALLSPARK_SRCS}" - CUDA_ARCHS "${ALLSPARK_ARCHS}") - list(APPEND VLLM_EXT_SRC "${ALLSPARK_SRCS}") - message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}") - else() - message(STATUS "Not building AllSpark kernels as no compatible archs found" - " in CUDA target architectures") - endif() - - # CUTLASS MLA Archs and flags - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(MLA_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS) - set(SRCS - "csrc/attention/mla/sm100_cutlass_mla_kernel.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${MLA_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1") - # Add MLA-specific include directories only to MLA source files - set_source_files_properties(${SRCS} - PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_DIR}/examples/77_blackwell_fmha;${CUTLASS_DIR}/examples/common") - message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}") - else() - message(STATUS "Not building CUTLASS MLA as no compatible archs were found.") - # clear MLA_ARCHS - set(MLA_ARCHS) - endif() - # Expert-specialization MXFP8 blockscaled grouped kernels (SM100+). if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") @@ -568,24 +527,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() - # DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper and later) - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0a;10.1a;10.3a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND DSV3_FUSED_A_GEMM_ARCHS) - set(DSV3_FUSED_A_GEMM_SRC "csrc/dsv3_fused_a_gemm.cu") - set_gencode_flags_for_srcs( - SRCS "${DSV3_FUSED_A_GEMM_SRC}" - CUDA_ARCHS "${DSV3_FUSED_A_GEMM_ARCHS}") - list(APPEND VLLM_EXT_SRC ${DSV3_FUSED_A_GEMM_SRC}) - message(STATUS "Building dsv3_fused_a_gemm for archs: ${DSV3_FUSED_A_GEMM_ARCHS}") - else() - message(STATUS "Not building dsv3_fused_a_gemm as no compatible archs found " - "in CUDA target architectures.") - endif() - # # Machete kernels @@ -657,16 +598,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() - # Hadacore kernels - cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") - if(HADACORE_ARCHS) - set(SRCS "csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${HADACORE_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - message(STATUS "Building hadacore") - endif() # if CUDA endif endif() @@ -716,7 +647,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_STABLE_EXT_SRC "csrc/libtorch_stable/permute_cols.cu" "csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu" - "csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu") + "csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu" + "csrc/libtorch_stable/quantization/awq/gemm_kernels.cu") endif() if(VLLM_GPU_LANG STREQUAL "CUDA") @@ -725,6 +657,40 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") CUDA_ARCHS "${CUDA_ARCHS}") endif() + # DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper and later) + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0a;10.1a;10.3a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND DSV3_FUSED_A_GEMM_ARCHS) + set(SRCS "csrc/libtorch_stable/dsv3_fused_a_gemm.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${DSV3_FUSED_A_GEMM_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + message(STATUS "Building dsv3_fused_a_gemm for archs: ${DSV3_FUSED_A_GEMM_ARCHS}") + else() + message(STATUS "Not building dsv3_fused_a_gemm as no compatible archs found " + "in CUDA target architectures.") + endif() + + # Only build AllSpark kernels if we are building for at least some compatible archs. + cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}") + if (ALLSPARK_ARCHS) + set(SRCS + "csrc/libtorch_stable/quantization/gptq_allspark/allspark_repack.cu" + "csrc/libtorch_stable/quantization/gptq_allspark/allspark_qgemm_w8a16.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${ALLSPARK_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}") + else() + message(STATUS "Not building AllSpark kernels as no compatible archs found" + " in CUDA target architectures") + endif() + # # CUTLASS scaled_mm kernels (moved from _C to _C_stable_libtorch) # @@ -1034,6 +1000,41 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() + # CUTLASS MLA Archs and flags + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(MLA_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS) + set(SRCS + "csrc/libtorch_stable/attention/mla/sm100_cutlass_mla_kernel.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${MLA_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1") + # Add MLA-specific include directories only to MLA source files + set_source_files_properties(${SRCS} + PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_DIR}/examples/77_blackwell_fmha;${CUTLASS_DIR}/examples/common") + message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}") + else() + message(STATUS "Not building CUTLASS MLA as no compatible archs were found.") + # clear MLA_ARCHS + set(MLA_ARCHS) + endif() + + # Hadacore kernels + cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") + if(HADACORE_ARCHS) + set(SRCS "csrc/libtorch_stable/quantization/hadamard/hadacore/hadamard_transform_cuda.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${HADACORE_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + message(STATUS "Building hadacore") + endif() + message(STATUS "Enabling C_stable extension.") define_extension_target( _C_stable_libtorch diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index 68a8750f583..b6f39ed795f 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -1,7 +1,13 @@ #pragma once -// For TORCH_CHECK -#include +#include +#include +#include +#include +#include + +// For STD_TORCH_CHECK +#include namespace vllm { @@ -45,7 +51,7 @@ class ScalarType { // IEEE 754 compliant floating point type static constexpr ScalarType float_IEEE754(uint8_t exponent, uint8_t mantissa) { - TORCH_CHECK(mantissa > 0 && exponent > 0); + STD_TORCH_CHECK(mantissa > 0 && exponent > 0); return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); } @@ -53,11 +59,12 @@ class ScalarType { static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, bool finite_values_only, NanRepr nan_repr) { - TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr"); - TORCH_CHECK(mantissa > 0 && exponent > 0); - TORCH_CHECK(nan_repr != NAN_IEEE_754, - "use `float_IEEE754` constructor for floating point types that " - "follow IEEE 754 conventions"); + STD_TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr"); + STD_TORCH_CHECK(mantissa > 0 && exponent > 0); + STD_TORCH_CHECK( + nan_repr != NAN_IEEE_754, + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions"); return ScalarType(exponent, mantissa, true, 0, finite_values_only, nan_repr); } @@ -176,8 +183,8 @@ class ScalarType { private: double _floating_point_max() const { - TORCH_CHECK(mantissa <= 52 && exponent <= 11, - "Cannot represent max/min as a double for type ", str()); + STD_TORCH_CHECK(mantissa <= 52 && exponent <= 11, + "Cannot represent max/min as a double for type ", str()); uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { @@ -186,8 +193,8 @@ class ScalarType { uint64_t max_exponent = (uint64_t(1) << exponent) - 2; if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { - TORCH_CHECK(exponent < 11, - "Cannot represent max/min as a double for type ", str()); + STD_TORCH_CHECK(exponent < 11, + "Cannot represent max/min as a double for type ", str()); max_exponent += 1; } @@ -216,16 +223,17 @@ class ScalarType { if (is_floating_point()) { return {_floating_point_max()}; } else { - TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), - "Cannot represent max as a int64_t"); + STD_TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), + "Cannot represent max as a int64_t"); return {(int64_t(1) << mantissa) - 1}; } } constexpr std::variant _raw_min() const { if (is_floating_point()) { - TORCH_CHECK(is_signed(), - "We currently assume all floating point types are signed"); + STD_TORCH_CHECK( + is_signed(), + "We currently assume all floating point types are signed"); constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); double max = _floating_point_max(); @@ -233,8 +241,8 @@ class ScalarType { uint64_t min_raw = max_raw | sign_bit_double; return {*reinterpret_cast(&min_raw)}; } else { - TORCH_CHECK(!is_signed() || size_bits() <= 64, - "Cannot represent min as a int64_t"); + STD_TORCH_CHECK(!is_signed() || size_bits() <= 64, + "Cannot represent min as a int64_t"); if (is_signed()) { // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 // then perform an arithmetic shift right to set all the bits above diff --git a/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp b/csrc/libtorch_stable/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp similarity index 100% rename from csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp rename to csrc/libtorch_stable/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp diff --git a/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp b/csrc/libtorch_stable/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp similarity index 100% rename from csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp rename to csrc/libtorch_stable/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp diff --git a/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp b/csrc/libtorch_stable/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp similarity index 100% rename from csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp rename to csrc/libtorch_stable/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp diff --git a/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp b/csrc/libtorch_stable/attention/mla/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp similarity index 100% rename from csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp rename to csrc/libtorch_stable/attention/mla/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/libtorch_stable/attention/mla/sm100_cutlass_mla_kernel.cu similarity index 77% rename from csrc/attention/mla/sm100_cutlass_mla_kernel.cu rename to csrc/libtorch_stable/attention/mla/sm100_cutlass_mla_kernel.cu index d1874515cc8..55d75383476 100644 --- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu +++ b/csrc/libtorch_stable/attention/mla/sm100_cutlass_mla_kernel.cu @@ -18,13 +18,12 @@ limitations under the License. * Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929 * by Alcanderian JieXin Liang */ -#include "core/registration.h" +#include "libtorch_stable/torch_utils.h" + +#include -#include -#include #include #include -#include #include #include @@ -35,27 +34,27 @@ limitations under the License. // clang-format off #if !defined(CUDA_VERSION) || CUDA_VERSION < 12040 void sm100_cutlass_mla_decode( - torch::Tensor const& out, - torch::Tensor const& lse, - torch::Tensor const& q_nope, - torch::Tensor const& q_pe, - torch::Tensor const& kv_c_and_k_pe_cache, - torch::Tensor const& seq_lens, - torch::Tensor const& page_table, - torch::Tensor const& workspace, + torch::stable::Tensor const& out, + torch::stable::Tensor const& lse, + torch::stable::Tensor const& q_nope, + torch::stable::Tensor const& q_pe, + torch::stable::Tensor const& kv_c_and_k_pe_cache, + torch::stable::Tensor const& seq_lens, + torch::stable::Tensor const& page_table, + torch::stable::Tensor const& workspace, double sm_scale, int64_t num_kv_splits) { - TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode"); + STD_TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode"); } int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) { - TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size"); + STD_TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size"); } #else #define CUTLASS_CHECK(status) \ { \ cutlass::Status error = status; \ - TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \ + STD_TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \ } using namespace cute; @@ -100,23 +99,23 @@ struct MlaSm100 { template typename T::Fmha::Arguments args_from_options( - at::Tensor const& out, - at::Tensor const& lse, - at::Tensor const& q_nope, - at::Tensor const& q_pe, - at::Tensor const& kv_c_and_k_pe_cache, - at::Tensor const& seq_lens, - at::Tensor const& page_table, + torch::stable::Tensor const& out, + torch::stable::Tensor const& lse, + torch::stable::Tensor const& q_nope, + torch::stable::Tensor const& q_pe, + torch::stable::Tensor const& kv_c_and_k_pe_cache, + torch::stable::Tensor const& seq_lens, + torch::stable::Tensor const& page_table, double sm_scale, int64_t num_kv_splits) { cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = q_nope.device().index(); + hw_info.device_id = q_nope.get_device_index(); hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - int batches = q_nope.sizes()[0]; - int page_count_per_seq = page_table.sizes()[1]; - int page_count_total = kv_c_and_k_pe_cache.sizes()[0]; - int page_size = kv_c_and_k_pe_cache.sizes()[1]; + int batches = q_nope.size(0); + int page_count_per_seq = page_table.size(1); + int page_count_total = kv_c_and_k_pe_cache.size(0); + int page_size = kv_c_and_k_pe_cache.size(1); int max_seq_len = page_size * page_count_per_seq; using TileShapeH = typename T::TileShapeH; using TileShapeD = typename T::TileShapeD; @@ -186,14 +185,14 @@ typename T::Fmha::Arguments args_from_options( template void runMla( - at::Tensor const& out, - at::Tensor const& lse, - at::Tensor const& q_nope, - at::Tensor const& q_pe, - at::Tensor const& kv_c_and_k_pe_cache, - at::Tensor const& seq_lens, - at::Tensor const& page_table, - at::Tensor const& workspace, + torch::stable::Tensor const& out, + torch::stable::Tensor const& lse, + torch::stable::Tensor const& q_nope, + torch::stable::Tensor const& q_pe, + torch::stable::Tensor const& kv_c_and_k_pe_cache, + torch::stable::Tensor const& seq_lens, + torch::stable::Tensor const& page_table, + torch::stable::Tensor const& workspace, double sm_scale, int64_t num_kv_splits, cudaStream_t stream) { @@ -220,37 +219,37 @@ void runMla( }() void sm100_cutlass_mla_decode( - torch::Tensor const& out, - torch::Tensor const& lse, - torch::Tensor const& q_nope, - torch::Tensor const& q_pe, - torch::Tensor const& kv_c_and_k_pe_cache, - torch::Tensor const& seq_lens, - torch::Tensor const& page_table, - torch::Tensor const& workspace, + torch::stable::Tensor const& out, + torch::stable::Tensor const& lse, + torch::stable::Tensor const& q_nope, + torch::stable::Tensor const& q_pe, + torch::stable::Tensor const& kv_c_and_k_pe_cache, + torch::stable::Tensor const& seq_lens, + torch::stable::Tensor const& page_table, + torch::stable::Tensor const& workspace, double sm_scale, int64_t num_kv_splits) { - auto in_dtype = q_nope.dtype(); - at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()}; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device()); - const int page_size = kv_c_and_k_pe_cache.sizes()[1]; - + auto in_dtype = q_nope.scalar_type(); + torch::stable::accelerator::DeviceGuard device_guard(q_nope.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(q_nope.get_device_index()); + const int page_size = kv_c_and_k_pe_cache.size(1); + // NOTE(alcanderian): IsPersistent has bug with manual split_kv. // Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8) // Maybe per batch split kv will fix this. DISPATCH_BOOL(page_size == 128, IsPaged128, [&] { DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] { - if (in_dtype == at::ScalarType::Half) { + if (in_dtype == torch::headeronly::ScalarType::Half) { runMla>( out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); - } else if (in_dtype == at::ScalarType::BFloat16) { + } else if (in_dtype == torch::headeronly::ScalarType::BFloat16) { runMla>( out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); - } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { + } else if (in_dtype == torch::headeronly::ScalarType::Float8_e4m3fn) { runMla>( out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else { - TORCH_CHECK(false, "Unsupported input data type of MLA"); + STD_TORCH_CHECK(false, "Unsupported input data type of MLA"); } return true; }); @@ -280,12 +279,12 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba #endif -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode); +STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { + m.impl("sm100_cutlass_mla_decode", TORCH_BOX(&sm100_cutlass_mla_decode)); } -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CatchAll, m) { - m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size); +STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, m) { + m.impl("sm100_cutlass_mla_get_workspace_size", TORCH_BOX(&sm100_cutlass_mla_get_workspace_size)); } // clang-format on diff --git a/csrc/dsv3_fused_a_gemm.cu b/csrc/libtorch_stable/dsv3_fused_a_gemm.cu similarity index 93% rename from csrc/dsv3_fused_a_gemm.cu rename to csrc/libtorch_stable/dsv3_fused_a_gemm.cu index 65dff9c84ba..bdf749ddfcf 100644 --- a/csrc/dsv3_fused_a_gemm.cu +++ b/csrc/libtorch_stable/dsv3_fused_a_gemm.cu @@ -20,13 +20,15 @@ * limitations under the License. */ -#include -#include -#include -#include -#include +#include +#include +#include #include "core/registration.h" +#include "libtorch_stable/torch_utils.h" + +#include +#include #include #include @@ -34,7 +36,7 @@ namespace { inline int getSMVersion() { - auto* props = at::cuda::getCurrentDeviceProperties(); + auto* props = get_device_prop(); return props->major * 10 + props->minor; } @@ -700,37 +702,40 @@ template void invokeFusedAGemm<__nv_bfloat16, 7168, 2112, 16>( __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, int num_tokens, cudaStream_t); -void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, - torch::Tensor const& mat_b) { - TORCH_CHECK(mat_a.dim() == 2 && mat_b.dim() == 2 && output.dim() == 2); +void dsv3_fused_a_gemm(torch::stable::Tensor& output, + torch::stable::Tensor const& mat_a, + torch::stable::Tensor const& mat_b) { + STD_TORCH_CHECK(mat_a.dim() == 2 && mat_b.dim() == 2 && output.dim() == 2); int const num_tokens = mat_a.size(0); int const hd_in = mat_a.size(1); int const hd_out = mat_b.size(1); constexpr int kHdIn = 7168; constexpr int kHdOut = 2112; - TORCH_CHECK(num_tokens >= 1 && num_tokens <= 16, - "required 1 <= mat_a.shape[0] <= 16") - TORCH_CHECK(hd_in == kHdIn, "required mat_a.shape[1] == 7168") - TORCH_CHECK(hd_out == kHdOut, "required mat_b.shape[1] == 2112") - TORCH_CHECK(output.size(0) == num_tokens, - "required output.shape[0] == mat_a.shape[0]") - TORCH_CHECK(output.size(1) == hd_out, - "required output.shape[1] == mat_b.shape[1]") + STD_TORCH_CHECK(num_tokens >= 1 && num_tokens <= 16, + "required 1 <= mat_a.shape[0] <= 16"); + STD_TORCH_CHECK(hd_in == kHdIn, "required mat_a.shape[1] == 7168"); + STD_TORCH_CHECK(hd_out == kHdOut, "required mat_b.shape[1] == 2112"); + STD_TORCH_CHECK(output.size(0) == num_tokens, + "required output.shape[0] == mat_a.shape[0]"); + STD_TORCH_CHECK(output.size(1) == hd_out, + "required output.shape[1] == mat_b.shape[1]"); - TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); - TORCH_CHECK(output.stride(1) == 1, "output must be a row major tensor"); - TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor"); + STD_TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); + STD_TORCH_CHECK(output.stride(1) == 1, "output must be a row major tensor"); + STD_TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor"); - TORCH_CHECK(mat_a.scalar_type() == torch::kBFloat16 && - mat_b.scalar_type() == torch::kBFloat16, - "Only BFloat16 input dtype is supported") - TORCH_CHECK(output.scalar_type() == torch::kBFloat16, - "Only BFloat16 output dtype is supported") + STD_TORCH_CHECK( + mat_a.scalar_type() == torch::headeronly::ScalarType::BFloat16 && + mat_b.scalar_type() == torch::headeronly::ScalarType::BFloat16, + "Only BFloat16 input dtype is supported"); + STD_TORCH_CHECK( + output.scalar_type() == torch::headeronly::ScalarType::BFloat16, + "Only BFloat16 output dtype is supported"); - TORCH_CHECK(getSMVersion() >= 90, "required CUDA ARCH >= SM_90"); + STD_TORCH_CHECK(getSMVersion() >= 90, "required CUDA ARCH >= SM_90"); - auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device()); + auto stream = get_current_cuda_stream(mat_a.get_device_index()); if (num_tokens <= 8) { invokeFusedAGemm<__nv_bfloat16, kHdIn, kHdOut, 8>( reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()), @@ -746,6 +751,6 @@ void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, } } -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("dsv3_fused_a_gemm", &dsv3_fused_a_gemm); +STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { + m.impl("dsv3_fused_a_gemm", TORCH_BOX(&dsv3_fused_a_gemm)); } diff --git a/csrc/libtorch_stable/ops.h b/csrc/libtorch_stable/ops.h index 176cd500633..cdae5fff60f 100644 --- a/csrc/libtorch_stable/ops.h +++ b/csrc/libtorch_stable/ops.h @@ -143,4 +143,26 @@ void cutlass_mxfp4_group_mm(torch::stable::Tensor& output, const torch::stable::Tensor& expert_offsets, const torch::stable::Tensor& sf_offsets); +// AWQ ops +torch::stable::Tensor awq_gemm(torch::stable::Tensor _in_feats, + torch::stable::Tensor _kernel, + torch::stable::Tensor _scaling_factors, + torch::stable::Tensor _zeros, + int64_t split_k_iters); + +torch::stable::Tensor awq_dequantize(torch::stable::Tensor _kernel, + torch::stable::Tensor _scaling_factors, + torch::stable::Tensor _zeros, + int64_t split_k_iters, int64_t thx, + int64_t thy); + +// DSV3 fused A GEMM: conditionally compiled so declaration and impl +// registration are in the source file (dsv3_fused_a_gemm.cu) + +// AllSpark ops: declarations are in the source files +// (allspark_repack.cu and allspark_qgemm_w8a16.cu) + #endif + +torch::stable::Tensor hadacore_transform(torch::stable::Tensor& x, + bool inplace); diff --git a/csrc/quantization/awq/dequantize.cuh b/csrc/libtorch_stable/quantization/awq/dequantize.cuh similarity index 100% rename from csrc/quantization/awq/dequantize.cuh rename to csrc/libtorch_stable/quantization/awq/dequantize.cuh diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/libtorch_stable/quantization/awq/gemm_kernels.cu similarity index 89% rename from csrc/quantization/awq/gemm_kernels.cu rename to csrc/libtorch_stable/quantization/awq/gemm_kernels.cu index 53c47679cdd..c3702c52efc 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/libtorch_stable/quantization/awq/gemm_kernels.cu @@ -7,10 +7,11 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ -#include -#include +#include +#include +#include "libtorch_stable/torch_utils.h" -#include "dequantize.cuh" +#include "libtorch_stable/quantization/awq/dequantize.cuh" #include @@ -410,10 +411,11 @@ __global__ void __launch_bounds__(64) } // namespace awq } // namespace vllm -torch::Tensor awq_dequantize(torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, int64_t split_k_iters, - int64_t thx, int64_t thy) { +torch::stable::Tensor awq_dequantize(torch::stable::Tensor _kernel, + torch::stable::Tensor _scaling_factors, + torch::stable::Tensor _zeros, + int64_t split_k_iters, int64_t thx, + int64_t thy) { int in_c = _kernel.size(0); int qout_c = _kernel.size(1); int out_c = qout_c * 8; @@ -437,23 +439,24 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, y_blocks = (int)(in_c / 8); } - const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); + const torch::stable::accelerator::DeviceGuard device_guard( + _scaling_factors.get_device_index()); - auto options = torch::TensorOptions() - .dtype(_scaling_factors.dtype()) - .device(_scaling_factors.device()); - at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); + auto _de_kernel = + torch::stable::empty({in_c, out_c}, _scaling_factors.scalar_type(), + std::nullopt, _scaling_factors.device()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); - auto scaling_factors = - reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto kernel = reinterpret_cast(_kernel.mutable_data_ptr()); + auto de_kernel = reinterpret_cast( + _de_kernel.mutable_data_ptr()); + auto scaling_factors = reinterpret_cast( + _scaling_factors.mutable_data_ptr()); + auto zeros = reinterpret_cast(_zeros.mutable_data_ptr()); dim3 num_blocks(x_blocks, y_blocks); dim3 threads_per_block(x_thread, y_thread); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const cudaStream_t stream = get_current_cuda_stream(); vllm::awq::dequantize_weights<<>>( kernel, scaling_factors, zeros, de_kernel, G); @@ -466,27 +469,30 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] // assume that batch_size < 16 for now -torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, - torch::Tensor _scaling_factors, torch::Tensor _zeros, - int64_t split_k_iters) { +torch::stable::Tensor awq_gemm(torch::stable::Tensor _in_feats, + torch::stable::Tensor _kernel, + torch::stable::Tensor _scaling_factors, + torch::stable::Tensor _zeros, + int64_t split_k_iters) { int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); - const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + const torch::stable::accelerator::DeviceGuard device_guard( + _in_feats.get_device_index()); - auto options = torch::TensorOptions() - .dtype(_in_feats.dtype()) - .device(_in_feats.device()); - at::Tensor _out_feats = - torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); + auto _out_feats = torch::stable::empty( + {split_k_iters, num_in_feats, _kernel.size(1) * 8}, + _in_feats.scalar_type(), std::nullopt, _in_feats.device()); int num_out_feats = _out_feats.size(-2); int num_out_channels = _out_feats.size(-1); - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - auto scaling_factors = - reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto in_feats = reinterpret_cast( + _in_feats.mutable_data_ptr()); + auto kernel = reinterpret_cast(_kernel.mutable_data_ptr()); + auto out_feats = reinterpret_cast( + _out_feats.mutable_data_ptr()); + auto scaling_factors = reinterpret_cast( + _scaling_factors.mutable_data_ptr()); + auto zeros = reinterpret_cast(_zeros.mutable_data_ptr()); int group_size = num_in_channels / _scaling_factors.size(0); if (num_out_channels % 64 != 0) @@ -498,7 +504,7 @@ torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, if (num_out_channels % group_size != 0) throw std::invalid_argument("OC is not multiple of Group size"); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const cudaStream_t stream = get_current_cuda_stream(); if (num_out_channels % 128 == 0) { int j_factors1 = num_out_channels / 128 / 1; dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); @@ -522,5 +528,5 @@ torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); } - return _out_feats.sum(0); + return torch::stable::sum(_out_feats, 0); } diff --git a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu b/csrc/libtorch_stable/quantization/gptq_allspark/allspark_qgemm_w8a16.cu similarity index 92% rename from csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu rename to csrc/libtorch_stable/quantization/gptq_allspark/allspark_qgemm_w8a16.cu index e306ff02605..96dc3ecfc86 100644 --- a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu +++ b/csrc/libtorch_stable/quantization/gptq_allspark/allspark_qgemm_w8a16.cu @@ -1,20 +1,28 @@ #include "allspark_utils.cuh" -#include -#include "core/registration.h" + +#include +#include +#include +#include + #include -at::Tensor as_g_workspace; +#include "core/registration.h" +#include "libtorch_stable/torch_utils.h" + +torch::stable::Tensor as_g_workspace; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -torch::Tensor allspark_w8a16_gemm( - torch::Tensor const& a, torch::Tensor const& b_qweight, - torch::Tensor const& b_scales, std::optional const& b_qzeros, - int64_t n, int64_t group_size, int64_t sm_count, int64_t sm_version, +torch::stable::Tensor allspark_w8a16_gemm( + torch::stable::Tensor const& a, torch::stable::Tensor const& b_qweight, + torch::stable::Tensor const& b_scales, + std::optional const& b_qzeros, int64_t n, + int64_t group_size, int64_t sm_count, int64_t sm_version, int64_t CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) { - TORCH_CHECK_NOT_IMPLEMENTED( + STD_TORCH_CHECK_NOT_IMPLEMENTED( false, "allspark_w8a16_gemm(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); + return torch::stable::empty({1, 1}); } #else @@ -848,8 +856,8 @@ void restore_N32_K16_dequantize_rhs_w8a16(const QT* qdata, const FT* scales, const int N_32align, const int N, const int K, const int GroupSize, cudaStream_t stream) { - TORCH_CHECK(N % 8 == 0 && K % 16 == 0 && N_32align % 32 == 0, - "Unsupported shape"); + STD_TORCH_CHECK(N % 8 == 0 && K % 16 == 0 && N_32align % 32 == 0, + "Unsupported shape"); if (GroupSize == -1) { const int BLOCK = 128; dim3 grid(N_32align / 32, ((K / 16) + 3) / 4); @@ -859,7 +867,7 @@ void restore_N32_K16_dequantize_rhs_w8a16(const QT* qdata, const FT* scales, } // TODO: Support SubChannel else { - TORCH_CHECK(false, "Now only support PerChannel"); + STD_TORCH_CHECK(false, "Now only support PerChannel"); } } @@ -916,24 +924,27 @@ void allspark_qgemm_w8a16_perc_ampere( } // namespace allspark -torch::Tensor allspark_w8a16_gemm( - torch::Tensor const& a, torch::Tensor const& b_qweight, - torch::Tensor const& b_scales, std::optional const& b_qzeros, - int64_t n, int64_t group_size, int64_t sm_count, int64_t sm_version, +torch::stable::Tensor allspark_w8a16_gemm( + torch::stable::Tensor const& a, torch::stable::Tensor const& b_qweight, + torch::stable::Tensor const& b_scales, + std::optional const& b_qzeros, int64_t n, + int64_t group_size, int64_t sm_count, int64_t sm_version, int64_t CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) { // Verify device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + STD_TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + STD_TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - TORCH_CHECK(b_qweight.device().is_cuda(), "b_qweight is not on GPU"); - TORCH_CHECK(b_qweight.is_contiguous(), "b_qweight is not contiguous"); + STD_TORCH_CHECK(b_qweight.device().is_cuda(), "b_qweight is not on GPU"); + STD_TORCH_CHECK(b_qweight.is_contiguous(), "b_qweight is not contiguous"); - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + STD_TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + STD_TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); if (has_zp) { - TORCH_CHECK(b_qzeros.value().device().is_cuda(), "b_qzeros is not on GPU"); - TORCH_CHECK(b_qzeros.value().is_contiguous(), "b_qzeros is not contiguous"); + STD_TORCH_CHECK(b_qzeros.value().device().is_cuda(), + "b_qzeros is not on GPU"); + STD_TORCH_CHECK(b_qzeros.value().is_contiguous(), + "b_qzeros is not contiguous"); } int m = a.size(0); @@ -941,16 +952,17 @@ torch::Tensor allspark_w8a16_gemm( int k = a.size(1); // Verify shape - TORCH_CHECK(b_qweight.size(0) == n_32align, - "Shape mismatch: b_qweight.size(0) = ", b_qweight.size(0), - ", n_32align = ", n_32align); - TORCH_CHECK(b_qweight.size(1) == k, - "Shape mismatch: b_qweight.size(1) = ", b_qweight.size(1), - ", k = ", k); + STD_TORCH_CHECK(b_qweight.size(0) == n_32align, + "Shape mismatch: b_qweight.size(0) = ", b_qweight.size(0), + ", n_32align = ", n_32align); + STD_TORCH_CHECK(b_qweight.size(1) == k, + "Shape mismatch: b_qweight.size(1) = ", b_qweight.size(1), + ", k = ", k); - TORCH_CHECK(group_size == -1, "Currently only supports group_size = -1"); + STD_TORCH_CHECK(group_size == -1, "Currently only supports group_size = -1"); - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + const torch::stable::accelerator::DeviceGuard device_guard( + a.get_device_index()); const void* a_ptr = reinterpret_cast(a.data_ptr()); const uint8_t* b_ptr = reinterpret_cast(b_qweight.data_ptr()); const void* b_scale_ptr = reinterpret_cast(b_scales.data_ptr()); @@ -959,12 +971,12 @@ torch::Tensor allspark_w8a16_gemm( b_zero_ptr = reinterpret_cast(b_qzeros.value().data_ptr()); } - auto c_options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({m, n}, c_options); - void* c_ptr = reinterpret_cast(c.data_ptr()); + auto c = + torch::stable::empty({m, n}, a.scalar_type(), std::nullopt, a.device()); + void* c_ptr = reinterpret_cast(c.mutable_data_ptr()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cudaStream_t stream = get_current_cuda_stream(); + cublasHandle_t handle = get_current_cuda_blas_handle(); allspark::BlockTileSplitkParams fused_gemm_params; @@ -976,14 +988,15 @@ torch::Tensor allspark_w8a16_gemm( m, n, k, sm_count, fused_gemm_params); } - auto ws_options = torch::TensorOptions().dtype(at::kChar).device(a.device()); if (as_g_workspace.numel() < ws_size) { // ws_options: kChar, so numel() is bytes - as_g_workspace = torch::empty({long(ws_size)}, ws_options); + as_g_workspace = torch::stable::empty({static_cast(ws_size)}, + torch::headeronly::ScalarType::Char, + std::nullopt, a.device()); } void* ws = reinterpret_cast(as_g_workspace.data_ptr()); - if (a.dtype() == at::ScalarType::Half) { + if (a.scalar_type() == torch::headeronly::ScalarType::Half) { allspark::allspark_qgemm_w8a16_perc_ampere<__half, uint8_t>( reinterpret_cast(a_ptr), b_ptr, reinterpret_cast(b_scale_ptr), @@ -991,7 +1004,7 @@ torch::Tensor allspark_w8a16_gemm( reinterpret_cast<__half*>(c_ptr), m, n_32align, n, k, ws, fused_gemm_params, group_size, CUBLAS_M_THRESHOLD, sm_version, stream, handle); - } else if (a.dtype() == at::ScalarType::BFloat16) { + } else if (a.scalar_type() == torch::headeronly::ScalarType::BFloat16) { allspark::allspark_qgemm_w8a16_perc_ampere<__nv_bfloat16, uint8_t>( reinterpret_cast(a_ptr), b_ptr, reinterpret_cast(b_scale_ptr), @@ -1006,6 +1019,6 @@ torch::Tensor allspark_w8a16_gemm( #endif -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("allspark_w8a16_gemm", &allspark_w8a16_gemm); +STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { + m.impl("allspark_w8a16_gemm", TORCH_BOX(&allspark_w8a16_gemm)); } diff --git a/csrc/quantization/gptq_allspark/allspark_repack.cu b/csrc/libtorch_stable/quantization/gptq_allspark/allspark_repack.cu similarity index 67% rename from csrc/quantization/gptq_allspark/allspark_repack.cu rename to csrc/libtorch_stable/quantization/gptq_allspark/allspark_repack.cu index 7a5b2f95cc2..b325d30a041 100644 --- a/csrc/quantization/gptq_allspark/allspark_repack.cu +++ b/csrc/libtorch_stable/quantization/gptq_allspark/allspark_repack.cu @@ -1,6 +1,11 @@ #include "allspark_utils.cuh" -#include + +#include +#include +#include + #include "core/registration.h" +#include "libtorch_stable/torch_utils.h" namespace allspark { @@ -99,36 +104,40 @@ void rearrange_kn_weight_as_n32k16_order_ldg16( } // namespace allspark void rearrange_kn_weight_as_n32k16_order( - torch::Tensor const& b_qweight, torch::Tensor const& b_scales, - std::optional const& b_zeros, bool has_zp, - torch::Tensor& b_qweight_reorder, torch::Tensor& b_scales_reorder, - std::optional const& b_zeros_reorder, const int64_t K, - const int64_t N, const int64_t N_32align) { + torch::stable::Tensor const& b_qweight, + torch::stable::Tensor const& b_scales, + std::optional const& b_zeros, bool has_zp, + torch::stable::Tensor& b_qweight_reorder, + torch::stable::Tensor& b_scales_reorder, + std::optional const& b_zeros_reorder, + const int64_t K, const int64_t N, const int64_t N_32align) { // Verify device and strides - TORCH_CHECK(b_qweight.device().is_cuda(), "b_qweight is not on GPU"); - TORCH_CHECK(b_qweight.is_contiguous(), "b_qweight is not contiguous"); + STD_TORCH_CHECK(b_qweight.device().is_cuda(), "b_qweight is not on GPU"); + STD_TORCH_CHECK(b_qweight.is_contiguous(), "b_qweight is not contiguous"); - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + STD_TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + STD_TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - TORCH_CHECK(b_qweight_reorder.device().is_cuda(), - "b_qweight_reorder is not on GPU"); - TORCH_CHECK(b_qweight_reorder.is_contiguous(), - "b_qweight_reorder is not contiguous"); + STD_TORCH_CHECK(b_qweight_reorder.device().is_cuda(), + "b_qweight_reorder is not on GPU"); + STD_TORCH_CHECK(b_qweight_reorder.is_contiguous(), + "b_qweight_reorder is not contiguous"); - TORCH_CHECK(b_scales_reorder.device().is_cuda(), - "b_scales_reorder is not on GPU"); - TORCH_CHECK(b_scales_reorder.is_contiguous(), - "b_scales_reorder is not contiguous"); + STD_TORCH_CHECK(b_scales_reorder.device().is_cuda(), + "b_scales_reorder is not on GPU"); + STD_TORCH_CHECK(b_scales_reorder.is_contiguous(), + "b_scales_reorder is not contiguous"); if (has_zp) { - TORCH_CHECK(b_zeros.value().device().is_cuda(), "b_zeros is not on GPU"); - TORCH_CHECK(b_zeros.value().is_contiguous(), "b_zeros is not contiguous"); + STD_TORCH_CHECK(b_zeros.value().device().is_cuda(), + "b_zeros is not on GPU"); + STD_TORCH_CHECK(b_zeros.value().is_contiguous(), + "b_zeros is not contiguous"); - TORCH_CHECK(b_zeros_reorder.value().device().is_cuda(), - "b_zeros_reorder is not on GPU"); - TORCH_CHECK(b_zeros_reorder.value().is_contiguous(), - "b_zeros_reorder is not contiguous"); + STD_TORCH_CHECK(b_zeros_reorder.value().device().is_cuda(), + "b_zeros_reorder is not on GPU"); + STD_TORCH_CHECK(b_zeros_reorder.value().is_contiguous(), + "b_zeros_reorder is not contiguous"); } const uint8_t* matB = reinterpret_cast(b_qweight.data_ptr()); @@ -136,18 +145,20 @@ void rearrange_kn_weight_as_n32k16_order( const void* b_zero = has_zp ? b_zeros.value().data_ptr() : nullptr; uint8_t* matB_reorder = - reinterpret_cast(b_qweight_reorder.data_ptr()); - void* b_scale_reorder = b_scales_reorder.data_ptr(); - void* b_zero_reorder = has_zp ? b_zeros_reorder.value().data_ptr() : nullptr; + reinterpret_cast(b_qweight_reorder.mutable_data_ptr()); + void* b_scale_reorder = b_scales_reorder.mutable_data_ptr(); + void* b_zero_reorder = + has_zp ? b_zeros_reorder.value().mutable_data_ptr() : nullptr; - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (b_scales.dtype() == at::ScalarType::Half) { + cudaStream_t stream = get_current_cuda_stream(); + if (b_scales.scalar_type() == torch::headeronly::ScalarType::Half) { allspark::rearrange_kn_weight_as_n32k16_order_ldg16<__half>( matB, reinterpret_cast(b_scale), reinterpret_cast(b_zero), matB_reorder, reinterpret_cast<__half*>(b_scale_reorder), reinterpret_cast<__half*>(b_zero_reorder), K, N, N_32align, stream); - } else if (b_scales.dtype() == at::ScalarType::BFloat16) { + } else if (b_scales.scalar_type() == + torch::headeronly::ScalarType::BFloat16) { allspark::rearrange_kn_weight_as_n32k16_order_ldg16<__nv_bfloat16>( matB, reinterpret_cast(b_scale), reinterpret_cast(b_zero), matB_reorder, @@ -157,7 +168,7 @@ void rearrange_kn_weight_as_n32k16_order( } } -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { +STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { m.impl("rearrange_kn_weight_as_n32k16_order", - &rearrange_kn_weight_as_n32k16_order); + TORCH_BOX(&rearrange_kn_weight_as_n32k16_order)); } diff --git a/csrc/quantization/gptq_allspark/allspark_utils.cuh b/csrc/libtorch_stable/quantization/gptq_allspark/allspark_utils.cuh similarity index 99% rename from csrc/quantization/gptq_allspark/allspark_utils.cuh rename to csrc/libtorch_stable/quantization/gptq_allspark/allspark_utils.cuh index c7a6e96aff4..ce96c2d11fe 100644 --- a/csrc/quantization/gptq_allspark/allspark_utils.cuh +++ b/csrc/libtorch_stable/quantization/gptq_allspark/allspark_utils.cuh @@ -1,13 +1,12 @@ #pragma once -#include -#include -#include -#include -#include #include +#include +#include + #include -#include "../marlin/marlin_dtypes.cuh" + +#include "quantization/marlin/marlin_dtypes.cuh" using marlin::MarlinScalarType2; namespace allspark { diff --git a/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu b/csrc/libtorch_stable/quantization/hadamard/hadacore/hadamard_transform_cuda.cu similarity index 93% rename from csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu rename to csrc/libtorch_stable/quantization/hadamard/hadacore/hadamard_transform_cuda.cu index aff11326d78..665585caa46 100644 --- a/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu +++ b/csrc/libtorch_stable/quantization/hadamard/hadacore/hadamard_transform_cuda.cu @@ -11,18 +11,16 @@ Redistribution and use in source and binary forms, with or without modification, THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ***********/ -#include +#include "libtorch_stable/torch_utils.h" +#include "libtorch_stable/dispatch_utils.h" + +#include +#include + #include #include #include #include -#include - -#include -#include - -#include "core/registration.h" -#include "dispatch_utils.h" namespace hadacore { @@ -65,12 +63,12 @@ constexpr int launch_configs_big[7][3] = { }; // a 4x2, b 2x2, c 2x2 -template +template __device__ __forceinline__ void mma_m16_n8_k16_b16_b16_b16_noacc(b32 a0, b32 a1, b32 a2, b32 a3, b32 b0, b32 b1, b32& c0, b32& c1){ - static_assert(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16); + static_assert(dtype == torch::headeronly::ScalarType::Half || dtype == torch::headeronly::ScalarType::BFloat16); // d, a, b, c b32 zero = 0; - if constexpr(dtype == torch::ScalarType::Half) { + if constexpr(dtype == torch::headeronly::ScalarType::Half) { asm ( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " "{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n\t" @@ -89,7 +87,7 @@ __device__ __forceinline__ void mma_m16_n8_k16_b16_b16_b16_noacc(b32 a0, b32 a1, } // a 4x2, b 4x2, c 4x2 -template +template __device__ __forceinline__ void mma_m16_n16_k16_b16_b16_b16_noacc(b32 a0, b32 a1, b32 a2, b32 a3, b32 b0, b32 b1, b32 b2, b32 b3, b32& c0, b32& c1, b32& c2, b32& c3){ mma_m16_n8_k16_b16_b16_b16_noacc(a0, a1, a2, a3, b0, b1, c0, c1); mma_m16_n8_k16_b16_b16_b16_noacc(a0, a1, a2, a3, b2, b3, c2, c3); @@ -108,11 +106,11 @@ __device__ __forceinline__ void matrix_transpose_m8_n8_b16_inplace(b32& a0) { #define n_p(i) ((val_1n[i] & 0x0000FFFF) | val_1p[i] << 16) #define n_n(i) ((val_1n[i] & 0x0000FFFF) | val_1n[i] << 16) -template +template __global__ void __launch_bounds__(32 * warps_per_block, blocks_per_sm) // a is column major, b is row major hadamard_transform_kernel(b16* a, b16* out, int total_num_chunks) { - static_assert(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16, "Only fp16 and bf16 supported currently"); + static_assert(dtype == torch::headeronly::ScalarType::Half || dtype == torch::headeronly::ScalarType::BFloat16, "Only fp16 and bf16 supported currently"); b32 b_frag_all[num_chunks][4]; // for all chunks, holds matrix fragment (which takes 4 regs of b16x2 * 32 threads) @@ -162,8 +160,8 @@ hadamard_transform_kernel(b16* a, b16* out, int total_num_chunks) { constexpr b16 bf16_1p[4] = {0b0011111100110101, 0b0011111100000000, 0b0011111010110101, 0b0011111010000000}; constexpr b16 bf16_1n[4] = {0b1011111100110101, 0b1011111100000000, 0b1011111010110101, 0b1011111010000000}; - #define val_type_1p(i) (((dtype) == torch::ScalarType::Half) ? (fp16_1p[i]) : (bf16_1p[i])) - #define val_type_1n(i) (((dtype) == torch::ScalarType::Half) ? (fp16_1n[i]) : (bf16_1n[i])) + #define val_type_1p(i) (((dtype) == torch::headeronly::ScalarType::Half) ? (fp16_1p[i]) : (bf16_1p[i])) + #define val_type_1n(i) (((dtype) == torch::headeronly::ScalarType::Half) ? (fp16_1n[i]) : (bf16_1n[i])) constexpr b16 val_1p[4] = {val_type_1p(0), val_type_1p(1), val_type_1p(2), val_type_1p(3)}; constexpr b16 val_1n[4] = {val_type_1n(0), val_type_1n(1), val_type_1n(2), val_type_1n(3)}; @@ -684,14 +682,14 @@ constexpr int64_t ceil_div(int64_t a, int64_t b) { return (a + b - 1) / b; } -template +template void __forceinline__ run_kernel(b16* a_mat, b16* out, int64_t num_chunks, cudaStream_t stream) { int64_t shared_size = chunks_per_warp * warps_per_block * 128 * 4; dim3 block_size = 32 * warps_per_block; #define CHECK_SHARED_LIM() { \ if (shared_size > 48 * 1024) { \ - C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536)); \ + STD_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536)); \ } \ } \ @@ -714,10 +712,10 @@ void __forceinline__ run_kernel(b16* a_mat, b16* out, int64_t num_chunks, cudaSt kernel<<>>(a_mat, out, num_chunks); } - C10_CUDA_KERNEL_LAUNCH_CHECK(); + STD_CUDA_KERNEL_LAUNCH_CHECK(); } -template +template void run_fht(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cudaStream_t stream) { int64_t num_chunks = numel / 256; // caller required to ensure divisible by 256 // for size 256, use (2, 1) @@ -764,54 +762,54 @@ void run_fht(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cu } } -template void run_fht(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cudaStream_t stream); -template void run_fht(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cudaStream_t stream); +template void run_fht(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cudaStream_t stream); +template void run_fht(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cudaStream_t stream); } // namespace hadacore constexpr bool is_power_of_two(int x) { return x && !(x & (x - 1)); } -torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace) { +torch::stable::Tensor hadacore_transform(torch::stable::Tensor& x, bool inplace) { auto dtype = x.scalar_type(); - TORCH_CHECK(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16, "Only fp16 and bf16 supported currently"); - TORCH_CHECK(x.is_cuda()); - + STD_TORCH_CHECK(dtype == torch::headeronly::ScalarType::Half || dtype == torch::headeronly::ScalarType::BFloat16, "Only fp16 and bf16 supported currently"); + STD_TORCH_CHECK(x.is_cuda()); + const int had_size = x.size(-1); - TORCH_CHECK(is_power_of_two(had_size) && (had_size <= (1U << 15)), + STD_TORCH_CHECK(is_power_of_two(had_size) && (had_size <= (1U << 15)), "Only power of two Hadamard sizes up to 2^15 are supported, got ", had_size); - + const auto res_shape = x.sizes(); - x = x.reshape({-1, had_size}); - + x = torch::stable::reshape(x, {-1, had_size}); + auto numel = x.numel(); if (numel % 256 != 0) { - x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, 0, 0, (256 - numel % 256) / had_size})); + x = torch::stable::pad(x, {0, 0, 0, (256 - numel % 256) / had_size}); } - + if (x.stride(-1) != 1) { - x = x.contiguous(); + x = torch::stable::contiguous(x); } - torch::Tensor out = inplace ? x : torch::empty_like(x); + torch::stable::Tensor out = inplace ? x : torch::stable::empty_like(x); - at::cuda::CUDAGuard device_guard{(char)x.get_device()}; - auto stream = at::cuda::getCurrentCUDAStream().stream(); + torch::stable::accelerator::DeviceGuard device_guard(x.get_device_index()); + auto stream = get_current_cuda_stream(); - VLLM_DISPATCH_HALF_TYPES(x.scalar_type(), "hadacore_transform_runfht", [&] { - auto constexpr SCALAR_TYPE = c10::CppTypeToScalarType::value; + VLLM_STABLE_DISPATCH_HALF_TYPES(x.scalar_type(), "hadacore_transform_runfht", [&] { + auto constexpr SCALAR_TYPE = torch::headeronly::CppTypeToScalarType::value; hadacore::run_fht(x.data_ptr(), x.data_ptr(), x.numel(), had_size, stream); }); if (numel % 256 != 0) { - out = out.narrow(0, 0, numel / had_size); + out = torch::stable::narrow(out, 0, 0, numel / had_size); } if (inplace && out.data_ptr() != x.data_ptr()) { - x.copy_(out.view(res_shape)); + torch::stable::copy_(x, torch::stable::view(out, res_shape)); return x; } - return out.reshape(res_shape); + return torch::stable::reshape(out, res_shape); } -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("hadacore_transform", &hadacore_transform); +STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { + m.impl("hadacore_transform", TORCH_BOX(&hadacore_transform)); } diff --git a/csrc/libtorch_stable/torch_bindings.cpp b/csrc/libtorch_stable/torch_bindings.cpp index 124512e8162..0bbccd4222f 100644 --- a/csrc/libtorch_stable/torch_bindings.cpp +++ b/csrc/libtorch_stable/torch_bindings.cpp @@ -218,7 +218,54 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) { ops.def( "cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, " "Tensor)"); + + // SM100 CUTLASS MLA decode + // conditionally compiled so impl registrations are in source file + ops.def( + "sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope," + " Tensor q_pe, Tensor kv_c_and_k_pe_cache," + " Tensor seq_lens, Tensor page_table," + " Tensor workspace, float scale," + " int num_kv_splits) -> ()"); + + ops.def( + "sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches," + " int sm_count, int num_kv_splits) " + "-> int"); + // Quantized GEMM for AWQ. + ops.def( + "awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, " + "Tensor _zeros, SymInt split_k_iters) -> Tensor"); + + // Dequantization for AWQ. + ops.def( + "awq_dequantize(Tensor _kernel, Tensor _scaling_factors, " + "Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor"); + + // DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens). + // conditionally compiled so impl registration is in source file + ops.def( + "dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); + + // reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel + ops.def( + "rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, " + "Tensor? b_zeros, " + "bool has_zp, Tensor! b_qweight_reorder, Tensor! b_scales_reorder, " + "Tensor!? b_zeros_reorder, " + "int K, int N, int N_32align) -> ()"); + + // AllSpark quantization ops + ops.def( + "allspark_w8a16_gemm(Tensor a, Tensor b_qweight, Tensor b_scales, " + "Tensor? b_qzeros, " + "SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt " + "CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor"); #endif + + // Hadamard transforms + // conditionally compiled so impl registration is in source file + ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor"); } STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) { @@ -254,6 +301,16 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) { ops.impl("silu_and_mul_nvfp4_quant", TORCH_BOX(&silu_and_mul_nvfp4_quant)); // mxfp4_experts_quant: registered in mxfp4_experts_quant.cu (SM100 only). // W4A8 ops: registered in w4a8_mm_entry.cu / w4a8_grouped_mm_entry.cu. + + // AWQ ops + ops.impl("awq_gemm", TORCH_BOX(&awq_gemm)); + ops.impl("awq_dequantize", TORCH_BOX(&awq_dequantize)); + + // DSV3 fused A GEMM: conditionally compiled so impl registration is in + // source file (dsv3_fused_a_gemm.cu) + + // AllSpark ops: conditionally compiled so impl registrations are in source + // files (allspark_repack.cu and allspark_qgemm_w8a16.cu) #endif } diff --git a/csrc/libtorch_stable/torch_utils.h b/csrc/libtorch_stable/torch_utils.h index f5a80d63e1e..db2ff557c41 100644 --- a/csrc/libtorch_stable/torch_utils.h +++ b/csrc/libtorch_stable/torch_utils.h @@ -6,12 +6,71 @@ #include #include +#include #include +#include +#include +#include +#include + // Stable ABI equivalent of TORCH_CHECK_NOT_IMPLEMENTED. #define STD_TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \ STD_TORCH_CHECK(cond, "NotImplementedError: ", __VA_ARGS__) +// Device properties cache for stable ABI compatibility. +// Uses raw CUDA/HIP APIs instead of ATen functions. +// Using inline ensures a single instance across all translation units. +inline std::deque device_flags; +inline std::vector device_properties; +inline std::once_flag vectors_init_flag; + +inline void do_init_device_vectors() { + int device_count; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess) { + STD_TORCH_CHECK(false, "cudaGetDeviceCount failed: " + + std::string(cudaGetErrorString(err))); + } + device_flags.resize(device_count); + device_properties.resize(device_count); +} + +inline void initDeviceVectors() { + std::call_once(vectors_init_flag, do_init_device_vectors); +} + +inline void initDeviceProperty(int device_index) { + cudaDeviceProp device_prop{}; + cudaError_t err = cudaGetDeviceProperties(&device_prop, device_index); + if (err != cudaSuccess) { + STD_TORCH_CHECK(false, "cudaGetDeviceProperties failed: " + + std::string(cudaGetErrorString(err))); + } + device_properties[device_index] = device_prop; +} + +// Get device properties using raw CUDA/HIP APIs (stable ABI compatible). +// Caches results per device so cudaGetDeviceProperties is called at most once +// per device. +inline cudaDeviceProp* get_device_prop() { + initDeviceVectors(); + int device_index; + cudaError_t err = cudaGetDevice(&device_index); + if (err != cudaSuccess) { + STD_TORCH_CHECK( + false, "cudaGetDevice failed: " + std::string(cudaGetErrorString(err))); + } + STD_TORCH_CHECK(device_index >= 0 && static_cast(device_index) < + device_properties.size(), + "CUDA device index " + std::to_string(device_index) + + " out of range [0, " + + std::to_string(device_properties.size()) + ")"); + + std::call_once(device_flags[device_index], initDeviceProperty, device_index); + return &device_properties[device_index]; +} + // Utility to get the current CUDA stream for a given device using stable APIs. // Returns a cudaStream_t for use in kernel launches. inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) { @@ -20,3 +79,10 @@ inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) { aoti_torch_get_current_cuda_stream(device_index, &stream_ptr)); return reinterpret_cast(stream_ptr); } + +// Utility to get the current cuBLAS handle using stable APIs. +inline cublasHandle_t get_current_cuda_blas_handle() { + void* blas_handle_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(torch_get_current_cuda_blas_handle(&blas_handle_ptr)); + return reinterpret_cast(blas_handle_ptr); +} diff --git a/csrc/ops.h b/csrc/ops.h index 16a78f570cf..2a0819618e3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -200,19 +200,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor); -#ifndef USE_ROCM - -torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, - torch::Tensor _scaling_factors, torch::Tensor _zeros, - int64_t split_k_iters); - -torch::Tensor awq_dequantize(torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, int64_t split_k_iters, - int64_t thx, int64_t thy); - -#endif - torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, int64_t n, std::optional const& dtype); @@ -302,8 +289,6 @@ std::tuple allocate_shared_buffer_and_handle( int64_t open_mem_handle(torch::Tensor& mem_handle); void free_shared_buffer(int64_t buffer); -torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace); - #ifdef USE_ROCM fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional qr_max_size = std::nullopt); @@ -315,11 +300,6 @@ void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t qr_max_size(); #endif -#ifndef USE_ROCM -void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, - torch::Tensor const& mat_b); -#endif - #ifndef USE_ROCM torch::Tensor minimax_allreduce_rms(torch::Tensor const& input, torch::Tensor const& norm_weight, diff --git a/csrc/quantization/marlin/marlin.cuh b/csrc/quantization/marlin/marlin.cuh index 33fe52f605b..d3a91568349 100644 --- a/csrc/quantization/marlin/marlin.cuh +++ b/csrc/quantization/marlin/marlin.cuh @@ -2,10 +2,14 @@ #ifndef _marlin_cuh #define _marlin_cuh - #include - - #include - #include + // These torch headers are only needed by non-stable callers (e.g. ops.cu). + // Guard them so that stable ABI targets can still include marlin.cuh + // for Vec, constants, and cp_async helpers without pulling in torch/all.h. + #ifndef TORCH_TARGET_VERSION + #include + #include + #include + #endif #include #include #include diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7562d90c0b9..f0f3641d29a 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -263,22 +263,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Quantization ops #ifndef USE_ROCM - // DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens). - ops.def( - "dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); - // conditionally compiled so impl registration is in source file - - // Quantized GEMM for AWQ. - ops.def( - "awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, " - "Tensor _zeros, SymInt split_k_iters) -> Tensor"); - ops.impl("awq_gemm", torch::kCUDA, &awq_gemm); - - // Dequantization for AWQ. - ops.def( - "awq_dequantize(Tensor _kernel, Tensor _scaling_factors, " - "Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor"); - ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); // Note about marlin kernel 'workspace' arguments: // Technically these should be mutable since they are modified by the kernel. @@ -408,22 +392,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " -> ()"); // conditionally compiled so impl registration is in source file - // SM100 CUTLASS MLA decode - ops.def( - "sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope," - " Tensor q_pe, Tensor kv_c_and_k_pe_cache," - " Tensor seq_lens, Tensor page_table," - " Tensor workspace, float scale," - " int num_kv_splits) -> ()"); - // conditionally compiled so impl in source file - - // SM100 CUTLASS MLA workspace - ops.def( - "sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches," - " int sm_count, int num_kv_splits) " - "-> int"); - // conditionally compiled so impl in source file - #endif // Quantized GEMM for GPTQ. @@ -496,26 +464,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? last_chunk_indices) -> ()"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); - // Hadamard transforms - ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor"); - #ifndef USE_ROCM - // reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel - ops.def( - "rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, " - "Tensor? b_zeros, " - "bool has_zp, Tensor! b_qweight_reorder, Tensor! b_scales_reorder, " - "Tensor!? b_zeros_reorder, " - "int K, int N, int N_32align) -> ()"); - // conditionally compiled so impl in source file - - // AllSpark quantization ops - ops.def( - "allspark_w8a16_gemm(Tensor a, Tensor b_qweight, Tensor b_scales, " - "Tensor? b_qzeros, " - "SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt " - "CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor"); - ops.def( "minimax_allreduce_rms(" "Tensor input,"