#pragma once #include "mha_components.cuh" #include "mha_stdheaders.cuh" #include "tma.h" #include "utils.cuh" template __device__ inline ThrdRegRowMaxT loadShmRowMax( Vec const& shm, uint32_t tileBaseRow, uint32_t lane = laneId()) { ThrdRegRowMaxT result{}; #pragma unroll for (uint32_t i = 0; i < result.size; i++) { result[i] = shm[tileBaseRow + i * warp_size + lane]; } return result; } template __device__ inline void storeRowMax( Vec& shm, ThrdRegRowMaxT const& src, uint32_t tileBaseRow, uint32_t lane = laneId()) { #pragma unroll for (uint32_t i = 0; i < src.size; i++) { shm[tileBaseRow + i * warp_size + lane] = src[i]; } } template __device__ inline void storeRowMaxAsync(CgaBarrier& bar, Vec& shm, ThrdRegRowMaxT const& src, uint32_t tileBaseRow, uint32_t lane = laneId()) { #pragma unroll for (uint32_t i = 0; i < src.size; i++) { tma::storeAsync(&shm[tileBaseRow + i * warp_size + lane], src[i], bar); } } template __device__ inline QuadRegRowMaxT computeRowMax(WarpAccT const& acc) { QuadRegRowMaxT rowMaxLog2e{}; // compute per-thread row max #pragma unroll for (uint32_t n = 0; n < acc.cols; n++) { #pragma unroll for (uint32_t j = 0; j < InstAcc::cols; j++) { #pragma unroll for (uint32_t m = 0; m < acc.rows; m++) { #pragma unroll for (uint32_t i = 0; i < InstAcc::rows; i++) { float& dst = rowMaxLog2e[m * InstAcc::rows + i]; dst = ((n == 0 && j == 0) ? acc(m, n)(i, j) : fmaxf(dst, acc(m, n)(i, j))); } } } } // compute warp row max #pragma unroll for (uint32_t xorMask = 2; xorMask != 0; xorMask /= 2) { #pragma unroll for (uint32_t i = 0; i < rowMaxLog2e.size; i++) { rowMaxLog2e[i] = fmaxf(rowMaxLog2e[i], __shfl_xor_sync(~0U, rowMaxLog2e[i], xorMask)); } } return rowMaxLog2e; } template __device__ inline uint32_t hashRegData(Vec const& data) { static_assert(sizeof(T) == 4); uint32_t result = 0; #pragma unroll for (uint32_t i = 0; i < n; i++) { result ^= reinterpret_cast(data[i]); } return result; }