mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
[None][chore] Async Transfer Manager (#9891)
Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
This commit is contained in:
parent
e61c942d1f
commit
2db3d7eeba
@ -190,6 +190,14 @@ public:
|
||||
std::optional<executor::CacheTransceiverConfig> cacheTransceiverConfig = std::nullopt);
|
||||
};
|
||||
|
||||
struct RequestStatuses
|
||||
{
|
||||
/// Requests that have completed their transfer successfully.
|
||||
std::unordered_set<LlmRequest::RequestIdType> completedRequestIds;
|
||||
/// Requests that have encountered an error during their transfer.
|
||||
std::unordered_set<LlmRequest::RequestIdType> errorRequestIds;
|
||||
};
|
||||
|
||||
class BaseCacheTransceiver
|
||||
{
|
||||
public:
|
||||
@ -202,7 +210,10 @@ public:
|
||||
virtual void requestAndReceiveSync(LlmRequest* llmRequest) = 0;
|
||||
virtual void requestAndReceiveAsync(LlmRequest* llmRequest) = 0;
|
||||
|
||||
virtual void checkContextTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) = 0;
|
||||
/// Check all requests transferring context, and return the requests that have completed or encountered an error.
|
||||
virtual RequestStatuses checkContextTransferStatus(
|
||||
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false)
|
||||
= 0;
|
||||
|
||||
virtual void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) = 0;
|
||||
|
||||
@ -243,7 +254,8 @@ public:
|
||||
void requestAndReceiveSync(LlmRequest* llmRequest) override;
|
||||
void requestAndReceiveAsync(LlmRequest* llmRequest) override;
|
||||
|
||||
void checkContextTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override;
|
||||
RequestStatuses checkContextTransferStatus(
|
||||
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false) override;
|
||||
|
||||
void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override;
|
||||
|
||||
|
||||
@ -427,7 +427,8 @@ void updateKVCacheTransferBW(std::shared_ptr<CacheTransceiverComm> const& mComm,
|
||||
}
|
||||
}
|
||||
|
||||
void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLeastRequestNum)
|
||||
RequestStatuses CacheTransceiver::checkContextTransferStatus(
|
||||
std::optional<int> const& atLeastRequestNum, bool markComplete)
|
||||
{
|
||||
bool blockAll = !atLeastRequestNum.has_value();
|
||||
std::optional<int> senderFutureTimeoutMs = std::nullopt;
|
||||
@ -486,6 +487,8 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
|
||||
toCompleteIdSet.insert(request->mRequestId);
|
||||
}
|
||||
|
||||
RequestStatuses requestsStatus{};
|
||||
|
||||
// Complete all the requests in toCompleteIdSet
|
||||
for (auto it = mSenderFutures.begin(); it != mSenderFutures.end();)
|
||||
{
|
||||
@ -499,7 +502,11 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
|
||||
if (status == std::future_status::ready || !senderFutureTimeoutMs.has_value())
|
||||
{
|
||||
future.get();
|
||||
request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE);
|
||||
requestsStatus.completedRequestIds.insert(request->mRequestId);
|
||||
if (markComplete)
|
||||
{
|
||||
request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE);
|
||||
}
|
||||
it = mSenderFutures.erase(it);
|
||||
}
|
||||
else if (status == std::future_status::timeout)
|
||||
@ -514,6 +521,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
|
||||
"Future returned unexpected status for request %ld. Marking as error", request->mRequestId);
|
||||
|
||||
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
|
||||
requestsStatus.errorRequestIds.insert(request->mRequestId);
|
||||
it = mSenderFutures.erase(it);
|
||||
}
|
||||
}
|
||||
@ -522,6 +530,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
|
||||
TLLM_LOG_ERROR(
|
||||
"Error occurred during context transfer for request %ld: %s", request->mRequestId, e.what());
|
||||
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
|
||||
requestsStatus.errorRequestIds.insert(request->mRequestId);
|
||||
it = mSenderFutures.erase(it);
|
||||
}
|
||||
}
|
||||
@ -530,6 +539,8 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
return requestsStatus;
|
||||
}
|
||||
|
||||
void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastRequestNum)
|
||||
|
||||
@ -503,7 +503,7 @@ TrtGptModelInflightBatching::~TrtGptModelInflightBatching()
|
||||
{
|
||||
if (mCacheTransceiver)
|
||||
{
|
||||
mCacheTransceiver->checkContextTransferStatus(true);
|
||||
mCacheTransceiver->checkContextTransferStatus(1, true);
|
||||
TLLM_CHECK_WITH_INFO(mCacheTransceiver->checkGenTransferComplete(), "Generation transfer not complete");
|
||||
}
|
||||
if (mAsyncSendWaitThread)
|
||||
@ -932,7 +932,7 @@ void TrtGptModelInflightBatching::forwardSync()
|
||||
}
|
||||
if (mCacheTransceiver)
|
||||
{
|
||||
mCacheTransceiver->checkContextTransferStatus(0);
|
||||
mCacheTransceiver->checkContextTransferStatus(0, true);
|
||||
}
|
||||
++mIterCounter;
|
||||
|
||||
@ -1025,7 +1025,7 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
|
||||
mIterCounter);
|
||||
if (mCacheTransceiver)
|
||||
{
|
||||
mCacheTransceiver->checkContextTransferStatus(1);
|
||||
mCacheTransceiver->checkContextTransferStatus(1, true);
|
||||
// will free kvCache in next iteration.
|
||||
}
|
||||
}
|
||||
|
||||
@ -60,9 +60,10 @@ public:
|
||||
NB_OVERRIDE_PURE(requestAndReceiveAsync, llmRequest);
|
||||
}
|
||||
|
||||
void checkContextTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override
|
||||
tb::RequestStatuses checkContextTransferStatus(
|
||||
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false) override
|
||||
{
|
||||
NB_OVERRIDE_PURE(checkContextTransferStatus, atLeastRequestNum);
|
||||
NB_OVERRIDE_PURE(checkContextTransferStatus, atLeastRequestNum, markComplete);
|
||||
}
|
||||
|
||||
void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override
|
||||
@ -88,8 +89,23 @@ void tb::CacheTransceiverBindings::initBindings(nb::module_& m)
|
||||
.def("respond_and_send_async", &BaseCacheTransceiver::respondAndSendAsync)
|
||||
.def("request_and_receive_sync", &BaseCacheTransceiver::requestAndReceiveSync)
|
||||
.def("request_and_receive_async", &BaseCacheTransceiver::requestAndReceiveAsync)
|
||||
.def("check_context_transfer_status", &BaseCacheTransceiver::checkContextTransferStatus,
|
||||
nb::call_guard<nb::gil_scoped_release>())
|
||||
.def(
|
||||
"check_context_transfer_status",
|
||||
[](tb::BaseCacheTransceiver& self, std::optional<int> const& atLeastRequestNum, bool markComplete = false)
|
||||
{
|
||||
RequestStatuses result;
|
||||
{
|
||||
nb::gil_scoped_release release;
|
||||
result = self.checkContextTransferStatus(atLeastRequestNum, markComplete);
|
||||
}
|
||||
|
||||
auto completedRequestIds
|
||||
= std::vector<int64_t>(result.completedRequestIds.begin(), result.completedRequestIds.end());
|
||||
auto errorRequestIds
|
||||
= std::vector<int64_t>(result.errorRequestIds.begin(), result.errorRequestIds.end());
|
||||
return nb::make_tuple(completedRequestIds, errorRequestIds);
|
||||
},
|
||||
nb::arg("at_least_request_num") = std::nullopt, nb::arg("mark_complete") = false)
|
||||
.def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus,
|
||||
nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete)
|
||||
|
||||
@ -56,9 +56,13 @@ public:
|
||||
PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, requestAndReceiveAsync, llmRequest);
|
||||
}
|
||||
|
||||
void checkContextTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override
|
||||
using RequestStatuses = tb::RequestStatuses;
|
||||
|
||||
RequestStatuses checkContextTransferStatus(
|
||||
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false) override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, checkContextTransferStatus, atLeastRequestNum);
|
||||
PYBIND11_OVERLOAD_PURE(
|
||||
RequestStatuses, tb::BaseCacheTransceiver, checkContextTransferStatus, atLeastRequestNum, markComplete);
|
||||
}
|
||||
|
||||
void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override
|
||||
@ -84,8 +88,23 @@ void tb::CacheTransceiverBindings::initBindings(py::module_& m)
|
||||
.def("respond_and_send_async", &BaseCacheTransceiver::respondAndSendAsync)
|
||||
.def("request_and_receive_sync", &BaseCacheTransceiver::requestAndReceiveSync)
|
||||
.def("request_and_receive_async", &BaseCacheTransceiver::requestAndReceiveAsync)
|
||||
.def("check_context_transfer_status", &BaseCacheTransceiver::checkContextTransferStatus,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"check_context_transfer_status",
|
||||
[](tb::BaseCacheTransceiver& self, std::optional<int> const& atLeastRequestNum, bool markComplete = false)
|
||||
{
|
||||
RequestStatuses result;
|
||||
{
|
||||
py::gil_scoped_release release;
|
||||
result = self.checkContextTransferStatus(atLeastRequestNum, markComplete);
|
||||
}
|
||||
|
||||
auto completedRequestIds
|
||||
= std::vector<int64_t>(result.completedRequestIds.begin(), result.completedRequestIds.end());
|
||||
auto errorRequestIds
|
||||
= std::vector<int64_t>(result.errorRequestIds.begin(), result.errorRequestIds.end());
|
||||
return py::make_tuple(completedRequestIds, errorRequestIds);
|
||||
},
|
||||
py::arg("at_least_request_num") = std::nullopt, py::arg("mark_complete") = false)
|
||||
.def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete)
|
||||
|
||||
@ -114,6 +114,125 @@ class BatchStatePP(BatchState):
|
||||
finished_ctx_reqs: list[LlmRequest] = None
|
||||
|
||||
|
||||
class AsyncTransferManager:
|
||||
"""
|
||||
Handle asynchronous transfer or KV cache after a request has completed.
|
||||
When running with both the KV cache transceiver and the KV cache connector, we must ensure that BOTH transfers (if any) are completed before we can release the KV cache blocks.
|
||||
The AsyncTransferManager has a few key responsibilities:
|
||||
1. Track requests in transfer.
|
||||
2. Pin blocks for reuse while blocks are in transfer.
|
||||
3. Unpin blocks after all transfers are complete.
|
||||
|
||||
TODO(jthomson04): This only handles async send/saving, not loading. Loading kv cache is handled through a separate codepath. Eventually, we'll want to merge these two paths.
|
||||
"""
|
||||
|
||||
class RequestTransferMetadata:
|
||||
|
||||
def __init__(self, block_id: Optional[int]):
|
||||
self.block_id = block_id
|
||||
self.counter = 0
|
||||
|
||||
def start_transfer(self):
|
||||
self.counter += 1
|
||||
|
||||
def end_transfer(self) -> bool:
|
||||
"""
|
||||
Returns:
|
||||
bool: True if there are no more transfers for this request
|
||||
"""
|
||||
self.counter -= 1
|
||||
return self.counter == 0
|
||||
|
||||
def __init__(self,
|
||||
resource_manager: "ResourceManager",
|
||||
should_store_blocks: bool = True):
|
||||
self.resource_manager = resource_manager
|
||||
self.kv_cache_manager = resource_manager.resource_managers.get(
|
||||
ResourceManagerType.KV_CACHE_MANAGER)
|
||||
|
||||
self.should_store_blocks = should_store_blocks
|
||||
|
||||
# Mapping of request id to the LlmRequest
|
||||
self._requests_in_transfer: Dict[int, LlmRequest] = dict()
|
||||
|
||||
# Mapping of request id to the the request metadata
|
||||
self._request_transfer_metadata: Dict[
|
||||
int, self.RequestTransferMetadata] = dict()
|
||||
|
||||
def requests_in_transfer(self) -> Dict[int, LlmRequest]:
|
||||
return self._requests_in_transfer
|
||||
|
||||
def start_transfer(self, request: LlmRequest):
|
||||
"""
|
||||
Called when a Cache transceiver or connector transfer is started.
|
||||
1. Increment the counter for the request.
|
||||
2. Releases all resources except for the KV cache, if not already released.
|
||||
3. Store KV cache blocks for reuse.
|
||||
"""
|
||||
|
||||
req_id = request.py_request_id
|
||||
|
||||
if req_id not in self._requests_in_transfer:
|
||||
for resource_mgr_type in (
|
||||
ResourceManagerType.SEQ_SLOT_MANAGER,
|
||||
ResourceManagerType.SPEC_RESOURCE_MANAGER):
|
||||
if resource_mgr_type in self.resource_manager.resource_managers and self.resource_manager.resource_managers[
|
||||
resource_mgr_type] is not None:
|
||||
self.resource_manager.resource_managers[
|
||||
resource_mgr_type].free_resources(request)
|
||||
|
||||
request.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
|
||||
|
||||
if self.should_store_blocks:
|
||||
block_id = self.kv_cache_manager.store_blocks_for_reuse(
|
||||
request, True)
|
||||
else:
|
||||
block_id = None
|
||||
|
||||
self._requests_in_transfer[req_id] = request
|
||||
self._request_transfer_metadata[
|
||||
req_id] = self.RequestTransferMetadata(block_id)
|
||||
|
||||
self._request_transfer_metadata[req_id].start_transfer()
|
||||
|
||||
def end_transfer(self, request: LlmRequest) -> bool:
|
||||
"""
|
||||
Called after a send of KV cache is complete.
|
||||
1. Decrements counter for request.
|
||||
2. If there are no more inflight transfers for this request, unpin the blocks and mark the request as complete.
|
||||
|
||||
Returns:
|
||||
bool: True if the request should be terminated after call to end_transfer
|
||||
"""
|
||||
try:
|
||||
transfer_metadata = self._request_transfer_metadata[
|
||||
request.py_request_id]
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
f"Request {request.py_request_id} not found in transfer manager"
|
||||
)
|
||||
return
|
||||
|
||||
if transfer_metadata.end_transfer():
|
||||
self._requests_in_transfer.pop(request.py_request_id)
|
||||
self._request_transfer_metadata.pop(request.py_request_id)
|
||||
|
||||
if self.should_store_blocks:
|
||||
self.kv_cache_manager.unpin_blocks_by_id(
|
||||
transfer_metadata.block_id)
|
||||
|
||||
# We don't want to overwrite any error state.
|
||||
if request.state != LlmRequestState.DISAGG_TRANS_ERROR:
|
||||
request.state = LlmRequestState.DISAGG_CONTEXT_COMPLETE
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def has_any_inflight_requests(self) -> bool:
|
||||
return len(self._requests_in_transfer) > 0
|
||||
|
||||
|
||||
class PyExecutor:
|
||||
|
||||
def __init__(self,
|
||||
@ -233,10 +352,10 @@ class PyExecutor:
|
||||
self.max_num_active_requests = model_engine.get_max_num_sequences()
|
||||
self.active_requests: List[LlmRequest] = []
|
||||
self.expected_num_active_requests = 0
|
||||
self.ctx_in_transmission_requests = dict()
|
||||
self.ctx_in_transmission_counter = (1 if kv_cache_transceiver else
|
||||
0) + (1 if kv_connector_manager else
|
||||
0)
|
||||
self.async_transfer_manager = AsyncTransferManager(
|
||||
self.resource_manager,
|
||||
should_store_blocks=self.block_reuse_enabled
|
||||
and not self.kv_cache_manager.is_vswa)
|
||||
self.previous_batch: Optional[BatchState] = None
|
||||
self.has_previous_draft_tokens = False
|
||||
self.num_scheduled_requests: int = 0
|
||||
@ -373,6 +492,10 @@ class PyExecutor:
|
||||
module.register_forward_hook(
|
||||
self.kv_connector_manager.layer_post_hook)
|
||||
|
||||
def _end_transfer_and_maybe_terminate(self, request: LlmRequest):
|
||||
if self.async_transfer_manager.end_transfer(request):
|
||||
self._terminate_request(request)
|
||||
|
||||
def _event_loop_wrapper(self):
|
||||
try:
|
||||
with customized_gc_thresholds(
|
||||
@ -951,7 +1074,7 @@ class PyExecutor:
|
||||
raise RuntimeError(
|
||||
"KV cache transceiver is not enabled, but current rank cannot run first PP's schedule result due to limited KV cache resources. This is not expected."
|
||||
)
|
||||
if not self.ctx_in_transmission_requests:
|
||||
if not self.async_transfer_manager.has_any_inflight_requests():
|
||||
raise RuntimeError(
|
||||
"No context cache transmission is in progress, but current rank cannot run first PP's schedule result due to limited KV cache resources. This is not expected."
|
||||
)
|
||||
@ -970,7 +1093,6 @@ class PyExecutor:
|
||||
# Let cache transceiver finish at least one cache transmission and release requests' KV cache resources
|
||||
self._check_disagg_ctx_cache_transfer_status(1)
|
||||
self._check_kv_transfer_timeout()
|
||||
self._terminate_disagg_ctx_finished_requests()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Reach maximum PP retry count ({self.pp_scheduler_max_retry_count}) but still cannot run first PP's schedule result. Please consider increasing the KV cache size by setting `free_gpu_memory_fraction` to a larger value. Or you can set `TLLM_PP_SCHEDULER_MAX_RETRY_COUNT` to a larger value to allow more retries."
|
||||
@ -1186,21 +1308,8 @@ class PyExecutor:
|
||||
sample_state.scheduled_requests.context_requests = previous_batch.finished_ctx_reqs
|
||||
self._update_requests(previous_batch.sample_state)
|
||||
|
||||
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
|
||||
for req in previous_batch.scheduled_ctx_reqs:
|
||||
if req.is_context_only_request and (
|
||||
req.is_context_finished
|
||||
or req.is_finished_due_to_length
|
||||
) and not req.is_finished_due_to_cancellation:
|
||||
block_id = self.kv_cache_manager.store_blocks_for_reuse(
|
||||
req, True)
|
||||
self.ctx_in_transmission_requests[
|
||||
req.py_request_id] = (
|
||||
(req, block_id,
|
||||
self.ctx_in_transmission_counter))
|
||||
|
||||
if self.kv_cache_transceiver:
|
||||
self._send_disagg_ctx_cache(
|
||||
self._send_kv_async(
|
||||
previous_batch.scheduled_ctx_reqs)
|
||||
self._handle_canceled_requests()
|
||||
|
||||
@ -1222,9 +1331,9 @@ class PyExecutor:
|
||||
self.wait_on_pp_send_handles(prev_microbatch_id)
|
||||
self.micro_batches[prev_microbatch_id] = None
|
||||
|
||||
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
|
||||
if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests(
|
||||
):
|
||||
self._check_kv_transfer_timeout()
|
||||
self._terminate_disagg_ctx_finished_requests()
|
||||
|
||||
if self._disagg_pp_termination_handler is not None:
|
||||
self._disagg_pp_termination_handler.terminate_pending_requests(
|
||||
@ -1354,14 +1463,7 @@ class PyExecutor:
|
||||
if self.kv_connector_manager:
|
||||
reqs_to_terminate = self.kv_connector_manager.get_finished()
|
||||
for req in reqs_to_terminate:
|
||||
if req.py_request_id in self.ctx_in_transmission_requests:
|
||||
request, block_id, counter = self.ctx_in_transmission_requests.pop(
|
||||
req.py_request_id)
|
||||
if counter == 1:
|
||||
self.kv_cache_manager.unpin_blocks_by_id(block_id)
|
||||
else:
|
||||
self.ctx_in_transmission_requests[req.py_request_id] = (
|
||||
request, block_id, counter - 1)
|
||||
self._end_transfer_and_maybe_terminate(req)
|
||||
|
||||
def _kv_connector_wait_for_save(self):
|
||||
if self.kv_connector_manager is not None:
|
||||
@ -1457,25 +1559,9 @@ class PyExecutor:
|
||||
|
||||
self._update_request_states(scheduled_batch)
|
||||
self._update_requests(sample_state, self.resource_manager)
|
||||
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
|
||||
for req in scheduled_batch.context_requests:
|
||||
if req.is_context_only_request and (
|
||||
req.is_context_finished
|
||||
or req.is_finished_due_to_length
|
||||
) and not req.is_finished_due_to_cancellation:
|
||||
block_id = self.kv_cache_manager.store_blocks_for_reuse(
|
||||
req, True)
|
||||
self.ctx_in_transmission_requests[
|
||||
req.py_request_id] = (
|
||||
(req, block_id,
|
||||
self.ctx_in_transmission_counter))
|
||||
|
||||
if self.kv_cache_transceiver:
|
||||
ctx_transmission_reqs = self._send_disagg_ctx_cache(
|
||||
scheduled_batch.context_requests)
|
||||
# For context only req in transmission, we reset the state since sampler might have changed it
|
||||
for req in ctx_transmission_reqs:
|
||||
req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
|
||||
self._send_kv_async(scheduled_batch.context_requests +
|
||||
scheduled_batch.generation_requests)
|
||||
|
||||
self._handle_canceled_requests()
|
||||
finished_requests = self._handle_responses()
|
||||
@ -1489,9 +1575,9 @@ class PyExecutor:
|
||||
if self.enable_kv_cache_events:
|
||||
self._add_kv_cache_events()
|
||||
|
||||
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
|
||||
if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests(
|
||||
):
|
||||
self._check_kv_transfer_timeout()
|
||||
self._terminate_disagg_ctx_finished_requests()
|
||||
|
||||
self._kv_connector_terminate_requests()
|
||||
|
||||
@ -1709,19 +1795,6 @@ class PyExecutor:
|
||||
if self.previous_batch is not None:
|
||||
self._update_requests(self.previous_batch.sample_state)
|
||||
|
||||
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
|
||||
for req in self.previous_batch.sample_state.scheduled_requests.context_requests:
|
||||
if req.is_context_only_request and (
|
||||
req.is_context_finished
|
||||
or req.is_finished_due_to_length
|
||||
) and not req.is_finished_due_to_cancellation:
|
||||
block_id = self.kv_cache_manager.store_blocks_for_reuse(
|
||||
req, True)
|
||||
self.ctx_in_transmission_requests[
|
||||
req.py_request_id] = (
|
||||
(req, block_id,
|
||||
self.ctx_in_transmission_counter))
|
||||
|
||||
if self.drafter is not None and self.use_spec_decode:
|
||||
# Cleanup previous draft resources used in the draft model
|
||||
self.drafter.cleanup_previous_draft_resources()
|
||||
@ -1746,9 +1819,8 @@ class PyExecutor:
|
||||
|
||||
self._update_request_states(scheduled_batch)
|
||||
|
||||
ctx_transmission_reqs = self._send_disagg_ctx_cache(
|
||||
scheduled_batch.context_requests
|
||||
) if self.kv_cache_transceiver else []
|
||||
ctx_transmission_reqs = self._send_kv_async(
|
||||
scheduled_batch.all_requests())
|
||||
|
||||
if self.previous_batch is not None:
|
||||
self._process_previous_batch()
|
||||
@ -1764,9 +1836,9 @@ class PyExecutor:
|
||||
iter_stats=iter_stats,
|
||||
ctx_transmission_reqs=ctx_transmission_reqs)
|
||||
|
||||
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
|
||||
if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests(
|
||||
):
|
||||
self._check_kv_transfer_timeout()
|
||||
self._terminate_disagg_ctx_finished_requests()
|
||||
|
||||
self._kv_connector_terminate_requests()
|
||||
|
||||
@ -2101,7 +2173,7 @@ class PyExecutor:
|
||||
)
|
||||
req.py_kv_transfer_timed_out = True
|
||||
|
||||
for req, _, _ in self.ctx_in_transmission_requests.values():
|
||||
for req in self.async_transfer_manager.requests_in_transfer().values():
|
||||
flag_if_kv_transfer_timed_out(req, "context")
|
||||
|
||||
for req in self.active_requests:
|
||||
@ -2221,36 +2293,45 @@ class PyExecutor:
|
||||
|
||||
return
|
||||
|
||||
@nvtx_range("_send_disagg_ctx_cache")
|
||||
def _send_disagg_ctx_cache(self, scheduled_ctx_requests):
|
||||
if (scheduled_ctx_requests is None or len(scheduled_ctx_requests) == 0):
|
||||
return []
|
||||
for req in scheduled_ctx_requests:
|
||||
if req.is_context_only_request and (
|
||||
req.is_context_finished or req.is_finished_due_to_length
|
||||
) and not req.is_finished_due_to_cancellation:
|
||||
self.kv_cache_transceiver.respond_and_send_async(req)
|
||||
for resource_mgr_type in (
|
||||
ResourceManagerType.SEQ_SLOT_MANAGER,
|
||||
ResourceManagerType.SPEC_RESOURCE_MANAGER):
|
||||
if resource_mgr_type in self.resource_manager.resource_managers and self.resource_manager.resource_managers[
|
||||
resource_mgr_type] is not None:
|
||||
self.resource_manager.resource_managers[
|
||||
resource_mgr_type].free_resources(req)
|
||||
@nvtx_range("_send_kv_async")
|
||||
def _send_kv_async(self, scheduled_requests: List[LlmRequest]):
|
||||
|
||||
self._check_disagg_ctx_cache_transfer_status(0)
|
||||
def kv_connector_request_finished(req: LlmRequest):
|
||||
try:
|
||||
cache_block_ids = self.kv_cache_manager.get_cache_indices(req)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Unable to get cache blocks for request {req.py_request_id}. Skipping asynchronous saving: {e}"
|
||||
)
|
||||
else:
|
||||
if self.kv_connector_manager.request_finished(
|
||||
req, cache_block_ids):
|
||||
self.async_transfer_manager.start_transfer(req)
|
||||
|
||||
# Keep track of ctx requests that are in transmission
|
||||
ctx_transmission_reqs = [
|
||||
req for req in scheduled_ctx_requests
|
||||
if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
|
||||
]
|
||||
if self.kv_cache_transceiver:
|
||||
for req in scheduled_requests:
|
||||
if req.is_context_only_request and (
|
||||
req.is_context_finished or req.is_finished_due_to_length
|
||||
) and not req.is_finished_due_to_cancellation:
|
||||
self.kv_cache_transceiver.respond_and_send_async(req)
|
||||
|
||||
if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
|
||||
for req in ctx_transmission_reqs:
|
||||
req.py_kv_transfer_start_time = time.time()
|
||||
self.async_transfer_manager.start_transfer(req)
|
||||
|
||||
return ctx_transmission_reqs
|
||||
if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
|
||||
req.py_kv_transfer_start_time = time.time()
|
||||
|
||||
if self.kv_connector_manager:
|
||||
if not self.disable_overlap_scheduler:
|
||||
requests = self.previous_batch.sample_state.scheduled_requests.all_requests(
|
||||
) if self.previous_batch is not None else []
|
||||
else:
|
||||
requests = scheduled_requests
|
||||
for req in requests:
|
||||
if req.is_finished:
|
||||
kv_connector_request_finished(req)
|
||||
|
||||
if self.kv_cache_transceiver:
|
||||
self._check_disagg_ctx_cache_transfer_status(0)
|
||||
|
||||
def _get_disagg_reqs_in_error_state(self):
|
||||
return [
|
||||
@ -2268,7 +2349,41 @@ class PyExecutor:
|
||||
|
||||
@nvtx_range("_check_disagg_ctx_cache_transfer_status")
|
||||
def _check_disagg_ctx_cache_transfer_status(self, atLeastNum: int = 0):
|
||||
self.kv_cache_transceiver.check_context_transfer_status(atLeastNum)
|
||||
finished_requests, error_requests = self.kv_cache_transceiver.check_context_transfer_status(
|
||||
atLeastNum)
|
||||
|
||||
completed_req_ids = set(finished_requests + error_requests)
|
||||
|
||||
requests_in_transfer = self.async_transfer_manager.requests_in_transfer(
|
||||
)
|
||||
|
||||
for request_id in completed_req_ids:
|
||||
|
||||
if request_id not in requests_in_transfer:
|
||||
logger.warning(
|
||||
f"Request {request_id} not found in transfer manager")
|
||||
continue
|
||||
|
||||
request = requests_in_transfer[request_id]
|
||||
|
||||
self._end_transfer_and_maybe_terminate(request)
|
||||
|
||||
# The set of requests in transfer may have changed since we terminated some requests.
|
||||
requests_in_transfer = self.async_transfer_manager.requests_in_transfer(
|
||||
)
|
||||
|
||||
for request_id in list(requests_in_transfer.keys()):
|
||||
request = requests_in_transfer[request_id]
|
||||
if request.py_kv_transfer_timed_out and request_id not in completed_req_ids:
|
||||
is_cancelled = self.kv_cache_transceiver.cancel_request(request)
|
||||
# If cancel is successful, mark as complete so it can be cleaned up
|
||||
# Otherwise, try at next iteration
|
||||
if is_cancelled:
|
||||
request.py_kv_transfer_start_time = None
|
||||
request.state = LlmRequestState.DISAGG_CONTEXT_COMPLETE
|
||||
|
||||
self._end_transfer_and_maybe_terminate(request)
|
||||
|
||||
self._check_cache_transfer_errors("context requests")
|
||||
|
||||
@nvtx_range("_check_disagg_gen_cache_transfer_status")
|
||||
@ -2474,24 +2589,6 @@ class PyExecutor:
|
||||
self._do_terminate_request(request)
|
||||
|
||||
def _do_terminate_request(self, request: LlmRequest):
|
||||
if self.kv_connector_manager is not None:
|
||||
# Only call request_finished on the connector if the request has already been added to the kv cache manager.
|
||||
try:
|
||||
cache_block_ids = self.kv_cache_manager.get_cache_indices(
|
||||
request)
|
||||
except IndexError:
|
||||
# If the request has not yet been added to the kv cache manager,
|
||||
# we still need to free resources corresponding to other resource managers.
|
||||
self.resource_manager.free_resources(request)
|
||||
else:
|
||||
if self.kv_connector_manager.request_finished(
|
||||
request,
|
||||
cache_block_ids) and not self.kv_cache_transceiver:
|
||||
block_id = self.kv_cache_manager.store_blocks_for_reuse(
|
||||
request, True)
|
||||
self.ctx_in_transmission_requests[request.py_request_id] = (
|
||||
(request, block_id, self.ctx_in_transmission_counter))
|
||||
|
||||
self.resource_manager.free_resources(request)
|
||||
|
||||
if self.gather_all_responses or self.dist.rank == 0:
|
||||
@ -2678,12 +2775,7 @@ class PyExecutor:
|
||||
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa:
|
||||
requests_to_terminate.append(request)
|
||||
else:
|
||||
if request.is_disagg_context_transmission_state:
|
||||
self.ctx_in_transmission_requests[
|
||||
request.py_request_id] = (
|
||||
(request, None,
|
||||
self.ctx_in_transmission_counter))
|
||||
else:
|
||||
if not request.is_disagg_context_transmission_state:
|
||||
requests_to_terminate.append(request)
|
||||
else:
|
||||
new_active_requests.append(request)
|
||||
@ -2696,35 +2788,6 @@ class PyExecutor:
|
||||
self._terminate_request(request)
|
||||
return requests_to_terminate
|
||||
|
||||
@nvtx_range("_terminate_disagg_ctx_finished_requests")
|
||||
def _terminate_disagg_ctx_finished_requests(self):
|
||||
# make a copy of the keys, since we are modifying the dictionary in the loop
|
||||
in_transmission_requests_id = list(
|
||||
self.ctx_in_transmission_requests.keys())
|
||||
for request_id in in_transmission_requests_id:
|
||||
request, block_id, counter = self.ctx_in_transmission_requests[
|
||||
request_id]
|
||||
|
||||
if request.py_kv_transfer_timed_out:
|
||||
is_cancelled = self.kv_cache_transceiver.cancel_request(request)
|
||||
# If cancel is successful, mark as complete so it can be cleaned up
|
||||
# Otherwise, try at next iteration
|
||||
if is_cancelled:
|
||||
request.py_kv_transfer_start_time = None
|
||||
request.state = LlmRequestState.DISAGG_CONTEXT_COMPLETE
|
||||
|
||||
if request.is_disagg_context_complete_state:
|
||||
del self.ctx_in_transmission_requests[request_id]
|
||||
if not self.block_reuse_enabled or self.kv_cache_manager.is_vswa:
|
||||
self._terminate_request(request)
|
||||
elif counter == 1:
|
||||
self.kv_cache_manager.unpin_blocks_by_id(block_id)
|
||||
else:
|
||||
self.ctx_in_transmission_requests[request_id] = ((request,
|
||||
block_id,
|
||||
counter -
|
||||
1))
|
||||
|
||||
def _handle_logits_communication(self, previous_batch, prev_microbatch_id):
|
||||
"""Handle logits communication between pipeline parallel ranks.
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ l0_a10:
|
||||
- unittest/_torch/modeling/test_modeling_mistral.py
|
||||
- unittest/_torch/modeling/test_modeling_pixtral.py
|
||||
- unittest/_torch/sampler/test_trtllm_sampler.py
|
||||
- unittest/_torch/executor/test_async_transfer_manager.py
|
||||
- unittest/_torch/executor/test_scheduler_serializable_output.py
|
||||
# NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no
|
||||
# test list either).
|
||||
|
||||
182
tests/unittest/_torch/executor/test_async_transfer_manager.py
Normal file
182
tests/unittest/_torch/executor/test_async_transfer_manager.py
Normal file
@ -0,0 +1,182 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.py_executor import AsyncTransferManager
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
|
||||
from tensorrt_llm.bindings import LlmRequestState
|
||||
|
||||
|
||||
def create_mock_request(request_id: int):
|
||||
"""Create a mock LlmRequest with the given request ID."""
|
||||
request = MagicMock()
|
||||
request.py_request_id = request_id
|
||||
request.state = LlmRequestState.GENERATION_IN_PROGRESS
|
||||
return request
|
||||
|
||||
|
||||
def create_mock_resource_manager(
|
||||
kv_cache_manager=None,
|
||||
seq_slot_manager=None,
|
||||
spec_resource_manager=None,
|
||||
):
|
||||
"""Create a mock ResourceManager with the specified resource managers."""
|
||||
resource_manager = MagicMock()
|
||||
resource_manager.resource_managers = {}
|
||||
|
||||
if kv_cache_manager is not None:
|
||||
resource_manager.resource_managers[ResourceManagerType.KV_CACHE_MANAGER] = kv_cache_manager
|
||||
|
||||
if seq_slot_manager is not None:
|
||||
resource_manager.resource_managers[ResourceManagerType.SEQ_SLOT_MANAGER] = seq_slot_manager
|
||||
|
||||
if spec_resource_manager is not None:
|
||||
resource_manager.resource_managers[ResourceManagerType.SPEC_RESOURCE_MANAGER] = (
|
||||
spec_resource_manager
|
||||
)
|
||||
|
||||
return resource_manager
|
||||
|
||||
|
||||
def test_start_transfer_single_request():
|
||||
"""Test starting a single transfer."""
|
||||
kv_cache_manager = MagicMock()
|
||||
kv_cache_manager.store_blocks_for_reuse.return_value = 100
|
||||
seq_slot_manager = MagicMock()
|
||||
resource_manager = create_mock_resource_manager(
|
||||
kv_cache_manager=kv_cache_manager, seq_slot_manager=seq_slot_manager
|
||||
)
|
||||
manager = AsyncTransferManager(resource_manager)
|
||||
|
||||
request = create_mock_request(42)
|
||||
manager.start_transfer(request)
|
||||
|
||||
# Check request is tracked
|
||||
assert 42 in manager._requests_in_transfer
|
||||
|
||||
transfer_metadata = manager._request_transfer_metadata[42]
|
||||
|
||||
assert transfer_metadata.block_id == 100
|
||||
assert transfer_metadata.counter == 1
|
||||
|
||||
# Check state was updated
|
||||
assert request.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
|
||||
|
||||
# Check KV cache manager was called
|
||||
kv_cache_manager.store_blocks_for_reuse.assert_called_once_with(request, True)
|
||||
|
||||
# Check seq slot manager was called to free resources
|
||||
seq_slot_manager.free_resources.assert_called_once_with(request)
|
||||
|
||||
manager.end_transfer(request)
|
||||
kv_cache_manager.unpin_blocks_by_id.assert_called_once()
|
||||
|
||||
|
||||
def test_start_transfer_multiple_transfers_same_request():
|
||||
"""Test starting multiple transfers for the same request."""
|
||||
kv_cache_manager = MagicMock()
|
||||
kv_cache_manager.store_blocks_for_reuse.return_value = 100
|
||||
resource_manager = create_mock_resource_manager(kv_cache_manager=kv_cache_manager)
|
||||
manager = AsyncTransferManager(resource_manager)
|
||||
|
||||
request = create_mock_request(42)
|
||||
manager.start_transfer(request)
|
||||
manager.start_transfer(request)
|
||||
manager.start_transfer(request)
|
||||
|
||||
# Counter should be incremented
|
||||
transfer_metadata = manager._request_transfer_metadata[42]
|
||||
assert transfer_metadata.counter == 3
|
||||
|
||||
# store_blocks_for_reuse should only be called once
|
||||
kv_cache_manager.store_blocks_for_reuse.assert_called_once()
|
||||
|
||||
for _ in range(2):
|
||||
manager.end_transfer(request)
|
||||
kv_cache_manager.unpin_blocks_by_id.assert_not_called()
|
||||
|
||||
manager.end_transfer(request)
|
||||
kv_cache_manager.unpin_blocks_by_id.assert_called_once()
|
||||
|
||||
|
||||
def test_transfer_without_storing_blocks():
|
||||
"""Test starting a transfer with should_store_blocks=False."""
|
||||
kv_cache_manager = MagicMock()
|
||||
kv_cache_manager.store_blocks_for_reuse.return_value = 0
|
||||
spec_resource_manager = MagicMock()
|
||||
resource_manager = create_mock_resource_manager(
|
||||
kv_cache_manager=kv_cache_manager, spec_resource_manager=spec_resource_manager
|
||||
)
|
||||
manager = AsyncTransferManager(resource_manager, should_store_blocks=False)
|
||||
|
||||
request = create_mock_request(42)
|
||||
manager.start_transfer(request)
|
||||
|
||||
# Check request is tracked
|
||||
assert 42 in manager._requests_in_transfer
|
||||
transfer_metadata = manager._request_transfer_metadata[42]
|
||||
assert transfer_metadata.block_id is None # No block stored
|
||||
assert transfer_metadata.counter == 1
|
||||
|
||||
# Check KV cache manager was NOT called
|
||||
kv_cache_manager.store_blocks_for_reuse.assert_not_called()
|
||||
spec_resource_manager.free_resources.assert_called_once_with(request)
|
||||
|
||||
assert manager.end_transfer(request)
|
||||
|
||||
kv_cache_manager.unpin_blocks_by_id.assert_not_called()
|
||||
|
||||
|
||||
def test_end_transfer_preserves_error_state():
|
||||
"""Test that end_transfer does not overwrite error state."""
|
||||
kv_cache_manager = MagicMock()
|
||||
kv_cache_manager.store_blocks_for_reuse.return_value = 100
|
||||
resource_manager = create_mock_resource_manager(kv_cache_manager=kv_cache_manager)
|
||||
manager = AsyncTransferManager(resource_manager)
|
||||
|
||||
request = create_mock_request(42)
|
||||
manager.start_transfer(request)
|
||||
|
||||
# Set error state before end_transfer
|
||||
request.state = LlmRequestState.DISAGG_TRANS_ERROR
|
||||
|
||||
manager.end_transfer(request)
|
||||
|
||||
# Error state should be preserved
|
||||
assert request.state == LlmRequestState.DISAGG_TRANS_ERROR
|
||||
|
||||
|
||||
def test_requests_in_transfer():
|
||||
"""Test that requests_in_transfer returns correct mapping."""
|
||||
kv_cache_manager = MagicMock()
|
||||
kv_cache_manager.store_blocks_for_reuse.return_value = 100
|
||||
resource_manager = create_mock_resource_manager(kv_cache_manager=kv_cache_manager)
|
||||
manager = AsyncTransferManager(resource_manager)
|
||||
|
||||
request1 = create_mock_request(1)
|
||||
request2 = create_mock_request(2)
|
||||
request3 = create_mock_request(3)
|
||||
|
||||
manager.start_transfer(request1)
|
||||
manager.start_transfer(request2)
|
||||
manager.start_transfer(request3)
|
||||
|
||||
in_transfer = manager.requests_in_transfer()
|
||||
|
||||
assert len(in_transfer) == 3
|
||||
assert in_transfer[1] is request1
|
||||
assert in_transfer[2] is request2
|
||||
assert in_transfer[3] is request3
|
||||
Loading…
Reference in New Issue
Block a user