mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] move kv cache measure into transfer session (#6633)
Signed-off-by: zhengd-nv <200704041+zhengd-nv@users.noreply.github.com>
This commit is contained in:
parent
32ad7f3c12
commit
ebdc43e69d
@ -361,7 +361,7 @@ void CacheFormatter::format(TransferSession& session)
|
||||
}
|
||||
double cacheTransferTime
|
||||
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
|
||||
kvCacheMeasureHelper.appendKVCacheTransfer(llmRequest.mRequestId, delay, cacheTransferTime, size);
|
||||
session.appendMeasure(delay, cacheTransferTime, size);
|
||||
};
|
||||
|
||||
if (connections.size() > 1)
|
||||
@ -713,7 +713,7 @@ void CacheFormatter::unformat(TransferSession& session)
|
||||
}
|
||||
double cacheTransferTime
|
||||
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
|
||||
kvCacheMeasureHelper.appendKVCacheTransfer(ctxReqId, delay, cacheTransferTime, size);
|
||||
session.appendMeasure(delay, cacheTransferTime, size);
|
||||
};
|
||||
if (pickUpConnections.size() > 1)
|
||||
{
|
||||
|
||||
@ -76,15 +76,6 @@ public:
|
||||
|
||||
/// @brief Destructor.
|
||||
virtual ~BaseCacheFormatter() = default;
|
||||
|
||||
// TODO: better way for context/generation tagging
|
||||
void markAsSender(bool isSender)
|
||||
{
|
||||
kvCacheMeasureHelper.markAsSender(isSender);
|
||||
}
|
||||
|
||||
protected:
|
||||
KvCacheMeasureHelper kvCacheMeasureHelper{common::getEnvKVCacheTransferOutputPath()};
|
||||
};
|
||||
|
||||
// Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the
|
||||
|
||||
@ -91,6 +91,43 @@ std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo)
|
||||
return totalSize;
|
||||
}
|
||||
|
||||
void TransferSession::appendMeasure(double delay, double duration, size_t size)
|
||||
{
|
||||
if (!mRecordMeasure)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps
|
||||
mMeasures.emplace_back(Measure{delay, duration, bandwidth});
|
||||
}
|
||||
|
||||
void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const
|
||||
{
|
||||
if (mMeasures.empty())
|
||||
{
|
||||
return;
|
||||
}
|
||||
// write header if not exist
|
||||
if (outFile.tellp() == 0)
|
||||
{
|
||||
outFile << "RequestID";
|
||||
for (size_t i = 0; i < mMeasures.size(); i++)
|
||||
{
|
||||
outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)";
|
||||
}
|
||||
outFile << '\n';
|
||||
}
|
||||
// write measures
|
||||
TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value());
|
||||
auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId();
|
||||
outFile << reqId;
|
||||
for (auto const& measure : mMeasures)
|
||||
{
|
||||
outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth;
|
||||
}
|
||||
outFile << '\n' << std::flush;
|
||||
}
|
||||
|
||||
class DataResponder::Impl
|
||||
{
|
||||
public:
|
||||
|
||||
@ -97,15 +97,23 @@ private:
|
||||
class TransferSession
|
||||
{
|
||||
public:
|
||||
struct Measure
|
||||
{
|
||||
double delay; // from last token (ctx) or arrival time (gen), in ms
|
||||
double duration; // in ms
|
||||
double bandwidth; // in Gbps
|
||||
};
|
||||
|
||||
TransferSession(std::vector<Connection const*> connections, DataContext dataContext,
|
||||
executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState,
|
||||
runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr)
|
||||
runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr, bool recordMeasure = false)
|
||||
: mConnections(std::move(connections))
|
||||
, mDataContext(dataContext)
|
||||
, mSelfState(&selfState)
|
||||
, mOtherState(std::move(otherState))
|
||||
, mBufferManager(&bufferManager)
|
||||
, mRequest(llmRequest)
|
||||
, mRecordMeasure(recordMeasure)
|
||||
{
|
||||
TLLM_CHECK(!mConnections.empty());
|
||||
}
|
||||
@ -163,6 +171,11 @@ public:
|
||||
mRequest = &llmRequest;
|
||||
}
|
||||
|
||||
void appendMeasure(double delay, double duration, size_t size);
|
||||
|
||||
// TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file
|
||||
void exportMeasure(std::ofstream& outFile, bool isContext) const;
|
||||
|
||||
private:
|
||||
std::vector<Connection const*> mConnections;
|
||||
DataContext mDataContext;
|
||||
@ -170,6 +183,8 @@ private:
|
||||
executor::DataTransceiverState mOtherState;
|
||||
runtime::BufferManager const* mBufferManager;
|
||||
LlmRequest const* mRequest;
|
||||
bool mRecordMeasure;
|
||||
std::vector<Measure> mMeasures;
|
||||
};
|
||||
|
||||
// Operators required for data transmission in specific communication protocols.
|
||||
@ -266,79 +281,4 @@ private:
|
||||
std::unique_ptr<Impl> mImpl;
|
||||
};
|
||||
|
||||
class KvCacheMeasureHelper
|
||||
{
|
||||
public:
|
||||
struct Measure
|
||||
{
|
||||
double delay; // from last token (ctx) or arrival time (gen), in ms
|
||||
double duration; // in ms
|
||||
double bandwidth; // in Gbps
|
||||
};
|
||||
|
||||
KvCacheMeasureHelper(std::string output_path)
|
||||
: mOutputPath(std::move(output_path))
|
||||
{
|
||||
}
|
||||
|
||||
void markAsSender(bool isSender)
|
||||
{
|
||||
mIsSender = isSender;
|
||||
}
|
||||
|
||||
void appendKVCacheTransfer(LlmRequest::RequestIdType requestId, double delay, double duration, size_t size)
|
||||
{
|
||||
auto bandwidth = size * 8 / (duration / 1000) / 1e9;
|
||||
if (mOutputPath.empty())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
mRequestKVCacheTranfserMeasure[requestId].emplace_back(Measure{delay, duration, bandwidth});
|
||||
}
|
||||
|
||||
~KvCacheMeasureHelper()
|
||||
{
|
||||
if (!mRequestKVCacheTranfserMeasure.empty() && !mOutputPath.empty())
|
||||
{
|
||||
TLLM_CHECK(mIsSender.has_value());
|
||||
auto rank = mpi::MpiComm::world().getRank();
|
||||
std::string outFilePath
|
||||
= mOutputPath + "rank_" + std::to_string(rank) + "_" + (mIsSender.value() ? "send" : "recv") + ".csv";
|
||||
std::ofstream outFile(outFilePath);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(outFile.is_open(), "Cannot write to file " + outFilePath);
|
||||
|
||||
size_t numTransferMeasure = mRequestKVCacheTranfserMeasure.begin()->second.size();
|
||||
|
||||
outFile << "RequestID";
|
||||
for (size_t i = 0; i < numTransferMeasure; i++)
|
||||
{
|
||||
outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)";
|
||||
}
|
||||
outFile << '\n';
|
||||
|
||||
for (auto const& [requestID, measures] : mRequestKVCacheTranfserMeasure)
|
||||
{
|
||||
outFile << requestID;
|
||||
|
||||
for (auto const& measure : measures)
|
||||
{
|
||||
outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth;
|
||||
}
|
||||
outFile << '\n';
|
||||
}
|
||||
|
||||
outFile.close();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<LlmRequest::RequestIdType, std::vector<Measure>> mRequestKVCacheTranfserMeasure;
|
||||
std::string mOutputPath;
|
||||
std::mutex mMutex;
|
||||
std::optional<bool> mIsSender;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -21,6 +21,8 @@
|
||||
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
|
||||
#include <filesystem>
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
|
||||
@ -30,6 +32,21 @@ static int32_t tagFromRequestId(LlmRequest::RequestIdType requestId)
|
||||
return ((requestId & 0xFFF) << 8) | (kDATA_TAG & 0xFF);
|
||||
}
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
static fs::path getTransferOutputPath(char const* tag)
|
||||
{
|
||||
auto outputPath = common::getEnvKVCacheTransferOutputPath();
|
||||
if (!outputPath.empty())
|
||||
{
|
||||
auto rank = mpi::MpiComm::world().getRank();
|
||||
auto path = fs::path(outputPath);
|
||||
fs::create_directories(path);
|
||||
return path / ("rank_" + std::to_string(rank) + "_" + tag + ".csv");
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
|
||||
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
|
||||
: mManager{manager}
|
||||
@ -39,7 +56,6 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
|
||||
{
|
||||
TLLM_CHECK(mManager);
|
||||
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
|
||||
mFormatter->markAsSender(true);
|
||||
}
|
||||
|
||||
[[nodiscard]] RequestInfo DataSenderImpl::recvRequestInfo()
|
||||
@ -86,7 +102,8 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
|
||||
if (it == mRequestToSession.end())
|
||||
{
|
||||
auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr),
|
||||
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager);
|
||||
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager, nullptr,
|
||||
!common::getEnvKVCacheTransferOutputPath().empty());
|
||||
it = mRequestToSession.emplace(requestId, std::move(session)).first;
|
||||
}
|
||||
it->second.setConnection(peerIdx, connection);
|
||||
@ -125,6 +142,17 @@ void DataSenderImpl::release(LlmRequest::RequestIdType requestId)
|
||||
auto it = mRequestToSession.find(requestId);
|
||||
TLLM_CHECK(it != mRequestToSession.end());
|
||||
std::unique_lock<std::mutex> lk(mMtxForMap);
|
||||
if (!common::getEnvKVCacheTransferOutputPath().empty())
|
||||
{
|
||||
if (!mMeasuresFile.is_open())
|
||||
{
|
||||
auto outputPath = getTransferOutputPath("send");
|
||||
mMeasuresFile.open(outputPath);
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str());
|
||||
}
|
||||
it->second.exportMeasure(mMeasuresFile, true);
|
||||
}
|
||||
mRequestToSession.erase(it);
|
||||
}
|
||||
|
||||
@ -137,7 +165,6 @@ DataReceiverImpl::DataReceiverImpl(executor::kv_cache::ConnectionManager* manage
|
||||
TLLM_CHECK(mManager);
|
||||
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
|
||||
TLLM_CHECK(mFormatter);
|
||||
mFormatter->markAsSender(false);
|
||||
}
|
||||
|
||||
TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
|
||||
@ -203,12 +230,24 @@ TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
|
||||
}
|
||||
auto const& resource = getReceiveCacheResource(llmRequest);
|
||||
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState,
|
||||
contextState, resource->mBufferManager, &llmRequest);
|
||||
contextState, resource->mBufferManager, &llmRequest, !common::getEnvKVCacheTransferOutputPath().empty());
|
||||
}
|
||||
|
||||
void DataReceiverImpl::receiveSync(TransferSession& session)
|
||||
{
|
||||
mFormatter->unformat(session);
|
||||
if (!common::getEnvKVCacheTransferOutputPath().empty())
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mMeasuresFileMutex);
|
||||
if (!mMeasuresFile.is_open())
|
||||
{
|
||||
auto outputPath = getTransferOutputPath("recv");
|
||||
mMeasuresFile.open(outputPath);
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str());
|
||||
}
|
||||
session.exportMeasure(mMeasuresFile, false);
|
||||
}
|
||||
}
|
||||
|
||||
void DataReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info)
|
||||
|
||||
@ -23,6 +23,8 @@
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
struct TransceiverTag
|
||||
@ -67,6 +69,7 @@ private:
|
||||
std::unique_ptr<BaseCacheFormatter> mFormatter;
|
||||
std::mutex mMtxForMap;
|
||||
runtime::BufferManager mBufferManager;
|
||||
std::ofstream mMeasuresFile;
|
||||
};
|
||||
|
||||
class DataReceiverImpl : public DataReceiver, public TransceiverTag
|
||||
@ -103,6 +106,8 @@ private:
|
||||
std::unique_ptr<BaseCacheFormatter> mFormatter;
|
||||
std::unordered_map<std::string, std::unique_ptr<ReceiveCacheResource>> mProcessToResources;
|
||||
std::mutex mProcessIoResouceMutex;
|
||||
std::ofstream mMeasuresFile;
|
||||
std::mutex mMeasuresFileMutex;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -244,7 +244,7 @@ void MLACacheFormatter::format(TransferSession& session)
|
||||
}
|
||||
double cacheTransferTime
|
||||
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
|
||||
kvCacheMeasureHelper.appendKVCacheTransfer(llmRequest.mRequestId, delay, cacheTransferTime, size);
|
||||
session.appendMeasure(delay, cacheTransferTime, size);
|
||||
};
|
||||
|
||||
if (connections.size() > 1)
|
||||
@ -441,7 +441,7 @@ void MLACacheFormatter::unformat(TransferSession& session)
|
||||
}
|
||||
double cacheTransferTime
|
||||
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
|
||||
kvCacheMeasureHelper.appendKVCacheTransfer(ctxReqId, delay, cacheTransferTime, size);
|
||||
session.appendMeasure(delay, cacheTransferTime, size);
|
||||
};
|
||||
|
||||
if (pickUpConnections.size() > 1)
|
||||
|
||||
@ -386,7 +386,7 @@ size_t getEnvAllReduceWorkspaceSize()
|
||||
return workspaceSize;
|
||||
}
|
||||
|
||||
std::string getEnvKVCacheTransferOutputPath()
|
||||
std::string const& getEnvKVCacheTransferOutputPath()
|
||||
{
|
||||
static std::string outputPath = getStrEnv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH").value_or("");
|
||||
return outputPath;
|
||||
|
||||
@ -76,7 +76,7 @@ bool getEnvDisableKVCacheTransferOverlap();
|
||||
|
||||
bool getEnvEnableReceiveKVCacheParallel();
|
||||
|
||||
std::string getEnvKVCacheTransferOutputPath();
|
||||
std::string const& getEnvKVCacheTransferOutputPath();
|
||||
|
||||
bool getEnvTryZCopyForKVCacheTransfer();
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user