mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
address comments from Jin, Chuang and Yuxian
This commit is contained in:
parent
4e456350c0
commit
31f2ecd3cb
@ -133,8 +133,19 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
|
|||||||
|
|
||||||
if (worldConfig.isTensorParallel())
|
if (worldConfig.isTensorParallel())
|
||||||
{
|
{
|
||||||
mGroupTensorParaComm = std::make_shared<CacheTransceiverComm>(
|
if (worldConfig.isContextParallel())
|
||||||
mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getTensorParallelRank()));
|
{
|
||||||
|
// When CP is enabled, group ranks with same (ppRank, cpRank) to exclude both PP and CP.
|
||||||
|
auto const tpGroupId = worldConfig.getContextParallelRank()
|
||||||
|
+ worldConfig.getContextParallelism() * worldConfig.getPipelineParallelRank();
|
||||||
|
mGroupTensorParaComm
|
||||||
|
= std::make_shared<CacheTransceiverComm>(mGroupComm->split(tpGroupId, worldConfig.getRank()));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
mGroupTensorParaComm = std::make_shared<CacheTransceiverComm>(
|
||||||
|
mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getTensorParallelRank()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
int kvFactor = 2;
|
int kvFactor = 2;
|
||||||
if (cacheManager->getCacheType() == kv_cache_manager::CacheType::kSELFKONLY)
|
if (cacheManager->getCacheType() == kv_cache_manager::CacheType::kSELFKONLY)
|
||||||
@ -155,7 +166,7 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
|
|||||||
// DPRank is derived from the tensor parallel rank, which already accounts for CP.
|
// DPRank is derived from the tensor parallel rank, which already accounts for CP.
|
||||||
// Layout: rank = ppRank * (TP * CP) + tpRank * CP + cpRank.
|
// Layout: rank = ppRank * (TP * CP) + tpRank * CP + cpRank.
|
||||||
// getTensorParallelRank() correctly extracts tpRank regardless of CP.
|
// getTensorParallelRank() correctly extracts tpRank regardless of CP.
|
||||||
int DPRank = worldConfig.getTensorParallelRank() / TPSizeInDPGroup;
|
int DPRank = mCacheState->getParallelConfig().mDPrank;
|
||||||
// <PP,DP,TP,CP>
|
// <PP,DP,TP,CP>
|
||||||
mGroupDataComm = std::make_shared<CacheTransceiverComm>(mGroupComm->split(DPRank, worldConfig.getRank()));
|
mGroupDataComm = std::make_shared<CacheTransceiverComm>(mGroupComm->split(DPRank, worldConfig.getRank()));
|
||||||
if (worldConfig.isTensorParallel())
|
if (worldConfig.isTensorParallel())
|
||||||
|
|||||||
@ -554,7 +554,7 @@ protected:
|
|||||||
|
|
||||||
// Rank formula must match targetIRanks: ppRank * (tpNum * cpNum) + tpRank * cpNum + cpRank.
|
// Rank formula must match targetIRanks: ppRank * (tpNum * cpNum) + tpRank * cpNum + cpRank.
|
||||||
mCpRank = mRankInInstance % mCpSize;
|
mCpRank = mRankInInstance % mCpSize;
|
||||||
mTpRank = (mRankInInstance % (mTpSize * mCpSize)) / mCpSize;
|
mTpRank = (mRankInInstance / mCpSize) % mTpSize;
|
||||||
mPpRank = mRankInInstance / (mTpSize * mCpSize);
|
mPpRank = mRankInInstance / (mTpSize * mCpSize);
|
||||||
mContextRankSize = contextRanks;
|
mContextRankSize = contextRanks;
|
||||||
mGenRankSize = genRanks;
|
mGenRankSize = genRanks;
|
||||||
|
|||||||
@ -370,8 +370,8 @@ class ExecutorRequestQueue:
|
|||||||
num_active_tokens = sum(
|
num_active_tokens = sum(
|
||||||
[req.py_orig_prompt_len for req in activate_requests])
|
[req.py_orig_prompt_len for req in activate_requests])
|
||||||
|
|
||||||
# Note: We use tp_allgather even for CP assuming that all CP ranks a
|
# Note: We use tp_allgather even for CP assuming that all CP ranks with the
|
||||||
# DP group have the same num_active_tokens and num_active_requests.
|
# same dp_rank have the same num_active_tokens and num_active_requests.
|
||||||
responses_list = self.dist.tp_allgather(
|
responses_list = self.dist.tp_allgather(
|
||||||
[len(activate_requests), num_active_tokens])
|
[len(activate_requests), num_active_tokens])
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user