diff --git a/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h b/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h index 934cb39972..98296c8a03 100644 --- a/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h +++ b/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h @@ -34,8 +34,14 @@ namespace tensorrt_llm::batch_manager class ContextProgress; class BaseCacheTransceiver; -class DataResponder; -class DataRequester; + +namespace kv_cache_manager +{ +class BaseKVCacheManager; +} // namespace kv_cache_manager + +class CacheSender; +class CacheReceiver; class CacheTransceiverFactory { @@ -110,9 +116,9 @@ private: void setContextState(LlmRequest* llmRequest); - std::unique_ptr mDataResponder; - std::unique_ptr mDataRequester; - std::vector>> mResponderFutures; + std::unique_ptr mCacheSender; + std::unique_ptr mCacheReceiver; + std::vector>> mSenderFutures; std::vector>> mRequesterFutures; mpi::MpiComm const *mMpiGroupComm{nullptr}, *mMpiWorldComm{nullptr}; std::shared_ptr mMpiGroupTensorParaComm, mMpiGroupPipeParaComm, mMpiGroupDataComm, diff --git a/cpp/tensorrt_llm/batch_manager/CMakeLists.txt b/cpp/tensorrt_llm/batch_manager/CMakeLists.txt index 5f7d774c0b..b0e5b2ddf6 100644 --- a/cpp/tensorrt_llm/batch_manager/CMakeLists.txt +++ b/cpp/tensorrt_llm/batch_manager/CMakeLists.txt @@ -24,7 +24,6 @@ set(SRCS createNewDecoderRequests.cpp contextProgress.cpp dataTransceiver.cpp - dataTransceiverImpl.cpp decoderBuffers.cpp encoderBuffers.cpp guidedDecoder.cpp diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index b5a68b74b4..306cd64187 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -19,6 +19,7 @@ #include "mlaCacheFormatter.h" #include "tensorrt_llm/batch_manager/contextProgress.h" +#include "tensorrt_llm/batch_manager/dataTransceiver.h" #include "tensorrt_llm/batch_manager/kvCacheUtils.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" @@ -154,7 +155,7 @@ std::vector CacheFormatter::pickRecvConnections( return ret; } -void CacheFormatter::format(TransferSession& session) +void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& session) { NVTX3_SCOPED_RANGE(CacheFormatter_format); auto const& llmRequest = session.getLlmRequest(); @@ -468,7 +469,7 @@ void CacheFormatter::format(TransferSession& session) mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID:%ld ", llmRequest.mRequestId); } -void CacheFormatter::unformat(TransferSession& session) +void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& session) { NVTX3_SCOPED_RANGE(CacheFormatter_unformat); auto const& llmRequest = session.getLlmRequest(); diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.h b/cpp/tensorrt_llm/batch_manager/cacheFormatter.h index 8ae8ee5f2c..0071627af6 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.h +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.h @@ -18,11 +18,12 @@ #pragma once #include "cacheTransBuffer.h" -#include "dataTransceiver.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/batch_manager/kvCacheUtils.h" +#include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/envUtils.h" #include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/executor/cacheCommunicator.h" #include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h" #include "tensorrt_llm/executor/dataTransceiverState.h" #include "tensorrt_llm/runtime/bufferManager.h" @@ -30,10 +31,25 @@ #include #include #include +#include +#include + +// Forward declare TransferSession in the correct global namespace scope +namespace tensorrt_llm::batch_manager +{ +class TransferSession; +} namespace tensorrt_llm::batch_manager::kv_cache_manager { +using DataContext = tensorrt_llm::executor::kv_cache::DataContext; +using Connection = tensorrt_llm::executor::kv_cache::Connection; +using SizeType32 = tensorrt_llm::runtime::SizeType32; +using BaseKVCacheManager = kv_cache_manager::BaseKVCacheManager; +using CacheTransBufferManager = kv_cache_manager::CacheTransBufferManager; +using BlockRange = kv_cache_manager::BlockRange; + BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest); BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest); @@ -42,16 +58,15 @@ BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmReques class BaseCacheFormatter { public: - using SizeType32 = tensorrt_llm::runtime::SizeType32; using CacheState = executor::kv_cache::CacheState; /// @brief Format the cache data into bytes for sending. /// @param session The transfer session. - virtual void format(TransferSession& session) = 0; + virtual void format(tensorrt_llm::batch_manager::TransferSession& session) = 0; /// @brief Unformat the cache data from received bytes. /// @param session The transfer session. - virtual void unformat(TransferSession& session) = 0; + virtual void unformat(tensorrt_llm::batch_manager::TransferSession& session) = 0; /// @brief Determine whether the sender is applicable to the source and target. /// @param selfConfig Source data arrangement. @@ -91,9 +106,9 @@ public: TLLM_CHECK(mCacheTransBufferManager); } - void format(TransferSession& session) override; + void format(tensorrt_llm::batch_manager::TransferSession& session) override; - void unformat(TransferSession& session) override; + void unformat(tensorrt_llm::batch_manager::TransferSession& session) override; [[nodiscard]] bool inquireSupport(CacheState const& selfConfig, CacheState const& destConfig) const override; diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp index 2cd8742877..74af6ff90a 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp @@ -37,8 +37,9 @@ #include "tensorrt_llm/batch_manager/cacheFormatter.h" #include "tensorrt_llm/batch_manager/cacheTransceiver.h" #include "tensorrt_llm/batch_manager/contextProgress.h" -#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" +#include "tensorrt_llm/batch_manager/kvCacheType.h" +#include "tensorrt_llm/batch_manager/kvCacheUtils.h" #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/batch_manager/mlaCacheFormatter.h" #include "tensorrt_llm/common/envUtils.h" @@ -116,7 +117,6 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa : mMpiGroupComm(std::addressof(tensorrt_llm::mpi::MpiComm::session())) , mCacheTransceiverConfig{cacheTransceiverConfig} { - using tensorrt_llm::batch_manager::kv_cache_manager::CacheFormatter; if (worldConfig.isPipelineParallel()) { mMpiGroupPipeParaComm = std::make_shared( @@ -200,14 +200,12 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa TLLM_THROW("Unsupported cache transceiver backend type "); } - using tensorrt_llm::batch_manager::kv_cache_manager::MLACacheFormatter; auto makeFormatter = [cacheManager, isMLA, this]() { return createCacheFormatter(cacheManager, mCacheTransBufferManager.get(), isMLA); }; - mDataResponder = std::make_unique( - std::make_unique(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter())); - mDataRequester = std::make_unique( - std::make_unique(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter())); + mCacheSender = std::make_unique(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()); + mCacheReceiver + = std::make_unique(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()); initializeCommState(); } @@ -223,7 +221,7 @@ CacheTransceiver::~CacheTransceiver() void CacheTransceiver::initializeCommState() { - mCommState = std::addressof(mDataResponder->getCommState()); + mCommState = std::addressof(mCacheSender->getCommState()); } void CacheTransceiver::setContextState(LlmRequest* llmRequest) @@ -259,8 +257,8 @@ void CacheTransceiver::respondAndSendAsync(LlmRequest* llmRequest) return; } setContextState(llmRequest); - auto future = mDataResponder->respondAndSendAsync(*llmRequest); - mResponderFutures.emplace_back(llmRequest, std::move(future)); + auto future = mCacheSender->sendAsync(*llmRequest); + mSenderFutures.emplace_back(llmRequest, std::move(future)); } void CacheTransceiver::respondAndSendLayerWise( @@ -275,8 +273,8 @@ void CacheTransceiver::respondAndSendLayerWise( llmRequest->setState(LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS); setContextState(llmRequest.get()); - auto future = mDataResponder->respondAndSendAsync(*llmRequest); - mResponderFutures.emplace_back(llmRequest.get(), std::move(future)); + auto future = mCacheSender->sendAsync(*llmRequest); + mSenderFutures.emplace_back(llmRequest.get(), std::move(future)); } } @@ -284,7 +282,7 @@ void CacheTransceiver::requestAndReceiveSync(LlmRequest* llmRequest) { TLLM_CHECK(llmRequest && llmRequest->isGenerationOnlyRequest()); { - auto future = mDataRequester->requestAndReceiveAsync(*llmRequest); + auto future = mCacheReceiver->receiveAsync(*llmRequest); future.get(); } llmRequest->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE); @@ -302,7 +300,7 @@ void CacheTransceiver::requestAndReceiveAsync(LlmRequest* llmRequest) return; } - auto future = mDataRequester->requestAndReceiveAsync(*llmRequest); + auto future = mCacheReceiver->receiveAsync(*llmRequest); mRequesterFutures.emplace_back(llmRequest, std::move(future)); llmRequest->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS); } @@ -382,7 +380,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional const& atLe bool blockAll = !atLeastRequestNum.has_value(); auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mMpiGroupTPInDPComm : mMpiGroupTensorParaComm; std::vector contextCompleteRequestIds; - for (auto&& [request, future] : mResponderFutures) + for (auto&& [request, future] : mSenderFutures) { if (future.wait_for(std::chrono::milliseconds(0)) == std::future_status::ready) { @@ -422,16 +420,15 @@ void CacheTransceiver::checkContextTransferStatus(std::optional const& atLe // Make sure there are at least atLeastRequestNum requests in toCompleteIdSet. // This will preserve the order of insertion for KVCache transfer requests. - for (auto it = mResponderFutures.begin(); - atLeastRequestNum.value_or(0) > static_cast(toCompleteIdSet.size()) && it != mResponderFutures.end(); - ++it) + for (auto it = mSenderFutures.begin(); + atLeastRequestNum.value_or(0) > static_cast(toCompleteIdSet.size()) && it != mSenderFutures.end(); ++it) { auto& [request, future] = *it; toCompleteIdSet.insert(request->mRequestId); } // Complete all the requests in toCompleteIdSet - for (auto it = mResponderFutures.begin(); it != mResponderFutures.end();) + for (auto it = mSenderFutures.begin(); it != mSenderFutures.end();) { auto& [request, future] = *it; if (blockAll || (toCompleteIdSet.find(request->mRequestId) != toCompleteIdSet.end())) @@ -447,7 +444,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional const& atLe "Error occurred during context transfer for request %ld: %s", request->mRequestId, e.what()); request->setState(LlmRequestState::kDISAGG_TRANS_ERROR); } - it = mResponderFutures.erase(it); + it = mSenderFutures.erase(it); } else { diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index 8b4a8773d1..3dbf65c42a 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -17,6 +17,7 @@ #include "dataTransceiver.h" +#include "tensorrt_llm/batch_manager/cacheFormatter.h" #include "tensorrt_llm/batch_manager/common.h" #include "tensorrt_llm/batch_manager/kvCacheUtils.h" #include "tensorrt_llm/batch_manager/runtimeBuffers.h" @@ -24,6 +25,7 @@ #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/tllmException.h" #include "tensorrt_llm/common/utils.h" +#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include #include @@ -33,8 +35,138 @@ namespace tensorrt_llm::batch_manager { -using kv_cache_manager::BlockRange; +using BlockRange = tensorrt_llm::batch_manager::kv_cache_manager::BlockRange; + +std::vector const& TransferSession::getConnections() const +{ + return mConnections; +} + +void TransferSession::setConnection(size_t idx, Connection const* conn) +{ + mConnections.at(idx) = conn; +} + +DataContext const& TransferSession::getDataContext() const +{ + return mDataContext; +} + +executor::DataTransceiverState const& TransferSession::getSelfState() const +{ + return *mSelfState; +} + +executor::DataTransceiverState const& TransferSession::getOtherState() const +{ + return mOtherState; +} + +runtime::BufferManager const& TransferSession::getBufferManager() const +{ + return *mBufferManager; +} + +void TransferSession::send(size_t idx, void const* data, size_t size) +{ + try + { + mConnections.at(idx)->send(mDataContext, data, size); + } + catch (std::exception const& e) + { + throw common::RequestSpecificException( + __FILE__, __LINE__, e.what(), mRequest->mRequestId, common::RequestErrorCode::kNETWORK_ERROR); + } +} + +void TransferSession::recv(size_t idx, void* data, size_t size) +{ + try + { + mConnections.at(idx)->recv(mDataContext, data, size); + } + catch (std::exception const& e) + { + throw common::RequestSpecificException( + __FILE__, __LINE__, e.what(), mRequest->mRequestId, common::RequestErrorCode::kNETWORK_ERROR); + } +} + +LlmRequest const& TransferSession::getLlmRequest() const +{ + TLLM_CHECK(mRequest != nullptr); + return *mRequest; +} + +void TransferSession::setLlmRequest(LlmRequest const& llmRequest) +{ + mRequest = &llmRequest; +} + +void TransferSession::appendMeasure(double delay, double duration, size_t size) +{ + if (!mRecordMeasure) + { + return; + } + auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps + mMeasures.emplace_back(Measure{delay, duration, bandwidth}); +} + +void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const +{ + if (mMeasures.empty()) + { + return; + } + // write header if not exist + if (outFile.tellp() == 0) + { + outFile << "RequestID"; + for (size_t i = 0; i < mMeasures.size(); i++) + { + outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)"; + } + outFile << '\n'; + } + // write measures + TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value()); + auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId(); + outFile << reqId; + for (auto const& measure : mMeasures) + { + outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth; + } + outFile << '\n' << std::flush; +} + +std::vector const& RequestInfo::getBlockHashes() const noexcept +{ + return mBlockHashes; +} + using runtime::SizeType32; +using AgentConnectionManager = tensorrt_llm::executor::kv_cache::AgentConnectionManager; +using DataContext = tensorrt_llm::executor::kv_cache::DataContext; + +static int32_t tagFromRequestId(LlmRequest::RequestIdType requestId) +{ + constexpr int32_t kDATA_TAG{43}; + return ((requestId & 0xFFF) << 8) | (kDATA_TAG & 0xFF); +} + +struct ReceiveCacheResource +{ + runtime::BufferManager mBufferManager; + runtime::CudaEvent mCudaEvent; + + ReceiveCacheResource(runtime::BufferManager&& bufferManager, runtime::CudaEvent&& cudaEvent) + : mBufferManager(bufferManager) + , mCudaEvent(std::move(cudaEvent)) + { + } +}; RequestInfo::RequestInfo(LlmRequest::RequestIdType requestId, executor::DataTransceiverState transState) : mRequestId{requestId} @@ -92,82 +224,128 @@ std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo) return totalSize; } -void TransferSession::appendMeasure(double delay, double duration, size_t size) -{ - if (!mRecordMeasure) - { - return; - } - auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps - mMeasures.emplace_back(Measure{delay, duration, bandwidth}); -} - -void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const -{ - if (mMeasures.empty()) - { - return; - } - // write header if not exist - if (outFile.tellp() == 0) - { - outFile << "RequestID"; - for (size_t i = 0; i < mMeasures.size(); i++) - { - outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)"; - } - outFile << '\n'; - } - // write measures - TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value()); - auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId(); - outFile << reqId; - for (auto const& measure : mMeasures) - { - outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth; - } - outFile << '\n' << std::flush; -} - -class DataResponder::Impl +class CacheSender::Impl { public: using RequestIdType = LlmRequest::RequestIdType; - Impl(std::unique_ptr sender) - : mSender{std::move(sender)} + Impl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, + SizeType32 selfIndex, std::unique_ptr formatter) + : mManager{manager} + , mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}} + , mFormatter{std::move(formatter)} + , mBufferManager{std::make_shared()} { - TLLM_CHECK(mSender); + TLLM_CHECK(mManager); + TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex); TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId)); mCurrentRequest = std::nullopt; mResponseFuture = std::async(std::launch::async, &Impl::response, this); } - [[nodiscard]] std::future respondAndSendAsync(LlmRequest& llmRequest) + [[nodiscard]] std::future sendAsync(LlmRequest& llmRequest) { std::promise promise; auto future = promise.get_future(); { { - std::unique_lock lkResp(mResponderMutex); + std::unique_lock lkResp(mSenderMutex); mReadyResponses.emplace( llmRequest.mRequestId, Response{std::addressof(llmRequest), std::move(promise)}); } std::unique_lock lkCond(mCondMutex); mAnyReady = true; } - mResponderCv.notify_all(); + mSenderCv.notify_all(); return future; } [[nodiscard]] executor::kv_cache::CommState const& getCommState() const { - return mSender->getCommState(); + return mSelfState.getCommState().value(); } void setCommState(executor::kv_cache::CommState commState) { - mSender->setCommState(std::move(commState)); + mSelfState.setCommState(std::move(commState)); + } + + [[nodiscard]] size_t getCounterpartsCount(LlmRequest::RequestIdType requestId) const + { + auto it = mRequestToSession.find(requestId); + TLLM_CHECK(it != mRequestToSession.end()); + return it->second.getConnections().size(); + } + + void release(LlmRequest::RequestIdType requestId) + { + auto it = mRequestToSession.find(requestId); + TLLM_CHECK(it != mRequestToSession.end()); + std::unique_lock lk(mMtxForMap); + mRequestToSession.erase(it); + } + + [[nodiscard]] RequestInfo recvRequestInfo() + { + auto* agentConnectionManager = dynamic_cast(mManager); + bool isAgent = agentConnectionManager != nullptr; + + auto agentRecvFun = [&](RequestInfo& requestInfo) + { + auto const* connection = agentConnectionManager->recvConnectionAndRequestInfo(requestInfo); + return connection; + }; + TransceiverTag::Id id; + RequestInfo info; + auto const* connection = isAgent ? agentRecvFun(info) + : mManager->recvConnect(DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id)); + if (!isAgent) + { + TLLM_CHECK(id == TransceiverTag::Id::REQUEST_SEND); + std::uint64_t infoSize{0}; + connection->recv( + executor::kv_cache::DataContext{TransceiverTag::kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize)); + std::string serializedInfo; + serializedInfo.resize(infoSize); + connection->recv( + executor::kv_cache::DataContext{TransceiverTag::kINFO_TAG}, serializedInfo.data(), infoSize); + std::istringstream iss(serializedInfo); + info = RequestInfo::deserialize(iss); + } + + auto requestId = info.getRequestId(); + TLLM_CHECK_WITH_INFO(mFormatter->inquireSupport( + mSelfState.getCacheState().value(), info.getTransState().getCacheState().value()), + "Disagg server does not currently support these cacheState, please check the cacheState of the context and " + "gen " + "executors"); + auto peerRelativeRanks = executor::kv_cache::targetIRanks(info.getTransState().getCacheState().value(), + mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx()) + .mIRanks; + int peerIdx = std::distance(peerRelativeRanks.begin(), + std::find( + peerRelativeRanks.begin(), peerRelativeRanks.end(), info.getTransState().getCommState()->getSelfIdx())); + { + std::unique_lock lk(mMtxForMap); + auto it = mRequestToSession.find(requestId); + if (it == mRequestToSession.end()) + { + auto session = TransferSession(std::vector(peerRelativeRanks.size(), nullptr), + DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager); + it = mRequestToSession.emplace(requestId, std::move(session)).first; + } + it->second.setConnection(peerIdx, connection); + } + return info; + } + + void sendSync(LlmRequest const& llmRequest) + { + auto it = mRequestToSession.find(llmRequest.mRequestId); + TLLM_CHECK(it != mRequestToSession.end()); + auto& session = it->second; + session.setLlmRequest(llmRequest); + mFormatter->format(session); } ~Impl() @@ -187,8 +365,8 @@ private: try { TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId)); - mSender->sendSync(*resp.mRequest); - mSender->release(id); + sendSync(*resp.mRequest); + release(id); resp.mPromise.set_value(); } catch (tensorrt_llm::common::RequestSpecificException const& e) @@ -220,12 +398,11 @@ private: if (common::getEnvParallelCacheSend()) { // TODO: Use a thread pool and check for thread safety. - std::thread(&DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second)) - .detach(); + std::thread(&CacheSender::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second)).detach(); } else { - DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second)); + CacheSender::Impl::sendAndRemoveResponse(it->first, std::move(it->second)); } removeResponse(it); } @@ -243,7 +420,7 @@ private: if (!mAnyReady) { std::unique_lock lk(mCondMutex); - mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); }); + mSenderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); }); } if (mTerminate) { @@ -252,14 +429,14 @@ private: std::vector blockHashes; if (!isSending() && !mReadyResponses.empty()) { - auto const& requestInfo = mSender->recvRequestInfo(); + auto const& requestInfo = recvRequestInfo(); auto reqId = requestInfo.getRequestId(); blockHashes = requestInfo.getBlockHashes(); mCurrentRequest = reqId; if (mRemainSendCount.find(reqId) == mRemainSendCount.end()) { - mRemainSendCount[reqId] = mSender->getCounterpartsCount(reqId); + mRemainSendCount[reqId] = getCounterpartsCount(reqId); } } auto it = getCurrentResponse(); @@ -273,7 +450,7 @@ private: while (it == mReadyResponses.end()) { std::unique_lock lk(mCondMutex); - mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); }); + mSenderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); }); if (mTerminate) { break; @@ -286,7 +463,7 @@ private: } catch (std::exception const& err) { - TLLM_LOG_ERROR("Exception in DataResponder response: %s", err.what()); + TLLM_LOG_ERROR("Exception in CacheSender response: %s", err.what()); for (auto& it : mReadyResponses) { it.second.mPromise.set_exception(std::current_exception()); @@ -302,13 +479,13 @@ private: } // We don't have to wait for the future. If another thread is sending data, it won't pay attention // to the terminate flag. - mResponderCv.notify_all(); + mSenderCv.notify_all(); } void removeResponse(std::map::iterator it) { { - std::unique_lock lkResp(mResponderMutex); + std::unique_lock lkResp(mSenderMutex); mReadyResponses.erase(it); } if (mReadyResponses.empty()) @@ -330,36 +507,47 @@ private: [[nodiscard]] std::map::iterator getCurrentResponse() { - std::unique_lock lk(mResponderMutex); + std::unique_lock lk(mSenderMutex); return mReadyResponses.find(getCurrentRequestId()); } private: std::optional mCurrentRequest; std::map mReadyResponses; - std::mutex mResponderMutex, mCondMutex; + std::mutex mSenderMutex, mCondMutex; std::atomic mAnyReady{false}, mTerminate{false}; - std::condition_variable mResponderCv; + std::condition_variable mSenderCv; std::future mResponseFuture; - std::unique_ptr mSender; std::unordered_map mRemainSendCount; int mDeviceId{-1}; + + executor::kv_cache::ConnectionManager* mManager; + std::map mRequestToSession; + executor::DataTransceiverState mSelfState; + std::unique_ptr mFormatter; + std::mutex mMtxForMap; + runtime::BufferManager mBufferManager; }; -class DataRequester::Impl +class CacheReceiver::Impl { public: - Impl(std::unique_ptr receiver) - : mReceiver{std::move(receiver)} + Impl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, + SizeType32 selfIndex, std::unique_ptr formatter) + : mManager{manager} + , mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}} + , mFormatter{std::move(formatter)} + , mBufferManager{std::make_shared()} { - TLLM_CHECK(mReceiver); + TLLM_CHECK(mManager); + TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex); TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId)); } - [[nodiscard]] std::future requestAndReceiveAsync(LlmRequest& llmRequest) + [[nodiscard]] std::future receiveAsync(LlmRequest& llmRequest) { // TODO: Modify the implementation here to avoid frequent thread creation. - return std::async(std::launch::async, &DataRequester::Impl::requestSync, this, std::ref(llmRequest)); + return std::async(std::launch::async, &CacheReceiver::Impl::requestSync, this, std::ref(llmRequest)); } [[nodiscard]] std::future requestAndReceiveAsyncMultiThreads(LlmRequest& llmRequest) @@ -378,7 +566,7 @@ public: { mInstanceToAsyncResource.emplace(processInfo, std::make_unique()); - auto requestFuture = std::async(std::launch::async, &DataRequester::Impl::request, this, + auto requestFuture = std::async(std::launch::async, &CacheReceiver::Impl::request, this, std::ref(*mInstanceToAsyncResource.at(processInfo))); mRequestFutures.emplace_back(std::move(requestFuture)); } @@ -396,6 +584,107 @@ public: } } + void receiveSync(TransferSession& session) + { + mFormatter->unformat(session); + } + + TransferSession sendRequestInfo(LlmRequest const& llmRequest) + { + uint64_t requestId = llmRequest.getContextPhaseParams().value().getReqId(); + auto const& contextState = llmRequest.getDataTransceiverState(); + auto const& commState = contextState.getCommState().value(); + auto const& destCacheState = contextState.getCacheState().value(); + TLLM_CHECK_WITH_INFO(mFormatter->inquireSupport(mSelfState.getCacheState().value(), destCacheState), + "Disagg server does not currently support these cacheState."); + + RequestInfo requestInfo(requestId, mSelfState); + + auto disableSelectiveCacheTransfer = common::getEnvDisableSelectiveCacheTransfer() + || (mFormatter->getCacheManager()->getBlockManager().getNumPools() > 1); + if (!disableSelectiveCacheTransfer) + { + auto* cacheManager = mFormatter->getCacheManager(); + auto blockRange + = kv_cache_manager::BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId); + requestInfo = RequestInfo(requestId, blockRange.getBlockHashes(), mSelfState); + } + + auto* agentConnectionManager = dynamic_cast(mManager); + std::optional cacheBufferId = std::nullopt; + if (agentConnectionManager != nullptr) + { + cacheBufferId = agentConnectionManager->getCacheTransBufferManager()->assignBufferIndexForRecv(); + TLLM_CHECK(cacheBufferId.has_value()); + // memory Desp , validSegmentIdx send + } + auto counterParts = mFormatter->getCounterparts( + mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx(), destCacheState); + + auto connections = mManager->getConnections(commState); + std::vector counterPartConnections; + for (auto index : counterParts) + { + auto const* connection = connections.at(index); + counterPartConnections.emplace_back(connection); + } + auto pickUpIdx = mFormatter->pickRecvConnections(counterParts.size(), mSelfState.getCacheState().value(), + mSelfState.getCommState().value().getSelfIdx(), destCacheState); + for (size_t i = 0; i < counterPartConnections.size(); i++) + { + auto const* connection = counterPartConnections[i]; + // if Manager is agentConnectionManager, then send request info to agent + auto* agentConnectionManager = dynamic_cast(mManager); + if (agentConnectionManager != nullptr) + { + // TODO: index -> validConnectionIdx conversion + auto valideConnectionIdx = std::find(pickUpIdx.begin(), pickUpIdx.end(), i) - pickUpIdx.begin(); + auto* agentConnection = dynamic_cast(connection); + TLLM_CHECK(agentConnection != nullptr); + TLLM_CHECK(cacheBufferId.has_value()); + const_cast(agentConnection) + ->sendRequestAndBufferInfo(requestInfo, cacheBufferId, valideConnectionIdx); + } + else + { + sendRequestInfo(connection, requestInfo); + } + } + auto const& resource = getReceiveCacheResource(llmRequest); + return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState, + contextState, resource->mBufferManager, &llmRequest); + } + + std::unique_ptr const& getReceiveCacheResource(LlmRequest const& llmRequest) + { + std::scoped_lock lock(mProcessIoResouceMutex); + TLLM_CHECK(llmRequest.getDataTransceiverState().getCommState().has_value()); + std::string processString = "default"; + if (common::getEnvRequestKVCacheConcurrent()) + { + processString = llmRequest.getDataTransceiverState().getCommState()->toString(); + } + if (mProcessToResources.find(processString) == mProcessToResources.end()) + { + mProcessToResources.emplace(processString, + std::make_unique( + runtime::BufferManager{std::make_shared()}, runtime::CudaEvent{})); + } + return mProcessToResources.at(processString); + } + + void sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info) + { + std::ostringstream oss; + RequestInfo::serialize(info, oss); + auto const& serializedInfo = oss.str(); + std::size_t const infoSize = serializedInfo.size(); + TransceiverTag::Id id{TransceiverTag::Id::REQUEST_SEND}; + connection->send(executor::kv_cache::DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id)); + connection->send(executor::kv_cache::DataContext{TransceiverTag::kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize)); + connection->send(executor::kv_cache::DataContext{TransceiverTag::kINFO_TAG}, serializedInfo.data(), infoSize); + } + ~Impl() { for (auto&& [processInfo, asyncResource] : mInstanceToAsyncResource) @@ -417,8 +706,8 @@ private: llmRequest.getContextPhaseParams().value().getReqId()); llmRequest.setKvCacheTransferStart(std::chrono::steady_clock::now()); TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId)); - auto session = mReceiver->sendRequestInfo(llmRequest); - mReceiver->receiveSync(session); + auto session = sendRequestInfo(llmRequest); + receiveSync(session); llmRequest.setKvCacheTransferEnd(std::chrono::steady_clock::now()); TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), @@ -524,7 +813,7 @@ private: } catch (std::exception const& err) { - TLLM_LOG_ERROR("Exception in DataRequester request(): request id:%ld , request context id:%ld : %s", + TLLM_LOG_ERROR("Exception in CacheReceiver request(): request id:%ld , request context id:%ld : %s", requestAndPromise.mRequest->mRequestId, requestAndPromise.mRequest->getContextPhaseParams().value().getReqId(), err.what()); requestAndPromise.mPromise->set_exception(std::current_exception()); @@ -533,45 +822,81 @@ private: } } - std::unique_ptr mReceiver; int mDeviceId{-1}; - std::vector> mRequestFutures; std::unordered_map> mInstanceToAsyncResource; + executor::kv_cache::ConnectionManager* mManager; + executor::DataTransceiverState mSelfState; + std::unique_ptr mFormatter; + std::unordered_map> mProcessToResources; + std::mutex mProcessIoResouceMutex; + runtime::BufferManager mBufferManager; }; -DataResponder::DataResponder(std::unique_ptr sender) - : mImpl{std::make_unique(std::move(sender))} +void CacheSender::ImplDeleter::operator()(Impl* ptr) +{ + delete ptr; +} + +void CacheReceiver::ImplDeleter::operator()(Impl* ptr) +{ + delete ptr; +} + +CacheSender::CacheSender(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, + SizeType32 selfIndex, std::unique_ptr formatter) + : mImpl{std::unique_ptr(new Impl(manager, selfCacheState, selfIndex, std::move(formatter)))} { } -std::future DataResponder::respondAndSendAsync(LlmRequest& llmRequest) const +std::future CacheSender::sendAsync(LlmRequest& llmRequest) const { - return mImpl->respondAndSendAsync(llmRequest); + return mImpl->sendAsync(llmRequest); } -executor::kv_cache::CommState const& DataResponder::getCommState() const +executor::kv_cache::CommState const& CacheSender::getCommState() const { return mImpl->getCommState(); } -void DataResponder::setCommState(executor::kv_cache::CommState commState) +void CacheSender::setCommState(executor::kv_cache::CommState commState) { mImpl->setCommState(std::move(commState)); } -DataResponder::~DataResponder() = default; +CacheSender::~CacheSender() = default; -DataRequester::DataRequester(std::unique_ptr receiver) - : mImpl{std::make_unique(std::move(receiver))} +void CacheSender::sendSync(LlmRequest const& llmRequest) +{ + mImpl->sendSync(llmRequest); +} + +RequestInfo CacheSender::recvRequestInfo() +{ + return mImpl->recvRequestInfo(); +} + +CacheReceiver::CacheReceiver(executor::kv_cache::ConnectionManager* manager, + executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr formatter) + : mImpl{std::unique_ptr(new Impl(manager, selfCacheState, selfIndex, std::move(formatter)))} { } -std::future DataRequester::requestAndReceiveAsync(LlmRequest& llmRequest) const +std::future CacheReceiver::receiveAsync(LlmRequest& llmRequest) const { return mImpl->requestAndReceiveAsyncMultiThreads(llmRequest); } -DataRequester::~DataRequester() = default; +CacheReceiver::~CacheReceiver() = default; + +TransferSession CacheReceiver::sendRequestInfo(LlmRequest const& llmRequest) +{ + return mImpl->sendRequestInfo(llmRequest); +} + +void CacheReceiver::receiveSync(TransferSession& session) +{ + mImpl->receiveSync(session); +} } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h index f51a85c484..2de48dc0bc 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/common/assert.h" @@ -28,16 +29,100 @@ #include "tensorrt_llm/executor/cacheCommunicator.h" #include "tensorrt_llm/executor/dataTransceiverState.h" #include "tensorrt_llm/executor/serializeUtils.h" +#include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/cudaEvent.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" namespace tensorrt_llm::batch_manager { +namespace kv_cache_manager +{ +class BaseCacheFormatter; +} + +using BaseCacheFormatter = kv_cache_manager::BaseCacheFormatter; + // TODO: unify the following class into a namespace like tensorrt_llm::transmission using DataContext = tensorrt_llm::executor::kv_cache::DataContext; using Connection = tensorrt_llm::executor::kv_cache::Connection; using ConnectionManager = tensorrt_llm::executor::kv_cache::ConnectionManager; +using SizeType32 = tensorrt_llm::runtime::SizeType32; + +class TransferSession +{ +public: + struct Measure + { + double delay; // from last token (ctx) or arrival time (gen), in ms + double duration; // in ms + double bandwidth; // in Gbps + }; + + TransferSession(std::vector connections, DataContext dataContext, + executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState, + runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr, bool recordMeasure = false) + : mConnections(std::move(connections)) + , mDataContext(dataContext) + , mSelfState(&selfState) + , mOtherState(std::move(otherState)) + , mBufferManager(&bufferManager) + , mRequest(llmRequest) + , mRecordMeasure(recordMeasure) + { + TLLM_CHECK(!mConnections.empty()); + } + + [[nodiscard]] std::vector const& getConnections() const; + + // should be called only during the initialization of the TransferSession + void setConnection(size_t idx, Connection const* conn); + + [[nodiscard]] DataContext const& getDataContext() const; + + [[nodiscard]] executor::DataTransceiverState const& getSelfState() const; + + [[nodiscard]] executor::DataTransceiverState const& getOtherState() const; + + [[nodiscard]] runtime::BufferManager const& getBufferManager() const; + + void send(size_t idx, void const* data, size_t size); + + void recv(size_t idx, void* data, size_t size); + + [[nodiscard]] LlmRequest const& getLlmRequest() const; + + // in CacheSender, the LlmRequest is not available until the sendSync is called + void setLlmRequest(LlmRequest const& llmRequest); + + void appendMeasure(double delay, double duration, size_t size); + + // TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file + void exportMeasure(std::ofstream& outFile, bool isContext) const; + +private: + std::vector mConnections; + DataContext mDataContext; + executor::DataTransceiverState const* mSelfState; // stored in CacheReceiver/CacheSender + executor::DataTransceiverState mOtherState; + runtime::BufferManager const* mBufferManager; + LlmRequest const* mRequest; + std::vector mMeasures; + bool mRecordMeasure{false}; +}; + +struct TransceiverTag +{ + enum class Id : uint64_t + { + REQUEST_SEND = 1, + TERMINATION = 2 + }; + + static constexpr int32_t kID_TAG{19}; + static constexpr int32_t kINFO_SIZE_TAG{22}; + static constexpr int32_t kINFO_TAG{32}; +}; // Used to store the information that needs to be sent to the context executor to ensure the generation // executor smoothly receives the data. @@ -61,10 +146,7 @@ public: /// @return The request ID. [[nodiscard]] LlmRequest::RequestIdType getRequestId() const noexcept; - [[nodiscard]] std::vector const& getBlockHashes() const noexcept - { - return mBlockHashes; - } + [[nodiscard]] std::vector const& getBlockHashes() const noexcept; /// @brief Return the state of the data transceiver. /// @return The state of the data transceiver. @@ -94,207 +176,81 @@ private: executor::DataTransceiverState mTransState; }; -class TransferSession -{ -public: - struct Measure - { - double delay; // from last token (ctx) or arrival time (gen), in ms - double duration; // in ms - double bandwidth; // in Gbps - }; - - TransferSession(std::vector connections, DataContext dataContext, - executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState, - runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr, bool recordMeasure = false) - : mConnections(std::move(connections)) - , mDataContext(dataContext) - , mSelfState(&selfState) - , mOtherState(std::move(otherState)) - , mBufferManager(&bufferManager) - , mRequest(llmRequest) - , mRecordMeasure(recordMeasure) - { - TLLM_CHECK(!mConnections.empty()); - } - - [[nodiscard]] std::vector const& getConnections() const - { - return mConnections; - } - - // should be called only during the initialization of the TransferSession - void setConnection(size_t idx, Connection const* conn) - { - mConnections.at(idx) = conn; - } - - [[nodiscard]] DataContext const& getDataContext() const - { - return mDataContext; - } - - [[nodiscard]] executor::DataTransceiverState const& getSelfState() const - { - return *mSelfState; - } - - [[nodiscard]] executor::DataTransceiverState const& getOtherState() const - { - return mOtherState; - } - - [[nodiscard]] runtime::BufferManager const& getBufferManager() const - { - return *mBufferManager; - } - - void send(size_t idx, void const* data, size_t size) - { - try - { - mConnections.at(idx)->send(mDataContext, data, size); - } - catch (std::exception const& e) - { - throw common::RequestSpecificException( - __FILE__, __LINE__, e.what(), mRequest->mRequestId, common::RequestErrorCode::kNETWORK_ERROR); - } - } - - void recv(size_t idx, void* data, size_t size) - { - try - { - mConnections.at(idx)->recv(mDataContext, data, size); - } - catch (std::exception const& e) - { - throw common::RequestSpecificException( - __FILE__, __LINE__, e.what(), mRequest->mRequestId, common::RequestErrorCode::kNETWORK_ERROR); - } - } - - [[nodiscard]] LlmRequest const& getLlmRequest() const - { - TLLM_CHECK(mRequest != nullptr); - return *mRequest; - } - - // in DataSender, the LlmRequest is not available until the sendSync is called - void setLlmRequest(LlmRequest const& llmRequest) - { - mRequest = &llmRequest; - } - - void appendMeasure(double delay, double duration, size_t size); - - // TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file - void exportMeasure(std::ofstream& outFile, bool isContext) const; - -private: - std::vector mConnections; - DataContext mDataContext; - executor::DataTransceiverState const* mSelfState; // stored in DataRequester/DataResponder - executor::DataTransceiverState mOtherState; - runtime::BufferManager const* mBufferManager; - LlmRequest const* mRequest; - bool mRecordMeasure; - std::vector mMeasures; -}; - -// Operators required for data transmission in specific communication protocols. -class DataSender -{ -public: - /// @brief Receive the request information. - /// @return The request information. - [[nodiscard]] virtual RequestInfo recvRequestInfo() = 0; - - /// @brief Synchronously send data. - /// @param llmRequest The request object to which the data belongs. - virtual void sendSync(LlmRequest const& llmRequest) = 0; - - /// @brief Return the internal communicator status. - /// @return The communicator status. - [[nodiscard]] virtual executor::kv_cache::CommState const& getCommState() const = 0; - - /// @brief Reset the internal communicator status. - /// @param commState The communicator status. - virtual void setCommState(executor::kv_cache::CommState commState) = 0; - - [[nodiscard]] virtual size_t getCounterpartsCount(LlmRequest::RequestIdType requestId) const = 0; - - virtual void release(LlmRequest::RequestIdType requestId) = 0; - - /// @brief Destructor. - virtual ~DataSender() = default; -}; - -// Operators required for data transmission in specific communication protocols. -class DataReceiver -{ -public: - /// @brief Send the request information. - /// @param llmRequest The request object to which the information belongs. - virtual TransferSession sendRequestInfo(LlmRequest const& llmRequest) = 0; - - /// @brief Synchronously receive data. - /// @param session The transfer session. - virtual void receiveSync(TransferSession& session) = 0; - - /// @brief Destructor. - virtual ~DataReceiver() = default; -}; - -class DataResponder +class CacheSender { public: /// @brief Constructor. - /// @param sender The sender used at the underlying level. - explicit DataResponder(std::unique_ptr sender); + CacheSender(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, + SizeType32 selfIndex, std::unique_ptr formatter); + + CacheSender() = default; /// @brief Asynchronously respond to the request and send data. /// @param llmRequest Request object. Its data should be ready when called, and the data for this request /// should remain valid until future synchronization. /// @return Once the data is fully sent, the future object will become valid. - [[nodiscard]] std::future respondAndSendAsync(LlmRequest& llmRequest) const; + [[nodiscard]] virtual std::future sendAsync(LlmRequest& llmRequest) const; /// @brief Return the internal communicator status. /// @return The communicator status. - [[nodiscard]] executor::kv_cache::CommState const& getCommState() const; + [[nodiscard]] virtual executor::kv_cache::CommState const& getCommState() const; /// @brief Reset the internal communicator status. /// @param commState The communicator status. - void setCommState(executor::kv_cache::CommState commState); + virtual void setCommState(executor::kv_cache::CommState commState); + + /// @brief Synchronously send data. + /// @param llmRequest The request object to which the data belongs. + virtual void sendSync(LlmRequest const& llmRequest); + + /// @brief Receive request information. + /// @param llmRequest The request object to which the data belongs. + virtual RequestInfo recvRequestInfo(); /// @brief Destructor. - ~DataResponder(); + virtual ~CacheSender(); private: class Impl; - std::unique_ptr mImpl; + + struct ImplDeleter + { + void operator()(Impl*); + }; + + std::unique_ptr mImpl; }; -class DataRequester +class CacheReceiver { public: /// @brief Constructor. - /// @param receiver The receiver used at the underlying level. - explicit DataRequester(std::unique_ptr receiver); + CacheReceiver(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, + SizeType32 selfIndex, std::unique_ptr formatter); + + CacheReceiver() = default; /// @brief Asynchronously send a request to receive data. /// @param llmRequest Request object. Its data should be in an allocated but unwritten state when called, and the /// data for this request should remain intact only after future synchronization. /// @return Once the data is fully received, the future object will become valid. - [[nodiscard]] std::future requestAndReceiveAsync(LlmRequest& llmRequest) const; + [[nodiscard]] virtual std::future receiveAsync(LlmRequest& llmRequest) const; + virtual TransferSession sendRequestInfo(LlmRequest const& llmRequest); + + virtual void receiveSync(TransferSession& session); /// @brief Destructor. - ~DataRequester(); + virtual ~CacheReceiver(); private: class Impl; - std::unique_ptr mImpl; + + struct ImplDeleter + { + void operator()(Impl*); + }; + + std::unique_ptr mImpl; }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp deleted file mode 100644 index 1a5c7fab4d..0000000000 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp +++ /dev/null @@ -1,285 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataTransceiverImpl.h" - -#include "tensorrt_llm/common/envUtils.h" -#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h" -#include "tensorrt_llm/runtime/utils/mpiUtils.h" - -#include - -namespace tensorrt_llm::batch_manager -{ - -static int32_t tagFromRequestId(LlmRequest::RequestIdType requestId) -{ - constexpr int32_t kDATA_TAG{43}; - return ((requestId & 0xFFF) << 8) | (kDATA_TAG & 0xFF); -} - -namespace fs = std::filesystem; - -static fs::path getTransferOutputPath(char const* tag) -{ - auto outputPath = common::getEnvKVCacheTransferOutputPath(); - if (!outputPath.empty()) - { - auto rank = mpi::MpiComm::world().getRank(); - auto path = fs::path(outputPath); - fs::create_directories(path); - return path / ("rank_" + std::to_string(rank) + "_" + tag + ".csv"); - } - return {}; -} - -DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager, - executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr formatter) - : mManager{manager} - , mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}} - , mFormatter(std::move(formatter)) - , mBufferManager{std::make_shared()} -{ - TLLM_CHECK(mManager); - TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex); -} - -[[nodiscard]] RequestInfo DataSenderImpl::recvRequestInfo() -{ - using DataContext = tensorrt_llm::executor::kv_cache::DataContext; - auto* agentConnectionManager = dynamic_cast(mManager); - bool isAgent = agentConnectionManager != nullptr; - - auto agentRecvFun = [&](RequestInfo& requestInfo) - { - auto const* connection = agentConnectionManager->recvConnectionAndRequestInfo(requestInfo); - return connection; - }; - Id id; - RequestInfo info; - auto const* connection - = isAgent ? agentRecvFun(info) : mManager->recvConnect(DataContext{kID_TAG}, &id, sizeof(id)); - if (!isAgent) - { - TLLM_CHECK(id == Id::REQUEST_SEND); - std::uint64_t infoSize{0}; - connection->recv(executor::kv_cache::DataContext{kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize)); - std::string serializedInfo; - serializedInfo.resize(infoSize); - connection->recv(executor::kv_cache::DataContext{kINFO_TAG}, serializedInfo.data(), infoSize); - std::istringstream iss(serializedInfo); - info = RequestInfo::deserialize(iss); - } - - auto requestId = info.getRequestId(); - TLLM_CHECK_WITH_INFO( - mFormatter->inquireSupport(mSelfState.getCacheState().value(), info.getTransState().getCacheState().value()), - "Disagg server does not currently support these cacheState, please check the cacheState of the context and gen " - "executors"); - auto peerRelativeRanks = executor::kv_cache::targetIRanks(info.getTransState().getCacheState().value(), - mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx()) - .mIRanks; - int peerIdx = std::distance(peerRelativeRanks.begin(), - std::find( - peerRelativeRanks.begin(), peerRelativeRanks.end(), info.getTransState().getCommState()->getSelfIdx())); - { - std::unique_lock lk(mMtxForMap); - auto it = mRequestToSession.find(requestId); - if (it == mRequestToSession.end()) - { - auto session = TransferSession(std::vector(peerRelativeRanks.size(), nullptr), - DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager, nullptr, - !common::getEnvKVCacheTransferOutputPath().empty()); - it = mRequestToSession.emplace(requestId, std::move(session)).first; - } - it->second.setConnection(peerIdx, connection); - } - return info; -} - -void DataSenderImpl::sendSync(LlmRequest const& llmRequest) -{ - auto it = mRequestToSession.find(llmRequest.mRequestId); - TLLM_CHECK(it != mRequestToSession.end()); - auto& session = it->second; - session.setLlmRequest(llmRequest); - mFormatter->format(session); -} - -[[nodiscard]] executor::kv_cache::CommState const& DataSenderImpl::getCommState() const -{ - return mSelfState.getCommState().value(); -} - -void DataSenderImpl::setCommState(executor::kv_cache::CommState commState) -{ - mSelfState.setCommState(std::move(commState)); -} - -[[nodiscard]] size_t DataSenderImpl::getCounterpartsCount(LlmRequest::RequestIdType requestId) const -{ - auto it = mRequestToSession.find(requestId); - TLLM_CHECK(it != mRequestToSession.end()); - return it->second.getConnections().size(); -} - -void DataSenderImpl::release(LlmRequest::RequestIdType requestId) -{ - auto it = mRequestToSession.find(requestId); - TLLM_CHECK(it != mRequestToSession.end()); - std::unique_lock lk(mMtxForMap); - if (!common::getEnvKVCacheTransferOutputPath().empty()) - { - if (!mMeasuresFile.is_open()) - { - auto outputPath = getTransferOutputPath("send"); - mMeasuresFile.open(outputPath); - TLLM_CHECK_WITH_INFO( - mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str()); - } - it->second.exportMeasure(mMeasuresFile, true); - } - mRequestToSession.erase(it); -} - -DataReceiverImpl::DataReceiverImpl(executor::kv_cache::ConnectionManager* manager, - executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr formatter) - : mManager{manager} - , mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}} - , mFormatter(std::move(formatter)) -{ - TLLM_CHECK(mManager); - TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex); - TLLM_CHECK(mFormatter); -} - -TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest) -{ - uint64_t requestId = llmRequest.getContextPhaseParams().value().getReqId(); - auto const& contextState = llmRequest.getDataTransceiverState(); - auto const& commState = contextState.getCommState().value(); - auto const& destCacheState = contextState.getCacheState().value(); - TLLM_CHECK_WITH_INFO(mFormatter->inquireSupport(mSelfState.getCacheState().value(), destCacheState), - "Disagg server does not currently support these cacheState."); - - RequestInfo requestInfo(requestId, mSelfState); - - auto disableSelectiveCacheTransfer = common::getEnvDisableSelectiveCacheTransfer() - || (mFormatter->getCacheManager()->getBlockManager().getNumPools() > 1); - if (!disableSelectiveCacheTransfer) - { - auto* cacheManager = mFormatter->getCacheManager(); - auto blockRange - = kv_cache_manager::BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId); - requestInfo = RequestInfo(requestId, blockRange.getBlockHashes(), mSelfState); - } - - auto* agentConnectionManager = dynamic_cast(mManager); - std::optional cacheBufferId = std::nullopt; - if (agentConnectionManager != nullptr) - { - cacheBufferId = agentConnectionManager->getCacheTransBufferManager()->assignBufferIndexForRecv(); - TLLM_CHECK(cacheBufferId.has_value()); - // memory Desp , validSegmentIdx send - } - auto counterParts = mFormatter->getCounterparts( - mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx(), destCacheState); - - auto connections = mManager->getConnections(commState); - std::vector counterPartConnections; - for (auto index : counterParts) - { - auto const* connection = connections.at(index); - counterPartConnections.emplace_back(connection); - } - auto pickUpIdx = mFormatter->pickRecvConnections(counterParts.size(), mSelfState.getCacheState().value(), - mSelfState.getCommState().value().getSelfIdx(), destCacheState); - for (size_t i = 0; i < counterPartConnections.size(); i++) - { - auto const* connection = counterPartConnections[i]; - // if Manager is agentConnectionManager, then send request info to agent - auto* agentConnectionManager = dynamic_cast(mManager); - if (agentConnectionManager != nullptr) - { - // TODO: index -> validConnectionIdx conversion - auto valideConnectionIdx = std::find(pickUpIdx.begin(), pickUpIdx.end(), i) - pickUpIdx.begin(); - auto* agentConnection = dynamic_cast(connection); - TLLM_CHECK(agentConnection != nullptr); - TLLM_CHECK(cacheBufferId.has_value()); - const_cast(agentConnection) - ->sendRequestAndBufferInfo(requestInfo, cacheBufferId, valideConnectionIdx); - } - else - { - sendRequestInfo(connection, requestInfo); - } - } - auto const& resource = getReceiveCacheResource(llmRequest); - return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState, - contextState, resource->mBufferManager, &llmRequest, !common::getEnvKVCacheTransferOutputPath().empty()); -} - -void DataReceiverImpl::receiveSync(TransferSession& session) -{ - mFormatter->unformat(session); - if (!common::getEnvKVCacheTransferOutputPath().empty()) - { - std::unique_lock lock(mMeasuresFileMutex); - if (!mMeasuresFile.is_open()) - { - auto outputPath = getTransferOutputPath("recv"); - mMeasuresFile.open(outputPath); - TLLM_CHECK_WITH_INFO( - mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str()); - } - session.exportMeasure(mMeasuresFile, false); - } -} - -void DataReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info) -{ - std::ostringstream oss; - RequestInfo::serialize(info, oss); - auto const& serializedInfo = oss.str(); - std::size_t const infoSize = serializedInfo.size(); - Id id{Id::REQUEST_SEND}; - connection->send(executor::kv_cache::DataContext{kID_TAG}, &id, sizeof(id)); - connection->send(executor::kv_cache::DataContext{kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize)); - connection->send(executor::kv_cache::DataContext{kINFO_TAG}, serializedInfo.data(), infoSize); -} - -std::unique_ptr const& DataReceiverImpl::getReceiveCacheResource( - LlmRequest const& llmRequest) -{ - std::scoped_lock lock(mProcessIoResouceMutex); - TLLM_CHECK(llmRequest.getDataTransceiverState().getCommState().has_value()); - std::string processString = "default"; - if (common::getEnvRequestKVCacheConcurrent()) - { - processString = llmRequest.getDataTransceiverState().getCommState()->toString(); - } - if (mProcessToResources.find(processString) == mProcessToResources.end()) - { - mProcessToResources.emplace(processString, - std::make_unique( - runtime::BufferManager{std::make_shared()}, runtime::CudaEvent{})); - } - - return mProcessToResources.at(processString); -} - -} // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h deleted file mode 100644 index 2f277f14ff..0000000000 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h +++ /dev/null @@ -1,113 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cacheFormatter.h" -#include "cacheTransBuffer.h" -#include "dataTransceiver.h" -#include "tensorrt_llm/common/envUtils.h" -#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h" - -#include - -namespace tensorrt_llm::batch_manager -{ -struct TransceiverTag -{ - enum class Id : uint64_t - { - REQUEST_SEND = 1, - TERMINATION = 2 - }; - - static constexpr int32_t kID_TAG{19}; - static constexpr int32_t kINFO_SIZE_TAG{22}; - static constexpr int32_t kINFO_TAG{32}; -}; - -using BaseCacheFormatter = kv_cache_manager::BaseCacheFormatter; - -class DataSenderImpl : public DataSender, public TransceiverTag -{ -public: - using SizeType32 = tensorrt_llm::runtime::SizeType32; - - DataSenderImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, - SizeType32 selfIndex, std::unique_ptr formatter); - - [[nodiscard]] RequestInfo recvRequestInfo() override; - - void sendSync(LlmRequest const& llmRequest) override; - - [[nodiscard]] executor::kv_cache::CommState const& getCommState() const override; - - void setCommState(executor::kv_cache::CommState commState) override; - - [[nodiscard]] size_t getCounterpartsCount(LlmRequest::RequestIdType requestId) const override; - - void release(LlmRequest::RequestIdType requestId) override; - -private: - executor::kv_cache::ConnectionManager* mManager; - std::map mRequestToSession; - executor::DataTransceiverState mSelfState; - std::unique_ptr mFormatter; - std::mutex mMtxForMap; - runtime::BufferManager mBufferManager; - std::ofstream mMeasuresFile; -}; - -class DataReceiverImpl : public DataReceiver, public TransceiverTag -{ -public: - using SizeType32 = tensorrt_llm::runtime::SizeType32; - - DataReceiverImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, - SizeType32 selfIndex, std::unique_ptr formatter); - - TransferSession sendRequestInfo(LlmRequest const& llmRequest) override; - - void receiveSync(TransferSession& session) override; - -private: - struct ReceiveCacheResource - { - runtime::BufferManager mBufferManager; - runtime::CudaEvent mCudaEvent; - - ReceiveCacheResource(runtime::BufferManager&& bufferManager, runtime::CudaEvent&& cudaEvent) - : mBufferManager(bufferManager) - , mCudaEvent(std::move(cudaEvent)) - { - } - }; - - static void sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info); - - [[nodiscard]] std::unique_ptr const& getReceiveCacheResource(LlmRequest const& llmRequest); - - executor::kv_cache::ConnectionManager* mManager; - executor::DataTransceiverState mSelfState; - std::unique_ptr mFormatter; - std::unordered_map> mProcessToResources; - std::mutex mProcessIoResouceMutex; - std::ofstream mMeasuresFile; - std::mutex mMeasuresFileMutex; -}; - -} // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index 474a0614d7..b0fba65363 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -84,7 +84,7 @@ bool MLACacheFormatter::needSendCache( return selfTpRank % (selfTPNum / destTPNum) == 0; } -void MLACacheFormatter::format(TransferSession& session) +void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& session) { NVTX3_SCOPED_RANGE(MLACacheFormatter_format); auto const& llmRequest = session.getLlmRequest(); @@ -292,7 +292,7 @@ void MLACacheFormatter::format(TransferSession& session) mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID: %ld.", llmRequest.mRequestId); } -void MLACacheFormatter::unformat(TransferSession& session) +void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& session) { NVTX3_SCOPED_RANGE(MLACacheFormatter_unformat); auto const& llmRequest = session.getLlmRequest(); diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.h b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.h index 17c671519a..acaf231363 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.h +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.h @@ -35,9 +35,9 @@ public: TLLM_CHECK(mCacheTransBufferManager); } - void format(TransferSession& session) override; + void format(tensorrt_llm::batch_manager::TransferSession& session) override; - void unformat(TransferSession& session) override; + void unformat(tensorrt_llm::batch_manager::TransferSession& session) override; [[nodiscard]] bool inquireSupport(CacheState const& selfConfig, CacheState const& destConfig) const override; diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index 08cb4d407c..ccab558e90 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -80,6 +80,8 @@ using namespace tensorrt_llm::runtime; namespace tc = tensorrt_llm::common; namespace tk = tensorrt_llm::kernels; +using tensorrt_llm::batch_manager::CacheTransceiverFactory; + namespace tensorrt_llm::batch_manager { diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h index 28d1767525..079d6340d5 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h @@ -55,6 +55,11 @@ namespace tensorrt_llm::mpi class MpiWaitThread; } // namespace tensorrt_llm::mpi +namespace tensorrt_llm::batch_manager +{ +class BaseCacheTransceiver; +} + namespace tensorrt_llm::batch_manager { @@ -79,7 +84,6 @@ class LlmRequest; class RuntimeBuffers; class BasePeftCacheManager; class GuidedDecoder; -class BaseCacheTransceiver; // Algorithms class CapacityScheduler; diff --git a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp index a9ba23e414..6ee50ab8e4 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp +++ b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp @@ -104,7 +104,7 @@ void AgentConnection::send(DataContext const& ctx, void const* data, size_t size MemoryDesc srcDesc{ reinterpret_cast(data), size, static_cast(mAgentConnectionManager->getDeviceId())}; MemoryDescs srcDescs{MemoryType::kVRAM, {srcDesc}}; - auto dstBaseDesc = mSenderState.mReceiverBufferDesc; + auto dstBaseDesc = mSenderState.mCacheReceiverBufferDesc; auto offset = size / mSenderState.mOffsetRatio.second * mSenderState.mOffsetRatio.first; MemoryDesc dstDesc{dstBaseDesc.getAddr() + offset, size, dstBaseDesc.getDeviceId()}; TLLM_LOG_DEBUG( @@ -162,9 +162,9 @@ void AgentConnection::sendRequestAndBufferInfo( } void AgentConnection::setSenderState( - MemoryDesc mReceiverBufferDesc, int validSegmentIdx, std::pair offsetRatio) + MemoryDesc mCacheReceiverBufferDesc, int validSegmentIdx, std::pair offsetRatio) { - mSenderState.mReceiverBufferDesc = mReceiverBufferDesc; + mSenderState.mCacheReceiverBufferDesc = mCacheReceiverBufferDesc; mSenderState.validSegmentIdx = validSegmentIdx; mSenderState.mOffsetRatio = offsetRatio; } diff --git a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h index 0ee171632f..ddcdc60103 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h +++ b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h @@ -175,7 +175,8 @@ public: void recv(DataContext const& ctx, void* data, size_t size) const override; void sendRequestAndBufferInfo( batch_manager::RequestInfo& requestInfo, std::optional cacheBufferId, int validConnectionIdx); - void setSenderState(MemoryDesc mReceiverBufferDesc, int valideSegmentIdx, std::pair offsetRatio); + void setSenderState( + MemoryDesc mCacheReceiverBufferDesc, int valideSegmentIdx, std::pair offsetRatio); [[nodiscard]] std::optional getCacheBufferId() const; void setHasLoadRemoteAgent(bool hasLoadRemoteAgent); [[nodiscard]] bool hasLoadRemoteAgent() const; @@ -186,7 +187,7 @@ private: struct SenderState { - MemoryDesc mReceiverBufferDesc{nullptr, 0, 0}; + MemoryDesc mCacheReceiverBufferDesc{nullptr, 0, 0}; int validSegmentIdx{0}; std::pair mOffsetRatio; SenderState() = default; diff --git a/cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/connection.cpp b/cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/connection.cpp index 1cd027af33..73f7aca8cf 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/connection.cpp +++ b/cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/connection.cpp @@ -18,7 +18,7 @@ #include "ucxCacheCommunicator.h" #if ENABLE_UCX -#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h" +#include "tensorrt_llm/batch_manager/dataTransceiver.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/tllmException.h" #include "tensorrt_llm/executor/cache_transmission/ucx_utils/connection.h" @@ -114,8 +114,8 @@ void UcxConnection::sendConnectionId(DataContext const& ctx, void const* data, s std::future future = promise.get_future(); auto completionCallback = [&](ucs_status_t, ucxx::RequestCallbackUserData) -> void { promise.set_value(); }; - uint64_t tag - = ((mSendTagPrefix & 0xFFFFFFFF) << 32) | static_cast(batch_manager::TransceiverTag::kID_TAG); + uint64_t tag = ((mSendTagPrefix & 0xFFFFFFFF) << 32) + | static_cast(tensorrt_llm::batch_manager::TransceiverTag::kID_TAG); std::vector buffer(size + sizeof(mConnectionId)); memcpy(buffer.data(), data, size); memcpy(buffer.data() + size, &mConnectionIdInPeer, sizeof(mConnectionIdInPeer)); @@ -133,7 +133,7 @@ void UcxConnection::sendConnectionId(DataContext const& ctx, void const* data, s void UcxConnection::send(DataContext const& ctx, void const* data, size_t size) const { - if (ctx.getTag() == batch_manager::TransceiverTag::kID_TAG) + if (ctx.getTag() == tensorrt_llm::batch_manager::TransceiverTag::kID_TAG) { sendConnectionId(ctx, data, size); return; diff --git a/cpp/tests/unit_tests/executor/ucxCommTest.cpp b/cpp/tests/unit_tests/executor/ucxCommTest.cpp index 5895ac0947..4caec5022f 100644 --- a/cpp/tests/unit_tests/executor/ucxCommTest.cpp +++ b/cpp/tests/unit_tests/executor/ucxCommTest.cpp @@ -26,7 +26,7 @@ #include "tensorrt_llm/batch_manager/cacheFormatter.h" #include "tensorrt_llm/batch_manager/cacheTransceiver.h" -#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h" +#include "tensorrt_llm/batch_manager/dataTransceiver.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" @@ -45,7 +45,6 @@ #include #include #include -#include #include #include @@ -84,6 +83,7 @@ class UcxCommTest : public ::testing::Test }; using DataContext = tensorrt_llm::executor::kv_cache::DataContext; +using TransceiverTag = tensorrt_llm::batch_manager::TransceiverTag; TEST_F(UcxCommTest, Basic) { diff --git a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp index 6b4298909a..2d28e6d4e5 100644 --- a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp +++ b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp @@ -26,7 +26,6 @@ #include "tensorrt_llm/batch_manager/cacheFormatter.h" #include "tensorrt_llm/batch_manager/cacheTransceiver.h" -#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" @@ -154,106 +153,6 @@ TEST_F(CacheConfigTest, EqualTo) EXPECT_EQ(state0, state1); } -// --------------------------------------- -// MockTransceiverTest -// --------------------------------------- - -class MockDataSender : public DataSender -{ -public: - MockDataSender() - { - ON_CALL(*this, getCommState).WillByDefault(ReturnRef(mState)); - ON_CALL(*this, recvRequestInfo) - .WillByDefault(Return(RequestInfo{0, - texec::DataTransceiverState{ - texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 1, {10}, nvinfer1::DataType::kFLOAT}, - texec::kv_cache::CommState{std::vector{0}, 0}}})); - ON_CALL(*this, getCounterpartsCount).WillByDefault(Return(1)); - } - - MOCK_METHOD(RequestInfo, recvRequestInfo, (), (override)); - MOCK_METHOD(void, sendSync, (LlmRequest const&), (override)); - MOCK_METHOD(texec::kv_cache::CommState const&, getCommState, (), (const)); - MOCK_METHOD(void, setCommState, (texec::kv_cache::CommState), (override)); - MOCK_METHOD(size_t, getCounterpartsCount, (LlmRequest::RequestIdType), (const)); - MOCK_METHOD(void, release, (LlmRequest::RequestIdType), (override)); - -private: - static texec::kv_cache::CommState mState; -}; - -texec::kv_cache::CommState MockDataSender::mState; - -class MockDataReceiver : public DataReceiver -{ -public: - MOCK_METHOD(TransferSession, sendRequestInfo, (LlmRequest const&), (override)); - MOCK_METHOD(void, receiveSync, (TransferSession&), (override)); -}; - -class MockTransceiverTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-type-member-init) -{ -public: - void SetUp() override {} - - void TearDown() override {} - - static auto makeLlmRequest( - LlmRequest::RequestIdType requestId = 0, SizeType32 maxNewTokens = 1, VecTokens inputTokens = {-1}) - { - texec::Request request{std::move(inputTokens), maxNewTokens}; - auto state = std::make_unique(); - auto stats = texec::ContextPhaseParams({}, requestId, state.release(), std::nullopt); - request.setContextPhaseParams(std::move(stats)); - return std::make_unique(requestId, std::move(request)); - } -}; - -TEST_F(MockTransceiverTest, MpiResponderBasic) -{ - if (tensorrt_llm::mpi::MpiComm::world().getSize() > 2) - { - GTEST_SKIP() << "mpirun with procs<=2 is required to run this test."; - } - auto sender = std::make_unique(); - EXPECT_CALL(*sender, recvRequestInfo) - .WillOnce(Return(RequestInfo{0, - texec::DataTransceiverState{ - texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 1, {4}, nvinfer1::DataType::kFLOAT}, - texec::kv_cache::CommState{std::vector{0}, 0}}})); - EXPECT_CALL(*sender, sendSync).WillOnce(Return()); - EXPECT_CALL(*sender, getCounterpartsCount).WillOnce(Return(1)); - EXPECT_CALL(*sender, release).WillOnce(Return()); - - DataResponder responder{std::move(sender)}; - auto request = makeLlmRequest(0); - auto future = responder.respondAndSendAsync(*request); - future.get(); -} - -TEST_F(MockTransceiverTest, MpiRequesterBasic) -{ - - if (tensorrt_llm::mpi::MpiComm::world().getSize() > 2) - { - GTEST_SKIP() << "mpirun with procs<=2 is required to run this test."; - } - auto receiver = std::make_unique(); - auto state = std::make_unique(); - state->setCommState(texec::kv_cache::CommState{std::vector{0}}); - EXPECT_CALL(*receiver, sendRequestInfo) - .WillOnce(Return(TransferSession({nullptr}, DataContext{0}, *state, *state, - tensorrt_llm::runtime::BufferManager{std::make_shared()}, nullptr))); - EXPECT_CALL(*receiver, receiveSync).WillOnce(Return()); - DataRequester requester{std::move(receiver)}; - auto request = makeLlmRequest(0); - auto stats = texec::ContextPhaseParams({}, 0, state.release(), std::nullopt); - request->setContextPhaseParams(std::move(stats)); - auto future = requester.requestAndReceiveAsync(*request); - future.get(); -} - // TODO: Restore multi-rank tests. // --------------------------------------- @@ -397,15 +296,13 @@ protected: mCacheTransBufferManager = std::make_unique(mManager.get(), maxNumTokens); if (isSender) { - mResponder = std::make_unique( - std::make_unique(mConnectionManager.get(), *mCacheState, mlocalRank, - std::make_unique(mManager.get(), mCacheTransBufferManager.get()))); + mSender = std::make_unique(mConnectionManager.get(), *mCacheState, mlocalRank, + std::make_unique(mManager.get(), mCacheTransBufferManager.get())); } else { - mRequester = std::make_unique( - std::make_unique(mConnectionManager.get(), *mCacheState, mlocalRank, - std::make_unique(mManager.get(), mCacheTransBufferManager.get()))); + mRequester = std::make_unique(mConnectionManager.get(), *mCacheState, mlocalRank, + std::make_unique(mManager.get(), mCacheTransBufferManager.get())); } } @@ -435,11 +332,11 @@ protected: // fill cache with tokens (= request length), for reuse test TLLM_CUDA_CHECK(cudaMemset(block.data(), llmRequest->getPromptLen(), block.getSizeInBytes())); } - mFutures.emplace_back(mResponder->respondAndSendAsync(*llmRequest)); + mFutures.emplace_back(mSender->sendAsync(*llmRequest)); } else { - auto future = mRequester->requestAndReceiveAsync(*llmRequest); + auto future = mRequester->receiveAsync(*llmRequest); future.get(); TLLM_CUDA_CHECK(cudaDeviceSynchronize()); auto blockRange = BlockRange::fromAllBlockIds(*mManager, llmRequest->mRequestId); @@ -460,8 +357,8 @@ protected: SizeType32 mMaxNumSequences{}; std::unique_ptr mManager; std::unique_ptr mCacheTransBufferManager; - std::unique_ptr mResponder; - std::unique_ptr mRequester; + std::unique_ptr mSender; + std::unique_ptr mRequester; std::unique_ptr mCacheState; std::unique_ptr mContextCommState; std::vector> mFutures; @@ -789,13 +686,13 @@ protected: if (mIsContext) { - mResponder = std::make_unique(std::make_unique( - mConnectionManager.get(), *mCacheState, mRankInInstance, makeFormatter())); + mSender = std::make_unique( + mConnectionManager.get(), *mCacheState, mRankInInstance, makeFormatter()); } else { - mRequester = std::make_unique(std::make_unique( - mConnectionManager.get(), *mCacheState, mRankInInstance, makeFormatter())); + mRequester = std::make_unique( + mConnectionManager.get(), *mCacheState, mRankInInstance, makeFormatter()); } std::vector contextRankVec(mContextRankSize); @@ -930,7 +827,7 @@ protected: auto const onlyWindowSize = blockManager.getPoolWindowSize(0); blockManager.getBufferManager(onlyWindowSize).getStream().synchronize(); - auto future = mResponder->respondAndSendAsync(*llmRequest); + auto future = mSender->sendAsync(*llmRequest); return future; } @@ -940,7 +837,7 @@ protected: auto constexpr beamWidth{1}; mManager->addSequence(llmRequest->mRequestId, llmRequest->getNumTokens(beamIdx), beamWidth, llmRequest); - return mRequester->requestAndReceiveAsync(*llmRequest); + return mRequester->receiveAsync(*llmRequest); } void generationVerifyKVCache(std::shared_ptr const& llmRequest) @@ -1162,8 +1059,8 @@ protected: SizeType32 mMaxNumSequences{}; std::unique_ptr mManager; std::unique_ptr mCacheTransBufferManager; - std::unique_ptr mResponder; - std::unique_ptr mRequester; + std::unique_ptr mSender; + std::unique_ptr mRequester; std::unique_ptr mCacheState; std::unique_ptr mContextCacheState; std::unique_ptr mContextCommState;