From 338744fba6a91147b739b7f02d19b37bc19aa17a Mon Sep 17 00:00:00 2001 From: dongxuy04 <78518666+dongxuy04@users.noreply.github.com> Date: Fri, 23 May 2025 09:24:23 +0800 Subject: [PATCH] fix[nvbug-5295425]: [TRTLLM-5385] fix race condition in MoeLoadBalancer (#4573) fix moe possible race cond and add bypass worker thread for no updates Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- cpp/tensorrt_llm/runtime/moeLoadBalancer.cpp | 34 +++++++++++++------- cpp/tensorrt_llm/runtime/moeLoadBalancer.h | 13 ++++++-- 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/cpp/tensorrt_llm/runtime/moeLoadBalancer.cpp b/cpp/tensorrt_llm/runtime/moeLoadBalancer.cpp index 338d06a46c..4f71815eaf 100644 --- a/cpp/tensorrt_llm/runtime/moeLoadBalancer.cpp +++ b/cpp/tensorrt_llm/runtime/moeLoadBalancer.cpp @@ -728,8 +728,15 @@ void MoeLoadBalancer::finalizeModel() { layer->finalizeModel(); } - generateUpdatePlan(); - startThreads(); + if (mLayerUpdatesPerIter > 0) + { + generateUpdatePlan(); + startThreads(); + } + else + { + mWorkerThreadStopped = true; + } mModelFinalized = true; } @@ -751,10 +758,12 @@ void MoeLoadBalancer::startIter(int64_t iterId, bool enableStatistic, bool enabl TLLM_CHECK_WITH_INFO(mIterId + 1 == iterId, "Expected iterId=%ld, but got %ld", mIterId + 1, iterId); mIterId = iterId; - mStatisticEnabled = enableStatistic; // disable update for warm up iters. bool isWarmUpIter = mIterId <= mWarmUpUntilIter; - mUpdateWeightsEnabled = enableUpdateWeights && !isWarmUpIter; + bool fixedUpdateWeightsEnabled = enableUpdateWeights && !isWarmUpIter; + + IterInfo iterInfo{iterId, enableStatistic, fixedUpdateWeightsEnabled}; + mIterInfoQueue.push(iterInfo); mWorkerThreadCondition.notify_one(); } @@ -780,21 +789,24 @@ void MoeLoadBalancer::shutdown() void MoeLoadBalancer::workerThread() { TLLM_CUDA_CHECK(cudaSetDevice(mCudaDeviceId)); + int64_t iterId = -1; while (true) { - int64_t iterId; bool iterUpdateWeightsEnabled, iterStatisticEnabled; { std::unique_lock lock(mWorkerThreadMutex); - mWorkerThreadCondition.wait(lock, [this] { return mWaitIterId == mIterId || mWorkerThreadStopped; }); - iterId = mIterId; - if (mWorkerThreadStopped) + mWorkerThreadCondition.wait(lock, [this] { return !mIterInfoQueue.empty() || mWorkerThreadStopped; }); + if (mIterInfoQueue.empty() && mWorkerThreadStopped) { break; } - mWaitIterId = mIterId + 1; - iterUpdateWeightsEnabled = mUpdateWeightsEnabled; - iterStatisticEnabled = mStatisticEnabled; + auto iterInfo = mIterInfoQueue.front(); + mIterInfoQueue.pop(); + TLLM_CHECK_WITH_INFO(iterInfo.iterId == iterId + 1, "Jump detected, iterId=%ld, but got next iterId=%ld", + iterId, iterInfo.iterId); + iterId = iterInfo.iterId; + iterUpdateWeightsEnabled = iterInfo.updateWeightsEnabled; + iterStatisticEnabled = iterInfo.statisticEnabled; } for (int layerId = 0; static_cast(layerId) < mLayers.size(); ++layerId) { diff --git a/cpp/tensorrt_llm/runtime/moeLoadBalancer.h b/cpp/tensorrt_llm/runtime/moeLoadBalancer.h index 5ce690ceb8..5eb8070421 100644 --- a/cpp/tensorrt_llm/runtime/moeLoadBalancer.h +++ b/cpp/tensorrt_llm/runtime/moeLoadBalancer.h @@ -238,7 +238,6 @@ private: std::mutex mWorkerThreadMutex; std::condition_variable mWorkerThreadCondition; bool mWorkerThreadStopped = false; - int64_t mWaitIterId = 0; int64_t mWarmUpUntilIter = -1; // we use a separate thread to compute and update weights to avoid possible blocking for next layer due to slow @@ -252,8 +251,16 @@ private: std::vector> mLayers; int64_t mIterId = -1; - bool mStatisticEnabled = true; - bool mUpdateWeightsEnabled = true; + + struct IterInfo + { + int64_t iterId = -1; + bool statisticEnabled = true; + bool updateWeightsEnabled = true; + }; + + std::queue mIterInfoQueue; + bool mModelFinalized = false; int mEpRank = 0;