From 56aff0dd15c0d01abc3f1de2b773aff2ed61d7a0 Mon Sep 17 00:00:00 2001 From: Chris Leonard Date: Thu, 4 Jun 2026 23:14:43 -0400 Subject: [PATCH] [10/n] Migrate cuda_view and silu_and_mul_per_block_quant kernels to torch stale ABI. (#44334) --- CMakeLists.txt | 18 ++-- csrc/cuda_view.cu | 59 ------------- .../cuda_utils_kernels.cu | 0 csrc/libtorch_stable/cuda_view.cu | 76 +++++++++++++++++ .../cutlass_extensions/common.cpp | 2 +- .../cutlass_extensions/common.hpp | 0 csrc/libtorch_stable/ops.h | 11 +++ .../cutlass_w4a8/w4a8_grouped_mm_entry.cu | 2 +- .../cutlass_w4a8/w4a8_mm_entry.cu | 2 +- .../fp4/mxfp4_blockwise_moe_kernel.cu | 2 +- .../fp4/nvfp4_blockwise_moe_kernel.cu | 2 +- .../quantization/fp4/nvfp4_quant_entry.cu | 2 +- .../quantization/fp4/nvfp4_scaled_mm_entry.cu | 2 +- .../fp4/nvfp4_scaled_mm_kernels.cu | 2 +- .../fp4/nvfp4_scaled_mm_sm120_kernels.cu | 2 +- .../fused_silu_mul_block_quant.cu | 83 ++++++++++--------- .../w8a8/cutlass/c3x/cutlass_gemm_caller.cuh | 2 +- .../w8a8/cutlass/c3x/scaled_mm.cuh | 2 +- .../w8a8/cutlass/c3x/scaled_mm_helper.hpp | 2 +- .../w8a8/cutlass/moe/grouped_mm_c3x.cuh | 2 +- .../w8a8/cutlass/scaled_mm_c2x.cuh | 2 +- .../w8a8/cutlass/scaled_mm_entry.cu | 2 +- csrc/libtorch_stable/torch_bindings.cpp | 33 ++++++++ csrc/ops.h | 8 -- csrc/torch_bindings.cpp | 31 ------- 25 files changed, 186 insertions(+), 163 deletions(-) delete mode 100644 csrc/cuda_view.cu rename csrc/{ => libtorch_stable}/cuda_utils_kernels.cu (100%) create mode 100644 csrc/libtorch_stable/cuda_view.cu rename csrc/{ => libtorch_stable}/cutlass_extensions/common.cpp (90%) rename csrc/{ => libtorch_stable}/cutlass_extensions/common.hpp (100%) rename csrc/{ => libtorch_stable}/quantization/fused_kernels/fused_silu_mul_block_quant.cu (63%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0652a5f066e..cd4a9c0b590 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -307,10 +307,7 @@ endif() # set(VLLM_EXT_SRC - "csrc/cuda_view.cu" - "csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu" "csrc/quantization/activation_kernels.cu" - "csrc/cuda_utils_kernels.cu" "csrc/torch_bindings.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") @@ -346,9 +343,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() FetchContent_MakeAvailable(cutlass) - list(APPEND VLLM_EXT_SRC - "csrc/cutlass_extensions/common.cpp") - set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" CUDA_ARCHS "${CUDA_ARCHS}") @@ -627,6 +621,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") # set(VLLM_STABLE_EXT_SRC "csrc/libtorch_stable/torch_bindings.cpp" + "csrc/libtorch_stable/cuda_view.cu" "csrc/libtorch_stable/activation_kernels.cu" "csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu" "csrc/libtorch_stable/quantization/w8a8/fp8/common.cu" @@ -639,6 +634,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") "csrc/libtorch_stable/layernorm_kernels.cu" "csrc/libtorch_stable/layernorm_quant_kernels.cu" "csrc/libtorch_stable/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" + "csrc/libtorch_stable/quantization/fused_kernels/fused_silu_mul_block_quant.cu" "csrc/libtorch_stable/attention/merge_attn_states.cu" "csrc/libtorch_stable/sampler.cu" "csrc/libtorch_stable/topk.cu" @@ -653,8 +649,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_STABLE_EXT_SRC - "csrc/cuda_utils_kernels.cu" - "csrc/cutlass_extensions/common.cpp" + "csrc/libtorch_stable/cuda_utils_kernels.cu" + "csrc/libtorch_stable/cutlass_extensions/common.cpp" "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu" "csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu" "csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu" @@ -1076,11 +1072,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") WITH_SOABI) # Set TORCH_TARGET_VERSION for stable ABI compatibility. - # This ensures we only use C-shim APIs available in PyTorch 2.10. + # This ensures we only use C-shim APIs available in PyTorch 2.11. # _C_stable_libtorch is abi compatible with PyTorch >= TORCH_TARGET_VERSION - # which is currently set to 2.10. + # which is currently set to 2.11. target_compile_definitions(_C_stable_libtorch PRIVATE - TORCH_TARGET_VERSION=0x020A000000000000ULL) + TORCH_TARGET_VERSION=0x020B000000000000ULL) # Needed to use cuda/hip APIs from C-shim if(VLLM_GPU_LANG STREQUAL "CUDA") diff --git a/csrc/cuda_view.cu b/csrc/cuda_view.cu deleted file mode 100644 index 73b368cb600..00000000000 --- a/csrc/cuda_view.cu +++ /dev/null @@ -1,59 +0,0 @@ -#include -#include -#include - -// This function assumes that `cpu_tensor` is a CPU tensor, -// and that UVA (Unified Virtual Addressing) is enabled. -torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) { - TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU"); - - // handle empty tensor - if (cpu_tensor.numel() == 0) { - return torch::empty(cpu_tensor.sizes(), - cpu_tensor.options().device(torch::kCUDA)); - } - - if (cpu_tensor.is_pinned()) { - // If CPU tensor is pinned, directly get the device pointer. - void* host_ptr = const_cast(cpu_tensor.data_ptr()); - void* device_ptr = nullptr; - cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0); - TORCH_CHECK(err == cudaSuccess, - "cudaHostGetDevicePointer failed: ", cudaGetErrorString(err)); - - return torch::from_blob( - device_ptr, cpu_tensor.sizes(), cpu_tensor.strides(), - [base = cpu_tensor](void*) {}, // keep cpu tensor alive - cpu_tensor.options().device(torch::kCUDA)); - } - - // If CPU tensor is not pinned, allocate a new pinned memory buffer. - torch::Tensor contiguous_cpu = cpu_tensor.contiguous(); - size_t nbytes = contiguous_cpu.nbytes(); - - void* host_ptr = nullptr; - cudaError_t err = cudaHostAlloc(&host_ptr, nbytes, cudaHostAllocMapped); - if (err != cudaSuccess) { - AT_ERROR("cudaHostAlloc failed: ", cudaGetErrorString(err)); - } - - err = cudaMemcpy(host_ptr, contiguous_cpu.data_ptr(), nbytes, - cudaMemcpyDefault); - if (err != cudaSuccess) { - cudaFreeHost(host_ptr); - AT_ERROR("cudaMemcpy failed: ", cudaGetErrorString(err)); - } - - void* device_ptr = nullptr; - err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0); - if (err != cudaSuccess) { - cudaFreeHost(host_ptr); - AT_ERROR("cudaHostGetDevicePointer failed: ", cudaGetErrorString(err)); - } - - auto deleter = [host_ptr](void*) { cudaFreeHost(host_ptr); }; - - return torch::from_blob(device_ptr, contiguous_cpu.sizes(), - contiguous_cpu.strides(), deleter, - contiguous_cpu.options().device(torch::kCUDA)); -} \ No newline at end of file diff --git a/csrc/cuda_utils_kernels.cu b/csrc/libtorch_stable/cuda_utils_kernels.cu similarity index 100% rename from csrc/cuda_utils_kernels.cu rename to csrc/libtorch_stable/cuda_utils_kernels.cu diff --git a/csrc/libtorch_stable/cuda_view.cu b/csrc/libtorch_stable/cuda_view.cu new file mode 100644 index 00000000000..7bf8267470e --- /dev/null +++ b/csrc/libtorch_stable/cuda_view.cu @@ -0,0 +1,76 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +// This function assumes that `cpu_tensor` is a CPU tensor, +// and that UVA (Unified Virtual Addressing) is enabled. +torch::stable::Tensor get_cuda_view_from_cpu_tensor( + torch::stable::Tensor& cpu_tensor) { + STD_TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU"); + + const auto dtype = cpu_tensor.scalar_type(); + const auto layout = cpu_tensor.layout(); + const torch::stable::Device cuda_dev(torch::headeronly::DeviceType::CUDA); + + // handle empty tensor + if (cpu_tensor.numel() == 0) { + return torch::stable::empty(cpu_tensor.sizes(), dtype, layout, cuda_dev); + } + + std::array is_pinned_stack{ + torch::stable::detail::from(cpu_tensor), + torch::stable::detail::from(std::nullopt)}; + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::is_pinned", "", is_pinned_stack.data(), TORCH_ABI_VERSION)); + if (torch::stable::detail::to(is_pinned_stack[0])) { + // If CPU tensor is pinned, directly get the device pointer. + void* host_ptr = const_cast(cpu_tensor.mutable_data_ptr()); + void* device_ptr = nullptr; + cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0); + STD_TORCH_CHECK(err == cudaSuccess, "cudaHostGetDevicePointer failed: ", + cudaGetErrorString(err)); + + return torch::stable::from_blob( + device_ptr, cpu_tensor.sizes(), cpu_tensor.strides(), cuda_dev, dtype, + [base = cpu_tensor](void*) {}); // keep cpu tensor alive + } + + // If CPU tensor is not pinned, allocate a new pinned memory buffer. + torch::stable::Tensor contiguous_cpu = torch::stable::contiguous(cpu_tensor); + size_t nbytes = contiguous_cpu.numel() * contiguous_cpu.element_size(); + + void* host_ptr = nullptr; + cudaError_t err = cudaHostAlloc(&host_ptr, nbytes, cudaHostAllocMapped); + if (err != cudaSuccess) { + STD_TORCH_CHECK(false, "cudaHostAlloc failed: ", cudaGetErrorString(err)); + } + + err = cudaMemcpy(host_ptr, contiguous_cpu.const_data_ptr(), nbytes, + cudaMemcpyDefault); + if (err != cudaSuccess) { + cudaFreeHost(host_ptr); + STD_TORCH_CHECK(false, "cudaMemcpy failed: ", cudaGetErrorString(err)); + } + + void* device_ptr = nullptr; + err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0); + if (err != cudaSuccess) { + cudaFreeHost(host_ptr); + STD_TORCH_CHECK( + false, "cudaHostGetDevicePointer failed: ", cudaGetErrorString(err)); + } + + auto deleter = [host_ptr](void*) { cudaFreeHost(host_ptr); }; + + return torch::stable::from_blob(device_ptr, contiguous_cpu.sizes(), + contiguous_cpu.strides(), cuda_dev, + contiguous_cpu.scalar_type(), deleter); +} diff --git a/csrc/cutlass_extensions/common.cpp b/csrc/libtorch_stable/cutlass_extensions/common.cpp similarity index 90% rename from csrc/cutlass_extensions/common.cpp rename to csrc/libtorch_stable/cutlass_extensions/common.cpp index 3d2093ab942..5bc9463bfa6 100644 --- a/csrc/cutlass_extensions/common.cpp +++ b/csrc/libtorch_stable/cutlass_extensions/common.cpp @@ -1,4 +1,4 @@ -#include "cutlass_extensions/common.hpp" +#include "common.hpp" int32_t get_sm_version_num() { int32_t major_capability, minor_capability; diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/libtorch_stable/cutlass_extensions/common.hpp similarity index 100% rename from csrc/cutlass_extensions/common.hpp rename to csrc/libtorch_stable/cutlass_extensions/common.hpp diff --git a/csrc/libtorch_stable/ops.h b/csrc/libtorch_stable/ops.h index 0a991de76ff..536693fee96 100644 --- a/csrc/libtorch_stable/ops.h +++ b/csrc/libtorch_stable/ops.h @@ -164,6 +164,10 @@ torch::stable::Tensor awq_dequantize(torch::stable::Tensor _kernel, #endif +// CPU tensor -> CUDA UVA view (shared CUDA/ROCm) +torch::stable::Tensor get_cuda_view_from_cpu_tensor( + torch::stable::Tensor& cpu_tensor); + // Attention kernels (shared CUDA/ROCm) void merge_attn_states( torch::stable::Tensor& output, @@ -215,6 +219,13 @@ void rms_norm_per_block_quant(torch::stable::Tensor& out, std::optional residual, int64_t group_size, bool is_scale_transposed); +void silu_and_mul_per_block_quant(torch::stable::Tensor& out, + torch::stable::Tensor const& input, + torch::stable::Tensor& scales, + int64_t group_size, + std::optional scale_ub, + bool is_scale_transposed); + // Positional encoding kernels (shared CUDA/ROCm) void rotary_embedding(torch::stable::Tensor& positions, torch::stable::Tensor& query, diff --git a/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu b/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu index 1091d9d1230..53ffe521363 100644 --- a/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu +++ b/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu @@ -18,7 +18,7 @@ #include #include "libtorch_stable/torch_utils.h" #include "cutlass_extensions/torch_utils.hpp" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" #include "get_group_starts.cuh" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" diff --git a/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu b/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu index c2b8c0c00de..502f430b30b 100644 --- a/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu +++ b/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu @@ -21,7 +21,7 @@ #include "cutlass/util/packed_stride.hpp" #include "cutlass/util/mixed_dtype_utils.hpp" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include diff --git a/csrc/libtorch_stable/quantization/fp4/mxfp4_blockwise_moe_kernel.cu b/csrc/libtorch_stable/quantization/fp4/mxfp4_blockwise_moe_kernel.cu index 8a493fdf22c..04e98b6076a 100644 --- a/csrc/libtorch_stable/quantization/fp4/mxfp4_blockwise_moe_kernel.cu +++ b/csrc/libtorch_stable/quantization/fp4/mxfp4_blockwise_moe_kernel.cu @@ -12,7 +12,7 @@ #include -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" #include "cute/tensor.hpp" #include "cutlass/tensor_ref.h" diff --git a/csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu index b22308d25ca..88caf03fda3 100644 --- a/csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu @@ -20,7 +20,7 @@ #include -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" #include "cute/tensor.hpp" #include "cutlass/tensor_ref.h" diff --git a/csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu index 8d4ba1accc7..e1e7e7a74da 100644 --- a/csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu @@ -18,7 +18,7 @@ #include "libtorch_stable/torch_utils.h" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" #include "nvfp4_utils.cuh" #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ diff --git a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu index d7b2a18e29c..bfb526fcd40 100644 --- a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu @@ -18,7 +18,7 @@ #include "libtorch_stable/torch_utils.h" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" #if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D, diff --git a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu index fc83c6e8d34..86355bf7060 100644 --- a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu @@ -18,7 +18,7 @@ #include "libtorch_stable/torch_utils.h" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" #include "cutlass/cutlass.h" diff --git a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu index 2baa00caa82..7adba6308fa 100644 --- a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu @@ -18,7 +18,7 @@ #include "libtorch_stable/torch_utils.h" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" #include "cutlass/cutlass.h" diff --git a/csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu b/csrc/libtorch_stable/quantization/fused_kernels/fused_silu_mul_block_quant.cu similarity index 63% rename from csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu rename to csrc/libtorch_stable/quantization/fused_kernels/fused_silu_mul_block_quant.cu index d5c76232599..b32a7bd271f 100644 --- a/csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu +++ b/csrc/libtorch_stable/quantization/fused_kernels/fused_silu_mul_block_quant.cu @@ -1,11 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright contributors to the vLLM project -#include -#include +#include "../../torch_utils.h" #include "../../dispatch_utils.h" -#include "libtorch_stable/quantization/fused_kernels/quant_conversions.cuh" +#include "quant_conversions.cuh" namespace vllm { @@ -105,64 +104,70 @@ __global__ void silu_and_mul_per_block_quant_kernel( } // namespace vllm -void silu_and_mul_per_block_quant(torch::Tensor& out, - torch::Tensor const& input, - torch::Tensor& scales, int64_t group_size, - std::optional scale_ub, +void silu_and_mul_per_block_quant(torch::stable::Tensor& out, + torch::stable::Tensor const& input, + torch::stable::Tensor& scales, + int64_t group_size, + std::optional scale_ub, bool is_scale_transposed) { - static c10::ScalarType kFp8Type = is_fp8_ocp() - ? c10::ScalarType::Float8_e4m3fn - : c10::ScalarType::Float8_e4m3fnuz; + static torch::headeronly::ScalarType kFp8Type = + is_fp8_ocp() ? torch::headeronly::ScalarType::Float8_e4m3fn + : torch::headeronly::ScalarType::Float8_e4m3fnuz; - TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); - TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); - TORCH_CHECK( - input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16, + STD_TORCH_CHECK(out.scalar_type() == kFp8Type || + out.scalar_type() == torch::headeronly::ScalarType::Char); + STD_TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); + STD_TORCH_CHECK( + input.scalar_type() == torch::headeronly::ScalarType::Half || + input.scalar_type() == torch::headeronly::ScalarType::BFloat16, "Input must be FP16 or BF16"); - TORCH_CHECK(scales.dtype() == torch::kFloat32, "Scales must be FP32"); - TORCH_CHECK(group_size == 128 || group_size == 64, - "Unsupported group size: ", group_size); + STD_TORCH_CHECK(scales.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(group_size == 128 || group_size == 64, + "Unsupported group size: ", group_size); if (scale_ub.has_value()) { - TORCH_CHECK(out.dtype() == kFp8Type); + STD_TORCH_CHECK(out.scalar_type() == kFp8Type); } int32_t hidden_size = out.size(-1); auto num_tokens = input.size(0); int32_t num_groups = hidden_size / group_size; - TORCH_CHECK(input.size(-1) == hidden_size * 2, - "input last dim must be 2x output hidden_size"); - TORCH_CHECK(hidden_size % group_size == 0, - "hidden_size must be divisible by group_size"); + STD_TORCH_CHECK(input.size(-1) == hidden_size * 2, + "input last dim must be 2x output hidden_size"); + STD_TORCH_CHECK(hidden_size % group_size == 0, + "hidden_size must be divisible by group_size"); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(input.get_device_index()); dim3 grid(num_tokens, num_groups); dim3 block(group_size); - VLLM_DISPATCH_FLOATING_TYPES( + VLLM_STABLE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "silu_and_mul_per_block_quant", [&] { using scalar_in_t = scalar_t; - VLLM_DISPATCH_QUANT_TYPES( + VLLM_STABLE_DISPATCH_QUANT_TYPES( out.scalar_type(), "silu_and_mul_per_block_quant", [&] { using scalar_out_t = scalar_t; - VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] { - VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] { - vllm::silu_and_mul_per_block_quant_kernel< - scalar_in_t, scalar_out_t, transpose_scale, gs> - <<>>( - out.data_ptr(), - scales.data_ptr(), - input.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() - : nullptr, - hidden_size); - }); + VLLM_STABLE_DISPATCH_GROUP_SIZE(group_size, gs, [&] { + VLLM_STABLE_DISPATCH_BOOL( + is_scale_transposed, transpose_scale, [&] { + vllm::silu_and_mul_per_block_quant_kernel< + scalar_in_t, scalar_out_t, transpose_scale, gs> + <<>>( + out.mutable_data_ptr(), + scales.mutable_data_ptr(), + input.const_data_ptr(), + scale_ub.has_value() + ? scale_ub->const_data_ptr() + : nullptr, + hidden_size); + }); }); }); }); -} \ No newline at end of file +} diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh index ae40c0989e0..1eed7579924 100644 --- a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh @@ -20,7 +20,7 @@ #include "cutlass/util/packed_stride.hpp" #include "core/math.hpp" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" // clang-format on namespace vllm::c3x { diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh index 952931103c6..4cb591be056 100644 --- a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh @@ -15,7 +15,7 @@ #include "cutlass/gemm/collective/collective_builder.hpp" #include "core/math.hpp" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" // clang-format on /* diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp index adb3de50fc1..913436186c3 100644 --- a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp @@ -1,7 +1,7 @@ #include #include #include "cuda_utils.h" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" template void dispatch_scaled_mm(torch::stable::Tensor& c, diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh index 49df3fa4e7f..b523d7baeaa 100644 --- a/csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh @@ -8,7 +8,7 @@ #include #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" #include "get_group_starts.cuh" using namespace cute; diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cuh index 6eb2c051d00..7846e609fe7 100644 --- a/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cuh @@ -23,7 +23,7 @@ #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" #include "core/math.hpp" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" // clang-format on using namespace cute; diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu b/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu index 2e5bbca4700..0f9873cbf88 100644 --- a/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu @@ -4,7 +4,7 @@ #include "libtorch_stable/torch_utils.h" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" void cutlass_scaled_mm_sm75(torch::stable::Tensor& c, torch::stable::Tensor const& a, diff --git a/csrc/libtorch_stable/torch_bindings.cpp b/csrc/libtorch_stable/torch_bindings.cpp index 511a788eeae..d376fd43aa7 100644 --- a/csrc/libtorch_stable/torch_bindings.cpp +++ b/csrc/libtorch_stable/torch_bindings.cpp @@ -1,4 +1,5 @@ #include "ops.h" +#include "cuda_utils.h" #include "core/registration.h" #include @@ -27,6 +28,8 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) { "output_s, int group_size, float eps, float int8_min, float int8_max) -> " "()"); + ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor"); + #ifndef USE_ROCM ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); #endif @@ -321,6 +324,16 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) { "Tensor? scale_ub, Tensor!? residual, int group_size, " "bool is_scale_transposed) -> ()"); + // Fused SiLU+Mul + per-block quantization + ops.def( + "silu_and_mul_per_block_quant(" + "Tensor! out, " + "Tensor input, " + "Tensor! scales, " + "int group_size, " + "Tensor? scale_ub=None, " + "bool is_scale_transposed=False) -> ()"); + // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def( @@ -599,6 +612,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) { ops.impl("rms_norm_dynamic_per_token_quant", TORCH_BOX(&rms_norm_dynamic_per_token_quant)); ops.impl("rms_norm_per_block_quant", TORCH_BOX(&rms_norm_per_block_quant)); + ops.impl("silu_and_mul_per_block_quant", + TORCH_BOX(&silu_and_mul_per_block_quant)); // Positional encoding kernels (shared CUDA/ROCm) ops.impl("rotary_embedding", TORCH_BOX(&rotary_embedding)); @@ -661,6 +676,11 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) { ops.impl("paged_attention_v2", TORCH_BOX(&paged_attention_v2)); } +STABLE_TORCH_LIBRARY_IMPL(_C, CPU, ops) { + ops.impl("get_cuda_view_from_cpu_tensor", + TORCH_BOX(&get_cuda_view_from_cpu_tensor)); +} + // These capability-check functions take only primitive args (no tensors), so // there is no device to dispatch on. CompositeExplicitAutograd makes them // available for all backends. This is the stable ABI equivalent of calling @@ -681,6 +701,19 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) { ops.impl("ggml_moe_get_block_size", TORCH_BOX(&ggml_moe_get_block_size)); } +STABLE_TORCH_LIBRARY_FRAGMENT(_C_cuda_utils, cuda_utils) { + cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int"); + cuda_utils.def( + "get_max_shared_memory_per_block_device_attribute(int device_id) -> int"); +} + +STABLE_TORCH_LIBRARY_IMPL(_C_cuda_utils, CompositeExplicitAutograd, + cuda_utils) { + cuda_utils.impl("get_device_attribute", TORCH_BOX(&get_device_attribute)); + cuda_utils.impl("get_max_shared_memory_per_block_device_attribute", + TORCH_BOX(&get_max_shared_memory_per_block_device_attribute)); +} + // Cache ops STABLE_TORCH_LIBRARY_FRAGMENT(_C_cache_ops, ops) { // Swap in (out) the cache blocks from src to dst. diff --git a/csrc/ops.h b/csrc/ops.h index ed2fca26b0d..11b704aff29 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -40,12 +40,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); -void silu_and_mul_per_block_quant(torch::Tensor& out, - torch::Tensor const& input, - torch::Tensor& scales, int64_t group_size, - std::optional scale_ub, - bool is_scale_transposed); - // rotary_embedding also exist in csrc/libtorch_stable/ops.h (torch::stable // ABI for CUDA). It remains here because the CPU build still uses these // torch::Tensor declarations. @@ -84,8 +78,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, torch::Tensor const& seq_lens, torch::Tensor const& page_table, double scale); -torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor); - void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale, std::optional const& azp); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 3351638f574..f29b941affd 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -2,7 +2,6 @@ // cache.h, which is no longer included here after cache ops moved to // _C_stable_libtorch). #include -#include "cuda_utils.h" #include "ops.h" #include "core/registration.h" #include @@ -32,27 +31,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor); - ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor"); - ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU, - &get_cuda_view_from_cpu_tensor); - // Activation ops (quantized only — basic ops moved to _C_stable_libtorch) ops.def( "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); - // Fused SiLU+Mul + per-block quantization - ops.def( - "silu_and_mul_per_block_quant(" - "Tensor! out, " - "Tensor input, " - "Tensor! scales, " - "int group_size, " - "Tensor? scale_ub=None, " - "bool is_scale_transposed=False) -> ()"); - ops.impl("silu_and_mul_per_block_quant", torch::kCUDA, - &silu_and_mul_per_block_quant); - // Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and // GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one // kernel launch. Registered in _C_stable_libtorch (incl. the FlashInfer V4 @@ -178,18 +161,4 @@ TORCH_LIBRARY_FRAGMENT(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { } #endif -TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { - // Cuda utils - - // Gets the specified device attribute. - cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int"); - cuda_utils.impl("get_device_attribute", &get_device_attribute); - - // Gets the maximum shared memory per block device attribute. - cuda_utils.def( - "get_max_shared_memory_per_block_device_attribute(int device_id) -> int"); - cuda_utils.impl("get_max_shared_memory_per_block_device_attribute", - &get_max_shared_memory_per_block_device_attribute); -} - REGISTER_EXTENSION(TORCH_EXTENSION_NAME)