TensorRT-LLMs/cpp/kernels/xqa/test/refAttention.h
Ming Wei ed887940d4
infra: open source XQA kernels (#3762)
Replace libtensorrt_llm_nvrtc_wrapper.so with its source code, which
consists of two parts:

1. NVRTC glue code
2. XQA kernel code

During TensorRT-LLM build, XQA kernel code is embedded as C++ arries via
gen_cpp_header.py and passed to NVRTC for JIT compilation.

Signed-off-by: Ming Wei <2345434+ming-wei@users.noreply.github.com>
2025-04-30 18:05:15 +08:00

120 lines
3.8 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#pragma once
#include "../mha.h"
#include <Eigen/Dense>
template <bool isPaged, bool useBeamSearch>
struct CacheSeq;
template <>
struct CacheSeq<false, false>
{
GMemCacheHead const& operator[](uint32_t i) const
{
return data[i];
}
GMemCacheHead const* data;
};
template <>
struct CacheSeq<false, true>
{
GMemCacheHead const& operator[](uint32_t i) const
{
return data[2 * nbKHeads * maxSeqLen * cacheIndir[i] + i];
}
uint32_t nbKHeads;
GMemCacheHead const* data;
uint32_t const* cacheIndir;
uint32_t maxSeqLen;
};
template <>
struct CacheSeq<true, false>
{
GMemCacheHead const& operator[](uint32_t i) const
{
uint32_t const pageIdx = pageIndices[i / tokensPerPage];
return pool[tokensPerPage * nbHeads * pageIdx + tokensPerPage * idxHead + i % tokensPerPage];
}
GMemCacheHead const* pool;
int32_t const* pageIndices;
uint32_t nbHeads;
uint32_t idxHead;
};
template <>
struct CacheSeq<true, true>
{
GMemCacheHead const& operator[](uint32_t i) const
{
uint32_t const pageIdx = pageIndices[cacheIndir[i] * 2 * maxNbPages + i / tokensPerPage];
return pool[tokensPerPage * nbHeads * pageIdx + tokensPerPage * idxHead + i % tokensPerPage];
}
GMemCacheHead const* pool;
int32_t const* pageIndices;
uint32_t maxNbPages;
uint32_t nbHeads;
uint32_t idxHead;
uint32_t const* cacheIndir;
};
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
float kvScale, float xScale, uint32_t slidingWinSize);
template <typename MathElem, bool isPaged, bool useBeamSearch>
#if SPEC_DEC
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttention(IOHead const* q,
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
float kvScale, float xScale, bool* hostMask, const uint32_t qSeqLen, const uint32_t q_len);
#else
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttention(IOHead const* q,
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
float kvScale, float xScale, uint32_t slidingWinSize);
#endif
template <uint32_t ropeStyle>
InputHead applyRoPE(InputHead const& head, Vec<float, validElemsPerHead> const& ropeCosSin)
{
if constexpr (ropeStyle == 0)
{
return head;
}
constexpr uint32_t nbPairs = exactDiv(validElemsPerHead, 2);
InputHead dst;
constexpr bool isNeox = (ropeStyle == 1);
for (uint32_t i = 0; i < nbPairs; i++)
{
float const c = ropeCosSin[i * 2];
float const s = ropeCosSin[i * 2 + 1];
Eigen::Matrix2f r;
r << c, -s, s, c;
Eigen::Vector2f v;
uint32_t const ix = (isNeox ? i : i * 2);
uint32_t const iy = (isNeox ? nbPairs + i : i * 2 + 1);
v << float(head[ix]), float(head[iy]);
auto const rv = (r * v).eval();
dst[ix] = InputElem{rv[0]};
dst[iy] = InputElem{rv[1]};
}
return dst;
}