/* * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "tensorrt_llm/common/assert.h" #include #include namespace tensorrt_llm { namespace kernels { // Internal for K and V cache indexing enum class KVIdxType : int32_t { K_IDX = 0, V_IDX = 1 }; struct KVBlockArray { // Struct operates on paged kv cache providing // functions for accessing blocks of in K and V caches // and elements inside these blocks // Max number of blocks per sequence int32_t mMaxBlocksPerSeq; // Current number of sequences int32_t mMaxSeqs; // Number of tokens. It must be power of 2. int32_t mTokensPerBlock; // Exponent of number of tokens with base 2. // E.g. for mTokensPerBlock 64, mTokensPerBlockLog2 equals to 6 int32_t mTokensPerBlockLog2; // Table maps logical block idx to the data pointer of k/v cache block pool // Shape [B, W, 2, M], where 2 is table for K and V, // B is current number of sequences // W is beam width // M is Max number of blocks per sequence // int64_t reinterpred to void* pointing to the KV cache data int64_t* data; KVBlockArray() {} KVBlockArray(int32_t batchSize, int32_t maxBlocksPerSeq, int32_t tokensPerBlock, int32_t sizePerToken) : mMaxSeqs(batchSize) , mMaxBlocksPerSeq(maxBlocksPerSeq) , mTokensPerBlock(tokensPerBlock) { const float tokensPerBlockSeqLog2 = log2(mTokensPerBlock); TLLM_CHECK_WITH_INFO( ceil(tokensPerBlockSeqLog2) == floor(tokensPerBlockSeqLog2), "tokensPerBlock must be power of 2"); mTokensPerBlockLog2 = static_cast(tokensPerBlockSeqLog2); } __host__ __device__ inline void** getRowPtr(KVIdxType kvIdx, int32_t seqIdx) { // Returns pointer to array of pointers to K or V cache for one specific sequence seqIdx. // seqIdx is in range [0; B] return reinterpret_cast( data + seqIdx * mMaxBlocksPerSeq * 2 + static_cast(kvIdx) * mMaxBlocksPerSeq); } __host__ __device__ inline void* getBlockPtr(void** pointer, int32_t tokenIdx) { return pointer[tokenIdx >> mTokensPerBlockLog2]; } __host__ __device__ inline void* getBlockPtr(int32_t seqIdx, int32_t tokenIdx, KVIdxType kvIdx) { return getBlockPtr(getRowPtr(kvIdx, seqIdx), tokenIdx); } __host__ __device__ inline void* getKBlockPtr(int32_t seqIdx, int32_t tokenIdx) { return getBlockPtr(seqIdx, tokenIdx, KVIdxType::K_IDX); } __host__ __device__ inline void* getVBlockPtr(int32_t seqIdx, int32_t tokenIdx) { return getBlockPtr(seqIdx, tokenIdx, KVIdxType::V_IDX); } __host__ __device__ inline int32_t getLocalIdx(int32_t globalIdx) { return globalIdx & ((1 << mTokensPerBlockLog2) - 1); } __host__ __device__ inline int32_t getKVLocalIdx( int32_t globalTokenIdx, int32_t headIdx, int32_t dimsPerHead, int32_t channelIdx) { // For K or V, the hidden dimension per head is *not* decomposed. The layout of each block of K or V is: // [numHeads, tokensPerBlock, hiddenSizePerHead]. // This member function computes the corresponding linear index. // NOTE: we have remapped K layout as the same of V. return headIdx * mTokensPerBlock * dimsPerHead + getLocalIdx(globalTokenIdx) * dimsPerHead + channelIdx; } }; struct KVLinearBuffer { // Struct operates on contiguous kv cache providing // functions for accessing specific elements in K and V caches // Current number of sequences int32_t mMaxSeqs; // Max sequence length int32_t mMaxSeqLen; // Bytes per sequence (H*D*M_S*sizeof(DataType)) int32_t mBytesPerSeq; // Pointer to the of K/V cache data // Shape [B, 2, S*H*D], where 2 is for K and V, // B is current number of sequences and // H is number of heads // S is maximum sequence length // D is dimension per head // K shape is [B, 1, H, S, D] // V shape is [B, 1, H, S, D] // NOTE: we have remapped K layout as the same of V. int8_t* data; KVLinearBuffer() {} KVLinearBuffer(int32_t batchSize, int32_t maxBlocksPerSeq, int32_t tokensPerBlock, int32_t sizePerToken) : mMaxSeqs(batchSize) , mMaxSeqLen(tokensPerBlock) , mBytesPerSeq(tokensPerBlock * sizePerToken) { } __host__ __device__ inline void** getRowPtr(KVIdxType kvIdx, int32_t seqIdx) { return reinterpret_cast(data + seqIdx * mBytesPerSeq * 2 + static_cast(kvIdx) * mBytesPerSeq); } __host__ __device__ inline void* getBlockPtr(void** pointer, int32_t tokenIdx) { return reinterpret_cast(pointer); } __host__ __device__ inline void* getKBlockPtr(int32_t seqIdx, int32_t /*tokenIdx*/) { return reinterpret_cast(getRowPtr(KVIdxType::K_IDX, seqIdx)); } __host__ __device__ inline void* getVBlockPtr(int32_t seqIdx, int32_t /*tokenIdx*/) { return reinterpret_cast(getRowPtr(KVIdxType::V_IDX, seqIdx)); } __host__ __device__ inline int32_t getKVLocalIdx( int32_t tokenIdx, int32_t headIdx, int32_t dimsPerHead, int32_t channelIdx) { return headIdx * mMaxSeqLen * dimsPerHead + tokenIdx * dimsPerHead + channelIdx; } }; } // namespace kernels } // namespace tensorrt_llm