diff --git a/CMakeLists.txt b/CMakeLists.txt index 06f267ee53a..0652a5f066e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -311,14 +311,9 @@ set(VLLM_EXT_SRC "csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu" "csrc/quantization/activation_kernels.cu" "csrc/cuda_utils_kernels.cu" - "csrc/custom_all_reduce.cu" - "csrc/torch_bindings.cpp" - "csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu") + "csrc/torch_bindings.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") - list(APPEND VLLM_EXT_SRC - "csrc/minimax_reduce_rms_kernel.cu") - SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. @@ -505,12 +500,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND ES_MXFP8_GROUPED_MM_ARCHS) set(SRCS - "csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu" - "csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu") + "csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu" + "csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${ES_MXFP8_GROUPED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") list(APPEND VLLM_GPU_FLAGS "-DENABLE_ES_MXFP8_GROUPED_MM_SM100=1") message(STATUS "Building ES MXFP8 grouped kernels for archs: ${ES_MXFP8_GROUPED_MM_ARCHS}") else() @@ -600,7 +595,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() if (VLLM_GPU_LANG STREQUAL "HIP") - # Add QuickReduce kernels + # Add QuickReduce kernels (ROCm-only; not part of stable ABI migration). list(APPEND VLLM_EXT_SRC "csrc/custom_quickreduce.cu" ) @@ -651,7 +646,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") "csrc/libtorch_stable/attention/paged_attention_v1.cu" "csrc/libtorch_stable/attention/paged_attention_v2.cu" "csrc/libtorch_stable/cache_kernels.cu" - "csrc/libtorch_stable/cache_kernels_fused.cu") + "csrc/libtorch_stable/cache_kernels.cu" + "csrc/libtorch_stable/cache_kernels_fused.cu" + "csrc/libtorch_stable/custom_all_reduce.cu" + "csrc/libtorch_stable/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_STABLE_EXT_SRC @@ -661,7 +659,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") "csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu" "csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/libtorch_stable/permute_cols.cu" - "csrc/libtorch_stable/quantization/awq/gemm_kernels.cu") + "csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu" + "csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu" + "csrc/libtorch_stable/quantization/awq/gemm_kernels.cu" + "csrc/libtorch_stable/minimax_reduce_rms_kernel.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_STABLE_EXT_SRC}" diff --git a/csrc/custom_all_reduce.cu b/csrc/libtorch_stable/custom_all_reduce.cu similarity index 58% rename from csrc/custom_all_reduce.cu rename to csrc/libtorch_stable/custom_all_reduce.cu index a38d6fa24a2..0f7f759949a 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/libtorch_stable/custom_all_reduce.cu @@ -1,7 +1,11 @@ -#include -#include -#include -#include +#include "torch_utils.h" + +#include +#include +#include +#include +#include +#include #include "custom_all_reduce.cuh" @@ -11,7 +15,7 @@ using fptr_t = int64_t; static_assert(sizeof(void*) == sizeof(fptr_t)); fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, - torch::Tensor& rank_data, int64_t rank, + torch::stable::Tensor& rank_data, int64_t rank, bool fully_connected) { int world_size = fake_ipc_ptrs.size(); if (world_size > 8) @@ -25,9 +29,9 @@ fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, for (int i = 0; i < world_size; i++) { ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); } - return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(), - rank_data.numel(), rank, world_size, - fully_connected); + return (fptr_t) new vllm::CustomAllreduce( + ipc_ptrs, rank_data.mutable_data_ptr(), rank_data.numel(), rank, + world_size, fully_connected); } /** @@ -46,10 +50,14 @@ fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, * 5. A[None].expand(2, -1, -1, -1): Not OK * 6. A[:, 1:, 1:]: Not OK */ -bool _is_weak_contiguous(torch::Tensor& t) { - return t.is_contiguous() || - (t.storage().nbytes() - t.storage_offset() * t.element_size() == - t.numel() * t.element_size()); +bool _is_weak_contiguous(torch::stable::Tensor& t) { + if (t.is_contiguous()) { + return true; + } + int64_t storage_nbytes = 0; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_storage_size(t.get(), &storage_nbytes)); + return storage_nbytes - t.storage_offset() * t.element_size() == + static_cast(t.numel() * t.element_size()); } /** @@ -59,42 +67,45 @@ bool _is_weak_contiguous(torch::Tensor& t) { * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first * copied into _reg_buffer. */ -void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, - fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) { +void all_reduce(fptr_t _fa, torch::stable::Tensor& inp, + torch::stable::Tensor& out, fptr_t _reg_buffer, + int64_t reg_buffer_sz_bytes) { auto fa = reinterpret_cast(_fa); - const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); - auto stream = c10::cuda::getCurrentCUDAStream().stream(); + const torch::stable::accelerator::DeviceGuard device_guard( + inp.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(inp.get_device_index()); - TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); - TORCH_CHECK_EQ(inp.numel(), out.numel()); - TORCH_CHECK(_is_weak_contiguous(out)); - TORCH_CHECK(_is_weak_contiguous(inp)); + STD_TORCH_CHECK((inp.scalar_type()) == (out.scalar_type())); + STD_TORCH_CHECK((inp.numel()) == (out.numel())); + STD_TORCH_CHECK(_is_weak_contiguous(out)); + STD_TORCH_CHECK(_is_weak_contiguous(inp)); auto input_size = inp.numel() * inp.element_size(); auto reg_buffer = reinterpret_cast(_reg_buffer); if (reg_buffer) { - TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes); - AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size, - cudaMemcpyDeviceToDevice, stream)); + STD_TORCH_CHECK((input_size) <= (reg_buffer_sz_bytes)); + STD_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.const_data_ptr(), input_size, + cudaMemcpyDeviceToDevice, stream)); } else { - reg_buffer = inp.data_ptr(); + reg_buffer = inp.mutable_data_ptr(); } switch (out.scalar_type()) { - case at::ScalarType::Float: { + case torch::headeronly::ScalarType::Float: { fa->allreduce(stream, reinterpret_cast(reg_buffer), - reinterpret_cast(out.data_ptr()), + reinterpret_cast(out.mutable_data_ptr()), out.numel()); break; } - case at::ScalarType::Half: { + case torch::headeronly::ScalarType::Half: { fa->allreduce(stream, reinterpret_cast(reg_buffer), - reinterpret_cast(out.data_ptr()), out.numel()); + reinterpret_cast(out.mutable_data_ptr()), + out.numel()); break; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) - case at::ScalarType::BFloat16: { + case torch::headeronly::ScalarType::BFloat16: { fa->allreduce( stream, reinterpret_cast(reg_buffer), - reinterpret_cast(out.data_ptr()), out.numel()); + reinterpret_cast(out.mutable_data_ptr()), out.numel()); break; } #endif @@ -112,7 +123,7 @@ int64_t meta_size() { return sizeof(vllm::Signal); } void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs) { auto fa = reinterpret_cast(_fa); - TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_); + STD_TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_); void* ipc_ptrs[8]; for (int i = 0; i < fake_ipc_ptrs.size(); i++) { ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); @@ -143,47 +154,49 @@ void register_graph_buffers(fptr_t _fa, fa->register_graph_buffers(bytes, offsets); } -std::tuple allocate_shared_buffer_and_handle( +std::tuple allocate_shared_buffer_and_handle( int64_t size) { - auto device_index = c10::cuda::current_device(); - at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index)); + int device_index; + STD_CUDA_CHECK(cudaGetDevice(&device_index)); + const torch::stable::accelerator::DeviceGuard device_guard(device_index); void* buffer; cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; - auto stream = c10::cuda::getCurrentCUDAStream().stream(); - AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); + const cudaStream_t stream = get_current_cuda_stream(device_index); + STD_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); // Allocate buffer #if defined(USE_ROCM) // data buffers need to be "uncached" for signal on MI200 - AT_CUDA_CHECK( + STD_CUDA_CHECK( hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached)); #else - AT_CUDA_CHECK(cudaMalloc((void**)&buffer, size)); + STD_CUDA_CHECK(cudaMalloc((void**)&buffer, size)); #endif - AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream)); - AT_CUDA_CHECK(cudaStreamSynchronize(stream)); - AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); + STD_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream)); + STD_CUDA_CHECK(cudaStreamSynchronize(stream)); + STD_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); // Create IPC memhandle for the allocated buffer. // Will use it in open_mem_handle. - auto options = - torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); - auto handle = - torch::empty({static_cast(sizeof(cudaIpcMemHandle_t))}, options); - AT_CUDA_CHECK( - cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data_ptr(), buffer)); + auto handle = torch::stable::empty( + {static_cast(sizeof(cudaIpcMemHandle_t))}, + torch::headeronly::ScalarType::Byte, std::nullopt, + torch::stable::Device(torch::stable::DeviceType::CPU)); + STD_CUDA_CHECK(cudaIpcGetMemHandle( + (cudaIpcMemHandle_t*)handle.mutable_data_ptr(), buffer)); return std::make_tuple(reinterpret_cast(buffer), handle); } -fptr_t open_mem_handle(torch::Tensor& mem_handle) { +fptr_t open_mem_handle(torch::stable::Tensor& mem_handle) { void* ipc_ptr; - AT_CUDA_CHECK(cudaIpcOpenMemHandle( - (void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data_ptr()), + STD_CUDA_CHECK(cudaIpcOpenMemHandle( + (void**)&ipc_ptr, + *((const cudaIpcMemHandle_t*)mem_handle.const_data_ptr()), cudaIpcMemLazyEnablePeerAccess)); return reinterpret_cast(ipc_ptr); } void free_shared_buffer(fptr_t buffer) { - AT_CUDA_CHECK(cudaFree(reinterpret_cast(buffer))); + STD_CUDA_CHECK(cudaFree(reinterpret_cast(buffer))); } diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/libtorch_stable/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu similarity index 89% rename from csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu rename to csrc/libtorch_stable/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu index e4d432cac97..a5f3f03de00 100644 --- a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +++ b/csrc/libtorch_stable/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -28,7 +28,20 @@ * [bs*576, bs*576 + bs*8): UE8M0 scales, 7 real + 1 pad per token */ +#include "torch_utils.h" + +#include +#include +#include +#include +#include +#include + #include +#include "cuda_compat.h" +#include "dispatch_utils.h" +#include "type_convert.cuh" + #ifndef USE_ROCM #include #else @@ -37,14 +50,6 @@ #include #include -#include -#include -#include - -#include "cuda_compat.h" -#include "dispatch_utils.h" -#include "type_convert.cuh" - #ifndef FINAL_MASK #ifdef USE_ROCM #define FINAL_MASK 0xffffffffffffffffULL @@ -70,7 +75,7 @@ namespace deepseek_v4_fused_ops { namespace { inline int getSMVersion() { - auto* props = at::cuda::getCurrentDeviceProperties(); + auto* props = get_device_prop(); return props->major * 10 + props->minor; } } // namespace @@ -564,7 +569,7 @@ static void launchFusedDeepseekV4Templated( // bf16 on pre-Ampere (sm_70/sm_75) because _typeConvert is // unavailable there. Refuse the launch loudly instead of silently // skipping the work. - TORCH_CHECK( + STD_TORCH_CHECK( sm_version >= 80, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert requires sm_80+ " "(Ampere or newer); got sm_", @@ -635,7 +640,7 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( DISPATCH(64) DISPATCH(128) default: - TORCH_CHECK(false, + STD_TORCH_CHECK(false, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert: " "unsupported num_heads_q_padded=", num_heads_q_padded, @@ -650,71 +655,80 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( // ──────────────────────────────────────────────────────────────────────────── // Torch op wrapper // ──────────────────────────────────────────────────────────────────────────── -torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( - torch::Tensor const& q_in, // [N, num_heads_q, 512] bf16 - torch::Tensor const& kv, // [N, 512] bf16 (read-only) - torch::Tensor& k_cache, // [num_blocks, block_bytes] uint8 - torch::Tensor const& slot_mapping, // [N] int64 - torch::Tensor const& position_ids, // [N] int64 - torch::Tensor const& cos_sin_cache, // [max_pos, rope_dim] bf16 - int64_t q_head_padded, // padded Q head count for output +torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( + torch::stable::Tensor const& q_in, // [N, num_heads_q, 512] bf16 + torch::stable::Tensor const& kv, // [N, 512] bf16 (read-only) + torch::stable::Tensor& k_cache, // [num_blocks, block_bytes] uint8 + torch::stable::Tensor const& slot_mapping, // [N] int64 + torch::stable::Tensor const& position_ids, // [N] int64 + torch::stable::Tensor const& cos_sin_cache, // [max_pos, rope_dim] bf16 + int64_t q_head_padded, // padded Q head count for output double eps, int64_t cache_block_size) { - TORCH_CHECK(q_in.is_cuda() && q_in.is_contiguous(), - "q_in must be contiguous CUDA"); - TORCH_CHECK(kv.is_cuda() && kv.is_contiguous(), "kv must be contiguous CUDA"); - TORCH_CHECK(k_cache.is_cuda(), "k_cache must be CUDA"); - TORCH_CHECK(slot_mapping.is_cuda() && slot_mapping.dtype() == torch::kInt64, - "slot_mapping must be int64 CUDA"); - TORCH_CHECK(position_ids.is_cuda() && position_ids.dtype() == torch::kInt64, - "position_ids must be int64 CUDA"); - TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA"); - TORCH_CHECK(q_in.dim() == 3 && q_in.size(2) == 512, - "q_in shape [N, num_heads_q, 512]"); - TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]"); - TORCH_CHECK(q_in.dtype() == kv.dtype(), "q_in and kv dtype must match"); - TORCH_CHECK(q_head_padded >= q_in.size(1), - "q_head_padded must be >= q_in.size(1) (num_heads_q)"); - TORCH_CHECK(k_cache.dtype() == torch::kUInt8, "k_cache must be uint8"); - TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64, - "cos_sin_cache shape [max_pos, 64]"); - TORCH_CHECK(cos_sin_cache.dtype() == torch::kFloat32, - "cos_sin_cache must be float32"); + STD_TORCH_CHECK(q_in.device().is_cuda() && q_in.is_contiguous(), + "q_in must be contiguous CUDA"); + STD_TORCH_CHECK(kv.device().is_cuda() && kv.is_contiguous(), + "kv must be contiguous CUDA"); + STD_TORCH_CHECK(k_cache.device().is_cuda(), "k_cache must be CUDA"); + STD_TORCH_CHECK(slot_mapping.device().is_cuda() && + slot_mapping.scalar_type() == + torch::headeronly::ScalarType::Long, + "slot_mapping must be int64 CUDA"); + STD_TORCH_CHECK(position_ids.device().is_cuda() && + position_ids.scalar_type() == + torch::headeronly::ScalarType::Long, + "position_ids must be int64 CUDA"); + STD_TORCH_CHECK(cos_sin_cache.device().is_cuda(), "cos_sin_cache must be CUDA"); + STD_TORCH_CHECK(q_in.dim() == 3 && q_in.size(2) == 512, + "q_in shape [N, num_heads_q, 512]"); + STD_TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]"); + STD_TORCH_CHECK(q_in.scalar_type() == kv.scalar_type(), + "q_in and kv dtype must match"); + STD_TORCH_CHECK(q_head_padded >= q_in.size(1), + "q_head_padded must be >= q_in.size(1) (num_heads_q)"); + STD_TORCH_CHECK(k_cache.scalar_type() == torch::headeronly::ScalarType::Byte, + "k_cache must be uint8"); + STD_TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64, + "cos_sin_cache shape [max_pos, 64]"); + STD_TORCH_CHECK(cos_sin_cache.scalar_type() == + torch::headeronly::ScalarType::Float, + "cos_sin_cache must be float32"); // With DP padding, slot_mapping can be shorter than q/kv/positions. // Q-norm+RoPE runs on all q.size(0) rows (downstream attention uses them); // KV quant+insert runs only on the first slot_mapping.size(0) rows. int const num_tokens_full = static_cast(q_in.size(0)); int const num_tokens_insert = static_cast(slot_mapping.size(0)); - TORCH_CHECK(static_cast(kv.size(0)) == num_tokens_full && - static_cast(position_ids.size(0)) == num_tokens_full, - "q/kv/position_ids row counts must match"); - TORCH_CHECK(num_tokens_insert <= num_tokens_full, - "slot_mapping must not exceed q row count"); + STD_TORCH_CHECK(static_cast(kv.size(0)) == num_tokens_full && + static_cast(position_ids.size(0)) == num_tokens_full, + "q/kv/position_ids row counts must match"); + STD_TORCH_CHECK(num_tokens_insert <= num_tokens_full, + "slot_mapping must not exceed q row count"); int const num_heads_q = static_cast(q_in.size(1)); int const num_heads_q_padded = static_cast(q_head_padded); int const cache_block_size_i = static_cast(cache_block_size); int const kv_block_stride = static_cast(k_cache.stride(0)); - at::cuda::OptionalCUDAGuard device_guard(device_of(q_in)); - auto stream = at::cuda::getCurrentCUDAStream(); + const torch::stable::accelerator::DeviceGuard device_guard( + q_in.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(q_in.get_device_index()); // Allocate the padded q output. The kernel writes every element (live // region gets RMSNorm+RoPE; pad region gets zeros), so `empty` is safe. - torch::Tensor q_out = torch::empty( - {q_in.size(0), q_head_padded, q_in.size(2)}, q_in.options()); + auto q_out = torch::stable::new_empty( + q_in, {q_in.size(0), q_head_padded, q_in.size(2)}, q_in.scalar_type()); - VLLM_DISPATCH_HALF_TYPES( + VLLM_STABLE_DISPATCH_HALF_TYPES( q_in.scalar_type(), "fused_deepseek_v4_qnorm_rope_kv_insert", [&] { using qkv_scalar_t = scalar_t; vllm::deepseek_v4_fused_ops:: launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( - reinterpret_cast(q_in.data_ptr()), - reinterpret_cast(q_out.data_ptr()), - reinterpret_cast(kv.data_ptr()), - reinterpret_cast(k_cache.data_ptr()), - reinterpret_cast(slot_mapping.data_ptr()), - reinterpret_cast(position_ids.data_ptr()), - cos_sin_cache.data_ptr(), static_cast(eps), + reinterpret_cast(q_in.const_data_ptr()), + reinterpret_cast(q_out.mutable_data_ptr()), + reinterpret_cast(kv.const_data_ptr()), + reinterpret_cast(k_cache.mutable_data_ptr()), + slot_mapping.const_data_ptr(), + position_ids.const_data_ptr(), + cos_sin_cache.const_data_ptr(), static_cast(eps), num_tokens_full, num_tokens_insert, num_heads_q, num_heads_q_padded, cache_block_size_i, kv_block_stride, stream); diff --git a/csrc/minimax_reduce_rms_kernel.cu b/csrc/libtorch_stable/minimax_reduce_rms_kernel.cu similarity index 87% rename from csrc/minimax_reduce_rms_kernel.cu rename to csrc/libtorch_stable/minimax_reduce_rms_kernel.cu index 6245b02d6e9..d9af0f5efe0 100644 --- a/csrc/minimax_reduce_rms_kernel.cu +++ b/csrc/libtorch_stable/minimax_reduce_rms_kernel.cu @@ -15,16 +15,19 @@ * limitations under the License. */ +#include "torch_utils.h" + +#include +#include +#include +#include +#include +#include + #include #include -#include -#include -#include - #include "cuda_compat.h" -#include "cuda_utils.h" -#include "core/registration.h" #include "minimax_reduce_rms_kernel.h" #include @@ -611,7 +614,7 @@ int get_sm_count() { static int sm_count = 0; if (sm_count == 0) { int device_id; - CUDA_CHECK(cudaGetDevice(&device_id)); + STD_CUDA_CHECK(cudaGetDevice(&device_id)); cudaDeviceProp device_prop; cudaGetDeviceProperties(&device_prop, device_id); sm_count = device_prop.multiProcessorCount; @@ -621,13 +624,13 @@ int get_sm_count() { inline int getSMVersion(bool queryRealSmArch = false) { int device{-1}; - CUDA_CHECK(cudaGetDevice(&device)); + STD_CUDA_CHECK(cudaGetDevice(&device)); int sm_major = 0; int sm_minor = 0; - CUDA_CHECK(cudaDeviceGetAttribute(&sm_major, - cudaDevAttrComputeCapabilityMajor, device)); - CUDA_CHECK(cudaDeviceGetAttribute(&sm_minor, - cudaDevAttrComputeCapabilityMinor, device)); + STD_CUDA_CHECK(cudaDeviceGetAttribute( + &sm_major, cudaDevAttrComputeCapabilityMajor, device)); + STD_CUDA_CHECK(cudaDeviceGetAttribute( + &sm_minor, cudaDevAttrComputeCapabilityMinor, device)); int sm = sm_major * 10 + sm_minor; if (sm == 121 && !queryRealSmArch) { return 120; @@ -639,7 +642,7 @@ template int get_max_active_blocks(KernelFunc kernel, int block_size, int dynamic_smem = 0) { int max_active = 0; - CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + STD_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active, kernel, block_size, dynamic_smem)); return std::max(max_active, 1); } @@ -678,27 +681,27 @@ void minimax_reduce_rms_kernel_launcher(MiniMaxReduceRMSParams const& params) { cfg.attrs = attribute; cfg.numAttrs = SM >= 90 ? 2 : 0; - CUDA_CHECK(cudaLaunchKernelEx( + STD_CUDA_CHECK(cudaLaunchKernelEx( &cfg, minimax_reduce_rms_kernel_lamport, params)); } template void minimax_reduce_rms_kernel_launcher_float4( MiniMaxReduceRMSParams const& params) { - TORCH_CHECK(params.size_q % params.hidden_dim == 0); - TORCH_CHECK(params.hidden_dim % kElemsPerAccess == 0); + STD_TORCH_CHECK(params.size_q % params.hidden_dim == 0); + STD_TORCH_CHECK(params.hidden_dim % kElemsPerAccess == 0); if (params.stride_q > 0) { - TORCH_CHECK(params.stride_q % kElemsPerAccess == 0); + STD_TORCH_CHECK(params.stride_q % kElemsPerAccess == 0); } - TORCH_CHECK(params.allreduce_in_k != nullptr, - "float4 QK kernel requires K input"); - TORCH_CHECK(params.hidden_dim >= params.hidden_dim_k); - TORCH_CHECK(params.size_k % params.hidden_dim_k == 0); - TORCH_CHECK(params.hidden_dim_k % kElemsPerAccess == 0); - TORCH_CHECK(params.size_q / params.hidden_dim == - params.size_k / params.hidden_dim_k); + STD_TORCH_CHECK(params.allreduce_in_k != nullptr, + "float4 QK kernel requires K input"); + STD_TORCH_CHECK(params.hidden_dim >= params.hidden_dim_k); + STD_TORCH_CHECK(params.size_k % params.hidden_dim_k == 0); + STD_TORCH_CHECK(params.hidden_dim_k % kElemsPerAccess == 0); + STD_TORCH_CHECK(params.size_q / params.hidden_dim == + params.size_k / params.hidden_dim_k); if (params.stride_k > 0) { - TORCH_CHECK(params.stride_k % kElemsPerAccess == 0); + STD_TORCH_CHECK(params.stride_k % kElemsPerAccess == 0); } int token_num = params.size_q / params.hidden_dim; @@ -746,7 +749,7 @@ void minimax_reduce_rms_kernel_launcher_float4( cfg.attrs = attribute; cfg.numAttrs = SM >= 90 ? 2 : 0; - CUDA_CHECK(cudaLaunchKernelEx(&cfg, kfn, params)); + STD_CUDA_CHECK(cudaLaunchKernelEx(&cfg, kfn, params)); } template @@ -759,21 +762,21 @@ void dispatch_dtype(MiniMaxReduceRMSParams const& params) { (params.hidden_dim * params.nranks == 6144) && (params.hidden_dim_k * params.nranks == 1024); - if (params.dtype == at::ScalarType::Half) { + if (params.dtype == torch::headeronly::ScalarType::Half) { if (use_float4) { minimax_reduce_rms_kernel_launcher_float4( params); } else { minimax_reduce_rms_kernel_launcher(params); } - } else if (params.dtype == at::ScalarType::BFloat16) { + } else if (params.dtype == torch::headeronly::ScalarType::BFloat16) { if (use_float4) { minimax_reduce_rms_kernel_launcher_float4<__nv_bfloat16, NRanks, 6144, 1024>(params); } else { minimax_reduce_rms_kernel_launcher<__nv_bfloat16, NRanks>(params); } - } else if (params.dtype == at::ScalarType::Float) { + } else if (params.dtype == torch::headeronly::ScalarType::Float) { if (use_float4) { minimax_reduce_rms_kernel_launcher_float4( params); @@ -781,7 +784,7 @@ void dispatch_dtype(MiniMaxReduceRMSParams const& params) { minimax_reduce_rms_kernel_launcher(params); } } else { - TORCH_CHECK(false, "Unsupported data type for minimax_reduce_rms_op"); + STD_TORCH_CHECK(false, "Unsupported data type for minimax_reduce_rms_op"); } } @@ -795,16 +798,18 @@ void minimax_reduce_rms_op(MiniMaxReduceRMSParams const& params) { } else if (params.nranks == 16) { dispatch_dtype<16>(params); } else { - TORCH_CHECK(false, "minimax_reduce_rms_op: unsupported ranks number!"); + STD_TORCH_CHECK(false, "minimax_reduce_rms_op: unsupported ranks number!"); } } } // namespace tensorrt_llm } // namespace vllm -torch::Tensor minimax_allreduce_rms(torch::Tensor const& input, - torch::Tensor const& norm_weight, - torch::Tensor workspace, int64_t const rank, - int64_t const nranks, double const eps) { +torch::stable::Tensor minimax_allreduce_rms( + torch::stable::Tensor const& input, + torch::stable::Tensor const& norm_weight, torch::stable::Tensor workspace, + int64_t const rank, int64_t const nranks, double const eps) { + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); auto allreduce_params = vllm::tensorrt_llm::MiniMaxReduceRMSParams(); allreduce_params.nranks = static_cast(nranks); @@ -815,12 +820,12 @@ torch::Tensor minimax_allreduce_rms(torch::Tensor const& input, allreduce_params.stride_q = allreduce_params.hidden_dim; allreduce_params.workspace = reinterpret_cast(workspace.mutable_data_ptr()); - allreduce_params.allreduce_in = input.data_ptr(); - allreduce_params.rms_gamma = norm_weight.data_ptr(); + allreduce_params.allreduce_in = const_cast(input.const_data_ptr()); + allreduce_params.rms_gamma = const_cast(norm_weight.const_data_ptr()); allreduce_params.rms_eps = static_cast(eps); - allreduce_params.stream = at::cuda::getCurrentCUDAStream(input.get_device()); + allreduce_params.stream = get_current_cuda_stream(input.get_device_index()); - torch::Tensor rms_norm_out = torch::empty_like(input); + torch::stable::Tensor rms_norm_out = torch::stable::empty_like(input); allreduce_params.rms_norm_out = rms_norm_out.mutable_data_ptr(); vllm::tensorrt_llm::minimax_reduce_rms_op(allreduce_params); @@ -828,26 +833,33 @@ torch::Tensor minimax_allreduce_rms(torch::Tensor const& input, return rms_norm_out; } -std::tuple minimax_allreduce_rms_qk( - torch::Tensor qkv, torch::Tensor const& norm_weight_q, - torch::Tensor const& norm_weight_k, torch::Tensor workspace, - int64_t const q_size, int64_t const kv_size, int64_t const rank, - int64_t const nranks, double const eps) { - TORCH_CHECK(qkv.dim() == 2, "minimax_allreduce_rms_qk: qkv must be 2D"); - TORCH_CHECK(qkv.is_contiguous(), - "minimax_allreduce_rms_qk: qkv must be contiguous"); +std::tuple +minimax_allreduce_rms_qk(torch::stable::Tensor qkv, + torch::stable::Tensor const& norm_weight_q, + torch::stable::Tensor const& norm_weight_k, + torch::stable::Tensor workspace, int64_t const q_size, + int64_t const kv_size, int64_t const rank, + int64_t const nranks, double const eps) { + STD_TORCH_CHECK(qkv.dim() == 2, "minimax_allreduce_rms_qk: qkv must be 2D"); + STD_TORCH_CHECK(qkv.is_contiguous(), + "minimax_allreduce_rms_qk: qkv must be contiguous"); int64_t qkv_dim = qkv.size(-1); - TORCH_CHECK(qkv_dim == q_size + 2 * kv_size, - "minimax_allreduce_rms_qk: qkv last dim must equal " - "q_size + 2 * kv_size"); - TORCH_CHECK(rank < nranks, - "minimax_allreduce_rms_qk: rank must be less than nranks"); + STD_TORCH_CHECK(qkv_dim == q_size + 2 * kv_size, + "minimax_allreduce_rms_qk: qkv last dim must equal " + "q_size + 2 * kv_size"); + STD_TORCH_CHECK(rank < nranks, + "minimax_allreduce_rms_qk: rank must be less than nranks"); + + const torch::stable::accelerator::DeviceGuard device_guard( + qkv.get_device_index()); int64_t num_tokens = qkv.size(0); int elem_bytes = qkv.element_size(); - torch::Tensor q_out = torch::empty({num_tokens, q_size}, qkv.options()); - torch::Tensor k_out = torch::empty({num_tokens, kv_size}, qkv.options()); + torch::stable::Tensor q_out = + torch::stable::new_empty(qkv, {num_tokens, q_size}, qkv.scalar_type()); + torch::stable::Tensor k_out = + torch::stable::new_empty(qkv, {num_tokens, kv_size}, qkv.scalar_type()); auto params = vllm::tensorrt_llm::MiniMaxReduceRMSParams(); params.nranks = static_cast(nranks); @@ -863,13 +875,14 @@ std::tuple minimax_allreduce_rms_qk( params.stride_k_out = 0; // k_out is contiguous; kernel uses hidden_dim_k params.workspace = reinterpret_cast(workspace.mutable_data_ptr()); - uint8_t* base = static_cast(qkv.data_ptr()); + uint8_t* base = + const_cast(static_cast(qkv.const_data_ptr())); params.allreduce_in = base; params.allreduce_in_k = base + q_size * elem_bytes; - params.rms_gamma = norm_weight_q.data_ptr(); - params.rms_gamma_k = norm_weight_k.data_ptr(); + params.rms_gamma = const_cast(norm_weight_q.const_data_ptr()); + params.rms_gamma_k = const_cast(norm_weight_k.const_data_ptr()); params.rms_eps = static_cast(eps); - params.stream = at::cuda::getCurrentCUDAStream(qkv.get_device()); + params.stream = get_current_cuda_stream(qkv.get_device_index()); params.rms_norm_out = q_out.mutable_data_ptr(); params.rms_norm_out_k = k_out.mutable_data_ptr(); diff --git a/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu b/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu new file mode 100644 index 00000000000..fda9bc020da --- /dev/null +++ b/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project +// Adapted from SGLang: +// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu + +#include +#include +#include "libtorch_stable/torch_utils.h" + +#include "cutlass_mxfp8_grouped_mm_launcher.cuh" + +void cutlass_mxfp8_grouped_mm(const torch::stable::Tensor& a, + const torch::stable::Tensor& b, + const torch::stable::Tensor& sfa, + const torch::stable::Tensor& sfb, + torch::stable::Tensor& d, + const torch::stable::Tensor& problem_sizes, + const torch::stable::Tensor& expert_offsets, + const torch::stable::Tensor& blockscale_offsets) { +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + STD_TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); + STD_TORCH_CHECK(problem_sizes.size(1) == 3, + "problem_sizes must have shape (num_experts, 3)"); + STD_TORCH_CHECK( + problem_sizes.size(0) == expert_offsets.size(0), + "Number of experts in problem_sizes must match expert_offsets"); + STD_TORCH_CHECK( + problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int, + "problem_sizes must be int32"); + STD_TORCH_CHECK( + expert_offsets.scalar_type() == torch::headeronly::ScalarType::Int, + "expert_offsets must be int32"); + STD_TORCH_CHECK( + blockscale_offsets.scalar_type() == torch::headeronly::ScalarType::Int, + "blockscale_offsets must be int32"); + STD_TORCH_CHECK(a.dim() == 2, + "a must be a 2D tensor of shape (num_tokens, k)"); + STD_TORCH_CHECK(b.dim() == 3, + "b must be a 3D tensor of shape (num_experts, k, n)"); + STD_TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0, + "k should align 128"); + STD_TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128"); + STD_TORCH_CHECK(a.stride(1) == 1, "a must be row major"); + STD_TORCH_CHECK(b.stride(1) == 1, "b must be column major"); + + const torch::stable::accelerator::DeviceGuard device_guard( + a.get_device_index()); + auto stream = get_current_cuda_stream(a.get_device_index()); + if (d.scalar_type() == torch::headeronly::ScalarType::BFloat16) { + expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype< + cutlass::bfloat16_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets, + blockscale_offsets, stream); + } else if (d.scalar_type() == torch::headeronly::ScalarType::Half) { + expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype< + cutlass::half_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets, + blockscale_offsets, stream); + } else { + STD_TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16"); + } +#else + STD_TORCH_CHECK(false, + "No implemented cutlass_mxfp8_grouped_mm for " + "current device"); +#endif +} + +STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { + m.impl("cutlass_mxfp8_grouped_mm", TORCH_BOX(&cutlass_mxfp8_grouped_mm)); +} diff --git a/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh b/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh similarity index 100% rename from csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh rename to csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh diff --git a/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh b/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh similarity index 54% rename from csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh rename to csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh index 2c46e1fa725..82d6543b288 100644 --- a/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh +++ b/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh @@ -4,9 +4,9 @@ // https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_launcher.cuh #pragma once -#include -#include -#include + +#include +#include #include #include @@ -15,18 +15,22 @@ #include "cute/tensor.hpp" #include "cutlass_mxfp8_grouped_mm_functor.cuh" #include "cutlass_mxfp8_grouped_mm_traits.cuh" +#include "libtorch_stable/torch_utils.h" namespace expert_specialization { template void cutlass_mxfp8_grouped_mm_pre_compute( - torch::Tensor& a_ptrs, torch::Tensor& b_ptrs, torch::Tensor& sfa_ptrs, - torch::Tensor& sfb_ptrs, torch::Tensor& d_ptrs, torch::Tensor& stride_a, - torch::Tensor& stride_b, torch::Tensor& stride_d, torch::Tensor& layout_sfa, - torch::Tensor& layout_sfb, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& sfa, const torch::Tensor& sfb, const torch::Tensor& d, - const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets, - const torch::Tensor& blockscale_offsets, cudaStream_t stream) { + torch::stable::Tensor& a_ptrs, torch::stable::Tensor& b_ptrs, + torch::stable::Tensor& sfa_ptrs, torch::stable::Tensor& sfb_ptrs, + torch::stable::Tensor& d_ptrs, torch::stable::Tensor& stride_a, + torch::stable::Tensor& stride_b, torch::stable::Tensor& stride_d, + torch::stable::Tensor& layout_sfa, torch::stable::Tensor& layout_sfb, + const torch::stable::Tensor& a, const torch::stable::Tensor& b, + const torch::stable::Tensor& sfa, const torch::stable::Tensor& sfb, + const torch::stable::Tensor& d, const torch::stable::Tensor& problem_sizes, + const torch::stable::Tensor& expert_offsets, + const torch::stable::Tensor& blockscale_offsets, cudaStream_t stream) { using OffsetFunctor = CutlassMxfp8GroupedMmOffsetFunctor; using ElementA = typename OffsetFunctor::ElementA; using ElementB = typename OffsetFunctor::ElementB; @@ -42,10 +46,10 @@ void cutlass_mxfp8_grouped_mm_pre_compute( using StrideB = typename StrideFunctor::StrideB; using StrideD = typename StrideFunctor::StrideD; - int num_experts = (int)expert_offsets.size(0); - TORCH_CHECK(num_experts <= 1024, - "Number of experts cannot exceed 1024, the maximum number of " - "threads per block."); + int num_experts = static_cast(expert_offsets.size(0)); + STD_TORCH_CHECK(num_experts <= 1024, + "Number of experts cannot exceed 1024, the maximum number of " + "threads per block."); OffsetFunctor offset_functor( reinterpret_cast(expert_offsets.data_ptr()), @@ -72,13 +76,18 @@ void cutlass_mxfp8_grouped_mm_pre_compute( } template -void cutlass_mxfp8_grouped_mm( - const torch::Tensor& a_ptrs, const torch::Tensor& b_ptrs, - const torch::Tensor& sfa_ptrs, const torch::Tensor& sfb_ptrs, - const torch::Tensor& d_ptrs, const torch::Tensor& stride_a, - const torch::Tensor& stride_b, const torch::Tensor& stride_d, - const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb, - const torch::Tensor& problem_sizes, cudaStream_t stream) { +void cutlass_mxfp8_grouped_mm(const torch::stable::Tensor& a_ptrs, + const torch::stable::Tensor& b_ptrs, + const torch::stable::Tensor& sfa_ptrs, + const torch::stable::Tensor& sfb_ptrs, + const torch::stable::Tensor& d_ptrs, + const torch::stable::Tensor& stride_a, + const torch::stable::Tensor& stride_b, + const torch::stable::Tensor& stride_d, + const torch::stable::Tensor& layout_sfa, + const torch::stable::Tensor& layout_sfb, + const torch::stable::Tensor& problem_sizes, + cudaStream_t stream) { using Gemm = typename GemmTraits::Gemm; using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; @@ -93,13 +102,12 @@ void cutlass_mxfp8_grouped_mm( typename GemmTraits::ProblemShape::UnderlyingProblemShape; cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = c10::cuda::current_device(); - hw_info.sm_count = - at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + hw_info.device_id = d_ptrs.get_device_index(); + hw_info.sm_count = get_device_prop()->multiProcessorCount; hw_info.cluster_shape = GemmTraits::MMAConfig::preferred_cluster; hw_info.cluster_shape_fallback = GemmTraits::MMAConfig::fallback_cluster; - int num_experts = (int)problem_sizes.size(0); + int num_experts = static_cast(problem_sizes.size(0)); UnderlyingProblemShape* underlying_problem_shape = reinterpret_cast(problem_sizes.data_ptr()); @@ -127,44 +135,55 @@ void cutlass_mxfp8_grouped_mm( Gemm gemm; auto can_implement_status = gemm.can_implement(arguments); - TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, - "Failed to implement GEMM"); + STD_TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, + "Failed to implement GEMM"); - torch::TensorOptions options_uint8 = - torch::TensorOptions().dtype(torch::kUInt8).device(d_ptrs.device()); size_t workspace_size = gemm.get_workspace_size(arguments); - torch::Tensor workspace = torch::empty(workspace_size, options_uint8); + torch::stable::Tensor workspace = torch::stable::empty( + {static_cast(workspace_size)}, + torch::headeronly::ScalarType::Byte, std::nullopt, d_ptrs.device()); auto status = gemm.initialize(arguments, workspace.data_ptr(), stream); - TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); + STD_TORCH_CHECK(status == cutlass::Status::kSuccess, + "Failed to initialize GEMM"); status = gemm.run(stream, nullptr, true); // Enable PDL - TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); + STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); } template void cutlass_mxfp8_grouped_mm_dispatch_out_dtype( - const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& sfa, - const torch::Tensor& sfb, torch::Tensor& d, - const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets, - const torch::Tensor& blockscale_offsets, cudaStream_t stream) { - int num_experts = (int)problem_sizes.size(0); - torch::TensorOptions options_int64 = - torch::TensorOptions().dtype(torch::kInt64).device(a.device()); - torch::TensorOptions options_int32 = - torch::TensorOptions().dtype(torch::kInt32).device(a.device()); + const torch::stable::Tensor& a, const torch::stable::Tensor& b, + const torch::stable::Tensor& sfa, const torch::stable::Tensor& sfb, + torch::stable::Tensor& d, const torch::stable::Tensor& problem_sizes, + const torch::stable::Tensor& expert_offsets, + const torch::stable::Tensor& blockscale_offsets, cudaStream_t stream) { + int num_experts = static_cast(problem_sizes.size(0)); + auto device = a.device(); - torch::Tensor a_ptrs = torch::empty(num_experts, options_int64); - torch::Tensor b_ptrs = torch::empty(num_experts, options_int64); - torch::Tensor sfa_ptrs = torch::empty(num_experts, options_int64); - torch::Tensor sfb_ptrs = torch::empty(num_experts, options_int64); - torch::Tensor d_ptrs = torch::empty(num_experts, options_int64); + torch::stable::Tensor a_ptrs = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); + torch::stable::Tensor b_ptrs = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); + torch::stable::Tensor sfa_ptrs = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); + torch::stable::Tensor sfb_ptrs = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); + torch::stable::Tensor d_ptrs = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); - torch::Tensor stride_a = torch::empty(num_experts, options_int64); - torch::Tensor stride_b = torch::empty(num_experts, options_int64); - torch::Tensor stride_d = torch::empty(num_experts, options_int64); - torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int32); - torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int32); + torch::stable::Tensor stride_a = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); + torch::stable::Tensor stride_b = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); + torch::stable::Tensor stride_d = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); + torch::stable::Tensor layout_sfa = + torch::stable::empty({num_experts, 5}, torch::headeronly::ScalarType::Int, + std::nullopt, device); + torch::stable::Tensor layout_sfb = + torch::stable::empty({num_experts, 5}, torch::headeronly::ScalarType::Int, + std::nullopt, device); using GemmTraits = CutlassMxfp8GroupedMmGemmTraits; cutlass_mxfp8_grouped_mm_pre_compute( @@ -176,4 +195,4 @@ void cutlass_mxfp8_grouped_mm_dispatch_out_dtype( layout_sfa, layout_sfb, problem_sizes, stream); } -} // namespace expert_specialization \ No newline at end of file +} // namespace expert_specialization diff --git a/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh b/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh similarity index 100% rename from csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh rename to csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh diff --git a/csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cu b/csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cu new file mode 100644 index 00000000000..e075721c2a3 --- /dev/null +++ b/csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cu @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project +// Adapted from SGLang: +// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu + +#include +#include +#include "libtorch_stable/torch_utils.h" + +#include "mxfp8_experts_quant.cuh" + +void mxfp8_experts_quant(const torch::stable::Tensor& input, + const torch::stable::Tensor& problem_sizes, + const torch::stable::Tensor& expert_offsets, + const torch::stable::Tensor& blockscale_offsets, + torch::stable::Tensor& quant_output, + torch::stable::Tensor& scale_factor) { +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + STD_TORCH_CHECK(input.dim() == 2, "input must be 2D tensor"); + STD_TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128"); + STD_TORCH_CHECK(input.stride(1) == 1, "input must be row major"); + STD_TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); + STD_TORCH_CHECK( + problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int, + "problem_sizes must be int32"); + STD_TORCH_CHECK( + expert_offsets.scalar_type() == torch::headeronly::ScalarType::Int, + "expert_offsets must be int32"); + STD_TORCH_CHECK( + blockscale_offsets.scalar_type() == torch::headeronly::ScalarType::Int, + "blockscale_offsets must be int32"); + + auto groups = problem_sizes.size(0); + STD_TORCH_CHECK( + expert_offsets.dim() == 1 && expert_offsets.size(0) == groups, + "expert_offsets must be 1D and have size equal to the number of groups"); + STD_TORCH_CHECK( + blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups, + "blockscale_offsets must be 1D and have size equal to the number of " + "groups"); + + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + if (input.scalar_type() == torch::headeronly::ScalarType::BFloat16) { + expert_specialization::launch_mxfp8_experts_quant<__nv_bfloat16>( + input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, + scale_factor); + } else if (input.scalar_type() == torch::headeronly::ScalarType::Half) { + expert_specialization::launch_mxfp8_experts_quant<__half>( + input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, + scale_factor); + } else { + STD_TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16"); + } +#else + STD_TORCH_CHECK(false, + "No implemented mxfp8_experts_quant for " + "current device"); +#endif +} + +// Registered here (not torch_bindings.cpp) because ENABLE_ES_MXFP8_GROUPED_MM +// is applied only under COMPILE_LANGUAGE:CUDA. +STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { + m.impl("mxfp8_experts_quant", TORCH_BOX(&mxfp8_experts_quant)); +} diff --git a/csrc/moe/mxfp8_moe/mxfp8_experts_quant.cuh b/csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cuh similarity index 95% rename from csrc/moe/mxfp8_moe/mxfp8_experts_quant.cuh rename to csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cuh index 9a85852080f..a57e00e76c3 100644 --- a/csrc/moe/mxfp8_moe/mxfp8_experts_quant.cuh +++ b/csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cuh @@ -4,16 +4,19 @@ // https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cuh #pragma once -#include -#include #include #include #include -#include + +#include +#include +#include +#include #include #include "cute/tensor.hpp" +#include "libtorch_stable/torch_utils.h" namespace expert_specialization { @@ -356,12 +359,12 @@ __global__ void mxfp8_experts_quant_kernel( } template -void launch_mxfp8_experts_quant(const torch::Tensor& input, - const torch::Tensor& problem_sizes, - const torch::Tensor& expert_offsets, - const torch::Tensor& blockscale_offsets, - torch::Tensor& quant_output, - torch::Tensor& scale_factor) { +void launch_mxfp8_experts_quant(const torch::stable::Tensor& input, + const torch::stable::Tensor& problem_sizes, + const torch::stable::Tensor& expert_offsets, + const torch::stable::Tensor& blockscale_offsets, + torch::stable::Tensor& quant_output, + torch::stable::Tensor& scale_factor) { ThrLayout thr_layout{}; ValLayout val_layout{}; SfR2SThrLayout r2s_thr_layout{}; @@ -386,19 +389,18 @@ void launch_mxfp8_experts_quant(const torch::Tensor& input, CopyAtomR2S{}, r2s_thr_layout, r2s_val_layout); // Tiler_MN: (16, 4) int max_active_blocks_per_sm = -1; - AT_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + STD_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_per_sm, mxfp8_experts_quant_kernel, THREAD_BLOCK_SIZE, 0)); - dim3 grid(at::cuda::getCurrentDeviceProperties()->multiProcessorCount * - max_active_blocks_per_sm, + dim3 grid(get_device_prop()->multiProcessorCount * max_active_blocks_per_sm, 1, 1); dim3 block(THREAD_BLOCK_SIZE, 1, 1); - int num_experts = (int)problem_sizes.size(0); - auto stream = at::cuda::getCurrentCUDAStream(); + int num_experts = static_cast(problem_sizes.size(0)); + auto stream = get_current_cuda_stream(input.get_device_index()); mxfp8_experts_quant_kernel <<>>( diff --git a/csrc/libtorch_stable/ops.h b/csrc/libtorch_stable/ops.h index 0363ec7cdfc..dd27a6968d0 100644 --- a/csrc/libtorch_stable/ops.h +++ b/csrc/libtorch_stable/ops.h @@ -231,6 +231,27 @@ void fused_qk_norm_rope(torch::stable::Tensor& qkv, int64_t num_heads_q, torch::stable::Tensor& position_ids, int64_t forced_token_heads_per_warp); +torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( + torch::stable::Tensor const& q_in, torch::stable::Tensor const& kv, + torch::stable::Tensor& k_cache, torch::stable::Tensor const& slot_mapping, + torch::stable::Tensor const& position_ids, + torch::stable::Tensor const& cos_sin_cache, int64_t q_head_padded, + double eps, int64_t cache_block_size); + +#ifndef USE_ROCM +torch::stable::Tensor minimax_allreduce_rms( + torch::stable::Tensor const& input, + torch::stable::Tensor const& norm_weight, torch::stable::Tensor workspace, + int64_t const rank, int64_t const nranks, double const eps); +std::tuple +minimax_allreduce_rms_qk(torch::stable::Tensor qkv, + torch::stable::Tensor const& norm_weight_q, + torch::stable::Tensor const& norm_weight_k, + torch::stable::Tensor workspace, int64_t const q_size, + int64_t const kv_size, int64_t const rank, + int64_t const nranks, double const eps); +#endif + // Sampler kernels (shared CUDA/ROCm) void apply_repetition_penalties_( torch::stable::Tensor& logits, const torch::stable::Tensor& prompt_mask, @@ -273,6 +294,26 @@ void selective_scan_fwd( const std::optional& cu_chunk_seqlen, const std::optional& last_chunk_indices); +using fptr_t = int64_t; +fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, + torch::stable::Tensor& rank_data, int64_t rank, + bool fully_connected); +void all_reduce(fptr_t _fa, torch::stable::Tensor& inp, + torch::stable::Tensor& out, fptr_t reg_buffer, + int64_t reg_buffer_sz_bytes); +void dispose(fptr_t _fa); +int64_t meta_size(); +void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs); +std::tuple, std::vector> +get_graph_buffer_ipc_meta(fptr_t _fa); +void register_graph_buffers(fptr_t _fa, + const std::vector>& handles, + const std::vector>& offsets); +std::tuple allocate_shared_buffer_and_handle( + int64_t size); +int64_t open_mem_handle(torch::stable::Tensor& mem_handle); +void free_shared_buffer(int64_t buffer); + // Activation kernels (shared CUDA/ROCm) void silu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input); void silu_and_mul_clamp(torch::stable::Tensor& out, diff --git a/csrc/libtorch_stable/torch_bindings.cpp b/csrc/libtorch_stable/torch_bindings.cpp index 98cd31df13b..e9a62a8666c 100644 --- a/csrc/libtorch_stable/torch_bindings.cpp +++ b/csrc/libtorch_stable/torch_bindings.cpp @@ -337,6 +337,24 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) { "bool is_neox, Tensor position_ids, " "int forced_token_heads_per_warp=-1) -> ()"); + ops.def( + "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(" + "Tensor q_in, Tensor kv, Tensor! k_cache, " + "Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, " + "int q_head_padded, float eps, int cache_block_size) -> Tensor"); + +#ifndef USE_ROCM + ops.def( + "minimax_allreduce_rms(" + "Tensor input, Tensor norm_weight, Tensor workspace, " + "int rank, int nranks, float eps) -> Tensor"); + ops.def( + "minimax_allreduce_rms_qk(" + "Tensor qkv, Tensor norm_weight_q, Tensor norm_weight_k, " + "Tensor workspace, int q_size, int kv_size, int rank, int nranks, " + "float eps) -> (Tensor, Tensor)"); +#endif + // Apply repetition penalties to logits in-place. ops.def( "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " @@ -571,6 +589,12 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) { // Positional encoding kernels (shared CUDA/ROCm) ops.impl("rotary_embedding", TORCH_BOX(&rotary_embedding)); ops.impl("fused_qk_norm_rope", TORCH_BOX(&fused_qk_norm_rope)); + ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", + TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert)); +#ifndef USE_ROCM + ops.impl("minimax_allreduce_rms", TORCH_BOX(&minimax_allreduce_rms)); + ops.impl("minimax_allreduce_rms_qk", TORCH_BOX(&minimax_allreduce_rms_qk)); +#endif // Sampler kernels (shared CUDA/ROCm) ops.impl("apply_repetition_penalties_", @@ -725,6 +749,45 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C_cache_ops, ops) { "dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()"); } +STABLE_TORCH_LIBRARY_FRAGMENT(_C_custom_ar, custom_ar) { + custom_ar.def( + "init_custom_ar(int[] ipc_tensors, Tensor rank_data, " + "int rank, bool fully_connected) -> int"); + custom_ar.def( + "all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, " + "int reg_buffer_sz_bytes) -> ()"); + custom_ar.def("dispose(int fa) -> ()"); + custom_ar.def("meta_size() -> int"); + custom_ar.def("register_buffer(int fa, int[] ipc_tensors) -> ()"); + custom_ar.def("get_graph_buffer_ipc_meta(int fa) -> (int[], int[])"); + custom_ar.def( + "register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()"); + custom_ar.def("allocate_shared_buffer_and_handle(int size) -> (int, Tensor)"); + custom_ar.def("open_mem_handle(Tensor mem_handle) -> int"); + custom_ar.def("free_shared_buffer(int ptr) -> ()"); +} + +STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CUDA, custom_ar) { + custom_ar.impl("init_custom_ar", TORCH_BOX(&init_custom_ar)); + custom_ar.impl("all_reduce", TORCH_BOX(&all_reduce)); +} + +STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CPU, custom_ar) { + custom_ar.impl("open_mem_handle", TORCH_BOX(&open_mem_handle)); +} + +STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CompositeExplicitAutograd, custom_ar) { + custom_ar.impl("dispose", TORCH_BOX(&dispose)); + custom_ar.impl("meta_size", TORCH_BOX(&meta_size)); + custom_ar.impl("register_buffer", TORCH_BOX(®ister_buffer)); + custom_ar.impl("get_graph_buffer_ipc_meta", + TORCH_BOX(&get_graph_buffer_ipc_meta)); + custom_ar.impl("register_graph_buffers", TORCH_BOX(®ister_graph_buffers)); + custom_ar.impl("allocate_shared_buffer_and_handle", + TORCH_BOX(&allocate_shared_buffer_and_handle)); + custom_ar.impl("free_shared_buffer", TORCH_BOX(&free_shared_buffer)); +} + STABLE_TORCH_LIBRARY_IMPL(_C_cache_ops, CPU, ops) { ops.impl("swap_blocks_batch", TORCH_BOX(&swap_blocks_batch)); } diff --git a/csrc/minimax_reduce_rms_kernel.h b/csrc/minimax_reduce_rms_kernel.h index e8c2d012247..c3d2dd5c599 100644 --- a/csrc/minimax_reduce_rms_kernel.h +++ b/csrc/minimax_reduce_rms_kernel.h @@ -19,7 +19,7 @@ #include #include -#include +#include namespace vllm { namespace tensorrt_llm { @@ -51,7 +51,7 @@ static constexpr int kElemsPerAccess = ElemsPerAccess::value; struct MiniMaxReduceRMSParams { int nranks{}; int rank{}; - at::ScalarType dtype{at::ScalarType::Undefined}; + torch::headeronly::ScalarType dtype{torch::headeronly::ScalarType::Undefined}; int size_q{}; int hidden_dim{}; int size_k{}; diff --git a/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu b/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu deleted file mode 100644 index f507f9299b0..00000000000 --- a/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu +++ /dev/null @@ -1,60 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright contributors to the vLLM project -// Adapted from SGLang: -// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu - -#include - -#include "cutlass_mxfp8_grouped_mm_launcher.cuh" - -void cutlass_mxfp8_grouped_mm(const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& sfa, - const torch::Tensor& sfb, torch::Tensor& d, - const torch::Tensor& problem_sizes, - const torch::Tensor& expert_offsets, - const torch::Tensor& blockscale_offsets) { -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) - TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); - TORCH_CHECK(problem_sizes.size(1) == 3, - "problem_sizes must have shape (num_experts, 3)"); - TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0), - "Number of experts in problem_sizes must match expert_offsets"); - TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, - "problem_sizes must be int32"); - TORCH_CHECK(expert_offsets.dtype() == torch::kInt32, - "expert_offsets must be int32"); - TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32, - "blockscale_offsets must be int32"); - TORCH_CHECK(a.dim() == 2, "a must be a 2D tensor of shape (num_tokens, k)"); - TORCH_CHECK(b.dim() == 3, - "b must be a 3D tensor of shape (num_experts, k, n)"); - TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0, - "k should align 128"); - TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128"); - TORCH_CHECK(a.strides()[1] == 1, "a must be row major"); - TORCH_CHECK(b.strides()[1] == 1, "b must be column major"); - - auto stream = at::cuda::getCurrentCUDAStream(); - if (d.dtype() == torch::kBFloat16) { - expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype< - cutlass::bfloat16_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets, - blockscale_offsets, stream); - } else if (d.dtype() == torch::kFloat16) { - expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype< - cutlass::half_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets, - blockscale_offsets, stream); - } else { - TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16"); - } -#else - TORCH_CHECK(false, - "No implemented cutlass_mxfp8_grouped_mm for " - "current device"); -#endif -} - -#include "core/registration.h" - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("cutlass_mxfp8_grouped_mm", cutlass_mxfp8_grouped_mm); -} \ No newline at end of file diff --git a/csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu b/csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu deleted file mode 100644 index 2a93ab94d5c..00000000000 --- a/csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu +++ /dev/null @@ -1,60 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright contributors to the vLLM project -// Adapted from SGLang: -// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu - -#include - -#include "mxfp8_experts_quant.cuh" - -void mxfp8_experts_quant(const torch::Tensor& input, - const torch::Tensor& problem_sizes, - const torch::Tensor& expert_offsets, - const torch::Tensor& blockscale_offsets, - torch::Tensor& quant_output, - torch::Tensor& scale_factor) { -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) - TORCH_CHECK(input.dim() == 2, "input must be 2D tensor"); - TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128"); - TORCH_CHECK(input.strides()[1] == 1, "input must be row major"); - TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); - TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, - "problem_sizes must be int32"); - TORCH_CHECK(expert_offsets.dtype() == torch::kInt32, - "expert_offsets must be int32"); - TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32, - "blockscale_offsets must be int32"); - - auto groups = problem_sizes.size(0); - TORCH_CHECK( - expert_offsets.dim() == 1 && expert_offsets.size(0) == groups, - "expert_offsets must be 1D and have size equal to the number of groups"); - TORCH_CHECK( - blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups, - "blockscale_offsets must be 1D and have size equal to the number of " - "groups"); - - auto stream = at::cuda::getCurrentCUDAStream(); - if (input.dtype() == torch::kBFloat16) { - expert_specialization::launch_mxfp8_experts_quant<__nv_bfloat16>( - input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, - scale_factor); - } else if (input.dtype() == torch::kFloat16) { - expert_specialization::launch_mxfp8_experts_quant<__half>( - input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, - scale_factor); - } else { - TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16"); - } -#else - TORCH_CHECK(false, - "No implemented mxfp8_experts_quant for " - "current device"); -#endif -} - -#include "core/registration.h" - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("mxfp8_experts_quant", mxfp8_experts_quant); -} \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index f458f79d6f4..ed2fca26b0d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -40,12 +40,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); -torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( - torch::Tensor const& q_in, torch::Tensor const& kv, torch::Tensor& k_cache, - torch::Tensor const& slot_mapping, torch::Tensor const& position_ids, - torch::Tensor const& cos_sin_cache, int64_t q_head_padded, double eps, - int64_t cache_block_size); - void silu_and_mul_per_block_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales, int64_t group_size, @@ -107,24 +101,6 @@ torch::Tensor dynamic_4bit_int_moe_cpu( int64_t activation_kind); using fptr_t = int64_t; -fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, - torch::Tensor& rank_data, int64_t rank, - bool fully_connected); -void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, - fptr_t reg_buffer, int64_t reg_buffer_sz_bytes); -void dispose(fptr_t _fa); -int64_t meta_size(); -void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs); -std::tuple, std::vector> -get_graph_buffer_ipc_meta(fptr_t _fa); -void register_graph_buffers(fptr_t _fa, - const std::vector>& handles, - const std::vector>& offsets); -std::tuple allocate_shared_buffer_and_handle( - int64_t size); -int64_t open_mem_handle(torch::Tensor& mem_handle); -void free_shared_buffer(int64_t buffer); - #ifdef USE_ROCM fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional qr_max_size = std::nullopt); @@ -135,15 +111,3 @@ void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half = false); int64_t qr_max_size(); #endif - -#ifndef USE_ROCM -torch::Tensor minimax_allreduce_rms(torch::Tensor const& input, - torch::Tensor const& norm_weight, - torch::Tensor workspace, int64_t const rank, - int64_t const nranks, double const eps); -std::tuple minimax_allreduce_rms_qk( - torch::Tensor qkv, torch::Tensor const& norm_weight_q, - torch::Tensor const& norm_weight_k, torch::Tensor workspace, - int64_t const q_size, int64_t const kv_size, int64_t const rank, - int64_t const nranks, double const eps); -#endif diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 01869474e0f..c078222bca0 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -55,14 +55,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and // GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one - // kernel launch. - ops.def( - "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(" - "Tensor q_in, Tensor kv, Tensor! k_cache, " - "Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, " - "int q_head_padded, float eps, int cache_block_size) -> Tensor"); - ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA, - &fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert); + // kernel launch. Registered in _C_stable_libtorch. // Quantization ops #ifndef USE_ROCM @@ -163,34 +156,27 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // conditionally compiled so impl registration is in source file #endif - -#ifndef USE_ROCM - ops.def( - "minimax_allreduce_rms(" - "Tensor input," - "Tensor norm_weight," - "Tensor workspace," - "int rank," - "int nranks," - "float eps) -> Tensor"); - ops.impl("minimax_allreduce_rms", torch::kCUDA, &minimax_allreduce_rms); - ops.def( - "minimax_allreduce_rms_qk(" - "Tensor qkv," - "Tensor norm_weight_q," - "Tensor norm_weight_k," - "Tensor workspace," - "int q_size," - "int kv_size," - "int rank," - "int nranks," - "float eps) -> (Tensor, Tensor)"); - ops.impl("minimax_allreduce_rms_qk", torch::kCUDA, &minimax_allreduce_rms_qk); - - // conditionally compiled so impl in source file -#endif } +#ifdef USE_ROCM +TORCH_LIBRARY_FRAGMENT(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { + // Quick Reduce all-reduce kernels (ROCm-only; stays on legacy _C). + custom_ar.def( + "qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool " + "cast_bf2half) -> ()"); + custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce); + + custom_ar.def("init_custom_qr", &init_custom_qr); + custom_ar.def("qr_destroy", &qr_destroy); + custom_ar.def("qr_get_handle", &qr_get_handle); + + custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()"); + custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles); + + custom_ar.def("qr_max_size", &qr_max_size); +} +#endif + TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { // Cuda utils @@ -205,48 +191,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { &get_max_shared_memory_per_block_device_attribute); } -TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { - // Custom all-reduce kernels - custom_ar.def( - "init_custom_ar(int[] ipc_tensors, Tensor rank_data, " - "int rank, bool fully_connected) -> int"); - custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); - custom_ar.def( - "all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, " - "int reg_buffer_sz_bytes) -> ()"); - custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce); - - custom_ar.def("dispose", &dispose); - custom_ar.def("meta_size", &meta_size); - - custom_ar.def("register_buffer", ®ister_buffer); - custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); - custom_ar.def("register_graph_buffers", ®ister_graph_buffers); - - custom_ar.def("allocate_shared_buffer_and_handle", - &allocate_shared_buffer_and_handle); - custom_ar.def("open_mem_handle(Tensor mem_handle) -> int", &open_mem_handle); - custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle); - - custom_ar.def("free_shared_buffer", &free_shared_buffer); -#ifdef USE_ROCM - // Quick Reduce all-reduce kernels - custom_ar.def( - "qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool " - "cast_bf2half) -> ()"); - custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce); - - custom_ar.def("init_custom_qr", &init_custom_qr); - custom_ar.def("qr_destroy", &qr_destroy); - - custom_ar.def("qr_get_handle", &qr_get_handle); - - custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()"); - custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles); - - // Max input size in bytes - custom_ar.def("qr_max_size", &qr_max_size); -#endif -} - REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/type_convert.cuh b/csrc/type_convert.cuh index 9d939bb828f..8093c4bc871 100644 --- a/csrc/type_convert.cuh +++ b/csrc/type_convert.cuh @@ -50,7 +50,7 @@ struct _typeConvert { #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) // CUDA < 12.0 runs into issues with packed type conversion template <> -struct _typeConvert { +struct _typeConvert { static constexpr bool exists = true; using hip_type = __half; using packed_hip_type = __half2; @@ -73,7 +73,7 @@ struct _typeConvert { // CUDA_ARCH < 800 does not have BF16 support // ROCm 7.0+ supports bfloat16 template <> -struct _typeConvert { +struct _typeConvert { static constexpr bool exists = true; using hip_type = __nv_bfloat16; using packed_hip_type = __nv_bfloat162;