address comments from Jin, Chuang and Yuxian

This commit is contained in:
Balaram Buddharaju 2026-01-13 03:28:51 +00:00
parent 4e456350c0
commit 31f2ecd3cb
3 changed files with 17 additions and 6 deletions

View File

@ -133,8 +133,19 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
if (worldConfig.isTensorParallel())
{
mGroupTensorParaComm = std::make_shared<CacheTransceiverComm>(
mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getTensorParallelRank()));
if (worldConfig.isContextParallel())
{
// When CP is enabled, group ranks with same (ppRank, cpRank) to exclude both PP and CP.
auto const tpGroupId = worldConfig.getContextParallelRank()
+ worldConfig.getContextParallelism() * worldConfig.getPipelineParallelRank();
mGroupTensorParaComm
= std::make_shared<CacheTransceiverComm>(mGroupComm->split(tpGroupId, worldConfig.getRank()));
}
else
{
mGroupTensorParaComm = std::make_shared<CacheTransceiverComm>(
mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getTensorParallelRank()));
}
}
int kvFactor = 2;
if (cacheManager->getCacheType() == kv_cache_manager::CacheType::kSELFKONLY)
@ -155,7 +166,7 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
// DPRank is derived from the tensor parallel rank, which already accounts for CP.
// Layout: rank = ppRank * (TP * CP) + tpRank * CP + cpRank.
// getTensorParallelRank() correctly extracts tpRank regardless of CP.
int DPRank = worldConfig.getTensorParallelRank() / TPSizeInDPGroup;
int DPRank = mCacheState->getParallelConfig().mDPrank;
// <PP,DP,TP,CP>
mGroupDataComm = std::make_shared<CacheTransceiverComm>(mGroupComm->split(DPRank, worldConfig.getRank()));
if (worldConfig.isTensorParallel())

View File

@ -554,7 +554,7 @@ protected:
// Rank formula must match targetIRanks: ppRank * (tpNum * cpNum) + tpRank * cpNum + cpRank.
mCpRank = mRankInInstance % mCpSize;
mTpRank = (mRankInInstance % (mTpSize * mCpSize)) / mCpSize;
mTpRank = (mRankInInstance / mCpSize) % mTpSize;
mPpRank = mRankInInstance / (mTpSize * mCpSize);
mContextRankSize = contextRanks;
mGenRankSize = genRanks;

View File

@ -370,8 +370,8 @@ class ExecutorRequestQueue:
num_active_tokens = sum(
[req.py_orig_prompt_len for req in activate_requests])
# Note: We use tp_allgather even for CP assuming that all CP ranks a
# DP group have the same num_active_tokens and num_active_requests.
# Note: We use tp_allgather even for CP assuming that all CP ranks with the
# same dp_rank have the same num_active_tokens and num_active_requests.
responses_list = self.dist.tp_allgather(
[len(activate_requests), num_active_tokens])