mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 14:07:21 +08:00
[https://nvbugs/5689235][fix] Fix cancellation+chunked prefill+disagg (#10111)
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
This commit is contained in:
parent
18a33764b5
commit
48b09e5a25
@ -648,7 +648,7 @@ public:
|
||||
|
||||
void replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx);
|
||||
|
||||
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false);
|
||||
|
||||
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
|
||||
@ -853,8 +853,8 @@ public:
|
||||
//! \param blockKeys Key of each block.
|
||||
//! \param blockIds Id of each block.
|
||||
//! \param pinBlocks If true, increment ref count for blocks while storing (pin on store).
|
||||
//! \return Pair of (num blocks stored for reuse, id of the last block stored if any).
|
||||
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
|
||||
//! \return Pair of (num blocks stored for reuse, vector of pinned block IDs).
|
||||
[[nodiscard]] std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> storeBlocks(
|
||||
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
|
||||
bool pinBlocks = false);
|
||||
|
||||
@ -886,8 +886,8 @@ public:
|
||||
|
||||
[[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey);
|
||||
|
||||
//! \brief Unpin blocks by starting from a block id and walking prev pointers.
|
||||
void unpinBlocksById(KVCacheBlock::IdType blockId);
|
||||
//! \brief Unpin blocks by block ids directly
|
||||
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds);
|
||||
|
||||
void initializeSequenceStorageValidity(LlmRequest::RequestIdType requestId)
|
||||
{
|
||||
@ -1103,7 +1103,7 @@ public:
|
||||
std::optional<KVCacheBlock::IdType> releaseBlocks(
|
||||
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
|
||||
|
||||
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
|
||||
|
||||
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
|
||||
@ -1112,7 +1112,7 @@ public:
|
||||
/// @param sequence The generation request whose blocks should be pinned.
|
||||
void pinBlocks(GenerationRequest& sequence);
|
||||
|
||||
void unpinBlocksById(KVCacheBlock::IdType blockId);
|
||||
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds);
|
||||
|
||||
void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize);
|
||||
|
||||
@ -1133,7 +1133,7 @@ public:
|
||||
void offloadBlock(BlockPtr const& block, SizeType32 windowSize,
|
||||
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");
|
||||
|
||||
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
|
||||
[[nodiscard]] std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> storeBlocks(
|
||||
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
|
||||
SizeType32 windowSize, bool pinBlocks = false)
|
||||
{
|
||||
@ -1584,7 +1584,7 @@ public:
|
||||
virtual void storeNewBlock(LlmRequest const& llmRequest) = 0;
|
||||
|
||||
/// \brief Store blocks for reuse for a given request id
|
||||
[[nodiscard]] virtual std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||
[[nodiscard]] virtual std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false)
|
||||
= 0;
|
||||
|
||||
@ -1678,7 +1678,7 @@ public:
|
||||
BlockKey const& blockKey, SizeType32 windowSize)
|
||||
= 0;
|
||||
|
||||
virtual void unpinBlocksById(KVCacheBlock::IdType blockId) = 0;
|
||||
virtual void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) = 0;
|
||||
};
|
||||
|
||||
class KVCacheManager : public BaseKVCacheManager
|
||||
@ -1939,7 +1939,7 @@ public:
|
||||
//! \brief Store newest blocks for reuse
|
||||
void storeNewBlock(LlmRequest const& llmRequest) override;
|
||||
|
||||
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false) override;
|
||||
|
||||
[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
|
||||
@ -1960,7 +1960,7 @@ public:
|
||||
|
||||
void pinBlocks(LlmRequest::RequestIdType requestId) override;
|
||||
|
||||
void unpinBlocksById(KVCacheBlock::IdType blockId) override;
|
||||
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) override;
|
||||
|
||||
std::optional<KVCacheBlock::IdType> getLastBlockId(LlmRequest::RequestIdType requestId) const override;
|
||||
|
||||
|
||||
@ -1667,6 +1667,12 @@ public:
|
||||
[](auto reason) { return reason == executor::FinishReason::kLENGTH; });
|
||||
}
|
||||
|
||||
[[nodiscard]] bool isFinishedDueToCancellation() const noexcept
|
||||
{
|
||||
return std::all_of(mFinishReasons.begin(), mFinishReasons.end(),
|
||||
[](auto reason) { return reason == executor::FinishReason::kCANCELLED; });
|
||||
}
|
||||
|
||||
[[nodiscard]] bool isTimedOut() const
|
||||
{
|
||||
if (!mAllottedTimeMs.has_value())
|
||||
|
||||
@ -1556,7 +1556,7 @@ void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAm
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::storeBlocks(
|
||||
std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> WindowBlockManager::storeBlocks(
|
||||
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds, bool pinBlocks)
|
||||
{
|
||||
SizeType32 numBlocksStoredForReuse = 0;
|
||||
@ -1569,7 +1569,7 @@ std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::s
|
||||
|
||||
auto numBlocks = blockKeys.size();
|
||||
std::vector<BlockPtr> storedBlocks;
|
||||
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
|
||||
std::vector<KVCacheBlock::IdType> pinnedBlockIds;
|
||||
for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt)
|
||||
{
|
||||
auto const bid = blockIds[blockCnt];
|
||||
@ -1620,14 +1620,14 @@ std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::s
|
||||
if (pinBlocks)
|
||||
{
|
||||
searchRoot->incRefCount();
|
||||
pinnedBlockIds.push_back(searchRoot->getBlockId());
|
||||
}
|
||||
lastStoredId = searchRoot->getBlockId();
|
||||
}
|
||||
if (mEventManager)
|
||||
{
|
||||
mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize);
|
||||
}
|
||||
return {numBlocksStoredForReuse, lastStoredId};
|
||||
return {numBlocksStoredForReuse, pinnedBlockIds};
|
||||
}
|
||||
|
||||
void BlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx)
|
||||
@ -1715,15 +1715,15 @@ std::deque<tle::KVCacheEvent> BlockManager::getLatestEvents(std::optional<std::c
|
||||
return mEventManager ? mEventManager->getEvents(timeout) : std::deque<tle::KVCacheEvent>{};
|
||||
}
|
||||
|
||||
std::optional<KVCacheBlock::IdType> BlockManager::storeBlocksForReuse(
|
||||
std::vector<KVCacheBlock::IdType> BlockManager::storeBlocksForReuse(
|
||||
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
|
||||
{
|
||||
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
|
||||
std::vector<KVCacheBlock::IdType> pinnedBlockIds;
|
||||
for (auto& [_, manager] : mWindowBlockManagers)
|
||||
{
|
||||
lastStoredId = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
|
||||
pinnedBlockIds = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
|
||||
}
|
||||
return lastStoredId;
|
||||
return pinnedBlockIds;
|
||||
}
|
||||
|
||||
std::optional<KVCacheBlock::IdType> BlockManager::releaseBlocks(
|
||||
@ -1767,7 +1767,7 @@ void BlockManager::pinBlocks(GenerationRequest& sequence)
|
||||
}
|
||||
}
|
||||
|
||||
void BlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
|
||||
void BlockManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
|
||||
{
|
||||
// Use the first window size
|
||||
if (mWindowBlockManagers.empty())
|
||||
@ -1775,7 +1775,7 @@ void BlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
|
||||
return;
|
||||
}
|
||||
auto& firstManager = mWindowBlockManagers.begin()->second;
|
||||
firstManager.unpinBlocksById(blockId);
|
||||
firstManager.unpinBlocksById(blockIds);
|
||||
}
|
||||
|
||||
void WindowBlockManager::pinBlocks(GenerationRequest& sequence)
|
||||
@ -1788,21 +1788,26 @@ void WindowBlockManager::pinBlocks(GenerationRequest& sequence)
|
||||
}
|
||||
}
|
||||
|
||||
void WindowBlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
|
||||
void WindowBlockManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
|
||||
{
|
||||
if (blockId < 0 || static_cast<size_t>(blockId) >= mAllBlocksById.size())
|
||||
if (blockIds.empty())
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto block = mAllBlocksById[blockId];
|
||||
while (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
|
||||
|
||||
for (auto const& blockId : blockIds)
|
||||
{
|
||||
block->decRefCount();
|
||||
if (!block->hasRefs())
|
||||
TLLM_CHECK_WITH_INFO(blockId >= 0 && static_cast<size_t>(blockId) < mAllBlocksById.size(),
|
||||
"Block id %d is out of range", blockId);
|
||||
auto block = mAllBlocksById[blockId];
|
||||
if (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
|
||||
{
|
||||
mEvictionPolicy->releaseBlock(block);
|
||||
block->decRefCount();
|
||||
if (!block->hasRefs())
|
||||
{
|
||||
mEvictionPolicy->releaseBlock(block);
|
||||
}
|
||||
}
|
||||
block = std::move(block->getPrevBlock());
|
||||
}
|
||||
}
|
||||
|
||||
@ -1870,7 +1875,7 @@ void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef<
|
||||
(void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
|
||||
}
|
||||
|
||||
std::optional<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
|
||||
std::vector<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
|
||||
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
|
||||
{
|
||||
auto constexpr beamIdx = 0;
|
||||
@ -1883,7 +1888,10 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
|
||||
auto const usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1;
|
||||
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, true);
|
||||
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);
|
||||
return storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks).second;
|
||||
|
||||
auto [numStored, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks);
|
||||
|
||||
return pinnedBlockIds;
|
||||
}
|
||||
|
||||
std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
|
||||
@ -1922,7 +1930,7 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
|
||||
std::transform(allocatedBlocks.begin(), allocatedBlocks.end(), cacheBlockIds.begin(),
|
||||
[](BlockPtr const& block) { return block->getBlockId(); });
|
||||
|
||||
auto [numBlocksStoredForReuse, lastStoredId] = storeBlocks(std::move(blockKeys), cacheBlockIds);
|
||||
auto [numBlocksStoredForReuse, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds);
|
||||
TLLM_LOG_DEBUG("%s::releaseBlocks Request %lu, %d blocks stored for reuse", mLogPrefix.c_str(),
|
||||
sequence.getRequestId(), numBlocksStoredForReuse);
|
||||
}
|
||||
@ -2499,15 +2507,14 @@ std::optional<KVCacheBlock::IdType> KVCacheManager::removeSequence(
|
||||
return lastStoredId;
|
||||
}
|
||||
|
||||
std::optional<KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse(
|
||||
std::vector<KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse(
|
||||
RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
|
||||
{
|
||||
TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
|
||||
auto& sequence = getSequence(requestId);
|
||||
std::optional<KVCacheBlock::IdType> lastStoredId
|
||||
= mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
|
||||
auto pinnedBlockIds = mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
|
||||
TLLM_LOG_TRACE("[%s]::%s stop", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
|
||||
return lastStoredId;
|
||||
return pinnedBlockIds;
|
||||
}
|
||||
|
||||
void KVCacheManager::schedulingRemoveSequence(RequestIdType requestId)
|
||||
@ -2522,9 +2529,9 @@ void KVCacheManager::pinBlocks(RequestIdType requestId)
|
||||
mBlockManager.pinBlocks(sequence);
|
||||
}
|
||||
|
||||
void KVCacheManager::unpinBlocksById(KVCacheBlock::IdType blockId)
|
||||
void KVCacheManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
|
||||
{
|
||||
mBlockManager.unpinBlocksById(blockId);
|
||||
mBlockManager.unpinBlocksById(blockIds);
|
||||
}
|
||||
|
||||
SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const
|
||||
|
||||
@ -2179,11 +2179,11 @@ void Executor::Impl::terminateContextFinishedRequests(InTransList& inTransmissio
|
||||
auto req = item.request;
|
||||
if (req->isDisaggContextCompleteState())
|
||||
{
|
||||
// If lastBlockId was tracked, unpin it. Otherwise, just terminate.
|
||||
// If pinnedBlockIds were tracked, unpin them. Otherwise, just terminate.
|
||||
auto kvMgr = mModel->getKVCacheManager();
|
||||
if (kvMgr && item.lastBlockId.has_value())
|
||||
if (kvMgr && !item.pinnedBlockIds.empty())
|
||||
{
|
||||
kvMgr->unpinBlocksById(item.lastBlockId.value());
|
||||
kvMgr->unpinBlocksById(item.pinnedBlockIds);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -2234,14 +2234,14 @@ Executor::Impl::RequestList Executor::Impl::populateNewResponses(
|
||||
// move the in transmission requests to another tracker
|
||||
if (llmReq->isDisaggContextTransmissionState())
|
||||
{
|
||||
std::optional<SizeType32> lastBlockId{};
|
||||
std::vector<SizeType32> pinnedBlockIds{};
|
||||
auto kvMgr = mModel->getKVCacheManager();
|
||||
if (kvMgr && kvMgr->isEnableBlockReuse() && !kvMgr->getBlockManager().isVariableWindow())
|
||||
{
|
||||
lastBlockId = kvMgr->storeBlocksForReuse(llmReq->mRequestId, llmReq, /*pinBlocks=*/true);
|
||||
pinnedBlockIds = kvMgr->storeBlocksForReuse(llmReq->mRequestId, llmReq, /*pinBlocks=*/true);
|
||||
mModel->terminateRequest(llmReq);
|
||||
}
|
||||
inTransmissionRequests.push_back(InTransmissionItem{*it, lastBlockId});
|
||||
inTransmissionRequests.push_back(InTransmissionItem{*it, pinnedBlockIds});
|
||||
}
|
||||
finishedRequests.push_back(*it);
|
||||
it = activeRequests.erase(it);
|
||||
|
||||
@ -80,12 +80,12 @@ class Executor::Impl
|
||||
using RequestList = std::list<LlmRequestPtr>;
|
||||
|
||||
// When block reuse is enabled for context worker for disaggregated serving,
|
||||
// we need to store the last block id so that we can unpin the block when
|
||||
// we need to store the pinned block ids so that we can unpin them when
|
||||
// the request is finished.
|
||||
struct InTransmissionItem
|
||||
{
|
||||
LlmRequestPtr request;
|
||||
std::optional<SizeType32> lastBlockId;
|
||||
std::vector<SizeType32> pinnedBlockIds;
|
||||
};
|
||||
|
||||
using InTransList = std::list<InTransmissionItem>;
|
||||
|
||||
@ -161,6 +161,7 @@ void initBindings(nb::module_& m)
|
||||
.def("set_finished_reason", &GenLlmReq::setFinishedReason, nb::arg("finish_reason"), nb::arg("beam"))
|
||||
.def_prop_ro("is_finished", &GenLlmReq::isFinished)
|
||||
.def_prop_ro("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength)
|
||||
.def_prop_ro("is_finished_due_to_cancellation", &GenLlmReq::isFinishedDueToCancellation)
|
||||
.def_prop_rw(
|
||||
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
|
||||
.def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)
|
||||
|
||||
@ -123,7 +123,7 @@ public:
|
||||
NB_OVERRIDE_PURE(removeSequence, requestId, llmRequest, pinOnRelease);
|
||||
}
|
||||
|
||||
std::optional<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
|
||||
std::vector<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
|
||||
tensorrt_llm::common::OptionalRef<tb::LlmRequest const> llmRequest, bool pinBlocks) override
|
||||
{
|
||||
NB_OVERRIDE_PURE(storeBlocksForReuse, requestId, llmRequest, pinBlocks);
|
||||
|
||||
@ -165,6 +165,7 @@ void initBindings(pybind11::module_& m)
|
||||
.def("set_finished_reason", &GenLlmReq::setFinishedReason, py::arg("finish_reason"), py::arg("beam"))
|
||||
.def_property_readonly("is_finished", &GenLlmReq::isFinished)
|
||||
.def_property_readonly("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength)
|
||||
.def_property_readonly("is_finished_due_to_cancellation", &GenLlmReq::isFinishedDueToCancellation)
|
||||
.def_property(
|
||||
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
|
||||
.def_property_readonly("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)
|
||||
|
||||
@ -111,10 +111,10 @@ public:
|
||||
requestId, llmRequest, pinOnRelease);
|
||||
}
|
||||
|
||||
std::optional<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
|
||||
std::vector<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
|
||||
tensorrt_llm::common::OptionalRef<tb::LlmRequest const> llmRequest, bool pinBlocks) override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(std::optional<tbk::KVCacheBlock::IdType>, tbk::BaseKVCacheManager, storeBlocksForReuse,
|
||||
PYBIND11_OVERLOAD_PURE(std::vector<tbk::KVCacheBlock::IdType>, tbk::BaseKVCacheManager, storeBlocksForReuse,
|
||||
requestId, llmRequest, pinBlocks);
|
||||
}
|
||||
|
||||
|
||||
@ -4066,11 +4066,13 @@ TEST_F(KVCacheManagerTest, PinAndUnpinBlocksById)
|
||||
kvCacheManager.pinBlocks(requestId);
|
||||
auto lastBlockIdOpt = kvCacheManager.getLastBlockId(requestId);
|
||||
ASSERT_TRUE(lastBlockIdOpt.has_value());
|
||||
auto const& allBlockIds = kvCacheManager.getCacheBlockIds(requestId, maxAttentionWindow)[0];
|
||||
std::vector<SizeType32> pinnedBlockIds(allBlockIds.begin(), allBlockIds.end());
|
||||
(void) kvCacheManager.removeSequence(requestId, llmRequest);
|
||||
auto const freeAfterRemovePinned = kvCacheManager.getNumFreeBlocks();
|
||||
EXPECT_LT(freeAfterRemovePinned, totalBlocks);
|
||||
|
||||
kvCacheManager.unpinBlocksById(lastBlockIdOpt.value());
|
||||
kvCacheManager.unpinBlocksById(pinnedBlockIds);
|
||||
auto const freeAfterUnpin = kvCacheManager.getNumFreeBlocks();
|
||||
EXPECT_EQ(freeAfterUnpin, totalBlocks);
|
||||
}
|
||||
|
||||
@ -1167,7 +1167,8 @@ class PyExecutor:
|
||||
for req in previous_batch.scheduled_ctx_reqs:
|
||||
if req.is_context_only_request and (
|
||||
req.is_context_finished
|
||||
or req.is_finished_due_to_length):
|
||||
or req.is_finished_due_to_length
|
||||
) and not req.is_finished_due_to_cancellation:
|
||||
block_id = self.kv_cache_manager.store_blocks_for_reuse(
|
||||
req, True)
|
||||
self.ctx_in_transmission_requests[
|
||||
@ -1436,7 +1437,8 @@ class PyExecutor:
|
||||
for req in scheduled_batch.context_requests:
|
||||
if req.is_context_only_request and (
|
||||
req.is_context_finished
|
||||
or req.is_finished_due_to_length):
|
||||
or req.is_finished_due_to_length
|
||||
) and not req.is_finished_due_to_cancellation:
|
||||
block_id = self.kv_cache_manager.store_blocks_for_reuse(
|
||||
req, True)
|
||||
self.ctx_in_transmission_requests[
|
||||
@ -1686,7 +1688,8 @@ class PyExecutor:
|
||||
for req in self.previous_batch.sample_state.scheduled_requests.context_requests:
|
||||
if req.is_context_only_request and (
|
||||
req.is_context_finished
|
||||
or req.is_finished_due_to_length):
|
||||
or req.is_finished_due_to_length
|
||||
) and not req.is_finished_due_to_cancellation:
|
||||
block_id = self.kv_cache_manager.store_blocks_for_reuse(
|
||||
req, True)
|
||||
self.ctx_in_transmission_requests[
|
||||
@ -2196,8 +2199,9 @@ class PyExecutor:
|
||||
if (scheduled_ctx_requests is None or len(scheduled_ctx_requests) == 0):
|
||||
return []
|
||||
for req in scheduled_ctx_requests:
|
||||
if req.is_context_only_request and (req.is_context_finished or
|
||||
req.is_finished_due_to_length):
|
||||
if req.is_context_only_request and (
|
||||
req.is_context_finished or req.is_finished_due_to_length
|
||||
) and not req.is_finished_due_to_cancellation:
|
||||
self.kv_cache_transceiver.respond_and_send_async(req)
|
||||
for resource_mgr_type in (
|
||||
ResourceManagerType.SEQ_SLOT_MANAGER,
|
||||
|
||||
@ -1431,7 +1431,8 @@ class ResourceManager:
|
||||
resource_manager.update_resources(scheduled_batch)
|
||||
|
||||
def free_resources(self, request: LlmRequest):
|
||||
for _, resource_manager in reversed(self.resource_managers.items()):
|
||||
for resource_type, resource_manager in reversed(
|
||||
self.resource_managers.items()):
|
||||
if hasattr(resource_manager, "free_resources"):
|
||||
resource_manager.free_resources(request)
|
||||
|
||||
|
||||
@ -0,0 +1,44 @@
|
||||
hostname: localhost
|
||||
port: 8000
|
||||
model: DeepSeek-V3-Lite/bf16
|
||||
backend: "pytorch"
|
||||
enable_autotuner: False
|
||||
context_servers:
|
||||
disable_overlap_scheduler: True
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 1
|
||||
pipeline_parallel_size: 1
|
||||
max_num_tokens: 16384
|
||||
max_seq_len: 32768
|
||||
enable_chunked_prefill: True
|
||||
kv_cache_config:
|
||||
enable_block_reuse: True
|
||||
enable_partial_reuse: True
|
||||
free_gpu_memory_fraction: 0.3
|
||||
cache_transceiver_config:
|
||||
backend: "DEFAULT"
|
||||
max_tokens_in_buffer: 32768
|
||||
cuda_graph_config:
|
||||
enable_padding: True
|
||||
max_batch_size: 1
|
||||
urls:
|
||||
- "localhost:8001"
|
||||
generation_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 1
|
||||
pipeline_parallel_size: 1
|
||||
max_num_tokens: 2048
|
||||
max_seq_len: 32768
|
||||
enable_chunked_prefill: True
|
||||
kv_cache_config:
|
||||
enable_block_reuse: True
|
||||
enable_partial_reuse: True
|
||||
free_gpu_memory_fraction: 0.85
|
||||
cache_transceiver_config:
|
||||
backend: "DEFAULT"
|
||||
max_tokens_in_buffer: 32768
|
||||
cuda_graph_config:
|
||||
enable_padding: True
|
||||
max_batch_size: 64
|
||||
urls:
|
||||
- "localhost:8002"
|
||||
@ -0,0 +1,44 @@
|
||||
hostname: localhost
|
||||
port: 8000
|
||||
model: DeepSeek-V3-0324-FP4
|
||||
backend: "pytorch"
|
||||
enable_autotuner: False
|
||||
context_servers:
|
||||
disable_overlap_scheduler: True
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 4
|
||||
pipeline_parallel_size: 1
|
||||
max_num_tokens: 12000
|
||||
max_seq_len: 262144
|
||||
enable_chunked_prefill: True
|
||||
kv_cache_config:
|
||||
enable_block_reuse: True
|
||||
enable_partial_reuse: True
|
||||
free_gpu_memory_fraction: 0.2
|
||||
cache_transceiver_config:
|
||||
backend: "DEFAULT"
|
||||
max_tokens_in_buffer: 262144
|
||||
cuda_graph_config:
|
||||
enable_padding: True
|
||||
max_batch_size: 1
|
||||
urls:
|
||||
- "localhost:8001"
|
||||
generation_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 4
|
||||
pipeline_parallel_size: 1
|
||||
max_num_tokens: 2048
|
||||
max_seq_len: 262144
|
||||
enable_chunked_prefill: True
|
||||
kv_cache_config:
|
||||
enable_block_reuse: True
|
||||
enable_partial_reuse: True
|
||||
free_gpu_memory_fraction: 0.3
|
||||
cache_transceiver_config:
|
||||
backend: "DEFAULT"
|
||||
max_tokens_in_buffer: 262144
|
||||
cuda_graph_config:
|
||||
enable_padding: True
|
||||
max_batch_size: 11
|
||||
urls:
|
||||
- "localhost:8002"
|
||||
@ -200,6 +200,10 @@ def get_test_config(test_desc, example_dir, test_root):
|
||||
"gpt_oss_120b_stress":
|
||||
(4,
|
||||
f"{test_configs_root}/disagg_config_ctxtp2_gentp2_gptoss_tllm.yaml"),
|
||||
"cancel_stress_test":
|
||||
(2, f"{test_configs_root}/disagg_config_cancel_stress_test.yaml"),
|
||||
"cancel_stress_test_large":
|
||||
(8, f"{test_configs_root}/disagg_config_cancel_stress_test_large.yaml"),
|
||||
}
|
||||
|
||||
if test_desc not in config_map:
|
||||
@ -2098,3 +2102,211 @@ def test_disaggregated_stress_test(disaggregated_test_root,
|
||||
threshold=test_config.accuracy_threshold,
|
||||
env=llm_venv._new_env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
|
||||
|
||||
def run_cancel_stress_test(server_url: str,
|
||||
num_bursts: int = 5,
|
||||
requests_per_burst: int = 32,
|
||||
prompt_len_range: tuple = (2000, 8000),
|
||||
cancel_after_range: tuple = (0.01, 0.1)):
|
||||
"""
|
||||
Stress test that sends requests with large contexts and cancels them
|
||||
during prefill to test resource cleanup under cancellation.
|
||||
|
||||
Args:
|
||||
server_url: The server URL (e.g., "http://localhost:8000")
|
||||
num_bursts: Number of request bursts to send
|
||||
requests_per_burst: Number of concurrent requests per burst
|
||||
prompt_len_range: (min, max) prompt length in tokens
|
||||
cancel_after_range: (min, max) seconds to wait before cancelling
|
||||
"""
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
|
||||
import aiohttp
|
||||
|
||||
async def spam_and_cancel(session, req_id, url, prompt_len_range,
|
||||
cancel_after_range):
|
||||
"""Send a request and cancel it during prefill."""
|
||||
prompt_len = random.randint(prompt_len_range[0], prompt_len_range[1])
|
||||
prompt = "test " * (prompt_len // 5)
|
||||
|
||||
payload = {
|
||||
"model": "test-model",
|
||||
"prompt": prompt,
|
||||
"max_tokens": 10,
|
||||
"stream": True
|
||||
}
|
||||
|
||||
try:
|
||||
cancel_after = random.uniform(cancel_after_range[0],
|
||||
cancel_after_range[1])
|
||||
start = time.time()
|
||||
async with session.post(
|
||||
f"{url}/v1/completions",
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=60)) as resp:
|
||||
async for line in resp.content:
|
||||
if time.time() - start > cancel_after:
|
||||
# Force disconnect during prefill
|
||||
break
|
||||
except Exception:
|
||||
pass # Connection abort is expected
|
||||
|
||||
async def run_bursts():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for burst_idx in range(num_bursts):
|
||||
tasks = [
|
||||
spam_and_cancel(session, i, server_url, prompt_len_range,
|
||||
cancel_after_range)
|
||||
for i in range(requests_per_burst)
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
logger.info(
|
||||
f"Completed burst {burst_idx + 1}/{num_bursts} ({requests_per_burst} requests)"
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
asyncio.run(run_bursts())
|
||||
|
||||
|
||||
def run_disaggregated_cancel_test(example_dir,
|
||||
test_desc,
|
||||
env=None,
|
||||
cwd=None,
|
||||
num_bursts=64,
|
||||
requests_per_burst=64):
|
||||
"""Run disaggregated test with request cancellation stress test."""
|
||||
cleanup_output_files()
|
||||
run_env = env.copy()
|
||||
run_env["UCX_TLS"] = "^ib"
|
||||
|
||||
num_ranks, config_file = get_test_config(test_desc, example_dir,
|
||||
os.path.dirname(__file__))
|
||||
|
||||
workers_cmd = [
|
||||
'mpirun', '--allow-run-as-root', '--oversubscribe', '-n',
|
||||
str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
|
||||
config_file
|
||||
]
|
||||
|
||||
server_start_timeout = 1200
|
||||
server_cmd = [
|
||||
'trtllm-serve', 'disaggregated', '--server_start_timeout',
|
||||
str(server_start_timeout), '-c', config_file
|
||||
]
|
||||
server_host, server_port = get_disagg_server_url_from_cfg(config_file)
|
||||
server_url = f"http://{server_host}:{server_port}"
|
||||
|
||||
try:
|
||||
with (open('output_workers.log', 'w') as output_workers,
|
||||
popen(workers_cmd,
|
||||
stdout=output_workers,
|
||||
stderr=subprocess.STDOUT,
|
||||
env=run_env,
|
||||
cwd=cwd) as workers_proc, open('output_disagg.log', 'w') as
|
||||
output_disagg,
|
||||
popen(server_cmd,
|
||||
stdout=output_disagg,
|
||||
stderr=subprocess.STDOUT,
|
||||
env=run_env,
|
||||
cwd=cwd) as server_proc):
|
||||
|
||||
# Wait for server to be ready
|
||||
if not wait_for_server(server_host,
|
||||
server_port,
|
||||
timeout_seconds=server_start_timeout):
|
||||
raise RuntimeError(
|
||||
f"Disaggregated server did not become ready within {server_start_timeout} seconds"
|
||||
)
|
||||
|
||||
# Run the cancel stress test
|
||||
run_cancel_stress_test(server_url,
|
||||
num_bursts=num_bursts,
|
||||
requests_per_burst=requests_per_burst)
|
||||
|
||||
# Verify server is still healthy after stress test by sending a normal request
|
||||
client_dir = f"{example_dir}/clients"
|
||||
client_cmd = [
|
||||
'python3', f'{client_dir}/disagg_client.py', '-c', config_file,
|
||||
'-p', f'{client_dir}/prompts.json', '--ignore-eos',
|
||||
'--server-start-timeout',
|
||||
str(server_start_timeout)
|
||||
]
|
||||
check_call(client_cmd,
|
||||
env=env,
|
||||
poll_procs=[workers_proc, server_proc])
|
||||
|
||||
except Exception:
|
||||
logger.error("-------- Workers output --------")
|
||||
with open('output_workers.log', 'r') as f:
|
||||
logger.error(f.read())
|
||||
|
||||
logger.error("-------- Disagg server output --------")
|
||||
with open('output_disagg.log', 'r') as f:
|
||||
logger.error(f.read())
|
||||
raise
|
||||
finally:
|
||||
if 'server_proc' in locals() and 'workers_proc' in locals():
|
||||
server_proc.terminate()
|
||||
workers_proc.terminate()
|
||||
server_proc.wait()
|
||||
workers_proc.wait()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-bf16'],
|
||||
indirect=True)
|
||||
def test_disaggregated_cancel_large_context_requests(disaggregated_test_root,
|
||||
disaggregated_example_root,
|
||||
llm_venv,
|
||||
deepseek_v3_model_root):
|
||||
"""
|
||||
Test that the disaggregated server handles request cancellations gracefully.
|
||||
|
||||
This test sends bursts of requests with large contexts and cancels them
|
||||
during prefill to stress test resource cleanup.
|
||||
"""
|
||||
src_dst_dict = {
|
||||
deepseek_v3_model_root:
|
||||
f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/bf16",
|
||||
}
|
||||
for src, dst in src_dst_dict.items():
|
||||
if not os.path.islink(dst):
|
||||
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
||||
os.symlink(src, dst, target_is_directory=True)
|
||||
|
||||
run_disaggregated_cancel_test(disaggregated_example_root,
|
||||
"cancel_stress_test",
|
||||
env=llm_venv._new_env,
|
||||
cwd=llm_venv.get_working_directory(),
|
||||
num_bursts=5,
|
||||
requests_per_burst=32)
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.parametrize("model_path", ['DeepSeek-V3-0324-FP4'])
|
||||
def test_disaggregated_cancel_large_context_requests_long(
|
||||
disaggregated_test_root, disaggregated_example_root, llm_venv,
|
||||
model_path):
|
||||
"""Test that disaggregated server handles request cancellations gracefully.
|
||||
|
||||
This test sends bursts of requests with large contexts and cancels them
|
||||
during prefill to stress test resource cleanup.
|
||||
"""
|
||||
model_dir = f"{llm_models_root()}/{model_path}"
|
||||
src_dst_dict = {
|
||||
model_dir: f"{llm_venv.get_working_directory()}/{model_path}",
|
||||
}
|
||||
for src, dst in src_dst_dict.items():
|
||||
if not os.path.islink(dst):
|
||||
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
||||
os.symlink(src, dst, target_is_directory=True)
|
||||
|
||||
run_disaggregated_cancel_test(disaggregated_example_root,
|
||||
"cancel_stress_test_large",
|
||||
env=llm_venv._new_env,
|
||||
cwd=llm_venv.get_working_directory(),
|
||||
num_bursts=1000,
|
||||
requests_per_burst=32)
|
||||
|
||||
@ -43,6 +43,7 @@ l0_dgx_h100:
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-False]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]
|
||||
- unittest/llmapi/apps/test_disagg_serving_perf_metrics.py
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_cancel_large_context_requests[DeepSeek-V3-Lite-bf16]
|
||||
# ------------- AutoDeploy tests ---------------
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-2]
|
||||
# llmapi
|
||||
|
||||
Loading…
Reference in New Issue
Block a user