mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] sm100 weight-only kernel (#10190)
Signed-off-by: Cheng Hang <chang@nvidia.com>
This commit is contained in:
parent
b5a1e10bc0
commit
656c705ff1
@ -686,4 +686,212 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template <class Collective>
|
||||
struct MixedInputUtilsSM100
|
||||
{
|
||||
private:
|
||||
using KernelSchedule = typename Collective::KernelSchedule;
|
||||
using ConversionMode = typename Collective::ConversionMode;
|
||||
using SmemLayoutA = typename Collective::SmemLayoutA;
|
||||
using SmemLayoutB = typename Collective::SmemLayoutB;
|
||||
using ElementScale = typename Collective::ElementScale;
|
||||
using ElementZero = typename Collective::ElementZero;
|
||||
static constexpr auto KernelConversionMode = Collective::KernelConversionMode;
|
||||
|
||||
public:
|
||||
// Helper functions to select packing for conversion
|
||||
template <class SrcType, class DstType, int Cosize>
|
||||
struct select_packing
|
||||
{ // Naive packing policy
|
||||
|
||||
static constexpr auto value()
|
||||
{
|
||||
return Int<cute::gcd(Cosize, 32 / cute::min(sizeof_bits_v<SrcType>, sizeof_bits_v<DstType>))>{};
|
||||
}
|
||||
};
|
||||
|
||||
/// (Designed for separate transform pipeline in Blackwell)
|
||||
/// Utilities to dequantize A.
|
||||
template <class EngineIn, class EngineOut, class LayoutIn, class LayoutOut, class... Ts>
|
||||
CUTLASS_DEVICE static void dequantize_A_kblock_for_transform(Tensor<EngineIn, LayoutIn> const& tArA,
|
||||
Tensor<EngineOut, LayoutOut>& tArACompute, cute::tuple<Ts...> const& partitioned_extra_info, int const k_block)
|
||||
{
|
||||
|
||||
static_assert(is_rmem<EngineIn>::value, "Input tensor for A conversion must come from registers");
|
||||
static_assert(is_rmem<EngineOut>::value, "Output tensor for A conversion must come from registers");
|
||||
static_assert(cosize_v<LayoutIn> == cosize_v<LayoutOut>);
|
||||
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
|
||||
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
|
||||
using SrcType = typename EngineIn::value_type;
|
||||
using DstType = typename EngineOut::value_type;
|
||||
|
||||
auto src = tArA(_, _, _, k_block);
|
||||
auto dst = tArACompute(_, _, _, k_block);
|
||||
auto pSrc = raw_pointer_cast(src.data());
|
||||
auto pDst = const_cast<DstType*>(raw_pointer_cast(dst.data()));
|
||||
constexpr int num_elements = decltype(size(src))::value;
|
||||
|
||||
constexpr int pack = decltype(select_packing<SrcType, DstType, num_elements>::value())::value;
|
||||
using Converter
|
||||
= cutlass::NumericArrayConverter<DstType, SrcType, pack, cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using SrcArray = cutlass::Array<SrcType, pack>;
|
||||
using DstArray = cutlass::Array<DstType, pack>;
|
||||
constexpr int DstElementsPerReg = 32 / sizeof_bits_v<DstType>;
|
||||
using RegArray = cutlass::AlignedArray<uint32_t, pack / DstElementsPerReg, sizeof(DstArray)>;
|
||||
|
||||
auto src_arr = recast<SrcArray>(src);
|
||||
auto dst_arr = recast<DstArray>(dst);
|
||||
|
||||
Tensor dst_vm = cute::group_modes<1, -1>(cute::zipped_divide(dst, pack));
|
||||
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert)
|
||||
{
|
||||
cute::transform(src_arr, dst_arr, Converter::convert);
|
||||
}
|
||||
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale)
|
||||
{
|
||||
|
||||
auto const& scales = cute::get<1>(partitioned_extra_info)(_, _, _, k_block);
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
|
||||
|
||||
if constexpr (is_same_v<DstType, ElementScale>)
|
||||
{
|
||||
cute::transform(src_arr, dst_arr, Converter::convert);
|
||||
|
||||
using ScaleArray = cutlass::Array<ElementScale, pack>;
|
||||
auto scale_arr = recast<ScaleArray>(filter_zeros(scales));
|
||||
|
||||
if constexpr (is_same_v<DstType, cutlass::bfloat16_t>)
|
||||
{
|
||||
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, pack));
|
||||
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i)
|
||||
{
|
||||
auto&& r = cute::recast<RegArray>(dst_vm(_, i))(0);
|
||||
auto&& scale_reg = cute::recast<RegArray>(scales_vm(_, i))(0);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (size_t ii = 0; ii < RegArray::kElements; ++ii)
|
||||
{
|
||||
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||
bf16x2_val = __hmul2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(scale_reg[ii]));
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
cute::transform(dst_arr, scale_arr, dst_arr, cute::multiplies{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr int pack1 = decltype(select_packing<SrcType, ElementScale, num_elements>::value())::value;
|
||||
constexpr int pack2 = decltype(select_packing<ElementScale, DstType, num_elements>::value())::value;
|
||||
constexpr int pack = cute::gcd(pack1, pack2);
|
||||
using Converter1 = cutlass::NumericArrayConverter<ElementScale, SrcType, pack,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using Converter2 = cutlass::NumericArrayConverter<DstType, ElementScale, pack,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using SrcArray = cutlass::Array<SrcType, pack>;
|
||||
using DstArray = cutlass::Array<DstType, pack>;
|
||||
using StageArray = cutlass::Array<ElementScale, pack>;
|
||||
constexpr int iters = num_elements / pack;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < iters; ++i)
|
||||
{
|
||||
SrcArray const* pSrcArr = reinterpret_cast<SrcArray const*>(pSrc) + i;
|
||||
DstArray* pDstArr = reinterpret_cast<DstArray*>(pDst) + i;
|
||||
StageArray stageArr;
|
||||
stageArr = Converter1::convert(*pSrcArr);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < pack; ++j)
|
||||
{
|
||||
stageArr[j] = stageArr[j] * scales[i * pack + j];
|
||||
}
|
||||
*pDstArr = Converter2::convert(stageArr);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero)
|
||||
{
|
||||
static_assert(is_same_v<ElementScale, ElementZero>, "ElementScale and ElementZero must be the same.");
|
||||
|
||||
auto const& scales = cute::get<1>(partitioned_extra_info)(_, _, _, k_block);
|
||||
auto const& zeros = cute::get<3>(partitioned_extra_info)(_, _, _, k_block);
|
||||
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
|
||||
CUTE_STATIC_ASSERT_V(size(src) == size(zeros));
|
||||
|
||||
if constexpr (is_same_v<DstType, ElementZero>)
|
||||
{
|
||||
cute::transform(src_arr, dst_arr, Converter::convert);
|
||||
|
||||
using ScaleArray = cutlass::Array<ElementScale, pack>;
|
||||
auto scale_arr = recast<ScaleArray>(filter_zeros(scales));
|
||||
|
||||
using ZeroArray = cutlass::Array<ElementZero, pack>;
|
||||
auto zero_arr = recast<ZeroArray>(filter_zeros(zeros));
|
||||
|
||||
if constexpr (is_same_v<DstType, cutlass::bfloat16_t>)
|
||||
{
|
||||
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, pack));
|
||||
Tensor zeros_vm = cute::group_modes<1, -1>(cute::zipped_divide(zeros, pack));
|
||||
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i)
|
||||
{
|
||||
auto&& r = cute::recast<RegArray>(dst_vm(_, i))(0);
|
||||
auto&& scale_reg = cute::recast<RegArray>(scales_vm(_, i))(0);
|
||||
auto&& zero_reg = cute::recast<RegArray>(zeros_vm(_, i))(0);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (size_t ii = 0; ii < RegArray::kElements; ++ii)
|
||||
{
|
||||
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||
bf16x2_val = __hmul2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(scale_reg[ii]));
|
||||
bf16x2_val = __hadd2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(zero_reg[ii]));
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
cute::transform(dst_arr, scale_arr, dst_arr, cute::multiplies{});
|
||||
cute::transform(dst_arr, zero_arr, dst_arr, cute::plus{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr int pack1 = decltype(select_packing<SrcType, ElementScale, num_elements>::value())::value;
|
||||
constexpr int pack2 = decltype(select_packing<ElementScale, DstType, num_elements>::value())::value;
|
||||
constexpr int pack = cute::gcd(pack1, pack2);
|
||||
using Converter1 = cutlass::NumericArrayConverter<ElementScale, SrcType, pack,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using Converter2 = cutlass::NumericArrayConverter<DstType, ElementScale, pack,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using SrcArray = cutlass::Array<SrcType, pack>;
|
||||
using DstArray = cutlass::Array<DstType, pack>;
|
||||
using StageArray = cutlass::Array<ElementScale, pack>;
|
||||
constexpr int iters = num_elements / pack;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < iters; ++i)
|
||||
{
|
||||
SrcArray const* pSrcArr = reinterpret_cast<SrcArray const*>(pSrc) + i;
|
||||
DstArray* pDstArr = reinterpret_cast<DstArray*>(pDst) + i;
|
||||
StageArray stageArr;
|
||||
stageArr = Converter1::convert(*pSrcArr);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < pack; ++j)
|
||||
{
|
||||
stageArr[j] = stageArr[j] * scales[i * pack + j] + zeros[i * pack + j];
|
||||
}
|
||||
*pDstArr = Converter2::convert(stageArr);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
|
||||
"Conversion mode not handled for input partitioning.");
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace cutlass::gemm::collective::detail
|
||||
|
||||
@ -0,0 +1,294 @@
|
||||
/*
|
||||
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/collective/builders/sm100_common.inl"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
|
||||
template <int CapacityBytes, class ElementA, class ElementAMma, class ElementScale, class ElementZero, class ElementB,
|
||||
class CtaTileShape_MNK, class TiledMma, class KernelScheduleType, UMMA::Major UmmaMajorA, int ScaleGranularityK,
|
||||
int stages>
|
||||
constexpr cute::tuple<int, int, int> sm100_compute_stage_count_or_override_weightonly(StageCount<stages> stage_count)
|
||||
{
|
||||
constexpr int Load2TransformStageCount = stages;
|
||||
constexpr int Transform2MmaStageCount = stages;
|
||||
constexpr int AccumulatorStageCount = stages;
|
||||
return cute::make_tuple(Load2TransformStageCount, Transform2MmaStageCount, AccumulatorStageCount);
|
||||
}
|
||||
|
||||
template <int CapacityBytes, class ElementA, class ElementAMma, class ElementScale, class ElementZero, class ElementB,
|
||||
class CtaTileShape_MNK, class TiledMma, class KernelScheduleType, UMMA::Major UmmaMajorA, int ScaleGranularityK,
|
||||
int carveout_bytes>
|
||||
constexpr cute::tuple<int, int, int> sm100_compute_stage_count_or_override_weightonly(
|
||||
StageCountAutoCarveout<carveout_bytes> stage_count)
|
||||
{
|
||||
|
||||
constexpr int CtaM = get<0>(CtaTileShape_MNK{});
|
||||
constexpr int CtaN = get<1>(CtaTileShape_MNK{});
|
||||
static_assert(CtaN <= 128, "Can't support CtaN>128 tiles");
|
||||
constexpr int CtaK = get<2>(CtaTileShape_MNK{});
|
||||
using AtomThrID = typename TiledMma::AtomThrID;
|
||||
|
||||
constexpr int TmemColumns = 512;
|
||||
|
||||
constexpr bool IsAComputeinTmem = UmmaMajorA == cute::UMMA::Major::K
|
||||
&& !cute::is_base_of_v<KernelTmaWarpSpecializedMixedInputSmemSm100, KernelScheduleType>;
|
||||
constexpr bool IsAComputeinSmem = !IsAComputeinTmem;
|
||||
|
||||
// Detect 2x2 TMEM layout
|
||||
constexpr int TmemAccWordsPerDP = (CtaM == 64 && size(AtomThrID{}) == 2) ? CtaN / 2 : CtaN;
|
||||
constexpr int TmemAWordsPerDP = CtaK / 2;
|
||||
|
||||
constexpr int AccumulatorStageCount
|
||||
= (IsAComputeinTmem) ? ((TmemAccWordsPerDP == 128) ? 2 : 3) : (TmemColumns / TmemAccWordsPerDP);
|
||||
|
||||
constexpr int SmemCapacityAfterMma2AccumCarveout = CapacityBytes - (carveout_bytes + AccumulatorStageCount * 32);
|
||||
|
||||
constexpr int TmemInAStageCount_Potential
|
||||
= (IsAComputeinTmem) ? (TmemColumns - AccumulatorStageCount * TmemAccWordsPerDP) / TmemAWordsPerDP : 10000;
|
||||
|
||||
// Mainload2Transform Pipeline
|
||||
constexpr auto load2transform_pipeline_bytes
|
||||
= sizeof(typename cutlass::PipelineTmaTransformAsync<1>::SharedStorage);
|
||||
constexpr auto a_bits = cute::sizeof_bits_v<ElementA>; // ElementA introduce here
|
||||
constexpr auto s_bits = cute::is_void_v<ElementScale> ? 0 : cute::sizeof_bits_v<ElementScale>;
|
||||
constexpr auto z_bits = cute::is_void_v<ElementZero> ? 0 : cute::sizeof_bits_v<ElementZero>;
|
||||
|
||||
constexpr auto load2mma_pipeline_bytes = sizeof(typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage);
|
||||
constexpr auto b_bits = cute::sizeof_bits_v<ElementB>; // ElementB introduce here
|
||||
|
||||
constexpr int ab_stage_bytes
|
||||
= cutlass::bits_to_bytes(a_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}))
|
||||
+ cutlass::bits_to_bytes(s_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}) / ScaleGranularityK)
|
||||
+ cutlass::bits_to_bytes(z_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}) / ScaleGranularityK)
|
||||
+ cutlass::bits_to_bytes(b_bits * size<1>(CtaTileShape_MNK{}) / size(AtomThrID{}) * size<2>(CtaTileShape_MNK{}))
|
||||
+ static_cast<int>(load2transform_pipeline_bytes) + static_cast<int>(load2mma_pipeline_bytes);
|
||||
|
||||
// Transform2Mma Pipeline
|
||||
constexpr auto transform2mma_pipeline_bytes = sizeof(typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage);
|
||||
constexpr auto a_compute_bits = cute::sizeof_bits_v<ElementAMma>;
|
||||
constexpr int ab_compute_stage_bytes = cutlass::bits_to_bytes(a_compute_bits * int(IsAComputeinSmem)
|
||||
* size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}))
|
||||
+ // If ACompute is in TMEM, Acompute buffer has 0 bytes.
|
||||
static_cast<int>(transform2mma_pipeline_bytes);
|
||||
|
||||
constexpr int ABComputeStageCount_Potential
|
||||
= SmemCapacityAfterMma2AccumCarveout / (ab_stage_bytes + ab_compute_stage_bytes);
|
||||
|
||||
// The number of SMEM buffers for A, B. ACompute (if in SMEM), BCompute should be at least Transform2MmaStageCount
|
||||
constexpr int Transform2MmaStageCount = std::min(TmemInAStageCount_Potential, ABComputeStageCount_Potential);
|
||||
|
||||
constexpr int SmemCapacityAfterABComputeCarveout
|
||||
= SmemCapacityAfterMma2AccumCarveout - (Transform2MmaStageCount * ab_compute_stage_bytes);
|
||||
|
||||
// Can we boost the number of buffers for A and B?
|
||||
constexpr int Load2TransformStageCount = SmemCapacityAfterABComputeCarveout / ab_stage_bytes;
|
||||
|
||||
static_assert(Load2TransformStageCount >= 2 && Transform2MmaStageCount >= 2 && AccumulatorStageCount >= 2,
|
||||
"Not enough SMEM or TMEM capacity for selected tile size");
|
||||
return cute::make_tuple(Load2TransformStageCount, Transform2MmaStageCount, AccumulatorStageCount);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Mixed Input MMA kernels builder
|
||||
template <class ElementAOptionalTuple, class GmemLayoutATagTuple, int AlignmentA, class ElementBOptionalTuple,
|
||||
class GmemLayoutBTag, int AlignmentB, class ElementAccumulator,
|
||||
class TileShape_MNK, // The Cluster-level TileShape
|
||||
class ClusterShape_MNK, class StageCountType, class KernelScheduleType>
|
||||
struct CollectiveBuilderSm100WeightOnly<arch::Sm100, arch::OpClassTensorOp,
|
||||
ElementAOptionalTuple, // ElementA
|
||||
GmemLayoutATagTuple, // LayoutA
|
||||
AlignmentA,
|
||||
ElementBOptionalTuple, // ElementB
|
||||
GmemLayoutBTag, // LayoutB
|
||||
AlignmentB, ElementAccumulator,
|
||||
TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK)
|
||||
ClusterShape_MNK, // Static cluster shape or dynamic (int, int, int)
|
||||
StageCountType, KernelScheduleType,
|
||||
cute::enable_if_t<(cute::is_base_of_v<KernelScheduleSm100MixedInputGemm, KernelScheduleType>) &&(
|
||||
(sizeof(float) * AlignmentA) % detail::tma_alignment_bytes == 0)
|
||||
&& ((sizeof(float) * AlignmentB) % detail::tma_alignment_bytes == 0)>>
|
||||
{
|
||||
using GmemLayoutATag = detail::deduce_mixed_width_dtype_t<0, GmemLayoutATagTuple>;
|
||||
using GmemLayoutScaleTag = detail::deduce_mixed_width_dtype_t<1, GmemLayoutATagTuple>;
|
||||
|
||||
static constexpr cute::UMMA::Major UmmaMajorA
|
||||
= cutlass::gemm::collective::detail::tag_to_umma_major_A<GmemLayoutATag>();
|
||||
static constexpr cute::UMMA::Major UmmaMajorB
|
||||
= cutlass::gemm::collective::detail::tag_to_umma_major_B<GmemLayoutBTag>();
|
||||
|
||||
using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>;
|
||||
using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>;
|
||||
using ElementScale = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple>;
|
||||
using ElementZero = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>;
|
||||
|
||||
static constexpr bool NeitherIsTuple
|
||||
= !cute::is_tuple<ElementAOptionalTuple>::value && !cute::is_tuple<ElementBOptionalTuple>::value;
|
||||
static constexpr bool IsANarrow = cute::sizeof_bits_v<ElementA> < cute::sizeof_bits_v<ElementB>;
|
||||
static constexpr bool IsMixedInput = cute::sizeof_bits_v<ElementA> != cute::sizeof_bits_v<ElementB>;
|
||||
static_assert(IsMixedInput, "Mixed Input GEMM Kernel doesn't support regular gemm.");
|
||||
|
||||
static_assert(
|
||||
(cute::is_tuple<ElementAOptionalTuple>::value ^ cute::is_tuple<ElementBOptionalTuple>::value
|
||||
|| (NeitherIsTuple && (cute::sizeof_bits<ElementA>::value != cute::sizeof_bits<ElementB>::value))),
|
||||
"Either A OR B must be a tuple or the widths of A and B must be different.");
|
||||
using ElementPairA = cute::conditional_t<IsMixedInput && IsANarrow && NeitherIsTuple, cute::tuple<ElementA>,
|
||||
ElementAOptionalTuple>;
|
||||
using ElementPairB = cute::conditional_t<IsMixedInput && !IsANarrow && NeitherIsTuple, cute::tuple<ElementB>,
|
||||
ElementBOptionalTuple>;
|
||||
static constexpr bool IsATransformed = cute::is_tuple<ElementPairA>::value;
|
||||
static_assert(IsATransformed, "A matrix should be transformed.");
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type.
|
||||
using ElementMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
|
||||
|
||||
using ElementAMma = ElementMma;
|
||||
using ElementBMma = ElementMma;
|
||||
|
||||
static constexpr int IsSubbyteA = cute::sizeof_bits_v<ElementA> < 8;
|
||||
using TmaElementA = cute::conditional_t<IsSubbyteA, uint8_t, ElementA>;
|
||||
|
||||
static constexpr int ScalingFactor = 1;
|
||||
|
||||
using TiledMma = decltype(detail::sm100_make_trivial_mixed_input_tiled_mma<ElementAMma, ElementB,
|
||||
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, UmmaMajorA, UmmaMajorB, KernelScheduleType>());
|
||||
using AtomThrID = typename TiledMma::AtomThrID;
|
||||
using AtomThrShapeMNK = Shape<decltype(shape<0>(typename TiledMma::ThrLayoutVMNK{})), _1, _1>;
|
||||
using CtaTileShape_MNK = decltype(shape_div(TileShape_MNK{}, AtomThrShapeMNK{}));
|
||||
|
||||
// ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K)
|
||||
using MmaShapeA_MK = decltype(partition_shape_A(
|
||||
TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), cute::size<2>(TileShape_MNK{}))));
|
||||
// ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K)
|
||||
using MmaShapeB_NK = decltype(partition_shape_B(
|
||||
TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), cute::size<2>(TileShape_MNK{}))));
|
||||
|
||||
using BlockTileA_M = decltype(cute::size<0, 0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{}));
|
||||
using BlockTileA_K = decltype(cute::size<0, 1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(cute::size<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm100_cluster_shape_to_tma_atom_B(ClusterShape_MNK{}, AtomThrID{}));
|
||||
|
||||
// Input transform kernel can not use TMA 2SM instructions.
|
||||
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorA, ElementA,
|
||||
BlockTileA_M, BlockTileA_K>());
|
||||
using SmemLayoutAtomACompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorA,
|
||||
ElementAMma, BlockTileA_M, BlockTileA_K>());
|
||||
using SmemLayoutAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType<SmemLayoutAtomA,
|
||||
SmemLayoutAtomACompute>;
|
||||
static constexpr int MMA_M = cute::size<0, 0>(MmaShapeA_MK{});
|
||||
using CopyAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType<
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementA>,
|
||||
cute::conditional_t<
|
||||
(UmmaMajorA == cute::UMMA::Major::K
|
||||
&& !cute::is_base_of_v<KernelTmaWarpSpecializedMixedInputSmemSm100, KernelScheduleType>),
|
||||
cute::conditional_t<(MMA_M == 64 && size(AtomThrID{}) == 1), SM100_TMEM_STORE_16dp256b1x,
|
||||
SM100_TMEM_STORE_32dp32b8x>, // TS Implementation
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementA>> // SS Implementation
|
||||
>;
|
||||
|
||||
using BlockTileB_N = decltype(cute::size<0, 0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{}));
|
||||
using BlockTileB_K = decltype(cute::size<0, 1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{}));
|
||||
|
||||
// Input transform kernel can not use TMA 2SM instructions.
|
||||
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorB, ElementB,
|
||||
BlockTileB_N, BlockTileB_K>());
|
||||
using SmemLayoutAtomBCompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorB,
|
||||
ElementBMma, BlockTileB_N, BlockTileB_K>());
|
||||
using SmemLayoutAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType<SmemLayoutAtomB,
|
||||
SmemLayoutAtomBCompute>;
|
||||
using CopyAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType<
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementB>,
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementMma>>;
|
||||
|
||||
// Creating the stride of Transformed Input
|
||||
using StrideA = cutlass::gemm::TagToStrideA_t<GmemLayoutATag>;
|
||||
using LayoutScale = cutlass::gemm::TagToStrideA_t<GmemLayoutScaleTag>;
|
||||
|
||||
using VoidShapeScale
|
||||
= Shape<Shape<Int<128>, _1>, Shape<Int<64>, _1>, _1>; // Dummy Value to create a dummy ScaleConfig
|
||||
using VoidStrideScale = Stride<Stride<_0, _1>, Stride<_0, _1>, _1>;
|
||||
using VoidLayoutScale = Layout<VoidShapeScale, VoidStrideScale>;
|
||||
|
||||
using NonVoidLayoutScale = cute::conditional_t<cute::is_void_v<LayoutScale>, VoidLayoutScale, LayoutScale>;
|
||||
|
||||
using StridePairA = decltype(cute::make_tuple(StrideA{}, NonVoidLayoutScale{}));
|
||||
|
||||
// SmemCarveout
|
||||
static constexpr int SchedulerPipelineStageCount = 3;
|
||||
static constexpr bool IsArrayOfPointersGemm
|
||||
= (cute::is_base_of_v<KernelScheduleSm100PtrArrayFastFP32Gemm, KernelScheduleType>);
|
||||
|
||||
// CLCPipeline = PipelineCLCFetchAsync
|
||||
static constexpr auto CLCPipelineStorage
|
||||
= sizeof(typename cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount, ClusterShape_MNK>::SharedStorage);
|
||||
// CLC (scheduler) response
|
||||
static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize;
|
||||
// CLC Throttle pipeline storage
|
||||
static constexpr auto CLCThrottlePipelineStorage
|
||||
= sizeof(typename cutlass::PipelineAsync<SchedulerPipelineStageCount>::SharedStorage);
|
||||
// Tmem dealloc
|
||||
static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier);
|
||||
// Tmem ptr storage
|
||||
static constexpr auto TmemBasePtrsStorage = sizeof(uint32_t);
|
||||
// Tensormap Storage
|
||||
static constexpr size_t TensorMapStorage
|
||||
= IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
|
||||
|
||||
// Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage
|
||||
static constexpr auto KernelSmemCarveout = static_cast<int>(CLCPipelineStorage + CLCResponseStorage
|
||||
+ CLCThrottlePipelineStorage + TmemDeallocStorage + TmemBasePtrsStorage + TensorMapStorage);
|
||||
|
||||
// Reduce SMEM capacity available for buffers considering extra B smem and barrier smem allocations
|
||||
static constexpr int Sm100ReducedSmemCapacityBytes = detail::sm100_smem_capacity_bytes - KernelSmemCarveout;
|
||||
|
||||
static constexpr int ScaleGranularityK = get_ScaleGranularityK<LayoutScale>();
|
||||
|
||||
static constexpr auto stage_info
|
||||
= cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_weightonly<
|
||||
Sm100ReducedSmemCapacityBytes, TmaElementA, ElementAMma, ElementScale, ElementZero, ElementB,
|
||||
CtaTileShape_MNK, TiledMma, KernelScheduleType, UmmaMajorA, ScaleGranularityK>(StageCountType{});
|
||||
|
||||
static constexpr int Load2TransformPipelineStageCount = get<0>(stage_info);
|
||||
static constexpr int Transform2MmaPipelineStageCount = get<1>(stage_info);
|
||||
static constexpr int AccumulatorPipelineStageCount = get<2>(stage_info);
|
||||
|
||||
static_assert(!IsArrayOfPointersGemm, "mixed input does not support grouped gemm on Blackwell");
|
||||
|
||||
using DispatchPolicy
|
||||
= cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedMixedInput<Load2TransformPipelineStageCount,
|
||||
Transform2MmaPipelineStageCount, SchedulerPipelineStageCount, AccumulatorPipelineStageCount,
|
||||
ClusterShape_MNK>;
|
||||
using CollectiveOp = cutlass::gemm::collective::CollectiveMmaSm100WeightOnly<DispatchPolicy, TileShape_MNK,
|
||||
ElementPairA, StridePairA, ElementPairB, cutlass::gemm::TagToStrideB_t<GmemLayoutBTag>, TiledMma,
|
||||
GmemTiledCopyA, SmemLayoutAtomPairA, CopyAtomPairA, cute::identity, GmemTiledCopyB, SmemLayoutAtomPairB,
|
||||
CopyAtomPairB, cute::identity>;
|
||||
};
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
@ -0,0 +1,42 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_mma_sm100_weightonly.hpp"
|
||||
|
||||
namespace cutlass::gemm::collective
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA, int AlignmentA, class ElementB,
|
||||
class GmemLayoutB, int AlignmentB, class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
|
||||
class StageCountType, class KernelScheduleType, class Enable = void>
|
||||
struct CollectiveBuilderSm100WeightOnly
|
||||
{
|
||||
static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/builders/sm100_umma_builder_weightonly.inl"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,42 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/detail/dependent_false.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class DispatchPolicy, class TileShape, class ElementA, class StrideA, class ElementB, class StrideB,
|
||||
class TiledMma, class GmemTiledCopyA, class SmemLayoutAtomA, class SmemCopyAtomA, class TransformA,
|
||||
class GmemTiledCopyB, class SmemLayoutAtomB, class SmemCopyAtomB, class TransformB>
|
||||
struct CollectiveMmaSm100WeightOnly
|
||||
{
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
File diff suppressed because it is too large
Load Diff
@ -533,8 +533,8 @@ struct GemmFpAIntB
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ == 890)
|
||||
run_kernel<arch::Sm89>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 1000)
|
||||
// Use SM80 implementation for GB10x, GB20x.
|
||||
#elif (__CUDA_ARCH__ >= 1200)
|
||||
// Use SM80 implementation for GB20x.
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#else
|
||||
CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.
|
||||
|
||||
@ -87,7 +87,9 @@ public:
|
||||
// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
|
||||
// which signals that we want to dequantize after loading from smem.
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
struct LayoutDetailsB<TypeA, uint8_t, Arch,
|
||||
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability != 100
|
||||
&& Arch::kMinComputeCapability != 103>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
|
||||
@ -102,7 +104,9 @@ public:
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
struct LayoutDetailsB<TypeA, uint4b_t, Arch,
|
||||
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability != 100
|
||||
&& Arch::kMinComputeCapability != 103>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
|
||||
@ -116,6 +120,26 @@ public:
|
||||
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint8_t, Arch,
|
||||
typename platform::enable_if<Arch::kMinComputeCapability == 100 || Arch::kMinComputeCapability == 103>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint4b_t, Arch,
|
||||
typename platform::enable_if<Arch::kMinComputeCapability == 100 || Arch::kMinComputeCapability == 103>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@ -409,14 +409,14 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm100_dynamic_cluster_shape
|
||||
}
|
||||
|
||||
std::vector<std::pair<CutlassTileConfigSM100, ClusterShape>> tile_configs{
|
||||
{CutlassTileConfigSM100::CtaShape128x128x128B, cluster1sm},
|
||||
{CutlassTileConfigSM100::CtaShape128x256x128B, cluster1sm},
|
||||
{CutlassTileConfigSM100::CtaShape128x32x128B, cluster1sm},
|
||||
{CutlassTileConfigSM100::CtaShape64x64x128B, cluster1sm},
|
||||
{CutlassTileConfigSM100::CtaShape64x32x128B, cluster1sm},
|
||||
{CutlassTileConfigSM100::CtaShape64x64x128B, cluster1sm},
|
||||
{CutlassTileConfigSM100::CtaShape64x128x128B, cluster1sm},
|
||||
{CutlassTileConfigSM100::CtaShape64x256x128B, cluster1sm},
|
||||
{CutlassTileConfigSM100::CtaShape128x32x128B, cluster1sm},
|
||||
{CutlassTileConfigSM100::CtaShape128x64x128B, cluster1sm},
|
||||
{CutlassTileConfigSM100::CtaShape128x128x128B, cluster1sm},
|
||||
{CutlassTileConfigSM100::CtaShape128x256x128B, cluster1sm},
|
||||
};
|
||||
|
||||
if (supports_2sm)
|
||||
@ -479,6 +479,30 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm100(
|
||||
}
|
||||
return candidate_configs;
|
||||
}
|
||||
else if (config & CutlassGemmConfig::WEIGHT_ONLY)
|
||||
{
|
||||
std::vector<CutlassTileConfigSM100> tile_configs{
|
||||
CutlassTileConfigSM100::CtaShape64x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x128x128B,
|
||||
};
|
||||
std::vector<CutlassGemmConfig> candidate_configs;
|
||||
for (auto const& tile_config : tile_configs)
|
||||
{
|
||||
CutlassGemmConfig cutlassKernelConfig(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
|
||||
ClusterShape::ClusterShape_1x1x1, ClusterShape::Undefined, ClusterShape::Undefined, sm);
|
||||
candidate_configs.push_back(cutlassKernelConfig);
|
||||
}
|
||||
|
||||
// add cuda kernel profiler to tactics for weight-only plugins
|
||||
CutlassGemmConfig cudaKernelConfig(CutlassTileConfigSM100::CtaShape64x128x128B, MainloopScheduleType::AUTO,
|
||||
EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1, ClusterShape::Undefined,
|
||||
ClusterShape::Undefined, sm);
|
||||
|
||||
cudaKernelConfig.enableCudaKernel = true;
|
||||
candidate_configs.push_back(cudaKernelConfig);
|
||||
|
||||
return candidate_configs;
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Not Implemented: SM100 GEMM candidates have not been defined.");
|
||||
|
||||
@ -134,8 +134,17 @@ LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch)
|
||||
{
|
||||
return getLayoutDetailsForArch<cutlass::arch::Sm90>(quant_type);
|
||||
}
|
||||
else if (arch >= 100)
|
||||
else if (arch == 100)
|
||||
{
|
||||
return getLayoutDetailsForArch<cutlass::arch::Sm100>(quant_type);
|
||||
}
|
||||
else if (arch == 103)
|
||||
{
|
||||
return getLayoutDetailsForArch<cutlass::arch::Sm103>(quant_type);
|
||||
}
|
||||
else if (arch >= 120)
|
||||
{
|
||||
// Use SM80 implementation for GB20x.
|
||||
return getLayoutDetailsForArch<cutlass::arch::Sm80>(quant_type);
|
||||
}
|
||||
else
|
||||
@ -591,23 +600,31 @@ void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, in
|
||||
// Works on row major data, so issue this permutation first.
|
||||
if (details.uses_imma_ldsm)
|
||||
{
|
||||
TLLM_LOG_INFO("permute_B_rows_for_mixed_gemm");
|
||||
permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type, arch);
|
||||
src_buf.swap(dst_buf);
|
||||
}
|
||||
|
||||
if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR)
|
||||
{
|
||||
TLLM_LOG_INFO("subbyte_transpose");
|
||||
subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type);
|
||||
src_buf.swap(dst_buf);
|
||||
}
|
||||
|
||||
if (details.columns_interleaved > 1 && arch != 90)
|
||||
if (details.columns_interleaved > 1 && (arch != 90))
|
||||
{
|
||||
TLLM_LOG_INFO("interleave_column_major_tensor");
|
||||
interleave_column_major_tensor(dst_buf.data(), src_buf.data(), shape, quant_type, details);
|
||||
src_buf.swap(dst_buf);
|
||||
}
|
||||
|
||||
add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type);
|
||||
if (arch != 100 && arch != 103)
|
||||
{
|
||||
TLLM_LOG_INFO("add_bias_and_interleave_quantized_tensor_inplace");
|
||||
add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type);
|
||||
}
|
||||
TLLM_LOG_INFO("copy");
|
||||
std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight);
|
||||
}
|
||||
|
||||
|
||||
@ -40,6 +40,7 @@
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm100.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h"
|
||||
|
||||
namespace tk = tensorrt_llm::common;
|
||||
@ -429,7 +430,7 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
|
||||
QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
|
||||
workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
|
||||
}
|
||||
else if ((sm_ >= 80 && sm_ < 89) || sm_ >= 100)
|
||||
else if ((sm_ >= 80 && sm_ < 89) || sm_ >= 120)
|
||||
{
|
||||
dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, cutlass::arch::Sm80,
|
||||
QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
|
||||
@ -454,9 +455,27 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
|
||||
static_assert(!cutlass::platform::is_same<ActivationType, __nv_fp8_e4m3>::value
|
||||
|| cutlass::platform::is_same<ScaleZeroType, half>::value,
|
||||
"ScaleZeroType must be half for activation=fp8");
|
||||
#ifdef COMPILE_HOPPER_TMA_GEMMS
|
||||
cutlass_kernels_oss::sm90_dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType,
|
||||
OutputType, QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k,
|
||||
group_size, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
|
||||
#else // COMPILE_HOPPER_TMA_GEMMS
|
||||
throw std::runtime_error(
|
||||
"[TensorRT LLM Error][fpA_intB Runner] Please recompile with support for hopper by passing 90-real as an "
|
||||
"arch to build_wheel.py.");
|
||||
#endif // COMPILE_HOPPER_TMA_GEMMS
|
||||
}
|
||||
else if (sm_ == 100 || sm_ == 103)
|
||||
{
|
||||
#ifdef COMPILE_BLACKWELL_TMA_GEMMS
|
||||
cutlass_kernels_oss::sm100_dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType,
|
||||
OutputType, QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k,
|
||||
group_size, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
|
||||
#else // COMPILE_BLACKWELL_TMA_GEMMS
|
||||
throw std::runtime_error(
|
||||
"[TensorRT LLM Error][fpA_intB Runner] Please recompile with support for blackwell by passing 100-real as "
|
||||
"an arch to build_wheel.py.");
|
||||
#endif // COMPILE_BLACKWELL_TMA_GEMMS
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -537,8 +556,9 @@ CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, Bia
|
||||
{
|
||||
|
||||
static constexpr bool is_weight_only = !std::is_same<ActivationType, WeightType>::value;
|
||||
tkc::CutlassGemmConfig::CandidateConfigTypeParam config_type_param
|
||||
= tkc::CutlassGemmConfig::CandidateConfigTypeParam::HOPPER;
|
||||
tkc::CutlassGemmConfig::CandidateConfigTypeParam config_type_param = (sm_ >= 100)
|
||||
? tkc::CutlassGemmConfig::CandidateConfigTypeParam::BLACKWELL
|
||||
: tkc::CutlassGemmConfig::CandidateConfigTypeParam::HOPPER;
|
||||
if (is_weight_only)
|
||||
{
|
||||
config_type_param = static_cast<tkc::CutlassGemmConfig::CandidateConfigTypeParam>(
|
||||
|
||||
@ -0,0 +1,153 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/numeric/integral_constant.hpp"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm100.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace cutlass_kernels_oss
|
||||
{
|
||||
namespace tk = tensorrt_llm::common;
|
||||
namespace tkc = tensorrt_llm::cutlass_extensions;
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
|
||||
typename MainloopSchedule>
|
||||
void sm100_dispatch_epilogue_schedules(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
|
||||
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
|
||||
int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||
cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
switch (gemm_config.epilogue_schedule)
|
||||
{
|
||||
case tkc::EpilogueScheduleType::AUTO:
|
||||
// TODO: use heuristics to select the epilogue schedule, depending on the CTA shape and cluster shape
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
sm100_generic_mixed_gemm_kernelLauncher<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType,
|
||||
QuantOp, EpilogueTag, CTAShape, ClusterShape, MainloopSchedule, EpilogueSchedule>(A, B, weight_scales,
|
||||
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream,
|
||||
occupancy);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[TensorRT LLM Error][fpA_intB][sm100_dispatch_epilogue_schedules] Unsupported epilogue schedule for mixed "
|
||||
"type GEMM.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape>
|
||||
void sm100_dispatch_mainloop_schedules(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
|
||||
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
|
||||
int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||
cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
|
||||
switch (gemm_config.mainloop_schedule)
|
||||
{
|
||||
case tkc::MainloopScheduleType::AUTO:
|
||||
// TODO: use heuristics to select the mainloop schedule, depending on the CTA shape and cluster shape
|
||||
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100;
|
||||
sm100_dispatch_epilogue_schedules<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
|
||||
EpilogueTag, CTAShape, ClusterShape, MainloopSchedule>(A, B, weight_scales, weight_zero_points, biases,
|
||||
alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[TensorRT LLM Error][fpA_intB][sm100_dispatch_mainloop_schedules] Unsupported mainloop schedule for mixed "
|
||||
"type GEMM.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape>
|
||||
void sm100_dispatch_gemm_config(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
|
||||
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
|
||||
int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||
cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
switch (gemm_config.cluster_shape)
|
||||
{
|
||||
// TODO: add support for other cluster shapes
|
||||
case tkc::ClusterShape::ClusterShape_1x1x1:
|
||||
sm100_dispatch_mainloop_schedules<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
|
||||
EpilogueTag, CTAShape, Shape<_1, _1, _1>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n,
|
||||
k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[TensorRT LLM Error][fpA_intB][sm100_dispatch_gemm_config] Unsupported CTA and Cluster shapes for mixed "
|
||||
"type GEMM.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag>
|
||||
void sm100_dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
|
||||
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
|
||||
int k, int const group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config,
|
||||
cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
// Note that SIMT configs are omitted here since they are not supported for fpA_intB.
|
||||
// We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best
|
||||
// for mixed type gemms.
|
||||
|
||||
constexpr int Ktile = 128 / sizeof(ActivationType);
|
||||
using _Ktile = Int<Ktile>;
|
||||
switch (gemm_config.tile_config_sm100)
|
||||
{
|
||||
case tkc::CutlassTileConfigSM100::CtaShape64x128x128B:
|
||||
sm100_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
|
||||
EpilogueTag, Shape<_64, _128, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k,
|
||||
group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM100::CtaShape128x128x128B:
|
||||
sm100_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
|
||||
EpilogueTag, Shape<_128, _128, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k,
|
||||
group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[TensorRT LLM Error][fpA_intB][sm100_dispatch_gemm_to_cutlass] Unsupported tile shape for mixed type "
|
||||
"GEMM.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cutlass_kernels_oss
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,39 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
#include "cutlass_extensions/weight_only_quant_op.h"
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace cutlass_kernels_oss
|
||||
{
|
||||
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
|
||||
typename MainloopScheduleType, typename EpilogueScheduleType>
|
||||
void sm100_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B,
|
||||
ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases,
|
||||
float const alpha, OutputType* C, int m, int n, int k, int const group_size,
|
||||
tensorrt_llm::cutlass_extensions::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||
cudaStream_t stream, int* occupancy = nullptr);
|
||||
|
||||
} // namespace cutlass_kernels_oss
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,286 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif // __GNUC__
|
||||
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
#include "cutlass_extensions/compute_occupancy.h"
|
||||
#include "cutlass_extensions/epilogue_helpers.h"
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder_sm100_weightonly.hpp"
|
||||
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic pop
|
||||
#endif // __GNUC__
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm100.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace cutlass_kernels_oss
|
||||
{
|
||||
using namespace tensorrt_llm::kernels::cutlass_kernels;
|
||||
namespace tk = tensorrt_llm::common;
|
||||
namespace tkc = tensorrt_llm::cutlass_extensions;
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
|
||||
typename MainloopScheduleType, typename EpilogueScheduleType>
|
||||
void sm100_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B,
|
||||
ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases,
|
||||
float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config,
|
||||
char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
|
||||
using CutlassActivationType = typename TllmToCutlassTypeAdapter<ActivationType>::type;
|
||||
|
||||
#ifdef COMPILE_BLACKWELL_TMA_GEMMS
|
||||
if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v<cutlass::arch::Sm100, CTAShape, ClusterShape,
|
||||
false, ActivationType>)
|
||||
{
|
||||
using CutlassWeightType__ = typename TllmToCutlassTypeAdapter<WeightType>::type;
|
||||
// We need to remap this since SM100 uses a different layout for the weight matrix.
|
||||
using CutlassWeightType_ = std::conditional_t<std::is_same_v<CutlassWeightType__, cutlass::uint4b_t>,
|
||||
cutlass::int4b_t, CutlassWeightType__>;
|
||||
|
||||
using CutlassWeightType
|
||||
= std::conditional_t<std::is_same_v<CutlassWeightType_, uint8_t>, int8_t, CutlassWeightType_>;
|
||||
|
||||
using CutlassScaleZeroType = typename TllmToCutlassTypeAdapter<ScaleZeroType>::type;
|
||||
using CutlassBiasType = typename TllmToCutlassTypeAdapter<BiasType>::type;
|
||||
using CutlassOutputType = typename TllmToCutlassTypeAdapter<OutputType>::type;
|
||||
|
||||
static_assert(std::is_same_v<CutlassActivationType, cutlass::half_t>
|
||||
|| std::is_same_v<CutlassActivationType, cutlass::bfloat16_t>
|
||||
|| std::is_same_v<CutlassActivationType, cutlass::float_e4m3_t>
|
||||
|| std::is_same_v<CutlassActivationType, cutlass::float_e5m2_t>,
|
||||
"Activation type must be bfloat16, half, FP8");
|
||||
|
||||
static_assert(std::is_same_v<CutlassWeightType, int8_t> || std::is_same_v<CutlassWeightType, cutlass::int4b_t>
|
||||
|| std::is_same_v<CutlassWeightType, cutlass::float_e4m3_t>
|
||||
|| std::is_same_v<CutlassWeightType, cutlass::float_e5m2_t>,
|
||||
"Weight type must be fp8, int8_t or int4_t");
|
||||
|
||||
static_assert(!std::is_same_v<CutlassActivationType, cutlass::float_e4m3_t>
|
||||
|| std::is_same_v<CutlassScaleZeroType, cutlass::half_t>,
|
||||
"Scale/Zero type must be half for fp8 activation");
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<CutlassActivationType>::value;
|
||||
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<CutlassWeightType>::value;
|
||||
|
||||
// This example manually swaps and transposes, so keep transpose of input layouts
|
||||
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
|
||||
using ElementZero = CutlassScaleZeroType;
|
||||
using ElementScale = CutlassScaleZeroType;
|
||||
|
||||
// C/D matrix configuration. We reuse the C operand for the bias and set the stride for broadcast.
|
||||
using LayoutBias = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentBias = 128 / cutlass::sizeof_bits<CutlassBiasType>::value;
|
||||
|
||||
// D matrix configuration
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<CutlassOutputType>::value;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
// using TileShape = CTAShape; // Threadblock-level tile size
|
||||
constexpr static bool Is2SM = cute::size<0>(ClusterShape{}) == 2;
|
||||
constexpr static int TileM = cute::size<0>(CTAShape{}) * (Is2SM ? 2 : 1);
|
||||
constexpr static int TileN = cute::size<1>(CTAShape{});
|
||||
constexpr static int TileK = cute::size<2>(CTAShape{});
|
||||
using TileShape = cute::Shape<cute::Int<TileM>, cute::Int<TileN>, cute::Int<TileK>>;
|
||||
using MainloopSchedule = std::conditional_t<Is2SM, cutlass::gemm::KernelTmaWarpSpecialized2SmMixedInputSm100,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized1SmMixedInputSm100>;
|
||||
using EpilogueSchedule = std::conditional_t<Is2SM, cutlass::epilogue::TmaWarpSpecialized2Sm,
|
||||
cutlass::epilogue::TmaWarpSpecialized1Sm>;
|
||||
|
||||
static_assert(std::is_same_v<EpilogueTag, tensorrt_llm::cutlass_extensions::EpilogueOpBias>, "");
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<ArchTag, OperatorClass, TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementCompute,
|
||||
// Transpose layout of D here since we use the explicit swap + transpose trick
|
||||
// Void C since we don't use it. Prevents smem allocation.
|
||||
CutlassBiasType, typename cutlass::layout::LayoutTranspose<LayoutBias>::type, AlignmentBias,
|
||||
CutlassOutputType, typename cutlass::layout::LayoutTranspose<LayoutOutput>::type, AlignmentOutput,
|
||||
EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using PackedScaleZero = cute::tuple<CutlassWeightType, ElementScale, ElementZero>;
|
||||
using PackedScale = cute::tuple<CutlassWeightType, ElementScale>;
|
||||
using ElementBCollectiveInfo = std::conditional_t<cutlass::hasZero(QuantOp), PackedScaleZero, PackedScale>;
|
||||
|
||||
constexpr int ScaleGranularityN = 1; // Should be less than or equal to GEMM_N
|
||||
constexpr int ScaleGranularityK = size<2>(TileShape{}); // Should be less than or equal to GEMM_K
|
||||
using ScaleConfig = cutlass::detail::Sm100MixedInputBlockwiseScaleConfig<ScaleGranularityN, ScaleGranularityK>;
|
||||
using LayoutScale = decltype(ScaleConfig::deduce_layout_scale()); // Layout type for SFA matrix operand
|
||||
LayoutScale layout_S = ScaleConfig::tile_atom_to_shape_scale(make_shape(n, k, 1));
|
||||
|
||||
// We swap A and B operands to the builder here
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilderSm100WeightOnly<ArchTag,
|
||||
OperatorClass, ElementBCollectiveInfo, cute::tuple<LayoutB_Transpose, LayoutScale>, AlignmentB,
|
||||
CutlassActivationType, LayoutA_Transpose, AlignmentA, ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop, CollectiveEpilogue>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using StrideA = typename GemmKernel::StrideA;
|
||||
using StrideB = typename GemmKernel::StrideB;
|
||||
using StrideC = typename GemmKernel::StrideC;
|
||||
using StrideD = typename GemmKernel::StrideD;
|
||||
|
||||
if (weight_scales == nullptr)
|
||||
{
|
||||
throw std::runtime_error("Weight scales must always be set to a non-null value.");
|
||||
}
|
||||
|
||||
if constexpr (cutlass::isFinegrained(QuantOp))
|
||||
{
|
||||
int cta_shape_k = cute::size<2>(TileShape{});
|
||||
if (group_size % cta_shape_k != 0)
|
||||
{
|
||||
std::string err_msg = "The group size must a multiple of " + std::to_string(cta_shape_k);
|
||||
throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner]" + err_msg);
|
||||
}
|
||||
|
||||
if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY)
|
||||
{
|
||||
if (weight_zero_points != nullptr)
|
||||
{
|
||||
throw std::runtime_error("Weight zero pointer must be a nullptr for scale only fine grained");
|
||||
}
|
||||
}
|
||||
else if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS)
|
||||
{
|
||||
if (weight_zero_points == nullptr)
|
||||
{
|
||||
throw std::runtime_error("Weight zero pointer must be valid for scale and bias fine grained");
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (group_size != k)
|
||||
{
|
||||
throw std::runtime_error("Invalid group size for per column scaling kernels.");
|
||||
}
|
||||
|
||||
if (weight_zero_points != nullptr)
|
||||
{
|
||||
throw std::runtime_error("Weight zero-points must be null when running per column scaling");
|
||||
}
|
||||
}
|
||||
|
||||
StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||
StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
||||
StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1));
|
||||
StrideC stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(0, m, 1));
|
||||
|
||||
typename Gemm::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, {n, m, k, 1},
|
||||
{reinterpret_cast<CutlassWeightType const*>(B), stride_B, reinterpret_cast<CutlassActivationType const*>(A),
|
||||
stride_A, reinterpret_cast<ElementScale const*>(weight_scales), layout_S, group_size,
|
||||
reinterpret_cast<ElementZero const*>(weight_zero_points)},
|
||||
{{alpha}, reinterpret_cast<CutlassBiasType const*>(biases), stride_C,
|
||||
reinterpret_cast<CutlassOutputType*>(C), stride_D}};
|
||||
|
||||
Gemm gemm;
|
||||
if (gemm.get_workspace_size(args) > workspace_bytes)
|
||||
{
|
||||
TLLM_LOG_ERROR("[TensorRT LLM Error][fpA_intB Runner] given workspace size insufficient.");
|
||||
}
|
||||
|
||||
auto can_implement = gemm.can_implement(args);
|
||||
if (can_implement != cutlass::Status::kSuccess)
|
||||
{
|
||||
std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: "
|
||||
+ std::string(cutlass::cutlassGetStatusString(can_implement));
|
||||
std::cout << err_msg << std::endl;
|
||||
throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner] " + err_msg);
|
||||
}
|
||||
|
||||
auto init_status = gemm.initialize(args, workspace, stream);
|
||||
if (init_status != cutlass::Status::kSuccess)
|
||||
{
|
||||
std::string err_msg = "Failed to initialize cutlass fpA_intB gemm. Error: "
|
||||
+ std::string(cutlassGetStatusString(init_status));
|
||||
throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner] " + err_msg);
|
||||
}
|
||||
|
||||
auto run_status = gemm.run(stream);
|
||||
if (run_status != cutlass::Status::kSuccess)
|
||||
{
|
||||
std::string err_msg
|
||||
= "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status));
|
||||
throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner] " + err_msg);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << "[TensorRT LLM Error][fpA_intB Runner] Config (" << (int64_t) cute::size<0>(CTAShape{}) << ","
|
||||
<< (int64_t) cute::size<1>(CTAShape{}) << "," << (int64_t) cute::size<2>(CTAShape{}) << ") ("
|
||||
<< (int64_t) cute::size<0>(ClusterShape{}) << "," << (int64_t) cute::size<1>(ClusterShape{}) << ","
|
||||
<< (int64_t) cute::size<2>(ClusterShape{}) << ") not compiled with FAST_BUILD.";
|
||||
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
#else // COMPILE_BLACKWELL_TMA_GEMMS
|
||||
throw std::runtime_error(
|
||||
"[TensorRT LLM Error][fpA_intB Runner] Please recompile with support for blackwell by passing 100-real as an "
|
||||
"arch "
|
||||
"to build_wheel.py.");
|
||||
#endif // COMPILE_BLACKWELL_TMA_GEMMS
|
||||
}
|
||||
|
||||
} // namespace cutlass_kernels_oss
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -213,8 +213,10 @@ def instantiate_operation_tma_warp_specialized(operation):
|
||||
|
||||
if operation.gemm_kind == GemmKind.Gemm:
|
||||
weight_tag = DataTypeTag[operation.weight_type]
|
||||
# Use sm100_generic_mixed_gemm_kernelLauncher for SM100 and above
|
||||
launcher_name = "sm100_generic_mixed_gemm_kernelLauncher" if operation.arch >= 100 else "sm90_generic_mixed_gemm_kernelLauncher"
|
||||
instantiation = f"""
|
||||
template void sm90_generic_mixed_gemm_kernelLauncher<{act_tag}, {weight_tag}, {scale_zero_tag}, {bias_tag}, {out_tag},
|
||||
template void {launcher_name}<{act_tag}, {weight_tag}, {scale_zero_tag}, {bias_tag}, {out_tag},
|
||||
{quant_op}, {epi_tag},
|
||||
{cute_cta_shape}, {cute_cga_shape},
|
||||
{kernel_sched}, {epi_sched}> (
|
||||
@ -808,8 +810,69 @@ def generate_sm103_operations(is_arch_enabled):
|
||||
return operations
|
||||
|
||||
|
||||
def generate_sm100_mixed_gemm_operations(is_arch_enabled):
|
||||
if not is_arch_enabled:
|
||||
return []
|
||||
arch = 100
|
||||
|
||||
# For SM100 (Blackwell), we support similar dtypes as SM90
|
||||
# Takes the form (activation_type, weight_type, scalezero_type, bias_type, output_type)
|
||||
supported_dtypes = [
|
||||
(DataType.e4m3, DataType.u4, DataType.f16, DataType.f16, DataType.f16),
|
||||
(DataType.e4m3, DataType.u4, DataType.f16, DataType.bf16,
|
||||
DataType.bf16),
|
||||
(DataType.f16, DataType.u4, DataType.f16, DataType.f16, DataType.f16),
|
||||
(DataType.bf16, DataType.u4, DataType.bf16, DataType.bf16,
|
||||
DataType.bf16),
|
||||
(DataType.f16, DataType.u8, DataType.f16, DataType.f16, DataType.f16),
|
||||
(DataType.bf16, DataType.u8, DataType.bf16, DataType.bf16,
|
||||
DataType.bf16)
|
||||
]
|
||||
|
||||
quant_ops = [
|
||||
TrtLlm_QuantOp.per_column_scale_only,
|
||||
TrtLlm_QuantOp.finegrained_scale_only,
|
||||
TrtLlm_QuantOp.finegrained_scale_and_zeros
|
||||
]
|
||||
|
||||
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_bias]
|
||||
|
||||
# SM100 uses different tile shapes
|
||||
M_TILES = [64, 128]
|
||||
N_TILES = [128]
|
||||
cta_shapes_mn = product(M_TILES, N_TILES)
|
||||
|
||||
warp_shape = [4, 1, 1]
|
||||
stages = 0 # auto
|
||||
|
||||
# SM100 currently only supports 1x1x1 cluster shape
|
||||
cga_shapes = [(1, 1, 1)]
|
||||
|
||||
partial_args = product(supported_dtypes, quant_ops, epi_tags, cta_shapes_mn,
|
||||
cga_shapes)
|
||||
|
||||
operations = list()
|
||||
for dtype_combo, quant_op, epi_tag, cta_shape_mn, cga_shape in partial_args:
|
||||
max_k_bits = 128 * 8
|
||||
cta_shape_k = max_k_bits // GetDataTypeBits(dtype_combo[0])
|
||||
cta_shape_mnk = cta_shape_mn + (cta_shape_k, )
|
||||
|
||||
# SM100 uses 1SM schedule for mixed type GEMM
|
||||
mainloop_schedule = KernelScheduleType.TmaWarpSpecialized1SmSm100
|
||||
epi_schedule = EpilogueScheduleType.TmaWarpSpecialized1Sm
|
||||
|
||||
fpA_intB_operation = TrtLlm_GemmLauncher(GemmKind.Gemm, arch, *dtype_combo, quant_op, epi_tag, cta_shape_mnk, \
|
||||
warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule)
|
||||
|
||||
if is_gemm_op_valid_sm100(fpA_intB_operation):
|
||||
operations.append(fpA_intB_operation)
|
||||
|
||||
return operations
|
||||
|
||||
|
||||
def generate_sm100_operations(is_arch_enabled):
|
||||
operations = generate_sm100_grouped_gemm_operations(is_arch_enabled, 100)
|
||||
operations.extend(generate_sm100_mixed_gemm_operations(is_arch_enabled))
|
||||
return operations
|
||||
|
||||
|
||||
@ -878,6 +941,7 @@ if __name__ == "__main__":
|
||||
output_dir = os.path.abspath(args.output_dir)
|
||||
|
||||
fpA_intB_inl = "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl"
|
||||
fpA_intB_sm100_inl = "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm100.inl"
|
||||
moe_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl"
|
||||
# moe_gemm_inl = "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl"
|
||||
moe_mixed_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl"
|
||||
@ -887,6 +951,8 @@ if __name__ == "__main__":
|
||||
|
||||
inl_map = {
|
||||
(GemmKind.Gemm, 90): [fpA_intB_inl],
|
||||
(GemmKind.Gemm, 100): [fpA_intB_sm100_inl],
|
||||
(GemmKind.Gemm, 103): [fpA_intB_sm100_inl],
|
||||
(GemmKind.Grouped, 90): [moe_gemm_inl],
|
||||
(GemmKind.Grouped, 100): [moe_gemm_inl],
|
||||
(GemmKind.Grouped, 103): [moe_gemm_inl],
|
||||
|
||||
@ -0,0 +1,32 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace weight_only
|
||||
{
|
||||
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
|
||||
KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajor, false, 64);
|
||||
// KTile=128 for w4a8
|
||||
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
|
||||
KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajor, false, 128);
|
||||
} // namespace weight_only
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace weight_only
|
||||
{
|
||||
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
|
||||
KernelType::BF16Int4PerChannel, BF16DetailsA, Int4DetailsW, ColumnMajor, false, 64);
|
||||
} // namespace weight_only
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace weight_only
|
||||
{
|
||||
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
|
||||
KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajor, false, 64);
|
||||
} // namespace weight_only
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace weight_only
|
||||
{
|
||||
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
|
||||
KernelType::BF16Int8PerChannel, BF16DetailsA, Int8DetailsW, ColumnMajor, false, 64);
|
||||
} // namespace weight_only
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,32 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace weight_only
|
||||
{
|
||||
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
|
||||
KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajor, false, 64);
|
||||
// KTile=128 for w4a8
|
||||
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
|
||||
KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajor, false, 128);
|
||||
} // namespace weight_only
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace weight_only
|
||||
{
|
||||
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
|
||||
KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajor, false, 64);
|
||||
} // namespace weight_only
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace weight_only
|
||||
{
|
||||
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
|
||||
KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajor, false, 64);
|
||||
} // namespace weight_only
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace weight_only
|
||||
{
|
||||
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
|
||||
KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajor, false, 64);
|
||||
} // namespace weight_only
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -52,7 +52,7 @@ inline void kernel_launcher(int arch, Params& params, cudaStream_t s)
|
||||
EXEC(KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true);
|
||||
EXEC(KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true);
|
||||
}
|
||||
else if ((arch >= 80 && arch < 90) || arch >= 100)
|
||||
else if ((arch >= 80 && arch < 90) || arch >= 120)
|
||||
{
|
||||
if (arch == 89 || arch >= 120)
|
||||
{
|
||||
@ -68,7 +68,7 @@ inline void kernel_launcher(int arch, Params& params, cudaStream_t s)
|
||||
EXEC(KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true);
|
||||
EXEC(KernelType::BF16Int4PerChannel, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true);
|
||||
}
|
||||
else if (arch >= 90)
|
||||
else if (arch == 90)
|
||||
{
|
||||
// Dispatchers for W4A8 groupwise
|
||||
EXEC_W4A8(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true);
|
||||
@ -83,6 +83,20 @@ inline void kernel_launcher(int arch, Params& params, cudaStream_t s)
|
||||
EXEC(KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true);
|
||||
EXEC(KernelType::BF16Int4PerChannel, BF16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true);
|
||||
}
|
||||
else if (arch == 100 || arch == 103)
|
||||
{
|
||||
EXEC_W4A8(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajor, false);
|
||||
EXEC_W4A8(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajor, false);
|
||||
EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajor, false);
|
||||
EXEC(KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajor, false);
|
||||
EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajor, false);
|
||||
EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajor, false);
|
||||
EXEC(KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajor, false);
|
||||
EXEC(KernelType::BF16Int8PerChannel, BF16DetailsA, Int8DetailsW, ColumnMajor, false);
|
||||
EXEC(KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajor, false);
|
||||
EXEC(KernelType::BF16Int4PerChannel, BF16DetailsA, Int4DetailsW, ColumnMajor, false);
|
||||
}
|
||||
#undef EXEC_W4A8
|
||||
#undef EXEC
|
||||
}
|
||||
|
||||
|
||||
@ -123,7 +123,7 @@ The language component decides which quantization methods are supported by a giv
|
||||
| Model | NVFP4 | MXFP4 | FP8(per tensor)| FP8(block scaling) | FP8(rowwise) | FP8 KV Cache | NVFP4 KV Cache | W4A8 AWQ | W4A16 AWQ | W4A8 GPTQ | W4A16 GPTQ |
|
||||
| :------------- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :-------: | :-------: | :--------: | :--------: |
|
||||
| Blackwell(sm120) | Y | Y | Y | . | . | Y | . | . | . | . | . |
|
||||
| Blackwell(sm100/103) | Y | Y | Y | Y | . | Y | Y | . | . | . | . |
|
||||
| Blackwell(sm100/103) | Y | Y | Y | Y | . | Y | Y | Y | Y | Y | Y |
|
||||
| Hopper | . | . | Y | Y | Y | Y | . | Y | Y | Y | Y |
|
||||
| Ada Lovelace | . | . | Y | . | . | Y | . | Y | Y | Y | Y |
|
||||
| Ampere | . | . | . | . | . | Y | . | . | Y | . | Y |
|
||||
|
||||
@ -1327,7 +1327,7 @@ def _(
|
||||
|
||||
class FinegrainedMixedDtypeGemm(TunableRunner):
|
||||
_runner_dict = dict()
|
||||
MAX_SUPPORTED_SM_VERSION = 90
|
||||
MAX_SUPPORTED_SM_VERSION = 103
|
||||
|
||||
def __init__(self, activation_dtype: torch.dtype, output_dtype: torch.dtype,
|
||||
quant_mode: int):
|
||||
@ -1362,7 +1362,7 @@ class FinegrainedMixedDtypeGemm(TunableRunner):
|
||||
|
||||
if get_sm_version() > self.MAX_SUPPORTED_SM_VERSION:
|
||||
raise ValueError(
|
||||
f"SM version {get_sm_version()} is not supported for W4A16 GEMM"
|
||||
f"SM version {get_sm_version()} is not supported for W4A16/W4A8 finegrained mixed dtype GEMM"
|
||||
)
|
||||
|
||||
activation, weights_packed, scales = inputs
|
||||
|
||||
@ -961,8 +961,8 @@ def preprocess_weights_for_mixed_gemm(
|
||||
tensor = tensor.unsqueeze(0)
|
||||
elif sm_ >= 90:
|
||||
sm_ = 80
|
||||
if sm_ > 90:
|
||||
sm_ = 80
|
||||
if sm_ == 100 or sm_ == 103:
|
||||
do_weight_interleave = False
|
||||
|
||||
permutation_map = {
|
||||
"16_8": [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15],
|
||||
@ -990,7 +990,7 @@ def preprocess_weights_for_mixed_gemm(
|
||||
assert (num_rows % B_ROWS_PER_MMA == 0)
|
||||
assert (num_cols % MMA_SHAPE_N == 0)
|
||||
|
||||
if do_weight_interleave:
|
||||
if do_weight_interleave and sm_ < 100:
|
||||
row_idx_list = [(row_idx // B_ROWS_PER_MMA) * B_ROWS_PER_MMA +
|
||||
permutation_map[f"{BITS_PER_ELT_A}_{BITS_PER_ELT_B}"][
|
||||
row_idx % B_ROWS_PER_MMA]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user