[TRTLLM-8044][refactor] Rename data -> cache for cacheTransceiver (#7659)

This commit is contained in:
Iman Tabrizian 2025-09-16 08:43:56 -04:00 committed by GitHub
parent 8226ef23dc
commit 6ce0624208
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 629 additions and 824 deletions

View File

@ -34,8 +34,14 @@ namespace tensorrt_llm::batch_manager
class ContextProgress;
class BaseCacheTransceiver;
class DataResponder;
class DataRequester;
namespace kv_cache_manager
{
class BaseKVCacheManager;
} // namespace kv_cache_manager
class CacheSender;
class CacheReceiver;
class CacheTransceiverFactory
{
@ -110,9 +116,9 @@ private:
void setContextState(LlmRequest* llmRequest);
std::unique_ptr<DataResponder> mDataResponder;
std::unique_ptr<DataRequester> mDataRequester;
std::vector<std::pair<LlmRequest*, std::future<void>>> mResponderFutures;
std::unique_ptr<CacheSender> mCacheSender;
std::unique_ptr<CacheReceiver> mCacheReceiver;
std::vector<std::pair<LlmRequest*, std::future<void>>> mSenderFutures;
std::vector<std::pair<LlmRequest*, std::future<void>>> mRequesterFutures;
mpi::MpiComm const *mMpiGroupComm{nullptr}, *mMpiWorldComm{nullptr};
std::shared_ptr<mpi::MpiComm> mMpiGroupTensorParaComm, mMpiGroupPipeParaComm, mMpiGroupDataComm,

View File

@ -24,7 +24,6 @@ set(SRCS
createNewDecoderRequests.cpp
contextProgress.cpp
dataTransceiver.cpp
dataTransceiverImpl.cpp
decoderBuffers.cpp
encoderBuffers.cpp
guidedDecoder.cpp

View File

@ -19,6 +19,7 @@
#include "mlaCacheFormatter.h"
#include "tensorrt_llm/batch_manager/contextProgress.h"
#include "tensorrt_llm/batch_manager/dataTransceiver.h"
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
@ -154,7 +155,7 @@ std::vector<size_t> CacheFormatter::pickRecvConnections(
return ret;
}
void CacheFormatter::format(TransferSession& session)
void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& session)
{
NVTX3_SCOPED_RANGE(CacheFormatter_format);
auto const& llmRequest = session.getLlmRequest();
@ -468,7 +469,7 @@ void CacheFormatter::format(TransferSession& session)
mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID:%ld ", llmRequest.mRequestId);
}
void CacheFormatter::unformat(TransferSession& session)
void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& session)
{
NVTX3_SCOPED_RANGE(CacheFormatter_unformat);
auto const& llmRequest = session.getLlmRequest();

View File

@ -18,11 +18,12 @@
#pragma once
#include "cacheTransBuffer.h"
#include "dataTransceiver.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/executor/cacheCommunicator.h"
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
#include "tensorrt_llm/executor/dataTransceiverState.h"
#include "tensorrt_llm/runtime/bufferManager.h"
@ -30,10 +31,25 @@
#include <NvInferRuntimeBase.h>
#include <cstddef>
#include <cstdint>
#include <fstream>
#include <vector>
// Forward declare TransferSession in the correct global namespace scope
namespace tensorrt_llm::batch_manager
{
class TransferSession;
}
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
using Connection = tensorrt_llm::executor::kv_cache::Connection;
using SizeType32 = tensorrt_llm::runtime::SizeType32;
using BaseKVCacheManager = kv_cache_manager::BaseKVCacheManager;
using CacheTransBufferManager = kv_cache_manager::CacheTransBufferManager;
using BlockRange = kv_cache_manager::BlockRange;
BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest);
BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest);
@ -42,16 +58,15 @@ BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmReques
class BaseCacheFormatter
{
public:
using SizeType32 = tensorrt_llm::runtime::SizeType32;
using CacheState = executor::kv_cache::CacheState;
/// @brief Format the cache data into bytes for sending.
/// @param session The transfer session.
virtual void format(TransferSession& session) = 0;
virtual void format(tensorrt_llm::batch_manager::TransferSession& session) = 0;
/// @brief Unformat the cache data from received bytes.
/// @param session The transfer session.
virtual void unformat(TransferSession& session) = 0;
virtual void unformat(tensorrt_llm::batch_manager::TransferSession& session) = 0;
/// @brief Determine whether the sender is applicable to the source and target.
/// @param selfConfig Source data arrangement.
@ -91,9 +106,9 @@ public:
TLLM_CHECK(mCacheTransBufferManager);
}
void format(TransferSession& session) override;
void format(tensorrt_llm::batch_manager::TransferSession& session) override;
void unformat(TransferSession& session) override;
void unformat(tensorrt_llm::batch_manager::TransferSession& session) override;
[[nodiscard]] bool inquireSupport(CacheState const& selfConfig, CacheState const& destConfig) const override;

View File

@ -37,8 +37,9 @@
#include "tensorrt_llm/batch_manager/cacheFormatter.h"
#include "tensorrt_llm/batch_manager/cacheTransceiver.h"
#include "tensorrt_llm/batch_manager/contextProgress.h"
#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/batch_manager/kvCacheType.h"
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/batch_manager/mlaCacheFormatter.h"
#include "tensorrt_llm/common/envUtils.h"
@ -116,7 +117,6 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
: mMpiGroupComm(std::addressof(tensorrt_llm::mpi::MpiComm::session()))
, mCacheTransceiverConfig{cacheTransceiverConfig}
{
using tensorrt_llm::batch_manager::kv_cache_manager::CacheFormatter;
if (worldConfig.isPipelineParallel())
{
mMpiGroupPipeParaComm = std::make_shared<tensorrt_llm::mpi::MpiComm>(
@ -200,14 +200,12 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
TLLM_THROW("Unsupported cache transceiver backend type ");
}
using tensorrt_llm::batch_manager::kv_cache_manager::MLACacheFormatter;
auto makeFormatter = [cacheManager, isMLA, this]()
{ return createCacheFormatter(cacheManager, mCacheTransBufferManager.get(), isMLA); };
mDataResponder = std::make_unique<DataResponder>(
std::make_unique<DataSenderImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
mDataRequester = std::make_unique<DataRequester>(
std::make_unique<DataReceiverImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
mCacheSender = std::make_unique<CacheSender>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter());
mCacheReceiver
= std::make_unique<CacheReceiver>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter());
initializeCommState();
}
@ -223,7 +221,7 @@ CacheTransceiver::~CacheTransceiver()
void CacheTransceiver::initializeCommState()
{
mCommState = std::addressof(mDataResponder->getCommState());
mCommState = std::addressof(mCacheSender->getCommState());
}
void CacheTransceiver::setContextState(LlmRequest* llmRequest)
@ -259,8 +257,8 @@ void CacheTransceiver::respondAndSendAsync(LlmRequest* llmRequest)
return;
}
setContextState(llmRequest);
auto future = mDataResponder->respondAndSendAsync(*llmRequest);
mResponderFutures.emplace_back(llmRequest, std::move(future));
auto future = mCacheSender->sendAsync(*llmRequest);
mSenderFutures.emplace_back(llmRequest, std::move(future));
}
void CacheTransceiver::respondAndSendLayerWise(
@ -275,8 +273,8 @@ void CacheTransceiver::respondAndSendLayerWise(
llmRequest->setState(LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS);
setContextState(llmRequest.get());
auto future = mDataResponder->respondAndSendAsync(*llmRequest);
mResponderFutures.emplace_back(llmRequest.get(), std::move(future));
auto future = mCacheSender->sendAsync(*llmRequest);
mSenderFutures.emplace_back(llmRequest.get(), std::move(future));
}
}
@ -284,7 +282,7 @@ void CacheTransceiver::requestAndReceiveSync(LlmRequest* llmRequest)
{
TLLM_CHECK(llmRequest && llmRequest->isGenerationOnlyRequest());
{
auto future = mDataRequester->requestAndReceiveAsync(*llmRequest);
auto future = mCacheReceiver->receiveAsync(*llmRequest);
future.get();
}
llmRequest->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE);
@ -302,7 +300,7 @@ void CacheTransceiver::requestAndReceiveAsync(LlmRequest* llmRequest)
return;
}
auto future = mDataRequester->requestAndReceiveAsync(*llmRequest);
auto future = mCacheReceiver->receiveAsync(*llmRequest);
mRequesterFutures.emplace_back(llmRequest, std::move(future));
llmRequest->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS);
}
@ -382,7 +380,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
bool blockAll = !atLeastRequestNum.has_value();
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mMpiGroupTPInDPComm : mMpiGroupTensorParaComm;
std::vector<LlmRequest::RequestIdType> contextCompleteRequestIds;
for (auto&& [request, future] : mResponderFutures)
for (auto&& [request, future] : mSenderFutures)
{
if (future.wait_for(std::chrono::milliseconds(0)) == std::future_status::ready)
{
@ -422,16 +420,15 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
// Make sure there are at least atLeastRequestNum requests in toCompleteIdSet.
// This will preserve the order of insertion for KVCache transfer requests.
for (auto it = mResponderFutures.begin();
atLeastRequestNum.value_or(0) > static_cast<int>(toCompleteIdSet.size()) && it != mResponderFutures.end();
++it)
for (auto it = mSenderFutures.begin();
atLeastRequestNum.value_or(0) > static_cast<int>(toCompleteIdSet.size()) && it != mSenderFutures.end(); ++it)
{
auto& [request, future] = *it;
toCompleteIdSet.insert(request->mRequestId);
}
// Complete all the requests in toCompleteIdSet
for (auto it = mResponderFutures.begin(); it != mResponderFutures.end();)
for (auto it = mSenderFutures.begin(); it != mSenderFutures.end();)
{
auto& [request, future] = *it;
if (blockAll || (toCompleteIdSet.find(request->mRequestId) != toCompleteIdSet.end()))
@ -447,7 +444,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
"Error occurred during context transfer for request %ld: %s", request->mRequestId, e.what());
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
}
it = mResponderFutures.erase(it);
it = mSenderFutures.erase(it);
}
else
{

View File

@ -17,6 +17,7 @@
#include "dataTransceiver.h"
#include "tensorrt_llm/batch_manager/cacheFormatter.h"
#include "tensorrt_llm/batch_manager/common.h"
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"
#include "tensorrt_llm/batch_manager/runtimeBuffers.h"
@ -24,6 +25,7 @@
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/tllmException.h"
#include "tensorrt_llm/common/utils.h"
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <future>
#include <map>
@ -33,8 +35,138 @@
namespace tensorrt_llm::batch_manager
{
using kv_cache_manager::BlockRange;
using BlockRange = tensorrt_llm::batch_manager::kv_cache_manager::BlockRange;
std::vector<Connection const*> const& TransferSession::getConnections() const
{
return mConnections;
}
void TransferSession::setConnection(size_t idx, Connection const* conn)
{
mConnections.at(idx) = conn;
}
DataContext const& TransferSession::getDataContext() const
{
return mDataContext;
}
executor::DataTransceiverState const& TransferSession::getSelfState() const
{
return *mSelfState;
}
executor::DataTransceiverState const& TransferSession::getOtherState() const
{
return mOtherState;
}
runtime::BufferManager const& TransferSession::getBufferManager() const
{
return *mBufferManager;
}
void TransferSession::send(size_t idx, void const* data, size_t size)
{
try
{
mConnections.at(idx)->send(mDataContext, data, size);
}
catch (std::exception const& e)
{
throw common::RequestSpecificException(
__FILE__, __LINE__, e.what(), mRequest->mRequestId, common::RequestErrorCode::kNETWORK_ERROR);
}
}
void TransferSession::recv(size_t idx, void* data, size_t size)
{
try
{
mConnections.at(idx)->recv(mDataContext, data, size);
}
catch (std::exception const& e)
{
throw common::RequestSpecificException(
__FILE__, __LINE__, e.what(), mRequest->mRequestId, common::RequestErrorCode::kNETWORK_ERROR);
}
}
LlmRequest const& TransferSession::getLlmRequest() const
{
TLLM_CHECK(mRequest != nullptr);
return *mRequest;
}
void TransferSession::setLlmRequest(LlmRequest const& llmRequest)
{
mRequest = &llmRequest;
}
void TransferSession::appendMeasure(double delay, double duration, size_t size)
{
if (!mRecordMeasure)
{
return;
}
auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps
mMeasures.emplace_back(Measure{delay, duration, bandwidth});
}
void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const
{
if (mMeasures.empty())
{
return;
}
// write header if not exist
if (outFile.tellp() == 0)
{
outFile << "RequestID";
for (size_t i = 0; i < mMeasures.size(); i++)
{
outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)";
}
outFile << '\n';
}
// write measures
TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value());
auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId();
outFile << reqId;
for (auto const& measure : mMeasures)
{
outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth;
}
outFile << '\n' << std::flush;
}
std::vector<size_t> const& RequestInfo::getBlockHashes() const noexcept
{
return mBlockHashes;
}
using runtime::SizeType32;
using AgentConnectionManager = tensorrt_llm::executor::kv_cache::AgentConnectionManager;
using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
static int32_t tagFromRequestId(LlmRequest::RequestIdType requestId)
{
constexpr int32_t kDATA_TAG{43};
return ((requestId & 0xFFF) << 8) | (kDATA_TAG & 0xFF);
}
struct ReceiveCacheResource
{
runtime::BufferManager mBufferManager;
runtime::CudaEvent mCudaEvent;
ReceiveCacheResource(runtime::BufferManager&& bufferManager, runtime::CudaEvent&& cudaEvent)
: mBufferManager(bufferManager)
, mCudaEvent(std::move(cudaEvent))
{
}
};
RequestInfo::RequestInfo(LlmRequest::RequestIdType requestId, executor::DataTransceiverState transState)
: mRequestId{requestId}
@ -92,82 +224,128 @@ std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo)
return totalSize;
}
void TransferSession::appendMeasure(double delay, double duration, size_t size)
{
if (!mRecordMeasure)
{
return;
}
auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps
mMeasures.emplace_back(Measure{delay, duration, bandwidth});
}
void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const
{
if (mMeasures.empty())
{
return;
}
// write header if not exist
if (outFile.tellp() == 0)
{
outFile << "RequestID";
for (size_t i = 0; i < mMeasures.size(); i++)
{
outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)";
}
outFile << '\n';
}
// write measures
TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value());
auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId();
outFile << reqId;
for (auto const& measure : mMeasures)
{
outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth;
}
outFile << '\n' << std::flush;
}
class DataResponder::Impl
class CacheSender::Impl
{
public:
using RequestIdType = LlmRequest::RequestIdType;
Impl(std::unique_ptr<DataSender> sender)
: mSender{std::move(sender)}
Impl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
: mManager{manager}
, mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}}
, mFormatter{std::move(formatter)}
, mBufferManager{std::make_shared<runtime::CudaStream>()}
{
TLLM_CHECK(mSender);
TLLM_CHECK(mManager);
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId));
mCurrentRequest = std::nullopt;
mResponseFuture = std::async(std::launch::async, &Impl::response, this);
}
[[nodiscard]] std::future<void> respondAndSendAsync(LlmRequest& llmRequest)
[[nodiscard]] std::future<void> sendAsync(LlmRequest& llmRequest)
{
std::promise<void> promise;
auto future = promise.get_future();
{
{
std::unique_lock lkResp(mResponderMutex);
std::unique_lock lkResp(mSenderMutex);
mReadyResponses.emplace(
llmRequest.mRequestId, Response{std::addressof(llmRequest), std::move(promise)});
}
std::unique_lock lkCond(mCondMutex);
mAnyReady = true;
}
mResponderCv.notify_all();
mSenderCv.notify_all();
return future;
}
[[nodiscard]] executor::kv_cache::CommState const& getCommState() const
{
return mSender->getCommState();
return mSelfState.getCommState().value();
}
void setCommState(executor::kv_cache::CommState commState)
{
mSender->setCommState(std::move(commState));
mSelfState.setCommState(std::move(commState));
}
[[nodiscard]] size_t getCounterpartsCount(LlmRequest::RequestIdType requestId) const
{
auto it = mRequestToSession.find(requestId);
TLLM_CHECK(it != mRequestToSession.end());
return it->second.getConnections().size();
}
void release(LlmRequest::RequestIdType requestId)
{
auto it = mRequestToSession.find(requestId);
TLLM_CHECK(it != mRequestToSession.end());
std::unique_lock<std::mutex> lk(mMtxForMap);
mRequestToSession.erase(it);
}
[[nodiscard]] RequestInfo recvRequestInfo()
{
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
bool isAgent = agentConnectionManager != nullptr;
auto agentRecvFun = [&](RequestInfo& requestInfo)
{
auto const* connection = agentConnectionManager->recvConnectionAndRequestInfo(requestInfo);
return connection;
};
TransceiverTag::Id id;
RequestInfo info;
auto const* connection = isAgent ? agentRecvFun(info)
: mManager->recvConnect(DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id));
if (!isAgent)
{
TLLM_CHECK(id == TransceiverTag::Id::REQUEST_SEND);
std::uint64_t infoSize{0};
connection->recv(
executor::kv_cache::DataContext{TransceiverTag::kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize));
std::string serializedInfo;
serializedInfo.resize(infoSize);
connection->recv(
executor::kv_cache::DataContext{TransceiverTag::kINFO_TAG}, serializedInfo.data(), infoSize);
std::istringstream iss(serializedInfo);
info = RequestInfo::deserialize(iss);
}
auto requestId = info.getRequestId();
TLLM_CHECK_WITH_INFO(mFormatter->inquireSupport(
mSelfState.getCacheState().value(), info.getTransState().getCacheState().value()),
"Disagg server does not currently support these cacheState, please check the cacheState of the context and "
"gen "
"executors");
auto peerRelativeRanks = executor::kv_cache::targetIRanks(info.getTransState().getCacheState().value(),
mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx())
.mIRanks;
int peerIdx = std::distance(peerRelativeRanks.begin(),
std::find(
peerRelativeRanks.begin(), peerRelativeRanks.end(), info.getTransState().getCommState()->getSelfIdx()));
{
std::unique_lock<std::mutex> lk(mMtxForMap);
auto it = mRequestToSession.find(requestId);
if (it == mRequestToSession.end())
{
auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr),
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager);
it = mRequestToSession.emplace(requestId, std::move(session)).first;
}
it->second.setConnection(peerIdx, connection);
}
return info;
}
void sendSync(LlmRequest const& llmRequest)
{
auto it = mRequestToSession.find(llmRequest.mRequestId);
TLLM_CHECK(it != mRequestToSession.end());
auto& session = it->second;
session.setLlmRequest(llmRequest);
mFormatter->format(session);
}
~Impl()
@ -187,8 +365,8 @@ private:
try
{
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId));
mSender->sendSync(*resp.mRequest);
mSender->release(id);
sendSync(*resp.mRequest);
release(id);
resp.mPromise.set_value();
}
catch (tensorrt_llm::common::RequestSpecificException const& e)
@ -220,12 +398,11 @@ private:
if (common::getEnvParallelCacheSend())
{
// TODO: Use a thread pool and check for thread safety.
std::thread(&DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second))
.detach();
std::thread(&CacheSender::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second)).detach();
}
else
{
DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
CacheSender::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
}
removeResponse(it);
}
@ -243,7 +420,7 @@ private:
if (!mAnyReady)
{
std::unique_lock lk(mCondMutex);
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
mSenderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
}
if (mTerminate)
{
@ -252,14 +429,14 @@ private:
std::vector<size_t> blockHashes;
if (!isSending() && !mReadyResponses.empty())
{
auto const& requestInfo = mSender->recvRequestInfo();
auto const& requestInfo = recvRequestInfo();
auto reqId = requestInfo.getRequestId();
blockHashes = requestInfo.getBlockHashes();
mCurrentRequest = reqId;
if (mRemainSendCount.find(reqId) == mRemainSendCount.end())
{
mRemainSendCount[reqId] = mSender->getCounterpartsCount(reqId);
mRemainSendCount[reqId] = getCounterpartsCount(reqId);
}
}
auto it = getCurrentResponse();
@ -273,7 +450,7 @@ private:
while (it == mReadyResponses.end())
{
std::unique_lock lk(mCondMutex);
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
mSenderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
if (mTerminate)
{
break;
@ -286,7 +463,7 @@ private:
}
catch (std::exception const& err)
{
TLLM_LOG_ERROR("Exception in DataResponder response: %s", err.what());
TLLM_LOG_ERROR("Exception in CacheSender response: %s", err.what());
for (auto& it : mReadyResponses)
{
it.second.mPromise.set_exception(std::current_exception());
@ -302,13 +479,13 @@ private:
}
// We don't have to wait for the future. If another thread is sending data, it won't pay attention
// to the terminate flag.
mResponderCv.notify_all();
mSenderCv.notify_all();
}
void removeResponse(std::map<RequestIdType, Response>::iterator it)
{
{
std::unique_lock lkResp(mResponderMutex);
std::unique_lock lkResp(mSenderMutex);
mReadyResponses.erase(it);
}
if (mReadyResponses.empty())
@ -330,36 +507,47 @@ private:
[[nodiscard]] std::map<RequestIdType, Response>::iterator getCurrentResponse()
{
std::unique_lock lk(mResponderMutex);
std::unique_lock lk(mSenderMutex);
return mReadyResponses.find(getCurrentRequestId());
}
private:
std::optional<RequestIdType> mCurrentRequest;
std::map<RequestIdType, Response> mReadyResponses;
std::mutex mResponderMutex, mCondMutex;
std::mutex mSenderMutex, mCondMutex;
std::atomic<bool> mAnyReady{false}, mTerminate{false};
std::condition_variable mResponderCv;
std::condition_variable mSenderCv;
std::future<void> mResponseFuture;
std::unique_ptr<DataSender> mSender;
std::unordered_map<LlmRequest::RequestIdType, int> mRemainSendCount;
int mDeviceId{-1};
executor::kv_cache::ConnectionManager* mManager;
std::map<LlmRequest::RequestIdType, TransferSession> mRequestToSession;
executor::DataTransceiverState mSelfState;
std::unique_ptr<BaseCacheFormatter> mFormatter;
std::mutex mMtxForMap;
runtime::BufferManager mBufferManager;
};
class DataRequester::Impl
class CacheReceiver::Impl
{
public:
Impl(std::unique_ptr<DataReceiver> receiver)
: mReceiver{std::move(receiver)}
Impl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
: mManager{manager}
, mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}}
, mFormatter{std::move(formatter)}
, mBufferManager{std::make_shared<runtime::CudaStream>()}
{
TLLM_CHECK(mReceiver);
TLLM_CHECK(mManager);
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId));
}
[[nodiscard]] std::future<void> requestAndReceiveAsync(LlmRequest& llmRequest)
[[nodiscard]] std::future<void> receiveAsync(LlmRequest& llmRequest)
{
// TODO: Modify the implementation here to avoid frequent thread creation.
return std::async(std::launch::async, &DataRequester::Impl::requestSync, this, std::ref(llmRequest));
return std::async(std::launch::async, &CacheReceiver::Impl::requestSync, this, std::ref(llmRequest));
}
[[nodiscard]] std::future<void> requestAndReceiveAsyncMultiThreads(LlmRequest& llmRequest)
@ -378,7 +566,7 @@ public:
{
mInstanceToAsyncResource.emplace(processInfo, std::make_unique<AsyncResource>());
auto requestFuture = std::async(std::launch::async, &DataRequester::Impl::request, this,
auto requestFuture = std::async(std::launch::async, &CacheReceiver::Impl::request, this,
std::ref(*mInstanceToAsyncResource.at(processInfo)));
mRequestFutures.emplace_back(std::move(requestFuture));
}
@ -396,6 +584,107 @@ public:
}
}
void receiveSync(TransferSession& session)
{
mFormatter->unformat(session);
}
TransferSession sendRequestInfo(LlmRequest const& llmRequest)
{
uint64_t requestId = llmRequest.getContextPhaseParams().value().getReqId();
auto const& contextState = llmRequest.getDataTransceiverState();
auto const& commState = contextState.getCommState().value();
auto const& destCacheState = contextState.getCacheState().value();
TLLM_CHECK_WITH_INFO(mFormatter->inquireSupport(mSelfState.getCacheState().value(), destCacheState),
"Disagg server does not currently support these cacheState.");
RequestInfo requestInfo(requestId, mSelfState);
auto disableSelectiveCacheTransfer = common::getEnvDisableSelectiveCacheTransfer()
|| (mFormatter->getCacheManager()->getBlockManager().getNumPools() > 1);
if (!disableSelectiveCacheTransfer)
{
auto* cacheManager = mFormatter->getCacheManager();
auto blockRange
= kv_cache_manager::BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId);
requestInfo = RequestInfo(requestId, blockRange.getBlockHashes(), mSelfState);
}
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
std::optional<size_t> cacheBufferId = std::nullopt;
if (agentConnectionManager != nullptr)
{
cacheBufferId = agentConnectionManager->getCacheTransBufferManager()->assignBufferIndexForRecv();
TLLM_CHECK(cacheBufferId.has_value());
// memory Desp , validSegmentIdx send
}
auto counterParts = mFormatter->getCounterparts(
mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx(), destCacheState);
auto connections = mManager->getConnections(commState);
std::vector<executor::kv_cache::Connection const*> counterPartConnections;
for (auto index : counterParts)
{
auto const* connection = connections.at(index);
counterPartConnections.emplace_back(connection);
}
auto pickUpIdx = mFormatter->pickRecvConnections(counterParts.size(), mSelfState.getCacheState().value(),
mSelfState.getCommState().value().getSelfIdx(), destCacheState);
for (size_t i = 0; i < counterPartConnections.size(); i++)
{
auto const* connection = counterPartConnections[i];
// if Manager is agentConnectionManager, then send request info to agent
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
if (agentConnectionManager != nullptr)
{
// TODO: index -> validConnectionIdx conversion
auto valideConnectionIdx = std::find(pickUpIdx.begin(), pickUpIdx.end(), i) - pickUpIdx.begin();
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connection);
TLLM_CHECK(agentConnection != nullptr);
TLLM_CHECK(cacheBufferId.has_value());
const_cast<executor::kv_cache::AgentConnection*>(agentConnection)
->sendRequestAndBufferInfo(requestInfo, cacheBufferId, valideConnectionIdx);
}
else
{
sendRequestInfo(connection, requestInfo);
}
}
auto const& resource = getReceiveCacheResource(llmRequest);
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState,
contextState, resource->mBufferManager, &llmRequest);
}
std::unique_ptr<ReceiveCacheResource> const& getReceiveCacheResource(LlmRequest const& llmRequest)
{
std::scoped_lock<std::mutex> lock(mProcessIoResouceMutex);
TLLM_CHECK(llmRequest.getDataTransceiverState().getCommState().has_value());
std::string processString = "default";
if (common::getEnvRequestKVCacheConcurrent())
{
processString = llmRequest.getDataTransceiverState().getCommState()->toString();
}
if (mProcessToResources.find(processString) == mProcessToResources.end())
{
mProcessToResources.emplace(processString,
std::make_unique<ReceiveCacheResource>(
runtime::BufferManager{std::make_shared<runtime::CudaStream>()}, runtime::CudaEvent{}));
}
return mProcessToResources.at(processString);
}
void sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info)
{
std::ostringstream oss;
RequestInfo::serialize(info, oss);
auto const& serializedInfo = oss.str();
std::size_t const infoSize = serializedInfo.size();
TransceiverTag::Id id{TransceiverTag::Id::REQUEST_SEND};
connection->send(executor::kv_cache::DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id));
connection->send(executor::kv_cache::DataContext{TransceiverTag::kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize));
connection->send(executor::kv_cache::DataContext{TransceiverTag::kINFO_TAG}, serializedInfo.data(), infoSize);
}
~Impl()
{
for (auto&& [processInfo, asyncResource] : mInstanceToAsyncResource)
@ -417,8 +706,8 @@ private:
llmRequest.getContextPhaseParams().value().getReqId());
llmRequest.setKvCacheTransferStart(std::chrono::steady_clock::now());
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId));
auto session = mReceiver->sendRequestInfo(llmRequest);
mReceiver->receiveSync(session);
auto session = sendRequestInfo(llmRequest);
receiveSync(session);
llmRequest.setKvCacheTransferEnd(std::chrono::steady_clock::now());
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
@ -524,7 +813,7 @@ private:
}
catch (std::exception const& err)
{
TLLM_LOG_ERROR("Exception in DataRequester request(): request id:%ld , request context id:%ld : %s",
TLLM_LOG_ERROR("Exception in CacheReceiver request(): request id:%ld , request context id:%ld : %s",
requestAndPromise.mRequest->mRequestId,
requestAndPromise.mRequest->getContextPhaseParams().value().getReqId(), err.what());
requestAndPromise.mPromise->set_exception(std::current_exception());
@ -533,45 +822,81 @@ private:
}
}
std::unique_ptr<DataReceiver> mReceiver;
int mDeviceId{-1};
std::vector<std::future<void>> mRequestFutures;
std::unordered_map<std::string, std::unique_ptr<AsyncResource>> mInstanceToAsyncResource;
executor::kv_cache::ConnectionManager* mManager;
executor::DataTransceiverState mSelfState;
std::unique_ptr<BaseCacheFormatter> mFormatter;
std::unordered_map<std::string, std::unique_ptr<ReceiveCacheResource>> mProcessToResources;
std::mutex mProcessIoResouceMutex;
runtime::BufferManager mBufferManager;
};
DataResponder::DataResponder(std::unique_ptr<DataSender> sender)
: mImpl{std::make_unique<Impl>(std::move(sender))}
void CacheSender::ImplDeleter::operator()(Impl* ptr)
{
delete ptr;
}
void CacheReceiver::ImplDeleter::operator()(Impl* ptr)
{
delete ptr;
}
CacheSender::CacheSender(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
: mImpl{std::unique_ptr<Impl, ImplDeleter>(new Impl(manager, selfCacheState, selfIndex, std::move(formatter)))}
{
}
std::future<void> DataResponder::respondAndSendAsync(LlmRequest& llmRequest) const
std::future<void> CacheSender::sendAsync(LlmRequest& llmRequest) const
{
return mImpl->respondAndSendAsync(llmRequest);
return mImpl->sendAsync(llmRequest);
}
executor::kv_cache::CommState const& DataResponder::getCommState() const
executor::kv_cache::CommState const& CacheSender::getCommState() const
{
return mImpl->getCommState();
}
void DataResponder::setCommState(executor::kv_cache::CommState commState)
void CacheSender::setCommState(executor::kv_cache::CommState commState)
{
mImpl->setCommState(std::move(commState));
}
DataResponder::~DataResponder() = default;
CacheSender::~CacheSender() = default;
DataRequester::DataRequester(std::unique_ptr<DataReceiver> receiver)
: mImpl{std::make_unique<Impl>(std::move(receiver))}
void CacheSender::sendSync(LlmRequest const& llmRequest)
{
mImpl->sendSync(llmRequest);
}
RequestInfo CacheSender::recvRequestInfo()
{
return mImpl->recvRequestInfo();
}
CacheReceiver::CacheReceiver(executor::kv_cache::ConnectionManager* manager,
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
: mImpl{std::unique_ptr<Impl, ImplDeleter>(new Impl(manager, selfCacheState, selfIndex, std::move(formatter)))}
{
}
std::future<void> DataRequester::requestAndReceiveAsync(LlmRequest& llmRequest) const
std::future<void> CacheReceiver::receiveAsync(LlmRequest& llmRequest) const
{
return mImpl->requestAndReceiveAsyncMultiThreads(llmRequest);
}
DataRequester::~DataRequester() = default;
CacheReceiver::~CacheReceiver() = default;
TransferSession CacheReceiver::sendRequestInfo(LlmRequest const& llmRequest)
{
return mImpl->sendRequestInfo(llmRequest);
}
void CacheReceiver::receiveSync(TransferSession& session)
{
mImpl->receiveSync(session);
}
} // namespace tensorrt_llm::batch_manager

View File

@ -20,6 +20,7 @@
#include <future>
#include <map>
#include <string>
#include <vector>
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/common/assert.h"
@ -28,16 +29,100 @@
#include "tensorrt_llm/executor/cacheCommunicator.h"
#include "tensorrt_llm/executor/dataTransceiverState.h"
#include "tensorrt_llm/executor/serializeUtils.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
namespace tensorrt_llm::batch_manager
{
namespace kv_cache_manager
{
class BaseCacheFormatter;
}
using BaseCacheFormatter = kv_cache_manager::BaseCacheFormatter;
// TODO: unify the following class into a namespace like tensorrt_llm::transmission
using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
using Connection = tensorrt_llm::executor::kv_cache::Connection;
using ConnectionManager = tensorrt_llm::executor::kv_cache::ConnectionManager;
using SizeType32 = tensorrt_llm::runtime::SizeType32;
class TransferSession
{
public:
struct Measure
{
double delay; // from last token (ctx) or arrival time (gen), in ms
double duration; // in ms
double bandwidth; // in Gbps
};
TransferSession(std::vector<Connection const*> connections, DataContext dataContext,
executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState,
runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr, bool recordMeasure = false)
: mConnections(std::move(connections))
, mDataContext(dataContext)
, mSelfState(&selfState)
, mOtherState(std::move(otherState))
, mBufferManager(&bufferManager)
, mRequest(llmRequest)
, mRecordMeasure(recordMeasure)
{
TLLM_CHECK(!mConnections.empty());
}
[[nodiscard]] std::vector<Connection const*> const& getConnections() const;
// should be called only during the initialization of the TransferSession
void setConnection(size_t idx, Connection const* conn);
[[nodiscard]] DataContext const& getDataContext() const;
[[nodiscard]] executor::DataTransceiverState const& getSelfState() const;
[[nodiscard]] executor::DataTransceiverState const& getOtherState() const;
[[nodiscard]] runtime::BufferManager const& getBufferManager() const;
void send(size_t idx, void const* data, size_t size);
void recv(size_t idx, void* data, size_t size);
[[nodiscard]] LlmRequest const& getLlmRequest() const;
// in CacheSender, the LlmRequest is not available until the sendSync is called
void setLlmRequest(LlmRequest const& llmRequest);
void appendMeasure(double delay, double duration, size_t size);
// TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file
void exportMeasure(std::ofstream& outFile, bool isContext) const;
private:
std::vector<Connection const*> mConnections;
DataContext mDataContext;
executor::DataTransceiverState const* mSelfState; // stored in CacheReceiver/CacheSender
executor::DataTransceiverState mOtherState;
runtime::BufferManager const* mBufferManager;
LlmRequest const* mRequest;
std::vector<Measure> mMeasures;
bool mRecordMeasure{false};
};
struct TransceiverTag
{
enum class Id : uint64_t
{
REQUEST_SEND = 1,
TERMINATION = 2
};
static constexpr int32_t kID_TAG{19};
static constexpr int32_t kINFO_SIZE_TAG{22};
static constexpr int32_t kINFO_TAG{32};
};
// Used to store the information that needs to be sent to the context executor to ensure the generation
// executor smoothly receives the data.
@ -61,10 +146,7 @@ public:
/// @return The request ID.
[[nodiscard]] LlmRequest::RequestIdType getRequestId() const noexcept;
[[nodiscard]] std::vector<size_t> const& getBlockHashes() const noexcept
{
return mBlockHashes;
}
[[nodiscard]] std::vector<size_t> const& getBlockHashes() const noexcept;
/// @brief Return the state of the data transceiver.
/// @return The state of the data transceiver.
@ -94,207 +176,81 @@ private:
executor::DataTransceiverState mTransState;
};
class TransferSession
{
public:
struct Measure
{
double delay; // from last token (ctx) or arrival time (gen), in ms
double duration; // in ms
double bandwidth; // in Gbps
};
TransferSession(std::vector<Connection const*> connections, DataContext dataContext,
executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState,
runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr, bool recordMeasure = false)
: mConnections(std::move(connections))
, mDataContext(dataContext)
, mSelfState(&selfState)
, mOtherState(std::move(otherState))
, mBufferManager(&bufferManager)
, mRequest(llmRequest)
, mRecordMeasure(recordMeasure)
{
TLLM_CHECK(!mConnections.empty());
}
[[nodiscard]] std::vector<Connection const*> const& getConnections() const
{
return mConnections;
}
// should be called only during the initialization of the TransferSession
void setConnection(size_t idx, Connection const* conn)
{
mConnections.at(idx) = conn;
}
[[nodiscard]] DataContext const& getDataContext() const
{
return mDataContext;
}
[[nodiscard]] executor::DataTransceiverState const& getSelfState() const
{
return *mSelfState;
}
[[nodiscard]] executor::DataTransceiverState const& getOtherState() const
{
return mOtherState;
}
[[nodiscard]] runtime::BufferManager const& getBufferManager() const
{
return *mBufferManager;
}
void send(size_t idx, void const* data, size_t size)
{
try
{
mConnections.at(idx)->send(mDataContext, data, size);
}
catch (std::exception const& e)
{
throw common::RequestSpecificException(
__FILE__, __LINE__, e.what(), mRequest->mRequestId, common::RequestErrorCode::kNETWORK_ERROR);
}
}
void recv(size_t idx, void* data, size_t size)
{
try
{
mConnections.at(idx)->recv(mDataContext, data, size);
}
catch (std::exception const& e)
{
throw common::RequestSpecificException(
__FILE__, __LINE__, e.what(), mRequest->mRequestId, common::RequestErrorCode::kNETWORK_ERROR);
}
}
[[nodiscard]] LlmRequest const& getLlmRequest() const
{
TLLM_CHECK(mRequest != nullptr);
return *mRequest;
}
// in DataSender, the LlmRequest is not available until the sendSync is called
void setLlmRequest(LlmRequest const& llmRequest)
{
mRequest = &llmRequest;
}
void appendMeasure(double delay, double duration, size_t size);
// TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file
void exportMeasure(std::ofstream& outFile, bool isContext) const;
private:
std::vector<Connection const*> mConnections;
DataContext mDataContext;
executor::DataTransceiverState const* mSelfState; // stored in DataRequester/DataResponder
executor::DataTransceiverState mOtherState;
runtime::BufferManager const* mBufferManager;
LlmRequest const* mRequest;
bool mRecordMeasure;
std::vector<Measure> mMeasures;
};
// Operators required for data transmission in specific communication protocols.
class DataSender
{
public:
/// @brief Receive the request information.
/// @return The request information.
[[nodiscard]] virtual RequestInfo recvRequestInfo() = 0;
/// @brief Synchronously send data.
/// @param llmRequest The request object to which the data belongs.
virtual void sendSync(LlmRequest const& llmRequest) = 0;
/// @brief Return the internal communicator status.
/// @return The communicator status.
[[nodiscard]] virtual executor::kv_cache::CommState const& getCommState() const = 0;
/// @brief Reset the internal communicator status.
/// @param commState The communicator status.
virtual void setCommState(executor::kv_cache::CommState commState) = 0;
[[nodiscard]] virtual size_t getCounterpartsCount(LlmRequest::RequestIdType requestId) const = 0;
virtual void release(LlmRequest::RequestIdType requestId) = 0;
/// @brief Destructor.
virtual ~DataSender() = default;
};
// Operators required for data transmission in specific communication protocols.
class DataReceiver
{
public:
/// @brief Send the request information.
/// @param llmRequest The request object to which the information belongs.
virtual TransferSession sendRequestInfo(LlmRequest const& llmRequest) = 0;
/// @brief Synchronously receive data.
/// @param session The transfer session.
virtual void receiveSync(TransferSession& session) = 0;
/// @brief Destructor.
virtual ~DataReceiver() = default;
};
class DataResponder
class CacheSender
{
public:
/// @brief Constructor.
/// @param sender The sender used at the underlying level.
explicit DataResponder(std::unique_ptr<DataSender> sender);
CacheSender(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter);
CacheSender() = default;
/// @brief Asynchronously respond to the request and send data.
/// @param llmRequest Request object. Its data should be ready when called, and the data for this request
/// should remain valid until future synchronization.
/// @return Once the data is fully sent, the future object will become valid.
[[nodiscard]] std::future<void> respondAndSendAsync(LlmRequest& llmRequest) const;
[[nodiscard]] virtual std::future<void> sendAsync(LlmRequest& llmRequest) const;
/// @brief Return the internal communicator status.
/// @return The communicator status.
[[nodiscard]] executor::kv_cache::CommState const& getCommState() const;
[[nodiscard]] virtual executor::kv_cache::CommState const& getCommState() const;
/// @brief Reset the internal communicator status.
/// @param commState The communicator status.
void setCommState(executor::kv_cache::CommState commState);
virtual void setCommState(executor::kv_cache::CommState commState);
/// @brief Synchronously send data.
/// @param llmRequest The request object to which the data belongs.
virtual void sendSync(LlmRequest const& llmRequest);
/// @brief Receive request information.
/// @param llmRequest The request object to which the data belongs.
virtual RequestInfo recvRequestInfo();
/// @brief Destructor.
~DataResponder();
virtual ~CacheSender();
private:
class Impl;
std::unique_ptr<Impl> mImpl;
struct ImplDeleter
{
void operator()(Impl*);
};
std::unique_ptr<Impl, ImplDeleter> mImpl;
};
class DataRequester
class CacheReceiver
{
public:
/// @brief Constructor.
/// @param receiver The receiver used at the underlying level.
explicit DataRequester(std::unique_ptr<DataReceiver> receiver);
CacheReceiver(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter);
CacheReceiver() = default;
/// @brief Asynchronously send a request to receive data.
/// @param llmRequest Request object. Its data should be in an allocated but unwritten state when called, and the
/// data for this request should remain intact only after future synchronization.
/// @return Once the data is fully received, the future object will become valid.
[[nodiscard]] std::future<void> requestAndReceiveAsync(LlmRequest& llmRequest) const;
[[nodiscard]] virtual std::future<void> receiveAsync(LlmRequest& llmRequest) const;
virtual TransferSession sendRequestInfo(LlmRequest const& llmRequest);
virtual void receiveSync(TransferSession& session);
/// @brief Destructor.
~DataRequester();
virtual ~CacheReceiver();
private:
class Impl;
std::unique_ptr<Impl> mImpl;
struct ImplDeleter
{
void operator()(Impl*);
};
std::unique_ptr<Impl, ImplDeleter> mImpl;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -1,285 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataTransceiverImpl.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <filesystem>
namespace tensorrt_llm::batch_manager
{
static int32_t tagFromRequestId(LlmRequest::RequestIdType requestId)
{
constexpr int32_t kDATA_TAG{43};
return ((requestId & 0xFFF) << 8) | (kDATA_TAG & 0xFF);
}
namespace fs = std::filesystem;
static fs::path getTransferOutputPath(char const* tag)
{
auto outputPath = common::getEnvKVCacheTransferOutputPath();
if (!outputPath.empty())
{
auto rank = mpi::MpiComm::world().getRank();
auto path = fs::path(outputPath);
fs::create_directories(path);
return path / ("rank_" + std::to_string(rank) + "_" + tag + ".csv");
}
return {};
}
DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
: mManager{manager}
, mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}}
, mFormatter(std::move(formatter))
, mBufferManager{std::make_shared<runtime::CudaStream>()}
{
TLLM_CHECK(mManager);
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
}
[[nodiscard]] RequestInfo DataSenderImpl::recvRequestInfo()
{
using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
bool isAgent = agentConnectionManager != nullptr;
auto agentRecvFun = [&](RequestInfo& requestInfo)
{
auto const* connection = agentConnectionManager->recvConnectionAndRequestInfo(requestInfo);
return connection;
};
Id id;
RequestInfo info;
auto const* connection
= isAgent ? agentRecvFun(info) : mManager->recvConnect(DataContext{kID_TAG}, &id, sizeof(id));
if (!isAgent)
{
TLLM_CHECK(id == Id::REQUEST_SEND);
std::uint64_t infoSize{0};
connection->recv(executor::kv_cache::DataContext{kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize));
std::string serializedInfo;
serializedInfo.resize(infoSize);
connection->recv(executor::kv_cache::DataContext{kINFO_TAG}, serializedInfo.data(), infoSize);
std::istringstream iss(serializedInfo);
info = RequestInfo::deserialize(iss);
}
auto requestId = info.getRequestId();
TLLM_CHECK_WITH_INFO(
mFormatter->inquireSupport(mSelfState.getCacheState().value(), info.getTransState().getCacheState().value()),
"Disagg server does not currently support these cacheState, please check the cacheState of the context and gen "
"executors");
auto peerRelativeRanks = executor::kv_cache::targetIRanks(info.getTransState().getCacheState().value(),
mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx())
.mIRanks;
int peerIdx = std::distance(peerRelativeRanks.begin(),
std::find(
peerRelativeRanks.begin(), peerRelativeRanks.end(), info.getTransState().getCommState()->getSelfIdx()));
{
std::unique_lock<std::mutex> lk(mMtxForMap);
auto it = mRequestToSession.find(requestId);
if (it == mRequestToSession.end())
{
auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr),
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager, nullptr,
!common::getEnvKVCacheTransferOutputPath().empty());
it = mRequestToSession.emplace(requestId, std::move(session)).first;
}
it->second.setConnection(peerIdx, connection);
}
return info;
}
void DataSenderImpl::sendSync(LlmRequest const& llmRequest)
{
auto it = mRequestToSession.find(llmRequest.mRequestId);
TLLM_CHECK(it != mRequestToSession.end());
auto& session = it->second;
session.setLlmRequest(llmRequest);
mFormatter->format(session);
}
[[nodiscard]] executor::kv_cache::CommState const& DataSenderImpl::getCommState() const
{
return mSelfState.getCommState().value();
}
void DataSenderImpl::setCommState(executor::kv_cache::CommState commState)
{
mSelfState.setCommState(std::move(commState));
}
[[nodiscard]] size_t DataSenderImpl::getCounterpartsCount(LlmRequest::RequestIdType requestId) const
{
auto it = mRequestToSession.find(requestId);
TLLM_CHECK(it != mRequestToSession.end());
return it->second.getConnections().size();
}
void DataSenderImpl::release(LlmRequest::RequestIdType requestId)
{
auto it = mRequestToSession.find(requestId);
TLLM_CHECK(it != mRequestToSession.end());
std::unique_lock<std::mutex> lk(mMtxForMap);
if (!common::getEnvKVCacheTransferOutputPath().empty())
{
if (!mMeasuresFile.is_open())
{
auto outputPath = getTransferOutputPath("send");
mMeasuresFile.open(outputPath);
TLLM_CHECK_WITH_INFO(
mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str());
}
it->second.exportMeasure(mMeasuresFile, true);
}
mRequestToSession.erase(it);
}
DataReceiverImpl::DataReceiverImpl(executor::kv_cache::ConnectionManager* manager,
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
: mManager{manager}
, mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}}
, mFormatter(std::move(formatter))
{
TLLM_CHECK(mManager);
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
TLLM_CHECK(mFormatter);
}
TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
{
uint64_t requestId = llmRequest.getContextPhaseParams().value().getReqId();
auto const& contextState = llmRequest.getDataTransceiverState();
auto const& commState = contextState.getCommState().value();
auto const& destCacheState = contextState.getCacheState().value();
TLLM_CHECK_WITH_INFO(mFormatter->inquireSupport(mSelfState.getCacheState().value(), destCacheState),
"Disagg server does not currently support these cacheState.");
RequestInfo requestInfo(requestId, mSelfState);
auto disableSelectiveCacheTransfer = common::getEnvDisableSelectiveCacheTransfer()
|| (mFormatter->getCacheManager()->getBlockManager().getNumPools() > 1);
if (!disableSelectiveCacheTransfer)
{
auto* cacheManager = mFormatter->getCacheManager();
auto blockRange
= kv_cache_manager::BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId);
requestInfo = RequestInfo(requestId, blockRange.getBlockHashes(), mSelfState);
}
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
std::optional<size_t> cacheBufferId = std::nullopt;
if (agentConnectionManager != nullptr)
{
cacheBufferId = agentConnectionManager->getCacheTransBufferManager()->assignBufferIndexForRecv();
TLLM_CHECK(cacheBufferId.has_value());
// memory Desp , validSegmentIdx send
}
auto counterParts = mFormatter->getCounterparts(
mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx(), destCacheState);
auto connections = mManager->getConnections(commState);
std::vector<executor::kv_cache::Connection const*> counterPartConnections;
for (auto index : counterParts)
{
auto const* connection = connections.at(index);
counterPartConnections.emplace_back(connection);
}
auto pickUpIdx = mFormatter->pickRecvConnections(counterParts.size(), mSelfState.getCacheState().value(),
mSelfState.getCommState().value().getSelfIdx(), destCacheState);
for (size_t i = 0; i < counterPartConnections.size(); i++)
{
auto const* connection = counterPartConnections[i];
// if Manager is agentConnectionManager, then send request info to agent
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
if (agentConnectionManager != nullptr)
{
// TODO: index -> validConnectionIdx conversion
auto valideConnectionIdx = std::find(pickUpIdx.begin(), pickUpIdx.end(), i) - pickUpIdx.begin();
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connection);
TLLM_CHECK(agentConnection != nullptr);
TLLM_CHECK(cacheBufferId.has_value());
const_cast<executor::kv_cache::AgentConnection*>(agentConnection)
->sendRequestAndBufferInfo(requestInfo, cacheBufferId, valideConnectionIdx);
}
else
{
sendRequestInfo(connection, requestInfo);
}
}
auto const& resource = getReceiveCacheResource(llmRequest);
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState,
contextState, resource->mBufferManager, &llmRequest, !common::getEnvKVCacheTransferOutputPath().empty());
}
void DataReceiverImpl::receiveSync(TransferSession& session)
{
mFormatter->unformat(session);
if (!common::getEnvKVCacheTransferOutputPath().empty())
{
std::unique_lock<std::mutex> lock(mMeasuresFileMutex);
if (!mMeasuresFile.is_open())
{
auto outputPath = getTransferOutputPath("recv");
mMeasuresFile.open(outputPath);
TLLM_CHECK_WITH_INFO(
mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str());
}
session.exportMeasure(mMeasuresFile, false);
}
}
void DataReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info)
{
std::ostringstream oss;
RequestInfo::serialize(info, oss);
auto const& serializedInfo = oss.str();
std::size_t const infoSize = serializedInfo.size();
Id id{Id::REQUEST_SEND};
connection->send(executor::kv_cache::DataContext{kID_TAG}, &id, sizeof(id));
connection->send(executor::kv_cache::DataContext{kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize));
connection->send(executor::kv_cache::DataContext{kINFO_TAG}, serializedInfo.data(), infoSize);
}
std::unique_ptr<DataReceiverImpl::ReceiveCacheResource> const& DataReceiverImpl::getReceiveCacheResource(
LlmRequest const& llmRequest)
{
std::scoped_lock<std::mutex> lock(mProcessIoResouceMutex);
TLLM_CHECK(llmRequest.getDataTransceiverState().getCommState().has_value());
std::string processString = "default";
if (common::getEnvRequestKVCacheConcurrent())
{
processString = llmRequest.getDataTransceiverState().getCommState()->toString();
}
if (mProcessToResources.find(processString) == mProcessToResources.end())
{
mProcessToResources.emplace(processString,
std::make_unique<ReceiveCacheResource>(
runtime::BufferManager{std::make_shared<runtime::CudaStream>()}, runtime::CudaEvent{}));
}
return mProcessToResources.at(processString);
}
} // namespace tensorrt_llm::batch_manager

View File

@ -1,113 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cacheFormatter.h"
#include "cacheTransBuffer.h"
#include "dataTransceiver.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
#include <fstream>
namespace tensorrt_llm::batch_manager
{
struct TransceiverTag
{
enum class Id : uint64_t
{
REQUEST_SEND = 1,
TERMINATION = 2
};
static constexpr int32_t kID_TAG{19};
static constexpr int32_t kINFO_SIZE_TAG{22};
static constexpr int32_t kINFO_TAG{32};
};
using BaseCacheFormatter = kv_cache_manager::BaseCacheFormatter;
class DataSenderImpl : public DataSender, public TransceiverTag
{
public:
using SizeType32 = tensorrt_llm::runtime::SizeType32;
DataSenderImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter);
[[nodiscard]] RequestInfo recvRequestInfo() override;
void sendSync(LlmRequest const& llmRequest) override;
[[nodiscard]] executor::kv_cache::CommState const& getCommState() const override;
void setCommState(executor::kv_cache::CommState commState) override;
[[nodiscard]] size_t getCounterpartsCount(LlmRequest::RequestIdType requestId) const override;
void release(LlmRequest::RequestIdType requestId) override;
private:
executor::kv_cache::ConnectionManager* mManager;
std::map<LlmRequest::RequestIdType, TransferSession> mRequestToSession;
executor::DataTransceiverState mSelfState;
std::unique_ptr<BaseCacheFormatter> mFormatter;
std::mutex mMtxForMap;
runtime::BufferManager mBufferManager;
std::ofstream mMeasuresFile;
};
class DataReceiverImpl : public DataReceiver, public TransceiverTag
{
public:
using SizeType32 = tensorrt_llm::runtime::SizeType32;
DataReceiverImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter);
TransferSession sendRequestInfo(LlmRequest const& llmRequest) override;
void receiveSync(TransferSession& session) override;
private:
struct ReceiveCacheResource
{
runtime::BufferManager mBufferManager;
runtime::CudaEvent mCudaEvent;
ReceiveCacheResource(runtime::BufferManager&& bufferManager, runtime::CudaEvent&& cudaEvent)
: mBufferManager(bufferManager)
, mCudaEvent(std::move(cudaEvent))
{
}
};
static void sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info);
[[nodiscard]] std::unique_ptr<ReceiveCacheResource> const& getReceiveCacheResource(LlmRequest const& llmRequest);
executor::kv_cache::ConnectionManager* mManager;
executor::DataTransceiverState mSelfState;
std::unique_ptr<BaseCacheFormatter> mFormatter;
std::unordered_map<std::string, std::unique_ptr<ReceiveCacheResource>> mProcessToResources;
std::mutex mProcessIoResouceMutex;
std::ofstream mMeasuresFile;
std::mutex mMeasuresFileMutex;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -84,7 +84,7 @@ bool MLACacheFormatter::needSendCache(
return selfTpRank % (selfTPNum / destTPNum) == 0;
}
void MLACacheFormatter::format(TransferSession& session)
void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& session)
{
NVTX3_SCOPED_RANGE(MLACacheFormatter_format);
auto const& llmRequest = session.getLlmRequest();
@ -292,7 +292,7 @@ void MLACacheFormatter::format(TransferSession& session)
mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID: %ld.", llmRequest.mRequestId);
}
void MLACacheFormatter::unformat(TransferSession& session)
void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& session)
{
NVTX3_SCOPED_RANGE(MLACacheFormatter_unformat);
auto const& llmRequest = session.getLlmRequest();

View File

@ -35,9 +35,9 @@ public:
TLLM_CHECK(mCacheTransBufferManager);
}
void format(TransferSession& session) override;
void format(tensorrt_llm::batch_manager::TransferSession& session) override;
void unformat(TransferSession& session) override;
void unformat(tensorrt_llm::batch_manager::TransferSession& session) override;
[[nodiscard]] bool inquireSupport(CacheState const& selfConfig, CacheState const& destConfig) const override;

View File

@ -80,6 +80,8 @@ using namespace tensorrt_llm::runtime;
namespace tc = tensorrt_llm::common;
namespace tk = tensorrt_llm::kernels;
using tensorrt_llm::batch_manager::CacheTransceiverFactory;
namespace tensorrt_llm::batch_manager
{

View File

@ -55,6 +55,11 @@ namespace tensorrt_llm::mpi
class MpiWaitThread;
} // namespace tensorrt_llm::mpi
namespace tensorrt_llm::batch_manager
{
class BaseCacheTransceiver;
}
namespace tensorrt_llm::batch_manager
{
@ -79,7 +84,6 @@ class LlmRequest;
class RuntimeBuffers;
class BasePeftCacheManager;
class GuidedDecoder;
class BaseCacheTransceiver;
// Algorithms
class CapacityScheduler;

View File

@ -104,7 +104,7 @@ void AgentConnection::send(DataContext const& ctx, void const* data, size_t size
MemoryDesc srcDesc{
reinterpret_cast<uintptr_t>(data), size, static_cast<uint32_t>(mAgentConnectionManager->getDeviceId())};
MemoryDescs srcDescs{MemoryType::kVRAM, {srcDesc}};
auto dstBaseDesc = mSenderState.mReceiverBufferDesc;
auto dstBaseDesc = mSenderState.mCacheReceiverBufferDesc;
auto offset = size / mSenderState.mOffsetRatio.second * mSenderState.mOffsetRatio.first;
MemoryDesc dstDesc{dstBaseDesc.getAddr() + offset, size, dstBaseDesc.getDeviceId()};
TLLM_LOG_DEBUG(
@ -162,9 +162,9 @@ void AgentConnection::sendRequestAndBufferInfo(
}
void AgentConnection::setSenderState(
MemoryDesc mReceiverBufferDesc, int validSegmentIdx, std::pair<size_t, size_t> offsetRatio)
MemoryDesc mCacheReceiverBufferDesc, int validSegmentIdx, std::pair<size_t, size_t> offsetRatio)
{
mSenderState.mReceiverBufferDesc = mReceiverBufferDesc;
mSenderState.mCacheReceiverBufferDesc = mCacheReceiverBufferDesc;
mSenderState.validSegmentIdx = validSegmentIdx;
mSenderState.mOffsetRatio = offsetRatio;
}

View File

@ -175,7 +175,8 @@ public:
void recv(DataContext const& ctx, void* data, size_t size) const override;
void sendRequestAndBufferInfo(
batch_manager::RequestInfo& requestInfo, std::optional<size_t> cacheBufferId, int validConnectionIdx);
void setSenderState(MemoryDesc mReceiverBufferDesc, int valideSegmentIdx, std::pair<size_t, size_t> offsetRatio);
void setSenderState(
MemoryDesc mCacheReceiverBufferDesc, int valideSegmentIdx, std::pair<size_t, size_t> offsetRatio);
[[nodiscard]] std::optional<size_t> getCacheBufferId() const;
void setHasLoadRemoteAgent(bool hasLoadRemoteAgent);
[[nodiscard]] bool hasLoadRemoteAgent() const;
@ -186,7 +187,7 @@ private:
struct SenderState
{
MemoryDesc mReceiverBufferDesc{nullptr, 0, 0};
MemoryDesc mCacheReceiverBufferDesc{nullptr, 0, 0};
int validSegmentIdx{0};
std::pair<size_t, size_t> mOffsetRatio;
SenderState() = default;

View File

@ -18,7 +18,7 @@
#include "ucxCacheCommunicator.h"
#if ENABLE_UCX
#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h"
#include "tensorrt_llm/batch_manager/dataTransceiver.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/tllmException.h"
#include "tensorrt_llm/executor/cache_transmission/ucx_utils/connection.h"
@ -114,8 +114,8 @@ void UcxConnection::sendConnectionId(DataContext const& ctx, void const* data, s
std::future<void> future = promise.get_future();
auto completionCallback = [&](ucs_status_t, ucxx::RequestCallbackUserData) -> void { promise.set_value(); };
uint64_t tag
= ((mSendTagPrefix & 0xFFFFFFFF) << 32) | static_cast<uint64_t>(batch_manager::TransceiverTag::kID_TAG);
uint64_t tag = ((mSendTagPrefix & 0xFFFFFFFF) << 32)
| static_cast<uint64_t>(tensorrt_llm::batch_manager::TransceiverTag::kID_TAG);
std::vector<char> buffer(size + sizeof(mConnectionId));
memcpy(buffer.data(), data, size);
memcpy(buffer.data() + size, &mConnectionIdInPeer, sizeof(mConnectionIdInPeer));
@ -133,7 +133,7 @@ void UcxConnection::sendConnectionId(DataContext const& ctx, void const* data, s
void UcxConnection::send(DataContext const& ctx, void const* data, size_t size) const
{
if (ctx.getTag() == batch_manager::TransceiverTag::kID_TAG)
if (ctx.getTag() == tensorrt_llm::batch_manager::TransceiverTag::kID_TAG)
{
sendConnectionId(ctx, data, size);
return;

View File

@ -26,7 +26,7 @@
#include "tensorrt_llm/batch_manager/cacheFormatter.h"
#include "tensorrt_llm/batch_manager/cacheTransceiver.h"
#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h"
#include "tensorrt_llm/batch_manager/dataTransceiver.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
@ -45,7 +45,6 @@
#include <gmock/gmock.h>
#include <memory>
#include <random>
#include <tensorrt_llm/batch_manager/dataTransceiverImpl.h>
#include <tensorrt_llm/batch_manager/mlaCacheFormatter.h>
#include <tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h>
@ -84,6 +83,7 @@ class UcxCommTest : public ::testing::Test
};
using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
using TransceiverTag = tensorrt_llm::batch_manager::TransceiverTag;
TEST_F(UcxCommTest, Basic)
{

View File

@ -26,7 +26,6 @@
#include "tensorrt_llm/batch_manager/cacheFormatter.h"
#include "tensorrt_llm/batch_manager/cacheTransceiver.h"
#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
@ -154,106 +153,6 @@ TEST_F(CacheConfigTest, EqualTo)
EXPECT_EQ(state0, state1);
}
// ---------------------------------------
// MockTransceiverTest
// ---------------------------------------
class MockDataSender : public DataSender
{
public:
MockDataSender()
{
ON_CALL(*this, getCommState).WillByDefault(ReturnRef(mState));
ON_CALL(*this, recvRequestInfo)
.WillByDefault(Return(RequestInfo{0,
texec::DataTransceiverState{
texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 1, {10}, nvinfer1::DataType::kFLOAT},
texec::kv_cache::CommState{std::vector<SizeType32>{0}, 0}}}));
ON_CALL(*this, getCounterpartsCount).WillByDefault(Return(1));
}
MOCK_METHOD(RequestInfo, recvRequestInfo, (), (override));
MOCK_METHOD(void, sendSync, (LlmRequest const&), (override));
MOCK_METHOD(texec::kv_cache::CommState const&, getCommState, (), (const));
MOCK_METHOD(void, setCommState, (texec::kv_cache::CommState), (override));
MOCK_METHOD(size_t, getCounterpartsCount, (LlmRequest::RequestIdType), (const));
MOCK_METHOD(void, release, (LlmRequest::RequestIdType), (override));
private:
static texec::kv_cache::CommState mState;
};
texec::kv_cache::CommState MockDataSender::mState;
class MockDataReceiver : public DataReceiver
{
public:
MOCK_METHOD(TransferSession, sendRequestInfo, (LlmRequest const&), (override));
MOCK_METHOD(void, receiveSync, (TransferSession&), (override));
};
class MockTransceiverTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-type-member-init)
{
public:
void SetUp() override {}
void TearDown() override {}
static auto makeLlmRequest(
LlmRequest::RequestIdType requestId = 0, SizeType32 maxNewTokens = 1, VecTokens inputTokens = {-1})
{
texec::Request request{std::move(inputTokens), maxNewTokens};
auto state = std::make_unique<texec::DataTransceiverState>();
auto stats = texec::ContextPhaseParams({}, requestId, state.release(), std::nullopt);
request.setContextPhaseParams(std::move(stats));
return std::make_unique<LlmRequest>(requestId, std::move(request));
}
};
TEST_F(MockTransceiverTest, MpiResponderBasic)
{
if (tensorrt_llm::mpi::MpiComm::world().getSize() > 2)
{
GTEST_SKIP() << "mpirun with procs<=2 is required to run this test.";
}
auto sender = std::make_unique<MockDataSender>();
EXPECT_CALL(*sender, recvRequestInfo)
.WillOnce(Return(RequestInfo{0,
texec::DataTransceiverState{
texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 1, {4}, nvinfer1::DataType::kFLOAT},
texec::kv_cache::CommState{std::vector<SizeType32>{0}, 0}}}));
EXPECT_CALL(*sender, sendSync).WillOnce(Return());
EXPECT_CALL(*sender, getCounterpartsCount).WillOnce(Return(1));
EXPECT_CALL(*sender, release).WillOnce(Return());
DataResponder responder{std::move(sender)};
auto request = makeLlmRequest(0);
auto future = responder.respondAndSendAsync(*request);
future.get();
}
TEST_F(MockTransceiverTest, MpiRequesterBasic)
{
if (tensorrt_llm::mpi::MpiComm::world().getSize() > 2)
{
GTEST_SKIP() << "mpirun with procs<=2 is required to run this test.";
}
auto receiver = std::make_unique<MockDataReceiver>();
auto state = std::make_unique<texec::DataTransceiverState>();
state->setCommState(texec::kv_cache::CommState{std::vector<int>{0}});
EXPECT_CALL(*receiver, sendRequestInfo)
.WillOnce(Return(TransferSession({nullptr}, DataContext{0}, *state, *state,
tensorrt_llm::runtime::BufferManager{std::make_shared<tr::CudaStream>()}, nullptr)));
EXPECT_CALL(*receiver, receiveSync).WillOnce(Return());
DataRequester requester{std::move(receiver)};
auto request = makeLlmRequest(0);
auto stats = texec::ContextPhaseParams({}, 0, state.release(), std::nullopt);
request->setContextPhaseParams(std::move(stats));
auto future = requester.requestAndReceiveAsync(*request);
future.get();
}
// TODO: Restore multi-rank tests.
// ---------------------------------------
@ -397,15 +296,13 @@ protected:
mCacheTransBufferManager = std::make_unique<CacheTransBufferManager>(mManager.get(), maxNumTokens);
if (isSender)
{
mResponder = std::make_unique<DataResponder>(
std::make_unique<DataSenderImpl>(mConnectionManager.get(), *mCacheState, mlocalRank,
std::make_unique<CacheFormatter>(mManager.get(), mCacheTransBufferManager.get())));
mSender = std::make_unique<CacheSender>(mConnectionManager.get(), *mCacheState, mlocalRank,
std::make_unique<CacheFormatter>(mManager.get(), mCacheTransBufferManager.get()));
}
else
{
mRequester = std::make_unique<DataRequester>(
std::make_unique<DataReceiverImpl>(mConnectionManager.get(), *mCacheState, mlocalRank,
std::make_unique<CacheFormatter>(mManager.get(), mCacheTransBufferManager.get())));
mRequester = std::make_unique<CacheReceiver>(mConnectionManager.get(), *mCacheState, mlocalRank,
std::make_unique<CacheFormatter>(mManager.get(), mCacheTransBufferManager.get()));
}
}
@ -435,11 +332,11 @@ protected:
// fill cache with tokens (= request length), for reuse test
TLLM_CUDA_CHECK(cudaMemset(block.data(), llmRequest->getPromptLen(), block.getSizeInBytes()));
}
mFutures.emplace_back(mResponder->respondAndSendAsync(*llmRequest));
mFutures.emplace_back(mSender->sendAsync(*llmRequest));
}
else
{
auto future = mRequester->requestAndReceiveAsync(*llmRequest);
auto future = mRequester->receiveAsync(*llmRequest);
future.get();
TLLM_CUDA_CHECK(cudaDeviceSynchronize());
auto blockRange = BlockRange::fromAllBlockIds(*mManager, llmRequest->mRequestId);
@ -460,8 +357,8 @@ protected:
SizeType32 mMaxNumSequences{};
std::unique_ptr<KVCacheManager> mManager;
std::unique_ptr<CacheTransBufferManager> mCacheTransBufferManager;
std::unique_ptr<DataResponder> mResponder;
std::unique_ptr<DataRequester> mRequester;
std::unique_ptr<CacheSender> mSender;
std::unique_ptr<CacheReceiver> mRequester;
std::unique_ptr<texec::kv_cache::CacheState> mCacheState;
std::unique_ptr<texec::kv_cache::CommState> mContextCommState;
std::vector<std::future<void>> mFutures;
@ -789,13 +686,13 @@ protected:
if (mIsContext)
{
mResponder = std::make_unique<DataResponder>(std::make_unique<DataSenderImpl>(
mConnectionManager.get(), *mCacheState, mRankInInstance, makeFormatter()));
mSender = std::make_unique<CacheSender>(
mConnectionManager.get(), *mCacheState, mRankInInstance, makeFormatter());
}
else
{
mRequester = std::make_unique<DataRequester>(std::make_unique<DataReceiverImpl>(
mConnectionManager.get(), *mCacheState, mRankInInstance, makeFormatter()));
mRequester = std::make_unique<CacheReceiver>(
mConnectionManager.get(), *mCacheState, mRankInInstance, makeFormatter());
}
std::vector<int> contextRankVec(mContextRankSize);
@ -930,7 +827,7 @@ protected:
auto const onlyWindowSize = blockManager.getPoolWindowSize(0);
blockManager.getBufferManager(onlyWindowSize).getStream().synchronize();
auto future = mResponder->respondAndSendAsync(*llmRequest);
auto future = mSender->sendAsync(*llmRequest);
return future;
}
@ -940,7 +837,7 @@ protected:
auto constexpr beamWidth{1};
mManager->addSequence(llmRequest->mRequestId, llmRequest->getNumTokens(beamIdx), beamWidth, llmRequest);
return mRequester->requestAndReceiveAsync(*llmRequest);
return mRequester->receiveAsync(*llmRequest);
}
void generationVerifyKVCache(std::shared_ptr<LlmRequest> const& llmRequest)
@ -1162,8 +1059,8 @@ protected:
SizeType32 mMaxNumSequences{};
std::unique_ptr<KVCacheManager> mManager;
std::unique_ptr<CacheTransBufferManager> mCacheTransBufferManager;
std::unique_ptr<DataResponder> mResponder;
std::unique_ptr<DataRequester> mRequester;
std::unique_ptr<CacheSender> mSender;
std::unique_ptr<CacheReceiver> mRequester;
std::unique_ptr<texec::kv_cache::CacheState> mCacheState;
std::unique_ptr<texec::kv_cache::CacheState> mContextCacheState;
std::unique_ptr<texec::kv_cache::CommState> mContextCommState;