[None][chore] Async Transfer Manager (#9891)

Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
This commit is contained in:
jthomson04 2026-01-20 09:12:47 -08:00 committed by GitHub
parent e61c942d1f
commit 2db3d7eeba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 468 additions and 164 deletions

View File

@ -190,6 +190,14 @@ public:
std::optional<executor::CacheTransceiverConfig> cacheTransceiverConfig = std::nullopt);
};
struct RequestStatuses
{
/// Requests that have completed their transfer successfully.
std::unordered_set<LlmRequest::RequestIdType> completedRequestIds;
/// Requests that have encountered an error during their transfer.
std::unordered_set<LlmRequest::RequestIdType> errorRequestIds;
};
class BaseCacheTransceiver
{
public:
@ -202,7 +210,10 @@ public:
virtual void requestAndReceiveSync(LlmRequest* llmRequest) = 0;
virtual void requestAndReceiveAsync(LlmRequest* llmRequest) = 0;
virtual void checkContextTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) = 0;
/// Check all requests transferring context, and return the requests that have completed or encountered an error.
virtual RequestStatuses checkContextTransferStatus(
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false)
= 0;
virtual void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) = 0;
@ -243,7 +254,8 @@ public:
void requestAndReceiveSync(LlmRequest* llmRequest) override;
void requestAndReceiveAsync(LlmRequest* llmRequest) override;
void checkContextTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override;
RequestStatuses checkContextTransferStatus(
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false) override;
void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override;

View File

@ -427,7 +427,8 @@ void updateKVCacheTransferBW(std::shared_ptr<CacheTransceiverComm> const& mComm,
}
}
void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLeastRequestNum)
RequestStatuses CacheTransceiver::checkContextTransferStatus(
std::optional<int> const& atLeastRequestNum, bool markComplete)
{
bool blockAll = !atLeastRequestNum.has_value();
std::optional<int> senderFutureTimeoutMs = std::nullopt;
@ -486,6 +487,8 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
toCompleteIdSet.insert(request->mRequestId);
}
RequestStatuses requestsStatus{};
// Complete all the requests in toCompleteIdSet
for (auto it = mSenderFutures.begin(); it != mSenderFutures.end();)
{
@ -499,7 +502,11 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
if (status == std::future_status::ready || !senderFutureTimeoutMs.has_value())
{
future.get();
request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE);
requestsStatus.completedRequestIds.insert(request->mRequestId);
if (markComplete)
{
request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE);
}
it = mSenderFutures.erase(it);
}
else if (status == std::future_status::timeout)
@ -514,6 +521,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
"Future returned unexpected status for request %ld. Marking as error", request->mRequestId);
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
requestsStatus.errorRequestIds.insert(request->mRequestId);
it = mSenderFutures.erase(it);
}
}
@ -522,6 +530,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
TLLM_LOG_ERROR(
"Error occurred during context transfer for request %ld: %s", request->mRequestId, e.what());
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
requestsStatus.errorRequestIds.insert(request->mRequestId);
it = mSenderFutures.erase(it);
}
}
@ -530,6 +539,8 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
++it;
}
}
return requestsStatus;
}
void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastRequestNum)

View File

@ -503,7 +503,7 @@ TrtGptModelInflightBatching::~TrtGptModelInflightBatching()
{
if (mCacheTransceiver)
{
mCacheTransceiver->checkContextTransferStatus(true);
mCacheTransceiver->checkContextTransferStatus(1, true);
TLLM_CHECK_WITH_INFO(mCacheTransceiver->checkGenTransferComplete(), "Generation transfer not complete");
}
if (mAsyncSendWaitThread)
@ -932,7 +932,7 @@ void TrtGptModelInflightBatching::forwardSync()
}
if (mCacheTransceiver)
{
mCacheTransceiver->checkContextTransferStatus(0);
mCacheTransceiver->checkContextTransferStatus(0, true);
}
++mIterCounter;
@ -1025,7 +1025,7 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
mIterCounter);
if (mCacheTransceiver)
{
mCacheTransceiver->checkContextTransferStatus(1);
mCacheTransceiver->checkContextTransferStatus(1, true);
// will free kvCache in next iteration.
}
}

View File

@ -60,9 +60,10 @@ public:
NB_OVERRIDE_PURE(requestAndReceiveAsync, llmRequest);
}
void checkContextTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override
tb::RequestStatuses checkContextTransferStatus(
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false) override
{
NB_OVERRIDE_PURE(checkContextTransferStatus, atLeastRequestNum);
NB_OVERRIDE_PURE(checkContextTransferStatus, atLeastRequestNum, markComplete);
}
void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override
@ -88,8 +89,23 @@ void tb::CacheTransceiverBindings::initBindings(nb::module_& m)
.def("respond_and_send_async", &BaseCacheTransceiver::respondAndSendAsync)
.def("request_and_receive_sync", &BaseCacheTransceiver::requestAndReceiveSync)
.def("request_and_receive_async", &BaseCacheTransceiver::requestAndReceiveAsync)
.def("check_context_transfer_status", &BaseCacheTransceiver::checkContextTransferStatus,
nb::call_guard<nb::gil_scoped_release>())
.def(
"check_context_transfer_status",
[](tb::BaseCacheTransceiver& self, std::optional<int> const& atLeastRequestNum, bool markComplete = false)
{
RequestStatuses result;
{
nb::gil_scoped_release release;
result = self.checkContextTransferStatus(atLeastRequestNum, markComplete);
}
auto completedRequestIds
= std::vector<int64_t>(result.completedRequestIds.begin(), result.completedRequestIds.end());
auto errorRequestIds
= std::vector<int64_t>(result.errorRequestIds.begin(), result.errorRequestIds.end());
return nb::make_tuple(completedRequestIds, errorRequestIds);
},
nb::arg("at_least_request_num") = std::nullopt, nb::arg("mark_complete") = false)
.def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus,
nb::call_guard<nb::gil_scoped_release>())
.def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete)

View File

@ -56,9 +56,13 @@ public:
PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, requestAndReceiveAsync, llmRequest);
}
void checkContextTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override
using RequestStatuses = tb::RequestStatuses;
RequestStatuses checkContextTransferStatus(
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false) override
{
PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, checkContextTransferStatus, atLeastRequestNum);
PYBIND11_OVERLOAD_PURE(
RequestStatuses, tb::BaseCacheTransceiver, checkContextTransferStatus, atLeastRequestNum, markComplete);
}
void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override
@ -84,8 +88,23 @@ void tb::CacheTransceiverBindings::initBindings(py::module_& m)
.def("respond_and_send_async", &BaseCacheTransceiver::respondAndSendAsync)
.def("request_and_receive_sync", &BaseCacheTransceiver::requestAndReceiveSync)
.def("request_and_receive_async", &BaseCacheTransceiver::requestAndReceiveAsync)
.def("check_context_transfer_status", &BaseCacheTransceiver::checkContextTransferStatus,
py::call_guard<py::gil_scoped_release>())
.def(
"check_context_transfer_status",
[](tb::BaseCacheTransceiver& self, std::optional<int> const& atLeastRequestNum, bool markComplete = false)
{
RequestStatuses result;
{
py::gil_scoped_release release;
result = self.checkContextTransferStatus(atLeastRequestNum, markComplete);
}
auto completedRequestIds
= std::vector<int64_t>(result.completedRequestIds.begin(), result.completedRequestIds.end());
auto errorRequestIds
= std::vector<int64_t>(result.errorRequestIds.begin(), result.errorRequestIds.end());
return py::make_tuple(completedRequestIds, errorRequestIds);
},
py::arg("at_least_request_num") = std::nullopt, py::arg("mark_complete") = false)
.def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus,
py::call_guard<py::gil_scoped_release>())
.def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete)

View File

@ -114,6 +114,125 @@ class BatchStatePP(BatchState):
finished_ctx_reqs: list[LlmRequest] = None
class AsyncTransferManager:
"""
Handle asynchronous transfer or KV cache after a request has completed.
When running with both the KV cache transceiver and the KV cache connector, we must ensure that BOTH transfers (if any) are completed before we can release the KV cache blocks.
The AsyncTransferManager has a few key responsibilities:
1. Track requests in transfer.
2. Pin blocks for reuse while blocks are in transfer.
3. Unpin blocks after all transfers are complete.
TODO(jthomson04): This only handles async send/saving, not loading. Loading kv cache is handled through a separate codepath. Eventually, we'll want to merge these two paths.
"""
class RequestTransferMetadata:
def __init__(self, block_id: Optional[int]):
self.block_id = block_id
self.counter = 0
def start_transfer(self):
self.counter += 1
def end_transfer(self) -> bool:
"""
Returns:
bool: True if there are no more transfers for this request
"""
self.counter -= 1
return self.counter == 0
def __init__(self,
resource_manager: "ResourceManager",
should_store_blocks: bool = True):
self.resource_manager = resource_manager
self.kv_cache_manager = resource_manager.resource_managers.get(
ResourceManagerType.KV_CACHE_MANAGER)
self.should_store_blocks = should_store_blocks
# Mapping of request id to the LlmRequest
self._requests_in_transfer: Dict[int, LlmRequest] = dict()
# Mapping of request id to the the request metadata
self._request_transfer_metadata: Dict[
int, self.RequestTransferMetadata] = dict()
def requests_in_transfer(self) -> Dict[int, LlmRequest]:
return self._requests_in_transfer
def start_transfer(self, request: LlmRequest):
"""
Called when a Cache transceiver or connector transfer is started.
1. Increment the counter for the request.
2. Releases all resources except for the KV cache, if not already released.
3. Store KV cache blocks for reuse.
"""
req_id = request.py_request_id
if req_id not in self._requests_in_transfer:
for resource_mgr_type in (
ResourceManagerType.SEQ_SLOT_MANAGER,
ResourceManagerType.SPEC_RESOURCE_MANAGER):
if resource_mgr_type in self.resource_manager.resource_managers and self.resource_manager.resource_managers[
resource_mgr_type] is not None:
self.resource_manager.resource_managers[
resource_mgr_type].free_resources(request)
request.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
if self.should_store_blocks:
block_id = self.kv_cache_manager.store_blocks_for_reuse(
request, True)
else:
block_id = None
self._requests_in_transfer[req_id] = request
self._request_transfer_metadata[
req_id] = self.RequestTransferMetadata(block_id)
self._request_transfer_metadata[req_id].start_transfer()
def end_transfer(self, request: LlmRequest) -> bool:
"""
Called after a send of KV cache is complete.
1. Decrements counter for request.
2. If there are no more inflight transfers for this request, unpin the blocks and mark the request as complete.
Returns:
bool: True if the request should be terminated after call to end_transfer
"""
try:
transfer_metadata = self._request_transfer_metadata[
request.py_request_id]
except KeyError:
logger.warning(
f"Request {request.py_request_id} not found in transfer manager"
)
return
if transfer_metadata.end_transfer():
self._requests_in_transfer.pop(request.py_request_id)
self._request_transfer_metadata.pop(request.py_request_id)
if self.should_store_blocks:
self.kv_cache_manager.unpin_blocks_by_id(
transfer_metadata.block_id)
# We don't want to overwrite any error state.
if request.state != LlmRequestState.DISAGG_TRANS_ERROR:
request.state = LlmRequestState.DISAGG_CONTEXT_COMPLETE
return True
return False
def has_any_inflight_requests(self) -> bool:
return len(self._requests_in_transfer) > 0
class PyExecutor:
def __init__(self,
@ -233,10 +352,10 @@ class PyExecutor:
self.max_num_active_requests = model_engine.get_max_num_sequences()
self.active_requests: List[LlmRequest] = []
self.expected_num_active_requests = 0
self.ctx_in_transmission_requests = dict()
self.ctx_in_transmission_counter = (1 if kv_cache_transceiver else
0) + (1 if kv_connector_manager else
0)
self.async_transfer_manager = AsyncTransferManager(
self.resource_manager,
should_store_blocks=self.block_reuse_enabled
and not self.kv_cache_manager.is_vswa)
self.previous_batch: Optional[BatchState] = None
self.has_previous_draft_tokens = False
self.num_scheduled_requests: int = 0
@ -373,6 +492,10 @@ class PyExecutor:
module.register_forward_hook(
self.kv_connector_manager.layer_post_hook)
def _end_transfer_and_maybe_terminate(self, request: LlmRequest):
if self.async_transfer_manager.end_transfer(request):
self._terminate_request(request)
def _event_loop_wrapper(self):
try:
with customized_gc_thresholds(
@ -951,7 +1074,7 @@ class PyExecutor:
raise RuntimeError(
"KV cache transceiver is not enabled, but current rank cannot run first PP's schedule result due to limited KV cache resources. This is not expected."
)
if not self.ctx_in_transmission_requests:
if not self.async_transfer_manager.has_any_inflight_requests():
raise RuntimeError(
"No context cache transmission is in progress, but current rank cannot run first PP's schedule result due to limited KV cache resources. This is not expected."
)
@ -970,7 +1093,6 @@ class PyExecutor:
# Let cache transceiver finish at least one cache transmission and release requests' KV cache resources
self._check_disagg_ctx_cache_transfer_status(1)
self._check_kv_transfer_timeout()
self._terminate_disagg_ctx_finished_requests()
else:
raise RuntimeError(
f"Reach maximum PP retry count ({self.pp_scheduler_max_retry_count}) but still cannot run first PP's schedule result. Please consider increasing the KV cache size by setting `free_gpu_memory_fraction` to a larger value. Or you can set `TLLM_PP_SCHEDULER_MAX_RETRY_COUNT` to a larger value to allow more retries."
@ -1186,21 +1308,8 @@ class PyExecutor:
sample_state.scheduled_requests.context_requests = previous_batch.finished_ctx_reqs
self._update_requests(previous_batch.sample_state)
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
for req in previous_batch.scheduled_ctx_reqs:
if req.is_context_only_request and (
req.is_context_finished
or req.is_finished_due_to_length
) and not req.is_finished_due_to_cancellation:
block_id = self.kv_cache_manager.store_blocks_for_reuse(
req, True)
self.ctx_in_transmission_requests[
req.py_request_id] = (
(req, block_id,
self.ctx_in_transmission_counter))
if self.kv_cache_transceiver:
self._send_disagg_ctx_cache(
self._send_kv_async(
previous_batch.scheduled_ctx_reqs)
self._handle_canceled_requests()
@ -1222,9 +1331,9 @@ class PyExecutor:
self.wait_on_pp_send_handles(prev_microbatch_id)
self.micro_batches[prev_microbatch_id] = None
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests(
):
self._check_kv_transfer_timeout()
self._terminate_disagg_ctx_finished_requests()
if self._disagg_pp_termination_handler is not None:
self._disagg_pp_termination_handler.terminate_pending_requests(
@ -1354,14 +1463,7 @@ class PyExecutor:
if self.kv_connector_manager:
reqs_to_terminate = self.kv_connector_manager.get_finished()
for req in reqs_to_terminate:
if req.py_request_id in self.ctx_in_transmission_requests:
request, block_id, counter = self.ctx_in_transmission_requests.pop(
req.py_request_id)
if counter == 1:
self.kv_cache_manager.unpin_blocks_by_id(block_id)
else:
self.ctx_in_transmission_requests[req.py_request_id] = (
request, block_id, counter - 1)
self._end_transfer_and_maybe_terminate(req)
def _kv_connector_wait_for_save(self):
if self.kv_connector_manager is not None:
@ -1457,25 +1559,9 @@ class PyExecutor:
self._update_request_states(scheduled_batch)
self._update_requests(sample_state, self.resource_manager)
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
for req in scheduled_batch.context_requests:
if req.is_context_only_request and (
req.is_context_finished
or req.is_finished_due_to_length
) and not req.is_finished_due_to_cancellation:
block_id = self.kv_cache_manager.store_blocks_for_reuse(
req, True)
self.ctx_in_transmission_requests[
req.py_request_id] = (
(req, block_id,
self.ctx_in_transmission_counter))
if self.kv_cache_transceiver:
ctx_transmission_reqs = self._send_disagg_ctx_cache(
scheduled_batch.context_requests)
# For context only req in transmission, we reset the state since sampler might have changed it
for req in ctx_transmission_reqs:
req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
self._send_kv_async(scheduled_batch.context_requests +
scheduled_batch.generation_requests)
self._handle_canceled_requests()
finished_requests = self._handle_responses()
@ -1489,9 +1575,9 @@ class PyExecutor:
if self.enable_kv_cache_events:
self._add_kv_cache_events()
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests(
):
self._check_kv_transfer_timeout()
self._terminate_disagg_ctx_finished_requests()
self._kv_connector_terminate_requests()
@ -1709,19 +1795,6 @@ class PyExecutor:
if self.previous_batch is not None:
self._update_requests(self.previous_batch.sample_state)
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
for req in self.previous_batch.sample_state.scheduled_requests.context_requests:
if req.is_context_only_request and (
req.is_context_finished
or req.is_finished_due_to_length
) and not req.is_finished_due_to_cancellation:
block_id = self.kv_cache_manager.store_blocks_for_reuse(
req, True)
self.ctx_in_transmission_requests[
req.py_request_id] = (
(req, block_id,
self.ctx_in_transmission_counter))
if self.drafter is not None and self.use_spec_decode:
# Cleanup previous draft resources used in the draft model
self.drafter.cleanup_previous_draft_resources()
@ -1746,9 +1819,8 @@ class PyExecutor:
self._update_request_states(scheduled_batch)
ctx_transmission_reqs = self._send_disagg_ctx_cache(
scheduled_batch.context_requests
) if self.kv_cache_transceiver else []
ctx_transmission_reqs = self._send_kv_async(
scheduled_batch.all_requests())
if self.previous_batch is not None:
self._process_previous_batch()
@ -1764,9 +1836,9 @@ class PyExecutor:
iter_stats=iter_stats,
ctx_transmission_reqs=ctx_transmission_reqs)
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests(
):
self._check_kv_transfer_timeout()
self._terminate_disagg_ctx_finished_requests()
self._kv_connector_terminate_requests()
@ -2101,7 +2173,7 @@ class PyExecutor:
)
req.py_kv_transfer_timed_out = True
for req, _, _ in self.ctx_in_transmission_requests.values():
for req in self.async_transfer_manager.requests_in_transfer().values():
flag_if_kv_transfer_timed_out(req, "context")
for req in self.active_requests:
@ -2221,36 +2293,45 @@ class PyExecutor:
return
@nvtx_range("_send_disagg_ctx_cache")
def _send_disagg_ctx_cache(self, scheduled_ctx_requests):
if (scheduled_ctx_requests is None or len(scheduled_ctx_requests) == 0):
return []
for req in scheduled_ctx_requests:
if req.is_context_only_request and (
req.is_context_finished or req.is_finished_due_to_length
) and not req.is_finished_due_to_cancellation:
self.kv_cache_transceiver.respond_and_send_async(req)
for resource_mgr_type in (
ResourceManagerType.SEQ_SLOT_MANAGER,
ResourceManagerType.SPEC_RESOURCE_MANAGER):
if resource_mgr_type in self.resource_manager.resource_managers and self.resource_manager.resource_managers[
resource_mgr_type] is not None:
self.resource_manager.resource_managers[
resource_mgr_type].free_resources(req)
@nvtx_range("_send_kv_async")
def _send_kv_async(self, scheduled_requests: List[LlmRequest]):
self._check_disagg_ctx_cache_transfer_status(0)
def kv_connector_request_finished(req: LlmRequest):
try:
cache_block_ids = self.kv_cache_manager.get_cache_indices(req)
except Exception as e:
logger.warning(
f"Unable to get cache blocks for request {req.py_request_id}. Skipping asynchronous saving: {e}"
)
else:
if self.kv_connector_manager.request_finished(
req, cache_block_ids):
self.async_transfer_manager.start_transfer(req)
# Keep track of ctx requests that are in transmission
ctx_transmission_reqs = [
req for req in scheduled_ctx_requests
if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
]
if self.kv_cache_transceiver:
for req in scheduled_requests:
if req.is_context_only_request and (
req.is_context_finished or req.is_finished_due_to_length
) and not req.is_finished_due_to_cancellation:
self.kv_cache_transceiver.respond_and_send_async(req)
if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
for req in ctx_transmission_reqs:
req.py_kv_transfer_start_time = time.time()
self.async_transfer_manager.start_transfer(req)
return ctx_transmission_reqs
if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
req.py_kv_transfer_start_time = time.time()
if self.kv_connector_manager:
if not self.disable_overlap_scheduler:
requests = self.previous_batch.sample_state.scheduled_requests.all_requests(
) if self.previous_batch is not None else []
else:
requests = scheduled_requests
for req in requests:
if req.is_finished:
kv_connector_request_finished(req)
if self.kv_cache_transceiver:
self._check_disagg_ctx_cache_transfer_status(0)
def _get_disagg_reqs_in_error_state(self):
return [
@ -2268,7 +2349,41 @@ class PyExecutor:
@nvtx_range("_check_disagg_ctx_cache_transfer_status")
def _check_disagg_ctx_cache_transfer_status(self, atLeastNum: int = 0):
self.kv_cache_transceiver.check_context_transfer_status(atLeastNum)
finished_requests, error_requests = self.kv_cache_transceiver.check_context_transfer_status(
atLeastNum)
completed_req_ids = set(finished_requests + error_requests)
requests_in_transfer = self.async_transfer_manager.requests_in_transfer(
)
for request_id in completed_req_ids:
if request_id not in requests_in_transfer:
logger.warning(
f"Request {request_id} not found in transfer manager")
continue
request = requests_in_transfer[request_id]
self._end_transfer_and_maybe_terminate(request)
# The set of requests in transfer may have changed since we terminated some requests.
requests_in_transfer = self.async_transfer_manager.requests_in_transfer(
)
for request_id in list(requests_in_transfer.keys()):
request = requests_in_transfer[request_id]
if request.py_kv_transfer_timed_out and request_id not in completed_req_ids:
is_cancelled = self.kv_cache_transceiver.cancel_request(request)
# If cancel is successful, mark as complete so it can be cleaned up
# Otherwise, try at next iteration
if is_cancelled:
request.py_kv_transfer_start_time = None
request.state = LlmRequestState.DISAGG_CONTEXT_COMPLETE
self._end_transfer_and_maybe_terminate(request)
self._check_cache_transfer_errors("context requests")
@nvtx_range("_check_disagg_gen_cache_transfer_status")
@ -2474,24 +2589,6 @@ class PyExecutor:
self._do_terminate_request(request)
def _do_terminate_request(self, request: LlmRequest):
if self.kv_connector_manager is not None:
# Only call request_finished on the connector if the request has already been added to the kv cache manager.
try:
cache_block_ids = self.kv_cache_manager.get_cache_indices(
request)
except IndexError:
# If the request has not yet been added to the kv cache manager,
# we still need to free resources corresponding to other resource managers.
self.resource_manager.free_resources(request)
else:
if self.kv_connector_manager.request_finished(
request,
cache_block_ids) and not self.kv_cache_transceiver:
block_id = self.kv_cache_manager.store_blocks_for_reuse(
request, True)
self.ctx_in_transmission_requests[request.py_request_id] = (
(request, block_id, self.ctx_in_transmission_counter))
self.resource_manager.free_resources(request)
if self.gather_all_responses or self.dist.rank == 0:
@ -2678,12 +2775,7 @@ class PyExecutor:
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa:
requests_to_terminate.append(request)
else:
if request.is_disagg_context_transmission_state:
self.ctx_in_transmission_requests[
request.py_request_id] = (
(request, None,
self.ctx_in_transmission_counter))
else:
if not request.is_disagg_context_transmission_state:
requests_to_terminate.append(request)
else:
new_active_requests.append(request)
@ -2696,35 +2788,6 @@ class PyExecutor:
self._terminate_request(request)
return requests_to_terminate
@nvtx_range("_terminate_disagg_ctx_finished_requests")
def _terminate_disagg_ctx_finished_requests(self):
# make a copy of the keys, since we are modifying the dictionary in the loop
in_transmission_requests_id = list(
self.ctx_in_transmission_requests.keys())
for request_id in in_transmission_requests_id:
request, block_id, counter = self.ctx_in_transmission_requests[
request_id]
if request.py_kv_transfer_timed_out:
is_cancelled = self.kv_cache_transceiver.cancel_request(request)
# If cancel is successful, mark as complete so it can be cleaned up
# Otherwise, try at next iteration
if is_cancelled:
request.py_kv_transfer_start_time = None
request.state = LlmRequestState.DISAGG_CONTEXT_COMPLETE
if request.is_disagg_context_complete_state:
del self.ctx_in_transmission_requests[request_id]
if not self.block_reuse_enabled or self.kv_cache_manager.is_vswa:
self._terminate_request(request)
elif counter == 1:
self.kv_cache_manager.unpin_blocks_by_id(block_id)
else:
self.ctx_in_transmission_requests[request_id] = ((request,
block_id,
counter -
1))
def _handle_logits_communication(self, previous_batch, prev_microbatch_id):
"""Handle logits communication between pipeline parallel ranks.

View File

@ -21,6 +21,7 @@ l0_a10:
- unittest/_torch/modeling/test_modeling_mistral.py
- unittest/_torch/modeling/test_modeling_pixtral.py
- unittest/_torch/sampler/test_trtllm_sampler.py
- unittest/_torch/executor/test_async_transfer_manager.py
- unittest/_torch/executor/test_scheduler_serializable_output.py
# NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no
# test list either).

View File

@ -0,0 +1,182 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock
from tensorrt_llm._torch.pyexecutor.py_executor import AsyncTransferManager
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
from tensorrt_llm.bindings import LlmRequestState
def create_mock_request(request_id: int):
"""Create a mock LlmRequest with the given request ID."""
request = MagicMock()
request.py_request_id = request_id
request.state = LlmRequestState.GENERATION_IN_PROGRESS
return request
def create_mock_resource_manager(
kv_cache_manager=None,
seq_slot_manager=None,
spec_resource_manager=None,
):
"""Create a mock ResourceManager with the specified resource managers."""
resource_manager = MagicMock()
resource_manager.resource_managers = {}
if kv_cache_manager is not None:
resource_manager.resource_managers[ResourceManagerType.KV_CACHE_MANAGER] = kv_cache_manager
if seq_slot_manager is not None:
resource_manager.resource_managers[ResourceManagerType.SEQ_SLOT_MANAGER] = seq_slot_manager
if spec_resource_manager is not None:
resource_manager.resource_managers[ResourceManagerType.SPEC_RESOURCE_MANAGER] = (
spec_resource_manager
)
return resource_manager
def test_start_transfer_single_request():
"""Test starting a single transfer."""
kv_cache_manager = MagicMock()
kv_cache_manager.store_blocks_for_reuse.return_value = 100
seq_slot_manager = MagicMock()
resource_manager = create_mock_resource_manager(
kv_cache_manager=kv_cache_manager, seq_slot_manager=seq_slot_manager
)
manager = AsyncTransferManager(resource_manager)
request = create_mock_request(42)
manager.start_transfer(request)
# Check request is tracked
assert 42 in manager._requests_in_transfer
transfer_metadata = manager._request_transfer_metadata[42]
assert transfer_metadata.block_id == 100
assert transfer_metadata.counter == 1
# Check state was updated
assert request.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
# Check KV cache manager was called
kv_cache_manager.store_blocks_for_reuse.assert_called_once_with(request, True)
# Check seq slot manager was called to free resources
seq_slot_manager.free_resources.assert_called_once_with(request)
manager.end_transfer(request)
kv_cache_manager.unpin_blocks_by_id.assert_called_once()
def test_start_transfer_multiple_transfers_same_request():
"""Test starting multiple transfers for the same request."""
kv_cache_manager = MagicMock()
kv_cache_manager.store_blocks_for_reuse.return_value = 100
resource_manager = create_mock_resource_manager(kv_cache_manager=kv_cache_manager)
manager = AsyncTransferManager(resource_manager)
request = create_mock_request(42)
manager.start_transfer(request)
manager.start_transfer(request)
manager.start_transfer(request)
# Counter should be incremented
transfer_metadata = manager._request_transfer_metadata[42]
assert transfer_metadata.counter == 3
# store_blocks_for_reuse should only be called once
kv_cache_manager.store_blocks_for_reuse.assert_called_once()
for _ in range(2):
manager.end_transfer(request)
kv_cache_manager.unpin_blocks_by_id.assert_not_called()
manager.end_transfer(request)
kv_cache_manager.unpin_blocks_by_id.assert_called_once()
def test_transfer_without_storing_blocks():
"""Test starting a transfer with should_store_blocks=False."""
kv_cache_manager = MagicMock()
kv_cache_manager.store_blocks_for_reuse.return_value = 0
spec_resource_manager = MagicMock()
resource_manager = create_mock_resource_manager(
kv_cache_manager=kv_cache_manager, spec_resource_manager=spec_resource_manager
)
manager = AsyncTransferManager(resource_manager, should_store_blocks=False)
request = create_mock_request(42)
manager.start_transfer(request)
# Check request is tracked
assert 42 in manager._requests_in_transfer
transfer_metadata = manager._request_transfer_metadata[42]
assert transfer_metadata.block_id is None # No block stored
assert transfer_metadata.counter == 1
# Check KV cache manager was NOT called
kv_cache_manager.store_blocks_for_reuse.assert_not_called()
spec_resource_manager.free_resources.assert_called_once_with(request)
assert manager.end_transfer(request)
kv_cache_manager.unpin_blocks_by_id.assert_not_called()
def test_end_transfer_preserves_error_state():
"""Test that end_transfer does not overwrite error state."""
kv_cache_manager = MagicMock()
kv_cache_manager.store_blocks_for_reuse.return_value = 100
resource_manager = create_mock_resource_manager(kv_cache_manager=kv_cache_manager)
manager = AsyncTransferManager(resource_manager)
request = create_mock_request(42)
manager.start_transfer(request)
# Set error state before end_transfer
request.state = LlmRequestState.DISAGG_TRANS_ERROR
manager.end_transfer(request)
# Error state should be preserved
assert request.state == LlmRequestState.DISAGG_TRANS_ERROR
def test_requests_in_transfer():
"""Test that requests_in_transfer returns correct mapping."""
kv_cache_manager = MagicMock()
kv_cache_manager.store_blocks_for_reuse.return_value = 100
resource_manager = create_mock_resource_manager(kv_cache_manager=kv_cache_manager)
manager = AsyncTransferManager(resource_manager)
request1 = create_mock_request(1)
request2 = create_mock_request(2)
request3 = create_mock_request(3)
manager.start_transfer(request1)
manager.start_transfer(request2)
manager.start_transfer(request3)
in_transfer = manager.requests_in_transfer()
assert len(in_transfer) == 3
assert in_transfer[1] is request1
assert in_transfer[2] is request2
assert in_transfer[3] is request3