TensorRT-LLMs/cpp/tensorrt_llm/kernels/kvCacheUtils.h
Yihan Wang 9df4dad3b6
[None][fix] Introduce inline namespace to avoid symbol collision (#9541)
Signed-off-by: Yihan Wang <yihwang@nvidia.com>
2025-12-12 23:32:15 +08:00

331 lines
12 KiB
C++

/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/config.h"
#include "tensorrt_llm/kernels/kvCacheIndex.h"
#include <cmath>
#include <cstdint>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <limits>
TRTLLM_NAMESPACE_BEGIN
namespace kernels
{
// Internal for K and V cache indexing
enum class KVIdxType : int32_t
{
K_IDX = 0,
V_IDX = 1
};
// Struct operates on paged kv cache providing
// only the fields necessary for context FMHA
struct KVBlockArrayForContextFMHA
{
using DataType = ::tensorrt_llm::kernels::KVCacheIndex const;
// The maximum number of sequences supported by the kv-cache.
int32_t mMaxSeqs;
// Max number of blocks per sequence
int32_t mMaxBlocksPerSeq;
// Number of tokens. It must be power of 2.
int32_t mTokensPerBlock;
// Exponent of number of tokens with base 2.
// E.g. for mTokensPerBlock 64, mTokensPerBlockLog2 equals to 6
int32_t mTokensPerBlockLog2;
// Table maps logical block idx to the data pointer of k/v cache block pool
// Shape [B, W, 2, M], where 2 is table for K and V,
// B is current number of sequences
// W is beam width
// M is Max number of blocks per sequence
// Size of KV cache blocks in bytes (H*D*T*sizeof(DataType))
int32_t mBytesPerBlock;
// Pointer to beginning of pool.
void* mPrimaryPoolPtr;
// Pointer to block offsets.
DataType* data;
KVBlockArrayForContextFMHA()
: mMaxSeqs{0}
, mMaxBlocksPerSeq{0}
, mTokensPerBlock{0}
, mTokensPerBlockLog2{0}
, mBytesPerBlock{0}
, mPrimaryPoolPtr{nullptr}
, data{nullptr}
{
}
KVBlockArrayForContextFMHA(int32_t batchSize, int32_t maxBlocksPerSeq, int32_t tokensPerBlock,
int32_t bytesPerToken, void* primaryPoolPtr, DataType* data)
: mMaxSeqs(batchSize)
, mMaxBlocksPerSeq(maxBlocksPerSeq)
, mTokensPerBlock(tokensPerBlock)
, mBytesPerBlock{tokensPerBlock * bytesPerToken}
, mPrimaryPoolPtr{primaryPoolPtr}
, data{data}
{
float const tokensPerBlockSeqLog2 = log2(mTokensPerBlock);
TLLM_CHECK_WITH_INFO(
ceil(tokensPerBlockSeqLog2) == floor(tokensPerBlockSeqLog2), "tokensPerBlock must be power of 2");
// NOTE: pointer offset arithmetic offset is performed on int32_t (see this.getRowPtr).
// If needed, we could do it on uint32_t or even uint64_t, but that might have performance implications
TLLM_CHECK_WITH_INFO(static_cast<int64_t>(mMaxSeqs - 1) * mMaxBlocksPerSeq * 2 + maxBlocksPerSeq
<= std::numeric_limits<int32_t>::max(),
"kv cache is too large for gpt_attention_plugin");
mTokensPerBlockLog2 = static_cast<int>(tokensPerBlockSeqLog2);
}
};
// Struct operates on paged kv cache providing
// functions for accessing blocks of in K and V caches
// and elements inside these blocks
struct KVBlockArray : public KVBlockArrayForContextFMHA
{
// Pointer to beginning of pool.
void* mSecondaryPoolPtr;
// Maximum kv cache length per sequence
int32_t mMaxAttentionWindow;
// Number of sink tokens.
int32_t mSinkTokens;
// Cyclic cache length.
int32_t mCyclicCacheLen;
// Bubble length.
int32_t mBubbleLen;
// Enable one more block to save the kv tokens
bool mEnableOneMoreBlock;
KVBlockArray()
: mSecondaryPoolPtr(nullptr)
, mMaxAttentionWindow{0}
, mSinkTokens{0}
, mCyclicCacheLen{0}
, mBubbleLen{0}
, mEnableOneMoreBlock{false}
{
}
KVBlockArray(int32_t batchSize, int32_t maxBlocksPerSeq, int32_t tokensPerBlock, int32_t bytesPerToken,
int32_t maxAttentionWindow, int32_t maxAttentionWindowAllLayer, int32_t sinkTokenLen, bool canUseOneMoreBlock,
void* primaryPoolPtr, void* secondaryPoolPtr, DataType* data)
: KVBlockArrayForContextFMHA(batchSize, maxBlocksPerSeq, tokensPerBlock, bytesPerToken, primaryPoolPtr, data)
, mSecondaryPoolPtr{secondaryPoolPtr}
, mMaxAttentionWindow(maxAttentionWindow)
, mSinkTokens(sinkTokenLen)
{
auto sinkTokensInLastBlock = mSinkTokens % mTokensPerBlock;
mBubbleLen = sinkTokensInLastBlock == 0 ? 0 : mTokensPerBlock - sinkTokensInLastBlock;
mEnableOneMoreBlock = (maxBlocksPerSeq - 1) * tokensPerBlock >= maxAttentionWindowAllLayer + mBubbleLen;
mEnableOneMoreBlock &= canUseOneMoreBlock;
mCyclicCacheLen = (mEnableOneMoreBlock) ? mMaxAttentionWindow + mTokensPerBlock - mSinkTokens
: mMaxAttentionWindow - mSinkTokens;
}
[[nodiscard]] KVBlockArrayForContextFMHA copyKVBlockArrayForContextFMHA() const
{
return KVBlockArrayForContextFMHA{
mMaxSeqs, mMaxBlocksPerSeq, mTokensPerBlock, mBytesPerBlock / mTokensPerBlock, mPrimaryPoolPtr, data};
}
__host__ __device__ [[nodiscard]] inline bool isSinkToken(int32_t tokenIdx) const
{
return tokenIdx < mSinkTokens;
}
__host__ __device__ [[nodiscard]] inline int32_t getKVTokenIdx(int32_t tokenIdx) const
{
if (!isSinkToken(tokenIdx))
{
// Apply cyclic kv cache if tokenIdx >= max_attention_window_size.
return mSinkTokens + mBubbleLen + (tokenIdx - mSinkTokens) % mCyclicCacheLen;
}
return tokenIdx;
}
__host__ __device__ [[nodiscard]] inline DataType const* getRowPtr(KVIdxType kvIdx, int32_t seqIdx) const
{
// Returns pointer to array of offsets to K or V cache for one specific sequence seqIdx.
// seqIdx is in range [0; B]
return data + (seqIdx * mMaxBlocksPerSeq * 2 + static_cast<int32_t>(kvIdx) * mMaxBlocksPerSeq);
}
__host__ __device__ inline void* getBlockPtr(DataType const* offsets, int32_t tokenIdx) const
{
auto const offset = offsets[tokenIdx >> mTokensPerBlockLog2];
return reinterpret_cast<void*>(
reinterpret_cast<char*>(getPoolPtr(offset)) + offset.get() * static_cast<uint64_t>(mBytesPerBlock));
}
__host__ __device__ [[nodiscard]] inline void* getBlockPtr(int32_t seqIdx, int32_t tokenIdx, KVIdxType kvIdx) const
{
return getBlockPtr(getRowPtr(kvIdx, seqIdx), tokenIdx);
}
__host__ __device__ [[nodiscard]] inline void* getKBlockPtr(int32_t seqIdx, int32_t tokenIdx) const
{
return getBlockPtr(seqIdx, tokenIdx, KVIdxType::K_IDX);
}
__host__ __device__ [[nodiscard]] inline void* getVBlockPtr(int32_t seqIdx, int32_t tokenIdx) const
{
return getBlockPtr(seqIdx, tokenIdx, KVIdxType::V_IDX);
}
__host__ __device__ [[nodiscard]] inline int32_t getLocalIdx(int32_t globalIdx) const
{
return globalIdx & ((1 << mTokensPerBlockLog2) - 1);
}
__host__ __device__ [[nodiscard]] inline int32_t getKVLocalIdx(
int32_t globalTokenIdx, int32_t headIdx, int32_t dimsPerHead, int32_t channelIdx) const
{
// For K or V, the hidden dimension per head is *not* decomposed. The layout of each block of K or V is:
// [numHeads, tokensPerBlock, hiddenSizePerHead].
// This member function computes the corresponding linear index.
// NOTE: we have remapped K layout as the same of V.
return headIdx * mTokensPerBlock * dimsPerHead + getLocalIdx(globalTokenIdx) * dimsPerHead + channelIdx;
}
private:
__host__ __device__ [[nodiscard]] void* getPoolPtr(DataType offset) const
{
return offset.isPrimary() ? mPrimaryPoolPtr : mSecondaryPoolPtr;
}
};
// Struct operates on contiguous kv cache providing
// functions for accessing specific elements in K and V caches
struct KVLinearBuffer
{
using DataType = int8_t;
// Current number of sequences
int32_t mMaxSeqs;
// Max sequence length
int32_t mMaxSeqLen;
// Bytes per sequence (H*D*M_S*sizeof(DataType))
int32_t mBytesPerSeq;
// Maximum kv cache length per sequence
int32_t mMaxAttentionWindow;
// Number of sink tokens.
int32_t mSinkTokens;
// Cyclic cache length.
int32_t mCyclicCacheLen;
// Bubble length.
int32_t mBubbleLen;
// Valid rows per sequence
int32_t mValidRowsPerSeq;
// Enable one more block to save the kv tokens
bool mEnableOneMoreBlock;
// Pointer to the of K/V cache data
// Shape [B, 2, S*H*D], where 2 is for K and V,
// B is current number of sequences and
// H is number of heads
// S is maximum sequence length
// D is dimension per head
// K shape is [B, 1, H, S, D]
// V shape is [B, 1, H, S, D]
// NOTE: we have remapped K layout as the same of V.
DataType* data;
KVLinearBuffer()
: mMaxSeqs{0}
, mMaxSeqLen{0}
, mBytesPerSeq{0}
, mMaxAttentionWindow{0}
, mSinkTokens{0}
, mCyclicCacheLen{0}
, mBubbleLen{0}
, mValidRowsPerSeq{0}
, mEnableOneMoreBlock{false}
, data{nullptr}
{
}
KVLinearBuffer(int32_t batchSize, int32_t tokensPerBlock, int32_t sizePerToken, int32_t maxAttentionWindow,
int32_t sinkTokenLen, bool onlyKorV, DataType* data)
: mMaxSeqs(batchSize)
, mMaxSeqLen(tokensPerBlock)
, mBytesPerSeq(tokensPerBlock * sizePerToken)
, mMaxAttentionWindow(maxAttentionWindow)
, mSinkTokens(sinkTokenLen)
, data(data)
{
// NOTE: pointer offset arithmetic offset is performed on int32_t (see this.getRowPtr).
// If needed, we could do it on uint32_t or even uint64_t, but that might have performance implications
TLLM_CHECK_WITH_INFO(
static_cast<int64_t>(mMaxSeqs - 1) * mBytesPerSeq * 2 + mBytesPerSeq <= std::numeric_limits<int32_t>::max(),
"kv cache is too large for gpt_attention_plugin");
mCyclicCacheLen = mMaxAttentionWindow - mSinkTokens;
mBubbleLen = 0;
mValidRowsPerSeq = (onlyKorV) ? 1 : 2;
mEnableOneMoreBlock = false;
}
__host__ __device__ [[nodiscard]] inline bool isSinkToken(int32_t tokenIdx) const
{
return tokenIdx < mSinkTokens;
}
__host__ __device__ [[nodiscard]] inline void** getRowPtr(KVIdxType kvIdx, int32_t seqIdx) const
{
return reinterpret_cast<void**>(data + seqIdx * mBytesPerSeq * mValidRowsPerSeq
+ static_cast<int32_t>(kvIdx) * mBytesPerSeq * (mValidRowsPerSeq - 1));
}
__host__ __device__ [[nodiscard]] inline int32_t getKVTokenIdx(int32_t tokenIdx) const
{
if (!isSinkToken(tokenIdx))
{
// Apply cyclic kv cache if tokenIdx >= max_attention_window_size.
return mSinkTokens + (tokenIdx - mSinkTokens) % mCyclicCacheLen;
}
return tokenIdx;
}
__host__ __device__ static inline void* getBlockPtr(void** pointer, int32_t tokenIdx)
{
return reinterpret_cast<void*>(pointer);
}
__host__ __device__ [[nodiscard]] inline void* getKBlockPtr(int32_t seqIdx, int32_t /*tokenIdx*/) const
{
return reinterpret_cast<void*>(getRowPtr(KVIdxType::K_IDX, seqIdx));
}
__host__ __device__ [[nodiscard]] inline void* getVBlockPtr(int32_t seqIdx, int32_t /*tokenIdx*/) const
{
return reinterpret_cast<void*>(getRowPtr(KVIdxType::V_IDX, seqIdx));
}
__host__ __device__ [[nodiscard]] inline int32_t getKVLocalIdx(
int32_t tokenIdx, int32_t headIdx, int32_t dimsPerHead, int32_t channelIdx) const
{
return headIdx * mMaxSeqLen * dimsPerHead + tokenIdx * dimsPerHead + channelIdx;
}
};
} // namespace kernels
TRTLLM_NAMESPACE_END