mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix[nvbug-5295425]: [TRTLLM-5385] fix race condition in MoeLoadBalancer (#4573)
fix moe possible race cond and add bypass worker thread for no updates Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>
This commit is contained in:
parent
1e55d616da
commit
338744fba6
@ -728,8 +728,15 @@ void MoeLoadBalancer::finalizeModel()
|
||||
{
|
||||
layer->finalizeModel();
|
||||
}
|
||||
generateUpdatePlan();
|
||||
startThreads();
|
||||
if (mLayerUpdatesPerIter > 0)
|
||||
{
|
||||
generateUpdatePlan();
|
||||
startThreads();
|
||||
}
|
||||
else
|
||||
{
|
||||
mWorkerThreadStopped = true;
|
||||
}
|
||||
mModelFinalized = true;
|
||||
}
|
||||
|
||||
@ -751,10 +758,12 @@ void MoeLoadBalancer::startIter(int64_t iterId, bool enableStatistic, bool enabl
|
||||
TLLM_CHECK_WITH_INFO(mIterId + 1 == iterId, "Expected iterId=%ld, but got %ld", mIterId + 1, iterId);
|
||||
|
||||
mIterId = iterId;
|
||||
mStatisticEnabled = enableStatistic;
|
||||
// disable update for warm up iters.
|
||||
bool isWarmUpIter = mIterId <= mWarmUpUntilIter;
|
||||
mUpdateWeightsEnabled = enableUpdateWeights && !isWarmUpIter;
|
||||
bool fixedUpdateWeightsEnabled = enableUpdateWeights && !isWarmUpIter;
|
||||
|
||||
IterInfo iterInfo{iterId, enableStatistic, fixedUpdateWeightsEnabled};
|
||||
mIterInfoQueue.push(iterInfo);
|
||||
mWorkerThreadCondition.notify_one();
|
||||
}
|
||||
|
||||
@ -780,21 +789,24 @@ void MoeLoadBalancer::shutdown()
|
||||
void MoeLoadBalancer::workerThread()
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaSetDevice(mCudaDeviceId));
|
||||
int64_t iterId = -1;
|
||||
while (true)
|
||||
{
|
||||
int64_t iterId;
|
||||
bool iterUpdateWeightsEnabled, iterStatisticEnabled;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mWorkerThreadMutex);
|
||||
mWorkerThreadCondition.wait(lock, [this] { return mWaitIterId == mIterId || mWorkerThreadStopped; });
|
||||
iterId = mIterId;
|
||||
if (mWorkerThreadStopped)
|
||||
mWorkerThreadCondition.wait(lock, [this] { return !mIterInfoQueue.empty() || mWorkerThreadStopped; });
|
||||
if (mIterInfoQueue.empty() && mWorkerThreadStopped)
|
||||
{
|
||||
break;
|
||||
}
|
||||
mWaitIterId = mIterId + 1;
|
||||
iterUpdateWeightsEnabled = mUpdateWeightsEnabled;
|
||||
iterStatisticEnabled = mStatisticEnabled;
|
||||
auto iterInfo = mIterInfoQueue.front();
|
||||
mIterInfoQueue.pop();
|
||||
TLLM_CHECK_WITH_INFO(iterInfo.iterId == iterId + 1, "Jump detected, iterId=%ld, but got next iterId=%ld",
|
||||
iterId, iterInfo.iterId);
|
||||
iterId = iterInfo.iterId;
|
||||
iterUpdateWeightsEnabled = iterInfo.updateWeightsEnabled;
|
||||
iterStatisticEnabled = iterInfo.statisticEnabled;
|
||||
}
|
||||
for (int layerId = 0; static_cast<size_t>(layerId) < mLayers.size(); ++layerId)
|
||||
{
|
||||
|
||||
@ -238,7 +238,6 @@ private:
|
||||
std::mutex mWorkerThreadMutex;
|
||||
std::condition_variable mWorkerThreadCondition;
|
||||
bool mWorkerThreadStopped = false;
|
||||
int64_t mWaitIterId = 0;
|
||||
int64_t mWarmUpUntilIter = -1;
|
||||
|
||||
// we use a separate thread to compute and update weights to avoid possible blocking for next layer due to slow
|
||||
@ -252,8 +251,16 @@ private:
|
||||
std::vector<std::shared_ptr<SingleLayerMoeLoadBalancer>> mLayers;
|
||||
|
||||
int64_t mIterId = -1;
|
||||
bool mStatisticEnabled = true;
|
||||
bool mUpdateWeightsEnabled = true;
|
||||
|
||||
struct IterInfo
|
||||
{
|
||||
int64_t iterId = -1;
|
||||
bool statisticEnabled = true;
|
||||
bool updateWeightsEnabled = true;
|
||||
};
|
||||
|
||||
std::queue<IterInfo> mIterInfoQueue;
|
||||
|
||||
bool mModelFinalized = false;
|
||||
|
||||
int mEpRank = 0;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user