mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] add detailed KV cache transfer time breakdown (#8521)
Signed-off-by: zhengd-nv <200704041+zhengd-nv@users.noreply.github.com>
This commit is contained in:
parent
f444fe2deb
commit
fea5bfbda7
@ -1691,22 +1691,22 @@ public:
|
|||||||
mDecodingIter = iter;
|
mDecodingIter = iter;
|
||||||
}
|
}
|
||||||
|
|
||||||
void setKvCacheTransferStart(TimePoint const& time)
|
void setKvCacheTransferStart(TimePoint time) const
|
||||||
{
|
{
|
||||||
mPerfMetrics.timingMetrics.kvCacheTransferStart = maybeToGlobalSteadyClock(time);
|
mPerfMetrics.timingMetrics.kvCacheTransferStart = maybeToGlobalSteadyClock(time);
|
||||||
}
|
}
|
||||||
|
|
||||||
void setKvCacheTransferEnd(TimePoint const& time)
|
void setKvCacheTransferEnd(TimePoint time) const
|
||||||
{
|
{
|
||||||
mPerfMetrics.timingMetrics.kvCacheTransferEnd = maybeToGlobalSteadyClock(time);
|
mPerfMetrics.timingMetrics.kvCacheTransferEnd = maybeToGlobalSteadyClock(time);
|
||||||
}
|
}
|
||||||
|
|
||||||
TimePoint getKvCacheTransferStart()
|
TimePoint getKvCacheTransferStart() const
|
||||||
{
|
{
|
||||||
return mPerfMetrics.timingMetrics.kvCacheTransferStart;
|
return mPerfMetrics.timingMetrics.kvCacheTransferStart;
|
||||||
}
|
}
|
||||||
|
|
||||||
TimePoint getKvCacheTransferEnd()
|
TimePoint getKvCacheTransferEnd() const
|
||||||
{
|
{
|
||||||
return mPerfMetrics.timingMetrics.kvCacheTransferEnd;
|
return mPerfMetrics.timingMetrics.kvCacheTransferEnd;
|
||||||
}
|
}
|
||||||
@ -1865,13 +1865,11 @@ public:
|
|||||||
return mUseDraftModel;
|
return mUseDraftModel;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If mGlobalSteadyClockOffset is set, return a global steady clock time point, otherwise return local steady clock
|
// If sGlobalSteadyClockOffset is set, return a global steady clock time point, otherwise return local steady clock
|
||||||
// time point
|
// time point
|
||||||
[[nodiscard]] TimePoint getSteadyClockNow() const
|
[[nodiscard]] static TimePoint getSteadyClockNow()
|
||||||
{
|
{
|
||||||
const TimePoint time_point = std::chrono::steady_clock::now();
|
return maybeToGlobalSteadyClock(std::chrono::steady_clock::now());
|
||||||
|
|
||||||
return maybeToGlobalSteadyClock(time_point);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RequestIdType mRequestId;
|
RequestIdType mRequestId;
|
||||||
@ -1894,7 +1892,7 @@ public:
|
|||||||
SizeType32 mPtableCurrentPosition{0};
|
SizeType32 mPtableCurrentPosition{0};
|
||||||
|
|
||||||
// The offset between local steady clock and global steady clock (at rank 0)
|
// The offset between local steady clock and global steady clock (at rank 0)
|
||||||
inline static std::optional<Duration> mGlobalSteadyClockOffset{std::nullopt};
|
inline static std::optional<Duration> sGlobalSteadyClockOffset{std::nullopt};
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
bool mIsStreaming;
|
bool mIsStreaming;
|
||||||
@ -2028,9 +2026,9 @@ protected:
|
|||||||
|
|
||||||
std::optional<TensorPtr> mSkipCrossAttnBlocks{std::nullopt};
|
std::optional<TensorPtr> mSkipCrossAttnBlocks{std::nullopt};
|
||||||
|
|
||||||
// Performance metrics.
|
// Performance metrics. Should be updatable even from a const LlmRequest reference.
|
||||||
bool mReturnPerfMetrics{false};
|
bool mReturnPerfMetrics{false};
|
||||||
executor::RequestPerfMetrics mPerfMetrics;
|
mutable executor::RequestPerfMetrics mPerfMetrics;
|
||||||
|
|
||||||
// Guided decoding params.
|
// Guided decoding params.
|
||||||
std::optional<executor::GuidedDecodingParams> mGuidedDecodingParams{std::nullopt};
|
std::optional<executor::GuidedDecodingParams> mGuidedDecodingParams{std::nullopt};
|
||||||
@ -2183,16 +2181,13 @@ private:
|
|||||||
return tensor;
|
return tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
TimePoint maybeToGlobalSteadyClock(TimePoint const& time_point) const
|
static TimePoint maybeToGlobalSteadyClock(TimePoint const& time_point)
|
||||||
{
|
{
|
||||||
if (mGlobalSteadyClockOffset.has_value())
|
if (sGlobalSteadyClockOffset.has_value())
|
||||||
{
|
{
|
||||||
return time_point + *mGlobalSteadyClockOffset;
|
return time_point + *sGlobalSteadyClockOffset;
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
return time_point;
|
|
||||||
}
|
}
|
||||||
|
return time_point;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -451,7 +451,7 @@ struct RequestPerfMetrics
|
|||||||
/// @brief End time of the KV cache transfer for disaggregated serving
|
/// @brief End time of the KV cache transfer for disaggregated serving
|
||||||
TimePoint kvCacheTransferEnd;
|
TimePoint kvCacheTransferEnd;
|
||||||
/// @brief KV Cache size transfer for disaggregated serving
|
/// @brief KV Cache size transfer for disaggregated serving
|
||||||
mutable size_t kvCacheSize = 0;
|
size_t kvCacheSize = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct KvCacheMetrics
|
struct KvCacheMetrics
|
||||||
|
|||||||
@ -227,6 +227,7 @@ std::vector<size_t> CacheFormatter::pickRecvConnections(
|
|||||||
void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& session)
|
void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& session)
|
||||||
{
|
{
|
||||||
NVTX3_SCOPED_RANGE(CacheFormatter_format);
|
NVTX3_SCOPED_RANGE(CacheFormatter_format);
|
||||||
|
session.setTime(TransferSession::kTimeFormatter);
|
||||||
auto const& llmRequest = session.getLlmRequest();
|
auto const& llmRequest = session.getLlmRequest();
|
||||||
TLLM_LOG_DEBUG(
|
TLLM_LOG_DEBUG(
|
||||||
mpi::MpiComm::world().getRank(), "Start sending KV cache for request ID: %ld.", llmRequest.mRequestId);
|
mpi::MpiComm::world().getRank(), "Start sending KV cache for request ID: %ld.", llmRequest.mRequestId);
|
||||||
@ -249,9 +250,6 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
|
|||||||
auto const numPools = blockManager.getNumPools();
|
auto const numPools = blockManager.getNumPools();
|
||||||
// TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1...
|
// TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1...
|
||||||
|
|
||||||
auto lastTokenTime = llmRequest.getPerfMetrics().timingMetrics.lastTokenTime;
|
|
||||||
bool recordDelay = lastTokenTime != std::chrono::steady_clock::time_point();
|
|
||||||
|
|
||||||
bool layerWise = common::getEnvDisaggLayerwise() && numPools == 1;
|
bool layerWise = common::getEnvDisaggLayerwise() && numPools == 1;
|
||||||
if (layerWise)
|
if (layerWise)
|
||||||
{
|
{
|
||||||
@ -420,6 +418,7 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
|
|||||||
inputKvCacheBlocksPerWindow, outputSplitCaches, destConfig, selfConfig, selfIdx, bufferManager);
|
inputKvCacheBlocksPerWindow, outputSplitCaches, destConfig, selfConfig, selfIdx, bufferManager);
|
||||||
|
|
||||||
bufferManager.getStream().synchronize();
|
bufferManager.getStream().synchronize();
|
||||||
|
session.setTime(TransferSession::kTimePreprocess);
|
||||||
|
|
||||||
auto preAllocSendBuffer = mCacheTransBufferManager->getSendBuffer(cacheBufferId);
|
auto preAllocSendBuffer = mCacheTransBufferManager->getSendBuffer(cacheBufferId);
|
||||||
if (preAllocSendBuffer != nullptr)
|
if (preAllocSendBuffer != nullptr)
|
||||||
@ -434,7 +433,7 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
|
|||||||
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
|
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
|
||||||
TLLM_CHECK(connections.size() > (processIdx / peerDuplicateHeadFactor));
|
TLLM_CHECK(connections.size() > (processIdx / peerDuplicateHeadFactor));
|
||||||
TLLM_CHECK(outputSplitCaches.size() > (processIdx / peerDuplicateHeadFactor));
|
TLLM_CHECK(outputSplitCaches.size() > (processIdx / peerDuplicateHeadFactor));
|
||||||
auto startTime = llmRequest.getSteadyClockNow();
|
auto startTime = LlmRequest::getSteadyClockNow();
|
||||||
|
|
||||||
size_t ppDomainSize = targetInfo.mDomainPPSize;
|
size_t ppDomainSize = targetInfo.mDomainPPSize;
|
||||||
size_t bufferTpRank = (processIdx / ppDomainSize) / peerDuplicateHeadFactor;
|
size_t bufferTpRank = (processIdx / ppDomainSize) / peerDuplicateHeadFactor;
|
||||||
@ -481,15 +480,8 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto endTime = llmRequest.getSteadyClockNow();
|
auto endTime = LlmRequest::getSteadyClockNow();
|
||||||
double delay = 0.0;
|
session.appendMeasure(startTime, endTime, size);
|
||||||
if (recordDelay)
|
|
||||||
{
|
|
||||||
delay = std::chrono::duration<double, std::milli>(startTime - lastTokenTime).count();
|
|
||||||
}
|
|
||||||
double cacheTransferTime
|
|
||||||
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
|
|
||||||
session.appendMeasure(delay, cacheTransferTime, size);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if (connections.size() > 1)
|
if (connections.size() > 1)
|
||||||
@ -534,8 +526,10 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
|
|||||||
{
|
{
|
||||||
sendBufferFun(deviceId, 0);
|
sendBufferFun(deviceId, 0);
|
||||||
}
|
}
|
||||||
|
session.setTime(TransferSession::kTimeTransmissions);
|
||||||
|
|
||||||
mCacheTransBufferManager->freeBufferIndexForSend(cacheBufferId);
|
mCacheTransBufferManager->freeBufferIndexForSend(cacheBufferId);
|
||||||
|
session.setTime(TransferSession::kTimePostprocess);
|
||||||
}
|
}
|
||||||
TLLM_LOG_DEBUG(
|
TLLM_LOG_DEBUG(
|
||||||
mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID:%ld ", llmRequest.mRequestId);
|
mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID:%ld ", llmRequest.mRequestId);
|
||||||
@ -544,6 +538,7 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
|
|||||||
void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& session)
|
void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& session)
|
||||||
{
|
{
|
||||||
NVTX3_SCOPED_RANGE(CacheFormatter_unformat);
|
NVTX3_SCOPED_RANGE(CacheFormatter_unformat);
|
||||||
|
session.setTime(TransferSession::kTimeFormatter);
|
||||||
auto const& llmRequest = session.getLlmRequest();
|
auto const& llmRequest = session.getLlmRequest();
|
||||||
auto const ctxReqId = llmRequest.getContextPhaseParams().value().getReqId();
|
auto const ctxReqId = llmRequest.getContextPhaseParams().value().getReqId();
|
||||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||||
@ -555,9 +550,6 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
|
|||||||
auto& bufferManager = session.getBufferManager();
|
auto& bufferManager = session.getBufferManager();
|
||||||
auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest, destConfig.getEnableBlockReuse());
|
auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest, destConfig.getEnableBlockReuse());
|
||||||
|
|
||||||
auto arrivalTime = llmRequest.getPerfMetrics().timingMetrics.arrivalTime;
|
|
||||||
bool recordDelay = arrivalTime != std::chrono::steady_clock::time_point();
|
|
||||||
|
|
||||||
auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig);
|
auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig);
|
||||||
|
|
||||||
TLLM_LOG_DEBUG("pickUpConnections size: %d connections size: %d", pickUpConnections.size(), connections.size());
|
TLLM_LOG_DEBUG("pickUpConnections size: %d connections size: %d", pickUpConnections.size(), connections.size());
|
||||||
@ -779,6 +771,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
|
|||||||
// sync to alloc buffer
|
// sync to alloc buffer
|
||||||
bufferManager.getStream().synchronize();
|
bufferManager.getStream().synchronize();
|
||||||
}
|
}
|
||||||
|
session.setTime(TransferSession::kTimePreprocess);
|
||||||
|
|
||||||
runtime::ITensor::SharedPtr preAllocRecvBuffer = nullptr;
|
runtime::ITensor::SharedPtr preAllocRecvBuffer = nullptr;
|
||||||
if (cacheBufferId.has_value())
|
if (cacheBufferId.has_value())
|
||||||
@ -794,7 +787,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
|
|||||||
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
|
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
|
||||||
TLLM_CHECK(pickUpConnections.size() > processIdx);
|
TLLM_CHECK(pickUpConnections.size() > processIdx);
|
||||||
TLLM_CHECK(recvSplitCaches.size() > processIdx);
|
TLLM_CHECK(recvSplitCaches.size() > processIdx);
|
||||||
auto startTime = llmRequest.getSteadyClockNow();
|
auto startTime = LlmRequest::getSteadyClockNow();
|
||||||
size_t size = 0;
|
size_t size = 0;
|
||||||
|
|
||||||
if (processIdx >= remainNoCoverTargetNum)
|
if (processIdx >= remainNoCoverTargetNum)
|
||||||
@ -835,15 +828,8 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto endTime = llmRequest.getSteadyClockNow();
|
auto endTime = LlmRequest::getSteadyClockNow();
|
||||||
double delay = 0.0;
|
session.appendMeasure(startTime, endTime, size);
|
||||||
if (recordDelay)
|
|
||||||
{
|
|
||||||
delay = std::chrono::duration<double, std::milli>(startTime - arrivalTime).count();
|
|
||||||
}
|
|
||||||
double cacheTransferTime
|
|
||||||
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
|
|
||||||
session.appendMeasure(delay, cacheTransferTime, size);
|
|
||||||
};
|
};
|
||||||
if (pickUpConnections.size() > 1)
|
if (pickUpConnections.size() > 1)
|
||||||
{
|
{
|
||||||
@ -891,6 +877,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
|
|||||||
{
|
{
|
||||||
recvBufferFun(deviceId, 0);
|
recvBufferFun(deviceId, 0);
|
||||||
}
|
}
|
||||||
|
session.setTime(TransferSession::kTimeTransmissions);
|
||||||
|
|
||||||
{
|
{
|
||||||
NVTX3_SCOPED_RANGE(formatInputConcatenate);
|
NVTX3_SCOPED_RANGE(formatInputConcatenate);
|
||||||
@ -904,6 +891,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
|
|||||||
mCacheTransBufferManager->freeBufferIndexForRecv(cacheBufferId);
|
mCacheTransBufferManager->freeBufferIndexForRecv(cacheBufferId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
session.setTime(TransferSession::kTimePostprocess);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -603,7 +603,7 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
|
|||||||
it->first->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE);
|
it->first->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE);
|
||||||
|
|
||||||
// Gather the kv cache transfer time from all workers and update to leader rank
|
// Gather the kv cache transfer time from all workers and update to leader rank
|
||||||
if (!common::getEnvKVCacheTransferOutputPath().empty())
|
if (!common::getEnvKVCacheTimeOutputPath().empty())
|
||||||
{
|
{
|
||||||
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm;
|
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm;
|
||||||
updateKVCacheTransferBW(syncComm, it->first);
|
updateKVCacheTransferBW(syncComm, it->first);
|
||||||
|
|||||||
@ -28,6 +28,7 @@
|
|||||||
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
|
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
|
||||||
#include "tensorrt_llm/runtime/common.h"
|
#include "tensorrt_llm/runtime/common.h"
|
||||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||||
|
#include <chrono>
|
||||||
#include <future>
|
#include <future>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
@ -105,39 +106,65 @@ void TransferSession::setLlmRequest(LlmRequest const& llmRequest)
|
|||||||
mRequest = &llmRequest;
|
mRequest = &llmRequest;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TransferSession::appendMeasure(double delay, double duration, size_t size)
|
void TransferSession::setTime(TimeNames name)
|
||||||
{
|
{
|
||||||
if (!mRecordMeasure)
|
if (mTimes)
|
||||||
{
|
{
|
||||||
return;
|
mTimes->times.at(name) = LlmRequest::getSteadyClockNow();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TransferSession::appendMeasure(LlmRequest::TimePoint start, LlmRequest::TimePoint end, size_t size)
|
||||||
|
{
|
||||||
|
if (mTimes)
|
||||||
|
{
|
||||||
|
mTimes->measures.emplace_back(Measure{start, end, size});
|
||||||
}
|
}
|
||||||
auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps
|
|
||||||
mMeasures.emplace_back(Measure{delay, duration, bandwidth});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const
|
void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const
|
||||||
{
|
{
|
||||||
if (mMeasures.empty())
|
if (!mTimes || mTimes->measures.empty())
|
||||||
{
|
{
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// write header if not exist
|
// write header if not exist
|
||||||
if (outFile.tellp() == 0)
|
if (outFile.tellp() == 0)
|
||||||
{
|
{
|
||||||
outFile << "RequestID";
|
outFile << "RequestID,RequestInfo,Preparation,Preprocess,Transmissions,Postprocess";
|
||||||
for (size_t i = 0; i < mMeasures.size(); i++)
|
for (size_t i = 0; i < mTimes->measures.size(); i++)
|
||||||
{
|
{
|
||||||
outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)";
|
outFile << ",Delay,Duration,Bandwidth(Gbps)";
|
||||||
}
|
}
|
||||||
outFile << '\n';
|
outFile << '\n';
|
||||||
}
|
}
|
||||||
// write measures
|
auto transferStart = mRequest->getPerfMetrics().timingMetrics.kvCacheTransferStart;
|
||||||
|
using Milliseconds = std::chrono::duration<double, std::milli>;
|
||||||
|
|
||||||
|
// write measures, time is in milliseconds
|
||||||
TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value());
|
TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value());
|
||||||
auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId();
|
auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId();
|
||||||
outFile << reqId;
|
outFile << reqId;
|
||||||
for (auto const& measure : mMeasures)
|
auto previousTime = transferStart;
|
||||||
|
for (auto time : mTimes->times)
|
||||||
{
|
{
|
||||||
outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth;
|
if (time == LlmRequest::TimePoint())
|
||||||
|
{
|
||||||
|
// timepoint is unset, skip
|
||||||
|
outFile << ",0.0";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
double delay = Milliseconds(time - previousTime).count();
|
||||||
|
previousTime = time;
|
||||||
|
outFile << "," << delay;
|
||||||
|
}
|
||||||
|
previousTime = mTimes->times[kTimePreprocess];
|
||||||
|
for (auto const& measure : mTimes->measures)
|
||||||
|
{
|
||||||
|
double delay = Milliseconds(measure.start - previousTime).count();
|
||||||
|
double duration = Milliseconds(measure.end - measure.start).count();
|
||||||
|
double bandwidth = static_cast<double>(measure.size) * 8.0 / duration / 1e6; // byte, ms => Gbps
|
||||||
|
outFile << "," << delay << "," << duration << "," << bandwidth;
|
||||||
}
|
}
|
||||||
outFile << '\n' << std::flush;
|
outFile << '\n' << std::flush;
|
||||||
}
|
}
|
||||||
@ -158,7 +185,7 @@ int32_t tagFromRequestId(LlmRequest::RequestIdType requestId)
|
|||||||
std::filesystem::path getTransferOutputPath(char const* tag)
|
std::filesystem::path getTransferOutputPath(char const* tag)
|
||||||
{
|
{
|
||||||
namespace fs = std::filesystem;
|
namespace fs = std::filesystem;
|
||||||
auto outputPath = common::getEnvKVCacheTransferOutputPath();
|
auto outputPath = common::getEnvKVCacheTimeOutputPath();
|
||||||
if (!outputPath.empty())
|
if (!outputPath.empty())
|
||||||
{
|
{
|
||||||
auto rank = mpi::MpiComm::world().getRank();
|
auto rank = mpi::MpiComm::world().getRank();
|
||||||
@ -273,6 +300,7 @@ public:
|
|||||||
{
|
{
|
||||||
std::promise<void> promise;
|
std::promise<void> promise;
|
||||||
auto future = promise.get_future();
|
auto future = promise.get_future();
|
||||||
|
llmRequest.setKvCacheTransferStart(LlmRequest::getSteadyClockNow());
|
||||||
{
|
{
|
||||||
{
|
{
|
||||||
std::scoped_lock lkResp(mSenderMutex);
|
std::scoped_lock lkResp(mSenderMutex);
|
||||||
@ -309,7 +337,7 @@ public:
|
|||||||
std::unique_lock<std::mutex> lk(mMtxForMap);
|
std::unique_lock<std::mutex> lk(mMtxForMap);
|
||||||
auto it = mRequestToSession.find(requestId);
|
auto it = mRequestToSession.find(requestId);
|
||||||
TLLM_CHECK(it != mRequestToSession.end());
|
TLLM_CHECK(it != mRequestToSession.end());
|
||||||
if (!common::getEnvKVCacheTransferOutputPath().empty())
|
if (!common::getEnvKVCacheTimeOutputPath().empty())
|
||||||
{
|
{
|
||||||
if (!mMeasuresFile.is_open())
|
if (!mMeasuresFile.is_open())
|
||||||
{
|
{
|
||||||
@ -363,7 +391,8 @@ public:
|
|||||||
auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr),
|
auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr),
|
||||||
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager,
|
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager,
|
||||||
info.getIndexFromEnd(), info.getLastBlockKey(), nullptr,
|
info.getIndexFromEnd(), info.getLastBlockKey(), nullptr,
|
||||||
!common::getEnvKVCacheTransferOutputPath().empty());
|
!common::getEnvKVCacheTimeOutputPath().empty());
|
||||||
|
session.setTime(TransferSession::kTimeRequestInfo);
|
||||||
it = mRequestToSession.emplace(requestId, std::move(session)).first;
|
it = mRequestToSession.emplace(requestId, std::move(session)).first;
|
||||||
}
|
}
|
||||||
it->second.setConnection(peerIdx, connection);
|
it->second.setConnection(peerIdx, connection);
|
||||||
@ -382,6 +411,7 @@ public:
|
|||||||
}
|
}
|
||||||
session->setLlmRequest(llmRequest);
|
session->setLlmRequest(llmRequest);
|
||||||
mFormatter->format(*session);
|
mFormatter->format(*session);
|
||||||
|
llmRequest.setKvCacheTransferEnd(LlmRequest::getSteadyClockNow());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool cancelRequest(LlmRequest const& llmRequest)
|
bool cancelRequest(LlmRequest const& llmRequest)
|
||||||
@ -751,7 +781,7 @@ public:
|
|||||||
void receiveSync(TransferSession& session)
|
void receiveSync(TransferSession& session)
|
||||||
{
|
{
|
||||||
mFormatter->unformat(session);
|
mFormatter->unformat(session);
|
||||||
if (!common::getEnvKVCacheTransferOutputPath().empty())
|
if (!common::getEnvKVCacheTimeOutputPath().empty())
|
||||||
{
|
{
|
||||||
std::unique_lock<std::mutex> lock(mMeasuresFileMutex);
|
std::unique_lock<std::mutex> lock(mMeasuresFileMutex);
|
||||||
if (!mMeasuresFile.is_open())
|
if (!mMeasuresFile.is_open())
|
||||||
@ -846,7 +876,7 @@ public:
|
|||||||
auto const& resource = getReceiveCacheResource(llmRequest);
|
auto const& resource = getReceiveCacheResource(llmRequest);
|
||||||
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState,
|
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState,
|
||||||
contextState, resource->mBufferManager, requestInfo.getIndexFromEnd(), requestInfo.getLastBlockKey(),
|
contextState, resource->mBufferManager, requestInfo.getIndexFromEnd(), requestInfo.getLastBlockKey(),
|
||||||
&llmRequest, !common::getEnvKVCacheTransferOutputPath().empty());
|
&llmRequest, !common::getEnvKVCacheTimeOutputPath().empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<ReceiveCacheResource> const& getReceiveCacheResource(LlmRequest const& llmRequest)
|
std::unique_ptr<ReceiveCacheResource> const& getReceiveCacheResource(LlmRequest const& llmRequest)
|
||||||
@ -957,6 +987,7 @@ private:
|
|||||||
llmRequest.setKvCacheTransferStart(std::chrono::steady_clock::now());
|
llmRequest.setKvCacheTransferStart(std::chrono::steady_clock::now());
|
||||||
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId));
|
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId));
|
||||||
auto session = sendRequestInfo(llmRequest);
|
auto session = sendRequestInfo(llmRequest);
|
||||||
|
session.setTime(TransferSession::kTimeRequestInfo);
|
||||||
bool isReady = receiveReadySignal(session);
|
bool isReady = receiveReadySignal(session);
|
||||||
if (!isReady)
|
if (!isReady)
|
||||||
{
|
{
|
||||||
|
|||||||
@ -56,29 +56,48 @@ using UniqueToken = tensorrt_llm::runtime::UniqueToken;
|
|||||||
class TransferSession
|
class TransferSession
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
// measures for each single transmission
|
||||||
struct Measure
|
struct Measure
|
||||||
{
|
{
|
||||||
double delay; // from last token (ctx) or arrival time (gen), in ms
|
LlmRequest::TimePoint start;
|
||||||
double duration; // in ms
|
LlmRequest::TimePoint end;
|
||||||
double bandwidth; // in Gbps
|
size_t size = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
enum TimeNames : uint8_t
|
||||||
|
{
|
||||||
|
kTimeRequestInfo = 0,
|
||||||
|
kTimeFormatter,
|
||||||
|
kTimePreprocess,
|
||||||
|
kTimeTransmissions,
|
||||||
|
kTimePostprocess,
|
||||||
|
kTimeCounts
|
||||||
|
};
|
||||||
|
|
||||||
|
struct KVCacheTimes
|
||||||
|
{
|
||||||
|
std::array<LlmRequest::TimePoint, kTimeCounts> times;
|
||||||
|
std::vector<Measure> measures;
|
||||||
};
|
};
|
||||||
|
|
||||||
TransferSession(std::vector<Connection const*> connections, DataContext dataContext,
|
TransferSession(std::vector<Connection const*> connections, DataContext dataContext,
|
||||||
executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState,
|
executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState,
|
||||||
runtime::BufferManager const& bufferManager, int32_t indexFromEnd, BlockKey const& lastBlockKey,
|
runtime::BufferManager const& bufferManager, int32_t indexFromEnd, BlockKey const& lastBlockKey,
|
||||||
LlmRequest const* llmRequest = nullptr, bool recordMeasure = false)
|
LlmRequest const* llmRequest = nullptr, bool recordTiming = false)
|
||||||
: mConnections(std::move(connections))
|
: mConnections(std::move(connections))
|
||||||
, mDataContext(std::move(dataContext))
|
, mDataContext(std::move(dataContext))
|
||||||
, mSelfState(&selfState)
|
, mSelfState(&selfState)
|
||||||
, mOtherState(std::move(otherState))
|
, mOtherState(std::move(otherState))
|
||||||
, mBufferManager(&bufferManager)
|
, mBufferManager(&bufferManager)
|
||||||
, mRequest(llmRequest)
|
, mRequest(llmRequest)
|
||||||
, mMeasures()
|
|
||||||
, mRecordMeasure(recordMeasure)
|
|
||||||
, mIndexFromEnd(indexFromEnd)
|
, mIndexFromEnd(indexFromEnd)
|
||||||
, mLastBlockKey(lastBlockKey)
|
, mLastBlockKey(lastBlockKey)
|
||||||
{
|
{
|
||||||
TLLM_CHECK(!mConnections.empty());
|
TLLM_CHECK(!mConnections.empty());
|
||||||
|
if (recordTiming)
|
||||||
|
{
|
||||||
|
mTimes = std::make_unique<KVCacheTimes>();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] std::vector<Connection const*> const& getConnections() const;
|
[[nodiscard]] std::vector<Connection const*> const& getConnections() const;
|
||||||
@ -103,7 +122,9 @@ public:
|
|||||||
// in CacheSender, the LlmRequest is not available until the sendSync is called
|
// in CacheSender, the LlmRequest is not available until the sendSync is called
|
||||||
void setLlmRequest(LlmRequest const& llmRequest);
|
void setLlmRequest(LlmRequest const& llmRequest);
|
||||||
|
|
||||||
void appendMeasure(double delay, double duration, size_t size);
|
void setTime(TimeNames name);
|
||||||
|
|
||||||
|
void appendMeasure(LlmRequest::TimePoint start, LlmRequest::TimePoint end, size_t size);
|
||||||
|
|
||||||
// TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file
|
// TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file
|
||||||
void exportMeasure(std::ofstream& outFile, bool isContext) const;
|
void exportMeasure(std::ofstream& outFile, bool isContext) const;
|
||||||
@ -125,8 +146,7 @@ private:
|
|||||||
executor::DataTransceiverState mOtherState;
|
executor::DataTransceiverState mOtherState;
|
||||||
runtime::BufferManager const* mBufferManager;
|
runtime::BufferManager const* mBufferManager;
|
||||||
LlmRequest const* mRequest;
|
LlmRequest const* mRequest;
|
||||||
std::vector<Measure> mMeasures;
|
std::unique_ptr<KVCacheTimes> mTimes;
|
||||||
bool mRecordMeasure{false};
|
|
||||||
int32_t mIndexFromEnd{0};
|
int32_t mIndexFromEnd{0};
|
||||||
BlockKey mLastBlockKey{};
|
BlockKey mLastBlockKey{};
|
||||||
};
|
};
|
||||||
|
|||||||
@ -122,6 +122,7 @@ bool MLACacheFormatter::needSendCache(
|
|||||||
void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& session)
|
void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& session)
|
||||||
{
|
{
|
||||||
NVTX3_SCOPED_RANGE(MLACacheFormatter_format);
|
NVTX3_SCOPED_RANGE(MLACacheFormatter_format);
|
||||||
|
session.setTime(TransferSession::kTimeFormatter);
|
||||||
auto const& llmRequest = session.getLlmRequest();
|
auto const& llmRequest = session.getLlmRequest();
|
||||||
TLLM_LOG_DEBUG(
|
TLLM_LOG_DEBUG(
|
||||||
mpi::MpiComm::world().getRank(), "Start sending KV cache for request ID: %ld.", llmRequest.mRequestId);
|
mpi::MpiComm::world().getRank(), "Start sending KV cache for request ID: %ld.", llmRequest.mRequestId);
|
||||||
@ -141,9 +142,6 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
|
|||||||
|
|
||||||
// diff end
|
// diff end
|
||||||
|
|
||||||
auto lastTokenTime = llmRequest.getPerfMetrics().timingMetrics.lastTokenTime;
|
|
||||||
bool recordDelay = lastTokenTime != std::chrono::steady_clock::time_point();
|
|
||||||
|
|
||||||
int blockNum = 0;
|
int blockNum = 0;
|
||||||
std::vector<runtime::ITensor::SharedPtr> inputKvCacheBlocks;
|
std::vector<runtime::ITensor::SharedPtr> inputKvCacheBlocks;
|
||||||
auto const numPools = mCacheManager->getBlockManager().getNumPools();
|
auto const numPools = mCacheManager->getBlockManager().getNumPools();
|
||||||
@ -235,6 +233,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
|
|||||||
inputKvCacheBlocksPerWindow, outputSplitCaches, destConfig, selfConfig, selfIdx, bufferManager);
|
inputKvCacheBlocksPerWindow, outputSplitCaches, destConfig, selfConfig, selfIdx, bufferManager);
|
||||||
|
|
||||||
bufferManager.getStream().synchronize();
|
bufferManager.getStream().synchronize();
|
||||||
|
session.setTime(TransferSession::kTimePreprocess);
|
||||||
|
|
||||||
auto preAllocSendBuffer = mCacheTransBufferManager->getSendBuffer(cacheBufferId);
|
auto preAllocSendBuffer = mCacheTransBufferManager->getSendBuffer(cacheBufferId);
|
||||||
if (preAllocSendBuffer != nullptr)
|
if (preAllocSendBuffer != nullptr)
|
||||||
@ -246,7 +245,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
|
|||||||
NVTX3_SCOPED_RANGE(sendBufferFun);
|
NVTX3_SCOPED_RANGE(sendBufferFun);
|
||||||
|
|
||||||
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
|
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
|
||||||
auto startTime = llmRequest.getSteadyClockNow();
|
auto startTime = LlmRequest::getSteadyClockNow();
|
||||||
auto cacheIdx = processIdx % (pPDomainSize * cPDomainSize);
|
auto cacheIdx = processIdx % (pPDomainSize * cPDomainSize);
|
||||||
if (cacheIdx < bufferCoverTargetNum)
|
if (cacheIdx < bufferCoverTargetNum)
|
||||||
{
|
{
|
||||||
@ -279,15 +278,8 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
|
|||||||
remainSendSize -= sendSize;
|
remainSendSize -= sendSize;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto endTime = llmRequest.getSteadyClockNow();
|
auto endTime = LlmRequest::getSteadyClockNow();
|
||||||
double delay = 0.0;
|
session.appendMeasure(startTime, endTime, outputSplitCaches.at(cacheIdx)->getSizeInBytes());
|
||||||
if (recordDelay)
|
|
||||||
{
|
|
||||||
delay = std::chrono::duration<double, std::milli>(startTime - lastTokenTime).count();
|
|
||||||
}
|
|
||||||
double cacheTransferTime
|
|
||||||
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
|
|
||||||
session.appendMeasure(delay, cacheTransferTime, outputSplitCaches.at(cacheIdx)->getSizeInBytes());
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if (connections.size() > 1)
|
if (connections.size() > 1)
|
||||||
@ -331,7 +323,9 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
|
|||||||
{
|
{
|
||||||
sendBufferFun(deviceId, 0);
|
sendBufferFun(deviceId, 0);
|
||||||
}
|
}
|
||||||
|
session.setTime(TransferSession::kTimeTransmissions);
|
||||||
mCacheTransBufferManager->freeBufferIndexForSend(cacheBufferId);
|
mCacheTransBufferManager->freeBufferIndexForSend(cacheBufferId);
|
||||||
|
session.setTime(TransferSession::kTimePostprocess);
|
||||||
|
|
||||||
TLLM_LOG_DEBUG(
|
TLLM_LOG_DEBUG(
|
||||||
mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID: %ld.", llmRequest.mRequestId);
|
mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID: %ld.", llmRequest.mRequestId);
|
||||||
@ -340,6 +334,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
|
|||||||
void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& session)
|
void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& session)
|
||||||
{
|
{
|
||||||
NVTX3_SCOPED_RANGE(MLACacheFormatter_unformat);
|
NVTX3_SCOPED_RANGE(MLACacheFormatter_unformat);
|
||||||
|
session.setTime(TransferSession::kTimeFormatter);
|
||||||
auto const& llmRequest = session.getLlmRequest();
|
auto const& llmRequest = session.getLlmRequest();
|
||||||
TLLM_CHECK_WITH_INFO(llmRequest.mSamplingConfig.beamWidth == 1, "Currently only supports beam width 1.");
|
TLLM_CHECK_WITH_INFO(llmRequest.mSamplingConfig.beamWidth == 1, "Currently only supports beam width 1.");
|
||||||
auto const ctxReqId = llmRequest.getContextPhaseParams().value().getReqId();
|
auto const ctxReqId = llmRequest.getContextPhaseParams().value().getReqId();
|
||||||
@ -350,8 +345,6 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
|
|||||||
auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx();
|
auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx();
|
||||||
auto const& connections = session.getConnections();
|
auto const& connections = session.getConnections();
|
||||||
auto& bufferManager = session.getBufferManager();
|
auto& bufferManager = session.getBufferManager();
|
||||||
auto arrivalTime = llmRequest.getPerfMetrics().timingMetrics.arrivalTime;
|
|
||||||
bool recordDelay = arrivalTime != std::chrono::steady_clock::time_point();
|
|
||||||
auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig);
|
auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig);
|
||||||
auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest, destConfig.getEnableBlockReuse());
|
auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest, destConfig.getEnableBlockReuse());
|
||||||
std::vector<runtime::ITensor::SharedPtr> recvBufferTmps;
|
std::vector<runtime::ITensor::SharedPtr> recvBufferTmps;
|
||||||
@ -445,6 +438,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
|
|||||||
TLLM_CHECK(onlyUseDynamicBuffer == false);
|
TLLM_CHECK(onlyUseDynamicBuffer == false);
|
||||||
}
|
}
|
||||||
bufferManager.getStream().synchronize();
|
bufferManager.getStream().synchronize();
|
||||||
|
session.setTime(TransferSession::kTimePreprocess);
|
||||||
|
|
||||||
auto preAllocRecvBuffer = mCacheTransBufferManager->getRecvBuffer(cacheBufferId);
|
auto preAllocRecvBuffer = mCacheTransBufferManager->getRecvBuffer(cacheBufferId);
|
||||||
if (preAllocRecvBuffer != nullptr)
|
if (preAllocRecvBuffer != nullptr)
|
||||||
@ -456,7 +450,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
|
|||||||
{
|
{
|
||||||
NVTX3_SCOPED_RANGE(recvBufferFun);
|
NVTX3_SCOPED_RANGE(recvBufferFun);
|
||||||
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
|
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
|
||||||
auto startTime = llmRequest.getSteadyClockNow();
|
auto startTime = LlmRequest::getSteadyClockNow();
|
||||||
size_t size = 0;
|
size_t size = 0;
|
||||||
if (processIdx >= remainNoCoverTargetNum)
|
if (processIdx >= remainNoCoverTargetNum)
|
||||||
{
|
{
|
||||||
@ -489,15 +483,8 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
|
|||||||
remainRecvSize -= recvSize;
|
remainRecvSize -= recvSize;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto endTime = llmRequest.getSteadyClockNow();
|
auto endTime = LlmRequest::getSteadyClockNow();
|
||||||
double delay = 0.0;
|
session.appendMeasure(startTime, endTime, size);
|
||||||
if (recordDelay)
|
|
||||||
{
|
|
||||||
delay = std::chrono::duration<double, std::milli>(startTime - arrivalTime).count();
|
|
||||||
}
|
|
||||||
double cacheTransferTime
|
|
||||||
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
|
|
||||||
session.appendMeasure(delay, cacheTransferTime, size);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if (pickUpConnections.size() > 1)
|
if (pickUpConnections.size() > 1)
|
||||||
@ -546,6 +533,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
|
|||||||
{
|
{
|
||||||
recvBufferFun(deviceId, 0);
|
recvBufferFun(deviceId, 0);
|
||||||
}
|
}
|
||||||
|
session.setTime(TransferSession::kTimeTransmissions);
|
||||||
|
|
||||||
{
|
{
|
||||||
std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> outputCachesPerWindow;
|
std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> outputCachesPerWindow;
|
||||||
@ -564,6 +552,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
|
|||||||
{
|
{
|
||||||
mCacheTransBufferManager->freeBufferIndexForRecv(cacheBufferId);
|
mCacheTransBufferManager->freeBufferIndexForRecv(cacheBufferId);
|
||||||
}
|
}
|
||||||
|
session.setTime(TransferSession::kTimePostprocess);
|
||||||
|
|
||||||
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
|
||||||
"End receiving KV cache for request ID: %ld, context request ID: %ld.", llmRequest.mRequestId,
|
"End receiving KV cache for request ID: %ld, context request ID: %ld.", llmRequest.mRequestId,
|
||||||
|
|||||||
@ -380,7 +380,7 @@ size_t getEnvAllReduceWorkspaceSize()
|
|||||||
return workspaceSize;
|
return workspaceSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string const& getEnvKVCacheTransferOutputPath()
|
std::string const& getEnvKVCacheTimeOutputPath()
|
||||||
{
|
{
|
||||||
static std::string outputPath = getStrEnv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH").value_or("");
|
static std::string outputPath = getStrEnv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH").value_or("");
|
||||||
return outputPath;
|
return outputPath;
|
||||||
|
|||||||
@ -96,7 +96,7 @@ bool getEnvDisableKVCacheTransferOverlap();
|
|||||||
|
|
||||||
bool getEnvEnableReceiveKVCacheParallel();
|
bool getEnvEnableReceiveKVCacheParallel();
|
||||||
|
|
||||||
std::string const& getEnvKVCacheTransferOutputPath();
|
std::string const& getEnvKVCacheTimeOutputPath();
|
||||||
|
|
||||||
bool getEnvTryZCopyForKVCacheTransfer();
|
bool getEnvTryZCopyForKVCacheTransfer();
|
||||||
|
|
||||||
|
|||||||
@ -383,7 +383,7 @@ void initBindings(nb::module_& m)
|
|||||||
.def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime)
|
.def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime)
|
||||||
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, nb::arg("iter_counter"))
|
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, nb::arg("iter_counter"))
|
||||||
.def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors)
|
.def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors)
|
||||||
.def_rw_static("global_steady_clock_offset", &tb::LlmRequest::mGlobalSteadyClockOffset);
|
.def_rw_static("global_steady_clock_offset", &tb::LlmRequest::sGlobalSteadyClockOffset);
|
||||||
|
|
||||||
nb::class_<tb::SequenceSlotManager>(m, "SequenceSlotManager")
|
nb::class_<tb::SequenceSlotManager>(m, "SequenceSlotManager")
|
||||||
.def(nb::init<tb::SequenceSlotManager::SlotIdType, uint64_t>(), nb::arg("max_num_slots"),
|
.def(nb::init<tb::SequenceSlotManager::SlotIdType, uint64_t>(), nb::arg("max_num_slots"),
|
||||||
|
|||||||
@ -389,7 +389,7 @@ void initBindings(pybind11::module_& m)
|
|||||||
.def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime)
|
.def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime)
|
||||||
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, py::arg("iter_counter"))
|
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, py::arg("iter_counter"))
|
||||||
.def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors)
|
.def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors)
|
||||||
.def_readwrite_static("global_steady_clock_offset", &tb::LlmRequest::mGlobalSteadyClockOffset);
|
.def_readwrite_static("global_steady_clock_offset", &tb::LlmRequest::sGlobalSteadyClockOffset);
|
||||||
|
|
||||||
py::classh<tb::SequenceSlotManager>(m, "SequenceSlotManager")
|
py::classh<tb::SequenceSlotManager>(m, "SequenceSlotManager")
|
||||||
.def(py::init<tb::SequenceSlotManager::SlotIdType, uint64_t>(), py::arg("max_num_slots"),
|
.def(py::init<tb::SequenceSlotManager::SlotIdType, uint64_t>(), py::arg("max_num_slots"),
|
||||||
|
|||||||
@ -853,10 +853,12 @@ def test_disaggregated_kv_cache_time_output(disaggregated_test_root, llm_venv,
|
|||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
assert len(lines) > 1
|
assert len(lines) > 1
|
||||||
assert lines[0].startswith(
|
assert lines[0].startswith(
|
||||||
"RequestID,Delay(ms),Duration(ms),Bandwidth(Gbps)")
|
"RequestID,RequestInfo,Preparation,Preprocess,Transmissions,Postprocess"
|
||||||
|
)
|
||||||
|
assert ",Delay,Duration,Bandwidth(Gbps)" in lines[0]
|
||||||
# get a send sample and match the recv
|
# get a send sample and match the recv
|
||||||
sample = lines[1].split(',')
|
sample = lines[1].split(',')
|
||||||
assert len(sample) >= 4
|
assert len(sample) >= 9
|
||||||
with open(recv_file, "r") as f:
|
with open(recv_file, "r") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
assert len(lines) > 1
|
assert len(lines) > 1
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user