mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix datatype check (#4606)
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
This commit is contained in:
parent
20c15fc04f
commit
b60846b47d
@ -599,6 +599,11 @@ void CacheFormatter::formatInput(LlmRequest const& llmRequest,
|
||||
|
||||
[[nodiscard]] bool CacheFormatter::inquireSupport(CacheState const& selfConfig, CacheState const& destConfig) const
|
||||
{
|
||||
if (selfConfig.getDataType() != destConfig.getDataType())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unordered_set<SizeType32> setVecSelf{
|
||||
selfConfig.getModelConfig().mNbKvHeadsPerLayer.begin(), selfConfig.getModelConfig().mNbKvHeadsPerLayer.end()};
|
||||
|
||||
@ -618,6 +623,7 @@ void CacheFormatter::formatInput(LlmRequest const& llmRequest,
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unordered_set<int> setVecDest{
|
||||
destConfig.getModelConfig().mNbKvHeadsPerLayer.begin(), destConfig.getModelConfig().mNbKvHeadsPerLayer.end()};
|
||||
|
||||
|
||||
@ -67,7 +67,8 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
|
||||
auto requestId = info.getRequestId();
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mFormatter->inquireSupport(mSelfState.getCacheState().value(), info.getTransState().getCacheState().value()),
|
||||
"Disagg server does not currently support these cacheState.");
|
||||
"Disagg server does not currently support these cacheState, please check the cacheState of the context and gen "
|
||||
"executors");
|
||||
auto peerRelativeRanks = executor::kv_cache::targetIRanks(info.getTransState().getCacheState().value(),
|
||||
mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx())
|
||||
.mIRanks;
|
||||
|
||||
@ -469,6 +469,10 @@ void MLACacheFormatter::formatInput(LlmRequest const& llmRequest,
|
||||
|
||||
[[nodiscard]] bool MLACacheFormatter::inquireSupport(CacheState const& selfConfig, CacheState const& destConfig) const
|
||||
{
|
||||
if (selfConfig.getDataType() != destConfig.getDataType())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if (selfConfig.getAttentionConfig().mAttentionType != CacheState::AttentionType::kMLA
|
||||
|| destConfig.getAttentionConfig().mAttentionType != CacheState::AttentionType::kMLA)
|
||||
{
|
||||
|
||||
Loading…
Reference in New Issue
Block a user