mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
address comments from Jin, Chuang and Yuxian
This commit is contained in:
parent
4e456350c0
commit
31f2ecd3cb
@ -133,8 +133,19 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
|
||||
|
||||
if (worldConfig.isTensorParallel())
|
||||
{
|
||||
mGroupTensorParaComm = std::make_shared<CacheTransceiverComm>(
|
||||
mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getTensorParallelRank()));
|
||||
if (worldConfig.isContextParallel())
|
||||
{
|
||||
// When CP is enabled, group ranks with same (ppRank, cpRank) to exclude both PP and CP.
|
||||
auto const tpGroupId = worldConfig.getContextParallelRank()
|
||||
+ worldConfig.getContextParallelism() * worldConfig.getPipelineParallelRank();
|
||||
mGroupTensorParaComm
|
||||
= std::make_shared<CacheTransceiverComm>(mGroupComm->split(tpGroupId, worldConfig.getRank()));
|
||||
}
|
||||
else
|
||||
{
|
||||
mGroupTensorParaComm = std::make_shared<CacheTransceiverComm>(
|
||||
mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getTensorParallelRank()));
|
||||
}
|
||||
}
|
||||
int kvFactor = 2;
|
||||
if (cacheManager->getCacheType() == kv_cache_manager::CacheType::kSELFKONLY)
|
||||
@ -155,7 +166,7 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
|
||||
// DPRank is derived from the tensor parallel rank, which already accounts for CP.
|
||||
// Layout: rank = ppRank * (TP * CP) + tpRank * CP + cpRank.
|
||||
// getTensorParallelRank() correctly extracts tpRank regardless of CP.
|
||||
int DPRank = worldConfig.getTensorParallelRank() / TPSizeInDPGroup;
|
||||
int DPRank = mCacheState->getParallelConfig().mDPrank;
|
||||
// <PP,DP,TP,CP>
|
||||
mGroupDataComm = std::make_shared<CacheTransceiverComm>(mGroupComm->split(DPRank, worldConfig.getRank()));
|
||||
if (worldConfig.isTensorParallel())
|
||||
|
||||
@ -554,7 +554,7 @@ protected:
|
||||
|
||||
// Rank formula must match targetIRanks: ppRank * (tpNum * cpNum) + tpRank * cpNum + cpRank.
|
||||
mCpRank = mRankInInstance % mCpSize;
|
||||
mTpRank = (mRankInInstance % (mTpSize * mCpSize)) / mCpSize;
|
||||
mTpRank = (mRankInInstance / mCpSize) % mTpSize;
|
||||
mPpRank = mRankInInstance / (mTpSize * mCpSize);
|
||||
mContextRankSize = contextRanks;
|
||||
mGenRankSize = genRanks;
|
||||
|
||||
@ -370,8 +370,8 @@ class ExecutorRequestQueue:
|
||||
num_active_tokens = sum(
|
||||
[req.py_orig_prompt_len for req in activate_requests])
|
||||
|
||||
# Note: We use tp_allgather even for CP assuming that all CP ranks a
|
||||
# DP group have the same num_active_tokens and num_active_requests.
|
||||
# Note: We use tp_allgather even for CP assuming that all CP ranks with the
|
||||
# same dp_rank have the same num_active_tokens and num_active_requests.
|
||||
responses_list = self.dist.tp_allgather(
|
||||
[len(activate_requests), num_active_tokens])
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user