mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5627710][fix] Fix synchronization bugs in KvCacheTransferManager that can cause corrupted blocks (#9056)
Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com> Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> Co-authored-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Co-authored-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
This commit is contained in:
parent
b86256eb54
commit
95049eea86
@ -824,6 +824,9 @@ public:
|
|||||||
return mBufferManager;
|
return mBufferManager;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//! \brief Sync internal streams used by transfer manager with buffer manager stream
|
||||||
|
void syncTransferManagerWithBufferManager();
|
||||||
|
|
||||||
//! \brief Perform per-request bookkeeping
|
//! \brief Perform per-request bookkeeping
|
||||||
void refreshBlocks();
|
void refreshBlocks();
|
||||||
|
|
||||||
@ -1313,6 +1316,9 @@ public:
|
|||||||
//! \brief Store newest block for reuse
|
//! \brief Store newest block for reuse
|
||||||
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
|
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
|
||||||
|
|
||||||
|
//! \brief Sync internal streams used by transfer manager with buffer manager stream
|
||||||
|
void syncTransferManagerWithBufferManager();
|
||||||
|
|
||||||
//! \brief Perform per-request bookkeeping
|
//! \brief Perform per-request bookkeeping
|
||||||
void refreshBlocks();
|
void refreshBlocks();
|
||||||
|
|
||||||
@ -1584,6 +1590,7 @@ public:
|
|||||||
[[nodiscard]] virtual runtime::ITensor::SharedPtr getIndexerKCachePool() const = 0;
|
[[nodiscard]] virtual runtime::ITensor::SharedPtr getIndexerKCachePool() const = 0;
|
||||||
[[nodiscard]] virtual SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const = 0;
|
[[nodiscard]] virtual SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const = 0;
|
||||||
|
|
||||||
|
virtual void syncTransferManagerWithBufferManager() = 0;
|
||||||
virtual void refreshBlocks() = 0;
|
virtual void refreshBlocks() = 0;
|
||||||
virtual void flushIterationEvents() = 0;
|
virtual void flushIterationEvents() = 0;
|
||||||
virtual void resetReuseState() = 0;
|
virtual void resetReuseState() = 0;
|
||||||
@ -1965,6 +1972,11 @@ public:
|
|||||||
return mBlockManager.getPoolLayerIdx(layer_idx);
|
return mBlockManager.getPoolLayerIdx(layer_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void syncTransferManagerWithBufferManager() override
|
||||||
|
{
|
||||||
|
mBlockManager.syncTransferManagerWithBufferManager();
|
||||||
|
}
|
||||||
|
|
||||||
//! \brief Perform per-iteration bookkeeping
|
//! \brief Perform per-iteration bookkeeping
|
||||||
void refreshBlocks() override
|
void refreshBlocks() override
|
||||||
{
|
{
|
||||||
|
|||||||
@ -46,7 +46,15 @@ public:
|
|||||||
int numTokensToCopy = 0, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM,
|
int numTokensToCopy = 0, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM,
|
||||||
std::string const& directory = "");
|
std::string const& directory = "");
|
||||||
|
|
||||||
//! \brief Synchronize the offload/onboard streams with the bufferManager stream.
|
//! \brief Synchronize internal streams with bufferManager stream.
|
||||||
|
//! \details The buffer manager uses the same stream as the prefill and decode kernels. This method ensures that the
|
||||||
|
//! internal kernels used for offloading and onboarding will wait for prefill and decode kernels before performing
|
||||||
|
//! any block copies. This method must be called before the first call to KVCacheManager::addSequence in every step.
|
||||||
|
void syncWithBufferManager();
|
||||||
|
|
||||||
|
//! \brief Synchronize bufferManager stream with internal streams. This method ensures that prefill and decode
|
||||||
|
//! kernels for next step will wait for offloading and onboarding work that has already been scheduled. This method
|
||||||
|
//! must be called after last call to KVCacheManager::addSequence in every step.
|
||||||
void syncTransfers();
|
void syncTransfers();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -75,8 +83,10 @@ private:
|
|||||||
runtime::BufferManager mOnboardManager;
|
runtime::BufferManager mOnboardManager;
|
||||||
runtime::BufferManager mOffloadManager;
|
runtime::BufferManager mOffloadManager;
|
||||||
|
|
||||||
// Track the block ids offloaded in this iteration.
|
// Track reads and writes for blocks. Note that it is the memory pool index that
|
||||||
std::unordered_map<int32_t, tr::CudaEvent> mPendingOffloads;
|
// identifies the raw memory blocks involved in I/O, not the block Id.
|
||||||
|
std::unordered_map<kernels::KVCacheIndex::UnderlyingType, tr::CudaEvent> mPendingReads;
|
||||||
|
std::unordered_map<kernels::KVCacheIndex::UnderlyingType, tr::CudaEvent> mPendingWrites;
|
||||||
// Reference to parent loopback agent
|
// Reference to parent loopback agent
|
||||||
std::shared_ptr<kvc::BaseLoopbackAgent> mLoopbackAgent;
|
std::shared_ptr<kvc::BaseLoopbackAgent> mLoopbackAgent;
|
||||||
int mDeviceId;
|
int mDeviceId;
|
||||||
|
|||||||
@ -26,6 +26,8 @@ void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager
|
|||||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||||
NVTX3_SCOPED_RANGE(allocateKvCache);
|
NVTX3_SCOPED_RANGE(allocateKvCache);
|
||||||
|
|
||||||
|
kvCacheManager.syncTransferManagerWithBufferManager();
|
||||||
|
|
||||||
for (auto const& llmReq : contextRequests)
|
for (auto const& llmReq : contextRequests)
|
||||||
{
|
{
|
||||||
if (llmReq->isFirstContextChunk())
|
if (llmReq->isFirstContextChunk())
|
||||||
|
|||||||
@ -1336,6 +1336,19 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
|
|||||||
return numMatchedTokens;
|
return numMatchedTokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void BlockManager::syncTransferManagerWithBufferManager()
|
||||||
|
{
|
||||||
|
for (auto& [_, manager] : mWindowBlockManagers)
|
||||||
|
{
|
||||||
|
manager.syncTransferManagerWithBufferManager();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void WindowBlockManager::syncTransferManagerWithBufferManager()
|
||||||
|
{
|
||||||
|
mTransferManager->syncWithBufferManager();
|
||||||
|
}
|
||||||
|
|
||||||
void BlockManager::refreshBlocks()
|
void BlockManager::refreshBlocks()
|
||||||
{
|
{
|
||||||
for (auto& [_, manager] : mWindowBlockManagers)
|
for (auto& [_, manager] : mWindowBlockManagers)
|
||||||
|
|||||||
@ -207,47 +207,140 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void KVCacheTransferManager::onboard(BlockPtr const& offloadBlock, BlockPtr const& block,
|
//
|
||||||
|
// Note about recording events to wait for cudaMempyAsync calls between blocks:
|
||||||
|
// The memory copy involves raw memory blocks, which are pointed to by the
|
||||||
|
// memory pool block index. When recording events, you must use getMemoryPoolBlockIndex()
|
||||||
|
// as the raw memory block identifier. Using getBlockId() when recording events is wrong.
|
||||||
|
// getBlockId() returns the logical block id, which has nothing to do with the raw memory
|
||||||
|
// block pointers involved in a cudaMemcpy.
|
||||||
|
//
|
||||||
|
|
||||||
|
//
|
||||||
|
// Notes about need for synchronization:
|
||||||
|
//
|
||||||
|
// Relying on decoder syncing GPU with CPU to ensure that blocks are ready
|
||||||
|
// for offload/onboard/partial copy is dangerous. We have an asynchronous decoder
|
||||||
|
// that may not synchronize or synchronize at a later point in the execution stream.
|
||||||
|
// To avoid synchronization issues caused by changes to decoder design we rely on
|
||||||
|
// KVCacheTransferManager::syncWithBufferManager() that ensures that internal copy streams
|
||||||
|
// will wait for prefill and decode kernels that have already been scheduled.
|
||||||
|
//
|
||||||
|
// Earlier versions of this code did not account for all possible cases where a new block copy
|
||||||
|
// needed to wait for a previously scheduled copy to finish. For instance, it is possible
|
||||||
|
// that two primary blocks are offloaded to the same secondary block in a single step,
|
||||||
|
// scheduling the second offloading without waiting for the first one to finish leads to
|
||||||
|
// a corrupted block after offloading. It is possible that partial reuse will copy
|
||||||
|
// from a block that is currently being onboarded, scheduling the partial copy without
|
||||||
|
// waiting for the onboarding to finish will lead to a corrupted block. To handle all
|
||||||
|
// possible cases needing synchronization we record separate events for reads and writes
|
||||||
|
// to a block. When a new block copy is scheduled, we wait for all writes to the source
|
||||||
|
// block and all reads and writes to a destination block.
|
||||||
|
//
|
||||||
|
// As before, syncTransfers() must be called after last call to KVCacheManager::addSequence.
|
||||||
|
// Failing to do so will lead to corrupted blocks eventually.
|
||||||
|
//
|
||||||
|
|
||||||
|
void KVCacheTransferManager::onboard(BlockPtr const& offloadedBlock, BlockPtr const& block,
|
||||||
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
|
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
|
||||||
std::string const& directory)
|
std::string const& directory)
|
||||||
{
|
{
|
||||||
if (mode != executor::KvCacheTransferMode::DRAM
|
// Wait for any pending writes before reading from offloadedBlock
|
||||||
&& mPendingOffloads.find(offloadBlock->getBlockId()) == mPendingOffloads.end())
|
auto offloadedBlockPendingWriteItr = mPendingWrites.find(offloadedBlock->getMemoryPoolBlockIndex());
|
||||||
|
if (offloadedBlockPendingWriteItr != mPendingWrites.end())
|
||||||
{
|
{
|
||||||
TLLM_LOG_DEBUG("Skipping onboard for block %d because it was never previously offloaded to disk",
|
mOnboardManager.getStream().wait(offloadedBlockPendingWriteItr->second);
|
||||||
offloadBlock->getBlockId());
|
// Don't erase, we are not changing state of offloadedBlock
|
||||||
return;
|
}
|
||||||
|
// Wait for any pending reads before overwriting block
|
||||||
|
auto blockPendingReadItr = mPendingReads.find(block->getMemoryPoolBlockIndex());
|
||||||
|
if (blockPendingReadItr != mPendingReads.end())
|
||||||
|
{
|
||||||
|
mOnboardManager.getStream().wait(blockPendingReadItr->second);
|
||||||
|
mPendingReads.erase(blockPendingReadItr);
|
||||||
|
}
|
||||||
|
// Wait for any pending writes before overwriting block
|
||||||
|
auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex());
|
||||||
|
if (blockPendingWriteItr != mPendingWrites.end())
|
||||||
|
{
|
||||||
|
mOnboardManager.getStream().wait(blockPendingWriteItr->second);
|
||||||
|
mPendingWrites.erase(blockPendingWriteItr);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mPendingOffloads.find(offloadBlock->getBlockId()) != mPendingOffloads.end())
|
copyBlock(offloadedBlock, block, pools, false, numTokensToCopy, mode, directory);
|
||||||
{
|
|
||||||
mOnboardManager.getStream().wait(mPendingOffloads[offloadBlock->getBlockId()]);
|
// Record new pending read from offloadedBlock
|
||||||
}
|
mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent();
|
||||||
copyBlock(offloadBlock, block, pools, false, numTokensToCopy, mode, directory);
|
mOnboardManager.getStream().record(mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()]);
|
||||||
|
// Record new pending write to block
|
||||||
|
mPendingWrites[block->getMemoryPoolBlockIndex()] = tr::CudaEvent();
|
||||||
|
mOnboardManager.getStream().record(mPendingWrites[block->getMemoryPoolBlockIndex()]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void KVCacheTransferManager::offload(BlockPtr const& block, BlockPtr const& offloadBlock,
|
void KVCacheTransferManager::offload(BlockPtr const& block, BlockPtr const& offloadBlock,
|
||||||
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
|
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
|
||||||
std::string const& directory)
|
std::string const& directory)
|
||||||
{
|
{
|
||||||
mPendingOffloads[block->getBlockId()] = tr::CudaEvent();
|
// Wait for any pending writes before reading from block
|
||||||
|
auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex());
|
||||||
|
if (blockPendingWriteItr != mPendingWrites.end())
|
||||||
|
{
|
||||||
|
mOffloadManager.getStream().wait(blockPendingWriteItr->second);
|
||||||
|
// Don't erase, we are not changing state of block
|
||||||
|
}
|
||||||
|
// Wait for any pending reads before overwriting offloadBlock
|
||||||
|
auto offloadBlockPendingReadItr = mPendingReads.find(offloadBlock->getMemoryPoolBlockIndex());
|
||||||
|
if (offloadBlockPendingReadItr != mPendingReads.end())
|
||||||
|
{
|
||||||
|
mOffloadManager.getStream().wait(offloadBlockPendingReadItr->second);
|
||||||
|
mPendingReads.erase(offloadBlockPendingReadItr);
|
||||||
|
}
|
||||||
|
// Wait for any pending writes before overwriting offloadBlock
|
||||||
|
auto offloadBlockPendingWriteItr = mPendingWrites.find(offloadBlock->getMemoryPoolBlockIndex());
|
||||||
|
if (offloadBlockPendingWriteItr != mPendingWrites.end())
|
||||||
|
{
|
||||||
|
mOffloadManager.getStream().wait(offloadBlockPendingWriteItr->second);
|
||||||
|
mPendingWrites.erase(offloadBlockPendingWriteItr);
|
||||||
|
}
|
||||||
|
|
||||||
copyBlock(block, offloadBlock, pools, true, numTokensToCopy, mode, directory);
|
copyBlock(block, offloadBlock, pools, true, numTokensToCopy, mode, directory);
|
||||||
mOffloadManager.getStream().record(mPendingOffloads[block->getBlockId()]);
|
|
||||||
|
// Record new pending read from block
|
||||||
|
mPendingReads[block->getMemoryPoolBlockIndex()] = tr::CudaEvent();
|
||||||
|
mOffloadManager.getStream().record(mPendingReads[block->getMemoryPoolBlockIndex()]);
|
||||||
|
// Record new pending write to offloadBlock
|
||||||
|
mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent();
|
||||||
|
mOffloadManager.getStream().record(mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void KVCacheTransferManager::syncWithBufferManager()
|
||||||
|
{
|
||||||
|
tr::CudaEvent readyForOffloadEvent;
|
||||||
|
mBufferManager.getStream().record(readyForOffloadEvent);
|
||||||
|
mOffloadManager.getStream().wait(readyForOffloadEvent);
|
||||||
|
|
||||||
|
tr::CudaEvent readyForOnboardEvent;
|
||||||
|
mBufferManager.getStream().record(readyForOnboardEvent);
|
||||||
|
mOnboardManager.getStream().wait(readyForOnboardEvent);
|
||||||
|
|
||||||
|
// Once we synchronize, clear our list of pending thransfers.
|
||||||
|
mPendingReads.clear();
|
||||||
|
mPendingWrites.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void KVCacheTransferManager::syncTransfers()
|
void KVCacheTransferManager::syncTransfers()
|
||||||
{
|
{
|
||||||
tr::CudaEvent offloadEvent;
|
tr::CudaEvent offloadEvent;
|
||||||
mOffloadManager.getStream().record(offloadEvent);
|
mOffloadManager.getStream().record(offloadEvent);
|
||||||
|
mBufferManager.getStream().wait(offloadEvent);
|
||||||
|
|
||||||
tr::CudaEvent onboardEvent;
|
tr::CudaEvent onboardEvent;
|
||||||
mOnboardManager.getStream().record(onboardEvent);
|
mOnboardManager.getStream().record(onboardEvent);
|
||||||
|
|
||||||
mBufferManager.getStream().wait(offloadEvent);
|
|
||||||
mBufferManager.getStream().wait(onboardEvent);
|
mBufferManager.getStream().wait(onboardEvent);
|
||||||
|
|
||||||
// Once we synchronize, clear our list of pending thransfers.
|
// Once we synchronize, clear our list of pending thransfers.
|
||||||
mPendingOffloads.clear();
|
mPendingReads.clear();
|
||||||
|
mPendingWrites.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||||
|
|||||||
@ -235,6 +235,11 @@ public:
|
|||||||
NB_OVERRIDE_PURE(getPoolLayerIdx, layer_idx);
|
NB_OVERRIDE_PURE(getPoolLayerIdx, layer_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void syncTransferManagerWithBufferManager() override
|
||||||
|
{
|
||||||
|
NB_OVERRIDE_PURE(syncTransferManagerWithBufferManager);
|
||||||
|
}
|
||||||
|
|
||||||
void refreshBlocks() override
|
void refreshBlocks() override
|
||||||
{
|
{
|
||||||
NB_OVERRIDE_PURE(refreshBlocks);
|
NB_OVERRIDE_PURE(refreshBlocks);
|
||||||
@ -481,6 +486,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
|
|||||||
nb::call_guard<nb::gil_scoped_release>())
|
nb::call_guard<nb::gil_scoped_release>())
|
||||||
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
|
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
|
||||||
nb::call_guard<nb::gil_scoped_release>())
|
nb::call_guard<nb::gil_scoped_release>())
|
||||||
|
.def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager,
|
||||||
|
nb::call_guard<nb::gil_scoped_release>())
|
||||||
|
.def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, nb::call_guard<nb::gil_scoped_release>())
|
||||||
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, nb::call_guard<nb::gil_scoped_release>())
|
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, nb::call_guard<nb::gil_scoped_release>())
|
||||||
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard<nb::gil_scoped_release>())
|
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard<nb::gil_scoped_release>())
|
||||||
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, nb::call_guard<nb::gil_scoped_release>());
|
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, nb::call_guard<nb::gil_scoped_release>());
|
||||||
|
|||||||
@ -240,6 +240,11 @@ public:
|
|||||||
PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getPoolLayerIdx, layer_idx);
|
PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getPoolLayerIdx, layer_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void syncTransferManagerWithBufferManager() override
|
||||||
|
{
|
||||||
|
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, syncTransferManagerWithBufferManager);
|
||||||
|
}
|
||||||
|
|
||||||
void refreshBlocks() override
|
void refreshBlocks() override
|
||||||
{
|
{
|
||||||
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, refreshBlocks);
|
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, refreshBlocks);
|
||||||
@ -485,6 +490,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
|
|||||||
py::call_guard<py::gil_scoped_release>())
|
py::call_guard<py::gil_scoped_release>())
|
||||||
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
|
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
|
||||||
py::call_guard<py::gil_scoped_release>())
|
py::call_guard<py::gil_scoped_release>())
|
||||||
|
.def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager,
|
||||||
|
py::call_guard<py::gil_scoped_release>())
|
||||||
|
.def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, py::call_guard<py::gil_scoped_release>())
|
||||||
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, py::call_guard<py::gil_scoped_release>())
|
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, py::call_guard<py::gil_scoped_release>())
|
||||||
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard<py::gil_scoped_release>())
|
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard<py::gil_scoped_release>())
|
||||||
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, py::call_guard<py::gil_scoped_release>());
|
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, py::call_guard<py::gil_scoped_release>());
|
||||||
|
|||||||
@ -434,6 +434,10 @@ class KVCacheManager(BaseResourceManager):
|
|||||||
with request_context(self.is_draft, scheduled_batch):
|
with request_context(self.is_draft, scheduled_batch):
|
||||||
context_batch = scheduled_batch.context_requests
|
context_batch = scheduled_batch.context_requests
|
||||||
generation_batch = scheduled_batch.generation_requests
|
generation_batch = scheduled_batch.generation_requests
|
||||||
|
|
||||||
|
# wait for all pending work to finish before launching offload/onboarding/partial copy
|
||||||
|
self.impl.sync_transfer_manager_with_buffer_manager()
|
||||||
|
|
||||||
# allocate KV Cache
|
# allocate KV Cache
|
||||||
for req in context_batch:
|
for req in context_batch:
|
||||||
req_beam_width = req.sampling_config.beam_width
|
req_beam_width = req.sampling_config.beam_width
|
||||||
@ -475,6 +479,9 @@ class KVCacheManager(BaseResourceManager):
|
|||||||
for _ in range(get_draft_token_length(req)):
|
for _ in range(get_draft_token_length(req)):
|
||||||
self.impl.add_token(req.py_request_id)
|
self.impl.add_token(req.py_request_id)
|
||||||
|
|
||||||
|
# prefill and generation kernels wait for scheduled offload/onboard/partial copy work before launching
|
||||||
|
self.impl.refresh_blocks()
|
||||||
|
|
||||||
if self.kv_connector_manager is not None:
|
if self.kv_connector_manager is not None:
|
||||||
self.kv_connector_manager.build_scheduler_output(
|
self.kv_connector_manager.build_scheduler_output(
|
||||||
scheduled_batch, self)
|
scheduled_batch, self)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user