TensorRT-LLMs/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp
Raayan Dhar bae9560e62
[https://nvbugs/5448767][fix] sync termination of requests across PP ranks (#7455)
Signed-off-by: raayandhar <rdhar@nvidia.com>
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
Co-authored-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
2025-09-07 08:45:49 -04:00

578 lines
19 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataTransceiver.h"
#include "tensorrt_llm/batch_manager/common.h"
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"
#include "tensorrt_llm/batch_manager/runtimeBuffers.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/tllmException.h"
#include "tensorrt_llm/common/utils.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <future>
#include <map>
#include <memory>
#include <unordered_map>
namespace tensorrt_llm::batch_manager
{
using kv_cache_manager::BlockRange;
using runtime::SizeType32;
RequestInfo::RequestInfo(LlmRequest::RequestIdType requestId, executor::DataTransceiverState transState)
: mRequestId{requestId}
, mTransState{std::move(transState)}
{
}
RequestInfo::RequestInfo(
LlmRequest::RequestIdType requestId, std::vector<size_t> blockHashes, executor::DataTransceiverState transState)
: mRequestId{requestId}
, mBlockHashes{std::move(blockHashes)}
, mTransState{std::move(transState)}
{
}
bool RequestInfo::operator==(RequestInfo const& rhs) const
{
return mRequestId == rhs.mRequestId && mBlockHashes == rhs.mBlockHashes && mTransState == rhs.mTransState;
}
LlmRequest::RequestIdType RequestInfo::getRequestId() const noexcept
{
return mRequestId;
}
executor::DataTransceiverState const& RequestInfo::getTransState() const noexcept
{
return mTransState;
}
void RequestInfo::serialize(RequestInfo const& requestInfo, std::ostream& os)
{
namespace su = executor::serialize_utils;
su::serialize(requestInfo.mRequestId, os);
su::serialize(requestInfo.mBlockHashes, os);
su::serialize(requestInfo.mTransState, os);
}
RequestInfo RequestInfo::deserialize(std::istream& is)
{
namespace su = executor::serialize_utils;
auto requestId = su::deserialize<decltype(mRequestId)>(is);
auto blockHashes = su::deserialize<decltype(mBlockHashes)>(is);
auto transState = su::deserialize<decltype(mTransState)>(is);
return RequestInfo{requestId, std::move(blockHashes), std::move(transState)};
}
std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo)
{
namespace su = executor::serialize_utils;
std::size_t totalSize = 0;
totalSize += su::serializedSize(requestInfo.mRequestId);
totalSize += su::serializedSize(requestInfo.mBlockHashes);
totalSize += su::serializedSize(requestInfo.mTransState);
return totalSize;
}
void TransferSession::appendMeasure(double delay, double duration, size_t size)
{
if (!mRecordMeasure)
{
return;
}
auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps
mMeasures.emplace_back(Measure{delay, duration, bandwidth});
}
void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const
{
if (mMeasures.empty())
{
return;
}
// write header if not exist
if (outFile.tellp() == 0)
{
outFile << "RequestID";
for (size_t i = 0; i < mMeasures.size(); i++)
{
outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)";
}
outFile << '\n';
}
// write measures
TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value());
auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId();
outFile << reqId;
for (auto const& measure : mMeasures)
{
outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth;
}
outFile << '\n' << std::flush;
}
class DataResponder::Impl
{
public:
using RequestIdType = LlmRequest::RequestIdType;
Impl(std::unique_ptr<DataSender> sender)
: mSender{std::move(sender)}
{
TLLM_CHECK(mSender);
TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId));
mCurrentRequest = std::nullopt;
mResponseFuture = std::async(std::launch::async, &Impl::response, this);
}
[[nodiscard]] std::future<void> respondAndSendAsync(LlmRequest& llmRequest)
{
std::promise<void> promise;
auto future = promise.get_future();
{
{
std::unique_lock lkResp(mResponderMutex);
mReadyResponses.emplace(
llmRequest.mRequestId, Response{std::addressof(llmRequest), std::move(promise)});
}
std::unique_lock lkCond(mCondMutex);
mAnyReady = true;
}
mResponderCv.notify_all();
return future;
}
[[nodiscard]] executor::kv_cache::CommState const& getCommState() const
{
return mSender->getCommState();
}
void setCommState(executor::kv_cache::CommState commState)
{
mSender->setCommState(std::move(commState));
}
~Impl()
{
terminate();
}
private:
struct Response
{
LlmRequest* mRequest;
std::promise<void> mPromise;
};
void sendAndRemoveResponse(RequestIdType id, Response resp) noexcept
{
try
{
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId));
mSender->sendSync(*resp.mRequest);
mSender->release(id);
resp.mPromise.set_value();
}
catch (tensorrt_llm::common::RequestSpecificException const& e)
{
TLLM_LOG_ERROR("Exception in sendAndRemoveResponse: %s ", e.what());
auto new_exception = TLLM_REQUEST_EXCEPTION(id, e.getErrorCode(), "%s", e.what());
resp.mPromise.set_exception(std::make_exception_ptr(new_exception));
}
catch (std::exception const& e)
{
TLLM_LOG_ERROR("Exception in sendAndRemoveResponse: %s ", e.what());
resp.mPromise.set_exception(std::current_exception());
}
}
void sendResponse(std::vector<size_t> const& blockHashes, std::map<RequestIdType, Response>::iterator it)
{
auto reqId = mCurrentRequest.value();
auto count = --mRemainSendCount[reqId];
TLLM_CHECK(count >= 0);
if (count == 0)
{
mRemainSendCount.erase(reqId);
// TODO(zhengd): pass the hashes directly instead of update llmRequest
auto llmRequest = it->second.mRequest;
llmRequest->setRequestedBlockHashes(std::move(blockHashes));
if (common::getEnvParallelCacheSend())
{
// TODO: Use a thread pool and check for thread safety.
std::thread(&DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second))
.detach();
}
else
{
DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
}
removeResponse(it);
}
mCurrentRequest = std::nullopt;
}
void response() noexcept
{
try
{
tensorrt_llm::common::setThreadName("dataTransResp");
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId));
while (!mTerminate || !mAnyReady)
{
if (!mAnyReady)
{
std::unique_lock lk(mCondMutex);
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
}
if (mTerminate)
{
break;
}
std::vector<size_t> blockHashes;
if (!isSending() && !mReadyResponses.empty())
{
auto const& requestInfo = mSender->recvRequestInfo();
auto reqId = requestInfo.getRequestId();
blockHashes = requestInfo.getBlockHashes();
mCurrentRequest = reqId;
if (mRemainSendCount.find(reqId) == mRemainSendCount.end())
{
mRemainSendCount[reqId] = mSender->getCounterpartsCount(reqId);
}
}
auto it = getCurrentResponse();
if (it != mReadyResponses.end())
{
sendResponse(blockHashes, it);
}
else
{
auto it = getCurrentResponse();
while (it == mReadyResponses.end())
{
std::unique_lock lk(mCondMutex);
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
if (mTerminate)
{
break;
}
it = getCurrentResponse();
}
sendResponse(blockHashes, it);
}
}
}
catch (std::exception const& err)
{
TLLM_LOG_ERROR("Exception in DataResponder response: %s", err.what());
for (auto& it : mReadyResponses)
{
it.second.mPromise.set_exception(std::current_exception());
}
}
}
void terminate()
{
{
std::unique_lock lk(mCondMutex);
mTerminate = true;
}
// We don't have to wait for the future. If another thread is sending data, it won't pay attention
// to the terminate flag.
mResponderCv.notify_all();
}
void removeResponse(std::map<RequestIdType, Response>::iterator it)
{
{
std::unique_lock lkResp(mResponderMutex);
mReadyResponses.erase(it);
}
if (mReadyResponses.empty())
{
std::unique_lock lkCond(mCondMutex);
mAnyReady = false;
}
}
[[nodiscard]] bool isSending() const
{
return mCurrentRequest.has_value();
}
[[nodiscard]] RequestIdType getCurrentRequestId() const
{
return mCurrentRequest.value();
}
[[nodiscard]] std::map<RequestIdType, Response>::iterator getCurrentResponse()
{
std::unique_lock lk(mResponderMutex);
return mReadyResponses.find(getCurrentRequestId());
}
private:
std::optional<RequestIdType> mCurrentRequest;
std::map<RequestIdType, Response> mReadyResponses;
std::mutex mResponderMutex, mCondMutex;
std::atomic<bool> mAnyReady{false}, mTerminate{false};
std::condition_variable mResponderCv;
std::future<void> mResponseFuture;
std::unique_ptr<DataSender> mSender;
std::unordered_map<LlmRequest::RequestIdType, int> mRemainSendCount;
int mDeviceId{-1};
};
class DataRequester::Impl
{
public:
Impl(std::unique_ptr<DataReceiver> receiver)
: mReceiver{std::move(receiver)}
{
TLLM_CHECK(mReceiver);
TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId));
}
[[nodiscard]] std::future<void> requestAndReceiveAsync(LlmRequest& llmRequest)
{
// TODO: Modify the implementation here to avoid frequent thread creation.
return std::async(std::launch::async, &DataRequester::Impl::requestSync, this, std::ref(llmRequest));
}
[[nodiscard]] std::future<void> requestAndReceiveAsyncMultiThreads(LlmRequest& llmRequest)
{
try
{
auto promise = std::make_unique<std::promise<void>>();
auto future = promise->get_future();
TLLM_CHECK(llmRequest.getDataTransceiverState().getCommState().has_value());
std::string processInfo = "default";
if (common::getEnvRequestKVCacheConcurrent())
{
processInfo = llmRequest.getDataTransceiverState().getCommState()->toString();
}
if (mInstanceToAsyncResource.find(processInfo) == mInstanceToAsyncResource.end())
{
mInstanceToAsyncResource.emplace(processInfo, std::make_unique<AsyncResource>());
auto requestFuture = std::async(std::launch::async, &DataRequester::Impl::request, this,
std::ref(*mInstanceToAsyncResource.at(processInfo)));
mRequestFutures.emplace_back(std::move(requestFuture));
}
auto& asyncResource = mInstanceToAsyncResource.at(processInfo);
{
std::unique_lock<std::mutex> lck(asyncResource->mMtxForQueue);
asyncResource->mRequestsQueue.emplace_back(std::addressof(llmRequest), std::move(promise));
}
asyncResource->mCVforQueue.notify_all();
return future;
}
catch (std::exception const& e)
{
TLLM_THROW("%s", e.what());
}
}
~Impl()
{
for (auto&& [processInfo, asyncResource] : mInstanceToAsyncResource)
{
asyncResource->mTerminate = true;
asyncResource->mCVforQueue.notify_all();
}
for (auto&& future : mRequestFutures)
{
future.get();
}
}
private:
void requestSync(LlmRequest& llmRequest)
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"Start calling requestSync for request ID: %zu, context request ID: %zu.", llmRequest.mRequestId,
llmRequest.getContextPhaseParams().value().getReqId());
llmRequest.setKvCacheTransferStart(std::chrono::steady_clock::now());
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId));
auto session = mReceiver->sendRequestInfo(llmRequest);
mReceiver->receiveSync(session);
llmRequest.setKvCacheTransferEnd(std::chrono::steady_clock::now());
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"End calling requestSync for request ID: %zu, context request ID: %zu.", llmRequest.mRequestId,
llmRequest.getContextPhaseParams().value().getReqId());
}
struct RequestAndPromise
{
LlmRequest* mRequest;
std::unique_ptr<std::promise<void>> mPromise;
RequestAndPromise()
: mRequest(nullptr)
, mPromise(nullptr)
{
}
RequestAndPromise(LlmRequest* request, std::unique_ptr<std::promise<void>>&& promise)
: mRequest(request)
, mPromise(std::move(promise))
{
}
RequestAndPromise(RequestAndPromise const&) = delete;
RequestAndPromise(RequestAndPromise&& other) noexcept
: mRequest(other.mRequest)
, mPromise(std::move(other.mPromise))
{
other.mRequest = nullptr;
}
RequestAndPromise& operator=(RequestAndPromise&& other) noexcept
{
if (this != &other)
{
mRequest = nullptr;
if (mPromise)
{
mPromise.reset();
}
mRequest = other.mRequest;
mPromise = std::move(other.mPromise);
other.mRequest = nullptr;
}
return *this;
}
};
struct AsyncResource
{
std::deque<RequestAndPromise> mRequestsQueue;
std::mutex mMtxForQueue;
std::condition_variable mCVforQueue;
std::atomic<bool> mTerminate{false};
};
void request(AsyncResource& resource)
{
tensorrt_llm::common::setThreadName("dataTransRequest");
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId));
while (!resource.mTerminate)
{
RequestAndPromise requestAndPromise;
{
std::unique_lock lck(resource.mMtxForQueue);
resource.mCVforQueue.wait(
lck, [&resource] { return !resource.mRequestsQueue.empty() || resource.mTerminate; });
if (resource.mTerminate)
{
if (!resource.mRequestsQueue.empty())
{
TLLM_LOG_WARNING(
"There are still %zu requests in the mRequestsQueue, but encountered terminate.",
resource.mRequestsQueue.size());
}
break;
}
requestAndPromise = std::move(resource.mRequestsQueue.front());
resource.mRequestsQueue.pop_front();
}
{
try
{
TLLM_CHECK_WITH_INFO(requestAndPromise.mRequest != nullptr, "requestAndPromise.mRequest is null");
requestSync(*requestAndPromise.mRequest);
requestAndPromise.mPromise->set_value();
}
catch (tensorrt_llm::common::RequestSpecificException const& err)
{
TLLM_LOG_ERROR("Exception in DataRequester request(): request id:%zu , request context id:%zu : %s",
requestAndPromise.mRequest->mRequestId,
requestAndPromise.mRequest->getContextPhaseParams().value().getReqId(), err.what());
auto new_exception = TLLM_REQUEST_EXCEPTION(
requestAndPromise.mRequest->mRequestId, err.getErrorCode(), "%s", err.what());
requestAndPromise.mPromise->set_exception(std::make_exception_ptr(new_exception));
}
catch (std::exception const& err)
{
TLLM_LOG_ERROR("Exception in DataRequester request(): request id:%ld , request context id:%ld : %s",
requestAndPromise.mRequest->mRequestId,
requestAndPromise.mRequest->getContextPhaseParams().value().getReqId(), err.what());
requestAndPromise.mPromise->set_exception(std::current_exception());
}
}
}
}
std::unique_ptr<DataReceiver> mReceiver;
int mDeviceId{-1};
std::vector<std::future<void>> mRequestFutures;
std::unordered_map<std::string, std::unique_ptr<AsyncResource>> mInstanceToAsyncResource;
};
DataResponder::DataResponder(std::unique_ptr<DataSender> sender)
: mImpl{std::make_unique<Impl>(std::move(sender))}
{
}
std::future<void> DataResponder::respondAndSendAsync(LlmRequest& llmRequest) const
{
return mImpl->respondAndSendAsync(llmRequest);
}
executor::kv_cache::CommState const& DataResponder::getCommState() const
{
return mImpl->getCommState();
}
void DataResponder::setCommState(executor::kv_cache::CommState commState)
{
mImpl->setCommState(std::move(commState));
}
DataResponder::~DataResponder() = default;
DataRequester::DataRequester(std::unique_ptr<DataReceiver> receiver)
: mImpl{std::make_unique<Impl>(std::move(receiver))}
{
}
std::future<void> DataRequester::requestAndReceiveAsync(LlmRequest& llmRequest) const
{
return mImpl->requestAndReceiveAsyncMultiThreads(llmRequest);
}
DataRequester::~DataRequester() = default;
} // namespace tensorrt_llm::batch_manager