/* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "dataTransceiver.h" #include "tensorrt_llm/batch_manager/common.h" #include "tensorrt_llm/batch_manager/kvCacheUtils.h" #include "tensorrt_llm/batch_manager/runtimeBuffers.h" #include "tensorrt_llm/common/envUtils.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/utils.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include #include #include #include namespace tensorrt_llm::batch_manager { using kv_cache_manager::BlockRange; using runtime::SizeType32; RequestInfo::RequestInfo(LlmRequest::RequestIdType requestId, executor::DataTransceiverState transState) : mRequestId{requestId} , mTransState{std::move(transState)} { } RequestInfo::RequestInfo( LlmRequest::RequestIdType requestId, std::vector blockHashes, executor::DataTransceiverState transState) : mRequestId{requestId} , mBlockHashes{std::move(blockHashes)} , mTransState{std::move(transState)} { } bool RequestInfo::operator==(RequestInfo const& rhs) const { return mRequestId == rhs.mRequestId && mBlockHashes == rhs.mBlockHashes && mTransState == rhs.mTransState; } LlmRequest::RequestIdType RequestInfo::getRequestId() const noexcept { return mRequestId; } executor::DataTransceiverState const& RequestInfo::getTransState() const noexcept { return mTransState; } void RequestInfo::serialize(RequestInfo const& requestInfo, std::ostream& os) { namespace su = executor::serialize_utils; su::serialize(requestInfo.mRequestId, os); su::serialize(requestInfo.mBlockHashes, os); su::serialize(requestInfo.mTransState, os); } RequestInfo RequestInfo::deserialize(std::istream& is) { namespace su = executor::serialize_utils; auto requestId = su::deserialize(is); auto blockHashes = su::deserialize(is); auto transState = su::deserialize(is); return RequestInfo{requestId, std::move(blockHashes), std::move(transState)}; } std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo) { namespace su = executor::serialize_utils; std::size_t totalSize = 0; totalSize += su::serializedSize(requestInfo.mRequestId); totalSize += su::serializedSize(requestInfo.mBlockHashes); totalSize += su::serializedSize(requestInfo.mTransState); return totalSize; } void TransferSession::appendMeasure(double delay, double duration, size_t size) { if (!mRecordMeasure) { return; } auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps mMeasures.emplace_back(Measure{delay, duration, bandwidth}); } void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const { if (mMeasures.empty()) { return; } // write header if not exist if (outFile.tellp() == 0) { outFile << "RequestID"; for (size_t i = 0; i < mMeasures.size(); i++) { outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)"; } outFile << '\n'; } // write measures TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value()); auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId(); outFile << reqId; for (auto const& measure : mMeasures) { outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth; } outFile << '\n' << std::flush; } class DataResponder::Impl { public: using RequestIdType = LlmRequest::RequestIdType; Impl(std::unique_ptr sender) : mSender{std::move(sender)} { TLLM_CHECK(mSender); TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId)); mCurrentRequest = std::nullopt; mResponseFuture = std::async(std::launch::async, &Impl::response, this); } [[nodiscard]] std::future respondAndSendAsync(LlmRequest& llmRequest) { std::promise promise; auto future = promise.get_future(); { { std::unique_lock lkResp(mResponderMutex); mReadyResponses.emplace( llmRequest.mRequestId, Response{std::addressof(llmRequest), std::move(promise)}); } std::unique_lock lkCond(mCondMutex); mAnyReady = true; } mResponderCv.notify_all(); return future; } [[nodiscard]] executor::kv_cache::CommState const& getCommState() const { return mSender->getCommState(); } void setCommState(executor::kv_cache::CommState commState) { mSender->setCommState(std::move(commState)); } ~Impl() { terminate(); } private: struct Response { LlmRequest* mRequest; std::promise mPromise; }; void sendAndRemoveResponse(RequestIdType id, Response resp) noexcept { try { TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId)); mSender->sendSync(*resp.mRequest); mSender->release(id); resp.mPromise.set_value(); } catch (std::exception const& e) { TLLM_LOG_ERROR("Exception in sendAndRemoveResponse: %s ", e.what()); resp.mPromise.set_exception(std::current_exception()); } } void response() noexcept { try { tensorrt_llm::common::setThreadName("dataTransResp"); TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId)); while (!mTerminate || !mAnyReady) { if (!mAnyReady) { std::unique_lock lk(mCondMutex); mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); }); } if (mTerminate) { break; } std::vector blockHashes; if (!isSending() && !mReadyResponses.empty()) { auto const& requestInfo = mSender->recvRequestInfo(); auto reqId = requestInfo.getRequestId(); blockHashes = requestInfo.getBlockHashes(); mCurrentRequest = reqId; if (mRemainSendCount.find(reqId) == mRemainSendCount.end()) { mRemainSendCount[reqId] = mSender->getCounterpartsCount(reqId); } } auto it = getCurrentResponse(); if (it != mReadyResponses.end()) { auto reqId = mCurrentRequest.value(); auto count = --mRemainSendCount[reqId]; TLLM_CHECK(count >= 0); if (count == 0) { mRemainSendCount.erase(reqId); // TODO(zhengd): pass the hashes directly instead of update llmRequest auto llmRequest = it->second.mRequest; llmRequest->setRequestedBlockHashes(std::move(blockHashes)); if (common::getEnvParallelCacheSend()) { // TODO: Use a thread pool and check for thread safety. std::thread( &DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second)) .detach(); } else { DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second)); } removeResponse(it); } mCurrentRequest = std::nullopt; } else { TLLM_CHECK_WITH_INFO(!mCurrentRequest.has_value(), "This executor does not have a prepared KV cache for request ID: %zu, and the " "mReadyResponses size is: %zu. mpi rank :%d ", mCurrentRequest.value(), mReadyResponses.size(), mpi::MpiComm::world().getRank()); std::unique_lock lk(mCondMutex); mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); }); } } } catch (std::exception const& err) { TLLM_LOG_ERROR("Exception in DataResponder response: %s", err.what()); for (auto& it : mReadyResponses) { it.second.mPromise.set_exception(std::current_exception()); } } } void terminate() { { std::unique_lock lk(mCondMutex); mTerminate = true; } // We don't have to wait for the future. If another thread is sending data, it won't pay attention // to the terminate flag. mResponderCv.notify_all(); } void removeResponse(std::map::iterator it) { { std::unique_lock lkResp(mResponderMutex); mReadyResponses.erase(it); } if (mReadyResponses.empty()) { std::unique_lock lkCond(mCondMutex); mAnyReady = false; } } [[nodiscard]] bool isSending() const { return mCurrentRequest.has_value(); } [[nodiscard]] RequestIdType getCurrentRequestId() const { return mCurrentRequest.value(); } [[nodiscard]] std::map::iterator getCurrentResponse() { std::unique_lock lk(mResponderMutex); return mReadyResponses.find(getCurrentRequestId()); } private: std::optional mCurrentRequest; std::map mReadyResponses; std::mutex mResponderMutex, mCondMutex; std::atomic mAnyReady{false}, mTerminate{false}; std::condition_variable mResponderCv; std::future mResponseFuture; std::unique_ptr mSender; std::unordered_map mRemainSendCount; int mDeviceId{-1}; }; class DataRequester::Impl { public: Impl(std::unique_ptr receiver) : mReceiver{std::move(receiver)} { TLLM_CHECK(mReceiver); TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId)); } [[nodiscard]] std::future requestAndReceiveAsync(LlmRequest& llmRequest) { // TODO: Modify the implementation here to avoid frequent thread creation. return std::async(std::launch::async, &DataRequester::Impl::requestSync, this, std::ref(llmRequest)); } [[nodiscard]] std::future requestAndReceiveAsyncMultiThreads(LlmRequest& llmRequest) { try { auto promise = std::make_unique>(); auto future = promise->get_future(); TLLM_CHECK(llmRequest.getDataTransceiverState().getCommState().has_value()); std::string processInfo = "default"; if (common::getEnvRequestKVCacheConcurrent()) { processInfo = llmRequest.getDataTransceiverState().getCommState()->toString(); } if (mInstanceToAsyncResource.find(processInfo) == mInstanceToAsyncResource.end()) { mInstanceToAsyncResource.emplace(processInfo, std::make_unique()); auto requestFuture = std::async(std::launch::async, &DataRequester::Impl::request, this, std::ref(*mInstanceToAsyncResource.at(processInfo))); mRequestFutures.emplace_back(std::move(requestFuture)); } auto& asyncResource = mInstanceToAsyncResource.at(processInfo); { std::unique_lock lck(asyncResource->mMtxForQueue); asyncResource->mRequestsQueue.emplace_back(std::addressof(llmRequest), std::move(promise)); } asyncResource->mCVforQueue.notify_all(); return future; } catch (std::exception const& e) { TLLM_THROW("%s", e.what()); } } ~Impl() { for (auto&& [processInfo, asyncResource] : mInstanceToAsyncResource) { asyncResource->mTerminate = true; asyncResource->mCVforQueue.notify_all(); } for (auto&& future : mRequestFutures) { future.get(); } } private: void requestSync(LlmRequest& llmRequest) { TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "Start calling requestSync for request ID: %zu, context request ID: %zu.", llmRequest.mRequestId, llmRequest.getContextPhaseParams().value().getReqId()); llmRequest.setKvCacheTransferStart(std::chrono::steady_clock::now()); TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId)); auto session = mReceiver->sendRequestInfo(llmRequest); mReceiver->receiveSync(session); llmRequest.setKvCacheTransferEnd(std::chrono::steady_clock::now()); TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "End calling requestSync for request ID: %zu, context request ID: %zu.", llmRequest.mRequestId, llmRequest.getContextPhaseParams().value().getReqId()); } struct RequestAndPromise { LlmRequest* mRequest; std::unique_ptr> mPromise; RequestAndPromise() : mRequest(nullptr) , mPromise(nullptr) { } RequestAndPromise(LlmRequest* request, std::unique_ptr>&& promise) : mRequest(request) , mPromise(std::move(promise)) { } RequestAndPromise(RequestAndPromise const&) = delete; RequestAndPromise(RequestAndPromise&& other) noexcept : mRequest(other.mRequest) , mPromise(std::move(other.mPromise)) { other.mRequest = nullptr; } RequestAndPromise& operator=(RequestAndPromise&& other) noexcept { if (this != &other) { mRequest = nullptr; if (mPromise) { mPromise.reset(); } mRequest = other.mRequest; mPromise = std::move(other.mPromise); other.mRequest = nullptr; } return *this; } }; struct AsyncResource { std::deque mRequestsQueue; std::mutex mMtxForQueue; std::condition_variable mCVforQueue; std::atomic mTerminate{false}; }; void request(AsyncResource& resource) { tensorrt_llm::common::setThreadName("dataTransRequest"); TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId)); while (!resource.mTerminate) { RequestAndPromise requestAndPromise; { std::unique_lock lck(resource.mMtxForQueue); resource.mCVforQueue.wait( lck, [&resource] { return !resource.mRequestsQueue.empty() || resource.mTerminate; }); if (resource.mTerminate) { if (!resource.mRequestsQueue.empty()) { TLLM_LOG_WARNING( "There are still %zu requests in the mRequestsQueue, but encountered terminate.", resource.mRequestsQueue.size()); } break; } requestAndPromise = std::move(resource.mRequestsQueue.front()); resource.mRequestsQueue.pop_front(); } { try { TLLM_CHECK_WITH_INFO(requestAndPromise.mRequest != nullptr, "requestAndPromise.mRequest is null"); requestSync(*requestAndPromise.mRequest); requestAndPromise.mPromise->set_value(); } catch (std::exception const& err) { TLLM_LOG_ERROR("Exception in DataRequester request(): request id:%ld , request context id:%ld : %s", requestAndPromise.mRequest->mRequestId, requestAndPromise.mRequest->getContextPhaseParams().value().getReqId(), err.what()); requestAndPromise.mPromise->set_exception(std::current_exception()); } } } } std::unique_ptr mReceiver; int mDeviceId{-1}; std::vector> mRequestFutures; std::unordered_map> mInstanceToAsyncResource; }; DataResponder::DataResponder(std::unique_ptr sender) : mImpl{std::make_unique(std::move(sender))} { } std::future DataResponder::respondAndSendAsync(LlmRequest& llmRequest) const { return mImpl->respondAndSendAsync(llmRequest); } executor::kv_cache::CommState const& DataResponder::getCommState() const { return mImpl->getCommState(); } void DataResponder::setCommState(executor::kv_cache::CommState commState) { mImpl->setCommState(std::move(commState)); } DataResponder::~DataResponder() = default; DataRequester::DataRequester(std::unique_ptr receiver) : mImpl{std::make_unique(std::move(receiver))} { } std::future DataRequester::requestAndReceiveAsync(LlmRequest& llmRequest) const { return mImpl->requestAndReceiveAsyncMultiThreads(llmRequest); } DataRequester::~DataRequester() = default; } // namespace tensorrt_llm::batch_manager