mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-8044][refactor] Rename data -> cache for cacheTransceiver (#7659)
This commit is contained in:
parent
8226ef23dc
commit
6ce0624208
@ -34,8 +34,14 @@ namespace tensorrt_llm::batch_manager
|
||||
|
||||
class ContextProgress;
|
||||
class BaseCacheTransceiver;
|
||||
class DataResponder;
|
||||
class DataRequester;
|
||||
|
||||
namespace kv_cache_manager
|
||||
{
|
||||
class BaseKVCacheManager;
|
||||
} // namespace kv_cache_manager
|
||||
|
||||
class CacheSender;
|
||||
class CacheReceiver;
|
||||
|
||||
class CacheTransceiverFactory
|
||||
{
|
||||
@ -110,9 +116,9 @@ private:
|
||||
|
||||
void setContextState(LlmRequest* llmRequest);
|
||||
|
||||
std::unique_ptr<DataResponder> mDataResponder;
|
||||
std::unique_ptr<DataRequester> mDataRequester;
|
||||
std::vector<std::pair<LlmRequest*, std::future<void>>> mResponderFutures;
|
||||
std::unique_ptr<CacheSender> mCacheSender;
|
||||
std::unique_ptr<CacheReceiver> mCacheReceiver;
|
||||
std::vector<std::pair<LlmRequest*, std::future<void>>> mSenderFutures;
|
||||
std::vector<std::pair<LlmRequest*, std::future<void>>> mRequesterFutures;
|
||||
mpi::MpiComm const *mMpiGroupComm{nullptr}, *mMpiWorldComm{nullptr};
|
||||
std::shared_ptr<mpi::MpiComm> mMpiGroupTensorParaComm, mMpiGroupPipeParaComm, mMpiGroupDataComm,
|
||||
|
||||
@ -24,7 +24,6 @@ set(SRCS
|
||||
createNewDecoderRequests.cpp
|
||||
contextProgress.cpp
|
||||
dataTransceiver.cpp
|
||||
dataTransceiverImpl.cpp
|
||||
decoderBuffers.cpp
|
||||
encoderBuffers.cpp
|
||||
guidedDecoder.cpp
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
#include "mlaCacheFormatter.h"
|
||||
|
||||
#include "tensorrt_llm/batch_manager/contextProgress.h"
|
||||
#include "tensorrt_llm/batch_manager/dataTransceiver.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
@ -154,7 +155,7 @@ std::vector<size_t> CacheFormatter::pickRecvConnections(
|
||||
return ret;
|
||||
}
|
||||
|
||||
void CacheFormatter::format(TransferSession& session)
|
||||
void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& session)
|
||||
{
|
||||
NVTX3_SCOPED_RANGE(CacheFormatter_format);
|
||||
auto const& llmRequest = session.getLlmRequest();
|
||||
@ -468,7 +469,7 @@ void CacheFormatter::format(TransferSession& session)
|
||||
mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID:%ld ", llmRequest.mRequestId);
|
||||
}
|
||||
|
||||
void CacheFormatter::unformat(TransferSession& session)
|
||||
void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& session)
|
||||
{
|
||||
NVTX3_SCOPED_RANGE(CacheFormatter_unformat);
|
||||
auto const& llmRequest = session.getLlmRequest();
|
||||
|
||||
@ -18,11 +18,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "cacheTransBuffer.h"
|
||||
#include "dataTransceiver.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/executor/cacheCommunicator.h"
|
||||
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
|
||||
#include "tensorrt_llm/executor/dataTransceiverState.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
@ -30,10 +31,25 @@
|
||||
#include <NvInferRuntimeBase.h>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
// Forward declare TransferSession in the correct global namespace scope
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
class TransferSession;
|
||||
}
|
||||
|
||||
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
{
|
||||
|
||||
using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
|
||||
using Connection = tensorrt_llm::executor::kv_cache::Connection;
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
using BaseKVCacheManager = kv_cache_manager::BaseKVCacheManager;
|
||||
using CacheTransBufferManager = kv_cache_manager::CacheTransBufferManager;
|
||||
using BlockRange = kv_cache_manager::BlockRange;
|
||||
|
||||
BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest);
|
||||
|
||||
BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest);
|
||||
@ -42,16 +58,15 @@ BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmReques
|
||||
class BaseCacheFormatter
|
||||
{
|
||||
public:
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
using CacheState = executor::kv_cache::CacheState;
|
||||
|
||||
/// @brief Format the cache data into bytes for sending.
|
||||
/// @param session The transfer session.
|
||||
virtual void format(TransferSession& session) = 0;
|
||||
virtual void format(tensorrt_llm::batch_manager::TransferSession& session) = 0;
|
||||
|
||||
/// @brief Unformat the cache data from received bytes.
|
||||
/// @param session The transfer session.
|
||||
virtual void unformat(TransferSession& session) = 0;
|
||||
virtual void unformat(tensorrt_llm::batch_manager::TransferSession& session) = 0;
|
||||
|
||||
/// @brief Determine whether the sender is applicable to the source and target.
|
||||
/// @param selfConfig Source data arrangement.
|
||||
@ -91,9 +106,9 @@ public:
|
||||
TLLM_CHECK(mCacheTransBufferManager);
|
||||
}
|
||||
|
||||
void format(TransferSession& session) override;
|
||||
void format(tensorrt_llm::batch_manager::TransferSession& session) override;
|
||||
|
||||
void unformat(TransferSession& session) override;
|
||||
void unformat(tensorrt_llm::batch_manager::TransferSession& session) override;
|
||||
|
||||
[[nodiscard]] bool inquireSupport(CacheState const& selfConfig, CacheState const& destConfig) const override;
|
||||
|
||||
|
||||
@ -37,8 +37,9 @@
|
||||
#include "tensorrt_llm/batch_manager/cacheFormatter.h"
|
||||
#include "tensorrt_llm/batch_manager/cacheTransceiver.h"
|
||||
#include "tensorrt_llm/batch_manager/contextProgress.h"
|
||||
#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheType.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/batch_manager/mlaCacheFormatter.h"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
@ -116,7 +117,6 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
|
||||
: mMpiGroupComm(std::addressof(tensorrt_llm::mpi::MpiComm::session()))
|
||||
, mCacheTransceiverConfig{cacheTransceiverConfig}
|
||||
{
|
||||
using tensorrt_llm::batch_manager::kv_cache_manager::CacheFormatter;
|
||||
if (worldConfig.isPipelineParallel())
|
||||
{
|
||||
mMpiGroupPipeParaComm = std::make_shared<tensorrt_llm::mpi::MpiComm>(
|
||||
@ -200,14 +200,12 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
|
||||
TLLM_THROW("Unsupported cache transceiver backend type ");
|
||||
}
|
||||
|
||||
using tensorrt_llm::batch_manager::kv_cache_manager::MLACacheFormatter;
|
||||
auto makeFormatter = [cacheManager, isMLA, this]()
|
||||
{ return createCacheFormatter(cacheManager, mCacheTransBufferManager.get(), isMLA); };
|
||||
|
||||
mDataResponder = std::make_unique<DataResponder>(
|
||||
std::make_unique<DataSenderImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
|
||||
mDataRequester = std::make_unique<DataRequester>(
|
||||
std::make_unique<DataReceiverImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
|
||||
mCacheSender = std::make_unique<CacheSender>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter());
|
||||
mCacheReceiver
|
||||
= std::make_unique<CacheReceiver>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter());
|
||||
|
||||
initializeCommState();
|
||||
}
|
||||
@ -223,7 +221,7 @@ CacheTransceiver::~CacheTransceiver()
|
||||
|
||||
void CacheTransceiver::initializeCommState()
|
||||
{
|
||||
mCommState = std::addressof(mDataResponder->getCommState());
|
||||
mCommState = std::addressof(mCacheSender->getCommState());
|
||||
}
|
||||
|
||||
void CacheTransceiver::setContextState(LlmRequest* llmRequest)
|
||||
@ -259,8 +257,8 @@ void CacheTransceiver::respondAndSendAsync(LlmRequest* llmRequest)
|
||||
return;
|
||||
}
|
||||
setContextState(llmRequest);
|
||||
auto future = mDataResponder->respondAndSendAsync(*llmRequest);
|
||||
mResponderFutures.emplace_back(llmRequest, std::move(future));
|
||||
auto future = mCacheSender->sendAsync(*llmRequest);
|
||||
mSenderFutures.emplace_back(llmRequest, std::move(future));
|
||||
}
|
||||
|
||||
void CacheTransceiver::respondAndSendLayerWise(
|
||||
@ -275,8 +273,8 @@ void CacheTransceiver::respondAndSendLayerWise(
|
||||
|
||||
llmRequest->setState(LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS);
|
||||
setContextState(llmRequest.get());
|
||||
auto future = mDataResponder->respondAndSendAsync(*llmRequest);
|
||||
mResponderFutures.emplace_back(llmRequest.get(), std::move(future));
|
||||
auto future = mCacheSender->sendAsync(*llmRequest);
|
||||
mSenderFutures.emplace_back(llmRequest.get(), std::move(future));
|
||||
}
|
||||
}
|
||||
|
||||
@ -284,7 +282,7 @@ void CacheTransceiver::requestAndReceiveSync(LlmRequest* llmRequest)
|
||||
{
|
||||
TLLM_CHECK(llmRequest && llmRequest->isGenerationOnlyRequest());
|
||||
{
|
||||
auto future = mDataRequester->requestAndReceiveAsync(*llmRequest);
|
||||
auto future = mCacheReceiver->receiveAsync(*llmRequest);
|
||||
future.get();
|
||||
}
|
||||
llmRequest->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE);
|
||||
@ -302,7 +300,7 @@ void CacheTransceiver::requestAndReceiveAsync(LlmRequest* llmRequest)
|
||||
return;
|
||||
}
|
||||
|
||||
auto future = mDataRequester->requestAndReceiveAsync(*llmRequest);
|
||||
auto future = mCacheReceiver->receiveAsync(*llmRequest);
|
||||
mRequesterFutures.emplace_back(llmRequest, std::move(future));
|
||||
llmRequest->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS);
|
||||
}
|
||||
@ -382,7 +380,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
|
||||
bool blockAll = !atLeastRequestNum.has_value();
|
||||
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mMpiGroupTPInDPComm : mMpiGroupTensorParaComm;
|
||||
std::vector<LlmRequest::RequestIdType> contextCompleteRequestIds;
|
||||
for (auto&& [request, future] : mResponderFutures)
|
||||
for (auto&& [request, future] : mSenderFutures)
|
||||
{
|
||||
if (future.wait_for(std::chrono::milliseconds(0)) == std::future_status::ready)
|
||||
{
|
||||
@ -422,16 +420,15 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
|
||||
|
||||
// Make sure there are at least atLeastRequestNum requests in toCompleteIdSet.
|
||||
// This will preserve the order of insertion for KVCache transfer requests.
|
||||
for (auto it = mResponderFutures.begin();
|
||||
atLeastRequestNum.value_or(0) > static_cast<int>(toCompleteIdSet.size()) && it != mResponderFutures.end();
|
||||
++it)
|
||||
for (auto it = mSenderFutures.begin();
|
||||
atLeastRequestNum.value_or(0) > static_cast<int>(toCompleteIdSet.size()) && it != mSenderFutures.end(); ++it)
|
||||
{
|
||||
auto& [request, future] = *it;
|
||||
toCompleteIdSet.insert(request->mRequestId);
|
||||
}
|
||||
|
||||
// Complete all the requests in toCompleteIdSet
|
||||
for (auto it = mResponderFutures.begin(); it != mResponderFutures.end();)
|
||||
for (auto it = mSenderFutures.begin(); it != mSenderFutures.end();)
|
||||
{
|
||||
auto& [request, future] = *it;
|
||||
if (blockAll || (toCompleteIdSet.find(request->mRequestId) != toCompleteIdSet.end()))
|
||||
@ -447,7 +444,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
|
||||
"Error occurred during context transfer for request %ld: %s", request->mRequestId, e.what());
|
||||
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
|
||||
}
|
||||
it = mResponderFutures.erase(it);
|
||||
it = mSenderFutures.erase(it);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
|
||||
#include "dataTransceiver.h"
|
||||
|
||||
#include "tensorrt_llm/batch_manager/cacheFormatter.h"
|
||||
#include "tensorrt_llm/batch_manager/common.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"
|
||||
#include "tensorrt_llm/batch_manager/runtimeBuffers.h"
|
||||
@ -24,6 +25,7 @@
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/tllmException.h"
|
||||
#include "tensorrt_llm/common/utils.h"
|
||||
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
#include <future>
|
||||
#include <map>
|
||||
@ -33,8 +35,138 @@
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
|
||||
using kv_cache_manager::BlockRange;
|
||||
using BlockRange = tensorrt_llm::batch_manager::kv_cache_manager::BlockRange;
|
||||
|
||||
std::vector<Connection const*> const& TransferSession::getConnections() const
|
||||
{
|
||||
return mConnections;
|
||||
}
|
||||
|
||||
void TransferSession::setConnection(size_t idx, Connection const* conn)
|
||||
{
|
||||
mConnections.at(idx) = conn;
|
||||
}
|
||||
|
||||
DataContext const& TransferSession::getDataContext() const
|
||||
{
|
||||
return mDataContext;
|
||||
}
|
||||
|
||||
executor::DataTransceiverState const& TransferSession::getSelfState() const
|
||||
{
|
||||
return *mSelfState;
|
||||
}
|
||||
|
||||
executor::DataTransceiverState const& TransferSession::getOtherState() const
|
||||
{
|
||||
return mOtherState;
|
||||
}
|
||||
|
||||
runtime::BufferManager const& TransferSession::getBufferManager() const
|
||||
{
|
||||
return *mBufferManager;
|
||||
}
|
||||
|
||||
void TransferSession::send(size_t idx, void const* data, size_t size)
|
||||
{
|
||||
try
|
||||
{
|
||||
mConnections.at(idx)->send(mDataContext, data, size);
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
throw common::RequestSpecificException(
|
||||
__FILE__, __LINE__, e.what(), mRequest->mRequestId, common::RequestErrorCode::kNETWORK_ERROR);
|
||||
}
|
||||
}
|
||||
|
||||
void TransferSession::recv(size_t idx, void* data, size_t size)
|
||||
{
|
||||
try
|
||||
{
|
||||
mConnections.at(idx)->recv(mDataContext, data, size);
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
throw common::RequestSpecificException(
|
||||
__FILE__, __LINE__, e.what(), mRequest->mRequestId, common::RequestErrorCode::kNETWORK_ERROR);
|
||||
}
|
||||
}
|
||||
|
||||
LlmRequest const& TransferSession::getLlmRequest() const
|
||||
{
|
||||
TLLM_CHECK(mRequest != nullptr);
|
||||
return *mRequest;
|
||||
}
|
||||
|
||||
void TransferSession::setLlmRequest(LlmRequest const& llmRequest)
|
||||
{
|
||||
mRequest = &llmRequest;
|
||||
}
|
||||
|
||||
void TransferSession::appendMeasure(double delay, double duration, size_t size)
|
||||
{
|
||||
if (!mRecordMeasure)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps
|
||||
mMeasures.emplace_back(Measure{delay, duration, bandwidth});
|
||||
}
|
||||
|
||||
void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const
|
||||
{
|
||||
if (mMeasures.empty())
|
||||
{
|
||||
return;
|
||||
}
|
||||
// write header if not exist
|
||||
if (outFile.tellp() == 0)
|
||||
{
|
||||
outFile << "RequestID";
|
||||
for (size_t i = 0; i < mMeasures.size(); i++)
|
||||
{
|
||||
outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)";
|
||||
}
|
||||
outFile << '\n';
|
||||
}
|
||||
// write measures
|
||||
TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value());
|
||||
auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId();
|
||||
outFile << reqId;
|
||||
for (auto const& measure : mMeasures)
|
||||
{
|
||||
outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth;
|
||||
}
|
||||
outFile << '\n' << std::flush;
|
||||
}
|
||||
|
||||
std::vector<size_t> const& RequestInfo::getBlockHashes() const noexcept
|
||||
{
|
||||
return mBlockHashes;
|
||||
}
|
||||
|
||||
using runtime::SizeType32;
|
||||
using AgentConnectionManager = tensorrt_llm::executor::kv_cache::AgentConnectionManager;
|
||||
using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
|
||||
|
||||
static int32_t tagFromRequestId(LlmRequest::RequestIdType requestId)
|
||||
{
|
||||
constexpr int32_t kDATA_TAG{43};
|
||||
return ((requestId & 0xFFF) << 8) | (kDATA_TAG & 0xFF);
|
||||
}
|
||||
|
||||
struct ReceiveCacheResource
|
||||
{
|
||||
runtime::BufferManager mBufferManager;
|
||||
runtime::CudaEvent mCudaEvent;
|
||||
|
||||
ReceiveCacheResource(runtime::BufferManager&& bufferManager, runtime::CudaEvent&& cudaEvent)
|
||||
: mBufferManager(bufferManager)
|
||||
, mCudaEvent(std::move(cudaEvent))
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
RequestInfo::RequestInfo(LlmRequest::RequestIdType requestId, executor::DataTransceiverState transState)
|
||||
: mRequestId{requestId}
|
||||
@ -92,82 +224,128 @@ std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo)
|
||||
return totalSize;
|
||||
}
|
||||
|
||||
void TransferSession::appendMeasure(double delay, double duration, size_t size)
|
||||
{
|
||||
if (!mRecordMeasure)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps
|
||||
mMeasures.emplace_back(Measure{delay, duration, bandwidth});
|
||||
}
|
||||
|
||||
void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const
|
||||
{
|
||||
if (mMeasures.empty())
|
||||
{
|
||||
return;
|
||||
}
|
||||
// write header if not exist
|
||||
if (outFile.tellp() == 0)
|
||||
{
|
||||
outFile << "RequestID";
|
||||
for (size_t i = 0; i < mMeasures.size(); i++)
|
||||
{
|
||||
outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)";
|
||||
}
|
||||
outFile << '\n';
|
||||
}
|
||||
// write measures
|
||||
TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value());
|
||||
auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId();
|
||||
outFile << reqId;
|
||||
for (auto const& measure : mMeasures)
|
||||
{
|
||||
outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth;
|
||||
}
|
||||
outFile << '\n' << std::flush;
|
||||
}
|
||||
|
||||
class DataResponder::Impl
|
||||
class CacheSender::Impl
|
||||
{
|
||||
public:
|
||||
using RequestIdType = LlmRequest::RequestIdType;
|
||||
|
||||
Impl(std::unique_ptr<DataSender> sender)
|
||||
: mSender{std::move(sender)}
|
||||
Impl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
|
||||
SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
|
||||
: mManager{manager}
|
||||
, mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}}
|
||||
, mFormatter{std::move(formatter)}
|
||||
, mBufferManager{std::make_shared<runtime::CudaStream>()}
|
||||
{
|
||||
TLLM_CHECK(mSender);
|
||||
TLLM_CHECK(mManager);
|
||||
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
|
||||
TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId));
|
||||
mCurrentRequest = std::nullopt;
|
||||
mResponseFuture = std::async(std::launch::async, &Impl::response, this);
|
||||
}
|
||||
|
||||
[[nodiscard]] std::future<void> respondAndSendAsync(LlmRequest& llmRequest)
|
||||
[[nodiscard]] std::future<void> sendAsync(LlmRequest& llmRequest)
|
||||
{
|
||||
std::promise<void> promise;
|
||||
auto future = promise.get_future();
|
||||
{
|
||||
{
|
||||
std::unique_lock lkResp(mResponderMutex);
|
||||
std::unique_lock lkResp(mSenderMutex);
|
||||
mReadyResponses.emplace(
|
||||
llmRequest.mRequestId, Response{std::addressof(llmRequest), std::move(promise)});
|
||||
}
|
||||
std::unique_lock lkCond(mCondMutex);
|
||||
mAnyReady = true;
|
||||
}
|
||||
mResponderCv.notify_all();
|
||||
mSenderCv.notify_all();
|
||||
return future;
|
||||
}
|
||||
|
||||
[[nodiscard]] executor::kv_cache::CommState const& getCommState() const
|
||||
{
|
||||
return mSender->getCommState();
|
||||
return mSelfState.getCommState().value();
|
||||
}
|
||||
|
||||
void setCommState(executor::kv_cache::CommState commState)
|
||||
{
|
||||
mSender->setCommState(std::move(commState));
|
||||
mSelfState.setCommState(std::move(commState));
|
||||
}
|
||||
|
||||
[[nodiscard]] size_t getCounterpartsCount(LlmRequest::RequestIdType requestId) const
|
||||
{
|
||||
auto it = mRequestToSession.find(requestId);
|
||||
TLLM_CHECK(it != mRequestToSession.end());
|
||||
return it->second.getConnections().size();
|
||||
}
|
||||
|
||||
void release(LlmRequest::RequestIdType requestId)
|
||||
{
|
||||
auto it = mRequestToSession.find(requestId);
|
||||
TLLM_CHECK(it != mRequestToSession.end());
|
||||
std::unique_lock<std::mutex> lk(mMtxForMap);
|
||||
mRequestToSession.erase(it);
|
||||
}
|
||||
|
||||
[[nodiscard]] RequestInfo recvRequestInfo()
|
||||
{
|
||||
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
|
||||
bool isAgent = agentConnectionManager != nullptr;
|
||||
|
||||
auto agentRecvFun = [&](RequestInfo& requestInfo)
|
||||
{
|
||||
auto const* connection = agentConnectionManager->recvConnectionAndRequestInfo(requestInfo);
|
||||
return connection;
|
||||
};
|
||||
TransceiverTag::Id id;
|
||||
RequestInfo info;
|
||||
auto const* connection = isAgent ? agentRecvFun(info)
|
||||
: mManager->recvConnect(DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id));
|
||||
if (!isAgent)
|
||||
{
|
||||
TLLM_CHECK(id == TransceiverTag::Id::REQUEST_SEND);
|
||||
std::uint64_t infoSize{0};
|
||||
connection->recv(
|
||||
executor::kv_cache::DataContext{TransceiverTag::kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize));
|
||||
std::string serializedInfo;
|
||||
serializedInfo.resize(infoSize);
|
||||
connection->recv(
|
||||
executor::kv_cache::DataContext{TransceiverTag::kINFO_TAG}, serializedInfo.data(), infoSize);
|
||||
std::istringstream iss(serializedInfo);
|
||||
info = RequestInfo::deserialize(iss);
|
||||
}
|
||||
|
||||
auto requestId = info.getRequestId();
|
||||
TLLM_CHECK_WITH_INFO(mFormatter->inquireSupport(
|
||||
mSelfState.getCacheState().value(), info.getTransState().getCacheState().value()),
|
||||
"Disagg server does not currently support these cacheState, please check the cacheState of the context and "
|
||||
"gen "
|
||||
"executors");
|
||||
auto peerRelativeRanks = executor::kv_cache::targetIRanks(info.getTransState().getCacheState().value(),
|
||||
mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx())
|
||||
.mIRanks;
|
||||
int peerIdx = std::distance(peerRelativeRanks.begin(),
|
||||
std::find(
|
||||
peerRelativeRanks.begin(), peerRelativeRanks.end(), info.getTransState().getCommState()->getSelfIdx()));
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(mMtxForMap);
|
||||
auto it = mRequestToSession.find(requestId);
|
||||
if (it == mRequestToSession.end())
|
||||
{
|
||||
auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr),
|
||||
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager);
|
||||
it = mRequestToSession.emplace(requestId, std::move(session)).first;
|
||||
}
|
||||
it->second.setConnection(peerIdx, connection);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
void sendSync(LlmRequest const& llmRequest)
|
||||
{
|
||||
auto it = mRequestToSession.find(llmRequest.mRequestId);
|
||||
TLLM_CHECK(it != mRequestToSession.end());
|
||||
auto& session = it->second;
|
||||
session.setLlmRequest(llmRequest);
|
||||
mFormatter->format(session);
|
||||
}
|
||||
|
||||
~Impl()
|
||||
@ -187,8 +365,8 @@ private:
|
||||
try
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId));
|
||||
mSender->sendSync(*resp.mRequest);
|
||||
mSender->release(id);
|
||||
sendSync(*resp.mRequest);
|
||||
release(id);
|
||||
resp.mPromise.set_value();
|
||||
}
|
||||
catch (tensorrt_llm::common::RequestSpecificException const& e)
|
||||
@ -220,12 +398,11 @@ private:
|
||||
if (common::getEnvParallelCacheSend())
|
||||
{
|
||||
// TODO: Use a thread pool and check for thread safety.
|
||||
std::thread(&DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second))
|
||||
.detach();
|
||||
std::thread(&CacheSender::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second)).detach();
|
||||
}
|
||||
else
|
||||
{
|
||||
DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
|
||||
CacheSender::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
|
||||
}
|
||||
removeResponse(it);
|
||||
}
|
||||
@ -243,7 +420,7 @@ private:
|
||||
if (!mAnyReady)
|
||||
{
|
||||
std::unique_lock lk(mCondMutex);
|
||||
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
|
||||
mSenderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
|
||||
}
|
||||
if (mTerminate)
|
||||
{
|
||||
@ -252,14 +429,14 @@ private:
|
||||
std::vector<size_t> blockHashes;
|
||||
if (!isSending() && !mReadyResponses.empty())
|
||||
{
|
||||
auto const& requestInfo = mSender->recvRequestInfo();
|
||||
auto const& requestInfo = recvRequestInfo();
|
||||
auto reqId = requestInfo.getRequestId();
|
||||
blockHashes = requestInfo.getBlockHashes();
|
||||
|
||||
mCurrentRequest = reqId;
|
||||
if (mRemainSendCount.find(reqId) == mRemainSendCount.end())
|
||||
{
|
||||
mRemainSendCount[reqId] = mSender->getCounterpartsCount(reqId);
|
||||
mRemainSendCount[reqId] = getCounterpartsCount(reqId);
|
||||
}
|
||||
}
|
||||
auto it = getCurrentResponse();
|
||||
@ -273,7 +450,7 @@ private:
|
||||
while (it == mReadyResponses.end())
|
||||
{
|
||||
std::unique_lock lk(mCondMutex);
|
||||
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
|
||||
mSenderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
|
||||
if (mTerminate)
|
||||
{
|
||||
break;
|
||||
@ -286,7 +463,7 @@ private:
|
||||
}
|
||||
catch (std::exception const& err)
|
||||
{
|
||||
TLLM_LOG_ERROR("Exception in DataResponder response: %s", err.what());
|
||||
TLLM_LOG_ERROR("Exception in CacheSender response: %s", err.what());
|
||||
for (auto& it : mReadyResponses)
|
||||
{
|
||||
it.second.mPromise.set_exception(std::current_exception());
|
||||
@ -302,13 +479,13 @@ private:
|
||||
}
|
||||
// We don't have to wait for the future. If another thread is sending data, it won't pay attention
|
||||
// to the terminate flag.
|
||||
mResponderCv.notify_all();
|
||||
mSenderCv.notify_all();
|
||||
}
|
||||
|
||||
void removeResponse(std::map<RequestIdType, Response>::iterator it)
|
||||
{
|
||||
{
|
||||
std::unique_lock lkResp(mResponderMutex);
|
||||
std::unique_lock lkResp(mSenderMutex);
|
||||
mReadyResponses.erase(it);
|
||||
}
|
||||
if (mReadyResponses.empty())
|
||||
@ -330,36 +507,47 @@ private:
|
||||
|
||||
[[nodiscard]] std::map<RequestIdType, Response>::iterator getCurrentResponse()
|
||||
{
|
||||
std::unique_lock lk(mResponderMutex);
|
||||
std::unique_lock lk(mSenderMutex);
|
||||
return mReadyResponses.find(getCurrentRequestId());
|
||||
}
|
||||
|
||||
private:
|
||||
std::optional<RequestIdType> mCurrentRequest;
|
||||
std::map<RequestIdType, Response> mReadyResponses;
|
||||
std::mutex mResponderMutex, mCondMutex;
|
||||
std::mutex mSenderMutex, mCondMutex;
|
||||
std::atomic<bool> mAnyReady{false}, mTerminate{false};
|
||||
std::condition_variable mResponderCv;
|
||||
std::condition_variable mSenderCv;
|
||||
std::future<void> mResponseFuture;
|
||||
std::unique_ptr<DataSender> mSender;
|
||||
std::unordered_map<LlmRequest::RequestIdType, int> mRemainSendCount;
|
||||
int mDeviceId{-1};
|
||||
|
||||
executor::kv_cache::ConnectionManager* mManager;
|
||||
std::map<LlmRequest::RequestIdType, TransferSession> mRequestToSession;
|
||||
executor::DataTransceiverState mSelfState;
|
||||
std::unique_ptr<BaseCacheFormatter> mFormatter;
|
||||
std::mutex mMtxForMap;
|
||||
runtime::BufferManager mBufferManager;
|
||||
};
|
||||
|
||||
class DataRequester::Impl
|
||||
class CacheReceiver::Impl
|
||||
{
|
||||
public:
|
||||
Impl(std::unique_ptr<DataReceiver> receiver)
|
||||
: mReceiver{std::move(receiver)}
|
||||
Impl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
|
||||
SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
|
||||
: mManager{manager}
|
||||
, mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}}
|
||||
, mFormatter{std::move(formatter)}
|
||||
, mBufferManager{std::make_shared<runtime::CudaStream>()}
|
||||
{
|
||||
TLLM_CHECK(mReceiver);
|
||||
TLLM_CHECK(mManager);
|
||||
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
|
||||
TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId));
|
||||
}
|
||||
|
||||
[[nodiscard]] std::future<void> requestAndReceiveAsync(LlmRequest& llmRequest)
|
||||
[[nodiscard]] std::future<void> receiveAsync(LlmRequest& llmRequest)
|
||||
{
|
||||
// TODO: Modify the implementation here to avoid frequent thread creation.
|
||||
return std::async(std::launch::async, &DataRequester::Impl::requestSync, this, std::ref(llmRequest));
|
||||
return std::async(std::launch::async, &CacheReceiver::Impl::requestSync, this, std::ref(llmRequest));
|
||||
}
|
||||
|
||||
[[nodiscard]] std::future<void> requestAndReceiveAsyncMultiThreads(LlmRequest& llmRequest)
|
||||
@ -378,7 +566,7 @@ public:
|
||||
{
|
||||
|
||||
mInstanceToAsyncResource.emplace(processInfo, std::make_unique<AsyncResource>());
|
||||
auto requestFuture = std::async(std::launch::async, &DataRequester::Impl::request, this,
|
||||
auto requestFuture = std::async(std::launch::async, &CacheReceiver::Impl::request, this,
|
||||
std::ref(*mInstanceToAsyncResource.at(processInfo)));
|
||||
mRequestFutures.emplace_back(std::move(requestFuture));
|
||||
}
|
||||
@ -396,6 +584,107 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void receiveSync(TransferSession& session)
|
||||
{
|
||||
mFormatter->unformat(session);
|
||||
}
|
||||
|
||||
TransferSession sendRequestInfo(LlmRequest const& llmRequest)
|
||||
{
|
||||
uint64_t requestId = llmRequest.getContextPhaseParams().value().getReqId();
|
||||
auto const& contextState = llmRequest.getDataTransceiverState();
|
||||
auto const& commState = contextState.getCommState().value();
|
||||
auto const& destCacheState = contextState.getCacheState().value();
|
||||
TLLM_CHECK_WITH_INFO(mFormatter->inquireSupport(mSelfState.getCacheState().value(), destCacheState),
|
||||
"Disagg server does not currently support these cacheState.");
|
||||
|
||||
RequestInfo requestInfo(requestId, mSelfState);
|
||||
|
||||
auto disableSelectiveCacheTransfer = common::getEnvDisableSelectiveCacheTransfer()
|
||||
|| (mFormatter->getCacheManager()->getBlockManager().getNumPools() > 1);
|
||||
if (!disableSelectiveCacheTransfer)
|
||||
{
|
||||
auto* cacheManager = mFormatter->getCacheManager();
|
||||
auto blockRange
|
||||
= kv_cache_manager::BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId);
|
||||
requestInfo = RequestInfo(requestId, blockRange.getBlockHashes(), mSelfState);
|
||||
}
|
||||
|
||||
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
|
||||
std::optional<size_t> cacheBufferId = std::nullopt;
|
||||
if (agentConnectionManager != nullptr)
|
||||
{
|
||||
cacheBufferId = agentConnectionManager->getCacheTransBufferManager()->assignBufferIndexForRecv();
|
||||
TLLM_CHECK(cacheBufferId.has_value());
|
||||
// memory Desp , validSegmentIdx send
|
||||
}
|
||||
auto counterParts = mFormatter->getCounterparts(
|
||||
mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx(), destCacheState);
|
||||
|
||||
auto connections = mManager->getConnections(commState);
|
||||
std::vector<executor::kv_cache::Connection const*> counterPartConnections;
|
||||
for (auto index : counterParts)
|
||||
{
|
||||
auto const* connection = connections.at(index);
|
||||
counterPartConnections.emplace_back(connection);
|
||||
}
|
||||
auto pickUpIdx = mFormatter->pickRecvConnections(counterParts.size(), mSelfState.getCacheState().value(),
|
||||
mSelfState.getCommState().value().getSelfIdx(), destCacheState);
|
||||
for (size_t i = 0; i < counterPartConnections.size(); i++)
|
||||
{
|
||||
auto const* connection = counterPartConnections[i];
|
||||
// if Manager is agentConnectionManager, then send request info to agent
|
||||
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
|
||||
if (agentConnectionManager != nullptr)
|
||||
{
|
||||
// TODO: index -> validConnectionIdx conversion
|
||||
auto valideConnectionIdx = std::find(pickUpIdx.begin(), pickUpIdx.end(), i) - pickUpIdx.begin();
|
||||
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connection);
|
||||
TLLM_CHECK(agentConnection != nullptr);
|
||||
TLLM_CHECK(cacheBufferId.has_value());
|
||||
const_cast<executor::kv_cache::AgentConnection*>(agentConnection)
|
||||
->sendRequestAndBufferInfo(requestInfo, cacheBufferId, valideConnectionIdx);
|
||||
}
|
||||
else
|
||||
{
|
||||
sendRequestInfo(connection, requestInfo);
|
||||
}
|
||||
}
|
||||
auto const& resource = getReceiveCacheResource(llmRequest);
|
||||
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState,
|
||||
contextState, resource->mBufferManager, &llmRequest);
|
||||
}
|
||||
|
||||
std::unique_ptr<ReceiveCacheResource> const& getReceiveCacheResource(LlmRequest const& llmRequest)
|
||||
{
|
||||
std::scoped_lock<std::mutex> lock(mProcessIoResouceMutex);
|
||||
TLLM_CHECK(llmRequest.getDataTransceiverState().getCommState().has_value());
|
||||
std::string processString = "default";
|
||||
if (common::getEnvRequestKVCacheConcurrent())
|
||||
{
|
||||
processString = llmRequest.getDataTransceiverState().getCommState()->toString();
|
||||
}
|
||||
if (mProcessToResources.find(processString) == mProcessToResources.end())
|
||||
{
|
||||
mProcessToResources.emplace(processString,
|
||||
std::make_unique<ReceiveCacheResource>(
|
||||
runtime::BufferManager{std::make_shared<runtime::CudaStream>()}, runtime::CudaEvent{}));
|
||||
}
|
||||
return mProcessToResources.at(processString);
|
||||
}
|
||||
|
||||
void sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
RequestInfo::serialize(info, oss);
|
||||
auto const& serializedInfo = oss.str();
|
||||
std::size_t const infoSize = serializedInfo.size();
|
||||
TransceiverTag::Id id{TransceiverTag::Id::REQUEST_SEND};
|
||||
connection->send(executor::kv_cache::DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id));
|
||||
connection->send(executor::kv_cache::DataContext{TransceiverTag::kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize));
|
||||
connection->send(executor::kv_cache::DataContext{TransceiverTag::kINFO_TAG}, serializedInfo.data(), infoSize);
|
||||
}
|
||||
|
||||
~Impl()
|
||||
{
|
||||
for (auto&& [processInfo, asyncResource] : mInstanceToAsyncResource)
|
||||
@ -417,8 +706,8 @@ private:
|
||||
llmRequest.getContextPhaseParams().value().getReqId());
|
||||
llmRequest.setKvCacheTransferStart(std::chrono::steady_clock::now());
|
||||
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId));
|
||||
auto session = mReceiver->sendRequestInfo(llmRequest);
|
||||
mReceiver->receiveSync(session);
|
||||
auto session = sendRequestInfo(llmRequest);
|
||||
receiveSync(session);
|
||||
llmRequest.setKvCacheTransferEnd(std::chrono::steady_clock::now());
|
||||
|
||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||
@ -524,7 +813,7 @@ private:
|
||||
}
|
||||
catch (std::exception const& err)
|
||||
{
|
||||
TLLM_LOG_ERROR("Exception in DataRequester request(): request id:%ld , request context id:%ld : %s",
|
||||
TLLM_LOG_ERROR("Exception in CacheReceiver request(): request id:%ld , request context id:%ld : %s",
|
||||
requestAndPromise.mRequest->mRequestId,
|
||||
requestAndPromise.mRequest->getContextPhaseParams().value().getReqId(), err.what());
|
||||
requestAndPromise.mPromise->set_exception(std::current_exception());
|
||||
@ -533,45 +822,81 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<DataReceiver> mReceiver;
|
||||
int mDeviceId{-1};
|
||||
|
||||
std::vector<std::future<void>> mRequestFutures;
|
||||
std::unordered_map<std::string, std::unique_ptr<AsyncResource>> mInstanceToAsyncResource;
|
||||
executor::kv_cache::ConnectionManager* mManager;
|
||||
executor::DataTransceiverState mSelfState;
|
||||
std::unique_ptr<BaseCacheFormatter> mFormatter;
|
||||
std::unordered_map<std::string, std::unique_ptr<ReceiveCacheResource>> mProcessToResources;
|
||||
std::mutex mProcessIoResouceMutex;
|
||||
runtime::BufferManager mBufferManager;
|
||||
};
|
||||
|
||||
DataResponder::DataResponder(std::unique_ptr<DataSender> sender)
|
||||
: mImpl{std::make_unique<Impl>(std::move(sender))}
|
||||
void CacheSender::ImplDeleter::operator()(Impl* ptr)
|
||||
{
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
void CacheReceiver::ImplDeleter::operator()(Impl* ptr)
|
||||
{
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
CacheSender::CacheSender(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
|
||||
SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
|
||||
: mImpl{std::unique_ptr<Impl, ImplDeleter>(new Impl(manager, selfCacheState, selfIndex, std::move(formatter)))}
|
||||
{
|
||||
}
|
||||
|
||||
std::future<void> DataResponder::respondAndSendAsync(LlmRequest& llmRequest) const
|
||||
std::future<void> CacheSender::sendAsync(LlmRequest& llmRequest) const
|
||||
{
|
||||
return mImpl->respondAndSendAsync(llmRequest);
|
||||
return mImpl->sendAsync(llmRequest);
|
||||
}
|
||||
|
||||
executor::kv_cache::CommState const& DataResponder::getCommState() const
|
||||
executor::kv_cache::CommState const& CacheSender::getCommState() const
|
||||
{
|
||||
return mImpl->getCommState();
|
||||
}
|
||||
|
||||
void DataResponder::setCommState(executor::kv_cache::CommState commState)
|
||||
void CacheSender::setCommState(executor::kv_cache::CommState commState)
|
||||
{
|
||||
mImpl->setCommState(std::move(commState));
|
||||
}
|
||||
|
||||
DataResponder::~DataResponder() = default;
|
||||
CacheSender::~CacheSender() = default;
|
||||
|
||||
DataRequester::DataRequester(std::unique_ptr<DataReceiver> receiver)
|
||||
: mImpl{std::make_unique<Impl>(std::move(receiver))}
|
||||
void CacheSender::sendSync(LlmRequest const& llmRequest)
|
||||
{
|
||||
mImpl->sendSync(llmRequest);
|
||||
}
|
||||
|
||||
RequestInfo CacheSender::recvRequestInfo()
|
||||
{
|
||||
return mImpl->recvRequestInfo();
|
||||
}
|
||||
|
||||
CacheReceiver::CacheReceiver(executor::kv_cache::ConnectionManager* manager,
|
||||
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
|
||||
: mImpl{std::unique_ptr<Impl, ImplDeleter>(new Impl(manager, selfCacheState, selfIndex, std::move(formatter)))}
|
||||
{
|
||||
}
|
||||
|
||||
std::future<void> DataRequester::requestAndReceiveAsync(LlmRequest& llmRequest) const
|
||||
std::future<void> CacheReceiver::receiveAsync(LlmRequest& llmRequest) const
|
||||
{
|
||||
return mImpl->requestAndReceiveAsyncMultiThreads(llmRequest);
|
||||
}
|
||||
|
||||
DataRequester::~DataRequester() = default;
|
||||
CacheReceiver::~CacheReceiver() = default;
|
||||
|
||||
TransferSession CacheReceiver::sendRequestInfo(LlmRequest const& llmRequest)
|
||||
{
|
||||
return mImpl->sendRequestInfo(llmRequest);
|
||||
}
|
||||
|
||||
void CacheReceiver::receiveSync(TransferSession& session)
|
||||
{
|
||||
mImpl->receiveSync(session);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
#include <future>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
@ -28,16 +29,100 @@
|
||||
#include "tensorrt_llm/executor/cacheCommunicator.h"
|
||||
#include "tensorrt_llm/executor/dataTransceiverState.h"
|
||||
#include "tensorrt_llm/executor/serializeUtils.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/cudaEvent.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
|
||||
namespace kv_cache_manager
|
||||
{
|
||||
class BaseCacheFormatter;
|
||||
}
|
||||
|
||||
using BaseCacheFormatter = kv_cache_manager::BaseCacheFormatter;
|
||||
|
||||
// TODO: unify the following class into a namespace like tensorrt_llm::transmission
|
||||
using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
|
||||
using Connection = tensorrt_llm::executor::kv_cache::Connection;
|
||||
using ConnectionManager = tensorrt_llm::executor::kv_cache::ConnectionManager;
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
|
||||
class TransferSession
|
||||
{
|
||||
public:
|
||||
struct Measure
|
||||
{
|
||||
double delay; // from last token (ctx) or arrival time (gen), in ms
|
||||
double duration; // in ms
|
||||
double bandwidth; // in Gbps
|
||||
};
|
||||
|
||||
TransferSession(std::vector<Connection const*> connections, DataContext dataContext,
|
||||
executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState,
|
||||
runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr, bool recordMeasure = false)
|
||||
: mConnections(std::move(connections))
|
||||
, mDataContext(dataContext)
|
||||
, mSelfState(&selfState)
|
||||
, mOtherState(std::move(otherState))
|
||||
, mBufferManager(&bufferManager)
|
||||
, mRequest(llmRequest)
|
||||
, mRecordMeasure(recordMeasure)
|
||||
{
|
||||
TLLM_CHECK(!mConnections.empty());
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<Connection const*> const& getConnections() const;
|
||||
|
||||
// should be called only during the initialization of the TransferSession
|
||||
void setConnection(size_t idx, Connection const* conn);
|
||||
|
||||
[[nodiscard]] DataContext const& getDataContext() const;
|
||||
|
||||
[[nodiscard]] executor::DataTransceiverState const& getSelfState() const;
|
||||
|
||||
[[nodiscard]] executor::DataTransceiverState const& getOtherState() const;
|
||||
|
||||
[[nodiscard]] runtime::BufferManager const& getBufferManager() const;
|
||||
|
||||
void send(size_t idx, void const* data, size_t size);
|
||||
|
||||
void recv(size_t idx, void* data, size_t size);
|
||||
|
||||
[[nodiscard]] LlmRequest const& getLlmRequest() const;
|
||||
|
||||
// in CacheSender, the LlmRequest is not available until the sendSync is called
|
||||
void setLlmRequest(LlmRequest const& llmRequest);
|
||||
|
||||
void appendMeasure(double delay, double duration, size_t size);
|
||||
|
||||
// TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file
|
||||
void exportMeasure(std::ofstream& outFile, bool isContext) const;
|
||||
|
||||
private:
|
||||
std::vector<Connection const*> mConnections;
|
||||
DataContext mDataContext;
|
||||
executor::DataTransceiverState const* mSelfState; // stored in CacheReceiver/CacheSender
|
||||
executor::DataTransceiverState mOtherState;
|
||||
runtime::BufferManager const* mBufferManager;
|
||||
LlmRequest const* mRequest;
|
||||
std::vector<Measure> mMeasures;
|
||||
bool mRecordMeasure{false};
|
||||
};
|
||||
|
||||
struct TransceiverTag
|
||||
{
|
||||
enum class Id : uint64_t
|
||||
{
|
||||
REQUEST_SEND = 1,
|
||||
TERMINATION = 2
|
||||
};
|
||||
|
||||
static constexpr int32_t kID_TAG{19};
|
||||
static constexpr int32_t kINFO_SIZE_TAG{22};
|
||||
static constexpr int32_t kINFO_TAG{32};
|
||||
};
|
||||
|
||||
// Used to store the information that needs to be sent to the context executor to ensure the generation
|
||||
// executor smoothly receives the data.
|
||||
@ -61,10 +146,7 @@ public:
|
||||
/// @return The request ID.
|
||||
[[nodiscard]] LlmRequest::RequestIdType getRequestId() const noexcept;
|
||||
|
||||
[[nodiscard]] std::vector<size_t> const& getBlockHashes() const noexcept
|
||||
{
|
||||
return mBlockHashes;
|
||||
}
|
||||
[[nodiscard]] std::vector<size_t> const& getBlockHashes() const noexcept;
|
||||
|
||||
/// @brief Return the state of the data transceiver.
|
||||
/// @return The state of the data transceiver.
|
||||
@ -94,207 +176,81 @@ private:
|
||||
executor::DataTransceiverState mTransState;
|
||||
};
|
||||
|
||||
class TransferSession
|
||||
{
|
||||
public:
|
||||
struct Measure
|
||||
{
|
||||
double delay; // from last token (ctx) or arrival time (gen), in ms
|
||||
double duration; // in ms
|
||||
double bandwidth; // in Gbps
|
||||
};
|
||||
|
||||
TransferSession(std::vector<Connection const*> connections, DataContext dataContext,
|
||||
executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState,
|
||||
runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr, bool recordMeasure = false)
|
||||
: mConnections(std::move(connections))
|
||||
, mDataContext(dataContext)
|
||||
, mSelfState(&selfState)
|
||||
, mOtherState(std::move(otherState))
|
||||
, mBufferManager(&bufferManager)
|
||||
, mRequest(llmRequest)
|
||||
, mRecordMeasure(recordMeasure)
|
||||
{
|
||||
TLLM_CHECK(!mConnections.empty());
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<Connection const*> const& getConnections() const
|
||||
{
|
||||
return mConnections;
|
||||
}
|
||||
|
||||
// should be called only during the initialization of the TransferSession
|
||||
void setConnection(size_t idx, Connection const* conn)
|
||||
{
|
||||
mConnections.at(idx) = conn;
|
||||
}
|
||||
|
||||
[[nodiscard]] DataContext const& getDataContext() const
|
||||
{
|
||||
return mDataContext;
|
||||
}
|
||||
|
||||
[[nodiscard]] executor::DataTransceiverState const& getSelfState() const
|
||||
{
|
||||
return *mSelfState;
|
||||
}
|
||||
|
||||
[[nodiscard]] executor::DataTransceiverState const& getOtherState() const
|
||||
{
|
||||
return mOtherState;
|
||||
}
|
||||
|
||||
[[nodiscard]] runtime::BufferManager const& getBufferManager() const
|
||||
{
|
||||
return *mBufferManager;
|
||||
}
|
||||
|
||||
void send(size_t idx, void const* data, size_t size)
|
||||
{
|
||||
try
|
||||
{
|
||||
mConnections.at(idx)->send(mDataContext, data, size);
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
throw common::RequestSpecificException(
|
||||
__FILE__, __LINE__, e.what(), mRequest->mRequestId, common::RequestErrorCode::kNETWORK_ERROR);
|
||||
}
|
||||
}
|
||||
|
||||
void recv(size_t idx, void* data, size_t size)
|
||||
{
|
||||
try
|
||||
{
|
||||
mConnections.at(idx)->recv(mDataContext, data, size);
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
throw common::RequestSpecificException(
|
||||
__FILE__, __LINE__, e.what(), mRequest->mRequestId, common::RequestErrorCode::kNETWORK_ERROR);
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] LlmRequest const& getLlmRequest() const
|
||||
{
|
||||
TLLM_CHECK(mRequest != nullptr);
|
||||
return *mRequest;
|
||||
}
|
||||
|
||||
// in DataSender, the LlmRequest is not available until the sendSync is called
|
||||
void setLlmRequest(LlmRequest const& llmRequest)
|
||||
{
|
||||
mRequest = &llmRequest;
|
||||
}
|
||||
|
||||
void appendMeasure(double delay, double duration, size_t size);
|
||||
|
||||
// TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file
|
||||
void exportMeasure(std::ofstream& outFile, bool isContext) const;
|
||||
|
||||
private:
|
||||
std::vector<Connection const*> mConnections;
|
||||
DataContext mDataContext;
|
||||
executor::DataTransceiverState const* mSelfState; // stored in DataRequester/DataResponder
|
||||
executor::DataTransceiverState mOtherState;
|
||||
runtime::BufferManager const* mBufferManager;
|
||||
LlmRequest const* mRequest;
|
||||
bool mRecordMeasure;
|
||||
std::vector<Measure> mMeasures;
|
||||
};
|
||||
|
||||
// Operators required for data transmission in specific communication protocols.
|
||||
class DataSender
|
||||
{
|
||||
public:
|
||||
/// @brief Receive the request information.
|
||||
/// @return The request information.
|
||||
[[nodiscard]] virtual RequestInfo recvRequestInfo() = 0;
|
||||
|
||||
/// @brief Synchronously send data.
|
||||
/// @param llmRequest The request object to which the data belongs.
|
||||
virtual void sendSync(LlmRequest const& llmRequest) = 0;
|
||||
|
||||
/// @brief Return the internal communicator status.
|
||||
/// @return The communicator status.
|
||||
[[nodiscard]] virtual executor::kv_cache::CommState const& getCommState() const = 0;
|
||||
|
||||
/// @brief Reset the internal communicator status.
|
||||
/// @param commState The communicator status.
|
||||
virtual void setCommState(executor::kv_cache::CommState commState) = 0;
|
||||
|
||||
[[nodiscard]] virtual size_t getCounterpartsCount(LlmRequest::RequestIdType requestId) const = 0;
|
||||
|
||||
virtual void release(LlmRequest::RequestIdType requestId) = 0;
|
||||
|
||||
/// @brief Destructor.
|
||||
virtual ~DataSender() = default;
|
||||
};
|
||||
|
||||
// Operators required for data transmission in specific communication protocols.
|
||||
class DataReceiver
|
||||
{
|
||||
public:
|
||||
/// @brief Send the request information.
|
||||
/// @param llmRequest The request object to which the information belongs.
|
||||
virtual TransferSession sendRequestInfo(LlmRequest const& llmRequest) = 0;
|
||||
|
||||
/// @brief Synchronously receive data.
|
||||
/// @param session The transfer session.
|
||||
virtual void receiveSync(TransferSession& session) = 0;
|
||||
|
||||
/// @brief Destructor.
|
||||
virtual ~DataReceiver() = default;
|
||||
};
|
||||
|
||||
class DataResponder
|
||||
class CacheSender
|
||||
{
|
||||
public:
|
||||
/// @brief Constructor.
|
||||
/// @param sender The sender used at the underlying level.
|
||||
explicit DataResponder(std::unique_ptr<DataSender> sender);
|
||||
CacheSender(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
|
||||
SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter);
|
||||
|
||||
CacheSender() = default;
|
||||
|
||||
/// @brief Asynchronously respond to the request and send data.
|
||||
/// @param llmRequest Request object. Its data should be ready when called, and the data for this request
|
||||
/// should remain valid until future synchronization.
|
||||
/// @return Once the data is fully sent, the future object will become valid.
|
||||
[[nodiscard]] std::future<void> respondAndSendAsync(LlmRequest& llmRequest) const;
|
||||
[[nodiscard]] virtual std::future<void> sendAsync(LlmRequest& llmRequest) const;
|
||||
|
||||
/// @brief Return the internal communicator status.
|
||||
/// @return The communicator status.
|
||||
[[nodiscard]] executor::kv_cache::CommState const& getCommState() const;
|
||||
[[nodiscard]] virtual executor::kv_cache::CommState const& getCommState() const;
|
||||
|
||||
/// @brief Reset the internal communicator status.
|
||||
/// @param commState The communicator status.
|
||||
void setCommState(executor::kv_cache::CommState commState);
|
||||
virtual void setCommState(executor::kv_cache::CommState commState);
|
||||
|
||||
/// @brief Synchronously send data.
|
||||
/// @param llmRequest The request object to which the data belongs.
|
||||
virtual void sendSync(LlmRequest const& llmRequest);
|
||||
|
||||
/// @brief Receive request information.
|
||||
/// @param llmRequest The request object to which the data belongs.
|
||||
virtual RequestInfo recvRequestInfo();
|
||||
|
||||
/// @brief Destructor.
|
||||
~DataResponder();
|
||||
virtual ~CacheSender();
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> mImpl;
|
||||
|
||||
struct ImplDeleter
|
||||
{
|
||||
void operator()(Impl*);
|
||||
};
|
||||
|
||||
std::unique_ptr<Impl, ImplDeleter> mImpl;
|
||||
};
|
||||
|
||||
class DataRequester
|
||||
class CacheReceiver
|
||||
{
|
||||
public:
|
||||
/// @brief Constructor.
|
||||
/// @param receiver The receiver used at the underlying level.
|
||||
explicit DataRequester(std::unique_ptr<DataReceiver> receiver);
|
||||
CacheReceiver(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
|
||||
SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter);
|
||||
|
||||
CacheReceiver() = default;
|
||||
|
||||
/// @brief Asynchronously send a request to receive data.
|
||||
/// @param llmRequest Request object. Its data should be in an allocated but unwritten state when called, and the
|
||||
/// data for this request should remain intact only after future synchronization.
|
||||
/// @return Once the data is fully received, the future object will become valid.
|
||||
[[nodiscard]] std::future<void> requestAndReceiveAsync(LlmRequest& llmRequest) const;
|
||||
[[nodiscard]] virtual std::future<void> receiveAsync(LlmRequest& llmRequest) const;
|
||||
|
||||
virtual TransferSession sendRequestInfo(LlmRequest const& llmRequest);
|
||||
|
||||
virtual void receiveSync(TransferSession& session);
|
||||
/// @brief Destructor.
|
||||
~DataRequester();
|
||||
virtual ~CacheReceiver();
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> mImpl;
|
||||
|
||||
struct ImplDeleter
|
||||
{
|
||||
void operator()(Impl*);
|
||||
};
|
||||
|
||||
std::unique_ptr<Impl, ImplDeleter> mImpl;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -1,285 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "dataTransceiverImpl.h"
|
||||
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
|
||||
#include <filesystem>
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
|
||||
static int32_t tagFromRequestId(LlmRequest::RequestIdType requestId)
|
||||
{
|
||||
constexpr int32_t kDATA_TAG{43};
|
||||
return ((requestId & 0xFFF) << 8) | (kDATA_TAG & 0xFF);
|
||||
}
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
static fs::path getTransferOutputPath(char const* tag)
|
||||
{
|
||||
auto outputPath = common::getEnvKVCacheTransferOutputPath();
|
||||
if (!outputPath.empty())
|
||||
{
|
||||
auto rank = mpi::MpiComm::world().getRank();
|
||||
auto path = fs::path(outputPath);
|
||||
fs::create_directories(path);
|
||||
return path / ("rank_" + std::to_string(rank) + "_" + tag + ".csv");
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
|
||||
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
|
||||
: mManager{manager}
|
||||
, mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}}
|
||||
, mFormatter(std::move(formatter))
|
||||
, mBufferManager{std::make_shared<runtime::CudaStream>()}
|
||||
{
|
||||
TLLM_CHECK(mManager);
|
||||
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
|
||||
}
|
||||
|
||||
[[nodiscard]] RequestInfo DataSenderImpl::recvRequestInfo()
|
||||
{
|
||||
using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
|
||||
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
|
||||
bool isAgent = agentConnectionManager != nullptr;
|
||||
|
||||
auto agentRecvFun = [&](RequestInfo& requestInfo)
|
||||
{
|
||||
auto const* connection = agentConnectionManager->recvConnectionAndRequestInfo(requestInfo);
|
||||
return connection;
|
||||
};
|
||||
Id id;
|
||||
RequestInfo info;
|
||||
auto const* connection
|
||||
= isAgent ? agentRecvFun(info) : mManager->recvConnect(DataContext{kID_TAG}, &id, sizeof(id));
|
||||
if (!isAgent)
|
||||
{
|
||||
TLLM_CHECK(id == Id::REQUEST_SEND);
|
||||
std::uint64_t infoSize{0};
|
||||
connection->recv(executor::kv_cache::DataContext{kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize));
|
||||
std::string serializedInfo;
|
||||
serializedInfo.resize(infoSize);
|
||||
connection->recv(executor::kv_cache::DataContext{kINFO_TAG}, serializedInfo.data(), infoSize);
|
||||
std::istringstream iss(serializedInfo);
|
||||
info = RequestInfo::deserialize(iss);
|
||||
}
|
||||
|
||||
auto requestId = info.getRequestId();
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mFormatter->inquireSupport(mSelfState.getCacheState().value(), info.getTransState().getCacheState().value()),
|
||||
"Disagg server does not currently support these cacheState, please check the cacheState of the context and gen "
|
||||
"executors");
|
||||
auto peerRelativeRanks = executor::kv_cache::targetIRanks(info.getTransState().getCacheState().value(),
|
||||
mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx())
|
||||
.mIRanks;
|
||||
int peerIdx = std::distance(peerRelativeRanks.begin(),
|
||||
std::find(
|
||||
peerRelativeRanks.begin(), peerRelativeRanks.end(), info.getTransState().getCommState()->getSelfIdx()));
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(mMtxForMap);
|
||||
auto it = mRequestToSession.find(requestId);
|
||||
if (it == mRequestToSession.end())
|
||||
{
|
||||
auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr),
|
||||
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager, nullptr,
|
||||
!common::getEnvKVCacheTransferOutputPath().empty());
|
||||
it = mRequestToSession.emplace(requestId, std::move(session)).first;
|
||||
}
|
||||
it->second.setConnection(peerIdx, connection);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
void DataSenderImpl::sendSync(LlmRequest const& llmRequest)
|
||||
{
|
||||
auto it = mRequestToSession.find(llmRequest.mRequestId);
|
||||
TLLM_CHECK(it != mRequestToSession.end());
|
||||
auto& session = it->second;
|
||||
session.setLlmRequest(llmRequest);
|
||||
mFormatter->format(session);
|
||||
}
|
||||
|
||||
[[nodiscard]] executor::kv_cache::CommState const& DataSenderImpl::getCommState() const
|
||||
{
|
||||
return mSelfState.getCommState().value();
|
||||
}
|
||||
|
||||
void DataSenderImpl::setCommState(executor::kv_cache::CommState commState)
|
||||
{
|
||||
mSelfState.setCommState(std::move(commState));
|
||||
}
|
||||
|
||||
[[nodiscard]] size_t DataSenderImpl::getCounterpartsCount(LlmRequest::RequestIdType requestId) const
|
||||
{
|
||||
auto it = mRequestToSession.find(requestId);
|
||||
TLLM_CHECK(it != mRequestToSession.end());
|
||||
return it->second.getConnections().size();
|
||||
}
|
||||
|
||||
void DataSenderImpl::release(LlmRequest::RequestIdType requestId)
|
||||
{
|
||||
auto it = mRequestToSession.find(requestId);
|
||||
TLLM_CHECK(it != mRequestToSession.end());
|
||||
std::unique_lock<std::mutex> lk(mMtxForMap);
|
||||
if (!common::getEnvKVCacheTransferOutputPath().empty())
|
||||
{
|
||||
if (!mMeasuresFile.is_open())
|
||||
{
|
||||
auto outputPath = getTransferOutputPath("send");
|
||||
mMeasuresFile.open(outputPath);
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str());
|
||||
}
|
||||
it->second.exportMeasure(mMeasuresFile, true);
|
||||
}
|
||||
mRequestToSession.erase(it);
|
||||
}
|
||||
|
||||
DataReceiverImpl::DataReceiverImpl(executor::kv_cache::ConnectionManager* manager,
|
||||
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
|
||||
: mManager{manager}
|
||||
, mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}}
|
||||
, mFormatter(std::move(formatter))
|
||||
{
|
||||
TLLM_CHECK(mManager);
|
||||
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
|
||||
TLLM_CHECK(mFormatter);
|
||||
}
|
||||
|
||||
TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
|
||||
{
|
||||
uint64_t requestId = llmRequest.getContextPhaseParams().value().getReqId();
|
||||
auto const& contextState = llmRequest.getDataTransceiverState();
|
||||
auto const& commState = contextState.getCommState().value();
|
||||
auto const& destCacheState = contextState.getCacheState().value();
|
||||
TLLM_CHECK_WITH_INFO(mFormatter->inquireSupport(mSelfState.getCacheState().value(), destCacheState),
|
||||
"Disagg server does not currently support these cacheState.");
|
||||
|
||||
RequestInfo requestInfo(requestId, mSelfState);
|
||||
|
||||
auto disableSelectiveCacheTransfer = common::getEnvDisableSelectiveCacheTransfer()
|
||||
|| (mFormatter->getCacheManager()->getBlockManager().getNumPools() > 1);
|
||||
if (!disableSelectiveCacheTransfer)
|
||||
{
|
||||
auto* cacheManager = mFormatter->getCacheManager();
|
||||
auto blockRange
|
||||
= kv_cache_manager::BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId);
|
||||
requestInfo = RequestInfo(requestId, blockRange.getBlockHashes(), mSelfState);
|
||||
}
|
||||
|
||||
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
|
||||
std::optional<size_t> cacheBufferId = std::nullopt;
|
||||
if (agentConnectionManager != nullptr)
|
||||
{
|
||||
cacheBufferId = agentConnectionManager->getCacheTransBufferManager()->assignBufferIndexForRecv();
|
||||
TLLM_CHECK(cacheBufferId.has_value());
|
||||
// memory Desp , validSegmentIdx send
|
||||
}
|
||||
auto counterParts = mFormatter->getCounterparts(
|
||||
mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx(), destCacheState);
|
||||
|
||||
auto connections = mManager->getConnections(commState);
|
||||
std::vector<executor::kv_cache::Connection const*> counterPartConnections;
|
||||
for (auto index : counterParts)
|
||||
{
|
||||
auto const* connection = connections.at(index);
|
||||
counterPartConnections.emplace_back(connection);
|
||||
}
|
||||
auto pickUpIdx = mFormatter->pickRecvConnections(counterParts.size(), mSelfState.getCacheState().value(),
|
||||
mSelfState.getCommState().value().getSelfIdx(), destCacheState);
|
||||
for (size_t i = 0; i < counterPartConnections.size(); i++)
|
||||
{
|
||||
auto const* connection = counterPartConnections[i];
|
||||
// if Manager is agentConnectionManager, then send request info to agent
|
||||
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
|
||||
if (agentConnectionManager != nullptr)
|
||||
{
|
||||
// TODO: index -> validConnectionIdx conversion
|
||||
auto valideConnectionIdx = std::find(pickUpIdx.begin(), pickUpIdx.end(), i) - pickUpIdx.begin();
|
||||
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connection);
|
||||
TLLM_CHECK(agentConnection != nullptr);
|
||||
TLLM_CHECK(cacheBufferId.has_value());
|
||||
const_cast<executor::kv_cache::AgentConnection*>(agentConnection)
|
||||
->sendRequestAndBufferInfo(requestInfo, cacheBufferId, valideConnectionIdx);
|
||||
}
|
||||
else
|
||||
{
|
||||
sendRequestInfo(connection, requestInfo);
|
||||
}
|
||||
}
|
||||
auto const& resource = getReceiveCacheResource(llmRequest);
|
||||
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState,
|
||||
contextState, resource->mBufferManager, &llmRequest, !common::getEnvKVCacheTransferOutputPath().empty());
|
||||
}
|
||||
|
||||
void DataReceiverImpl::receiveSync(TransferSession& session)
|
||||
{
|
||||
mFormatter->unformat(session);
|
||||
if (!common::getEnvKVCacheTransferOutputPath().empty())
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mMeasuresFileMutex);
|
||||
if (!mMeasuresFile.is_open())
|
||||
{
|
||||
auto outputPath = getTransferOutputPath("recv");
|
||||
mMeasuresFile.open(outputPath);
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str());
|
||||
}
|
||||
session.exportMeasure(mMeasuresFile, false);
|
||||
}
|
||||
}
|
||||
|
||||
void DataReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
RequestInfo::serialize(info, oss);
|
||||
auto const& serializedInfo = oss.str();
|
||||
std::size_t const infoSize = serializedInfo.size();
|
||||
Id id{Id::REQUEST_SEND};
|
||||
connection->send(executor::kv_cache::DataContext{kID_TAG}, &id, sizeof(id));
|
||||
connection->send(executor::kv_cache::DataContext{kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize));
|
||||
connection->send(executor::kv_cache::DataContext{kINFO_TAG}, serializedInfo.data(), infoSize);
|
||||
}
|
||||
|
||||
std::unique_ptr<DataReceiverImpl::ReceiveCacheResource> const& DataReceiverImpl::getReceiveCacheResource(
|
||||
LlmRequest const& llmRequest)
|
||||
{
|
||||
std::scoped_lock<std::mutex> lock(mProcessIoResouceMutex);
|
||||
TLLM_CHECK(llmRequest.getDataTransceiverState().getCommState().has_value());
|
||||
std::string processString = "default";
|
||||
if (common::getEnvRequestKVCacheConcurrent())
|
||||
{
|
||||
processString = llmRequest.getDataTransceiverState().getCommState()->toString();
|
||||
}
|
||||
if (mProcessToResources.find(processString) == mProcessToResources.end())
|
||||
{
|
||||
mProcessToResources.emplace(processString,
|
||||
std::make_unique<ReceiveCacheResource>(
|
||||
runtime::BufferManager{std::make_shared<runtime::CudaStream>()}, runtime::CudaEvent{}));
|
||||
}
|
||||
|
||||
return mProcessToResources.at(processString);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
@ -1,113 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cacheFormatter.h"
|
||||
#include "cacheTransBuffer.h"
|
||||
#include "dataTransceiver.h"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
struct TransceiverTag
|
||||
{
|
||||
enum class Id : uint64_t
|
||||
{
|
||||
REQUEST_SEND = 1,
|
||||
TERMINATION = 2
|
||||
};
|
||||
|
||||
static constexpr int32_t kID_TAG{19};
|
||||
static constexpr int32_t kINFO_SIZE_TAG{22};
|
||||
static constexpr int32_t kINFO_TAG{32};
|
||||
};
|
||||
|
||||
using BaseCacheFormatter = kv_cache_manager::BaseCacheFormatter;
|
||||
|
||||
class DataSenderImpl : public DataSender, public TransceiverTag
|
||||
{
|
||||
public:
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
|
||||
DataSenderImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
|
||||
SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter);
|
||||
|
||||
[[nodiscard]] RequestInfo recvRequestInfo() override;
|
||||
|
||||
void sendSync(LlmRequest const& llmRequest) override;
|
||||
|
||||
[[nodiscard]] executor::kv_cache::CommState const& getCommState() const override;
|
||||
|
||||
void setCommState(executor::kv_cache::CommState commState) override;
|
||||
|
||||
[[nodiscard]] size_t getCounterpartsCount(LlmRequest::RequestIdType requestId) const override;
|
||||
|
||||
void release(LlmRequest::RequestIdType requestId) override;
|
||||
|
||||
private:
|
||||
executor::kv_cache::ConnectionManager* mManager;
|
||||
std::map<LlmRequest::RequestIdType, TransferSession> mRequestToSession;
|
||||
executor::DataTransceiverState mSelfState;
|
||||
std::unique_ptr<BaseCacheFormatter> mFormatter;
|
||||
std::mutex mMtxForMap;
|
||||
runtime::BufferManager mBufferManager;
|
||||
std::ofstream mMeasuresFile;
|
||||
};
|
||||
|
||||
class DataReceiverImpl : public DataReceiver, public TransceiverTag
|
||||
{
|
||||
public:
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
|
||||
DataReceiverImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
|
||||
SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter);
|
||||
|
||||
TransferSession sendRequestInfo(LlmRequest const& llmRequest) override;
|
||||
|
||||
void receiveSync(TransferSession& session) override;
|
||||
|
||||
private:
|
||||
struct ReceiveCacheResource
|
||||
{
|
||||
runtime::BufferManager mBufferManager;
|
||||
runtime::CudaEvent mCudaEvent;
|
||||
|
||||
ReceiveCacheResource(runtime::BufferManager&& bufferManager, runtime::CudaEvent&& cudaEvent)
|
||||
: mBufferManager(bufferManager)
|
||||
, mCudaEvent(std::move(cudaEvent))
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
static void sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info);
|
||||
|
||||
[[nodiscard]] std::unique_ptr<ReceiveCacheResource> const& getReceiveCacheResource(LlmRequest const& llmRequest);
|
||||
|
||||
executor::kv_cache::ConnectionManager* mManager;
|
||||
executor::DataTransceiverState mSelfState;
|
||||
std::unique_ptr<BaseCacheFormatter> mFormatter;
|
||||
std::unordered_map<std::string, std::unique_ptr<ReceiveCacheResource>> mProcessToResources;
|
||||
std::mutex mProcessIoResouceMutex;
|
||||
std::ofstream mMeasuresFile;
|
||||
std::mutex mMeasuresFileMutex;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
@ -84,7 +84,7 @@ bool MLACacheFormatter::needSendCache(
|
||||
return selfTpRank % (selfTPNum / destTPNum) == 0;
|
||||
}
|
||||
|
||||
void MLACacheFormatter::format(TransferSession& session)
|
||||
void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& session)
|
||||
{
|
||||
NVTX3_SCOPED_RANGE(MLACacheFormatter_format);
|
||||
auto const& llmRequest = session.getLlmRequest();
|
||||
@ -292,7 +292,7 @@ void MLACacheFormatter::format(TransferSession& session)
|
||||
mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID: %ld.", llmRequest.mRequestId);
|
||||
}
|
||||
|
||||
void MLACacheFormatter::unformat(TransferSession& session)
|
||||
void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& session)
|
||||
{
|
||||
NVTX3_SCOPED_RANGE(MLACacheFormatter_unformat);
|
||||
auto const& llmRequest = session.getLlmRequest();
|
||||
|
||||
@ -35,9 +35,9 @@ public:
|
||||
TLLM_CHECK(mCacheTransBufferManager);
|
||||
}
|
||||
|
||||
void format(TransferSession& session) override;
|
||||
void format(tensorrt_llm::batch_manager::TransferSession& session) override;
|
||||
|
||||
void unformat(TransferSession& session) override;
|
||||
void unformat(tensorrt_llm::batch_manager::TransferSession& session) override;
|
||||
|
||||
[[nodiscard]] bool inquireSupport(CacheState const& selfConfig, CacheState const& destConfig) const override;
|
||||
|
||||
|
||||
@ -80,6 +80,8 @@ using namespace tensorrt_llm::runtime;
|
||||
namespace tc = tensorrt_llm::common;
|
||||
namespace tk = tensorrt_llm::kernels;
|
||||
|
||||
using tensorrt_llm::batch_manager::CacheTransceiverFactory;
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
|
||||
|
||||
@ -55,6 +55,11 @@ namespace tensorrt_llm::mpi
|
||||
class MpiWaitThread;
|
||||
} // namespace tensorrt_llm::mpi
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
class BaseCacheTransceiver;
|
||||
}
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
|
||||
@ -79,7 +84,6 @@ class LlmRequest;
|
||||
class RuntimeBuffers;
|
||||
class BasePeftCacheManager;
|
||||
class GuidedDecoder;
|
||||
class BaseCacheTransceiver;
|
||||
|
||||
// Algorithms
|
||||
class CapacityScheduler;
|
||||
|
||||
@ -104,7 +104,7 @@ void AgentConnection::send(DataContext const& ctx, void const* data, size_t size
|
||||
MemoryDesc srcDesc{
|
||||
reinterpret_cast<uintptr_t>(data), size, static_cast<uint32_t>(mAgentConnectionManager->getDeviceId())};
|
||||
MemoryDescs srcDescs{MemoryType::kVRAM, {srcDesc}};
|
||||
auto dstBaseDesc = mSenderState.mReceiverBufferDesc;
|
||||
auto dstBaseDesc = mSenderState.mCacheReceiverBufferDesc;
|
||||
auto offset = size / mSenderState.mOffsetRatio.second * mSenderState.mOffsetRatio.first;
|
||||
MemoryDesc dstDesc{dstBaseDesc.getAddr() + offset, size, dstBaseDesc.getDeviceId()};
|
||||
TLLM_LOG_DEBUG(
|
||||
@ -162,9 +162,9 @@ void AgentConnection::sendRequestAndBufferInfo(
|
||||
}
|
||||
|
||||
void AgentConnection::setSenderState(
|
||||
MemoryDesc mReceiverBufferDesc, int validSegmentIdx, std::pair<size_t, size_t> offsetRatio)
|
||||
MemoryDesc mCacheReceiverBufferDesc, int validSegmentIdx, std::pair<size_t, size_t> offsetRatio)
|
||||
{
|
||||
mSenderState.mReceiverBufferDesc = mReceiverBufferDesc;
|
||||
mSenderState.mCacheReceiverBufferDesc = mCacheReceiverBufferDesc;
|
||||
mSenderState.validSegmentIdx = validSegmentIdx;
|
||||
mSenderState.mOffsetRatio = offsetRatio;
|
||||
}
|
||||
|
||||
@ -175,7 +175,8 @@ public:
|
||||
void recv(DataContext const& ctx, void* data, size_t size) const override;
|
||||
void sendRequestAndBufferInfo(
|
||||
batch_manager::RequestInfo& requestInfo, std::optional<size_t> cacheBufferId, int validConnectionIdx);
|
||||
void setSenderState(MemoryDesc mReceiverBufferDesc, int valideSegmentIdx, std::pair<size_t, size_t> offsetRatio);
|
||||
void setSenderState(
|
||||
MemoryDesc mCacheReceiverBufferDesc, int valideSegmentIdx, std::pair<size_t, size_t> offsetRatio);
|
||||
[[nodiscard]] std::optional<size_t> getCacheBufferId() const;
|
||||
void setHasLoadRemoteAgent(bool hasLoadRemoteAgent);
|
||||
[[nodiscard]] bool hasLoadRemoteAgent() const;
|
||||
@ -186,7 +187,7 @@ private:
|
||||
|
||||
struct SenderState
|
||||
{
|
||||
MemoryDesc mReceiverBufferDesc{nullptr, 0, 0};
|
||||
MemoryDesc mCacheReceiverBufferDesc{nullptr, 0, 0};
|
||||
int validSegmentIdx{0};
|
||||
std::pair<size_t, size_t> mOffsetRatio;
|
||||
SenderState() = default;
|
||||
|
||||
@ -18,7 +18,7 @@
|
||||
#include "ucxCacheCommunicator.h"
|
||||
#if ENABLE_UCX
|
||||
|
||||
#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h"
|
||||
#include "tensorrt_llm/batch_manager/dataTransceiver.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/tllmException.h"
|
||||
#include "tensorrt_llm/executor/cache_transmission/ucx_utils/connection.h"
|
||||
@ -114,8 +114,8 @@ void UcxConnection::sendConnectionId(DataContext const& ctx, void const* data, s
|
||||
std::future<void> future = promise.get_future();
|
||||
auto completionCallback = [&](ucs_status_t, ucxx::RequestCallbackUserData) -> void { promise.set_value(); };
|
||||
|
||||
uint64_t tag
|
||||
= ((mSendTagPrefix & 0xFFFFFFFF) << 32) | static_cast<uint64_t>(batch_manager::TransceiverTag::kID_TAG);
|
||||
uint64_t tag = ((mSendTagPrefix & 0xFFFFFFFF) << 32)
|
||||
| static_cast<uint64_t>(tensorrt_llm::batch_manager::TransceiverTag::kID_TAG);
|
||||
std::vector<char> buffer(size + sizeof(mConnectionId));
|
||||
memcpy(buffer.data(), data, size);
|
||||
memcpy(buffer.data() + size, &mConnectionIdInPeer, sizeof(mConnectionIdInPeer));
|
||||
@ -133,7 +133,7 @@ void UcxConnection::sendConnectionId(DataContext const& ctx, void const* data, s
|
||||
|
||||
void UcxConnection::send(DataContext const& ctx, void const* data, size_t size) const
|
||||
{
|
||||
if (ctx.getTag() == batch_manager::TransceiverTag::kID_TAG)
|
||||
if (ctx.getTag() == tensorrt_llm::batch_manager::TransceiverTag::kID_TAG)
|
||||
{
|
||||
sendConnectionId(ctx, data, size);
|
||||
return;
|
||||
|
||||
@ -26,7 +26,7 @@
|
||||
|
||||
#include "tensorrt_llm/batch_manager/cacheFormatter.h"
|
||||
#include "tensorrt_llm/batch_manager/cacheTransceiver.h"
|
||||
#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h"
|
||||
#include "tensorrt_llm/batch_manager/dataTransceiver.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
@ -45,7 +45,6 @@
|
||||
#include <gmock/gmock.h>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <tensorrt_llm/batch_manager/dataTransceiverImpl.h>
|
||||
#include <tensorrt_llm/batch_manager/mlaCacheFormatter.h>
|
||||
#include <tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h>
|
||||
|
||||
@ -84,6 +83,7 @@ class UcxCommTest : public ::testing::Test
|
||||
};
|
||||
|
||||
using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
|
||||
using TransceiverTag = tensorrt_llm::batch_manager::TransceiverTag;
|
||||
|
||||
TEST_F(UcxCommTest, Basic)
|
||||
{
|
||||
|
||||
@ -26,7 +26,6 @@
|
||||
|
||||
#include "tensorrt_llm/batch_manager/cacheFormatter.h"
|
||||
#include "tensorrt_llm/batch_manager/cacheTransceiver.h"
|
||||
#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
@ -154,106 +153,6 @@ TEST_F(CacheConfigTest, EqualTo)
|
||||
EXPECT_EQ(state0, state1);
|
||||
}
|
||||
|
||||
// ---------------------------------------
|
||||
// MockTransceiverTest
|
||||
// ---------------------------------------
|
||||
|
||||
class MockDataSender : public DataSender
|
||||
{
|
||||
public:
|
||||
MockDataSender()
|
||||
{
|
||||
ON_CALL(*this, getCommState).WillByDefault(ReturnRef(mState));
|
||||
ON_CALL(*this, recvRequestInfo)
|
||||
.WillByDefault(Return(RequestInfo{0,
|
||||
texec::DataTransceiverState{
|
||||
texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 1, {10}, nvinfer1::DataType::kFLOAT},
|
||||
texec::kv_cache::CommState{std::vector<SizeType32>{0}, 0}}}));
|
||||
ON_CALL(*this, getCounterpartsCount).WillByDefault(Return(1));
|
||||
}
|
||||
|
||||
MOCK_METHOD(RequestInfo, recvRequestInfo, (), (override));
|
||||
MOCK_METHOD(void, sendSync, (LlmRequest const&), (override));
|
||||
MOCK_METHOD(texec::kv_cache::CommState const&, getCommState, (), (const));
|
||||
MOCK_METHOD(void, setCommState, (texec::kv_cache::CommState), (override));
|
||||
MOCK_METHOD(size_t, getCounterpartsCount, (LlmRequest::RequestIdType), (const));
|
||||
MOCK_METHOD(void, release, (LlmRequest::RequestIdType), (override));
|
||||
|
||||
private:
|
||||
static texec::kv_cache::CommState mState;
|
||||
};
|
||||
|
||||
texec::kv_cache::CommState MockDataSender::mState;
|
||||
|
||||
class MockDataReceiver : public DataReceiver
|
||||
{
|
||||
public:
|
||||
MOCK_METHOD(TransferSession, sendRequestInfo, (LlmRequest const&), (override));
|
||||
MOCK_METHOD(void, receiveSync, (TransferSession&), (override));
|
||||
};
|
||||
|
||||
class MockTransceiverTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-type-member-init)
|
||||
{
|
||||
public:
|
||||
void SetUp() override {}
|
||||
|
||||
void TearDown() override {}
|
||||
|
||||
static auto makeLlmRequest(
|
||||
LlmRequest::RequestIdType requestId = 0, SizeType32 maxNewTokens = 1, VecTokens inputTokens = {-1})
|
||||
{
|
||||
texec::Request request{std::move(inputTokens), maxNewTokens};
|
||||
auto state = std::make_unique<texec::DataTransceiverState>();
|
||||
auto stats = texec::ContextPhaseParams({}, requestId, state.release(), std::nullopt);
|
||||
request.setContextPhaseParams(std::move(stats));
|
||||
return std::make_unique<LlmRequest>(requestId, std::move(request));
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(MockTransceiverTest, MpiResponderBasic)
|
||||
{
|
||||
if (tensorrt_llm::mpi::MpiComm::world().getSize() > 2)
|
||||
{
|
||||
GTEST_SKIP() << "mpirun with procs<=2 is required to run this test.";
|
||||
}
|
||||
auto sender = std::make_unique<MockDataSender>();
|
||||
EXPECT_CALL(*sender, recvRequestInfo)
|
||||
.WillOnce(Return(RequestInfo{0,
|
||||
texec::DataTransceiverState{
|
||||
texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 1, {4}, nvinfer1::DataType::kFLOAT},
|
||||
texec::kv_cache::CommState{std::vector<SizeType32>{0}, 0}}}));
|
||||
EXPECT_CALL(*sender, sendSync).WillOnce(Return());
|
||||
EXPECT_CALL(*sender, getCounterpartsCount).WillOnce(Return(1));
|
||||
EXPECT_CALL(*sender, release).WillOnce(Return());
|
||||
|
||||
DataResponder responder{std::move(sender)};
|
||||
auto request = makeLlmRequest(0);
|
||||
auto future = responder.respondAndSendAsync(*request);
|
||||
future.get();
|
||||
}
|
||||
|
||||
TEST_F(MockTransceiverTest, MpiRequesterBasic)
|
||||
{
|
||||
|
||||
if (tensorrt_llm::mpi::MpiComm::world().getSize() > 2)
|
||||
{
|
||||
GTEST_SKIP() << "mpirun with procs<=2 is required to run this test.";
|
||||
}
|
||||
auto receiver = std::make_unique<MockDataReceiver>();
|
||||
auto state = std::make_unique<texec::DataTransceiverState>();
|
||||
state->setCommState(texec::kv_cache::CommState{std::vector<int>{0}});
|
||||
EXPECT_CALL(*receiver, sendRequestInfo)
|
||||
.WillOnce(Return(TransferSession({nullptr}, DataContext{0}, *state, *state,
|
||||
tensorrt_llm::runtime::BufferManager{std::make_shared<tr::CudaStream>()}, nullptr)));
|
||||
EXPECT_CALL(*receiver, receiveSync).WillOnce(Return());
|
||||
DataRequester requester{std::move(receiver)};
|
||||
auto request = makeLlmRequest(0);
|
||||
auto stats = texec::ContextPhaseParams({}, 0, state.release(), std::nullopt);
|
||||
request->setContextPhaseParams(std::move(stats));
|
||||
auto future = requester.requestAndReceiveAsync(*request);
|
||||
future.get();
|
||||
}
|
||||
|
||||
// TODO: Restore multi-rank tests.
|
||||
|
||||
// ---------------------------------------
|
||||
@ -397,15 +296,13 @@ protected:
|
||||
mCacheTransBufferManager = std::make_unique<CacheTransBufferManager>(mManager.get(), maxNumTokens);
|
||||
if (isSender)
|
||||
{
|
||||
mResponder = std::make_unique<DataResponder>(
|
||||
std::make_unique<DataSenderImpl>(mConnectionManager.get(), *mCacheState, mlocalRank,
|
||||
std::make_unique<CacheFormatter>(mManager.get(), mCacheTransBufferManager.get())));
|
||||
mSender = std::make_unique<CacheSender>(mConnectionManager.get(), *mCacheState, mlocalRank,
|
||||
std::make_unique<CacheFormatter>(mManager.get(), mCacheTransBufferManager.get()));
|
||||
}
|
||||
else
|
||||
{
|
||||
mRequester = std::make_unique<DataRequester>(
|
||||
std::make_unique<DataReceiverImpl>(mConnectionManager.get(), *mCacheState, mlocalRank,
|
||||
std::make_unique<CacheFormatter>(mManager.get(), mCacheTransBufferManager.get())));
|
||||
mRequester = std::make_unique<CacheReceiver>(mConnectionManager.get(), *mCacheState, mlocalRank,
|
||||
std::make_unique<CacheFormatter>(mManager.get(), mCacheTransBufferManager.get()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -435,11 +332,11 @@ protected:
|
||||
// fill cache with tokens (= request length), for reuse test
|
||||
TLLM_CUDA_CHECK(cudaMemset(block.data(), llmRequest->getPromptLen(), block.getSizeInBytes()));
|
||||
}
|
||||
mFutures.emplace_back(mResponder->respondAndSendAsync(*llmRequest));
|
||||
mFutures.emplace_back(mSender->sendAsync(*llmRequest));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto future = mRequester->requestAndReceiveAsync(*llmRequest);
|
||||
auto future = mRequester->receiveAsync(*llmRequest);
|
||||
future.get();
|
||||
TLLM_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
auto blockRange = BlockRange::fromAllBlockIds(*mManager, llmRequest->mRequestId);
|
||||
@ -460,8 +357,8 @@ protected:
|
||||
SizeType32 mMaxNumSequences{};
|
||||
std::unique_ptr<KVCacheManager> mManager;
|
||||
std::unique_ptr<CacheTransBufferManager> mCacheTransBufferManager;
|
||||
std::unique_ptr<DataResponder> mResponder;
|
||||
std::unique_ptr<DataRequester> mRequester;
|
||||
std::unique_ptr<CacheSender> mSender;
|
||||
std::unique_ptr<CacheReceiver> mRequester;
|
||||
std::unique_ptr<texec::kv_cache::CacheState> mCacheState;
|
||||
std::unique_ptr<texec::kv_cache::CommState> mContextCommState;
|
||||
std::vector<std::future<void>> mFutures;
|
||||
@ -789,13 +686,13 @@ protected:
|
||||
|
||||
if (mIsContext)
|
||||
{
|
||||
mResponder = std::make_unique<DataResponder>(std::make_unique<DataSenderImpl>(
|
||||
mConnectionManager.get(), *mCacheState, mRankInInstance, makeFormatter()));
|
||||
mSender = std::make_unique<CacheSender>(
|
||||
mConnectionManager.get(), *mCacheState, mRankInInstance, makeFormatter());
|
||||
}
|
||||
else
|
||||
{
|
||||
mRequester = std::make_unique<DataRequester>(std::make_unique<DataReceiverImpl>(
|
||||
mConnectionManager.get(), *mCacheState, mRankInInstance, makeFormatter()));
|
||||
mRequester = std::make_unique<CacheReceiver>(
|
||||
mConnectionManager.get(), *mCacheState, mRankInInstance, makeFormatter());
|
||||
}
|
||||
|
||||
std::vector<int> contextRankVec(mContextRankSize);
|
||||
@ -930,7 +827,7 @@ protected:
|
||||
auto const onlyWindowSize = blockManager.getPoolWindowSize(0);
|
||||
|
||||
blockManager.getBufferManager(onlyWindowSize).getStream().synchronize();
|
||||
auto future = mResponder->respondAndSendAsync(*llmRequest);
|
||||
auto future = mSender->sendAsync(*llmRequest);
|
||||
return future;
|
||||
}
|
||||
|
||||
@ -940,7 +837,7 @@ protected:
|
||||
auto constexpr beamWidth{1};
|
||||
mManager->addSequence(llmRequest->mRequestId, llmRequest->getNumTokens(beamIdx), beamWidth, llmRequest);
|
||||
|
||||
return mRequester->requestAndReceiveAsync(*llmRequest);
|
||||
return mRequester->receiveAsync(*llmRequest);
|
||||
}
|
||||
|
||||
void generationVerifyKVCache(std::shared_ptr<LlmRequest> const& llmRequest)
|
||||
@ -1162,8 +1059,8 @@ protected:
|
||||
SizeType32 mMaxNumSequences{};
|
||||
std::unique_ptr<KVCacheManager> mManager;
|
||||
std::unique_ptr<CacheTransBufferManager> mCacheTransBufferManager;
|
||||
std::unique_ptr<DataResponder> mResponder;
|
||||
std::unique_ptr<DataRequester> mRequester;
|
||||
std::unique_ptr<CacheSender> mSender;
|
||||
std::unique_ptr<CacheReceiver> mRequester;
|
||||
std::unique_ptr<texec::kv_cache::CacheState> mCacheState;
|
||||
std::unique_ptr<texec::kv_cache::CacheState> mContextCacheState;
|
||||
std::unique_ptr<texec::kv_cache::CommState> mContextCommState;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user