[https://nvbugs/5674665][fix] Fix accuracy drop in VSWA with KV cache block reuse (#10875)

Signed-off-by: SimengLiu-nv <simengl@nvidia.com>
This commit is contained in:
Simeng Liu 2026-02-04 09:46:31 -08:00 committed by GitHub
parent 767b8dcab3
commit d9fd8cc951
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 181 additions and 63 deletions

View File

@ -642,7 +642,8 @@ public:
void startScheduling();
//! \brief Assign blocks for new sequence. Try to reuse blocks.
void addSequence(
//! \return The number of tokens that were matched/prepopulated from cache (prepopulatedPromptLen)
[[nodiscard]] SizeType32 addSequence(
GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest);
//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
@ -1091,8 +1092,9 @@ public:
void allocatePools(bool useUvm);
void addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks,
LlmRequest& llmRequest, SizeType32 windowSize);
//! \return The number of tokens that were matched/prepopulated from cache (prepopulatedPromptLen)
[[nodiscard]] SizeType32 addSequence(GenerationRequest& sequence, SizeType32 inputLength,
SizeType32 numContextBlocks, LlmRequest& llmRequest, SizeType32 windowSize);
//! \brief Assign blocks for a new sequence.
//! \param sequence The GenerationRequest to process.

View File

@ -34,6 +34,7 @@
#include "tensorrt_llm/runtime/worldConfig.h"
#include <algorithm>
#include <limits>
#include <map>
#include <optional>
#include <utility>
@ -1397,15 +1398,16 @@ void WindowBlockManager::refreshBlocks()
// There are two versions of BlockManager::addSequence function.
// This is called when block reuse is enabled.
void BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks,
SizeType32 BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks,
LlmRequest& llmRequest, SizeType32 windowSize)
{
mWindowBlockManagers.at(windowSize).addSequence(sequence, inputLength, numContextBlocks, llmRequest);
return mWindowBlockManagers.at(windowSize).addSequence(sequence, inputLength, numContextBlocks, llmRequest);
}
// There are two versions of WindowBlockManager::addSequence function.
// This is called when block reuse is enabled.
void WindowBlockManager::addSequence(
// Returns the total prepopulatedPromptLen (including connector matched tokens) for this window.
SizeType32 WindowBlockManager::addSequence(
GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest)
{
auto const requestId = sequence.getRequestId();
@ -1457,9 +1459,13 @@ void WindowBlockManager::addSequence(
numConnectorMatchedTokens = mKvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen);
}
llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen + numConnectorMatchedTokens, getTokensPerBlock());
TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numConnectorMatchedTokens %d",
llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numConnectorMatchedTokens);
// Return the total prepopulated length for this window (do not set on llmRequest here -
// the caller KVCacheManager::addSequence will use the minimum across all windows)
auto const totalPrepopulatedLen = prepopulatedPromptLen + numConnectorMatchedTokens;
TLLM_LOG_DEBUG(
"%s::addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numConnectorMatchedTokens %d",
mLogPrefix.c_str(), llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numConnectorMatchedTokens);
return totalPrepopulatedLen;
}
void BlockManager::adjustBlocksIfNeeded(GenerationRequest& sequence)
@ -2449,6 +2455,9 @@ void KVCacheManager::addSequence(
"[kv cache manager] Encounter existing sequence %d, skip sequence storage validity initialization",
requestId);
}
// Track the minimum prepopulated length across all windows (for VSWA with mixed isSWA flags)
SizeType32 minPrepopulatedPromptLen = std::numeric_limits<SizeType32>::max();
for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata())
{
// NOTE: Caller to KVCacheManager::addSequence should deal with the chunking
@ -2460,7 +2469,11 @@ void KVCacheManager::addSequence(
auto const numContextBlocks = tc::ceilDiv(effectiveInputLength, getTokensPerBlock());
if (mEnableBlockReuse)
{
mBlockManager.addSequence(sequence, effectiveInputLength, numContextBlocks, *llmRequest, windowSize);
auto const prepopulatedLen
= mBlockManager.addSequence(sequence, effectiveInputLength, numContextBlocks, *llmRequest, windowSize);
// Use the minimum prepopulated length across all windows to ensure correctness
// when there's a mix of SWA and non-SWA windows (e.g., VSWA case)
minPrepopulatedPromptLen = std::min(minPrepopulatedPromptLen, prepopulatedLen);
}
else
{
@ -2479,6 +2492,13 @@ void KVCacheManager::addSequence(
mBlockManager.updateSequenceCacheBlockOffsets(sequence, windowSize);
}
// Set the prepopulated prompt length once using the minimum across all windows
if (llmRequest && mEnableBlockReuse)
{
TLLM_LOG_DEBUG("KVCacheManager::addSequence: Setting prepopulatedPromptLen to %d", minPrepopulatedPromptLen);
llmRequest->setPrepopulatedPromptLen(minPrepopulatedPromptLen, getTokensPerBlock());
}
if (llmRequest)
{
// Update statistics for block allocations/reuse per request.

View File

@ -306,7 +306,9 @@ void runPartialCopyTest()
auto promptLen0 = llmRequest0->getNumTokens(beamIdx);
auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq0.getRequestId());
blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
auto prepopulatedPromptLen0
= blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0);
auto cacheBlockIds = seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx);
EXPECT_THAT(cacheBlockIds, ::testing::ElementsAreArray({0, 1, 2}));
@ -354,7 +356,9 @@ void runPartialCopyTest()
auto promptLen1 = llmRequest1->getNumTokens(beamIdx);
auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq1.getRequestId());
blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
auto prepopulatedPromptLen1
= blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 16);
auto cacheBlockIds1 = seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx);
EXPECT_THAT(cacheBlockIds1, ::testing::ElementsAreArray({0, 1, 6}));
@ -379,7 +383,9 @@ void runPartialCopyTest()
auto promptLen2 = llmRequest2->getNumTokens(beamIdx);
auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq2.getRequestId());
blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
auto prepopulatedPromptLen2
= blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 11);
auto cacheBlockIds2 = seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx);
EXPECT_THAT(cacheBlockIds2, ::testing::ElementsAreArray({0, 2}));
@ -756,7 +762,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
auto promptLen0 = llmRequest0->getNumTokens(beamIdx);
auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq0.getRequestId());
blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
auto prepopulatedPromptLen0
= blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0);
EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
llmRequest0->addNewToken(9, beamIdx); // block 2 contains [8]
@ -783,7 +791,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
auto promptLen1 = llmRequest1->getNumTokens(beamIdx);
auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq1.getRequestId());
blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
auto prepopulatedPromptLen1
= blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock);
EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3}));
// at this point, block 3 contains [8]
@ -810,7 +820,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
promptLen0 = llmRequest0->getNumTokens(beamIdx);
numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq0_dup.getRequestId());
blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
prepopulatedPromptLen0
= blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), promptLen0 - 1);
EXPECT_THAT(seq0_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks);
@ -826,7 +838,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
promptLen1 = llmRequest1->getNumTokens(beamIdx);
numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq1_dup.getRequestId());
blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
prepopulatedPromptLen1
= blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock);
EXPECT_THAT(seq1_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 4}));
llmRequest1->addNewToken(10, beamIdx); // block 4 contains [8, 9, 10]
@ -859,7 +873,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
auto promptLen2 = llmRequest2->getNumTokens(beamIdx);
auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq2.getRequestId());
blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
auto prepopulatedPromptLen2
= blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest2->getContextCurrentPosition(), tokensPerBlock);
EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 5}));
llmRequest2->addNewToken(5, beamIdx); // block 5 contains [4]
@ -881,7 +897,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
auto promptLen3 = llmRequest3->getNumTokens(beamIdx);
auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq3.getRequestId());
blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
auto prepopulatedPromptLen3
= blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest3->getContextCurrentPosition(), numTokens - 1);
EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 4}));
llmRequest3->addNewToken(11, beamIdx); // block 4 contains [8, 9, 11]
@ -914,7 +932,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
auto promptLen4 = llmRequest4->getNumTokens(beamIdx);
auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq4.getRequestId());
blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
auto prepopulatedPromptLen4
= blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest4->getContextCurrentPosition(), promptLen4 - 1);
EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 4}));
numTokens = llmRequest4->getNumTokens(beamIdx);
@ -946,7 +966,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
promptLen4 = llmRequest4->getNumTokens(beamIdx);
numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq4_dup.getRequestId());
blockManager.addSequence(seq4_dup, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
prepopulatedPromptLen4
= blockManager.addSequence(seq4_dup, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest4->getContextCurrentPosition(), promptLen4 - 2);
EXPECT_THAT(seq4_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
numTokens = llmRequest4->getNumTokens(beamIdx);
@ -973,7 +995,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
auto promptLen5 = llmRequest5->getNumTokens(beamIdx);
auto numContextBlocks5 = tc::ceilDiv(promptLen5, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq5.getRequestId());
blockManager.addSequence(seq5, promptLen5, numContextBlocks5, *llmRequest5, maxAttentionWindow);
auto prepopulatedPromptLen5
= blockManager.addSequence(seq5, promptLen5, numContextBlocks5, *llmRequest5, maxAttentionWindow);
llmRequest5->setPrepopulatedPromptLen(prepopulatedPromptLen5, blockManager.getTokensPerBlock());
llmRequest5->addNewToken(0, beamIdx);
EXPECT_EQ(llmRequest5->getContextCurrentPosition(), 1); // incidental reuse
@ -998,7 +1022,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
auto promptLen6 = llmRequest6->getNumTokens(beamIdx);
auto numContextBlocks6 = tc::ceilDiv(promptLen6, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq6.getRequestId());
blockManager.addSequence(seq6, promptLen6, numContextBlocks6, *llmRequest6, maxAttentionWindow);
auto prepopulatedPromptLen6
= blockManager.addSequence(seq6, promptLen6, numContextBlocks6, *llmRequest6, maxAttentionWindow);
llmRequest6->setPrepopulatedPromptLen(prepopulatedPromptLen6, blockManager.getTokensPerBlock());
llmRequest6->addNewToken(0, beamIdx);
// no reuse occurs because we are unable to reuse last input token and inputLength6 == 1.
EXPECT_EQ(llmRequest6->getContextCurrentPosition(), 0);
@ -1068,7 +1094,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
auto promptLen0 = llmRequest0->getNumTokens(beamIdx);
auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq0.getRequestId());
blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
auto prepopulatedPromptLen0
= blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0);
EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
llmRequest0->addNewToken(3, beamIdx);
@ -1100,7 +1128,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
auto promptLen1 = llmRequest1->getNumTokens(beamIdx);
auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq1.getRequestId());
blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
auto prepopulatedPromptLen1
= blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock);
EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3}));
llmRequest1->addNewToken(3, beamIdx);
@ -1127,7 +1157,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
promptLen0 = llmRequest0->getNumTokens(beamIdx);
numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq0_dup.getRequestId());
blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
prepopulatedPromptLen0
= blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
llmRequest0->addNewToken(3, beamIdx);
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 2 * tokensPerBlock);
EXPECT_THAT(seq0_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 4}));
@ -1149,7 +1181,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
promptLen1 = llmRequest1->getNumTokens(beamIdx);
numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq1_dup.getRequestId());
blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
prepopulatedPromptLen1
= blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), llmRequest1->getNumTokens(beamIdx) - 1);
EXPECT_THAT(seq1_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
llmRequest1->addNewToken(5, beamIdx);
@ -1183,7 +1217,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
auto promptLen2 = llmRequest2->getNumTokens(beamIdx);
auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq2.getRequestId());
blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
auto prepopulatedPromptLen2
= blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0);
EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({5, 6, 7}));
llmRequest2->addNewToken(3, beamIdx);
@ -1209,7 +1245,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
auto promptLen3 = llmRequest3->getNumTokens(beamIdx);
auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq3.getRequestId());
blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
auto prepopulatedPromptLen3
= blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest3->getContextCurrentPosition(), tokensPerBlock);
EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 8, 9}));
llmRequest3->addNewToken(3, beamIdx);
@ -1286,7 +1324,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest)
auto promptLen0 = llmRequest0->getNumTokens(beamIdx);
auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq0.getRequestId());
blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
auto prepopulatedPromptLen0
= blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0);
EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
llmRequest0->addNewToken(3, beamIdx);
@ -1323,7 +1363,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest)
auto promptLen1 = llmRequest1->getNumTokens(beamIdx);
auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq1.getRequestId());
blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
auto prepopulatedPromptLen1
= blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock);
EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3}));
llmRequest1->addNewToken(3, beamIdx);
@ -1358,7 +1400,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest)
auto promptLen2 = llmRequest2->getNumTokens(beamIdx);
auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq2.getRequestId());
blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
auto prepopulatedPromptLen2
= blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0);
EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({4, 5, 6}));
llmRequest2->addNewToken(9, beamIdx);
@ -1391,7 +1435,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest)
auto promptLen3 = llmRequest3->getNumTokens(beamIdx);
auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq3.getRequestId());
blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
auto prepopulatedPromptLen3
= blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest3->getContextCurrentPosition(),
tokensPerBlock); // only reuse block 0 [100, 101, 102, 103] with same hash/offset
EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 7, 8}));
@ -1462,7 +1508,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
// get new blocks 0, 1, 2 ([0,1,2,3], [4,5,6,7], [8])
blockManager.holdSequence(seq0.getRequestId());
blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
auto prepopulatedPromptLen0
= blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0);
EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
llmRequest0->addNewToken(9, beamIdx);
@ -1492,7 +1540,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
auto promptLen1 = llmRequest1->getNumTokens(beamIdx);
auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq1.getRequestId());
blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
auto prepopulatedPromptLen1
= blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock);
EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3}));
llmRequest1->addNewToken(9, beamIdx);
@ -1517,7 +1567,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
promptLen0 = llmRequest0->getNumTokens(beamIdx);
numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq0_dup.getRequestId());
blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
prepopulatedPromptLen0
= blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
// nb! addNewToken adds new generated token, number of input tokens stay the same.
// calling addNewToken before addSequence potentially triggers this error message:
// Assertion failed: prepopulatedPromptLen < promptLen
@ -1539,7 +1591,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
// reuse 0, 1, 2(p) ([0,1,2,3], [4,5,6,7], [8])
blockManager.holdSequence(seq1_dup.getRequestId());
blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
prepopulatedPromptLen1
= blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), llmRequest1->getNumTokens(beamIdx) - 1);
EXPECT_THAT(seq1_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
llmRequest1->addNewToken(10, beamIdx);
@ -1571,7 +1625,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
auto promptLen2 = llmRequest2->getNumTokens(beamIdx);
auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq2.getRequestId());
blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
auto prepopulatedPromptLen2
= blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock());
// no reuse expected. Input tokens match blocks 0 and 1, but lora task id differs.
EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0);
EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({5, 6, 7}));
@ -1600,7 +1656,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
auto promptLen3 = llmRequest3->getNumTokens(beamIdx);
auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq3.getRequestId());
blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
auto prepopulatedPromptLen3
= blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest3->getContextCurrentPosition(), promptLen3 - 2);
EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({5, 6, 7}));
llmRequest3->addNewToken(11, beamIdx);
@ -1630,7 +1688,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
auto promptLen4 = llmRequest4->getNumTokens(beamIdx);
auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq4.getRequestId());
blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
auto prepopulatedPromptLen4
= blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest4->getContextCurrentPosition(), tokensPerBlock);
EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 8}));
llmRequest4->addNewToken(5, beamIdx);
@ -1655,7 +1715,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
auto promptLen5 = llmRequest5->getNumTokens(beamIdx);
auto numContextBlocks5 = tc::ceilDiv(promptLen5, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq5.getRequestId());
blockManager.addSequence(seq5, promptLen5, numContextBlocks5, *llmRequest5, maxAttentionWindow);
auto prepopulatedPromptLen5
= blockManager.addSequence(seq5, promptLen5, numContextBlocks5, *llmRequest5, maxAttentionWindow);
llmRequest5->setPrepopulatedPromptLen(prepopulatedPromptLen5, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest5->getContextCurrentPosition(), 0);
EXPECT_THAT(seq5.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({9, 10, 11}));
llmRequest5->addNewToken(9, beamIdx);
@ -1726,7 +1788,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
auto promptLen0 = llmRequest0->getNumTokens(beamIdx);
auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq0.getRequestId());
blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
auto prepopulatedPromptLen0
= blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0);
EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
llmRequest0->addNewToken(3, beamIdx);
@ -1759,7 +1823,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
auto promptLen1 = llmRequest1->getNumTokens(beamIdx);
auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq1.getRequestId());
blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
auto prepopulatedPromptLen1
= blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 0);
EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 5}));
llmRequest1->addNewToken(3, beamIdx);
@ -1786,7 +1852,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
// reuse blocks 0, 1 and get new block 6
blockManager.holdSequence(seq0_dup.getRequestId());
blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
prepopulatedPromptLen0
= blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
llmRequest0->addNewToken(3, beamIdx);
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 2 * tokensPerBlock);
EXPECT_THAT(seq0_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 6}));
@ -1808,7 +1876,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
promptLen1 = llmRequest1->getNumTokens(beamIdx);
numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq1_dup.getRequestId());
blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
prepopulatedPromptLen1
= blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), llmRequest1->getNumTokens(beamIdx) - 1);
EXPECT_THAT(seq1_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 5}));
llmRequest1->addNewToken(5, beamIdx);
@ -1841,7 +1911,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
auto promptLen2 = llmRequest2->getNumTokens(beamIdx);
auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq2.getRequestId());
blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
auto prepopulatedPromptLen2
= blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0);
EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({7, 8, 9}));
llmRequest2->addNewToken(3, beamIdx);
@ -1867,7 +1939,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
auto promptLen3 = llmRequest3->getNumTokens(beamIdx);
auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq3.getRequestId());
blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
auto prepopulatedPromptLen3
= blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest3->getContextCurrentPosition(), tokensPerBlock);
EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 10, 11}));
llmRequest3->addNewToken(3, beamIdx);
@ -1892,7 +1966,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
auto promptLen4 = llmRequest4->getNumTokens(beamIdx);
auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq4.getRequestId());
blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
auto prepopulatedPromptLen4
= blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest4->getContextCurrentPosition(), tokensPerBlock);
EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 12, 13}));
llmRequest4->addNewToken(3, beamIdx);
@ -1971,7 +2047,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest)
auto promptLen0 = llmRequest0->getNumTokens(beamIdx);
auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq0.getRequestId());
blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
auto prepopulatedPromptLen0
= blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0);
EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
@ -2009,7 +2087,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest)
auto promptLen1 = llmRequest1->getNumTokens(beamIdx);
auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq1.getRequestId());
blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
auto prepopulatedPromptLen1
= blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 0); // No reuse, starts from scratch
EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 5}));
@ -2042,7 +2122,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest)
auto promptLen2 = llmRequest2->getNumTokens(beamIdx);
auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq2.getRequestId());
blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
auto prepopulatedPromptLen2
= blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 2 * tokensPerBlock); // Reuse blocks 3,4
EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 6}));
@ -2076,7 +2158,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest)
auto promptLen3 = llmRequest3->getNumTokens(beamIdx);
auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq3.getRequestId());
blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
auto prepopulatedPromptLen3
= blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest3->getContextCurrentPosition(), 0); // No reuse
EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({7, 8, 9}));
@ -2103,7 +2187,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest)
auto promptLen4 = llmRequest4->getNumTokens(beamIdx);
auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq4.getRequestId());
blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
auto prepopulatedPromptLen4
= blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest4->getContextCurrentPosition(), 2 * tokensPerBlock); // Reuse blocks 0,1
EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 10}));
@ -2225,7 +2311,9 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest)
GenerationRequest seq0{0, inputLength0, beamWidth, blockManager.getWindowSizesMetadata()};
auto numContextBlocks0 = tc::ceilDiv(inputLength0, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq0.getRequestId());
blockManager.addSequence(seq0, llmRequest0->getNumTokens(0), numContextBlocks0, *llmRequest0, maxAttentionWindow);
auto prepopulatedPromptLen0 = blockManager.addSequence(
seq0, llmRequest0->getNumTokens(0), numContextBlocks0, *llmRequest0, maxAttentionWindow);
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
// Add another sequence with different tokens, at a low priority
auto inputTokens1 = std::make_shared<VecTokens>(VecTokens{8, 9, 10, 11, 12, 13, 14, 15});
@ -2234,7 +2322,9 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest)
GenerationRequest seq1{1, inputLength1, beamWidth, blockManager.getWindowSizesMetadata()};
auto numContextBlocks1 = tc::ceilDiv(inputLength1, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq1.getRequestId());
blockManager.addSequence(seq1, llmRequest1->getNumTokens(0), numContextBlocks1, *llmRequest1, maxAttentionWindow);
auto prepopulatedPromptLen1 = blockManager.addSequence(
seq1, llmRequest1->getNumTokens(0), numContextBlocks1, *llmRequest1, maxAttentionWindow);
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
// Release both sequences
blockManager.releaseBlocks(seq0, llmRequest0);
@ -2251,7 +2341,9 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest)
GenerationRequest seq2{2, inputLength2, beamWidth, blockManager.getWindowSizesMetadata()};
auto numContextBlocks2 = tc::ceilDiv(inputLength2, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq2.getRequestId());
blockManager.addSequence(seq2, llmRequest2->getNumTokens(0), numContextBlocks2, *llmRequest2, maxAttentionWindow);
auto prepopulatedPromptLen2 = blockManager.addSequence(
seq2, llmRequest2->getNumTokens(0), numContextBlocks2, *llmRequest2, maxAttentionWindow);
llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock());
blockManager.releaseBlocks(seq2, llmRequest2);
blockManager.releaseSequence(seq2.getRequestId());
@ -2262,7 +2354,9 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest)
GenerationRequest seq3{3, inputLength3, beamWidth, blockManager.getWindowSizesMetadata()};
auto numContextBlocks3 = tc::ceilDiv(inputLength3, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq3.getRequestId());
blockManager.addSequence(seq3, llmRequest3->getNumTokens(0), numContextBlocks3, *llmRequest3, maxAttentionWindow);
auto prepopulatedPromptLen3 = blockManager.addSequence(
seq3, llmRequest3->getNumTokens(0), numContextBlocks3, *llmRequest3, maxAttentionWindow);
llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest3->getContextCurrentPosition(), 4);
@ -2277,7 +2371,9 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest)
GenerationRequest seq4{4, inputLength3, beamWidth, blockManager.getWindowSizesMetadata()};
auto numContextBlocks4 = tc::ceilDiv(inputLength4, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq4.getRequestId());
blockManager.addSequence(seq4, llmRequest4->getNumTokens(0), numContextBlocks4, *llmRequest4, maxAttentionWindow);
auto prepopulatedPromptLen4 = blockManager.addSequence(
seq4, llmRequest4->getNumTokens(0), numContextBlocks4, *llmRequest4, maxAttentionWindow);
llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest4->getContextCurrentPosition(), 4);
@ -2288,7 +2384,9 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest)
GenerationRequest seq5{5, inputLength5, beamWidth, blockManager.getWindowSizesMetadata()};
auto numContextBlocks5 = tc::ceilDiv(inputLength5, blockManager.getTokensPerBlock());
blockManager.holdSequence(seq5.getRequestId());
blockManager.addSequence(seq5, llmRequest5->getNumTokens(0), numContextBlocks5, *llmRequest5, maxAttentionWindow);
auto prepopulatedPromptLen5 = blockManager.addSequence(
seq5, llmRequest5->getNumTokens(0), numContextBlocks5, *llmRequest5, maxAttentionWindow);
llmRequest5->setPrepopulatedPromptLen(prepopulatedPromptLen5, blockManager.getTokensPerBlock());
EXPECT_EQ(llmRequest5->getContextCurrentPosition(), 0);
}

View File

@ -1217,10 +1217,6 @@ class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness):
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.skip(
reason=
"Currently failing due to accuracy drop, https://nvbugspro.nvidia.com/bug/5674665"
)
def test_auto_dtype_vswa_reuse_disable_overlap_scheduler(self):
# NOTE: Test with VSWA kv cache config.
kv_cache_config = KvCacheConfig(

View File

@ -49,6 +49,7 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_reuse_low_memory_available_no_partial_reuse
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_reuse_low_memory_available_partial_reuse
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_without_reuse_disable_overlap_scheduler
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_reuse_disable_overlap_scheduler
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=TRTLLM] TIMEOUT (90)

View File

@ -92,6 +92,7 @@ def test_kv_lens_runtime_with_eagle3_one_model():
f"kv_lens should be {expected_kv_lens_with_extra.tolist()}, but got {kv_lens_internal.tolist()}"
@pytest.mark.skip(reason="https://nvbugs/5856637")
@pytest.mark.parametrize(
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp,use_hf_speculative_model",
[