mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][chore] optimize kv cache transfer for context TEP and gen DEP (#6657)
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
This commit is contained in:
parent
3e41e6c077
commit
ee471df07c
@ -75,7 +75,6 @@ BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmReques
|
||||
bool CacheFormatter::needSendCache(
|
||||
CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx)
|
||||
{
|
||||
// int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
|
||||
auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
|
||||
if (targetInfo.mDupHeadFactor <= 1)
|
||||
{
|
||||
@ -90,8 +89,9 @@ bool CacheFormatter::needSendCache(
|
||||
= selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize;
|
||||
selfTpRankInDpGroup = selfTpRank % selfTPNumInDPGroup;
|
||||
}
|
||||
int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0;
|
||||
|
||||
return selfTpRankInDpGroup % targetInfo.mDupHeadFactor == 0;
|
||||
return (destDPRank % targetInfo.mDupHeadFactor) == (selfTpRankInDpGroup % targetInfo.mDupHeadFactor);
|
||||
}
|
||||
|
||||
void checkAlternateWindow(BaseKVCacheManager* cacheManager, BaseCacheFormatter::CacheState const& selfConfig,
|
||||
@ -128,11 +128,12 @@ std::vector<size_t> CacheFormatter::pickRecvConnections(
|
||||
return ret;
|
||||
}
|
||||
TLLM_CHECK(numConnections == targetInfo.mIRanks.size());
|
||||
int selfDPRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0;
|
||||
|
||||
std::vector<size_t> ret;
|
||||
for (int i = 0; i < targetInfo.mDomainTPSize; i++)
|
||||
{
|
||||
if (i % targetInfo.mPeerDupHeadFactor == 0)
|
||||
if ((i % targetInfo.mPeerDupHeadFactor) == (selfDPRank % targetInfo.mPeerDupHeadFactor))
|
||||
{
|
||||
for (int j = 0; j < targetInfo.mDomainPPSize; j++)
|
||||
{
|
||||
|
||||
@ -45,10 +45,12 @@ std::vector<size_t> MLACacheFormatter::pickRecvConnections(
|
||||
auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
|
||||
TLLM_CHECK(numConnections == targetInfo.mIRanks.size());
|
||||
std::vector<size_t> ret;
|
||||
// targetInfo , mRanks [tpranks, dpranks]
|
||||
// targetInfo , mRanks [tpranks, ppranks]
|
||||
int dpRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0;
|
||||
|
||||
for (int i = 0; i < targetInfo.mDomainPPSize; i++)
|
||||
{
|
||||
ret.push_back(i);
|
||||
ret.push_back(i + (dpRank % (targetInfo.mDomainTPSize)) * targetInfo.mDomainPPSize);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
@ -58,19 +60,24 @@ bool MLACacheFormatter::needSendCache(
|
||||
{
|
||||
int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
|
||||
|
||||
int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP
|
||||
? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize
|
||||
: destConfig.getParallelConfig().mTensorParallelism;
|
||||
int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0;
|
||||
|
||||
if (selfConfig.getParallelConfig().mEnableAttentionDP)
|
||||
{
|
||||
int selfTPNumInDPGroup
|
||||
= selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize;
|
||||
int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP
|
||||
? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize
|
||||
: destConfig.getParallelConfig().mTensorParallelism;
|
||||
|
||||
int selfTPrankINDPGroup = selfTpRank % selfTPNumInDPGroup;
|
||||
if (selfTPNumInDPGroup <= destTPNumInDPGroup)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
return selfTPrankINDPGroup % (selfTPNumInDPGroup / destTPNumInDPGroup) == 0;
|
||||
|
||||
int dupHeadFactor = selfTPNumInDPGroup / destTPNumInDPGroup;
|
||||
return selfTPrankINDPGroup % dupHeadFactor == destDPRank;
|
||||
}
|
||||
|
||||
int destTPNum = destConfig.getParallelConfig().mEnableAttentionDP
|
||||
@ -81,7 +88,8 @@ bool MLACacheFormatter::needSendCache(
|
||||
{
|
||||
return true;
|
||||
}
|
||||
return selfTpRank % (selfTPNum / destTPNum) == 0;
|
||||
int dupHeadFactor = selfTPNum / destTPNum;
|
||||
return selfTpRank % dupHeadFactor == destDPRank;
|
||||
}
|
||||
|
||||
void MLACacheFormatter::format(TransferSession& session)
|
||||
|
||||
@ -1457,12 +1457,15 @@ TEST(targetTest, CacheStateNODP)
|
||||
|
||||
verifyContext(
|
||||
/*contextRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
|
||||
|
||||
verifyContext(
|
||||
/*contextRank*/ 1, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false);
|
||||
|
||||
verifyContext(
|
||||
/*contextRank*/ 2, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
|
||||
verifyContext(
|
||||
/*contextRank*/ 3, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false);
|
||||
|
||||
verifyContext(
|
||||
/*contextRank*/ 4, /*expectRanks*/ {2}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
|
||||
verifyContext(
|
||||
@ -1474,7 +1477,6 @@ TEST(targetTest, CacheStateNODP)
|
||||
|
||||
contextTP = 2;
|
||||
genTP = 4;
|
||||
|
||||
verifyContext(
|
||||
/*contextRank*/ 0, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2, /*expectNeedSend*/ true);
|
||||
verifyContext(/*contextRank*/ 1, /*expectRanks*/ {2, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2,
|
||||
@ -1564,13 +1566,13 @@ TEST(targetTest, CacheStateContextDP)
|
||||
/*expectNeedSend*/ true);
|
||||
verifyContext(
|
||||
/*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
|
||||
/*expectNeedSend*/ true);
|
||||
/*expectNeedSend*/ false);
|
||||
verifyContext(
|
||||
/*contextRank*/ 1, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
|
||||
/*expectNeedSend*/ false);
|
||||
verifyContext(
|
||||
/*contextRank*/ 1, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
|
||||
/*expectNeedSend*/ false);
|
||||
/*expectNeedSend*/ true);
|
||||
verifyContext(
|
||||
/*contextRank*/ 2, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
|
||||
/*expectNeedSend*/ false);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user