diff --git a/CMakeLists.txt b/CMakeLists.txt index 0b919e2314a..09f2a3baf27 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -486,32 +486,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") " in CUDA target architectures") endif() - # Expert-specialization MXFP8 blockscaled grouped kernels (SM100+). - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND ES_MXFP8_GROUPED_MM_ARCHS) - set(SRCS - "csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu" - "csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${ES_MXFP8_GROUPED_MM_ARCHS}") - list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_ES_MXFP8_GROUPED_MM_SM100=1") - message(STATUS "Building ES MXFP8 grouped kernels for archs: ${ES_MXFP8_GROUPED_MM_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 - AND ES_MXFP8_GROUPED_MM_ARCHS) - message(STATUS "Not building ES MXFP8 grouped kernels as CUDA Compiler version is " - "not >= 12.8.") - else() - message(STATUS "Not building ES MXFP8 grouped kernels as no compatible archs found " - "in CUDA target architectures.") - endif() - endif() # # Machete kernels diff --git a/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu b/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu deleted file mode 100644 index fda9bc020da..00000000000 --- a/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu +++ /dev/null @@ -1,69 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright contributors to the vLLM project -// Adapted from SGLang: -// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu - -#include -#include -#include "libtorch_stable/torch_utils.h" - -#include "cutlass_mxfp8_grouped_mm_launcher.cuh" - -void cutlass_mxfp8_grouped_mm(const torch::stable::Tensor& a, - const torch::stable::Tensor& b, - const torch::stable::Tensor& sfa, - const torch::stable::Tensor& sfb, - torch::stable::Tensor& d, - const torch::stable::Tensor& problem_sizes, - const torch::stable::Tensor& expert_offsets, - const torch::stable::Tensor& blockscale_offsets) { -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) - STD_TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); - STD_TORCH_CHECK(problem_sizes.size(1) == 3, - "problem_sizes must have shape (num_experts, 3)"); - STD_TORCH_CHECK( - problem_sizes.size(0) == expert_offsets.size(0), - "Number of experts in problem_sizes must match expert_offsets"); - STD_TORCH_CHECK( - problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int, - "problem_sizes must be int32"); - STD_TORCH_CHECK( - expert_offsets.scalar_type() == torch::headeronly::ScalarType::Int, - "expert_offsets must be int32"); - STD_TORCH_CHECK( - blockscale_offsets.scalar_type() == torch::headeronly::ScalarType::Int, - "blockscale_offsets must be int32"); - STD_TORCH_CHECK(a.dim() == 2, - "a must be a 2D tensor of shape (num_tokens, k)"); - STD_TORCH_CHECK(b.dim() == 3, - "b must be a 3D tensor of shape (num_experts, k, n)"); - STD_TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0, - "k should align 128"); - STD_TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128"); - STD_TORCH_CHECK(a.stride(1) == 1, "a must be row major"); - STD_TORCH_CHECK(b.stride(1) == 1, "b must be column major"); - - const torch::stable::accelerator::DeviceGuard device_guard( - a.get_device_index()); - auto stream = get_current_cuda_stream(a.get_device_index()); - if (d.scalar_type() == torch::headeronly::ScalarType::BFloat16) { - expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype< - cutlass::bfloat16_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets, - blockscale_offsets, stream); - } else if (d.scalar_type() == torch::headeronly::ScalarType::Half) { - expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype< - cutlass::half_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets, - blockscale_offsets, stream); - } else { - STD_TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16"); - } -#else - STD_TORCH_CHECK(false, - "No implemented cutlass_mxfp8_grouped_mm for " - "current device"); -#endif -} - -STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { - m.impl("cutlass_mxfp8_grouped_mm", TORCH_BOX(&cutlass_mxfp8_grouped_mm)); -} diff --git a/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh b/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh deleted file mode 100644 index 9fb1dbf8eef..00000000000 --- a/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh +++ /dev/null @@ -1,141 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright contributors to the vLLM project -// Adapted from SGLang: -// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_functor.cuh - -#pragma once -#include - -#include "cute/tensor.hpp" -#include "cutlass/util/packed_stride.hpp" -#include "cutlass_mxfp8_grouped_mm_traits.cuh" - -namespace expert_specialization { - -using namespace cute; - -template -struct CutlassMxfp8GroupedMmOffsetFunctor { - using Gemm = typename GemmTraits::Gemm; - using ElementA = typename Gemm::ElementA; - using ElementB = typename Gemm::ElementB; - using ElementSF = typename GemmTraits::ElementSF; - using ElementD = typename GemmTraits::ElementOutput; - // Input - int* expert_offsets{nullptr}; - int* blockscale_offsets{nullptr}; - // Output - ElementA* a_base{nullptr}; - ElementB* b_base{nullptr}; - ElementSF* sfa_base{nullptr}; - ElementSF* sfb_base{nullptr}; - ElementD* d_base{nullptr}; - ElementA** a_offsets{nullptr}; - ElementB** b_offsets{nullptr}; - ElementSF** sfa_offsets{nullptr}; - ElementSF** sfb_offsets{nullptr}; - ElementD** d_offsets{nullptr}; - - CutlassMxfp8GroupedMmOffsetFunctor() = default; - CutlassMxfp8GroupedMmOffsetFunctor( - int* _expert_offsets, int* _blockscale_offsets, ElementA* _a_base, - ElementB* _b_base, ElementSF* _sfa_base, ElementSF* _sfb_base, - ElementD* _d_base, ElementA** _a_offsets, ElementB** _b_offsets, - ElementSF** _sfa_offsets, ElementSF** _sfb_offsets, ElementD** _d_offsets) - : expert_offsets{_expert_offsets}, - blockscale_offsets{_blockscale_offsets}, - a_base(_a_base), - b_base(_b_base), - sfa_base(_sfa_base), - sfb_base(_sfb_base), - d_base(_d_base), - a_offsets(_a_offsets), - b_offsets(_b_offsets), - sfa_offsets(_sfa_offsets), - sfb_offsets(_sfb_offsets), - d_offsets(_d_offsets) {} - - void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { - int64_t expert_offset = static_cast(expert_offsets[expert_id]); - int64_t blockscale_offset = - static_cast(blockscale_offsets[expert_id]); - int64_t a_stride = expert_offset * k; - int64_t b_stride = expert_id * k * n; - int64_t d_stride = expert_offset * n; - int64_t sfa_stride = blockscale_offset * (k / 32); - int64_t sfb_stride = expert_id * n * (k / 32); - - a_offsets[expert_id] = a_base + a_stride; - b_offsets[expert_id] = b_base + b_stride; - sfa_offsets[expert_id] = sfa_base + sfa_stride; - sfb_offsets[expert_id] = sfb_base + sfb_stride; - d_offsets[expert_id] = d_base + d_stride; - } -}; - -template -struct CutlassMxfp8GroupedMmLayoutFunctor { - using Sm1xxBlkScaledConfig = typename GemmTraits::Sm1xxBlkScaledConfig; - using LayoutSFA = typename GemmTraits::LayoutSFA; - using LayoutSFB = typename GemmTraits::LayoutSFB; - LayoutSFA* layout_sfa_base{nullptr}; - LayoutSFB* layout_sfb_base{nullptr}; - - CutlassMxfp8GroupedMmLayoutFunctor() = default; - CutlassMxfp8GroupedMmLayoutFunctor(LayoutSFA* _layout_sfa_base, - LayoutSFB* _layout_sfb_base) - : layout_sfa_base(_layout_sfa_base), layout_sfb_base(_layout_sfb_base) {} - - void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { - LayoutSFA* layout_sfa_ptr = layout_sfa_base + expert_id; - LayoutSFB* layout_sfb_ptr = layout_sfb_base + expert_id; - *layout_sfa_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA( - cute::make_shape(m, n, k, 1)); - *layout_sfb_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB( - cute::make_shape(m, n, k, 1)); - } -}; - -template -struct CutlassMxfp8GroupedMmStrideFunctor { - using StrideA = typename GemmTraits::StrideA; - using StrideB = typename GemmTraits::StrideB; - using StrideD = typename GemmTraits::StrideD; - StrideA* stride_A_base{nullptr}; - StrideB* stride_B_base{nullptr}; - StrideD* stride_D_base{nullptr}; - - CutlassMxfp8GroupedMmStrideFunctor() = default; - CutlassMxfp8GroupedMmStrideFunctor(StrideA* _stride_A_base, - StrideB* _stride_B_base, - StrideD* _stride_D_base) - : stride_A_base(_stride_A_base), - stride_B_base(_stride_B_base), - stride_D_base(_stride_D_base) {} - - void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { - StrideA* stride_A = stride_A_base + expert_id; - StrideB* stride_B = stride_B_base + expert_id; - StrideD* stride_D = stride_D_base + expert_id; - *stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); - *stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1}); - *stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1}); - } -}; - -template -__global__ void cutlassMxfp8GroupedMmPreComputeKernel( - int* problem_sizes, OffsetFunctor offset_functor, - LayoutFunctor layout_functor, StrideFunctor stride_functor) { - int64_t expert_id = static_cast(threadIdx.x); - int m = problem_sizes[expert_id * 3 + 0]; - int n = problem_sizes[expert_id * 3 + 1]; - int k = problem_sizes[expert_id * 3 + 2]; - - offset_functor(expert_id, m, n, k); - layout_functor(expert_id, m, n, k); - stride_functor(expert_id, m, n, k); -} - -} // namespace expert_specialization \ No newline at end of file diff --git a/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh b/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh deleted file mode 100644 index 82d6543b288..00000000000 --- a/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh +++ /dev/null @@ -1,198 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright contributors to the vLLM project -// Adapted from SGLang: -// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_launcher.cuh - -#pragma once - -#include -#include - -#include -#include -#include - -#include "cute/tensor.hpp" -#include "cutlass_mxfp8_grouped_mm_functor.cuh" -#include "cutlass_mxfp8_grouped_mm_traits.cuh" -#include "libtorch_stable/torch_utils.h" - -namespace expert_specialization { - -template -void cutlass_mxfp8_grouped_mm_pre_compute( - torch::stable::Tensor& a_ptrs, torch::stable::Tensor& b_ptrs, - torch::stable::Tensor& sfa_ptrs, torch::stable::Tensor& sfb_ptrs, - torch::stable::Tensor& d_ptrs, torch::stable::Tensor& stride_a, - torch::stable::Tensor& stride_b, torch::stable::Tensor& stride_d, - torch::stable::Tensor& layout_sfa, torch::stable::Tensor& layout_sfb, - const torch::stable::Tensor& a, const torch::stable::Tensor& b, - const torch::stable::Tensor& sfa, const torch::stable::Tensor& sfb, - const torch::stable::Tensor& d, const torch::stable::Tensor& problem_sizes, - const torch::stable::Tensor& expert_offsets, - const torch::stable::Tensor& blockscale_offsets, cudaStream_t stream) { - using OffsetFunctor = CutlassMxfp8GroupedMmOffsetFunctor; - using ElementA = typename OffsetFunctor::ElementA; - using ElementB = typename OffsetFunctor::ElementB; - using ElementSF = typename OffsetFunctor::ElementSF; - using ElementD = typename OffsetFunctor::ElementD; - - using LayoutFunctor = CutlassMxfp8GroupedMmLayoutFunctor; - using LayoutSFA = typename LayoutFunctor::LayoutSFA; - using LayoutSFB = typename LayoutFunctor::LayoutSFB; - - using StrideFunctor = CutlassMxfp8GroupedMmStrideFunctor; - using StrideA = typename StrideFunctor::StrideA; - using StrideB = typename StrideFunctor::StrideB; - using StrideD = typename StrideFunctor::StrideD; - - int num_experts = static_cast(expert_offsets.size(0)); - STD_TORCH_CHECK(num_experts <= 1024, - "Number of experts cannot exceed 1024, the maximum number of " - "threads per block."); - - OffsetFunctor offset_functor( - reinterpret_cast(expert_offsets.data_ptr()), - reinterpret_cast(blockscale_offsets.data_ptr()), - reinterpret_cast(a.data_ptr()), - reinterpret_cast(b.data_ptr()), - reinterpret_cast(sfa.data_ptr()), - reinterpret_cast(sfb.data_ptr()), - reinterpret_cast(d.data_ptr()), - reinterpret_cast(a_ptrs.data_ptr()), - reinterpret_cast(b_ptrs.data_ptr()), - reinterpret_cast(sfa_ptrs.data_ptr()), - reinterpret_cast(sfb_ptrs.data_ptr()), - reinterpret_cast(d_ptrs.data_ptr())); - LayoutFunctor layout_functor( - reinterpret_cast(layout_sfa.data_ptr()), - reinterpret_cast(layout_sfb.data_ptr())); - StrideFunctor stride_functor(reinterpret_cast(stride_a.data_ptr()), - reinterpret_cast(stride_b.data_ptr()), - reinterpret_cast(stride_d.data_ptr())); - cutlassMxfp8GroupedMmPreComputeKernel<<<1, num_experts, 0, stream>>>( - static_cast(problem_sizes.data_ptr()), offset_functor, - layout_functor, stride_functor); -} - -template -void cutlass_mxfp8_grouped_mm(const torch::stable::Tensor& a_ptrs, - const torch::stable::Tensor& b_ptrs, - const torch::stable::Tensor& sfa_ptrs, - const torch::stable::Tensor& sfb_ptrs, - const torch::stable::Tensor& d_ptrs, - const torch::stable::Tensor& stride_a, - const torch::stable::Tensor& stride_b, - const torch::stable::Tensor& stride_d, - const torch::stable::Tensor& layout_sfa, - const torch::stable::Tensor& layout_sfb, - const torch::stable::Tensor& problem_sizes, - cudaStream_t stream) { - using Gemm = typename GemmTraits::Gemm; - using ElementA = typename Gemm::ElementA; - using ElementB = typename Gemm::ElementB; - using ElementSF = typename GemmTraits::ElementSF; - using ElementD = typename GemmTraits::ElementOutput; - using StrideA = typename GemmTraits::StrideA; - using StrideB = typename GemmTraits::StrideB; - using StrideD = typename GemmTraits::StrideD; - using LayoutSFA = typename GemmTraits::LayoutSFA; - using LayoutSFB = typename GemmTraits::LayoutSFB; - using UnderlyingProblemShape = - typename GemmTraits::ProblemShape::UnderlyingProblemShape; - - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = d_ptrs.get_device_index(); - hw_info.sm_count = get_device_prop()->multiProcessorCount; - hw_info.cluster_shape = GemmTraits::MMAConfig::preferred_cluster; - hw_info.cluster_shape_fallback = GemmTraits::MMAConfig::fallback_cluster; - - int num_experts = static_cast(problem_sizes.size(0)); - - UnderlyingProblemShape* underlying_problem_shape = - reinterpret_cast(problem_sizes.data_ptr()); - - typename Gemm::Arguments arguments = { - cutlass::gemm::GemmUniversalMode::kGrouped, - {num_experts, underlying_problem_shape, nullptr}, - {reinterpret_cast(a_ptrs.data_ptr()), - reinterpret_cast(stride_a.data_ptr()), - reinterpret_cast(b_ptrs.data_ptr()), - reinterpret_cast(stride_b.data_ptr()), - reinterpret_cast(sfa_ptrs.data_ptr()), - reinterpret_cast(layout_sfa.data_ptr()), - reinterpret_cast(sfb_ptrs.data_ptr()), - reinterpret_cast(layout_sfb.data_ptr())}, - {{}, - nullptr, - nullptr, - reinterpret_cast(d_ptrs.data_ptr()), - reinterpret_cast(stride_d.data_ptr())}, - hw_info, - {} // Scheduler - }; - - Gemm gemm; - - auto can_implement_status = gemm.can_implement(arguments); - STD_TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, - "Failed to implement GEMM"); - - size_t workspace_size = gemm.get_workspace_size(arguments); - torch::stable::Tensor workspace = torch::stable::empty( - {static_cast(workspace_size)}, - torch::headeronly::ScalarType::Byte, std::nullopt, d_ptrs.device()); - - auto status = gemm.initialize(arguments, workspace.data_ptr(), stream); - STD_TORCH_CHECK(status == cutlass::Status::kSuccess, - "Failed to initialize GEMM"); - - status = gemm.run(stream, nullptr, true); // Enable PDL - STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); -} - -template -void cutlass_mxfp8_grouped_mm_dispatch_out_dtype( - const torch::stable::Tensor& a, const torch::stable::Tensor& b, - const torch::stable::Tensor& sfa, const torch::stable::Tensor& sfb, - torch::stable::Tensor& d, const torch::stable::Tensor& problem_sizes, - const torch::stable::Tensor& expert_offsets, - const torch::stable::Tensor& blockscale_offsets, cudaStream_t stream) { - int num_experts = static_cast(problem_sizes.size(0)); - auto device = a.device(); - - torch::stable::Tensor a_ptrs = torch::stable::empty( - num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); - torch::stable::Tensor b_ptrs = torch::stable::empty( - num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); - torch::stable::Tensor sfa_ptrs = torch::stable::empty( - num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); - torch::stable::Tensor sfb_ptrs = torch::stable::empty( - num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); - torch::stable::Tensor d_ptrs = torch::stable::empty( - num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); - - torch::stable::Tensor stride_a = torch::stable::empty( - num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); - torch::stable::Tensor stride_b = torch::stable::empty( - num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); - torch::stable::Tensor stride_d = torch::stable::empty( - num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); - torch::stable::Tensor layout_sfa = - torch::stable::empty({num_experts, 5}, torch::headeronly::ScalarType::Int, - std::nullopt, device); - torch::stable::Tensor layout_sfb = - torch::stable::empty({num_experts, 5}, torch::headeronly::ScalarType::Int, - std::nullopt, device); - - using GemmTraits = CutlassMxfp8GroupedMmGemmTraits; - cutlass_mxfp8_grouped_mm_pre_compute( - a_ptrs, b_ptrs, sfa_ptrs, sfb_ptrs, d_ptrs, stride_a, stride_b, stride_d, - layout_sfa, layout_sfb, a, b, sfa, sfb, d, problem_sizes, expert_offsets, - blockscale_offsets, stream); - cutlass_mxfp8_grouped_mm( - a_ptrs, b_ptrs, sfa_ptrs, sfb_ptrs, d_ptrs, stride_a, stride_b, stride_d, - layout_sfa, layout_sfb, problem_sizes, stream); -} - -} // namespace expert_specialization diff --git a/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh b/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh deleted file mode 100644 index ed8cd7ce065..00000000000 --- a/csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh +++ /dev/null @@ -1,127 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright contributors to the vLLM project -// Adapted from SGLang: -// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_traits.cuh - -#pragma once - -// Misc -#include "cute/tensor.hpp" -#include "cutlass/arch/arch.h" -#include "cutlass/arch/mma.h" -#include "cutlass/cutlass.h" -#include "cutlass/detail/sm100_blockscaled_layout.hpp" -#include "cutlass/epilogue/dispatch_policy.hpp" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/group_array_problem_shape.hpp" -#include "cutlass/layout/layout.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_size.h" - -// Collective Builder -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" -#include "cutlass/epilogue/thread/activation.h" -#include "cutlass/gemm/collective/collective_builder.hpp" - -// Integration -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" - -namespace expert_specialization { - -using namespace cute; - -// Different configs for 1SM and 2SM MMA kernel -struct MMA1SMConfig { - using MmaTileShape = Shape<_128, _128, _128>; - using KernelSchedule = - cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100; - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; - const static dim3 preferred_cluster; - const static dim3 fallback_cluster; -}; -const dim3 MMA1SMConfig::preferred_cluster(1, 4, 1); -const dim3 MMA1SMConfig::fallback_cluster(1, 2, 1); - -template -struct CutlassMxfp8GroupedMmGemmTraits { - using MMAConfig = _MMAConfig; - using ElementInput = cutlass::float_e4m3_t; - using ElementOutput = OutputDtype; - using ProblemShape = cutlass::gemm::GroupProblemShape>; - - // A matrix configuration - using ElementA = cutlass::mx_float8_t; - using LayoutA = cutlass::layout::RowMajor; - constexpr static int AlignmentA = 32; - - // B matrix configuration - using ElementB = cutlass::mx_float8_t; - using LayoutB = cutlass::layout::ColumnMajor; - constexpr static int AlignmentB = 32; - - // C/D matrix configuration - using ElementC = void; - using ElementD = ElementOutput; - using LayoutC = cutlass::layout::RowMajor; - using LayoutD = cutlass::layout::RowMajor; - constexpr static int AlignmentC = 128 / cutlass::sizeof_bits::value; - constexpr static int AlignmentD = 128 / cutlass::sizeof_bits::value; - using ElementAccumulator = float; - - static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; - using CustomEVTIdentity = // acc - cutlass::epilogue::fusion::Sm90EVT< - cutlass::epilogue::fusion::Sm90Compute< - cutlass::epilogue::thread::Identity, ElementD, ElementAccumulator, - RoundStyle>, - cutlass::epilogue::fusion::Sm90AccFetch>; - - // Core kernel configurations - using ArchTag = cutlass::arch::Sm100; - using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; - using StageCountType = cutlass::gemm::collective::StageCountAuto; - - // Runtime Cluster Shape - using ClusterShape = Shape; - - // Define Epilogue - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, typename MMAConfig::MmaTileShape, - ClusterShape, Shape<_64, _64>, ElementAccumulator, ElementAccumulator, - ElementC, LayoutC*, AlignmentC, ElementD, LayoutD*, AlignmentD, - typename MMAConfig::EpilogueSchedule, - CustomEVTIdentity>::CollectiveOp; - - // Define Mainloop - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, - LayoutB*, AlignmentB, ElementAccumulator, - typename MMAConfig::MmaTileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - typename MMAConfig::KernelSchedule>::CollectiveOp; - - // Define GemmKernel - using GemmKernel = - cutlass::gemm::kernel::GemmUniversal; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - using ElementSF = typename Gemm::GemmKernel::ElementSF; - using StrideA = typename Gemm::GemmKernel::InternalStrideA; - using StrideB = typename Gemm::GemmKernel::InternalStrideB; - using StrideC = typename Gemm::GemmKernel::InternalStrideC; - using StrideD = typename Gemm::GemmKernel::InternalStrideD; - using LayoutSFA = - typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; - using LayoutSFB = - typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; - using Sm1xxBlkScaledConfig = - typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; -}; - -} // namespace expert_specialization \ No newline at end of file diff --git a/csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cu b/csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cu deleted file mode 100644 index e075721c2a3..00000000000 --- a/csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cu +++ /dev/null @@ -1,66 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright contributors to the vLLM project -// Adapted from SGLang: -// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu - -#include -#include -#include "libtorch_stable/torch_utils.h" - -#include "mxfp8_experts_quant.cuh" - -void mxfp8_experts_quant(const torch::stable::Tensor& input, - const torch::stable::Tensor& problem_sizes, - const torch::stable::Tensor& expert_offsets, - const torch::stable::Tensor& blockscale_offsets, - torch::stable::Tensor& quant_output, - torch::stable::Tensor& scale_factor) { -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) - STD_TORCH_CHECK(input.dim() == 2, "input must be 2D tensor"); - STD_TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128"); - STD_TORCH_CHECK(input.stride(1) == 1, "input must be row major"); - STD_TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); - STD_TORCH_CHECK( - problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int, - "problem_sizes must be int32"); - STD_TORCH_CHECK( - expert_offsets.scalar_type() == torch::headeronly::ScalarType::Int, - "expert_offsets must be int32"); - STD_TORCH_CHECK( - blockscale_offsets.scalar_type() == torch::headeronly::ScalarType::Int, - "blockscale_offsets must be int32"); - - auto groups = problem_sizes.size(0); - STD_TORCH_CHECK( - expert_offsets.dim() == 1 && expert_offsets.size(0) == groups, - "expert_offsets must be 1D and have size equal to the number of groups"); - STD_TORCH_CHECK( - blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups, - "blockscale_offsets must be 1D and have size equal to the number of " - "groups"); - - const torch::stable::accelerator::DeviceGuard device_guard( - input.get_device_index()); - if (input.scalar_type() == torch::headeronly::ScalarType::BFloat16) { - expert_specialization::launch_mxfp8_experts_quant<__nv_bfloat16>( - input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, - scale_factor); - } else if (input.scalar_type() == torch::headeronly::ScalarType::Half) { - expert_specialization::launch_mxfp8_experts_quant<__half>( - input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, - scale_factor); - } else { - STD_TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16"); - } -#else - STD_TORCH_CHECK(false, - "No implemented mxfp8_experts_quant for " - "current device"); -#endif -} - -// Registered here (not torch_bindings.cpp) because ENABLE_ES_MXFP8_GROUPED_MM -// is applied only under COMPILE_LANGUAGE:CUDA. -STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { - m.impl("mxfp8_experts_quant", TORCH_BOX(&mxfp8_experts_quant)); -} diff --git a/csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cuh b/csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cuh deleted file mode 100644 index a57e00e76c3..00000000000 --- a/csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cuh +++ /dev/null @@ -1,416 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright contributors to the vLLM project -// Adapted from SGLang: -// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cuh - -#pragma once -#include -#include -#include - -#include -#include -#include -#include - -#include - -#include "cute/tensor.hpp" -#include "libtorch_stable/torch_utils.h" - -namespace expert_specialization { - -using namespace cute; - -constexpr uint32_t THREAD_BLOCK_SIZE = 128; -constexpr uint32_t WARP_SIZE = 32; -constexpr int BLOCK_M = 128; -constexpr int BLOCK_K = 128; -using ThrLayout = Layout, Stride<_8, _1>>; -using ValLayout = Layout>; -using SfR2SThrLayout = Layout, Stride<_4, _1>>; -using SfR2SValLayout = Layout>; -using ScaleFactorTileLayout = - Layout, _4>, Stride, _1>>; - -// Fast reciprocal. -inline __device__ float reciprocal_approximate_ftz(float a) { - float b; - asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); - return b; -} - -// Some code references TRT-LLM: -// https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/quantization.cuh -template -__inline__ __device__ uint8_t cvt_warp_fp16_to_mxfp8(FragmentS& fragment_s, - FragmentD& fragment_d) { - using FragmentSLayout = typename FragmentS::layout_type; - using FragmentDLayout = typename FragmentD::layout_type; - FragmentSLayout fragment_s_layout; - FragmentDLayout fragment_d_layout; - static_assert(is_static::value && - size(fragment_s_layout) == 16); - static_assert(is_static::value && - size(fragment_d_layout) == 16); - - constexpr int eles_per_thr = 16; - using ValType = typename FragmentS::element_type; - using VecType = std::conditional_t, - __nv_bfloat162, __half2>; - VecType vec[8]; - // Assign vals - vec[0].x = fragment_s(Int<0>{}); - vec[0].y = fragment_s(Int<1>{}); - vec[1].x = fragment_s(Int<2>{}); - vec[1].y = fragment_s(Int<3>{}); - vec[2].x = fragment_s(Int<4>{}); - vec[2].y = fragment_s(Int<5>{}); - vec[3].x = fragment_s(Int<6>{}); - vec[3].y = fragment_s(Int<7>{}); - vec[4].x = fragment_s(Int<8>{}); - vec[4].y = fragment_s(Int<9>{}); - vec[5].x = fragment_s(Int<10>{}); - vec[5].y = fragment_s(Int<11>{}); - vec[6].x = fragment_s(Int<12>{}); - vec[6].y = fragment_s(Int<13>{}); - vec[7].x = fragment_s(Int<14>{}); - vec[7].y = fragment_s(Int<15>{}); - - auto local_max = __habs2(vec[0]); - for (int i = 1; i < eles_per_thr / 2; i++) { - local_max = __hmax2(__habs2(vec[i]), local_max); - } - local_max = __hmax2(__shfl_xor_sync(uint32_t(-1), local_max, 1), local_max); - - // Get the final absolute maximum values. - float block_max(0.0f); - if constexpr (std::is_same_v) { - block_max = __bfloat162float(__hmax(local_max.x, local_max.y)); - } else { - block_max = __half2float(__hmax(local_max.x, local_max.y)); - } - // Get the SF (max value of the vector / max value of mxfp8). - float sf_val = block_max * reciprocal_approximate_ftz(448.0f); - // 8 bits representation of the SF. - uint8_t fp8_sf_val; - - __nv_fp8_e8m0 tmp_sf_val; - tmp_sf_val.__x = - __nv_cvt_float_to_e8m0(sf_val, __NV_SATFINITE, cudaRoundPosInf); - sf_val = static_cast(tmp_sf_val); - fp8_sf_val = tmp_sf_val.__x; - // Get the output scale (reciprocal of the SFValue). - float output_scale = - block_max != 0.f ? reciprocal_approximate_ftz(sf_val) : 0.0f; - - // Convert the input to float. - float2 fp2_vals[eles_per_thr / 2]; - -#pragma unroll - for (int i = 0; i < eles_per_thr / 2; i++) { - if constexpr (std::is_same_v) { - fp2_vals[i] = __half22float2(vec[i]); - } else { - fp2_vals[i] = __bfloat1622float2(vec[i]); - } - fp2_vals[i].x *= output_scale; - fp2_vals[i].y *= output_scale; - } - union { - uint8_t bytes[16]; - __nv_fp8x2_e4m3 elts[8]; - } u; - u.elts[0] = __nv_fp8x2_e4m3(fp2_vals[0]); - u.elts[1] = __nv_fp8x2_e4m3(fp2_vals[1]); - u.elts[2] = __nv_fp8x2_e4m3(fp2_vals[2]); - u.elts[3] = __nv_fp8x2_e4m3(fp2_vals[3]); - u.elts[4] = __nv_fp8x2_e4m3(fp2_vals[4]); - u.elts[5] = __nv_fp8x2_e4m3(fp2_vals[5]); - u.elts[6] = __nv_fp8x2_e4m3(fp2_vals[6]); - u.elts[7] = __nv_fp8x2_e4m3(fp2_vals[7]); - fragment_d(Int<0>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[0]); - fragment_d(Int<1>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[1]); - fragment_d(Int<2>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[2]); - fragment_d(Int<3>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[3]); - fragment_d(Int<4>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[4]); - fragment_d(Int<5>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[5]); - fragment_d(Int<6>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[6]); - fragment_d(Int<7>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[7]); - fragment_d(Int<8>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[8]); - fragment_d(Int<9>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[9]); - fragment_d(Int<10>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[10]); - fragment_d(Int<11>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[11]); - fragment_d(Int<12>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[12]); - fragment_d(Int<13>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[13]); - fragment_d(Int<14>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[14]); - fragment_d(Int<15>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[15]); - return fp8_sf_val; -} - -template -__inline__ __device__ void mxfp8_experts_quant_tile( - TensorS& tensor_s, TensorP& tensor_p, TensorD& tensor_d, - TensorSharedSF& tensor_shared_sf, TensorSF& tensor_sf, int m, - TiledCopyG2R& tiled_copy_g2r, TiledCopyR2G& tiled_copy_r2g, - TiledCopyR2S& tiled_copy_r2s) { - static_assert(size(get<0>(typename TensorS::layout_type{})) == 128 && - size(get<1>(typename TensorS::layout_type{})) == 128 && - stride(get<1>(typename TensorS::layout_type{})) == 1); - static_assert(size(get<0>(typename TensorD::layout_type{})) == 128 && - size(get<1>(typename TensorD::layout_type{})) == 128 && - stride(get<1>(typename TensorD::layout_type{})) == 1); - static_assert(size(get<0>(typename TensorP::layout_type{})) == 128 && - size(get<1>(typename TensorP::layout_type{})) == 128); - static_assert(size(get<0>(typename TensorSharedSF::layout_type{})) == 128 && - size(get<1>(typename TensorSharedSF::layout_type{})) == 4); - static_assert(size(get<0>(typename TensorSF::layout_type{})) == 128 && - size(get<1>(typename TensorSF::layout_type{})) == 4); - - using Tiler_MN = typename TiledCopyG2R::Tiler_MN; - auto tiler_mn = Tiler_MN{}; - static_assert(size<0>(tiler_mn) == 16 && size<1>(tiler_mn) == 128); - - auto tiled_tensor_s = tiled_divide(tensor_s, tiler_mn); - auto tiled_tensor_p = tiled_divide(tensor_p, tiler_mn); - auto tiled_tensor_d = tiled_divide(tensor_d, tiler_mn); - static_assert(size<2>(tiled_tensor_s) == 1); - static_assert(size<2>(tiled_tensor_p) == 1); - static_assert(size<2>(tiled_tensor_d) == 1); - auto squeeze_tiled_tensor_s = take<0, 2>(tiled_tensor_s); - auto squeeze_tiled_tensor_p = take<0, 2>(tiled_tensor_p); - auto squeeze_tiled_tensor_d = take<0, 2>(tiled_tensor_d); - - using SF_Tiler_MN = typename TiledCopyR2S::Tiler_MN; - auto sf_tiler_mn = SF_Tiler_MN{}; - static_assert(size<0>(sf_tiler_mn) == 16 && size<1>(sf_tiler_mn) == 4); - - auto tiled_tensor_sf = tiled_divide(tensor_sf, sf_tiler_mn); - auto tiled_tensor_shared_sf = tiled_divide(tensor_shared_sf, sf_tiler_mn); - auto squeeze_tiled_tensor_sf = take<0, 2>(tiled_tensor_sf); - auto squeeze_tiled_tensor_shared_sf = take<0, 2>(tiled_tensor_shared_sf); - - constexpr int tile_loop_count = size<1>(tiled_tensor_s); - constexpr int rows_in_tile = 16; - // We don't need to clear shared memory - // clear(squeeze_tiled_tensor_shared_sf); -#pragma unroll 4 - for (int t = 0; t < tile_loop_count; t++) { - if (t * rows_in_tile >= m) { - break; - } - auto current_copy_tile_s = tensor<0>(squeeze_tiled_tensor_s(_, t)); - auto current_copy_tile_p = tensor<0>(squeeze_tiled_tensor_p(_, t)); - auto current_copy_tile_d = tensor<0>(squeeze_tiled_tensor_d(_, t)); - auto current_copy_tile_sf = tensor<0>(squeeze_tiled_tensor_sf(_, t)); - auto current_copy_tile_shared_sf = - tensor<0>(squeeze_tiled_tensor_shared_sf(_, t)); - - // Global to Register copy - auto thr_copy_g2r = tiled_copy_g2r.get_thread_slice(threadIdx.x); - auto thr_tile_g2r_s = thr_copy_g2r.partition_S(current_copy_tile_s); - auto thr_tile_g2r_p = thr_copy_g2r.partition_S(current_copy_tile_p); - auto input_fragment = make_fragment_like(thr_tile_g2r_s); - - // Register to Global copy - auto thr_copy_r2g = tiled_copy_r2g.get_thread_slice(threadIdx.x); - auto thr_tile_r2g_d = thr_copy_r2g.partition_D(current_copy_tile_d); - auto thr_tile_r2g_p = thr_copy_r2g.partition_D(current_copy_tile_p); - auto output_fragment = make_fragment_like(thr_tile_r2g_d); - - // Register to Shared copy - auto thr_copy_r2s = tiled_copy_r2s.get_thread_slice(threadIdx.x / 2); - auto thr_tile_r2s_shared_sf = - thr_copy_r2s.partition_D(current_copy_tile_shared_sf); - auto shared_sf_fragment = make_fragment_like(thr_tile_r2s_shared_sf); - - // CopyG2R & convert & CopyR2G - copy_if(tiled_copy_g2r, thr_tile_g2r_p, thr_tile_g2r_s, input_fragment); - uint8_t fp8_sf_val = - cvt_warp_fp16_to_mxfp8(input_fragment, output_fragment); - copy_if(tiled_copy_r2g, thr_tile_r2g_p, output_fragment, thr_tile_r2g_d); - shared_sf_fragment[0] = fp8_sf_val; - - // Before first copy r2s, clear shared memory and wait previous group - if (t == 0 && threadIdx.x == 0) { - // Wait for the group to have completed reading from shared memory. - cuda::ptx::cp_async_bulk_wait_group_read(cuda::ptx::n32_t<0>()); - } - __syncthreads(); - - if (threadIdx.x % 2 == 0) { - copy(tiled_copy_r2s, shared_sf_fragment, thr_tile_r2s_shared_sf); - } - __syncthreads(); - } - - // Wait for shared memory writes to be visible to TMA engine. - cuda::ptx::fence_proxy_async(cuda::ptx::space_shared); // b) - __syncthreads(); - - if (threadIdx.x == 0) { - cuda::ptx::cp_async_bulk(cuda::ptx::space_global, cuda::ptx::space_shared, - squeeze_tiled_tensor_sf.data().get(), - squeeze_tiled_tensor_shared_sf.data().get(), 512); - // Wait for TMA transfer to have finished reading shared memory. - // Create a "bulk async-group" out of the previous bulk copy operation. - cuda::ptx::cp_async_bulk_commit_group(); - } - __syncthreads(); -} - -template -__global__ void mxfp8_experts_quant_kernel( - const T_IN* input, const int* problem_sizes, const int* expert_offsets, - const int* blockscale_offsets, cutlass::float_e4m3_t* quant_output, - uint8_t* scale_factor, int groups, TiledCopyG2R tiled_copy_g2r, - TiledCopyR2G tiled_copy_r2g, TiledCopyR2S tiled_copy_r2s) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 - __shared__ __align__(512) uint8_t shared_memory[512]; - ScaleFactorTileLayout scale_factor_tile_layout{}; - auto scale_factor_shared = - make_tensor(make_smem_ptr(shared_memory), - scale_factor_tile_layout); // ((_32,_4), _4):((_16,_4), _1) - // TODO: Transform Groupwise Schedule into a more efficient Schedule - for (int g = 0; g < groups; g++) { - int m = problem_sizes[g * 3 + 0]; - int k = problem_sizes[g * 3 + 2]; - int64_t expert_offset = static_cast(expert_offsets[g]); - int64_t blockscale_offset = static_cast(blockscale_offsets[g]); - - auto input_tensor = make_tensor( - make_gmem_ptr(input + expert_offset * k), - make_layout(make_shape(m, k), - LayoutRight{})); // (M, K):(K, 1) half_t/bfloat16_t - - auto quant_output_tensor = make_tensor( - make_gmem_ptr(quant_output + expert_offset * k), - make_layout(make_shape(m, k), - LayoutRight{})); // (M, K):(K, 1) cutlass::float_e4m3_t - - auto scale_factor_shape = make_shape(ceil_div(m, 128) * 128, k / 32); - auto scale_factor_layout = tile_to_shape(scale_factor_tile_layout, - scale_factor_shape, LayoutRight{}); - // layout<0>(layout<0>(scale_factor_layout)) (_32,_4):(_16,_4) -- static - // layout<1>(layout<0>(scale_factor_layout)) M_align_128 / 128 -- dynamic - // shape dynamic stride layout<0>(layout<1>(scale_factor_layout)) _4:_1 -- - // static layout<1>(layout<1>(scale_factor_layout)) (K / 32) / 4 : _512 -- - // dynamic shape static stride - - // Reshape to zipped layout for 1D indexing - auto zipped_scale_factor_layout = make_layout( - make_layout(layout<0>(layout<0>(scale_factor_layout)), - layout<0>(layout<1>(scale_factor_layout))), - make_layout( - layout<1>(layout<0>(scale_factor_layout)), - layout<1>(layout<1>( - scale_factor_layout)))); // (((_32,_4),_4),(M_align_128 / - // 128,(K / 32) / - // 4)):(((_16,_4),_1),(?,_512)) - - auto scale_factor_tensor = - make_tensor(make_gmem_ptr(scale_factor + blockscale_offset * (k / 32)), - zipped_scale_factor_layout); - - // Used for cases where M is not divisible by 128 (most scenarios). - auto input_shape = shape(input_tensor); // (M, K):(K, 1) - auto identity_tensor = make_identity_tensor(input_shape); - auto predict_tensor = cute::lazy::transform( - identity_tensor, [&](auto c) { return elem_less(c, input_shape); }); - - // (_128, _128) - auto tiler = make_shape(Int{}, Int{}); - - auto tiled_input_tensor = zipped_divide( - input_tensor, tiler); // ((128, 128), (cdiv(M, 128), cdiv(K, 128))) - auto tiled_quant_output_tensor = - zipped_divide(quant_output_tensor, - tiler); // ((128, 128), (cdiv(M, 128), cdiv(K, 128))) - auto tiled_predict_tensor = zipped_divide( - predict_tensor, tiler); // ((128, 128), (cdiv(M, 128), cdiv(K, 128))) - - auto total_tiles = - size<1>(tiled_input_tensor); // cdiv(M, 128) * cdiv(K, 128) - decltype(total_tiles) blk_offset = blockIdx.x; - while (blk_offset < total_tiles) { - auto current_input_tile = tensor<0>(tiled_input_tensor(_, blk_offset)); - auto current_quant_output_tile = - tensor<0>(tiled_quant_output_tensor(_, blk_offset)); - auto current_predict_tile = - tensor<0>(tiled_predict_tensor(_, blk_offset)); - auto current_scale_factor_tile = - tensor<0>(scale_factor_tensor(_, blk_offset)); - - mxfp8_experts_quant_tile< - decltype(current_input_tile), decltype(current_predict_tile), - decltype(current_quant_output_tile), decltype(scale_factor_shared), - decltype(current_scale_factor_tile), TiledCopyG2R, TiledCopyR2G, - TiledCopyR2S>(current_input_tile, current_predict_tile, - current_quant_output_tile, scale_factor_shared, - current_scale_factor_tile, m, tiled_copy_g2r, - tiled_copy_r2g, tiled_copy_r2s); - blk_offset += gridDim.x; - } - } -#endif -} - -template -void launch_mxfp8_experts_quant(const torch::stable::Tensor& input, - const torch::stable::Tensor& problem_sizes, - const torch::stable::Tensor& expert_offsets, - const torch::stable::Tensor& blockscale_offsets, - torch::stable::Tensor& quant_output, - torch::stable::Tensor& scale_factor) { - ThrLayout thr_layout{}; - ValLayout val_layout{}; - SfR2SThrLayout r2s_thr_layout{}; - SfR2SValLayout r2s_val_layout{}; - - using CopyOpG2R = - UniversalCopy>; - using CopyAtomG2R = cute::Copy_Atom; - auto tiled_copy_g2r = cute::make_tiled_copy( - CopyAtomG2R{}, thr_layout, val_layout); // Tiler_MN: (16, 128) - - using CopyOpR2G = UniversalCopy< - cutlass::AlignedArray>; - using CopyAtomR2G = cute::Copy_Atom; - auto tiled_copy_r2g = cute::make_tiled_copy( - CopyAtomR2G{}, thr_layout, val_layout); // Tiler_MN: (16, 128) - - using CopyOpR2S = - UniversalCopy>; - using CopyAtomR2S = cute::Copy_Atom; - auto tiled_copy_r2s = cute::make_tiled_copy( - CopyAtomR2S{}, r2s_thr_layout, r2s_val_layout); // Tiler_MN: (16, 4) - - int max_active_blocks_per_sm = -1; - STD_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks_per_sm, - mxfp8_experts_quant_kernel, - THREAD_BLOCK_SIZE, 0)); - - dim3 grid(get_device_prop()->multiProcessorCount * max_active_blocks_per_sm, - 1, 1); - dim3 block(THREAD_BLOCK_SIZE, 1, 1); - int num_experts = static_cast(problem_sizes.size(0)); - auto stream = get_current_cuda_stream(input.get_device_index()); - mxfp8_experts_quant_kernel - <<>>( - reinterpret_cast(input.data_ptr()), - reinterpret_cast(problem_sizes.data_ptr()), - reinterpret_cast(expert_offsets.data_ptr()), - reinterpret_cast(blockscale_offsets.data_ptr()), - reinterpret_cast(quant_output.data_ptr()), - reinterpret_cast(scale_factor.data_ptr()), num_experts, - tiled_copy_g2r, tiled_copy_r2g, tiled_copy_r2s); -} - -} // namespace expert_specialization \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index c63e59c3b03..58524c4c5db 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -130,25 +130,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? qzeros_or_none, bool inplace) -> Tensor"); // conditionally compiled so impl registrations are in source file -#endif - -#ifndef USE_ROCM - // Expert-specialization mxfp8 blockscaled grouped quantization (SM100+). - ops.def( - "mxfp8_experts_quant(" - " Tensor input, Tensor problem_sizes, Tensor expert_offsets," - " Tensor blockscale_offsets, Tensor! quant_output, Tensor! scale_factor)" - " -> ()"); - // conditionally compiled so impl registration is in source file - - // Expert-specialization mxfp8 blockscaled grouped GEMM (SM100+). - ops.def( - "cutlass_mxfp8_grouped_mm(" - " Tensor a, Tensor b, Tensor sfa, Tensor sfb, Tensor! out," - " Tensor problem_sizes, Tensor expert_offsets, Tensor blockscale_offsets)" - " -> ()"); - // conditionally compiled so impl registration is in source file - #endif } diff --git a/tests/kernels/moe/test_cutlass_mxfp8_grouped_mm.py b/tests/kernels/moe/test_cutlass_mxfp8_grouped_mm.py deleted file mode 100644 index 3a154fbb84c..00000000000 --- a/tests/kernels/moe/test_cutlass_mxfp8_grouped_mm.py +++ /dev/null @@ -1,237 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Adapted from SGLang: -# https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/tests/test_es_fp8_blockwise_moe.py - -"""Tests for SM100 CUTLASS MXFP8 grouped MoE kernels.""" - -import random - -import pytest -import torch - -from tests.kernels.utils import torch_moe_single -from vllm import _custom_ops as ops -from vllm.platforms import current_platform -from vllm.utils.torch_utils import set_random_seed - -random.seed(42) -set_random_seed(42) - - -def align(val: int, alignment: int = 128) -> int: - return int((val + alignment - 1) // alignment * alignment) - - -# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py -def calc_diff(x, y): - x, y = x.double(), y.double() - denominator = (x * x + y * y).sum() - sim = 2 * (x * y).sum() / denominator - return 1 - sim - - -def is_sm100_supported() -> bool: - return current_platform.is_cuda() and current_platform.is_device_capability_family( - 100 - ) - - -def compute_ref_output( - input_tensor: torch.Tensor, - weight_list: list[torch.Tensor], - expert_offsets: list[int], - expert_offset: int, - num_experts: int, -) -> torch.Tensor: - # Build a top-1 routing score so each token maps to its owning expert. - score = torch.full( - (expert_offset, num_experts), - -1e9, - device=input_tensor.device, - dtype=torch.float32, - ) - for g in range(num_experts): - start = expert_offsets[g] - end = expert_offsets[g + 1] if g + 1 < num_experts else expert_offset - score[start:end, g] = 0.0 - - return torch_moe_single( - input_tensor, torch.stack(weight_list, dim=0), score, topk=1 - ) - - -def compute_kernel_output( - input_tensor: torch.Tensor, - weight_tensor: torch.Tensor, - problem_sizes: list[list[int]], - aux_problem_sizes: list[list[int]], - expert_offsets: list[int], - aux_expert_offsets: list[int], - input_blockscale_offsets: list[int], - weight_blockscale_offsets: list[int], - input_blockscale_offset: int, - n_g: int, - k_g: int, - num_experts: int, - expert_offset: int, - out_dtype: torch.dtype, -) -> torch.Tensor: - device = input_tensor.device - _problem_sizes = torch.tensor(problem_sizes).to(device=device, dtype=torch.int32) - _aux_problem_sizes = torch.tensor(aux_problem_sizes).to( - device=device, dtype=torch.int32 - ) - _expert_offsets = torch.tensor(expert_offsets).to(device=device, dtype=torch.int32) - _aux_expert_offsets = torch.tensor(aux_expert_offsets).to( - device=device, dtype=torch.int32 - ) - _input_blockscale_offsets = torch.tensor(input_blockscale_offsets).to( - device=device, dtype=torch.int32 - ) - _weight_blockscale_offsets = torch.tensor(weight_blockscale_offsets).to( - device=device, dtype=torch.int32 - ) - - input_quant = torch.zeros_like( - input_tensor, dtype=torch.float8_e4m3fn, device=device - ) - input_scale_factor = torch.zeros( - (input_blockscale_offset, k_g // 32), dtype=torch.uint8, device=device - ) - - weight_quant = torch.zeros_like( - weight_tensor, dtype=torch.float8_e4m3fn, device=device - ) - weight_scale_factor = torch.zeros( - (num_experts, n_g, k_g // 32), dtype=torch.uint8, device=device - ) - - ops.mxfp8_experts_quant( - input_tensor, - _problem_sizes, - _expert_offsets, - _input_blockscale_offsets, - input_quant, - input_scale_factor, - ) - - ops.mxfp8_experts_quant( - weight_tensor, - _aux_problem_sizes, - _aux_expert_offsets, - _weight_blockscale_offsets, - weight_quant, - weight_scale_factor, - ) - weight_quant = weight_quant.view(num_experts, n_g, k_g).transpose(1, 2) - weight_scale_factor = weight_scale_factor.view( - num_experts, n_g, k_g // 32 - ).transpose(1, 2) - - output = torch.empty((expert_offset, n_g), device=device, dtype=out_dtype) - ops.cutlass_mxfp8_grouped_mm( - input_quant, - weight_quant, - input_scale_factor, - weight_scale_factor, - output, - _problem_sizes, - _expert_offsets, - _input_blockscale_offsets, - ) - return output - - -@pytest.mark.skipif( - not is_sm100_supported(), - reason=( - "cutlass_mxfp8_grouped_mm and mxfp8_experts_quant " - "are only supported on CUDA SM100" - ), -) -@pytest.mark.parametrize("num_experts", [8, 16, 32, 64]) -@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16]) -def test_cutlass_mxfp8_grouped_mm(num_experts, out_dtype): - device = "cuda" - alignment = 128 - n_g = random.randint(1, 64) * alignment - k_g = random.randint(1, 64) * alignment - - expert_offset = 0 - expert_offsets = [] - aux_expert_offset = 0 - aux_expert_offsets = [] - input_blockscale_offset = 0 - input_blockscale_offsets = [] - weight_blockscale_offset = 0 - weight_blockscale_offsets = [] - problem_sizes = [] - aux_problem_sizes = [] - input_list = [] - weight_list = [] - - for g in range(num_experts): - m_g = random.randint(1, 512) - expert_offsets.append(expert_offset) - expert_offset += m_g - aux_expert_offsets.append(aux_expert_offset) - aux_expert_offset += n_g - input_blockscale_offsets.append(input_blockscale_offset) - input_blockscale_offset += align(m_g, 128) - weight_blockscale_offsets.append(weight_blockscale_offset) - weight_blockscale_offset += n_g # n_g already align to 128 - problem_sizes.append([m_g, n_g, k_g]) - aux_problem_sizes.append([n_g, m_g, k_g]) - - input_tensor = torch.normal( - 0.0, std=1.0, size=(m_g, k_g), device=device, dtype=out_dtype - ) # (M, K):(K, 1) - weight_tensor = torch.normal( - 0.0, std=1.0, size=(n_g, k_g), device=device, dtype=out_dtype - ) # (N, K):(K, 1) - - input_list.append(input_tensor) - weight_list.append(weight_tensor) - input_tensor = torch.concat(input_list, dim=0) - weight_tensor = torch.concat(weight_list, dim=0) - - ref_output = compute_ref_output( - input_tensor=input_tensor, - weight_list=weight_list, - expert_offsets=expert_offsets, - expert_offset=expert_offset, - num_experts=num_experts, - ) - output = compute_kernel_output( - input_tensor=input_tensor, - weight_tensor=weight_tensor, - problem_sizes=problem_sizes, - aux_problem_sizes=aux_problem_sizes, - expert_offsets=expert_offsets, - aux_expert_offsets=aux_expert_offsets, - input_blockscale_offsets=input_blockscale_offsets, - weight_blockscale_offsets=weight_blockscale_offsets, - input_blockscale_offset=input_blockscale_offset, - n_g=n_g, - k_g=k_g, - num_experts=num_experts, - expert_offset=expert_offset, - out_dtype=out_dtype, - ) - - for g in range(num_experts): - baseline = ref_output[ - expert_offsets[g] : (expert_offsets[g] + problem_sizes[g][0]) - ] - actual = output[expert_offsets[g] : (expert_offsets[g] + problem_sizes[g][0])] - diff = calc_diff(actual, baseline) - assert diff < 0.001 - print( - f"m_g={baseline.shape[0]} n_g={n_g} k_g={k_g} num_experts={num_experts}, " - f"out_dtype={out_dtype}, diff={diff:.5f}: OK" - ) - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index bd8a19b6d2d..ee4f0dc850e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1132,76 +1132,6 @@ def cutlass_mxfp4_moe_mm( ) -def mxfp8_experts_quant( - input_tensor: torch.Tensor, - problem_sizes: torch.Tensor, - expert_offsets: torch.Tensor, - blockscale_offsets: torch.Tensor, - quant_output: torch.Tensor, - scale_factor: torch.Tensor, -) -> None: - torch.ops._C.mxfp8_experts_quant( - input_tensor, - problem_sizes, - expert_offsets, - blockscale_offsets, - quant_output, - scale_factor, - ) - - -def cutlass_mxfp8_grouped_mm( - a_tensors: torch.Tensor, - b_tensors: torch.Tensor, - a_scales: torch.Tensor, - b_scales: torch.Tensor, - out_tensors: torch.Tensor, - problem_sizes: torch.Tensor, - expert_offsets: torch.Tensor, - blockscale_offsets: torch.Tensor, -) -> None: - torch.ops._C.cutlass_mxfp8_grouped_mm( - a_tensors, - b_tensors, - a_scales, - b_scales, - out_tensors, - problem_sizes, - expert_offsets, - blockscale_offsets, - ) - - -if hasattr(torch.ops._C, "mxfp8_experts_quant"): - - @register_fake("_C::mxfp8_experts_quant") - def _mxfp8_experts_quant_fake( - input_tensor: torch.Tensor, - problem_sizes: torch.Tensor, - expert_offsets: torch.Tensor, - blockscale_offsets: torch.Tensor, - quant_output: torch.Tensor, - scale_factor: torch.Tensor, - ) -> None: - return None - - -if hasattr(torch.ops._C, "cutlass_mxfp8_grouped_mm"): - - @register_fake("_C::cutlass_mxfp8_grouped_mm") - def _cutlass_mxfp8_grouped_mm_fake( - a_tensors: torch.Tensor, - b_tensors: torch.Tensor, - a_scales: torch.Tensor, - b_scales: torch.Tensor, - out_tensors: torch.Tensor, - problem_sizes: torch.Tensor, - expert_offsets: torch.Tensor, - blockscale_offsets: torch.Tensor, - ) -> None: - return None - - # gptq_marlin def gptq_marlin_repack( b_q_weight: torch.Tensor,