mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[10b/n] Migrate custom all-reduce, DeepSeek V4 fused MLA, MiniMax reduce-RMS, and MXFP8 MoE to libtorch stable ABI (#44365)
Signed-off-by: Chris Leonard <chleonar@redhat.com> Signed-off-by: Shengqi Chen <harry-chen@outlook.com> Co-authored-by: Shengqi Chen <harry-chen@outlook.com>
This commit is contained in:
+13
-12
@@ -311,14 +311,9 @@ set(VLLM_EXT_SRC
|
||||
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
|
||||
"csrc/quantization/activation_kernels.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/torch_bindings.cpp"
|
||||
"csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu")
|
||||
"csrc/torch_bindings.cpp")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/minimax_reduce_rms_kernel.cu")
|
||||
|
||||
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
||||
|
||||
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
|
||||
@@ -505,12 +500,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND ES_MXFP8_GROUPED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu"
|
||||
"csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu")
|
||||
"csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu"
|
||||
"csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${ES_MXFP8_GROUPED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_ES_MXFP8_GROUPED_MM_SM100=1")
|
||||
message(STATUS "Building ES MXFP8 grouped kernels for archs: ${ES_MXFP8_GROUPED_MM_ARCHS}")
|
||||
else()
|
||||
@@ -600,7 +595,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
|
||||
if (VLLM_GPU_LANG STREQUAL "HIP")
|
||||
# Add QuickReduce kernels
|
||||
# Add QuickReduce kernels (ROCm-only; not part of stable ABI migration).
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/custom_quickreduce.cu"
|
||||
)
|
||||
@@ -651,7 +646,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
"csrc/libtorch_stable/attention/paged_attention_v1.cu"
|
||||
"csrc/libtorch_stable/attention/paged_attention_v2.cu"
|
||||
"csrc/libtorch_stable/cache_kernels.cu"
|
||||
"csrc/libtorch_stable/cache_kernels_fused.cu")
|
||||
"csrc/libtorch_stable/cache_kernels.cu"
|
||||
"csrc/libtorch_stable/cache_kernels_fused.cu"
|
||||
"csrc/libtorch_stable/custom_all_reduce.cu"
|
||||
"csrc/libtorch_stable/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC
|
||||
@@ -661,7 +659,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||
"csrc/libtorch_stable/permute_cols.cu"
|
||||
"csrc/libtorch_stable/quantization/awq/gemm_kernels.cu")
|
||||
"csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu"
|
||||
"csrc/libtorch_stable/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/libtorch_stable/minimax_reduce_rms_kernel.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_STABLE_EXT_SRC}"
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/all.h>
|
||||
#include "torch_utils.h"
|
||||
|
||||
#include <torch/csrc/stable/macros.h>
|
||||
#include <torch/csrc/stable/accelerator.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/csrc/stable/device.h>
|
||||
|
||||
#include "custom_all_reduce.cuh"
|
||||
|
||||
@@ -11,7 +15,7 @@ using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
|
||||
torch::Tensor& rank_data, int64_t rank,
|
||||
torch::stable::Tensor& rank_data, int64_t rank,
|
||||
bool fully_connected) {
|
||||
int world_size = fake_ipc_ptrs.size();
|
||||
if (world_size > 8)
|
||||
@@ -25,9 +29,9 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(fake_ipc_ptrs[i]);
|
||||
}
|
||||
return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(),
|
||||
rank_data.numel(), rank, world_size,
|
||||
fully_connected);
|
||||
return (fptr_t) new vllm::CustomAllreduce(
|
||||
ipc_ptrs, rank_data.mutable_data_ptr(), rank_data.numel(), rank,
|
||||
world_size, fully_connected);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -46,10 +50,14 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
|
||||
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
||||
* 6. A[:, 1:, 1:]: Not OK
|
||||
*/
|
||||
bool _is_weak_contiguous(torch::Tensor& t) {
|
||||
return t.is_contiguous() ||
|
||||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
|
||||
t.numel() * t.element_size());
|
||||
bool _is_weak_contiguous(torch::stable::Tensor& t) {
|
||||
if (t.is_contiguous()) {
|
||||
return true;
|
||||
}
|
||||
int64_t storage_nbytes = 0;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_storage_size(t.get(), &storage_nbytes));
|
||||
return storage_nbytes - t.storage_offset() * t.element_size() ==
|
||||
static_cast<int64_t>(t.numel() * t.element_size());
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -59,42 +67,45 @@ bool _is_weak_contiguous(torch::Tensor& t) {
|
||||
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
|
||||
* copied into _reg_buffer.
|
||||
*/
|
||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
|
||||
void all_reduce(fptr_t _fa, torch::stable::Tensor& inp,
|
||||
torch::stable::Tensor& out, fptr_t _reg_buffer,
|
||||
int64_t reg_buffer_sz_bytes) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
inp.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream(inp.get_device_index());
|
||||
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK(_is_weak_contiguous(out));
|
||||
TORCH_CHECK(_is_weak_contiguous(inp));
|
||||
STD_TORCH_CHECK((inp.scalar_type()) == (out.scalar_type()));
|
||||
STD_TORCH_CHECK((inp.numel()) == (out.numel()));
|
||||
STD_TORCH_CHECK(_is_weak_contiguous(out));
|
||||
STD_TORCH_CHECK(_is_weak_contiguous(inp));
|
||||
auto input_size = inp.numel() * inp.element_size();
|
||||
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
|
||||
if (reg_buffer) {
|
||||
TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size,
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
STD_TORCH_CHECK((input_size) <= (reg_buffer_sz_bytes));
|
||||
STD_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.const_data_ptr(), input_size,
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
} else {
|
||||
reg_buffer = inp.data_ptr();
|
||||
reg_buffer = inp.mutable_data_ptr();
|
||||
}
|
||||
switch (out.scalar_type()) {
|
||||
case at::ScalarType::Float: {
|
||||
case torch::headeronly::ScalarType::Float: {
|
||||
fa->allreduce<float>(stream, reinterpret_cast<float*>(reg_buffer),
|
||||
reinterpret_cast<float*>(out.data_ptr()),
|
||||
reinterpret_cast<float*>(out.mutable_data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
case torch::headeronly::ScalarType::Half: {
|
||||
fa->allreduce<half>(stream, reinterpret_cast<half*>(reg_buffer),
|
||||
reinterpret_cast<half*>(out.data_ptr()), out.numel());
|
||||
reinterpret_cast<half*>(out.mutable_data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
case at::ScalarType::BFloat16: {
|
||||
case torch::headeronly::ScalarType::BFloat16: {
|
||||
fa->allreduce<nv_bfloat16>(
|
||||
stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
|
||||
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
|
||||
reinterpret_cast<nv_bfloat16*>(out.mutable_data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
@@ -112,7 +123,7 @@ int64_t meta_size() { return sizeof(vllm::Signal); }
|
||||
|
||||
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
|
||||
STD_TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
|
||||
void* ipc_ptrs[8];
|
||||
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
|
||||
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
|
||||
@@ -143,47 +154,49 @@ void register_graph_buffers(fptr_t _fa,
|
||||
fa->register_graph_buffers(bytes, offsets);
|
||||
}
|
||||
|
||||
std::tuple<fptr_t, torch::Tensor> allocate_shared_buffer_and_handle(
|
||||
std::tuple<fptr_t, torch::stable::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size) {
|
||||
auto device_index = c10::cuda::current_device();
|
||||
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
|
||||
int device_index;
|
||||
STD_CUDA_CHECK(cudaGetDevice(&device_index));
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(device_index);
|
||||
void* buffer;
|
||||
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
const cudaStream_t stream = get_current_cuda_stream(device_index);
|
||||
STD_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
|
||||
// Allocate buffer
|
||||
#if defined(USE_ROCM)
|
||||
// data buffers need to be "uncached" for signal on MI200
|
||||
AT_CUDA_CHECK(
|
||||
STD_CUDA_CHECK(
|
||||
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
|
||||
#else
|
||||
AT_CUDA_CHECK(cudaMalloc((void**)&buffer, size));
|
||||
STD_CUDA_CHECK(cudaMalloc((void**)&buffer, size));
|
||||
#endif
|
||||
AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
|
||||
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
STD_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
|
||||
STD_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
STD_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
|
||||
// Create IPC memhandle for the allocated buffer.
|
||||
// Will use it in open_mem_handle.
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||
auto handle =
|
||||
torch::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, options);
|
||||
AT_CUDA_CHECK(
|
||||
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data_ptr(), buffer));
|
||||
auto handle = torch::stable::empty(
|
||||
{static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))},
|
||||
torch::headeronly::ScalarType::Byte, std::nullopt,
|
||||
torch::stable::Device(torch::stable::DeviceType::CPU));
|
||||
STD_CUDA_CHECK(cudaIpcGetMemHandle(
|
||||
(cudaIpcMemHandle_t*)handle.mutable_data_ptr(), buffer));
|
||||
|
||||
return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
|
||||
}
|
||||
|
||||
fptr_t open_mem_handle(torch::Tensor& mem_handle) {
|
||||
fptr_t open_mem_handle(torch::stable::Tensor& mem_handle) {
|
||||
void* ipc_ptr;
|
||||
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
|
||||
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data_ptr()),
|
||||
STD_CUDA_CHECK(cudaIpcOpenMemHandle(
|
||||
(void**)&ipc_ptr,
|
||||
*((const cudaIpcMemHandle_t*)mem_handle.const_data_ptr()),
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
return reinterpret_cast<fptr_t>(ipc_ptr);
|
||||
}
|
||||
|
||||
void free_shared_buffer(fptr_t buffer) {
|
||||
AT_CUDA_CHECK(cudaFree(reinterpret_cast<void*>(buffer)));
|
||||
STD_CUDA_CHECK(cudaFree(reinterpret_cast<void*>(buffer)));
|
||||
}
|
||||
+70
-56
@@ -28,7 +28,20 @@
|
||||
* [bs*576, bs*576 + bs*8): UE8M0 scales, 7 real + 1 pad per token
|
||||
*/
|
||||
|
||||
#include "torch_utils.h"
|
||||
|
||||
#include <torch/csrc/stable/macros.h>
|
||||
#include <torch/csrc/stable/accelerator.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/csrc/stable/device.h>
|
||||
|
||||
#include <cmath>
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
#include "type_convert.cuh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_fp8.h>
|
||||
#else
|
||||
@@ -37,14 +50,6 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/cuda.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
#include "type_convert.cuh"
|
||||
|
||||
#ifndef FINAL_MASK
|
||||
#ifdef USE_ROCM
|
||||
#define FINAL_MASK 0xffffffffffffffffULL
|
||||
@@ -70,7 +75,7 @@ namespace deepseek_v4_fused_ops {
|
||||
|
||||
namespace {
|
||||
inline int getSMVersion() {
|
||||
auto* props = at::cuda::getCurrentDeviceProperties();
|
||||
auto* props = get_device_prop();
|
||||
return props->major * 10 + props->minor;
|
||||
}
|
||||
} // namespace
|
||||
@@ -564,7 +569,7 @@ static void launchFusedDeepseekV4Templated(
|
||||
// bf16 on pre-Ampere (sm_70/sm_75) because _typeConvert<BFloat16> is
|
||||
// unavailable there. Refuse the launch loudly instead of silently
|
||||
// skipping the work.
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
sm_version >= 80,
|
||||
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert requires sm_80+ "
|
||||
"(Ampere or newer); got sm_",
|
||||
@@ -635,7 +640,7 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
|
||||
DISPATCH(64)
|
||||
DISPATCH(128)
|
||||
default:
|
||||
TORCH_CHECK(false,
|
||||
STD_TORCH_CHECK(false,
|
||||
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert: "
|
||||
"unsupported num_heads_q_padded=",
|
||||
num_heads_q_padded,
|
||||
@@ -650,71 +655,80 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
// Torch op wrapper
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
|
||||
torch::Tensor const& q_in, // [N, num_heads_q, 512] bf16
|
||||
torch::Tensor const& kv, // [N, 512] bf16 (read-only)
|
||||
torch::Tensor& k_cache, // [num_blocks, block_bytes] uint8
|
||||
torch::Tensor const& slot_mapping, // [N] int64
|
||||
torch::Tensor const& position_ids, // [N] int64
|
||||
torch::Tensor const& cos_sin_cache, // [max_pos, rope_dim] bf16
|
||||
int64_t q_head_padded, // padded Q head count for output
|
||||
torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
|
||||
torch::stable::Tensor const& q_in, // [N, num_heads_q, 512] bf16
|
||||
torch::stable::Tensor const& kv, // [N, 512] bf16 (read-only)
|
||||
torch::stable::Tensor& k_cache, // [num_blocks, block_bytes] uint8
|
||||
torch::stable::Tensor const& slot_mapping, // [N] int64
|
||||
torch::stable::Tensor const& position_ids, // [N] int64
|
||||
torch::stable::Tensor const& cos_sin_cache, // [max_pos, rope_dim] bf16
|
||||
int64_t q_head_padded, // padded Q head count for output
|
||||
double eps, int64_t cache_block_size) {
|
||||
TORCH_CHECK(q_in.is_cuda() && q_in.is_contiguous(),
|
||||
"q_in must be contiguous CUDA");
|
||||
TORCH_CHECK(kv.is_cuda() && kv.is_contiguous(), "kv must be contiguous CUDA");
|
||||
TORCH_CHECK(k_cache.is_cuda(), "k_cache must be CUDA");
|
||||
TORCH_CHECK(slot_mapping.is_cuda() && slot_mapping.dtype() == torch::kInt64,
|
||||
"slot_mapping must be int64 CUDA");
|
||||
TORCH_CHECK(position_ids.is_cuda() && position_ids.dtype() == torch::kInt64,
|
||||
"position_ids must be int64 CUDA");
|
||||
TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA");
|
||||
TORCH_CHECK(q_in.dim() == 3 && q_in.size(2) == 512,
|
||||
"q_in shape [N, num_heads_q, 512]");
|
||||
TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
|
||||
TORCH_CHECK(q_in.dtype() == kv.dtype(), "q_in and kv dtype must match");
|
||||
TORCH_CHECK(q_head_padded >= q_in.size(1),
|
||||
"q_head_padded must be >= q_in.size(1) (num_heads_q)");
|
||||
TORCH_CHECK(k_cache.dtype() == torch::kUInt8, "k_cache must be uint8");
|
||||
TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
|
||||
"cos_sin_cache shape [max_pos, 64]");
|
||||
TORCH_CHECK(cos_sin_cache.dtype() == torch::kFloat32,
|
||||
"cos_sin_cache must be float32");
|
||||
STD_TORCH_CHECK(q_in.device().is_cuda() && q_in.is_contiguous(),
|
||||
"q_in must be contiguous CUDA");
|
||||
STD_TORCH_CHECK(kv.device().is_cuda() && kv.is_contiguous(),
|
||||
"kv must be contiguous CUDA");
|
||||
STD_TORCH_CHECK(k_cache.device().is_cuda(), "k_cache must be CUDA");
|
||||
STD_TORCH_CHECK(slot_mapping.device().is_cuda() &&
|
||||
slot_mapping.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Long,
|
||||
"slot_mapping must be int64 CUDA");
|
||||
STD_TORCH_CHECK(position_ids.device().is_cuda() &&
|
||||
position_ids.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Long,
|
||||
"position_ids must be int64 CUDA");
|
||||
STD_TORCH_CHECK(cos_sin_cache.device().is_cuda(), "cos_sin_cache must be CUDA");
|
||||
STD_TORCH_CHECK(q_in.dim() == 3 && q_in.size(2) == 512,
|
||||
"q_in shape [N, num_heads_q, 512]");
|
||||
STD_TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
|
||||
STD_TORCH_CHECK(q_in.scalar_type() == kv.scalar_type(),
|
||||
"q_in and kv dtype must match");
|
||||
STD_TORCH_CHECK(q_head_padded >= q_in.size(1),
|
||||
"q_head_padded must be >= q_in.size(1) (num_heads_q)");
|
||||
STD_TORCH_CHECK(k_cache.scalar_type() == torch::headeronly::ScalarType::Byte,
|
||||
"k_cache must be uint8");
|
||||
STD_TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
|
||||
"cos_sin_cache shape [max_pos, 64]");
|
||||
STD_TORCH_CHECK(cos_sin_cache.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float,
|
||||
"cos_sin_cache must be float32");
|
||||
|
||||
// With DP padding, slot_mapping can be shorter than q/kv/positions.
|
||||
// Q-norm+RoPE runs on all q.size(0) rows (downstream attention uses them);
|
||||
// KV quant+insert runs only on the first slot_mapping.size(0) rows.
|
||||
int const num_tokens_full = static_cast<int>(q_in.size(0));
|
||||
int const num_tokens_insert = static_cast<int>(slot_mapping.size(0));
|
||||
TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
|
||||
static_cast<int>(position_ids.size(0)) == num_tokens_full,
|
||||
"q/kv/position_ids row counts must match");
|
||||
TORCH_CHECK(num_tokens_insert <= num_tokens_full,
|
||||
"slot_mapping must not exceed q row count");
|
||||
STD_TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
|
||||
static_cast<int>(position_ids.size(0)) == num_tokens_full,
|
||||
"q/kv/position_ids row counts must match");
|
||||
STD_TORCH_CHECK(num_tokens_insert <= num_tokens_full,
|
||||
"slot_mapping must not exceed q row count");
|
||||
int const num_heads_q = static_cast<int>(q_in.size(1));
|
||||
int const num_heads_q_padded = static_cast<int>(q_head_padded);
|
||||
int const cache_block_size_i = static_cast<int>(cache_block_size);
|
||||
int const kv_block_stride = static_cast<int>(k_cache.stride(0));
|
||||
|
||||
at::cuda::OptionalCUDAGuard device_guard(device_of(q_in));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
q_in.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream(q_in.get_device_index());
|
||||
|
||||
// Allocate the padded q output. The kernel writes every element (live
|
||||
// region gets RMSNorm+RoPE; pad region gets zeros), so `empty` is safe.
|
||||
torch::Tensor q_out = torch::empty(
|
||||
{q_in.size(0), q_head_padded, q_in.size(2)}, q_in.options());
|
||||
auto q_out = torch::stable::new_empty(
|
||||
q_in, {q_in.size(0), q_head_padded, q_in.size(2)}, q_in.scalar_type());
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(
|
||||
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||
q_in.scalar_type(), "fused_deepseek_v4_qnorm_rope_kv_insert", [&] {
|
||||
using qkv_scalar_t = scalar_t;
|
||||
vllm::deepseek_v4_fused_ops::
|
||||
launchFusedDeepseekV4QNormRopeKVRopeQuantInsert<qkv_scalar_t>(
|
||||
reinterpret_cast<qkv_scalar_t const*>(q_in.data_ptr()),
|
||||
reinterpret_cast<qkv_scalar_t*>(q_out.data_ptr()),
|
||||
reinterpret_cast<qkv_scalar_t const*>(kv.data_ptr()),
|
||||
reinterpret_cast<uint8_t*>(k_cache.data_ptr()),
|
||||
reinterpret_cast<int64_t const*>(slot_mapping.data_ptr()),
|
||||
reinterpret_cast<int64_t const*>(position_ids.data_ptr()),
|
||||
cos_sin_cache.data_ptr<float>(), static_cast<float>(eps),
|
||||
reinterpret_cast<qkv_scalar_t const*>(q_in.const_data_ptr()),
|
||||
reinterpret_cast<qkv_scalar_t*>(q_out.mutable_data_ptr()),
|
||||
reinterpret_cast<qkv_scalar_t const*>(kv.const_data_ptr()),
|
||||
reinterpret_cast<uint8_t*>(k_cache.mutable_data_ptr()),
|
||||
slot_mapping.const_data_ptr<int64_t>(),
|
||||
position_ids.const_data_ptr<int64_t>(),
|
||||
cos_sin_cache.const_data_ptr<float>(), static_cast<float>(eps),
|
||||
num_tokens_full, num_tokens_insert, num_heads_q,
|
||||
num_heads_q_padded, cache_block_size_i, kv_block_stride,
|
||||
stream);
|
||||
+71
-58
@@ -15,16 +15,19 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "torch_utils.h"
|
||||
|
||||
#include <torch/csrc/stable/macros.h>
|
||||
#include <torch/csrc/stable/accelerator.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/csrc/stable/device.h>
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "cuda_utils.h"
|
||||
#include "core/registration.h"
|
||||
#include "minimax_reduce_rms_kernel.h"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -611,7 +614,7 @@ int get_sm_count() {
|
||||
static int sm_count = 0;
|
||||
if (sm_count == 0) {
|
||||
int device_id;
|
||||
CUDA_CHECK(cudaGetDevice(&device_id));
|
||||
STD_CUDA_CHECK(cudaGetDevice(&device_id));
|
||||
cudaDeviceProp device_prop;
|
||||
cudaGetDeviceProperties(&device_prop, device_id);
|
||||
sm_count = device_prop.multiProcessorCount;
|
||||
@@ -621,13 +624,13 @@ int get_sm_count() {
|
||||
|
||||
inline int getSMVersion(bool queryRealSmArch = false) {
|
||||
int device{-1};
|
||||
CUDA_CHECK(cudaGetDevice(&device));
|
||||
STD_CUDA_CHECK(cudaGetDevice(&device));
|
||||
int sm_major = 0;
|
||||
int sm_minor = 0;
|
||||
CUDA_CHECK(cudaDeviceGetAttribute(&sm_major,
|
||||
cudaDevAttrComputeCapabilityMajor, device));
|
||||
CUDA_CHECK(cudaDeviceGetAttribute(&sm_minor,
|
||||
cudaDevAttrComputeCapabilityMinor, device));
|
||||
STD_CUDA_CHECK(cudaDeviceGetAttribute(
|
||||
&sm_major, cudaDevAttrComputeCapabilityMajor, device));
|
||||
STD_CUDA_CHECK(cudaDeviceGetAttribute(
|
||||
&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
|
||||
int sm = sm_major * 10 + sm_minor;
|
||||
if (sm == 121 && !queryRealSmArch) {
|
||||
return 120;
|
||||
@@ -639,7 +642,7 @@ template <typename KernelFunc>
|
||||
int get_max_active_blocks(KernelFunc kernel, int block_size,
|
||||
int dynamic_smem = 0) {
|
||||
int max_active = 0;
|
||||
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
STD_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active, kernel, block_size, dynamic_smem));
|
||||
return std::max(max_active, 1);
|
||||
}
|
||||
@@ -678,27 +681,27 @@ void minimax_reduce_rms_kernel_launcher(MiniMaxReduceRMSParams const& params) {
|
||||
cfg.attrs = attribute;
|
||||
cfg.numAttrs = SM >= 90 ? 2 : 0;
|
||||
|
||||
CUDA_CHECK(cudaLaunchKernelEx(
|
||||
STD_CUDA_CHECK(cudaLaunchKernelEx(
|
||||
&cfg, minimax_reduce_rms_kernel_lamport<DType, NRanks>, params));
|
||||
}
|
||||
|
||||
template <typename DType, int NRanks, int OriginQDim, int OriginKDim>
|
||||
void minimax_reduce_rms_kernel_launcher_float4(
|
||||
MiniMaxReduceRMSParams const& params) {
|
||||
TORCH_CHECK(params.size_q % params.hidden_dim == 0);
|
||||
TORCH_CHECK(params.hidden_dim % kElemsPerAccess<DType> == 0);
|
||||
STD_TORCH_CHECK(params.size_q % params.hidden_dim == 0);
|
||||
STD_TORCH_CHECK(params.hidden_dim % kElemsPerAccess<DType> == 0);
|
||||
if (params.stride_q > 0) {
|
||||
TORCH_CHECK(params.stride_q % kElemsPerAccess<DType> == 0);
|
||||
STD_TORCH_CHECK(params.stride_q % kElemsPerAccess<DType> == 0);
|
||||
}
|
||||
TORCH_CHECK(params.allreduce_in_k != nullptr,
|
||||
"float4 QK kernel requires K input");
|
||||
TORCH_CHECK(params.hidden_dim >= params.hidden_dim_k);
|
||||
TORCH_CHECK(params.size_k % params.hidden_dim_k == 0);
|
||||
TORCH_CHECK(params.hidden_dim_k % kElemsPerAccess<DType> == 0);
|
||||
TORCH_CHECK(params.size_q / params.hidden_dim ==
|
||||
params.size_k / params.hidden_dim_k);
|
||||
STD_TORCH_CHECK(params.allreduce_in_k != nullptr,
|
||||
"float4 QK kernel requires K input");
|
||||
STD_TORCH_CHECK(params.hidden_dim >= params.hidden_dim_k);
|
||||
STD_TORCH_CHECK(params.size_k % params.hidden_dim_k == 0);
|
||||
STD_TORCH_CHECK(params.hidden_dim_k % kElemsPerAccess<DType> == 0);
|
||||
STD_TORCH_CHECK(params.size_q / params.hidden_dim ==
|
||||
params.size_k / params.hidden_dim_k);
|
||||
if (params.stride_k > 0) {
|
||||
TORCH_CHECK(params.stride_k % kElemsPerAccess<DType> == 0);
|
||||
STD_TORCH_CHECK(params.stride_k % kElemsPerAccess<DType> == 0);
|
||||
}
|
||||
|
||||
int token_num = params.size_q / params.hidden_dim;
|
||||
@@ -746,7 +749,7 @@ void minimax_reduce_rms_kernel_launcher_float4(
|
||||
cfg.attrs = attribute;
|
||||
cfg.numAttrs = SM >= 90 ? 2 : 0;
|
||||
|
||||
CUDA_CHECK(cudaLaunchKernelEx(&cfg, kfn, params));
|
||||
STD_CUDA_CHECK(cudaLaunchKernelEx(&cfg, kfn, params));
|
||||
}
|
||||
|
||||
template <int NRanks>
|
||||
@@ -759,21 +762,21 @@ void dispatch_dtype(MiniMaxReduceRMSParams const& params) {
|
||||
(params.hidden_dim * params.nranks == 6144) &&
|
||||
(params.hidden_dim_k * params.nranks == 1024);
|
||||
|
||||
if (params.dtype == at::ScalarType::Half) {
|
||||
if (params.dtype == torch::headeronly::ScalarType::Half) {
|
||||
if (use_float4) {
|
||||
minimax_reduce_rms_kernel_launcher_float4<half, NRanks, 6144, 1024>(
|
||||
params);
|
||||
} else {
|
||||
minimax_reduce_rms_kernel_launcher<half, NRanks>(params);
|
||||
}
|
||||
} else if (params.dtype == at::ScalarType::BFloat16) {
|
||||
} else if (params.dtype == torch::headeronly::ScalarType::BFloat16) {
|
||||
if (use_float4) {
|
||||
minimax_reduce_rms_kernel_launcher_float4<__nv_bfloat16, NRanks, 6144,
|
||||
1024>(params);
|
||||
} else {
|
||||
minimax_reduce_rms_kernel_launcher<__nv_bfloat16, NRanks>(params);
|
||||
}
|
||||
} else if (params.dtype == at::ScalarType::Float) {
|
||||
} else if (params.dtype == torch::headeronly::ScalarType::Float) {
|
||||
if (use_float4) {
|
||||
minimax_reduce_rms_kernel_launcher_float4<float, NRanks, 6144, 1024>(
|
||||
params);
|
||||
@@ -781,7 +784,7 @@ void dispatch_dtype(MiniMaxReduceRMSParams const& params) {
|
||||
minimax_reduce_rms_kernel_launcher<float, NRanks>(params);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type for minimax_reduce_rms_op");
|
||||
STD_TORCH_CHECK(false, "Unsupported data type for minimax_reduce_rms_op");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -795,16 +798,18 @@ void minimax_reduce_rms_op(MiniMaxReduceRMSParams const& params) {
|
||||
} else if (params.nranks == 16) {
|
||||
dispatch_dtype<16>(params);
|
||||
} else {
|
||||
TORCH_CHECK(false, "minimax_reduce_rms_op: unsupported ranks number!");
|
||||
STD_TORCH_CHECK(false, "minimax_reduce_rms_op: unsupported ranks number!");
|
||||
}
|
||||
}
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace vllm
|
||||
|
||||
torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
|
||||
torch::Tensor const& norm_weight,
|
||||
torch::Tensor workspace, int64_t const rank,
|
||||
int64_t const nranks, double const eps) {
|
||||
torch::stable::Tensor minimax_allreduce_rms(
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& norm_weight, torch::stable::Tensor workspace,
|
||||
int64_t const rank, int64_t const nranks, double const eps) {
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
input.get_device_index());
|
||||
auto allreduce_params = vllm::tensorrt_llm::MiniMaxReduceRMSParams();
|
||||
|
||||
allreduce_params.nranks = static_cast<int>(nranks);
|
||||
@@ -815,12 +820,12 @@ torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
|
||||
allreduce_params.stride_q = allreduce_params.hidden_dim;
|
||||
allreduce_params.workspace =
|
||||
reinterpret_cast<void**>(workspace.mutable_data_ptr());
|
||||
allreduce_params.allreduce_in = input.data_ptr();
|
||||
allreduce_params.rms_gamma = norm_weight.data_ptr();
|
||||
allreduce_params.allreduce_in = const_cast<void*>(input.const_data_ptr());
|
||||
allreduce_params.rms_gamma = const_cast<void*>(norm_weight.const_data_ptr());
|
||||
allreduce_params.rms_eps = static_cast<float>(eps);
|
||||
allreduce_params.stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
allreduce_params.stream = get_current_cuda_stream(input.get_device_index());
|
||||
|
||||
torch::Tensor rms_norm_out = torch::empty_like(input);
|
||||
torch::stable::Tensor rms_norm_out = torch::stable::empty_like(input);
|
||||
allreduce_params.rms_norm_out = rms_norm_out.mutable_data_ptr();
|
||||
|
||||
vllm::tensorrt_llm::minimax_reduce_rms_op(allreduce_params);
|
||||
@@ -828,26 +833,33 @@ torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
|
||||
return rms_norm_out;
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> minimax_allreduce_rms_qk(
|
||||
torch::Tensor qkv, torch::Tensor const& norm_weight_q,
|
||||
torch::Tensor const& norm_weight_k, torch::Tensor workspace,
|
||||
int64_t const q_size, int64_t const kv_size, int64_t const rank,
|
||||
int64_t const nranks, double const eps) {
|
||||
TORCH_CHECK(qkv.dim() == 2, "minimax_allreduce_rms_qk: qkv must be 2D");
|
||||
TORCH_CHECK(qkv.is_contiguous(),
|
||||
"minimax_allreduce_rms_qk: qkv must be contiguous");
|
||||
std::tuple<torch::stable::Tensor, torch::stable::Tensor>
|
||||
minimax_allreduce_rms_qk(torch::stable::Tensor qkv,
|
||||
torch::stable::Tensor const& norm_weight_q,
|
||||
torch::stable::Tensor const& norm_weight_k,
|
||||
torch::stable::Tensor workspace, int64_t const q_size,
|
||||
int64_t const kv_size, int64_t const rank,
|
||||
int64_t const nranks, double const eps) {
|
||||
STD_TORCH_CHECK(qkv.dim() == 2, "minimax_allreduce_rms_qk: qkv must be 2D");
|
||||
STD_TORCH_CHECK(qkv.is_contiguous(),
|
||||
"minimax_allreduce_rms_qk: qkv must be contiguous");
|
||||
int64_t qkv_dim = qkv.size(-1);
|
||||
TORCH_CHECK(qkv_dim == q_size + 2 * kv_size,
|
||||
"minimax_allreduce_rms_qk: qkv last dim must equal "
|
||||
"q_size + 2 * kv_size");
|
||||
TORCH_CHECK(rank < nranks,
|
||||
"minimax_allreduce_rms_qk: rank must be less than nranks");
|
||||
STD_TORCH_CHECK(qkv_dim == q_size + 2 * kv_size,
|
||||
"minimax_allreduce_rms_qk: qkv last dim must equal "
|
||||
"q_size + 2 * kv_size");
|
||||
STD_TORCH_CHECK(rank < nranks,
|
||||
"minimax_allreduce_rms_qk: rank must be less than nranks");
|
||||
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
qkv.get_device_index());
|
||||
|
||||
int64_t num_tokens = qkv.size(0);
|
||||
int elem_bytes = qkv.element_size();
|
||||
|
||||
torch::Tensor q_out = torch::empty({num_tokens, q_size}, qkv.options());
|
||||
torch::Tensor k_out = torch::empty({num_tokens, kv_size}, qkv.options());
|
||||
torch::stable::Tensor q_out =
|
||||
torch::stable::new_empty(qkv, {num_tokens, q_size}, qkv.scalar_type());
|
||||
torch::stable::Tensor k_out =
|
||||
torch::stable::new_empty(qkv, {num_tokens, kv_size}, qkv.scalar_type());
|
||||
|
||||
auto params = vllm::tensorrt_llm::MiniMaxReduceRMSParams();
|
||||
params.nranks = static_cast<int>(nranks);
|
||||
@@ -863,13 +875,14 @@ std::tuple<torch::Tensor, torch::Tensor> minimax_allreduce_rms_qk(
|
||||
params.stride_k_out = 0; // k_out is contiguous; kernel uses hidden_dim_k
|
||||
params.workspace = reinterpret_cast<void**>(workspace.mutable_data_ptr());
|
||||
|
||||
uint8_t* base = static_cast<uint8_t*>(qkv.data_ptr());
|
||||
uint8_t* base =
|
||||
const_cast<uint8_t*>(static_cast<const uint8_t*>(qkv.const_data_ptr()));
|
||||
params.allreduce_in = base;
|
||||
params.allreduce_in_k = base + q_size * elem_bytes;
|
||||
params.rms_gamma = norm_weight_q.data_ptr();
|
||||
params.rms_gamma_k = norm_weight_k.data_ptr();
|
||||
params.rms_gamma = const_cast<void*>(norm_weight_q.const_data_ptr());
|
||||
params.rms_gamma_k = const_cast<void*>(norm_weight_k.const_data_ptr());
|
||||
params.rms_eps = static_cast<float>(eps);
|
||||
params.stream = at::cuda::getCurrentCUDAStream(qkv.get_device());
|
||||
params.stream = get_current_cuda_stream(qkv.get_device_index());
|
||||
|
||||
params.rms_norm_out = q_out.mutable_data_ptr();
|
||||
params.rms_norm_out_k = k_out.mutable_data_ptr();
|
||||
@@ -0,0 +1,69 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
// Adapted from SGLang:
|
||||
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu
|
||||
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_mxfp8_grouped_mm_launcher.cuh"
|
||||
|
||||
void cutlass_mxfp8_grouped_mm(const torch::stable::Tensor& a,
|
||||
const torch::stable::Tensor& b,
|
||||
const torch::stable::Tensor& sfa,
|
||||
const torch::stable::Tensor& sfb,
|
||||
torch::stable::Tensor& d,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& blockscale_offsets) {
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
STD_TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
STD_TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||
"problem_sizes must have shape (num_experts, 3)");
|
||||
STD_TORCH_CHECK(
|
||||
problem_sizes.size(0) == expert_offsets.size(0),
|
||||
"Number of experts in problem_sizes must match expert_offsets");
|
||||
STD_TORCH_CHECK(
|
||||
problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||
"problem_sizes must be int32");
|
||||
STD_TORCH_CHECK(
|
||||
expert_offsets.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||
"expert_offsets must be int32");
|
||||
STD_TORCH_CHECK(
|
||||
blockscale_offsets.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||
"blockscale_offsets must be int32");
|
||||
STD_TORCH_CHECK(a.dim() == 2,
|
||||
"a must be a 2D tensor of shape (num_tokens, k)");
|
||||
STD_TORCH_CHECK(b.dim() == 3,
|
||||
"b must be a 3D tensor of shape (num_experts, k, n)");
|
||||
STD_TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0,
|
||||
"k should align 128");
|
||||
STD_TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128");
|
||||
STD_TORCH_CHECK(a.stride(1) == 1, "a must be row major");
|
||||
STD_TORCH_CHECK(b.stride(1) == 1, "b must be column major");
|
||||
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
a.get_device_index());
|
||||
auto stream = get_current_cuda_stream(a.get_device_index());
|
||||
if (d.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
|
||||
cutlass::bfloat16_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
|
||||
blockscale_offsets, stream);
|
||||
} else if (d.scalar_type() == torch::headeronly::ScalarType::Half) {
|
||||
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
|
||||
cutlass::half_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
|
||||
blockscale_offsets, stream);
|
||||
} else {
|
||||
STD_TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
|
||||
}
|
||||
#else
|
||||
STD_TORCH_CHECK(false,
|
||||
"No implemented cutlass_mxfp8_grouped_mm for "
|
||||
"current device");
|
||||
#endif
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
|
||||
m.impl("cutlass_mxfp8_grouped_mm", TORCH_BOX(&cutlass_mxfp8_grouped_mm));
|
||||
}
|
||||
+71
-52
@@ -4,9 +4,9 @@
|
||||
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_launcher.cuh
|
||||
|
||||
#pragma once
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
@@ -15,18 +15,22 @@
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass_mxfp8_grouped_mm_functor.cuh"
|
||||
#include "cutlass_mxfp8_grouped_mm_traits.cuh"
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
namespace expert_specialization {
|
||||
|
||||
template <typename GemmTraits>
|
||||
void cutlass_mxfp8_grouped_mm_pre_compute(
|
||||
torch::Tensor& a_ptrs, torch::Tensor& b_ptrs, torch::Tensor& sfa_ptrs,
|
||||
torch::Tensor& sfb_ptrs, torch::Tensor& d_ptrs, torch::Tensor& stride_a,
|
||||
torch::Tensor& stride_b, torch::Tensor& stride_d, torch::Tensor& layout_sfa,
|
||||
torch::Tensor& layout_sfb, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& sfa, const torch::Tensor& sfb, const torch::Tensor& d,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& blockscale_offsets, cudaStream_t stream) {
|
||||
torch::stable::Tensor& a_ptrs, torch::stable::Tensor& b_ptrs,
|
||||
torch::stable::Tensor& sfa_ptrs, torch::stable::Tensor& sfb_ptrs,
|
||||
torch::stable::Tensor& d_ptrs, torch::stable::Tensor& stride_a,
|
||||
torch::stable::Tensor& stride_b, torch::stable::Tensor& stride_d,
|
||||
torch::stable::Tensor& layout_sfa, torch::stable::Tensor& layout_sfb,
|
||||
const torch::stable::Tensor& a, const torch::stable::Tensor& b,
|
||||
const torch::stable::Tensor& sfa, const torch::stable::Tensor& sfb,
|
||||
const torch::stable::Tensor& d, const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& blockscale_offsets, cudaStream_t stream) {
|
||||
using OffsetFunctor = CutlassMxfp8GroupedMmOffsetFunctor<GemmTraits>;
|
||||
using ElementA = typename OffsetFunctor::ElementA;
|
||||
using ElementB = typename OffsetFunctor::ElementB;
|
||||
@@ -42,10 +46,10 @@ void cutlass_mxfp8_grouped_mm_pre_compute(
|
||||
using StrideB = typename StrideFunctor::StrideB;
|
||||
using StrideD = typename StrideFunctor::StrideD;
|
||||
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
TORCH_CHECK(num_experts <= 1024,
|
||||
"Number of experts cannot exceed 1024, the maximum number of "
|
||||
"threads per block.");
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
STD_TORCH_CHECK(num_experts <= 1024,
|
||||
"Number of experts cannot exceed 1024, the maximum number of "
|
||||
"threads per block.");
|
||||
|
||||
OffsetFunctor offset_functor(
|
||||
reinterpret_cast<int*>(expert_offsets.data_ptr()),
|
||||
@@ -72,13 +76,18 @@ void cutlass_mxfp8_grouped_mm_pre_compute(
|
||||
}
|
||||
|
||||
template <typename GemmTraits>
|
||||
void cutlass_mxfp8_grouped_mm(
|
||||
const torch::Tensor& a_ptrs, const torch::Tensor& b_ptrs,
|
||||
const torch::Tensor& sfa_ptrs, const torch::Tensor& sfb_ptrs,
|
||||
const torch::Tensor& d_ptrs, const torch::Tensor& stride_a,
|
||||
const torch::Tensor& stride_b, const torch::Tensor& stride_d,
|
||||
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& problem_sizes, cudaStream_t stream) {
|
||||
void cutlass_mxfp8_grouped_mm(const torch::stable::Tensor& a_ptrs,
|
||||
const torch::stable::Tensor& b_ptrs,
|
||||
const torch::stable::Tensor& sfa_ptrs,
|
||||
const torch::stable::Tensor& sfb_ptrs,
|
||||
const torch::stable::Tensor& d_ptrs,
|
||||
const torch::stable::Tensor& stride_a,
|
||||
const torch::stable::Tensor& stride_b,
|
||||
const torch::stable::Tensor& stride_d,
|
||||
const torch::stable::Tensor& layout_sfa,
|
||||
const torch::stable::Tensor& layout_sfb,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
cudaStream_t stream) {
|
||||
using Gemm = typename GemmTraits::Gemm;
|
||||
using ElementA = typename Gemm::ElementA;
|
||||
using ElementB = typename Gemm::ElementB;
|
||||
@@ -93,13 +102,12 @@ void cutlass_mxfp8_grouped_mm(
|
||||
typename GemmTraits::ProblemShape::UnderlyingProblemShape;
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = c10::cuda::current_device();
|
||||
hw_info.sm_count =
|
||||
at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||
hw_info.device_id = d_ptrs.get_device_index();
|
||||
hw_info.sm_count = get_device_prop()->multiProcessorCount;
|
||||
hw_info.cluster_shape = GemmTraits::MMAConfig::preferred_cluster;
|
||||
hw_info.cluster_shape_fallback = GemmTraits::MMAConfig::fallback_cluster;
|
||||
|
||||
int num_experts = (int)problem_sizes.size(0);
|
||||
int num_experts = static_cast<int>(problem_sizes.size(0));
|
||||
|
||||
UnderlyingProblemShape* underlying_problem_shape =
|
||||
reinterpret_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||
@@ -127,44 +135,55 @@ void cutlass_mxfp8_grouped_mm(
|
||||
Gemm gemm;
|
||||
|
||||
auto can_implement_status = gemm.can_implement(arguments);
|
||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
||||
"Failed to implement GEMM");
|
||||
STD_TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
||||
"Failed to implement GEMM");
|
||||
|
||||
torch::TensorOptions options_uint8 =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(d_ptrs.device());
|
||||
size_t workspace_size = gemm.get_workspace_size(arguments);
|
||||
torch::Tensor workspace = torch::empty(workspace_size, options_uint8);
|
||||
torch::stable::Tensor workspace = torch::stable::empty(
|
||||
{static_cast<int64_t>(workspace_size)},
|
||||
torch::headeronly::ScalarType::Byte, std::nullopt, d_ptrs.device());
|
||||
|
||||
auto status = gemm.initialize(arguments, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
|
||||
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
|
||||
"Failed to initialize GEMM");
|
||||
|
||||
status = gemm.run(stream, nullptr, true); // Enable PDL
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void cutlass_mxfp8_grouped_mm_dispatch_out_dtype(
|
||||
const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& sfa,
|
||||
const torch::Tensor& sfb, torch::Tensor& d,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& blockscale_offsets, cudaStream_t stream) {
|
||||
int num_experts = (int)problem_sizes.size(0);
|
||||
torch::TensorOptions options_int64 =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
torch::TensorOptions options_int32 =
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(a.device());
|
||||
const torch::stable::Tensor& a, const torch::stable::Tensor& b,
|
||||
const torch::stable::Tensor& sfa, const torch::stable::Tensor& sfb,
|
||||
torch::stable::Tensor& d, const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& blockscale_offsets, cudaStream_t stream) {
|
||||
int num_experts = static_cast<int>(problem_sizes.size(0));
|
||||
auto device = a.device();
|
||||
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor sfa_ptrs = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor sfb_ptrs = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor d_ptrs = torch::empty(num_experts, options_int64);
|
||||
torch::stable::Tensor a_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor b_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor sfa_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor sfb_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor d_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
|
||||
torch::Tensor stride_a = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor stride_b = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor stride_d = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int32);
|
||||
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int32);
|
||||
torch::stable::Tensor stride_a = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor stride_b = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor stride_d = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor layout_sfa =
|
||||
torch::stable::empty({num_experts, 5}, torch::headeronly::ScalarType::Int,
|
||||
std::nullopt, device);
|
||||
torch::stable::Tensor layout_sfb =
|
||||
torch::stable::empty({num_experts, 5}, torch::headeronly::ScalarType::Int,
|
||||
std::nullopt, device);
|
||||
|
||||
using GemmTraits = CutlassMxfp8GroupedMmGemmTraits<MMA1SMConfig, OutType>;
|
||||
cutlass_mxfp8_grouped_mm_pre_compute<GemmTraits>(
|
||||
@@ -176,4 +195,4 @@ void cutlass_mxfp8_grouped_mm_dispatch_out_dtype(
|
||||
layout_sfa, layout_sfb, problem_sizes, stream);
|
||||
}
|
||||
|
||||
} // namespace expert_specialization
|
||||
} // namespace expert_specialization
|
||||
@@ -0,0 +1,66 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
// Adapted from SGLang:
|
||||
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu
|
||||
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "mxfp8_experts_quant.cuh"
|
||||
|
||||
void mxfp8_experts_quant(const torch::stable::Tensor& input,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& blockscale_offsets,
|
||||
torch::stable::Tensor& quant_output,
|
||||
torch::stable::Tensor& scale_factor) {
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
STD_TORCH_CHECK(input.dim() == 2, "input must be 2D tensor");
|
||||
STD_TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128");
|
||||
STD_TORCH_CHECK(input.stride(1) == 1, "input must be row major");
|
||||
STD_TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
STD_TORCH_CHECK(
|
||||
problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||
"problem_sizes must be int32");
|
||||
STD_TORCH_CHECK(
|
||||
expert_offsets.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||
"expert_offsets must be int32");
|
||||
STD_TORCH_CHECK(
|
||||
blockscale_offsets.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||
"blockscale_offsets must be int32");
|
||||
|
||||
auto groups = problem_sizes.size(0);
|
||||
STD_TORCH_CHECK(
|
||||
expert_offsets.dim() == 1 && expert_offsets.size(0) == groups,
|
||||
"expert_offsets must be 1D and have size equal to the number of groups");
|
||||
STD_TORCH_CHECK(
|
||||
blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups,
|
||||
"blockscale_offsets must be 1D and have size equal to the number of "
|
||||
"groups");
|
||||
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
input.get_device_index());
|
||||
if (input.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
expert_specialization::launch_mxfp8_experts_quant<__nv_bfloat16>(
|
||||
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
|
||||
scale_factor);
|
||||
} else if (input.scalar_type() == torch::headeronly::ScalarType::Half) {
|
||||
expert_specialization::launch_mxfp8_experts_quant<__half>(
|
||||
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
|
||||
scale_factor);
|
||||
} else {
|
||||
STD_TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
|
||||
}
|
||||
#else
|
||||
STD_TORCH_CHECK(false,
|
||||
"No implemented mxfp8_experts_quant for "
|
||||
"current device");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Registered here (not torch_bindings.cpp) because ENABLE_ES_MXFP8_GROUPED_MM
|
||||
// is applied only under COMPILE_LANGUAGE:CUDA.
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
|
||||
m.impl("mxfp8_experts_quant", TORCH_BOX(&mxfp8_experts_quant));
|
||||
}
|
||||
+16
-14
@@ -4,16 +4,19 @@
|
||||
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cuh
|
||||
|
||||
#pragma once
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/macros.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
|
||||
#include <cuda/ptx>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
namespace expert_specialization {
|
||||
|
||||
@@ -356,12 +359,12 @@ __global__ void mxfp8_experts_quant_kernel(
|
||||
}
|
||||
|
||||
template <typename T_IN>
|
||||
void launch_mxfp8_experts_quant(const torch::Tensor& input,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& blockscale_offsets,
|
||||
torch::Tensor& quant_output,
|
||||
torch::Tensor& scale_factor) {
|
||||
void launch_mxfp8_experts_quant(const torch::stable::Tensor& input,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& blockscale_offsets,
|
||||
torch::stable::Tensor& quant_output,
|
||||
torch::stable::Tensor& scale_factor) {
|
||||
ThrLayout thr_layout{};
|
||||
ValLayout val_layout{};
|
||||
SfR2SThrLayout r2s_thr_layout{};
|
||||
@@ -386,19 +389,18 @@ void launch_mxfp8_experts_quant(const torch::Tensor& input,
|
||||
CopyAtomR2S{}, r2s_thr_layout, r2s_val_layout); // Tiler_MN: (16, 4)
|
||||
|
||||
int max_active_blocks_per_sm = -1;
|
||||
AT_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
STD_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks_per_sm,
|
||||
mxfp8_experts_quant_kernel<T_IN, decltype(tiled_copy_g2r),
|
||||
decltype(tiled_copy_r2g),
|
||||
decltype(tiled_copy_r2s)>,
|
||||
THREAD_BLOCK_SIZE, 0));
|
||||
|
||||
dim3 grid(at::cuda::getCurrentDeviceProperties()->multiProcessorCount *
|
||||
max_active_blocks_per_sm,
|
||||
dim3 grid(get_device_prop()->multiProcessorCount * max_active_blocks_per_sm,
|
||||
1, 1);
|
||||
dim3 block(THREAD_BLOCK_SIZE, 1, 1);
|
||||
int num_experts = (int)problem_sizes.size(0);
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
int num_experts = static_cast<int>(problem_sizes.size(0));
|
||||
auto stream = get_current_cuda_stream(input.get_device_index());
|
||||
mxfp8_experts_quant_kernel<T_IN, decltype(tiled_copy_g2r),
|
||||
decltype(tiled_copy_r2g), decltype(tiled_copy_r2s)>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
@@ -231,6 +231,27 @@ void fused_qk_norm_rope(torch::stable::Tensor& qkv, int64_t num_heads_q,
|
||||
torch::stable::Tensor& position_ids,
|
||||
int64_t forced_token_heads_per_warp);
|
||||
|
||||
torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
|
||||
torch::stable::Tensor const& q_in, torch::stable::Tensor const& kv,
|
||||
torch::stable::Tensor& k_cache, torch::stable::Tensor const& slot_mapping,
|
||||
torch::stable::Tensor const& position_ids,
|
||||
torch::stable::Tensor const& cos_sin_cache, int64_t q_head_padded,
|
||||
double eps, int64_t cache_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::stable::Tensor minimax_allreduce_rms(
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& norm_weight, torch::stable::Tensor workspace,
|
||||
int64_t const rank, int64_t const nranks, double const eps);
|
||||
std::tuple<torch::stable::Tensor, torch::stable::Tensor>
|
||||
minimax_allreduce_rms_qk(torch::stable::Tensor qkv,
|
||||
torch::stable::Tensor const& norm_weight_q,
|
||||
torch::stable::Tensor const& norm_weight_k,
|
||||
torch::stable::Tensor workspace, int64_t const q_size,
|
||||
int64_t const kv_size, int64_t const rank,
|
||||
int64_t const nranks, double const eps);
|
||||
#endif
|
||||
|
||||
// Sampler kernels (shared CUDA/ROCm)
|
||||
void apply_repetition_penalties_(
|
||||
torch::stable::Tensor& logits, const torch::stable::Tensor& prompt_mask,
|
||||
@@ -273,6 +294,26 @@ void selective_scan_fwd(
|
||||
const std::optional<torch::stable::Tensor>& cu_chunk_seqlen,
|
||||
const std::optional<torch::stable::Tensor>& last_chunk_indices);
|
||||
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||
torch::stable::Tensor& rank_data, int64_t rank,
|
||||
bool fully_connected);
|
||||
void all_reduce(fptr_t _fa, torch::stable::Tensor& inp,
|
||||
torch::stable::Tensor& out, fptr_t reg_buffer,
|
||||
int64_t reg_buffer_sz_bytes);
|
||||
void dispose(fptr_t _fa);
|
||||
int64_t meta_size();
|
||||
void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
|
||||
get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||
void register_graph_buffers(fptr_t _fa,
|
||||
const std::vector<std::vector<int64_t>>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets);
|
||||
std::tuple<int64_t, torch::stable::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size);
|
||||
int64_t open_mem_handle(torch::stable::Tensor& mem_handle);
|
||||
void free_shared_buffer(int64_t buffer);
|
||||
|
||||
// Activation kernels (shared CUDA/ROCm)
|
||||
void silu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input);
|
||||
void silu_and_mul_clamp(torch::stable::Tensor& out,
|
||||
|
||||
@@ -337,6 +337,24 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||
"bool is_neox, Tensor position_ids, "
|
||||
"int forced_token_heads_per_warp=-1) -> ()");
|
||||
|
||||
ops.def(
|
||||
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert("
|
||||
"Tensor q_in, Tensor kv, Tensor! k_cache, "
|
||||
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
|
||||
"int q_head_padded, float eps, int cache_block_size) -> Tensor");
|
||||
|
||||
#ifndef USE_ROCM
|
||||
ops.def(
|
||||
"minimax_allreduce_rms("
|
||||
"Tensor input, Tensor norm_weight, Tensor workspace, "
|
||||
"int rank, int nranks, float eps) -> Tensor");
|
||||
ops.def(
|
||||
"minimax_allreduce_rms_qk("
|
||||
"Tensor qkv, Tensor norm_weight_q, Tensor norm_weight_k, "
|
||||
"Tensor workspace, int q_size, int kv_size, int rank, int nranks, "
|
||||
"float eps) -> (Tensor, Tensor)");
|
||||
#endif
|
||||
|
||||
// Apply repetition penalties to logits in-place.
|
||||
ops.def(
|
||||
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
|
||||
@@ -571,6 +589,12 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||
// Positional encoding kernels (shared CUDA/ROCm)
|
||||
ops.impl("rotary_embedding", TORCH_BOX(&rotary_embedding));
|
||||
ops.impl("fused_qk_norm_rope", TORCH_BOX(&fused_qk_norm_rope));
|
||||
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert",
|
||||
TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert));
|
||||
#ifndef USE_ROCM
|
||||
ops.impl("minimax_allreduce_rms", TORCH_BOX(&minimax_allreduce_rms));
|
||||
ops.impl("minimax_allreduce_rms_qk", TORCH_BOX(&minimax_allreduce_rms_qk));
|
||||
#endif
|
||||
|
||||
// Sampler kernels (shared CUDA/ROCm)
|
||||
ops.impl("apply_repetition_penalties_",
|
||||
@@ -725,6 +749,45 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C_cache_ops, ops) {
|
||||
"dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(_C_custom_ar, custom_ar) {
|
||||
custom_ar.def(
|
||||
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
|
||||
"int rank, bool fully_connected) -> int");
|
||||
custom_ar.def(
|
||||
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
|
||||
"int reg_buffer_sz_bytes) -> ()");
|
||||
custom_ar.def("dispose(int fa) -> ()");
|
||||
custom_ar.def("meta_size() -> int");
|
||||
custom_ar.def("register_buffer(int fa, int[] ipc_tensors) -> ()");
|
||||
custom_ar.def("get_graph_buffer_ipc_meta(int fa) -> (int[], int[])");
|
||||
custom_ar.def(
|
||||
"register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()");
|
||||
custom_ar.def("allocate_shared_buffer_and_handle(int size) -> (int, Tensor)");
|
||||
custom_ar.def("open_mem_handle(Tensor mem_handle) -> int");
|
||||
custom_ar.def("free_shared_buffer(int ptr) -> ()");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CUDA, custom_ar) {
|
||||
custom_ar.impl("init_custom_ar", TORCH_BOX(&init_custom_ar));
|
||||
custom_ar.impl("all_reduce", TORCH_BOX(&all_reduce));
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CPU, custom_ar) {
|
||||
custom_ar.impl("open_mem_handle", TORCH_BOX(&open_mem_handle));
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CompositeExplicitAutograd, custom_ar) {
|
||||
custom_ar.impl("dispose", TORCH_BOX(&dispose));
|
||||
custom_ar.impl("meta_size", TORCH_BOX(&meta_size));
|
||||
custom_ar.impl("register_buffer", TORCH_BOX(®ister_buffer));
|
||||
custom_ar.impl("get_graph_buffer_ipc_meta",
|
||||
TORCH_BOX(&get_graph_buffer_ipc_meta));
|
||||
custom_ar.impl("register_graph_buffers", TORCH_BOX(®ister_graph_buffers));
|
||||
custom_ar.impl("allocate_shared_buffer_and_handle",
|
||||
TORCH_BOX(&allocate_shared_buffer_and_handle));
|
||||
custom_ar.impl("free_shared_buffer", TORCH_BOX(&free_shared_buffer));
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C_cache_ops, CPU, ops) {
|
||||
ops.impl("swap_blocks_batch", TORCH_BOX(&swap_blocks_batch));
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <torch/types.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
namespace vllm {
|
||||
namespace tensorrt_llm {
|
||||
@@ -51,7 +51,7 @@ static constexpr int kElemsPerAccess = ElemsPerAccess<DType>::value;
|
||||
struct MiniMaxReduceRMSParams {
|
||||
int nranks{};
|
||||
int rank{};
|
||||
at::ScalarType dtype{at::ScalarType::Undefined};
|
||||
torch::headeronly::ScalarType dtype{torch::headeronly::ScalarType::Undefined};
|
||||
int size_q{};
|
||||
int hidden_dim{};
|
||||
int size_k{};
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
// Adapted from SGLang:
|
||||
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass_mxfp8_grouped_mm_launcher.cuh"
|
||||
|
||||
void cutlass_mxfp8_grouped_mm(const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& sfa,
|
||||
const torch::Tensor& sfb, torch::Tensor& d,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& blockscale_offsets) {
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||
"problem_sizes must have shape (num_experts, 3)");
|
||||
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
|
||||
"Number of experts in problem_sizes must match expert_offsets");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(expert_offsets.dtype() == torch::kInt32,
|
||||
"expert_offsets must be int32");
|
||||
TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32,
|
||||
"blockscale_offsets must be int32");
|
||||
TORCH_CHECK(a.dim() == 2, "a must be a 2D tensor of shape (num_tokens, k)");
|
||||
TORCH_CHECK(b.dim() == 3,
|
||||
"b must be a 3D tensor of shape (num_experts, k, n)");
|
||||
TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0,
|
||||
"k should align 128");
|
||||
TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128");
|
||||
TORCH_CHECK(a.strides()[1] == 1, "a must be row major");
|
||||
TORCH_CHECK(b.strides()[1] == 1, "b must be column major");
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
if (d.dtype() == torch::kBFloat16) {
|
||||
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
|
||||
cutlass::bfloat16_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
|
||||
blockscale_offsets, stream);
|
||||
} else if (d.dtype() == torch::kFloat16) {
|
||||
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
|
||||
cutlass::half_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
|
||||
blockscale_offsets, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false,
|
||||
"No implemented cutlass_mxfp8_grouped_mm for "
|
||||
"current device");
|
||||
#endif
|
||||
}
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("cutlass_mxfp8_grouped_mm", cutlass_mxfp8_grouped_mm);
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
// Adapted from SGLang:
|
||||
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "mxfp8_experts_quant.cuh"
|
||||
|
||||
void mxfp8_experts_quant(const torch::Tensor& input,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& blockscale_offsets,
|
||||
torch::Tensor& quant_output,
|
||||
torch::Tensor& scale_factor) {
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
TORCH_CHECK(input.dim() == 2, "input must be 2D tensor");
|
||||
TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128");
|
||||
TORCH_CHECK(input.strides()[1] == 1, "input must be row major");
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(expert_offsets.dtype() == torch::kInt32,
|
||||
"expert_offsets must be int32");
|
||||
TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32,
|
||||
"blockscale_offsets must be int32");
|
||||
|
||||
auto groups = problem_sizes.size(0);
|
||||
TORCH_CHECK(
|
||||
expert_offsets.dim() == 1 && expert_offsets.size(0) == groups,
|
||||
"expert_offsets must be 1D and have size equal to the number of groups");
|
||||
TORCH_CHECK(
|
||||
blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups,
|
||||
"blockscale_offsets must be 1D and have size equal to the number of "
|
||||
"groups");
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
if (input.dtype() == torch::kBFloat16) {
|
||||
expert_specialization::launch_mxfp8_experts_quant<__nv_bfloat16>(
|
||||
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
|
||||
scale_factor);
|
||||
} else if (input.dtype() == torch::kFloat16) {
|
||||
expert_specialization::launch_mxfp8_experts_quant<__half>(
|
||||
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
|
||||
scale_factor);
|
||||
} else {
|
||||
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false,
|
||||
"No implemented mxfp8_experts_quant for "
|
||||
"current device");
|
||||
#endif
|
||||
}
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("mxfp8_experts_quant", mxfp8_experts_quant);
|
||||
}
|
||||
-36
@@ -40,12 +40,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||
torch::Tensor& weight, double epsilon);
|
||||
|
||||
torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
|
||||
torch::Tensor const& q_in, torch::Tensor const& kv, torch::Tensor& k_cache,
|
||||
torch::Tensor const& slot_mapping, torch::Tensor const& position_ids,
|
||||
torch::Tensor const& cos_sin_cache, int64_t q_head_padded, double eps,
|
||||
int64_t cache_block_size);
|
||||
|
||||
void silu_and_mul_per_block_quant(torch::Tensor& out,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor& scales, int64_t group_size,
|
||||
@@ -107,24 +101,6 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
|
||||
int64_t activation_kind);
|
||||
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||
torch::Tensor& rank_data, int64_t rank,
|
||||
bool fully_connected);
|
||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
|
||||
void dispose(fptr_t _fa);
|
||||
int64_t meta_size();
|
||||
void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
|
||||
get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||
void register_graph_buffers(fptr_t _fa,
|
||||
const std::vector<std::vector<int64_t>>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets);
|
||||
std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size);
|
||||
int64_t open_mem_handle(torch::Tensor& mem_handle);
|
||||
void free_shared_buffer(int64_t buffer);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
fptr_t init_custom_qr(int64_t rank, int64_t world_size,
|
||||
std::optional<int64_t> qr_max_size = std::nullopt);
|
||||
@@ -135,15 +111,3 @@ void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
int64_t quant_level, bool cast_bf2half = false);
|
||||
int64_t qr_max_size();
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
|
||||
torch::Tensor const& norm_weight,
|
||||
torch::Tensor workspace, int64_t const rank,
|
||||
int64_t const nranks, double const eps);
|
||||
std::tuple<torch::Tensor, torch::Tensor> minimax_allreduce_rms_qk(
|
||||
torch::Tensor qkv, torch::Tensor const& norm_weight_q,
|
||||
torch::Tensor const& norm_weight_k, torch::Tensor workspace,
|
||||
int64_t const q_size, int64_t const kv_size, int64_t const rank,
|
||||
int64_t const nranks, double const eps);
|
||||
#endif
|
||||
|
||||
+20
-78
@@ -55,14 +55,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
|
||||
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
|
||||
// kernel launch.
|
||||
ops.def(
|
||||
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert("
|
||||
"Tensor q_in, Tensor kv, Tensor! k_cache, "
|
||||
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
|
||||
"int q_head_padded, float eps, int cache_block_size) -> Tensor");
|
||||
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA,
|
||||
&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert);
|
||||
// kernel launch. Registered in _C_stable_libtorch.
|
||||
|
||||
// Quantization ops
|
||||
#ifndef USE_ROCM
|
||||
@@ -163,34 +156,27 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
ops.def(
|
||||
"minimax_allreduce_rms("
|
||||
"Tensor input,"
|
||||
"Tensor norm_weight,"
|
||||
"Tensor workspace,"
|
||||
"int rank,"
|
||||
"int nranks,"
|
||||
"float eps) -> Tensor");
|
||||
ops.impl("minimax_allreduce_rms", torch::kCUDA, &minimax_allreduce_rms);
|
||||
ops.def(
|
||||
"minimax_allreduce_rms_qk("
|
||||
"Tensor qkv,"
|
||||
"Tensor norm_weight_q,"
|
||||
"Tensor norm_weight_k,"
|
||||
"Tensor workspace,"
|
||||
"int q_size,"
|
||||
"int kv_size,"
|
||||
"int rank,"
|
||||
"int nranks,"
|
||||
"float eps) -> (Tensor, Tensor)");
|
||||
ops.impl("minimax_allreduce_rms_qk", torch::kCUDA, &minimax_allreduce_rms_qk);
|
||||
|
||||
// conditionally compiled so impl in source file
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
TORCH_LIBRARY_FRAGMENT(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
||||
// Quick Reduce all-reduce kernels (ROCm-only; stays on legacy _C).
|
||||
custom_ar.def(
|
||||
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
|
||||
"cast_bf2half) -> ()");
|
||||
custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
|
||||
|
||||
custom_ar.def("init_custom_qr", &init_custom_qr);
|
||||
custom_ar.def("qr_destroy", &qr_destroy);
|
||||
custom_ar.def("qr_get_handle", &qr_get_handle);
|
||||
|
||||
custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
|
||||
custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
|
||||
|
||||
custom_ar.def("qr_max_size", &qr_max_size);
|
||||
}
|
||||
#endif
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||
// Cuda utils
|
||||
|
||||
@@ -205,48 +191,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||
&get_max_shared_memory_per_block_device_attribute);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
||||
// Custom all-reduce kernels
|
||||
custom_ar.def(
|
||||
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
|
||||
"int rank, bool fully_connected) -> int");
|
||||
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
||||
custom_ar.def(
|
||||
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
|
||||
"int reg_buffer_sz_bytes) -> ()");
|
||||
custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce);
|
||||
|
||||
custom_ar.def("dispose", &dispose);
|
||||
custom_ar.def("meta_size", &meta_size);
|
||||
|
||||
custom_ar.def("register_buffer", ®ister_buffer);
|
||||
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
||||
custom_ar.def("register_graph_buffers", ®ister_graph_buffers);
|
||||
|
||||
custom_ar.def("allocate_shared_buffer_and_handle",
|
||||
&allocate_shared_buffer_and_handle);
|
||||
custom_ar.def("open_mem_handle(Tensor mem_handle) -> int", &open_mem_handle);
|
||||
custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle);
|
||||
|
||||
custom_ar.def("free_shared_buffer", &free_shared_buffer);
|
||||
#ifdef USE_ROCM
|
||||
// Quick Reduce all-reduce kernels
|
||||
custom_ar.def(
|
||||
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
|
||||
"cast_bf2half) -> ()");
|
||||
custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
|
||||
|
||||
custom_ar.def("init_custom_qr", &init_custom_qr);
|
||||
custom_ar.def("qr_destroy", &qr_destroy);
|
||||
|
||||
custom_ar.def("qr_get_handle", &qr_get_handle);
|
||||
|
||||
custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
|
||||
custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
|
||||
|
||||
// Max input size in bytes
|
||||
custom_ar.def("qr_max_size", &qr_max_size);
|
||||
#endif
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
||||
@@ -50,7 +50,7 @@ struct _typeConvert<float> {
|
||||
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||
// CUDA < 12.0 runs into issues with packed type conversion
|
||||
template <>
|
||||
struct _typeConvert<c10::Half> {
|
||||
struct _typeConvert<torch::headeronly::Half> {
|
||||
static constexpr bool exists = true;
|
||||
using hip_type = __half;
|
||||
using packed_hip_type = __half2;
|
||||
@@ -73,7 +73,7 @@ struct _typeConvert<c10::Half> {
|
||||
// CUDA_ARCH < 800 does not have BF16 support
|
||||
// ROCm 7.0+ supports bfloat16
|
||||
template <>
|
||||
struct _typeConvert<c10::BFloat16> {
|
||||
struct _typeConvert<torch::headeronly::BFloat16> {
|
||||
static constexpr bool exists = true;
|
||||
using hip_type = __nv_bfloat16;
|
||||
using packed_hip_type = __nv_bfloat162;
|
||||
|
||||
Reference in New Issue
Block a user