[None][feat] KV Cache Connector API (#7228)

Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
Signed-off-by: richardhuo-nv <rihuo@nvidia.com>
Co-authored-by: jthomson04 <jwillthomson19@gmail.com>
Co-authored-by: Iman Tabrizian <10105175+Tabrizian@users.noreply.github.com>
Co-authored-by: Sharan Chetlur <116769508+schetlur-nv@users.noreply.github.com>
This commit is contained in:
Richard Huo 2025-08-28 20:09:27 -07:00 committed by GitHub
parent 085dc19bfa
commit ce580ce4f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 1945 additions and 93 deletions

View File

@ -0,0 +1,46 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/batch_manager/common.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/runtime/common.h"
#include <utility>
#include <vector>
using SizeType32 = tensorrt_llm::runtime::SizeType32;
using RequestIdType = tensorrt_llm::batch_manager::LlmRequest::RequestIdType;
/// See tensorrt_llm/_torch/pyexecutor/connector.py for details on the Connector API.
namespace tensorrt_llm::batch_manager::kv_connector
{
/// @brief The KV connector manager. This is passed into the C++ KV Cache Manager when adding sequences.
class KvCacheConnectorManager
{
public:
KvCacheConnectorManager() = default;
virtual ~KvCacheConnectorManager() = default;
/// @brief Handle the getNumNewMatchedTokens call inside the C++ KV Cache Manager.
/// @return The number of tokens that can be loaded from remote KV cache.
virtual SizeType32 getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) = 0;
};
} // namespace tensorrt_llm::batch_manager::kv_connector

View File

@ -16,6 +16,7 @@
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheConnector.h"
#include "tensorrt_llm/batch_manager/kvCacheEventManager.h"
#include "tensorrt_llm/batch_manager/kvCacheType.h"
#include "tensorrt_llm/batch_manager/llmRequest.h" // TODO forward declare
@ -538,7 +539,8 @@ public:
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool,
SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse);
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager);
~WindowBlockManager();
@ -835,6 +837,8 @@ private:
bool mEnablePartialReuse;
// Whether partially matched blocks that are already in use should be copied and reused.
bool mCopyOnPartialReuse;
// The kv cache connector manager
std::shared_ptr<kv_connector::KvCacheConnectorManager> mKvCacheConnectorManager;
};
class BlockManager
@ -852,7 +856,8 @@ public:
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
bool copyOnPartialReuse = true);
bool copyOnPartialReuse = true,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);
BlockManager(BlockManager const&) = delete;
BlockManager& operator=(BlockManager const&) = delete;
@ -1287,6 +1292,7 @@ public:
LlmRequest::RequestIdType requestId, SizeType32 windowSize) const
= 0;
[[nodiscard]] virtual runtime::ITensor::SharedPtr getUniquePrimaryPool() const = 0;
[[nodiscard]] virtual runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const = 0;
[[nodiscard]] virtual SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const = 0;
@ -1373,7 +1379,8 @@ public:
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
bool copyOnpartialReuse = true);
bool copyOnpartialReuse = true,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);
KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
@ -1383,7 +1390,8 @@ public:
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
bool copyOnpartialReuse = true);
bool copyOnpartialReuse = true,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);
KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
@ -1393,7 +1401,8 @@ public:
bool enableBlockReuse = true, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
bool copyOnpartialReuse = true);
bool copyOnpartialReuse = true,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);
KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
@ -1624,6 +1633,7 @@ public:
std::vector<SizeType32> getNewlyAllocatedBlockIds(
LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override;
runtime::ITensor::SharedPtr getUniquePrimaryPool() const override;
runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const override;
SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const override

View File

@ -504,7 +504,8 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
: mNumLayers{static_cast<SizeType32>(numKvHeadsPerLayer.size())}
, mTokensPerBlock{tokensPerBlock}
, mEventManager{std::move(eventManager)}
@ -513,6 +514,10 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
{
auto const uniqueWindowSizeToLayers
= BaseKVCacheManager::groupLayersByWindowSize(maxAttentionWindowVec, mNumLayers);
TLLM_CHECK_WITH_INFO(kvCacheConnectorManager == nullptr || uniqueWindowSizeToLayers.size() == 1,
"KV Cache Connector is not supported with multiple window sizes");
auto const numUniqueWindowSizes = static_cast<SizeType32>(uniqueWindowSizeToLayers.size());
mIsVariableWindow = numUniqueWindowSizes > 1;
@ -530,7 +535,7 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
mWindowBlockManagers.try_emplace(windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer,
sizePerHead, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks, maxNumSequences, stream,
onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enablePartialReuse,
copyOnPartialReuse);
copyOnPartialReuse, kvCacheConnectorManager);
}
auto const numAllPools = getNumPools();
@ -572,7 +577,8 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool,
SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks, CacheType cacheType,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
: mDataType{dtype}
, mWindowSize{windowSize}
, mNumPrimaryBlocks{blocksInPrimaryPool}
@ -596,6 +602,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
, mTotalInputTokens{0.0}
, mEnablePartialReuse{enablePartialReuse}
, mCopyOnPartialReuse{copyOnPartialReuse}
, mKvCacheConnectorManager{std::move(kvCacheConnectorManager)}
{
std::map<SizeType32, SizeType32> numLayersPerPool;
@ -1188,9 +1195,18 @@ void WindowBlockManager::addSequence(
auto const prepopulatedPromptLen = loadOrAllocateBlocks(blockKeys, numContextBlocks, sequence, perBlockRetentions);
mReusedTokens += static_cast<double>(prepopulatedPromptLen);
mTotalInputTokens += static_cast<double>(uniqueTokens.size());
llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen, getTokensPerBlock());
TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d", llmRequest.mRequestId,
inputLength, prepopulatedPromptLen);
SizeType32 numConnectorMatchedTokens = 0;
// If we're using a KV cache connector, check if any additional blocks can be loaded.
if (mKvCacheConnectorManager && !llmRequest.isDummyRequest())
{
numConnectorMatchedTokens = mKvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen);
}
llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen + numConnectorMatchedTokens, getTokensPerBlock());
TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numConnectorMatchedTokens %d",
llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numConnectorMatchedTokens);
}
// There are two versions of BlockManager::addSequence function.
@ -1206,6 +1222,13 @@ void BlockManager::addSequence(
void WindowBlockManager::addSequence(
GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock)
{
if (mKvCacheConnectorManager)
{
TLLM_LOG_WARNING(
"KV Cache Connector specified when block reuse is disabled. The KV Cache Connector will be "
"ignored.");
}
auto const requestId = sequence.getRequestId();
auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector<BlockPtr>{});
TLLM_CHECK(emplaceDone);
@ -1618,12 +1641,13 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
SizeType32 sinkTokenLength, int64_t stream, std::optional<runtime::SizeType32> maxSequenceLength,
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
: KVCacheManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth,
maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
std::make_shared<runtime::CudaStream>(reinterpret_cast<cudaStream_t>(stream)), maxSequenceLength,
enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, eventManager, enablePartialReuse,
copyOnPartialReuse)
copyOnPartialReuse, kvCacheConnectorManager)
{
}
@ -1634,7 +1658,8 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
: mMaxBeamWidth(maxBeamWidth)
, mDataType(dtype)
, mMaxAttentionWindow(*std::max_element(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end()))
@ -1644,7 +1669,7 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
, mBlockManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences,
std::move(stream), maxSequenceLength, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype,
mSinkBubbleLength, onboardBlocks, cacheType, secondaryOffloadMinPriority, std::move(eventManager),
enablePartialReuse, copyOnPartialReuse)
enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager))
// disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case
, mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse}
{
@ -1668,11 +1693,12 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
: KVCacheManager(std::vector<SizeType32>(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
std::move(stream), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority,
std::move(eventManager), enablePartialReuse, copyOnPartialReuse)
std::move(eventManager), enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager))
{
}
@ -2383,6 +2409,13 @@ std::vector<SizeType32> KVCacheManager::getNewlyAllocatedBlockIds(
return mBlockManager.getNewlyAllocatedBlockIds(getSequence(requestId), windowSize);
}
runtime::ITensor::SharedPtr KVCacheManager::getUniquePrimaryPool() const
{
TLLM_CHECK_WITH_INFO(mBlockManager.getWindowSizesMetadata().size() == 1,
"getUniquePrimaryPool is only supported for a single window size");
return mBlockManager.getPrimaryPool(0);
}
runtime::ITensor::SharedPtr KVCacheManager::getPrimaryPool(SizeType32 layer_idx) const
{
return mBlockManager.getPrimaryPool(mBlockManager.getLayerPoolIdx(layer_idx));
@ -2462,4 +2495,5 @@ SizeType32 KVCacheManager::calculateMaxBlockRequirements(SizeType32 inputLength,
auto const leftoverBlockCapacity = blockCapacity - outputBlockRequirements;
return std::min(outputLength + leftoverBlockCapacity * tokensPerBlock, inputLength + outputLength);
}
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

View File

@ -7,6 +7,7 @@ set(SRCS
batch_manager/algorithms.cpp
batch_manager/bindings.cpp
batch_manager/cacheTransceiver.cpp
batch_manager/kvCacheConnector.cpp
batch_manager/kvCacheManager.cpp
batch_manager/llmRequest.cpp
executor/bindings.cpp

View File

@ -0,0 +1,48 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h"
#include <nanobind/trampoline.h>
#include <torch/extension.h>
namespace
{
using KvCacheConnectorManager = tensorrt_llm::batch_manager::kv_connector::KvCacheConnectorManager;
namespace tb = tensorrt_llm::batch_manager;
class PyKvCacheConnectorManager : KvCacheConnectorManager
{
public:
NB_TRAMPOLINE(KvCacheConnectorManager, 1);
SizeType32 getNumNewMatchedTokens(tb::LlmRequest const& request, SizeType32 numComputedTokens) override
{
NB_OVERRIDE_PURE_NAME("get_num_new_matched_tokens", getNumNewMatchedTokens, request, numComputedTokens);
}
};
} // namespace
void tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(nb::module_& m)
{
nb::class_<tb::kv_connector::KvCacheConnectorManager, PyKvCacheConnectorManager>(m, "KvCacheConnectorManager")
.def(nb::init<>())
.def("get_num_new_matched_tokens", &tb::kv_connector::KvCacheConnectorManager::getNumNewMatchedTokens,
nb::arg("request"), nb::arg("num_computed_tokens"));
}

View File

@ -0,0 +1,39 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheConnector.h"
#include <nanobind/nanobind.h>
namespace nb = nanobind;
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
class KVCacheManagerConnectorBindings
{
public:
static void initBindings(nb::module_& m);
};
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
namespace tensorrt_llm::pybind::batch_manager::kv_connector
{
using namespace tensorrt_llm::batch_manager::kv_connector;
} // namespace tensorrt_llm::pybind::batch_manager::kv_connector

View File

@ -39,6 +39,7 @@
#include <torch/extension.h>
namespace tb = tensorrt_llm::batch_manager;
namespace tbc = tensorrt_llm::batch_manager::kv_connector;
namespace tbk = tensorrt_llm::batch_manager::kv_cache_manager;
namespace tr = tensorrt_llm::runtime;
namespace nb = nanobind;
@ -381,6 +382,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
auto pool_layer_idx = self.getPoolLayerIdx(layer_idx);
return pool.index({torch::indexing::Slice(), pool_layer_idx});
})
.def("get_unique_primary_pool", [](tbk::BaseKVCacheManager& self) { return self.getUniquePrimaryPool(); })
.def("get_block_offsets_of_batch",
[](tbk::BaseKVCacheManager& self, at::Tensor output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize,
SizeType32 beamWidth)
@ -446,12 +448,13 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
.value("SELFKONLY", tbk::CacheType::kSELFKONLY);
nb::class_<tbk::KVCacheManager, tbk::BaseKVCacheManager>(m, "KVCacheManager")
.def(nb::init<std::vector<SizeType32> const&, SizeType32, SizeType32,
std::map<SizeType32, std::tuple<SizeType32, SizeType32>> const&, SizeType32, SizeType32,
std::vector<SizeType32> const&, std::optional<tbk::TempAttentionWindowInputs> const&,
nvinfer1::DataType, SizeType32, int64_t, std::optional<runtime::SizeType32>, bool, bool,
tbk::CacheType, std::optional<tensorrt_llm::executor::RetentionPriority>,
std::shared_ptr<tbk::KVCacheEventManager>, bool, bool>(),
.def(
nb::init<std::vector<SizeType32> const&, SizeType32, SizeType32,
std::map<SizeType32, std::tuple<SizeType32, SizeType32>> const&, SizeType32, SizeType32,
std::vector<SizeType32> const&, std::optional<tbk::TempAttentionWindowInputs> const&,
nvinfer1::DataType, SizeType32, int64_t, std::optional<runtime::SizeType32>, bool, bool, tbk::CacheType,
std::optional<tensorrt_llm::executor::RetentionPriority>, std::shared_ptr<tbk::KVCacheEventManager>,
bool, bool, std::shared_ptr<tbc::KvCacheConnectorManager>>(),
nb::arg("num_kv_heads_per_layer"), nb::arg("size_per_head"), nb::arg("tokens_per_block"),
nb::arg("blocks_per_window"), nb::arg("max_num_sequences"), nb::arg("max_beam_width"),
nb::arg("max_attention_window_vec"), nb::arg("temp_attention_window_inputs").none(), nb::arg("dtype"),
@ -459,7 +462,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
nb::arg("enable_block_reuse") = false, nb::arg("onboard_blocks") = true,
nb::arg("cache_type") = tbk::CacheType::kSELF, nb::arg("secondary_offload_min_priority") = std::nullopt,
nb::arg("event_manager") = nullptr, nb::arg("enable_partial_reuse") = true,
nb::arg("copy_on_partial_reuse") = true);
nb::arg("copy_on_partial_reuse") = true, nb::arg("kv_connector_manager") = nullptr);
}
void tb::BasePeftCacheManagerBindings::initBindings(nb::module_& m)

View File

@ -34,6 +34,7 @@
#include "tensorrt_llm/nanobind/batch_manager/algorithms.h"
#include "tensorrt_llm/nanobind/batch_manager/bindings.h"
#include "tensorrt_llm/nanobind/batch_manager/cacheTransceiver.h"
#include "tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h"
#include "tensorrt_llm/nanobind/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/nanobind/batch_manager/llmRequest.h"
#include "tensorrt_llm/nanobind/executor/bindings.h"
@ -480,6 +481,8 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
tensorrt_llm::nanobind::runtime::initBindings(mInternalRuntime);
tensorrt_llm::nanobind::testing::initBindings(mInternalTesting);
tpb::initBindings(mInternalBatchManager);
tb::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(mInternalBatchManager);
tb::kv_cache_manager::KVCacheManagerBindings::initBindings(mInternalBatchManager);
tb::BasePeftCacheManagerBindings::initBindings(mInternalBatchManager);
tb::CacheTransceiverBindings::initBindings(mInternalBatchManager);

View File

@ -7,6 +7,7 @@ set(SRCS
batch_manager/algorithms.cpp
batch_manager/bindings.cpp
batch_manager/cacheTransceiver.cpp
batch_manager/kvCacheConnector.cpp
batch_manager/kvCacheManager.cpp
batch_manager/llmRequest.cpp
executor/bindings.cpp

View File

@ -19,6 +19,8 @@
#include "tensorrt_llm/batch_manager/common.h"
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
#include "tensorrt_llm/batch_manager/kvCacheConnector.h"
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
#include "tensorrt_llm/batch_manager/rnnStateManager.h"

View File

@ -0,0 +1,47 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/pybind/batch_manager/kvCacheConnector.h"
namespace
{
using KvCacheConnectorManager = tensorrt_llm::batch_manager::kv_connector::KvCacheConnectorManager;
namespace tb = tensorrt_llm::batch_manager;
class PyKvCacheConnectorManager : public KvCacheConnectorManager, py::trampoline_self_life_support
{
public:
using KvCacheConnectorManager::KvCacheConnectorManager;
SizeType32 getNumNewMatchedTokens(tb::LlmRequest const& request, SizeType32 numComputedTokens) override
{
PYBIND11_OVERRIDE_PURE_NAME(SizeType32, KvCacheConnectorManager, "get_num_new_matched_tokens",
getNumNewMatchedTokens, request, numComputedTokens);
}
};
} // namespace
void tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(py::module_& m)
{
py::class_<tb::kv_connector::KvCacheConnectorManager, PyKvCacheConnectorManager, py::smart_holder>(
m, "KvCacheConnectorManager")
.def(py::init<>())
.def("get_num_new_matched_tokens", &tb::kv_connector::KvCacheConnectorManager::getNumNewMatchedTokens,
py::arg("request"), py::arg("num_computed_tokens"));
}

View File

@ -0,0 +1,39 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheConnector.h"
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
class KVCacheManagerConnectorBindings
{
public:
static void initBindings(pybind11::module_& m);
};
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
namespace tensorrt_llm::pybind::batch_manager::kv_connector
{
using namespace tensorrt_llm::batch_manager::kv_connector;
} // namespace tensorrt_llm::pybind::batch_manager::kv_connector

View File

@ -30,6 +30,7 @@
#include <torch/extension.h>
namespace tb = tensorrt_llm::batch_manager;
namespace tbc = tensorrt_llm::batch_manager::kv_connector;
namespace tbk = tensorrt_llm::batch_manager::kv_cache_manager;
namespace tr = tensorrt_llm::runtime;
namespace py = pybind11;
@ -214,10 +215,16 @@ public:
std::deque<tensorrt_llm::executor::KVCacheEvent>, tbk::BaseKVCacheManager, getLatestEvents, timeout);
}
tensorrt_llm::runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const override
tensorrt_llm::runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 poolIdx) const override
{
PYBIND11_OVERLOAD_PURE(
tensorrt_llm::runtime::ITensor::SharedPtr, tbk::BaseKVCacheManager, getPrimaryPool, layer_idx);
tensorrt_llm::runtime::ITensor::SharedPtr, tbk::BaseKVCacheManager, getPrimaryPool, poolIdx);
}
tensorrt_llm::runtime::ITensor::SharedPtr getUniquePrimaryPool() const override
{
PYBIND11_OVERLOAD_PURE(
tensorrt_llm::runtime::ITensor::SharedPtr, tbk::BaseKVCacheManager, getUniquePrimaryPool);
}
SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const override
@ -377,6 +384,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
auto pool_layer_idx = self.getPoolLayerIdx(layer_idx);
return pool.index({torch::indexing::Slice(), pool_layer_idx});
})
.def("get_unique_primary_pool", [](tbk::BaseKVCacheManager& self) { return self.getUniquePrimaryPool(); })
.def("get_block_offsets_of_batch",
[](tbk::BaseKVCacheManager& self, at::Tensor output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize,
SizeType32 beamWidth)
@ -437,7 +445,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
std::vector<SizeType32> const&, std::optional<tbk::TempAttentionWindowInputs> const&,
nvinfer1::DataType, SizeType32, bool, int64_t, bool, bool, tbk::CacheType,
std::optional<tensorrt_llm::executor::RetentionPriority>, std::shared_ptr<tbk::KVCacheEventManager>,
bool, bool>(),
bool, bool, std::shared_ptr<tbc::KvCacheConnectorManager>>(),
py::arg("num_kv_heads_per_layer"), py::arg("size_per_head"), py::arg("tokens_per_block"),
py::arg("blocks_per_window"), py::arg("max_num_sequences"), py::arg("max_beam_width"),
py::arg("max_attention_window_vec"), py::arg("temp_attention_window_inputs"), py::arg("dtype"),
@ -445,7 +453,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
py::arg("enable_block_reuse") = false, py::arg("onboard_blocks") = true,
py::arg_v("cache_type", tbk::CacheType::kSELF, "bindings.internal.batch_manager.CacheType.SELF"),
py::arg("secondary_offload_min_priority") = std::nullopt, py::arg("event_manager") = nullptr,
py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true);
py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true,
py::arg("kv_connector_manager") = nullptr);
}
void tb::BasePeftCacheManagerBindings::initBindings(py::module_& m)

View File

@ -28,6 +28,7 @@
#include "tensorrt_llm/pybind/batch_manager/algorithms.h"
#include "tensorrt_llm/pybind/batch_manager/bindings.h"
#include "tensorrt_llm/pybind/batch_manager/cacheTransceiver.h"
#include "tensorrt_llm/pybind/batch_manager/kvCacheConnector.h"
#include "tensorrt_llm/pybind/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/pybind/batch_manager/llmRequest.h"
#include "tensorrt_llm/pybind/executor/bindings.h"
@ -468,6 +469,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
tensorrt_llm::pybind::runtime::initBindings(mInternalRuntime);
tensorrt_llm::pybind::testing::initBindings(mInternalTesting);
tpb::initBindings(mInternalBatchManager);
tb::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(mInternalBatchManager);
tb::kv_cache_manager::KVCacheManagerBindings::initBindings(mInternalBatchManager);
tb::BasePeftCacheManagerBindings::initBindings(mInternalBatchManager);
tb::CacheTransceiverBindings::initBindings(mInternalBatchManager);

View File

@ -43,15 +43,11 @@ public:
.deleter(
[ptr = std::move(tensor)](void* data) mutable
{
try
if (data != ptr->data())
{
TLLM_CHECK(data == ptr->data());
ptr.reset();
}
catch (std::exception const& e)
{
TLLM_LOG_EXCEPTION(e);
TLLM_LOG_WARNING("Torch tensor refers to deallocated memory.");
}
ptr.reset();
})
.make_tensor();
}

View File

@ -0,0 +1,248 @@
### :title KV Cache Connector
### :order 6
### :section Customization
import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
from tempfile import TemporaryDirectory
import click
import torch
from tensorrt_llm import LLM, SamplingParams, logger
from tensorrt_llm._torch.pyexecutor.kv_cache_connector import (
KvCacheConnectorScheduler, KvCacheConnectorWorker, SchedulerOutput)
from tensorrt_llm.bindings.executor import ExecutorConfig
from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig
# This is a simple example of the use of the KV cache connector.
# It persists KV cache contents into a folder, and can load them back on subsequent runs.
# See tensorrt_llm/_torch/pyexecutor/connector.py for details about the KV cache connector interface.
# NOTE: This example connector implementation is NOT suitable for production use.
CONNECTOR_CACHE_FOLDER_KEY = "CONNECTOR_CACHE_FOLDER"
@dataclass
class PersistentKvCacheConnectorMetadata:
load: list[tuple[str, int]] = field(default_factory=list)
save: list[tuple[str, int]] = field(default_factory=list)
class PersistentKvCacheConnectorWorker(KvCacheConnectorWorker):
def __init__(self, executor_config: ExecutorConfig):
super().__init__(executor_config)
self.kv_cache_tensor = None
def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
assert self.kv_cache_tensor is None, "KV cache tensor already registered"
self.kv_cache_tensor = kv_cache_tensor
def start_load_kv(self, stream: torch.cuda.Stream):
# Do all loads synchronously, and blockwise.
for path, block_id in self._metadata.load:
cpu_tensor = torch.load(path, map_location="cpu")
# Copy into the device block.
self.kv_cache_tensor[block_id].copy_(cpu_tensor, non_blocking=False)
def wait_for_layer_load(self, layer_idx: int, stream: torch.cuda.Stream):
pass
def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream):
pass
def wait_for_save(self, stream: torch.cuda.Stream):
# Make sure the forward pass is complete before beginning our save.
stream.synchronize()
for path, block_id in self._metadata.save:
cpu_tensor = self.kv_cache_tensor[block_id].cpu()
# Don't write anything if this specific block already exists.
if Path(path).exists():
continue
# Do a blocking save to the file. This way, we only return once all saves are complete.
torch.save(cpu_tensor, path)
def get_finished(
self, finished_gen_req_ids: list[int],
started_loading_req_ids: list[int]) -> tuple[list[int], list[int]]:
return [], []
class PersistentKvCacheConnectorLeader(KvCacheConnectorScheduler):
def __init__(self, executor_config: ExecutorConfig):
super().__init__(executor_config)
self.block_size = self._config.tokens_per_block
self.pending_loads = {}
self.cache_folder = os.environ.get(CONNECTOR_CACHE_FOLDER_KEY,
"./connector_cache")
os.makedirs(self.cache_folder, exist_ok=True)
def build_connector_meta(self, scheduler_output: SchedulerOutput):
# NOTE: This is a simplified implementation, and does not work with chunked prefill.
metadata = PersistentKvCacheConnectorMetadata()
for req in scheduler_output.new_requests:
# If we don't have any pending loads for this request, we can skip it.
if req.request_id not in self.pending_loads:
continue
num_computed_blocks = req.computed_position // self.block_size
block_ids = req.new_block_ids
pending_load = self.pending_loads[req.request_id]
for file_path, block_pos in zip(
pending_load, range(num_computed_blocks, len(block_ids))):
metadata.load.append((file_path, block_ids[block_pos]))
# Break up the remainder of the token sequence into chunks.
chunks = self._chunk_tokens(req.new_tokens)
# For each chunk that isn't already on device, and isn't in our connector cache, we need to save it.
for block_pos in range(num_computed_blocks + len(pending_load),
len(block_ids)):
if len(chunks[block_pos]) == self.block_size:
hashed_tokens = self._hash_tokens(chunks[block_pos])
file_path = self._file_path(hashed_tokens)
metadata.save.append((file_path, block_ids[block_pos]))
self.pending_loads = {}
return metadata
def _hash_tokens(self, tokens: list[int]) -> int:
return abs(hash(tuple(tokens)))
def _file_path(self, hash_value: int) -> Path:
return Path(self.cache_folder) / f"{hash_value}.pt"
def _chunk_tokens(self, tokens: list[int]) -> list[list[int]]:
return [
tokens[i:i + self.block_size]
for i in range(0, len(tokens), self.block_size)
]
def get_num_new_matched_tokens(
self, request: LlmRequest,
num_computed_tokens: int) -> tuple[int, bool]:
self.pending_loads[request.request_id] = []
# Don't bother with sequences with partial matches.
if (num_computed_tokens % self.block_size) != 0:
return 0, False
computed_blocks = num_computed_tokens // self.block_size
# Get all the tokens that don't have a cache hit on device.
remaining_tokens = request.get_tokens(0)[computed_blocks *
self.block_size:]
remaining_chunks = self._chunk_tokens(remaining_tokens)
# For each chunk, check if it exists in our cache.
for chunk in remaining_chunks:
# Only do full blocks.
if len(chunk) == self.block_size:
hashed_tokens = self._hash_tokens(chunk)
file_path = self._file_path(hashed_tokens)
# If we get a cache hit, we want to load it into device.
# Otherwise, we can stop looking.
if file_path.exists():
self.pending_loads[request.request_id].append(file_path)
else:
break
logger.info(
f"KV CONNECTOR: Matched {len(self.pending_loads[request.request_id])} blocks for request {request.request_id}"
)
return len(
self.pending_loads[request.request_id]) * self.block_size, False
def request_finished(self, request: LlmRequest,
cache_block_ids: list[int]) -> bool:
# We don't do any asynchronous saving, so always return False
return False
def update_state_after_alloc(self, request: LlmRequest,
block_ids: list[int]):
pass
@click.command()
@click.argument("model", type=str)
def main(model: str):
sys.path.append(os.path.join(
os.path.dirname(__file__),
"..",
))
this_module = __file__[__file__.rfind("/") + 1:__file__.rfind(".py")]
kv_connector_config = KvCacheConnectorConfig(
connector_module=this_module,
connector_scheduler_class="PersistentKvCacheConnectorLeader",
connector_worker_class="PersistentKvCacheConnectorWorker",
)
connector_cache_dir = TemporaryDirectory()
os.environ[CONNECTOR_CACHE_FOLDER_KEY] = connector_cache_dir.name
llm = LLM(model=model,
backend="pytorch",
cuda_graph_config=None,
kv_connector_config=kv_connector_config)
test_text = (
"Nvidia Corporation is an American technology company headquartered in Santa Clara, California."
"Founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem, it develops graphics processing units (GPUs), "
"system on a chips (SoCs), and application programming interfaces (APIs) for data science, high-performance computing, "
"and mobile and automotive applications. Tell me about the company.")
sampling_params = SamplingParams(max_tokens=32)
output = llm.generate([test_text], sampling_params)
text0 = output[0].outputs[0].text
print("First output: ", text0)
print("Loading new LLM instance...")
del llm
llm = LLM(model=model,
backend="pytorch",
cuda_graph_config=None,
kv_connector_config=kv_connector_config)
output = llm.generate([test_text], sampling_params)
text1 = output[0].outputs[0].text
print("Second output (using connector cache): ", text1)
assert text0 == text1
connector_cache_dir.cleanup()
if __name__ == "__main__":
main()

View File

@ -22,6 +22,7 @@ from ..speculative import get_num_extra_kv_tokens, get_spec_decoder
from .config import PyTorchConfig
from .config_utils import is_mla, is_nemotron_hybrid
from .guided_decoder import GuidedDecoder
from .kv_cache_connector import KvCacheConnectorManager
from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver
from .llm_request import ExecutorResponse
from .mamba_cache_manager import MambaHybridCacheManager
@ -44,7 +45,8 @@ class KvCacheCreator:
def __init__(self, *, executor_config: ExecutorConfig,
model_engine: PyTorchModelEngine,
draft_model_engine: Optional[PyTorchModelEngine],
mapping: Mapping, net_max_seq_len: int):
mapping: Mapping, net_max_seq_len: int,
kv_connector_manager: Optional[KvCacheConnectorManager]):
self._executor_config = executor_config
self._model_engine = model_engine
self._draft_model_engine = draft_model_engine
@ -52,6 +54,7 @@ class KvCacheCreator:
self._max_kv_tokens_in = self._executor_config.kv_cache_config.max_tokens
self._dummy_reqs = self._create_dummy_context_requests(net_max_seq_len -
1)
self._kv_connector_manager = kv_connector_manager
@staticmethod
def _get_cache_size_per_token(model_config: ModelConfig,
@ -335,7 +338,9 @@ class KvCacheCreator:
# ---------------------------handle max_gpu_total_bytes---------------------------------
def _create_kv_cache_manager(
self, model_engine: PyTorchModelEngine) -> KVCacheManager:
self,
model_engine: PyTorchModelEngine,
estimating_kv_cache: bool = False) -> KVCacheManager:
executor_config = self._executor_config
mapping = self._mapping
assert model_engine.model.model_config.is_generation, "Only construct KV cache for generation models."
@ -377,12 +382,20 @@ class KvCacheCreator:
spec_config=spec_config,
max_beam_width=executor_config.max_beam_width,
is_draft=model_engine.is_draft_model,
kv_connector_manager=self._kv_connector_manager
if not estimating_kv_cache else None,
)
elif is_nemotron_hybrid(config):
if executor_config.max_beam_width > 1:
raise ValueError(
"MambaHybridCacheManager + beam search is not supported yet."
)
if not estimating_kv_cache and self._kv_connector_manager is not None:
raise NotImplementedError(
"Connector manager is not supported for MambaHybridCacheManager."
)
config = model_engine.model.model_config.pretrained_config
num_layers = config.hybrid_override_pattern.count("*")
layer_mask = [
@ -443,6 +456,8 @@ class KvCacheCreator:
model_config=binding_model_config,
max_beam_width=executor_config.max_beam_width,
is_draft=model_engine.is_draft_model,
kv_connector_manager=self._kv_connector_manager
if not estimating_kv_cache else None,
)
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to executor_config
if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER:
@ -450,12 +465,21 @@ class KvCacheCreator:
return kv_cache_manager
def build_managers(self, resources: Dict) -> None:
def build_managers(self,
resources: Dict,
estimating_kv_cache: bool = False) -> None:
"""Construct KV caches for model and draft model (if applicable)."""
kv_cache_manager = self._create_kv_cache_manager(self._model_engine)
kv_cache_manager = self._create_kv_cache_manager(
self._model_engine, estimating_kv_cache)
if not estimating_kv_cache and self._kv_connector_manager is not None and self._draft_model_engine is not None:
raise NotImplementedError(
"Connector manager is not supported for draft model.")
draft_kv_cache_manager = self._create_kv_cache_manager(
self._draft_model_engine
self._draft_model_engine, estimating_kv_cache
) if self._draft_model_engine is not None else None
resources[ResourceManagerType.KV_CACHE_MANAGER] = kv_cache_manager
resources[
ResourceManagerType.DRAFT_KV_CACHE_MANAGER] = draft_kv_cache_manager
@ -472,20 +496,22 @@ class KvCacheCreator:
def create_py_executor_instance(
*,
dist,
resources,
mapping,
pytorch_backend_config,
executor_config,
ctx_chunk_config,
model_engine,
start_worker,
sampler,
drafter,
guided_decoder: Optional[GuidedDecoder] = None,
lora_config: Optional[LoraConfig] = None,
garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor:
*,
dist,
resources,
mapping,
pytorch_backend_config,
executor_config,
ctx_chunk_config,
model_engine,
start_worker,
sampler,
drafter,
guided_decoder: Optional[GuidedDecoder] = None,
lora_config: Optional[LoraConfig] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
kv_connector_manager: Optional[KvCacheConnectorManager] = None
) -> PyExecutor:
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
spec_config = model_engine.spec_config
@ -632,7 +658,8 @@ def create_py_executor_instance(
kv_cache_transceiver=kv_cache_transceiver,
guided_decoder=guided_decoder,
start_worker=start_worker,
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
kv_connector_manager=kv_connector_manager)
def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,

View File

@ -0,0 +1,549 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains the primary interface for the KV Cache Connector.
The KV Cache Connector is a component that allows for remote KV cache access.
It is responsible for:
- Orchestrating the loading and saving of KV cache blocks.
- Managing asynchronous block tx/rx.
It can be used to provide functionalities such as:
1. Disagg
2. KV offload/onboard
3. KV cache sharing
4. P2P KV cache transfer
etc.
The Connector API is split into two parts:
1. The scheduler, which is responsible for orchestration, and building metadata for the workers.
2. The worker, which performs and monitors transfers indicated by the scheduler's metadata.
To implement a custom KV connector, you need to implement both the scheduler and worker-side interfaces.
"""
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
Tuple)
import torch
from tensorrt_llm._utils import mpi_allgather, mpi_broadcast, mpi_rank
from tensorrt_llm.bindings import LlmRequestState
from tensorrt_llm.bindings.executor import ExecutorConfig
from tensorrt_llm.bindings.internal.batch_manager import \
KvCacheConnectorManager as KvCacheConnectorManagerCpp
from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
from .scheduler import ScheduledRequests
if TYPE_CHECKING:
from .resource_manager import KVCacheManager
# Used to store data for a single inflight request.
@dataclass
class RequestData:
# The request ID.
request_id: int
# The new tokens that were generated in the prior forward pass.
new_tokens: List[int]
# The new block IDs allocated in the prior forward pass.
new_block_ids: List[int]
# The position of the latest token with computed (valid) kv cache values.
computed_position: int
# A class to store some basic data regarding all inflight requests.
# This is used when calling `build_connector_meta` on the scheduler.
@dataclass
class SchedulerOutput:
# Requests being scheduled for the first time. Requests will show up in `new_request` exactly once.
new_requests: List[RequestData] = field(default_factory=list)
# Requests being scheduled, that have already shown up in `new_requests`.
cached_requests: List[RequestData] = field(default_factory=list)
class KvCacheConnectorWorker(ABC):
def __init__(self, config: ExecutorConfig):
self._config = config
self._metadata = None
super().__init__()
def bind_connector_meta(self, metadata: object):
self._metadata = metadata
def get_connector_meta(self) -> object:
return self._metadata
def _clear_connector_meta(self):
self._metadata = None
@abstractmethod
def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
"""
Register the KV cache tensors to the worker.
This can be used for something like NIXL registration.
Args:
kv_cache_tensor: The contiguous KV cache tensor.
"""
@abstractmethod
def start_load_kv(self, stream: torch.cuda.Stream):
"""
Begin loading the KV cache in preparation for the next forward pass.
Specific blocks to transfer are indicated by the scheduler's metadata.
"""
@abstractmethod
def wait_for_layer_load(self, layer_idx: int, stream: torch.cuda.Stream):
"""
Wait for a layer to finish being loaded before proceeding with the forward pass on the layer.
Note: This function is called immediately before the layer's work is enqueued into the stream.
Args:
layer_idx: The index of the layer to wait for.
stream: The stream the forward pass is being executed on.
"""
@abstractmethod
def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream):
"""
Begin saving the KV cache for a layer.
Note: This function is called immediately after the layer's work is enqueued into the stream.
Args:
layer_idx: The index of the layer to save.
stream: The stream the forward pass is being executed on.
"""
@abstractmethod
def wait_for_save(self, stream: torch.cuda.Stream):
"""
Block until all synchronous saving operations are complete. Called at the end of the forward pass.
"""
@abstractmethod
def get_finished(
self, finished_gen_req_ids: List[int],
started_loading_req_ids: List[int]) -> Tuple[List[int], List[int]]:
"""
Get the requests that have finished loading and saving.
Args:
finished_gen_req_ids: The IDs of the requests that have finished generating tokens, and are now asynchronously saving.
started_loading_req_ids: The IDs of the requests that have started asynchronously loading.
Returns:
The IDs of the requests that have finished saving.
The IDs of the requests that have finished loading.
Note: IDs may only be returned from this call after they've been provided in the `finished_gen_req_ids` and `started_loading_req_ids` arguments.
Additionally, the runtime will only take action based on these returned IDs once they've been returned by ALL workers. This allows some workers to take longer than others to complete the operations.
"""
class KvCacheConnectorScheduler(ABC):
def __init__(self, executor_config: ExecutorConfig):
self._config = executor_config
super().__init__()
@abstractmethod
def build_connector_meta(self, scheduler_output: SchedulerOutput):
"""
Build the metadata for the worker.
This is called by the KV Cache Manager when adding a sequence.
Args:
scheduler_output: The data for all inflight requests.
Returns:
The metadata for the workers.
"""
@abstractmethod
def get_num_new_matched_tokens(
self, request: LlmRequest,
num_computed_tokens: int) -> Tuple[int, bool]:
"""
Get the number of tokens that can be loaded from remote KV cache.
This does not include the tokens already matched on device (indicated by `num_computed_tokens`).
Args:
request: The request to get the number of tokens for.
num_computed_tokens: The number of tokens already matched on device.
Returns:
The number of tokens that can be loaded from remote KV cache.
Whether the tokens will be loaded asynchronously.
"""
@abstractmethod
def request_finished(self, request: LlmRequest,
cache_block_ids: List[int]) -> bool:
"""
Called when a request is finished generating tokens.
Args:
request: The request that finished generating tokens.
Returns:
Whether the request is performing asynchronous saving operations.
If true, this indicates that the kv cache manager should wait to deallocate the blocks until the saving has completed (determined by `get_finished` on the workers).
"""
@abstractmethod
def update_state_after_alloc(self, request: LlmRequest,
block_ids: List[int]):
"""
Called after get_num_new_matched_tokens is called to provide the block ids to the scheduler.
Args:
request: The request that was allocated resources.
block_ids: The KV cacheblock IDs that were allocated.
"""
# An internal dataclass to handle async saving/loading requests.
@dataclass
class AsyncRequests:
saving: Dict[int, LlmRequest]
loading: Dict[int, LlmRequest]
def add_from(self, other: 'AsyncRequests'):
"""
Remove requests from the other `AsyncRequests` object, and add them to this one.
"""
self.saving.update(other.saving)
self.loading.update(other.loading)
other.saving = dict()
other.loading = dict()
def extract_by_id(self, saving_ids: List[int],
loading_ids: List[int]) -> 'AsyncRequests':
"""
Extract the requests with the given IDs from this `AsyncRequests` object.
Args:
saving_ids: The IDs of the requests to extract.
loading_ids: The IDs of the requests to extract.
"""
new_async_requests = AsyncRequests(dict(), dict())
for req_id in saving_ids:
new_async_requests.saving[req_id] = self.saving[req_id]
del self.saving[req_id]
for req_id in loading_ids:
new_async_requests.loading[req_id] = self.loading[req_id]
del self.loading[req_id]
return new_async_requests
@property
def saving_ids(self) -> Set[int]:
"""
Get the IDs of the requests that are being saved asynchronously.
"""
return set(self.saving.keys())
@property
def loading_ids(self) -> Set[int]:
"""
Get the IDs of the requests that are being loaded asynchronously.
"""
return set(self.loading.keys())
class KvCacheConnectorSchedulerOutputRequest:
def __init__(self):
self.block_ids = []
self.tokens = []
def update_and_build_data(self, req: LlmRequest,
kv_cache_manager: "KVCacheManager"):
block_ids = kv_cache_manager.get_cache_indices(req)
tokens = req.get_tokens(0)
new_block_ids = block_ids[len(self.block_ids):]
new_tokens = tokens[len(self.tokens):]
self.block_ids.extend(new_block_ids)
self.tokens.extend(new_tokens)
computed_position = len(
tokens
) - 1 if req.state != LlmRequestState.CONTEXT_INIT and req.state != LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS else req.context_current_position
return RequestData(req.request_id, new_tokens, new_block_ids,
computed_position)
class KvCacheConnectorSchedulerOutputManager:
def __init__(self):
self.requests = defaultdict(KvCacheConnectorSchedulerOutputRequest)
self.external_loads = dict()
def build_scheduler_output(self, scheduled_batch: ScheduledRequests,
new_async_requests: AsyncRequests,
kv_cache_manager: "KVCacheManager"):
scheduler_output = SchedulerOutput()
for req in scheduled_batch.context_requests:
if req.request_id in new_async_requests.loading_ids:
continue
is_new = req.request_id not in self.requests
request_data = self.requests[req.request_id].update_and_build_data(
req, kv_cache_manager)
# Don't include the connector matched tokens in the initial scheduler output.
if req.request_id in self.external_loads:
request_data.computed_position -= self.external_loads[
req.request_id]
if is_new:
scheduler_output.new_requests.append(request_data)
else:
scheduler_output.cached_requests.append(request_data)
for req in scheduled_batch.generation_requests:
request_data = self.requests[req.request_id].update_and_build_data(
req, kv_cache_manager)
scheduler_output.cached_requests.append(request_data)
self.external_loads = dict()
return scheduler_output
def record_new_matched_tokens(self, request: LlmRequest,
num_new_matched_tokens: int):
self.external_loads[request.request_id] = num_new_matched_tokens
class KvCacheConnectorManager(KvCacheConnectorManagerCpp):
"""
The KvCacheConnectorManager is used to manager connector-related state.
It has the following responsibilities:
1. Managing the state of async requests (both offload and onboard)
2. Handling MPI communication. We only run the leader on one rank, but need the results of the leader API on all ranks.
Note: This class is solely an implementation detail, and is not part of the connector interface itself.
When implementing a connector API, you do not need to implement this class.
"""
def __init__(self, worker: KvCacheConnectorWorker,
scheduler: Optional[KvCacheConnectorScheduler]):
assert (scheduler is not None) == (
mpi_rank() == 0), "The scheduler may only exist on rank 0!"
super().__init__()
self.worker = worker
self.scheduler = scheduler
# Requests that haven't yet been passed into get_finished.
self.new_async_requests = AsyncRequests(dict(), dict())
# Requests that have been passed into get_finished, but haven't yet been returned.
self.pending_async_requests = AsyncRequests(dict(), dict())
# Requests that have been returned from get_finished locally, but haven't yet been returned by all workers.
self.local_finished_async_requests = AsyncRequests(dict(), dict())
# Requests that have finished loading asynchronously.
self.finished_async_loading_requests = dict()
self._scheduler_output = None
self.scheduler_output_manager = KvCacheConnectorSchedulerOutputManager()
def _run_on_leader(self, f: Callable[[], Any]) -> Any:
"""
Run a function on the leader rank, and broadcast the result to all other ranks.
"""
if self.scheduler is not None:
assert mpi_rank() == 0, "The scheduler may only exist on rank 0!"
res = f()
else:
res = None
return mpi_broadcast(res, root=0)
def get_num_new_matched_tokens(self, request: LlmRequest,
num_computed_tokens: int) -> int:
num_tokens, load_kv_async = self._run_on_leader(
lambda: self.scheduler.get_num_new_matched_tokens(
request, num_computed_tokens))
if num_tokens == 0 and load_kv_async:
raise RuntimeError(
"load_kv_async must be False when num_tokens is 0!")
# TODO(jthomson04): This part is a bit ugly.
# When the connector indicates that a request will be loaded asynchronously, we need to suspend it's execution.
# This is problematic, since at the point when this function is called, the request has already been scheduled!
# Because of this, we need to remove it from our list of scheduled requests (see `take_scheduled_requests_pending_load`).
if load_kv_async:
self.new_async_requests.loading[request.request_id] = request
self.scheduler_output_manager.record_new_matched_tokens(
request, num_tokens)
return num_tokens
def should_add_sequence(self, request: LlmRequest) -> bool:
req_id = request.request_id
return req_id not in self.finished_async_loading_requests
def build_scheduler_output(self, scheduled_batch: ScheduledRequests,
kv_cache_manager: "KVCacheManager"):
self._scheduler_output = self.scheduler_output_manager.build_scheduler_output(
scheduled_batch, self.new_async_requests, kv_cache_manager)
def take_scheduled_requests_pending_load(
self, scheduled_requests: ScheduledRequests):
"""
Remove context requests from our list of scheduled requests that are being loaded asynchronously.
This is done to prevent the runtime from attempting to load the KV cache for these requests.
Args:
scheduled_requests: The scheduled requests.
Returns:
The scheduled requests with the context requests that are being loaded asynchronously removed.
"""
allowed_context_requests = []
for req in scheduled_requests.context_requests:
# If this request is being loaded asynchronously, in addition to removing it from the list of scheduled requests,
# we also need to update it's state.
if req.request_id in self.new_async_requests.loading.keys():
req.state = LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS
# Replace the request with the canonical request.
self.new_async_requests.loading[req.request_id] = req
else:
allowed_context_requests.append(req)
# Update the list of scheduled requests.
scheduled_requests.context_requests = allowed_context_requests
def handle_metadata(self) -> object:
metadata = self._run_on_leader(
lambda: self.scheduler.build_connector_meta(self._scheduler_output))
self._scheduler_output = None
self.worker.bind_connector_meta(metadata)
def request_finished(self, req: LlmRequest,
cache_block_ids: List[int]) -> bool:
"""
Called when a request is finished generating tokens.
Args:
req: The request that finished generating tokens.
Returns:
Whether the request is performing asynchronous saving operations. If true, we do not immediately call free_resources on the request.
"""
if req.request_id in self.finished_async_loading_requests:
del self.finished_async_loading_requests[req.request_id]
saving_async = self._run_on_leader(
lambda: self.scheduler.request_finished(req, cache_block_ids))
# This is similar to take_scheduled_requests_pending_load.
# We need to update the request's state to indicate that it's still being used, but isn't schedulable.
if saving_async:
self.new_async_requests.saving[req.request_id] = req
req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
return saving_async
def get_finished(self) -> List[LlmRequest]:
"""
Process requests that have finished loading and saving.
Returns:
The requests that have newly finished saving.
"""
started_loading_req_ids = list(self.new_async_requests.loading_ids)
finished_gen_req_ids = list(self.new_async_requests.saving_ids)
# Add the requests to our list of outstanding (still in progress) requests.
self.pending_async_requests.add_from(self.new_async_requests)
# Pass these newly finished requests into get_finished, and get the list of requests that have finished saving and loading.
(finished_saving,
finished_loading) = self.worker.get_finished(finished_gen_req_ids,
started_loading_req_ids)
# Remove the requests from our pending list that have finished locally.
new_local_finished_async_requests = self.pending_async_requests.extract_by_id(
finished_saving, finished_loading)
# Add these requests to our list of locally finished requests.
self.local_finished_async_requests.add_from(
new_local_finished_async_requests)
# Broadcast this whole list to all other workers.
finished_saving = list(self.local_finished_async_requests.saving_ids)
finished_loading = list(self.local_finished_async_requests.loading_ids)
all_results = mpi_allgather((finished_saving, finished_loading))
# Find only the requests that have been reported complete by all workers.
intersect_finished_saving = set.intersection(
*[set(res[0]) for res in all_results])
intersect_finished_loading = set.intersection(
*[set(res[1]) for res in all_results])
# Remove these requests from our list of locally finished requests.
all_finished = self.local_finished_async_requests.extract_by_id(
intersect_finished_saving, intersect_finished_loading)
# For requests that have finished loading, move them back to the context state.
for id, req in all_finished.loading.items():
req.state = LlmRequestState.CONTEXT_INIT
self.finished_async_loading_requests[id] = req
# Return the requests that have finished saving.
# The execution loop will call _terminate_request on these requests.
return list(all_finished.saving.values())
def update_state_after_alloc(self, req: LlmRequest, block_ids: List[int]):
if self.scheduler is not None:
self.scheduler.update_state_after_alloc(req, block_ids)
def set_scheduler_output(self, scheduler_output: SchedulerOutput):
self._scheduler_output = scheduler_output
def layer_pre_hook(self, module, *args):
self.worker.wait_for_layer_load(module.layer_idx,
torch.cuda.current_stream())
def layer_post_hook(self, module, *args):
self.worker.save_kv_layer(module.layer_idx, torch.cuda.current_stream())

View File

@ -36,10 +36,12 @@ from tensorrt_llm.runtime.generation import CUASSERT
from ..distributed import Distributed
from ..models.modeling_utils import DecoderModelForCausalLM
from ..modules.decoder_layer import DecoderLayer
from ..speculative.drafter import Drafter
from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem
from .guided_decoder import GuidedDecoder
from .handle_logits import HandleLogits
from .kv_cache_connector import KvCacheConnectorManager
from .kv_cache_transceiver import KvCacheTransceiver
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
LlmResponse, get_draft_token_length)
@ -137,23 +139,25 @@ class BatchStatePP(BatchState):
class PyExecutor:
def __init__(self,
resource_manager,
scheduler: RequestScheduler,
model_engine: ModelEngine,
sampler: Sampler,
dist: Distributed,
max_num_sequences: int,
drafter: Optional[Drafter] = None,
disable_overlap_scheduler: bool = False,
max_input_len: int = 2048,
max_batch_size: int = 8,
max_beam_width: int = 1,
max_draft_len: int = 0,
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
guided_decoder: Optional[GuidedDecoder] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
start_worker: bool = True):
def __init__(
self,
resource_manager,
scheduler: RequestScheduler,
model_engine: ModelEngine,
sampler: Sampler,
dist: Distributed,
max_num_sequences: int,
drafter: Optional[Drafter] = None,
disable_overlap_scheduler: bool = False,
max_input_len: int = 2048,
max_batch_size: int = 8,
max_beam_width: int = 1,
max_draft_len: int = 0,
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
guided_decoder: Optional[GuidedDecoder] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
start_worker: bool = True,
kv_connector_manager: Optional[KvCacheConnectorManager] = None):
super(PyExecutor, self).__init__()
self.device_id = torch.cuda.current_device()
self.global_rank = global_mpi_rank()
@ -270,9 +274,42 @@ class PyExecutor:
self.worker_started = False
self.worker_lock = threading.Lock()
self.kv_connector_manager = kv_connector_manager
self._maybe_init_kv_connector_manager()
if start_worker:
self.start_worker()
def _maybe_init_kv_connector_manager(self):
if self.kv_connector_manager is not None:
if self.kv_cache_transceiver is not None:
raise NotImplementedError(
"KV Cache Connector is not supported with KvCacheTransceiver."
)
if self.dist.pp_size > 1:
raise NotImplementedError(
"KV Cache Connector is not supported with pipeline parallelism."
)
if self.kv_cache_manager is None:
raise ValueError(
"KV Cache Connector requires a KV Cache Manager.")
kv_tensor = self.kv_cache_manager.get_unique_primary_pool()
self.kv_connector_manager.worker.register_kv_caches(kv_tensor)
# For each of our layers, we need to register the pre/post hooks.
# These are used for methods like `wait_for_layer_load` and `save_kv_layer`.
for _name, module in self.model_engine.model.named_modules():
if isinstance(module, DecoderLayer):
module.register_forward_pre_hook(
self.kv_connector_manager.layer_pre_hook)
module.register_forward_hook(
self.kv_connector_manager.layer_post_hook)
def _event_loop_wrapper(self):
try:
with customized_gc_thresholds(
@ -920,6 +957,25 @@ class PyExecutor:
self.guided_decoder.build(scheduled_batch)
self.guided_decoder.execute(scheduled_batch, logits)
def _kv_connector_start_batch(self, scheduled_batch):
if self.kv_connector_manager:
self.kv_connector_manager.take_scheduled_requests_pending_load(
scheduled_batch)
self.kv_connector_manager.handle_metadata()
self.kv_connector_manager.worker.start_load_kv(
torch.cuda.current_stream())
def _kv_connector_terminate_requests(self):
if self.kv_connector_manager:
reqs_to_terminate = self.kv_connector_manager.get_finished()
for req in reqs_to_terminate:
self.resource_manager.free_resources(req)
def _kv_connector_wait_for_save(self):
if self.kv_connector_manager is not None:
self.kv_connector_manager.worker.wait_for_save(
torch.cuda.current_stream())
def _executor_loop(self):
torch.cuda.set_device(self.device_id)
# ensure the context is created, otherwise, some MPI calls will fail.
@ -950,12 +1006,17 @@ class PyExecutor:
# Return the first token to the client
self._handle_first_token_response(scheduled_batch)
self.resource_manager.prepare_resources(scheduled_batch)
if self.kv_cache_transceiver and self.guided_decoder:
self.guided_decoder.init_disagg_gen_requests(
scheduled_batch)
self._kv_connector_start_batch(scheduled_batch)
if scheduled_batch.batch_size > 0 or (
self.enable_attention_dp and self.dist.tp_size > 1):
if self.drafter is not None and self.use_spec_decode:
with request_context(
is_draft=True,
@ -1001,6 +1062,8 @@ class PyExecutor:
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
self._terminate_ctx_finished_requests()
self._kv_connector_terminate_requests()
if self.enable_iter_perf_stats:
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
'num_ctx_tokens']
@ -1067,9 +1130,12 @@ class PyExecutor:
# For generation requests which have completed KV cache transfer
self._prepare_disagg_gen_transmission_complete(
scheduled_batch)
self.resource_manager.prepare_resources(scheduled_batch)
self._kv_connector_start_batch(scheduled_batch)
if scheduled_batch.batch_size > 0:
# The generation requests that are do not have batch_idx,
# needs to be in front of the batch due to the assumptions
# made in model_engine.py::_forward_step. This is only important
@ -1125,6 +1191,8 @@ class PyExecutor:
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
self._terminate_ctx_finished_requests()
self._kv_connector_terminate_requests()
def _process_previous_batch(self):
if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs:
for req in self.previous_batch.ctx_transmission_reqs:
@ -1456,6 +1524,9 @@ class PyExecutor:
outputs = forward(scheduled_requests, self.resource_manager,
new_tensors_device, gather_context_logits,
cache_indirection_buffer)
self._kv_connector_wait_for_save()
return outputs
except Exception as e:
traceback.print_exc()
@ -1579,7 +1650,21 @@ class PyExecutor:
self._enqueue_responses(error_responses.items())
def _terminate_request(self, request: LlmRequest):
self.resource_manager.free_resources(request)
if self.kv_connector_manager is None:
self.resource_manager.free_resources(request)
else:
# Only call request_finished on the connector if the request has already been added to the kv cache manager.
try:
cache_block_ids = self.kv_cache_manager.get_cache_indices(
request)
except IndexError:
# If the request has not yet been added to the kv cache manager,
# we still need to free resources corresponding to other resource managers.
self.resource_manager.free_resources(request)
else:
if not self.kv_connector_manager.request_finished(
request, cache_block_ids):
self.resource_manager.free_resources(request)
@nvtx_range("_handle_canceled_requests")
def _handle_canceled_requests(self):

View File

@ -1,5 +1,7 @@
import copy
import enum
import importlib
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import chain
@ -10,8 +12,11 @@ import torch
import tensorrt_llm
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig
from tensorrt_llm.bindings.executor import (CapacitySchedulerPolicy,
ContextChunkingPolicy,
ExecutorConfig)
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.mapping import Mapping
@ -26,6 +31,7 @@ from ._util import (KvCacheCreator, _adjust_torch_mem_fraction,
from .config import LoadFormat, PyTorchConfig
from .config_utils import is_mla
from .guided_decoder import GuidedDecoder
from .kv_cache_connector import KvCacheConnectorManager
from .model_engine import PyTorchModelEngine
from .py_executor import PyExecutor
@ -206,7 +212,9 @@ def create_py_executor(
executor_config: ExecutorConfig,
checkpoint_dir: str = None,
lora_config: Optional[LoraConfig] = None,
garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor:
garbage_collection_gen0_threshold: Optional[int] = None,
kv_connector_config: Optional[KvCacheConnectorConfig] = None
) -> PyExecutor:
_mangle_executor_config(executor_config)
pytorch_backend_config = executor_config.pytorch_backend_config
@ -375,21 +383,70 @@ def create_py_executor(
pytorch_backend_config, mapping)
logger.info(f"Using Sampler: {type(sampler).__name__}")
if kv_connector_config is not None:
logger.info(
f"Initializing kv connector with config: {kv_connector_config}")
if pytorch_backend_config.use_cuda_graph:
raise NotImplementedError(
"CUDA graphs are not supported with KV connector hooks.")
if executor_config.scheduler_config.capacity_scheduler_policy != CapacitySchedulerPolicy.GUARANTEED_NO_EVICT:
raise NotImplementedError(
"KV connector is only supported with guaranteed no evict scheduler policy."
)
try:
module = importlib.import_module(
kv_connector_config.connector_module)
worker_cls = getattr(module,
kv_connector_config.connector_worker_class)
scheduler_cls = getattr(
module, kv_connector_config.connector_scheduler_class)
rank = tensorrt_llm.mpi_rank()
# Some connector API implementations may need to establish out-of-band communication between the scheduler and workers.
# In this case, the worker may be dependent on the scheduler, or vice-versa.
# To deal with cases like this, we instantiate them both concurrently.
with ThreadPoolExecutor(max_workers=2) as executor:
connector_worker_task = executor.submit(worker_cls,
executor_config)
if scheduler_cls is not None and rank == 0:
connector_scheduler_task = executor.submit(
scheduler_cls, executor_config)
connector_scheduler = connector_scheduler_task.result()
else:
connector_scheduler = None
connector_worker = connector_worker_task.result()
kv_connector_manager = KvCacheConnectorManager(
connector_worker, connector_scheduler)
except Exception as e:
logger.error(f"Error instantiating connector: {e}")
raise e
else:
kv_connector_manager = None
resources = {}
estimating_kv_cache = False
kv_cache_creator = None
if model_engine.model.model_config.is_generation:
#NOTE: non-generation models do not have kv cache
kv_cache_creator = KvCacheCreator(executor_config=executor_config,
model_engine=model_engine,
draft_model_engine=draft_model_engine,
mapping=mapping,
net_max_seq_len=net_max_seq_len)
kv_cache_creator = KvCacheCreator(
executor_config=executor_config,
model_engine=model_engine,
draft_model_engine=draft_model_engine,
mapping=mapping,
net_max_seq_len=net_max_seq_len,
kv_connector_manager=kv_connector_manager)
estimating_kv_cache = kv_cache_creator.try_prepare_estimation()
with mem_monitor.observe_creation_stage(
_ExecutorCreationStage.INIT_KV_CACHE
if estimating_kv_cache else _ExecutorCreationStage.KV_CACHE):
kv_cache_creator.build_managers(resources)
kv_cache_creator.build_managers(resources, estimating_kv_cache)
# Resource managers for speculative decoding
# For user-specified drafters, use extra_resource_managers in PyTorchBackend config
@ -425,6 +482,8 @@ def create_py_executor(
guided_decoder=guided_decoder,
lora_config=lora_config,
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
kv_connector_manager=kv_connector_manager
if not estimating_kv_cache else None,
)
if estimating_kv_cache:
@ -441,7 +500,7 @@ def create_py_executor(
# create_kv_cache_manager above, which caps executor_config.max_seq_len. Restoring
# the original value before creating the final KV cache.
executor_config.max_seq_len = max_seq_len
kv_cache_creator.build_managers(resources)
kv_cache_creator.build_managers(resources, False)
for eng in [model_engine, draft_model_engine]:
if eng is None:
@ -468,6 +527,7 @@ def create_py_executor(
lora_config=lora_config,
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold,
kv_connector_manager=kv_connector_manager,
)
_adjust_torch_mem_fraction(executor_config.pytorch_backend_config)

View File

@ -17,6 +17,7 @@ from tensorrt_llm.sampling_params import SamplingParams
from ..._utils import binding_dtype_size, binding_to_str_dtype, nvtx_range
from ...logger import logger
from ...mapping import CpType, Mapping
from .kv_cache_connector import KvCacheConnectorManager
from .llm_request import (LlmRequest, LlmRequestState, SamplingConfig,
get_draft_token_length)
from .scheduler import ScheduledRequests
@ -161,6 +162,7 @@ class KVCacheManager(BaseResourceManager):
model_config: Optional[ModelConfig] = None,
max_beam_width: int = 1,
is_draft: bool = False,
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
) -> None:
self.mapping = mapping
self.dtype = dtype
@ -178,6 +180,8 @@ class KVCacheManager(BaseResourceManager):
for offset, idx in enumerate(self.pp_layers)
}
self.kv_connector_manager = kv_connector_manager
tp_size = mapping.tp_size
if mapping.enable_attention_dp:
tp_size = 1
@ -329,6 +333,7 @@ class KVCacheManager(BaseResourceManager):
'cache_type': kv_cache_type,
'enable_partial_reuse': kv_cache_config.enable_partial_reuse,
'copy_on_partial_reuse': kv_cache_config.copy_on_partial_reuse,
'kv_connector_manager': self.kv_connector_manager,
}
if self.event_buffer_max_size > 0:
if mapping.enable_attention_dp:
@ -413,7 +418,8 @@ class KVCacheManager(BaseResourceManager):
== self.mapping.cp_size - 1 else 0),
req_beam_width, req)
else:
if req.is_first_context_chunk:
if req.is_first_context_chunk and self._kv_connector_should_add_sequence(
req):
self.impl.add_sequence(req.py_request_id,
req.prompt_len, req_beam_width,
req)
@ -422,11 +428,24 @@ class KVCacheManager(BaseResourceManager):
for _ in range(get_draft_token_length(req)):
self.impl.add_token(req.py_request_id)
if self.kv_connector_manager is not None:
block_ids = self.get_cache_indices(req)
self.kv_connector_manager.update_state_after_alloc(
req, block_ids)
for req in generation_batch:
self.impl.add_token(req.py_request_id)
for _ in range(get_draft_token_length(req)):
self.impl.add_token(req.py_request_id)
if self.kv_connector_manager is not None:
self.kv_connector_manager.build_scheduler_output(
scheduled_batch, self)
def _kv_connector_should_add_sequence(self, request: LlmRequest) -> bool:
return self.kv_connector_manager is None or self.kv_connector_manager.should_add_sequence(
request)
def add_dummy_requests(
self,
request_ids: List[int],
@ -626,6 +645,9 @@ class KVCacheManager(BaseResourceManager):
self.head_dim,
)
def get_unique_primary_pool(self) -> torch.Tensor:
return self.impl.get_unique_primary_pool()
def get_block_ids_per_seq(self, request_ids: List[int]) -> torch.Tensor:
block_ids_per_seq = self.get_batch_cache_indices(request_ids)
block_ids_per_seq_tensors = [

View File

@ -21,7 +21,7 @@ from .._utils import mpi_world_size
from ..bindings import executor as tllm
from ..builder import Engine
from ..disaggregated_params import DisaggregatedParams
from ..llmapi.llm_args import TorchLlmArgs
from ..llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs
from ..llmapi.llm_utils import KvCacheRetentionConfig
from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available,
need_spawn_mpi_workers)
@ -356,6 +356,7 @@ class GenerationExecutor(ABC):
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
lora_config: Optional[LoraConfig] = None,
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[TorchLlmArgs] = None,
@ -405,7 +406,8 @@ class GenerationExecutor(ABC):
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor)
is_llm_executor=is_llm_executor,
kv_connector_config=kv_connector_config)
# WAR: For the performance of gathering logits, we use single process worker
# for TP1 to avoid the large overhead of IPC.
@ -415,8 +417,10 @@ class GenerationExecutor(ABC):
logger.warning(
"Using single process worker for TP1, this may hurt streaming generation performance."
)
return GenerationExecutorWorker(**worker_kwargs,
is_llm_executor=is_llm_executor)
return GenerationExecutorWorker(
**worker_kwargs,
is_llm_executor=is_llm_executor,
kv_connector_config=kv_connector_config)
# For single-gpu case:
# Partition the workload to multiple process for streaming performance.
@ -428,7 +432,8 @@ class GenerationExecutor(ABC):
model_world_size=model_world_size,
mpi_session=None, # use mpi4py
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor)
is_llm_executor=is_llm_executor,
kv_connector_config=kv_connector_config)
else:
ctx = multiprocessing.get_context("spawn")
# The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot.
@ -439,7 +444,8 @@ class GenerationExecutor(ABC):
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor)
is_llm_executor=is_llm_executor,
kv_connector_config=kv_connector_config)
def wait_first_completed(
self, futures: List[GenerationResult]

View File

@ -12,6 +12,7 @@ import zmq.asyncio
from tensorrt_llm.logger import logger
from .._utils import customized_gc_thresholds, mpi_rank, nvtx_range_debug
from ..llmapi.llm_args import KvCacheConnectorConfig
from ..llmapi.mpi_session import (MpiCommSession, MpiPoolSession, MpiSession,
RemoteMpiCommSessionClient)
from ..llmapi.tracer import enable_llm_tracer, get_tracer, global_tracer
@ -45,6 +46,7 @@ class GenerationExecutorProxy(GenerationExecutor):
worker_cls: type = GenerationExecutorWorker,
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
) -> None:
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig(
)
@ -93,7 +95,8 @@ class GenerationExecutorProxy(GenerationExecutor):
worker_kwargs = dict(**worker_kwargs,
worker_queues=self._setup_queues(),
postproc_worker_config=postproc_worker_config,
is_llm_executor=False)
is_llm_executor=False,
kv_connector_config=kv_connector_config)
if "log_level" not in worker_kwargs:
worker_kwargs["log_level"] = logger.level

View File

@ -18,7 +18,7 @@ from .._utils import (KVCacheEventSerializer, global_mpi_rank, global_mpi_size,
mpi_comm, mpi_rank, nvtx_range_debug)
from ..bindings import executor as tllm
from ..builder import ConfigEncoder, Engine, EngineConfig
from ..llmapi.llm_args import PybindMirror, TorchLlmArgs
from ..llmapi.llm_args import KvCacheConnectorConfig, PybindMirror, TorchLlmArgs
from ..llmapi.mpi_session import set_mpi_session_cpp
from ..llmapi.tokenizer import TokenizerBase
from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
@ -61,6 +61,7 @@ class GenerationExecutorWorker(GenerationExecutor):
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
lora_config: Optional[LoraConfig] = None,
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[TorchLlmArgs] = None,
@ -87,6 +88,10 @@ class GenerationExecutorWorker(GenerationExecutor):
self._is_pytorch_backend = llm_args is not None and llm_args.backend == "pytorch"
self.llm_args = llm_args
if not self._is_pytorch_backend and kv_connector_config is not None:
raise ValueError(
"KV connector config is only supported for PyTorch backend")
if global_mpi_size() > 1:
logger.set_rank(self.global_rank)
@ -127,6 +132,7 @@ class GenerationExecutorWorker(GenerationExecutor):
args["lora_config"] = lora_config
args[
"garbage_collection_gen0_threshold"] = llm_args.garbage_collection_gen0_threshold
args["kv_connector_config"] = kv_connector_config
elif executor_config.backend == "_autodeploy":
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
create_autodeploy_executor
@ -680,6 +686,7 @@ def worker_main(
is_llm_executor: Optional[
bool] = True, # whether it's the main executor instance
lora_config: Optional[LoraConfig] = None,
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[TorchLlmArgs] = None,
@ -809,6 +816,7 @@ def worker_main(
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
lora_config=lora_config,
kv_connector_config=kv_connector_config,
hf_model_dir=hf_model_dir,
tokenizer=tokenizer,
llm_args=llm_args)

View File

@ -983,6 +983,8 @@ class _TorchLLM(BaseLLM):
),
is_llm_executor=True,
lora_config=self.args.lora_config,
# Autodeploy does not support kv_connector_config
kv_connector_config=getattr(self.args, "kv_connector_config", None),
hf_model_dir=self._hf_model_dir,
tokenizer=self.tokenizer,
llm_args=self.args)

View File

@ -404,6 +404,21 @@ class DecodingBaseConfig(StrictBaseModel):
self.decoding_type.upper())
class KvCacheConnectorConfig(StrictBaseModel):
"""
Configuration for the KV Cache Connector.
"""
connector_module: str = Field(
...,
description=
"The import path to the connector module. It will be imported with `importlib.import_module`."
)
connector_scheduler_class: str = Field(
..., description="The class name of the scheduler within the module.")
connector_worker_class: str = Field(
..., description="The class name of the worker within the module.")
class MedusaDecodingConfig(DecodingBaseConfig):
medusa_choices: Optional[List[List[int]]] = None
num_medusa_heads: Optional[int] = None
@ -2302,6 +2317,11 @@ class TorchLlmArgs(BaseLlmArgs):
status="prototype",
)
kv_connector_config: Optional[KvCacheConnectorConfig] = Field(
default=None,
description="The config for KV cache connector.",
)
mm_encoder_only: bool = Field(
default=False,
description=

View File

@ -0,0 +1,348 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from unittest.mock import MagicMock, patch
import pytest
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, KvCacheConnectorConfig
from ..conftest import llm_models_root
@pytest.fixture(scope="function")
def model_with_connector():
with patch("tensorrt_llm._torch.pyexecutor.py_executor_creator.importlib"
) as importlib_mock:
mock_scheduler = MagicMock()
mock_worker = MagicMock()
importlib_mock.import_module.return_value.KvConnectorScheduler.return_value = mock_scheduler
importlib_mock.import_module.return_value.KvConnectorWorker.return_value = mock_worker
kv_connector_config = KvCacheConnectorConfig(
connector_module="",
connector_scheduler_class="KvConnectorScheduler",
connector_worker_class="KvConnectorWorker",
)
def model_fn(*args, **kwargs):
return LLM(
*args,
**kwargs,
model=f"{llm_models_root()}/Qwen2-0.5B",
backend="pytorch",
kv_connector_config=kv_connector_config,
cuda_graph_config=None,
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1),
)
yield model_fn, mock_scheduler, mock_worker
@pytest.fixture(scope="function")
def enforce_single_worker(monkeypatch):
monkeypatch.setenv("TLLM_WORKER_USE_SINGLE_PROCESS", "1")
yield
@pytest.mark.threadleak(enabled=False)
@pytest.mark.parametrize("use_overlap_scheduler", [True, False])
def test_connector_simple(enforce_single_worker, model_with_connector,
use_overlap_scheduler):
NUM_TOKENS = 8
model_fn, scheduler, worker = model_with_connector
model = model_fn(disable_overlap_scheduler=not use_overlap_scheduler, )
assert worker.register_kv_caches.call_count == 1
scheduler.get_num_new_matched_tokens.return_value = 0, False
worker.get_finished.return_value = [], []
sampling_params = SamplingParams(max_tokens=NUM_TOKENS, ignore_eos=True)
model.generate(["Hello, world"], sampling_params)
assert scheduler.update_state_after_alloc.call_count == 1
# Allocate 1 block.
assert len(scheduler.update_state_after_alloc.call_args.args[1]) == 1
# With the overlap scheduler, we generate one extra token.
assert scheduler.build_connector_meta.call_count == NUM_TOKENS + int(
use_overlap_scheduler)
# We should have a single `SchedulerOutput` per forward pass.
for i, call in enumerate(scheduler.build_connector_meta.call_args_list):
scheduler_output = call[0][0]
if i == 0:
assert len(scheduler_output.new_requests) == 1
assert len(scheduler_output.cached_requests) == 0
elif i == 1 and use_overlap_scheduler:
assert len(scheduler_output.new_requests) == 0
assert len(scheduler_output.cached_requests) == 1
assert len(scheduler_output.cached_requests[0].new_tokens) == 0
else:
assert len(scheduler_output.new_requests) == 0
assert len(scheduler_output.cached_requests) == 1
assert len(scheduler_output.cached_requests[0].new_tokens) == 1
# We call `start_load_kv` once at the beginning of each forward pass.
assert worker.start_load_kv.call_count == NUM_TOKENS + int(
use_overlap_scheduler)
# Only called once when the request is received.
assert scheduler.get_num_new_matched_tokens.call_count == 1
num_layers = max(call.args[0]
for call in worker.wait_for_layer_load.call_args_list) + 1
# Called num_layers * num_forward_passes times.
assert worker.wait_for_layer_load.call_count == num_layers * (
NUM_TOKENS + int(use_overlap_scheduler))
assert worker.save_kv_layer.call_count == num_layers * (
NUM_TOKENS + int(use_overlap_scheduler))
for i, call in enumerate(worker.wait_for_layer_load.call_args_list):
assert call.args[0] == i % num_layers
for i, call in enumerate(worker.save_kv_layer.call_args_list):
assert call.args[0] == i % num_layers
assert worker.wait_for_save.call_count == NUM_TOKENS + int(
use_overlap_scheduler)
assert scheduler.request_finished.call_count == 1
assert len(scheduler.request_finished.call_args.args[1]) == 1
assert worker.get_finished.call_count == NUM_TOKENS + int(
use_overlap_scheduler)
@pytest.mark.threadleak(enabled=False)
@pytest.mark.parametrize("use_overlap_scheduler", [True, False])
def test_connector_async_onboard(enforce_single_worker, model_with_connector,
use_overlap_scheduler):
NUM_TOKENS = 8
model_fn, scheduler, worker = model_with_connector
model = model_fn(disable_overlap_scheduler=not use_overlap_scheduler, )
assert worker.register_kv_caches.call_count == 1
scheduler.get_num_new_matched_tokens.return_value = 16, True
worker.get_finished.side_effect = lambda finished_gen, load_async: (
finished_gen, load_async)
model.generate([
"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."
], SamplingParams(max_tokens=NUM_TOKENS, ignore_eos=True))
# Once for the initial poll, then once for each token. One extra token when using the overlap scheduler.
assert worker.get_finished.call_count == NUM_TOKENS + 1 + int(
use_overlap_scheduler)
# In the first iteration, there should be a single request id provided.
assert len(worker.get_finished.call_args_list[0].args[1]) == 1
@pytest.mark.threadleak(enabled=False)
@pytest.mark.parametrize("use_overlap_scheduler", [True, False])
def test_connector_async_save(enforce_single_worker, model_with_connector,
use_overlap_scheduler):
NUM_TOKENS = 8
model_fn, scheduler, worker = model_with_connector
model = model_fn(disable_overlap_scheduler=not use_overlap_scheduler, )
assert worker.register_kv_caches.call_count == 1
scheduler.get_num_new_matched_tokens.return_value = 0, False
scheduler.request_finished.return_value = True
worker.get_finished.side_effect = lambda finished_gen, load_async: (
finished_gen, load_async)
sampling_params = SamplingParams(max_tokens=NUM_TOKENS, ignore_eos=True)
model.generate(["Hello, world"], sampling_params)
assert scheduler.request_finished.call_count == 1
assert len(scheduler.request_finished.call_args.args[1]) == 1
# On the last call to get_finished, we should be providing the async saving request. One extra token when using the overlap scheduler.
assert worker.get_finished.call_count == NUM_TOKENS + int(
use_overlap_scheduler)
for i, call in enumerate(worker.get_finished.call_args_list):
args = call.args
if i != len(worker.get_finished.call_args_list) - 1:
assert args == ([], [])
else:
assert len(args[0]) == 1
assert args[0][0] == scheduler.request_finished.call_args.args[
0].request_id
@pytest.mark.threadleak(enabled=False)
@pytest.mark.parametrize("use_overlap_scheduler", [True, False])
def test_connector_scheduler_output(enforce_single_worker, model_with_connector,
use_overlap_scheduler):
NUM_INPUT_TOKENS = 48
NUM_TOKENS = 32
BLOCK_SIZE = 32
model_fn, scheduler, worker = model_with_connector
model = model_fn(disable_overlap_scheduler=not use_overlap_scheduler, )
assert worker.register_kv_caches.call_count == 1
scheduler.get_num_new_matched_tokens.return_value = 0, False
worker.get_finished.return_value = [], []
sampling_params = SamplingParams(max_tokens=32, ignore_eos=True)
model.generate([0] * NUM_INPUT_TOKENS, sampling_params)
assert scheduler.update_state_after_alloc.call_count == 1
assert len(
scheduler.update_state_after_alloc.call_args.args[1]) == math.ceil(
NUM_INPUT_TOKENS / BLOCK_SIZE)
# Additional token when using the overlap scheduler.
assert scheduler.build_connector_meta.call_count == NUM_TOKENS + int(
use_overlap_scheduler)
for i, call in enumerate(scheduler.build_connector_meta.call_args_list):
sched_output = call.args[0]
if i == 0:
assert len(sched_output.new_requests) == 1
assert len(sched_output.cached_requests) == 0
request = sched_output.new_requests[0]
assert len(request.new_tokens) == NUM_INPUT_TOKENS
assert len(request.new_block_ids) == math.ceil(NUM_INPUT_TOKENS /
BLOCK_SIZE)
assert request.computed_position == 0
elif i == 1 and use_overlap_scheduler:
assert len(sched_output.new_requests) == 0
assert len(sched_output.cached_requests) == 1
assert len(sched_output.cached_requests[0].new_tokens) == 0
else:
assert len(sched_output.cached_requests) == 1
assert len(sched_output.new_requests) == 0
request = sched_output.cached_requests[0]
assert len(request.new_tokens) == 1
if (request.computed_position +
int(use_overlap_scheduler)) % BLOCK_SIZE == 0:
assert len(request.new_block_ids) == 1
else:
assert request.new_block_ids == []
scheduler.build_connector_meta.reset_mock()
scheduler.get_num_new_matched_tokens.return_value = 8, False
assert len(scheduler.request_finished.call_args.args[1]) == math.ceil(
(NUM_INPUT_TOKENS + NUM_TOKENS) / BLOCK_SIZE)
model.generate([1] * NUM_INPUT_TOKENS, sampling_params)
# The initial computed position should be 0, since we haven't yet onboarded any blocks.
assert scheduler.build_connector_meta.call_args_list[0].args[
0].new_requests[0].computed_position == 0
@pytest.mark.threadleak(enabled=False)
@pytest.mark.parametrize("use_overlap_scheduler", [True, False])
def test_connector_scheduler_output_chunked_context(enforce_single_worker,
model_with_connector,
use_overlap_scheduler):
model_fn, scheduler, worker = model_with_connector
CHUNK_SIZE = 128
BLOCK_SIZE = 32
model = model_fn(disable_overlap_scheduler=not use_overlap_scheduler,
enable_chunked_prefill=True,
max_num_tokens=CHUNK_SIZE)
assert worker.register_kv_caches.call_count == 1
scheduler.get_num_new_matched_tokens.return_value = 0, False
worker.get_finished.return_value = [], []
sampling_params = SamplingParams(max_tokens=BLOCK_SIZE, ignore_eos=True)
model.generate([0] * (CHUNK_SIZE * 2), sampling_params)
assert scheduler.update_state_after_alloc.call_count == 1
assert len(
scheduler.update_state_after_alloc.call_args.args[1]) == math.ceil(
CHUNK_SIZE * 2 / BLOCK_SIZE)
for i, call in enumerate(scheduler.build_connector_meta.call_args_list):
sched_output = call.args[0]
if i == 0:
assert len(sched_output.new_requests) == 1
assert len(sched_output.cached_requests) == 0
req = sched_output.new_requests[0]
else:
assert len(sched_output.cached_requests) == 1
assert len(sched_output.new_requests) == 0
req = sched_output.cached_requests[0]
if i == 0:
# The first prefill chunk.
# All of the prefill tokens and all the blocks should be provided upfront.
assert req.computed_position == 0
assert len(req.new_tokens) == CHUNK_SIZE * 2
assert len(req.new_block_ids) == math.ceil(CHUNK_SIZE * 2 /
BLOCK_SIZE)
elif i == 1:
# The second prefill chunk.
assert req.computed_position == CHUNK_SIZE
assert len(req.new_tokens) == 0
assert len(req.new_block_ids) == 0
elif i == 2 and use_overlap_scheduler:
assert len(req.new_tokens) == 0
else:
assert len(req.new_tokens) == 1
assert len(scheduler.request_finished.call_args.args[1]) == math.ceil(
(CHUNK_SIZE * 2 + BLOCK_SIZE) / BLOCK_SIZE)

View File

@ -163,3 +163,12 @@ def test_llmapi_sampling(llm_root, engine_dir, llm_venv):
@pytest.mark.skip(reason="https://nvbugs/5365825")
def test_llmapi_runtime(llm_root, engine_dir, llm_venv):
_run_llmapi_example(llm_root, engine_dir, llm_venv, "llm_runtime.py")
@pytest.mark.parametrize("model", ["Qwen2-0.5B"])
def test_llmapi_kv_cache_connector(llm_root, llm_venv, model):
script_path = Path(
llm_root) / "examples" / "llm-api" / "llm_kv_cache_connector.py"
model_path = f"{llm_models_root()}/{model}"
venv_check_call(llm_venv, [str(script_path), model_path])

View File

@ -81,12 +81,23 @@ l0_a10:
- unittest/trt/model/test_mistral.py
- unittest/trt/model/test_llama.py
- test_e2e.py::test_gpt3_175b_1layers_build_only # 6 mins
- llmapi/test_llm_api_connector.py::test_connector_simple[True]
- llmapi/test_llm_api_connector.py::test_connector_simple[False]
- llmapi/test_llm_api_connector.py::test_connector_async_onboard[True]
- llmapi/test_llm_api_connector.py::test_connector_async_onboard[False]
- llmapi/test_llm_api_connector.py::test_connector_async_save[True]
- llmapi/test_llm_api_connector.py::test_connector_async_save[False]
- llmapi/test_llm_api_connector.py::test_connector_scheduler_output[True]
- llmapi/test_llm_api_connector.py::test_connector_scheduler_output[False]
- llmapi/test_llm_api_connector.py::test_connector_scheduler_output_chunked_context[True]
- llmapi/test_llm_api_connector.py::test_connector_scheduler_output_chunked_context[False]
- llmapi/test_llm_e2e.py::test_llmapi_load_engine_from_build_command[llama-llama-models/llama-7b-hf] # 5min
- llmapi/test_llm_e2e.py::test_llmapi_build_command_parameters_align[llama-llama-models-v2/TinyLlama-1.1B-Chat-v1.0]
- llmapi/test_llm_e2e.py::test_llmapi_load_engine_from_build_command_with_lora[llama-llama-models-v2/llama-v2-7b-hf]
- llmapi/test_llm_examples.py::test_llmapi_chat_example
- llmapi/test_llm_e2e.py::test_llmapi_exit
- llmapi/test_llm_examples.py::test_llmapi_server_example
- llmapi/test_llm_examples.py::test_llmapi_kv_cache_connector[Qwen2-0.5B]
- test_e2e.py::test_trtllm_serve_example
- test_e2e.py::test_openai_misc_example[trt]
- test_e2e.py::test_openai_completions_example[trt]

View File

@ -0,0 +1,170 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import sys
from unittest.mock import MagicMock
import cloudpickle
import mpi4py
import pytest
from tensorrt_llm import mpi_rank
from tensorrt_llm._torch.pyexecutor.kv_cache_connector import \
KvCacheConnectorManager
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
cloudpickle.register_pickle_by_value(sys.modules[__name__])
mpi4py.MPI.pickle.__init__(
cloudpickle.dumps,
cloudpickle.loads,
pickle.HIGHEST_PROTOCOL,
)
def run_across_mpi(executor, fun, num_ranks):
return list(executor.starmap(fun, [() for i in range(num_ranks)]))
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
# TODO(jthomson04): I don't have the slightest idea why this test is leaking threads.
@pytest.mark.threadleak(enabled=False)
def test_connector_manager_get_finished_allgather(mpi_pool_executor):
def test():
worker = MagicMock()
if mpi_rank() == 0:
scheduler = MagicMock()
scheduler.request_finished.return_value = True
else:
scheduler = None
manager = KvCacheConnectorManager(worker, scheduler=scheduler)
req = MagicMock()
req.request_id = 42
manager.request_finished(req, [])
# To start, make both workers return nothing.
worker.get_finished.return_value = ([], [])
assert manager.get_finished() == []
assert worker.get_finished.call_count == 1
assert worker.get_finished.call_args[0] == ([42], [])
worker.get_finished.reset_mock()
# Now, only return the request id on one worker.
if mpi_rank() == 0:
worker.get_finished.return_value = ([42], [])
else:
worker.get_finished.return_value = ([], [])
# It should still return nothing, since rank 1 is still saving.
assert manager.get_finished() == []
assert worker.get_finished.call_count == 1
assert worker.get_finished.call_args[0] == ([], [])
# Now, also return it on worker 1.
if mpi_rank() == 0:
worker.get_finished.return_value = ([], [])
else:
worker.get_finished.return_value = ([42], [])
assert manager.get_finished() == [req]
run_across_mpi(mpi_pool_executor, test, 2)
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
def test_connector_manager_num_matched_tokens(mpi_pool_executor):
def test():
worker = MagicMock()
if mpi_rank() == 0:
scheduler = MagicMock()
scheduler.get_num_new_matched_tokens.return_value = (16, True)
else:
scheduler = None
manager = KvCacheConnectorManager(worker, scheduler=scheduler)
req = MagicMock()
req.request_id = 42
assert manager.get_num_new_matched_tokens(req, 32) == 16
if mpi_rank() == 0:
assert scheduler.get_num_new_matched_tokens.call_count == 1
assert scheduler.get_num_new_matched_tokens.call_args[0] == (req,
32)
run_across_mpi(mpi_pool_executor, test, 2)
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
def test_connector_manager_take_scheduled_requests(mpi_pool_executor):
def test():
worker = MagicMock()
if mpi_rank() == 0:
scheduler = MagicMock()
else:
scheduler = None
manager = KvCacheConnectorManager(worker, scheduler=scheduler)
scheduled_requests = ScheduledRequests()
req0 = MagicMock()
req0.request_id = 0
req1 = MagicMock()
req1.request_id = 1
if mpi_rank() == 0:
scheduler.get_num_new_matched_tokens.return_value = (16, True)
assert manager.get_num_new_matched_tokens(req0, 0) == 16
if mpi_rank() == 0:
assert scheduler.get_num_new_matched_tokens.call_count == 1
assert scheduler.get_num_new_matched_tokens.call_args[0] == (req0,
0)
scheduler.get_num_new_matched_tokens.reset_mock()
scheduler.get_num_new_matched_tokens.return_value = (32, False)
assert manager.get_num_new_matched_tokens(req1, 0) == 32
if mpi_rank() == 0:
assert scheduler.get_num_new_matched_tokens.call_count == 1
assert scheduler.get_num_new_matched_tokens.call_args[0] == (req1,
0)
scheduled_requests.context_requests = [req0, req1]
manager.take_scheduled_requests_pending_load(scheduled_requests)
assert scheduled_requests.context_requests == [req1]
run_across_mpi(mpi_pool_executor, test, 2)

View File

@ -167,6 +167,10 @@ methods:
annotation: Optional[tensorrt_llm.llmapi.llm_args.DecodingConfig]
default: null
status: deprecated
kv_connector_config:
annotation: Optional[tensorrt_llm.llmapi.llm_args.KvCacheConnectorConfig]
default: null
status: prototype
return_annotation: None
generate:
parameters: