TensorRT-LLMs/cpp/kernels/fmha_v2/src/fmha/smem_tile.h
qsang-nv 0fd59d64ab
infra: open source fmha v2 kernels (#4185)
* add fmha repo

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix format

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix code style

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix header

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix header kernel_traits.h

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* add .gitignore file

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* add SLIDING_WINDOW_ATTENTION

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix style

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix format

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* update setup.py

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* update build_wheel.py

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

---------

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
Signed-off-by: qsang-nv <200703406+qsang-nv@users.noreply.github.com>
2025-05-15 10:56:34 +08:00

2619 lines
87 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#pragma once
#include <fmha/fragment.h>
#include <fmha/traits.h>
#include <fmha/utils.h>
namespace fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The description of the tile computed by this CTA.
typename Cta_tile,
// The number of rows in the 2D shared memory buffer.
int M_,
// The number of cols.
int N_,
// The size in bits of each element.
int BITS_PER_ELEMENT_,
// The number of bytes per STS.
int BYTES_PER_STS_ = 16,
// The number of buffers. (Used in multistage and double buffer cases.)
int BUFFERS_PER_TILE_ = 1,
// Do we enable the fast path for LDS.128 and friends.
int ENABLE_LDS_FAST_PATH_ = 0,
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
int ROWS_PER_XOR_PATTERN_ = 8,
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
int COLS_PER_XOR_PATTERN_ = 1,
// Use or not predicates
bool USE_PREDICATES_ = true,
// Use TMA or not,
bool USE_TMA_ = false,
// The leading dim elements in shared memory
int LEAD_DIM_ELEMENTS_ = N_>
struct Smem_tile_without_skews
{
// The type of this tile
using Smem_tile_ = Smem_tile_without_skews<Cta_tile, M_, N_, BITS_PER_ELEMENT_, BYTES_PER_STS_, BUFFERS_PER_TILE_,
ENABLE_LDS_FAST_PATH_, ROWS_PER_XOR_PATTERN_, COLS_PER_XOR_PATTERN_, USE_PREDICATES_>;
static constexpr bool USE_TMA = USE_TMA_;
// The size in bits of each element.
enum
{
BITS_PER_ELEMENT = BITS_PER_ELEMENT_
};
// The size in bytes of a single STS.
enum
{
BYTES_PER_STS = BYTES_PER_STS_
};
// The number of elements per STS.
enum
{
ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT
};
// To support arbitrary N, we pad some values to a power-of-2.
enum
{
N_WITH_PADDING = Next_power_of_two<LEAD_DIM_ELEMENTS_>::VALUE
};
// The number of bytes per row without packing of rows.
enum
{
BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8
};
// The number of bytes per row -- we want at least 128B per row.
enum
{
BYTES_PER_ROW = Max<BYTES_PER_ROW_BEFORE_PACKING, 128>::VALUE
};
// The number of rows in shared memory (two rows may be packed into a single one).
enum
{
ROWS = M_ * N_ / LEAD_DIM_ELEMENTS_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW
};
// The number of threads per row.
enum
{
THREADS_PER_ROW_UNBOUNDED = BYTES_PER_ROW / BYTES_PER_STS
};
// The number of threads per row.
enum
{
THREADS_PER_ROW = Min<Cta_tile::THREADS_PER_CTA, THREADS_PER_ROW_UNBOUNDED>::VALUE
};
// The number of STS per row.
enum
{
STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS
};
// It must be at least one.
static_assert(STS_PER_ROW >= 1, "");
// The number of rows written with a single STS.
enum
{
ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW
};
// Make sure we write to at least one row per STS. Thanks Dr. Obvious ;)
static_assert(ROWS_PER_STS >= 1, "");
// The number of STS needed to store all rows.
enum
{
STS_PER_COL = Div_up<ROWS, ROWS_PER_STS>::VALUE
};
// The number of STS in total.
enum
{
STS = STS_PER_COL * STS_PER_ROW
};
// The size of one buffer in bytes in shared memory.
enum
{
BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA
};
// The number of buffers.
enum
{
BUFFERS_PER_TILE = BUFFERS_PER_TILE_
};
// The size in bytes of total buffers.
enum
{
BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE
};
// The boundary for smem_read_offset and smem_write_offset increment.
enum
{
BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER
};
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
enum
{
ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_
};
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
enum
{
COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS
};
// Use or not predicates
enum
{
USE_PREDICATES = USE_PREDICATES_
};
// The bytes of one shmem row
enum
{
BYTES_PER_SHMEM_ROW = 128
};
// The type of elements that are stored in shared memory by each thread.
using Store_type = typename Uint_from_size_in_bytes<BYTES_PER_STS>::Type;
// Ctor.
inline __device__ Smem_tile_without_skews(void* smem, int tidx)
: smem_(__nvvm_get_smem_pointer(smem))
{
// The row written by a thread. See doc/mma_smem_layout.xlsx.
int smem_write_row = tidx / THREADS_PER_ROW;
// The XOR pattern.
int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN * COLS_PER_XOR_PATTERN;
// Compute the column and apply the XOR pattern.
int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor;
// The offset.
this->smem_write_offset_ = smem_write_row * BYTES_PER_ROW + smem_write_col * BYTES_PER_STS;
// That code is expected to trigger the utilization of the URF by the compiler.
this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0);
this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0);
}
// Compute the store pointers.
template <int N, int K = 1>
inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N])
{
#pragma unroll
for (int ii = 0; ii < N; ++ii)
{
// Decompose the STS into row/col.
int row = ii % STS_PER_COL;
int col = ii / STS_PER_COL;
// Compute the immediate.
int imm = row;
// Assemble the offset.
int offset = smem_write_offset_ + imm * ROWS_PER_STS * BYTES_PER_ROW;
// Take the column into account.
if (STS_PER_ROW > 1)
{
offset += col * THREADS_PER_ROW * BYTES_PER_STS;
}
// Apply the XOR pattern if needed.
if (ROWS_PER_STS < ROWS_PER_XOR_PATTERN)
{
int const m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN;
offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS;
}
// Assemble the final pointer :)
#pragma unroll
for (int k = 0; k < K; k++)
{
ptrs[ii * K + k] = smem_ + offset + k * (BYTES_PER_STS / K) + smem_write_buffer_;
}
}
}
inline __device__ void debug_reset()
{
for (int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER)
{
for (int row = 0; row < ROWS; ++row)
{
for (int col = 0; col < BYTES_PER_ROW; col += 4)
{
if (threadIdx.x == 0)
{
uint32_t val = 0x0;
sts(val, smem_ + row * BYTES_PER_ROW + col + buffer);
}
}
}
}
}
// Print the content of the tile (only for debug ;)).
inline __device__ void debug_print() const
{
for (int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER)
{
for (int row = 0; row < ROWS; ++row)
{
for (int col = 0; col < BYTES_PER_ROW; col += 4)
{
if (threadIdx.x == 0)
{
uint32_t val;
lds(val, smem_ + row * BYTES_PER_ROW + col + buffer);
printf(
"block=(x=%2d, y=%2d, z=%2d) (smem_=0x%08x, buffer=%2d, row=%2d, "
"byte=%4d)=0x%08x\n",
blockIdx.x, blockIdx.y, blockIdx.z, smem_, buffer, row, col, val);
}
}
}
}
}
// Move the read offset to next buffer.
inline __device__ void move_to_next_read_buffer()
{
if (BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY)
{
this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;
}
else if (BUFFERS_PER_TILE > 1)
{
this->smem_read_buffer_ += BYTES_PER_BUFFER;
}
}
// Move the read offset to next buffer. TODO: Remove this member function!!!
inline __device__ void move_next_read_buffer()
{
this->move_to_next_read_buffer();
}
// Move the read offset to next N buffer (circular-buffer).
inline __device__ void move_to_next_read_buffer(int N)
{
if (BUFFERS_PER_TILE > 1)
{
this->smem_read_buffer_ += N * BYTES_PER_BUFFER;
this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0;
}
}
// Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!!
inline __device__ void move_next_read_buffer(int N)
{
this->move_to_next_read_buffer(N);
}
// Move the write offset to next buffer.
inline __device__ void move_to_next_write_buffer()
{
if (BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY)
{
this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;
}
else if (BUFFERS_PER_TILE > 1)
{
this->smem_write_buffer_ += BYTES_PER_BUFFER;
}
}
// Move the write offset to next buffer. TODO: Remove that member function!
inline __device__ void move_next_write_buffer()
{
this->move_to_next_write_buffer();
}
// Move the read offset.
inline __device__ void move_read_offset(int delta)
{
this->smem_read_offset_ += delta;
}
// Move the write offset.
inline __device__ void move_write_offset(int delta)
{
this->smem_write_offset_ += delta;
}
// Store to the tile in shared memory.
template <int N>
inline __device__ void store(Store_type const (&data)[N])
{
uint32_t smem_ptrs[N];
this->compute_store_pointers(smem_ptrs);
sts(smem_ptrs, data);
}
// Store to the tile in shared memory.
template <int N, int M>
inline __device__ void store(Store_type const (&data)[N], uint32_t (&preds)[M])
{
uint32_t smem_ptrs[N];
this->compute_store_pointers(smem_ptrs);
sts(smem_ptrs, data, preds);
}
// Store to the tile in shared memory.
template <int N>
inline __device__ void store(Store_type const (&data)[N], uint32_t preds)
{
this->store(data, preds);
}
// Store to the tile in shared memory. TODO: Remove last template arguments.
template <int N, int M>
inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t (&preds)[M])
{
uint32_t smem_ptrs[N];
this->compute_store_pointers<N>(smem_ptrs);
ldgsts<N, M>(smem_ptrs, gmem_ptrs, preds);
}
// Store to the tile in shared memory.
template <int N>
inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t preds, uint64_t = 0)
{
uint32_t tmp[1] = {preds};
this->store(gmem_ptrs, tmp);
}
// Store to the tile in shared memory.
template <int N>
inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t preds)
{
uint32_t tmp[1] = {preds};
this->store(gmem_ptrs, tmp);
}
inline __device__ void add_smem_barrier_base(uint64_t*) {}
// The shared memory pointer.
uint32_t smem_;
// The read offset. Reserve 4 offsets if needed.
int smem_read_offset_;
// The write offset.
int smem_write_offset_;
// The buffer base offset for read.
int smem_read_buffer_;
// The buffer base offset for write.
int smem_write_buffer_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Use TMA
template <
// The description of the tile computed by this CTA.
typename Cta_tile,
// The number of rows in the 2D shared memory buffer.
int M_,
// The number of cols.
int N_,
// The size in bits of each element.
int BITS_PER_ELEMENT_,
// The number of bytes per STS. Not relevant for TMA
int BYTES_PER_STS_,
// The number of buffers. (Used in multistage and double buffer cases.)
int BUFFERS_PER_TILE_,
// Do we enable the fast path for LDS.128 and friends.
int ENABLE_LDS_FAST_PATH_,
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
int ROWS_PER_XOR_PATTERN_,
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
int COLS_PER_XOR_PATTERN_,
// Use or not predicates
bool USE_PREDICATES_,
// The leading dim elements in shared memory
int LEAD_DIM_ELEMENTS_>
struct Smem_tile_without_skews<Cta_tile, M_, N_, BITS_PER_ELEMENT_, BYTES_PER_STS_, BUFFERS_PER_TILE_,
ENABLE_LDS_FAST_PATH_, ROWS_PER_XOR_PATTERN_, COLS_PER_XOR_PATTERN_, USE_PREDICATES_, true, LEAD_DIM_ELEMENTS_>
: public Smem_tile_without_skews<Cta_tile, M_, N_, BITS_PER_ELEMENT_, BYTES_PER_STS_, BUFFERS_PER_TILE_,
ENABLE_LDS_FAST_PATH_, ROWS_PER_XOR_PATTERN_, COLS_PER_XOR_PATTERN_, USE_PREDICATES_, false,
LEAD_DIM_ELEMENTS_>
{
// Base struct
using Base = Smem_tile_without_skews<Cta_tile, M_, N_, BITS_PER_ELEMENT_, BYTES_PER_STS_, BUFFERS_PER_TILE_,
ENABLE_LDS_FAST_PATH_, ROWS_PER_XOR_PATTERN_, COLS_PER_XOR_PATTERN_, USE_PREDICATES_, false,
LEAD_DIM_ELEMENTS_>;
static constexpr bool USE_TMA = true;
// Tile size overrides. STS per thread not relevant for TMA
static constexpr int BYTES_PER_BUFFER = M_ * N_ * Base::BITS_PER_ELEMENT / 8;
static constexpr int BYTES_PER_TILE = BYTES_PER_BUFFER * Base::BUFFERS_PER_TILE;
static constexpr int BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER;
// The number of bytes per barrier
static constexpr int BYTES_PER_BARRIER = 8;
// Ctor
inline __device__ Smem_tile_without_skews(void* smem, int tidx)
: Base(smem, tidx)
{
this->smem_write_offset_ = __nvvm_get_smem_pointer(smem);
this->smem_barrier_offset_ = 0;
this->elect_one_ = elect_one_sync();
}
inline __device__ void add_smem_barrier_base(uint64_t* smem_barrier)
{
this->smem_barrier_ = smem_barrier;
this->smem_barrier_offset_ = __nvvm_get_smem_pointer(this->smem_barrier_);
}
/**
* \brief load tensor blocks from global memory and stores to shared memory using tma instructions
*
* \param p_desc pointer to tma descriptor masked as const void* pointer
* \param smem_offset shared memory offset in bytes relative to smem_write_buffer_
* \param coord0 tensor access coordinate in dimension 1, used by tma load
* \param coord1 tensor access coordinate in dimension 2, used by tma load
* \param coord2 tensor access coordinate in dimension 3, used by tma load
* \param coord3 tensor access coordinate in dimension 4, used by tma load
* \param coord4 tensor access coordinate in dimension 5, used by tma load
* \param filter_offsets encodes multicast cta id and filter offsets
*/
template <uint32_t DIM, cudaTmaDescType DESC_TYPE, unsigned COPY_BYTES, bool USE_TMA_MULTICAST = false>
inline __device__ void store(void const* p_desc, unsigned const& smem_offset, int32_t coord0, int32_t coord1,
int32_t coord2, int32_t coord3, int32_t coord4, uint16_t filter_offsets, uint16_t mcast_cta_mask,
uint64_t mem_desc)
{
uint32_t smem = this->smem_write_offset_ + smem_offset;
fmha::utmaldg<DIM, DESC_TYPE, USE_TMA_MULTICAST>(reinterpret_cast<cudaTmaDesc const*>(p_desc), smem,
unsigned(this->smem_barrier_offset_), coord0, coord1, coord2, coord3, coord4, filter_offsets,
mcast_cta_mask, mem_desc, this->elect_one_);
}
// Same function as above but for runtime cga dimension
template <uint32_t DIM, cudaTmaDescType DESC_TYPE, unsigned COPY_BYTES>
inline __device__ void store(void const* p_desc, unsigned const& smem_offset, int32_t coord0, int32_t coord1,
int32_t coord2, int32_t coord3, int32_t coord4, uint16_t filter_offsets, uint16_t mcast_cta_mask,
uint64_t mem_desc, bool mcast_enabled)
{
uint32_t smem = this->smem_write_offset_ + smem_offset;
fmha::utmaldg<DIM, DESC_TYPE>(reinterpret_cast<cudaTmaDesc const*>(p_desc), smem,
unsigned(this->smem_barrier_offset_), coord0, coord1, coord2, coord3, coord4, filter_offsets,
mcast_cta_mask, mcast_enabled, mem_desc, this->elect_one_);
}
// Move the write offset to next buffer.
inline __device__ void move_next_write_buffer()
{
if (Base::BUFFERS_PER_TILE > 1)
{
this->smem_write_offset_ += (this->smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY)
? -BYTES_PER_TILE_INC_BOUNDARY
: BYTES_PER_BUFFER;
this->smem_barrier_offset_ += (this->smem_barrier_offset_ >= Base::BUFFERS_PER_TILE * BYTES_PER_BARRIER)
? -Base::BUFFERS_PER_TILE * BYTES_PER_BARRIER
: BYTES_PER_BARRIER;
}
}
inline __device__ void move_next_write_buffer(int buffer_id)
{
if (Base::BUFFERS_PER_TILE > 1)
{
this->smem_write_offset_ = this->smem_ + buffer_id * BYTES_PER_BUFFER;
}
this->smem_barrier_offset_ = __nvvm_get_smem_pointer(this->smem_barrier_ + buffer_id);
}
// Move the read offset to next buffer.
// do nothing, as it is controlled by gmma desc
inline __device__ void move_next_read_buffer() {}
uint64_t* smem_barrier_;
uint32_t smem_barrier_offset_;
// elect one thread to issue utmaldg
uint32_t elect_one_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The instruction traits.
typename Traits,
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The layout of the tile.
typename Layout,
// The size of the STS.
int BYTES_PER_STS = 16,
// The number of buffers per tile.
int BUFFERS_PER_TILE = 1,
// Use or not predicates
bool USE_PREDICATES = true>
struct Smem_tile_a
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits, int N>
struct Rows_per_xor_pattern_volta_a
{
// The size in bits.
enum
{
N_IN_BITS = N * Traits::BITS_PER_ELEMENT_A
};
// The number of rows.
enum
{
VALUE = N_IN_BITS <= 256 ? 1 : (N_IN_BITS <= 512 ? 2 : 4)
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int MMAS_K, int MMAS_K_WITH_PADDING>
struct Compute_reset_mask
{
// The potential mask.
enum
{
HALF = MMAS_K_WITH_PADDING / 2
};
// The remainder.
enum
{
MOD = MMAS_K % HALF
};
// The final value.
enum
{
VALUE = (MMAS_K == MOD ? 0 : HALF) | Compute_reset_mask<MOD, HALF>::VALUE
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int MMAS_K_WITH_PADDING>
struct Compute_reset_mask<0, MMAS_K_WITH_PADDING>
{
enum
{
VALUE = 0
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int MMAS_K>
struct Compute_reset_mask<MMAS_K, MMAS_K>
{
enum
{
VALUE = MMAS_K - 1
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The instruction traits.
typename Traits,
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_volta_a<Traits, Cta_tile::K>::VALUE>
struct Smem_tile_volta_row_a : public Smem_tile_without_skews<Cta_tile, Cta_tile::M, Cta_tile::K, 16, BYTES_PER_STS,
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>
{
// The MMA tile.
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::M, Cta_tile::K, 16, BYTES_PER_STS, BUFFERS_PER_TILE, 0,
ROWS_PER_XOR_PATTERN_, 1>;
// The fragment.
using Fragment = Fragment_a<Traits, Row>;
// When we use padding to reach a power of two, special care has to be taken.
using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Traits, Cta_tile>;
// The number of MMAs.
using Mma_tile_with_padding = typename Traits::template Mma_tile<Cta_tile_with_padding>;
// The size of a single LDS in bytes.
enum
{
BYTES_PER_LDS = 16
};
// Ctor.
inline __device__ Smem_tile_volta_row_a(void* smem, int tidx)
: Base(smem, tidx)
{
// For documentation on the layout, see doc/xmma_smem_layout.xlsx.
// The number of warps.
int const WARPS_M = Cta_tile::WARPS_M;
int const WARPS_N = Cta_tile::WARPS_N;
int const WARPS_K = Cta_tile::WARPS_K;
// The masks to select the warps.
int const WARP_MASK_M = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::M;
int const WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;
// The divisor for the warps.
int const WARP_DIV_M = 1 * 1 * Cta_tile::THREADS_PER_WARP;
int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;
// The row and column read by the thread.
int smem_read_row, smem_read_col;
if (Base::N_WITH_PADDING >= 64)
{
smem_read_row
= (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 1 + (tidx & 0x10) / 2 + (tidx & 0x07);
smem_read_col = (tidx & 0x03);
}
else if (Base::N_WITH_PADDING == 32)
{
smem_read_row
= (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 2 + (tidx & 0x10) / 4 + (tidx & 0x06) / 2;
smem_read_col = (tidx & 0x02) / 2 + (tidx & 0x01) * 4;
}
else
{
assert(false);
}
// For WARPS_K > 1, we do not support Base::N_WITH_PADDING < 64 for the moment.
static_assert(WARPS_K <= 2 && (WARPS_K == 1 || Base::N_WITH_PADDING >= 64), "");
// We "swap" the block for the second warp working on the in-CTA split-K.
if (WARPS_K == 2)
{
smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K;
}
// The shared memory offset.
this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS;
}
// Rewind smem_read_offset for last LDS phase in main loop.-
inline __device__ void reverse_smem_read_offset(int ki = 0)
{
// Move the offset to the next position. See doc/xmma_smem_layout.xlsx.
this->smem_read_offset_ ^= ((ki % 2 == 0) ? 1 : 3) * BYTES_PER_LDS;
}
// Load from shared memory.
inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki)
{
#pragma unroll
for (int mi = 0; mi < Mma_tile::MMAS_M; ++mi)
{
// Jump over as many rows as needed.
int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
// TODO: Could we fuse smem_read_buffer and smem_read_offset?
uint4 tmp;
lds(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
a[mi].reg(0) = tmp.x;
a[mi].reg(1) = tmp.y;
a[mi].reg(2) = tmp.z;
a[mi].reg(3) = tmp.w;
}
// Move the offset to the next position. See doc/xmma_smem_layout.xlsx.
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15)
{
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS;
}
else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7)
{
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS;
}
else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3)
{
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS;
}
else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1)
{
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS;
}
else if (Mma_tile_with_padding::MMAS_K >= 2)
{
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS;
}
}
// Reset the read offset.
inline __device__ void reset_read_offset()
{
// The number of MMAs in the K dimension.
enum
{
MMAS_K = Mma_tile::MMAS_K
};
// The number of MMAs in the K dimension when we include padding.
enum
{
MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K
};
// Assemble the mask.
enum
{
MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE
};
// Reset the read offset.
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_a<Volta_hmma_fp16_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_volta_row_a<Volta_hmma_fp16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = fmha::Volta_hmma_fp16_traits;
// The base class.
using Base = Smem_tile_volta_row_a<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_a(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits, int N>
struct Rows_per_xor_pattern_turing_a
{
// The size in bits.
enum
{
N_IN_BITS = N * Traits::BITS_PER_ELEMENT_A
};
// The number of rows.
enum
{
VALUE = N_IN_BITS <= 128 ? 1 : (N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8))
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The instruction traits.
typename Traits,
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_turing_a<Traits, Cta_tile::K>::VALUE>
struct Smem_tile_turing_row_a
: public Smem_tile_without_skews<Cta_tile, Cta_tile::M, Cta_tile::K, Traits::BITS_PER_ELEMENT_A, BYTES_PER_STS,
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>
{
// The MMA tile.
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::M, Cta_tile::K, Traits::BITS_PER_ELEMENT_A, BYTES_PER_STS,
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>;
// The fragment.
using Fragment = Fragment_a<Traits, Row>;
// When we use padding to reach a power of two, special care has to be taken.
using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Traits, Cta_tile>;
// The number of MMAs.
using Mma_tile_with_padding = typename Traits::template Mma_tile<Cta_tile_with_padding>;
// The size of a single LDS in bytes.
enum
{
BYTES_PER_LDS = 16
};
// Ctor.
inline __device__ Smem_tile_turing_row_a(void* smem, int tidx)
: Base(smem, tidx)
{
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
// The number of warps.
int const WARPS_M = Cta_tile::WARPS_M;
int const WARPS_N = Cta_tile::WARPS_N;
int const WARPS_K = Cta_tile::WARPS_K;
// The masks to select the warps.
int const WARP_MASK_M = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::M;
int const WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;
// The divisor for the warps.
int const WARP_DIV_M = 1 * 1 * Cta_tile::THREADS_PER_WARP;
int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;
// The row and column read by the thread.
int smem_read_row, smem_read_col;
static_assert(Base::ROWS_PER_XOR_PATTERN == 8 || Base::ROWS_PER_XOR_PATTERN == 4
|| Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 1,
"");
if (Base::ROWS_PER_XOR_PATTERN == 8)
{
smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 1 + (tidx & 0x0f);
smem_read_col = (tidx & 0x07);
}
else if (Base::ROWS_PER_XOR_PATTERN == 4)
{
smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 2 + (tidx & 0x0e) / 2;
smem_read_col = (tidx & 0x06) / 2 + (tidx & 0x01) * 4;
}
else if (Base::ROWS_PER_XOR_PATTERN == 2)
{
smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 4 + (tidx & 0x0c) / 4;
smem_read_col = (tidx & 0x04) / 4 + (tidx & 0x03) * 2;
}
else if (Base::ROWS_PER_XOR_PATTERN == 1)
{
smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 8 + (tidx & 0x1f) / 8;
smem_read_col = (tidx & 0x07);
}
static_assert(WARPS_K <= 2, "");
// We "swap" the block for the second warp working on the in-CTA split-K.
if (WARPS_K == 2)
{
smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K;
}
// The shared memory offset.
this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS;
}
// Rewind smem_read_offset for last LDS phase in main loop.-
inline __device__ void reverse_smem_read_offset(int ki = 0)
{
// Move the offset to the next position. See doc/mma_smem_layout.xlsx.
this->smem_read_offset_ ^= ((ki % 2 == 0) ? 1 : 3) * BYTES_PER_LDS;
}
// Load from shared memory.
inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki)
{
#pragma unroll
for (int mi = 0; mi < Mma_tile::MMAS_M; ++mi)
{
int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
uint2 tmp;
ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
a[mi].reg(0) = tmp.x;
a[mi].reg(1) = tmp.y;
}
// Move the offset to the next position. See doc/mma_smem_layout.xlsx.
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15)
{
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS;
}
else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7)
{
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS;
}
else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3)
{
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS;
}
else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1)
{
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS;
}
else if (Mma_tile_with_padding::MMAS_K >= 2)
{
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS;
}
}
// Reset the read offset.
inline __device__ void reset_read_offset()
{
// The number of MMAs in the K dimension.
enum
{
MMAS_K = Mma_tile::MMAS_K
};
// The number of MMAs in the K dimension when we include padding.
enum
{
MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K
};
// Assemble the mask.
enum
{
MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE
};
// Reset the read offset.
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_a<Turing_hmma_fp16_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_turing_row_a<Turing_hmma_fp16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = Turing_hmma_fp16_traits;
// The base class.
using Base = Smem_tile_turing_row_a<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_a(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_a<Turing_hmma_fp32_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_turing_row_a<Turing_hmma_fp32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = Turing_hmma_fp32_traits;
// The base class.
using Base = Smem_tile_turing_row_a<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_a(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_a<Turing_imma_int8_int32_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_turing_row_a<Turing_imma_int8_int32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = Turing_imma_int8_int32_traits;
// The base class.
using Base = Smem_tile_turing_row_a<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_a(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits, int N>
struct Rows_per_xor_pattern_ampere_a
{
// The size in bits.
enum
{
N_IN_BITS = N * Traits::BITS_PER_ELEMENT_A
};
// The number of rows.
enum
{
VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8)
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits, int N>
struct Rows_per_xor_pattern_ampere_row_a : public Rows_per_xor_pattern_ampere_a<Traits, N>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The instruction traits.
typename Traits,
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_ampere_row_a<Traits, Cta_tile::K>::VALUE>
struct Smem_tile_ampere_row_a
: public Smem_tile_without_skews<Cta_tile, Cta_tile::M, Cta_tile::K, Traits::BITS_PER_ELEMENT_A, BYTES_PER_STS,
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>
{
// The MMA tile.
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::M, Cta_tile::K, Traits::BITS_PER_ELEMENT_A, BYTES_PER_STS,
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>;
// The fragment.
using Fragment = Fragment_a<Traits, Row>;
// When we use padding to reach a power of two, special care has to be taken.
using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Traits, Cta_tile>;
// The number of MMAs.
using Mma_tile_with_padding = typename Traits::template Mma_tile<Cta_tile_with_padding>;
// The size of a single LDS in bytes.
enum
{
BYTES_PER_LDS = 16
};
// Ctor.
inline __device__ Smem_tile_ampere_row_a(void* smem, int tidx)
: Base(smem, tidx)
{
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
// The number of warps.
int const WARPS_M = Cta_tile::WARPS_M;
int const WARPS_N = Cta_tile::WARPS_N;
int const WARPS_K = Cta_tile::WARPS_K;
// The masks to select the warps.
int const WARP_MASK_M = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::M;
int const WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;
// The divisor for the warps.
int const WARP_DIV_M = 1 * 1 * Cta_tile::THREADS_PER_WARP;
int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;
// The row and column read by the thread.
int smem_read_row, smem_read_col;
static_assert(
Base::ROWS_PER_XOR_PATTERN == 8 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 2, "");
if (Base::ROWS_PER_XOR_PATTERN == 8)
{
smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 1 + (tidx & 0x0f);
smem_read_col = (tidx & 0x07);
smem_read_col ^= (tidx & 0x10) / 16;
}
else if (Base::ROWS_PER_XOR_PATTERN == 4)
{
smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 2 + (tidx & 0x0e) / 2;
smem_read_col = (tidx & 0x06) / 2 + (tidx & 0x01) * 4;
smem_read_col ^= (tidx & 0x10) / 16;
}
else if (Base::ROWS_PER_XOR_PATTERN == 2)
{
smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 4 + (tidx & 0x0c) / 4;
smem_read_col = (tidx & 0x04) / 4 + (tidx & 0x03) * 2;
smem_read_col ^= (tidx & 0x10) / 16;
}
static_assert(WARPS_K <= 2, "");
static_assert(WARPS_K != 2 || Base::ROWS_PER_XOR_PATTERN != 2, "");
// We "swap" the block for the second warp working on the same outputs in-CTA split-K.
if (WARPS_K == 2)
{
smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K * 2;
}
// The shared memory offset.
this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS;
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline __device__ void reverse_smem_read_offset(int ki = 0)
{
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if (Mma_tile_with_padding::MMAS_K >= 2)
{
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Load from shared memory.
inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki)
{
if (ki < Mma_tile::VALID_MMAS_K)
{
#pragma unroll
for (int mi = 0; mi < Mma_tile::MMAS_M; ++mi)
{
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
// Load using LDSM.M88.4.
uint4 tmp;
ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
// Store the value into the fragment.
a[mi].reg(0) = tmp.x;
a[mi].reg(1) = tmp.y;
a[mi].reg(2) = tmp.z;
a[mi].reg(3) = tmp.w;
}
}
// Move the offset to the next position. See doc/mma_smem_layout.xlsx.
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15)
{
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;
}
else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7)
{
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;
}
else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3)
{
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2;
}
else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1)
{
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2;
}
else if (Mma_tile_with_padding::MMAS_K >= 2)
{
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2;
}
}
// Reset the read offset.
inline __device__ void reset_read_offset()
{
// The number of MMAs in the K dimension.
enum
{
MMAS_K = Mma_tile::MMAS_K
};
// The number of MMAs in the K dimension when we include padding.
enum
{
MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K
};
// Assemble the mask.
enum
{
MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE
};
// Reset the read offset.
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_a<Ampere_hmma_fp16_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_ampere_row_a<Ampere_hmma_fp16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = Ampere_hmma_fp16_traits;
// The base class.
using Base = Smem_tile_ampere_row_a<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_a(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_a<Ampere_hmma_fp32_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_ampere_row_a<Ampere_hmma_fp32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = Ampere_hmma_fp32_traits;
// The base class.
using Base = Smem_tile_ampere_row_a<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_a(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_a<Ampere_hmma_bf16_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_ampere_row_a<Ampere_hmma_bf16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = Ampere_hmma_bf16_traits;
// The base class.
using Base = Smem_tile_ampere_row_a<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_a(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_a<Ampere_imma_int8_int32_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_ampere_row_a<Ampere_imma_int8_int32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = Ampere_imma_int8_int32_traits;
// The base class.
using Base = Smem_tile_ampere_row_a<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_a(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_a<Ada_qmma_e4m3_fp32_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_ampere_row_a<Ada_qmma_e4m3_fp32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = Ada_qmma_e4m3_fp32_traits;
// The base class.
using Base = Smem_tile_ampere_row_a<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_a(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_a<Ada_qmma_e4m3_fp16_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_ampere_row_a<Ada_qmma_e4m3_fp16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = Ada_qmma_e4m3_fp16_traits;
// The base class.
using Base = Smem_tile_ampere_row_a<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_a(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The instruction traits.
typename Traits,
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The layout of the tile.
typename Layout,
// The size of the STS.
int BYTES_PER_STS = 16,
// The number of buffers per tile.
int BUFFERS_PER_TILE = 1,
// Use or not predicates
bool USE_PREDICATES = true>
struct Smem_tile_b
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits, int N>
struct Rows_per_xor_pattern_volta_b
{
// The size in bits.
enum
{
N_IN_BITS = N * Traits::BITS_PER_ELEMENT_B
};
// The number of rows.
enum
{
VALUE = N_IN_BITS <= 256 ? 1 : (N_IN_BITS <= 512 ? 2 : 4)
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The instruction traits.
typename Traits,
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_volta_b<Traits, Cta_tile::K>::VALUE>
struct Smem_tile_volta_col_b : public Smem_tile_without_skews<Cta_tile, Cta_tile::N, Cta_tile::K, 16, BYTES_PER_STS,
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>
{
// The MMA tile.
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::N, Cta_tile::K, 16, BYTES_PER_STS, BUFFERS_PER_TILE, 0,
ROWS_PER_XOR_PATTERN_, 1>;
// When we use padding to reach a power of two, special care has to be taken.
using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Traits, Cta_tile>;
// The number of MMAs.
using Mma_tile_with_padding = typename Traits::template Mma_tile<Cta_tile_with_padding>;
// The fragment.
using Fragment = Fragment_b<Traits, Col>;
// The size of a single LDS in bytes.
enum
{
BYTES_PER_LDS = 16
};
// Ctor.
inline __device__ Smem_tile_volta_col_b(void* smem, int tidx)
: Base(smem, tidx)
{
// For documentation on the layout, see doc/xmma_smem_layout.xlsx.
// The number of warps.
int const WARPS_M = Cta_tile::WARPS_M;
int const WARPS_N = Cta_tile::WARPS_N;
int const WARPS_K = Cta_tile::WARPS_K;
// The masks to select the warps.
int const WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
int const WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;
// The divisor for the warps.
int const WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;
int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;
// The row and column read by the thread.
int smem_read_row, smem_read_col;
if (Base::N_WITH_PADDING >= 64)
{
smem_read_row
= (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 1 + (tidx & 0x18) / 2 + (tidx & 0x03);
smem_read_col = (tidx & 0x03);
}
else if (Base::N_WITH_PADDING == 32)
{
smem_read_row
= (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 2 + (tidx & 0x18) / 4 + (tidx & 0x02) / 2;
smem_read_col = (tidx & 0x02) / 2 + (tidx & 0x01) * 4;
}
else
{
assert(false);
}
// For WARPS_K > 1, we do not support Base::N_WITH_PADDING < 64 for the moment.
static_assert(WARPS_K <= 2 && (WARPS_K == 1 || Base::N_WITH_PADDING >= 64), "");
// We "swap" the block for the second warp working on the in-CTA split-K.
if (WARPS_K == 2)
{
smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K;
}
// The shared memory offset.
this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS;
}
// Rewind smem_read_offset for last LDS phase in main loop.-
inline __device__ void reverse_smem_read_offset(int ki = 0)
{
// Move the offset to the next position. See doc/xmma_smem_layout.xlsx.
this->smem_read_offset_ ^= ((ki % 2 == 0) ? 1 : 3) * BYTES_PER_LDS;
}
// Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki)
{
#pragma unroll
for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni)
{
// Jump over as many rows as needed.
int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
// TODO: Can we fuse read_offset and read_buffer?
uint4 tmp;
lds(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
b[ni].reg(0) = tmp.x;
b[ni].reg(1) = tmp.y;
b[ni].reg(2) = tmp.z;
b[ni].reg(3) = tmp.w;
}
// Move the offset to the next position. See doc/xmma_smem_layout.xlsx.
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15)
{
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS;
}
else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7)
{
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS;
}
else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3)
{
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS;
}
else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1)
{
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS;
}
else if (Mma_tile_with_padding::MMAS_K >= 2)
{
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS;
}
}
// Reset the read offset.
inline __device__ void reset_read_offset()
{
// The number of MMAs in the K dimension.
enum
{
MMAS_K = Mma_tile::MMAS_K
};
// The number of MMAs in the K dimension when we include padding.
enum
{
MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K
};
// Assemble the mask.
enum
{
MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE
};
// Reset the read offset.
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_b<Volta_hmma_fp16_traits, Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_volta_col_b<Volta_hmma_fp16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = fmha::Volta_hmma_fp16_traits;
// The base class.
using Base = Smem_tile_volta_col_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits, int N>
struct Rows_per_xor_pattern_turing_b
{
// The size in bits.
enum
{
N_IN_BITS = N * Traits::BITS_PER_ELEMENT_B
};
// The number of rows.
enum
{
VALUE = N_IN_BITS <= 128 ? 1 : (N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8))
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The instruction traits.
typename Traits,
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_turing_b<Traits, Cta_tile::K>::VALUE>
struct Smem_tile_turing_col_b
: public Smem_tile_without_skews<Cta_tile, Cta_tile::N, Cta_tile::K, Traits::BITS_PER_ELEMENT_B, BYTES_PER_STS,
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>
{
// The MMA tile.
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::N, Cta_tile::K, Traits::BITS_PER_ELEMENT_B, BYTES_PER_STS,
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>;
// The fragment.
using Fragment = Fragment_b<Traits, Col>;
// When we use padding to reach a power of two, special care has to be taken.
using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Traits, Cta_tile>;
// The number of MMAs.
using Mma_tile_with_padding = typename Traits::template Mma_tile<Cta_tile_with_padding>;
// The size of a single LDS in bytes.
enum
{
BYTES_PER_LDS = 16
};
// Ctor.
inline __device__ Smem_tile_turing_col_b(void* smem, int tidx)
: Base(smem, tidx)
{
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
// The number of warps.
int const WARPS_M = Cta_tile::WARPS_M;
int const WARPS_N = Cta_tile::WARPS_N;
int const WARPS_K = Cta_tile::WARPS_K;
// The masks to select the warps.
int const WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
int const WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;
// The divisor for the warps.
int const WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;
int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;
// The row and column read by the thread.
int smem_read_row, smem_read_col;
static_assert(Base::ROWS_PER_XOR_PATTERN == 8 || Base::ROWS_PER_XOR_PATTERN == 4
|| Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 1,
"");
if (Base::ROWS_PER_XOR_PATTERN == 8)
{
// For group fprop. B is divided into 2 halves along N dimension.
// The fist warp takes the first half and the second warp takes the second half.
smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 1 + (tidx & 0x0f);
smem_read_col = (tidx & 0x07);
}
else if (Base::ROWS_PER_XOR_PATTERN == 4)
{
smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 2 + (tidx & 0x0e) / 2;
smem_read_col = (tidx & 0x06) / 2 + (tidx & 0x01) * 4;
}
else if (Base::ROWS_PER_XOR_PATTERN == 2)
{
smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 4 + (tidx & 0x0c) / 4;
smem_read_col = (tidx & 0x04) / 4 + (tidx & 0x03) * 2;
}
else if (Base::ROWS_PER_XOR_PATTERN == 1)
{
smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 8 + (tidx & 0x1f) / 8;
smem_read_col = (tidx & 0x07);
}
static_assert(WARPS_K <= 2, "");
// We "swap" the block for the second warp working on the in-CTA split-K.
if (WARPS_K == 2)
{
smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K;
}
// The shared memory offset.
this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS;
}
// Rewind smem_read_offset for last LDS phase in main loop.-
inline __device__ void reverse_smem_read_offset(int ki = 0)
{
// Move the offset to the next position. See doc/mma_smem_layout.xlsx.
this->smem_read_offset_ ^= ((ki % 2 == 0) ? 1 : 3) * BYTES_PER_LDS;
}
// Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki)
{
#pragma unroll
for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni)
{
int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
uint2 tmp;
ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
b[ni].reg(0) = tmp.x;
b[ni].reg(1) = tmp.y;
}
// Move the offset to the next position. See doc/mma_smem_layout.xlsx.
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15)
{
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS;
}
else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7)
{
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS;
}
else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3)
{
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS;
}
else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1)
{
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS;
}
else if (Mma_tile_with_padding::MMAS_K >= 2)
{
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS;
}
}
// Reset the read offset.
inline __device__ void reset_read_offset()
{
// The number of MMAs in the K dimension.
enum
{
MMAS_K = Mma_tile::MMAS_K
};
// The number of MMAs in the K dimension when we include padding.
enum
{
MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K
};
// Assemble the mask.
enum
{
MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE
};
// Reset the read offset.
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_b<Turing_hmma_fp16_traits, Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_turing_col_b<Turing_hmma_fp16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = Turing_hmma_fp16_traits;
// The base class.
using Base = Smem_tile_turing_col_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_b<Turing_hmma_fp32_traits, Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_turing_col_b<Turing_hmma_fp32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = Turing_hmma_fp32_traits;
// The base class.
using Base = Smem_tile_turing_col_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_b<Turing_imma_int8_int32_traits, Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_turing_col_b<Turing_imma_int8_int32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = Turing_imma_int8_int32_traits;
// The base class.
using Base = Smem_tile_turing_col_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits, int N>
struct Rows_per_xor_pattern_ampere_b
{
// The size in bits.
enum
{
N_IN_BITS = N * Traits::BITS_PER_ELEMENT_B
};
// The number of rows.
enum
{
VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8)
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits, int N>
struct Rows_per_xor_pattern_ampere_col_b : public Rows_per_xor_pattern_ampere_b<Traits, N>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The instruction traits.
typename Traits,
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_ampere_col_b<Traits, Cta_tile::K>::VALUE>
struct Smem_tile_ampere_col_b
: public Smem_tile_without_skews<Cta_tile, Cta_tile::N, Cta_tile::K, Traits::BITS_PER_ELEMENT_B, BYTES_PER_STS,
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>
{
// The MMA tile.
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::N, Cta_tile::K, Traits::BITS_PER_ELEMENT_B, BYTES_PER_STS,
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>;
// The fragment.
using Fragment = Fragment_b<Traits, Col>;
// When we use padding to reach a power of two, special care has to be taken.
using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Traits, Cta_tile>;
// The number of MMAs.
using Mma_tile_with_padding = typename Traits::template Mma_tile<Cta_tile_with_padding>;
// The size of a single LDS in bytes.
enum
{
BYTES_PER_LDS = 16
};
// The number of STS per thread
enum
{
STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA
};
// The number of STS per thread must be at least 1.
enum
{
STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE
};
// Ctor.
inline __device__ Smem_tile_ampere_col_b(void* smem, int tidx)
: Base(smem, tidx)
{
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
// The number of warps.
int const WARPS_M = Cta_tile::WARPS_M;
int const WARPS_N = Cta_tile::WARPS_N;
int const WARPS_K = Cta_tile::WARPS_K;
// The masks to select the warps.
int const WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
int const WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;
// The divisor for the warps.
int const WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;
int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;
// The row and column read by the thread.
int smem_read_row, smem_read_col;
static_assert(
Base::ROWS_PER_XOR_PATTERN == 8 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 2, "");
if (Base::ROWS_PER_XOR_PATTERN == 8)
{
// For group fprop. B is divided into 2 halves along N dimension.
// The fist warp takes the first half and the second warp takes the second half.
smem_read_row
= (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 1 + (tidx & 0x07) + (tidx & 0x10) / 2;
smem_read_col = (tidx & 0x07);
smem_read_col ^= (tidx & 0x08) / 8;
}
else if (Base::ROWS_PER_XOR_PATTERN == 4)
{
smem_read_row
= (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 2 + (tidx & 0x06) / 2 + (tidx & 0x10) / 4;
smem_read_col = (tidx & 0x06) / 2 + (tidx & 0x01) * 4;
smem_read_col ^= (tidx & 0x08) / 8;
}
else if (Base::ROWS_PER_XOR_PATTERN == 2)
{
smem_read_row
= (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 4 + (tidx & 0x04) / 4 + (tidx & 0x10) / 8;
smem_read_col = (tidx & 0x04) / 4 + (tidx & 0x03) * 2;
smem_read_col ^= (tidx & 0x08) / 8;
}
static_assert(WARPS_K <= 2, "");
static_assert(WARPS_K != 2 || Base::ROWS_PER_XOR_PATTERN != 2, "");
// We "swap" the block for the second warp working on the in-CTA split-K.
if (WARPS_K == 2)
{
smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K * 2;
}
// The shared memory offset.
this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS;
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline __device__ void reverse_smem_read_offset(int ki = 0)
{
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if (Mma_tile_with_padding::MMAS_K >= 2)
{
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki)
{
if (ki < Mma_tile::VALID_MMAS_K)
{
#pragma unroll
for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni)
{
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
// Load using LDSM.M88.4.
uint4 tmp;
ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
// Store the value into the fragment.
b[ni].reg(0) = tmp.x;
b[ni].reg(1) = tmp.y;
b[ni].reg(2) = tmp.z;
b[ni].reg(3) = tmp.w;
}
}
// Move the offset to the next position. See doc/mma_smem_layout.xlsx.
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15)
{
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;
}
else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7)
{
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;
}
else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3)
{
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2;
}
else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1)
{
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2;
}
else if (Mma_tile_with_padding::MMAS_K >= 2)
{
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2;
}
}
// Reset the read offset.
inline __device__ void reset_read_offset()
{
// The number of MMAs in the K dimension.
enum
{
MMAS_K = Mma_tile::MMAS_K
};
// The number of MMAs in the K dimension when we include padding.
enum
{
MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K
};
// Assemble the mask.
enum
{
MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE
};
// Reset the read offset.
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_b<Ampere_hmma_fp16_traits, Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_ampere_col_b<Ampere_hmma_fp16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = Ampere_hmma_fp16_traits;
// The base class.
using Base = Smem_tile_ampere_col_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_b<Ampere_hmma_fp32_traits, Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_ampere_col_b<Ampere_hmma_fp32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = Ampere_hmma_fp32_traits;
// The base class.
using Base = Smem_tile_ampere_col_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_b<Ampere_hmma_bf16_traits, Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_ampere_col_b<Ampere_hmma_bf16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = Ampere_hmma_bf16_traits;
// The base class.
using Base = Smem_tile_ampere_col_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_b<Ampere_imma_int8_int32_traits, Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_ampere_col_b<Ampere_imma_int8_int32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = Ampere_imma_int8_int32_traits;
// The base class.
using Base = Smem_tile_ampere_col_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_b<Ada_qmma_e4m3_fp32_traits, Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_ampere_col_b<Ada_qmma_e4m3_fp32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = Ada_qmma_e4m3_fp32_traits;
// The base class.
using Base = Smem_tile_ampere_col_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_b<Ada_qmma_e4m3_fp16_traits, Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_ampere_col_b<Ada_qmma_e4m3_fp16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = Ada_qmma_e4m3_fp16_traits;
// The base class.
using Base = Smem_tile_ampere_col_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits, int N>
struct Rows_per_xor_pattern_ampere_row_b : public Rows_per_xor_pattern_ampere_b<Traits, N>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The instruction traits.
typename Traits,
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_ampere_row_b<Traits, Cta_tile::N>::VALUE,
// How many cols to use for the XOR pattern to avoid bank conflicts?
int COLS_PER_XOR_PATTERN_ = 1>
struct Smem_tile_ampere_row_b
: public Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, Traits::BITS_PER_ELEMENT_B, BYTES_PER_STS,
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, COLS_PER_XOR_PATTERN_>
{
// The MMA tile.
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, Traits::BITS_PER_ELEMENT_B, BYTES_PER_STS,
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, COLS_PER_XOR_PATTERN_>;
// The fragment.
using Fragment = Fragment_b<Traits, Row>;
// Can we use LDSM? No if the data type is 32-bit large.
enum
{
USE_LDSMT = Traits::BITS_PER_ELEMENT_B == 16
};
// The size of a single LDS in bytes.
enum
{
BYTES_PER_LDS = USE_LDSMT ? 16 : 4
};
// The number of elements per LDS.
enum
{
ELEMENTS_PER_LDS = BYTES_PER_LDS * 8 / Traits::BITS_PER_ELEMENT_B
};
// The number of STS per thread
enum
{
STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA
};
// The number of STS per thread must be at least 1.
enum
{
STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE
};
// Ctor.
inline __device__ Smem_tile_ampere_row_b(void* smem, int tidx)
: Base(smem, tidx)
{
// For documentation on the layout, see doc/xmma_smem_layout.xlsx.
// The number of warps.
int const WARPS_M = Cta_tile::WARPS_M;
int const WARPS_N = Cta_tile::WARPS_N;
int const WARPS_K = Cta_tile::WARPS_K;
// The masks to select the warps.
int const WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
int const WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;
// The divisor for the warps.
int const WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;
int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;
// The row/col read by the thread.
int smem_read_row, smem_read_col;
static_assert((USE_LDSMT && Base::ROWS_PER_XOR_PATTERN == 8) || Base::ROWS_PER_XOR_PATTERN == 4
|| Base::ROWS_PER_XOR_PATTERN == 2,
"");
if (USE_LDSMT && Base::ROWS_PER_XOR_PATTERN == 8)
{
// For group dgrad. B is divided into 2 halves along K dimension.
// The fist warp takes the first half and the second warp takes the second half.
smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 + (tidx & 0x07) + (tidx & 0x08);
smem_read_col = (tidx & 0x07);
}
else if (USE_LDSMT && Base::ROWS_PER_XOR_PATTERN == 4)
{
smem_read_row
= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 8 + (tidx & 0x06) / 2 + (tidx & 0x08) / 2;
smem_read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2;
}
else if (USE_LDSMT && Base::ROWS_PER_XOR_PATTERN == 2)
{
smem_read_row
= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 4 + (tidx & 0x04) / 4 + (tidx & 0x08) / 4;
smem_read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4;
}
else if (Base::ROWS_PER_XOR_PATTERN == 4 && Base::COLS_PER_XOR_PATTERN == 2)
{
smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 8 + (tidx & 0x03);
smem_read_col = (tidx & 0x1c) / 4 + (tidx & 0x03) * 8;
}
// Each half-warp applies a different XOR pattern -- see the Excel document.
if (USE_LDSMT)
{
smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16;
}
else
{
smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 16;
}
// The shared memory offset.
this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS;
// Fill zeroes for group conv
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline __device__ void reverse_smem_read_offset(int ki = 0)
{
// The size of each element in bits.
int const BITS_PER_ELT = Traits::BITS_PER_ELEMENT_B;
// The size in bytes of the data needed to compute an MMA per CTA.
int const BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;
#pragma unroll
for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni)
{
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if (BYTES_PER_MMA_PER_CTA >= 128)
{
// Nothing to do!
}
else if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1)
{
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
}
else if (BYTES_PER_MMA_PER_CTA == 64)
{
// Nothing to do!
}
else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4)
{
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
}
else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2)
{
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && Mma_tile::MMAS_N % 2 == 1)
{
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
}
}
// Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::VALID_MMAS_N], int ki)
{
// The size of each element in bits.
int const BITS_PER_ELT = Traits::BITS_PER_ELEMENT_B;
// The size in bytes of the data needed to compute an MMA per CTA.
int const BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;
#pragma unroll
for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni)
{
// Prepare the offset.
int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * Base::BYTES_PER_ROW;
if (BYTES_PER_MMA_PER_CTA == 32)
{
offset += this->smem_read_offset_;
}
else if (BYTES_PER_MMA_PER_CTA == 64)
{
offset += this->smem_read_offset_ + (ni / 2) * BYTES_PER_MMA_PER_CTA * 2;
}
else
{
offset += this->smem_read_offset_ + (ni) *BYTES_PER_MMA_PER_CTA;
}
// Load the data using LDSM.MT88.2.
uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset;
if (ni < Mma_tile::VALID_MMAS_N)
{
uint4 tmp;
if (USE_LDSMT)
{
ldsmt(tmp, ptr);
}
else
{
lds(tmp.x, (ptr) + 0 * Base::BYTES_PER_ROW);
lds(tmp.y, (ptr) + 4 * Base::BYTES_PER_ROW);
lds(tmp.z, (ptr ^ 32) + 0 * Base::BYTES_PER_ROW);
lds(tmp.w, (ptr ^ 32) + 4 * Base::BYTES_PER_ROW);
}
// Store those values in the fragment.
b[ni].reg(0) = tmp.x;
b[ni].reg(1) = tmp.y;
b[ni].reg(2) = tmp.z;
b[ni].reg(3) = tmp.w;
}
// static_assert(BYTES_PER_MMA_PER_CTA >= 128 ||
// BYTES_PER_MMA_PER_CTA == 64 ||
// (BYTES_PER_MMA_PER_CTA == 32 &&
// (Mma_tile::MMAS_M == 4 ||
// Mma_tile::MMAS_M == 2 ||
// Mma_tile::MMAS_M == 1)), "");
// Move the pointer for the next ni. I expect the compiler to not recompute those.
if (BYTES_PER_MMA_PER_CTA >= 128)
{
// Nothing to do!
}
else if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1)
{
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
}
else if (BYTES_PER_MMA_PER_CTA == 64)
{
// Nothing to do!
}
else if (BYTES_PER_MMA_PER_CTA == 32)
{
if ((ni & 1) == 0)
{
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
else if (Mma_tile::MMAS_N >= 16 && (ni & 7) == 7)
{
this->smem_read_offset_ ^= BYTES_PER_LDS * 30;
}
else if (Mma_tile::MMAS_N >= 8 && (ni & 3) == 3)
{
this->smem_read_offset_ ^= BYTES_PER_LDS * 14;
}
else if (Mma_tile::MMAS_N >= 4 && (ni & 1) == 1)
{
this->smem_read_offset_ ^= BYTES_PER_LDS * 6;
}
}
}
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && Mma_tile::MMAS_N % 2 == 1)
{
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_b<Ampere_hmma_fp32_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_ampere_row_b<Ampere_hmma_fp32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = Ampere_hmma_fp32_traits;
// The base class.
using Base = Smem_tile_ampere_row_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE>
struct Smem_tile_b<Ampere_hmma_bf16_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_ampere_row_b<Ampere_hmma_bf16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
{
// The traits class.
using Traits = Ampere_hmma_bf16_traits;
// The base class.
using Base = Smem_tile_ampere_row_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void* smem, int tidx)
: Base(smem, tidx)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha