#include #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h" #include "tensorrt_llm/kernels/gptKernels.h" #include "tensorrt_llm/kernels/kvCacheUtils.h" #include "tensorrt_llm/kernels/unfusedAttentionKernels.h" #include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/cudaStream.h" #include "tensorrt_llm/runtime/runtimeKernels.h" #include using namespace tensorrt_llm::runtime; using namespace tensorrt_llm::kernels; namespace tc = tensorrt_llm::common; namespace trk = tensorrt_llm::runtime::kernels; namespace { template void initRandom(T* ptr, size_t size, float minval, float maxval) { for (size_t i = 0; i < size; ++i) { float val = static_cast(rand()) / static_cast(RAND_MAX); val *= (maxval - minval); ptr[i] = static_cast(minval + val); } } template struct SATypeConverter { using Type = T; }; template <> struct SATypeConverter { using Type = uint16_t; }; template __global__ void applyRoPE(KVCacheBuffer kCacheRead, KVLinearBuffer kCacheWrite, const int sizePerHead, const int beam_width, const int* token_read_idxs, const int* token_write_idxs, const int* token_pos_idxs, const int* token_seq_idxs, const int* sequence_lengths, const int* input_lengths, const int rotary_embedding_dim, float rotary_embedding_base, RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, const int rotary_embedding_max_positions, PositionEmbeddingType const position_embedding_type) { // We allow only fp32/fp16/bf16 as the data types to apply rotary static_assert(sizeof(T) == 4 || sizeof(T) == 2, ""); extern __shared__ __align__(sizeof(float2)) char smem_[]; // align on largest vector type // Each thread will handle 16 bytes. constexpr int vec_size = 16u / sizeof(T); using Vec_k = typename mmha::packed_type::type; const int sizePerHeadDivX = sizePerHead / vec_size; // The position idx const int token_idx = token_seq_idxs[blockIdx.x]; const int token_read_idx = token_read_idxs[blockIdx.x]; const int token_write_idx = token_write_idxs[blockIdx.x]; const int token_pos_idx = token_pos_idxs[blockIdx.x]; // Head const int head_idx = blockIdx.y; // The batch beam idx const int batch_beam_idx = blockIdx.z; // The beam idx const int beam_idx = batch_beam_idx % beam_width; // Thread idx const int tidx = threadIdx.x; // The actual sequence length excluding the paddings. const int tlength = sequence_lengths[batch_beam_idx] - 1; // The context length const int inlength = input_lengths[batch_beam_idx]; // Mask out the tokens exceed the real total length and tokens in the context phase with beam_idx>0 const bool valid_seq = token_idx < tlength && !(token_idx < inlength && beam_idx > 0); const bool is_head_size_masked = tidx * vec_size >= sizePerHead; if (!valid_seq || is_head_size_masked) { return; } // Read k Vec_k k; T* k_cache = reinterpret_cast(kCacheRead.getKBlockPtr(batch_beam_idx, token_read_idx)); int inBlockIdx_r = kCacheRead.getKVLocalIdx(token_read_idx, head_idx, sizePerHead, tidx * vec_size); k = *reinterpret_cast(&k_cache[inBlockIdx_r]); // Apply position embedding switch (position_embedding_type) { case PositionEmbeddingType::kROPE_GPTJ: { mmha::apply_rotary_embedding( k, tidx, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale, token_pos_idx); break; } case PositionEmbeddingType::kROPE_GPT_NEOX: { const bool do_rotary = vec_size * tidx < rotary_embedding_dim; T* k_smem = reinterpret_cast(smem_); const int half_rotary_dim = rotary_embedding_dim / 2; const int half_idx = (tidx * vec_size) / half_rotary_dim; const int intra_half_idx = (tidx * vec_size) % half_rotary_dim; const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts? if (do_rotary) { *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx) = k; } __syncthreads(); const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; constexpr int tidx_factor = vec_size / 2; if (do_rotary) { mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); mmha::apply_rotary_embedding(k, transpose_idx / tidx_factor, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale, token_pos_idx); mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); } __syncthreads(); if (do_rotary) { k = *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx); } break; } } // write back to cache T* kDst = reinterpret_cast(kCacheWrite.getKBlockPtr(batch_beam_idx, token_write_idx)); int inBlockIdx_w = kCacheWrite.getKVLocalIdx(token_write_idx, head_idx, sizePerHeadDivX, tidx); reinterpret_cast(kDst)[inBlockIdx_w] = k; } template void invokeApplyRoPE(KVCacheBuffer kCacheRead, KVLinearBuffer kCacheWrite, const int sizePerHead, const int batch_beam, const int kv_head_num, const int beam_width, const int* token_read_idxs, const int* token_write_idxs, const int* token_pos_idxs, const int* token_seq_idxs, const int token_num, const int* sequence_lengths, const int* input_lengths, const int rotary_embedding_dim, float rotary_embedding_base, RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, const int rotary_embedding_max_positions, PositionEmbeddingType const position_embedding_type, cudaStream_t stream) { // Block handles K tile. const int vec_size = 16u / sizeof(T); dim3 block((sizePerHead / vec_size + 31) / 32 * 32); dim3 grid(token_num, kv_head_num, batch_beam); size_t smem_size = (position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX ? 2 * rotary_embedding_dim * sizeof(T) : 0); applyRoPE<<>>(kCacheRead, kCacheWrite, sizePerHead, beam_width, token_read_idxs, token_write_idxs, token_pos_idxs, token_seq_idxs, sequence_lengths, input_lengths, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type); } template class ShiftKCacheKernelTest : public ::testing::Test { public: using TensorPtr = ITensor::SharedPtr; void SetUp() override { mStream = std::make_shared(); mBufferManager = std::make_shared(mStream); auto const deviceCount = tc::getDeviceCount(); if (deviceCount == 0) { GTEST_SKIP(); } } void TearDown() override {} void initData(int32_t batchSize, int32_t beamWidth, int32_t numHeads, int32_t maxAttentionWindow, int32_t headSize, bool pagedKvCache, int32_t maxBlocksPerSeq, int32_t tokensPerBlock, const std::vector& seqLengths, const std::vector& inputLengths, const std::vector& tokenReadIdxs, const std::vector& tokenWriteIdxs, const std::vector& tokenPosIdxs, const std::vector& tokenSeqIdxs) { // allocate buffer mSeqLengthsHost = mBufferManager->pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); mSeqLengthsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); mInputLengthsHost = mBufferManager->pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); mInputLengthsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); mKScaleQuantOrigDevice = mBufferManager->gpu(ITensor::makeShape({1}), nvinfer1::DataType::kFLOAT); mTokenReadIdxsHost = mBufferManager->pinned( ITensor::makeShape({static_cast(tokenReadIdxs.size())}), nvinfer1::DataType::kINT32); mTokenReadIdxsDevice = mBufferManager->gpu( ITensor::makeShape({static_cast(tokenReadIdxs.size())}), nvinfer1::DataType::kINT32); mTokenWriteIdxsHost = mBufferManager->pinned( ITensor::makeShape({static_cast(tokenWriteIdxs.size())}), nvinfer1::DataType::kINT32); mTokenWriteIdxsDevice = mBufferManager->gpu( ITensor::makeShape({static_cast(tokenWriteIdxs.size())}), nvinfer1::DataType::kINT32); mTokenPosIdxsHost = mBufferManager->pinned( ITensor::makeShape({static_cast(tokenPosIdxs.size())}), nvinfer1::DataType::kINT32); mTokenPosIdxsDevice = mBufferManager->gpu( ITensor::makeShape({static_cast(tokenPosIdxs.size())}), nvinfer1::DataType::kINT32); mTokenSeqIdxsHost = mBufferManager->pinned( ITensor::makeShape({static_cast(tokenSeqIdxs.size())}), nvinfer1::DataType::kINT32); mTokenSeqIdxsDevice = mBufferManager->gpu( ITensor::makeShape({static_cast(tokenSeqIdxs.size())}), nvinfer1::DataType::kINT32); // nvinfer1::DataType dataType = nvinfer1::DataType::kHALF // nvinfer1::DataType::kHALF // nvinfer1::DataType::kBF16 int32_t batchBeam = batchSize * beamWidth; if (pagedKvCache) { mInputDataHost = mBufferManager->pinned( ITensor::makeShape({batchSize, beamWidth, 2, maxBlocksPerSeq, numHeads * tokensPerBlock * headSize}), TRTDataType::value); mInputDataDevice = mBufferManager->gpu( ITensor::makeShape({batchSize, beamWidth, 2, maxBlocksPerSeq, numHeads * tokensPerBlock * headSize}), TRTDataType::value); mInputBlockPtrsHost = mBufferManager->pinned( ITensor::makeShape({batchSize, beamWidth, 2, maxBlocksPerSeq}), nvinfer1::DataType::kINT64); mInputBlockPtrsDevice = mBufferManager->gpu( ITensor::makeShape({batchSize, beamWidth, 2, maxBlocksPerSeq}), nvinfer1::DataType::kINT64); } else { mInputDataHost = mBufferManager->pinned( ITensor::makeShape({batchBeam, 2, numHeads, maxAttentionWindow, headSize}), TRTDataType::value); mInputDataDevice = mBufferManager->gpu( ITensor::makeShape({batchBeam, 2, numHeads, maxAttentionWindow, headSize}), TRTDataType::value); } mOutputDataHost = mBufferManager->pinned( ITensor::makeShape({batchBeam, 1, numHeads, maxAttentionWindow, headSize}), TRTDataType::value); mOutputDataDevice = mBufferManager->gpu( ITensor::makeShape({batchBeam, 1, numHeads, maxAttentionWindow, headSize}), TRTDataType::value); mRefOutputDataHost = mBufferManager->pinned( ITensor::makeShape({batchBeam, 1, numHeads, maxAttentionWindow, headSize}), TRTDataType::value); mRefOutputDataDevice = mBufferManager->gpu( ITensor::makeShape({batchBeam, 1, numHeads, maxAttentionWindow, headSize}), TRTDataType::value); // init data auto inputDataHostPtr = bufferCast(*mInputDataHost); initRandom(inputDataHostPtr, batchSize * beamWidth * 2 * numHeads * maxAttentionWindow * headSize, -3.0f, 3.0f); trk::invokeFill(*mKScaleQuantOrigDevice, float{1.0f}, *mStream); auto seqLengthsHostPtr = bufferCast(*mSeqLengthsHost); auto inputLengthsHostPtr = bufferCast(*mInputLengthsHost); auto tokenReadIdxsHostPtr = bufferCast(*mTokenReadIdxsHost); auto tokenWriteIdxsHostPtr = bufferCast(*mTokenWriteIdxsHost); auto tokenPosIdxsHostPtr = bufferCast(*mTokenPosIdxsHost); auto tokenSeqIdxsHostPtr = bufferCast(*mTokenSeqIdxsHost); for (SizeType bi = 0; bi < batchSize; ++bi) { seqLengthsHostPtr[bi] = seqLengths[bi]; inputLengthsHostPtr[bi] = inputLengths[bi]; } for (SizeType idx = 0; idx < tokenReadIdxs.size(); ++idx) { tokenReadIdxsHostPtr[idx] = tokenReadIdxs[idx]; tokenWriteIdxsHostPtr[idx] = tokenWriteIdxs[idx]; tokenPosIdxsHostPtr[idx] = tokenPosIdxs[idx]; tokenSeqIdxsHostPtr[idx] = tokenSeqIdxs[idx]; } if (pagedKvCache) { auto inputDataDevicePtr = bufferCast(*mInputDataDevice); auto inputBlockPtrsHostPtr = reinterpret_cast(bufferCast(*mInputBlockPtrsHost)); const int32_t num_per_block = tokensPerBlock * numHeads * headSize; inputBlockPtrsHostPtr[0] = inputDataDevicePtr; for (SizeType idx = 1; idx < batchBeam * 2 * maxBlocksPerSeq; idx++) { inputBlockPtrsHostPtr[idx] = inputBlockPtrsHostPtr[idx - 1] + num_per_block; } mBufferManager->copy(*mInputBlockPtrsHost, *mInputBlockPtrsDevice); } mBufferManager->copy(*mInputDataHost, *mInputDataDevice); mBufferManager->copy(*mSeqLengthsHost, *mSeqLengthsDevice); mBufferManager->copy(*mInputLengthsHost, *mInputLengthsDevice); mBufferManager->copy(*mTokenReadIdxsHost, *mTokenReadIdxsDevice); mBufferManager->copy(*mTokenWriteIdxsHost, *mTokenWriteIdxsDevice); mBufferManager->copy(*mTokenPosIdxsHost, *mTokenPosIdxsDevice); mBufferManager->copy(*mTokenSeqIdxsHost, *mTokenSeqIdxsDevice); } float compareResults(KVLinearBuffer kCacheOut, KVLinearBuffer kCacheRef, int32_t batchBeam, int32_t beamWidth, int32_t numHeads, int32_t headSize, int32_t validTokenNum, const int32_t* seqLengths, const int32_t* inputLengths, const int32_t* tokenWriteIdxs, const int32_t* tokenSeqIdxs) { mBufferManager->copy(*mOutputDataDevice, *mOutputDataHost); mBufferManager->copy(*mRefOutputDataDevice, *mRefOutputDataHost); // Synchronize mStream->synchronize(); // Compare the results float tot_diff = 0.f; for (SizeType bi = 0; bi < batchBeam; ++bi) { const int tlength = seqLengths[bi] - 1; const int inlength = inputLengths[bi]; const int beam_idx = bi % beamWidth; for (SizeType hi = 0; hi < numHeads; ++hi) { for (SizeType ti = 0; ti < validTokenNum; ++ti) { const int token_seq_idx = tokenSeqIdxs[ti]; const int token_write_idx = tokenWriteIdxs[ti]; const bool valid_seq = token_seq_idx < tlength && !(token_seq_idx < inlength && beam_idx > 0); if (!valid_seq) { continue; } for (SizeType ci = 0; ci < headSize; ++ci) { T* kRes = reinterpret_cast(kCacheOut.getKBlockPtr(bi, token_write_idx)); int resIdx = kCacheOut.getKVLocalIdx(token_write_idx, hi, headSize, ci); T* kRef = reinterpret_cast(kCacheRef.getKBlockPtr(bi, token_write_idx)); int refIdx = kCacheRef.getKVLocalIdx(token_write_idx, hi, headSize, ci); float res = static_cast(kRes[resIdx]); float ref = static_cast(kRef[refIdx]); float diff = std::abs(res - ref); tot_diff += diff; } } } } return tot_diff; } void runTest(int32_t batchSize, int32_t beamWidth, int32_t numHeads, int32_t headSize, int32_t maxAttentionWindow, int32_t sinkTokenLength, int32_t pastKCacheLength, int32_t validTokenNum, bool pagedKvCache, int32_t maxBlocksPerSeq, int32_t tokensPerBlock, int rotaryEmbeddingDim, float rotaryEmbeddingBase, RotaryScalingType const rotaryScaleType, float rotaryEmbeddingScale, const int rotaryEmbeddingMaxPositions, PositionEmbeddingType const positionEmbeddingType) { // Synchronize mStream->synchronize(); // get kv cache const int32_t batchBeam = batchSize * beamWidth; const auto elemSize = sizeof(T); KVLinearBuffer shiftKCacheBuffer = KVLinearBuffer(batchBeam, 1, maxAttentionWindow, numHeads * headSize * elemSize, maxAttentionWindow, sinkTokenLength, true); shiftKCacheBuffer.data = reinterpret_cast(bufferCast(*mOutputDataDevice)); KVLinearBuffer refShiftKCacheBuffer = KVLinearBuffer(batchBeam, 1, maxAttentionWindow, numHeads * headSize * elemSize, maxAttentionWindow, sinkTokenLength, true); refShiftKCacheBuffer.data = reinterpret_cast(bufferCast(*mRefOutputDataDevice)); // run shift k cache const KvCacheDataType kv_cache_type = KvCacheDataType::BASE; using DataType = typename SATypeConverter::Type; if (pagedKvCache) { KVBlockArray kvCacheBuffer = KVBlockArray(batchBeam, maxBlocksPerSeq, tokensPerBlock, numHeads * headSize * elemSize, maxAttentionWindow, sinkTokenLength, false); kvCacheBuffer.data = reinterpret_cast(bufferCast(*mInputBlockPtrsDevice)); invokeShiftKCache(kvCacheBuffer, shiftKCacheBuffer, kv_cache_type, headSize, pastKCacheLength, batchBeam, numHeads, beamWidth, maxAttentionWindow, sinkTokenLength, bufferCast(*mKScaleQuantOrigDevice), bufferCast(*mSeqLengthsDevice), bufferCast(*mInputLengthsDevice), rotaryEmbeddingDim, rotaryEmbeddingBase, rotaryScaleType, rotaryEmbeddingScale, rotaryEmbeddingMaxPositions, positionEmbeddingType, mStream->get()); // run ref invokeApplyRoPE(kvCacheBuffer, refShiftKCacheBuffer, headSize, batchBeam, numHeads, beamWidth, bufferCast(*mTokenReadIdxsDevice), bufferCast(*mTokenWriteIdxsDevice), bufferCast(*mTokenPosIdxsDevice), bufferCast(*mTokenSeqIdxsDevice), validTokenNum, bufferCast(*mSeqLengthsDevice), bufferCast(*mInputLengthsDevice), rotaryEmbeddingDim, rotaryEmbeddingBase, rotaryScaleType, rotaryEmbeddingScale, rotaryEmbeddingMaxPositions, positionEmbeddingType, mStream->get()); } else { KVLinearBuffer kvCacheBuffer = KVLinearBuffer(batchBeam, 1, maxAttentionWindow, numHeads * headSize * elemSize, maxAttentionWindow, sinkTokenLength, false); // run shift k cache kvCacheBuffer.data = reinterpret_cast(bufferCast(*mInputDataDevice)); invokeShiftKCache(kvCacheBuffer, shiftKCacheBuffer, kv_cache_type, headSize, pastKCacheLength, batchBeam, numHeads, beamWidth, maxAttentionWindow, sinkTokenLength, bufferCast(*mKScaleQuantOrigDevice), bufferCast(*mSeqLengthsDevice), bufferCast(*mInputLengthsDevice), rotaryEmbeddingDim, rotaryEmbeddingBase, rotaryScaleType, rotaryEmbeddingScale, rotaryEmbeddingMaxPositions, positionEmbeddingType, mStream->get()); // run ref invokeApplyRoPE(kvCacheBuffer, refShiftKCacheBuffer, headSize, batchBeam, numHeads, beamWidth, bufferCast(*mTokenReadIdxsDevice), bufferCast(*mTokenWriteIdxsDevice), bufferCast(*mTokenPosIdxsDevice), bufferCast(*mTokenSeqIdxsDevice), validTokenNum, bufferCast(*mSeqLengthsDevice), bufferCast(*mInputLengthsDevice), rotaryEmbeddingDim, rotaryEmbeddingBase, rotaryScaleType, rotaryEmbeddingScale, rotaryEmbeddingMaxPositions, positionEmbeddingType, mStream->get()); } // Synchronize mStream->synchronize(); shiftKCacheBuffer.data = reinterpret_cast(bufferCast(*mOutputDataHost)); refShiftKCacheBuffer.data = reinterpret_cast(bufferCast(*mRefOutputDataHost)); float diff = compareResults(shiftKCacheBuffer, refShiftKCacheBuffer, batchBeam, beamWidth, numHeads, headSize, validTokenNum, bufferCast(*mSeqLengthsHost), bufferCast(*mInputLengthsHost), bufferCast(*mTokenWriteIdxsHost), bufferCast(*mTokenSeqIdxsHost)); EXPECT_EQ(diff, 0); } protected: std::shared_ptr mBufferManager; std::shared_ptr mStream; TensorPtr mSeqLengthsHost; TensorPtr mSeqLengthsDevice; TensorPtr mInputLengthsHost; TensorPtr mInputLengthsDevice; TensorPtr mInputBlockPtrsHost; TensorPtr mInputBlockPtrsDevice; TensorPtr mKScaleQuantOrigDevice; TensorPtr mTokenReadIdxsHost; TensorPtr mTokenReadIdxsDevice; TensorPtr mTokenWriteIdxsHost; TensorPtr mTokenWriteIdxsDevice; TensorPtr mTokenPosIdxsHost; TensorPtr mTokenPosIdxsDevice; TensorPtr mTokenSeqIdxsHost; TensorPtr mTokenSeqIdxsDevice; TensorPtr mInputDataHost; TensorPtr mInputDataDevice; TensorPtr mOutputDataHost; TensorPtr mOutputDataDevice; TensorPtr mRefOutputDataHost; TensorPtr mRefOutputDataDevice; }; typedef testing::Types FloatAndHalfTypes; TYPED_TEST_SUITE(ShiftKCacheKernelTest, FloatAndHalfTypes); TYPED_TEST(ShiftKCacheKernelTest, UncyclicShiftKCache) { auto constexpr batchSize = 1; auto constexpr beamWidth = 1; auto constexpr numHeads = 2; auto constexpr headSize = 64; auto constexpr maxAttentionWindow = 32; auto constexpr sinkTokenLength = 0; auto constexpr timestep = 17; auto constexpr pastKCacheLength = timestep - 1; auto constexpr validTokenNum = std::min(pastKCacheLength, maxAttentionWindow); auto constexpr rotaryEmbeddingDim = headSize; auto constexpr rotaryEmbeddingBase = 10000.0; auto constexpr rotaryScaleType = RotaryScalingType::kNONE; auto constexpr rotaryEmbeddingScale = 1.0; auto constexpr rotaryEmbeddingMaxPositions = 4096; auto constexpr positionEmbeddingType = PositionEmbeddingType::kROPE_GPT_NEOX; auto pagedKvCaches = std::vector{false, true}; for (auto pagedKvCache : pagedKvCaches) { const SizeType maxBlocksPerSeq = (pagedKvCache) ? 2 : 0; const SizeType tokensPerBlock = (pagedKvCache) ? 16 : 0; // include one more token for the current time step in seqLengths. std::vector seqLengths = {timestep}; std::vector inputLengths = {8}; std::vector tokenReadIdxs; std::vector tokenWriteIdxs; std::vector tokenPosIdxs; std::vector tokenSeqIdxs; for (SizeType idx = 0; idx < timestep - 1; ++idx) { tokenReadIdxs.push_back(idx); tokenWriteIdxs.push_back(idx); tokenPosIdxs.push_back(idx); tokenSeqIdxs.push_back(idx); } this->initData(batchSize, beamWidth, numHeads, maxAttentionWindow, headSize, pagedKvCache, maxBlocksPerSeq, tokensPerBlock, seqLengths, inputLengths, tokenReadIdxs, tokenWriteIdxs, tokenPosIdxs, tokenSeqIdxs); this->runTest(batchSize, beamWidth, numHeads, headSize, maxAttentionWindow, sinkTokenLength, pastKCacheLength, validTokenNum, pagedKvCache, maxBlocksPerSeq, tokensPerBlock, rotaryEmbeddingDim, rotaryEmbeddingBase, rotaryScaleType, rotaryEmbeddingScale, rotaryEmbeddingMaxPositions, positionEmbeddingType); } }; TYPED_TEST(ShiftKCacheKernelTest, CyclicShiftKCacheSimple) { auto constexpr batchSize = 1; auto constexpr beamWidth = 1; auto constexpr numHeads = 2; auto constexpr headSize = 64; auto constexpr maxAttentionWindow = 32; auto constexpr sinkTokenLength = 0; auto constexpr timestep = 45; auto constexpr pastKCacheLength = timestep - 1; auto constexpr validTokenNum = std::min(pastKCacheLength, maxAttentionWindow); auto constexpr rotaryEmbeddingDim = headSize; auto constexpr rotaryEmbeddingBase = 10000.0; auto constexpr rotaryScaleType = RotaryScalingType::kNONE; auto constexpr rotaryEmbeddingScale = 1.0; auto constexpr rotaryEmbeddingMaxPositions = 4096; auto constexpr positionEmbeddingType = PositionEmbeddingType::kROPE_GPT_NEOX; auto pagedKvCaches = std::vector{false, true}; for (auto pagedKvCache : pagedKvCaches) { const SizeType maxBlocksPerSeq = (pagedKvCache) ? 2 : 0; const SizeType tokensPerBlock = (pagedKvCache) ? 16 : 0; // include one more token for the current time step in seqLengths. std::vector seqLengths = {timestep}; std::vector inputLengths = {8}; std::vector tokenReadIdxs; std::vector tokenWriteIdxs; std::vector tokenPosIdxs; std::vector tokenSeqIdxs; for (SizeType idx = pastKCacheLength - maxAttentionWindow; idx < pastKCacheLength; ++idx) { tokenReadIdxs.push_back(idx % maxAttentionWindow); tokenWriteIdxs.push_back(idx % maxAttentionWindow); tokenPosIdxs.push_back(idx - pastKCacheLength + maxAttentionWindow); tokenSeqIdxs.push_back(idx); } this->initData(batchSize, beamWidth, numHeads, maxAttentionWindow, headSize, pagedKvCache, maxBlocksPerSeq, tokensPerBlock, seqLengths, inputLengths, tokenReadIdxs, tokenWriteIdxs, tokenPosIdxs, tokenSeqIdxs); this->runTest(batchSize, beamWidth, numHeads, headSize, maxAttentionWindow, sinkTokenLength, pastKCacheLength, validTokenNum, pagedKvCache, maxBlocksPerSeq, tokensPerBlock, rotaryEmbeddingDim, rotaryEmbeddingBase, rotaryScaleType, rotaryEmbeddingScale, rotaryEmbeddingMaxPositions, positionEmbeddingType); } }; TYPED_TEST(ShiftKCacheKernelTest, CyclicShiftKCacheSink) { auto constexpr batchSize = 1; auto constexpr beamWidth = 1; auto constexpr numHeads = 2; auto constexpr headSize = 64; auto constexpr maxAttentionWindow = 32; auto constexpr sinkTokenLength = 4; auto constexpr timestep = 67; auto constexpr pastKCacheLength = timestep - 1; auto constexpr validTokenNum = std::min(pastKCacheLength, maxAttentionWindow); auto constexpr rotaryEmbeddingDim = headSize; auto constexpr rotaryEmbeddingBase = 10000.0; auto constexpr rotaryScaleType = RotaryScalingType::kNONE; auto constexpr rotaryEmbeddingScale = 1.0; auto constexpr rotaryEmbeddingMaxPositions = 4096; auto constexpr positionEmbeddingType = PositionEmbeddingType::kROPE_GPT_NEOX; auto pagedKvCaches = std::vector{false, true}; for (auto pagedKvCache : pagedKvCaches) { const SizeType maxBlocksPerSeq = (pagedKvCache) ? 3 : 0; const SizeType tokensPerBlock = (pagedKvCache) ? 16 : 1; const SizeType sinkTokensInLastBlock = sinkTokenLength % tokensPerBlock; const SizeType bubbleLength = sinkTokensInLastBlock == 0 ? 0 : tokensPerBlock - sinkTokensInLastBlock; // include one more token for the current time step in seqLengths. std::vector seqLengths = {timestep}; std::vector inputLengths = {8}; std::vector tokenReadIdxs = {0, 1, 2, 3}; std::vector tokenWriteIdxs = {0, 1, 2, 3}; std::vector tokenPosIdxs = {0, 1, 2, 3}; std::vector tokenSeqIdxs = {0, 1, 2, 3}; const int cyclicLength = maxAttentionWindow - sinkTokenLength; for (SizeType idx = pastKCacheLength - cyclicLength; idx < pastKCacheLength; ++idx) { tokenReadIdxs.push_back(sinkTokenLength + bubbleLength + (idx - sinkTokenLength) % cyclicLength); tokenWriteIdxs.push_back(sinkTokenLength + (idx - sinkTokenLength) % cyclicLength); tokenPosIdxs.push_back(sinkTokenLength + idx - pastKCacheLength + cyclicLength); tokenSeqIdxs.push_back(idx); } this->initData(batchSize, beamWidth, numHeads, maxAttentionWindow, headSize, pagedKvCache, maxBlocksPerSeq, tokensPerBlock, seqLengths, inputLengths, tokenReadIdxs, tokenWriteIdxs, tokenPosIdxs, tokenSeqIdxs); this->runTest(batchSize, beamWidth, numHeads, headSize, maxAttentionWindow, sinkTokenLength, pastKCacheLength, validTokenNum, pagedKvCache, maxBlocksPerSeq, tokensPerBlock, rotaryEmbeddingDim, rotaryEmbeddingBase, rotaryScaleType, rotaryEmbeddingScale, rotaryEmbeddingMaxPositions, positionEmbeddingType); } }; TYPED_TEST(ShiftKCacheKernelTest, CyclicShiftKCacheSinkOneMoreBlock) { auto constexpr batchSize = 1; auto constexpr beamWidth = 2; auto constexpr numHeads = 2; auto constexpr headSize = 64; auto constexpr maxAttentionWindow = 32; auto constexpr sinkTokenLength = 4; auto constexpr timestep = 67; auto constexpr pastKCacheLength = timestep - 1; auto constexpr validTokenNum = std::min(pastKCacheLength, maxAttentionWindow); auto constexpr rotaryEmbeddingDim = headSize; auto constexpr rotaryEmbeddingBase = 10000.0; auto constexpr rotaryScaleType = RotaryScalingType::kNONE; auto constexpr rotaryEmbeddingScale = 1.0; auto constexpr rotaryEmbeddingMaxPositions = 4096; auto constexpr positionEmbeddingType = PositionEmbeddingType::kROPE_GPT_NEOX; auto constexpr pagedKvCache = true; auto constexpr maxBlocksPerSeq = 4; auto constexpr tokensPerBlock = 16; auto constexpr sinkTokensInLastBlock = sinkTokenLength % tokensPerBlock; auto constexpr bubbleLength = sinkTokensInLastBlock == 0 ? 0 : tokensPerBlock - sinkTokensInLastBlock; // include one more token for the current time step in seqLengths. std::vector seqLengths = {timestep}; std::vector inputLengths = {8}; std::vector tokenReadIdxs = {0, 1, 2, 3}; std::vector tokenWriteIdxs = {0, 1, 2, 3}; std::vector tokenPosIdxs = {0, 1, 2, 3}; std::vector tokenSeqIdxs = {0, 1, 2, 3}; auto constexpr cyclicLength = maxAttentionWindow - sinkTokenLength; auto constexpr rCyclicLength = maxAttentionWindow - sinkTokenLength + tokensPerBlock; auto constexpr wCyclicLength = cyclicLength; for (SizeType idx = pastKCacheLength - cyclicLength; idx < pastKCacheLength; ++idx) { tokenReadIdxs.push_back(sinkTokenLength + bubbleLength + (idx - sinkTokenLength) % rCyclicLength); tokenWriteIdxs.push_back(sinkTokenLength + (idx - sinkTokenLength) % wCyclicLength); tokenPosIdxs.push_back(sinkTokenLength + idx - pastKCacheLength + cyclicLength); tokenSeqIdxs.push_back(idx); } this->initData(batchSize, beamWidth, numHeads, maxAttentionWindow, headSize, pagedKvCache, maxBlocksPerSeq, tokensPerBlock, seqLengths, inputLengths, tokenReadIdxs, tokenWriteIdxs, tokenPosIdxs, tokenSeqIdxs); this->runTest(batchSize, beamWidth, numHeads, headSize, maxAttentionWindow, sinkTokenLength, pastKCacheLength, validTokenNum, pagedKvCache, maxBlocksPerSeq, tokensPerBlock, rotaryEmbeddingDim, rotaryEmbeddingBase, rotaryScaleType, rotaryEmbeddingScale, rotaryEmbeddingMaxPositions, positionEmbeddingType); }; } // end of namespace