From b60846b47d55debbbfa5c41bc6bee545a26d877c Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Sat, 24 May 2025 08:36:17 +0800 Subject: [PATCH] fix datatype check (#4606) Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> --- cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp | 6 ++++++ cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp | 3 ++- cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp | 4 ++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index 202cd7e3ad..a9b13f5669 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -599,6 +599,11 @@ void CacheFormatter::formatInput(LlmRequest const& llmRequest, [[nodiscard]] bool CacheFormatter::inquireSupport(CacheState const& selfConfig, CacheState const& destConfig) const { + if (selfConfig.getDataType() != destConfig.getDataType()) + { + return false; + } + std::unordered_set setVecSelf{ selfConfig.getModelConfig().mNbKvHeadsPerLayer.begin(), selfConfig.getModelConfig().mNbKvHeadsPerLayer.end()}; @@ -618,6 +623,7 @@ void CacheFormatter::formatInput(LlmRequest const& llmRequest, { return false; } + std::unordered_set setVecDest{ destConfig.getModelConfig().mNbKvHeadsPerLayer.begin(), destConfig.getModelConfig().mNbKvHeadsPerLayer.end()}; diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp index fcd96cb587..f13ecb4f02 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp @@ -67,7 +67,8 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager, auto requestId = info.getRequestId(); TLLM_CHECK_WITH_INFO( mFormatter->inquireSupport(mSelfState.getCacheState().value(), info.getTransState().getCacheState().value()), - "Disagg server does not currently support these cacheState."); + "Disagg server does not currently support these cacheState, please check the cacheState of the context and gen " + "executors"); auto peerRelativeRanks = executor::kv_cache::targetIRanks(info.getTransState().getCacheState().value(), mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx()) .mIRanks; diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index b7d44414f2..771e4c291c 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -469,6 +469,10 @@ void MLACacheFormatter::formatInput(LlmRequest const& llmRequest, [[nodiscard]] bool MLACacheFormatter::inquireSupport(CacheState const& selfConfig, CacheState const& destConfig) const { + if (selfConfig.getDataType() != destConfig.getDataType()) + { + return false; + } if (selfConfig.getAttentionConfig().mAttentionType != CacheState::AttentionType::kMLA || destConfig.getAttentionConfig().mAttentionType != CacheState::AttentionType::kMLA) {