[None][chore] Reduce attention module repeated warnings. (#11335)

Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
Yuxian Qiu 2026-02-10 08:58:21 +08:00 committed by GitHub
parent fe4c690b6c
commit af68c29d3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 13 deletions

View File

@ -134,7 +134,7 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
if (worldConfig.isTensorParallel() || worldConfig.isContextParallel())
{
mGroupTensorParaComm = std::make_shared<CacheTransceiverComm>(
mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getTensorParallelRank()));
mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getRank()));
}
int kvFactor = 2;
if (cacheManager->getCacheType() == kv_cache_manager::CacheType::kSELFKONLY)
@ -148,19 +148,19 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
if (mCacheState->getParallelConfig().mEnableAttentionDP)
{
int DPSize = mCacheState->getParallelConfig().mDPsize;
int dpSize = mCacheState->getParallelConfig().mDPsize;
// DPRank is derived from the tensor parallel rank, which already accounts for CP.
// dpRank is derived from the tensor parallel rank, which already accounts for CP.
// Layout: rank = ppRank * (TP * CP) + tpRank * CP + cpRank.
// getTensorParallelRank() correctly extracts tpRank regardless of CP.
int DPRank = mCacheState->getParallelConfig().mDPrank;
int dpRank = mCacheState->getParallelConfig().mDPrank;
// <PP,DP,TP,CP>
mGroupDataComm = std::make_shared<CacheTransceiverComm>(mGroupComm->split(DPRank, worldConfig.getRank()));
mGroupDataComm = std::make_shared<CacheTransceiverComm>(mGroupComm->split(dpRank, worldConfig.getRank()));
if (worldConfig.isTensorParallel() || worldConfig.isContextParallel())
{
// Group ranks with same (ppRank, DPRank) accounting for CP.
// Group ranks with same (ppRank, dpRank) accounting for CP.
mGroupTPInDPComm = std::make_shared<CacheTransceiverComm>(
mGroupComm->split(worldConfig.getPipelineParallelRank() * DPSize + DPRank, worldConfig.getRank()));
mGroupComm->split(worldConfig.getPipelineParallelRank() * dpSize + dpRank, worldConfig.getRank()));
}
}
bool isMLA = attentionType == executor::kv_cache::CacheState::AttentionType::kMLA;

View File

@ -336,13 +336,14 @@ class Attention(nn.Module):
key="sparse_attention_config")
if config.sparse_attention_config.algorithm == "rocket":
logger.warning("disable rope_fusion for RocketKV.")
logger.warning_once("disable rope_fusion for RocketKV.",
key="disable_rope_fusion_for_rocketkv")
self.rope_fusion = False
if self.rope_fusion and not attn_cls.support_fused_rope():
logger.warning(
"rope_fusion is true but the attention backend does not support it. Will disable rope_fusion."
)
logger.warning_once(
"rope_fusion is true but the attention backend does not support it. Will disable rope_fusion.",
key="disable_rope_fusion_for_non_supported_backend")
self.rope_fusion = False
# If rope_fusion is not specified, enable if the attention backend supports it.
if self.rope_fusion is None:
@ -824,8 +825,9 @@ class MLA(nn.Module):
# tensor parallel
config = config or ModelConfig()
if mapping_with_cp is not None:
logger.warning(
"[MLA::__init__] Overriding mapping with CP detected.")
logger.warning_once(
"[MLA::__init__] Overriding mapping with CP detected.",
key="mla_init_mapping_with_cp")
self.mapping = mapping_with_cp
else:
self.mapping = config.mapping