mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] CUTLASS MoE FC2+Finalize fusion (#3294)
Signed-off-by: Sergey Klevtsov <sklevtsov@nvidia.com>
This commit is contained in:
parent
0dc4b4e699
commit
27fc35175e
@ -366,6 +366,12 @@ bool getEnvForceDeterministicMOE()
|
||||
return forceDeterministic;
|
||||
}
|
||||
|
||||
bool getEnvMOEDisableFinalizeFusion()
|
||||
{
|
||||
static bool const moeDisableFinalizeFusion = getBoolEnv("TRTLLM_MOE_DISABLE_FINALIZE_FUSION");
|
||||
return moeDisableFinalizeFusion;
|
||||
}
|
||||
|
||||
bool getEnvForceDeterministicAttention()
|
||||
{
|
||||
static bool const forceDeterministic
|
||||
|
||||
@ -86,6 +86,9 @@ bool getEnvForceDeterministic();
|
||||
// Force deterministic behavior for MoE plugin.
|
||||
bool getEnvForceDeterministicMOE();
|
||||
|
||||
// Disable finalize fusion in MoE plugin
|
||||
bool getEnvMOEDisableFinalizeFusion();
|
||||
|
||||
// Force deterministic behavior for attention plugin.
|
||||
bool getEnvForceDeterministicAttention();
|
||||
|
||||
|
||||
@ -1,568 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/*! \file
|
||||
\brief Functor performing elementwise operations used by epilogues.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/detail.hpp"
|
||||
#include "cutlass/fast_math.h"
|
||||
|
||||
#include "cute/numeric/numeric_types.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/arch/copy_red_global.hpp"
|
||||
#include "cutlass_extensions/util/gather_tensor.hpp"
|
||||
|
||||
#include "cutlass/epilogue/collective/builders/sm90_builder.inl"
|
||||
#include "cutlass/epilogue/collective/builders/sm90_common.inl"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace epilogue
|
||||
{
|
||||
namespace collective
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class StrideC_, class ElementD_, class StrideD_, class ThreadEpilogueOp_, class ElementBias, class StrideBias,
|
||||
class ElementScale, class StrideScale, class EpilogueTile, class SmemLayoutAtomD, class CopyOpR2S, class CopyOpS2R,
|
||||
class CopyOpR2G>
|
||||
class EpilogueMoeFusedFinalize
|
||||
{
|
||||
public:
|
||||
using EpilogueSchedule = PtrArrayNoSmemWarpSpecialized;
|
||||
using DispatchPolicy = PtrArrayNoSmemWarpSpecialized;
|
||||
|
||||
using ThreadEpilogueOp = ThreadEpilogueOp_;
|
||||
using ElementOutput = typename ThreadEpilogueOp::ElementOutput;
|
||||
using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
|
||||
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
|
||||
using ElementIntermediate = typename ThreadEpilogueOp::ElementD;
|
||||
|
||||
using ElementC = typename ThreadEpilogueOp::ElementC;
|
||||
using StrideC = StrideC_;
|
||||
using InternalStrideC = cute::remove_pointer_t<StrideC>;
|
||||
using ElementD = ElementD_;
|
||||
using StrideD = StrideD_;
|
||||
using InternalStrideD = cute::remove_pointer_t<StrideD>;
|
||||
|
||||
static_assert(!is_same_v<InternalStrideC, StrideC>, "Stride C must be a pointer");
|
||||
static_assert(is_same_v<InternalStrideD, StrideD>, "Stride D must not be a pointer");
|
||||
|
||||
using CopyAtomR2S = Copy_Atom<CopyOpR2S, ElementAccumulator>;
|
||||
using CopyAtomS2R = Copy_Atom<CopyOpS2R, ElementAccumulator>;
|
||||
using CopyAtomR2G = Copy_Atom<CopyOpR2G, ElementD>;
|
||||
static constexpr int AlignmentD = CopyAtomR2G::NumValSrc;
|
||||
|
||||
using SmemLayoutD = decltype(tile_to_shape(SmemLayoutAtomD{}, EpilogueTile{}));
|
||||
|
||||
constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{});
|
||||
|
||||
struct SharedStorage
|
||||
{
|
||||
alignas(SmemAlignmentD) cute::ArrayEngine<ElementAccumulator, cosize_v<SmemLayoutD>> smem_D;
|
||||
};
|
||||
|
||||
struct TensorMapStorage
|
||||
{
|
||||
};
|
||||
|
||||
struct Arguments
|
||||
{
|
||||
typename ThreadEpilogueOp::Params thread{};
|
||||
ElementC const** ptr_C{};
|
||||
StrideC dC{};
|
||||
ElementD* ptr_D{};
|
||||
StrideD dD{};
|
||||
ElementBias const* ptr_bias;
|
||||
StrideBias dBias{};
|
||||
ElementScale const* ptr_scale;
|
||||
StrideScale dScale{};
|
||||
int64_t const* group_offset{};
|
||||
int32_t const* scatter_index{};
|
||||
cutlass::FastDivmod num_rows_in_final_output;
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params to_underlying_arguments(
|
||||
ProblemShape const&, Arguments const& args, [[maybe_unused]] void* workspace)
|
||||
{
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count = 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args,
|
||||
void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr)
|
||||
{
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
CUTLASS_HOST_DEVICE static bool can_implement(
|
||||
[[maybe_unused]] ProblemShape problem_shape, [[maybe_unused]] Arguments const& args)
|
||||
{
|
||||
bool implementable = true;
|
||||
if (problem_shape.is_host_problem_shape_available())
|
||||
{
|
||||
// Check alignment for all problem sizes
|
||||
for (int i = 0; i < problem_shape.groups(); i++)
|
||||
{
|
||||
auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
implementable = implementable
|
||||
&& cutlass::detail::check_alignment<AlignmentD>(cute::make_shape(M, N, L), InternalStrideD{});
|
||||
}
|
||||
}
|
||||
|
||||
if (!implementable)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(
|
||||
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for selected global "
|
||||
"reduction instruction.\n");
|
||||
}
|
||||
return implementable;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
EpilogueMoeFusedFinalize(Params const& params_)
|
||||
: params(params_)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_source_needed()
|
||||
{
|
||||
// For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta.
|
||||
return params.ptr_C != nullptr
|
||||
&& (params.thread.beta_ptr_array || params.thread.beta_ptr || params.thread.beta != 0);
|
||||
}
|
||||
|
||||
template <class ProblemShapeMNKL, class BlockShapeMNK, class BlockCoordMNKL, class FrgEngine, class FrgLayout,
|
||||
class TiledMma, class ResidueMNK>
|
||||
CUTLASS_HOST_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK,
|
||||
BlockCoordMNKL blk_coord_mnkl, cute::Tensor<FrgEngine, FrgLayout> const& accumulators, TiledMma tiled_mma,
|
||||
ResidueMNK residue_mnk, int thread_idx, [[maybe_unused]] char* smem_buf)
|
||||
{
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
|
||||
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
|
||||
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
|
||||
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
|
||||
|
||||
auto synchronize = [&]()
|
||||
{ cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||
|
||||
// Separate out problem shape for convenience
|
||||
auto M = get<0>(problem_shape_mnkl);
|
||||
auto N = get<1>(problem_shape_mnkl);
|
||||
auto L = get<3>(problem_shape_mnkl);
|
||||
|
||||
auto mma_tile_m = tile_size<0>(tiled_mma);
|
||||
auto mma_tile_n = tile_size<1>(tiled_mma);
|
||||
auto epi_tile_m = size<0>(EpilogueTile{});
|
||||
auto epi_tile_n = size<1>(EpilogueTile{});
|
||||
|
||||
CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M");
|
||||
CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N");
|
||||
|
||||
// Batches are managed by using appropriate pointers to C and D matrices
|
||||
int32_t const mock_L = 1;
|
||||
int32_t const mock_l_coord = 0;
|
||||
|
||||
// Slice to get the tile this CTA is responsible for
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
|
||||
|
||||
// If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups.
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups,
|
||||
// we get the correct alpha/beta values for the current batch/group using group index.
|
||||
ThreadEpilogueOp epilogue_op(params.thread, l_coord);
|
||||
|
||||
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
||||
|
||||
Tensor sD_ = make_tensor(make_smem_ptr(storage.smem_D.begin()), SmemLayoutD{});
|
||||
Tensor sD = as_position_independent_swizzle_tensor(sD_);
|
||||
|
||||
// Function to scatter output rows
|
||||
auto& num_rows = params.num_rows_in_final_output;
|
||||
auto read_scatter_map = tensorrt_llm::cutlass_extensions::IndexedGather(
|
||||
make_gmem_ptr(params.scatter_index + params.group_offset[l_coord]));
|
||||
auto get_scatter_idx = [&](auto i)
|
||||
{
|
||||
auto scatter = read_scatter_map(i);
|
||||
int quot, rem;
|
||||
num_rows(quot, rem, scatter);
|
||||
return rem;
|
||||
};
|
||||
|
||||
// Represent the full output tensor
|
||||
ElementC const* ptr_C = epilogue_op.is_source_needed() ? params.ptr_C[l_coord] : nullptr;
|
||||
auto dC = epilogue_op.is_source_needed() ? params.dC[l_coord] : InternalStrideC{};
|
||||
Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C), make_shape(M, N, mock_L), dC); // (m,n,l)
|
||||
Tensor mD_mnl = tensorrt_llm::cutlass_extensions::make_gather_tensor(
|
||||
make_gmem_ptr(params.ptr_D), make_shape(M, N, mock_L), params.dD, get_scatter_idx); // (m,n,l)
|
||||
|
||||
// Use fake shape for bias, it doesn't matter
|
||||
bool const is_bias_needed = params.ptr_bias != nullptr;
|
||||
Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_bias), make_shape(M, N, 1), params.dBias);
|
||||
Tensor mScale_mnl = make_tensor(
|
||||
make_gmem_ptr(params.ptr_scale + params.group_offset[l_coord]), make_shape(M, N), params.dScale);
|
||||
|
||||
Tensor gC_mnl
|
||||
= local_tile(mC_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
||||
Tensor gD_mnl
|
||||
= local_tile(mD_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
||||
|
||||
Tensor gC = gC_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N)
|
||||
Tensor gD = gD_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N)
|
||||
|
||||
Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
|
||||
Tensor gBias_mnl
|
||||
= local_tile(mBias_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
||||
Tensor gScale_mnl
|
||||
= local_tile(mScale_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
||||
|
||||
Tensor gBias = gBias_mnl(_, _, m_coord, n_coord, l_coord); // (BLK_M,BLK_N)
|
||||
Tensor gScale = gScale_mnl(_, _, m_coord, n_coord); // (BLK_M,BLK_N)
|
||||
|
||||
Tensor gBias_epi = flat_divide(gBias, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor gScale_epi = flat_divide(gScale, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
|
||||
// Get the smallest tiled copy we can use to retile the accumulators
|
||||
TiledCopy tiled_copy_C_atom
|
||||
= make_tiled_copy_C_atom(Copy_Atom<SM90_U32x4_STSM_N, cutlass::half_t>{}, tiled_mma);
|
||||
TiledCopy tiled_r2s = make_tiled_copy_S(CopyAtomR2S{}, tiled_copy_C_atom);
|
||||
|
||||
auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx);
|
||||
Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N)
|
||||
Tensor tRS_sD = thread_r2s.partition_D(sD); // ((R2S,R2S_V),R2S_M,R2S_N)
|
||||
Tensor tRS_rD = make_tensor<ElementAccumulator>(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N)
|
||||
|
||||
// Make a tiled copy vectorized along major direction of D
|
||||
auto tiled_s2r = [&]()
|
||||
{
|
||||
if constexpr (cutlass::gemm::detail::is_k_major<StrideD>())
|
||||
{
|
||||
constexpr int NumThreadsMajor = epi_tile_n / AlignmentD;
|
||||
constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
|
||||
return make_tiled_copy(CopyAtomS2R{},
|
||||
Layout<Shape<Int<NumThreadsMinor>, Int<NumThreadsMajor>>, Stride<Int<NumThreadsMajor>, _1>>{},
|
||||
Layout<Shape<_1, Int<AlignmentD>>>{});
|
||||
}
|
||||
else if constexpr (cutlass::gemm::detail::is_mn_major<StrideD>())
|
||||
{
|
||||
constexpr int NumThreadsMajor = epi_tile_m / AlignmentD;
|
||||
constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
|
||||
return make_tiled_copy(CopyAtomS2R{},
|
||||
Layout<Shape<Int<NumThreadsMajor>, Int<NumThreadsMinor>>, Stride<_1, Int<NumThreadsMajor>>>{},
|
||||
Layout<Shape<Int<AlignmentD>, _1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(cute::is_void_v<StrideD>, "Unsupported D gmem layout.");
|
||||
}
|
||||
}();
|
||||
|
||||
auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx);
|
||||
Tensor tSR_sD = thread_s2r.partition_S(sD); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
Tensor tSR_gD = thread_s2r.partition_D(gD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
Tensor tSR_gC = thread_s2r.partition_D(gC_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
Tensor tSR_gBias = thread_s2r.partition_D(gBias_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
|
||||
// Allocate intermediate registers for a single subtile
|
||||
Tensor tSR_rD = make_tensor<ElementAccumulator>(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
Tensor tSR_rD_final = make_tensor<ElementD>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
Tensor tSR_rC = make_tensor<ElementC>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
Tensor tSR_rBias = make_tensor<ElementBias>(tSR_gBias(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
Tensor tSR_rScale = make_tensor<ElementScale>(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
|
||||
// Make an identity coordinate tensor for predicating our output MN tile
|
||||
Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD))));
|
||||
Tensor cD_epi = flat_divide(cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor tSR_cD = thread_s2r.partition_D(cD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
|
||||
// epilogue subtile loop
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n)
|
||||
{
|
||||
int mma_m = (epi_m * epi_tile_m) / mma_tile_m;
|
||||
int mma_n = (epi_n * epi_tile_n) / mma_tile_n;
|
||||
Tensor tRS_rAcc_mn = tRS_rAcc(_, mma_m, mma_n);
|
||||
|
||||
int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n);
|
||||
int r2s_v = epi_n_in_mma * size(tRS_rD);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int epi_v = 0; epi_v < size(tRS_rD); ++epi_v)
|
||||
{
|
||||
tRS_rD(epi_v) = tRS_rAcc_mn(r2s_v + epi_v);
|
||||
}
|
||||
|
||||
copy(tiled_r2s, tRS_rD, tRS_sD);
|
||||
synchronize();
|
||||
|
||||
copy(tiled_s2r, tSR_sD, tSR_rD);
|
||||
synchronize();
|
||||
|
||||
Tensor tSR_gC_mn = tSR_gC(_, _, _, epi_m, epi_n);
|
||||
Tensor tSR_gBias_mn = tSR_gBias(_, _, _, epi_m, epi_n);
|
||||
Tensor tSR_gScale_mn = tSR_gScale(_, _, _, epi_m, epi_n);
|
||||
Tensor tSR_cD_mn = tSR_cD(_, _, _, epi_m, epi_n);
|
||||
Tensor tSR_gD_mn = tSR_gD(_, _, _, epi_m, epi_n);
|
||||
|
||||
if (epilogue_op.is_source_needed())
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < size<1>(tSR_rD); ++m)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < size<2>(tSR_rD); ++n)
|
||||
{
|
||||
if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk))))
|
||||
{
|
||||
copy(tSR_gC_mn(_, m, n), tSR_rC(_, m, n));
|
||||
if (is_bias_needed)
|
||||
{
|
||||
copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n));
|
||||
}
|
||||
copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(tSR_rD); ++i)
|
||||
{
|
||||
auto epi_value = epilogue_op(tSR_rD(i, m, n), tSR_rC(i, m, n));
|
||||
if (is_bias_needed)
|
||||
{
|
||||
epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n));
|
||||
}
|
||||
tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value);
|
||||
}
|
||||
copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < size<1>(tSR_rD); ++m)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < size<2>(tSR_rD); ++n)
|
||||
{
|
||||
if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk))))
|
||||
{
|
||||
if (is_bias_needed)
|
||||
{
|
||||
copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n));
|
||||
}
|
||||
copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(tSR_rD); ++i)
|
||||
{
|
||||
auto epi_value = epilogue_op(tSR_rD(i, m, n));
|
||||
if (is_bias_needed)
|
||||
{
|
||||
epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n));
|
||||
}
|
||||
tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value);
|
||||
}
|
||||
copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Params params;
|
||||
};
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
template <class Element, class MaxVec>
|
||||
constexpr auto get_vectorized_atomic_add_op()
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
auto constexpr MaxVecSize = size(MaxVec{});
|
||||
|
||||
if constexpr (is_same_v<Element, cutlass::half_t>)
|
||||
{
|
||||
if constexpr (MaxVecSize >= 8)
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_F16x2_V4{};
|
||||
}
|
||||
else if constexpr (MaxVecSize >= 4)
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_F16x2_V2{};
|
||||
}
|
||||
else if constexpr (MaxVecSize >= 2)
|
||||
{
|
||||
return SM70_RED_ADD_NOFTZ_F16x2{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return SM70_RED_ADD_NOFTZ_F16{};
|
||||
}
|
||||
}
|
||||
else if constexpr (is_same_v<Element, cutlass::bfloat16_t>)
|
||||
{
|
||||
if constexpr (MaxVecSize >= 8)
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_BF16x2_V4{};
|
||||
}
|
||||
else if constexpr (MaxVecSize >= 4)
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_BF16x2_V2{};
|
||||
}
|
||||
else if constexpr (MaxVecSize >= 2)
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_BF16x2{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_BF16{};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// non-vectorized atomic add for all other types until supported
|
||||
return TypedAtomicAdd<Element>{};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <class Arch, class TileShape, class ElementC, class StrideC, class ElementD, class StrideD,
|
||||
class ElementAccumulator, class ElementCompute, class ElementBias, class StrideBias, class ElementScale,
|
||||
class StrideScale>
|
||||
struct EpilogueMoeFusedFinalizeBuilder
|
||||
{
|
||||
|
||||
// assuming cooperative kernel schedule
|
||||
using EpiTileN = decltype(cute::min(size<1>(TileShape{}), _32{}));
|
||||
using EpilogueTile = Shape<_128, EpiTileN>;
|
||||
|
||||
// Output of linear combination is ElementCompute instead of ElementD
|
||||
// since we will be doing more computate on it, no need to cast yet.
|
||||
using ThreadEpilogueOp
|
||||
= cutlass::epilogue::thread::LinearCombination<ElementCompute, 1, ElementAccumulator, ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::Default, cutlass::FloatRoundStyle::round_to_nearest, ElementC>;
|
||||
|
||||
using SmemLayoutAtomD
|
||||
= decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<StrideD, ElementAccumulator, EpilogueTile>());
|
||||
using CopyAtomR2S
|
||||
= decltype(detail::sm90_get_smem_store_op_for_accumulator<StrideD, ElementAccumulator, EpilogueTile>());
|
||||
using CopyAtomS2R = DefaultCopy;
|
||||
using CopyAtomR2G = decltype(detail::get_vectorized_atomic_add_op<ElementD, EpiTileN>());
|
||||
|
||||
template <class Base, class EpilogueOp>
|
||||
struct TmaWarpSpecializedAdapterWithSmemStorageImpl : Base
|
||||
{
|
||||
// We need to override this one using declaration because otherwise we double up on the smem
|
||||
using TensorMapStorage = typename EpilogueOp::TensorMapStorage;
|
||||
|
||||
// using Base = detail::Sm90TmaWarpSpecializedAdapter<EpilogueOp>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
TmaWarpSpecializedAdapterWithSmemStorageImpl(
|
||||
typename EpilogueOp::Params const& params, [[maybe_unused]] typename Base::TensorStorage& shared_tensors)
|
||||
: Base(params)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE auto load_init([[maybe_unused]] typename EpilogueOp::Params const& params,
|
||||
[[maybe_unused]] TensorMapStorage& shared_tensormaps, [[maybe_unused]] int32_t sm_count,
|
||||
[[maybe_unused]] int32_t sm_idx)
|
||||
{
|
||||
return cute::make_tuple(nullptr);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE auto store_init([[maybe_unused]] typename EpilogueOp::Params const& params,
|
||||
[[maybe_unused]] TensorMapStorage& shared_tensormaps, [[maybe_unused]] int32_t sm_count,
|
||||
[[maybe_unused]] int32_t sm_idx, [[maybe_unused]] int32_t warp_group_idx)
|
||||
{
|
||||
return cute::make_tuple(nullptr);
|
||||
}
|
||||
|
||||
// Dummy methods to perform different parts of TMA/Tensormap modifications
|
||||
|
||||
template <bool IsLoad, class ProblemShapeMNKL>
|
||||
CUTLASS_DEVICE void tensormaps_perform_update([[maybe_unused]] TensorMapStorage& shared_tensormaps,
|
||||
[[maybe_unused]] typename EpilogueOp::Params const& params,
|
||||
[[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] ProblemShapeMNKL problem_shape,
|
||||
[[maybe_unused]] int32_t next_batch, [[maybe_unused]] int32_t warp_group_idx)
|
||||
{
|
||||
}
|
||||
|
||||
template <bool IsLoad>
|
||||
CUTLASS_DEVICE void tensormaps_cp_fence_release([[maybe_unused]] TensorMapStorage& shared_tensormaps,
|
||||
[[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] int32_t warp_group_idx)
|
||||
{
|
||||
}
|
||||
|
||||
template <bool IsLoad>
|
||||
CUTLASS_DEVICE void tensormaps_fence_acquire([[maybe_unused]] cute::TmaDescriptor const* tensormap)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
template <class EpilogueOp>
|
||||
using TmaWarpSpecializedAdapterWithSmemStorage = TmaWarpSpecializedAdapterWithSmemStorageImpl<
|
||||
std::conditional_t<Arch::kMinComputeCapability >= 100, detail::Sm100TmaWarpSpecializedAdapter<EpilogueOp>,
|
||||
detail::Sm90TmaWarpSpecializedAdapter<EpilogueOp>>,
|
||||
EpilogueOp>;
|
||||
|
||||
using CollectiveOp = TmaWarpSpecializedAdapterWithSmemStorage<
|
||||
EpilogueMoeFusedFinalize<StrideC, ElementD, StrideD, ThreadEpilogueOp, ElementBias, StrideBias, ElementScale,
|
||||
StrideScale, EpilogueTile, SmemLayoutAtomD, CopyAtomR2S, CopyAtomS2R, CopyAtomR2G>>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace collective
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,547 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Visitor tree store operations for the sm90 TMA warp-specialized (ws) epilogue
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/epilogue/fusion/operations.hpp"
|
||||
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
|
||||
|
||||
#include "cutlass_extensions/arch/copy_red_global.hpp"
|
||||
#include "cutlass_extensions/util/gather_tensor.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// clang-format off
|
||||
|
||||
namespace cutlass::epilogue::fusion {
|
||||
|
||||
using namespace cute;
|
||||
using namespace detail;
|
||||
|
||||
template <
|
||||
class EpilogueTile,
|
||||
class StrideOutput,
|
||||
class SmemLayoutAtom,
|
||||
class CopyOpR2S,
|
||||
class ElementOutput,
|
||||
int AlignmentOutput = 128 / cute::sizeof_bits_v<ElementOutput>,
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
struct Sm90ScatterPtrArray {
|
||||
|
||||
using SmemShape = decltype(make_shape(size(make_layout(get<0>(EpilogueTile{}))), size(make_layout(get<1>(EpilogueTile{})))));
|
||||
using SmemLayout = decltype(tile_to_shape(SmemLayoutAtom{}, SmemShape{}));
|
||||
|
||||
using ElementIndex = int32_t;
|
||||
// TODO: more generic treatment, or pass StrideIndex via template param?
|
||||
using StrideIndex = conditional_t<cutlass::gemm::detail::is_mn_major<StrideOutput>(), Stride<_0,_1,_0>, Stride<_1,_0,_0>>;
|
||||
|
||||
struct SharedStorage {};
|
||||
|
||||
struct Arguments {
|
||||
ElementOutput* ptr_out = nullptr;
|
||||
StrideOutput dOut = {};
|
||||
ElementIndex const* const* ptr_index{}; // per-group pointer to the scatter index
|
||||
int index_modulo{}; // modulo used to transform the index before store
|
||||
bool use_reduction = true;
|
||||
};
|
||||
|
||||
struct Params {
|
||||
ElementOutput* ptr_out = nullptr;
|
||||
StrideOutput dOut = {};
|
||||
ElementIndex const* const* ptr_index{}; // per-group pointer to the scatter index
|
||||
cutlass::FastDivmod index_divmod{}; // modulo used to transform the index before store
|
||||
bool use_reduction = true;
|
||||
};
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return {
|
||||
args.ptr_out,
|
||||
args.dOut,
|
||||
args.ptr_index,
|
||||
cutlass::FastDivmod(args.index_modulo),
|
||||
args.use_reduction
|
||||
};
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ScatterPtrArray() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ScatterPtrArray(Params const& params, SharedStorage const& shared_storage)
|
||||
: params_ptr(¶ms) { }
|
||||
|
||||
Params const* params_ptr;
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template<
|
||||
class ArgsTuple
|
||||
>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(ArgsTuple&& args_tuple)
|
||||
: args_tuple(std::move(args_tuple)) {}
|
||||
|
||||
ArgsTuple args_tuple;
|
||||
|
||||
template <typename ElementAccumulator, typename ElementInput, int FragmentSize>
|
||||
CUTLASS_DEVICE auto
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n,
|
||||
Array<ElementInput, FragmentSize> const& frg_input) {
|
||||
|
||||
auto& [tC_rOut, tiled_r2s, tRG_gOut, tRG_cD, tiled_r2g_red, tiled_r2g_stg, use_reduction, thread_idx, residue_cD] = args_tuple;
|
||||
|
||||
using ConvertInput = NumericArrayConverter<ElementOutput, ElementInput, FragmentSize, RoundStyle>;
|
||||
ConvertInput convert_input{};
|
||||
|
||||
Tensor tC_rOut_frg = recast<Array<ElementOutput, FragmentSize>>(coalesce(tC_rOut)); // (EPI_V)
|
||||
tC_rOut_frg(epi_v) = convert_input(frg_input);
|
||||
|
||||
return tC_rOut_frg(epi_v);
|
||||
}
|
||||
|
||||
template <class STensor, class SyncFn, class VTensor>
|
||||
CUTLASS_DEVICE void
|
||||
reduce(STensor&& reduction_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) {
|
||||
|
||||
auto& [tC_rOut, tiled_r2s, tRG_gOut, tRG_cD, tiled_r2g_red, tiled_r2g_stg, use_reduction, thread_idx, residue_cD] = args_tuple;
|
||||
|
||||
Tensor byte_buffer = recast<uint8_t>(reduction_buffer);
|
||||
static_assert(cosize(byte_buffer.layout()) * sizeof_bits_v<uint8_t> >= cosize(SmemLayout{}) * sizeof_bits_v<ElementOutput>,
|
||||
"Not enough space in scratch smem buffer");
|
||||
|
||||
Tensor sOut = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(recast_ptr<ElementOutput>(byte_buffer.data())), SmemLayout{}));
|
||||
|
||||
auto thread_r2s = tiled_r2s.get_slice(thread_idx);
|
||||
Tensor tRS_sOut_epi = thread_r2s.partition_D(sOut);
|
||||
Tensor tRS_rOut_epi = thread_r2s.retile_S(tC_rOut);
|
||||
|
||||
auto thread_r2g = tiled_r2g_red.get_slice(thread_idx);
|
||||
Tensor tRG_gOut_epi = tRG_gOut(_,_,_,epi_m,epi_n);
|
||||
Tensor tRG_sOut_epi = thread_r2g.partition_D(sOut);
|
||||
Tensor tRG_rOut_epi = thread_r2g.retile_S(make_tensor(tC_rOut.data(), shape(tRG_sOut_epi))); // reuse D registers
|
||||
|
||||
// sanity check for register reuse
|
||||
CUTE_STATIC_ASSERT_V(cosize(tC_rOut.layout()) == cosize(tRG_rOut_epi.layout()), "Invalid register count for R2G");
|
||||
|
||||
copy(tiled_r2s, tRS_rOut_epi, tRS_sOut_epi);
|
||||
sync_fn();
|
||||
copy(tRG_sOut_epi, tRG_rOut_epi);
|
||||
|
||||
auto residue = residue_cD; // capturing structured bindings is a C++20 feature
|
||||
Tensor tRG_cD_epi = tRG_cD(0,_,_,epi_m,epi_n);
|
||||
auto pred = cute::lazy::transform(tRG_cD_epi, [&](auto c){ return elem_less(c, residue); });
|
||||
|
||||
if (use_reduction) {
|
||||
copy_if(tiled_r2g_red, pred, tRG_rOut_epi, tRG_gOut_epi);
|
||||
}
|
||||
else {
|
||||
copy_if(tiled_r2g_stg, pred, tRG_rOut_epi, tRG_gOut_epi);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class Element, int MaxVecSize>
|
||||
static constexpr auto get_reduction_op()
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// For now only support red.add
|
||||
if constexpr (is_same_v<Element, cutlass::half_t>) {
|
||||
if constexpr (MaxVecSize % 8 == 0) {
|
||||
return SM90_RED_ADD_NOFTZ_F16x2_V4{};
|
||||
}
|
||||
else if constexpr (MaxVecSize % 4 == 0) {
|
||||
return SM90_RED_ADD_NOFTZ_F16x2_V2{};
|
||||
}
|
||||
else if constexpr (MaxVecSize % 2 == 0) {
|
||||
return SM70_RED_ADD_NOFTZ_F16x2{};
|
||||
}
|
||||
else {
|
||||
return SM70_RED_ADD_NOFTZ_F16{};
|
||||
}
|
||||
}
|
||||
else if constexpr (is_same_v<Element, cutlass::bfloat16_t>) {
|
||||
if constexpr (MaxVecSize % 8 == 0) {
|
||||
return SM90_RED_ADD_NOFTZ_BF16x2_V4{};
|
||||
}
|
||||
else if constexpr (MaxVecSize % 4 == 0) {
|
||||
return SM90_RED_ADD_NOFTZ_BF16x2_V2{};
|
||||
}
|
||||
else if constexpr (MaxVecSize % 2 == 0) {
|
||||
return SM90_RED_ADD_NOFTZ_BF16x2{};
|
||||
}
|
||||
else {
|
||||
return SM90_RED_ADD_NOFTZ_BF16{};
|
||||
}
|
||||
}
|
||||
else {
|
||||
// non-vectorized atomic add for all other types until supported
|
||||
return TypedAtomicAdd<Element>{};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
|
||||
auto index_read = [index = params_ptr->ptr_index[l], divmod = params_ptr->index_divmod](auto i){ return divmod.rem(index[i]); };
|
||||
Tensor mOut = cutlass::util::make_gather_tensor(params_ptr->ptr_out, make_shape(M,N,Int<1>{}), params_ptr->dOut, index_read); // (M,N,_1)
|
||||
Tensor gOut = local_tile(mOut, take<0,2>(args.tile_shape_mnk), make_coord(m,n,Int<0>{})); // (CTA_M,CTA_N)
|
||||
Tensor gOut_epi = flat_divide(gOut, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
|
||||
Tensor mIdx = make_tensor(params_ptr->ptr_index[l], make_shape(M,N,Int<1>{}), StrideIndex{}); // (M,N,_1)
|
||||
Tensor gIdx = local_tile(mIdx, take<0,2>(args.tile_shape_mnk), make_coord(m,n,Int<0>{})); // (CTA_M,CTA_N)
|
||||
Tensor gIdx_epi = flat_divide(gIdx, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
|
||||
Tensor cD_epi = flat_divide(args.cD, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
|
||||
Tensor tC_gOut = sm90_partition_for_epilogue<ReferenceSrc>(gOut, args.epi_tile, args.tiled_copy, args.thread_idx); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
Tensor tC_rOut = make_tensor<ElementOutput>(take<0,3>(shape(tC_gOut))); // (CPY,CPY_M,CPY_N)
|
||||
|
||||
auto tiled_r2s = conditional_return<ReferenceSrc>(
|
||||
make_tiled_copy_S(Copy_Atom<CopyOpR2S,ElementOutput>{}, args.tiled_copy),
|
||||
make_tiled_copy_D(Copy_Atom<CopyOpR2S,ElementOutput>{}, args.tiled_copy)
|
||||
);
|
||||
|
||||
// Vectorization must not exceed alignment and also the number of values per thread in the tile
|
||||
int constexpr NumThreads = CUTE_STATIC_V(size(args.tiled_copy));
|
||||
int constexpr NumValTile = product(take<0,2>(shape(cD_epi)));
|
||||
int constexpr MaxVecSize = cute::min(AlignmentOutput, NumValTile / NumThreads);
|
||||
|
||||
// Choose the largest available red.global op and an st.global op with matching vectorization
|
||||
using CopyOpR2GRed = decltype(get_reduction_op<ElementOutput, MaxVecSize>());
|
||||
using CopyOpR2GStg = UniversalCopy<uint_bit_t<Copy_Atom<CopyOpR2GRed,ElementOutput>::NumValSrc * sizeof_bits_v<ElementOutput>>>;
|
||||
|
||||
auto make_tiled_r2g = [&](auto copy_op)
|
||||
{
|
||||
using CopyAtomR2G = Copy_Atom<decltype(copy_op),ElementOutput>;
|
||||
constexpr int VecSize = CopyAtomR2G::NumValSrc;
|
||||
if constexpr (cutlass::gemm::detail::is_k_major<StrideOutput>()) {
|
||||
constexpr int ThreadsMajor = size<1>(args.epi_tile) / VecSize;
|
||||
constexpr int ThreadsMinor = NumThreads / ThreadsMajor;
|
||||
return make_tiled_copy(CopyAtomR2G{},
|
||||
Layout<Shape<Int<ThreadsMinor>, Int<ThreadsMajor>>, Stride<Int<ThreadsMajor>, _1>>{},
|
||||
Layout<Shape<_1, Int<VecSize>>>{});
|
||||
}
|
||||
else if constexpr (cutlass::gemm::detail::is_mn_major<StrideOutput>()) {
|
||||
constexpr int ThreadsMajor = size<0>(args.epi_tile) / VecSize;
|
||||
constexpr int ThreadsMinor = NumThreads / ThreadsMajor;
|
||||
return make_tiled_copy(CopyAtomR2G{},
|
||||
Layout<Shape<Int<ThreadsMajor>, Int<ThreadsMinor>>, Stride<_1, Int<ThreadsMajor>>>{},
|
||||
Layout<Shape<Int<VecSize>, _1>>{});
|
||||
}
|
||||
else {
|
||||
static_assert(cute::is_void_v<StrideOutput>, "Unsupported D gmem layout.");
|
||||
}
|
||||
};
|
||||
|
||||
auto tiled_r2g_red = make_tiled_r2g(CopyOpR2GRed{});
|
||||
auto tiled_r2g_stg = make_tiled_r2g(CopyOpR2GStg{});
|
||||
|
||||
// Sanity checks - since we will be using one tiled copy with tensors partitioned with the other tiled copy,
|
||||
// ensure they have matching layouts/tilers
|
||||
using TiledR2GRed = decltype(tiled_r2g_red);
|
||||
using TiledR2GStg = decltype(tiled_r2g_stg);
|
||||
static_assert(typename TiledR2GRed::AtomLayoutSrc{} == typename TiledR2GStg::AtomLayoutSrc{}, "Mismatching AtomLayoutSrc");
|
||||
static_assert(typename TiledR2GRed::AtomLayoutDst{} == typename TiledR2GStg::AtomLayoutDst{}, "Mismatching AtomLayoutDst");
|
||||
static_assert(typename TiledR2GRed::TiledLayout_TV{} == typename TiledR2GStg::TiledLayout_TV{}, "Mismatching TiledLayout_TV");
|
||||
static_assert(typename TiledR2GRed::Tiler_MN{} == typename TiledR2GStg::Tiler_MN{}, "Mismatching Tiler_MN");
|
||||
|
||||
auto thread_r2g = tiled_r2g_red.get_slice(args.thread_idx);
|
||||
Tensor tRG_gOut = thread_r2g.partition_D(gOut_epi); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N)
|
||||
Tensor tRG_cD = thread_r2g.partition_D(cD_epi); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N)
|
||||
|
||||
auto args_tuple = make_tuple(
|
||||
cute::move(tC_rOut),
|
||||
tiled_r2s,
|
||||
tRG_gOut,
|
||||
tRG_cD,
|
||||
tiled_r2g_red,
|
||||
tiled_r2g_stg,
|
||||
params_ptr->use_reduction,
|
||||
args.thread_idx,
|
||||
args.residue_cD);
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(args_tuple)>(std::move(args_tuple));
|
||||
}
|
||||
};
|
||||
|
||||
template<
|
||||
class ElementOutput_,
|
||||
class ElementCompute_,
|
||||
class ElementBias_ = ElementOutput_,
|
||||
class ElementScalar_ = ElementCompute_,
|
||||
int AlignmentBias_ = 128 / cute::sizeof_bits_v<ElementBias_>,
|
||||
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
struct ScaledAccPerRowBias
|
||||
: ScaledAcc<ElementOutput_, ElementCompute_, ElementScalar_, RoundStyle_>
|
||||
{
|
||||
using ElementBias = ElementBias_;
|
||||
static constexpr int AlignmentBias = AlignmentBias_;
|
||||
static constexpr bool IsPerRowBiasSupported = true;
|
||||
};
|
||||
|
||||
template<
|
||||
class GmemLayoutTagOut,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias = ElementOutput,
|
||||
class ElementScale = ElementCompute,
|
||||
class ElementScalar = ElementCompute,
|
||||
int AlignmentBias = 128 / cute::sizeof_bits_v<ElementBias>,
|
||||
int AlignmentOutput = 128 / cute::sizeof_bits_v<ElementOutput>,
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
struct ScaledAccPerRowBiasPerColScaleScatter
|
||||
: ScaledAccPerRowBias<ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>
|
||||
{
|
||||
using ElementAux = ElementOutput;
|
||||
using GmemLayoutTagAux = GmemLayoutTagOut;
|
||||
static constexpr int AlignmentAux = AlignmentOutput;
|
||||
static constexpr bool IsAuxOutSupported = true;
|
||||
};
|
||||
|
||||
// D = alpha * acc + per-row bias
|
||||
template<
|
||||
class CtaTileShapeMNK,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias = ElementOutput,
|
||||
class ElementScalar = ElementCompute,
|
||||
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
using Sm90ScaledAccPerRowBiasPtrArray =
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // alpha * acc + bias
|
||||
Sm90ScalarBroadcastPtrArray<ElementScalar, Stride<_0,_0,int64_t>>, // alpha
|
||||
Sm90AccFetch, // acc
|
||||
Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias *, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias
|
||||
>;
|
||||
|
||||
template<
|
||||
class CtaTileShapeMNK,
|
||||
class EpilogueTile,
|
||||
class StrideOutput,
|
||||
class SmemLayoutAtom,
|
||||
class CopyOpR2S,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias = ElementOutput,
|
||||
class ElementScale = ElementCompute,
|
||||
class ElementScalar = ElementCompute,
|
||||
int AlignmentBias = 128 / cute::sizeof_bits_v<ElementBias>,
|
||||
int AlignmentOutput = 128 / cute::sizeof_bits_v<ElementOutput>,
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
using Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray =
|
||||
Sm90EVT<Sm90ScatterPtrArray<EpilogueTile, StrideOutput, SmemLayoutAtom, CopyOpR2S, ElementOutput, AlignmentOutput, RoundStyle>, // scatter store
|
||||
Sm90EVT<Sm90Compute<multiplies, ElementCompute, ElementCompute, RoundStyle>, // scale * (alpha * acc + bias)
|
||||
Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar *, ElementCompute, Stride<_0,_1,int64_t>, 1>, // scale
|
||||
Sm90ScaledAccPerRowBiasPtrArray<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle> // alpha * acc + bias
|
||||
>
|
||||
>;
|
||||
|
||||
template <
|
||||
int StagesC,
|
||||
int StagesD,
|
||||
int FragmentSize,
|
||||
bool ReuseSmemC,
|
||||
bool DelayTmaStore,
|
||||
int NumEpilogueWarpGroups,
|
||||
class GmemLayoutTagOut,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementBias,
|
||||
class ElementScale,
|
||||
class ElementScalar,
|
||||
int AlignmentBias,
|
||||
int AlignmentOutput,
|
||||
FloatRoundStyle RoundStyle,
|
||||
class CtaTileShapeMNK,
|
||||
class EpilogueTile,
|
||||
class SmemLayoutAtom,
|
||||
class CopyOpR2S
|
||||
>
|
||||
struct FusionCallbacks<
|
||||
epilogue::Sm90PtrArrayTmaWarpSpecialized<StagesC,
|
||||
StagesD,
|
||||
FragmentSize,
|
||||
ReuseSmemC,
|
||||
DelayTmaStore,
|
||||
NumEpilogueWarpGroups
|
||||
>,
|
||||
fusion::ScaledAccPerRowBiasPerColScaleScatter<GmemLayoutTagOut,
|
||||
ElementOutput,
|
||||
ElementCompute,
|
||||
ElementBias,
|
||||
ElementScale,
|
||||
ElementScalar,
|
||||
AlignmentBias,
|
||||
AlignmentOutput,
|
||||
RoundStyle>,
|
||||
CtaTileShapeMNK,
|
||||
EpilogueTile,
|
||||
SmemLayoutAtom,
|
||||
CopyOpR2S
|
||||
> : Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray<
|
||||
CtaTileShapeMNK,
|
||||
EpilogueTile,
|
||||
cutlass::gemm::TagToStrideC_t<GmemLayoutTagOut>,
|
||||
SmemLayoutAtom, CopyOpR2S,
|
||||
ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar,
|
||||
AlignmentBias, AlignmentOutput, RoundStyle
|
||||
> {
|
||||
|
||||
using StrideOutput = cutlass::gemm::TagToStrideC_t<GmemLayoutTagOut>;
|
||||
|
||||
using Impl = Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray<
|
||||
CtaTileShapeMNK,
|
||||
EpilogueTile,
|
||||
StrideOutput,
|
||||
SmemLayoutAtom, CopyOpR2S,
|
||||
ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar,
|
||||
AlignmentBias, AlignmentOutput, RoundStyle
|
||||
>;
|
||||
using Operation = fusion::ScaledAccPerRowBiasPerColScaleScatter<
|
||||
GmemLayoutTagOut,
|
||||
ElementOutput,
|
||||
ElementCompute,
|
||||
ElementBias,
|
||||
ElementScale,
|
||||
ElementScalar,
|
||||
AlignmentBias,
|
||||
AlignmentOutput,
|
||||
RoundStyle>;
|
||||
|
||||
struct Arguments {
|
||||
|
||||
using StrideAlpha = Stride<_0,_0,int64_t>;
|
||||
ElementScalar alpha = ElementScalar(1);
|
||||
ElementScalar const* alpha_ptr{};
|
||||
ElementScalar const* const* alpha_ptr_array{};
|
||||
StrideAlpha dAlpha{};
|
||||
|
||||
using StrideBias = Stride<_1,_0,int64_t>;
|
||||
ElementBias const* const* bias_ptr{};
|
||||
StrideBias dBias{};
|
||||
|
||||
using StrideScale = Stride<_0,_1,int64_t>;
|
||||
ElementScalar const* const* scale_ptr_array{};
|
||||
StrideScale dScale{};
|
||||
|
||||
// Nested args not usable due to a compiler bug with constexpr evaluation
|
||||
// using ScatterArguments = typename Sm90ScatterPtrArray<EpilogueTile, StrideOutput, SmemLayoutAtom, CopyOpR2S, ElementOutput, AlignmentOutput, RoundStyle>::Arguments;
|
||||
// ScatterArguments scatter{};
|
||||
|
||||
ElementOutput* ptr_out = nullptr;
|
||||
StrideOutput dOut = {};
|
||||
int const* const* ptr_index{}; // per-group pointer to the scatter index
|
||||
int index_modulo{}; // modulo used to transform the index before store
|
||||
bool use_reduction = true;
|
||||
|
||||
operator typename Impl::Arguments() const {
|
||||
return
|
||||
{ // unary op: reduce(scale * (beta * C + (alpha * acc)))
|
||||
{ // binary op: scale * (beta * C + (alpha * acc))
|
||||
{ scale_ptr_array, ElementScalar(1), dScale }, // leaf args : scale broadcast
|
||||
{ // ternary op : alpha * acc + bias
|
||||
{{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha
|
||||
{}, // leaf args : acc
|
||||
{bias_ptr, ElementBias(0), dBias}, // leaf args : bias
|
||||
{} // ternary args : multiply_add
|
||||
}, // end binary op
|
||||
{} // binary args: multiply
|
||||
}, // end binary op
|
||||
//scatter // unary args: reduce
|
||||
{ ptr_out, dOut, ptr_index, index_modulo, use_reduction }
|
||||
}; // end unary op
|
||||
}
|
||||
};
|
||||
|
||||
// Ctor inheritance
|
||||
using Impl::Impl;
|
||||
|
||||
};
|
||||
|
||||
} // namespace cutlass::epilogue::fusion
|
||||
|
||||
// clang-format on
|
||||
@ -133,11 +133,6 @@ enum class CutlassTileConfigSM100
|
||||
CtaShape128x256x128B,
|
||||
CtaShape128x128x256B,
|
||||
CtaShape128x256x256B,
|
||||
|
||||
// M=256
|
||||
CtaShape256x64x128B,
|
||||
CtaShape256x128x128B,
|
||||
CtaShape256x256x128B,
|
||||
};
|
||||
|
||||
enum class CutlassTileConfigSM120
|
||||
|
||||
@ -19,7 +19,7 @@
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/util/print.hpp"
|
||||
|
||||
namespace tensorrt_llm::cutlass_extensions
|
||||
namespace cutlass::util
|
||||
{
|
||||
|
||||
/// Function object that applies an index to its argument
|
||||
@ -81,7 +81,7 @@ struct CustomStride
|
||||
template <class Div>
|
||||
CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div)
|
||||
{
|
||||
return CustomStride<Func, decltype(safe_div(s.stride_, div))>(s.func_, safe_div(s.stride_, div));
|
||||
return CustomStride<Func, decltype(cute::safe_div(s.stride_, div))>(s.func_, cute::safe_div(s.stride_, div));
|
||||
}
|
||||
|
||||
// Circumvent the requirement on make_layout that shape and stride are integral
|
||||
@ -116,7 +116,7 @@ CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, S
|
||||
Layout gather_layout = make_custom_stride_layout(stride, static_cast<Func&&>(func));
|
||||
return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout});
|
||||
}
|
||||
} // namespace tensorrt_llm::cutlass_extensions
|
||||
} // namespace cutlass::util
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
@ -377,72 +377,62 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm100(CutlassGemmConfig::Ca
|
||||
if (config & CutlassGemmConfig::GROUPED_GEMM)
|
||||
{
|
||||
std::vector<CutlassGemmConfig> candidate_configs;
|
||||
if ((config & CutlassGemmConfig::FP4_ONLY) != 0)
|
||||
if (config & CutlassGemmConfig::FP4_ONLY)
|
||||
{
|
||||
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B,
|
||||
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
|
||||
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x128x128B,
|
||||
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B,
|
||||
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
|
||||
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B,
|
||||
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1});
|
||||
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B,
|
||||
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
|
||||
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x256x128B,
|
||||
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
|
||||
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B,
|
||||
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1});
|
||||
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x64x128B,
|
||||
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
|
||||
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B,
|
||||
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
|
||||
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B,
|
||||
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
|
||||
return candidate_configs;
|
||||
}
|
||||
|
||||
for (int cluster_m = 1; cluster_m <= 2; cluster_m++)
|
||||
std::vector<std::pair<CutlassTileConfigSM100, ClusterShape>> tile_configs{
|
||||
{CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_1x1x1},
|
||||
{CutlassTileConfigSM100::CtaShape128x256x128B, ClusterShape::ClusterShape_1x1x1},
|
||||
{CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_1x2x1},
|
||||
{CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_1x2x1},
|
||||
{CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_2x1x1},
|
||||
{CutlassTileConfigSM100::CtaShape64x256x128B, ClusterShape::ClusterShape_2x1x1},
|
||||
{CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_2x2x1},
|
||||
{CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_2x2x1},
|
||||
{CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_2x1x1},
|
||||
{CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_2x1x1},
|
||||
{CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_2x1x1},
|
||||
{CutlassTileConfigSM100::CtaShape128x256x128B, ClusterShape::ClusterShape_2x1x1},
|
||||
{CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_2x2x1},
|
||||
{CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_2x2x1},
|
||||
{CutlassTileConfigSM100::CtaShape128x32x128B, ClusterShape::ClusterShape_1x1x1},
|
||||
{CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_1x1x1},
|
||||
{CutlassTileConfigSM100::CtaShape64x32x128B, ClusterShape::ClusterShape_1x2x1},
|
||||
{CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_1x1x1},
|
||||
{CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_1x2x1},
|
||||
{CutlassTileConfigSM100::CtaShape64x256x128B, ClusterShape::ClusterShape_1x1x1},
|
||||
{CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_1x2x1},
|
||||
{CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_1x1x1},
|
||||
{CutlassTileConfigSM100::CtaShape128x32x128B, ClusterShape::ClusterShape_1x2x1},
|
||||
};
|
||||
|
||||
if (config & CutlassGemmConfig::FP8_ONLY)
|
||||
{
|
||||
bool Is2SM = cluster_m == 2;
|
||||
for (int cluster_n = 1; cluster_n <= 2; cluster_n++)
|
||||
{
|
||||
std::vector base = {// M=128
|
||||
CutlassTileConfigSM100::CtaShape128x128x128B, CutlassTileConfigSM100::CtaShape128x256x128B};
|
||||
tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x16x128B, ClusterShape::ClusterShape_1x1x1});
|
||||
// TODO: re-enable when handled by the MoE GEMM dispatch
|
||||
// tile_configs.push_back({ CutlassTileConfigSM100::CtaShape128x8x256B, ClusterShape::ClusterShape_1x1x1 });
|
||||
}
|
||||
|
||||
if (Is2SM)
|
||||
{
|
||||
if (cluster_n == 1)
|
||||
{
|
||||
base.push_back(CutlassTileConfigSM100::CtaShape128x64x128B);
|
||||
base.push_back(CutlassTileConfigSM100::CtaShape256x64x128B);
|
||||
}
|
||||
|
||||
std::vector twosm = {// M=256
|
||||
CutlassTileConfigSM100::CtaShape256x128x128B, CutlassTileConfigSM100::CtaShape256x256x128B};
|
||||
std::copy(twosm.begin(), twosm.end(), std::back_inserter(base));
|
||||
}
|
||||
else
|
||||
{
|
||||
if (cluster_n == 1)
|
||||
{
|
||||
base.push_back(CutlassTileConfigSM100::CtaShape128x32x128B);
|
||||
if ((config & CutlassGemmConfig::FP8_ONLY) != 0)
|
||||
{
|
||||
base.push_back(CutlassTileConfigSM100::CtaShape128x16x128B);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector onesm{CutlassTileConfigSM100::CtaShape64x64x128B,
|
||||
CutlassTileConfigSM100::CtaShape64x128x128B, CutlassTileConfigSM100::CtaShape64x256x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x64x128B};
|
||||
std::copy(onesm.begin(), onesm.end(), std::back_inserter(base));
|
||||
}
|
||||
|
||||
constexpr std::array cluster_shapes
|
||||
= {std::array{ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1},
|
||||
std::array{ClusterShape::ClusterShape_2x1x1, ClusterShape::ClusterShape_2x2x1}};
|
||||
auto cluster = cluster_shapes[cluster_m - 1][cluster_n - 1];
|
||||
for (auto tile : base)
|
||||
{
|
||||
CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, cluster};
|
||||
candidate_configs.push_back(config);
|
||||
}
|
||||
}
|
||||
for (auto [tile, cluster] : tile_configs)
|
||||
{
|
||||
CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, cluster};
|
||||
candidate_configs.push_back(config);
|
||||
}
|
||||
return candidate_configs;
|
||||
}
|
||||
|
||||
@ -37,11 +37,6 @@
|
||||
|
||||
namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
{
|
||||
template <class T>
|
||||
constexpr auto transpose_stride(T const& t)
|
||||
{
|
||||
return cute::prepend(cute::prepend(cute::take<2, cute::rank_v<T>>(t), cute::get<0>(t)), cute::get<1>(t));
|
||||
}
|
||||
|
||||
template <typename AType, typename BType, typename BScaleType, typename OType>
|
||||
struct GroupedGemmInput
|
||||
@ -72,8 +67,6 @@ struct GroupedGemmInput
|
||||
|
||||
struct TmaWarpSpecializedGroupedGemmInput
|
||||
{
|
||||
template <class T>
|
||||
using TransposeStride = decltype(transpose_stride<T>(T{}));
|
||||
template <class Tag>
|
||||
using TransposeLayoutTag = std::conditional_t<std::is_same_v<Tag, cutlass::layout::RowMajor>,
|
||||
cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>;
|
||||
@ -86,6 +79,7 @@ struct TmaWarpSpecializedGroupedGemmInput
|
||||
using LayoutA = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for A matrix operand
|
||||
using LayoutB = TransposeLayoutTag<cutlass::layout::ColumnMajor>; // Layout type for B matrix operand
|
||||
using LayoutC = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for C matrix operand
|
||||
using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand
|
||||
|
||||
constexpr static int NVFP4BlockScaleVectorSize = 16;
|
||||
constexpr static int MXFPXBlockScaleVectorSize = 32;
|
||||
@ -121,6 +115,7 @@ struct TmaWarpSpecializedGroupedGemmInput
|
||||
using StrideB
|
||||
= std::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutB*>>; // Use A because they will be swapped
|
||||
using StrideC = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutC*>>;
|
||||
using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
template <class T>
|
||||
@ -147,37 +142,26 @@ struct TmaWarpSpecializedGroupedGemmInput
|
||||
StrideC* stride_c = nullptr;
|
||||
void const** ptr_c = nullptr;
|
||||
|
||||
struct DefaultEpilogue
|
||||
{
|
||||
using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand
|
||||
using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;
|
||||
|
||||
StrideD* stride_d = nullptr;
|
||||
void** ptr_d = nullptr;
|
||||
};
|
||||
// D is used in all cases except fused finalize
|
||||
StrideD* stride_d = nullptr;
|
||||
void** ptr_d = nullptr;
|
||||
|
||||
struct FusedFinalizeEpilogue
|
||||
{
|
||||
using StrideFinalOutput = DefaultEpilogue::StrideD;
|
||||
using StrideBias = TransposeStride<cute::Stride<cute::_0, cute::_1, int>>;
|
||||
using StrideRouterScales = TransposeStride<cute::Stride<cute::_1, cute::_0>>;
|
||||
using StrideFinalOutput = cutlass::detail::TagToStrideC_t<LayoutD>;
|
||||
|
||||
void* ptr_final_output = nullptr;
|
||||
StrideFinalOutput stride_final_output{};
|
||||
|
||||
void const* ptr_bias = nullptr;
|
||||
StrideBias stride_bias{};
|
||||
void const** ptr_bias = nullptr;
|
||||
float const** ptr_router_scales = nullptr;
|
||||
|
||||
float const* ptr_router_scales = nullptr;
|
||||
StrideRouterScales stride_router_scales{};
|
||||
int const** ptr_source_token_index = nullptr;
|
||||
int num_rows_in_final_output = 0;
|
||||
|
||||
int64_t const* ptr_expert_first_token_offset = nullptr;
|
||||
int const* ptr_source_token_index = nullptr;
|
||||
|
||||
size_t num_rows_in_final_output = 0;
|
||||
bool use_reduction = true;
|
||||
};
|
||||
|
||||
DefaultEpilogue default_epilogue;
|
||||
FusedFinalizeEpilogue fused_finalize_epilogue;
|
||||
|
||||
enum class EpilogueFusion
|
||||
@ -235,7 +219,7 @@ struct TmaWarpSpecializedGroupedGemmInput
|
||||
uint8_t* gemm_workspace = nullptr;
|
||||
size_t gemm_workspace_size = 0;
|
||||
|
||||
static std::array<size_t, 17> workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type);
|
||||
static std::array<size_t, 20> workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type);
|
||||
|
||||
static size_t workspaceSize(int num_experts, FpXBlockScalingType scaling_type);
|
||||
|
||||
@ -247,9 +231,7 @@ struct TmaWarpSpecializedGroupedGemmInput
|
||||
return stride_a != nullptr && ptr_a != nullptr;
|
||||
}
|
||||
|
||||
void setFinalizeFusionParams(void* final_output, float const* router_scales,
|
||||
int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size,
|
||||
int num_output_tokens);
|
||||
void setFinalizeFusionParams(void* final_output, int hidden_size, int num_output_tokens, bool use_reduction);
|
||||
|
||||
std::string toString() const;
|
||||
};
|
||||
|
||||
@ -495,7 +495,8 @@ public:
|
||||
void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2,
|
||||
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
|
||||
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1,
|
||||
void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream)
|
||||
void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales,
|
||||
int const* permuted_row_to_unpermuted_row, cudaStream_t stream)
|
||||
= 0;
|
||||
|
||||
virtual std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
|
||||
@ -512,13 +513,13 @@ public:
|
||||
virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const = 0;
|
||||
|
||||
bool is_profiler = false;
|
||||
bool use_deterministic_hopper_reduce_ = false;
|
||||
bool use_fused_finalize_ = true;
|
||||
};
|
||||
|
||||
// Assumes inputs activations are row major. Weights need to be preprocessed by th_op/weight_quantize.cc .
|
||||
// Nested in a class to avoid multiple calls to cudaGetDeviceProperties as this call can be expensive.
|
||||
// Avoid making several duplicates of this class.
|
||||
template <typename T, /*The type used for activations*/
|
||||
template <typename T, /* The type used for activations */
|
||||
typename WeightType, /* The type for the MoE weights */
|
||||
typename OutputType = T, /* The type for the MoE final output */
|
||||
typename InputType = T, /* The type for the MoE input */
|
||||
@ -709,7 +710,8 @@ public:
|
||||
void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2,
|
||||
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
|
||||
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1,
|
||||
void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream) override
|
||||
void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales,
|
||||
int const* permuted_row_to_unpermuted_row, cudaStream_t stream) override
|
||||
{
|
||||
return Self::computeStridesTmaWarpSpecialized(expert_first_token_offset, layout_info1, layout_info2, num_tokens,
|
||||
expanded_num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts_per_node,
|
||||
@ -718,7 +720,8 @@ public:
|
||||
alpha_scale_flat1, alpha_scale_flat2, fp4_act_flat1, fp4_act_flat2, quant_params,
|
||||
reinterpret_cast<ScaleBiasType const*>(bias1), reinterpret_cast<ScaleBiasType const*>(bias2),
|
||||
reinterpret_cast<UnfusedGemmOutputType*>(gemm1_output),
|
||||
reinterpret_cast<UnfusedGemmOutputType*>(gemm2_output), stream);
|
||||
reinterpret_cast<UnfusedGemmOutputType*>(gemm2_output), router_scales, permuted_row_to_unpermuted_row,
|
||||
stream);
|
||||
}
|
||||
|
||||
std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
|
||||
@ -760,7 +763,8 @@ private:
|
||||
float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
|
||||
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params,
|
||||
ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output,
|
||||
UnfusedGemmOutputType* gemm2_output, cudaStream_t stream);
|
||||
UnfusedGemmOutputType* gemm2_output, float const* router_scales, int const* permuted_row_to_unpermuted_row,
|
||||
cudaStream_t stream);
|
||||
static std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
|
||||
computeStridesTmaWarpSpecializedLowLatency(TmaWarpSpecializedGroupedGemmInput layout_info1,
|
||||
TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k,
|
||||
@ -790,8 +794,8 @@ private:
|
||||
|
||||
bool mayHaveFinalizeFused() const
|
||||
{
|
||||
return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() == 90
|
||||
&& !use_deterministic_hopper_reduce_ && !use_w4_groupwise;
|
||||
return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() >= 90 && use_fused_finalize_
|
||||
&& !use_w4_groupwise;
|
||||
}
|
||||
|
||||
// TODO: This should eventually take the quant params to give more flexibility
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -26,21 +26,14 @@
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/fusion/operations.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "cutlass_extensions/compute_occupancy.h"
|
||||
#include "cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp"
|
||||
#include "cutlass_extensions/epilogue_helpers.h"
|
||||
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
|
||||
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp"
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
@ -189,17 +182,19 @@ using SafeBF16 = void;
|
||||
TmaWarpSpecializedGroupedGemmInput tma_ws_input, int num_experts, int const multi_processor_count, \
|
||||
cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size) \
|
||||
{ \
|
||||
constexpr static EpilogueFusion FUSION = EpilogueFusion::FUSION_; \
|
||||
/* constexpr static bool BIAS = BIAS_; */ /* Always false */ \
|
||||
using ArchTag = cutlass::arch::ArchTag_; \
|
||||
constexpr static EpilogueFusion FUSION = EpilogueFusion::FUSION_; \
|
||||
constexpr static bool IsMXFPX = MXFPX_; \
|
||||
constexpr bool IsBlackwell = ArchTag::kMinComputeCapability >= 100; \
|
||||
constexpr bool IsSM120 = ArchTag::kMinComputeCapability == 120 || ArchTag::kMinComputeCapability == 121; \
|
||||
constexpr bool Is2SM = IsBlackwell && (CGA_M_ % 2 == 0); \
|
||||
/* constexpr static bool BIAS = BIAS_; */ /* Always false */ \
|
||||
using T = DataType_; \
|
||||
using WeightType = WeightType_; \
|
||||
using OutputType = OutputType_; \
|
||||
using EpilogueTag = tensorrt_llm::cutlass_extensions::EpilogueTag_; \
|
||||
using TileShape = cute::Shape<cute::Int<CTA_M_>, cute::Int<CTA_N_>, cute::Int<CTA_K_>>; \
|
||||
using MmaTileShape = cute::Shape<cute::Int<CTA_M_*(Is2SM ? 2 : 1)>, cute::Int<CTA_N_>, cute::Int<CTA_K_>>; \
|
||||
using ClusterShape = cute::Shape<cute::Int<CGA_M_>, cute::Int<CGA_N_>, cute::Int<CGA_K_>>; \
|
||||
constexpr static bool IsMXFPX = MXFPX_; \
|
||||
\
|
||||
if constexpr (!COMPILE_HOPPER_TMA_GROUPED_GEMMS_ENABLED && ArchTag::kMinComputeCapability >= 90 \
|
||||
&& ArchTag::kMinComputeCapability < 100) \
|
||||
{ \
|
||||
@ -217,18 +212,15 @@ using SafeBF16 = void;
|
||||
TLLM_THROW( \
|
||||
"Please recompile with support for blackwell by passing 120-real as an arch to build_wheel.py."); \
|
||||
} \
|
||||
else if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v<ArchTag, TileShape, ClusterShape, \
|
||||
T>) \
|
||||
else if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v<ArchTag, MmaTileShape, \
|
||||
ClusterShape, T>) \
|
||||
{ \
|
||||
using namespace cute; \
|
||||
/* Helper class for defining all the cutlass types \
|
||||
// template <typename ArchTag, typename T, typename WeightType, typename OutputType, typename EpilogueTag, \
|
||||
// typename TileShape, typename ClusterShape, bool BIAS, EpilogueFusion FUSION> \
|
||||
// typename MmaTileShape, typename ClusterShape, bool BIAS, EpilogueFusion FUSION> \
|
||||
// struct TmaWarpSpecializedGroupedGemmInfo \
|
||||
{ */ \
|
||||
using Arch = ArchTag; \
|
||||
constexpr static bool IsBlackwell = Arch::kMinComputeCapability >= 100; \
|
||||
constexpr static bool IsSM120 = Arch::kMinComputeCapability == 120 || Arch::kMinComputeCapability == 121; \
|
||||
constexpr static bool IsWFP4AFP8 = cutlass::platform::is_same<WeightType, SafeFP4>::value \
|
||||
&& cutlass::platform::is_same<T, SafeFP8>::value; \
|
||||
constexpr static bool IsFP4 = cutlass::platform::is_same<T, SafeFP4>::value; \
|
||||
@ -308,8 +300,8 @@ using SafeBF16 = void;
|
||||
// units of elements (up to 16 bytes)*/ \
|
||||
\
|
||||
/* D matrix configuration */ \
|
||||
using LayoutD = TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::LayoutD; \
|
||||
using StrideD = TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::StrideD; \
|
||||
using LayoutD = TmaWarpSpecializedGroupedGemmInput::LayoutD; \
|
||||
using StrideD = TmaWarpSpecializedGroupedGemmInput::StrideD; \
|
||||
constexpr static int AlignmentD \
|
||||
= 128 / cutlass::sizeof_bits<ElementD>::value; /* Memory access granularity/alignment of D matrix \
|
||||
// in units of elements (up to 16 bytes) */ \
|
||||
@ -327,30 +319,24 @@ using SafeBF16 = void;
|
||||
// cutlass::epilogue::PtrArrayNoSmemWarpSpecialized, \
|
||||
// cutlass::epilogue::?????????????????? /// <<<<<< what supports activations \
|
||||
// >;*/ \
|
||||
using EpilogueScheduleSM90 = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; \
|
||||
using EpilogueScheduleSM90 = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \
|
||||
\
|
||||
constexpr static bool Is2SM = IsBlackwell && (cute::size<0>(ClusterShape{}) % 2) == 0; \
|
||||
using EpilogueScheduleSM100 = std::conditional_t<Is2SM, cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm, \
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm>; \
|
||||
using EpilogueScheduleSM120 = cutlass::epilogue::TmaWarpSpecialized; \
|
||||
using EpilogueScheduleBW = std ::conditional_t<IsSM120, EpilogueScheduleSM120, EpilogueScheduleSM100>; \
|
||||
using EpilogueScheduleBW = std::conditional_t<IsSM120, EpilogueScheduleSM120, EpilogueScheduleSM100>; \
|
||||
using EpilogueSchedule = std::conditional_t<IsBlackwell, EpilogueScheduleBW, EpilogueScheduleSM90>; \
|
||||
\
|
||||
using EpilogueTileShapeSm90 = TileShape; \
|
||||
using AtomClusterDiv = std::conditional_t<Is2SM, _2, _1>; \
|
||||
using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<AtomClusterDiv, _1, _1>{})); \
|
||||
using EpilogueTileShapeSm100 = decltype(shape_div(TileShape{}, AtomThrShape{})); \
|
||||
using EpilogueTileShape = std::conditional_t<IsBlackwell, EpilogueTileShapeSm100, EpilogueTileShapeSm90>; \
|
||||
using EpilogueElementC = std::conditional_t<IsSM120, ElementCSafe, ElementC>; \
|
||||
using EpilogueTensorOp = std::conditional_t<IsBlackwell && IsBlockScaled, \
|
||||
cutlass::arch::OpClassBlockScaledTensorOp, cutlass::arch::OpClassTensorOp>; \
|
||||
using EpilogueSubTile \
|
||||
= std::conditional_t<Arch::kMinComputeCapability == 100 && IsFP4 && CTA_N_ == 256, /* SM100 Exactly */ \
|
||||
cute::Shape<cute::_128, cute::_64>, cutlass::epilogue::collective::EpilogueTileAuto>; \
|
||||
using EpilogueSubTile = std::conditional_t<ArchTag::kMinComputeCapability == 100 && IsFP4 \
|
||||
&& CTA_N_ == 256, /* SM100 Exactly */ \
|
||||
cute::Shape<cute::_128, cute::_64>, cutlass::epilogue::collective::EpilogueTileAuto>; \
|
||||
/* Epilogue For Default Finalize */ \
|
||||
using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder</**/ \
|
||||
Arch, EpilogueTensorOp, /**/ \
|
||||
EpilogueTileShape, ClusterShape, /**/ \
|
||||
ArchTag, EpilogueTensorOp, /**/ \
|
||||
MmaTileShape, ClusterShape, /**/ \
|
||||
EpilogueSubTile, /**/ \
|
||||
ElementAccumulator, ElementAccumulator, /**/ \
|
||||
EpilogueElementC, LayoutC*, AlignmentC, /**/ \
|
||||
@ -358,18 +344,17 @@ using SafeBF16 = void;
|
||||
EpilogueSchedule>::CollectiveOp; \
|
||||
\
|
||||
/* Epilogue For Fused Finalize */ \
|
||||
using CollectiveEpilogueFinalize = \
|
||||
typename cutlass::epilogue::collective::EpilogueMoeFusedFinalizeBuilder< /**/ \
|
||||
Arch, EpilogueTileShape, /**/ \
|
||||
ElementCSafe, StrideC*, /**/ \
|
||||
ElementFinalOutput, \
|
||||
TmaWarpSpecializedGroupedGemmInput::FusedFinalizeEpilogue::StrideFinalOutput, /**/ \
|
||||
ElementAccumulator, /**/ \
|
||||
ElementAccumulator, /**/ \
|
||||
ElementBias, TmaWarpSpecializedGroupedGemmInput::FusedFinalizeEpilogue::StrideBias, /**/ \
|
||||
ElementRouterScales, \
|
||||
TmaWarpSpecializedGroupedGemmInput::FusedFinalizeEpilogue::StrideRouterScales /**/ \
|
||||
>::CollectiveOp; \
|
||||
using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::CollectiveBuilder</**/ \
|
||||
ArchTag, EpilogueTensorOp, /**/ \
|
||||
MmaTileShape, ClusterShape, /**/ \
|
||||
EpilogueSubTile, /**/ \
|
||||
ElementAccumulator, ElementAccumulator, /**/ \
|
||||
EpilogueElementC, LayoutC*, AlignmentC, /**/ \
|
||||
void, LayoutD*, AlignmentD, /**/ \
|
||||
EpilogueSchedule, /**/ \
|
||||
cutlass::epilogue::fusion::ScaledAccPerRowBiasPerColScaleScatter< /**/ \
|
||||
LayoutD, ElementFinalOutput, ElementAccumulator, ElementBias, ElementRouterScales> /**/ \
|
||||
>::CollectiveOp; \
|
||||
\
|
||||
using CollectiveEpilogue = std::conditional_t<FUSION == EpilogueFusion::FINALIZE, \
|
||||
CollectiveEpilogueFinalize, CollectiveEpilogueDefault>; \
|
||||
@ -405,16 +390,12 @@ using SafeBF16 = void;
|
||||
using MainloopElementA = std::conditional_t<IsBlackwell && IsBlockScaled, ElementABlockScaled, ElementA>; \
|
||||
using MainloopElementB = std::conditional_t<IsBlackwell && IsBlockScaled, ElementBBlockScaled, ElementB>; \
|
||||
\
|
||||
using MainloopTileShapeSm90 = TileShape; \
|
||||
using MainloopTileShapeSm100 = decltype(shape_div(TileShape{}, AtomThrShape{})); \
|
||||
using MainloopTileShape = std::conditional_t<IsBlackwell, MainloopTileShapeSm100, MainloopTileShapeSm90>; \
|
||||
\
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder</**/ \
|
||||
Arch, TensorOp, /**/ \
|
||||
ArchTag, TensorOp, /**/ \
|
||||
MainloopElementB, LayoutB*, AlignmentB, /* A & B swapped here */ \
|
||||
MainloopElementA, LayoutA*, AlignmentA, /**/ \
|
||||
ElementAccumulator, /**/ \
|
||||
MainloopTileShape, ClusterShape, /**/ \
|
||||
MmaTileShape, ClusterShape, /**/ \
|
||||
StageCountAutoCarveout, KernelSchedule>::CollectiveOp; \
|
||||
\
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<TmaWarpSpecializedGroupedGemmInput::ProblemShape, \
|
||||
@ -422,11 +403,11 @@ using SafeBF16 = void;
|
||||
\
|
||||
using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; \
|
||||
/*}; \
|
||||
\ \
|
||||
// \
|
||||
// using namespace cute; \
|
||||
// using GemmInfo = TmaWarpSpecializedGroupedGemmInfo;<ArchTag, T, WeightType, OutputType, \
|
||||
EpilogueTag, \
|
||||
// TileShape, \
|
||||
// MmaTileShape, \
|
||||
// ClusterShape, BIAS, FUSION>; \
|
||||
// \
|
||||
// using ElementAccumulator = typename GemmInfo::ElementAccumulator; \
|
||||
@ -478,7 +459,7 @@ using SafeBF16 = void;
|
||||
TLLM_CHECK(tma_ws_input.ptr_a); \
|
||||
TLLM_CHECK(tma_ws_input.ptr_b); \
|
||||
\
|
||||
auto make_mainloop_params = [&]() -> MainloopArguments \
|
||||
MainloopArguments const mainloop_args = [&] \
|
||||
{ \
|
||||
if constexpr (IsBlockScaled) \
|
||||
{ \
|
||||
@ -498,67 +479,46 @@ using SafeBF16 = void;
|
||||
reinterpret_cast<ElementB const**>(tma_ws_input.ptr_b), tma_ws_input.stride_b, \
|
||||
reinterpret_cast<ElementA const**>(tma_ws_input.ptr_a), tma_ws_input.stride_a); \
|
||||
} \
|
||||
}; \
|
||||
\
|
||||
auto const mainloop_params = make_mainloop_params(); \
|
||||
\
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments; \
|
||||
using EpilogueScalars = decltype(EpilogueArguments{}.thread); \
|
||||
auto make_epilogue_scalars = [&]() \
|
||||
}(); \
|
||||
using FusionArguments = typename CollectiveEpilogue::FusionCallbacks::Arguments; \
|
||||
FusionArguments fusion_args = [&] \
|
||||
{ \
|
||||
if constexpr (IsBlackwell) \
|
||||
if constexpr (FUSION == EpilogueFusion::FINALIZE) \
|
||||
{ \
|
||||
return construct_if_true<IsBlackwell, EpilogueScalars>(ElementAccumulator(1.f), \
|
||||
tma_ws_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f), nullptr, nullptr, \
|
||||
tma_ws_input.alpha_scale_ptr_array, nullptr, \
|
||||
cute::Shape<_0, _0, int64_t>{ \
|
||||
cute::_0{}, cute::_0{}, (tma_ws_input.alpha_scale_ptr_array != nullptr) ? 1 : 0}, \
|
||||
cute::Shape<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 0}); \
|
||||
} \
|
||||
else if (tma_ws_input.alpha_scale_ptr_array) \
|
||||
{ \
|
||||
return construct_if_true<!IsBlackwell, EpilogueScalars>(tma_ws_input.alpha_scale_ptr_array); \
|
||||
auto epi_params = tma_ws_input.fused_finalize_epilogue; \
|
||||
return construct_if_true<FUSION == EpilogueFusion::FINALIZE, FusionArguments>( \
|
||||
ElementAccumulator(1), nullptr, tma_ws_input.alpha_scale_ptr_array, \
|
||||
Stride<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 1}, /* alpha */ \
|
||||
reinterpret_cast<ElementBias const* const*>(epi_params.ptr_bias), \
|
||||
Stride<_1, _0, int64_t>{}, /* bias */ \
|
||||
epi_params.ptr_router_scales, Stride<_0, _1, int64_t>{}, /* scale */ \
|
||||
reinterpret_cast<ElementFinalOutput*>(epi_params.ptr_final_output), \
|
||||
epi_params.stride_final_output, epi_params.ptr_source_token_index, \
|
||||
epi_params.num_rows_in_final_output, epi_params.use_reduction); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
return construct_if_true<!IsBlackwell, EpilogueScalars>(ElementAccumulator(1.f), \
|
||||
tma_ws_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); \
|
||||
return construct_if_true<FUSION != EpilogueFusion::FINALIZE, FusionArguments>( \
|
||||
ElementAccumulator(1), ElementAccumulator(0), nullptr, nullptr, \
|
||||
tma_ws_input.alpha_scale_ptr_array, nullptr, \
|
||||
Stride<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 1}, Stride<_0, _0, int64_t>{}); \
|
||||
} \
|
||||
}; \
|
||||
auto epilogue_scalars = make_epilogue_scalars(); \
|
||||
/* TODO ptr_c casts to ElementCSafe** because there is a workaround in CUTLASS */ \
|
||||
auto make_epi_args = [&]() \
|
||||
{ \
|
||||
static_assert(FUSION == EpilogueFusion::NONE || FUSION == EpilogueFusion::FINALIZE, \
|
||||
"Unimplemented fusion provided to TMA WS MoE gemm launcher"); \
|
||||
}(); \
|
||||
\
|
||||
if constexpr (FUSION == EpilogueFusion::NONE) \
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments; \
|
||||
EpilogueArguments epilogue_args = [&] \
|
||||
{ \
|
||||
if constexpr (FUSION == EpilogueFusion::FINALIZE) \
|
||||
{ \
|
||||
auto epi_params = tma_ws_input.default_epilogue; \
|
||||
return construct_if_true<FUSION == EpilogueFusion::NONE, EpilogueArguments>(epilogue_scalars, \
|
||||
nullptr, tma_ws_input.stride_c, reinterpret_cast<ElementD**>(epi_params.ptr_d), \
|
||||
epi_params.stride_d); \
|
||||
} \
|
||||
else if constexpr (FUSION == EpilogueFusion::FINALIZE) \
|
||||
{ \
|
||||
/* Parameters for fused finalize */ \
|
||||
auto epi_params = tma_ws_input.fused_finalize_epilogue; \
|
||||
return construct_if_true<FUSION == EpilogueFusion::FINALIZE, EpilogueArguments>( \
|
||||
epilogue_scalars, /* Parameters to underlying epilogue */ \
|
||||
nullptr, tma_ws_input.stride_c, /* C params */ \
|
||||
reinterpret_cast<ElementFinalOutput*>(epi_params.ptr_final_output), \
|
||||
epi_params.stride_final_output, /* D (output) params */ \
|
||||
reinterpret_cast<ElementBias const*>(epi_params.ptr_bias), \
|
||||
epi_params.stride_bias, /* Bias params */ \
|
||||
epi_params.ptr_router_scales, epi_params.stride_router_scales, /* Router scales */ \
|
||||
epi_params.ptr_expert_first_token_offset, /* Offset of this expert's token in the \
|
||||
router scales */ \
|
||||
epi_params.ptr_source_token_index, /* Index of the source token to sum into */ \
|
||||
epi_params.num_rows_in_final_output /* Number of tokens in the output buffer */ \
|
||||
); \
|
||||
fusion_args, nullptr, nullptr, nullptr, nullptr); \
|
||||
} \
|
||||
}; \
|
||||
EpilogueArguments const epilogue_params = make_epi_args(); \
|
||||
else \
|
||||
{ \
|
||||
return construct_if_true<FUSION != EpilogueFusion::FINALIZE, EpilogueArguments>(fusion_args, \
|
||||
nullptr, nullptr, reinterpret_cast<ElementD**>(tma_ws_input.ptr_d), tma_ws_input.stride_d); \
|
||||
} \
|
||||
}(); \
|
||||
/* EpilogueArguments const epilogue_params = make_epi_args<EpilogueArguments, EpilogueScalars, \
|
||||
ElementCSafe, ElementD, ElementFinalOutput, ElementBias, FUSION>( \
|
||||
// tma_ws_input, epilogue_scalars \
|
||||
@ -568,7 +528,7 @@ using SafeBF16 = void;
|
||||
1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN}; \
|
||||
\
|
||||
const typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, \
|
||||
tma_ws_input.shape_info, mainloop_params, epilogue_params, hw_info, scheduler_args}; \
|
||||
tma_ws_input.shape_info, mainloop_args, epilogue_args, hw_info, scheduler_args}; \
|
||||
\
|
||||
size_t calculated_ws_size = gemm.get_workspace_size(args); \
|
||||
TLLM_CHECK_WITH_INFO(calculated_ws_size <= tma_ws_input.gemm_workspace_size, \
|
||||
|
||||
@ -197,8 +197,7 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
|
||||
reinterpret_cast<ElementScalePacked const**>(hopper_inputs.int4_groupwise_params.ptr_s_a),
|
||||
hopper_inputs.int4_groupwise_params.stride_s_a, group_size},
|
||||
{fusion_args, reinterpret_cast<ElementC const**>(hopper_inputs.ptr_c), hopper_inputs.stride_c,
|
||||
reinterpret_cast<ElementD**>(hopper_inputs.default_epilogue.ptr_d),
|
||||
hopper_inputs.default_epilogue.stride_d},
|
||||
reinterpret_cast<ElementD**>(hopper_inputs.ptr_d), hopper_inputs.stride_d},
|
||||
hw_info};
|
||||
*workspace_size = gemm.get_workspace_size(args);
|
||||
return;
|
||||
@ -211,8 +210,7 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
|
||||
reinterpret_cast<ElementScalePacked const**>(hopper_inputs.int4_groupwise_params.ptr_s_a),
|
||||
hopper_inputs.int4_groupwise_params.stride_s_a, group_size},
|
||||
{fusion_args, reinterpret_cast<ElementC const**>(hopper_inputs.ptr_c), hopper_inputs.stride_c,
|
||||
reinterpret_cast<ElementD**>(hopper_inputs.default_epilogue.ptr_d),
|
||||
hopper_inputs.default_epilogue.stride_d},
|
||||
reinterpret_cast<ElementD**>(hopper_inputs.ptr_d), hopper_inputs.stride_d},
|
||||
hw_info};
|
||||
|
||||
if (gemm.get_workspace_size(arguments) > hopper_inputs.gemm_workspace_size)
|
||||
|
||||
@ -138,11 +138,11 @@ void dispatchMoeGemmSelectBiasTmaWarpSpecialized(TmaWarpSpecializedGroupedGemmIn
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ClusterTileShape, typename ClusterShape, typename DataType, typename WeightType>
|
||||
template <typename CtaShape, typename ClusterShape, typename DataType, typename WeightType>
|
||||
constexpr bool are_tile_shapes_supported_sm100()
|
||||
{
|
||||
using namespace cute;
|
||||
using CtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{}));
|
||||
|
||||
// This is the epilogue shape. The MMA shape will be twice this for 2SM
|
||||
constexpr auto TileM = size<0>(CtaShape{});
|
||||
constexpr auto TileN = size<1>(CtaShape{});
|
||||
@ -353,6 +353,7 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized(TmaWarpSpecializedGroupedG
|
||||
{
|
||||
switch (gemm_config.tile_config_sm100)
|
||||
{
|
||||
SHAPE_CASE(100, 64, 32, 128)
|
||||
SHAPE_CASE(100, 64, 64, 128)
|
||||
SHAPE_CASE(100, 64, 128, 128)
|
||||
SHAPE_CASE(100, 64, 256, 128)
|
||||
@ -363,13 +364,8 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized(TmaWarpSpecializedGroupedG
|
||||
SHAPE_CASE(100, 128, 128, 128)
|
||||
SHAPE_CASE(100, 128, 256, 128)
|
||||
|
||||
SHAPE_CASE(100, 256, 64, 128)
|
||||
SHAPE_CASE(100, 256, 128, 128)
|
||||
SHAPE_CASE(100, 256, 256, 128)
|
||||
|
||||
// SHAPE_CASE(100, 128, 128, 64)
|
||||
// SHAPE_CASE(100, 128, 256, 64)
|
||||
// SHAPE_CASE(100, 256, 256, 64)
|
||||
DEFAULT_CASE(100)
|
||||
}
|
||||
}
|
||||
|
||||
@ -27,14 +27,14 @@
|
||||
|
||||
namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
{
|
||||
std::array<size_t, 17> TmaWarpSpecializedGroupedGemmInput::workspaceBuffers(
|
||||
std::array<size_t, 20> TmaWarpSpecializedGroupedGemmInput::workspaceBuffers(
|
||||
int num_experts, FpXBlockScalingType scaling_type)
|
||||
{
|
||||
size_t problem_shape_size = sizeof(ProblemShape::UnderlyingProblemShape) * num_experts;
|
||||
size_t stride_a_size = sizeof(StrideA) * num_experts;
|
||||
size_t stride_b_size = sizeof(StrideB) * num_experts;
|
||||
size_t stride_c_size = sizeof(StrideC) * num_experts;
|
||||
size_t stride_d_size = sizeof(DefaultEpilogue::StrideD) * num_experts;
|
||||
size_t stride_d_size = sizeof(StrideD) * num_experts;
|
||||
|
||||
size_t ptr_buf_size = sizeof(void*) * num_experts;
|
||||
size_t scale_buf_size = sizeof(float*) * num_experts;
|
||||
@ -53,9 +53,12 @@ std::array<size_t, 17> TmaWarpSpecializedGroupedGemmInput::workspaceBuffers(
|
||||
size_t int4_groupwise_sf_a_size = sizeof(INT4GroupwiseParams::SFA*) * num_experts;
|
||||
size_t int4_groupwise_stride_sf_a_size = sizeof(INT4GroupwiseParams::StrideSFA) * num_experts;
|
||||
|
||||
size_t ptr_token_map_size = sizeof(int**) * num_experts;
|
||||
|
||||
return std::array{problem_shape_size, stride_a_size, stride_b_size, stride_c_size, stride_d_size, ptr_buf_size,
|
||||
ptr_buf_size, ptr_buf_size, ptr_buf_size, scale_buf_size, sf_a_size, sf_b_size, stride_sf_a_size,
|
||||
stride_sf_b_size, int4_groupwise_problem_shape_size, int4_groupwise_sf_a_size, int4_groupwise_stride_sf_a_size};
|
||||
stride_sf_b_size, int4_groupwise_problem_shape_size, int4_groupwise_sf_a_size, int4_groupwise_stride_sf_a_size,
|
||||
ptr_buf_size, scale_buf_size, ptr_token_map_size};
|
||||
}
|
||||
|
||||
size_t TmaWarpSpecializedGroupedGemmInput::workspaceSize(int num_experts, FpXBlockScalingType scaling_type)
|
||||
@ -68,7 +71,7 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i
|
||||
size_t gemm_workspace_size, FpXBlockScalingType scaling_type)
|
||||
{
|
||||
auto buffers = workspaceBuffers(num_experts, scaling_type);
|
||||
std::array<int8_t*, 17> pointers{};
|
||||
std::array<int8_t*, 20> pointers{};
|
||||
TLLM_CHECK_WITH_INFO(pointers.size() == buffers.size(), "Mismatching workspace size and number of buffers");
|
||||
for (int i = 0; i < buffers.size(); i++)
|
||||
{
|
||||
@ -82,12 +85,12 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i
|
||||
stride_a = reinterpret_cast<StrideA*>(pointers[1]);
|
||||
stride_b = reinterpret_cast<StrideB*>(pointers[2]);
|
||||
stride_c = reinterpret_cast<StrideC*>(pointers[3]);
|
||||
default_epilogue.stride_d = reinterpret_cast<DefaultEpilogue::StrideD*>(pointers[4]);
|
||||
stride_d = reinterpret_cast<StrideD*>(pointers[4]);
|
||||
|
||||
ptr_a = reinterpret_cast<void const**>(pointers[5]);
|
||||
ptr_b = reinterpret_cast<void const**>(pointers[6]);
|
||||
ptr_c = reinterpret_cast<void const**>(pointers[7]);
|
||||
default_epilogue.ptr_d = reinterpret_cast<void**>(pointers[8]);
|
||||
ptr_d = reinterpret_cast<void**>(pointers[8]);
|
||||
|
||||
alpha_scale_ptr_array = reinterpret_cast<float const**>(pointers[9]);
|
||||
|
||||
@ -103,28 +106,24 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i
|
||||
int4_groupwise_params.ptr_s_a = reinterpret_cast<INT4GroupwiseParams::SFA const**>(pointers[15]);
|
||||
int4_groupwise_params.stride_s_a = reinterpret_cast<INT4GroupwiseParams::StrideSFA*>(pointers[16]);
|
||||
|
||||
fused_finalize_epilogue.ptr_bias = reinterpret_cast<void const**>(pointers[17]);
|
||||
fused_finalize_epilogue.ptr_router_scales = reinterpret_cast<float const**>(pointers[18]);
|
||||
fused_finalize_epilogue.ptr_source_token_index = reinterpret_cast<int const**>(pointers[19]);
|
||||
|
||||
this->gemm_workspace = reinterpret_cast<uint8_t*>(gemm_workspace);
|
||||
this->gemm_workspace_size = gemm_workspace_size;
|
||||
}
|
||||
|
||||
void TmaWarpSpecializedGroupedGemmInput::setFinalizeFusionParams(void* final_output, float const* router_scales,
|
||||
int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size,
|
||||
int num_output_tokens)
|
||||
void TmaWarpSpecializedGroupedGemmInput::setFinalizeFusionParams(
|
||||
void* final_output, int hidden_size, int num_output_tokens, bool use_reduction)
|
||||
{
|
||||
fused_finalize_epilogue.ptr_final_output = final_output;
|
||||
fused_finalize_epilogue.ptr_router_scales = router_scales;
|
||||
fused_finalize_epilogue.ptr_bias = bias;
|
||||
fused_finalize_epilogue.ptr_expert_first_token_offset = expert_first_token_offset;
|
||||
fused_finalize_epilogue.ptr_source_token_index = source_token_index;
|
||||
|
||||
fused_finalize_epilogue.stride_final_output
|
||||
= cutlass::make_cute_packed_stride(FusedFinalizeEpilogue::StrideFinalOutput{},
|
||||
transpose_stride(cute::make_shape(num_output_tokens, hidden_size, 1)));
|
||||
fused_finalize_epilogue.stride_bias
|
||||
= transpose_stride(cute::make_stride(cute::Int<0>{}, cute::Int<1>{}, hidden_size));
|
||||
fused_finalize_epilogue.stride_router_scales = {};
|
||||
fused_finalize_epilogue.stride_final_output = cutlass::make_cute_packed_stride(
|
||||
FusedFinalizeEpilogue::StrideFinalOutput{}, cute::make_shape(hidden_size, num_output_tokens, 1));
|
||||
|
||||
fused_finalize_epilogue.num_rows_in_final_output = num_output_tokens;
|
||||
fused_finalize_epilogue.use_reduction = use_reduction;
|
||||
}
|
||||
|
||||
std::string TmaWarpSpecializedGroupedGemmInput::toString() const
|
||||
@ -143,16 +142,13 @@ std::string TmaWarpSpecializedGroupedGemmInput::toString() const
|
||||
ss << "Final Output: " << (PrintType) fused_finalize_epilogue.ptr_final_output;
|
||||
ss << " with Stride: " << fused_finalize_epilogue.stride_final_output;
|
||||
ss << ",\nBias: " << (PrintType) fused_finalize_epilogue.ptr_bias;
|
||||
ss << " with Stride: " << fused_finalize_epilogue.stride_bias;
|
||||
ss << ",\nRouter Scales: " << fused_finalize_epilogue.ptr_router_scales;
|
||||
ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales;
|
||||
ss << ",\nExpert Offset: " << (PrintType) fused_finalize_epilogue.ptr_expert_first_token_offset;
|
||||
ss << ", Source Map: " << (PrintType) fused_finalize_epilogue.ptr_source_token_index;
|
||||
}
|
||||
else
|
||||
{
|
||||
ss << "Ptr D: " << (PrintType) default_epilogue.ptr_d;
|
||||
ss << " with Stride: " << (PrintType) default_epilogue.stride_d;
|
||||
ss << "Ptr D: " << (PrintType) ptr_d;
|
||||
ss << " with Stride: " << (PrintType) stride_d;
|
||||
}
|
||||
ss << '\n';
|
||||
ss << "Alpha scale ptr: " << (PrintType) alpha_scale_ptr_array << "\n";
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -1165,8 +1165,8 @@ __device__ void computeTmaWarpSpecializedInputStrides(
|
||||
}
|
||||
if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE)
|
||||
{
|
||||
layout_info.default_epilogue.stride_d[out_idx] = cutlass::make_cute_packed_stride(
|
||||
TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::StrideD{}, cute::make_shape(gemm_n, gemm_m, 1));
|
||||
layout_info.stride_d[out_idx] = cutlass::make_cute_packed_stride(
|
||||
TmaWarpSpecializedGroupedGemmInput::StrideD{}, cute::make_shape(gemm_n, gemm_m, 1));
|
||||
}
|
||||
if (layout_info.int4_groupwise_params.enabled)
|
||||
{
|
||||
@ -1185,7 +1185,8 @@ template <class T, class WeightType, class OutputType, class ScaleBiasType>
|
||||
__device__ void computeTmaWarpSpecializedInputPointers(TmaWarpSpecializedGroupedGemmInput& layout_info, int64_t gemm_m,
|
||||
int64_t gemm_n, int64_t gemm_k, int num_tokens_before_expert, int64_t expert, T const* in,
|
||||
WeightType const* weights, TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::SFA const* w4a8_weight_scale,
|
||||
ScaleBiasType const* bias, OutputType* output, int64_t const out_idx)
|
||||
ScaleBiasType const* bias, OutputType* output, float const* router_scales,
|
||||
int const* permuted_row_to_unpermuted_row, int64_t const out_idx)
|
||||
{
|
||||
// The input prior to this contains K elements per token, with `num_tokens_before_expert` tokens
|
||||
layout_info.ptr_a[out_idx] = safe_inc_ptr(in, num_tokens_before_expert * gemm_k);
|
||||
@ -1196,7 +1197,18 @@ __device__ void computeTmaWarpSpecializedInputPointers(TmaWarpSpecializedGrouped
|
||||
if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE)
|
||||
{
|
||||
// The output prior to this contains N elements per token, with `num_tokens_before_expert` tokens
|
||||
layout_info.default_epilogue.ptr_d[out_idx] = safe_inc_ptr(output, num_tokens_before_expert * gemm_n);
|
||||
layout_info.ptr_d[out_idx] = safe_inc_ptr(output, num_tokens_before_expert * gemm_n);
|
||||
}
|
||||
if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE)
|
||||
{
|
||||
|
||||
layout_info.fused_finalize_epilogue.ptr_source_token_index[expert]
|
||||
= permuted_row_to_unpermuted_row + num_tokens_before_expert;
|
||||
layout_info.fused_finalize_epilogue.ptr_router_scales[expert] = router_scales + num_tokens_before_expert;
|
||||
if (layout_info.fused_finalize_epilogue.ptr_bias != nullptr)
|
||||
{
|
||||
layout_info.fused_finalize_epilogue.ptr_bias[expert] = bias + gemm_n * expert;
|
||||
}
|
||||
}
|
||||
if (layout_info.int4_groupwise_params.enabled)
|
||||
{
|
||||
@ -1219,7 +1231,8 @@ __global__ void computeStridesTmaWarpSpecializedKernel(int64_t const* expert_fir
|
||||
WeightType const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2,
|
||||
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
|
||||
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params,
|
||||
ScaleBiasType const* bias1, ScaleBiasType const* bias2, OutputType* gemm1_output, OutputType* gemm2_output)
|
||||
ScaleBiasType const* bias1, ScaleBiasType const* bias2, OutputType* gemm1_output, OutputType* gemm2_output,
|
||||
float const* router_scales, int const* permuted_row_to_unpermuted_row)
|
||||
{
|
||||
// First, compute the global tid. We only need 1 thread per expert.
|
||||
int const expert = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
@ -1297,12 +1310,12 @@ __global__ void computeStridesTmaWarpSpecializedKernel(int64_t const* expert_fir
|
||||
gemm1_in, weights1,
|
||||
reinterpret_cast<TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::SFA const*>(
|
||||
quant_params.groupwise.fc1.weight_scales),
|
||||
bias1, gemm1_output, expert);
|
||||
bias1, gemm1_output, nullptr, nullptr, expert);
|
||||
computeTmaWarpSpecializedInputPointers(layout_info2, gemm_m, gemm2_n, gemm2_k, num_tokens_before_expert, expert,
|
||||
gemm2_in, weights2,
|
||||
reinterpret_cast<TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::SFA const*>(
|
||||
quant_params.groupwise.fc2.weight_scales),
|
||||
bias2, gemm2_output, expert);
|
||||
bias2, gemm2_output, router_scales, permuted_row_to_unpermuted_row, expert);
|
||||
}
|
||||
|
||||
template <class T, class WeightType, class OutputType, class ScaleBiasType>
|
||||
@ -1420,12 +1433,12 @@ __global__ void computeStridesTmaWarpSpecializedLowLatencyKernel(TmaWarpSpeciali
|
||||
layout_info2.ptr_b[expert] = safe_inc_ptr(weights2, local_expert * (gemm1_n * gemm2_k));
|
||||
|
||||
assert(layout_info1.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE);
|
||||
layout_info1.default_epilogue.ptr_d[expert] = safe_inc_ptr(output1, expert * num_tokens * gemm1_n);
|
||||
layout_info1.ptr_d[expert] = safe_inc_ptr(output1, expert * num_tokens * gemm1_n);
|
||||
|
||||
if (layout_info2.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE)
|
||||
{
|
||||
// The output prior to this contains N elements per token, with `num_tokens` tokens
|
||||
layout_info2.default_epilogue.ptr_d[expert] = safe_inc_ptr(output2, expert * num_tokens * gemm2_n);
|
||||
layout_info2.ptr_d[expert] = safe_inc_ptr(output2, expert * num_tokens * gemm2_n);
|
||||
}
|
||||
}
|
||||
else
|
||||
@ -1435,10 +1448,10 @@ __global__ void computeStridesTmaWarpSpecializedLowLatencyKernel(TmaWarpSpeciali
|
||||
layout_info1.ptr_b[expert] = nullptr;
|
||||
layout_info2.ptr_b[expert] = nullptr;
|
||||
|
||||
layout_info1.default_epilogue.ptr_d[expert] = nullptr;
|
||||
layout_info1.ptr_d[expert] = nullptr;
|
||||
if (layout_info2.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE)
|
||||
{
|
||||
layout_info2.default_epilogue.ptr_d[expert] = nullptr;
|
||||
layout_info2.ptr_d[expert] = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2015,8 +2028,8 @@ void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_ro
|
||||
#define INSTANTIATE_FINALIZE_MOE_ROUTING(OutputT, GemmOutputT, ScaleBiasT) \
|
||||
template void finalizeMoeRoutingKernelLauncher<OutputT, GemmOutputT, ScaleBiasT>( \
|
||||
GemmOutputT const* expanded_permuted_rows, OutputT* reduced_unpermuted_output, ScaleBiasT const* bias, \
|
||||
float const* final_scales, int const* expanded_source_row_to_expanded_dest_row, \
|
||||
int const* expanded_dest_row_to_expanded_source_row, int const* expert_for_source_row, \
|
||||
float const* final_scales, int const* unpermuted_row_to_permuted_row, \
|
||||
int const* permuted_row_to_unpermuted_row, int const* expert_for_source_row, \
|
||||
int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const cols, \
|
||||
int64_t const experts_per_token, int64_t const num_experts_per_node, MOEParallelismConfig parallelism_config, \
|
||||
bool const enable_alltoall, cudaStream_t stream);
|
||||
@ -3295,9 +3308,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
|
||||
}
|
||||
|
||||
bool has_different_output_type_ampere = (use_w4afp8 || use_fp8) && !using_tma_ws_gemm2;
|
||||
bool using_hopper_fused_finalize
|
||||
= tma_ws_input.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE;
|
||||
bool has_different_output_type_tma_ws = !using_hopper_fused_finalize && using_tma_ws_gemm2;
|
||||
bool using_fused_finalize = tma_ws_input.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE;
|
||||
bool has_different_output_type_tma_ws = !using_fused_finalize && using_tma_ws_gemm2;
|
||||
|
||||
if (has_different_output_type_ampere || has_different_output_type_tma_ws)
|
||||
{
|
||||
@ -3815,7 +3827,8 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
|
||||
float const* fp8_dequant2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
|
||||
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params,
|
||||
ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output,
|
||||
UnfusedGemmOutputType* gemm2_output, cudaStream_t stream)
|
||||
UnfusedGemmOutputType* gemm2_output, float const* router_scales, int const* permuted_row_to_unpermuted_row,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
// Always nullptr
|
||||
layout_info1.ptr_c = nullptr;
|
||||
@ -3823,6 +3836,12 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
|
||||
layout_info2.ptr_c = nullptr;
|
||||
layout_info2.stride_c = nullptr;
|
||||
|
||||
layout_info1.fused_finalize_epilogue.ptr_bias = nullptr;
|
||||
if (!bias2)
|
||||
{
|
||||
layout_info2.fused_finalize_epilogue.ptr_bias = nullptr;
|
||||
}
|
||||
|
||||
auto alpha_scale_flat1 = use_fp4 ? quant_params.fp4.fc1.global_scale
|
||||
: use_wfp4afp8 ? quant_params.fp8_mxfp4.fc1.global_scale
|
||||
: use_fp8 ? fp8_dequant1
|
||||
@ -3863,7 +3882,7 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
|
||||
cudaLaunchKernelEx(&config, kernel_instance, expert_first_token_offset, layout_info1, layout_info2, num_tokens,
|
||||
expanded_num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts_per_node, gemm1_in, gemm2_in, weights1,
|
||||
weights2, alpha_scale_flat1, alpha_scale_flat2, fp4_act_flat1, fp4_act_flat2, quant_params, bias1, bias2,
|
||||
gemm1_output, gemm2_output);
|
||||
gemm1_output, gemm2_output, router_scales, permuted_row_to_unpermuted_row);
|
||||
|
||||
return std::make_pair(layout_info1, layout_info2);
|
||||
}
|
||||
@ -3986,15 +4005,15 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
|
||||
gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
|
||||
|
||||
bool apply_bias = parallelism_config.tp_rank == 0;
|
||||
bool using_hopper_fused_finalize
|
||||
= !use_deterministic_hopper_reduce_ && gemm2_config_->sm_version == 90 && !use_w4_groupwise && !use_lora;
|
||||
if (using_hopper_fused_finalize)
|
||||
auto* fc2_bias = apply_bias ? fc2_expert_biases : nullptr;
|
||||
bool using_fused_finalize
|
||||
= use_fused_finalize_ && gemm2_config_->sm_version >= 90 && !use_w4_groupwise && !use_lora;
|
||||
if (using_fused_finalize)
|
||||
{
|
||||
assert(min_latency_mode == false);
|
||||
bool use_reduction = expanded_num_rows > num_rows;
|
||||
gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE;
|
||||
gemm2_tma_ws_input.setFinalizeFusionParams(final_output, permuted_token_final_scales_,
|
||||
expert_first_token_offset_, permuted_row_to_unpermuted_row_, apply_bias ? fc2_expert_biases : nullptr,
|
||||
hidden_size, num_rows);
|
||||
gemm2_tma_ws_input.setFinalizeFusionParams(final_output, hidden_size, num_rows, use_reduction);
|
||||
}
|
||||
|
||||
// fp8_mxfp4 memsets the scaling factors to 1.0f
|
||||
@ -4028,9 +4047,10 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
|
||||
gemm2_tma_ws_input, num_rows, expanded_num_rows, fc1_out_size, hidden_size, hidden_size, inter_size,
|
||||
num_experts_per_node, reinterpret_cast<T const*>(gemm1_input), reinterpret_cast<T const*>(gemm2_input),
|
||||
fc1_expert_weights, fc2_expert_weights, quant_params.fp8.dequant_fc1, quant_params.fp8.dequant_fc2,
|
||||
fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, fc1_expert_biases, fc2_expert_biases,
|
||||
fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, fc1_expert_biases, fc2_bias,
|
||||
reinterpret_cast<UnfusedGemmOutputType*>(gemm1_output),
|
||||
reinterpret_cast<UnfusedGemmOutputType*>(fc2_result_), stream);
|
||||
reinterpret_cast<UnfusedGemmOutputType*>(fc2_result_), permuted_token_final_scales_,
|
||||
permuted_row_to_unpermuted_row_, stream);
|
||||
}
|
||||
}
|
||||
|
||||
@ -4591,20 +4611,17 @@ void GemmProfilerBackend::prepareTmaWsInputs(
|
||||
gemm1_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
|
||||
gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
|
||||
|
||||
bool apply_bias = true;
|
||||
bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4);
|
||||
bool use_wfp4a16 = ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16)
|
||||
&& mWType == nvinfer1::DataType::kUINT8);
|
||||
bool use_w4_groupwise = use_w4afp8 || use_wfp4a16;
|
||||
bool using_fused_finalize
|
||||
= !mInterface->use_deterministic_hopper_reduce_ && mSM == 90 && !mMinLatencyMode && !use_w4_groupwise;
|
||||
= mInterface->use_fused_finalize_ && mSM >= 90 && !mMinLatencyMode && !use_w4_groupwise;
|
||||
if (using_fused_finalize)
|
||||
{
|
||||
assert(!mMinLatencyMode);
|
||||
gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE;
|
||||
gemm2_tma_ws_input.setFinalizeFusionParams(output, token_topk_unpermuted_scales,
|
||||
expert_first_token_offset, permuted_row_to_unpermuted_row, apply_bias ? bias : nullptr,
|
||||
mExpertHiddenSize, num_tokens);
|
||||
gemm2_tma_ws_input.setFinalizeFusionParams(output, mExpertHiddenSize, num_tokens, mK > 1);
|
||||
}
|
||||
|
||||
auto fc1_output_size = isGatedActivation(mActivationType) ? mExpertInterSize * 2 : mExpertInterSize;
|
||||
@ -4625,7 +4642,7 @@ void GemmProfilerBackend::prepareTmaWsInputs(
|
||||
fc1_output_size, mExpertHiddenSize, mExpertHiddenSize, mExpertInterSize, mNumExpertsPerNode, input,
|
||||
input, weights_sel, weights_sel, mQuantParams.fp8.dequant_fc1, mQuantParams.fp8.dequant_fc2,
|
||||
fp4_act_scale_flat, fp4_act_scale_flat, mQuantParams, nullptr, nullptr, intermediate, intermediate,
|
||||
stream);
|
||||
token_topk_unpermuted_scales, permuted_row_to_unpermuted_row, stream);
|
||||
}
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
|
||||
@ -35,8 +35,7 @@ constexpr bool isValidSM120MOESpecialisation()
|
||||
{
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) // TODO Is there a better choice
|
||||
return cutlass::platform::is_same<T, __nv_fp4_e2m1>::value && cutlass::platform::is_same<T, WeightType>::value
|
||||
&& cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value
|
||||
&& Fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
|
||||
&& cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value;
|
||||
#else
|
||||
return false; // CUTLASS_ARCH_MMA_SM100_SUPPORTED is set when Blackwell kernels are enabled
|
||||
#endif
|
||||
@ -51,8 +50,7 @@ constexpr bool isValidBlackwellMOESpecialisation()
|
||||
return (cutlass::platform::is_same<T, WeightType>::value
|
||||
|| (cutlass::platform::is_same<T, __nv_fp8_e4m3>::value
|
||||
&& cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value))
|
||||
&& cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value
|
||||
&& Fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
|
||||
&& cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value;
|
||||
#else
|
||||
return false; // CUTLASS_ARCH_MMA_SM100_SUPPORTED is set when Blackwell kernels are enabled
|
||||
#endif
|
||||
|
||||
@ -212,8 +212,7 @@ template void sm90_generic_mixed_gemm_kernelLauncher<{act_tag}, {weight_tag}, {s
|
||||
{kernel_sched}, {epi_sched}> (
|
||||
const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zero_tag}*, const {bias_tag}*, const float,
|
||||
{out_tag}*, int, int, int, const int, tensorrt_llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
|
||||
);
|
||||
"""
|
||||
);"""
|
||||
elif operation.gemm_kind == GemmKind.Grouped:
|
||||
if operation.act_type != operation.weight_type and (
|
||||
operation.act_type != DataType.e4m3
|
||||
@ -261,11 +260,9 @@ GroupedGemmInput<{act_tag}, {weight_tag}, {out_tag}, {out_tag}>inputs, TmaWarpSp
|
||||
# (TmaWarpSpecializedGroupedGemmInput, int, int, cudaStream_t, int*, size_t*);
|
||||
# """
|
||||
instantiation = f"""
|
||||
#if {guard_act} && {guard_weight}\n
|
||||
INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM({arch_tag}, {act_tag}, {weight_tag}, {out_tag},
|
||||
{epi_tag}, {epi_fusion}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.cga_shape[0]}, {operation.cga_shape[1]}, {operation.cga_shape[2]}, {"true" if operation.is_mx_fpx else "false"}, false);\n
|
||||
#endif
|
||||
"""
|
||||
#if {guard_act} && {guard_weight}
|
||||
INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM({arch_tag}, {act_tag}, {weight_tag}, {out_tag}, {epi_tag}, {epi_fusion}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.cga_shape[0]}, {operation.cga_shape[1]}, {operation.cga_shape[2]}, {"true" if operation.is_mx_fpx else "false"}, false);
|
||||
#endif"""
|
||||
return instantiation
|
||||
|
||||
|
||||
@ -276,8 +273,7 @@ def instantiate_operation_sm80(operation):
|
||||
|
||||
instantiation = f"""
|
||||
template void sm80_generic_fused_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.stage}, {epi_tag}>
|
||||
({act_tag} const* A, {weight_tag} const* B, {act_tag} const* biases, bool bias_is_broadcast, {act_tag}* C, int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy);
|
||||
"""
|
||||
({act_tag} const* A, {weight_tag} const* B, {act_tag} const* biases, bool bias_is_broadcast, {act_tag}* C, int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy);"""
|
||||
return instantiation
|
||||
|
||||
|
||||
@ -340,16 +336,13 @@ def write_file(launcher_inl_files, operations, output_file):
|
||||
f.write(content)
|
||||
|
||||
|
||||
from operator import mul, truediv
|
||||
|
||||
|
||||
def elementwise(x, y, f):
|
||||
return tuple(f(a, b) for (a, b) in zip(x, y))
|
||||
|
||||
|
||||
def is_gemm_op_valid_sm100(op):
|
||||
# TODO These are much more restricted than theory dictates, investigate if more can be enabled in future
|
||||
tile_m, tile_n, _ = elementwise(op.cta_shape, op.cga_shape, truediv)
|
||||
tile_m, tile_n, _ = op.cta_shape
|
||||
cga_m, cga_n, _ = op.cga_shape
|
||||
|
||||
# Default shapes
|
||||
@ -366,13 +359,11 @@ def is_gemm_op_valid_sm100(op):
|
||||
return False
|
||||
|
||||
# Shapes for fp8 small N shapes
|
||||
if (op.act_type == DataType.e4m3 and (tile_n == 16 or tile_n == 8)
|
||||
and (cga_m == 1 and cga_n == 1)):
|
||||
# todo: double check why this is disable in CUTLASS backend. @yuhan
|
||||
if tile_m == 128 and tile_n == 8:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
if (op.act_type == DataType.e4m3) and (tile_n == 16
|
||||
or tile_n == 8) and (cga_m == 1
|
||||
and cga_n == 1):
|
||||
# todo: double check why tile_n = 8 is disabled in CUTLASS backend. @yuhan
|
||||
return tile_m != 128 or tile_n % 16 == 0
|
||||
|
||||
# Default alignment requirements
|
||||
if tile_n % 32 != 0 or tile_n < 32 or tile_n > 256:
|
||||
@ -617,8 +608,6 @@ def calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype):
|
||||
cta_shape_k = max_k_bits // GetDataTypeBits(dtype)
|
||||
if dtype == DataType.e4m3 and (cta_shape_mn[1] == 8):
|
||||
cta_shape_k = 256
|
||||
if dtype == DataType.e4m3 and (cta_shape_mn[1] == 16):
|
||||
cta_shape_k = 128
|
||||
return cta_shape_mn + (cta_shape_k, )
|
||||
|
||||
|
||||
@ -638,7 +627,7 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled):
|
||||
|
||||
epi_fusions = [
|
||||
TrtLlm_EpilogueFusion.epilogue_fusion_none,
|
||||
# TrtLlm_EpilogueFusion.epilogue_fusion_finalize
|
||||
TrtLlm_EpilogueFusion.epilogue_fusion_finalize
|
||||
]
|
||||
|
||||
cga_shapes = [[1, 1, 1]]
|
||||
@ -648,7 +637,6 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled):
|
||||
|
||||
operations = list()
|
||||
for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mnk, cga_shape in partial_args:
|
||||
cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul)
|
||||
|
||||
# Ignored
|
||||
mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative
|
||||
@ -661,8 +649,8 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled):
|
||||
for otype in otypes:
|
||||
moe_gemm_operation = TrtLlm_GemmLauncher(
|
||||
GemmKind.Grouped, arch, dtype, dtype, dtype, dtype, otype,
|
||||
quant_op, epi_tag, cga_tile_shape_mnk, warp_shape, stages,
|
||||
cga_shape, mainloop_schedule, epi_schedule, epi_fusion)
|
||||
quant_op, epi_tag, cta_shape_mnk, warp_shape, stages, cga_shape,
|
||||
mainloop_schedule, epi_schedule, epi_fusion)
|
||||
|
||||
operations.append(moe_gemm_operation)
|
||||
return operations
|
||||
@ -692,7 +680,7 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled):
|
||||
|
||||
epi_fusions = [
|
||||
TrtLlm_EpilogueFusion.epilogue_fusion_none,
|
||||
# TrtLlm_EpilogueFusion.epilogue_fusion_finalize
|
||||
TrtLlm_EpilogueFusion.epilogue_fusion_finalize
|
||||
]
|
||||
|
||||
cga_shapes = list(product([1, 2], [1, 2], [1]))
|
||||
@ -708,7 +696,6 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled):
|
||||
weight_type = dtype
|
||||
|
||||
cta_shape_mnk = calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype)
|
||||
cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul)
|
||||
|
||||
# Ignored
|
||||
mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative
|
||||
@ -729,7 +716,7 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled):
|
||||
otype,
|
||||
quant_op,
|
||||
epi_tag,
|
||||
cga_tile_shape_mnk,
|
||||
cta_shape_mnk,
|
||||
warp_shape,
|
||||
stages,
|
||||
cga_shape,
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b25eb3a8bc1fae83eb43f9e0cf8fd93bb00f412d6cbd1bf7e2214e878bec3b4a
|
||||
size 64735372
|
||||
oid sha256:86586b9f6845e91e8ba0accad53a5a3418c50d8fd30ad49fa8837470c72b5dcf
|
||||
size 67051604
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
16bae34717995b98ee8cff17bc8ec080c0e1b1aca02e5949be171eb8d40eff39 libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
commit 995030f9b86258f3db876df6b1dbc46a7c5dae50
|
||||
568cb6ca2413c93b0f5839dd05577c0c57bc4b5f2359366c79d0ace665de4bd6 libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
commit 9c0a42825905952beaf9b35d5a35d58de1a123fa
|
||||
|
||||
@ -39,11 +39,6 @@
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
template <class T>
|
||||
constexpr auto transpose_stride(T const& t)
|
||||
{
|
||||
return cute::prepend(cute::prepend(cute::take<2, cute::rank_v<T>>(t), cute::get<0>(t)), cute::get<1>(t));
|
||||
}
|
||||
|
||||
// Note update moe.py to match
|
||||
enum class ActivationType
|
||||
@ -87,8 +82,6 @@ struct GroupedGemmInput
|
||||
|
||||
struct TmaWarpSpecializedGroupedGemmInput
|
||||
{
|
||||
template <class T>
|
||||
using TransposeStride = decltype(transpose_stride<T>(T{}));
|
||||
template <class Tag>
|
||||
using TransposeLayoutTag = std::conditional_t<std::is_same_v<Tag, cutlass::layout::RowMajor>,
|
||||
cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>;
|
||||
@ -101,6 +94,7 @@ struct TmaWarpSpecializedGroupedGemmInput
|
||||
using LayoutA = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for A matrix operand
|
||||
using LayoutB = TransposeLayoutTag<cutlass::layout::ColumnMajor>; // Layout type for B matrix operand
|
||||
using LayoutC = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for C matrix operand
|
||||
using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand
|
||||
|
||||
constexpr static int NVFP4BlockScaleVectorSize = 16;
|
||||
constexpr static int MXFPXBlockScaleVectorSize = 32;
|
||||
@ -122,6 +116,7 @@ struct TmaWarpSpecializedGroupedGemmInput
|
||||
using StrideB
|
||||
= std::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutB*>>; // Use A because they will be swapped
|
||||
using StrideC = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutC*>>;
|
||||
using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
template <class T>
|
||||
@ -148,37 +143,26 @@ struct TmaWarpSpecializedGroupedGemmInput
|
||||
StrideC* stride_c = nullptr;
|
||||
void const** ptr_c = nullptr;
|
||||
|
||||
struct DefaultEpilogue
|
||||
{
|
||||
using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand
|
||||
using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;
|
||||
|
||||
StrideD* stride_d = nullptr;
|
||||
void** ptr_d = nullptr;
|
||||
};
|
||||
// D is used in all cases except fused finalize
|
||||
StrideD* stride_d = nullptr;
|
||||
void** ptr_d = nullptr;
|
||||
|
||||
struct FusedFinalizeEpilogue
|
||||
{
|
||||
using StrideFinalOutput = DefaultEpilogue::StrideD;
|
||||
using StrideBias = TransposeStride<cute::Stride<cute::_0, cute::_1, int>>;
|
||||
using StrideRouterScales = TransposeStride<cute::Stride<cute::_1, cute::_0>>;
|
||||
using StrideFinalOutput = cutlass::detail::TagToStrideC_t<LayoutD>;
|
||||
|
||||
void* ptr_final_output = nullptr;
|
||||
StrideFinalOutput stride_final_output{};
|
||||
|
||||
void const* ptr_bias = nullptr;
|
||||
StrideBias stride_bias{};
|
||||
void const** ptr_bias = nullptr;
|
||||
float const** ptr_router_scales = nullptr;
|
||||
|
||||
float const* ptr_router_scales = nullptr;
|
||||
StrideRouterScales stride_router_scales{};
|
||||
int const** ptr_source_token_index = nullptr;
|
||||
int num_rows_in_final_output = 0;
|
||||
|
||||
int64_t const* ptr_expert_first_token_offset = nullptr;
|
||||
int const* ptr_source_token_index = nullptr;
|
||||
|
||||
size_t num_rows_in_final_output = 0;
|
||||
bool use_reduction = true;
|
||||
};
|
||||
|
||||
DefaultEpilogue default_epilogue;
|
||||
FusedFinalizeEpilogue fused_finalize_epilogue;
|
||||
|
||||
enum class EpilogueFusion
|
||||
@ -235,7 +219,7 @@ struct TmaWarpSpecializedGroupedGemmInput
|
||||
uint8_t* gemm_workspace = nullptr;
|
||||
size_t gemm_workspace_size = 0;
|
||||
|
||||
static std::array<size_t, 17> workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type);
|
||||
static std::array<size_t, 20> workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type);
|
||||
|
||||
static size_t workspaceSize(int num_experts, FpXBlockScalingType scaling_type);
|
||||
|
||||
@ -247,9 +231,7 @@ struct TmaWarpSpecializedGroupedGemmInput
|
||||
return stride_a != nullptr && ptr_a != nullptr;
|
||||
}
|
||||
|
||||
void setFinalizeFusionParams(void* final_output, float const* router_scales,
|
||||
int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size,
|
||||
int num_output_tokens);
|
||||
void setFinalizeFusionParams(void* final_output, int hidden_size, int num_output_tokens, bool use_reduction);
|
||||
|
||||
std::string toString() const;
|
||||
};
|
||||
|
||||
@ -426,7 +426,7 @@ public:
|
||||
ActivationParams fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
|
||||
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
|
||||
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output,
|
||||
int* expanded_source_row_to_expanded_dest_row, MOEParallelismConfig parallelism_config, bool use_lora,
|
||||
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool use_lora,
|
||||
LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
|
||||
MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
|
||||
= 0;
|
||||
@ -450,8 +450,8 @@ public:
|
||||
void const* const fc2_expert_weights, void const* const fc2_expert_biases, void const* const fc2_int_scales,
|
||||
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
|
||||
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
|
||||
float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
|
||||
int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
|
||||
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
|
||||
int const* permuted_row_to_unpermuted_row, int const* const expert_for_source_row,
|
||||
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
|
||||
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
|
||||
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
|
||||
@ -468,7 +468,8 @@ public:
|
||||
void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2,
|
||||
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
|
||||
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1,
|
||||
void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream)
|
||||
void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales,
|
||||
int const* permuted_row_to_unpermuted_row, cudaStream_t stream)
|
||||
= 0;
|
||||
|
||||
virtual std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
|
||||
@ -485,13 +486,13 @@ public:
|
||||
virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const = 0;
|
||||
|
||||
bool is_profiler = false;
|
||||
bool use_deterministic_hopper_reduce_ = false;
|
||||
bool use_fused_finalize_ = true;
|
||||
};
|
||||
|
||||
// Assumes inputs activations are row major. Weights need to be preprocessed by th_op/weight_quantize.cc .
|
||||
// Nested in a class to avoid multiple calls to cudaGetDeviceProperties as this call can be expensive.
|
||||
// Avoid making several duplicates of this class.
|
||||
template <typename T, /*The type used for activations*/
|
||||
template <typename T, /* The type used for activations */
|
||||
typename WeightType, /* The type for the MoE weights */
|
||||
typename OutputType = T, /* The type for the MoE final output */
|
||||
typename InputType = T, /* The type for the MoE input */
|
||||
@ -573,7 +574,7 @@ public:
|
||||
ActivationParams fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
|
||||
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
|
||||
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output,
|
||||
int* expanded_source_row_to_expanded_dest_row, MOEParallelismConfig parallelism_config, bool use_lora,
|
||||
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool use_lora,
|
||||
LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
|
||||
MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override;
|
||||
|
||||
@ -603,8 +604,8 @@ public:
|
||||
ScaleBiasType const* const fc2_expert_biases, ScaleBiasType const* const fc2_int_scales,
|
||||
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
|
||||
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
|
||||
float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
|
||||
int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
|
||||
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
|
||||
int const* permuted_row_to_unpermuted_row, int const* const expert_for_source_row,
|
||||
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
|
||||
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
|
||||
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
|
||||
@ -639,8 +640,8 @@ public:
|
||||
void const* const fc2_expert_weights, void const* const fc2_expert_biases, void const* const fc2_int_scales,
|
||||
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
|
||||
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
|
||||
float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
|
||||
int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
|
||||
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
|
||||
int const* permuted_row_to_unpermuted_row, int const* const expert_for_source_row,
|
||||
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
|
||||
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
|
||||
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
|
||||
@ -653,11 +654,11 @@ public:
|
||||
static_cast<OutputType*>(final_output), expert_first_token_offset, tma_ws_input_template,
|
||||
static_cast<WeightType const*>(fc2_expert_weights), static_cast<ScaleBiasType const*>(fc2_expert_biases),
|
||||
static_cast<ScaleBiasType const*>(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params,
|
||||
token_topk_unpermuted_scales, token_topk_permuted_scales, expanded_source_row_to_expanded_dest_row,
|
||||
expanded_dest_row_to_expanded_source_row, expert_for_source_row, num_valid_tokens_ptr, num_rows,
|
||||
expanded_num_rows, hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array,
|
||||
use_lora, fc2_lora, stream, parallelism_config, config, min_latency_mode, num_active_experts_per,
|
||||
active_expert_global_ids, start_expert);
|
||||
token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row,
|
||||
permuted_row_to_unpermuted_row, expert_for_source_row, num_valid_tokens_ptr, num_rows, expanded_num_rows,
|
||||
hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, use_lora, fc2_lora,
|
||||
stream, parallelism_config, config, min_latency_mode, num_active_experts_per, active_expert_global_ids,
|
||||
start_expert);
|
||||
}
|
||||
|
||||
virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override
|
||||
@ -673,7 +674,8 @@ public:
|
||||
void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2,
|
||||
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
|
||||
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1,
|
||||
void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream) override
|
||||
void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales,
|
||||
int const* permuted_row_to_unpermuted_row, cudaStream_t stream) override
|
||||
{
|
||||
return Self::computeStridesTmaWarpSpecialized(expert_first_token_offset, layout_info1, layout_info2, num_tokens,
|
||||
expanded_num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts_per_node,
|
||||
@ -682,7 +684,8 @@ public:
|
||||
alpha_scale_flat1, alpha_scale_flat2, fp4_act_flat1, fp4_act_flat2, quant_params,
|
||||
reinterpret_cast<ScaleBiasType const*>(bias1), reinterpret_cast<ScaleBiasType const*>(bias2),
|
||||
reinterpret_cast<UnfusedGemmOutputType*>(gemm1_output),
|
||||
reinterpret_cast<UnfusedGemmOutputType*>(gemm2_output), stream);
|
||||
reinterpret_cast<UnfusedGemmOutputType*>(gemm2_output), router_scales, permuted_row_to_unpermuted_row,
|
||||
stream);
|
||||
}
|
||||
|
||||
std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
|
||||
@ -724,7 +727,8 @@ private:
|
||||
float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
|
||||
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params,
|
||||
ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output,
|
||||
UnfusedGemmOutputType* gemm2_output, cudaStream_t stream);
|
||||
UnfusedGemmOutputType* gemm2_output, float const* router_scales, int const* permuted_row_to_unpermuted_row,
|
||||
cudaStream_t stream);
|
||||
static std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
|
||||
computeStridesTmaWarpSpecializedLowLatency(TmaWarpSpecializedGroupedGemmInput layout_info1,
|
||||
TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k,
|
||||
@ -754,8 +758,8 @@ private:
|
||||
|
||||
bool mayHaveFinalizeFused() const
|
||||
{
|
||||
return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() == 90
|
||||
&& !use_deterministic_hopper_reduce_ && !use_w4afp8;
|
||||
return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() >= 90 && use_fused_finalize_
|
||||
&& !use_w4afp8;
|
||||
}
|
||||
|
||||
// TODO: This should eventually take the quant params to give more flexibility
|
||||
@ -791,7 +795,7 @@ private:
|
||||
static void BlockScaleFC2(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output,
|
||||
OutputType* const final_output, int64_t const* const expert_first_token_offset,
|
||||
WeightType const* const fc2_expert_weights, ScaleBiasType const* const fc2_expert_biases,
|
||||
float const* const token_topk_unpermuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
|
||||
float const* const token_topk_unpermuted_scales, int const* const unpermuted_row_to_permuted_row,
|
||||
int const* const expert_for_source_row, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows,
|
||||
int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size,
|
||||
int const num_experts_per_node, int64_t const k, MOEParallelismConfig parallelism_config,
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3767deac592a204493b09f6798d50269c90d4571971b1746a5e6d0009a6d6d65
|
||||
size 64229720
|
||||
oid sha256:6489751f16a4dadf42664738ded03fbbd60195619f2d5f80af8190554318257d
|
||||
size 66872936
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
f68113dae0236968594276bf4f8b0a6f9161d3fbbac6fcb9ea1a438d16055490 libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
commit 995030f9b86258f3db876df6b1dbc46a7c5dae50
|
||||
813c237a565664b2acf2313f0e436f66f24deeb16a84d273dc007af55795e55f libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
commit 9c0a42825905952beaf9b35d5a35d58de1a123fa
|
||||
|
||||
@ -334,12 +334,13 @@ void MixtureOfExpertsPlugin::init()
|
||||
static_cast<int>(mType), static_cast<int>(mWeightType), static_cast<int>(mOutputType));
|
||||
}
|
||||
|
||||
mMOERunner->use_deterministic_hopper_reduce_ = mExpertsPerToken > 2 && mUseDeterministicKernels;
|
||||
mMOERunner->use_fused_finalize_
|
||||
= (mExpertsPerToken < 3 || !mUseDeterministicKernels) && !getEnvMOEDisableFinalizeFusion();
|
||||
|
||||
mGemmId1 = GemmIDMoe{1, mNumExperts, mExpertsPerToken, mParallelismConfig, mExpertHiddenSize, mExpertInterSize,
|
||||
mGroupSize, mActivationType, mType, mWeightType, mQuantMode, mMOERunner->use_deterministic_hopper_reduce_};
|
||||
mGroupSize, mActivationType, mType, mWeightType, mQuantMode, !mMOERunner->use_fused_finalize_};
|
||||
mGemmId2 = GemmIDMoe{2, mNumExperts, mExpertsPerToken, mParallelismConfig, mExpertHiddenSize, mExpertInterSize,
|
||||
mGroupSize, mActivationType, mType, mWeightType, mQuantMode, mMOERunner->use_deterministic_hopper_reduce_};
|
||||
mGroupSize, mActivationType, mType, mWeightType, mQuantMode, !mMOERunner->use_fused_finalize_};
|
||||
mGemmProfiler->setMaxProfileM(16384 * mNumExperts / mExpertsPerToken);
|
||||
|
||||
if (hasLora())
|
||||
|
||||
@ -95,7 +95,8 @@ public:
|
||||
};
|
||||
|
||||
FusedMoeRunner(c10::ScalarType activation_dtype, c10::ScalarType weight_dtype, c10::ScalarType output_dtype,
|
||||
bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_mxfp8_act_scaling)
|
||||
bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_mxfp8_act_scaling,
|
||||
bool use_fused_finalize)
|
||||
{
|
||||
mActivationDtype = activation_dtype;
|
||||
mWeightDtype = weight_dtype;
|
||||
@ -103,6 +104,7 @@ public:
|
||||
mUseDeepSeekFP8BlockScaling = use_deepseek_fp8_block_scale;
|
||||
mUseW4GroupScaling = use_w4_group_scaling;
|
||||
mUseMxfp8ActScaling = use_mxfp8_act_scaling;
|
||||
mUseFusedFinalize = use_fused_finalize;
|
||||
mInnerDimMultiplier = 1;
|
||||
|
||||
// keep consistent with cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp
|
||||
@ -213,6 +215,8 @@ public:
|
||||
<< ", Output: " << torch::toString(mOutputDtype));
|
||||
}
|
||||
|
||||
mKernelRunner->use_fused_finalize_ = mUseFusedFinalize;
|
||||
|
||||
mProfiler = std::make_shared<kernels::GemmProfilerBackend>();
|
||||
mAllProfiles = mKernelRunner->getTactics();
|
||||
}
|
||||
@ -674,6 +678,7 @@ private:
|
||||
bool mUseDeepSeekFP8BlockScaling = false;
|
||||
bool mUseW4GroupScaling = false;
|
||||
bool mUseMxfp8ActScaling = false;
|
||||
bool mUseFusedFinalize = true;
|
||||
|
||||
using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
|
||||
std::vector<Profile> mAllProfiles;
|
||||
@ -1045,7 +1050,7 @@ private:
|
||||
TORCH_LIBRARY(trtllm, m)
|
||||
{
|
||||
m.class_<torch_ext::FusedMoeRunner>("FusedMoeRunner")
|
||||
.def(torch::init<c10::ScalarType, c10::ScalarType, c10::ScalarType, bool, bool, bool>())
|
||||
.def(torch::init<c10::ScalarType, c10::ScalarType, c10::ScalarType, bool, bool, bool, bool>())
|
||||
.def("run_gemm_profile", &torch_ext::FusedMoeRunner::runGemmProfile)
|
||||
.def("get_tactic_num", &torch_ext::FusedMoeRunner::getTacticNum)
|
||||
.def("run_moe", &torch_ext::FusedMoeRunner::runMoe)
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h"
|
||||
@ -355,7 +371,7 @@ protected:
|
||||
float mSparseMixerEpsilon = 0.2f;
|
||||
|
||||
// Default this to true. This only matters for K>2, and so by doing this we will test the fused and unfused paths
|
||||
bool mUseDeterminsiticHopperReduce = true;
|
||||
bool mUseDeterministicHopperReduce = true;
|
||||
|
||||
// Disable this for long running tests to speed up runtime
|
||||
bool mIsLongTest = false;
|
||||
@ -440,7 +456,7 @@ protected:
|
||||
{
|
||||
managed_buffers.clear();
|
||||
|
||||
mMoERunner.use_deterministic_hopper_reduce_ = k > 2 && mUseDeterminsiticHopperReduce;
|
||||
mMoERunner.use_fused_finalize_ = k < 3 || !mUseDeterministicHopperReduce;
|
||||
|
||||
mHiddenSize = hidden_size;
|
||||
mInterSize = hidden_size * mInterSizeFraction;
|
||||
@ -1614,7 +1630,7 @@ void MixtureOfExpertsTest<TypeParam_>::BasicPermuteTest(
|
||||
|
||||
runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k);
|
||||
bool should_be_deterministic
|
||||
= mUseDeterminsiticHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
|
||||
= mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
|
||||
if (should_be_deterministic && !mIsLongTest)
|
||||
{
|
||||
auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize);
|
||||
@ -1733,7 +1749,7 @@ TYPED_TEST(MixtureOfExpertsTest, PermuteSwigluBias)
|
||||
|
||||
TYPED_TEST(MixtureOfExpertsTest, PermuteNonDeterministic)
|
||||
{
|
||||
this->mUseDeterminsiticHopperReduce = false;
|
||||
this->mUseDeterministicHopperReduce = false;
|
||||
// Just test case 3, cases 1&2 always use the fused paths
|
||||
this->BasicPermuteTest(3);
|
||||
}
|
||||
@ -1881,7 +1897,7 @@ void MixtureOfExpertsTest<TypeParam_>::ParallelismTest(
|
||||
runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k,
|
||||
MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall);
|
||||
bool should_be_deterministic
|
||||
= mUseDeterminsiticHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
|
||||
= mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
|
||||
if (should_be_deterministic && !mIsLongTest)
|
||||
{
|
||||
auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize);
|
||||
@ -1897,7 +1913,7 @@ void MixtureOfExpertsTest<TypeParam_>::ParallelismTest(
|
||||
{
|
||||
runMoEPermute(MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall);
|
||||
bool should_be_deterministic
|
||||
= mUseDeterminsiticHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
|
||||
= mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
|
||||
if (should_be_deterministic && !mIsLongTest)
|
||||
{
|
||||
auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize);
|
||||
|
||||
@ -68,7 +68,7 @@ ignore_patterns = [
|
||||
[tool.codespell]
|
||||
skip = ".git,3rdparty,tests/integration/test_input_files**,**.jsonl,**.json"
|
||||
exclude-file = "examples/models/core/whisper/tokenizer.py"
|
||||
ignore-words-list = "rouge,inout,atleast,strat,nd,subtile,thrid,improbe,NotIn,te,iteract,anythin,tru,Tracin,vEw"
|
||||
ignore-words-list = "rouge,inout,atleast,strat,nd,subtile,thrid,improbe,NotIn,te,iteract,anythin,tru,Tracin,vEw,dOut"
|
||||
|
||||
[tool.autoflake]
|
||||
in-place = true
|
||||
|
||||
@ -458,7 +458,7 @@ def _register_fake():
|
||||
gemm2_output: torch.Tensor,
|
||||
fc2_expert_biases: torch.Tensor,
|
||||
unpermuted_final_scales: torch.Tensor,
|
||||
expanded_source_row_to_expanded_dest_row: torch.Tensor,
|
||||
unpermuted_row_to_permuted_row: torch.Tensor,
|
||||
expert_for_source_row: torch.Tensor,
|
||||
expert_first_token_offset_tensor: torch.Tensor,
|
||||
num_rows: torch.SymInt,
|
||||
|
||||
@ -42,6 +42,7 @@ class MoERunner(TunableRunner):
|
||||
use_w4_group_scaling: bool,
|
||||
use_mxfp8_act_scaling: bool,
|
||||
min_latency_mode: bool,
|
||||
use_fused_finalize: bool,
|
||||
):
|
||||
self.x_dtype = x_dtype
|
||||
self.weight_dtype = weight_dtype
|
||||
@ -59,6 +60,8 @@ class MoERunner(TunableRunner):
|
||||
self.use_w4_group_scaling = use_w4_group_scaling
|
||||
self.use_mxfp8_act_scaling = use_mxfp8_act_scaling
|
||||
self.min_latency_mode = min_latency_mode
|
||||
self.use_fused_finalize = use_fused_finalize
|
||||
|
||||
instance_key = (x_dtype, weight_dtype, output_dtype,
|
||||
use_deepseek_fp8_block_scale, use_w4_group_scaling,
|
||||
use_mxfp8_act_scaling)
|
||||
@ -68,7 +71,7 @@ class MoERunner(TunableRunner):
|
||||
instance_key] = torch.classes.trtllm.FusedMoeRunner(
|
||||
x_dtype, weight_dtype, output_dtype,
|
||||
use_deepseek_fp8_block_scale, use_w4_group_scaling,
|
||||
use_mxfp8_act_scaling)
|
||||
use_mxfp8_act_scaling, use_fused_finalize)
|
||||
self.fused_moe_runner = MoERunner.runner_dict[instance_key]
|
||||
|
||||
def get_valid_tactics(
|
||||
@ -143,6 +146,7 @@ def fused_moe(
|
||||
use_w4_group_scaling: bool = False,
|
||||
use_mxfp8_act_scaling: bool = False,
|
||||
min_latency_mode: bool = False,
|
||||
use_fused_finalize: bool = True,
|
||||
tune_max_num_tokens: int = 8192,
|
||||
tuner_num_tokens: Optional[int] = None,
|
||||
tuner_top_k: Optional[int] = None,
|
||||
@ -179,6 +183,7 @@ def fused_moe(
|
||||
use_w4_group_scaling=use_w4_group_scaling,
|
||||
use_mxfp8_act_scaling=use_mxfp8_act_scaling,
|
||||
min_latency_mode=min_latency_mode,
|
||||
use_fused_finalize=use_fused_finalize,
|
||||
)
|
||||
|
||||
_, gemm_tactic_1 = tuner.choose_one(
|
||||
@ -259,6 +264,7 @@ def _(
|
||||
use_w4_group_scaling: bool = False,
|
||||
use_mxfp8_act_scaling: bool = False,
|
||||
min_latency_mode: bool = False,
|
||||
use_fused_finalize: bool = True,
|
||||
tune_max_num_tokens: int = 8192,
|
||||
):
|
||||
seq_len = input.shape[0]
|
||||
|
||||
@ -83,6 +83,9 @@ class ModelConfig(Generic[TConfig]):
|
||||
|
||||
attn_backend: str = 'TRTLLM'
|
||||
moe_backend: str = 'CUTLASS' # options can be CUTLASS, TRTLLM
|
||||
# IF true, disables FC2+finalize fusion in CUTLASS MoE backend
|
||||
moe_disable_finalize_fusion: bool = False
|
||||
|
||||
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO
|
||||
|
||||
# If true, enable min-latency mode. Currently only used for Llama4.
|
||||
|
||||
@ -152,6 +152,8 @@ class CutlassFusedMoE(MoE):
|
||||
# If True, the router weight will be multiplied on the input rather than at the end of FC2
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
|
||||
self.use_fused_finalize = not model_config.moe_disable_finalize_fusion
|
||||
|
||||
self._weights_created = False
|
||||
if not model_config.skip_create_weights_in_init:
|
||||
self.create_weights()
|
||||
@ -421,6 +423,7 @@ class CutlassFusedMoE(MoE):
|
||||
use_w4_group_scaling=use_w4_group_scaling,
|
||||
use_mxfp8_act_scaling=use_mxfp8_act_scaling,
|
||||
min_latency_mode=False,
|
||||
use_fused_finalize=self.use_fused_finalize,
|
||||
tune_max_num_tokens=self.tune_max_num_tokens,
|
||||
tuner_num_tokens=tuner_num_tokens,
|
||||
tuner_top_k=tuner_top_k,
|
||||
|
||||
@ -53,6 +53,8 @@ class PyTorchConfig:
|
||||
attn_backend: str = 'TRTLLM'
|
||||
moe_backend: str = 'CUTLASS'
|
||||
|
||||
moe_disable_finalize_fusion: bool = False
|
||||
|
||||
enable_mixed_sampler: bool = False
|
||||
"""
|
||||
If true, will iterate over sampling_params of each request and use the
|
||||
|
||||
@ -305,6 +305,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
checkpoint_loader=checkpoint_loader,
|
||||
attn_backend=attn_backend,
|
||||
moe_backend=pytorch_backend_config.moe_backend,
|
||||
moe_disable_finalize_fusion=pytorch_backend_config.
|
||||
moe_disable_finalize_fusion,
|
||||
load_format=pytorch_backend_config.load_format,
|
||||
max_num_tokens=max_num_tokens,
|
||||
moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens,
|
||||
|
||||
@ -183,6 +183,12 @@ class MoeConfig(StrictBaseModel):
|
||||
description="Configuration for MoE load balancing.",
|
||||
json_schema_extra={"type": "Union[MoeLoadBalancerConfig, str]"})
|
||||
|
||||
disable_finalize_fusion: bool = Field(
|
||||
default=False,
|
||||
description=
|
||||
"Disable FC2+finalize kernel fusion in CUTLASS MoE backend. Setting this to True recovers deterministic numerical behavior with top-k > 2."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
return cls(**data)
|
||||
@ -2365,6 +2371,7 @@ class TorchLlmArgs(BaseLlmArgs):
|
||||
enable_layerwise_nvtx_marker=self.enable_layerwise_nvtx_marker,
|
||||
load_format=self.load_format,
|
||||
enable_min_latency=self.enable_min_latency,
|
||||
moe_disable_finalize_fusion=self.moe_config.disable_finalize_fusion,
|
||||
stream_interval=self.stream_interval,
|
||||
force_dynamic_quantization=self.force_dynamic_quantization,
|
||||
allreduce_strategy=self.allreduce_strategy,
|
||||
|
||||
@ -1075,7 +1075,7 @@ def test_fused_moe_nvfp4(dtype):
|
||||
|
||||
# compare
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
|
||||
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.15)
|
||||
|
||||
|
||||
@skip_neither_ada_nor_hopper_unittest
|
||||
@ -1320,7 +1320,7 @@ def test_fused_moe_mxfp4_mxpf8(moe_backend, bias):
|
||||
|
||||
# compare
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
|
||||
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.15)
|
||||
|
||||
|
||||
@skip_non_hopper_unittest
|
||||
|
||||
Loading…
Reference in New Issue
Block a user