mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge 31f2ecd3cb into 6df2c8a074
This commit is contained in:
commit
c7322d95d6
@ -133,8 +133,19 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
|
||||
|
||||
if (worldConfig.isTensorParallel())
|
||||
{
|
||||
mGroupTensorParaComm = std::make_shared<CacheTransceiverComm>(
|
||||
mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getTensorParallelRank()));
|
||||
if (worldConfig.isContextParallel())
|
||||
{
|
||||
// When CP is enabled, group ranks with same (ppRank, cpRank) to exclude both PP and CP.
|
||||
auto const tpGroupId = worldConfig.getContextParallelRank()
|
||||
+ worldConfig.getContextParallelism() * worldConfig.getPipelineParallelRank();
|
||||
mGroupTensorParaComm
|
||||
= std::make_shared<CacheTransceiverComm>(mGroupComm->split(tpGroupId, worldConfig.getRank()));
|
||||
}
|
||||
else
|
||||
{
|
||||
mGroupTensorParaComm = std::make_shared<CacheTransceiverComm>(
|
||||
mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getTensorParallelRank()));
|
||||
}
|
||||
}
|
||||
int kvFactor = 2;
|
||||
if (cacheManager->getCacheType() == kv_cache_manager::CacheType::kSELFKONLY)
|
||||
@ -151,17 +162,18 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
|
||||
int TPSizeInDPGroup
|
||||
= mCacheState->getParallelConfig().mTensorParallelism / mCacheState->getParallelConfig().mDPsize;
|
||||
int DPSize = mCacheState->getParallelConfig().mDPsize;
|
||||
int TPRankInDPGroup = worldConfig.getTensorParallelRank() % TPSizeInDPGroup;
|
||||
|
||||
int DPRank = (worldConfig.getRank() - TPSizeInDPGroup * DPSize * worldConfig.getPipelineParallelRank()
|
||||
- TPRankInDPGroup)
|
||||
/ TPSizeInDPGroup;
|
||||
// <PP,DP,TP>
|
||||
// DPRank is derived from the tensor parallel rank, which already accounts for CP.
|
||||
// Layout: rank = ppRank * (TP * CP) + tpRank * CP + cpRank.
|
||||
// getTensorParallelRank() correctly extracts tpRank regardless of CP.
|
||||
int DPRank = mCacheState->getParallelConfig().mDPrank;
|
||||
// <PP,DP,TP,CP>
|
||||
mGroupDataComm = std::make_shared<CacheTransceiverComm>(mGroupComm->split(DPRank, worldConfig.getRank()));
|
||||
if (worldConfig.isTensorParallel())
|
||||
{
|
||||
// Group ranks with same (ppRank, DPRank) accounting for CP.
|
||||
mGroupTPInDPComm = std::make_shared<CacheTransceiverComm>(
|
||||
mGroupComm->split(worldConfig.getRank() / TPSizeInDPGroup, worldConfig.getRank()));
|
||||
mGroupComm->split(worldConfig.getPipelineParallelRank() * DPSize + DPRank, worldConfig.getRank()));
|
||||
}
|
||||
}
|
||||
bool isMLA = attentionType == executor::kv_cache::CacheState::AttentionType::kMLA;
|
||||
|
||||
@ -552,9 +552,10 @@ protected:
|
||||
mCpSize = genCp;
|
||||
}
|
||||
|
||||
mTpRank = mRankInInstance % mTpSize;
|
||||
// Rank formula must match targetIRanks: ppRank * (tpNum * cpNum) + tpRank * cpNum + cpRank.
|
||||
mCpRank = mRankInInstance % mCpSize;
|
||||
mTpRank = (mRankInInstance / mCpSize) % mTpSize;
|
||||
mPpRank = mRankInInstance / (mTpSize * mCpSize);
|
||||
mCpRank = (mRankInInstance % (mTpSize * mCpSize)) / mTpSize;
|
||||
mContextRankSize = contextRanks;
|
||||
mGenRankSize = genRanks;
|
||||
mContextTpSize = contextTp;
|
||||
@ -887,7 +888,16 @@ protected:
|
||||
auto makeLlmRequestWithDP(SizeType32 length, LlmRequest::RequestIdType requestId, int contextDpRank)
|
||||
{
|
||||
constexpr SizeType32 maxNewTokens{1};
|
||||
texec::Request request{VecTokens(length), maxNewTokens};
|
||||
auto const tokensPerBlock = mContextCacheState->getModelConfig().mTokensPerBlock;
|
||||
|
||||
std::optional<CPMetaData> cpMetaData;
|
||||
int seqLen = length;
|
||||
if (mCpSize > 1)
|
||||
{
|
||||
cpMetaData.emplace(length, tokensPerBlock, mCpRank, mCpSize);
|
||||
seqLen = cpMetaData.value().mSeqLenOnThisCPRank;
|
||||
}
|
||||
texec::Request request{VecTokens(seqLen, seqLen), maxNewTokens};
|
||||
|
||||
auto state = std::make_unique<texec::DataTransceiverState>();
|
||||
state->setCommState(texec::kv_cache::CommState{*mContextCommState});
|
||||
@ -905,7 +915,6 @@ protected:
|
||||
request.setContextPhaseParams(std::move(stats));
|
||||
auto llmRequestPtr = std::make_unique<LlmRequest>(requestId, std::move(request));
|
||||
|
||||
std::optional<CPMetaData> cpMetaData;
|
||||
return std::make_unique<WrappedLlmRequest>(std::move(llmRequestPtr), cpMetaData);
|
||||
}
|
||||
|
||||
@ -1428,6 +1437,27 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase)
|
||||
{
|
||||
GTEST_SKIP() << "Temporarily skipping cache transceiver tests with NIXL and MOONCAKE backend for CP.";
|
||||
}
|
||||
// Filter request lengths based on CP requirements.
|
||||
// Each request must have at least one block per CP rank to be valid for CP tests.
|
||||
std::vector<int> lenList = {60, 30, 60, 10};
|
||||
if (genCp > 1)
|
||||
{
|
||||
std::vector<int> updatedLenList;
|
||||
for (auto len : lenList)
|
||||
{
|
||||
if (len > tokensPerBlock * (genCp - 1))
|
||||
{
|
||||
updatedLenList.push_back(len);
|
||||
}
|
||||
}
|
||||
if (updatedLenList.empty())
|
||||
{
|
||||
GTEST_SKIP() << "Skipping test because not even one request has one block per genCP rank. tokensPerBlock="
|
||||
<< tokensPerBlock << ", genCp=" << genCp;
|
||||
}
|
||||
lenList = updatedLenList;
|
||||
}
|
||||
|
||||
setUpCommunicator(contextTp, contextPp, contextCp, genTp, genPp, genCp, isMLA, contextDP, generationDP);
|
||||
|
||||
if (mIsContext || mIsGeneration)
|
||||
@ -1438,7 +1468,7 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase)
|
||||
setUpCacheTransceiver();
|
||||
std::vector<std::shared_ptr<WrappedLlmRequest>> requests;
|
||||
int requestId = 0;
|
||||
for (auto len : {60, 30, 60, 10})
|
||||
for (auto len : lenList)
|
||||
{
|
||||
requests.emplace_back(makeLlmRequestWithDP(len, requestId, requestId % contextTp));
|
||||
requestId++;
|
||||
@ -1814,6 +1844,44 @@ INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest1WithCPForMLA, AsymmetricalCacheTest,
|
||||
/*generationDP*/ testing::Values(false),
|
||||
/*isWindow*/ testing::Values(false), testing::Values(false), testing::Values(0), testing::Values(128)));
|
||||
|
||||
// Tests cases where there's non-trivial TP and PP on context side while non-trivial CP & DP on gen side.
|
||||
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithCPAndDPForMLA0, AsymmetricalCacheTestWithDP,
|
||||
testing::Combine(/*contextTp*/ testing::Values(1, 2),
|
||||
/*contextPp*/ testing::Values(1, 2),
|
||||
/*contextCp*/ testing::Values(1),
|
||||
/*genTp*/ testing::Values(2),
|
||||
/*genPp*/ testing::Values(1),
|
||||
/*genCp*/ testing::Values(2),
|
||||
/*numLayers*/ testing::Values(4),
|
||||
/*numHeads*/ testing::Values(1),
|
||||
/*sizePerHead*/ testing::Values(4),
|
||||
/*tokensPerBlock*/ testing::Values(8),
|
||||
/*dataType*/ testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8),
|
||||
/*kvFactor*/ testing::Values(1),
|
||||
/*isMLA*/ testing::Values(true),
|
||||
/*contextDP*/ testing::Values(false),
|
||||
/*generationDP*/ testing::Values(true),
|
||||
/*isWindow*/ testing::Values(false), testing::Values(false), testing::Values(0), testing::Values(128)));
|
||||
|
||||
// Tests cases where there's non-trivial DP on context side while non-trivial CP & DP on gen side.
|
||||
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithCPAndDPForMLA1, AsymmetricalCacheTestWithDP,
|
||||
testing::Combine(/*contextTp*/ testing::Values(2, 4),
|
||||
/*contextPp*/ testing::Values(1),
|
||||
/*contextCp*/ testing::Values(1),
|
||||
/*genTp*/ testing::Values(2),
|
||||
/*genPp*/ testing::Values(1),
|
||||
/*genCp*/ testing::Values(2),
|
||||
/*numLayers*/ testing::Values(4),
|
||||
/*numHeads*/ testing::Values(1),
|
||||
/*sizePerHead*/ testing::Values(4),
|
||||
/*tokensPerBlock*/ testing::Values(8),
|
||||
/*dataType*/ testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8),
|
||||
/*kvFactor*/ testing::Values(1),
|
||||
/*isMLA*/ testing::Values(true),
|
||||
/*contextDP*/ testing::Values(true),
|
||||
/*generationDP*/ testing::Values(true),
|
||||
/*isWindow*/ testing::Values(false), testing::Values(false), testing::Values(0), testing::Values(128)));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA1, AsymmetricalCacheTestWithDP,
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
|
||||
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4),
|
||||
@ -2226,8 +2294,8 @@ TEST(targetTest, CacheStateContextDP)
|
||||
auto const verifyContext = [&](int contextRank, int generationRank, std::vector<int> const& expectRanks,
|
||||
int expectPPDomain, int expectTPDomain, bool expectNeedSend)
|
||||
{
|
||||
int contextDPRank = contextRank % contextTP;
|
||||
int generationDPRank = generationRank % genTP;
|
||||
int contextDPRank = (contextRank % (contextTP * contextCP)) / contextCP;
|
||||
int generationDPRank = (generationRank % (genTP * genCP)) / genCP;
|
||||
auto attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA
|
||||
: texec::kv_cache::CacheState::AttentionType::kDEFAULT;
|
||||
|
||||
@ -2239,12 +2307,12 @@ TEST(targetTest, CacheStateContextDP)
|
||||
tokensPerBlock, genTP, genPP, genCP, genAttentionLayerNumPerPP, dataType, attentionType, kvFactor,
|
||||
genEnableDP, generationDPRank, genTP};
|
||||
|
||||
auto const contextTragetInfo
|
||||
auto const contextTargetInfo
|
||||
= tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(genCache, contextCache, contextRank);
|
||||
|
||||
EXPECT_EQ(expectRanks, contextTragetInfo.mIRanks);
|
||||
EXPECT_EQ(expectPPDomain, contextTragetInfo.mDomainPPSize);
|
||||
EXPECT_EQ(expectTPDomain, contextTragetInfo.mDomainTPSize);
|
||||
EXPECT_EQ(expectRanks, contextTargetInfo.mIRanks);
|
||||
EXPECT_EQ(expectPPDomain, contextTargetInfo.mDomainPPSize);
|
||||
EXPECT_EQ(expectTPDomain, contextTargetInfo.mDomainTPSize);
|
||||
EXPECT_EQ(expectNeedSend, MLACacheFormatter::needSendCache(contextCache, genCache, contextRank));
|
||||
};
|
||||
|
||||
@ -2330,11 +2398,11 @@ TEST(targetTest, CacheStateContextDP)
|
||||
contextTP = 1;
|
||||
genTP = 2;
|
||||
|
||||
auto const verfiyGeneration = [&](int contextRank, int generationRank, std::vector<int> const& expectRanks,
|
||||
auto const verifyGeneration = [&](int contextRank, int generationRank, std::vector<int> const& expectRanks,
|
||||
int expectPPDomain, int expectTPDomain)
|
||||
{
|
||||
int contextDPRank = contextRank % contextTP;
|
||||
int generationDPRank = generationRank % genTP;
|
||||
int contextDPRank = (contextRank % (contextTP * contextCP)) / contextCP;
|
||||
int generationDPRank = (generationRank % (genTP * genCP)) / genCP;
|
||||
auto attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA
|
||||
: texec::kv_cache::CacheState::AttentionType::kDEFAULT;
|
||||
|
||||
@ -2346,17 +2414,17 @@ TEST(targetTest, CacheStateContextDP)
|
||||
tokensPerBlock, genTP, genPP, genCP, genAttentionLayerNumPerPP, dataType, attentionType, kvFactor,
|
||||
genEnableDP, generationDPRank, genTP};
|
||||
|
||||
auto const contextTragetInfo
|
||||
auto const contextTargetInfo
|
||||
= tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(contextCache, genCache, generationRank);
|
||||
|
||||
EXPECT_EQ(expectRanks, contextTragetInfo.mIRanks);
|
||||
EXPECT_EQ(expectPPDomain, contextTragetInfo.mDomainPPSize);
|
||||
EXPECT_EQ(expectTPDomain, contextTragetInfo.mDomainTPSize);
|
||||
EXPECT_EQ(expectRanks, contextTargetInfo.mIRanks);
|
||||
EXPECT_EQ(expectPPDomain, contextTargetInfo.mDomainPPSize);
|
||||
EXPECT_EQ(expectTPDomain, contextTargetInfo.mDomainTPSize);
|
||||
};
|
||||
|
||||
verfiyGeneration(
|
||||
verifyGeneration(
|
||||
/*contextRank*/ 0, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
|
||||
verfiyGeneration(
|
||||
verifyGeneration(
|
||||
/*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
|
||||
|
||||
contextTP = 1;
|
||||
@ -2366,9 +2434,9 @@ TEST(targetTest, CacheStateContextDP)
|
||||
contextAttentionLayerNumPerPP = std::vector<SizeType32>(contextPP, numLayers / contextPP);
|
||||
genAttentionLayerNumPerPP = std::vector<SizeType32>(genPP, numLayers / genPP);
|
||||
|
||||
verfiyGeneration(
|
||||
verifyGeneration(
|
||||
/*contextRank*/ 0, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
|
||||
verfiyGeneration(
|
||||
verifyGeneration(
|
||||
/*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
|
||||
|
||||
genEnableDP = false;
|
||||
@ -2381,8 +2449,8 @@ TEST(targetTest, CacheStateContextDP)
|
||||
contextAttentionLayerNumPerPP = std::vector<SizeType32>(contextPP, numLayers / contextPP);
|
||||
genAttentionLayerNumPerPP = std::vector<SizeType32>(genPP, numLayers / genPP);
|
||||
|
||||
verfiyGeneration(
|
||||
verifyGeneration(
|
||||
/*contextRank*/ 0, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
|
||||
verfiyGeneration(
|
||||
verifyGeneration(
|
||||
/*contextRank*/ 1, /*generationRank*/ 0, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
|
||||
}
|
||||
|
||||
@ -1547,7 +1547,7 @@ class AutoTuner:
|
||||
def _merge_cache_data(self, custom_op: str):
|
||||
cache_data = self.profiling_cache.get_specific_custom_op(custom_op)
|
||||
merged_cache_data = dict()
|
||||
all_cache_data = self._dist.tp_allgather(obj=cache_data)
|
||||
all_cache_data = self._dist.tp_cp_allgather(obj=cache_data)
|
||||
|
||||
for data in all_cache_data:
|
||||
for key, value in data.items():
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import math
|
||||
import pickle # nosec B403
|
||||
from abc import ABC, abstractmethod
|
||||
@ -136,6 +135,35 @@ class Distributed(ABC):
|
||||
obj = self.cp_broadcast(obj, root=root, **kwargs)
|
||||
return obj
|
||||
|
||||
@abstractmethod
|
||||
def tp_allgather(self, obj):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cp_allgather(self, obj):
|
||||
pass
|
||||
|
||||
def tp_cp_allgather(self, obj):
|
||||
"""Allgather across both TP and CP dimensions.
|
||||
|
||||
First gathers within CP group, then across TP groups, returning
|
||||
a flattened list with tp_size * cp_size entries.
|
||||
"""
|
||||
# Gather across CP dimension.
|
||||
if self.cp_size > 1:
|
||||
obj = self.cp_allgather(obj)
|
||||
else:
|
||||
obj = [obj] # Wrap to match cp_allgather output format.
|
||||
|
||||
# Gather across TP dimension.
|
||||
if self.tp_size > 1:
|
||||
obj = self.tp_allgather(obj)
|
||||
else:
|
||||
obj = [obj] # Wrap to match tp_allgather output format.
|
||||
|
||||
# Flatten: [[cp0, cp1], [cp0, cp1], ...] -> [tp0_cp0, tp0_cp1, tp1_cp0, ...]
|
||||
return [entry for tp_group in obj for entry in tp_group]
|
||||
|
||||
|
||||
def safe_broadcast(comm, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
|
||||
"""
|
||||
@ -363,24 +391,9 @@ class MPIDist(Distributed):
|
||||
def __init__(self, mapping: Mapping):
|
||||
super().__init__(mapping)
|
||||
self.create_cp_comm()
|
||||
# Repurpose CP ranks to TP for Helix so that the right comms are created.
|
||||
mapping_with_cp = None
|
||||
if self.mapping.has_cp_helix():
|
||||
logger.info(
|
||||
f"[MPIDist::__init__] Repurposing CP ranks to TP for Helix.")
|
||||
mapping_with_cp = copy.deepcopy(self.mapping)
|
||||
self.mapping = self.mapping.repurpose_helix_cp_to_tp()
|
||||
|
||||
self.create_tp_comm()
|
||||
self.create_pp_comm()
|
||||
|
||||
# Restore the original mapping.
|
||||
if mapping_with_cp is not None:
|
||||
logger.info(
|
||||
f"[MPIDist::__init__] Restoring original mapping undoing Helix manipulation."
|
||||
)
|
||||
self.mapping = mapping_with_cp
|
||||
|
||||
def broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
|
||||
comm = mpi_comm()
|
||||
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)
|
||||
@ -758,6 +771,22 @@ class TorchDist(Distributed):
|
||||
device=torch.device("cpu"))
|
||||
return ret[0]
|
||||
|
||||
@log_op
|
||||
def cp_allgather(self, obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
output_list = [
|
||||
torch.empty_like(obj)
|
||||
for _ in range(self.mapping.cp_group_pg.size())
|
||||
]
|
||||
dist.all_gather(output_list, obj, group=self.mapping.cp_group_pg)
|
||||
return output_list
|
||||
else:
|
||||
output_list = [None] * self.mapping.cp_group_pg.size()
|
||||
dist.all_gather_object(output_list,
|
||||
obj,
|
||||
group=self.mapping.cp_group_pg)
|
||||
return output_list
|
||||
|
||||
@log_op
|
||||
def pp_allgather(self, obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
|
||||
@ -1120,13 +1120,19 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
reduce_output=not self.enable_attention_dp
|
||||
and self.mapping.tp_size > 1)
|
||||
else:
|
||||
# When enable_attention_dp is True, TP reduction is skipped since each DP rank
|
||||
# works on different batch elements. However, with CP > 1, attention is split
|
||||
# across CP ranks for the SAME batch element, so reduction is still needed
|
||||
# within the CP group.
|
||||
needs_tp_reduce = not self.enable_attention_dp and self.mapping.tp_size > 1
|
||||
needs_cp_reduce = mapping_with_cp is not None and mapping_with_cp.has_cp_helix(
|
||||
)
|
||||
self.self_attn = DeepseekV3Attention(
|
||||
model_config,
|
||||
layer_idx=layer_idx_for_attention,
|
||||
aux_stream=aux_stream_dict[AuxStreamType.Attention],
|
||||
mapping_with_cp=mapping_with_cp,
|
||||
reduce_output=not self.enable_attention_dp
|
||||
and self.mapping.tp_size > 1)
|
||||
reduce_output=needs_tp_reduce or needs_cp_reduce)
|
||||
|
||||
self.fusion_config = EagerFusionConfig()
|
||||
self.enable_fusion = os.environ.get(
|
||||
@ -1192,10 +1198,15 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
|
||||
# When enable_attention_dp is True, we normally skip attention all-reduce since each
|
||||
# DP rank works on different batch elements. However, with CP > 1, attention is split
|
||||
# across CP ranks for the SAME batch element, so all-reduce is still needed.
|
||||
has_cp = mapping_with_cp is not None and mapping_with_cp.cp_size > 1
|
||||
can_skip_for_attention_dp = self.enable_attention_dp and not has_cp
|
||||
self.disable_attn_allreduce = (self.fusion_config.PRE_MOE_FUSION
|
||||
or self.fusion_config.PRE_MLP_FUSION
|
||||
or self.mapping.tp_size == 1
|
||||
or self.enable_attention_dp)
|
||||
or can_skip_for_attention_dp)
|
||||
|
||||
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
|
||||
@ -814,7 +814,9 @@ class MLA(nn.Module):
|
||||
tp_size = self.mapping.tp_size
|
||||
pp_size = self.mapping.pp_size
|
||||
cp_size = self.mapping.cp_size
|
||||
dp_size = 1
|
||||
if self.mapping.enable_attention_dp:
|
||||
dp_size = tp_size
|
||||
tp_size = 1
|
||||
if self.mapping.has_cp_ulysses():
|
||||
raise NotImplementedError("MLA doesn't support CP Ulyssees yet")
|
||||
@ -823,9 +825,9 @@ class MLA(nn.Module):
|
||||
), f"CP type must be HELIX for MLA, but got {self.mapping.cp_config['cp_type']}."
|
||||
|
||||
mapping = Mapping(
|
||||
world_size=tp_size * pp_size * cp_size,
|
||||
world_size=pp_size * dp_size * tp_size * cp_size,
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
pp_size=pp_size * dp_size,
|
||||
cp_size=cp_size,
|
||||
cp_config=self.mapping.cp_config,
|
||||
rank=self.mapping.rank,
|
||||
@ -924,9 +926,9 @@ class MLA(nn.Module):
|
||||
)
|
||||
|
||||
mapping_o = Mapping(
|
||||
world_size=tp_size * pp_size * cp_size,
|
||||
world_size=pp_size * dp_size * tp_size * cp_size,
|
||||
tp_size=tp_size * cp_size,
|
||||
pp_size=pp_size,
|
||||
pp_size=pp_size * dp_size,
|
||||
cp_size=1,
|
||||
rank=self.mapping.rank,
|
||||
gpus_per_node=self.mapping.gpus_per_node,
|
||||
|
||||
@ -247,7 +247,7 @@ class CUDAGraphRunner:
|
||||
can_run_cuda_graph = batch.can_run_cuda_graph
|
||||
batch_size = batch.batch_size
|
||||
if self.enabled and self.config.enable_attention_dp and self.config.mapping.tp_size > 1:
|
||||
all_can_graph_batch = self.config.dist.tp_allgather(
|
||||
all_can_graph_batch = self.config.dist.tp_cp_allgather(
|
||||
[can_run_cuda_graph, batch_size])
|
||||
is_all_gen_only = all(all_can_graph[0]
|
||||
for all_can_graph in all_can_graph_batch)
|
||||
@ -409,7 +409,7 @@ class CUDAGraphRunner:
|
||||
new_batch_size = batch_size
|
||||
|
||||
if self.enabled and self.config.enable_attention_dp and self.config.mapping.tp_size > 1:
|
||||
graph_batch_size = self.config.dist.tp_allgather(
|
||||
graph_batch_size = self.config.dist.tp_cp_allgather(
|
||||
[can_run_cuda_graph, batch_size])
|
||||
all_can_graph = all(graph_batch[0]
|
||||
for graph_batch in graph_batch_size)
|
||||
|
||||
@ -369,13 +369,22 @@ class ExecutorRequestQueue:
|
||||
def _fetch_new_requests_attention_dp(
|
||||
self, activate_requests: List[LlmRequest]) -> List[LlmRequest]:
|
||||
"""Handle attention DP request fetching with load balancing."""
|
||||
# Get active request counts across all ranks
|
||||
# Get active request counts across all ranks.
|
||||
all_ranks_num_active_requests = []
|
||||
all_ranks_num_active_tokens = []
|
||||
num_active_tokens = sum(
|
||||
[req.py_orig_prompt_len for req in activate_requests])
|
||||
|
||||
if self.dist.has_cp_helix:
|
||||
num_active_tokens = sum(
|
||||
[req.total_input_len_cp for req in activate_requests])
|
||||
else:
|
||||
num_active_tokens = sum(
|
||||
[req.py_orig_prompt_len for req in activate_requests])
|
||||
|
||||
# Note: We use tp_allgather even for CP assuming that all CP ranks with the
|
||||
# same dp_rank have the same num_active_tokens and num_active_requests.
|
||||
responses_list = self.dist.tp_allgather(
|
||||
[len(activate_requests), num_active_tokens])
|
||||
|
||||
for num_active_requests, num_active_tokens in responses_list:
|
||||
all_ranks_num_active_requests.append(num_active_requests)
|
||||
all_ranks_num_active_tokens.append(num_active_tokens)
|
||||
|
||||
@ -1324,7 +1324,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
def _get_all_rank_num_tokens(self, attn_metadata: AttentionMetadata):
|
||||
if self.enable_attention_dp:
|
||||
return list(self.dist.tp_allgather(attn_metadata.num_tokens))
|
||||
return list(self.dist.tp_cp_allgather(attn_metadata.num_tokens))
|
||||
return None
|
||||
|
||||
def _get_all_rank_ctx_requests(self, num_ctx_requests: int):
|
||||
@ -1369,7 +1369,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
max(attn_all_rank_num_tokens)
|
||||
<= max_captured_num_tokens)
|
||||
all_ranks_can_run_piecewise_cuda_graph = list(
|
||||
self.dist.tp_allgather(can_run_piecewise_cuda_graph))
|
||||
self.dist.tp_cp_allgather(can_run_piecewise_cuda_graph))
|
||||
if all(all_ranks_can_run_piecewise_cuda_graph):
|
||||
padded_num_tokens = get_padded_piecewise_tokens(
|
||||
max(attn_all_rank_num_tokens))
|
||||
@ -1536,7 +1536,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# Handle distributed spec metadata
|
||||
if enable_attention_dp:
|
||||
sequence_lengths = spec_metadata.seq_lens
|
||||
all_rank_num_tokens = self.dist.tp_allgather(
|
||||
all_rank_num_tokens = self.dist.tp_cp_allgather(
|
||||
[spec_metadata.num_tokens,
|
||||
len(sequence_lengths)])
|
||||
spec_metadata.all_rank_num_tokens = [
|
||||
@ -2691,7 +2691,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
inputs['spec_metadata'] = spec_metadata
|
||||
|
||||
if self.enable_attention_dp:
|
||||
all_rank_num_tokens = self.dist.tp_allgather(
|
||||
all_rank_num_tokens = self.dist.tp_cp_allgather(
|
||||
[spec_metadata.num_tokens,
|
||||
len(sequence_lengths)])
|
||||
|
||||
@ -2856,7 +2856,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# support attention dp
|
||||
if self.enable_attention_dp:
|
||||
if spec_metadata is not None:
|
||||
all_rank_num_tokens = self.dist.tp_allgather([
|
||||
all_rank_num_tokens = self.dist.tp_cp_allgather([
|
||||
attn_metadata.num_tokens, spec_metadata.num_tokens,
|
||||
len(sequence_lengths)
|
||||
])
|
||||
@ -2871,7 +2871,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
spec_metadata.all_rank_num_tokens = spec_all_rank_num_tokens
|
||||
spec_metadata.all_rank_num_seqs = all_rank_num_seqs
|
||||
else:
|
||||
all_rank_num_tokens = self.dist.tp_allgather(
|
||||
all_rank_num_tokens = self.dist.tp_cp_allgather(
|
||||
attn_metadata.num_tokens)
|
||||
attn_metadata.all_rank_num_tokens = all_rank_num_tokens
|
||||
|
||||
|
||||
@ -1249,7 +1249,8 @@ class PyExecutor:
|
||||
def _can_queue(self, scheduled_batch):
|
||||
|
||||
if self.enable_attention_dp:
|
||||
tp_batch_sizes = self.dist.tp_allgather(scheduled_batch.batch_size)
|
||||
tp_batch_sizes = self.dist.tp_cp_allgather(
|
||||
scheduled_batch.batch_size)
|
||||
can_queue = 0 not in tp_batch_sizes
|
||||
else:
|
||||
can_queue = scheduled_batch.batch_size > 0
|
||||
@ -1597,7 +1598,7 @@ class PyExecutor:
|
||||
if self.enable_attention_dp:
|
||||
local_can_forward = self.executor_request_queue.num_fetch_requests + \
|
||||
len(scheduled_batch.generation_requests) >= self.benchmark_req_queues_size
|
||||
all_can_forward = self.dist.tp_allgather(
|
||||
all_can_forward = self.dist.tp_cp_allgather(
|
||||
local_can_forward)
|
||||
if all(all_can_forward):
|
||||
can_forward = True
|
||||
@ -1970,6 +1971,8 @@ class PyExecutor:
|
||||
num_scheduled_tokens = sum(
|
||||
[len(req.get_tokens(0))
|
||||
for req in context_requests]) + num_scheduled_generation_requests
|
||||
# Note: We use tp_allgather instead of tp_cp_allgather because we want to
|
||||
# balance the requests across DP ranks; not CP ranks within those DP ranks.
|
||||
responses_list = self.dist.tp_allgather([
|
||||
num_scheduled_context_requests, num_scheduled_generation_requests,
|
||||
num_scheduled_tokens
|
||||
|
||||
@ -871,10 +871,16 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 1, 4), (1, 2, 2),
|
||||
(2, 1, 2)],
|
||||
ids=["pp1tp1cp4", "pp1tp2cp2", "pp2tp1cp2"])
|
||||
@pytest.mark.parametrize(
|
||||
"gen_pp,gen_tp,gen_cp,enable_attention_dp", [
|
||||
(1, 1, 4, False),
|
||||
(1, 2, 2, False),
|
||||
(1, 2, 2, True),
|
||||
(2, 1, 2, False),
|
||||
],
|
||||
ids=["pp1tp1cp4", "pp1tp2cp2", "pp1dp2cp2", "pp2tp1cp2"])
|
||||
@pytest.mark.parametrize("cuda_graph_config", [
|
||||
None,
|
||||
{
|
||||
@ -892,7 +898,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
])
|
||||
@pytest.mark.parametrize("comms_medium", ["fifo", "nccl"])
|
||||
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config,
|
||||
gen_pp, gen_tp, gen_cp):
|
||||
gen_pp, gen_tp, gen_cp, enable_attention_dp):
|
||||
use_nccl_for_alltoall = comms_medium == "nccl"
|
||||
gen_ep = gen_tp * gen_cp
|
||||
kv_cache_config = {
|
||||
@ -932,6 +938,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
"backend": "UCX",
|
||||
"max_tokens_in_buffer": 8192,
|
||||
},
|
||||
"enable_attention_dp": enable_attention_dp,
|
||||
}
|
||||
disaggregated_server_config = {
|
||||
"hostname": "localhost",
|
||||
|
||||
@ -281,11 +281,13 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none-pp1tp1cp4]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp2tp1cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding-pp2tp1cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none-pp2tp1cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding-pp2tp1cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp1dp2cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1dp2cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none-pp1dp2cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1dp2cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend
|
||||
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
|
||||
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
|
||||
|
||||
@ -74,6 +74,7 @@ l0_dgx_b200:
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1dp2cp2] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_bs8_mtp] TIMEOUT (60)
|
||||
@ -104,6 +105,7 @@ l0_dgx_b200:
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1dp2cp2] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency] TIMEOUT (60)
|
||||
|
||||
@ -360,20 +360,6 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_vswa_reuse_4gpus[two_m
|
||||
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=False] SKIP (https://nvbugs/5785206)
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_fp8_prequantized SKIP (https://nvbugs/5785465)
|
||||
accuracy/test_llm_api_pytorch.py::TestMinistral8BInstruct::test_fp8 SKIP (https://nvbugs/5785485)
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp1tp2cp2] SKIP (https://nvbugs/5787836)
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] SKIP (https://nvbugs/5787836)
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none-pp1tp2cp2] SKIP (https://nvbugs/5787836)
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] SKIP (https://nvbugs/5787836)
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp1tp1cp4] SKIP (https://nvbugs/5787836)
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp1cp4] SKIP (https://nvbugs/5787836)
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none-pp1tp1cp4] SKIP (https://nvbugs/5787836)
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4] SKIP (https://nvbugs/5787836)
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp2tp1cp2] SKIP (https://nvbugs/5787836)
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding-pp2tp1cp2] SKIP (https://nvbugs/5787836)
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] SKIP (https://nvbugs/5787836)
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none-pp2tp1cp2] SKIP (https://nvbugs/5787836)
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding-pp2tp1cp2] SKIP (https://nvbugs/5787836)
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] SKIP (https://nvbugs/5787836)
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5769815)
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=True-torch_compile=False] SKIP (https://nvbugs/5787892)
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=False-torch_compile=False] SKIP (https://nvbugs/5787892)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user