[None][feat] CUTLASS MoE FC2+Finalize fusion (#3294)

Signed-off-by: Sergey Klevtsov <sklevtsov@nvidia.com>
This commit is contained in:
Sergey Klevtsov 2025-08-12 00:56:48 -07:00 committed by GitHub
parent 0dc4b4e699
commit 27fc35175e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 895 additions and 953 deletions

View File

@ -366,6 +366,12 @@ bool getEnvForceDeterministicMOE()
return forceDeterministic;
}
bool getEnvMOEDisableFinalizeFusion()
{
static bool const moeDisableFinalizeFusion = getBoolEnv("TRTLLM_MOE_DISABLE_FINALIZE_FUSION");
return moeDisableFinalizeFusion;
}
bool getEnvForceDeterministicAttention()
{
static bool const forceDeterministic

View File

@ -86,6 +86,9 @@ bool getEnvForceDeterministic();
// Force deterministic behavior for MoE plugin.
bool getEnvForceDeterministicMOE();
// Disable finalize fusion in MoE plugin
bool getEnvMOEDisableFinalizeFusion();
// Force deterministic behavior for attention plugin.
bool getEnvForceDeterministicAttention();

View File

@ -1,568 +0,0 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file
\brief Functor performing elementwise operations used by epilogues.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/fast_math.h"
#include "cute/numeric/numeric_types.hpp"
#include "cute/tensor.hpp"
#include "cutlass/trace.h"
#include "cutlass_extensions/arch/copy_red_global.hpp"
#include "cutlass_extensions/util/gather_tensor.hpp"
#include "cutlass/epilogue/collective/builders/sm90_builder.inl"
#include "cutlass/epilogue/collective/builders/sm90_common.inl"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace epilogue
{
namespace collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <class StrideC_, class ElementD_, class StrideD_, class ThreadEpilogueOp_, class ElementBias, class StrideBias,
class ElementScale, class StrideScale, class EpilogueTile, class SmemLayoutAtomD, class CopyOpR2S, class CopyOpS2R,
class CopyOpR2G>
class EpilogueMoeFusedFinalize
{
public:
using EpilogueSchedule = PtrArrayNoSmemWarpSpecialized;
using DispatchPolicy = PtrArrayNoSmemWarpSpecialized;
using ThreadEpilogueOp = ThreadEpilogueOp_;
using ElementOutput = typename ThreadEpilogueOp::ElementOutput;
using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
using ElementIntermediate = typename ThreadEpilogueOp::ElementD;
using ElementC = typename ThreadEpilogueOp::ElementC;
using StrideC = StrideC_;
using InternalStrideC = cute::remove_pointer_t<StrideC>;
using ElementD = ElementD_;
using StrideD = StrideD_;
using InternalStrideD = cute::remove_pointer_t<StrideD>;
static_assert(!is_same_v<InternalStrideC, StrideC>, "Stride C must be a pointer");
static_assert(is_same_v<InternalStrideD, StrideD>, "Stride D must not be a pointer");
using CopyAtomR2S = Copy_Atom<CopyOpR2S, ElementAccumulator>;
using CopyAtomS2R = Copy_Atom<CopyOpS2R, ElementAccumulator>;
using CopyAtomR2G = Copy_Atom<CopyOpR2G, ElementD>;
static constexpr int AlignmentD = CopyAtomR2G::NumValSrc;
using SmemLayoutD = decltype(tile_to_shape(SmemLayoutAtomD{}, EpilogueTile{}));
constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{});
struct SharedStorage
{
alignas(SmemAlignmentD) cute::ArrayEngine<ElementAccumulator, cosize_v<SmemLayoutD>> smem_D;
};
struct TensorMapStorage
{
};
struct Arguments
{
typename ThreadEpilogueOp::Params thread{};
ElementC const** ptr_C{};
StrideC dC{};
ElementD* ptr_D{};
StrideD dD{};
ElementBias const* ptr_bias;
StrideBias dBias{};
ElementScale const* ptr_scale;
StrideScale dScale{};
int64_t const* group_offset{};
int32_t const* scatter_index{};
cutlass::FastDivmod num_rows_in_final_output;
};
using Params = Arguments;
//
// Methods
//
template <class ProblemShape>
static constexpr Params to_underlying_arguments(
ProblemShape const&, Arguments const& args, [[maybe_unused]] void* workspace)
{
return args;
}
template <class ProblemShape>
static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count = 0)
{
return 0;
}
template <class ProblemShape>
static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args,
void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr)
{
return cutlass::Status::kSuccess;
}
template <class ProblemShape>
CUTLASS_HOST_DEVICE static bool can_implement(
[[maybe_unused]] ProblemShape problem_shape, [[maybe_unused]] Arguments const& args)
{
bool implementable = true;
if (problem_shape.is_host_problem_shape_available())
{
// Check alignment for all problem sizes
for (int i = 0; i < problem_shape.groups(); i++)
{
auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1);
auto [M, N, K, L] = problem_shape_MNKL;
implementable = implementable
&& cutlass::detail::check_alignment<AlignmentD>(cute::make_shape(M, N, L), InternalStrideD{});
}
}
if (!implementable)
{
CUTLASS_TRACE_HOST(
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for selected global "
"reduction instruction.\n");
}
return implementable;
}
CUTLASS_HOST_DEVICE
EpilogueMoeFusedFinalize(Params const& params_)
: params(params_)
{
}
CUTLASS_DEVICE
bool is_source_needed()
{
// For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta.
return params.ptr_C != nullptr
&& (params.thread.beta_ptr_array || params.thread.beta_ptr || params.thread.beta != 0);
}
template <class ProblemShapeMNKL, class BlockShapeMNK, class BlockCoordMNKL, class FrgEngine, class FrgLayout,
class TiledMma, class ResidueMNK>
CUTLASS_HOST_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK,
BlockCoordMNKL blk_coord_mnkl, cute::Tensor<FrgEngine, FrgLayout> const& accumulators, TiledMma tiled_mma,
ResidueMNK residue_mnk, int thread_idx, [[maybe_unused]] char* smem_buf)
{
using namespace cute;
using X = Underscore;
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
auto synchronize = [&]()
{ cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
// Separate out problem shape for convenience
auto M = get<0>(problem_shape_mnkl);
auto N = get<1>(problem_shape_mnkl);
auto L = get<3>(problem_shape_mnkl);
auto mma_tile_m = tile_size<0>(tiled_mma);
auto mma_tile_n = tile_size<1>(tiled_mma);
auto epi_tile_m = size<0>(EpilogueTile{});
auto epi_tile_n = size<1>(EpilogueTile{});
CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M");
CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N");
// Batches are managed by using appropriate pointers to C and D matrices
int32_t const mock_L = 1;
int32_t const mock_l_coord = 0;
// Slice to get the tile this CTA is responsible for
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
// If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups.
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups,
// we get the correct alpha/beta values for the current batch/group using group index.
ThreadEpilogueOp epilogue_op(params.thread, l_coord);
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);
Tensor sD_ = make_tensor(make_smem_ptr(storage.smem_D.begin()), SmemLayoutD{});
Tensor sD = as_position_independent_swizzle_tensor(sD_);
// Function to scatter output rows
auto& num_rows = params.num_rows_in_final_output;
auto read_scatter_map = tensorrt_llm::cutlass_extensions::IndexedGather(
make_gmem_ptr(params.scatter_index + params.group_offset[l_coord]));
auto get_scatter_idx = [&](auto i)
{
auto scatter = read_scatter_map(i);
int quot, rem;
num_rows(quot, rem, scatter);
return rem;
};
// Represent the full output tensor
ElementC const* ptr_C = epilogue_op.is_source_needed() ? params.ptr_C[l_coord] : nullptr;
auto dC = epilogue_op.is_source_needed() ? params.dC[l_coord] : InternalStrideC{};
Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C), make_shape(M, N, mock_L), dC); // (m,n,l)
Tensor mD_mnl = tensorrt_llm::cutlass_extensions::make_gather_tensor(
make_gmem_ptr(params.ptr_D), make_shape(M, N, mock_L), params.dD, get_scatter_idx); // (m,n,l)
// Use fake shape for bias, it doesn't matter
bool const is_bias_needed = params.ptr_bias != nullptr;
Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_bias), make_shape(M, N, 1), params.dBias);
Tensor mScale_mnl = make_tensor(
make_gmem_ptr(params.ptr_scale + params.group_offset[l_coord]), make_shape(M, N), params.dScale);
Tensor gC_mnl
= local_tile(mC_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gD_mnl
= local_tile(mD_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gC = gC_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N)
Tensor gD = gD_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N)
Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor gBias_mnl
= local_tile(mBias_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gScale_mnl
= local_tile(mScale_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gBias = gBias_mnl(_, _, m_coord, n_coord, l_coord); // (BLK_M,BLK_N)
Tensor gScale = gScale_mnl(_, _, m_coord, n_coord); // (BLK_M,BLK_N)
Tensor gBias_epi = flat_divide(gBias, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor gScale_epi = flat_divide(gScale, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
// Get the smallest tiled copy we can use to retile the accumulators
TiledCopy tiled_copy_C_atom
= make_tiled_copy_C_atom(Copy_Atom<SM90_U32x4_STSM_N, cutlass::half_t>{}, tiled_mma);
TiledCopy tiled_r2s = make_tiled_copy_S(CopyAtomR2S{}, tiled_copy_C_atom);
auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx);
Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N)
Tensor tRS_sD = thread_r2s.partition_D(sD); // ((R2S,R2S_V),R2S_M,R2S_N)
Tensor tRS_rD = make_tensor<ElementAccumulator>(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N)
// Make a tiled copy vectorized along major direction of D
auto tiled_s2r = [&]()
{
if constexpr (cutlass::gemm::detail::is_k_major<StrideD>())
{
constexpr int NumThreadsMajor = epi_tile_n / AlignmentD;
constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
return make_tiled_copy(CopyAtomS2R{},
Layout<Shape<Int<NumThreadsMinor>, Int<NumThreadsMajor>>, Stride<Int<NumThreadsMajor>, _1>>{},
Layout<Shape<_1, Int<AlignmentD>>>{});
}
else if constexpr (cutlass::gemm::detail::is_mn_major<StrideD>())
{
constexpr int NumThreadsMajor = epi_tile_m / AlignmentD;
constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
return make_tiled_copy(CopyAtomS2R{},
Layout<Shape<Int<NumThreadsMajor>, Int<NumThreadsMinor>>, Stride<_1, Int<NumThreadsMajor>>>{},
Layout<Shape<Int<AlignmentD>, _1>>{});
}
else
{
static_assert(cute::is_void_v<StrideD>, "Unsupported D gmem layout.");
}
}();
auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx);
Tensor tSR_sD = thread_s2r.partition_S(sD); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_gD = thread_s2r.partition_D(gD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
Tensor tSR_gC = thread_s2r.partition_D(gC_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
Tensor tSR_gBias = thread_s2r.partition_D(gBias_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
// Allocate intermediate registers for a single subtile
Tensor tSR_rD = make_tensor<ElementAccumulator>(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rD_final = make_tensor<ElementD>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rC = make_tensor<ElementC>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rBias = make_tensor<ElementBias>(tSR_gBias(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rScale = make_tensor<ElementScale>(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)
// Make an identity coordinate tensor for predicating our output MN tile
Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD))));
Tensor cD_epi = flat_divide(cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor tSR_cD = thread_s2r.partition_D(cD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
// epilogue subtile loop
CUTLASS_PRAGMA_UNROLL
for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m)
{
CUTLASS_PRAGMA_UNROLL
for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n)
{
int mma_m = (epi_m * epi_tile_m) / mma_tile_m;
int mma_n = (epi_n * epi_tile_n) / mma_tile_n;
Tensor tRS_rAcc_mn = tRS_rAcc(_, mma_m, mma_n);
int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n);
int r2s_v = epi_n_in_mma * size(tRS_rD);
CUTLASS_PRAGMA_UNROLL
for (int epi_v = 0; epi_v < size(tRS_rD); ++epi_v)
{
tRS_rD(epi_v) = tRS_rAcc_mn(r2s_v + epi_v);
}
copy(tiled_r2s, tRS_rD, tRS_sD);
synchronize();
copy(tiled_s2r, tSR_sD, tSR_rD);
synchronize();
Tensor tSR_gC_mn = tSR_gC(_, _, _, epi_m, epi_n);
Tensor tSR_gBias_mn = tSR_gBias(_, _, _, epi_m, epi_n);
Tensor tSR_gScale_mn = tSR_gScale(_, _, _, epi_m, epi_n);
Tensor tSR_cD_mn = tSR_cD(_, _, _, epi_m, epi_n);
Tensor tSR_gD_mn = tSR_gD(_, _, _, epi_m, epi_n);
if (epilogue_op.is_source_needed())
{
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < size<1>(tSR_rD); ++m)
{
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < size<2>(tSR_rD); ++n)
{
if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk))))
{
copy(tSR_gC_mn(_, m, n), tSR_rC(_, m, n));
if (is_bias_needed)
{
copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n));
}
copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(tSR_rD); ++i)
{
auto epi_value = epilogue_op(tSR_rD(i, m, n), tSR_rC(i, m, n));
if (is_bias_needed)
{
epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n));
}
tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value);
}
copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n));
}
}
}
}
else
{
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < size<1>(tSR_rD); ++m)
{
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < size<2>(tSR_rD); ++n)
{
if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk))))
{
if (is_bias_needed)
{
copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n));
}
copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(tSR_rD); ++i)
{
auto epi_value = epilogue_op(tSR_rD(i, m, n));
if (is_bias_needed)
{
epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n));
}
tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value);
}
copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n));
}
}
}
}
}
}
}
private:
Params params;
};
namespace detail
{
template <class Element, class MaxVec>
constexpr auto get_vectorized_atomic_add_op()
{
using namespace cute;
auto constexpr MaxVecSize = size(MaxVec{});
if constexpr (is_same_v<Element, cutlass::half_t>)
{
if constexpr (MaxVecSize >= 8)
{
return SM90_RED_ADD_NOFTZ_F16x2_V4{};
}
else if constexpr (MaxVecSize >= 4)
{
return SM90_RED_ADD_NOFTZ_F16x2_V2{};
}
else if constexpr (MaxVecSize >= 2)
{
return SM70_RED_ADD_NOFTZ_F16x2{};
}
else
{
return SM70_RED_ADD_NOFTZ_F16{};
}
}
else if constexpr (is_same_v<Element, cutlass::bfloat16_t>)
{
if constexpr (MaxVecSize >= 8)
{
return SM90_RED_ADD_NOFTZ_BF16x2_V4{};
}
else if constexpr (MaxVecSize >= 4)
{
return SM90_RED_ADD_NOFTZ_BF16x2_V2{};
}
else if constexpr (MaxVecSize >= 2)
{
return SM90_RED_ADD_NOFTZ_BF16x2{};
}
else
{
return SM90_RED_ADD_NOFTZ_BF16{};
}
}
else
{
// non-vectorized atomic add for all other types until supported
return TypedAtomicAdd<Element>{};
}
}
} // namespace detail
template <class Arch, class TileShape, class ElementC, class StrideC, class ElementD, class StrideD,
class ElementAccumulator, class ElementCompute, class ElementBias, class StrideBias, class ElementScale,
class StrideScale>
struct EpilogueMoeFusedFinalizeBuilder
{
// assuming cooperative kernel schedule
using EpiTileN = decltype(cute::min(size<1>(TileShape{}), _32{}));
using EpilogueTile = Shape<_128, EpiTileN>;
// Output of linear combination is ElementCompute instead of ElementD
// since we will be doing more computate on it, no need to cast yet.
using ThreadEpilogueOp
= cutlass::epilogue::thread::LinearCombination<ElementCompute, 1, ElementAccumulator, ElementCompute,
cutlass::epilogue::thread::ScaleType::Default, cutlass::FloatRoundStyle::round_to_nearest, ElementC>;
using SmemLayoutAtomD
= decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<StrideD, ElementAccumulator, EpilogueTile>());
using CopyAtomR2S
= decltype(detail::sm90_get_smem_store_op_for_accumulator<StrideD, ElementAccumulator, EpilogueTile>());
using CopyAtomS2R = DefaultCopy;
using CopyAtomR2G = decltype(detail::get_vectorized_atomic_add_op<ElementD, EpiTileN>());
template <class Base, class EpilogueOp>
struct TmaWarpSpecializedAdapterWithSmemStorageImpl : Base
{
// We need to override this one using declaration because otherwise we double up on the smem
using TensorMapStorage = typename EpilogueOp::TensorMapStorage;
// using Base = detail::Sm90TmaWarpSpecializedAdapter<EpilogueOp>;
CUTLASS_HOST_DEVICE
TmaWarpSpecializedAdapterWithSmemStorageImpl(
typename EpilogueOp::Params const& params, [[maybe_unused]] typename Base::TensorStorage& shared_tensors)
: Base(params)
{
}
CUTLASS_DEVICE auto load_init([[maybe_unused]] typename EpilogueOp::Params const& params,
[[maybe_unused]] TensorMapStorage& shared_tensormaps, [[maybe_unused]] int32_t sm_count,
[[maybe_unused]] int32_t sm_idx)
{
return cute::make_tuple(nullptr);
}
CUTLASS_DEVICE auto store_init([[maybe_unused]] typename EpilogueOp::Params const& params,
[[maybe_unused]] TensorMapStorage& shared_tensormaps, [[maybe_unused]] int32_t sm_count,
[[maybe_unused]] int32_t sm_idx, [[maybe_unused]] int32_t warp_group_idx)
{
return cute::make_tuple(nullptr);
}
// Dummy methods to perform different parts of TMA/Tensormap modifications
template <bool IsLoad, class ProblemShapeMNKL>
CUTLASS_DEVICE void tensormaps_perform_update([[maybe_unused]] TensorMapStorage& shared_tensormaps,
[[maybe_unused]] typename EpilogueOp::Params const& params,
[[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] ProblemShapeMNKL problem_shape,
[[maybe_unused]] int32_t next_batch, [[maybe_unused]] int32_t warp_group_idx)
{
}
template <bool IsLoad>
CUTLASS_DEVICE void tensormaps_cp_fence_release([[maybe_unused]] TensorMapStorage& shared_tensormaps,
[[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] int32_t warp_group_idx)
{
}
template <bool IsLoad>
CUTLASS_DEVICE void tensormaps_fence_acquire([[maybe_unused]] cute::TmaDescriptor const* tensormap)
{
}
};
template <class EpilogueOp>
using TmaWarpSpecializedAdapterWithSmemStorage = TmaWarpSpecializedAdapterWithSmemStorageImpl<
std::conditional_t<Arch::kMinComputeCapability >= 100, detail::Sm100TmaWarpSpecializedAdapter<EpilogueOp>,
detail::Sm90TmaWarpSpecializedAdapter<EpilogueOp>>,
EpilogueOp>;
using CollectiveOp = TmaWarpSpecializedAdapterWithSmemStorage<
EpilogueMoeFusedFinalize<StrideC, ElementD, StrideD, ThreadEpilogueOp, ElementBias, StrideBias, ElementScale,
StrideScale, EpilogueTile, SmemLayoutAtomD, CopyAtomR2S, CopyAtomS2R, CopyAtomR2G>>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace collective
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,547 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Visitor tree store operations for the sm90 TMA warp-specialized (ws) epilogue
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/fusion/operations.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
#include "cutlass_extensions/arch/copy_red_global.hpp"
#include "cutlass_extensions/util/gather_tensor.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
// clang-format off
namespace cutlass::epilogue::fusion {
using namespace cute;
using namespace detail;
template <
class EpilogueTile,
class StrideOutput,
class SmemLayoutAtom,
class CopyOpR2S,
class ElementOutput,
int AlignmentOutput = 128 / cute::sizeof_bits_v<ElementOutput>,
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
>
struct Sm90ScatterPtrArray {
using SmemShape = decltype(make_shape(size(make_layout(get<0>(EpilogueTile{}))), size(make_layout(get<1>(EpilogueTile{})))));
using SmemLayout = decltype(tile_to_shape(SmemLayoutAtom{}, SmemShape{}));
using ElementIndex = int32_t;
// TODO: more generic treatment, or pass StrideIndex via template param?
using StrideIndex = conditional_t<cutlass::gemm::detail::is_mn_major<StrideOutput>(), Stride<_0,_1,_0>, Stride<_1,_0,_0>>;
struct SharedStorage {};
struct Arguments {
ElementOutput* ptr_out = nullptr;
StrideOutput dOut = {};
ElementIndex const* const* ptr_index{}; // per-group pointer to the scatter index
int index_modulo{}; // modulo used to transform the index before store
bool use_reduction = true;
};
struct Params {
ElementOutput* ptr_out = nullptr;
StrideOutput dOut = {};
ElementIndex const* const* ptr_index{}; // per-group pointer to the scatter index
cutlass::FastDivmod index_divmod{}; // modulo used to transform the index before store
bool use_reduction = true;
};
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return {
args.ptr_out,
args.dOut,
args.ptr_index,
cutlass::FastDivmod(args.index_modulo),
args.use_reduction
};
}
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
CUTLASS_HOST_DEVICE
Sm90ScatterPtrArray() { }
CUTLASS_HOST_DEVICE
Sm90ScatterPtrArray(Params const& params, SharedStorage const& shared_storage)
: params_ptr(&params) { }
Params const* params_ptr;
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
}
template<
class ArgsTuple
>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(ArgsTuple&& args_tuple)
: args_tuple(std::move(args_tuple)) {}
ArgsTuple args_tuple;
template <typename ElementAccumulator, typename ElementInput, int FragmentSize>
CUTLASS_DEVICE auto
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n,
Array<ElementInput, FragmentSize> const& frg_input) {
auto& [tC_rOut, tiled_r2s, tRG_gOut, tRG_cD, tiled_r2g_red, tiled_r2g_stg, use_reduction, thread_idx, residue_cD] = args_tuple;
using ConvertInput = NumericArrayConverter<ElementOutput, ElementInput, FragmentSize, RoundStyle>;
ConvertInput convert_input{};
Tensor tC_rOut_frg = recast<Array<ElementOutput, FragmentSize>>(coalesce(tC_rOut)); // (EPI_V)
tC_rOut_frg(epi_v) = convert_input(frg_input);
return tC_rOut_frg(epi_v);
}
template <class STensor, class SyncFn, class VTensor>
CUTLASS_DEVICE void
reduce(STensor&& reduction_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) {
auto& [tC_rOut, tiled_r2s, tRG_gOut, tRG_cD, tiled_r2g_red, tiled_r2g_stg, use_reduction, thread_idx, residue_cD] = args_tuple;
Tensor byte_buffer = recast<uint8_t>(reduction_buffer);
static_assert(cosize(byte_buffer.layout()) * sizeof_bits_v<uint8_t> >= cosize(SmemLayout{}) * sizeof_bits_v<ElementOutput>,
"Not enough space in scratch smem buffer");
Tensor sOut = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(recast_ptr<ElementOutput>(byte_buffer.data())), SmemLayout{}));
auto thread_r2s = tiled_r2s.get_slice(thread_idx);
Tensor tRS_sOut_epi = thread_r2s.partition_D(sOut);
Tensor tRS_rOut_epi = thread_r2s.retile_S(tC_rOut);
auto thread_r2g = tiled_r2g_red.get_slice(thread_idx);
Tensor tRG_gOut_epi = tRG_gOut(_,_,_,epi_m,epi_n);
Tensor tRG_sOut_epi = thread_r2g.partition_D(sOut);
Tensor tRG_rOut_epi = thread_r2g.retile_S(make_tensor(tC_rOut.data(), shape(tRG_sOut_epi))); // reuse D registers
// sanity check for register reuse
CUTE_STATIC_ASSERT_V(cosize(tC_rOut.layout()) == cosize(tRG_rOut_epi.layout()), "Invalid register count for R2G");
copy(tiled_r2s, tRS_rOut_epi, tRS_sOut_epi);
sync_fn();
copy(tRG_sOut_epi, tRG_rOut_epi);
auto residue = residue_cD; // capturing structured bindings is a C++20 feature
Tensor tRG_cD_epi = tRG_cD(0,_,_,epi_m,epi_n);
auto pred = cute::lazy::transform(tRG_cD_epi, [&](auto c){ return elem_less(c, residue); });
if (use_reduction) {
copy_if(tiled_r2g_red, pred, tRG_rOut_epi, tRG_gOut_epi);
}
else {
copy_if(tiled_r2g_stg, pred, tRG_rOut_epi, tRG_gOut_epi);
}
}
};
template <class Element, int MaxVecSize>
static constexpr auto get_reduction_op()
{
using namespace cute;
// For now only support red.add
if constexpr (is_same_v<Element, cutlass::half_t>) {
if constexpr (MaxVecSize % 8 == 0) {
return SM90_RED_ADD_NOFTZ_F16x2_V4{};
}
else if constexpr (MaxVecSize % 4 == 0) {
return SM90_RED_ADD_NOFTZ_F16x2_V2{};
}
else if constexpr (MaxVecSize % 2 == 0) {
return SM70_RED_ADD_NOFTZ_F16x2{};
}
else {
return SM70_RED_ADD_NOFTZ_F16{};
}
}
else if constexpr (is_same_v<Element, cutlass::bfloat16_t>) {
if constexpr (MaxVecSize % 8 == 0) {
return SM90_RED_ADD_NOFTZ_BF16x2_V4{};
}
else if constexpr (MaxVecSize % 4 == 0) {
return SM90_RED_ADD_NOFTZ_BF16x2_V2{};
}
else if constexpr (MaxVecSize % 2 == 0) {
return SM90_RED_ADD_NOFTZ_BF16x2{};
}
else {
return SM90_RED_ADD_NOFTZ_BF16{};
}
}
else {
// non-vectorized atomic add for all other types until supported
return TypedAtomicAdd<Element>{};
}
}
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
auto index_read = [index = params_ptr->ptr_index[l], divmod = params_ptr->index_divmod](auto i){ return divmod.rem(index[i]); };
Tensor mOut = cutlass::util::make_gather_tensor(params_ptr->ptr_out, make_shape(M,N,Int<1>{}), params_ptr->dOut, index_read); // (M,N,_1)
Tensor gOut = local_tile(mOut, take<0,2>(args.tile_shape_mnk), make_coord(m,n,Int<0>{})); // (CTA_M,CTA_N)
Tensor gOut_epi = flat_divide(gOut, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor mIdx = make_tensor(params_ptr->ptr_index[l], make_shape(M,N,Int<1>{}), StrideIndex{}); // (M,N,_1)
Tensor gIdx = local_tile(mIdx, take<0,2>(args.tile_shape_mnk), make_coord(m,n,Int<0>{})); // (CTA_M,CTA_N)
Tensor gIdx_epi = flat_divide(gIdx, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor cD_epi = flat_divide(args.cD, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor tC_gOut = sm90_partition_for_epilogue<ReferenceSrc>(gOut, args.epi_tile, args.tiled_copy, args.thread_idx); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Tensor tC_rOut = make_tensor<ElementOutput>(take<0,3>(shape(tC_gOut))); // (CPY,CPY_M,CPY_N)
auto tiled_r2s = conditional_return<ReferenceSrc>(
make_tiled_copy_S(Copy_Atom<CopyOpR2S,ElementOutput>{}, args.tiled_copy),
make_tiled_copy_D(Copy_Atom<CopyOpR2S,ElementOutput>{}, args.tiled_copy)
);
// Vectorization must not exceed alignment and also the number of values per thread in the tile
int constexpr NumThreads = CUTE_STATIC_V(size(args.tiled_copy));
int constexpr NumValTile = product(take<0,2>(shape(cD_epi)));
int constexpr MaxVecSize = cute::min(AlignmentOutput, NumValTile / NumThreads);
// Choose the largest available red.global op and an st.global op with matching vectorization
using CopyOpR2GRed = decltype(get_reduction_op<ElementOutput, MaxVecSize>());
using CopyOpR2GStg = UniversalCopy<uint_bit_t<Copy_Atom<CopyOpR2GRed,ElementOutput>::NumValSrc * sizeof_bits_v<ElementOutput>>>;
auto make_tiled_r2g = [&](auto copy_op)
{
using CopyAtomR2G = Copy_Atom<decltype(copy_op),ElementOutput>;
constexpr int VecSize = CopyAtomR2G::NumValSrc;
if constexpr (cutlass::gemm::detail::is_k_major<StrideOutput>()) {
constexpr int ThreadsMajor = size<1>(args.epi_tile) / VecSize;
constexpr int ThreadsMinor = NumThreads / ThreadsMajor;
return make_tiled_copy(CopyAtomR2G{},
Layout<Shape<Int<ThreadsMinor>, Int<ThreadsMajor>>, Stride<Int<ThreadsMajor>, _1>>{},
Layout<Shape<_1, Int<VecSize>>>{});
}
else if constexpr (cutlass::gemm::detail::is_mn_major<StrideOutput>()) {
constexpr int ThreadsMajor = size<0>(args.epi_tile) / VecSize;
constexpr int ThreadsMinor = NumThreads / ThreadsMajor;
return make_tiled_copy(CopyAtomR2G{},
Layout<Shape<Int<ThreadsMajor>, Int<ThreadsMinor>>, Stride<_1, Int<ThreadsMajor>>>{},
Layout<Shape<Int<VecSize>, _1>>{});
}
else {
static_assert(cute::is_void_v<StrideOutput>, "Unsupported D gmem layout.");
}
};
auto tiled_r2g_red = make_tiled_r2g(CopyOpR2GRed{});
auto tiled_r2g_stg = make_tiled_r2g(CopyOpR2GStg{});
// Sanity checks - since we will be using one tiled copy with tensors partitioned with the other tiled copy,
// ensure they have matching layouts/tilers
using TiledR2GRed = decltype(tiled_r2g_red);
using TiledR2GStg = decltype(tiled_r2g_stg);
static_assert(typename TiledR2GRed::AtomLayoutSrc{} == typename TiledR2GStg::AtomLayoutSrc{}, "Mismatching AtomLayoutSrc");
static_assert(typename TiledR2GRed::AtomLayoutDst{} == typename TiledR2GStg::AtomLayoutDst{}, "Mismatching AtomLayoutDst");
static_assert(typename TiledR2GRed::TiledLayout_TV{} == typename TiledR2GStg::TiledLayout_TV{}, "Mismatching TiledLayout_TV");
static_assert(typename TiledR2GRed::Tiler_MN{} == typename TiledR2GStg::Tiler_MN{}, "Mismatching Tiler_MN");
auto thread_r2g = tiled_r2g_red.get_slice(args.thread_idx);
Tensor tRG_gOut = thread_r2g.partition_D(gOut_epi); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N)
Tensor tRG_cD = thread_r2g.partition_D(cD_epi); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N)
auto args_tuple = make_tuple(
cute::move(tC_rOut),
tiled_r2s,
tRG_gOut,
tRG_cD,
tiled_r2g_red,
tiled_r2g_stg,
params_ptr->use_reduction,
args.thread_idx,
args.residue_cD);
return ConsumerStoreCallbacks<decltype(args_tuple)>(std::move(args_tuple));
}
};
template<
class ElementOutput_,
class ElementCompute_,
class ElementBias_ = ElementOutput_,
class ElementScalar_ = ElementCompute_,
int AlignmentBias_ = 128 / cute::sizeof_bits_v<ElementBias_>,
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
>
struct ScaledAccPerRowBias
: ScaledAcc<ElementOutput_, ElementCompute_, ElementScalar_, RoundStyle_>
{
using ElementBias = ElementBias_;
static constexpr int AlignmentBias = AlignmentBias_;
static constexpr bool IsPerRowBiasSupported = true;
};
template<
class GmemLayoutTagOut,
class ElementOutput,
class ElementCompute,
class ElementBias = ElementOutput,
class ElementScale = ElementCompute,
class ElementScalar = ElementCompute,
int AlignmentBias = 128 / cute::sizeof_bits_v<ElementBias>,
int AlignmentOutput = 128 / cute::sizeof_bits_v<ElementOutput>,
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
>
struct ScaledAccPerRowBiasPerColScaleScatter
: ScaledAccPerRowBias<ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>
{
using ElementAux = ElementOutput;
using GmemLayoutTagAux = GmemLayoutTagOut;
static constexpr int AlignmentAux = AlignmentOutput;
static constexpr bool IsAuxOutSupported = true;
};
// D = alpha * acc + per-row bias
template<
class CtaTileShapeMNK,
class ElementOutput,
class ElementCompute,
class ElementBias = ElementOutput,
class ElementScalar = ElementCompute,
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
>
using Sm90ScaledAccPerRowBiasPtrArray =
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // alpha * acc + bias
Sm90ScalarBroadcastPtrArray<ElementScalar, Stride<_0,_0,int64_t>>, // alpha
Sm90AccFetch, // acc
Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias *, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias
>;
template<
class CtaTileShapeMNK,
class EpilogueTile,
class StrideOutput,
class SmemLayoutAtom,
class CopyOpR2S,
class ElementOutput,
class ElementCompute,
class ElementBias = ElementOutput,
class ElementScale = ElementCompute,
class ElementScalar = ElementCompute,
int AlignmentBias = 128 / cute::sizeof_bits_v<ElementBias>,
int AlignmentOutput = 128 / cute::sizeof_bits_v<ElementOutput>,
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
>
using Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray =
Sm90EVT<Sm90ScatterPtrArray<EpilogueTile, StrideOutput, SmemLayoutAtom, CopyOpR2S, ElementOutput, AlignmentOutput, RoundStyle>, // scatter store
Sm90EVT<Sm90Compute<multiplies, ElementCompute, ElementCompute, RoundStyle>, // scale * (alpha * acc + bias)
Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar *, ElementCompute, Stride<_0,_1,int64_t>, 1>, // scale
Sm90ScaledAccPerRowBiasPtrArray<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle> // alpha * acc + bias
>
>;
template <
int StagesC,
int StagesD,
int FragmentSize,
bool ReuseSmemC,
bool DelayTmaStore,
int NumEpilogueWarpGroups,
class GmemLayoutTagOut,
class ElementOutput,
class ElementCompute,
class ElementBias,
class ElementScale,
class ElementScalar,
int AlignmentBias,
int AlignmentOutput,
FloatRoundStyle RoundStyle,
class CtaTileShapeMNK,
class EpilogueTile,
class SmemLayoutAtom,
class CopyOpR2S
>
struct FusionCallbacks<
epilogue::Sm90PtrArrayTmaWarpSpecialized<StagesC,
StagesD,
FragmentSize,
ReuseSmemC,
DelayTmaStore,
NumEpilogueWarpGroups
>,
fusion::ScaledAccPerRowBiasPerColScaleScatter<GmemLayoutTagOut,
ElementOutput,
ElementCompute,
ElementBias,
ElementScale,
ElementScalar,
AlignmentBias,
AlignmentOutput,
RoundStyle>,
CtaTileShapeMNK,
EpilogueTile,
SmemLayoutAtom,
CopyOpR2S
> : Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray<
CtaTileShapeMNK,
EpilogueTile,
cutlass::gemm::TagToStrideC_t<GmemLayoutTagOut>,
SmemLayoutAtom, CopyOpR2S,
ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar,
AlignmentBias, AlignmentOutput, RoundStyle
> {
using StrideOutput = cutlass::gemm::TagToStrideC_t<GmemLayoutTagOut>;
using Impl = Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray<
CtaTileShapeMNK,
EpilogueTile,
StrideOutput,
SmemLayoutAtom, CopyOpR2S,
ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar,
AlignmentBias, AlignmentOutput, RoundStyle
>;
using Operation = fusion::ScaledAccPerRowBiasPerColScaleScatter<
GmemLayoutTagOut,
ElementOutput,
ElementCompute,
ElementBias,
ElementScale,
ElementScalar,
AlignmentBias,
AlignmentOutput,
RoundStyle>;
struct Arguments {
using StrideAlpha = Stride<_0,_0,int64_t>;
ElementScalar alpha = ElementScalar(1);
ElementScalar const* alpha_ptr{};
ElementScalar const* const* alpha_ptr_array{};
StrideAlpha dAlpha{};
using StrideBias = Stride<_1,_0,int64_t>;
ElementBias const* const* bias_ptr{};
StrideBias dBias{};
using StrideScale = Stride<_0,_1,int64_t>;
ElementScalar const* const* scale_ptr_array{};
StrideScale dScale{};
// Nested args not usable due to a compiler bug with constexpr evaluation
// using ScatterArguments = typename Sm90ScatterPtrArray<EpilogueTile, StrideOutput, SmemLayoutAtom, CopyOpR2S, ElementOutput, AlignmentOutput, RoundStyle>::Arguments;
// ScatterArguments scatter{};
ElementOutput* ptr_out = nullptr;
StrideOutput dOut = {};
int const* const* ptr_index{}; // per-group pointer to the scatter index
int index_modulo{}; // modulo used to transform the index before store
bool use_reduction = true;
operator typename Impl::Arguments() const {
return
{ // unary op: reduce(scale * (beta * C + (alpha * acc)))
{ // binary op: scale * (beta * C + (alpha * acc))
{ scale_ptr_array, ElementScalar(1), dScale }, // leaf args : scale broadcast
{ // ternary op : alpha * acc + bias
{{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha
{}, // leaf args : acc
{bias_ptr, ElementBias(0), dBias}, // leaf args : bias
{} // ternary args : multiply_add
}, // end binary op
{} // binary args: multiply
}, // end binary op
//scatter // unary args: reduce
{ ptr_out, dOut, ptr_index, index_modulo, use_reduction }
}; // end unary op
}
};
// Ctor inheritance
using Impl::Impl;
};
} // namespace cutlass::epilogue::fusion
// clang-format on

View File

@ -133,11 +133,6 @@ enum class CutlassTileConfigSM100
CtaShape128x256x128B,
CtaShape128x128x256B,
CtaShape128x256x256B,
// M=256
CtaShape256x64x128B,
CtaShape256x128x128B,
CtaShape256x256x128B,
};
enum class CutlassTileConfigSM120

View File

@ -19,7 +19,7 @@
#include "cute/tensor.hpp"
#include "cute/util/print.hpp"
namespace tensorrt_llm::cutlass_extensions
namespace cutlass::util
{
/// Function object that applies an index to its argument
@ -81,7 +81,7 @@ struct CustomStride
template <class Div>
CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div)
{
return CustomStride<Func, decltype(safe_div(s.stride_, div))>(s.func_, safe_div(s.stride_, div));
return CustomStride<Func, decltype(cute::safe_div(s.stride_, div))>(s.func_, cute::safe_div(s.stride_, div));
}
// Circumvent the requirement on make_layout that shape and stride are integral
@ -116,7 +116,7 @@ CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, S
Layout gather_layout = make_custom_stride_layout(stride, static_cast<Func&&>(func));
return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout});
}
} // namespace tensorrt_llm::cutlass_extensions
} // namespace cutlass::util
namespace cute
{

View File

@ -377,72 +377,62 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm100(CutlassGemmConfig::Ca
if (config & CutlassGemmConfig::GROUPED_GEMM)
{
std::vector<CutlassGemmConfig> candidate_configs;
if ((config & CutlassGemmConfig::FP4_ONLY) != 0)
if (config & CutlassGemmConfig::FP4_ONLY)
{
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x128x128B,
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x256x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x64x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
return candidate_configs;
}
for (int cluster_m = 1; cluster_m <= 2; cluster_m++)
std::vector<std::pair<CutlassTileConfigSM100, ClusterShape>> tile_configs{
{CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape128x256x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_1x2x1},
{CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_1x2x1},
{CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_2x1x1},
{CutlassTileConfigSM100::CtaShape64x256x128B, ClusterShape::ClusterShape_2x1x1},
{CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_2x2x1},
{CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_2x2x1},
{CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_2x1x1},
{CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_2x1x1},
{CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_2x1x1},
{CutlassTileConfigSM100::CtaShape128x256x128B, ClusterShape::ClusterShape_2x1x1},
{CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_2x2x1},
{CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_2x2x1},
{CutlassTileConfigSM100::CtaShape128x32x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape64x32x128B, ClusterShape::ClusterShape_1x2x1},
{CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_1x2x1},
{CutlassTileConfigSM100::CtaShape64x256x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_1x2x1},
{CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape128x32x128B, ClusterShape::ClusterShape_1x2x1},
};
if (config & CutlassGemmConfig::FP8_ONLY)
{
bool Is2SM = cluster_m == 2;
for (int cluster_n = 1; cluster_n <= 2; cluster_n++)
{
std::vector base = {// M=128
CutlassTileConfigSM100::CtaShape128x128x128B, CutlassTileConfigSM100::CtaShape128x256x128B};
tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x16x128B, ClusterShape::ClusterShape_1x1x1});
// TODO: re-enable when handled by the MoE GEMM dispatch
// tile_configs.push_back({ CutlassTileConfigSM100::CtaShape128x8x256B, ClusterShape::ClusterShape_1x1x1 });
}
if (Is2SM)
{
if (cluster_n == 1)
{
base.push_back(CutlassTileConfigSM100::CtaShape128x64x128B);
base.push_back(CutlassTileConfigSM100::CtaShape256x64x128B);
}
std::vector twosm = {// M=256
CutlassTileConfigSM100::CtaShape256x128x128B, CutlassTileConfigSM100::CtaShape256x256x128B};
std::copy(twosm.begin(), twosm.end(), std::back_inserter(base));
}
else
{
if (cluster_n == 1)
{
base.push_back(CutlassTileConfigSM100::CtaShape128x32x128B);
if ((config & CutlassGemmConfig::FP8_ONLY) != 0)
{
base.push_back(CutlassTileConfigSM100::CtaShape128x16x128B);
}
}
std::vector onesm{CutlassTileConfigSM100::CtaShape64x64x128B,
CutlassTileConfigSM100::CtaShape64x128x128B, CutlassTileConfigSM100::CtaShape64x256x128B,
CutlassTileConfigSM100::CtaShape128x64x128B};
std::copy(onesm.begin(), onesm.end(), std::back_inserter(base));
}
constexpr std::array cluster_shapes
= {std::array{ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1},
std::array{ClusterShape::ClusterShape_2x1x1, ClusterShape::ClusterShape_2x2x1}};
auto cluster = cluster_shapes[cluster_m - 1][cluster_n - 1];
for (auto tile : base)
{
CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, cluster};
candidate_configs.push_back(config);
}
}
for (auto [tile, cluster] : tile_configs)
{
CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, cluster};
candidate_configs.push_back(config);
}
return candidate_configs;
}

View File

@ -37,11 +37,6 @@
namespace tensorrt_llm::kernels::cutlass_kernels
{
template <class T>
constexpr auto transpose_stride(T const& t)
{
return cute::prepend(cute::prepend(cute::take<2, cute::rank_v<T>>(t), cute::get<0>(t)), cute::get<1>(t));
}
template <typename AType, typename BType, typename BScaleType, typename OType>
struct GroupedGemmInput
@ -72,8 +67,6 @@ struct GroupedGemmInput
struct TmaWarpSpecializedGroupedGemmInput
{
template <class T>
using TransposeStride = decltype(transpose_stride<T>(T{}));
template <class Tag>
using TransposeLayoutTag = std::conditional_t<std::is_same_v<Tag, cutlass::layout::RowMajor>,
cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>;
@ -86,6 +79,7 @@ struct TmaWarpSpecializedGroupedGemmInput
using LayoutA = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for A matrix operand
using LayoutB = TransposeLayoutTag<cutlass::layout::ColumnMajor>; // Layout type for B matrix operand
using LayoutC = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for C matrix operand
using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand
constexpr static int NVFP4BlockScaleVectorSize = 16;
constexpr static int MXFPXBlockScaleVectorSize = 32;
@ -121,6 +115,7 @@ struct TmaWarpSpecializedGroupedGemmInput
using StrideB
= std::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutB*>>; // Use A because they will be swapped
using StrideC = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutC*>>;
using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;
#ifdef ENABLE_FP8
template <class T>
@ -147,37 +142,26 @@ struct TmaWarpSpecializedGroupedGemmInput
StrideC* stride_c = nullptr;
void const** ptr_c = nullptr;
struct DefaultEpilogue
{
using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand
using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;
StrideD* stride_d = nullptr;
void** ptr_d = nullptr;
};
// D is used in all cases except fused finalize
StrideD* stride_d = nullptr;
void** ptr_d = nullptr;
struct FusedFinalizeEpilogue
{
using StrideFinalOutput = DefaultEpilogue::StrideD;
using StrideBias = TransposeStride<cute::Stride<cute::_0, cute::_1, int>>;
using StrideRouterScales = TransposeStride<cute::Stride<cute::_1, cute::_0>>;
using StrideFinalOutput = cutlass::detail::TagToStrideC_t<LayoutD>;
void* ptr_final_output = nullptr;
StrideFinalOutput stride_final_output{};
void const* ptr_bias = nullptr;
StrideBias stride_bias{};
void const** ptr_bias = nullptr;
float const** ptr_router_scales = nullptr;
float const* ptr_router_scales = nullptr;
StrideRouterScales stride_router_scales{};
int const** ptr_source_token_index = nullptr;
int num_rows_in_final_output = 0;
int64_t const* ptr_expert_first_token_offset = nullptr;
int const* ptr_source_token_index = nullptr;
size_t num_rows_in_final_output = 0;
bool use_reduction = true;
};
DefaultEpilogue default_epilogue;
FusedFinalizeEpilogue fused_finalize_epilogue;
enum class EpilogueFusion
@ -235,7 +219,7 @@ struct TmaWarpSpecializedGroupedGemmInput
uint8_t* gemm_workspace = nullptr;
size_t gemm_workspace_size = 0;
static std::array<size_t, 17> workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type);
static std::array<size_t, 20> workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type);
static size_t workspaceSize(int num_experts, FpXBlockScalingType scaling_type);
@ -247,9 +231,7 @@ struct TmaWarpSpecializedGroupedGemmInput
return stride_a != nullptr && ptr_a != nullptr;
}
void setFinalizeFusionParams(void* final_output, float const* router_scales,
int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size,
int num_output_tokens);
void setFinalizeFusionParams(void* final_output, int hidden_size, int num_output_tokens, bool use_reduction);
std::string toString() const;
};

View File

@ -495,7 +495,8 @@ public:
void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1,
void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream)
void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales,
int const* permuted_row_to_unpermuted_row, cudaStream_t stream)
= 0;
virtual std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
@ -512,13 +513,13 @@ public:
virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const = 0;
bool is_profiler = false;
bool use_deterministic_hopper_reduce_ = false;
bool use_fused_finalize_ = true;
};
// Assumes inputs activations are row major. Weights need to be preprocessed by th_op/weight_quantize.cc .
// Nested in a class to avoid multiple calls to cudaGetDeviceProperties as this call can be expensive.
// Avoid making several duplicates of this class.
template <typename T, /*The type used for activations*/
template <typename T, /* The type used for activations */
typename WeightType, /* The type for the MoE weights */
typename OutputType = T, /* The type for the MoE final output */
typename InputType = T, /* The type for the MoE input */
@ -709,7 +710,8 @@ public:
void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1,
void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream) override
void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales,
int const* permuted_row_to_unpermuted_row, cudaStream_t stream) override
{
return Self::computeStridesTmaWarpSpecialized(expert_first_token_offset, layout_info1, layout_info2, num_tokens,
expanded_num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts_per_node,
@ -718,7 +720,8 @@ public:
alpha_scale_flat1, alpha_scale_flat2, fp4_act_flat1, fp4_act_flat2, quant_params,
reinterpret_cast<ScaleBiasType const*>(bias1), reinterpret_cast<ScaleBiasType const*>(bias2),
reinterpret_cast<UnfusedGemmOutputType*>(gemm1_output),
reinterpret_cast<UnfusedGemmOutputType*>(gemm2_output), stream);
reinterpret_cast<UnfusedGemmOutputType*>(gemm2_output), router_scales, permuted_row_to_unpermuted_row,
stream);
}
std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
@ -760,7 +763,8 @@ private:
float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params,
ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output,
UnfusedGemmOutputType* gemm2_output, cudaStream_t stream);
UnfusedGemmOutputType* gemm2_output, float const* router_scales, int const* permuted_row_to_unpermuted_row,
cudaStream_t stream);
static std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
computeStridesTmaWarpSpecializedLowLatency(TmaWarpSpecializedGroupedGemmInput layout_info1,
TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k,
@ -790,8 +794,8 @@ private:
bool mayHaveFinalizeFused() const
{
return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() == 90
&& !use_deterministic_hopper_reduce_ && !use_w4_groupwise;
return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() >= 90 && use_fused_finalize_
&& !use_w4_groupwise;
}
// TODO: This should eventually take the quant params to give more flexibility

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,21 +26,14 @@
#include "cute/tensor.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/fusion/operations.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass_extensions/compute_occupancy.h"
#include "cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp"
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#include "cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
@ -189,17 +182,19 @@ using SafeBF16 = void;
TmaWarpSpecializedGroupedGemmInput tma_ws_input, int num_experts, int const multi_processor_count, \
cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size) \
{ \
constexpr static EpilogueFusion FUSION = EpilogueFusion::FUSION_; \
/* constexpr static bool BIAS = BIAS_; */ /* Always false */ \
using ArchTag = cutlass::arch::ArchTag_; \
constexpr static EpilogueFusion FUSION = EpilogueFusion::FUSION_; \
constexpr static bool IsMXFPX = MXFPX_; \
constexpr bool IsBlackwell = ArchTag::kMinComputeCapability >= 100; \
constexpr bool IsSM120 = ArchTag::kMinComputeCapability == 120 || ArchTag::kMinComputeCapability == 121; \
constexpr bool Is2SM = IsBlackwell && (CGA_M_ % 2 == 0); \
/* constexpr static bool BIAS = BIAS_; */ /* Always false */ \
using T = DataType_; \
using WeightType = WeightType_; \
using OutputType = OutputType_; \
using EpilogueTag = tensorrt_llm::cutlass_extensions::EpilogueTag_; \
using TileShape = cute::Shape<cute::Int<CTA_M_>, cute::Int<CTA_N_>, cute::Int<CTA_K_>>; \
using MmaTileShape = cute::Shape<cute::Int<CTA_M_*(Is2SM ? 2 : 1)>, cute::Int<CTA_N_>, cute::Int<CTA_K_>>; \
using ClusterShape = cute::Shape<cute::Int<CGA_M_>, cute::Int<CGA_N_>, cute::Int<CGA_K_>>; \
constexpr static bool IsMXFPX = MXFPX_; \
\
if constexpr (!COMPILE_HOPPER_TMA_GROUPED_GEMMS_ENABLED && ArchTag::kMinComputeCapability >= 90 \
&& ArchTag::kMinComputeCapability < 100) \
{ \
@ -217,18 +212,15 @@ using SafeBF16 = void;
TLLM_THROW( \
"Please recompile with support for blackwell by passing 120-real as an arch to build_wheel.py."); \
} \
else if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v<ArchTag, TileShape, ClusterShape, \
T>) \
else if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v<ArchTag, MmaTileShape, \
ClusterShape, T>) \
{ \
using namespace cute; \
/* Helper class for defining all the cutlass types \
// template <typename ArchTag, typename T, typename WeightType, typename OutputType, typename EpilogueTag, \
// typename TileShape, typename ClusterShape, bool BIAS, EpilogueFusion FUSION> \
// typename MmaTileShape, typename ClusterShape, bool BIAS, EpilogueFusion FUSION> \
// struct TmaWarpSpecializedGroupedGemmInfo \
{ */ \
using Arch = ArchTag; \
constexpr static bool IsBlackwell = Arch::kMinComputeCapability >= 100; \
constexpr static bool IsSM120 = Arch::kMinComputeCapability == 120 || Arch::kMinComputeCapability == 121; \
constexpr static bool IsWFP4AFP8 = cutlass::platform::is_same<WeightType, SafeFP4>::value \
&& cutlass::platform::is_same<T, SafeFP8>::value; \
constexpr static bool IsFP4 = cutlass::platform::is_same<T, SafeFP4>::value; \
@ -308,8 +300,8 @@ using SafeBF16 = void;
// units of elements (up to 16 bytes)*/ \
\
/* D matrix configuration */ \
using LayoutD = TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::LayoutD; \
using StrideD = TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::StrideD; \
using LayoutD = TmaWarpSpecializedGroupedGemmInput::LayoutD; \
using StrideD = TmaWarpSpecializedGroupedGemmInput::StrideD; \
constexpr static int AlignmentD \
= 128 / cutlass::sizeof_bits<ElementD>::value; /* Memory access granularity/alignment of D matrix \
// in units of elements (up to 16 bytes) */ \
@ -327,30 +319,24 @@ using SafeBF16 = void;
// cutlass::epilogue::PtrArrayNoSmemWarpSpecialized, \
// cutlass::epilogue::?????????????????? /// <<<<<< what supports activations \
// >;*/ \
using EpilogueScheduleSM90 = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; \
using EpilogueScheduleSM90 = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \
\
constexpr static bool Is2SM = IsBlackwell && (cute::size<0>(ClusterShape{}) % 2) == 0; \
using EpilogueScheduleSM100 = std::conditional_t<Is2SM, cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm, \
cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm>; \
using EpilogueScheduleSM120 = cutlass::epilogue::TmaWarpSpecialized; \
using EpilogueScheduleBW = std ::conditional_t<IsSM120, EpilogueScheduleSM120, EpilogueScheduleSM100>; \
using EpilogueScheduleBW = std::conditional_t<IsSM120, EpilogueScheduleSM120, EpilogueScheduleSM100>; \
using EpilogueSchedule = std::conditional_t<IsBlackwell, EpilogueScheduleBW, EpilogueScheduleSM90>; \
\
using EpilogueTileShapeSm90 = TileShape; \
using AtomClusterDiv = std::conditional_t<Is2SM, _2, _1>; \
using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<AtomClusterDiv, _1, _1>{})); \
using EpilogueTileShapeSm100 = decltype(shape_div(TileShape{}, AtomThrShape{})); \
using EpilogueTileShape = std::conditional_t<IsBlackwell, EpilogueTileShapeSm100, EpilogueTileShapeSm90>; \
using EpilogueElementC = std::conditional_t<IsSM120, ElementCSafe, ElementC>; \
using EpilogueTensorOp = std::conditional_t<IsBlackwell && IsBlockScaled, \
cutlass::arch::OpClassBlockScaledTensorOp, cutlass::arch::OpClassTensorOp>; \
using EpilogueSubTile \
= std::conditional_t<Arch::kMinComputeCapability == 100 && IsFP4 && CTA_N_ == 256, /* SM100 Exactly */ \
cute::Shape<cute::_128, cute::_64>, cutlass::epilogue::collective::EpilogueTileAuto>; \
using EpilogueSubTile = std::conditional_t<ArchTag::kMinComputeCapability == 100 && IsFP4 \
&& CTA_N_ == 256, /* SM100 Exactly */ \
cute::Shape<cute::_128, cute::_64>, cutlass::epilogue::collective::EpilogueTileAuto>; \
/* Epilogue For Default Finalize */ \
using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder</**/ \
Arch, EpilogueTensorOp, /**/ \
EpilogueTileShape, ClusterShape, /**/ \
ArchTag, EpilogueTensorOp, /**/ \
MmaTileShape, ClusterShape, /**/ \
EpilogueSubTile, /**/ \
ElementAccumulator, ElementAccumulator, /**/ \
EpilogueElementC, LayoutC*, AlignmentC, /**/ \
@ -358,18 +344,17 @@ using SafeBF16 = void;
EpilogueSchedule>::CollectiveOp; \
\
/* Epilogue For Fused Finalize */ \
using CollectiveEpilogueFinalize = \
typename cutlass::epilogue::collective::EpilogueMoeFusedFinalizeBuilder< /**/ \
Arch, EpilogueTileShape, /**/ \
ElementCSafe, StrideC*, /**/ \
ElementFinalOutput, \
TmaWarpSpecializedGroupedGemmInput::FusedFinalizeEpilogue::StrideFinalOutput, /**/ \
ElementAccumulator, /**/ \
ElementAccumulator, /**/ \
ElementBias, TmaWarpSpecializedGroupedGemmInput::FusedFinalizeEpilogue::StrideBias, /**/ \
ElementRouterScales, \
TmaWarpSpecializedGroupedGemmInput::FusedFinalizeEpilogue::StrideRouterScales /**/ \
>::CollectiveOp; \
using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::CollectiveBuilder</**/ \
ArchTag, EpilogueTensorOp, /**/ \
MmaTileShape, ClusterShape, /**/ \
EpilogueSubTile, /**/ \
ElementAccumulator, ElementAccumulator, /**/ \
EpilogueElementC, LayoutC*, AlignmentC, /**/ \
void, LayoutD*, AlignmentD, /**/ \
EpilogueSchedule, /**/ \
cutlass::epilogue::fusion::ScaledAccPerRowBiasPerColScaleScatter< /**/ \
LayoutD, ElementFinalOutput, ElementAccumulator, ElementBias, ElementRouterScales> /**/ \
>::CollectiveOp; \
\
using CollectiveEpilogue = std::conditional_t<FUSION == EpilogueFusion::FINALIZE, \
CollectiveEpilogueFinalize, CollectiveEpilogueDefault>; \
@ -405,16 +390,12 @@ using SafeBF16 = void;
using MainloopElementA = std::conditional_t<IsBlackwell && IsBlockScaled, ElementABlockScaled, ElementA>; \
using MainloopElementB = std::conditional_t<IsBlackwell && IsBlockScaled, ElementBBlockScaled, ElementB>; \
\
using MainloopTileShapeSm90 = TileShape; \
using MainloopTileShapeSm100 = decltype(shape_div(TileShape{}, AtomThrShape{})); \
using MainloopTileShape = std::conditional_t<IsBlackwell, MainloopTileShapeSm100, MainloopTileShapeSm90>; \
\
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder</**/ \
Arch, TensorOp, /**/ \
ArchTag, TensorOp, /**/ \
MainloopElementB, LayoutB*, AlignmentB, /* A & B swapped here */ \
MainloopElementA, LayoutA*, AlignmentA, /**/ \
ElementAccumulator, /**/ \
MainloopTileShape, ClusterShape, /**/ \
MmaTileShape, ClusterShape, /**/ \
StageCountAutoCarveout, KernelSchedule>::CollectiveOp; \
\
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<TmaWarpSpecializedGroupedGemmInput::ProblemShape, \
@ -422,11 +403,11 @@ using SafeBF16 = void;
\
using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; \
/*}; \
\ \
// \
// using namespace cute; \
// using GemmInfo = TmaWarpSpecializedGroupedGemmInfo;<ArchTag, T, WeightType, OutputType, \
EpilogueTag, \
// TileShape, \
// MmaTileShape, \
// ClusterShape, BIAS, FUSION>; \
// \
// using ElementAccumulator = typename GemmInfo::ElementAccumulator; \
@ -478,7 +459,7 @@ using SafeBF16 = void;
TLLM_CHECK(tma_ws_input.ptr_a); \
TLLM_CHECK(tma_ws_input.ptr_b); \
\
auto make_mainloop_params = [&]() -> MainloopArguments \
MainloopArguments const mainloop_args = [&] \
{ \
if constexpr (IsBlockScaled) \
{ \
@ -498,67 +479,46 @@ using SafeBF16 = void;
reinterpret_cast<ElementB const**>(tma_ws_input.ptr_b), tma_ws_input.stride_b, \
reinterpret_cast<ElementA const**>(tma_ws_input.ptr_a), tma_ws_input.stride_a); \
} \
}; \
\
auto const mainloop_params = make_mainloop_params(); \
\
using EpilogueArguments = typename CollectiveEpilogue::Arguments; \
using EpilogueScalars = decltype(EpilogueArguments{}.thread); \
auto make_epilogue_scalars = [&]() \
}(); \
using FusionArguments = typename CollectiveEpilogue::FusionCallbacks::Arguments; \
FusionArguments fusion_args = [&] \
{ \
if constexpr (IsBlackwell) \
if constexpr (FUSION == EpilogueFusion::FINALIZE) \
{ \
return construct_if_true<IsBlackwell, EpilogueScalars>(ElementAccumulator(1.f), \
tma_ws_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f), nullptr, nullptr, \
tma_ws_input.alpha_scale_ptr_array, nullptr, \
cute::Shape<_0, _0, int64_t>{ \
cute::_0{}, cute::_0{}, (tma_ws_input.alpha_scale_ptr_array != nullptr) ? 1 : 0}, \
cute::Shape<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 0}); \
} \
else if (tma_ws_input.alpha_scale_ptr_array) \
{ \
return construct_if_true<!IsBlackwell, EpilogueScalars>(tma_ws_input.alpha_scale_ptr_array); \
auto epi_params = tma_ws_input.fused_finalize_epilogue; \
return construct_if_true<FUSION == EpilogueFusion::FINALIZE, FusionArguments>( \
ElementAccumulator(1), nullptr, tma_ws_input.alpha_scale_ptr_array, \
Stride<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 1}, /* alpha */ \
reinterpret_cast<ElementBias const* const*>(epi_params.ptr_bias), \
Stride<_1, _0, int64_t>{}, /* bias */ \
epi_params.ptr_router_scales, Stride<_0, _1, int64_t>{}, /* scale */ \
reinterpret_cast<ElementFinalOutput*>(epi_params.ptr_final_output), \
epi_params.stride_final_output, epi_params.ptr_source_token_index, \
epi_params.num_rows_in_final_output, epi_params.use_reduction); \
} \
else \
{ \
return construct_if_true<!IsBlackwell, EpilogueScalars>(ElementAccumulator(1.f), \
tma_ws_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); \
return construct_if_true<FUSION != EpilogueFusion::FINALIZE, FusionArguments>( \
ElementAccumulator(1), ElementAccumulator(0), nullptr, nullptr, \
tma_ws_input.alpha_scale_ptr_array, nullptr, \
Stride<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 1}, Stride<_0, _0, int64_t>{}); \
} \
}; \
auto epilogue_scalars = make_epilogue_scalars(); \
/* TODO ptr_c casts to ElementCSafe** because there is a workaround in CUTLASS */ \
auto make_epi_args = [&]() \
{ \
static_assert(FUSION == EpilogueFusion::NONE || FUSION == EpilogueFusion::FINALIZE, \
"Unimplemented fusion provided to TMA WS MoE gemm launcher"); \
}(); \
\
if constexpr (FUSION == EpilogueFusion::NONE) \
using EpilogueArguments = typename CollectiveEpilogue::Arguments; \
EpilogueArguments epilogue_args = [&] \
{ \
if constexpr (FUSION == EpilogueFusion::FINALIZE) \
{ \
auto epi_params = tma_ws_input.default_epilogue; \
return construct_if_true<FUSION == EpilogueFusion::NONE, EpilogueArguments>(epilogue_scalars, \
nullptr, tma_ws_input.stride_c, reinterpret_cast<ElementD**>(epi_params.ptr_d), \
epi_params.stride_d); \
} \
else if constexpr (FUSION == EpilogueFusion::FINALIZE) \
{ \
/* Parameters for fused finalize */ \
auto epi_params = tma_ws_input.fused_finalize_epilogue; \
return construct_if_true<FUSION == EpilogueFusion::FINALIZE, EpilogueArguments>( \
epilogue_scalars, /* Parameters to underlying epilogue */ \
nullptr, tma_ws_input.stride_c, /* C params */ \
reinterpret_cast<ElementFinalOutput*>(epi_params.ptr_final_output), \
epi_params.stride_final_output, /* D (output) params */ \
reinterpret_cast<ElementBias const*>(epi_params.ptr_bias), \
epi_params.stride_bias, /* Bias params */ \
epi_params.ptr_router_scales, epi_params.stride_router_scales, /* Router scales */ \
epi_params.ptr_expert_first_token_offset, /* Offset of this expert's token in the \
router scales */ \
epi_params.ptr_source_token_index, /* Index of the source token to sum into */ \
epi_params.num_rows_in_final_output /* Number of tokens in the output buffer */ \
); \
fusion_args, nullptr, nullptr, nullptr, nullptr); \
} \
}; \
EpilogueArguments const epilogue_params = make_epi_args(); \
else \
{ \
return construct_if_true<FUSION != EpilogueFusion::FINALIZE, EpilogueArguments>(fusion_args, \
nullptr, nullptr, reinterpret_cast<ElementD**>(tma_ws_input.ptr_d), tma_ws_input.stride_d); \
} \
}(); \
/* EpilogueArguments const epilogue_params = make_epi_args<EpilogueArguments, EpilogueScalars, \
ElementCSafe, ElementD, ElementFinalOutput, ElementBias, FUSION>( \
// tma_ws_input, epilogue_scalars \
@ -568,7 +528,7 @@ using SafeBF16 = void;
1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN}; \
\
const typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, \
tma_ws_input.shape_info, mainloop_params, epilogue_params, hw_info, scheduler_args}; \
tma_ws_input.shape_info, mainloop_args, epilogue_args, hw_info, scheduler_args}; \
\
size_t calculated_ws_size = gemm.get_workspace_size(args); \
TLLM_CHECK_WITH_INFO(calculated_ws_size <= tma_ws_input.gemm_workspace_size, \

View File

@ -197,8 +197,7 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
reinterpret_cast<ElementScalePacked const**>(hopper_inputs.int4_groupwise_params.ptr_s_a),
hopper_inputs.int4_groupwise_params.stride_s_a, group_size},
{fusion_args, reinterpret_cast<ElementC const**>(hopper_inputs.ptr_c), hopper_inputs.stride_c,
reinterpret_cast<ElementD**>(hopper_inputs.default_epilogue.ptr_d),
hopper_inputs.default_epilogue.stride_d},
reinterpret_cast<ElementD**>(hopper_inputs.ptr_d), hopper_inputs.stride_d},
hw_info};
*workspace_size = gemm.get_workspace_size(args);
return;
@ -211,8 +210,7 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
reinterpret_cast<ElementScalePacked const**>(hopper_inputs.int4_groupwise_params.ptr_s_a),
hopper_inputs.int4_groupwise_params.stride_s_a, group_size},
{fusion_args, reinterpret_cast<ElementC const**>(hopper_inputs.ptr_c), hopper_inputs.stride_c,
reinterpret_cast<ElementD**>(hopper_inputs.default_epilogue.ptr_d),
hopper_inputs.default_epilogue.stride_d},
reinterpret_cast<ElementD**>(hopper_inputs.ptr_d), hopper_inputs.stride_d},
hw_info};
if (gemm.get_workspace_size(arguments) > hopper_inputs.gemm_workspace_size)

View File

@ -138,11 +138,11 @@ void dispatchMoeGemmSelectBiasTmaWarpSpecialized(TmaWarpSpecializedGroupedGemmIn
}
}
template <typename ClusterTileShape, typename ClusterShape, typename DataType, typename WeightType>
template <typename CtaShape, typename ClusterShape, typename DataType, typename WeightType>
constexpr bool are_tile_shapes_supported_sm100()
{
using namespace cute;
using CtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{}));
// This is the epilogue shape. The MMA shape will be twice this for 2SM
constexpr auto TileM = size<0>(CtaShape{});
constexpr auto TileN = size<1>(CtaShape{});
@ -353,6 +353,7 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized(TmaWarpSpecializedGroupedG
{
switch (gemm_config.tile_config_sm100)
{
SHAPE_CASE(100, 64, 32, 128)
SHAPE_CASE(100, 64, 64, 128)
SHAPE_CASE(100, 64, 128, 128)
SHAPE_CASE(100, 64, 256, 128)
@ -363,13 +364,8 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized(TmaWarpSpecializedGroupedG
SHAPE_CASE(100, 128, 128, 128)
SHAPE_CASE(100, 128, 256, 128)
SHAPE_CASE(100, 256, 64, 128)
SHAPE_CASE(100, 256, 128, 128)
SHAPE_CASE(100, 256, 256, 128)
// SHAPE_CASE(100, 128, 128, 64)
// SHAPE_CASE(100, 128, 256, 64)
// SHAPE_CASE(100, 256, 256, 64)
DEFAULT_CASE(100)
}
}

View File

@ -27,14 +27,14 @@
namespace tensorrt_llm::kernels::cutlass_kernels
{
std::array<size_t, 17> TmaWarpSpecializedGroupedGemmInput::workspaceBuffers(
std::array<size_t, 20> TmaWarpSpecializedGroupedGemmInput::workspaceBuffers(
int num_experts, FpXBlockScalingType scaling_type)
{
size_t problem_shape_size = sizeof(ProblemShape::UnderlyingProblemShape) * num_experts;
size_t stride_a_size = sizeof(StrideA) * num_experts;
size_t stride_b_size = sizeof(StrideB) * num_experts;
size_t stride_c_size = sizeof(StrideC) * num_experts;
size_t stride_d_size = sizeof(DefaultEpilogue::StrideD) * num_experts;
size_t stride_d_size = sizeof(StrideD) * num_experts;
size_t ptr_buf_size = sizeof(void*) * num_experts;
size_t scale_buf_size = sizeof(float*) * num_experts;
@ -53,9 +53,12 @@ std::array<size_t, 17> TmaWarpSpecializedGroupedGemmInput::workspaceBuffers(
size_t int4_groupwise_sf_a_size = sizeof(INT4GroupwiseParams::SFA*) * num_experts;
size_t int4_groupwise_stride_sf_a_size = sizeof(INT4GroupwiseParams::StrideSFA) * num_experts;
size_t ptr_token_map_size = sizeof(int**) * num_experts;
return std::array{problem_shape_size, stride_a_size, stride_b_size, stride_c_size, stride_d_size, ptr_buf_size,
ptr_buf_size, ptr_buf_size, ptr_buf_size, scale_buf_size, sf_a_size, sf_b_size, stride_sf_a_size,
stride_sf_b_size, int4_groupwise_problem_shape_size, int4_groupwise_sf_a_size, int4_groupwise_stride_sf_a_size};
stride_sf_b_size, int4_groupwise_problem_shape_size, int4_groupwise_sf_a_size, int4_groupwise_stride_sf_a_size,
ptr_buf_size, scale_buf_size, ptr_token_map_size};
}
size_t TmaWarpSpecializedGroupedGemmInput::workspaceSize(int num_experts, FpXBlockScalingType scaling_type)
@ -68,7 +71,7 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i
size_t gemm_workspace_size, FpXBlockScalingType scaling_type)
{
auto buffers = workspaceBuffers(num_experts, scaling_type);
std::array<int8_t*, 17> pointers{};
std::array<int8_t*, 20> pointers{};
TLLM_CHECK_WITH_INFO(pointers.size() == buffers.size(), "Mismatching workspace size and number of buffers");
for (int i = 0; i < buffers.size(); i++)
{
@ -82,12 +85,12 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i
stride_a = reinterpret_cast<StrideA*>(pointers[1]);
stride_b = reinterpret_cast<StrideB*>(pointers[2]);
stride_c = reinterpret_cast<StrideC*>(pointers[3]);
default_epilogue.stride_d = reinterpret_cast<DefaultEpilogue::StrideD*>(pointers[4]);
stride_d = reinterpret_cast<StrideD*>(pointers[4]);
ptr_a = reinterpret_cast<void const**>(pointers[5]);
ptr_b = reinterpret_cast<void const**>(pointers[6]);
ptr_c = reinterpret_cast<void const**>(pointers[7]);
default_epilogue.ptr_d = reinterpret_cast<void**>(pointers[8]);
ptr_d = reinterpret_cast<void**>(pointers[8]);
alpha_scale_ptr_array = reinterpret_cast<float const**>(pointers[9]);
@ -103,28 +106,24 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i
int4_groupwise_params.ptr_s_a = reinterpret_cast<INT4GroupwiseParams::SFA const**>(pointers[15]);
int4_groupwise_params.stride_s_a = reinterpret_cast<INT4GroupwiseParams::StrideSFA*>(pointers[16]);
fused_finalize_epilogue.ptr_bias = reinterpret_cast<void const**>(pointers[17]);
fused_finalize_epilogue.ptr_router_scales = reinterpret_cast<float const**>(pointers[18]);
fused_finalize_epilogue.ptr_source_token_index = reinterpret_cast<int const**>(pointers[19]);
this->gemm_workspace = reinterpret_cast<uint8_t*>(gemm_workspace);
this->gemm_workspace_size = gemm_workspace_size;
}
void TmaWarpSpecializedGroupedGemmInput::setFinalizeFusionParams(void* final_output, float const* router_scales,
int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size,
int num_output_tokens)
void TmaWarpSpecializedGroupedGemmInput::setFinalizeFusionParams(
void* final_output, int hidden_size, int num_output_tokens, bool use_reduction)
{
fused_finalize_epilogue.ptr_final_output = final_output;
fused_finalize_epilogue.ptr_router_scales = router_scales;
fused_finalize_epilogue.ptr_bias = bias;
fused_finalize_epilogue.ptr_expert_first_token_offset = expert_first_token_offset;
fused_finalize_epilogue.ptr_source_token_index = source_token_index;
fused_finalize_epilogue.stride_final_output
= cutlass::make_cute_packed_stride(FusedFinalizeEpilogue::StrideFinalOutput{},
transpose_stride(cute::make_shape(num_output_tokens, hidden_size, 1)));
fused_finalize_epilogue.stride_bias
= transpose_stride(cute::make_stride(cute::Int<0>{}, cute::Int<1>{}, hidden_size));
fused_finalize_epilogue.stride_router_scales = {};
fused_finalize_epilogue.stride_final_output = cutlass::make_cute_packed_stride(
FusedFinalizeEpilogue::StrideFinalOutput{}, cute::make_shape(hidden_size, num_output_tokens, 1));
fused_finalize_epilogue.num_rows_in_final_output = num_output_tokens;
fused_finalize_epilogue.use_reduction = use_reduction;
}
std::string TmaWarpSpecializedGroupedGemmInput::toString() const
@ -143,16 +142,13 @@ std::string TmaWarpSpecializedGroupedGemmInput::toString() const
ss << "Final Output: " << (PrintType) fused_finalize_epilogue.ptr_final_output;
ss << " with Stride: " << fused_finalize_epilogue.stride_final_output;
ss << ",\nBias: " << (PrintType) fused_finalize_epilogue.ptr_bias;
ss << " with Stride: " << fused_finalize_epilogue.stride_bias;
ss << ",\nRouter Scales: " << fused_finalize_epilogue.ptr_router_scales;
ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales;
ss << ",\nExpert Offset: " << (PrintType) fused_finalize_epilogue.ptr_expert_first_token_offset;
ss << ", Source Map: " << (PrintType) fused_finalize_epilogue.ptr_source_token_index;
}
else
{
ss << "Ptr D: " << (PrintType) default_epilogue.ptr_d;
ss << " with Stride: " << (PrintType) default_epilogue.stride_d;
ss << "Ptr D: " << (PrintType) ptr_d;
ss << " with Stride: " << (PrintType) stride_d;
}
ss << '\n';
ss << "Alpha scale ptr: " << (PrintType) alpha_scale_ptr_array << "\n";

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -1165,8 +1165,8 @@ __device__ void computeTmaWarpSpecializedInputStrides(
}
if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE)
{
layout_info.default_epilogue.stride_d[out_idx] = cutlass::make_cute_packed_stride(
TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::StrideD{}, cute::make_shape(gemm_n, gemm_m, 1));
layout_info.stride_d[out_idx] = cutlass::make_cute_packed_stride(
TmaWarpSpecializedGroupedGemmInput::StrideD{}, cute::make_shape(gemm_n, gemm_m, 1));
}
if (layout_info.int4_groupwise_params.enabled)
{
@ -1185,7 +1185,8 @@ template <class T, class WeightType, class OutputType, class ScaleBiasType>
__device__ void computeTmaWarpSpecializedInputPointers(TmaWarpSpecializedGroupedGemmInput& layout_info, int64_t gemm_m,
int64_t gemm_n, int64_t gemm_k, int num_tokens_before_expert, int64_t expert, T const* in,
WeightType const* weights, TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::SFA const* w4a8_weight_scale,
ScaleBiasType const* bias, OutputType* output, int64_t const out_idx)
ScaleBiasType const* bias, OutputType* output, float const* router_scales,
int const* permuted_row_to_unpermuted_row, int64_t const out_idx)
{
// The input prior to this contains K elements per token, with `num_tokens_before_expert` tokens
layout_info.ptr_a[out_idx] = safe_inc_ptr(in, num_tokens_before_expert * gemm_k);
@ -1196,7 +1197,18 @@ __device__ void computeTmaWarpSpecializedInputPointers(TmaWarpSpecializedGrouped
if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE)
{
// The output prior to this contains N elements per token, with `num_tokens_before_expert` tokens
layout_info.default_epilogue.ptr_d[out_idx] = safe_inc_ptr(output, num_tokens_before_expert * gemm_n);
layout_info.ptr_d[out_idx] = safe_inc_ptr(output, num_tokens_before_expert * gemm_n);
}
if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE)
{
layout_info.fused_finalize_epilogue.ptr_source_token_index[expert]
= permuted_row_to_unpermuted_row + num_tokens_before_expert;
layout_info.fused_finalize_epilogue.ptr_router_scales[expert] = router_scales + num_tokens_before_expert;
if (layout_info.fused_finalize_epilogue.ptr_bias != nullptr)
{
layout_info.fused_finalize_epilogue.ptr_bias[expert] = bias + gemm_n * expert;
}
}
if (layout_info.int4_groupwise_params.enabled)
{
@ -1219,7 +1231,8 @@ __global__ void computeStridesTmaWarpSpecializedKernel(int64_t const* expert_fir
WeightType const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params,
ScaleBiasType const* bias1, ScaleBiasType const* bias2, OutputType* gemm1_output, OutputType* gemm2_output)
ScaleBiasType const* bias1, ScaleBiasType const* bias2, OutputType* gemm1_output, OutputType* gemm2_output,
float const* router_scales, int const* permuted_row_to_unpermuted_row)
{
// First, compute the global tid. We only need 1 thread per expert.
int const expert = blockIdx.x * blockDim.x + threadIdx.x;
@ -1297,12 +1310,12 @@ __global__ void computeStridesTmaWarpSpecializedKernel(int64_t const* expert_fir
gemm1_in, weights1,
reinterpret_cast<TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::SFA const*>(
quant_params.groupwise.fc1.weight_scales),
bias1, gemm1_output, expert);
bias1, gemm1_output, nullptr, nullptr, expert);
computeTmaWarpSpecializedInputPointers(layout_info2, gemm_m, gemm2_n, gemm2_k, num_tokens_before_expert, expert,
gemm2_in, weights2,
reinterpret_cast<TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::SFA const*>(
quant_params.groupwise.fc2.weight_scales),
bias2, gemm2_output, expert);
bias2, gemm2_output, router_scales, permuted_row_to_unpermuted_row, expert);
}
template <class T, class WeightType, class OutputType, class ScaleBiasType>
@ -1420,12 +1433,12 @@ __global__ void computeStridesTmaWarpSpecializedLowLatencyKernel(TmaWarpSpeciali
layout_info2.ptr_b[expert] = safe_inc_ptr(weights2, local_expert * (gemm1_n * gemm2_k));
assert(layout_info1.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE);
layout_info1.default_epilogue.ptr_d[expert] = safe_inc_ptr(output1, expert * num_tokens * gemm1_n);
layout_info1.ptr_d[expert] = safe_inc_ptr(output1, expert * num_tokens * gemm1_n);
if (layout_info2.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE)
{
// The output prior to this contains N elements per token, with `num_tokens` tokens
layout_info2.default_epilogue.ptr_d[expert] = safe_inc_ptr(output2, expert * num_tokens * gemm2_n);
layout_info2.ptr_d[expert] = safe_inc_ptr(output2, expert * num_tokens * gemm2_n);
}
}
else
@ -1435,10 +1448,10 @@ __global__ void computeStridesTmaWarpSpecializedLowLatencyKernel(TmaWarpSpeciali
layout_info1.ptr_b[expert] = nullptr;
layout_info2.ptr_b[expert] = nullptr;
layout_info1.default_epilogue.ptr_d[expert] = nullptr;
layout_info1.ptr_d[expert] = nullptr;
if (layout_info2.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE)
{
layout_info2.default_epilogue.ptr_d[expert] = nullptr;
layout_info2.ptr_d[expert] = nullptr;
}
}
}
@ -2015,8 +2028,8 @@ void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_ro
#define INSTANTIATE_FINALIZE_MOE_ROUTING(OutputT, GemmOutputT, ScaleBiasT) \
template void finalizeMoeRoutingKernelLauncher<OutputT, GemmOutputT, ScaleBiasT>( \
GemmOutputT const* expanded_permuted_rows, OutputT* reduced_unpermuted_output, ScaleBiasT const* bias, \
float const* final_scales, int const* expanded_source_row_to_expanded_dest_row, \
int const* expanded_dest_row_to_expanded_source_row, int const* expert_for_source_row, \
float const* final_scales, int const* unpermuted_row_to_permuted_row, \
int const* permuted_row_to_unpermuted_row, int const* expert_for_source_row, \
int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const cols, \
int64_t const experts_per_token, int64_t const num_experts_per_node, MOEParallelismConfig parallelism_config, \
bool const enable_alltoall, cudaStream_t stream);
@ -3295,9 +3308,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
}
bool has_different_output_type_ampere = (use_w4afp8 || use_fp8) && !using_tma_ws_gemm2;
bool using_hopper_fused_finalize
= tma_ws_input.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE;
bool has_different_output_type_tma_ws = !using_hopper_fused_finalize && using_tma_ws_gemm2;
bool using_fused_finalize = tma_ws_input.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE;
bool has_different_output_type_tma_ws = !using_fused_finalize && using_tma_ws_gemm2;
if (has_different_output_type_ampere || has_different_output_type_tma_ws)
{
@ -3815,7 +3827,8 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
float const* fp8_dequant2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params,
ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output,
UnfusedGemmOutputType* gemm2_output, cudaStream_t stream)
UnfusedGemmOutputType* gemm2_output, float const* router_scales, int const* permuted_row_to_unpermuted_row,
cudaStream_t stream)
{
// Always nullptr
layout_info1.ptr_c = nullptr;
@ -3823,6 +3836,12 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
layout_info2.ptr_c = nullptr;
layout_info2.stride_c = nullptr;
layout_info1.fused_finalize_epilogue.ptr_bias = nullptr;
if (!bias2)
{
layout_info2.fused_finalize_epilogue.ptr_bias = nullptr;
}
auto alpha_scale_flat1 = use_fp4 ? quant_params.fp4.fc1.global_scale
: use_wfp4afp8 ? quant_params.fp8_mxfp4.fc1.global_scale
: use_fp8 ? fp8_dequant1
@ -3863,7 +3882,7 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
cudaLaunchKernelEx(&config, kernel_instance, expert_first_token_offset, layout_info1, layout_info2, num_tokens,
expanded_num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts_per_node, gemm1_in, gemm2_in, weights1,
weights2, alpha_scale_flat1, alpha_scale_flat2, fp4_act_flat1, fp4_act_flat2, quant_params, bias1, bias2,
gemm1_output, gemm2_output);
gemm1_output, gemm2_output, router_scales, permuted_row_to_unpermuted_row);
return std::make_pair(layout_info1, layout_info2);
}
@ -3986,15 +4005,15 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
bool apply_bias = parallelism_config.tp_rank == 0;
bool using_hopper_fused_finalize
= !use_deterministic_hopper_reduce_ && gemm2_config_->sm_version == 90 && !use_w4_groupwise && !use_lora;
if (using_hopper_fused_finalize)
auto* fc2_bias = apply_bias ? fc2_expert_biases : nullptr;
bool using_fused_finalize
= use_fused_finalize_ && gemm2_config_->sm_version >= 90 && !use_w4_groupwise && !use_lora;
if (using_fused_finalize)
{
assert(min_latency_mode == false);
bool use_reduction = expanded_num_rows > num_rows;
gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE;
gemm2_tma_ws_input.setFinalizeFusionParams(final_output, permuted_token_final_scales_,
expert_first_token_offset_, permuted_row_to_unpermuted_row_, apply_bias ? fc2_expert_biases : nullptr,
hidden_size, num_rows);
gemm2_tma_ws_input.setFinalizeFusionParams(final_output, hidden_size, num_rows, use_reduction);
}
// fp8_mxfp4 memsets the scaling factors to 1.0f
@ -4028,9 +4047,10 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
gemm2_tma_ws_input, num_rows, expanded_num_rows, fc1_out_size, hidden_size, hidden_size, inter_size,
num_experts_per_node, reinterpret_cast<T const*>(gemm1_input), reinterpret_cast<T const*>(gemm2_input),
fc1_expert_weights, fc2_expert_weights, quant_params.fp8.dequant_fc1, quant_params.fp8.dequant_fc2,
fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, fc1_expert_biases, fc2_expert_biases,
fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, fc1_expert_biases, fc2_bias,
reinterpret_cast<UnfusedGemmOutputType*>(gemm1_output),
reinterpret_cast<UnfusedGemmOutputType*>(fc2_result_), stream);
reinterpret_cast<UnfusedGemmOutputType*>(fc2_result_), permuted_token_final_scales_,
permuted_row_to_unpermuted_row_, stream);
}
}
@ -4591,20 +4611,17 @@ void GemmProfilerBackend::prepareTmaWsInputs(
gemm1_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
bool apply_bias = true;
bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4);
bool use_wfp4a16 = ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16)
&& mWType == nvinfer1::DataType::kUINT8);
bool use_w4_groupwise = use_w4afp8 || use_wfp4a16;
bool using_fused_finalize
= !mInterface->use_deterministic_hopper_reduce_ && mSM == 90 && !mMinLatencyMode && !use_w4_groupwise;
= mInterface->use_fused_finalize_ && mSM >= 90 && !mMinLatencyMode && !use_w4_groupwise;
if (using_fused_finalize)
{
assert(!mMinLatencyMode);
gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE;
gemm2_tma_ws_input.setFinalizeFusionParams(output, token_topk_unpermuted_scales,
expert_first_token_offset, permuted_row_to_unpermuted_row, apply_bias ? bias : nullptr,
mExpertHiddenSize, num_tokens);
gemm2_tma_ws_input.setFinalizeFusionParams(output, mExpertHiddenSize, num_tokens, mK > 1);
}
auto fc1_output_size = isGatedActivation(mActivationType) ? mExpertInterSize * 2 : mExpertInterSize;
@ -4625,7 +4642,7 @@ void GemmProfilerBackend::prepareTmaWsInputs(
fc1_output_size, mExpertHiddenSize, mExpertHiddenSize, mExpertInterSize, mNumExpertsPerNode, input,
input, weights_sel, weights_sel, mQuantParams.fp8.dequant_fc1, mQuantParams.fp8.dequant_fc2,
fp4_act_scale_flat, fp4_act_scale_flat, mQuantParams, nullptr, nullptr, intermediate, intermediate,
stream);
token_topk_unpermuted_scales, permuted_row_to_unpermuted_row, stream);
}
sync_check_cuda_error(stream);
}

View File

@ -35,8 +35,7 @@ constexpr bool isValidSM120MOESpecialisation()
{
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) // TODO Is there a better choice
return cutlass::platform::is_same<T, __nv_fp4_e2m1>::value && cutlass::platform::is_same<T, WeightType>::value
&& cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value
&& Fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
&& cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value;
#else
return false; // CUTLASS_ARCH_MMA_SM100_SUPPORTED is set when Blackwell kernels are enabled
#endif
@ -51,8 +50,7 @@ constexpr bool isValidBlackwellMOESpecialisation()
return (cutlass::platform::is_same<T, WeightType>::value
|| (cutlass::platform::is_same<T, __nv_fp8_e4m3>::value
&& cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value))
&& cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value
&& Fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
&& cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value;
#else
return false; // CUTLASS_ARCH_MMA_SM100_SUPPORTED is set when Blackwell kernels are enabled
#endif

View File

@ -212,8 +212,7 @@ template void sm90_generic_mixed_gemm_kernelLauncher<{act_tag}, {weight_tag}, {s
{kernel_sched}, {epi_sched}> (
const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zero_tag}*, const {bias_tag}*, const float,
{out_tag}*, int, int, int, const int, tensorrt_llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
);
"""
);"""
elif operation.gemm_kind == GemmKind.Grouped:
if operation.act_type != operation.weight_type and (
operation.act_type != DataType.e4m3
@ -261,11 +260,9 @@ GroupedGemmInput<{act_tag}, {weight_tag}, {out_tag}, {out_tag}>inputs, TmaWarpSp
# (TmaWarpSpecializedGroupedGemmInput, int, int, cudaStream_t, int*, size_t*);
# """
instantiation = f"""
#if {guard_act} && {guard_weight}\n
INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM({arch_tag}, {act_tag}, {weight_tag}, {out_tag},
{epi_tag}, {epi_fusion}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.cga_shape[0]}, {operation.cga_shape[1]}, {operation.cga_shape[2]}, {"true" if operation.is_mx_fpx else "false"}, false);\n
#endif
"""
#if {guard_act} && {guard_weight}
INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM({arch_tag}, {act_tag}, {weight_tag}, {out_tag}, {epi_tag}, {epi_fusion}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.cga_shape[0]}, {operation.cga_shape[1]}, {operation.cga_shape[2]}, {"true" if operation.is_mx_fpx else "false"}, false);
#endif"""
return instantiation
@ -276,8 +273,7 @@ def instantiate_operation_sm80(operation):
instantiation = f"""
template void sm80_generic_fused_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.stage}, {epi_tag}>
({act_tag} const* A, {weight_tag} const* B, {act_tag} const* biases, bool bias_is_broadcast, {act_tag}* C, int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy);
"""
({act_tag} const* A, {weight_tag} const* B, {act_tag} const* biases, bool bias_is_broadcast, {act_tag}* C, int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy);"""
return instantiation
@ -340,16 +336,13 @@ def write_file(launcher_inl_files, operations, output_file):
f.write(content)
from operator import mul, truediv
def elementwise(x, y, f):
return tuple(f(a, b) for (a, b) in zip(x, y))
def is_gemm_op_valid_sm100(op):
# TODO These are much more restricted than theory dictates, investigate if more can be enabled in future
tile_m, tile_n, _ = elementwise(op.cta_shape, op.cga_shape, truediv)
tile_m, tile_n, _ = op.cta_shape
cga_m, cga_n, _ = op.cga_shape
# Default shapes
@ -366,13 +359,11 @@ def is_gemm_op_valid_sm100(op):
return False
# Shapes for fp8 small N shapes
if (op.act_type == DataType.e4m3 and (tile_n == 16 or tile_n == 8)
and (cga_m == 1 and cga_n == 1)):
# todo: double check why this is disable in CUTLASS backend. @yuhan
if tile_m == 128 and tile_n == 8:
return False
else:
return True
if (op.act_type == DataType.e4m3) and (tile_n == 16
or tile_n == 8) and (cga_m == 1
and cga_n == 1):
# todo: double check why tile_n = 8 is disabled in CUTLASS backend. @yuhan
return tile_m != 128 or tile_n % 16 == 0
# Default alignment requirements
if tile_n % 32 != 0 or tile_n < 32 or tile_n > 256:
@ -617,8 +608,6 @@ def calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype):
cta_shape_k = max_k_bits // GetDataTypeBits(dtype)
if dtype == DataType.e4m3 and (cta_shape_mn[1] == 8):
cta_shape_k = 256
if dtype == DataType.e4m3 and (cta_shape_mn[1] == 16):
cta_shape_k = 128
return cta_shape_mn + (cta_shape_k, )
@ -638,7 +627,7 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled):
epi_fusions = [
TrtLlm_EpilogueFusion.epilogue_fusion_none,
# TrtLlm_EpilogueFusion.epilogue_fusion_finalize
TrtLlm_EpilogueFusion.epilogue_fusion_finalize
]
cga_shapes = [[1, 1, 1]]
@ -648,7 +637,6 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled):
operations = list()
for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mnk, cga_shape in partial_args:
cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul)
# Ignored
mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative
@ -661,8 +649,8 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled):
for otype in otypes:
moe_gemm_operation = TrtLlm_GemmLauncher(
GemmKind.Grouped, arch, dtype, dtype, dtype, dtype, otype,
quant_op, epi_tag, cga_tile_shape_mnk, warp_shape, stages,
cga_shape, mainloop_schedule, epi_schedule, epi_fusion)
quant_op, epi_tag, cta_shape_mnk, warp_shape, stages, cga_shape,
mainloop_schedule, epi_schedule, epi_fusion)
operations.append(moe_gemm_operation)
return operations
@ -692,7 +680,7 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled):
epi_fusions = [
TrtLlm_EpilogueFusion.epilogue_fusion_none,
# TrtLlm_EpilogueFusion.epilogue_fusion_finalize
TrtLlm_EpilogueFusion.epilogue_fusion_finalize
]
cga_shapes = list(product([1, 2], [1, 2], [1]))
@ -708,7 +696,6 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled):
weight_type = dtype
cta_shape_mnk = calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype)
cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul)
# Ignored
mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative
@ -729,7 +716,7 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled):
otype,
quant_op,
epi_tag,
cga_tile_shape_mnk,
cta_shape_mnk,
warp_shape,
stages,
cga_shape,

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b25eb3a8bc1fae83eb43f9e0cf8fd93bb00f412d6cbd1bf7e2214e878bec3b4a
size 64735372
oid sha256:86586b9f6845e91e8ba0accad53a5a3418c50d8fd30ad49fa8837470c72b5dcf
size 67051604

View File

@ -1,2 +1,2 @@
16bae34717995b98ee8cff17bc8ec080c0e1b1aca02e5949be171eb8d40eff39 libtensorrt_llm_internal_cutlass_kernels_static.a
commit 995030f9b86258f3db876df6b1dbc46a7c5dae50
568cb6ca2413c93b0f5839dd05577c0c57bc4b5f2359366c79d0ace665de4bd6 libtensorrt_llm_internal_cutlass_kernels_static.a
commit 9c0a42825905952beaf9b35d5a35d58de1a123fa

View File

@ -39,11 +39,6 @@
namespace tensorrt_llm
{
template <class T>
constexpr auto transpose_stride(T const& t)
{
return cute::prepend(cute::prepend(cute::take<2, cute::rank_v<T>>(t), cute::get<0>(t)), cute::get<1>(t));
}
// Note update moe.py to match
enum class ActivationType
@ -87,8 +82,6 @@ struct GroupedGemmInput
struct TmaWarpSpecializedGroupedGemmInput
{
template <class T>
using TransposeStride = decltype(transpose_stride<T>(T{}));
template <class Tag>
using TransposeLayoutTag = std::conditional_t<std::is_same_v<Tag, cutlass::layout::RowMajor>,
cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>;
@ -101,6 +94,7 @@ struct TmaWarpSpecializedGroupedGemmInput
using LayoutA = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for A matrix operand
using LayoutB = TransposeLayoutTag<cutlass::layout::ColumnMajor>; // Layout type for B matrix operand
using LayoutC = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for C matrix operand
using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand
constexpr static int NVFP4BlockScaleVectorSize = 16;
constexpr static int MXFPXBlockScaleVectorSize = 32;
@ -122,6 +116,7 @@ struct TmaWarpSpecializedGroupedGemmInput
using StrideB
= std::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutB*>>; // Use A because they will be swapped
using StrideC = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutC*>>;
using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;
#ifdef ENABLE_FP8
template <class T>
@ -148,37 +143,26 @@ struct TmaWarpSpecializedGroupedGemmInput
StrideC* stride_c = nullptr;
void const** ptr_c = nullptr;
struct DefaultEpilogue
{
using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand
using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;
StrideD* stride_d = nullptr;
void** ptr_d = nullptr;
};
// D is used in all cases except fused finalize
StrideD* stride_d = nullptr;
void** ptr_d = nullptr;
struct FusedFinalizeEpilogue
{
using StrideFinalOutput = DefaultEpilogue::StrideD;
using StrideBias = TransposeStride<cute::Stride<cute::_0, cute::_1, int>>;
using StrideRouterScales = TransposeStride<cute::Stride<cute::_1, cute::_0>>;
using StrideFinalOutput = cutlass::detail::TagToStrideC_t<LayoutD>;
void* ptr_final_output = nullptr;
StrideFinalOutput stride_final_output{};
void const* ptr_bias = nullptr;
StrideBias stride_bias{};
void const** ptr_bias = nullptr;
float const** ptr_router_scales = nullptr;
float const* ptr_router_scales = nullptr;
StrideRouterScales stride_router_scales{};
int const** ptr_source_token_index = nullptr;
int num_rows_in_final_output = 0;
int64_t const* ptr_expert_first_token_offset = nullptr;
int const* ptr_source_token_index = nullptr;
size_t num_rows_in_final_output = 0;
bool use_reduction = true;
};
DefaultEpilogue default_epilogue;
FusedFinalizeEpilogue fused_finalize_epilogue;
enum class EpilogueFusion
@ -235,7 +219,7 @@ struct TmaWarpSpecializedGroupedGemmInput
uint8_t* gemm_workspace = nullptr;
size_t gemm_workspace_size = 0;
static std::array<size_t, 17> workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type);
static std::array<size_t, 20> workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type);
static size_t workspaceSize(int num_experts, FpXBlockScalingType scaling_type);
@ -247,9 +231,7 @@ struct TmaWarpSpecializedGroupedGemmInput
return stride_a != nullptr && ptr_a != nullptr;
}
void setFinalizeFusionParams(void* final_output, float const* router_scales,
int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size,
int num_output_tokens);
void setFinalizeFusionParams(void* final_output, int hidden_size, int num_output_tokens, bool use_reduction);
std::string toString() const;
};

View File

@ -426,7 +426,7 @@ public:
ActivationParams fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output,
int* expanded_source_row_to_expanded_dest_row, MOEParallelismConfig parallelism_config, bool use_lora,
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool use_lora,
LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
= 0;
@ -450,8 +450,8 @@ public:
void const* const fc2_expert_weights, void const* const fc2_expert_biases, void const* const fc2_int_scales,
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* permuted_row_to_unpermuted_row, int const* const expert_for_source_row,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
@ -468,7 +468,8 @@ public:
void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1,
void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream)
void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales,
int const* permuted_row_to_unpermuted_row, cudaStream_t stream)
= 0;
virtual std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
@ -485,13 +486,13 @@ public:
virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const = 0;
bool is_profiler = false;
bool use_deterministic_hopper_reduce_ = false;
bool use_fused_finalize_ = true;
};
// Assumes inputs activations are row major. Weights need to be preprocessed by th_op/weight_quantize.cc .
// Nested in a class to avoid multiple calls to cudaGetDeviceProperties as this call can be expensive.
// Avoid making several duplicates of this class.
template <typename T, /*The type used for activations*/
template <typename T, /* The type used for activations */
typename WeightType, /* The type for the MoE weights */
typename OutputType = T, /* The type for the MoE final output */
typename InputType = T, /* The type for the MoE input */
@ -573,7 +574,7 @@ public:
ActivationParams fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output,
int* expanded_source_row_to_expanded_dest_row, MOEParallelismConfig parallelism_config, bool use_lora,
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool use_lora,
LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override;
@ -603,8 +604,8 @@ public:
ScaleBiasType const* const fc2_expert_biases, ScaleBiasType const* const fc2_int_scales,
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* permuted_row_to_unpermuted_row, int const* const expert_for_source_row,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
@ -639,8 +640,8 @@ public:
void const* const fc2_expert_weights, void const* const fc2_expert_biases, void const* const fc2_int_scales,
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* permuted_row_to_unpermuted_row, int const* const expert_for_source_row,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
@ -653,11 +654,11 @@ public:
static_cast<OutputType*>(final_output), expert_first_token_offset, tma_ws_input_template,
static_cast<WeightType const*>(fc2_expert_weights), static_cast<ScaleBiasType const*>(fc2_expert_biases),
static_cast<ScaleBiasType const*>(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params,
token_topk_unpermuted_scales, token_topk_permuted_scales, expanded_source_row_to_expanded_dest_row,
expanded_dest_row_to_expanded_source_row, expert_for_source_row, num_valid_tokens_ptr, num_rows,
expanded_num_rows, hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array,
use_lora, fc2_lora, stream, parallelism_config, config, min_latency_mode, num_active_experts_per,
active_expert_global_ids, start_expert);
token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row,
permuted_row_to_unpermuted_row, expert_for_source_row, num_valid_tokens_ptr, num_rows, expanded_num_rows,
hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, use_lora, fc2_lora,
stream, parallelism_config, config, min_latency_mode, num_active_experts_per, active_expert_global_ids,
start_expert);
}
virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override
@ -673,7 +674,8 @@ public:
void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1,
void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream) override
void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales,
int const* permuted_row_to_unpermuted_row, cudaStream_t stream) override
{
return Self::computeStridesTmaWarpSpecialized(expert_first_token_offset, layout_info1, layout_info2, num_tokens,
expanded_num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts_per_node,
@ -682,7 +684,8 @@ public:
alpha_scale_flat1, alpha_scale_flat2, fp4_act_flat1, fp4_act_flat2, quant_params,
reinterpret_cast<ScaleBiasType const*>(bias1), reinterpret_cast<ScaleBiasType const*>(bias2),
reinterpret_cast<UnfusedGemmOutputType*>(gemm1_output),
reinterpret_cast<UnfusedGemmOutputType*>(gemm2_output), stream);
reinterpret_cast<UnfusedGemmOutputType*>(gemm2_output), router_scales, permuted_row_to_unpermuted_row,
stream);
}
std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
@ -724,7 +727,8 @@ private:
float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params,
ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output,
UnfusedGemmOutputType* gemm2_output, cudaStream_t stream);
UnfusedGemmOutputType* gemm2_output, float const* router_scales, int const* permuted_row_to_unpermuted_row,
cudaStream_t stream);
static std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
computeStridesTmaWarpSpecializedLowLatency(TmaWarpSpecializedGroupedGemmInput layout_info1,
TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k,
@ -754,8 +758,8 @@ private:
bool mayHaveFinalizeFused() const
{
return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() == 90
&& !use_deterministic_hopper_reduce_ && !use_w4afp8;
return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() >= 90 && use_fused_finalize_
&& !use_w4afp8;
}
// TODO: This should eventually take the quant params to give more flexibility
@ -791,7 +795,7 @@ private:
static void BlockScaleFC2(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output,
OutputType* const final_output, int64_t const* const expert_first_token_offset,
WeightType const* const fc2_expert_weights, ScaleBiasType const* const fc2_expert_biases,
float const* const token_topk_unpermuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
float const* const token_topk_unpermuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* const expert_for_source_row, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows,
int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size,
int const num_experts_per_node, int64_t const k, MOEParallelismConfig parallelism_config,

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3767deac592a204493b09f6798d50269c90d4571971b1746a5e6d0009a6d6d65
size 64229720
oid sha256:6489751f16a4dadf42664738ded03fbbd60195619f2d5f80af8190554318257d
size 66872936

View File

@ -1,2 +1,2 @@
f68113dae0236968594276bf4f8b0a6f9161d3fbbac6fcb9ea1a438d16055490 libtensorrt_llm_internal_cutlass_kernels_static.a
commit 995030f9b86258f3db876df6b1dbc46a7c5dae50
813c237a565664b2acf2313f0e436f66f24deeb16a84d273dc007af55795e55f libtensorrt_llm_internal_cutlass_kernels_static.a
commit 9c0a42825905952beaf9b35d5a35d58de1a123fa

View File

@ -334,12 +334,13 @@ void MixtureOfExpertsPlugin::init()
static_cast<int>(mType), static_cast<int>(mWeightType), static_cast<int>(mOutputType));
}
mMOERunner->use_deterministic_hopper_reduce_ = mExpertsPerToken > 2 && mUseDeterministicKernels;
mMOERunner->use_fused_finalize_
= (mExpertsPerToken < 3 || !mUseDeterministicKernels) && !getEnvMOEDisableFinalizeFusion();
mGemmId1 = GemmIDMoe{1, mNumExperts, mExpertsPerToken, mParallelismConfig, mExpertHiddenSize, mExpertInterSize,
mGroupSize, mActivationType, mType, mWeightType, mQuantMode, mMOERunner->use_deterministic_hopper_reduce_};
mGroupSize, mActivationType, mType, mWeightType, mQuantMode, !mMOERunner->use_fused_finalize_};
mGemmId2 = GemmIDMoe{2, mNumExperts, mExpertsPerToken, mParallelismConfig, mExpertHiddenSize, mExpertInterSize,
mGroupSize, mActivationType, mType, mWeightType, mQuantMode, mMOERunner->use_deterministic_hopper_reduce_};
mGroupSize, mActivationType, mType, mWeightType, mQuantMode, !mMOERunner->use_fused_finalize_};
mGemmProfiler->setMaxProfileM(16384 * mNumExperts / mExpertsPerToken);
if (hasLora())

View File

@ -95,7 +95,8 @@ public:
};
FusedMoeRunner(c10::ScalarType activation_dtype, c10::ScalarType weight_dtype, c10::ScalarType output_dtype,
bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_mxfp8_act_scaling)
bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_mxfp8_act_scaling,
bool use_fused_finalize)
{
mActivationDtype = activation_dtype;
mWeightDtype = weight_dtype;
@ -103,6 +104,7 @@ public:
mUseDeepSeekFP8BlockScaling = use_deepseek_fp8_block_scale;
mUseW4GroupScaling = use_w4_group_scaling;
mUseMxfp8ActScaling = use_mxfp8_act_scaling;
mUseFusedFinalize = use_fused_finalize;
mInnerDimMultiplier = 1;
// keep consistent with cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp
@ -213,6 +215,8 @@ public:
<< ", Output: " << torch::toString(mOutputDtype));
}
mKernelRunner->use_fused_finalize_ = mUseFusedFinalize;
mProfiler = std::make_shared<kernels::GemmProfilerBackend>();
mAllProfiles = mKernelRunner->getTactics();
}
@ -674,6 +678,7 @@ private:
bool mUseDeepSeekFP8BlockScaling = false;
bool mUseW4GroupScaling = false;
bool mUseMxfp8ActScaling = false;
bool mUseFusedFinalize = true;
using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
std::vector<Profile> mAllProfiles;
@ -1045,7 +1050,7 @@ private:
TORCH_LIBRARY(trtllm, m)
{
m.class_<torch_ext::FusedMoeRunner>("FusedMoeRunner")
.def(torch::init<c10::ScalarType, c10::ScalarType, c10::ScalarType, bool, bool, bool>())
.def(torch::init<c10::ScalarType, c10::ScalarType, c10::ScalarType, bool, bool, bool, bool>())
.def("run_gemm_profile", &torch_ext::FusedMoeRunner::runGemmProfile)
.def("get_tactic_num", &torch_ext::FusedMoeRunner::getTacticNum)
.def("run_moe", &torch_ext::FusedMoeRunner::runMoe)

View File

@ -1,3 +1,19 @@
/*
* Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h"
@ -355,7 +371,7 @@ protected:
float mSparseMixerEpsilon = 0.2f;
// Default this to true. This only matters for K>2, and so by doing this we will test the fused and unfused paths
bool mUseDeterminsiticHopperReduce = true;
bool mUseDeterministicHopperReduce = true;
// Disable this for long running tests to speed up runtime
bool mIsLongTest = false;
@ -440,7 +456,7 @@ protected:
{
managed_buffers.clear();
mMoERunner.use_deterministic_hopper_reduce_ = k > 2 && mUseDeterminsiticHopperReduce;
mMoERunner.use_fused_finalize_ = k < 3 || !mUseDeterministicHopperReduce;
mHiddenSize = hidden_size;
mInterSize = hidden_size * mInterSizeFraction;
@ -1614,7 +1630,7 @@ void MixtureOfExpertsTest<TypeParam_>::BasicPermuteTest(
runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k);
bool should_be_deterministic
= mUseDeterminsiticHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
= mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
if (should_be_deterministic && !mIsLongTest)
{
auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize);
@ -1733,7 +1749,7 @@ TYPED_TEST(MixtureOfExpertsTest, PermuteSwigluBias)
TYPED_TEST(MixtureOfExpertsTest, PermuteNonDeterministic)
{
this->mUseDeterminsiticHopperReduce = false;
this->mUseDeterministicHopperReduce = false;
// Just test case 3, cases 1&2 always use the fused paths
this->BasicPermuteTest(3);
}
@ -1881,7 +1897,7 @@ void MixtureOfExpertsTest<TypeParam_>::ParallelismTest(
runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k,
MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall);
bool should_be_deterministic
= mUseDeterminsiticHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
= mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
if (should_be_deterministic && !mIsLongTest)
{
auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize);
@ -1897,7 +1913,7 @@ void MixtureOfExpertsTest<TypeParam_>::ParallelismTest(
{
runMoEPermute(MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall);
bool should_be_deterministic
= mUseDeterminsiticHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
= mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
if (should_be_deterministic && !mIsLongTest)
{
auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize);

View File

@ -68,7 +68,7 @@ ignore_patterns = [
[tool.codespell]
skip = ".git,3rdparty,tests/integration/test_input_files**,**.jsonl,**.json"
exclude-file = "examples/models/core/whisper/tokenizer.py"
ignore-words-list = "rouge,inout,atleast,strat,nd,subtile,thrid,improbe,NotIn,te,iteract,anythin,tru,Tracin,vEw"
ignore-words-list = "rouge,inout,atleast,strat,nd,subtile,thrid,improbe,NotIn,te,iteract,anythin,tru,Tracin,vEw,dOut"
[tool.autoflake]
in-place = true

View File

@ -458,7 +458,7 @@ def _register_fake():
gemm2_output: torch.Tensor,
fc2_expert_biases: torch.Tensor,
unpermuted_final_scales: torch.Tensor,
expanded_source_row_to_expanded_dest_row: torch.Tensor,
unpermuted_row_to_permuted_row: torch.Tensor,
expert_for_source_row: torch.Tensor,
expert_first_token_offset_tensor: torch.Tensor,
num_rows: torch.SymInt,

View File

@ -42,6 +42,7 @@ class MoERunner(TunableRunner):
use_w4_group_scaling: bool,
use_mxfp8_act_scaling: bool,
min_latency_mode: bool,
use_fused_finalize: bool,
):
self.x_dtype = x_dtype
self.weight_dtype = weight_dtype
@ -59,6 +60,8 @@ class MoERunner(TunableRunner):
self.use_w4_group_scaling = use_w4_group_scaling
self.use_mxfp8_act_scaling = use_mxfp8_act_scaling
self.min_latency_mode = min_latency_mode
self.use_fused_finalize = use_fused_finalize
instance_key = (x_dtype, weight_dtype, output_dtype,
use_deepseek_fp8_block_scale, use_w4_group_scaling,
use_mxfp8_act_scaling)
@ -68,7 +71,7 @@ class MoERunner(TunableRunner):
instance_key] = torch.classes.trtllm.FusedMoeRunner(
x_dtype, weight_dtype, output_dtype,
use_deepseek_fp8_block_scale, use_w4_group_scaling,
use_mxfp8_act_scaling)
use_mxfp8_act_scaling, use_fused_finalize)
self.fused_moe_runner = MoERunner.runner_dict[instance_key]
def get_valid_tactics(
@ -143,6 +146,7 @@ def fused_moe(
use_w4_group_scaling: bool = False,
use_mxfp8_act_scaling: bool = False,
min_latency_mode: bool = False,
use_fused_finalize: bool = True,
tune_max_num_tokens: int = 8192,
tuner_num_tokens: Optional[int] = None,
tuner_top_k: Optional[int] = None,
@ -179,6 +183,7 @@ def fused_moe(
use_w4_group_scaling=use_w4_group_scaling,
use_mxfp8_act_scaling=use_mxfp8_act_scaling,
min_latency_mode=min_latency_mode,
use_fused_finalize=use_fused_finalize,
)
_, gemm_tactic_1 = tuner.choose_one(
@ -259,6 +264,7 @@ def _(
use_w4_group_scaling: bool = False,
use_mxfp8_act_scaling: bool = False,
min_latency_mode: bool = False,
use_fused_finalize: bool = True,
tune_max_num_tokens: int = 8192,
):
seq_len = input.shape[0]

View File

@ -83,6 +83,9 @@ class ModelConfig(Generic[TConfig]):
attn_backend: str = 'TRTLLM'
moe_backend: str = 'CUTLASS' # options can be CUTLASS, TRTLLM
# IF true, disables FC2+finalize fusion in CUTLASS MoE backend
moe_disable_finalize_fusion: bool = False
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO
# If true, enable min-latency mode. Currently only used for Llama4.

View File

@ -152,6 +152,8 @@ class CutlassFusedMoE(MoE):
# If True, the router weight will be multiplied on the input rather than at the end of FC2
self.apply_router_weight_on_input = apply_router_weight_on_input
self.use_fused_finalize = not model_config.moe_disable_finalize_fusion
self._weights_created = False
if not model_config.skip_create_weights_in_init:
self.create_weights()
@ -421,6 +423,7 @@ class CutlassFusedMoE(MoE):
use_w4_group_scaling=use_w4_group_scaling,
use_mxfp8_act_scaling=use_mxfp8_act_scaling,
min_latency_mode=False,
use_fused_finalize=self.use_fused_finalize,
tune_max_num_tokens=self.tune_max_num_tokens,
tuner_num_tokens=tuner_num_tokens,
tuner_top_k=tuner_top_k,

View File

@ -53,6 +53,8 @@ class PyTorchConfig:
attn_backend: str = 'TRTLLM'
moe_backend: str = 'CUTLASS'
moe_disable_finalize_fusion: bool = False
enable_mixed_sampler: bool = False
"""
If true, will iterate over sampling_params of each request and use the

View File

@ -305,6 +305,8 @@ class PyTorchModelEngine(ModelEngine):
checkpoint_loader=checkpoint_loader,
attn_backend=attn_backend,
moe_backend=pytorch_backend_config.moe_backend,
moe_disable_finalize_fusion=pytorch_backend_config.
moe_disable_finalize_fusion,
load_format=pytorch_backend_config.load_format,
max_num_tokens=max_num_tokens,
moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens,

View File

@ -183,6 +183,12 @@ class MoeConfig(StrictBaseModel):
description="Configuration for MoE load balancing.",
json_schema_extra={"type": "Union[MoeLoadBalancerConfig, str]"})
disable_finalize_fusion: bool = Field(
default=False,
description=
"Disable FC2+finalize kernel fusion in CUTLASS MoE backend. Setting this to True recovers deterministic numerical behavior with top-k > 2."
)
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
@ -2365,6 +2371,7 @@ class TorchLlmArgs(BaseLlmArgs):
enable_layerwise_nvtx_marker=self.enable_layerwise_nvtx_marker,
load_format=self.load_format,
enable_min_latency=self.enable_min_latency,
moe_disable_finalize_fusion=self.moe_config.disable_finalize_fusion,
stream_interval=self.stream_interval,
force_dynamic_quantization=self.force_dynamic_quantization,
allreduce_strategy=self.allreduce_strategy,

View File

@ -1075,7 +1075,7 @@ def test_fused_moe_nvfp4(dtype):
# compare
torch.cuda.synchronize()
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.15)
@skip_neither_ada_nor_hopper_unittest
@ -1320,7 +1320,7 @@ def test_fused_moe_mxfp4_mxpf8(moe_backend, bias):
# compare
torch.cuda.synchronize()
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.15)
@skip_non_hopper_unittest