[None][chore] optimize kv cache transfer for context TEP and gen DEP (#6657)

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
This commit is contained in:
Chuang Zhu 2025-08-07 11:36:05 +08:00 committed by GitHub
parent 3e41e6c077
commit ee471df07c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 24 additions and 13 deletions

View File

@ -75,7 +75,6 @@ BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmReques
bool CacheFormatter::needSendCache(
CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx)
{
// int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
if (targetInfo.mDupHeadFactor <= 1)
{
@ -90,8 +89,9 @@ bool CacheFormatter::needSendCache(
= selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize;
selfTpRankInDpGroup = selfTpRank % selfTPNumInDPGroup;
}
int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0;
return selfTpRankInDpGroup % targetInfo.mDupHeadFactor == 0;
return (destDPRank % targetInfo.mDupHeadFactor) == (selfTpRankInDpGroup % targetInfo.mDupHeadFactor);
}
void checkAlternateWindow(BaseKVCacheManager* cacheManager, BaseCacheFormatter::CacheState const& selfConfig,
@ -128,11 +128,12 @@ std::vector<size_t> CacheFormatter::pickRecvConnections(
return ret;
}
TLLM_CHECK(numConnections == targetInfo.mIRanks.size());
int selfDPRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0;
std::vector<size_t> ret;
for (int i = 0; i < targetInfo.mDomainTPSize; i++)
{
if (i % targetInfo.mPeerDupHeadFactor == 0)
if ((i % targetInfo.mPeerDupHeadFactor) == (selfDPRank % targetInfo.mPeerDupHeadFactor))
{
for (int j = 0; j < targetInfo.mDomainPPSize; j++)
{

View File

@ -45,10 +45,12 @@ std::vector<size_t> MLACacheFormatter::pickRecvConnections(
auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
TLLM_CHECK(numConnections == targetInfo.mIRanks.size());
std::vector<size_t> ret;
// targetInfo , mRanks [tpranks, dpranks]
// targetInfo , mRanks [tpranks, ppranks]
int dpRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0;
for (int i = 0; i < targetInfo.mDomainPPSize; i++)
{
ret.push_back(i);
ret.push_back(i + (dpRank % (targetInfo.mDomainTPSize)) * targetInfo.mDomainPPSize);
}
return ret;
}
@ -58,19 +60,24 @@ bool MLACacheFormatter::needSendCache(
{
int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP
? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize
: destConfig.getParallelConfig().mTensorParallelism;
int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0;
if (selfConfig.getParallelConfig().mEnableAttentionDP)
{
int selfTPNumInDPGroup
= selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize;
int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP
? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize
: destConfig.getParallelConfig().mTensorParallelism;
int selfTPrankINDPGroup = selfTpRank % selfTPNumInDPGroup;
if (selfTPNumInDPGroup <= destTPNumInDPGroup)
{
return true;
}
return selfTPrankINDPGroup % (selfTPNumInDPGroup / destTPNumInDPGroup) == 0;
int dupHeadFactor = selfTPNumInDPGroup / destTPNumInDPGroup;
return selfTPrankINDPGroup % dupHeadFactor == destDPRank;
}
int destTPNum = destConfig.getParallelConfig().mEnableAttentionDP
@ -81,7 +88,8 @@ bool MLACacheFormatter::needSendCache(
{
return true;
}
return selfTpRank % (selfTPNum / destTPNum) == 0;
int dupHeadFactor = selfTPNum / destTPNum;
return selfTpRank % dupHeadFactor == destDPRank;
}
void MLACacheFormatter::format(TransferSession& session)

View File

@ -1457,12 +1457,15 @@ TEST(targetTest, CacheStateNODP)
verifyContext(
/*contextRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 1, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false);
verifyContext(
/*contextRank*/ 2, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 3, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false);
verifyContext(
/*contextRank*/ 4, /*expectRanks*/ {2}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
verifyContext(
@ -1474,7 +1477,6 @@ TEST(targetTest, CacheStateNODP)
contextTP = 2;
genTP = 4;
verifyContext(
/*contextRank*/ 0, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2, /*expectNeedSend*/ true);
verifyContext(/*contextRank*/ 1, /*expectRanks*/ {2, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2,
@ -1564,13 +1566,13 @@ TEST(targetTest, CacheStateContextDP)
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ true);
/*expectNeedSend*/ false);
verifyContext(
/*contextRank*/ 1, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ false);
verifyContext(
/*contextRank*/ 1, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ false);
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 2, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ false);