diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp index e12cefca14..416f6ccb1e 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp @@ -134,7 +134,7 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa if (worldConfig.isTensorParallel() || worldConfig.isContextParallel()) { mGroupTensorParaComm = std::make_shared( - mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getTensorParallelRank())); + mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getRank())); } int kvFactor = 2; if (cacheManager->getCacheType() == kv_cache_manager::CacheType::kSELFKONLY) @@ -148,19 +148,19 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa if (mCacheState->getParallelConfig().mEnableAttentionDP) { - int DPSize = mCacheState->getParallelConfig().mDPsize; + int dpSize = mCacheState->getParallelConfig().mDPsize; - // DPRank is derived from the tensor parallel rank, which already accounts for CP. + // dpRank is derived from the tensor parallel rank, which already accounts for CP. // Layout: rank = ppRank * (TP * CP) + tpRank * CP + cpRank. // getTensorParallelRank() correctly extracts tpRank regardless of CP. - int DPRank = mCacheState->getParallelConfig().mDPrank; + int dpRank = mCacheState->getParallelConfig().mDPrank; // - mGroupDataComm = std::make_shared(mGroupComm->split(DPRank, worldConfig.getRank())); + mGroupDataComm = std::make_shared(mGroupComm->split(dpRank, worldConfig.getRank())); if (worldConfig.isTensorParallel() || worldConfig.isContextParallel()) { - // Group ranks with same (ppRank, DPRank) accounting for CP. + // Group ranks with same (ppRank, dpRank) accounting for CP. mGroupTPInDPComm = std::make_shared( - mGroupComm->split(worldConfig.getPipelineParallelRank() * DPSize + DPRank, worldConfig.getRank())); + mGroupComm->split(worldConfig.getPipelineParallelRank() * dpSize + dpRank, worldConfig.getRank())); } } bool isMLA = attentionType == executor::kv_cache::CacheState::AttentionType::kMLA; diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index c0739b7c69..f944f50126 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -336,13 +336,14 @@ class Attention(nn.Module): key="sparse_attention_config") if config.sparse_attention_config.algorithm == "rocket": - logger.warning("disable rope_fusion for RocketKV.") + logger.warning_once("disable rope_fusion for RocketKV.", + key="disable_rope_fusion_for_rocketkv") self.rope_fusion = False if self.rope_fusion and not attn_cls.support_fused_rope(): - logger.warning( - "rope_fusion is true but the attention backend does not support it. Will disable rope_fusion." - ) + logger.warning_once( + "rope_fusion is true but the attention backend does not support it. Will disable rope_fusion.", + key="disable_rope_fusion_for_non_supported_backend") self.rope_fusion = False # If rope_fusion is not specified, enable if the attention backend supports it. if self.rope_fusion is None: @@ -824,8 +825,9 @@ class MLA(nn.Module): # tensor parallel config = config or ModelConfig() if mapping_with_cp is not None: - logger.warning( - "[MLA::__init__] Overriding mapping with CP detected.") + logger.warning_once( + "[MLA::__init__] Overriding mapping with CP detected.", + key="mla_init_mapping_with_cp") self.mapping = mapping_with_cp else: self.mapping = config.mapping