[10b/n] Migrate custom all-reduce, DeepSeek V4 fused MLA, MiniMax reduce-RMS, and MXFP8 MoE to libtorch stable ABI (#44365)

Signed-off-by: Chris Leonard <chleonar@redhat.com>
Signed-off-by: Shengqi Chen <harry-chen@outlook.com>
Co-authored-by: Shengqi Chen <harry-chen@outlook.com>
This commit is contained in:
Chris Leonard
2026-06-03 12:29:46 -04:00
committed by GitHub
parent 0a5cbf633e
commit 59d0236193
18 changed files with 568 additions and 481 deletions
+13 -12
View File
@@ -311,14 +311,9 @@ set(VLLM_EXT_SRC
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
"csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/custom_all_reduce.cu"
"csrc/torch_bindings.cpp"
"csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu")
"csrc/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC
"csrc/minimax_reduce_rms_kernel.cu")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
@@ -505,12 +500,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND ES_MXFP8_GROUPED_MM_ARCHS)
set(SRCS
"csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu"
"csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu")
"csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu"
"csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${ES_MXFP8_GROUPED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_ES_MXFP8_GROUPED_MM_SM100=1")
message(STATUS "Building ES MXFP8 grouped kernels for archs: ${ES_MXFP8_GROUPED_MM_ARCHS}")
else()
@@ -600,7 +595,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
if (VLLM_GPU_LANG STREQUAL "HIP")
# Add QuickReduce kernels
# Add QuickReduce kernels (ROCm-only; not part of stable ABI migration).
list(APPEND VLLM_EXT_SRC
"csrc/custom_quickreduce.cu"
)
@@ -651,7 +646,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
"csrc/libtorch_stable/attention/paged_attention_v1.cu"
"csrc/libtorch_stable/attention/paged_attention_v2.cu"
"csrc/libtorch_stable/cache_kernels.cu"
"csrc/libtorch_stable/cache_kernels_fused.cu")
"csrc/libtorch_stable/cache_kernels.cu"
"csrc/libtorch_stable/cache_kernels_fused.cu"
"csrc/libtorch_stable/custom_all_reduce.cu"
"csrc/libtorch_stable/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC
@@ -661,7 +659,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/libtorch_stable/permute_cols.cu"
"csrc/libtorch_stable/quantization/awq/gemm_kernels.cu")
"csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu"
"csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu"
"csrc/libtorch_stable/quantization/awq/gemm_kernels.cu"
"csrc/libtorch_stable/minimax_reduce_rms_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${VLLM_STABLE_EXT_SRC}"
@@ -1,7 +1,11 @@
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#include "torch_utils.h"
#include <torch/csrc/stable/macros.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/csrc/stable/device.h>
#include "custom_all_reduce.cuh"
@@ -11,7 +15,7 @@ using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
torch::Tensor& rank_data, int64_t rank,
torch::stable::Tensor& rank_data, int64_t rank,
bool fully_connected) {
int world_size = fake_ipc_ptrs.size();
if (world_size > 8)
@@ -25,9 +29,9 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
for (int i = 0; i < world_size; i++) {
ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(fake_ipc_ptrs[i]);
}
return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(),
rank_data.numel(), rank, world_size,
fully_connected);
return (fptr_t) new vllm::CustomAllreduce(
ipc_ptrs, rank_data.mutable_data_ptr(), rank_data.numel(), rank,
world_size, fully_connected);
}
/**
@@ -46,10 +50,14 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
* 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK
*/
bool _is_weak_contiguous(torch::Tensor& t) {
return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
t.numel() * t.element_size());
bool _is_weak_contiguous(torch::stable::Tensor& t) {
if (t.is_contiguous()) {
return true;
}
int64_t storage_nbytes = 0;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_storage_size(t.get(), &storage_nbytes));
return storage_nbytes - t.storage_offset() * t.element_size() ==
static_cast<int64_t>(t.numel() * t.element_size());
}
/**
@@ -59,42 +67,45 @@ bool _is_weak_contiguous(torch::Tensor& t) {
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer.
*/
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
void all_reduce(fptr_t _fa, torch::stable::Tensor& inp,
torch::stable::Tensor& out, fptr_t _reg_buffer,
int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
const torch::stable::accelerator::DeviceGuard device_guard(
inp.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(inp.get_device_index());
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(_is_weak_contiguous(out));
TORCH_CHECK(_is_weak_contiguous(inp));
STD_TORCH_CHECK((inp.scalar_type()) == (out.scalar_type()));
STD_TORCH_CHECK((inp.numel()) == (out.numel()));
STD_TORCH_CHECK(_is_weak_contiguous(out));
STD_TORCH_CHECK(_is_weak_contiguous(inp));
auto input_size = inp.numel() * inp.element_size();
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
if (reg_buffer) {
TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size,
STD_TORCH_CHECK((input_size) <= (reg_buffer_sz_bytes));
STD_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.const_data_ptr(), input_size,
cudaMemcpyDeviceToDevice, stream));
} else {
reg_buffer = inp.data_ptr();
reg_buffer = inp.mutable_data_ptr();
}
switch (out.scalar_type()) {
case at::ScalarType::Float: {
case torch::headeronly::ScalarType::Float: {
fa->allreduce<float>(stream, reinterpret_cast<float*>(reg_buffer),
reinterpret_cast<float*>(out.data_ptr()),
reinterpret_cast<float*>(out.mutable_data_ptr()),
out.numel());
break;
}
case at::ScalarType::Half: {
case torch::headeronly::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half*>(reg_buffer),
reinterpret_cast<half*>(out.data_ptr()), out.numel());
reinterpret_cast<half*>(out.mutable_data_ptr()),
out.numel());
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
case torch::headeronly::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
reinterpret_cast<nv_bfloat16*>(out.mutable_data_ptr()), out.numel());
break;
}
#endif
@@ -112,7 +123,7 @@ int64_t meta_size() { return sizeof(vllm::Signal); }
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
STD_TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
void* ipc_ptrs[8];
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
@@ -143,47 +154,49 @@ void register_graph_buffers(fptr_t _fa,
fa->register_graph_buffers(bytes, offsets);
}
std::tuple<fptr_t, torch::Tensor> allocate_shared_buffer_and_handle(
std::tuple<fptr_t, torch::stable::Tensor> allocate_shared_buffer_and_handle(
int64_t size) {
auto device_index = c10::cuda::current_device();
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
int device_index;
STD_CUDA_CHECK(cudaGetDevice(&device_index));
const torch::stable::accelerator::DeviceGuard device_guard(device_index);
void* buffer;
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
auto stream = c10::cuda::getCurrentCUDAStream().stream();
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
const cudaStream_t stream = get_current_cuda_stream(device_index);
STD_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
// Allocate buffer
#if defined(USE_ROCM)
// data buffers need to be "uncached" for signal on MI200
AT_CUDA_CHECK(
STD_CUDA_CHECK(
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
#else
AT_CUDA_CHECK(cudaMalloc((void**)&buffer, size));
STD_CUDA_CHECK(cudaMalloc((void**)&buffer, size));
#endif
AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
STD_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
STD_CUDA_CHECK(cudaStreamSynchronize(stream));
STD_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
// Create IPC memhandle for the allocated buffer.
// Will use it in open_mem_handle.
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto handle =
torch::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, options);
AT_CUDA_CHECK(
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data_ptr(), buffer));
auto handle = torch::stable::empty(
{static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))},
torch::headeronly::ScalarType::Byte, std::nullopt,
torch::stable::Device(torch::stable::DeviceType::CPU));
STD_CUDA_CHECK(cudaIpcGetMemHandle(
(cudaIpcMemHandle_t*)handle.mutable_data_ptr(), buffer));
return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
}
fptr_t open_mem_handle(torch::Tensor& mem_handle) {
fptr_t open_mem_handle(torch::stable::Tensor& mem_handle) {
void* ipc_ptr;
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data_ptr()),
STD_CUDA_CHECK(cudaIpcOpenMemHandle(
(void**)&ipc_ptr,
*((const cudaIpcMemHandle_t*)mem_handle.const_data_ptr()),
cudaIpcMemLazyEnablePeerAccess));
return reinterpret_cast<fptr_t>(ipc_ptr);
}
void free_shared_buffer(fptr_t buffer) {
AT_CUDA_CHECK(cudaFree(reinterpret_cast<void*>(buffer)));
STD_CUDA_CHECK(cudaFree(reinterpret_cast<void*>(buffer)));
}
@@ -28,7 +28,20 @@
* [bs*576, bs*576 + bs*8): UE8M0 scales, 7 real + 1 pad per token
*/
#include "torch_utils.h"
#include <torch/csrc/stable/macros.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/csrc/stable/device.h>
#include <cmath>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "type_convert.cuh"
#ifndef USE_ROCM
#include <cuda_fp8.h>
#else
@@ -37,14 +50,6 @@
#include <cuda_runtime.h>
#include <type_traits>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/cuda.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "type_convert.cuh"
#ifndef FINAL_MASK
#ifdef USE_ROCM
#define FINAL_MASK 0xffffffffffffffffULL
@@ -70,7 +75,7 @@ namespace deepseek_v4_fused_ops {
namespace {
inline int getSMVersion() {
auto* props = at::cuda::getCurrentDeviceProperties();
auto* props = get_device_prop();
return props->major * 10 + props->minor;
}
} // namespace
@@ -564,7 +569,7 @@ static void launchFusedDeepseekV4Templated(
// bf16 on pre-Ampere (sm_70/sm_75) because _typeConvert<BFloat16> is
// unavailable there. Refuse the launch loudly instead of silently
// skipping the work.
TORCH_CHECK(
STD_TORCH_CHECK(
sm_version >= 80,
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert requires sm_80+ "
"(Ampere or newer); got sm_",
@@ -635,7 +640,7 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
DISPATCH(64)
DISPATCH(128)
default:
TORCH_CHECK(false,
STD_TORCH_CHECK(false,
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert: "
"unsupported num_heads_q_padded=",
num_heads_q_padded,
@@ -650,34 +655,42 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
// ────────────────────────────────────────────────────────────────────────────
// Torch op wrapper
// ────────────────────────────────────────────────────────────────────────────
torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::Tensor const& q_in, // [N, num_heads_q, 512] bf16
torch::Tensor const& kv, // [N, 512] bf16 (read-only)
torch::Tensor& k_cache, // [num_blocks, block_bytes] uint8
torch::Tensor const& slot_mapping, // [N] int64
torch::Tensor const& position_ids, // [N] int64
torch::Tensor const& cos_sin_cache, // [max_pos, rope_dim] bf16
torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::stable::Tensor const& q_in, // [N, num_heads_q, 512] bf16
torch::stable::Tensor const& kv, // [N, 512] bf16 (read-only)
torch::stable::Tensor& k_cache, // [num_blocks, block_bytes] uint8
torch::stable::Tensor const& slot_mapping, // [N] int64
torch::stable::Tensor const& position_ids, // [N] int64
torch::stable::Tensor const& cos_sin_cache, // [max_pos, rope_dim] bf16
int64_t q_head_padded, // padded Q head count for output
double eps, int64_t cache_block_size) {
TORCH_CHECK(q_in.is_cuda() && q_in.is_contiguous(),
STD_TORCH_CHECK(q_in.device().is_cuda() && q_in.is_contiguous(),
"q_in must be contiguous CUDA");
TORCH_CHECK(kv.is_cuda() && kv.is_contiguous(), "kv must be contiguous CUDA");
TORCH_CHECK(k_cache.is_cuda(), "k_cache must be CUDA");
TORCH_CHECK(slot_mapping.is_cuda() && slot_mapping.dtype() == torch::kInt64,
STD_TORCH_CHECK(kv.device().is_cuda() && kv.is_contiguous(),
"kv must be contiguous CUDA");
STD_TORCH_CHECK(k_cache.device().is_cuda(), "k_cache must be CUDA");
STD_TORCH_CHECK(slot_mapping.device().is_cuda() &&
slot_mapping.scalar_type() ==
torch::headeronly::ScalarType::Long,
"slot_mapping must be int64 CUDA");
TORCH_CHECK(position_ids.is_cuda() && position_ids.dtype() == torch::kInt64,
STD_TORCH_CHECK(position_ids.device().is_cuda() &&
position_ids.scalar_type() ==
torch::headeronly::ScalarType::Long,
"position_ids must be int64 CUDA");
TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA");
TORCH_CHECK(q_in.dim() == 3 && q_in.size(2) == 512,
STD_TORCH_CHECK(cos_sin_cache.device().is_cuda(), "cos_sin_cache must be CUDA");
STD_TORCH_CHECK(q_in.dim() == 3 && q_in.size(2) == 512,
"q_in shape [N, num_heads_q, 512]");
TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
TORCH_CHECK(q_in.dtype() == kv.dtype(), "q_in and kv dtype must match");
TORCH_CHECK(q_head_padded >= q_in.size(1),
STD_TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
STD_TORCH_CHECK(q_in.scalar_type() == kv.scalar_type(),
"q_in and kv dtype must match");
STD_TORCH_CHECK(q_head_padded >= q_in.size(1),
"q_head_padded must be >= q_in.size(1) (num_heads_q)");
TORCH_CHECK(k_cache.dtype() == torch::kUInt8, "k_cache must be uint8");
TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
STD_TORCH_CHECK(k_cache.scalar_type() == torch::headeronly::ScalarType::Byte,
"k_cache must be uint8");
STD_TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
"cos_sin_cache shape [max_pos, 64]");
TORCH_CHECK(cos_sin_cache.dtype() == torch::kFloat32,
STD_TORCH_CHECK(cos_sin_cache.scalar_type() ==
torch::headeronly::ScalarType::Float,
"cos_sin_cache must be float32");
// With DP padding, slot_mapping can be shorter than q/kv/positions.
@@ -685,36 +698,37 @@ torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
// KV quant+insert runs only on the first slot_mapping.size(0) rows.
int const num_tokens_full = static_cast<int>(q_in.size(0));
int const num_tokens_insert = static_cast<int>(slot_mapping.size(0));
TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
STD_TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
static_cast<int>(position_ids.size(0)) == num_tokens_full,
"q/kv/position_ids row counts must match");
TORCH_CHECK(num_tokens_insert <= num_tokens_full,
STD_TORCH_CHECK(num_tokens_insert <= num_tokens_full,
"slot_mapping must not exceed q row count");
int const num_heads_q = static_cast<int>(q_in.size(1));
int const num_heads_q_padded = static_cast<int>(q_head_padded);
int const cache_block_size_i = static_cast<int>(cache_block_size);
int const kv_block_stride = static_cast<int>(k_cache.stride(0));
at::cuda::OptionalCUDAGuard device_guard(device_of(q_in));
auto stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
q_in.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(q_in.get_device_index());
// Allocate the padded q output. The kernel writes every element (live
// region gets RMSNorm+RoPE; pad region gets zeros), so `empty` is safe.
torch::Tensor q_out = torch::empty(
{q_in.size(0), q_head_padded, q_in.size(2)}, q_in.options());
auto q_out = torch::stable::new_empty(
q_in, {q_in.size(0), q_head_padded, q_in.size(2)}, q_in.scalar_type());
VLLM_DISPATCH_HALF_TYPES(
VLLM_STABLE_DISPATCH_HALF_TYPES(
q_in.scalar_type(), "fused_deepseek_v4_qnorm_rope_kv_insert", [&] {
using qkv_scalar_t = scalar_t;
vllm::deepseek_v4_fused_ops::
launchFusedDeepseekV4QNormRopeKVRopeQuantInsert<qkv_scalar_t>(
reinterpret_cast<qkv_scalar_t const*>(q_in.data_ptr()),
reinterpret_cast<qkv_scalar_t*>(q_out.data_ptr()),
reinterpret_cast<qkv_scalar_t const*>(kv.data_ptr()),
reinterpret_cast<uint8_t*>(k_cache.data_ptr()),
reinterpret_cast<int64_t const*>(slot_mapping.data_ptr()),
reinterpret_cast<int64_t const*>(position_ids.data_ptr()),
cos_sin_cache.data_ptr<float>(), static_cast<float>(eps),
reinterpret_cast<qkv_scalar_t const*>(q_in.const_data_ptr()),
reinterpret_cast<qkv_scalar_t*>(q_out.mutable_data_ptr()),
reinterpret_cast<qkv_scalar_t const*>(kv.const_data_ptr()),
reinterpret_cast<uint8_t*>(k_cache.mutable_data_ptr()),
slot_mapping.const_data_ptr<int64_t>(),
position_ids.const_data_ptr<int64_t>(),
cos_sin_cache.const_data_ptr<float>(), static_cast<float>(eps),
num_tokens_full, num_tokens_insert, num_heads_q,
num_heads_q_padded, cache_block_size_i, kv_block_stride,
stream);
@@ -15,16 +15,19 @@
* limitations under the License.
*/
#include "torch_utils.h"
#include <torch/csrc/stable/macros.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/csrc/stable/device.h>
#include <cooperative_groups.h>
#include <cuda_runtime.h>
#include <torch/cuda.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "cuda_utils.h"
#include "core/registration.h"
#include "minimax_reduce_rms_kernel.h"
#include <algorithm>
@@ -611,7 +614,7 @@ int get_sm_count() {
static int sm_count = 0;
if (sm_count == 0) {
int device_id;
CUDA_CHECK(cudaGetDevice(&device_id));
STD_CUDA_CHECK(cudaGetDevice(&device_id));
cudaDeviceProp device_prop;
cudaGetDeviceProperties(&device_prop, device_id);
sm_count = device_prop.multiProcessorCount;
@@ -621,13 +624,13 @@ int get_sm_count() {
inline int getSMVersion(bool queryRealSmArch = false) {
int device{-1};
CUDA_CHECK(cudaGetDevice(&device));
STD_CUDA_CHECK(cudaGetDevice(&device));
int sm_major = 0;
int sm_minor = 0;
CUDA_CHECK(cudaDeviceGetAttribute(&sm_major,
cudaDevAttrComputeCapabilityMajor, device));
CUDA_CHECK(cudaDeviceGetAttribute(&sm_minor,
cudaDevAttrComputeCapabilityMinor, device));
STD_CUDA_CHECK(cudaDeviceGetAttribute(
&sm_major, cudaDevAttrComputeCapabilityMajor, device));
STD_CUDA_CHECK(cudaDeviceGetAttribute(
&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
int sm = sm_major * 10 + sm_minor;
if (sm == 121 && !queryRealSmArch) {
return 120;
@@ -639,7 +642,7 @@ template <typename KernelFunc>
int get_max_active_blocks(KernelFunc kernel, int block_size,
int dynamic_smem = 0) {
int max_active = 0;
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
STD_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active, kernel, block_size, dynamic_smem));
return std::max(max_active, 1);
}
@@ -678,27 +681,27 @@ void minimax_reduce_rms_kernel_launcher(MiniMaxReduceRMSParams const& params) {
cfg.attrs = attribute;
cfg.numAttrs = SM >= 90 ? 2 : 0;
CUDA_CHECK(cudaLaunchKernelEx(
STD_CUDA_CHECK(cudaLaunchKernelEx(
&cfg, minimax_reduce_rms_kernel_lamport<DType, NRanks>, params));
}
template <typename DType, int NRanks, int OriginQDim, int OriginKDim>
void minimax_reduce_rms_kernel_launcher_float4(
MiniMaxReduceRMSParams const& params) {
TORCH_CHECK(params.size_q % params.hidden_dim == 0);
TORCH_CHECK(params.hidden_dim % kElemsPerAccess<DType> == 0);
STD_TORCH_CHECK(params.size_q % params.hidden_dim == 0);
STD_TORCH_CHECK(params.hidden_dim % kElemsPerAccess<DType> == 0);
if (params.stride_q > 0) {
TORCH_CHECK(params.stride_q % kElemsPerAccess<DType> == 0);
STD_TORCH_CHECK(params.stride_q % kElemsPerAccess<DType> == 0);
}
TORCH_CHECK(params.allreduce_in_k != nullptr,
STD_TORCH_CHECK(params.allreduce_in_k != nullptr,
"float4 QK kernel requires K input");
TORCH_CHECK(params.hidden_dim >= params.hidden_dim_k);
TORCH_CHECK(params.size_k % params.hidden_dim_k == 0);
TORCH_CHECK(params.hidden_dim_k % kElemsPerAccess<DType> == 0);
TORCH_CHECK(params.size_q / params.hidden_dim ==
STD_TORCH_CHECK(params.hidden_dim >= params.hidden_dim_k);
STD_TORCH_CHECK(params.size_k % params.hidden_dim_k == 0);
STD_TORCH_CHECK(params.hidden_dim_k % kElemsPerAccess<DType> == 0);
STD_TORCH_CHECK(params.size_q / params.hidden_dim ==
params.size_k / params.hidden_dim_k);
if (params.stride_k > 0) {
TORCH_CHECK(params.stride_k % kElemsPerAccess<DType> == 0);
STD_TORCH_CHECK(params.stride_k % kElemsPerAccess<DType> == 0);
}
int token_num = params.size_q / params.hidden_dim;
@@ -746,7 +749,7 @@ void minimax_reduce_rms_kernel_launcher_float4(
cfg.attrs = attribute;
cfg.numAttrs = SM >= 90 ? 2 : 0;
CUDA_CHECK(cudaLaunchKernelEx(&cfg, kfn, params));
STD_CUDA_CHECK(cudaLaunchKernelEx(&cfg, kfn, params));
}
template <int NRanks>
@@ -759,21 +762,21 @@ void dispatch_dtype(MiniMaxReduceRMSParams const& params) {
(params.hidden_dim * params.nranks == 6144) &&
(params.hidden_dim_k * params.nranks == 1024);
if (params.dtype == at::ScalarType::Half) {
if (params.dtype == torch::headeronly::ScalarType::Half) {
if (use_float4) {
minimax_reduce_rms_kernel_launcher_float4<half, NRanks, 6144, 1024>(
params);
} else {
minimax_reduce_rms_kernel_launcher<half, NRanks>(params);
}
} else if (params.dtype == at::ScalarType::BFloat16) {
} else if (params.dtype == torch::headeronly::ScalarType::BFloat16) {
if (use_float4) {
minimax_reduce_rms_kernel_launcher_float4<__nv_bfloat16, NRanks, 6144,
1024>(params);
} else {
minimax_reduce_rms_kernel_launcher<__nv_bfloat16, NRanks>(params);
}
} else if (params.dtype == at::ScalarType::Float) {
} else if (params.dtype == torch::headeronly::ScalarType::Float) {
if (use_float4) {
minimax_reduce_rms_kernel_launcher_float4<float, NRanks, 6144, 1024>(
params);
@@ -781,7 +784,7 @@ void dispatch_dtype(MiniMaxReduceRMSParams const& params) {
minimax_reduce_rms_kernel_launcher<float, NRanks>(params);
}
} else {
TORCH_CHECK(false, "Unsupported data type for minimax_reduce_rms_op");
STD_TORCH_CHECK(false, "Unsupported data type for minimax_reduce_rms_op");
}
}
@@ -795,16 +798,18 @@ void minimax_reduce_rms_op(MiniMaxReduceRMSParams const& params) {
} else if (params.nranks == 16) {
dispatch_dtype<16>(params);
} else {
TORCH_CHECK(false, "minimax_reduce_rms_op: unsupported ranks number!");
STD_TORCH_CHECK(false, "minimax_reduce_rms_op: unsupported ranks number!");
}
}
} // namespace tensorrt_llm
} // namespace vllm
torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
torch::Tensor const& norm_weight,
torch::Tensor workspace, int64_t const rank,
int64_t const nranks, double const eps) {
torch::stable::Tensor minimax_allreduce_rms(
torch::stable::Tensor const& input,
torch::stable::Tensor const& norm_weight, torch::stable::Tensor workspace,
int64_t const rank, int64_t const nranks, double const eps) {
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
auto allreduce_params = vllm::tensorrt_llm::MiniMaxReduceRMSParams();
allreduce_params.nranks = static_cast<int>(nranks);
@@ -815,12 +820,12 @@ torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
allreduce_params.stride_q = allreduce_params.hidden_dim;
allreduce_params.workspace =
reinterpret_cast<void**>(workspace.mutable_data_ptr());
allreduce_params.allreduce_in = input.data_ptr();
allreduce_params.rms_gamma = norm_weight.data_ptr();
allreduce_params.allreduce_in = const_cast<void*>(input.const_data_ptr());
allreduce_params.rms_gamma = const_cast<void*>(norm_weight.const_data_ptr());
allreduce_params.rms_eps = static_cast<float>(eps);
allreduce_params.stream = at::cuda::getCurrentCUDAStream(input.get_device());
allreduce_params.stream = get_current_cuda_stream(input.get_device_index());
torch::Tensor rms_norm_out = torch::empty_like(input);
torch::stable::Tensor rms_norm_out = torch::stable::empty_like(input);
allreduce_params.rms_norm_out = rms_norm_out.mutable_data_ptr();
vllm::tensorrt_llm::minimax_reduce_rms_op(allreduce_params);
@@ -828,26 +833,33 @@ torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
return rms_norm_out;
}
std::tuple<torch::Tensor, torch::Tensor> minimax_allreduce_rms_qk(
torch::Tensor qkv, torch::Tensor const& norm_weight_q,
torch::Tensor const& norm_weight_k, torch::Tensor workspace,
int64_t const q_size, int64_t const kv_size, int64_t const rank,
std::tuple<torch::stable::Tensor, torch::stable::Tensor>
minimax_allreduce_rms_qk(torch::stable::Tensor qkv,
torch::stable::Tensor const& norm_weight_q,
torch::stable::Tensor const& norm_weight_k,
torch::stable::Tensor workspace, int64_t const q_size,
int64_t const kv_size, int64_t const rank,
int64_t const nranks, double const eps) {
TORCH_CHECK(qkv.dim() == 2, "minimax_allreduce_rms_qk: qkv must be 2D");
TORCH_CHECK(qkv.is_contiguous(),
STD_TORCH_CHECK(qkv.dim() == 2, "minimax_allreduce_rms_qk: qkv must be 2D");
STD_TORCH_CHECK(qkv.is_contiguous(),
"minimax_allreduce_rms_qk: qkv must be contiguous");
int64_t qkv_dim = qkv.size(-1);
TORCH_CHECK(qkv_dim == q_size + 2 * kv_size,
STD_TORCH_CHECK(qkv_dim == q_size + 2 * kv_size,
"minimax_allreduce_rms_qk: qkv last dim must equal "
"q_size + 2 * kv_size");
TORCH_CHECK(rank < nranks,
STD_TORCH_CHECK(rank < nranks,
"minimax_allreduce_rms_qk: rank must be less than nranks");
const torch::stable::accelerator::DeviceGuard device_guard(
qkv.get_device_index());
int64_t num_tokens = qkv.size(0);
int elem_bytes = qkv.element_size();
torch::Tensor q_out = torch::empty({num_tokens, q_size}, qkv.options());
torch::Tensor k_out = torch::empty({num_tokens, kv_size}, qkv.options());
torch::stable::Tensor q_out =
torch::stable::new_empty(qkv, {num_tokens, q_size}, qkv.scalar_type());
torch::stable::Tensor k_out =
torch::stable::new_empty(qkv, {num_tokens, kv_size}, qkv.scalar_type());
auto params = vllm::tensorrt_llm::MiniMaxReduceRMSParams();
params.nranks = static_cast<int>(nranks);
@@ -863,13 +875,14 @@ std::tuple<torch::Tensor, torch::Tensor> minimax_allreduce_rms_qk(
params.stride_k_out = 0; // k_out is contiguous; kernel uses hidden_dim_k
params.workspace = reinterpret_cast<void**>(workspace.mutable_data_ptr());
uint8_t* base = static_cast<uint8_t*>(qkv.data_ptr());
uint8_t* base =
const_cast<uint8_t*>(static_cast<const uint8_t*>(qkv.const_data_ptr()));
params.allreduce_in = base;
params.allreduce_in_k = base + q_size * elem_bytes;
params.rms_gamma = norm_weight_q.data_ptr();
params.rms_gamma_k = norm_weight_k.data_ptr();
params.rms_gamma = const_cast<void*>(norm_weight_q.const_data_ptr());
params.rms_gamma_k = const_cast<void*>(norm_weight_k.const_data_ptr());
params.rms_eps = static_cast<float>(eps);
params.stream = at::cuda::getCurrentCUDAStream(qkv.get_device());
params.stream = get_current_cuda_stream(qkv.get_device_index());
params.rms_norm_out = q_out.mutable_data_ptr();
params.rms_norm_out_k = k_out.mutable_data_ptr();
@@ -0,0 +1,69 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_mxfp8_grouped_mm_launcher.cuh"
void cutlass_mxfp8_grouped_mm(const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
const torch::stable::Tensor& sfa,
const torch::stable::Tensor& sfb,
torch::stable::Tensor& d,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& blockscale_offsets) {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
STD_TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
STD_TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have shape (num_experts, 3)");
STD_TORCH_CHECK(
problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
STD_TORCH_CHECK(
problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int,
"problem_sizes must be int32");
STD_TORCH_CHECK(
expert_offsets.scalar_type() == torch::headeronly::ScalarType::Int,
"expert_offsets must be int32");
STD_TORCH_CHECK(
blockscale_offsets.scalar_type() == torch::headeronly::ScalarType::Int,
"blockscale_offsets must be int32");
STD_TORCH_CHECK(a.dim() == 2,
"a must be a 2D tensor of shape (num_tokens, k)");
STD_TORCH_CHECK(b.dim() == 3,
"b must be a 3D tensor of shape (num_experts, k, n)");
STD_TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0,
"k should align 128");
STD_TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128");
STD_TORCH_CHECK(a.stride(1) == 1, "a must be row major");
STD_TORCH_CHECK(b.stride(1) == 1, "b must be column major");
const torch::stable::accelerator::DeviceGuard device_guard(
a.get_device_index());
auto stream = get_current_cuda_stream(a.get_device_index());
if (d.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
cutlass::bfloat16_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
blockscale_offsets, stream);
} else if (d.scalar_type() == torch::headeronly::ScalarType::Half) {
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
cutlass::half_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
blockscale_offsets, stream);
} else {
STD_TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
}
#else
STD_TORCH_CHECK(false,
"No implemented cutlass_mxfp8_grouped_mm for "
"current device");
#endif
}
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("cutlass_mxfp8_grouped_mm", TORCH_BOX(&cutlass_mxfp8_grouped_mm));
}
@@ -4,9 +4,9 @@
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_launcher.cuh
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/Exception.h>
#include <cassert>
#include <iostream>
@@ -15,18 +15,22 @@
#include "cute/tensor.hpp"
#include "cutlass_mxfp8_grouped_mm_functor.cuh"
#include "cutlass_mxfp8_grouped_mm_traits.cuh"
#include "libtorch_stable/torch_utils.h"
namespace expert_specialization {
template <typename GemmTraits>
void cutlass_mxfp8_grouped_mm_pre_compute(
torch::Tensor& a_ptrs, torch::Tensor& b_ptrs, torch::Tensor& sfa_ptrs,
torch::Tensor& sfb_ptrs, torch::Tensor& d_ptrs, torch::Tensor& stride_a,
torch::Tensor& stride_b, torch::Tensor& stride_d, torch::Tensor& layout_sfa,
torch::Tensor& layout_sfb, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& sfa, const torch::Tensor& sfb, const torch::Tensor& d,
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets, cudaStream_t stream) {
torch::stable::Tensor& a_ptrs, torch::stable::Tensor& b_ptrs,
torch::stable::Tensor& sfa_ptrs, torch::stable::Tensor& sfb_ptrs,
torch::stable::Tensor& d_ptrs, torch::stable::Tensor& stride_a,
torch::stable::Tensor& stride_b, torch::stable::Tensor& stride_d,
torch::stable::Tensor& layout_sfa, torch::stable::Tensor& layout_sfb,
const torch::stable::Tensor& a, const torch::stable::Tensor& b,
const torch::stable::Tensor& sfa, const torch::stable::Tensor& sfb,
const torch::stable::Tensor& d, const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& blockscale_offsets, cudaStream_t stream) {
using OffsetFunctor = CutlassMxfp8GroupedMmOffsetFunctor<GemmTraits>;
using ElementA = typename OffsetFunctor::ElementA;
using ElementB = typename OffsetFunctor::ElementB;
@@ -42,8 +46,8 @@ void cutlass_mxfp8_grouped_mm_pre_compute(
using StrideB = typename StrideFunctor::StrideB;
using StrideD = typename StrideFunctor::StrideD;
int num_experts = (int)expert_offsets.size(0);
TORCH_CHECK(num_experts <= 1024,
int num_experts = static_cast<int>(expert_offsets.size(0));
STD_TORCH_CHECK(num_experts <= 1024,
"Number of experts cannot exceed 1024, the maximum number of "
"threads per block.");
@@ -72,13 +76,18 @@ void cutlass_mxfp8_grouped_mm_pre_compute(
}
template <typename GemmTraits>
void cutlass_mxfp8_grouped_mm(
const torch::Tensor& a_ptrs, const torch::Tensor& b_ptrs,
const torch::Tensor& sfa_ptrs, const torch::Tensor& sfb_ptrs,
const torch::Tensor& d_ptrs, const torch::Tensor& stride_a,
const torch::Tensor& stride_b, const torch::Tensor& stride_d,
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes, cudaStream_t stream) {
void cutlass_mxfp8_grouped_mm(const torch::stable::Tensor& a_ptrs,
const torch::stable::Tensor& b_ptrs,
const torch::stable::Tensor& sfa_ptrs,
const torch::stable::Tensor& sfb_ptrs,
const torch::stable::Tensor& d_ptrs,
const torch::stable::Tensor& stride_a,
const torch::stable::Tensor& stride_b,
const torch::stable::Tensor& stride_d,
const torch::stable::Tensor& layout_sfa,
const torch::stable::Tensor& layout_sfb,
const torch::stable::Tensor& problem_sizes,
cudaStream_t stream) {
using Gemm = typename GemmTraits::Gemm;
using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
@@ -93,13 +102,12 @@ void cutlass_mxfp8_grouped_mm(
typename GemmTraits::ProblemShape::UnderlyingProblemShape;
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = c10::cuda::current_device();
hw_info.sm_count =
at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
hw_info.device_id = d_ptrs.get_device_index();
hw_info.sm_count = get_device_prop()->multiProcessorCount;
hw_info.cluster_shape = GemmTraits::MMAConfig::preferred_cluster;
hw_info.cluster_shape_fallback = GemmTraits::MMAConfig::fallback_cluster;
int num_experts = (int)problem_sizes.size(0);
int num_experts = static_cast<int>(problem_sizes.size(0));
UnderlyingProblemShape* underlying_problem_shape =
reinterpret_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
@@ -127,44 +135,55 @@ void cutlass_mxfp8_grouped_mm(
Gemm gemm;
auto can_implement_status = gemm.can_implement(arguments);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
STD_TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM");
torch::TensorOptions options_uint8 =
torch::TensorOptions().dtype(torch::kUInt8).device(d_ptrs.device());
size_t workspace_size = gemm.get_workspace_size(arguments);
torch::Tensor workspace = torch::empty(workspace_size, options_uint8);
torch::stable::Tensor workspace = torch::stable::empty(
{static_cast<int64_t>(workspace_size)},
torch::headeronly::ScalarType::Byte, std::nullopt, d_ptrs.device());
auto status = gemm.initialize(arguments, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM");
status = gemm.run(stream, nullptr, true); // Enable PDL
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
template <typename OutType>
void cutlass_mxfp8_grouped_mm_dispatch_out_dtype(
const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& sfa,
const torch::Tensor& sfb, torch::Tensor& d,
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets, cudaStream_t stream) {
int num_experts = (int)problem_sizes.size(0);
torch::TensorOptions options_int64 =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::TensorOptions options_int32 =
torch::TensorOptions().dtype(torch::kInt32).device(a.device());
const torch::stable::Tensor& a, const torch::stable::Tensor& b,
const torch::stable::Tensor& sfa, const torch::stable::Tensor& sfb,
torch::stable::Tensor& d, const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& blockscale_offsets, cudaStream_t stream) {
int num_experts = static_cast<int>(problem_sizes.size(0));
auto device = a.device();
torch::Tensor a_ptrs = torch::empty(num_experts, options_int64);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int64);
torch::Tensor sfa_ptrs = torch::empty(num_experts, options_int64);
torch::Tensor sfb_ptrs = torch::empty(num_experts, options_int64);
torch::Tensor d_ptrs = torch::empty(num_experts, options_int64);
torch::stable::Tensor a_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor sfa_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor sfb_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor d_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::Tensor stride_a = torch::empty(num_experts, options_int64);
torch::Tensor stride_b = torch::empty(num_experts, options_int64);
torch::Tensor stride_d = torch::empty(num_experts, options_int64);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int32);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int32);
torch::stable::Tensor stride_a = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor stride_b = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor stride_d = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor layout_sfa =
torch::stable::empty({num_experts, 5}, torch::headeronly::ScalarType::Int,
std::nullopt, device);
torch::stable::Tensor layout_sfb =
torch::stable::empty({num_experts, 5}, torch::headeronly::ScalarType::Int,
std::nullopt, device);
using GemmTraits = CutlassMxfp8GroupedMmGemmTraits<MMA1SMConfig, OutType>;
cutlass_mxfp8_grouped_mm_pre_compute<GemmTraits>(
@@ -0,0 +1,66 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "mxfp8_experts_quant.cuh"
void mxfp8_experts_quant(const torch::stable::Tensor& input,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& blockscale_offsets,
torch::stable::Tensor& quant_output,
torch::stable::Tensor& scale_factor) {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
STD_TORCH_CHECK(input.dim() == 2, "input must be 2D tensor");
STD_TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128");
STD_TORCH_CHECK(input.stride(1) == 1, "input must be row major");
STD_TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
STD_TORCH_CHECK(
problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int,
"problem_sizes must be int32");
STD_TORCH_CHECK(
expert_offsets.scalar_type() == torch::headeronly::ScalarType::Int,
"expert_offsets must be int32");
STD_TORCH_CHECK(
blockscale_offsets.scalar_type() == torch::headeronly::ScalarType::Int,
"blockscale_offsets must be int32");
auto groups = problem_sizes.size(0);
STD_TORCH_CHECK(
expert_offsets.dim() == 1 && expert_offsets.size(0) == groups,
"expert_offsets must be 1D and have size equal to the number of groups");
STD_TORCH_CHECK(
blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups,
"blockscale_offsets must be 1D and have size equal to the number of "
"groups");
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
if (input.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
expert_specialization::launch_mxfp8_experts_quant<__nv_bfloat16>(
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
scale_factor);
} else if (input.scalar_type() == torch::headeronly::ScalarType::Half) {
expert_specialization::launch_mxfp8_experts_quant<__half>(
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
scale_factor);
} else {
STD_TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
}
#else
STD_TORCH_CHECK(false,
"No implemented mxfp8_experts_quant for "
"current device");
#endif
}
// Registered here (not torch_bindings.cpp) because ENABLE_ES_MXFP8_GROUPED_MM
// is applied only under COMPILE_LANGUAGE:CUDA.
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("mxfp8_experts_quant", TORCH_BOX(&mxfp8_experts_quant));
}
@@ -4,16 +4,19 @@
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cuh
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/all.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/macros.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/Exception.h>
#include <cuda/ptx>
#include "cute/tensor.hpp"
#include "libtorch_stable/torch_utils.h"
namespace expert_specialization {
@@ -356,12 +359,12 @@ __global__ void mxfp8_experts_quant_kernel(
}
template <typename T_IN>
void launch_mxfp8_experts_quant(const torch::Tensor& input,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets,
torch::Tensor& quant_output,
torch::Tensor& scale_factor) {
void launch_mxfp8_experts_quant(const torch::stable::Tensor& input,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& blockscale_offsets,
torch::stable::Tensor& quant_output,
torch::stable::Tensor& scale_factor) {
ThrLayout thr_layout{};
ValLayout val_layout{};
SfR2SThrLayout r2s_thr_layout{};
@@ -386,19 +389,18 @@ void launch_mxfp8_experts_quant(const torch::Tensor& input,
CopyAtomR2S{}, r2s_thr_layout, r2s_val_layout); // Tiler_MN: (16, 4)
int max_active_blocks_per_sm = -1;
AT_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
STD_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_per_sm,
mxfp8_experts_quant_kernel<T_IN, decltype(tiled_copy_g2r),
decltype(tiled_copy_r2g),
decltype(tiled_copy_r2s)>,
THREAD_BLOCK_SIZE, 0));
dim3 grid(at::cuda::getCurrentDeviceProperties()->multiProcessorCount *
max_active_blocks_per_sm,
dim3 grid(get_device_prop()->multiProcessorCount * max_active_blocks_per_sm,
1, 1);
dim3 block(THREAD_BLOCK_SIZE, 1, 1);
int num_experts = (int)problem_sizes.size(0);
auto stream = at::cuda::getCurrentCUDAStream();
int num_experts = static_cast<int>(problem_sizes.size(0));
auto stream = get_current_cuda_stream(input.get_device_index());
mxfp8_experts_quant_kernel<T_IN, decltype(tiled_copy_g2r),
decltype(tiled_copy_r2g), decltype(tiled_copy_r2s)>
<<<grid, block, 0, stream>>>(
+41
View File
@@ -231,6 +231,27 @@ void fused_qk_norm_rope(torch::stable::Tensor& qkv, int64_t num_heads_q,
torch::stable::Tensor& position_ids,
int64_t forced_token_heads_per_warp);
torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::stable::Tensor const& q_in, torch::stable::Tensor const& kv,
torch::stable::Tensor& k_cache, torch::stable::Tensor const& slot_mapping,
torch::stable::Tensor const& position_ids,
torch::stable::Tensor const& cos_sin_cache, int64_t q_head_padded,
double eps, int64_t cache_block_size);
#ifndef USE_ROCM
torch::stable::Tensor minimax_allreduce_rms(
torch::stable::Tensor const& input,
torch::stable::Tensor const& norm_weight, torch::stable::Tensor workspace,
int64_t const rank, int64_t const nranks, double const eps);
std::tuple<torch::stable::Tensor, torch::stable::Tensor>
minimax_allreduce_rms_qk(torch::stable::Tensor qkv,
torch::stable::Tensor const& norm_weight_q,
torch::stable::Tensor const& norm_weight_k,
torch::stable::Tensor workspace, int64_t const q_size,
int64_t const kv_size, int64_t const rank,
int64_t const nranks, double const eps);
#endif
// Sampler kernels (shared CUDA/ROCm)
void apply_repetition_penalties_(
torch::stable::Tensor& logits, const torch::stable::Tensor& prompt_mask,
@@ -273,6 +294,26 @@ void selective_scan_fwd(
const std::optional<torch::stable::Tensor>& cu_chunk_seqlen,
const std::optional<torch::stable::Tensor>& last_chunk_indices);
using fptr_t = int64_t;
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
torch::stable::Tensor& rank_data, int64_t rank,
bool fully_connected);
void all_reduce(fptr_t _fa, torch::stable::Tensor& inp,
torch::stable::Tensor& out, fptr_t reg_buffer,
int64_t reg_buffer_sz_bytes);
void dispose(fptr_t _fa);
int64_t meta_size();
void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa,
const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets);
std::tuple<int64_t, torch::stable::Tensor> allocate_shared_buffer_and_handle(
int64_t size);
int64_t open_mem_handle(torch::stable::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);
// Activation kernels (shared CUDA/ROCm)
void silu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input);
void silu_and_mul_clamp(torch::stable::Tensor& out,
+63
View File
@@ -337,6 +337,24 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"bool is_neox, Tensor position_ids, "
"int forced_token_heads_per_warp=-1) -> ()");
ops.def(
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert("
"Tensor q_in, Tensor kv, Tensor! k_cache, "
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
"int q_head_padded, float eps, int cache_block_size) -> Tensor");
#ifndef USE_ROCM
ops.def(
"minimax_allreduce_rms("
"Tensor input, Tensor norm_weight, Tensor workspace, "
"int rank, int nranks, float eps) -> Tensor");
ops.def(
"minimax_allreduce_rms_qk("
"Tensor qkv, Tensor norm_weight_q, Tensor norm_weight_k, "
"Tensor workspace, int q_size, int kv_size, int rank, int nranks, "
"float eps) -> (Tensor, Tensor)");
#endif
// Apply repetition penalties to logits in-place.
ops.def(
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
@@ -571,6 +589,12 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
// Positional encoding kernels (shared CUDA/ROCm)
ops.impl("rotary_embedding", TORCH_BOX(&rotary_embedding));
ops.impl("fused_qk_norm_rope", TORCH_BOX(&fused_qk_norm_rope));
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert",
TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert));
#ifndef USE_ROCM
ops.impl("minimax_allreduce_rms", TORCH_BOX(&minimax_allreduce_rms));
ops.impl("minimax_allreduce_rms_qk", TORCH_BOX(&minimax_allreduce_rms_qk));
#endif
// Sampler kernels (shared CUDA/ROCm)
ops.impl("apply_repetition_penalties_",
@@ -725,6 +749,45 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C_cache_ops, ops) {
"dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()");
}
STABLE_TORCH_LIBRARY_FRAGMENT(_C_custom_ar, custom_ar) {
custom_ar.def(
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
"int rank, bool fully_connected) -> int");
custom_ar.def(
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()");
custom_ar.def("dispose(int fa) -> ()");
custom_ar.def("meta_size() -> int");
custom_ar.def("register_buffer(int fa, int[] ipc_tensors) -> ()");
custom_ar.def("get_graph_buffer_ipc_meta(int fa) -> (int[], int[])");
custom_ar.def(
"register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()");
custom_ar.def("allocate_shared_buffer_and_handle(int size) -> (int, Tensor)");
custom_ar.def("open_mem_handle(Tensor mem_handle) -> int");
custom_ar.def("free_shared_buffer(int ptr) -> ()");
}
STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CUDA, custom_ar) {
custom_ar.impl("init_custom_ar", TORCH_BOX(&init_custom_ar));
custom_ar.impl("all_reduce", TORCH_BOX(&all_reduce));
}
STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CPU, custom_ar) {
custom_ar.impl("open_mem_handle", TORCH_BOX(&open_mem_handle));
}
STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CompositeExplicitAutograd, custom_ar) {
custom_ar.impl("dispose", TORCH_BOX(&dispose));
custom_ar.impl("meta_size", TORCH_BOX(&meta_size));
custom_ar.impl("register_buffer", TORCH_BOX(&register_buffer));
custom_ar.impl("get_graph_buffer_ipc_meta",
TORCH_BOX(&get_graph_buffer_ipc_meta));
custom_ar.impl("register_graph_buffers", TORCH_BOX(&register_graph_buffers));
custom_ar.impl("allocate_shared_buffer_and_handle",
TORCH_BOX(&allocate_shared_buffer_and_handle));
custom_ar.impl("free_shared_buffer", TORCH_BOX(&free_shared_buffer));
}
STABLE_TORCH_LIBRARY_IMPL(_C_cache_ops, CPU, ops) {
ops.impl("swap_blocks_batch", TORCH_BOX(&swap_blocks_batch));
}
+2 -2
View File
@@ -19,7 +19,7 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/types.h>
#include <torch/headeronly/core/ScalarType.h>
namespace vllm {
namespace tensorrt_llm {
@@ -51,7 +51,7 @@ static constexpr int kElemsPerAccess = ElemsPerAccess<DType>::value;
struct MiniMaxReduceRMSParams {
int nranks{};
int rank{};
at::ScalarType dtype{at::ScalarType::Undefined};
torch::headeronly::ScalarType dtype{torch::headeronly::ScalarType::Undefined};
int size_q{};
int hidden_dim{};
int size_k{};
@@ -1,60 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu
#include <torch/all.h>
#include "cutlass_mxfp8_grouped_mm_launcher.cuh"
void cutlass_mxfp8_grouped_mm(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& sfa,
const torch::Tensor& sfb, torch::Tensor& d,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets) {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have shape (num_experts, 3)");
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
"problem_sizes must be int32");
TORCH_CHECK(expert_offsets.dtype() == torch::kInt32,
"expert_offsets must be int32");
TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32,
"blockscale_offsets must be int32");
TORCH_CHECK(a.dim() == 2, "a must be a 2D tensor of shape (num_tokens, k)");
TORCH_CHECK(b.dim() == 3,
"b must be a 3D tensor of shape (num_experts, k, n)");
TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0,
"k should align 128");
TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128");
TORCH_CHECK(a.strides()[1] == 1, "a must be row major");
TORCH_CHECK(b.strides()[1] == 1, "b must be column major");
auto stream = at::cuda::getCurrentCUDAStream();
if (d.dtype() == torch::kBFloat16) {
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
cutlass::bfloat16_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
blockscale_offsets, stream);
} else if (d.dtype() == torch::kFloat16) {
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
cutlass::half_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
blockscale_offsets, stream);
} else {
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
}
#else
TORCH_CHECK(false,
"No implemented cutlass_mxfp8_grouped_mm for "
"current device");
#endif
}
#include "core/registration.h"
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_mxfp8_grouped_mm", cutlass_mxfp8_grouped_mm);
}
-60
View File
@@ -1,60 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu
#include <torch/all.h>
#include "mxfp8_experts_quant.cuh"
void mxfp8_experts_quant(const torch::Tensor& input,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets,
torch::Tensor& quant_output,
torch::Tensor& scale_factor) {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
TORCH_CHECK(input.dim() == 2, "input must be 2D tensor");
TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128");
TORCH_CHECK(input.strides()[1] == 1, "input must be row major");
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
"problem_sizes must be int32");
TORCH_CHECK(expert_offsets.dtype() == torch::kInt32,
"expert_offsets must be int32");
TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32,
"blockscale_offsets must be int32");
auto groups = problem_sizes.size(0);
TORCH_CHECK(
expert_offsets.dim() == 1 && expert_offsets.size(0) == groups,
"expert_offsets must be 1D and have size equal to the number of groups");
TORCH_CHECK(
blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups,
"blockscale_offsets must be 1D and have size equal to the number of "
"groups");
auto stream = at::cuda::getCurrentCUDAStream();
if (input.dtype() == torch::kBFloat16) {
expert_specialization::launch_mxfp8_experts_quant<__nv_bfloat16>(
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
scale_factor);
} else if (input.dtype() == torch::kFloat16) {
expert_specialization::launch_mxfp8_experts_quant<__half>(
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
scale_factor);
} else {
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
}
#else
TORCH_CHECK(false,
"No implemented mxfp8_experts_quant for "
"current device");
#endif
}
#include "core/registration.h"
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("mxfp8_experts_quant", mxfp8_experts_quant);
}
-36
View File
@@ -40,12 +40,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);
torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::Tensor const& q_in, torch::Tensor const& kv, torch::Tensor& k_cache,
torch::Tensor const& slot_mapping, torch::Tensor const& position_ids,
torch::Tensor const& cos_sin_cache, int64_t q_head_padded, double eps,
int64_t cache_block_size);
void silu_and_mul_per_block_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scales, int64_t group_size,
@@ -107,24 +101,6 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
int64_t activation_kind);
using fptr_t = int64_t;
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
torch::Tensor& rank_data, int64_t rank,
bool fully_connected);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
void dispose(fptr_t _fa);
int64_t meta_size();
void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa,
const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets);
std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
int64_t size);
int64_t open_mem_handle(torch::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);
#ifdef USE_ROCM
fptr_t init_custom_qr(int64_t rank, int64_t world_size,
std::optional<int64_t> qr_max_size = std::nullopt);
@@ -135,15 +111,3 @@ void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int64_t quant_level, bool cast_bf2half = false);
int64_t qr_max_size();
#endif
#ifndef USE_ROCM
torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
torch::Tensor const& norm_weight,
torch::Tensor workspace, int64_t const rank,
int64_t const nranks, double const eps);
std::tuple<torch::Tensor, torch::Tensor> minimax_allreduce_rms_qk(
torch::Tensor qkv, torch::Tensor const& norm_weight_q,
torch::Tensor const& norm_weight_k, torch::Tensor workspace,
int64_t const q_size, int64_t const kv_size, int64_t const rank,
int64_t const nranks, double const eps);
#endif
+20 -78
View File
@@ -55,14 +55,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
// kernel launch.
ops.def(
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert("
"Tensor q_in, Tensor kv, Tensor! k_cache, "
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
"int q_head_padded, float eps, int cache_block_size) -> Tensor");
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA,
&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert);
// kernel launch. Registered in _C_stable_libtorch.
// Quantization ops
#ifndef USE_ROCM
@@ -163,34 +156,27 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// conditionally compiled so impl registration is in source file
#endif
#ifndef USE_ROCM
ops.def(
"minimax_allreduce_rms("
"Tensor input,"
"Tensor norm_weight,"
"Tensor workspace,"
"int rank,"
"int nranks,"
"float eps) -> Tensor");
ops.impl("minimax_allreduce_rms", torch::kCUDA, &minimax_allreduce_rms);
ops.def(
"minimax_allreduce_rms_qk("
"Tensor qkv,"
"Tensor norm_weight_q,"
"Tensor norm_weight_k,"
"Tensor workspace,"
"int q_size,"
"int kv_size,"
"int rank,"
"int nranks,"
"float eps) -> (Tensor, Tensor)");
ops.impl("minimax_allreduce_rms_qk", torch::kCUDA, &minimax_allreduce_rms_qk);
// conditionally compiled so impl in source file
#endif
}
#ifdef USE_ROCM
TORCH_LIBRARY_FRAGMENT(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
// Quick Reduce all-reduce kernels (ROCm-only; stays on legacy _C).
custom_ar.def(
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
"cast_bf2half) -> ()");
custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
custom_ar.def("init_custom_qr", &init_custom_qr);
custom_ar.def("qr_destroy", &qr_destroy);
custom_ar.def("qr_get_handle", &qr_get_handle);
custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
custom_ar.def("qr_max_size", &qr_max_size);
}
#endif
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
// Cuda utils
@@ -205,48 +191,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
&get_max_shared_memory_per_block_device_attribute);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
// Custom all-reduce kernels
custom_ar.def(
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
"int rank, bool fully_connected) -> int");
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
custom_ar.def(
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()");
custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce);
custom_ar.def("dispose", &dispose);
custom_ar.def("meta_size", &meta_size);
custom_ar.def("register_buffer", &register_buffer);
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
custom_ar.def("register_graph_buffers", &register_graph_buffers);
custom_ar.def("allocate_shared_buffer_and_handle",
&allocate_shared_buffer_and_handle);
custom_ar.def("open_mem_handle(Tensor mem_handle) -> int", &open_mem_handle);
custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle);
custom_ar.def("free_shared_buffer", &free_shared_buffer);
#ifdef USE_ROCM
// Quick Reduce all-reduce kernels
custom_ar.def(
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
"cast_bf2half) -> ()");
custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
custom_ar.def("init_custom_qr", &init_custom_qr);
custom_ar.def("qr_destroy", &qr_destroy);
custom_ar.def("qr_get_handle", &qr_get_handle);
custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
// Max input size in bytes
custom_ar.def("qr_max_size", &qr_max_size);
#endif
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
+2 -2
View File
@@ -50,7 +50,7 @@ struct _typeConvert<float> {
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
template <>
struct _typeConvert<c10::Half> {
struct _typeConvert<torch::headeronly::Half> {
static constexpr bool exists = true;
using hip_type = __half;
using packed_hip_type = __half2;
@@ -73,7 +73,7 @@ struct _typeConvert<c10::Half> {
// CUDA_ARCH < 800 does not have BF16 support
// ROCm 7.0+ supports bfloat16
template <>
struct _typeConvert<c10::BFloat16> {
struct _typeConvert<torch::headeronly::BFloat16> {
static constexpr bool exists = true;
using hip_type = __nv_bfloat16;
using packed_hip_type = __nv_bfloat162;