mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5689235][fix] Fix cancellation+chunked prefill+disagg (#10111)
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
This commit is contained in:
parent
18a33764b5
commit
48b09e5a25
@ -648,7 +648,7 @@ public:
|
|||||||
|
|
||||||
void replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx);
|
void replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx);
|
||||||
|
|
||||||
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
|
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||||
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false);
|
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false);
|
||||||
|
|
||||||
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
|
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
|
||||||
@ -853,8 +853,8 @@ public:
|
|||||||
//! \param blockKeys Key of each block.
|
//! \param blockKeys Key of each block.
|
||||||
//! \param blockIds Id of each block.
|
//! \param blockIds Id of each block.
|
||||||
//! \param pinBlocks If true, increment ref count for blocks while storing (pin on store).
|
//! \param pinBlocks If true, increment ref count for blocks while storing (pin on store).
|
||||||
//! \return Pair of (num blocks stored for reuse, id of the last block stored if any).
|
//! \return Pair of (num blocks stored for reuse, vector of pinned block IDs).
|
||||||
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
|
[[nodiscard]] std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> storeBlocks(
|
||||||
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
|
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
|
||||||
bool pinBlocks = false);
|
bool pinBlocks = false);
|
||||||
|
|
||||||
@ -886,8 +886,8 @@ public:
|
|||||||
|
|
||||||
[[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey);
|
[[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey);
|
||||||
|
|
||||||
//! \brief Unpin blocks by starting from a block id and walking prev pointers.
|
//! \brief Unpin blocks by block ids directly
|
||||||
void unpinBlocksById(KVCacheBlock::IdType blockId);
|
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds);
|
||||||
|
|
||||||
void initializeSequenceStorageValidity(LlmRequest::RequestIdType requestId)
|
void initializeSequenceStorageValidity(LlmRequest::RequestIdType requestId)
|
||||||
{
|
{
|
||||||
@ -1103,7 +1103,7 @@ public:
|
|||||||
std::optional<KVCacheBlock::IdType> releaseBlocks(
|
std::optional<KVCacheBlock::IdType> releaseBlocks(
|
||||||
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
|
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
|
||||||
|
|
||||||
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
|
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||||
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
|
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
|
||||||
|
|
||||||
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
|
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
|
||||||
@ -1112,7 +1112,7 @@ public:
|
|||||||
/// @param sequence The generation request whose blocks should be pinned.
|
/// @param sequence The generation request whose blocks should be pinned.
|
||||||
void pinBlocks(GenerationRequest& sequence);
|
void pinBlocks(GenerationRequest& sequence);
|
||||||
|
|
||||||
void unpinBlocksById(KVCacheBlock::IdType blockId);
|
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds);
|
||||||
|
|
||||||
void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize);
|
void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize);
|
||||||
|
|
||||||
@ -1133,7 +1133,7 @@ public:
|
|||||||
void offloadBlock(BlockPtr const& block, SizeType32 windowSize,
|
void offloadBlock(BlockPtr const& block, SizeType32 windowSize,
|
||||||
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");
|
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");
|
||||||
|
|
||||||
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
|
[[nodiscard]] std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> storeBlocks(
|
||||||
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
|
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
|
||||||
SizeType32 windowSize, bool pinBlocks = false)
|
SizeType32 windowSize, bool pinBlocks = false)
|
||||||
{
|
{
|
||||||
@ -1584,7 +1584,7 @@ public:
|
|||||||
virtual void storeNewBlock(LlmRequest const& llmRequest) = 0;
|
virtual void storeNewBlock(LlmRequest const& llmRequest) = 0;
|
||||||
|
|
||||||
/// \brief Store blocks for reuse for a given request id
|
/// \brief Store blocks for reuse for a given request id
|
||||||
[[nodiscard]] virtual std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
|
[[nodiscard]] virtual std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||||
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false)
|
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false)
|
||||||
= 0;
|
= 0;
|
||||||
|
|
||||||
@ -1678,7 +1678,7 @@ public:
|
|||||||
BlockKey const& blockKey, SizeType32 windowSize)
|
BlockKey const& blockKey, SizeType32 windowSize)
|
||||||
= 0;
|
= 0;
|
||||||
|
|
||||||
virtual void unpinBlocksById(KVCacheBlock::IdType blockId) = 0;
|
virtual void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class KVCacheManager : public BaseKVCacheManager
|
class KVCacheManager : public BaseKVCacheManager
|
||||||
@ -1939,7 +1939,7 @@ public:
|
|||||||
//! \brief Store newest blocks for reuse
|
//! \brief Store newest blocks for reuse
|
||||||
void storeNewBlock(LlmRequest const& llmRequest) override;
|
void storeNewBlock(LlmRequest const& llmRequest) override;
|
||||||
|
|
||||||
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
|
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
|
||||||
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false) override;
|
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false) override;
|
||||||
|
|
||||||
[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
|
[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
|
||||||
@ -1960,7 +1960,7 @@ public:
|
|||||||
|
|
||||||
void pinBlocks(LlmRequest::RequestIdType requestId) override;
|
void pinBlocks(LlmRequest::RequestIdType requestId) override;
|
||||||
|
|
||||||
void unpinBlocksById(KVCacheBlock::IdType blockId) override;
|
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) override;
|
||||||
|
|
||||||
std::optional<KVCacheBlock::IdType> getLastBlockId(LlmRequest::RequestIdType requestId) const override;
|
std::optional<KVCacheBlock::IdType> getLastBlockId(LlmRequest::RequestIdType requestId) const override;
|
||||||
|
|
||||||
|
|||||||
@ -1667,6 +1667,12 @@ public:
|
|||||||
[](auto reason) { return reason == executor::FinishReason::kLENGTH; });
|
[](auto reason) { return reason == executor::FinishReason::kLENGTH; });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool isFinishedDueToCancellation() const noexcept
|
||||||
|
{
|
||||||
|
return std::all_of(mFinishReasons.begin(), mFinishReasons.end(),
|
||||||
|
[](auto reason) { return reason == executor::FinishReason::kCANCELLED; });
|
||||||
|
}
|
||||||
|
|
||||||
[[nodiscard]] bool isTimedOut() const
|
[[nodiscard]] bool isTimedOut() const
|
||||||
{
|
{
|
||||||
if (!mAllottedTimeMs.has_value())
|
if (!mAllottedTimeMs.has_value())
|
||||||
|
|||||||
@ -1556,7 +1556,7 @@ void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAm
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::storeBlocks(
|
std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> WindowBlockManager::storeBlocks(
|
||||||
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds, bool pinBlocks)
|
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds, bool pinBlocks)
|
||||||
{
|
{
|
||||||
SizeType32 numBlocksStoredForReuse = 0;
|
SizeType32 numBlocksStoredForReuse = 0;
|
||||||
@ -1569,7 +1569,7 @@ std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::s
|
|||||||
|
|
||||||
auto numBlocks = blockKeys.size();
|
auto numBlocks = blockKeys.size();
|
||||||
std::vector<BlockPtr> storedBlocks;
|
std::vector<BlockPtr> storedBlocks;
|
||||||
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
|
std::vector<KVCacheBlock::IdType> pinnedBlockIds;
|
||||||
for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt)
|
for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt)
|
||||||
{
|
{
|
||||||
auto const bid = blockIds[blockCnt];
|
auto const bid = blockIds[blockCnt];
|
||||||
@ -1620,14 +1620,14 @@ std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::s
|
|||||||
if (pinBlocks)
|
if (pinBlocks)
|
||||||
{
|
{
|
||||||
searchRoot->incRefCount();
|
searchRoot->incRefCount();
|
||||||
|
pinnedBlockIds.push_back(searchRoot->getBlockId());
|
||||||
}
|
}
|
||||||
lastStoredId = searchRoot->getBlockId();
|
|
||||||
}
|
}
|
||||||
if (mEventManager)
|
if (mEventManager)
|
||||||
{
|
{
|
||||||
mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize);
|
mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize);
|
||||||
}
|
}
|
||||||
return {numBlocksStoredForReuse, lastStoredId};
|
return {numBlocksStoredForReuse, pinnedBlockIds};
|
||||||
}
|
}
|
||||||
|
|
||||||
void BlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx)
|
void BlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx)
|
||||||
@ -1715,15 +1715,15 @@ std::deque<tle::KVCacheEvent> BlockManager::getLatestEvents(std::optional<std::c
|
|||||||
return mEventManager ? mEventManager->getEvents(timeout) : std::deque<tle::KVCacheEvent>{};
|
return mEventManager ? mEventManager->getEvents(timeout) : std::deque<tle::KVCacheEvent>{};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<KVCacheBlock::IdType> BlockManager::storeBlocksForReuse(
|
std::vector<KVCacheBlock::IdType> BlockManager::storeBlocksForReuse(
|
||||||
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
|
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
|
||||||
{
|
{
|
||||||
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
|
std::vector<KVCacheBlock::IdType> pinnedBlockIds;
|
||||||
for (auto& [_, manager] : mWindowBlockManagers)
|
for (auto& [_, manager] : mWindowBlockManagers)
|
||||||
{
|
{
|
||||||
lastStoredId = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
|
pinnedBlockIds = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
|
||||||
}
|
}
|
||||||
return lastStoredId;
|
return pinnedBlockIds;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<KVCacheBlock::IdType> BlockManager::releaseBlocks(
|
std::optional<KVCacheBlock::IdType> BlockManager::releaseBlocks(
|
||||||
@ -1767,7 +1767,7 @@ void BlockManager::pinBlocks(GenerationRequest& sequence)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void BlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
|
void BlockManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
|
||||||
{
|
{
|
||||||
// Use the first window size
|
// Use the first window size
|
||||||
if (mWindowBlockManagers.empty())
|
if (mWindowBlockManagers.empty())
|
||||||
@ -1775,7 +1775,7 @@ void BlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto& firstManager = mWindowBlockManagers.begin()->second;
|
auto& firstManager = mWindowBlockManagers.begin()->second;
|
||||||
firstManager.unpinBlocksById(blockId);
|
firstManager.unpinBlocksById(blockIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
void WindowBlockManager::pinBlocks(GenerationRequest& sequence)
|
void WindowBlockManager::pinBlocks(GenerationRequest& sequence)
|
||||||
@ -1788,21 +1788,26 @@ void WindowBlockManager::pinBlocks(GenerationRequest& sequence)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void WindowBlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
|
void WindowBlockManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
|
||||||
{
|
{
|
||||||
if (blockId < 0 || static_cast<size_t>(blockId) >= mAllBlocksById.size())
|
if (blockIds.empty())
|
||||||
{
|
{
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (auto const& blockId : blockIds)
|
||||||
|
{
|
||||||
|
TLLM_CHECK_WITH_INFO(blockId >= 0 && static_cast<size_t>(blockId) < mAllBlocksById.size(),
|
||||||
|
"Block id %d is out of range", blockId);
|
||||||
auto block = mAllBlocksById[blockId];
|
auto block = mAllBlocksById[blockId];
|
||||||
while (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
|
if (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
|
||||||
{
|
{
|
||||||
block->decRefCount();
|
block->decRefCount();
|
||||||
if (!block->hasRefs())
|
if (!block->hasRefs())
|
||||||
{
|
{
|
||||||
mEvictionPolicy->releaseBlock(block);
|
mEvictionPolicy->releaseBlock(block);
|
||||||
}
|
}
|
||||||
block = std::move(block->getPrevBlock());
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1870,7 +1875,7 @@ void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef<
|
|||||||
(void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
|
(void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
|
std::vector<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
|
||||||
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
|
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
|
||||||
{
|
{
|
||||||
auto constexpr beamIdx = 0;
|
auto constexpr beamIdx = 0;
|
||||||
@ -1883,7 +1888,10 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
|
|||||||
auto const usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1;
|
auto const usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1;
|
||||||
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, true);
|
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, true);
|
||||||
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);
|
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);
|
||||||
return storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks).second;
|
|
||||||
|
auto [numStored, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks);
|
||||||
|
|
||||||
|
return pinnedBlockIds;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
|
std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
|
||||||
@ -1922,7 +1930,7 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
|
|||||||
std::transform(allocatedBlocks.begin(), allocatedBlocks.end(), cacheBlockIds.begin(),
|
std::transform(allocatedBlocks.begin(), allocatedBlocks.end(), cacheBlockIds.begin(),
|
||||||
[](BlockPtr const& block) { return block->getBlockId(); });
|
[](BlockPtr const& block) { return block->getBlockId(); });
|
||||||
|
|
||||||
auto [numBlocksStoredForReuse, lastStoredId] = storeBlocks(std::move(blockKeys), cacheBlockIds);
|
auto [numBlocksStoredForReuse, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds);
|
||||||
TLLM_LOG_DEBUG("%s::releaseBlocks Request %lu, %d blocks stored for reuse", mLogPrefix.c_str(),
|
TLLM_LOG_DEBUG("%s::releaseBlocks Request %lu, %d blocks stored for reuse", mLogPrefix.c_str(),
|
||||||
sequence.getRequestId(), numBlocksStoredForReuse);
|
sequence.getRequestId(), numBlocksStoredForReuse);
|
||||||
}
|
}
|
||||||
@ -2499,15 +2507,14 @@ std::optional<KVCacheBlock::IdType> KVCacheManager::removeSequence(
|
|||||||
return lastStoredId;
|
return lastStoredId;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse(
|
std::vector<KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse(
|
||||||
RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
|
RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
|
||||||
{
|
{
|
||||||
TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
|
TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
|
||||||
auto& sequence = getSequence(requestId);
|
auto& sequence = getSequence(requestId);
|
||||||
std::optional<KVCacheBlock::IdType> lastStoredId
|
auto pinnedBlockIds = mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
|
||||||
= mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
|
|
||||||
TLLM_LOG_TRACE("[%s]::%s stop", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
|
TLLM_LOG_TRACE("[%s]::%s stop", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
|
||||||
return lastStoredId;
|
return pinnedBlockIds;
|
||||||
}
|
}
|
||||||
|
|
||||||
void KVCacheManager::schedulingRemoveSequence(RequestIdType requestId)
|
void KVCacheManager::schedulingRemoveSequence(RequestIdType requestId)
|
||||||
@ -2522,9 +2529,9 @@ void KVCacheManager::pinBlocks(RequestIdType requestId)
|
|||||||
mBlockManager.pinBlocks(sequence);
|
mBlockManager.pinBlocks(sequence);
|
||||||
}
|
}
|
||||||
|
|
||||||
void KVCacheManager::unpinBlocksById(KVCacheBlock::IdType blockId)
|
void KVCacheManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
|
||||||
{
|
{
|
||||||
mBlockManager.unpinBlocksById(blockId);
|
mBlockManager.unpinBlocksById(blockIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const
|
SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const
|
||||||
|
|||||||
@ -2179,11 +2179,11 @@ void Executor::Impl::terminateContextFinishedRequests(InTransList& inTransmissio
|
|||||||
auto req = item.request;
|
auto req = item.request;
|
||||||
if (req->isDisaggContextCompleteState())
|
if (req->isDisaggContextCompleteState())
|
||||||
{
|
{
|
||||||
// If lastBlockId was tracked, unpin it. Otherwise, just terminate.
|
// If pinnedBlockIds were tracked, unpin them. Otherwise, just terminate.
|
||||||
auto kvMgr = mModel->getKVCacheManager();
|
auto kvMgr = mModel->getKVCacheManager();
|
||||||
if (kvMgr && item.lastBlockId.has_value())
|
if (kvMgr && !item.pinnedBlockIds.empty())
|
||||||
{
|
{
|
||||||
kvMgr->unpinBlocksById(item.lastBlockId.value());
|
kvMgr->unpinBlocksById(item.pinnedBlockIds);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@ -2234,14 +2234,14 @@ Executor::Impl::RequestList Executor::Impl::populateNewResponses(
|
|||||||
// move the in transmission requests to another tracker
|
// move the in transmission requests to another tracker
|
||||||
if (llmReq->isDisaggContextTransmissionState())
|
if (llmReq->isDisaggContextTransmissionState())
|
||||||
{
|
{
|
||||||
std::optional<SizeType32> lastBlockId{};
|
std::vector<SizeType32> pinnedBlockIds{};
|
||||||
auto kvMgr = mModel->getKVCacheManager();
|
auto kvMgr = mModel->getKVCacheManager();
|
||||||
if (kvMgr && kvMgr->isEnableBlockReuse() && !kvMgr->getBlockManager().isVariableWindow())
|
if (kvMgr && kvMgr->isEnableBlockReuse() && !kvMgr->getBlockManager().isVariableWindow())
|
||||||
{
|
{
|
||||||
lastBlockId = kvMgr->storeBlocksForReuse(llmReq->mRequestId, llmReq, /*pinBlocks=*/true);
|
pinnedBlockIds = kvMgr->storeBlocksForReuse(llmReq->mRequestId, llmReq, /*pinBlocks=*/true);
|
||||||
mModel->terminateRequest(llmReq);
|
mModel->terminateRequest(llmReq);
|
||||||
}
|
}
|
||||||
inTransmissionRequests.push_back(InTransmissionItem{*it, lastBlockId});
|
inTransmissionRequests.push_back(InTransmissionItem{*it, pinnedBlockIds});
|
||||||
}
|
}
|
||||||
finishedRequests.push_back(*it);
|
finishedRequests.push_back(*it);
|
||||||
it = activeRequests.erase(it);
|
it = activeRequests.erase(it);
|
||||||
|
|||||||
@ -80,12 +80,12 @@ class Executor::Impl
|
|||||||
using RequestList = std::list<LlmRequestPtr>;
|
using RequestList = std::list<LlmRequestPtr>;
|
||||||
|
|
||||||
// When block reuse is enabled for context worker for disaggregated serving,
|
// When block reuse is enabled for context worker for disaggregated serving,
|
||||||
// we need to store the last block id so that we can unpin the block when
|
// we need to store the pinned block ids so that we can unpin them when
|
||||||
// the request is finished.
|
// the request is finished.
|
||||||
struct InTransmissionItem
|
struct InTransmissionItem
|
||||||
{
|
{
|
||||||
LlmRequestPtr request;
|
LlmRequestPtr request;
|
||||||
std::optional<SizeType32> lastBlockId;
|
std::vector<SizeType32> pinnedBlockIds;
|
||||||
};
|
};
|
||||||
|
|
||||||
using InTransList = std::list<InTransmissionItem>;
|
using InTransList = std::list<InTransmissionItem>;
|
||||||
|
|||||||
@ -161,6 +161,7 @@ void initBindings(nb::module_& m)
|
|||||||
.def("set_finished_reason", &GenLlmReq::setFinishedReason, nb::arg("finish_reason"), nb::arg("beam"))
|
.def("set_finished_reason", &GenLlmReq::setFinishedReason, nb::arg("finish_reason"), nb::arg("beam"))
|
||||||
.def_prop_ro("is_finished", &GenLlmReq::isFinished)
|
.def_prop_ro("is_finished", &GenLlmReq::isFinished)
|
||||||
.def_prop_ro("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength)
|
.def_prop_ro("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength)
|
||||||
|
.def_prop_ro("is_finished_due_to_cancellation", &GenLlmReq::isFinishedDueToCancellation)
|
||||||
.def_prop_rw(
|
.def_prop_rw(
|
||||||
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
|
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
|
||||||
.def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)
|
.def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)
|
||||||
|
|||||||
@ -123,7 +123,7 @@ public:
|
|||||||
NB_OVERRIDE_PURE(removeSequence, requestId, llmRequest, pinOnRelease);
|
NB_OVERRIDE_PURE(removeSequence, requestId, llmRequest, pinOnRelease);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
|
std::vector<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
|
||||||
tensorrt_llm::common::OptionalRef<tb::LlmRequest const> llmRequest, bool pinBlocks) override
|
tensorrt_llm::common::OptionalRef<tb::LlmRequest const> llmRequest, bool pinBlocks) override
|
||||||
{
|
{
|
||||||
NB_OVERRIDE_PURE(storeBlocksForReuse, requestId, llmRequest, pinBlocks);
|
NB_OVERRIDE_PURE(storeBlocksForReuse, requestId, llmRequest, pinBlocks);
|
||||||
|
|||||||
@ -165,6 +165,7 @@ void initBindings(pybind11::module_& m)
|
|||||||
.def("set_finished_reason", &GenLlmReq::setFinishedReason, py::arg("finish_reason"), py::arg("beam"))
|
.def("set_finished_reason", &GenLlmReq::setFinishedReason, py::arg("finish_reason"), py::arg("beam"))
|
||||||
.def_property_readonly("is_finished", &GenLlmReq::isFinished)
|
.def_property_readonly("is_finished", &GenLlmReq::isFinished)
|
||||||
.def_property_readonly("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength)
|
.def_property_readonly("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength)
|
||||||
|
.def_property_readonly("is_finished_due_to_cancellation", &GenLlmReq::isFinishedDueToCancellation)
|
||||||
.def_property(
|
.def_property(
|
||||||
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
|
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
|
||||||
.def_property_readonly("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)
|
.def_property_readonly("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)
|
||||||
|
|||||||
@ -111,10 +111,10 @@ public:
|
|||||||
requestId, llmRequest, pinOnRelease);
|
requestId, llmRequest, pinOnRelease);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
|
std::vector<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
|
||||||
tensorrt_llm::common::OptionalRef<tb::LlmRequest const> llmRequest, bool pinBlocks) override
|
tensorrt_llm::common::OptionalRef<tb::LlmRequest const> llmRequest, bool pinBlocks) override
|
||||||
{
|
{
|
||||||
PYBIND11_OVERLOAD_PURE(std::optional<tbk::KVCacheBlock::IdType>, tbk::BaseKVCacheManager, storeBlocksForReuse,
|
PYBIND11_OVERLOAD_PURE(std::vector<tbk::KVCacheBlock::IdType>, tbk::BaseKVCacheManager, storeBlocksForReuse,
|
||||||
requestId, llmRequest, pinBlocks);
|
requestId, llmRequest, pinBlocks);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -4066,11 +4066,13 @@ TEST_F(KVCacheManagerTest, PinAndUnpinBlocksById)
|
|||||||
kvCacheManager.pinBlocks(requestId);
|
kvCacheManager.pinBlocks(requestId);
|
||||||
auto lastBlockIdOpt = kvCacheManager.getLastBlockId(requestId);
|
auto lastBlockIdOpt = kvCacheManager.getLastBlockId(requestId);
|
||||||
ASSERT_TRUE(lastBlockIdOpt.has_value());
|
ASSERT_TRUE(lastBlockIdOpt.has_value());
|
||||||
|
auto const& allBlockIds = kvCacheManager.getCacheBlockIds(requestId, maxAttentionWindow)[0];
|
||||||
|
std::vector<SizeType32> pinnedBlockIds(allBlockIds.begin(), allBlockIds.end());
|
||||||
(void) kvCacheManager.removeSequence(requestId, llmRequest);
|
(void) kvCacheManager.removeSequence(requestId, llmRequest);
|
||||||
auto const freeAfterRemovePinned = kvCacheManager.getNumFreeBlocks();
|
auto const freeAfterRemovePinned = kvCacheManager.getNumFreeBlocks();
|
||||||
EXPECT_LT(freeAfterRemovePinned, totalBlocks);
|
EXPECT_LT(freeAfterRemovePinned, totalBlocks);
|
||||||
|
|
||||||
kvCacheManager.unpinBlocksById(lastBlockIdOpt.value());
|
kvCacheManager.unpinBlocksById(pinnedBlockIds);
|
||||||
auto const freeAfterUnpin = kvCacheManager.getNumFreeBlocks();
|
auto const freeAfterUnpin = kvCacheManager.getNumFreeBlocks();
|
||||||
EXPECT_EQ(freeAfterUnpin, totalBlocks);
|
EXPECT_EQ(freeAfterUnpin, totalBlocks);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1167,7 +1167,8 @@ class PyExecutor:
|
|||||||
for req in previous_batch.scheduled_ctx_reqs:
|
for req in previous_batch.scheduled_ctx_reqs:
|
||||||
if req.is_context_only_request and (
|
if req.is_context_only_request and (
|
||||||
req.is_context_finished
|
req.is_context_finished
|
||||||
or req.is_finished_due_to_length):
|
or req.is_finished_due_to_length
|
||||||
|
) and not req.is_finished_due_to_cancellation:
|
||||||
block_id = self.kv_cache_manager.store_blocks_for_reuse(
|
block_id = self.kv_cache_manager.store_blocks_for_reuse(
|
||||||
req, True)
|
req, True)
|
||||||
self.ctx_in_transmission_requests[
|
self.ctx_in_transmission_requests[
|
||||||
@ -1436,7 +1437,8 @@ class PyExecutor:
|
|||||||
for req in scheduled_batch.context_requests:
|
for req in scheduled_batch.context_requests:
|
||||||
if req.is_context_only_request and (
|
if req.is_context_only_request and (
|
||||||
req.is_context_finished
|
req.is_context_finished
|
||||||
or req.is_finished_due_to_length):
|
or req.is_finished_due_to_length
|
||||||
|
) and not req.is_finished_due_to_cancellation:
|
||||||
block_id = self.kv_cache_manager.store_blocks_for_reuse(
|
block_id = self.kv_cache_manager.store_blocks_for_reuse(
|
||||||
req, True)
|
req, True)
|
||||||
self.ctx_in_transmission_requests[
|
self.ctx_in_transmission_requests[
|
||||||
@ -1686,7 +1688,8 @@ class PyExecutor:
|
|||||||
for req in self.previous_batch.sample_state.scheduled_requests.context_requests:
|
for req in self.previous_batch.sample_state.scheduled_requests.context_requests:
|
||||||
if req.is_context_only_request and (
|
if req.is_context_only_request and (
|
||||||
req.is_context_finished
|
req.is_context_finished
|
||||||
or req.is_finished_due_to_length):
|
or req.is_finished_due_to_length
|
||||||
|
) and not req.is_finished_due_to_cancellation:
|
||||||
block_id = self.kv_cache_manager.store_blocks_for_reuse(
|
block_id = self.kv_cache_manager.store_blocks_for_reuse(
|
||||||
req, True)
|
req, True)
|
||||||
self.ctx_in_transmission_requests[
|
self.ctx_in_transmission_requests[
|
||||||
@ -2196,8 +2199,9 @@ class PyExecutor:
|
|||||||
if (scheduled_ctx_requests is None or len(scheduled_ctx_requests) == 0):
|
if (scheduled_ctx_requests is None or len(scheduled_ctx_requests) == 0):
|
||||||
return []
|
return []
|
||||||
for req in scheduled_ctx_requests:
|
for req in scheduled_ctx_requests:
|
||||||
if req.is_context_only_request and (req.is_context_finished or
|
if req.is_context_only_request and (
|
||||||
req.is_finished_due_to_length):
|
req.is_context_finished or req.is_finished_due_to_length
|
||||||
|
) and not req.is_finished_due_to_cancellation:
|
||||||
self.kv_cache_transceiver.respond_and_send_async(req)
|
self.kv_cache_transceiver.respond_and_send_async(req)
|
||||||
for resource_mgr_type in (
|
for resource_mgr_type in (
|
||||||
ResourceManagerType.SEQ_SLOT_MANAGER,
|
ResourceManagerType.SEQ_SLOT_MANAGER,
|
||||||
|
|||||||
@ -1431,7 +1431,8 @@ class ResourceManager:
|
|||||||
resource_manager.update_resources(scheduled_batch)
|
resource_manager.update_resources(scheduled_batch)
|
||||||
|
|
||||||
def free_resources(self, request: LlmRequest):
|
def free_resources(self, request: LlmRequest):
|
||||||
for _, resource_manager in reversed(self.resource_managers.items()):
|
for resource_type, resource_manager in reversed(
|
||||||
|
self.resource_managers.items()):
|
||||||
if hasattr(resource_manager, "free_resources"):
|
if hasattr(resource_manager, "free_resources"):
|
||||||
resource_manager.free_resources(request)
|
resource_manager.free_resources(request)
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,44 @@
|
|||||||
|
hostname: localhost
|
||||||
|
port: 8000
|
||||||
|
model: DeepSeek-V3-Lite/bf16
|
||||||
|
backend: "pytorch"
|
||||||
|
enable_autotuner: False
|
||||||
|
context_servers:
|
||||||
|
disable_overlap_scheduler: True
|
||||||
|
num_instances: 1
|
||||||
|
tensor_parallel_size: 1
|
||||||
|
pipeline_parallel_size: 1
|
||||||
|
max_num_tokens: 16384
|
||||||
|
max_seq_len: 32768
|
||||||
|
enable_chunked_prefill: True
|
||||||
|
kv_cache_config:
|
||||||
|
enable_block_reuse: True
|
||||||
|
enable_partial_reuse: True
|
||||||
|
free_gpu_memory_fraction: 0.3
|
||||||
|
cache_transceiver_config:
|
||||||
|
backend: "DEFAULT"
|
||||||
|
max_tokens_in_buffer: 32768
|
||||||
|
cuda_graph_config:
|
||||||
|
enable_padding: True
|
||||||
|
max_batch_size: 1
|
||||||
|
urls:
|
||||||
|
- "localhost:8001"
|
||||||
|
generation_servers:
|
||||||
|
num_instances: 1
|
||||||
|
tensor_parallel_size: 1
|
||||||
|
pipeline_parallel_size: 1
|
||||||
|
max_num_tokens: 2048
|
||||||
|
max_seq_len: 32768
|
||||||
|
enable_chunked_prefill: True
|
||||||
|
kv_cache_config:
|
||||||
|
enable_block_reuse: True
|
||||||
|
enable_partial_reuse: True
|
||||||
|
free_gpu_memory_fraction: 0.85
|
||||||
|
cache_transceiver_config:
|
||||||
|
backend: "DEFAULT"
|
||||||
|
max_tokens_in_buffer: 32768
|
||||||
|
cuda_graph_config:
|
||||||
|
enable_padding: True
|
||||||
|
max_batch_size: 64
|
||||||
|
urls:
|
||||||
|
- "localhost:8002"
|
||||||
@ -0,0 +1,44 @@
|
|||||||
|
hostname: localhost
|
||||||
|
port: 8000
|
||||||
|
model: DeepSeek-V3-0324-FP4
|
||||||
|
backend: "pytorch"
|
||||||
|
enable_autotuner: False
|
||||||
|
context_servers:
|
||||||
|
disable_overlap_scheduler: True
|
||||||
|
num_instances: 1
|
||||||
|
tensor_parallel_size: 4
|
||||||
|
pipeline_parallel_size: 1
|
||||||
|
max_num_tokens: 12000
|
||||||
|
max_seq_len: 262144
|
||||||
|
enable_chunked_prefill: True
|
||||||
|
kv_cache_config:
|
||||||
|
enable_block_reuse: True
|
||||||
|
enable_partial_reuse: True
|
||||||
|
free_gpu_memory_fraction: 0.2
|
||||||
|
cache_transceiver_config:
|
||||||
|
backend: "DEFAULT"
|
||||||
|
max_tokens_in_buffer: 262144
|
||||||
|
cuda_graph_config:
|
||||||
|
enable_padding: True
|
||||||
|
max_batch_size: 1
|
||||||
|
urls:
|
||||||
|
- "localhost:8001"
|
||||||
|
generation_servers:
|
||||||
|
num_instances: 1
|
||||||
|
tensor_parallel_size: 4
|
||||||
|
pipeline_parallel_size: 1
|
||||||
|
max_num_tokens: 2048
|
||||||
|
max_seq_len: 262144
|
||||||
|
enable_chunked_prefill: True
|
||||||
|
kv_cache_config:
|
||||||
|
enable_block_reuse: True
|
||||||
|
enable_partial_reuse: True
|
||||||
|
free_gpu_memory_fraction: 0.3
|
||||||
|
cache_transceiver_config:
|
||||||
|
backend: "DEFAULT"
|
||||||
|
max_tokens_in_buffer: 262144
|
||||||
|
cuda_graph_config:
|
||||||
|
enable_padding: True
|
||||||
|
max_batch_size: 11
|
||||||
|
urls:
|
||||||
|
- "localhost:8002"
|
||||||
@ -200,6 +200,10 @@ def get_test_config(test_desc, example_dir, test_root):
|
|||||||
"gpt_oss_120b_stress":
|
"gpt_oss_120b_stress":
|
||||||
(4,
|
(4,
|
||||||
f"{test_configs_root}/disagg_config_ctxtp2_gentp2_gptoss_tllm.yaml"),
|
f"{test_configs_root}/disagg_config_ctxtp2_gentp2_gptoss_tllm.yaml"),
|
||||||
|
"cancel_stress_test":
|
||||||
|
(2, f"{test_configs_root}/disagg_config_cancel_stress_test.yaml"),
|
||||||
|
"cancel_stress_test_large":
|
||||||
|
(8, f"{test_configs_root}/disagg_config_cancel_stress_test_large.yaml"),
|
||||||
}
|
}
|
||||||
|
|
||||||
if test_desc not in config_map:
|
if test_desc not in config_map:
|
||||||
@ -2098,3 +2102,211 @@ def test_disaggregated_stress_test(disaggregated_test_root,
|
|||||||
threshold=test_config.accuracy_threshold,
|
threshold=test_config.accuracy_threshold,
|
||||||
env=llm_venv._new_env,
|
env=llm_venv._new_env,
|
||||||
cwd=llm_venv.get_working_directory())
|
cwd=llm_venv.get_working_directory())
|
||||||
|
|
||||||
|
|
||||||
|
def run_cancel_stress_test(server_url: str,
|
||||||
|
num_bursts: int = 5,
|
||||||
|
requests_per_burst: int = 32,
|
||||||
|
prompt_len_range: tuple = (2000, 8000),
|
||||||
|
cancel_after_range: tuple = (0.01, 0.1)):
|
||||||
|
"""
|
||||||
|
Stress test that sends requests with large contexts and cancels them
|
||||||
|
during prefill to test resource cleanup under cancellation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_url: The server URL (e.g., "http://localhost:8000")
|
||||||
|
num_bursts: Number of request bursts to send
|
||||||
|
requests_per_burst: Number of concurrent requests per burst
|
||||||
|
prompt_len_range: (min, max) prompt length in tokens
|
||||||
|
cancel_after_range: (min, max) seconds to wait before cancelling
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
async def spam_and_cancel(session, req_id, url, prompt_len_range,
|
||||||
|
cancel_after_range):
|
||||||
|
"""Send a request and cancel it during prefill."""
|
||||||
|
prompt_len = random.randint(prompt_len_range[0], prompt_len_range[1])
|
||||||
|
prompt = "test " * (prompt_len // 5)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": 10,
|
||||||
|
"stream": True
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
cancel_after = random.uniform(cancel_after_range[0],
|
||||||
|
cancel_after_range[1])
|
||||||
|
start = time.time()
|
||||||
|
async with session.post(
|
||||||
|
f"{url}/v1/completions",
|
||||||
|
json=payload,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=60)) as resp:
|
||||||
|
async for line in resp.content:
|
||||||
|
if time.time() - start > cancel_after:
|
||||||
|
# Force disconnect during prefill
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass # Connection abort is expected
|
||||||
|
|
||||||
|
async def run_bursts():
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
for burst_idx in range(num_bursts):
|
||||||
|
tasks = [
|
||||||
|
spam_and_cancel(session, i, server_url, prompt_len_range,
|
||||||
|
cancel_after_range)
|
||||||
|
for i in range(requests_per_burst)
|
||||||
|
]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
logger.info(
|
||||||
|
f"Completed burst {burst_idx + 1}/{num_bursts} ({requests_per_burst} requests)"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
asyncio.run(run_bursts())
|
||||||
|
|
||||||
|
|
||||||
|
def run_disaggregated_cancel_test(example_dir,
|
||||||
|
test_desc,
|
||||||
|
env=None,
|
||||||
|
cwd=None,
|
||||||
|
num_bursts=64,
|
||||||
|
requests_per_burst=64):
|
||||||
|
"""Run disaggregated test with request cancellation stress test."""
|
||||||
|
cleanup_output_files()
|
||||||
|
run_env = env.copy()
|
||||||
|
run_env["UCX_TLS"] = "^ib"
|
||||||
|
|
||||||
|
num_ranks, config_file = get_test_config(test_desc, example_dir,
|
||||||
|
os.path.dirname(__file__))
|
||||||
|
|
||||||
|
workers_cmd = [
|
||||||
|
'mpirun', '--allow-run-as-root', '--oversubscribe', '-n',
|
||||||
|
str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
|
||||||
|
config_file
|
||||||
|
]
|
||||||
|
|
||||||
|
server_start_timeout = 1200
|
||||||
|
server_cmd = [
|
||||||
|
'trtllm-serve', 'disaggregated', '--server_start_timeout',
|
||||||
|
str(server_start_timeout), '-c', config_file
|
||||||
|
]
|
||||||
|
server_host, server_port = get_disagg_server_url_from_cfg(config_file)
|
||||||
|
server_url = f"http://{server_host}:{server_port}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
with (open('output_workers.log', 'w') as output_workers,
|
||||||
|
popen(workers_cmd,
|
||||||
|
stdout=output_workers,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
env=run_env,
|
||||||
|
cwd=cwd) as workers_proc, open('output_disagg.log', 'w') as
|
||||||
|
output_disagg,
|
||||||
|
popen(server_cmd,
|
||||||
|
stdout=output_disagg,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
env=run_env,
|
||||||
|
cwd=cwd) as server_proc):
|
||||||
|
|
||||||
|
# Wait for server to be ready
|
||||||
|
if not wait_for_server(server_host,
|
||||||
|
server_port,
|
||||||
|
timeout_seconds=server_start_timeout):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Disaggregated server did not become ready within {server_start_timeout} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the cancel stress test
|
||||||
|
run_cancel_stress_test(server_url,
|
||||||
|
num_bursts=num_bursts,
|
||||||
|
requests_per_burst=requests_per_burst)
|
||||||
|
|
||||||
|
# Verify server is still healthy after stress test by sending a normal request
|
||||||
|
client_dir = f"{example_dir}/clients"
|
||||||
|
client_cmd = [
|
||||||
|
'python3', f'{client_dir}/disagg_client.py', '-c', config_file,
|
||||||
|
'-p', f'{client_dir}/prompts.json', '--ignore-eos',
|
||||||
|
'--server-start-timeout',
|
||||||
|
str(server_start_timeout)
|
||||||
|
]
|
||||||
|
check_call(client_cmd,
|
||||||
|
env=env,
|
||||||
|
poll_procs=[workers_proc, server_proc])
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.error("-------- Workers output --------")
|
||||||
|
with open('output_workers.log', 'r') as f:
|
||||||
|
logger.error(f.read())
|
||||||
|
|
||||||
|
logger.error("-------- Disagg server output --------")
|
||||||
|
with open('output_disagg.log', 'r') as f:
|
||||||
|
logger.error(f.read())
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if 'server_proc' in locals() and 'workers_proc' in locals():
|
||||||
|
server_proc.terminate()
|
||||||
|
workers_proc.terminate()
|
||||||
|
server_proc.wait()
|
||||||
|
workers_proc.wait()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-bf16'],
|
||||||
|
indirect=True)
|
||||||
|
def test_disaggregated_cancel_large_context_requests(disaggregated_test_root,
|
||||||
|
disaggregated_example_root,
|
||||||
|
llm_venv,
|
||||||
|
deepseek_v3_model_root):
|
||||||
|
"""
|
||||||
|
Test that the disaggregated server handles request cancellations gracefully.
|
||||||
|
|
||||||
|
This test sends bursts of requests with large contexts and cancels them
|
||||||
|
during prefill to stress test resource cleanup.
|
||||||
|
"""
|
||||||
|
src_dst_dict = {
|
||||||
|
deepseek_v3_model_root:
|
||||||
|
f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/bf16",
|
||||||
|
}
|
||||||
|
for src, dst in src_dst_dict.items():
|
||||||
|
if not os.path.islink(dst):
|
||||||
|
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
||||||
|
os.symlink(src, dst, target_is_directory=True)
|
||||||
|
|
||||||
|
run_disaggregated_cancel_test(disaggregated_example_root,
|
||||||
|
"cancel_stress_test",
|
||||||
|
env=llm_venv._new_env,
|
||||||
|
cwd=llm_venv.get_working_directory(),
|
||||||
|
num_bursts=5,
|
||||||
|
requests_per_burst=32)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_less_device(8)
|
||||||
|
@skip_pre_blackwell
|
||||||
|
@pytest.mark.parametrize("model_path", ['DeepSeek-V3-0324-FP4'])
|
||||||
|
def test_disaggregated_cancel_large_context_requests_long(
|
||||||
|
disaggregated_test_root, disaggregated_example_root, llm_venv,
|
||||||
|
model_path):
|
||||||
|
"""Test that disaggregated server handles request cancellations gracefully.
|
||||||
|
|
||||||
|
This test sends bursts of requests with large contexts and cancels them
|
||||||
|
during prefill to stress test resource cleanup.
|
||||||
|
"""
|
||||||
|
model_dir = f"{llm_models_root()}/{model_path}"
|
||||||
|
src_dst_dict = {
|
||||||
|
model_dir: f"{llm_venv.get_working_directory()}/{model_path}",
|
||||||
|
}
|
||||||
|
for src, dst in src_dst_dict.items():
|
||||||
|
if not os.path.islink(dst):
|
||||||
|
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
||||||
|
os.symlink(src, dst, target_is_directory=True)
|
||||||
|
|
||||||
|
run_disaggregated_cancel_test(disaggregated_example_root,
|
||||||
|
"cancel_stress_test_large",
|
||||||
|
env=llm_venv._new_env,
|
||||||
|
cwd=llm_venv.get_working_directory(),
|
||||||
|
num_bursts=1000,
|
||||||
|
requests_per_burst=32)
|
||||||
|
|||||||
@ -43,6 +43,7 @@ l0_dgx_h100:
|
|||||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-False]
|
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-False]
|
||||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]
|
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]
|
||||||
- unittest/llmapi/apps/test_disagg_serving_perf_metrics.py
|
- unittest/llmapi/apps/test_disagg_serving_perf_metrics.py
|
||||||
|
- disaggregated/test_disaggregated.py::test_disaggregated_cancel_large_context_requests[DeepSeek-V3-Lite-bf16]
|
||||||
# ------------- AutoDeploy tests ---------------
|
# ------------- AutoDeploy tests ---------------
|
||||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-2]
|
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-2]
|
||||||
# llmapi
|
# llmapi
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user