/* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h" #include "tensorrt_llm/executor/types.h" #include #include #include #define UCX_WRAPPER_LIB_NAME "tensorrt_llm_ucx_wrapper" #if defined(_WIN32) #include #define dllOpen(name) LoadLibrary(name ".dll") #define dllClose(handle) FreeLibrary(static_cast(handle)) #define dllGetSym(handle, name) static_cast(GetProcAddress(static_cast(handle), name)) #else // For non-Windows platforms #include #define dllOpen(name) dlopen("lib" name ".so", RTLD_LAZY) #define dllClose(handle) dlclose(handle) #define dllGetSym(handle, name) dlsym(handle, name) #endif // defined(_WIN32) #include "tensorrt_llm/batch_manager/cacheFormatter.h" #include "tensorrt_llm/batch_manager/cacheTransceiver.h" #include "tensorrt_llm/batch_manager/contextProgress.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/batch_manager/kvCacheType.h" #include "tensorrt_llm/batch_manager/kvCacheUtils.h" #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/batch_manager/mlaCacheFormatter.h" #include "tensorrt_llm/common/envUtils.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/executor/cache_transmission/mpi_utils/connection.h" #include "tensorrt_llm/executor/dataTransceiverState.h" #include "tensorrt_llm/executor/serializeUtils.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include "tensorrt_llm/runtime/utils/pgUtils.h" #include #include #include #include namespace tensorrt_llm::batch_manager { std::mutex CacheTransceiver::mDllMutex; std::unique_ptr CacheTransceiverFactory::createCacheTransceiver( kv_cache_manager::BaseKVCacheManager* cacheManager, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, executor::kv_cache::CacheState::AttentionType attentionType, std::optional cacheTransceiverConfig) { if (!cacheTransceiverConfig.has_value() || !cacheTransceiverConfig.value().getBackendType().has_value()) { TLLM_LOG_INFO("CacheTransceiver is disabled."); return nullptr; } auto backendType = cacheTransceiverConfig.value().getBackendType(); if (backendType.value() == executor::CacheTransceiverConfig::BackendType::DEFAULT) { if (common::getEnvUseUCXKvCache()) { backendType = executor::CacheTransceiverConfig::BackendType::UCX; TLLM_LOG_INFO("Enable UCX KV cache transport."); } else if (common::getEnvUseNixlKvCache()) { backendType = executor::CacheTransceiverConfig::BackendType::NIXL; TLLM_LOG_INFO("Enable NIXL KV cache transport."); } else if (common::getEnvUseMooncakeKvCache()) { backendType = executor::CacheTransceiverConfig::BackendType::MOONCAKE; TLLM_LOG_INFO("Enable MOONCAKE KV cache transport."); } else if (common::getEnvUseMPIKvCache()) { backendType = executor::CacheTransceiverConfig::BackendType::MPI; TLLM_LOG_INFO("Enable MPI KV cache transport."); TLLM_LOG_WARNING("MPI KV cache transport is deprecated, please use UCX or NIXL instead."); } else { backendType = executor::CacheTransceiverConfig::BackendType::NIXL; } } cacheTransceiverConfig.value().setBackendType(backendType); executor::kv_cache::CacheState::ModelConfig cacheStateCfg{ modelConfig.getNumKvHeadsPerLayer(), modelConfig.getSizePerHead(), modelConfig.getTokensPerBlock()}; auto ppSize = worldConfig.getPipelineParallelism(); std::vector attentionLayerNumPerPP(ppSize, 0); for (int ppRank = 0; ppRank < ppSize; ppRank++) { attentionLayerNumPerPP[ppRank] = modelConfig.getNbAttentionLayers(ppSize, ppRank); } return std::make_unique(cacheManager, cacheStateCfg, worldConfig, attentionLayerNumPerPP, modelConfig.getKvDataType(), attentionType, cacheTransceiverConfig); } CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager, executor::kv_cache::CacheState::ModelConfig const& cacheStateModelCfg, runtime::WorldConfig const& worldConfig, std::vector const& attentionLayerNumPerPP, nvinfer1::DataType dataType, executor::kv_cache::CacheState::AttentionType attentionType, std::optional cacheTransceiverConfig) : mCacheTransceiverConfig{cacheTransceiverConfig} { using tensorrt_llm::batch_manager::kv_cache_manager::CacheFormatter; if (useMPI()) { mGroupComm = std::make_shared(std::addressof(tensorrt_llm::mpi::MpiComm::session())); } else { mGroupComm = std::make_shared(tensorrt_llm::pg_utils::get_world_pg()); } if (worldConfig.isTensorParallel()) { mGroupTensorParaComm = std::make_shared( mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getTensorParallelRank())); } int kvFactor = 2; if (cacheManager->getCacheType() == kv_cache_manager::CacheType::kSELFKONLY) { kvFactor = 1; } mCacheState = std::make_unique(cacheStateModelCfg, worldConfig, attentionLayerNumPerPP, dataType, attentionType, kvFactor, cacheManager->isEnableBlockReuse(), cacheManager->isEnableIndexerKCache(), cacheManager->getIndexerKCacheIndexHeadDim(), cacheManager->getIndexerKCacheQuantBlockSize()); if (mCacheState->getParallelConfig().mEnableAttentionDP) { int TPSizeInDPGroup = mCacheState->getParallelConfig().mTensorParallelism / mCacheState->getParallelConfig().mDPsize; int DPSize = mCacheState->getParallelConfig().mDPsize; int TPRankInDPGroup = worldConfig.getTensorParallelRank() % TPSizeInDPGroup; int DPRank = (worldConfig.getRank() - TPSizeInDPGroup * DPSize * worldConfig.getPipelineParallelRank() - TPRankInDPGroup) / TPSizeInDPGroup; // mGroupDataComm = std::make_shared(mGroupComm->split(DPRank, worldConfig.getRank())); if (worldConfig.isTensorParallel()) { mGroupTPInDPComm = std::make_shared( mGroupComm->split(worldConfig.getRank() / TPSizeInDPGroup, worldConfig.getRank())); } } bool isMLA = attentionType == executor::kv_cache::CacheState::AttentionType::kMLA; TLLM_CHECK_WITH_INFO(mCacheTransceiverConfig.has_value(), "CacheTransceiverConfig is not set."); auto backendType = mCacheTransceiverConfig.value().getBackendType(); TLLM_CHECK_WITH_INFO( backendType.has_value() && (backendType.value() != executor::CacheTransceiverConfig::BackendType::DEFAULT), " CacheTransceiverConfig::BackendType is not set."); std::optional maxNumTokens = mCacheTransceiverConfig.value().getMaxTokensInBuffer(); mCacheTransBufferManagers.push_back( std::make_unique(cacheManager, maxNumTokens)); if (isMLA && cacheManager->isEnableIndexerKCache()) { mCacheTransBufferManagers.push_back( std::make_unique(cacheManager, maxNumTokens, true)); } mCacheTransBufferManagerPtrs.clear(); mCacheTransBufferManagerPtrs.reserve(mCacheTransBufferManagers.size()); for (auto& manager : mCacheTransBufferManagers) { mCacheTransBufferManagerPtrs.push_back(manager.get()); } if (backendType.value() == executor::CacheTransceiverConfig::BackendType::UCX) { std::lock_guard lock(mDllMutex); mWrapperLibHandle = dllOpen(UCX_WRAPPER_LIB_NAME); TLLM_CHECK_WITH_INFO( mWrapperLibHandle != nullptr, "UCX wrapper library is not open correctly. error : %s", dlerror()); auto load_sym = [](void* handle, char const* name) { void* ret = dllGetSym(handle, name); TLLM_CHECK_WITH_INFO(ret != nullptr, "Unable to load UCX wrapper library symbol, possible cause is that TensorRT LLM library is not " "built with UCX support, please rebuild in UCX-enabled environment."); return ret; }; std::unique_ptr (*makeUcxConnectionManager)(); *(void**) (&makeUcxConnectionManager) = load_sym(mWrapperLibHandle, "makeUcxConnectionManager"); mManager = makeUcxConnectionManager(); TLLM_LOG_INFO("UCX Connection Manager created"); } else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::NIXL) { mManager = std::make_unique( mCacheTransBufferManagerPtrs, *mCacheState, "nixl"); TLLM_LOG_INFO("NIXL Connection Manager created"); } else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MOONCAKE) { mManager = std::make_unique( mCacheTransBufferManagerPtrs, *mCacheState, "mooncake"); TLLM_LOG_INFO("MOONCAKE Connection Manager created"); } else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MPI) { mMpiWorldComm = std::addressof(tensorrt_llm::mpi::MpiComm::world()); mManager = std::make_unique(mMpiWorldComm); TLLM_LOG_INFO("MPI Connection Manager created"); } else { TLLM_THROW("Unsupported cache transceiver backend type "); } auto makeFormatter = [cacheManager, isMLA, this]() { return createCacheFormatter(cacheManager, mCacheTransBufferManagerPtrs, isMLA); }; mCacheSender = std::make_unique(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()); mCacheReceiver = std::make_unique(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()); initializeCommState(); } CacheTransceiver::~CacheTransceiver() { if (mWrapperLibHandle) { std::lock_guard lock(mDllMutex); dllClose(mWrapperLibHandle); } } void CacheTransceiver::initializeCommState() { mCommState = std::addressof(mCacheSender->getCommState()); } void CacheTransceiver::setContextState(LlmRequest* llmRequest) { TLLM_CHECK(llmRequest && llmRequest->isContextOnlyRequest()); auto contextState = std::make_unique(); contextState->setCommState(*mCommState); contextState->setCacheState(*mCacheState); if (!llmRequest->hasDraftTokens()) { llmRequest->setContextPhaseParams( executor::ContextPhaseParams{{}, llmRequest->mRequestId, contextState.release(), std::nullopt}); } else { llmRequest->setContextPhaseParams(executor::ContextPhaseParams{ {}, llmRequest->mRequestId, contextState.release(), *llmRequest->getDraftTokens()}); } } void CacheTransceiver::respondAndSendAsync(LlmRequest* llmRequest) { TLLM_CHECK(llmRequest && llmRequest->isContextOnlyRequest()); llmRequest->setState(LlmRequestState::kDISAGG_CONTEXT_TRANS_IN_PROGRESS); // If context phase params is already set, it means that the KV cache // transfer is already in progress. if (llmRequest->getContextPhaseParams().has_value()) { if (llmRequest->getContextProgress() == nullptr) { TLLM_LOG_WARNING("Request %ld is already responding", llmRequest->mRequestId); } return; } setContextState(llmRequest); auto future = mCacheSender->sendAsync(*llmRequest); mSenderFutures.emplace_back(llmRequest, std::move(future)); } void CacheTransceiver::respondAndSendLayerWise( RequestVector const& requests, std::shared_ptr const& progress) { for (auto const& llmRequest : requests) { TLLM_CHECK(llmRequest && llmRequest->isContextOnlyRequest()); TLLM_CHECK(!llmRequest->getContextPhaseParams().has_value()); llmRequest->setContextProgress(progress); TLLM_LOG_DEBUG("Request %ld is being sent layer-wise.", llmRequest->mRequestId); llmRequest->setState(LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS); setContextState(llmRequest.get()); auto future = mCacheSender->sendAsync(*llmRequest); mSenderFutures.emplace_back(llmRequest.get(), std::move(future)); } } void CacheTransceiver::requestAndReceiveSync(LlmRequest* llmRequest) { TLLM_CHECK(llmRequest && llmRequest->isGenerationOnlyRequest()); { auto future = mCacheReceiver->receiveAsync(*llmRequest); future.get(); } llmRequest->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE); } void CacheTransceiver::requestAndReceiveAsync(LlmRequest* llmRequest) { TLLM_CHECK(llmRequest && llmRequest->isGenerationOnlyRequest()); if (std::find_if(mRequesterFutures.begin(), mRequesterFutures.end(), [llmRequest](auto const& pair) { return pair.first->mRequestId == llmRequest->mRequestId; }) != mRequesterFutures.end()) { TLLM_LOG_WARNING("Request ID %zu is already in mRequestFutures.", llmRequest->mRequestId); return; } auto future = mCacheReceiver->receiveAsync(*llmRequest); mRequesterFutures.emplace_back(llmRequest, std::move(future)); llmRequest->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS); } std::vector gatherRequestIds( std::shared_ptr const& mComm, std::vector const& requestIds) { int localSize = static_cast(requestIds.size()); std::vector sizes(mComm->getSize()); std::vector retData; if (useMPI()) { mComm->allgather(&localSize, sizes.data(), 1, mpi::MpiType::kINT32); std::vector displs(mComm->getSize()); size_t totalSize = 0; for (int i = 0; i < mComm->getSize(); i++) { displs[i] = totalSize; totalSize += sizes[i]; } retData.resize(totalSize); mComm->allgatherv(requestIds.data(), static_cast(requestIds.size()), mpi::MpiType::kUINT64, retData.data(), sizes, displs, mpi::MpiType::kUINT64); } else { mComm->allgather(&localSize, std::ref(sizes), {}); size_t totalSize = std::accumulate(sizes.begin(), sizes.end(), 0); retData.resize(totalSize); mComm->allgatherv(std::ref(requestIds), std::ref(retData), std::cref(sizes), {}); } return retData; } void updateKVCacheTransferBW(std::shared_ptr const& mComm, LlmRequest* request) { namespace su = executor::serialize_utils; int worldSize = mComm->getSize(); std::ostringstream oStream; su::serialize(request->getKvCacheTransferStart(), oStream); su::serialize(request->getKvCacheTransferEnd(), oStream); auto str = oStream.str(); std::vector sendBuffer(str.begin(), str.end()); auto sendBufferSize = sendBuffer.size(); auto recvBufferSize = sendBufferSize * worldSize; std::vector recvBuffer(recvBufferSize); if (useMPI()) { mComm->allgather(sendBuffer.data(), recvBuffer.data(), sendBufferSize, mpi::MpiType::kCHAR); } else { mComm->allgather(std::ref(sendBuffer), std::ref(recvBuffer), {}); } su::VectorWrapBuf strbuf(recvBuffer); std::istream is(&strbuf); auto minStartTime = executor::RequestPerfMetrics::TimePoint::max(); auto maxEndTime = executor::RequestPerfMetrics::TimePoint::min(); for (int rank = 0; rank < worldSize; rank++) { minStartTime = std::min(su::deserialize(is), minStartTime); maxEndTime = std::max(su::deserialize(is), maxEndTime); } // Handle KV cache size separately - gather all sizes to the leader rank std::size_t localKVCacheSize = request->getKvCacheSize(); std::vector allKVCacheSizes(worldSize, 0); if (useMPI()) { mComm->allgather(&localKVCacheSize, allKVCacheSizes.data(), 1, mpi::MpiType::kUINT64); } else { mComm->allgather(&localKVCacheSize, std::ref(allKVCacheSizes), {}); } std::size_t totalKVCacheSize = 0; for (int rank = 0; rank < worldSize; rank++) { totalKVCacheSize += allKVCacheSizes[rank]; } // Update the latest KV cache transfer time for leader rank if (mComm->getRank() == 0) { request->setKvCacheTransferStart(minStartTime); request->setKvCacheTransferEnd(maxEndTime); request->setKvCacheSize(totalKVCacheSize); } } void CacheTransceiver::checkContextTransferStatus(std::optional const& atLeastRequestNum) { bool blockAll = !atLeastRequestNum.has_value(); std::optional senderFutureTimeoutMs = std::nullopt; // If blockAll is true, we want to block and not use a timeout if (!blockAll && mCacheTransceiverConfig.has_value()) { senderFutureTimeoutMs = mCacheTransceiverConfig->getKvTransferSenderFutureTimeoutMs(); } auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupTPInDPComm : mGroupTensorParaComm; std::vector contextCompleteRequestIds; for (auto&& [request, future] : mSenderFutures) { if (future.wait_for(std::chrono::milliseconds(0)) == std::future_status::ready) { contextCompleteRequestIds.push_back(request->mRequestId); } } std::unordered_map frequencyMap; if ((syncComm) && syncComm->getSize() > 1) { auto gatherRequestIdVec = gatherRequestIds(syncComm, contextCompleteRequestIds); for (auto&& requestId : gatherRequestIdVec) { frequencyMap[requestId]++; } } else { for (auto&& requestId : contextCompleteRequestIds) { frequencyMap[requestId]++; } } std::vector> freqVec(frequencyMap.begin(), frequencyMap.end()); std::sort(freqVec.begin(), freqVec.end(), [](std::pair const& left, std::pair const& right) { return left.second > right.second; }); std::unordered_set toCompleteIdSet; for (auto&& [requestId, freq] : freqVec) { if (freq == ((syncComm) ? syncComm->getSize() : 1)) { toCompleteIdSet.insert(requestId); } } // Make sure there are at least atLeastRequestNum requests in toCompleteIdSet. // This will preserve the order of insertion for KVCache transfer requests. for (auto it = mSenderFutures.begin(); atLeastRequestNum.value_or(0) > static_cast(toCompleteIdSet.size()) && it != mSenderFutures.end(); ++it) { auto& [request, future] = *it; toCompleteIdSet.insert(request->mRequestId); } // Complete all the requests in toCompleteIdSet for (auto it = mSenderFutures.begin(); it != mSenderFutures.end();) { auto& [request, future] = *it; if (blockAll || (toCompleteIdSet.find(request->mRequestId) != toCompleteIdSet.end())) { try { // Wait for up to a specified timeout auto status = future.wait_for(std::chrono::milliseconds(senderFutureTimeoutMs.value_or(0))); if (status == std::future_status::ready || !senderFutureTimeoutMs.has_value()) { future.get(); request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE); it = mSenderFutures.erase(it); } else if (status == std::future_status::timeout) { TLLM_LOG_WARNING("Timed out waiting for context KV cache transfer after %d milliseconds.", senderFutureTimeoutMs.value()); ++it; } else { TLLM_LOG_ERROR( "Future returned unexpected status for request %ld. Marking as error", request->mRequestId); request->setState(LlmRequestState::kDISAGG_TRANS_ERROR); it = mSenderFutures.erase(it); } } catch (std::exception const& e) { TLLM_LOG_ERROR( "Error occurred during context transfer for request %ld: %s", request->mRequestId, e.what()); request->setState(LlmRequestState::kDISAGG_TRANS_ERROR); it = mSenderFutures.erase(it); } } else { ++it; } } } void CacheTransceiver::checkGenTransferStatus(std::optional const& atLeastRequestNum) { bool blockAll = !atLeastRequestNum.has_value(); std::vector genTransferReadyRequestIds; for (auto&& [request, future] : mRequesterFutures) { if (future.wait_for(std::chrono::milliseconds(0)) == std::future_status::ready) { genTransferReadyRequestIds.push_back(request->mRequestId); } } std::unordered_map frequencyMap; std::vector toBlockRequestIds; auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm; if ((syncComm) && syncComm->getSize() > 1) { auto gatherRequestIdVec = gatherRequestIds(syncComm, genTransferReadyRequestIds); for (auto&& requestId : gatherRequestIdVec) { frequencyMap[requestId]++; } } else { for (auto&& requestId : genTransferReadyRequestIds) { frequencyMap[requestId]++; } } std::vector> freqVec(frequencyMap.begin(), frequencyMap.end()); std::sort(freqVec.begin(), freqVec.end(), [](std::pair const& left, std::pair const& right) { return left.second > right.second; }); std::unordered_set toCompleteIdSet; size_t idx = 0; while (atLeastRequestNum.value_or(0) > static_cast(toCompleteIdSet.size())) { if (idx >= freqVec.size()) { break; } toCompleteIdSet.insert(freqVec.at(idx).first); if (useMPI()) { TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " checkGenTransferStatus at least from freqVec requestId: %zu ", freqVec.at(idx).first); } else { TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(), " checkGenTransferStatus at least from freqVec requestId: %zu ", freqVec.at(idx).first); } idx++; } idx = 0; // insert order while (atLeastRequestNum.value_or(0) > static_cast(toCompleteIdSet.size())) { if (idx >= mRequesterFutures.size()) { break; } if (toCompleteIdSet.find(mRequesterFutures.at(idx).first->mRequestId) == toCompleteIdSet.end()) { toCompleteIdSet.insert(mRequesterFutures.at(idx).first->mRequestId); if (useMPI()) { TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " checkGenTransferStatus at least from RequesterFuture requestId: %zu atLeastRequestNum:%d", mRequesterFutures.at(idx).first->mRequestId, atLeastRequestNum.value_or(0)); } else { TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(), " checkGenTransferStatus at least from RequesterFuture requestId: %zu atLeastRequestNum:%d", mRequesterFutures.at(idx).first->mRequestId, atLeastRequestNum.value_or(0)); } } idx++; } for (auto&& [requestId, freq] : freqVec) { if (freq == ((syncComm != nullptr) ? syncComm->getSize() : 1)) { toCompleteIdSet.insert(requestId); } if (useMPI()) { TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " checkGenTransferStatus freqVec requestId: %zu,freq:%d ", requestId, freq); } else { TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(), " checkGenTransferStatus freqVec requestId: %zu,freq:%d ", requestId, freq); } } if (useMPI()) { TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " checkGenTransferStatus toCompleteIdSet size: %zu, atLeastRequestNum: %d ", toCompleteIdSet.size(), atLeastRequestNum.value_or(0)); } else { TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(), " checkGenTransferStatus toCompleteIdSet size: %zu, atLeastRequestNum: %d ", toCompleteIdSet.size(), atLeastRequestNum.value_or(0)); } for (auto it = mRequesterFutures.begin(); it != mRequesterFutures.end();) { if (blockAll || toCompleteIdSet.find(it->first->mRequestId) != toCompleteIdSet.end()) { try { it->second.get(); it->first->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE); // Gather the kv cache transfer time from all workers and update to leader rank if (!common::getEnvKVCacheTimeOutputPath().empty()) { auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm; updateKVCacheTransferBW(syncComm, it->first); } } catch (std::exception const& e) { TLLM_LOG_ERROR( "Error occurred during generation transfer for request %ld: %s", it->first->mRequestId, e.what()); it->first->setState(LlmRequestState::kDISAGG_TRANS_ERROR); } if (useMPI()) { TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***", it->first->mRequestId, it->first->getContextPhaseParams().value().getReqId()); } else { TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(), "**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***", it->first->mRequestId, it->first->getContextPhaseParams().value().getReqId()); } it = mRequesterFutures.erase(it); } else { ++it; } } } bool CacheTransceiver::checkGenTransferComplete() const { return mRequesterFutures.empty(); } bool CacheTransceiver::cancelRequest(LlmRequest* llmRequest) { if (llmRequest->isContextOnlyRequest()) { return mCacheSender->cancelRequest(*llmRequest); } else if (llmRequest->isGenerationOnlyRequest()) { return mCacheReceiver->cancelRequest(*llmRequest); } return false; } } // namespace tensorrt_llm::batch_manager