mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-03 17:52:19 +08:00
[https://nvbugs/5721661][fix] Prevent out-of-bounds read (#9879)
Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com>
This commit is contained in:
parent
dfac07c045
commit
0998a7bf20
@ -288,6 +288,9 @@ public:
|
||||
|
||||
void removeNextBlock(BlockKey const& blockKey);
|
||||
|
||||
void freeDescendantsRecursively();
|
||||
void freeBlockAndAllDescendants();
|
||||
|
||||
//! \brief Find block matching blockKey. If allowPartial is true, the returned block may match only a prefix of
|
||||
//! blockKey.
|
||||
//! @return tuple of [partialMatch, numMatched, block], partialMatch is true if not all the tokens of the block were
|
||||
|
||||
@ -477,6 +477,32 @@ void KVCacheBlock::removeNextBlock(BlockKey const& blockKey)
|
||||
mNextBlocks.erase(blockKey);
|
||||
}
|
||||
|
||||
void KVCacheBlock::freeDescendantsRecursively()
|
||||
{
|
||||
bool hasChildren = !mNextBlocks.empty();
|
||||
if (hasChildren)
|
||||
{
|
||||
for (auto it = mNextBlocks.begin(); it != mNextBlocks.end();)
|
||||
{
|
||||
it->second->freeDescendantsRecursively();
|
||||
TLLM_LOG_DEBUG("KVCacheBlock::freeDescendantsRecursively - Freeing block %d", it->second->getBlockId());
|
||||
it = mNextBlocks.erase(it);
|
||||
}
|
||||
}
|
||||
mPrevBlock = nullptr;
|
||||
}
|
||||
|
||||
void KVCacheBlock::freeBlockAndAllDescendants()
|
||||
{
|
||||
// free from previous block
|
||||
if (mPrevBlock != nullptr)
|
||||
{
|
||||
mPrevBlock->removeNextBlock(mBlockKey);
|
||||
mPrevBlock = nullptr;
|
||||
}
|
||||
freeDescendantsRecursively();
|
||||
}
|
||||
|
||||
bool KVCacheBlock::isFull() const
|
||||
{
|
||||
return mIsFull;
|
||||
@ -956,19 +982,14 @@ void WindowBlockManager::freeLeafBlock(BlockPtr const& block)
|
||||
|
||||
void WindowBlockManager::freeChildren(BlockPtr const& block)
|
||||
{
|
||||
// Free all descendants of block
|
||||
for (auto const& p : block->getNextBlocks())
|
||||
{
|
||||
auto childBlock = p.second;
|
||||
freeChildren(childBlock);
|
||||
}
|
||||
|
||||
// Free block
|
||||
// Tell event manager we are freeing block
|
||||
if (mEventManager && blockInRadixTree(block))
|
||||
{
|
||||
mEventManager->enqueueRemovedEvent(block, mWindowSize);
|
||||
}
|
||||
freeLeafBlock(block);
|
||||
|
||||
// Free block and all it's descendants from radix tree
|
||||
block->freeBlockAndAllDescendants();
|
||||
}
|
||||
|
||||
BlockPtr WindowBlockManager::getFreeBlock(GenerationRequest& sequence, executor::RetentionPriority priority,
|
||||
@ -1567,60 +1588,80 @@ std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> WindowBlockManager::sto
|
||||
auto searchRoot = mCachedBlocksRoot;
|
||||
bool needMatch = true;
|
||||
|
||||
auto numBlocks = blockKeys.size();
|
||||
// There is no guarantee that these vectors will be the same length.
|
||||
// Only iterate as long as we have valid blockKey and blockId.
|
||||
auto numBlocks = std::min(blockKeys.size(), blockIds.size());
|
||||
std::vector<BlockPtr> storedBlocks;
|
||||
std::vector<KVCacheBlock::IdType> pinnedBlockIds;
|
||||
for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt)
|
||||
{
|
||||
auto const bid = blockIds[blockCnt];
|
||||
TLLM_LOG_DEBUG("%s::storeBlocks - Searching match for block %d", mLogPrefix.c_str(), bid);
|
||||
auto& block = mAllBlocksById[bid];
|
||||
auto const& blockKey = blockKeys[blockCnt];
|
||||
|
||||
auto [partialMatch, numMatched, matchedBlock]
|
||||
= needMatch ? searchRoot->findMatchingBlock(blockKey, false, false) : std::make_tuple(false, 0, nullptr);
|
||||
if (matchedBlock != nullptr)
|
||||
try
|
||||
{
|
||||
// Found match
|
||||
TLLM_LOG_DEBUG(
|
||||
"%s::storeBlocks - Found matching block %d, traverse", mLogPrefix.c_str(), matchedBlock->getBlockId());
|
||||
searchRoot = matchedBlock;
|
||||
// TODO possible optimization: if bid != matchedBlock->getBlockId(),
|
||||
// block can be freed and inserted at mFreePrimaryBlocks.begin()
|
||||
}
|
||||
else
|
||||
{
|
||||
// No match
|
||||
TLLM_LOG_DEBUG("%s::storeBlocks - No match, inserting block %d into search structure", mLogPrefix.c_str(),
|
||||
block->getBlockId());
|
||||
TLLM_CHECK_WITH_INFO(block->getBlockId() == bid,
|
||||
"Block id mismatch " + std::to_string(block->getBlockId()) + " != " + std::to_string(bid));
|
||||
needMatch = false; // no matching needed for following blocks
|
||||
block->setBlockKey(blockKey, static_cast<SizeType32>(blockKey.uniqueTokens.size()) == mTokensPerBlock);
|
||||
block->setPrevBlock(searchRoot);
|
||||
block->setPrevBlockInSeq(searchRoot);
|
||||
searchRoot->addNextBlock(blockKey, block);
|
||||
// Protect against blockIds being shorter than blockKeys.
|
||||
auto const bid = blockIds.at(blockCnt);
|
||||
TLLM_LOG_DEBUG("%s::storeBlocks - Searching match for block %d", mLogPrefix.c_str(), bid);
|
||||
// We set blockId to an invalid value to indicate that a block has been released early for a limited
|
||||
// attention layer. Make sure we don't store an invalid block because of this.
|
||||
auto& block = mAllBlocksById.at(bid);
|
||||
// Protect against blockKeys being shorter than blockIds.
|
||||
auto const& blockKey = blockKeys.at(blockCnt);
|
||||
|
||||
// Sanity check. The list of stored blocks should be connected.
|
||||
TLLM_CHECK(storedBlocks.empty() || block->getPrevBlock() == storedBlocks.back());
|
||||
// If either of the above error conditions occur, std::vector::at will throw an exception, which is caught
|
||||
// further down. This will prevent an invalid block from being stored for reuse. The catch clause exits loop
|
||||
// early, preventing blocks following an invalid block from being reused.
|
||||
|
||||
storedBlocks.push_back(block);
|
||||
TLLM_CHECK(block->getPrevBlockInSeq() == nullptr
|
||||
|| block->getPrevBlockInSeq()->getHash() == searchRoot->getHash());
|
||||
auto oldHash = block->getHash();
|
||||
auto newHash = BlockKeyHasher()(blockKey, searchRoot->getHash());
|
||||
if (oldHash != newHash)
|
||||
auto [partialMatch, numMatched, matchedBlock] = needMatch
|
||||
? searchRoot->findMatchingBlock(blockKey, false, false)
|
||||
: std::make_tuple(false, 0, nullptr);
|
||||
if (matchedBlock != nullptr)
|
||||
{
|
||||
TLLM_LOG_DEBUG("#%d block hash %zx -> %zx", block->getBlockId(), oldHash, newHash);
|
||||
block->setHash(newHash);
|
||||
// Found match
|
||||
TLLM_LOG_DEBUG("%s::storeBlocks - Found matching block %d, traverse", mLogPrefix.c_str(),
|
||||
matchedBlock->getBlockId());
|
||||
searchRoot = matchedBlock;
|
||||
// TODO possible optimization: if bid != matchedBlock->getBlockId(),
|
||||
// block can be freed and inserted at mFreePrimaryBlocks.begin()
|
||||
}
|
||||
else
|
||||
{
|
||||
// No match
|
||||
TLLM_LOG_DEBUG("%s::storeBlocks - No match, inserting block %d into search structure",
|
||||
mLogPrefix.c_str(), block->getBlockId());
|
||||
TLLM_CHECK_WITH_INFO(block->getBlockId() == bid,
|
||||
"Block id mismatch " + std::to_string(block->getBlockId()) + " != " + std::to_string(bid));
|
||||
needMatch = false; // no matching needed for following blocks
|
||||
block->setBlockKey(blockKey, static_cast<SizeType32>(blockKey.uniqueTokens.size()) == mTokensPerBlock);
|
||||
block->setPrevBlock(searchRoot);
|
||||
block->setPrevBlockInSeq(searchRoot);
|
||||
searchRoot->addNextBlock(blockKey, block);
|
||||
|
||||
// Sanity check. The list of stored blocks should be connected.
|
||||
TLLM_CHECK(storedBlocks.empty() || block->getPrevBlock() == storedBlocks.back());
|
||||
|
||||
storedBlocks.push_back(block);
|
||||
TLLM_CHECK(block->getPrevBlockInSeq() == nullptr
|
||||
|| block->getPrevBlockInSeq()->getHash() == searchRoot->getHash());
|
||||
auto oldHash = block->getHash();
|
||||
auto newHash = BlockKeyHasher()(blockKey, searchRoot->getHash());
|
||||
if (oldHash != newHash)
|
||||
{
|
||||
TLLM_LOG_DEBUG("#%d block hash %zx -> %zx", block->getBlockId(), oldHash, newHash);
|
||||
block->setHash(newHash);
|
||||
}
|
||||
searchRoot = block;
|
||||
numBlocksStoredForReuse++;
|
||||
}
|
||||
if (pinBlocks)
|
||||
{
|
||||
searchRoot->incRefCount();
|
||||
pinnedBlockIds.push_back(searchRoot->getBlockId());
|
||||
}
|
||||
searchRoot = block;
|
||||
numBlocksStoredForReuse++;
|
||||
}
|
||||
if (pinBlocks)
|
||||
catch (std::out_of_range const& ex)
|
||||
{
|
||||
searchRoot->incRefCount();
|
||||
pinnedBlockIds.push_back(searchRoot->getBlockId());
|
||||
TLLM_LOG_WARNING("Out of range access, terminating storeBlocks early.");
|
||||
// Prevent blocks following an invalid block from being reused.
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (mEventManager)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user