[None][feat] add detailed KV cache transfer time breakdown (#8521)

Signed-off-by: zhengd-nv <200704041+zhengd-nv@users.noreply.github.com>
This commit is contained in:
Zheng Duan 2025-10-29 10:11:09 +08:00 committed by GitHub
parent f444fe2deb
commit fea5bfbda7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 129 additions and 104 deletions

View File

@ -1691,22 +1691,22 @@ public:
mDecodingIter = iter; mDecodingIter = iter;
} }
void setKvCacheTransferStart(TimePoint const& time) void setKvCacheTransferStart(TimePoint time) const
{ {
mPerfMetrics.timingMetrics.kvCacheTransferStart = maybeToGlobalSteadyClock(time); mPerfMetrics.timingMetrics.kvCacheTransferStart = maybeToGlobalSteadyClock(time);
} }
void setKvCacheTransferEnd(TimePoint const& time) void setKvCacheTransferEnd(TimePoint time) const
{ {
mPerfMetrics.timingMetrics.kvCacheTransferEnd = maybeToGlobalSteadyClock(time); mPerfMetrics.timingMetrics.kvCacheTransferEnd = maybeToGlobalSteadyClock(time);
} }
TimePoint getKvCacheTransferStart() TimePoint getKvCacheTransferStart() const
{ {
return mPerfMetrics.timingMetrics.kvCacheTransferStart; return mPerfMetrics.timingMetrics.kvCacheTransferStart;
} }
TimePoint getKvCacheTransferEnd() TimePoint getKvCacheTransferEnd() const
{ {
return mPerfMetrics.timingMetrics.kvCacheTransferEnd; return mPerfMetrics.timingMetrics.kvCacheTransferEnd;
} }
@ -1865,13 +1865,11 @@ public:
return mUseDraftModel; return mUseDraftModel;
} }
// If mGlobalSteadyClockOffset is set, return a global steady clock time point, otherwise return local steady clock // If sGlobalSteadyClockOffset is set, return a global steady clock time point, otherwise return local steady clock
// time point // time point
[[nodiscard]] TimePoint getSteadyClockNow() const [[nodiscard]] static TimePoint getSteadyClockNow()
{ {
const TimePoint time_point = std::chrono::steady_clock::now(); return maybeToGlobalSteadyClock(std::chrono::steady_clock::now());
return maybeToGlobalSteadyClock(time_point);
} }
RequestIdType mRequestId; RequestIdType mRequestId;
@ -1894,7 +1892,7 @@ public:
SizeType32 mPtableCurrentPosition{0}; SizeType32 mPtableCurrentPosition{0};
// The offset between local steady clock and global steady clock (at rank 0) // The offset between local steady clock and global steady clock (at rank 0)
inline static std::optional<Duration> mGlobalSteadyClockOffset{std::nullopt}; inline static std::optional<Duration> sGlobalSteadyClockOffset{std::nullopt};
protected: protected:
bool mIsStreaming; bool mIsStreaming;
@ -2028,9 +2026,9 @@ protected:
std::optional<TensorPtr> mSkipCrossAttnBlocks{std::nullopt}; std::optional<TensorPtr> mSkipCrossAttnBlocks{std::nullopt};
// Performance metrics. // Performance metrics. Should be updatable even from a const LlmRequest reference.
bool mReturnPerfMetrics{false}; bool mReturnPerfMetrics{false};
executor::RequestPerfMetrics mPerfMetrics; mutable executor::RequestPerfMetrics mPerfMetrics;
// Guided decoding params. // Guided decoding params.
std::optional<executor::GuidedDecodingParams> mGuidedDecodingParams{std::nullopt}; std::optional<executor::GuidedDecodingParams> mGuidedDecodingParams{std::nullopt};
@ -2183,16 +2181,13 @@ private:
return tensor; return tensor;
} }
TimePoint maybeToGlobalSteadyClock(TimePoint const& time_point) const static TimePoint maybeToGlobalSteadyClock(TimePoint const& time_point)
{ {
if (mGlobalSteadyClockOffset.has_value()) if (sGlobalSteadyClockOffset.has_value())
{ {
return time_point + *mGlobalSteadyClockOffset; return time_point + *sGlobalSteadyClockOffset;
}
else
{
return time_point;
} }
return time_point;
} }
}; };

View File

@ -451,7 +451,7 @@ struct RequestPerfMetrics
/// @brief End time of the KV cache transfer for disaggregated serving /// @brief End time of the KV cache transfer for disaggregated serving
TimePoint kvCacheTransferEnd; TimePoint kvCacheTransferEnd;
/// @brief KV Cache size transfer for disaggregated serving /// @brief KV Cache size transfer for disaggregated serving
mutable size_t kvCacheSize = 0; size_t kvCacheSize = 0;
}; };
struct KvCacheMetrics struct KvCacheMetrics

View File

@ -227,6 +227,7 @@ std::vector<size_t> CacheFormatter::pickRecvConnections(
void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& session) void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& session)
{ {
NVTX3_SCOPED_RANGE(CacheFormatter_format); NVTX3_SCOPED_RANGE(CacheFormatter_format);
session.setTime(TransferSession::kTimeFormatter);
auto const& llmRequest = session.getLlmRequest(); auto const& llmRequest = session.getLlmRequest();
TLLM_LOG_DEBUG( TLLM_LOG_DEBUG(
mpi::MpiComm::world().getRank(), "Start sending KV cache for request ID: %ld.", llmRequest.mRequestId); mpi::MpiComm::world().getRank(), "Start sending KV cache for request ID: %ld.", llmRequest.mRequestId);
@ -249,9 +250,6 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
auto const numPools = blockManager.getNumPools(); auto const numPools = blockManager.getNumPools();
// TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1... // TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1...
auto lastTokenTime = llmRequest.getPerfMetrics().timingMetrics.lastTokenTime;
bool recordDelay = lastTokenTime != std::chrono::steady_clock::time_point();
bool layerWise = common::getEnvDisaggLayerwise() && numPools == 1; bool layerWise = common::getEnvDisaggLayerwise() && numPools == 1;
if (layerWise) if (layerWise)
{ {
@ -420,6 +418,7 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
inputKvCacheBlocksPerWindow, outputSplitCaches, destConfig, selfConfig, selfIdx, bufferManager); inputKvCacheBlocksPerWindow, outputSplitCaches, destConfig, selfConfig, selfIdx, bufferManager);
bufferManager.getStream().synchronize(); bufferManager.getStream().synchronize();
session.setTime(TransferSession::kTimePreprocess);
auto preAllocSendBuffer = mCacheTransBufferManager->getSendBuffer(cacheBufferId); auto preAllocSendBuffer = mCacheTransBufferManager->getSendBuffer(cacheBufferId);
if (preAllocSendBuffer != nullptr) if (preAllocSendBuffer != nullptr)
@ -434,7 +433,7 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
TLLM_CHECK(connections.size() > (processIdx / peerDuplicateHeadFactor)); TLLM_CHECK(connections.size() > (processIdx / peerDuplicateHeadFactor));
TLLM_CHECK(outputSplitCaches.size() > (processIdx / peerDuplicateHeadFactor)); TLLM_CHECK(outputSplitCaches.size() > (processIdx / peerDuplicateHeadFactor));
auto startTime = llmRequest.getSteadyClockNow(); auto startTime = LlmRequest::getSteadyClockNow();
size_t ppDomainSize = targetInfo.mDomainPPSize; size_t ppDomainSize = targetInfo.mDomainPPSize;
size_t bufferTpRank = (processIdx / ppDomainSize) / peerDuplicateHeadFactor; size_t bufferTpRank = (processIdx / ppDomainSize) / peerDuplicateHeadFactor;
@ -481,15 +480,8 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
} }
} }
auto endTime = llmRequest.getSteadyClockNow(); auto endTime = LlmRequest::getSteadyClockNow();
double delay = 0.0; session.appendMeasure(startTime, endTime, size);
if (recordDelay)
{
delay = std::chrono::duration<double, std::milli>(startTime - lastTokenTime).count();
}
double cacheTransferTime
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
session.appendMeasure(delay, cacheTransferTime, size);
}; };
if (connections.size() > 1) if (connections.size() > 1)
@ -534,8 +526,10 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
{ {
sendBufferFun(deviceId, 0); sendBufferFun(deviceId, 0);
} }
session.setTime(TransferSession::kTimeTransmissions);
mCacheTransBufferManager->freeBufferIndexForSend(cacheBufferId); mCacheTransBufferManager->freeBufferIndexForSend(cacheBufferId);
session.setTime(TransferSession::kTimePostprocess);
} }
TLLM_LOG_DEBUG( TLLM_LOG_DEBUG(
mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID:%ld ", llmRequest.mRequestId); mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID:%ld ", llmRequest.mRequestId);
@ -544,6 +538,7 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& session) void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& session)
{ {
NVTX3_SCOPED_RANGE(CacheFormatter_unformat); NVTX3_SCOPED_RANGE(CacheFormatter_unformat);
session.setTime(TransferSession::kTimeFormatter);
auto const& llmRequest = session.getLlmRequest(); auto const& llmRequest = session.getLlmRequest();
auto const ctxReqId = llmRequest.getContextPhaseParams().value().getReqId(); auto const ctxReqId = llmRequest.getContextPhaseParams().value().getReqId();
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
@ -555,9 +550,6 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
auto& bufferManager = session.getBufferManager(); auto& bufferManager = session.getBufferManager();
auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest, destConfig.getEnableBlockReuse()); auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest, destConfig.getEnableBlockReuse());
auto arrivalTime = llmRequest.getPerfMetrics().timingMetrics.arrivalTime;
bool recordDelay = arrivalTime != std::chrono::steady_clock::time_point();
auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig); auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig);
TLLM_LOG_DEBUG("pickUpConnections size: %d connections size: %d", pickUpConnections.size(), connections.size()); TLLM_LOG_DEBUG("pickUpConnections size: %d connections size: %d", pickUpConnections.size(), connections.size());
@ -779,6 +771,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
// sync to alloc buffer // sync to alloc buffer
bufferManager.getStream().synchronize(); bufferManager.getStream().synchronize();
} }
session.setTime(TransferSession::kTimePreprocess);
runtime::ITensor::SharedPtr preAllocRecvBuffer = nullptr; runtime::ITensor::SharedPtr preAllocRecvBuffer = nullptr;
if (cacheBufferId.has_value()) if (cacheBufferId.has_value())
@ -794,7 +787,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
TLLM_CHECK(pickUpConnections.size() > processIdx); TLLM_CHECK(pickUpConnections.size() > processIdx);
TLLM_CHECK(recvSplitCaches.size() > processIdx); TLLM_CHECK(recvSplitCaches.size() > processIdx);
auto startTime = llmRequest.getSteadyClockNow(); auto startTime = LlmRequest::getSteadyClockNow();
size_t size = 0; size_t size = 0;
if (processIdx >= remainNoCoverTargetNum) if (processIdx >= remainNoCoverTargetNum)
@ -835,15 +828,8 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
} }
} }
auto endTime = llmRequest.getSteadyClockNow(); auto endTime = LlmRequest::getSteadyClockNow();
double delay = 0.0; session.appendMeasure(startTime, endTime, size);
if (recordDelay)
{
delay = std::chrono::duration<double, std::milli>(startTime - arrivalTime).count();
}
double cacheTransferTime
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
session.appendMeasure(delay, cacheTransferTime, size);
}; };
if (pickUpConnections.size() > 1) if (pickUpConnections.size() > 1)
{ {
@ -891,6 +877,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
{ {
recvBufferFun(deviceId, 0); recvBufferFun(deviceId, 0);
} }
session.setTime(TransferSession::kTimeTransmissions);
{ {
NVTX3_SCOPED_RANGE(formatInputConcatenate); NVTX3_SCOPED_RANGE(formatInputConcatenate);
@ -904,6 +891,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
mCacheTransBufferManager->freeBufferIndexForRecv(cacheBufferId); mCacheTransBufferManager->freeBufferIndexForRecv(cacheBufferId);
} }
} }
session.setTime(TransferSession::kTimePostprocess);
} }
} }

View File

@ -603,7 +603,7 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
it->first->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE); it->first->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE);
// Gather the kv cache transfer time from all workers and update to leader rank // Gather the kv cache transfer time from all workers and update to leader rank
if (!common::getEnvKVCacheTransferOutputPath().empty()) if (!common::getEnvKVCacheTimeOutputPath().empty())
{ {
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm; auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm;
updateKVCacheTransferBW(syncComm, it->first); updateKVCacheTransferBW(syncComm, it->first);

View File

@ -28,6 +28,7 @@
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h" #include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
#include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <chrono>
#include <future> #include <future>
#include <map> #include <map>
#include <memory> #include <memory>
@ -105,39 +106,65 @@ void TransferSession::setLlmRequest(LlmRequest const& llmRequest)
mRequest = &llmRequest; mRequest = &llmRequest;
} }
void TransferSession::appendMeasure(double delay, double duration, size_t size) void TransferSession::setTime(TimeNames name)
{ {
if (!mRecordMeasure) if (mTimes)
{ {
return; mTimes->times.at(name) = LlmRequest::getSteadyClockNow();
}
}
void TransferSession::appendMeasure(LlmRequest::TimePoint start, LlmRequest::TimePoint end, size_t size)
{
if (mTimes)
{
mTimes->measures.emplace_back(Measure{start, end, size});
} }
auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps
mMeasures.emplace_back(Measure{delay, duration, bandwidth});
} }
void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const
{ {
if (mMeasures.empty()) if (!mTimes || mTimes->measures.empty())
{ {
return; return;
} }
// write header if not exist // write header if not exist
if (outFile.tellp() == 0) if (outFile.tellp() == 0)
{ {
outFile << "RequestID"; outFile << "RequestID,RequestInfo,Preparation,Preprocess,Transmissions,Postprocess";
for (size_t i = 0; i < mMeasures.size(); i++) for (size_t i = 0; i < mTimes->measures.size(); i++)
{ {
outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)"; outFile << ",Delay,Duration,Bandwidth(Gbps)";
} }
outFile << '\n'; outFile << '\n';
} }
// write measures auto transferStart = mRequest->getPerfMetrics().timingMetrics.kvCacheTransferStart;
using Milliseconds = std::chrono::duration<double, std::milli>;
// write measures, time is in milliseconds
TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value()); TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value());
auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId(); auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId();
outFile << reqId; outFile << reqId;
for (auto const& measure : mMeasures) auto previousTime = transferStart;
for (auto time : mTimes->times)
{ {
outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth; if (time == LlmRequest::TimePoint())
{
// timepoint is unset, skip
outFile << ",0.0";
continue;
}
double delay = Milliseconds(time - previousTime).count();
previousTime = time;
outFile << "," << delay;
}
previousTime = mTimes->times[kTimePreprocess];
for (auto const& measure : mTimes->measures)
{
double delay = Milliseconds(measure.start - previousTime).count();
double duration = Milliseconds(measure.end - measure.start).count();
double bandwidth = static_cast<double>(measure.size) * 8.0 / duration / 1e6; // byte, ms => Gbps
outFile << "," << delay << "," << duration << "," << bandwidth;
} }
outFile << '\n' << std::flush; outFile << '\n' << std::flush;
} }
@ -158,7 +185,7 @@ int32_t tagFromRequestId(LlmRequest::RequestIdType requestId)
std::filesystem::path getTransferOutputPath(char const* tag) std::filesystem::path getTransferOutputPath(char const* tag)
{ {
namespace fs = std::filesystem; namespace fs = std::filesystem;
auto outputPath = common::getEnvKVCacheTransferOutputPath(); auto outputPath = common::getEnvKVCacheTimeOutputPath();
if (!outputPath.empty()) if (!outputPath.empty())
{ {
auto rank = mpi::MpiComm::world().getRank(); auto rank = mpi::MpiComm::world().getRank();
@ -273,6 +300,7 @@ public:
{ {
std::promise<void> promise; std::promise<void> promise;
auto future = promise.get_future(); auto future = promise.get_future();
llmRequest.setKvCacheTransferStart(LlmRequest::getSteadyClockNow());
{ {
{ {
std::scoped_lock lkResp(mSenderMutex); std::scoped_lock lkResp(mSenderMutex);
@ -309,7 +337,7 @@ public:
std::unique_lock<std::mutex> lk(mMtxForMap); std::unique_lock<std::mutex> lk(mMtxForMap);
auto it = mRequestToSession.find(requestId); auto it = mRequestToSession.find(requestId);
TLLM_CHECK(it != mRequestToSession.end()); TLLM_CHECK(it != mRequestToSession.end());
if (!common::getEnvKVCacheTransferOutputPath().empty()) if (!common::getEnvKVCacheTimeOutputPath().empty())
{ {
if (!mMeasuresFile.is_open()) if (!mMeasuresFile.is_open())
{ {
@ -363,7 +391,8 @@ public:
auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr), auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr),
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager, DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager,
info.getIndexFromEnd(), info.getLastBlockKey(), nullptr, info.getIndexFromEnd(), info.getLastBlockKey(), nullptr,
!common::getEnvKVCacheTransferOutputPath().empty()); !common::getEnvKVCacheTimeOutputPath().empty());
session.setTime(TransferSession::kTimeRequestInfo);
it = mRequestToSession.emplace(requestId, std::move(session)).first; it = mRequestToSession.emplace(requestId, std::move(session)).first;
} }
it->second.setConnection(peerIdx, connection); it->second.setConnection(peerIdx, connection);
@ -382,6 +411,7 @@ public:
} }
session->setLlmRequest(llmRequest); session->setLlmRequest(llmRequest);
mFormatter->format(*session); mFormatter->format(*session);
llmRequest.setKvCacheTransferEnd(LlmRequest::getSteadyClockNow());
} }
bool cancelRequest(LlmRequest const& llmRequest) bool cancelRequest(LlmRequest const& llmRequest)
@ -751,7 +781,7 @@ public:
void receiveSync(TransferSession& session) void receiveSync(TransferSession& session)
{ {
mFormatter->unformat(session); mFormatter->unformat(session);
if (!common::getEnvKVCacheTransferOutputPath().empty()) if (!common::getEnvKVCacheTimeOutputPath().empty())
{ {
std::unique_lock<std::mutex> lock(mMeasuresFileMutex); std::unique_lock<std::mutex> lock(mMeasuresFileMutex);
if (!mMeasuresFile.is_open()) if (!mMeasuresFile.is_open())
@ -846,7 +876,7 @@ public:
auto const& resource = getReceiveCacheResource(llmRequest); auto const& resource = getReceiveCacheResource(llmRequest);
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState, return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState,
contextState, resource->mBufferManager, requestInfo.getIndexFromEnd(), requestInfo.getLastBlockKey(), contextState, resource->mBufferManager, requestInfo.getIndexFromEnd(), requestInfo.getLastBlockKey(),
&llmRequest, !common::getEnvKVCacheTransferOutputPath().empty()); &llmRequest, !common::getEnvKVCacheTimeOutputPath().empty());
} }
std::unique_ptr<ReceiveCacheResource> const& getReceiveCacheResource(LlmRequest const& llmRequest) std::unique_ptr<ReceiveCacheResource> const& getReceiveCacheResource(LlmRequest const& llmRequest)
@ -957,6 +987,7 @@ private:
llmRequest.setKvCacheTransferStart(std::chrono::steady_clock::now()); llmRequest.setKvCacheTransferStart(std::chrono::steady_clock::now());
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId)); TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId));
auto session = sendRequestInfo(llmRequest); auto session = sendRequestInfo(llmRequest);
session.setTime(TransferSession::kTimeRequestInfo);
bool isReady = receiveReadySignal(session); bool isReady = receiveReadySignal(session);
if (!isReady) if (!isReady)
{ {

View File

@ -56,29 +56,48 @@ using UniqueToken = tensorrt_llm::runtime::UniqueToken;
class TransferSession class TransferSession
{ {
public: public:
// measures for each single transmission
struct Measure struct Measure
{ {
double delay; // from last token (ctx) or arrival time (gen), in ms LlmRequest::TimePoint start;
double duration; // in ms LlmRequest::TimePoint end;
double bandwidth; // in Gbps size_t size = 0;
};
enum TimeNames : uint8_t
{
kTimeRequestInfo = 0,
kTimeFormatter,
kTimePreprocess,
kTimeTransmissions,
kTimePostprocess,
kTimeCounts
};
struct KVCacheTimes
{
std::array<LlmRequest::TimePoint, kTimeCounts> times;
std::vector<Measure> measures;
}; };
TransferSession(std::vector<Connection const*> connections, DataContext dataContext, TransferSession(std::vector<Connection const*> connections, DataContext dataContext,
executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState, executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState,
runtime::BufferManager const& bufferManager, int32_t indexFromEnd, BlockKey const& lastBlockKey, runtime::BufferManager const& bufferManager, int32_t indexFromEnd, BlockKey const& lastBlockKey,
LlmRequest const* llmRequest = nullptr, bool recordMeasure = false) LlmRequest const* llmRequest = nullptr, bool recordTiming = false)
: mConnections(std::move(connections)) : mConnections(std::move(connections))
, mDataContext(std::move(dataContext)) , mDataContext(std::move(dataContext))
, mSelfState(&selfState) , mSelfState(&selfState)
, mOtherState(std::move(otherState)) , mOtherState(std::move(otherState))
, mBufferManager(&bufferManager) , mBufferManager(&bufferManager)
, mRequest(llmRequest) , mRequest(llmRequest)
, mMeasures()
, mRecordMeasure(recordMeasure)
, mIndexFromEnd(indexFromEnd) , mIndexFromEnd(indexFromEnd)
, mLastBlockKey(lastBlockKey) , mLastBlockKey(lastBlockKey)
{ {
TLLM_CHECK(!mConnections.empty()); TLLM_CHECK(!mConnections.empty());
if (recordTiming)
{
mTimes = std::make_unique<KVCacheTimes>();
}
} }
[[nodiscard]] std::vector<Connection const*> const& getConnections() const; [[nodiscard]] std::vector<Connection const*> const& getConnections() const;
@ -103,7 +122,9 @@ public:
// in CacheSender, the LlmRequest is not available until the sendSync is called // in CacheSender, the LlmRequest is not available until the sendSync is called
void setLlmRequest(LlmRequest const& llmRequest); void setLlmRequest(LlmRequest const& llmRequest);
void appendMeasure(double delay, double duration, size_t size); void setTime(TimeNames name);
void appendMeasure(LlmRequest::TimePoint start, LlmRequest::TimePoint end, size_t size);
// TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file // TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file
void exportMeasure(std::ofstream& outFile, bool isContext) const; void exportMeasure(std::ofstream& outFile, bool isContext) const;
@ -125,8 +146,7 @@ private:
executor::DataTransceiverState mOtherState; executor::DataTransceiverState mOtherState;
runtime::BufferManager const* mBufferManager; runtime::BufferManager const* mBufferManager;
LlmRequest const* mRequest; LlmRequest const* mRequest;
std::vector<Measure> mMeasures; std::unique_ptr<KVCacheTimes> mTimes;
bool mRecordMeasure{false};
int32_t mIndexFromEnd{0}; int32_t mIndexFromEnd{0};
BlockKey mLastBlockKey{}; BlockKey mLastBlockKey{};
}; };

View File

@ -122,6 +122,7 @@ bool MLACacheFormatter::needSendCache(
void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& session) void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& session)
{ {
NVTX3_SCOPED_RANGE(MLACacheFormatter_format); NVTX3_SCOPED_RANGE(MLACacheFormatter_format);
session.setTime(TransferSession::kTimeFormatter);
auto const& llmRequest = session.getLlmRequest(); auto const& llmRequest = session.getLlmRequest();
TLLM_LOG_DEBUG( TLLM_LOG_DEBUG(
mpi::MpiComm::world().getRank(), "Start sending KV cache for request ID: %ld.", llmRequest.mRequestId); mpi::MpiComm::world().getRank(), "Start sending KV cache for request ID: %ld.", llmRequest.mRequestId);
@ -141,9 +142,6 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
// diff end // diff end
auto lastTokenTime = llmRequest.getPerfMetrics().timingMetrics.lastTokenTime;
bool recordDelay = lastTokenTime != std::chrono::steady_clock::time_point();
int blockNum = 0; int blockNum = 0;
std::vector<runtime::ITensor::SharedPtr> inputKvCacheBlocks; std::vector<runtime::ITensor::SharedPtr> inputKvCacheBlocks;
auto const numPools = mCacheManager->getBlockManager().getNumPools(); auto const numPools = mCacheManager->getBlockManager().getNumPools();
@ -235,6 +233,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
inputKvCacheBlocksPerWindow, outputSplitCaches, destConfig, selfConfig, selfIdx, bufferManager); inputKvCacheBlocksPerWindow, outputSplitCaches, destConfig, selfConfig, selfIdx, bufferManager);
bufferManager.getStream().synchronize(); bufferManager.getStream().synchronize();
session.setTime(TransferSession::kTimePreprocess);
auto preAllocSendBuffer = mCacheTransBufferManager->getSendBuffer(cacheBufferId); auto preAllocSendBuffer = mCacheTransBufferManager->getSendBuffer(cacheBufferId);
if (preAllocSendBuffer != nullptr) if (preAllocSendBuffer != nullptr)
@ -246,7 +245,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
NVTX3_SCOPED_RANGE(sendBufferFun); NVTX3_SCOPED_RANGE(sendBufferFun);
TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
auto startTime = llmRequest.getSteadyClockNow(); auto startTime = LlmRequest::getSteadyClockNow();
auto cacheIdx = processIdx % (pPDomainSize * cPDomainSize); auto cacheIdx = processIdx % (pPDomainSize * cPDomainSize);
if (cacheIdx < bufferCoverTargetNum) if (cacheIdx < bufferCoverTargetNum)
{ {
@ -279,15 +278,8 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
remainSendSize -= sendSize; remainSendSize -= sendSize;
} }
} }
auto endTime = llmRequest.getSteadyClockNow(); auto endTime = LlmRequest::getSteadyClockNow();
double delay = 0.0; session.appendMeasure(startTime, endTime, outputSplitCaches.at(cacheIdx)->getSizeInBytes());
if (recordDelay)
{
delay = std::chrono::duration<double, std::milli>(startTime - lastTokenTime).count();
}
double cacheTransferTime
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
session.appendMeasure(delay, cacheTransferTime, outputSplitCaches.at(cacheIdx)->getSizeInBytes());
}; };
if (connections.size() > 1) if (connections.size() > 1)
@ -331,7 +323,9 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
{ {
sendBufferFun(deviceId, 0); sendBufferFun(deviceId, 0);
} }
session.setTime(TransferSession::kTimeTransmissions);
mCacheTransBufferManager->freeBufferIndexForSend(cacheBufferId); mCacheTransBufferManager->freeBufferIndexForSend(cacheBufferId);
session.setTime(TransferSession::kTimePostprocess);
TLLM_LOG_DEBUG( TLLM_LOG_DEBUG(
mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID: %ld.", llmRequest.mRequestId); mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID: %ld.", llmRequest.mRequestId);
@ -340,6 +334,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& session) void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& session)
{ {
NVTX3_SCOPED_RANGE(MLACacheFormatter_unformat); NVTX3_SCOPED_RANGE(MLACacheFormatter_unformat);
session.setTime(TransferSession::kTimeFormatter);
auto const& llmRequest = session.getLlmRequest(); auto const& llmRequest = session.getLlmRequest();
TLLM_CHECK_WITH_INFO(llmRequest.mSamplingConfig.beamWidth == 1, "Currently only supports beam width 1."); TLLM_CHECK_WITH_INFO(llmRequest.mSamplingConfig.beamWidth == 1, "Currently only supports beam width 1.");
auto const ctxReqId = llmRequest.getContextPhaseParams().value().getReqId(); auto const ctxReqId = llmRequest.getContextPhaseParams().value().getReqId();
@ -350,8 +345,6 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx(); auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx();
auto const& connections = session.getConnections(); auto const& connections = session.getConnections();
auto& bufferManager = session.getBufferManager(); auto& bufferManager = session.getBufferManager();
auto arrivalTime = llmRequest.getPerfMetrics().timingMetrics.arrivalTime;
bool recordDelay = arrivalTime != std::chrono::steady_clock::time_point();
auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig); auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig);
auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest, destConfig.getEnableBlockReuse()); auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest, destConfig.getEnableBlockReuse());
std::vector<runtime::ITensor::SharedPtr> recvBufferTmps; std::vector<runtime::ITensor::SharedPtr> recvBufferTmps;
@ -445,6 +438,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
TLLM_CHECK(onlyUseDynamicBuffer == false); TLLM_CHECK(onlyUseDynamicBuffer == false);
} }
bufferManager.getStream().synchronize(); bufferManager.getStream().synchronize();
session.setTime(TransferSession::kTimePreprocess);
auto preAllocRecvBuffer = mCacheTransBufferManager->getRecvBuffer(cacheBufferId); auto preAllocRecvBuffer = mCacheTransBufferManager->getRecvBuffer(cacheBufferId);
if (preAllocRecvBuffer != nullptr) if (preAllocRecvBuffer != nullptr)
@ -456,7 +450,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
{ {
NVTX3_SCOPED_RANGE(recvBufferFun); NVTX3_SCOPED_RANGE(recvBufferFun);
TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
auto startTime = llmRequest.getSteadyClockNow(); auto startTime = LlmRequest::getSteadyClockNow();
size_t size = 0; size_t size = 0;
if (processIdx >= remainNoCoverTargetNum) if (processIdx >= remainNoCoverTargetNum)
{ {
@ -489,15 +483,8 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
remainRecvSize -= recvSize; remainRecvSize -= recvSize;
} }
} }
auto endTime = llmRequest.getSteadyClockNow(); auto endTime = LlmRequest::getSteadyClockNow();
double delay = 0.0; session.appendMeasure(startTime, endTime, size);
if (recordDelay)
{
delay = std::chrono::duration<double, std::milli>(startTime - arrivalTime).count();
}
double cacheTransferTime
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
session.appendMeasure(delay, cacheTransferTime, size);
}; };
if (pickUpConnections.size() > 1) if (pickUpConnections.size() > 1)
@ -546,6 +533,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
{ {
recvBufferFun(deviceId, 0); recvBufferFun(deviceId, 0);
} }
session.setTime(TransferSession::kTimeTransmissions);
{ {
std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> outputCachesPerWindow; std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> outputCachesPerWindow;
@ -564,6 +552,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
{ {
mCacheTransBufferManager->freeBufferIndexForRecv(cacheBufferId); mCacheTransBufferManager->freeBufferIndexForRecv(cacheBufferId);
} }
session.setTime(TransferSession::kTimePostprocess);
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"End receiving KV cache for request ID: %ld, context request ID: %ld.", llmRequest.mRequestId, "End receiving KV cache for request ID: %ld, context request ID: %ld.", llmRequest.mRequestId,

View File

@ -380,7 +380,7 @@ size_t getEnvAllReduceWorkspaceSize()
return workspaceSize; return workspaceSize;
} }
std::string const& getEnvKVCacheTransferOutputPath() std::string const& getEnvKVCacheTimeOutputPath()
{ {
static std::string outputPath = getStrEnv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH").value_or(""); static std::string outputPath = getStrEnv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH").value_or("");
return outputPath; return outputPath;

View File

@ -96,7 +96,7 @@ bool getEnvDisableKVCacheTransferOverlap();
bool getEnvEnableReceiveKVCacheParallel(); bool getEnvEnableReceiveKVCacheParallel();
std::string const& getEnvKVCacheTransferOutputPath(); std::string const& getEnvKVCacheTimeOutputPath();
bool getEnvTryZCopyForKVCacheTransfer(); bool getEnvTryZCopyForKVCacheTransfer();

View File

@ -383,7 +383,7 @@ void initBindings(nb::module_& m)
.def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime) .def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime)
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, nb::arg("iter_counter")) .def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, nb::arg("iter_counter"))
.def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors) .def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors)
.def_rw_static("global_steady_clock_offset", &tb::LlmRequest::mGlobalSteadyClockOffset); .def_rw_static("global_steady_clock_offset", &tb::LlmRequest::sGlobalSteadyClockOffset);
nb::class_<tb::SequenceSlotManager>(m, "SequenceSlotManager") nb::class_<tb::SequenceSlotManager>(m, "SequenceSlotManager")
.def(nb::init<tb::SequenceSlotManager::SlotIdType, uint64_t>(), nb::arg("max_num_slots"), .def(nb::init<tb::SequenceSlotManager::SlotIdType, uint64_t>(), nb::arg("max_num_slots"),

View File

@ -389,7 +389,7 @@ void initBindings(pybind11::module_& m)
.def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime) .def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime)
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, py::arg("iter_counter")) .def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, py::arg("iter_counter"))
.def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors) .def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors)
.def_readwrite_static("global_steady_clock_offset", &tb::LlmRequest::mGlobalSteadyClockOffset); .def_readwrite_static("global_steady_clock_offset", &tb::LlmRequest::sGlobalSteadyClockOffset);
py::classh<tb::SequenceSlotManager>(m, "SequenceSlotManager") py::classh<tb::SequenceSlotManager>(m, "SequenceSlotManager")
.def(py::init<tb::SequenceSlotManager::SlotIdType, uint64_t>(), py::arg("max_num_slots"), .def(py::init<tb::SequenceSlotManager::SlotIdType, uint64_t>(), py::arg("max_num_slots"),

View File

@ -853,10 +853,12 @@ def test_disaggregated_kv_cache_time_output(disaggregated_test_root, llm_venv,
lines = f.readlines() lines = f.readlines()
assert len(lines) > 1 assert len(lines) > 1
assert lines[0].startswith( assert lines[0].startswith(
"RequestID,Delay(ms),Duration(ms),Bandwidth(Gbps)") "RequestID,RequestInfo,Preparation,Preprocess,Transmissions,Postprocess"
)
assert ",Delay,Duration,Bandwidth(Gbps)" in lines[0]
# get a send sample and match the recv # get a send sample and match the recv
sample = lines[1].split(',') sample = lines[1].split(',')
assert len(sample) >= 4 assert len(sample) >= 9
with open(recv_file, "r") as f: with open(recv_file, "r") as f:
lines = f.readlines() lines = f.readlines()
assert len(lines) > 1 assert len(lines) > 1