From 656c705ff1aba56ec1d7e3f43d52fcefe5dea5cc Mon Sep 17 00:00:00 2001 From: Cheng Hang Date: Mon, 5 Jan 2026 09:44:36 +0800 Subject: [PATCH] [None][feat] sm100 weight-only kernel (#10190) Signed-off-by: Cheng Hang --- .../detail/collective/mixed_input_utils.hpp | 208 +++ .../sm100_umma_builder_weightonly.inl | 294 ++++ .../collective_builder_sm100_weightonly.hpp | 42 + .../collective_mma_sm100_weightonly.hpp | 42 + .../sm100_mma_warpspecialized_mixed_input.hpp | 1261 +++++++++++++++++ .../gemm/kernel/fpA_intB_gemm.h | 4 +- .../gemm/kernel/mixed_gemm_B_layout.h | 28 +- .../cutlass_kernels/cutlass_heuristic.cpp | 32 +- .../cutlass_kernels/cutlass_preprocessors.cpp | 23 +- .../fpA_intB_gemm/fpA_intB_gemm_template.h | 26 +- .../fpA_intB_gemm_template_sm100.h | 153 ++ .../launchers/fpA_intB_launcher_sm100.h | 39 + .../launchers/fpA_intB_launcher_sm100.inl | 286 ++++ .../python/generate_kernels.py | 68 +- ...atcherBf16Int4GroupwiseColumnMajorFalse.cu | 32 + ...tcherBf16Int4PerChannelColumnMajorFalse.cu | 29 + ...atcherBf16Int8GroupwiseColumnMajorFalse.cu | 29 + ...tcherBf16Int8PerChannelColumnMajorFalse.cu | 29 + ...atcherFp16Int4GroupwiseColumnMajorFalse.cu | 32 + ...tcherFp16Int4PerChannelColumnMajorFalse.cu | 29 + ...atcherFp16Int8GroupwiseColumnMajorFalse.cu | 29 + ...tcherFp16Int8PerChannelColumnMajorFalse.cu | 29 + .../weightOnlyBatchedGemv/kernelLauncher.h | 18 +- docs/source/features/quantization.md | 2 +- .../_torch/custom_ops/torch_custom_ops.py | 4 +- tensorrt_llm/quantization/functional.py | 6 +- 26 files changed, 2751 insertions(+), 23 deletions(-) create mode 100644 cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm100_umma_builder_weightonly.inl create mode 100644 cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_sm100_weightonly.hpp create mode 100644 cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_sm100_weightonly.hpp create mode 100644 cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp create mode 100644 cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm100.h create mode 100644 cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm100.h create mode 100644 cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm100.inl create mode 100644 cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int4GroupwiseColumnMajorFalse.cu create mode 100644 cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int4PerChannelColumnMajorFalse.cu create mode 100644 cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int8GroupwiseColumnMajorFalse.cu create mode 100644 cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int8PerChannelColumnMajorFalse.cu create mode 100644 cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4GroupwiseColumnMajorFalse.cu create mode 100644 cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4PerChannelColumnMajorFalse.cu create mode 100644 cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8GroupwiseColumnMajorFalse.cu create mode 100644 cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8PerChannelColumnMajorFalse.cu diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp index 53dc9e053a..bc3591ad3a 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp @@ -686,4 +686,212 @@ public: } }; +template +struct MixedInputUtilsSM100 +{ +private: + using KernelSchedule = typename Collective::KernelSchedule; + using ConversionMode = typename Collective::ConversionMode; + using SmemLayoutA = typename Collective::SmemLayoutA; + using SmemLayoutB = typename Collective::SmemLayoutB; + using ElementScale = typename Collective::ElementScale; + using ElementZero = typename Collective::ElementZero; + static constexpr auto KernelConversionMode = Collective::KernelConversionMode; + +public: + // Helper functions to select packing for conversion + template + struct select_packing + { // Naive packing policy + + static constexpr auto value() + { + return Int, sizeof_bits_v))>{}; + } + }; + + /// (Designed for separate transform pipeline in Blackwell) + /// Utilities to dequantize A. + template + CUTLASS_DEVICE static void dequantize_A_kblock_for_transform(Tensor const& tArA, + Tensor& tArACompute, cute::tuple const& partitioned_extra_info, int const k_block) + { + + static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); + static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); + static_assert(cosize_v == cosize_v); + static_assert(size_v == cosize_v); + static_assert(size_v == cosize_v); + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + auto src = tArA(_, _, _, k_block); + auto dst = tArACompute(_, _, _, k_block); + auto pSrc = raw_pointer_cast(src.data()); + auto pDst = const_cast(raw_pointer_cast(dst.data())); + constexpr int num_elements = decltype(size(src))::value; + + constexpr int pack = decltype(select_packing::value())::value; + using Converter + = cutlass::NumericArrayConverter; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + constexpr int DstElementsPerReg = 32 / sizeof_bits_v; + using RegArray = cutlass::AlignedArray; + + auto src_arr = recast(src); + auto dst_arr = recast(dst); + + Tensor dst_vm = cute::group_modes<1, -1>(cute::zipped_divide(dst, pack)); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) + { + cute::transform(src_arr, dst_arr, Converter::convert); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) + { + + auto const& scales = cute::get<1>(partitioned_extra_info)(_, _, _, k_block); + + CUTE_STATIC_ASSERT_V(size(src) == size(scales)); + + if constexpr (is_same_v) + { + cute::transform(src_arr, dst_arr, Converter::convert); + + using ScaleArray = cutlass::Array; + auto scale_arr = recast(filter_zeros(scales)); + + if constexpr (is_same_v) + { + Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, pack)); + + for (int i = 0; i < size<1>(dst_vm); ++i) + { + auto&& r = cute::recast(dst_vm(_, i))(0); + auto&& scale_reg = cute::recast(scales_vm(_, i))(0); + CUTLASS_PRAGMA_UNROLL + for (size_t ii = 0; ii < RegArray::kElements; ++ii) + { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hmul2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(scale_reg[ii])); + } + } + } + else + { + cute::transform(dst_arr, scale_arr, dst_arr, cute::multiplies{}); + } + } + else + { + constexpr int pack1 = decltype(select_packing::value())::value; + constexpr int pack2 = decltype(select_packing::value())::value; + constexpr int pack = cute::gcd(pack1, pack2); + using Converter1 = cutlass::NumericArrayConverter; + using Converter2 = cutlass::NumericArrayConverter; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using StageArray = cutlass::Array; + constexpr int iters = num_elements / pack; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < iters; ++i) + { + SrcArray const* pSrcArr = reinterpret_cast(pSrc) + i; + DstArray* pDstArr = reinterpret_cast(pDst) + i; + StageArray stageArr; + stageArr = Converter1::convert(*pSrcArr); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < pack; ++j) + { + stageArr[j] = stageArr[j] * scales[i * pack + j]; + } + *pDstArr = Converter2::convert(stageArr); + } + } + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) + { + static_assert(is_same_v, "ElementScale and ElementZero must be the same."); + + auto const& scales = cute::get<1>(partitioned_extra_info)(_, _, _, k_block); + auto const& zeros = cute::get<3>(partitioned_extra_info)(_, _, _, k_block); + CUTE_STATIC_ASSERT_V(size(src) == size(scales)); + CUTE_STATIC_ASSERT_V(size(src) == size(zeros)); + + if constexpr (is_same_v) + { + cute::transform(src_arr, dst_arr, Converter::convert); + + using ScaleArray = cutlass::Array; + auto scale_arr = recast(filter_zeros(scales)); + + using ZeroArray = cutlass::Array; + auto zero_arr = recast(filter_zeros(zeros)); + + if constexpr (is_same_v) + { + Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, pack)); + Tensor zeros_vm = cute::group_modes<1, -1>(cute::zipped_divide(zeros, pack)); + + for (int i = 0; i < size<1>(dst_vm); ++i) + { + auto&& r = cute::recast(dst_vm(_, i))(0); + auto&& scale_reg = cute::recast(scales_vm(_, i))(0); + auto&& zero_reg = cute::recast(zeros_vm(_, i))(0); + CUTLASS_PRAGMA_UNROLL + for (size_t ii = 0; ii < RegArray::kElements; ++ii) + { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hmul2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(scale_reg[ii])); + bf16x2_val = __hadd2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(zero_reg[ii])); + } + } + } + else + { + cute::transform(dst_arr, scale_arr, dst_arr, cute::multiplies{}); + cute::transform(dst_arr, zero_arr, dst_arr, cute::plus{}); + } + } + else + { + constexpr int pack1 = decltype(select_packing::value())::value; + constexpr int pack2 = decltype(select_packing::value())::value; + constexpr int pack = cute::gcd(pack1, pack2); + using Converter1 = cutlass::NumericArrayConverter; + using Converter2 = cutlass::NumericArrayConverter; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using StageArray = cutlass::Array; + constexpr int iters = num_elements / pack; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < iters; ++i) + { + SrcArray const* pSrcArr = reinterpret_cast(pSrc) + i; + DstArray* pDstArr = reinterpret_cast(pDst) + i; + StageArray stageArr; + stageArr = Converter1::convert(*pSrcArr); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < pack; ++j) + { + stageArr[j] = stageArr[j] * scales[i * pack + j] + zeros[i * pack + j]; + } + *pDstArr = Converter2::convert(stageArr); + } + } + } + else + { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled for input partitioning."); + } + } +}; } // namespace cutlass::gemm::collective::detail diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm100_umma_builder_weightonly.inl b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm100_umma_builder_weightonly.inl new file mode 100644 index 0000000000..2cfe7d36d5 --- /dev/null +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm100_umma_builder_weightonly.inl @@ -0,0 +1,294 @@ +/* + * Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cutlass/gemm/collective/builders/sm100_common.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail +{ + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr cute::tuple sm100_compute_stage_count_or_override_weightonly(StageCount stage_count) +{ + constexpr int Load2TransformStageCount = stages; + constexpr int Transform2MmaStageCount = stages; + constexpr int AccumulatorStageCount = stages; + return cute::make_tuple(Load2TransformStageCount, Transform2MmaStageCount, AccumulatorStageCount); +} + +template +constexpr cute::tuple sm100_compute_stage_count_or_override_weightonly( + StageCountAutoCarveout stage_count) +{ + + constexpr int CtaM = get<0>(CtaTileShape_MNK{}); + constexpr int CtaN = get<1>(CtaTileShape_MNK{}); + static_assert(CtaN <= 128, "Can't support CtaN>128 tiles"); + constexpr int CtaK = get<2>(CtaTileShape_MNK{}); + using AtomThrID = typename TiledMma::AtomThrID; + + constexpr int TmemColumns = 512; + + constexpr bool IsAComputeinTmem = UmmaMajorA == cute::UMMA::Major::K + && !cute::is_base_of_v; + constexpr bool IsAComputeinSmem = !IsAComputeinTmem; + + // Detect 2x2 TMEM layout + constexpr int TmemAccWordsPerDP = (CtaM == 64 && size(AtomThrID{}) == 2) ? CtaN / 2 : CtaN; + constexpr int TmemAWordsPerDP = CtaK / 2; + + constexpr int AccumulatorStageCount + = (IsAComputeinTmem) ? ((TmemAccWordsPerDP == 128) ? 2 : 3) : (TmemColumns / TmemAccWordsPerDP); + + constexpr int SmemCapacityAfterMma2AccumCarveout = CapacityBytes - (carveout_bytes + AccumulatorStageCount * 32); + + constexpr int TmemInAStageCount_Potential + = (IsAComputeinTmem) ? (TmemColumns - AccumulatorStageCount * TmemAccWordsPerDP) / TmemAWordsPerDP : 10000; + + // Mainload2Transform Pipeline + constexpr auto load2transform_pipeline_bytes + = sizeof(typename cutlass::PipelineTmaTransformAsync<1>::SharedStorage); + constexpr auto a_bits = cute::sizeof_bits_v; // ElementA introduce here + constexpr auto s_bits = cute::is_void_v ? 0 : cute::sizeof_bits_v; + constexpr auto z_bits = cute::is_void_v ? 0 : cute::sizeof_bits_v; + + constexpr auto load2mma_pipeline_bytes = sizeof(typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage); + constexpr auto b_bits = cute::sizeof_bits_v; // ElementB introduce here + + constexpr int ab_stage_bytes + = cutlass::bits_to_bytes(a_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{})) + + cutlass::bits_to_bytes(s_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}) / ScaleGranularityK) + + cutlass::bits_to_bytes(z_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}) / ScaleGranularityK) + + cutlass::bits_to_bytes(b_bits * size<1>(CtaTileShape_MNK{}) / size(AtomThrID{}) * size<2>(CtaTileShape_MNK{})) + + static_cast(load2transform_pipeline_bytes) + static_cast(load2mma_pipeline_bytes); + + // Transform2Mma Pipeline + constexpr auto transform2mma_pipeline_bytes = sizeof(typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage); + constexpr auto a_compute_bits = cute::sizeof_bits_v; + constexpr int ab_compute_stage_bytes = cutlass::bits_to_bytes(a_compute_bits * int(IsAComputeinSmem) + * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{})) + + // If ACompute is in TMEM, Acompute buffer has 0 bytes. + static_cast(transform2mma_pipeline_bytes); + + constexpr int ABComputeStageCount_Potential + = SmemCapacityAfterMma2AccumCarveout / (ab_stage_bytes + ab_compute_stage_bytes); + + // The number of SMEM buffers for A, B. ACompute (if in SMEM), BCompute should be at least Transform2MmaStageCount + constexpr int Transform2MmaStageCount = std::min(TmemInAStageCount_Potential, ABComputeStageCount_Potential); + + constexpr int SmemCapacityAfterABComputeCarveout + = SmemCapacityAfterMma2AccumCarveout - (Transform2MmaStageCount * ab_compute_stage_bytes); + + // Can we boost the number of buffers for A and B? + constexpr int Load2TransformStageCount = SmemCapacityAfterABComputeCarveout / ab_stage_bytes; + + static_assert(Load2TransformStageCount >= 2 && Transform2MmaStageCount >= 2 && AccumulatorStageCount >= 2, + "Not enough SMEM or TMEM capacity for selected tile size"); + return cute::make_tuple(Load2TransformStageCount, Transform2MmaStageCount, AccumulatorStageCount); +} + +} // namespace detail + +// Mixed Input MMA kernels builder +template +struct CollectiveBuilderSm100WeightOnly) &&( + (sizeof(float) * AlignmentA) % detail::tma_alignment_bytes == 0) + && ((sizeof(float) * AlignmentB) % detail::tma_alignment_bytes == 0)>> +{ + using GmemLayoutATag = detail::deduce_mixed_width_dtype_t<0, GmemLayoutATagTuple>; + using GmemLayoutScaleTag = detail::deduce_mixed_width_dtype_t<1, GmemLayoutATagTuple>; + + static constexpr cute::UMMA::Major UmmaMajorA + = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB + = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; + using ElementScale = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple>; + using ElementZero = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>; + + static constexpr bool NeitherIsTuple + = !cute::is_tuple::value && !cute::is_tuple::value; + static constexpr bool IsANarrow = cute::sizeof_bits_v < cute::sizeof_bits_v; + static constexpr bool IsMixedInput = cute::sizeof_bits_v != cute::sizeof_bits_v; + static_assert(IsMixedInput, "Mixed Input GEMM Kernel doesn't support regular gemm."); + + static_assert( + (cute::is_tuple::value ^ cute::is_tuple::value + || (NeitherIsTuple && (cute::sizeof_bits::value != cute::sizeof_bits::value))), + "Either A OR B must be a tuple or the widths of A and B must be different."); + using ElementPairA = cute::conditional_t, + ElementAOptionalTuple>; + using ElementPairB = cute::conditional_t, + ElementBOptionalTuple>; + static constexpr bool IsATransformed = cute::is_tuple::value; + static_assert(IsATransformed, "A matrix should be transformed."); + + // For fp32 types, map to tf32 MMA value type. + using ElementMma = cute::conditional_t, tfloat32_t, ElementB>; + + using ElementAMma = ElementMma; + using ElementBMma = ElementMma; + + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = cute::conditional_t; + + static constexpr int ScalingFactor = 1; + + using TiledMma = decltype(detail::sm100_make_trivial_mixed_input_tiled_mma()); + using AtomThrID = typename TiledMma::AtomThrID; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + using CtaTileShape_MNK = decltype(shape_div(TileShape_MNK{}, AtomThrShapeMNK{})); + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A( + TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B( + TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), cute::size<2>(TileShape_MNK{})))); + + using BlockTileA_M = decltype(cute::size<0, 0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{})); + using BlockTileA_K = decltype(cute::size<0, 1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(cute::size<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm100_cluster_shape_to_tma_atom_B(ClusterShape_MNK{}, AtomThrID{})); + + // Input transform kernel can not use TMA 2SM instructions. + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + using SmemLayoutAtomACompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + using SmemLayoutAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType; + static constexpr int MMA_M = cute::size<0, 0>(MmaShapeA_MK{}); + using CopyAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType< + Copy_Atom, ElementA>, + cute::conditional_t< + (UmmaMajorA == cute::UMMA::Major::K + && !cute::is_base_of_v), + cute::conditional_t<(MMA_M == 64 && size(AtomThrID{}) == 1), SM100_TMEM_STORE_16dp256b1x, + SM100_TMEM_STORE_32dp32b8x>, // TS Implementation + Copy_Atom, ElementA>> // SS Implementation + >; + + using BlockTileB_N = decltype(cute::size<0, 0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); + using BlockTileB_K = decltype(cute::size<0, 1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); + + // Input transform kernel can not use TMA 2SM instructions. + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + using SmemLayoutAtomBCompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + using SmemLayoutAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType; + using CopyAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType< + Copy_Atom, ElementB>, + Copy_Atom, ElementMma>>; + + // Creating the stride of Transformed Input + using StrideA = cutlass::gemm::TagToStrideA_t; + using LayoutScale = cutlass::gemm::TagToStrideA_t; + + using VoidShapeScale + = Shape, _1>, Shape, _1>, _1>; // Dummy Value to create a dummy ScaleConfig + using VoidStrideScale = Stride, Stride<_0, _1>, _1>; + using VoidLayoutScale = Layout; + + using NonVoidLayoutScale = cute::conditional_t, VoidLayoutScale, LayoutScale>; + + using StridePairA = decltype(cute::make_tuple(StrideA{}, NonVoidLayoutScale{})); + + // SmemCarveout + static constexpr int SchedulerPipelineStageCount = 3; + static constexpr bool IsArrayOfPointersGemm + = (cute::is_base_of_v); + + // CLCPipeline = PipelineCLCFetchAsync + static constexpr auto CLCPipelineStorage + = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // CLC (scheduler) response + static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + // CLC Throttle pipeline storage + static constexpr auto CLCThrottlePipelineStorage + = sizeof(typename cutlass::PipelineAsync::SharedStorage); + // Tmem dealloc + static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier); + // Tmem ptr storage + static constexpr auto TmemBasePtrsStorage = sizeof(uint32_t); + // Tensormap Storage + static constexpr size_t TensorMapStorage + = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; + + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage + static constexpr auto KernelSmemCarveout = static_cast(CLCPipelineStorage + CLCResponseStorage + + CLCThrottlePipelineStorage + TmemDeallocStorage + TmemBasePtrsStorage + TensorMapStorage); + + // Reduce SMEM capacity available for buffers considering extra B smem and barrier smem allocations + static constexpr int Sm100ReducedSmemCapacityBytes = detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + static constexpr int ScaleGranularityK = get_ScaleGranularityK(); + + static constexpr auto stage_info + = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_weightonly< + Sm100ReducedSmemCapacityBytes, TmaElementA, ElementAMma, ElementScale, ElementZero, ElementB, + CtaTileShape_MNK, TiledMma, KernelScheduleType, UmmaMajorA, ScaleGranularityK>(StageCountType{}); + + static constexpr int Load2TransformPipelineStageCount = get<0>(stage_info); + static constexpr int Transform2MmaPipelineStageCount = get<1>(stage_info); + static constexpr int AccumulatorPipelineStageCount = get<2>(stage_info); + + static_assert(!IsArrayOfPointersGemm, "mixed input does not support grouped gemm on Blackwell"); + + using DispatchPolicy + = cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedMixedInput; + using CollectiveOp = cutlass::gemm::collective::CollectiveMmaSm100WeightOnly, TiledMma, + GmemTiledCopyA, SmemLayoutAtomPairA, CopyAtomPairA, cute::identity, GmemTiledCopyB, SmemLayoutAtomPairB, + CopyAtomPairB, cute::identity>; +}; + +} // namespace cutlass::gemm::collective diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_sm100_weightonly.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_sm100_weightonly.hpp new file mode 100644 index 0000000000..837e3d5375 --- /dev/null +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_sm100_weightonly.hpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass_extensions/gemm/collective/collective_mma_sm100_weightonly.hpp" + +namespace cutlass::gemm::collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct CollectiveBuilderSm100WeightOnly +{ + static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass_extensions/gemm/collective/builders/sm100_umma_builder_weightonly.inl" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_sm100_weightonly.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_sm100_weightonly.hpp new file mode 100644 index 0000000000..2ffabb970b --- /dev/null +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_sm100_weightonly.hpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/detail/dependent_false.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct CollectiveMmaSm100WeightOnly +{ + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass_extensions/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp new file mode 100644 index 0000000000..8cfb3cc692 --- /dev/null +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp @@ -0,0 +1,1261 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#include + +#include "cutlass/cutlass.h" +#include "cutlass/detail/blockwise_scale_layout.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/detail/collective/mixed_input_utils.hpp" +#include "cutlass/detail/sm100_mixed_dtype_blockwise_layout.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass_extensions/detail/collective/mixed_input_utils.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/mma_sm100.hpp" +#include "cute/atom/copy_atom.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop for Mixed Input Kernels +template +struct CollectiveMmaSm100WeightOnly< + MainloopSm100TmaUmmaWarpSpecializedMixedInput, + TileShape_, ElementAOptionalTuple_, StridePairA_, ElementBOptionalTuple_, StrideB_, TiledMma_, GmemTiledCopyA_, + SmemLayoutAtomsA_, CopyAtomsA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomsB_, CopyAtomsB_, TransformB_> +{ +public: + // + // Type Aliases + // + + using ConversionMode = cutlass::detail::ConversionMode; + // Determine MMA type: MMA_1SM vs MMA_2SM + using AtomThrShapeMNK = Shape(typename TiledMma_::ThrLayoutVMNK{})), _1, _1>; + using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedMixedInput; + using TileShape = TileShape_; + using TiledMma = TiledMma_; + using KernelSchedule = typename DispatchPolicy::Schedule; + static constexpr bool IsDynamicCluster = not cute::is_static_v; + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + using ElementAOptionalTuple = ElementAOptionalTuple_; + using ElementBOptionalTuple = ElementBOptionalTuple_; + +private: + template + friend struct detail::MixedInputUtils; + using CollectiveType = CollectiveMmaSm100WeightOnly; + using Utils = detail::MixedInputUtils; + using UtilsSM100 = detail::MixedInputUtilsSM100; + + using ElementScaleA = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple_>; + using ElementScaleB = detail::deduce_mixed_width_dtype_t<1, ElementBOptionalTuple>; + using ElementZeroA = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>; + using ElementZeroB = detail::deduce_mixed_width_dtype_t<2, ElementBOptionalTuple>; + +public: + static_assert(cute::is_tuple::value ^ cute::is_tuple::value, + "Either A OR B must be a tuple. It must take the from {ElementOperand, [ElementScale]," + "[ElementZero]}. Inputs in [] are optional."); + + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; + static constexpr bool IsATransformed = cute::is_tuple::value; + using ElementScale = cute::conditional_t; + using ElementZero = cute::conditional_t; + // For cases where we can't have a void type, we can use this to allow the code to compile when the scale / zero is + // void. + using NonVoidElementScale = cute::conditional_t, float, ElementScale>; + using NonVoidElementZero = cute::conditional_t, float, ElementZero>; + + using StrideA = cute::remove_cvref_t(StridePairA_{}))>; + using LayoutScale = cute::remove_cvref_t(StridePairA_{}))>; + using InternalStrideA = cute::remove_pointer_t; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + + static_assert((IsATransformed && cutlass::gemm::detail::is_k_major()) + || (!IsATransformed && cutlass::gemm::detail::is_k_major()), + "The transformed type must be K-major."); + + static_assert((IsATransformed && (sizeof(ElementB) == 2)) || (!IsATransformed && (sizeof(ElementA) == 2)) + || (cutlass::gemm::detail::is_k_major() && cutlass::gemm::detail::is_k_major()), + "The unscaled element must be 2 bytes OR both inputs must be K-major"); + + // Define A and B block shapes for reduced size TMA_LOADs + using CtaShapeA_MK + = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using CtaShapeB_NK + = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ElementAMma = typename TiledMma::ValTypeA; + using ElementBMma = typename TiledMma::ValTypeB; + + using ElementAccumulator = typename TiledMma::ValTypeC; + + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using GmemTiledCopyScale = GmemTiledCopyA_; + + using SmemLayoutAtomsA = SmemLayoutAtomsA_; + using SmemLayoutAtomsB = SmemLayoutAtomsB_; + using CopyAtomsA = CopyAtomsA_; + using CopyAtomsB = CopyAtomsB_; + using SmemCopyAtomScale = Copy_Atom; + + using SmemLayoutAtomA = typename SmemLayoutAtomsA::InputLayoutAtom; + using SmemLayoutAtomACompute = typename SmemLayoutAtomsA::ComputeLayoutAtom; + using SmemLayoutAtomB = typename SmemLayoutAtomsB::InputLayoutAtom; + using SmemLayoutAtomBCompute = typename SmemLayoutAtomsB::ComputeLayoutAtom; + + using InputCopyAtomA = typename CopyAtomsA::InputCopyAtom; + using ComputeCopyAtomA = typename CopyAtomsA::ComputeCopyAtom; + using InputCopyAtomB = typename CopyAtomsB::InputCopyAtom; + using ComputeCopyAtomB = typename CopyAtomsB::ComputeCopyAtom; + + // We must ensure the type to be scaled goes to RF + static constexpr bool SwapAB = !IsATransformed; + using InternalSmemLayoutAtomA = cute::conditional_t; + using InternalSmemLayoutAtomB = cute::conditional_t; + using InternalSmemLayoutAtomACompute = cute::conditional_t; + using InternalSmemLayoutAtomBCompute = cute::conditional_t; + + using InternalInputCopyAtomA = cute::conditional_t; + using InternalInputCopyAtomB = cute::conditional_t; + using InternalComputeCopyAtomA = cute::conditional_t; + using InternalComputeCopyAtomB = cute::conditional_t; + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using ConvertedElementA = cute::conditional_t>>; + using ConvertedElementB = cute::conditional_t>>; + using RealSwappedElementA = cute::conditional_t; + using RealSwappedElementB = cute::conditional_t; + using SwappedElementA = cute::conditional_t; + using SwappedElementB = cute::conditional_t; + using SwappedStrideA = cute::conditional_t; + using SwappedStrideB = cute::conditional_t; + using InternalSwappedStrideA = cute::conditional_t; + using InternalSwappedStrideB = cute::conditional_t; + + using TransformA = TransformA_; + using TransformB = TransformB_; + using InternalTransformA = cute::conditional_t; + using InternalTransformB = cute::conditional_t; + + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = cute::conditional_t; + using TmaElementScale + = uint_bit_t>; // in case we have array. translating to uint to satisfy tma + // descriptor's specialization + + using ArchTag = typename DispatchPolicy::ArchTag; + static_assert(cute::is_same_v || cute::is_same_v + || cute::is_same_v, + "Compute type A should be cutlass::bfloat16_t or cutlass::half_t or cutlass::float_e4m3_t"); + + using Load2TransformPipeline + = cutlass::PipelineTmaTransformAsync; + using Load2TransformPipelineState = typename Load2TransformPipeline::PipelineState; + + using Load2MmaPipeline = cutlass::PipelineTmaUmmaAsync; + using Load2MmaPipelineState = typename Load2MmaPipeline::PipelineState; + + using Transform2MmaPipeline + = cutlass::PipelineUmmaConsumerAsync; + using Transform2MmaPipelineState = typename Transform2MmaPipeline::PipelineState; + + using Mma2AccumPipeline + = cutlass::PipelineUmmaAsync; + using Mma2AccumPipelineState = typename Mma2AccumPipeline::PipelineState; + + static constexpr int ScaleGranularityMN = size<0, 0>(LayoutScale{}); + static constexpr int ScaleGranularityK = size<1, 0>(LayoutScale{}); + using ScaleConfig = cutlass::detail::Sm100MixedInputBlockwiseScaleConfig; + + using ScaleTileShape + = cute::conditional_t(TileShape{}), size<2>(TileShape{}))), + decltype(make_shape(size<1>(TileShape{}), size<2>(TileShape{})))>; + + using SmemLayoutAtomScaleFull = decltype(ScaleConfig::smem_atom_layout_scale(ScaleTileShape{})); + + // Getting the SmemSizeMN and SmemSizeK from the mixed_dtype blockwise utils. + using SmemLayoutAtomScale + = decltype(slice(make_coord(make_coord(_, 0), make_coord(_, 0)), SmemLayoutAtomScaleFull{})); + + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomScale{}) == 2, "SmemLayoutAtomScale must be rank 2"); + static_assert( + (size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must equal the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, + "SmemLayoutAtomScale must evenly divide tile k shape."); + + // Thread Counts + static constexpr uint32_t NumTransformationThreads = 128; + static constexpr uint32_t NumAccumThreads = 128; // Maintains compatibility with input_transform kernel + + // Get the Algorithm parameters + constexpr static int AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + constexpr static int StagesPerTile = size<2>(CtaShapeA_MK{}); + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0, 0>(CtaShapeA_MK{}) * size<1>(CtaShapeA_MK{})) % size<0>(SmemLayoutAtomACompute{})) == 0, + "SmemLayoutAtomCompute must evenly divide tile shape."); + static_assert(((size<0, 1>(CtaShapeA_MK{}) * size<2>(CtaShapeA_MK{})) % size<1>(SmemLayoutAtomACompute{})) == 0, + "SmemLayoutAtomCompute must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0, 0>(CtaShapeB_NK{}) * size<1>(CtaShapeB_NK{})) % size<0>(SmemLayoutAtomBCompute{})) == 0, + "SmemLayoutAtomCompute must evenly divide tile shape."); + static_assert(((size<0, 1>(CtaShapeB_NK{}) * size<2>(CtaShapeB_NK{})) % size<1>(SmemLayoutAtomBCompute{})) == 0, + "SmemLayoutAtomCompute must evenly divide tile shape."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape(SmemLayoutAtomA{}, + append(CtaShapeA_MK{}, Int{}), + (cute::conditional_t(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}))); + + using SmemLayoutACompute = decltype(UMMA::tile_to_mma_shape(SmemLayoutAtomACompute{}, + append(CtaShapeA_MK{}, Int{}), + (cute::conditional_t(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}))); + + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape(SmemLayoutAtomB{}, + append(CtaShapeB_NK{}, Int{}), + (cute::conditional_t(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}))); + + using SmemLayoutScale = decltype(UMMA::tile_to_mma_shape(SmemLayoutAtomScale{}, + append(CtaShapeA_MK{}, Int{}), + (cute::conditional_t(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}))); + + static_assert(DispatchPolicy::Load2TransformPipelineStageCount >= 2, + "Specialization requires Stages set to value 2 or more."); + static_assert((cute::is_base_of::value + || cute::is_base_of::value) + && cute::is_base_of::value, + "MMA atom must A operand from SMEM or TMEM and B operand from SMEM for this mainloop."); + static_assert( + (cute::is_same_v || cute::is_same_v), + "GmemTiledCopyA - invalid TMA copy atom specified."); + +private: + static constexpr ConversionMode get_conversion_mode() + { + if constexpr (cute::is_void_v) + { + return ConversionMode::DirectConvert; + } + else if constexpr (cute::is_void_v) + { + return ConversionMode::ConvertAndScale; + } + else + { + return ConversionMode::ConvertAndScaleWithZero; + } + } + +public: + static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); + static constexpr bool ModeHasScales = KernelConversionMode == ConversionMode::ConvertAndScale + || KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + static constexpr bool UseScaleLookupTable + = KernelConversionMode == ConversionMode::ConvertAndScale && cutlass::detail::is_Array_v; + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + + // Just pick the max alignment of A and B since it is required to be at least 128B + static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); + + static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment"); + + struct PipelineStorage + { + using Load2TransformPipelineStorage = typename Load2TransformPipeline::SharedStorage; + alignas(16) Load2TransformPipelineStorage load2transform_pipeline; + using Load2MmaPipelineStorage = typename Load2MmaPipeline::SharedStorage; + alignas(16) Load2MmaPipelineStorage load2mma_pipeline; + using Transform2MmaPipelineStorage = typename Transform2MmaPipeline::SharedStorage; + alignas(16) Transform2MmaPipelineStorage transform2mma_pipeline; + using Mma2AccumPipelineStorage = typename Mma2AccumPipeline::SharedStorage; + alignas(16) Mma2AccumPipelineStorage mma2accum_pipeline; + }; + + struct SharedStorage + { + static constexpr int scale_elements = Utils::elements_per_smem_scale(); + static constexpr int zero_elements = Utils::elements_per_smem_zero(); + + struct TensorStorage : cute::aligned_struct<128, _0> + { + + struct TensorStorageUntransformed + { + alignas(512) cute::ArrayEngine> smem_A; + alignas(1024) cute::ArrayEngine> smem_B; + cute::ArrayEngine smem_scale; + cute::ArrayEngine smem_zero; + }; + + struct TensorStorageTransformedAinSmem + { + // We require alignas(1024) here because the smem_ACompute may not be aligned to 1024 by default. + // We need 1024B alignment of smem_ACompute because we are using Swizzle<3,4,3> here. + // The Swizzle<3,4,3> aligns with 1024B. If we don't align the data, the compiler cannot deduce + // the base pointer of the data. + // This alignment allows us to perform the function swizzle(layout(i) * base_ptr). + alignas(1024) cute::ArrayEngine> smem_ACompute; + }; + + union TensorStorageTransformedAinTmem + { + cute::ArrayEngine smem_ACompute; // No smem_ACompute + }; + + using TensorStorageTransformed = cute::conditional_t< + cute::is_base_of::value, + TensorStorageTransformedAinSmem, TensorStorageTransformedAinTmem>; + + TensorStorageUntransformed input; + TensorStorageTransformed compute; + } tensors; + + PipelineStorage pipeline; + }; + + using TensorStorage = typename SharedStorage::TensorStorage; + + // Different from other GEMM kernels, both CTAs should be aware of loads. Both CTAs will work on + // loaded input A and B matrices to convert the data type + static constexpr uint32_t TmaTransactionBytes_A + = cutlass::bits_to_bytes(cosize(take<0, 3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + Utils::compute_tma_transaction_bytes_extra_transform(); + static constexpr uint32_t TmaTransactionBytes_B = cutlass::bits_to_bytes( + size(AtomThrShapeMNK{}) * cosize(take<0, 3>(SmemLayoutB{})) * cute::sizeof_bits_v); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytes_A + TmaTransactionBytes_B; + + // Host side kernel arguments + struct Arguments + { + ElementA const* ptr_A{nullptr}; + StrideA dA{}; + ElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementScale const* ptr_S{nullptr}; + LayoutScale layout_S{}; + int group_size = 0; + ElementZero const* ptr_Z{nullptr}; + }; + + struct TMAScaleParams + { + using ClusterLayout_VMNK + = decltype(tiled_divide(make_layout(conditional_return( + make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_Scale = decltype(make_tma_atom_A_sm100(GmemTiledCopyScale{}, + make_tensor(static_cast(nullptr), LayoutScale{}), + SmemLayoutScale{}(_, _, _, cute::Int<0>{}), TileShape{}, TiledMma{}, ClusterLayout_VMNK{})); + + TMA_Scale tma_load_scale; + TMA_Scale tma_load_zero; + }; + + struct EmptyScaleParams + { + }; + + // Device side kernel params + struct Params : public cute::conditional_t + { + + using ClusterLayout_VMNK + = decltype(tiled_divide(make_layout(conditional_return( + make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100(GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_, _, _, cute::Int<0>{}), TileShape{}, TiledMma{}, ClusterLayout_VMNK{})); + + using TMA_B = decltype(make_tma_atom_B_sm100(GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_, _, _, cute::Int<0>{}), TileShape{}, TiledMma{}, ClusterLayout_VMNK{})); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + + int reload_factor; + uint32_t tma_transaction_bytes{TmaTransactionBytes}; + SwappedStrideA dA{}; + SwappedStrideB dB{}; + }; + + CUTLASS_DEVICE + CollectiveMmaSm100WeightOnly(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) + { + if constexpr (IsDynamicCluster) + { + bool const is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x + && cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else + { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, + void* workspace, cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) + { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + Tensor tensor_a = make_tensor(args.ptr_A, make_layout(make_shape(M, K, L), args.dA)); + Tensor tensor_b = make_tensor(args.ptr_B, make_layout(make_shape(N, K, L), args.dB)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + + auto cluster_shape_fallback + = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + // Cluster layout for TMA construction + auto cluster_layout_vmnk_fallback + = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100(GmemTiledCopyA{}, tensor_a, + SmemLayoutA{}(_, _, _, cute::Int<0>{}), TileShape{}, TiledMma{}, cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100(GmemTiledCopyB{}, tensor_b, + SmemLayoutB{}(_, _, _, cute::Int<0>{}), TileShape{}, TiledMma{}, cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100(GmemTiledCopyA{}, tensor_a, + SmemLayoutA{}(_, _, _, cute::Int<0>{}), TileShape{}, TiledMma{}, cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100(GmemTiledCopyB{}, tensor_b, + SmemLayoutB{}(_, _, _, cute::Int<0>{}), TileShape{}, TiledMma{}, cluster_layout_vmnk_fallback); + + uint32_t tma_transaction_bytes = TmaTransactionBytes; + int reload_factor = (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) + { + return {{}, tma_load_a, tma_load_b, tma_load_a_fallback, tma_load_b_fallback, + hw_info.cluster_shape_fallback, reload_factor, tma_transaction_bytes, args.dA, args.dB}; + } + else if constexpr (ModeHasScales) + { + ElementScale const* ptr_S = args.ptr_S; + + Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), args.layout_S); + typename Params::TMA_Scale tma_load_scale = make_tma_atom_A_sm100(GmemTiledCopyScale{}, + tensor_scale, SmemLayoutScale{}(_, _, _, cute::Int<0>{}), TileShape{}, TiledMma{}, cluster_layout_vmnk); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) + { + typename Params::TMAScaleParams scale_params{tma_load_scale, {}}; + return {scale_params, tma_load_a, tma_load_b, tma_load_a_fallback, tma_load_b_fallback, + hw_info.cluster_shape_fallback, reload_factor, tma_transaction_bytes, args.dA, args.dB}; + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) + { + Tensor tensor_zero = make_tensor(detail::get_logical_ptr(args.ptr_Z), args.layout_S); + typename Params::TMA_Scale tma_load_zero + = make_tma_atom_A_sm100(GmemTiledCopyScale{}, tensor_zero, + SmemLayoutScale{}(_, _, _, cute::Int<0>{}), TileShape{}, TiledMma{}, cluster_layout_vmnk); + + typename Params::TMAScaleParams scale_params{tma_load_scale, tma_load_zero}; + return {scale_params, tma_load_a, tma_load_b, tma_load_a_fallback, tma_load_b_fallback, + hw_info.cluster_shape_fallback, reload_factor, tma_transaction_bytes, args.dA, args.dB}; + } + else + { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in to_underlying_arguments."); + } + } + else + { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in to_underlying_arguments."); + } + } + + template + static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) + { + + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_S = cutlass::detail::get_input_alignment_bits(); + + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cutlass::sizeof_bits::value; + bool check_aligned_A + = cutlass::detail::check_alignment(cute::make_shape(M, K, L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cutlass::sizeof_bits::value; + bool check_aligned_B + = cutlass::detail::check_alignment(cute::make_shape(N, K, L), StrideB{}); + + bool check_aligned_S = true; + bool check_aligned_Z = true; + bool check_mode_args = true; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) + { + check_mode_args = check_mode_args && (args.ptr_S == nullptr); + check_mode_args = check_mode_args && (args.ptr_Z == nullptr); + } + else if constexpr (ModeHasScales) + { + constexpr int min_tma_aligned_elements_scale + = tma_alignment_bits_S / cutlass::sizeof_bits::value; + check_aligned_S = cutlass::detail::check_alignment(args.layout_S); + check_mode_args + = check_mode_args && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0)); + check_mode_args = check_mode_args && args.group_size != 0; + check_mode_args = check_mode_args && (args.ptr_S != nullptr); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) + { + check_mode_args = check_mode_args && (args.ptr_Z == nullptr); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) + { + constexpr int min_tma_aligned_elements_zero + = tma_alignment_bits_S / cutlass::sizeof_bits::value; + check_aligned_Z = cutlass::detail::check_alignment(args.layout_S); + check_mode_args = check_mode_args && (args.ptr_Z != nullptr); + } + else + { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } + else + { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + + if (!check_mode_args) + { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Invalid arguments for the selected conversion mode.\n"); + } + if (!check_aligned_A) + { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor A does not meet the minimum alignment requirements for TMA.\n"); + } + if (!check_aligned_B) + { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor B does not meet the minimum alignment requirements for TMA.\n"); + } + if (!check_aligned_S) + { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Tensor S (scale) does not meet the minimum alignment requirements for TMA.\n"); + } + if (!check_aligned_Z) + { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Tensor Z (zeros) does not meet the minimum alignment requirements for TMA.\n"); + } + + return check_mode_args && check_aligned_A && check_aligned_B && check_aligned_S && check_aligned_Z; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& params) + { + if constexpr (IsDynamicCluster) + { + dim3 cs = cute::cluster_shape(); + bool const is_fallback_cluster + = (cs.x == params.cluster_shape_fallback.x && cs.y == params.cluster_shape_fallback.y); + if (is_fallback_cluster) + { + cute::prefetch_tma_descriptor(params.tma_load_a_fallback.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b_fallback.get_tma_descriptor()); + } + else + { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + } + } + else + { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + } + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) + ; + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) + { + cute::prefetch_tma_descriptor(params.tma_load_scale.get_tma_descriptor()); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) + { + cute::prefetch_tma_descriptor(params.tma_load_scale.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_zero.get_tma_descriptor()); + } + else + { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in TMA prefetch."); + } + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto partition_accumulator_shape() + { + auto acc_shape + = partition_shape_C(TiledMma{}, take<0, 2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + /// Produce the inputs to the transform threads by loading inputs from gmem -> smem + template + CUTLASS_DEVICE auto load_A(Params const& params, Load2TransformPipeline load2xform_pipeline, + Load2TransformPipelineState load2xform_pipeline_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, KTileIterator k_tile_iter, int k_tile_count) + { + + auto [unused_gA, unused_gB, tAgA_mkl, tBgB_nkl, tAsA, tBsB, mcast_mask_a, mcast_mask_b, extra_input_partitions] + = load_inputs; + + // slice out the work coord from tiled tensors + Tensor tAgA + = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + + uint32_t skip_wait = (k_tile_count <= 0); + auto load2xform_pipeline_flag = load2xform_pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + // Load2Mma and Load2Transform pipelines both have the same ProducerBarrierType + using BarrierType = typename Load2TransformPipeline::ProducerBarrierType; + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + + // LOCK mainloop_load2xform_pipeline_state for _writing_ + load2xform_pipeline.producer_acquire(load2xform_pipeline_state, load2xform_pipeline_flag); + + int tile_A_write_stage = load2xform_pipeline_state.index(); + + BarrierType* load2xform_tma_barrier = load2xform_pipeline.producer_get_barrier(load2xform_pipeline_state); + + // Advance mainloop load2transform pipeline + ++load2xform_pipeline_state; + + skip_wait = (k_tile_count <= 1); + load2xform_pipeline_flag = load2xform_pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + // TMA load for A k_tile + copy(observed_tma_load_a_->with(*load2xform_tma_barrier, mcast_mask_a), tAgA(_, *k_tile_iter), + tAsA(_, tile_A_write_stage)); + + if constexpr (ModeHasScales) + { + auto tSgS_mkl = get<0>(extra_input_partitions); + auto tSgS = tSgS_mkl( + _, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + auto tSsS = get<1>(extra_input_partitions); + int const scale_load_k = *k_tile_iter / params.reload_factor; + copy(params.tma_load_scale.with(*load2xform_tma_barrier, mcast_mask_a), tSgS(_, scale_load_k), + tSsS(_, tile_A_write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) + { + auto tZgZ_mkl = get<2>(extra_input_partitions); + auto tZgZ = tZgZ_mkl( + _, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + auto tZsZ = get<3>(extra_input_partitions); + copy(params.tma_load_zero.with(*load2xform_tma_barrier, mcast_mask_a), tZgZ(_, scale_load_k), + tZsZ(_, tile_A_write_stage)); + } + } + else + { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) + ; + else + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled for TMA copy op."); + } + + ++k_tile_iter; + } + + return cute::make_tuple(load2xform_pipeline_state, k_tile_iter); + } + + /// Produce the inputs to the transform threads by loading inputs from gmem -> smem + template + CUTLASS_DEVICE auto load_B(Params const& params, Load2MmaPipeline load2mma_pipeline, + Load2MmaPipelineState load2mma_pipeline_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, KTileIterator k_tile_iter, int k_tile_count) + { + + auto [unused_gA, unused_gB, tAgA_mkl, tBgB_nkl, tAsA, tBsB, mcast_mask_a, mcast_mask_b, extra_input_partitions] + = load_inputs; + + // slice out the work coord from tiled tensors + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + uint32_t skip_wait = (k_tile_count <= 0); + auto load2mma_pipeline_flag = load2mma_pipeline.producer_try_acquire(load2mma_pipeline_state, skip_wait); + + // Load2Mma and Load2Transform pipelines both have the same ProducerBarrierType + using BarrierType = typename Load2TransformPipeline::ProducerBarrierType; + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + + // LOCK mainloop_load2mma_pipeline_state for _writing_ + load2mma_pipeline.producer_acquire(load2mma_pipeline_state, load2mma_pipeline_flag); + + int tile_B_write_stage = load2mma_pipeline_state.index(); + + BarrierType* load2mma_tma_barrier = load2mma_pipeline.producer_get_barrier(load2mma_pipeline_state); + + // Advance mainloop load2mma pipeline + ++load2mma_pipeline_state; + + skip_wait = (k_tile_count <= 1); + load2mma_pipeline_flag = load2mma_pipeline.producer_try_acquire(load2mma_pipeline_state, skip_wait); + + // TMA load for B k_tile + copy(observed_tma_load_b_->with(*load2mma_tma_barrier, mcast_mask_b), tBgB(_, *k_tile_iter), + tBsB(_, tile_B_write_stage)); + + ++k_tile_iter; + } + + return cute::make_tuple(load2mma_pipeline_state, k_tile_iter); + } + + /// Set up the data needed by this collective for load. + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + // Other inputs needed for load(): partitioned AB tensors for gmem and smem, and mcast masks + template + CUTLASS_DEVICE auto load_init( + ProblemShape_MNKL const& problem_shape_MNKL, Params const& params, TensorStorage& shared_storage) const + { + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA + = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB + = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, get<2>(cta_coord_vmnk), + make_layout(size<2>(cta_layout_vmnk)), group_modes<0, 3>(sA), group_modes<0, 3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, get<1>(cta_coord_vmnk), + make_layout(size<1>(cta_layout_vmnk)), group_modes<0, 3>(sB), group_modes<0, 3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) + { + return cute::make_tuple(gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + cute::make_tuple()); + } + else if constexpr (ModeHasScales) + { + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + Tensor mS_mkl = params.tma_load_scale.get_tma_tensor(shape(LayoutScale{})); + Tensor gS_mkl = local_tile(mS_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); + + Tensor sS = make_tensor(make_smem_ptr(shared_storage.input.smem_scale.begin()), SmemLayoutScale{}); + + Tensor tCgS_mkl = cta_mma.partition_A(gS_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + + // Project the cta_layout for tma_scale along the n-modes + auto [tSgS_mkl, tSsS] = tma_partition(params.tma_load_scale, get<2>(cta_coord_vmnk), + make_layout(size<2>(cta_layout_vmnk)), group_modes<0, 3>(sS), group_modes<0, 3>(tCgS_mkl)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) + { + return cute::make_tuple(gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + cute::make_tuple(tSgS_mkl, tSsS)); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) + { + Tensor mZ_mkl = params.tma_load_zero.get_tma_tensor(shape(LayoutScale{})); + Tensor gZ_mkl = local_tile(mZ_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor sZ = make_tensor(make_smem_ptr(shared_storage.input.smem_zero.begin()), SmemLayoutScale{}); + + Tensor tCgZ_mkl = cta_mma.partition_A(gZ_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + + // Project the cta_layout for tma_scale along the n-modes + auto [tZgZ_mkl, tZsZ] = tma_partition(params.tma_load_zero, get<2>(cta_coord_vmnk), + make_layout(size<2>(cta_layout_vmnk)), group_modes<0, 3>(sZ), group_modes<0, 3>(tCgZ_mkl)); + return cute::make_tuple(gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + cute::make_tuple(tSgS_mkl, tSsS, tZgZ_mkl, tZsZ)); + } + else + { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } + else + { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } + + template + CUTLASS_DEVICE auto transform(Load2TransformPipeline load2transform_pipeline, + Load2TransformPipelineState load2transform_pipeline_consumer_state, + Transform2MmaPipeline transform2mma_pipeline, Transform2MmaPipelineState transform2mma_pipeline_producer_state, + Accumulator accumulators, + cute::tuple> input_operands, + KTileIterator k_tile_iter, int k_tile_count) + { + + static_assert( + cute::is_same_v, "ElementAMma and ElementBMma types should be the same."); + cutlass::arch::NamedBarrier transform_bar( + NumTransformationThreads, cutlass::arch::ReservedNamedBarriers::TransformBarrier); + + // tAsA : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, SmemStages (In SMEM) + // tAsACompute : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, SmemStages (In SMEM or TMEM) + auto [unused_tAgA, dst_copy_A, tAsA, tAsACompute, partitioned_extra_info] = input_operands; + + // Create the tensors in registers + auto tArA = make_tensor( + tAsA(_, _, _, _, 0).shape()); //(Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest (Register) + auto tArACompute = make_tensor(tAsA(_, _, _, _, 0).shape()); + constexpr int K_BLOCK_MAX = size<3>(tArA); + + uint32_t skip_wait = (k_tile_count <= 0); + auto load2transform_flag + = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + auto transform2mma_flag + = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + + load2transform_pipeline.consumer_wait(load2transform_pipeline_consumer_state, load2transform_flag); + + transform2mma_pipeline.producer_acquire(transform2mma_pipeline_producer_state, transform2mma_flag); + + int load2transform_consumer_index = load2transform_pipeline_consumer_state.index(); // read stage + int transform2mma_producer_index = transform2mma_pipeline_producer_state.index(); // write stage + + auto curr_load2transform_pipeline_consumer_state = load2transform_pipeline_consumer_state; + + // Copy the input A matrix from SMEM + copy(AutoVectorizingCopy{}, tAsA(_, _, _, _, load2transform_consumer_index), tArA); + // Copy scale/zero vector from SMEM + Utils::copy_scale_zeros_for_transform(partitioned_extra_info, load2transform_consumer_index); + + // Loads from SMEM are done. Signal the mainloop load as early as possible + transform_bar.sync(); + load2transform_pipeline.consumer_release(curr_load2transform_pipeline_consumer_state); + + auto curr_transform2mma_pipeline_producer_state = transform2mma_pipeline_producer_state; + + // Dequantize A with scale/zero in RF + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; k_block++) + { + UtilsSM100::dequantize_A_kblock_for_transform(tArA, tArACompute, partitioned_extra_info, k_block); + } + + // Dequantized A is stored into either Smem or Tmem + copy(dst_copy_A, tArACompute, tAsACompute(_, _, _, _, transform2mma_producer_index)); + + // fence for SMEM writes + cutlass::arch::fence_view_async_shared(); + if constexpr (is_tmem::value) + { + // fence for TMEM writes if A operand is coming from TMEM + cutlass::arch::fence_view_async_tmem_store(); + } + + // Let the MMA know we are done transforming + transform2mma_pipeline.producer_commit(curr_transform2mma_pipeline_producer_state); + // Next pipeline stage + ++load2transform_pipeline_consumer_state; + ++transform2mma_pipeline_producer_state; + + skip_wait = (k_tile_count <= 1); + // Peek the next pipeline stage's barriers + load2transform_flag + = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + transform2mma_flag + = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + } + return cute::make_tuple(load2transform_pipeline_consumer_state, transform2mma_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto transform_init(Params const& params, ProblemShape_MNKL const& problem_shape_MNKL, + Accumulator accumulators, TensorStorage& shared_storage) + { + + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + Tensor sA_orig = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); + Tensor sA = as_position_independent_swizzle_tensor(sA_orig); + Tensor sACompute + = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + + Tensor sS = make_tensor(make_smem_ptr(shared_storage.input.smem_scale.begin()), SmemLayoutScale{}); + Tensor sZ = make_tensor(make_smem_ptr(shared_storage.input.smem_zero.begin()), SmemLayoutScale{}); + + // Map input, compute, and fragment tensors to + // Copy strategies and partitioned tensors. These will become the input + // operands of the transform function. Depending on MMA atom type, the + // operands can reside in SMEM or TMEM + auto setup_copy_ops = [&](auto tensor_input, auto input_copy_atom, auto tensor_compute, auto make_fragment, + auto compute_copy_atom) constexpr + { + auto fragment_compute = make_fragment(tensor_compute); + if constexpr (cute::is_tmem>::value) + { + // For M=128 with 2CTA MMA atoms, the TMEM tensor for A has a duplicated allocation. + // Instead of allocation a 64x16 TMEM tensor, we have a 128x16 allocation + // See: TmemAllocMode::Duplicated. + Tensor tensor_input2x = [&]() constexpr + { + if constexpr (decltype(size<0, 0>(fragment_compute) == Int<128>{} + && size<0, 0>(tensor_input) == Int<64>{})::value) + { + return make_tensor(tensor_input.data(), + logical_product(tensor_input.layout(), make_tile(make_tile(Layout<_2, _0>{}, _), _, _, _))); + } + else + { + return tensor_input; + } + }(); + + fragment_compute.data() + = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + // If operand comes from TMEM, create the TMEM_STORE based copy + auto r2t_tiled_copy = make_tmem_copy(compute_copy_atom, fragment_compute(_, _, _, 0)); + auto thr_r2t_tiled_copy = r2t_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input + = thr_r2t_tiled_copy.partition_S(tensor_input2x); //(TMEM_STORE, TMEM_STORE_M, TMEM_STORE_N) + auto partitioned_tensor_compute + = thr_r2t_tiled_copy.partition_D(fragment_compute); //(TMEM_STORE, TMEM_STORE_M, TMEM_STORE_N) + + // Source copy is based on the source operand of TMEM_STORE copy. + auto smem2reg_tiled_copy = make_tiled_copy_S(Copy_Atom{}, r2t_tiled_copy); + return cute::make_tuple( + smem2reg_tiled_copy, r2t_tiled_copy, partitioned_tensor_input, partitioned_tensor_compute); + } + else + { + auto tensor_compute_ind_sw = as_position_independent_swizzle_tensor(tensor_compute); + auto r2s_tiled_copy = make_cotiled_copy( + compute_copy_atom, Layout, Stride<_8, _1>>{}, tensor_compute(_, _, _, 0).layout()); + + auto smem2reg_tiled_copy = make_tiled_copy_S(input_copy_atom, r2s_tiled_copy); + auto thr_r2s_tiled_copy = r2s_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input + = thr_r2s_tiled_copy.partition_S(tensor_input); //(SMEM_STORE, SMEM_STORE_M, SMEM_STORE_N) + + auto partitioned_tensor_compute + = thr_r2s_tiled_copy.partition_D(tensor_compute_ind_sw); //(SMEM_STORE, SMEM_STORE_M, SMEM_STORE_N) + + return cute::make_tuple( + smem2reg_tiled_copy, AutoVectorizingCopy{}, partitioned_tensor_input, partitioned_tensor_compute); + } + }; + + auto [src_copy_A, dst_copy_A, tAsA, tAsACompute] = setup_copy_ops( + sA, InputCopyAtomA{}, sACompute, [&](auto& arg) { return TiledMma::make_fragment_A(arg); }, + ComputeCopyAtomA{}); + + // Partition of thread -> shared and thread -> RF + auto fragment_compute = TiledMma::make_fragment_A(sS); + fragment_compute.data() + = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + auto r2t_tiled_copy = make_tmem_copy(ComputeCopyAtomA{}, fragment_compute(_, _, _, 0)); + auto src_copy_scale = make_tiled_copy_S(Copy_Atom{}, r2t_tiled_copy); + + auto partitioned_extra_info = Utils::partition_extra_transform_info(TiledMma{}, src_copy_scale, shared_storage); + + return cute::make_tuple(gA_mkl, dst_copy_A, tAsA, tAsACompute, partitioned_extra_info); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template + CUTLASS_DEVICE auto mma(Load2MmaPipeline load2mma_pipeline, Load2MmaPipelineState load2mma_pipeline_consumer_state, + Transform2MmaPipeline transform2mma_pipeline, Transform2MmaPipelineState transform2mma_pipeline_consumer_state, + Mma2AccumPipeline mma2accum_pipeline, Mma2AccumPipelineState mma2accum_pipeline_producer_state, + cute::Tensor const& accumulators, cute::tuple const& input_operands, + int k_tile_count) + { + TiledMma tiled_mma; + + auto curr_load2mma_pipeline_consumer_state = load2mma_pipeline_consumer_state; + auto next_load2mma_pipeline_consumer_state = load2mma_pipeline_consumer_state; + + auto curr_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + auto next_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + + uint32_t skip_wait = (k_tile_count <= 0); + auto transform2mma_flag + = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + auto load2mma_flag = load2mma_pipeline.consumer_try_wait(next_load2mma_pipeline_consumer_state, skip_wait); + ++next_transform2mma_pipeline_consumer_state; + ++next_load2mma_pipeline_consumer_state; + + // tCrA : (MMA), MMA_M, MMA_K, SmemStage (In SMEM or TMEM) + // We use SMEM stages to match #buffers in Load <-> Convert + // tCrB : (MMA), MMA_N, MMA_K, SmemStages (In SMEM) + auto const [tCrA, tCrB] = input_operands; + + mma2accum_pipeline.producer_acquire(mma2accum_pipeline_producer_state); + + int mma2accum_pipeline_producer_state_index = mma2accum_pipeline_producer_state.index(); + auto tCtC = accumulators(_, _, _, mma2accum_pipeline_producer_state_index); + auto curr_mma2accum_pipeline_producer_state = mma2accum_pipeline_producer_state; + ++mma2accum_pipeline_producer_state; + + // + // PIPELINED MAIN LOOP + // + // Clear the accumulator + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + + load2mma_pipeline.consumer_wait(curr_load2mma_pipeline_consumer_state, load2mma_flag); + transform2mma_pipeline.consumer_wait(curr_transform2mma_pipeline_consumer_state, transform2mma_flag); + + int load2mma_pipeline_consumer_state_index = curr_load2mma_pipeline_consumer_state.index(); // read_stage + int transform2mma_pipeline_consumer_state_index + = curr_transform2mma_pipeline_consumer_state.index(); // read_stage + + auto tCrA0 = tCrA(_, _, _, transform2mma_pipeline_consumer_state_index); + auto tCrB0 = tCrB(_, _, _, load2mma_pipeline_consumer_state_index); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); k_block++) + { + cute::gemm(tiled_mma, tCrA0(_, _, k_block), tCrB0(_, _, k_block), tCtC); // A[0]*B[0] + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + load2mma_pipeline.consumer_release(curr_load2mma_pipeline_consumer_state); + transform2mma_pipeline.consumer_release(curr_transform2mma_pipeline_consumer_state); + + skip_wait = (k_tile_count <= 1); + load2mma_flag = load2mma_pipeline.consumer_try_wait(next_load2mma_pipeline_consumer_state, skip_wait); + transform2mma_flag + = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + + curr_load2mma_pipeline_consumer_state = next_load2mma_pipeline_consumer_state; + curr_transform2mma_pipeline_consumer_state = next_transform2mma_pipeline_consumer_state; + + ++next_load2mma_pipeline_consumer_state; + ++next_transform2mma_pipeline_consumer_state; + } + + mma2accum_pipeline.producer_commit(curr_mma2accum_pipeline_producer_state); + + return cute::make_tuple(curr_load2mma_pipeline_consumer_state, curr_transform2mma_pipeline_consumer_state, + mma2accum_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto mma_init( + cute::Tensor const& accumulators, TensorStorage& shared_storage) const + { + TiledMma tiled_mma; + + auto get_tCrA = [&]() constexpr + { + if constexpr (cute::is_base_of::value) + { + Tensor sACompute + = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + return tiled_mma.make_fragment_A(sACompute); + } + else + { + auto tCrA = tiled_mma.make_fragment_A(shape(SmemLayoutACompute{})); + tCrA.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + return tCrA; + } + }; + + Tensor tCrA = get_tCrA(); + Tensor sB = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); + Tensor tCrB = tiled_mma.make_fragment_B(sB); + return cute::make_tuple(tCrA, tCrB); + } + + template + CUTLASS_DEVICE auto accum_init( + cute::Tensor const& accumulators, TmemCopyAtom tmem_cp_atom, EpilogueTile epilogue_tile) + { + return accumulators; + } + +private: + template + CUTLASS_DEVICE constexpr auto tile_input_tensors( + Params const& params, ProblemShape_MNKL const& problem_shape_MNKL) const + { + using X = cute::Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M, K, L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N, K, L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step{}); + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + typename Params::TMA_A const* observed_tma_load_a_ = nullptr; + typename Params::TMA_B const* observed_tma_load_b_ = nullptr; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + + ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h index b94595485b..b67a5bcf52 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h @@ -533,8 +533,8 @@ struct GemmFpAIntB run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ == 890) run_kernel(params, shared_storage); -#elif (__CUDA_ARCH__ >= 1000) - // Use SM80 implementation for GB10x, GB20x. +#elif (__CUDA_ARCH__ >= 1200) + // Use SM80 implementation for GB20x. run_kernel(params, shared_storage); #else CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels. diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h index 17690c278b..4615b88815 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -87,7 +87,9 @@ public: // Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, // which signals that we want to dequantize after loading from smem. template -struct LayoutDetailsB= 75>::type> +struct LayoutDetailsB= 75 && Arch::kMinComputeCapability != 100 + && Arch::kMinComputeCapability != 103>::type> { static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; @@ -102,7 +104,9 @@ public: }; template -struct LayoutDetailsB= 75>::type> +struct LayoutDetailsB= 75 && Arch::kMinComputeCapability != 100 + && Arch::kMinComputeCapability != 103>::type> { static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; @@ -116,6 +120,26 @@ public: using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; +template +struct LayoutDetailsB::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + } // namespace kernel } // namespace gemm } // namespace cutlass diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp index 028effc68f..50794769b5 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp @@ -409,14 +409,14 @@ std::vector get_candidate_configs_sm100_dynamic_cluster_shape } std::vector> tile_configs{ - {CutlassTileConfigSM100::CtaShape128x128x128B, cluster1sm}, - {CutlassTileConfigSM100::CtaShape128x256x128B, cluster1sm}, - {CutlassTileConfigSM100::CtaShape128x32x128B, cluster1sm}, - {CutlassTileConfigSM100::CtaShape64x64x128B, cluster1sm}, {CutlassTileConfigSM100::CtaShape64x32x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape64x64x128B, cluster1sm}, {CutlassTileConfigSM100::CtaShape64x128x128B, cluster1sm}, {CutlassTileConfigSM100::CtaShape64x256x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape128x32x128B, cluster1sm}, {CutlassTileConfigSM100::CtaShape128x64x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape128x128x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape128x256x128B, cluster1sm}, }; if (supports_2sm) @@ -479,6 +479,30 @@ std::vector get_candidate_configs_sm100( } return candidate_configs; } + else if (config & CutlassGemmConfig::WEIGHT_ONLY) + { + std::vector tile_configs{ + CutlassTileConfigSM100::CtaShape64x128x128B, + CutlassTileConfigSM100::CtaShape128x128x128B, + }; + std::vector candidate_configs; + for (auto const& tile_config : tile_configs) + { + CutlassGemmConfig cutlassKernelConfig(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_1x1x1, ClusterShape::Undefined, ClusterShape::Undefined, sm); + candidate_configs.push_back(cutlassKernelConfig); + } + + // add cuda kernel profiler to tactics for weight-only plugins + CutlassGemmConfig cudaKernelConfig(CutlassTileConfigSM100::CtaShape64x128x128B, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1, ClusterShape::Undefined, + ClusterShape::Undefined, sm); + + cudaKernelConfig.enableCudaKernel = true; + candidate_configs.push_back(cudaKernelConfig); + + return candidate_configs; + } else { TLLM_THROW("Not Implemented: SM100 GEMM candidates have not been defined."); diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp index 896f7e76f5..6bd24a972f 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp @@ -134,8 +134,17 @@ LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch) { return getLayoutDetailsForArch(quant_type); } - else if (arch >= 100) + else if (arch == 100) { + return getLayoutDetailsForArch(quant_type); + } + else if (arch == 103) + { + return getLayoutDetailsForArch(quant_type); + } + else if (arch >= 120) + { + // Use SM80 implementation for GB20x. return getLayoutDetailsForArch(quant_type); } else @@ -591,23 +600,31 @@ void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, in // Works on row major data, so issue this permutation first. if (details.uses_imma_ldsm) { + TLLM_LOG_INFO("permute_B_rows_for_mixed_gemm"); permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type, arch); src_buf.swap(dst_buf); } if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR) { + TLLM_LOG_INFO("subbyte_transpose"); subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type); src_buf.swap(dst_buf); } - if (details.columns_interleaved > 1 && arch != 90) + if (details.columns_interleaved > 1 && (arch != 90)) { + TLLM_LOG_INFO("interleave_column_major_tensor"); interleave_column_major_tensor(dst_buf.data(), src_buf.data(), shape, quant_type, details); src_buf.swap(dst_buf); } - add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type); + if (arch != 100 && arch != 103) + { + TLLM_LOG_INFO("add_bias_and_interleave_quantized_tensor_inplace"); + add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type); + } + TLLM_LOG_INFO("copy"); std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight); } diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h index 1ebaecaa11..b554ea2c8d 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -40,6 +40,7 @@ #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" #include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h" +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm100.h" #include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h" namespace tk = tensorrt_llm::common; @@ -429,7 +430,7 @@ void CutlassFpAIntBGemmRunner(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); } - else if ((sm_ >= 80 && sm_ < 89) || sm_ >= 100) + else if ((sm_ >= 80 && sm_ < 89) || sm_ >= 120) { dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, @@ -454,9 +455,27 @@ void CutlassFpAIntBGemmRunner::value || cutlass::platform::is_same::value, "ScaleZeroType must be half for activation=fp8"); +#ifdef COMPILE_HOPPER_TMA_GEMMS cutlass_kernels_oss::sm90_dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); +#else // COMPILE_HOPPER_TMA_GEMMS + throw std::runtime_error( + "[TensorRT LLM Error][fpA_intB Runner] Please recompile with support for hopper by passing 90-real as an " + "arch to build_wheel.py."); +#endif // COMPILE_HOPPER_TMA_GEMMS + } + else if (sm_ == 100 || sm_ == 103) + { +#ifdef COMPILE_BLACKWELL_TMA_GEMMS + cutlass_kernels_oss::sm100_dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, + group_size, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); +#else // COMPILE_BLACKWELL_TMA_GEMMS + throw std::runtime_error( + "[TensorRT LLM Error][fpA_intB Runner] Please recompile with support for blackwell by passing 100-real as " + "an arch to build_wheel.py."); +#endif // COMPILE_BLACKWELL_TMA_GEMMS } else { @@ -537,8 +556,9 @@ CutlassFpAIntBGemmRunner::value; - tkc::CutlassGemmConfig::CandidateConfigTypeParam config_type_param - = tkc::CutlassGemmConfig::CandidateConfigTypeParam::HOPPER; + tkc::CutlassGemmConfig::CandidateConfigTypeParam config_type_param = (sm_ >= 100) + ? tkc::CutlassGemmConfig::CandidateConfigTypeParam::BLACKWELL + : tkc::CutlassGemmConfig::CandidateConfigTypeParam::HOPPER; if (is_weight_only) { config_type_param = static_cast( diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm100.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm100.h new file mode 100644 index 0000000000..c95b80693c --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm100.h @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cute/numeric/integral_constant.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm100.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels_oss +{ +namespace tk = tensorrt_llm::common; +namespace tkc = tensorrt_llm::cutlass_extensions; + +using namespace cute; + +template +void sm100_dispatch_epilogue_schedules(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + switch (gemm_config.epilogue_schedule) + { + case tkc::EpilogueScheduleType::AUTO: + // TODO: use heuristics to select the epilogue schedule, depending on the CTA shape and cluster shape + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; + sm100_generic_mixed_gemm_kernelLauncher(A, B, weight_scales, + weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, + occupancy); + break; + default: + throw std::runtime_error( + "[TensorRT LLM Error][fpA_intB][sm100_dispatch_epilogue_schedules] Unsupported epilogue schedule for mixed " + "type GEMM."); + break; + } +} + +template +void sm100_dispatch_mainloop_schedules(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + switch (gemm_config.mainloop_schedule) + { + case tkc::MainloopScheduleType::AUTO: + // TODO: use heuristics to select the mainloop schedule, depending on the CTA shape and cluster shape + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; + sm100_dispatch_epilogue_schedules(A, B, weight_scales, weight_zero_points, biases, + alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + default: + throw std::runtime_error( + "[TensorRT LLM Error][fpA_intB][sm100_dispatch_mainloop_schedules] Unsupported mainloop schedule for mixed " + "type GEMM."); + break; + } +} + +template +void sm100_dispatch_gemm_config(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + switch (gemm_config.cluster_shape) + { + // TODO: add support for other cluster shapes + case tkc::ClusterShape::ClusterShape_1x1x1: + sm100_dispatch_mainloop_schedules>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, + k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + default: + throw std::runtime_error( + "[TensorRT LLM Error][fpA_intB][sm100_dispatch_gemm_config] Unsupported CTA and Cluster shapes for mixed " + "type GEMM."); + break; + } +} + +template +void sm100_dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config, + cudaStream_t stream, int* occupancy = nullptr) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + // Note that SIMT configs are omitted here since they are not supported for fpA_intB. + // We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best + // for mixed type gemms. + + constexpr int Ktile = 128 / sizeof(ActivationType); + using _Ktile = Int; + switch (gemm_config.tile_config_sm100) + { + case tkc::CutlassTileConfigSM100::CtaShape64x128x128B: + sm100_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM100::CtaShape128x128x128B: + sm100_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + default: + throw std::runtime_error( + "[TensorRT LLM Error][fpA_intB][sm100_dispatch_gemm_to_cutlass] Unsupported tile shape for mixed type " + "GEMM."); + break; + } +} + +} // namespace cutlass_kernels_oss +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm100.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm100.h new file mode 100644 index 0000000000..38c436bbd5 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm100.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cutlass_extensions/gemm_configs.h" +#include "cutlass_extensions/weight_only_quant_op.h" +#include + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels_oss +{ + +template +void sm100_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B, + ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases, + float const alpha, OutputType* C, int m, int n, int k, int const group_size, + tensorrt_llm::cutlass_extensions::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr); + +} // namespace cutlass_kernels_oss +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm100.inl b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm100.inl new file mode 100644 index 0000000000..44c97db125 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm100.inl @@ -0,0 +1,286 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/packed_stride.hpp" + +#include "cutlass_extensions/compute_occupancy.h" +#include "cutlass_extensions/epilogue_helpers.h" +#include "cutlass_extensions/gemm_configs.h" + +#include "cutlass_extensions/gemm/collective/collective_builder_sm100_weightonly.hpp" + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic pop +#endif // __GNUC__ + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm100.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels_oss +{ +using namespace tensorrt_llm::kernels::cutlass_kernels; +namespace tk = tensorrt_llm::common; +namespace tkc = tensorrt_llm::cutlass_extensions; + +using namespace cute; + +template +void sm100_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B, + ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases, + float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, + char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + using CutlassActivationType = typename TllmToCutlassTypeAdapter::type; + +#ifdef COMPILE_BLACKWELL_TMA_GEMMS + if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v) + { + using CutlassWeightType__ = typename TllmToCutlassTypeAdapter::type; + // We need to remap this since SM100 uses a different layout for the weight matrix. + using CutlassWeightType_ = std::conditional_t, + cutlass::int4b_t, CutlassWeightType__>; + + using CutlassWeightType + = std::conditional_t, int8_t, CutlassWeightType_>; + + using CutlassScaleZeroType = typename TllmToCutlassTypeAdapter::type; + using CutlassBiasType = typename TllmToCutlassTypeAdapter::type; + using CutlassOutputType = typename TllmToCutlassTypeAdapter::type; + + static_assert(std::is_same_v + || std::is_same_v + || std::is_same_v + || std::is_same_v, + "Activation type must be bfloat16, half, FP8"); + + static_assert(std::is_same_v || std::is_same_v + || std::is_same_v + || std::is_same_v, + "Weight type must be fp8, int8_t or int4_t"); + + static_assert(!std::is_same_v + || std::is_same_v, + "Scale/Zero type must be half for fp8 activation"); + + using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + // This example manually swaps and transposes, so keep transpose of input layouts + using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; + using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; + + using ElementZero = CutlassScaleZeroType; + using ElementScale = CutlassScaleZeroType; + + // C/D matrix configuration. We reuse the C operand for the bias and set the stride for broadcast. + using LayoutBias = cutlass::layout::RowMajor; + constexpr int AlignmentBias = 128 / cutlass::sizeof_bits::value; + + // D matrix configuration + using LayoutOutput = cutlass::layout::RowMajor; + constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + // Core kernel configurations + using ElementAccumulator = float; // Element type for internal accumulation + using ElementCompute = float; // Element type for epilogue computation + using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + // using TileShape = CTAShape; // Threadblock-level tile size + constexpr static bool Is2SM = cute::size<0>(ClusterShape{}) == 2; + constexpr static int TileM = cute::size<0>(CTAShape{}) * (Is2SM ? 2 : 1); + constexpr static int TileN = cute::size<1>(CTAShape{}); + constexpr static int TileK = cute::size<2>(CTAShape{}); + using TileShape = cute::Shape, cute::Int, cute::Int>; + using MainloopSchedule = std::conditional_t; + using EpilogueSchedule = std::conditional_t; + + static_assert(std::is_same_v, ""); + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder::type, AlignmentBias, + CutlassOutputType, typename cutlass::layout::LayoutTranspose::type, AlignmentOutput, + EpilogueSchedule>::CollectiveOp; + + using PackedScaleZero = cute::tuple; + using PackedScale = cute::tuple; + using ElementBCollectiveInfo = std::conditional_t; + + constexpr int ScaleGranularityN = 1; // Should be less than or equal to GEMM_N + constexpr int ScaleGranularityK = size<2>(TileShape{}); // Should be less than or equal to GEMM_K + using ScaleConfig = cutlass::detail::Sm100MixedInputBlockwiseScaleConfig; + using LayoutScale = decltype(ScaleConfig::deduce_layout_scale()); // Layout type for SFA matrix operand + LayoutScale layout_S = ScaleConfig::tile_atom_to_shape_scale(make_shape(n, k, 1)); + + // We swap A and B operands to the builder here + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilderSm100WeightOnly, AlignmentB, + CutlassActivationType, LayoutA_Transpose, AlignmentA, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename GemmKernel::StrideA; + using StrideB = typename GemmKernel::StrideB; + using StrideC = typename GemmKernel::StrideC; + using StrideD = typename GemmKernel::StrideD; + + if (weight_scales == nullptr) + { + throw std::runtime_error("Weight scales must always be set to a non-null value."); + } + + if constexpr (cutlass::isFinegrained(QuantOp)) + { + int cta_shape_k = cute::size<2>(TileShape{}); + if (group_size % cta_shape_k != 0) + { + std::string err_msg = "The group size must a multiple of " + std::to_string(cta_shape_k); + throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner]" + err_msg); + } + + if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY) + { + if (weight_zero_points != nullptr) + { + throw std::runtime_error("Weight zero pointer must be a nullptr for scale only fine grained"); + } + } + else if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) + { + if (weight_zero_points == nullptr) + { + throw std::runtime_error("Weight zero pointer must be valid for scale and bias fine grained"); + } + } + } + else + { + if (group_size != k) + { + throw std::runtime_error("Invalid group size for per column scaling kernels."); + } + + if (weight_zero_points != nullptr) + { + throw std::runtime_error("Weight zero-points must be null when running per column scaling"); + } + } + + StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1)); + StrideC stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(0, m, 1)); + + typename Gemm::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, {n, m, k, 1}, + {reinterpret_cast(B), stride_B, reinterpret_cast(A), + stride_A, reinterpret_cast(weight_scales), layout_S, group_size, + reinterpret_cast(weight_zero_points)}, + {{alpha}, reinterpret_cast(biases), stride_C, + reinterpret_cast(C), stride_D}}; + + Gemm gemm; + if (gemm.get_workspace_size(args) > workspace_bytes) + { + TLLM_LOG_ERROR("[TensorRT LLM Error][fpA_intB Runner] given workspace size insufficient."); + } + + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) + { + std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " + + std::string(cutlass::cutlassGetStatusString(can_implement)); + std::cout << err_msg << std::endl; + throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner] " + err_msg); + } + + auto init_status = gemm.initialize(args, workspace, stream); + if (init_status != cutlass::Status::kSuccess) + { + std::string err_msg = "Failed to initialize cutlass fpA_intB gemm. Error: " + + std::string(cutlassGetStatusString(init_status)); + throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner] " + err_msg); + } + + auto run_status = gemm.run(stream); + if (run_status != cutlass::Status::kSuccess) + { + std::string err_msg + = "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status)); + throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner] " + err_msg); + } + } + else + { + std::stringstream ss; + ss << "[TensorRT LLM Error][fpA_intB Runner] Config (" << (int64_t) cute::size<0>(CTAShape{}) << "," + << (int64_t) cute::size<1>(CTAShape{}) << "," << (int64_t) cute::size<2>(CTAShape{}) << ") (" + << (int64_t) cute::size<0>(ClusterShape{}) << "," << (int64_t) cute::size<1>(ClusterShape{}) << "," + << (int64_t) cute::size<2>(ClusterShape{}) << ") not compiled with FAST_BUILD."; + + throw std::runtime_error(ss.str()); + } + +#else // COMPILE_BLACKWELL_TMA_GEMMS + throw std::runtime_error( + "[TensorRT LLM Error][fpA_intB Runner] Please recompile with support for blackwell by passing 100-real as an " + "arch " + "to build_wheel.py."); +#endif // COMPILE_BLACKWELL_TMA_GEMMS +} + +} // namespace cutlass_kernels_oss +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py b/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py index 61070281c4..082198a4ca 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py @@ -213,8 +213,10 @@ def instantiate_operation_tma_warp_specialized(operation): if operation.gemm_kind == GemmKind.Gemm: weight_tag = DataTypeTag[operation.weight_type] + # Use sm100_generic_mixed_gemm_kernelLauncher for SM100 and above + launcher_name = "sm100_generic_mixed_gemm_kernelLauncher" if operation.arch >= 100 else "sm90_generic_mixed_gemm_kernelLauncher" instantiation = f""" -template void sm90_generic_mixed_gemm_kernelLauncher<{act_tag}, {weight_tag}, {scale_zero_tag}, {bias_tag}, {out_tag}, +template void {launcher_name}<{act_tag}, {weight_tag}, {scale_zero_tag}, {bias_tag}, {out_tag}, {quant_op}, {epi_tag}, {cute_cta_shape}, {cute_cga_shape}, {kernel_sched}, {epi_sched}> ( @@ -808,8 +810,69 @@ def generate_sm103_operations(is_arch_enabled): return operations +def generate_sm100_mixed_gemm_operations(is_arch_enabled): + if not is_arch_enabled: + return [] + arch = 100 + + # For SM100 (Blackwell), we support similar dtypes as SM90 + # Takes the form (activation_type, weight_type, scalezero_type, bias_type, output_type) + supported_dtypes = [ + (DataType.e4m3, DataType.u4, DataType.f16, DataType.f16, DataType.f16), + (DataType.e4m3, DataType.u4, DataType.f16, DataType.bf16, + DataType.bf16), + (DataType.f16, DataType.u4, DataType.f16, DataType.f16, DataType.f16), + (DataType.bf16, DataType.u4, DataType.bf16, DataType.bf16, + DataType.bf16), + (DataType.f16, DataType.u8, DataType.f16, DataType.f16, DataType.f16), + (DataType.bf16, DataType.u8, DataType.bf16, DataType.bf16, + DataType.bf16) + ] + + quant_ops = [ + TrtLlm_QuantOp.per_column_scale_only, + TrtLlm_QuantOp.finegrained_scale_only, + TrtLlm_QuantOp.finegrained_scale_and_zeros + ] + + epi_tags = [TrtLlm_EpilogueTag.epilogue_op_bias] + + # SM100 uses different tile shapes + M_TILES = [64, 128] + N_TILES = [128] + cta_shapes_mn = product(M_TILES, N_TILES) + + warp_shape = [4, 1, 1] + stages = 0 # auto + + # SM100 currently only supports 1x1x1 cluster shape + cga_shapes = [(1, 1, 1)] + + partial_args = product(supported_dtypes, quant_ops, epi_tags, cta_shapes_mn, + cga_shapes) + + operations = list() + for dtype_combo, quant_op, epi_tag, cta_shape_mn, cga_shape in partial_args: + max_k_bits = 128 * 8 + cta_shape_k = max_k_bits // GetDataTypeBits(dtype_combo[0]) + cta_shape_mnk = cta_shape_mn + (cta_shape_k, ) + + # SM100 uses 1SM schedule for mixed type GEMM + mainloop_schedule = KernelScheduleType.TmaWarpSpecialized1SmSm100 + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized1Sm + + fpA_intB_operation = TrtLlm_GemmLauncher(GemmKind.Gemm, arch, *dtype_combo, quant_op, epi_tag, cta_shape_mnk, \ + warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule) + + if is_gemm_op_valid_sm100(fpA_intB_operation): + operations.append(fpA_intB_operation) + + return operations + + def generate_sm100_operations(is_arch_enabled): operations = generate_sm100_grouped_gemm_operations(is_arch_enabled, 100) + operations.extend(generate_sm100_mixed_gemm_operations(is_arch_enabled)) return operations @@ -878,6 +941,7 @@ if __name__ == "__main__": output_dir = os.path.abspath(args.output_dir) fpA_intB_inl = "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl" + fpA_intB_sm100_inl = "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm100.inl" moe_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl" # moe_gemm_inl = "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl" moe_mixed_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl" @@ -887,6 +951,8 @@ if __name__ == "__main__": inl_map = { (GemmKind.Gemm, 90): [fpA_intB_inl], + (GemmKind.Gemm, 100): [fpA_intB_sm100_inl], + (GemmKind.Gemm, 103): [fpA_intB_sm100_inl], (GemmKind.Grouped, 90): [moe_gemm_inl], (GemmKind.Grouped, 100): [moe_gemm_inl], (GemmKind.Grouped, 103): [moe_gemm_inl], diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int4GroupwiseColumnMajorFalse.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int4GroupwiseColumnMajorFalse.cu new file mode 100644 index 0000000000..f7eeb0879a --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int4GroupwiseColumnMajorFalse.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajor, false, 64); +// KTile=128 for w4a8 +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajor, false, 128); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int4PerChannelColumnMajorFalse.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int4PerChannelColumnMajorFalse.cu new file mode 100644 index 0000000000..9cae268bf6 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int4PerChannelColumnMajorFalse.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int4PerChannel, BF16DetailsA, Int4DetailsW, ColumnMajor, false, 64); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int8GroupwiseColumnMajorFalse.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int8GroupwiseColumnMajorFalse.cu new file mode 100644 index 0000000000..41f6f9edaf --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int8GroupwiseColumnMajorFalse.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajor, false, 64); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int8PerChannelColumnMajorFalse.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int8PerChannelColumnMajorFalse.cu new file mode 100644 index 0000000000..07a6433b84 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int8PerChannelColumnMajorFalse.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int8PerChannel, BF16DetailsA, Int8DetailsW, ColumnMajor, false, 64); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4GroupwiseColumnMajorFalse.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4GroupwiseColumnMajorFalse.cu new file mode 100644 index 0000000000..c89d05ae5a --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4GroupwiseColumnMajorFalse.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajor, false, 64); +// KTile=128 for w4a8 +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajor, false, 128); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4PerChannelColumnMajorFalse.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4PerChannelColumnMajorFalse.cu new file mode 100644 index 0000000000..e4c0b2d17a --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4PerChannelColumnMajorFalse.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajor, false, 64); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8GroupwiseColumnMajorFalse.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8GroupwiseColumnMajorFalse.cu new file mode 100644 index 0000000000..17dce0221c --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8GroupwiseColumnMajorFalse.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajor, false, 64); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8PerChannelColumnMajorFalse.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8PerChannelColumnMajorFalse.cu new file mode 100644 index 0000000000..6c4ae98083 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8PerChannelColumnMajorFalse.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajor, false, 64); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h index 4562562754..ff5b95ba16 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h @@ -52,7 +52,7 @@ inline void kernel_launcher(int arch, Params& params, cudaStream_t s) EXEC(KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true); EXEC(KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); } - else if ((arch >= 80 && arch < 90) || arch >= 100) + else if ((arch >= 80 && arch < 90) || arch >= 120) { if (arch == 89 || arch >= 120) { @@ -68,7 +68,7 @@ inline void kernel_launcher(int arch, Params& params, cudaStream_t s) EXEC(KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); EXEC(KernelType::BF16Int4PerChannel, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); } - else if (arch >= 90) + else if (arch == 90) { // Dispatchers for W4A8 groupwise EXEC_W4A8(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true); @@ -83,6 +83,20 @@ inline void kernel_launcher(int arch, Params& params, cudaStream_t s) EXEC(KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true); EXEC(KernelType::BF16Int4PerChannel, BF16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true); } + else if (arch == 100 || arch == 103) + { + EXEC_W4A8(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajor, false); + EXEC_W4A8(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajor, false); + EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajor, false); + EXEC(KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajor, false); + EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajor, false); + EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajor, false); + EXEC(KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajor, false); + EXEC(KernelType::BF16Int8PerChannel, BF16DetailsA, Int8DetailsW, ColumnMajor, false); + EXEC(KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajor, false); + EXEC(KernelType::BF16Int4PerChannel, BF16DetailsA, Int4DetailsW, ColumnMajor, false); + } +#undef EXEC_W4A8 #undef EXEC } diff --git a/docs/source/features/quantization.md b/docs/source/features/quantization.md index f1a10e6dac..0fef31dace 100644 --- a/docs/source/features/quantization.md +++ b/docs/source/features/quantization.md @@ -123,7 +123,7 @@ The language component decides which quantization methods are supported by a giv | Model | NVFP4 | MXFP4 | FP8(per tensor)| FP8(block scaling) | FP8(rowwise) | FP8 KV Cache | NVFP4 KV Cache | W4A8 AWQ | W4A16 AWQ | W4A8 GPTQ | W4A16 GPTQ | | :------------- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :-------: | :-------: | :--------: | :--------: | | Blackwell(sm120) | Y | Y | Y | . | . | Y | . | . | . | . | . | -| Blackwell(sm100/103) | Y | Y | Y | Y | . | Y | Y | . | . | . | . | +| Blackwell(sm100/103) | Y | Y | Y | Y | . | Y | Y | Y | Y | Y | Y | | Hopper | . | . | Y | Y | Y | Y | . | Y | Y | Y | Y | | Ada Lovelace | . | . | Y | . | . | Y | . | Y | Y | Y | Y | | Ampere | . | . | . | . | . | Y | . | . | Y | . | Y | diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 5a3774e91a..b26b687ced 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -1327,7 +1327,7 @@ def _( class FinegrainedMixedDtypeGemm(TunableRunner): _runner_dict = dict() - MAX_SUPPORTED_SM_VERSION = 90 + MAX_SUPPORTED_SM_VERSION = 103 def __init__(self, activation_dtype: torch.dtype, output_dtype: torch.dtype, quant_mode: int): @@ -1362,7 +1362,7 @@ class FinegrainedMixedDtypeGemm(TunableRunner): if get_sm_version() > self.MAX_SUPPORTED_SM_VERSION: raise ValueError( - f"SM version {get_sm_version()} is not supported for W4A16 GEMM" + f"SM version {get_sm_version()} is not supported for W4A16/W4A8 finegrained mixed dtype GEMM" ) activation, weights_packed, scales = inputs diff --git a/tensorrt_llm/quantization/functional.py b/tensorrt_llm/quantization/functional.py index 30986412c8..bdfa6a0796 100644 --- a/tensorrt_llm/quantization/functional.py +++ b/tensorrt_llm/quantization/functional.py @@ -961,8 +961,8 @@ def preprocess_weights_for_mixed_gemm( tensor = tensor.unsqueeze(0) elif sm_ >= 90: sm_ = 80 - if sm_ > 90: - sm_ = 80 + if sm_ == 100 or sm_ == 103: + do_weight_interleave = False permutation_map = { "16_8": [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15], @@ -990,7 +990,7 @@ def preprocess_weights_for_mixed_gemm( assert (num_rows % B_ROWS_PER_MMA == 0) assert (num_cols % MMA_SHAPE_N == 0) - if do_weight_interleave: + if do_weight_interleave and sm_ < 100: row_idx_list = [(row_idx // B_ROWS_PER_MMA) * B_ROWS_PER_MMA + permutation_map[f"{BITS_PER_ELT_A}_{BITS_PER_ELT_B}"][ row_idx % B_ROWS_PER_MMA]