mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5627710][fix] Fix synchronization bugs in KvCacheTransferManager that can cause corrupted blocks (#9056)
Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com> Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> Co-authored-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Co-authored-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
This commit is contained in:
parent
b86256eb54
commit
95049eea86
@ -824,6 +824,9 @@ public:
|
||||
return mBufferManager;
|
||||
}
|
||||
|
||||
//! \brief Sync internal streams used by transfer manager with buffer manager stream
|
||||
void syncTransferManagerWithBufferManager();
|
||||
|
||||
//! \brief Perform per-request bookkeeping
|
||||
void refreshBlocks();
|
||||
|
||||
@ -1313,6 +1316,9 @@ public:
|
||||
//! \brief Store newest block for reuse
|
||||
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
|
||||
|
||||
//! \brief Sync internal streams used by transfer manager with buffer manager stream
|
||||
void syncTransferManagerWithBufferManager();
|
||||
|
||||
//! \brief Perform per-request bookkeeping
|
||||
void refreshBlocks();
|
||||
|
||||
@ -1584,6 +1590,7 @@ public:
|
||||
[[nodiscard]] virtual runtime::ITensor::SharedPtr getIndexerKCachePool() const = 0;
|
||||
[[nodiscard]] virtual SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const = 0;
|
||||
|
||||
virtual void syncTransferManagerWithBufferManager() = 0;
|
||||
virtual void refreshBlocks() = 0;
|
||||
virtual void flushIterationEvents() = 0;
|
||||
virtual void resetReuseState() = 0;
|
||||
@ -1965,6 +1972,11 @@ public:
|
||||
return mBlockManager.getPoolLayerIdx(layer_idx);
|
||||
}
|
||||
|
||||
void syncTransferManagerWithBufferManager() override
|
||||
{
|
||||
mBlockManager.syncTransferManagerWithBufferManager();
|
||||
}
|
||||
|
||||
//! \brief Perform per-iteration bookkeeping
|
||||
void refreshBlocks() override
|
||||
{
|
||||
|
||||
@ -46,7 +46,15 @@ public:
|
||||
int numTokensToCopy = 0, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM,
|
||||
std::string const& directory = "");
|
||||
|
||||
//! \brief Synchronize the offload/onboard streams with the bufferManager stream.
|
||||
//! \brief Synchronize internal streams with bufferManager stream.
|
||||
//! \details The buffer manager uses the same stream as the prefill and decode kernels. This method ensures that the
|
||||
//! internal kernels used for offloading and onboarding will wait for prefill and decode kernels before performing
|
||||
//! any block copies. This method must be called before the first call to KVCacheManager::addSequence in every step.
|
||||
void syncWithBufferManager();
|
||||
|
||||
//! \brief Synchronize bufferManager stream with internal streams. This method ensures that prefill and decode
|
||||
//! kernels for next step will wait for offloading and onboarding work that has already been scheduled. This method
|
||||
//! must be called after last call to KVCacheManager::addSequence in every step.
|
||||
void syncTransfers();
|
||||
|
||||
private:
|
||||
@ -75,8 +83,10 @@ private:
|
||||
runtime::BufferManager mOnboardManager;
|
||||
runtime::BufferManager mOffloadManager;
|
||||
|
||||
// Track the block ids offloaded in this iteration.
|
||||
std::unordered_map<int32_t, tr::CudaEvent> mPendingOffloads;
|
||||
// Track reads and writes for blocks. Note that it is the memory pool index that
|
||||
// identifies the raw memory blocks involved in I/O, not the block Id.
|
||||
std::unordered_map<kernels::KVCacheIndex::UnderlyingType, tr::CudaEvent> mPendingReads;
|
||||
std::unordered_map<kernels::KVCacheIndex::UnderlyingType, tr::CudaEvent> mPendingWrites;
|
||||
// Reference to parent loopback agent
|
||||
std::shared_ptr<kvc::BaseLoopbackAgent> mLoopbackAgent;
|
||||
int mDeviceId;
|
||||
|
||||
@ -26,6 +26,8 @@ void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
NVTX3_SCOPED_RANGE(allocateKvCache);
|
||||
|
||||
kvCacheManager.syncTransferManagerWithBufferManager();
|
||||
|
||||
for (auto const& llmReq : contextRequests)
|
||||
{
|
||||
if (llmReq->isFirstContextChunk())
|
||||
|
||||
@ -1336,6 +1336,19 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
|
||||
return numMatchedTokens;
|
||||
}
|
||||
|
||||
void BlockManager::syncTransferManagerWithBufferManager()
|
||||
{
|
||||
for (auto& [_, manager] : mWindowBlockManagers)
|
||||
{
|
||||
manager.syncTransferManagerWithBufferManager();
|
||||
}
|
||||
}
|
||||
|
||||
void WindowBlockManager::syncTransferManagerWithBufferManager()
|
||||
{
|
||||
mTransferManager->syncWithBufferManager();
|
||||
}
|
||||
|
||||
void BlockManager::refreshBlocks()
|
||||
{
|
||||
for (auto& [_, manager] : mWindowBlockManagers)
|
||||
|
||||
@ -207,47 +207,140 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst,
|
||||
}
|
||||
}
|
||||
|
||||
void KVCacheTransferManager::onboard(BlockPtr const& offloadBlock, BlockPtr const& block,
|
||||
//
|
||||
// Note about recording events to wait for cudaMempyAsync calls between blocks:
|
||||
// The memory copy involves raw memory blocks, which are pointed to by the
|
||||
// memory pool block index. When recording events, you must use getMemoryPoolBlockIndex()
|
||||
// as the raw memory block identifier. Using getBlockId() when recording events is wrong.
|
||||
// getBlockId() returns the logical block id, which has nothing to do with the raw memory
|
||||
// block pointers involved in a cudaMemcpy.
|
||||
//
|
||||
|
||||
//
|
||||
// Notes about need for synchronization:
|
||||
//
|
||||
// Relying on decoder syncing GPU with CPU to ensure that blocks are ready
|
||||
// for offload/onboard/partial copy is dangerous. We have an asynchronous decoder
|
||||
// that may not synchronize or synchronize at a later point in the execution stream.
|
||||
// To avoid synchronization issues caused by changes to decoder design we rely on
|
||||
// KVCacheTransferManager::syncWithBufferManager() that ensures that internal copy streams
|
||||
// will wait for prefill and decode kernels that have already been scheduled.
|
||||
//
|
||||
// Earlier versions of this code did not account for all possible cases where a new block copy
|
||||
// needed to wait for a previously scheduled copy to finish. For instance, it is possible
|
||||
// that two primary blocks are offloaded to the same secondary block in a single step,
|
||||
// scheduling the second offloading without waiting for the first one to finish leads to
|
||||
// a corrupted block after offloading. It is possible that partial reuse will copy
|
||||
// from a block that is currently being onboarded, scheduling the partial copy without
|
||||
// waiting for the onboarding to finish will lead to a corrupted block. To handle all
|
||||
// possible cases needing synchronization we record separate events for reads and writes
|
||||
// to a block. When a new block copy is scheduled, we wait for all writes to the source
|
||||
// block and all reads and writes to a destination block.
|
||||
//
|
||||
// As before, syncTransfers() must be called after last call to KVCacheManager::addSequence.
|
||||
// Failing to do so will lead to corrupted blocks eventually.
|
||||
//
|
||||
|
||||
void KVCacheTransferManager::onboard(BlockPtr const& offloadedBlock, BlockPtr const& block,
|
||||
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
|
||||
std::string const& directory)
|
||||
{
|
||||
if (mode != executor::KvCacheTransferMode::DRAM
|
||||
&& mPendingOffloads.find(offloadBlock->getBlockId()) == mPendingOffloads.end())
|
||||
// Wait for any pending writes before reading from offloadedBlock
|
||||
auto offloadedBlockPendingWriteItr = mPendingWrites.find(offloadedBlock->getMemoryPoolBlockIndex());
|
||||
if (offloadedBlockPendingWriteItr != mPendingWrites.end())
|
||||
{
|
||||
TLLM_LOG_DEBUG("Skipping onboard for block %d because it was never previously offloaded to disk",
|
||||
offloadBlock->getBlockId());
|
||||
return;
|
||||
mOnboardManager.getStream().wait(offloadedBlockPendingWriteItr->second);
|
||||
// Don't erase, we are not changing state of offloadedBlock
|
||||
}
|
||||
// Wait for any pending reads before overwriting block
|
||||
auto blockPendingReadItr = mPendingReads.find(block->getMemoryPoolBlockIndex());
|
||||
if (blockPendingReadItr != mPendingReads.end())
|
||||
{
|
||||
mOnboardManager.getStream().wait(blockPendingReadItr->second);
|
||||
mPendingReads.erase(blockPendingReadItr);
|
||||
}
|
||||
// Wait for any pending writes before overwriting block
|
||||
auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex());
|
||||
if (blockPendingWriteItr != mPendingWrites.end())
|
||||
{
|
||||
mOnboardManager.getStream().wait(blockPendingWriteItr->second);
|
||||
mPendingWrites.erase(blockPendingWriteItr);
|
||||
}
|
||||
|
||||
if (mPendingOffloads.find(offloadBlock->getBlockId()) != mPendingOffloads.end())
|
||||
{
|
||||
mOnboardManager.getStream().wait(mPendingOffloads[offloadBlock->getBlockId()]);
|
||||
}
|
||||
copyBlock(offloadBlock, block, pools, false, numTokensToCopy, mode, directory);
|
||||
copyBlock(offloadedBlock, block, pools, false, numTokensToCopy, mode, directory);
|
||||
|
||||
// Record new pending read from offloadedBlock
|
||||
mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent();
|
||||
mOnboardManager.getStream().record(mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()]);
|
||||
// Record new pending write to block
|
||||
mPendingWrites[block->getMemoryPoolBlockIndex()] = tr::CudaEvent();
|
||||
mOnboardManager.getStream().record(mPendingWrites[block->getMemoryPoolBlockIndex()]);
|
||||
}
|
||||
|
||||
void KVCacheTransferManager::offload(BlockPtr const& block, BlockPtr const& offloadBlock,
|
||||
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
|
||||
std::string const& directory)
|
||||
{
|
||||
mPendingOffloads[block->getBlockId()] = tr::CudaEvent();
|
||||
// Wait for any pending writes before reading from block
|
||||
auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex());
|
||||
if (blockPendingWriteItr != mPendingWrites.end())
|
||||
{
|
||||
mOffloadManager.getStream().wait(blockPendingWriteItr->second);
|
||||
// Don't erase, we are not changing state of block
|
||||
}
|
||||
// Wait for any pending reads before overwriting offloadBlock
|
||||
auto offloadBlockPendingReadItr = mPendingReads.find(offloadBlock->getMemoryPoolBlockIndex());
|
||||
if (offloadBlockPendingReadItr != mPendingReads.end())
|
||||
{
|
||||
mOffloadManager.getStream().wait(offloadBlockPendingReadItr->second);
|
||||
mPendingReads.erase(offloadBlockPendingReadItr);
|
||||
}
|
||||
// Wait for any pending writes before overwriting offloadBlock
|
||||
auto offloadBlockPendingWriteItr = mPendingWrites.find(offloadBlock->getMemoryPoolBlockIndex());
|
||||
if (offloadBlockPendingWriteItr != mPendingWrites.end())
|
||||
{
|
||||
mOffloadManager.getStream().wait(offloadBlockPendingWriteItr->second);
|
||||
mPendingWrites.erase(offloadBlockPendingWriteItr);
|
||||
}
|
||||
|
||||
copyBlock(block, offloadBlock, pools, true, numTokensToCopy, mode, directory);
|
||||
mOffloadManager.getStream().record(mPendingOffloads[block->getBlockId()]);
|
||||
|
||||
// Record new pending read from block
|
||||
mPendingReads[block->getMemoryPoolBlockIndex()] = tr::CudaEvent();
|
||||
mOffloadManager.getStream().record(mPendingReads[block->getMemoryPoolBlockIndex()]);
|
||||
// Record new pending write to offloadBlock
|
||||
mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent();
|
||||
mOffloadManager.getStream().record(mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()]);
|
||||
}
|
||||
|
||||
void KVCacheTransferManager::syncWithBufferManager()
|
||||
{
|
||||
tr::CudaEvent readyForOffloadEvent;
|
||||
mBufferManager.getStream().record(readyForOffloadEvent);
|
||||
mOffloadManager.getStream().wait(readyForOffloadEvent);
|
||||
|
||||
tr::CudaEvent readyForOnboardEvent;
|
||||
mBufferManager.getStream().record(readyForOnboardEvent);
|
||||
mOnboardManager.getStream().wait(readyForOnboardEvent);
|
||||
|
||||
// Once we synchronize, clear our list of pending thransfers.
|
||||
mPendingReads.clear();
|
||||
mPendingWrites.clear();
|
||||
}
|
||||
|
||||
void KVCacheTransferManager::syncTransfers()
|
||||
{
|
||||
tr::CudaEvent offloadEvent;
|
||||
mOffloadManager.getStream().record(offloadEvent);
|
||||
mBufferManager.getStream().wait(offloadEvent);
|
||||
|
||||
tr::CudaEvent onboardEvent;
|
||||
mOnboardManager.getStream().record(onboardEvent);
|
||||
|
||||
mBufferManager.getStream().wait(offloadEvent);
|
||||
mBufferManager.getStream().wait(onboardEvent);
|
||||
|
||||
// Once we synchronize, clear our list of pending thransfers.
|
||||
mPendingOffloads.clear();
|
||||
mPendingReads.clear();
|
||||
mPendingWrites.clear();
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
@ -235,6 +235,11 @@ public:
|
||||
NB_OVERRIDE_PURE(getPoolLayerIdx, layer_idx);
|
||||
}
|
||||
|
||||
void syncTransferManagerWithBufferManager() override
|
||||
{
|
||||
NB_OVERRIDE_PURE(syncTransferManagerWithBufferManager);
|
||||
}
|
||||
|
||||
void refreshBlocks() override
|
||||
{
|
||||
NB_OVERRIDE_PURE(refreshBlocks);
|
||||
@ -481,6 +486,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
|
||||
nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
|
||||
nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager,
|
||||
nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, nb::call_guard<nb::gil_scoped_release>());
|
||||
|
||||
@ -240,6 +240,11 @@ public:
|
||||
PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getPoolLayerIdx, layer_idx);
|
||||
}
|
||||
|
||||
void syncTransferManagerWithBufferManager() override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, syncTransferManagerWithBufferManager);
|
||||
}
|
||||
|
||||
void refreshBlocks() override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, refreshBlocks);
|
||||
@ -485,6 +490,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, py::call_guard<py::gil_scoped_release>())
|
||||
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, py::call_guard<py::gil_scoped_release>())
|
||||
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard<py::gil_scoped_release>())
|
||||
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
@ -434,6 +434,10 @@ class KVCacheManager(BaseResourceManager):
|
||||
with request_context(self.is_draft, scheduled_batch):
|
||||
context_batch = scheduled_batch.context_requests
|
||||
generation_batch = scheduled_batch.generation_requests
|
||||
|
||||
# wait for all pending work to finish before launching offload/onboarding/partial copy
|
||||
self.impl.sync_transfer_manager_with_buffer_manager()
|
||||
|
||||
# allocate KV Cache
|
||||
for req in context_batch:
|
||||
req_beam_width = req.sampling_config.beam_width
|
||||
@ -475,6 +479,9 @@ class KVCacheManager(BaseResourceManager):
|
||||
for _ in range(get_draft_token_length(req)):
|
||||
self.impl.add_token(req.py_request_id)
|
||||
|
||||
# prefill and generation kernels wait for scheduled offload/onboard/partial copy work before launching
|
||||
self.impl.refresh_blocks()
|
||||
|
||||
if self.kv_connector_manager is not None:
|
||||
self.kv_connector_manager.build_scheduler_output(
|
||||
scheduled_batch, self)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user