mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[https://nvbugs/5674665][fix] Fix accuracy drop in VSWA with KV cache block reuse (#10875)
Signed-off-by: SimengLiu-nv <simengl@nvidia.com>
This commit is contained in:
parent
767b8dcab3
commit
d9fd8cc951
@ -642,7 +642,8 @@ public:
|
||||
void startScheduling();
|
||||
|
||||
//! \brief Assign blocks for new sequence. Try to reuse blocks.
|
||||
void addSequence(
|
||||
//! \return The number of tokens that were matched/prepopulated from cache (prepopulatedPromptLen)
|
||||
[[nodiscard]] SizeType32 addSequence(
|
||||
GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest);
|
||||
|
||||
//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
|
||||
@ -1091,8 +1092,9 @@ public:
|
||||
|
||||
void allocatePools(bool useUvm);
|
||||
|
||||
void addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks,
|
||||
LlmRequest& llmRequest, SizeType32 windowSize);
|
||||
//! \return The number of tokens that were matched/prepopulated from cache (prepopulatedPromptLen)
|
||||
[[nodiscard]] SizeType32 addSequence(GenerationRequest& sequence, SizeType32 inputLength,
|
||||
SizeType32 numContextBlocks, LlmRequest& llmRequest, SizeType32 windowSize);
|
||||
|
||||
//! \brief Assign blocks for a new sequence.
|
||||
//! \param sequence The GenerationRequest to process.
|
||||
|
||||
@ -34,6 +34,7 @@
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
@ -1397,15 +1398,16 @@ void WindowBlockManager::refreshBlocks()
|
||||
|
||||
// There are two versions of BlockManager::addSequence function.
|
||||
// This is called when block reuse is enabled.
|
||||
void BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks,
|
||||
SizeType32 BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks,
|
||||
LlmRequest& llmRequest, SizeType32 windowSize)
|
||||
{
|
||||
mWindowBlockManagers.at(windowSize).addSequence(sequence, inputLength, numContextBlocks, llmRequest);
|
||||
return mWindowBlockManagers.at(windowSize).addSequence(sequence, inputLength, numContextBlocks, llmRequest);
|
||||
}
|
||||
|
||||
// There are two versions of WindowBlockManager::addSequence function.
|
||||
// This is called when block reuse is enabled.
|
||||
void WindowBlockManager::addSequence(
|
||||
// Returns the total prepopulatedPromptLen (including connector matched tokens) for this window.
|
||||
SizeType32 WindowBlockManager::addSequence(
|
||||
GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest)
|
||||
{
|
||||
auto const requestId = sequence.getRequestId();
|
||||
@ -1457,9 +1459,13 @@ void WindowBlockManager::addSequence(
|
||||
numConnectorMatchedTokens = mKvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen);
|
||||
}
|
||||
|
||||
llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen + numConnectorMatchedTokens, getTokensPerBlock());
|
||||
TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numConnectorMatchedTokens %d",
|
||||
llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numConnectorMatchedTokens);
|
||||
// Return the total prepopulated length for this window (do not set on llmRequest here -
|
||||
// the caller KVCacheManager::addSequence will use the minimum across all windows)
|
||||
auto const totalPrepopulatedLen = prepopulatedPromptLen + numConnectorMatchedTokens;
|
||||
TLLM_LOG_DEBUG(
|
||||
"%s::addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numConnectorMatchedTokens %d",
|
||||
mLogPrefix.c_str(), llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numConnectorMatchedTokens);
|
||||
return totalPrepopulatedLen;
|
||||
}
|
||||
|
||||
void BlockManager::adjustBlocksIfNeeded(GenerationRequest& sequence)
|
||||
@ -2449,6 +2455,9 @@ void KVCacheManager::addSequence(
|
||||
"[kv cache manager] Encounter existing sequence %d, skip sequence storage validity initialization",
|
||||
requestId);
|
||||
}
|
||||
// Track the minimum prepopulated length across all windows (for VSWA with mixed isSWA flags)
|
||||
SizeType32 minPrepopulatedPromptLen = std::numeric_limits<SizeType32>::max();
|
||||
|
||||
for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata())
|
||||
{
|
||||
// NOTE: Caller to KVCacheManager::addSequence should deal with the chunking
|
||||
@ -2460,7 +2469,11 @@ void KVCacheManager::addSequence(
|
||||
auto const numContextBlocks = tc::ceilDiv(effectiveInputLength, getTokensPerBlock());
|
||||
if (mEnableBlockReuse)
|
||||
{
|
||||
mBlockManager.addSequence(sequence, effectiveInputLength, numContextBlocks, *llmRequest, windowSize);
|
||||
auto const prepopulatedLen
|
||||
= mBlockManager.addSequence(sequence, effectiveInputLength, numContextBlocks, *llmRequest, windowSize);
|
||||
// Use the minimum prepopulated length across all windows to ensure correctness
|
||||
// when there's a mix of SWA and non-SWA windows (e.g., VSWA case)
|
||||
minPrepopulatedPromptLen = std::min(minPrepopulatedPromptLen, prepopulatedLen);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -2479,6 +2492,13 @@ void KVCacheManager::addSequence(
|
||||
mBlockManager.updateSequenceCacheBlockOffsets(sequence, windowSize);
|
||||
}
|
||||
|
||||
// Set the prepopulated prompt length once using the minimum across all windows
|
||||
if (llmRequest && mEnableBlockReuse)
|
||||
{
|
||||
TLLM_LOG_DEBUG("KVCacheManager::addSequence: Setting prepopulatedPromptLen to %d", minPrepopulatedPromptLen);
|
||||
llmRequest->setPrepopulatedPromptLen(minPrepopulatedPromptLen, getTokensPerBlock());
|
||||
}
|
||||
|
||||
if (llmRequest)
|
||||
{
|
||||
// Update statistics for block allocations/reuse per request.
|
||||
|
||||
@ -306,7 +306,9 @@ void runPartialCopyTest()
|
||||
auto promptLen0 = llmRequest0->getNumTokens(beamIdx);
|
||||
auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq0.getRequestId());
|
||||
blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen0
|
||||
= blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0);
|
||||
auto cacheBlockIds = seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx);
|
||||
EXPECT_THAT(cacheBlockIds, ::testing::ElementsAreArray({0, 1, 2}));
|
||||
@ -354,7 +356,9 @@ void runPartialCopyTest()
|
||||
auto promptLen1 = llmRequest1->getNumTokens(beamIdx);
|
||||
auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq1.getRequestId());
|
||||
blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen1
|
||||
= blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 16);
|
||||
auto cacheBlockIds1 = seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx);
|
||||
EXPECT_THAT(cacheBlockIds1, ::testing::ElementsAreArray({0, 1, 6}));
|
||||
@ -379,7 +383,9 @@ void runPartialCopyTest()
|
||||
auto promptLen2 = llmRequest2->getNumTokens(beamIdx);
|
||||
auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq2.getRequestId());
|
||||
blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen2
|
||||
= blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
|
||||
llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 11);
|
||||
auto cacheBlockIds2 = seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx);
|
||||
EXPECT_THAT(cacheBlockIds2, ::testing::ElementsAreArray({0, 2}));
|
||||
@ -756,7 +762,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
|
||||
auto promptLen0 = llmRequest0->getNumTokens(beamIdx);
|
||||
auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq0.getRequestId());
|
||||
blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen0
|
||||
= blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0);
|
||||
EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
|
||||
llmRequest0->addNewToken(9, beamIdx); // block 2 contains [8]
|
||||
@ -783,7 +791,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
|
||||
auto promptLen1 = llmRequest1->getNumTokens(beamIdx);
|
||||
auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq1.getRequestId());
|
||||
blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen1
|
||||
= blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock);
|
||||
EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3}));
|
||||
// at this point, block 3 contains [8]
|
||||
@ -810,7 +820,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
|
||||
promptLen0 = llmRequest0->getNumTokens(beamIdx);
|
||||
numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq0_dup.getRequestId());
|
||||
blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
prepopulatedPromptLen0
|
||||
= blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), promptLen0 - 1);
|
||||
EXPECT_THAT(seq0_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
|
||||
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks);
|
||||
@ -826,7 +838,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
|
||||
promptLen1 = llmRequest1->getNumTokens(beamIdx);
|
||||
numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq1_dup.getRequestId());
|
||||
blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
prepopulatedPromptLen1
|
||||
= blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock);
|
||||
EXPECT_THAT(seq1_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 4}));
|
||||
llmRequest1->addNewToken(10, beamIdx); // block 4 contains [8, 9, 10]
|
||||
@ -859,7 +873,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
|
||||
auto promptLen2 = llmRequest2->getNumTokens(beamIdx);
|
||||
auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq2.getRequestId());
|
||||
blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen2
|
||||
= blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
|
||||
llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest2->getContextCurrentPosition(), tokensPerBlock);
|
||||
EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 5}));
|
||||
llmRequest2->addNewToken(5, beamIdx); // block 5 contains [4]
|
||||
@ -881,7 +897,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
|
||||
auto promptLen3 = llmRequest3->getNumTokens(beamIdx);
|
||||
auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq3.getRequestId());
|
||||
blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen3
|
||||
= blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
|
||||
llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest3->getContextCurrentPosition(), numTokens - 1);
|
||||
EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 4}));
|
||||
llmRequest3->addNewToken(11, beamIdx); // block 4 contains [8, 9, 11]
|
||||
@ -914,7 +932,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
|
||||
auto promptLen4 = llmRequest4->getNumTokens(beamIdx);
|
||||
auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq4.getRequestId());
|
||||
blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen4
|
||||
= blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
|
||||
llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest4->getContextCurrentPosition(), promptLen4 - 1);
|
||||
EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 4}));
|
||||
numTokens = llmRequest4->getNumTokens(beamIdx);
|
||||
@ -946,7 +966,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
|
||||
promptLen4 = llmRequest4->getNumTokens(beamIdx);
|
||||
numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq4_dup.getRequestId());
|
||||
blockManager.addSequence(seq4_dup, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
|
||||
prepopulatedPromptLen4
|
||||
= blockManager.addSequence(seq4_dup, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
|
||||
llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest4->getContextCurrentPosition(), promptLen4 - 2);
|
||||
EXPECT_THAT(seq4_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
|
||||
numTokens = llmRequest4->getNumTokens(beamIdx);
|
||||
@ -973,7 +995,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
|
||||
auto promptLen5 = llmRequest5->getNumTokens(beamIdx);
|
||||
auto numContextBlocks5 = tc::ceilDiv(promptLen5, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq5.getRequestId());
|
||||
blockManager.addSequence(seq5, promptLen5, numContextBlocks5, *llmRequest5, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen5
|
||||
= blockManager.addSequence(seq5, promptLen5, numContextBlocks5, *llmRequest5, maxAttentionWindow);
|
||||
llmRequest5->setPrepopulatedPromptLen(prepopulatedPromptLen5, blockManager.getTokensPerBlock());
|
||||
llmRequest5->addNewToken(0, beamIdx);
|
||||
EXPECT_EQ(llmRequest5->getContextCurrentPosition(), 1); // incidental reuse
|
||||
|
||||
@ -998,7 +1022,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
|
||||
auto promptLen6 = llmRequest6->getNumTokens(beamIdx);
|
||||
auto numContextBlocks6 = tc::ceilDiv(promptLen6, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq6.getRequestId());
|
||||
blockManager.addSequence(seq6, promptLen6, numContextBlocks6, *llmRequest6, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen6
|
||||
= blockManager.addSequence(seq6, promptLen6, numContextBlocks6, *llmRequest6, maxAttentionWindow);
|
||||
llmRequest6->setPrepopulatedPromptLen(prepopulatedPromptLen6, blockManager.getTokensPerBlock());
|
||||
llmRequest6->addNewToken(0, beamIdx);
|
||||
// no reuse occurs because we are unable to reuse last input token and inputLength6 == 1.
|
||||
EXPECT_EQ(llmRequest6->getContextCurrentPosition(), 0);
|
||||
@ -1068,7 +1094,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
|
||||
auto promptLen0 = llmRequest0->getNumTokens(beamIdx);
|
||||
auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq0.getRequestId());
|
||||
blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen0
|
||||
= blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0);
|
||||
EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
|
||||
llmRequest0->addNewToken(3, beamIdx);
|
||||
@ -1100,7 +1128,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
|
||||
auto promptLen1 = llmRequest1->getNumTokens(beamIdx);
|
||||
auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq1.getRequestId());
|
||||
blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen1
|
||||
= blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock);
|
||||
EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3}));
|
||||
llmRequest1->addNewToken(3, beamIdx);
|
||||
@ -1127,7 +1157,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
|
||||
promptLen0 = llmRequest0->getNumTokens(beamIdx);
|
||||
numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq0_dup.getRequestId());
|
||||
blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
prepopulatedPromptLen0
|
||||
= blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
|
||||
llmRequest0->addNewToken(3, beamIdx);
|
||||
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 2 * tokensPerBlock);
|
||||
EXPECT_THAT(seq0_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 4}));
|
||||
@ -1149,7 +1181,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
|
||||
promptLen1 = llmRequest1->getNumTokens(beamIdx);
|
||||
numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq1_dup.getRequestId());
|
||||
blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
prepopulatedPromptLen1
|
||||
= blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), llmRequest1->getNumTokens(beamIdx) - 1);
|
||||
EXPECT_THAT(seq1_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
|
||||
llmRequest1->addNewToken(5, beamIdx);
|
||||
@ -1183,7 +1217,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
|
||||
auto promptLen2 = llmRequest2->getNumTokens(beamIdx);
|
||||
auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq2.getRequestId());
|
||||
blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen2
|
||||
= blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
|
||||
llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0);
|
||||
EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({5, 6, 7}));
|
||||
llmRequest2->addNewToken(3, beamIdx);
|
||||
@ -1209,7 +1245,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
|
||||
auto promptLen3 = llmRequest3->getNumTokens(beamIdx);
|
||||
auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq3.getRequestId());
|
||||
blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen3
|
||||
= blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
|
||||
llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest3->getContextCurrentPosition(), tokensPerBlock);
|
||||
EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 8, 9}));
|
||||
llmRequest3->addNewToken(3, beamIdx);
|
||||
@ -1286,7 +1324,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest)
|
||||
auto promptLen0 = llmRequest0->getNumTokens(beamIdx);
|
||||
auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq0.getRequestId());
|
||||
blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen0
|
||||
= blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0);
|
||||
EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
|
||||
llmRequest0->addNewToken(3, beamIdx);
|
||||
@ -1323,7 +1363,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest)
|
||||
auto promptLen1 = llmRequest1->getNumTokens(beamIdx);
|
||||
auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq1.getRequestId());
|
||||
blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen1
|
||||
= blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock);
|
||||
EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3}));
|
||||
llmRequest1->addNewToken(3, beamIdx);
|
||||
@ -1358,7 +1400,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest)
|
||||
auto promptLen2 = llmRequest2->getNumTokens(beamIdx);
|
||||
auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq2.getRequestId());
|
||||
blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen2
|
||||
= blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
|
||||
llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0);
|
||||
EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({4, 5, 6}));
|
||||
llmRequest2->addNewToken(9, beamIdx);
|
||||
@ -1391,7 +1435,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest)
|
||||
auto promptLen3 = llmRequest3->getNumTokens(beamIdx);
|
||||
auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq3.getRequestId());
|
||||
blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen3
|
||||
= blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
|
||||
llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest3->getContextCurrentPosition(),
|
||||
tokensPerBlock); // only reuse block 0 [100, 101, 102, 103] with same hash/offset
|
||||
EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 7, 8}));
|
||||
@ -1462,7 +1508,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
|
||||
auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
|
||||
// get new blocks 0, 1, 2 ([0,1,2,3], [4,5,6,7], [8])
|
||||
blockManager.holdSequence(seq0.getRequestId());
|
||||
blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen0
|
||||
= blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0);
|
||||
EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
|
||||
llmRequest0->addNewToken(9, beamIdx);
|
||||
@ -1492,7 +1540,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
|
||||
auto promptLen1 = llmRequest1->getNumTokens(beamIdx);
|
||||
auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq1.getRequestId());
|
||||
blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen1
|
||||
= blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock);
|
||||
EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3}));
|
||||
llmRequest1->addNewToken(9, beamIdx);
|
||||
@ -1517,7 +1567,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
|
||||
promptLen0 = llmRequest0->getNumTokens(beamIdx);
|
||||
numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq0_dup.getRequestId());
|
||||
blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
prepopulatedPromptLen0
|
||||
= blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
|
||||
// nb! addNewToken adds new generated token, number of input tokens stay the same.
|
||||
// calling addNewToken before addSequence potentially triggers this error message:
|
||||
// Assertion failed: prepopulatedPromptLen < promptLen
|
||||
@ -1539,7 +1591,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
|
||||
numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
|
||||
// reuse 0, 1, 2(p) ([0,1,2,3], [4,5,6,7], [8])
|
||||
blockManager.holdSequence(seq1_dup.getRequestId());
|
||||
blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
prepopulatedPromptLen1
|
||||
= blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), llmRequest1->getNumTokens(beamIdx) - 1);
|
||||
EXPECT_THAT(seq1_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
|
||||
llmRequest1->addNewToken(10, beamIdx);
|
||||
@ -1571,7 +1625,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
|
||||
auto promptLen2 = llmRequest2->getNumTokens(beamIdx);
|
||||
auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq2.getRequestId());
|
||||
blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen2
|
||||
= blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
|
||||
llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock());
|
||||
// no reuse expected. Input tokens match blocks 0 and 1, but lora task id differs.
|
||||
EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0);
|
||||
EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({5, 6, 7}));
|
||||
@ -1600,7 +1656,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
|
||||
auto promptLen3 = llmRequest3->getNumTokens(beamIdx);
|
||||
auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq3.getRequestId());
|
||||
blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen3
|
||||
= blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
|
||||
llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest3->getContextCurrentPosition(), promptLen3 - 2);
|
||||
EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({5, 6, 7}));
|
||||
llmRequest3->addNewToken(11, beamIdx);
|
||||
@ -1630,7 +1688,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
|
||||
auto promptLen4 = llmRequest4->getNumTokens(beamIdx);
|
||||
auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq4.getRequestId());
|
||||
blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen4
|
||||
= blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
|
||||
llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest4->getContextCurrentPosition(), tokensPerBlock);
|
||||
EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 8}));
|
||||
llmRequest4->addNewToken(5, beamIdx);
|
||||
@ -1655,7 +1715,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
|
||||
auto promptLen5 = llmRequest5->getNumTokens(beamIdx);
|
||||
auto numContextBlocks5 = tc::ceilDiv(promptLen5, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq5.getRequestId());
|
||||
blockManager.addSequence(seq5, promptLen5, numContextBlocks5, *llmRequest5, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen5
|
||||
= blockManager.addSequence(seq5, promptLen5, numContextBlocks5, *llmRequest5, maxAttentionWindow);
|
||||
llmRequest5->setPrepopulatedPromptLen(prepopulatedPromptLen5, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest5->getContextCurrentPosition(), 0);
|
||||
EXPECT_THAT(seq5.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({9, 10, 11}));
|
||||
llmRequest5->addNewToken(9, beamIdx);
|
||||
@ -1726,7 +1788,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
|
||||
auto promptLen0 = llmRequest0->getNumTokens(beamIdx);
|
||||
auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq0.getRequestId());
|
||||
blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen0
|
||||
= blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0);
|
||||
EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
|
||||
llmRequest0->addNewToken(3, beamIdx);
|
||||
@ -1759,7 +1823,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
|
||||
auto promptLen1 = llmRequest1->getNumTokens(beamIdx);
|
||||
auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq1.getRequestId());
|
||||
blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen1
|
||||
= blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 0);
|
||||
EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 5}));
|
||||
llmRequest1->addNewToken(3, beamIdx);
|
||||
@ -1786,7 +1852,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
|
||||
numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
|
||||
// reuse blocks 0, 1 and get new block 6
|
||||
blockManager.holdSequence(seq0_dup.getRequestId());
|
||||
blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
prepopulatedPromptLen0
|
||||
= blockManager.addSequence(seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
|
||||
llmRequest0->addNewToken(3, beamIdx);
|
||||
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 2 * tokensPerBlock);
|
||||
EXPECT_THAT(seq0_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 6}));
|
||||
@ -1808,7 +1876,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
|
||||
promptLen1 = llmRequest1->getNumTokens(beamIdx);
|
||||
numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq1_dup.getRequestId());
|
||||
blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
prepopulatedPromptLen1
|
||||
= blockManager.addSequence(seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), llmRequest1->getNumTokens(beamIdx) - 1);
|
||||
EXPECT_THAT(seq1_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 5}));
|
||||
llmRequest1->addNewToken(5, beamIdx);
|
||||
@ -1841,7 +1911,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
|
||||
auto promptLen2 = llmRequest2->getNumTokens(beamIdx);
|
||||
auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq2.getRequestId());
|
||||
blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen2
|
||||
= blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
|
||||
llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0);
|
||||
EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({7, 8, 9}));
|
||||
llmRequest2->addNewToken(3, beamIdx);
|
||||
@ -1867,7 +1939,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
|
||||
auto promptLen3 = llmRequest3->getNumTokens(beamIdx);
|
||||
auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq3.getRequestId());
|
||||
blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen3
|
||||
= blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
|
||||
llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest3->getContextCurrentPosition(), tokensPerBlock);
|
||||
EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 10, 11}));
|
||||
llmRequest3->addNewToken(3, beamIdx);
|
||||
@ -1892,7 +1966,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest)
|
||||
auto promptLen4 = llmRequest4->getNumTokens(beamIdx);
|
||||
auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq4.getRequestId());
|
||||
blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen4
|
||||
= blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
|
||||
llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest4->getContextCurrentPosition(), tokensPerBlock);
|
||||
EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 12, 13}));
|
||||
llmRequest4->addNewToken(3, beamIdx);
|
||||
@ -1971,7 +2047,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest)
|
||||
auto promptLen0 = llmRequest0->getNumTokens(beamIdx);
|
||||
auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq0.getRequestId());
|
||||
blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen0
|
||||
= blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0);
|
||||
EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
|
||||
|
||||
@ -2009,7 +2087,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest)
|
||||
auto promptLen1 = llmRequest1->getNumTokens(beamIdx);
|
||||
auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq1.getRequestId());
|
||||
blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen1
|
||||
= blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 0); // No reuse, starts from scratch
|
||||
EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 5}));
|
||||
|
||||
@ -2042,7 +2122,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest)
|
||||
auto promptLen2 = llmRequest2->getNumTokens(beamIdx);
|
||||
auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq2.getRequestId());
|
||||
blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen2
|
||||
= blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
|
||||
llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 2 * tokensPerBlock); // Reuse blocks 3,4
|
||||
EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 6}));
|
||||
|
||||
@ -2076,7 +2158,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest)
|
||||
auto promptLen3 = llmRequest3->getNumTokens(beamIdx);
|
||||
auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq3.getRequestId());
|
||||
blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen3
|
||||
= blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
|
||||
llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest3->getContextCurrentPosition(), 0); // No reuse
|
||||
EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({7, 8, 9}));
|
||||
|
||||
@ -2103,7 +2187,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest)
|
||||
auto promptLen4 = llmRequest4->getNumTokens(beamIdx);
|
||||
auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq4.getRequestId());
|
||||
blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen4
|
||||
= blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow);
|
||||
llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock());
|
||||
EXPECT_EQ(llmRequest4->getContextCurrentPosition(), 2 * tokensPerBlock); // Reuse blocks 0,1
|
||||
EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 10}));
|
||||
|
||||
@ -2225,7 +2311,9 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest)
|
||||
GenerationRequest seq0{0, inputLength0, beamWidth, blockManager.getWindowSizesMetadata()};
|
||||
auto numContextBlocks0 = tc::ceilDiv(inputLength0, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq0.getRequestId());
|
||||
blockManager.addSequence(seq0, llmRequest0->getNumTokens(0), numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen0 = blockManager.addSequence(
|
||||
seq0, llmRequest0->getNumTokens(0), numContextBlocks0, *llmRequest0, maxAttentionWindow);
|
||||
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
|
||||
|
||||
// Add another sequence with different tokens, at a low priority
|
||||
auto inputTokens1 = std::make_shared<VecTokens>(VecTokens{8, 9, 10, 11, 12, 13, 14, 15});
|
||||
@ -2234,7 +2322,9 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest)
|
||||
GenerationRequest seq1{1, inputLength1, beamWidth, blockManager.getWindowSizesMetadata()};
|
||||
auto numContextBlocks1 = tc::ceilDiv(inputLength1, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq1.getRequestId());
|
||||
blockManager.addSequence(seq1, llmRequest1->getNumTokens(0), numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen1 = blockManager.addSequence(
|
||||
seq1, llmRequest1->getNumTokens(0), numContextBlocks1, *llmRequest1, maxAttentionWindow);
|
||||
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
|
||||
|
||||
// Release both sequences
|
||||
blockManager.releaseBlocks(seq0, llmRequest0);
|
||||
@ -2251,7 +2341,9 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest)
|
||||
GenerationRequest seq2{2, inputLength2, beamWidth, blockManager.getWindowSizesMetadata()};
|
||||
auto numContextBlocks2 = tc::ceilDiv(inputLength2, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq2.getRequestId());
|
||||
blockManager.addSequence(seq2, llmRequest2->getNumTokens(0), numContextBlocks2, *llmRequest2, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen2 = blockManager.addSequence(
|
||||
seq2, llmRequest2->getNumTokens(0), numContextBlocks2, *llmRequest2, maxAttentionWindow);
|
||||
llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock());
|
||||
blockManager.releaseBlocks(seq2, llmRequest2);
|
||||
blockManager.releaseSequence(seq2.getRequestId());
|
||||
|
||||
@ -2262,7 +2354,9 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest)
|
||||
GenerationRequest seq3{3, inputLength3, beamWidth, blockManager.getWindowSizesMetadata()};
|
||||
auto numContextBlocks3 = tc::ceilDiv(inputLength3, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq3.getRequestId());
|
||||
blockManager.addSequence(seq3, llmRequest3->getNumTokens(0), numContextBlocks3, *llmRequest3, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen3 = blockManager.addSequence(
|
||||
seq3, llmRequest3->getNumTokens(0), numContextBlocks3, *llmRequest3, maxAttentionWindow);
|
||||
llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock());
|
||||
|
||||
EXPECT_EQ(llmRequest3->getContextCurrentPosition(), 4);
|
||||
|
||||
@ -2277,7 +2371,9 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest)
|
||||
GenerationRequest seq4{4, inputLength3, beamWidth, blockManager.getWindowSizesMetadata()};
|
||||
auto numContextBlocks4 = tc::ceilDiv(inputLength4, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq4.getRequestId());
|
||||
blockManager.addSequence(seq4, llmRequest4->getNumTokens(0), numContextBlocks4, *llmRequest4, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen4 = blockManager.addSequence(
|
||||
seq4, llmRequest4->getNumTokens(0), numContextBlocks4, *llmRequest4, maxAttentionWindow);
|
||||
llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock());
|
||||
|
||||
EXPECT_EQ(llmRequest4->getContextCurrentPosition(), 4);
|
||||
|
||||
@ -2288,7 +2384,9 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest)
|
||||
GenerationRequest seq5{5, inputLength5, beamWidth, blockManager.getWindowSizesMetadata()};
|
||||
auto numContextBlocks5 = tc::ceilDiv(inputLength5, blockManager.getTokensPerBlock());
|
||||
blockManager.holdSequence(seq5.getRequestId());
|
||||
blockManager.addSequence(seq5, llmRequest5->getNumTokens(0), numContextBlocks5, *llmRequest5, maxAttentionWindow);
|
||||
auto prepopulatedPromptLen5 = blockManager.addSequence(
|
||||
seq5, llmRequest5->getNumTokens(0), numContextBlocks5, *llmRequest5, maxAttentionWindow);
|
||||
llmRequest5->setPrepopulatedPromptLen(prepopulatedPromptLen5, blockManager.getTokensPerBlock());
|
||||
|
||||
EXPECT_EQ(llmRequest5->getContextCurrentPosition(), 0);
|
||||
}
|
||||
|
||||
@ -1217,10 +1217,6 @@ class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness):
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=
|
||||
"Currently failing due to accuracy drop, https://nvbugspro.nvidia.com/bug/5674665"
|
||||
)
|
||||
def test_auto_dtype_vswa_reuse_disable_overlap_scheduler(self):
|
||||
# NOTE: Test with VSWA kv cache config.
|
||||
kv_cache_config = KvCacheConfig(
|
||||
|
||||
@ -49,6 +49,7 @@ l0_h100:
|
||||
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_reuse_low_memory_available_no_partial_reuse
|
||||
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_reuse_low_memory_available_partial_reuse
|
||||
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_without_reuse_disable_overlap_scheduler
|
||||
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_reuse_disable_overlap_scheduler
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=TRTLLM] TIMEOUT (90)
|
||||
|
||||
@ -92,6 +92,7 @@ def test_kv_lens_runtime_with_eagle3_one_model():
|
||||
f"kv_lens should be {expected_kv_lens_with_extra.tolist()}, but got {kv_lens_internal.tolist()}"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5856637")
|
||||
@pytest.mark.parametrize(
|
||||
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp,use_hf_speculative_model",
|
||||
[
|
||||
|
||||
Loading…
Reference in New Issue
Block a user