diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 9a48b3391e..d81be0d1e9 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -642,7 +642,8 @@ public: void startScheduling(); //! \brief Assign blocks for new sequence. Try to reuse blocks. - void addSequence( + //! \return The number of tokens that were matched/prepopulated from cache (prepopulatedPromptLen) + [[nodiscard]] SizeType32 addSequence( GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest); //! \brief Assign blocks for new sequence. Does not try to reuse blocks. @@ -1091,8 +1092,9 @@ public: void allocatePools(bool useUvm); - void addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, - LlmRequest& llmRequest, SizeType32 windowSize); + //! \return The number of tokens that were matched/prepopulated from cache (prepopulatedPromptLen) + [[nodiscard]] SizeType32 addSequence(GenerationRequest& sequence, SizeType32 inputLength, + SizeType32 numContextBlocks, LlmRequest& llmRequest, SizeType32 windowSize); //! \brief Assign blocks for a new sequence. //! \param sequence The GenerationRequest to process. diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index c30bfbdbc3..fd5f5487f2 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -34,6 +34,7 @@ #include "tensorrt_llm/runtime/worldConfig.h" #include +#include #include #include #include @@ -1397,15 +1398,16 @@ void WindowBlockManager::refreshBlocks() // There are two versions of BlockManager::addSequence function. // This is called when block reuse is enabled. -void BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, +SizeType32 BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest, SizeType32 windowSize) { - mWindowBlockManagers.at(windowSize).addSequence(sequence, inputLength, numContextBlocks, llmRequest); + return mWindowBlockManagers.at(windowSize).addSequence(sequence, inputLength, numContextBlocks, llmRequest); } // There are two versions of WindowBlockManager::addSequence function. // This is called when block reuse is enabled. -void WindowBlockManager::addSequence( +// Returns the total prepopulatedPromptLen (including connector matched tokens) for this window. +SizeType32 WindowBlockManager::addSequence( GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest) { auto const requestId = sequence.getRequestId(); @@ -1457,9 +1459,13 @@ void WindowBlockManager::addSequence( numConnectorMatchedTokens = mKvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen); } - llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen + numConnectorMatchedTokens, getTokensPerBlock()); - TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numConnectorMatchedTokens %d", - llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numConnectorMatchedTokens); + // Return the total prepopulated length for this window (do not set on llmRequest here - + // the caller KVCacheManager::addSequence will use the minimum across all windows) + auto const totalPrepopulatedLen = prepopulatedPromptLen + numConnectorMatchedTokens; + TLLM_LOG_DEBUG( + "%s::addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numConnectorMatchedTokens %d", + mLogPrefix.c_str(), llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numConnectorMatchedTokens); + return totalPrepopulatedLen; } void BlockManager::adjustBlocksIfNeeded(GenerationRequest& sequence) @@ -2449,6 +2455,9 @@ void KVCacheManager::addSequence( "[kv cache manager] Encounter existing sequence %d, skip sequence storage validity initialization", requestId); } + // Track the minimum prepopulated length across all windows (for VSWA with mixed isSWA flags) + SizeType32 minPrepopulatedPromptLen = std::numeric_limits::max(); + for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata()) { // NOTE: Caller to KVCacheManager::addSequence should deal with the chunking @@ -2460,7 +2469,11 @@ void KVCacheManager::addSequence( auto const numContextBlocks = tc::ceilDiv(effectiveInputLength, getTokensPerBlock()); if (mEnableBlockReuse) { - mBlockManager.addSequence(sequence, effectiveInputLength, numContextBlocks, *llmRequest, windowSize); + auto const prepopulatedLen + = mBlockManager.addSequence(sequence, effectiveInputLength, numContextBlocks, *llmRequest, windowSize); + // Use the minimum prepopulated length across all windows to ensure correctness + // when there's a mix of SWA and non-SWA windows (e.g., VSWA case) + minPrepopulatedPromptLen = std::min(minPrepopulatedPromptLen, prepopulatedLen); } else { @@ -2479,6 +2492,13 @@ void KVCacheManager::addSequence( mBlockManager.updateSequenceCacheBlockOffsets(sequence, windowSize); } + // Set the prepopulated prompt length once using the minimum across all windows + if (llmRequest && mEnableBlockReuse) + { + TLLM_LOG_DEBUG("KVCacheManager::addSequence: Setting prepopulatedPromptLen to %d", minPrepopulatedPromptLen); + llmRequest->setPrepopulatedPromptLen(minPrepopulatedPromptLen, getTokensPerBlock()); + } + if (llmRequest) { // Update statistics for block allocations/reuse per request. diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index 0803140f1c..b3535c7b11 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -306,7 +306,9 @@ void runPartialCopyTest() auto promptLen0 = llmRequest0->getNumTokens(beamIdx); auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq0.getRequestId()); - blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + auto prepopulatedPromptLen0 + = blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); auto cacheBlockIds = seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx); EXPECT_THAT(cacheBlockIds, ::testing::ElementsAreArray({0, 1, 2})); @@ -354,7 +356,9 @@ void runPartialCopyTest() auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1.getRequestId()); - blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + auto prepopulatedPromptLen1 + = blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 16); auto cacheBlockIds1 = seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx); EXPECT_THAT(cacheBlockIds1, ::testing::ElementsAreArray({0, 1, 6})); @@ -379,7 +383,9 @@ void runPartialCopyTest() auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq2.getRequestId()); - blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); + auto prepopulatedPromptLen2 + = blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); + llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 11); auto cacheBlockIds2 = seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx); EXPECT_THAT(cacheBlockIds2, ::testing::ElementsAreArray({0, 2})); @@ -756,7 +762,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) auto promptLen0 = llmRequest0->getNumTokens(beamIdx); auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq0.getRequestId()); - blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + auto prepopulatedPromptLen0 + = blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); llmRequest0->addNewToken(9, beamIdx); // block 2 contains [8] @@ -783,7 +791,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1.getRequestId()); - blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + auto prepopulatedPromptLen1 + = blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock); EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3})); // at this point, block 3 contains [8] @@ -810,7 +820,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) promptLen0 = llmRequest0->getNumTokens(beamIdx); numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq0_dup.getRequestId()); - blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + prepopulatedPromptLen0 + = blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), promptLen0 - 1); EXPECT_THAT(seq0_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks); @@ -826,7 +838,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) promptLen1 = llmRequest1->getNumTokens(beamIdx); numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1_dup.getRequestId()); - blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + prepopulatedPromptLen1 + = blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock); EXPECT_THAT(seq1_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 4})); llmRequest1->addNewToken(10, beamIdx); // block 4 contains [8, 9, 10] @@ -859,7 +873,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq2.getRequestId()); - blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); + auto prepopulatedPromptLen2 + = blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); + llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest2->getContextCurrentPosition(), tokensPerBlock); EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 5})); llmRequest2->addNewToken(5, beamIdx); // block 5 contains [4] @@ -881,7 +897,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) auto promptLen3 = llmRequest3->getNumTokens(beamIdx); auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq3.getRequestId()); - blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow); + auto prepopulatedPromptLen3 + = blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow); + llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest3->getContextCurrentPosition(), numTokens - 1); EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 4})); llmRequest3->addNewToken(11, beamIdx); // block 4 contains [8, 9, 11] @@ -914,7 +932,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) auto promptLen4 = llmRequest4->getNumTokens(beamIdx); auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq4.getRequestId()); - blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow); + auto prepopulatedPromptLen4 + = blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow); + llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest4->getContextCurrentPosition(), promptLen4 - 1); EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 4})); numTokens = llmRequest4->getNumTokens(beamIdx); @@ -946,7 +966,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) promptLen4 = llmRequest4->getNumTokens(beamIdx); numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq4_dup.getRequestId()); - blockManager.addSequence(seq4_dup, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow); + prepopulatedPromptLen4 + = blockManager.addSequence(seq4_dup, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow); + llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest4->getContextCurrentPosition(), promptLen4 - 2); EXPECT_THAT(seq4_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); numTokens = llmRequest4->getNumTokens(beamIdx); @@ -973,7 +995,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) auto promptLen5 = llmRequest5->getNumTokens(beamIdx); auto numContextBlocks5 = tc::ceilDiv(promptLen5, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq5.getRequestId()); - blockManager.addSequence(seq5, promptLen5, numContextBlocks5, *llmRequest5, maxAttentionWindow); + auto prepopulatedPromptLen5 + = blockManager.addSequence(seq5, promptLen5, numContextBlocks5, *llmRequest5, maxAttentionWindow); + llmRequest5->setPrepopulatedPromptLen(prepopulatedPromptLen5, blockManager.getTokensPerBlock()); llmRequest5->addNewToken(0, beamIdx); EXPECT_EQ(llmRequest5->getContextCurrentPosition(), 1); // incidental reuse @@ -998,7 +1022,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) auto promptLen6 = llmRequest6->getNumTokens(beamIdx); auto numContextBlocks6 = tc::ceilDiv(promptLen6, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq6.getRequestId()); - blockManager.addSequence(seq6, promptLen6, numContextBlocks6, *llmRequest6, maxAttentionWindow); + auto prepopulatedPromptLen6 + = blockManager.addSequence(seq6, promptLen6, numContextBlocks6, *llmRequest6, maxAttentionWindow); + llmRequest6->setPrepopulatedPromptLen(prepopulatedPromptLen6, blockManager.getTokensPerBlock()); llmRequest6->addNewToken(0, beamIdx); // no reuse occurs because we are unable to reuse last input token and inputLength6 == 1. EXPECT_EQ(llmRequest6->getContextCurrentPosition(), 0); @@ -1068,7 +1094,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) auto promptLen0 = llmRequest0->getNumTokens(beamIdx); auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq0.getRequestId()); - blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + auto prepopulatedPromptLen0 + = blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); llmRequest0->addNewToken(3, beamIdx); @@ -1100,7 +1128,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1.getRequestId()); - blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + auto prepopulatedPromptLen1 + = blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock); EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3})); llmRequest1->addNewToken(3, beamIdx); @@ -1127,7 +1157,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) promptLen0 = llmRequest0->getNumTokens(beamIdx); numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq0_dup.getRequestId()); - blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + prepopulatedPromptLen0 + = blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); llmRequest0->addNewToken(3, beamIdx); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 2 * tokensPerBlock); EXPECT_THAT(seq0_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 4})); @@ -1149,7 +1181,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) promptLen1 = llmRequest1->getNumTokens(beamIdx); numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1_dup.getRequestId()); - blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + prepopulatedPromptLen1 + = blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), llmRequest1->getNumTokens(beamIdx) - 1); EXPECT_THAT(seq1_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); llmRequest1->addNewToken(5, beamIdx); @@ -1183,7 +1217,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq2.getRequestId()); - blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); + auto prepopulatedPromptLen2 + = blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); + llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0); EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({5, 6, 7})); llmRequest2->addNewToken(3, beamIdx); @@ -1209,7 +1245,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) auto promptLen3 = llmRequest3->getNumTokens(beamIdx); auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq3.getRequestId()); - blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow); + auto prepopulatedPromptLen3 + = blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow); + llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest3->getContextCurrentPosition(), tokensPerBlock); EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 8, 9})); llmRequest3->addNewToken(3, beamIdx); @@ -1286,7 +1324,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) auto promptLen0 = llmRequest0->getNumTokens(beamIdx); auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq0.getRequestId()); - blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + auto prepopulatedPromptLen0 + = blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); llmRequest0->addNewToken(3, beamIdx); @@ -1323,7 +1363,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1.getRequestId()); - blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + auto prepopulatedPromptLen1 + = blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock); EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3})); llmRequest1->addNewToken(3, beamIdx); @@ -1358,7 +1400,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq2.getRequestId()); - blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); + auto prepopulatedPromptLen2 + = blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); + llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0); EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({4, 5, 6})); llmRequest2->addNewToken(9, beamIdx); @@ -1391,7 +1435,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) auto promptLen3 = llmRequest3->getNumTokens(beamIdx); auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq3.getRequestId()); - blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow); + auto prepopulatedPromptLen3 + = blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow); + llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest3->getContextCurrentPosition(), tokensPerBlock); // only reuse block 0 [100, 101, 102, 103] with same hash/offset EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 7, 8})); @@ -1462,7 +1508,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); // get new blocks 0, 1, 2 ([0,1,2,3], [4,5,6,7], [8]) blockManager.holdSequence(seq0.getRequestId()); - blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + auto prepopulatedPromptLen0 + = blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); llmRequest0->addNewToken(9, beamIdx); @@ -1492,7 +1540,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1.getRequestId()); - blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + auto prepopulatedPromptLen1 + = blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock); EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3})); llmRequest1->addNewToken(9, beamIdx); @@ -1517,7 +1567,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) promptLen0 = llmRequest0->getNumTokens(beamIdx); numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq0_dup.getRequestId()); - blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + prepopulatedPromptLen0 + = blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); // nb! addNewToken adds new generated token, number of input tokens stay the same. // calling addNewToken before addSequence potentially triggers this error message: // Assertion failed: prepopulatedPromptLen < promptLen @@ -1539,7 +1591,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); // reuse 0, 1, 2(p) ([0,1,2,3], [4,5,6,7], [8]) blockManager.holdSequence(seq1_dup.getRequestId()); - blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + prepopulatedPromptLen1 + = blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), llmRequest1->getNumTokens(beamIdx) - 1); EXPECT_THAT(seq1_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); llmRequest1->addNewToken(10, beamIdx); @@ -1571,7 +1625,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq2.getRequestId()); - blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); + auto prepopulatedPromptLen2 + = blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); + llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock()); // no reuse expected. Input tokens match blocks 0 and 1, but lora task id differs. EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0); EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({5, 6, 7})); @@ -1600,7 +1656,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) auto promptLen3 = llmRequest3->getNumTokens(beamIdx); auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq3.getRequestId()); - blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow); + auto prepopulatedPromptLen3 + = blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow); + llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest3->getContextCurrentPosition(), promptLen3 - 2); EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({5, 6, 7})); llmRequest3->addNewToken(11, beamIdx); @@ -1630,7 +1688,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) auto promptLen4 = llmRequest4->getNumTokens(beamIdx); auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq4.getRequestId()); - blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow); + auto prepopulatedPromptLen4 + = blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow); + llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest4->getContextCurrentPosition(), tokensPerBlock); EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 8})); llmRequest4->addNewToken(5, beamIdx); @@ -1655,7 +1715,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) auto promptLen5 = llmRequest5->getNumTokens(beamIdx); auto numContextBlocks5 = tc::ceilDiv(promptLen5, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq5.getRequestId()); - blockManager.addSequence(seq5, promptLen5, numContextBlocks5, *llmRequest5, maxAttentionWindow); + auto prepopulatedPromptLen5 + = blockManager.addSequence(seq5, promptLen5, numContextBlocks5, *llmRequest5, maxAttentionWindow); + llmRequest5->setPrepopulatedPromptLen(prepopulatedPromptLen5, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest5->getContextCurrentPosition(), 0); EXPECT_THAT(seq5.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({9, 10, 11})); llmRequest5->addNewToken(9, beamIdx); @@ -1726,7 +1788,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) auto promptLen0 = llmRequest0->getNumTokens(beamIdx); auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq0.getRequestId()); - blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + auto prepopulatedPromptLen0 + = blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); llmRequest0->addNewToken(3, beamIdx); @@ -1759,7 +1823,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1.getRequestId()); - blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + auto prepopulatedPromptLen1 + = blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 0); EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 5})); llmRequest1->addNewToken(3, beamIdx); @@ -1786,7 +1852,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); // reuse blocks 0, 1 and get new block 6 blockManager.holdSequence(seq0_dup.getRequestId()); - blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + prepopulatedPromptLen0 + = blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); llmRequest0->addNewToken(3, beamIdx); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 2 * tokensPerBlock); EXPECT_THAT(seq0_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 6})); @@ -1808,7 +1876,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) promptLen1 = llmRequest1->getNumTokens(beamIdx); numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1_dup.getRequestId()); - blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + prepopulatedPromptLen1 + = blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), llmRequest1->getNumTokens(beamIdx) - 1); EXPECT_THAT(seq1_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 5})); llmRequest1->addNewToken(5, beamIdx); @@ -1841,7 +1911,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq2.getRequestId()); - blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); + auto prepopulatedPromptLen2 + = blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); + llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0); EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({7, 8, 9})); llmRequest2->addNewToken(3, beamIdx); @@ -1867,7 +1939,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) auto promptLen3 = llmRequest3->getNumTokens(beamIdx); auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq3.getRequestId()); - blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow); + auto prepopulatedPromptLen3 + = blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow); + llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest3->getContextCurrentPosition(), tokensPerBlock); EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 10, 11})); llmRequest3->addNewToken(3, beamIdx); @@ -1892,7 +1966,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) auto promptLen4 = llmRequest4->getNumTokens(beamIdx); auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq4.getRequestId()); - blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow); + auto prepopulatedPromptLen4 + = blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow); + llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest4->getContextCurrentPosition(), tokensPerBlock); EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 12, 13})); llmRequest4->addNewToken(3, beamIdx); @@ -1971,7 +2047,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) auto promptLen0 = llmRequest0->getNumTokens(beamIdx); auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq0.getRequestId()); - blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + auto prepopulatedPromptLen0 + = blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); @@ -2009,7 +2087,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1.getRequestId()); - blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + auto prepopulatedPromptLen1 + = blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 0); // No reuse, starts from scratch EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 5})); @@ -2042,7 +2122,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq2.getRequestId()); - blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); + auto prepopulatedPromptLen2 + = blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); + llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 2 * tokensPerBlock); // Reuse blocks 3,4 EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 6})); @@ -2076,7 +2158,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) auto promptLen3 = llmRequest3->getNumTokens(beamIdx); auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq3.getRequestId()); - blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow); + auto prepopulatedPromptLen3 + = blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow); + llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest3->getContextCurrentPosition(), 0); // No reuse EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({7, 8, 9})); @@ -2103,7 +2187,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) auto promptLen4 = llmRequest4->getNumTokens(beamIdx); auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq4.getRequestId()); - blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow); + auto prepopulatedPromptLen4 + = blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow); + llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest4->getContextCurrentPosition(), 2 * tokensPerBlock); // Reuse blocks 0,1 EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 10})); @@ -2225,7 +2311,9 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest) GenerationRequest seq0{0, inputLength0, beamWidth, blockManager.getWindowSizesMetadata()}; auto numContextBlocks0 = tc::ceilDiv(inputLength0, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq0.getRequestId()); - blockManager.addSequence(seq0, llmRequest0->getNumTokens(0), numContextBlocks0, *llmRequest0, maxAttentionWindow); + auto prepopulatedPromptLen0 = blockManager.addSequence( + seq0, llmRequest0->getNumTokens(0), numContextBlocks0, *llmRequest0, maxAttentionWindow); + llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); // Add another sequence with different tokens, at a low priority auto inputTokens1 = std::make_shared(VecTokens{8, 9, 10, 11, 12, 13, 14, 15}); @@ -2234,7 +2322,9 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest) GenerationRequest seq1{1, inputLength1, beamWidth, blockManager.getWindowSizesMetadata()}; auto numContextBlocks1 = tc::ceilDiv(inputLength1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1.getRequestId()); - blockManager.addSequence(seq1, llmRequest1->getNumTokens(0), numContextBlocks1, *llmRequest1, maxAttentionWindow); + auto prepopulatedPromptLen1 = blockManager.addSequence( + seq1, llmRequest1->getNumTokens(0), numContextBlocks1, *llmRequest1, maxAttentionWindow); + llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); // Release both sequences blockManager.releaseBlocks(seq0, llmRequest0); @@ -2251,7 +2341,9 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest) GenerationRequest seq2{2, inputLength2, beamWidth, blockManager.getWindowSizesMetadata()}; auto numContextBlocks2 = tc::ceilDiv(inputLength2, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq2.getRequestId()); - blockManager.addSequence(seq2, llmRequest2->getNumTokens(0), numContextBlocks2, *llmRequest2, maxAttentionWindow); + auto prepopulatedPromptLen2 = blockManager.addSequence( + seq2, llmRequest2->getNumTokens(0), numContextBlocks2, *llmRequest2, maxAttentionWindow); + llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock()); blockManager.releaseBlocks(seq2, llmRequest2); blockManager.releaseSequence(seq2.getRequestId()); @@ -2262,7 +2354,9 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest) GenerationRequest seq3{3, inputLength3, beamWidth, blockManager.getWindowSizesMetadata()}; auto numContextBlocks3 = tc::ceilDiv(inputLength3, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq3.getRequestId()); - blockManager.addSequence(seq3, llmRequest3->getNumTokens(0), numContextBlocks3, *llmRequest3, maxAttentionWindow); + auto prepopulatedPromptLen3 = blockManager.addSequence( + seq3, llmRequest3->getNumTokens(0), numContextBlocks3, *llmRequest3, maxAttentionWindow); + llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest3->getContextCurrentPosition(), 4); @@ -2277,7 +2371,9 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest) GenerationRequest seq4{4, inputLength3, beamWidth, blockManager.getWindowSizesMetadata()}; auto numContextBlocks4 = tc::ceilDiv(inputLength4, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq4.getRequestId()); - blockManager.addSequence(seq4, llmRequest4->getNumTokens(0), numContextBlocks4, *llmRequest4, maxAttentionWindow); + auto prepopulatedPromptLen4 = blockManager.addSequence( + seq4, llmRequest4->getNumTokens(0), numContextBlocks4, *llmRequest4, maxAttentionWindow); + llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest4->getContextCurrentPosition(), 4); @@ -2288,7 +2384,9 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest) GenerationRequest seq5{5, inputLength5, beamWidth, blockManager.getWindowSizesMetadata()}; auto numContextBlocks5 = tc::ceilDiv(inputLength5, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq5.getRequestId()); - blockManager.addSequence(seq5, llmRequest5->getNumTokens(0), numContextBlocks5, *llmRequest5, maxAttentionWindow); + auto prepopulatedPromptLen5 = blockManager.addSequence( + seq5, llmRequest5->getNumTokens(0), numContextBlocks5, *llmRequest5, maxAttentionWindow); + llmRequest5->setPrepopulatedPromptLen(prepopulatedPromptLen5, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest5->getContextCurrentPosition(), 0); } diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 8be3b35091..672906b05d 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1217,10 +1217,6 @@ class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness): task = MMLU(self.MODEL_NAME) task.evaluate(llm) - @pytest.mark.skip( - reason= - "Currently failing due to accuracy drop, https://nvbugspro.nvidia.com/bug/5674665" - ) def test_auto_dtype_vswa_reuse_disable_overlap_scheduler(self): # NOTE: Test with VSWA kv cache config. kv_cache_config = KvCacheConfig( diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 69582ecc09..d7f55c98d1 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -49,6 +49,7 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_reuse_low_memory_available_no_partial_reuse - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_reuse_low_memory_available_partial_reuse - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_without_reuse_disable_overlap_scheduler + - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_reuse_disable_overlap_scheduler - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=TRTLLM] TIMEOUT (90) diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index c8ede3ed8c..7d08f44677 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -92,6 +92,7 @@ def test_kv_lens_runtime_with_eagle3_one_model(): f"kv_lens should be {expected_kv_lens_with_extra.tolist()}, but got {kv_lens_internal.tolist()}" +@pytest.mark.skip(reason="https://nvbugs/5856637") @pytest.mark.parametrize( "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp,use_hf_speculative_model", [