From 2db3d7eeba2259f67af283410a10de0449ee970f Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Tue, 20 Jan 2026 09:12:47 -0800 Subject: [PATCH] [None][chore] Async Transfer Manager (#9891) Signed-off-by: jthomson04 --- .../batch_manager/cacheTransceiver.h | 16 +- .../batch_manager/cacheTransceiver.cpp | 15 +- .../trtGptModelInflightBatching.cpp | 6 +- .../batch_manager/cacheTransceiver.cpp | 24 +- .../pybind/batch_manager/cacheTransceiver.cpp | 27 +- tensorrt_llm/_torch/pyexecutor/py_executor.py | 361 ++++++++++-------- .../integration/test_lists/test-db/l0_a10.yml | 1 + .../executor/test_async_transfer_manager.py | 182 +++++++++ 8 files changed, 468 insertions(+), 164 deletions(-) create mode 100644 tests/unittest/_torch/executor/test_async_transfer_manager.py diff --git a/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h b/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h index de68e9805e..4da26f72d5 100644 --- a/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h +++ b/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h @@ -190,6 +190,14 @@ public: std::optional cacheTransceiverConfig = std::nullopt); }; +struct RequestStatuses +{ + /// Requests that have completed their transfer successfully. + std::unordered_set completedRequestIds; + /// Requests that have encountered an error during their transfer. + std::unordered_set errorRequestIds; +}; + class BaseCacheTransceiver { public: @@ -202,7 +210,10 @@ public: virtual void requestAndReceiveSync(LlmRequest* llmRequest) = 0; virtual void requestAndReceiveAsync(LlmRequest* llmRequest) = 0; - virtual void checkContextTransferStatus(std::optional const& atLeastRequestNum = std::nullopt) = 0; + /// Check all requests transferring context, and return the requests that have completed or encountered an error. + virtual RequestStatuses checkContextTransferStatus( + std::optional const& atLeastRequestNum = std::nullopt, bool markComplete = false) + = 0; virtual void checkGenTransferStatus(std::optional const& atLeastRequestNum = std::nullopt) = 0; @@ -243,7 +254,8 @@ public: void requestAndReceiveSync(LlmRequest* llmRequest) override; void requestAndReceiveAsync(LlmRequest* llmRequest) override; - void checkContextTransferStatus(std::optional const& atLeastRequestNum = std::nullopt) override; + RequestStatuses checkContextTransferStatus( + std::optional const& atLeastRequestNum = std::nullopt, bool markComplete = false) override; void checkGenTransferStatus(std::optional const& atLeastRequestNum = std::nullopt) override; diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp index 7e4c26bfd7..2170370d55 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp @@ -427,7 +427,8 @@ void updateKVCacheTransferBW(std::shared_ptr const& mComm, } } -void CacheTransceiver::checkContextTransferStatus(std::optional const& atLeastRequestNum) +RequestStatuses CacheTransceiver::checkContextTransferStatus( + std::optional const& atLeastRequestNum, bool markComplete) { bool blockAll = !atLeastRequestNum.has_value(); std::optional senderFutureTimeoutMs = std::nullopt; @@ -486,6 +487,8 @@ void CacheTransceiver::checkContextTransferStatus(std::optional const& atLe toCompleteIdSet.insert(request->mRequestId); } + RequestStatuses requestsStatus{}; + // Complete all the requests in toCompleteIdSet for (auto it = mSenderFutures.begin(); it != mSenderFutures.end();) { @@ -499,7 +502,11 @@ void CacheTransceiver::checkContextTransferStatus(std::optional const& atLe if (status == std::future_status::ready || !senderFutureTimeoutMs.has_value()) { future.get(); - request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE); + requestsStatus.completedRequestIds.insert(request->mRequestId); + if (markComplete) + { + request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE); + } it = mSenderFutures.erase(it); } else if (status == std::future_status::timeout) @@ -514,6 +521,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional const& atLe "Future returned unexpected status for request %ld. Marking as error", request->mRequestId); request->setState(LlmRequestState::kDISAGG_TRANS_ERROR); + requestsStatus.errorRequestIds.insert(request->mRequestId); it = mSenderFutures.erase(it); } } @@ -522,6 +530,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional const& atLe TLLM_LOG_ERROR( "Error occurred during context transfer for request %ld: %s", request->mRequestId, e.what()); request->setState(LlmRequestState::kDISAGG_TRANS_ERROR); + requestsStatus.errorRequestIds.insert(request->mRequestId); it = mSenderFutures.erase(it); } } @@ -530,6 +539,8 @@ void CacheTransceiver::checkContextTransferStatus(std::optional const& atLe ++it; } } + + return requestsStatus; } void CacheTransceiver::checkGenTransferStatus(std::optional const& atLeastRequestNum) diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index 9210fe9587..e93a908aa8 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -503,7 +503,7 @@ TrtGptModelInflightBatching::~TrtGptModelInflightBatching() { if (mCacheTransceiver) { - mCacheTransceiver->checkContextTransferStatus(true); + mCacheTransceiver->checkContextTransferStatus(1, true); TLLM_CHECK_WITH_INFO(mCacheTransceiver->checkGenTransferComplete(), "Generation transfer not complete"); } if (mAsyncSendWaitThread) @@ -932,7 +932,7 @@ void TrtGptModelInflightBatching::forwardSync() } if (mCacheTransceiver) { - mCacheTransceiver->checkContextTransferStatus(0); + mCacheTransceiver->checkContextTransferStatus(0, true); } ++mIterCounter; @@ -1025,7 +1025,7 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests mIterCounter); if (mCacheTransceiver) { - mCacheTransceiver->checkContextTransferStatus(1); + mCacheTransceiver->checkContextTransferStatus(1, true); // will free kvCache in next iteration. } } diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp index c018a1e0d1..dd3452f0f6 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp @@ -60,9 +60,10 @@ public: NB_OVERRIDE_PURE(requestAndReceiveAsync, llmRequest); } - void checkContextTransferStatus(std::optional const& atLeastRequestNum = std::nullopt) override + tb::RequestStatuses checkContextTransferStatus( + std::optional const& atLeastRequestNum = std::nullopt, bool markComplete = false) override { - NB_OVERRIDE_PURE(checkContextTransferStatus, atLeastRequestNum); + NB_OVERRIDE_PURE(checkContextTransferStatus, atLeastRequestNum, markComplete); } void checkGenTransferStatus(std::optional const& atLeastRequestNum = std::nullopt) override @@ -88,8 +89,23 @@ void tb::CacheTransceiverBindings::initBindings(nb::module_& m) .def("respond_and_send_async", &BaseCacheTransceiver::respondAndSendAsync) .def("request_and_receive_sync", &BaseCacheTransceiver::requestAndReceiveSync) .def("request_and_receive_async", &BaseCacheTransceiver::requestAndReceiveAsync) - .def("check_context_transfer_status", &BaseCacheTransceiver::checkContextTransferStatus, - nb::call_guard()) + .def( + "check_context_transfer_status", + [](tb::BaseCacheTransceiver& self, std::optional const& atLeastRequestNum, bool markComplete = false) + { + RequestStatuses result; + { + nb::gil_scoped_release release; + result = self.checkContextTransferStatus(atLeastRequestNum, markComplete); + } + + auto completedRequestIds + = std::vector(result.completedRequestIds.begin(), result.completedRequestIds.end()); + auto errorRequestIds + = std::vector(result.errorRequestIds.begin(), result.errorRequestIds.end()); + return nb::make_tuple(completedRequestIds, errorRequestIds); + }, + nb::arg("at_least_request_num") = std::nullopt, nb::arg("mark_complete") = false) .def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus, nb::call_guard()) .def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp index 30bf411c9b..7ab2ba0241 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp @@ -56,9 +56,13 @@ public: PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, requestAndReceiveAsync, llmRequest); } - void checkContextTransferStatus(std::optional const& atLeastRequestNum = std::nullopt) override + using RequestStatuses = tb::RequestStatuses; + + RequestStatuses checkContextTransferStatus( + std::optional const& atLeastRequestNum = std::nullopt, bool markComplete = false) override { - PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, checkContextTransferStatus, atLeastRequestNum); + PYBIND11_OVERLOAD_PURE( + RequestStatuses, tb::BaseCacheTransceiver, checkContextTransferStatus, atLeastRequestNum, markComplete); } void checkGenTransferStatus(std::optional const& atLeastRequestNum = std::nullopt) override @@ -84,8 +88,23 @@ void tb::CacheTransceiverBindings::initBindings(py::module_& m) .def("respond_and_send_async", &BaseCacheTransceiver::respondAndSendAsync) .def("request_and_receive_sync", &BaseCacheTransceiver::requestAndReceiveSync) .def("request_and_receive_async", &BaseCacheTransceiver::requestAndReceiveAsync) - .def("check_context_transfer_status", &BaseCacheTransceiver::checkContextTransferStatus, - py::call_guard()) + .def( + "check_context_transfer_status", + [](tb::BaseCacheTransceiver& self, std::optional const& atLeastRequestNum, bool markComplete = false) + { + RequestStatuses result; + { + py::gil_scoped_release release; + result = self.checkContextTransferStatus(atLeastRequestNum, markComplete); + } + + auto completedRequestIds + = std::vector(result.completedRequestIds.begin(), result.completedRequestIds.end()); + auto errorRequestIds + = std::vector(result.errorRequestIds.begin(), result.errorRequestIds.end()); + return py::make_tuple(completedRequestIds, errorRequestIds); + }, + py::arg("at_least_request_num") = std::nullopt, py::arg("mark_complete") = false) .def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus, py::call_guard()) .def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index ac1c15147c..aef8d34941 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -114,6 +114,125 @@ class BatchStatePP(BatchState): finished_ctx_reqs: list[LlmRequest] = None +class AsyncTransferManager: + """ + Handle asynchronous transfer or KV cache after a request has completed. + When running with both the KV cache transceiver and the KV cache connector, we must ensure that BOTH transfers (if any) are completed before we can release the KV cache blocks. + The AsyncTransferManager has a few key responsibilities: + 1. Track requests in transfer. + 2. Pin blocks for reuse while blocks are in transfer. + 3. Unpin blocks after all transfers are complete. + + TODO(jthomson04): This only handles async send/saving, not loading. Loading kv cache is handled through a separate codepath. Eventually, we'll want to merge these two paths. + """ + + class RequestTransferMetadata: + + def __init__(self, block_id: Optional[int]): + self.block_id = block_id + self.counter = 0 + + def start_transfer(self): + self.counter += 1 + + def end_transfer(self) -> bool: + """ + Returns: + bool: True if there are no more transfers for this request + """ + self.counter -= 1 + return self.counter == 0 + + def __init__(self, + resource_manager: "ResourceManager", + should_store_blocks: bool = True): + self.resource_manager = resource_manager + self.kv_cache_manager = resource_manager.resource_managers.get( + ResourceManagerType.KV_CACHE_MANAGER) + + self.should_store_blocks = should_store_blocks + + # Mapping of request id to the LlmRequest + self._requests_in_transfer: Dict[int, LlmRequest] = dict() + + # Mapping of request id to the the request metadata + self._request_transfer_metadata: Dict[ + int, self.RequestTransferMetadata] = dict() + + def requests_in_transfer(self) -> Dict[int, LlmRequest]: + return self._requests_in_transfer + + def start_transfer(self, request: LlmRequest): + """ + Called when a Cache transceiver or connector transfer is started. + 1. Increment the counter for the request. + 2. Releases all resources except for the KV cache, if not already released. + 3. Store KV cache blocks for reuse. + """ + + req_id = request.py_request_id + + if req_id not in self._requests_in_transfer: + for resource_mgr_type in ( + ResourceManagerType.SEQ_SLOT_MANAGER, + ResourceManagerType.SPEC_RESOURCE_MANAGER): + if resource_mgr_type in self.resource_manager.resource_managers and self.resource_manager.resource_managers[ + resource_mgr_type] is not None: + self.resource_manager.resource_managers[ + resource_mgr_type].free_resources(request) + + request.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS + + if self.should_store_blocks: + block_id = self.kv_cache_manager.store_blocks_for_reuse( + request, True) + else: + block_id = None + + self._requests_in_transfer[req_id] = request + self._request_transfer_metadata[ + req_id] = self.RequestTransferMetadata(block_id) + + self._request_transfer_metadata[req_id].start_transfer() + + def end_transfer(self, request: LlmRequest) -> bool: + """ + Called after a send of KV cache is complete. + 1. Decrements counter for request. + 2. If there are no more inflight transfers for this request, unpin the blocks and mark the request as complete. + + Returns: + bool: True if the request should be terminated after call to end_transfer + """ + try: + transfer_metadata = self._request_transfer_metadata[ + request.py_request_id] + except KeyError: + logger.warning( + f"Request {request.py_request_id} not found in transfer manager" + ) + return + + if transfer_metadata.end_transfer(): + self._requests_in_transfer.pop(request.py_request_id) + self._request_transfer_metadata.pop(request.py_request_id) + + if self.should_store_blocks: + self.kv_cache_manager.unpin_blocks_by_id( + transfer_metadata.block_id) + + # We don't want to overwrite any error state. + if request.state != LlmRequestState.DISAGG_TRANS_ERROR: + request.state = LlmRequestState.DISAGG_CONTEXT_COMPLETE + + return True + + return False + + def has_any_inflight_requests(self) -> bool: + return len(self._requests_in_transfer) > 0 + + class PyExecutor: def __init__(self, @@ -233,10 +352,10 @@ class PyExecutor: self.max_num_active_requests = model_engine.get_max_num_sequences() self.active_requests: List[LlmRequest] = [] self.expected_num_active_requests = 0 - self.ctx_in_transmission_requests = dict() - self.ctx_in_transmission_counter = (1 if kv_cache_transceiver else - 0) + (1 if kv_connector_manager else - 0) + self.async_transfer_manager = AsyncTransferManager( + self.resource_manager, + should_store_blocks=self.block_reuse_enabled + and not self.kv_cache_manager.is_vswa) self.previous_batch: Optional[BatchState] = None self.has_previous_draft_tokens = False self.num_scheduled_requests: int = 0 @@ -373,6 +492,10 @@ class PyExecutor: module.register_forward_hook( self.kv_connector_manager.layer_post_hook) + def _end_transfer_and_maybe_terminate(self, request: LlmRequest): + if self.async_transfer_manager.end_transfer(request): + self._terminate_request(request) + def _event_loop_wrapper(self): try: with customized_gc_thresholds( @@ -951,7 +1074,7 @@ class PyExecutor: raise RuntimeError( "KV cache transceiver is not enabled, but current rank cannot run first PP's schedule result due to limited KV cache resources. This is not expected." ) - if not self.ctx_in_transmission_requests: + if not self.async_transfer_manager.has_any_inflight_requests(): raise RuntimeError( "No context cache transmission is in progress, but current rank cannot run first PP's schedule result due to limited KV cache resources. This is not expected." ) @@ -970,7 +1093,6 @@ class PyExecutor: # Let cache transceiver finish at least one cache transmission and release requests' KV cache resources self._check_disagg_ctx_cache_transfer_status(1) self._check_kv_transfer_timeout() - self._terminate_disagg_ctx_finished_requests() else: raise RuntimeError( f"Reach maximum PP retry count ({self.pp_scheduler_max_retry_count}) but still cannot run first PP's schedule result. Please consider increasing the KV cache size by setting `free_gpu_memory_fraction` to a larger value. Or you can set `TLLM_PP_SCHEDULER_MAX_RETRY_COUNT` to a larger value to allow more retries." @@ -1186,21 +1308,8 @@ class PyExecutor: sample_state.scheduled_requests.context_requests = previous_batch.finished_ctx_reqs self._update_requests(previous_batch.sample_state) - if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: - for req in previous_batch.scheduled_ctx_reqs: - if req.is_context_only_request and ( - req.is_context_finished - or req.is_finished_due_to_length - ) and not req.is_finished_due_to_cancellation: - block_id = self.kv_cache_manager.store_blocks_for_reuse( - req, True) - self.ctx_in_transmission_requests[ - req.py_request_id] = ( - (req, block_id, - self.ctx_in_transmission_counter)) - if self.kv_cache_transceiver: - self._send_disagg_ctx_cache( + self._send_kv_async( previous_batch.scheduled_ctx_reqs) self._handle_canceled_requests() @@ -1222,9 +1331,9 @@ class PyExecutor: self.wait_on_pp_send_handles(prev_microbatch_id) self.micro_batches[prev_microbatch_id] = None - if self.kv_cache_transceiver and self.ctx_in_transmission_requests: + if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests( + ): self._check_kv_transfer_timeout() - self._terminate_disagg_ctx_finished_requests() if self._disagg_pp_termination_handler is not None: self._disagg_pp_termination_handler.terminate_pending_requests( @@ -1354,14 +1463,7 @@ class PyExecutor: if self.kv_connector_manager: reqs_to_terminate = self.kv_connector_manager.get_finished() for req in reqs_to_terminate: - if req.py_request_id in self.ctx_in_transmission_requests: - request, block_id, counter = self.ctx_in_transmission_requests.pop( - req.py_request_id) - if counter == 1: - self.kv_cache_manager.unpin_blocks_by_id(block_id) - else: - self.ctx_in_transmission_requests[req.py_request_id] = ( - request, block_id, counter - 1) + self._end_transfer_and_maybe_terminate(req) def _kv_connector_wait_for_save(self): if self.kv_connector_manager is not None: @@ -1457,25 +1559,9 @@ class PyExecutor: self._update_request_states(scheduled_batch) self._update_requests(sample_state, self.resource_manager) - if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: - for req in scheduled_batch.context_requests: - if req.is_context_only_request and ( - req.is_context_finished - or req.is_finished_due_to_length - ) and not req.is_finished_due_to_cancellation: - block_id = self.kv_cache_manager.store_blocks_for_reuse( - req, True) - self.ctx_in_transmission_requests[ - req.py_request_id] = ( - (req, block_id, - self.ctx_in_transmission_counter)) - if self.kv_cache_transceiver: - ctx_transmission_reqs = self._send_disagg_ctx_cache( - scheduled_batch.context_requests) - # For context only req in transmission, we reset the state since sampler might have changed it - for req in ctx_transmission_reqs: - req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS + self._send_kv_async(scheduled_batch.context_requests + + scheduled_batch.generation_requests) self._handle_canceled_requests() finished_requests = self._handle_responses() @@ -1489,9 +1575,9 @@ class PyExecutor: if self.enable_kv_cache_events: self._add_kv_cache_events() - if self.kv_cache_transceiver and self.ctx_in_transmission_requests: + if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests( + ): self._check_kv_transfer_timeout() - self._terminate_disagg_ctx_finished_requests() self._kv_connector_terminate_requests() @@ -1709,19 +1795,6 @@ class PyExecutor: if self.previous_batch is not None: self._update_requests(self.previous_batch.sample_state) - if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: - for req in self.previous_batch.sample_state.scheduled_requests.context_requests: - if req.is_context_only_request and ( - req.is_context_finished - or req.is_finished_due_to_length - ) and not req.is_finished_due_to_cancellation: - block_id = self.kv_cache_manager.store_blocks_for_reuse( - req, True) - self.ctx_in_transmission_requests[ - req.py_request_id] = ( - (req, block_id, - self.ctx_in_transmission_counter)) - if self.drafter is not None and self.use_spec_decode: # Cleanup previous draft resources used in the draft model self.drafter.cleanup_previous_draft_resources() @@ -1746,9 +1819,8 @@ class PyExecutor: self._update_request_states(scheduled_batch) - ctx_transmission_reqs = self._send_disagg_ctx_cache( - scheduled_batch.context_requests - ) if self.kv_cache_transceiver else [] + ctx_transmission_reqs = self._send_kv_async( + scheduled_batch.all_requests()) if self.previous_batch is not None: self._process_previous_batch() @@ -1764,9 +1836,9 @@ class PyExecutor: iter_stats=iter_stats, ctx_transmission_reqs=ctx_transmission_reqs) - if self.kv_cache_transceiver and self.ctx_in_transmission_requests: + if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests( + ): self._check_kv_transfer_timeout() - self._terminate_disagg_ctx_finished_requests() self._kv_connector_terminate_requests() @@ -2101,7 +2173,7 @@ class PyExecutor: ) req.py_kv_transfer_timed_out = True - for req, _, _ in self.ctx_in_transmission_requests.values(): + for req in self.async_transfer_manager.requests_in_transfer().values(): flag_if_kv_transfer_timed_out(req, "context") for req in self.active_requests: @@ -2221,36 +2293,45 @@ class PyExecutor: return - @nvtx_range("_send_disagg_ctx_cache") - def _send_disagg_ctx_cache(self, scheduled_ctx_requests): - if (scheduled_ctx_requests is None or len(scheduled_ctx_requests) == 0): - return [] - for req in scheduled_ctx_requests: - if req.is_context_only_request and ( - req.is_context_finished or req.is_finished_due_to_length - ) and not req.is_finished_due_to_cancellation: - self.kv_cache_transceiver.respond_and_send_async(req) - for resource_mgr_type in ( - ResourceManagerType.SEQ_SLOT_MANAGER, - ResourceManagerType.SPEC_RESOURCE_MANAGER): - if resource_mgr_type in self.resource_manager.resource_managers and self.resource_manager.resource_managers[ - resource_mgr_type] is not None: - self.resource_manager.resource_managers[ - resource_mgr_type].free_resources(req) + @nvtx_range("_send_kv_async") + def _send_kv_async(self, scheduled_requests: List[LlmRequest]): - self._check_disagg_ctx_cache_transfer_status(0) + def kv_connector_request_finished(req: LlmRequest): + try: + cache_block_ids = self.kv_cache_manager.get_cache_indices(req) + except Exception as e: + logger.warning( + f"Unable to get cache blocks for request {req.py_request_id}. Skipping asynchronous saving: {e}" + ) + else: + if self.kv_connector_manager.request_finished( + req, cache_block_ids): + self.async_transfer_manager.start_transfer(req) - # Keep track of ctx requests that are in transmission - ctx_transmission_reqs = [ - req for req in scheduled_ctx_requests - if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS - ] + if self.kv_cache_transceiver: + for req in scheduled_requests: + if req.is_context_only_request and ( + req.is_context_finished or req.is_finished_due_to_length + ) and not req.is_finished_due_to_cancellation: + self.kv_cache_transceiver.respond_and_send_async(req) - if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None: - for req in ctx_transmission_reqs: - req.py_kv_transfer_start_time = time.time() + self.async_transfer_manager.start_transfer(req) - return ctx_transmission_reqs + if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None: + req.py_kv_transfer_start_time = time.time() + + if self.kv_connector_manager: + if not self.disable_overlap_scheduler: + requests = self.previous_batch.sample_state.scheduled_requests.all_requests( + ) if self.previous_batch is not None else [] + else: + requests = scheduled_requests + for req in requests: + if req.is_finished: + kv_connector_request_finished(req) + + if self.kv_cache_transceiver: + self._check_disagg_ctx_cache_transfer_status(0) def _get_disagg_reqs_in_error_state(self): return [ @@ -2268,7 +2349,41 @@ class PyExecutor: @nvtx_range("_check_disagg_ctx_cache_transfer_status") def _check_disagg_ctx_cache_transfer_status(self, atLeastNum: int = 0): - self.kv_cache_transceiver.check_context_transfer_status(atLeastNum) + finished_requests, error_requests = self.kv_cache_transceiver.check_context_transfer_status( + atLeastNum) + + completed_req_ids = set(finished_requests + error_requests) + + requests_in_transfer = self.async_transfer_manager.requests_in_transfer( + ) + + for request_id in completed_req_ids: + + if request_id not in requests_in_transfer: + logger.warning( + f"Request {request_id} not found in transfer manager") + continue + + request = requests_in_transfer[request_id] + + self._end_transfer_and_maybe_terminate(request) + + # The set of requests in transfer may have changed since we terminated some requests. + requests_in_transfer = self.async_transfer_manager.requests_in_transfer( + ) + + for request_id in list(requests_in_transfer.keys()): + request = requests_in_transfer[request_id] + if request.py_kv_transfer_timed_out and request_id not in completed_req_ids: + is_cancelled = self.kv_cache_transceiver.cancel_request(request) + # If cancel is successful, mark as complete so it can be cleaned up + # Otherwise, try at next iteration + if is_cancelled: + request.py_kv_transfer_start_time = None + request.state = LlmRequestState.DISAGG_CONTEXT_COMPLETE + + self._end_transfer_and_maybe_terminate(request) + self._check_cache_transfer_errors("context requests") @nvtx_range("_check_disagg_gen_cache_transfer_status") @@ -2474,24 +2589,6 @@ class PyExecutor: self._do_terminate_request(request) def _do_terminate_request(self, request: LlmRequest): - if self.kv_connector_manager is not None: - # Only call request_finished on the connector if the request has already been added to the kv cache manager. - try: - cache_block_ids = self.kv_cache_manager.get_cache_indices( - request) - except IndexError: - # If the request has not yet been added to the kv cache manager, - # we still need to free resources corresponding to other resource managers. - self.resource_manager.free_resources(request) - else: - if self.kv_connector_manager.request_finished( - request, - cache_block_ids) and not self.kv_cache_transceiver: - block_id = self.kv_cache_manager.store_blocks_for_reuse( - request, True) - self.ctx_in_transmission_requests[request.py_request_id] = ( - (request, block_id, self.ctx_in_transmission_counter)) - self.resource_manager.free_resources(request) if self.gather_all_responses or self.dist.rank == 0: @@ -2678,12 +2775,7 @@ class PyExecutor: if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa: requests_to_terminate.append(request) else: - if request.is_disagg_context_transmission_state: - self.ctx_in_transmission_requests[ - request.py_request_id] = ( - (request, None, - self.ctx_in_transmission_counter)) - else: + if not request.is_disagg_context_transmission_state: requests_to_terminate.append(request) else: new_active_requests.append(request) @@ -2696,35 +2788,6 @@ class PyExecutor: self._terminate_request(request) return requests_to_terminate - @nvtx_range("_terminate_disagg_ctx_finished_requests") - def _terminate_disagg_ctx_finished_requests(self): - # make a copy of the keys, since we are modifying the dictionary in the loop - in_transmission_requests_id = list( - self.ctx_in_transmission_requests.keys()) - for request_id in in_transmission_requests_id: - request, block_id, counter = self.ctx_in_transmission_requests[ - request_id] - - if request.py_kv_transfer_timed_out: - is_cancelled = self.kv_cache_transceiver.cancel_request(request) - # If cancel is successful, mark as complete so it can be cleaned up - # Otherwise, try at next iteration - if is_cancelled: - request.py_kv_transfer_start_time = None - request.state = LlmRequestState.DISAGG_CONTEXT_COMPLETE - - if request.is_disagg_context_complete_state: - del self.ctx_in_transmission_requests[request_id] - if not self.block_reuse_enabled or self.kv_cache_manager.is_vswa: - self._terminate_request(request) - elif counter == 1: - self.kv_cache_manager.unpin_blocks_by_id(block_id) - else: - self.ctx_in_transmission_requests[request_id] = ((request, - block_id, - counter - - 1)) - def _handle_logits_communication(self, previous_batch, prev_microbatch_id): """Handle logits communication between pipeline parallel ranks. diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index df380e1b04..6521d65766 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -21,6 +21,7 @@ l0_a10: - unittest/_torch/modeling/test_modeling_mistral.py - unittest/_torch/modeling/test_modeling_pixtral.py - unittest/_torch/sampler/test_trtllm_sampler.py + - unittest/_torch/executor/test_async_transfer_manager.py - unittest/_torch/executor/test_scheduler_serializable_output.py # NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no # test list either). diff --git a/tests/unittest/_torch/executor/test_async_transfer_manager.py b/tests/unittest/_torch/executor/test_async_transfer_manager.py new file mode 100644 index 0000000000..1f2f9013d9 --- /dev/null +++ b/tests/unittest/_torch/executor/test_async_transfer_manager.py @@ -0,0 +1,182 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +from tensorrt_llm._torch.pyexecutor.py_executor import AsyncTransferManager +from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType +from tensorrt_llm.bindings import LlmRequestState + + +def create_mock_request(request_id: int): + """Create a mock LlmRequest with the given request ID.""" + request = MagicMock() + request.py_request_id = request_id + request.state = LlmRequestState.GENERATION_IN_PROGRESS + return request + + +def create_mock_resource_manager( + kv_cache_manager=None, + seq_slot_manager=None, + spec_resource_manager=None, +): + """Create a mock ResourceManager with the specified resource managers.""" + resource_manager = MagicMock() + resource_manager.resource_managers = {} + + if kv_cache_manager is not None: + resource_manager.resource_managers[ResourceManagerType.KV_CACHE_MANAGER] = kv_cache_manager + + if seq_slot_manager is not None: + resource_manager.resource_managers[ResourceManagerType.SEQ_SLOT_MANAGER] = seq_slot_manager + + if spec_resource_manager is not None: + resource_manager.resource_managers[ResourceManagerType.SPEC_RESOURCE_MANAGER] = ( + spec_resource_manager + ) + + return resource_manager + + +def test_start_transfer_single_request(): + """Test starting a single transfer.""" + kv_cache_manager = MagicMock() + kv_cache_manager.store_blocks_for_reuse.return_value = 100 + seq_slot_manager = MagicMock() + resource_manager = create_mock_resource_manager( + kv_cache_manager=kv_cache_manager, seq_slot_manager=seq_slot_manager + ) + manager = AsyncTransferManager(resource_manager) + + request = create_mock_request(42) + manager.start_transfer(request) + + # Check request is tracked + assert 42 in manager._requests_in_transfer + + transfer_metadata = manager._request_transfer_metadata[42] + + assert transfer_metadata.block_id == 100 + assert transfer_metadata.counter == 1 + + # Check state was updated + assert request.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS + + # Check KV cache manager was called + kv_cache_manager.store_blocks_for_reuse.assert_called_once_with(request, True) + + # Check seq slot manager was called to free resources + seq_slot_manager.free_resources.assert_called_once_with(request) + + manager.end_transfer(request) + kv_cache_manager.unpin_blocks_by_id.assert_called_once() + + +def test_start_transfer_multiple_transfers_same_request(): + """Test starting multiple transfers for the same request.""" + kv_cache_manager = MagicMock() + kv_cache_manager.store_blocks_for_reuse.return_value = 100 + resource_manager = create_mock_resource_manager(kv_cache_manager=kv_cache_manager) + manager = AsyncTransferManager(resource_manager) + + request = create_mock_request(42) + manager.start_transfer(request) + manager.start_transfer(request) + manager.start_transfer(request) + + # Counter should be incremented + transfer_metadata = manager._request_transfer_metadata[42] + assert transfer_metadata.counter == 3 + + # store_blocks_for_reuse should only be called once + kv_cache_manager.store_blocks_for_reuse.assert_called_once() + + for _ in range(2): + manager.end_transfer(request) + kv_cache_manager.unpin_blocks_by_id.assert_not_called() + + manager.end_transfer(request) + kv_cache_manager.unpin_blocks_by_id.assert_called_once() + + +def test_transfer_without_storing_blocks(): + """Test starting a transfer with should_store_blocks=False.""" + kv_cache_manager = MagicMock() + kv_cache_manager.store_blocks_for_reuse.return_value = 0 + spec_resource_manager = MagicMock() + resource_manager = create_mock_resource_manager( + kv_cache_manager=kv_cache_manager, spec_resource_manager=spec_resource_manager + ) + manager = AsyncTransferManager(resource_manager, should_store_blocks=False) + + request = create_mock_request(42) + manager.start_transfer(request) + + # Check request is tracked + assert 42 in manager._requests_in_transfer + transfer_metadata = manager._request_transfer_metadata[42] + assert transfer_metadata.block_id is None # No block stored + assert transfer_metadata.counter == 1 + + # Check KV cache manager was NOT called + kv_cache_manager.store_blocks_for_reuse.assert_not_called() + spec_resource_manager.free_resources.assert_called_once_with(request) + + assert manager.end_transfer(request) + + kv_cache_manager.unpin_blocks_by_id.assert_not_called() + + +def test_end_transfer_preserves_error_state(): + """Test that end_transfer does not overwrite error state.""" + kv_cache_manager = MagicMock() + kv_cache_manager.store_blocks_for_reuse.return_value = 100 + resource_manager = create_mock_resource_manager(kv_cache_manager=kv_cache_manager) + manager = AsyncTransferManager(resource_manager) + + request = create_mock_request(42) + manager.start_transfer(request) + + # Set error state before end_transfer + request.state = LlmRequestState.DISAGG_TRANS_ERROR + + manager.end_transfer(request) + + # Error state should be preserved + assert request.state == LlmRequestState.DISAGG_TRANS_ERROR + + +def test_requests_in_transfer(): + """Test that requests_in_transfer returns correct mapping.""" + kv_cache_manager = MagicMock() + kv_cache_manager.store_blocks_for_reuse.return_value = 100 + resource_manager = create_mock_resource_manager(kv_cache_manager=kv_cache_manager) + manager = AsyncTransferManager(resource_manager) + + request1 = create_mock_request(1) + request2 = create_mock_request(2) + request3 = create_mock_request(3) + + manager.start_transfer(request1) + manager.start_transfer(request2) + manager.start_transfer(request3) + + in_transfer = manager.requests_in_transfer() + + assert len(in_transfer) == 3 + assert in_transfer[1] is request1 + assert in_transfer[2] is request2 + assert in_transfer[3] is request3