fix datatype check (#4606)

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
This commit is contained in:
Chuang Zhu 2025-05-24 08:36:17 +08:00 committed by GitHub
parent 20c15fc04f
commit b60846b47d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 12 additions and 1 deletions

View File

@ -599,6 +599,11 @@ void CacheFormatter::formatInput(LlmRequest const& llmRequest,
[[nodiscard]] bool CacheFormatter::inquireSupport(CacheState const& selfConfig, CacheState const& destConfig) const
{
if (selfConfig.getDataType() != destConfig.getDataType())
{
return false;
}
std::unordered_set<SizeType32> setVecSelf{
selfConfig.getModelConfig().mNbKvHeadsPerLayer.begin(), selfConfig.getModelConfig().mNbKvHeadsPerLayer.end()};
@ -618,6 +623,7 @@ void CacheFormatter::formatInput(LlmRequest const& llmRequest,
{
return false;
}
std::unordered_set<int> setVecDest{
destConfig.getModelConfig().mNbKvHeadsPerLayer.begin(), destConfig.getModelConfig().mNbKvHeadsPerLayer.end()};

View File

@ -67,7 +67,8 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
auto requestId = info.getRequestId();
TLLM_CHECK_WITH_INFO(
mFormatter->inquireSupport(mSelfState.getCacheState().value(), info.getTransState().getCacheState().value()),
"Disagg server does not currently support these cacheState.");
"Disagg server does not currently support these cacheState, please check the cacheState of the context and gen "
"executors");
auto peerRelativeRanks = executor::kv_cache::targetIRanks(info.getTransState().getCacheState().value(),
mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx())
.mIRanks;

View File

@ -469,6 +469,10 @@ void MLACacheFormatter::formatInput(LlmRequest const& llmRequest,
[[nodiscard]] bool MLACacheFormatter::inquireSupport(CacheState const& selfConfig, CacheState const& destConfig) const
{
if (selfConfig.getDataType() != destConfig.getDataType())
{
return false;
}
if (selfConfig.getAttentionConfig().mAttentionType != CacheState::AttentionType::kMLA
|| destConfig.getAttentionConfig().mAttentionType != CacheState::AttentionType::kMLA)
{