fix[nvbug-5295425]: [TRTLLM-5385] fix race condition in MoeLoadBalancer (#4573)

fix moe possible race cond and add bypass worker thread for no updates

Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>
This commit is contained in:
dongxuy04 2025-05-23 09:24:23 +08:00 committed by GitHub
parent 1e55d616da
commit 338744fba6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 14 deletions

View File

@ -728,8 +728,15 @@ void MoeLoadBalancer::finalizeModel()
{
layer->finalizeModel();
}
generateUpdatePlan();
startThreads();
if (mLayerUpdatesPerIter > 0)
{
generateUpdatePlan();
startThreads();
}
else
{
mWorkerThreadStopped = true;
}
mModelFinalized = true;
}
@ -751,10 +758,12 @@ void MoeLoadBalancer::startIter(int64_t iterId, bool enableStatistic, bool enabl
TLLM_CHECK_WITH_INFO(mIterId + 1 == iterId, "Expected iterId=%ld, but got %ld", mIterId + 1, iterId);
mIterId = iterId;
mStatisticEnabled = enableStatistic;
// disable update for warm up iters.
bool isWarmUpIter = mIterId <= mWarmUpUntilIter;
mUpdateWeightsEnabled = enableUpdateWeights && !isWarmUpIter;
bool fixedUpdateWeightsEnabled = enableUpdateWeights && !isWarmUpIter;
IterInfo iterInfo{iterId, enableStatistic, fixedUpdateWeightsEnabled};
mIterInfoQueue.push(iterInfo);
mWorkerThreadCondition.notify_one();
}
@ -780,21 +789,24 @@ void MoeLoadBalancer::shutdown()
void MoeLoadBalancer::workerThread()
{
TLLM_CUDA_CHECK(cudaSetDevice(mCudaDeviceId));
int64_t iterId = -1;
while (true)
{
int64_t iterId;
bool iterUpdateWeightsEnabled, iterStatisticEnabled;
{
std::unique_lock<std::mutex> lock(mWorkerThreadMutex);
mWorkerThreadCondition.wait(lock, [this] { return mWaitIterId == mIterId || mWorkerThreadStopped; });
iterId = mIterId;
if (mWorkerThreadStopped)
mWorkerThreadCondition.wait(lock, [this] { return !mIterInfoQueue.empty() || mWorkerThreadStopped; });
if (mIterInfoQueue.empty() && mWorkerThreadStopped)
{
break;
}
mWaitIterId = mIterId + 1;
iterUpdateWeightsEnabled = mUpdateWeightsEnabled;
iterStatisticEnabled = mStatisticEnabled;
auto iterInfo = mIterInfoQueue.front();
mIterInfoQueue.pop();
TLLM_CHECK_WITH_INFO(iterInfo.iterId == iterId + 1, "Jump detected, iterId=%ld, but got next iterId=%ld",
iterId, iterInfo.iterId);
iterId = iterInfo.iterId;
iterUpdateWeightsEnabled = iterInfo.updateWeightsEnabled;
iterStatisticEnabled = iterInfo.statisticEnabled;
}
for (int layerId = 0; static_cast<size_t>(layerId) < mLayers.size(); ++layerId)
{

View File

@ -238,7 +238,6 @@ private:
std::mutex mWorkerThreadMutex;
std::condition_variable mWorkerThreadCondition;
bool mWorkerThreadStopped = false;
int64_t mWaitIterId = 0;
int64_t mWarmUpUntilIter = -1;
// we use a separate thread to compute and update weights to avoid possible blocking for next layer due to slow
@ -252,8 +251,16 @@ private:
std::vector<std::shared_ptr<SingleLayerMoeLoadBalancer>> mLayers;
int64_t mIterId = -1;
bool mStatisticEnabled = true;
bool mUpdateWeightsEnabled = true;
struct IterInfo
{
int64_t iterId = -1;
bool statisticEnabled = true;
bool updateWeightsEnabled = true;
};
std::queue<IterInfo> mIterInfoQueue;
bool mModelFinalized = false;
int mEpRank = 0;