[https://nvbugs/5627710][fix] Fix synchronization bugs in KvCacheTransferManager that can cause corrupted blocks (#9056)

Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com>
Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>
Co-authored-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
Co-authored-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
This commit is contained in:
Thor Johnsen 2025-12-02 09:10:21 -06:00 committed by GitHub
parent b86256eb54
commit 95049eea86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 172 additions and 19 deletions

View File

@ -824,6 +824,9 @@ public:
return mBufferManager; return mBufferManager;
} }
//! \brief Sync internal streams used by transfer manager with buffer manager stream
void syncTransferManagerWithBufferManager();
//! \brief Perform per-request bookkeeping //! \brief Perform per-request bookkeeping
void refreshBlocks(); void refreshBlocks();
@ -1313,6 +1316,9 @@ public:
//! \brief Store newest block for reuse //! \brief Store newest block for reuse
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest); void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
//! \brief Sync internal streams used by transfer manager with buffer manager stream
void syncTransferManagerWithBufferManager();
//! \brief Perform per-request bookkeeping //! \brief Perform per-request bookkeeping
void refreshBlocks(); void refreshBlocks();
@ -1584,6 +1590,7 @@ public:
[[nodiscard]] virtual runtime::ITensor::SharedPtr getIndexerKCachePool() const = 0; [[nodiscard]] virtual runtime::ITensor::SharedPtr getIndexerKCachePool() const = 0;
[[nodiscard]] virtual SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const = 0; [[nodiscard]] virtual SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const = 0;
virtual void syncTransferManagerWithBufferManager() = 0;
virtual void refreshBlocks() = 0; virtual void refreshBlocks() = 0;
virtual void flushIterationEvents() = 0; virtual void flushIterationEvents() = 0;
virtual void resetReuseState() = 0; virtual void resetReuseState() = 0;
@ -1965,6 +1972,11 @@ public:
return mBlockManager.getPoolLayerIdx(layer_idx); return mBlockManager.getPoolLayerIdx(layer_idx);
} }
void syncTransferManagerWithBufferManager() override
{
mBlockManager.syncTransferManagerWithBufferManager();
}
//! \brief Perform per-iteration bookkeeping //! \brief Perform per-iteration bookkeeping
void refreshBlocks() override void refreshBlocks() override
{ {

View File

@ -46,7 +46,15 @@ public:
int numTokensToCopy = 0, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, int numTokensToCopy = 0, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM,
std::string const& directory = ""); std::string const& directory = "");
//! \brief Synchronize the offload/onboard streams with the bufferManager stream. //! \brief Synchronize internal streams with bufferManager stream.
//! \details The buffer manager uses the same stream as the prefill and decode kernels. This method ensures that the
//! internal kernels used for offloading and onboarding will wait for prefill and decode kernels before performing
//! any block copies. This method must be called before the first call to KVCacheManager::addSequence in every step.
void syncWithBufferManager();
//! \brief Synchronize bufferManager stream with internal streams. This method ensures that prefill and decode
//! kernels for next step will wait for offloading and onboarding work that has already been scheduled. This method
//! must be called after last call to KVCacheManager::addSequence in every step.
void syncTransfers(); void syncTransfers();
private: private:
@ -75,8 +83,10 @@ private:
runtime::BufferManager mOnboardManager; runtime::BufferManager mOnboardManager;
runtime::BufferManager mOffloadManager; runtime::BufferManager mOffloadManager;
// Track the block ids offloaded in this iteration. // Track reads and writes for blocks. Note that it is the memory pool index that
std::unordered_map<int32_t, tr::CudaEvent> mPendingOffloads; // identifies the raw memory blocks involved in I/O, not the block Id.
std::unordered_map<kernels::KVCacheIndex::UnderlyingType, tr::CudaEvent> mPendingReads;
std::unordered_map<kernels::KVCacheIndex::UnderlyingType, tr::CudaEvent> mPendingWrites;
// Reference to parent loopback agent // Reference to parent loopback agent
std::shared_ptr<kvc::BaseLoopbackAgent> mLoopbackAgent; std::shared_ptr<kvc::BaseLoopbackAgent> mLoopbackAgent;
int mDeviceId; int mDeviceId;

View File

@ -26,6 +26,8 @@ void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(allocateKvCache); NVTX3_SCOPED_RANGE(allocateKvCache);
kvCacheManager.syncTransferManagerWithBufferManager();
for (auto const& llmReq : contextRequests) for (auto const& llmReq : contextRequests)
{ {
if (llmReq->isFirstContextChunk()) if (llmReq->isFirstContextChunk())

View File

@ -1336,6 +1336,19 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
return numMatchedTokens; return numMatchedTokens;
} }
void BlockManager::syncTransferManagerWithBufferManager()
{
for (auto& [_, manager] : mWindowBlockManagers)
{
manager.syncTransferManagerWithBufferManager();
}
}
void WindowBlockManager::syncTransferManagerWithBufferManager()
{
mTransferManager->syncWithBufferManager();
}
void BlockManager::refreshBlocks() void BlockManager::refreshBlocks()
{ {
for (auto& [_, manager] : mWindowBlockManagers) for (auto& [_, manager] : mWindowBlockManagers)

View File

@ -207,47 +207,140 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst,
} }
} }
void KVCacheTransferManager::onboard(BlockPtr const& offloadBlock, BlockPtr const& block, //
// Note about recording events to wait for cudaMempyAsync calls between blocks:
// The memory copy involves raw memory blocks, which are pointed to by the
// memory pool block index. When recording events, you must use getMemoryPoolBlockIndex()
// as the raw memory block identifier. Using getBlockId() when recording events is wrong.
// getBlockId() returns the logical block id, which has nothing to do with the raw memory
// block pointers involved in a cudaMemcpy.
//
//
// Notes about need for synchronization:
//
// Relying on decoder syncing GPU with CPU to ensure that blocks are ready
// for offload/onboard/partial copy is dangerous. We have an asynchronous decoder
// that may not synchronize or synchronize at a later point in the execution stream.
// To avoid synchronization issues caused by changes to decoder design we rely on
// KVCacheTransferManager::syncWithBufferManager() that ensures that internal copy streams
// will wait for prefill and decode kernels that have already been scheduled.
//
// Earlier versions of this code did not account for all possible cases where a new block copy
// needed to wait for a previously scheduled copy to finish. For instance, it is possible
// that two primary blocks are offloaded to the same secondary block in a single step,
// scheduling the second offloading without waiting for the first one to finish leads to
// a corrupted block after offloading. It is possible that partial reuse will copy
// from a block that is currently being onboarded, scheduling the partial copy without
// waiting for the onboarding to finish will lead to a corrupted block. To handle all
// possible cases needing synchronization we record separate events for reads and writes
// to a block. When a new block copy is scheduled, we wait for all writes to the source
// block and all reads and writes to a destination block.
//
// As before, syncTransfers() must be called after last call to KVCacheManager::addSequence.
// Failing to do so will lead to corrupted blocks eventually.
//
void KVCacheTransferManager::onboard(BlockPtr const& offloadedBlock, BlockPtr const& block,
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode, std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
std::string const& directory) std::string const& directory)
{ {
if (mode != executor::KvCacheTransferMode::DRAM // Wait for any pending writes before reading from offloadedBlock
&& mPendingOffloads.find(offloadBlock->getBlockId()) == mPendingOffloads.end()) auto offloadedBlockPendingWriteItr = mPendingWrites.find(offloadedBlock->getMemoryPoolBlockIndex());
if (offloadedBlockPendingWriteItr != mPendingWrites.end())
{ {
TLLM_LOG_DEBUG("Skipping onboard for block %d because it was never previously offloaded to disk", mOnboardManager.getStream().wait(offloadedBlockPendingWriteItr->second);
offloadBlock->getBlockId()); // Don't erase, we are not changing state of offloadedBlock
return; }
// Wait for any pending reads before overwriting block
auto blockPendingReadItr = mPendingReads.find(block->getMemoryPoolBlockIndex());
if (blockPendingReadItr != mPendingReads.end())
{
mOnboardManager.getStream().wait(blockPendingReadItr->second);
mPendingReads.erase(blockPendingReadItr);
}
// Wait for any pending writes before overwriting block
auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex());
if (blockPendingWriteItr != mPendingWrites.end())
{
mOnboardManager.getStream().wait(blockPendingWriteItr->second);
mPendingWrites.erase(blockPendingWriteItr);
} }
if (mPendingOffloads.find(offloadBlock->getBlockId()) != mPendingOffloads.end()) copyBlock(offloadedBlock, block, pools, false, numTokensToCopy, mode, directory);
{
mOnboardManager.getStream().wait(mPendingOffloads[offloadBlock->getBlockId()]); // Record new pending read from offloadedBlock
} mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent();
copyBlock(offloadBlock, block, pools, false, numTokensToCopy, mode, directory); mOnboardManager.getStream().record(mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()]);
// Record new pending write to block
mPendingWrites[block->getMemoryPoolBlockIndex()] = tr::CudaEvent();
mOnboardManager.getStream().record(mPendingWrites[block->getMemoryPoolBlockIndex()]);
} }
void KVCacheTransferManager::offload(BlockPtr const& block, BlockPtr const& offloadBlock, void KVCacheTransferManager::offload(BlockPtr const& block, BlockPtr const& offloadBlock,
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode, std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
std::string const& directory) std::string const& directory)
{ {
mPendingOffloads[block->getBlockId()] = tr::CudaEvent(); // Wait for any pending writes before reading from block
auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex());
if (blockPendingWriteItr != mPendingWrites.end())
{
mOffloadManager.getStream().wait(blockPendingWriteItr->second);
// Don't erase, we are not changing state of block
}
// Wait for any pending reads before overwriting offloadBlock
auto offloadBlockPendingReadItr = mPendingReads.find(offloadBlock->getMemoryPoolBlockIndex());
if (offloadBlockPendingReadItr != mPendingReads.end())
{
mOffloadManager.getStream().wait(offloadBlockPendingReadItr->second);
mPendingReads.erase(offloadBlockPendingReadItr);
}
// Wait for any pending writes before overwriting offloadBlock
auto offloadBlockPendingWriteItr = mPendingWrites.find(offloadBlock->getMemoryPoolBlockIndex());
if (offloadBlockPendingWriteItr != mPendingWrites.end())
{
mOffloadManager.getStream().wait(offloadBlockPendingWriteItr->second);
mPendingWrites.erase(offloadBlockPendingWriteItr);
}
copyBlock(block, offloadBlock, pools, true, numTokensToCopy, mode, directory); copyBlock(block, offloadBlock, pools, true, numTokensToCopy, mode, directory);
mOffloadManager.getStream().record(mPendingOffloads[block->getBlockId()]);
// Record new pending read from block
mPendingReads[block->getMemoryPoolBlockIndex()] = tr::CudaEvent();
mOffloadManager.getStream().record(mPendingReads[block->getMemoryPoolBlockIndex()]);
// Record new pending write to offloadBlock
mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent();
mOffloadManager.getStream().record(mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()]);
}
void KVCacheTransferManager::syncWithBufferManager()
{
tr::CudaEvent readyForOffloadEvent;
mBufferManager.getStream().record(readyForOffloadEvent);
mOffloadManager.getStream().wait(readyForOffloadEvent);
tr::CudaEvent readyForOnboardEvent;
mBufferManager.getStream().record(readyForOnboardEvent);
mOnboardManager.getStream().wait(readyForOnboardEvent);
// Once we synchronize, clear our list of pending thransfers.
mPendingReads.clear();
mPendingWrites.clear();
} }
void KVCacheTransferManager::syncTransfers() void KVCacheTransferManager::syncTransfers()
{ {
tr::CudaEvent offloadEvent; tr::CudaEvent offloadEvent;
mOffloadManager.getStream().record(offloadEvent); mOffloadManager.getStream().record(offloadEvent);
mBufferManager.getStream().wait(offloadEvent);
tr::CudaEvent onboardEvent; tr::CudaEvent onboardEvent;
mOnboardManager.getStream().record(onboardEvent); mOnboardManager.getStream().record(onboardEvent);
mBufferManager.getStream().wait(offloadEvent);
mBufferManager.getStream().wait(onboardEvent); mBufferManager.getStream().wait(onboardEvent);
// Once we synchronize, clear our list of pending thransfers. // Once we synchronize, clear our list of pending thransfers.
mPendingOffloads.clear(); mPendingReads.clear();
mPendingWrites.clear();
} }
} // namespace tensorrt_llm::batch_manager::kv_cache_manager } // namespace tensorrt_llm::batch_manager::kv_cache_manager

View File

@ -235,6 +235,11 @@ public:
NB_OVERRIDE_PURE(getPoolLayerIdx, layer_idx); NB_OVERRIDE_PURE(getPoolLayerIdx, layer_idx);
} }
void syncTransferManagerWithBufferManager() override
{
NB_OVERRIDE_PURE(syncTransferManagerWithBufferManager);
}
void refreshBlocks() override void refreshBlocks() override
{ {
NB_OVERRIDE_PURE(refreshBlocks); NB_OVERRIDE_PURE(refreshBlocks);
@ -481,6 +486,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
nb::call_guard<nb::gil_scoped_release>()) nb::call_guard<nb::gil_scoped_release>())
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents, .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
nb::call_guard<nb::gil_scoped_release>()) nb::call_guard<nb::gil_scoped_release>())
.def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager,
nb::call_guard<nb::gil_scoped_release>())
.def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, nb::call_guard<nb::gil_scoped_release>())
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, nb::call_guard<nb::gil_scoped_release>()) .def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, nb::call_guard<nb::gil_scoped_release>())
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard<nb::gil_scoped_release>()) .def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard<nb::gil_scoped_release>())
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, nb::call_guard<nb::gil_scoped_release>()); .def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, nb::call_guard<nb::gil_scoped_release>());

View File

@ -240,6 +240,11 @@ public:
PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getPoolLayerIdx, layer_idx); PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getPoolLayerIdx, layer_idx);
} }
void syncTransferManagerWithBufferManager() override
{
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, syncTransferManagerWithBufferManager);
}
void refreshBlocks() override void refreshBlocks() override
{ {
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, refreshBlocks); PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, refreshBlocks);
@ -485,6 +490,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents, .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager,
py::call_guard<py::gil_scoped_release>())
.def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, py::call_guard<py::gil_scoped_release>())
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, py::call_guard<py::gil_scoped_release>()) .def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, py::call_guard<py::gil_scoped_release>())
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard<py::gil_scoped_release>()) .def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard<py::gil_scoped_release>())
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, py::call_guard<py::gil_scoped_release>()); .def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, py::call_guard<py::gil_scoped_release>());

View File

@ -434,6 +434,10 @@ class KVCacheManager(BaseResourceManager):
with request_context(self.is_draft, scheduled_batch): with request_context(self.is_draft, scheduled_batch):
context_batch = scheduled_batch.context_requests context_batch = scheduled_batch.context_requests
generation_batch = scheduled_batch.generation_requests generation_batch = scheduled_batch.generation_requests
# wait for all pending work to finish before launching offload/onboarding/partial copy
self.impl.sync_transfer_manager_with_buffer_manager()
# allocate KV Cache # allocate KV Cache
for req in context_batch: for req in context_batch:
req_beam_width = req.sampling_config.beam_width req_beam_width = req.sampling_config.beam_width
@ -475,6 +479,9 @@ class KVCacheManager(BaseResourceManager):
for _ in range(get_draft_token_length(req)): for _ in range(get_draft_token_length(req)):
self.impl.add_token(req.py_request_id) self.impl.add_token(req.py_request_id)
# prefill and generation kernels wait for scheduled offload/onboard/partial copy work before launching
self.impl.refresh_blocks()
if self.kv_connector_manager is not None: if self.kv_connector_manager is not None:
self.kv_connector_manager.build_scheduler_output( self.kv_connector_manager.build_scheduler_output(
scheduled_batch, self) scheduled_batch, self)