mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[None][chore] Reduce attention module repeated warnings. (#11335)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
fe4c690b6c
commit
af68c29d3d
@ -134,7 +134,7 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
|
||||
if (worldConfig.isTensorParallel() || worldConfig.isContextParallel())
|
||||
{
|
||||
mGroupTensorParaComm = std::make_shared<CacheTransceiverComm>(
|
||||
mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getTensorParallelRank()));
|
||||
mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getRank()));
|
||||
}
|
||||
int kvFactor = 2;
|
||||
if (cacheManager->getCacheType() == kv_cache_manager::CacheType::kSELFKONLY)
|
||||
@ -148,19 +148,19 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
|
||||
|
||||
if (mCacheState->getParallelConfig().mEnableAttentionDP)
|
||||
{
|
||||
int DPSize = mCacheState->getParallelConfig().mDPsize;
|
||||
int dpSize = mCacheState->getParallelConfig().mDPsize;
|
||||
|
||||
// DPRank is derived from the tensor parallel rank, which already accounts for CP.
|
||||
// dpRank is derived from the tensor parallel rank, which already accounts for CP.
|
||||
// Layout: rank = ppRank * (TP * CP) + tpRank * CP + cpRank.
|
||||
// getTensorParallelRank() correctly extracts tpRank regardless of CP.
|
||||
int DPRank = mCacheState->getParallelConfig().mDPrank;
|
||||
int dpRank = mCacheState->getParallelConfig().mDPrank;
|
||||
// <PP,DP,TP,CP>
|
||||
mGroupDataComm = std::make_shared<CacheTransceiverComm>(mGroupComm->split(DPRank, worldConfig.getRank()));
|
||||
mGroupDataComm = std::make_shared<CacheTransceiverComm>(mGroupComm->split(dpRank, worldConfig.getRank()));
|
||||
if (worldConfig.isTensorParallel() || worldConfig.isContextParallel())
|
||||
{
|
||||
// Group ranks with same (ppRank, DPRank) accounting for CP.
|
||||
// Group ranks with same (ppRank, dpRank) accounting for CP.
|
||||
mGroupTPInDPComm = std::make_shared<CacheTransceiverComm>(
|
||||
mGroupComm->split(worldConfig.getPipelineParallelRank() * DPSize + DPRank, worldConfig.getRank()));
|
||||
mGroupComm->split(worldConfig.getPipelineParallelRank() * dpSize + dpRank, worldConfig.getRank()));
|
||||
}
|
||||
}
|
||||
bool isMLA = attentionType == executor::kv_cache::CacheState::AttentionType::kMLA;
|
||||
|
||||
@ -336,13 +336,14 @@ class Attention(nn.Module):
|
||||
key="sparse_attention_config")
|
||||
|
||||
if config.sparse_attention_config.algorithm == "rocket":
|
||||
logger.warning("disable rope_fusion for RocketKV.")
|
||||
logger.warning_once("disable rope_fusion for RocketKV.",
|
||||
key="disable_rope_fusion_for_rocketkv")
|
||||
self.rope_fusion = False
|
||||
|
||||
if self.rope_fusion and not attn_cls.support_fused_rope():
|
||||
logger.warning(
|
||||
"rope_fusion is true but the attention backend does not support it. Will disable rope_fusion."
|
||||
)
|
||||
logger.warning_once(
|
||||
"rope_fusion is true but the attention backend does not support it. Will disable rope_fusion.",
|
||||
key="disable_rope_fusion_for_non_supported_backend")
|
||||
self.rope_fusion = False
|
||||
# If rope_fusion is not specified, enable if the attention backend supports it.
|
||||
if self.rope_fusion is None:
|
||||
@ -824,8 +825,9 @@ class MLA(nn.Module):
|
||||
# tensor parallel
|
||||
config = config or ModelConfig()
|
||||
if mapping_with_cp is not None:
|
||||
logger.warning(
|
||||
"[MLA::__init__] Overriding mapping with CP detected.")
|
||||
logger.warning_once(
|
||||
"[MLA::__init__] Overriding mapping with CP detected.",
|
||||
key="mla_init_mapping_with_cp")
|
||||
self.mapping = mapping_with_cp
|
||||
else:
|
||||
self.mapping = config.mapping
|
||||
|
||||
Loading…
Reference in New Issue
Block a user