mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[fix] fix tileN cannot % 16==0 & support sm89 deepgemm bmm (#5531)
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
This commit is contained in:
parent
7d21b55b5a
commit
dc32f9ae73
@ -1,277 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
// clang-format off
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/layout/layout.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
|
||||
// clang-format on
|
||||
namespace ada_blockwise_gemm
|
||||
{
|
||||
|
||||
template <typename Element, typename Layout, int Alignment, int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA;
|
||||
|
||||
template <typename Element, typename Layout, int Alignment, int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandB;
|
||||
|
||||
// specializations for float_e4m3
|
||||
template <>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA<cute::float_e4m3_t, cutlass::layout::RowMajor, 16, 128>
|
||||
{
|
||||
// Smem
|
||||
using smem_layout = cute::Layout<cute::Shape<cute::_8, cute::_128>, cute::Stride<cute::_128, cute::_1>>;
|
||||
using SmemLayoutAtom = decltype(cute::composition(cute::Swizzle<3, 4, 3>{}, smem_layout{}));
|
||||
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::float_e4m3_t>;
|
||||
|
||||
// Gmem
|
||||
using copy_atom = decltype(cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::float_e4m3_t>{});
|
||||
using thr_layout = decltype(cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_8, cute::_1>>{});
|
||||
using val_layout = decltype(cute::Layout<cute::Shape<cute::_1, cute::_16>>{});
|
||||
using GmemTiledCopy = decltype(cute::make_tiled_copy(copy_atom{}, thr_layout{}, val_layout{}));
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA<cute::float_e4m3_t, cutlass::layout::ColumnMajor, 16, 128>
|
||||
{
|
||||
// Smem
|
||||
using smem_layout = cute::Layout<cute::Shape<cute::_8, cute::_128>, cute::Stride<cute::_128, cute::_1>>;
|
||||
using SmemLayoutAtom = decltype(cute::composition(cute::Swizzle<3, 4, 3>{}, smem_layout{}));
|
||||
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::float_e4m3_t>;
|
||||
|
||||
// Gmem
|
||||
using copy_atom = decltype(cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::float_e4m3_t>{});
|
||||
using thr_layout = decltype(cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_1, cute::_16>>{});
|
||||
using val_layout = decltype(cute::Layout<cute::Shape<cute::_16, cute::_1>>{});
|
||||
using GmemTiledCopy = decltype(cute::make_tiled_copy(copy_atom{}, thr_layout{}, val_layout{}));
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, 8, 64>
|
||||
{
|
||||
// Smem
|
||||
using SmemLayoutAtom = decltype(cute::composition(
|
||||
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_64>, cute::Stride<cute::_64, cute::_1>>{}));
|
||||
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::half_t>;
|
||||
|
||||
// Gmem
|
||||
using GmemTiledCopy = decltype(cute::make_tiled_copy(
|
||||
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
|
||||
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_8, cute::_1>>{},
|
||||
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, 8, 64>
|
||||
{
|
||||
// Smem
|
||||
using SmemLayoutAtom = decltype(cute::composition(
|
||||
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_64>, cute::Stride<cute::_64, cute::_1>>{}));
|
||||
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::bfloat16_t>;
|
||||
|
||||
// Gmem
|
||||
using GmemTiledCopy = decltype(cute::make_tiled_copy(
|
||||
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
|
||||
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_8, cute::_1>>{},
|
||||
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
|
||||
};
|
||||
|
||||
/// Operand A - Column-major (M-major)
|
||||
template <int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::ColumnMajor, 8, SizeK>
|
||||
{
|
||||
// Smem
|
||||
using SmemLayoutAtom = decltype(cute::composition(
|
||||
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_64, cute::_8>, cute::Stride<cute::_1, cute::_64>>{}));
|
||||
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U16x8_LDSM_T, cute::half_t>;
|
||||
|
||||
// Gmem
|
||||
using GmemTiledCopy = decltype(cute::make_tiled_copy(
|
||||
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
|
||||
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_1, cute::_16>>{},
|
||||
cute::Layout<cute::Shape<cute::_8, cute::_1>>{}));
|
||||
};
|
||||
|
||||
template <int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::ColumnMajor, 8, SizeK>
|
||||
{
|
||||
// Smem
|
||||
using SmemLayoutAtom = decltype(cute::composition(
|
||||
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_64, cute::_8>, cute::Stride<cute::_1, cute::_64>>{}));
|
||||
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U16x8_LDSM_T, cute::bfloat16_t>;
|
||||
|
||||
// Gmem
|
||||
using GmemTiledCopy = decltype(cute::make_tiled_copy(
|
||||
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
|
||||
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_1, cute::_16>>{},
|
||||
cute::Layout<cute::Shape<cute::_8, cute::_1>>{}));
|
||||
};
|
||||
|
||||
// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the
|
||||
// DefaultOperands
|
||||
|
||||
// Operand B - Column-Major (K-major)
|
||||
template <int Alignment, int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandB<cute::half_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
|
||||
: DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, Alignment, SizeK>
|
||||
{
|
||||
};
|
||||
|
||||
template <int Alignment, int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandB<cute::bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
|
||||
: DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
|
||||
{
|
||||
};
|
||||
|
||||
// Operand B - Row-Major (N-major)
|
||||
template <int Alignment, int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandB<cute::half_t, cutlass::layout::RowMajor, Alignment, SizeK>
|
||||
: DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
|
||||
{
|
||||
};
|
||||
|
||||
template <int Alignment, int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandB<cute::bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
|
||||
: DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
|
||||
{
|
||||
};
|
||||
|
||||
template <int Alignment, int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandB<cute::float_e4m3_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
|
||||
: DefaultGemm_TensorOpSm80_OperandA<cute::float_e4m3_t, cutlass::layout::RowMajor, Alignment, SizeK>
|
||||
{
|
||||
};
|
||||
|
||||
template <int Alignment, int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandB<cute::float_e4m3_t, cutlass::layout::RowMajor, Alignment, SizeK>
|
||||
: DefaultGemm_TensorOpSm80_OperandA<cute::float_e4m3_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
|
||||
{
|
||||
};
|
||||
|
||||
//
|
||||
// F16: 128-by-128-by-32 (small k-block)
|
||||
//
|
||||
|
||||
/// Operand A - Row-major (K-Major)
|
||||
template <>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, 8, 32>
|
||||
{
|
||||
// Smem
|
||||
using SmemLayoutAtom = decltype(cute::composition(
|
||||
cute::Swizzle<2, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_32>, cute::Stride<cute::_32, cute::_1>>{}));
|
||||
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::half_t>;
|
||||
|
||||
// Gmem
|
||||
using GmemTiledCopy = decltype(cute::make_tiled_copy(
|
||||
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
|
||||
cute::Layout<cute::Shape<cute::_32, cute::_4>, cute::Stride<cute::_4, cute::_1>>{},
|
||||
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, 8, 32>
|
||||
{
|
||||
// Smem
|
||||
using SmemLayoutAtom = decltype(cute::composition(
|
||||
cute::Swizzle<2, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_32>, cute::Stride<cute::_32, cute::_1>>{}));
|
||||
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::bfloat16_t>;
|
||||
|
||||
// Gmem
|
||||
using GmemTiledCopy = decltype(cute::make_tiled_copy(
|
||||
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
|
||||
cute::Layout<cute::Shape<cute::_32, cute::_4>, cute::Stride<cute::_4, cute::_1>>{},
|
||||
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
|
||||
};
|
||||
|
||||
struct CopyTraitsScaleA
|
||||
{
|
||||
using GmemCopyAtom = cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEALWAYS<cute::uint32_t>, float>;
|
||||
|
||||
// Gmem
|
||||
using GmemLayoutTVScale
|
||||
= cute::Layout<cute::Shape<cute::Shape<cute::_32, cute::_4>, cute::Shape<cute::_1, cute::_1>>,
|
||||
cute::Stride<cute::Stride<cute::_1, cute::_0>, cute::Stride<cute::_1, cute::_1>>>;
|
||||
using GmemTileShapeScale = cute::Shape<cute::_32, cute::_1>;
|
||||
using GmemTiledCopyScale
|
||||
= decltype(cute::make_tiled_copy_impl(GmemCopyAtom{}, GmemLayoutTVScale{}, GmemTileShapeScale{}));
|
||||
|
||||
// Smem
|
||||
using SmemCopyAtomScale = cute::Copy_Atom<cute::UniversalCopy<float>, float>;
|
||||
using SmemLayoutTVScale
|
||||
= cute::Layout<cute::Shape<cute::Shape<cute::_4, cute::_8, cute::_2, cute::_2>, cute::Shape<cute::_2>>,
|
||||
cute::Stride<cute::Stride<cute::_0, cute::_1, cute::_16, cute::_0>, cute::Stride<cute::_8>>>;
|
||||
using SmemTileShapeScale = cute::Shape<cute::_32, cute::_1>;
|
||||
using SmemTiledCopyScale
|
||||
= decltype(cute::make_tiled_copy_impl(SmemCopyAtomScale{}, SmemLayoutTVScale{}, SmemTileShapeScale{}));
|
||||
};
|
||||
|
||||
struct CopyTraitsScaleB
|
||||
{
|
||||
using GmemCopyAtom = cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEALWAYS<cute::uint32_t>, float>;
|
||||
|
||||
// Gmem
|
||||
using GmemLayoutTVScale
|
||||
= cute::Layout<cute::Shape<cute::Shape<cute::_32, cute::_4>, cute::Shape<cute::_1, cute::_1>>,
|
||||
cute::Stride<cute::Stride<cute::_0, cute::_0>, cute::Stride<cute::_1, cute::_1>>>;
|
||||
using GmemTileShapeScale = cute::Shape<cute::_1, cute::_1>;
|
||||
using GmemTiledCopyScale
|
||||
= decltype(cute::make_tiled_copy_impl(GmemCopyAtom{}, GmemLayoutTVScale{}, GmemTileShapeScale{}));
|
||||
|
||||
// Smem
|
||||
using SmemCopyAtomScale = cute::Copy_Atom<cute::UniversalCopy<float>, float>;
|
||||
using SmemLayoutTVScale
|
||||
= cute::Layout<cute::Shape<cute::Shape<cute::_4, cute::_8, cute::_2, cute::_2>, cute::Shape<cute::_1>>,
|
||||
cute::Stride<cute::Stride<cute::_0, cute::_0, cute::_0, cute::_0>, cute::Stride<cute::_0>>>;
|
||||
using SmemTileShapeScale = cute::Shape<cute::_1, cute::_1>;
|
||||
using SmemTiledCopyScale
|
||||
= decltype(cute::make_tiled_copy_impl(SmemCopyAtomScale{}, SmemLayoutTVScale{}, SmemTileShapeScale{}));
|
||||
};
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
CUTE_DEVICE auto util_convert_type(cute::Tensor<Engine, Layout> const& tensor)
|
||||
{
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(cute::size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
// HACK: this requires tensor to be "contiguous"
|
||||
auto frag = convert_op(*reinterpret_cast<cutlass::Array<From_type, numel> const*>(tensor.data()));
|
||||
return cute::make_tensor(cute::make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||
}
|
||||
|
||||
template <typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
CUTE_DEVICE void util_copy(
|
||||
TiledCopy const& tiled_copy, cute::Tensor<Engine0, Layout0> const& S, cute::Tensor<Engine1, Layout1>& D)
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(cute::rank(S) == cute::Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(cute::rank(D) == cute::Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(cute::size<0>(S) == cute::size<0>(D));
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(S) == cute::size<1>(D));
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(S) == cute::size<2>(D));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < cute::size<1>(S); ++m)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k = 0; k < cute::size<2>(S); ++k)
|
||||
{
|
||||
cute::copy(tiled_copy, S(cute::_, m, k), D(cute::_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ada_blockwise_gemm
|
||||
@ -1,173 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/gemm/device/gemm.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/trace.h>
|
||||
|
||||
#include "ada_blockwise_gemm_kernel.cuh"
|
||||
|
||||
#define CUTLASS_HOST_TRACE(x) \
|
||||
{ \
|
||||
std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; \
|
||||
}
|
||||
|
||||
namespace ada_blockwise_gemm
|
||||
{
|
||||
|
||||
template <typename GemmKernel>
|
||||
CUTLASS_GLOBAL void run_global(typename GemmKernel::Params params)
|
||||
{
|
||||
// Dynamic shared memory base pointer
|
||||
extern __shared__ int SharedStorageBase[];
|
||||
// Declare pointer to dynamic shared memory.
|
||||
typename GemmKernel::SharedStorage* shared_storage
|
||||
= reinterpret_cast<typename GemmKernel::SharedStorage*>(SharedStorageBase);
|
||||
|
||||
GemmKernel::invoke(params, *shared_storage);
|
||||
}
|
||||
|
||||
using namespace cutlass;
|
||||
|
||||
template <typename KT>
|
||||
struct AdaBlockwiseGemm
|
||||
{
|
||||
|
||||
using GemmKernel = AdaBlockwiseGemmKernel<KT>;
|
||||
|
||||
static constexpr int kSmemSize = GemmKernel::kSmemSize;
|
||||
static constexpr int kThreadCount = GemmKernel::kThreadCount;
|
||||
|
||||
/// Kernel parameters object
|
||||
typename GemmKernel::Params params_;
|
||||
|
||||
AdaBlockwiseGemm()
|
||||
: params_()
|
||||
{
|
||||
}
|
||||
|
||||
using Arguments = typename GemmKernel::Arguments;
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("AdaBlockwiseGemmKernel::maximum_active_blocks()");
|
||||
|
||||
CUTLASS_TRACE_HOST(" kSmemSize: " << kSmemSize << " bytes");
|
||||
|
||||
cudaError_t result;
|
||||
if (kSmemSize > (48 << 10))
|
||||
{
|
||||
result
|
||||
= cudaFuncSetAttribute(run_global<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_HOST_TRACE(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int max_active_blocks = -1;
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, run_global<GemmKernel>, kThreadCount, kSmemSize);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_HOST_TRACE(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned "
|
||||
"error "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_TRACE(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
Status can_implement(Arguments const& args)
|
||||
{
|
||||
if (kSmemSize > (48 << 10))
|
||||
{
|
||||
cudaError_t result
|
||||
= cudaFuncSetAttribute(run_global<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_HOST_TRACE(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result));
|
||||
return Status::kInvalid;
|
||||
}
|
||||
}
|
||||
|
||||
if (args.problem_size.n() % KT::kTileN != 0)
|
||||
{
|
||||
CUTLASS_HOST_TRACE(" n:" << args.problem_size.n() << " % kTileN:" << KT::kTileN << " != 0");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
|
||||
if (args.problem_size.k() % KT::kTileK != 0)
|
||||
{
|
||||
CUTLASS_HOST_TRACE(" k:" << args.problem_size.k() << " % kTileK:" << KT::kTileK << " != 0");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
|
||||
{
|
||||
|
||||
params_ = GemmKernel::to_underlying_arguments(args);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
Status run(cudaStream_t stream = nullptr)
|
||||
{
|
||||
|
||||
// Configure grid and block dimensions
|
||||
|
||||
dim3 grid = GemmKernel::get_grid_shape(params_.problem_size);
|
||||
dim3 block = GemmKernel::get_block_shape();
|
||||
|
||||
// Launch kernel
|
||||
run_global<GemmKernel><<<grid, block, kSmemSize, stream>>>(params_);
|
||||
|
||||
// Query for errors
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_HOST_TRACE(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ada_blockwise_gemm
|
||||
@ -1,513 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ada_blockwise_gemm_traits.cuh"
|
||||
|
||||
namespace ada_blockwise_gemm
|
||||
{
|
||||
|
||||
template <typename KT>
|
||||
struct AdaBlockwiseGemmKernel
|
||||
{
|
||||
using Params = typename KT::Params;
|
||||
using Arguments = typename KT::Arguments;
|
||||
using SharedStorage = typename KT::SharedStorage;
|
||||
|
||||
static constexpr int kThreadCount = KT::kThreadCount;
|
||||
static constexpr int kSmemSize = KT::kSmemSize;
|
||||
|
||||
static dim3 get_grid_shape(GemmCoord problem_size)
|
||||
{
|
||||
|
||||
int grid_m = (problem_size.m() + KT::kTileM - 1) / KT::kTileM;
|
||||
int grid_n = (problem_size.n() + KT::kTileN - 1) / KT::kTileN;
|
||||
int grid_k = 1;
|
||||
return dim3(grid_m, grid_n, grid_k);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape()
|
||||
{
|
||||
return dim3(kThreadCount, 1, 1);
|
||||
}
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args)
|
||||
{
|
||||
return KT::to_underlying_arguments(args);
|
||||
}
|
||||
|
||||
// Factory invocation
|
||||
CUTLASS_DEVICE
|
||||
static void invoke(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
AdaBlockwiseGemmKernel op;
|
||||
op(params, shared_storage);
|
||||
}
|
||||
|
||||
CUTE_DEVICE auto gmem_tensor_init(Params const& params)
|
||||
{
|
||||
using X = cute::Underscore;
|
||||
|
||||
int const M = params.problem_size.m();
|
||||
int const N = params.problem_size.n();
|
||||
int const K = params.problem_size.k();
|
||||
int const ScaleM = (((M + 3) >> 2) << 2); // align 4
|
||||
int const ScaleN = (N + KT::ScaleGranularityN - 1) / KT::ScaleGranularityN;
|
||||
int const ScaleK = (K + KT::ScaleGranularityK - 1) / KT::ScaleGranularityK;
|
||||
|
||||
typename KT::ElementA const* ptr_A_ = params.ptr_a;
|
||||
typename KT::ElementB const* ptr_B_ = params.ptr_b;
|
||||
typename KT::ElementOutput* ptr_output_ = params.ptr_output;
|
||||
typename KT::ElementBlockScale const* ptr_scale_a_ = params.ptr_scale_a;
|
||||
typename KT::ElementBlockScale const* ptr_scale_b_ = params.ptr_scale_b;
|
||||
|
||||
cute::Tensor mA_mk
|
||||
= cute::make_tensor(cute::make_gmem_ptr(ptr_A_), cute::make_shape(M, K), cute::make_stride(K, cute::_1{}));
|
||||
|
||||
cute::Tensor mB_nk
|
||||
= cute::make_tensor(cute::make_gmem_ptr(ptr_B_), cute::make_shape(N, K), cute::make_stride(K, cute::_1{}));
|
||||
|
||||
cute::Tensor mOutput_mn = cute::make_tensor(
|
||||
cute::make_gmem_ptr(ptr_output_), cute::make_shape(M, N), cute::make_stride(N, cute::_1{}));
|
||||
|
||||
cute::Tensor mScaleA_mk = cute::make_tensor(
|
||||
cute::make_gmem_ptr(ptr_scale_a_), cute::make_shape(ScaleM, ScaleK), cute::make_stride(cute::_1{}, ScaleM));
|
||||
|
||||
cute::Tensor mScaleB_nk = cute::make_tensor(
|
||||
cute::make_gmem_ptr(ptr_scale_b_), cute::make_shape(ScaleN, ScaleK), cute::make_stride(ScaleK, cute::_1{}));
|
||||
|
||||
// partition the gmem tensor for each Cta
|
||||
cute::Tensor gA_mk = cute::local_tile(mA_mk, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, X, cute::_1>{}); // (BLK_M, BLK_K, m, k)
|
||||
|
||||
cute::Tensor gB_nk = cute::local_tile(mB_nk, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<X, cute::_1, cute::_1>{}); // (BLK_N, BLK_K, n, k)
|
||||
|
||||
cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
|
||||
|
||||
cute::Tensor gScaleA_mk = cute::local_tile(mScaleA_mk, typename KT::ScalePerTileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, X, cute::_1>{}); // (BLK_M, BLK_K, m, k)
|
||||
|
||||
cute::Tensor gScaleB_nk = cute::local_tile(mScaleB_nk, typename KT::ScalePerTileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<X, cute::_1, cute::_1>{}); // (BLK_N, BLK_K, n, k)
|
||||
|
||||
return cute::make_tuple(gA_mk, gB_nk, gOutput_mn, gScaleA_mk, gScaleB_nk);
|
||||
}
|
||||
|
||||
template <class TensorAccum, class TensorScaleA, class TensorScaleB>
|
||||
CUTE_DEVICE void promote(
|
||||
TensorAccum& accum, TensorAccum const& temp_accum, TensorScaleA const& tCrScaleA, TensorScaleB const& tCrScaleB)
|
||||
{
|
||||
|
||||
using AccumType = typename TensorAccum::value_type;
|
||||
CUTE_UNROLL
|
||||
for (int mma_m = 0; mma_m < cute::get<1>(cute::shape<0>(accum)); ++mma_m)
|
||||
{
|
||||
AccumType sFA = tCrScaleA(mma_m);
|
||||
AccumType sFB = tCrScaleB(0);
|
||||
AccumType scale = sFA * sFB;
|
||||
CUTE_UNROLL
|
||||
for (int mma_n = 0; mma_n < cute::get<0>(cute::shape<0>(accum)); ++mma_n)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int mma_iter_m = 0; mma_iter_m < cute::size<1>(accum); ++mma_iter_m)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int mma_iter_n = 0; mma_iter_n < cute::size<2>(accum); ++mma_iter_n)
|
||||
{
|
||||
auto coord = cute::make_coord(cute::make_coord(mma_n, mma_m), mma_iter_m, mma_iter_n);
|
||||
accum(coord) += temp_accum(coord) * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTE_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
int const block_m_idx = blockIdx.x;
|
||||
int const block_n_idx = blockIdx.y;
|
||||
int const thread_idx = threadIdx.x;
|
||||
int const residue_m = params.problem_size.m() - block_m_idx * cute::size<0>(typename KT::TileShape{});
|
||||
int const residue_n = params.problem_size.n() - block_n_idx * cute::size<1>(typename KT::TileShape{});
|
||||
// gmem tensor partition ..
|
||||
auto [gA_mk, gB_nk, gOutput_mn, gScaleA_mk, gScaleB_nk] = gmem_tensor_init(params);
|
||||
|
||||
// smem tensor ..
|
||||
cute::Tensor sA = cute::make_tensor(
|
||||
cute::make_smem_ptr(shared_storage.smem_a.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage)
|
||||
cute::Tensor sB = cute::make_tensor(
|
||||
cute::make_smem_ptr(shared_storage.smem_b.data()), typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage)
|
||||
cute::Tensor sO = cute::make_tensor(
|
||||
cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N)
|
||||
|
||||
cute::Tensor sScaleA = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_scale_a.data()),
|
||||
typename KT::SmemLayoutScaleA{}); // (BLK_M, BLK_K, Stage)
|
||||
cute::Tensor sScaleB = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_scale_b.data()),
|
||||
typename KT::SmemLayoutScaleB{}); // (BLK_N, BLK_K, Stage)
|
||||
|
||||
// (1) first step, get the B_res and B_gate
|
||||
|
||||
// (1.1) get partition for gmem -> smem
|
||||
cute::Tensor gA = gA_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k)
|
||||
cute::Tensor gB = gB_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k)
|
||||
|
||||
cute::Tensor gScaleA = gScaleA_mk(cute::_, cute::_, block_m_idx, cute::_);
|
||||
cute::Tensor gScaleB = gScaleB_nk(cute::_, cute::_, block_n_idx, cute::_);
|
||||
|
||||
typename KT::GmemTiledCopyA gmem_tiled_copy_A;
|
||||
typename KT::GmemTiledCopyB gmem_tiled_copy_B;
|
||||
auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx);
|
||||
auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx);
|
||||
|
||||
cute::Tensor tAgA = gmem_thr_copy_A.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k)
|
||||
cute::Tensor tAsA = gmem_thr_copy_A.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,Stage)
|
||||
cute::Tensor tBgB = gmem_thr_copy_B.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k)
|
||||
cute::Tensor tBsB = gmem_thr_copy_B.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,Stage)
|
||||
|
||||
typename KT::GmemTiledCopyScaleA gmem_tiled_copy_ScaleA;
|
||||
typename KT::GmemTiledCopyScaleB gmem_tiled_copy_ScaleB;
|
||||
auto gmem_thr_copy_ScaleA = gmem_tiled_copy_ScaleA.get_slice(thread_idx);
|
||||
auto gmem_thr_copy_ScaleB = gmem_tiled_copy_ScaleB.get_slice(thread_idx);
|
||||
|
||||
cute::Tensor tAgScaleA = gmem_thr_copy_ScaleA.partition_S(gScaleA); // (ACPY,ACPY_M,ACPY_K,k)
|
||||
cute::Tensor tAsScaleA = gmem_thr_copy_ScaleA.partition_D(sScaleA); // (ACPY,ACPY_M,ACPY_K,Stage)
|
||||
cute::Tensor tBgScaleB = gmem_thr_copy_ScaleB.partition_S(gScaleB); // (BCPY,BCPY_N,BCPY_K,k)
|
||||
cute::Tensor tBsScaleB = gmem_thr_copy_ScaleB.partition_D(sScaleB); // (BCPY,BCPY_N,BCPY_K,Stage)
|
||||
|
||||
// Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor)
|
||||
cute::Tensor tApA = cute::make_tensor<bool>(
|
||||
cute::make_shape(cute::size<1>(tAsA), cute::size<2>(tAsA)), cute::Stride<cute::_1, cute::_0>{});
|
||||
// Construct identity layout for sA
|
||||
cute::Tensor cA = make_identity_tensor(
|
||||
cute::make_shape(cute::size<0>(sA), cute::size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
|
||||
// Repeat the partitioning with identity layouts
|
||||
cute::Tensor tAcA = gmem_thr_copy_A.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
|
||||
// Set predicates for m bounds
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < cute::size<0>(tApA); ++m)
|
||||
{
|
||||
tApA(m, 0) = cute::get<0>(tAcA(0, m, 0)) < residue_m; // blk_m coord < residue_m
|
||||
}
|
||||
|
||||
cute::Tensor tBpB = cute::make_tensor<bool>(
|
||||
cute::make_shape(cute::size<1>(tBsB), cute::size<2>(tBsB)), cute::Stride<cute::_1, cute::_0>{});
|
||||
// Construct identity layout for sB
|
||||
cute::Tensor cB = make_identity_tensor(
|
||||
cute::make_shape(cute::size<0>(sB), cute::size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
||||
// Repeat the partitioning with identity layouts
|
||||
cute::Tensor tBcB = gmem_thr_copy_B.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
|
||||
|
||||
// Set predicates for n bounds
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < cute::size<0>(tBpB); ++n)
|
||||
{
|
||||
tBpB(n, 0) = cute::get<0>(tBcB(0, n, 0)) < residue_n; // blk_n coord < residue_n
|
||||
}
|
||||
|
||||
cute::Tensor tApSFA = cute::make_tensor<bool>(
|
||||
cute::make_shape(cute::size<1>(tAsScaleA), cute::size<2>(tAsScaleA)), cute::Stride<cute::_1, cute::_0>{});
|
||||
cute::Tensor tAcSFA = gmem_thr_copy_ScaleA.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < cute::size<0>(tApSFA); ++m)
|
||||
{
|
||||
tApSFA(m, 0) = cute::get<0>(tAcSFA(0, m, 0)) < residue_m; // blk_m coord < residue_m
|
||||
}
|
||||
// (1.2) prefetch gmem -> smem
|
||||
cute::clear(tAsA); // we don't need to clear tBsB..
|
||||
cute::clear(tBsB);
|
||||
|
||||
cute::clear(tAsScaleA);
|
||||
cute::clear(tBsScaleB);
|
||||
|
||||
auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gA)); // emm, iter start from 0
|
||||
int k_tile_count = cute::size<2>(gA);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe)
|
||||
{
|
||||
if (k_tile_count <= 0)
|
||||
{
|
||||
cute::clear(tApA);
|
||||
cute::clear(tBpB);
|
||||
cute::clear(tApSFA);
|
||||
}
|
||||
cute::copy_if(gmem_tiled_copy_A, tApA, tAgA(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tAsA(cute::_, cute::_, cute::_, k_pipe));
|
||||
cute::copy_if(gmem_tiled_copy_B, tBpB, tBgB(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tBsB(cute::_, cute::_, cute::_, k_pipe));
|
||||
|
||||
cute::copy_if(gmem_tiled_copy_ScaleA, tApSFA, tAgScaleA(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tAsScaleA(cute::_, cute::_, cute::_, k_pipe));
|
||||
cute::copy(gmem_tiled_copy_ScaleB, tBgScaleB(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tBsScaleB(cute::_, cute::_, cute::_, k_pipe));
|
||||
|
||||
cute::cp_async_fence();
|
||||
k_tile_count--;
|
||||
if (k_tile_count > 0)
|
||||
{
|
||||
++k_tile_iter;
|
||||
}
|
||||
}
|
||||
|
||||
// (1.3) get partition for rf
|
||||
typename KT::TiledMma tiled_mma;
|
||||
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
cute::Tensor tCrA = thr_mma.partition_fragment_A(sA(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K)
|
||||
cute::Tensor tCrB = thr_mma.partition_fragment_B(sB(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K)
|
||||
|
||||
cute::Tensor accum
|
||||
= cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N)
|
||||
cute::Tensor temp_accum
|
||||
= cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N)
|
||||
cute::clear(accum);
|
||||
// checkout the shape
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tCrA) == cute::size<1>(accum)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tCrB) == cute::size<2>(accum)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(tCrA) == cute::size<2>(tCrB)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma));
|
||||
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma));
|
||||
|
||||
// (1.4)retiling the smem and rf for copy..
|
||||
auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma);
|
||||
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx);
|
||||
cute::Tensor tOsA = smem_thr_copy_A.partition_S(sA); // (CPY,CPY_M,CPY_K,Stage)
|
||||
cute::Tensor tCrA_write = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K)
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsA) == cute::size<1>(tCrA_write)); // CPY_M
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsA) == cute::size<2>(tCrA_write)); // CPY_K
|
||||
|
||||
auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma);
|
||||
auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx);
|
||||
cute::Tensor tOsB = smem_thr_copy_B.partition_S(sB); // (CPY,CPY_N,CPY_K,Stage)
|
||||
cute::Tensor tCrB_write = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_N,CPY_K)
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsB) == cute::size<1>(tCrB_write)); // CPY_N
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsB) == cute::size<2>(tCrB_write)); // CPY_K
|
||||
|
||||
typename KT::SmemCopyAtomScaleA smem_tiled_copy_ScaleA;
|
||||
typename KT::SmemCopyAtomScaleB smem_tiled_copy_ScaleB;
|
||||
auto smem_thr_copy_ScaleA = smem_tiled_copy_ScaleA.get_thread_slice(thread_idx);
|
||||
auto smem_thr_copy_ScaleB = smem_tiled_copy_ScaleB.get_thread_slice(thread_idx);
|
||||
|
||||
cute::Tensor tOsScaleA = smem_thr_copy_ScaleA.partition_S(sScaleA);
|
||||
cute::Tensor tCrScaleA = cute::make_fragment_like(tOsScaleA(cute::_, cute::_, cute::_, 0));
|
||||
cute::Tensor tOsScaleB = smem_thr_copy_ScaleB.partition_S(sScaleB);
|
||||
cute::Tensor tCrScaleB = cute::make_fragment_like(tOsScaleB(cute::_, cute::_, cute::_, 0));
|
||||
|
||||
// (1.5) mainloop
|
||||
// Current pipe index in smem to read from
|
||||
int smem_pipe_read = 0;
|
||||
// Current pipe index in smem to write to
|
||||
int smem_pipe_write = KT::Stages - 1;
|
||||
|
||||
cute::Tensor tOsA_read = tOsA(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
cute::Tensor tOsB_read = tOsB(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
|
||||
cute::Tensor tOsScaleA_read = tOsScaleA(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
cute::Tensor tOsScaleB_read = tOsScaleB(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
|
||||
constexpr int K_BLOCK_MAX = cute::size<2>(tCrA);
|
||||
// prefetch register pipeline
|
||||
if constexpr (K_BLOCK_MAX > 1)
|
||||
{
|
||||
cute::cp_async_wait<KT::Stages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Prefetch the first k-tile smem -> reg
|
||||
cute::copy(smem_tiled_copy_A, tOsA_read(cute::_, cute::_, cute::Int<0>{}),
|
||||
tCrA_write(cute::_, cute::_, cute::Int<0>{}));
|
||||
cute::copy(smem_tiled_copy_B, tOsB_read(cute::_, cute::_, cute::Int<0>{}),
|
||||
tCrB_write(cute::_, cute::_, cute::Int<0>{}));
|
||||
}
|
||||
// k loop for mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
cute::clear(temp_accum);
|
||||
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
|
||||
[&](auto k_block)
|
||||
{
|
||||
if (k_block == K_BLOCK_MAX - 1)
|
||||
{
|
||||
tOsA_read = tOsA(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tOsB_read = tOsB(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
|
||||
tOsScaleA_read = tOsScaleA(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tOsScaleB_read = tOsScaleB(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
|
||||
cute::cp_async_wait<KT::Stages - 2>();
|
||||
__syncthreads();
|
||||
}
|
||||
// Load A, B smem -> reg for k_block+1
|
||||
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
|
||||
cute::copy(smem_tiled_copy_A, tOsA_read(cute::_, cute::_, k_block_next),
|
||||
tCrA_write(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsB_read(cute::_, cute::_, k_block_next),
|
||||
tCrB_write(cute::_, cute::_, k_block_next));
|
||||
// Copy gmem -> smem before computing gemm on each k-pipe
|
||||
if (k_block == 0)
|
||||
{
|
||||
cute::copy_if(gmem_tiled_copy_A, tApA, tAgA(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tAsA(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
cute::copy_if(gmem_tiled_copy_B, tBpB, tBgB(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tBsB(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
|
||||
cute::copy_if(gmem_tiled_copy_ScaleA, tApSFA,
|
||||
tAgScaleA(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tAsScaleA(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
cute::copy(gmem_tiled_copy_ScaleB, tBgScaleB(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tBsScaleB(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
|
||||
cute::cp_async_fence();
|
||||
if (k_tile_count - 1 > 0)
|
||||
{
|
||||
++k_tile_iter;
|
||||
}
|
||||
|
||||
cute::copy(smem_tiled_copy_ScaleA, tOsScaleA_read, tCrScaleA);
|
||||
cute::copy(smem_tiled_copy_ScaleB, tOsScaleB_read, tCrScaleB);
|
||||
|
||||
// Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe)
|
||||
smem_pipe_write = smem_pipe_read;
|
||||
++smem_pipe_read;
|
||||
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
|
||||
}
|
||||
// Thread-level register gemm for k_block
|
||||
cute::gemm(tiled_mma, temp_accum, tCrA(cute::_, cute::_, k_block), tCrB(cute::_, cute::_, k_block),
|
||||
temp_accum);
|
||||
});
|
||||
|
||||
promote(accum, temp_accum, tCrScaleA, tCrScaleB);
|
||||
}
|
||||
// load tail
|
||||
cute::for_each(cute::make_int_sequence<KT::Stages - 2>{},
|
||||
[&](auto WaitIndex)
|
||||
{
|
||||
k_tile_count--;
|
||||
using WaitIndex_t = decltype(WaitIndex);
|
||||
cute::clear(temp_accum);
|
||||
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
|
||||
[&](auto k_block)
|
||||
{
|
||||
if (k_block == K_BLOCK_MAX - 1)
|
||||
{
|
||||
tOsA_read = tOsA(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tOsB_read = tOsB(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
|
||||
tOsScaleA_read = tOsScaleA(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tOsScaleB_read = tOsScaleB(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
|
||||
cute::cp_async_wait<KT::Stages - 3 - WaitIndex_t::value>();
|
||||
__syncthreads();
|
||||
}
|
||||
// Load A, B smem -> reg for k_block+1
|
||||
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
|
||||
cute::copy(smem_tiled_copy_A, tOsA_read(cute::_, cute::_, k_block_next),
|
||||
tCrA_write(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsB_read(cute::_, cute::_, k_block_next),
|
||||
tCrB_write(cute::_, cute::_, k_block_next));
|
||||
if (k_block == 0)
|
||||
{
|
||||
|
||||
cute::copy(smem_tiled_copy_ScaleA, tOsScaleA_read, tCrScaleA);
|
||||
cute::copy(smem_tiled_copy_ScaleB, tOsScaleB_read, tCrScaleB);
|
||||
|
||||
// only update smem_pipe_read
|
||||
++smem_pipe_read;
|
||||
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
|
||||
}
|
||||
// Thread-level register gemm for k_block
|
||||
cute::gemm(tiled_mma, temp_accum, tCrA(cute::_, cute::_, k_block),
|
||||
tCrB(cute::_, cute::_, k_block), temp_accum);
|
||||
});
|
||||
|
||||
promote(accum, temp_accum, tCrScaleA, tCrScaleB);
|
||||
});
|
||||
// mma tail
|
||||
cute::clear(temp_accum);
|
||||
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
|
||||
[&](auto k_block)
|
||||
{
|
||||
// Load A, B smem -> reg for k_block+1
|
||||
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
|
||||
cute::copy(smem_tiled_copy_A, tOsA_read(cute::_, cute::_, k_block_next),
|
||||
tCrA_write(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsB_read(cute::_, cute::_, k_block_next),
|
||||
tCrB_write(cute::_, cute::_, k_block_next));
|
||||
if (k_block == 0)
|
||||
{
|
||||
|
||||
cute::copy(smem_tiled_copy_ScaleA, tOsScaleA_read, tCrScaleA);
|
||||
cute::copy(smem_tiled_copy_ScaleB, tOsScaleB_read, tCrScaleB);
|
||||
}
|
||||
// Thread-level register gemm for k_block
|
||||
cute::gemm(tiled_mma, temp_accum, tCrA(cute::_, cute::_, k_block), tCrB(cute::_, cute::_, k_block),
|
||||
temp_accum);
|
||||
});
|
||||
|
||||
promote(accum, temp_accum, tCrScaleA, tCrScaleB);
|
||||
|
||||
// (4) push all the result to smem
|
||||
// (4.1) convert result from ElementAccum to ElementA
|
||||
cute::Tensor epi = util_convert_type<KT::ElementOutput>(accum);
|
||||
|
||||
// (4.2) rf -> smem
|
||||
auto smem_tiled_copy_R2S = cute::make_tiled_copy_C(typename KT::SmemCopyAtomR2S{}, tiled_mma);
|
||||
auto smem_thr_copy_R2S = smem_tiled_copy_R2S.get_thread_slice(thread_idx);
|
||||
// cute::clear(sO);
|
||||
cute::Tensor tRS_rO = smem_thr_copy_R2S.retile_S(epi);
|
||||
cute::Tensor tRS_sO = smem_thr_copy_R2S.partition_D(sO);
|
||||
|
||||
cute::copy(smem_tiled_copy_R2S, tRS_rO, tRS_sO);
|
||||
__syncthreads();
|
||||
|
||||
// (4.3) smem -> rf
|
||||
|
||||
typename KT::SmemTiledCopyS2R smem_tiled_copy_S2R;
|
||||
auto smem_thr_copy_S2R = smem_tiled_copy_S2R.get_thread_slice(thread_idx);
|
||||
cute::Tensor tSR_sO = smem_thr_copy_S2R.partition_S(sO);
|
||||
cute::Tensor tSR_rO = cute::make_tensor<KT::ElementOutput>(cute::shape(tSR_sO));
|
||||
|
||||
cute::copy(smem_tiled_copy_S2R, tSR_sO, tSR_rO);
|
||||
__syncthreads();
|
||||
|
||||
// (4.4) rf -> gmem
|
||||
cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx);
|
||||
cute::Tensor cO = cute::make_identity_tensor(
|
||||
cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{})));
|
||||
auto tRG_rO = smem_thr_copy_S2R.retile_S(tSR_rO);
|
||||
auto tRG_gO = smem_thr_copy_S2R.partition_D(gO);
|
||||
auto tRG_cO = smem_thr_copy_S2R.partition_D(cO);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < cute::size<1>(tRG_cO); ++m)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < cute::size<2>(tRG_cO); ++n)
|
||||
{
|
||||
if (cute::get<0>(tRG_cO(0, m, n)) < residue_m && cute::get<1>(tRG_cO(0, m, n)) < residue_n)
|
||||
{
|
||||
cute::copy(typename KT::GmemCopyAtomR2G{}, tRG_rO(cute::_, m, n), tRG_gO(cute::_, m, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ada_blockwise_gemm
|
||||
@ -1,217 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
#include <cuda_runtime.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cute/int_tuple.hpp>
|
||||
#include <cute/layout.hpp>
|
||||
#include <cutlass/arch/mma.h>
|
||||
#include "ada_blockwise_mma_utils.cuh"
|
||||
#include "ada_blockwise_copy_utils.cuh"
|
||||
|
||||
// clang-format on
|
||||
using namespace cute;
|
||||
using namespace cutlass;
|
||||
using namespace cutlass::gemm;
|
||||
|
||||
namespace ada_blockwise_gemm
|
||||
{
|
||||
|
||||
template <typename ElementType, typename OutElementType, typename AccumElementType, typename BlockScaleElementType,
|
||||
int Stages_, int TileM_, int TileN_, int TileK_>
|
||||
struct AdaBlockwiseGemmTraits
|
||||
{
|
||||
using ElementA = ElementType;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
|
||||
using ElementB = ElementType;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
using ElementAccum = AccumElementType;
|
||||
using ElementBlockScale = BlockScaleElementType;
|
||||
using ElementOutput = OutElementType;
|
||||
|
||||
using index_t = uint32_t;
|
||||
static_assert(TileM_ % 16 == 0);
|
||||
static_assert(TileN_ % 32 == 0);
|
||||
static_assert(TileK_ % 32 == 0);
|
||||
static constexpr int Stages = Stages_;
|
||||
static constexpr int kTileM = TileM_;
|
||||
static constexpr int kTileN = TileN_;
|
||||
static constexpr int kTileK = (kTileM > 16) ? (TileK_) : (TileK_ >= 64 ? TileK_ : 64);
|
||||
|
||||
// tile shape
|
||||
using TileShape = cute::Shape<cute::Int<kTileM>, cute::Int<kTileN>, cute::Int<kTileK>>;
|
||||
static constexpr int kWarpsCount = 4;
|
||||
static constexpr int kThreadCount = kWarpsCount * 32;
|
||||
|
||||
static constexpr int ScaleGranularityM = 1;
|
||||
static constexpr int ScaleGranularityN = 128;
|
||||
static constexpr int ScaleGranularityK = 128;
|
||||
|
||||
static constexpr int ScaleMsPerTile = (kTileM + ScaleGranularityM - 1) / ScaleGranularityM;
|
||||
static constexpr int ScaleNsPerTile = (kTileN + ScaleGranularityN - 1) / ScaleGranularityN;
|
||||
static constexpr int ScaleKsPerTile = (kTileK + ScaleGranularityK - 1) / ScaleGranularityK;
|
||||
|
||||
static_assert(ScaleKsPerTile >= 1, "ScaleKsPerTile must be greater than or equal to 1");
|
||||
|
||||
using ScaleGranularity
|
||||
= cute::Shape<cute::Int<ScaleGranularityM>, cute::Int<ScaleGranularityN>, cute::Int<ScaleGranularityK>>;
|
||||
using ScalePerTileShape
|
||||
= cute::Shape<cute::Int<ScaleMsPerTile>, cute::Int<ScaleNsPerTile>, cute::Int<ScaleKsPerTile>>;
|
||||
|
||||
// MMA atom arch and layout
|
||||
using TiledMma = DefaultGemm_TensorOp_MMA<cute::float_e4m3_t, cutlass::arch::Sm89>::TiledMma;
|
||||
|
||||
static constexpr int kBlockKSmem = 128;
|
||||
// A memory copy operand
|
||||
using DefaultOperandA
|
||||
= DefaultGemm_TensorOpSm80_OperandA<ElementA, cutlass::layout::RowMajor, AlignmentA, kBlockKSmem>;
|
||||
using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom;
|
||||
using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom;
|
||||
using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy;
|
||||
|
||||
// ScaleA memory copy operand
|
||||
using SmemLayoutAtomScaleA = decltype(cute::make_layout(typename CopyTraitsScaleA::SmemTileShapeScale{}));
|
||||
using SmemCopyAtomScaleA = typename CopyTraitsScaleA::SmemTiledCopyScale;
|
||||
using GmemTiledCopyScaleA = typename CopyTraitsScaleA::GmemTiledCopyScale;
|
||||
|
||||
// B memory copy operand
|
||||
using DefaultOperandB
|
||||
= DefaultGemm_TensorOpSm80_OperandB<ElementB, cutlass::layout::ColumnMajor, AlignmentB, kBlockKSmem>;
|
||||
using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom;
|
||||
using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom;
|
||||
using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy;
|
||||
|
||||
// ScaleB memory copy operand
|
||||
using SmemLayoutAtomScaleB = decltype(cute::make_layout(typename CopyTraitsScaleB::SmemTileShapeScale{}));
|
||||
using SmemCopyAtomScaleB = typename CopyTraitsScaleB::SmemTiledCopyScale;
|
||||
using GmemTiledCopyScaleB = typename CopyTraitsScaleB::GmemTiledCopyScale;
|
||||
|
||||
// Output memory copy operand
|
||||
using SmemLayoutAtomO = decltype(cute::composition(cute::Swizzle<3, 3, 3>{},
|
||||
cute::Layout<cute::Shape<cute::_8, cute::Shape<cute::_8, cute::_8>>,
|
||||
cute::Stride<cute::_8, cute::Stride<cute::_1, cute::_64>>>{}));
|
||||
|
||||
using SmemCopyAtomR2S = cute::Copy_Atom<cute::AutoVectorizingCopy, ElementOutput>;
|
||||
using SmemCopyAtomS2R = cute::Copy_Atom<cute::UniversalCopy<cute::uint128_t>, ElementOutput>;
|
||||
using GmemCopyAtomR2G = cute::Copy_Atom<cute::UniversalCopy<cute::uint128_t>, ElementOutput>;
|
||||
using SmemThrLayoutS2R
|
||||
= cute::Layout<cute::Shape<cute::Int<8>, cute::Int<16>>, cute::Stride<cute::Int<16>, cute::_1>>;
|
||||
using SmemValLayoutS2R = cute::Layout<cute::Shape<cute::Int<1>, cute::Int<8>>>;
|
||||
using SmemTiledCopyS2R = decltype(cute::make_tiled_copy(SmemCopyAtomS2R{}, SmemThrLayoutS2R{}, SmemValLayoutS2R{}));
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2);
|
||||
static_assert(cute::size<0>(TileShape{}) % cute::size<0>(SmemLayoutAtomA{}) == 0); // M
|
||||
static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomA{}) == 0); // K
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2);
|
||||
static_assert(cute::size<1>(TileShape{}) % cute::size<0>(SmemLayoutAtomB{}) == 0); // N
|
||||
static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomB{}) == 0); // K
|
||||
|
||||
using SmemLayoutA = decltype(cute::tile_to_shape(SmemLayoutAtomA{},
|
||||
cute::make_shape(
|
||||
cute::shape<0>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int<Stages>{}))); // BLK_M, BLK_K, Stages
|
||||
using SmemLayoutB = decltype(cute::tile_to_shape(SmemLayoutAtomB{},
|
||||
cute::make_shape(
|
||||
cute::shape<1>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int<Stages>{}))); // BLK_N, BLK_K, Stages
|
||||
using SmemLayoutO = decltype(cute::tile_to_shape(
|
||||
SmemLayoutAtomO{}, cute::make_shape(cute::shape<0>(TileShape{}), cute::shape<1>(TileShape{})))); // BLK_M, BLK_N
|
||||
|
||||
using SmemLayoutScaleA = decltype(cute::tile_to_shape(SmemLayoutAtomScaleA{},
|
||||
cute::make_shape(cute::shape<0>(ScalePerTileShape{}), cute::shape<2>(ScalePerTileShape{}),
|
||||
cute::Int<Stages>{}))); // BLK_M, BLK_K, Stages
|
||||
using SmemLayoutScaleB = decltype(cute::tile_to_shape(SmemLayoutAtomScaleB{},
|
||||
cute::make_shape(cute::shape<1>(ScalePerTileShape{}), cute::shape<2>(ScalePerTileShape{}),
|
||||
cute::Int<Stages>{}))); // BLK_N, BLK_K, Stages
|
||||
|
||||
// we need at least 2 stages..
|
||||
static_assert(Stages >= 2);
|
||||
|
||||
struct SharedStorage : cute::aligned_struct<128>
|
||||
{
|
||||
cute::array_aligned<ElementA, cute::cosize_v<SmemLayoutA>> smem_a;
|
||||
cute::array_aligned<ElementB, cute::cosize_v<SmemLayoutB>> smem_b;
|
||||
cute::array_aligned<ElementOutput, cute::cosize_v<SmemLayoutO>> smem_o;
|
||||
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleA>> smem_scale_a;
|
||||
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleB>> smem_scale_b;
|
||||
};
|
||||
|
||||
static constexpr int kSmemSize = static_cast<int>(sizeof(SharedStorage));
|
||||
|
||||
struct Params
|
||||
{
|
||||
GemmCoord problem_size{};
|
||||
ElementA const* ptr_a{};
|
||||
ElementB const* ptr_b{};
|
||||
ElementOutput* ptr_output{};
|
||||
BlockScaleElementType const* ptr_scale_a{};
|
||||
BlockScaleElementType const* ptr_scale_b{};
|
||||
|
||||
Params() {}
|
||||
|
||||
Params(GemmCoord problem_size_, ElementA const* ptr_a_, ElementB const* ptr_b_, ElementOutput* ptr_output_,
|
||||
BlockScaleElementType const* ptr_scale_a_, BlockScaleElementType const* ptr_scale_b_)
|
||||
: problem_size(problem_size_)
|
||||
, ptr_a(ptr_a_)
|
||||
, ptr_b(ptr_b_)
|
||||
, ptr_output(ptr_output_)
|
||||
, ptr_scale_a(ptr_scale_a_)
|
||||
, ptr_scale_b(ptr_scale_b_)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct Arguments
|
||||
{
|
||||
GemmCoord problem_size{};
|
||||
void const* ptr_a;
|
||||
void const* ptr_b;
|
||||
void* ptr_d;
|
||||
float const* ptr_scale_a;
|
||||
float const* ptr_scale_b;
|
||||
|
||||
Arguments(GemmCoord problem_size_, void const* ptr_a_, void const* ptr_b_, void* ptr_d_,
|
||||
float const* ptr_scale_a_, float const* ptr_scale_b_)
|
||||
: problem_size(problem_size_)
|
||||
, ptr_a(ptr_a_)
|
||||
, ptr_b(ptr_b_)
|
||||
, ptr_d(ptr_d_)
|
||||
, ptr_scale_a(ptr_scale_a_)
|
||||
, ptr_scale_b(ptr_scale_b_)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args)
|
||||
{
|
||||
auto ptr_a = reinterpret_cast<ElementA const*>(args.ptr_a);
|
||||
auto ptr_b = reinterpret_cast<ElementB const*>(args.ptr_b);
|
||||
auto ptr_d = reinterpret_cast<ElementOutput*>(args.ptr_d);
|
||||
auto ptr_scale_a = reinterpret_cast<ElementBlockScale const*>(args.ptr_scale_a);
|
||||
auto ptr_scale_b = reinterpret_cast<ElementBlockScale const*>(args.ptr_scale_b);
|
||||
|
||||
Params params(args.problem_size, ptr_a, ptr_b, ptr_d, ptr_scale_a, ptr_scale_b);
|
||||
|
||||
return params;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ada_blockwise_gemm
|
||||
@ -1,104 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include <cute/arch/mma.hpp>
|
||||
#include <cute/config.hpp>
|
||||
#include <cute/layout.hpp>
|
||||
#include <cutlass/arch/mma.h>
|
||||
|
||||
// Config
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890))
|
||||
#define CUTE_ARCH_MMA_F32_SM89_ENABLED
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
// MMA 16x8x32 TN
|
||||
struct SM89_16x8x32_F32F8F8F32_TN
|
||||
{
|
||||
using DRegisters = float[4];
|
||||
using ARegisters = uint32_t[4];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = float[4];
|
||||
|
||||
CUTE_HOST_DEVICE static void fma(float& d0, float& d1, float& d2, float& d3, uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& a2, uint32_t const& a3, uint32_t const& b0, uint32_t const& b1, float const& c0,
|
||||
float const& c1, float const& c2, float const& c3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_F32_SM89_ENABLED)
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5, %6, %7},"
|
||||
"{%8, %9},"
|
||||
"{%10, %11, %12, %13};\n"
|
||||
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3)
|
||||
: "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "f"(c0), "f"(c1), "f"(c2), "f"(c3));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH(
|
||||
"Attempting to use SM89_16x8x32_F32F8F8F32_TN without "
|
||||
"CUTE_ARCH_MMA_F32_SM89_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM89_16x8x32_F32F8F8F32_TN>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = float_e4m3_t;
|
||||
using ValTypeB = float_e4m3_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using Shape_MNK = Shape<_16, _8, _32>;
|
||||
using ThrID = Layout<_32>;
|
||||
using ALayout = Layout<Shape<Shape<_4, _8>, Shape<_4, _2, _2>>, Stride<Stride<_64, _1>, Stride<_16, _8, _256>>>;
|
||||
using BLayout = Layout<Shape<Shape<_4, _8>, Shape<_4, _2>>, Stride<Stride<_32, _1>, Stride<_8, _128>>>;
|
||||
using CLayout = SM80_16x8_Row;
|
||||
};
|
||||
|
||||
} // namespace cute
|
||||
|
||||
namespace ada_blockwise_gemm
|
||||
{
|
||||
|
||||
template <typename Element, typename Arch>
|
||||
struct DefaultGemm_TensorOp_MMA;
|
||||
|
||||
template <>
|
||||
struct DefaultGemm_TensorOp_MMA<cute::bfloat16_t, cutlass::arch::Sm80>
|
||||
{
|
||||
using ArchTag = cutlass::arch::Sm80;
|
||||
using MMA_Atom_Arch = cute::MMA_Atom<cute::SM80_16x8x16_F32BF16BF16F32_TN>;
|
||||
using ThreadLayoutMNK = cute::Layout<cute::Shape<cute::_2, cute::_2, cute::_1>>;
|
||||
using ValLayoutMNK = cute::Tile<cute::_32, cute::_32, cute::_16>;
|
||||
using TiledMma = cute::TiledMMA<MMA_Atom_Arch, ThreadLayoutMNK, ValLayoutMNK>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DefaultGemm_TensorOp_MMA<cute::float_e4m3_t, cutlass::arch::Sm89>
|
||||
{
|
||||
using ArchTag = cutlass::arch::Sm89;
|
||||
using MMA_Atom_Arch = cute::MMA_Atom<cute::SM89_16x8x32_F32F8F8F32_TN>;
|
||||
using ThreadLayoutMNK = cute::Layout<cute::Shape<cute::_2, cute::_2, cute::_1>>;
|
||||
using ValLayoutMNK = cute::Tile<cute::_32, cute::_32, cute::_32>;
|
||||
using TiledMma = cute::TiledMMA<MMA_Atom_Arch, ThreadLayoutMNK, ValLayoutMNK>;
|
||||
};
|
||||
|
||||
} // namespace ada_blockwise_gemm
|
||||
@ -0,0 +1,453 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "sm89_utils.cuh"
|
||||
|
||||
namespace ada_blockwise_gemm
|
||||
{
|
||||
|
||||
template <typename GemmKernel>
|
||||
CUTLASS_GLOBAL void sm89_fp8_gemm_1d1d_impl(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, void const* A,
|
||||
void const* B, void* D, float const* scales_a, float const* scales_b)
|
||||
{
|
||||
GemmKernel op;
|
||||
op.invoke(shape_m, shape_n, shape_k, A, B, D, scales_a, scales_b);
|
||||
}
|
||||
|
||||
template <typename GemmKernel>
|
||||
CUTLASS_GLOBAL void sm89_fp8_bmm_1d1d_impl(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, __nv_fp8_e4m3* A,
|
||||
__nv_fp8_e4m3* B, __nv_bfloat16* D, float* scales_a, float* scales_b, uint64_t stride_a, uint64_t stride_b,
|
||||
uint64_t stride_d, uint64_t stride_scales_a, uint64_t stride_scales_b)
|
||||
{
|
||||
GemmKernel op;
|
||||
|
||||
auto ptr_a = reinterpret_cast<typename GemmKernel::ElementInput const*>(A + blockIdx.z * stride_a);
|
||||
auto ptr_b = reinterpret_cast<typename GemmKernel::ElementInput const*>(B + blockIdx.z * stride_b);
|
||||
auto ptr_scale_a
|
||||
= reinterpret_cast<typename GemmKernel::ElementBlockScale const*>(scales_a + blockIdx.z * stride_scales_a);
|
||||
auto ptr_scale_b
|
||||
= reinterpret_cast<typename GemmKernel::ElementBlockScale const*>(scales_b + blockIdx.z * stride_scales_b);
|
||||
auto ptr_output = reinterpret_cast<typename GemmKernel::ElementOutput*>(D + blockIdx.z * stride_d);
|
||||
|
||||
op(ptr_a, ptr_b, ptr_scale_a, ptr_scale_b, ptr_output, shape_m, shape_n, shape_k);
|
||||
}
|
||||
|
||||
template <typename KT>
|
||||
struct AdaBlockwiseGemmKernel
|
||||
{
|
||||
using SharedStorage = typename KT::SharedStorage;
|
||||
using ElementInput = typename KT::ElementInput;
|
||||
using ElementOutput = typename KT::ElementOutput;
|
||||
using ElementBlockScale = typename KT::ElementBlockScale;
|
||||
|
||||
// Factory invocation
|
||||
CUTLASS_DEVICE
|
||||
void invoke(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, void const* A, void const* B, void* D,
|
||||
float const* scales_a, float const* scales_b)
|
||||
{
|
||||
auto ptr_a = reinterpret_cast<ElementInput const*>(A);
|
||||
auto ptr_b = reinterpret_cast<ElementInput const*>(B);
|
||||
auto ptr_scale_a = reinterpret_cast<ElementBlockScale const*>(scales_a);
|
||||
auto ptr_scale_b = reinterpret_cast<ElementBlockScale const*>(scales_b);
|
||||
auto ptr_output = reinterpret_cast<ElementOutput*>(D);
|
||||
|
||||
(*this)(ptr_a, ptr_b, ptr_scale_a, ptr_scale_b, ptr_output, shape_m, shape_n, shape_k);
|
||||
}
|
||||
|
||||
CUTE_DEVICE auto gmem_tensor_init(typename KT::ElementInput const* ptr_a, typename KT::ElementInput const* ptr_b,
|
||||
typename KT::ElementBlockScale const* ptr_scale_a, typename KT::ElementBlockScale const* ptr_scale_b,
|
||||
uint32_t M, uint32_t N, uint32_t K, int* SharedStorageBase)
|
||||
{
|
||||
using X = cute::Underscore;
|
||||
|
||||
uint32_t const ScaleM = (((M + 3) >> 2) << 2); // align 4
|
||||
uint32_t const ScaleN = (N + KT::ScaleGranularityN - 1) / KT::ScaleGranularityN;
|
||||
uint32_t const ScaleK = (K + KT::ScaleGranularityK - 1) / KT::ScaleGranularityK;
|
||||
|
||||
auto mA_mk
|
||||
= cute::make_tensor(cute::make_gmem_ptr(ptr_a), cute::make_shape(M, K), cute::make_stride(K, cute::_1{}));
|
||||
|
||||
auto mB_nk
|
||||
= cute::make_tensor(cute::make_gmem_ptr(ptr_b), cute::make_shape(N, K), cute::make_stride(K, cute::_1{}));
|
||||
|
||||
auto mSFA_mk = cute::make_tensor(
|
||||
cute::make_gmem_ptr(ptr_scale_a), cute::make_shape(ScaleM, ScaleK), cute::make_stride(cute::_1{}, ScaleM));
|
||||
|
||||
auto mSFB_nk = cute::make_tensor(
|
||||
cute::make_gmem_ptr(ptr_scale_b), cute::make_shape(ScaleN, ScaleK), cute::make_stride(ScaleK, cute::_1{}));
|
||||
|
||||
auto cta_coord = cute::make_coord(blockIdx.x, blockIdx.y, cute::_); // (m,n,k)
|
||||
auto gA
|
||||
= cute::local_tile(mA_mk, typename KT::TileShape{}, cta_coord, cute::Step<_1, X, _1>{}); // (BLK_M,BLK_K,k)
|
||||
auto gB
|
||||
= cute::local_tile(mB_nk, typename KT::TileShape{}, cta_coord, cute::Step<X, _1, _1>{}); // (BLK_N,BLK_K,k)
|
||||
auto gSFA = cute::local_tile(
|
||||
mSFA_mk, typename KT::ScalePerTileShape{}, cta_coord, cute::Step<_1, X, _1>{}); // (BLK_M,BLK_K)
|
||||
auto gSFB = cute::local_tile(
|
||||
mSFB_nk, typename KT::ScalePerTileShape{}, cta_coord, cute::Step<X, _1, _1>{}); // (BLK_N,BLK_K)
|
||||
|
||||
typename KT::SharedStorageLoad* load_storage
|
||||
= reinterpret_cast<typename KT::SharedStorageLoad*>(SharedStorageBase);
|
||||
auto sA = cute::make_tensor(cute::make_smem_ptr(load_storage->smem_a.data()), typename KT::SmemLayoutA{});
|
||||
auto sB = cute::make_tensor(cute::make_smem_ptr(load_storage->smem_b.data()), typename KT::SmemLayoutB{});
|
||||
auto sSFA = cute::make_tensor(cute::make_smem_ptr(load_storage->smem_sfa.data()), typename KT::SmemLayoutSFA{});
|
||||
auto sSFB = cute::make_tensor(cute::make_smem_ptr(load_storage->smem_sfb.data()), typename KT::SmemLayoutSFB{});
|
||||
|
||||
return cute::make_tuple(gA, gB, gSFA, gSFB, sA, sB, sSFA, sSFB);
|
||||
}
|
||||
|
||||
template <class Accumulator, class SharedStorage, class ElementOutput>
|
||||
CUTE_DEVICE void epilogue_with_smem(
|
||||
Accumulator& accum, SharedStorage& shared_storage, ElementOutput* o, int M, int N)
|
||||
{
|
||||
// convert type
|
||||
auto epi = cute::make_fragment_like<ElementOutput>(accum);
|
||||
cute::for_each(cute::make_int_sequence<cute::size(epi)>{}, [&](auto i) { epi(i) = ElementOutput(accum(i)); });
|
||||
|
||||
auto sO = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{});
|
||||
// copy rf -> smem
|
||||
typename KT::TiledMma mma;
|
||||
auto tiled_copy_R2S = cute::make_tiled_copy_C(typename KT::SmemCopyAtomR2S{}, mma);
|
||||
auto thr_copy_R2S = tiled_copy_R2S.get_slice(threadIdx.x);
|
||||
auto tRS_rO = thr_copy_R2S.retile_S(epi);
|
||||
auto tRS_sO = thr_copy_R2S.partition_D(sO);
|
||||
|
||||
cute::copy(tiled_copy_R2S, tRS_rO, tRS_sO);
|
||||
__syncthreads();
|
||||
|
||||
// copy smem -> rf
|
||||
typename KT::TiledCopyS2G tiled_copy_S2G;
|
||||
auto thr_copy_S2G = tiled_copy_S2G.get_slice(threadIdx.x);
|
||||
auto tSR_sO = thr_copy_S2G.partition_S(sO);
|
||||
auto tSR_rO = cute::make_tensor<KT::ElementOutput>(cute::shape(tSR_sO));
|
||||
|
||||
cute::copy(tiled_copy_S2G, tSR_sO, tSR_rO);
|
||||
__syncthreads();
|
||||
|
||||
// copy rf -> gmem
|
||||
auto mO = cute::make_tensor(cute::make_gmem_ptr(o), cute::make_shape(M, N), cute::make_stride(N, cute::_1{}));
|
||||
auto cta_coord = cute::make_coord(blockIdx.x, blockIdx.y, cute::_);
|
||||
auto gO = cute::local_tile(mO, typename KT::TileShape{}, cta_coord, cute::Step<cute::_1, cute::_1, X>{});
|
||||
auto cO = cute::make_identity_tensor(cute::make_shape(cute::Int<KT::kTileM>{}, cute::Int<KT::kTileN>{}));
|
||||
auto tRG_rO = thr_copy_S2G.retile_S(tSR_rO);
|
||||
auto tRG_gO = thr_copy_S2G.partition_D(gO);
|
||||
auto tRG_cO = thr_copy_S2G.partition_D(cO);
|
||||
|
||||
int residue_m = M - KT::kTileM * blockIdx.x;
|
||||
int residue_n = N - KT::kTileN * blockIdx.y;
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < cute::size<1>(tRG_gO); ++m)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < cute::size<2>(tRG_gO); ++n)
|
||||
{
|
||||
if (cute::get<0>(tRG_cO(0, m, n)) < residue_m && cute::get<1>(tRG_cO(0, m, n)) < residue_n)
|
||||
{
|
||||
cute::copy(typename KT::GmemCopyAtomR2G{}, tRG_rO(cute::_, m, n), tRG_gO(cute::_, m, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class TensorD, class TensorC, class TensorScale, class Index>
|
||||
CUTE_DEVICE void promote(TensorD& accum, TensorC const& temp_accum, TensorScale const& scale, Index n_block)
|
||||
{
|
||||
|
||||
using AccumType = typename TensorD::value_type;
|
||||
for (int mma_m = 0; mma_m < cute::get<1>(cute::shape<0>(accum)); ++mma_m)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int mma_n = 0; mma_n < cute::get<0>(cute::shape<0>(accum)); ++mma_n)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int mma_iter_m = 0; mma_iter_m < cute::size<1>(accum); ++mma_iter_m)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int mma_iter_n = 0; mma_iter_n < cute::size<2>(accum); ++mma_iter_n)
|
||||
{
|
||||
auto coord_d
|
||||
= cute::make_coord(cute::make_coord(mma_n, mma_m), mma_iter_m, mma_iter_n, n_block);
|
||||
auto coord_c = cute::make_coord(cute::make_coord(mma_n, mma_m), mma_iter_m, mma_iter_n);
|
||||
accum(coord_d) += temp_accum(coord_c) * scale(mma_m, mma_iter_m, cute::_0{});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTE_DEVICE
|
||||
void operator()(typename KT::ElementInput const* ptr_a, typename KT::ElementInput const* ptr_b,
|
||||
typename KT::ElementBlockScale const* ptr_scale_a, typename KT::ElementBlockScale const* ptr_scale_b,
|
||||
typename KT::ElementOutput* ptr_output, uint32_t M, uint32_t N, uint32_t K)
|
||||
{
|
||||
// Dynamic shared memory base pointer
|
||||
extern __shared__ int SharedStorageBase[];
|
||||
|
||||
auto [gA, gB, gSFA, gSFB, sA, sB, sSFA, sSFB]
|
||||
= gmem_tensor_init(ptr_a, ptr_b, ptr_scale_a, ptr_scale_b, M, N, K, SharedStorageBase);
|
||||
typename KT::GmemTiledCopyA g2s_copy_A;
|
||||
typename KT::GmemTiledCopyB g2s_copy_B;
|
||||
auto g2s_thr_copy_A = g2s_copy_A.get_slice(threadIdx.x);
|
||||
auto g2s_thr_copy_B = g2s_copy_B.get_slice(threadIdx.x);
|
||||
|
||||
auto tAgA = g2s_thr_copy_A.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k)
|
||||
auto tAsA = g2s_thr_copy_A.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,Stage)
|
||||
auto tBgB = g2s_thr_copy_B.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k)
|
||||
auto tBsB = g2s_thr_copy_B.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,Stage)
|
||||
|
||||
typename KT::GmemTiledCopySFA g2s_copy_SFA;
|
||||
typename KT::GmemTiledCopySFB g2s_copy_SFB;
|
||||
auto g2s_thr_copy_SFA = g2s_copy_SFA.get_slice(threadIdx.x);
|
||||
auto g2s_thr_copy_SFB = g2s_copy_SFB.get_slice(threadIdx.x);
|
||||
|
||||
auto tAgSFA = g2s_thr_copy_SFA.partition_S(gSFA); // (ACPY,ACPY_M,ACPY_K,Stage)
|
||||
auto tAsSFA = g2s_thr_copy_SFA.partition_D(sSFA); // (ACPY,ACPY_M,ACPY_K,Stage)
|
||||
auto tBgSFB = g2s_thr_copy_SFB.partition_S(gSFB); // (BCPY,BCPY_N,BCPY_K,Stage)
|
||||
auto tBsSFB = g2s_thr_copy_SFB.partition_D(sSFB); // (BCPY,BCPY_N,BCPY_K,Stage)
|
||||
|
||||
auto cA = make_identity_tensor(cute::make_shape(cute::size<0>(sA), cute::size<1>(sA)));
|
||||
auto tAcA = g2s_thr_copy_A.partition_S(cA);
|
||||
|
||||
auto cB = make_identity_tensor(cute::make_shape(cute::size<0>(sB), cute::size<1>(sB)));
|
||||
auto tBcB = g2s_thr_copy_B.partition_S(cB);
|
||||
|
||||
auto cSFA = cute::make_identity_tensor(typename KT::GmemTiledCopySFA::Tiler_MN{});
|
||||
auto tAcSFA = g2s_thr_copy_SFA.partition_S(cSFA);
|
||||
|
||||
int residue_m = M - KT::kTileM * blockIdx.x;
|
||||
int residue_n = N - KT::kTileN * blockIdx.y;
|
||||
residue_m = residue_m > KT::kTileM ? KT::kTileM : residue_m;
|
||||
residue_n = residue_n > KT::kTileN ? KT::kTileN : residue_n;
|
||||
|
||||
auto tApA = cute::make_tensor<bool>(
|
||||
cute::make_shape(cute::size<1>(tAsA), cute::size<2>(tAsA)), cute::Stride<cute::_1, cute::_0>{});
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < cute::size<0>(tApA); ++m)
|
||||
{
|
||||
tApA(m, 0) = cute::get<0>(tAcA(0, m, 0)) < residue_m; // blk_m coord < residue_m
|
||||
}
|
||||
|
||||
auto tBpB = cute::make_tensor<bool>(
|
||||
cute::make_shape(cute::size<1>(tBsB), cute::size<2>(tBsB)), cute::Stride<cute::_1, cute::_0>{});
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < cute::size<0>(tBpB); ++n)
|
||||
{
|
||||
tBpB(n, 0) = cute::get<0>(tBcB(0, n, 0)) < residue_n; // blk_n coord < residue_n
|
||||
}
|
||||
|
||||
auto tApSFA = cute::make_tensor<bool>(
|
||||
cute::make_shape(cute::size<1>(tAsSFA), cute::size<2>(tAsSFA)), cute::Stride<cute::_1, cute::_0>{});
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < cute::size<0>(tApSFA); ++m)
|
||||
{
|
||||
tApSFA(m, 0) = cute::get<0>(tAcSFA(0, m, 0)) < residue_m; // blk_m coord < residue_m
|
||||
}
|
||||
|
||||
// prefetch gmem A/B
|
||||
cute::clear(tAsA);
|
||||
cute::clear(tBsB);
|
||||
cute::clear(tAsSFA);
|
||||
|
||||
int k_tile_count = cute::size<2>(gA);
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe)
|
||||
{
|
||||
if (k_pipe >= k_tile_count)
|
||||
{
|
||||
cute::clear(tApA);
|
||||
cute::clear(tBpB);
|
||||
cute::clear(tApSFA);
|
||||
}
|
||||
auto k_tile_iter = std::min(k_pipe, k_tile_count - 1);
|
||||
cute::copy_if(g2s_copy_A, tApA, tAgA(cute::_, cute::_, cute::_, k_tile_iter),
|
||||
tAsA(cute::_, cute::_, cute::_, k_pipe));
|
||||
cute::copy_if(g2s_copy_B, tBpB, tBgB(cute::_, cute::_, cute::_, k_tile_iter),
|
||||
tBsB(cute::_, cute::_, cute::_, k_pipe));
|
||||
cute::copy_if(g2s_copy_SFA, tApSFA, tAgSFA(cute::_, cute::_, cute::_, k_tile_iter),
|
||||
tAsSFA(cute::_, cute::_, cute::_, k_pipe));
|
||||
cute::copy(g2s_copy_SFB, tBgSFB(cute::_, cute::_, cute::_, k_tile_iter),
|
||||
tBsSFB(cute::_, cute::_, cute::_, k_pipe));
|
||||
|
||||
cute::cp_async_fence();
|
||||
}
|
||||
|
||||
typename KT::TiledMma mma;
|
||||
auto thr_mma = mma.get_slice(threadIdx.x);
|
||||
auto accum = cute::partition_fragment_C(mma,
|
||||
cute::make_shape(cute::Int<KT::kTileM>{}, cute::Int<KT::kMmaPermN>{},
|
||||
cute::Int<KT::NUM_GROUP_N>{})); // (MMA,MMA_M,MMA_N)
|
||||
auto temp = cute::partition_fragment_C(
|
||||
mma, cute::make_shape(cute::Int<KT::kTileM>{}, cute::Int<KT::kMmaPermN>{})); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
auto mma_shape_A
|
||||
= cute::partition_shape_A(mma, cute::make_shape(cute::Int<KT::kTileM>{}, cute::Int<KT::kTileK>{}));
|
||||
auto tCrA = cute::make_tensor<typename KT::ElementInput>(mma_shape_A);
|
||||
|
||||
auto mma_shape_B = cute::partition_shape_B(
|
||||
mma, cute::make_shape(cute::Int<KT::kMmaPermN>{}, cute::Int<KT::kTileK>{}, cute::Int<KT::NUM_GROUP_N>{}));
|
||||
auto tCrB = cute::make_tensor<typename KT::ElementInput>(mma_shape_B);
|
||||
|
||||
auto s2r_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, mma);
|
||||
auto s2r_thr_copy_A = s2r_copy_A.get_slice(threadIdx.x);
|
||||
auto tXsA = s2r_thr_copy_A.partition_S(sA); // (CPY,CPY_M,CPY_K,Stage)
|
||||
auto tXrA = s2r_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K)
|
||||
static_assert(is_static<decltype(tXrA.layout())>::value, "tXrA layout must be static");
|
||||
|
||||
auto s2r_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, mma);
|
||||
auto s2r_thr_copy_B = s2r_copy_B.get_slice(threadIdx.x);
|
||||
auto tXsB = s2r_thr_copy_B.partition_S(sB); // (CPY,CPY_N,CPY_K,Stage)
|
||||
auto tXrB = s2r_thr_copy_B.retile_D(tCrB)(cute::_, cute::Int<0>{}, cute::_, cute::_);
|
||||
|
||||
typename KT::SmemTiledCopySFA s2r_copy_SFA;
|
||||
typename KT::SmemTiledCopySFB s2r_copy_SFB;
|
||||
auto s2r_thr_copy_SFA = s2r_copy_SFA.get_slice(threadIdx.x);
|
||||
auto s2r_thr_copy_SFB = s2r_copy_SFB.get_slice(threadIdx.x);
|
||||
|
||||
auto tXsSFA = s2r_thr_copy_SFA.partition_S(sSFA);
|
||||
auto tXrSFA = cute::make_fragment_like(tXsSFA(cute::_, cute::_, cute::_, 0));
|
||||
auto tXsSFB = s2r_thr_copy_SFB.partition_S(sSFB);
|
||||
auto tXrSFB = cute::make_fragment_like(tXsSFB(cute::_, cute::_, cute::_, 0));
|
||||
auto scale = cute::make_fragment_like(tXrSFA);
|
||||
|
||||
int smem_pipe_write = KT::Stages - 1;
|
||||
int smem_pipe_read = 0;
|
||||
|
||||
auto tXsA_read = tXsA(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
auto tXsB_read = tXsB(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
auto tXsSFA_read = tXsSFA(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
auto tXsSFB_read = tXsSFB(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
cute::cp_async_wait<KT::Stages - 2>();
|
||||
__syncthreads();
|
||||
// prefetch smem -> rf
|
||||
cute::copy(s2r_copy_SFA, tXsSFA_read, tXrSFA);
|
||||
cute::copy(s2r_copy_SFB, tXsSFB_read, tXrSFB);
|
||||
cute::copy(s2r_copy_A, tXsA_read, tXrA);
|
||||
cute::copy(s2r_copy_B, tXsB_read(cute::_, cute::Int<0>{}, cute::_), tXrB(cute::_, cute::_, cute::Int<0>{}));
|
||||
|
||||
cute::clear(accum);
|
||||
int k_tile_iter = KT::Stages - 1;
|
||||
while (k_tile_iter < k_tile_count)
|
||||
{
|
||||
cute::for_each(cute::make_int_sequence<KT::NUM_GROUP_N>{},
|
||||
[&](auto n_block)
|
||||
{
|
||||
if constexpr (n_block == KT::NUM_GROUP_N - 1)
|
||||
{
|
||||
tXsA_read = tXsA(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tXsB_read = tXsB(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tXsSFA_read = tXsSFA(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tXsSFB_read = tXsSFB(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
cute::cp_async_wait<KT::Stages - 2>();
|
||||
__syncthreads();
|
||||
cute::copy(s2r_copy_SFA, tXsSFA_read, tXrSFA);
|
||||
cute::copy(s2r_copy_SFB, tXsSFB_read, tXrSFB);
|
||||
}
|
||||
auto n_block_next = (n_block + cute::_1{}) % KT::NUM_GROUP_N;
|
||||
cute::copy(
|
||||
s2r_copy_B, tXsB_read(cute::_, n_block_next, cute::_), tXrB(cute::_, cute::_, n_block_next));
|
||||
if constexpr (n_block == 0)
|
||||
{
|
||||
// gmem -> smem
|
||||
cute::copy_if(g2s_copy_A, tApA, tAgA(cute::_, cute::_, cute::_, k_tile_iter),
|
||||
tAsA(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
cute::copy_if(g2s_copy_B, tBpB, tBgB(cute::_, cute::_, cute::_, k_tile_iter),
|
||||
tBsB(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
cute::copy_if(g2s_copy_SFA, tApSFA, tAgSFA(cute::_, cute::_, cute::_, k_tile_iter),
|
||||
tAsSFA(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
cute::copy(g2s_copy_SFB, tBgSFB(cute::_, cute::_, cute::_, k_tile_iter),
|
||||
tBsSFB(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
cute::cp_async_fence();
|
||||
k_tile_iter++;
|
||||
smem_pipe_write = smem_pipe_read;
|
||||
++smem_pipe_read;
|
||||
smem_pipe_read = smem_pipe_read == KT::Stages ? 0 : smem_pipe_read;
|
||||
cute::for_each(cute::make_int_sequence<cute::size(scale)>{},
|
||||
[&](auto i) { scale(i) = tXrSFA(i) * tXrSFB(0); });
|
||||
}
|
||||
cute::clear(temp);
|
||||
cute::gemm(mma, tCrA, tCrB(cute::_, cute::_, cute::_, n_block), temp);
|
||||
if constexpr (n_block == KT::NUM_GROUP_N - 1)
|
||||
{
|
||||
cute::copy(s2r_copy_A, tXsA_read, tXrA);
|
||||
}
|
||||
promote(accum, temp, scale, n_block);
|
||||
});
|
||||
}
|
||||
// load tail
|
||||
cute::for_each(cute::make_int_sequence<KT::Stages - 2>{},
|
||||
[&](auto WaitIndex)
|
||||
{
|
||||
using WaitIndex_t = decltype(WaitIndex);
|
||||
cute::for_each(cute::make_int_sequence<KT::NUM_GROUP_N>{},
|
||||
[&](auto n_block)
|
||||
{
|
||||
if constexpr (n_block == KT::NUM_GROUP_N - 1)
|
||||
{
|
||||
tXsA_read = tXsA(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tXsB_read = tXsB(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tXsSFA_read = tXsSFA(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tXsSFB_read = tXsSFB(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
cute::cp_async_wait<KT::Stages - 3 - WaitIndex_t::value>();
|
||||
__syncthreads();
|
||||
cute::copy(s2r_copy_SFA, tXsSFA_read, tXrSFA);
|
||||
cute::copy(s2r_copy_SFB, tXsSFB_read, tXrSFB);
|
||||
}
|
||||
auto n_block_next = (n_block + cute::_1{}) % KT::NUM_GROUP_N;
|
||||
cute::copy(s2r_copy_B, tXsB_read(cute::_, n_block_next, cute::_),
|
||||
tXrB(cute::_, cute::_, n_block_next));
|
||||
if constexpr (n_block == 0)
|
||||
{
|
||||
++smem_pipe_read;
|
||||
smem_pipe_read = smem_pipe_read == KT::Stages ? 0 : smem_pipe_read;
|
||||
cute::for_each(cute::make_int_sequence<cute::size(scale)>{},
|
||||
[&](auto i) { scale(i) = tXrSFA(i) * tXrSFB(0); });
|
||||
}
|
||||
cute::clear(temp);
|
||||
cute::gemm(mma, tCrA, tCrB(cute::_, cute::_, cute::_, n_block), temp);
|
||||
if constexpr (n_block == KT::NUM_GROUP_N - 1)
|
||||
{
|
||||
cute::copy(s2r_copy_A, tXsA_read, tXrA);
|
||||
}
|
||||
promote(accum, temp, scale, n_block);
|
||||
});
|
||||
});
|
||||
// mma tail
|
||||
cute::for_each(cute::make_int_sequence<KT::NUM_GROUP_N>{},
|
||||
[&](auto n_block)
|
||||
{
|
||||
auto n_block_next = (n_block + cute::_1{}) % KT::NUM_GROUP_N;
|
||||
cute::copy(s2r_copy_B, tXsB_read(cute::_, n_block_next, cute::_), tXrB(cute::_, cute::_, n_block_next));
|
||||
cute::clear(temp);
|
||||
if constexpr (n_block == 0)
|
||||
{
|
||||
cute::for_each(cute::make_int_sequence<cute::size(scale)>{},
|
||||
[&](auto i) { scale(i) = tXrSFA(i) * tXrSFB(0); });
|
||||
}
|
||||
cute::gemm(mma, tCrA, tCrB(cute::_, cute::_, cute::_, n_block), temp);
|
||||
promote(accum, temp, scale, n_block);
|
||||
});
|
||||
// epilogue
|
||||
__syncthreads(); // sync before using store smem
|
||||
typename KT::SharedStorageStore* store_storage
|
||||
= reinterpret_cast<typename KT::SharedStorageStore*>(SharedStorageBase);
|
||||
epilogue_with_smem(accum, *store_storage, ptr_output, M, N);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ada_blockwise_gemm
|
||||
@ -0,0 +1,248 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include <cuda_runtime.h>
|
||||
#include <cute/arch/mma.hpp>
|
||||
#include <cute/config.hpp>
|
||||
#include <cute/int_tuple.hpp>
|
||||
#include <cute/layout.hpp>
|
||||
#include <cutlass/arch/mma.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
|
||||
#define CUTLASS_HOST_TRACE(x) \
|
||||
{ \
|
||||
std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; \
|
||||
}
|
||||
|
||||
// Config
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890))
|
||||
#define CUTE_ARCH_MMA_F32_SM89_ENABLED
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
// MMA 16x8x32 TN
|
||||
struct SM89_16x8x32_F32F8F8F32_TN
|
||||
{
|
||||
using DRegisters = float[4];
|
||||
using ARegisters = uint32_t[4];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = float[4];
|
||||
|
||||
CUTE_HOST_DEVICE static void fma(float& d0, float& d1, float& d2, float& d3, uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& a2, uint32_t const& a3, uint32_t const& b0, uint32_t const& b1, float const& c0,
|
||||
float const& c1, float const& c2, float const& c3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_F32_SM89_ENABLED)
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5, %6, %7},"
|
||||
"{%8, %9},"
|
||||
"{%10, %11, %12, %13};\n"
|
||||
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3)
|
||||
: "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "f"(c0), "f"(c1), "f"(c2), "f"(c3));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH(
|
||||
"Attempting to use SM89_16x8x32_F32F8F8F32_TN without "
|
||||
"CUTE_ARCH_MMA_F32_SM89_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM89_16x8x32_F32F8F8F32_TN>
|
||||
{
|
||||
using ValTypeD = float;
|
||||
using ValTypeA = float_e4m3_t;
|
||||
using ValTypeB = float_e4m3_t;
|
||||
using ValTypeC = float;
|
||||
|
||||
using Shape_MNK = Shape<_16, _8, _32>;
|
||||
using ThrID = Layout<_32>;
|
||||
using ALayout = Layout<Shape<Shape<_4, _8>, Shape<_4, _2, _2>>, Stride<Stride<_64, _1>, Stride<_16, _8, _256>>>;
|
||||
using BLayout = Layout<Shape<Shape<_4, _8>, Shape<_4, _2>>, Stride<Stride<_32, _1>, Stride<_8, _128>>>;
|
||||
using CLayout = SM80_16x8_Row;
|
||||
};
|
||||
|
||||
} // namespace cute
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass;
|
||||
using namespace cutlass::gemm;
|
||||
|
||||
namespace ada_blockwise_gemm
|
||||
{
|
||||
|
||||
template <typename Element, typename Arch>
|
||||
struct DefaultGemm_TensorOp_MMA;
|
||||
|
||||
template <>
|
||||
struct DefaultGemm_TensorOp_MMA<cute::bfloat16_t, cutlass::arch::Sm80>
|
||||
{
|
||||
using ArchTag = cutlass::arch::Sm80;
|
||||
using MMA_Atom_Arch = cute::MMA_Atom<cute::SM80_16x8x16_F32BF16BF16F32_TN>;
|
||||
using ThreadLayoutMNK = cute::Layout<cute::Shape<cute::_2, cute::_2, cute::_1>>;
|
||||
using ValLayoutMNK = cute::Tile<cute::_32, cute::_32, cute::_16>;
|
||||
using TiledMma = cute::TiledMMA<MMA_Atom_Arch, ThreadLayoutMNK, ValLayoutMNK>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DefaultGemm_TensorOp_MMA<cute::float_e4m3_t, cutlass::arch::Sm89>
|
||||
{
|
||||
using ArchTag = cutlass::arch::Sm89;
|
||||
using MMA_Atom_Arch = cute::MMA_Atom<cute::SM89_16x8x32_F32F8F8F32_TN>;
|
||||
using ThreadLayoutMNK = cute::Layout<cute::Shape<cute::_2, cute::_2, cute::_1>>;
|
||||
using ValLayoutMNK = cute::Tile<cute::_32, cute::_32, cute::_32>;
|
||||
using TiledMma = cute::TiledMMA<MMA_Atom_Arch, ThreadLayoutMNK, ValLayoutMNK>;
|
||||
};
|
||||
|
||||
template <typename ElementType, typename OutElementType, typename AccumElementType, typename BlockScaleElementType,
|
||||
int Stages_, int TileM_, int TileN_, int TileK_>
|
||||
struct AdaBlockwiseGemmTraits
|
||||
{
|
||||
using ElementInput = ElementType;
|
||||
using ElementOutput = OutElementType;
|
||||
using ElementAccumulator = float;
|
||||
using ElementBlockScale = float;
|
||||
|
||||
using index_t = uint32_t;
|
||||
static_assert(TileM_ % 16 == 0);
|
||||
static_assert(TileN_ % 32 == 0);
|
||||
static_assert(TileK_ % 32 == 0);
|
||||
static constexpr int Stages = Stages_;
|
||||
static constexpr int kTileM = TileM_;
|
||||
static constexpr int kTileN = TileN_;
|
||||
static constexpr int kTileK = TileK_;
|
||||
|
||||
using TileShape = Shape<Int<kTileM>, Int<kTileN>, Int<kTileK>>;
|
||||
static constexpr int kWarpsCount = 4;
|
||||
static constexpr int kThreadCount = kWarpsCount * 32;
|
||||
|
||||
static constexpr int ScaleGranularityM = 1;
|
||||
static constexpr int ScaleGranularityN = 128;
|
||||
static constexpr int ScaleGranularityK = 128;
|
||||
|
||||
static constexpr int ScaleMsPerTile = (kTileM + ScaleGranularityM - 1) / ScaleGranularityM;
|
||||
static constexpr int ScaleNsPerTile = (kTileN + ScaleGranularityN - 1) / ScaleGranularityN;
|
||||
static constexpr int ScaleKsPerTile = (kTileK + ScaleGranularityK - 1) / ScaleGranularityK;
|
||||
|
||||
using ScaleGranularity = Shape<Int<ScaleGranularityM>, Int<ScaleGranularityN>, Int<ScaleGranularityK>>;
|
||||
using ScalePerTileShape = Shape<Int<ScaleMsPerTile>, Int<ScaleNsPerTile>, Int<ScaleKsPerTile>>;
|
||||
|
||||
// ====== mma ======
|
||||
static constexpr int kMmaPermM = 32;
|
||||
static constexpr int kMmaPermN = 32;
|
||||
static constexpr int kMmaPermK = 32;
|
||||
constexpr static int NUM_GROUP_M = kTileM / kMmaPermM;
|
||||
constexpr static int NUM_GROUP_N = kTileN / kMmaPermN;
|
||||
constexpr static int NUM_GROUP_K = kTileK / kMmaPermK;
|
||||
using MMA_Atom = MMA_Atom<SM89_16x8x32_F32F8F8F32_TN>;
|
||||
using AtomLayoutMNK = Layout<Shape<_2, _2, _1>>;
|
||||
using PermutationMNK = Tile<Int<kMmaPermM>, Int<kMmaPermN>, Int<kMmaPermK>>;
|
||||
using TiledMma = TiledMMA<MMA_Atom, AtomLayoutMNK, PermutationMNK>;
|
||||
|
||||
// ====== load gmem -> smem ======
|
||||
using GmemTiledCopyLoad = decltype(make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, ElementInput>{},
|
||||
Layout<Shape<_16, _8>, Stride<_8, _1>>{}, Layout<Shape<_1, _16>>{}));
|
||||
|
||||
using GmemTiledCopyA = GmemTiledCopyLoad;
|
||||
using GmemTiledCopyB = GmemTiledCopyLoad;
|
||||
|
||||
// ====== load smem -> rf ======
|
||||
using SmemAtomLayoutLoad = decltype(composition(Swizzle<3, 4, 3>{}, Layout<Shape<_16, _128>, Stride<_128, _1>>{}));
|
||||
using SmemLayoutA = decltype(tile_to_shape(SmemAtomLayoutLoad{}, Shape<Int<kTileM>, Int<kTileK>, Int<Stages>>{}));
|
||||
using SmemLayoutB = decltype(tile_to_shape(SmemAtomLayoutLoad{}, Shape<Int<kTileN>, Int<kTileK>, Int<Stages>>{}));
|
||||
|
||||
using SmemCopyAtomLoad = Copy_Atom<SM75_U32x4_LDSM_N, ElementInput>;
|
||||
using SmemCopyAtomA = SmemCopyAtomLoad;
|
||||
using SmemCopyAtomB = SmemCopyAtomLoad;
|
||||
|
||||
// ====== store rf -> smem ======
|
||||
using SmemAtomLayoutStore = decltype(composition(
|
||||
Swizzle<3, 3, 3>{}, Layout<Shape<_8, Shape<_8, _8>>, Stride<_8, Stride<_1, _64>>>{})); // 8x64
|
||||
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemAtomLayoutStore{}, Shape<Int<kTileM>, Int<kTileN>>{}));
|
||||
|
||||
using SmemCopyAtomR2S = Copy_Atom<AutoVectorizingCopy, ElementOutput>;
|
||||
|
||||
// ====== store smem -> gmem ======
|
||||
using SmemCopyAtomS2R = Copy_Atom<UniversalCopy<uint128_t>, ElementOutput>;
|
||||
using GmemCopyAtomR2G = SmemCopyAtomS2R;
|
||||
|
||||
using TiledCopyS2G = decltype(make_tiled_copy(
|
||||
SmemCopyAtomS2R{}, Layout<Shape<_16, _8>, Stride<_8, _1>>{}, Layout<Shape<_1, _8>>{})); // 16x64
|
||||
|
||||
// ====== load scale gmem -> smem ======
|
||||
using GmemCopyAtomScale = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
||||
using GmemLayoutTVSFA = Layout<Shape<Shape<Int<ScaleMsPerTile>, Int<kThreadCount / ScaleMsPerTile>>, Shape<_1, _1>>,
|
||||
Stride<Stride<_1, _0>, Stride<_1, _1>>>;
|
||||
using GmemTileShapeSFA = Shape<Int<ScaleMsPerTile>, Int<ScaleKsPerTile>>;
|
||||
using GmemTiledCopySFA = decltype(make_tiled_copy_impl(GmemCopyAtomScale{}, GmemLayoutTVSFA{}, GmemTileShapeSFA{}));
|
||||
|
||||
using GmemLayoutTVSFB = Layout<Shape<Shape<_32, _4>, Shape<_1, _1>>, Stride<Stride<_0, _0>, Stride<_1, _1>>>;
|
||||
using GmemTileShapeSFB = Shape<Int<ScaleNsPerTile>, Int<ScaleKsPerTile>>;
|
||||
using GmemTiledCopySFB = decltype(make_tiled_copy_impl(GmemCopyAtomScale{}, GmemLayoutTVSFB{}, GmemTileShapeSFB{}));
|
||||
|
||||
// ====== load scale smem -> rf ======
|
||||
using SmemCopyAtomScale = Copy_Atom<UniversalCopy<ElementBlockScale>, ElementBlockScale>;
|
||||
using SmemLayoutTVSFA
|
||||
= Layout<Shape<Shape<_4, _8, _2, _2>, Shape<_2>>, Stride<Stride<_0, _1, _16, _0>, Stride<_8, _0>>>;
|
||||
using SmemTileShapeSFA = Shape<Int<kMmaPermM>, _1>;
|
||||
using SmemTiledCopySFA = decltype(make_tiled_copy_impl(SmemCopyAtomScale{}, SmemLayoutTVSFA{}, SmemTileShapeSFA{}));
|
||||
|
||||
using SmemLayoutSFA = decltype(tile_to_shape(make_layout(SmemTileShapeSFA{}),
|
||||
make_shape(
|
||||
shape<0>(ScalePerTileShape{}), shape<2>(ScalePerTileShape{}), Int<Stages>{}))); // BLK_M, BLK_K, Stages
|
||||
|
||||
using SmemLayoutTVSFB
|
||||
= Layout<Shape<Shape<_4, _8, _2, _2>, Shape<_1>>, Stride<Stride<_0, _0, _0, _0>, Stride<_0, _0>>>;
|
||||
using SmemTileShapeSFB = Shape<_1, _1>;
|
||||
using SmemTiledCopySFB = decltype(make_tiled_copy_impl(SmemCopyAtomScale{}, SmemLayoutTVSFB{}, SmemTileShapeSFB{}));
|
||||
|
||||
using SmemLayoutSFB = decltype(tile_to_shape(make_layout(SmemTileShapeSFB{}),
|
||||
make_shape(
|
||||
shape<1>(ScalePerTileShape{}), shape<2>(ScalePerTileShape{}), Int<Stages>{}))); // BLK_N, BLK_K, Stages
|
||||
|
||||
// we need at least 2 stages..
|
||||
static_assert(Stages >= 2);
|
||||
|
||||
struct SharedStorageLoad : aligned_struct<128>
|
||||
{
|
||||
array_aligned<ElementInput, cosize_v<SmemLayoutA>> smem_a;
|
||||
array_aligned<ElementInput, cosize_v<SmemLayoutB>> smem_b;
|
||||
array_aligned<float, cosize_v<SmemLayoutSFA>> smem_sfa;
|
||||
array_aligned<float, cosize_v<SmemLayoutSFB>> smem_sfb;
|
||||
};
|
||||
|
||||
struct SharedStorageStore : aligned_struct<128>
|
||||
{
|
||||
array_aligned<ElementOutput, cosize_v<SmemLayoutO>> smem_o;
|
||||
};
|
||||
|
||||
union SharedStorage
|
||||
{
|
||||
SharedStorageLoad load;
|
||||
SharedStorageStore store;
|
||||
};
|
||||
|
||||
static constexpr int kSmemSize = static_cast<int>(sizeof(SharedStorage));
|
||||
};
|
||||
|
||||
} // namespace ada_blockwise_gemm
|
||||
@ -27,7 +27,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ada_blockwise_gemm/ada_blockwise_gemm.cuh"
|
||||
#include "ada_blockwise_gemm/sm89_fp8_gemm_1d1d.cuh"
|
||||
#include "fp8_blockscale_mma_utils.cuh"
|
||||
#include "fp8_blockscale_tma_utils.cuh"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
@ -1172,7 +1172,7 @@ template <typename InputType, typename OutputType, typename ScaleType = float>
|
||||
__global__ void scale_1x128_reshape_kernel(
|
||||
OutputType* output, ScaleType* scales, InputType const* const input, int dim_x, int dim_h, int dim_y, int stride_x)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890))
|
||||
size_t scales_along_dim_x = div_up(dim_x, 128);
|
||||
size_t scales_along_dim_y = div_up(dim_y, 1);
|
||||
size_t scales_along_dim_h = div_up(dim_h, 1);
|
||||
@ -1582,30 +1582,30 @@ void gemm_dispatch_sm89(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
|
||||
using ElementOutput = cute::bfloat16_t;
|
||||
using ElementAccum = float;
|
||||
using ElementBlockScale = float;
|
||||
static constexpr int Stages = 4;
|
||||
using TileShape = cutlass::gemm::GemmShape<32, 128, 128>; // only support 32x128x128 for now
|
||||
static constexpr int Stages = 3;
|
||||
using TileShape = cutlass::gemm::GemmShape<32, 128, 128>;
|
||||
using KT = ada_blockwise_gemm::AdaBlockwiseGemmTraits<ElementInput, ElementOutput, ElementAccum, ElementBlockScale,
|
||||
Stages, TileShape::kM, TileShape::kN, TileShape::kK>;
|
||||
using Gemm = ada_blockwise_gemm::AdaBlockwiseGemm<KT>;
|
||||
using GemmKernel = ada_blockwise_gemm::AdaBlockwiseGemmKernel<KT>;
|
||||
|
||||
int gemm_m = shape_m;
|
||||
int gemm_n = shape_n;
|
||||
int gemm_k = shape_k;
|
||||
typename KT::Arguments args({gemm_m, gemm_n, gemm_k}, mat_a, mat_b, mat_d, scales_a, scales_b);
|
||||
static constexpr int kSmemSize = KT::kSmemSize;
|
||||
static constexpr int kThreadCount = KT::kThreadCount;
|
||||
int grid_m = (shape_m + KT::kTileM - 1) / KT::kTileM;
|
||||
int grid_n = (shape_n + KT::kTileN - 1) / KT::kTileN;
|
||||
int grid_k = 1;
|
||||
dim3 grid = dim3(grid_m, grid_n, grid_k);
|
||||
dim3 block = dim3(kThreadCount, 1, 1);
|
||||
|
||||
Gemm gemm{};
|
||||
if (kSmemSize > (48 << 10))
|
||||
{
|
||||
cudaFuncSetAttribute(ada_blockwise_gemm::sm89_fp8_gemm_1d1d_impl<GemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize);
|
||||
auto result = cudaGetLastError();
|
||||
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm89 gemm kernel cannot launch: %s", cudaGetErrorString(result));
|
||||
}
|
||||
|
||||
auto status = gemm.can_implement(args);
|
||||
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "This kernel is not supported. Last CUDA error is: %s",
|
||||
cutlassGetStatusString(status));
|
||||
|
||||
status = gemm.initialize(args);
|
||||
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess,
|
||||
"Failed to initialize the CUTLASS kernel. Last CUDA error is: %s", cutlassGetStatusString(status));
|
||||
|
||||
status = gemm.run(stream);
|
||||
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess,
|
||||
"Failed to run the CUTLASS kernel. Last CUDA error is: %s", cutlassGetStatusString(status));
|
||||
ada_blockwise_gemm::sm89_fp8_gemm_1d1d_impl<GemmKernel>
|
||||
<<<grid, block, kSmemSize, stream>>>(shape_m, shape_n, shape_k, mat_a, mat_b, mat_d, scales_a, scales_b);
|
||||
}
|
||||
|
||||
void fp8_gemm_run(__nv_fp8_e4m3* mat_a, int ld_a, __nv_fp8_e4m3* mat_b, int ld_b, __nv_bfloat16* mat_d, int ld_d,
|
||||
@ -1788,6 +1788,48 @@ void strided_batch_gemm_dispatch(__nv_fp8_e4m3* mat_a, int ld_a, int stride_a, _
|
||||
static_cast<uint32_t>(best_smem_size));
|
||||
}
|
||||
|
||||
void strided_batch_gemm_dispatch_sm89(__nv_fp8_e4m3* mat_a, int ld_a, int stride_a, __nv_fp8_e4m3* mat_b, int ld_b,
|
||||
int stride_b, __nv_bfloat16* mat_d, int ld_d, int stride_d, float* scales_a, int stride_scales_a, float* scales_b,
|
||||
uint32_t num_problems, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, cudaStream_t stream,
|
||||
int num_device_sms = kNumDeviceSMs)
|
||||
{
|
||||
|
||||
if (num_device_sms < 0)
|
||||
{
|
||||
num_device_sms = kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
|
||||
}
|
||||
using ElementInput = cute::float_e4m3_t;
|
||||
using ElementOutput = cute::bfloat16_t;
|
||||
using ElementAccum = float;
|
||||
using ElementBlockScale = float;
|
||||
static constexpr int Stages = 3;
|
||||
using TileShape = cutlass::gemm::GemmShape<32, 128, 128>;
|
||||
using KT = ada_blockwise_gemm::AdaBlockwiseGemmTraits<ElementInput, ElementOutput, ElementAccum, ElementBlockScale,
|
||||
Stages, TileShape::kM, TileShape::kN, TileShape::kK>;
|
||||
using GemmKernel = ada_blockwise_gemm::AdaBlockwiseGemmKernel<KT>;
|
||||
|
||||
static constexpr int kSmemSize = KT::kSmemSize;
|
||||
static constexpr int kThreadCount = KT::kThreadCount;
|
||||
int grid_m = (shape_m + KT::kTileM - 1) / KT::kTileM;
|
||||
int grid_n = (shape_n + KT::kTileN - 1) / KT::kTileN;
|
||||
int grid_k = num_problems;
|
||||
dim3 grid = dim3(grid_m, grid_n, grid_k);
|
||||
dim3 block = dim3(kThreadCount, 1, 1);
|
||||
|
||||
int stride_scales_b = ((shape_n + 128 - 1) / 128) * ((shape_k + 128 - 1) / 128);
|
||||
|
||||
if (kSmemSize > (48 << 10))
|
||||
{
|
||||
cudaFuncSetAttribute(ada_blockwise_gemm::sm89_fp8_bmm_1d1d_impl<GemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize);
|
||||
auto result = cudaGetLastError();
|
||||
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm89 gemm kernel cannot launch: %s", cudaGetErrorString(result));
|
||||
}
|
||||
ada_blockwise_gemm::sm89_fp8_bmm_1d1d_impl<GemmKernel><<<grid, block, kSmemSize, stream>>>(shape_m, shape_n,
|
||||
shape_k, mat_a, mat_b, mat_d, scales_a, scales_b, stride_a, stride_b, stride_d, stride_scales_a,
|
||||
stride_scales_b);
|
||||
}
|
||||
|
||||
void fp8_stride_batch_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_mat_a, float* scales_a, int ld_a,
|
||||
int stride_a, int stride_scales_a, __nv_bfloat16 const* mat_b, __nv_fp8_e4m3* fp8_mat_b, float* scales_b, int ld_b,
|
||||
int stride_b, __nv_bfloat16* mat_d, int ld_d, int stride_d, uint32_t num_problems, uint32_t shape_m,
|
||||
@ -1814,6 +1856,13 @@ void fp8_stride_batch_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_ma
|
||||
fp8_mat_b, scales_b, mat_b, shape_k, shape_n * num_problems);
|
||||
}
|
||||
|
||||
int arch = tensorrt_llm::common::getSMVersion();
|
||||
if (arch == 89)
|
||||
{
|
||||
strided_batch_gemm_dispatch_sm89(fp8_mat_a, ld_a, stride_a, fp8_mat_b, ld_b, stride_b, mat_d, ld_d, stride_d,
|
||||
scales_a, stride_scales_a, scales_b, num_problems, shape_m, shape_n, shape_k, stream);
|
||||
return;
|
||||
}
|
||||
if (kDeepGemmEnabled)
|
||||
{
|
||||
strided_batch_gemm_dispatch(fp8_mat_a, ld_a, stride_a, fp8_mat_b, ld_b, stride_b, mat_d, ld_d, stride_d,
|
||||
|
||||
@ -359,7 +359,7 @@ def fp8_block_scaling_bmm_out(
|
||||
out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
sm_version = get_sm_version()
|
||||
if sm_version == 90:
|
||||
if sm_version == 90 or sm_version == 89:
|
||||
mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
|
||||
mat1)
|
||||
torch.ops.trtllm.fp8_block_scaling_bmm_out(mat1_fp8, mat2_fp8,
|
||||
|
||||
@ -41,8 +41,6 @@ from utils.util import getSMVersion
|
||||
[torch.bfloat16],
|
||||
)
|
||||
def test_fp8_block_scale_gemm(dtype, m, k, n):
|
||||
if getSMVersion() == 89 and k == 7168 and n == 2112:
|
||||
pytest.skip("https://nvbugs/5328184")
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
a = torch.randn((m, k), device='cuda', dtype=dtype) / k
|
||||
@ -60,6 +58,55 @@ def test_fp8_block_scale_gemm(dtype, m, k, n):
|
||||
torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
getSMVersion() != 90 and getSMVersion() != 89,
|
||||
reason="The test is for Hopper and Ada only. Current SM is %d." %
|
||||
getSMVersion(),
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"k, n",
|
||||
[(7168, 2112), (512, 32768), (16384, 7168), (2048, 7168)],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"m",
|
||||
[7, 64, 128],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"num_groups",
|
||||
[4, 8, 16],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"dtype",
|
||||
[torch.bfloat16],
|
||||
)
|
||||
def test_fp8_block_scale_bmm(dtype, m, k, n, num_groups):
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
a = torch.randn((m, num_groups, k), device='cuda', dtype=dtype) / k
|
||||
|
||||
a_fp8, a_scales = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(a)
|
||||
|
||||
b = torch.randn((num_groups, n, k), device='cuda', dtype=dtype) / k
|
||||
b_fp8 = torch.zeros_like(b, device='cuda', dtype=torch.float8_e4m3fn)
|
||||
b_scales = torch.zeros((num_groups, (n + 127) // 128, (k + 127) // 128),
|
||||
device='cuda',
|
||||
dtype=torch.float)
|
||||
|
||||
for i in range(num_groups):
|
||||
b_fp8[i], b_scales[i] = per_block_cast_to_fp8(b[i])
|
||||
|
||||
output_expected = torch.einsum('mgk,gnk->gmn', a, b)
|
||||
output = torch.empty((num_groups, m, n),
|
||||
device='cuda',
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
torch.ops.trtllm.fp8_block_scaling_bmm_out(a_fp8, b_fp8, a_scales, b_scales,
|
||||
output)
|
||||
diff = calc_diff(output, output_expected)
|
||||
assert diff < 1e-3
|
||||
torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
def deepSeekFp8ComputeGemmReference(mM, mN, mK, valsC, dqSfsC, valsA, dqSfsA,
|
||||
valsB, dqSfsB, quantizeOutput, tileSize):
|
||||
for mi in range(mM):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user