mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Yao Yao <lowsfer@users.noreply.github.com> Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com> Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> Co-authored-by: Yao Yao <lowsfer@users.noreply.github.com> Co-authored-by: peaceh-nv <103117813+peaceh-nv@users.noreply.github.com>
90 lines
2.5 KiB
Plaintext
90 lines
2.5 KiB
Plaintext
#pragma once
|
|
#include "mha_components.cuh"
|
|
#include "mha_stdheaders.cuh"
|
|
#include "tma.h"
|
|
#include "utils.cuh"
|
|
|
|
template <uint32_t tileM, uint32_t totalNbRows>
|
|
__device__ inline ThrdRegRowMaxT<tileM> loadShmRowMax(
|
|
Vec<float, totalNbRows> const& shm, uint32_t tileBaseRow, uint32_t lane = laneId())
|
|
{
|
|
ThrdRegRowMaxT<tileM> result;
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < result.size; i++)
|
|
{
|
|
result[i] = shm[tileBaseRow + i * warp_size + lane];
|
|
}
|
|
return result;
|
|
}
|
|
|
|
template <uint32_t tileM, uint32_t totalNbRows>
|
|
__device__ inline void storeRowMax(
|
|
Vec<float, totalNbRows>& shm, ThrdRegRowMaxT<tileM> const& src, uint32_t tileBaseRow, uint32_t lane = laneId())
|
|
{
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < src.size; i++)
|
|
{
|
|
shm[tileBaseRow + i * warp_size + lane] = src[i];
|
|
}
|
|
}
|
|
|
|
template <uint32_t tileM, uint32_t totalNbRows>
|
|
__device__ inline void storeRowMaxAsync(CgaBarrier& bar, Vec<float, totalNbRows>& shm, ThrdRegRowMaxT<tileM> const& src,
|
|
uint32_t tileBaseRow, uint32_t lane = laneId())
|
|
{
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < src.size; i++)
|
|
{
|
|
tma::storeAsync(&shm[tileBaseRow + i * warp_size + lane], src[i], bar);
|
|
}
|
|
}
|
|
|
|
template <uint32_t tileM, uint32_t tileN>
|
|
__device__ inline QuadRegRowMaxT<tileM> computeRowMax(WarpAccT<tileM, tileN> const& acc)
|
|
{
|
|
QuadRegRowMaxT<tileM> rowMaxLog2e;
|
|
// compute per-thread row max
|
|
#pragma unroll
|
|
for (uint32_t n = 0; n < acc.cols; n++)
|
|
{
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < InstAcc::cols; j++)
|
|
{
|
|
#pragma unroll
|
|
for (uint32_t m = 0; m < acc.rows; m++)
|
|
{
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < InstAcc::rows; i++)
|
|
{
|
|
float& dst = rowMaxLog2e[m * InstAcc::rows + i];
|
|
dst = ((n == 0 && j == 0) ? acc(m, n)(i, j) : fmaxf(dst, acc(m, n)(i, j)));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// compute warp row max
|
|
#pragma unroll
|
|
for (uint32_t xorMask = 2; xorMask != 0; xorMask /= 2)
|
|
{
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < rowMaxLog2e.size; i++)
|
|
{
|
|
rowMaxLog2e[i] = fmaxf(rowMaxLog2e[i], __shfl_xor_sync(~0U, rowMaxLog2e[i], xorMask));
|
|
}
|
|
}
|
|
return rowMaxLog2e;
|
|
}
|
|
|
|
template <typename T, uint32_t n>
|
|
__device__ inline uint32_t hashRegData(Vec<T, n> const& data)
|
|
{
|
|
static_assert(sizeof(T) == 4);
|
|
uint32_t result = 0;
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < n; i++)
|
|
{
|
|
result ^= reinterpret_cast<uint32_t const&>(data[i]);
|
|
}
|
|
return result;
|
|
}
|