TensorRT-LLMs/cpp/kernels/xqa/mla_sm120.cuh
Yao Yao 0788c5d0d6
[perf] improve XQA-MLA perf (#5468)
Signed-off-by: Yao Yao <lowsfer@users.noreply.github.com>
2025-06-26 18:09:13 +08:00

90 lines
2.5 KiB
Plaintext

#pragma once
#include "mha_components.cuh"
#include "mha_stdheaders.cuh"
#include "tma.h"
#include "utils.cuh"
template <uint32_t tileM, uint32_t totalNbRows>
__device__ inline ThrdRegRowMaxT<tileM> loadShmRowMax(
Vec<float, totalNbRows> const& shm, uint32_t tileBaseRow, uint32_t lane = laneId())
{
ThrdRegRowMaxT<tileM> result{};
#pragma unroll
for (uint32_t i = 0; i < result.size; i++)
{
result[i] = shm[tileBaseRow + i * warp_size + lane];
}
return result;
}
template <uint32_t tileM, uint32_t totalNbRows>
__device__ inline void storeRowMax(
Vec<float, totalNbRows>& shm, ThrdRegRowMaxT<tileM> const& src, uint32_t tileBaseRow, uint32_t lane = laneId())
{
#pragma unroll
for (uint32_t i = 0; i < src.size; i++)
{
shm[tileBaseRow + i * warp_size + lane] = src[i];
}
}
template <uint32_t tileM, uint32_t totalNbRows>
__device__ inline void storeRowMaxAsync(CgaBarrier& bar, Vec<float, totalNbRows>& shm, ThrdRegRowMaxT<tileM> const& src,
uint32_t tileBaseRow, uint32_t lane = laneId())
{
#pragma unroll
for (uint32_t i = 0; i < src.size; i++)
{
tma::storeAsync(&shm[tileBaseRow + i * warp_size + lane], src[i], bar);
}
}
template <uint32_t tileM, uint32_t tileN>
__device__ inline QuadRegRowMaxT<tileM> computeRowMax(WarpAccT<tileM, tileN> const& acc)
{
QuadRegRowMaxT<tileM> rowMaxLog2e{};
// compute per-thread row max
#pragma unroll
for (uint32_t n = 0; n < acc.cols; n++)
{
#pragma unroll
for (uint32_t j = 0; j < InstAcc::cols; j++)
{
#pragma unroll
for (uint32_t m = 0; m < acc.rows; m++)
{
#pragma unroll
for (uint32_t i = 0; i < InstAcc::rows; i++)
{
float& dst = rowMaxLog2e[m * InstAcc::rows + i];
dst = ((n == 0 && j == 0) ? acc(m, n)(i, j) : fmaxf(dst, acc(m, n)(i, j)));
}
}
}
}
// compute warp row max
#pragma unroll
for (uint32_t xorMask = 2; xorMask != 0; xorMask /= 2)
{
#pragma unroll
for (uint32_t i = 0; i < rowMaxLog2e.size; i++)
{
rowMaxLog2e[i] = fmaxf(rowMaxLog2e[i], __shfl_xor_sync(~0U, rowMaxLog2e[i], xorMask));
}
}
return rowMaxLog2e;
}
template <typename T, uint32_t n>
__device__ inline uint32_t hashRegData(Vec<T, n> const& data)
{
static_assert(sizeof(T) == 4);
uint32_t result = 0;
#pragma unroll
for (uint32_t i = 0; i < n; i++)
{
result ^= reinterpret_cast<uint32_t const&>(data[i]);
}
return result;
}