[None][feat] move kv cache measure into transfer session (#6633)

Signed-off-by: zhengd-nv <200704041+zhengd-nv@users.noreply.github.com>
This commit is contained in:
Zheng Duan 2025-08-08 17:49:22 +08:00 committed by GitHub
parent 32ad7f3c12
commit ebdc43e69d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 107 additions and 95 deletions

View File

@ -361,7 +361,7 @@ void CacheFormatter::format(TransferSession& session)
}
double cacheTransferTime
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
kvCacheMeasureHelper.appendKVCacheTransfer(llmRequest.mRequestId, delay, cacheTransferTime, size);
session.appendMeasure(delay, cacheTransferTime, size);
};
if (connections.size() > 1)
@ -713,7 +713,7 @@ void CacheFormatter::unformat(TransferSession& session)
}
double cacheTransferTime
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
kvCacheMeasureHelper.appendKVCacheTransfer(ctxReqId, delay, cacheTransferTime, size);
session.appendMeasure(delay, cacheTransferTime, size);
};
if (pickUpConnections.size() > 1)
{

View File

@ -76,15 +76,6 @@ public:
/// @brief Destructor.
virtual ~BaseCacheFormatter() = default;
// TODO: better way for context/generation tagging
void markAsSender(bool isSender)
{
kvCacheMeasureHelper.markAsSender(isSender);
}
protected:
KvCacheMeasureHelper kvCacheMeasureHelper{common::getEnvKVCacheTransferOutputPath()};
};
// Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the

View File

@ -91,6 +91,43 @@ std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo)
return totalSize;
}
void TransferSession::appendMeasure(double delay, double duration, size_t size)
{
if (!mRecordMeasure)
{
return;
}
auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps
mMeasures.emplace_back(Measure{delay, duration, bandwidth});
}
void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const
{
if (mMeasures.empty())
{
return;
}
// write header if not exist
if (outFile.tellp() == 0)
{
outFile << "RequestID";
for (size_t i = 0; i < mMeasures.size(); i++)
{
outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)";
}
outFile << '\n';
}
// write measures
TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value());
auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId();
outFile << reqId;
for (auto const& measure : mMeasures)
{
outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth;
}
outFile << '\n' << std::flush;
}
class DataResponder::Impl
{
public:

View File

@ -97,15 +97,23 @@ private:
class TransferSession
{
public:
struct Measure
{
double delay; // from last token (ctx) or arrival time (gen), in ms
double duration; // in ms
double bandwidth; // in Gbps
};
TransferSession(std::vector<Connection const*> connections, DataContext dataContext,
executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState,
runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr)
runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr, bool recordMeasure = false)
: mConnections(std::move(connections))
, mDataContext(dataContext)
, mSelfState(&selfState)
, mOtherState(std::move(otherState))
, mBufferManager(&bufferManager)
, mRequest(llmRequest)
, mRecordMeasure(recordMeasure)
{
TLLM_CHECK(!mConnections.empty());
}
@ -163,6 +171,11 @@ public:
mRequest = &llmRequest;
}
void appendMeasure(double delay, double duration, size_t size);
// TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file
void exportMeasure(std::ofstream& outFile, bool isContext) const;
private:
std::vector<Connection const*> mConnections;
DataContext mDataContext;
@ -170,6 +183,8 @@ private:
executor::DataTransceiverState mOtherState;
runtime::BufferManager const* mBufferManager;
LlmRequest const* mRequest;
bool mRecordMeasure;
std::vector<Measure> mMeasures;
};
// Operators required for data transmission in specific communication protocols.
@ -266,79 +281,4 @@ private:
std::unique_ptr<Impl> mImpl;
};
class KvCacheMeasureHelper
{
public:
struct Measure
{
double delay; // from last token (ctx) or arrival time (gen), in ms
double duration; // in ms
double bandwidth; // in Gbps
};
KvCacheMeasureHelper(std::string output_path)
: mOutputPath(std::move(output_path))
{
}
void markAsSender(bool isSender)
{
mIsSender = isSender;
}
void appendKVCacheTransfer(LlmRequest::RequestIdType requestId, double delay, double duration, size_t size)
{
auto bandwidth = size * 8 / (duration / 1000) / 1e9;
if (mOutputPath.empty())
{
return;
}
std::lock_guard<std::mutex> lock(mMutex);
mRequestKVCacheTranfserMeasure[requestId].emplace_back(Measure{delay, duration, bandwidth});
}
~KvCacheMeasureHelper()
{
if (!mRequestKVCacheTranfserMeasure.empty() && !mOutputPath.empty())
{
TLLM_CHECK(mIsSender.has_value());
auto rank = mpi::MpiComm::world().getRank();
std::string outFilePath
= mOutputPath + "rank_" + std::to_string(rank) + "_" + (mIsSender.value() ? "send" : "recv") + ".csv";
std::ofstream outFile(outFilePath);
TLLM_CHECK_WITH_INFO(outFile.is_open(), "Cannot write to file " + outFilePath);
size_t numTransferMeasure = mRequestKVCacheTranfserMeasure.begin()->second.size();
outFile << "RequestID";
for (size_t i = 0; i < numTransferMeasure; i++)
{
outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)";
}
outFile << '\n';
for (auto const& [requestID, measures] : mRequestKVCacheTranfserMeasure)
{
outFile << requestID;
for (auto const& measure : measures)
{
outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth;
}
outFile << '\n';
}
outFile.close();
}
}
private:
std::map<LlmRequest::RequestIdType, std::vector<Measure>> mRequestKVCacheTranfserMeasure;
std::string mOutputPath;
std::mutex mMutex;
std::optional<bool> mIsSender;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -21,6 +21,8 @@
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <filesystem>
namespace tensorrt_llm::batch_manager
{
@ -30,6 +32,21 @@ static int32_t tagFromRequestId(LlmRequest::RequestIdType requestId)
return ((requestId & 0xFFF) << 8) | (kDATA_TAG & 0xFF);
}
namespace fs = std::filesystem;
static fs::path getTransferOutputPath(char const* tag)
{
auto outputPath = common::getEnvKVCacheTransferOutputPath();
if (!outputPath.empty())
{
auto rank = mpi::MpiComm::world().getRank();
auto path = fs::path(outputPath);
fs::create_directories(path);
return path / ("rank_" + std::to_string(rank) + "_" + tag + ".csv");
}
return {};
}
DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
: mManager{manager}
@ -39,7 +56,6 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
{
TLLM_CHECK(mManager);
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
mFormatter->markAsSender(true);
}
[[nodiscard]] RequestInfo DataSenderImpl::recvRequestInfo()
@ -86,7 +102,8 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
if (it == mRequestToSession.end())
{
auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr),
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager);
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager, nullptr,
!common::getEnvKVCacheTransferOutputPath().empty());
it = mRequestToSession.emplace(requestId, std::move(session)).first;
}
it->second.setConnection(peerIdx, connection);
@ -125,6 +142,17 @@ void DataSenderImpl::release(LlmRequest::RequestIdType requestId)
auto it = mRequestToSession.find(requestId);
TLLM_CHECK(it != mRequestToSession.end());
std::unique_lock<std::mutex> lk(mMtxForMap);
if (!common::getEnvKVCacheTransferOutputPath().empty())
{
if (!mMeasuresFile.is_open())
{
auto outputPath = getTransferOutputPath("send");
mMeasuresFile.open(outputPath);
TLLM_CHECK_WITH_INFO(
mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str());
}
it->second.exportMeasure(mMeasuresFile, true);
}
mRequestToSession.erase(it);
}
@ -137,7 +165,6 @@ DataReceiverImpl::DataReceiverImpl(executor::kv_cache::ConnectionManager* manage
TLLM_CHECK(mManager);
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
TLLM_CHECK(mFormatter);
mFormatter->markAsSender(false);
}
TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
@ -203,12 +230,24 @@ TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
}
auto const& resource = getReceiveCacheResource(llmRequest);
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState,
contextState, resource->mBufferManager, &llmRequest);
contextState, resource->mBufferManager, &llmRequest, !common::getEnvKVCacheTransferOutputPath().empty());
}
void DataReceiverImpl::receiveSync(TransferSession& session)
{
mFormatter->unformat(session);
if (!common::getEnvKVCacheTransferOutputPath().empty())
{
std::unique_lock<std::mutex> lock(mMeasuresFileMutex);
if (!mMeasuresFile.is_open())
{
auto outputPath = getTransferOutputPath("recv");
mMeasuresFile.open(outputPath);
TLLM_CHECK_WITH_INFO(
mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str());
}
session.exportMeasure(mMeasuresFile, false);
}
}
void DataReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info)

View File

@ -23,6 +23,8 @@
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
#include <fstream>
namespace tensorrt_llm::batch_manager
{
struct TransceiverTag
@ -67,6 +69,7 @@ private:
std::unique_ptr<BaseCacheFormatter> mFormatter;
std::mutex mMtxForMap;
runtime::BufferManager mBufferManager;
std::ofstream mMeasuresFile;
};
class DataReceiverImpl : public DataReceiver, public TransceiverTag
@ -103,6 +106,8 @@ private:
std::unique_ptr<BaseCacheFormatter> mFormatter;
std::unordered_map<std::string, std::unique_ptr<ReceiveCacheResource>> mProcessToResources;
std::mutex mProcessIoResouceMutex;
std::ofstream mMeasuresFile;
std::mutex mMeasuresFileMutex;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -244,7 +244,7 @@ void MLACacheFormatter::format(TransferSession& session)
}
double cacheTransferTime
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
kvCacheMeasureHelper.appendKVCacheTransfer(llmRequest.mRequestId, delay, cacheTransferTime, size);
session.appendMeasure(delay, cacheTransferTime, size);
};
if (connections.size() > 1)
@ -441,7 +441,7 @@ void MLACacheFormatter::unformat(TransferSession& session)
}
double cacheTransferTime
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
kvCacheMeasureHelper.appendKVCacheTransfer(ctxReqId, delay, cacheTransferTime, size);
session.appendMeasure(delay, cacheTransferTime, size);
};
if (pickUpConnections.size() > 1)

View File

@ -386,7 +386,7 @@ size_t getEnvAllReduceWorkspaceSize()
return workspaceSize;
}
std::string getEnvKVCacheTransferOutputPath()
std::string const& getEnvKVCacheTransferOutputPath()
{
static std::string outputPath = getStrEnv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH").value_or("");
return outputPath;

View File

@ -76,7 +76,7 @@ bool getEnvDisableKVCacheTransferOverlap();
bool getEnvEnableReceiveKVCacheParallel();
std::string getEnvKVCacheTransferOutputPath();
std::string const& getEnvKVCacheTransferOutputPath();
bool getEnvTryZCopyForKVCacheTransfer();