mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
131 lines
4.3 KiB
C++
131 lines
4.3 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#define EIGEN_STACK_ALLOCATION_LIMIT (1U << 20)
|
|
#include "../mha.h"
|
|
#include <Eigen/Dense>
|
|
|
|
template <bool isPaged, bool useBeamSearch>
|
|
struct CacheSeq;
|
|
|
|
template <>
|
|
struct CacheSeq<false, false>
|
|
{
|
|
GMemCacheHead const& operator[](uint32_t i) const
|
|
{
|
|
return data[i];
|
|
}
|
|
|
|
GMemCacheHead const* data;
|
|
};
|
|
|
|
template <>
|
|
struct CacheSeq<false, true>
|
|
{
|
|
GMemCacheHead const& operator[](uint32_t i) const
|
|
{
|
|
return data[2 * nbKHeads * maxSeqLen * cacheIndir[i] + i];
|
|
}
|
|
|
|
uint32_t nbKHeads;
|
|
GMemCacheHead const* data;
|
|
uint32_t const* cacheIndir;
|
|
uint32_t maxSeqLen;
|
|
};
|
|
|
|
template <>
|
|
struct CacheSeq<true, false>
|
|
{
|
|
GMemCacheHead const& operator[](uint32_t i) const
|
|
{
|
|
uint32_t const pageIdx = pageIndices[i / tokensPerPage];
|
|
#if PAGED_KV_CACHE_LAYOUT == 1 && USE_PAGED_KV_CACHE
|
|
return pool[nbHeads * tokensPerPage * pageIdx + (i % tokensPerPage) * nbHeads + idxHead];
|
|
#else
|
|
return pool[tokensPerPage * nbHeads * pageIdx + tokensPerPage * idxHead + i % tokensPerPage];
|
|
#endif
|
|
}
|
|
|
|
GMemCacheHead const* pool;
|
|
int32_t const* pageIndices;
|
|
uint32_t nbHeads;
|
|
uint32_t idxHead;
|
|
};
|
|
|
|
template <>
|
|
struct CacheSeq<true, true>
|
|
{
|
|
GMemCacheHead const& operator[](uint32_t i) const
|
|
{
|
|
uint32_t const pageIdx = pageIndices[cacheIndir[i] * 2 * maxNbPages + i / tokensPerPage];
|
|
return pool[tokensPerPage * nbHeads * pageIdx + tokensPerPage * idxHead + i % tokensPerPage];
|
|
}
|
|
|
|
GMemCacheHead const* pool;
|
|
int32_t const* pageIndices;
|
|
uint32_t maxNbPages;
|
|
uint32_t nbHeads;
|
|
uint32_t idxHead;
|
|
uint32_t const* cacheIndir;
|
|
};
|
|
|
|
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
|
|
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
|
|
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
|
|
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, float skipSoftmaxThresholdScaleFactor,
|
|
uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum);
|
|
|
|
template <typename MathElem, bool isPaged, bool useBeamSearch>
|
|
#if SPEC_DEC
|
|
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttention(IOHead const* q,
|
|
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
|
|
float kvScale, float xScale, uint32_t slidingWinSize, bool* hostMask, const uint32_t qSeqLen, const uint32_t q_len);
|
|
#else
|
|
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refAttention(IOHead const* q,
|
|
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
|
|
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks);
|
|
#endif
|
|
|
|
template <uint32_t ropeStyle>
|
|
InputHead applyRoPE(InputHead const& head, Vec<float, validElemsPerHead> const& ropeCosSin)
|
|
{
|
|
if constexpr (ropeStyle == 0)
|
|
{
|
|
return head;
|
|
}
|
|
constexpr uint32_t nbPairs = exactDiv(validElemsPerHead, 2);
|
|
InputHead dst;
|
|
constexpr bool isNeox = (ropeStyle == 1);
|
|
for (uint32_t i = 0; i < nbPairs; i++)
|
|
{
|
|
float const c = ropeCosSin[i * 2];
|
|
float const s = ropeCosSin[i * 2 + 1];
|
|
Eigen::Matrix2f r;
|
|
r << c, -s, s, c;
|
|
Eigen::Vector2f v;
|
|
uint32_t const ix = (isNeox ? i : i * 2);
|
|
uint32_t const iy = (isNeox ? nbPairs + i : i * 2 + 1);
|
|
v << float(head[ix]), float(head[iy]);
|
|
auto const rv = (r * v).eval();
|
|
dst[ix] = InputElem{rv[0]};
|
|
dst[iy] = InputElem{rv[1]};
|
|
}
|
|
return dst;
|
|
}
|