[fix] fix tileN cannot % 16==0 & support sm89 deepgemm bmm (#5531)

Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
This commit is contained in:
CarstyYou 2025-07-10 15:16:18 +08:00 committed by GitHub
parent 7d21b55b5a
commit dc32f9ae73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 821 additions and 1308 deletions

View File

@ -1,277 +0,0 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
// clang-format off
#include <cute/tensor.hpp>
#include <cutlass/layout/layout.h>
#include <cutlass/numeric_conversion.h>
// clang-format on
namespace ada_blockwise_gemm
{
template <typename Element, typename Layout, int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandA;
template <typename Element, typename Layout, int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB;
// specializations for float_e4m3
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cute::float_e4m3_t, cutlass::layout::RowMajor, 16, 128>
{
// Smem
using smem_layout = cute::Layout<cute::Shape<cute::_8, cute::_128>, cute::Stride<cute::_128, cute::_1>>;
using SmemLayoutAtom = decltype(cute::composition(cute::Swizzle<3, 4, 3>{}, smem_layout{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::float_e4m3_t>;
// Gmem
using copy_atom = decltype(cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::float_e4m3_t>{});
using thr_layout = decltype(cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_8, cute::_1>>{});
using val_layout = decltype(cute::Layout<cute::Shape<cute::_1, cute::_16>>{});
using GmemTiledCopy = decltype(cute::make_tiled_copy(copy_atom{}, thr_layout{}, val_layout{}));
};
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cute::float_e4m3_t, cutlass::layout::ColumnMajor, 16, 128>
{
// Smem
using smem_layout = cute::Layout<cute::Shape<cute::_8, cute::_128>, cute::Stride<cute::_128, cute::_1>>;
using SmemLayoutAtom = decltype(cute::composition(cute::Swizzle<3, 4, 3>{}, smem_layout{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::float_e4m3_t>;
// Gmem
using copy_atom = decltype(cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::float_e4m3_t>{});
using thr_layout = decltype(cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_1, cute::_16>>{});
using val_layout = decltype(cute::Layout<cute::Shape<cute::_16, cute::_1>>{});
using GmemTiledCopy = decltype(cute::make_tiled_copy(copy_atom{}, thr_layout{}, val_layout{}));
};
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, 8, 64>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_64>, cute::Stride<cute::_64, cute::_1>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::half_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_8, cute::_1>>{},
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
};
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, 8, 64>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_64>, cute::Stride<cute::_64, cute::_1>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::bfloat16_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_8, cute::_1>>{},
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
};
/// Operand A - Column-major (M-major)
template <int SizeK>
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::ColumnMajor, 8, SizeK>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_64, cute::_8>, cute::Stride<cute::_1, cute::_64>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U16x8_LDSM_T, cute::half_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_1, cute::_16>>{},
cute::Layout<cute::Shape<cute::_8, cute::_1>>{}));
};
template <int SizeK>
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::ColumnMajor, 8, SizeK>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_64, cute::_8>, cute::Stride<cute::_1, cute::_64>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U16x8_LDSM_T, cute::bfloat16_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_1, cute::_16>>{},
cute::Layout<cute::Shape<cute::_8, cute::_1>>{}));
};
// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the
// DefaultOperands
// Operand B - Column-Major (K-major)
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<cute::half_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, Alignment, SizeK>
{
};
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<cute::bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
{
};
// Operand B - Row-Major (N-major)
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<cute::half_t, cutlass::layout::RowMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
{
};
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<cute::bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
{
};
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<cute::float_e4m3_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<cute::float_e4m3_t, cutlass::layout::RowMajor, Alignment, SizeK>
{
};
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<cute::float_e4m3_t, cutlass::layout::RowMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<cute::float_e4m3_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
{
};
//
// F16: 128-by-128-by-32 (small k-block)
//
/// Operand A - Row-major (K-Major)
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, 8, 32>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<2, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_32>, cute::Stride<cute::_32, cute::_1>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::half_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
cute::Layout<cute::Shape<cute::_32, cute::_4>, cute::Stride<cute::_4, cute::_1>>{},
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
};
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, 8, 32>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<2, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_32>, cute::Stride<cute::_32, cute::_1>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::bfloat16_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
cute::Layout<cute::Shape<cute::_32, cute::_4>, cute::Stride<cute::_4, cute::_1>>{},
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
};
struct CopyTraitsScaleA
{
using GmemCopyAtom = cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEALWAYS<cute::uint32_t>, float>;
// Gmem
using GmemLayoutTVScale
= cute::Layout<cute::Shape<cute::Shape<cute::_32, cute::_4>, cute::Shape<cute::_1, cute::_1>>,
cute::Stride<cute::Stride<cute::_1, cute::_0>, cute::Stride<cute::_1, cute::_1>>>;
using GmemTileShapeScale = cute::Shape<cute::_32, cute::_1>;
using GmemTiledCopyScale
= decltype(cute::make_tiled_copy_impl(GmemCopyAtom{}, GmemLayoutTVScale{}, GmemTileShapeScale{}));
// Smem
using SmemCopyAtomScale = cute::Copy_Atom<cute::UniversalCopy<float>, float>;
using SmemLayoutTVScale
= cute::Layout<cute::Shape<cute::Shape<cute::_4, cute::_8, cute::_2, cute::_2>, cute::Shape<cute::_2>>,
cute::Stride<cute::Stride<cute::_0, cute::_1, cute::_16, cute::_0>, cute::Stride<cute::_8>>>;
using SmemTileShapeScale = cute::Shape<cute::_32, cute::_1>;
using SmemTiledCopyScale
= decltype(cute::make_tiled_copy_impl(SmemCopyAtomScale{}, SmemLayoutTVScale{}, SmemTileShapeScale{}));
};
struct CopyTraitsScaleB
{
using GmemCopyAtom = cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEALWAYS<cute::uint32_t>, float>;
// Gmem
using GmemLayoutTVScale
= cute::Layout<cute::Shape<cute::Shape<cute::_32, cute::_4>, cute::Shape<cute::_1, cute::_1>>,
cute::Stride<cute::Stride<cute::_0, cute::_0>, cute::Stride<cute::_1, cute::_1>>>;
using GmemTileShapeScale = cute::Shape<cute::_1, cute::_1>;
using GmemTiledCopyScale
= decltype(cute::make_tiled_copy_impl(GmemCopyAtom{}, GmemLayoutTVScale{}, GmemTileShapeScale{}));
// Smem
using SmemCopyAtomScale = cute::Copy_Atom<cute::UniversalCopy<float>, float>;
using SmemLayoutTVScale
= cute::Layout<cute::Shape<cute::Shape<cute::_4, cute::_8, cute::_2, cute::_2>, cute::Shape<cute::_1>>,
cute::Stride<cute::Stride<cute::_0, cute::_0, cute::_0, cute::_0>, cute::Stride<cute::_0>>>;
using SmemTileShapeScale = cute::Shape<cute::_1, cute::_1>;
using SmemTiledCopyScale
= decltype(cute::make_tiled_copy_impl(SmemCopyAtomScale{}, SmemLayoutTVScale{}, SmemTileShapeScale{}));
};
template <typename To_type, typename Engine, typename Layout>
CUTE_DEVICE auto util_convert_type(cute::Tensor<Engine, Layout> const& tensor)
{
using From_type = typename Engine::value_type;
constexpr int numel = decltype(cute::size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// HACK: this requires tensor to be "contiguous"
auto frag = convert_op(*reinterpret_cast<cutlass::Array<From_type, numel> const*>(tensor.data()));
return cute::make_tensor(cute::make_rmem_ptr<To_type>(&frag), tensor.layout());
}
template <typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
CUTE_DEVICE void util_copy(
TiledCopy const& tiled_copy, cute::Tensor<Engine0, Layout0> const& S, cute::Tensor<Engine1, Layout1>& D)
{
CUTE_STATIC_ASSERT_V(cute::rank(S) == cute::Int<3>{});
CUTE_STATIC_ASSERT_V(cute::rank(D) == cute::Int<3>{});
CUTE_STATIC_ASSERT_V(cute::size<0>(S) == cute::size<0>(D));
CUTE_STATIC_ASSERT_V(cute::size<1>(S) == cute::size<1>(D));
CUTE_STATIC_ASSERT_V(cute::size<2>(S) == cute::size<2>(D));
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<1>(S); ++m)
{
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < cute::size<2>(S); ++k)
{
cute::copy(tiled_copy, S(cute::_, m, k), D(cute::_, m, k));
}
}
}
} // namespace ada_blockwise_gemm

View File

@ -1,173 +0,0 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/numeric_types.h>
#include <cutlass/trace.h>
#include "ada_blockwise_gemm_kernel.cuh"
#define CUTLASS_HOST_TRACE(x) \
{ \
std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; \
}
namespace ada_blockwise_gemm
{
template <typename GemmKernel>
CUTLASS_GLOBAL void run_global(typename GemmKernel::Params params)
{
// Dynamic shared memory base pointer
extern __shared__ int SharedStorageBase[];
// Declare pointer to dynamic shared memory.
typename GemmKernel::SharedStorage* shared_storage
= reinterpret_cast<typename GemmKernel::SharedStorage*>(SharedStorageBase);
GemmKernel::invoke(params, *shared_storage);
}
using namespace cutlass;
template <typename KT>
struct AdaBlockwiseGemm
{
using GemmKernel = AdaBlockwiseGemmKernel<KT>;
static constexpr int kSmemSize = GemmKernel::kSmemSize;
static constexpr int kThreadCount = GemmKernel::kThreadCount;
/// Kernel parameters object
typename GemmKernel::Params params_;
AdaBlockwiseGemm()
: params_()
{
}
using Arguments = typename GemmKernel::Arguments;
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1)
{
CUTLASS_TRACE_HOST("AdaBlockwiseGemmKernel::maximum_active_blocks()");
CUTLASS_TRACE_HOST(" kSmemSize: " << kSmemSize << " bytes");
cudaError_t result;
if (kSmemSize > (48 << 10))
{
result
= cudaFuncSetAttribute(run_global<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_HOST_TRACE(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result));
return -1;
}
}
int max_active_blocks = -1;
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, run_global<GemmKernel>, kThreadCount, kSmemSize);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_HOST_TRACE(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned "
"error "
<< cudaGetErrorString(result));
return -1;
}
CUTLASS_HOST_TRACE(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
Status can_implement(Arguments const& args)
{
if (kSmemSize > (48 << 10))
{
cudaError_t result
= cudaFuncSetAttribute(run_global<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_HOST_TRACE(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result));
return Status::kInvalid;
}
}
if (args.problem_size.n() % KT::kTileN != 0)
{
CUTLASS_HOST_TRACE(" n:" << args.problem_size.n() << " % kTileN:" << KT::kTileN << " != 0");
return Status::kInvalid;
}
if (args.problem_size.k() % KT::kTileK != 0)
{
CUTLASS_HOST_TRACE(" k:" << args.problem_size.k() << " % kTileK:" << KT::kTileK << " != 0");
return Status::kInvalid;
}
return Status::kSuccess;
}
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
{
params_ = GemmKernel::to_underlying_arguments(args);
return Status::kSuccess;
}
Status run(cudaStream_t stream = nullptr)
{
// Configure grid and block dimensions
dim3 grid = GemmKernel::get_grid_shape(params_.problem_size);
dim3 block = GemmKernel::get_block_shape();
// Launch kernel
run_global<GemmKernel><<<grid, block, kSmemSize, stream>>>(params_);
// Query for errors
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess)
{
CUTLASS_HOST_TRACE(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
return Status::kSuccess;
}
};
} // namespace ada_blockwise_gemm

View File

@ -1,513 +0,0 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "ada_blockwise_gemm_traits.cuh"
namespace ada_blockwise_gemm
{
template <typename KT>
struct AdaBlockwiseGemmKernel
{
using Params = typename KT::Params;
using Arguments = typename KT::Arguments;
using SharedStorage = typename KT::SharedStorage;
static constexpr int kThreadCount = KT::kThreadCount;
static constexpr int kSmemSize = KT::kSmemSize;
static dim3 get_grid_shape(GemmCoord problem_size)
{
int grid_m = (problem_size.m() + KT::kTileM - 1) / KT::kTileM;
int grid_n = (problem_size.n() + KT::kTileN - 1) / KT::kTileN;
int grid_k = 1;
return dim3(grid_m, grid_n, grid_k);
}
static dim3 get_block_shape()
{
return dim3(kThreadCount, 1, 1);
}
static Params to_underlying_arguments(Arguments const& args)
{
return KT::to_underlying_arguments(args);
}
// Factory invocation
CUTLASS_DEVICE
static void invoke(Params const& params, SharedStorage& shared_storage)
{
AdaBlockwiseGemmKernel op;
op(params, shared_storage);
}
CUTE_DEVICE auto gmem_tensor_init(Params const& params)
{
using X = cute::Underscore;
int const M = params.problem_size.m();
int const N = params.problem_size.n();
int const K = params.problem_size.k();
int const ScaleM = (((M + 3) >> 2) << 2); // align 4
int const ScaleN = (N + KT::ScaleGranularityN - 1) / KT::ScaleGranularityN;
int const ScaleK = (K + KT::ScaleGranularityK - 1) / KT::ScaleGranularityK;
typename KT::ElementA const* ptr_A_ = params.ptr_a;
typename KT::ElementB const* ptr_B_ = params.ptr_b;
typename KT::ElementOutput* ptr_output_ = params.ptr_output;
typename KT::ElementBlockScale const* ptr_scale_a_ = params.ptr_scale_a;
typename KT::ElementBlockScale const* ptr_scale_b_ = params.ptr_scale_b;
cute::Tensor mA_mk
= cute::make_tensor(cute::make_gmem_ptr(ptr_A_), cute::make_shape(M, K), cute::make_stride(K, cute::_1{}));
cute::Tensor mB_nk
= cute::make_tensor(cute::make_gmem_ptr(ptr_B_), cute::make_shape(N, K), cute::make_stride(K, cute::_1{}));
cute::Tensor mOutput_mn = cute::make_tensor(
cute::make_gmem_ptr(ptr_output_), cute::make_shape(M, N), cute::make_stride(N, cute::_1{}));
cute::Tensor mScaleA_mk = cute::make_tensor(
cute::make_gmem_ptr(ptr_scale_a_), cute::make_shape(ScaleM, ScaleK), cute::make_stride(cute::_1{}, ScaleM));
cute::Tensor mScaleB_nk = cute::make_tensor(
cute::make_gmem_ptr(ptr_scale_b_), cute::make_shape(ScaleN, ScaleK), cute::make_stride(ScaleK, cute::_1{}));
// partition the gmem tensor for each Cta
cute::Tensor gA_mk = cute::local_tile(mA_mk, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, X, cute::_1>{}); // (BLK_M, BLK_K, m, k)
cute::Tensor gB_nk = cute::local_tile(mB_nk, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<X, cute::_1, cute::_1>{}); // (BLK_N, BLK_K, n, k)
cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
cute::Tensor gScaleA_mk = cute::local_tile(mScaleA_mk, typename KT::ScalePerTileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, X, cute::_1>{}); // (BLK_M, BLK_K, m, k)
cute::Tensor gScaleB_nk = cute::local_tile(mScaleB_nk, typename KT::ScalePerTileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<X, cute::_1, cute::_1>{}); // (BLK_N, BLK_K, n, k)
return cute::make_tuple(gA_mk, gB_nk, gOutput_mn, gScaleA_mk, gScaleB_nk);
}
template <class TensorAccum, class TensorScaleA, class TensorScaleB>
CUTE_DEVICE void promote(
TensorAccum& accum, TensorAccum const& temp_accum, TensorScaleA const& tCrScaleA, TensorScaleB const& tCrScaleB)
{
using AccumType = typename TensorAccum::value_type;
CUTE_UNROLL
for (int mma_m = 0; mma_m < cute::get<1>(cute::shape<0>(accum)); ++mma_m)
{
AccumType sFA = tCrScaleA(mma_m);
AccumType sFB = tCrScaleB(0);
AccumType scale = sFA * sFB;
CUTE_UNROLL
for (int mma_n = 0; mma_n < cute::get<0>(cute::shape<0>(accum)); ++mma_n)
{
CUTE_UNROLL
for (int mma_iter_m = 0; mma_iter_m < cute::size<1>(accum); ++mma_iter_m)
{
CUTE_UNROLL
for (int mma_iter_n = 0; mma_iter_n < cute::size<2>(accum); ++mma_iter_n)
{
auto coord = cute::make_coord(cute::make_coord(mma_n, mma_m), mma_iter_m, mma_iter_n);
accum(coord) += temp_accum(coord) * scale;
}
}
}
}
}
/// Executes one GEMM
CUTE_DEVICE
void operator()(Params const& params, SharedStorage& shared_storage)
{
int const block_m_idx = blockIdx.x;
int const block_n_idx = blockIdx.y;
int const thread_idx = threadIdx.x;
int const residue_m = params.problem_size.m() - block_m_idx * cute::size<0>(typename KT::TileShape{});
int const residue_n = params.problem_size.n() - block_n_idx * cute::size<1>(typename KT::TileShape{});
// gmem tensor partition ..
auto [gA_mk, gB_nk, gOutput_mn, gScaleA_mk, gScaleB_nk] = gmem_tensor_init(params);
// smem tensor ..
cute::Tensor sA = cute::make_tensor(
cute::make_smem_ptr(shared_storage.smem_a.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage)
cute::Tensor sB = cute::make_tensor(
cute::make_smem_ptr(shared_storage.smem_b.data()), typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage)
cute::Tensor sO = cute::make_tensor(
cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N)
cute::Tensor sScaleA = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_scale_a.data()),
typename KT::SmemLayoutScaleA{}); // (BLK_M, BLK_K, Stage)
cute::Tensor sScaleB = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_scale_b.data()),
typename KT::SmemLayoutScaleB{}); // (BLK_N, BLK_K, Stage)
// (1) first step, get the B_res and B_gate
// (1.1) get partition for gmem -> smem
cute::Tensor gA = gA_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k)
cute::Tensor gB = gB_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k)
cute::Tensor gScaleA = gScaleA_mk(cute::_, cute::_, block_m_idx, cute::_);
cute::Tensor gScaleB = gScaleB_nk(cute::_, cute::_, block_n_idx, cute::_);
typename KT::GmemTiledCopyA gmem_tiled_copy_A;
typename KT::GmemTiledCopyB gmem_tiled_copy_B;
auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx);
auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx);
cute::Tensor tAgA = gmem_thr_copy_A.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k)
cute::Tensor tAsA = gmem_thr_copy_A.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,Stage)
cute::Tensor tBgB = gmem_thr_copy_B.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k)
cute::Tensor tBsB = gmem_thr_copy_B.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,Stage)
typename KT::GmemTiledCopyScaleA gmem_tiled_copy_ScaleA;
typename KT::GmemTiledCopyScaleB gmem_tiled_copy_ScaleB;
auto gmem_thr_copy_ScaleA = gmem_tiled_copy_ScaleA.get_slice(thread_idx);
auto gmem_thr_copy_ScaleB = gmem_tiled_copy_ScaleB.get_slice(thread_idx);
cute::Tensor tAgScaleA = gmem_thr_copy_ScaleA.partition_S(gScaleA); // (ACPY,ACPY_M,ACPY_K,k)
cute::Tensor tAsScaleA = gmem_thr_copy_ScaleA.partition_D(sScaleA); // (ACPY,ACPY_M,ACPY_K,Stage)
cute::Tensor tBgScaleB = gmem_thr_copy_ScaleB.partition_S(gScaleB); // (BCPY,BCPY_N,BCPY_K,k)
cute::Tensor tBsScaleB = gmem_thr_copy_ScaleB.partition_D(sScaleB); // (BCPY,BCPY_N,BCPY_K,Stage)
// Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor)
cute::Tensor tApA = cute::make_tensor<bool>(
cute::make_shape(cute::size<1>(tAsA), cute::size<2>(tAsA)), cute::Stride<cute::_1, cute::_0>{});
// Construct identity layout for sA
cute::Tensor cA = make_identity_tensor(
cute::make_shape(cute::size<0>(sA), cute::size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
cute::Tensor tAcA = gmem_thr_copy_A.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
// Set predicates for m bounds
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<0>(tApA); ++m)
{
tApA(m, 0) = cute::get<0>(tAcA(0, m, 0)) < residue_m; // blk_m coord < residue_m
}
cute::Tensor tBpB = cute::make_tensor<bool>(
cute::make_shape(cute::size<1>(tBsB), cute::size<2>(tBsB)), cute::Stride<cute::_1, cute::_0>{});
// Construct identity layout for sB
cute::Tensor cB = make_identity_tensor(
cute::make_shape(cute::size<0>(sB), cute::size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
// Repeat the partitioning with identity layouts
cute::Tensor tBcB = gmem_thr_copy_B.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
// Set predicates for n bounds
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < cute::size<0>(tBpB); ++n)
{
tBpB(n, 0) = cute::get<0>(tBcB(0, n, 0)) < residue_n; // blk_n coord < residue_n
}
cute::Tensor tApSFA = cute::make_tensor<bool>(
cute::make_shape(cute::size<1>(tAsScaleA), cute::size<2>(tAsScaleA)), cute::Stride<cute::_1, cute::_0>{});
cute::Tensor tAcSFA = gmem_thr_copy_ScaleA.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<0>(tApSFA); ++m)
{
tApSFA(m, 0) = cute::get<0>(tAcSFA(0, m, 0)) < residue_m; // blk_m coord < residue_m
}
// (1.2) prefetch gmem -> smem
cute::clear(tAsA); // we don't need to clear tBsB..
cute::clear(tBsB);
cute::clear(tAsScaleA);
cute::clear(tBsScaleB);
auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gA)); // emm, iter start from 0
int k_tile_count = cute::size<2>(gA);
CUTLASS_PRAGMA_UNROLL
for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe)
{
if (k_tile_count <= 0)
{
cute::clear(tApA);
cute::clear(tBpB);
cute::clear(tApSFA);
}
cute::copy_if(gmem_tiled_copy_A, tApA, tAgA(cute::_, cute::_, cute::_, *k_tile_iter),
tAsA(cute::_, cute::_, cute::_, k_pipe));
cute::copy_if(gmem_tiled_copy_B, tBpB, tBgB(cute::_, cute::_, cute::_, *k_tile_iter),
tBsB(cute::_, cute::_, cute::_, k_pipe));
cute::copy_if(gmem_tiled_copy_ScaleA, tApSFA, tAgScaleA(cute::_, cute::_, cute::_, *k_tile_iter),
tAsScaleA(cute::_, cute::_, cute::_, k_pipe));
cute::copy(gmem_tiled_copy_ScaleB, tBgScaleB(cute::_, cute::_, cute::_, *k_tile_iter),
tBsScaleB(cute::_, cute::_, cute::_, k_pipe));
cute::cp_async_fence();
k_tile_count--;
if (k_tile_count > 0)
{
++k_tile_iter;
}
}
// (1.3) get partition for rf
typename KT::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
cute::Tensor tCrA = thr_mma.partition_fragment_A(sA(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K)
cute::Tensor tCrB = thr_mma.partition_fragment_B(sB(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K)
cute::Tensor accum
= cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N)
cute::Tensor temp_accum
= cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N)
cute::clear(accum);
// checkout the shape
CUTE_STATIC_ASSERT_V(cute::size<1>(tCrA) == cute::size<1>(accum)); // MMA_M
CUTE_STATIC_ASSERT_V(cute::size<1>(tCrB) == cute::size<2>(accum)); // MMA_N
CUTE_STATIC_ASSERT_V(cute::size<2>(tCrA) == cute::size<2>(tCrB)); // MMA_K
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma));
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma));
// (1.4)retiling the smem and rf for copy..
auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma);
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx);
cute::Tensor tOsA = smem_thr_copy_A.partition_S(sA); // (CPY,CPY_M,CPY_K,Stage)
cute::Tensor tCrA_write = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K)
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsA) == cute::size<1>(tCrA_write)); // CPY_M
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsA) == cute::size<2>(tCrA_write)); // CPY_K
auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma);
auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx);
cute::Tensor tOsB = smem_thr_copy_B.partition_S(sB); // (CPY,CPY_N,CPY_K,Stage)
cute::Tensor tCrB_write = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_N,CPY_K)
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsB) == cute::size<1>(tCrB_write)); // CPY_N
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsB) == cute::size<2>(tCrB_write)); // CPY_K
typename KT::SmemCopyAtomScaleA smem_tiled_copy_ScaleA;
typename KT::SmemCopyAtomScaleB smem_tiled_copy_ScaleB;
auto smem_thr_copy_ScaleA = smem_tiled_copy_ScaleA.get_thread_slice(thread_idx);
auto smem_thr_copy_ScaleB = smem_tiled_copy_ScaleB.get_thread_slice(thread_idx);
cute::Tensor tOsScaleA = smem_thr_copy_ScaleA.partition_S(sScaleA);
cute::Tensor tCrScaleA = cute::make_fragment_like(tOsScaleA(cute::_, cute::_, cute::_, 0));
cute::Tensor tOsScaleB = smem_thr_copy_ScaleB.partition_S(sScaleB);
cute::Tensor tCrScaleB = cute::make_fragment_like(tOsScaleB(cute::_, cute::_, cute::_, 0));
// (1.5) mainloop
// Current pipe index in smem to read from
int smem_pipe_read = 0;
// Current pipe index in smem to write to
int smem_pipe_write = KT::Stages - 1;
cute::Tensor tOsA_read = tOsA(cute::_, cute::_, cute::_, smem_pipe_read);
cute::Tensor tOsB_read = tOsB(cute::_, cute::_, cute::_, smem_pipe_read);
cute::Tensor tOsScaleA_read = tOsScaleA(cute::_, cute::_, cute::_, smem_pipe_read);
cute::Tensor tOsScaleB_read = tOsScaleB(cute::_, cute::_, cute::_, smem_pipe_read);
constexpr int K_BLOCK_MAX = cute::size<2>(tCrA);
// prefetch register pipeline
if constexpr (K_BLOCK_MAX > 1)
{
cute::cp_async_wait<KT::Stages - 2>();
__syncthreads();
// Prefetch the first k-tile smem -> reg
cute::copy(smem_tiled_copy_A, tOsA_read(cute::_, cute::_, cute::Int<0>{}),
tCrA_write(cute::_, cute::_, cute::Int<0>{}));
cute::copy(smem_tiled_copy_B, tOsB_read(cute::_, cute::_, cute::Int<0>{}),
tCrB_write(cute::_, cute::_, cute::Int<0>{}));
}
// k loop for mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count)
{
cute::clear(temp_accum);
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
if (k_block == K_BLOCK_MAX - 1)
{
tOsA_read = tOsA(cute::_, cute::_, cute::_, smem_pipe_read);
tOsB_read = tOsB(cute::_, cute::_, cute::_, smem_pipe_read);
tOsScaleA_read = tOsScaleA(cute::_, cute::_, cute::_, smem_pipe_read);
tOsScaleB_read = tOsScaleB(cute::_, cute::_, cute::_, smem_pipe_read);
cute::cp_async_wait<KT::Stages - 2>();
__syncthreads();
}
// Load A, B smem -> reg for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsA_read(cute::_, cute::_, k_block_next),
tCrA_write(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsB_read(cute::_, cute::_, k_block_next),
tCrB_write(cute::_, cute::_, k_block_next));
// Copy gmem -> smem before computing gemm on each k-pipe
if (k_block == 0)
{
cute::copy_if(gmem_tiled_copy_A, tApA, tAgA(cute::_, cute::_, cute::_, *k_tile_iter),
tAsA(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy_if(gmem_tiled_copy_B, tBpB, tBgB(cute::_, cute::_, cute::_, *k_tile_iter),
tBsB(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy_if(gmem_tiled_copy_ScaleA, tApSFA,
tAgScaleA(cute::_, cute::_, cute::_, *k_tile_iter),
tAsScaleA(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy(gmem_tiled_copy_ScaleB, tBgScaleB(cute::_, cute::_, cute::_, *k_tile_iter),
tBsScaleB(cute::_, cute::_, cute::_, smem_pipe_write));
cute::cp_async_fence();
if (k_tile_count - 1 > 0)
{
++k_tile_iter;
}
cute::copy(smem_tiled_copy_ScaleA, tOsScaleA_read, tCrScaleA);
cute::copy(smem_tiled_copy_ScaleB, tOsScaleB_read, tCrScaleB);
// Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe)
smem_pipe_write = smem_pipe_read;
++smem_pipe_read;
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
}
// Thread-level register gemm for k_block
cute::gemm(tiled_mma, temp_accum, tCrA(cute::_, cute::_, k_block), tCrB(cute::_, cute::_, k_block),
temp_accum);
});
promote(accum, temp_accum, tCrScaleA, tCrScaleB);
}
// load tail
cute::for_each(cute::make_int_sequence<KT::Stages - 2>{},
[&](auto WaitIndex)
{
k_tile_count--;
using WaitIndex_t = decltype(WaitIndex);
cute::clear(temp_accum);
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
if (k_block == K_BLOCK_MAX - 1)
{
tOsA_read = tOsA(cute::_, cute::_, cute::_, smem_pipe_read);
tOsB_read = tOsB(cute::_, cute::_, cute::_, smem_pipe_read);
tOsScaleA_read = tOsScaleA(cute::_, cute::_, cute::_, smem_pipe_read);
tOsScaleB_read = tOsScaleB(cute::_, cute::_, cute::_, smem_pipe_read);
cute::cp_async_wait<KT::Stages - 3 - WaitIndex_t::value>();
__syncthreads();
}
// Load A, B smem -> reg for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsA_read(cute::_, cute::_, k_block_next),
tCrA_write(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsB_read(cute::_, cute::_, k_block_next),
tCrB_write(cute::_, cute::_, k_block_next));
if (k_block == 0)
{
cute::copy(smem_tiled_copy_ScaleA, tOsScaleA_read, tCrScaleA);
cute::copy(smem_tiled_copy_ScaleB, tOsScaleB_read, tCrScaleB);
// only update smem_pipe_read
++smem_pipe_read;
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
}
// Thread-level register gemm for k_block
cute::gemm(tiled_mma, temp_accum, tCrA(cute::_, cute::_, k_block),
tCrB(cute::_, cute::_, k_block), temp_accum);
});
promote(accum, temp_accum, tCrScaleA, tCrScaleB);
});
// mma tail
cute::clear(temp_accum);
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
// Load A, B smem -> reg for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsA_read(cute::_, cute::_, k_block_next),
tCrA_write(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsB_read(cute::_, cute::_, k_block_next),
tCrB_write(cute::_, cute::_, k_block_next));
if (k_block == 0)
{
cute::copy(smem_tiled_copy_ScaleA, tOsScaleA_read, tCrScaleA);
cute::copy(smem_tiled_copy_ScaleB, tOsScaleB_read, tCrScaleB);
}
// Thread-level register gemm for k_block
cute::gemm(tiled_mma, temp_accum, tCrA(cute::_, cute::_, k_block), tCrB(cute::_, cute::_, k_block),
temp_accum);
});
promote(accum, temp_accum, tCrScaleA, tCrScaleB);
// (4) push all the result to smem
// (4.1) convert result from ElementAccum to ElementA
cute::Tensor epi = util_convert_type<KT::ElementOutput>(accum);
// (4.2) rf -> smem
auto smem_tiled_copy_R2S = cute::make_tiled_copy_C(typename KT::SmemCopyAtomR2S{}, tiled_mma);
auto smem_thr_copy_R2S = smem_tiled_copy_R2S.get_thread_slice(thread_idx);
// cute::clear(sO);
cute::Tensor tRS_rO = smem_thr_copy_R2S.retile_S(epi);
cute::Tensor tRS_sO = smem_thr_copy_R2S.partition_D(sO);
cute::copy(smem_tiled_copy_R2S, tRS_rO, tRS_sO);
__syncthreads();
// (4.3) smem -> rf
typename KT::SmemTiledCopyS2R smem_tiled_copy_S2R;
auto smem_thr_copy_S2R = smem_tiled_copy_S2R.get_thread_slice(thread_idx);
cute::Tensor tSR_sO = smem_thr_copy_S2R.partition_S(sO);
cute::Tensor tSR_rO = cute::make_tensor<KT::ElementOutput>(cute::shape(tSR_sO));
cute::copy(smem_tiled_copy_S2R, tSR_sO, tSR_rO);
__syncthreads();
// (4.4) rf -> gmem
cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx);
cute::Tensor cO = cute::make_identity_tensor(
cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{})));
auto tRG_rO = smem_thr_copy_S2R.retile_S(tSR_rO);
auto tRG_gO = smem_thr_copy_S2R.partition_D(gO);
auto tRG_cO = smem_thr_copy_S2R.partition_D(cO);
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<1>(tRG_cO); ++m)
{
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < cute::size<2>(tRG_cO); ++n)
{
if (cute::get<0>(tRG_cO(0, m, n)) < residue_m && cute::get<1>(tRG_cO(0, m, n)) < residue_n)
{
cute::copy(typename KT::GmemCopyAtomR2G{}, tRG_rO(cute::_, m, n), tRG_gO(cute::_, m, n));
}
}
}
}
};
} // namespace ada_blockwise_gemm

View File

@ -1,217 +0,0 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
// clang-format off
#include <cuda_runtime.h>
#include <cutlass/cutlass.h>
#include <cute/int_tuple.hpp>
#include <cute/layout.hpp>
#include <cutlass/arch/mma.h>
#include "ada_blockwise_mma_utils.cuh"
#include "ada_blockwise_copy_utils.cuh"
// clang-format on
using namespace cute;
using namespace cutlass;
using namespace cutlass::gemm;
namespace ada_blockwise_gemm
{
template <typename ElementType, typename OutElementType, typename AccumElementType, typename BlockScaleElementType,
int Stages_, int TileM_, int TileN_, int TileK_>
struct AdaBlockwiseGemmTraits
{
using ElementA = ElementType;
using LayoutA = cutlass::layout::RowMajor;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
using ElementB = ElementType;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
using ElementAccum = AccumElementType;
using ElementBlockScale = BlockScaleElementType;
using ElementOutput = OutElementType;
using index_t = uint32_t;
static_assert(TileM_ % 16 == 0);
static_assert(TileN_ % 32 == 0);
static_assert(TileK_ % 32 == 0);
static constexpr int Stages = Stages_;
static constexpr int kTileM = TileM_;
static constexpr int kTileN = TileN_;
static constexpr int kTileK = (kTileM > 16) ? (TileK_) : (TileK_ >= 64 ? TileK_ : 64);
// tile shape
using TileShape = cute::Shape<cute::Int<kTileM>, cute::Int<kTileN>, cute::Int<kTileK>>;
static constexpr int kWarpsCount = 4;
static constexpr int kThreadCount = kWarpsCount * 32;
static constexpr int ScaleGranularityM = 1;
static constexpr int ScaleGranularityN = 128;
static constexpr int ScaleGranularityK = 128;
static constexpr int ScaleMsPerTile = (kTileM + ScaleGranularityM - 1) / ScaleGranularityM;
static constexpr int ScaleNsPerTile = (kTileN + ScaleGranularityN - 1) / ScaleGranularityN;
static constexpr int ScaleKsPerTile = (kTileK + ScaleGranularityK - 1) / ScaleGranularityK;
static_assert(ScaleKsPerTile >= 1, "ScaleKsPerTile must be greater than or equal to 1");
using ScaleGranularity
= cute::Shape<cute::Int<ScaleGranularityM>, cute::Int<ScaleGranularityN>, cute::Int<ScaleGranularityK>>;
using ScalePerTileShape
= cute::Shape<cute::Int<ScaleMsPerTile>, cute::Int<ScaleNsPerTile>, cute::Int<ScaleKsPerTile>>;
// MMA atom arch and layout
using TiledMma = DefaultGemm_TensorOp_MMA<cute::float_e4m3_t, cutlass::arch::Sm89>::TiledMma;
static constexpr int kBlockKSmem = 128;
// A memory copy operand
using DefaultOperandA
= DefaultGemm_TensorOpSm80_OperandA<ElementA, cutlass::layout::RowMajor, AlignmentA, kBlockKSmem>;
using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom;
using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom;
using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy;
// ScaleA memory copy operand
using SmemLayoutAtomScaleA = decltype(cute::make_layout(typename CopyTraitsScaleA::SmemTileShapeScale{}));
using SmemCopyAtomScaleA = typename CopyTraitsScaleA::SmemTiledCopyScale;
using GmemTiledCopyScaleA = typename CopyTraitsScaleA::GmemTiledCopyScale;
// B memory copy operand
using DefaultOperandB
= DefaultGemm_TensorOpSm80_OperandB<ElementB, cutlass::layout::ColumnMajor, AlignmentB, kBlockKSmem>;
using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom;
using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom;
using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy;
// ScaleB memory copy operand
using SmemLayoutAtomScaleB = decltype(cute::make_layout(typename CopyTraitsScaleB::SmemTileShapeScale{}));
using SmemCopyAtomScaleB = typename CopyTraitsScaleB::SmemTiledCopyScale;
using GmemTiledCopyScaleB = typename CopyTraitsScaleB::GmemTiledCopyScale;
// Output memory copy operand
using SmemLayoutAtomO = decltype(cute::composition(cute::Swizzle<3, 3, 3>{},
cute::Layout<cute::Shape<cute::_8, cute::Shape<cute::_8, cute::_8>>,
cute::Stride<cute::_8, cute::Stride<cute::_1, cute::_64>>>{}));
using SmemCopyAtomR2S = cute::Copy_Atom<cute::AutoVectorizingCopy, ElementOutput>;
using SmemCopyAtomS2R = cute::Copy_Atom<cute::UniversalCopy<cute::uint128_t>, ElementOutput>;
using GmemCopyAtomR2G = cute::Copy_Atom<cute::UniversalCopy<cute::uint128_t>, ElementOutput>;
using SmemThrLayoutS2R
= cute::Layout<cute::Shape<cute::Int<8>, cute::Int<16>>, cute::Stride<cute::Int<16>, cute::_1>>;
using SmemValLayoutS2R = cute::Layout<cute::Shape<cute::Int<1>, cute::Int<8>>>;
using SmemTiledCopyS2R = decltype(cute::make_tiled_copy(SmemCopyAtomS2R{}, SmemThrLayoutS2R{}, SmemValLayoutS2R{}));
static_assert(cute::rank(SmemLayoutAtomA{}) == 2);
static_assert(cute::size<0>(TileShape{}) % cute::size<0>(SmemLayoutAtomA{}) == 0); // M
static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomA{}) == 0); // K
static_assert(cute::rank(SmemLayoutAtomB{}) == 2);
static_assert(cute::size<1>(TileShape{}) % cute::size<0>(SmemLayoutAtomB{}) == 0); // N
static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomB{}) == 0); // K
using SmemLayoutA = decltype(cute::tile_to_shape(SmemLayoutAtomA{},
cute::make_shape(
cute::shape<0>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int<Stages>{}))); // BLK_M, BLK_K, Stages
using SmemLayoutB = decltype(cute::tile_to_shape(SmemLayoutAtomB{},
cute::make_shape(
cute::shape<1>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int<Stages>{}))); // BLK_N, BLK_K, Stages
using SmemLayoutO = decltype(cute::tile_to_shape(
SmemLayoutAtomO{}, cute::make_shape(cute::shape<0>(TileShape{}), cute::shape<1>(TileShape{})))); // BLK_M, BLK_N
using SmemLayoutScaleA = decltype(cute::tile_to_shape(SmemLayoutAtomScaleA{},
cute::make_shape(cute::shape<0>(ScalePerTileShape{}), cute::shape<2>(ScalePerTileShape{}),
cute::Int<Stages>{}))); // BLK_M, BLK_K, Stages
using SmemLayoutScaleB = decltype(cute::tile_to_shape(SmemLayoutAtomScaleB{},
cute::make_shape(cute::shape<1>(ScalePerTileShape{}), cute::shape<2>(ScalePerTileShape{}),
cute::Int<Stages>{}))); // BLK_N, BLK_K, Stages
// we need at least 2 stages..
static_assert(Stages >= 2);
struct SharedStorage : cute::aligned_struct<128>
{
cute::array_aligned<ElementA, cute::cosize_v<SmemLayoutA>> smem_a;
cute::array_aligned<ElementB, cute::cosize_v<SmemLayoutB>> smem_b;
cute::array_aligned<ElementOutput, cute::cosize_v<SmemLayoutO>> smem_o;
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleA>> smem_scale_a;
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleB>> smem_scale_b;
};
static constexpr int kSmemSize = static_cast<int>(sizeof(SharedStorage));
struct Params
{
GemmCoord problem_size{};
ElementA const* ptr_a{};
ElementB const* ptr_b{};
ElementOutput* ptr_output{};
BlockScaleElementType const* ptr_scale_a{};
BlockScaleElementType const* ptr_scale_b{};
Params() {}
Params(GemmCoord problem_size_, ElementA const* ptr_a_, ElementB const* ptr_b_, ElementOutput* ptr_output_,
BlockScaleElementType const* ptr_scale_a_, BlockScaleElementType const* ptr_scale_b_)
: problem_size(problem_size_)
, ptr_a(ptr_a_)
, ptr_b(ptr_b_)
, ptr_output(ptr_output_)
, ptr_scale_a(ptr_scale_a_)
, ptr_scale_b(ptr_scale_b_)
{
}
};
struct Arguments
{
GemmCoord problem_size{};
void const* ptr_a;
void const* ptr_b;
void* ptr_d;
float const* ptr_scale_a;
float const* ptr_scale_b;
Arguments(GemmCoord problem_size_, void const* ptr_a_, void const* ptr_b_, void* ptr_d_,
float const* ptr_scale_a_, float const* ptr_scale_b_)
: problem_size(problem_size_)
, ptr_a(ptr_a_)
, ptr_b(ptr_b_)
, ptr_d(ptr_d_)
, ptr_scale_a(ptr_scale_a_)
, ptr_scale_b(ptr_scale_b_)
{
}
};
static Params to_underlying_arguments(Arguments const& args)
{
auto ptr_a = reinterpret_cast<ElementA const*>(args.ptr_a);
auto ptr_b = reinterpret_cast<ElementB const*>(args.ptr_b);
auto ptr_d = reinterpret_cast<ElementOutput*>(args.ptr_d);
auto ptr_scale_a = reinterpret_cast<ElementBlockScale const*>(args.ptr_scale_a);
auto ptr_scale_b = reinterpret_cast<ElementBlockScale const*>(args.ptr_scale_b);
Params params(args.problem_size, ptr_a, ptr_b, ptr_d, ptr_scale_a, ptr_scale_b);
return params;
}
};
} // namespace ada_blockwise_gemm

View File

@ -1,104 +0,0 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cute/atom/mma_atom.hpp"
#include <cute/arch/mma.hpp>
#include <cute/config.hpp>
#include <cute/layout.hpp>
#include <cutlass/arch/mma.h>
// Config
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890))
#define CUTE_ARCH_MMA_F32_SM89_ENABLED
#endif
namespace cute
{
// MMA 16x8x32 TN
struct SM89_16x8x32_F32F8F8F32_TN
{
using DRegisters = float[4];
using ARegisters = uint32_t[4];
using BRegisters = uint32_t[2];
using CRegisters = float[4];
CUTE_HOST_DEVICE static void fma(float& d0, float& d1, float& d2, float& d3, uint32_t const& a0, uint32_t const& a1,
uint32_t const& a2, uint32_t const& a3, uint32_t const& b0, uint32_t const& b1, float const& c0,
float const& c1, float const& c2, float const& c3)
{
#if defined(CUTE_ARCH_MMA_F32_SM89_ENABLED)
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3)
: "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "f"(c0), "f"(c1), "f"(c2), "f"(c3));
#else
CUTE_INVALID_CONTROL_PATH(
"Attempting to use SM89_16x8x32_F32F8F8F32_TN without "
"CUTE_ARCH_MMA_F32_SM89_ENABLED");
#endif
}
};
template <>
struct MMA_Traits<SM89_16x8x32_F32F8F8F32_TN>
{
using ValTypeD = float;
using ValTypeA = float_e4m3_t;
using ValTypeB = float_e4m3_t;
using ValTypeC = float;
using Shape_MNK = Shape<_16, _8, _32>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape<Shape<_4, _8>, Shape<_4, _2, _2>>, Stride<Stride<_64, _1>, Stride<_16, _8, _256>>>;
using BLayout = Layout<Shape<Shape<_4, _8>, Shape<_4, _2>>, Stride<Stride<_32, _1>, Stride<_8, _128>>>;
using CLayout = SM80_16x8_Row;
};
} // namespace cute
namespace ada_blockwise_gemm
{
template <typename Element, typename Arch>
struct DefaultGemm_TensorOp_MMA;
template <>
struct DefaultGemm_TensorOp_MMA<cute::bfloat16_t, cutlass::arch::Sm80>
{
using ArchTag = cutlass::arch::Sm80;
using MMA_Atom_Arch = cute::MMA_Atom<cute::SM80_16x8x16_F32BF16BF16F32_TN>;
using ThreadLayoutMNK = cute::Layout<cute::Shape<cute::_2, cute::_2, cute::_1>>;
using ValLayoutMNK = cute::Tile<cute::_32, cute::_32, cute::_16>;
using TiledMma = cute::TiledMMA<MMA_Atom_Arch, ThreadLayoutMNK, ValLayoutMNK>;
};
template <>
struct DefaultGemm_TensorOp_MMA<cute::float_e4m3_t, cutlass::arch::Sm89>
{
using ArchTag = cutlass::arch::Sm89;
using MMA_Atom_Arch = cute::MMA_Atom<cute::SM89_16x8x32_F32F8F8F32_TN>;
using ThreadLayoutMNK = cute::Layout<cute::Shape<cute::_2, cute::_2, cute::_1>>;
using ValLayoutMNK = cute::Tile<cute::_32, cute::_32, cute::_32>;
using TiledMma = cute::TiledMMA<MMA_Atom_Arch, ThreadLayoutMNK, ValLayoutMNK>;
};
} // namespace ada_blockwise_gemm

View File

@ -0,0 +1,453 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "sm89_utils.cuh"
namespace ada_blockwise_gemm
{
template <typename GemmKernel>
CUTLASS_GLOBAL void sm89_fp8_gemm_1d1d_impl(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, void const* A,
void const* B, void* D, float const* scales_a, float const* scales_b)
{
GemmKernel op;
op.invoke(shape_m, shape_n, shape_k, A, B, D, scales_a, scales_b);
}
template <typename GemmKernel>
CUTLASS_GLOBAL void sm89_fp8_bmm_1d1d_impl(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, __nv_fp8_e4m3* A,
__nv_fp8_e4m3* B, __nv_bfloat16* D, float* scales_a, float* scales_b, uint64_t stride_a, uint64_t stride_b,
uint64_t stride_d, uint64_t stride_scales_a, uint64_t stride_scales_b)
{
GemmKernel op;
auto ptr_a = reinterpret_cast<typename GemmKernel::ElementInput const*>(A + blockIdx.z * stride_a);
auto ptr_b = reinterpret_cast<typename GemmKernel::ElementInput const*>(B + blockIdx.z * stride_b);
auto ptr_scale_a
= reinterpret_cast<typename GemmKernel::ElementBlockScale const*>(scales_a + blockIdx.z * stride_scales_a);
auto ptr_scale_b
= reinterpret_cast<typename GemmKernel::ElementBlockScale const*>(scales_b + blockIdx.z * stride_scales_b);
auto ptr_output = reinterpret_cast<typename GemmKernel::ElementOutput*>(D + blockIdx.z * stride_d);
op(ptr_a, ptr_b, ptr_scale_a, ptr_scale_b, ptr_output, shape_m, shape_n, shape_k);
}
template <typename KT>
struct AdaBlockwiseGemmKernel
{
using SharedStorage = typename KT::SharedStorage;
using ElementInput = typename KT::ElementInput;
using ElementOutput = typename KT::ElementOutput;
using ElementBlockScale = typename KT::ElementBlockScale;
// Factory invocation
CUTLASS_DEVICE
void invoke(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, void const* A, void const* B, void* D,
float const* scales_a, float const* scales_b)
{
auto ptr_a = reinterpret_cast<ElementInput const*>(A);
auto ptr_b = reinterpret_cast<ElementInput const*>(B);
auto ptr_scale_a = reinterpret_cast<ElementBlockScale const*>(scales_a);
auto ptr_scale_b = reinterpret_cast<ElementBlockScale const*>(scales_b);
auto ptr_output = reinterpret_cast<ElementOutput*>(D);
(*this)(ptr_a, ptr_b, ptr_scale_a, ptr_scale_b, ptr_output, shape_m, shape_n, shape_k);
}
CUTE_DEVICE auto gmem_tensor_init(typename KT::ElementInput const* ptr_a, typename KT::ElementInput const* ptr_b,
typename KT::ElementBlockScale const* ptr_scale_a, typename KT::ElementBlockScale const* ptr_scale_b,
uint32_t M, uint32_t N, uint32_t K, int* SharedStorageBase)
{
using X = cute::Underscore;
uint32_t const ScaleM = (((M + 3) >> 2) << 2); // align 4
uint32_t const ScaleN = (N + KT::ScaleGranularityN - 1) / KT::ScaleGranularityN;
uint32_t const ScaleK = (K + KT::ScaleGranularityK - 1) / KT::ScaleGranularityK;
auto mA_mk
= cute::make_tensor(cute::make_gmem_ptr(ptr_a), cute::make_shape(M, K), cute::make_stride(K, cute::_1{}));
auto mB_nk
= cute::make_tensor(cute::make_gmem_ptr(ptr_b), cute::make_shape(N, K), cute::make_stride(K, cute::_1{}));
auto mSFA_mk = cute::make_tensor(
cute::make_gmem_ptr(ptr_scale_a), cute::make_shape(ScaleM, ScaleK), cute::make_stride(cute::_1{}, ScaleM));
auto mSFB_nk = cute::make_tensor(
cute::make_gmem_ptr(ptr_scale_b), cute::make_shape(ScaleN, ScaleK), cute::make_stride(ScaleK, cute::_1{}));
auto cta_coord = cute::make_coord(blockIdx.x, blockIdx.y, cute::_); // (m,n,k)
auto gA
= cute::local_tile(mA_mk, typename KT::TileShape{}, cta_coord, cute::Step<_1, X, _1>{}); // (BLK_M,BLK_K,k)
auto gB
= cute::local_tile(mB_nk, typename KT::TileShape{}, cta_coord, cute::Step<X, _1, _1>{}); // (BLK_N,BLK_K,k)
auto gSFA = cute::local_tile(
mSFA_mk, typename KT::ScalePerTileShape{}, cta_coord, cute::Step<_1, X, _1>{}); // (BLK_M,BLK_K)
auto gSFB = cute::local_tile(
mSFB_nk, typename KT::ScalePerTileShape{}, cta_coord, cute::Step<X, _1, _1>{}); // (BLK_N,BLK_K)
typename KT::SharedStorageLoad* load_storage
= reinterpret_cast<typename KT::SharedStorageLoad*>(SharedStorageBase);
auto sA = cute::make_tensor(cute::make_smem_ptr(load_storage->smem_a.data()), typename KT::SmemLayoutA{});
auto sB = cute::make_tensor(cute::make_smem_ptr(load_storage->smem_b.data()), typename KT::SmemLayoutB{});
auto sSFA = cute::make_tensor(cute::make_smem_ptr(load_storage->smem_sfa.data()), typename KT::SmemLayoutSFA{});
auto sSFB = cute::make_tensor(cute::make_smem_ptr(load_storage->smem_sfb.data()), typename KT::SmemLayoutSFB{});
return cute::make_tuple(gA, gB, gSFA, gSFB, sA, sB, sSFA, sSFB);
}
template <class Accumulator, class SharedStorage, class ElementOutput>
CUTE_DEVICE void epilogue_with_smem(
Accumulator& accum, SharedStorage& shared_storage, ElementOutput* o, int M, int N)
{
// convert type
auto epi = cute::make_fragment_like<ElementOutput>(accum);
cute::for_each(cute::make_int_sequence<cute::size(epi)>{}, [&](auto i) { epi(i) = ElementOutput(accum(i)); });
auto sO = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{});
// copy rf -> smem
typename KT::TiledMma mma;
auto tiled_copy_R2S = cute::make_tiled_copy_C(typename KT::SmemCopyAtomR2S{}, mma);
auto thr_copy_R2S = tiled_copy_R2S.get_slice(threadIdx.x);
auto tRS_rO = thr_copy_R2S.retile_S(epi);
auto tRS_sO = thr_copy_R2S.partition_D(sO);
cute::copy(tiled_copy_R2S, tRS_rO, tRS_sO);
__syncthreads();
// copy smem -> rf
typename KT::TiledCopyS2G tiled_copy_S2G;
auto thr_copy_S2G = tiled_copy_S2G.get_slice(threadIdx.x);
auto tSR_sO = thr_copy_S2G.partition_S(sO);
auto tSR_rO = cute::make_tensor<KT::ElementOutput>(cute::shape(tSR_sO));
cute::copy(tiled_copy_S2G, tSR_sO, tSR_rO);
__syncthreads();
// copy rf -> gmem
auto mO = cute::make_tensor(cute::make_gmem_ptr(o), cute::make_shape(M, N), cute::make_stride(N, cute::_1{}));
auto cta_coord = cute::make_coord(blockIdx.x, blockIdx.y, cute::_);
auto gO = cute::local_tile(mO, typename KT::TileShape{}, cta_coord, cute::Step<cute::_1, cute::_1, X>{});
auto cO = cute::make_identity_tensor(cute::make_shape(cute::Int<KT::kTileM>{}, cute::Int<KT::kTileN>{}));
auto tRG_rO = thr_copy_S2G.retile_S(tSR_rO);
auto tRG_gO = thr_copy_S2G.partition_D(gO);
auto tRG_cO = thr_copy_S2G.partition_D(cO);
int residue_m = M - KT::kTileM * blockIdx.x;
int residue_n = N - KT::kTileN * blockIdx.y;
CUTE_UNROLL
for (int m = 0; m < cute::size<1>(tRG_gO); ++m)
{
CUTE_UNROLL
for (int n = 0; n < cute::size<2>(tRG_gO); ++n)
{
if (cute::get<0>(tRG_cO(0, m, n)) < residue_m && cute::get<1>(tRG_cO(0, m, n)) < residue_n)
{
cute::copy(typename KT::GmemCopyAtomR2G{}, tRG_rO(cute::_, m, n), tRG_gO(cute::_, m, n));
}
}
}
}
template <class TensorD, class TensorC, class TensorScale, class Index>
CUTE_DEVICE void promote(TensorD& accum, TensorC const& temp_accum, TensorScale const& scale, Index n_block)
{
using AccumType = typename TensorD::value_type;
for (int mma_m = 0; mma_m < cute::get<1>(cute::shape<0>(accum)); ++mma_m)
{
CUTE_UNROLL
for (int mma_n = 0; mma_n < cute::get<0>(cute::shape<0>(accum)); ++mma_n)
{
CUTE_UNROLL
for (int mma_iter_m = 0; mma_iter_m < cute::size<1>(accum); ++mma_iter_m)
{
CUTE_UNROLL
for (int mma_iter_n = 0; mma_iter_n < cute::size<2>(accum); ++mma_iter_n)
{
auto coord_d
= cute::make_coord(cute::make_coord(mma_n, mma_m), mma_iter_m, mma_iter_n, n_block);
auto coord_c = cute::make_coord(cute::make_coord(mma_n, mma_m), mma_iter_m, mma_iter_n);
accum(coord_d) += temp_accum(coord_c) * scale(mma_m, mma_iter_m, cute::_0{});
}
}
}
}
}
/// Executes one GEMM
CUTE_DEVICE
void operator()(typename KT::ElementInput const* ptr_a, typename KT::ElementInput const* ptr_b,
typename KT::ElementBlockScale const* ptr_scale_a, typename KT::ElementBlockScale const* ptr_scale_b,
typename KT::ElementOutput* ptr_output, uint32_t M, uint32_t N, uint32_t K)
{
// Dynamic shared memory base pointer
extern __shared__ int SharedStorageBase[];
auto [gA, gB, gSFA, gSFB, sA, sB, sSFA, sSFB]
= gmem_tensor_init(ptr_a, ptr_b, ptr_scale_a, ptr_scale_b, M, N, K, SharedStorageBase);
typename KT::GmemTiledCopyA g2s_copy_A;
typename KT::GmemTiledCopyB g2s_copy_B;
auto g2s_thr_copy_A = g2s_copy_A.get_slice(threadIdx.x);
auto g2s_thr_copy_B = g2s_copy_B.get_slice(threadIdx.x);
auto tAgA = g2s_thr_copy_A.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k)
auto tAsA = g2s_thr_copy_A.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,Stage)
auto tBgB = g2s_thr_copy_B.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k)
auto tBsB = g2s_thr_copy_B.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,Stage)
typename KT::GmemTiledCopySFA g2s_copy_SFA;
typename KT::GmemTiledCopySFB g2s_copy_SFB;
auto g2s_thr_copy_SFA = g2s_copy_SFA.get_slice(threadIdx.x);
auto g2s_thr_copy_SFB = g2s_copy_SFB.get_slice(threadIdx.x);
auto tAgSFA = g2s_thr_copy_SFA.partition_S(gSFA); // (ACPY,ACPY_M,ACPY_K,Stage)
auto tAsSFA = g2s_thr_copy_SFA.partition_D(sSFA); // (ACPY,ACPY_M,ACPY_K,Stage)
auto tBgSFB = g2s_thr_copy_SFB.partition_S(gSFB); // (BCPY,BCPY_N,BCPY_K,Stage)
auto tBsSFB = g2s_thr_copy_SFB.partition_D(sSFB); // (BCPY,BCPY_N,BCPY_K,Stage)
auto cA = make_identity_tensor(cute::make_shape(cute::size<0>(sA), cute::size<1>(sA)));
auto tAcA = g2s_thr_copy_A.partition_S(cA);
auto cB = make_identity_tensor(cute::make_shape(cute::size<0>(sB), cute::size<1>(sB)));
auto tBcB = g2s_thr_copy_B.partition_S(cB);
auto cSFA = cute::make_identity_tensor(typename KT::GmemTiledCopySFA::Tiler_MN{});
auto tAcSFA = g2s_thr_copy_SFA.partition_S(cSFA);
int residue_m = M - KT::kTileM * blockIdx.x;
int residue_n = N - KT::kTileN * blockIdx.y;
residue_m = residue_m > KT::kTileM ? KT::kTileM : residue_m;
residue_n = residue_n > KT::kTileN ? KT::kTileN : residue_n;
auto tApA = cute::make_tensor<bool>(
cute::make_shape(cute::size<1>(tAsA), cute::size<2>(tAsA)), cute::Stride<cute::_1, cute::_0>{});
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<0>(tApA); ++m)
{
tApA(m, 0) = cute::get<0>(tAcA(0, m, 0)) < residue_m; // blk_m coord < residue_m
}
auto tBpB = cute::make_tensor<bool>(
cute::make_shape(cute::size<1>(tBsB), cute::size<2>(tBsB)), cute::Stride<cute::_1, cute::_0>{});
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < cute::size<0>(tBpB); ++n)
{
tBpB(n, 0) = cute::get<0>(tBcB(0, n, 0)) < residue_n; // blk_n coord < residue_n
}
auto tApSFA = cute::make_tensor<bool>(
cute::make_shape(cute::size<1>(tAsSFA), cute::size<2>(tAsSFA)), cute::Stride<cute::_1, cute::_0>{});
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<0>(tApSFA); ++m)
{
tApSFA(m, 0) = cute::get<0>(tAcSFA(0, m, 0)) < residue_m; // blk_m coord < residue_m
}
// prefetch gmem A/B
cute::clear(tAsA);
cute::clear(tBsB);
cute::clear(tAsSFA);
int k_tile_count = cute::size<2>(gA);
CUTLASS_PRAGMA_NO_UNROLL
for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe)
{
if (k_pipe >= k_tile_count)
{
cute::clear(tApA);
cute::clear(tBpB);
cute::clear(tApSFA);
}
auto k_tile_iter = std::min(k_pipe, k_tile_count - 1);
cute::copy_if(g2s_copy_A, tApA, tAgA(cute::_, cute::_, cute::_, k_tile_iter),
tAsA(cute::_, cute::_, cute::_, k_pipe));
cute::copy_if(g2s_copy_B, tBpB, tBgB(cute::_, cute::_, cute::_, k_tile_iter),
tBsB(cute::_, cute::_, cute::_, k_pipe));
cute::copy_if(g2s_copy_SFA, tApSFA, tAgSFA(cute::_, cute::_, cute::_, k_tile_iter),
tAsSFA(cute::_, cute::_, cute::_, k_pipe));
cute::copy(g2s_copy_SFB, tBgSFB(cute::_, cute::_, cute::_, k_tile_iter),
tBsSFB(cute::_, cute::_, cute::_, k_pipe));
cute::cp_async_fence();
}
typename KT::TiledMma mma;
auto thr_mma = mma.get_slice(threadIdx.x);
auto accum = cute::partition_fragment_C(mma,
cute::make_shape(cute::Int<KT::kTileM>{}, cute::Int<KT::kMmaPermN>{},
cute::Int<KT::NUM_GROUP_N>{})); // (MMA,MMA_M,MMA_N)
auto temp = cute::partition_fragment_C(
mma, cute::make_shape(cute::Int<KT::kTileM>{}, cute::Int<KT::kMmaPermN>{})); // (MMA,MMA_M,MMA_N)
auto mma_shape_A
= cute::partition_shape_A(mma, cute::make_shape(cute::Int<KT::kTileM>{}, cute::Int<KT::kTileK>{}));
auto tCrA = cute::make_tensor<typename KT::ElementInput>(mma_shape_A);
auto mma_shape_B = cute::partition_shape_B(
mma, cute::make_shape(cute::Int<KT::kMmaPermN>{}, cute::Int<KT::kTileK>{}, cute::Int<KT::NUM_GROUP_N>{}));
auto tCrB = cute::make_tensor<typename KT::ElementInput>(mma_shape_B);
auto s2r_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, mma);
auto s2r_thr_copy_A = s2r_copy_A.get_slice(threadIdx.x);
auto tXsA = s2r_thr_copy_A.partition_S(sA); // (CPY,CPY_M,CPY_K,Stage)
auto tXrA = s2r_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K)
static_assert(is_static<decltype(tXrA.layout())>::value, "tXrA layout must be static");
auto s2r_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, mma);
auto s2r_thr_copy_B = s2r_copy_B.get_slice(threadIdx.x);
auto tXsB = s2r_thr_copy_B.partition_S(sB); // (CPY,CPY_N,CPY_K,Stage)
auto tXrB = s2r_thr_copy_B.retile_D(tCrB)(cute::_, cute::Int<0>{}, cute::_, cute::_);
typename KT::SmemTiledCopySFA s2r_copy_SFA;
typename KT::SmemTiledCopySFB s2r_copy_SFB;
auto s2r_thr_copy_SFA = s2r_copy_SFA.get_slice(threadIdx.x);
auto s2r_thr_copy_SFB = s2r_copy_SFB.get_slice(threadIdx.x);
auto tXsSFA = s2r_thr_copy_SFA.partition_S(sSFA);
auto tXrSFA = cute::make_fragment_like(tXsSFA(cute::_, cute::_, cute::_, 0));
auto tXsSFB = s2r_thr_copy_SFB.partition_S(sSFB);
auto tXrSFB = cute::make_fragment_like(tXsSFB(cute::_, cute::_, cute::_, 0));
auto scale = cute::make_fragment_like(tXrSFA);
int smem_pipe_write = KT::Stages - 1;
int smem_pipe_read = 0;
auto tXsA_read = tXsA(cute::_, cute::_, cute::_, smem_pipe_read);
auto tXsB_read = tXsB(cute::_, cute::_, cute::_, smem_pipe_read);
auto tXsSFA_read = tXsSFA(cute::_, cute::_, cute::_, smem_pipe_read);
auto tXsSFB_read = tXsSFB(cute::_, cute::_, cute::_, smem_pipe_read);
cute::cp_async_wait<KT::Stages - 2>();
__syncthreads();
// prefetch smem -> rf
cute::copy(s2r_copy_SFA, tXsSFA_read, tXrSFA);
cute::copy(s2r_copy_SFB, tXsSFB_read, tXrSFB);
cute::copy(s2r_copy_A, tXsA_read, tXrA);
cute::copy(s2r_copy_B, tXsB_read(cute::_, cute::Int<0>{}, cute::_), tXrB(cute::_, cute::_, cute::Int<0>{}));
cute::clear(accum);
int k_tile_iter = KT::Stages - 1;
while (k_tile_iter < k_tile_count)
{
cute::for_each(cute::make_int_sequence<KT::NUM_GROUP_N>{},
[&](auto n_block)
{
if constexpr (n_block == KT::NUM_GROUP_N - 1)
{
tXsA_read = tXsA(cute::_, cute::_, cute::_, smem_pipe_read);
tXsB_read = tXsB(cute::_, cute::_, cute::_, smem_pipe_read);
tXsSFA_read = tXsSFA(cute::_, cute::_, cute::_, smem_pipe_read);
tXsSFB_read = tXsSFB(cute::_, cute::_, cute::_, smem_pipe_read);
cute::cp_async_wait<KT::Stages - 2>();
__syncthreads();
cute::copy(s2r_copy_SFA, tXsSFA_read, tXrSFA);
cute::copy(s2r_copy_SFB, tXsSFB_read, tXrSFB);
}
auto n_block_next = (n_block + cute::_1{}) % KT::NUM_GROUP_N;
cute::copy(
s2r_copy_B, tXsB_read(cute::_, n_block_next, cute::_), tXrB(cute::_, cute::_, n_block_next));
if constexpr (n_block == 0)
{
// gmem -> smem
cute::copy_if(g2s_copy_A, tApA, tAgA(cute::_, cute::_, cute::_, k_tile_iter),
tAsA(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy_if(g2s_copy_B, tBpB, tBgB(cute::_, cute::_, cute::_, k_tile_iter),
tBsB(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy_if(g2s_copy_SFA, tApSFA, tAgSFA(cute::_, cute::_, cute::_, k_tile_iter),
tAsSFA(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy(g2s_copy_SFB, tBgSFB(cute::_, cute::_, cute::_, k_tile_iter),
tBsSFB(cute::_, cute::_, cute::_, smem_pipe_write));
cute::cp_async_fence();
k_tile_iter++;
smem_pipe_write = smem_pipe_read;
++smem_pipe_read;
smem_pipe_read = smem_pipe_read == KT::Stages ? 0 : smem_pipe_read;
cute::for_each(cute::make_int_sequence<cute::size(scale)>{},
[&](auto i) { scale(i) = tXrSFA(i) * tXrSFB(0); });
}
cute::clear(temp);
cute::gemm(mma, tCrA, tCrB(cute::_, cute::_, cute::_, n_block), temp);
if constexpr (n_block == KT::NUM_GROUP_N - 1)
{
cute::copy(s2r_copy_A, tXsA_read, tXrA);
}
promote(accum, temp, scale, n_block);
});
}
// load tail
cute::for_each(cute::make_int_sequence<KT::Stages - 2>{},
[&](auto WaitIndex)
{
using WaitIndex_t = decltype(WaitIndex);
cute::for_each(cute::make_int_sequence<KT::NUM_GROUP_N>{},
[&](auto n_block)
{
if constexpr (n_block == KT::NUM_GROUP_N - 1)
{
tXsA_read = tXsA(cute::_, cute::_, cute::_, smem_pipe_read);
tXsB_read = tXsB(cute::_, cute::_, cute::_, smem_pipe_read);
tXsSFA_read = tXsSFA(cute::_, cute::_, cute::_, smem_pipe_read);
tXsSFB_read = tXsSFB(cute::_, cute::_, cute::_, smem_pipe_read);
cute::cp_async_wait<KT::Stages - 3 - WaitIndex_t::value>();
__syncthreads();
cute::copy(s2r_copy_SFA, tXsSFA_read, tXrSFA);
cute::copy(s2r_copy_SFB, tXsSFB_read, tXrSFB);
}
auto n_block_next = (n_block + cute::_1{}) % KT::NUM_GROUP_N;
cute::copy(s2r_copy_B, tXsB_read(cute::_, n_block_next, cute::_),
tXrB(cute::_, cute::_, n_block_next));
if constexpr (n_block == 0)
{
++smem_pipe_read;
smem_pipe_read = smem_pipe_read == KT::Stages ? 0 : smem_pipe_read;
cute::for_each(cute::make_int_sequence<cute::size(scale)>{},
[&](auto i) { scale(i) = tXrSFA(i) * tXrSFB(0); });
}
cute::clear(temp);
cute::gemm(mma, tCrA, tCrB(cute::_, cute::_, cute::_, n_block), temp);
if constexpr (n_block == KT::NUM_GROUP_N - 1)
{
cute::copy(s2r_copy_A, tXsA_read, tXrA);
}
promote(accum, temp, scale, n_block);
});
});
// mma tail
cute::for_each(cute::make_int_sequence<KT::NUM_GROUP_N>{},
[&](auto n_block)
{
auto n_block_next = (n_block + cute::_1{}) % KT::NUM_GROUP_N;
cute::copy(s2r_copy_B, tXsB_read(cute::_, n_block_next, cute::_), tXrB(cute::_, cute::_, n_block_next));
cute::clear(temp);
if constexpr (n_block == 0)
{
cute::for_each(cute::make_int_sequence<cute::size(scale)>{},
[&](auto i) { scale(i) = tXrSFA(i) * tXrSFB(0); });
}
cute::gemm(mma, tCrA, tCrB(cute::_, cute::_, cute::_, n_block), temp);
promote(accum, temp, scale, n_block);
});
// epilogue
__syncthreads(); // sync before using store smem
typename KT::SharedStorageStore* store_storage
= reinterpret_cast<typename KT::SharedStorageStore*>(SharedStorageBase);
epilogue_with_smem(accum, *store_storage, ptr_output, M, N);
}
};
} // namespace ada_blockwise_gemm

View File

@ -0,0 +1,248 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cute/atom/mma_atom.hpp"
#include <cuda_runtime.h>
#include <cute/arch/mma.hpp>
#include <cute/config.hpp>
#include <cute/int_tuple.hpp>
#include <cute/layout.hpp>
#include <cutlass/arch/mma.h>
#include <cutlass/cutlass.h>
#define CUTLASS_HOST_TRACE(x) \
{ \
std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; \
}
// Config
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890))
#define CUTE_ARCH_MMA_F32_SM89_ENABLED
#endif
namespace cute
{
// MMA 16x8x32 TN
struct SM89_16x8x32_F32F8F8F32_TN
{
using DRegisters = float[4];
using ARegisters = uint32_t[4];
using BRegisters = uint32_t[2];
using CRegisters = float[4];
CUTE_HOST_DEVICE static void fma(float& d0, float& d1, float& d2, float& d3, uint32_t const& a0, uint32_t const& a1,
uint32_t const& a2, uint32_t const& a3, uint32_t const& b0, uint32_t const& b1, float const& c0,
float const& c1, float const& c2, float const& c3)
{
#if defined(CUTE_ARCH_MMA_F32_SM89_ENABLED)
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3)
: "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "f"(c0), "f"(c1), "f"(c2), "f"(c3));
#else
CUTE_INVALID_CONTROL_PATH(
"Attempting to use SM89_16x8x32_F32F8F8F32_TN without "
"CUTE_ARCH_MMA_F32_SM89_ENABLED");
#endif
}
};
template <>
struct MMA_Traits<SM89_16x8x32_F32F8F8F32_TN>
{
using ValTypeD = float;
using ValTypeA = float_e4m3_t;
using ValTypeB = float_e4m3_t;
using ValTypeC = float;
using Shape_MNK = Shape<_16, _8, _32>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape<Shape<_4, _8>, Shape<_4, _2, _2>>, Stride<Stride<_64, _1>, Stride<_16, _8, _256>>>;
using BLayout = Layout<Shape<Shape<_4, _8>, Shape<_4, _2>>, Stride<Stride<_32, _1>, Stride<_8, _128>>>;
using CLayout = SM80_16x8_Row;
};
} // namespace cute
using namespace cute;
using namespace cutlass;
using namespace cutlass::gemm;
namespace ada_blockwise_gemm
{
template <typename Element, typename Arch>
struct DefaultGemm_TensorOp_MMA;
template <>
struct DefaultGemm_TensorOp_MMA<cute::bfloat16_t, cutlass::arch::Sm80>
{
using ArchTag = cutlass::arch::Sm80;
using MMA_Atom_Arch = cute::MMA_Atom<cute::SM80_16x8x16_F32BF16BF16F32_TN>;
using ThreadLayoutMNK = cute::Layout<cute::Shape<cute::_2, cute::_2, cute::_1>>;
using ValLayoutMNK = cute::Tile<cute::_32, cute::_32, cute::_16>;
using TiledMma = cute::TiledMMA<MMA_Atom_Arch, ThreadLayoutMNK, ValLayoutMNK>;
};
template <>
struct DefaultGemm_TensorOp_MMA<cute::float_e4m3_t, cutlass::arch::Sm89>
{
using ArchTag = cutlass::arch::Sm89;
using MMA_Atom_Arch = cute::MMA_Atom<cute::SM89_16x8x32_F32F8F8F32_TN>;
using ThreadLayoutMNK = cute::Layout<cute::Shape<cute::_2, cute::_2, cute::_1>>;
using ValLayoutMNK = cute::Tile<cute::_32, cute::_32, cute::_32>;
using TiledMma = cute::TiledMMA<MMA_Atom_Arch, ThreadLayoutMNK, ValLayoutMNK>;
};
template <typename ElementType, typename OutElementType, typename AccumElementType, typename BlockScaleElementType,
int Stages_, int TileM_, int TileN_, int TileK_>
struct AdaBlockwiseGemmTraits
{
using ElementInput = ElementType;
using ElementOutput = OutElementType;
using ElementAccumulator = float;
using ElementBlockScale = float;
using index_t = uint32_t;
static_assert(TileM_ % 16 == 0);
static_assert(TileN_ % 32 == 0);
static_assert(TileK_ % 32 == 0);
static constexpr int Stages = Stages_;
static constexpr int kTileM = TileM_;
static constexpr int kTileN = TileN_;
static constexpr int kTileK = TileK_;
using TileShape = Shape<Int<kTileM>, Int<kTileN>, Int<kTileK>>;
static constexpr int kWarpsCount = 4;
static constexpr int kThreadCount = kWarpsCount * 32;
static constexpr int ScaleGranularityM = 1;
static constexpr int ScaleGranularityN = 128;
static constexpr int ScaleGranularityK = 128;
static constexpr int ScaleMsPerTile = (kTileM + ScaleGranularityM - 1) / ScaleGranularityM;
static constexpr int ScaleNsPerTile = (kTileN + ScaleGranularityN - 1) / ScaleGranularityN;
static constexpr int ScaleKsPerTile = (kTileK + ScaleGranularityK - 1) / ScaleGranularityK;
using ScaleGranularity = Shape<Int<ScaleGranularityM>, Int<ScaleGranularityN>, Int<ScaleGranularityK>>;
using ScalePerTileShape = Shape<Int<ScaleMsPerTile>, Int<ScaleNsPerTile>, Int<ScaleKsPerTile>>;
// ====== mma ======
static constexpr int kMmaPermM = 32;
static constexpr int kMmaPermN = 32;
static constexpr int kMmaPermK = 32;
constexpr static int NUM_GROUP_M = kTileM / kMmaPermM;
constexpr static int NUM_GROUP_N = kTileN / kMmaPermN;
constexpr static int NUM_GROUP_K = kTileK / kMmaPermK;
using MMA_Atom = MMA_Atom<SM89_16x8x32_F32F8F8F32_TN>;
using AtomLayoutMNK = Layout<Shape<_2, _2, _1>>;
using PermutationMNK = Tile<Int<kMmaPermM>, Int<kMmaPermN>, Int<kMmaPermK>>;
using TiledMma = TiledMMA<MMA_Atom, AtomLayoutMNK, PermutationMNK>;
// ====== load gmem -> smem ======
using GmemTiledCopyLoad = decltype(make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, ElementInput>{},
Layout<Shape<_16, _8>, Stride<_8, _1>>{}, Layout<Shape<_1, _16>>{}));
using GmemTiledCopyA = GmemTiledCopyLoad;
using GmemTiledCopyB = GmemTiledCopyLoad;
// ====== load smem -> rf ======
using SmemAtomLayoutLoad = decltype(composition(Swizzle<3, 4, 3>{}, Layout<Shape<_16, _128>, Stride<_128, _1>>{}));
using SmemLayoutA = decltype(tile_to_shape(SmemAtomLayoutLoad{}, Shape<Int<kTileM>, Int<kTileK>, Int<Stages>>{}));
using SmemLayoutB = decltype(tile_to_shape(SmemAtomLayoutLoad{}, Shape<Int<kTileN>, Int<kTileK>, Int<Stages>>{}));
using SmemCopyAtomLoad = Copy_Atom<SM75_U32x4_LDSM_N, ElementInput>;
using SmemCopyAtomA = SmemCopyAtomLoad;
using SmemCopyAtomB = SmemCopyAtomLoad;
// ====== store rf -> smem ======
using SmemAtomLayoutStore = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_8, Shape<_8, _8>>, Stride<_8, Stride<_1, _64>>>{})); // 8x64
using SmemLayoutO = decltype(tile_to_shape(SmemAtomLayoutStore{}, Shape<Int<kTileM>, Int<kTileN>>{}));
using SmemCopyAtomR2S = Copy_Atom<AutoVectorizingCopy, ElementOutput>;
// ====== store smem -> gmem ======
using SmemCopyAtomS2R = Copy_Atom<UniversalCopy<uint128_t>, ElementOutput>;
using GmemCopyAtomR2G = SmemCopyAtomS2R;
using TiledCopyS2G = decltype(make_tiled_copy(
SmemCopyAtomS2R{}, Layout<Shape<_16, _8>, Stride<_8, _1>>{}, Layout<Shape<_1, _8>>{})); // 16x64
// ====== load scale gmem -> smem ======
using GmemCopyAtomScale = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
using GmemLayoutTVSFA = Layout<Shape<Shape<Int<ScaleMsPerTile>, Int<kThreadCount / ScaleMsPerTile>>, Shape<_1, _1>>,
Stride<Stride<_1, _0>, Stride<_1, _1>>>;
using GmemTileShapeSFA = Shape<Int<ScaleMsPerTile>, Int<ScaleKsPerTile>>;
using GmemTiledCopySFA = decltype(make_tiled_copy_impl(GmemCopyAtomScale{}, GmemLayoutTVSFA{}, GmemTileShapeSFA{}));
using GmemLayoutTVSFB = Layout<Shape<Shape<_32, _4>, Shape<_1, _1>>, Stride<Stride<_0, _0>, Stride<_1, _1>>>;
using GmemTileShapeSFB = Shape<Int<ScaleNsPerTile>, Int<ScaleKsPerTile>>;
using GmemTiledCopySFB = decltype(make_tiled_copy_impl(GmemCopyAtomScale{}, GmemLayoutTVSFB{}, GmemTileShapeSFB{}));
// ====== load scale smem -> rf ======
using SmemCopyAtomScale = Copy_Atom<UniversalCopy<ElementBlockScale>, ElementBlockScale>;
using SmemLayoutTVSFA
= Layout<Shape<Shape<_4, _8, _2, _2>, Shape<_2>>, Stride<Stride<_0, _1, _16, _0>, Stride<_8, _0>>>;
using SmemTileShapeSFA = Shape<Int<kMmaPermM>, _1>;
using SmemTiledCopySFA = decltype(make_tiled_copy_impl(SmemCopyAtomScale{}, SmemLayoutTVSFA{}, SmemTileShapeSFA{}));
using SmemLayoutSFA = decltype(tile_to_shape(make_layout(SmemTileShapeSFA{}),
make_shape(
shape<0>(ScalePerTileShape{}), shape<2>(ScalePerTileShape{}), Int<Stages>{}))); // BLK_M, BLK_K, Stages
using SmemLayoutTVSFB
= Layout<Shape<Shape<_4, _8, _2, _2>, Shape<_1>>, Stride<Stride<_0, _0, _0, _0>, Stride<_0, _0>>>;
using SmemTileShapeSFB = Shape<_1, _1>;
using SmemTiledCopySFB = decltype(make_tiled_copy_impl(SmemCopyAtomScale{}, SmemLayoutTVSFB{}, SmemTileShapeSFB{}));
using SmemLayoutSFB = decltype(tile_to_shape(make_layout(SmemTileShapeSFB{}),
make_shape(
shape<1>(ScalePerTileShape{}), shape<2>(ScalePerTileShape{}), Int<Stages>{}))); // BLK_N, BLK_K, Stages
// we need at least 2 stages..
static_assert(Stages >= 2);
struct SharedStorageLoad : aligned_struct<128>
{
array_aligned<ElementInput, cosize_v<SmemLayoutA>> smem_a;
array_aligned<ElementInput, cosize_v<SmemLayoutB>> smem_b;
array_aligned<float, cosize_v<SmemLayoutSFA>> smem_sfa;
array_aligned<float, cosize_v<SmemLayoutSFB>> smem_sfb;
};
struct SharedStorageStore : aligned_struct<128>
{
array_aligned<ElementOutput, cosize_v<SmemLayoutO>> smem_o;
};
union SharedStorage
{
SharedStorageLoad load;
SharedStorageStore store;
};
static constexpr int kSmemSize = static_cast<int>(sizeof(SharedStorage));
};
} // namespace ada_blockwise_gemm

View File

@ -27,7 +27,7 @@
#include <string>
#include <vector>
#include "ada_blockwise_gemm/ada_blockwise_gemm.cuh"
#include "ada_blockwise_gemm/sm89_fp8_gemm_1d1d.cuh"
#include "fp8_blockscale_mma_utils.cuh"
#include "fp8_blockscale_tma_utils.cuh"
#include "tensorrt_llm/common/cudaUtils.h"
@ -1172,7 +1172,7 @@ template <typename InputType, typename OutputType, typename ScaleType = float>
__global__ void scale_1x128_reshape_kernel(
OutputType* output, ScaleType* scales, InputType const* const input, int dim_x, int dim_h, int dim_y, int stride_x)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890))
size_t scales_along_dim_x = div_up(dim_x, 128);
size_t scales_along_dim_y = div_up(dim_y, 1);
size_t scales_along_dim_h = div_up(dim_h, 1);
@ -1582,30 +1582,30 @@ void gemm_dispatch_sm89(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
using ElementOutput = cute::bfloat16_t;
using ElementAccum = float;
using ElementBlockScale = float;
static constexpr int Stages = 4;
using TileShape = cutlass::gemm::GemmShape<32, 128, 128>; // only support 32x128x128 for now
static constexpr int Stages = 3;
using TileShape = cutlass::gemm::GemmShape<32, 128, 128>;
using KT = ada_blockwise_gemm::AdaBlockwiseGemmTraits<ElementInput, ElementOutput, ElementAccum, ElementBlockScale,
Stages, TileShape::kM, TileShape::kN, TileShape::kK>;
using Gemm = ada_blockwise_gemm::AdaBlockwiseGemm<KT>;
using GemmKernel = ada_blockwise_gemm::AdaBlockwiseGemmKernel<KT>;
int gemm_m = shape_m;
int gemm_n = shape_n;
int gemm_k = shape_k;
typename KT::Arguments args({gemm_m, gemm_n, gemm_k}, mat_a, mat_b, mat_d, scales_a, scales_b);
static constexpr int kSmemSize = KT::kSmemSize;
static constexpr int kThreadCount = KT::kThreadCount;
int grid_m = (shape_m + KT::kTileM - 1) / KT::kTileM;
int grid_n = (shape_n + KT::kTileN - 1) / KT::kTileN;
int grid_k = 1;
dim3 grid = dim3(grid_m, grid_n, grid_k);
dim3 block = dim3(kThreadCount, 1, 1);
Gemm gemm{};
if (kSmemSize > (48 << 10))
{
cudaFuncSetAttribute(ada_blockwise_gemm::sm89_fp8_gemm_1d1d_impl<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize);
auto result = cudaGetLastError();
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm89 gemm kernel cannot launch: %s", cudaGetErrorString(result));
}
auto status = gemm.can_implement(args);
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "This kernel is not supported. Last CUDA error is: %s",
cutlassGetStatusString(status));
status = gemm.initialize(args);
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess,
"Failed to initialize the CUTLASS kernel. Last CUDA error is: %s", cutlassGetStatusString(status));
status = gemm.run(stream);
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess,
"Failed to run the CUTLASS kernel. Last CUDA error is: %s", cutlassGetStatusString(status));
ada_blockwise_gemm::sm89_fp8_gemm_1d1d_impl<GemmKernel>
<<<grid, block, kSmemSize, stream>>>(shape_m, shape_n, shape_k, mat_a, mat_b, mat_d, scales_a, scales_b);
}
void fp8_gemm_run(__nv_fp8_e4m3* mat_a, int ld_a, __nv_fp8_e4m3* mat_b, int ld_b, __nv_bfloat16* mat_d, int ld_d,
@ -1788,6 +1788,48 @@ void strided_batch_gemm_dispatch(__nv_fp8_e4m3* mat_a, int ld_a, int stride_a, _
static_cast<uint32_t>(best_smem_size));
}
void strided_batch_gemm_dispatch_sm89(__nv_fp8_e4m3* mat_a, int ld_a, int stride_a, __nv_fp8_e4m3* mat_b, int ld_b,
int stride_b, __nv_bfloat16* mat_d, int ld_d, int stride_d, float* scales_a, int stride_scales_a, float* scales_b,
uint32_t num_problems, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, cudaStream_t stream,
int num_device_sms = kNumDeviceSMs)
{
if (num_device_sms < 0)
{
num_device_sms = kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
}
using ElementInput = cute::float_e4m3_t;
using ElementOutput = cute::bfloat16_t;
using ElementAccum = float;
using ElementBlockScale = float;
static constexpr int Stages = 3;
using TileShape = cutlass::gemm::GemmShape<32, 128, 128>;
using KT = ada_blockwise_gemm::AdaBlockwiseGemmTraits<ElementInput, ElementOutput, ElementAccum, ElementBlockScale,
Stages, TileShape::kM, TileShape::kN, TileShape::kK>;
using GemmKernel = ada_blockwise_gemm::AdaBlockwiseGemmKernel<KT>;
static constexpr int kSmemSize = KT::kSmemSize;
static constexpr int kThreadCount = KT::kThreadCount;
int grid_m = (shape_m + KT::kTileM - 1) / KT::kTileM;
int grid_n = (shape_n + KT::kTileN - 1) / KT::kTileN;
int grid_k = num_problems;
dim3 grid = dim3(grid_m, grid_n, grid_k);
dim3 block = dim3(kThreadCount, 1, 1);
int stride_scales_b = ((shape_n + 128 - 1) / 128) * ((shape_k + 128 - 1) / 128);
if (kSmemSize > (48 << 10))
{
cudaFuncSetAttribute(ada_blockwise_gemm::sm89_fp8_bmm_1d1d_impl<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize);
auto result = cudaGetLastError();
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm89 gemm kernel cannot launch: %s", cudaGetErrorString(result));
}
ada_blockwise_gemm::sm89_fp8_bmm_1d1d_impl<GemmKernel><<<grid, block, kSmemSize, stream>>>(shape_m, shape_n,
shape_k, mat_a, mat_b, mat_d, scales_a, scales_b, stride_a, stride_b, stride_d, stride_scales_a,
stride_scales_b);
}
void fp8_stride_batch_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_mat_a, float* scales_a, int ld_a,
int stride_a, int stride_scales_a, __nv_bfloat16 const* mat_b, __nv_fp8_e4m3* fp8_mat_b, float* scales_b, int ld_b,
int stride_b, __nv_bfloat16* mat_d, int ld_d, int stride_d, uint32_t num_problems, uint32_t shape_m,
@ -1814,6 +1856,13 @@ void fp8_stride_batch_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_ma
fp8_mat_b, scales_b, mat_b, shape_k, shape_n * num_problems);
}
int arch = tensorrt_llm::common::getSMVersion();
if (arch == 89)
{
strided_batch_gemm_dispatch_sm89(fp8_mat_a, ld_a, stride_a, fp8_mat_b, ld_b, stride_b, mat_d, ld_d, stride_d,
scales_a, stride_scales_a, scales_b, num_problems, shape_m, shape_n, shape_k, stream);
return;
}
if (kDeepGemmEnabled)
{
strided_batch_gemm_dispatch(fp8_mat_a, ld_a, stride_a, fp8_mat_b, ld_b, stride_b, mat_d, ld_d, stride_d,

View File

@ -359,7 +359,7 @@ def fp8_block_scaling_bmm_out(
out: torch.Tensor,
) -> torch.Tensor:
sm_version = get_sm_version()
if sm_version == 90:
if sm_version == 90 or sm_version == 89:
mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
mat1)
torch.ops.trtllm.fp8_block_scaling_bmm_out(mat1_fp8, mat2_fp8,

View File

@ -41,8 +41,6 @@ from utils.util import getSMVersion
[torch.bfloat16],
)
def test_fp8_block_scale_gemm(dtype, m, k, n):
if getSMVersion() == 89 and k == 7168 and n == 2112:
pytest.skip("https://nvbugs/5328184")
torch.random.manual_seed(0)
a = torch.randn((m, k), device='cuda', dtype=dtype) / k
@ -60,6 +58,55 @@ def test_fp8_block_scale_gemm(dtype, m, k, n):
torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3)
@pytest.mark.skipif(
getSMVersion() != 90 and getSMVersion() != 89,
reason="The test is for Hopper and Ada only. Current SM is %d." %
getSMVersion(),
)
@pytest.mark.parametrize(
"k, n",
[(7168, 2112), (512, 32768), (16384, 7168), (2048, 7168)],
)
@pytest.mark.parametrize(
"m",
[7, 64, 128],
)
@pytest.mark.parametrize(
"num_groups",
[4, 8, 16],
)
@pytest.mark.parametrize(
"dtype",
[torch.bfloat16],
)
def test_fp8_block_scale_bmm(dtype, m, k, n, num_groups):
torch.random.manual_seed(0)
a = torch.randn((m, num_groups, k), device='cuda', dtype=dtype) / k
a_fp8, a_scales = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(a)
b = torch.randn((num_groups, n, k), device='cuda', dtype=dtype) / k
b_fp8 = torch.zeros_like(b, device='cuda', dtype=torch.float8_e4m3fn)
b_scales = torch.zeros((num_groups, (n + 127) // 128, (k + 127) // 128),
device='cuda',
dtype=torch.float)
for i in range(num_groups):
b_fp8[i], b_scales[i] = per_block_cast_to_fp8(b[i])
output_expected = torch.einsum('mgk,gnk->gmn', a, b)
output = torch.empty((num_groups, m, n),
device='cuda',
dtype=torch.bfloat16)
torch.ops.trtllm.fp8_block_scaling_bmm_out(a_fp8, b_fp8, a_scales, b_scales,
output)
diff = calc_diff(output, output_expected)
assert diff < 1e-3
torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3)
def deepSeekFp8ComputeGemmReference(mM, mN, mK, valsC, dqSfsC, valsA, dqSfsA,
valsB, dqSfsB, quantizeOutput, tileSize):
for mi in range(mM):