TensorRT-LLMs/cpp/tensorrt_llm/kernels/kvCacheUtils.h
2023-09-20 00:29:41 -07:00

174 lines
5.9 KiB
C++

/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/assert.h"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
namespace tensorrt_llm
{
namespace kernels
{
// Internal for K and V cache indexing
enum class KVIdxType : int32_t
{
K_IDX = 0,
V_IDX = 1
};
struct KVBlockArray
{
// Struct operates on paged kv cache providing
// functions for accessing blocks of in K and V caches
// and elements inside these blocks
// Max number of blocks per sequence
int32_t mMaxBlocksPerSeq;
// Current number of sequences
int32_t mMaxSeqs;
// Number of tokens. It must be power of 2.
int32_t mTokensPerBlock;
// Exponent of number of tokens with base 2.
// E.g. for mTokensPerBlock 64, mTokensPerBlockLog2 equals to 6
int32_t mTokensPerBlockLog2;
// Table maps logical block idx to the data pointer of k/v cache block pool
// Shape [B, W, 2, M], where 2 is table for K and V,
// B is current number of sequences
// W is beam width
// M is Max number of blocks per sequence
// int64_t reinterpred to void* pointing to the KV cache data
int64_t* data;
KVBlockArray() {}
KVBlockArray(int32_t batchSize, int32_t maxBlocksPerSeq, int32_t tokensPerBlock, int32_t sizePerToken)
: mMaxSeqs(batchSize)
, mMaxBlocksPerSeq(maxBlocksPerSeq)
, mTokensPerBlock(tokensPerBlock)
{
const float tokensPerBlockSeqLog2 = log2(mTokensPerBlock);
TLLM_CHECK_WITH_INFO(
ceil(tokensPerBlockSeqLog2) == floor(tokensPerBlockSeqLog2), "tokensPerBlock must be power of 2");
mTokensPerBlockLog2 = static_cast<int>(tokensPerBlockSeqLog2);
}
__host__ __device__ inline void** getRowPtr(KVIdxType kvIdx, int32_t seqIdx)
{
// Returns pointer to array of pointers to K or V cache for one specific sequence seqIdx.
// seqIdx is in range [0; B]
return reinterpret_cast<void**>(
data + seqIdx * mMaxBlocksPerSeq * 2 + static_cast<int32_t>(kvIdx) * mMaxBlocksPerSeq);
}
__host__ __device__ inline void* getBlockPtr(void** pointer, int32_t tokenIdx)
{
return pointer[tokenIdx >> mTokensPerBlockLog2];
}
__host__ __device__ inline void* getBlockPtr(int32_t seqIdx, int32_t tokenIdx, KVIdxType kvIdx)
{
return getBlockPtr(getRowPtr(kvIdx, seqIdx), tokenIdx);
}
__host__ __device__ inline void* getKBlockPtr(int32_t seqIdx, int32_t tokenIdx)
{
return getBlockPtr(seqIdx, tokenIdx, KVIdxType::K_IDX);
}
__host__ __device__ inline void* getVBlockPtr(int32_t seqIdx, int32_t tokenIdx)
{
return getBlockPtr(seqIdx, tokenIdx, KVIdxType::V_IDX);
}
__host__ __device__ inline int32_t getLocalIdx(int32_t globalIdx)
{
return globalIdx & ((1 << mTokensPerBlockLog2) - 1);
}
__host__ __device__ inline int32_t getKVLocalIdx(
int32_t globalTokenIdx, int32_t headIdx, int32_t dimsPerHead, int32_t channelIdx)
{
// For K or V, the hidden dimension per head is *not* decomposed. The layout of each block of K or V is:
// [numHeads, tokensPerBlock, hiddenSizePerHead].
// This member function computes the corresponding linear index.
// NOTE: we have remapped K layout as the same of V.
return headIdx * mTokensPerBlock * dimsPerHead + getLocalIdx(globalTokenIdx) * dimsPerHead + channelIdx;
}
};
struct KVLinearBuffer
{
// Struct operates on contiguous kv cache providing
// functions for accessing specific elements in K and V caches
// Current number of sequences
int32_t mMaxSeqs;
// Max sequence length
int32_t mMaxSeqLen;
// Bytes per sequence (H*D*M_S*sizeof(DataType))
int32_t mBytesPerSeq;
// Pointer to the of K/V cache data
// Shape [B, 2, S*H*D], where 2 is for K and V,
// B is current number of sequences and
// H is number of heads
// S is maximum sequence length
// D is dimension per head
// K shape is [B, 1, H, S, D]
// V shape is [B, 1, H, S, D]
// NOTE: we have remapped K layout as the same of V.
int8_t* data;
KVLinearBuffer() {}
KVLinearBuffer(int32_t batchSize, int32_t maxBlocksPerSeq, int32_t tokensPerBlock, int32_t sizePerToken)
: mMaxSeqs(batchSize)
, mMaxSeqLen(tokensPerBlock)
, mBytesPerSeq(tokensPerBlock * sizePerToken)
{
}
__host__ __device__ inline void** getRowPtr(KVIdxType kvIdx, int32_t seqIdx)
{
return reinterpret_cast<void**>(data + seqIdx * mBytesPerSeq * 2 + static_cast<int32_t>(kvIdx) * mBytesPerSeq);
}
__host__ __device__ inline void* getBlockPtr(void** pointer, int32_t tokenIdx)
{
return reinterpret_cast<void*>(pointer);
}
__host__ __device__ inline void* getKBlockPtr(int32_t seqIdx, int32_t /*tokenIdx*/)
{
return reinterpret_cast<void*>(getRowPtr(KVIdxType::K_IDX, seqIdx));
}
__host__ __device__ inline void* getVBlockPtr(int32_t seqIdx, int32_t /*tokenIdx*/)
{
return reinterpret_cast<void*>(getRowPtr(KVIdxType::V_IDX, seqIdx));
}
__host__ __device__ inline int32_t getKVLocalIdx(
int32_t tokenIdx, int32_t headIdx, int32_t dimsPerHead, int32_t channelIdx)
{
return headIdx * mMaxSeqLen * dimsPerHead + tokenIdx * dimsPerHead + channelIdx;
}
};
} // namespace kernels
} // namespace tensorrt_llm