[https://nvbugs/5689235][fix] Fix cancellation+chunked prefill+disagg (#10111)

Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
This commit is contained in:
Iman Tabrizian 2026-01-12 15:23:26 -08:00 committed by GitHub
parent 18a33764b5
commit 48b09e5a25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 380 additions and 57 deletions

View File

@ -648,7 +648,7 @@ public:
void replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx);
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false);
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
@ -853,8 +853,8 @@ public:
//! \param blockKeys Key of each block.
//! \param blockIds Id of each block.
//! \param pinBlocks If true, increment ref count for blocks while storing (pin on store).
//! \return Pair of (num blocks stored for reuse, id of the last block stored if any).
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
//! \return Pair of (num blocks stored for reuse, vector of pinned block IDs).
[[nodiscard]] std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> storeBlocks(
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
bool pinBlocks = false);
@ -886,8 +886,8 @@ public:
[[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey);
//! \brief Unpin blocks by starting from a block id and walking prev pointers.
void unpinBlocksById(KVCacheBlock::IdType blockId);
//! \brief Unpin blocks by block ids directly
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds);
void initializeSequenceStorageValidity(LlmRequest::RequestIdType requestId)
{
@ -1103,7 +1103,7 @@ public:
std::optional<KVCacheBlock::IdType> releaseBlocks(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
@ -1112,7 +1112,7 @@ public:
/// @param sequence The generation request whose blocks should be pinned.
void pinBlocks(GenerationRequest& sequence);
void unpinBlocksById(KVCacheBlock::IdType blockId);
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds);
void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize);
@ -1133,7 +1133,7 @@ public:
void offloadBlock(BlockPtr const& block, SizeType32 windowSize,
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
[[nodiscard]] std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> storeBlocks(
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
SizeType32 windowSize, bool pinBlocks = false)
{
@ -1584,7 +1584,7 @@ public:
virtual void storeNewBlock(LlmRequest const& llmRequest) = 0;
/// \brief Store blocks for reuse for a given request id
[[nodiscard]] virtual std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
[[nodiscard]] virtual std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false)
= 0;
@ -1678,7 +1678,7 @@ public:
BlockKey const& blockKey, SizeType32 windowSize)
= 0;
virtual void unpinBlocksById(KVCacheBlock::IdType blockId) = 0;
virtual void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) = 0;
};
class KVCacheManager : public BaseKVCacheManager
@ -1939,7 +1939,7 @@ public:
//! \brief Store newest blocks for reuse
void storeNewBlock(LlmRequest const& llmRequest) override;
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false) override;
[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
@ -1960,7 +1960,7 @@ public:
void pinBlocks(LlmRequest::RequestIdType requestId) override;
void unpinBlocksById(KVCacheBlock::IdType blockId) override;
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) override;
std::optional<KVCacheBlock::IdType> getLastBlockId(LlmRequest::RequestIdType requestId) const override;

View File

@ -1667,6 +1667,12 @@ public:
[](auto reason) { return reason == executor::FinishReason::kLENGTH; });
}
[[nodiscard]] bool isFinishedDueToCancellation() const noexcept
{
return std::all_of(mFinishReasons.begin(), mFinishReasons.end(),
[](auto reason) { return reason == executor::FinishReason::kCANCELLED; });
}
[[nodiscard]] bool isTimedOut() const
{
if (!mAllottedTimeMs.has_value())

View File

@ -1556,7 +1556,7 @@ void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAm
}
}
std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::storeBlocks(
std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> WindowBlockManager::storeBlocks(
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds, bool pinBlocks)
{
SizeType32 numBlocksStoredForReuse = 0;
@ -1569,7 +1569,7 @@ std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::s
auto numBlocks = blockKeys.size();
std::vector<BlockPtr> storedBlocks;
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
std::vector<KVCacheBlock::IdType> pinnedBlockIds;
for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt)
{
auto const bid = blockIds[blockCnt];
@ -1620,14 +1620,14 @@ std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::s
if (pinBlocks)
{
searchRoot->incRefCount();
pinnedBlockIds.push_back(searchRoot->getBlockId());
}
lastStoredId = searchRoot->getBlockId();
}
if (mEventManager)
{
mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize);
}
return {numBlocksStoredForReuse, lastStoredId};
return {numBlocksStoredForReuse, pinnedBlockIds};
}
void BlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx)
@ -1715,15 +1715,15 @@ std::deque<tle::KVCacheEvent> BlockManager::getLatestEvents(std::optional<std::c
return mEventManager ? mEventManager->getEvents(timeout) : std::deque<tle::KVCacheEvent>{};
}
std::optional<KVCacheBlock::IdType> BlockManager::storeBlocksForReuse(
std::vector<KVCacheBlock::IdType> BlockManager::storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
{
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
std::vector<KVCacheBlock::IdType> pinnedBlockIds;
for (auto& [_, manager] : mWindowBlockManagers)
{
lastStoredId = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
pinnedBlockIds = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
}
return lastStoredId;
return pinnedBlockIds;
}
std::optional<KVCacheBlock::IdType> BlockManager::releaseBlocks(
@ -1767,7 +1767,7 @@ void BlockManager::pinBlocks(GenerationRequest& sequence)
}
}
void BlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
void BlockManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
{
// Use the first window size
if (mWindowBlockManagers.empty())
@ -1775,7 +1775,7 @@ void BlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
return;
}
auto& firstManager = mWindowBlockManagers.begin()->second;
firstManager.unpinBlocksById(blockId);
firstManager.unpinBlocksById(blockIds);
}
void WindowBlockManager::pinBlocks(GenerationRequest& sequence)
@ -1788,21 +1788,26 @@ void WindowBlockManager::pinBlocks(GenerationRequest& sequence)
}
}
void WindowBlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
void WindowBlockManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
{
if (blockId < 0 || static_cast<size_t>(blockId) >= mAllBlocksById.size())
if (blockIds.empty())
{
return;
}
for (auto const& blockId : blockIds)
{
TLLM_CHECK_WITH_INFO(blockId >= 0 && static_cast<size_t>(blockId) < mAllBlocksById.size(),
"Block id %d is out of range", blockId);
auto block = mAllBlocksById[blockId];
while (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
if (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
{
block->decRefCount();
if (!block->hasRefs())
{
mEvictionPolicy->releaseBlock(block);
}
block = std::move(block->getPrevBlock());
}
}
}
@ -1870,7 +1875,7 @@ void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef<
(void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
}
std::optional<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
std::vector<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
{
auto constexpr beamIdx = 0;
@ -1883,7 +1888,10 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
auto const usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1;
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, true);
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);
return storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks).second;
auto [numStored, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks);
return pinnedBlockIds;
}
std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
@ -1922,7 +1930,7 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
std::transform(allocatedBlocks.begin(), allocatedBlocks.end(), cacheBlockIds.begin(),
[](BlockPtr const& block) { return block->getBlockId(); });
auto [numBlocksStoredForReuse, lastStoredId] = storeBlocks(std::move(blockKeys), cacheBlockIds);
auto [numBlocksStoredForReuse, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds);
TLLM_LOG_DEBUG("%s::releaseBlocks Request %lu, %d blocks stored for reuse", mLogPrefix.c_str(),
sequence.getRequestId(), numBlocksStoredForReuse);
}
@ -2499,15 +2507,14 @@ std::optional<KVCacheBlock::IdType> KVCacheManager::removeSequence(
return lastStoredId;
}
std::optional<KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse(
std::vector<KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse(
RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
{
TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
auto& sequence = getSequence(requestId);
std::optional<KVCacheBlock::IdType> lastStoredId
= mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
auto pinnedBlockIds = mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
TLLM_LOG_TRACE("[%s]::%s stop", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
return lastStoredId;
return pinnedBlockIds;
}
void KVCacheManager::schedulingRemoveSequence(RequestIdType requestId)
@ -2522,9 +2529,9 @@ void KVCacheManager::pinBlocks(RequestIdType requestId)
mBlockManager.pinBlocks(sequence);
}
void KVCacheManager::unpinBlocksById(KVCacheBlock::IdType blockId)
void KVCacheManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
{
mBlockManager.unpinBlocksById(blockId);
mBlockManager.unpinBlocksById(blockIds);
}
SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const

View File

@ -2179,11 +2179,11 @@ void Executor::Impl::terminateContextFinishedRequests(InTransList& inTransmissio
auto req = item.request;
if (req->isDisaggContextCompleteState())
{
// If lastBlockId was tracked, unpin it. Otherwise, just terminate.
// If pinnedBlockIds were tracked, unpin them. Otherwise, just terminate.
auto kvMgr = mModel->getKVCacheManager();
if (kvMgr && item.lastBlockId.has_value())
if (kvMgr && !item.pinnedBlockIds.empty())
{
kvMgr->unpinBlocksById(item.lastBlockId.value());
kvMgr->unpinBlocksById(item.pinnedBlockIds);
}
else
{
@ -2234,14 +2234,14 @@ Executor::Impl::RequestList Executor::Impl::populateNewResponses(
// move the in transmission requests to another tracker
if (llmReq->isDisaggContextTransmissionState())
{
std::optional<SizeType32> lastBlockId{};
std::vector<SizeType32> pinnedBlockIds{};
auto kvMgr = mModel->getKVCacheManager();
if (kvMgr && kvMgr->isEnableBlockReuse() && !kvMgr->getBlockManager().isVariableWindow())
{
lastBlockId = kvMgr->storeBlocksForReuse(llmReq->mRequestId, llmReq, /*pinBlocks=*/true);
pinnedBlockIds = kvMgr->storeBlocksForReuse(llmReq->mRequestId, llmReq, /*pinBlocks=*/true);
mModel->terminateRequest(llmReq);
}
inTransmissionRequests.push_back(InTransmissionItem{*it, lastBlockId});
inTransmissionRequests.push_back(InTransmissionItem{*it, pinnedBlockIds});
}
finishedRequests.push_back(*it);
it = activeRequests.erase(it);

View File

@ -80,12 +80,12 @@ class Executor::Impl
using RequestList = std::list<LlmRequestPtr>;
// When block reuse is enabled for context worker for disaggregated serving,
// we need to store the last block id so that we can unpin the block when
// we need to store the pinned block ids so that we can unpin them when
// the request is finished.
struct InTransmissionItem
{
LlmRequestPtr request;
std::optional<SizeType32> lastBlockId;
std::vector<SizeType32> pinnedBlockIds;
};
using InTransList = std::list<InTransmissionItem>;

View File

@ -161,6 +161,7 @@ void initBindings(nb::module_& m)
.def("set_finished_reason", &GenLlmReq::setFinishedReason, nb::arg("finish_reason"), nb::arg("beam"))
.def_prop_ro("is_finished", &GenLlmReq::isFinished)
.def_prop_ro("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength)
.def_prop_ro("is_finished_due_to_cancellation", &GenLlmReq::isFinishedDueToCancellation)
.def_prop_rw(
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
.def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)

View File

@ -123,7 +123,7 @@ public:
NB_OVERRIDE_PURE(removeSequence, requestId, llmRequest, pinOnRelease);
}
std::optional<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
std::vector<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
tensorrt_llm::common::OptionalRef<tb::LlmRequest const> llmRequest, bool pinBlocks) override
{
NB_OVERRIDE_PURE(storeBlocksForReuse, requestId, llmRequest, pinBlocks);

View File

@ -165,6 +165,7 @@ void initBindings(pybind11::module_& m)
.def("set_finished_reason", &GenLlmReq::setFinishedReason, py::arg("finish_reason"), py::arg("beam"))
.def_property_readonly("is_finished", &GenLlmReq::isFinished)
.def_property_readonly("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength)
.def_property_readonly("is_finished_due_to_cancellation", &GenLlmReq::isFinishedDueToCancellation)
.def_property(
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
.def_property_readonly("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)

View File

@ -111,10 +111,10 @@ public:
requestId, llmRequest, pinOnRelease);
}
std::optional<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
std::vector<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
tensorrt_llm::common::OptionalRef<tb::LlmRequest const> llmRequest, bool pinBlocks) override
{
PYBIND11_OVERLOAD_PURE(std::optional<tbk::KVCacheBlock::IdType>, tbk::BaseKVCacheManager, storeBlocksForReuse,
PYBIND11_OVERLOAD_PURE(std::vector<tbk::KVCacheBlock::IdType>, tbk::BaseKVCacheManager, storeBlocksForReuse,
requestId, llmRequest, pinBlocks);
}

View File

@ -4066,11 +4066,13 @@ TEST_F(KVCacheManagerTest, PinAndUnpinBlocksById)
kvCacheManager.pinBlocks(requestId);
auto lastBlockIdOpt = kvCacheManager.getLastBlockId(requestId);
ASSERT_TRUE(lastBlockIdOpt.has_value());
auto const& allBlockIds = kvCacheManager.getCacheBlockIds(requestId, maxAttentionWindow)[0];
std::vector<SizeType32> pinnedBlockIds(allBlockIds.begin(), allBlockIds.end());
(void) kvCacheManager.removeSequence(requestId, llmRequest);
auto const freeAfterRemovePinned = kvCacheManager.getNumFreeBlocks();
EXPECT_LT(freeAfterRemovePinned, totalBlocks);
kvCacheManager.unpinBlocksById(lastBlockIdOpt.value());
kvCacheManager.unpinBlocksById(pinnedBlockIds);
auto const freeAfterUnpin = kvCacheManager.getNumFreeBlocks();
EXPECT_EQ(freeAfterUnpin, totalBlocks);
}

View File

@ -1167,7 +1167,8 @@ class PyExecutor:
for req in previous_batch.scheduled_ctx_reqs:
if req.is_context_only_request and (
req.is_context_finished
or req.is_finished_due_to_length):
or req.is_finished_due_to_length
) and not req.is_finished_due_to_cancellation:
block_id = self.kv_cache_manager.store_blocks_for_reuse(
req, True)
self.ctx_in_transmission_requests[
@ -1436,7 +1437,8 @@ class PyExecutor:
for req in scheduled_batch.context_requests:
if req.is_context_only_request and (
req.is_context_finished
or req.is_finished_due_to_length):
or req.is_finished_due_to_length
) and not req.is_finished_due_to_cancellation:
block_id = self.kv_cache_manager.store_blocks_for_reuse(
req, True)
self.ctx_in_transmission_requests[
@ -1686,7 +1688,8 @@ class PyExecutor:
for req in self.previous_batch.sample_state.scheduled_requests.context_requests:
if req.is_context_only_request and (
req.is_context_finished
or req.is_finished_due_to_length):
or req.is_finished_due_to_length
) and not req.is_finished_due_to_cancellation:
block_id = self.kv_cache_manager.store_blocks_for_reuse(
req, True)
self.ctx_in_transmission_requests[
@ -2196,8 +2199,9 @@ class PyExecutor:
if (scheduled_ctx_requests is None or len(scheduled_ctx_requests) == 0):
return []
for req in scheduled_ctx_requests:
if req.is_context_only_request and (req.is_context_finished or
req.is_finished_due_to_length):
if req.is_context_only_request and (
req.is_context_finished or req.is_finished_due_to_length
) and not req.is_finished_due_to_cancellation:
self.kv_cache_transceiver.respond_and_send_async(req)
for resource_mgr_type in (
ResourceManagerType.SEQ_SLOT_MANAGER,

View File

@ -1431,7 +1431,8 @@ class ResourceManager:
resource_manager.update_resources(scheduled_batch)
def free_resources(self, request: LlmRequest):
for _, resource_manager in reversed(self.resource_managers.items()):
for resource_type, resource_manager in reversed(
self.resource_managers.items()):
if hasattr(resource_manager, "free_resources"):
resource_manager.free_resources(request)

View File

@ -0,0 +1,44 @@
hostname: localhost
port: 8000
model: DeepSeek-V3-Lite/bf16
backend: "pytorch"
enable_autotuner: False
context_servers:
disable_overlap_scheduler: True
num_instances: 1
tensor_parallel_size: 1
pipeline_parallel_size: 1
max_num_tokens: 16384
max_seq_len: 32768
enable_chunked_prefill: True
kv_cache_config:
enable_block_reuse: True
enable_partial_reuse: True
free_gpu_memory_fraction: 0.3
cache_transceiver_config:
backend: "DEFAULT"
max_tokens_in_buffer: 32768
cuda_graph_config:
enable_padding: True
max_batch_size: 1
urls:
- "localhost:8001"
generation_servers:
num_instances: 1
tensor_parallel_size: 1
pipeline_parallel_size: 1
max_num_tokens: 2048
max_seq_len: 32768
enable_chunked_prefill: True
kv_cache_config:
enable_block_reuse: True
enable_partial_reuse: True
free_gpu_memory_fraction: 0.85
cache_transceiver_config:
backend: "DEFAULT"
max_tokens_in_buffer: 32768
cuda_graph_config:
enable_padding: True
max_batch_size: 64
urls:
- "localhost:8002"

View File

@ -0,0 +1,44 @@
hostname: localhost
port: 8000
model: DeepSeek-V3-0324-FP4
backend: "pytorch"
enable_autotuner: False
context_servers:
disable_overlap_scheduler: True
num_instances: 1
tensor_parallel_size: 4
pipeline_parallel_size: 1
max_num_tokens: 12000
max_seq_len: 262144
enable_chunked_prefill: True
kv_cache_config:
enable_block_reuse: True
enable_partial_reuse: True
free_gpu_memory_fraction: 0.2
cache_transceiver_config:
backend: "DEFAULT"
max_tokens_in_buffer: 262144
cuda_graph_config:
enable_padding: True
max_batch_size: 1
urls:
- "localhost:8001"
generation_servers:
num_instances: 1
tensor_parallel_size: 4
pipeline_parallel_size: 1
max_num_tokens: 2048
max_seq_len: 262144
enable_chunked_prefill: True
kv_cache_config:
enable_block_reuse: True
enable_partial_reuse: True
free_gpu_memory_fraction: 0.3
cache_transceiver_config:
backend: "DEFAULT"
max_tokens_in_buffer: 262144
cuda_graph_config:
enable_padding: True
max_batch_size: 11
urls:
- "localhost:8002"

View File

@ -200,6 +200,10 @@ def get_test_config(test_desc, example_dir, test_root):
"gpt_oss_120b_stress":
(4,
f"{test_configs_root}/disagg_config_ctxtp2_gentp2_gptoss_tllm.yaml"),
"cancel_stress_test":
(2, f"{test_configs_root}/disagg_config_cancel_stress_test.yaml"),
"cancel_stress_test_large":
(8, f"{test_configs_root}/disagg_config_cancel_stress_test_large.yaml"),
}
if test_desc not in config_map:
@ -2098,3 +2102,211 @@ def test_disaggregated_stress_test(disaggregated_test_root,
threshold=test_config.accuracy_threshold,
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())
def run_cancel_stress_test(server_url: str,
num_bursts: int = 5,
requests_per_burst: int = 32,
prompt_len_range: tuple = (2000, 8000),
cancel_after_range: tuple = (0.01, 0.1)):
"""
Stress test that sends requests with large contexts and cancels them
during prefill to test resource cleanup under cancellation.
Args:
server_url: The server URL (e.g., "http://localhost:8000")
num_bursts: Number of request bursts to send
requests_per_burst: Number of concurrent requests per burst
prompt_len_range: (min, max) prompt length in tokens
cancel_after_range: (min, max) seconds to wait before cancelling
"""
import asyncio
import random
import time
import aiohttp
async def spam_and_cancel(session, req_id, url, prompt_len_range,
cancel_after_range):
"""Send a request and cancel it during prefill."""
prompt_len = random.randint(prompt_len_range[0], prompt_len_range[1])
prompt = "test " * (prompt_len // 5)
payload = {
"model": "test-model",
"prompt": prompt,
"max_tokens": 10,
"stream": True
}
try:
cancel_after = random.uniform(cancel_after_range[0],
cancel_after_range[1])
start = time.time()
async with session.post(
f"{url}/v1/completions",
json=payload,
timeout=aiohttp.ClientTimeout(total=60)) as resp:
async for line in resp.content:
if time.time() - start > cancel_after:
# Force disconnect during prefill
break
except Exception:
pass # Connection abort is expected
async def run_bursts():
async with aiohttp.ClientSession() as session:
for burst_idx in range(num_bursts):
tasks = [
spam_and_cancel(session, i, server_url, prompt_len_range,
cancel_after_range)
for i in range(requests_per_burst)
]
await asyncio.gather(*tasks)
logger.info(
f"Completed burst {burst_idx + 1}/{num_bursts} ({requests_per_burst} requests)"
)
await asyncio.sleep(0.05)
asyncio.run(run_bursts())
def run_disaggregated_cancel_test(example_dir,
test_desc,
env=None,
cwd=None,
num_bursts=64,
requests_per_burst=64):
"""Run disaggregated test with request cancellation stress test."""
cleanup_output_files()
run_env = env.copy()
run_env["UCX_TLS"] = "^ib"
num_ranks, config_file = get_test_config(test_desc, example_dir,
os.path.dirname(__file__))
workers_cmd = [
'mpirun', '--allow-run-as-root', '--oversubscribe', '-n',
str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
config_file
]
server_start_timeout = 1200
server_cmd = [
'trtllm-serve', 'disaggregated', '--server_start_timeout',
str(server_start_timeout), '-c', config_file
]
server_host, server_port = get_disagg_server_url_from_cfg(config_file)
server_url = f"http://{server_host}:{server_port}"
try:
with (open('output_workers.log', 'w') as output_workers,
popen(workers_cmd,
stdout=output_workers,
stderr=subprocess.STDOUT,
env=run_env,
cwd=cwd) as workers_proc, open('output_disagg.log', 'w') as
output_disagg,
popen(server_cmd,
stdout=output_disagg,
stderr=subprocess.STDOUT,
env=run_env,
cwd=cwd) as server_proc):
# Wait for server to be ready
if not wait_for_server(server_host,
server_port,
timeout_seconds=server_start_timeout):
raise RuntimeError(
f"Disaggregated server did not become ready within {server_start_timeout} seconds"
)
# Run the cancel stress test
run_cancel_stress_test(server_url,
num_bursts=num_bursts,
requests_per_burst=requests_per_burst)
# Verify server is still healthy after stress test by sending a normal request
client_dir = f"{example_dir}/clients"
client_cmd = [
'python3', f'{client_dir}/disagg_client.py', '-c', config_file,
'-p', f'{client_dir}/prompts.json', '--ignore-eos',
'--server-start-timeout',
str(server_start_timeout)
]
check_call(client_cmd,
env=env,
poll_procs=[workers_proc, server_proc])
except Exception:
logger.error("-------- Workers output --------")
with open('output_workers.log', 'r') as f:
logger.error(f.read())
logger.error("-------- Disagg server output --------")
with open('output_disagg.log', 'r') as f:
logger.error(f.read())
raise
finally:
if 'server_proc' in locals() and 'workers_proc' in locals():
server_proc.terminate()
workers_proc.terminate()
server_proc.wait()
workers_proc.wait()
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-bf16'],
indirect=True)
def test_disaggregated_cancel_large_context_requests(disaggregated_test_root,
disaggregated_example_root,
llm_venv,
deepseek_v3_model_root):
"""
Test that the disaggregated server handles request cancellations gracefully.
This test sends bursts of requests with large contexts and cancels them
during prefill to stress test resource cleanup.
"""
src_dst_dict = {
deepseek_v3_model_root:
f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/bf16",
}
for src, dst in src_dst_dict.items():
if not os.path.islink(dst):
os.makedirs(os.path.dirname(dst), exist_ok=True)
os.symlink(src, dst, target_is_directory=True)
run_disaggregated_cancel_test(disaggregated_example_root,
"cancel_stress_test",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory(),
num_bursts=5,
requests_per_burst=32)
@pytest.mark.skip_less_device(8)
@skip_pre_blackwell
@pytest.mark.parametrize("model_path", ['DeepSeek-V3-0324-FP4'])
def test_disaggregated_cancel_large_context_requests_long(
disaggregated_test_root, disaggregated_example_root, llm_venv,
model_path):
"""Test that disaggregated server handles request cancellations gracefully.
This test sends bursts of requests with large contexts and cancels them
during prefill to stress test resource cleanup.
"""
model_dir = f"{llm_models_root()}/{model_path}"
src_dst_dict = {
model_dir: f"{llm_venv.get_working_directory()}/{model_path}",
}
for src, dst in src_dst_dict.items():
if not os.path.islink(dst):
os.makedirs(os.path.dirname(dst), exist_ok=True)
os.symlink(src, dst, target_is_directory=True)
run_disaggregated_cancel_test(disaggregated_example_root,
"cancel_stress_test_large",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory(),
num_bursts=1000,
requests_per_burst=32)

View File

@ -43,6 +43,7 @@ l0_dgx_h100:
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]
- unittest/llmapi/apps/test_disagg_serving_perf_metrics.py
- disaggregated/test_disaggregated.py::test_disaggregated_cancel_large_context_requests[DeepSeek-V3-Lite-bf16]
# ------------- AutoDeploy tests ---------------
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-2]
# llmapi