#pragma once #include "mma.cuh" #include "utils.cuh" using InstAcc = Array2D; template using WarpAccT = Array2D; template __device__ inline void applyMask( Warp const& warp, Array2D& acc, uint32_t validColBeg, uint32_t validColEnd) { uint32_t const idxInQuad = laneId() % 4; uint32_t const idxQuad = laneId() / 4; #pragma unroll for (uint32_t n = 0; n < acc.cols; n++) { #pragma unroll for (uint32_t j = 0; j < InstAcc::cols; j++) { uint32_t const col = 8 * n + InstAcc::cols * idxInQuad + j; if (col >= validColBeg && col < validColEnd) { continue; } #pragma unroll for (uint32_t m = 0; m < acc.rows; m++) { #pragma unroll for (uint32_t i = 0; i < InstAcc::rows; i++) { acc(m, n)(i, j) = mha::numeric_limits::lowest(); } } } } } template using QuadRegRowMaxT = Vec; // data is replicated across 4 threads in a MMA quad. template using ThrdRegRowMaxT = Vec; // unlike QuadRegRowMax, not replicated. template using UniformRescaleMaskT = Vec; // uniform and stored in UR inline constexpr uint32_t quadPerWarp = warp_size / 4; // idxMat8 is the reduced row index in 8-row unit. template __device__ inline float replicateValForQuad(Warp const& warp, Vec const& src, uint32_t idxMat8) { assertWarpConverged(); uint32_t const i = idxMat8 / 4; uint32_t const j = idxMat8 % 4; return __shfl_sync(~0U, src[i], quadPerWarp * j + laneId() / 4); } template __device__ inline QuadRegRowMaxT replicateForQuad(Warp const& warp, Vec const& src) { assertWarpConverged(); QuadRegRowMaxT dst; #pragma unroll for (uint32_t i = 0; i < src.size; i++) { #pragma unroll for (uint32_t j = 0; j < 4; j++) { dst[i * 4 + j] = __shfl_sync(~0U, src[i], quadPerWarp * j + laneId() / 4); assert(dst[i * 4 + j] == replicateValForQuad(warp, src, i * 4 + j)); } } return dst; } template __device__ inline ThrdRegRowMaxT dedupFromQuad(Warp const& warp, Vec const& src) { #ifndef NDEBUG for (uint32_t i = 0; i < src.size; i++) { assert(src[i] == __shfl_sync(~0U, src[i], laneId() / 4 * 4)); } #endif ThrdRegRowMaxT dst; uint32_t const lane = laneId(); uint32_t const idxMat = lane / 8; uint32_t const idxRow = lane % 8; #pragma unroll for (uint32_t i = 0; i < dst.size; i++) { #pragma unroll for (uint32_t j = 0; j < 4; j++) { float const val = __shfl_sync(~0U, src[i * 4 + j], 4 * idxRow); if (idxMat == j) { dst[i] = val; } } } #ifndef NDEBUG // refcheck QuadRegRowMaxT rep = replicateForQuad(warp, dst); #pragma unroll for (uint32_t i = 0; i < n; i++) { assert(src[i] == rep[i]); __syncwarp(); } #endif return dst; } template __device__ inline ThrdRegRowMaxT computeRowSumF8( Warp const& warp, Array2D, exactDiv(tileM, 16), exactDiv(tileN, 16)> const& src) { using WarpAcc = WarpAccT; WarpAcc acc{}; Vec<__nv_fp8x2_e4m3, 2> const bWord = {__nv_fp8x2_e4m3{float2{1, 1}}, __nv_fp8x2_e4m3{float2{1, 1}}}; uint32_t const b[2][1] = {reinterpret_cast(bWord), reinterpret_cast(bWord)}; #pragma unroll for (uint32_t i = 0; i < WarpAcc::rows; i++) { #pragma unroll for (uint32_t k = 0; k < exactDiv(src.cols, 2); k++) { mma<__nv_fp8_e4m3>(reinterpret_cast(acc(i, 0)), reinterpret_cast(src(i, k * 2)), b); } } QuadRegRowMaxT rowSum; for (uint32_t i = 0; i < WarpAcc::rows; i++) { for (uint32_t m = 0; m < InstAcc::rows; m++) { #ifndef NDEBUG assert(acc(i, 0)(m, 0) == acc(i, 0)(m, 1)); assert(acc(i, 0)(m, 0) == __shfl_sync(~0U, acc(i, 0)(m, 0), laneId() / 4 * 4)); #endif rowSum[i * InstAcc::rows + m] = acc(i, 0)(m, 0); } } return dedupFromQuad(warp, rowSum); } template __device__ inline ThrdRegRowMaxT computeRowSumF32(Warp const& warp, WarpAccT const& src) { QuadRegRowMaxT rowSum{}; #pragma unroll for (uint32_t n = 0; n < src.cols; n++) { #pragma unroll for (uint32_t j = 0; j < InstAcc::cols; j++) { #pragma unroll for (uint32_t m = 0; m < src.rows; m++) { #pragma unroll for (uint32_t i = 0; i < InstAcc::rows; i++) { if (n == 0 && j == 0) { rowSum[m * InstAcc::rows + i] = src(m, n)(i, j); } else { rowSum[m * InstAcc::rows + i] += src(m, n)(i, j); } } } } } uint32_t const lane = laneId(); #pragma unroll for (uint32_t mask = 2; mask != 0; mask /= 2) { #pragma unroll for (uint32_t i = 0; i < rowSum.size; i++) { rowSum[i] += __shfl_xor_sync(~0U, rowSum[i], mask); } } return dedupFromQuad(warp, rowSum); }