[https://nvbugs/5689235][fix] Fix cancellation+chunked prefill+disagg (#10111)

Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
This commit is contained in:
Iman Tabrizian 2026-01-12 15:23:26 -08:00 committed by GitHub
parent 18a33764b5
commit 48b09e5a25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 380 additions and 57 deletions

View File

@ -648,7 +648,7 @@ public:
void replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx); void replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx);
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse( [[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false); GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false);
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest); void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
@ -853,8 +853,8 @@ public:
//! \param blockKeys Key of each block. //! \param blockKeys Key of each block.
//! \param blockIds Id of each block. //! \param blockIds Id of each block.
//! \param pinBlocks If true, increment ref count for blocks while storing (pin on store). //! \param pinBlocks If true, increment ref count for blocks while storing (pin on store).
//! \return Pair of (num blocks stored for reuse, id of the last block stored if any). //! \return Pair of (num blocks stored for reuse, vector of pinned block IDs).
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks( [[nodiscard]] std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> storeBlocks(
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds, std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
bool pinBlocks = false); bool pinBlocks = false);
@ -886,8 +886,8 @@ public:
[[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey); [[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey);
//! \brief Unpin blocks by starting from a block id and walking prev pointers. //! \brief Unpin blocks by block ids directly
void unpinBlocksById(KVCacheBlock::IdType blockId); void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds);
void initializeSequenceStorageValidity(LlmRequest::RequestIdType requestId) void initializeSequenceStorageValidity(LlmRequest::RequestIdType requestId)
{ {
@ -1103,7 +1103,7 @@ public:
std::optional<KVCacheBlock::IdType> releaseBlocks( std::optional<KVCacheBlock::IdType> releaseBlocks(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false); GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse( [[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false); GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId); void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
@ -1112,7 +1112,7 @@ public:
/// @param sequence The generation request whose blocks should be pinned. /// @param sequence The generation request whose blocks should be pinned.
void pinBlocks(GenerationRequest& sequence); void pinBlocks(GenerationRequest& sequence);
void unpinBlocksById(KVCacheBlock::IdType blockId); void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds);
void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize); void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize);
@ -1133,7 +1133,7 @@ public:
void offloadBlock(BlockPtr const& block, SizeType32 windowSize, void offloadBlock(BlockPtr const& block, SizeType32 windowSize,
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = ""); executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks( [[nodiscard]] std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> storeBlocks(
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds, std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
SizeType32 windowSize, bool pinBlocks = false) SizeType32 windowSize, bool pinBlocks = false)
{ {
@ -1584,7 +1584,7 @@ public:
virtual void storeNewBlock(LlmRequest const& llmRequest) = 0; virtual void storeNewBlock(LlmRequest const& llmRequest) = 0;
/// \brief Store blocks for reuse for a given request id /// \brief Store blocks for reuse for a given request id
[[nodiscard]] virtual std::optional<KVCacheBlock::IdType> storeBlocksForReuse( [[nodiscard]] virtual std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false) LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false)
= 0; = 0;
@ -1678,7 +1678,7 @@ public:
BlockKey const& blockKey, SizeType32 windowSize) BlockKey const& blockKey, SizeType32 windowSize)
= 0; = 0;
virtual void unpinBlocksById(KVCacheBlock::IdType blockId) = 0; virtual void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) = 0;
}; };
class KVCacheManager : public BaseKVCacheManager class KVCacheManager : public BaseKVCacheManager
@ -1939,7 +1939,7 @@ public:
//! \brief Store newest blocks for reuse //! \brief Store newest blocks for reuse
void storeNewBlock(LlmRequest const& llmRequest) override; void storeNewBlock(LlmRequest const& llmRequest) override;
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse( [[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false) override; LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false) override;
[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock); [[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
@ -1960,7 +1960,7 @@ public:
void pinBlocks(LlmRequest::RequestIdType requestId) override; void pinBlocks(LlmRequest::RequestIdType requestId) override;
void unpinBlocksById(KVCacheBlock::IdType blockId) override; void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) override;
std::optional<KVCacheBlock::IdType> getLastBlockId(LlmRequest::RequestIdType requestId) const override; std::optional<KVCacheBlock::IdType> getLastBlockId(LlmRequest::RequestIdType requestId) const override;

View File

@ -1667,6 +1667,12 @@ public:
[](auto reason) { return reason == executor::FinishReason::kLENGTH; }); [](auto reason) { return reason == executor::FinishReason::kLENGTH; });
} }
[[nodiscard]] bool isFinishedDueToCancellation() const noexcept
{
return std::all_of(mFinishReasons.begin(), mFinishReasons.end(),
[](auto reason) { return reason == executor::FinishReason::kCANCELLED; });
}
[[nodiscard]] bool isTimedOut() const [[nodiscard]] bool isTimedOut() const
{ {
if (!mAllottedTimeMs.has_value()) if (!mAllottedTimeMs.has_value())

View File

@ -1556,7 +1556,7 @@ void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAm
} }
} }
std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::storeBlocks( std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> WindowBlockManager::storeBlocks(
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds, bool pinBlocks) std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds, bool pinBlocks)
{ {
SizeType32 numBlocksStoredForReuse = 0; SizeType32 numBlocksStoredForReuse = 0;
@ -1569,7 +1569,7 @@ std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::s
auto numBlocks = blockKeys.size(); auto numBlocks = blockKeys.size();
std::vector<BlockPtr> storedBlocks; std::vector<BlockPtr> storedBlocks;
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt; std::vector<KVCacheBlock::IdType> pinnedBlockIds;
for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt) for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt)
{ {
auto const bid = blockIds[blockCnt]; auto const bid = blockIds[blockCnt];
@ -1620,14 +1620,14 @@ std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::s
if (pinBlocks) if (pinBlocks)
{ {
searchRoot->incRefCount(); searchRoot->incRefCount();
pinnedBlockIds.push_back(searchRoot->getBlockId());
} }
lastStoredId = searchRoot->getBlockId();
} }
if (mEventManager) if (mEventManager)
{ {
mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize); mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize);
} }
return {numBlocksStoredForReuse, lastStoredId}; return {numBlocksStoredForReuse, pinnedBlockIds};
} }
void BlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx) void BlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx)
@ -1715,15 +1715,15 @@ std::deque<tle::KVCacheEvent> BlockManager::getLatestEvents(std::optional<std::c
return mEventManager ? mEventManager->getEvents(timeout) : std::deque<tle::KVCacheEvent>{}; return mEventManager ? mEventManager->getEvents(timeout) : std::deque<tle::KVCacheEvent>{};
} }
std::optional<KVCacheBlock::IdType> BlockManager::storeBlocksForReuse( std::vector<KVCacheBlock::IdType> BlockManager::storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks) GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
{ {
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt; std::vector<KVCacheBlock::IdType> pinnedBlockIds;
for (auto& [_, manager] : mWindowBlockManagers) for (auto& [_, manager] : mWindowBlockManagers)
{ {
lastStoredId = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks); pinnedBlockIds = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
} }
return lastStoredId; return pinnedBlockIds;
} }
std::optional<KVCacheBlock::IdType> BlockManager::releaseBlocks( std::optional<KVCacheBlock::IdType> BlockManager::releaseBlocks(
@ -1767,7 +1767,7 @@ void BlockManager::pinBlocks(GenerationRequest& sequence)
} }
} }
void BlockManager::unpinBlocksById(KVCacheBlock::IdType blockId) void BlockManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
{ {
// Use the first window size // Use the first window size
if (mWindowBlockManagers.empty()) if (mWindowBlockManagers.empty())
@ -1775,7 +1775,7 @@ void BlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
return; return;
} }
auto& firstManager = mWindowBlockManagers.begin()->second; auto& firstManager = mWindowBlockManagers.begin()->second;
firstManager.unpinBlocksById(blockId); firstManager.unpinBlocksById(blockIds);
} }
void WindowBlockManager::pinBlocks(GenerationRequest& sequence) void WindowBlockManager::pinBlocks(GenerationRequest& sequence)
@ -1788,21 +1788,26 @@ void WindowBlockManager::pinBlocks(GenerationRequest& sequence)
} }
} }
void WindowBlockManager::unpinBlocksById(KVCacheBlock::IdType blockId) void WindowBlockManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
{ {
if (blockId < 0 || static_cast<size_t>(blockId) >= mAllBlocksById.size()) if (blockIds.empty())
{ {
return; return;
} }
for (auto const& blockId : blockIds)
{
TLLM_CHECK_WITH_INFO(blockId >= 0 && static_cast<size_t>(blockId) < mAllBlocksById.size(),
"Block id %d is out of range", blockId);
auto block = mAllBlocksById[blockId]; auto block = mAllBlocksById[blockId];
while (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId) if (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
{ {
block->decRefCount(); block->decRefCount();
if (!block->hasRefs()) if (!block->hasRefs())
{ {
mEvictionPolicy->releaseBlock(block); mEvictionPolicy->releaseBlock(block);
} }
block = std::move(block->getPrevBlock()); }
} }
} }
@ -1870,7 +1875,7 @@ void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef<
(void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); (void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
} }
std::optional<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse( std::vector<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks) GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
{ {
auto constexpr beamIdx = 0; auto constexpr beamIdx = 0;
@ -1883,7 +1888,10 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
auto const usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1; auto const usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1;
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, true); auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, true);
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest); auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);
return storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks).second;
auto [numStored, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks);
return pinnedBlockIds;
} }
std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks( std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
@ -1922,7 +1930,7 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
std::transform(allocatedBlocks.begin(), allocatedBlocks.end(), cacheBlockIds.begin(), std::transform(allocatedBlocks.begin(), allocatedBlocks.end(), cacheBlockIds.begin(),
[](BlockPtr const& block) { return block->getBlockId(); }); [](BlockPtr const& block) { return block->getBlockId(); });
auto [numBlocksStoredForReuse, lastStoredId] = storeBlocks(std::move(blockKeys), cacheBlockIds); auto [numBlocksStoredForReuse, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds);
TLLM_LOG_DEBUG("%s::releaseBlocks Request %lu, %d blocks stored for reuse", mLogPrefix.c_str(), TLLM_LOG_DEBUG("%s::releaseBlocks Request %lu, %d blocks stored for reuse", mLogPrefix.c_str(),
sequence.getRequestId(), numBlocksStoredForReuse); sequence.getRequestId(), numBlocksStoredForReuse);
} }
@ -2499,15 +2507,14 @@ std::optional<KVCacheBlock::IdType> KVCacheManager::removeSequence(
return lastStoredId; return lastStoredId;
} }
std::optional<KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse( std::vector<KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse(
RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks) RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
{ {
TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
auto& sequence = getSequence(requestId); auto& sequence = getSequence(requestId);
std::optional<KVCacheBlock::IdType> lastStoredId auto pinnedBlockIds = mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
= mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
TLLM_LOG_TRACE("[%s]::%s stop", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__); TLLM_LOG_TRACE("[%s]::%s stop", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
return lastStoredId; return pinnedBlockIds;
} }
void KVCacheManager::schedulingRemoveSequence(RequestIdType requestId) void KVCacheManager::schedulingRemoveSequence(RequestIdType requestId)
@ -2522,9 +2529,9 @@ void KVCacheManager::pinBlocks(RequestIdType requestId)
mBlockManager.pinBlocks(sequence); mBlockManager.pinBlocks(sequence);
} }
void KVCacheManager::unpinBlocksById(KVCacheBlock::IdType blockId) void KVCacheManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
{ {
mBlockManager.unpinBlocksById(blockId); mBlockManager.unpinBlocksById(blockIds);
} }
SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const

View File

@ -2179,11 +2179,11 @@ void Executor::Impl::terminateContextFinishedRequests(InTransList& inTransmissio
auto req = item.request; auto req = item.request;
if (req->isDisaggContextCompleteState()) if (req->isDisaggContextCompleteState())
{ {
// If lastBlockId was tracked, unpin it. Otherwise, just terminate. // If pinnedBlockIds were tracked, unpin them. Otherwise, just terminate.
auto kvMgr = mModel->getKVCacheManager(); auto kvMgr = mModel->getKVCacheManager();
if (kvMgr && item.lastBlockId.has_value()) if (kvMgr && !item.pinnedBlockIds.empty())
{ {
kvMgr->unpinBlocksById(item.lastBlockId.value()); kvMgr->unpinBlocksById(item.pinnedBlockIds);
} }
else else
{ {
@ -2234,14 +2234,14 @@ Executor::Impl::RequestList Executor::Impl::populateNewResponses(
// move the in transmission requests to another tracker // move the in transmission requests to another tracker
if (llmReq->isDisaggContextTransmissionState()) if (llmReq->isDisaggContextTransmissionState())
{ {
std::optional<SizeType32> lastBlockId{}; std::vector<SizeType32> pinnedBlockIds{};
auto kvMgr = mModel->getKVCacheManager(); auto kvMgr = mModel->getKVCacheManager();
if (kvMgr && kvMgr->isEnableBlockReuse() && !kvMgr->getBlockManager().isVariableWindow()) if (kvMgr && kvMgr->isEnableBlockReuse() && !kvMgr->getBlockManager().isVariableWindow())
{ {
lastBlockId = kvMgr->storeBlocksForReuse(llmReq->mRequestId, llmReq, /*pinBlocks=*/true); pinnedBlockIds = kvMgr->storeBlocksForReuse(llmReq->mRequestId, llmReq, /*pinBlocks=*/true);
mModel->terminateRequest(llmReq); mModel->terminateRequest(llmReq);
} }
inTransmissionRequests.push_back(InTransmissionItem{*it, lastBlockId}); inTransmissionRequests.push_back(InTransmissionItem{*it, pinnedBlockIds});
} }
finishedRequests.push_back(*it); finishedRequests.push_back(*it);
it = activeRequests.erase(it); it = activeRequests.erase(it);

View File

@ -80,12 +80,12 @@ class Executor::Impl
using RequestList = std::list<LlmRequestPtr>; using RequestList = std::list<LlmRequestPtr>;
// When block reuse is enabled for context worker for disaggregated serving, // When block reuse is enabled for context worker for disaggregated serving,
// we need to store the last block id so that we can unpin the block when // we need to store the pinned block ids so that we can unpin them when
// the request is finished. // the request is finished.
struct InTransmissionItem struct InTransmissionItem
{ {
LlmRequestPtr request; LlmRequestPtr request;
std::optional<SizeType32> lastBlockId; std::vector<SizeType32> pinnedBlockIds;
}; };
using InTransList = std::list<InTransmissionItem>; using InTransList = std::list<InTransmissionItem>;

View File

@ -161,6 +161,7 @@ void initBindings(nb::module_& m)
.def("set_finished_reason", &GenLlmReq::setFinishedReason, nb::arg("finish_reason"), nb::arg("beam")) .def("set_finished_reason", &GenLlmReq::setFinishedReason, nb::arg("finish_reason"), nb::arg("beam"))
.def_prop_ro("is_finished", &GenLlmReq::isFinished) .def_prop_ro("is_finished", &GenLlmReq::isFinished)
.def_prop_ro("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength) .def_prop_ro("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength)
.def_prop_ro("is_finished_due_to_cancellation", &GenLlmReq::isFinishedDueToCancellation)
.def_prop_rw( .def_prop_rw(
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition) "context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
.def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen) .def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)

View File

@ -123,7 +123,7 @@ public:
NB_OVERRIDE_PURE(removeSequence, requestId, llmRequest, pinOnRelease); NB_OVERRIDE_PURE(removeSequence, requestId, llmRequest, pinOnRelease);
} }
std::optional<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId, std::vector<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
tensorrt_llm::common::OptionalRef<tb::LlmRequest const> llmRequest, bool pinBlocks) override tensorrt_llm::common::OptionalRef<tb::LlmRequest const> llmRequest, bool pinBlocks) override
{ {
NB_OVERRIDE_PURE(storeBlocksForReuse, requestId, llmRequest, pinBlocks); NB_OVERRIDE_PURE(storeBlocksForReuse, requestId, llmRequest, pinBlocks);

View File

@ -165,6 +165,7 @@ void initBindings(pybind11::module_& m)
.def("set_finished_reason", &GenLlmReq::setFinishedReason, py::arg("finish_reason"), py::arg("beam")) .def("set_finished_reason", &GenLlmReq::setFinishedReason, py::arg("finish_reason"), py::arg("beam"))
.def_property_readonly("is_finished", &GenLlmReq::isFinished) .def_property_readonly("is_finished", &GenLlmReq::isFinished)
.def_property_readonly("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength) .def_property_readonly("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength)
.def_property_readonly("is_finished_due_to_cancellation", &GenLlmReq::isFinishedDueToCancellation)
.def_property( .def_property(
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition) "context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
.def_property_readonly("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen) .def_property_readonly("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)

View File

@ -111,10 +111,10 @@ public:
requestId, llmRequest, pinOnRelease); requestId, llmRequest, pinOnRelease);
} }
std::optional<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId, std::vector<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
tensorrt_llm::common::OptionalRef<tb::LlmRequest const> llmRequest, bool pinBlocks) override tensorrt_llm::common::OptionalRef<tb::LlmRequest const> llmRequest, bool pinBlocks) override
{ {
PYBIND11_OVERLOAD_PURE(std::optional<tbk::KVCacheBlock::IdType>, tbk::BaseKVCacheManager, storeBlocksForReuse, PYBIND11_OVERLOAD_PURE(std::vector<tbk::KVCacheBlock::IdType>, tbk::BaseKVCacheManager, storeBlocksForReuse,
requestId, llmRequest, pinBlocks); requestId, llmRequest, pinBlocks);
} }

View File

@ -4066,11 +4066,13 @@ TEST_F(KVCacheManagerTest, PinAndUnpinBlocksById)
kvCacheManager.pinBlocks(requestId); kvCacheManager.pinBlocks(requestId);
auto lastBlockIdOpt = kvCacheManager.getLastBlockId(requestId); auto lastBlockIdOpt = kvCacheManager.getLastBlockId(requestId);
ASSERT_TRUE(lastBlockIdOpt.has_value()); ASSERT_TRUE(lastBlockIdOpt.has_value());
auto const& allBlockIds = kvCacheManager.getCacheBlockIds(requestId, maxAttentionWindow)[0];
std::vector<SizeType32> pinnedBlockIds(allBlockIds.begin(), allBlockIds.end());
(void) kvCacheManager.removeSequence(requestId, llmRequest); (void) kvCacheManager.removeSequence(requestId, llmRequest);
auto const freeAfterRemovePinned = kvCacheManager.getNumFreeBlocks(); auto const freeAfterRemovePinned = kvCacheManager.getNumFreeBlocks();
EXPECT_LT(freeAfterRemovePinned, totalBlocks); EXPECT_LT(freeAfterRemovePinned, totalBlocks);
kvCacheManager.unpinBlocksById(lastBlockIdOpt.value()); kvCacheManager.unpinBlocksById(pinnedBlockIds);
auto const freeAfterUnpin = kvCacheManager.getNumFreeBlocks(); auto const freeAfterUnpin = kvCacheManager.getNumFreeBlocks();
EXPECT_EQ(freeAfterUnpin, totalBlocks); EXPECT_EQ(freeAfterUnpin, totalBlocks);
} }

View File

@ -1167,7 +1167,8 @@ class PyExecutor:
for req in previous_batch.scheduled_ctx_reqs: for req in previous_batch.scheduled_ctx_reqs:
if req.is_context_only_request and ( if req.is_context_only_request and (
req.is_context_finished req.is_context_finished
or req.is_finished_due_to_length): or req.is_finished_due_to_length
) and not req.is_finished_due_to_cancellation:
block_id = self.kv_cache_manager.store_blocks_for_reuse( block_id = self.kv_cache_manager.store_blocks_for_reuse(
req, True) req, True)
self.ctx_in_transmission_requests[ self.ctx_in_transmission_requests[
@ -1436,7 +1437,8 @@ class PyExecutor:
for req in scheduled_batch.context_requests: for req in scheduled_batch.context_requests:
if req.is_context_only_request and ( if req.is_context_only_request and (
req.is_context_finished req.is_context_finished
or req.is_finished_due_to_length): or req.is_finished_due_to_length
) and not req.is_finished_due_to_cancellation:
block_id = self.kv_cache_manager.store_blocks_for_reuse( block_id = self.kv_cache_manager.store_blocks_for_reuse(
req, True) req, True)
self.ctx_in_transmission_requests[ self.ctx_in_transmission_requests[
@ -1686,7 +1688,8 @@ class PyExecutor:
for req in self.previous_batch.sample_state.scheduled_requests.context_requests: for req in self.previous_batch.sample_state.scheduled_requests.context_requests:
if req.is_context_only_request and ( if req.is_context_only_request and (
req.is_context_finished req.is_context_finished
or req.is_finished_due_to_length): or req.is_finished_due_to_length
) and not req.is_finished_due_to_cancellation:
block_id = self.kv_cache_manager.store_blocks_for_reuse( block_id = self.kv_cache_manager.store_blocks_for_reuse(
req, True) req, True)
self.ctx_in_transmission_requests[ self.ctx_in_transmission_requests[
@ -2196,8 +2199,9 @@ class PyExecutor:
if (scheduled_ctx_requests is None or len(scheduled_ctx_requests) == 0): if (scheduled_ctx_requests is None or len(scheduled_ctx_requests) == 0):
return [] return []
for req in scheduled_ctx_requests: for req in scheduled_ctx_requests:
if req.is_context_only_request and (req.is_context_finished or if req.is_context_only_request and (
req.is_finished_due_to_length): req.is_context_finished or req.is_finished_due_to_length
) and not req.is_finished_due_to_cancellation:
self.kv_cache_transceiver.respond_and_send_async(req) self.kv_cache_transceiver.respond_and_send_async(req)
for resource_mgr_type in ( for resource_mgr_type in (
ResourceManagerType.SEQ_SLOT_MANAGER, ResourceManagerType.SEQ_SLOT_MANAGER,

View File

@ -1431,7 +1431,8 @@ class ResourceManager:
resource_manager.update_resources(scheduled_batch) resource_manager.update_resources(scheduled_batch)
def free_resources(self, request: LlmRequest): def free_resources(self, request: LlmRequest):
for _, resource_manager in reversed(self.resource_managers.items()): for resource_type, resource_manager in reversed(
self.resource_managers.items()):
if hasattr(resource_manager, "free_resources"): if hasattr(resource_manager, "free_resources"):
resource_manager.free_resources(request) resource_manager.free_resources(request)

View File

@ -0,0 +1,44 @@
hostname: localhost
port: 8000
model: DeepSeek-V3-Lite/bf16
backend: "pytorch"
enable_autotuner: False
context_servers:
disable_overlap_scheduler: True
num_instances: 1
tensor_parallel_size: 1
pipeline_parallel_size: 1
max_num_tokens: 16384
max_seq_len: 32768
enable_chunked_prefill: True
kv_cache_config:
enable_block_reuse: True
enable_partial_reuse: True
free_gpu_memory_fraction: 0.3
cache_transceiver_config:
backend: "DEFAULT"
max_tokens_in_buffer: 32768
cuda_graph_config:
enable_padding: True
max_batch_size: 1
urls:
- "localhost:8001"
generation_servers:
num_instances: 1
tensor_parallel_size: 1
pipeline_parallel_size: 1
max_num_tokens: 2048
max_seq_len: 32768
enable_chunked_prefill: True
kv_cache_config:
enable_block_reuse: True
enable_partial_reuse: True
free_gpu_memory_fraction: 0.85
cache_transceiver_config:
backend: "DEFAULT"
max_tokens_in_buffer: 32768
cuda_graph_config:
enable_padding: True
max_batch_size: 64
urls:
- "localhost:8002"

View File

@ -0,0 +1,44 @@
hostname: localhost
port: 8000
model: DeepSeek-V3-0324-FP4
backend: "pytorch"
enable_autotuner: False
context_servers:
disable_overlap_scheduler: True
num_instances: 1
tensor_parallel_size: 4
pipeline_parallel_size: 1
max_num_tokens: 12000
max_seq_len: 262144
enable_chunked_prefill: True
kv_cache_config:
enable_block_reuse: True
enable_partial_reuse: True
free_gpu_memory_fraction: 0.2
cache_transceiver_config:
backend: "DEFAULT"
max_tokens_in_buffer: 262144
cuda_graph_config:
enable_padding: True
max_batch_size: 1
urls:
- "localhost:8001"
generation_servers:
num_instances: 1
tensor_parallel_size: 4
pipeline_parallel_size: 1
max_num_tokens: 2048
max_seq_len: 262144
enable_chunked_prefill: True
kv_cache_config:
enable_block_reuse: True
enable_partial_reuse: True
free_gpu_memory_fraction: 0.3
cache_transceiver_config:
backend: "DEFAULT"
max_tokens_in_buffer: 262144
cuda_graph_config:
enable_padding: True
max_batch_size: 11
urls:
- "localhost:8002"

View File

@ -200,6 +200,10 @@ def get_test_config(test_desc, example_dir, test_root):
"gpt_oss_120b_stress": "gpt_oss_120b_stress":
(4, (4,
f"{test_configs_root}/disagg_config_ctxtp2_gentp2_gptoss_tllm.yaml"), f"{test_configs_root}/disagg_config_ctxtp2_gentp2_gptoss_tllm.yaml"),
"cancel_stress_test":
(2, f"{test_configs_root}/disagg_config_cancel_stress_test.yaml"),
"cancel_stress_test_large":
(8, f"{test_configs_root}/disagg_config_cancel_stress_test_large.yaml"),
} }
if test_desc not in config_map: if test_desc not in config_map:
@ -2098,3 +2102,211 @@ def test_disaggregated_stress_test(disaggregated_test_root,
threshold=test_config.accuracy_threshold, threshold=test_config.accuracy_threshold,
env=llm_venv._new_env, env=llm_venv._new_env,
cwd=llm_venv.get_working_directory()) cwd=llm_venv.get_working_directory())
def run_cancel_stress_test(server_url: str,
num_bursts: int = 5,
requests_per_burst: int = 32,
prompt_len_range: tuple = (2000, 8000),
cancel_after_range: tuple = (0.01, 0.1)):
"""
Stress test that sends requests with large contexts and cancels them
during prefill to test resource cleanup under cancellation.
Args:
server_url: The server URL (e.g., "http://localhost:8000")
num_bursts: Number of request bursts to send
requests_per_burst: Number of concurrent requests per burst
prompt_len_range: (min, max) prompt length in tokens
cancel_after_range: (min, max) seconds to wait before cancelling
"""
import asyncio
import random
import time
import aiohttp
async def spam_and_cancel(session, req_id, url, prompt_len_range,
cancel_after_range):
"""Send a request and cancel it during prefill."""
prompt_len = random.randint(prompt_len_range[0], prompt_len_range[1])
prompt = "test " * (prompt_len // 5)
payload = {
"model": "test-model",
"prompt": prompt,
"max_tokens": 10,
"stream": True
}
try:
cancel_after = random.uniform(cancel_after_range[0],
cancel_after_range[1])
start = time.time()
async with session.post(
f"{url}/v1/completions",
json=payload,
timeout=aiohttp.ClientTimeout(total=60)) as resp:
async for line in resp.content:
if time.time() - start > cancel_after:
# Force disconnect during prefill
break
except Exception:
pass # Connection abort is expected
async def run_bursts():
async with aiohttp.ClientSession() as session:
for burst_idx in range(num_bursts):
tasks = [
spam_and_cancel(session, i, server_url, prompt_len_range,
cancel_after_range)
for i in range(requests_per_burst)
]
await asyncio.gather(*tasks)
logger.info(
f"Completed burst {burst_idx + 1}/{num_bursts} ({requests_per_burst} requests)"
)
await asyncio.sleep(0.05)
asyncio.run(run_bursts())
def run_disaggregated_cancel_test(example_dir,
test_desc,
env=None,
cwd=None,
num_bursts=64,
requests_per_burst=64):
"""Run disaggregated test with request cancellation stress test."""
cleanup_output_files()
run_env = env.copy()
run_env["UCX_TLS"] = "^ib"
num_ranks, config_file = get_test_config(test_desc, example_dir,
os.path.dirname(__file__))
workers_cmd = [
'mpirun', '--allow-run-as-root', '--oversubscribe', '-n',
str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
config_file
]
server_start_timeout = 1200
server_cmd = [
'trtllm-serve', 'disaggregated', '--server_start_timeout',
str(server_start_timeout), '-c', config_file
]
server_host, server_port = get_disagg_server_url_from_cfg(config_file)
server_url = f"http://{server_host}:{server_port}"
try:
with (open('output_workers.log', 'w') as output_workers,
popen(workers_cmd,
stdout=output_workers,
stderr=subprocess.STDOUT,
env=run_env,
cwd=cwd) as workers_proc, open('output_disagg.log', 'w') as
output_disagg,
popen(server_cmd,
stdout=output_disagg,
stderr=subprocess.STDOUT,
env=run_env,
cwd=cwd) as server_proc):
# Wait for server to be ready
if not wait_for_server(server_host,
server_port,
timeout_seconds=server_start_timeout):
raise RuntimeError(
f"Disaggregated server did not become ready within {server_start_timeout} seconds"
)
# Run the cancel stress test
run_cancel_stress_test(server_url,
num_bursts=num_bursts,
requests_per_burst=requests_per_burst)
# Verify server is still healthy after stress test by sending a normal request
client_dir = f"{example_dir}/clients"
client_cmd = [
'python3', f'{client_dir}/disagg_client.py', '-c', config_file,
'-p', f'{client_dir}/prompts.json', '--ignore-eos',
'--server-start-timeout',
str(server_start_timeout)
]
check_call(client_cmd,
env=env,
poll_procs=[workers_proc, server_proc])
except Exception:
logger.error("-------- Workers output --------")
with open('output_workers.log', 'r') as f:
logger.error(f.read())
logger.error("-------- Disagg server output --------")
with open('output_disagg.log', 'r') as f:
logger.error(f.read())
raise
finally:
if 'server_proc' in locals() and 'workers_proc' in locals():
server_proc.terminate()
workers_proc.terminate()
server_proc.wait()
workers_proc.wait()
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-bf16'],
indirect=True)
def test_disaggregated_cancel_large_context_requests(disaggregated_test_root,
disaggregated_example_root,
llm_venv,
deepseek_v3_model_root):
"""
Test that the disaggregated server handles request cancellations gracefully.
This test sends bursts of requests with large contexts and cancels them
during prefill to stress test resource cleanup.
"""
src_dst_dict = {
deepseek_v3_model_root:
f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/bf16",
}
for src, dst in src_dst_dict.items():
if not os.path.islink(dst):
os.makedirs(os.path.dirname(dst), exist_ok=True)
os.symlink(src, dst, target_is_directory=True)
run_disaggregated_cancel_test(disaggregated_example_root,
"cancel_stress_test",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory(),
num_bursts=5,
requests_per_burst=32)
@pytest.mark.skip_less_device(8)
@skip_pre_blackwell
@pytest.mark.parametrize("model_path", ['DeepSeek-V3-0324-FP4'])
def test_disaggregated_cancel_large_context_requests_long(
disaggregated_test_root, disaggregated_example_root, llm_venv,
model_path):
"""Test that disaggregated server handles request cancellations gracefully.
This test sends bursts of requests with large contexts and cancels them
during prefill to stress test resource cleanup.
"""
model_dir = f"{llm_models_root()}/{model_path}"
src_dst_dict = {
model_dir: f"{llm_venv.get_working_directory()}/{model_path}",
}
for src, dst in src_dst_dict.items():
if not os.path.islink(dst):
os.makedirs(os.path.dirname(dst), exist_ok=True)
os.symlink(src, dst, target_is_directory=True)
run_disaggregated_cancel_test(disaggregated_example_root,
"cancel_stress_test_large",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory(),
num_bursts=1000,
requests_per_burst=32)

View File

@ -43,6 +43,7 @@ l0_dgx_h100:
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-False] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]
- unittest/llmapi/apps/test_disagg_serving_perf_metrics.py - unittest/llmapi/apps/test_disagg_serving_perf_metrics.py
- disaggregated/test_disaggregated.py::test_disaggregated_cancel_large_context_requests[DeepSeek-V3-Lite-bf16]
# ------------- AutoDeploy tests --------------- # ------------- AutoDeploy tests ---------------
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-2] - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-2]
# llmapi # llmapi