mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][chore] optimize kv cache transfer for context TEP and gen DEP (#6657)
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
This commit is contained in:
parent
3e41e6c077
commit
ee471df07c
@ -75,7 +75,6 @@ BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmReques
|
|||||||
bool CacheFormatter::needSendCache(
|
bool CacheFormatter::needSendCache(
|
||||||
CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx)
|
CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx)
|
||||||
{
|
{
|
||||||
// int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
|
|
||||||
auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
|
auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
|
||||||
if (targetInfo.mDupHeadFactor <= 1)
|
if (targetInfo.mDupHeadFactor <= 1)
|
||||||
{
|
{
|
||||||
@ -90,8 +89,9 @@ bool CacheFormatter::needSendCache(
|
|||||||
= selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize;
|
= selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize;
|
||||||
selfTpRankInDpGroup = selfTpRank % selfTPNumInDPGroup;
|
selfTpRankInDpGroup = selfTpRank % selfTPNumInDPGroup;
|
||||||
}
|
}
|
||||||
|
int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0;
|
||||||
|
|
||||||
return selfTpRankInDpGroup % targetInfo.mDupHeadFactor == 0;
|
return (destDPRank % targetInfo.mDupHeadFactor) == (selfTpRankInDpGroup % targetInfo.mDupHeadFactor);
|
||||||
}
|
}
|
||||||
|
|
||||||
void checkAlternateWindow(BaseKVCacheManager* cacheManager, BaseCacheFormatter::CacheState const& selfConfig,
|
void checkAlternateWindow(BaseKVCacheManager* cacheManager, BaseCacheFormatter::CacheState const& selfConfig,
|
||||||
@ -128,11 +128,12 @@ std::vector<size_t> CacheFormatter::pickRecvConnections(
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
TLLM_CHECK(numConnections == targetInfo.mIRanks.size());
|
TLLM_CHECK(numConnections == targetInfo.mIRanks.size());
|
||||||
|
int selfDPRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0;
|
||||||
|
|
||||||
std::vector<size_t> ret;
|
std::vector<size_t> ret;
|
||||||
for (int i = 0; i < targetInfo.mDomainTPSize; i++)
|
for (int i = 0; i < targetInfo.mDomainTPSize; i++)
|
||||||
{
|
{
|
||||||
if (i % targetInfo.mPeerDupHeadFactor == 0)
|
if ((i % targetInfo.mPeerDupHeadFactor) == (selfDPRank % targetInfo.mPeerDupHeadFactor))
|
||||||
{
|
{
|
||||||
for (int j = 0; j < targetInfo.mDomainPPSize; j++)
|
for (int j = 0; j < targetInfo.mDomainPPSize; j++)
|
||||||
{
|
{
|
||||||
|
|||||||
@ -45,10 +45,12 @@ std::vector<size_t> MLACacheFormatter::pickRecvConnections(
|
|||||||
auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
|
auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
|
||||||
TLLM_CHECK(numConnections == targetInfo.mIRanks.size());
|
TLLM_CHECK(numConnections == targetInfo.mIRanks.size());
|
||||||
std::vector<size_t> ret;
|
std::vector<size_t> ret;
|
||||||
// targetInfo , mRanks [tpranks, dpranks]
|
// targetInfo , mRanks [tpranks, ppranks]
|
||||||
|
int dpRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0;
|
||||||
|
|
||||||
for (int i = 0; i < targetInfo.mDomainPPSize; i++)
|
for (int i = 0; i < targetInfo.mDomainPPSize; i++)
|
||||||
{
|
{
|
||||||
ret.push_back(i);
|
ret.push_back(i + (dpRank % (targetInfo.mDomainTPSize)) * targetInfo.mDomainPPSize);
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
@ -58,19 +60,24 @@ bool MLACacheFormatter::needSendCache(
|
|||||||
{
|
{
|
||||||
int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
|
int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
|
||||||
|
|
||||||
|
int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP
|
||||||
|
? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize
|
||||||
|
: destConfig.getParallelConfig().mTensorParallelism;
|
||||||
|
int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0;
|
||||||
|
|
||||||
if (selfConfig.getParallelConfig().mEnableAttentionDP)
|
if (selfConfig.getParallelConfig().mEnableAttentionDP)
|
||||||
{
|
{
|
||||||
int selfTPNumInDPGroup
|
int selfTPNumInDPGroup
|
||||||
= selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize;
|
= selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize;
|
||||||
int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP
|
|
||||||
? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize
|
|
||||||
: destConfig.getParallelConfig().mTensorParallelism;
|
|
||||||
int selfTPrankINDPGroup = selfTpRank % selfTPNumInDPGroup;
|
int selfTPrankINDPGroup = selfTpRank % selfTPNumInDPGroup;
|
||||||
if (selfTPNumInDPGroup <= destTPNumInDPGroup)
|
if (selfTPNumInDPGroup <= destTPNumInDPGroup)
|
||||||
{
|
{
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return selfTPrankINDPGroup % (selfTPNumInDPGroup / destTPNumInDPGroup) == 0;
|
|
||||||
|
int dupHeadFactor = selfTPNumInDPGroup / destTPNumInDPGroup;
|
||||||
|
return selfTPrankINDPGroup % dupHeadFactor == destDPRank;
|
||||||
}
|
}
|
||||||
|
|
||||||
int destTPNum = destConfig.getParallelConfig().mEnableAttentionDP
|
int destTPNum = destConfig.getParallelConfig().mEnableAttentionDP
|
||||||
@ -81,7 +88,8 @@ bool MLACacheFormatter::needSendCache(
|
|||||||
{
|
{
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return selfTpRank % (selfTPNum / destTPNum) == 0;
|
int dupHeadFactor = selfTPNum / destTPNum;
|
||||||
|
return selfTpRank % dupHeadFactor == destDPRank;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MLACacheFormatter::format(TransferSession& session)
|
void MLACacheFormatter::format(TransferSession& session)
|
||||||
|
|||||||
@ -1457,12 +1457,15 @@ TEST(targetTest, CacheStateNODP)
|
|||||||
|
|
||||||
verifyContext(
|
verifyContext(
|
||||||
/*contextRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
|
/*contextRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
|
||||||
|
|
||||||
verifyContext(
|
verifyContext(
|
||||||
/*contextRank*/ 1, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false);
|
/*contextRank*/ 1, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false);
|
||||||
|
|
||||||
verifyContext(
|
verifyContext(
|
||||||
/*contextRank*/ 2, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
|
/*contextRank*/ 2, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
|
||||||
verifyContext(
|
verifyContext(
|
||||||
/*contextRank*/ 3, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false);
|
/*contextRank*/ 3, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false);
|
||||||
|
|
||||||
verifyContext(
|
verifyContext(
|
||||||
/*contextRank*/ 4, /*expectRanks*/ {2}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
|
/*contextRank*/ 4, /*expectRanks*/ {2}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
|
||||||
verifyContext(
|
verifyContext(
|
||||||
@ -1474,7 +1477,6 @@ TEST(targetTest, CacheStateNODP)
|
|||||||
|
|
||||||
contextTP = 2;
|
contextTP = 2;
|
||||||
genTP = 4;
|
genTP = 4;
|
||||||
|
|
||||||
verifyContext(
|
verifyContext(
|
||||||
/*contextRank*/ 0, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2, /*expectNeedSend*/ true);
|
/*contextRank*/ 0, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2, /*expectNeedSend*/ true);
|
||||||
verifyContext(/*contextRank*/ 1, /*expectRanks*/ {2, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2,
|
verifyContext(/*contextRank*/ 1, /*expectRanks*/ {2, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2,
|
||||||
@ -1564,13 +1566,13 @@ TEST(targetTest, CacheStateContextDP)
|
|||||||
/*expectNeedSend*/ true);
|
/*expectNeedSend*/ true);
|
||||||
verifyContext(
|
verifyContext(
|
||||||
/*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
|
/*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
|
||||||
/*expectNeedSend*/ true);
|
/*expectNeedSend*/ false);
|
||||||
verifyContext(
|
verifyContext(
|
||||||
/*contextRank*/ 1, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
|
/*contextRank*/ 1, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
|
||||||
/*expectNeedSend*/ false);
|
/*expectNeedSend*/ false);
|
||||||
verifyContext(
|
verifyContext(
|
||||||
/*contextRank*/ 1, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
|
/*contextRank*/ 1, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
|
||||||
/*expectNeedSend*/ false);
|
/*expectNeedSend*/ true);
|
||||||
verifyContext(
|
verifyContext(
|
||||||
/*contextRank*/ 2, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
|
/*contextRank*/ 2, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
|
||||||
/*expectNeedSend*/ false);
|
/*expectNeedSend*/ false);
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user