mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Bugfix] [ROCm] [Critical] fallback to regular abi for ROCm (#44648)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
+16
-8
@@ -590,8 +590,11 @@ endif()
|
|||||||
|
|
||||||
if (VLLM_GPU_LANG STREQUAL "HIP")
|
if (VLLM_GPU_LANG STREQUAL "HIP")
|
||||||
# Add QuickReduce kernels (ROCm-only; not part of stable ABI migration).
|
# Add QuickReduce kernels (ROCm-only; not part of stable ABI migration).
|
||||||
|
# TODO: Remove the cuda_view when ROCm upgrade to torch 2.11.
|
||||||
list(APPEND VLLM_EXT_SRC
|
list(APPEND VLLM_EXT_SRC
|
||||||
"csrc/custom_quickreduce.cu"
|
"csrc/custom_quickreduce.cu"
|
||||||
|
"csrc/cuda_view.cu"
|
||||||
|
"csrc/libtorch_stable/cuda_utils_kernels.cu"
|
||||||
)
|
)
|
||||||
# if ROCM endif
|
# if ROCM endif
|
||||||
endif()
|
endif()
|
||||||
@@ -621,7 +624,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
|||||||
#
|
#
|
||||||
set(VLLM_STABLE_EXT_SRC
|
set(VLLM_STABLE_EXT_SRC
|
||||||
"csrc/libtorch_stable/torch_bindings.cpp"
|
"csrc/libtorch_stable/torch_bindings.cpp"
|
||||||
"csrc/libtorch_stable/cuda_view.cu"
|
|
||||||
"csrc/libtorch_stable/activation_kernels.cu"
|
"csrc/libtorch_stable/activation_kernels.cu"
|
||||||
"csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu"
|
"csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu"
|
||||||
"csrc/libtorch_stable/quantization/w8a8/fp8/common.cu"
|
"csrc/libtorch_stable/quantization/w8a8/fp8/common.cu"
|
||||||
@@ -649,6 +651,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
|||||||
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
list(APPEND VLLM_STABLE_EXT_SRC
|
list(APPEND VLLM_STABLE_EXT_SRC
|
||||||
|
"csrc/libtorch_stable/cuda_view.cu"
|
||||||
"csrc/libtorch_stable/cuda_utils_kernels.cu"
|
"csrc/libtorch_stable/cuda_utils_kernels.cu"
|
||||||
"csrc/libtorch_stable/cutlass_extensions/common.cpp"
|
"csrc/libtorch_stable/cutlass_extensions/common.cpp"
|
||||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
|
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
|
||||||
@@ -1071,20 +1074,25 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
|||||||
USE_SABI 3
|
USE_SABI 3
|
||||||
WITH_SOABI)
|
WITH_SOABI)
|
||||||
|
|
||||||
# Set TORCH_TARGET_VERSION for stable ABI compatibility.
|
|
||||||
# This ensures we only use C-shim APIs available in PyTorch 2.11.
|
|
||||||
# _C_stable_libtorch is abi compatible with PyTorch >= TORCH_TARGET_VERSION
|
|
||||||
# which is currently set to 2.11.
|
|
||||||
target_compile_definitions(_C_stable_libtorch PRIVATE
|
|
||||||
TORCH_TARGET_VERSION=0x020B000000000000ULL)
|
|
||||||
|
|
||||||
# Needed to use cuda/hip APIs from C-shim
|
# Needed to use cuda/hip APIs from C-shim
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
|
# Set TORCH_TARGET_VERSION for stable ABI compatibility.
|
||||||
|
# This ensures we only use C-shim APIs available in PyTorch 2.11.
|
||||||
|
# _C_stable_libtorch is abi compatible with PyTorch >= TORCH_TARGET_VERSION
|
||||||
|
# which is currently set to 2.11.
|
||||||
|
target_compile_definitions(_C_stable_libtorch PRIVATE
|
||||||
|
TORCH_TARGET_VERSION=0x020B000000000000ULL)
|
||||||
target_compile_definitions(_C_stable_libtorch PRIVATE USE_CUDA)
|
target_compile_definitions(_C_stable_libtorch PRIVATE USE_CUDA)
|
||||||
# Needed by CUTLASS kernels
|
# Needed by CUTLASS kernels
|
||||||
target_compile_definitions(_C_stable_libtorch PRIVATE
|
target_compile_definitions(_C_stable_libtorch PRIVATE
|
||||||
CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
|
CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
|
||||||
elseif(VLLM_GPU_LANG STREQUAL "HIP")
|
elseif(VLLM_GPU_LANG STREQUAL "HIP")
|
||||||
|
# Set TORCH_TARGET_VERSION for stable ABI compatibility.
|
||||||
|
# This ensures we only use C-shim APIs available in PyTorch 2.10.
|
||||||
|
# _C_stable_libtorch is abi compatible with PyTorch >= TORCH_TARGET_VERSION
|
||||||
|
# which is currently set to 2.10.
|
||||||
|
target_compile_definitions(_C_stable_libtorch PRIVATE
|
||||||
|
TORCH_TARGET_VERSION=0x020A000000000000ULL)
|
||||||
target_compile_definitions(_C_stable_libtorch PRIVATE USE_ROCM)
|
target_compile_definitions(_C_stable_libtorch PRIVATE USE_ROCM)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,60 @@
|
|||||||
|
// TODO: Remove this once ROCm upgrade to torch 2.11.
|
||||||
|
#include <torch/all.h>
|
||||||
|
#include <torch/cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
// This function assumes that `cpu_tensor` is a CPU tensor,
|
||||||
|
// and that UVA (Unified Virtual Addressing) is enabled.
|
||||||
|
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
|
||||||
|
TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
|
||||||
|
|
||||||
|
// handle empty tensor
|
||||||
|
if (cpu_tensor.numel() == 0) {
|
||||||
|
return torch::empty(cpu_tensor.sizes(),
|
||||||
|
cpu_tensor.options().device(torch::kCUDA));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cpu_tensor.is_pinned()) {
|
||||||
|
// If CPU tensor is pinned, directly get the device pointer.
|
||||||
|
void* host_ptr = const_cast<void*>(cpu_tensor.data_ptr());
|
||||||
|
void* device_ptr = nullptr;
|
||||||
|
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
|
||||||
|
TORCH_CHECK(err == cudaSuccess,
|
||||||
|
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
|
||||||
|
|
||||||
|
return torch::from_blob(
|
||||||
|
device_ptr, cpu_tensor.sizes(), cpu_tensor.strides(),
|
||||||
|
[base = cpu_tensor](void*) {}, // keep cpu tensor alive
|
||||||
|
cpu_tensor.options().device(torch::kCUDA));
|
||||||
|
}
|
||||||
|
|
||||||
|
// If CPU tensor is not pinned, allocate a new pinned memory buffer.
|
||||||
|
torch::Tensor contiguous_cpu = cpu_tensor.contiguous();
|
||||||
|
size_t nbytes = contiguous_cpu.nbytes();
|
||||||
|
|
||||||
|
void* host_ptr = nullptr;
|
||||||
|
cudaError_t err = cudaHostAlloc(&host_ptr, nbytes, cudaHostAllocMapped);
|
||||||
|
if (err != cudaSuccess) {
|
||||||
|
AT_ERROR("cudaHostAlloc failed: ", cudaGetErrorString(err));
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cudaMemcpy(host_ptr, contiguous_cpu.data_ptr(), nbytes,
|
||||||
|
cudaMemcpyDefault);
|
||||||
|
if (err != cudaSuccess) {
|
||||||
|
cudaFreeHost(host_ptr);
|
||||||
|
AT_ERROR("cudaMemcpy failed: ", cudaGetErrorString(err));
|
||||||
|
}
|
||||||
|
|
||||||
|
void* device_ptr = nullptr;
|
||||||
|
err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
|
||||||
|
if (err != cudaSuccess) {
|
||||||
|
cudaFreeHost(host_ptr);
|
||||||
|
AT_ERROR("cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto deleter = [host_ptr](void*) { cudaFreeHost(host_ptr); };
|
||||||
|
|
||||||
|
return torch::from_blob(device_ptr, contiguous_cpu.sizes(),
|
||||||
|
contiguous_cpu.strides(), deleter,
|
||||||
|
contiguous_cpu.options().device(torch::kCUDA));
|
||||||
|
}
|
||||||
@@ -162,12 +162,13 @@ torch::stable::Tensor awq_dequantize(torch::stable::Tensor _kernel,
|
|||||||
// AllSpark ops: declarations are in the source files
|
// AllSpark ops: declarations are in the source files
|
||||||
// (allspark_repack.cu and allspark_qgemm_w8a16.cu)
|
// (allspark_repack.cu and allspark_qgemm_w8a16.cu)
|
||||||
|
|
||||||
#endif
|
// TODO: Move this out once ROCm upgrade their torch to 2.11.
|
||||||
|
// CPU tensor -> CUDA UVA view (shared CUDA)
|
||||||
// CPU tensor -> CUDA UVA view (shared CUDA/ROCm)
|
|
||||||
torch::stable::Tensor get_cuda_view_from_cpu_tensor(
|
torch::stable::Tensor get_cuda_view_from_cpu_tensor(
|
||||||
torch::stable::Tensor& cpu_tensor);
|
torch::stable::Tensor& cpu_tensor);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
// Attention kernels (shared CUDA/ROCm)
|
// Attention kernels (shared CUDA/ROCm)
|
||||||
void merge_attn_states(
|
void merge_attn_states(
|
||||||
torch::stable::Tensor& output,
|
torch::stable::Tensor& output,
|
||||||
|
|||||||
@@ -28,9 +28,11 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
|||||||
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
|
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
|
||||||
"()");
|
"()");
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
|
||||||
|
// TODO: Remove this once ROCm upgrade to torch 2.11.
|
||||||
ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
|
ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@@ -676,11 +678,28 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
|||||||
ops.impl("paged_attention_v2", TORCH_BOX(&paged_attention_v2));
|
ops.impl("paged_attention_v2", TORCH_BOX(&paged_attention_v2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Remove this once ROCm upgrade to torch 2.11.
|
||||||
|
#ifndef USE_ROCM
|
||||||
STABLE_TORCH_LIBRARY_IMPL(_C, CPU, ops) {
|
STABLE_TORCH_LIBRARY_IMPL(_C, CPU, ops) {
|
||||||
ops.impl("get_cuda_view_from_cpu_tensor",
|
ops.impl("get_cuda_view_from_cpu_tensor",
|
||||||
TORCH_BOX(&get_cuda_view_from_cpu_tensor));
|
TORCH_BOX(&get_cuda_view_from_cpu_tensor));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
STABLE_TORCH_LIBRARY_FRAGMENT(_C_cuda_utils, cuda_utils) {
|
||||||
|
cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
|
||||||
|
cuda_utils.def(
|
||||||
|
"get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
|
||||||
|
}
|
||||||
|
|
||||||
|
STABLE_TORCH_LIBRARY_IMPL(_C_cuda_utils, CompositeExplicitAutograd,
|
||||||
|
cuda_utils) {
|
||||||
|
cuda_utils.impl("get_device_attribute", TORCH_BOX(&get_device_attribute));
|
||||||
|
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
|
||||||
|
TORCH_BOX(&get_max_shared_memory_per_block_device_attribute));
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
// These capability-check functions take only primitive args (no tensors), so
|
// These capability-check functions take only primitive args (no tensors), so
|
||||||
// there is no device to dispatch on. CompositeExplicitAutograd makes them
|
// there is no device to dispatch on. CompositeExplicitAutograd makes them
|
||||||
// available for all backends. This is the stable ABI equivalent of calling
|
// available for all backends. This is the stable ABI equivalent of calling
|
||||||
@@ -701,19 +720,6 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
|
|||||||
ops.impl("ggml_moe_get_block_size", TORCH_BOX(&ggml_moe_get_block_size));
|
ops.impl("ggml_moe_get_block_size", TORCH_BOX(&ggml_moe_get_block_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
STABLE_TORCH_LIBRARY_FRAGMENT(_C_cuda_utils, cuda_utils) {
|
|
||||||
cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
|
|
||||||
cuda_utils.def(
|
|
||||||
"get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
|
|
||||||
}
|
|
||||||
|
|
||||||
STABLE_TORCH_LIBRARY_IMPL(_C_cuda_utils, CompositeExplicitAutograd,
|
|
||||||
cuda_utils) {
|
|
||||||
cuda_utils.impl("get_device_attribute", TORCH_BOX(&get_device_attribute));
|
|
||||||
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
|
|
||||||
TORCH_BOX(&get_max_shared_memory_per_block_device_attribute));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cache ops
|
// Cache ops
|
||||||
STABLE_TORCH_LIBRARY_FRAGMENT(_C_cache_ops, ops) {
|
STABLE_TORCH_LIBRARY_FRAGMENT(_C_cache_ops, ops) {
|
||||||
// Swap in (out) the cache blocks from src to dst.
|
// Swap in (out) the cache blocks from src to dst.
|
||||||
|
|||||||
@@ -102,4 +102,7 @@ void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles);
|
|||||||
void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||||
int64_t quant_level, bool cast_bf2half = false);
|
int64_t quant_level, bool cast_bf2half = false);
|
||||||
int64_t qr_max_size();
|
int64_t qr_max_size();
|
||||||
|
|
||||||
|
// TODO: Remove this once ROCm upgrade to torch 2.11.
|
||||||
|
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
// cache.h, which is no longer included here after cache ops moved to
|
// cache.h, which is no longer included here after cache ops moved to
|
||||||
// _C_stable_libtorch).
|
// _C_stable_libtorch).
|
||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
|
#include "cuda_utils.h"
|
||||||
#include "ops.h"
|
#include "ops.h"
|
||||||
#include "core/registration.h"
|
#include "core/registration.h"
|
||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
@@ -31,6 +32,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||||
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
||||||
|
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
// TODO: Remove this once we upgrade to torch 2.11.
|
||||||
|
// ROCm still uses torch 2.10,
|
||||||
|
// So we still need to use unstable torch ABI for now.
|
||||||
|
ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
|
||||||
|
ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU,
|
||||||
|
&get_cuda_view_from_cpu_tensor);
|
||||||
|
#endif
|
||||||
|
|
||||||
// Activation ops (quantized only — basic ops moved to _C_stable_libtorch)
|
// Activation ops (quantized only — basic ops moved to _C_stable_libtorch)
|
||||||
ops.def(
|
ops.def(
|
||||||
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
|
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
|
||||||
@@ -159,6 +169,20 @@ TORCH_LIBRARY_FRAGMENT(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
|||||||
|
|
||||||
custom_ar.def("qr_max_size", &qr_max_size);
|
custom_ar.def("qr_max_size", &qr_max_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Remove this once ROCm upgrade to torch 2.11.
|
||||||
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||||
|
// Cuda utils
|
||||||
|
// Gets the specified device attribute.
|
||||||
|
cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
|
||||||
|
cuda_utils.impl("get_device_attribute", &get_device_attribute);
|
||||||
|
|
||||||
|
// Gets the maximum shared memory per block device attribute.
|
||||||
|
cuda_utils.def(
|
||||||
|
"get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
|
||||||
|
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
|
||||||
|
&get_max_shared_memory_per_block_device_attribute);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||||
|
|||||||
Reference in New Issue
Block a user