[None][feat] sm100 weight-only kernel (#10190)

Signed-off-by: Cheng Hang <chang@nvidia.com>
This commit is contained in:
Cheng Hang 2026-01-05 09:44:36 +08:00 committed by GitHub
parent b5a1e10bc0
commit 656c705ff1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 2751 additions and 23 deletions

View File

@ -686,4 +686,212 @@ public:
}
};
template <class Collective>
struct MixedInputUtilsSM100
{
private:
using KernelSchedule = typename Collective::KernelSchedule;
using ConversionMode = typename Collective::ConversionMode;
using SmemLayoutA = typename Collective::SmemLayoutA;
using SmemLayoutB = typename Collective::SmemLayoutB;
using ElementScale = typename Collective::ElementScale;
using ElementZero = typename Collective::ElementZero;
static constexpr auto KernelConversionMode = Collective::KernelConversionMode;
public:
// Helper functions to select packing for conversion
template <class SrcType, class DstType, int Cosize>
struct select_packing
{ // Naive packing policy
static constexpr auto value()
{
return Int<cute::gcd(Cosize, 32 / cute::min(sizeof_bits_v<SrcType>, sizeof_bits_v<DstType>))>{};
}
};
/// (Designed for separate transform pipeline in Blackwell)
/// Utilities to dequantize A.
template <class EngineIn, class EngineOut, class LayoutIn, class LayoutOut, class... Ts>
CUTLASS_DEVICE static void dequantize_A_kblock_for_transform(Tensor<EngineIn, LayoutIn> const& tArA,
Tensor<EngineOut, LayoutOut>& tArACompute, cute::tuple<Ts...> const& partitioned_extra_info, int const k_block)
{
static_assert(is_rmem<EngineIn>::value, "Input tensor for A conversion must come from registers");
static_assert(is_rmem<EngineOut>::value, "Output tensor for A conversion must come from registers");
static_assert(cosize_v<LayoutIn> == cosize_v<LayoutOut>);
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
using SrcType = typename EngineIn::value_type;
using DstType = typename EngineOut::value_type;
auto src = tArA(_, _, _, k_block);
auto dst = tArACompute(_, _, _, k_block);
auto pSrc = raw_pointer_cast(src.data());
auto pDst = const_cast<DstType*>(raw_pointer_cast(dst.data()));
constexpr int num_elements = decltype(size(src))::value;
constexpr int pack = decltype(select_packing<SrcType, DstType, num_elements>::value())::value;
using Converter
= cutlass::NumericArrayConverter<DstType, SrcType, pack, cutlass::FloatRoundStyle::round_to_nearest>;
using SrcArray = cutlass::Array<SrcType, pack>;
using DstArray = cutlass::Array<DstType, pack>;
constexpr int DstElementsPerReg = 32 / sizeof_bits_v<DstType>;
using RegArray = cutlass::AlignedArray<uint32_t, pack / DstElementsPerReg, sizeof(DstArray)>;
auto src_arr = recast<SrcArray>(src);
auto dst_arr = recast<DstArray>(dst);
Tensor dst_vm = cute::group_modes<1, -1>(cute::zipped_divide(dst, pack));
if constexpr (KernelConversionMode == ConversionMode::DirectConvert)
{
cute::transform(src_arr, dst_arr, Converter::convert);
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale)
{
auto const& scales = cute::get<1>(partitioned_extra_info)(_, _, _, k_block);
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
if constexpr (is_same_v<DstType, ElementScale>)
{
cute::transform(src_arr, dst_arr, Converter::convert);
using ScaleArray = cutlass::Array<ElementScale, pack>;
auto scale_arr = recast<ScaleArray>(filter_zeros(scales));
if constexpr (is_same_v<DstType, cutlass::bfloat16_t>)
{
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, pack));
for (int i = 0; i < size<1>(dst_vm); ++i)
{
auto&& r = cute::recast<RegArray>(dst_vm(_, i))(0);
auto&& scale_reg = cute::recast<RegArray>(scales_vm(_, i))(0);
CUTLASS_PRAGMA_UNROLL
for (size_t ii = 0; ii < RegArray::kElements; ++ii)
{
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
bf16x2_val = __hmul2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(scale_reg[ii]));
}
}
}
else
{
cute::transform(dst_arr, scale_arr, dst_arr, cute::multiplies{});
}
}
else
{
constexpr int pack1 = decltype(select_packing<SrcType, ElementScale, num_elements>::value())::value;
constexpr int pack2 = decltype(select_packing<ElementScale, DstType, num_elements>::value())::value;
constexpr int pack = cute::gcd(pack1, pack2);
using Converter1 = cutlass::NumericArrayConverter<ElementScale, SrcType, pack,
cutlass::FloatRoundStyle::round_to_nearest>;
using Converter2 = cutlass::NumericArrayConverter<DstType, ElementScale, pack,
cutlass::FloatRoundStyle::round_to_nearest>;
using SrcArray = cutlass::Array<SrcType, pack>;
using DstArray = cutlass::Array<DstType, pack>;
using StageArray = cutlass::Array<ElementScale, pack>;
constexpr int iters = num_elements / pack;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < iters; ++i)
{
SrcArray const* pSrcArr = reinterpret_cast<SrcArray const*>(pSrc) + i;
DstArray* pDstArr = reinterpret_cast<DstArray*>(pDst) + i;
StageArray stageArr;
stageArr = Converter1::convert(*pSrcArr);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < pack; ++j)
{
stageArr[j] = stageArr[j] * scales[i * pack + j];
}
*pDstArr = Converter2::convert(stageArr);
}
}
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero)
{
static_assert(is_same_v<ElementScale, ElementZero>, "ElementScale and ElementZero must be the same.");
auto const& scales = cute::get<1>(partitioned_extra_info)(_, _, _, k_block);
auto const& zeros = cute::get<3>(partitioned_extra_info)(_, _, _, k_block);
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
CUTE_STATIC_ASSERT_V(size(src) == size(zeros));
if constexpr (is_same_v<DstType, ElementZero>)
{
cute::transform(src_arr, dst_arr, Converter::convert);
using ScaleArray = cutlass::Array<ElementScale, pack>;
auto scale_arr = recast<ScaleArray>(filter_zeros(scales));
using ZeroArray = cutlass::Array<ElementZero, pack>;
auto zero_arr = recast<ZeroArray>(filter_zeros(zeros));
if constexpr (is_same_v<DstType, cutlass::bfloat16_t>)
{
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, pack));
Tensor zeros_vm = cute::group_modes<1, -1>(cute::zipped_divide(zeros, pack));
for (int i = 0; i < size<1>(dst_vm); ++i)
{
auto&& r = cute::recast<RegArray>(dst_vm(_, i))(0);
auto&& scale_reg = cute::recast<RegArray>(scales_vm(_, i))(0);
auto&& zero_reg = cute::recast<RegArray>(zeros_vm(_, i))(0);
CUTLASS_PRAGMA_UNROLL
for (size_t ii = 0; ii < RegArray::kElements; ++ii)
{
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
bf16x2_val = __hmul2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(scale_reg[ii]));
bf16x2_val = __hadd2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(zero_reg[ii]));
}
}
}
else
{
cute::transform(dst_arr, scale_arr, dst_arr, cute::multiplies{});
cute::transform(dst_arr, zero_arr, dst_arr, cute::plus{});
}
}
else
{
constexpr int pack1 = decltype(select_packing<SrcType, ElementScale, num_elements>::value())::value;
constexpr int pack2 = decltype(select_packing<ElementScale, DstType, num_elements>::value())::value;
constexpr int pack = cute::gcd(pack1, pack2);
using Converter1 = cutlass::NumericArrayConverter<ElementScale, SrcType, pack,
cutlass::FloatRoundStyle::round_to_nearest>;
using Converter2 = cutlass::NumericArrayConverter<DstType, ElementScale, pack,
cutlass::FloatRoundStyle::round_to_nearest>;
using SrcArray = cutlass::Array<SrcType, pack>;
using DstArray = cutlass::Array<DstType, pack>;
using StageArray = cutlass::Array<ElementScale, pack>;
constexpr int iters = num_elements / pack;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < iters; ++i)
{
SrcArray const* pSrcArr = reinterpret_cast<SrcArray const*>(pSrc) + i;
DstArray* pDstArr = reinterpret_cast<DstArray*>(pDst) + i;
StageArray stageArr;
stageArr = Converter1::convert(*pSrcArr);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < pack; ++j)
{
stageArr[j] = stageArr[j] * scales[i * pack + j] + zeros[i * pack + j];
}
*pDstArr = Converter2::convert(stageArr);
}
}
}
else
{
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled for input partitioning.");
}
}
};
} // namespace cutlass::gemm::collective::detail

View File

@ -0,0 +1,294 @@
/*
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/gemm/collective/builders/sm100_common.inl"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail
{
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
template <int CapacityBytes, class ElementA, class ElementAMma, class ElementScale, class ElementZero, class ElementB,
class CtaTileShape_MNK, class TiledMma, class KernelScheduleType, UMMA::Major UmmaMajorA, int ScaleGranularityK,
int stages>
constexpr cute::tuple<int, int, int> sm100_compute_stage_count_or_override_weightonly(StageCount<stages> stage_count)
{
constexpr int Load2TransformStageCount = stages;
constexpr int Transform2MmaStageCount = stages;
constexpr int AccumulatorStageCount = stages;
return cute::make_tuple(Load2TransformStageCount, Transform2MmaStageCount, AccumulatorStageCount);
}
template <int CapacityBytes, class ElementA, class ElementAMma, class ElementScale, class ElementZero, class ElementB,
class CtaTileShape_MNK, class TiledMma, class KernelScheduleType, UMMA::Major UmmaMajorA, int ScaleGranularityK,
int carveout_bytes>
constexpr cute::tuple<int, int, int> sm100_compute_stage_count_or_override_weightonly(
StageCountAutoCarveout<carveout_bytes> stage_count)
{
constexpr int CtaM = get<0>(CtaTileShape_MNK{});
constexpr int CtaN = get<1>(CtaTileShape_MNK{});
static_assert(CtaN <= 128, "Can't support CtaN>128 tiles");
constexpr int CtaK = get<2>(CtaTileShape_MNK{});
using AtomThrID = typename TiledMma::AtomThrID;
constexpr int TmemColumns = 512;
constexpr bool IsAComputeinTmem = UmmaMajorA == cute::UMMA::Major::K
&& !cute::is_base_of_v<KernelTmaWarpSpecializedMixedInputSmemSm100, KernelScheduleType>;
constexpr bool IsAComputeinSmem = !IsAComputeinTmem;
// Detect 2x2 TMEM layout
constexpr int TmemAccWordsPerDP = (CtaM == 64 && size(AtomThrID{}) == 2) ? CtaN / 2 : CtaN;
constexpr int TmemAWordsPerDP = CtaK / 2;
constexpr int AccumulatorStageCount
= (IsAComputeinTmem) ? ((TmemAccWordsPerDP == 128) ? 2 : 3) : (TmemColumns / TmemAccWordsPerDP);
constexpr int SmemCapacityAfterMma2AccumCarveout = CapacityBytes - (carveout_bytes + AccumulatorStageCount * 32);
constexpr int TmemInAStageCount_Potential
= (IsAComputeinTmem) ? (TmemColumns - AccumulatorStageCount * TmemAccWordsPerDP) / TmemAWordsPerDP : 10000;
// Mainload2Transform Pipeline
constexpr auto load2transform_pipeline_bytes
= sizeof(typename cutlass::PipelineTmaTransformAsync<1>::SharedStorage);
constexpr auto a_bits = cute::sizeof_bits_v<ElementA>; // ElementA introduce here
constexpr auto s_bits = cute::is_void_v<ElementScale> ? 0 : cute::sizeof_bits_v<ElementScale>;
constexpr auto z_bits = cute::is_void_v<ElementZero> ? 0 : cute::sizeof_bits_v<ElementZero>;
constexpr auto load2mma_pipeline_bytes = sizeof(typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage);
constexpr auto b_bits = cute::sizeof_bits_v<ElementB>; // ElementB introduce here
constexpr int ab_stage_bytes
= cutlass::bits_to_bytes(a_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}))
+ cutlass::bits_to_bytes(s_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}) / ScaleGranularityK)
+ cutlass::bits_to_bytes(z_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}) / ScaleGranularityK)
+ cutlass::bits_to_bytes(b_bits * size<1>(CtaTileShape_MNK{}) / size(AtomThrID{}) * size<2>(CtaTileShape_MNK{}))
+ static_cast<int>(load2transform_pipeline_bytes) + static_cast<int>(load2mma_pipeline_bytes);
// Transform2Mma Pipeline
constexpr auto transform2mma_pipeline_bytes = sizeof(typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage);
constexpr auto a_compute_bits = cute::sizeof_bits_v<ElementAMma>;
constexpr int ab_compute_stage_bytes = cutlass::bits_to_bytes(a_compute_bits * int(IsAComputeinSmem)
* size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}))
+ // If ACompute is in TMEM, Acompute buffer has 0 bytes.
static_cast<int>(transform2mma_pipeline_bytes);
constexpr int ABComputeStageCount_Potential
= SmemCapacityAfterMma2AccumCarveout / (ab_stage_bytes + ab_compute_stage_bytes);
// The number of SMEM buffers for A, B. ACompute (if in SMEM), BCompute should be at least Transform2MmaStageCount
constexpr int Transform2MmaStageCount = std::min(TmemInAStageCount_Potential, ABComputeStageCount_Potential);
constexpr int SmemCapacityAfterABComputeCarveout
= SmemCapacityAfterMma2AccumCarveout - (Transform2MmaStageCount * ab_compute_stage_bytes);
// Can we boost the number of buffers for A and B?
constexpr int Load2TransformStageCount = SmemCapacityAfterABComputeCarveout / ab_stage_bytes;
static_assert(Load2TransformStageCount >= 2 && Transform2MmaStageCount >= 2 && AccumulatorStageCount >= 2,
"Not enough SMEM or TMEM capacity for selected tile size");
return cute::make_tuple(Load2TransformStageCount, Transform2MmaStageCount, AccumulatorStageCount);
}
} // namespace detail
// Mixed Input MMA kernels builder
template <class ElementAOptionalTuple, class GmemLayoutATagTuple, int AlignmentA, class ElementBOptionalTuple,
class GmemLayoutBTag, int AlignmentB, class ElementAccumulator,
class TileShape_MNK, // The Cluster-level TileShape
class ClusterShape_MNK, class StageCountType, class KernelScheduleType>
struct CollectiveBuilderSm100WeightOnly<arch::Sm100, arch::OpClassTensorOp,
ElementAOptionalTuple, // ElementA
GmemLayoutATagTuple, // LayoutA
AlignmentA,
ElementBOptionalTuple, // ElementB
GmemLayoutBTag, // LayoutB
AlignmentB, ElementAccumulator,
TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK)
ClusterShape_MNK, // Static cluster shape or dynamic (int, int, int)
StageCountType, KernelScheduleType,
cute::enable_if_t<(cute::is_base_of_v<KernelScheduleSm100MixedInputGemm, KernelScheduleType>) &&(
(sizeof(float) * AlignmentA) % detail::tma_alignment_bytes == 0)
&& ((sizeof(float) * AlignmentB) % detail::tma_alignment_bytes == 0)>>
{
using GmemLayoutATag = detail::deduce_mixed_width_dtype_t<0, GmemLayoutATagTuple>;
using GmemLayoutScaleTag = detail::deduce_mixed_width_dtype_t<1, GmemLayoutATagTuple>;
static constexpr cute::UMMA::Major UmmaMajorA
= cutlass::gemm::collective::detail::tag_to_umma_major_A<GmemLayoutATag>();
static constexpr cute::UMMA::Major UmmaMajorB
= cutlass::gemm::collective::detail::tag_to_umma_major_B<GmemLayoutBTag>();
using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>;
using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>;
using ElementScale = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple>;
using ElementZero = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>;
static constexpr bool NeitherIsTuple
= !cute::is_tuple<ElementAOptionalTuple>::value && !cute::is_tuple<ElementBOptionalTuple>::value;
static constexpr bool IsANarrow = cute::sizeof_bits_v<ElementA> < cute::sizeof_bits_v<ElementB>;
static constexpr bool IsMixedInput = cute::sizeof_bits_v<ElementA> != cute::sizeof_bits_v<ElementB>;
static_assert(IsMixedInput, "Mixed Input GEMM Kernel doesn't support regular gemm.");
static_assert(
(cute::is_tuple<ElementAOptionalTuple>::value ^ cute::is_tuple<ElementBOptionalTuple>::value
|| (NeitherIsTuple && (cute::sizeof_bits<ElementA>::value != cute::sizeof_bits<ElementB>::value))),
"Either A OR B must be a tuple or the widths of A and B must be different.");
using ElementPairA = cute::conditional_t<IsMixedInput && IsANarrow && NeitherIsTuple, cute::tuple<ElementA>,
ElementAOptionalTuple>;
using ElementPairB = cute::conditional_t<IsMixedInput && !IsANarrow && NeitherIsTuple, cute::tuple<ElementB>,
ElementBOptionalTuple>;
static constexpr bool IsATransformed = cute::is_tuple<ElementPairA>::value;
static_assert(IsATransformed, "A matrix should be transformed.");
// For fp32 types, map to tf32 MMA value type.
using ElementMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
using ElementAMma = ElementMma;
using ElementBMma = ElementMma;
static constexpr int IsSubbyteA = cute::sizeof_bits_v<ElementA> < 8;
using TmaElementA = cute::conditional_t<IsSubbyteA, uint8_t, ElementA>;
static constexpr int ScalingFactor = 1;
using TiledMma = decltype(detail::sm100_make_trivial_mixed_input_tiled_mma<ElementAMma, ElementB,
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, UmmaMajorA, UmmaMajorB, KernelScheduleType>());
using AtomThrID = typename TiledMma::AtomThrID;
using AtomThrShapeMNK = Shape<decltype(shape<0>(typename TiledMma::ThrLayoutVMNK{})), _1, _1>;
using CtaTileShape_MNK = decltype(shape_div(TileShape_MNK{}, AtomThrShapeMNK{}));
// ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K)
using MmaShapeA_MK = decltype(partition_shape_A(
TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), cute::size<2>(TileShape_MNK{}))));
// ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K)
using MmaShapeB_NK = decltype(partition_shape_B(
TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), cute::size<2>(TileShape_MNK{}))));
using BlockTileA_M = decltype(cute::size<0, 0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{}));
using BlockTileA_K = decltype(cute::size<0, 1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(cute::size<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm100_cluster_shape_to_tma_atom_B(ClusterShape_MNK{}, AtomThrID{}));
// Input transform kernel can not use TMA 2SM instructions.
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorA, ElementA,
BlockTileA_M, BlockTileA_K>());
using SmemLayoutAtomACompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorA,
ElementAMma, BlockTileA_M, BlockTileA_K>());
using SmemLayoutAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType<SmemLayoutAtomA,
SmemLayoutAtomACompute>;
static constexpr int MMA_M = cute::size<0, 0>(MmaShapeA_MK{});
using CopyAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType<
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementA>,
cute::conditional_t<
(UmmaMajorA == cute::UMMA::Major::K
&& !cute::is_base_of_v<KernelTmaWarpSpecializedMixedInputSmemSm100, KernelScheduleType>),
cute::conditional_t<(MMA_M == 64 && size(AtomThrID{}) == 1), SM100_TMEM_STORE_16dp256b1x,
SM100_TMEM_STORE_32dp32b8x>, // TS Implementation
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementA>> // SS Implementation
>;
using BlockTileB_N = decltype(cute::size<0, 0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{}));
using BlockTileB_K = decltype(cute::size<0, 1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{}));
// Input transform kernel can not use TMA 2SM instructions.
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorB, ElementB,
BlockTileB_N, BlockTileB_K>());
using SmemLayoutAtomBCompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorB,
ElementBMma, BlockTileB_N, BlockTileB_K>());
using SmemLayoutAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType<SmemLayoutAtomB,
SmemLayoutAtomBCompute>;
using CopyAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType<
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementB>,
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementMma>>;
// Creating the stride of Transformed Input
using StrideA = cutlass::gemm::TagToStrideA_t<GmemLayoutATag>;
using LayoutScale = cutlass::gemm::TagToStrideA_t<GmemLayoutScaleTag>;
using VoidShapeScale
= Shape<Shape<Int<128>, _1>, Shape<Int<64>, _1>, _1>; // Dummy Value to create a dummy ScaleConfig
using VoidStrideScale = Stride<Stride<_0, _1>, Stride<_0, _1>, _1>;
using VoidLayoutScale = Layout<VoidShapeScale, VoidStrideScale>;
using NonVoidLayoutScale = cute::conditional_t<cute::is_void_v<LayoutScale>, VoidLayoutScale, LayoutScale>;
using StridePairA = decltype(cute::make_tuple(StrideA{}, NonVoidLayoutScale{}));
// SmemCarveout
static constexpr int SchedulerPipelineStageCount = 3;
static constexpr bool IsArrayOfPointersGemm
= (cute::is_base_of_v<KernelScheduleSm100PtrArrayFastFP32Gemm, KernelScheduleType>);
// CLCPipeline = PipelineCLCFetchAsync
static constexpr auto CLCPipelineStorage
= sizeof(typename cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount, ClusterShape_MNK>::SharedStorage);
// CLC (scheduler) response
static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize;
// CLC Throttle pipeline storage
static constexpr auto CLCThrottlePipelineStorage
= sizeof(typename cutlass::PipelineAsync<SchedulerPipelineStageCount>::SharedStorage);
// Tmem dealloc
static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier);
// Tmem ptr storage
static constexpr auto TmemBasePtrsStorage = sizeof(uint32_t);
// Tensormap Storage
static constexpr size_t TensorMapStorage
= IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
// Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage
static constexpr auto KernelSmemCarveout = static_cast<int>(CLCPipelineStorage + CLCResponseStorage
+ CLCThrottlePipelineStorage + TmemDeallocStorage + TmemBasePtrsStorage + TensorMapStorage);
// Reduce SMEM capacity available for buffers considering extra B smem and barrier smem allocations
static constexpr int Sm100ReducedSmemCapacityBytes = detail::sm100_smem_capacity_bytes - KernelSmemCarveout;
static constexpr int ScaleGranularityK = get_ScaleGranularityK<LayoutScale>();
static constexpr auto stage_info
= cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_weightonly<
Sm100ReducedSmemCapacityBytes, TmaElementA, ElementAMma, ElementScale, ElementZero, ElementB,
CtaTileShape_MNK, TiledMma, KernelScheduleType, UmmaMajorA, ScaleGranularityK>(StageCountType{});
static constexpr int Load2TransformPipelineStageCount = get<0>(stage_info);
static constexpr int Transform2MmaPipelineStageCount = get<1>(stage_info);
static constexpr int AccumulatorPipelineStageCount = get<2>(stage_info);
static_assert(!IsArrayOfPointersGemm, "mixed input does not support grouped gemm on Blackwell");
using DispatchPolicy
= cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedMixedInput<Load2TransformPipelineStageCount,
Transform2MmaPipelineStageCount, SchedulerPipelineStageCount, AccumulatorPipelineStageCount,
ClusterShape_MNK>;
using CollectiveOp = cutlass::gemm::collective::CollectiveMmaSm100WeightOnly<DispatchPolicy, TileShape_MNK,
ElementPairA, StridePairA, ElementPairB, cutlass::gemm::TagToStrideB_t<GmemLayoutBTag>, TiledMma,
GmemTiledCopyA, SmemLayoutAtomPairA, CopyAtomPairA, cute::identity, GmemTiledCopyB, SmemLayoutAtomPairB,
CopyAtomPairB, cute::identity>;
};
} // namespace cutlass::gemm::collective

View File

@ -0,0 +1,42 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/collective/collective_mma_sm100_weightonly.hpp"
namespace cutlass::gemm::collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA, int AlignmentA, class ElementB,
class GmemLayoutB, int AlignmentB, class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
class StageCountType, class KernelScheduleType, class Enable = void>
struct CollectiveBuilderSm100WeightOnly
{
static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/builders/sm100_umma_builder_weightonly.inl"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,42 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/detail/dependent_false.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <class DispatchPolicy, class TileShape, class ElementA, class StrideA, class ElementB, class StrideB,
class TiledMma, class GmemTiledCopyA, class SmemLayoutAtomA, class SmemCopyAtomA, class TransformA,
class GmemTiledCopyB, class SmemLayoutAtomB, class SmemCopyAtomB, class TransformB>
struct CollectiveMmaSm100WeightOnly
{
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -533,8 +533,8 @@ struct GemmFpAIntB
run_kernel<arch::Sm80>(params, shared_storage);
#elif (__CUDA_ARCH__ == 890)
run_kernel<arch::Sm89>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 1000)
// Use SM80 implementation for GB10x, GB20x.
#elif (__CUDA_ARCH__ >= 1200)
// Use SM80 implementation for GB20x.
run_kernel<arch::Sm80>(params, shared_storage);
#else
CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.

View File

@ -87,7 +87,9 @@ public:
// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
// which signals that we want to dequantize after loading from smem.
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
struct LayoutDetailsB<TypeA, uint8_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability != 100
&& Arch::kMinComputeCapability != 103>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
@ -102,7 +104,9 @@ public:
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
struct LayoutDetailsB<TypeA, uint4b_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability != 100
&& Arch::kMinComputeCapability != 103>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
@ -116,6 +120,26 @@ public:
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint8_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability == 100 || Arch::kMinComputeCapability == 103>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint4b_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability == 100 || Arch::kMinComputeCapability == 103>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -409,14 +409,14 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm100_dynamic_cluster_shape
}
std::vector<std::pair<CutlassTileConfigSM100, ClusterShape>> tile_configs{
{CutlassTileConfigSM100::CtaShape128x128x128B, cluster1sm},
{CutlassTileConfigSM100::CtaShape128x256x128B, cluster1sm},
{CutlassTileConfigSM100::CtaShape128x32x128B, cluster1sm},
{CutlassTileConfigSM100::CtaShape64x64x128B, cluster1sm},
{CutlassTileConfigSM100::CtaShape64x32x128B, cluster1sm},
{CutlassTileConfigSM100::CtaShape64x64x128B, cluster1sm},
{CutlassTileConfigSM100::CtaShape64x128x128B, cluster1sm},
{CutlassTileConfigSM100::CtaShape64x256x128B, cluster1sm},
{CutlassTileConfigSM100::CtaShape128x32x128B, cluster1sm},
{CutlassTileConfigSM100::CtaShape128x64x128B, cluster1sm},
{CutlassTileConfigSM100::CtaShape128x128x128B, cluster1sm},
{CutlassTileConfigSM100::CtaShape128x256x128B, cluster1sm},
};
if (supports_2sm)
@ -479,6 +479,30 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm100(
}
return candidate_configs;
}
else if (config & CutlassGemmConfig::WEIGHT_ONLY)
{
std::vector<CutlassTileConfigSM100> tile_configs{
CutlassTileConfigSM100::CtaShape64x128x128B,
CutlassTileConfigSM100::CtaShape128x128x128B,
};
std::vector<CutlassGemmConfig> candidate_configs;
for (auto const& tile_config : tile_configs)
{
CutlassGemmConfig cutlassKernelConfig(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_1x1x1, ClusterShape::Undefined, ClusterShape::Undefined, sm);
candidate_configs.push_back(cutlassKernelConfig);
}
// add cuda kernel profiler to tactics for weight-only plugins
CutlassGemmConfig cudaKernelConfig(CutlassTileConfigSM100::CtaShape64x128x128B, MainloopScheduleType::AUTO,
EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1, ClusterShape::Undefined,
ClusterShape::Undefined, sm);
cudaKernelConfig.enableCudaKernel = true;
candidate_configs.push_back(cudaKernelConfig);
return candidate_configs;
}
else
{
TLLM_THROW("Not Implemented: SM100 GEMM candidates have not been defined.");

View File

@ -134,8 +134,17 @@ LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch)
{
return getLayoutDetailsForArch<cutlass::arch::Sm90>(quant_type);
}
else if (arch >= 100)
else if (arch == 100)
{
return getLayoutDetailsForArch<cutlass::arch::Sm100>(quant_type);
}
else if (arch == 103)
{
return getLayoutDetailsForArch<cutlass::arch::Sm103>(quant_type);
}
else if (arch >= 120)
{
// Use SM80 implementation for GB20x.
return getLayoutDetailsForArch<cutlass::arch::Sm80>(quant_type);
}
else
@ -591,23 +600,31 @@ void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, in
// Works on row major data, so issue this permutation first.
if (details.uses_imma_ldsm)
{
TLLM_LOG_INFO("permute_B_rows_for_mixed_gemm");
permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type, arch);
src_buf.swap(dst_buf);
}
if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR)
{
TLLM_LOG_INFO("subbyte_transpose");
subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type);
src_buf.swap(dst_buf);
}
if (details.columns_interleaved > 1 && arch != 90)
if (details.columns_interleaved > 1 && (arch != 90))
{
TLLM_LOG_INFO("interleave_column_major_tensor");
interleave_column_major_tensor(dst_buf.data(), src_buf.data(), shape, quant_type, details);
src_buf.swap(dst_buf);
}
add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type);
if (arch != 100 && arch != 103)
{
TLLM_LOG_INFO("add_bias_and_interleave_quantized_tensor_inplace");
add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type);
}
TLLM_LOG_INFO("copy");
std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight);
}

View File

@ -40,6 +40,7 @@
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm100.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h"
namespace tk = tensorrt_llm::common;
@ -429,7 +430,7 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
}
else if ((sm_ >= 80 && sm_ < 89) || sm_ >= 100)
else if ((sm_ >= 80 && sm_ < 89) || sm_ >= 120)
{
dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, cutlass::arch::Sm80,
QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
@ -454,9 +455,27 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
static_assert(!cutlass::platform::is_same<ActivationType, __nv_fp8_e4m3>::value
|| cutlass::platform::is_same<ScaleZeroType, half>::value,
"ScaleZeroType must be half for activation=fp8");
#ifdef COMPILE_HOPPER_TMA_GEMMS
cutlass_kernels_oss::sm90_dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType,
OutputType, QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k,
group_size, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
#else // COMPILE_HOPPER_TMA_GEMMS
throw std::runtime_error(
"[TensorRT LLM Error][fpA_intB Runner] Please recompile with support for hopper by passing 90-real as an "
"arch to build_wheel.py.");
#endif // COMPILE_HOPPER_TMA_GEMMS
}
else if (sm_ == 100 || sm_ == 103)
{
#ifdef COMPILE_BLACKWELL_TMA_GEMMS
cutlass_kernels_oss::sm100_dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType,
OutputType, QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k,
group_size, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
#else // COMPILE_BLACKWELL_TMA_GEMMS
throw std::runtime_error(
"[TensorRT LLM Error][fpA_intB Runner] Please recompile with support for blackwell by passing 100-real as "
"an arch to build_wheel.py.");
#endif // COMPILE_BLACKWELL_TMA_GEMMS
}
else
{
@ -537,8 +556,9 @@ CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, Bia
{
static constexpr bool is_weight_only = !std::is_same<ActivationType, WeightType>::value;
tkc::CutlassGemmConfig::CandidateConfigTypeParam config_type_param
= tkc::CutlassGemmConfig::CandidateConfigTypeParam::HOPPER;
tkc::CutlassGemmConfig::CandidateConfigTypeParam config_type_param = (sm_ >= 100)
? tkc::CutlassGemmConfig::CandidateConfigTypeParam::BLACKWELL
: tkc::CutlassGemmConfig::CandidateConfigTypeParam::HOPPER;
if (is_weight_only)
{
config_type_param = static_cast<tkc::CutlassGemmConfig::CandidateConfigTypeParam>(

View File

@ -0,0 +1,153 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cute/numeric/integral_constant.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm100.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels_oss
{
namespace tk = tensorrt_llm::common;
namespace tkc = tensorrt_llm::cutlass_extensions;
using namespace cute;
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
typename MainloopSchedule>
void sm100_dispatch_epilogue_schedules(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
cudaStream_t stream, int* occupancy = nullptr)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
switch (gemm_config.epilogue_schedule)
{
case tkc::EpilogueScheduleType::AUTO:
// TODO: use heuristics to select the epilogue schedule, depending on the CTA shape and cluster shape
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm;
sm100_generic_mixed_gemm_kernelLauncher<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType,
QuantOp, EpilogueTag, CTAShape, ClusterShape, MainloopSchedule, EpilogueSchedule>(A, B, weight_scales,
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream,
occupancy);
break;
default:
throw std::runtime_error(
"[TensorRT LLM Error][fpA_intB][sm100_dispatch_epilogue_schedules] Unsupported epilogue schedule for mixed "
"type GEMM.");
break;
}
}
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape>
void sm100_dispatch_mainloop_schedules(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
cudaStream_t stream, int* occupancy = nullptr)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
switch (gemm_config.mainloop_schedule)
{
case tkc::MainloopScheduleType::AUTO:
// TODO: use heuristics to select the mainloop schedule, depending on the CTA shape and cluster shape
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100;
sm100_dispatch_epilogue_schedules<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
EpilogueTag, CTAShape, ClusterShape, MainloopSchedule>(A, B, weight_scales, weight_zero_points, biases,
alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
default:
throw std::runtime_error(
"[TensorRT LLM Error][fpA_intB][sm100_dispatch_mainloop_schedules] Unsupported mainloop schedule for mixed "
"type GEMM.");
break;
}
}
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape>
void sm100_dispatch_gemm_config(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
cudaStream_t stream, int* occupancy = nullptr)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
switch (gemm_config.cluster_shape)
{
// TODO: add support for other cluster shapes
case tkc::ClusterShape::ClusterShape_1x1x1:
sm100_dispatch_mainloop_schedules<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
EpilogueTag, CTAShape, Shape<_1, _1, _1>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n,
k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
default:
throw std::runtime_error(
"[TensorRT LLM Error][fpA_intB][sm100_dispatch_gemm_config] Unsupported CTA and Cluster shapes for mixed "
"type GEMM.");
break;
}
}
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag>
void sm100_dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
int k, int const group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config,
cudaStream_t stream, int* occupancy = nullptr)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
// Note that SIMT configs are omitted here since they are not supported for fpA_intB.
// We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best
// for mixed type gemms.
constexpr int Ktile = 128 / sizeof(ActivationType);
using _Ktile = Int<Ktile>;
switch (gemm_config.tile_config_sm100)
{
case tkc::CutlassTileConfigSM100::CtaShape64x128x128B:
sm100_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
EpilogueTag, Shape<_64, _128, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k,
group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
case tkc::CutlassTileConfigSM100::CtaShape128x128x128B:
sm100_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
EpilogueTag, Shape<_128, _128, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k,
group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
default:
throw std::runtime_error(
"[TensorRT LLM Error][fpA_intB][sm100_dispatch_gemm_to_cutlass] Unsupported tile shape for mixed type "
"GEMM.");
break;
}
}
} // namespace cutlass_kernels_oss
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,39 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cutlass_extensions/gemm_configs.h"
#include "cutlass_extensions/weight_only_quant_op.h"
#include <cuda_runtime_api.h>
namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels_oss
{
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
typename MainloopScheduleType, typename EpilogueScheduleType>
void sm100_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B,
ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases,
float const alpha, OutputType* C, int m, int n, int k, int const group_size,
tensorrt_llm::cutlass_extensions::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
cudaStream_t stream, int* occupancy = nullptr);
} // namespace cutlass_kernels_oss
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,286 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif // __GNUC__
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass_extensions/compute_occupancy.h"
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm_configs.h"
#include "cutlass_extensions/gemm/collective/collective_builder_sm100_weightonly.hpp"
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
#pragma GCC diagnostic pop
#endif // __GNUC__
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm100.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels_oss
{
using namespace tensorrt_llm::kernels::cutlass_kernels;
namespace tk = tensorrt_llm::common;
namespace tkc = tensorrt_llm::cutlass_extensions;
using namespace cute;
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
typename MainloopScheduleType, typename EpilogueScheduleType>
void sm100_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B,
ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases,
float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config,
char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
using CutlassActivationType = typename TllmToCutlassTypeAdapter<ActivationType>::type;
#ifdef COMPILE_BLACKWELL_TMA_GEMMS
if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v<cutlass::arch::Sm100, CTAShape, ClusterShape,
false, ActivationType>)
{
using CutlassWeightType__ = typename TllmToCutlassTypeAdapter<WeightType>::type;
// We need to remap this since SM100 uses a different layout for the weight matrix.
using CutlassWeightType_ = std::conditional_t<std::is_same_v<CutlassWeightType__, cutlass::uint4b_t>,
cutlass::int4b_t, CutlassWeightType__>;
using CutlassWeightType
= std::conditional_t<std::is_same_v<CutlassWeightType_, uint8_t>, int8_t, CutlassWeightType_>;
using CutlassScaleZeroType = typename TllmToCutlassTypeAdapter<ScaleZeroType>::type;
using CutlassBiasType = typename TllmToCutlassTypeAdapter<BiasType>::type;
using CutlassOutputType = typename TllmToCutlassTypeAdapter<OutputType>::type;
static_assert(std::is_same_v<CutlassActivationType, cutlass::half_t>
|| std::is_same_v<CutlassActivationType, cutlass::bfloat16_t>
|| std::is_same_v<CutlassActivationType, cutlass::float_e4m3_t>
|| std::is_same_v<CutlassActivationType, cutlass::float_e5m2_t>,
"Activation type must be bfloat16, half, FP8");
static_assert(std::is_same_v<CutlassWeightType, int8_t> || std::is_same_v<CutlassWeightType, cutlass::int4b_t>
|| std::is_same_v<CutlassWeightType, cutlass::float_e4m3_t>
|| std::is_same_v<CutlassWeightType, cutlass::float_e5m2_t>,
"Weight type must be fp8, int8_t or int4_t");
static_assert(!std::is_same_v<CutlassActivationType, cutlass::float_e4m3_t>
|| std::is_same_v<CutlassScaleZeroType, cutlass::half_t>,
"Scale/Zero type must be half for fp8 activation");
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<CutlassActivationType>::value;
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<CutlassWeightType>::value;
// This example manually swaps and transposes, so keep transpose of input layouts
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
using ElementZero = CutlassScaleZeroType;
using ElementScale = CutlassScaleZeroType;
// C/D matrix configuration. We reuse the C operand for the bias and set the stride for broadcast.
using LayoutBias = cutlass::layout::RowMajor;
constexpr int AlignmentBias = 128 / cutlass::sizeof_bits<CutlassBiasType>::value;
// D matrix configuration
using LayoutOutput = cutlass::layout::RowMajor;
constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<CutlassOutputType>::value;
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ElementCompute = float; // Element type for epilogue computation
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
// using TileShape = CTAShape; // Threadblock-level tile size
constexpr static bool Is2SM = cute::size<0>(ClusterShape{}) == 2;
constexpr static int TileM = cute::size<0>(CTAShape{}) * (Is2SM ? 2 : 1);
constexpr static int TileN = cute::size<1>(CTAShape{});
constexpr static int TileK = cute::size<2>(CTAShape{});
using TileShape = cute::Shape<cute::Int<TileM>, cute::Int<TileN>, cute::Int<TileK>>;
using MainloopSchedule = std::conditional_t<Is2SM, cutlass::gemm::KernelTmaWarpSpecialized2SmMixedInputSm100,
cutlass::gemm::KernelTmaWarpSpecialized1SmMixedInputSm100>;
using EpilogueSchedule = std::conditional_t<Is2SM, cutlass::epilogue::TmaWarpSpecialized2Sm,
cutlass::epilogue::TmaWarpSpecialized1Sm>;
static_assert(std::is_same_v<EpilogueTag, tensorrt_llm::cutlass_extensions::EpilogueOpBias>, "");
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<ArchTag, OperatorClass, TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementCompute,
// Transpose layout of D here since we use the explicit swap + transpose trick
// Void C since we don't use it. Prevents smem allocation.
CutlassBiasType, typename cutlass::layout::LayoutTranspose<LayoutBias>::type, AlignmentBias,
CutlassOutputType, typename cutlass::layout::LayoutTranspose<LayoutOutput>::type, AlignmentOutput,
EpilogueSchedule>::CollectiveOp;
using PackedScaleZero = cute::tuple<CutlassWeightType, ElementScale, ElementZero>;
using PackedScale = cute::tuple<CutlassWeightType, ElementScale>;
using ElementBCollectiveInfo = std::conditional_t<cutlass::hasZero(QuantOp), PackedScaleZero, PackedScale>;
constexpr int ScaleGranularityN = 1; // Should be less than or equal to GEMM_N
constexpr int ScaleGranularityK = size<2>(TileShape{}); // Should be less than or equal to GEMM_K
using ScaleConfig = cutlass::detail::Sm100MixedInputBlockwiseScaleConfig<ScaleGranularityN, ScaleGranularityK>;
using LayoutScale = decltype(ScaleConfig::deduce_layout_scale()); // Layout type for SFA matrix operand
LayoutScale layout_S = ScaleConfig::tile_atom_to_shape_scale(make_shape(n, k, 1));
// We swap A and B operands to the builder here
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilderSm100WeightOnly<ArchTag,
OperatorClass, ElementBCollectiveInfo, cute::tuple<LayoutB_Transpose, LayoutScale>, AlignmentB,
CutlassActivationType, LayoutA_Transpose, AlignmentA, ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename GemmKernel::StrideA;
using StrideB = typename GemmKernel::StrideB;
using StrideC = typename GemmKernel::StrideC;
using StrideD = typename GemmKernel::StrideD;
if (weight_scales == nullptr)
{
throw std::runtime_error("Weight scales must always be set to a non-null value.");
}
if constexpr (cutlass::isFinegrained(QuantOp))
{
int cta_shape_k = cute::size<2>(TileShape{});
if (group_size % cta_shape_k != 0)
{
std::string err_msg = "The group size must a multiple of " + std::to_string(cta_shape_k);
throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner]" + err_msg);
}
if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY)
{
if (weight_zero_points != nullptr)
{
throw std::runtime_error("Weight zero pointer must be a nullptr for scale only fine grained");
}
}
else if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS)
{
if (weight_zero_points == nullptr)
{
throw std::runtime_error("Weight zero pointer must be valid for scale and bias fine grained");
}
}
}
else
{
if (group_size != k)
{
throw std::runtime_error("Invalid group size for per column scaling kernels.");
}
if (weight_zero_points != nullptr)
{
throw std::runtime_error("Weight zero-points must be null when running per column scaling");
}
}
StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1));
StrideC stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(0, m, 1));
typename Gemm::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, {n, m, k, 1},
{reinterpret_cast<CutlassWeightType const*>(B), stride_B, reinterpret_cast<CutlassActivationType const*>(A),
stride_A, reinterpret_cast<ElementScale const*>(weight_scales), layout_S, group_size,
reinterpret_cast<ElementZero const*>(weight_zero_points)},
{{alpha}, reinterpret_cast<CutlassBiasType const*>(biases), stride_C,
reinterpret_cast<CutlassOutputType*>(C), stride_D}};
Gemm gemm;
if (gemm.get_workspace_size(args) > workspace_bytes)
{
TLLM_LOG_ERROR("[TensorRT LLM Error][fpA_intB Runner] given workspace size insufficient.");
}
auto can_implement = gemm.can_implement(args);
if (can_implement != cutlass::Status::kSuccess)
{
std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: "
+ std::string(cutlass::cutlassGetStatusString(can_implement));
std::cout << err_msg << std::endl;
throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner] " + err_msg);
}
auto init_status = gemm.initialize(args, workspace, stream);
if (init_status != cutlass::Status::kSuccess)
{
std::string err_msg = "Failed to initialize cutlass fpA_intB gemm. Error: "
+ std::string(cutlassGetStatusString(init_status));
throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner] " + err_msg);
}
auto run_status = gemm.run(stream);
if (run_status != cutlass::Status::kSuccess)
{
std::string err_msg
= "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status));
throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner] " + err_msg);
}
}
else
{
std::stringstream ss;
ss << "[TensorRT LLM Error][fpA_intB Runner] Config (" << (int64_t) cute::size<0>(CTAShape{}) << ","
<< (int64_t) cute::size<1>(CTAShape{}) << "," << (int64_t) cute::size<2>(CTAShape{}) << ") ("
<< (int64_t) cute::size<0>(ClusterShape{}) << "," << (int64_t) cute::size<1>(ClusterShape{}) << ","
<< (int64_t) cute::size<2>(ClusterShape{}) << ") not compiled with FAST_BUILD.";
throw std::runtime_error(ss.str());
}
#else // COMPILE_BLACKWELL_TMA_GEMMS
throw std::runtime_error(
"[TensorRT LLM Error][fpA_intB Runner] Please recompile with support for blackwell by passing 100-real as an "
"arch "
"to build_wheel.py.");
#endif // COMPILE_BLACKWELL_TMA_GEMMS
}
} // namespace cutlass_kernels_oss
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -213,8 +213,10 @@ def instantiate_operation_tma_warp_specialized(operation):
if operation.gemm_kind == GemmKind.Gemm:
weight_tag = DataTypeTag[operation.weight_type]
# Use sm100_generic_mixed_gemm_kernelLauncher for SM100 and above
launcher_name = "sm100_generic_mixed_gemm_kernelLauncher" if operation.arch >= 100 else "sm90_generic_mixed_gemm_kernelLauncher"
instantiation = f"""
template void sm90_generic_mixed_gemm_kernelLauncher<{act_tag}, {weight_tag}, {scale_zero_tag}, {bias_tag}, {out_tag},
template void {launcher_name}<{act_tag}, {weight_tag}, {scale_zero_tag}, {bias_tag}, {out_tag},
{quant_op}, {epi_tag},
{cute_cta_shape}, {cute_cga_shape},
{kernel_sched}, {epi_sched}> (
@ -808,8 +810,69 @@ def generate_sm103_operations(is_arch_enabled):
return operations
def generate_sm100_mixed_gemm_operations(is_arch_enabled):
if not is_arch_enabled:
return []
arch = 100
# For SM100 (Blackwell), we support similar dtypes as SM90
# Takes the form (activation_type, weight_type, scalezero_type, bias_type, output_type)
supported_dtypes = [
(DataType.e4m3, DataType.u4, DataType.f16, DataType.f16, DataType.f16),
(DataType.e4m3, DataType.u4, DataType.f16, DataType.bf16,
DataType.bf16),
(DataType.f16, DataType.u4, DataType.f16, DataType.f16, DataType.f16),
(DataType.bf16, DataType.u4, DataType.bf16, DataType.bf16,
DataType.bf16),
(DataType.f16, DataType.u8, DataType.f16, DataType.f16, DataType.f16),
(DataType.bf16, DataType.u8, DataType.bf16, DataType.bf16,
DataType.bf16)
]
quant_ops = [
TrtLlm_QuantOp.per_column_scale_only,
TrtLlm_QuantOp.finegrained_scale_only,
TrtLlm_QuantOp.finegrained_scale_and_zeros
]
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_bias]
# SM100 uses different tile shapes
M_TILES = [64, 128]
N_TILES = [128]
cta_shapes_mn = product(M_TILES, N_TILES)
warp_shape = [4, 1, 1]
stages = 0 # auto
# SM100 currently only supports 1x1x1 cluster shape
cga_shapes = [(1, 1, 1)]
partial_args = product(supported_dtypes, quant_ops, epi_tags, cta_shapes_mn,
cga_shapes)
operations = list()
for dtype_combo, quant_op, epi_tag, cta_shape_mn, cga_shape in partial_args:
max_k_bits = 128 * 8
cta_shape_k = max_k_bits // GetDataTypeBits(dtype_combo[0])
cta_shape_mnk = cta_shape_mn + (cta_shape_k, )
# SM100 uses 1SM schedule for mixed type GEMM
mainloop_schedule = KernelScheduleType.TmaWarpSpecialized1SmSm100
epi_schedule = EpilogueScheduleType.TmaWarpSpecialized1Sm
fpA_intB_operation = TrtLlm_GemmLauncher(GemmKind.Gemm, arch, *dtype_combo, quant_op, epi_tag, cta_shape_mnk, \
warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule)
if is_gemm_op_valid_sm100(fpA_intB_operation):
operations.append(fpA_intB_operation)
return operations
def generate_sm100_operations(is_arch_enabled):
operations = generate_sm100_grouped_gemm_operations(is_arch_enabled, 100)
operations.extend(generate_sm100_mixed_gemm_operations(is_arch_enabled))
return operations
@ -878,6 +941,7 @@ if __name__ == "__main__":
output_dir = os.path.abspath(args.output_dir)
fpA_intB_inl = "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl"
fpA_intB_sm100_inl = "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm100.inl"
moe_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl"
# moe_gemm_inl = "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl"
moe_mixed_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl"
@ -887,6 +951,8 @@ if __name__ == "__main__":
inl_map = {
(GemmKind.Gemm, 90): [fpA_intB_inl],
(GemmKind.Gemm, 100): [fpA_intB_sm100_inl],
(GemmKind.Gemm, 103): [fpA_intB_sm100_inl],
(GemmKind.Grouped, 90): [moe_gemm_inl],
(GemmKind.Grouped, 100): [moe_gemm_inl],
(GemmKind.Grouped, 103): [moe_gemm_inl],

View File

@ -0,0 +1,32 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace weight_only
{
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajor, false, 64);
// KTile=128 for w4a8
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajor, false, 128);
} // namespace weight_only
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,29 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace weight_only
{
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
KernelType::BF16Int4PerChannel, BF16DetailsA, Int4DetailsW, ColumnMajor, false, 64);
} // namespace weight_only
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,29 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace weight_only
{
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajor, false, 64);
} // namespace weight_only
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,29 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace weight_only
{
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
KernelType::BF16Int8PerChannel, BF16DetailsA, Int8DetailsW, ColumnMajor, false, 64);
} // namespace weight_only
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,32 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace weight_only
{
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajor, false, 64);
// KTile=128 for w4a8
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajor, false, 128);
} // namespace weight_only
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,29 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace weight_only
{
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajor, false, 64);
} // namespace weight_only
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,29 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace weight_only
{
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajor, false, 64);
} // namespace weight_only
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,29 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace weight_only
{
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajor, false, 64);
} // namespace weight_only
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -52,7 +52,7 @@ inline void kernel_launcher(int arch, Params& params, cudaStream_t s)
EXEC(KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true);
EXEC(KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true);
}
else if ((arch >= 80 && arch < 90) || arch >= 100)
else if ((arch >= 80 && arch < 90) || arch >= 120)
{
if (arch == 89 || arch >= 120)
{
@ -68,7 +68,7 @@ inline void kernel_launcher(int arch, Params& params, cudaStream_t s)
EXEC(KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true);
EXEC(KernelType::BF16Int4PerChannel, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true);
}
else if (arch >= 90)
else if (arch == 90)
{
// Dispatchers for W4A8 groupwise
EXEC_W4A8(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true);
@ -83,6 +83,20 @@ inline void kernel_launcher(int arch, Params& params, cudaStream_t s)
EXEC(KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true);
EXEC(KernelType::BF16Int4PerChannel, BF16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true);
}
else if (arch == 100 || arch == 103)
{
EXEC_W4A8(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajor, false);
EXEC_W4A8(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajor, false);
EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajor, false);
EXEC(KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajor, false);
EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajor, false);
EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajor, false);
EXEC(KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajor, false);
EXEC(KernelType::BF16Int8PerChannel, BF16DetailsA, Int8DetailsW, ColumnMajor, false);
EXEC(KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajor, false);
EXEC(KernelType::BF16Int4PerChannel, BF16DetailsA, Int4DetailsW, ColumnMajor, false);
}
#undef EXEC_W4A8
#undef EXEC
}

View File

@ -123,7 +123,7 @@ The language component decides which quantization methods are supported by a giv
| Model | NVFP4 | MXFP4 | FP8(per tensor)| FP8(block scaling) | FP8(rowwise) | FP8 KV Cache | NVFP4 KV Cache | W4A8 AWQ | W4A16 AWQ | W4A8 GPTQ | W4A16 GPTQ |
| :------------- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :-------: | :-------: | :--------: | :--------: |
| Blackwell(sm120) | Y | Y | Y | . | . | Y | . | . | . | . | . |
| Blackwell(sm100/103) | Y | Y | Y | Y | . | Y | Y | . | . | . | . |
| Blackwell(sm100/103) | Y | Y | Y | Y | . | Y | Y | Y | Y | Y | Y |
| Hopper | . | . | Y | Y | Y | Y | . | Y | Y | Y | Y |
| Ada Lovelace | . | . | Y | . | . | Y | . | Y | Y | Y | Y |
| Ampere | . | . | . | . | . | Y | . | . | Y | . | Y |

View File

@ -1327,7 +1327,7 @@ def _(
class FinegrainedMixedDtypeGemm(TunableRunner):
_runner_dict = dict()
MAX_SUPPORTED_SM_VERSION = 90
MAX_SUPPORTED_SM_VERSION = 103
def __init__(self, activation_dtype: torch.dtype, output_dtype: torch.dtype,
quant_mode: int):
@ -1362,7 +1362,7 @@ class FinegrainedMixedDtypeGemm(TunableRunner):
if get_sm_version() > self.MAX_SUPPORTED_SM_VERSION:
raise ValueError(
f"SM version {get_sm_version()} is not supported for W4A16 GEMM"
f"SM version {get_sm_version()} is not supported for W4A16/W4A8 finegrained mixed dtype GEMM"
)
activation, weights_packed, scales = inputs

View File

@ -961,8 +961,8 @@ def preprocess_weights_for_mixed_gemm(
tensor = tensor.unsqueeze(0)
elif sm_ >= 90:
sm_ = 80
if sm_ > 90:
sm_ = 80
if sm_ == 100 or sm_ == 103:
do_weight_interleave = False
permutation_map = {
"16_8": [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15],
@ -990,7 +990,7 @@ def preprocess_weights_for_mixed_gemm(
assert (num_rows % B_ROWS_PER_MMA == 0)
assert (num_cols % MMA_SHAPE_N == 0)
if do_weight_interleave:
if do_weight_interleave and sm_ < 100:
row_idx_list = [(row_idx // B_ROWS_PER_MMA) * B_ROWS_PER_MMA +
permutation_map[f"{BITS_PER_ELT_A}_{BITS_PER_ELT_B}"][
row_idx % B_ROWS_PER_MMA]