[https://nvbugs/5721661][fix] Prevent out-of-bounds read (#9879)

Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com>
This commit is contained in:
Thor Johnsen 2026-01-15 10:51:40 -06:00 committed by GitHub
parent dfac07c045
commit 0998a7bf20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 96 additions and 52 deletions

View File

@ -288,6 +288,9 @@ public:
void removeNextBlock(BlockKey const& blockKey);
void freeDescendantsRecursively();
void freeBlockAndAllDescendants();
//! \brief Find block matching blockKey. If allowPartial is true, the returned block may match only a prefix of
//! blockKey.
//! @return tuple of [partialMatch, numMatched, block], partialMatch is true if not all the tokens of the block were

View File

@ -477,6 +477,32 @@ void KVCacheBlock::removeNextBlock(BlockKey const& blockKey)
mNextBlocks.erase(blockKey);
}
void KVCacheBlock::freeDescendantsRecursively()
{
bool hasChildren = !mNextBlocks.empty();
if (hasChildren)
{
for (auto it = mNextBlocks.begin(); it != mNextBlocks.end();)
{
it->second->freeDescendantsRecursively();
TLLM_LOG_DEBUG("KVCacheBlock::freeDescendantsRecursively - Freeing block %d", it->second->getBlockId());
it = mNextBlocks.erase(it);
}
}
mPrevBlock = nullptr;
}
void KVCacheBlock::freeBlockAndAllDescendants()
{
// free from previous block
if (mPrevBlock != nullptr)
{
mPrevBlock->removeNextBlock(mBlockKey);
mPrevBlock = nullptr;
}
freeDescendantsRecursively();
}
bool KVCacheBlock::isFull() const
{
return mIsFull;
@ -956,19 +982,14 @@ void WindowBlockManager::freeLeafBlock(BlockPtr const& block)
void WindowBlockManager::freeChildren(BlockPtr const& block)
{
// Free all descendants of block
for (auto const& p : block->getNextBlocks())
{
auto childBlock = p.second;
freeChildren(childBlock);
}
// Free block
// Tell event manager we are freeing block
if (mEventManager && blockInRadixTree(block))
{
mEventManager->enqueueRemovedEvent(block, mWindowSize);
}
freeLeafBlock(block);
// Free block and all it's descendants from radix tree
block->freeBlockAndAllDescendants();
}
BlockPtr WindowBlockManager::getFreeBlock(GenerationRequest& sequence, executor::RetentionPriority priority,
@ -1567,60 +1588,80 @@ std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> WindowBlockManager::sto
auto searchRoot = mCachedBlocksRoot;
bool needMatch = true;
auto numBlocks = blockKeys.size();
// There is no guarantee that these vectors will be the same length.
// Only iterate as long as we have valid blockKey and blockId.
auto numBlocks = std::min(blockKeys.size(), blockIds.size());
std::vector<BlockPtr> storedBlocks;
std::vector<KVCacheBlock::IdType> pinnedBlockIds;
for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt)
{
auto const bid = blockIds[blockCnt];
TLLM_LOG_DEBUG("%s::storeBlocks - Searching match for block %d", mLogPrefix.c_str(), bid);
auto& block = mAllBlocksById[bid];
auto const& blockKey = blockKeys[blockCnt];
auto [partialMatch, numMatched, matchedBlock]
= needMatch ? searchRoot->findMatchingBlock(blockKey, false, false) : std::make_tuple(false, 0, nullptr);
if (matchedBlock != nullptr)
try
{
// Found match
TLLM_LOG_DEBUG(
"%s::storeBlocks - Found matching block %d, traverse", mLogPrefix.c_str(), matchedBlock->getBlockId());
searchRoot = matchedBlock;
// TODO possible optimization: if bid != matchedBlock->getBlockId(),
// block can be freed and inserted at mFreePrimaryBlocks.begin()
}
else
{
// No match
TLLM_LOG_DEBUG("%s::storeBlocks - No match, inserting block %d into search structure", mLogPrefix.c_str(),
block->getBlockId());
TLLM_CHECK_WITH_INFO(block->getBlockId() == bid,
"Block id mismatch " + std::to_string(block->getBlockId()) + " != " + std::to_string(bid));
needMatch = false; // no matching needed for following blocks
block->setBlockKey(blockKey, static_cast<SizeType32>(blockKey.uniqueTokens.size()) == mTokensPerBlock);
block->setPrevBlock(searchRoot);
block->setPrevBlockInSeq(searchRoot);
searchRoot->addNextBlock(blockKey, block);
// Protect against blockIds being shorter than blockKeys.
auto const bid = blockIds.at(blockCnt);
TLLM_LOG_DEBUG("%s::storeBlocks - Searching match for block %d", mLogPrefix.c_str(), bid);
// We set blockId to an invalid value to indicate that a block has been released early for a limited
// attention layer. Make sure we don't store an invalid block because of this.
auto& block = mAllBlocksById.at(bid);
// Protect against blockKeys being shorter than blockIds.
auto const& blockKey = blockKeys.at(blockCnt);
// Sanity check. The list of stored blocks should be connected.
TLLM_CHECK(storedBlocks.empty() || block->getPrevBlock() == storedBlocks.back());
// If either of the above error conditions occur, std::vector::at will throw an exception, which is caught
// further down. This will prevent an invalid block from being stored for reuse. The catch clause exits loop
// early, preventing blocks following an invalid block from being reused.
storedBlocks.push_back(block);
TLLM_CHECK(block->getPrevBlockInSeq() == nullptr
|| block->getPrevBlockInSeq()->getHash() == searchRoot->getHash());
auto oldHash = block->getHash();
auto newHash = BlockKeyHasher()(blockKey, searchRoot->getHash());
if (oldHash != newHash)
auto [partialMatch, numMatched, matchedBlock] = needMatch
? searchRoot->findMatchingBlock(blockKey, false, false)
: std::make_tuple(false, 0, nullptr);
if (matchedBlock != nullptr)
{
TLLM_LOG_DEBUG("#%d block hash %zx -> %zx", block->getBlockId(), oldHash, newHash);
block->setHash(newHash);
// Found match
TLLM_LOG_DEBUG("%s::storeBlocks - Found matching block %d, traverse", mLogPrefix.c_str(),
matchedBlock->getBlockId());
searchRoot = matchedBlock;
// TODO possible optimization: if bid != matchedBlock->getBlockId(),
// block can be freed and inserted at mFreePrimaryBlocks.begin()
}
else
{
// No match
TLLM_LOG_DEBUG("%s::storeBlocks - No match, inserting block %d into search structure",
mLogPrefix.c_str(), block->getBlockId());
TLLM_CHECK_WITH_INFO(block->getBlockId() == bid,
"Block id mismatch " + std::to_string(block->getBlockId()) + " != " + std::to_string(bid));
needMatch = false; // no matching needed for following blocks
block->setBlockKey(blockKey, static_cast<SizeType32>(blockKey.uniqueTokens.size()) == mTokensPerBlock);
block->setPrevBlock(searchRoot);
block->setPrevBlockInSeq(searchRoot);
searchRoot->addNextBlock(blockKey, block);
// Sanity check. The list of stored blocks should be connected.
TLLM_CHECK(storedBlocks.empty() || block->getPrevBlock() == storedBlocks.back());
storedBlocks.push_back(block);
TLLM_CHECK(block->getPrevBlockInSeq() == nullptr
|| block->getPrevBlockInSeq()->getHash() == searchRoot->getHash());
auto oldHash = block->getHash();
auto newHash = BlockKeyHasher()(blockKey, searchRoot->getHash());
if (oldHash != newHash)
{
TLLM_LOG_DEBUG("#%d block hash %zx -> %zx", block->getBlockId(), oldHash, newHash);
block->setHash(newHash);
}
searchRoot = block;
numBlocksStoredForReuse++;
}
if (pinBlocks)
{
searchRoot->incRefCount();
pinnedBlockIds.push_back(searchRoot->getBlockId());
}
searchRoot = block;
numBlocksStoredForReuse++;
}
if (pinBlocks)
catch (std::out_of_range const& ex)
{
searchRoot->incRefCount();
pinnedBlockIds.push_back(searchRoot->getBlockId());
TLLM_LOG_WARNING("Out of range access, terminating storeBlocks early.");
// Prevent blocks following an invalid block from being reused.
break;
}
}
if (mEventManager)