TensorRT-LLMs/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp
Netanel Haber 3c52ac098f
feat: allocate minimal blocks per window size (#3028)
* implement variable window attention by breaking the block manager into window block managers per window size

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* revert isCyclic to be true if the min attention window is reached, not per window size

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* add explanatory comment to mCyclicThreshold

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* load correct gemma config

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* don't shadow inputLength in addSequence - it should remain the function scope input length between window size loop iterations

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* fix KVCacheManagerVariableWindowAttentionWithReuseTest for multiple window block managers

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* if TYPE_CHECKING

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* set temp_attention_window_inputs to None explicitly

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* set temp_attention_window_inputs to None explicitly

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* pass dtype as well

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* test_gemma variable sliding window attention

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* allot a fraction of primary/secondaryBlocks to different window size heaps, depending on the window size's total contribution to the kvcache size (i.e., including all layers)

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* remove || mEnableBlockReuse which erroneously triggers beamsearch code for cyclic variable attention window code

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* turn off request delaying for MaxUtil

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* make comments better

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* windowSizesTotalSum using std::accumulate

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* fix error handling of forwardAsync - forwardAsync catch-all catch cleanup code that runs terminateRequest can also fail and must be caught

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* fix comments

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* remove assert that kills disagg tests, since it isn't necessary

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* fix corrupted expression: 'isNewTask && (peftCacheManager ?' -> '(isNewTask && peftCacheManager) ?' which caused boolean algebra. Main is correct

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* add Gemma3 to SUPPORTED_HF_ARCHITECTURES

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* support Gemma3

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* finally fix test_gemma - always spread at least {} into generate_summary_cmd, never None

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* finally fix test_gemma - always spread at least {} into generate_summary_cmd, never None

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* fix kvfactor field for deepseek

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* fix comment

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* fix gemma-3 entries in testlist to include vswa

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* only quantize gemma2 VSWA

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

remove misleading comment

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

fix test_gemma

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* fix test_gemma

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* fix test_gemma

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* in sendRequestInfo, fromOldAllocatedBlockIds->fromOldAllocatedBlockIds, like in main

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* fix: disable KV cache reuse if using attention sink (#3021)

* fix: disable KV cache reuse if using attention sink

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* fix: disable KV cache reuse if sink bubble

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* add comment

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

---------

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

---------

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
Co-authored-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
2025-04-17 16:04:57 +08:00

208 lines
8.9 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h"
#include "tensorrt_llm/batch_manager/cacheFormatter.h"
#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h"
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
namespace tensorrt_llm::batch_manager
{
DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<IOFormatter> formatter)
: mManager{manager}
, mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}}
, mFormatter(std::move(formatter))
, mBufferManager{std::make_shared<runtime::CudaStream>()}
{
TLLM_CHECK(mManager);
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
}
[[nodiscard]] RequestInfo DataSenderImpl::recvRequestInfo()
{
using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
Id id;
auto const* connection = mManager->recvConnect(DataContext{kID_TAG}, &id, sizeof(id));
TLLM_CHECK(id == Id::REQUEST_SEND);
std::uint64_t infoSize{0};
connection->recv(executor::kv_cache::DataContext{kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize));
std::string serializedInfo;
serializedInfo.resize(infoSize);
connection->recv(executor::kv_cache::DataContext{kINFO_TAG}, serializedInfo.data(), infoSize);
std::istringstream iss(serializedInfo);
auto info = RequestInfo::deserialize(iss);
auto requestId = info.getRequestId();
TLLM_CHECK_WITH_INFO(
mFormatter->inquireSupport(mSelfState.getCacheState().value(), info.getTransState().getCacheState().value()),
"Disagg server does not currently support these cacheState.");
auto peerRelativeRanks = executor::kv_cache::targetIRanks(info.getTransState().getCacheState().value(),
mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx())
.mIRanks;
int peerIdx = std::distance(peerRelativeRanks.begin(),
std::find(
peerRelativeRanks.begin(), peerRelativeRanks.end(), info.getTransState().getCommState()->getSelfIdx()));
{
std::unique_lock<std::mutex> lk(mMtxForMap);
auto it = mRequestToComms.find(requestId);
if (it == mRequestToComms.end())
{
int recvExpectCount = peerRelativeRanks.size();
{
it = mRequestToComms.emplace(requestId, RequestMapInfo{}).first;
it->second.resize(recvExpectCount);
}
}
it->second[peerIdx] = {connection, info.getTransState()};
}
return info;
}
void DataSenderImpl::sendSync(LlmRequest const& llmRequest)
{
std::vector<executor::kv_cache::Connection const*> connections;
auto it = mRequestToComms.find(llmRequest.mRequestId);
TLLM_CHECK(it != mRequestToComms.end());
auto const& reqToComm = it->second;
for (auto&& [connection, dataTransceiverState] : reqToComm)
{
connections.emplace_back(connection);
}
auto&& dataTransceiverState = reqToComm.at(0).second;
mFormatter->formatOutput(llmRequest, std::move(connections), mSelfState.getCacheState().value(),
mSelfState.getCommState().value().getSelfIdx(), dataTransceiverState.getCacheState().value(), mBufferManager);
}
[[nodiscard]] executor::kv_cache::CommState const& DataSenderImpl::getCommState() const
{
return mSelfState.getCommState().value();
}
void DataSenderImpl::setCommState(executor::kv_cache::CommState commState)
{
mSelfState.setCommState(std::move(commState));
}
[[nodiscard]] size_t DataSenderImpl::getCounterpartsCount(LlmRequest::RequestIdType requestId) const
{
auto it = mRequestToComms.find(requestId);
TLLM_CHECK(it != mRequestToComms.end());
return it->second.size();
}
void DataSenderImpl::release(LlmRequest::RequestIdType requestId)
{
auto it = mRequestToComms.find(requestId);
TLLM_CHECK(it != mRequestToComms.end());
std::unique_lock<std::mutex> lk(mMtxForMap);
mRequestToComms.erase(it);
}
DataReceiverImpl::DataReceiverImpl(executor::kv_cache::ConnectionManager* manager,
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<IOFormatter> formatter)
: mManager{manager}
, mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}}
, mFormatter(std::move(formatter))
{
TLLM_CHECK(mManager);
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
TLLM_CHECK(mFormatter);
}
void DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
{
uint64_t requestId = llmRequest.getContextPhaseParams().value().getReqId();
auto const& contextState = llmRequest.getDataTransceiverState();
auto const& commState = contextState.getCommState().value();
auto const& destCacheState = contextState.getCacheState().value();
TLLM_CHECK_WITH_INFO(mFormatter->inquireSupport(mSelfState.getCacheState().value(), destCacheState),
"Disagg server does not currently support these cacheState.");
RequestInfo requestInfo(requestId, mSelfState);
// TODO: remove IOFormatter and make CacheFormatter new base class
auto* cacheFormatter = dynamic_cast<kv_cache_manager::CacheFormatter const*>(mFormatter.get());
if (cacheFormatter != nullptr)
{
auto* cacheManager = cacheFormatter->getCacheManager();
auto blockRange
= kv_cache_manager::BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId);
requestInfo = RequestInfo(requestId, blockRange.getBlockHashes(), mSelfState);
}
for (auto index : mFormatter->getCounterparts(
mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx(), destCacheState))
{
auto const* connection = mManager->getConnections(commState).at(index);
sendRequestInfo(connection, requestInfo);
}
}
void DataReceiverImpl::receiveSync(LlmRequest const& llmRequest)
{
auto const& contextState = llmRequest.getDataTransceiverState();
auto const& commState = contextState.getCommState().value();
auto const& destCacheState = contextState.getCacheState().value();
std::vector<tensorrt_llm::executor::kv_cache::Connection const*> connections;
for (auto index : mFormatter->getCounterparts(
mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx(), destCacheState))
{
auto const* connection = mManager->getConnections(commState).at(index);
connections.emplace_back(connection);
}
auto const& resource = getReceiveCacheResource(llmRequest);
mFormatter->formatInput(llmRequest, std::move(connections), mSelfState.getCacheState().value(),
mSelfState.getCommState().value().getSelfIdx(), destCacheState, resource->mBufferManager);
}
void DataReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info)
{
std::ostringstream oss;
RequestInfo::serialize(info, oss);
auto const& serializedInfo = oss.str();
std::size_t const infoSize = serializedInfo.size();
Id id{Id::REQUEST_SEND};
connection->send(executor::kv_cache::DataContext{kID_TAG}, &id, sizeof(id));
connection->send(executor::kv_cache::DataContext{kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize));
connection->send(executor::kv_cache::DataContext{kINFO_TAG}, serializedInfo.data(), infoSize);
}
std::unique_ptr<DataReceiverImpl::ReceiveCacheResource> const& DataReceiverImpl::getReceiveCacheResource(
LlmRequest const& llmRequest)
{
std::scoped_lock<std::mutex> lock(mProcessIoResouceMutex);
TLLM_CHECK(llmRequest.getDataTransceiverState().getCommState().has_value());
std::string processString = "default";
if (common::getEnvRequestKVCacheConcurrent())
{
processString = llmRequest.getDataTransceiverState().getCommState()->toString();
}
if (mProcessToResources.find(processString) == mProcessToResources.end())
{
mProcessToResources.emplace(processString,
std::make_unique<ReceiveCacheResource>(
runtime::BufferManager{std::make_shared<runtime::CudaStream>()}, runtime::CudaEvent{}));
}
return mProcessToResources.at(processString);
}
} // namespace tensorrt_llm::batch_manager