mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-10 04:53:38 +08:00
* add fmha repo Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> * fix format Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> * fix code style Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> * fix header Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> * fix header kernel_traits.h Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> * add .gitignore file Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> * add SLIDING_WINDOW_ATTENTION Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> * fix style Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> * fix format Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> * update setup.py Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> * update build_wheel.py Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> --------- Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> Signed-off-by: qsang-nv <200703406+qsang-nv@users.noreply.github.com>
2619 lines
87 KiB
C++
2619 lines
87 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
|
|
*
|
|
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
* property and proprietary rights in and to this material, related
|
|
* documentation and any modifications thereto. Any use, reproduction,
|
|
* disclosure or distribution of this material and related documentation
|
|
* without an express license agreement from NVIDIA CORPORATION or
|
|
* its affiliates is strictly prohibited.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include <fmha/fragment.h>
|
|
#include <fmha/traits.h>
|
|
#include <fmha/utils.h>
|
|
|
|
namespace fmha
|
|
{
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The description of the tile computed by this CTA.
|
|
typename Cta_tile,
|
|
// The number of rows in the 2D shared memory buffer.
|
|
int M_,
|
|
// The number of cols.
|
|
int N_,
|
|
// The size in bits of each element.
|
|
int BITS_PER_ELEMENT_,
|
|
// The number of bytes per STS.
|
|
int BYTES_PER_STS_ = 16,
|
|
// The number of buffers. (Used in multistage and double buffer cases.)
|
|
int BUFFERS_PER_TILE_ = 1,
|
|
// Do we enable the fast path for LDS.128 and friends.
|
|
int ENABLE_LDS_FAST_PATH_ = 0,
|
|
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
|
|
int ROWS_PER_XOR_PATTERN_ = 8,
|
|
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
|
|
int COLS_PER_XOR_PATTERN_ = 1,
|
|
// Use or not predicates
|
|
bool USE_PREDICATES_ = true,
|
|
// Use TMA or not,
|
|
bool USE_TMA_ = false,
|
|
// The leading dim elements in shared memory
|
|
int LEAD_DIM_ELEMENTS_ = N_>
|
|
struct Smem_tile_without_skews
|
|
{
|
|
|
|
// The type of this tile
|
|
using Smem_tile_ = Smem_tile_without_skews<Cta_tile, M_, N_, BITS_PER_ELEMENT_, BYTES_PER_STS_, BUFFERS_PER_TILE_,
|
|
ENABLE_LDS_FAST_PATH_, ROWS_PER_XOR_PATTERN_, COLS_PER_XOR_PATTERN_, USE_PREDICATES_>;
|
|
|
|
static constexpr bool USE_TMA = USE_TMA_;
|
|
|
|
// The size in bits of each element.
|
|
enum
|
|
{
|
|
BITS_PER_ELEMENT = BITS_PER_ELEMENT_
|
|
};
|
|
|
|
// The size in bytes of a single STS.
|
|
enum
|
|
{
|
|
BYTES_PER_STS = BYTES_PER_STS_
|
|
};
|
|
|
|
// The number of elements per STS.
|
|
enum
|
|
{
|
|
ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT
|
|
};
|
|
|
|
// To support arbitrary N, we pad some values to a power-of-2.
|
|
enum
|
|
{
|
|
N_WITH_PADDING = Next_power_of_two<LEAD_DIM_ELEMENTS_>::VALUE
|
|
};
|
|
|
|
// The number of bytes per row without packing of rows.
|
|
enum
|
|
{
|
|
BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8
|
|
};
|
|
|
|
// The number of bytes per row -- we want at least 128B per row.
|
|
enum
|
|
{
|
|
BYTES_PER_ROW = Max<BYTES_PER_ROW_BEFORE_PACKING, 128>::VALUE
|
|
};
|
|
|
|
// The number of rows in shared memory (two rows may be packed into a single one).
|
|
enum
|
|
{
|
|
ROWS = M_ * N_ / LEAD_DIM_ELEMENTS_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW
|
|
};
|
|
|
|
// The number of threads per row.
|
|
enum
|
|
{
|
|
THREADS_PER_ROW_UNBOUNDED = BYTES_PER_ROW / BYTES_PER_STS
|
|
};
|
|
|
|
// The number of threads per row.
|
|
enum
|
|
{
|
|
THREADS_PER_ROW = Min<Cta_tile::THREADS_PER_CTA, THREADS_PER_ROW_UNBOUNDED>::VALUE
|
|
};
|
|
|
|
// The number of STS per row.
|
|
enum
|
|
{
|
|
STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS
|
|
};
|
|
|
|
// It must be at least one.
|
|
static_assert(STS_PER_ROW >= 1, "");
|
|
|
|
// The number of rows written with a single STS.
|
|
enum
|
|
{
|
|
ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW
|
|
};
|
|
|
|
// Make sure we write to at least one row per STS. Thanks Dr. Obvious ;)
|
|
static_assert(ROWS_PER_STS >= 1, "");
|
|
|
|
// The number of STS needed to store all rows.
|
|
enum
|
|
{
|
|
STS_PER_COL = Div_up<ROWS, ROWS_PER_STS>::VALUE
|
|
};
|
|
|
|
// The number of STS in total.
|
|
enum
|
|
{
|
|
STS = STS_PER_COL * STS_PER_ROW
|
|
};
|
|
|
|
// The size of one buffer in bytes in shared memory.
|
|
enum
|
|
{
|
|
BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA
|
|
};
|
|
|
|
// The number of buffers.
|
|
enum
|
|
{
|
|
BUFFERS_PER_TILE = BUFFERS_PER_TILE_
|
|
};
|
|
|
|
// The size in bytes of total buffers.
|
|
enum
|
|
{
|
|
BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE
|
|
};
|
|
|
|
// The boundary for smem_read_offset and smem_write_offset increment.
|
|
enum
|
|
{
|
|
BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER
|
|
};
|
|
|
|
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
|
|
enum
|
|
{
|
|
ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_
|
|
};
|
|
|
|
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
|
|
enum
|
|
{
|
|
COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS
|
|
};
|
|
|
|
// Use or not predicates
|
|
enum
|
|
{
|
|
USE_PREDICATES = USE_PREDICATES_
|
|
};
|
|
|
|
// The bytes of one shmem row
|
|
enum
|
|
{
|
|
BYTES_PER_SHMEM_ROW = 128
|
|
};
|
|
|
|
// The type of elements that are stored in shared memory by each thread.
|
|
using Store_type = typename Uint_from_size_in_bytes<BYTES_PER_STS>::Type;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_without_skews(void* smem, int tidx)
|
|
: smem_(__nvvm_get_smem_pointer(smem))
|
|
{
|
|
|
|
// The row written by a thread. See doc/mma_smem_layout.xlsx.
|
|
int smem_write_row = tidx / THREADS_PER_ROW;
|
|
|
|
// The XOR pattern.
|
|
int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN * COLS_PER_XOR_PATTERN;
|
|
// Compute the column and apply the XOR pattern.
|
|
int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor;
|
|
|
|
// The offset.
|
|
this->smem_write_offset_ = smem_write_row * BYTES_PER_ROW + smem_write_col * BYTES_PER_STS;
|
|
|
|
// That code is expected to trigger the utilization of the URF by the compiler.
|
|
this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0);
|
|
this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0);
|
|
}
|
|
|
|
// Compute the store pointers.
|
|
template <int N, int K = 1>
|
|
inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N])
|
|
{
|
|
|
|
#pragma unroll
|
|
for (int ii = 0; ii < N; ++ii)
|
|
{
|
|
// Decompose the STS into row/col.
|
|
int row = ii % STS_PER_COL;
|
|
int col = ii / STS_PER_COL;
|
|
|
|
// Compute the immediate.
|
|
int imm = row;
|
|
|
|
// Assemble the offset.
|
|
int offset = smem_write_offset_ + imm * ROWS_PER_STS * BYTES_PER_ROW;
|
|
|
|
// Take the column into account.
|
|
if (STS_PER_ROW > 1)
|
|
{
|
|
offset += col * THREADS_PER_ROW * BYTES_PER_STS;
|
|
}
|
|
|
|
// Apply the XOR pattern if needed.
|
|
if (ROWS_PER_STS < ROWS_PER_XOR_PATTERN)
|
|
{
|
|
int const m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN;
|
|
offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS;
|
|
}
|
|
|
|
// Assemble the final pointer :)
|
|
#pragma unroll
|
|
for (int k = 0; k < K; k++)
|
|
{
|
|
ptrs[ii * K + k] = smem_ + offset + k * (BYTES_PER_STS / K) + smem_write_buffer_;
|
|
}
|
|
}
|
|
}
|
|
|
|
inline __device__ void debug_reset()
|
|
{
|
|
for (int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER)
|
|
{
|
|
for (int row = 0; row < ROWS; ++row)
|
|
{
|
|
for (int col = 0; col < BYTES_PER_ROW; col += 4)
|
|
{
|
|
if (threadIdx.x == 0)
|
|
{
|
|
uint32_t val = 0x0;
|
|
sts(val, smem_ + row * BYTES_PER_ROW + col + buffer);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Print the content of the tile (only for debug ;)).
|
|
inline __device__ void debug_print() const
|
|
{
|
|
for (int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER)
|
|
{
|
|
for (int row = 0; row < ROWS; ++row)
|
|
{
|
|
for (int col = 0; col < BYTES_PER_ROW; col += 4)
|
|
{
|
|
if (threadIdx.x == 0)
|
|
{
|
|
uint32_t val;
|
|
lds(val, smem_ + row * BYTES_PER_ROW + col + buffer);
|
|
printf(
|
|
"block=(x=%2d, y=%2d, z=%2d) (smem_=0x%08x, buffer=%2d, row=%2d, "
|
|
"byte=%4d)=0x%08x\n",
|
|
blockIdx.x, blockIdx.y, blockIdx.z, smem_, buffer, row, col, val);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Move the read offset to next buffer.
|
|
inline __device__ void move_to_next_read_buffer()
|
|
{
|
|
if (BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY)
|
|
{
|
|
this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;
|
|
}
|
|
else if (BUFFERS_PER_TILE > 1)
|
|
{
|
|
this->smem_read_buffer_ += BYTES_PER_BUFFER;
|
|
}
|
|
}
|
|
|
|
// Move the read offset to next buffer. TODO: Remove this member function!!!
|
|
inline __device__ void move_next_read_buffer()
|
|
{
|
|
this->move_to_next_read_buffer();
|
|
}
|
|
|
|
// Move the read offset to next N buffer (circular-buffer).
|
|
inline __device__ void move_to_next_read_buffer(int N)
|
|
{
|
|
if (BUFFERS_PER_TILE > 1)
|
|
{
|
|
this->smem_read_buffer_ += N * BYTES_PER_BUFFER;
|
|
this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0;
|
|
}
|
|
}
|
|
|
|
// Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!!
|
|
inline __device__ void move_next_read_buffer(int N)
|
|
{
|
|
this->move_to_next_read_buffer(N);
|
|
}
|
|
|
|
// Move the write offset to next buffer.
|
|
inline __device__ void move_to_next_write_buffer()
|
|
{
|
|
if (BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY)
|
|
{
|
|
this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;
|
|
}
|
|
else if (BUFFERS_PER_TILE > 1)
|
|
{
|
|
this->smem_write_buffer_ += BYTES_PER_BUFFER;
|
|
}
|
|
}
|
|
|
|
// Move the write offset to next buffer. TODO: Remove that member function!
|
|
inline __device__ void move_next_write_buffer()
|
|
{
|
|
this->move_to_next_write_buffer();
|
|
}
|
|
|
|
// Move the read offset.
|
|
inline __device__ void move_read_offset(int delta)
|
|
{
|
|
this->smem_read_offset_ += delta;
|
|
}
|
|
|
|
// Move the write offset.
|
|
inline __device__ void move_write_offset(int delta)
|
|
{
|
|
this->smem_write_offset_ += delta;
|
|
}
|
|
|
|
// Store to the tile in shared memory.
|
|
template <int N>
|
|
inline __device__ void store(Store_type const (&data)[N])
|
|
{
|
|
uint32_t smem_ptrs[N];
|
|
this->compute_store_pointers(smem_ptrs);
|
|
sts(smem_ptrs, data);
|
|
}
|
|
|
|
// Store to the tile in shared memory.
|
|
template <int N, int M>
|
|
inline __device__ void store(Store_type const (&data)[N], uint32_t (&preds)[M])
|
|
{
|
|
uint32_t smem_ptrs[N];
|
|
this->compute_store_pointers(smem_ptrs);
|
|
sts(smem_ptrs, data, preds);
|
|
}
|
|
|
|
// Store to the tile in shared memory.
|
|
template <int N>
|
|
inline __device__ void store(Store_type const (&data)[N], uint32_t preds)
|
|
{
|
|
this->store(data, preds);
|
|
}
|
|
|
|
// Store to the tile in shared memory. TODO: Remove last template arguments.
|
|
template <int N, int M>
|
|
inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t (&preds)[M])
|
|
{
|
|
uint32_t smem_ptrs[N];
|
|
this->compute_store_pointers<N>(smem_ptrs);
|
|
ldgsts<N, M>(smem_ptrs, gmem_ptrs, preds);
|
|
}
|
|
|
|
// Store to the tile in shared memory.
|
|
template <int N>
|
|
inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t preds, uint64_t = 0)
|
|
{
|
|
uint32_t tmp[1] = {preds};
|
|
this->store(gmem_ptrs, tmp);
|
|
}
|
|
|
|
// Store to the tile in shared memory.
|
|
template <int N>
|
|
inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t preds)
|
|
{
|
|
uint32_t tmp[1] = {preds};
|
|
this->store(gmem_ptrs, tmp);
|
|
}
|
|
|
|
inline __device__ void add_smem_barrier_base(uint64_t*) {}
|
|
|
|
// The shared memory pointer.
|
|
uint32_t smem_;
|
|
// The read offset. Reserve 4 offsets if needed.
|
|
int smem_read_offset_;
|
|
// The write offset.
|
|
int smem_write_offset_;
|
|
// The buffer base offset for read.
|
|
int smem_read_buffer_;
|
|
// The buffer base offset for write.
|
|
int smem_write_buffer_;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// Use TMA
|
|
template <
|
|
// The description of the tile computed by this CTA.
|
|
typename Cta_tile,
|
|
// The number of rows in the 2D shared memory buffer.
|
|
int M_,
|
|
// The number of cols.
|
|
int N_,
|
|
// The size in bits of each element.
|
|
int BITS_PER_ELEMENT_,
|
|
// The number of bytes per STS. Not relevant for TMA
|
|
int BYTES_PER_STS_,
|
|
// The number of buffers. (Used in multistage and double buffer cases.)
|
|
int BUFFERS_PER_TILE_,
|
|
// Do we enable the fast path for LDS.128 and friends.
|
|
int ENABLE_LDS_FAST_PATH_,
|
|
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
|
|
int ROWS_PER_XOR_PATTERN_,
|
|
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
|
|
int COLS_PER_XOR_PATTERN_,
|
|
// Use or not predicates
|
|
bool USE_PREDICATES_,
|
|
// The leading dim elements in shared memory
|
|
int LEAD_DIM_ELEMENTS_>
|
|
struct Smem_tile_without_skews<Cta_tile, M_, N_, BITS_PER_ELEMENT_, BYTES_PER_STS_, BUFFERS_PER_TILE_,
|
|
ENABLE_LDS_FAST_PATH_, ROWS_PER_XOR_PATTERN_, COLS_PER_XOR_PATTERN_, USE_PREDICATES_, true, LEAD_DIM_ELEMENTS_>
|
|
: public Smem_tile_without_skews<Cta_tile, M_, N_, BITS_PER_ELEMENT_, BYTES_PER_STS_, BUFFERS_PER_TILE_,
|
|
ENABLE_LDS_FAST_PATH_, ROWS_PER_XOR_PATTERN_, COLS_PER_XOR_PATTERN_, USE_PREDICATES_, false,
|
|
LEAD_DIM_ELEMENTS_>
|
|
{
|
|
// Base struct
|
|
using Base = Smem_tile_without_skews<Cta_tile, M_, N_, BITS_PER_ELEMENT_, BYTES_PER_STS_, BUFFERS_PER_TILE_,
|
|
ENABLE_LDS_FAST_PATH_, ROWS_PER_XOR_PATTERN_, COLS_PER_XOR_PATTERN_, USE_PREDICATES_, false,
|
|
LEAD_DIM_ELEMENTS_>;
|
|
static constexpr bool USE_TMA = true;
|
|
|
|
// Tile size overrides. STS per thread not relevant for TMA
|
|
static constexpr int BYTES_PER_BUFFER = M_ * N_ * Base::BITS_PER_ELEMENT / 8;
|
|
static constexpr int BYTES_PER_TILE = BYTES_PER_BUFFER * Base::BUFFERS_PER_TILE;
|
|
static constexpr int BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER;
|
|
// The number of bytes per barrier
|
|
static constexpr int BYTES_PER_BARRIER = 8;
|
|
|
|
// Ctor
|
|
inline __device__ Smem_tile_without_skews(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
this->smem_write_offset_ = __nvvm_get_smem_pointer(smem);
|
|
this->smem_barrier_offset_ = 0;
|
|
this->elect_one_ = elect_one_sync();
|
|
}
|
|
|
|
inline __device__ void add_smem_barrier_base(uint64_t* smem_barrier)
|
|
{
|
|
this->smem_barrier_ = smem_barrier;
|
|
this->smem_barrier_offset_ = __nvvm_get_smem_pointer(this->smem_barrier_);
|
|
}
|
|
|
|
/**
|
|
* \brief load tensor blocks from global memory and stores to shared memory using tma instructions
|
|
*
|
|
* \param p_desc pointer to tma descriptor masked as const void* pointer
|
|
* \param smem_offset shared memory offset in bytes relative to smem_write_buffer_
|
|
* \param coord0 tensor access coordinate in dimension 1, used by tma load
|
|
* \param coord1 tensor access coordinate in dimension 2, used by tma load
|
|
* \param coord2 tensor access coordinate in dimension 3, used by tma load
|
|
* \param coord3 tensor access coordinate in dimension 4, used by tma load
|
|
* \param coord4 tensor access coordinate in dimension 5, used by tma load
|
|
* \param filter_offsets encodes multicast cta id and filter offsets
|
|
*/
|
|
template <uint32_t DIM, cudaTmaDescType DESC_TYPE, unsigned COPY_BYTES, bool USE_TMA_MULTICAST = false>
|
|
inline __device__ void store(void const* p_desc, unsigned const& smem_offset, int32_t coord0, int32_t coord1,
|
|
int32_t coord2, int32_t coord3, int32_t coord4, uint16_t filter_offsets, uint16_t mcast_cta_mask,
|
|
uint64_t mem_desc)
|
|
{
|
|
uint32_t smem = this->smem_write_offset_ + smem_offset;
|
|
fmha::utmaldg<DIM, DESC_TYPE, USE_TMA_MULTICAST>(reinterpret_cast<cudaTmaDesc const*>(p_desc), smem,
|
|
unsigned(this->smem_barrier_offset_), coord0, coord1, coord2, coord3, coord4, filter_offsets,
|
|
mcast_cta_mask, mem_desc, this->elect_one_);
|
|
}
|
|
|
|
// Same function as above but for runtime cga dimension
|
|
template <uint32_t DIM, cudaTmaDescType DESC_TYPE, unsigned COPY_BYTES>
|
|
inline __device__ void store(void const* p_desc, unsigned const& smem_offset, int32_t coord0, int32_t coord1,
|
|
int32_t coord2, int32_t coord3, int32_t coord4, uint16_t filter_offsets, uint16_t mcast_cta_mask,
|
|
uint64_t mem_desc, bool mcast_enabled)
|
|
{
|
|
uint32_t smem = this->smem_write_offset_ + smem_offset;
|
|
fmha::utmaldg<DIM, DESC_TYPE>(reinterpret_cast<cudaTmaDesc const*>(p_desc), smem,
|
|
unsigned(this->smem_barrier_offset_), coord0, coord1, coord2, coord3, coord4, filter_offsets,
|
|
mcast_cta_mask, mcast_enabled, mem_desc, this->elect_one_);
|
|
}
|
|
|
|
// Move the write offset to next buffer.
|
|
inline __device__ void move_next_write_buffer()
|
|
{
|
|
if (Base::BUFFERS_PER_TILE > 1)
|
|
{
|
|
this->smem_write_offset_ += (this->smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY)
|
|
? -BYTES_PER_TILE_INC_BOUNDARY
|
|
: BYTES_PER_BUFFER;
|
|
this->smem_barrier_offset_ += (this->smem_barrier_offset_ >= Base::BUFFERS_PER_TILE * BYTES_PER_BARRIER)
|
|
? -Base::BUFFERS_PER_TILE * BYTES_PER_BARRIER
|
|
: BYTES_PER_BARRIER;
|
|
}
|
|
}
|
|
|
|
inline __device__ void move_next_write_buffer(int buffer_id)
|
|
{
|
|
if (Base::BUFFERS_PER_TILE > 1)
|
|
{
|
|
this->smem_write_offset_ = this->smem_ + buffer_id * BYTES_PER_BUFFER;
|
|
}
|
|
this->smem_barrier_offset_ = __nvvm_get_smem_pointer(this->smem_barrier_ + buffer_id);
|
|
}
|
|
|
|
// Move the read offset to next buffer.
|
|
// do nothing, as it is controlled by gmma desc
|
|
inline __device__ void move_next_read_buffer() {}
|
|
|
|
uint64_t* smem_barrier_;
|
|
uint32_t smem_barrier_offset_;
|
|
// elect one thread to issue utmaldg
|
|
uint32_t elect_one_;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The instruction traits.
|
|
typename Traits,
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The layout of the tile.
|
|
typename Layout,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS = 16,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE = 1,
|
|
// Use or not predicates
|
|
bool USE_PREDICATES = true>
|
|
struct Smem_tile_a
|
|
{
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Traits, int N>
|
|
struct Rows_per_xor_pattern_volta_a
|
|
{
|
|
// The size in bits.
|
|
enum
|
|
{
|
|
N_IN_BITS = N * Traits::BITS_PER_ELEMENT_A
|
|
};
|
|
|
|
// The number of rows.
|
|
enum
|
|
{
|
|
VALUE = N_IN_BITS <= 256 ? 1 : (N_IN_BITS <= 512 ? 2 : 4)
|
|
};
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int MMAS_K, int MMAS_K_WITH_PADDING>
|
|
struct Compute_reset_mask
|
|
{
|
|
// The potential mask.
|
|
enum
|
|
{
|
|
HALF = MMAS_K_WITH_PADDING / 2
|
|
};
|
|
|
|
// The remainder.
|
|
enum
|
|
{
|
|
MOD = MMAS_K % HALF
|
|
};
|
|
|
|
// The final value.
|
|
enum
|
|
{
|
|
VALUE = (MMAS_K == MOD ? 0 : HALF) | Compute_reset_mask<MOD, HALF>::VALUE
|
|
};
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int MMAS_K_WITH_PADDING>
|
|
struct Compute_reset_mask<0, MMAS_K_WITH_PADDING>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = 0
|
|
};
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int MMAS_K>
|
|
struct Compute_reset_mask<MMAS_K, MMAS_K>
|
|
{
|
|
enum
|
|
{
|
|
VALUE = MMAS_K - 1
|
|
};
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The instruction traits.
|
|
typename Traits,
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE,
|
|
// How many rows to use for the XOR pattern to avoid bank conflicts?
|
|
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_volta_a<Traits, Cta_tile::K>::VALUE>
|
|
struct Smem_tile_volta_row_a : public Smem_tile_without_skews<Cta_tile, Cta_tile::M, Cta_tile::K, 16, BYTES_PER_STS,
|
|
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>
|
|
{
|
|
|
|
// The MMA tile.
|
|
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
|
|
// The base class.
|
|
using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::M, Cta_tile::K, 16, BYTES_PER_STS, BUFFERS_PER_TILE, 0,
|
|
ROWS_PER_XOR_PATTERN_, 1>;
|
|
// The fragment.
|
|
using Fragment = Fragment_a<Traits, Row>;
|
|
|
|
// When we use padding to reach a power of two, special care has to be taken.
|
|
using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Traits, Cta_tile>;
|
|
// The number of MMAs.
|
|
using Mma_tile_with_padding = typename Traits::template Mma_tile<Cta_tile_with_padding>;
|
|
|
|
// The size of a single LDS in bytes.
|
|
enum
|
|
{
|
|
BYTES_PER_LDS = 16
|
|
};
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_volta_row_a(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
|
|
// For documentation on the layout, see doc/xmma_smem_layout.xlsx.
|
|
|
|
// The number of warps.
|
|
int const WARPS_M = Cta_tile::WARPS_M;
|
|
int const WARPS_N = Cta_tile::WARPS_N;
|
|
int const WARPS_K = Cta_tile::WARPS_K;
|
|
|
|
// The masks to select the warps.
|
|
int const WARP_MASK_M = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::M;
|
|
int const WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;
|
|
|
|
// The divisor for the warps.
|
|
int const WARP_DIV_M = 1 * 1 * Cta_tile::THREADS_PER_WARP;
|
|
int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;
|
|
|
|
// The row and column read by the thread.
|
|
int smem_read_row, smem_read_col;
|
|
if (Base::N_WITH_PADDING >= 64)
|
|
{
|
|
smem_read_row
|
|
= (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 1 + (tidx & 0x10) / 2 + (tidx & 0x07);
|
|
smem_read_col = (tidx & 0x03);
|
|
}
|
|
else if (Base::N_WITH_PADDING == 32)
|
|
{
|
|
smem_read_row
|
|
= (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 2 + (tidx & 0x10) / 4 + (tidx & 0x06) / 2;
|
|
smem_read_col = (tidx & 0x02) / 2 + (tidx & 0x01) * 4;
|
|
}
|
|
else
|
|
{
|
|
assert(false);
|
|
}
|
|
|
|
// For WARPS_K > 1, we do not support Base::N_WITH_PADDING < 64 for the moment.
|
|
static_assert(WARPS_K <= 2 && (WARPS_K == 1 || Base::N_WITH_PADDING >= 64), "");
|
|
|
|
// We "swap" the block for the second warp working on the in-CTA split-K.
|
|
if (WARPS_K == 2)
|
|
{
|
|
smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K;
|
|
}
|
|
|
|
// The shared memory offset.
|
|
this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS;
|
|
}
|
|
|
|
// Rewind smem_read_offset for last LDS phase in main loop.-
|
|
inline __device__ void reverse_smem_read_offset(int ki = 0)
|
|
{
|
|
// Move the offset to the next position. See doc/xmma_smem_layout.xlsx.
|
|
this->smem_read_offset_ ^= ((ki % 2 == 0) ? 1 : 3) * BYTES_PER_LDS;
|
|
}
|
|
|
|
// Load from shared memory.
|
|
inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki)
|
|
{
|
|
#pragma unroll
|
|
for (int mi = 0; mi < Mma_tile::MMAS_M; ++mi)
|
|
{
|
|
// Jump over as many rows as needed.
|
|
int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
|
|
|
|
// TODO: Could we fuse smem_read_buffer and smem_read_offset?
|
|
uint4 tmp;
|
|
lds(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
|
|
a[mi].reg(0) = tmp.x;
|
|
a[mi].reg(1) = tmp.y;
|
|
a[mi].reg(2) = tmp.z;
|
|
a[mi].reg(3) = tmp.w;
|
|
}
|
|
|
|
// Move the offset to the next position. See doc/xmma_smem_layout.xlsx.
|
|
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
|
|
if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15)
|
|
{
|
|
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7)
|
|
{
|
|
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3)
|
|
{
|
|
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1)
|
|
{
|
|
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 2)
|
|
{
|
|
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS;
|
|
}
|
|
}
|
|
|
|
// Reset the read offset.
|
|
inline __device__ void reset_read_offset()
|
|
{
|
|
// The number of MMAs in the K dimension.
|
|
enum
|
|
{
|
|
MMAS_K = Mma_tile::MMAS_K
|
|
};
|
|
|
|
// The number of MMAs in the K dimension when we include padding.
|
|
enum
|
|
{
|
|
MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K
|
|
};
|
|
|
|
// Assemble the mask.
|
|
enum
|
|
{
|
|
MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE
|
|
};
|
|
|
|
// Reset the read offset.
|
|
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS;
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_a<Volta_hmma_fp16_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_volta_row_a<Volta_hmma_fp16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
|
|
// The traits class.
|
|
using Traits = fmha::Volta_hmma_fp16_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_volta_row_a<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_a(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Traits, int N>
|
|
struct Rows_per_xor_pattern_turing_a
|
|
{
|
|
// The size in bits.
|
|
enum
|
|
{
|
|
N_IN_BITS = N * Traits::BITS_PER_ELEMENT_A
|
|
};
|
|
|
|
// The number of rows.
|
|
enum
|
|
{
|
|
VALUE = N_IN_BITS <= 128 ? 1 : (N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8))
|
|
};
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The instruction traits.
|
|
typename Traits,
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE,
|
|
// How many rows to use for the XOR pattern to avoid bank conflicts?
|
|
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_turing_a<Traits, Cta_tile::K>::VALUE>
|
|
struct Smem_tile_turing_row_a
|
|
: public Smem_tile_without_skews<Cta_tile, Cta_tile::M, Cta_tile::K, Traits::BITS_PER_ELEMENT_A, BYTES_PER_STS,
|
|
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>
|
|
{
|
|
|
|
// The MMA tile.
|
|
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
|
|
// The base class.
|
|
using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::M, Cta_tile::K, Traits::BITS_PER_ELEMENT_A, BYTES_PER_STS,
|
|
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>;
|
|
// The fragment.
|
|
using Fragment = Fragment_a<Traits, Row>;
|
|
|
|
// When we use padding to reach a power of two, special care has to be taken.
|
|
using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Traits, Cta_tile>;
|
|
// The number of MMAs.
|
|
using Mma_tile_with_padding = typename Traits::template Mma_tile<Cta_tile_with_padding>;
|
|
|
|
// The size of a single LDS in bytes.
|
|
enum
|
|
{
|
|
BYTES_PER_LDS = 16
|
|
};
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_turing_row_a(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
|
|
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
|
|
|
|
// The number of warps.
|
|
int const WARPS_M = Cta_tile::WARPS_M;
|
|
int const WARPS_N = Cta_tile::WARPS_N;
|
|
int const WARPS_K = Cta_tile::WARPS_K;
|
|
|
|
// The masks to select the warps.
|
|
int const WARP_MASK_M = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::M;
|
|
int const WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;
|
|
|
|
// The divisor for the warps.
|
|
int const WARP_DIV_M = 1 * 1 * Cta_tile::THREADS_PER_WARP;
|
|
int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;
|
|
|
|
// The row and column read by the thread.
|
|
int smem_read_row, smem_read_col;
|
|
|
|
static_assert(Base::ROWS_PER_XOR_PATTERN == 8 || Base::ROWS_PER_XOR_PATTERN == 4
|
|
|| Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 1,
|
|
"");
|
|
|
|
if (Base::ROWS_PER_XOR_PATTERN == 8)
|
|
{
|
|
smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 1 + (tidx & 0x0f);
|
|
smem_read_col = (tidx & 0x07);
|
|
}
|
|
else if (Base::ROWS_PER_XOR_PATTERN == 4)
|
|
{
|
|
smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 2 + (tidx & 0x0e) / 2;
|
|
smem_read_col = (tidx & 0x06) / 2 + (tidx & 0x01) * 4;
|
|
}
|
|
else if (Base::ROWS_PER_XOR_PATTERN == 2)
|
|
{
|
|
smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 4 + (tidx & 0x0c) / 4;
|
|
smem_read_col = (tidx & 0x04) / 4 + (tidx & 0x03) * 2;
|
|
}
|
|
else if (Base::ROWS_PER_XOR_PATTERN == 1)
|
|
{
|
|
smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 8 + (tidx & 0x1f) / 8;
|
|
smem_read_col = (tidx & 0x07);
|
|
}
|
|
|
|
static_assert(WARPS_K <= 2, "");
|
|
|
|
// We "swap" the block for the second warp working on the in-CTA split-K.
|
|
if (WARPS_K == 2)
|
|
{
|
|
smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K;
|
|
}
|
|
|
|
// The shared memory offset.
|
|
this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS;
|
|
}
|
|
|
|
// Rewind smem_read_offset for last LDS phase in main loop.-
|
|
inline __device__ void reverse_smem_read_offset(int ki = 0)
|
|
{
|
|
// Move the offset to the next position. See doc/mma_smem_layout.xlsx.
|
|
this->smem_read_offset_ ^= ((ki % 2 == 0) ? 1 : 3) * BYTES_PER_LDS;
|
|
}
|
|
|
|
// Load from shared memory.
|
|
inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki)
|
|
{
|
|
#pragma unroll
|
|
for (int mi = 0; mi < Mma_tile::MMAS_M; ++mi)
|
|
{
|
|
int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
|
|
uint2 tmp;
|
|
ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
|
|
a[mi].reg(0) = tmp.x;
|
|
a[mi].reg(1) = tmp.y;
|
|
}
|
|
|
|
// Move the offset to the next position. See doc/mma_smem_layout.xlsx.
|
|
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
|
|
if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15)
|
|
{
|
|
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7)
|
|
{
|
|
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3)
|
|
{
|
|
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1)
|
|
{
|
|
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 2)
|
|
{
|
|
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS;
|
|
}
|
|
}
|
|
|
|
// Reset the read offset.
|
|
inline __device__ void reset_read_offset()
|
|
{
|
|
// The number of MMAs in the K dimension.
|
|
enum
|
|
{
|
|
MMAS_K = Mma_tile::MMAS_K
|
|
};
|
|
|
|
// The number of MMAs in the K dimension when we include padding.
|
|
enum
|
|
{
|
|
MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K
|
|
};
|
|
|
|
// Assemble the mask.
|
|
enum
|
|
{
|
|
MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE
|
|
};
|
|
|
|
// Reset the read offset.
|
|
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS;
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_a<Turing_hmma_fp16_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_turing_row_a<Turing_hmma_fp16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
|
|
// The traits class.
|
|
using Traits = Turing_hmma_fp16_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_turing_row_a<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_a(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_a<Turing_hmma_fp32_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_turing_row_a<Turing_hmma_fp32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
|
|
// The traits class.
|
|
using Traits = Turing_hmma_fp32_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_turing_row_a<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_a(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_a<Turing_imma_int8_int32_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_turing_row_a<Turing_imma_int8_int32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
|
|
// The traits class.
|
|
using Traits = Turing_imma_int8_int32_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_turing_row_a<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_a(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Traits, int N>
|
|
struct Rows_per_xor_pattern_ampere_a
|
|
{
|
|
// The size in bits.
|
|
enum
|
|
{
|
|
N_IN_BITS = N * Traits::BITS_PER_ELEMENT_A
|
|
};
|
|
|
|
// The number of rows.
|
|
enum
|
|
{
|
|
VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8)
|
|
};
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Traits, int N>
|
|
struct Rows_per_xor_pattern_ampere_row_a : public Rows_per_xor_pattern_ampere_a<Traits, N>
|
|
{
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The instruction traits.
|
|
typename Traits,
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE,
|
|
// How many rows to use for the XOR pattern to avoid bank conflicts?
|
|
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_ampere_row_a<Traits, Cta_tile::K>::VALUE>
|
|
struct Smem_tile_ampere_row_a
|
|
: public Smem_tile_without_skews<Cta_tile, Cta_tile::M, Cta_tile::K, Traits::BITS_PER_ELEMENT_A, BYTES_PER_STS,
|
|
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>
|
|
{
|
|
// The MMA tile.
|
|
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
|
|
// The base class.
|
|
using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::M, Cta_tile::K, Traits::BITS_PER_ELEMENT_A, BYTES_PER_STS,
|
|
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>;
|
|
// The fragment.
|
|
using Fragment = Fragment_a<Traits, Row>;
|
|
|
|
// When we use padding to reach a power of two, special care has to be taken.
|
|
using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Traits, Cta_tile>;
|
|
// The number of MMAs.
|
|
using Mma_tile_with_padding = typename Traits::template Mma_tile<Cta_tile_with_padding>;
|
|
|
|
// The size of a single LDS in bytes.
|
|
enum
|
|
{
|
|
BYTES_PER_LDS = 16
|
|
};
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_ampere_row_a(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
|
|
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
|
|
|
|
// The number of warps.
|
|
int const WARPS_M = Cta_tile::WARPS_M;
|
|
int const WARPS_N = Cta_tile::WARPS_N;
|
|
int const WARPS_K = Cta_tile::WARPS_K;
|
|
|
|
// The masks to select the warps.
|
|
int const WARP_MASK_M = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::M;
|
|
int const WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;
|
|
|
|
// The divisor for the warps.
|
|
int const WARP_DIV_M = 1 * 1 * Cta_tile::THREADS_PER_WARP;
|
|
int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;
|
|
|
|
// The row and column read by the thread.
|
|
int smem_read_row, smem_read_col;
|
|
|
|
static_assert(
|
|
Base::ROWS_PER_XOR_PATTERN == 8 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 2, "");
|
|
|
|
if (Base::ROWS_PER_XOR_PATTERN == 8)
|
|
{
|
|
smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 1 + (tidx & 0x0f);
|
|
smem_read_col = (tidx & 0x07);
|
|
smem_read_col ^= (tidx & 0x10) / 16;
|
|
}
|
|
else if (Base::ROWS_PER_XOR_PATTERN == 4)
|
|
{
|
|
smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 2 + (tidx & 0x0e) / 2;
|
|
smem_read_col = (tidx & 0x06) / 2 + (tidx & 0x01) * 4;
|
|
smem_read_col ^= (tidx & 0x10) / 16;
|
|
}
|
|
else if (Base::ROWS_PER_XOR_PATTERN == 2)
|
|
{
|
|
smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 4 + (tidx & 0x0c) / 4;
|
|
smem_read_col = (tidx & 0x04) / 4 + (tidx & 0x03) * 2;
|
|
smem_read_col ^= (tidx & 0x10) / 16;
|
|
}
|
|
|
|
static_assert(WARPS_K <= 2, "");
|
|
static_assert(WARPS_K != 2 || Base::ROWS_PER_XOR_PATTERN != 2, "");
|
|
|
|
// We "swap" the block for the second warp working on the same outputs in-CTA split-K.
|
|
if (WARPS_K == 2)
|
|
{
|
|
smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K * 2;
|
|
}
|
|
|
|
// The shared memory offset.
|
|
this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS;
|
|
}
|
|
|
|
// Rewind smem_read_offset for last LDS phase in main loop.
|
|
inline __device__ void reverse_smem_read_offset(int ki = 0)
|
|
{
|
|
// Undo the pointer increment for the next ni.
|
|
// Should match the load function below for ki = 0.
|
|
if (Mma_tile_with_padding::MMAS_K >= 2)
|
|
{
|
|
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
|
|
}
|
|
}
|
|
|
|
// Load from shared memory.
|
|
inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki)
|
|
{
|
|
if (ki < Mma_tile::VALID_MMAS_K)
|
|
{
|
|
#pragma unroll
|
|
for (int mi = 0; mi < Mma_tile::MMAS_M; ++mi)
|
|
{
|
|
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
|
|
int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
|
|
|
|
// Load using LDSM.M88.4.
|
|
uint4 tmp;
|
|
ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
|
|
|
|
// Store the value into the fragment.
|
|
a[mi].reg(0) = tmp.x;
|
|
a[mi].reg(1) = tmp.y;
|
|
a[mi].reg(2) = tmp.z;
|
|
a[mi].reg(3) = tmp.w;
|
|
}
|
|
}
|
|
|
|
// Move the offset to the next position. See doc/mma_smem_layout.xlsx.
|
|
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
|
|
if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15)
|
|
{
|
|
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7)
|
|
{
|
|
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3)
|
|
{
|
|
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1)
|
|
{
|
|
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 2)
|
|
{
|
|
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2;
|
|
}
|
|
}
|
|
|
|
// Reset the read offset.
|
|
inline __device__ void reset_read_offset()
|
|
{
|
|
// The number of MMAs in the K dimension.
|
|
enum
|
|
{
|
|
MMAS_K = Mma_tile::MMAS_K
|
|
};
|
|
|
|
// The number of MMAs in the K dimension when we include padding.
|
|
enum
|
|
{
|
|
MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K
|
|
};
|
|
|
|
// Assemble the mask.
|
|
enum
|
|
{
|
|
MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE
|
|
};
|
|
|
|
// Reset the read offset.
|
|
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_a<Ampere_hmma_fp16_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_ampere_row_a<Ampere_hmma_fp16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
// The traits class.
|
|
using Traits = Ampere_hmma_fp16_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_ampere_row_a<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_a(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_a<Ampere_hmma_fp32_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_ampere_row_a<Ampere_hmma_fp32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
// The traits class.
|
|
using Traits = Ampere_hmma_fp32_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_ampere_row_a<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_a(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_a<Ampere_hmma_bf16_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_ampere_row_a<Ampere_hmma_bf16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
// The traits class.
|
|
using Traits = Ampere_hmma_bf16_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_ampere_row_a<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_a(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_a<Ampere_imma_int8_int32_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_ampere_row_a<Ampere_imma_int8_int32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
|
|
// The traits class.
|
|
using Traits = Ampere_imma_int8_int32_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_ampere_row_a<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_a(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_a<Ada_qmma_e4m3_fp32_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_ampere_row_a<Ada_qmma_e4m3_fp32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
|
|
// The traits class.
|
|
using Traits = Ada_qmma_e4m3_fp32_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_ampere_row_a<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_a(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_a<Ada_qmma_e4m3_fp16_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_ampere_row_a<Ada_qmma_e4m3_fp16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
|
|
// The traits class.
|
|
using Traits = Ada_qmma_e4m3_fp16_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_ampere_row_a<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_a(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The instruction traits.
|
|
typename Traits,
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The layout of the tile.
|
|
typename Layout,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS = 16,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE = 1,
|
|
// Use or not predicates
|
|
bool USE_PREDICATES = true>
|
|
struct Smem_tile_b
|
|
{
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Traits, int N>
|
|
struct Rows_per_xor_pattern_volta_b
|
|
{
|
|
// The size in bits.
|
|
enum
|
|
{
|
|
N_IN_BITS = N * Traits::BITS_PER_ELEMENT_B
|
|
};
|
|
|
|
// The number of rows.
|
|
enum
|
|
{
|
|
VALUE = N_IN_BITS <= 256 ? 1 : (N_IN_BITS <= 512 ? 2 : 4)
|
|
};
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The instruction traits.
|
|
typename Traits,
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE,
|
|
// How many rows to use for the XOR pattern to avoid bank conflicts?
|
|
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_volta_b<Traits, Cta_tile::K>::VALUE>
|
|
struct Smem_tile_volta_col_b : public Smem_tile_without_skews<Cta_tile, Cta_tile::N, Cta_tile::K, 16, BYTES_PER_STS,
|
|
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>
|
|
{
|
|
|
|
// The MMA tile.
|
|
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
|
|
// The base class.
|
|
using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::N, Cta_tile::K, 16, BYTES_PER_STS, BUFFERS_PER_TILE, 0,
|
|
ROWS_PER_XOR_PATTERN_, 1>;
|
|
|
|
// When we use padding to reach a power of two, special care has to be taken.
|
|
using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Traits, Cta_tile>;
|
|
// The number of MMAs.
|
|
using Mma_tile_with_padding = typename Traits::template Mma_tile<Cta_tile_with_padding>;
|
|
// The fragment.
|
|
using Fragment = Fragment_b<Traits, Col>;
|
|
|
|
// The size of a single LDS in bytes.
|
|
enum
|
|
{
|
|
BYTES_PER_LDS = 16
|
|
};
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_volta_col_b(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
|
|
// For documentation on the layout, see doc/xmma_smem_layout.xlsx.
|
|
|
|
// The number of warps.
|
|
int const WARPS_M = Cta_tile::WARPS_M;
|
|
int const WARPS_N = Cta_tile::WARPS_N;
|
|
int const WARPS_K = Cta_tile::WARPS_K;
|
|
|
|
// The masks to select the warps.
|
|
int const WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
|
|
int const WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;
|
|
|
|
// The divisor for the warps.
|
|
int const WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;
|
|
int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;
|
|
|
|
// The row and column read by the thread.
|
|
int smem_read_row, smem_read_col;
|
|
|
|
if (Base::N_WITH_PADDING >= 64)
|
|
{
|
|
smem_read_row
|
|
= (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 1 + (tidx & 0x18) / 2 + (tidx & 0x03);
|
|
smem_read_col = (tidx & 0x03);
|
|
}
|
|
else if (Base::N_WITH_PADDING == 32)
|
|
{
|
|
smem_read_row
|
|
= (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 2 + (tidx & 0x18) / 4 + (tidx & 0x02) / 2;
|
|
smem_read_col = (tidx & 0x02) / 2 + (tidx & 0x01) * 4;
|
|
}
|
|
else
|
|
{
|
|
assert(false);
|
|
}
|
|
|
|
// For WARPS_K > 1, we do not support Base::N_WITH_PADDING < 64 for the moment.
|
|
static_assert(WARPS_K <= 2 && (WARPS_K == 1 || Base::N_WITH_PADDING >= 64), "");
|
|
|
|
// We "swap" the block for the second warp working on the in-CTA split-K.
|
|
if (WARPS_K == 2)
|
|
{
|
|
smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K;
|
|
}
|
|
|
|
// The shared memory offset.
|
|
this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS;
|
|
}
|
|
|
|
// Rewind smem_read_offset for last LDS phase in main loop.-
|
|
inline __device__ void reverse_smem_read_offset(int ki = 0)
|
|
{
|
|
// Move the offset to the next position. See doc/xmma_smem_layout.xlsx.
|
|
this->smem_read_offset_ ^= ((ki % 2 == 0) ? 1 : 3) * BYTES_PER_LDS;
|
|
}
|
|
|
|
// Load from shared memory.
|
|
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki)
|
|
{
|
|
#pragma unroll
|
|
for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni)
|
|
{
|
|
|
|
// Jump over as many rows as needed.
|
|
int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
|
|
|
|
// TODO: Can we fuse read_offset and read_buffer?
|
|
uint4 tmp;
|
|
lds(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
|
|
b[ni].reg(0) = tmp.x;
|
|
b[ni].reg(1) = tmp.y;
|
|
b[ni].reg(2) = tmp.z;
|
|
b[ni].reg(3) = tmp.w;
|
|
}
|
|
|
|
// Move the offset to the next position. See doc/xmma_smem_layout.xlsx.
|
|
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
|
|
if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15)
|
|
{
|
|
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7)
|
|
{
|
|
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3)
|
|
{
|
|
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1)
|
|
{
|
|
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 2)
|
|
{
|
|
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS;
|
|
}
|
|
}
|
|
|
|
// Reset the read offset.
|
|
inline __device__ void reset_read_offset()
|
|
{
|
|
// The number of MMAs in the K dimension.
|
|
enum
|
|
{
|
|
MMAS_K = Mma_tile::MMAS_K
|
|
};
|
|
|
|
// The number of MMAs in the K dimension when we include padding.
|
|
enum
|
|
{
|
|
MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K
|
|
};
|
|
|
|
// Assemble the mask.
|
|
enum
|
|
{
|
|
MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE
|
|
};
|
|
|
|
// Reset the read offset.
|
|
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS;
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_b<Volta_hmma_fp16_traits, Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_volta_col_b<Volta_hmma_fp16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
|
|
// The traits class.
|
|
using Traits = fmha::Volta_hmma_fp16_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_volta_col_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_b(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Traits, int N>
|
|
struct Rows_per_xor_pattern_turing_b
|
|
{
|
|
// The size in bits.
|
|
enum
|
|
{
|
|
N_IN_BITS = N * Traits::BITS_PER_ELEMENT_B
|
|
};
|
|
|
|
// The number of rows.
|
|
enum
|
|
{
|
|
VALUE = N_IN_BITS <= 128 ? 1 : (N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8))
|
|
};
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The instruction traits.
|
|
typename Traits,
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE,
|
|
// How many rows to use for the XOR pattern to avoid bank conflicts?
|
|
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_turing_b<Traits, Cta_tile::K>::VALUE>
|
|
struct Smem_tile_turing_col_b
|
|
: public Smem_tile_without_skews<Cta_tile, Cta_tile::N, Cta_tile::K, Traits::BITS_PER_ELEMENT_B, BYTES_PER_STS,
|
|
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>
|
|
{
|
|
|
|
// The MMA tile.
|
|
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
|
|
// The base class.
|
|
using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::N, Cta_tile::K, Traits::BITS_PER_ELEMENT_B, BYTES_PER_STS,
|
|
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>;
|
|
// The fragment.
|
|
using Fragment = Fragment_b<Traits, Col>;
|
|
|
|
// When we use padding to reach a power of two, special care has to be taken.
|
|
using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Traits, Cta_tile>;
|
|
// The number of MMAs.
|
|
using Mma_tile_with_padding = typename Traits::template Mma_tile<Cta_tile_with_padding>;
|
|
|
|
// The size of a single LDS in bytes.
|
|
enum
|
|
{
|
|
BYTES_PER_LDS = 16
|
|
};
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_turing_col_b(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
|
|
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
|
|
|
|
// The number of warps.
|
|
int const WARPS_M = Cta_tile::WARPS_M;
|
|
int const WARPS_N = Cta_tile::WARPS_N;
|
|
int const WARPS_K = Cta_tile::WARPS_K;
|
|
|
|
// The masks to select the warps.
|
|
int const WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
|
|
int const WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;
|
|
|
|
// The divisor for the warps.
|
|
int const WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;
|
|
int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;
|
|
|
|
// The row and column read by the thread.
|
|
int smem_read_row, smem_read_col;
|
|
|
|
static_assert(Base::ROWS_PER_XOR_PATTERN == 8 || Base::ROWS_PER_XOR_PATTERN == 4
|
|
|| Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 1,
|
|
"");
|
|
|
|
if (Base::ROWS_PER_XOR_PATTERN == 8)
|
|
{
|
|
// For group fprop. B is divided into 2 halves along N dimension.
|
|
// The fist warp takes the first half and the second warp takes the second half.
|
|
smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 1 + (tidx & 0x0f);
|
|
smem_read_col = (tidx & 0x07);
|
|
}
|
|
else if (Base::ROWS_PER_XOR_PATTERN == 4)
|
|
{
|
|
smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 2 + (tidx & 0x0e) / 2;
|
|
smem_read_col = (tidx & 0x06) / 2 + (tidx & 0x01) * 4;
|
|
}
|
|
else if (Base::ROWS_PER_XOR_PATTERN == 2)
|
|
{
|
|
smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 4 + (tidx & 0x0c) / 4;
|
|
smem_read_col = (tidx & 0x04) / 4 + (tidx & 0x03) * 2;
|
|
}
|
|
else if (Base::ROWS_PER_XOR_PATTERN == 1)
|
|
{
|
|
smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 8 + (tidx & 0x1f) / 8;
|
|
smem_read_col = (tidx & 0x07);
|
|
}
|
|
|
|
static_assert(WARPS_K <= 2, "");
|
|
|
|
// We "swap" the block for the second warp working on the in-CTA split-K.
|
|
if (WARPS_K == 2)
|
|
{
|
|
smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K;
|
|
}
|
|
|
|
// The shared memory offset.
|
|
this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS;
|
|
}
|
|
|
|
// Rewind smem_read_offset for last LDS phase in main loop.-
|
|
inline __device__ void reverse_smem_read_offset(int ki = 0)
|
|
{
|
|
// Move the offset to the next position. See doc/mma_smem_layout.xlsx.
|
|
this->smem_read_offset_ ^= ((ki % 2 == 0) ? 1 : 3) * BYTES_PER_LDS;
|
|
}
|
|
|
|
// Load from shared memory.
|
|
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki)
|
|
{
|
|
#pragma unroll
|
|
for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni)
|
|
{
|
|
int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
|
|
uint2 tmp;
|
|
ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
|
|
b[ni].reg(0) = tmp.x;
|
|
b[ni].reg(1) = tmp.y;
|
|
}
|
|
// Move the offset to the next position. See doc/mma_smem_layout.xlsx.
|
|
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
|
|
if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15)
|
|
{
|
|
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7)
|
|
{
|
|
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3)
|
|
{
|
|
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1)
|
|
{
|
|
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 2)
|
|
{
|
|
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS;
|
|
}
|
|
}
|
|
|
|
// Reset the read offset.
|
|
inline __device__ void reset_read_offset()
|
|
{
|
|
// The number of MMAs in the K dimension.
|
|
enum
|
|
{
|
|
MMAS_K = Mma_tile::MMAS_K
|
|
};
|
|
|
|
// The number of MMAs in the K dimension when we include padding.
|
|
enum
|
|
{
|
|
MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K
|
|
};
|
|
|
|
// Assemble the mask.
|
|
enum
|
|
{
|
|
MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE
|
|
};
|
|
|
|
// Reset the read offset.
|
|
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS;
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_b<Turing_hmma_fp16_traits, Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_turing_col_b<Turing_hmma_fp16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
|
|
// The traits class.
|
|
using Traits = Turing_hmma_fp16_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_turing_col_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_b(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_b<Turing_hmma_fp32_traits, Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_turing_col_b<Turing_hmma_fp32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
|
|
// The traits class.
|
|
using Traits = Turing_hmma_fp32_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_turing_col_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_b(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_b<Turing_imma_int8_int32_traits, Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_turing_col_b<Turing_imma_int8_int32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
|
|
// The traits class.
|
|
using Traits = Turing_imma_int8_int32_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_turing_col_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_b(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Traits, int N>
|
|
struct Rows_per_xor_pattern_ampere_b
|
|
{
|
|
// The size in bits.
|
|
enum
|
|
{
|
|
N_IN_BITS = N * Traits::BITS_PER_ELEMENT_B
|
|
};
|
|
|
|
// The number of rows.
|
|
enum
|
|
{
|
|
VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8)
|
|
};
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Traits, int N>
|
|
struct Rows_per_xor_pattern_ampere_col_b : public Rows_per_xor_pattern_ampere_b<Traits, N>
|
|
{
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The instruction traits.
|
|
typename Traits,
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE,
|
|
// How many rows to use for the XOR pattern to avoid bank conflicts?
|
|
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_ampere_col_b<Traits, Cta_tile::K>::VALUE>
|
|
struct Smem_tile_ampere_col_b
|
|
: public Smem_tile_without_skews<Cta_tile, Cta_tile::N, Cta_tile::K, Traits::BITS_PER_ELEMENT_B, BYTES_PER_STS,
|
|
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>
|
|
{
|
|
// The MMA tile.
|
|
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
|
|
// The base class.
|
|
using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::N, Cta_tile::K, Traits::BITS_PER_ELEMENT_B, BYTES_PER_STS,
|
|
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>;
|
|
// The fragment.
|
|
using Fragment = Fragment_b<Traits, Col>;
|
|
|
|
// When we use padding to reach a power of two, special care has to be taken.
|
|
using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Traits, Cta_tile>;
|
|
// The number of MMAs.
|
|
using Mma_tile_with_padding = typename Traits::template Mma_tile<Cta_tile_with_padding>;
|
|
|
|
// The size of a single LDS in bytes.
|
|
enum
|
|
{
|
|
BYTES_PER_LDS = 16
|
|
};
|
|
|
|
// The number of STS per thread
|
|
enum
|
|
{
|
|
STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA
|
|
};
|
|
|
|
// The number of STS per thread must be at least 1.
|
|
enum
|
|
{
|
|
STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE
|
|
};
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_ampere_col_b(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
|
|
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
|
|
|
|
// The number of warps.
|
|
int const WARPS_M = Cta_tile::WARPS_M;
|
|
int const WARPS_N = Cta_tile::WARPS_N;
|
|
int const WARPS_K = Cta_tile::WARPS_K;
|
|
|
|
// The masks to select the warps.
|
|
int const WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
|
|
int const WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;
|
|
|
|
// The divisor for the warps.
|
|
int const WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;
|
|
int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;
|
|
|
|
// The row and column read by the thread.
|
|
int smem_read_row, smem_read_col;
|
|
|
|
static_assert(
|
|
Base::ROWS_PER_XOR_PATTERN == 8 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 2, "");
|
|
|
|
if (Base::ROWS_PER_XOR_PATTERN == 8)
|
|
{
|
|
// For group fprop. B is divided into 2 halves along N dimension.
|
|
// The fist warp takes the first half and the second warp takes the second half.
|
|
smem_read_row
|
|
= (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 1 + (tidx & 0x07) + (tidx & 0x10) / 2;
|
|
smem_read_col = (tidx & 0x07);
|
|
smem_read_col ^= (tidx & 0x08) / 8;
|
|
}
|
|
else if (Base::ROWS_PER_XOR_PATTERN == 4)
|
|
{
|
|
smem_read_row
|
|
= (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 2 + (tidx & 0x06) / 2 + (tidx & 0x10) / 4;
|
|
smem_read_col = (tidx & 0x06) / 2 + (tidx & 0x01) * 4;
|
|
smem_read_col ^= (tidx & 0x08) / 8;
|
|
}
|
|
else if (Base::ROWS_PER_XOR_PATTERN == 2)
|
|
{
|
|
smem_read_row
|
|
= (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 4 + (tidx & 0x04) / 4 + (tidx & 0x10) / 8;
|
|
smem_read_col = (tidx & 0x04) / 4 + (tidx & 0x03) * 2;
|
|
smem_read_col ^= (tidx & 0x08) / 8;
|
|
}
|
|
|
|
static_assert(WARPS_K <= 2, "");
|
|
static_assert(WARPS_K != 2 || Base::ROWS_PER_XOR_PATTERN != 2, "");
|
|
|
|
// We "swap" the block for the second warp working on the in-CTA split-K.
|
|
if (WARPS_K == 2)
|
|
{
|
|
smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K * 2;
|
|
}
|
|
|
|
// The shared memory offset.
|
|
this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS;
|
|
}
|
|
|
|
// Rewind smem_read_offset for last LDS phase in main loop.
|
|
inline __device__ void reverse_smem_read_offset(int ki = 0)
|
|
{
|
|
// Undo the pointer increment for the next ni.
|
|
// Should match the load function below for ki = 0.
|
|
if (Mma_tile_with_padding::MMAS_K >= 2)
|
|
{
|
|
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
|
|
}
|
|
}
|
|
|
|
// Load from shared memory.
|
|
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki)
|
|
{
|
|
if (ki < Mma_tile::VALID_MMAS_K)
|
|
{
|
|
#pragma unroll
|
|
for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni)
|
|
{
|
|
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
|
|
int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
|
|
|
|
// Load using LDSM.M88.4.
|
|
uint4 tmp;
|
|
ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
|
|
|
|
// Store the value into the fragment.
|
|
b[ni].reg(0) = tmp.x;
|
|
b[ni].reg(1) = tmp.y;
|
|
b[ni].reg(2) = tmp.z;
|
|
b[ni].reg(3) = tmp.w;
|
|
}
|
|
}
|
|
|
|
// Move the offset to the next position. See doc/mma_smem_layout.xlsx.
|
|
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
|
|
if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15)
|
|
{
|
|
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7)
|
|
{
|
|
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3)
|
|
{
|
|
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1)
|
|
{
|
|
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2;
|
|
}
|
|
else if (Mma_tile_with_padding::MMAS_K >= 2)
|
|
{
|
|
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2;
|
|
}
|
|
}
|
|
|
|
// Reset the read offset.
|
|
inline __device__ void reset_read_offset()
|
|
{
|
|
// The number of MMAs in the K dimension.
|
|
enum
|
|
{
|
|
MMAS_K = Mma_tile::MMAS_K
|
|
};
|
|
|
|
// The number of MMAs in the K dimension when we include padding.
|
|
enum
|
|
{
|
|
MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K
|
|
};
|
|
|
|
// Assemble the mask.
|
|
enum
|
|
{
|
|
MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE
|
|
};
|
|
|
|
// Reset the read offset.
|
|
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_b<Ampere_hmma_fp16_traits, Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_ampere_col_b<Ampere_hmma_fp16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
|
|
// The traits class.
|
|
using Traits = Ampere_hmma_fp16_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_ampere_col_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_b(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_b<Ampere_hmma_fp32_traits, Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_ampere_col_b<Ampere_hmma_fp32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
|
|
// The traits class.
|
|
using Traits = Ampere_hmma_fp32_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_ampere_col_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_b(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_b<Ampere_hmma_bf16_traits, Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_ampere_col_b<Ampere_hmma_bf16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
|
|
// The traits class.
|
|
using Traits = Ampere_hmma_bf16_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_ampere_col_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_b(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_b<Ampere_imma_int8_int32_traits, Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_ampere_col_b<Ampere_imma_int8_int32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
|
|
// The traits class.
|
|
using Traits = Ampere_imma_int8_int32_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_ampere_col_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_b(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_b<Ada_qmma_e4m3_fp32_traits, Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_ampere_col_b<Ada_qmma_e4m3_fp32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
|
|
// The traits class.
|
|
using Traits = Ada_qmma_e4m3_fp32_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_ampere_col_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_b(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_b<Ada_qmma_e4m3_fp16_traits, Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_ampere_col_b<Ada_qmma_e4m3_fp16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
|
|
// The traits class.
|
|
using Traits = Ada_qmma_e4m3_fp16_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_ampere_col_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_b(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Traits, int N>
|
|
struct Rows_per_xor_pattern_ampere_row_b : public Rows_per_xor_pattern_ampere_b<Traits, N>
|
|
{
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The instruction traits.
|
|
typename Traits,
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE,
|
|
// How many rows to use for the XOR pattern to avoid bank conflicts?
|
|
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_ampere_row_b<Traits, Cta_tile::N>::VALUE,
|
|
// How many cols to use for the XOR pattern to avoid bank conflicts?
|
|
int COLS_PER_XOR_PATTERN_ = 1>
|
|
struct Smem_tile_ampere_row_b
|
|
: public Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, Traits::BITS_PER_ELEMENT_B, BYTES_PER_STS,
|
|
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, COLS_PER_XOR_PATTERN_>
|
|
{
|
|
|
|
// The MMA tile.
|
|
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
|
|
// The base class.
|
|
using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, Traits::BITS_PER_ELEMENT_B, BYTES_PER_STS,
|
|
BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, COLS_PER_XOR_PATTERN_>;
|
|
// The fragment.
|
|
using Fragment = Fragment_b<Traits, Row>;
|
|
|
|
// Can we use LDSM? No if the data type is 32-bit large.
|
|
enum
|
|
{
|
|
USE_LDSMT = Traits::BITS_PER_ELEMENT_B == 16
|
|
};
|
|
|
|
// The size of a single LDS in bytes.
|
|
enum
|
|
{
|
|
BYTES_PER_LDS = USE_LDSMT ? 16 : 4
|
|
};
|
|
|
|
// The number of elements per LDS.
|
|
enum
|
|
{
|
|
ELEMENTS_PER_LDS = BYTES_PER_LDS * 8 / Traits::BITS_PER_ELEMENT_B
|
|
};
|
|
|
|
// The number of STS per thread
|
|
enum
|
|
{
|
|
STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA
|
|
};
|
|
|
|
// The number of STS per thread must be at least 1.
|
|
enum
|
|
{
|
|
STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE
|
|
};
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_ampere_row_b(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
|
|
// For documentation on the layout, see doc/xmma_smem_layout.xlsx.
|
|
|
|
// The number of warps.
|
|
int const WARPS_M = Cta_tile::WARPS_M;
|
|
int const WARPS_N = Cta_tile::WARPS_N;
|
|
int const WARPS_K = Cta_tile::WARPS_K;
|
|
|
|
// The masks to select the warps.
|
|
int const WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
|
|
int const WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;
|
|
|
|
// The divisor for the warps.
|
|
int const WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;
|
|
int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;
|
|
|
|
// The row/col read by the thread.
|
|
int smem_read_row, smem_read_col;
|
|
|
|
static_assert((USE_LDSMT && Base::ROWS_PER_XOR_PATTERN == 8) || Base::ROWS_PER_XOR_PATTERN == 4
|
|
|| Base::ROWS_PER_XOR_PATTERN == 2,
|
|
"");
|
|
|
|
if (USE_LDSMT && Base::ROWS_PER_XOR_PATTERN == 8)
|
|
{
|
|
// For group dgrad. B is divided into 2 halves along K dimension.
|
|
// The fist warp takes the first half and the second warp takes the second half.
|
|
smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 + (tidx & 0x07) + (tidx & 0x08);
|
|
smem_read_col = (tidx & 0x07);
|
|
}
|
|
else if (USE_LDSMT && Base::ROWS_PER_XOR_PATTERN == 4)
|
|
{
|
|
smem_read_row
|
|
= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 8 + (tidx & 0x06) / 2 + (tidx & 0x08) / 2;
|
|
smem_read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2;
|
|
}
|
|
else if (USE_LDSMT && Base::ROWS_PER_XOR_PATTERN == 2)
|
|
{
|
|
smem_read_row
|
|
= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 4 + (tidx & 0x04) / 4 + (tidx & 0x08) / 4;
|
|
smem_read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4;
|
|
}
|
|
else if (Base::ROWS_PER_XOR_PATTERN == 4 && Base::COLS_PER_XOR_PATTERN == 2)
|
|
{
|
|
smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 8 + (tidx & 0x03);
|
|
smem_read_col = (tidx & 0x1c) / 4 + (tidx & 0x03) * 8;
|
|
}
|
|
|
|
// Each half-warp applies a different XOR pattern -- see the Excel document.
|
|
if (USE_LDSMT)
|
|
{
|
|
smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16;
|
|
}
|
|
else
|
|
{
|
|
smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 16;
|
|
}
|
|
|
|
// The shared memory offset.
|
|
this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS;
|
|
|
|
// Fill zeroes for group conv
|
|
}
|
|
|
|
// Rewind smem_read_offset for last LDS phase in main loop.
|
|
inline __device__ void reverse_smem_read_offset(int ki = 0)
|
|
{
|
|
// The size of each element in bits.
|
|
int const BITS_PER_ELT = Traits::BITS_PER_ELEMENT_B;
|
|
// The size in bytes of the data needed to compute an MMA per CTA.
|
|
int const BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;
|
|
|
|
#pragma unroll
|
|
for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni)
|
|
{
|
|
// Undo the pointer increment for the next ni.
|
|
// Should match the load function below for ki = 0.
|
|
if (BYTES_PER_MMA_PER_CTA >= 128)
|
|
{
|
|
// Nothing to do!
|
|
}
|
|
else if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1)
|
|
{
|
|
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
|
|
}
|
|
else if (BYTES_PER_MMA_PER_CTA == 64)
|
|
{
|
|
// Nothing to do!
|
|
}
|
|
else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4)
|
|
{
|
|
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
|
|
}
|
|
else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2)
|
|
{
|
|
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
|
|
}
|
|
}
|
|
|
|
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
|
|
if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && Mma_tile::MMAS_N % 2 == 1)
|
|
{
|
|
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
|
|
}
|
|
}
|
|
|
|
// Load from shared memory.
|
|
inline __device__ void load(Fragment (&b)[Mma_tile::VALID_MMAS_N], int ki)
|
|
{
|
|
// The size of each element in bits.
|
|
int const BITS_PER_ELT = Traits::BITS_PER_ELEMENT_B;
|
|
// The size in bytes of the data needed to compute an MMA per CTA.
|
|
int const BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;
|
|
|
|
#pragma unroll
|
|
for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni)
|
|
{
|
|
// Prepare the offset.
|
|
int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * Base::BYTES_PER_ROW;
|
|
if (BYTES_PER_MMA_PER_CTA == 32)
|
|
{
|
|
offset += this->smem_read_offset_;
|
|
}
|
|
else if (BYTES_PER_MMA_PER_CTA == 64)
|
|
{
|
|
offset += this->smem_read_offset_ + (ni / 2) * BYTES_PER_MMA_PER_CTA * 2;
|
|
}
|
|
else
|
|
{
|
|
offset += this->smem_read_offset_ + (ni) *BYTES_PER_MMA_PER_CTA;
|
|
}
|
|
|
|
// Load the data using LDSM.MT88.2.
|
|
uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset;
|
|
|
|
if (ni < Mma_tile::VALID_MMAS_N)
|
|
{
|
|
uint4 tmp;
|
|
if (USE_LDSMT)
|
|
{
|
|
ldsmt(tmp, ptr);
|
|
}
|
|
else
|
|
{
|
|
lds(tmp.x, (ptr) + 0 * Base::BYTES_PER_ROW);
|
|
lds(tmp.y, (ptr) + 4 * Base::BYTES_PER_ROW);
|
|
lds(tmp.z, (ptr ^ 32) + 0 * Base::BYTES_PER_ROW);
|
|
lds(tmp.w, (ptr ^ 32) + 4 * Base::BYTES_PER_ROW);
|
|
}
|
|
|
|
// Store those values in the fragment.
|
|
b[ni].reg(0) = tmp.x;
|
|
b[ni].reg(1) = tmp.y;
|
|
b[ni].reg(2) = tmp.z;
|
|
b[ni].reg(3) = tmp.w;
|
|
}
|
|
|
|
// static_assert(BYTES_PER_MMA_PER_CTA >= 128 ||
|
|
// BYTES_PER_MMA_PER_CTA == 64 ||
|
|
// (BYTES_PER_MMA_PER_CTA == 32 &&
|
|
// (Mma_tile::MMAS_M == 4 ||
|
|
// Mma_tile::MMAS_M == 2 ||
|
|
// Mma_tile::MMAS_M == 1)), "");
|
|
|
|
// Move the pointer for the next ni. I expect the compiler to not recompute those.
|
|
if (BYTES_PER_MMA_PER_CTA >= 128)
|
|
{
|
|
// Nothing to do!
|
|
}
|
|
else if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1)
|
|
{
|
|
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
|
|
}
|
|
else if (BYTES_PER_MMA_PER_CTA == 64)
|
|
{
|
|
// Nothing to do!
|
|
}
|
|
else if (BYTES_PER_MMA_PER_CTA == 32)
|
|
{
|
|
if ((ni & 1) == 0)
|
|
{
|
|
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
|
|
}
|
|
else if (Mma_tile::MMAS_N >= 16 && (ni & 7) == 7)
|
|
{
|
|
this->smem_read_offset_ ^= BYTES_PER_LDS * 30;
|
|
}
|
|
else if (Mma_tile::MMAS_N >= 8 && (ni & 3) == 3)
|
|
{
|
|
this->smem_read_offset_ ^= BYTES_PER_LDS * 14;
|
|
}
|
|
else if (Mma_tile::MMAS_N >= 4 && (ni & 1) == 1)
|
|
{
|
|
this->smem_read_offset_ ^= BYTES_PER_LDS * 6;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
|
|
if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && Mma_tile::MMAS_N % 2 == 1)
|
|
{
|
|
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
|
|
}
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_b<Ampere_hmma_fp32_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_ampere_row_b<Ampere_hmma_fp32_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
|
|
// The traits class.
|
|
using Traits = Ampere_hmma_fp32_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_ampere_row_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_b(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
// The dimensions of the tile computed by the CTA.
|
|
typename Cta_tile,
|
|
// The size of the STS.
|
|
int BYTES_PER_STS,
|
|
// The number of buffers per tile.
|
|
int BUFFERS_PER_TILE>
|
|
struct Smem_tile_b<Ampere_hmma_bf16_traits, Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
: public Smem_tile_ampere_row_b<Ampere_hmma_bf16_traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>
|
|
{
|
|
|
|
// The traits class.
|
|
using Traits = Ampere_hmma_bf16_traits;
|
|
// The base class.
|
|
using Base = Smem_tile_ampere_row_b<Traits, Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
|
|
|
|
// Ctor.
|
|
inline __device__ Smem_tile_b(void* smem, int tidx)
|
|
: Base(smem, tidx)
|
|
{
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace fmha
|